mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-14 09:45:58 +08:00
Optimize graphrag cache get entity (#6018)
### What problem does this PR solve? Optimize graphrag cache get entity ### Type of change - [x] Performance Improvement
This commit is contained in:
parent
56acb340d2
commit
e213873852
@ -237,8 +237,33 @@ def is_float_regex(value):
|
|||||||
def chunk_id(chunk):
|
def chunk_id(chunk):
|
||||||
return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest()
|
return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
def get_entity_cache(tenant_id, kb_id, ent_name) -> str | list[str]:
|
||||||
|
hasher = xxhash.xxh64()
|
||||||
|
hasher.update(str(tenant_id).encode("utf-8"))
|
||||||
|
hasher.update(str(kb_id).encode("utf-8"))
|
||||||
|
hasher.update(str(ent_name).encode("utf-8"))
|
||||||
|
|
||||||
|
k = hasher.hexdigest()
|
||||||
|
bin = REDIS_CONN.get(k)
|
||||||
|
if not bin:
|
||||||
|
return
|
||||||
|
return json.loads(bin)
|
||||||
|
|
||||||
|
|
||||||
|
def set_entity_cache(tenant_id, kb_id, ent_name, content_with_weight):
|
||||||
|
hasher = xxhash.xxh64()
|
||||||
|
hasher.update(str(tenant_id).encode("utf-8"))
|
||||||
|
hasher.update(str(kb_id).encode("utf-8"))
|
||||||
|
hasher.update(str(ent_name).encode("utf-8"))
|
||||||
|
|
||||||
|
k = hasher.hexdigest()
|
||||||
|
REDIS_CONN.set(k, content_with_weight.encode("utf-8"), 3600)
|
||||||
|
|
||||||
|
|
||||||
def get_entity(tenant_id, kb_id, ent_name):
|
def get_entity(tenant_id, kb_id, ent_name):
|
||||||
|
cache = get_entity_cache(tenant_id, kb_id, ent_name)
|
||||||
|
if cache:
|
||||||
|
return cache
|
||||||
conds = {
|
conds = {
|
||||||
"fields": ["content_with_weight"],
|
"fields": ["content_with_weight"],
|
||||||
"entity_kwd": ent_name,
|
"entity_kwd": ent_name,
|
||||||
@ -250,6 +275,7 @@ def get_entity(tenant_id, kb_id, ent_name):
|
|||||||
for id in es_res.ids:
|
for id in es_res.ids:
|
||||||
try:
|
try:
|
||||||
if isinstance(ent_name, str):
|
if isinstance(ent_name, str):
|
||||||
|
set_entity_cache(tenant_id, kb_id, ent_name, es_res.field[id]["content_with_weight"])
|
||||||
return json.loads(es_res.field[id]["content_with_weight"])
|
return json.loads(es_res.field[id]["content_with_weight"])
|
||||||
res.append(json.loads(es_res.field[id]["content_with_weight"]))
|
res.append(json.loads(es_res.field[id]["content_with_weight"]))
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -272,6 +298,7 @@ def set_entity(tenant_id, kb_id, embd_mdl, ent_name, meta):
|
|||||||
"available_int": 0
|
"available_int": 0
|
||||||
}
|
}
|
||||||
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
|
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
|
||||||
|
set_entity_cache(tenant_id, kb_id, ent_name, chunk["content_with_weight"])
|
||||||
res = settings.retrievaler.search({"entity_kwd": ent_name, "size": 1, "fields": []},
|
res = settings.retrievaler.search({"entity_kwd": ent_name, "size": 1, "fields": []},
|
||||||
search.index_name(tenant_id), [kb_id])
|
search.index_name(tenant_id), [kb_id])
|
||||||
if res.ids:
|
if res.ids:
|
||||||
|
@ -26,7 +26,6 @@ from rag.prompts import keyword_extraction, question_proposal, content_tagging
|
|||||||
|
|
||||||
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
|
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
|
||||||
CONSUMER_NAME = "task_executor_" + CONSUMER_NO
|
CONSUMER_NAME = "task_executor_" + CONSUMER_NO
|
||||||
initRootLogger(CONSUMER_NAME)
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@ -43,6 +42,7 @@ import tracemalloc
|
|||||||
import signal
|
import signal
|
||||||
import trio
|
import trio
|
||||||
import exceptiongroup
|
import exceptiongroup
|
||||||
|
import faulthandler
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from peewee import DoesNotExist
|
from peewee import DoesNotExist
|
||||||
@ -139,30 +139,35 @@ class TaskCanceledException(Exception):
|
|||||||
|
|
||||||
|
|
||||||
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
|
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
|
||||||
if prog is not None and prog < 0:
|
try:
|
||||||
msg = "[ERROR]" + msg
|
if prog is not None and prog < 0:
|
||||||
cancel = TaskService.do_cancel(task_id)
|
msg = "[ERROR]" + msg
|
||||||
|
cancel = TaskService.do_cancel(task_id)
|
||||||
|
|
||||||
if cancel:
|
if cancel:
|
||||||
msg += " [Canceled]"
|
msg += " [Canceled]"
|
||||||
prog = -1
|
prog = -1
|
||||||
|
|
||||||
if to_page > 0:
|
if to_page > 0:
|
||||||
|
if msg:
|
||||||
|
if from_page < to_page:
|
||||||
|
msg = f"Page({from_page + 1}~{to_page + 1}): " + msg
|
||||||
if msg:
|
if msg:
|
||||||
if from_page < to_page:
|
msg = datetime.now().strftime("%H:%M:%S") + " " + msg
|
||||||
msg = f"Page({from_page + 1}~{to_page + 1}): " + msg
|
d = {"progress_msg": msg}
|
||||||
if msg:
|
if prog is not None:
|
||||||
msg = datetime.now().strftime("%H:%M:%S") + " " + msg
|
d["progress"] = prog
|
||||||
d = {"progress_msg": msg}
|
|
||||||
if prog is not None:
|
|
||||||
d["progress"] = prog
|
|
||||||
|
|
||||||
logging.info(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}")
|
TaskService.update_progress(task_id, d)
|
||||||
TaskService.update_progress(task_id, d)
|
|
||||||
|
|
||||||
close_connection()
|
close_connection()
|
||||||
if cancel:
|
if cancel:
|
||||||
raise TaskCanceledException(msg)
|
raise TaskCanceledException(msg)
|
||||||
|
logging.info(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}")
|
||||||
|
except DoesNotExist:
|
||||||
|
logging.warning(f"set_progress({task_id}) got exception DoesNotExist")
|
||||||
|
except Exception:
|
||||||
|
logging.exception(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}, got exception")
|
||||||
|
|
||||||
async def collect():
|
async def collect():
|
||||||
global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS
|
global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS
|
||||||
@ -664,4 +669,6 @@ async def main():
|
|||||||
logging.error("BUG!!! You should not reach here!!!")
|
logging.error("BUG!!! You should not reach here!!!")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
faulthandler.enable()
|
||||||
|
initRootLogger(CONSUMER_NAME)
|
||||||
trio.run(main)
|
trio.run(main)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user