mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-06-04 11:24:00 +08:00
Refactored DocumentService.update_progress (#5642)
### What problem does this PR solve? Refactored DocumentService.update_progress ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
parent
02c955babb
commit
f65c3ae62b
@ -843,8 +843,8 @@ class Task(DataBaseModel):
|
|||||||
id = CharField(max_length=32, primary_key=True)
|
id = CharField(max_length=32, primary_key=True)
|
||||||
doc_id = CharField(max_length=32, null=False, index=True)
|
doc_id = CharField(max_length=32, null=False, index=True)
|
||||||
from_page = IntegerField(default=0)
|
from_page = IntegerField(default=0)
|
||||||
|
|
||||||
to_page = IntegerField(default=100000000)
|
to_page = IntegerField(default=100000000)
|
||||||
|
task_type = CharField(max_length=32, null=False, default="")
|
||||||
|
|
||||||
begin_at = DateTimeField(null=True, index=True)
|
begin_at = DateTimeField(null=True, index=True)
|
||||||
process_duation = FloatField(default=0)
|
process_duation = FloatField(default=0)
|
||||||
@ -1115,3 +1115,10 @@ def migrate_db():
|
|||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
try:
|
||||||
|
migrate(
|
||||||
|
migrator.add_column("task", "task_type",
|
||||||
|
CharField(max_length=32, null=False, default=""))
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
@ -381,12 +381,6 @@ class DocumentService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def update_progress(cls):
|
def update_progress(cls):
|
||||||
MSG = {
|
|
||||||
"raptor": "Start RAPTOR (Recursive Abstractive Processing for Tree-Organized Retrieval).",
|
|
||||||
"graphrag": "Entities",
|
|
||||||
"graph_resolution": "Resolution",
|
|
||||||
"graph_community": "Communities"
|
|
||||||
}
|
|
||||||
docs = cls.get_unfinished_docs()
|
docs = cls.get_unfinished_docs()
|
||||||
for d in docs:
|
for d in docs:
|
||||||
try:
|
try:
|
||||||
@ -397,37 +391,31 @@ class DocumentService(CommonService):
|
|||||||
prg = 0
|
prg = 0
|
||||||
finished = True
|
finished = True
|
||||||
bad = 0
|
bad = 0
|
||||||
|
has_raptor = False
|
||||||
|
has_graphrag = False
|
||||||
e, doc = DocumentService.get_by_id(d["id"])
|
e, doc = DocumentService.get_by_id(d["id"])
|
||||||
status = doc.run # TaskStatus.RUNNING.value
|
status = doc.run # TaskStatus.RUNNING.value
|
||||||
for t in tsks:
|
for t in tsks:
|
||||||
if 0 <= t.progress < 1:
|
if 0 <= t.progress < 1:
|
||||||
finished = False
|
finished = False
|
||||||
prg += t.progress if t.progress >= 0 else 0
|
|
||||||
if t.progress_msg not in msg:
|
|
||||||
msg.append(t.progress_msg)
|
|
||||||
if t.progress == -1:
|
if t.progress == -1:
|
||||||
bad += 1
|
bad += 1
|
||||||
|
prg += t.progress if t.progress >= 0 else 0
|
||||||
|
msg.append(t.progress_msg)
|
||||||
|
if t.task_type == "raptor":
|
||||||
|
has_raptor = True
|
||||||
|
elif t.task_type == "graphrag":
|
||||||
|
has_graphrag = True
|
||||||
prg /= len(tsks)
|
prg /= len(tsks)
|
||||||
if finished and bad:
|
if finished and bad:
|
||||||
prg = -1
|
prg = -1
|
||||||
status = TaskStatus.FAIL.value
|
status = TaskStatus.FAIL.value
|
||||||
elif finished:
|
elif finished:
|
||||||
m = "\n".join(sorted(msg))
|
if d["parser_config"].get("raptor", {}).get("use_raptor") and not has_raptor:
|
||||||
if d["parser_config"].get("raptor", {}).get("use_raptor") and m.find(MSG["raptor"]) < 0:
|
queue_raptor_o_graphrag_tasks(d, "raptor")
|
||||||
queue_raptor_o_graphrag_tasks(d, "raptor", MSG["raptor"])
|
|
||||||
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
||||||
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") and m.find(MSG["graphrag"]) < 0:
|
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") and not has_graphrag:
|
||||||
queue_raptor_o_graphrag_tasks(d, "graphrag", MSG["graphrag"])
|
queue_raptor_o_graphrag_tasks(d, "graphrag")
|
||||||
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
|
||||||
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") \
|
|
||||||
and d["parser_config"].get("graphrag", {}).get("resolution") \
|
|
||||||
and m.find(MSG["graph_resolution"]) < 0:
|
|
||||||
queue_raptor_o_graphrag_tasks(d, "graph_resolution", MSG["graph_resolution"])
|
|
||||||
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
|
||||||
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") \
|
|
||||||
and d["parser_config"].get("graphrag", {}).get("community") \
|
|
||||||
and m.find(MSG["graph_community"]) < 0:
|
|
||||||
queue_raptor_o_graphrag_tasks(d, "graph_community", MSG["graph_community"])
|
|
||||||
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
||||||
else:
|
else:
|
||||||
status = TaskStatus.DONE.value
|
status = TaskStatus.DONE.value
|
||||||
@ -464,7 +452,7 @@ class DocumentService(CommonService):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def queue_raptor_o_graphrag_tasks(doc, ty, msg):
|
def queue_raptor_o_graphrag_tasks(doc, ty):
|
||||||
chunking_config = DocumentService.get_chunking_config(doc["id"])
|
chunking_config = DocumentService.get_chunking_config(doc["id"])
|
||||||
hasher = xxhash.xxh64()
|
hasher = xxhash.xxh64()
|
||||||
for field in sorted(chunking_config.keys()):
|
for field in sorted(chunking_config.keys()):
|
||||||
@ -477,7 +465,8 @@ def queue_raptor_o_graphrag_tasks(doc, ty, msg):
|
|||||||
"doc_id": doc["id"],
|
"doc_id": doc["id"],
|
||||||
"from_page": 100000000,
|
"from_page": 100000000,
|
||||||
"to_page": 100000000,
|
"to_page": 100000000,
|
||||||
"progress_msg": datetime.now().strftime("%H:%M:%S") + " " + msg
|
"task_type": ty,
|
||||||
|
"progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty
|
||||||
}
|
}
|
||||||
|
|
||||||
task = new_task()
|
task = new_task()
|
||||||
@ -486,7 +475,6 @@ def queue_raptor_o_graphrag_tasks(doc, ty, msg):
|
|||||||
hasher.update(ty.encode("utf-8"))
|
hasher.update(ty.encode("utf-8"))
|
||||||
task["digest"] = hasher.hexdigest()
|
task["digest"] = hasher.hexdigest()
|
||||||
bulk_insert_into_db(Task, [task], True)
|
bulk_insert_into_db(Task, [task], True)
|
||||||
task["task_type"] = ty
|
|
||||||
assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."
|
assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PRO
|
|||||||
export PYTHONPATH=$(pwd)
|
export PYTHONPATH=$(pwd)
|
||||||
|
|
||||||
export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/
|
export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/
|
||||||
|
JEMALLOC_PATH=$(pkg-config --variable=libdir jemalloc)/libjemalloc.so
|
||||||
|
|
||||||
PY=python3
|
PY=python3
|
||||||
|
|
||||||
@ -48,7 +49,7 @@ task_exe(){
|
|||||||
local retry_count=0
|
local retry_count=0
|
||||||
while ! $STOP && [ $retry_count -lt $MAX_RETRIES ]; do
|
while ! $STOP && [ $retry_count -lt $MAX_RETRIES ]; do
|
||||||
echo "Starting task_executor.py for task $task_id (Attempt $((retry_count+1)))"
|
echo "Starting task_executor.py for task $task_id (Attempt $((retry_count+1)))"
|
||||||
$PY rag/svr/task_executor.py "$task_id"
|
LD_PRELOAD=$JEMALLOC_PATH $PY rag/svr/task_executor.py "$task_id"
|
||||||
EXIT_CODE=$?
|
EXIT_CODE=$?
|
||||||
if [ $EXIT_CODE -eq 0 ]; then
|
if [ $EXIT_CODE -eq 0 ]; then
|
||||||
echo "task_executor.py for task $task_id exited successfully."
|
echo "task_executor.py for task $task_id exited successfully."
|
||||||
|
@ -104,14 +104,14 @@ class EntityResolution(Extractor):
|
|||||||
connect_graph = nx.Graph()
|
connect_graph = nx.Graph()
|
||||||
removed_entities = []
|
removed_entities = []
|
||||||
connect_graph.add_edges_from(resolution_result)
|
connect_graph.add_edges_from(resolution_result)
|
||||||
# for issue #5600
|
all_entities_data = []
|
||||||
all_relationships_data = []
|
all_relationships_data = []
|
||||||
|
|
||||||
for sub_connect_graph in nx.connected_components(connect_graph):
|
for sub_connect_graph in nx.connected_components(connect_graph):
|
||||||
sub_connect_graph = connect_graph.subgraph(sub_connect_graph)
|
sub_connect_graph = connect_graph.subgraph(sub_connect_graph)
|
||||||
remove_nodes = list(sub_connect_graph.nodes)
|
remove_nodes = list(sub_connect_graph.nodes)
|
||||||
keep_node = remove_nodes.pop()
|
keep_node = remove_nodes.pop()
|
||||||
await self._merge_nodes(keep_node, self._get_entity_(remove_nodes), all_relationships_data=all_relationships_data)
|
await self._merge_nodes(keep_node, self._get_entity_(remove_nodes), all_entities_data)
|
||||||
for remove_node in remove_nodes:
|
for remove_node in remove_nodes:
|
||||||
removed_entities.append(remove_node)
|
removed_entities.append(remove_node)
|
||||||
remove_node_neighbors = graph[remove_node]
|
remove_node_neighbors = graph[remove_node]
|
||||||
@ -127,7 +127,7 @@ class EntityResolution(Extractor):
|
|||||||
if not rel:
|
if not rel:
|
||||||
continue
|
continue
|
||||||
if graph.has_edge(keep_node, remove_node_neighbor):
|
if graph.has_edge(keep_node, remove_node_neighbor):
|
||||||
self._merge_edges(keep_node, remove_node_neighbor, [rel])
|
await self._merge_edges(keep_node, remove_node_neighbor, [rel], all_relationships_data)
|
||||||
else:
|
else:
|
||||||
pair = sorted([keep_node, remove_node_neighbor])
|
pair = sorted([keep_node, remove_node_neighbor])
|
||||||
graph.add_edge(pair[0], pair[1], weight=rel['weight'])
|
graph.add_edge(pair[0], pair[1], weight=rel['weight'])
|
||||||
|
@ -193,7 +193,7 @@ async def collect():
|
|||||||
FAILED_TASKS += 1
|
FAILED_TASKS += 1
|
||||||
logging.warning(f"collect task {msg['id']} {state}")
|
logging.warning(f"collect task {msg['id']} {state}")
|
||||||
redis_msg.ack()
|
redis_msg.ack()
|
||||||
return None
|
return None, None
|
||||||
task["task_type"] = msg.get("task_type", "")
|
task["task_type"] = msg.get("task_type", "")
|
||||||
return redis_msg, task
|
return redis_msg, task
|
||||||
|
|
||||||
@ -521,30 +521,29 @@ async def do_handle_task(task):
|
|||||||
chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
|
chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
|
||||||
# Either using graphrag or Standard chunking methods
|
# Either using graphrag or Standard chunking methods
|
||||||
elif task.get("task_type", "") == "graphrag":
|
elif task.get("task_type", "") == "graphrag":
|
||||||
|
graphrag_conf = task_parser_config.get("graphrag", {})
|
||||||
|
if not graphrag_conf.get("use_graphrag", False):
|
||||||
|
return
|
||||||
start_ts = timer()
|
start_ts = timer()
|
||||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||||
await run_graphrag(task, chat_model, task_language, embedding_model, progress_callback)
|
await run_graphrag(task, chat_model, task_language, embedding_model, progress_callback)
|
||||||
progress_callback(prog=1.0, msg="Knowledge Graph is done ({:.2f}s)".format(timer() - start_ts))
|
progress_callback(prog=1.0, msg="Knowledge Graph basic is done ({:.2f}s)".format(timer() - start_ts))
|
||||||
return
|
if graphrag_conf.get("resolution", False):
|
||||||
elif task.get("task_type", "") == "graph_resolution":
|
start_ts = timer()
|
||||||
start_ts = timer()
|
with_res = WithResolution(
|
||||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
|
||||||
with_res = WithResolution(
|
progress_callback
|
||||||
task["tenant_id"], str(task["kb_id"]),chat_model, embedding_model,
|
)
|
||||||
progress_callback
|
await with_res()
|
||||||
)
|
progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts))
|
||||||
await with_res()
|
if graphrag_conf.get("community", False):
|
||||||
progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts))
|
start_ts = timer()
|
||||||
return
|
with_comm = WithCommunity(
|
||||||
elif task.get("task_type", "") == "graph_community":
|
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
|
||||||
start_ts = timer()
|
progress_callback
|
||||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
)
|
||||||
with_comm = WithCommunity(
|
await with_comm()
|
||||||
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
|
progress_callback(prog=1.0, msg="Knowledge Graph community is done ({:.2f}s)".format(timer() - start_ts))
|
||||||
progress_callback
|
|
||||||
)
|
|
||||||
await with_comm()
|
|
||||||
progress_callback(prog=1.0, msg="GraphRAG community reports generation is done ({:.2f}s)".format(timer() - start_ts))
|
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
# Standard chunking methods
|
# Standard chunking methods
|
||||||
|
Loading…
x
Reference in New Issue
Block a user