mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-13 17:28:59 +08:00
Optimized graphrag again (#5927)
### What problem does this PR solve? Optimized graphrag again ### Type of change - [x] Performance Improvement
This commit is contained in:
parent
45318e7575
commit
939e668096
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import logging
|
||||||
import itertools
|
import itertools
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
@ -67,7 +68,7 @@ class EntityResolution(Extractor):
|
|||||||
self._resolution_result_delimiter_key = "resolution_result_delimiter"
|
self._resolution_result_delimiter_key = "resolution_result_delimiter"
|
||||||
self._input_text_key = "input_text"
|
self._input_text_key = "input_text"
|
||||||
|
|
||||||
async def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult:
|
async def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None, callback: Callable | None = None) -> EntityResolutionResult:
|
||||||
"""Call method definition."""
|
"""Call method definition."""
|
||||||
if prompt_variables is None:
|
if prompt_variables is None:
|
||||||
prompt_variables = {}
|
prompt_variables = {}
|
||||||
@ -93,6 +94,8 @@ class EntityResolution(Extractor):
|
|||||||
candidate_resolution = {entity_type: [] for entity_type in entity_types}
|
candidate_resolution = {entity_type: [] for entity_type in entity_types}
|
||||||
for k, v in node_clusters.items():
|
for k, v in node_clusters.items():
|
||||||
candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if self.is_similarity(a, b)]
|
candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if self.is_similarity(a, b)]
|
||||||
|
num_candidates = sum([len(candidates) for _, candidates in candidate_resolution.items()])
|
||||||
|
callback(msg=f"Identified {num_candidates} candidate pairs")
|
||||||
|
|
||||||
resolution_result = set()
|
resolution_result = set()
|
||||||
async with trio.open_nursery() as nursery:
|
async with trio.open_nursery() as nursery:
|
||||||
@ -100,48 +103,52 @@ class EntityResolution(Extractor):
|
|||||||
if not candidate_resolution_i[1]:
|
if not candidate_resolution_i[1]:
|
||||||
continue
|
continue
|
||||||
nursery.start_soon(lambda: self._resolve_candidate(candidate_resolution_i, resolution_result))
|
nursery.start_soon(lambda: self._resolve_candidate(candidate_resolution_i, resolution_result))
|
||||||
|
callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.")
|
||||||
|
|
||||||
connect_graph = nx.Graph()
|
connect_graph = nx.Graph()
|
||||||
removed_entities = []
|
removed_entities = []
|
||||||
connect_graph.add_edges_from(resolution_result)
|
connect_graph.add_edges_from(resolution_result)
|
||||||
all_entities_data = []
|
all_entities_data = []
|
||||||
all_relationships_data = []
|
all_relationships_data = []
|
||||||
|
all_remove_nodes = []
|
||||||
|
|
||||||
for sub_connect_graph in nx.connected_components(connect_graph):
|
async with trio.open_nursery() as nursery:
|
||||||
sub_connect_graph = connect_graph.subgraph(sub_connect_graph)
|
for sub_connect_graph in nx.connected_components(connect_graph):
|
||||||
remove_nodes = list(sub_connect_graph.nodes)
|
sub_connect_graph = connect_graph.subgraph(sub_connect_graph)
|
||||||
keep_node = remove_nodes.pop()
|
remove_nodes = list(sub_connect_graph.nodes)
|
||||||
await self._merge_nodes(keep_node, self._get_entity_(remove_nodes), all_entities_data)
|
keep_node = remove_nodes.pop()
|
||||||
for remove_node in remove_nodes:
|
all_remove_nodes.append(remove_nodes)
|
||||||
removed_entities.append(remove_node)
|
nursery.start_soon(lambda: self._merge_nodes(keep_node, self._get_entity_(remove_nodes), all_entities_data))
|
||||||
remove_node_neighbors = graph[remove_node]
|
for remove_node in remove_nodes:
|
||||||
remove_node_neighbors = list(remove_node_neighbors)
|
removed_entities.append(remove_node)
|
||||||
for remove_node_neighbor in remove_node_neighbors:
|
remove_node_neighbors = graph[remove_node]
|
||||||
rel = self._get_relation_(remove_node, remove_node_neighbor)
|
remove_node_neighbors = list(remove_node_neighbors)
|
||||||
if graph.has_edge(remove_node, remove_node_neighbor):
|
for remove_node_neighbor in remove_node_neighbors:
|
||||||
graph.remove_edge(remove_node, remove_node_neighbor)
|
rel = self._get_relation_(remove_node, remove_node_neighbor)
|
||||||
if remove_node_neighbor == keep_node:
|
if graph.has_edge(remove_node, remove_node_neighbor):
|
||||||
if graph.has_edge(keep_node, remove_node):
|
graph.remove_edge(remove_node, remove_node_neighbor)
|
||||||
graph.remove_edge(keep_node, remove_node)
|
if remove_node_neighbor == keep_node:
|
||||||
continue
|
if graph.has_edge(keep_node, remove_node):
|
||||||
if not rel:
|
graph.remove_edge(keep_node, remove_node)
|
||||||
continue
|
continue
|
||||||
if graph.has_edge(keep_node, remove_node_neighbor):
|
if not rel:
|
||||||
await self._merge_edges(keep_node, remove_node_neighbor, [rel], all_relationships_data)
|
continue
|
||||||
else:
|
if graph.has_edge(keep_node, remove_node_neighbor):
|
||||||
pair = sorted([keep_node, remove_node_neighbor])
|
nursery.start_soon(lambda: self._merge_edges(keep_node, remove_node_neighbor, [rel], all_relationships_data))
|
||||||
graph.add_edge(pair[0], pair[1], weight=rel['weight'])
|
else:
|
||||||
self._set_relation_(pair[0], pair[1],
|
pair = sorted([keep_node, remove_node_neighbor])
|
||||||
dict(
|
graph.add_edge(pair[0], pair[1], weight=rel['weight'])
|
||||||
src_id=pair[0],
|
self._set_relation_(pair[0], pair[1],
|
||||||
tgt_id=pair[1],
|
dict(
|
||||||
weight=rel['weight'],
|
src_id=pair[0],
|
||||||
description=rel['description'],
|
tgt_id=pair[1],
|
||||||
keywords=[],
|
weight=rel['weight'],
|
||||||
source_id=rel.get("source_id", ""),
|
description=rel['description'],
|
||||||
metadata={"created_at": time.time()}
|
keywords=[],
|
||||||
))
|
source_id=rel.get("source_id", ""),
|
||||||
graph.remove_node(remove_node)
|
metadata={"created_at": time.time()}
|
||||||
|
))
|
||||||
|
graph.remove_node(remove_node)
|
||||||
|
|
||||||
return EntityResolutionResult(
|
return EntityResolutionResult(
|
||||||
graph=graph,
|
graph=graph,
|
||||||
@ -164,8 +171,10 @@ class EntityResolution(Extractor):
|
|||||||
self._input_text_key: pair_prompt
|
self._input_text_key: pair_prompt
|
||||||
}
|
}
|
||||||
text = perform_variable_replacements(self._resolution_prompt, variables=variables)
|
text = perform_variable_replacements(self._resolution_prompt, variables=variables)
|
||||||
|
logging.info(f"Created resolution prompt {len(text)} bytes for {len(candidate_resolution_i[1])} entity pairs of type {candidate_resolution_i[0]}")
|
||||||
async with chat_limiter:
|
async with chat_limiter:
|
||||||
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
|
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
|
||||||
|
logging.debug(f"_resolve_candidate chat prompt: {text}\nchat response: {response}")
|
||||||
result = self._process_results(len(candidate_resolution_i[1]), response,
|
result = self._process_results(len(candidate_resolution_i[1]), response,
|
||||||
self.prompt_variables.get(self._record_delimiter_key,
|
self.prompt_variables.get(self._record_delimiter_key,
|
||||||
DEFAULT_RECORD_DELIMITER),
|
DEFAULT_RECORD_DELIMITER),
|
||||||
|
@ -19,7 +19,6 @@ from graphrag.general.leiden import add_community_info2graph
|
|||||||
from rag.llm.chat_model import Base as CompletionLLM
|
from rag.llm.chat_model import Base as CompletionLLM
|
||||||
from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
|
from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
|
||||||
from rag.utils import num_tokens_from_string
|
from rag.utils import num_tokens_from_string
|
||||||
from timeit import default_timer as timer
|
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
|
|
||||||
@ -62,62 +61,69 @@ class CommunityReportsExtractor(Extractor):
|
|||||||
res_str = []
|
res_str = []
|
||||||
res_dict = []
|
res_dict = []
|
||||||
over, token_count = 0, 0
|
over, token_count = 0, 0
|
||||||
st = timer()
|
async def extract_community_report(community):
|
||||||
for level, comm in communities.items():
|
nonlocal res_str, res_dict, over, token_count
|
||||||
logging.info(f"Level {level}: Community: {len(comm.keys())}")
|
cm_id, ents = community
|
||||||
for cm_id, ents in comm.items():
|
weight = ents["weight"]
|
||||||
weight = ents["weight"]
|
ents = ents["nodes"]
|
||||||
ents = ents["nodes"]
|
ent_df = pd.DataFrame(self._get_entity_(ents)).dropna()
|
||||||
ent_df = pd.DataFrame(self._get_entity_(ents)).dropna()#[{"entity": n, **graph.nodes[n]} for n in ents])
|
if ent_df.empty or "entity_name" not in ent_df.columns:
|
||||||
if ent_df.empty or "entity_name" not in ent_df.columns:
|
return
|
||||||
continue
|
ent_df["entity"] = ent_df["entity_name"]
|
||||||
ent_df["entity"] = ent_df["entity_name"]
|
del ent_df["entity_name"]
|
||||||
del ent_df["entity_name"]
|
rela_df = pd.DataFrame(self._get_relation_(list(ent_df["entity"]), list(ent_df["entity"]), 10000))
|
||||||
rela_df = pd.DataFrame(self._get_relation_(list(ent_df["entity"]), list(ent_df["entity"]), 10000))
|
if rela_df.empty:
|
||||||
if rela_df.empty:
|
return
|
||||||
continue
|
rela_df["source"] = rela_df["src_id"]
|
||||||
rela_df["source"] = rela_df["src_id"]
|
rela_df["target"] = rela_df["tgt_id"]
|
||||||
rela_df["target"] = rela_df["tgt_id"]
|
del rela_df["src_id"]
|
||||||
del rela_df["src_id"]
|
del rela_df["tgt_id"]
|
||||||
del rela_df["tgt_id"]
|
|
||||||
|
|
||||||
prompt_variables = {
|
prompt_variables = {
|
||||||
"entity_df": ent_df.to_csv(index_label="id"),
|
"entity_df": ent_df.to_csv(index_label="id"),
|
||||||
"relation_df": rela_df.to_csv(index_label="id")
|
"relation_df": rela_df.to_csv(index_label="id")
|
||||||
}
|
}
|
||||||
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
|
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
|
||||||
gen_conf = {"temperature": 0.3}
|
gen_conf = {"temperature": 0.3}
|
||||||
async with chat_limiter:
|
async with chat_limiter:
|
||||||
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
|
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
|
||||||
token_count += num_tokens_from_string(text + response)
|
token_count += num_tokens_from_string(text + response)
|
||||||
response = re.sub(r"^[^\{]*", "", response)
|
response = re.sub(r"^[^\{]*", "", response)
|
||||||
response = re.sub(r"[^\}]*$", "", response)
|
response = re.sub(r"[^\}]*$", "", response)
|
||||||
response = re.sub(r"\{\{", "{", response)
|
response = re.sub(r"\{\{", "{", response)
|
||||||
response = re.sub(r"\}\}", "}", response)
|
response = re.sub(r"\}\}", "}", response)
|
||||||
logging.debug(response)
|
logging.debug(response)
|
||||||
try:
|
try:
|
||||||
response = json.loads(response)
|
response = json.loads(response)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logging.error(f"Failed to parse JSON response: {e}")
|
logging.error(f"Failed to parse JSON response: {e}")
|
||||||
logging.error(f"Response content: {response}")
|
logging.error(f"Response content: {response}")
|
||||||
continue
|
return
|
||||||
if not dict_has_keys_with_types(response, [
|
if not dict_has_keys_with_types(response, [
|
||||||
("title", str),
|
("title", str),
|
||||||
("summary", str),
|
("summary", str),
|
||||||
("findings", list),
|
("findings", list),
|
||||||
("rating", float),
|
("rating", float),
|
||||||
("rating_explanation", str),
|
("rating_explanation", str),
|
||||||
]):
|
]):
|
||||||
continue
|
return
|
||||||
response["weight"] = weight
|
response["weight"] = weight
|
||||||
response["entities"] = ents
|
response["entities"] = ents
|
||||||
|
add_community_info2graph(graph, ents, response["title"])
|
||||||
|
res_str.append(self._get_text_output(response))
|
||||||
|
res_dict.append(response)
|
||||||
|
over += 1
|
||||||
|
if callback:
|
||||||
|
callback(msg=f"Communities: {over}/{total}, used tokens: {token_count}")
|
||||||
|
|
||||||
add_community_info2graph(graph, ents, response["title"])
|
st = trio.current_time()
|
||||||
res_str.append(self._get_text_output(response))
|
async with trio.open_nursery() as nursery:
|
||||||
res_dict.append(response)
|
for level, comm in communities.items():
|
||||||
over += 1
|
logging.info(f"Level {level}: Community: {len(comm.keys())}")
|
||||||
if callback:
|
for community in comm.items():
|
||||||
callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}")
|
nursery.start_soon(lambda: extract_community_report(community))
|
||||||
|
if callback:
|
||||||
|
callback(msg=f"Community reports done in {trio.current_time() - st:.2f}s, used tokens: {token_count}")
|
||||||
|
|
||||||
return CommunityReportsResult(
|
return CommunityReportsResult(
|
||||||
structured_output=res_dict,
|
structured_output=res_dict,
|
||||||
|
@ -228,7 +228,7 @@ async def resolve_entities(
|
|||||||
get_relation=partial(get_relation, tenant_id, kb_id),
|
get_relation=partial(get_relation, tenant_id, kb_id),
|
||||||
set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl),
|
set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl),
|
||||||
)
|
)
|
||||||
reso = await er(graph)
|
reso = await er(graph, callback=callback)
|
||||||
graph = reso.graph
|
graph = reso.graph
|
||||||
callback(msg=f"Graph resolution removed {len(reso.removed_entities)} nodes.")
|
callback(msg=f"Graph resolution removed {len(reso.removed_entities)} nodes.")
|
||||||
await update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, 2)
|
await update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, 2)
|
||||||
|
@ -489,15 +489,16 @@ async def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop):
|
|||||||
return nbrs
|
return nbrs
|
||||||
|
|
||||||
pr = nx.pagerank(graph)
|
pr = nx.pagerank(graph)
|
||||||
for n, p in pr.items():
|
try:
|
||||||
graph.nodes[n]["pagerank"] = p
|
async with trio.open_nursery() as nursery:
|
||||||
try:
|
for n, p in pr.items():
|
||||||
await trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"entity_kwd": n, "kb_id": kb_id},
|
graph.nodes[n]["pagerank"] = p
|
||||||
{"rank_flt": p,
|
nursery.start_soon(lambda: trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"entity_kwd": n, "kb_id": kb_id},
|
||||||
"n_hop_with_weight": json.dumps( (n), ensure_ascii=False)},
|
{"rank_flt": p,
|
||||||
search.index_name(tenant_id), kb_id))
|
"n_hop_with_weight": json.dumps((n), ensure_ascii=False)},
|
||||||
except Exception as e:
|
search.index_name(tenant_id), kb_id)))
|
||||||
logging.exception(e)
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
|
||||||
ty2ents = defaultdict(list)
|
ty2ents = defaultdict(list)
|
||||||
for p, r in sorted(pr.items(), key=lambda x: x[1], reverse=True):
|
for p, r in sorted(pr.items(), key=lambda x: x[1], reverse=True):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user