mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-06-04 11:24:00 +08:00

## Problem Description Multiple files in the RAGFlow project contain closure trap issues when using lambda functions with `trio.open_nursery()`. This problem causes concurrent tasks created in loops to reference the same variable, resulting in all tasks processing the same data (the data from the last iteration) rather than each task processing its corresponding data from the loop. ## Issue Details When using a `lambda` to create a closure function and passing it to `nursery.start_soon()` within a loop, the lambda function captures a reference to the loop variable rather than its value. For example: ```python # Problematic code async with trio.open_nursery() as nursery: for d in docs: nursery.start_soon(lambda: doc_keyword_extraction(chat_mdl, d, topn)) ``` In this pattern, when concurrent tasks begin execution, `d` has already become the value after the loop ends (typically the last element), causing all tasks to use the same data. ## Fix Solution Changed the way concurrent tasks are created with `nursery.start_soon()` by leveraging Trio's API design to directly pass the function and its arguments separately: ```python # Fixed code async with trio.open_nursery() as nursery: for d in docs: nursery.start_soon(doc_keyword_extraction, chat_mdl, d, topn) ``` This way, each task uses the parameter values at the time of the function call, rather than references captured through closures. ## Fixed Files Fixed closure traps in the following files: 1. `rag/svr/task_executor.py`: 3 fixes, involving document keyword extraction, question generation, and tag processing 2. `rag/raptor.py`: 1 fix, involving document summarization 3. `graphrag/utils.py`: 2 fixes, involving graph node and edge processing 4. `graphrag/entity_resolution.py`: 2 fixes, involving entity resolution and graph node merging 5. `graphrag/general/mind_map_extractor.py`: 2 fixes, involving document processing 6. `graphrag/general/extractor.py`: 3 fixes, involving content processing and graph node/edge merging 7. `graphrag/general/community_reports_extractor.py`: 1 fix, involving community report extraction ## Potential Impact This fix resolves a serious concurrency issue that could have caused: - Data processing errors (processing duplicate data) - Performance degradation (all tasks working on the same data) - Inconsistent results (some data not being processed) After the fix, all concurrent tasks should correctly process their respective data, improving system correctness and reliability.
602 lines
21 KiB
Python
602 lines
21 KiB
Python
# Copyright (c) 2024 Microsoft Corporation.
|
|
# Licensed under the MIT License
|
|
"""
|
|
Reference:
|
|
- [graphrag](https://github.com/microsoft/graphrag)
|
|
- [LightRag](https://github.com/HKUDS/LightRAG)
|
|
"""
|
|
|
|
import html
|
|
import json
|
|
import logging
|
|
import re
|
|
import time
|
|
from collections import defaultdict
|
|
from hashlib import md5
|
|
from typing import Any, Callable
|
|
import os
|
|
import trio
|
|
from typing import Set, Tuple
|
|
|
|
import networkx as nx
|
|
import numpy as np
|
|
import xxhash
|
|
from networkx.readwrite import json_graph
|
|
import dataclasses
|
|
|
|
from api import settings
|
|
from api.utils import get_uuid
|
|
from rag.nlp import search, rag_tokenizer
|
|
from rag.utils.doc_store_conn import OrderByExpr
|
|
from rag.utils.redis_conn import REDIS_CONN
|
|
|
|
GRAPH_FIELD_SEP = "<SEP>"
|
|
|
|
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
|
|
|
|
chat_limiter = trio.CapacityLimiter(int(os.environ.get('MAX_CONCURRENT_CHATS', 10)))
|
|
|
|
@dataclasses.dataclass
|
|
class GraphChange:
|
|
removed_nodes: Set[str] = dataclasses.field(default_factory=set)
|
|
added_updated_nodes: Set[str] = dataclasses.field(default_factory=set)
|
|
removed_edges: Set[Tuple[str, str]] = dataclasses.field(default_factory=set)
|
|
added_updated_edges: Set[Tuple[str, str]] = dataclasses.field(default_factory=set)
|
|
|
|
def perform_variable_replacements(
|
|
input: str, history: list[dict] | None = None, variables: dict | None = None
|
|
) -> str:
|
|
"""Perform variable replacements on the input string and in a chat log."""
|
|
if history is None:
|
|
history = []
|
|
if variables is None:
|
|
variables = {}
|
|
result = input
|
|
|
|
def replace_all(input: str) -> str:
|
|
result = input
|
|
for k, v in variables.items():
|
|
result = result.replace(f"{{{k}}}", str(v))
|
|
return result
|
|
|
|
result = replace_all(result)
|
|
for i, entry in enumerate(history):
|
|
if entry.get("role") == "system":
|
|
entry["content"] = replace_all(entry.get("content") or "")
|
|
|
|
return result
|
|
|
|
|
|
def clean_str(input: Any) -> str:
|
|
"""Clean an input string by removing HTML escapes, control characters, and other unwanted characters."""
|
|
# If we get non-string input, just give it back
|
|
if not isinstance(input, str):
|
|
return input
|
|
|
|
result = html.unescape(input.strip())
|
|
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
|
|
return re.sub(r"[\"\x00-\x1f\x7f-\x9f]", "", result)
|
|
|
|
|
|
def dict_has_keys_with_types(
|
|
data: dict, expected_fields: list[tuple[str, type]]
|
|
) -> bool:
|
|
"""Return True if the given dictionary has the given keys with the given types."""
|
|
for field, field_type in expected_fields:
|
|
if field not in data:
|
|
return False
|
|
|
|
value = data[field]
|
|
if not isinstance(value, field_type):
|
|
return False
|
|
return True
|
|
|
|
|
|
def get_llm_cache(llmnm, txt, history, genconf):
|
|
hasher = xxhash.xxh64()
|
|
hasher.update(str(llmnm).encode("utf-8"))
|
|
hasher.update(str(txt).encode("utf-8"))
|
|
hasher.update(str(history).encode("utf-8"))
|
|
hasher.update(str(genconf).encode("utf-8"))
|
|
|
|
k = hasher.hexdigest()
|
|
bin = REDIS_CONN.get(k)
|
|
if not bin:
|
|
return
|
|
return bin
|
|
|
|
|
|
def set_llm_cache(llmnm, txt, v, history, genconf):
|
|
hasher = xxhash.xxh64()
|
|
hasher.update(str(llmnm).encode("utf-8"))
|
|
hasher.update(str(txt).encode("utf-8"))
|
|
hasher.update(str(history).encode("utf-8"))
|
|
hasher.update(str(genconf).encode("utf-8"))
|
|
|
|
k = hasher.hexdigest()
|
|
REDIS_CONN.set(k, v.encode("utf-8"), 24*3600)
|
|
|
|
|
|
def get_embed_cache(llmnm, txt):
|
|
hasher = xxhash.xxh64()
|
|
hasher.update(str(llmnm).encode("utf-8"))
|
|
hasher.update(str(txt).encode("utf-8"))
|
|
|
|
k = hasher.hexdigest()
|
|
bin = REDIS_CONN.get(k)
|
|
if not bin:
|
|
return
|
|
return np.array(json.loads(bin))
|
|
|
|
|
|
def set_embed_cache(llmnm, txt, arr):
|
|
hasher = xxhash.xxh64()
|
|
hasher.update(str(llmnm).encode("utf-8"))
|
|
hasher.update(str(txt).encode("utf-8"))
|
|
|
|
k = hasher.hexdigest()
|
|
arr = json.dumps(arr.tolist() if isinstance(arr, np.ndarray) else arr)
|
|
REDIS_CONN.set(k, arr.encode("utf-8"), 24*3600)
|
|
|
|
|
|
def get_tags_from_cache(kb_ids):
|
|
hasher = xxhash.xxh64()
|
|
hasher.update(str(kb_ids).encode("utf-8"))
|
|
|
|
k = hasher.hexdigest()
|
|
bin = REDIS_CONN.get(k)
|
|
if not bin:
|
|
return
|
|
return bin
|
|
|
|
|
|
def set_tags_to_cache(kb_ids, tags):
|
|
hasher = xxhash.xxh64()
|
|
hasher.update(str(kb_ids).encode("utf-8"))
|
|
|
|
k = hasher.hexdigest()
|
|
REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600)
|
|
|
|
def tidy_graph(graph: nx.Graph, callback):
|
|
"""
|
|
Ensure all nodes and edges in the graph have some essential attribute.
|
|
"""
|
|
def is_valid_node(node_attrs: dict) -> bool:
|
|
valid_node = True
|
|
for attr in ["description", "source_id"]:
|
|
if attr not in node_attrs:
|
|
valid_node = False
|
|
break
|
|
return valid_node
|
|
purged_nodes = []
|
|
for node, node_attrs in graph.nodes(data=True):
|
|
if not is_valid_node(node_attrs):
|
|
purged_nodes.append(node)
|
|
for node in purged_nodes:
|
|
graph.remove_node(node)
|
|
if purged_nodes and callback:
|
|
callback(msg=f"Purged {len(purged_nodes)} nodes from graph due to missing essential attributes.")
|
|
|
|
purged_edges = []
|
|
for source, target, attr in graph.edges(data=True):
|
|
if not is_valid_node(attr):
|
|
purged_edges.append((source, target))
|
|
if "keywords" not in attr:
|
|
attr["keywords"] = []
|
|
for source, target in purged_edges:
|
|
graph.remove_edge(source, target)
|
|
if purged_edges and callback:
|
|
callback(msg=f"Purged {len(purged_edges)} edges from graph due to missing essential attributes.")
|
|
|
|
def get_from_to(node1, node2):
|
|
if node1 < node2:
|
|
return (node1, node2)
|
|
else:
|
|
return (node2, node1)
|
|
|
|
def graph_merge(g1: nx.Graph, g2: nx.Graph, change: GraphChange):
|
|
"""Merge graph g2 into g1 in place."""
|
|
for node_name, attr in g2.nodes(data=True):
|
|
change.added_updated_nodes.add(node_name)
|
|
if not g1.has_node(node_name):
|
|
g1.add_node(node_name, **attr)
|
|
continue
|
|
node = g1.nodes[node_name]
|
|
node["description"] += GRAPH_FIELD_SEP + attr["description"]
|
|
# A node's source_id indicates which chunks it came from.
|
|
node["source_id"] += attr["source_id"]
|
|
|
|
for source, target, attr in g2.edges(data=True):
|
|
change.added_updated_edges.add(get_from_to(source, target))
|
|
edge = g1.get_edge_data(source, target)
|
|
if edge is None:
|
|
g1.add_edge(source, target, **attr)
|
|
continue
|
|
edge["weight"] += attr.get("weight", 0)
|
|
edge["description"] += GRAPH_FIELD_SEP + attr["description"]
|
|
edge["keywords"] += attr["keywords"]
|
|
# A edge's source_id indicates which chunks it came from.
|
|
edge["source_id"] += attr["source_id"]
|
|
|
|
for node_degree in g1.degree:
|
|
g1.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
|
|
# A graph's source_id indicates which documents it came from.
|
|
if "source_id" not in g1.graph:
|
|
g1.graph["source_id"] = []
|
|
g1.graph["source_id"] += g2.graph.get("source_id", [])
|
|
return g1
|
|
|
|
def compute_args_hash(*args):
|
|
return md5(str(args).encode()).hexdigest()
|
|
|
|
|
|
def handle_single_entity_extraction(
|
|
record_attributes: list[str],
|
|
chunk_key: str,
|
|
):
|
|
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
|
|
return None
|
|
# add this record as a node in the G
|
|
entity_name = clean_str(record_attributes[1].upper())
|
|
if not entity_name.strip():
|
|
return None
|
|
entity_type = clean_str(record_attributes[2].upper())
|
|
entity_description = clean_str(record_attributes[3])
|
|
entity_source_id = chunk_key
|
|
return dict(
|
|
entity_name=entity_name.upper(),
|
|
entity_type=entity_type.upper(),
|
|
description=entity_description,
|
|
source_id=entity_source_id,
|
|
)
|
|
|
|
|
|
def handle_single_relationship_extraction(record_attributes: list[str], chunk_key: str):
|
|
if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
|
|
return None
|
|
# add this record as edge
|
|
source = clean_str(record_attributes[1].upper())
|
|
target = clean_str(record_attributes[2].upper())
|
|
edge_description = clean_str(record_attributes[3])
|
|
|
|
edge_keywords = clean_str(record_attributes[4])
|
|
edge_source_id = chunk_key
|
|
weight = (
|
|
float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0
|
|
)
|
|
pair = sorted([source.upper(), target.upper()])
|
|
return dict(
|
|
src_id=pair[0],
|
|
tgt_id=pair[1],
|
|
weight=weight,
|
|
description=edge_description,
|
|
keywords=edge_keywords,
|
|
source_id=edge_source_id,
|
|
metadata={"created_at": time.time()},
|
|
)
|
|
|
|
|
|
def pack_user_ass_to_openai_messages(*args: str):
|
|
roles = ["user", "assistant"]
|
|
return [
|
|
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
|
|
]
|
|
|
|
|
|
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
|
|
"""Split a string by multiple markers"""
|
|
if not markers:
|
|
return [content]
|
|
results = re.split("|".join(re.escape(marker) for marker in markers), content)
|
|
return [r.strip() for r in results if r.strip()]
|
|
|
|
|
|
def is_float_regex(value):
|
|
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
|
|
|
|
|
|
def chunk_id(chunk):
|
|
return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest()
|
|
|
|
|
|
async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
|
|
chunk = {
|
|
"id": get_uuid(),
|
|
"important_kwd": [ent_name],
|
|
"title_tks": rag_tokenizer.tokenize(ent_name),
|
|
"entity_kwd": ent_name,
|
|
"knowledge_graph_kwd": "entity",
|
|
"entity_type_kwd": meta["entity_type"],
|
|
"content_with_weight": json.dumps(meta, ensure_ascii=False),
|
|
"content_ltks": rag_tokenizer.tokenize(meta["description"]),
|
|
"source_id": meta["source_id"],
|
|
"kb_id": kb_id,
|
|
"available_int": 0
|
|
}
|
|
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
|
|
ebd = get_embed_cache(embd_mdl.llm_name, ent_name)
|
|
if ebd is None:
|
|
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([ent_name]))
|
|
ebd = ebd[0]
|
|
set_embed_cache(embd_mdl.llm_name, ent_name, ebd)
|
|
assert ebd is not None
|
|
chunk["q_%d_vec" % len(ebd)] = ebd
|
|
chunks.append(chunk)
|
|
|
|
|
|
def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1):
|
|
ents = from_ent_name
|
|
if isinstance(ents, str):
|
|
ents = [from_ent_name]
|
|
if isinstance(to_ent_name, str):
|
|
to_ent_name = [to_ent_name]
|
|
ents.extend(to_ent_name)
|
|
ents = list(set(ents))
|
|
conds = {
|
|
"fields": ["content_with_weight"],
|
|
"size": size,
|
|
"from_entity_kwd": ents,
|
|
"to_entity_kwd": ents,
|
|
"knowledge_graph_kwd": ["relation"]
|
|
}
|
|
res = []
|
|
es_res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id)
|
|
for id in es_res.ids:
|
|
try:
|
|
if size == 1:
|
|
return json.loads(es_res.field[id]["content_with_weight"])
|
|
res.append(json.loads(es_res.field[id]["content_with_weight"]))
|
|
except Exception:
|
|
continue
|
|
return res
|
|
|
|
|
|
async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks):
|
|
chunk = {
|
|
"id": get_uuid(),
|
|
"from_entity_kwd": from_ent_name,
|
|
"to_entity_kwd": to_ent_name,
|
|
"knowledge_graph_kwd": "relation",
|
|
"content_with_weight": json.dumps(meta, ensure_ascii=False),
|
|
"content_ltks": rag_tokenizer.tokenize(meta["description"]),
|
|
"important_kwd": meta["keywords"],
|
|
"source_id": meta["source_id"],
|
|
"weight_int": int(meta["weight"]),
|
|
"kb_id": kb_id,
|
|
"available_int": 0
|
|
}
|
|
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
|
|
txt = f"{from_ent_name}->{to_ent_name}"
|
|
ebd = get_embed_cache(embd_mdl.llm_name, txt)
|
|
if ebd is None:
|
|
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([txt+f": {meta['description']}"]))
|
|
ebd = ebd[0]
|
|
set_embed_cache(embd_mdl.llm_name, txt, ebd)
|
|
assert ebd is not None
|
|
chunk["q_%d_vec" % len(ebd)] = ebd
|
|
chunks.append(chunk)
|
|
|
|
async def does_graph_contains(tenant_id, kb_id, doc_id):
|
|
# Get doc_ids of graph
|
|
fields = ["source_id"]
|
|
condition = {
|
|
"knowledge_graph_kwd": ["graph"],
|
|
"removed_kwd": "N",
|
|
}
|
|
res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id]))
|
|
fields2 = settings.docStoreConn.getFields(res, fields)
|
|
graph_doc_ids = set()
|
|
for chunk_id in fields2.keys():
|
|
graph_doc_ids = set(fields2[chunk_id]["source_id"])
|
|
return doc_id in graph_doc_ids
|
|
|
|
async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
|
|
conds = {
|
|
"fields": ["source_id"],
|
|
"removed_kwd": "N",
|
|
"size": 1,
|
|
"knowledge_graph_kwd": ["graph"]
|
|
}
|
|
res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id]))
|
|
doc_ids = []
|
|
if res.total == 0:
|
|
return doc_ids
|
|
for id in res.ids:
|
|
doc_ids = res.field[id]["source_id"]
|
|
return doc_ids
|
|
|
|
|
|
async def get_graph(tenant_id, kb_id):
|
|
conds = {
|
|
"fields": ["content_with_weight", "source_id"],
|
|
"removed_kwd": "N",
|
|
"size": 1,
|
|
"knowledge_graph_kwd": ["graph"]
|
|
}
|
|
res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id]))
|
|
if res.total == 0:
|
|
return None
|
|
for id in res.ids:
|
|
try:
|
|
g = json_graph.node_link_graph(json.loads(res.field[id]["content_with_weight"]), edges="edges")
|
|
if "source_id" not in g.graph:
|
|
g.graph["source_id"] = res.field[id]["source_id"]
|
|
return g
|
|
except Exception:
|
|
continue
|
|
result = await rebuild_graph(tenant_id, kb_id)
|
|
return result
|
|
|
|
|
|
async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, change: GraphChange, callback):
|
|
start = trio.current_time()
|
|
|
|
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph"]}, search.index_name(tenant_id), kb_id))
|
|
|
|
if change.removed_nodes:
|
|
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id))
|
|
|
|
if change.removed_edges:
|
|
async with trio.open_nursery() as nursery:
|
|
for from_node, to_node in change.removed_edges:
|
|
nursery.start_soon(lambda from_node=from_node, to_node=to_node: trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id)))
|
|
now = trio.current_time()
|
|
if callback:
|
|
callback(msg=f"set_graph removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges from index in {now - start:.2f}s.")
|
|
start = now
|
|
|
|
chunks = [{
|
|
"id": get_uuid(),
|
|
"content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False),
|
|
"knowledge_graph_kwd": "graph",
|
|
"kb_id": kb_id,
|
|
"source_id": graph.graph.get("source_id", []),
|
|
"available_int": 0,
|
|
"removed_kwd": "N"
|
|
}]
|
|
async with trio.open_nursery() as nursery:
|
|
for node in change.added_updated_nodes:
|
|
node_attrs = graph.nodes[node]
|
|
nursery.start_soon(graph_node_to_chunk, kb_id, embd_mdl, node, node_attrs, chunks)
|
|
for from_node, to_node in change.added_updated_edges:
|
|
edge_attrs = graph.get_edge_data(from_node, to_node)
|
|
if not edge_attrs:
|
|
# added_updated_edges could record a non-existing edge if both from_node and to_node participate in nodes merging.
|
|
continue
|
|
nursery.start_soon(graph_edge_to_chunk, kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks)
|
|
now = trio.current_time()
|
|
if callback:
|
|
callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.")
|
|
start = now
|
|
|
|
es_bulk_size = 4
|
|
for b in range(0, len(chunks), es_bulk_size):
|
|
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(tenant_id), kb_id))
|
|
if doc_store_result:
|
|
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
|
raise Exception(error_message)
|
|
now = trio.current_time()
|
|
if callback:
|
|
callback(msg=f"set_graph added/updated {len(change.added_updated_nodes)} nodes and {len(change.added_updated_edges)} edges from index in {now - start:.2f}s.")
|
|
|
|
|
|
def is_continuous_subsequence(subseq, seq):
|
|
def find_all_indexes(tup, value):
|
|
indexes = []
|
|
start = 0
|
|
while True:
|
|
try:
|
|
index = tup.index(value, start)
|
|
indexes.append(index)
|
|
start = index + 1
|
|
except ValueError:
|
|
break
|
|
return indexes
|
|
|
|
index_list = find_all_indexes(seq,subseq[0])
|
|
for idx in index_list:
|
|
if idx!=len(seq)-1:
|
|
if seq[idx+1]==subseq[-1]:
|
|
return True
|
|
return False
|
|
|
|
|
|
def merge_tuples(list1, list2):
|
|
result = []
|
|
for tup in list1:
|
|
last_element = tup[-1]
|
|
if last_element in tup[:-1]:
|
|
result.append(tup)
|
|
else:
|
|
matching_tuples = [t for t in list2 if t[0] == last_element]
|
|
already_match_flag = 0
|
|
for match in matching_tuples:
|
|
matchh = (match[1], match[0])
|
|
if is_continuous_subsequence(match, tup) or is_continuous_subsequence(matchh, tup):
|
|
continue
|
|
already_match_flag = 1
|
|
merged_tuple = tup + match[1:]
|
|
result.append(merged_tuple)
|
|
if not already_match_flag:
|
|
result.append(tup)
|
|
return result
|
|
|
|
|
|
async def get_entity_type2sampels(idxnms, kb_ids: list):
|
|
es_res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids,
|
|
"size": 10000,
|
|
"fields": ["content_with_weight"]},
|
|
idxnms, kb_ids))
|
|
|
|
res = defaultdict(list)
|
|
for id in es_res.ids:
|
|
smp = es_res.field[id].get("content_with_weight")
|
|
if not smp:
|
|
continue
|
|
try:
|
|
smp = json.loads(smp)
|
|
except Exception as e:
|
|
logging.exception(e)
|
|
|
|
for ty, ents in smp.items():
|
|
res[ty].extend(ents)
|
|
return res
|
|
|
|
|
|
def flat_uniq_list(arr, key):
|
|
res = []
|
|
for a in arr:
|
|
a = a[key]
|
|
if isinstance(a, list):
|
|
res.extend(a)
|
|
else:
|
|
res.append(a)
|
|
return list(set(res))
|
|
|
|
|
|
async def rebuild_graph(tenant_id, kb_id):
|
|
graph = nx.Graph()
|
|
src_ids = set()
|
|
flds = ["entity_kwd", "from_entity_kwd", "to_entity_kwd", "knowledge_graph_kwd", "content_with_weight", "source_id"]
|
|
bs = 256
|
|
for i in range(0, 1024*bs, bs):
|
|
es_res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(flds, [],
|
|
{"kb_id": kb_id, "knowledge_graph_kwd": ["entity"]},
|
|
[],
|
|
OrderByExpr(),
|
|
i, bs, search.index_name(tenant_id), [kb_id]
|
|
))
|
|
tot = settings.docStoreConn.getTotal(es_res)
|
|
if tot == 0:
|
|
break
|
|
|
|
es_res = settings.docStoreConn.getFields(es_res, flds)
|
|
for id, d in es_res.items():
|
|
assert d["knowledge_graph_kwd"] == "relation"
|
|
src_ids.update(d.get("source_id", []))
|
|
attrs = json.load(d["content_with_weight"])
|
|
graph.add_node(d["entity_kwd"], **attrs)
|
|
|
|
for i in range(0, 1024*bs, bs):
|
|
es_res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(flds, [],
|
|
{"kb_id": kb_id, "knowledge_graph_kwd": ["relation"]},
|
|
[],
|
|
OrderByExpr(),
|
|
i, bs, search.index_name(tenant_id), [kb_id]
|
|
))
|
|
tot = settings.docStoreConn.getTotal(es_res)
|
|
if tot == 0:
|
|
return None
|
|
|
|
es_res = settings.docStoreConn.getFields(es_res, flds)
|
|
for id, d in es_res.items():
|
|
assert d["knowledge_graph_kwd"] == "relation"
|
|
src_ids.update(d.get("source_id", []))
|
|
if graph.has_node(d["from_entity_kwd"]) and graph.has_node(d["to_entity_kwd"]):
|
|
attrs = json.load(d["content_with_weight"])
|
|
graph.add_edge(d["from_entity_kwd"], d["to_entity_kwd"], **attrs)
|
|
|
|
src_ids = sorted(src_ids)
|
|
graph.graph["source_id"] = src_ids
|
|
return graph
|