ragflow/api/db/services/task_service.py
Zhichang Yu 89a69eed72
Introduced task priority (#6118)
### What problem does this PR solve?

Introduced task priority

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-03-14 23:43:46 +08:00

434 lines
17 KiB
Python

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import random
import xxhash
from datetime import datetime
from api.db.db_utils import bulk_insert_into_db
from deepdoc.parser import PdfParser
from peewee import JOIN
from api.db.db_models import DB, File2Document, File
from api.db import StatusEnum, FileType, TaskStatus
from api.db.db_models import Task, Document, Knowledgebase, Tenant
from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService
from api.utils import current_timestamp, get_uuid
from deepdoc.parser.excel_parser import RAGFlowExcelParser
from rag.settings import get_svr_queue_name
from rag.utils.storage_factory import STORAGE_IMPL
from rag.utils.redis_conn import REDIS_CONN
from api import settings
from rag.nlp import search
def trim_header_by_lines(text: str, max_length) -> str:
# Trim header text to maximum length while preserving line breaks
# Args:
# text: Input text to trim
# max_length: Maximum allowed length
# Returns:
# Trimmed text
len_text = len(text)
if len_text <= max_length:
return text
for i in range(len_text):
if text[i] == '\n' and len_text - i <= max_length:
return text[i + 1:]
return text
class TaskService(CommonService):
"""Service class for managing document processing tasks.
This class extends CommonService to provide specialized functionality for document
processing task management, including task creation, progress tracking, and chunk
management. It handles various document types (PDF, Excel, etc.) and manages their
processing lifecycle.
The class implements a robust task queue system with retry mechanisms and progress
tracking, supporting both synchronous and asynchronous task execution.
Attributes:
model: The Task model class for database operations.
"""
model = Task
@classmethod
@DB.connection_context()
def get_task(cls, task_id):
"""Retrieve detailed task information by task ID.
This method fetches comprehensive task details including associated document,
knowledge base, and tenant information. It also handles task retry logic and
progress updates.
Args:
task_id (str): The unique identifier of the task to retrieve.
Returns:
dict: Task details dictionary containing all task information and related metadata.
Returns None if task is not found or has exceeded retry limit.
"""
fields = [
cls.model.id,
cls.model.doc_id,
cls.model.from_page,
cls.model.to_page,
cls.model.retry_count,
Document.kb_id,
Document.parser_id,
Document.parser_config,
Document.name,
Document.type,
Document.location,
Document.size,
Knowledgebase.tenant_id,
Knowledgebase.language,
Knowledgebase.embd_id,
Knowledgebase.pagerank,
Knowledgebase.parser_config.alias("kb_parser_config"),
Tenant.img2txt_id,
Tenant.asr_id,
Tenant.llm_id,
cls.model.update_time,
]
docs = (
cls.model.select(*fields)
.join(Document, on=(cls.model.doc_id == Document.id))
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where(cls.model.id == task_id)
)
docs = list(docs.dicts())
if not docs:
return None
msg = f"\n{datetime.now().strftime('%H:%M:%S')} Task has been received."
prog = random.random() / 10.0
if docs[0]["retry_count"] >= 3:
msg = "\nERROR: Task is abandoned after 3 times attempts."
prog = -1
cls.model.update(
progress_msg=cls.model.progress_msg + msg,
progress=prog,
retry_count=docs[0]["retry_count"] + 1,
).where(cls.model.id == docs[0]["id"]).execute()
if docs[0]["retry_count"] >= 3:
return None
return docs[0]
@classmethod
@DB.connection_context()
def get_tasks(cls, doc_id: str):
"""Retrieve all tasks associated with a document.
This method fetches all processing tasks for a given document, ordered by page
number and creation time. It includes task progress and chunk information.
Args:
doc_id (str): The unique identifier of the document.
Returns:
list[dict]: List of task dictionaries containing task details.
Returns None if no tasks are found.
"""
fields = [
cls.model.id,
cls.model.from_page,
cls.model.progress,
cls.model.digest,
cls.model.chunk_ids,
]
tasks = (
cls.model.select(*fields).order_by(cls.model.from_page.asc(), cls.model.create_time.desc())
.where(cls.model.doc_id == doc_id)
)
tasks = list(tasks.dicts())
if not tasks:
return None
return tasks
@classmethod
@DB.connection_context()
def update_chunk_ids(cls, id: str, chunk_ids: str):
"""Update the chunk IDs associated with a task.
This method updates the chunk_ids field of a task, which stores the IDs of
processed document chunks in a space-separated string format.
Args:
id (str): The unique identifier of the task.
chunk_ids (str): Space-separated string of chunk identifiers.
"""
cls.model.update(chunk_ids=chunk_ids).where(cls.model.id == id).execute()
@classmethod
@DB.connection_context()
def get_ongoing_doc_name(cls):
"""Get names of documents that are currently being processed.
This method retrieves information about documents that are in the processing state,
including their locations and associated IDs. It uses database locking to ensure
thread safety when accessing the task information.
Returns:
list[tuple]: A list of tuples, each containing (parent_id/kb_id, location)
for documents currently being processed. Returns empty list if
no documents are being processed.
"""
with DB.lock("get_task", -1):
docs = (
cls.model.select(
*[Document.id, Document.kb_id, Document.location, File.parent_id]
)
.join(Document, on=(cls.model.doc_id == Document.id))
.join(
File2Document,
on=(File2Document.document_id == Document.id),
join_type=JOIN.LEFT_OUTER,
)
.join(
File,
on=(File2Document.file_id == File.id),
join_type=JOIN.LEFT_OUTER,
)
.where(
Document.status == StatusEnum.VALID.value,
Document.run == TaskStatus.RUNNING.value,
~(Document.type == FileType.VIRTUAL.value),
cls.model.progress < 1,
cls.model.create_time >= current_timestamp() - 1000 * 600,
)
)
docs = list(docs.dicts())
if not docs:
return []
return list(
set(
[
(
d["parent_id"] if d["parent_id"] else d["kb_id"],
d["location"],
)
for d in docs
]
)
)
@classmethod
@DB.connection_context()
def do_cancel(cls, id):
"""Check if a task should be cancelled based on its document status.
This method determines whether a task should be cancelled by checking the
associated document's run status and progress. A task should be cancelled
if its document is marked for cancellation or has negative progress.
Args:
id (str): The unique identifier of the task to check.
Returns:
bool: True if the task should be cancelled, False otherwise.
"""
task = cls.model.get_by_id(id)
_, doc = DocumentService.get_by_id(task.doc_id)
return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
@classmethod
@DB.connection_context()
def update_progress(cls, id, info):
"""Update the progress information for a task.
This method updates both the progress message and completion percentage of a task.
It handles platform-specific behavior (macOS vs others) and uses database locking
when necessary to ensure thread safety.
Args:
id (str): The unique identifier of the task to update.
info (dict): Dictionary containing progress information with keys:
- progress_msg (str, optional): Progress message to append
- progress (float, optional): Progress percentage (0.0 to 1.0)
"""
if os.environ.get("MACOS"):
if info["progress_msg"]:
task = cls.model.get_by_id(id)
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000)
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
if "progress" in info:
cls.model.update(progress=info["progress"]).where(
cls.model.id == id
).execute()
return
with DB.lock("update_progress", -1):
if info["progress_msg"]:
task = cls.model.get_by_id(id)
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000)
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
if "progress" in info:
cls.model.update(progress=info["progress"]).where(
cls.model.id == id
).execute()
def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
"""Create and queue document processing tasks.
This function creates processing tasks for a document based on its type and configuration.
It handles different document types (PDF, Excel, etc.) differently and manages task
chunking and configuration. It also implements task reuse optimization by checking
for previously completed tasks.
Args:
doc (dict): Document dictionary containing metadata and configuration.
bucket (str): Storage bucket name where the document is stored.
name (str): File name of the document.
priority (int, optional): Priority level for task queueing (default is 0).
Note:
- For PDF documents, tasks are created per page range based on configuration
- For Excel documents, tasks are created per row range
- Task digests are calculated for optimization and reuse
- Previous task chunks may be reused if available
"""
def new_task():
return {"id": get_uuid(), "doc_id": doc["id"], "progress": 0.0, "from_page": 0, "to_page": 100000000}
parse_task_array = []
if doc["type"] == FileType.PDF.value:
file_bin = STORAGE_IMPL.get(bucket, name)
do_layout = doc["parser_config"].get("layout_recognize", "DeepDOC")
pages = PdfParser.total_page_number(doc["name"], file_bin)
page_size = doc["parser_config"].get("task_page_size", 12)
if doc["parser_id"] == "paper":
page_size = doc["parser_config"].get("task_page_size", 22)
if doc["parser_id"] in ["one", "knowledge_graph"] or do_layout != "DeepDOC":
page_size = 10 ** 9
page_ranges = doc["parser_config"].get("pages") or [(1, 10 ** 5)]
for s, e in page_ranges:
s -= 1
s = max(0, s)
e = min(e - 1, pages)
for p in range(s, e, page_size):
task = new_task()
task["from_page"] = p
task["to_page"] = min(p + page_size, e)
parse_task_array.append(task)
elif doc["parser_id"] == "table":
file_bin = STORAGE_IMPL.get(bucket, name)
rn = RAGFlowExcelParser.row_number(doc["name"], file_bin)
for i in range(0, rn, 3000):
task = new_task()
task["from_page"] = i
task["to_page"] = min(i + 3000, rn)
parse_task_array.append(task)
else:
parse_task_array.append(new_task())
chunking_config = DocumentService.get_chunking_config(doc["id"])
for task in parse_task_array:
hasher = xxhash.xxh64()
for field in sorted(chunking_config.keys()):
if field == "parser_config":
for k in ["raptor", "graphrag"]:
if k in chunking_config[field]:
del chunking_config[field][k]
hasher.update(str(chunking_config[field]).encode("utf-8"))
for field in ["doc_id", "from_page", "to_page"]:
hasher.update(str(task.get(field, "")).encode("utf-8"))
task_digest = hasher.hexdigest()
task["digest"] = task_digest
task["progress"] = 0.0
task["priority"] = priority
prev_tasks = TaskService.get_tasks(doc["id"])
ck_num = 0
if prev_tasks:
for task in parse_task_array:
ck_num += reuse_prev_task_chunks(task, prev_tasks, chunking_config)
TaskService.filter_delete([Task.doc_id == doc["id"]])
chunk_ids = []
for task in prev_tasks:
if task["chunk_ids"]:
chunk_ids.extend(task["chunk_ids"].split())
if chunk_ids:
settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(chunking_config["tenant_id"]),
chunking_config["kb_id"])
DocumentService.update_by_id(doc["id"], {"chunk_num": ck_num})
bulk_insert_into_db(Task, parse_task_array, True)
DocumentService.begin2parse(doc["id"])
unfinished_task_array = [task for task in parse_task_array if task["progress"] < 1.0]
for unfinished_task in unfinished_task_array:
assert REDIS_CONN.queue_product(
get_svr_queue_name(priority), message=unfinished_task
), "Can't access Redis. Please check the Redis' status."
def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict):
"""Attempt to reuse chunks from previous tasks for optimization.
This function checks if chunks from previously completed tasks can be reused for
the current task, which can significantly improve processing efficiency. It matches
tasks based on page ranges and configuration digests.
Args:
task (dict): Current task dictionary to potentially reuse chunks for.
prev_tasks (list[dict]): List of previous task dictionaries to check for reuse.
chunking_config (dict): Configuration dictionary for chunk processing.
Returns:
int: Number of chunks successfully reused. Returns 0 if no chunks could be reused.
Note:
Chunks can only be reused if:
- A previous task exists with matching page range and configuration digest
- The previous task was completed successfully (progress = 1.0)
- The previous task has valid chunk IDs
"""
idx = 0
while idx < len(prev_tasks):
prev_task = prev_tasks[idx]
if prev_task.get("from_page", 0) == task.get("from_page", 0) \
and prev_task.get("digest", 0) == task.get("digest", ""):
break
idx += 1
if idx >= len(prev_tasks):
return 0
prev_task = prev_tasks[idx]
if prev_task["progress"] < 1.0 or not prev_task["chunk_ids"]:
return 0
task["chunk_ids"] = prev_task["chunk_ids"]
task["progress"] = 1.0
if "from_page" in task and "to_page" in task and int(task['to_page']) - int(task['from_page']) >= 10 ** 6:
task["progress_msg"] = f"Page({task['from_page']}~{task['to_page']}): "
else:
task["progress_msg"] = ""
task["progress_msg"] = " ".join(
[datetime.now().strftime("%H:%M:%S"), task["progress_msg"], "Reused previous task's chunks."])
prev_task["chunk_ids"] = ""
return len(task["chunk_ids"].split())