Optimize graphrag again (#6513)

### What problem does this PR solve?

Removed set_entity and set_relation to avoid accessing doc engine during
graph computation.
Introduced GraphChange to avoid writing unchanged chunks.

### Type of change

- [x] Performance Improvement
This commit is contained in:
Zhichang Yu 2025-03-26 15:34:42 +08:00 committed by GitHub
parent 7a677cb095
commit 6bf26e2a81
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 466 additions and 530 deletions

View File

@ -47,6 +47,8 @@ from rag.utils.redis_conn import RedisDistributedLock
stop_event = threading.Event()
RAGFLOW_DEBUGPY_LISTEN = int(os.environ.get('RAGFLOW_DEBUGPY_LISTEN', "0"))
def update_progress():
lock_value = str(uuid.uuid4())
redis_lock = RedisDistributedLock("update_progress", lock_value=lock_value, timeout=60)
@ -85,6 +87,11 @@ if __name__ == '__main__':
settings.init_settings()
print_rag_settings()
if RAGFLOW_DEBUGPY_LISTEN > 0:
logging.info(f"debugpy listen on {RAGFLOW_DEBUGPY_LISTEN}")
import debugpy
debugpy.listen(("0.0.0.0", RAGFLOW_DEBUGPY_LISTEN))
# init db
init_web_db()
init_web_data()

View File

@ -8,7 +8,7 @@
"docnm_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"title_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"title_sm_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"name_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"name_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
"important_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
"tag_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
"important_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
@ -27,16 +27,16 @@
"rank_int": {"type": "integer", "default": 0},
"rank_flt": {"type": "float", "default": 0},
"available_int": {"type": "integer", "default": 1},
"knowledge_graph_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"knowledge_graph_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
"entities_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
"pagerank_fea": {"type": "integer", "default": 0},
"tag_feas": {"type": "varchar", "default": ""},
"from_entity_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"to_entity_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"entity_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"entity_type_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"source_id": {"type": "varchar", "default": ""},
"from_entity_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
"to_entity_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
"entity_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
"entity_type_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
"source_id": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
"n_hop_with_weight": {"type": "varchar", "default": ""},
"removed_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}
"removed_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"}
}

View File

@ -12,6 +12,8 @@ services:
- ${SVR_HTTP_PORT}:9380
- 80:80
- 443:443
- 5678:5678
- 5679:5679
volumes:
- ./ragflow-logs:/ragflow/logs
- ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf

View File

@ -16,7 +16,6 @@
import logging
import itertools
import re
import time
from dataclasses import dataclass
from typing import Any, Callable
@ -28,7 +27,7 @@ from rag.nlp import is_english
import editdistance
from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import perform_variable_replacements, chat_limiter
from graphrag.utils import perform_variable_replacements, chat_limiter, GraphChange
DEFAULT_RECORD_DELIMITER = "##"
DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
@ -39,7 +38,7 @@ DEFAULT_RESOLUTION_RESULT_DELIMITER = "&&"
class EntityResolutionResult:
"""Entity resolution result class definition."""
graph: nx.Graph
removed_entities: list
change: GraphChange
class EntityResolution(Extractor):
@ -54,12 +53,8 @@ class EntityResolution(Extractor):
def __init__(
self,
llm_invoker: CompletionLLM,
get_entity: Callable | None = None,
set_entity: Callable | None = None,
get_relation: Callable | None = None,
set_relation: Callable | None = None
):
super().__init__(llm_invoker, get_entity=get_entity, set_entity=set_entity, get_relation=get_relation, set_relation=set_relation)
super().__init__(llm_invoker)
"""Init method definition."""
self._llm = llm_invoker
self._resolution_prompt = ENTITY_RESOLUTION_PROMPT
@ -84,8 +79,8 @@ class EntityResolution(Extractor):
or DEFAULT_RESOLUTION_RESULT_DELIMITER,
}
nodes = graph.nodes
entity_types = list(set(graph.nodes[node].get('entity_type', '-') for node in nodes))
nodes = sorted(graph.nodes())
entity_types = sorted(set(graph.nodes[node].get('entity_type', '-') for node in nodes))
node_clusters = {entity_type: [] for entity_type in entity_types}
for node in nodes:
@ -105,54 +100,22 @@ class EntityResolution(Extractor):
nursery.start_soon(lambda: self._resolve_candidate(candidate_resolution_i, resolution_result))
callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.")
change = GraphChange()
connect_graph = nx.Graph()
removed_entities = []
connect_graph.add_edges_from(resolution_result)
all_entities_data = []
all_relationships_data = []
all_remove_nodes = []
async with trio.open_nursery() as nursery:
for sub_connect_graph in nx.connected_components(connect_graph):
sub_connect_graph = connect_graph.subgraph(sub_connect_graph)
remove_nodes = list(sub_connect_graph.nodes)
keep_node = remove_nodes.pop()
all_remove_nodes.append(remove_nodes)
nursery.start_soon(lambda: self._merge_nodes(keep_node, self._get_entity_(remove_nodes), all_entities_data))
for remove_node in remove_nodes:
removed_entities.append(remove_node)
remove_node_neighbors = graph[remove_node]
remove_node_neighbors = list(remove_node_neighbors)
for remove_node_neighbor in remove_node_neighbors:
rel = self._get_relation_(remove_node, remove_node_neighbor)
if graph.has_edge(remove_node, remove_node_neighbor):
graph.remove_edge(remove_node, remove_node_neighbor)
if remove_node_neighbor == keep_node:
if graph.has_edge(keep_node, remove_node):
graph.remove_edge(keep_node, remove_node)
continue
if not rel:
continue
if graph.has_edge(keep_node, remove_node_neighbor):
nursery.start_soon(lambda: self._merge_edges(keep_node, remove_node_neighbor, [rel], all_relationships_data))
else:
pair = sorted([keep_node, remove_node_neighbor])
graph.add_edge(pair[0], pair[1], weight=rel['weight'])
self._set_relation_(pair[0], pair[1],
dict(
src_id=pair[0],
tgt_id=pair[1],
weight=rel['weight'],
description=rel['description'],
keywords=[],
source_id=rel.get("source_id", ""),
metadata={"created_at": time.time()}
))
graph.remove_node(remove_node)
merging_nodes = list(sub_connect_graph.nodes)
nursery.start_soon(lambda: self._merge_graph_nodes(graph, merging_nodes, change))
# Update pagerank
pr = nx.pagerank(graph)
for node_name, pagerank in pr.items():
graph.nodes[node_name]["pagerank"] = pagerank
return EntityResolutionResult(
graph=graph,
removed_entities=removed_entities
change=change,
)
async def _resolve_candidate(self, candidate_resolution_i, resolution_result):

View File

