mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-23 06:30:00 +08:00

### What problem does this PR solve? #4543 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
211 lines
10 KiB
Python
211 lines
10 KiB
Python
#
|
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
import logging
|
|
import itertools
|
|
import re
|
|
import time
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable
|
|
|
|
import networkx as nx
|
|
|
|
from graphrag.general.extractor import Extractor
|
|
from rag.nlp import is_english
|
|
import editdistance
|
|
from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
|
|
from rag.llm.chat_model import Base as CompletionLLM
|
|
from graphrag.utils import perform_variable_replacements
|
|
|
|
DEFAULT_RECORD_DELIMITER = "##"
|
|
DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
|
|
DEFAULT_RESOLUTION_RESULT_DELIMITER = "&&"
|
|
|
|
|
|
@dataclass
|
|
class EntityResolutionResult:
|
|
"""Entity resolution result class definition."""
|
|
graph: nx.Graph
|
|
removed_entities: list
|
|
|
|
|
|
class EntityResolution(Extractor):
|
|
"""Entity resolution class definition."""
|
|
|
|
_resolution_prompt: str
|
|
_output_formatter_prompt: str
|
|
_record_delimiter_key: str
|
|
_entity_index_delimiter_key: str
|
|
_resolution_result_delimiter_key: str
|
|
|
|
def __init__(
|
|
self,
|
|
llm_invoker: CompletionLLM,
|
|
get_entity: Callable | None = None,
|
|
set_entity: Callable | None = None,
|
|
get_relation: Callable | None = None,
|
|
set_relation: Callable | None = None
|
|
):
|
|
super().__init__(llm_invoker, get_entity=get_entity, set_entity=set_entity, get_relation=get_relation, set_relation=set_relation)
|
|
"""Init method definition."""
|
|
self._llm = llm_invoker
|
|
self._resolution_prompt = ENTITY_RESOLUTION_PROMPT
|
|
self._record_delimiter_key = "record_delimiter"
|
|
self._entity_index_dilimiter_key = "entity_index_delimiter"
|
|
self._resolution_result_delimiter_key = "resolution_result_delimiter"
|
|
self._input_text_key = "input_text"
|
|
|
|
def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult:
|
|
"""Call method definition."""
|
|
if prompt_variables is None:
|
|
prompt_variables = {}
|
|
|
|
# Wire defaults into the prompt variables
|
|
prompt_variables = {
|
|
**prompt_variables,
|
|
self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
|
|
or DEFAULT_RECORD_DELIMITER,
|
|
self._entity_index_dilimiter_key: prompt_variables.get(self._entity_index_dilimiter_key)
|
|
or DEFAULT_ENTITY_INDEX_DELIMITER,
|
|
self._resolution_result_delimiter_key: prompt_variables.get(self._resolution_result_delimiter_key)
|
|
or DEFAULT_RESOLUTION_RESULT_DELIMITER,
|
|
}
|
|
|
|
nodes = graph.nodes
|
|
entity_types = list(set(graph.nodes[node].get('entity_type', '-') for node in nodes))
|
|
node_clusters = {entity_type: [] for entity_type in entity_types}
|
|
|
|
for node in nodes:
|
|
node_clusters[graph.nodes[node].get('entity_type', '-')].append(node)
|
|
|
|
candidate_resolution = {entity_type: [] for entity_type in entity_types}
|
|
for k, v in node_clusters.items():
|
|
candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if self.is_similarity(a, b)]
|
|
|
|
gen_conf = {"temperature": 0.5}
|
|
resolution_result = set()
|
|
for candidate_resolution_i in candidate_resolution.items():
|
|
if candidate_resolution_i[1]:
|
|
try:
|
|
pair_txt = [
|
|
f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n']
|
|
for index, candidate in enumerate(candidate_resolution_i[1]):
|
|
pair_txt.append(
|
|
f'Question {index + 1}: name of{candidate_resolution_i[0]} A is {candidate[0]} ,name of{candidate_resolution_i[0]} B is {candidate[1]}')
|
|
sent = 'question above' if len(pair_txt) == 1 else f'above {len(pair_txt)} questions'
|
|
pair_txt.append(
|
|
f'\nUse domain knowledge of {candidate_resolution_i[0]}s to help understand the text and answer the {sent} in the format: For Question i, Yes, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are the same {candidate_resolution_i[0]}./No, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are different {candidate_resolution_i[0]}s. For Question i+1, (repeat the above procedures)')
|
|
pair_prompt = '\n'.join(pair_txt)
|
|
|
|
variables = {
|
|
**prompt_variables,
|
|
self._input_text_key: pair_prompt
|
|
}
|
|
text = perform_variable_replacements(self._resolution_prompt, variables=variables)
|
|
|
|
response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
|
|
result = self._process_results(len(candidate_resolution_i[1]), response,
|
|
prompt_variables.get(self._record_delimiter_key,
|
|
DEFAULT_RECORD_DELIMITER),
|
|
prompt_variables.get(self._entity_index_dilimiter_key,
|
|
DEFAULT_ENTITY_INDEX_DELIMITER),
|
|
prompt_variables.get(self._resolution_result_delimiter_key,
|
|
DEFAULT_RESOLUTION_RESULT_DELIMITER))
|
|
for result_i in result:
|
|
resolution_result.add(candidate_resolution_i[1][result_i[0] - 1])
|
|
except Exception:
|
|
logging.exception("error entity resolution")
|
|
|
|
connect_graph = nx.Graph()
|
|
removed_entities = []
|
|
connect_graph.add_edges_from(resolution_result)
|
|
for sub_connect_graph in nx.connected_components(connect_graph):
|
|
sub_connect_graph = connect_graph.subgraph(sub_connect_graph)
|
|
remove_nodes = list(sub_connect_graph.nodes)
|
|
keep_node = remove_nodes.pop()
|
|
self._merge_nodes(keep_node, self._get_entity_(remove_nodes))
|
|
for remove_node in remove_nodes:
|
|
removed_entities.append(remove_node)
|
|
remove_node_neighbors = graph[remove_node]
|
|
remove_node_neighbors = list(remove_node_neighbors)
|
|
for remove_node_neighbor in remove_node_neighbors:
|
|
rel = self._get_relation_(remove_node, remove_node_neighbor)
|
|
if graph.has_edge(remove_node, remove_node_neighbor):
|
|
graph.remove_edge(remove_node, remove_node_neighbor)
|
|
if remove_node_neighbor == keep_node:
|
|
if graph.has_edge(keep_node, remove_node):
|
|
graph.remove_edge(keep_node, remove_node)
|
|
continue
|
|
if not rel:
|
|
continue
|
|
if graph.has_edge(keep_node, remove_node_neighbor):
|
|
self._merge_edges(keep_node, remove_node_neighbor, [rel])
|
|
else:
|
|
pair = sorted([keep_node, remove_node_neighbor])
|
|
graph.add_edge(pair[0], pair[1], weight=rel['weight'])
|
|
self._set_relation_(pair[0], pair[1],
|
|
dict(
|
|
src_id=pair[0],
|
|
tgt_id=pair[1],
|
|
weight=rel['weight'],
|
|
description=rel['description'],
|
|
keywords=[],
|
|
source_id=rel.get("source_id", ""),
|
|
metadata={"created_at": time.time()}
|
|
))
|
|
graph.remove_node(remove_node)
|
|
|
|
return EntityResolutionResult(
|
|
graph=graph,
|
|
removed_entities=removed_entities
|
|
)
|
|
|
|
def _process_results(
|
|
self,
|
|
records_length: int,
|
|
results: str,
|
|
record_delimiter: str,
|
|
entity_index_delimiter: str,
|
|
resolution_result_delimiter: str
|
|
) -> list:
|
|
ans_list = []
|
|
records = [r.strip() for r in results.split(record_delimiter)]
|
|
for record in records:
|
|
pattern_int = f"{re.escape(entity_index_delimiter)}(\d+){re.escape(entity_index_delimiter)}"
|
|
match_int = re.search(pattern_int, record)
|
|
res_int = int(str(match_int.group(1) if match_int else '0'))
|
|
if res_int > records_length:
|
|
continue
|
|
|
|
pattern_bool = f"{re.escape(resolution_result_delimiter)}([a-zA-Z]+){re.escape(resolution_result_delimiter)}"
|
|
match_bool = re.search(pattern_bool, record)
|
|
res_bool = str(match_bool.group(1) if match_bool else '')
|
|
|
|
if res_int and res_bool:
|
|
if res_bool.lower() == 'yes':
|
|
ans_list.append((res_int, "yes"))
|
|
|
|
return ans_list
|
|
|
|
def is_similarity(self, a, b):
|
|
if is_english(a) and is_english(b):
|
|
if editdistance.eval(a, b) <= min(len(a), len(b)) // 2:
|
|
return True
|
|
|
|
if len(set(a) & set(b)) > 0:
|
|
return True
|
|
|
|
return False
|