diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index a01460ace..ccef8bfb1 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -83,11 +83,14 @@ FACTORY = { CONSUMER_NAME = "task_consumer_" + CONSUMER_NO PAYLOAD: Payload | None = None BOOT_AT = datetime.now().isoformat() -DONE_TASKS = 0 -FAILED_TASKS = 0 PENDING_TASKS = 0 LAG_TASKS = 0 +mt_lock = threading.Lock() +DONE_TASKS = 0 +FAILED_TASKS = 0 +CURRENT_TASK = None + def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."): global PAYLOAD @@ -135,12 +138,14 @@ def collect(): return None if TaskService.do_cancel(msg["id"]): - DONE_TASKS += 1 + with mt_lock: + DONE_TASKS += 1 logging.info("Task {} has been canceled.".format(msg["id"])) return None task = TaskService.get_task(msg["id"]) if not task: - DONE_TASKS += 1 + with mt_lock: + DONE_TASKS += 1 logging.warning("{} empty task!".format(msg["id"])) return None @@ -427,16 +432,22 @@ def do_handle_task(r): def handle_task(): - global PAYLOAD, DONE_TASKS, FAILED_TASKS + global PAYLOAD, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK task = collect() if task: try: logging.info(f"handle_task begin for task {json.dumps(task)}") + with mt_lock: + CURRENT_TASK = copy.deepcopy(task) do_handle_task(task) - DONE_TASKS += 1 - logging.exception(f"handle_task done for task {json.dumps(task)}") + with mt_lock: + DONE_TASKS += 1 + CURRENT_TASK = None + logging.info(f"handle_task done for task {json.dumps(task)}") except Exception: - FAILED_TASKS += 1 + with mt_lock: + FAILED_TASKS += 1 + CURRENT_TASK = None logging.exception(f"handle_task got exception for task {json.dumps(task)}") if PAYLOAD: PAYLOAD.ack() @@ -444,7 +455,7 @@ def handle_task(): def report_status(): - global CONSUMER_NAME, BOOT_AT, DONE_TASKS, FAILED_TASKS, PENDING_TASKS, LAG_TASKS + global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME) while True: try: @@ -454,15 +465,17 @@ def report_status(): PENDING_TASKS = int(group_info["pending"]) LAG_TASKS = int(group_info["lag"]) - heartbeat = json.dumps({ - "name": CONSUMER_NAME, - "now": now.isoformat(), - "boot_at": BOOT_AT, - "done": DONE_TASKS, - "failed": FAILED_TASKS, - "pending": PENDING_TASKS, - "lag": LAG_TASKS, - }) + with mt_lock: + heartbeat = json.dumps({ + "name": CONSUMER_NAME, + "now": now.isoformat(), + "boot_at": BOOT_AT, + "pending": PENDING_TASKS, + "lag": LAG_TASKS, + "done": DONE_TASKS, + "failed": FAILED_TASKS, + "current": CURRENT_TASK, + }) REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp()) logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}") @@ -474,6 +487,7 @@ def report_status(): time.sleep(30) def main(): + settings.init_settings() background_thread = threading.Thread(target=report_status) background_thread.daemon = True background_thread.start()