diff --git a/api/ragflow_server.py b/api/ragflow_server.py index ac05dadd..8f2697f2 100644 --- a/api/ragflow_server.py +++ b/api/ragflow_server.py @@ -42,16 +42,22 @@ from api.db.init_data import init_web_data from api.versions import get_ragflow_version from api.utils import show_configs from rag.settings import print_rag_settings +from rag.utils.redis_conn import RedisDistributedLock stop_event = threading.Event() def update_progress(): + redis_lock = RedisDistributedLock("update_progress", timeout=60) while not stop_event.is_set(): try: + if not redis_lock.acquire(): + continue DocumentService.update_progress() stop_event.wait(6) except Exception: logging.exception("update_progress exception") + finally: + redis_lock.release() def signal_handler(sig, frame): logging.info("Received interrupt signal, shutting down...") diff --git a/graphrag/general/extractor.py b/graphrag/general/extractor.py index 81e597e4..ff3c81ad 100644 --- a/graphrag/general/extractor.py +++ b/graphrag/general/extractor.py @@ -93,7 +93,7 @@ class Extractor: return dict(maybe_nodes), dict(maybe_edges) async def __call__( - self, chunks: list[tuple[str, str]], + self, doc_id: str, chunks: list[str], callback: Callable | None = None ): @@ -101,9 +101,9 @@ class Extractor: start_ts = trio.current_time() out_results = [] async with trio.open_nursery() as nursery: - for i, (cid, ck) in enumerate(chunks): + for i, ck in enumerate(chunks): ck = truncate(ck, int(self._llm.max_length*0.8)) - nursery.start_soon(lambda: self._process_single_content((cid, ck), i, len(chunks), out_results)) + nursery.start_soon(lambda: self._process_single_content((doc_id, ck), i, len(chunks), out_results)) maybe_nodes = defaultdict(list) maybe_edges = defaultdict(list) @@ -241,10 +241,13 @@ class Extractor: ) -> str: summary_max_tokens = 512 use_description = truncate(description, summary_max_tokens) + description_list=use_description.split(GRAPH_FIELD_SEP), + if len(description_list) <= 12: + return use_description prompt_template = SUMMARIZE_DESCRIPTIONS_PROMPT context_base = dict( entity_name=entity_or_relation_name, - description_list=use_description.split(GRAPH_FIELD_SEP), + description_list=description_list, language=self._language, ) use_prompt = prompt_template.format(**context_base) diff --git a/graphrag/general/index.py b/graphrag/general/index.py index 130bf10d..179478c9 100644 --- a/graphrag/general/index.py +++ b/graphrag/general/index.py @@ -15,196 +15,353 @@ # import json import logging -from functools import reduce, partial +from functools import partial import networkx as nx import trio from api import settings +from graphrag.light.graph_extractor import GraphExtractor as LightKGExt +from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt from graphrag.general.community_reports_extractor import CommunityReportsExtractor from graphrag.entity_resolution import EntityResolution from graphrag.general.extractor import Extractor -from graphrag.general.graph_extractor import DEFAULT_ENTITY_TYPES -from graphrag.utils import graph_merge, set_entity, get_relation, set_relation, get_entity, get_graph, set_graph, \ - chunk_id, update_nodes_pagerank_nhop_neighbour +from graphrag.utils import ( + graph_merge, + set_entity, + get_relation, + set_relation, + get_entity, + get_graph, + set_graph, + chunk_id, + update_nodes_pagerank_nhop_neighbour, + does_graph_contains, + get_graph_doc_ids, +) from rag.nlp import rag_tokenizer, search -from rag.utils.redis_conn import RedisDistributedLock +from rag.utils.redis_conn import REDIS_CONN -class Dealer: - def __init__(self, - extractor: Extractor, - tenant_id: str, - kb_id: str, - llm_bdl, - chunks: list[tuple[str, str]], - language, - entity_types=DEFAULT_ENTITY_TYPES, - embed_bdl=None, - callback=None - ): - self.tenant_id = tenant_id - self.kb_id = kb_id - self.chunks = chunks - self.llm_bdl = llm_bdl - self.embed_bdl = embed_bdl - self.ext = extractor(self.llm_bdl, language=language, - entity_types=entity_types, - get_entity=partial(get_entity, tenant_id, kb_id), - set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl), - get_relation=partial(get_relation, tenant_id, kb_id), - set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl) - ) - self.graph = nx.Graph() - self.callback = callback +def graphrag_task_set(tenant_id, kb_id, doc_id) -> bool: + key = f"graphrag:{tenant_id}:{kb_id}" + ok = REDIS_CONN.set(key, doc_id, exp=3600 * 24) + if not ok: + raise Exception(f"Faild to set the {key} to {doc_id}") - async def __call__(self): - docids = list(set([docid for docid, _ in self.chunks])) - ents, rels = await self.ext(self.chunks, self.callback) - for en in ents: - self.graph.add_node(en["entity_name"], entity_type=en["entity_type"])#, description=en["description"]) - for rel in rels: - self.graph.add_edge( - rel["src_id"], - rel["tgt_id"], - weight=rel["weight"], - #description=rel["description"] +def graphrag_task_get(tenant_id, kb_id) -> str | None: + key = f"graphrag:{tenant_id}:{kb_id}" + doc_id = REDIS_CONN.get(key) + return doc_id + + +async def run_graphrag( + row: dict, + language, + with_resolution: bool, + with_community: bool, + chat_model, + embedding_model, + callback, +): + start = trio.current_time() + tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"] + chunks = [] + for d in settings.retrievaler.chunk_list( + doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"] + ): + chunks.append(d["content_with_weight"]) + + graph, doc_ids = await update_graph( + LightKGExt + if row["parser_config"]["graphrag"]["method"] != "general" + else GeneralKGExt, + tenant_id, + kb_id, + doc_id, + chunks, + language, + row["parser_config"]["graphrag"]["entity_types"], + chat_model, + embedding_model, + callback, + ) + if not graph: + return + if with_resolution or with_community: + graphrag_task_set(tenant_id, kb_id, doc_id) + if with_resolution: + await resolve_entities( + graph, + doc_ids, + tenant_id, + kb_id, + doc_id, + chat_model, + embedding_model, + callback, + ) + if with_community: + await extract_community( + graph, + doc_ids, + tenant_id, + kb_id, + doc_id, + chat_model, + embedding_model, + callback, + ) + now = trio.current_time() + callback(msg=f"GraphRAG for doc {doc_id} done in {now - start:.2f} seconds.") + return + + +async def update_graph( + extractor: Extractor, + tenant_id: str, + kb_id: str, + doc_id: str, + chunks: list[str], + language, + entity_types, + llm_bdl, + embed_bdl, + callback, +): + contains = await does_graph_contains(tenant_id, kb_id, doc_id) + if contains: + callback(msg=f"Graph already contains {doc_id}, cancel myself") + return None, None + start = trio.current_time() + ext = extractor( + llm_bdl, + language=language, + entity_types=entity_types, + get_entity=partial(get_entity, tenant_id, kb_id), + set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl), + get_relation=partial(get_relation, tenant_id, kb_id), + set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl), + ) + ents, rels = await ext(doc_id, chunks, callback) + subgraph = nx.Graph() + for en in ents: + subgraph.add_node(en["entity_name"], entity_type=en["entity_type"]) + + for rel in rels: + subgraph.add_edge( + rel["src_id"], + rel["tgt_id"], + weight=rel["weight"], + # description=rel["description"] + ) + # TODO: infinity doesn't support array search + chunk = { + "content_with_weight": json.dumps( + nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False, indent=2 + ), + "knowledge_graph_kwd": "subgraph", + "kb_id": kb_id, + "source_id": [doc_id], + "available_int": 0, + "removed_kwd": "N", + } + cid = chunk_id(chunk) + await trio.to_thread.run_sync( + lambda: settings.docStoreConn.insert( + [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id + ) + ) + now = trio.current_time() + callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.") + start = now + + while True: + new_graph = subgraph + now_docids = set([doc_id]) + old_graph, old_doc_ids = await get_graph(tenant_id, kb_id) + if old_graph is not None: + logging.info("Merge with an exiting graph...................") + new_graph = graph_merge(old_graph, subgraph) + await update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, new_graph, 2) + if old_doc_ids: + for old_doc_id in old_doc_ids: + now_docids.add(old_doc_id) + old_doc_ids2 = await get_graph_doc_ids(tenant_id, kb_id) + delta_doc_ids = set(old_doc_ids2) - set(old_doc_ids) + if delta_doc_ids: + callback( + msg="The global graph has changed during merging, try again" ) - - with RedisDistributedLock(self.kb_id, 60*60): - old_graph, old_doc_ids = get_graph(self.tenant_id, self.kb_id) - if old_graph is not None: - logging.info("Merge with an exiting graph...................") - self.graph = reduce(graph_merge, [old_graph, self.graph]) - update_nodes_pagerank_nhop_neighbour(self.tenant_id, self.kb_id, self.graph, 2) - if old_doc_ids: - docids.extend(old_doc_ids) - docids = list(set(docids)) - set_graph(self.tenant_id, self.kb_id, self.graph, docids) + await trio.sleep(1) + continue + break + await set_graph(tenant_id, kb_id, new_graph, list(now_docids)) + now = trio.current_time() + callback( + msg=f"merging subgraph for doc {doc_id} into the global graph done in {now - start:.2f} seconds." + ) + return new_graph, now_docids -class WithResolution(Dealer): - def __init__(self, - tenant_id: str, - kb_id: str, - llm_bdl, - embed_bdl=None, - callback=None - ): - self.tenant_id = tenant_id - self.kb_id = kb_id - self.llm_bdl = llm_bdl - self.embed_bdl = embed_bdl - self.callback = callback - async def __call__(self): - with RedisDistributedLock(self.kb_id, 60*60): - self.graph, doc_ids = await trio.to_thread.run_sync(lambda: get_graph(self.tenant_id, self.kb_id)) - if not self.graph: - logging.error(f"Faild to fetch the graph. tenant_id:{self.kb_id}, kb_id:{self.kb_id}") - if self.callback: - self.callback(-1, msg="Faild to fetch the graph.") - return +async def resolve_entities( + graph, + doc_ids, + tenant_id: str, + kb_id: str, + doc_id: str, + llm_bdl, + embed_bdl, + callback, +): + working_doc_id = graphrag_task_get(tenant_id, kb_id) + if doc_id != working_doc_id: + callback( + msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself" + ) + return + start = trio.current_time() + er = EntityResolution( + llm_bdl, + get_entity=partial(get_entity, tenant_id, kb_id), + set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl), + get_relation=partial(get_relation, tenant_id, kb_id), + set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl), + ) + reso = await er(graph) + graph = reso.graph + callback(msg=f"Graph resolution removed {len(reso.removed_entities)} nodes.") + await update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, 2) + callback(msg="Graph resolution updated pagerank.") - if self.callback: - self.callback(msg="Fetch the existing graph.") - er = EntityResolution(self.llm_bdl, - get_entity=partial(get_entity, self.tenant_id, self.kb_id), - set_entity=partial(set_entity, self.tenant_id, self.kb_id, self.embed_bdl), - get_relation=partial(get_relation, self.tenant_id, self.kb_id), - set_relation=partial(set_relation, self.tenant_id, self.kb_id, self.embed_bdl)) - reso = await er(self.graph) - self.graph = reso.graph - logging.info("Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities))) - if self.callback: - self.callback(msg="Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities))) - await trio.to_thread.run_sync(lambda: update_nodes_pagerank_nhop_neighbour(self.tenant_id, self.kb_id, self.graph, 2)) - await trio.to_thread.run_sync(lambda: set_graph(self.tenant_id, self.kb_id, self.graph, doc_ids)) + working_doc_id = graphrag_task_get(tenant_id, kb_id) + if doc_id != working_doc_id: + callback( + msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself" + ) + return + await set_graph(tenant_id, kb_id, graph, doc_ids) - await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({ - "knowledge_graph_kwd": "relation", - "kb_id": self.kb_id, - "from_entity_kwd": reso.removed_entities - }, search.index_name(self.tenant_id), self.kb_id)) - await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({ - "knowledge_graph_kwd": "relation", - "kb_id": self.kb_id, - "to_entity_kwd": reso.removed_entities - }, search.index_name(self.tenant_id), self.kb_id)) - await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({ - "knowledge_graph_kwd": "entity", - "kb_id": self.kb_id, - "entity_kwd": reso.removed_entities - }, search.index_name(self.tenant_id), self.kb_id)) + await trio.to_thread.run_sync( + lambda: settings.docStoreConn.delete( + { + "knowledge_graph_kwd": "relation", + "kb_id": kb_id, + "from_entity_kwd": reso.removed_entities, + }, + search.index_name(tenant_id), + kb_id, + ) + ) + await trio.to_thread.run_sync( + lambda: settings.docStoreConn.delete( + { + "knowledge_graph_kwd": "relation", + "kb_id": kb_id, + "to_entity_kwd": reso.removed_entities, + }, + search.index_name(tenant_id), + kb_id, + ) + ) + await trio.to_thread.run_sync( + lambda: settings.docStoreConn.delete( + { + "knowledge_graph_kwd": "entity", + "kb_id": kb_id, + "entity_kwd": reso.removed_entities, + }, + search.index_name(tenant_id), + kb_id, + ) + ) + now = trio.current_time() + callback(msg=f"Graph resolution done in {now - start:.2f}s.") -class WithCommunity(Dealer): - def __init__(self, - tenant_id: str, - kb_id: str, - llm_bdl, - embed_bdl=None, - callback=None - ): +async def extract_community( + graph, + doc_ids, + tenant_id: str, + kb_id: str, + doc_id: str, + llm_bdl, + embed_bdl, + callback, +): + working_doc_id = graphrag_task_get(tenant_id, kb_id) + if doc_id != working_doc_id: + callback( + msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself" + ) + return + start = trio.current_time() + ext = CommunityReportsExtractor( + llm_bdl, + get_entity=partial(get_entity, tenant_id, kb_id), + set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl), + get_relation=partial(get_relation, tenant_id, kb_id), + set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl), + ) + cr = await ext(graph, callback=callback) + community_structure = cr.structured_output + community_reports = cr.output + working_doc_id = graphrag_task_get(tenant_id, kb_id) + if doc_id != working_doc_id: + callback( + msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself" + ) + return + await set_graph(tenant_id, kb_id, graph, doc_ids) - self.tenant_id = tenant_id - self.kb_id = kb_id - self.community_structure = None - self.community_reports = None - self.llm_bdl = llm_bdl - self.embed_bdl = embed_bdl - self.callback = callback - async def __call__(self): - with RedisDistributedLock(self.kb_id, 60*60): - self.graph, doc_ids = get_graph(self.tenant_id, self.kb_id) - if not self.graph: - logging.error(f"Faild to fetch the graph. tenant_id:{self.kb_id}, kb_id:{self.kb_id}") - if self.callback: - self.callback(-1, msg="Faild to fetch the graph.") - return - if self.callback: - self.callback(msg="Fetch the existing graph.") - - cr = CommunityReportsExtractor(self.llm_bdl, - get_entity=partial(get_entity, self.tenant_id, self.kb_id), - set_entity=partial(set_entity, self.tenant_id, self.kb_id, self.embed_bdl), - get_relation=partial(get_relation, self.tenant_id, self.kb_id), - set_relation=partial(set_relation, self.tenant_id, self.kb_id, self.embed_bdl)) - cr = await cr(self.graph, callback=self.callback) - self.community_structure = cr.structured_output - self.community_reports = cr.output - await trio.to_thread.run_sync(lambda: set_graph(self.tenant_id, self.kb_id, self.graph, doc_ids)) - - if self.callback: - self.callback(msg="Graph community extraction is done. Indexing {} reports.".format(len(cr.structured_output))) - - await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({ + now = trio.current_time() + callback( + msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s." + ) + start = now + await trio.to_thread.run_sync( + lambda: settings.docStoreConn.delete( + {"knowledge_graph_kwd": "community_report", "kb_id": kb_id}, + search.index_name(tenant_id), + kb_id, + ) + ) + for stru, rep in zip(community_structure, community_reports): + obj = { + "report": rep, + "evidences": "\n".join([f["explanation"] for f in stru["findings"]]), + } + chunk = { + "docnm_kwd": stru["title"], + "title_tks": rag_tokenizer.tokenize(stru["title"]), + "content_with_weight": json.dumps(obj, ensure_ascii=False), + "content_ltks": rag_tokenizer.tokenize( + obj["report"] + " " + obj["evidences"] + ), "knowledge_graph_kwd": "community_report", - "kb_id": self.kb_id - }, search.index_name(self.tenant_id), self.kb_id)) - - for stru, rep in zip(self.community_structure, self.community_reports): - obj = { - "report": rep, - "evidences": "\n".join([f["explanation"] for f in stru["findings"]]) - } - chunk = { - "docnm_kwd": stru["title"], - "title_tks": rag_tokenizer.tokenize(stru["title"]), - "content_with_weight": json.dumps(obj, ensure_ascii=False), - "content_ltks": rag_tokenizer.tokenize(obj["report"] +" "+ obj["evidences"]), - "knowledge_graph_kwd": "community_report", - "weight_flt": stru["weight"], - "entities_kwd": stru["entities"], - "important_kwd": stru["entities"], - "kb_id": self.kb_id, - "source_id": doc_ids, - "available_int": 0 - } - chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"]) - #try: - # ebd, _ = self.embed_bdl.encode([", ".join(community["entities"])]) - # chunk["q_%d_vec" % len(ebd[0])] = ebd[0] - #except Exception as e: - # logging.exception(f"Fail to embed entity relation: {e}") - await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(self.tenant_id))) + "weight_flt": stru["weight"], + "entities_kwd": stru["entities"], + "important_kwd": stru["entities"], + "kb_id": kb_id, + "source_id": doc_ids, + "available_int": 0, + } + chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize( + chunk["content_ltks"] + ) + # try: + # ebd, _ = embed_bdl.encode([", ".join(community["entities"])]) + # chunk["q_%d_vec" % len(ebd[0])] = ebd[0] + # except Exception as e: + # logging.exception(f"Fail to embed entity relation: {e}") + await trio.to_thread.run_sync( + lambda: settings.docStoreConn.insert( + [{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id) + ) + ) + now = trio.current_time() + callback( + msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s." + ) + return community_structure, community_reports diff --git a/graphrag/general/smoke.py b/graphrag/general/smoke.py index f8f10e74..3f282fb0 100644 --- a/graphrag/general/smoke.py +++ b/graphrag/general/smoke.py @@ -16,7 +16,7 @@ import argparse import json - +import logging import networkx as nx import trio @@ -26,42 +26,85 @@ from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle from api.db.services.user_service import TenantService -from graphrag.general.index import WithCommunity, Dealer, WithResolution -from graphrag.light.graph_extractor import GraphExtractor -from rag.utils.redis_conn import RedisDistributedLock +from graphrag.general.graph_extractor import GraphExtractor +from graphrag.general.index import update_graph, with_resolution, with_community settings.init_settings() -if __name__ == "__main__": + +def callback(prog=None, msg="Processing..."): + logging.info(msg) + + +async def main(): parser = argparse.ArgumentParser() - parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True) - parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True) + parser.add_argument( + "-t", + "--tenant_id", + default=False, + help="Tenant ID", + action="store", + required=True, + ) + parser.add_argument( + "-d", + "--doc_id", + default=False, + help="Document ID", + action="store", + required=True, + ) args = parser.parse_args() e, doc = DocumentService.get_by_id(args.doc_id) if not e: raise LookupError("Document not found.") kb_id = doc.kb_id - chunks = [d["content_with_weight"] for d in - settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, [kb_id], max_count=6, - fields=["content_with_weight"])] - chunks = [("x", c) for c in chunks] - - RedisDistributedLock.clean_lock(kb_id) + chunks = [ + d["content_with_weight"] + for d in settings.retrievaler.chunk_list( + args.doc_id, + args.tenant_id, + [kb_id], + max_count=6, + fields=["content_with_weight"], + ) + ] _, tenant = TenantService.get_by_id(args.tenant_id) llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id) _, kb = KnowledgebaseService.get_by_id(kb_id) embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id) - dealer = Dealer(GraphExtractor, args.tenant_id, kb_id, llm_bdl, chunks, "English", embed_bdl=embed_bdl) - trio.run(dealer()) - print(json.dumps(nx.node_link_data(dealer.graph), ensure_ascii=False, indent=2)) + graph, doc_ids = await update_graph( + GraphExtractor, + args.tenant_id, + kb_id, + args.doc_id, + chunks, + "English", + llm_bdl, + embed_bdl, + callback, + ) + print(json.dumps(nx.node_link_data(graph), ensure_ascii=False, indent=2)) - dealer = WithResolution(args.tenant_id, kb_id, llm_bdl, embed_bdl) - trio.run(dealer()) - dealer = WithCommunity(args.tenant_id, kb_id, llm_bdl, embed_bdl) - trio.run(dealer()) + await with_resolution( + args.tenant_id, kb_id, args.doc_id, llm_bdl, embed_bdl, callback + ) + community_structure, community_reports = await with_community( + args.tenant_id, kb_id, args.doc_id, llm_bdl, embed_bdl, callback + ) - print("------------------ COMMUNITY REPORT ----------------------\n", dealer.community_reports) - print(json.dumps(dealer.community_structure, ensure_ascii=False, indent=2)) + print( + "------------------ COMMUNITY STRUCTURE--------------------\n", + json.dumps(community_structure, ensure_ascii=False, indent=2), + ) + print( + "------------------ COMMUNITY REPORTS----------------------\n", + community_reports, + ) + + +if __name__ == "__main__": + trio.run(main) diff --git a/graphrag/light/smoke.py b/graphrag/light/smoke.py index 20dc5615..504f09ce 100644 --- a/graphrag/light/smoke.py +++ b/graphrag/light/smoke.py @@ -18,22 +18,42 @@ import argparse import json from api import settings import networkx as nx +import logging +import trio from api.db import LLMType from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle from api.db.services.user_service import TenantService -from graphrag.general.index import Dealer +from graphrag.general.index import update_graph from graphrag.light.graph_extractor import GraphExtractor -from rag.utils.redis_conn import RedisDistributedLock settings.init_settings() -if __name__ == "__main__": + +def callback(prog=None, msg="Processing..."): + logging.info(msg) + + +async def main(): parser = argparse.ArgumentParser() - parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True) - parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True) + parser.add_argument( + "-t", + "--tenant_id", + default=False, + help="Tenant ID", + action="store", + required=True, + ) + parser.add_argument( + "-d", + "--doc_id", + default=False, + help="Document ID", + action="store", + required=True, + ) args = parser.parse_args() e, doc = DocumentService.get_by_id(args.doc_id) @@ -41,18 +61,36 @@ if __name__ == "__main__": raise LookupError("Document not found.") kb_id = doc.kb_id - chunks = [d["content_with_weight"] for d in - settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, [kb_id], max_count=6, - fields=["content_with_weight"])] - chunks = [("x", c) for c in chunks] - - RedisDistributedLock.clean_lock(kb_id) + chunks = [ + d["content_with_weight"] + for d in settings.retrievaler.chunk_list( + args.doc_id, + args.tenant_id, + [kb_id], + max_count=6, + fields=["content_with_weight"], + ) + ] _, tenant = TenantService.get_by_id(args.tenant_id) llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id) _, kb = KnowledgebaseService.get_by_id(kb_id) embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id) - dealer = Dealer(GraphExtractor, args.tenant_id, kb_id, llm_bdl, chunks, "English", embed_bdl=embed_bdl) + graph, doc_ids = await update_graph( + GraphExtractor, + args.tenant_id, + kb_id, + args.doc_id, + chunks, + "English", + llm_bdl, + embed_bdl, + callback, + ) - print(json.dumps(nx.node_link_data(dealer.graph), ensure_ascii=False, indent=2)) + print(json.dumps(nx.node_link_data(graph), ensure_ascii=False, indent=2)) + + +if __name__ == "__main__": + trio.run(main) diff --git a/graphrag/utils.py b/graphrag/utils.py index 45cb7c79..29ed9425 100644 --- a/graphrag/utils.py +++ b/graphrag/utils.py @@ -352,25 +352,57 @@ def set_relation(tenant_id, kb_id, embd_mdl, from_ent_name, to_ent_name, meta): chunk["q_%d_vec" % len(ebd)] = ebd settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id) +async def does_graph_contains(tenant_id, kb_id, doc_id): + # Get doc_ids of graph + fields = ["source_id"] + condition = { + "knowledge_graph_kwd": ["graph"], + "removed_kwd": "N", + } + res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id])) + fields2 = settings.docStoreConn.getFields(res, fields) + graph_doc_ids = set() + for chunk_id in fields2.keys(): + graph_doc_ids = set(fields2[chunk_id]["source_id"]) + return doc_id in graph_doc_ids -def get_graph(tenant_id, kb_id): +async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]: + conds = { + "fields": ["source_id"], + "removed_kwd": "N", + "size": 1, + "knowledge_graph_kwd": ["graph"] + } + res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id])) + doc_ids = [] + if res.total == 0: + return doc_ids + for id in res.ids: + doc_ids = res.field[id]["source_id"] + return doc_ids + + +async def get_graph(tenant_id, kb_id): conds = { "fields": ["content_with_weight", "source_id"], "removed_kwd": "N", "size": 1, "knowledge_graph_kwd": ["graph"] } - res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id]) + res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id])) + if res.total == 0: + return None, [] for id in res.ids: try: return json_graph.node_link_graph(json.loads(res.field[id]["content_with_weight"]), edges="edges"), \ res.field[id]["source_id"] except Exception: continue - return rebuild_graph(tenant_id, kb_id) + result = await rebuild_graph(tenant_id, kb_id) + return result -def set_graph(tenant_id, kb_id, graph, docids): +async def set_graph(tenant_id, kb_id, graph, docids): chunk = { "content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False, indent=2), @@ -379,13 +411,13 @@ def set_graph(tenant_id, kb_id, graph, docids): "source_id": list(docids), "available_int": 0, "removed_kwd": "N" - } - res = settings.retrievaler.search({"knowledge_graph_kwd": "graph", "size": 1, "fields": []}, search.index_name(tenant_id), [kb_id]) + } + res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "graph", "size": 1, "fields": []}, search.index_name(tenant_id), [kb_id])) if res.ids: - settings.docStoreConn.update({"knowledge_graph_kwd": "graph"}, chunk, - search.index_name(tenant_id), kb_id) + await trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"knowledge_graph_kwd": "graph"}, chunk, + search.index_name(tenant_id), kb_id)) else: - settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id) + await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id)) def is_continuous_subsequence(subseq, seq): @@ -430,7 +462,7 @@ def merge_tuples(list1, list2): return result -def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop): +async def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop): def n_neighbor(id): nonlocal graph, n_hop count = 0 @@ -460,10 +492,10 @@ def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop): for n, p in pr.items(): graph.nodes[n]["pagerank"] = p try: - settings.docStoreConn.update({"entity_kwd": n, "kb_id": kb_id}, + await trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"entity_kwd": n, "kb_id": kb_id}, {"rank_flt": p, - "n_hop_with_weight": json.dumps(n_neighbor(n), ensure_ascii=False)}, - search.index_name(tenant_id), kb_id) + "n_hop_with_weight": json.dumps( (n), ensure_ascii=False)}, + search.index_name(tenant_id), kb_id)) except Exception as e: logging.exception(e) @@ -480,21 +512,21 @@ def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop): "knowledge_graph_kwd": "ty2ents", "available_int": 0 } - res = settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "size": 1, "fields": []}, - search.index_name(tenant_id), [kb_id]) + res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "size": 1, "fields": []}, + search.index_name(tenant_id), [kb_id])) if res.ids: - settings.docStoreConn.update({"knowledge_graph_kwd": "ty2ents"}, + await trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"knowledge_graph_kwd": "ty2ents"}, chunk, - search.index_name(tenant_id), kb_id) + search.index_name(tenant_id), kb_id)) else: - settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id) + await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id)) -def get_entity_type2sampels(idxnms, kb_ids: list): - es_res = settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, +async def get_entity_type2sampels(idxnms, kb_ids: list): + es_res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, - idxnms, kb_ids) + idxnms, kb_ids)) res = defaultdict(list) for id in es_res.ids: @@ -522,18 +554,18 @@ def flat_uniq_list(arr, key): return list(set(res)) -def rebuild_graph(tenant_id, kb_id): +async def rebuild_graph(tenant_id, kb_id): graph = nx.Graph() src_ids = [] flds = ["entity_kwd", "entity_type_kwd", "from_entity_kwd", "to_entity_kwd", "weight_int", "knowledge_graph_kwd", "source_id"] bs = 256 for i in range(0, 39*bs, bs): - es_res = settings.docStoreConn.search(flds, [], + es_res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["entity", "relation"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id] - ) + )) tot = settings.docStoreConn.getTotal(es_res) if tot == 0: return None, None diff --git a/rag/raptor.py b/rag/raptor.py index ba273578..99155e19 100644 --- a/rag/raptor.py +++ b/rag/raptor.py @@ -15,18 +15,25 @@ # import logging import re -from threading import Lock import umap import numpy as np from sklearn.mixture import GaussianMixture import trio -from graphrag.utils import get_llm_cache, get_embed_cache, set_embed_cache, set_llm_cache, chat_limiter +from graphrag.utils import ( + get_llm_cache, + get_embed_cache, + set_embed_cache, + set_llm_cache, + chat_limiter, +) from rag.utils import truncate class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: - def __init__(self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1): + def __init__( + self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1 + ): self._max_cluster = max_cluster self._llm_model = llm_model self._embd_model = embd_model @@ -34,22 +41,24 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: self._prompt = prompt self._max_token = max_token - def _chat(self, system, history, gen_conf): + async def _chat(self, system, history, gen_conf): response = get_llm_cache(self._llm_model.llm_name, system, history, gen_conf) if response: return response - response = self._llm_model.chat(system, history, gen_conf) + response = await trio.to_thread.run_sync( + lambda: self._llm_model.chat(system, history, gen_conf) + ) response = re.sub(r".*", "", response, flags=re.DOTALL) if response.find("**ERROR**") >= 0: raise Exception(response) set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf) return response - def _embedding_encode(self, txt): + async def _embedding_encode(self, txt): response = get_embed_cache(self._embd_model.llm_name, txt) if response is not None: return response - embds, _ = self._embd_model.encode([txt]) + embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt])) if len(embds) < 1 or len(embds[0]) < 1: raise Exception("Embedding error: ") embds = embds[0] @@ -74,36 +83,48 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: return [] chunks = [(s, a) for s, a in chunks if s and len(a) > 0] - async def summarize(ck_idx, lock): + async def summarize(ck_idx: list[int]): nonlocal chunks - try: - texts = [chunks[i][0] for i in ck_idx] - len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts)) - cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts]) - async with chat_limiter: - cnt = await trio.to_thread.run_sync(lambda: self._chat("You're a helpful assistant.", - [{"role": "user", - "content": self._prompt.format(cluster_content=cluster_content)}], - {"temperature": 0.3, "max_tokens": self._max_token} - )) - cnt = re.sub("(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", "", - cnt) - logging.debug(f"SUM: {cnt}") - embds, _ = self._embd_model.encode([cnt]) - with lock: - chunks.append((cnt, self._embedding_encode(cnt))) - except Exception as e: - logging.exception("summarize got exception") - return e + texts = [chunks[i][0] for i in ck_idx] + len_per_chunk = int( + (self._llm_model.max_length - self._max_token) / len(texts) + ) + cluster_content = "\n".join( + [truncate(t, max(1, len_per_chunk)) for t in texts] + ) + async with chat_limiter: + cnt = await self._chat( + "You're a helpful assistant.", + [ + { + "role": "user", + "content": self._prompt.format( + cluster_content=cluster_content + ), + } + ], + {"temperature": 0.3, "max_tokens": self._max_token}, + ) + cnt = re.sub( + "(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", + "", + cnt, + ) + logging.debug(f"SUM: {cnt}") + embds = await self._embedding_encode(cnt) + chunks.append((cnt, embds)) labels = [] - lock = Lock() while end - start > 1: - embeddings = [embd for _, embd in chunks[start: end]] + embeddings = [embd for _, embd in chunks[start:end]] if len(embeddings) == 2: - await summarize([start, start + 1], lock) + await summarize([start, start + 1]) if callback: - callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end)) + callback( + msg="Cluster one layer: {} -> {}".format( + end - start, len(chunks) - end + ) + ) labels.extend([0, 0]) layers.append((end, len(chunks))) start = end @@ -112,7 +133,9 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: n_neighbors = int((len(embeddings) - 1) ** 0.8) reduced_embeddings = umap.UMAP( - n_neighbors=max(2, n_neighbors), n_components=min(12, len(embeddings) - 2), metric="cosine" + n_neighbors=max(2, n_neighbors), + n_components=min(12, len(embeddings) - 2), + metric="cosine", ).fit_transform(embeddings) n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state) if n_clusters == 1: @@ -127,18 +150,22 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: async with trio.open_nursery() as nursery: for c in range(n_clusters): ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c] - if not ck_idx: - continue + assert len(ck_idx) > 0 async with chat_limiter: - nursery.start_soon(lambda: summarize(ck_idx, lock)) + nursery.start_soon(lambda: summarize(ck_idx)) - assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters) + assert len(chunks) - end == n_clusters, "{} vs. {}".format( + len(chunks) - end, n_clusters + ) labels.extend(lbls) layers.append((end, len(chunks))) if callback: - callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end)) + callback( + msg="Cluster one layer: {} -> {}".format( + end - start, len(chunks) - end + ) + ) start = end end = len(chunks) return chunks - diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 0f254e92..320ac130 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -20,9 +20,7 @@ import random import sys from api.utils.log_utils import initRootLogger, get_project_base_directory -from graphrag.general.index import WithCommunity, WithResolution, Dealer -from graphrag.light.graph_extractor import GraphExtractor as LightKGExt -from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt +from graphrag.general.index import run_graphrag from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache from rag.prompts import keyword_extraction, question_proposal, content_tagging @@ -45,6 +43,7 @@ import tracemalloc import resource import signal import trio +import exceptiongroup import numpy as np from peewee import DoesNotExist @@ -453,24 +452,6 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None): return res, tk_count -async def run_graphrag(row, chat_model, language, embedding_model, callback=None): - chunks = [] - for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], - fields=["content_with_weight", "doc_id"]): - chunks.append((d["doc_id"], d["content_with_weight"])) - - dealer = Dealer(LightKGExt if row["parser_config"]["graphrag"]["method"] != 'general' else GeneralKGExt, - row["tenant_id"], - str(row["kb_id"]), - chat_model, - chunks=chunks, - language=language, - entity_types=row["parser_config"]["graphrag"]["entity_types"], - embed_bdl=embedding_model, - callback=callback) - await dealer() - - async def do_handle_task(task): task_id = task["id"] task_from_page = task["from_page"] @@ -526,24 +507,10 @@ async def do_handle_task(task): return start_ts = timer() chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) - await run_graphrag(task, chat_model, task_language, embedding_model, progress_callback) - progress_callback(prog=1.0, msg="Knowledge Graph basic is done ({:.2f}s)".format(timer() - start_ts)) - if graphrag_conf.get("resolution", False): - start_ts = timer() - with_res = WithResolution( - task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model, - progress_callback - ) - await with_res() - progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts)) - if graphrag_conf.get("community", False): - start_ts = timer() - with_comm = WithCommunity( - task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model, - progress_callback - ) - await with_comm() - progress_callback(prog=1.0, msg="Knowledge Graph community is done ({:.2f}s)".format(timer() - start_ts)) + with_resolution = graphrag_conf.get("resolution", False) + with_community = graphrag_conf.get("community", False) + await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback) + progress_callback(prog=1.0, msg="Knowledge Graph done ({:.2f}s)".format(timer() - start_ts)) return else: # Standard chunking methods @@ -622,7 +589,11 @@ async def handle_task(): FAILED_TASKS += 1 CURRENT_TASKS.pop(task["id"], None) try: - set_progress(task["id"], prog=-1, msg=f"[Exception]: {e}") + err_msg = str(e) + while isinstance(e, exceptiongroup.ExceptionGroup): + e = e.exceptions[0] + err_msg += ' -- ' + str(e) + set_progress(task["id"], prog=-1, msg=f"[Exception]: {err_msg}") except Exception: pass logging.exception(f"handle_task got exception for task {json.dumps(task)}") diff --git a/rag/utils/redis_conn.py b/rag/utils/redis_conn.py index 78162d9f..75acb483 100644 --- a/rag/utils/redis_conn.py +++ b/rag/utils/redis_conn.py @@ -16,13 +16,12 @@ import logging import json -import time import uuid import valkey as redis from rag import settings from rag.utils import singleton - +from valkey.lock import Lock class RedisMsg: def __init__(self, consumer, queue_name, group_name, msg_id, message): @@ -281,29 +280,23 @@ REDIS_CONN = RedisDB() class RedisDistributedLock: - def __init__(self, lock_key, timeout=10): + def __init__(self, lock_key, lock_value=None, timeout=10, blocking_timeout=1): self.lock_key = lock_key - self.lock_value = str(uuid.uuid4()) + if lock_value: + self.lock_value = lock_value + else: + self.lock_value = str(uuid.uuid4()) self.timeout = timeout + self.lock = Lock(REDIS_CONN.REDIS, lock_key, timeout=timeout, blocking_timeout=blocking_timeout) - @staticmethod - def clean_lock(lock_key): - REDIS_CONN.REDIS.delete(lock_key) + def acquire(self): + return self.lock.acquire() - def acquire_lock(self): - end_time = time.time() + self.timeout - while time.time() < end_time: - if REDIS_CONN.REDIS.setnx(self.lock_key, self.lock_value): - return True - time.sleep(1) - return False - - def release_lock(self): - if REDIS_CONN.REDIS.get(self.lock_key) == self.lock_value: - REDIS_CONN.REDIS.delete(self.lock_key) + def release(self): + return self.lock.release() def __enter__(self): - self.acquire_lock() + self.acquire() def __exit__(self, exception_type, exception_value, exception_traceback): - self.release_lock() \ No newline at end of file + self.release() \ No newline at end of file