Refactor graphrag to remove redis lock (#5828)

### What problem does this PR solve?

Refactor graphrag to remove redis lock

### Type of change

- [x] Refactoring
This commit is contained in:
Zhichang Yu 2025-03-10 15:15:06 +08:00 committed by GitHub
parent 1163e9e409
commit 6ec6ca6971
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 602 additions and 332 deletions

View File

@ -42,16 +42,22 @@ from api.db.init_data import init_web_data
from api.versions import get_ragflow_version from api.versions import get_ragflow_version
from api.utils import show_configs from api.utils import show_configs
from rag.settings import print_rag_settings from rag.settings import print_rag_settings
from rag.utils.redis_conn import RedisDistributedLock
stop_event = threading.Event() stop_event = threading.Event()
def update_progress(): def update_progress():
redis_lock = RedisDistributedLock("update_progress", timeout=60)
while not stop_event.is_set(): while not stop_event.is_set():
try: try:
if not redis_lock.acquire():
continue
DocumentService.update_progress() DocumentService.update_progress()
stop_event.wait(6) stop_event.wait(6)
except Exception: except Exception:
logging.exception("update_progress exception") logging.exception("update_progress exception")
finally:
redis_lock.release()
def signal_handler(sig, frame): def signal_handler(sig, frame):
logging.info("Received interrupt signal, shutting down...") logging.info("Received interrupt signal, shutting down...")

View File

@ -93,7 +93,7 @@ class Extractor:
return dict(maybe_nodes), dict(maybe_edges) return dict(maybe_nodes), dict(maybe_edges)
async def __call__( async def __call__(
self, chunks: list[tuple[str, str]], self, doc_id: str, chunks: list[str],
callback: Callable | None = None callback: Callable | None = None
): ):
@ -101,9 +101,9 @@ class Extractor:
start_ts = trio.current_time() start_ts = trio.current_time()
out_results = [] out_results = []
async with trio.open_nursery() as nursery: async with trio.open_nursery() as nursery:
for i, (cid, ck) in enumerate(chunks): for i, ck in enumerate(chunks):
ck = truncate(ck, int(self._llm.max_length*0.8)) ck = truncate(ck, int(self._llm.max_length*0.8))
nursery.start_soon(lambda: self._process_single_content((cid, ck), i, len(chunks), out_results)) nursery.start_soon(lambda: self._process_single_content((doc_id, ck), i, len(chunks), out_results))
maybe_nodes = defaultdict(list) maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list) maybe_edges = defaultdict(list)
@ -241,10 +241,13 @@ class Extractor:
) -> str: ) -> str:
summary_max_tokens = 512 summary_max_tokens = 512
use_description = truncate(description, summary_max_tokens) use_description = truncate(description, summary_max_tokens)
description_list=use_description.split(GRAPH_FIELD_SEP),
if len(description_list) <= 12:
return use_description
prompt_template = SUMMARIZE_DESCRIPTIONS_PROMPT prompt_template = SUMMARIZE_DESCRIPTIONS_PROMPT
context_base = dict( context_base = dict(
entity_name=entity_or_relation_name, entity_name=entity_or_relation_name,
description_list=use_description.split(GRAPH_FIELD_SEP), description_list=description_list,
language=self._language, language=self._language,
) )
use_prompt = prompt_template.format(**context_base) use_prompt = prompt_template.format(**context_base)

View File

