mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-12 04:29:10 +08:00
Introduced task priority (#6118)
### What problem does this PR solve? Introduced task priority ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
parent
1842ca0334
commit
89a69eed72
@ -479,7 +479,7 @@ def upload():
|
|||||||
doc = doc.to_dict()
|
doc = doc.to_dict()
|
||||||
doc["tenant_id"] = tenant_id
|
doc["tenant_id"] = tenant_id
|
||||||
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
|
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
|
||||||
queue_tasks(doc, bucket, name)
|
queue_tasks(doc, bucket, name, 0)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
@ -380,7 +380,7 @@ def run():
|
|||||||
doc = doc.to_dict()
|
doc = doc.to_dict()
|
||||||
doc["tenant_id"] = tenant_id
|
doc["tenant_id"] = tenant_id
|
||||||
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
|
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
|
||||||
queue_tasks(doc, bucket, name)
|
queue_tasks(doc, bucket, name, 0)
|
||||||
|
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -693,7 +693,7 @@ def parse(tenant_id, dataset_id):
|
|||||||
doc = doc.to_dict()
|
doc = doc.to_dict()
|
||||||
doc["tenant_id"] = tenant_id
|
doc["tenant_id"] = tenant_id
|
||||||
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
|
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
|
||||||
queue_tasks(doc, bucket, name)
|
queue_tasks(doc, bucket, name, 0)
|
||||||
return get_result()
|
return get_result()
|
||||||
|
|
||||||
|
|
||||||
|
@ -845,6 +845,7 @@ class Task(DataBaseModel):
|
|||||||
from_page = IntegerField(default=0)
|
from_page = IntegerField(default=0)
|
||||||
to_page = IntegerField(default=100000000)
|
to_page = IntegerField(default=100000000)
|
||||||
task_type = CharField(max_length=32, null=False, default="")
|
task_type = CharField(max_length=32, null=False, default="")
|
||||||
|
priority = IntegerField(default=0)
|
||||||
|
|
||||||
begin_at = DateTimeField(null=True, index=True)
|
begin_at = DateTimeField(null=True, index=True)
|
||||||
process_duation = FloatField(default=0)
|
process_duation = FloatField(default=0)
|
||||||
@ -1122,3 +1123,10 @@ def migrate_db():
|
|||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
try:
|
||||||
|
migrate(
|
||||||
|
migrator.add_column("task", "priority",
|
||||||
|
IntegerField(default=0))
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
@ -34,7 +34,7 @@ from api.db.services.common_service import CommonService
|
|||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.utils import current_timestamp, get_format_time, get_uuid
|
from api.utils import current_timestamp, get_format_time, get_uuid
|
||||||
from rag.nlp import rag_tokenizer, search
|
from rag.nlp import rag_tokenizer, search
|
||||||
from rag.settings import SVR_QUEUE_NAME
|
from rag.settings import get_svr_queue_name
|
||||||
from rag.utils.redis_conn import REDIS_CONN
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
|
|
||||||
@ -392,6 +392,7 @@ class DocumentService(CommonService):
|
|||||||
has_graphrag = False
|
has_graphrag = False
|
||||||
e, doc = DocumentService.get_by_id(d["id"])
|
e, doc = DocumentService.get_by_id(d["id"])
|
||||||
status = doc.run # TaskStatus.RUNNING.value
|
status = doc.run # TaskStatus.RUNNING.value
|
||||||
|
priority = 0
|
||||||
for t in tsks:
|
for t in tsks:
|
||||||
if 0 <= t.progress < 1:
|
if 0 <= t.progress < 1:
|
||||||
finished = False
|
finished = False
|
||||||
@ -403,16 +404,17 @@ class DocumentService(CommonService):
|
|||||||
has_raptor = True
|
has_raptor = True
|
||||||
elif t.task_type == "graphrag":
|
elif t.task_type == "graphrag":
|
||||||
has_graphrag = True
|
has_graphrag = True
|
||||||
|
priority = max(priority, t.priority)
|
||||||
prg /= len(tsks)
|
prg /= len(tsks)
|
||||||
if finished and bad:
|
if finished and bad:
|
||||||
prg = -1
|
prg = -1
|
||||||
status = TaskStatus.FAIL.value
|
status = TaskStatus.FAIL.value
|
||||||
elif finished:
|
elif finished:
|
||||||
if d["parser_config"].get("raptor", {}).get("use_raptor") and not has_raptor:
|
if d["parser_config"].get("raptor", {}).get("use_raptor") and not has_raptor:
|
||||||
queue_raptor_o_graphrag_tasks(d, "raptor")
|
queue_raptor_o_graphrag_tasks(d, "raptor", priority)
|
||||||
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
||||||
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") and not has_graphrag:
|
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") and not has_graphrag:
|
||||||
queue_raptor_o_graphrag_tasks(d, "graphrag")
|
queue_raptor_o_graphrag_tasks(d, "graphrag", priority)
|
||||||
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
||||||
else:
|
else:
|
||||||
status = TaskStatus.DONE.value
|
status = TaskStatus.DONE.value
|
||||||
@ -449,7 +451,7 @@ class DocumentService(CommonService):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def queue_raptor_o_graphrag_tasks(doc, ty):
|
def queue_raptor_o_graphrag_tasks(doc, ty, priority):
|
||||||
chunking_config = DocumentService.get_chunking_config(doc["id"])
|
chunking_config = DocumentService.get_chunking_config(doc["id"])
|
||||||
hasher = xxhash.xxh64()
|
hasher = xxhash.xxh64()
|
||||||
for field in sorted(chunking_config.keys()):
|
for field in sorted(chunking_config.keys()):
|
||||||
@ -472,7 +474,7 @@ def queue_raptor_o_graphrag_tasks(doc, ty):
|
|||||||
hasher.update(ty.encode("utf-8"))
|
hasher.update(ty.encode("utf-8"))
|
||||||
task["digest"] = hasher.hexdigest()
|
task["digest"] = hasher.hexdigest()
|
||||||
bulk_insert_into_db(Task, [task], True)
|
bulk_insert_into_db(Task, [task], True)
|
||||||
assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."
|
assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status."
|
||||||
|
|
||||||
|
|
||||||
def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||||
|
@ -28,7 +28,7 @@ from api.db.services.common_service import CommonService
|
|||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.utils import current_timestamp, get_uuid
|
from api.utils import current_timestamp, get_uuid
|
||||||
from deepdoc.parser.excel_parser import RAGFlowExcelParser
|
from deepdoc.parser.excel_parser import RAGFlowExcelParser
|
||||||
from rag.settings import SVR_QUEUE_NAME
|
from rag.settings import get_svr_queue_name
|
||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
from rag.utils.redis_conn import REDIS_CONN
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
from api import settings
|
from api import settings
|
||||||
@ -289,7 +289,7 @@ class TaskService(CommonService):
|
|||||||
).execute()
|
).execute()
|
||||||
|
|
||||||
|
|
||||||
def queue_tasks(doc: dict, bucket: str, name: str):
|
def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
||||||
"""Create and queue document processing tasks.
|
"""Create and queue document processing tasks.
|
||||||
|
|
||||||
This function creates processing tasks for a document based on its type and configuration.
|
This function creates processing tasks for a document based on its type and configuration.
|
||||||
@ -301,6 +301,7 @@ def queue_tasks(doc: dict, bucket: str, name: str):
|
|||||||
doc (dict): Document dictionary containing metadata and configuration.
|
doc (dict): Document dictionary containing metadata and configuration.
|
||||||
bucket (str): Storage bucket name where the document is stored.
|
bucket (str): Storage bucket name where the document is stored.
|
||||||
name (str): File name of the document.
|
name (str): File name of the document.
|
||||||
|
priority (int, optional): Priority level for task queueing (default is 0).
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
- For PDF documents, tasks are created per page range based on configuration
|
- For PDF documents, tasks are created per page range based on configuration
|
||||||
@ -358,6 +359,7 @@ def queue_tasks(doc: dict, bucket: str, name: str):
|
|||||||
task_digest = hasher.hexdigest()
|
task_digest = hasher.hexdigest()
|
||||||
task["digest"] = task_digest
|
task["digest"] = task_digest
|
||||||
task["progress"] = 0.0
|
task["progress"] = 0.0
|
||||||
|
task["priority"] = priority
|
||||||
|
|
||||||
prev_tasks = TaskService.get_tasks(doc["id"])
|
prev_tasks = TaskService.get_tasks(doc["id"])
|
||||||
ck_num = 0
|
ck_num = 0
|
||||||
@ -380,7 +382,7 @@ def queue_tasks(doc: dict, bucket: str, name: str):
|
|||||||
unfinished_task_array = [task for task in parse_task_array if task["progress"] < 1.0]
|
unfinished_task_array = [task for task in parse_task_array if task["progress"] < 1.0]
|
||||||
for unfinished_task in unfinished_task_array:
|
for unfinished_task in unfinished_task_array:
|
||||||
assert REDIS_CONN.queue_product(
|
assert REDIS_CONN.queue_product(
|
||||||
SVR_QUEUE_NAME, message=unfinished_task
|
get_svr_queue_name(priority), message=unfinished_task
|
||||||
), "Can't access Redis. Please check the Redis' status."
|
), "Can't access Redis. Please check the Redis' status."
|
||||||
|
|
||||||
|
|
||||||
|
@ -35,16 +35,20 @@ except Exception:
|
|||||||
DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024))
|
DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024))
|
||||||
|
|
||||||
SVR_QUEUE_NAME = "rag_flow_svr_queue"
|
SVR_QUEUE_NAME = "rag_flow_svr_queue"
|
||||||
SVR_QUEUE_RETENTION = 60*60
|
SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_task_broker"
|
||||||
SVR_QUEUE_MAX_LEN = 1024
|
|
||||||
SVR_CONSUMER_NAME = "rag_flow_svr_consumer"
|
|
||||||
SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_consumer_group"
|
|
||||||
PAGERANK_FLD = "pagerank_fea"
|
PAGERANK_FLD = "pagerank_fea"
|
||||||
TAG_FLD = "tag_feas"
|
TAG_FLD = "tag_feas"
|
||||||
|
|
||||||
|
|
||||||
def print_rag_settings():
|
def print_rag_settings():
|
||||||
logging.info(f"MAX_CONTENT_LENGTH: {DOC_MAXIMUM_SIZE}")
|
logging.info(f"MAX_CONTENT_LENGTH: {DOC_MAXIMUM_SIZE}")
|
||||||
logging.info(f"SERVER_QUEUE_MAX_LEN: {SVR_QUEUE_MAX_LEN}")
|
|
||||||
logging.info(f"SERVER_QUEUE_RETENTION: {SVR_QUEUE_RETENTION}")
|
|
||||||
logging.info(f"MAX_FILE_COUNT_PER_USER: {int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))}")
|
logging.info(f"MAX_FILE_COUNT_PER_USER: {int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_svr_queue_name(priority: int) -> str:
|
||||||
|
if priority == 0:
|
||||||
|
return SVR_QUEUE_NAME
|
||||||
|
return f"{SVR_QUEUE_NAME}_{priority}"
|
||||||
|
|
||||||
|
def get_svr_queue_names():
|
||||||
|
return [get_svr_queue_name(priority) for priority in [1, 0]]
|
||||||
|
@ -56,7 +56,7 @@ from rag.app import laws, paper, presentation, manual, qa, table, book, resume,
|
|||||||
email, tag
|
email, tag
|
||||||
from rag.nlp import search, rag_tokenizer
|
from rag.nlp import search, rag_tokenizer
|
||||||
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
||||||
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings, TAG_FLD, PAGERANK_FLD
|
from rag.settings import DOC_MAXIMUM_SIZE, SVR_CONSUMER_GROUP_NAME, get_svr_queue_name, get_svr_queue_names, print_rag_settings, TAG_FLD, PAGERANK_FLD
|
||||||
from rag.utils import num_tokens_from_string
|
from rag.utils import num_tokens_from_string
|
||||||
from rag.utils.redis_conn import REDIS_CONN
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
@ -171,20 +171,23 @@ def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing...
|
|||||||
async def collect():
|
async def collect():
|
||||||
global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS
|
global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS
|
||||||
global UNACKED_ITERATOR
|
global UNACKED_ITERATOR
|
||||||
|
svr_queue_names = get_svr_queue_names()
|
||||||
try:
|
try:
|
||||||
if not UNACKED_ITERATOR:
|
if not UNACKED_ITERATOR:
|
||||||
UNACKED_ITERATOR = REDIS_CONN.get_unacked_iterator(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
|
UNACKED_ITERATOR = REDIS_CONN.get_unacked_iterator(svr_queue_names, SVR_CONSUMER_GROUP_NAME, CONSUMER_NAME)
|
||||||
try:
|
try:
|
||||||
redis_msg = next(UNACKED_ITERATOR)
|
redis_msg = next(UNACKED_ITERATOR)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
redis_msg = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
|
for svr_queue_name in svr_queue_names:
|
||||||
if not redis_msg:
|
redis_msg = REDIS_CONN.queue_consumer(svr_queue_name, SVR_CONSUMER_GROUP_NAME, CONSUMER_NAME)
|
||||||
await trio.sleep(1)
|
if redis_msg:
|
||||||
return None, None
|
break
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception("collect got exception")
|
logging.exception("collect got exception")
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
if not redis_msg:
|
||||||
|
return None, None
|
||||||
msg = redis_msg.get_message()
|
msg = redis_msg.get_message()
|
||||||
if not msg:
|
if not msg:
|
||||||
logging.error(f"collect got empty message of {redis_msg.get_msg_id()}")
|
logging.error(f"collect got empty message of {redis_msg.get_msg_id()}")
|
||||||
@ -615,7 +618,7 @@ async def report_status():
|
|||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
group_info = REDIS_CONN.queue_info(SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
|
group_info = REDIS_CONN.queue_info(get_svr_queue_name(0), SVR_CONSUMER_GROUP_NAME)
|
||||||
if group_info is not None:
|
if group_info is not None:
|
||||||
PENDING_TASKS = int(group_info.get("pending", 0))
|
PENDING_TASKS = int(group_info.get("pending", 0))
|
||||||
LAG_TASKS = int(group_info.get("lag", 0))
|
LAG_TASKS = int(group_info.get("lag", 0))
|
||||||
|
@ -193,14 +193,11 @@ class RedisDB:
|
|||||||
self.__open__()
|
self.__open__()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def queue_product(self, queue, message, exp=settings.SVR_QUEUE_RETENTION) -> bool:
|
def queue_product(self, queue, message) -> bool:
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
try:
|
try:
|
||||||
payload = {"message": json.dumps(message)}
|
payload = {"message": json.dumps(message)}
|
||||||
pipeline = self.REDIS.pipeline()
|
self.REDIS.xadd(queue, payload)
|
||||||
pipeline.xadd(queue, payload)
|
|
||||||
# pipeline.expire(queue, exp)
|
|
||||||
pipeline.execute()
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(
|
logging.exception(
|
||||||
@ -242,16 +239,17 @@ class RedisDB:
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_unacked_iterator(self, queue_name, group_name, consumer_name):
|
def get_unacked_iterator(self, queue_names: list[str], group_name, consumer_name):
|
||||||
try:
|
try:
|
||||||
|
for queue_name in queue_names:
|
||||||
group_info = self.REDIS.xinfo_groups(queue_name)
|
group_info = self.REDIS.xinfo_groups(queue_name)
|
||||||
if not any(e["name"] == group_name for e in group_info):
|
if not any(e["name"] == group_name for e in group_info):
|
||||||
return
|
continue
|
||||||
current_min = 0
|
current_min = 0
|
||||||
while True:
|
while True:
|
||||||
payload = self.queue_consumer(queue_name, group_name, consumer_name, current_min)
|
payload = self.queue_consumer(queue_name, group_name, consumer_name, current_min)
|
||||||
if not payload:
|
if not payload:
|
||||||
return
|
break
|
||||||
current_min = payload.get_msg_id()
|
current_min = payload.get_msg_id()
|
||||||
logging.info(f"RedisDB.get_unacked_iterator {consumer_name} msg_id {current_min}")
|
logging.info(f"RedisDB.get_unacked_iterator {consumer_name} msg_id {current_min}")
|
||||||
yield payload
|
yield payload
|
||||||
|
Loading…
x
Reference in New Issue
Block a user