@ -2,7 +2,7 @@
# Licensed under the MIT License
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
- [GraphRAG](https://github.com/microsoft/graphrag/blob/main/graphrag/prompts/index/community_report.py)
"""
COMMUNITY_REPORT_PROMPT = """

View File

@ -40,13 +40,9 @@ class CommunityReportsExtractor(Extractor):
def __init__(
self,
llm_invoker: CompletionLLM,
get_entity: Callable | None = None,
set_entity: Callable | None = None,
get_relation: Callable | None = None,
set_relation: Callable | None = None,
max_report_length: int | None = None,
):
super().__init__(llm_invoker, get_entity=get_entity, set_entity=set_entity, get_relation=get_relation, set_relation=set_relation)
super().__init__(llm_invoker)
"""Init method definition."""
self._llm = llm_invoker
self._extraction_prompt = COMMUNITY_REPORT_PROMPT
@ -63,21 +59,28 @@ class CommunityReportsExtractor(Extractor):
over, token_count = 0, 0
async def extract_community_report(community):
nonlocal res_str, res_dict, over, token_count
cm_id, ents = community
weight = ents["weight"]
ents = ents["nodes"]
ent_df = pd.DataFrame(self._get_entity_(ents)).dropna()
if ent_df.empty or "entity_name" not in ent_df.columns:
cm_id, cm = community
weight = cm["weight"]
ents = cm["nodes"]
if len(ents) < 2:
return
ent_df["entity"] = ent_df["entity_name"]
del ent_df["entity_name"]
rela_df = pd.DataFrame(self._get_relation_(list(ent_df["entity"]), list(ent_df["entity"]), 10000))
if rela_df.empty:
return
rela_df["source"] = rela_df["src_id"]
rela_df["target"] = rela_df["tgt_id"]
del rela_df["src_id"]
del rela_df["tgt_id"]
ent_list = [{"entity": ent, "description": graph.nodes[ent]["description"]} for ent in ents]
ent_df = pd.DataFrame(ent_list)
rela_list = []
k = 0
for i in range(0, len(ents)):
if k >= 10000:
break
for j in range(i + 1, len(ents)):
if k >= 10000:
break
edge = graph.get_edge_data(ents[i], ents[j])
if edge is None:
continue
rela_list.append({"source": ents[i], "target": ents[j], "description": edge["description"]})
k += 1
rela_df = pd.DataFrame(rela_list)
prompt_variables = {
"entity_df": ent_df.to_csv(index_label="id"),

View File

@ -19,10 +19,11 @@ from collections import defaultdict, Counter
from copy import deepcopy
from typing import Callable
import trio
import networkx as nx
from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT
from graphrag.utils import get_llm_cache, set_llm_cache, handle_single_entity_extraction, \
handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list, chat_limiter
handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list, chat_limiter, get_from_to, GraphChange
from rag.llm.chat_model import Base as CompletionLLM
from rag.prompts import message_fit_in
from rag.utils import truncate
@ -40,18 +41,10 @@ class Extractor:
llm_invoker: CompletionLLM,
language: str | None = "English",
entity_types: list[str] | None = None,
get_entity: Callable | None = None,
set_entity: Callable | None = None,
get_relation: Callable | None = None,
set_relation: Callable | None = None,
):
self._llm = llm_invoker
self._language = language
self._entity_types = entity_types or DEFAULT_ENTITY_TYPES
self._get_entity_ = get_entity
self._set_entity_ = set_entity
self._get_relation_ = get_relation
self._set_relation_ = set_relation
def _chat(self, system, history, gen_conf):
hist = deepcopy(history)
@ -152,25 +145,15 @@ class Extractor:
async def _merge_nodes(self, entity_name: str, entities: list[dict], all_relationships_data):
if not entities:
return
already_entity_types = []
already_source_ids = []
already_description = []
already_node = self._get_entity_(entity_name)
if already_node:
already_entity_types.append(already_node["entity_type"])
already_source_ids.extend(already_node["source_id"])
already_description.append(already_node["description"])
entity_type = sorted(
Counter(
[dp["entity_type"] for dp in entities] + already_entity_types
[dp["entity_type"] for dp in entities]
).items(),
key=lambda x: x[1],
reverse=True,
)[0][0]
description = GRAPH_FIELD_SEP.join(
sorted(set([dp["description"] for dp in entities] + already_description))
sorted(set([dp["description"] for dp in entities]))
)
already_source_ids = flat_uniq_list(entities, "source_id")
description = await self._handle_entity_relation_summary(entity_name, description)
@ -180,7 +163,6 @@ class Extractor:
source_id=already_source_ids,
)
node_data["entity_name"] = entity_name
self._set_entity_(entity_name, node_data)
all_relationships_data.append(node_data)
async def _merge_edges(
@ -192,36 +174,11 @@ class Extractor:
):
if not edges_data:
return
already_weights = []
already_source_ids = []
already_description = []
already_keywords = []
relation = self._get_relation_(src_id, tgt_id)
if relation:
already_weights = [relation["weight"]]
already_source_ids = relation["source_id"]
already_description = [relation["description"]]
already_keywords = relation["keywords"]
weight = sum([dp["weight"] for dp in edges_data] + already_weights)
description = GRAPH_FIELD_SEP.join(
sorted(set([dp["description"] for dp in edges_data] + already_description))
)
keywords = flat_uniq_list(edges_data, "keywords") + already_keywords
source_id = flat_uniq_list(edges_data, "source_id") + already_source_ids
for need_insert_id in [src_id, tgt_id]:
if self._get_entity_(need_insert_id):
continue
self._set_entity_(need_insert_id, {
"source_id": source_id,
"description": description,
"entity_type": 'UNKNOWN'
})
description = await self._handle_entity_relation_summary(
f"({src_id}, {tgt_id})", description
)
weight = sum([edge["weight"] for edge in edges_data])
description = GRAPH_FIELD_SEP.join(sorted(set([edge["description"] for edge in edges_data])))
description = await self._handle_entity_relation_summary(f"{src_id} -> {tgt_id}", description)
keywords = flat_uniq_list(edges_data, "keywords")
source_id = flat_uniq_list(edges_data, "source_id")
edge_data = dict(
src_id=src_id,
tgt_id=tgt_id,
@ -230,9 +187,41 @@ class Extractor:
weight=weight,
source_id=source_id
)
self._set_relation_(src_id, tgt_id, edge_data)
if all_relationships_data is not None:
all_relationships_data.append(edge_data)
all_relationships_data.append(edge_data)
async def _merge_graph_nodes(self, graph: nx.Graph, nodes: list[str], change: GraphChange):
if len(nodes) <= 1:
return
change.added_updated_nodes.add(nodes[0])
change.removed_nodes.extend(nodes[1:])
nodes_set = set(nodes)
node0_attrs = graph.nodes[nodes[0]]
node0_neighbors = set(graph.neighbors(nodes[0]))
for node1 in nodes[1:]:
# Merge two nodes, keep "entity_name", "entity_type", "page_rank" unchanged.
node1_attrs = graph.nodes[node1]
node0_attrs["description"] += f"{GRAPH_FIELD_SEP}{node1_attrs['description']}"
for attr in ["keywords", "source_id"]:
node0_attrs[attr] = sorted(set(node0_attrs[attr].extend(node1_attrs[attr])))
for neighbor in graph.neighbors(node1):
change.removed_edges.add(get_from_to(node1, neighbor))
if neighbor not in nodes_set:
edge1_attrs = graph.get_edge_data(node1, neighbor)
if neighbor in node0_neighbors:
# Merge two edges
change.added_updated_edges.add(get_from_to(nodes[0], neighbor))
edge0_attrs = graph.get_edge_data(nodes[0], neighbor)
edge0_attrs["weight"] += edge1_attrs["weight"]
edge0_attrs["description"] += f"{GRAPH_FIELD_SEP}{edge1_attrs['description']}"
edge0_attrs["keywords"] = list(set(edge0_attrs["keywords"].extend(edge1_attrs["keywords"])))
edge0_attrs["source_id"] = list(set(edge0_attrs["source_id"].extend(edge1_attrs["source_id"])))
edge0_attrs["description"] = await self._handle_entity_relation_summary(f"({nodes[0]}, {neighbor})", edge0_attrs["description"])
graph.add_edge(nodes[0], neighbor, **edge0_attrs)
else:
graph.add_edge(nodes[0], neighbor, **edge1_attrs)
graph.remove_node(node1)
node0_attrs["description"] = await self._handle_entity_relation_summary(nodes[0], node0_attrs["description"])
graph.nodes[nodes[0]].update(node0_attrs)
async def _handle_entity_relation_summary(
self,

View File

@ -6,7 +6,7 @@ Reference:
"""
import re
from typing import Any, Callable
from typing import Any
from dataclasses import dataclass
import tiktoken
import trio
@ -53,10 +53,6 @@ class GraphExtractor(Extractor):
llm_invoker: CompletionLLM,
language: str | None = "English",
entity_types: list[str] | None = None,
get_entity: Callable | None = None,
set_entity: Callable | None = None,
get_relation: Callable | None = None,
set_relation: Callable | None = None,
tuple_delimiter_key: str | None = None,
record_delimiter_key: str | None = None,
input_text_key: str | None = None,
@ -66,7 +62,7 @@ class GraphExtractor(Extractor):
max_gleanings: int | None = None,
on_error: ErrorHandlerFn | None = None,
):
super().__init__(llm_invoker, language, entity_types, get_entity, set_entity, get_relation, set_relation)
super().__init__(llm_invoker, language, entity_types)
"""Init method definition."""
# TODO: streamline construction
self._llm = llm_invoker

View File

@ -2,7 +2,7 @@
# Licensed under the MIT License
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
- [GraphRAG](https://github.com/microsoft/graphrag/blob/main/graphrag/prompts/index/extract_graph.py)
"""
GRAPH_EXTRACTION_PROMPT = """

View File

@ -15,11 +15,11 @@
#
import json
import logging
from functools import partial
import networkx as nx
import trio
from api import settings
from api.utils import get_uuid
from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
from graphrag.general.community_reports_extractor import CommunityReportsExtractor
@ -27,32 +27,15 @@ from graphrag.entity_resolution import EntityResolution
from graphrag.general.extractor import Extractor
from graphrag.utils import (
graph_merge,
set_entity,
get_relation,
set_relation,
get_entity,
get_graph,
set_graph,
chunk_id,
update_nodes_pagerank_nhop_neighbour,
does_graph_contains,
get_graph_doc_ids,
tidy_graph,
GraphChange,
)
from rag.nlp import rag_tokenizer, search
from rag.utils.redis_conn import REDIS_CONN
def graphrag_task_set(tenant_id, kb_id, doc_id) -> bool:
key = f"graphrag:{tenant_id}:{kb_id}"
ok = REDIS_CONN.set(key, doc_id, exp=3600 * 24)
if not ok:
raise Exception(f"Faild to set the {key} to {doc_id}")
def graphrag_task_get(tenant_id, kb_id) -> str | None:
key = f"graphrag:{tenant_id}:{kb_id}"
doc_id = REDIS_CONN.get(key)
return doc_id
from rag.utils.redis_conn import RedisDistributedLock
async def run_graphrag(
@ -72,7 +55,7 @@ async def run_graphrag(
):
chunks.append(d["content_with_weight"])
graph, doc_ids = await update_graph(
subgraph = await generate_subgraph(
LightKGExt
if row["parser_config"]["graphrag"]["method"] != "general"
else GeneralKGExt,
@ -86,14 +69,26 @@ async def run_graphrag(
embedding_model,
callback,
)
if not graph:
new_graph = None
if subgraph:
new_graph = await merge_subgraph(
tenant_id,
kb_id,
doc_id,
subgraph,
embedding_model,
callback,
)
if not with_resolution or not with_community:
return
if with_resolution or with_community:
graphrag_task_set(tenant_id, kb_id, doc_id)
if with_resolution:
if new_graph is None:
new_graph = await get_graph(tenant_id, kb_id)
if with_resolution and new_graph is not None:
await resolve_entities(
graph,
doc_ids,
new_graph,
tenant_id,
kb_id,
doc_id,
@ -101,10 +96,9 @@ async def run_graphrag(
embedding_model,
callback,
)
if with_community:
if with_community and new_graph is not None:
await extract_community(
graph,
doc_ids,
new_graph,
tenant_id,
kb_id,
doc_id,
@ -117,7 +111,7 @@ async def run_graphrag(
return
async def update_graph(
async def generate_subgraph(
extractor: Extractor,
tenant_id: str,
kb_id: str,
@ -131,34 +125,41 @@ async def update_graph(
):
contains = await does_graph_contains(tenant_id, kb_id, doc_id)
if contains:
callback(msg=f"Graph already contains {doc_id}, cancel myself")
return None, None
callback(msg=f"Graph already contains {doc_id}")
return None
start = trio.current_time()
ext = extractor(
llm_bdl,
language=language,
entity_types=entity_types,
get_entity=partial(get_entity, tenant_id, kb_id),
set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl),
get_relation=partial(get_relation, tenant_id, kb_id),
set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl),
)
ents, rels = await ext(doc_id, chunks, callback)
subgraph = nx.Graph()
for en in ents:
subgraph.add_node(en["entity_name"], entity_type=en["entity_type"])
for ent in ents:
assert "description" in ent, f"entity {ent} does not have description"
ent["source_id"] = [doc_id]
subgraph.add_node(ent["entity_name"], **ent)
ignored_rels = 0
for rel in rels:
assert "description" in rel, f"relation {rel} does not have description"
if not subgraph.has_node(rel["src_id"]) or not subgraph.has_node(rel["tgt_id"]):
ignored_rels += 1
continue
rel["source_id"] = [doc_id]
subgraph.add_edge(
rel["src_id"],
rel["tgt_id"],
weight=rel["weight"],
# description=rel["description"]
**rel,
)
# TODO: infinity doesn't support array search
if ignored_rels:
callback(msg=f"ignored {ignored_rels} relations due to missing entities.")
tidy_graph(subgraph, callback)
subgraph.graph["source_id"] = [doc_id]
chunk = {
"content_with_weight": json.dumps(
nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False, indent=2
nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False
),
"knowledge_graph_kwd": "subgraph",
"kb_id": kb_id,
@ -167,6 +168,11 @@ async def update_graph(
"removed_kwd": "N",
}
cid = chunk_id(chunk)
await trio.to_thread.run_sync(
lambda: settings.docStoreConn.delete(
{"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id
)
)
await trio.to_thread.run_sync(
lambda: settings.docStoreConn.insert(
[{"id": cid, **chunk}], search.index_name(tenant_id), kb_id
@ -174,39 +180,49 @@ async def update_graph(
)
now = trio.current_time()
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
start = now
return subgraph
async def merge_subgraph(
tenant_id: str,
kb_id: str,
doc_id: str,
subgraph: nx.Graph,
embedding_model,
callback,
):
graphrag_task_lock = RedisDistributedLock("graphrag_task", lock_value=doc_id, timeout=600)
while True:
if graphrag_task_lock.acquire():
break
callback(msg=f"merge_subgraph {doc_id} is waiting graphrag_task_lock")
await trio.sleep(10)
start = trio.current_time()
change = GraphChange()
old_graph = await get_graph(tenant_id, kb_id)
if old_graph is not None:
logging.info("Merge with an exiting graph...................")
tidy_graph(old_graph, callback)
new_graph = graph_merge(old_graph, subgraph, change)
else:
new_graph = subgraph
now_docids = set([doc_id])
old_graph, old_doc_ids = await get_graph(tenant_id, kb_id)
if old_graph is not None:
logging.info("Merge with an exiting graph...................")
new_graph = graph_merge(old_graph, subgraph)
await update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, new_graph, 2)
if old_doc_ids:
for old_doc_id in old_doc_ids:
now_docids.add(old_doc_id)
old_doc_ids2 = await get_graph_doc_ids(tenant_id, kb_id)
delta_doc_ids = set(old_doc_ids2) - set(old_doc_ids)
if delta_doc_ids:
callback(
msg="The global graph has changed during merging, try again"
)
await trio.sleep(1)
continue
break
await set_graph(tenant_id, kb_id, new_graph, list(now_docids))
change.added_updated_nodes = set(new_graph.nodes())
change.added_updated_edges = set(new_graph.edges())
pr = nx.pagerank(new_graph)
for node_name, pagerank in pr.items():
new_graph.nodes[node_name]["pagerank"] = pagerank
await set_graph(tenant_id, kb_id, embedding_model, new_graph, change, callback)
graphrag_task_lock.release()
now = trio.current_time()
callback(
msg=f"merging subgraph for doc {doc_id} into the global graph done in {now - start:.2f} seconds."
)
return new_graph, now_docids
return new_graph
async def resolve_entities(
graph,
doc_ids,
tenant_id: str,
kb_id: str,
doc_id: str,
@ -214,74 +230,30 @@ async def resolve_entities(
embed_bdl,
callback,
):
working_doc_id = graphrag_task_get(tenant_id, kb_id)
if doc_id != working_doc_id:
callback(
msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself"
)
return
graphrag_task_lock = RedisDistributedLock("graphrag_task", lock_value=doc_id, timeout=600)
while True:
if graphrag_task_lock.acquire():
break
await trio.sleep(10)
start = trio.current_time()
er = EntityResolution(
llm_bdl,
get_entity=partial(get_entity, tenant_id, kb_id),
set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl),
get_relation=partial(get_relation, tenant_id, kb_id),
set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl),
)
reso = await er(graph, callback=callback)
graph = reso.graph
callback(msg=f"Graph resolution removed {len(reso.removed_entities)} nodes.")
await update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, 2)
change = reso.change
callback(msg=f"Graph resolution removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges.")
callback(msg="Graph resolution updated pagerank.")
working_doc_id = graphrag_task_get(tenant_id, kb_id)
if doc_id != working_doc_id:
callback(
msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself"
)
return
await set_graph(tenant_id, kb_id, graph, doc_ids)
await trio.to_thread.run_sync(
lambda: settings.docStoreConn.delete(
{
"knowledge_graph_kwd": "relation",
"kb_id": kb_id,
"from_entity_kwd": reso.removed_entities,
},
search.index_name(tenant_id),
kb_id,
)
)
await trio.to_thread.run_sync(
lambda: settings.docStoreConn.delete(
{
"knowledge_graph_kwd": "relation",
"kb_id": kb_id,
"to_entity_kwd": reso.removed_entities,
},
search.index_name(tenant_id),
kb_id,
)
)
await trio.to_thread.run_sync(
lambda: settings.docStoreConn.delete(
{
"knowledge_graph_kwd": "entity",
"kb_id": kb_id,
"entity_kwd": reso.removed_entities,
},
search.index_name(tenant_id),
kb_id,
)
)
await set_graph(tenant_id, kb_id, embed_bdl, graph, change, callback)
graphrag_task_lock.release()
now = trio.current_time()
callback(msg=f"Graph resolution done in {now - start:.2f}s.")
async def extract_community(
graph,
doc_ids,
tenant_id: str,
kb_id: str,
doc_id: str,
@ -289,49 +261,34 @@ async def extract_community(
embed_bdl,
callback,
):
working_doc_id = graphrag_task_get(tenant_id, kb_id)
if doc_id != working_doc_id:
callback(
msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself"
)
return
graphrag_task_lock = RedisDistributedLock("graphrag_task", lock_value=doc_id, timeout=600)
while True:
if graphrag_task_lock.acquire():
break
await trio.sleep(10)
start = trio.current_time()
ext = CommunityReportsExtractor(
llm_bdl,
get_entity=partial(get_entity, tenant_id, kb_id),
set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl),
get_relation=partial(get_relation, tenant_id, kb_id),
set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl),
)
cr = await ext(graph, callback=callback)
community_structure = cr.structured_output
community_reports = cr.output
working_doc_id = graphrag_task_get(tenant_id, kb_id)
if doc_id != working_doc_id:
callback(
msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself"
)
return
await set_graph(tenant_id, kb_id, graph, doc_ids)
doc_ids = graph.graph["source_id"]
now = trio.current_time()
callback(
msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s."
)
start = now
await trio.to_thread.run_sync(
lambda: settings.docStoreConn.delete(
{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},
search.index_name(tenant_id),
kb_id,
)
)
chunks = []
for stru, rep in zip(community_structure, community_reports):
obj = {
"report": rep,
"evidences": "\n".join([f["explanation"] for f in stru["findings"]]),
}
chunk = {
"id": get_uuid(),
"docnm_kwd": stru["title"],
"title_tks": rag_tokenizer.tokenize(stru["title"]),
"content_with_weight": json.dumps(obj, ensure_ascii=False),
@ -349,17 +306,23 @@ async def extract_community(
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(
chunk["content_ltks"]
)
# try:
# ebd, _ = embed_bdl.encode([", ".join(community["entities"])])
# chunk["q_%d_vec" % len(ebd[0])] = ebd[0]
# except Exception as e:
# logging.exception(f"Fail to embed entity relation: {e}")
await trio.to_thread.run_sync(
lambda: settings.docStoreConn.insert(
[{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id)
)
)
chunks.append(chunk)
await trio.to_thread.run_sync(
lambda: settings.docStoreConn.delete(
{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},
search.index_name(tenant_id),
kb_id,
)
)
es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(tenant_id), kb_id))
if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
raise Exception(error_message)
graphrag_task_lock.release()
now = trio.current_time()
callback(
msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s."

View File

@ -100,7 +100,8 @@ def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]:
logging.debug(
"Running leiden with max_cluster_size=%s, lcc=%s", max_cluster_size, use_lcc
)
if not graph.nodes():
nodes = set(graph.nodes())
if not nodes:
return {}
node_id_to_community_map = _compute_leiden_communities(
@ -120,7 +121,7 @@ def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]:
result = {}
results_by_level[level] = result
for node_id, raw_community_id in node_id_to_community_map[level].items():
if node_id not in graph.nodes:
if node_id not in nodes:
logging.warning(f"Node {node_id} not found in the graph.")
continue
community_id = str(raw_community_id)

View File

@ -5,7 +5,7 @@ Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""
import re
from typing import Any, Callable
from typing import Any
from dataclasses import dataclass
from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS
from graphrag.light.graph_prompt import PROMPTS
@ -33,14 +33,10 @@ class GraphExtractor(Extractor):
llm_invoker: CompletionLLM,
language: str | None = "English",
entity_types: list[str] | None = None,
get_entity: Callable | None = None,
set_entity: Callable | None = None,
get_relation: Callable | None = None,
set_relation: Callable | None = None,
example_number: int = 2,
max_gleanings: int | None = None,
):
super().__init__(llm_invoker, language, entity_types, get_entity, set_entity, get_relation, set_relation)
super().__init__(llm_invoker, language, entity_types)
"""Init method definition."""
self._max_gleanings = (
max_gleanings

View File

@ -1,7 +1,7 @@
# Licensed under the MIT License
"""
Reference:
- [LightRag](https://github.com/HKUDS/LightRAG)
- [LightRAG](https://github.com/HKUDS/LightRAG/blob/main/lightrag/prompt.py)
"""

View File

@ -12,26 +12,37 @@ import logging
import re
import time
from collections import defaultdict
from copy import deepcopy
from hashlib import md5
from typing import Any, Callable
import os
import trio
from typing import Set, Tuple
import networkx as nx
import numpy as np
import xxhash
from networkx.readwrite import json_graph
import dataclasses
from api import settings
from api.utils import get_uuid
from rag.nlp import search, rag_tokenizer
from rag.utils.doc_store_conn import OrderByExpr
from rag.utils.redis_conn import REDIS_CONN
GRAPH_FIELD_SEP = "<SEP>"
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
chat_limiter = trio.CapacityLimiter(int(os.environ.get('MAX_CONCURRENT_CHATS', 10)))
@dataclasses.dataclass
class GraphChange:
removed_nodes: Set[str] = dataclasses.field(default_factory=set)
added_updated_nodes: Set[str] = dataclasses.field(default_factory=set)
removed_edges: Set[Tuple[str, str]] = dataclasses.field(default_factory=set)
added_updated_edges: Set[Tuple[str, str]] = dataclasses.field(default_factory=set)
def perform_variable_replacements(
input: str, history: list[dict] | None = None, variables: dict | None = None
) -> str:
@ -146,24 +157,74 @@ def set_tags_to_cache(kb_ids, tags):
k = hasher.hexdigest()
REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600)
def tidy_graph(graph: nx.Graph, callback):
"""
Ensure all nodes and edges in the graph have some essential attribute.
"""
def is_valid_node(node_attrs: dict) -> bool:
valid_node = True
for attr in ["description", "source_id"]:
if attr not in node_attrs:
valid_node = False
break
return valid_node
purged_nodes = []
for node, node_attrs in graph.nodes(data=True):
if not is_valid_node(node_attrs):
purged_nodes.append(node)
for node in purged_nodes:
graph.remove_node(node)
if purged_nodes and callback:
callback(msg=f"Purged {len(purged_nodes)} nodes from graph due to missing essential attributes.")
def graph_merge(g1, g2):
g = g2.copy()
for n, attr in g1.nodes(data=True):
if n not in g2.nodes():
g.add_node(n, **attr)
purged_edges = []
for source, target, attr in graph.edges(data=True):
if not is_valid_node(attr):
purged_edges.append((source, target))
if "keywords" not in attr:
attr["keywords"] = []
for source, target in purged_edges:
graph.remove_edge(source, target)
if purged_edges and callback:
callback(msg=f"Purged {len(purged_edges)} edges from graph due to missing essential attributes.")
def get_from_to(node1, node2):
if node1 < node2:
return (node1, node2)
else:
return (node2, node1)
def graph_merge(g1: nx.Graph, g2: nx.Graph, change: GraphChange):
"""Merge graph g2 into g1 in place."""
for node_name, attr in g2.nodes(data=True):
change.added_updated_nodes.add(node_name)
if not g1.has_node(node_name):
g1.add_node(node_name, **attr)
continue
node = g1.nodes[node_name]
node["description"] += GRAPH_FIELD_SEP + attr["description"]
# A node's source_id indicates which chunks it came from.
node["source_id"] += attr["source_id"]
for source, target, attr in g1.edges(data=True):
if g.has_edge(source, target):
g[source][target].update({"weight": attr.get("weight", 0)+1})
for source, target, attr in g2.edges(data=True):
change.added_updated_edges.add(get_from_to(source, target))
edge = g1.get_edge_data(source, target)
if edge is None:
g1.add_edge(source, target, **attr)
continue
g.add_edge(source, target)#, **attr)
for node_degree in g.degree:
g.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
return g
edge["weight"] += attr.get("weight", 0)
edge["description"] += GRAPH_FIELD_SEP + attr["description"]
edge["keywords"] += attr["keywords"]
# A edge's source_id indicates which chunks it came from.
edge["source_id"] += attr["source_id"]
for node_degree in g1.degree:
g1.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
# A graph's source_id indicates which documents it came from.
if "source_id" not in g1.graph:
g1.graph["source_id"] = []
g1.graph["source_id"] += g2.graph.get("source_id", [])
return g1
def compute_args_hash(*args):
return md5(str(args).encode()).hexdigest()
@ -237,55 +298,10 @@ def is_float_regex(value):
def chunk_id(chunk):
return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest()
def get_entity_cache(tenant_id, kb_id, ent_name) -> str | list[str]:
hasher = xxhash.xxh64()
hasher.update(str(tenant_id).encode("utf-8"))
hasher.update(str(kb_id).encode("utf-8"))
hasher.update(str(ent_name).encode("utf-8"))
k = hasher.hexdigest()
bin = REDIS_CONN.get(k)
if not bin:
return
return json.loads(bin)
def set_entity_cache(tenant_id, kb_id, ent_name, content_with_weight):
hasher = xxhash.xxh64()
hasher.update(str(tenant_id).encode("utf-8"))
hasher.update(str(kb_id).encode("utf-8"))
hasher.update(str(ent_name).encode("utf-8"))
k = hasher.hexdigest()
REDIS_CONN.set(k, content_with_weight.encode("utf-8"), 3600)
def get_entity(tenant_id, kb_id, ent_name):
cache = get_entity_cache(tenant_id, kb_id, ent_name)
if cache:
return cache
conds = {
"fields": ["content_with_weight"],
"entity_kwd": ent_name,
"size": 10000,
"knowledge_graph_kwd": ["entity"]
}
res = []
es_res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id])
for id in es_res.ids:
try:
if isinstance(ent_name, str):
set_entity_cache(tenant_id, kb_id, ent_name, es_res.field[id]["content_with_weight"])
return json.loads(es_res.field[id]["content_with_weight"])
res.append(json.loads(es_res.field[id]["content_with_weight"]))
except Exception:
continue
return res
def set_entity(tenant_id, kb_id, embd_mdl, ent_name, meta):
async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
chunk = {
"id": get_uuid(),
"important_kwd": [ent_name],
"title_tks": rag_tokenizer.tokenize(ent_name),
"entity_kwd": ent_name,
@ -293,28 +309,19 @@ def set_entity(tenant_id, kb_id, embd_mdl, ent_name, meta):
"entity_type_kwd": meta["entity_type"],
"content_with_weight": json.dumps(meta, ensure_ascii=False),
"content_ltks": rag_tokenizer.tokenize(meta["description"]),
"source_id": list(set(meta["source_id"])),
"source_id": meta["source_id"],
"kb_id": kb_id,
"available_int": 0
}
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
set_entity_cache(tenant_id, kb_id, ent_name, chunk["content_with_weight"])
res = settings.retrievaler.search({"entity_kwd": ent_name, "size": 1, "fields": []},
search.index_name(tenant_id), [kb_id])
if res.ids:
settings.docStoreConn.update({"entity_kwd": ent_name}, chunk, search.index_name(tenant_id), kb_id)
else:
ebd = get_embed_cache(embd_mdl.llm_name, ent_name)
if ebd is None:
try:
ebd, _ = embd_mdl.encode([ent_name])
ebd = ebd[0]
set_embed_cache(embd_mdl.llm_name, ent_name, ebd)
except Exception as e:
logging.exception(f"Fail to embed entity: {e}")
if ebd is not None:
chunk["q_%d_vec" % len(ebd)] = ebd
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id)
ebd = get_embed_cache(embd_mdl.llm_name, ent_name)
if ebd is None:
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([ent_name]))
ebd = ebd[0]
set_embed_cache(embd_mdl.llm_name, ent_name, ebd)
assert ebd is not None
chunk["q_%d_vec" % len(ebd)] = ebd
chunks.append(chunk)
def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1):
@ -344,40 +351,30 @@ def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1):
return res
def set_relation(tenant_id, kb_id, embd_mdl, from_ent_name, to_ent_name, meta):
async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks):
chunk = {
"id": get_uuid(),
"from_entity_kwd": from_ent_name,
"to_entity_kwd": to_ent_name,
"knowledge_graph_kwd": "relation",
"content_with_weight": json.dumps(meta, ensure_ascii=False),
"content_ltks": rag_tokenizer.tokenize(meta["description"]),
"important_kwd": meta["keywords"],
"source_id": list(set(meta["source_id"])),
"source_id": meta["source_id"],
"weight_int": int(meta["weight"]),
"kb_id": kb_id,
"available_int": 0
}
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
res = settings.retrievaler.search({"from_entity_kwd": to_ent_name, "to_entity_kwd": to_ent_name, "size": 1, "fields": []},
search.index_name(tenant_id), [kb_id])
if res.ids:
settings.docStoreConn.update({"from_entity_kwd": from_ent_name, "to_entity_kwd": to_ent_name},
chunk,
search.index_name(tenant_id), kb_id)
else:
txt = f"{from_ent_name}->{to_ent_name}"
ebd = get_embed_cache(embd_mdl.llm_name, txt)
if ebd is None:
try:
ebd, _ = embd_mdl.encode([txt+f": {meta['description']}"])
ebd = ebd[0]
set_embed_cache(embd_mdl.llm_name, txt, ebd)
except Exception as e:
logging.exception(f"Fail to embed entity relation: {e}")
if ebd is not None:
chunk["q_%d_vec" % len(ebd)] = ebd
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id)
txt = f"{from_ent_name}->{to_ent_name}"
ebd = get_embed_cache(embd_mdl.llm_name, txt)
if ebd is None:
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([txt+f": {meta['description']}"]))
ebd = ebd[0]
set_embed_cache(embd_mdl.llm_name, txt, ebd)
assert ebd is not None
chunk["q_%d_vec" % len(ebd)] = ebd
chunks.append(chunk)
async def does_graph_contains(tenant_id, kb_id, doc_id):
# Get doc_ids of graph
@ -418,33 +415,68 @@ async def get_graph(tenant_id, kb_id):
}
res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id]))
if res.total == 0:
return None, []
return None
for id in res.ids:
try:
return json_graph.node_link_graph(json.loads(res.field[id]["content_with_weight"]), edges="edges"), \
res.field[id]["source_id"]
g = json_graph.node_link_graph(json.loads(res.field[id]["content_with_weight"]), edges="edges")
if "source_id" not in g.graph:
g.graph["source_id"] = res.field[id]["source_id"]
return g
except Exception:
continue
result = await rebuild_graph(tenant_id, kb_id)
return result
async def set_graph(tenant_id, kb_id, graph, docids):
chunk = {
"content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False,
indent=2),
async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, change: GraphChange, callback):
start = trio.current_time()
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph"]}, search.index_name(tenant_id), kb_id))
if change.removed_nodes:
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id))
if change.removed_edges:
async with trio.open_nursery() as nursery:
for from_node, to_node in change.removed_edges:
nursery.start_soon(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id))
now = trio.current_time()
if callback:
callback(msg=f"set_graph removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges from index in {now - start:.2f}s.")
start = now
chunks = [{
"id": get_uuid(),
"content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False),
"knowledge_graph_kwd": "graph",
"kb_id": kb_id,
"source_id": list(docids),
"source_id": graph.graph.get("source_id", []),
"available_int": 0,
"removed_kwd": "N"
}
res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "graph", "size": 1, "fields": []}, search.index_name(tenant_id), [kb_id]))
if res.ids:
await trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"knowledge_graph_kwd": "graph"}, chunk,
search.index_name(tenant_id), kb_id))
else:
await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id))
}]
async with trio.open_nursery() as nursery:
for node in change.added_updated_nodes:
node_attrs = graph.nodes[node]
nursery.start_soon(lambda: graph_node_to_chunk(kb_id, embd_mdl, node, node_attrs, chunks))
for from_node, to_node in change.added_updated_edges:
edge_attrs = graph.edges[from_node, to_node]
nursery.start_soon(lambda: graph_edge_to_chunk(kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks))
now = trio.current_time()
if callback:
callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.")
start = now
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "entity", "relation"]}, search.index_name(tenant_id), kb_id))
es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(tenant_id), kb_id))
if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
raise Exception(error_message)
now = trio.current_time()
if callback:
callback(msg=f"set_graph added/updated {len(change.added_updated_nodes)} nodes and {len(change.added_updated_edges)} edges from index in {now - start:.2f}s.")
def is_continuous_subsequence(subseq, seq):
@ -489,67 +521,6 @@ def merge_tuples(list1, list2):
return result
async def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop):
def n_neighbor(id):
nonlocal graph, n_hop
count = 0
source_edge = list(graph.edges(id))
if not source_edge:
return []
count = count + 1
while count < n_hop:
count = count + 1
sc_edge = deepcopy(source_edge)
source_edge = []
for pair in sc_edge:
append_edge = list(graph.edges(pair[-1]))
for tuples in merge_tuples([pair], append_edge):
source_edge.append(tuples)
nbrs = []
for path in source_edge:
n = {"path": path, "weights": []}
wts = nx.get_edge_attributes(graph, 'weight')
for i in range(len(path)-1):
f, t = path[i], path[i+1]
n["weights"].append(wts.get((f, t), 0))
nbrs.append(n)
return nbrs
pr = nx.pagerank(graph)
try:
async with trio.open_nursery() as nursery:
for n, p in pr.items():
graph.nodes[n]["pagerank"] = p
nursery.start_soon(lambda: trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"entity_kwd": n, "kb_id": kb_id},
{"rank_flt": p,
"n_hop_with_weight": json.dumps((n), ensure_ascii=False)},
search.index_name(tenant_id), kb_id)))
except Exception as e:
logging.exception(e)
ty2ents = defaultdict(list)
for p, r in sorted(pr.items(), key=lambda x: x[1], reverse=True):
ty = graph.nodes[p].get("entity_type")
if not ty or len(ty2ents[ty]) > 12:
continue
ty2ents[ty].append(p)
chunk = {
"content_with_weight": json.dumps(ty2ents, ensure_ascii=False),
"kb_id": kb_id,
"knowledge_graph_kwd": "ty2ents",
"available_int": 0
}
res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "size": 1, "fields": []},
search.index_name(tenant_id), [kb_id]))
if res.ids:
await trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"knowledge_graph_kwd": "ty2ents"},
chunk,
search.index_name(tenant_id), kb_id))
else:
await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id))
async def get_entity_type2sampels(idxnms, kb_ids: list):
es_res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids,
"size": 10000,
@ -584,33 +555,46 @@ def flat_uniq_list(arr, key):
async def rebuild_graph(tenant_id, kb_id):
graph = nx.Graph()
src_ids = []
flds = ["entity_kwd", "entity_type_kwd", "from_entity_kwd", "to_entity_kwd", "weight_int", "knowledge_graph_kwd", "source_id"]
src_ids = set()
flds = ["entity_kwd", "from_entity_kwd", "to_entity_kwd", "knowledge_graph_kwd", "content_with_weight", "source_id"]
bs = 256
for i in range(0, 39*bs, bs):
for i in range(0, 1024*bs, bs):
es_res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(flds, [],
{"kb_id": kb_id, "knowledge_graph_kwd": ["entity", "relation"]},
{"kb_id": kb_id, "knowledge_graph_kwd": ["entity"]},
[],
OrderByExpr(),
i, bs, search.index_name(tenant_id), [kb_id]
))
tot = settings.docStoreConn.getTotal(es_res)
if tot == 0:
return None, None
break
es_res = settings.docStoreConn.getFields(es_res, flds)
for id, d in es_res.items():
src_ids.extend(d.get("source_id", []))
if d["knowledge_graph_kwd"] == "entity":
graph.add_node(d["entity_kwd"], entity_type=d["entity_type_kwd"])
elif "from_entity_kwd" in d and "to_entity_kwd" in d:
graph.add_edge(
d["from_entity_kwd"],
d["to_entity_kwd"],
weight=int(d["weight_int"])
)
assert d["knowledge_graph_kwd"] == "relation"
src_ids.update(d.get("source_id", []))
attrs = json.load(d["content_with_weight"])
graph.add_node(d["entity_kwd"], **attrs)
if len(es_res.keys()) < 128:
return graph, list(set(src_ids))
for i in range(0, 1024*bs, bs):
es_res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(flds, [],
{"kb_id": kb_id, "knowledge_graph_kwd": ["relation"]},
[],
OrderByExpr(),
i, bs, search.index_name(tenant_id), [kb_id]
))
tot = settings.docStoreConn.getTotal(es_res)
if tot == 0:
return None
return graph, list(set(src_ids))
es_res = settings.docStoreConn.getFields(es_res, flds)
for id, d in es_res.items():
assert d["knowledge_graph_kwd"] == "relation"
src_ids.update(d.get("source_id", []))
if graph.has_node(d["from_entity_kwd"]) and graph.has_node(d["to_entity_kwd"]):
attrs = json.load(d["content_with_weight"])
graph.add_edge(d["from_entity_kwd"], d["to_entity_kwd"], **attrs)
src_ids = sorted(src_ids)
graph.graph["source_id"] = src_ids
return graph

View File

@ -125,6 +125,7 @@ dependencies = [
"xxhash>=3.5.0,<4.0.0",
"trio>=0.29.0",
"langfuse>=2.60.0",
"debugpy>=1.8.13",
]
[project.optional-dependencies]

View File

@ -517,6 +517,8 @@ async def do_handle_task(task):
chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
# Either using graphrag or Standard chunking methods
elif task.get("task_type", "") == "graphrag":
global task_limiter
task_limiter = trio.CapacityLimiter(2)
graphrag_conf = task_parser_config.get("graphrag", {})
if not graphrag_conf.get("use_graphrag", False):
return

View File

@ -172,6 +172,12 @@ class InfinityConnection(DocStoreConnection):
ConflictType.Ignore,
)
def field_keyword(self, field_name: str):
# The "docnm_kwd" field is always a string, not list.
if field_name == "source_id" or (field_name.endswith("_kwd") and field_name != "docnm_kwd"):
return True
return False
"""
Database operations
"""
@ -480,9 +486,11 @@ class InfinityConnection(DocStoreConnection):
assert "_id" not in d
assert "id" in d
for k, v in d.items():
if k in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]:
assert isinstance(v, list)
d[k] = "###".join(v)
if self.field_keyword(k):
if isinstance(v, list):
d[k] = "###".join(v)
else:
d[k] = v
elif re.search(r"_feas$", k):
d[k] = json.dumps(v)
elif k == 'kb_id':
@ -495,6 +503,8 @@ class InfinityConnection(DocStoreConnection):
elif k in ["page_num_int", "top_int"]:
assert isinstance(v, list)
d[k] = "_".join(f"{num:08x}" for num in v)
else:
d[k] = v
for n, vs in embedding_clmns:
if n in d:
@ -525,13 +535,13 @@ class InfinityConnection(DocStoreConnection):
# del condition["exists"]
filter = equivalent_condition_to_str(condition, table_instance)
for k, v in list(newValue.items()):
if k in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]:
assert isinstance(v, list)
newValue[k] = "###".join(v)
if self.field_keyword(k):
if isinstance(v, list):
newValue[k] = "###".join(v)
else:
newValue[k] = v
elif re.search(r"_feas$", k):
newValue[k] = json.dumps(v)
elif k.endswith("_kwd") and isinstance(v, list):
newValue[k] = " ".join(v)
elif k == 'kb_id':
if isinstance(newValue[k], list):
newValue[k] = newValue[k][0] # since d[k] is a list, but we need a str
@ -546,6 +556,8 @@ class InfinityConnection(DocStoreConnection):
del newValue[k]
if v in [PAGERANK_FLD]:
newValue[v] = 0
else:
newValue[k] = v
logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.")
table_instance.update(filter, newValue)
@ -600,7 +612,7 @@ class InfinityConnection(DocStoreConnection):
for column in res2.columns:
k = column.lower()
if k in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]:
if self.field_keyword(k):
res2[column] = res2[column].apply(lambda v:[kwd for kwd in v.split("###") if kwd])
elif k == "position_int":
def to_position_int(v):

View File

@ -319,9 +319,3 @@ class RedisDistributedLock:
def release(self):
return self.lock.release()
def __enter__(self):
self.acquire()
def __exit__(self, exception_type, exception_value, exception_traceback):
self.release()

57
uv.lock generated
View File

@ -1100,6 +1100,27 @@ version = "0.8.2"
source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
sdist = { url = "https://mirrors.aliyun.com/pypi/packages/9d/fe/db74bd405d515f06657f11ad529878fd389576dca4812bea6f98d9b31574/datrie-0.8.2.tar.gz", hash = "sha256:525b08f638d5cf6115df6ccd818e5a01298cd230b2dac91c8ff2e6499d18765d" }
[[package]]
name = "debugpy"
version = "1.8.13"
source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
sdist = { url = "https://mirrors.aliyun.com/pypi/packages/51/d4/f35f539e11c9344652f362c22413ec5078f677ac71229dc9b4f6f85ccaa3/debugpy-1.8.13.tar.gz", hash = "sha256:837e7bef95bdefba426ae38b9a94821ebdc5bea55627879cd48165c90b9e50ce" }
wheels = [
{ url = "https://mirrors.aliyun.com/pypi/packages/3f/32/901c7204cceb3262fdf38f4c25c9a46372c11661e8490e9ea702bc4ff448/debugpy-1.8.13-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:06859f68e817966723ffe046b896b1bd75c665996a77313370336ee9e1de3e90" },
{ url = "https://mirrors.aliyun.com/pypi/packages/95/10/77fe746851c8d84838a807da60c7bd0ac8627a6107d6917dd3293bf8628c/debugpy-1.8.13-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cb56c2db69fb8df3168bc857d7b7d2494fed295dfdbde9a45f27b4b152f37520" },
{ url = "https://mirrors.aliyun.com/pypi/packages/a1/ef/28f8db2070e453dda0e49b356e339d0b4e1d38058d4c4ea9e88cdc8ee8e7/debugpy-1.8.13-cp310-cp310-win32.whl", hash = "sha256:46abe0b821cad751fc1fb9f860fb2e68d75e2c5d360986d0136cd1db8cad4428" },
{ url = "https://mirrors.aliyun.com/pypi/packages/89/16/1d53a80caf5862627d3eaffb217d4079d7e4a1df6729a2d5153733661efd/debugpy-1.8.13-cp310-cp310-win_amd64.whl", hash = "sha256:dc7b77f5d32674686a5f06955e4b18c0e41fb5a605f5b33cf225790f114cfeec" },
{ url = "https://mirrors.aliyun.com/pypi/packages/31/90/dd2fcad8364f0964f476537481985198ce6e879760281ad1cec289f1aa71/debugpy-1.8.13-cp311-cp311-macosx_14_0_universal2.whl", hash = "sha256:eee02b2ed52a563126c97bf04194af48f2fe1f68bb522a312b05935798e922ff" },
{ url = "https://mirrors.aliyun.com/pypi/packages/5c/c9/06ff65f15eb30dbdafd45d1575770b842ce3869ad5580a77f4e5590f1be7/debugpy-1.8.13-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4caca674206e97c85c034c1efab4483f33971d4e02e73081265ecb612af65377" },
{ url = "https://mirrors.aliyun.com/pypi/packages/3b/49/798a4092bde16a4650f17ac5f2301d4d37e1972d65462fb25c80a83b4790/debugpy-1.8.13-cp311-cp311-win32.whl", hash = "sha256:7d9a05efc6973b5aaf076d779cf3a6bbb1199e059a17738a2aa9d27a53bcc888" },
{ url = "https://mirrors.aliyun.com/pypi/packages/cd/d5/3684d7561c8ba2797305cf8259619acccb8d6ebe2117bb33a6897c235eee/debugpy-1.8.13-cp311-cp311-win_amd64.whl", hash = "sha256:62f9b4a861c256f37e163ada8cf5a81f4c8d5148fc17ee31fb46813bd658cdcc" },
{ url = "https://mirrors.aliyun.com/pypi/packages/79/ad/dff929b6b5403feaab0af0e5bb460fd723f9c62538b718a9af819b8fff20/debugpy-1.8.13-cp312-cp312-macosx_14_0_universal2.whl", hash = "sha256:2b8de94c5c78aa0d0ed79023eb27c7c56a64c68217d881bee2ffbcb13951d0c1" },
{ url = "https://mirrors.aliyun.com/pypi/packages/d6/4f/b7d42e6679f0bb525888c278b0c0d2b6dff26ed42795230bb46eaae4f9b3/debugpy-1.8.13-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:887d54276cefbe7290a754424b077e41efa405a3e07122d8897de54709dbe522" },
{ url = "https://mirrors.aliyun.com/pypi/packages/ec/18/d9b3e88e85d41f68f77235112adc31012a784e45a3fcdbb039777d570a0f/debugpy-1.8.13-cp312-cp312-win32.whl", hash = "sha256:3872ce5453b17837ef47fb9f3edc25085ff998ce63543f45ba7af41e7f7d370f" },
{ url = "https://mirrors.aliyun.com/pypi/packages/c9/f7/0df18a4f530ed3cc06f0060f548efe9e3316102101e311739d906f5650be/debugpy-1.8.13-cp312-cp312-win_amd64.whl", hash = "sha256:63ca7670563c320503fea26ac688988d9d6b9c6a12abc8a8cf2e7dd8e5f6b6ea" },
{ url = "https://mirrors.aliyun.com/pypi/packages/37/4f/0b65410a08b6452bfd3f7ed6f3610f1a31fb127f46836e82d31797065dcb/debugpy-1.8.13-py2.py3-none-any.whl", hash = "sha256:d4ba115cdd0e3a70942bd562adba9ec8c651fe69ddde2298a1be296fc331906f" },
]
[[package]]
name = "decorator"
version = "5.2.1"
@ -1375,17 +1396,17 @@ name = "fastembed-gpu"
version = "0.3.6"
source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
dependencies = [
{ name = "huggingface-hub" },
{ name = "loguru" },
{ name = "mmh3" },
{ name = "numpy" },
{ name = "onnxruntime-gpu" },
{ name = "pillow" },
{ name = "pystemmer" },
{ name = "requests" },
{ name = "snowballstemmer" },
{ name = "tokenizers" },
{ name = "tqdm" },
{ name = "huggingface-hub", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "loguru", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "mmh3", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "numpy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "onnxruntime-gpu", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "pillow", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "pystemmer", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "requests", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "snowballstemmer", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "tokenizers", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "tqdm", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
]
sdist = { url = "https://mirrors.aliyun.com/pypi/packages/da/07/7336c7f3d7ee47f33b407eeb50f5eeb152889de538a52a8f1cc637192816/fastembed_gpu-0.3.6.tar.gz", hash = "sha256:ee2de8918b142adbbf48caaffec0c492f864d73c073eea5a3dcd0e8c1041c50d" }
wheels = [
@ -3531,12 +3552,12 @@ name = "onnxruntime-gpu"
version = "1.19.2"
source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
dependencies = [
{ name = "coloredlogs" },
{ name = "flatbuffers" },
{ name = "numpy" },
{ name = "packaging" },
{ name = "protobuf" },
{ name = "sympy" },
{ name = "coloredlogs", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "flatbuffers", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "numpy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "packaging", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "protobuf", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "sympy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
]
wheels = [
{ url = "https://mirrors.aliyun.com/pypi/packages/d0/9c/3fa310e0730643051eb88e884f19813a6c8b67d0fbafcda610d960e589db/onnxruntime_gpu-1.19.2-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a49740e079e7c5215830d30cde3df792e903df007aa0b0fd7aa797937061b27a" },
@ -4746,6 +4767,7 @@ dependencies = [
{ name = "crawl4ai" },
{ name = "dashscope" },
{ name = "datrie" },
{ name = "debugpy" },
{ name = "deepl" },
{ name = "demjson3" },
{ name = "discord-py" },
@ -4877,6 +4899,7 @@ requires-dist = [
{ name = "crawl4ai", specifier = "==0.3.8" },
{ name = "dashscope", specifier = "==1.20.11" },
{ name = "datrie", specifier = "==0.8.2" },
{ name = "debugpy", specifier = ">=1.8.13" },
{ name = "deepl", specifier = "==1.18.0" },
{ name = "demjson3", specifier = "==3.0.6" },
{ name = "discord-py", specifier = "==2.3.2" },