@ -15,196 +15,353 @@
# #
import json import json
import logging import logging
from functools import reduce, partial from functools import partial
import networkx as nx import networkx as nx
import trio import trio
from api import settings from api import settings
from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
from graphrag.general.community_reports_extractor import CommunityReportsExtractor from graphrag.general.community_reports_extractor import CommunityReportsExtractor
from graphrag.entity_resolution import EntityResolution from graphrag.entity_resolution import EntityResolution
from graphrag.general.extractor import Extractor from graphrag.general.extractor import Extractor
from graphrag.general.graph_extractor import DEFAULT_ENTITY_TYPES from graphrag.utils import (
from graphrag.utils import graph_merge, set_entity, get_relation, set_relation, get_entity, get_graph, set_graph, \ graph_merge,
chunk_id, update_nodes_pagerank_nhop_neighbour set_entity,
get_relation,
set_relation,
get_entity,
get_graph,
set_graph,
chunk_id,
update_nodes_pagerank_nhop_neighbour,
does_graph_contains,
get_graph_doc_ids,
)
from rag.nlp import rag_tokenizer, search from rag.nlp import rag_tokenizer, search
from rag.utils.redis_conn import RedisDistributedLock from rag.utils.redis_conn import REDIS_CONN
class Dealer: def graphrag_task_set(tenant_id, kb_id, doc_id) -> bool:
def __init__(self, key = f"graphrag:{tenant_id}:{kb_id}"
ok = REDIS_CONN.set(key, doc_id, exp=3600 * 24)
if not ok:
raise Exception(f"Faild to set the {key} to {doc_id}")
def graphrag_task_get(tenant_id, kb_id) -> str | None:
key = f"graphrag:{tenant_id}:{kb_id}"
doc_id = REDIS_CONN.get(key)
return doc_id
async def run_graphrag(
row: dict,
language,
with_resolution: bool,
with_community: bool,
chat_model,
embedding_model,
callback,
):
start = trio.current_time()
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
chunks = []
for d in settings.retrievaler.chunk_list(
doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"]
):
chunks.append(d["content_with_weight"])
graph, doc_ids = await update_graph(
LightKGExt
if row["parser_config"]["graphrag"]["method"] != "general"
else GeneralKGExt,
tenant_id,
kb_id,
doc_id,
chunks,
language,
row["parser_config"]["graphrag"]["entity_types"],
chat_model,
embedding_model,
callback,
)
if not graph:
return
if with_resolution or with_community:
graphrag_task_set(tenant_id, kb_id, doc_id)
if with_resolution:
await resolve_entities(
graph,
doc_ids,
tenant_id,
kb_id,
doc_id,
chat_model,
embedding_model,
callback,
)
if with_community:
await extract_community(
graph,
doc_ids,
tenant_id,
kb_id,
doc_id,
chat_model,
embedding_model,
callback,
)
now = trio.current_time()
callback(msg=f"GraphRAG for doc {doc_id} done in {now - start:.2f} seconds.")
return
async def update_graph(
extractor: Extractor, extractor: Extractor,
tenant_id: str, tenant_id: str,
kb_id: str, kb_id: str,
llm_bdl, doc_id: str,
chunks: list[tuple[str, str]], chunks: list[str],
language, language,
entity_types=DEFAULT_ENTITY_TYPES, entity_types,
embed_bdl=None, llm_bdl,
callback=None embed_bdl,
callback,
): ):
self.tenant_id = tenant_id contains = await does_graph_contains(tenant_id, kb_id, doc_id)
self.kb_id = kb_id if contains:
self.chunks = chunks callback(msg=f"Graph already contains {doc_id}, cancel myself")
self.llm_bdl = llm_bdl return None, None
self.embed_bdl = embed_bdl start = trio.current_time()
self.ext = extractor(self.llm_bdl, language=language, ext = extractor(
llm_bdl,
language=language,
entity_types=entity_types, entity_types=entity_types,
get_entity=partial(get_entity, tenant_id, kb_id), get_entity=partial(get_entity, tenant_id, kb_id),
set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl), set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl),
get_relation=partial(get_relation, tenant_id, kb_id), get_relation=partial(get_relation, tenant_id, kb_id),
set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl) set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl),
) )
self.graph = nx.Graph() ents, rels = await ext(doc_id, chunks, callback)
self.callback = callback subgraph = nx.Graph()
async def __call__(self):
docids = list(set([docid for docid, _ in self.chunks]))
ents, rels = await self.ext(self.chunks, self.callback)
for en in ents: for en in ents:
self.graph.add_node(en["entity_name"], entity_type=en["entity_type"])#, description=en["description"]) subgraph.add_node(en["entity_name"], entity_type=en["entity_type"])
for rel in rels: for rel in rels:
self.graph.add_edge( subgraph.add_edge(
rel["src_id"], rel["src_id"],
rel["tgt_id"], rel["tgt_id"],
weight=rel["weight"], weight=rel["weight"],
# description=rel["description"] # description=rel["description"]
) )
# TODO: infinity doesn't support array search
chunk = {
"content_with_weight": json.dumps(
nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False, indent=2
),
"knowledge_graph_kwd": "subgraph",
"kb_id": kb_id,
"source_id": [doc_id],
"available_int": 0,
"removed_kwd": "N",
}
cid = chunk_id(chunk)
await trio.to_thread.run_sync(
lambda: settings.docStoreConn.insert(
[{"id": cid, **chunk}], search.index_name(tenant_id), kb_id
)
)
now = trio.current_time()
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
start = now
with RedisDistributedLock(self.kb_id, 60*60): while True:
old_graph, old_doc_ids = get_graph(self.tenant_id, self.kb_id) new_graph = subgraph
now_docids = set([doc_id])
old_graph, old_doc_ids = await get_graph(tenant_id, kb_id)
if old_graph is not None: if old_graph is not None:
logging.info("Merge with an exiting graph...................") logging.info("Merge with an exiting graph...................")
self.graph = reduce(graph_merge, [old_graph, self.graph]) new_graph = graph_merge(old_graph, subgraph)
update_nodes_pagerank_nhop_neighbour(self.tenant_id, self.kb_id, self.graph, 2) await update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, new_graph, 2)
if old_doc_ids: if old_doc_ids:
docids.extend(old_doc_ids) for old_doc_id in old_doc_ids:
docids = list(set(docids)) now_docids.add(old_doc_id)
set_graph(self.tenant_id, self.kb_id, self.graph, docids) old_doc_ids2 = await get_graph_doc_ids(tenant_id, kb_id)
delta_doc_ids = set(old_doc_ids2) - set(old_doc_ids)
if delta_doc_ids:
callback(
msg="The global graph has changed during merging, try again"
)
await trio.sleep(1)
continue
break
await set_graph(tenant_id, kb_id, new_graph, list(now_docids))
now = trio.current_time()
callback(
msg=f"merging subgraph for doc {doc_id} into the global graph done in {now - start:.2f} seconds."
)
return new_graph, now_docids
class WithResolution(Dealer): async def resolve_entities(
def __init__(self, graph,
doc_ids,
tenant_id: str, tenant_id: str,
kb_id: str, kb_id: str,
doc_id: str,
llm_bdl, llm_bdl,
embed_bdl=None, embed_bdl,
callback=None callback,
): ):
self.tenant_id = tenant_id working_doc_id = graphrag_task_get(tenant_id, kb_id)
self.kb_id = kb_id if doc_id != working_doc_id:
self.llm_bdl = llm_bdl callback(
self.embed_bdl = embed_bdl msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself"
self.callback = callback )
async def __call__(self):
with RedisDistributedLock(self.kb_id, 60*60):
self.graph, doc_ids = await trio.to_thread.run_sync(lambda: get_graph(self.tenant_id, self.kb_id))
if not self.graph:
logging.error(f"Faild to fetch the graph. tenant_id:{self.kb_id}, kb_id:{self.kb_id}")
if self.callback:
self.callback(-1, msg="Faild to fetch the graph.")
return return
start = trio.current_time()
er = EntityResolution(
llm_bdl,
get_entity=partial(get_entity, tenant_id, kb_id),
set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl),
get_relation=partial(get_relation, tenant_id, kb_id),
set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl),
)
reso = await er(graph)
graph = reso.graph
callback(msg=f"Graph resolution removed {len(reso.removed_entities)} nodes.")
await update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, 2)
callback(msg="Graph resolution updated pagerank.")
if self.callback: working_doc_id = graphrag_task_get(tenant_id, kb_id)
self.callback(msg="Fetch the existing graph.") if doc_id != working_doc_id:
er = EntityResolution(self.llm_bdl, callback(
get_entity=partial(get_entity, self.tenant_id, self.kb_id), msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself"
set_entity=partial(set_entity, self.tenant_id, self.kb_id, self.embed_bdl), )
get_relation=partial(get_relation, self.tenant_id, self.kb_id), return
set_relation=partial(set_relation, self.tenant_id, self.kb_id, self.embed_bdl)) await set_graph(tenant_id, kb_id, graph, doc_ids)
reso = await er(self.graph)
self.graph = reso.graph
logging.info("Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
if self.callback:
self.callback(msg="Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
await trio.to_thread.run_sync(lambda: update_nodes_pagerank_nhop_neighbour(self.tenant_id, self.kb_id, self.graph, 2))
await trio.to_thread.run_sync(lambda: set_graph(self.tenant_id, self.kb_id, self.graph, doc_ids))
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({ await trio.to_thread.run_sync(
lambda: settings.docStoreConn.delete(
{
"knowledge_graph_kwd": "relation", "knowledge_graph_kwd": "relation",
"kb_id": self.kb_id, "kb_id": kb_id,
"from_entity_kwd": reso.removed_entities "from_entity_kwd": reso.removed_entities,
}, search.index_name(self.tenant_id), self.kb_id)) },
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({ search.index_name(tenant_id),
kb_id,
)
)
await trio.to_thread.run_sync(
lambda: settings.docStoreConn.delete(
{
"knowledge_graph_kwd": "relation", "knowledge_graph_kwd": "relation",
"kb_id": self.kb_id, "kb_id": kb_id,
"to_entity_kwd": reso.removed_entities "to_entity_kwd": reso.removed_entities,
}, search.index_name(self.tenant_id), self.kb_id)) },
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({ search.index_name(tenant_id),
kb_id,
)
)
await trio.to_thread.run_sync(
lambda: settings.docStoreConn.delete(
{
"knowledge_graph_kwd": "entity", "knowledge_graph_kwd": "entity",
"kb_id": self.kb_id, "kb_id": kb_id,
"entity_kwd": reso.removed_entities "entity_kwd": reso.removed_entities,
}, search.index_name(self.tenant_id), self.kb_id)) },
search.index_name(tenant_id),
kb_id,
)
)
now = trio.current_time()
callback(msg=f"Graph resolution done in {now - start:.2f}s.")
class WithCommunity(Dealer): async def extract_community(
def __init__(self, graph,
doc_ids,
tenant_id: str, tenant_id: str,
kb_id: str, kb_id: str,
doc_id: str,
llm_bdl, llm_bdl,
embed_bdl=None, embed_bdl,
callback=None callback,
): ):
working_doc_id = graphrag_task_get(tenant_id, kb_id)
self.tenant_id = tenant_id if doc_id != working_doc_id:
self.kb_id = kb_id callback(
self.community_structure = None msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself"
self.community_reports = None )
self.llm_bdl = llm_bdl
self.embed_bdl = embed_bdl
self.callback = callback
async def __call__(self):
with RedisDistributedLock(self.kb_id, 60*60):
self.graph, doc_ids = get_graph(self.tenant_id, self.kb_id)
if not self.graph:
logging.error(f"Faild to fetch the graph. tenant_id:{self.kb_id}, kb_id:{self.kb_id}")
if self.callback:
self.callback(-1, msg="Faild to fetch the graph.")
return return
if self.callback: start = trio.current_time()
self.callback(msg="Fetch the existing graph.") ext = CommunityReportsExtractor(
llm_bdl,
get_entity=partial(get_entity, tenant_id, kb_id),
set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl),
get_relation=partial(get_relation, tenant_id, kb_id),
set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl),
)
cr = await ext(graph, callback=callback)
community_structure = cr.structured_output
community_reports = cr.output
working_doc_id = graphrag_task_get(tenant_id, kb_id)
if doc_id != working_doc_id:
callback(
msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself"
)
return
await set_graph(tenant_id, kb_id, graph, doc_ids)
cr = CommunityReportsExtractor(self.llm_bdl, now = trio.current_time()
get_entity=partial(get_entity, self.tenant_id, self.kb_id), callback(
set_entity=partial(set_entity, self.tenant_id, self.kb_id, self.embed_bdl), msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s."
get_relation=partial(get_relation, self.tenant_id, self.kb_id), )
set_relation=partial(set_relation, self.tenant_id, self.kb_id, self.embed_bdl)) start = now
cr = await cr(self.graph, callback=self.callback) await trio.to_thread.run_sync(
self.community_structure = cr.structured_output lambda: settings.docStoreConn.delete(
self.community_reports = cr.output {"knowledge_graph_kwd": "community_report", "kb_id": kb_id},
await trio.to_thread.run_sync(lambda: set_graph(self.tenant_id, self.kb_id, self.graph, doc_ids)) search.index_name(tenant_id),
kb_id,
if self.callback: )
self.callback(msg="Graph community extraction is done. Indexing {} reports.".format(len(cr.structured_output))) )
for stru, rep in zip(community_structure, community_reports):
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({
"knowledge_graph_kwd": "community_report",
"kb_id": self.kb_id
}, search.index_name(self.tenant_id), self.kb_id))
for stru, rep in zip(self.community_structure, self.community_reports):
obj = { obj = {
"report": rep, "report": rep,
"evidences": "\n".join([f["explanation"] for f in stru["findings"]]) "evidences": "\n".join([f["explanation"] for f in stru["findings"]]),
} }
chunk = { chunk = {
"docnm_kwd": stru["title"], "docnm_kwd": stru["title"],
"title_tks": rag_tokenizer.tokenize(stru["title"]), "title_tks": rag_tokenizer.tokenize(stru["title"]),
"content_with_weight": json.dumps(obj, ensure_ascii=False), "content_with_weight": json.dumps(obj, ensure_ascii=False),
"content_ltks": rag_tokenizer.tokenize(obj["report"] +" "+ obj["evidences"]), "content_ltks": rag_tokenizer.tokenize(
obj["report"] + " " + obj["evidences"]
),
"knowledge_graph_kwd": "community_report", "knowledge_graph_kwd": "community_report",
"weight_flt": stru["weight"], "weight_flt": stru["weight"],
"entities_kwd": stru["entities"], "entities_kwd": stru["entities"],
"important_kwd": stru["entities"], "important_kwd": stru["entities"],
"kb_id": self.kb_id, "kb_id": kb_id,
"source_id": doc_ids, "source_id": doc_ids,
"available_int": 0 "available_int": 0,
} }
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"]) chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(
chunk["content_ltks"]
)
# try: # try:
# ebd, _ = self.embed_bdl.encode([", ".join(community["entities"])]) # ebd, _ = embed_bdl.encode([", ".join(community["entities"])])
# chunk["q_%d_vec" % len(ebd[0])] = ebd[0] # chunk["q_%d_vec" % len(ebd[0])] = ebd[0]
# except Exception as e: # except Exception as e:
# logging.exception(f"Fail to embed entity relation: {e}") # logging.exception(f"Fail to embed entity relation: {e}")
await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(self.tenant_id))) await trio.to_thread.run_sync(
lambda: settings.docStoreConn.insert(
[{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id)
)
)
now = trio.current_time()
callback(
msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s."
)
return community_structure, community_reports

View File

@ -16,7 +16,7 @@
import argparse import argparse
import json import json
import logging
import networkx as nx import networkx as nx
import trio import trio
@ -26,42 +26,85 @@ from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService
from graphrag.general.index import WithCommunity, Dealer, WithResolution from graphrag.general.graph_extractor import GraphExtractor
from graphrag.light.graph_extractor import GraphExtractor from graphrag.general.index import update_graph, with_resolution, with_community
from rag.utils.redis_conn import RedisDistributedLock
settings.init_settings() settings.init_settings()
if __name__ == "__main__":
def callback(prog=None, msg="Processing..."):
logging.info(msg)
async def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True) parser.add_argument(
parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True) "-t",
"--tenant_id",
default=False,
help="Tenant ID",
action="store",
required=True,
)
parser.add_argument(
"-d",
"--doc_id",
default=False,
help="Document ID",
action="store",
required=True,
)
args = parser.parse_args() args = parser.parse_args()
e, doc = DocumentService.get_by_id(args.doc_id) e, doc = DocumentService.get_by_id(args.doc_id)
if not e: if not e:
raise LookupError("Document not found.") raise LookupError("Document not found.")
kb_id = doc.kb_id kb_id = doc.kb_id
chunks = [d["content_with_weight"] for d in chunks = [
settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, [kb_id], max_count=6, d["content_with_weight"]
fields=["content_with_weight"])] for d in settings.retrievaler.chunk_list(
chunks = [("x", c) for c in chunks] args.doc_id,
args.tenant_id,
RedisDistributedLock.clean_lock(kb_id) [kb_id],
max_count=6,
fields=["content_with_weight"],
)
]
_, tenant = TenantService.get_by_id(args.tenant_id) _, tenant = TenantService.get_by_id(args.tenant_id)
llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id) llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id)
_, kb = KnowledgebaseService.get_by_id(kb_id) _, kb = KnowledgebaseService.get_by_id(kb_id)
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id) embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
dealer = Dealer(GraphExtractor, args.tenant_id, kb_id, llm_bdl, chunks, "English", embed_bdl=embed_bdl) graph, doc_ids = await update_graph(
trio.run(dealer()) GraphExtractor,
print(json.dumps(nx.node_link_data(dealer.graph), ensure_ascii=False, indent=2)) args.tenant_id,
kb_id,
args.doc_id,
chunks,
"English",
llm_bdl,
embed_bdl,
callback,
)
print(json.dumps(nx.node_link_data(graph), ensure_ascii=False, indent=2))
dealer = WithResolution(args.tenant_id, kb_id, llm_bdl, embed_bdl) await with_resolution(
trio.run(dealer()) args.tenant_id, kb_id, args.doc_id, llm_bdl, embed_bdl, callback
dealer = WithCommunity(args.tenant_id, kb_id, llm_bdl, embed_bdl) )
trio.run(dealer()) community_structure, community_reports = await with_community(
args.tenant_id, kb_id, args.doc_id, llm_bdl, embed_bdl, callback
)
print("------------------ COMMUNITY REPORT ----------------------\n", dealer.community_reports) print(
print(json.dumps(dealer.community_structure, ensure_ascii=False, indent=2)) "------------------ COMMUNITY STRUCTURE--------------------\n",
json.dumps(community_structure, ensure_ascii=False, indent=2),
)
print(
"------------------ COMMUNITY REPORTS----------------------\n",
community_reports,
)
if __name__ == "__main__":
trio.run(main)

