mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-06-04 11:24:00 +08:00
Refactored DocumentService.update_progress (#5642)
### What problem does this PR solve? Refactored DocumentService.update_progress ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
parent
02c955babb
commit
f65c3ae62b
@ -843,8 +843,8 @@ class Task(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
doc_id = CharField(max_length=32, null=False, index=True)
|
||||
from_page = IntegerField(default=0)
|
||||
|
||||
to_page = IntegerField(default=100000000)
|
||||
task_type = CharField(max_length=32, null=False, default="")
|
||||
|
||||
begin_at = DateTimeField(null=True, index=True)
|
||||
process_duation = FloatField(default=0)
|
||||
@ -1115,3 +1115,10 @@ def migrate_db():
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(
|
||||
migrator.add_column("task", "task_type",
|
||||
CharField(max_length=32, null=False, default=""))
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
@ -381,12 +381,6 @@ class DocumentService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_progress(cls):
|
||||
MSG = {
|
||||
"raptor": "Start RAPTOR (Recursive Abstractive Processing for Tree-Organized Retrieval).",
|
||||
"graphrag": "Entities",
|
||||
"graph_resolution": "Resolution",
|
||||
"graph_community": "Communities"
|
||||
}
|
||||
docs = cls.get_unfinished_docs()
|
||||
for d in docs:
|
||||
try:
|
||||
@ -397,37 +391,31 @@ class DocumentService(CommonService):
|
||||
prg = 0
|
||||
finished = True
|
||||
bad = 0
|
||||
has_raptor = False
|
||||
has_graphrag = False
|
||||
e, doc = DocumentService.get_by_id(d["id"])
|
||||
status = doc.run # TaskStatus.RUNNING.value
|
||||
for t in tsks:
|
||||
if 0 <= t.progress < 1:
|
||||
finished = False
|
||||
prg += t.progress if t.progress >= 0 else 0
|
||||
if t.progress_msg not in msg:
|
||||
msg.append(t.progress_msg)
|
||||
if t.progress == -1:
|
||||
bad += 1
|
||||
prg += t.progress if t.progress >= 0 else 0
|
||||
msg.append(t.progress_msg)
|
||||
if t.task_type == "raptor":
|
||||
has_raptor = True
|
||||
elif t.task_type == "graphrag":
|
||||
has_graphrag = True
|
||||
prg /= len(tsks)
|
||||
if finished and bad:
|
||||
prg = -1
|
||||
status = TaskStatus.FAIL.value
|
||||
elif finished:
|
||||
m = "\n".join(sorted(msg))
|
||||
if d["parser_config"].get("raptor", {}).get("use_raptor") and m.find(MSG["raptor"]) < 0:
|
||||
queue_raptor_o_graphrag_tasks(d, "raptor", MSG["raptor"])
|
||||
if d["parser_config"].get("raptor", {}).get("use_raptor") and not has_raptor:
|
||||
queue_raptor_o_graphrag_tasks(d, "raptor")
|
||||
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
||||
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") and m.find(MSG["graphrag"]) < 0:
|
||||
queue_raptor_o_graphrag_tasks(d, "graphrag", MSG["graphrag"])
|
||||
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
||||
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") \
|
||||
and d["parser_config"].get("graphrag", {}).get("resolution") \
|
||||
and m.find(MSG["graph_resolution"]) < 0:
|
||||
queue_raptor_o_graphrag_tasks(d, "graph_resolution", MSG["graph_resolution"])
|
||||
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
||||
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") \
|
||||
and d["parser_config"].get("graphrag", {}).get("community") \
|
||||
and m.find(MSG["graph_community"]) < 0:
|
||||
queue_raptor_o_graphrag_tasks(d, "graph_community", MSG["graph_community"])
|
||||
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") and not has_graphrag:
|
||||
queue_raptor_o_graphrag_tasks(d, "graphrag")
|
||||
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
||||
else:
|
||||
status = TaskStatus.DONE.value
|
||||
@ -464,7 +452,7 @@ class DocumentService(CommonService):
|
||||
return False
|
||||
|
||||
|
||||
def queue_raptor_o_graphrag_tasks(doc, ty, msg):
|
||||
def queue_raptor_o_graphrag_tasks(doc, ty):
|
||||
chunking_config = DocumentService.get_chunking_config(doc["id"])
|
||||
hasher = xxhash.xxh64()
|
||||
for field in sorted(chunking_config.keys()):
|
||||
@ -477,7 +465,8 @@ def queue_raptor_o_graphrag_tasks(doc, ty, msg):
|
||||
"doc_id": doc["id"],
|
||||
"from_page": 100000000,
|
||||
"to_page": 100000000,
|
||||
"progress_msg": datetime.now().strftime("%H:%M:%S") + " " + msg
|
||||
"task_type": ty,
|
||||
"progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty
|
||||
}
|
||||
|
||||
task = new_task()
|
||||
@ -486,7 +475,6 @@ def queue_raptor_o_graphrag_tasks(doc, ty, msg):
|
||||
hasher.update(ty.encode("utf-8"))
|
||||
task["digest"] = hasher.hexdigest()
|
||||
bulk_insert_into_db(Task, [task], True)
|
||||
task["task_type"] = ty
|
||||
assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."
|
||||
|
||||
|
||||
|
@ -8,6 +8,7 @@ export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PRO
|
||||
export PYTHONPATH=$(pwd)
|
||||
|
||||
export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/
|
||||
JEMALLOC_PATH=$(pkg-config --variable=libdir jemalloc)/libjemalloc.so
|
||||
|
||||
PY=python3
|
||||
|
||||
@ -48,7 +49,7 @@ task_exe(){
|
||||
local retry_count=0
|
||||
while ! $STOP && [ $retry_count -lt $MAX_RETRIES ]; do
|
||||
echo "Starting task_executor.py for task $task_id (Attempt $((retry_count+1)))"
|
||||
$PY rag/svr/task_executor.py "$task_id"
|
||||
LD_PRELOAD=$JEMALLOC_PATH $PY rag/svr/task_executor.py "$task_id"
|
||||
EXIT_CODE=$?
|
||||
if [ $EXIT_CODE -eq 0 ]; then
|
||||
echo "task_executor.py for task $task_id exited successfully."
|
||||
|
@ -104,14 +104,14 @@ class EntityResolution(Extractor):
|
||||
connect_graph = nx.Graph()
|
||||
removed_entities = []
|
||||
connect_graph.add_edges_from(resolution_result)
|
||||
# for issue #5600
|
||||
all_entities_data = []
|
||||
all_relationships_data = []
|
||||
|
||||
for sub_connect_graph in nx.connected_components(connect_graph):
|
||||
sub_connect_graph = connect_graph.subgraph(sub_connect_graph)
|
||||
remove_nodes = list(sub_connect_graph.nodes)
|
||||
keep_node = remove_nodes.pop()
|
||||
await self._merge_nodes(keep_node, self._get_entity_(remove_nodes), all_relationships_data=all_relationships_data)
|
||||
await self._merge_nodes(keep_node, self._get_entity_(remove_nodes), all_entities_data)
|
||||
for remove_node in remove_nodes:
|
||||
removed_entities.append(remove_node)
|
||||
remove_node_neighbors = graph[remove_node]
|
||||
@ -127,7 +127,7 @@ class EntityResolution(Extractor):
|
||||
if not rel:
|
||||
continue
|
||||
if graph.has_edge(keep_node, remove_node_neighbor):
|
||||
self._merge_edges(keep_node, remove_node_neighbor, [rel])
|
||||
await self._merge_edges(keep_node, remove_node_neighbor, [rel], all_relationships_data)
|
||||
else:
|
||||
pair = sorted([keep_node, remove_node_neighbor])
|
||||
graph.add_edge(pair[0], pair[1], weight=rel['weight'])
|
||||
|
@ -193,7 +193,7 @@ async def collect():
|
||||
FAILED_TASKS += 1
|
||||
logging.warning(f"collect task {msg['id']} {state}")
|
||||
redis_msg.ack()
|
||||
return None
|
||||
return None, None
|
||||
task["task_type"] = msg.get("task_type", "")
|
||||
return redis_msg, task
|
||||
|
||||
@ -521,30 +521,29 @@ async def do_handle_task(task):
|
||||
chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
|
||||
# Either using graphrag or Standard chunking methods
|
||||
elif task.get("task_type", "") == "graphrag":
|
||||
graphrag_conf = task_parser_config.get("graphrag", {})
|
||||
if not graphrag_conf.get("use_graphrag", False):
|
||||
return
|
||||
start_ts = timer()
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||
await run_graphrag(task, chat_model, task_language, embedding_model, progress_callback)
|
||||
progress_callback(prog=1.0, msg="Knowledge Graph is done ({:.2f}s)".format(timer() - start_ts))
|
||||
return
|
||||
elif task.get("task_type", "") == "graph_resolution":
|
||||
start_ts = timer()
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||
with_res = WithResolution(
|
||||
task["tenant_id"], str(task["kb_id"]),chat_model, embedding_model,
|
||||
progress_callback
|
||||
)
|
||||
await with_res()
|
||||
progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts))
|
||||
return
|
||||
elif task.get("task_type", "") == "graph_community":
|
||||
start_ts = timer()
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||
with_comm = WithCommunity(
|
||||
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
|
||||
progress_callback
|
||||
)
|
||||
await with_comm()
|
||||
progress_callback(prog=1.0, msg="GraphRAG community reports generation is done ({:.2f}s)".format(timer() - start_ts))
|
||||
progress_callback(prog=1.0, msg="Knowledge Graph basic is done ({:.2f}s)".format(timer() - start_ts))
|
||||
if graphrag_conf.get("resolution", False):
|
||||
start_ts = timer()
|
||||
with_res = WithResolution(
|
||||
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
|
||||
progress_callback
|
||||
)
|
||||
await with_res()
|
||||
progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts))
|
||||
if graphrag_conf.get("community", False):
|
||||
start_ts = timer()
|
||||
with_comm = WithCommunity(
|
||||
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
|
||||
progress_callback
|
||||
)
|
||||
await with_comm()
|
||||
progress_callback(prog=1.0, msg="Knowledge Graph community is done ({:.2f}s)".format(timer() - start_ts))
|
||||
return
|
||||
else:
|
||||
# Standard chunking methods
|
||||
|
Loading…
x
Reference in New Issue
Block a user