From d4dbdfb61d42d7a68cd297d95757ecfd2c487289 Mon Sep 17 00:00:00 2001 From: liuzhenghua <1090179900@qq.com> Date: Sat, 19 Apr 2025 16:18:51 +0800 Subject: [PATCH] feat: Recover pending tasks while pod restart. (#7073) ### What problem does this PR solve? If you deploy Ragflow using Kubernetes, the hostname will change during a rolling update. This causes the consumer name of the task executor to change, making it impossible to schedule tasks that were previously in a pending state. To address this, I introduced a recovery task that scans these pending messages and re-publishes them, allowing the tasks to continue being processed. ### Type of change - [ ] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality) - [ ] Documentation Update - [ ] Refactoring - [ ] Performance Improvement - [ ] Other (please describe): --------- Co-authored-by: liuzhenghua-jk --- rag/svr/task_executor.py | 62 ++++++++++++++++++++++++++++++++++++++-- rag/utils/redis_conn.py | 32 +++++++++++++++++++++ 2 files changed, 92 insertions(+), 2 deletions(-) diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index cd285d5fa..b1024fca6 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -18,6 +18,8 @@ # beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code import random import sys +import threading +import time from api.utils.log_utils import initRootLogger, get_project_base_directory from graphrag.general.index import run_graphrag @@ -58,7 +60,7 @@ from rag.nlp import search, rag_tokenizer from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor from rag.settings import DOC_MAXIMUM_SIZE, SVR_CONSUMER_GROUP_NAME, get_svr_queue_name, get_svr_queue_names, print_rag_settings, TAG_FLD, PAGERANK_FLD from rag.utils import num_tokens_from_string, truncate -from rag.utils.redis_conn import REDIS_CONN +from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock from rag.utils.storage_factory import STORAGE_IMPL from graphrag.utils import chat_limiter @@ -99,6 +101,15 @@ MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5")) MAX_CONCURRENT_CHUNK_BUILDERS = int(os.environ.get('MAX_CONCURRENT_CHUNK_BUILDERS', "1")) task_limiter = trio.CapacityLimiter(MAX_CONCURRENT_TASKS) chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS) +WORKER_HEARTBEAT_TIMEOUT = int(os.environ.get('WORKER_HEARTBEAT_TIMEOUT', '120')) +stop_event = threading.Event() + + +def signal_handler(sig, frame): + logging.info("Received interrupt signal, shutting down...") + stop_event.set() + time.sleep(1) + sys.exit(0) # SIGUSR1 handler: start tracemalloc and take snapshot @@ -621,6 +632,7 @@ async def handle_task(): async def report_status(): global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, DONE_TASKS, FAILED_TASKS REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME) + redis_lock = RedisDistributedLock("clean_task_executor", lock_value=CONSUMER_NAME, timeout=60) while True: try: now = datetime.now() @@ -646,11 +658,52 @@ async def report_status(): expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60 * 30) if expired > 0: REDIS_CONN.zpopmin(CONSUMER_NAME, expired) + + # clean task executor + if redis_lock.acquire(): + task_executors = REDIS_CONN.smembers("TASKEXE") + for consumer_name in task_executors: + if consumer_name == CONSUMER_NAME: + continue + expired = REDIS_CONN.zcount( + consumer_name, now.timestamp() - WORKER_HEARTBEAT_TIMEOUT, now.timestamp() + 10 + ) + if expired == 0: + logging.info(f"{consumer_name} expired, removed") + REDIS_CONN.srem("TASKEXE", consumer_name) + REDIS_CONN.delete(consumer_name) except Exception: logging.exception("report_status got exception") await trio.sleep(30) +def recover_pending_tasks(): + redis_lock = RedisDistributedLock("recover_pending_tasks", lock_value=CONSUMER_NAME, timeout=60) + svr_queue_names = get_svr_queue_names() + while not stop_event.is_set(): + try: + if redis_lock.acquire(): + for queue_name in svr_queue_names: + msgs = REDIS_CONN.get_pending_msg(queue=queue_name, group_name=SVR_CONSUMER_GROUP_NAME) + msgs = [msg for msg in msgs if msg['consumer'] != CONSUMER_NAME] + if len(msgs) == 0: + continue + + task_executors = REDIS_CONN.smembers("TASKEXE") + task_executor_set = {t for t in task_executors} + msgs = [msg for msg in msgs if msg['consumer'] not in task_executor_set] + for msg in msgs: + logging.info( + f"Recover pending task: {msg['message_id']}, consumer: {msg['consumer']}, " + f"time since delivered: {msg['time_since_delivered'] / 1000} s" + ) + REDIS_CONN.requeue_msg(queue_name, SVR_CONSUMER_GROUP_NAME, msg['message_id']) + + stop_event.wait(60) + except Exception: + logging.warning("recover_pending_tasks got exception") + + async def main(): logging.info(r""" ______ __ ______ __ @@ -669,9 +722,14 @@ async def main(): if TRACE_MALLOC_ENABLED: start_tracemalloc_and_snapshot(None, None) + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + threading.Thread(name="RecoverPendingTask", target=recover_pending_tasks).start() + async with trio.open_nursery() as nursery: nursery.start_soon(report_status) - while True: + while not stop_event.is_set(): async with task_limiter: nursery.start_soon(handle_task) logging.error("BUG!!! You should not reach here!!!") diff --git a/rag/utils/redis_conn.py b/rag/utils/redis_conn.py index 08f889f5f..abfb26fb7 100644 --- a/rag/utils/redis_conn.py +++ b/rag/utils/redis_conn.py @@ -282,6 +282,28 @@ class RedisDB: ) self.__open__() + def get_pending_msg(self, queue, group_name): + try: + messages = self.REDIS.xpending_range(queue, group_name, '-', '+', 10) + return messages + except Exception as e: + if 'No such key' not in (str(e) or ''): + logging.warning( + "RedisDB.get_pending_msg " + str(queue) + " got exception: " + str(e) + ) + return [] + + def requeue_msg(self, queue: str, group_name: str, msg_id: str): + try: + messages = self.REDIS.xrange(queue, msg_id, msg_id) + if messages: + self.REDIS.xadd(queue, messages[0][1]) + self.REDIS.xack(queue, group_name, msg_id) + except Exception as e: + logging.warning( + "RedisDB.get_pending_msg " + str(queue) + " got exception: " + str(e) + ) + def queue_info(self, queue, group_name) -> dict | None: try: groups = self.REDIS.xinfo_groups(queue) @@ -301,6 +323,16 @@ class RedisDB: """ return bool(self.lua_delete_if_equal(keys=[key], args=[expected_value], client=self.REDIS)) + def delete(self, key) -> bool: + try: + self.REDIS.delete(key) + return True + except Exception as e: + logging.warning("RedisDB.delete " + str(key) + " got exception: " + str(e)) + self.__open__() + return False + + REDIS_CONN = RedisDB()