Fix raptor resuable issue. (#4063)

### What problem does this PR solve?

#4045

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu 2024-12-17 15:28:35 +08:00 committed by GitHub
parent 4a95349492
commit fddac1345d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 30 additions and 16 deletions

View File

@ -344,6 +344,8 @@ class DocumentService(CommonService):
old[k] = v
dfs_update(d.parser_config, config)
if not config.get("raptor") and d.parser_config.get("raptor"):
del d.parser_config["raptor"]
cls.update_by_id(id, {"parser_config": d.parser_config})
@classmethod
@ -432,6 +434,11 @@ class DocumentService(CommonService):
def queue_raptor_tasks(doc):
chunking_config = DocumentService.get_chunking_config(doc["id"])
hasher = xxhash.xxh64()
for field in sorted(chunking_config.keys()):
hasher.update(str(chunking_config[field]).encode("utf-8"))
def new_task():
nonlocal doc
return {
@ -443,6 +450,9 @@ def queue_raptor_tasks(doc):
}
task = new_task()
for field in ["doc_id", "from_page", "to_page"]:
hasher.update(str(task.get(field, "")).encode("utf-8"))
task["digest"] = hasher.hexdigest()
bulk_insert_into_db(Task, [task], True)
task["type"] = "raptor"
assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."

View File

@ -34,15 +34,17 @@ from rag.utils.redis_conn import REDIS_CONN
from api import settings
from rag.nlp import search
def trim_header_by_lines(text: str, max_length) -> str:
len_text = len(text)
if len_text <= max_length:
return text
for i in range(len_text):
if text[i] == '\n' and len_text - i <= max_length:
return text[i+1:]
return text[i + 1:]
return text
class TaskService(CommonService):
model = Task
@ -73,10 +75,10 @@ class TaskService(CommonService):
]
docs = (
cls.model.select(*fields)
.join(Document, on=(cls.model.doc_id == Document.id))
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where(cls.model.id == task_id)
.join(Document, on=(cls.model.doc_id == Document.id))
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where(cls.model.id == task_id)
)
docs = list(docs.dicts())
if not docs:
@ -111,7 +113,7 @@ class TaskService(CommonService):
]
tasks = (
cls.model.select(*fields).order_by(cls.model.from_page.asc(), cls.model.create_time.desc())
.where(cls.model.doc_id == doc_id)
.where(cls.model.doc_id == doc_id)
)
tasks = list(tasks.dicts())
if not tasks:
@ -131,18 +133,18 @@ class TaskService(CommonService):
cls.model.select(
*[Document.id, Document.kb_id, Document.location, File.parent_id]
)
.join(Document, on=(cls.model.doc_id == Document.id))
.join(
.join(Document, on=(cls.model.doc_id == Document.id))
.join(
File2Document,
on=(File2Document.document_id == Document.id),
join_type=JOIN.LEFT_OUTER,
)
.join(
.join(
File,
on=(File2Document.file_id == File.id),
join_type=JOIN.LEFT_OUTER,
)
.where(
.where(
Document.status == StatusEnum.VALID.value,
Document.run == TaskStatus.RUNNING.value,
~(Document.type == FileType.VIRTUAL.value),
@ -212,8 +214,8 @@ def queue_tasks(doc: dict, bucket: str, name: str):
if doc["parser_id"] == "paper":
page_size = doc["parser_config"].get("task_page_size", 22)
if doc["parser_id"] in ["one", "knowledge_graph"] or not do_layout:
page_size = 10**9
page_ranges = doc["parser_config"].get("pages") or [(1, 10**5)]
page_size = 10 ** 9
page_ranges = doc["parser_config"].get("pages") or [(1, 10 ** 5)]
for s, e in page_ranges:
s -= 1
s = max(0, s)
@ -257,7 +259,8 @@ def queue_tasks(doc: dict, bucket: str, name: str):
if task["chunk_ids"]:
chunk_ids.extend(task["chunk_ids"].split())
if chunk_ids:
settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(chunking_config["tenant_id"]), chunking_config["kb_id"])
settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(chunking_config["tenant_id"]),
chunking_config["kb_id"])
DocumentService.update_by_id(doc["id"], {"chunk_num": ck_num})
bulk_insert_into_db(Task, tsks, True)
@ -271,7 +274,8 @@ def queue_tasks(doc: dict, bucket: str, name: str):
def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict):
idx = bisect.bisect_left(prev_tasks, task.get("from_page", 0), key=lambda x: x.get("from_page",0))
idx = bisect.bisect_left(prev_tasks, (task.get("from_page", 0), task.get("digest", "")),
key=lambda x: (x.get("from_page", 0), x.get("digest", "")))
if idx >= len(prev_tasks):
return 0
prev_task = prev_tasks[idx]
@ -286,4 +290,4 @@ def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config:
task["progress_msg"] += "reused previous task's chunks."
prev_task["chunk_ids"] = ""
return len(task["chunk_ids"].split())
return len(task["chunk_ids"].split())

View File

@ -78,7 +78,7 @@ def get_llm_cache(llmnm, txt, history, genconf):
bin = REDIS_CONN.get(k)
if not bin:
return
return bin.decode("utf-8")
return bin
def set_llm_cache(llmnm, txt, v: str, history, genconf):