View File

@ -18,22 +18,42 @@ import argparse
import json import json
from api import settings from api import settings
import networkx as nx import networkx as nx
import logging
import trio
from api.db import LLMType from api.db import LLMType
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService
from graphrag.general.index import Dealer from graphrag.general.index import update_graph
from graphrag.light.graph_extractor import GraphExtractor from graphrag.light.graph_extractor import GraphExtractor
from rag.utils.redis_conn import RedisDistributedLock
settings.init_settings() settings.init_settings()
if __name__ == "__main__":
def callback(prog=None, msg="Processing..."):
logging.info(msg)
async def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True) parser.add_argument(
parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True) "-t",
"--tenant_id",
default=False,
help="Tenant ID",
action="store",
required=True,
)
parser.add_argument(
"-d",
"--doc_id",
default=False,
help="Document ID",
action="store",
required=True,
)
args = parser.parse_args() args = parser.parse_args()
e, doc = DocumentService.get_by_id(args.doc_id) e, doc = DocumentService.get_by_id(args.doc_id)
@ -41,18 +61,36 @@ if __name__ == "__main__":
raise LookupError("Document not found.") raise LookupError("Document not found.")
kb_id = doc.kb_id kb_id = doc.kb_id
chunks = [d["content_with_weight"] for d in chunks = [
settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, [kb_id], max_count=6, d["content_with_weight"]
fields=["content_with_weight"])] for d in settings.retrievaler.chunk_list(
chunks = [("x", c) for c in chunks] args.doc_id,
args.tenant_id,
RedisDistributedLock.clean_lock(kb_id) [kb_id],
max_count=6,
fields=["content_with_weight"],
)
]
_, tenant = TenantService.get_by_id(args.tenant_id) _, tenant = TenantService.get_by_id(args.tenant_id)
llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id) llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id)
_, kb = KnowledgebaseService.get_by_id(kb_id) _, kb = KnowledgebaseService.get_by_id(kb_id)
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id) embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
dealer = Dealer(GraphExtractor, args.tenant_id, kb_id, llm_bdl, chunks, "English", embed_bdl=embed_bdl) graph, doc_ids = await update_graph(
GraphExtractor,
args.tenant_id,
kb_id,
args.doc_id,
chunks,
"English",
llm_bdl,
embed_bdl,
callback,
)
print(json.dumps(nx.node_link_data(dealer.graph), ensure_ascii=False, indent=2)) print(json.dumps(nx.node_link_data(graph), ensure_ascii=False, indent=2))
if __name__ == "__main__":
trio.run(main)

