feat: Recover pending tasks while pod restart. (#7073)

### What problem does this PR solve?

If you deploy Ragflow using Kubernetes, the hostname will change during
a rolling update. This causes the consumer name of the task executor to
change, making it impossible to schedule tasks that were previously in a
pending state.
To address this, I introduced a recovery task that scans these pending
messages and re-publishes them, allowing the tasks to continue being
processed.

### Type of change

- [ ] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
- [ ] Documentation Update
- [ ] Refactoring
- [ ] Performance Improvement
- [ ] Other (please describe):

---------

Co-authored-by: liuzhenghua-jk <liuzhenghua-jk@360shuke.com>
This commit is contained in:
liuzhenghua 2025-04-19 16:18:51 +08:00 committed by GitHub
parent 487aed419e
commit d4dbdfb61d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 92 additions and 2 deletions

View File

@ -18,6 +18,8 @@
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
import random
import sys
import threading
import time
from api.utils.log_utils import initRootLogger, get_project_base_directory
from graphrag.general.index import run_graphrag
@ -58,7 +60,7 @@ from rag.nlp import search, rag_tokenizer
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
from rag.settings import DOC_MAXIMUM_SIZE, SVR_CONSUMER_GROUP_NAME, get_svr_queue_name, get_svr_queue_names, print_rag_settings, TAG_FLD, PAGERANK_FLD
from rag.utils import num_tokens_from_string, truncate
from rag.utils.redis_conn import REDIS_CONN
from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock
from rag.utils.storage_factory import STORAGE_IMPL
from graphrag.utils import chat_limiter
@ -99,6 +101,15 @@ MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
MAX_CONCURRENT_CHUNK_BUILDERS = int(os.environ.get('MAX_CONCURRENT_CHUNK_BUILDERS', "1"))
task_limiter = trio.CapacityLimiter(MAX_CONCURRENT_TASKS)
chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS)
WORKER_HEARTBEAT_TIMEOUT = int(os.environ.get('WORKER_HEARTBEAT_TIMEOUT', '120'))
stop_event = threading.Event()
def signal_handler(sig, frame):
logging.info("Received interrupt signal, shutting down...")
stop_event.set()
time.sleep(1)
sys.exit(0)
# SIGUSR1 handler: start tracemalloc and take snapshot
@ -621,6 +632,7 @@ async def handle_task():
async def report_status():
global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, DONE_TASKS, FAILED_TASKS
REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME)
redis_lock = RedisDistributedLock("clean_task_executor", lock_value=CONSUMER_NAME, timeout=60)
while True:
try:
now = datetime.now()
@ -646,11 +658,52 @@ async def report_status():
expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60 * 30)
if expired > 0:
REDIS_CONN.zpopmin(CONSUMER_NAME, expired)
# clean task executor
if redis_lock.acquire():
task_executors = REDIS_CONN.smembers("TASKEXE")
for consumer_name in task_executors:
if consumer_name == CONSUMER_NAME:
continue
expired = REDIS_CONN.zcount(
consumer_name, now.timestamp() - WORKER_HEARTBEAT_TIMEOUT, now.timestamp() + 10
)
if expired == 0:
logging.info(f"{consumer_name} expired, removed")
REDIS_CONN.srem("TASKEXE", consumer_name)
REDIS_CONN.delete(consumer_name)
except Exception:
logging.exception("report_status got exception")
await trio.sleep(30)
def recover_pending_tasks():
redis_lock = RedisDistributedLock("recover_pending_tasks", lock_value=CONSUMER_NAME, timeout=60)
svr_queue_names = get_svr_queue_names()
while not stop_event.is_set():
try:
if redis_lock.acquire():
for queue_name in svr_queue_names:
msgs = REDIS_CONN.get_pending_msg(queue=queue_name, group_name=SVR_CONSUMER_GROUP_NAME)
msgs = [msg for msg in msgs if msg['consumer'] != CONSUMER_NAME]
if len(msgs) == 0:
continue
task_executors = REDIS_CONN.smembers("TASKEXE")
task_executor_set = {t for t in task_executors}
msgs = [msg for msg in msgs if msg['consumer'] not in task_executor_set]
for msg in msgs:
logging.info(
f"Recover pending task: {msg['message_id']}, consumer: {msg['consumer']}, "
f"time since delivered: {msg['time_since_delivered'] / 1000} s"
)
REDIS_CONN.requeue_msg(queue_name, SVR_CONSUMER_GROUP_NAME, msg['message_id'])
stop_event.wait(60)
except Exception:
logging.warning("recover_pending_tasks got exception")
async def main():
logging.info(r"""
______ __ ______ __
@ -669,9 +722,14 @@ async def main():
if TRACE_MALLOC_ENABLED:
start_tracemalloc_and_snapshot(None, None)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
threading.Thread(name="RecoverPendingTask", target=recover_pending_tasks).start()
async with trio.open_nursery() as nursery:
nursery.start_soon(report_status)
while True:
while not stop_event.is_set():
async with task_limiter:
nursery.start_soon(handle_task)
logging.error("BUG!!! You should not reach here!!!")

View File

@ -282,6 +282,28 @@ class RedisDB:
)
self.__open__()
def get_pending_msg(self, queue, group_name):
try:
messages = self.REDIS.xpending_range(queue, group_name, '-', '+', 10)
return messages
except Exception as e:
if 'No such key' not in (str(e) or ''):
logging.warning(
"RedisDB.get_pending_msg " + str(queue) + " got exception: " + str(e)
)
return []
def requeue_msg(self, queue: str, group_name: str, msg_id: str):
try:
messages = self.REDIS.xrange(queue, msg_id, msg_id)
if messages:
self.REDIS.xadd(queue, messages[0][1])
self.REDIS.xack(queue, group_name, msg_id)
except Exception as e:
logging.warning(
"RedisDB.get_pending_msg " + str(queue) + " got exception: " + str(e)
)
def queue_info(self, queue, group_name) -> dict | None:
try:
groups = self.REDIS.xinfo_groups(queue)
@ -301,6 +323,16 @@ class RedisDB:
"""
return bool(self.lua_delete_if_equal(keys=[key], args=[expected_value], client=self.REDIS))
def delete(self, key) -> bool:
try:
self.REDIS.delete(key)
return True
except Exception as e:
logging.warning("RedisDB.delete " + str(key) + " got exception: " + str(e))
self.__open__()
return False
REDIS_CONN = RedisDB()