diff --git a/graphrag/entity_resolution.py b/graphrag/entity_resolution.py index 01d99e2be..e00d3bf77 100644 --- a/graphrag/entity_resolution.py +++ b/graphrag/entity_resolution.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import logging import itertools import re import time @@ -67,7 +68,7 @@ class EntityResolution(Extractor): self._resolution_result_delimiter_key = "resolution_result_delimiter" self._input_text_key = "input_text" - async def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult: + async def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None, callback: Callable | None = None) -> EntityResolutionResult: """Call method definition.""" if prompt_variables is None: prompt_variables = {} @@ -93,6 +94,8 @@ class EntityResolution(Extractor): candidate_resolution = {entity_type: [] for entity_type in entity_types} for k, v in node_clusters.items(): candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if self.is_similarity(a, b)] + num_candidates = sum([len(candidates) for _, candidates in candidate_resolution.items()]) + callback(msg=f"Identified {num_candidates} candidate pairs") resolution_result = set() async with trio.open_nursery() as nursery: @@ -100,48 +103,52 @@ class EntityResolution(Extractor): if not candidate_resolution_i[1]: continue nursery.start_soon(lambda: self._resolve_candidate(candidate_resolution_i, resolution_result)) + callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.") connect_graph = nx.Graph() removed_entities = [] connect_graph.add_edges_from(resolution_result) all_entities_data = [] all_relationships_data = [] + all_remove_nodes = [] - for sub_connect_graph in nx.connected_components(connect_graph): - sub_connect_graph = connect_graph.subgraph(sub_connect_graph) - remove_nodes = list(sub_connect_graph.nodes) - keep_node = remove_nodes.pop() - await self._merge_nodes(keep_node, self._get_entity_(remove_nodes), all_entities_data) - for remove_node in remove_nodes: - removed_entities.append(remove_node) - remove_node_neighbors = graph[remove_node] - remove_node_neighbors = list(remove_node_neighbors) - for remove_node_neighbor in remove_node_neighbors: - rel = self._get_relation_(remove_node, remove_node_neighbor) - if graph.has_edge(remove_node, remove_node_neighbor): - graph.remove_edge(remove_node, remove_node_neighbor) - if remove_node_neighbor == keep_node: - if graph.has_edge(keep_node, remove_node): - graph.remove_edge(keep_node, remove_node) - continue - if not rel: - continue - if graph.has_edge(keep_node, remove_node_neighbor): - await self._merge_edges(keep_node, remove_node_neighbor, [rel], all_relationships_data) - else: - pair = sorted([keep_node, remove_node_neighbor]) - graph.add_edge(pair[0], pair[1], weight=rel['weight']) - self._set_relation_(pair[0], pair[1], - dict( - src_id=pair[0], - tgt_id=pair[1], - weight=rel['weight'], - description=rel['description'], - keywords=[], - source_id=rel.get("source_id", ""), - metadata={"created_at": time.time()} - )) - graph.remove_node(remove_node) + async with trio.open_nursery() as nursery: + for sub_connect_graph in nx.connected_components(connect_graph): + sub_connect_graph = connect_graph.subgraph(sub_connect_graph) + remove_nodes = list(sub_connect_graph.nodes) + keep_node = remove_nodes.pop() + all_remove_nodes.append(remove_nodes) + nursery.start_soon(lambda: self._merge_nodes(keep_node, self._get_entity_(remove_nodes), all_entities_data)) + for remove_node in remove_nodes: + removed_entities.append(remove_node) + remove_node_neighbors = graph[remove_node] + remove_node_neighbors = list(remove_node_neighbors) + for remove_node_neighbor in remove_node_neighbors: + rel = self._get_relation_(remove_node, remove_node_neighbor) + if graph.has_edge(remove_node, remove_node_neighbor): + graph.remove_edge(remove_node, remove_node_neighbor) + if remove_node_neighbor == keep_node: + if graph.has_edge(keep_node, remove_node): + graph.remove_edge(keep_node, remove_node) + continue + if not rel: + continue + if graph.has_edge(keep_node, remove_node_neighbor): + nursery.start_soon(lambda: self._merge_edges(keep_node, remove_node_neighbor, [rel], all_relationships_data)) + else: + pair = sorted([keep_node, remove_node_neighbor]) + graph.add_edge(pair[0], pair[1], weight=rel['weight']) + self._set_relation_(pair[0], pair[1], + dict( + src_id=pair[0], + tgt_id=pair[1], + weight=rel['weight'], + description=rel['description'], + keywords=[], + source_id=rel.get("source_id", ""), + metadata={"created_at": time.time()} + )) + graph.remove_node(remove_node) return EntityResolutionResult( graph=graph, @@ -164,8 +171,10 @@ class EntityResolution(Extractor): self._input_text_key: pair_prompt } text = perform_variable_replacements(self._resolution_prompt, variables=variables) + logging.info(f"Created resolution prompt {len(text)} bytes for {len(candidate_resolution_i[1])} entity pairs of type {candidate_resolution_i[0]}") async with chat_limiter: response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)) + logging.debug(f"_resolve_candidate chat prompt: {text}\nchat response: {response}") result = self._process_results(len(candidate_resolution_i[1]), response, self.prompt_variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER), diff --git a/graphrag/general/community_reports_extractor.py b/graphrag/general/community_reports_extractor.py index 9ef7c173f..5efc8e62c 100644 --- a/graphrag/general/community_reports_extractor.py +++ b/graphrag/general/community_reports_extractor.py @@ -19,7 +19,6 @@ from graphrag.general.leiden import add_community_info2graph from rag.llm.chat_model import Base as CompletionLLM from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter from rag.utils import num_tokens_from_string -from timeit import default_timer as timer import trio @@ -62,62 +61,69 @@ class CommunityReportsExtractor(Extractor): res_str = [] res_dict = [] over, token_count = 0, 0 - st = timer() - for level, comm in communities.items(): - logging.info(f"Level {level}: Community: {len(comm.keys())}") - for cm_id, ents in comm.items(): - weight = ents["weight"] - ents = ents["nodes"] - ent_df = pd.DataFrame(self._get_entity_(ents)).dropna()#[{"entity": n, **graph.nodes[n]} for n in ents]) - if ent_df.empty or "entity_name" not in ent_df.columns: - continue - ent_df["entity"] = ent_df["entity_name"] - del ent_df["entity_name"] - rela_df = pd.DataFrame(self._get_relation_(list(ent_df["entity"]), list(ent_df["entity"]), 10000)) - if rela_df.empty: - continue - rela_df["source"] = rela_df["src_id"] - rela_df["target"] = rela_df["tgt_id"] - del rela_df["src_id"] - del rela_df["tgt_id"] + async def extract_community_report(community): + nonlocal res_str, res_dict, over, token_count + cm_id, ents = community + weight = ents["weight"] + ents = ents["nodes"] + ent_df = pd.DataFrame(self._get_entity_(ents)).dropna() + if ent_df.empty or "entity_name" not in ent_df.columns: + return + ent_df["entity"] = ent_df["entity_name"] + del ent_df["entity_name"] + rela_df = pd.DataFrame(self._get_relation_(list(ent_df["entity"]), list(ent_df["entity"]), 10000)) + if rela_df.empty: + return + rela_df["source"] = rela_df["src_id"] + rela_df["target"] = rela_df["tgt_id"] + del rela_df["src_id"] + del rela_df["tgt_id"] - prompt_variables = { - "entity_df": ent_df.to_csv(index_label="id"), - "relation_df": rela_df.to_csv(index_label="id") - } - text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables) - gen_conf = {"temperature": 0.3} - async with chat_limiter: - response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)) - token_count += num_tokens_from_string(text + response) - response = re.sub(r"^[^\{]*", "", response) - response = re.sub(r"[^\}]*$", "", response) - response = re.sub(r"\{\{", "{", response) - response = re.sub(r"\}\}", "}", response) - logging.debug(response) - try: - response = json.loads(response) - except json.JSONDecodeError as e: - logging.error(f"Failed to parse JSON response: {e}") - logging.error(f"Response content: {response}") - continue - if not dict_has_keys_with_types(response, [ - ("title", str), - ("summary", str), - ("findings", list), - ("rating", float), - ("rating_explanation", str), - ]): - continue - response["weight"] = weight - response["entities"] = ents + prompt_variables = { + "entity_df": ent_df.to_csv(index_label="id"), + "relation_df": rela_df.to_csv(index_label="id") + } + text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables) + gen_conf = {"temperature": 0.3} + async with chat_limiter: + response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)) + token_count += num_tokens_from_string(text + response) + response = re.sub(r"^[^\{]*", "", response) + response = re.sub(r"[^\}]*$", "", response) + response = re.sub(r"\{\{", "{", response) + response = re.sub(r"\}\}", "}", response) + logging.debug(response) + try: + response = json.loads(response) + except json.JSONDecodeError as e: + logging.error(f"Failed to parse JSON response: {e}") + logging.error(f"Response content: {response}") + return + if not dict_has_keys_with_types(response, [ + ("title", str), + ("summary", str), + ("findings", list), + ("rating", float), + ("rating_explanation", str), + ]): + return + response["weight"] = weight + response["entities"] = ents + add_community_info2graph(graph, ents, response["title"]) + res_str.append(self._get_text_output(response)) + res_dict.append(response) + over += 1 + if callback: + callback(msg=f"Communities: {over}/{total}, used tokens: {token_count}") - add_community_info2graph(graph, ents, response["title"]) - res_str.append(self._get_text_output(response)) - res_dict.append(response) - over += 1 - if callback: - callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}") + st = trio.current_time() + async with trio.open_nursery() as nursery: + for level, comm in communities.items(): + logging.info(f"Level {level}: Community: {len(comm.keys())}") + for community in comm.items(): + nursery.start_soon(lambda: extract_community_report(community)) + if callback: + callback(msg=f"Community reports done in {trio.current_time() - st:.2f}s, used tokens: {token_count}") return CommunityReportsResult( structured_output=res_dict, diff --git a/graphrag/general/index.py b/graphrag/general/index.py index 179478c99..8b63b0d02 100644 --- a/graphrag/general/index.py +++ b/graphrag/general/index.py @@ -228,7 +228,7 @@ async def resolve_entities( get_relation=partial(get_relation, tenant_id, kb_id), set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl), ) - reso = await er(graph) + reso = await er(graph, callback=callback) graph = reso.graph callback(msg=f"Graph resolution removed {len(reso.removed_entities)} nodes.") await update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, 2) diff --git a/graphrag/utils.py b/graphrag/utils.py index 29ed94252..ab09e536f 100644 --- a/graphrag/utils.py +++ b/graphrag/utils.py @@ -489,15 +489,16 @@ async def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop): return nbrs pr = nx.pagerank(graph) - for n, p in pr.items(): - graph.nodes[n]["pagerank"] = p - try: - await trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"entity_kwd": n, "kb_id": kb_id}, - {"rank_flt": p, - "n_hop_with_weight": json.dumps( (n), ensure_ascii=False)}, - search.index_name(tenant_id), kb_id)) - except Exception as e: - logging.exception(e) + try: + async with trio.open_nursery() as nursery: + for n, p in pr.items(): + graph.nodes[n]["pagerank"] = p + nursery.start_soon(lambda: trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"entity_kwd": n, "kb_id": kb_id}, + {"rank_flt": p, + "n_hop_with_weight": json.dumps((n), ensure_ascii=False)}, + search.index_name(tenant_id), kb_id))) + except Exception as e: + logging.exception(e) ty2ents = defaultdict(list) for p, r in sorted(pr.items(), key=lambda x: x[1], reverse=True):