Refactored DocumentService.update_progress (#5642)

### What problem does this PR solve?

Refactored DocumentService.update_progress

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Zhichang Yu 2025-03-05 14:48:03 +08:00 committed by GitHub
parent 02c955babb
commit f65c3ae62b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 49 additions and 54 deletions

View File

@ -843,8 +843,8 @@ class Task(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
doc_id = CharField(max_length=32, null=False, index=True)
from_page = IntegerField(default=0)
to_page = IntegerField(default=100000000)
task_type = CharField(max_length=32, null=False, default="")
begin_at = DateTimeField(null=True, index=True)
process_duation = FloatField(default=0)
@ -1115,3 +1115,10 @@ def migrate_db():
)
except Exception:
pass
try:
migrate(
migrator.add_column("task", "task_type",
CharField(max_length=32, null=False, default=""))
)
except Exception:
pass

View File

@ -381,12 +381,6 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def update_progress(cls):
MSG = {
"raptor": "Start RAPTOR (Recursive Abstractive Processing for Tree-Organized Retrieval).",
"graphrag": "Entities",
"graph_resolution": "Resolution",
"graph_community": "Communities"
}
docs = cls.get_unfinished_docs()
for d in docs:
try:
@ -397,37 +391,31 @@ class DocumentService(CommonService):
prg = 0
finished = True
bad = 0
has_raptor = False
has_graphrag = False
e, doc = DocumentService.get_by_id(d["id"])
status = doc.run # TaskStatus.RUNNING.value
for t in tsks:
if 0 <= t.progress < 1:
finished = False
prg += t.progress if t.progress >= 0 else 0
if t.progress_msg not in msg:
msg.append(t.progress_msg)
if t.progress == -1:
bad += 1
prg += t.progress if t.progress >= 0 else 0
msg.append(t.progress_msg)
if t.task_type == "raptor":
has_raptor = True
elif t.task_type == "graphrag":
has_graphrag = True
prg /= len(tsks)
if finished and bad:
prg = -1
status = TaskStatus.FAIL.value
elif finished:
m = "\n".join(sorted(msg))
if d["parser_config"].get("raptor", {}).get("use_raptor") and m.find(MSG["raptor"]) < 0:
queue_raptor_o_graphrag_tasks(d, "raptor", MSG["raptor"])
if d["parser_config"].get("raptor", {}).get("use_raptor") and not has_raptor:
queue_raptor_o_graphrag_tasks(d, "raptor")
prg = 0.98 * len(tsks) / (len(tsks) + 1)
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") and m.find(MSG["graphrag"]) < 0:
queue_raptor_o_graphrag_tasks(d, "graphrag", MSG["graphrag"])
prg = 0.98 * len(tsks) / (len(tsks) + 1)
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") \
and d["parser_config"].get("graphrag", {}).get("resolution") \
and m.find(MSG["graph_resolution"]) < 0:
queue_raptor_o_graphrag_tasks(d, "graph_resolution", MSG["graph_resolution"])
prg = 0.98 * len(tsks) / (len(tsks) + 1)
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") \
and d["parser_config"].get("graphrag", {}).get("community") \
and m.find(MSG["graph_community"]) < 0:
queue_raptor_o_graphrag_tasks(d, "graph_community", MSG["graph_community"])
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") and not has_graphrag:
queue_raptor_o_graphrag_tasks(d, "graphrag")
prg = 0.98 * len(tsks) / (len(tsks) + 1)
else:
status = TaskStatus.DONE.value
@ -464,7 +452,7 @@ class DocumentService(CommonService):
return False
def queue_raptor_o_graphrag_tasks(doc, ty, msg):
def queue_raptor_o_graphrag_tasks(doc, ty):
chunking_config = DocumentService.get_chunking_config(doc["id"])
hasher = xxhash.xxh64()
for field in sorted(chunking_config.keys()):
@ -477,7 +465,8 @@ def queue_raptor_o_graphrag_tasks(doc, ty, msg):
"doc_id": doc["id"],
"from_page": 100000000,
"to_page": 100000000,
"progress_msg": datetime.now().strftime("%H:%M:%S") + " " + msg
"task_type": ty,
"progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty
}
task = new_task()
@ -486,7 +475,6 @@ def queue_raptor_o_graphrag_tasks(doc, ty, msg):
hasher.update(ty.encode("utf-8"))
task["digest"] = hasher.hexdigest()
bulk_insert_into_db(Task, [task], True)
task["task_type"] = ty
assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."

