mirror of
https://git-proxy.hk.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-10 15:40:35 +08:00
Optimize graphrag again (#6513)
### What problem does this PR solve? Removed set_entity and set_relation to avoid accessing doc engine during graph computation. Introduced GraphChange to avoid writing unchanged chunks. ### Type of change - [x] Performance Improvement
This commit is contained in:
parent
7a677cb095
commit
6bf26e2a81
@ -47,6 +47,8 @@ from rag.utils.redis_conn import RedisDistributedLock
|
||||
|
||||
stop_event = threading.Event()
|
||||
|
||||
RAGFLOW_DEBUGPY_LISTEN = int(os.environ.get('RAGFLOW_DEBUGPY_LISTEN', "0"))
|
||||
|
||||
def update_progress():
|
||||
lock_value = str(uuid.uuid4())
|
||||
redis_lock = RedisDistributedLock("update_progress", lock_value=lock_value, timeout=60)
|
||||
@ -85,6 +87,11 @@ if __name__ == '__main__':
|
||||
settings.init_settings()
|
||||
print_rag_settings()
|
||||
|
||||
if RAGFLOW_DEBUGPY_LISTEN > 0:
|
||||
logging.info(f"debugpy listen on {RAGFLOW_DEBUGPY_LISTEN}")
|
||||
import debugpy
|
||||
debugpy.listen(("0.0.0.0", RAGFLOW_DEBUGPY_LISTEN))
|
||||
|
||||
# init db
|
||||
init_web_db()
|
||||
init_web_data()
|
||||
|
@ -8,7 +8,7 @@
|
||||
"docnm_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"title_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"title_sm_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"name_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"name_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
|
||||
"important_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
|
||||
"tag_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
|
||||
"important_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
@ -27,16 +27,16 @@
|
||||
"rank_int": {"type": "integer", "default": 0},
|
||||
"rank_flt": {"type": "float", "default": 0},
|
||||
"available_int": {"type": "integer", "default": 1},
|
||||
"knowledge_graph_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"knowledge_graph_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
|
||||
"entities_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
|
||||
"pagerank_fea": {"type": "integer", "default": 0},
|
||||
"tag_feas": {"type": "varchar", "default": ""},
|
||||
|
||||
"from_entity_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"to_entity_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"entity_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"entity_type_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"source_id": {"type": "varchar", "default": ""},
|
||||
"from_entity_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
|
||||
"to_entity_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
|
||||
"entity_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
|
||||
"entity_type_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
|
||||
"source_id": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
|
||||
"n_hop_with_weight": {"type": "varchar", "default": ""},
|
||||
"removed_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}
|
||||
"removed_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"}
|
||||
}
|
||||
|
@ -12,6 +12,8 @@ services:
|
||||
- ${SVR_HTTP_PORT}:9380
|
||||
- 80:80
|
||||
- 443:443
|
||||
- 5678:5678
|
||||
- 5679:5679
|
||||
volumes:
|
||||
- ./ragflow-logs:/ragflow/logs
|
||||
- ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf
|
||||
|
@ -16,7 +16,6 @@
|
||||
import logging
|
||||
import itertools
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
@ -28,7 +27,7 @@ from rag.nlp import is_english
|
||||
import editdistance
|
||||
from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
from graphrag.utils import perform_variable_replacements, chat_limiter
|
||||
from graphrag.utils import perform_variable_replacements, chat_limiter, GraphChange
|
||||
|
||||
DEFAULT_RECORD_DELIMITER = "##"
|
||||
DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
|
||||
@ -39,7 +38,7 @@ DEFAULT_RESOLUTION_RESULT_DELIMITER = "&&"
|
||||
class EntityResolutionResult:
|
||||
"""Entity resolution result class definition."""
|
||||
graph: nx.Graph
|
||||
removed_entities: list
|
||||
change: GraphChange
|
||||
|
||||
|
||||
class EntityResolution(Extractor):
|
||||
@ -54,12 +53,8 @@ class EntityResolution(Extractor):
|
||||
def __init__(
|
||||
self,
|
||||
llm_invoker: CompletionLLM,
|
||||
get_entity: Callable | None = None,
|
||||
set_entity: Callable | None = None,
|
||||
get_relation: Callable | None = None,
|
||||
set_relation: Callable | None = None
|
||||
):
|
||||
super().__init__(llm_invoker, get_entity=get_entity, set_entity=set_entity, get_relation=get_relation, set_relation=set_relation)
|
||||
super().__init__(llm_invoker)
|
||||
"""Init method definition."""
|
||||
self._llm = llm_invoker
|
||||
self._resolution_prompt = ENTITY_RESOLUTION_PROMPT
|
||||
@ -84,8 +79,8 @@ class EntityResolution(Extractor):
|
||||
or DEFAULT_RESOLUTION_RESULT_DELIMITER,
|
||||
}
|
||||
|
||||
nodes = graph.nodes
|
||||
entity_types = list(set(graph.nodes[node].get('entity_type', '-') for node in nodes))
|
||||
nodes = sorted(graph.nodes())
|
||||
entity_types = sorted(set(graph.nodes[node].get('entity_type', '-') for node in nodes))
|
||||
node_clusters = {entity_type: [] for entity_type in entity_types}
|
||||
|
||||
for node in nodes:
|
||||
@ -105,54 +100,22 @@ class EntityResolution(Extractor):
|
||||
nursery.start_soon(lambda: self._resolve_candidate(candidate_resolution_i, resolution_result))
|
||||
callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.")
|
||||
|
||||
change = GraphChange()
|
||||
connect_graph = nx.Graph()
|
||||
removed_entities = []
|
||||
connect_graph.add_edges_from(resolution_result)
|
||||
all_entities_data = []
|
||||
all_relationships_data = []
|
||||
all_remove_nodes = []
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for sub_connect_graph in nx.connected_components(connect_graph):
|
||||
sub_connect_graph = connect_graph.subgraph(sub_connect_graph)
|
||||
remove_nodes = list(sub_connect_graph.nodes)
|
||||
keep_node = remove_nodes.pop()
|
||||
all_remove_nodes.append(remove_nodes)
|
||||
nursery.start_soon(lambda: self._merge_nodes(keep_node, self._get_entity_(remove_nodes), all_entities_data))
|
||||
for remove_node in remove_nodes:
|
||||
removed_entities.append(remove_node)
|
||||
remove_node_neighbors = graph[remove_node]
|
||||
remove_node_neighbors = list(remove_node_neighbors)
|
||||
for remove_node_neighbor in remove_node_neighbors:
|
||||
rel = self._get_relation_(remove_node, remove_node_neighbor)
|
||||
if graph.has_edge(remove_node, remove_node_neighbor):
|
||||
graph.remove_edge(remove_node, remove_node_neighbor)
|
||||
if remove_node_neighbor == keep_node:
|
||||
if graph.has_edge(keep_node, remove_node):
|
||||
graph.remove_edge(keep_node, remove_node)
|
||||
continue
|
||||
if not rel:
|
||||
continue
|
||||
if graph.has_edge(keep_node, remove_node_neighbor):
|
||||
nursery.start_soon(lambda: self._merge_edges(keep_node, remove_node_neighbor, [rel], all_relationships_data))
|
||||
else:
|
||||
pair = sorted([keep_node, remove_node_neighbor])
|
||||
graph.add_edge(pair[0], pair[1], weight=rel['weight'])
|
||||
self._set_relation_(pair[0], pair[1],
|
||||
dict(
|
||||
src_id=pair[0],
|
||||
tgt_id=pair[1],
|
||||
weight=rel['weight'],
|
||||
description=rel['description'],
|
||||
keywords=[],
|
||||
source_id=rel.get("source_id", ""),
|
||||
metadata={"created_at": time.time()}
|
||||
))
|
||||
graph.remove_node(remove_node)
|
||||
merging_nodes = list(sub_connect_graph.nodes)
|
||||
nursery.start_soon(lambda: self._merge_graph_nodes(graph, merging_nodes, change))
|
||||
|
||||
# Update pagerank
|
||||
pr = nx.pagerank(graph)
|
||||
for node_name, pagerank in pr.items():
|
||||
graph.nodes[node_name]["pagerank"] = pagerank
|
||||
|
||||
return EntityResolutionResult(
|
||||
graph=graph,
|
||||
removed_entities=removed_entities
|
||||
change=change,
|
||||
)
|
||||
|
||||
async def _resolve_candidate(self, candidate_resolution_i, resolution_result):
|
||||
|
@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License
|
||||
"""
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
- [GraphRAG](https://github.com/microsoft/graphrag/blob/main/graphrag/prompts/index/community_report.py)
|
||||
"""
|
||||
|
||||
COMMUNITY_REPORT_PROMPT = """
|
||||
|
@ -40,13 +40,9 @@ class CommunityReportsExtractor(Extractor):
|
||||
def __init__(
|
||||
self,
|
||||
llm_invoker: CompletionLLM,
|
||||
get_entity: Callable | None = None,
|
||||
set_entity: Callable | None = None,
|
||||
get_relation: Callable | None = None,
|
||||
set_relation: Callable | None = None,
|
||||
max_report_length: int | None = None,
|
||||
):
|
||||
super().__init__(llm_invoker, get_entity=get_entity, set_entity=set_entity, get_relation=get_relation, set_relation=set_relation)
|
||||
super().__init__(llm_invoker)
|
||||
"""Init method definition."""
|
||||
self._llm = llm_invoker
|
||||
self._extraction_prompt = COMMUNITY_REPORT_PROMPT
|
||||
@ -63,21 +59,28 @@ class CommunityReportsExtractor(Extractor):
|
||||
over, token_count = 0, 0
|
||||
async def extract_community_report(community):
|
||||
nonlocal res_str, res_dict, over, token_count
|
||||
cm_id, ents = community
|
||||
weight = ents["weight"]
|
||||
ents = ents["nodes"]
|
||||
ent_df = pd.DataFrame(self._get_entity_(ents)).dropna()
|
||||
if ent_df.empty or "entity_name" not in ent_df.columns:
|
||||
cm_id, cm = community
|
||||
weight = cm["weight"]
|
||||
ents = cm["nodes"]
|
||||
if len(ents) < 2:
|
||||
return
|
||||
ent_df["entity"] = ent_df["entity_name"]
|
||||
del ent_df["entity_name"]
|
||||
rela_df = pd.DataFrame(self._get_relation_(list(ent_df["entity"]), list(ent_df["entity"]), 10000))
|
||||
if rela_df.empty:
|
||||
return
|
||||
rela_df["source"] = rela_df["src_id"]
|
||||
rela_df["target"] = rela_df["tgt_id"]
|
||||
del rela_df["src_id"]
|
||||
del rela_df["tgt_id"]
|
||||
ent_list = [{"entity": ent, "description": graph.nodes[ent]["description"]} for ent in ents]
|
||||
ent_df = pd.DataFrame(ent_list)
|
||||
|
||||
rela_list = []
|
||||
k = 0
|
||||
for i in range(0, len(ents)):
|
||||
if k >= 10000:
|
||||
break
|
||||
for j in range(i + 1, len(ents)):
|
||||
if k >= 10000:
|
||||
break
|
||||
edge = graph.get_edge_data(ents[i], ents[j])
|
||||
if edge is None:
|
||||
continue
|
||||
rela_list.append({"source": ents[i], "target": ents[j], "description": edge["description"]})
|
||||
k += 1
|
||||
rela_df = pd.DataFrame(rela_list)
|
||||
|
||||
prompt_variables = {
|
||||
"entity_df": ent_df.to_csv(index_label="id"),
|
||||
|
@ -19,10 +19,11 @@ from collections import defaultdict, Counter
|
||||
from copy import deepcopy
|
||||
from typing import Callable
|
||||
import trio
|
||||
import networkx as nx
|
||||
|
||||
from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT
|
||||
from graphrag.utils import get_llm_cache, set_llm_cache, handle_single_entity_extraction, \
|
||||
handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list, chat_limiter
|
||||
handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list, chat_limiter, get_from_to, GraphChange
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
from rag.prompts import message_fit_in
|
||||
from rag.utils import truncate
|
||||
@ -40,18 +41,10 @@ class Extractor:
|
||||
llm_invoker: CompletionLLM,
|
||||
language: str | None = "English",
|
||||
entity_types: list[str] | None = None,
|
||||
get_entity: Callable | None = None,
|
||||
set_entity: Callable | None = None,
|
||||
get_relation: Callable | None = None,
|
||||
set_relation: Callable | None = None,
|
||||
):
|
||||
self._llm = llm_invoker
|
||||
self._language = language
|
||||
self._entity_types = entity_types or DEFAULT_ENTITY_TYPES
|
||||
self._get_entity_ = get_entity
|
||||
self._set_entity_ = set_entity
|
||||
self._get_relation_ = get_relation
|
||||
self._set_relation_ = set_relation
|
||||
|
||||
def _chat(self, system, history, gen_conf):
|
||||
hist = deepcopy(history)
|
||||
@ -152,25 +145,15 @@ class Extractor:
|
||||
async def _merge_nodes(self, entity_name: str, entities: list[dict], all_relationships_data):
|
||||
if not entities:
|
||||
return
|
||||
already_entity_types = []
|
||||
already_source_ids = []
|
||||
already_description = []
|
||||
|
||||
already_node = self._get_entity_(entity_name)
|
||||
if already_node:
|
||||
already_entity_types.append(already_node["entity_type"])
|
||||
already_source_ids.extend(already_node["source_id"])
|
||||
already_description.append(already_node["description"])
|
||||
|
||||
entity_type = sorted(
|
||||
Counter(
|
||||
[dp["entity_type"] for dp in entities] + already_entity_types
|
||||
[dp["entity_type"] for dp in entities]
|
||||
).items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True,
|
||||
)[0][0]
|
||||
description = GRAPH_FIELD_SEP.join(
|
||||
sorted(set([dp["description"] for dp in entities] + already_description))
|
||||
sorted(set([dp["description"] for dp in entities]))
|
||||
)
|
||||
already_source_ids = flat_uniq_list(entities, "source_id")
|
||||
description = await self._handle_entity_relation_summary(entity_name, description)
|
||||
@ -180,7 +163,6 @@ class Extractor:
|
||||
source_id=already_source_ids,
|
||||
)
|
||||
node_data["entity_name"] = entity_name
|
||||
self._set_entity_(entity_name, node_data)
|
||||
all_relationships_data.append(node_data)
|
||||
|
||||
async def _merge_edges(
|
||||
@ -192,36 +174,11 @@ class Extractor:
|
||||
):
|
||||
if not edges_data:
|
||||
return
|
||||
already_weights = []
|
||||
already_source_ids = []
|
||||
already_description = []
|
||||
already_keywords = []
|
||||
|
||||
relation = self._get_relation_(src_id, tgt_id)
|
||||
if relation:
|
||||
already_weights = [relation["weight"]]
|
||||
already_source_ids = relation["source_id"]
|
||||
already_description = [relation["description"]]
|
||||
already_keywords = relation["keywords"]
|
||||
|
||||
weight = sum([dp["weight"] for dp in edges_data] + already_weights)
|
||||
description = GRAPH_FIELD_SEP.join(
|
||||
sorted(set([dp["description"] for dp in edges_data] + already_description))
|
||||
)
|
||||
keywords = flat_uniq_list(edges_data, "keywords") + already_keywords
|
||||
source_id = flat_uniq_list(edges_data, "source_id") + already_source_ids
|
||||
|
||||
for need_insert_id in [src_id, tgt_id]:
|
||||
if self._get_entity_(need_insert_id):
|
||||
continue
|
||||
self._set_entity_(need_insert_id, {
|
||||
"source_id": source_id,
|
||||
"description": description,
|
||||
"entity_type": 'UNKNOWN'
|
||||
})
|
||||
description = await self._handle_entity_relation_summary(
|
||||
f"({src_id}, {tgt_id})", description
|
||||
)
|
||||
weight = sum([edge["weight"] for edge in edges_data])
|
||||
description = GRAPH_FIELD_SEP.join(sorted(set([edge["description"] for edge in edges_data])))
|
||||
description = await self._handle_entity_relation_summary(f"{src_id} -> {tgt_id}", description)
|
||||
keywords = flat_uniq_list(edges_data, "keywords")
|
||||
source_id = flat_uniq_list(edges_data, "source_id")
|
||||
edge_data = dict(
|
||||
src_id=src_id,
|
||||
tgt_id=tgt_id,
|
||||
@ -230,9 +187,41 @@ class Extractor:
|
||||
weight=weight,
|
||||
source_id=source_id
|
||||
)
|
||||
self._set_relation_(src_id, tgt_id, edge_data)
|
||||
if all_relationships_data is not None:
|
||||
all_relationships_data.append(edge_data)
|
||||
all_relationships_data.append(edge_data)
|
||||
|
||||
async def _merge_graph_nodes(self, graph: nx.Graph, nodes: list[str], change: GraphChange):
|
||||
if len(nodes) <= 1:
|
||||
return
|
||||
change.added_updated_nodes.add(nodes[0])
|
||||
change.removed_nodes.extend(nodes[1:])
|
||||
nodes_set = set(nodes)
|
||||
node0_attrs = graph.nodes[nodes[0]]
|
||||
node0_neighbors = set(graph.neighbors(nodes[0]))
|
||||
for node1 in nodes[1:]:
|
||||
# Merge two nodes, keep "entity_name", "entity_type", "page_rank" unchanged.
|
||||
node1_attrs = graph.nodes[node1]
|
||||
node0_attrs["description"] += f"{GRAPH_FIELD_SEP}{node1_attrs['description']}"
|
||||
for attr in ["keywords", "source_id"]:
|
||||
node0_attrs[attr] = sorted(set(node0_attrs[attr].extend(node1_attrs[attr])))
|
||||
for neighbor in graph.neighbors(node1):
|
||||
change.removed_edges.add(get_from_to(node1, neighbor))
|
||||
if neighbor not in nodes_set:
|
||||
edge1_attrs = graph.get_edge_data(node1, neighbor)
|
||||
if neighbor in node0_neighbors:
|
||||
# Merge two edges
|
||||
change.added_updated_edges.add(get_from_to(nodes[0], neighbor))
|
||||
edge0_attrs = graph.get_edge_data(nodes[0], neighbor)
|
||||
edge0_attrs["weight"] += edge1_attrs["weight"]
|
||||
edge0_attrs["description"] += f"{GRAPH_FIELD_SEP}{edge1_attrs['description']}"
|
||||
edge0_attrs["keywords"] = list(set(edge0_attrs["keywords"].extend(edge1_attrs["keywords"])))
|
||||
edge0_attrs["source_id"] = list(set(edge0_attrs["source_id"].extend(edge1_attrs["source_id"])))
|
||||
edge0_attrs["description"] = await self._handle_entity_relation_summary(f"({nodes[0]}, {neighbor})", edge0_attrs["description"])
|
||||
graph.add_edge(nodes[0], neighbor, **edge0_attrs)
|
||||
else:
|
||||
graph.add_edge(nodes[0], neighbor, **edge1_attrs)
|
||||
graph.remove_node(node1)
|
||||
node0_attrs["description"] = await self._handle_entity_relation_summary(nodes[0], node0_attrs["description"])
|
||||
graph.nodes[nodes[0]].update(node0_attrs)
|
||||
|
||||
async def _handle_entity_relation_summary(
|
||||
self,
|
||||
|
@ -6,7 +6,7 @@ Reference:
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
from dataclasses import dataclass
|
||||
import tiktoken
|
||||
import trio
|
||||
@ -53,10 +53,6 @@ class GraphExtractor(Extractor):
|
||||
llm_invoker: CompletionLLM,
|
||||
language: str | None = "English",
|
||||
entity_types: list[str] | None = None,
|
||||
get_entity: Callable | None = None,
|
||||
set_entity: Callable | None = None,
|
||||
get_relation: Callable | None = None,
|
||||
set_relation: Callable | None = None,
|
||||
tuple_delimiter_key: str | None = None,
|
||||
record_delimiter_key: str | None = None,
|
||||
input_text_key: str | None = None,
|
||||
@ -66,7 +62,7 @@ class GraphExtractor(Extractor):
|
||||
max_gleanings: int | None = None,
|
||||
on_error: ErrorHandlerFn | None = None,
|
||||
):
|
||||
super().__init__(llm_invoker, language, entity_types, get_entity, set_entity, get_relation, set_relation)
|
||||
super().__init__(llm_invoker, language, entity_types)
|
||||
"""Init method definition."""
|
||||
# TODO: streamline construction
|
||||
self._llm = llm_invoker
|
||||
|
@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License
|
||||
"""
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
- [GraphRAG](https://github.com/microsoft/graphrag/blob/main/graphrag/prompts/index/extract_graph.py)
|
||||
"""
|
||||
|
||||
GRAPH_EXTRACTION_PROMPT = """
|
||||
|
@ -15,11 +15,11 @@
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
from functools import partial
|
||||
import networkx as nx
|
||||
import trio
|
||||
|
||||
from api import settings
|
||||
from api.utils import get_uuid
|
||||
from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
|
||||
from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
|
||||
from graphrag.general.community_reports_extractor import CommunityReportsExtractor
|
||||
@ -27,32 +27,15 @@ from graphrag.entity_resolution import EntityResolution
|
||||
from graphrag.general.extractor import Extractor
|
||||
from graphrag.utils import (
|
||||
graph_merge,
|
||||
set_entity,
|
||||
get_relation,
|
||||
set_relation,
|
||||
get_entity,
|
||||
get_graph,
|
||||
set_graph,
|
||||
chunk_id,
|
||||
update_nodes_pagerank_nhop_neighbour,
|
||||
does_graph_contains,
|
||||
get_graph_doc_ids,
|
||||
tidy_graph,
|
||||
GraphChange,
|
||||
)
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
|
||||
def graphrag_task_set(tenant_id, kb_id, doc_id) -> bool:
|
||||
key = f"graphrag:{tenant_id}:{kb_id}"
|
||||
ok = REDIS_CONN.set(key, doc_id, exp=3600 * 24)
|
||||
if not ok:
|
||||
raise Exception(f"Faild to set the {key} to {doc_id}")
|
||||
|
||||
|
||||
def graphrag_task_get(tenant_id, kb_id) -> str | None:
|
||||
key = f"graphrag:{tenant_id}:{kb_id}"
|
||||
doc_id = REDIS_CONN.get(key)
|
||||
return doc_id
|
||||
from rag.utils.redis_conn import RedisDistributedLock
|
||||
|
||||
|
||||
async def run_graphrag(
|
||||
@ -72,7 +55,7 @@ async def run_graphrag(
|
||||
):
|
||||
chunks.append(d["content_with_weight"])
|
||||
|
||||
graph, doc_ids = await update_graph(
|
||||
subgraph = await generate_subgraph(
|
||||
LightKGExt
|
||||
if row["parser_config"]["graphrag"]["method"] != "general"
|
||||
else GeneralKGExt,
|
||||
@ -86,14 +69,26 @@ async def run_graphrag(
|
||||
embedding_model,
|
||||
callback,
|
||||
)
|
||||
if not graph:
|
||||
new_graph = None
|
||||
if subgraph:
|
||||
new_graph = await merge_subgraph(
|
||||
tenant_id,
|
||||
kb_id,
|
||||
doc_id,
|
||||
subgraph,
|
||||
embedding_model,
|
||||
callback,
|
||||
)
|
||||
|
||||
if not with_resolution or not with_community:
|
||||
return
|
||||
if with_resolution or with_community:
|
||||
graphrag_task_set(tenant_id, kb_id, doc_id)
|
||||
if with_resolution:
|
||||
|
||||
if new_graph is None:
|
||||
new_graph = await get_graph(tenant_id, kb_id)
|
||||
|
||||
if with_resolution and new_graph is not None:
|
||||
await resolve_entities(
|
||||
graph,
|
||||
doc_ids,
|
||||
new_graph,
|
||||
tenant_id,
|
||||
kb_id,
|
||||
doc_id,
|
||||
@ -101,10 +96,9 @@ async def run_graphrag(
|
||||
embedding_model,
|
||||
callback,
|
||||
)
|
||||
if with_community:
|
||||
if with_community and new_graph is not None:
|
||||
await extract_community(
|
||||
graph,
|
||||
doc_ids,
|
||||
new_graph,
|
||||
tenant_id,
|
||||
kb_id,
|
||||
doc_id,
|
||||
@ -117,7 +111,7 @@ async def run_graphrag(
|
||||
return
|
||||
|
||||
|
||||
async def update_graph(
|
||||
async def generate_subgraph(
|
||||
extractor: Extractor,
|
||||
tenant_id: str,
|
||||
kb_id: str,
|
||||
@ -131,34 +125,41 @@ async def update_graph(
|
||||
):
|
||||
contains = await does_graph_contains(tenant_id, kb_id, doc_id)
|
||||
if contains:
|
||||
callback(msg=f"Graph already contains {doc_id}, cancel myself")
|
||||
return None, None
|
||||
callback(msg=f"Graph already contains {doc_id}")
|
||||
return None
|
||||
start = trio.current_time()
|
||||
ext = extractor(
|
||||
llm_bdl,
|
||||
language=language,
|
||||
entity_types=entity_types,
|
||||
get_entity=partial(get_entity, tenant_id, kb_id),
|
||||
set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl),
|
||||
get_relation=partial(get_relation, tenant_id, kb_id),
|
||||
set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl),
|
||||
)
|
||||
ents, rels = await ext(doc_id, chunks, callback)
|
||||
subgraph = nx.Graph()
|
||||
for en in ents:
|
||||
subgraph.add_node(en["entity_name"], entity_type=en["entity_type"])
|
||||
for ent in ents:
|
||||
assert "description" in ent, f"entity {ent} does not have description"
|
||||
ent["source_id"] = [doc_id]
|
||||
subgraph.add_node(ent["entity_name"], **ent)
|
||||
|
||||
ignored_rels = 0
|
||||
for rel in rels:
|
||||
assert "description" in rel, f"relation {rel} does not have description"
|
||||
if not subgraph.has_node(rel["src_id"]) or not subgraph.has_node(rel["tgt_id"]):
|
||||
ignored_rels += 1
|
||||
continue
|
||||
rel["source_id"] = [doc_id]
|
||||
subgraph.add_edge(
|
||||
rel["src_id"],
|
||||
rel["tgt_id"],
|
||||
weight=rel["weight"],
|
||||
# description=rel["description"]
|
||||
**rel,
|
||||
)
|
||||
# TODO: infinity doesn't support array search
|
||||
if ignored_rels:
|
||||
callback(msg=f"ignored {ignored_rels} relations due to missing entities.")
|
||||
tidy_graph(subgraph, callback)
|
||||
|
||||
subgraph.graph["source_id"] = [doc_id]
|
||||
chunk = {
|
||||
"content_with_weight": json.dumps(
|
||||
nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False, indent=2
|
||||
nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False
|
||||
),
|
||||
"knowledge_graph_kwd": "subgraph",
|
||||
"kb_id": kb_id,
|
||||
@ -167,6 +168,11 @@ async def update_graph(
|
||||
"removed_kwd": "N",
|
||||
}
|
||||
cid = chunk_id(chunk)
|
||||
await trio.to_thread.run_sync(
|
||||
lambda: settings.docStoreConn.delete(
|
||||
{"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id
|
||||
)
|
||||
)
|
||||
await trio.to_thread.run_sync(
|
||||
lambda: settings.docStoreConn.insert(
|
||||
[{"id": cid, **chunk}], search.index_name(tenant_id), kb_id
|
||||
@ -174,39 +180,49 @@ async def update_graph(
|
||||
)
|
||||
now = trio.current_time()
|
||||
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
|
||||
start = now
|
||||
return subgraph
|
||||
|
||||
async def merge_subgraph(
|
||||
tenant_id: str,
|
||||
kb_id: str,
|
||||
doc_id: str,
|
||||
subgraph: nx.Graph,
|
||||
embedding_model,
|
||||
callback,
|
||||
):
|
||||
graphrag_task_lock = RedisDistributedLock("graphrag_task", lock_value=doc_id, timeout=600)
|
||||
while True:
|
||||
if graphrag_task_lock.acquire():
|
||||
break
|
||||
callback(msg=f"merge_subgraph {doc_id} is waiting graphrag_task_lock")
|
||||
await trio.sleep(10)
|
||||
|
||||
start = trio.current_time()
|
||||
change = GraphChange()
|
||||
old_graph = await get_graph(tenant_id, kb_id)
|
||||
if old_graph is not None:
|
||||
logging.info("Merge with an exiting graph...................")
|
||||
tidy_graph(old_graph, callback)
|
||||
new_graph = graph_merge(old_graph, subgraph, change)
|
||||
else:
|
||||
new_graph = subgraph
|
||||
now_docids = set([doc_id])
|
||||
old_graph, old_doc_ids = await get_graph(tenant_id, kb_id)
|
||||
if old_graph is not None:
|
||||
logging.info("Merge with an exiting graph...................")
|
||||
new_graph = graph_merge(old_graph, subgraph)
|
||||
await update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, new_graph, 2)
|
||||
if old_doc_ids:
|
||||
for old_doc_id in old_doc_ids:
|
||||
now_docids.add(old_doc_id)
|
||||
old_doc_ids2 = await get_graph_doc_ids(tenant_id, kb_id)
|
||||
delta_doc_ids = set(old_doc_ids2) - set(old_doc_ids)
|
||||
if delta_doc_ids:
|
||||
callback(
|
||||
msg="The global graph has changed during merging, try again"
|
||||
)
|
||||
await trio.sleep(1)
|
||||
continue
|
||||
break
|
||||
await set_graph(tenant_id, kb_id, new_graph, list(now_docids))
|
||||
change.added_updated_nodes = set(new_graph.nodes())
|
||||
change.added_updated_edges = set(new_graph.edges())
|
||||
pr = nx.pagerank(new_graph)
|
||||
for node_name, pagerank in pr.items():
|
||||
new_graph.nodes[node_name]["pagerank"] = pagerank
|
||||
|
||||
await set_graph(tenant_id, kb_id, embedding_model, new_graph, change, callback)
|
||||
graphrag_task_lock.release()
|
||||
now = trio.current_time()
|
||||
callback(
|
||||
msg=f"merging subgraph for doc {doc_id} into the global graph done in {now - start:.2f} seconds."
|
||||
)
|
||||
return new_graph, now_docids
|
||||
return new_graph
|
||||
|
||||
|
||||
async def resolve_entities(
|
||||
graph,
|
||||
doc_ids,
|
||||
tenant_id: str,
|
||||
kb_id: str,
|
||||
doc_id: str,
|
||||
@ -214,74 +230,30 @@ async def resolve_entities(
|
||||
embed_bdl,
|
||||
callback,
|
||||
):
|
||||
working_doc_id = graphrag_task_get(tenant_id, kb_id)
|
||||
if doc_id != working_doc_id:
|
||||
callback(
|
||||
msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself"
|
||||
)
|
||||
return
|
||||
graphrag_task_lock = RedisDistributedLock("graphrag_task", lock_value=doc_id, timeout=600)
|
||||
while True:
|
||||
if graphrag_task_lock.acquire():
|
||||
break
|
||||
await trio.sleep(10)
|
||||
|
||||
start = trio.current_time()
|
||||
er = EntityResolution(
|
||||
llm_bdl,
|
||||
get_entity=partial(get_entity, tenant_id, kb_id),
|
||||
set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl),
|
||||
get_relation=partial(get_relation, tenant_id, kb_id),
|
||||
set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl),
|
||||
)
|
||||
reso = await er(graph, callback=callback)
|
||||
graph = reso.graph
|
||||
callback(msg=f"Graph resolution removed {len(reso.removed_entities)} nodes.")
|
||||
await update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, 2)
|
||||
change = reso.change
|
||||
callback(msg=f"Graph resolution removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges.")
|
||||
callback(msg="Graph resolution updated pagerank.")
|
||||
|
||||
working_doc_id = graphrag_task_get(tenant_id, kb_id)
|
||||
if doc_id != working_doc_id:
|
||||
callback(
|
||||
msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself"
|
||||
)
|
||||
return
|
||||
await set_graph(tenant_id, kb_id, graph, doc_ids)
|
||||
|
||||
await trio.to_thread.run_sync(
|
||||
lambda: settings.docStoreConn.delete(
|
||||
{
|
||||
"knowledge_graph_kwd": "relation",
|
||||
"kb_id": kb_id,
|
||||
"from_entity_kwd": reso.removed_entities,
|
||||
},
|
||||
search.index_name(tenant_id),
|
||||
kb_id,
|
||||
)
|
||||
)
|
||||
await trio.to_thread.run_sync(
|
||||
lambda: settings.docStoreConn.delete(
|
||||
{
|
||||
"knowledge_graph_kwd": "relation",
|
||||
"kb_id": kb_id,
|
||||
"to_entity_kwd": reso.removed_entities,
|
||||
},
|
||||
search.index_name(tenant_id),
|
||||
kb_id,
|
||||
)
|
||||
)
|
||||
await trio.to_thread.run_sync(
|
||||
lambda: settings.docStoreConn.delete(
|
||||
{
|
||||
"knowledge_graph_kwd": "entity",
|
||||
"kb_id": kb_id,
|
||||
"entity_kwd": reso.removed_entities,
|
||||
},
|
||||
search.index_name(tenant_id),
|
||||
kb_id,
|
||||
)
|
||||
)
|
||||
await set_graph(tenant_id, kb_id, embed_bdl, graph, change, callback)
|
||||
graphrag_task_lock.release()
|
||||
now = trio.current_time()
|
||||
callback(msg=f"Graph resolution done in {now - start:.2f}s.")
|
||||
|
||||
|
||||
async def extract_community(
|
||||
graph,
|
||||
doc_ids,
|
||||
tenant_id: str,
|
||||
kb_id: str,
|
||||
doc_id: str,
|
||||
@ -289,49 +261,34 @@ async def extract_community(
|
||||
embed_bdl,
|
||||
callback,
|
||||
):
|
||||
working_doc_id = graphrag_task_get(tenant_id, kb_id)
|
||||
if doc_id != working_doc_id:
|
||||
callback(
|
||||
msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself"
|
||||
)
|
||||
return
|
||||
graphrag_task_lock = RedisDistributedLock("graphrag_task", lock_value=doc_id, timeout=600)
|
||||
while True:
|
||||
if graphrag_task_lock.acquire():
|
||||
break
|
||||
await trio.sleep(10)
|
||||
|
||||
start = trio.current_time()
|
||||
ext = CommunityReportsExtractor(
|
||||
llm_bdl,
|
||||
get_entity=partial(get_entity, tenant_id, kb_id),
|
||||
set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl),
|
||||
get_relation=partial(get_relation, tenant_id, kb_id),
|
||||
set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl),
|
||||
)
|
||||
cr = await ext(graph, callback=callback)
|
||||
community_structure = cr.structured_output
|
||||
community_reports = cr.output
|
||||
working_doc_id = graphrag_task_get(tenant_id, kb_id)
|
||||
if doc_id != working_doc_id:
|
||||
callback(
|
||||
msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself"
|
||||
)
|
||||
return
|
||||
await set_graph(tenant_id, kb_id, graph, doc_ids)
|
||||
doc_ids = graph.graph["source_id"]
|
||||
|
||||
now = trio.current_time()
|
||||
callback(
|
||||
msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s."
|
||||
)
|
||||
start = now
|
||||
await trio.to_thread.run_sync(
|
||||
lambda: settings.docStoreConn.delete(
|
||||
{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},
|
||||
search.index_name(tenant_id),
|
||||
kb_id,
|
||||
)
|
||||
)
|
||||
chunks = []
|
||||
for stru, rep in zip(community_structure, community_reports):
|
||||
obj = {
|
||||
"report": rep,
|
||||
"evidences": "\n".join([f["explanation"] for f in stru["findings"]]),
|
||||
}
|
||||
chunk = {
|
||||
"id": get_uuid(),
|
||||
"docnm_kwd": stru["title"],
|
||||
"title_tks": rag_tokenizer.tokenize(stru["title"]),
|
||||
"content_with_weight": json.dumps(obj, ensure_ascii=False),
|
||||
@ -349,17 +306,23 @@ async def extract_community(
|
||||
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(
|
||||
chunk["content_ltks"]
|
||||
)
|
||||
# try:
|
||||
# ebd, _ = embed_bdl.encode([", ".join(community["entities"])])
|
||||
# chunk["q_%d_vec" % len(ebd[0])] = ebd[0]
|
||||
# except Exception as e:
|
||||
# logging.exception(f"Fail to embed entity relation: {e}")
|
||||
await trio.to_thread.run_sync(
|
||||
lambda: settings.docStoreConn.insert(
|
||||
[{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id)
|
||||
)
|
||||
)
|
||||
chunks.append(chunk)
|
||||
|
||||
await trio.to_thread.run_sync(
|
||||
lambda: settings.docStoreConn.delete(
|
||||
{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},
|
||||
search.index_name(tenant_id),
|
||||
kb_id,
|
||||
)
|
||||
)
|
||||
es_bulk_size = 4
|
||||
for b in range(0, len(chunks), es_bulk_size):
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(tenant_id), kb_id))
|
||||
if doc_store_result:
|
||||
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
||||
raise Exception(error_message)
|
||||
|
||||
graphrag_task_lock.release()
|
||||
now = trio.current_time()
|
||||
callback(
|
||||
msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s."
|
||||
|
@ -100,7 +100,8 @@ def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]:
|
||||
logging.debug(
|
||||
"Running leiden with max_cluster_size=%s, lcc=%s", max_cluster_size, use_lcc
|
||||
)
|
||||
if not graph.nodes():
|
||||
nodes = set(graph.nodes())
|
||||
if not nodes:
|
||||
return {}
|
||||
|
||||
node_id_to_community_map = _compute_leiden_communities(
|
||||
@ -120,7 +121,7 @@ def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]:
|
||||
result = {}
|
||||
results_by_level[level] = result
|
||||
for node_id, raw_community_id in node_id_to_community_map[level].items():
|
||||
if node_id not in graph.nodes:
|
||||
if node_id not in nodes:
|
||||
logging.warning(f"Node {node_id} not found in the graph.")
|
||||
continue
|
||||
community_id = str(raw_community_id)
|
||||
|
@ -5,7 +5,7 @@ Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
"""
|
||||
import re
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
from dataclasses import dataclass
|
||||
from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS
|
||||
from graphrag.light.graph_prompt import PROMPTS
|
||||
@ -33,14 +33,10 @@ class GraphExtractor(Extractor):
|
||||
llm_invoker: CompletionLLM,
|
||||
language: str | None = "English",
|
||||
entity_types: list[str] | None = None,
|
||||
get_entity: Callable | None = None,
|
||||
set_entity: Callable | None = None,
|
||||
get_relation: Callable | None = None,
|
||||
set_relation: Callable | None = None,
|
||||
example_number: int = 2,
|
||||
max_gleanings: int | None = None,
|
||||
):
|
||||
super().__init__(llm_invoker, language, entity_types, get_entity, set_entity, get_relation, set_relation)
|
||||
super().__init__(llm_invoker, language, entity_types)
|
||||
"""Init method definition."""
|
||||
self._max_gleanings = (
|
||||
max_gleanings
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Licensed under the MIT License
|
||||
"""
|
||||
Reference:
|
||||
- [LightRag](https://github.com/HKUDS/LightRAG)
|
||||
- [LightRAG](https://github.com/HKUDS/LightRAG/blob/main/lightrag/prompt.py)
|
||||
"""
|
||||
|
||||
|
||||
|
@ -12,26 +12,37 @@ import logging
|
||||
import re
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from hashlib import md5
|
||||
from typing import Any, Callable
|
||||
import os
|
||||
import trio
|
||||
from typing import Set, Tuple
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import xxhash
|
||||
from networkx.readwrite import json_graph
|
||||
import dataclasses
|
||||
|
||||
from api import settings
|
||||
from api.utils import get_uuid
|
||||
from rag.nlp import search, rag_tokenizer
|
||||
from rag.utils.doc_store_conn import OrderByExpr
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
GRAPH_FIELD_SEP = "<SEP>"
|
||||
|
||||
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
|
||||
|
||||
chat_limiter = trio.CapacityLimiter(int(os.environ.get('MAX_CONCURRENT_CHATS', 10)))
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GraphChange:
|
||||
removed_nodes: Set[str] = dataclasses.field(default_factory=set)
|
||||
added_updated_nodes: Set[str] = dataclasses.field(default_factory=set)
|
||||
removed_edges: Set[Tuple[str, str]] = dataclasses.field(default_factory=set)
|
||||
added_updated_edges: Set[Tuple[str, str]] = dataclasses.field(default_factory=set)
|
||||
|
||||
def perform_variable_replacements(
|
||||
input: str, history: list[dict] | None = None, variables: dict | None = None
|
||||
) -> str:
|
||||
@ -146,24 +157,74 @@ def set_tags_to_cache(kb_ids, tags):
|
||||
k = hasher.hexdigest()
|
||||
REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600)
|
||||
|
||||
def tidy_graph(graph: nx.Graph, callback):
|
||||
"""
|
||||
Ensure all nodes and edges in the graph have some essential attribute.
|
||||
"""
|
||||
def is_valid_node(node_attrs: dict) -> bool:
|
||||
valid_node = True
|
||||
for attr in ["description", "source_id"]:
|
||||
if attr not in node_attrs:
|
||||
valid_node = False
|
||||
break
|
||||
return valid_node
|
||||
purged_nodes = []
|
||||
for node, node_attrs in graph.nodes(data=True):
|
||||
if not is_valid_node(node_attrs):
|
||||
purged_nodes.append(node)
|
||||
for node in purged_nodes:
|
||||
graph.remove_node(node)
|
||||
if purged_nodes and callback:
|
||||
callback(msg=f"Purged {len(purged_nodes)} nodes from graph due to missing essential attributes.")
|
||||
|
||||
def graph_merge(g1, g2):
|
||||
g = g2.copy()
|
||||
for n, attr in g1.nodes(data=True):
|
||||
if n not in g2.nodes():
|
||||
g.add_node(n, **attr)
|
||||
purged_edges = []
|
||||
for source, target, attr in graph.edges(data=True):
|
||||
if not is_valid_node(attr):
|
||||
purged_edges.append((source, target))
|
||||
if "keywords" not in attr:
|
||||
attr["keywords"] = []
|
||||
for source, target in purged_edges:
|
||||
graph.remove_edge(source, target)
|
||||
if purged_edges and callback:
|
||||
callback(msg=f"Purged {len(purged_edges)} edges from graph due to missing essential attributes.")
|
||||
|
||||
def get_from_to(node1, node2):
|
||||
if node1 < node2:
|
||||
return (node1, node2)
|
||||
else:
|
||||
return (node2, node1)
|
||||
|
||||
def graph_merge(g1: nx.Graph, g2: nx.Graph, change: GraphChange):
|
||||
"""Merge graph g2 into g1 in place."""
|
||||
for node_name, attr in g2.nodes(data=True):
|
||||
change.added_updated_nodes.add(node_name)
|
||||
if not g1.has_node(node_name):
|
||||
g1.add_node(node_name, **attr)
|
||||
continue
|
||||
node = g1.nodes[node_name]
|
||||
node["description"] += GRAPH_FIELD_SEP + attr["description"]
|
||||
# A node's source_id indicates which chunks it came from.
|
||||
node["source_id"] += attr["source_id"]
|
||||
|
||||
for source, target, attr in g1.edges(data=True):
|
||||
if g.has_edge(source, target):
|
||||
g[source][target].update({"weight": attr.get("weight", 0)+1})
|
||||
for source, target, attr in g2.edges(data=True):
|
||||
change.added_updated_edges.add(get_from_to(source, target))
|
||||
edge = g1.get_edge_data(source, target)
|
||||
if edge is None:
|
||||
g1.add_edge(source, target, **attr)
|
||||
continue
|
||||
g.add_edge(source, target)#, **attr)
|
||||
|
||||
for node_degree in g.degree:
|
||||
g.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
|
||||
return g
|
||||
edge["weight"] += attr.get("weight", 0)
|
||||
edge["description"] += GRAPH_FIELD_SEP + attr["description"]
|
||||
edge["keywords"] += attr["keywords"]
|
||||
# A edge's source_id indicates which chunks it came from.
|
||||
edge["source_id"] += attr["source_id"]
|
||||
|
||||
for node_degree in g1.degree:
|
||||
g1.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
|
||||
# A graph's source_id indicates which documents it came from.
|
||||
if "source_id" not in g1.graph:
|
||||
g1.graph["source_id"] = []
|
||||
g1.graph["source_id"] += g2.graph.get("source_id", [])
|
||||
return g1
|
||||
|
||||
def compute_args_hash(*args):
|
||||
return md5(str(args).encode()).hexdigest()
|
||||
@ -237,55 +298,10 @@ def is_float_regex(value):
|
||||
def chunk_id(chunk):
|
||||
return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest()
|
||||
|
||||
def get_entity_cache(tenant_id, kb_id, ent_name) -> str | list[str]:
|
||||
hasher = xxhash.xxh64()
|
||||
hasher.update(str(tenant_id).encode("utf-8"))
|
||||
hasher.update(str(kb_id).encode("utf-8"))
|
||||
hasher.update(str(ent_name).encode("utf-8"))
|
||||
|
||||
k = hasher.hexdigest()
|
||||
bin = REDIS_CONN.get(k)
|
||||
if not bin:
|
||||
return
|
||||
return json.loads(bin)
|
||||
|
||||
|
||||
def set_entity_cache(tenant_id, kb_id, ent_name, content_with_weight):
|
||||
hasher = xxhash.xxh64()
|
||||
hasher.update(str(tenant_id).encode("utf-8"))
|
||||
hasher.update(str(kb_id).encode("utf-8"))
|
||||
hasher.update(str(ent_name).encode("utf-8"))
|
||||
|
||||
k = hasher.hexdigest()
|
||||
REDIS_CONN.set(k, content_with_weight.encode("utf-8"), 3600)
|
||||
|
||||
|
||||
def get_entity(tenant_id, kb_id, ent_name):
|
||||
cache = get_entity_cache(tenant_id, kb_id, ent_name)
|
||||
if cache:
|
||||
return cache
|
||||
conds = {
|
||||
"fields": ["content_with_weight"],
|
||||
"entity_kwd": ent_name,
|
||||
"size": 10000,
|
||||
"knowledge_graph_kwd": ["entity"]
|
||||
}
|
||||
res = []
|
||||
es_res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id])
|
||||
for id in es_res.ids:
|
||||
try:
|
||||
if isinstance(ent_name, str):
|
||||
set_entity_cache(tenant_id, kb_id, ent_name, es_res.field[id]["content_with_weight"])
|
||||
return json.loads(es_res.field[id]["content_with_weight"])
|
||||
res.append(json.loads(es_res.field[id]["content_with_weight"]))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def set_entity(tenant_id, kb_id, embd_mdl, ent_name, meta):
|
||||
async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
|
||||
chunk = {
|
||||
"id": get_uuid(),
|
||||
"important_kwd": [ent_name],
|
||||
"title_tks": rag_tokenizer.tokenize(ent_name),
|
||||
"entity_kwd": ent_name,
|
||||
@ -293,28 +309,19 @@ def set_entity(tenant_id, kb_id, embd_mdl, ent_name, meta):
|
||||
"entity_type_kwd": meta["entity_type"],
|
||||
"content_with_weight": json.dumps(meta, ensure_ascii=False),
|
||||
"content_ltks": rag_tokenizer.tokenize(meta["description"]),
|
||||
"source_id": list(set(meta["source_id"])),
|
||||
"source_id": meta["source_id"],
|
||||
"kb_id": kb_id,
|
||||
"available_int": 0
|
||||
}
|
||||
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
|
||||
set_entity_cache(tenant_id, kb_id, ent_name, chunk["content_with_weight"])
|
||||
res = settings.retrievaler.search({"entity_kwd": ent_name, "size": 1, "fields": []},
|
||||
search.index_name(tenant_id), [kb_id])
|
||||
if res.ids:
|
||||
settings.docStoreConn.update({"entity_kwd": ent_name}, chunk, search.index_name(tenant_id), kb_id)
|
||||
else:
|
||||
ebd = get_embed_cache(embd_mdl.llm_name, ent_name)
|
||||
if ebd is None:
|
||||
try:
|
||||
ebd, _ = embd_mdl.encode([ent_name])
|
||||
ebd = ebd[0]
|
||||
set_embed_cache(embd_mdl.llm_name, ent_name, ebd)
|
||||
except Exception as e:
|
||||
logging.exception(f"Fail to embed entity: {e}")
|
||||
if ebd is not None:
|
||||
chunk["q_%d_vec" % len(ebd)] = ebd
|
||||
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id)
|
||||
ebd = get_embed_cache(embd_mdl.llm_name, ent_name)
|
||||
if ebd is None:
|
||||
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([ent_name]))
|
||||
ebd = ebd[0]
|
||||
set_embed_cache(embd_mdl.llm_name, ent_name, ebd)
|
||||
assert ebd is not None
|
||||
chunk["q_%d_vec" % len(ebd)] = ebd
|
||||
chunks.append(chunk)
|
||||
|
||||
|
||||
def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1):
|
||||
@ -344,40 +351,30 @@ def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1):
|
||||
return res
|
||||
|
||||
|
||||
def set_relation(tenant_id, kb_id, embd_mdl, from_ent_name, to_ent_name, meta):
|
||||
async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks):
|
||||
chunk = {
|
||||
"id": get_uuid(),
|
||||
"from_entity_kwd": from_ent_name,
|
||||
"to_entity_kwd": to_ent_name,
|
||||
"knowledge_graph_kwd": "relation",
|
||||
"content_with_weight": json.dumps(meta, ensure_ascii=False),
|
||||
"content_ltks": rag_tokenizer.tokenize(meta["description"]),
|
||||
"important_kwd": meta["keywords"],
|
||||
"source_id": list(set(meta["source_id"])),
|
||||
"source_id": meta["source_id"],
|
||||
"weight_int": int(meta["weight"]),
|
||||
"kb_id": kb_id,
|
||||
"available_int": 0
|
||||
}
|
||||
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
|
||||
res = settings.retrievaler.search({"from_entity_kwd": to_ent_name, "to_entity_kwd": to_ent_name, "size": 1, "fields": []},
|
||||
search.index_name(tenant_id), [kb_id])
|
||||
|
||||
if res.ids:
|
||||
settings.docStoreConn.update({"from_entity_kwd": from_ent_name, "to_entity_kwd": to_ent_name},
|
||||
chunk,
|
||||
search.index_name(tenant_id), kb_id)
|
||||
else:
|
||||
txt = f"{from_ent_name}->{to_ent_name}"
|
||||
ebd = get_embed_cache(embd_mdl.llm_name, txt)
|
||||
if ebd is None:
|
||||
try:
|
||||
ebd, _ = embd_mdl.encode([txt+f": {meta['description']}"])
|
||||
ebd = ebd[0]
|
||||
set_embed_cache(embd_mdl.llm_name, txt, ebd)
|
||||
except Exception as e:
|
||||
logging.exception(f"Fail to embed entity relation: {e}")
|
||||
if ebd is not None:
|
||||
chunk["q_%d_vec" % len(ebd)] = ebd
|
||||
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id)
|
||||
txt = f"{from_ent_name}->{to_ent_name}"
|
||||
ebd = get_embed_cache(embd_mdl.llm_name, txt)
|
||||
if ebd is None:
|
||||
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([txt+f": {meta['description']}"]))
|
||||
ebd = ebd[0]
|
||||
set_embed_cache(embd_mdl.llm_name, txt, ebd)
|
||||
assert ebd is not None
|
||||
chunk["q_%d_vec" % len(ebd)] = ebd
|
||||
chunks.append(chunk)
|
||||
|
||||
async def does_graph_contains(tenant_id, kb_id, doc_id):
|
||||
# Get doc_ids of graph
|
||||
@ -418,33 +415,68 @@ async def get_graph(tenant_id, kb_id):
|
||||
}
|
||||
res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id]))
|
||||
if res.total == 0:
|
||||
return None, []
|
||||
return None
|
||||
for id in res.ids:
|
||||
try:
|
||||
return json_graph.node_link_graph(json.loads(res.field[id]["content_with_weight"]), edges="edges"), \
|
||||
res.field[id]["source_id"]
|
||||
g = json_graph.node_link_graph(json.loads(res.field[id]["content_with_weight"]), edges="edges")
|
||||
if "source_id" not in g.graph:
|
||||
g.graph["source_id"] = res.field[id]["source_id"]
|
||||
return g
|
||||
except Exception:
|
||||
continue
|
||||
result = await rebuild_graph(tenant_id, kb_id)
|
||||
return result
|
||||
|
||||
|
||||
async def set_graph(tenant_id, kb_id, graph, docids):
|
||||
chunk = {
|
||||
"content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False,
|
||||
indent=2),
|
||||
async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, change: GraphChange, callback):
|
||||
start = trio.current_time()
|
||||
|
||||
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph"]}, search.index_name(tenant_id), kb_id))
|
||||
|
||||
if change.removed_nodes:
|
||||
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id))
|
||||
|
||||
if change.removed_edges:
|
||||
async with trio.open_nursery() as nursery:
|
||||
for from_node, to_node in change.removed_edges:
|
||||
nursery.start_soon(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id))
|
||||
now = trio.current_time()
|
||||
if callback:
|
||||
callback(msg=f"set_graph removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges from index in {now - start:.2f}s.")
|
||||
start = now
|
||||
|
||||
chunks = [{
|
||||
"id": get_uuid(),
|
||||
"content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False),
|
||||
"knowledge_graph_kwd": "graph",
|
||||
"kb_id": kb_id,
|
||||
"source_id": list(docids),
|
||||
"source_id": graph.graph.get("source_id", []),
|
||||
"available_int": 0,
|
||||
"removed_kwd": "N"
|
||||
}
|
||||
res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "graph", "size": 1, "fields": []}, search.index_name(tenant_id), [kb_id]))
|
||||
if res.ids:
|
||||
await trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"knowledge_graph_kwd": "graph"}, chunk,
|
||||
search.index_name(tenant_id), kb_id))
|
||||
else:
|
||||
await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id))
|
||||
}]
|
||||
async with trio.open_nursery() as nursery:
|
||||
for node in change.added_updated_nodes:
|
||||
node_attrs = graph.nodes[node]
|
||||
nursery.start_soon(lambda: graph_node_to_chunk(kb_id, embd_mdl, node, node_attrs, chunks))
|
||||
for from_node, to_node in change.added_updated_edges:
|
||||
edge_attrs = graph.edges[from_node, to_node]
|
||||
nursery.start_soon(lambda: graph_edge_to_chunk(kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks))
|
||||
now = trio.current_time()
|
||||
if callback:
|
||||
callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.")
|
||||
start = now
|
||||
|
||||
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "entity", "relation"]}, search.index_name(tenant_id), kb_id))
|
||||
|
||||
es_bulk_size = 4
|
||||
for b in range(0, len(chunks), es_bulk_size):
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(tenant_id), kb_id))
|
||||
if doc_store_result:
|
||||
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
||||
raise Exception(error_message)
|
||||
now = trio.current_time()
|
||||
if callback:
|
||||
callback(msg=f"set_graph added/updated {len(change.added_updated_nodes)} nodes and {len(change.added_updated_edges)} edges from index in {now - start:.2f}s.")
|
||||
|
||||
|
||||
def is_continuous_subsequence(subseq, seq):
|
||||
@ -489,67 +521,6 @@ def merge_tuples(list1, list2):
|
||||
return result
|
||||
|
||||
|
||||
async def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop):
|
||||
def n_neighbor(id):
|
||||
nonlocal graph, n_hop
|
||||
count = 0
|
||||
source_edge = list(graph.edges(id))
|
||||
if not source_edge:
|
||||
return []
|
||||
count = count + 1
|
||||
while count < n_hop:
|
||||
count = count + 1
|
||||
sc_edge = deepcopy(source_edge)
|
||||
source_edge = []
|
||||
for pair in sc_edge:
|
||||
append_edge = list(graph.edges(pair[-1]))
|
||||
for tuples in merge_tuples([pair], append_edge):
|
||||
source_edge.append(tuples)
|
||||
nbrs = []
|
||||
for path in source_edge:
|
||||
n = {"path": path, "weights": []}
|
||||
wts = nx.get_edge_attributes(graph, 'weight')
|
||||
for i in range(len(path)-1):
|
||||
f, t = path[i], path[i+1]
|
||||
n["weights"].append(wts.get((f, t), 0))
|
||||
nbrs.append(n)
|
||||
return nbrs
|
||||
|
||||
pr = nx.pagerank(graph)
|
||||
try:
|
||||
async with trio.open_nursery() as nursery:
|
||||
for n, p in pr.items():
|
||||
graph.nodes[n]["pagerank"] = p
|
||||
nursery.start_soon(lambda: trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"entity_kwd": n, "kb_id": kb_id},
|
||||
{"rank_flt": p,
|
||||
"n_hop_with_weight": json.dumps((n), ensure_ascii=False)},
|
||||
search.index_name(tenant_id), kb_id)))
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
ty2ents = defaultdict(list)
|
||||
for p, r in sorted(pr.items(), key=lambda x: x[1], reverse=True):
|
||||
ty = graph.nodes[p].get("entity_type")
|
||||
if not ty or len(ty2ents[ty]) > 12:
|
||||
continue
|
||||
ty2ents[ty].append(p)
|
||||
|
||||
chunk = {
|
||||
"content_with_weight": json.dumps(ty2ents, ensure_ascii=False),
|
||||
"kb_id": kb_id,
|
||||
"knowledge_graph_kwd": "ty2ents",
|
||||
"available_int": 0
|
||||
}
|
||||
res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "size": 1, "fields": []},
|
||||
search.index_name(tenant_id), [kb_id]))
|
||||
if res.ids:
|
||||
await trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"knowledge_graph_kwd": "ty2ents"},
|
||||
chunk,
|
||||
search.index_name(tenant_id), kb_id))
|
||||
else:
|
||||
await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id))
|
||||
|
||||
|
||||
async def get_entity_type2sampels(idxnms, kb_ids: list):
|
||||
es_res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids,
|
||||
"size": 10000,
|
||||
@ -584,33 +555,46 @@ def flat_uniq_list(arr, key):
|
||||
|
||||
async def rebuild_graph(tenant_id, kb_id):
|
||||
graph = nx.Graph()
|
||||
src_ids = []
|
||||
flds = ["entity_kwd", "entity_type_kwd", "from_entity_kwd", "to_entity_kwd", "weight_int", "knowledge_graph_kwd", "source_id"]
|
||||
src_ids = set()
|
||||
flds = ["entity_kwd", "from_entity_kwd", "to_entity_kwd", "knowledge_graph_kwd", "content_with_weight", "source_id"]
|
||||
bs = 256
|
||||
for i in range(0, 39*bs, bs):
|
||||
for i in range(0, 1024*bs, bs):
|
||||
es_res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(flds, [],
|
||||
{"kb_id": kb_id, "knowledge_graph_kwd": ["entity", "relation"]},
|
||||
{"kb_id": kb_id, "knowledge_graph_kwd": ["entity"]},
|
||||
[],
|
||||
OrderByExpr(),
|
||||
i, bs, search.index_name(tenant_id), [kb_id]
|
||||
))
|
||||
tot = settings.docStoreConn.getTotal(es_res)
|
||||
if tot == 0:
|
||||
return None, None
|
||||
break
|
||||
|
||||
es_res = settings.docStoreConn.getFields(es_res, flds)
|
||||
for id, d in es_res.items():
|
||||
src_ids.extend(d.get("source_id", []))
|
||||
if d["knowledge_graph_kwd"] == "entity":
|
||||
graph.add_node(d["entity_kwd"], entity_type=d["entity_type_kwd"])
|
||||
elif "from_entity_kwd" in d and "to_entity_kwd" in d:
|
||||
graph.add_edge(
|
||||
d["from_entity_kwd"],
|
||||
d["to_entity_kwd"],
|
||||
weight=int(d["weight_int"])
|
||||
)
|
||||
assert d["knowledge_graph_kwd"] == "relation"
|
||||
src_ids.update(d.get("source_id", []))
|
||||
attrs = json.load(d["content_with_weight"])
|
||||
graph.add_node(d["entity_kwd"], **attrs)
|
||||
|
||||
if len(es_res.keys()) < 128:
|
||||
return graph, list(set(src_ids))
|
||||
for i in range(0, 1024*bs, bs):
|
||||
es_res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(flds, [],
|
||||
{"kb_id": kb_id, "knowledge_graph_kwd": ["relation"]},
|
||||
[],
|
||||
OrderByExpr(),
|
||||
i, bs, search.index_name(tenant_id), [kb_id]
|
||||
))
|
||||
tot = settings.docStoreConn.getTotal(es_res)
|
||||
if tot == 0:
|
||||
return None
|
||||
|
||||
return graph, list(set(src_ids))
|
||||
es_res = settings.docStoreConn.getFields(es_res, flds)
|
||||
for id, d in es_res.items():
|
||||
assert d["knowledge_graph_kwd"] == "relation"
|
||||
src_ids.update(d.get("source_id", []))
|
||||
if graph.has_node(d["from_entity_kwd"]) and graph.has_node(d["to_entity_kwd"]):
|
||||
attrs = json.load(d["content_with_weight"])
|
||||
graph.add_edge(d["from_entity_kwd"], d["to_entity_kwd"], **attrs)
|
||||
|
||||
src_ids = sorted(src_ids)
|
||||
graph.graph["source_id"] = src_ids
|
||||
return graph
|
||||
|
@ -125,6 +125,7 @@ dependencies = [
|
||||
"xxhash>=3.5.0,<4.0.0",
|
||||
"trio>=0.29.0",
|
||||
"langfuse>=2.60.0",
|
||||
"debugpy>=1.8.13",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
@ -517,6 +517,8 @@ async def do_handle_task(task):
|
||||
chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
|
||||
# Either using graphrag or Standard chunking methods
|
||||
elif task.get("task_type", "") == "graphrag":
|
||||
global task_limiter
|
||||
task_limiter = trio.CapacityLimiter(2)
|
||||
graphrag_conf = task_parser_config.get("graphrag", {})
|
||||
if not graphrag_conf.get("use_graphrag", False):
|
||||
return
|
||||
|
@ -172,6 +172,12 @@ class InfinityConnection(DocStoreConnection):
|
||||
ConflictType.Ignore,
|
||||
)
|
||||
|
||||
def field_keyword(self, field_name: str):
|
||||
# The "docnm_kwd" field is always a string, not list.
|
||||
if field_name == "source_id" or (field_name.endswith("_kwd") and field_name != "docnm_kwd"):
|
||||
return True
|
||||
return False
|
||||
|
||||
"""
|
||||
Database operations
|
||||
"""
|
||||
@ -480,9 +486,11 @@ class InfinityConnection(DocStoreConnection):
|
||||
assert "_id" not in d
|
||||
assert "id" in d
|
||||
for k, v in d.items():
|
||||
if k in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]:
|
||||
assert isinstance(v, list)
|
||||
d[k] = "###".join(v)
|
||||
if self.field_keyword(k):
|
||||
if isinstance(v, list):
|
||||
d[k] = "###".join(v)
|
||||
else:
|
||||
d[k] = v
|
||||
elif re.search(r"_feas$", k):
|
||||
d[k] = json.dumps(v)
|
||||
elif k == 'kb_id':
|
||||
@ -495,6 +503,8 @@ class InfinityConnection(DocStoreConnection):
|
||||
elif k in ["page_num_int", "top_int"]:
|
||||
assert isinstance(v, list)
|
||||
d[k] = "_".join(f"{num:08x}" for num in v)
|
||||
else:
|
||||
d[k] = v
|
||||
|
||||
for n, vs in embedding_clmns:
|
||||
if n in d:
|
||||
@ -525,13 +535,13 @@ class InfinityConnection(DocStoreConnection):
|
||||
# del condition["exists"]
|
||||
filter = equivalent_condition_to_str(condition, table_instance)
|
||||
for k, v in list(newValue.items()):
|
||||
if k in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]:
|
||||
assert isinstance(v, list)
|
||||
newValue[k] = "###".join(v)
|
||||
if self.field_keyword(k):
|
||||
if isinstance(v, list):
|
||||
newValue[k] = "###".join(v)
|
||||
else:
|
||||
newValue[k] = v
|
||||
elif re.search(r"_feas$", k):
|
||||
newValue[k] = json.dumps(v)
|
||||
elif k.endswith("_kwd") and isinstance(v, list):
|
||||
newValue[k] = " ".join(v)
|
||||
elif k == 'kb_id':
|
||||
if isinstance(newValue[k], list):
|
||||
newValue[k] = newValue[k][0] # since d[k] is a list, but we need a str
|
||||
@ -546,6 +556,8 @@ class InfinityConnection(DocStoreConnection):
|
||||
del newValue[k]
|
||||
if v in [PAGERANK_FLD]:
|
||||
newValue[v] = 0
|
||||
else:
|
||||
newValue[k] = v
|
||||
|
||||
logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.")
|
||||
table_instance.update(filter, newValue)
|
||||
@ -600,7 +612,7 @@ class InfinityConnection(DocStoreConnection):
|
||||
|
||||
for column in res2.columns:
|
||||
k = column.lower()
|
||||
if k in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]:
|
||||
if self.field_keyword(k):
|
||||
res2[column] = res2[column].apply(lambda v:[kwd for kwd in v.split("###") if kwd])
|
||||
elif k == "position_int":
|
||||
def to_position_int(v):
|
||||
|
@ -319,9 +319,3 @@ class RedisDistributedLock:
|
||||
|
||||
def release(self):
|
||||
return self.lock.release()
|
||||
|
||||
def __enter__(self):
|
||||
self.acquire()
|
||||
|
||||
def __exit__(self, exception_type, exception_value, exception_traceback):
|
||||
self.release()
|
57
uv.lock
generated
57
uv.lock
generated
@ -1100,6 +1100,27 @@ version = "0.8.2"
|
||||
source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
|
||||
sdist = { url = "https://mirrors.aliyun.com/pypi/packages/9d/fe/db74bd405d515f06657f11ad529878fd389576dca4812bea6f98d9b31574/datrie-0.8.2.tar.gz", hash = "sha256:525b08f638d5cf6115df6ccd818e5a01298cd230b2dac91c8ff2e6499d18765d" }
|
||||
|
||||
[[package]]
|
||||
name = "debugpy"
|
||||
version = "1.8.13"
|
||||
source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
|
||||
sdist = { url = "https://mirrors.aliyun.com/pypi/packages/51/d4/f35f539e11c9344652f362c22413ec5078f677ac71229dc9b4f6f85ccaa3/debugpy-1.8.13.tar.gz", hash = "sha256:837e7bef95bdefba426ae38b9a94821ebdc5bea55627879cd48165c90b9e50ce" }
|
||||
wheels = [
|
||||
{ url = "https://mirrors.aliyun.com/pypi/packages/3f/32/901c7204cceb3262fdf38f4c25c9a46372c11661e8490e9ea702bc4ff448/debugpy-1.8.13-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:06859f68e817966723ffe046b896b1bd75c665996a77313370336ee9e1de3e90" },
|
||||
{ url = "https://mirrors.aliyun.com/pypi/packages/95/10/77fe746851c8d84838a807da60c7bd0ac8627a6107d6917dd3293bf8628c/debugpy-1.8.13-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cb56c2db69fb8df3168bc857d7b7d2494fed295dfdbde9a45f27b4b152f37520" },
|
||||
{ url = "https://mirrors.aliyun.com/pypi/packages/a1/ef/28f8db2070e453dda0e49b356e339d0b4e1d38058d4c4ea9e88cdc8ee8e7/debugpy-1.8.13-cp310-cp310-win32.whl", hash = "sha256:46abe0b821cad751fc1fb9f860fb2e68d75e2c5d360986d0136cd1db8cad4428" },
|
||||
{ url = "https://mirrors.aliyun.com/pypi/packages/89/16/1d53a80caf5862627d3eaffb217d4079d7e4a1df6729a2d5153733661efd/debugpy-1.8.13-cp310-cp310-win_amd64.whl", hash = "sha256:dc7b77f5d32674686a5f06955e4b18c0e41fb5a605f5b33cf225790f114cfeec" },
|
||||
{ url = "https://mirrors.aliyun.com/pypi/packages/31/90/dd2fcad8364f0964f476537481985198ce6e879760281ad1cec289f1aa71/debugpy-1.8.13-cp311-cp311-macosx_14_0_universal2.whl", hash = "sha256:eee02b2ed52a563126c97bf04194af48f2fe1f68bb522a312b05935798e922ff" },
|
||||
{ url = "https://mirrors.aliyun.com/pypi/packages/5c/c9/06ff65f15eb30dbdafd45d1575770b842ce3869ad5580a77f4e5590f1be7/debugpy-1.8.13-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4caca674206e97c85c034c1efab4483f33971d4e02e73081265ecb612af65377" },
|
||||
{ url = "https://mirrors.aliyun.com/pypi/packages/3b/49/798a4092bde16a4650f17ac5f2301d4d37e1972d65462fb25c80a83b4790/debugpy-1.8.13-cp311-cp311-win32.whl", hash = "sha256:7d9a05efc6973b5aaf076d779cf3a6bbb1199e059a17738a2aa9d27a53bcc888" },
|
||||
{ url = "https://mirrors.aliyun.com/pypi/packages/cd/d5/3684d7561c8ba2797305cf8259619acccb8d6ebe2117bb33a6897c235eee/debugpy-1.8.13-cp311-cp311-win_amd64.whl", hash = "sha256:62f9b4a861c256f37e163ada8cf5a81f4c8d5148fc17ee31fb46813bd658cdcc" },
|
||||
{ url = "https://mirrors.aliyun.com/pypi/packages/79/ad/dff929b6b5403feaab0af0e5bb460fd723f9c62538b718a9af819b8fff20/debugpy-1.8.13-cp312-cp312-macosx_14_0_universal2.whl", hash = "sha256:2b8de94c5c78aa0d0ed79023eb27c7c56a64c68217d881bee2ffbcb13951d0c1" },
|
||||
{ url = "https://mirrors.aliyun.com/pypi/packages/d6/4f/b7d42e6679f0bb525888c278b0c0d2b6dff26ed42795230bb46eaae4f9b3/debugpy-1.8.13-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:887d54276cefbe7290a754424b077e41efa405a3e07122d8897de54709dbe522" },
|
||||
{ url = "https://mirrors.aliyun.com/pypi/packages/ec/18/d9b3e88e85d41f68f77235112adc31012a784e45a3fcdbb039777d570a0f/debugpy-1.8.13-cp312-cp312-win32.whl", hash = "sha256:3872ce5453b17837ef47fb9f3edc25085ff998ce63543f45ba7af41e7f7d370f" },
|
||||
{ url = "https://mirrors.aliyun.com/pypi/packages/c9/f7/0df18a4f530ed3cc06f0060f548efe9e3316102101e311739d906f5650be/debugpy-1.8.13-cp312-cp312-win_amd64.whl", hash = "sha256:63ca7670563c320503fea26ac688988d9d6b9c6a12abc8a8cf2e7dd8e5f6b6ea" },
|
||||
{ url = "https://mirrors.aliyun.com/pypi/packages/37/4f/0b65410a08b6452bfd3f7ed6f3610f1a31fb127f46836e82d31797065dcb/debugpy-1.8.13-py2.py3-none-any.whl", hash = "sha256:d4ba115cdd0e3a70942bd562adba9ec8c651fe69ddde2298a1be296fc331906f" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "decorator"
|
||||
version = "5.2.1"
|
||||
@ -1375,17 +1396,17 @@ name = "fastembed-gpu"
|
||||
version = "0.3.6"
|
||||
source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
|
||||
dependencies = [
|
||||
{ name = "huggingface-hub" },
|
||||
{ name = "loguru" },
|
||||
{ name = "mmh3" },
|
||||
{ name = "numpy" },
|
||||
{ name = "onnxruntime-gpu" },
|
||||
{ name = "pillow" },
|
||||
{ name = "pystemmer" },
|
||||
{ name = "requests" },
|
||||
{ name = "snowballstemmer" },
|
||||
{ name = "tokenizers" },
|
||||
{ name = "tqdm" },
|
||||
{ name = "huggingface-hub", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "loguru", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "mmh3", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "numpy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "onnxruntime-gpu", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "pillow", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "pystemmer", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "requests", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "snowballstemmer", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "tokenizers", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "tqdm", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
]
|
||||
sdist = { url = "https://mirrors.aliyun.com/pypi/packages/da/07/7336c7f3d7ee47f33b407eeb50f5eeb152889de538a52a8f1cc637192816/fastembed_gpu-0.3.6.tar.gz", hash = "sha256:ee2de8918b142adbbf48caaffec0c492f864d73c073eea5a3dcd0e8c1041c50d" }
|
||||
wheels = [
|
||||
@ -3531,12 +3552,12 @@ name = "onnxruntime-gpu"
|
||||
version = "1.19.2"
|
||||
source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
|
||||
dependencies = [
|
||||
{ name = "coloredlogs" },
|
||||
{ name = "flatbuffers" },
|
||||
{ name = "numpy" },
|
||||
{ name = "packaging" },
|
||||
{ name = "protobuf" },
|
||||
{ name = "sympy" },
|
||||
{ name = "coloredlogs", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "flatbuffers", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "numpy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "packaging", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "protobuf", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "sympy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://mirrors.aliyun.com/pypi/packages/d0/9c/3fa310e0730643051eb88e884f19813a6c8b67d0fbafcda610d960e589db/onnxruntime_gpu-1.19.2-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a49740e079e7c5215830d30cde3df792e903df007aa0b0fd7aa797937061b27a" },
|
||||
@ -4746,6 +4767,7 @@ dependencies = [
|
||||
{ name = "crawl4ai" },
|
||||
{ name = "dashscope" },
|
||||
{ name = "datrie" },
|
||||
{ name = "debugpy" },
|
||||
{ name = "deepl" },
|
||||
{ name = "demjson3" },
|
||||
{ name = "discord-py" },
|
||||
@ -4877,6 +4899,7 @@ requires-dist = [
|
||||
{ name = "crawl4ai", specifier = "==0.3.8" },
|
||||
{ name = "dashscope", specifier = "==1.20.11" },
|
||||
{ name = "datrie", specifier = "==0.8.2" },
|
||||
{ name = "debugpy", specifier = ">=1.8.13" },
|
||||
{ name = "deepl", specifier = "==1.18.0" },
|
||||
{ name = "demjson3", specifier = "==3.0.6" },
|
||||
{ name = "discord-py", specifier = "==2.3.2" },
|
||||
|
Loading…
x
Reference in New Issue
Block a user