From 86892959a0808005362f70c24cb6e4d5002d123f Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Thu, 23 Jan 2025 17:26:20 +0800 Subject: [PATCH] Rebuild graph when it's out of time. (#4607) ### What problem does this PR solve? #4543 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] Refactoring --- api/db/services/dialog_service.py | 15 ++++++++++--- graphrag/search.py | 6 +++--- graphrag/utils.py | 36 ++++++++++++++++++++++++++++++- rag/nlp/search.py | 2 +- rag/svr/task_executor.py | 6 ++++-- 5 files changed, 55 insertions(+), 10 deletions(-) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index fa03ba65f..e30bfedd8 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -17,6 +17,7 @@ import logging import binascii import os import json +import json_repair import re from collections import defaultdict from copy import deepcopy @@ -353,7 +354,7 @@ def chat(dialog, messages, stream=True, **kwargs): generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000 prompt = f"{prompt}\n\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms" - return {"answer": answer, "reference": refs, "prompt": prompt} + return {"answer": answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt)} if stream: last_ans = "" @@ -795,5 +796,13 @@ Output: if kwd.find("**ERROR**") >= 0: raise Exception(kwd) - kwd = re.sub(r".*?\{", "{", kwd) - return json.loads(kwd) \ No newline at end of file + try: + return json_repair.loads(kwd) + except json_repair.JSONDecodeError: + try: + result = kwd.replace(prompt[:-1], '').replace('user', '').replace('model', '').strip() + result = '{' + result.split('{')[1].split('}')[0] + '}' + return json_repair.loads(result) + except Exception as e: + logging.exception(f"JSON parsing error: {result} -> {e}") + raise e diff --git a/graphrag/search.py b/graphrag/search.py index 5b37a4c99..c0cd1098c 100644 --- a/graphrag/search.py +++ b/graphrag/search.py @@ -251,11 +251,11 @@ class KGSearch(Dealer): break if ents: - ents = "\n-Entities-\n{}".format(pd.DataFrame(ents).to_csv()) + ents = "\n---- Entities ----\n{}".format(pd.DataFrame(ents).to_csv()) else: ents = "" if relas: - relas = "\n-Relations-\n{}".format(pd.DataFrame(relas).to_csv()) + relas = "\n---- Relations ----\n{}".format(pd.DataFrame(relas).to_csv()) else: relas = "" @@ -296,7 +296,7 @@ class KGSearch(Dealer): if not txts: return "" - return "\n-Community Report-\n" + "\n".join(txts) + return "\n---- Community Report ----\n" + "\n".join(txts) if __name__ == "__main__": diff --git a/graphrag/utils.py b/graphrag/utils.py index 24deb6b69..e0227aeab 100644 --- a/graphrag/utils.py +++ b/graphrag/utils.py @@ -23,6 +23,7 @@ from networkx.readwrite import json_graph from api import settings from rag.nlp import search, rag_tokenizer +from rag.utils.doc_store_conn import OrderByExpr from rag.utils.redis_conn import REDIS_CONN ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None] @@ -363,7 +364,7 @@ def get_graph(tenant_id, kb_id): res.field[id]["source_id"] except Exception: continue - return None, None + return rebuild_graph(tenant_id, kb_id) def set_graph(tenant_id, kb_id, graph, docids): @@ -517,3 +518,36 @@ def flat_uniq_list(arr, key): res.append(a) return list(set(res)) + +def rebuild_graph(tenant_id, kb_id): + graph = nx.Graph() + src_ids = [] + flds = ["entity_kwd", "entity_type_kwd", "from_entity_kwd", "to_entity_kwd", "weight_int", "knowledge_graph_kwd", "source_id"] + bs = 256 + for i in range(0, 10000000, bs): + es_res = settings.docStoreConn.search(flds, [], + {"kb_id": kb_id, "knowledge_graph_kwd": ["entity", "relation"]}, + [], + OrderByExpr(), + i, bs, search.index_name(tenant_id), [kb_id] + ) + tot = settings.docStoreConn.getTotal(es_res) + if tot == 0: + return None, None + + es_res = settings.docStoreConn.getFields(es_res, flds) + for id, d in es_res.items(): + src_ids.extend(d.get("source_id", [])) + if d["knowledge_graph_kwd"] == "entity": + graph.add_node(d["entity_kwd"], entity_type=d["entity_type_kwd"]) + else: + graph.add_edge( + d["from_entity_kwd"], + d["to_entity_kwd"], + weight=int(d["weight_int"]) + ) + + if len(es_res.keys()) < 128: + return graph, list(set(src_ids)) + + return graph, list(set(src_ids)) diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 922ab66a1..5d67d9da7 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -483,4 +483,4 @@ class Dealer: cnt = np.sum([c for _, c in aggs]) tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / (all_tags.get(a, 0.0001)))) for a, c in aggs], key=lambda x: x[1] * -1)[:topn_tags] - return {a: c for a, c in tag_fea if c > 0} + return {a: max(1, c) for a, c in tag_fea} diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index dd62b315c..dda396c1b 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -327,8 +327,10 @@ def build_chunks(task, progress_callback): random.choices(examples, k=2) if len(examples)>2 else examples, topn=topn_tags) if cached: - set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags}) - d[TAG_FLD] = json.loads(cached) + cached = json.dumps(cached) + if cached: + set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags}) + d[TAG_FLD] = json.loads(cached) progress_callback(msg="Tagging completed in {:.2f}s".format(timer() - st))