mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-11 00:29:03 +08:00
fix(nursery): Fix Closure Trap Issues in Trio Concurrent Tasks (#7106)
## Problem Description Multiple files in the RAGFlow project contain closure trap issues when using lambda functions with `trio.open_nursery()`. This problem causes concurrent tasks created in loops to reference the same variable, resulting in all tasks processing the same data (the data from the last iteration) rather than each task processing its corresponding data from the loop. ## Issue Details When using a `lambda` to create a closure function and passing it to `nursery.start_soon()` within a loop, the lambda function captures a reference to the loop variable rather than its value. For example: ```python # Problematic code async with trio.open_nursery() as nursery: for d in docs: nursery.start_soon(lambda: doc_keyword_extraction(chat_mdl, d, topn)) ``` In this pattern, when concurrent tasks begin execution, `d` has already become the value after the loop ends (typically the last element), causing all tasks to use the same data. ## Fix Solution Changed the way concurrent tasks are created with `nursery.start_soon()` by leveraging Trio's API design to directly pass the function and its arguments separately: ```python # Fixed code async with trio.open_nursery() as nursery: for d in docs: nursery.start_soon(doc_keyword_extraction, chat_mdl, d, topn) ``` This way, each task uses the parameter values at the time of the function call, rather than references captured through closures. ## Fixed Files Fixed closure traps in the following files: 1. `rag/svr/task_executor.py`: 3 fixes, involving document keyword extraction, question generation, and tag processing 2. `rag/raptor.py`: 1 fix, involving document summarization 3. `graphrag/utils.py`: 2 fixes, involving graph node and edge processing 4. `graphrag/entity_resolution.py`: 2 fixes, involving entity resolution and graph node merging 5. `graphrag/general/mind_map_extractor.py`: 2 fixes, involving document processing 6. `graphrag/general/extractor.py`: 3 fixes, involving content processing and graph node/edge merging 7. `graphrag/general/community_reports_extractor.py`: 1 fix, involving community report extraction ## Potential Impact This fix resolves a serious concurrency issue that could have caused: - Data processing errors (processing duplicate data) - Performance degradation (all tasks working on the same data) - Inconsistent results (some data not being processed) After the fix, all concurrent tasks should correctly process their respective data, improving system correctness and reliability.
This commit is contained in:
parent
42e236f464
commit
8b8a2f2949
@ -103,7 +103,7 @@ class EntityResolution(Extractor):
|
|||||||
continue
|
continue
|
||||||
for i in range(0, len(candidate_resolution_i[1]), resolution_batch_size):
|
for i in range(0, len(candidate_resolution_i[1]), resolution_batch_size):
|
||||||
candidate_batch = candidate_resolution_i[0], candidate_resolution_i[1][i:i + resolution_batch_size]
|
candidate_batch = candidate_resolution_i[0], candidate_resolution_i[1][i:i + resolution_batch_size]
|
||||||
nursery.start_soon(lambda: self._resolve_candidate(candidate_batch, resolution_result))
|
nursery.start_soon(self._resolve_candidate, candidate_batch, resolution_result)
|
||||||
callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.")
|
callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.")
|
||||||
|
|
||||||
change = GraphChange()
|
change = GraphChange()
|
||||||
@ -112,7 +112,7 @@ class EntityResolution(Extractor):
|
|||||||
async with trio.open_nursery() as nursery:
|
async with trio.open_nursery() as nursery:
|
||||||
for sub_connect_graph in nx.connected_components(connect_graph):
|
for sub_connect_graph in nx.connected_components(connect_graph):
|
||||||
merging_nodes = list(sub_connect_graph)
|
merging_nodes = list(sub_connect_graph)
|
||||||
nursery.start_soon(lambda: self._merge_graph_nodes(graph, merging_nodes, change))
|
nursery.start_soon(self._merge_graph_nodes, graph, merging_nodes, change)
|
||||||
|
|
||||||
# Update pagerank
|
# Update pagerank
|
||||||
pr = nx.pagerank(graph)
|
pr = nx.pagerank(graph)
|
||||||
|
@ -124,7 +124,7 @@ class CommunityReportsExtractor(Extractor):
|
|||||||
for level, comm in communities.items():
|
for level, comm in communities.items():
|
||||||
logging.info(f"Level {level}: Community: {len(comm.keys())}")
|
logging.info(f"Level {level}: Community: {len(comm.keys())}")
|
||||||
for community in comm.items():
|
for community in comm.items():
|
||||||
nursery.start_soon(lambda: extract_community_report(community))
|
nursery.start_soon(extract_community_report, community)
|
||||||
if callback:
|
if callback:
|
||||||
callback(msg=f"Community reports done in {trio.current_time() - st:.2f}s, used tokens: {token_count}")
|
callback(msg=f"Community reports done in {trio.current_time() - st:.2f}s, used tokens: {token_count}")
|
||||||
|
|
||||||
|
@ -97,7 +97,7 @@ class Extractor:
|
|||||||
async with trio.open_nursery() as nursery:
|
async with trio.open_nursery() as nursery:
|
||||||
for i, ck in enumerate(chunks):
|
for i, ck in enumerate(chunks):
|
||||||
ck = truncate(ck, int(self._llm.max_length*0.8))
|
ck = truncate(ck, int(self._llm.max_length*0.8))
|
||||||
nursery.start_soon(lambda: self._process_single_content((doc_id, ck), i, len(chunks), out_results))
|
nursery.start_soon(self._process_single_content, (doc_id, ck), i, len(chunks), out_results)
|
||||||
|
|
||||||
maybe_nodes = defaultdict(list)
|
maybe_nodes = defaultdict(list)
|
||||||
maybe_edges = defaultdict(list)
|
maybe_edges = defaultdict(list)
|
||||||
@ -116,7 +116,7 @@ class Extractor:
|
|||||||
all_entities_data = []
|
all_entities_data = []
|
||||||
async with trio.open_nursery() as nursery:
|
async with trio.open_nursery() as nursery:
|
||||||
for en_nm, ents in maybe_nodes.items():
|
for en_nm, ents in maybe_nodes.items():
|
||||||
nursery.start_soon(lambda: self._merge_nodes(en_nm, ents, all_entities_data))
|
nursery.start_soon(self._merge_nodes, en_nm, ents, all_entities_data)
|
||||||
now = trio.current_time()
|
now = trio.current_time()
|
||||||
if callback:
|
if callback:
|
||||||
callback(msg = f"Entities merging done, {now-start_ts:.2f}s.")
|
callback(msg = f"Entities merging done, {now-start_ts:.2f}s.")
|
||||||
@ -126,7 +126,7 @@ class Extractor:
|
|||||||
all_relationships_data = []
|
all_relationships_data = []
|
||||||
async with trio.open_nursery() as nursery:
|
async with trio.open_nursery() as nursery:
|
||||||
for (src, tgt), rels in maybe_edges.items():
|
for (src, tgt), rels in maybe_edges.items():
|
||||||
nursery.start_soon(lambda: self._merge_edges(src, tgt, rels, all_relationships_data))
|
nursery.start_soon(self._merge_edges, src, tgt, rels, all_relationships_data)
|
||||||
now = trio.current_time()
|
now = trio.current_time()
|
||||||
if callback:
|
if callback:
|
||||||
callback(msg = f"Relationships merging done, {now-start_ts:.2f}s.")
|
callback(msg = f"Relationships merging done, {now-start_ts:.2f}s.")
|
||||||
|
@ -93,13 +93,13 @@ class MindMapExtractor(Extractor):
|
|||||||
for i in range(len(sections)):
|
for i in range(len(sections)):
|
||||||
section_cnt = num_tokens_from_string(sections[i])
|
section_cnt = num_tokens_from_string(sections[i])
|
||||||
if cnt + section_cnt >= token_count and texts:
|
if cnt + section_cnt >= token_count and texts:
|
||||||
nursery.start_soon(lambda: self._process_document("".join(texts), prompt_variables, res))
|
nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res)
|
||||||
texts = []
|
texts = []
|
||||||
cnt = 0
|
cnt = 0
|
||||||
texts.append(sections[i])
|
texts.append(sections[i])
|
||||||
cnt += section_cnt
|
cnt += section_cnt
|
||||||
if texts:
|
if texts:
|
||||||
nursery.start_soon(lambda: self._process_document("".join(texts), prompt_variables, res))
|
nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res)
|
||||||
if not res:
|
if not res:
|
||||||
return MindMapResult(output={"id": "root", "children": []})
|
return MindMapResult(output={"id": "root", "children": []})
|
||||||
merge_json = reduce(self._merge, res)
|
merge_json = reduce(self._merge, res)
|
||||||
|
@ -439,7 +439,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
|
|||||||
if change.removed_edges:
|
if change.removed_edges:
|
||||||
async with trio.open_nursery() as nursery:
|
async with trio.open_nursery() as nursery:
|
||||||
for from_node, to_node in change.removed_edges:
|
for from_node, to_node in change.removed_edges:
|
||||||
nursery.start_soon(lambda: trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id)))
|
nursery.start_soon(lambda from_node=from_node, to_node=to_node: trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id)))
|
||||||
now = trio.current_time()
|
now = trio.current_time()
|
||||||
if callback:
|
if callback:
|
||||||
callback(msg=f"set_graph removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges from index in {now - start:.2f}s.")
|
callback(msg=f"set_graph removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges from index in {now - start:.2f}s.")
|
||||||
@ -457,13 +457,13 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
|
|||||||
async with trio.open_nursery() as nursery:
|
async with trio.open_nursery() as nursery:
|
||||||
for node in change.added_updated_nodes:
|
for node in change.added_updated_nodes:
|
||||||
node_attrs = graph.nodes[node]
|
node_attrs = graph.nodes[node]
|
||||||
nursery.start_soon(lambda: graph_node_to_chunk(kb_id, embd_mdl, node, node_attrs, chunks))
|
nursery.start_soon(graph_node_to_chunk, kb_id, embd_mdl, node, node_attrs, chunks)
|
||||||
for from_node, to_node in change.added_updated_edges:
|
for from_node, to_node in change.added_updated_edges:
|
||||||
edge_attrs = graph.get_edge_data(from_node, to_node)
|
edge_attrs = graph.get_edge_data(from_node, to_node)
|
||||||
if not edge_attrs:
|
if not edge_attrs:
|
||||||
# added_updated_edges could record a non-existing edge if both from_node and to_node participate in nodes merging.
|
# added_updated_edges could record a non-existing edge if both from_node and to_node participate in nodes merging.
|
||||||
continue
|
continue
|
||||||
nursery.start_soon(lambda: graph_edge_to_chunk(kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks))
|
nursery.start_soon(graph_edge_to_chunk, kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks)
|
||||||
now = trio.current_time()
|
now = trio.current_time()
|
||||||
if callback:
|
if callback:
|
||||||
callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.")
|
callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.")
|
||||||
|
@ -152,7 +152,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
|||||||
ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]
|
ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]
|
||||||
assert len(ck_idx) > 0
|
assert len(ck_idx) > 0
|
||||||
async with chat_limiter:
|
async with chat_limiter:
|
||||||
nursery.start_soon(lambda: summarize(ck_idx))
|
nursery.start_soon(summarize, ck_idx)
|
||||||
|
|
||||||
assert len(chunks) - end == n_clusters, "{} vs. {}".format(
|
assert len(chunks) - end == n_clusters, "{} vs. {}".format(
|
||||||
len(chunks) - end, n_clusters
|
len(chunks) - end, n_clusters
|
||||||
|
@ -309,7 +309,7 @@ async def build_chunks(task, progress_callback):
|
|||||||
return
|
return
|
||||||
async with trio.open_nursery() as nursery:
|
async with trio.open_nursery() as nursery:
|
||||||
for d in docs:
|
for d in docs:
|
||||||
nursery.start_soon(lambda: doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"]))
|
nursery.start_soon(doc_keyword_extraction, chat_mdl, d, task["parser_config"]["auto_keywords"])
|
||||||
progress_callback(msg="Keywords generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
|
progress_callback(msg="Keywords generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
|
||||||
|
|
||||||
if task["parser_config"].get("auto_questions", 0):
|
if task["parser_config"].get("auto_questions", 0):
|
||||||
@ -328,7 +328,7 @@ async def build_chunks(task, progress_callback):
|
|||||||
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
|
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
|
||||||
async with trio.open_nursery() as nursery:
|
async with trio.open_nursery() as nursery:
|
||||||
for d in docs:
|
for d in docs:
|
||||||
nursery.start_soon(lambda: doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"]))
|
nursery.start_soon(doc_question_proposal, chat_mdl, d, task["parser_config"]["auto_questions"])
|
||||||
progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
|
progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
|
||||||
|
|
||||||
if task["kb_parser_config"].get("tag_kb_ids", []):
|
if task["kb_parser_config"].get("tag_kb_ids", []):
|
||||||
@ -370,7 +370,7 @@ async def build_chunks(task, progress_callback):
|
|||||||
d[TAG_FLD] = json.loads(cached)
|
d[TAG_FLD] = json.loads(cached)
|
||||||
async with trio.open_nursery() as nursery:
|
async with trio.open_nursery() as nursery:
|
||||||
for d in docs_to_tag:
|
for d in docs_to_tag:
|
||||||
nursery.start_soon(lambda: doc_content_tagging(chat_mdl, d, topn_tags))
|
nursery.start_soon(doc_content_tagging, chat_mdl, d, topn_tags)
|
||||||
progress_callback(msg="Tagging {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
|
progress_callback(msg="Tagging {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
|
||||||
|
|
||||||
return docs
|
return docs
|
||||||
|
Loading…
x
Reference in New Issue
Block a user