View File

@ -352,25 +352,57 @@ def set_relation(tenant_id, kb_id, embd_mdl, from_ent_name, to_ent_name, meta):
chunk["q_%d_vec" % len(ebd)] = ebd chunk["q_%d_vec" % len(ebd)] = ebd
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id) settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id)
async def does_graph_contains(tenant_id, kb_id, doc_id):
# Get doc_ids of graph
fields = ["source_id"]
condition = {
"knowledge_graph_kwd": ["graph"],
"removed_kwd": "N",
}
res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id]))
fields2 = settings.docStoreConn.getFields(res, fields)
graph_doc_ids = set()
for chunk_id in fields2.keys():
graph_doc_ids = set(fields2[chunk_id]["source_id"])
return doc_id in graph_doc_ids
def get_graph(tenant_id, kb_id): async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
conds = {
"fields": ["source_id"],
"removed_kwd": "N",
"size": 1,
"knowledge_graph_kwd": ["graph"]
}
res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id]))
doc_ids = []
if res.total == 0:
return doc_ids
for id in res.ids:
doc_ids = res.field[id]["source_id"]
return doc_ids
async def get_graph(tenant_id, kb_id):
conds = { conds = {
"fields": ["content_with_weight", "source_id"], "fields": ["content_with_weight", "source_id"],
"removed_kwd": "N", "removed_kwd": "N",
"size": 1, "size": 1,
"knowledge_graph_kwd": ["graph"] "knowledge_graph_kwd": ["graph"]
} }
res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id]) res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id]))
if res.total == 0:
return None, []
for id in res.ids: for id in res.ids:
try: try:
return json_graph.node_link_graph(json.loads(res.field[id]["content_with_weight"]), edges="edges"), \ return json_graph.node_link_graph(json.loads(res.field[id]["content_with_weight"]), edges="edges"), \
res.field[id]["source_id"] res.field[id]["source_id"]
except Exception: except Exception:
continue continue
return rebuild_graph(tenant_id, kb_id) result = await rebuild_graph(tenant_id, kb_id)
return result
def set_graph(tenant_id, kb_id, graph, docids): async def set_graph(tenant_id, kb_id, graph, docids):
chunk = { chunk = {
"content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False, "content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False,
indent=2), indent=2),
@ -380,12 +412,12 @@ def set_graph(tenant_id, kb_id, graph, docids):
"available_int": 0, "available_int": 0,
"removed_kwd": "N" "removed_kwd": "N"
} }
res = settings.retrievaler.search({"knowledge_graph_kwd": "graph", "size": 1, "fields": []}, search.index_name(tenant_id), [kb_id]) res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "graph", "size": 1, "fields": []}, search.index_name(tenant_id), [kb_id]))
if res.ids: if res.ids:
settings.docStoreConn.update({"knowledge_graph_kwd": "graph"}, chunk, await trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"knowledge_graph_kwd": "graph"}, chunk,
search.index_name(tenant_id), kb_id) search.index_name(tenant_id), kb_id))
else: else:
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id) await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id))
def is_continuous_subsequence(subseq, seq): def is_continuous_subsequence(subseq, seq):
@ -430,7 +462,7 @@ def merge_tuples(list1, list2):
return result return result
def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop): async def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop):
def n_neighbor(id): def n_neighbor(id):
nonlocal graph, n_hop nonlocal graph, n_hop
count = 0 count = 0
@ -460,10 +492,10 @@ def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop):
for n, p in pr.items(): for n, p in pr.items():
graph.nodes[n]["pagerank"] = p graph.nodes[n]["pagerank"] = p
try: try:
settings.docStoreConn.update({"entity_kwd": n, "kb_id": kb_id}, await trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"entity_kwd": n, "kb_id": kb_id},
{"rank_flt": p, {"rank_flt": p,
"n_hop_with_weight": json.dumps(n_neighbor(n), ensure_ascii=False)}, "n_hop_with_weight": json.dumps( (n), ensure_ascii=False)},
search.index_name(tenant_id), kb_id) search.index_name(tenant_id), kb_id))
except Exception as e: except Exception as e:
logging.exception(e) logging.exception(e)
@ -480,21 +512,21 @@ def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop):
"knowledge_graph_kwd": "ty2ents", "knowledge_graph_kwd": "ty2ents",
"available_int": 0 "available_int": 0
} }
res = settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "size": 1, "fields": []}, res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "size": 1, "fields": []},
search.index_name(tenant_id), [kb_id]) search.index_name(tenant_id), [kb_id]))
if res.ids: if res.ids:
settings.docStoreConn.update({"knowledge_graph_kwd": "ty2ents"}, await trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"knowledge_graph_kwd": "ty2ents"},
chunk, chunk,
search.index_name(tenant_id), kb_id) search.index_name(tenant_id), kb_id))
else: else:
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id) await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id))
def get_entity_type2sampels(idxnms, kb_ids: list): async def get_entity_type2sampels(idxnms, kb_ids: list):
es_res = settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, es_res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids,
"size": 10000, "size": 10000,
"fields": ["content_with_weight"]}, "fields": ["content_with_weight"]},
idxnms, kb_ids) idxnms, kb_ids))
res = defaultdict(list) res = defaultdict(list)
for id in es_res.ids: for id in es_res.ids:
@ -522,18 +554,18 @@ def flat_uniq_list(arr, key):
return list(set(res)) return list(set(res))
def rebuild_graph(tenant_id, kb_id): async def rebuild_graph(tenant_id, kb_id):
graph = nx.Graph() graph = nx.Graph()
src_ids = [] src_ids = []
flds = ["entity_kwd", "entity_type_kwd", "from_entity_kwd", "to_entity_kwd", "weight_int", "knowledge_graph_kwd", "source_id"] flds = ["entity_kwd", "entity_type_kwd", "from_entity_kwd", "to_entity_kwd", "weight_int", "knowledge_graph_kwd", "source_id"]
bs = 256 bs = 256
for i in range(0, 39*bs, bs): for i in range(0, 39*bs, bs):
es_res = settings.docStoreConn.search(flds, [], es_res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(flds, [],
{"kb_id": kb_id, "knowledge_graph_kwd": ["entity", "relation"]}, {"kb_id": kb_id, "knowledge_graph_kwd": ["entity", "relation"]},
[], [],
OrderByExpr(), OrderByExpr(),
i, bs, search.index_name(tenant_id), [kb_id] i, bs, search.index_name(tenant_id), [kb_id]
) ))
tot = settings.docStoreConn.getTotal(es_res) tot = settings.docStoreConn.getTotal(es_res)
if tot == 0: if tot == 0:
return None, None return None, None

