mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-23 22:50:17 +08:00

### What problem does this PR solve? ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
188 lines
6.5 KiB
Python
188 lines
6.5 KiB
Python
#
|
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import collections
|
|
import logging
|
|
import re
|
|
import logging
|
|
import traceback
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
from graphrag.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
|
|
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
|
|
from rag.llm.chat_model import Base as CompletionLLM
|
|
import markdown_to_json
|
|
from functools import reduce
|
|
from rag.utils import num_tokens_from_string
|
|
|
|
|
|
@dataclass
|
|
class MindMapResult:
|
|
"""Unipartite Mind Graph result class definition."""
|
|
output: dict
|
|
|
|
|
|
class MindMapExtractor:
|
|
|
|
_llm: CompletionLLM
|
|
_input_text_key: str
|
|
_mind_map_prompt: str
|
|
_on_error: ErrorHandlerFn
|
|
|
|
def __init__(
|
|
self,
|
|
llm_invoker: CompletionLLM,
|
|
prompt: str | None = None,
|
|
input_text_key: str | None = None,
|
|
on_error: ErrorHandlerFn | None = None,
|
|
):
|
|
"""Init method definition."""
|
|
# TODO: streamline construction
|
|
self._llm = llm_invoker
|
|
self._input_text_key = input_text_key or "input_text"
|
|
self._mind_map_prompt = prompt or MIND_MAP_EXTRACTION_PROMPT
|
|
self._on_error = on_error or (lambda _e, _s, _d: None)
|
|
|
|
def _key(self, k):
|
|
return re.sub(r"\*+", "", k)
|
|
|
|
def _be_children(self, obj: dict, keyset: set):
|
|
if isinstance(obj, str):
|
|
obj = [obj]
|
|
if isinstance(obj, list):
|
|
for i in obj: keyset.add(i)
|
|
return [{"id": re.sub(r"\*+", "", i), "children": []} for i in obj if re.sub(r"\*+", "", i)]
|
|
arr = []
|
|
for k, v in obj.items():
|
|
k = self._key(k)
|
|
if not k or k in keyset: continue
|
|
keyset.add(k)
|
|
arr.append({
|
|
"id": k,
|
|
"children": self._be_children(v, keyset)
|
|
})
|
|
return arr
|
|
|
|
def __call__(
|
|
self, sections: list[str], prompt_variables: dict[str, Any] | None = None
|
|
) -> MindMapResult:
|
|
"""Call method definition."""
|
|
if prompt_variables is None:
|
|
prompt_variables = {}
|
|
|
|
try:
|
|
exe = ThreadPoolExecutor(max_workers=12)
|
|
threads = []
|
|
token_count = max(self._llm.max_length * 0.8, self._llm.max_length-512)
|
|
texts = []
|
|
res = []
|
|
cnt = 0
|
|
for i in range(len(sections)):
|
|
section_cnt = num_tokens_from_string(sections[i])
|
|
if cnt + section_cnt >= token_count and texts:
|
|
threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
|
|
texts = []
|
|
cnt = 0
|
|
texts.append(sections[i])
|
|
cnt += section_cnt
|
|
if texts:
|
|
threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
|
|
|
|
for i, _ in enumerate(threads):
|
|
res.append(_.result())
|
|
|
|
if not res:
|
|
return MindMapResult(output={"id": "root", "children": []})
|
|
|
|
merge_json = reduce(self._merge, res)
|
|
if len(merge_json.keys()) > 1:
|
|
keyset = set(
|
|
[re.sub(r"\*+", "", k) for k, v in merge_json.items() if isinstance(v, dict) and re.sub(r"\*+", "", k)])
|
|
merge_json = {"id": "root",
|
|
"children": [{"id": self._key(k), "children": self._be_children(v, keyset)} for k, v in
|
|
merge_json.items() if isinstance(v, dict) and self._key(k)]}
|
|
else:
|
|
k = self._key(list(merge_json.keys())[0])
|
|
merge_json = {"id": k, "children": self._be_children(list(merge_json.items())[0][1], set([k]))}
|
|
|
|
except Exception as e:
|
|
logging.exception("error mind graph")
|
|
self._on_error(
|
|
e,
|
|
traceback.format_exc(), None
|
|
)
|
|
merge_json = {"error": str(e)}
|
|
|
|
return MindMapResult(output=merge_json)
|
|
|
|
def _merge(self, d1, d2):
|
|
for k in d1:
|
|
if k in d2:
|
|
if isinstance(d1[k], dict) and isinstance(d2[k], dict):
|
|
self._merge(d1[k], d2[k])
|
|
elif isinstance(d1[k], list) and isinstance(d2[k], list):
|
|
d2[k].extend(d1[k])
|
|
else:
|
|
d2[k] = d1[k]
|
|
else:
|
|
d2[k] = d1[k]
|
|
|
|
return d2
|
|
|
|
def _list_to_kv(self, data):
|
|
for key, value in data.items():
|
|
if isinstance(value, dict):
|
|
self._list_to_kv(value)
|
|
elif isinstance(value, list):
|
|
new_value = {}
|
|
for i in range(len(value)):
|
|
if isinstance(value[i], list):
|
|
new_value[value[i - 1]] = value[i][0]
|
|
data[key] = new_value
|
|
else:
|
|
continue
|
|
return data
|
|
|
|
def _todict(self, layer:collections.OrderedDict):
|
|
to_ret = layer
|
|
if isinstance(layer, collections.OrderedDict):
|
|
to_ret = dict(layer)
|
|
|
|
try:
|
|
for key, value in to_ret.items():
|
|
to_ret[key] = self._todict(value)
|
|
except AttributeError:
|
|
pass
|
|
|
|
return self._list_to_kv(to_ret)
|
|
|
|
def _process_document(
|
|
self, text: str, prompt_variables: dict[str, str]
|
|
) -> str:
|
|
variables = {
|
|
**prompt_variables,
|
|
self._input_text_key: text,
|
|
}
|
|
text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
|
|
gen_conf = {"temperature": 0.5}
|
|
response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
|
|
response = re.sub(r"```[^\n]*", "", response)
|
|
print(response)
|
|
print("---------------------------------------------------\n", self._todict(markdown_to_json.dictify(response)))
|
|
return self._todict(markdown_to_json.dictify(response))
|