From 89a69eed721354ac99afcfec76d81fe3296190a5 Mon Sep 17 00:00:00 2001 From: Zhichang Yu Date: Fri, 14 Mar 2025 23:43:46 +0800 Subject: [PATCH] Introduced task priority (#6118) ### What problem does this PR solve? Introduced task priority ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- api/apps/api_app.py | 2 +- api/apps/document_app.py | 2 +- api/apps/sdk/doc.py | 2 +- api/db/db_models.py | 8 ++++++++ api/db/services/document_service.py | 12 ++++++----- api/db/services/task_service.py | 8 +++++--- rag/settings.py | 18 +++++++++------- rag/svr/task_executor.py | 17 ++++++++------- rag/utils/redis_conn.py | 32 ++++++++++++++--------------- 9 files changed, 59 insertions(+), 42 deletions(-) diff --git a/api/apps/api_app.py b/api/apps/api_app.py index 533cffaee..d30da1d82 100644 --- a/api/apps/api_app.py +++ b/api/apps/api_app.py @@ -479,7 +479,7 @@ def upload(): doc = doc.to_dict() doc["tenant_id"] = tenant_id bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"]) - queue_tasks(doc, bucket, name) + queue_tasks(doc, bucket, name, 0) except Exception as e: return server_error_response(e) diff --git a/api/apps/document_app.py b/api/apps/document_app.py index ec7db5f70..16ad51ee9 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -380,7 +380,7 @@ def run(): doc = doc.to_dict() doc["tenant_id"] = tenant_id bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"]) - queue_tasks(doc, bucket, name) + queue_tasks(doc, bucket, name, 0) return get_json_result(data=True) except Exception as e: diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index 6ecb2fa97..4165751bd 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -693,7 +693,7 @@ def parse(tenant_id, dataset_id): doc = doc.to_dict() doc["tenant_id"] = tenant_id bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"]) - queue_tasks(doc, bucket, name) + queue_tasks(doc, bucket, name, 0) return get_result() diff --git a/api/db/db_models.py b/api/db/db_models.py index 524cb8b58..4bc8e975d 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -845,6 +845,7 @@ class Task(DataBaseModel): from_page = IntegerField(default=0) to_page = IntegerField(default=100000000) task_type = CharField(max_length=32, null=False, default="") + priority = IntegerField(default=0) begin_at = DateTimeField(null=True, index=True) process_duation = FloatField(default=0) @@ -1122,3 +1123,10 @@ def migrate_db(): ) except Exception: pass + try: + migrate( + migrator.add_column("task", "priority", + IntegerField(default=0)) + ) + except Exception: + pass diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index b0c560e32..81203f086 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -34,7 +34,7 @@ from api.db.services.common_service import CommonService from api.db.services.knowledgebase_service import KnowledgebaseService from api.utils import current_timestamp, get_format_time, get_uuid from rag.nlp import rag_tokenizer, search -from rag.settings import SVR_QUEUE_NAME +from rag.settings import get_svr_queue_name from rag.utils.redis_conn import REDIS_CONN from rag.utils.storage_factory import STORAGE_IMPL @@ -392,6 +392,7 @@ class DocumentService(CommonService): has_graphrag = False e, doc = DocumentService.get_by_id(d["id"]) status = doc.run # TaskStatus.RUNNING.value + priority = 0 for t in tsks: if 0 <= t.progress < 1: finished = False @@ -403,16 +404,17 @@ class DocumentService(CommonService): has_raptor = True elif t.task_type == "graphrag": has_graphrag = True + priority = max(priority, t.priority) prg /= len(tsks) if finished and bad: prg = -1 status = TaskStatus.FAIL.value elif finished: if d["parser_config"].get("raptor", {}).get("use_raptor") and not has_raptor: - queue_raptor_o_graphrag_tasks(d, "raptor") + queue_raptor_o_graphrag_tasks(d, "raptor", priority) prg = 0.98 * len(tsks) / (len(tsks) + 1) elif d["parser_config"].get("graphrag", {}).get("use_graphrag") and not has_graphrag: - queue_raptor_o_graphrag_tasks(d, "graphrag") + queue_raptor_o_graphrag_tasks(d, "graphrag", priority) prg = 0.98 * len(tsks) / (len(tsks) + 1) else: status = TaskStatus.DONE.value @@ -449,7 +451,7 @@ class DocumentService(CommonService): return False -def queue_raptor_o_graphrag_tasks(doc, ty): +def queue_raptor_o_graphrag_tasks(doc, ty, priority): chunking_config = DocumentService.get_chunking_config(doc["id"]) hasher = xxhash.xxh64() for field in sorted(chunking_config.keys()): @@ -472,7 +474,7 @@ def queue_raptor_o_graphrag_tasks(doc, ty): hasher.update(ty.encode("utf-8")) task["digest"] = hasher.hexdigest() bulk_insert_into_db(Task, [task], True) - assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status." + assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status." def doc_upload_and_parse(conversation_id, file_objs, user_id): diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py index d75954365..4693a6684 100644 --- a/api/db/services/task_service.py +++ b/api/db/services/task_service.py @@ -28,7 +28,7 @@ from api.db.services.common_service import CommonService from api.db.services.document_service import DocumentService from api.utils import current_timestamp, get_uuid from deepdoc.parser.excel_parser import RAGFlowExcelParser -from rag.settings import SVR_QUEUE_NAME +from rag.settings import get_svr_queue_name from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.redis_conn import REDIS_CONN from api import settings @@ -289,7 +289,7 @@ class TaskService(CommonService): ).execute() -def queue_tasks(doc: dict, bucket: str, name: str): +def queue_tasks(doc: dict, bucket: str, name: str, priority: int): """Create and queue document processing tasks. This function creates processing tasks for a document based on its type and configuration. @@ -301,6 +301,7 @@ def queue_tasks(doc: dict, bucket: str, name: str): doc (dict): Document dictionary containing metadata and configuration. bucket (str): Storage bucket name where the document is stored. name (str): File name of the document. + priority (int, optional): Priority level for task queueing (default is 0). Note: - For PDF documents, tasks are created per page range based on configuration @@ -358,6 +359,7 @@ def queue_tasks(doc: dict, bucket: str, name: str): task_digest = hasher.hexdigest() task["digest"] = task_digest task["progress"] = 0.0 + task["priority"] = priority prev_tasks = TaskService.get_tasks(doc["id"]) ck_num = 0 @@ -380,7 +382,7 @@ def queue_tasks(doc: dict, bucket: str, name: str): unfinished_task_array = [task for task in parse_task_array if task["progress"] < 1.0] for unfinished_task in unfinished_task_array: assert REDIS_CONN.queue_product( - SVR_QUEUE_NAME, message=unfinished_task + get_svr_queue_name(priority), message=unfinished_task ), "Can't access Redis. Please check the Redis' status." diff --git a/rag/settings.py b/rag/settings.py index 72f48c04e..a54d65826 100644 --- a/rag/settings.py +++ b/rag/settings.py @@ -35,16 +35,20 @@ except Exception: DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024)) SVR_QUEUE_NAME = "rag_flow_svr_queue" -SVR_QUEUE_RETENTION = 60*60 -SVR_QUEUE_MAX_LEN = 1024 -SVR_CONSUMER_NAME = "rag_flow_svr_consumer" -SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_consumer_group" +SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_task_broker" PAGERANK_FLD = "pagerank_fea" TAG_FLD = "tag_feas" def print_rag_settings(): logging.info(f"MAX_CONTENT_LENGTH: {DOC_MAXIMUM_SIZE}") - logging.info(f"SERVER_QUEUE_MAX_LEN: {SVR_QUEUE_MAX_LEN}") - logging.info(f"SERVER_QUEUE_RETENTION: {SVR_QUEUE_RETENTION}") - logging.info(f"MAX_FILE_COUNT_PER_USER: {int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))}") \ No newline at end of file + logging.info(f"MAX_FILE_COUNT_PER_USER: {int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))}") + + +def get_svr_queue_name(priority: int) -> str: + if priority == 0: + return SVR_QUEUE_NAME + return f"{SVR_QUEUE_NAME}_{priority}" + +def get_svr_queue_names(): + return [get_svr_queue_name(priority) for priority in [1, 0]] diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index d14ebc51a..bc5157ee4 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -56,7 +56,7 @@ from rag.app import laws, paper, presentation, manual, qa, table, book, resume, email, tag from rag.nlp import search, rag_tokenizer from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor -from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings, TAG_FLD, PAGERANK_FLD +from rag.settings import DOC_MAXIMUM_SIZE, SVR_CONSUMER_GROUP_NAME, get_svr_queue_name, get_svr_queue_names, print_rag_settings, TAG_FLD, PAGERANK_FLD from rag.utils import num_tokens_from_string from rag.utils.redis_conn import REDIS_CONN from rag.utils.storage_factory import STORAGE_IMPL @@ -171,20 +171,23 @@ def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing... async def collect(): global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS global UNACKED_ITERATOR + svr_queue_names = get_svr_queue_names() try: if not UNACKED_ITERATOR: - UNACKED_ITERATOR = REDIS_CONN.get_unacked_iterator(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME) + UNACKED_ITERATOR = REDIS_CONN.get_unacked_iterator(svr_queue_names, SVR_CONSUMER_GROUP_NAME, CONSUMER_NAME) try: redis_msg = next(UNACKED_ITERATOR) except StopIteration: - redis_msg = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME) - if not redis_msg: - await trio.sleep(1) - return None, None + for svr_queue_name in svr_queue_names: + redis_msg = REDIS_CONN.queue_consumer(svr_queue_name, SVR_CONSUMER_GROUP_NAME, CONSUMER_NAME) + if redis_msg: + break except Exception: logging.exception("collect got exception") return None, None + if not redis_msg: + return None, None msg = redis_msg.get_message() if not msg: logging.error(f"collect got empty message of {redis_msg.get_msg_id()}") @@ -615,7 +618,7 @@ async def report_status(): while True: try: now = datetime.now() - group_info = REDIS_CONN.queue_info(SVR_QUEUE_NAME, "rag_flow_svr_task_broker") + group_info = REDIS_CONN.queue_info(get_svr_queue_name(0), SVR_CONSUMER_GROUP_NAME) if group_info is not None: PENDING_TASKS = int(group_info.get("pending", 0)) LAG_TASKS = int(group_info.get("lag", 0)) diff --git a/rag/utils/redis_conn.py b/rag/utils/redis_conn.py index 75acb4837..98a152c89 100644 --- a/rag/utils/redis_conn.py +++ b/rag/utils/redis_conn.py @@ -193,14 +193,11 @@ class RedisDB: self.__open__() return False - def queue_product(self, queue, message, exp=settings.SVR_QUEUE_RETENTION) -> bool: + def queue_product(self, queue, message) -> bool: for _ in range(3): try: payload = {"message": json.dumps(message)} - pipeline = self.REDIS.pipeline() - pipeline.xadd(queue, payload) - # pipeline.expire(queue, exp) - pipeline.execute() + self.REDIS.xadd(queue, payload) return True except Exception as e: logging.exception( @@ -242,19 +239,20 @@ class RedisDB: ) return None - def get_unacked_iterator(self, queue_name, group_name, consumer_name): + def get_unacked_iterator(self, queue_names: list[str], group_name, consumer_name): try: - group_info = self.REDIS.xinfo_groups(queue_name) - if not any(e["name"] == group_name for e in group_info): - return - current_min = 0 - while True: - payload = self.queue_consumer(queue_name, group_name, consumer_name, current_min) - if not payload: - return - current_min = payload.get_msg_id() - logging.info(f"RedisDB.get_unacked_iterator {consumer_name} msg_id {current_min}") - yield payload + for queue_name in queue_names: + group_info = self.REDIS.xinfo_groups(queue_name) + if not any(e["name"] == group_name for e in group_info): + continue + current_min = 0 + while True: + payload = self.queue_consumer(queue_name, group_name, consumer_name, current_min) + if not payload: + break + current_min = payload.get_msg_id() + logging.info(f"RedisDB.get_unacked_iterator {consumer_name} msg_id {current_min}") + yield payload except Exception as e: if "key" in str(e): return