mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-19 12:39:59 +08:00
### What problem does this PR solve? EntityResolution batch ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
parent
d2043ff9f2
commit
36b62e0fab
@ -63,7 +63,10 @@ class EntityResolution(Extractor):
|
||||
self._resolution_result_delimiter_key = "resolution_result_delimiter"
|
||||
self._input_text_key = "input_text"
|
||||
|
||||
async def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None, callback: Callable | None = None) -> EntityResolutionResult:
|
||||
async def __call__(self, graph: nx.Graph,
|
||||
subgraph_nodes: set[str],
|
||||
prompt_variables: dict[str, Any] | None = None,
|
||||
callback: Callable | None = None) -> EntityResolutionResult:
|
||||
"""Call method definition."""
|
||||
if prompt_variables is None:
|
||||
prompt_variables = {}
|
||||
@ -88,16 +91,19 @@ class EntityResolution(Extractor):
|
||||
|
||||
candidate_resolution = {entity_type: [] for entity_type in entity_types}
|
||||
for k, v in node_clusters.items():
|
||||
candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if self.is_similarity(a, b)]
|
||||
candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if (a in subgraph_nodes or b in subgraph_nodes) and self.is_similarity(a, b)]
|
||||
num_candidates = sum([len(candidates) for _, candidates in candidate_resolution.items()])
|
||||
callback(msg=f"Identified {num_candidates} candidate pairs")
|
||||
|
||||
resolution_result = set()
|
||||
resolution_batch_size = 100
|
||||
async with trio.open_nursery() as nursery:
|
||||
for candidate_resolution_i in candidate_resolution.items():
|
||||
if not candidate_resolution_i[1]:
|
||||
continue
|
||||
nursery.start_soon(lambda: self._resolve_candidate(candidate_resolution_i, resolution_result))
|
||||
for i in range(0, len(candidate_resolution_i[1]), resolution_batch_size):
|
||||
candidate_batch = candidate_resolution_i[0], candidate_resolution_i[1][i:i + resolution_batch_size]
|
||||
nursery.start_soon(lambda: self._resolve_candidate(candidate_batch, resolution_result))
|
||||
callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.")
|
||||
|
||||
change = GraphChange()
|
||||
@ -118,7 +124,7 @@ class EntityResolution(Extractor):
|
||||
change=change,
|
||||
)
|
||||
|
||||
async def _resolve_candidate(self, candidate_resolution_i, resolution_result):
|
||||
async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str]):
|
||||
gen_conf = {"temperature": 0.5}
|
||||
pair_txt = [
|
||||
f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n']
|
||||
|
@ -69,26 +69,27 @@ async def run_graphrag(
|
||||
embedding_model,
|
||||
callback,
|
||||
)
|
||||
new_graph = None
|
||||
if subgraph:
|
||||
new_graph = await merge_subgraph(
|
||||
tenant_id,
|
||||
kb_id,
|
||||
doc_id,
|
||||
subgraph,
|
||||
embedding_model,
|
||||
callback,
|
||||
)
|
||||
if not subgraph:
|
||||
return
|
||||
|
||||
subgraph_nodes = set(subgraph.nodes())
|
||||
new_graph = await merge_subgraph(
|
||||
tenant_id,
|
||||
kb_id,
|
||||
doc_id,
|
||||
subgraph,
|
||||
embedding_model,
|
||||
callback,
|
||||
)
|
||||
assert new_graph is not None
|
||||
|
||||
if not with_resolution or not with_community:
|
||||
return
|
||||
|
||||
if new_graph is None:
|
||||
new_graph = await get_graph(tenant_id, kb_id)
|
||||
|
||||
if with_resolution and new_graph is not None:
|
||||
if with_resolution:
|
||||
await resolve_entities(
|
||||
new_graph,
|
||||
subgraph_nodes,
|
||||
tenant_id,
|
||||
kb_id,
|
||||
doc_id,
|
||||
@ -96,7 +97,7 @@ async def run_graphrag(
|
||||
embedding_model,
|
||||
callback,
|
||||
)
|
||||
if with_community and new_graph is not None:
|
||||
if with_community:
|
||||
await extract_community(
|
||||
new_graph,
|
||||
tenant_id,
|
||||
@ -223,6 +224,7 @@ async def merge_subgraph(
|
||||
|
||||
async def resolve_entities(
|
||||
graph,
|
||||
subgraph_nodes: set[str],
|
||||
tenant_id: str,
|
||||
kb_id: str,
|
||||
doc_id: str,
|
||||
@ -241,7 +243,7 @@ async def resolve_entities(
|
||||
er = EntityResolution(
|
||||
llm_bdl,
|
||||
)
|
||||
reso = await er(graph, callback=callback)
|
||||
reso = await er(graph, subgraph_nodes, callback=callback)
|
||||
graph = reso.graph
|
||||
change = reso.change
|
||||
callback(msg=f"Graph resolution removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges.")
|
||||
|
Loading…
x
Reference in New Issue
Block a user