View File

@ -15,18 +15,25 @@
# #
import logging import logging
import re import re
from threading import Lock
import umap import umap
import numpy as np import numpy as np
from sklearn.mixture import GaussianMixture from sklearn.mixture import GaussianMixture
import trio import trio
from graphrag.utils import get_llm_cache, get_embed_cache, set_embed_cache, set_llm_cache, chat_limiter from graphrag.utils import (
get_llm_cache,
get_embed_cache,
set_embed_cache,
set_llm_cache,
chat_limiter,
)
from rag.utils import truncate from rag.utils import truncate
class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
def __init__(self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1): def __init__(
self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1
):
self._max_cluster = max_cluster self._max_cluster = max_cluster
self._llm_model = llm_model self._llm_model = llm_model
self._embd_model = embd_model self._embd_model = embd_model
@ -34,22 +41,24 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
self._prompt = prompt self._prompt = prompt
self._max_token = max_token self._max_token = max_token
def _chat(self, system, history, gen_conf): async def _chat(self, system, history, gen_conf):
response = get_llm_cache(self._llm_model.llm_name, system, history, gen_conf) response = get_llm_cache(self._llm_model.llm_name, system, history, gen_conf)
if response: if response:
return response return response
response = self._llm_model.chat(system, history, gen_conf) response = await trio.to_thread.run_sync(
lambda: self._llm_model.chat(system, history, gen_conf)
)
response = re.sub(r"<think>.*</think>", "", response, flags=re.DOTALL) response = re.sub(r"<think>.*</think>", "", response, flags=re.DOTALL)
if response.find("**ERROR**") >= 0: if response.find("**ERROR**") >= 0:
raise Exception(response) raise Exception(response)
set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf) set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)
return response return response
def _embedding_encode(self, txt): async def _embedding_encode(self, txt):
response = get_embed_cache(self._embd_model.llm_name, txt) response = get_embed_cache(self._embd_model.llm_name, txt)
if response is not None: if response is not None:
return response return response
embds, _ = self._embd_model.encode([txt]) embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt]))
if len(embds) < 1 or len(embds[0]) < 1: if len(embds) < 1 or len(embds[0]) < 1:
raise Exception("Embedding error: ") raise Exception("Embedding error: ")
embds = embds[0] embds = embds[0]
@ -74,36 +83,48 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
return [] return []
chunks = [(s, a) for s, a in chunks if s and len(a) > 0] chunks = [(s, a) for s, a in chunks if s and len(a) > 0]
async def summarize(ck_idx, lock): async def summarize(ck_idx: list[int]):
nonlocal chunks nonlocal chunks
try:
texts = [chunks[i][0] for i in ck_idx] texts = [chunks[i][0] for i in ck_idx]
len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts)) len_per_chunk = int(
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts]) (self._llm_model.max_length - self._max_token) / len(texts)
)
cluster_content = "\n".join(
[truncate(t, max(1, len_per_chunk)) for t in texts]
)
async with chat_limiter: async with chat_limiter:
cnt = await trio.to_thread.run_sync(lambda: self._chat("You're a helpful assistant.", cnt = await self._chat(
[{"role": "user", "You're a helpful assistant.",
"content": self._prompt.format(cluster_content=cluster_content)}], [
{"temperature": 0.3, "max_tokens": self._max_token} {
)) "role": "user",
cnt = re.sub("(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", "", "content": self._prompt.format(
cnt) cluster_content=cluster_content
),
}
],
{"temperature": 0.3, "max_tokens": self._max_token},
)
cnt = re.sub(
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
"",
cnt,
)
logging.debug(f"SUM: {cnt}") logging.debug(f"SUM: {cnt}")
embds, _ = self._embd_model.encode([cnt]) embds = await self._embedding_encode(cnt)
with lock: chunks.append((cnt, embds))
chunks.append((cnt, self._embedding_encode(cnt)))
except Exception as e:
logging.exception("summarize got exception")
return e
labels = [] labels = []
lock = Lock()
while end - start > 1: while end - start > 1:
embeddings = [embd for _, embd in chunks[start:end]] embeddings = [embd for _, embd in chunks[start:end]]
if len(embeddings) == 2: if len(embeddings) == 2:
await summarize([start, start + 1], lock) await summarize([start, start + 1])
if callback: if callback:
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end)) callback(
msg="Cluster one layer: {} -> {}".format(
end - start, len(chunks) - end
)
)
labels.extend([0, 0]) labels.extend([0, 0])
layers.append((end, len(chunks))) layers.append((end, len(chunks)))
start = end start = end
@ -112,7 +133,9 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
n_neighbors = int((len(embeddings) - 1) ** 0.8) n_neighbors = int((len(embeddings) - 1) ** 0.8)
reduced_embeddings = umap.UMAP( reduced_embeddings = umap.UMAP(
n_neighbors=max(2, n_neighbors), n_components=min(12, len(embeddings) - 2), metric="cosine" n_neighbors=max(2, n_neighbors),
n_components=min(12, len(embeddings) - 2),
metric="cosine",
).fit_transform(embeddings) ).fit_transform(embeddings)
n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state) n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state)
if n_clusters == 1: if n_clusters == 1:
@ -127,18 +150,22 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
async with trio.open_nursery() as nursery: async with trio.open_nursery() as nursery:
for c in range(n_clusters): for c in range(n_clusters):
ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c] ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]
if not ck_idx: assert len(ck_idx) > 0
continue
async with chat_limiter: async with chat_limiter:
nursery.start_soon(lambda: summarize(ck_idx, lock)) nursery.start_soon(lambda: summarize(ck_idx))
assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters) assert len(chunks) - end == n_clusters, "{} vs. {}".format(
len(chunks) - end, n_clusters
)
labels.extend(lbls) labels.extend(lbls)
layers.append((end, len(chunks))) layers.append((end, len(chunks)))
if callback: if callback:
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end)) callback(
msg="Cluster one layer: {} -> {}".format(
end - start, len(chunks) - end
)
)
start = end start = end
end = len(chunks) end = len(chunks)
return chunks return chunks