View File

@ -8,6 +8,7 @@ export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PRO
export PYTHONPATH=$(pwd)
export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/
JEMALLOC_PATH=$(pkg-config --variable=libdir jemalloc)/libjemalloc.so
PY=python3
@ -48,7 +49,7 @@ task_exe(){
local retry_count=0
while ! $STOP && [ $retry_count -lt $MAX_RETRIES ]; do
echo "Starting task_executor.py for task $task_id (Attempt $((retry_count+1)))"
$PY rag/svr/task_executor.py "$task_id"
LD_PRELOAD=$JEMALLOC_PATH $PY rag/svr/task_executor.py "$task_id"
EXIT_CODE=$?
if [ $EXIT_CODE -eq 0 ]; then
echo "task_executor.py for task $task_id exited successfully."

View File

@ -104,14 +104,14 @@ class EntityResolution(Extractor):
connect_graph = nx.Graph()
removed_entities = []
connect_graph.add_edges_from(resolution_result)
# for issue #5600
all_entities_data = []
all_relationships_data = []
for sub_connect_graph in nx.connected_components(connect_graph):
sub_connect_graph = connect_graph.subgraph(sub_connect_graph)
remove_nodes = list(sub_connect_graph.nodes)
keep_node = remove_nodes.pop()
await self._merge_nodes(keep_node, self._get_entity_(remove_nodes), all_relationships_data=all_relationships_data)
await self._merge_nodes(keep_node, self._get_entity_(remove_nodes), all_entities_data)
for remove_node in remove_nodes:
removed_entities.append(remove_node)
remove_node_neighbors = graph[remove_node]
@ -127,7 +127,7 @@ class EntityResolution(Extractor):
if not rel:
continue
if graph.has_edge(keep_node, remove_node_neighbor):
self._merge_edges(keep_node, remove_node_neighbor, [rel])
await self._merge_edges(keep_node, remove_node_neighbor, [rel], all_relationships_data)
else:
pair = sorted([keep_node, remove_node_neighbor])
graph.add_edge(pair[0], pair[1], weight=rel['weight'])

View File

@ -193,7 +193,7 @@ async def collect():
FAILED_TASKS += 1
logging.warning(f"collect task {msg['id']} {state}")
redis_msg.ack()
return None
return None, None
task["task_type"] = msg.get("task_type", "")
return redis_msg, task
@ -521,30 +521,29 @@ async def do_handle_task(task):
chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
# Either using graphrag or Standard chunking methods
elif task.get("task_type", "") == "graphrag":
graphrag_conf = task_parser_config.get("graphrag", {})
if not graphrag_conf.get("use_graphrag", False):
return
start_ts = timer()
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
await run_graphrag(task, chat_model, task_language, embedding_model, progress_callback)
progress_callback(prog=1.0, msg="Knowledge Graph is done ({:.2f}s)".format(timer() - start_ts))
return
elif task.get("task_type", "") == "graph_resolution":
start_ts = timer()
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
with_res = WithResolution(
task["tenant_id"], str(task["kb_id"]),chat_model, embedding_model,
progress_callback
)
await with_res()
progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts))
return
elif task.get("task_type", "") == "graph_community":
start_ts = timer()
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
with_comm = WithCommunity(
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
progress_callback
)
await with_comm()
progress_callback(prog=1.0, msg="GraphRAG community reports generation is done ({:.2f}s)".format(timer() - start_ts))
progress_callback(prog=1.0, msg="Knowledge Graph basic is done ({:.2f}s)".format(timer() - start_ts))
if graphrag_conf.get("resolution", False):
start_ts = timer()
with_res = WithResolution(
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
progress_callback
)
await with_res()
progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts))
if graphrag_conf.get("community", False):
start_ts = timer()
with_comm = WithCommunity(
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
progress_callback
)
await with_comm()
progress_callback(prog=1.0, msg="Knowledge Graph community is done ({:.2f}s)".format(timer() - start_ts))
return
else:
# Standard chunking methods