From a44ed9626a3c8f5b3c7b041b814a14d3f22ec39e Mon Sep 17 00:00:00 2001 From: yqkcn <410728991@qq.com> Date: Sun, 29 Sep 2024 09:49:45 +0800 Subject: [PATCH] handle nits in task_executor (#2637) ### What problem does this PR solve? - fix typo - fix string format - format import ### Type of change - [x] Refactoring --- rag/svr/task_executor.py | 78 ++++++++++++++++++++-------------------- 1 file changed, 38 insertions(+), 40 deletions(-) diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index b25bbbea2..48d00b13e 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -25,34 +25,31 @@ import time import traceback from concurrent.futures import ThreadPoolExecutor from functools import partial - -from api.db.services.file2document_service import File2DocumentService -from api.settings import retrievaler -from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor -from rag.utils.storage_factory import STORAGE_IMPL -from api.db.db_models import close_connection -from rag.settings import database_logger, SVR_QUEUE_NAME -from rag.settings import cron_logger, DOC_MAXIMUM_SIZE -from multiprocessing import Pool -import numpy as np -from elasticsearch_dsl import Q, Search -from multiprocessing.context import TimeoutError -from api.db.services.task_service import TaskService -from rag.utils.es_conn import ELASTICSEARCH -from timeit import default_timer as timer -from rag.utils import rmSpace, findMaxTm, num_tokens_from_string - -from rag.nlp import search, rag_tokenizer from io import BytesIO -import pandas as pd +from multiprocessing.context import TimeoutError +from timeit import default_timer as timer -from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email +import numpy as np +import pandas as pd +from elasticsearch_dsl import Q from api.db import LLMType, ParserType from api.db.services.document_service import DocumentService from api.db.services.llm_service import LLMBundle +from api.db.services.task_service import TaskService +from api.db.services.file2document_service import File2DocumentService +from api.settings import retrievaler from api.utils.file_utils import get_project_base_directory -from rag.utils.redis_conn import REDIS_CONN +from api.db.db_models import close_connection +from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email +from rag.nlp import search, rag_tokenizer +from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor +from rag.settings import database_logger, SVR_QUEUE_NAME +from rag.settings import cron_logger, DOC_MAXIMUM_SIZE +from rag.utils import rmSpace, num_tokens_from_string +from rag.utils.es_conn import ELASTICSEARCH +from rag.utils.redis_conn import REDIS_CONN, Payload +from rag.utils.storage_factory import STORAGE_IMPL BATCH_SIZE = 64 @@ -74,11 +71,11 @@ FACTORY = { ParserType.KG.value: knowledge_graph } -CONSUMEER_NAME = "task_consumer_" + ("0" if len(sys.argv) < 2 else sys.argv[1]) -PAYLOAD = None +CONSUMER_NAME = "task_consumer_" + ("0" if len(sys.argv) < 2 else sys.argv[1]) +PAYLOAD: Payload | None = None -def set_progress(task_id, from_page=0, to_page=-1, - prog=None, msg="Processing..."): + +def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."): global PAYLOAD if prog is not None and prog < 0: msg = "[ERROR]" + msg @@ -107,11 +104,11 @@ def set_progress(task_id, from_page=0, to_page=-1, def collect(): - global CONSUMEER_NAME, PAYLOAD + global CONSUMER_NAME, PAYLOAD try: - PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMEER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker") + PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker") if not PAYLOAD: - PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMEER_NAME) + PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME) if not PAYLOAD: time.sleep(1) return pd.DataFrame() @@ -159,8 +156,8 @@ def build(row): binary = get_storage_binary(bucket, name) cron_logger.info( "From minio({}) {}/{}".format(timer() - st, row["location"], row["name"])) - except TimeoutError as e: - callback(-1, f"Internal server error: Fetch file from minio timeout. Could you try it again.") + except TimeoutError: + callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.") cron_logger.error( "Minio {}/{}: Fetch file from minio timeout.".format(row["location"], row["name"])) return @@ -168,8 +165,7 @@ def build(row): if re.search("(No such file|not found)", str(e)): callback(-1, "Can not find file <%s> from minio. Could you try it again?" % row["name"]) else: - callback(-1, f"Get file from minio: %s" % - str(e).replace("'", "")) + callback(-1, "Get file from minio: %s" % str(e).replace("'", "")) traceback.print_exc() return @@ -180,7 +176,7 @@ def build(row): cron_logger.info( "Chunking({}) {}/{}".format(timer() - st, row["location"], row["name"])) except Exception as e: - callback(-1, f"Internal server error while chunking: %s" % + callback(-1, "Internal server error while chunking: %s" % str(e).replace("'", "")) cron_logger.error( "Chunking {}/{}: {}".format(row["location"], row["name"], str(e))) @@ -236,7 +232,9 @@ def init_kb(row): open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r"))) -def embedding(docs, mdl, parser_config={}, callback=None): +def embedding(docs, mdl, parser_config=None, callback=None): + if parser_config is None: + parser_config = {} batch_size = 32 tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [ re.sub(r"]{0,12})?>", " ", d["content_with_weight"]) for d in docs] @@ -277,7 +275,7 @@ def embedding(docs, mdl, parser_config={}, callback=None): def run_raptor(row, chat_mdl, embd_mdl, callback=None): vts, _ = embd_mdl.encode(["ok"]) - vctr_nm = "q_%d_vec"%len(vts[0]) + vctr_nm = "q_%d_vec" % len(vts[0]) chunks = [] for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], fields=["content_with_weight", vctr_nm]): chunks.append((d["content_with_weight"], np.array(d[vctr_nm]))) @@ -374,7 +372,7 @@ def main(): cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st)) if es_r: - callback(-1, f"Insert chunk error, detail info please check ragflow-logs/api/cron_logger.log. Please also check ES status!") + callback(-1, "Insert chunk error, detail info please check ragflow-logs/api/cron_logger.log. Please also check ES status!") ELASTICSEARCH.deleteByQuery( Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"])) cron_logger.error(str(es_r)) @@ -392,15 +390,15 @@ def main(): def report_status(): - global CONSUMEER_NAME + global CONSUMER_NAME while True: try: obj = REDIS_CONN.get("TASKEXE") if not obj: obj = {} else: obj = json.loads(obj) - if CONSUMEER_NAME not in obj: obj[CONSUMEER_NAME] = [] - obj[CONSUMEER_NAME].append(timer()) - obj[CONSUMEER_NAME] = obj[CONSUMEER_NAME][-60:] + if CONSUMER_NAME not in obj: obj[CONSUMER_NAME] = [] + obj[CONSUMER_NAME].append(timer()) + obj[CONSUMER_NAME] = obj[CONSUMER_NAME][-60:] REDIS_CONN.set_obj("TASKEXE", obj, 60*2) except Exception as e: print("[Exception]:", str(e))