From e2138738528eb3cbdb4eb12f3a0f176a750380b0 Mon Sep 17 00:00:00 2001 From: Zhichang Yu Date: Thu, 13 Mar 2025 14:37:59 +0800 Subject: [PATCH] Optimize graphrag cache get entity (#6018) ### What problem does this PR solve? Optimize graphrag cache get entity ### Type of change - [x] Performance Improvement --- graphrag/utils.py | 27 +++++++++++++++++++++++ rag/svr/task_executor.py | 47 +++++++++++++++++++++++----------------- 2 files changed, 54 insertions(+), 20 deletions(-) diff --git a/graphrag/utils.py b/graphrag/utils.py index ab09e536f..efac9783a 100644 --- a/graphrag/utils.py +++ b/graphrag/utils.py @@ -237,8 +237,33 @@ def is_float_regex(value): def chunk_id(chunk): return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest() +def get_entity_cache(tenant_id, kb_id, ent_name) -> str | list[str]: + hasher = xxhash.xxh64() + hasher.update(str(tenant_id).encode("utf-8")) + hasher.update(str(kb_id).encode("utf-8")) + hasher.update(str(ent_name).encode("utf-8")) + + k = hasher.hexdigest() + bin = REDIS_CONN.get(k) + if not bin: + return + return json.loads(bin) + + +def set_entity_cache(tenant_id, kb_id, ent_name, content_with_weight): + hasher = xxhash.xxh64() + hasher.update(str(tenant_id).encode("utf-8")) + hasher.update(str(kb_id).encode("utf-8")) + hasher.update(str(ent_name).encode("utf-8")) + + k = hasher.hexdigest() + REDIS_CONN.set(k, content_with_weight.encode("utf-8"), 3600) + def get_entity(tenant_id, kb_id, ent_name): + cache = get_entity_cache(tenant_id, kb_id, ent_name) + if cache: + return cache conds = { "fields": ["content_with_weight"], "entity_kwd": ent_name, @@ -250,6 +275,7 @@ def get_entity(tenant_id, kb_id, ent_name): for id in es_res.ids: try: if isinstance(ent_name, str): + set_entity_cache(tenant_id, kb_id, ent_name, es_res.field[id]["content_with_weight"]) return json.loads(es_res.field[id]["content_with_weight"]) res.append(json.loads(es_res.field[id]["content_with_weight"])) except Exception: @@ -272,6 +298,7 @@ def set_entity(tenant_id, kb_id, embd_mdl, ent_name, meta): "available_int": 0 } chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"]) + set_entity_cache(tenant_id, kb_id, ent_name, chunk["content_with_weight"]) res = settings.retrievaler.search({"entity_kwd": ent_name, "size": 1, "fields": []}, search.index_name(tenant_id), [kb_id]) if res.ids: diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 247042641..ea2ad7a74 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -26,7 +26,6 @@ from rag.prompts import keyword_extraction, question_proposal, content_tagging CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1] CONSUMER_NAME = "task_executor_" + CONSUMER_NO -initRootLogger(CONSUMER_NAME) import logging import os @@ -43,6 +42,7 @@ import tracemalloc import signal import trio import exceptiongroup +import faulthandler import numpy as np from peewee import DoesNotExist @@ -139,30 +139,35 @@ class TaskCanceledException(Exception): def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."): - if prog is not None and prog < 0: - msg = "[ERROR]" + msg - cancel = TaskService.do_cancel(task_id) + try: + if prog is not None and prog < 0: + msg = "[ERROR]" + msg + cancel = TaskService.do_cancel(task_id) - if cancel: - msg += " [Canceled]" - prog = -1 + if cancel: + msg += " [Canceled]" + prog = -1 - if to_page > 0: + if to_page > 0: + if msg: + if from_page < to_page: + msg = f"Page({from_page + 1}~{to_page + 1}): " + msg if msg: - if from_page < to_page: - msg = f"Page({from_page + 1}~{to_page + 1}): " + msg - if msg: - msg = datetime.now().strftime("%H:%M:%S") + " " + msg - d = {"progress_msg": msg} - if prog is not None: - d["progress"] = prog + msg = datetime.now().strftime("%H:%M:%S") + " " + msg + d = {"progress_msg": msg} + if prog is not None: + d["progress"] = prog - logging.info(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}") - TaskService.update_progress(task_id, d) + TaskService.update_progress(task_id, d) - close_connection() - if cancel: - raise TaskCanceledException(msg) + close_connection() + if cancel: + raise TaskCanceledException(msg) + logging.info(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}") + except DoesNotExist: + logging.warning(f"set_progress({task_id}) got exception DoesNotExist") + except Exception: + logging.exception(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}, got exception") async def collect(): global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS @@ -664,4 +669,6 @@ async def main(): logging.error("BUG!!! You should not reach here!!!") if __name__ == "__main__": + faulthandler.enable() + initRootLogger(CONSUMER_NAME) trio.run(main)