fix task cancling bug (#98)

This commit is contained in:
KevinHuSh 2024-03-05 16:33:47 +08:00 committed by GitHub
parent 07d76ea18d
commit 602038ac49
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 24 additions and 15 deletions

View File

@ -316,8 +316,7 @@ def change_parser():
return get_data_error_result(retmsg="Not supported yet!") return get_data_error_result(retmsg="Not supported yet!")
e = DocumentService.update_by_id(doc.id, e = DocumentService.update_by_id(doc.id,
{"parser_id": req["parser_id"], "progress": 0, "progress_msg": "", "run": "0", {"parser_id": req["parser_id"], "progress": 0, "progress_msg": "", "run": "0"})
"token_num": 0, "chunk_num": 0, "process_duation": 0})
if not e: if not e:
return get_data_error_result(retmsg="Document not found!") return get_data_error_result(retmsg="Document not found!")
if doc.token_num > 0: if doc.token_num > 0:

View File

@ -73,8 +73,9 @@ class TaskService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def update_progress(cls, id, info): def update_progress(cls, id, info):
cls.model.update(progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]).where( if info["progress_msg"]:
cls.model.id == id).execute() cls.model.update(progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]).where(
cls.model.id == id).execute()
if "progress" in info: if "progress" in info:
cls.model.update(progress=info["progress"]).where( cls.model.update(progress=info["progress"]).where(
cls.model.id == id).execute() cls.model.id == id).execute()

View File

@ -725,7 +725,7 @@ class HuParser:
(cropout( (cropout(
bxs, bxs,
"figure", poss), "figure", poss),
[txt] if not return_html else [f"<p>{txt}</p>"])) [txt]))
positions.append(poss) positions.append(poss)
for k, bxs in tables.items(): for k, bxs in tables.items():

View File

@ -16,7 +16,7 @@ MEM_LIMIT=4073741824
MYSQL_PASSWORD=infini_rag_flow MYSQL_PASSWORD=infini_rag_flow
MYSQL_PORT=5455 MYSQL_PORT=5455
MINIO_USER=infiniflow MINIO_USER=rag_flow
MINIO_PASSWORD=infini_rag_flow MINIO_PASSWORD=infini_rag_flow
SVR_HTTP_PORT=9380 SVR_HTTP_PORT=9380

View File

@ -28,7 +28,7 @@ class Pdf(PdfParser):
from_page, from_page,
to_page, to_page,
callback) callback)
callback("OCR finished") callback(msg="OCR finished")
from timeit import default_timer as timer from timeit import default_timer as timer
start = timer() start = timer()

View File

@ -57,7 +57,7 @@ class Pdf(PdfParser):
to_page, to_page,
callback callback
) )
callback("OCR finished") callback(msg="OCR finished")
from timeit import default_timer as timer from timeit import default_timer as timer
start = timer() start = timer()
@ -135,6 +135,6 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
def dummy(a, b): def dummy(prog=None, msg=""):
pass pass
chunk(sys.argv[1], callback=dummy) chunk(sys.argv[1], callback=dummy)

View File

@ -22,7 +22,7 @@ class Pdf(PdfParser):
to_page, to_page,
callback callback
) )
callback("OCR finished.") callback(msg="OCR finished.")
from timeit import default_timer as timer from timeit import default_timer as timer
start = timer() start = timer()

View File

@ -29,7 +29,7 @@ class Pdf(PdfParser):
to_page, to_page,
callback callback
) )
callback("OCR finished") callback(msg="OCR finished")
from timeit import default_timer as timer from timeit import default_timer as timer
start = timer() start = timer()

View File

@ -36,7 +36,7 @@ class Pdf(PdfParser):
to_page, to_page,
callback callback
) )
callback("OCR finished.") callback(msg="OCR finished.")
from timeit import default_timer as timer from timeit import default_timer as timer
start = timer() start = timer()

View File

@ -305,8 +305,15 @@ class Dealer:
"similarity": sim[i], "similarity": sim[i],
"vector_similarity": vsim[i], "vector_similarity": vsim[i],
"term_similarity": tsim[i], "term_similarity": tsim[i],
"vector": self.trans2floats(sres.field[id].get("q_%d_vec" % dim, "\t".join(["0"] * dim))) "vector": self.trans2floats(sres.field[id].get("q_%d_vec" % dim, "\t".join(["0"] * dim))),
"positions": sres.field[id].get("position_int", "").split("\t")
} }
if len(d["positions"]) % 5 == 0:
poss = []
for i in range(0, len(d["positions"]), 5):
poss.append([float(d["positions"][i]), float(d["positions"][i + 1]), float(d["positions"][i + 2]),
float(d["positions"][i + 3]), float(d["positions"][i + 4])])
d["positions"] = poss
ranks["chunks"].append(d) ranks["chunks"].append(d)
if dnm not in ranks["doc_aggs"]: if dnm not in ranks["doc_aggs"]:
ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0} ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}

View File

@ -25,6 +25,7 @@ import traceback
from functools import partial from functools import partial
from timeit import default_timer as timer from timeit import default_timer as timer
import numpy as np
from elasticsearch_dsl import Q from elasticsearch_dsl import Q
from api.db.services.task_service import TaskService from api.db.services.task_service import TaskService
@ -177,10 +178,11 @@ def embedding(docs, mdl, parser_config={}, callback=None):
tts, c = mdl.encode(tts) tts, c = mdl.encode(tts)
tk_count += c tk_count += c
cnts_ = [] cnts_ = np.array([])
for i in range(0, len(cnts), 32): for i in range(0, len(cnts), 32):
vts, c = mdl.encode(cnts[i: i+32]) vts, c = mdl.encode(cnts[i: i+32])
cnts_.extend(vts) if len(cnts_) == 0: cnts_ = vts
else: cnts_ = np.concatenate((cnts_, vts), axis=0)
tk_count += c tk_count += c
callback(msg="") callback(msg="")
cnts = cnts_ cnts = cnts_