View File

@ -20,9 +20,7 @@ import random
import sys import sys
from api.utils.log_utils import initRootLogger, get_project_base_directory from api.utils.log_utils import initRootLogger, get_project_base_directory
from graphrag.general.index import WithCommunity, WithResolution, Dealer from graphrag.general.index import run_graphrag
from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
from rag.prompts import keyword_extraction, question_proposal, content_tagging from rag.prompts import keyword_extraction, question_proposal, content_tagging
@ -45,6 +43,7 @@ import tracemalloc
import resource import resource
import signal import signal
import trio import trio
import exceptiongroup
import numpy as np import numpy as np
from peewee import DoesNotExist from peewee import DoesNotExist
@ -453,24 +452,6 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
return res, tk_count return res, tk_count
async def run_graphrag(row, chat_model, language, embedding_model, callback=None):
chunks = []
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", "doc_id"]):
chunks.append((d["doc_id"], d["content_with_weight"]))
dealer = Dealer(LightKGExt if row["parser_config"]["graphrag"]["method"] != 'general' else GeneralKGExt,
row["tenant_id"],
str(row["kb_id"]),
chat_model,
chunks=chunks,
language=language,
entity_types=row["parser_config"]["graphrag"]["entity_types"],
embed_bdl=embedding_model,
callback=callback)
await dealer()
async def do_handle_task(task): async def do_handle_task(task):
task_id = task["id"] task_id = task["id"]
task_from_page = task["from_page"] task_from_page = task["from_page"]
@ -526,24 +507,10 @@ async def do_handle_task(task):
return return
start_ts = timer() start_ts = timer()
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
await run_graphrag(task, chat_model, task_language, embedding_model, progress_callback) with_resolution = graphrag_conf.get("resolution", False)
progress_callback(prog=1.0, msg="Knowledge Graph basic is done ({:.2f}s)".format(timer() - start_ts)) with_community = graphrag_conf.get("community", False)
if graphrag_conf.get("resolution", False): await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback)
start_ts = timer() progress_callback(prog=1.0, msg="Knowledge Graph done ({:.2f}s)".format(timer() - start_ts))
with_res = WithResolution(
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
progress_callback
)
await with_res()
progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts))
if graphrag_conf.get("community", False):
start_ts = timer()
with_comm = WithCommunity(
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
progress_callback
)
await with_comm()
progress_callback(prog=1.0, msg="Knowledge Graph community is done ({:.2f}s)".format(timer() - start_ts))
return return
else: else:
# Standard chunking methods # Standard chunking methods
@ -622,7 +589,11 @@ async def handle_task():
FAILED_TASKS += 1 FAILED_TASKS += 1
CURRENT_TASKS.pop(task["id"], None) CURRENT_TASKS.pop(task["id"], None)
try: try:
set_progress(task["id"], prog=-1, msg=f"[Exception]: {e}") err_msg = str(e)
while isinstance(e, exceptiongroup.ExceptionGroup):
e = e.exceptions[0]
err_msg += ' -- ' + str(e)
set_progress(task["id"], prog=-1, msg=f"[Exception]: {err_msg}")
except Exception: except Exception:
pass pass
logging.exception(f"handle_task got exception for task {json.dumps(task)}") logging.exception(f"handle_task got exception for task {json.dumps(task)}")

