Optimized graphrag again (#5927)

### What problem does this PR solve?

Optimized graphrag again

### Type of change

- [x] Performance Improvement
This commit is contained in:
Zhichang Yu 2025-03-11 18:36:10 +08:00 committed by GitHub
parent 45318e7575
commit 939e668096
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 117 additions and 101 deletions

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import logging
import itertools import itertools
import re import re
import time import time
@ -67,7 +68,7 @@ class EntityResolution(Extractor):
self._resolution_result_delimiter_key = "resolution_result_delimiter" self._resolution_result_delimiter_key = "resolution_result_delimiter"
self._input_text_key = "input_text" self._input_text_key = "input_text"
async def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult: async def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None, callback: Callable | None = None) -> EntityResolutionResult:
"""Call method definition.""" """Call method definition."""
if prompt_variables is None: if prompt_variables is None:
prompt_variables = {} prompt_variables = {}
@ -93,6 +94,8 @@ class EntityResolution(Extractor):
candidate_resolution = {entity_type: [] for entity_type in entity_types} candidate_resolution = {entity_type: [] for entity_type in entity_types}
for k, v in node_clusters.items(): for k, v in node_clusters.items():
candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if self.is_similarity(a, b)] candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if self.is_similarity(a, b)]
num_candidates = sum([len(candidates) for _, candidates in candidate_resolution.items()])
callback(msg=f"Identified {num_candidates} candidate pairs")
resolution_result = set() resolution_result = set()
async with trio.open_nursery() as nursery: async with trio.open_nursery() as nursery:
@ -100,48 +103,52 @@ class EntityResolution(Extractor):
if not candidate_resolution_i[1]: if not candidate_resolution_i[1]:
continue continue
nursery.start_soon(lambda: self._resolve_candidate(candidate_resolution_i, resolution_result)) nursery.start_soon(lambda: self._resolve_candidate(candidate_resolution_i, resolution_result))
callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.")
connect_graph = nx.Graph() connect_graph = nx.Graph()
removed_entities = [] removed_entities = []
connect_graph.add_edges_from(resolution_result) connect_graph.add_edges_from(resolution_result)
all_entities_data = [] all_entities_data = []
all_relationships_data = [] all_relationships_data = []
all_remove_nodes = []
for sub_connect_graph in nx.connected_components(connect_graph): async with trio.open_nursery() as nursery:
sub_connect_graph = connect_graph.subgraph(sub_connect_graph) for sub_connect_graph in nx.connected_components(connect_graph):
remove_nodes = list(sub_connect_graph.nodes) sub_connect_graph = connect_graph.subgraph(sub_connect_graph)
keep_node = remove_nodes.pop() remove_nodes = list(sub_connect_graph.nodes)
await self._merge_nodes(keep_node, self._get_entity_(remove_nodes), all_entities_data) keep_node = remove_nodes.pop()
for remove_node in remove_nodes: all_remove_nodes.append(remove_nodes)
removed_entities.append(remove_node) nursery.start_soon(lambda: self._merge_nodes(keep_node, self._get_entity_(remove_nodes), all_entities_data))
remove_node_neighbors = graph[remove_node] for remove_node in remove_nodes:
remove_node_neighbors = list(remove_node_neighbors) removed_entities.append(remove_node)
for remove_node_neighbor in remove_node_neighbors: remove_node_neighbors = graph[remove_node]
rel = self._get_relation_(remove_node, remove_node_neighbor) remove_node_neighbors = list(remove_node_neighbors)
if graph.has_edge(remove_node, remove_node_neighbor): for remove_node_neighbor in remove_node_neighbors:
graph.remove_edge(remove_node, remove_node_neighbor) rel = self._get_relation_(remove_node, remove_node_neighbor)
if remove_node_neighbor == keep_node: if graph.has_edge(remove_node, remove_node_neighbor):
if graph.has_edge(keep_node, remove_node): graph.remove_edge(remove_node, remove_node_neighbor)
graph.remove_edge(keep_node, remove_node) if remove_node_neighbor == keep_node:
continue if graph.has_edge(keep_node, remove_node):
if not rel: graph.remove_edge(keep_node, remove_node)
continue continue
if graph.has_edge(keep_node, remove_node_neighbor): if not rel:
await self._merge_edges(keep_node, remove_node_neighbor, [rel], all_relationships_data) continue
else: if graph.has_edge(keep_node, remove_node_neighbor):
pair = sorted([keep_node, remove_node_neighbor]) nursery.start_soon(lambda: self._merge_edges(keep_node, remove_node_neighbor, [rel], all_relationships_data))
graph.add_edge(pair[0], pair[1], weight=rel['weight']) else:
self._set_relation_(pair[0], pair[1], pair = sorted([keep_node, remove_node_neighbor])
dict( graph.add_edge(pair[0], pair[1], weight=rel['weight'])
src_id=pair[0], self._set_relation_(pair[0], pair[1],
tgt_id=pair[1], dict(
weight=rel['weight'], src_id=pair[0],
description=rel['description'], tgt_id=pair[1],
keywords=[], weight=rel['weight'],
source_id=rel.get("source_id", ""), description=rel['description'],
metadata={"created_at": time.time()} keywords=[],
)) source_id=rel.get("source_id", ""),
graph.remove_node(remove_node) metadata={"created_at": time.time()}
))
graph.remove_node(remove_node)
return EntityResolutionResult( return EntityResolutionResult(
graph=graph, graph=graph,
@ -164,8 +171,10 @@ class EntityResolution(Extractor):
self._input_text_key: pair_prompt self._input_text_key: pair_prompt
} }
text = perform_variable_replacements(self._resolution_prompt, variables=variables) text = perform_variable_replacements(self._resolution_prompt, variables=variables)
logging.info(f"Created resolution prompt {len(text)} bytes for {len(candidate_resolution_i[1])} entity pairs of type {candidate_resolution_i[0]}")
async with chat_limiter: async with chat_limiter:
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)) response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
logging.debug(f"_resolve_candidate chat prompt: {text}\nchat response: {response}")
result = self._process_results(len(candidate_resolution_i[1]), response, result = self._process_results(len(candidate_resolution_i[1]), response,
self.prompt_variables.get(self._record_delimiter_key, self.prompt_variables.get(self._record_delimiter_key,
DEFAULT_RECORD_DELIMITER), DEFAULT_RECORD_DELIMITER),

View File

@ -19,7 +19,6 @@ from graphrag.general.leiden import add_community_info2graph
from rag.llm.chat_model import Base as CompletionLLM from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
from rag.utils import num_tokens_from_string from rag.utils import num_tokens_from_string
from timeit import default_timer as timer
import trio import trio
@ -62,62 +61,69 @@ class CommunityReportsExtractor(Extractor):
res_str = [] res_str = []
res_dict = [] res_dict = []
over, token_count = 0, 0 over, token_count = 0, 0
st = timer() async def extract_community_report(community):
for level, comm in communities.items(): nonlocal res_str, res_dict, over, token_count
logging.info(f"Level {level}: Community: {len(comm.keys())}") cm_id, ents = community
for cm_id, ents in comm.items(): weight = ents["weight"]
weight = ents["weight"] ents = ents["nodes"]
ents = ents["nodes"] ent_df = pd.DataFrame(self._get_entity_(ents)).dropna()
ent_df = pd.DataFrame(self._get_entity_(ents)).dropna()#[{"entity": n, **graph.nodes[n]} for n in ents]) if ent_df.empty or "entity_name" not in ent_df.columns:
if ent_df.empty or "entity_name" not in ent_df.columns: return
continue ent_df["entity"] = ent_df["entity_name"]
ent_df["entity"] = ent_df["entity_name"] del ent_df["entity_name"]
del ent_df["entity_name"] rela_df = pd.DataFrame(self._get_relation_(list(ent_df["entity"]), list(ent_df["entity"]), 10000))
rela_df = pd.DataFrame(self._get_relation_(list(ent_df["entity"]), list(ent_df["entity"]), 10000)) if rela_df.empty:
if rela_df.empty: return
continue rela_df["source"] = rela_df["src_id"]
rela_df["source"] = rela_df["src_id"] rela_df["target"] = rela_df["tgt_id"]
rela_df["target"] = rela_df["tgt_id"] del rela_df["src_id"]
del rela_df["src_id"] del rela_df["tgt_id"]
del rela_df["tgt_id"]
prompt_variables = { prompt_variables = {
"entity_df": ent_df.to_csv(index_label="id"), "entity_df": ent_df.to_csv(index_label="id"),
"relation_df": rela_df.to_csv(index_label="id") "relation_df": rela_df.to_csv(index_label="id")
} }
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables) text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
gen_conf = {"temperature": 0.3} gen_conf = {"temperature": 0.3}
async with chat_limiter: async with chat_limiter:
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)) response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
token_count += num_tokens_from_string(text + response) token_count += num_tokens_from_string(text + response)
response = re.sub(r"^[^\{]*", "", response) response = re.sub(r"^[^\{]*", "", response)
response = re.sub(r"[^\}]*$", "", response) response = re.sub(r"[^\}]*$", "", response)
response = re.sub(r"\{\{", "{", response) response = re.sub(r"\{\{", "{", response)
response = re.sub(r"\}\}", "}", response) response = re.sub(r"\}\}", "}", response)
logging.debug(response) logging.debug(response)
try: try:
response = json.loads(response) response = json.loads(response)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logging.error(f"Failed to parse JSON response: {e}") logging.error(f"Failed to parse JSON response: {e}")
logging.error(f"Response content: {response}") logging.error(f"Response content: {response}")
continue return
if not dict_has_keys_with_types(response, [ if not dict_has_keys_with_types(response, [
("title", str), ("title", str),
("summary", str), ("summary", str),
("findings", list), ("findings", list),
("rating", float), ("rating", float),
("rating_explanation", str), ("rating_explanation", str),
]): ]):
continue return
response["weight"] = weight response["weight"] = weight
response["entities"] = ents response["entities"] = ents
add_community_info2graph(graph, ents, response["title"])
res_str.append(self._get_text_output(response))
res_dict.append(response)
over += 1
if callback:
callback(msg=f"Communities: {over}/{total}, used tokens: {token_count}")
add_community_info2graph(graph, ents, response["title"]) st = trio.current_time()
res_str.append(self._get_text_output(response)) async with trio.open_nursery() as nursery:
res_dict.append(response) for level, comm in communities.items():
over += 1 logging.info(f"Level {level}: Community: {len(comm.keys())}")
if callback: for community in comm.items():
callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}") nursery.start_soon(lambda: extract_community_report(community))
if callback:
callback(msg=f"Community reports done in {trio.current_time() - st:.2f}s, used tokens: {token_count}")
return CommunityReportsResult( return CommunityReportsResult(
structured_output=res_dict, structured_output=res_dict,

View File

@ -228,7 +228,7 @@ async def resolve_entities(
get_relation=partial(get_relation, tenant_id, kb_id), get_relation=partial(get_relation, tenant_id, kb_id),
set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl), set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl),
) )
reso = await er(graph) reso = await er(graph, callback=callback)
graph = reso.graph graph = reso.graph
callback(msg=f"Graph resolution removed {len(reso.removed_entities)} nodes.") callback(msg=f"Graph resolution removed {len(reso.removed_entities)} nodes.")
await update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, 2) await update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, 2)

View File

@ -489,15 +489,16 @@ async def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop):
return nbrs return nbrs
pr = nx.pagerank(graph) pr = nx.pagerank(graph)
for n, p in pr.items(): try:
graph.nodes[n]["pagerank"] = p async with trio.open_nursery() as nursery:
try: for n, p in pr.items():
await trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"entity_kwd": n, "kb_id": kb_id}, graph.nodes[n]["pagerank"] = p
{"rank_flt": p, nursery.start_soon(lambda: trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"entity_kwd": n, "kb_id": kb_id},
"n_hop_with_weight": json.dumps( (n), ensure_ascii=False)}, {"rank_flt": p,
search.index_name(tenant_id), kb_id)) "n_hop_with_weight": json.dumps((n), ensure_ascii=False)},
except Exception as e: search.index_name(tenant_id), kb_id)))
logging.exception(e) except Exception as e:
logging.exception(e)
ty2ents = defaultdict(list) ty2ents = defaultdict(list)
for p, r in sorted(pr.items(), key=lambda x: x[1], reverse=True): for p, r in sorted(pr.items(), key=lambda x: x[1], reverse=True):