diff --git a/api/ragflow_server.py b/api/ragflow_server.py index b6407b52..52e565b4 100644 --- a/api/ragflow_server.py +++ b/api/ragflow_server.py @@ -47,6 +47,8 @@ from rag.utils.redis_conn import RedisDistributedLock stop_event = threading.Event() +RAGFLOW_DEBUGPY_LISTEN = int(os.environ.get('RAGFLOW_DEBUGPY_LISTEN', "0")) + def update_progress(): lock_value = str(uuid.uuid4()) redis_lock = RedisDistributedLock("update_progress", lock_value=lock_value, timeout=60) @@ -85,6 +87,11 @@ if __name__ == '__main__': settings.init_settings() print_rag_settings() + if RAGFLOW_DEBUGPY_LISTEN > 0: + logging.info(f"debugpy listen on {RAGFLOW_DEBUGPY_LISTEN}") + import debugpy + debugpy.listen(("0.0.0.0", RAGFLOW_DEBUGPY_LISTEN)) + # init db init_web_db() init_web_data() diff --git a/conf/infinity_mapping.json b/conf/infinity_mapping.json index 95d5f2c6..06c89cad 100644 --- a/conf/infinity_mapping.json +++ b/conf/infinity_mapping.json @@ -8,7 +8,7 @@ "docnm_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}, "title_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"}, "title_sm_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"}, - "name_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}, + "name_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"}, "important_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"}, "tag_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"}, "important_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"}, @@ -27,16 +27,16 @@ "rank_int": {"type": "integer", "default": 0}, "rank_flt": {"type": "float", "default": 0}, "available_int": {"type": "integer", "default": 1}, - "knowledge_graph_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}, + "knowledge_graph_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"}, "entities_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"}, "pagerank_fea": {"type": "integer", "default": 0}, "tag_feas": {"type": "varchar", "default": ""}, - "from_entity_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}, - "to_entity_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}, - "entity_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}, - "entity_type_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}, - "source_id": {"type": "varchar", "default": ""}, + "from_entity_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"}, + "to_entity_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"}, + "entity_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"}, + "entity_type_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"}, + "source_id": {"type": "varchar", "default": "", "analyzer": "whitespace-#"}, "n_hop_with_weight": {"type": "varchar", "default": ""}, - "removed_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"} + "removed_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"} } diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index fad404c8..e5ff15e3 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -12,6 +12,8 @@ services: - ${SVR_HTTP_PORT}:9380 - 80:80 - 443:443 + - 5678:5678 + - 5679:5679 volumes: - ./ragflow-logs:/ragflow/logs - ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf diff --git a/graphrag/entity_resolution.py b/graphrag/entity_resolution.py index fad9513a..05229853 100644 --- a/graphrag/entity_resolution.py +++ b/graphrag/entity_resolution.py @@ -16,7 +16,6 @@ import logging import itertools import re -import time from dataclasses import dataclass from typing import Any, Callable @@ -28,7 +27,7 @@ from rag.nlp import is_english import editdistance from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT from rag.llm.chat_model import Base as CompletionLLM -from graphrag.utils import perform_variable_replacements, chat_limiter +from graphrag.utils import perform_variable_replacements, chat_limiter, GraphChange DEFAULT_RECORD_DELIMITER = "##" DEFAULT_ENTITY_INDEX_DELIMITER = "<|>" @@ -39,7 +38,7 @@ DEFAULT_RESOLUTION_RESULT_DELIMITER = "&&" class EntityResolutionResult: """Entity resolution result class definition.""" graph: nx.Graph - removed_entities: list + change: GraphChange class EntityResolution(Extractor): @@ -54,12 +53,8 @@ class EntityResolution(Extractor): def __init__( self, llm_invoker: CompletionLLM, - get_entity: Callable | None = None, - set_entity: Callable | None = None, - get_relation: Callable | None = None, - set_relation: Callable | None = None ): - super().__init__(llm_invoker, get_entity=get_entity, set_entity=set_entity, get_relation=get_relation, set_relation=set_relation) + super().__init__(llm_invoker) """Init method definition.""" self._llm = llm_invoker self._resolution_prompt = ENTITY_RESOLUTION_PROMPT @@ -84,8 +79,8 @@ class EntityResolution(Extractor): or DEFAULT_RESOLUTION_RESULT_DELIMITER, } - nodes = graph.nodes - entity_types = list(set(graph.nodes[node].get('entity_type', '-') for node in nodes)) + nodes = sorted(graph.nodes()) + entity_types = sorted(set(graph.nodes[node].get('entity_type', '-') for node in nodes)) node_clusters = {entity_type: [] for entity_type in entity_types} for node in nodes: @@ -105,54 +100,22 @@ class EntityResolution(Extractor): nursery.start_soon(lambda: self._resolve_candidate(candidate_resolution_i, resolution_result)) callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.") + change = GraphChange() connect_graph = nx.Graph() - removed_entities = [] connect_graph.add_edges_from(resolution_result) - all_entities_data = [] - all_relationships_data = [] - all_remove_nodes = [] - async with trio.open_nursery() as nursery: for sub_connect_graph in nx.connected_components(connect_graph): - sub_connect_graph = connect_graph.subgraph(sub_connect_graph) - remove_nodes = list(sub_connect_graph.nodes) - keep_node = remove_nodes.pop() - all_remove_nodes.append(remove_nodes) - nursery.start_soon(lambda: self._merge_nodes(keep_node, self._get_entity_(remove_nodes), all_entities_data)) - for remove_node in remove_nodes: - removed_entities.append(remove_node) - remove_node_neighbors = graph[remove_node] - remove_node_neighbors = list(remove_node_neighbors) - for remove_node_neighbor in remove_node_neighbors: - rel = self._get_relation_(remove_node, remove_node_neighbor) - if graph.has_edge(remove_node, remove_node_neighbor): - graph.remove_edge(remove_node, remove_node_neighbor) - if remove_node_neighbor == keep_node: - if graph.has_edge(keep_node, remove_node): - graph.remove_edge(keep_node, remove_node) - continue - if not rel: - continue - if graph.has_edge(keep_node, remove_node_neighbor): - nursery.start_soon(lambda: self._merge_edges(keep_node, remove_node_neighbor, [rel], all_relationships_data)) - else: - pair = sorted([keep_node, remove_node_neighbor]) - graph.add_edge(pair[0], pair[1], weight=rel['weight']) - self._set_relation_(pair[0], pair[1], - dict( - src_id=pair[0], - tgt_id=pair[1], - weight=rel['weight'], - description=rel['description'], - keywords=[], - source_id=rel.get("source_id", ""), - metadata={"created_at": time.time()} - )) - graph.remove_node(remove_node) + merging_nodes = list(sub_connect_graph.nodes) + nursery.start_soon(lambda: self._merge_graph_nodes(graph, merging_nodes, change)) + + # Update pagerank + pr = nx.pagerank(graph) + for node_name, pagerank in pr.items(): + graph.nodes[node_name]["pagerank"] = pagerank return EntityResolutionResult( graph=graph, - removed_entities=removed_entities + change=change, ) async def _resolve_candidate(self, candidate_resolution_i, resolution_result): diff --git a/graphrag/general/community_report_prompt.py b/graphrag/general/community_report_prompt.py index 554ea367..8b9fa2f6 100644 --- a/graphrag/general/community_report_prompt.py +++ b/graphrag/general/community_report_prompt.py @@ -2,7 +2,7 @@ # Licensed under the MIT License """ Reference: - - [graphrag](https://github.com/microsoft/graphrag) + - [GraphRAG](https://github.com/microsoft/graphrag/blob/main/graphrag/prompts/index/community_report.py) """ COMMUNITY_REPORT_PROMPT = """ diff --git a/graphrag/general/community_reports_extractor.py b/graphrag/general/community_reports_extractor.py index 5efc8e62..4b0989f9 100644 --- a/graphrag/general/community_reports_extractor.py +++ b/graphrag/general/community_reports_extractor.py @@ -40,13 +40,9 @@ class CommunityReportsExtractor(Extractor): def __init__( self, llm_invoker: CompletionLLM, - get_entity: Callable | None = None, - set_entity: Callable | None = None, - get_relation: Callable | None = None, - set_relation: Callable | None = None, max_report_length: int | None = None, ): - super().__init__(llm_invoker, get_entity=get_entity, set_entity=set_entity, get_relation=get_relation, set_relation=set_relation) + super().__init__(llm_invoker) """Init method definition.""" self._llm = llm_invoker self._extraction_prompt = COMMUNITY_REPORT_PROMPT @@ -63,21 +59,28 @@ class CommunityReportsExtractor(Extractor): over, token_count = 0, 0 async def extract_community_report(community): nonlocal res_str, res_dict, over, token_count - cm_id, ents = community - weight = ents["weight"] - ents = ents["nodes"] - ent_df = pd.DataFrame(self._get_entity_(ents)).dropna() - if ent_df.empty or "entity_name" not in ent_df.columns: + cm_id, cm = community + weight = cm["weight"] + ents = cm["nodes"] + if len(ents) < 2: return - ent_df["entity"] = ent_df["entity_name"] - del ent_df["entity_name"] - rela_df = pd.DataFrame(self._get_relation_(list(ent_df["entity"]), list(ent_df["entity"]), 10000)) - if rela_df.empty: - return - rela_df["source"] = rela_df["src_id"] - rela_df["target"] = rela_df["tgt_id"] - del rela_df["src_id"] - del rela_df["tgt_id"] + ent_list = [{"entity": ent, "description": graph.nodes[ent]["description"]} for ent in ents] + ent_df = pd.DataFrame(ent_list) + + rela_list = [] + k = 0 + for i in range(0, len(ents)): + if k >= 10000: + break + for j in range(i + 1, len(ents)): + if k >= 10000: + break + edge = graph.get_edge_data(ents[i], ents[j]) + if edge is None: + continue + rela_list.append({"source": ents[i], "target": ents[j], "description": edge["description"]}) + k += 1 + rela_df = pd.DataFrame(rela_list) prompt_variables = { "entity_df": ent_df.to_csv(index_label="id"), diff --git a/graphrag/general/extractor.py b/graphrag/general/extractor.py index 5f62d137..4b906903 100644 --- a/graphrag/general/extractor.py +++ b/graphrag/general/extractor.py @@ -19,10 +19,11 @@ from collections import defaultdict, Counter from copy import deepcopy from typing import Callable import trio +import networkx as nx from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT from graphrag.utils import get_llm_cache, set_llm_cache, handle_single_entity_extraction, \ - handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list, chat_limiter + handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list, chat_limiter, get_from_to, GraphChange from rag.llm.chat_model import Base as CompletionLLM from rag.prompts import message_fit_in from rag.utils import truncate @@ -40,18 +41,10 @@ class Extractor: llm_invoker: CompletionLLM, language: str | None = "English", entity_types: list[str] | None = None, - get_entity: Callable | None = None, - set_entity: Callable | None = None, - get_relation: Callable | None = None, - set_relation: Callable | None = None, ): self._llm = llm_invoker self._language = language self._entity_types = entity_types or DEFAULT_ENTITY_TYPES - self._get_entity_ = get_entity - self._set_entity_ = set_entity - self._get_relation_ = get_relation - self._set_relation_ = set_relation def _chat(self, system, history, gen_conf): hist = deepcopy(history) @@ -152,25 +145,15 @@ class Extractor: async def _merge_nodes(self, entity_name: str, entities: list[dict], all_relationships_data): if not entities: return - already_entity_types = [] - already_source_ids = [] - already_description = [] - - already_node = self._get_entity_(entity_name) - if already_node: - already_entity_types.append(already_node["entity_type"]) - already_source_ids.extend(already_node["source_id"]) - already_description.append(already_node["description"]) - entity_type = sorted( Counter( - [dp["entity_type"] for dp in entities] + already_entity_types + [dp["entity_type"] for dp in entities] ).items(), key=lambda x: x[1], reverse=True, )[0][0] description = GRAPH_FIELD_SEP.join( - sorted(set([dp["description"] for dp in entities] + already_description)) + sorted(set([dp["description"] for dp in entities])) ) already_source_ids = flat_uniq_list(entities, "source_id") description = await self._handle_entity_relation_summary(entity_name, description) @@ -180,7 +163,6 @@ class Extractor: source_id=already_source_ids, ) node_data["entity_name"] = entity_name - self._set_entity_(entity_name, node_data) all_relationships_data.append(node_data) async def _merge_edges( @@ -192,36 +174,11 @@ class Extractor: ): if not edges_data: return - already_weights = [] - already_source_ids = [] - already_description = [] - already_keywords = [] - - relation = self._get_relation_(src_id, tgt_id) - if relation: - already_weights = [relation["weight"]] - already_source_ids = relation["source_id"] - already_description = [relation["description"]] - already_keywords = relation["keywords"] - - weight = sum([dp["weight"] for dp in edges_data] + already_weights) - description = GRAPH_FIELD_SEP.join( - sorted(set([dp["description"] for dp in edges_data] + already_description)) - ) - keywords = flat_uniq_list(edges_data, "keywords") + already_keywords - source_id = flat_uniq_list(edges_data, "source_id") + already_source_ids - - for need_insert_id in [src_id, tgt_id]: - if self._get_entity_(need_insert_id): - continue - self._set_entity_(need_insert_id, { - "source_id": source_id, - "description": description, - "entity_type": 'UNKNOWN' - }) - description = await self._handle_entity_relation_summary( - f"({src_id}, {tgt_id})", description - ) + weight = sum([edge["weight"] for edge in edges_data]) + description = GRAPH_FIELD_SEP.join(sorted(set([edge["description"] for edge in edges_data]))) + description = await self._handle_entity_relation_summary(f"{src_id} -> {tgt_id}", description) + keywords = flat_uniq_list(edges_data, "keywords") + source_id = flat_uniq_list(edges_data, "source_id") edge_data = dict( src_id=src_id, tgt_id=tgt_id, @@ -230,9 +187,41 @@ class Extractor: weight=weight, source_id=source_id ) - self._set_relation_(src_id, tgt_id, edge_data) - if all_relationships_data is not None: - all_relationships_data.append(edge_data) + all_relationships_data.append(edge_data) + + async def _merge_graph_nodes(self, graph: nx.Graph, nodes: list[str], change: GraphChange): + if len(nodes) <= 1: + return + change.added_updated_nodes.add(nodes[0]) + change.removed_nodes.extend(nodes[1:]) + nodes_set = set(nodes) + node0_attrs = graph.nodes[nodes[0]] + node0_neighbors = set(graph.neighbors(nodes[0])) + for node1 in nodes[1:]: + # Merge two nodes, keep "entity_name", "entity_type", "page_rank" unchanged. + node1_attrs = graph.nodes[node1] + node0_attrs["description"] += f"{GRAPH_FIELD_SEP}{node1_attrs['description']}" + for attr in ["keywords", "source_id"]: + node0_attrs[attr] = sorted(set(node0_attrs[attr].extend(node1_attrs[attr]))) + for neighbor in graph.neighbors(node1): + change.removed_edges.add(get_from_to(node1, neighbor)) + if neighbor not in nodes_set: + edge1_attrs = graph.get_edge_data(node1, neighbor) + if neighbor in node0_neighbors: + # Merge two edges + change.added_updated_edges.add(get_from_to(nodes[0], neighbor)) + edge0_attrs = graph.get_edge_data(nodes[0], neighbor) + edge0_attrs["weight"] += edge1_attrs["weight"] + edge0_attrs["description"] += f"{GRAPH_FIELD_SEP}{edge1_attrs['description']}" + edge0_attrs["keywords"] = list(set(edge0_attrs["keywords"].extend(edge1_attrs["keywords"]))) + edge0_attrs["source_id"] = list(set(edge0_attrs["source_id"].extend(edge1_attrs["source_id"]))) + edge0_attrs["description"] = await self._handle_entity_relation_summary(f"({nodes[0]}, {neighbor})", edge0_attrs["description"]) + graph.add_edge(nodes[0], neighbor, **edge0_attrs) + else: + graph.add_edge(nodes[0], neighbor, **edge1_attrs) + graph.remove_node(node1) + node0_attrs["description"] = await self._handle_entity_relation_summary(nodes[0], node0_attrs["description"]) + graph.nodes[nodes[0]].update(node0_attrs) async def _handle_entity_relation_summary( self, diff --git a/graphrag/general/graph_extractor.py b/graphrag/general/graph_extractor.py index e3c91126..b2a3948a 100644 --- a/graphrag/general/graph_extractor.py +++ b/graphrag/general/graph_extractor.py @@ -6,7 +6,7 @@ Reference: """ import re -from typing import Any, Callable +from typing import Any from dataclasses import dataclass import tiktoken import trio @@ -53,10 +53,6 @@ class GraphExtractor(Extractor): llm_invoker: CompletionLLM, language: str | None = "English", entity_types: list[str] | None = None, - get_entity: Callable | None = None, - set_entity: Callable | None = None, - get_relation: Callable | None = None, - set_relation: Callable | None = None, tuple_delimiter_key: str | None = None, record_delimiter_key: str | None = None, input_text_key: str | None = None, @@ -66,7 +62,7 @@ class GraphExtractor(Extractor): max_gleanings: int | None = None, on_error: ErrorHandlerFn | None = None, ): - super().__init__(llm_invoker, language, entity_types, get_entity, set_entity, get_relation, set_relation) + super().__init__(llm_invoker, language, entity_types) """Init method definition.""" # TODO: streamline construction self._llm = llm_invoker diff --git a/graphrag/general/graph_prompt.py b/graphrag/general/graph_prompt.py index 3a1f6483..3472bc73 100644 --- a/graphrag/general/graph_prompt.py +++ b/graphrag/general/graph_prompt.py @@ -2,7 +2,7 @@ # Licensed under the MIT License """ Reference: - - [graphrag](https://github.com/microsoft/graphrag) + - [GraphRAG](https://github.com/microsoft/graphrag/blob/main/graphrag/prompts/index/extract_graph.py) """ GRAPH_EXTRACTION_PROMPT = """ diff --git a/graphrag/general/index.py b/graphrag/general/index.py index dabb8a09..8b41eb5c 100644 --- a/graphrag/general/index.py +++ b/graphrag/general/index.py @@ -15,11 +15,11 @@ # import json import logging -from functools import partial import networkx as nx import trio from api import settings +from api.utils import get_uuid from graphrag.light.graph_extractor import GraphExtractor as LightKGExt from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt from graphrag.general.community_reports_extractor import CommunityReportsExtractor @@ -27,32 +27,15 @@ from graphrag.entity_resolution import EntityResolution from graphrag.general.extractor import Extractor from graphrag.utils import ( graph_merge, - set_entity, - get_relation, - set_relation, - get_entity, get_graph, set_graph, chunk_id, - update_nodes_pagerank_nhop_neighbour, does_graph_contains, - get_graph_doc_ids, + tidy_graph, + GraphChange, ) from rag.nlp import rag_tokenizer, search -from rag.utils.redis_conn import REDIS_CONN - - -def graphrag_task_set(tenant_id, kb_id, doc_id) -> bool: - key = f"graphrag:{tenant_id}:{kb_id}" - ok = REDIS_CONN.set(key, doc_id, exp=3600 * 24) - if not ok: - raise Exception(f"Faild to set the {key} to {doc_id}") - - -def graphrag_task_get(tenant_id, kb_id) -> str | None: - key = f"graphrag:{tenant_id}:{kb_id}" - doc_id = REDIS_CONN.get(key) - return doc_id +from rag.utils.redis_conn import RedisDistributedLock async def run_graphrag( @@ -72,7 +55,7 @@ async def run_graphrag( ): chunks.append(d["content_with_weight"]) - graph, doc_ids = await update_graph( + subgraph = await generate_subgraph( LightKGExt if row["parser_config"]["graphrag"]["method"] != "general" else GeneralKGExt, @@ -86,14 +69,26 @@ async def run_graphrag( embedding_model, callback, ) - if not graph: + new_graph = None + if subgraph: + new_graph = await merge_subgraph( + tenant_id, + kb_id, + doc_id, + subgraph, + embedding_model, + callback, + ) + + if not with_resolution or not with_community: return - if with_resolution or with_community: - graphrag_task_set(tenant_id, kb_id, doc_id) - if with_resolution: + + if new_graph is None: + new_graph = await get_graph(tenant_id, kb_id) + + if with_resolution and new_graph is not None: await resolve_entities( - graph, - doc_ids, + new_graph, tenant_id, kb_id, doc_id, @@ -101,10 +96,9 @@ async def run_graphrag( embedding_model, callback, ) - if with_community: + if with_community and new_graph is not None: await extract_community( - graph, - doc_ids, + new_graph, tenant_id, kb_id, doc_id, @@ -117,7 +111,7 @@ async def run_graphrag( return -async def update_graph( +async def generate_subgraph( extractor: Extractor, tenant_id: str, kb_id: str, @@ -131,34 +125,41 @@ async def update_graph( ): contains = await does_graph_contains(tenant_id, kb_id, doc_id) if contains: - callback(msg=f"Graph already contains {doc_id}, cancel myself") - return None, None + callback(msg=f"Graph already contains {doc_id}") + return None start = trio.current_time() ext = extractor( llm_bdl, language=language, entity_types=entity_types, - get_entity=partial(get_entity, tenant_id, kb_id), - set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl), - get_relation=partial(get_relation, tenant_id, kb_id), - set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl), ) ents, rels = await ext(doc_id, chunks, callback) subgraph = nx.Graph() - for en in ents: - subgraph.add_node(en["entity_name"], entity_type=en["entity_type"]) + for ent in ents: + assert "description" in ent, f"entity {ent} does not have description" + ent["source_id"] = [doc_id] + subgraph.add_node(ent["entity_name"], **ent) + ignored_rels = 0 for rel in rels: + assert "description" in rel, f"relation {rel} does not have description" + if not subgraph.has_node(rel["src_id"]) or not subgraph.has_node(rel["tgt_id"]): + ignored_rels += 1 + continue + rel["source_id"] = [doc_id] subgraph.add_edge( rel["src_id"], rel["tgt_id"], - weight=rel["weight"], - # description=rel["description"] + **rel, ) - # TODO: infinity doesn't support array search + if ignored_rels: + callback(msg=f"ignored {ignored_rels} relations due to missing entities.") + tidy_graph(subgraph, callback) + + subgraph.graph["source_id"] = [doc_id] chunk = { "content_with_weight": json.dumps( - nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False, indent=2 + nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False ), "knowledge_graph_kwd": "subgraph", "kb_id": kb_id, @@ -167,6 +168,11 @@ async def update_graph( "removed_kwd": "N", } cid = chunk_id(chunk) + await trio.to_thread.run_sync( + lambda: settings.docStoreConn.delete( + {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id + ) + ) await trio.to_thread.run_sync( lambda: settings.docStoreConn.insert( [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id @@ -174,39 +180,49 @@ async def update_graph( ) now = trio.current_time() callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.") - start = now + return subgraph +async def merge_subgraph( + tenant_id: str, + kb_id: str, + doc_id: str, + subgraph: nx.Graph, + embedding_model, + callback, +): + graphrag_task_lock = RedisDistributedLock("graphrag_task", lock_value=doc_id, timeout=600) while True: + if graphrag_task_lock.acquire(): + break + callback(msg=f"merge_subgraph {doc_id} is waiting graphrag_task_lock") + await trio.sleep(10) + + start = trio.current_time() + change = GraphChange() + old_graph = await get_graph(tenant_id, kb_id) + if old_graph is not None: + logging.info("Merge with an exiting graph...................") + tidy_graph(old_graph, callback) + new_graph = graph_merge(old_graph, subgraph, change) + else: new_graph = subgraph - now_docids = set([doc_id]) - old_graph, old_doc_ids = await get_graph(tenant_id, kb_id) - if old_graph is not None: - logging.info("Merge with an exiting graph...................") - new_graph = graph_merge(old_graph, subgraph) - await update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, new_graph, 2) - if old_doc_ids: - for old_doc_id in old_doc_ids: - now_docids.add(old_doc_id) - old_doc_ids2 = await get_graph_doc_ids(tenant_id, kb_id) - delta_doc_ids = set(old_doc_ids2) - set(old_doc_ids) - if delta_doc_ids: - callback( - msg="The global graph has changed during merging, try again" - ) - await trio.sleep(1) - continue - break - await set_graph(tenant_id, kb_id, new_graph, list(now_docids)) + change.added_updated_nodes = set(new_graph.nodes()) + change.added_updated_edges = set(new_graph.edges()) + pr = nx.pagerank(new_graph) + for node_name, pagerank in pr.items(): + new_graph.nodes[node_name]["pagerank"] = pagerank + + await set_graph(tenant_id, kb_id, embedding_model, new_graph, change, callback) + graphrag_task_lock.release() now = trio.current_time() callback( msg=f"merging subgraph for doc {doc_id} into the global graph done in {now - start:.2f} seconds." ) - return new_graph, now_docids + return new_graph async def resolve_entities( graph, - doc_ids, tenant_id: str, kb_id: str, doc_id: str, @@ -214,74 +230,30 @@ async def resolve_entities( embed_bdl, callback, ): - working_doc_id = graphrag_task_get(tenant_id, kb_id) - if doc_id != working_doc_id: - callback( - msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself" - ) - return + graphrag_task_lock = RedisDistributedLock("graphrag_task", lock_value=doc_id, timeout=600) + while True: + if graphrag_task_lock.acquire(): + break + await trio.sleep(10) + start = trio.current_time() er = EntityResolution( llm_bdl, - get_entity=partial(get_entity, tenant_id, kb_id), - set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl), - get_relation=partial(get_relation, tenant_id, kb_id), - set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl), ) reso = await er(graph, callback=callback) graph = reso.graph - callback(msg=f"Graph resolution removed {len(reso.removed_entities)} nodes.") - await update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, 2) + change = reso.change + callback(msg=f"Graph resolution removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges.") callback(msg="Graph resolution updated pagerank.") - working_doc_id = graphrag_task_get(tenant_id, kb_id) - if doc_id != working_doc_id: - callback( - msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself" - ) - return - await set_graph(tenant_id, kb_id, graph, doc_ids) - - await trio.to_thread.run_sync( - lambda: settings.docStoreConn.delete( - { - "knowledge_graph_kwd": "relation", - "kb_id": kb_id, - "from_entity_kwd": reso.removed_entities, - }, - search.index_name(tenant_id), - kb_id, - ) - ) - await trio.to_thread.run_sync( - lambda: settings.docStoreConn.delete( - { - "knowledge_graph_kwd": "relation", - "kb_id": kb_id, - "to_entity_kwd": reso.removed_entities, - }, - search.index_name(tenant_id), - kb_id, - ) - ) - await trio.to_thread.run_sync( - lambda: settings.docStoreConn.delete( - { - "knowledge_graph_kwd": "entity", - "kb_id": kb_id, - "entity_kwd": reso.removed_entities, - }, - search.index_name(tenant_id), - kb_id, - ) - ) + await set_graph(tenant_id, kb_id, embed_bdl, graph, change, callback) + graphrag_task_lock.release() now = trio.current_time() callback(msg=f"Graph resolution done in {now - start:.2f}s.") async def extract_community( graph, - doc_ids, tenant_id: str, kb_id: str, doc_id: str, @@ -289,49 +261,34 @@ async def extract_community( embed_bdl, callback, ): - working_doc_id = graphrag_task_get(tenant_id, kb_id) - if doc_id != working_doc_id: - callback( - msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself" - ) - return + graphrag_task_lock = RedisDistributedLock("graphrag_task", lock_value=doc_id, timeout=600) + while True: + if graphrag_task_lock.acquire(): + break + await trio.sleep(10) + start = trio.current_time() ext = CommunityReportsExtractor( llm_bdl, - get_entity=partial(get_entity, tenant_id, kb_id), - set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl), - get_relation=partial(get_relation, tenant_id, kb_id), - set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl), ) cr = await ext(graph, callback=callback) community_structure = cr.structured_output community_reports = cr.output - working_doc_id = graphrag_task_get(tenant_id, kb_id) - if doc_id != working_doc_id: - callback( - msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself" - ) - return - await set_graph(tenant_id, kb_id, graph, doc_ids) + doc_ids = graph.graph["source_id"] now = trio.current_time() callback( msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s." ) start = now - await trio.to_thread.run_sync( - lambda: settings.docStoreConn.delete( - {"knowledge_graph_kwd": "community_report", "kb_id": kb_id}, - search.index_name(tenant_id), - kb_id, - ) - ) + chunks = [] for stru, rep in zip(community_structure, community_reports): obj = { "report": rep, "evidences": "\n".join([f["explanation"] for f in stru["findings"]]), } chunk = { + "id": get_uuid(), "docnm_kwd": stru["title"], "title_tks": rag_tokenizer.tokenize(stru["title"]), "content_with_weight": json.dumps(obj, ensure_ascii=False), @@ -349,17 +306,23 @@ async def extract_community( chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize( chunk["content_ltks"] ) - # try: - # ebd, _ = embed_bdl.encode([", ".join(community["entities"])]) - # chunk["q_%d_vec" % len(ebd[0])] = ebd[0] - # except Exception as e: - # logging.exception(f"Fail to embed entity relation: {e}") - await trio.to_thread.run_sync( - lambda: settings.docStoreConn.insert( - [{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id) - ) - ) + chunks.append(chunk) + await trio.to_thread.run_sync( + lambda: settings.docStoreConn.delete( + {"knowledge_graph_kwd": "community_report", "kb_id": kb_id}, + search.index_name(tenant_id), + kb_id, + ) + ) + es_bulk_size = 4 + for b in range(0, len(chunks), es_bulk_size): + doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(tenant_id), kb_id)) + if doc_store_result: + error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" + raise Exception(error_message) + + graphrag_task_lock.release() now = trio.current_time() callback( msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s." diff --git a/graphrag/general/leiden.py b/graphrag/general/leiden.py index 98ecca70..b859e7e6 100644 --- a/graphrag/general/leiden.py +++ b/graphrag/general/leiden.py @@ -100,7 +100,8 @@ def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]: logging.debug( "Running leiden with max_cluster_size=%s, lcc=%s", max_cluster_size, use_lcc ) - if not graph.nodes(): + nodes = set(graph.nodes()) + if not nodes: return {} node_id_to_community_map = _compute_leiden_communities( @@ -120,7 +121,7 @@ def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]: result = {} results_by_level[level] = result for node_id, raw_community_id in node_id_to_community_map[level].items(): - if node_id not in graph.nodes: + if node_id not in nodes: logging.warning(f"Node {node_id} not found in the graph.") continue community_id = str(raw_community_id) diff --git a/graphrag/light/graph_extractor.py b/graphrag/light/graph_extractor.py index 5c3aa3e5..8b809b83 100644 --- a/graphrag/light/graph_extractor.py +++ b/graphrag/light/graph_extractor.py @@ -5,7 +5,7 @@ Reference: - [graphrag](https://github.com/microsoft/graphrag) """ import re -from typing import Any, Callable +from typing import Any from dataclasses import dataclass from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS from graphrag.light.graph_prompt import PROMPTS @@ -33,14 +33,10 @@ class GraphExtractor(Extractor): llm_invoker: CompletionLLM, language: str | None = "English", entity_types: list[str] | None = None, - get_entity: Callable | None = None, - set_entity: Callable | None = None, - get_relation: Callable | None = None, - set_relation: Callable | None = None, example_number: int = 2, max_gleanings: int | None = None, ): - super().__init__(llm_invoker, language, entity_types, get_entity, set_entity, get_relation, set_relation) + super().__init__(llm_invoker, language, entity_types) """Init method definition.""" self._max_gleanings = ( max_gleanings diff --git a/graphrag/light/graph_prompt.py b/graphrag/light/graph_prompt.py index dcfe470e..80d9f3cc 100644 --- a/graphrag/light/graph_prompt.py +++ b/graphrag/light/graph_prompt.py @@ -1,7 +1,7 @@ # Licensed under the MIT License """ Reference: - - [LightRag](https://github.com/HKUDS/LightRAG) + - [LightRAG](https://github.com/HKUDS/LightRAG/blob/main/lightrag/prompt.py) """ diff --git a/graphrag/utils.py b/graphrag/utils.py index efac9783..8ed3890e 100644 --- a/graphrag/utils.py +++ b/graphrag/utils.py @@ -12,26 +12,37 @@ import logging import re import time from collections import defaultdict -from copy import deepcopy from hashlib import md5 from typing import Any, Callable import os import trio +from typing import Set, Tuple import networkx as nx import numpy as np import xxhash from networkx.readwrite import json_graph +import dataclasses from api import settings +from api.utils import get_uuid from rag.nlp import search, rag_tokenizer from rag.utils.doc_store_conn import OrderByExpr from rag.utils.redis_conn import REDIS_CONN +GRAPH_FIELD_SEP = "" + ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None] chat_limiter = trio.CapacityLimiter(int(os.environ.get('MAX_CONCURRENT_CHATS', 10))) +@dataclasses.dataclass +class GraphChange: + removed_nodes: Set[str] = dataclasses.field(default_factory=set) + added_updated_nodes: Set[str] = dataclasses.field(default_factory=set) + removed_edges: Set[Tuple[str, str]] = dataclasses.field(default_factory=set) + added_updated_edges: Set[Tuple[str, str]] = dataclasses.field(default_factory=set) + def perform_variable_replacements( input: str, history: list[dict] | None = None, variables: dict | None = None ) -> str: @@ -146,24 +157,74 @@ def set_tags_to_cache(kb_ids, tags): k = hasher.hexdigest() REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600) +def tidy_graph(graph: nx.Graph, callback): + """ + Ensure all nodes and edges in the graph have some essential attribute. + """ + def is_valid_node(node_attrs: dict) -> bool: + valid_node = True + for attr in ["description", "source_id"]: + if attr not in node_attrs: + valid_node = False + break + return valid_node + purged_nodes = [] + for node, node_attrs in graph.nodes(data=True): + if not is_valid_node(node_attrs): + purged_nodes.append(node) + for node in purged_nodes: + graph.remove_node(node) + if purged_nodes and callback: + callback(msg=f"Purged {len(purged_nodes)} nodes from graph due to missing essential attributes.") -def graph_merge(g1, g2): - g = g2.copy() - for n, attr in g1.nodes(data=True): - if n not in g2.nodes(): - g.add_node(n, **attr) + purged_edges = [] + for source, target, attr in graph.edges(data=True): + if not is_valid_node(attr): + purged_edges.append((source, target)) + if "keywords" not in attr: + attr["keywords"] = [] + for source, target in purged_edges: + graph.remove_edge(source, target) + if purged_edges and callback: + callback(msg=f"Purged {len(purged_edges)} edges from graph due to missing essential attributes.") + +def get_from_to(node1, node2): + if node1 < node2: + return (node1, node2) + else: + return (node2, node1) + +def graph_merge(g1: nx.Graph, g2: nx.Graph, change: GraphChange): + """Merge graph g2 into g1 in place.""" + for node_name, attr in g2.nodes(data=True): + change.added_updated_nodes.add(node_name) + if not g1.has_node(node_name): + g1.add_node(node_name, **attr) continue + node = g1.nodes[node_name] + node["description"] += GRAPH_FIELD_SEP + attr["description"] + # A node's source_id indicates which chunks it came from. + node["source_id"] += attr["source_id"] - for source, target, attr in g1.edges(data=True): - if g.has_edge(source, target): - g[source][target].update({"weight": attr.get("weight", 0)+1}) + for source, target, attr in g2.edges(data=True): + change.added_updated_edges.add(get_from_to(source, target)) + edge = g1.get_edge_data(source, target) + if edge is None: + g1.add_edge(source, target, **attr) continue - g.add_edge(source, target)#, **attr) - - for node_degree in g.degree: - g.nodes[str(node_degree[0])]["rank"] = int(node_degree[1]) - return g + edge["weight"] += attr.get("weight", 0) + edge["description"] += GRAPH_FIELD_SEP + attr["description"] + edge["keywords"] += attr["keywords"] + # A edge's source_id indicates which chunks it came from. + edge["source_id"] += attr["source_id"] + for node_degree in g1.degree: + g1.nodes[str(node_degree[0])]["rank"] = int(node_degree[1]) + # A graph's source_id indicates which documents it came from. + if "source_id" not in g1.graph: + g1.graph["source_id"] = [] + g1.graph["source_id"] += g2.graph.get("source_id", []) + return g1 def compute_args_hash(*args): return md5(str(args).encode()).hexdigest() @@ -237,55 +298,10 @@ def is_float_regex(value): def chunk_id(chunk): return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest() -def get_entity_cache(tenant_id, kb_id, ent_name) -> str | list[str]: - hasher = xxhash.xxh64() - hasher.update(str(tenant_id).encode("utf-8")) - hasher.update(str(kb_id).encode("utf-8")) - hasher.update(str(ent_name).encode("utf-8")) - k = hasher.hexdigest() - bin = REDIS_CONN.get(k) - if not bin: - return - return json.loads(bin) - - -def set_entity_cache(tenant_id, kb_id, ent_name, content_with_weight): - hasher = xxhash.xxh64() - hasher.update(str(tenant_id).encode("utf-8")) - hasher.update(str(kb_id).encode("utf-8")) - hasher.update(str(ent_name).encode("utf-8")) - - k = hasher.hexdigest() - REDIS_CONN.set(k, content_with_weight.encode("utf-8"), 3600) - - -def get_entity(tenant_id, kb_id, ent_name): - cache = get_entity_cache(tenant_id, kb_id, ent_name) - if cache: - return cache - conds = { - "fields": ["content_with_weight"], - "entity_kwd": ent_name, - "size": 10000, - "knowledge_graph_kwd": ["entity"] - } - res = [] - es_res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id]) - for id in es_res.ids: - try: - if isinstance(ent_name, str): - set_entity_cache(tenant_id, kb_id, ent_name, es_res.field[id]["content_with_weight"]) - return json.loads(es_res.field[id]["content_with_weight"]) - res.append(json.loads(es_res.field[id]["content_with_weight"])) - except Exception: - continue - - return res - - -def set_entity(tenant_id, kb_id, embd_mdl, ent_name, meta): +async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks): chunk = { + "id": get_uuid(), "important_kwd": [ent_name], "title_tks": rag_tokenizer.tokenize(ent_name), "entity_kwd": ent_name, @@ -293,28 +309,19 @@ def set_entity(tenant_id, kb_id, embd_mdl, ent_name, meta): "entity_type_kwd": meta["entity_type"], "content_with_weight": json.dumps(meta, ensure_ascii=False), "content_ltks": rag_tokenizer.tokenize(meta["description"]), - "source_id": list(set(meta["source_id"])), + "source_id": meta["source_id"], "kb_id": kb_id, "available_int": 0 } chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"]) - set_entity_cache(tenant_id, kb_id, ent_name, chunk["content_with_weight"]) - res = settings.retrievaler.search({"entity_kwd": ent_name, "size": 1, "fields": []}, - search.index_name(tenant_id), [kb_id]) - if res.ids: - settings.docStoreConn.update({"entity_kwd": ent_name}, chunk, search.index_name(tenant_id), kb_id) - else: - ebd = get_embed_cache(embd_mdl.llm_name, ent_name) - if ebd is None: - try: - ebd, _ = embd_mdl.encode([ent_name]) - ebd = ebd[0] - set_embed_cache(embd_mdl.llm_name, ent_name, ebd) - except Exception as e: - logging.exception(f"Fail to embed entity: {e}") - if ebd is not None: - chunk["q_%d_vec" % len(ebd)] = ebd - settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id) + ebd = get_embed_cache(embd_mdl.llm_name, ent_name) + if ebd is None: + ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([ent_name])) + ebd = ebd[0] + set_embed_cache(embd_mdl.llm_name, ent_name, ebd) + assert ebd is not None + chunk["q_%d_vec" % len(ebd)] = ebd + chunks.append(chunk) def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1): @@ -344,40 +351,30 @@ def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1): return res -def set_relation(tenant_id, kb_id, embd_mdl, from_ent_name, to_ent_name, meta): +async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks): chunk = { + "id": get_uuid(), "from_entity_kwd": from_ent_name, "to_entity_kwd": to_ent_name, "knowledge_graph_kwd": "relation", "content_with_weight": json.dumps(meta, ensure_ascii=False), "content_ltks": rag_tokenizer.tokenize(meta["description"]), "important_kwd": meta["keywords"], - "source_id": list(set(meta["source_id"])), + "source_id": meta["source_id"], "weight_int": int(meta["weight"]), "kb_id": kb_id, "available_int": 0 } chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"]) - res = settings.retrievaler.search({"from_entity_kwd": to_ent_name, "to_entity_kwd": to_ent_name, "size": 1, "fields": []}, - search.index_name(tenant_id), [kb_id]) - - if res.ids: - settings.docStoreConn.update({"from_entity_kwd": from_ent_name, "to_entity_kwd": to_ent_name}, - chunk, - search.index_name(tenant_id), kb_id) - else: - txt = f"{from_ent_name}->{to_ent_name}" - ebd = get_embed_cache(embd_mdl.llm_name, txt) - if ebd is None: - try: - ebd, _ = embd_mdl.encode([txt+f": {meta['description']}"]) - ebd = ebd[0] - set_embed_cache(embd_mdl.llm_name, txt, ebd) - except Exception as e: - logging.exception(f"Fail to embed entity relation: {e}") - if ebd is not None: - chunk["q_%d_vec" % len(ebd)] = ebd - settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id) + txt = f"{from_ent_name}->{to_ent_name}" + ebd = get_embed_cache(embd_mdl.llm_name, txt) + if ebd is None: + ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([txt+f": {meta['description']}"])) + ebd = ebd[0] + set_embed_cache(embd_mdl.llm_name, txt, ebd) + assert ebd is not None + chunk["q_%d_vec" % len(ebd)] = ebd + chunks.append(chunk) async def does_graph_contains(tenant_id, kb_id, doc_id): # Get doc_ids of graph @@ -418,33 +415,68 @@ async def get_graph(tenant_id, kb_id): } res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id])) if res.total == 0: - return None, [] + return None for id in res.ids: try: - return json_graph.node_link_graph(json.loads(res.field[id]["content_with_weight"]), edges="edges"), \ - res.field[id]["source_id"] + g = json_graph.node_link_graph(json.loads(res.field[id]["content_with_weight"]), edges="edges") + if "source_id" not in g.graph: + g.graph["source_id"] = res.field[id]["source_id"] + return g except Exception: continue result = await rebuild_graph(tenant_id, kb_id) return result -async def set_graph(tenant_id, kb_id, graph, docids): - chunk = { - "content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False, - indent=2), +async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, change: GraphChange, callback): + start = trio.current_time() + + await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph"]}, search.index_name(tenant_id), kb_id)) + + if change.removed_nodes: + await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id)) + + if change.removed_edges: + async with trio.open_nursery() as nursery: + for from_node, to_node in change.removed_edges: + nursery.start_soon(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id)) + now = trio.current_time() + if callback: + callback(msg=f"set_graph removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges from index in {now - start:.2f}s.") + start = now + + chunks = [{ + "id": get_uuid(), + "content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False), "knowledge_graph_kwd": "graph", "kb_id": kb_id, - "source_id": list(docids), + "source_id": graph.graph.get("source_id", []), "available_int": 0, "removed_kwd": "N" - } - res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "graph", "size": 1, "fields": []}, search.index_name(tenant_id), [kb_id])) - if res.ids: - await trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"knowledge_graph_kwd": "graph"}, chunk, - search.index_name(tenant_id), kb_id)) - else: - await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id)) + }] + async with trio.open_nursery() as nursery: + for node in change.added_updated_nodes: + node_attrs = graph.nodes[node] + nursery.start_soon(lambda: graph_node_to_chunk(kb_id, embd_mdl, node, node_attrs, chunks)) + for from_node, to_node in change.added_updated_edges: + edge_attrs = graph.edges[from_node, to_node] + nursery.start_soon(lambda: graph_edge_to_chunk(kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks)) + now = trio.current_time() + if callback: + callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.") + start = now + + await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "entity", "relation"]}, search.index_name(tenant_id), kb_id)) + + es_bulk_size = 4 + for b in range(0, len(chunks), es_bulk_size): + doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(tenant_id), kb_id)) + if doc_store_result: + error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" + raise Exception(error_message) + now = trio.current_time() + if callback: + callback(msg=f"set_graph added/updated {len(change.added_updated_nodes)} nodes and {len(change.added_updated_edges)} edges from index in {now - start:.2f}s.") def is_continuous_subsequence(subseq, seq): @@ -489,67 +521,6 @@ def merge_tuples(list1, list2): return result -async def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop): - def n_neighbor(id): - nonlocal graph, n_hop - count = 0 - source_edge = list(graph.edges(id)) - if not source_edge: - return [] - count = count + 1 - while count < n_hop: - count = count + 1 - sc_edge = deepcopy(source_edge) - source_edge = [] - for pair in sc_edge: - append_edge = list(graph.edges(pair[-1])) - for tuples in merge_tuples([pair], append_edge): - source_edge.append(tuples) - nbrs = [] - for path in source_edge: - n = {"path": path, "weights": []} - wts = nx.get_edge_attributes(graph, 'weight') - for i in range(len(path)-1): - f, t = path[i], path[i+1] - n["weights"].append(wts.get((f, t), 0)) - nbrs.append(n) - return nbrs - - pr = nx.pagerank(graph) - try: - async with trio.open_nursery() as nursery: - for n, p in pr.items(): - graph.nodes[n]["pagerank"] = p - nursery.start_soon(lambda: trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"entity_kwd": n, "kb_id": kb_id}, - {"rank_flt": p, - "n_hop_with_weight": json.dumps((n), ensure_ascii=False)}, - search.index_name(tenant_id), kb_id))) - except Exception as e: - logging.exception(e) - - ty2ents = defaultdict(list) - for p, r in sorted(pr.items(), key=lambda x: x[1], reverse=True): - ty = graph.nodes[p].get("entity_type") - if not ty or len(ty2ents[ty]) > 12: - continue - ty2ents[ty].append(p) - - chunk = { - "content_with_weight": json.dumps(ty2ents, ensure_ascii=False), - "kb_id": kb_id, - "knowledge_graph_kwd": "ty2ents", - "available_int": 0 - } - res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "size": 1, "fields": []}, - search.index_name(tenant_id), [kb_id])) - if res.ids: - await trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"knowledge_graph_kwd": "ty2ents"}, - chunk, - search.index_name(tenant_id), kb_id)) - else: - await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id)) - - async def get_entity_type2sampels(idxnms, kb_ids: list): es_res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, @@ -584,33 +555,46 @@ def flat_uniq_list(arr, key): async def rebuild_graph(tenant_id, kb_id): graph = nx.Graph() - src_ids = [] - flds = ["entity_kwd", "entity_type_kwd", "from_entity_kwd", "to_entity_kwd", "weight_int", "knowledge_graph_kwd", "source_id"] + src_ids = set() + flds = ["entity_kwd", "from_entity_kwd", "to_entity_kwd", "knowledge_graph_kwd", "content_with_weight", "source_id"] bs = 256 - for i in range(0, 39*bs, bs): + for i in range(0, 1024*bs, bs): es_res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(flds, [], - {"kb_id": kb_id, "knowledge_graph_kwd": ["entity", "relation"]}, + {"kb_id": kb_id, "knowledge_graph_kwd": ["entity"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id] )) tot = settings.docStoreConn.getTotal(es_res) if tot == 0: - return None, None + break es_res = settings.docStoreConn.getFields(es_res, flds) for id, d in es_res.items(): - src_ids.extend(d.get("source_id", [])) - if d["knowledge_graph_kwd"] == "entity": - graph.add_node(d["entity_kwd"], entity_type=d["entity_type_kwd"]) - elif "from_entity_kwd" in d and "to_entity_kwd" in d: - graph.add_edge( - d["from_entity_kwd"], - d["to_entity_kwd"], - weight=int(d["weight_int"]) - ) + assert d["knowledge_graph_kwd"] == "relation" + src_ids.update(d.get("source_id", [])) + attrs = json.load(d["content_with_weight"]) + graph.add_node(d["entity_kwd"], **attrs) - if len(es_res.keys()) < 128: - return graph, list(set(src_ids)) + for i in range(0, 1024*bs, bs): + es_res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(flds, [], + {"kb_id": kb_id, "knowledge_graph_kwd": ["relation"]}, + [], + OrderByExpr(), + i, bs, search.index_name(tenant_id), [kb_id] + )) + tot = settings.docStoreConn.getTotal(es_res) + if tot == 0: + return None - return graph, list(set(src_ids)) + es_res = settings.docStoreConn.getFields(es_res, flds) + for id, d in es_res.items(): + assert d["knowledge_graph_kwd"] == "relation" + src_ids.update(d.get("source_id", [])) + if graph.has_node(d["from_entity_kwd"]) and graph.has_node(d["to_entity_kwd"]): + attrs = json.load(d["content_with_weight"]) + graph.add_edge(d["from_entity_kwd"], d["to_entity_kwd"], **attrs) + + src_ids = sorted(src_ids) + graph.graph["source_id"] = src_ids + return graph diff --git a/pyproject.toml b/pyproject.toml index 30c5e564..59afacdf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,6 +125,7 @@ dependencies = [ "xxhash>=3.5.0,<4.0.0", "trio>=0.29.0", "langfuse>=2.60.0", + "debugpy>=1.8.13", ] [project.optional-dependencies] diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 4d0e860b..49fb137d 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -517,6 +517,8 @@ async def do_handle_task(task): chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback) # Either using graphrag or Standard chunking methods elif task.get("task_type", "") == "graphrag": + global task_limiter + task_limiter = trio.CapacityLimiter(2) graphrag_conf = task_parser_config.get("graphrag", {}) if not graphrag_conf.get("use_graphrag", False): return diff --git a/rag/utils/infinity_conn.py b/rag/utils/infinity_conn.py index 28d107bc..1edff02c 100644 --- a/rag/utils/infinity_conn.py +++ b/rag/utils/infinity_conn.py @@ -172,6 +172,12 @@ class InfinityConnection(DocStoreConnection): ConflictType.Ignore, ) + def field_keyword(self, field_name: str): + # The "docnm_kwd" field is always a string, not list. + if field_name == "source_id" or (field_name.endswith("_kwd") and field_name != "docnm_kwd"): + return True + return False + """ Database operations """ @@ -480,9 +486,11 @@ class InfinityConnection(DocStoreConnection): assert "_id" not in d assert "id" in d for k, v in d.items(): - if k in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]: - assert isinstance(v, list) - d[k] = "###".join(v) + if self.field_keyword(k): + if isinstance(v, list): + d[k] = "###".join(v) + else: + d[k] = v elif re.search(r"_feas$", k): d[k] = json.dumps(v) elif k == 'kb_id': @@ -495,6 +503,8 @@ class InfinityConnection(DocStoreConnection): elif k in ["page_num_int", "top_int"]: assert isinstance(v, list) d[k] = "_".join(f"{num:08x}" for num in v) + else: + d[k] = v for n, vs in embedding_clmns: if n in d: @@ -525,13 +535,13 @@ class InfinityConnection(DocStoreConnection): # del condition["exists"] filter = equivalent_condition_to_str(condition, table_instance) for k, v in list(newValue.items()): - if k in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]: - assert isinstance(v, list) - newValue[k] = "###".join(v) + if self.field_keyword(k): + if isinstance(v, list): + newValue[k] = "###".join(v) + else: + newValue[k] = v elif re.search(r"_feas$", k): newValue[k] = json.dumps(v) - elif k.endswith("_kwd") and isinstance(v, list): - newValue[k] = " ".join(v) elif k == 'kb_id': if isinstance(newValue[k], list): newValue[k] = newValue[k][0] # since d[k] is a list, but we need a str @@ -546,6 +556,8 @@ class InfinityConnection(DocStoreConnection): del newValue[k] if v in [PAGERANK_FLD]: newValue[v] = 0 + else: + newValue[k] = v logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.") table_instance.update(filter, newValue) @@ -600,7 +612,7 @@ class InfinityConnection(DocStoreConnection): for column in res2.columns: k = column.lower() - if k in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]: + if self.field_keyword(k): res2[column] = res2[column].apply(lambda v:[kwd for kwd in v.split("###") if kwd]) elif k == "position_int": def to_position_int(v): diff --git a/rag/utils/redis_conn.py b/rag/utils/redis_conn.py index 790df690..0f3799cb 100644 --- a/rag/utils/redis_conn.py +++ b/rag/utils/redis_conn.py @@ -319,9 +319,3 @@ class RedisDistributedLock: def release(self): return self.lock.release() - - def __enter__(self): - self.acquire() - - def __exit__(self, exception_type, exception_value, exception_traceback): - self.release() \ No newline at end of file diff --git a/uv.lock b/uv.lock index f44cfee5..9644e92f 100644 --- a/uv.lock +++ b/uv.lock @@ -1100,6 +1100,27 @@ version = "0.8.2" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/9d/fe/db74bd405d515f06657f11ad529878fd389576dca4812bea6f98d9b31574/datrie-0.8.2.tar.gz", hash = "sha256:525b08f638d5cf6115df6ccd818e5a01298cd230b2dac91c8ff2e6499d18765d" } +[[package]] +name = "debugpy" +version = "1.8.13" +source = { registry = "https://mirrors.aliyun.com/pypi/simple" } +sdist = { url = "https://mirrors.aliyun.com/pypi/packages/51/d4/f35f539e11c9344652f362c22413ec5078f677ac71229dc9b4f6f85ccaa3/debugpy-1.8.13.tar.gz", hash = "sha256:837e7bef95bdefba426ae38b9a94821ebdc5bea55627879cd48165c90b9e50ce" } +wheels = [ + { url = "https://mirrors.aliyun.com/pypi/packages/3f/32/901c7204cceb3262fdf38f4c25c9a46372c11661e8490e9ea702bc4ff448/debugpy-1.8.13-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:06859f68e817966723ffe046b896b1bd75c665996a77313370336ee9e1de3e90" }, + { url = "https://mirrors.aliyun.com/pypi/packages/95/10/77fe746851c8d84838a807da60c7bd0ac8627a6107d6917dd3293bf8628c/debugpy-1.8.13-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cb56c2db69fb8df3168bc857d7b7d2494fed295dfdbde9a45f27b4b152f37520" }, + { url = "https://mirrors.aliyun.com/pypi/packages/a1/ef/28f8db2070e453dda0e49b356e339d0b4e1d38058d4c4ea9e88cdc8ee8e7/debugpy-1.8.13-cp310-cp310-win32.whl", hash = "sha256:46abe0b821cad751fc1fb9f860fb2e68d75e2c5d360986d0136cd1db8cad4428" }, + { url = "https://mirrors.aliyun.com/pypi/packages/89/16/1d53a80caf5862627d3eaffb217d4079d7e4a1df6729a2d5153733661efd/debugpy-1.8.13-cp310-cp310-win_amd64.whl", hash = "sha256:dc7b77f5d32674686a5f06955e4b18c0e41fb5a605f5b33cf225790f114cfeec" }, + { url = "https://mirrors.aliyun.com/pypi/packages/31/90/dd2fcad8364f0964f476537481985198ce6e879760281ad1cec289f1aa71/debugpy-1.8.13-cp311-cp311-macosx_14_0_universal2.whl", hash = "sha256:eee02b2ed52a563126c97bf04194af48f2fe1f68bb522a312b05935798e922ff" }, + { url = "https://mirrors.aliyun.com/pypi/packages/5c/c9/06ff65f15eb30dbdafd45d1575770b842ce3869ad5580a77f4e5590f1be7/debugpy-1.8.13-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4caca674206e97c85c034c1efab4483f33971d4e02e73081265ecb612af65377" }, + { url = "https://mirrors.aliyun.com/pypi/packages/3b/49/798a4092bde16a4650f17ac5f2301d4d37e1972d65462fb25c80a83b4790/debugpy-1.8.13-cp311-cp311-win32.whl", hash = "sha256:7d9a05efc6973b5aaf076d779cf3a6bbb1199e059a17738a2aa9d27a53bcc888" }, + { url = "https://mirrors.aliyun.com/pypi/packages/cd/d5/3684d7561c8ba2797305cf8259619acccb8d6ebe2117bb33a6897c235eee/debugpy-1.8.13-cp311-cp311-win_amd64.whl", hash = "sha256:62f9b4a861c256f37e163ada8cf5a81f4c8d5148fc17ee31fb46813bd658cdcc" }, + { url = "https://mirrors.aliyun.com/pypi/packages/79/ad/dff929b6b5403feaab0af0e5bb460fd723f9c62538b718a9af819b8fff20/debugpy-1.8.13-cp312-cp312-macosx_14_0_universal2.whl", hash = "sha256:2b8de94c5c78aa0d0ed79023eb27c7c56a64c68217d881bee2ffbcb13951d0c1" }, + { url = "https://mirrors.aliyun.com/pypi/packages/d6/4f/b7d42e6679f0bb525888c278b0c0d2b6dff26ed42795230bb46eaae4f9b3/debugpy-1.8.13-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:887d54276cefbe7290a754424b077e41efa405a3e07122d8897de54709dbe522" }, + { url = "https://mirrors.aliyun.com/pypi/packages/ec/18/d9b3e88e85d41f68f77235112adc31012a784e45a3fcdbb039777d570a0f/debugpy-1.8.13-cp312-cp312-win32.whl", hash = "sha256:3872ce5453b17837ef47fb9f3edc25085ff998ce63543f45ba7af41e7f7d370f" }, + { url = "https://mirrors.aliyun.com/pypi/packages/c9/f7/0df18a4f530ed3cc06f0060f548efe9e3316102101e311739d906f5650be/debugpy-1.8.13-cp312-cp312-win_amd64.whl", hash = "sha256:63ca7670563c320503fea26ac688988d9d6b9c6a12abc8a8cf2e7dd8e5f6b6ea" }, + { url = "https://mirrors.aliyun.com/pypi/packages/37/4f/0b65410a08b6452bfd3f7ed6f3610f1a31fb127f46836e82d31797065dcb/debugpy-1.8.13-py2.py3-none-any.whl", hash = "sha256:d4ba115cdd0e3a70942bd562adba9ec8c651fe69ddde2298a1be296fc331906f" }, +] + [[package]] name = "decorator" version = "5.2.1" @@ -1375,17 +1396,17 @@ name = "fastembed-gpu" version = "0.3.6" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } dependencies = [ - { name = "huggingface-hub" }, - { name = "loguru" }, - { name = "mmh3" }, - { name = "numpy" }, - { name = "onnxruntime-gpu" }, - { name = "pillow" }, - { name = "pystemmer" }, - { name = "requests" }, - { name = "snowballstemmer" }, - { name = "tokenizers" }, - { name = "tqdm" }, + { name = "huggingface-hub", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "loguru", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "mmh3", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "numpy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "onnxruntime-gpu", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "pillow", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "pystemmer", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "requests", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "snowballstemmer", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "tokenizers", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "tqdm", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, ] sdist = { url = "https://mirrors.aliyun.com/pypi/packages/da/07/7336c7f3d7ee47f33b407eeb50f5eeb152889de538a52a8f1cc637192816/fastembed_gpu-0.3.6.tar.gz", hash = "sha256:ee2de8918b142adbbf48caaffec0c492f864d73c073eea5a3dcd0e8c1041c50d" } wheels = [ @@ -3531,12 +3552,12 @@ name = "onnxruntime-gpu" version = "1.19.2" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } dependencies = [ - { name = "coloredlogs" }, - { name = "flatbuffers" }, - { name = "numpy" }, - { name = "packaging" }, - { name = "protobuf" }, - { name = "sympy" }, + { name = "coloredlogs", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "flatbuffers", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "numpy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "packaging", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "protobuf", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "sympy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, ] wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/d0/9c/3fa310e0730643051eb88e884f19813a6c8b67d0fbafcda610d960e589db/onnxruntime_gpu-1.19.2-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a49740e079e7c5215830d30cde3df792e903df007aa0b0fd7aa797937061b27a" }, @@ -4746,6 +4767,7 @@ dependencies = [ { name = "crawl4ai" }, { name = "dashscope" }, { name = "datrie" }, + { name = "debugpy" }, { name = "deepl" }, { name = "demjson3" }, { name = "discord-py" }, @@ -4877,6 +4899,7 @@ requires-dist = [ { name = "crawl4ai", specifier = "==0.3.8" }, { name = "dashscope", specifier = "==1.20.11" }, { name = "datrie", specifier = "==0.8.2" }, + { name = "debugpy", specifier = ">=1.8.13" }, { name = "deepl", specifier = "==1.18.0" }, { name = "demjson3", specifier = "==3.0.6" }, { name = "discord-py", specifier = "==2.3.2" },