View File

@ -16,13 +16,12 @@
import logging import logging
import json import json
import time
import uuid import uuid
import valkey as redis import valkey as redis
from rag import settings from rag import settings
from rag.utils import singleton from rag.utils import singleton
from valkey.lock import Lock
class RedisMsg: class RedisMsg:
def __init__(self, consumer, queue_name, group_name, msg_id, message): def __init__(self, consumer, queue_name, group_name, msg_id, message):
@ -281,29 +280,23 @@ REDIS_CONN = RedisDB()
class RedisDistributedLock: class RedisDistributedLock:
def __init__(self, lock_key, timeout=10): def __init__(self, lock_key, lock_value=None, timeout=10, blocking_timeout=1):
self.lock_key = lock_key self.lock_key = lock_key
if lock_value:
self.lock_value = lock_value
else:
self.lock_value = str(uuid.uuid4()) self.lock_value = str(uuid.uuid4())
self.timeout = timeout self.timeout = timeout
self.lock = Lock(REDIS_CONN.REDIS, lock_key, timeout=timeout, blocking_timeout=blocking_timeout)
@staticmethod def acquire(self):
def clean_lock(lock_key): return self.lock.acquire()
REDIS_CONN.REDIS.delete(lock_key)
def acquire_lock(self): def release(self):
end_time = time.time() + self.timeout return self.lock.release()
while time.time() < end_time:
if REDIS_CONN.REDIS.setnx(self.lock_key, self.lock_value):
return True
time.sleep(1)
return False
def release_lock(self):
if REDIS_CONN.REDIS.get(self.lock_key) == self.lock_value:
REDIS_CONN.REDIS.delete(self.lock_key)
def __enter__(self): def __enter__(self):
self.acquire_lock() self.acquire()
def __exit__(self, exception_type, exception_value, exception_traceback): def __exit__(self, exception_type, exception_value, exception_traceback):
self.release_lock() self.release()