diff --git a/graphrag/entity_resolution.py b/graphrag/entity_resolution.py index 5d75027ea..60792d381 100644 --- a/graphrag/entity_resolution.py +++ b/graphrag/entity_resolution.py @@ -103,7 +103,7 @@ class EntityResolution(Extractor): continue for i in range(0, len(candidate_resolution_i[1]), resolution_batch_size): candidate_batch = candidate_resolution_i[0], candidate_resolution_i[1][i:i + resolution_batch_size] - nursery.start_soon(lambda: self._resolve_candidate(candidate_batch, resolution_result)) + nursery.start_soon(self._resolve_candidate, candidate_batch, resolution_result) callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.") change = GraphChange() @@ -112,7 +112,7 @@ class EntityResolution(Extractor): async with trio.open_nursery() as nursery: for sub_connect_graph in nx.connected_components(connect_graph): merging_nodes = list(sub_connect_graph) - nursery.start_soon(lambda: self._merge_graph_nodes(graph, merging_nodes, change)) + nursery.start_soon(self._merge_graph_nodes, graph, merging_nodes, change) # Update pagerank pr = nx.pagerank(graph) diff --git a/graphrag/general/community_reports_extractor.py b/graphrag/general/community_reports_extractor.py index 4b0989f96..14966af02 100644 --- a/graphrag/general/community_reports_extractor.py +++ b/graphrag/general/community_reports_extractor.py @@ -124,7 +124,7 @@ class CommunityReportsExtractor(Extractor): for level, comm in communities.items(): logging.info(f"Level {level}: Community: {len(comm.keys())}") for community in comm.items(): - nursery.start_soon(lambda: extract_community_report(community)) + nursery.start_soon(extract_community_report, community) if callback: callback(msg=f"Community reports done in {trio.current_time() - st:.2f}s, used tokens: {token_count}") diff --git a/graphrag/general/extractor.py b/graphrag/general/extractor.py index 23caccb8a..c86cc1a95 100644 --- a/graphrag/general/extractor.py +++ b/graphrag/general/extractor.py @@ -97,7 +97,7 @@ class Extractor: async with trio.open_nursery() as nursery: for i, ck in enumerate(chunks): ck = truncate(ck, int(self._llm.max_length*0.8)) - nursery.start_soon(lambda: self._process_single_content((doc_id, ck), i, len(chunks), out_results)) + nursery.start_soon(self._process_single_content, (doc_id, ck), i, len(chunks), out_results) maybe_nodes = defaultdict(list) maybe_edges = defaultdict(list) @@ -116,7 +116,7 @@ class Extractor: all_entities_data = [] async with trio.open_nursery() as nursery: for en_nm, ents in maybe_nodes.items(): - nursery.start_soon(lambda: self._merge_nodes(en_nm, ents, all_entities_data)) + nursery.start_soon(self._merge_nodes, en_nm, ents, all_entities_data) now = trio.current_time() if callback: callback(msg = f"Entities merging done, {now-start_ts:.2f}s.") @@ -126,7 +126,7 @@ class Extractor: all_relationships_data = [] async with trio.open_nursery() as nursery: for (src, tgt), rels in maybe_edges.items(): - nursery.start_soon(lambda: self._merge_edges(src, tgt, rels, all_relationships_data)) + nursery.start_soon(self._merge_edges, src, tgt, rels, all_relationships_data) now = trio.current_time() if callback: callback(msg = f"Relationships merging done, {now-start_ts:.2f}s.") diff --git a/graphrag/general/mind_map_extractor.py b/graphrag/general/mind_map_extractor.py index c9ac2a64b..b4ee6343e 100644 --- a/graphrag/general/mind_map_extractor.py +++ b/graphrag/general/mind_map_extractor.py @@ -93,13 +93,13 @@ class MindMapExtractor(Extractor): for i in range(len(sections)): section_cnt = num_tokens_from_string(sections[i]) if cnt + section_cnt >= token_count and texts: - nursery.start_soon(lambda: self._process_document("".join(texts), prompt_variables, res)) + nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res) texts = [] cnt = 0 texts.append(sections[i]) cnt += section_cnt if texts: - nursery.start_soon(lambda: self._process_document("".join(texts), prompt_variables, res)) + nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res) if not res: return MindMapResult(output={"id": "root", "children": []}) merge_json = reduce(self._merge, res) diff --git a/graphrag/utils.py b/graphrag/utils.py index 20313e702..13472e414 100644 --- a/graphrag/utils.py +++ b/graphrag/utils.py @@ -439,7 +439,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang if change.removed_edges: async with trio.open_nursery() as nursery: for from_node, to_node in change.removed_edges: - nursery.start_soon(lambda: trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id))) + nursery.start_soon(lambda from_node=from_node, to_node=to_node: trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id))) now = trio.current_time() if callback: callback(msg=f"set_graph removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges from index in {now - start:.2f}s.") @@ -457,13 +457,13 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang async with trio.open_nursery() as nursery: for node in change.added_updated_nodes: node_attrs = graph.nodes[node] - nursery.start_soon(lambda: graph_node_to_chunk(kb_id, embd_mdl, node, node_attrs, chunks)) + nursery.start_soon(graph_node_to_chunk, kb_id, embd_mdl, node, node_attrs, chunks) for from_node, to_node in change.added_updated_edges: edge_attrs = graph.get_edge_data(from_node, to_node) if not edge_attrs: # added_updated_edges could record a non-existing edge if both from_node and to_node participate in nodes merging. continue - nursery.start_soon(lambda: graph_edge_to_chunk(kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks)) + nursery.start_soon(graph_edge_to_chunk, kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks) now = trio.current_time() if callback: callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.") diff --git a/rag/raptor.py b/rag/raptor.py index d09ea5719..dc8fbc70f 100644 --- a/rag/raptor.py +++ b/rag/raptor.py @@ -152,7 +152,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c] assert len(ck_idx) > 0 async with chat_limiter: - nursery.start_soon(lambda: summarize(ck_idx)) + nursery.start_soon(summarize, ck_idx) assert len(chunks) - end == n_clusters, "{} vs. {}".format( len(chunks) - end, n_clusters diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 49fb137da..cd285d5fa 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -309,7 +309,7 @@ async def build_chunks(task, progress_callback): return async with trio.open_nursery() as nursery: for d in docs: - nursery.start_soon(lambda: doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"])) + nursery.start_soon(doc_keyword_extraction, chat_mdl, d, task["parser_config"]["auto_keywords"]) progress_callback(msg="Keywords generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st)) if task["parser_config"].get("auto_questions", 0): @@ -328,7 +328,7 @@ async def build_chunks(task, progress_callback): d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"])) async with trio.open_nursery() as nursery: for d in docs: - nursery.start_soon(lambda: doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"])) + nursery.start_soon(doc_question_proposal, chat_mdl, d, task["parser_config"]["auto_questions"]) progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st)) if task["kb_parser_config"].get("tag_kb_ids", []): @@ -370,7 +370,7 @@ async def build_chunks(task, progress_callback): d[TAG_FLD] = json.loads(cached) async with trio.open_nursery() as nursery: for d in docs_to_tag: - nursery.start_soon(lambda: doc_content_tagging(chat_mdl, d, topn_tags)) + nursery.start_soon(doc_content_tagging, chat_mdl, d, topn_tags) progress_callback(msg="Tagging {} chunks completed in {:.2f}s".format(len(docs), timer() - st)) return docs @@ -653,11 +653,11 @@ async def report_status(): async def main(): logging.info(r""" - ______ __ ______ __ + ______ __ ______ __ /_ __/___ ______/ /__ / ____/ _____ _______ __/ /_____ _____ / / / __ `/ ___/ //_/ / __/ | |/_/ _ \/ ___/ / / / __/ __ \/ ___/ - / / / /_/ (__ ) ,< / /____>