mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-06-04 11:24:00 +08:00
Made task_executor async to speedup parsing (#5530)
### What problem does this PR solve? Made task_executor async to speedup parsing ### Type of change - [x] Performance Improvement
This commit is contained in:
parent
abac2ca2c5
commit
c813c1ff4c
@ -17,6 +17,7 @@ import json
|
|||||||
import re
|
import re
|
||||||
import traceback
|
import traceback
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
import trio
|
||||||
from api.db.db_models import APIToken
|
from api.db.db_models import APIToken
|
||||||
|
|
||||||
from api.db.services.conversation_service import ConversationService, structure_answer
|
from api.db.services.conversation_service import ConversationService, structure_answer
|
||||||
@ -386,7 +387,8 @@ def mindmap():
|
|||||||
rank_feature=label_question(question, [kb])
|
rank_feature=label_question(question, [kb])
|
||||||
)
|
)
|
||||||
mindmap = MindMapExtractor(chat_mdl)
|
mindmap = MindMapExtractor(chat_mdl)
|
||||||
mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output
|
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]])
|
||||||
|
mind_map = mind_map.output
|
||||||
if "error" in mind_map:
|
if "error" in mind_map:
|
||||||
return server_error_response(Exception(mind_map["error"]))
|
return server_error_response(Exception(mind_map["error"]))
|
||||||
return get_json_result(data=mind_map)
|
return get_json_result(data=mind_map)
|
||||||
|
@ -22,6 +22,7 @@ from concurrent.futures import ThreadPoolExecutor
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
import trio
|
||||||
|
|
||||||
from peewee import fn
|
from peewee import fn
|
||||||
|
|
||||||
@ -597,8 +598,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|||||||
if parser_ids[doc_id] != ParserType.PICTURE.value:
|
if parser_ids[doc_id] != ParserType.PICTURE.value:
|
||||||
mindmap = MindMapExtractor(llm_bdl)
|
mindmap = MindMapExtractor(llm_bdl)
|
||||||
try:
|
try:
|
||||||
mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output,
|
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in docs if c["doc_id"] == doc_id])
|
||||||
ensure_ascii=False, indent=2)
|
mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2)
|
||||||
if len(mind_map) < 32:
|
if len(mind_map) < 32:
|
||||||
raise Exception("Few content: " + mind_map)
|
raise Exception("Few content: " + mind_map)
|
||||||
cks.append({
|
cks.append({
|
||||||
|
@ -17,6 +17,8 @@ import base64
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
import pdfplumber
|
import pdfplumber
|
||||||
@ -30,6 +32,10 @@ from api.constants import IMG_BASE64_PREFIX
|
|||||||
PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
|
PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
|
||||||
RAG_BASE = os.getenv("RAG_BASE")
|
RAG_BASE = os.getenv("RAG_BASE")
|
||||||
|
|
||||||
|
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
|
||||||
|
if LOCK_KEY_pdfplumber not in sys.modules:
|
||||||
|
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
def get_project_base_directory(*args):
|
def get_project_base_directory(*args):
|
||||||
global PROJECT_BASE
|
global PROJECT_BASE
|
||||||
@ -175,19 +181,20 @@ def thumbnail_img(filename, blob):
|
|||||||
"""
|
"""
|
||||||
filename = filename.lower()
|
filename = filename.lower()
|
||||||
if re.match(r".*\.pdf$", filename):
|
if re.match(r".*\.pdf$", filename):
|
||||||
pdf = pdfplumber.open(BytesIO(blob))
|
with sys.modules[LOCK_KEY_pdfplumber]:
|
||||||
buffered = BytesIO()
|
pdf = pdfplumber.open(BytesIO(blob))
|
||||||
resolution = 32
|
buffered = BytesIO()
|
||||||
img = None
|
resolution = 32
|
||||||
for _ in range(10):
|
img = None
|
||||||
# https://github.com/jsvine/pdfplumber?tab=readme-ov-file#creating-a-pageimage-with-to_image
|
for _ in range(10):
|
||||||
pdf.pages[0].to_image(resolution=resolution).annotated.save(buffered, format="png")
|
# https://github.com/jsvine/pdfplumber?tab=readme-ov-file#creating-a-pageimage-with-to_image
|
||||||
img = buffered.getvalue()
|
pdf.pages[0].to_image(resolution=resolution).annotated.save(buffered, format="png")
|
||||||
if len(img) >= 64000 and resolution >= 2:
|
img = buffered.getvalue()
|
||||||
resolution = resolution / 2
|
if len(img) >= 64000 and resolution >= 2:
|
||||||
buffered = BytesIO()
|
resolution = resolution / 2
|
||||||
else:
|
buffered = BytesIO()
|
||||||
break
|
else:
|
||||||
|
break
|
||||||
pdf.close()
|
pdf.close()
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
@ -18,6 +18,8 @@ import os.path
|
|||||||
import logging
|
import logging
|
||||||
from logging.handlers import RotatingFileHandler
|
from logging.handlers import RotatingFileHandler
|
||||||
|
|
||||||
|
initialized_root_logger = False
|
||||||
|
|
||||||
def get_project_base_directory():
|
def get_project_base_directory():
|
||||||
PROJECT_BASE = os.path.abspath(
|
PROJECT_BASE = os.path.abspath(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
@ -29,10 +31,13 @@ def get_project_base_directory():
|
|||||||
return PROJECT_BASE
|
return PROJECT_BASE
|
||||||
|
|
||||||
def initRootLogger(logfile_basename: str, log_format: str = "%(asctime)-15s %(levelname)-8s %(process)d %(message)s"):
|
def initRootLogger(logfile_basename: str, log_format: str = "%(asctime)-15s %(levelname)-8s %(process)d %(message)s"):
|
||||||
logger = logging.getLogger()
|
global initialized_root_logger
|
||||||
if logger.hasHandlers():
|
if initialized_root_logger:
|
||||||
return
|
return
|
||||||
|
initialized_root_logger = True
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
logger.handlers.clear()
|
||||||
log_path = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{logfile_basename}.log"))
|
log_path = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{logfile_basename}.log"))
|
||||||
|
|
||||||
os.makedirs(os.path.dirname(log_path), exist_ok=True)
|
os.makedirs(os.path.dirname(log_path), exist_ok=True)
|
||||||
|
@ -18,6 +18,8 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@ -34,6 +36,10 @@ from rag.nlp import rag_tokenizer
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
|
||||||
|
if LOCK_KEY_pdfplumber not in sys.modules:
|
||||||
|
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
|
||||||
|
|
||||||
class RAGFlowPdfParser:
|
class RAGFlowPdfParser:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.ocr = OCR()
|
self.ocr = OCR()
|
||||||
@ -948,8 +954,9 @@ class RAGFlowPdfParser:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def total_page_number(fnm, binary=None):
|
def total_page_number(fnm, binary=None):
|
||||||
try:
|
try:
|
||||||
pdf = pdfplumber.open(
|
with sys.modules[LOCK_KEY_pdfplumber]:
|
||||||
fnm) if not binary else pdfplumber.open(BytesIO(binary))
|
pdf = pdfplumber.open(
|
||||||
|
fnm) if not binary else pdfplumber.open(BytesIO(binary))
|
||||||
total_page = len(pdf.pages)
|
total_page = len(pdf.pages)
|
||||||
pdf.close()
|
pdf.close()
|
||||||
return total_page
|
return total_page
|
||||||
@ -968,17 +975,18 @@ class RAGFlowPdfParser:
|
|||||||
self.page_from = page_from
|
self.page_from = page_from
|
||||||
start = timer()
|
start = timer()
|
||||||
try:
|
try:
|
||||||
self.pdf = pdfplumber.open(fnm) if isinstance(
|
with sys.modules[LOCK_KEY_pdfplumber]:
|
||||||
fnm, str) else pdfplumber.open(BytesIO(fnm))
|
self.pdf = pdfplumber.open(fnm) if isinstance(
|
||||||
self.page_images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
|
fnm, str) else pdfplumber.open(BytesIO(fnm))
|
||||||
enumerate(self.pdf.pages[page_from:page_to])]
|
self.page_images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
|
||||||
try:
|
enumerate(self.pdf.pages[page_from:page_to])]
|
||||||
self.page_chars = [[c for c in page.dedupe_chars().chars if self._has_color(c)] for page in self.pdf.pages[page_from:page_to]]
|
try:
|
||||||
except Exception as e:
|
self.page_chars = [[c for c in page.dedupe_chars().chars if self._has_color(c)] for page in self.pdf.pages[page_from:page_to]]
|
||||||
logging.warning(f"Failed to extract characters for pages {page_from}-{page_to}: {str(e)}")
|
except Exception as e:
|
||||||
self.page_chars = [[] for _ in range(page_to - page_from)] # If failed to extract, using empty list instead.
|
logging.warning(f"Failed to extract characters for pages {page_from}-{page_to}: {str(e)}")
|
||||||
|
self.page_chars = [[] for _ in range(page_to - page_from)] # If failed to extract, using empty list instead.
|
||||||
self.total_page = len(self.pdf.pages)
|
|
||||||
|
self.total_page = len(self.pdf.pages)
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception("RAGFlowPdfParser __images__")
|
logging.exception("RAGFlowPdfParser __images__")
|
||||||
logging.info(f"__images__ dedupe_chars cost {timer() - start}s")
|
logging.info(f"__images__ dedupe_chars cost {timer() - start}s")
|
||||||
|
@ -14,7 +14,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import io
|
import io
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
import pdfplumber
|
import pdfplumber
|
||||||
|
|
||||||
from .ocr import OCR
|
from .ocr import OCR
|
||||||
@ -23,6 +24,11 @@ from .layout_recognizer import LayoutRecognizer4YOLOv10 as LayoutRecognizer
|
|||||||
from .table_structure_recognizer import TableStructureRecognizer
|
from .table_structure_recognizer import TableStructureRecognizer
|
||||||
|
|
||||||
|
|
||||||
|
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
|
||||||
|
if LOCK_KEY_pdfplumber not in sys.modules:
|
||||||
|
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
def init_in_out(args):
|
def init_in_out(args):
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import os
|
import os
|
||||||
@ -36,9 +42,10 @@ def init_in_out(args):
|
|||||||
|
|
||||||
def pdf_pages(fnm, zoomin=3):
|
def pdf_pages(fnm, zoomin=3):
|
||||||
nonlocal outputs, images
|
nonlocal outputs, images
|
||||||
pdf = pdfplumber.open(fnm)
|
with sys.modules[LOCK_KEY_pdfplumber]:
|
||||||
images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
|
pdf = pdfplumber.open(fnm)
|
||||||
enumerate(pdf.pages)]
|
images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
|
||||||
|
enumerate(pdf.pages)]
|
||||||
|
|
||||||
for i, page in enumerate(images):
|
for i, page in enumerate(images):
|
||||||
outputs.append(os.path.split(fnm)[-1] + f"_{i}.jpg")
|
outputs.append(os.path.split(fnm)[-1] + f"_{i}.jpg")
|
||||||
|
@ -13,7 +13,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import logging
|
|
||||||
import itertools
|
import itertools
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
@ -21,13 +20,14 @@ from dataclasses import dataclass
|
|||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
|
import trio
|
||||||
|
|
||||||
from graphrag.general.extractor import Extractor
|
from graphrag.general.extractor import Extractor
|
||||||
from rag.nlp import is_english
|
from rag.nlp import is_english
|
||||||
import editdistance
|
import editdistance
|
||||||
from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
|
from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
|
||||||
from rag.llm.chat_model import Base as CompletionLLM
|
from rag.llm.chat_model import Base as CompletionLLM
|
||||||
from graphrag.utils import perform_variable_replacements
|
from graphrag.utils import perform_variable_replacements, chat_limiter
|
||||||
|
|
||||||
DEFAULT_RECORD_DELIMITER = "##"
|
DEFAULT_RECORD_DELIMITER = "##"
|
||||||
DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
|
DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
|
||||||
@ -67,13 +67,13 @@ class EntityResolution(Extractor):
|
|||||||
self._resolution_result_delimiter_key = "resolution_result_delimiter"
|
self._resolution_result_delimiter_key = "resolution_result_delimiter"
|
||||||
self._input_text_key = "input_text"
|
self._input_text_key = "input_text"
|
||||||
|
|
||||||
def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult:
|
async def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult:
|
||||||
"""Call method definition."""
|
"""Call method definition."""
|
||||||
if prompt_variables is None:
|
if prompt_variables is None:
|
||||||
prompt_variables = {}
|
prompt_variables = {}
|
||||||
|
|
||||||
# Wire defaults into the prompt variables
|
# Wire defaults into the prompt variables
|
||||||
prompt_variables = {
|
self.prompt_variables = {
|
||||||
**prompt_variables,
|
**prompt_variables,
|
||||||
self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
|
self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
|
||||||
or DEFAULT_RECORD_DELIMITER,
|
or DEFAULT_RECORD_DELIMITER,
|
||||||
@ -94,39 +94,12 @@ class EntityResolution(Extractor):
|
|||||||
for k, v in node_clusters.items():
|
for k, v in node_clusters.items():
|
||||||
candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if self.is_similarity(a, b)]
|
candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if self.is_similarity(a, b)]
|
||||||
|
|
||||||
gen_conf = {"temperature": 0.5}
|
|
||||||
resolution_result = set()
|
resolution_result = set()
|
||||||
for candidate_resolution_i in candidate_resolution.items():
|
async with trio.open_nursery() as nursery:
|
||||||
if candidate_resolution_i[1]:
|
for candidate_resolution_i in candidate_resolution.items():
|
||||||
try:
|
if not candidate_resolution_i[1]:
|
||||||
pair_txt = [
|
continue
|
||||||
f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n']
|
nursery.start_soon(self._resolve_candidate(candidate_resolution_i, resolution_result))
|
||||||
for index, candidate in enumerate(candidate_resolution_i[1]):
|
|
||||||
pair_txt.append(
|
|
||||||
f'Question {index + 1}: name of{candidate_resolution_i[0]} A is {candidate[0]} ,name of{candidate_resolution_i[0]} B is {candidate[1]}')
|
|
||||||
sent = 'question above' if len(pair_txt) == 1 else f'above {len(pair_txt)} questions'
|
|
||||||
pair_txt.append(
|
|
||||||
f'\nUse domain knowledge of {candidate_resolution_i[0]}s to help understand the text and answer the {sent} in the format: For Question i, Yes, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are the same {candidate_resolution_i[0]}./No, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are different {candidate_resolution_i[0]}s. For Question i+1, (repeat the above procedures)')
|
|
||||||
pair_prompt = '\n'.join(pair_txt)
|
|
||||||
|
|
||||||
variables = {
|
|
||||||
**prompt_variables,
|
|
||||||
self._input_text_key: pair_prompt
|
|
||||||
}
|
|
||||||
text = perform_variable_replacements(self._resolution_prompt, variables=variables)
|
|
||||||
|
|
||||||
response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
|
|
||||||
result = self._process_results(len(candidate_resolution_i[1]), response,
|
|
||||||
prompt_variables.get(self._record_delimiter_key,
|
|
||||||
DEFAULT_RECORD_DELIMITER),
|
|
||||||
prompt_variables.get(self._entity_index_dilimiter_key,
|
|
||||||
DEFAULT_ENTITY_INDEX_DELIMITER),
|
|
||||||
prompt_variables.get(self._resolution_result_delimiter_key,
|
|
||||||
DEFAULT_RESOLUTION_RESULT_DELIMITER))
|
|
||||||
for result_i in result:
|
|
||||||
resolution_result.add(candidate_resolution_i[1][result_i[0] - 1])
|
|
||||||
except Exception:
|
|
||||||
logging.exception("error entity resolution")
|
|
||||||
|
|
||||||
connect_graph = nx.Graph()
|
connect_graph = nx.Graph()
|
||||||
removed_entities = []
|
removed_entities = []
|
||||||
@ -172,6 +145,34 @@ class EntityResolution(Extractor):
|
|||||||
removed_entities=removed_entities
|
removed_entities=removed_entities
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _resolve_candidate(self, candidate_resolution_i, resolution_result):
|
||||||
|
gen_conf = {"temperature": 0.5}
|
||||||
|
pair_txt = [
|
||||||
|
f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n']
|
||||||
|
for index, candidate in enumerate(candidate_resolution_i[1]):
|
||||||
|
pair_txt.append(
|
||||||
|
f'Question {index + 1}: name of{candidate_resolution_i[0]} A is {candidate[0]} ,name of{candidate_resolution_i[0]} B is {candidate[1]}')
|
||||||
|
sent = 'question above' if len(pair_txt) == 1 else f'above {len(pair_txt)} questions'
|
||||||
|
pair_txt.append(
|
||||||
|
f'\nUse domain knowledge of {candidate_resolution_i[0]}s to help understand the text and answer the {sent} in the format: For Question i, Yes, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are the same {candidate_resolution_i[0]}./No, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are different {candidate_resolution_i[0]}s. For Question i+1, (repeat the above procedures)')
|
||||||
|
pair_prompt = '\n'.join(pair_txt)
|
||||||
|
variables = {
|
||||||
|
**self.prompt_variables,
|
||||||
|
self._input_text_key: pair_prompt
|
||||||
|
}
|
||||||
|
text = perform_variable_replacements(self._resolution_prompt, variables=variables)
|
||||||
|
async with chat_limiter:
|
||||||
|
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
|
||||||
|
result = self._process_results(len(candidate_resolution_i[1]), response,
|
||||||
|
self.prompt_variables.get(self._record_delimiter_key,
|
||||||
|
DEFAULT_RECORD_DELIMITER),
|
||||||
|
self.prompt_variables.get(self._entity_index_dilimiter_key,
|
||||||
|
DEFAULT_ENTITY_INDEX_DELIMITER),
|
||||||
|
self.prompt_variables.get(self._resolution_result_delimiter_key,
|
||||||
|
DEFAULT_RESOLUTION_RESULT_DELIMITER))
|
||||||
|
for result_i in result:
|
||||||
|
resolution_result.add(candidate_resolution_i[1][result_i[0] - 1])
|
||||||
|
|
||||||
def _process_results(
|
def _process_results(
|
||||||
self,
|
self,
|
||||||
records_length: int,
|
records_length: int,
|
||||||
|
@ -1,268 +0,0 @@
|
|||||||
# Copyright (c) 2024 Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License
|
|
||||||
"""
|
|
||||||
Reference:
|
|
||||||
- [graphrag](https://github.com/microsoft/graphrag)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
import traceback
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import tiktoken
|
|
||||||
|
|
||||||
from graphrag.general.claim_prompt import CLAIM_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
|
|
||||||
from graphrag.general.extractor import Extractor
|
|
||||||
from rag.llm.chat_model import Base as CompletionLLM
|
|
||||||
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
|
|
||||||
|
|
||||||
DEFAULT_TUPLE_DELIMITER = "<|>"
|
|
||||||
DEFAULT_RECORD_DELIMITER = "##"
|
|
||||||
DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
|
|
||||||
CLAIM_MAX_GLEANINGS = 1
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ClaimExtractorResult:
|
|
||||||
"""Claim extractor result class definition."""
|
|
||||||
|
|
||||||
output: list[dict]
|
|
||||||
source_docs: dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
class ClaimExtractor(Extractor):
|
|
||||||
"""Claim extractor class definition."""
|
|
||||||
|
|
||||||
_extraction_prompt: str
|
|
||||||
_summary_prompt: str
|
|
||||||
_output_formatter_prompt: str
|
|
||||||
_input_text_key: str
|
|
||||||
_input_entity_spec_key: str
|
|
||||||
_input_claim_description_key: str
|
|
||||||
_tuple_delimiter_key: str
|
|
||||||
_record_delimiter_key: str
|
|
||||||
_completion_delimiter_key: str
|
|
||||||
_max_gleanings: int
|
|
||||||
_on_error: ErrorHandlerFn
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
llm_invoker: CompletionLLM,
|
|
||||||
extraction_prompt: str | None = None,
|
|
||||||
input_text_key: str | None = None,
|
|
||||||
input_entity_spec_key: str | None = None,
|
|
||||||
input_claim_description_key: str | None = None,
|
|
||||||
input_resolved_entities_key: str | None = None,
|
|
||||||
tuple_delimiter_key: str | None = None,
|
|
||||||
record_delimiter_key: str | None = None,
|
|
||||||
completion_delimiter_key: str | None = None,
|
|
||||||
encoding_model: str | None = None,
|
|
||||||
max_gleanings: int | None = None,
|
|
||||||
on_error: ErrorHandlerFn | None = None,
|
|
||||||
):
|
|
||||||
"""Init method definition."""
|
|
||||||
self._llm = llm_invoker
|
|
||||||
self._extraction_prompt = extraction_prompt or CLAIM_EXTRACTION_PROMPT
|
|
||||||
self._input_text_key = input_text_key or "input_text"
|
|
||||||
self._input_entity_spec_key = input_entity_spec_key or "entity_specs"
|
|
||||||
self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter"
|
|
||||||
self._record_delimiter_key = record_delimiter_key or "record_delimiter"
|
|
||||||
self._completion_delimiter_key = (
|
|
||||||
completion_delimiter_key or "completion_delimiter"
|
|
||||||
)
|
|
||||||
self._input_claim_description_key = (
|
|
||||||
input_claim_description_key or "claim_description"
|
|
||||||
)
|
|
||||||
self._input_resolved_entities_key = (
|
|
||||||
input_resolved_entities_key or "resolved_entities"
|
|
||||||
)
|
|
||||||
self._max_gleanings = (
|
|
||||||
max_gleanings if max_gleanings is not None else CLAIM_MAX_GLEANINGS
|
|
||||||
)
|
|
||||||
self._on_error = on_error or (lambda _e, _s, _d: None)
|
|
||||||
|
|
||||||
# Construct the looping arguments
|
|
||||||
encoding = tiktoken.get_encoding(encoding_model or "cl100k_base")
|
|
||||||
yes = encoding.encode("YES")
|
|
||||||
no = encoding.encode("NO")
|
|
||||||
self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, inputs: dict[str, Any], prompt_variables: dict | None = None
|
|
||||||
) -> ClaimExtractorResult:
|
|
||||||
"""Call method definition."""
|
|
||||||
if prompt_variables is None:
|
|
||||||
prompt_variables = {}
|
|
||||||
texts = inputs[self._input_text_key]
|
|
||||||
entity_spec = str(inputs[self._input_entity_spec_key])
|
|
||||||
claim_description = inputs[self._input_claim_description_key]
|
|
||||||
resolved_entities = inputs.get(self._input_resolved_entities_key, {})
|
|
||||||
source_doc_map = {}
|
|
||||||
|
|
||||||
prompt_args = {
|
|
||||||
self._input_entity_spec_key: entity_spec,
|
|
||||||
self._input_claim_description_key: claim_description,
|
|
||||||
self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key)
|
|
||||||
or DEFAULT_TUPLE_DELIMITER,
|
|
||||||
self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
|
|
||||||
or DEFAULT_RECORD_DELIMITER,
|
|
||||||
self._completion_delimiter_key: prompt_variables.get(
|
|
||||||
self._completion_delimiter_key
|
|
||||||
)
|
|
||||||
or DEFAULT_COMPLETION_DELIMITER,
|
|
||||||
}
|
|
||||||
|
|
||||||
all_claims: list[dict] = []
|
|
||||||
for doc_index, text in enumerate(texts):
|
|
||||||
document_id = f"d{doc_index}"
|
|
||||||
try:
|
|
||||||
claims = self._process_document(prompt_args, text, doc_index)
|
|
||||||
all_claims += [
|
|
||||||
self._clean_claim(c, document_id, resolved_entities) for c in claims
|
|
||||||
]
|
|
||||||
source_doc_map[document_id] = text
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception("error extracting claim")
|
|
||||||
self._on_error(
|
|
||||||
e,
|
|
||||||
traceback.format_exc(),
|
|
||||||
{"doc_index": doc_index, "text": text},
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
return ClaimExtractorResult(
|
|
||||||
output=all_claims,
|
|
||||||
source_docs=source_doc_map,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _clean_claim(
|
|
||||||
self, claim: dict, document_id: str, resolved_entities: dict
|
|
||||||
) -> dict:
|
|
||||||
# clean the parsed claims to remove any claims with status = False
|
|
||||||
obj = claim.get("object_id", claim.get("object"))
|
|
||||||
subject = claim.get("subject_id", claim.get("subject"))
|
|
||||||
|
|
||||||
# If subject or object in resolved entities, then replace with resolved entity
|
|
||||||
obj = resolved_entities.get(obj, obj)
|
|
||||||
subject = resolved_entities.get(subject, subject)
|
|
||||||
claim["object_id"] = obj
|
|
||||||
claim["subject_id"] = subject
|
|
||||||
claim["doc_id"] = document_id
|
|
||||||
return claim
|
|
||||||
|
|
||||||
def _process_document(
|
|
||||||
self, prompt_args: dict, doc, doc_index: int
|
|
||||||
) -> list[dict]:
|
|
||||||
record_delimiter = prompt_args.get(
|
|
||||||
self._record_delimiter_key, DEFAULT_RECORD_DELIMITER
|
|
||||||
)
|
|
||||||
completion_delimiter = prompt_args.get(
|
|
||||||
self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER
|
|
||||||
)
|
|
||||||
variables = {
|
|
||||||
self._input_text_key: doc,
|
|
||||||
**prompt_args,
|
|
||||||
}
|
|
||||||
text = perform_variable_replacements(self._extraction_prompt, variables=variables)
|
|
||||||
gen_conf = {"temperature": 0.5}
|
|
||||||
results = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
|
|
||||||
claims = results.strip().removesuffix(completion_delimiter)
|
|
||||||
history = [{"role": "system", "content": text}, {"role": "assistant", "content": results}]
|
|
||||||
|
|
||||||
# Repeat to ensure we maximize entity count
|
|
||||||
for i in range(self._max_gleanings):
|
|
||||||
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
|
|
||||||
history.append({"role": "user", "content": text})
|
|
||||||
extension = self._chat("", history, gen_conf)
|
|
||||||
claims += record_delimiter + extension.strip().removesuffix(
|
|
||||||
completion_delimiter
|
|
||||||
)
|
|
||||||
|
|
||||||
# If this isn't the last loop, check to see if we should continue
|
|
||||||
if i >= self._max_gleanings - 1:
|
|
||||||
break
|
|
||||||
|
|
||||||
history.append({"role": "assistant", "content": extension})
|
|
||||||
history.append({"role": "user", "content": LOOP_PROMPT})
|
|
||||||
continuation = self._chat("", history, self._loop_args)
|
|
||||||
if continuation != "YES":
|
|
||||||
break
|
|
||||||
|
|
||||||
result = self._parse_claim_tuples(claims, prompt_args)
|
|
||||||
for r in result:
|
|
||||||
r["doc_id"] = f"{doc_index}"
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _parse_claim_tuples(
|
|
||||||
self, claims: str, prompt_variables: dict
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""Parse claim tuples."""
|
|
||||||
record_delimiter = prompt_variables.get(
|
|
||||||
self._record_delimiter_key, DEFAULT_RECORD_DELIMITER
|
|
||||||
)
|
|
||||||
completion_delimiter = prompt_variables.get(
|
|
||||||
self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER
|
|
||||||
)
|
|
||||||
tuple_delimiter = prompt_variables.get(
|
|
||||||
self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER
|
|
||||||
)
|
|
||||||
|
|
||||||
def pull_field(index: int, fields: list[str]) -> str | None:
|
|
||||||
return fields[index].strip() if len(fields) > index else None
|
|
||||||
|
|
||||||
result: list[dict[str, Any]] = []
|
|
||||||
claims_values = (
|
|
||||||
claims.strip().removesuffix(completion_delimiter).split(record_delimiter)
|
|
||||||
)
|
|
||||||
for claim in claims_values:
|
|
||||||
claim = claim.strip().removeprefix("(").removesuffix(")")
|
|
||||||
claim = re.sub(r".*Output:", "", claim)
|
|
||||||
|
|
||||||
# Ignore the completion delimiter
|
|
||||||
if claim == completion_delimiter:
|
|
||||||
continue
|
|
||||||
|
|
||||||
claim_fields = claim.split(tuple_delimiter)
|
|
||||||
o = {
|
|
||||||
"subject_id": pull_field(0, claim_fields),
|
|
||||||
"object_id": pull_field(1, claim_fields),
|
|
||||||
"type": pull_field(2, claim_fields),
|
|
||||||
"status": pull_field(3, claim_fields),
|
|
||||||
"start_date": pull_field(4, claim_fields),
|
|
||||||
"end_date": pull_field(5, claim_fields),
|
|
||||||
"description": pull_field(6, claim_fields),
|
|
||||||
"source_text": pull_field(7, claim_fields),
|
|
||||||
"doc_id": pull_field(8, claim_fields),
|
|
||||||
}
|
|
||||||
if any([not o["subject_id"], not o["object_id"], o["subject_id"].lower() == "none", o["object_id"] == "none"]):
|
|
||||||
continue
|
|
||||||
result.append(o)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True)
|
|
||||||
parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
from api.db import LLMType
|
|
||||||
from api.db.services.llm_service import LLMBundle
|
|
||||||
from api import settings
|
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
|
||||||
|
|
||||||
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
|
|
||||||
|
|
||||||
ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
|
|
||||||
docs = [d["content_with_weight"] for d in settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])]
|
|
||||||
info = {
|
|
||||||
"input_text": docs,
|
|
||||||
"entity_specs": "organization, person",
|
|
||||||
"claim_description": ""
|
|
||||||
}
|
|
||||||
claim = ex(info)
|
|
||||||
logging.info(json.dumps(claim.output, ensure_ascii=False, indent=2))
|
|
@ -1,71 +0,0 @@
|
|||||||
# Copyright (c) 2024 Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License
|
|
||||||
"""
|
|
||||||
Reference:
|
|
||||||
- [graphrag](https://github.com/microsoft/graphrag)
|
|
||||||
"""
|
|
||||||
|
|
||||||
CLAIM_EXTRACTION_PROMPT = """
|
|
||||||
################
|
|
||||||
-Target activity-
|
|
||||||
################
|
|
||||||
You are an intelligent assistant that helps a human analyst to analyze claims against certain entities presented in a text document.
|
|
||||||
|
|
||||||
################
|
|
||||||
-Goal-
|
|
||||||
################
|
|
||||||
Given a text document that is potentially relevant to this activity, an entity specification, and a claim description, extract all entities that match the entity specification and all claims against those entities.
|
|
||||||
|
|
||||||
################
|
|
||||||
-Steps-
|
|
||||||
################
|
|
||||||
- 1. Extract all named entities that match the predefined entity specification. Entity specification can either be a list of entity names or a list of entity types.
|
|
||||||
- 2. For each entity identified in step 1, extract all claims associated with the entity. Claims need to match the specified claim description, and the entity should be the subject of the claim.
|
|
||||||
For each claim, extract the following information:
|
|
||||||
- Subject: name of the entity that is subject of the claim, capitalized. The subject entity is one that committed the action described in the claim. Subject needs to be one of the named entities identified in step 1.
|
|
||||||
- Object: name of the entity that is object of the claim, capitalized. The object entity is one that either reports/handles or is affected by the action described in the claim. If object entity is unknown, use **NONE**.
|
|
||||||
- Claim Type: overall category of the claim, capitalized. Name it in a way that can be repeated across multiple text inputs, so that similar claims share the same claim type
|
|
||||||
- Claim Status: **TRUE**, **FALSE**, or **SUSPECTED**. TRUE means the claim is confirmed, FALSE means the claim is found to be False, SUSPECTED means the claim is not verified.
|
|
||||||
- Claim Description: Detailed description explaining the reasoning behind the claim, together with all the related evidence and references.
|
|
||||||
- Claim Date: Period (start_date, end_date) when the claim was made. Both start_date and end_date should be in ISO-8601 format. If the claim was made on a single date rather than a date range, set the same date for both start_date and end_date. If date is unknown, return **NONE**.
|
|
||||||
- Claim Source Text: List of **all** quotes from the original text that are relevant to the claim.
|
|
||||||
|
|
||||||
- 3. Format each claim as (<subject_entity>{tuple_delimiter}<object_entity>{tuple_delimiter}<claim_type>{tuple_delimiter}<claim_status>{tuple_delimiter}<claim_start_date>{tuple_delimiter}<claim_end_date>{tuple_delimiter}<claim_description>{tuple_delimiter}<claim_source>)
|
|
||||||
- 4. Return output in language of the 'Text' as a single list of all the claims identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
|
|
||||||
- 5. If there's nothing satisfy the above requirements, just keep output empty.
|
|
||||||
- 6. When finished, output {completion_delimiter}
|
|
||||||
|
|
||||||
################
|
|
||||||
-Examples-
|
|
||||||
################
|
|
||||||
Example 1:
|
|
||||||
Entity specification: organization
|
|
||||||
Claim description: red flags associated with an entity
|
|
||||||
Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015.
|
|
||||||
Output:
|
|
||||||
(COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.)
|
|
||||||
{completion_delimiter}
|
|
||||||
|
|
||||||
###########################
|
|
||||||
Example 2:
|
|
||||||
Entity specification: Company A, Person C
|
|
||||||
Claim description: red flags associated with an entity
|
|
||||||
Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015.
|
|
||||||
Output:
|
|
||||||
(COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.)
|
|
||||||
{record_delimiter}
|
|
||||||
(PERSON C{tuple_delimiter}NONE{tuple_delimiter}CORRUPTION{tuple_delimiter}SUSPECTED{tuple_delimiter}2015-01-01T00:00:00{tuple_delimiter}2015-12-30T00:00:00{tuple_delimiter}Person C was suspected of engaging in corruption activities in 2015{tuple_delimiter}The company is owned by Person C who was suspected of engaging in corruption activities in 2015)
|
|
||||||
{completion_delimiter}
|
|
||||||
|
|
||||||
################
|
|
||||||
-Real Data-
|
|
||||||
################
|
|
||||||
Use the following input for your answer.
|
|
||||||
Entity specification: {entity_specs}
|
|
||||||
Claim description: {claim_description}
|
|
||||||
Text: {input_text}
|
|
||||||
Output:"""
|
|
||||||
|
|
||||||
|
|
||||||
CONTINUE_PROMPT = "MANY entities were missed in the last extraction. Add them below using the same format(see 'Steps', start with the 'Output').\nOutput: "
|
|
||||||
LOOP_PROMPT = "It appears some entities may have still been missed. Answer YES {tuple_delimiter} NO if there are still entities that need to be added.\n"
|
|
@ -17,9 +17,10 @@ from graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
|
|||||||
from graphrag.general.extractor import Extractor
|
from graphrag.general.extractor import Extractor
|
||||||
from graphrag.general.leiden import add_community_info2graph
|
from graphrag.general.leiden import add_community_info2graph
|
||||||
from rag.llm.chat_model import Base as CompletionLLM
|
from rag.llm.chat_model import Base as CompletionLLM
|
||||||
from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types
|
from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
|
||||||
from rag.utils import num_tokens_from_string
|
from rag.utils import num_tokens_from_string
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
|
import trio
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -52,7 +53,7 @@ class CommunityReportsExtractor(Extractor):
|
|||||||
self._extraction_prompt = COMMUNITY_REPORT_PROMPT
|
self._extraction_prompt = COMMUNITY_REPORT_PROMPT
|
||||||
self._max_report_length = max_report_length or 1500
|
self._max_report_length = max_report_length or 1500
|
||||||
|
|
||||||
def __call__(self, graph: nx.Graph, callback: Callable | None = None):
|
async def __call__(self, graph: nx.Graph, callback: Callable | None = None):
|
||||||
for node_degree in graph.degree:
|
for node_degree in graph.degree:
|
||||||
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
|
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
|
||||||
|
|
||||||
@ -86,28 +87,25 @@ class CommunityReportsExtractor(Extractor):
|
|||||||
}
|
}
|
||||||
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
|
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
|
||||||
gen_conf = {"temperature": 0.3}
|
gen_conf = {"temperature": 0.3}
|
||||||
try:
|
async with chat_limiter:
|
||||||
response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
|
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
|
||||||
token_count += num_tokens_from_string(text + response)
|
token_count += num_tokens_from_string(text + response)
|
||||||
response = re.sub(r"^[^\{]*", "", response)
|
response = re.sub(r"^[^\{]*", "", response)
|
||||||
response = re.sub(r"[^\}]*$", "", response)
|
response = re.sub(r"[^\}]*$", "", response)
|
||||||
response = re.sub(r"\{\{", "{", response)
|
response = re.sub(r"\{\{", "{", response)
|
||||||
response = re.sub(r"\}\}", "}", response)
|
response = re.sub(r"\}\}", "}", response)
|
||||||
logging.debug(response)
|
logging.debug(response)
|
||||||
response = json.loads(response)
|
response = json.loads(response)
|
||||||
if not dict_has_keys_with_types(response, [
|
if not dict_has_keys_with_types(response, [
|
||||||
("title", str),
|
("title", str),
|
||||||
("summary", str),
|
("summary", str),
|
||||||
("findings", list),
|
("findings", list),
|
||||||
("rating", float),
|
("rating", float),
|
||||||
("rating_explanation", str),
|
("rating_explanation", str),
|
||||||
]):
|
]):
|
||||||
continue
|
|
||||||
response["weight"] = weight
|
|
||||||
response["entities"] = ents
|
|
||||||
except Exception:
|
|
||||||
logging.exception("CommunityReportsExtractor got exception")
|
|
||||||
continue
|
continue
|
||||||
|
response["weight"] = weight
|
||||||
|
response["entities"] = ents
|
||||||
|
|
||||||
add_community_info2graph(graph, ents, response["title"])
|
add_community_info2graph(graph, ents, response["title"])
|
||||||
res_str.append(self._get_text_output(response))
|
res_str.append(self._get_text_output(response))
|
||||||
|
@ -14,16 +14,15 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
from collections import defaultdict, Counter
|
from collections import defaultdict, Counter
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
import trio
|
||||||
|
|
||||||
from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT
|
from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT
|
||||||
from graphrag.utils import get_llm_cache, set_llm_cache, handle_single_entity_extraction, \
|
from graphrag.utils import get_llm_cache, set_llm_cache, handle_single_entity_extraction, \
|
||||||
handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list
|
handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list, chat_limiter
|
||||||
from rag.llm.chat_model import Base as CompletionLLM
|
from rag.llm.chat_model import Base as CompletionLLM
|
||||||
from rag.utils import truncate
|
from rag.utils import truncate
|
||||||
|
|
||||||
@ -91,54 +90,50 @@ class Extractor:
|
|||||||
)
|
)
|
||||||
return dict(maybe_nodes), dict(maybe_edges)
|
return dict(maybe_nodes), dict(maybe_edges)
|
||||||
|
|
||||||
def __call__(
|
async def __call__(
|
||||||
self, chunks: list[tuple[str, str]],
|
self, chunks: list[tuple[str, str]],
|
||||||
callback: Callable | None = None
|
callback: Callable | None = None
|
||||||
):
|
):
|
||||||
|
|
||||||
results = []
|
self.callback = callback
|
||||||
max_workers = int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 10))
|
start_ts = trio.current_time()
|
||||||
with ThreadPoolExecutor(max_workers=max_workers) as exe:
|
out_results = []
|
||||||
threads = []
|
async with trio.open_nursery() as nursery:
|
||||||
for i, (cid, ck) in enumerate(chunks):
|
for i, (cid, ck) in enumerate(chunks):
|
||||||
ck = truncate(ck, int(self._llm.max_length*0.8))
|
ck = truncate(ck, int(self._llm.max_length*0.8))
|
||||||
threads.append(
|
nursery.start_soon(self._process_single_content, (cid, ck), i, len(chunks), out_results)
|
||||||
exe.submit(self._process_single_content, (cid, ck)))
|
|
||||||
|
|
||||||
for i, _ in enumerate(threads):
|
|
||||||
n, r, tc = _.result()
|
|
||||||
if not isinstance(n, Exception):
|
|
||||||
results.append((n, r))
|
|
||||||
if callback:
|
|
||||||
callback(0.5 + 0.1 * i / len(threads), f"Entities extraction progress ... {i + 1}/{len(threads)} ({tc} tokens)")
|
|
||||||
elif callback:
|
|
||||||
callback(msg="Knowledge graph extraction error:{}".format(str(n)))
|
|
||||||
|
|
||||||
maybe_nodes = defaultdict(list)
|
maybe_nodes = defaultdict(list)
|
||||||
maybe_edges = defaultdict(list)
|
maybe_edges = defaultdict(list)
|
||||||
for m_nodes, m_edges in results:
|
sum_token_count = 0
|
||||||
|
for m_nodes, m_edges, token_count in out_results:
|
||||||
for k, v in m_nodes.items():
|
for k, v in m_nodes.items():
|
||||||
maybe_nodes[k].extend(v)
|
maybe_nodes[k].extend(v)
|
||||||
for k, v in m_edges.items():
|
for k, v in m_edges.items():
|
||||||
maybe_edges[tuple(sorted(k))].extend(v)
|
maybe_edges[tuple(sorted(k))].extend(v)
|
||||||
logging.info("Inserting entities into storage...")
|
sum_token_count += token_count
|
||||||
|
now = trio.current_time()
|
||||||
|
if callback:
|
||||||
|
callback(msg = f"Entities and relationships extraction done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {sum_token_count} tokens, {now-start_ts:.2f}s.")
|
||||||
|
start_ts = now
|
||||||
|
logging.info("Entities merging...")
|
||||||
all_entities_data = []
|
all_entities_data = []
|
||||||
with ThreadPoolExecutor(max_workers=max_workers) as exe:
|
async with trio.open_nursery() as nursery:
|
||||||
threads = []
|
|
||||||
for en_nm, ents in maybe_nodes.items():
|
for en_nm, ents in maybe_nodes.items():
|
||||||
threads.append(
|
nursery.start_soon(self._merge_nodes, en_nm, ents, all_entities_data)
|
||||||
exe.submit(self._merge_nodes, en_nm, ents))
|
now = trio.current_time()
|
||||||
for t in threads:
|
if callback:
|
||||||
n = t.result()
|
callback(msg = f"Entities merging done, {now-start_ts:.2f}s.")
|
||||||
if not isinstance(n, Exception):
|
|
||||||
all_entities_data.append(n)
|
|
||||||
elif callback:
|
|
||||||
callback(msg="Knowledge graph nodes merging error: {}".format(str(n)))
|
|
||||||
|
|
||||||
logging.info("Inserting relationships into storage...")
|
start_ts = now
|
||||||
|
logging.info("Relationships merging...")
|
||||||
all_relationships_data = []
|
all_relationships_data = []
|
||||||
for (src, tgt), rels in maybe_edges.items():
|
async with trio.open_nursery() as nursery:
|
||||||
all_relationships_data.append(self._merge_edges(src, tgt, rels))
|
for (src, tgt), rels in maybe_edges.items():
|
||||||
|
nursery.start_soon(self._merge_edges, src, tgt, rels, all_relationships_data)
|
||||||
|
now = trio.current_time()
|
||||||
|
if callback:
|
||||||
|
callback(msg = f"Relationships merging done, {now-start_ts:.2f}s.")
|
||||||
|
|
||||||
if not len(all_entities_data) and not len(all_relationships_data):
|
if not len(all_entities_data) and not len(all_relationships_data):
|
||||||
logging.warning(
|
logging.warning(
|
||||||
@ -152,7 +147,7 @@ class Extractor:
|
|||||||
|
|
||||||
return all_entities_data, all_relationships_data
|
return all_entities_data, all_relationships_data
|
||||||
|
|
||||||
def _merge_nodes(self, entity_name: str, entities: list[dict]):
|
async def _merge_nodes(self, entity_name: str, entities: list[dict], all_relationships_data):
|
||||||
if not entities:
|
if not entities:
|
||||||
return
|
return
|
||||||
already_entity_types = []
|
already_entity_types = []
|
||||||
@ -176,26 +171,22 @@ class Extractor:
|
|||||||
sorted(set([dp["description"] for dp in entities] + already_description))
|
sorted(set([dp["description"] for dp in entities] + already_description))
|
||||||
)
|
)
|
||||||
already_source_ids = flat_uniq_list(entities, "source_id")
|
already_source_ids = flat_uniq_list(entities, "source_id")
|
||||||
try:
|
description = await self._handle_entity_relation_summary(entity_name, description)
|
||||||
description = self._handle_entity_relation_summary(
|
node_data = dict(
|
||||||
entity_name, description
|
entity_type=entity_type,
|
||||||
)
|
description=description,
|
||||||
node_data = dict(
|
source_id=already_source_ids,
|
||||||
entity_type=entity_type,
|
)
|
||||||
description=description,
|
node_data["entity_name"] = entity_name
|
||||||
source_id=already_source_ids,
|
self._set_entity_(entity_name, node_data)
|
||||||
)
|
all_relationships_data.append(node_data)
|
||||||
node_data["entity_name"] = entity_name
|
|
||||||
self._set_entity_(entity_name, node_data)
|
|
||||||
return node_data
|
|
||||||
except Exception as e:
|
|
||||||
return e
|
|
||||||
|
|
||||||
def _merge_edges(
|
async def _merge_edges(
|
||||||
self,
|
self,
|
||||||
src_id: str,
|
src_id: str,
|
||||||
tgt_id: str,
|
tgt_id: str,
|
||||||
edges_data: list[dict]
|
edges_data: list[dict],
|
||||||
|
all_relationships_data
|
||||||
):
|
):
|
||||||
if not edges_data:
|
if not edges_data:
|
||||||
return
|
return
|
||||||
@ -226,7 +217,7 @@ class Extractor:
|
|||||||
"description": description,
|
"description": description,
|
||||||
"entity_type": 'UNKNOWN'
|
"entity_type": 'UNKNOWN'
|
||||||
})
|
})
|
||||||
description = self._handle_entity_relation_summary(
|
description = await self._handle_entity_relation_summary(
|
||||||
f"({src_id}, {tgt_id})", description
|
f"({src_id}, {tgt_id})", description
|
||||||
)
|
)
|
||||||
edge_data = dict(
|
edge_data = dict(
|
||||||
@ -238,10 +229,9 @@ class Extractor:
|
|||||||
source_id=source_id
|
source_id=source_id
|
||||||
)
|
)
|
||||||
self._set_relation_(src_id, tgt_id, edge_data)
|
self._set_relation_(src_id, tgt_id, edge_data)
|
||||||
|
all_relationships_data.append(edge_data)
|
||||||
|
|
||||||
return edge_data
|
async def _handle_entity_relation_summary(
|
||||||
|
|
||||||
def _handle_entity_relation_summary(
|
|
||||||
self,
|
self,
|
||||||
entity_or_relation_name: str,
|
entity_or_relation_name: str,
|
||||||
description: str
|
description: str
|
||||||
@ -256,5 +246,6 @@ class Extractor:
|
|||||||
)
|
)
|
||||||
use_prompt = prompt_template.format(**context_base)
|
use_prompt = prompt_template.format(**context_base)
|
||||||
logging.info(f"Trigger summary: {entity_or_relation_name}")
|
logging.info(f"Trigger summary: {entity_or_relation_name}")
|
||||||
summary = self._chat(use_prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.8})
|
async with chat_limiter:
|
||||||
|
summary = await trio.to_thread.run_sync(lambda: self._chat(use_prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.8}))
|
||||||
return summary
|
return summary
|
||||||
|
@ -5,15 +5,15 @@ Reference:
|
|||||||
- [graphrag](https://github.com/microsoft/graphrag)
|
- [graphrag](https://github.com/microsoft/graphrag)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import re
|
import re
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
import trio
|
||||||
|
|
||||||
from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS, DEFAULT_ENTITY_TYPES
|
from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS, DEFAULT_ENTITY_TYPES
|
||||||
from graphrag.general.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
|
from graphrag.general.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
|
||||||
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
|
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, chat_limiter
|
||||||
from rag.llm.chat_model import Base as CompletionLLM
|
from rag.llm.chat_model import Base as CompletionLLM
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from rag.utils import num_tokens_from_string
|
from rag.utils import num_tokens_from_string
|
||||||
@ -102,53 +102,47 @@ class GraphExtractor(Extractor):
|
|||||||
self._entity_types_key: ",".join(DEFAULT_ENTITY_TYPES),
|
self._entity_types_key: ",".join(DEFAULT_ENTITY_TYPES),
|
||||||
}
|
}
|
||||||
|
|
||||||
def _process_single_content(self,
|
async def _process_single_content(self, chunk_key_dp: tuple[str, str], chunk_seq: int, num_chunks: int, out_results):
|
||||||
chunk_key_dp: tuple[str, str]
|
|
||||||
):
|
|
||||||
token_count = 0
|
token_count = 0
|
||||||
|
|
||||||
chunk_key = chunk_key_dp[0]
|
chunk_key = chunk_key_dp[0]
|
||||||
content = chunk_key_dp[1]
|
content = chunk_key_dp[1]
|
||||||
variables = {
|
variables = {
|
||||||
**self._prompt_variables,
|
**self._prompt_variables,
|
||||||
self._input_text_key: content,
|
self._input_text_key: content,
|
||||||
}
|
}
|
||||||
try:
|
gen_conf = {"temperature": 0.3}
|
||||||
gen_conf = {"temperature": 0.3}
|
hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
|
||||||
hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
|
async with chat_limiter:
|
||||||
response = self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf)
|
response = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf))
|
||||||
token_count += num_tokens_from_string(hint_prompt + response)
|
token_count += num_tokens_from_string(hint_prompt + response)
|
||||||
|
|
||||||
results = response or ""
|
|
||||||
history = [{"role": "system", "content": hint_prompt}, {"role": "user", "content": response}]
|
|
||||||
|
|
||||||
# Repeat to ensure we maximize entity count
|
|
||||||
for i in range(self._max_gleanings):
|
|
||||||
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
|
|
||||||
history.append({"role": "user", "content": text})
|
|
||||||
response = self._chat("", history, gen_conf)
|
|
||||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
|
|
||||||
results += response or ""
|
|
||||||
|
|
||||||
# if this is the final glean, don't bother updating the continuation flag
|
|
||||||
if i >= self._max_gleanings - 1:
|
|
||||||
break
|
|
||||||
history.append({"role": "assistant", "content": response})
|
|
||||||
history.append({"role": "user", "content": LOOP_PROMPT})
|
|
||||||
continuation = self._chat("", history, {"temperature": 0.8})
|
|
||||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
|
|
||||||
if continuation != "YES":
|
|
||||||
break
|
|
||||||
|
|
||||||
record_delimiter = variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER)
|
|
||||||
tuple_delimiter = variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER)
|
|
||||||
records = [re.sub(r"^\(|\)$", "", r.strip()) for r in results.split(record_delimiter)]
|
|
||||||
records = [r for r in records if r.strip()]
|
|
||||||
maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, tuple_delimiter)
|
|
||||||
return maybe_nodes, maybe_edges, token_count
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception("error extracting graph")
|
|
||||||
return e, None, None
|
|
||||||
|
|
||||||
|
results = response or ""
|
||||||
|
history = [{"role": "system", "content": hint_prompt}, {"role": "user", "content": response}]
|
||||||
|
|
||||||
|
# Repeat to ensure we maximize entity count
|
||||||
|
for i in range(self._max_gleanings):
|
||||||
|
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
|
||||||
|
history.append({"role": "user", "content": text})
|
||||||
|
async with chat_limiter:
|
||||||
|
response = await trio.to_thread.run_sync(lambda: self._chat("", history, gen_conf))
|
||||||
|
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
|
||||||
|
results += response or ""
|
||||||
|
|
||||||
|
# if this is the final glean, don't bother updating the continuation flag
|
||||||
|
if i >= self._max_gleanings - 1:
|
||||||
|
break
|
||||||
|
history.append({"role": "assistant", "content": response})
|
||||||
|
history.append({"role": "user", "content": LOOP_PROMPT})
|
||||||
|
async with chat_limiter:
|
||||||
|
continuation = await trio.to_thread.run_sync(lambda: self._chat("", history, {"temperature": 0.8}))
|
||||||
|
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
|
||||||
|
if continuation != "YES":
|
||||||
|
break
|
||||||
|
record_delimiter = variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER)
|
||||||
|
tuple_delimiter = variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER)
|
||||||
|
records = [re.sub(r"^\(|\)$", "", r.strip()) for r in results.split(record_delimiter)]
|
||||||
|
records = [r for r in records if r.strip()]
|
||||||
|
maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, tuple_delimiter)
|
||||||
|
out_results.append((maybe_nodes, maybe_edges, token_count))
|
||||||
|
if self.callback:
|
||||||
|
self.callback(0.5+0.1*len(out_results)/num_chunks, msg = f"Entities extraction of chunk {chunk_seq} {len(out_results)}/{num_chunks} done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {token_count} tokens.")
|
||||||
|
@ -17,6 +17,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from functools import reduce, partial
|
from functools import reduce, partial
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
|
import trio
|
||||||
|
|
||||||
from api import settings
|
from api import settings
|
||||||
from graphrag.general.community_reports_extractor import CommunityReportsExtractor
|
from graphrag.general.community_reports_extractor import CommunityReportsExtractor
|
||||||
@ -41,18 +42,24 @@ class Dealer:
|
|||||||
embed_bdl=None,
|
embed_bdl=None,
|
||||||
callback=None
|
callback=None
|
||||||
):
|
):
|
||||||
docids = list(set([docid for docid,_ in chunks]))
|
self.tenant_id = tenant_id
|
||||||
|
self.kb_id = kb_id
|
||||||
|
self.chunks = chunks
|
||||||
self.llm_bdl = llm_bdl
|
self.llm_bdl = llm_bdl
|
||||||
self.embed_bdl = embed_bdl
|
self.embed_bdl = embed_bdl
|
||||||
ext = extractor(self.llm_bdl, language=language,
|
self.ext = extractor(self.llm_bdl, language=language,
|
||||||
entity_types=entity_types,
|
entity_types=entity_types,
|
||||||
get_entity=partial(get_entity, tenant_id, kb_id),
|
get_entity=partial(get_entity, tenant_id, kb_id),
|
||||||
set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl),
|
set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl),
|
||||||
get_relation=partial(get_relation, tenant_id, kb_id),
|
get_relation=partial(get_relation, tenant_id, kb_id),
|
||||||
set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl)
|
set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl)
|
||||||
)
|
)
|
||||||
ents, rels = ext(chunks, callback)
|
|
||||||
self.graph = nx.Graph()
|
self.graph = nx.Graph()
|
||||||
|
self.callback = callback
|
||||||
|
|
||||||
|
async def __call__(self):
|
||||||
|
docids = list(set([docid for docid, _ in self.chunks]))
|
||||||
|
ents, rels = await self.ext(self.chunks, self.callback)
|
||||||
for en in ents:
|
for en in ents:
|
||||||
self.graph.add_node(en["entity_name"], entity_type=en["entity_type"])#, description=en["description"])
|
self.graph.add_node(en["entity_name"], entity_type=en["entity_type"])#, description=en["description"])
|
||||||
|
|
||||||
@ -64,16 +71,16 @@ class Dealer:
|
|||||||
#description=rel["description"]
|
#description=rel["description"]
|
||||||
)
|
)
|
||||||
|
|
||||||
with RedisDistributedLock(kb_id, 60*60):
|
with RedisDistributedLock(self.kb_id, 60*60):
|
||||||
old_graph, old_doc_ids = get_graph(tenant_id, kb_id)
|
old_graph, old_doc_ids = get_graph(self.tenant_id, self.kb_id)
|
||||||
if old_graph is not None:
|
if old_graph is not None:
|
||||||
logging.info("Merge with an exiting graph...................")
|
logging.info("Merge with an exiting graph...................")
|
||||||
self.graph = reduce(graph_merge, [old_graph, self.graph])
|
self.graph = reduce(graph_merge, [old_graph, self.graph])
|
||||||
update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, self.graph, 2)
|
update_nodes_pagerank_nhop_neighbour(self.tenant_id, self.kb_id, self.graph, 2)
|
||||||
if old_doc_ids:
|
if old_doc_ids:
|
||||||
docids.extend(old_doc_ids)
|
docids.extend(old_doc_ids)
|
||||||
docids = list(set(docids))
|
docids = list(set(docids))
|
||||||
set_graph(tenant_id, kb_id, self.graph, docids)
|
set_graph(self.tenant_id, self.kb_id, self.graph, docids)
|
||||||
|
|
||||||
|
|
||||||
class WithResolution(Dealer):
|
class WithResolution(Dealer):
|
||||||
@ -84,47 +91,50 @@ class WithResolution(Dealer):
|
|||||||
embed_bdl=None,
|
embed_bdl=None,
|
||||||
callback=None
|
callback=None
|
||||||
):
|
):
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.kb_id = kb_id
|
||||||
self.llm_bdl = llm_bdl
|
self.llm_bdl = llm_bdl
|
||||||
self.embed_bdl = embed_bdl
|
self.embed_bdl = embed_bdl
|
||||||
|
self.callback = callback
|
||||||
with RedisDistributedLock(kb_id, 60*60):
|
async def __call__(self):
|
||||||
self.graph, doc_ids = get_graph(tenant_id, kb_id)
|
with RedisDistributedLock(self.kb_id, 60*60):
|
||||||
|
self.graph, doc_ids = await trio.to_thread.run_sync(lambda: get_graph(self.tenant_id, self.kb_id))
|
||||||
if not self.graph:
|
if not self.graph:
|
||||||
logging.error(f"Faild to fetch the graph. tenant_id:{kb_id}, kb_id:{kb_id}")
|
logging.error(f"Faild to fetch the graph. tenant_id:{self.kb_id}, kb_id:{self.kb_id}")
|
||||||
if callback:
|
if self.callback:
|
||||||
callback(-1, msg="Faild to fetch the graph.")
|
self.callback(-1, msg="Faild to fetch the graph.")
|
||||||
return
|
return
|
||||||
|
|
||||||
if callback:
|
if self.callback:
|
||||||
callback(msg="Fetch the existing graph.")
|
self.callback(msg="Fetch the existing graph.")
|
||||||
er = EntityResolution(self.llm_bdl,
|
er = EntityResolution(self.llm_bdl,
|
||||||
get_entity=partial(get_entity, tenant_id, kb_id),
|
get_entity=partial(get_entity, self.tenant_id, self.kb_id),
|
||||||
set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl),
|
set_entity=partial(set_entity, self.tenant_id, self.kb_id, self.embed_bdl),
|
||||||
get_relation=partial(get_relation, tenant_id, kb_id),
|
get_relation=partial(get_relation, self.tenant_id, self.kb_id),
|
||||||
set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl))
|
set_relation=partial(set_relation, self.tenant_id, self.kb_id, self.embed_bdl))
|
||||||
reso = er(self.graph)
|
reso = await er(self.graph)
|
||||||
self.graph = reso.graph
|
self.graph = reso.graph
|
||||||
logging.info("Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
|
logging.info("Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
|
||||||
if callback:
|
if self.callback:
|
||||||
callback(msg="Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
|
self.callback(msg="Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
|
||||||
update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, self.graph, 2)
|
await trio.to_thread.run_sync(lambda: update_nodes_pagerank_nhop_neighbour(self.tenant_id, self.kb_id, self.graph, 2))
|
||||||
set_graph(tenant_id, kb_id, self.graph, doc_ids)
|
await trio.to_thread.run_sync(lambda: set_graph(self.tenant_id, self.kb_id, self.graph, doc_ids))
|
||||||
|
|
||||||
settings.docStoreConn.delete({
|
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({
|
||||||
"knowledge_graph_kwd": "relation",
|
"knowledge_graph_kwd": "relation",
|
||||||
"kb_id": kb_id,
|
"kb_id": self.kb_id,
|
||||||
"from_entity_kwd": reso.removed_entities
|
"from_entity_kwd": reso.removed_entities
|
||||||
}, search.index_name(tenant_id), kb_id)
|
}, search.index_name(self.tenant_id), self.kb_id))
|
||||||
settings.docStoreConn.delete({
|
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({
|
||||||
"knowledge_graph_kwd": "relation",
|
"knowledge_graph_kwd": "relation",
|
||||||
"kb_id": kb_id,
|
"kb_id": self.kb_id,
|
||||||
"to_entity_kwd": reso.removed_entities
|
"to_entity_kwd": reso.removed_entities
|
||||||
}, search.index_name(tenant_id), kb_id)
|
}, search.index_name(self.tenant_id), self.kb_id))
|
||||||
settings.docStoreConn.delete({
|
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({
|
||||||
"knowledge_graph_kwd": "entity",
|
"knowledge_graph_kwd": "entity",
|
||||||
"kb_id": kb_id,
|
"kb_id": self.kb_id,
|
||||||
"entity_kwd": reso.removed_entities
|
"entity_kwd": reso.removed_entities
|
||||||
}, search.index_name(tenant_id), kb_id)
|
}, search.index_name(self.tenant_id), self.kb_id))
|
||||||
|
|
||||||
|
|
||||||
class WithCommunity(Dealer):
|
class WithCommunity(Dealer):
|
||||||
@ -136,38 +146,41 @@ class WithCommunity(Dealer):
|
|||||||
callback=None
|
callback=None
|
||||||
):
|
):
|
||||||
|
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.kb_id = kb_id
|
||||||
self.community_structure = None
|
self.community_structure = None
|
||||||
self.community_reports = None
|
self.community_reports = None
|
||||||
self.llm_bdl = llm_bdl
|
self.llm_bdl = llm_bdl
|
||||||
self.embed_bdl = embed_bdl
|
self.embed_bdl = embed_bdl
|
||||||
|
self.callback = callback
|
||||||
with RedisDistributedLock(kb_id, 60*60):
|
async def __call__(self):
|
||||||
self.graph, doc_ids = get_graph(tenant_id, kb_id)
|
with RedisDistributedLock(self.kb_id, 60*60):
|
||||||
|
self.graph, doc_ids = get_graph(self.tenant_id, self.kb_id)
|
||||||
if not self.graph:
|
if not self.graph:
|
||||||
logging.error(f"Faild to fetch the graph. tenant_id:{kb_id}, kb_id:{kb_id}")
|
logging.error(f"Faild to fetch the graph. tenant_id:{self.kb_id}, kb_id:{self.kb_id}")
|
||||||
if callback:
|
if self.callback:
|
||||||
callback(-1, msg="Faild to fetch the graph.")
|
self.callback(-1, msg="Faild to fetch the graph.")
|
||||||
return
|
return
|
||||||
if callback:
|
if self.callback:
|
||||||
callback(msg="Fetch the existing graph.")
|
self.callback(msg="Fetch the existing graph.")
|
||||||
|
|
||||||
cr = CommunityReportsExtractor(self.llm_bdl,
|
cr = CommunityReportsExtractor(self.llm_bdl,
|
||||||
get_entity=partial(get_entity, tenant_id, kb_id),
|
get_entity=partial(get_entity, self.tenant_id, self.kb_id),
|
||||||
set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl),
|
set_entity=partial(set_entity, self.tenant_id, self.kb_id, self.embed_bdl),
|
||||||
get_relation=partial(get_relation, tenant_id, kb_id),
|
get_relation=partial(get_relation, self.tenant_id, self.kb_id),
|
||||||
set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl))
|
set_relation=partial(set_relation, self.tenant_id, self.kb_id, self.embed_bdl))
|
||||||
cr = cr(self.graph, callback=callback)
|
cr = await cr(self.graph, callback=self.callback)
|
||||||
self.community_structure = cr.structured_output
|
self.community_structure = cr.structured_output
|
||||||
self.community_reports = cr.output
|
self.community_reports = cr.output
|
||||||
set_graph(tenant_id, kb_id, self.graph, doc_ids)
|
await trio.to_thread.run_sync(lambda: set_graph(self.tenant_id, self.kb_id, self.graph, doc_ids))
|
||||||
|
|
||||||
if callback:
|
if self.callback:
|
||||||
callback(msg="Graph community extraction is done. Indexing {} reports.".format(len(cr.structured_output)))
|
self.callback(msg="Graph community extraction is done. Indexing {} reports.".format(len(cr.structured_output)))
|
||||||
|
|
||||||
settings.docStoreConn.delete({
|
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({
|
||||||
"knowledge_graph_kwd": "community_report",
|
"knowledge_graph_kwd": "community_report",
|
||||||
"kb_id": kb_id
|
"kb_id": self.kb_id
|
||||||
}, search.index_name(tenant_id), kb_id)
|
}, search.index_name(self.tenant_id), self.kb_id))
|
||||||
|
|
||||||
for stru, rep in zip(self.community_structure, self.community_reports):
|
for stru, rep in zip(self.community_structure, self.community_reports):
|
||||||
obj = {
|
obj = {
|
||||||
@ -183,7 +196,7 @@ class WithCommunity(Dealer):
|
|||||||
"weight_flt": stru["weight"],
|
"weight_flt": stru["weight"],
|
||||||
"entities_kwd": stru["entities"],
|
"entities_kwd": stru["entities"],
|
||||||
"important_kwd": stru["entities"],
|
"important_kwd": stru["entities"],
|
||||||
"kb_id": kb_id,
|
"kb_id": self.kb_id,
|
||||||
"source_id": doc_ids,
|
"source_id": doc_ids,
|
||||||
"available_int": 0
|
"available_int": 0
|
||||||
}
|
}
|
||||||
@ -193,5 +206,5 @@ class WithCommunity(Dealer):
|
|||||||
# chunk["q_%d_vec" % len(ebd[0])] = ebd[0]
|
# chunk["q_%d_vec" % len(ebd[0])] = ebd[0]
|
||||||
#except Exception as e:
|
#except Exception as e:
|
||||||
# logging.exception(f"Fail to embed entity relation: {e}")
|
# logging.exception(f"Fail to embed entity relation: {e}")
|
||||||
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id))
|
await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(self.tenant_id)))
|
||||||
|
|
||||||
|
@ -16,16 +16,14 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import collections
|
import collections
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
import traceback
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
import trio
|
||||||
|
|
||||||
from graphrag.general.extractor import Extractor
|
from graphrag.general.extractor import Extractor
|
||||||
from graphrag.general.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
|
from graphrag.general.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
|
||||||
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
|
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, chat_limiter
|
||||||
from rag.llm.chat_model import Base as CompletionLLM
|
from rag.llm.chat_model import Base as CompletionLLM
|
||||||
import markdown_to_json
|
import markdown_to_json
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
@ -80,63 +78,47 @@ class MindMapExtractor(Extractor):
|
|||||||
)
|
)
|
||||||
return arr
|
return arr
|
||||||
|
|
||||||
def __call__(
|
async def __call__(
|
||||||
self, sections: list[str], prompt_variables: dict[str, Any] | None = None
|
self, sections: list[str], prompt_variables: dict[str, Any] | None = None
|
||||||
) -> MindMapResult:
|
) -> MindMapResult:
|
||||||
"""Call method definition."""
|
"""Call method definition."""
|
||||||
if prompt_variables is None:
|
if prompt_variables is None:
|
||||||
prompt_variables = {}
|
prompt_variables = {}
|
||||||
|
|
||||||
try:
|
res = []
|
||||||
res = []
|
token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512)
|
||||||
max_workers = int(os.environ.get('MINDMAP_EXTRACTOR_MAX_WORKERS', 12))
|
texts = []
|
||||||
with ThreadPoolExecutor(max_workers=max_workers) as exe:
|
cnt = 0
|
||||||
threads = []
|
async with trio.open_nursery() as nursery:
|
||||||
token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512)
|
for i in range(len(sections)):
|
||||||
texts = []
|
section_cnt = num_tokens_from_string(sections[i])
|
||||||
cnt = 0
|
if cnt + section_cnt >= token_count and texts:
|
||||||
for i in range(len(sections)):
|
nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res)
|
||||||
section_cnt = num_tokens_from_string(sections[i])
|
texts = []
|
||||||
if cnt + section_cnt >= token_count and texts:
|
cnt = 0
|
||||||
threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
|
texts.append(sections[i])
|
||||||
texts = []
|
cnt += section_cnt
|
||||||
cnt = 0
|
if texts:
|
||||||
texts.append(sections[i])
|
nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res)
|
||||||
cnt += section_cnt
|
if not res:
|
||||||
if texts:
|
return MindMapResult(output={"id": "root", "children": []})
|
||||||
threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
|
merge_json = reduce(self._merge, res)
|
||||||
|
if len(merge_json) > 1:
|
||||||
for i, _ in enumerate(threads):
|
keys = [re.sub(r"\*+", "", k) for k, v in merge_json.items() if isinstance(v, dict)]
|
||||||
res.append(_.result())
|
keyset = set(i for i in keys if i)
|
||||||
|
merge_json = {
|
||||||
if not res:
|
"id": "root",
|
||||||
return MindMapResult(output={"id": "root", "children": []})
|
"children": [
|
||||||
|
{
|
||||||
merge_json = reduce(self._merge, res)
|
"id": self._key(k),
|
||||||
if len(merge_json) > 1:
|
"children": self._be_children(v, keyset)
|
||||||
keys = [re.sub(r"\*+", "", k) for k, v in merge_json.items() if isinstance(v, dict)]
|
}
|
||||||
keyset = set(i for i in keys if i)
|
for k, v in merge_json.items() if isinstance(v, dict) and self._key(k)
|
||||||
merge_json = {
|
]
|
||||||
"id": "root",
|
}
|
||||||
"children": [
|
else:
|
||||||
{
|
k = self._key(list(merge_json.keys())[0])
|
||||||
"id": self._key(k),
|
merge_json = {"id": k, "children": self._be_children(list(merge_json.items())[0][1], {k})}
|
||||||
"children": self._be_children(v, keyset)
|
|
||||||
}
|
|
||||||
for k, v in merge_json.items() if isinstance(v, dict) and self._key(k)
|
|
||||||
]
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
k = self._key(list(merge_json.keys())[0])
|
|
||||||
merge_json = {"id": k, "children": self._be_children(list(merge_json.items())[0][1], {k})}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception("error mind graph")
|
|
||||||
self._on_error(
|
|
||||||
e,
|
|
||||||
traceback.format_exc(), None
|
|
||||||
)
|
|
||||||
merge_json = {"error": str(e)}
|
|
||||||
|
|
||||||
return MindMapResult(output=merge_json)
|
return MindMapResult(output=merge_json)
|
||||||
|
|
||||||
@ -181,8 +163,8 @@ class MindMapExtractor(Extractor):
|
|||||||
|
|
||||||
return self._list_to_kv(to_ret)
|
return self._list_to_kv(to_ret)
|
||||||
|
|
||||||
def _process_document(
|
async def _process_document(
|
||||||
self, text: str, prompt_variables: dict[str, str]
|
self, text: str, prompt_variables: dict[str, str], out_res
|
||||||
) -> str:
|
) -> str:
|
||||||
variables = {
|
variables = {
|
||||||
**prompt_variables,
|
**prompt_variables,
|
||||||
@ -190,8 +172,9 @@ class MindMapExtractor(Extractor):
|
|||||||
}
|
}
|
||||||
text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
|
text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
|
||||||
gen_conf = {"temperature": 0.5}
|
gen_conf = {"temperature": 0.5}
|
||||||
response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
|
async with chat_limiter:
|
||||||
|
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
|
||||||
response = re.sub(r"```[^\n]*", "", response)
|
response = re.sub(r"```[^\n]*", "", response)
|
||||||
logging.debug(response)
|
logging.debug(response)
|
||||||
logging.debug(self._todict(markdown_to_json.dictify(response)))
|
logging.debug(self._todict(markdown_to_json.dictify(response)))
|
||||||
return self._todict(markdown_to_json.dictify(response))
|
out_res.append(self._todict(markdown_to_json.dictify(response)))
|
||||||
|
@ -18,6 +18,7 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
|
import trio
|
||||||
|
|
||||||
from api import settings
|
from api import settings
|
||||||
from api.db import LLMType
|
from api.db import LLMType
|
||||||
@ -54,10 +55,13 @@ if __name__ == "__main__":
|
|||||||
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
|
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
|
||||||
|
|
||||||
dealer = Dealer(GraphExtractor, args.tenant_id, kb_id, llm_bdl, chunks, "English", embed_bdl=embed_bdl)
|
dealer = Dealer(GraphExtractor, args.tenant_id, kb_id, llm_bdl, chunks, "English", embed_bdl=embed_bdl)
|
||||||
|
trio.run(dealer())
|
||||||
print(json.dumps(nx.node_link_data(dealer.graph), ensure_ascii=False, indent=2))
|
print(json.dumps(nx.node_link_data(dealer.graph), ensure_ascii=False, indent=2))
|
||||||
|
|
||||||
dealer = WithResolution(args.tenant_id, kb_id, llm_bdl, embed_bdl)
|
dealer = WithResolution(args.tenant_id, kb_id, llm_bdl, embed_bdl)
|
||||||
|
trio.run(dealer())
|
||||||
dealer = WithCommunity(args.tenant_id, kb_id, llm_bdl, embed_bdl)
|
dealer = WithCommunity(args.tenant_id, kb_id, llm_bdl, embed_bdl)
|
||||||
|
trio.run(dealer())
|
||||||
|
|
||||||
print("------------------ COMMUNITY REPORT ----------------------\n", dealer.community_reports)
|
print("------------------ COMMUNITY REPORT ----------------------\n", dealer.community_reports)
|
||||||
print(json.dumps(dealer.community_structure, ensure_ascii=False, indent=2))
|
print(json.dumps(dealer.community_structure, ensure_ascii=False, indent=2))
|
||||||
|
@ -4,16 +4,16 @@
|
|||||||
Reference:
|
Reference:
|
||||||
- [graphrag](https://github.com/microsoft/graphrag)
|
- [graphrag](https://github.com/microsoft/graphrag)
|
||||||
"""
|
"""
|
||||||
import logging
|
|
||||||
import re
|
import re
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS
|
from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS
|
||||||
from graphrag.light.graph_prompt import PROMPTS
|
from graphrag.light.graph_prompt import PROMPTS
|
||||||
from graphrag.utils import pack_user_ass_to_openai_messages, split_string_by_multi_markers
|
from graphrag.utils import pack_user_ass_to_openai_messages, split_string_by_multi_markers, chat_limiter
|
||||||
from rag.llm.chat_model import Base as CompletionLLM
|
from rag.llm.chat_model import Base as CompletionLLM
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from rag.utils import num_tokens_from_string
|
from rag.utils import num_tokens_from_string
|
||||||
|
import trio
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -82,7 +82,7 @@ class GraphExtractor(Extractor):
|
|||||||
)
|
)
|
||||||
self._left_token_count = max(llm_invoker.max_length * 0.6, self._left_token_count)
|
self._left_token_count = max(llm_invoker.max_length * 0.6, self._left_token_count)
|
||||||
|
|
||||||
def _process_single_content(self, chunk_key_dp: tuple[str, str]):
|
async def _process_single_content(self, chunk_key_dp: tuple[str, str], chunk_seq: int, num_chunks: int, out_results):
|
||||||
token_count = 0
|
token_count = 0
|
||||||
chunk_key = chunk_key_dp[0]
|
chunk_key = chunk_key_dp[0]
|
||||||
content = chunk_key_dp[1]
|
content = chunk_key_dp[1]
|
||||||
@ -90,38 +90,39 @@ class GraphExtractor(Extractor):
|
|||||||
**self._context_base, input_text="{input_text}"
|
**self._context_base, input_text="{input_text}"
|
||||||
).format(**self._context_base, input_text=content)
|
).format(**self._context_base, input_text=content)
|
||||||
|
|
||||||
try:
|
gen_conf = {"temperature": 0.8}
|
||||||
gen_conf = {"temperature": 0.8}
|
async with chat_limiter:
|
||||||
final_result = self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf)
|
final_result = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf))
|
||||||
token_count += num_tokens_from_string(hint_prompt + final_result)
|
token_count += num_tokens_from_string(hint_prompt + final_result)
|
||||||
history = pack_user_ass_to_openai_messages("Output:", final_result, self._continue_prompt)
|
history = pack_user_ass_to_openai_messages("Output:", final_result, self._continue_prompt)
|
||||||
for now_glean_index in range(self._max_gleanings):
|
for now_glean_index in range(self._max_gleanings):
|
||||||
glean_result = self._chat(hint_prompt, history, gen_conf)
|
async with chat_limiter:
|
||||||
history.extend([{"role": "assistant", "content": glean_result}, {"role": "user", "content": self._continue_prompt}])
|
glean_result = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, history, gen_conf))
|
||||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt)
|
history.extend([{"role": "assistant", "content": glean_result}, {"role": "user", "content": self._continue_prompt}])
|
||||||
final_result += glean_result
|
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt)
|
||||||
if now_glean_index == self._max_gleanings - 1:
|
final_result += glean_result
|
||||||
break
|
if now_glean_index == self._max_gleanings - 1:
|
||||||
|
break
|
||||||
|
|
||||||
if_loop_result = self._chat(self._if_loop_prompt, history, gen_conf)
|
async with chat_limiter:
|
||||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt)
|
if_loop_result = await trio.to_thread.run_sync(lambda: self._chat(self._if_loop_prompt, history, gen_conf))
|
||||||
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt)
|
||||||
if if_loop_result != "yes":
|
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
||||||
break
|
if if_loop_result != "yes":
|
||||||
|
break
|
||||||
|
|
||||||
records = split_string_by_multi_markers(
|
records = split_string_by_multi_markers(
|
||||||
final_result,
|
final_result,
|
||||||
[self._context_base["record_delimiter"], self._context_base["completion_delimiter"]],
|
[self._context_base["record_delimiter"], self._context_base["completion_delimiter"]],
|
||||||
)
|
)
|
||||||
rcds = []
|
rcds = []
|
||||||
for record in records:
|
for record in records:
|
||||||
record = re.search(r"\((.*)\)", record)
|
record = re.search(r"\((.*)\)", record)
|
||||||
if record is None:
|
if record is None:
|
||||||
continue
|
continue
|
||||||
rcds.append(record.group(1))
|
rcds.append(record.group(1))
|
||||||
records = rcds
|
records = rcds
|
||||||
maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, self._context_base["tuple_delimiter"])
|
maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, self._context_base["tuple_delimiter"])
|
||||||
return maybe_nodes, maybe_edges, token_count
|
out_results.append((maybe_nodes, maybe_edges, token_count))
|
||||||
except Exception as e:
|
if self.callback:
|
||||||
logging.exception("error extracting graph")
|
self.callback(0.5+0.1*len(out_results)/num_chunks, msg = f"Entities extraction of chunk {chunk_seq} {len(out_results)}/{num_chunks} done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {token_count} tokens.")
|
||||||
return e, None, None
|
|
||||||
|
@ -15,6 +15,8 @@ from collections import defaultdict
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
|
import os
|
||||||
|
import trio
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -28,6 +30,7 @@ from rag.utils.redis_conn import REDIS_CONN
|
|||||||
|
|
||||||
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
|
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
|
||||||
|
|
||||||
|
chat_limiter = trio.CapacityLimiter(int(os.environ.get('MAX_CONCURRENT_CHATS', 100)))
|
||||||
|
|
||||||
def perform_variable_replacements(
|
def perform_variable_replacements(
|
||||||
input: str, history: list[dict] | None = None, variables: dict | None = None
|
input: str, history: list[dict] | None = None, variables: dict | None = None
|
||||||
|
@ -122,7 +122,8 @@ dependencies = [
|
|||||||
"pyodbc>=5.2.0,<6.0.0",
|
"pyodbc>=5.2.0,<6.0.0",
|
||||||
"pyicu>=2.13.1,<3.0.0",
|
"pyicu>=2.13.1,<3.0.0",
|
||||||
"flasgger>=0.9.7.1,<0.10.0",
|
"flasgger>=0.9.7.1,<0.10.0",
|
||||||
"xxhash>=3.5.0,<4.0.0"
|
"xxhash>=3.5.0,<4.0.0",
|
||||||
|
"trio>=0.29.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
@ -133,4 +134,7 @@ full = [
|
|||||||
"flagembedding==1.2.10",
|
"flagembedding==1.2.10",
|
||||||
"torch>=2.5.0,<3.0.0",
|
"torch>=2.5.0,<3.0.0",
|
||||||
"transformers>=4.35.0,<5.0.0"
|
"transformers>=4.35.0,<5.0.0"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[tool.uv.index]]
|
||||||
|
url = "https://mirrors.aliyun.com/pypi/simple"
|
||||||
|
@ -14,15 +14,14 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
from concurrent.futures import ThreadPoolExecutor, ALL_COMPLETED, wait
|
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
import umap
|
import umap
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.mixture import GaussianMixture
|
from sklearn.mixture import GaussianMixture
|
||||||
|
import trio
|
||||||
|
|
||||||
from graphrag.utils import get_llm_cache, get_embed_cache, set_embed_cache, set_llm_cache
|
from graphrag.utils import get_llm_cache, get_embed_cache, set_embed_cache, set_llm_cache, chat_limiter
|
||||||
from rag.utils import truncate
|
from rag.utils import truncate
|
||||||
|
|
||||||
|
|
||||||
@ -68,24 +67,25 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
|||||||
optimal_clusters = n_clusters[np.argmin(bics)]
|
optimal_clusters = n_clusters[np.argmin(bics)]
|
||||||
return optimal_clusters
|
return optimal_clusters
|
||||||
|
|
||||||
def __call__(self, chunks, random_state, callback=None):
|
async def __call__(self, chunks, random_state, callback=None):
|
||||||
layers = [(0, len(chunks))]
|
layers = [(0, len(chunks))]
|
||||||
start, end = 0, len(chunks)
|
start, end = 0, len(chunks)
|
||||||
if len(chunks) <= 1:
|
if len(chunks) <= 1:
|
||||||
return []
|
return []
|
||||||
chunks = [(s, a) for s, a in chunks if s and len(a) > 0]
|
chunks = [(s, a) for s, a in chunks if s and len(a) > 0]
|
||||||
|
|
||||||
def summarize(ck_idx, lock):
|
async def summarize(ck_idx, lock):
|
||||||
nonlocal chunks
|
nonlocal chunks
|
||||||
try:
|
try:
|
||||||
texts = [chunks[i][0] for i in ck_idx]
|
texts = [chunks[i][0] for i in ck_idx]
|
||||||
len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
|
len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
|
||||||
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
|
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
|
||||||
cnt = self._chat("You're a helpful assistant.",
|
async with chat_limiter:
|
||||||
[{"role": "user",
|
cnt = await trio.to_thread.run_sync(lambda: self._chat("You're a helpful assistant.",
|
||||||
"content": self._prompt.format(cluster_content=cluster_content)}],
|
[{"role": "user",
|
||||||
{"temperature": 0.3, "max_tokens": self._max_token}
|
"content": self._prompt.format(cluster_content=cluster_content)}],
|
||||||
)
|
{"temperature": 0.3, "max_tokens": self._max_token}
|
||||||
|
))
|
||||||
cnt = re.sub("(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", "",
|
cnt = re.sub("(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", "",
|
||||||
cnt)
|
cnt)
|
||||||
logging.debug(f"SUM: {cnt}")
|
logging.debug(f"SUM: {cnt}")
|
||||||
@ -97,10 +97,11 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
|||||||
return e
|
return e
|
||||||
|
|
||||||
labels = []
|
labels = []
|
||||||
|
lock = Lock()
|
||||||
while end - start > 1:
|
while end - start > 1:
|
||||||
embeddings = [embd for _, embd in chunks[start: end]]
|
embeddings = [embd for _, embd in chunks[start: end]]
|
||||||
if len(embeddings) == 2:
|
if len(embeddings) == 2:
|
||||||
summarize([start, start + 1], Lock())
|
await summarize([start, start + 1], lock)
|
||||||
if callback:
|
if callback:
|
||||||
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
|
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
|
||||||
labels.extend([0, 0])
|
labels.extend([0, 0])
|
||||||
@ -122,19 +123,14 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
|||||||
probs = gm.predict_proba(reduced_embeddings)
|
probs = gm.predict_proba(reduced_embeddings)
|
||||||
lbls = [np.where(prob > self._threshold)[0] for prob in probs]
|
lbls = [np.where(prob > self._threshold)[0] for prob in probs]
|
||||||
lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls]
|
lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls]
|
||||||
lock = Lock()
|
|
||||||
with ThreadPoolExecutor(max_workers=int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 10))) as executor:
|
async with trio.open_nursery() as nursery:
|
||||||
threads = []
|
|
||||||
for c in range(n_clusters):
|
for c in range(n_clusters):
|
||||||
ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]
|
ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]
|
||||||
if not ck_idx:
|
if not ck_idx:
|
||||||
continue
|
continue
|
||||||
threads.append(executor.submit(summarize, ck_idx, lock))
|
async with chat_limiter:
|
||||||
wait(threads, return_when=ALL_COMPLETED)
|
nursery.start_soon(lambda: summarize(ck_idx, lock))
|
||||||
for th in threads:
|
|
||||||
if isinstance(th.result(), Exception):
|
|
||||||
raise th.result()
|
|
||||||
logging.debug(str([t.result() for t in threads]))
|
|
||||||
|
|
||||||
assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
|
assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
|
||||||
labels.extend(lbls)
|
labels.extend(lbls)
|
||||||
|
@ -30,7 +30,6 @@ CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
|
|||||||
CONSUMER_NAME = "task_executor_" + CONSUMER_NO
|
CONSUMER_NAME = "task_executor_" + CONSUMER_NO
|
||||||
initRootLogger(CONSUMER_NAME)
|
initRootLogger(CONSUMER_NAME)
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@ -38,14 +37,14 @@ import json
|
|||||||
import xxhash
|
import xxhash
|
||||||
import copy
|
import copy
|
||||||
import re
|
import re
|
||||||
import time
|
|
||||||
import threading
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from multiprocessing.context import TimeoutError
|
from multiprocessing.context import TimeoutError
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
import tracemalloc
|
import tracemalloc
|
||||||
|
import resource
|
||||||
import signal
|
import signal
|
||||||
|
import trio
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from peewee import DoesNotExist
|
from peewee import DoesNotExist
|
||||||
@ -64,8 +63,9 @@ from rag.nlp import search, rag_tokenizer
|
|||||||
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
||||||
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings, TAG_FLD, PAGERANK_FLD
|
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings, TAG_FLD, PAGERANK_FLD
|
||||||
from rag.utils import num_tokens_from_string
|
from rag.utils import num_tokens_from_string
|
||||||
from rag.utils.redis_conn import REDIS_CONN, Payload
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
|
from graphrag.utils import chat_limiter
|
||||||
|
|
||||||
BATCH_SIZE = 64
|
BATCH_SIZE = 64
|
||||||
|
|
||||||
@ -88,28 +88,28 @@ FACTORY = {
|
|||||||
ParserType.TAG.value: tag
|
ParserType.TAG.value: tag
|
||||||
}
|
}
|
||||||
|
|
||||||
|
UNACKED_ITERATOR = None
|
||||||
CONSUMER_NAME = "task_consumer_" + CONSUMER_NO
|
CONSUMER_NAME = "task_consumer_" + CONSUMER_NO
|
||||||
PAYLOAD: Payload | None = None
|
|
||||||
BOOT_AT = datetime.now().astimezone().isoformat(timespec="milliseconds")
|
BOOT_AT = datetime.now().astimezone().isoformat(timespec="milliseconds")
|
||||||
PENDING_TASKS = 0
|
PENDING_TASKS = 0
|
||||||
LAG_TASKS = 0
|
LAG_TASKS = 0
|
||||||
|
|
||||||
mt_lock = threading.Lock()
|
|
||||||
DONE_TASKS = 0
|
DONE_TASKS = 0
|
||||||
FAILED_TASKS = 0
|
FAILED_TASKS = 0
|
||||||
CURRENT_TASK = None
|
|
||||||
|
|
||||||
tracemalloc_started = False
|
CURRENT_TASKS = {}
|
||||||
|
|
||||||
|
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
|
||||||
|
MAX_CONCURRENT_CHUNK_BUILDERS = int(os.environ.get('MAX_CONCURRENT_CHUNK_BUILDERS', "1"))
|
||||||
|
task_limiter = trio.CapacityLimiter(MAX_CONCURRENT_TASKS)
|
||||||
|
chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS)
|
||||||
|
|
||||||
# SIGUSR1 handler: start tracemalloc and take snapshot
|
# SIGUSR1 handler: start tracemalloc and take snapshot
|
||||||
def start_tracemalloc_and_snapshot(signum, frame):
|
def start_tracemalloc_and_snapshot(signum, frame):
|
||||||
global tracemalloc_started
|
if not tracemalloc.is_tracing():
|
||||||
if not tracemalloc_started:
|
logging.info("start tracemalloc")
|
||||||
logging.info("got SIGUSR1, start tracemalloc")
|
|
||||||
tracemalloc.start()
|
tracemalloc.start()
|
||||||
tracemalloc_started = True
|
|
||||||
else:
|
else:
|
||||||
logging.info("got SIGUSR1, tracemalloc is already running")
|
logging.info("tracemalloc is already running")
|
||||||
|
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
snapshot_file = f"snapshot_{timestamp}.trace"
|
snapshot_file = f"snapshot_{timestamp}.trace"
|
||||||
@ -117,17 +117,17 @@ def start_tracemalloc_and_snapshot(signum, frame):
|
|||||||
|
|
||||||
snapshot = tracemalloc.take_snapshot()
|
snapshot = tracemalloc.take_snapshot()
|
||||||
snapshot.dump(snapshot_file)
|
snapshot.dump(snapshot_file)
|
||||||
logging.info(f"taken snapshot {snapshot_file}")
|
current, peak = tracemalloc.get_traced_memory()
|
||||||
|
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||||
|
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
|
||||||
|
|
||||||
# SIGUSR2 handler: stop tracemalloc
|
# SIGUSR2 handler: stop tracemalloc
|
||||||
def stop_tracemalloc(signum, frame):
|
def stop_tracemalloc(signum, frame):
|
||||||
global tracemalloc_started
|
if tracemalloc.is_tracing():
|
||||||
if tracemalloc_started:
|
logging.info("stop tracemalloc")
|
||||||
logging.info("go SIGUSR2, stop tracemalloc")
|
|
||||||
tracemalloc.stop()
|
tracemalloc.stop()
|
||||||
tracemalloc_started = False
|
|
||||||
else:
|
else:
|
||||||
logging.info("got SIGUSR2, tracemalloc not running")
|
logging.info("tracemalloc not running")
|
||||||
|
|
||||||
class TaskCanceledException(Exception):
|
class TaskCanceledException(Exception):
|
||||||
def __init__(self, msg):
|
def __init__(self, msg):
|
||||||
@ -135,17 +135,9 @@ class TaskCanceledException(Exception):
|
|||||||
|
|
||||||
|
|
||||||
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
|
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
|
||||||
global PAYLOAD
|
|
||||||
if prog is not None and prog < 0:
|
if prog is not None and prog < 0:
|
||||||
msg = "[ERROR]" + msg
|
msg = "[ERROR]" + msg
|
||||||
try:
|
cancel = TaskService.do_cancel(task_id)
|
||||||
cancel = TaskService.do_cancel(task_id)
|
|
||||||
except DoesNotExist:
|
|
||||||
logging.warning(f"set_progress task {task_id} is unknown")
|
|
||||||
if PAYLOAD:
|
|
||||||
PAYLOAD.ack()
|
|
||||||
PAYLOAD = None
|
|
||||||
return
|
|
||||||
|
|
||||||
if cancel:
|
if cancel:
|
||||||
msg += " [Canceled]"
|
msg += " [Canceled]"
|
||||||
@ -162,66 +154,55 @@ def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing...
|
|||||||
d["progress"] = prog
|
d["progress"] = prog
|
||||||
|
|
||||||
logging.info(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}")
|
logging.info(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}")
|
||||||
try:
|
TaskService.update_progress(task_id, d)
|
||||||
TaskService.update_progress(task_id, d)
|
|
||||||
except DoesNotExist:
|
|
||||||
logging.warning(f"set_progress task {task_id} is unknown")
|
|
||||||
if PAYLOAD:
|
|
||||||
PAYLOAD.ack()
|
|
||||||
PAYLOAD = None
|
|
||||||
return
|
|
||||||
|
|
||||||
close_connection()
|
close_connection()
|
||||||
if cancel and PAYLOAD:
|
if cancel:
|
||||||
PAYLOAD.ack()
|
|
||||||
PAYLOAD = None
|
|
||||||
raise TaskCanceledException(msg)
|
raise TaskCanceledException(msg)
|
||||||
|
|
||||||
|
async def collect():
|
||||||
def collect():
|
global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS
|
||||||
global CONSUMER_NAME, PAYLOAD, DONE_TASKS, FAILED_TASKS
|
global UNACKED_ITERATOR
|
||||||
try:
|
try:
|
||||||
PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
|
if not UNACKED_ITERATOR:
|
||||||
if not PAYLOAD:
|
UNACKED_ITERATOR = REDIS_CONN.get_unacked_iterator(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
|
||||||
PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
|
try:
|
||||||
if not PAYLOAD:
|
redis_msg = next(UNACKED_ITERATOR)
|
||||||
time.sleep(1)
|
except StopIteration:
|
||||||
return None
|
redis_msg = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
|
||||||
|
if not redis_msg:
|
||||||
|
await trio.sleep(1)
|
||||||
|
return None, None
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception("Get task event from queue exception")
|
logging.exception("collect got exception")
|
||||||
return None
|
return None, None
|
||||||
|
|
||||||
msg = PAYLOAD.get_message()
|
msg = redis_msg.get_message()
|
||||||
if not msg:
|
if not msg:
|
||||||
return None
|
logging.error(f"collect got empty message of {redis_msg.get_msg_id()}")
|
||||||
|
redis_msg.ack()
|
||||||
|
return None, None
|
||||||
|
|
||||||
task = None
|
|
||||||
canceled = False
|
canceled = False
|
||||||
try:
|
task = TaskService.get_task(msg["id"])
|
||||||
task = TaskService.get_task(msg["id"])
|
if task:
|
||||||
if task:
|
_, doc = DocumentService.get_by_id(task["doc_id"])
|
||||||
_, doc = DocumentService.get_by_id(task["doc_id"])
|
canceled = doc.run == TaskStatus.CANCEL.value or doc.progress < 0
|
||||||
canceled = doc.run == TaskStatus.CANCEL.value or doc.progress < 0
|
|
||||||
except DoesNotExist:
|
|
||||||
pass
|
|
||||||
except Exception:
|
|
||||||
logging.exception("collect get_task exception")
|
|
||||||
if not task or canceled:
|
if not task or canceled:
|
||||||
state = "is unknown" if not task else "has been cancelled"
|
state = "is unknown" if not task else "has been cancelled"
|
||||||
with mt_lock:
|
FAILED_TASKS += 1
|
||||||
DONE_TASKS += 1
|
logging.warning(f"collect task {msg['id']} {state}")
|
||||||
logging.info(f"collect task {msg['id']} {state}")
|
redis_msg.ack()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
task["task_type"] = msg.get("task_type", "")
|
task["task_type"] = msg.get("task_type", "")
|
||||||
return task
|
return redis_msg, task
|
||||||
|
|
||||||
|
|
||||||
def get_storage_binary(bucket, name):
|
async def get_storage_binary(bucket, name):
|
||||||
return STORAGE_IMPL.get(bucket, name)
|
return await trio.to_thread.run_sync(lambda: STORAGE_IMPL.get(bucket, name))
|
||||||
|
|
||||||
|
|
||||||
def build_chunks(task, progress_callback):
|
async def build_chunks(task, progress_callback):
|
||||||
if task["size"] > DOC_MAXIMUM_SIZE:
|
if task["size"] > DOC_MAXIMUM_SIZE:
|
||||||
set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
|
set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
|
||||||
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
|
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
|
||||||
@ -231,7 +212,7 @@ def build_chunks(task, progress_callback):
|
|||||||
try:
|
try:
|
||||||
st = timer()
|
st = timer()
|
||||||
bucket, name = File2DocumentService.get_storage_address(doc_id=task["doc_id"])
|
bucket, name = File2DocumentService.get_storage_address(doc_id=task["doc_id"])
|
||||||
binary = get_storage_binary(bucket, name)
|
binary = await get_storage_binary(bucket, name)
|
||||||
logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"]))
|
logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"]))
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
|
progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
|
||||||
@ -247,9 +228,10 @@ def build_chunks(task, progress_callback):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cks = chunker.chunk(task["name"], binary=binary, from_page=task["from_page"],
|
async with chunk_limiter:
|
||||||
to_page=task["to_page"], lang=task["language"], callback=progress_callback,
|
cks = await trio.to_thread.run_sync(lambda: chunker.chunk(task["name"], binary=binary, from_page=task["from_page"],
|
||||||
kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"])
|
to_page=task["to_page"], lang=task["language"], callback=progress_callback,
|
||||||
|
kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"]))
|
||||||
logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"]))
|
logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"]))
|
||||||
except TaskCanceledException:
|
except TaskCanceledException:
|
||||||
raise
|
raise
|
||||||
@ -286,7 +268,7 @@ def build_chunks(task, progress_callback):
|
|||||||
d["image"].save(output_buffer, format='JPEG')
|
d["image"].save(output_buffer, format='JPEG')
|
||||||
|
|
||||||
st = timer()
|
st = timer()
|
||||||
STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue())
|
await trio.to_thread.run_sync(lambda: STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue()))
|
||||||
el += timer() - st
|
el += timer() - st
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception(
|
logging.exception(
|
||||||
@ -306,14 +288,16 @@ def build_chunks(task, progress_callback):
|
|||||||
async def doc_keyword_extraction(chat_mdl, d, topn):
|
async def doc_keyword_extraction(chat_mdl, d, topn):
|
||||||
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn})
|
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn})
|
||||||
if not cached:
|
if not cached:
|
||||||
cached = await asyncio.to_thread(keyword_extraction, chat_mdl, d["content_with_weight"], topn)
|
async with chat_limiter:
|
||||||
|
cached = await trio.to_thread.run_sync(lambda: keyword_extraction(chat_mdl, d["content_with_weight"], topn))
|
||||||
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn})
|
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn})
|
||||||
if cached:
|
if cached:
|
||||||
d["important_kwd"] = cached.split(",")
|
d["important_kwd"] = cached.split(",")
|
||||||
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
|
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
|
||||||
return
|
return
|
||||||
tasks = [doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"]) for d in docs]
|
async with trio.open_nursery() as nursery:
|
||||||
asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks))
|
for d in docs:
|
||||||
|
nursery.start_soon(doc_keyword_extraction, chat_mdl, d, task["parser_config"]["auto_keywords"])
|
||||||
progress_callback(msg="Keywords generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
|
progress_callback(msg="Keywords generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
|
||||||
|
|
||||||
if task["parser_config"].get("auto_questions", 0):
|
if task["parser_config"].get("auto_questions", 0):
|
||||||
@ -324,13 +308,15 @@ def build_chunks(task, progress_callback):
|
|||||||
async def doc_question_proposal(chat_mdl, d, topn):
|
async def doc_question_proposal(chat_mdl, d, topn):
|
||||||
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn})
|
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn})
|
||||||
if not cached:
|
if not cached:
|
||||||
cached = await asyncio.to_thread(question_proposal, chat_mdl, d["content_with_weight"], topn)
|
async with chat_limiter:
|
||||||
|
cached = await trio.to_thread.run_sync(lambda: question_proposal(chat_mdl, d["content_with_weight"], topn))
|
||||||
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", {"topn": topn})
|
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", {"topn": topn})
|
||||||
if cached:
|
if cached:
|
||||||
d["question_kwd"] = cached.split("\n")
|
d["question_kwd"] = cached.split("\n")
|
||||||
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
|
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
|
||||||
tasks = [doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"]) for d in docs]
|
async with trio.open_nursery() as nursery:
|
||||||
asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks))
|
for d in docs:
|
||||||
|
nursery.start_soon(doc_question_proposal, chat_mdl, d, task["parser_config"]["auto_questions"])
|
||||||
progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
|
progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
|
||||||
|
|
||||||
if task["kb_parser_config"].get("tag_kb_ids", []):
|
if task["kb_parser_config"].get("tag_kb_ids", []):
|
||||||
@ -361,14 +347,16 @@ def build_chunks(task, progress_callback):
|
|||||||
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], all_tags, {"topn": topn_tags})
|
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], all_tags, {"topn": topn_tags})
|
||||||
if not cached:
|
if not cached:
|
||||||
picked_examples = random.choices(examples, k=2) if len(examples)>2 else examples
|
picked_examples = random.choices(examples, k=2) if len(examples)>2 else examples
|
||||||
cached = await asyncio.to_thread(content_tagging, chat_mdl, d["content_with_weight"], all_tags, picked_examples, topn=topn_tags)
|
async with chat_limiter:
|
||||||
|
cached = await trio.to_thread.run_sync(lambda: content_tagging(chat_mdl, d["content_with_weight"], all_tags, picked_examples, topn=topn_tags))
|
||||||
if cached:
|
if cached:
|
||||||
cached = json.dumps(cached)
|
cached = json.dumps(cached)
|
||||||
if cached:
|
if cached:
|
||||||
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
|
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
|
||||||
d[TAG_FLD] = json.loads(cached)
|
d[TAG_FLD] = json.loads(cached)
|
||||||
tasks = [doc_content_tagging(chat_mdl, d, topn_tags) for d in docs_to_tag]
|
async with trio.open_nursery() as nursery:
|
||||||
asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks))
|
for d in docs_to_tag:
|
||||||
|
nursery.start_soon(doc_content_tagging, chat_mdl, d, topn_tags)
|
||||||
progress_callback(msg="Tagging {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
|
progress_callback(msg="Tagging {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
|
||||||
|
|
||||||
return docs
|
return docs
|
||||||
@ -379,7 +367,7 @@ def init_kb(row, vector_size: int):
|
|||||||
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
|
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
|
||||||
|
|
||||||
|
|
||||||
def embedding(docs, mdl, parser_config=None, callback=None):
|
async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||||
if parser_config is None:
|
if parser_config is None:
|
||||||
parser_config = {}
|
parser_config = {}
|
||||||
batch_size = 16
|
batch_size = 16
|
||||||
@ -396,13 +384,13 @@ def embedding(docs, mdl, parser_config=None, callback=None):
|
|||||||
|
|
||||||
tk_count = 0
|
tk_count = 0
|
||||||
if len(tts) == len(cnts):
|
if len(tts) == len(cnts):
|
||||||
vts, c = mdl.encode(tts[0: 1])
|
vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1]))
|
||||||
tts = np.concatenate([vts for _ in range(len(tts))], axis=0)
|
tts = np.concatenate([vts for _ in range(len(tts))], axis=0)
|
||||||
tk_count += c
|
tk_count += c
|
||||||
|
|
||||||
cnts_ = np.array([])
|
cnts_ = np.array([])
|
||||||
for i in range(0, len(cnts), batch_size):
|
for i in range(0, len(cnts), batch_size):
|
||||||
vts, c = mdl.encode(cnts[i: i + batch_size])
|
vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(cnts[i: i + batch_size]))
|
||||||
if len(cnts_) == 0:
|
if len(cnts_) == 0:
|
||||||
cnts_ = vts
|
cnts_ = vts
|
||||||
else:
|
else:
|
||||||
@ -424,7 +412,7 @@ def embedding(docs, mdl, parser_config=None, callback=None):
|
|||||||
return tk_count, vector_size
|
return tk_count, vector_size
|
||||||
|
|
||||||
|
|
||||||
def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
||||||
chunks = []
|
chunks = []
|
||||||
vctr_nm = "q_%d_vec"%vector_size
|
vctr_nm = "q_%d_vec"%vector_size
|
||||||
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
|
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
|
||||||
@ -440,7 +428,7 @@ def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
|||||||
row["parser_config"]["raptor"]["threshold"]
|
row["parser_config"]["raptor"]["threshold"]
|
||||||
)
|
)
|
||||||
original_length = len(chunks)
|
original_length = len(chunks)
|
||||||
chunks = raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
|
chunks = await raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
|
||||||
doc = {
|
doc = {
|
||||||
"doc_id": row["doc_id"],
|
"doc_id": row["doc_id"],
|
||||||
"kb_id": [str(row["kb_id"])],
|
"kb_id": [str(row["kb_id"])],
|
||||||
@ -465,13 +453,13 @@ def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
|||||||
return res, tk_count
|
return res, tk_count
|
||||||
|
|
||||||
|
|
||||||
def run_graphrag(row, chat_model, language, embedding_model, callback=None):
|
async def run_graphrag(row, chat_model, language, embedding_model, callback=None):
|
||||||
chunks = []
|
chunks = []
|
||||||
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
|
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
|
||||||
fields=["content_with_weight", "doc_id"]):
|
fields=["content_with_weight", "doc_id"]):
|
||||||
chunks.append((d["doc_id"], d["content_with_weight"]))
|
chunks.append((d["doc_id"], d["content_with_weight"]))
|
||||||
|
|
||||||
Dealer(LightKGExt if row["parser_config"]["graphrag"]["method"] != 'general' else GeneralKGExt,
|
dealer = Dealer(LightKGExt if row["parser_config"]["graphrag"]["method"] != 'general' else GeneralKGExt,
|
||||||
row["tenant_id"],
|
row["tenant_id"],
|
||||||
str(row["kb_id"]),
|
str(row["kb_id"]),
|
||||||
chat_model,
|
chat_model,
|
||||||
@ -480,9 +468,10 @@ def run_graphrag(row, chat_model, language, embedding_model, callback=None):
|
|||||||
entity_types=row["parser_config"]["graphrag"]["entity_types"],
|
entity_types=row["parser_config"]["graphrag"]["entity_types"],
|
||||||
embed_bdl=embedding_model,
|
embed_bdl=embedding_model,
|
||||||
callback=callback)
|
callback=callback)
|
||||||
|
await dealer()
|
||||||
|
|
||||||
|
|
||||||
def do_handle_task(task):
|
async def do_handle_task(task):
|
||||||
task_id = task["id"]
|
task_id = task["id"]
|
||||||
task_from_page = task["from_page"]
|
task_from_page = task["from_page"]
|
||||||
task_to_page = task["to_page"]
|
task_to_page = task["to_page"]
|
||||||
@ -494,6 +483,7 @@ def do_handle_task(task):
|
|||||||
task_doc_id = task["doc_id"]
|
task_doc_id = task["doc_id"]
|
||||||
task_document_name = task["name"]
|
task_document_name = task["name"]
|
||||||
task_parser_config = task["parser_config"]
|
task_parser_config = task["parser_config"]
|
||||||
|
task_start_ts = timer()
|
||||||
|
|
||||||
# prepare the progress callback function
|
# prepare the progress callback function
|
||||||
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
|
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
|
||||||
@ -505,11 +495,7 @@ def do_handle_task(task):
|
|||||||
progress_callback(-1, msg=error_message)
|
progress_callback(-1, msg=error_message)
|
||||||
raise Exception(error_message)
|
raise Exception(error_message)
|
||||||
|
|
||||||
try:
|
task_canceled = TaskService.do_cancel(task_id)
|
||||||
task_canceled = TaskService.do_cancel(task_id)
|
|
||||||
except DoesNotExist:
|
|
||||||
logging.warning(f"task {task_id} is unknown")
|
|
||||||
return
|
|
||||||
if task_canceled:
|
if task_canceled:
|
||||||
progress_callback(-1, msg="Task has been canceled.")
|
progress_callback(-1, msg="Task has been canceled.")
|
||||||
return
|
return
|
||||||
@ -529,71 +515,41 @@ def do_handle_task(task):
|
|||||||
|
|
||||||
# Either using RAPTOR or Standard chunking methods
|
# Either using RAPTOR or Standard chunking methods
|
||||||
if task.get("task_type", "") == "raptor":
|
if task.get("task_type", "") == "raptor":
|
||||||
try:
|
# bind LLM for raptor
|
||||||
# bind LLM for raptor
|
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
# run RAPTOR
|
||||||
# run RAPTOR
|
chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
|
||||||
chunks, token_count = run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
|
|
||||||
except TaskCanceledException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
error_message = f'Fail to bind LLM used by RAPTOR: {str(e)}'
|
|
||||||
progress_callback(-1, msg=error_message)
|
|
||||||
logging.exception(error_message)
|
|
||||||
raise
|
|
||||||
# Either using graphrag or Standard chunking methods
|
# Either using graphrag or Standard chunking methods
|
||||||
elif task.get("task_type", "") == "graphrag":
|
elif task.get("task_type", "") == "graphrag":
|
||||||
start_ts = timer()
|
start_ts = timer()
|
||||||
try:
|
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
await run_graphrag(task, chat_model, task_language, embedding_model, progress_callback)
|
||||||
run_graphrag(task, chat_model, task_language, embedding_model, progress_callback)
|
progress_callback(prog=1.0, msg="Knowledge Graph is done ({:.2f}s)".format(timer() - start_ts))
|
||||||
progress_callback(prog=1.0, msg="Knowledge Graph is done ({:.2f}s)".format(timer() - start_ts))
|
|
||||||
except TaskCanceledException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
error_message = f'Fail to bind LLM used by Knowledge Graph: {str(e)}'
|
|
||||||
progress_callback(-1, msg=error_message)
|
|
||||||
logging.exception(error_message)
|
|
||||||
raise
|
|
||||||
return
|
return
|
||||||
elif task.get("task_type", "") == "graph_resolution":
|
elif task.get("task_type", "") == "graph_resolution":
|
||||||
start_ts = timer()
|
start_ts = timer()
|
||||||
try:
|
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
with_res = WithResolution(
|
||||||
WithResolution(
|
task["tenant_id"], str(task["kb_id"]),chat_model, embedding_model,
|
||||||
task["tenant_id"], str(task["kb_id"]),chat_model, embedding_model,
|
progress_callback
|
||||||
progress_callback
|
)
|
||||||
)
|
await with_res()
|
||||||
progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts))
|
progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts))
|
||||||
except TaskCanceledException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
error_message = f'Fail to bind LLM used by Knowledge Graph resolution: {str(e)}'
|
|
||||||
progress_callback(-1, msg=error_message)
|
|
||||||
logging.exception(error_message)
|
|
||||||
raise
|
|
||||||
return
|
return
|
||||||
elif task.get("task_type", "") == "graph_community":
|
elif task.get("task_type", "") == "graph_community":
|
||||||
start_ts = timer()
|
start_ts = timer()
|
||||||
try:
|
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
with_comm = WithCommunity(
|
||||||
WithCommunity(
|
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
|
||||||
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
|
progress_callback
|
||||||
progress_callback
|
)
|
||||||
)
|
await with_comm()
|
||||||
progress_callback(prog=1.0, msg="GraphRAG community reports generation is done ({:.2f}s)".format(timer() - start_ts))
|
progress_callback(prog=1.0, msg="GraphRAG community reports generation is done ({:.2f}s)".format(timer() - start_ts))
|
||||||
except TaskCanceledException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
error_message = f'Fail to bind LLM used by GraphRAG community reports generation: {str(e)}'
|
|
||||||
progress_callback(-1, msg=error_message)
|
|
||||||
logging.exception(error_message)
|
|
||||||
raise
|
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
# Standard chunking methods
|
# Standard chunking methods
|
||||||
start_ts = timer()
|
start_ts = timer()
|
||||||
chunks = build_chunks(task, progress_callback)
|
chunks = await build_chunks(task, progress_callback)
|
||||||
logging.info("Build document {}: {:.2f}s".format(task_document_name, timer() - start_ts))
|
logging.info("Build document {}: {:.2f}s".format(task_document_name, timer() - start_ts))
|
||||||
if chunks is None:
|
if chunks is None:
|
||||||
return
|
return
|
||||||
@ -605,7 +561,7 @@ def do_handle_task(task):
|
|||||||
progress_callback(msg="Generate {} chunks".format(len(chunks)))
|
progress_callback(msg="Generate {} chunks".format(len(chunks)))
|
||||||
start_ts = timer()
|
start_ts = timer()
|
||||||
try:
|
try:
|
||||||
token_count, vector_size = embedding(chunks, embedding_model, task_parser_config, progress_callback)
|
token_count, vector_size = await embedding(chunks, embedding_model, task_parser_config, progress_callback)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_message = "Generate embedding error:{}".format(str(e))
|
error_message = "Generate embedding error:{}".format(str(e))
|
||||||
progress_callback(-1, error_message)
|
progress_callback(-1, error_message)
|
||||||
@ -621,8 +577,7 @@ def do_handle_task(task):
|
|||||||
doc_store_result = ""
|
doc_store_result = ""
|
||||||
es_bulk_size = 4
|
es_bulk_size = 4
|
||||||
for b in range(0, len(chunks), es_bulk_size):
|
for b in range(0, len(chunks), es_bulk_size):
|
||||||
doc_store_result = settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id),
|
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id), task_dataset_id))
|
||||||
task_dataset_id)
|
|
||||||
if b % 128 == 0:
|
if b % 128 == 0:
|
||||||
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
|
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
|
||||||
if doc_store_result:
|
if doc_store_result:
|
||||||
@ -635,8 +590,7 @@ def do_handle_task(task):
|
|||||||
TaskService.update_chunk_ids(task["id"], chunk_ids_str)
|
TaskService.update_chunk_ids(task["id"], chunk_ids_str)
|
||||||
except DoesNotExist:
|
except DoesNotExist:
|
||||||
logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.")
|
logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.")
|
||||||
doc_store_result = settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id),
|
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
|
||||||
task_dataset_id)
|
|
||||||
return
|
return
|
||||||
logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page,
|
logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page,
|
||||||
task_to_page, len(chunks),
|
task_to_page, len(chunks),
|
||||||
@ -645,51 +599,39 @@ def do_handle_task(task):
|
|||||||
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
|
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
|
||||||
|
|
||||||
time_cost = timer() - start_ts
|
time_cost = timer() - start_ts
|
||||||
progress_callback(prog=1.0, msg="Done ({:.2f}s)".format(time_cost))
|
task_time_cost = timer() - task_start_ts
|
||||||
|
progress_callback(prog=1.0, msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
|
||||||
logging.info(
|
logging.info(
|
||||||
"Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page,
|
"Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page,
|
||||||
task_to_page, len(chunks),
|
task_to_page, len(chunks),
|
||||||
token_count, time_cost))
|
token_count, task_time_cost))
|
||||||
|
|
||||||
|
|
||||||
def handle_task():
|
async def handle_task():
|
||||||
global PAYLOAD, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK
|
global DONE_TASKS, FAILED_TASKS
|
||||||
task = collect()
|
redis_msg, task = await collect()
|
||||||
if task:
|
if not task:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
logging.info(f"handle_task begin for task {json.dumps(task)}")
|
||||||
|
CURRENT_TASKS[task["id"]] = copy.deepcopy(task)
|
||||||
|
await do_handle_task(task)
|
||||||
|
DONE_TASKS += 1
|
||||||
|
CURRENT_TASKS.pop(task["id"], None)
|
||||||
|
logging.info(f"handle_task done for task {json.dumps(task)}")
|
||||||
|
except Exception as e:
|
||||||
|
FAILED_TASKS += 1
|
||||||
|
CURRENT_TASKS.pop(task["id"], None)
|
||||||
try:
|
try:
|
||||||
logging.info(f"handle_task begin for task {json.dumps(task)}")
|
set_progress(task["id"], prog=-1, msg=f"[Exception]: {e}")
|
||||||
with mt_lock:
|
except Exception:
|
||||||
CURRENT_TASK = copy.deepcopy(task)
|
pass
|
||||||
do_handle_task(task)
|
logging.exception(f"handle_task got exception for task {json.dumps(task)}")
|
||||||
with mt_lock:
|
redis_msg.ack()
|
||||||
DONE_TASKS += 1
|
|
||||||
CURRENT_TASK = None
|
|
||||||
logging.info(f"handle_task done for task {json.dumps(task)}")
|
|
||||||
except TaskCanceledException:
|
|
||||||
with mt_lock:
|
|
||||||
DONE_TASKS += 1
|
|
||||||
CURRENT_TASK = None
|
|
||||||
try:
|
|
||||||
set_progress(task["id"], prog=-1, msg="handle_task got TaskCanceledException")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
logging.debug("handle_task got TaskCanceledException", exc_info=True)
|
|
||||||
except Exception as e:
|
|
||||||
with mt_lock:
|
|
||||||
FAILED_TASKS += 1
|
|
||||||
CURRENT_TASK = None
|
|
||||||
try:
|
|
||||||
set_progress(task["id"], prog=-1, msg=f"[Exception]: {e}")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
logging.exception(f"handle_task got exception for task {json.dumps(task)}")
|
|
||||||
if PAYLOAD:
|
|
||||||
PAYLOAD.ack()
|
|
||||||
PAYLOAD = None
|
|
||||||
|
|
||||||
|
|
||||||
def report_status():
|
async def report_status():
|
||||||
global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK
|
global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, DONE_TASKS, FAILED_TASKS
|
||||||
REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME)
|
REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@ -699,17 +641,17 @@ def report_status():
|
|||||||
PENDING_TASKS = int(group_info.get("pending", 0))
|
PENDING_TASKS = int(group_info.get("pending", 0))
|
||||||
LAG_TASKS = int(group_info.get("lag", 0))
|
LAG_TASKS = int(group_info.get("lag", 0))
|
||||||
|
|
||||||
with mt_lock:
|
current = copy.deepcopy(CURRENT_TASKS)
|
||||||
heartbeat = json.dumps({
|
heartbeat = json.dumps({
|
||||||
"name": CONSUMER_NAME,
|
"name": CONSUMER_NAME,
|
||||||
"now": now.astimezone().isoformat(timespec="milliseconds"),
|
"now": now.astimezone().isoformat(timespec="milliseconds"),
|
||||||
"boot_at": BOOT_AT,
|
"boot_at": BOOT_AT,
|
||||||
"pending": PENDING_TASKS,
|
"pending": PENDING_TASKS,
|
||||||
"lag": LAG_TASKS,
|
"lag": LAG_TASKS,
|
||||||
"done": DONE_TASKS,
|
"done": DONE_TASKS,
|
||||||
"failed": FAILED_TASKS,
|
"failed": FAILED_TASKS,
|
||||||
"current": CURRENT_TASK,
|
"current": current,
|
||||||
})
|
})
|
||||||
REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
|
REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
|
||||||
logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
|
logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
|
||||||
|
|
||||||
@ -718,27 +660,10 @@ def report_status():
|
|||||||
REDIS_CONN.zpopmin(CONSUMER_NAME, expired)
|
REDIS_CONN.zpopmin(CONSUMER_NAME, expired)
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception("report_status got exception")
|
logging.exception("report_status got exception")
|
||||||
time.sleep(30)
|
await trio.sleep(30)
|
||||||
|
|
||||||
|
|
||||||
def analyze_heap(snapshot1: tracemalloc.Snapshot, snapshot2: tracemalloc.Snapshot, snapshot_id: int, dump_full: bool):
|
async def main():
|
||||||
msg = ""
|
|
||||||
if dump_full:
|
|
||||||
stats2 = snapshot2.statistics('lineno')
|
|
||||||
msg += f"{CONSUMER_NAME} memory usage of snapshot {snapshot_id}:\n"
|
|
||||||
for stat in stats2[:10]:
|
|
||||||
msg += f"{stat}\n"
|
|
||||||
stats1_vs_2 = snapshot2.compare_to(snapshot1, 'lineno')
|
|
||||||
msg += f"{CONSUMER_NAME} memory usage increase from snapshot {snapshot_id - 1} to snapshot {snapshot_id}:\n"
|
|
||||||
for stat in stats1_vs_2[:10]:
|
|
||||||
msg += f"{stat}\n"
|
|
||||||
msg += f"{CONSUMER_NAME} detailed traceback for the top memory consumers:\n"
|
|
||||||
for stat in stats1_vs_2[:3]:
|
|
||||||
msg += '\n'.join(stat.traceback.format())
|
|
||||||
logging.info(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
logging.info(r"""
|
logging.info(r"""
|
||||||
______ __ ______ __
|
______ __ ______ __
|
||||||
/_ __/___ ______/ /__ / ____/ _____ _______ __/ /_____ _____
|
/_ __/___ ______/ /__ / ____/ _____ _______ __/ /_____ _____
|
||||||
@ -755,33 +680,12 @@ def main():
|
|||||||
if TRACE_MALLOC_ENABLED:
|
if TRACE_MALLOC_ENABLED:
|
||||||
start_tracemalloc_and_snapshot(None, None)
|
start_tracemalloc_and_snapshot(None, None)
|
||||||
|
|
||||||
# Create an event to signal the background thread to exit
|
async with trio.open_nursery() as nursery:
|
||||||
stop_event = threading.Event()
|
nursery.start_soon(report_status)
|
||||||
|
while True:
|
||||||
background_thread = threading.Thread(target=report_status)
|
async with task_limiter:
|
||||||
background_thread.daemon = True
|
nursery.start_soon(handle_task)
|
||||||
background_thread.start()
|
logging.error("BUG!!! You should not reach here!!!")
|
||||||
|
|
||||||
# Handle SIGINT (Ctrl+C)
|
|
||||||
def signal_handler(sig, frame):
|
|
||||||
logging.info("Received Ctrl+C, shutting down gracefully...")
|
|
||||||
stop_event.set()
|
|
||||||
# Give the background thread time to clean up
|
|
||||||
if background_thread.is_alive():
|
|
||||||
background_thread.join(timeout=5)
|
|
||||||
logging.info("Exiting...")
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
|
|
||||||
try:
|
|
||||||
while not stop_event.is_set():
|
|
||||||
handle_task()
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
logging.info("Interrupted by keyboard, shutting down...")
|
|
||||||
stop_event.set()
|
|
||||||
if background_thread.is_alive():
|
|
||||||
background_thread.join(timeout=5)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
trio.run(main)
|
||||||
|
@ -24,7 +24,7 @@ from rag import settings
|
|||||||
from rag.utils import singleton
|
from rag.utils import singleton
|
||||||
|
|
||||||
|
|
||||||
class Payload:
|
class RedisMsg:
|
||||||
def __init__(self, consumer, queue_name, group_name, msg_id, message):
|
def __init__(self, consumer, queue_name, group_name, msg_id, message):
|
||||||
self.__consumer = consumer
|
self.__consumer = consumer
|
||||||
self.__queue_name = queue_name
|
self.__queue_name = queue_name
|
||||||
@ -43,6 +43,9 @@ class Payload:
|
|||||||
def get_message(self):
|
def get_message(self):
|
||||||
return self.__message
|
return self.__message
|
||||||
|
|
||||||
|
def get_msg_id(self):
|
||||||
|
return self.__msg_id
|
||||||
|
|
||||||
|
|
||||||
@singleton
|
@singleton
|
||||||
class RedisDB:
|
class RedisDB:
|
||||||
@ -206,9 +209,8 @@ class RedisDB:
|
|||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def queue_consumer(
|
def queue_consumer(self, queue_name, group_name, consumer_name, msg_id=b">") -> RedisMsg:
|
||||||
self, queue_name, group_name, consumer_name, msg_id=b">"
|
"""https://redis.io/docs/latest/commands/xreadgroup/"""
|
||||||
) -> Payload:
|
|
||||||
try:
|
try:
|
||||||
group_info = self.REDIS.xinfo_groups(queue_name)
|
group_info = self.REDIS.xinfo_groups(queue_name)
|
||||||
if not any(e["name"] == group_name for e in group_info):
|
if not any(e["name"] == group_name for e in group_info):
|
||||||
@ -217,15 +219,17 @@ class RedisDB:
|
|||||||
"groupname": group_name,
|
"groupname": group_name,
|
||||||
"consumername": consumer_name,
|
"consumername": consumer_name,
|
||||||
"count": 1,
|
"count": 1,
|
||||||
"block": 10000,
|
"block": 5,
|
||||||
"streams": {queue_name: msg_id},
|
"streams": {queue_name: msg_id},
|
||||||
}
|
}
|
||||||
messages = self.REDIS.xreadgroup(**args)
|
messages = self.REDIS.xreadgroup(**args)
|
||||||
if not messages:
|
if not messages:
|
||||||
return None
|
return None
|
||||||
stream, element_list = messages[0]
|
stream, element_list = messages[0]
|
||||||
|
if not element_list:
|
||||||
|
return None
|
||||||
msg_id, payload = element_list[0]
|
msg_id, payload = element_list[0]
|
||||||
res = Payload(self.REDIS, queue_name, group_name, msg_id, payload)
|
res = RedisMsg(self.REDIS, queue_name, group_name, msg_id, payload)
|
||||||
return res
|
return res
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "key" in str(e):
|
if "key" in str(e):
|
||||||
@ -239,30 +243,24 @@ class RedisDB:
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_unacked_for(self, consumer_name, queue_name, group_name):
|
def get_unacked_iterator(self, queue_name, group_name, consumer_name):
|
||||||
try:
|
try:
|
||||||
group_info = self.REDIS.xinfo_groups(queue_name)
|
group_info = self.REDIS.xinfo_groups(queue_name)
|
||||||
if not any(e["name"] == group_name for e in group_info):
|
if not any(e["name"] == group_name for e in group_info):
|
||||||
return
|
return
|
||||||
pendings = self.REDIS.xpending_range(
|
current_min = 0
|
||||||
queue_name,
|
while True:
|
||||||
group_name,
|
payload = self.queue_consumer(queue_name, group_name, consumer_name, current_min)
|
||||||
min=0,
|
if not payload:
|
||||||
max=10000000000000,
|
return
|
||||||
count=1,
|
current_min = payload.get_msg_id()
|
||||||
consumername=consumer_name,
|
logging.info(f"RedisDB.get_unacked_iterator {consumer_name} msg_id {current_min}")
|
||||||
)
|
yield payload
|
||||||
if not pendings:
|
|
||||||
return
|
|
||||||
msg_id = pendings[0]["message_id"]
|
|
||||||
msg = self.REDIS.xrange(queue_name, min=msg_id, count=1)
|
|
||||||
_, payload = msg[0]
|
|
||||||
return Payload(self.REDIS, queue_name, group_name, msg_id, payload)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "key" in str(e):
|
if "key" in str(e):
|
||||||
return
|
return
|
||||||
logging.exception(
|
logging.exception(
|
||||||
"RedisDB.get_unacked_for " + consumer_name + " got exception: " + str(e)
|
"RedisDB.get_unacked_iterator " + consumer_name + " got exception: "
|
||||||
)
|
)
|
||||||
self.__open__()
|
self.__open__()
|
||||||
|
|
||||||
|
52
uv.lock
generated
52
uv.lock
generated
@ -1,4 +1,5 @@
|
|||||||
version = 1
|
version = 1
|
||||||
|
revision = 1
|
||||||
requires-python = ">=3.10, <3.13"
|
requires-python = ">=3.10, <3.13"
|
||||||
resolution-markers = [
|
resolution-markers = [
|
||||||
"python_full_version >= '3.12' and sys_platform == 'darwin'",
|
"python_full_version >= '3.12' and sys_platform == 'darwin'",
|
||||||
@ -1083,9 +1084,6 @@ name = "datrie"
|
|||||||
version = "0.8.2"
|
version = "0.8.2"
|
||||||
source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
|
source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
|
||||||
sdist = { url = "https://mirrors.aliyun.com/pypi/packages/9d/fe/db74bd405d515f06657f11ad529878fd389576dca4812bea6f98d9b31574/datrie-0.8.2.tar.gz", hash = "sha256:525b08f638d5cf6115df6ccd818e5a01298cd230b2dac91c8ff2e6499d18765d" }
|
sdist = { url = "https://mirrors.aliyun.com/pypi/packages/9d/fe/db74bd405d515f06657f11ad529878fd389576dca4812bea6f98d9b31574/datrie-0.8.2.tar.gz", hash = "sha256:525b08f638d5cf6115df6ccd818e5a01298cd230b2dac91c8ff2e6499d18765d" }
|
||||||
wheels = [
|
|
||||||
{ url = "https://mirrors.aliyun.com/pypi/packages/44/02/53f0cf0bf0cd629ba6c2cc13f2f9db24323459e9c19463783d890a540a96/datrie-0.8.2-pp273-pypy_73-win32.whl", hash = "sha256:b07bd5fdfc3399a6dab86d6e35c72b1dbd598e80c97509c7c7518ab8774d3fda" },
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "decorator"
|
name = "decorator"
|
||||||
@ -1362,17 +1360,17 @@ name = "fastembed-gpu"
|
|||||||
version = "0.3.6"
|
version = "0.3.6"
|
||||||
source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
|
source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "huggingface-hub" },
|
{ name = "huggingface-hub", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||||
{ name = "loguru" },
|
{ name = "loguru", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||||
{ name = "mmh3" },
|
{ name = "mmh3", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||||
{ name = "numpy" },
|
{ name = "numpy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||||
{ name = "onnxruntime-gpu" },
|
{ name = "onnxruntime-gpu", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||||
{ name = "pillow" },
|
{ name = "pillow", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||||
{ name = "pystemmer" },
|
{ name = "pystemmer", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||||
{ name = "requests" },
|
{ name = "requests", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||||
{ name = "snowballstemmer" },
|
{ name = "snowballstemmer", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||||
{ name = "tokenizers" },
|
{ name = "tokenizers", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||||
{ name = "tqdm" },
|
{ name = "tqdm", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://mirrors.aliyun.com/pypi/packages/da/07/7336c7f3d7ee47f33b407eeb50f5eeb152889de538a52a8f1cc637192816/fastembed_gpu-0.3.6.tar.gz", hash = "sha256:ee2de8918b142adbbf48caaffec0c492f864d73c073eea5a3dcd0e8c1041c50d" }
|
sdist = { url = "https://mirrors.aliyun.com/pypi/packages/da/07/7336c7f3d7ee47f33b407eeb50f5eeb152889de538a52a8f1cc637192816/fastembed_gpu-0.3.6.tar.gz", hash = "sha256:ee2de8918b142adbbf48caaffec0c492f864d73c073eea5a3dcd0e8c1041c50d" }
|
||||||
wheels = [
|
wheels = [
|
||||||
@ -3485,12 +3483,12 @@ name = "onnxruntime-gpu"
|
|||||||
version = "1.19.2"
|
version = "1.19.2"
|
||||||
source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
|
source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "coloredlogs" },
|
{ name = "coloredlogs", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||||
{ name = "flatbuffers" },
|
{ name = "flatbuffers", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||||
{ name = "numpy" },
|
{ name = "numpy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||||
{ name = "packaging" },
|
{ name = "packaging", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||||
{ name = "protobuf" },
|
{ name = "protobuf", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||||
{ name = "sympy" },
|
{ name = "sympy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||||
]
|
]
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://mirrors.aliyun.com/pypi/packages/d0/9c/3fa310e0730643051eb88e884f19813a6c8b67d0fbafcda610d960e589db/onnxruntime_gpu-1.19.2-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a49740e079e7c5215830d30cde3df792e903df007aa0b0fd7aa797937061b27a" },
|
{ url = "https://mirrors.aliyun.com/pypi/packages/d0/9c/3fa310e0730643051eb88e884f19813a6c8b67d0fbafcda610d960e589db/onnxruntime_gpu-1.19.2-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a49740e079e7c5215830d30cde3df792e903df007aa0b0fd7aa797937061b27a" },
|
||||||
@ -4164,15 +4162,6 @@ wheels = [
|
|||||||
{ url = "https://mirrors.aliyun.com/pypi/packages/77/89/bc88a6711935ba795a679ea6ebee07e128050d6382eaa35a0a47c8032bdc/pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd" },
|
{ url = "https://mirrors.aliyun.com/pypi/packages/77/89/bc88a6711935ba795a679ea6ebee07e128050d6382eaa35a0a47c8032bdc/pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "pybind11"
|
|
||||||
version = "2.13.6"
|
|
||||||
source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
|
|
||||||
sdist = { url = "https://mirrors.aliyun.com/pypi/packages/d2/c1/72b9622fcb32ff98b054f724e213c7f70d6898baa714f4516288456ceaba/pybind11-2.13.6.tar.gz", hash = "sha256:ba6af10348c12b24e92fa086b39cfba0eff619b61ac77c406167d813b096d39a" }
|
|
||||||
wheels = [
|
|
||||||
{ url = "https://mirrors.aliyun.com/pypi/packages/13/2f/0f24b288e2ce56f51c920137620b4434a38fd80583dbbe24fc2a1656c388/pybind11-2.13.6-py3-none-any.whl", hash = "sha256:237c41e29157b962835d356b370ededd57594a26d5894a795960f0047cb5caf5" },
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pyclipper"
|
name = "pyclipper"
|
||||||
version = "1.3.0.post5"
|
version = "1.3.0.post5"
|
||||||
@ -4230,8 +4219,6 @@ wheels = [
|
|||||||
{ url = "https://mirrors.aliyun.com/pypi/packages/48/7d/0f2b09490b98cc6a902ac15dda8760c568b9c18cfe70e0ef7a16de64d53a/pycryptodomex-3.20.0-cp35-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:7a7a8f33a1f1fb762ede6cc9cbab8f2a9ba13b196bfaf7bc6f0b39d2ba315a43" },
|
{ url = "https://mirrors.aliyun.com/pypi/packages/48/7d/0f2b09490b98cc6a902ac15dda8760c568b9c18cfe70e0ef7a16de64d53a/pycryptodomex-3.20.0-cp35-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:7a7a8f33a1f1fb762ede6cc9cbab8f2a9ba13b196bfaf7bc6f0b39d2ba315a43" },
|
||||||
{ url = "https://mirrors.aliyun.com/pypi/packages/b0/1c/375adb14b71ee1c8d8232904e928b3e7af5bbbca7c04e4bec94fe8e90c3d/pycryptodomex-3.20.0-cp35-abi3-win32.whl", hash = "sha256:c39778fd0548d78917b61f03c1fa8bfda6cfcf98c767decf360945fe6f97461e" },
|
{ url = "https://mirrors.aliyun.com/pypi/packages/b0/1c/375adb14b71ee1c8d8232904e928b3e7af5bbbca7c04e4bec94fe8e90c3d/pycryptodomex-3.20.0-cp35-abi3-win32.whl", hash = "sha256:c39778fd0548d78917b61f03c1fa8bfda6cfcf98c767decf360945fe6f97461e" },
|
||||||
{ url = "https://mirrors.aliyun.com/pypi/packages/b2/e8/1b92184ab7e5595bf38000587e6f8cf9556ebd1bf0a583619bee2057afbd/pycryptodomex-3.20.0-cp35-abi3-win_amd64.whl", hash = "sha256:2a47bcc478741b71273b917232f521fd5704ab4b25d301669879e7273d3586cc" },
|
{ url = "https://mirrors.aliyun.com/pypi/packages/b2/e8/1b92184ab7e5595bf38000587e6f8cf9556ebd1bf0a583619bee2057afbd/pycryptodomex-3.20.0-cp35-abi3-win_amd64.whl", hash = "sha256:2a47bcc478741b71273b917232f521fd5704ab4b25d301669879e7273d3586cc" },
|
||||||
{ url = "https://mirrors.aliyun.com/pypi/packages/e7/c5/9140bb867141d948c8e242013ec8a8011172233c898dfdba0a2417c3169a/pycryptodomex-3.20.0-pp27-pypy_73-manylinux2010_x86_64.whl", hash = "sha256:1be97461c439a6af4fe1cf8bf6ca5936d3db252737d2f379cc6b2e394e12a458" },
|
|
||||||
{ url = "https://mirrors.aliyun.com/pypi/packages/5e/6a/04acb4978ce08ab16890c70611ebc6efd251681341617bbb9e53356dee70/pycryptodomex-3.20.0-pp27-pypy_73-win32.whl", hash = "sha256:19764605feea0df966445d46533729b645033f134baeb3ea26ad518c9fdf212c" },
|
|
||||||
{ url = "https://mirrors.aliyun.com/pypi/packages/eb/df/3f1ea084e43b91e6d2b6b3493cc948864c17ea5d93ff1261a03812fbfd1a/pycryptodomex-3.20.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:f2e497413560e03421484189a6b65e33fe800d3bd75590e6d78d4dfdb7accf3b" },
|
{ url = "https://mirrors.aliyun.com/pypi/packages/eb/df/3f1ea084e43b91e6d2b6b3493cc948864c17ea5d93ff1261a03812fbfd1a/pycryptodomex-3.20.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:f2e497413560e03421484189a6b65e33fe800d3bd75590e6d78d4dfdb7accf3b" },
|
||||||
{ url = "https://mirrors.aliyun.com/pypi/packages/c9/f3/83ffbdfa0c8f9154bcd8866895f6cae5a3ec749da8b0840603cf936c4412/pycryptodomex-3.20.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e48217c7901edd95f9f097feaa0388da215ed14ce2ece803d3f300b4e694abea" },
|
{ url = "https://mirrors.aliyun.com/pypi/packages/c9/f3/83ffbdfa0c8f9154bcd8866895f6cae5a3ec749da8b0840603cf936c4412/pycryptodomex-3.20.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e48217c7901edd95f9f097feaa0388da215ed14ce2ece803d3f300b4e694abea" },
|
||||||
{ url = "https://mirrors.aliyun.com/pypi/packages/c9/9d/c113e640aaf02af5631ae2686b742aac5cd0e1402b9d6512b1c7ec5ef05d/pycryptodomex-3.20.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d00fe8596e1cc46b44bf3907354e9377aa030ec4cd04afbbf6e899fc1e2a7781" },
|
{ url = "https://mirrors.aliyun.com/pypi/packages/c9/9d/c113e640aaf02af5631ae2686b742aac5cd0e1402b9d6512b1c7ec5ef05d/pycryptodomex-3.20.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d00fe8596e1cc46b44bf3907354e9377aa030ec4cd04afbbf6e899fc1e2a7781" },
|
||||||
@ -4820,6 +4807,7 @@ dependencies = [
|
|||||||
{ name = "tencentcloud-sdk-python" },
|
{ name = "tencentcloud-sdk-python" },
|
||||||
{ name = "tika" },
|
{ name = "tika" },
|
||||||
{ name = "tiktoken" },
|
{ name = "tiktoken" },
|
||||||
|
{ name = "trio" },
|
||||||
{ name = "umap-learn" },
|
{ name = "umap-learn" },
|
||||||
{ name = "valkey" },
|
{ name = "valkey" },
|
||||||
{ name = "vertexai" },
|
{ name = "vertexai" },
|
||||||
@ -4954,6 +4942,7 @@ requires-dist = [
|
|||||||
{ name = "tiktoken", specifier = "==0.7.0" },
|
{ name = "tiktoken", specifier = "==0.7.0" },
|
||||||
{ name = "torch", marker = "extra == 'full'", specifier = ">=2.5.0,<3.0.0" },
|
{ name = "torch", marker = "extra == 'full'", specifier = ">=2.5.0,<3.0.0" },
|
||||||
{ name = "transformers", marker = "extra == 'full'", specifier = ">=4.35.0,<5.0.0" },
|
{ name = "transformers", marker = "extra == 'full'", specifier = ">=4.35.0,<5.0.0" },
|
||||||
|
{ name = "trio", specifier = ">=0.29.0" },
|
||||||
{ name = "umap-learn", specifier = "==0.5.6" },
|
{ name = "umap-learn", specifier = "==0.5.6" },
|
||||||
{ name = "valkey", specifier = "==6.0.2" },
|
{ name = "valkey", specifier = "==6.0.2" },
|
||||||
{ name = "vertexai", specifier = "==1.64.0" },
|
{ name = "vertexai", specifier = "==1.64.0" },
|
||||||
@ -4969,6 +4958,7 @@ requires-dist = [
|
|||||||
{ name = "yfinance", specifier = "==0.1.96" },
|
{ name = "yfinance", specifier = "==0.1.96" },
|
||||||
{ name = "zhipuai", specifier = "==2.0.1" },
|
{ name = "zhipuai", specifier = "==2.0.1" },
|
||||||
]
|
]
|
||||||
|
provides-extras = ["full"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ranx"
|
name = "ranx"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user