Optimize graphrag cache get entity (#6018)

### What problem does this PR solve?

Optimize graphrag cache get entity

### Type of change

- [x] Performance Improvement
This commit is contained in:
Zhichang Yu 2025-03-13 14:37:59 +08:00 committed by GitHub
parent 56acb340d2
commit e213873852
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 54 additions and 20 deletions

View File

@ -237,8 +237,33 @@ def is_float_regex(value):
def chunk_id(chunk): def chunk_id(chunk):
return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest() return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest()
def get_entity_cache(tenant_id, kb_id, ent_name) -> str | list[str]:
hasher = xxhash.xxh64()
hasher.update(str(tenant_id).encode("utf-8"))
hasher.update(str(kb_id).encode("utf-8"))
hasher.update(str(ent_name).encode("utf-8"))
k = hasher.hexdigest()
bin = REDIS_CONN.get(k)
if not bin:
return
return json.loads(bin)
def set_entity_cache(tenant_id, kb_id, ent_name, content_with_weight):
hasher = xxhash.xxh64()
hasher.update(str(tenant_id).encode("utf-8"))
hasher.update(str(kb_id).encode("utf-8"))
hasher.update(str(ent_name).encode("utf-8"))
k = hasher.hexdigest()
REDIS_CONN.set(k, content_with_weight.encode("utf-8"), 3600)
def get_entity(tenant_id, kb_id, ent_name): def get_entity(tenant_id, kb_id, ent_name):
cache = get_entity_cache(tenant_id, kb_id, ent_name)
if cache:
return cache
conds = { conds = {
"fields": ["content_with_weight"], "fields": ["content_with_weight"],
"entity_kwd": ent_name, "entity_kwd": ent_name,
@ -250,6 +275,7 @@ def get_entity(tenant_id, kb_id, ent_name):
for id in es_res.ids: for id in es_res.ids:
try: try:
if isinstance(ent_name, str): if isinstance(ent_name, str):
set_entity_cache(tenant_id, kb_id, ent_name, es_res.field[id]["content_with_weight"])
return json.loads(es_res.field[id]["content_with_weight"]) return json.loads(es_res.field[id]["content_with_weight"])
res.append(json.loads(es_res.field[id]["content_with_weight"])) res.append(json.loads(es_res.field[id]["content_with_weight"]))
except Exception: except Exception:
@ -272,6 +298,7 @@ def set_entity(tenant_id, kb_id, embd_mdl, ent_name, meta):
"available_int": 0 "available_int": 0
} }
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"]) chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
set_entity_cache(tenant_id, kb_id, ent_name, chunk["content_with_weight"])
res = settings.retrievaler.search({"entity_kwd": ent_name, "size": 1, "fields": []}, res = settings.retrievaler.search({"entity_kwd": ent_name, "size": 1, "fields": []},
search.index_name(tenant_id), [kb_id]) search.index_name(tenant_id), [kb_id])
if res.ids: if res.ids:

View File

@ -26,7 +26,6 @@ from rag.prompts import keyword_extraction, question_proposal, content_tagging
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1] CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
CONSUMER_NAME = "task_executor_" + CONSUMER_NO CONSUMER_NAME = "task_executor_" + CONSUMER_NO
initRootLogger(CONSUMER_NAME)
import logging import logging
import os import os
@ -43,6 +42,7 @@ import tracemalloc
import signal import signal
import trio import trio
import exceptiongroup import exceptiongroup
import faulthandler
import numpy as np import numpy as np
from peewee import DoesNotExist from peewee import DoesNotExist
@ -139,30 +139,35 @@ class TaskCanceledException(Exception):
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."): def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
if prog is not None and prog < 0: try:
msg = "[ERROR]" + msg if prog is not None and prog < 0:
cancel = TaskService.do_cancel(task_id) msg = "[ERROR]" + msg
cancel = TaskService.do_cancel(task_id)
if cancel: if cancel:
msg += " [Canceled]" msg += " [Canceled]"
prog = -1 prog = -1
if to_page > 0: if to_page > 0:
if msg:
if from_page < to_page:
msg = f"Page({from_page + 1}~{to_page + 1}): " + msg
if msg: if msg:
if from_page < to_page: msg = datetime.now().strftime("%H:%M:%S") + " " + msg
msg = f"Page({from_page + 1}~{to_page + 1}): " + msg d = {"progress_msg": msg}
if msg: if prog is not None:
msg = datetime.now().strftime("%H:%M:%S") + " " + msg d["progress"] = prog
d = {"progress_msg": msg}
if prog is not None:
d["progress"] = prog
logging.info(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}") TaskService.update_progress(task_id, d)
TaskService.update_progress(task_id, d)
close_connection() close_connection()
if cancel: if cancel:
raise TaskCanceledException(msg) raise TaskCanceledException(msg)
logging.info(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}")
except DoesNotExist:
logging.warning(f"set_progress({task_id}) got exception DoesNotExist")
except Exception:
logging.exception(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}, got exception")
async def collect(): async def collect():
global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS
@ -664,4 +669,6 @@ async def main():
logging.error("BUG!!! You should not reach here!!!") logging.error("BUG!!! You should not reach here!!!")
if __name__ == "__main__": if __name__ == "__main__":
faulthandler.enable()
initRootLogger(CONSUMER_NAME)
trio.run(main) trio.run(main)