diff --git a/api/apps/system_app.py b/api/apps/system_app.py index 5e67ad43e..8b23ff5fa 100644 --- a/api/apps/system_app.py +++ b/api/apps/system_app.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License # -import json +import logging from datetime import datetime from flask_login import login_required, current_user @@ -154,26 +154,16 @@ def status(): "error": str(e), } + task_executor_heartbeats = {} try: - v = REDIS_CONN.get("TASKEXE") - if not v: - raise Exception("No task executor running!") - obj = json.loads(v) - color = "green" - for id in obj.keys(): - arr = obj[id] - if len(arr) == 1: - obj[id] = [0] - else: - obj[id] = [arr[i + 1] - arr[i] for i in range(len(arr) - 1)] - elapsed = max(obj[id]) - if elapsed > 50: - color = "yellow" - if elapsed > 120: - color = "red" - res["task_executor"] = {"status": color, "elapsed": obj} - except Exception as e: - res["task_executor"] = {"status": "red", "error": str(e)} + task_executors = REDIS_CONN.smembers("TASKEXE") + now = datetime.now().timestamp() + for task_executor_id in task_executors: + heartbeats = REDIS_CONN.zrangebyscore(task_executor_id, now - 60*30, now) + task_executor_heartbeats[task_executor_id] = heartbeats + except Exception: + logging.exception("get task executor heartbeats failed!") + res["task_executor_heartbeats"] = task_executor_heartbeats return get_json_result(data=res) diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py index 90e4c9ed2..6af13c75c 100644 --- a/api/db/services/task_service.py +++ b/api/db/services/task_service.py @@ -36,7 +36,7 @@ class TaskService(CommonService): @classmethod @DB.connection_context() - def get_tasks(cls, task_id): + def get_task(cls, task_id): fields = [ cls.model.id, cls.model.doc_id, @@ -63,7 +63,7 @@ class TaskService(CommonService): .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \ .where(cls.model.id == task_id) docs = list(docs.dicts()) - if not docs: return [] + if not docs: return None msg = "\nTask has been received." prog = random.random() / 10. @@ -77,9 +77,9 @@ class TaskService(CommonService): ).where( cls.model.id == docs[0]["id"]).execute() - if docs[0]["retry_count"] >= 3: return [] + if docs[0]["retry_count"] >= 3: return None - return docs + return docs[0] @classmethod @DB.connection_context() @@ -108,7 +108,7 @@ class TaskService(CommonService): task = cls.model.get_by_id(id) _, doc = DocumentService.get_by_id(task.doc_id) return doc.run == TaskStatus.CANCEL.value or doc.progress < 0 - except Exception as e: + except Exception: pass return False diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 76568d927..a01460ace 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -42,7 +42,6 @@ from multiprocessing.context import TimeoutError from timeit import default_timer as timer import numpy as np -import pandas as pd from api.db import LLMType, ParserType from api.db.services.dialog_service import keyword_extraction, question_proposal @@ -85,10 +84,9 @@ CONSUMER_NAME = "task_consumer_" + CONSUMER_NO PAYLOAD: Payload | None = None BOOT_AT = datetime.now().isoformat() DONE_TASKS = 0 -RETRY_TASKS = 0 +FAILED_TASKS = 0 PENDING_TASKS = 0 -HEAD_CREATED_AT = "" -HEAD_DETAIL = "" +LAG_TASKS = 0 def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."): @@ -120,34 +118,35 @@ def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing... def collect(): - global CONSUMER_NAME, PAYLOAD + global CONSUMER_NAME, PAYLOAD, DONE_TASKS, FAILED_TASKS try: PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker") if not PAYLOAD: PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME) if not PAYLOAD: time.sleep(1) - return pd.DataFrame() + return None except Exception: logging.exception("Get task event from queue exception") - return pd.DataFrame() + return None msg = PAYLOAD.get_message() if not msg: - return pd.DataFrame() + return None if TaskService.do_cancel(msg["id"]): + DONE_TASKS += 1 logging.info("Task {} has been canceled.".format(msg["id"])) - return pd.DataFrame() - tasks = TaskService.get_tasks(msg["id"]) - if not tasks: + return None + task = TaskService.get_task(msg["id"]) + if not task: + DONE_TASKS += 1 logging.warning("{} empty task!".format(msg["id"])) - return [] + return None - tasks = pd.DataFrame(tasks) if msg.get("type", "") == "raptor": - tasks["task_type"] = "raptor" - return tasks + task["task_type"] = "raptor" + return task def get_storage_binary(bucket, name): @@ -176,14 +175,14 @@ def build(row): callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.") logging.exception( "Minio {}/{} got timeout: Fetch file from minio timeout.".format(row["location"], row["name"])) - return + raise except Exception as e: if re.search("(No such file|not found)", str(e)): callback(-1, "Can not find file <%s> from minio. Could you try it again?" % row["name"]) else: callback(-1, "Get file from minio: %s" % str(e).replace("'", "")) logging.exception("Chunking {}/{} got exception".format(row["location"], row["name"])) - return + raise try: cks = chunker.chunk(row["name"], binary=binary, from_page=row["from_page"], @@ -194,7 +193,7 @@ def build(row): callback(-1, "Internal server error while chunking: %s" % str(e).replace("'", "")) logging.exception("Chunking {}/{} got exception".format(row["location"], row["name"])) - return + raise docs = [] doc = { @@ -212,6 +211,7 @@ def build(row): d["create_time"] = str(datetime.now()).replace("T", " ")[:19] d["create_timestamp_flt"] = datetime.now().timestamp() if not d.get("image"): + _ = d.pop("image", None) d["img_id"] = "" d["page_num_list"] = json.dumps([]) d["position_list"] = json.dumps([]) @@ -232,6 +232,7 @@ def build(row): except Exception: logging.exception( "Saving image of chunk {}/{}/{} got exception".format(row["location"], row["name"], d["_id"])) + raise d["img_id"] = "{}-{}".format(row["kb_id"], d["id"]) del d["image"] @@ -356,105 +357,111 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None): return res, tk_count, vector_size -def main(): - rows = collect() - if len(rows) == 0: - return - - for _, r in rows.iterrows(): - callback = partial(set_progress, r["id"], r["from_page"], r["to_page"]) +def do_handle_task(r): + callback = partial(set_progress, r["id"], r["from_page"], r["to_page"]) + try: + embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING, llm_name=r["embd_id"], lang=r["language"]) + except Exception as e: + callback(-1, msg=str(e)) + raise + if r.get("task_type", "") == "raptor": try: - embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING, llm_name=r["embd_id"], lang=r["language"]) + chat_mdl = LLMBundle(r["tenant_id"], LLMType.CHAT, llm_name=r["llm_id"], lang=r["language"]) + cks, tk_count, vector_size = run_raptor(r, chat_mdl, embd_mdl, callback) except Exception as e: callback(-1, msg=str(e)) - logging.exception("LLMBundle got exception") - continue - - if r.get("task_type", "") == "raptor": - try: - chat_mdl = LLMBundle(r["tenant_id"], LLMType.CHAT, llm_name=r["llm_id"], lang=r["language"]) - cks, tk_count, vector_size = run_raptor(r, chat_mdl, embd_mdl, callback) - except Exception as e: - callback(-1, msg=str(e)) - logging.exception("run_raptor got exception") - continue - else: - st = timer() - cks = build(r) - logging.info("Build chunks({}): {}".format(r["name"], timer() - st)) - if cks is None: - continue - if not cks: - callback(1., "No chunk! Done!") - continue - # TODO: exception handler - ## set_progress(r["did"], -1, "ERROR: ") - callback( + raise + else: + st = timer() + cks = build(r) + logging.info("Build chunks({}): {}".format(r["name"], timer() - st)) + if cks is None: + return + if not cks: + callback(1., "No chunk! Done!") + return + # TODO: exception handler + ## set_progress(r["did"], -1, "ERROR: ") + callback( msg="Finished slicing files ({} chunks in {:.2f}s). Start to embedding the content.".format(len(cks), timer() - st) - ) - st = timer() - try: - tk_count, vector_size = embedding(cks, embd_mdl, r["parser_config"], callback) - except Exception as e: - callback(-1, "Embedding error:{}".format(str(e))) - logging.exception("run_rembedding got exception") - tk_count = 0 - logging.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st)) - callback(msg="Finished embedding (in {:.2f}s)! Start to build index!".format(timer() - st)) - - # logging.info(f"task_executor init_kb index {search.index_name(r["tenant_id"])} embd_mdl {embd_mdl.llm_name} vector length {vector_size}") - init_kb(r, vector_size) - chunk_count = len(set([c["id"] for c in cks])) + ) st = timer() - es_r = "" - es_bulk_size = 4 - for b in range(0, len(cks), es_bulk_size): - es_r = settings.docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"]) - if b % 128 == 0: - callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="") + try: + tk_count, vector_size = embedding(cks, embd_mdl, r["parser_config"], callback) + except Exception as e: + callback(-1, "Embedding error:{}".format(str(e))) + logging.exception("run_rembedding got exception") + tk_count = 0 + raise + logging.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st)) + callback(msg="Finished embedding (in {:.2f}s)! Start to build index!".format(timer() - st)) + # logging.info(f"task_executor init_kb index {search.index_name(r["tenant_id"])} embd_mdl {embd_mdl.llm_name} vector length {vector_size}") + init_kb(r, vector_size) + chunk_count = len(set([c["id"] for c in cks])) + st = timer() + es_r = "" + es_bulk_size = 4 + for b in range(0, len(cks), es_bulk_size): + es_r = settings.docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"]) + if b % 128 == 0: + callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="") + logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st)) + if es_r: + callback(-1, "Insert chunk error, detail info please check log file. Please also check Elasticsearch/Infinity status!") + settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"]) + logging.error('Insert chunk error: ' + str(es_r)) + raise Exception('Insert chunk error: ' + str(es_r)) - logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st)) - if es_r: - callback(-1, "Insert chunk error, detail info please check log file. Please also check ES status!") - settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"]) - logging.error('Insert chunk error: ' + str(es_r)) - else: - if TaskService.do_cancel(r["id"]): - settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"]) - continue - callback(msg="Indexing elapsed in {:.2f}s.".format(timer() - st)) - callback(1., "Done!") - DocumentService.increment_chunk_num( - r["doc_id"], r["kb_id"], tk_count, chunk_count, 0) - logging.info( - "Chunk doc({}), token({}), chunks({}), elapsed:{:.2f}".format( - r["id"], tk_count, len(cks), timer() - st)) + if TaskService.do_cancel(r["id"]): + settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"]) + return + + callback(msg="Indexing elapsed in {:.2f}s.".format(timer() - st)) + callback(1., "Done!") + DocumentService.increment_chunk_num( + r["doc_id"], r["kb_id"], tk_count, chunk_count, 0) + logging.info( + "Chunk doc({}), token({}), chunks({}), elapsed:{:.2f}".format( + r["id"], tk_count, len(cks), timer() - st)) + + +def handle_task(): + global PAYLOAD, DONE_TASKS, FAILED_TASKS + task = collect() + if task: + try: + logging.info(f"handle_task begin for task {json.dumps(task)}") + do_handle_task(task) + DONE_TASKS += 1 + logging.exception(f"handle_task done for task {json.dumps(task)}") + except Exception: + FAILED_TASKS += 1 + logging.exception(f"handle_task got exception for task {json.dumps(task)}") + if PAYLOAD: + PAYLOAD.ack() + PAYLOAD = None def report_status(): - global CONSUMER_NAME, BOOT_AT, DONE_TASKS, RETRY_TASKS, PENDING_TASKS, HEAD_CREATED_AT, HEAD_DETAIL + global CONSUMER_NAME, BOOT_AT, DONE_TASKS, FAILED_TASKS, PENDING_TASKS, LAG_TASKS REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME) while True: try: now = datetime.now() - PENDING_TASKS = REDIS_CONN.queue_length(SVR_QUEUE_NAME) - if PENDING_TASKS > 0: - head_info = REDIS_CONN.queue_head(SVR_QUEUE_NAME) - if head_info is not None: - seconds = int(head_info[0].split("-")[0]) / 1000 - HEAD_CREATED_AT = datetime.fromtimestamp(seconds).isoformat() - HEAD_DETAIL = head_info[1] + group_info = REDIS_CONN.queue_info(SVR_QUEUE_NAME, "rag_flow_svr_task_broker") + if group_info is not None: + PENDING_TASKS = int(group_info["pending"]) + LAG_TASKS = int(group_info["lag"]) heartbeat = json.dumps({ "name": CONSUMER_NAME, "now": now.isoformat(), "boot_at": BOOT_AT, "done": DONE_TASKS, - "retry": RETRY_TASKS, + "failed": FAILED_TASKS, "pending": PENDING_TASKS, - "head_created_at": HEAD_CREATED_AT, - "head_detail": HEAD_DETAIL, + "lag": LAG_TASKS, }) REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp()) logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}") @@ -466,14 +473,13 @@ def report_status(): logging.exception("report_status got exception") time.sleep(30) - -if __name__ == "__main__": +def main(): background_thread = threading.Thread(target=report_status) background_thread.daemon = True background_thread.start() while True: - main() - if PAYLOAD: - PAYLOAD.ack() - PAYLOAD = None + handle_task() + +if __name__ == "__main__": + main() diff --git a/rag/utils/redis_conn.py b/rag/utils/redis_conn.py index b77d4eba2..7013a227d 100644 --- a/rag/utils/redis_conn.py +++ b/rag/utils/redis_conn.py @@ -225,14 +225,16 @@ class RedisDB: logging.exception("xpending_range: " + consumer_name + " got exception") self.__open__() - def queue_length(self, queue) -> int: + def queue_info(self, queue, group_name) -> dict: for _ in range(3): try: - num = self.REDIS.xlen(queue) - return num + groups = self.REDIS.xinfo_groups(queue) + for group in groups: + if group["name"] == group_name: + return group except Exception: logging.exception("queue_length" + str(queue) + " got exception") - return 0 + return None def queue_head(self, queue) -> int: for _ in range(3):