Made task_executor async to speedup parsing (#5530)

### What problem does this PR solve?

Made task_executor async to speedup parsing

### Type of change

- [x] Performance Improvement
This commit is contained in:
Zhichang Yu 2025-03-03 18:59:49 +08:00 committed by GitHub
parent abac2ca2c5
commit c813c1ff4c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 576 additions and 1005 deletions

View File

@ -17,6 +17,7 @@ import json
import re import re
import traceback import traceback
from copy import deepcopy from copy import deepcopy
import trio
from api.db.db_models import APIToken from api.db.db_models import APIToken
from api.db.services.conversation_service import ConversationService, structure_answer from api.db.services.conversation_service import ConversationService, structure_answer
@ -386,7 +387,8 @@ def mindmap():
rank_feature=label_question(question, [kb]) rank_feature=label_question(question, [kb])
) )
mindmap = MindMapExtractor(chat_mdl) mindmap = MindMapExtractor(chat_mdl)
mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]])
mind_map = mind_map.output
if "error" in mind_map: if "error" in mind_map:
return server_error_response(Exception(mind_map["error"])) return server_error_response(Exception(mind_map["error"]))
return get_json_result(data=mind_map) return get_json_result(data=mind_map)

View File

@ -22,6 +22,7 @@ from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy from copy import deepcopy
from datetime import datetime from datetime import datetime
from io import BytesIO from io import BytesIO
import trio
from peewee import fn from peewee import fn
@ -597,8 +598,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
if parser_ids[doc_id] != ParserType.PICTURE.value: if parser_ids[doc_id] != ParserType.PICTURE.value:
mindmap = MindMapExtractor(llm_bdl) mindmap = MindMapExtractor(llm_bdl)
try: try:
mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output, mind_map = trio.run(mindmap, [c["content_with_weight"] for c in docs if c["doc_id"] == doc_id])
ensure_ascii=False, indent=2) mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2)
if len(mind_map) < 32: if len(mind_map) < 32:
raise Exception("Few content: " + mind_map) raise Exception("Few content: " + mind_map)
cks.append({ cks.append({

View File

@ -17,6 +17,8 @@ import base64
import json import json
import os import os
import re import re
import sys
import threading
from io import BytesIO from io import BytesIO
import pdfplumber import pdfplumber
@ -30,6 +32,10 @@ from api.constants import IMG_BASE64_PREFIX
PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE") PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
RAG_BASE = os.getenv("RAG_BASE") RAG_BASE = os.getenv("RAG_BASE")
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
if LOCK_KEY_pdfplumber not in sys.modules:
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
def get_project_base_directory(*args): def get_project_base_directory(*args):
global PROJECT_BASE global PROJECT_BASE
@ -175,19 +181,20 @@ def thumbnail_img(filename, blob):
""" """
filename = filename.lower() filename = filename.lower()
if re.match(r".*\.pdf$", filename): if re.match(r".*\.pdf$", filename):
pdf = pdfplumber.open(BytesIO(blob)) with sys.modules[LOCK_KEY_pdfplumber]:
buffered = BytesIO() pdf = pdfplumber.open(BytesIO(blob))
resolution = 32 buffered = BytesIO()
img = None resolution = 32
for _ in range(10): img = None
# https://github.com/jsvine/pdfplumber?tab=readme-ov-file#creating-a-pageimage-with-to_image for _ in range(10):
pdf.pages[0].to_image(resolution=resolution).annotated.save(buffered, format="png") # https://github.com/jsvine/pdfplumber?tab=readme-ov-file#creating-a-pageimage-with-to_image
img = buffered.getvalue() pdf.pages[0].to_image(resolution=resolution).annotated.save(buffered, format="png")
if len(img) >= 64000 and resolution >= 2: img = buffered.getvalue()
resolution = resolution / 2 if len(img) >= 64000 and resolution >= 2:
buffered = BytesIO() resolution = resolution / 2
else: buffered = BytesIO()
break else:
break
pdf.close() pdf.close()
return img return img

View File

@ -18,6 +18,8 @@ import os.path
import logging import logging
from logging.handlers import RotatingFileHandler from logging.handlers import RotatingFileHandler
initialized_root_logger = False
def get_project_base_directory(): def get_project_base_directory():
PROJECT_BASE = os.path.abspath( PROJECT_BASE = os.path.abspath(
os.path.join( os.path.join(
@ -29,10 +31,13 @@ def get_project_base_directory():
return PROJECT_BASE return PROJECT_BASE
def initRootLogger(logfile_basename: str, log_format: str = "%(asctime)-15s %(levelname)-8s %(process)d %(message)s"): def initRootLogger(logfile_basename: str, log_format: str = "%(asctime)-15s %(levelname)-8s %(process)d %(message)s"):
logger = logging.getLogger() global initialized_root_logger
if logger.hasHandlers(): if initialized_root_logger:
return return
initialized_root_logger = True
logger = logging.getLogger()
logger.handlers.clear()
log_path = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{logfile_basename}.log")) log_path = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{logfile_basename}.log"))
os.makedirs(os.path.dirname(log_path), exist_ok=True) os.makedirs(os.path.dirname(log_path), exist_ok=True)

View File

@ -18,6 +18,8 @@ import logging
import os import os
import random import random
from timeit import default_timer as timer from timeit import default_timer as timer
import sys
import threading
import xgboost as xgb import xgboost as xgb
from io import BytesIO from io import BytesIO
@ -34,6 +36,10 @@ from rag.nlp import rag_tokenizer
from copy import deepcopy from copy import deepcopy
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
if LOCK_KEY_pdfplumber not in sys.modules:
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
class RAGFlowPdfParser: class RAGFlowPdfParser:
def __init__(self): def __init__(self):
self.ocr = OCR() self.ocr = OCR()
@ -948,8 +954,9 @@ class RAGFlowPdfParser:
@staticmethod @staticmethod
def total_page_number(fnm, binary=None): def total_page_number(fnm, binary=None):
try: try:
pdf = pdfplumber.open( with sys.modules[LOCK_KEY_pdfplumber]:
fnm) if not binary else pdfplumber.open(BytesIO(binary)) pdf = pdfplumber.open(
fnm) if not binary else pdfplumber.open(BytesIO(binary))
total_page = len(pdf.pages) total_page = len(pdf.pages)
pdf.close() pdf.close()
return total_page return total_page
@ -968,17 +975,18 @@ class RAGFlowPdfParser:
self.page_from = page_from self.page_from = page_from
start = timer() start = timer()
try: try:
self.pdf = pdfplumber.open(fnm) if isinstance( with sys.modules[LOCK_KEY_pdfplumber]:
fnm, str) else pdfplumber.open(BytesIO(fnm)) self.pdf = pdfplumber.open(fnm) if isinstance(
self.page_images = [p.to_image(resolution=72 * zoomin).annotated for i, p in fnm, str) else pdfplumber.open(BytesIO(fnm))
enumerate(self.pdf.pages[page_from:page_to])] self.page_images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
try: enumerate(self.pdf.pages[page_from:page_to])]
self.page_chars = [[c for c in page.dedupe_chars().chars if self._has_color(c)] for page in self.pdf.pages[page_from:page_to]] try:
except Exception as e: self.page_chars = [[c for c in page.dedupe_chars().chars if self._has_color(c)] for page in self.pdf.pages[page_from:page_to]]
logging.warning(f"Failed to extract characters for pages {page_from}-{page_to}: {str(e)}") except Exception as e:
self.page_chars = [[] for _ in range(page_to - page_from)] # If failed to extract, using empty list instead. logging.warning(f"Failed to extract characters for pages {page_from}-{page_to}: {str(e)}")
self.page_chars = [[] for _ in range(page_to - page_from)] # If failed to extract, using empty list instead.
self.total_page = len(self.pdf.pages)
self.total_page = len(self.pdf.pages)
except Exception: except Exception:
logging.exception("RAGFlowPdfParser __images__") logging.exception("RAGFlowPdfParser __images__")
logging.info(f"__images__ dedupe_chars cost {timer() - start}s") logging.info(f"__images__ dedupe_chars cost {timer() - start}s")

View File

@ -14,7 +14,8 @@
# limitations under the License. # limitations under the License.
# #
import io import io
import sys
import threading
import pdfplumber import pdfplumber
from .ocr import OCR from .ocr import OCR
@ -23,6 +24,11 @@ from .layout_recognizer import LayoutRecognizer4YOLOv10 as LayoutRecognizer
from .table_structure_recognizer import TableStructureRecognizer from .table_structure_recognizer import TableStructureRecognizer
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
if LOCK_KEY_pdfplumber not in sys.modules:
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
def init_in_out(args): def init_in_out(args):
from PIL import Image from PIL import Image
import os import os
@ -36,9 +42,10 @@ def init_in_out(args):
def pdf_pages(fnm, zoomin=3): def pdf_pages(fnm, zoomin=3):
nonlocal outputs, images nonlocal outputs, images
pdf = pdfplumber.open(fnm) with sys.modules[LOCK_KEY_pdfplumber]:
images = [p.to_image(resolution=72 * zoomin).annotated for i, p in pdf = pdfplumber.open(fnm)
enumerate(pdf.pages)] images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
enumerate(pdf.pages)]
for i, page in enumerate(images): for i, page in enumerate(images):
outputs.append(os.path.split(fnm)[-1] + f"_{i}.jpg") outputs.append(os.path.split(fnm)[-1] + f"_{i}.jpg")

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import logging
import itertools import itertools
import re import re
import time import time
@ -21,13 +20,14 @@ from dataclasses import dataclass
from typing import Any, Callable from typing import Any, Callable
import networkx as nx import networkx as nx
import trio
from graphrag.general.extractor import Extractor from graphrag.general.extractor import Extractor
from rag.nlp import is_english from rag.nlp import is_english
import editdistance import editdistance
from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
from rag.llm.chat_model import Base as CompletionLLM from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import perform_variable_replacements from graphrag.utils import perform_variable_replacements, chat_limiter
DEFAULT_RECORD_DELIMITER = "##" DEFAULT_RECORD_DELIMITER = "##"
DEFAULT_ENTITY_INDEX_DELIMITER = "<|>" DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
@ -67,13 +67,13 @@ class EntityResolution(Extractor):
self._resolution_result_delimiter_key = "resolution_result_delimiter" self._resolution_result_delimiter_key = "resolution_result_delimiter"
self._input_text_key = "input_text" self._input_text_key = "input_text"
def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult: async def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult:
"""Call method definition.""" """Call method definition."""
if prompt_variables is None: if prompt_variables is None:
prompt_variables = {} prompt_variables = {}
# Wire defaults into the prompt variables # Wire defaults into the prompt variables
prompt_variables = { self.prompt_variables = {
**prompt_variables, **prompt_variables,
self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key) self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
or DEFAULT_RECORD_DELIMITER, or DEFAULT_RECORD_DELIMITER,
@ -94,39 +94,12 @@ class EntityResolution(Extractor):
for k, v in node_clusters.items(): for k, v in node_clusters.items():
candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if self.is_similarity(a, b)] candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if self.is_similarity(a, b)]
gen_conf = {"temperature": 0.5}
resolution_result = set() resolution_result = set()
for candidate_resolution_i in candidate_resolution.items(): async with trio.open_nursery() as nursery:
if candidate_resolution_i[1]: for candidate_resolution_i in candidate_resolution.items():
try: if not candidate_resolution_i[1]:
pair_txt = [ continue
f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n'] nursery.start_soon(self._resolve_candidate(candidate_resolution_i, resolution_result))
for index, candidate in enumerate(candidate_resolution_i[1]):
pair_txt.append(
f'Question {index + 1}: name of{candidate_resolution_i[0]} A is {candidate[0]} ,name of{candidate_resolution_i[0]} B is {candidate[1]}')
sent = 'question above' if len(pair_txt) == 1 else f'above {len(pair_txt)} questions'
pair_txt.append(
f'\nUse domain knowledge of {candidate_resolution_i[0]}s to help understand the text and answer the {sent} in the format: For Question i, Yes, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are the same {candidate_resolution_i[0]}./No, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are different {candidate_resolution_i[0]}s. For Question i+1, (repeat the above procedures)')
pair_prompt = '\n'.join(pair_txt)
variables = {
**prompt_variables,
self._input_text_key: pair_prompt
}
text = perform_variable_replacements(self._resolution_prompt, variables=variables)
response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
result = self._process_results(len(candidate_resolution_i[1]), response,
prompt_variables.get(self._record_delimiter_key,
DEFAULT_RECORD_DELIMITER),
prompt_variables.get(self._entity_index_dilimiter_key,
DEFAULT_ENTITY_INDEX_DELIMITER),
prompt_variables.get(self._resolution_result_delimiter_key,
DEFAULT_RESOLUTION_RESULT_DELIMITER))
for result_i in result:
resolution_result.add(candidate_resolution_i[1][result_i[0] - 1])
except Exception:
logging.exception("error entity resolution")
connect_graph = nx.Graph() connect_graph = nx.Graph()
removed_entities = [] removed_entities = []
@ -172,6 +145,34 @@ class EntityResolution(Extractor):
removed_entities=removed_entities removed_entities=removed_entities
) )
async def _resolve_candidate(self, candidate_resolution_i, resolution_result):
gen_conf = {"temperature": 0.5}
pair_txt = [
f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n']
for index, candidate in enumerate(candidate_resolution_i[1]):
pair_txt.append(
f'Question {index + 1}: name of{candidate_resolution_i[0]} A is {candidate[0]} ,name of{candidate_resolution_i[0]} B is {candidate[1]}')
sent = 'question above' if len(pair_txt) == 1 else f'above {len(pair_txt)} questions'
pair_txt.append(
f'\nUse domain knowledge of {candidate_resolution_i[0]}s to help understand the text and answer the {sent} in the format: For Question i, Yes, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are the same {candidate_resolution_i[0]}./No, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are different {candidate_resolution_i[0]}s. For Question i+1, (repeat the above procedures)')
pair_prompt = '\n'.join(pair_txt)
variables = {
**self.prompt_variables,
self._input_text_key: pair_prompt
}
text = perform_variable_replacements(self._resolution_prompt, variables=variables)
async with chat_limiter:
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
result = self._process_results(len(candidate_resolution_i[1]), response,
self.prompt_variables.get(self._record_delimiter_key,
DEFAULT_RECORD_DELIMITER),
self.prompt_variables.get(self._entity_index_dilimiter_key,
DEFAULT_ENTITY_INDEX_DELIMITER),
self.prompt_variables.get(self._resolution_result_delimiter_key,
DEFAULT_RESOLUTION_RESULT_DELIMITER))
for result_i in result:
resolution_result.add(candidate_resolution_i[1][result_i[0] - 1])
def _process_results( def _process_results(
self, self,
records_length: int, records_length: int,

View File

@ -1,268 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""
import logging
import argparse
import json
import re
import traceback
from dataclasses import dataclass
from typing import Any
import tiktoken
from graphrag.general.claim_prompt import CLAIM_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
from graphrag.general.extractor import Extractor
from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
DEFAULT_TUPLE_DELIMITER = "<|>"
DEFAULT_RECORD_DELIMITER = "##"
DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
CLAIM_MAX_GLEANINGS = 1
@dataclass
class ClaimExtractorResult:
"""Claim extractor result class definition."""
output: list[dict]
source_docs: dict[str, Any]
class ClaimExtractor(Extractor):
"""Claim extractor class definition."""
_extraction_prompt: str
_summary_prompt: str
_output_formatter_prompt: str
_input_text_key: str
_input_entity_spec_key: str
_input_claim_description_key: str
_tuple_delimiter_key: str
_record_delimiter_key: str
_completion_delimiter_key: str
_max_gleanings: int
_on_error: ErrorHandlerFn
def __init__(
self,
llm_invoker: CompletionLLM,
extraction_prompt: str | None = None,
input_text_key: str | None = None,
input_entity_spec_key: str | None = None,
input_claim_description_key: str | None = None,
input_resolved_entities_key: str | None = None,
tuple_delimiter_key: str | None = None,
record_delimiter_key: str | None = None,
completion_delimiter_key: str | None = None,
encoding_model: str | None = None,
max_gleanings: int | None = None,
on_error: ErrorHandlerFn | None = None,
):
"""Init method definition."""
self._llm = llm_invoker
self._extraction_prompt = extraction_prompt or CLAIM_EXTRACTION_PROMPT
self._input_text_key = input_text_key or "input_text"
self._input_entity_spec_key = input_entity_spec_key or "entity_specs"
self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter"
self._record_delimiter_key = record_delimiter_key or "record_delimiter"
self._completion_delimiter_key = (
completion_delimiter_key or "completion_delimiter"
)
self._input_claim_description_key = (
input_claim_description_key or "claim_description"
)
self._input_resolved_entities_key = (
input_resolved_entities_key or "resolved_entities"
)
self._max_gleanings = (
max_gleanings if max_gleanings is not None else CLAIM_MAX_GLEANINGS
)
self._on_error = on_error or (lambda _e, _s, _d: None)
# Construct the looping arguments
encoding = tiktoken.get_encoding(encoding_model or "cl100k_base")
yes = encoding.encode("YES")
no = encoding.encode("NO")
self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
def __call__(
self, inputs: dict[str, Any], prompt_variables: dict | None = None
) -> ClaimExtractorResult:
"""Call method definition."""
if prompt_variables is None:
prompt_variables = {}
texts = inputs[self._input_text_key]
entity_spec = str(inputs[self._input_entity_spec_key])
claim_description = inputs[self._input_claim_description_key]
resolved_entities = inputs.get(self._input_resolved_entities_key, {})
source_doc_map = {}
prompt_args = {
self._input_entity_spec_key: entity_spec,
self._input_claim_description_key: claim_description,
self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key)
or DEFAULT_TUPLE_DELIMITER,
self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
or DEFAULT_RECORD_DELIMITER,
self._completion_delimiter_key: prompt_variables.get(
self._completion_delimiter_key
)
or DEFAULT_COMPLETION_DELIMITER,
}
all_claims: list[dict] = []
for doc_index, text in enumerate(texts):
document_id = f"d{doc_index}"
try:
claims = self._process_document(prompt_args, text, doc_index)
all_claims += [
self._clean_claim(c, document_id, resolved_entities) for c in claims
]
source_doc_map[document_id] = text
except Exception as e:
logging.exception("error extracting claim")
self._on_error(
e,
traceback.format_exc(),
{"doc_index": doc_index, "text": text},
)
continue
return ClaimExtractorResult(
output=all_claims,
source_docs=source_doc_map,
)
def _clean_claim(
self, claim: dict, document_id: str, resolved_entities: dict
) -> dict:
# clean the parsed claims to remove any claims with status = False
obj = claim.get("object_id", claim.get("object"))
subject = claim.get("subject_id", claim.get("subject"))
# If subject or object in resolved entities, then replace with resolved entity
obj = resolved_entities.get(obj, obj)
subject = resolved_entities.get(subject, subject)
claim["object_id"] = obj
claim["subject_id"] = subject
claim["doc_id"] = document_id
return claim
def _process_document(
self, prompt_args: dict, doc, doc_index: int
) -> list[dict]:
record_delimiter = prompt_args.get(
self._record_delimiter_key, DEFAULT_RECORD_DELIMITER
)
completion_delimiter = prompt_args.get(
self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER
)
variables = {
self._input_text_key: doc,
**prompt_args,
}
text = perform_variable_replacements(self._extraction_prompt, variables=variables)
gen_conf = {"temperature": 0.5}
results = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
claims = results.strip().removesuffix(completion_delimiter)
history = [{"role": "system", "content": text}, {"role": "assistant", "content": results}]
# Repeat to ensure we maximize entity count
for i in range(self._max_gleanings):
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
history.append({"role": "user", "content": text})
extension = self._chat("", history, gen_conf)
claims += record_delimiter + extension.strip().removesuffix(
completion_delimiter
)
# If this isn't the last loop, check to see if we should continue
if i >= self._max_gleanings - 1:
break
history.append({"role": "assistant", "content": extension})
history.append({"role": "user", "content": LOOP_PROMPT})
continuation = self._chat("", history, self._loop_args)
if continuation != "YES":
break
result = self._parse_claim_tuples(claims, prompt_args)
for r in result:
r["doc_id"] = f"{doc_index}"
return result
def _parse_claim_tuples(
self, claims: str, prompt_variables: dict
) -> list[dict[str, Any]]:
"""Parse claim tuples."""
record_delimiter = prompt_variables.get(
self._record_delimiter_key, DEFAULT_RECORD_DELIMITER
)
completion_delimiter = prompt_variables.get(
self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER
)
tuple_delimiter = prompt_variables.get(
self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER
)
def pull_field(index: int, fields: list[str]) -> str | None:
return fields[index].strip() if len(fields) > index else None
result: list[dict[str, Any]] = []
claims_values = (
claims.strip().removesuffix(completion_delimiter).split(record_delimiter)
)
for claim in claims_values:
claim = claim.strip().removeprefix("(").removesuffix(")")
claim = re.sub(r".*Output:", "", claim)
# Ignore the completion delimiter
if claim == completion_delimiter:
continue
claim_fields = claim.split(tuple_delimiter)
o = {
"subject_id": pull_field(0, claim_fields),
"object_id": pull_field(1, claim_fields),
"type": pull_field(2, claim_fields),
"status": pull_field(3, claim_fields),
"start_date": pull_field(4, claim_fields),
"end_date": pull_field(5, claim_fields),
"description": pull_field(6, claim_fields),
"source_text": pull_field(7, claim_fields),
"doc_id": pull_field(8, claim_fields),
}
if any([not o["subject_id"], not o["object_id"], o["subject_id"].lower() == "none", o["object_id"] == "none"]):
continue
result.append(o)
return result
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True)
parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True)
args = parser.parse_args()
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from api import settings
from api.db.services.knowledgebase_service import KnowledgebaseService
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
docs = [d["content_with_weight"] for d in settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])]
info = {
"input_text": docs,
"entity_specs": "organization, person",
"claim_description": ""
}
claim = ex(info)
logging.info(json.dumps(claim.output, ensure_ascii=False, indent=2))

View File

@ -1,71 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""
CLAIM_EXTRACTION_PROMPT = """
################
-Target activity-
################
You are an intelligent assistant that helps a human analyst to analyze claims against certain entities presented in a text document.
################
-Goal-
################
Given a text document that is potentially relevant to this activity, an entity specification, and a claim description, extract all entities that match the entity specification and all claims against those entities.
################
-Steps-
################
- 1. Extract all named entities that match the predefined entity specification. Entity specification can either be a list of entity names or a list of entity types.
- 2. For each entity identified in step 1, extract all claims associated with the entity. Claims need to match the specified claim description, and the entity should be the subject of the claim.
For each claim, extract the following information:
- Subject: name of the entity that is subject of the claim, capitalized. The subject entity is one that committed the action described in the claim. Subject needs to be one of the named entities identified in step 1.
- Object: name of the entity that is object of the claim, capitalized. The object entity is one that either reports/handles or is affected by the action described in the claim. If object entity is unknown, use **NONE**.
- Claim Type: overall category of the claim, capitalized. Name it in a way that can be repeated across multiple text inputs, so that similar claims share the same claim type
- Claim Status: **TRUE**, **FALSE**, or **SUSPECTED**. TRUE means the claim is confirmed, FALSE means the claim is found to be False, SUSPECTED means the claim is not verified.
- Claim Description: Detailed description explaining the reasoning behind the claim, together with all the related evidence and references.
- Claim Date: Period (start_date, end_date) when the claim was made. Both start_date and end_date should be in ISO-8601 format. If the claim was made on a single date rather than a date range, set the same date for both start_date and end_date. If date is unknown, return **NONE**.
- Claim Source Text: List of **all** quotes from the original text that are relevant to the claim.
- 3. Format each claim as (<subject_entity>{tuple_delimiter}<object_entity>{tuple_delimiter}<claim_type>{tuple_delimiter}<claim_status>{tuple_delimiter}<claim_start_date>{tuple_delimiter}<claim_end_date>{tuple_delimiter}<claim_description>{tuple_delimiter}<claim_source>)
- 4. Return output in language of the 'Text' as a single list of all the claims identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
- 5. If there's nothing satisfy the above requirements, just keep output empty.
- 6. When finished, output {completion_delimiter}
################
-Examples-
################
Example 1:
Entity specification: organization
Claim description: red flags associated with an entity
Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015.
Output:
(COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.)
{completion_delimiter}
###########################
Example 2:
Entity specification: Company A, Person C
Claim description: red flags associated with an entity
Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015.
Output:
(COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.)
{record_delimiter}
(PERSON C{tuple_delimiter}NONE{tuple_delimiter}CORRUPTION{tuple_delimiter}SUSPECTED{tuple_delimiter}2015-01-01T00:00:00{tuple_delimiter}2015-12-30T00:00:00{tuple_delimiter}Person C was suspected of engaging in corruption activities in 2015{tuple_delimiter}The company is owned by Person C who was suspected of engaging in corruption activities in 2015)
{completion_delimiter}
################
-Real Data-
################
Use the following input for your answer.
Entity specification: {entity_specs}
Claim description: {claim_description}
Text: {input_text}
Output:"""
CONTINUE_PROMPT = "MANY entities were missed in the last extraction. Add them below using the same format(see 'Steps', start with the 'Output').\nOutput: "
LOOP_PROMPT = "It appears some entities may have still been missed. Answer YES {tuple_delimiter} NO if there are still entities that need to be added.\n"

View File

@ -17,9 +17,10 @@ from graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
from graphrag.general.extractor import Extractor from graphrag.general.extractor import Extractor
from graphrag.general.leiden import add_community_info2graph from graphrag.general.leiden import add_community_info2graph
from rag.llm.chat_model import Base as CompletionLLM from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
from rag.utils import num_tokens_from_string from rag.utils import num_tokens_from_string
from timeit import default_timer as timer from timeit import default_timer as timer
import trio
@dataclass @dataclass
@ -52,7 +53,7 @@ class CommunityReportsExtractor(Extractor):
self._extraction_prompt = COMMUNITY_REPORT_PROMPT self._extraction_prompt = COMMUNITY_REPORT_PROMPT
self._max_report_length = max_report_length or 1500 self._max_report_length = max_report_length or 1500
def __call__(self, graph: nx.Graph, callback: Callable | None = None): async def __call__(self, graph: nx.Graph, callback: Callable | None = None):
for node_degree in graph.degree: for node_degree in graph.degree:
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1]) graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
@ -86,28 +87,25 @@ class CommunityReportsExtractor(Extractor):
} }
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables) text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
gen_conf = {"temperature": 0.3} gen_conf = {"temperature": 0.3}
try: async with chat_limiter:
response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf) response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
token_count += num_tokens_from_string(text + response) token_count += num_tokens_from_string(text + response)
response = re.sub(r"^[^\{]*", "", response) response = re.sub(r"^[^\{]*", "", response)
response = re.sub(r"[^\}]*$", "", response) response = re.sub(r"[^\}]*$", "", response)
response = re.sub(r"\{\{", "{", response) response = re.sub(r"\{\{", "{", response)
response = re.sub(r"\}\}", "}", response) response = re.sub(r"\}\}", "}", response)
logging.debug(response) logging.debug(response)
response = json.loads(response) response = json.loads(response)
if not dict_has_keys_with_types(response, [ if not dict_has_keys_with_types(response, [
("title", str), ("title", str),
("summary", str), ("summary", str),
("findings", list), ("findings", list),
("rating", float), ("rating", float),
("rating_explanation", str), ("rating_explanation", str),
]): ]):
continue
response["weight"] = weight
response["entities"] = ents
except Exception:
logging.exception("CommunityReportsExtractor got exception")
continue continue
response["weight"] = weight
response["entities"] = ents
add_community_info2graph(graph, ents, response["title"]) add_community_info2graph(graph, ents, response["title"])
res_str.append(self._get_text_output(response)) res_str.append(self._get_text_output(response))

View File

@ -14,16 +14,15 @@
# limitations under the License. # limitations under the License.
# #
import logging import logging
import os
import re import re
from collections import defaultdict, Counter from collections import defaultdict, Counter
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy from copy import deepcopy
from typing import Callable from typing import Callable
import trio
from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT
from graphrag.utils import get_llm_cache, set_llm_cache, handle_single_entity_extraction, \ from graphrag.utils import get_llm_cache, set_llm_cache, handle_single_entity_extraction, \
handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list, chat_limiter
from rag.llm.chat_model import Base as CompletionLLM from rag.llm.chat_model import Base as CompletionLLM
from rag.utils import truncate from rag.utils import truncate
@ -91,54 +90,50 @@ class Extractor:
) )
return dict(maybe_nodes), dict(maybe_edges) return dict(maybe_nodes), dict(maybe_edges)
def __call__( async def __call__(
self, chunks: list[tuple[str, str]], self, chunks: list[tuple[str, str]],
callback: Callable | None = None callback: Callable | None = None
): ):
results = [] self.callback = callback
max_workers = int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 10)) start_ts = trio.current_time()
with ThreadPoolExecutor(max_workers=max_workers) as exe: out_results = []
threads = [] async with trio.open_nursery() as nursery:
for i, (cid, ck) in enumerate(chunks): for i, (cid, ck) in enumerate(chunks):
ck = truncate(ck, int(self._llm.max_length*0.8)) ck = truncate(ck, int(self._llm.max_length*0.8))
threads.append( nursery.start_soon(self._process_single_content, (cid, ck), i, len(chunks), out_results)
exe.submit(self._process_single_content, (cid, ck)))
for i, _ in enumerate(threads):
n, r, tc = _.result()
if not isinstance(n, Exception):
results.append((n, r))
if callback:
callback(0.5 + 0.1 * i / len(threads), f"Entities extraction progress ... {i + 1}/{len(threads)} ({tc} tokens)")
elif callback:
callback(msg="Knowledge graph extraction error:{}".format(str(n)))
maybe_nodes = defaultdict(list) maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list) maybe_edges = defaultdict(list)
for m_nodes, m_edges in results: sum_token_count = 0
for m_nodes, m_edges, token_count in out_results:
for k, v in m_nodes.items(): for k, v in m_nodes.items():
maybe_nodes[k].extend(v) maybe_nodes[k].extend(v)
for k, v in m_edges.items(): for k, v in m_edges.items():
maybe_edges[tuple(sorted(k))].extend(v) maybe_edges[tuple(sorted(k))].extend(v)
logging.info("Inserting entities into storage...") sum_token_count += token_count
now = trio.current_time()
if callback:
callback(msg = f"Entities and relationships extraction done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {sum_token_count} tokens, {now-start_ts:.2f}s.")
start_ts = now
logging.info("Entities merging...")
all_entities_data = [] all_entities_data = []
with ThreadPoolExecutor(max_workers=max_workers) as exe: async with trio.open_nursery() as nursery:
threads = []
for en_nm, ents in maybe_nodes.items(): for en_nm, ents in maybe_nodes.items():
threads.append( nursery.start_soon(self._merge_nodes, en_nm, ents, all_entities_data)
exe.submit(self._merge_nodes, en_nm, ents)) now = trio.current_time()
for t in threads: if callback:
n = t.result() callback(msg = f"Entities merging done, {now-start_ts:.2f}s.")
if not isinstance(n, Exception):
all_entities_data.append(n)
elif callback:
callback(msg="Knowledge graph nodes merging error: {}".format(str(n)))
logging.info("Inserting relationships into storage...") start_ts = now
logging.info("Relationships merging...")
all_relationships_data = [] all_relationships_data = []
for (src, tgt), rels in maybe_edges.items(): async with trio.open_nursery() as nursery:
all_relationships_data.append(self._merge_edges(src, tgt, rels)) for (src, tgt), rels in maybe_edges.items():
nursery.start_soon(self._merge_edges, src, tgt, rels, all_relationships_data)
now = trio.current_time()
if callback:
callback(msg = f"Relationships merging done, {now-start_ts:.2f}s.")
if not len(all_entities_data) and not len(all_relationships_data): if not len(all_entities_data) and not len(all_relationships_data):
logging.warning( logging.warning(
@ -152,7 +147,7 @@ class Extractor:
return all_entities_data, all_relationships_data return all_entities_data, all_relationships_data
def _merge_nodes(self, entity_name: str, entities: list[dict]): async def _merge_nodes(self, entity_name: str, entities: list[dict], all_relationships_data):
if not entities: if not entities:
return return
already_entity_types = [] already_entity_types = []
@ -176,26 +171,22 @@ class Extractor:
sorted(set([dp["description"] for dp in entities] + already_description)) sorted(set([dp["description"] for dp in entities] + already_description))
) )
already_source_ids = flat_uniq_list(entities, "source_id") already_source_ids = flat_uniq_list(entities, "source_id")
try: description = await self._handle_entity_relation_summary(entity_name, description)
description = self._handle_entity_relation_summary( node_data = dict(
entity_name, description entity_type=entity_type,
) description=description,
node_data = dict( source_id=already_source_ids,
entity_type=entity_type, )
description=description, node_data["entity_name"] = entity_name
source_id=already_source_ids, self._set_entity_(entity_name, node_data)
) all_relationships_data.append(node_data)
node_data["entity_name"] = entity_name
self._set_entity_(entity_name, node_data)
return node_data
except Exception as e:
return e
def _merge_edges( async def _merge_edges(
self, self,
src_id: str, src_id: str,
tgt_id: str, tgt_id: str,
edges_data: list[dict] edges_data: list[dict],
all_relationships_data
): ):
if not edges_data: if not edges_data:
return return
@ -226,7 +217,7 @@ class Extractor:
"description": description, "description": description,
"entity_type": 'UNKNOWN' "entity_type": 'UNKNOWN'
}) })
description = self._handle_entity_relation_summary( description = await self._handle_entity_relation_summary(
f"({src_id}, {tgt_id})", description f"({src_id}, {tgt_id})", description
) )
edge_data = dict( edge_data = dict(
@ -238,10 +229,9 @@ class Extractor:
source_id=source_id source_id=source_id
) )
self._set_relation_(src_id, tgt_id, edge_data) self._set_relation_(src_id, tgt_id, edge_data)
all_relationships_data.append(edge_data)
return edge_data async def _handle_entity_relation_summary(
def _handle_entity_relation_summary(
self, self,
entity_or_relation_name: str, entity_or_relation_name: str,
description: str description: str
@ -256,5 +246,6 @@ class Extractor:
) )
use_prompt = prompt_template.format(**context_base) use_prompt = prompt_template.format(**context_base)
logging.info(f"Trigger summary: {entity_or_relation_name}") logging.info(f"Trigger summary: {entity_or_relation_name}")
summary = self._chat(use_prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.8}) async with chat_limiter:
summary = await trio.to_thread.run_sync(lambda: self._chat(use_prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.8}))
return summary return summary

View File

@ -5,15 +5,15 @@ Reference:
- [graphrag](https://github.com/microsoft/graphrag) - [graphrag](https://github.com/microsoft/graphrag)
""" """
import logging
import re import re
from typing import Any, Callable from typing import Any, Callable
from dataclasses import dataclass from dataclasses import dataclass
import tiktoken import tiktoken
import trio
from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS, DEFAULT_ENTITY_TYPES from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS, DEFAULT_ENTITY_TYPES
from graphrag.general.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT from graphrag.general.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, chat_limiter
from rag.llm.chat_model import Base as CompletionLLM from rag.llm.chat_model import Base as CompletionLLM
import networkx as nx import networkx as nx
from rag.utils import num_tokens_from_string from rag.utils import num_tokens_from_string
@ -102,53 +102,47 @@ class GraphExtractor(Extractor):
self._entity_types_key: ",".join(DEFAULT_ENTITY_TYPES), self._entity_types_key: ",".join(DEFAULT_ENTITY_TYPES),
} }
def _process_single_content(self, async def _process_single_content(self, chunk_key_dp: tuple[str, str], chunk_seq: int, num_chunks: int, out_results):
chunk_key_dp: tuple[str, str]
):
token_count = 0 token_count = 0
chunk_key = chunk_key_dp[0] chunk_key = chunk_key_dp[0]
content = chunk_key_dp[1] content = chunk_key_dp[1]
variables = { variables = {
**self._prompt_variables, **self._prompt_variables,
self._input_text_key: content, self._input_text_key: content,
} }
try: gen_conf = {"temperature": 0.3}
gen_conf = {"temperature": 0.3} hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables) async with chat_limiter:
response = self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf) response = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf))
token_count += num_tokens_from_string(hint_prompt + response) token_count += num_tokens_from_string(hint_prompt + response)
results = response or ""
history = [{"role": "system", "content": hint_prompt}, {"role": "user", "content": response}]
# Repeat to ensure we maximize entity count
for i in range(self._max_gleanings):
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
history.append({"role": "user", "content": text})
response = self._chat("", history, gen_conf)
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
results += response or ""
# if this is the final glean, don't bother updating the continuation flag
if i >= self._max_gleanings - 1:
break
history.append({"role": "assistant", "content": response})
history.append({"role": "user", "content": LOOP_PROMPT})
continuation = self._chat("", history, {"temperature": 0.8})
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
if continuation != "YES":
break
record_delimiter = variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER)
tuple_delimiter = variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER)
records = [re.sub(r"^\(|\)$", "", r.strip()) for r in results.split(record_delimiter)]
records = [r for r in records if r.strip()]
maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, tuple_delimiter)
return maybe_nodes, maybe_edges, token_count
except Exception as e:
logging.exception("error extracting graph")
return e, None, None
results = response or ""
history = [{"role": "system", "content": hint_prompt}, {"role": "user", "content": response}]
# Repeat to ensure we maximize entity count
for i in range(self._max_gleanings):
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
history.append({"role": "user", "content": text})
async with chat_limiter:
response = await trio.to_thread.run_sync(lambda: self._chat("", history, gen_conf))
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
results += response or ""
# if this is the final glean, don't bother updating the continuation flag
if i >= self._max_gleanings - 1:
break
history.append({"role": "assistant", "content": response})
history.append({"role": "user", "content": LOOP_PROMPT})
async with chat_limiter:
continuation = await trio.to_thread.run_sync(lambda: self._chat("", history, {"temperature": 0.8}))
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
if continuation != "YES":
break
record_delimiter = variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER)
tuple_delimiter = variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER)
records = [re.sub(r"^\(|\)$", "", r.strip()) for r in results.split(record_delimiter)]
records = [r for r in records if r.strip()]
maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, tuple_delimiter)
out_results.append((maybe_nodes, maybe_edges, token_count))
if self.callback:
self.callback(0.5+0.1*len(out_results)/num_chunks, msg = f"Entities extraction of chunk {chunk_seq} {len(out_results)}/{num_chunks} done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {token_count} tokens.")

View File

@ -17,6 +17,7 @@ import json
import logging import logging
from functools import reduce, partial from functools import reduce, partial
import networkx as nx import networkx as nx
import trio
from api import settings from api import settings
from graphrag.general.community_reports_extractor import CommunityReportsExtractor from graphrag.general.community_reports_extractor import CommunityReportsExtractor
@ -41,18 +42,24 @@ class Dealer:
embed_bdl=None, embed_bdl=None,
callback=None callback=None
): ):
docids = list(set([docid for docid,_ in chunks])) self.tenant_id = tenant_id
self.kb_id = kb_id
self.chunks = chunks
self.llm_bdl = llm_bdl self.llm_bdl = llm_bdl
self.embed_bdl = embed_bdl self.embed_bdl = embed_bdl
ext = extractor(self.llm_bdl, language=language, self.ext = extractor(self.llm_bdl, language=language,
entity_types=entity_types, entity_types=entity_types,
get_entity=partial(get_entity, tenant_id, kb_id), get_entity=partial(get_entity, tenant_id, kb_id),
set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl), set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl),
get_relation=partial(get_relation, tenant_id, kb_id), get_relation=partial(get_relation, tenant_id, kb_id),
set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl) set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl)
) )
ents, rels = ext(chunks, callback)
self.graph = nx.Graph() self.graph = nx.Graph()
self.callback = callback
async def __call__(self):
docids = list(set([docid for docid, _ in self.chunks]))
ents, rels = await self.ext(self.chunks, self.callback)
for en in ents: for en in ents:
self.graph.add_node(en["entity_name"], entity_type=en["entity_type"])#, description=en["description"]) self.graph.add_node(en["entity_name"], entity_type=en["entity_type"])#, description=en["description"])
@ -64,16 +71,16 @@ class Dealer:
#description=rel["description"] #description=rel["description"]
) )
with RedisDistributedLock(kb_id, 60*60): with RedisDistributedLock(self.kb_id, 60*60):
old_graph, old_doc_ids = get_graph(tenant_id, kb_id) old_graph, old_doc_ids = get_graph(self.tenant_id, self.kb_id)
if old_graph is not None: if old_graph is not None:
logging.info("Merge with an exiting graph...................") logging.info("Merge with an exiting graph...................")
self.graph = reduce(graph_merge, [old_graph, self.graph]) self.graph = reduce(graph_merge, [old_graph, self.graph])
update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, self.graph, 2) update_nodes_pagerank_nhop_neighbour(self.tenant_id, self.kb_id, self.graph, 2)
if old_doc_ids: if old_doc_ids:
docids.extend(old_doc_ids) docids.extend(old_doc_ids)
docids = list(set(docids)) docids = list(set(docids))
set_graph(tenant_id, kb_id, self.graph, docids) set_graph(self.tenant_id, self.kb_id, self.graph, docids)
class WithResolution(Dealer): class WithResolution(Dealer):
@ -84,47 +91,50 @@ class WithResolution(Dealer):
embed_bdl=None, embed_bdl=None,
callback=None callback=None
): ):
self.tenant_id = tenant_id
self.kb_id = kb_id
self.llm_bdl = llm_bdl self.llm_bdl = llm_bdl
self.embed_bdl = embed_bdl self.embed_bdl = embed_bdl
self.callback = callback
with RedisDistributedLock(kb_id, 60*60): async def __call__(self):
self.graph, doc_ids = get_graph(tenant_id, kb_id) with RedisDistributedLock(self.kb_id, 60*60):
self.graph, doc_ids = await trio.to_thread.run_sync(lambda: get_graph(self.tenant_id, self.kb_id))
if not self.graph: if not self.graph:
logging.error(f"Faild to fetch the graph. tenant_id:{kb_id}, kb_id:{kb_id}") logging.error(f"Faild to fetch the graph. tenant_id:{self.kb_id}, kb_id:{self.kb_id}")
if callback: if self.callback:
callback(-1, msg="Faild to fetch the graph.") self.callback(-1, msg="Faild to fetch the graph.")
return return
if callback: if self.callback:
callback(msg="Fetch the existing graph.") self.callback(msg="Fetch the existing graph.")
er = EntityResolution(self.llm_bdl, er = EntityResolution(self.llm_bdl,
get_entity=partial(get_entity, tenant_id, kb_id), get_entity=partial(get_entity, self.tenant_id, self.kb_id),
set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl), set_entity=partial(set_entity, self.tenant_id, self.kb_id, self.embed_bdl),
get_relation=partial(get_relation, tenant_id, kb_id), get_relation=partial(get_relation, self.tenant_id, self.kb_id),
set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl)) set_relation=partial(set_relation, self.tenant_id, self.kb_id, self.embed_bdl))
reso = er(self.graph) reso = await er(self.graph)
self.graph = reso.graph self.graph = reso.graph
logging.info("Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities))) logging.info("Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
if callback: if self.callback:
callback(msg="Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities))) self.callback(msg="Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, self.graph, 2) await trio.to_thread.run_sync(lambda: update_nodes_pagerank_nhop_neighbour(self.tenant_id, self.kb_id, self.graph, 2))
set_graph(tenant_id, kb_id, self.graph, doc_ids) await trio.to_thread.run_sync(lambda: set_graph(self.tenant_id, self.kb_id, self.graph, doc_ids))
settings.docStoreConn.delete({ await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({
"knowledge_graph_kwd": "relation", "knowledge_graph_kwd": "relation",
"kb_id": kb_id, "kb_id": self.kb_id,
"from_entity_kwd": reso.removed_entities "from_entity_kwd": reso.removed_entities
}, search.index_name(tenant_id), kb_id) }, search.index_name(self.tenant_id), self.kb_id))
settings.docStoreConn.delete({ await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({
"knowledge_graph_kwd": "relation", "knowledge_graph_kwd": "relation",
"kb_id": kb_id, "kb_id": self.kb_id,
"to_entity_kwd": reso.removed_entities "to_entity_kwd": reso.removed_entities
}, search.index_name(tenant_id), kb_id) }, search.index_name(self.tenant_id), self.kb_id))
settings.docStoreConn.delete({ await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({
"knowledge_graph_kwd": "entity", "knowledge_graph_kwd": "entity",
"kb_id": kb_id, "kb_id": self.kb_id,
"entity_kwd": reso.removed_entities "entity_kwd": reso.removed_entities
}, search.index_name(tenant_id), kb_id) }, search.index_name(self.tenant_id), self.kb_id))
class WithCommunity(Dealer): class WithCommunity(Dealer):
@ -136,38 +146,41 @@ class WithCommunity(Dealer):
callback=None callback=None
): ):
self.tenant_id = tenant_id
self.kb_id = kb_id
self.community_structure = None self.community_structure = None
self.community_reports = None self.community_reports = None
self.llm_bdl = llm_bdl self.llm_bdl = llm_bdl
self.embed_bdl = embed_bdl self.embed_bdl = embed_bdl
self.callback = callback
with RedisDistributedLock(kb_id, 60*60): async def __call__(self):
self.graph, doc_ids = get_graph(tenant_id, kb_id) with RedisDistributedLock(self.kb_id, 60*60):
self.graph, doc_ids = get_graph(self.tenant_id, self.kb_id)
if not self.graph: if not self.graph:
logging.error(f"Faild to fetch the graph. tenant_id:{kb_id}, kb_id:{kb_id}") logging.error(f"Faild to fetch the graph. tenant_id:{self.kb_id}, kb_id:{self.kb_id}")
if callback: if self.callback:
callback(-1, msg="Faild to fetch the graph.") self.callback(-1, msg="Faild to fetch the graph.")
return return
if callback: if self.callback:
callback(msg="Fetch the existing graph.") self.callback(msg="Fetch the existing graph.")
cr = CommunityReportsExtractor(self.llm_bdl, cr = CommunityReportsExtractor(self.llm_bdl,
get_entity=partial(get_entity, tenant_id, kb_id), get_entity=partial(get_entity, self.tenant_id, self.kb_id),
set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl), set_entity=partial(set_entity, self.tenant_id, self.kb_id, self.embed_bdl),
get_relation=partial(get_relation, tenant_id, kb_id), get_relation=partial(get_relation, self.tenant_id, self.kb_id),
set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl)) set_relation=partial(set_relation, self.tenant_id, self.kb_id, self.embed_bdl))
cr = cr(self.graph, callback=callback) cr = await cr(self.graph, callback=self.callback)
self.community_structure = cr.structured_output self.community_structure = cr.structured_output
self.community_reports = cr.output self.community_reports = cr.output
set_graph(tenant_id, kb_id, self.graph, doc_ids) await trio.to_thread.run_sync(lambda: set_graph(self.tenant_id, self.kb_id, self.graph, doc_ids))
if callback: if self.callback:
callback(msg="Graph community extraction is done. Indexing {} reports.".format(len(cr.structured_output))) self.callback(msg="Graph community extraction is done. Indexing {} reports.".format(len(cr.structured_output)))
settings.docStoreConn.delete({ await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({
"knowledge_graph_kwd": "community_report", "knowledge_graph_kwd": "community_report",
"kb_id": kb_id "kb_id": self.kb_id
}, search.index_name(tenant_id), kb_id) }, search.index_name(self.tenant_id), self.kb_id))
for stru, rep in zip(self.community_structure, self.community_reports): for stru, rep in zip(self.community_structure, self.community_reports):
obj = { obj = {
@ -183,7 +196,7 @@ class WithCommunity(Dealer):
"weight_flt": stru["weight"], "weight_flt": stru["weight"],
"entities_kwd": stru["entities"], "entities_kwd": stru["entities"],
"important_kwd": stru["entities"], "important_kwd": stru["entities"],
"kb_id": kb_id, "kb_id": self.kb_id,
"source_id": doc_ids, "source_id": doc_ids,
"available_int": 0 "available_int": 0
} }
@ -193,5 +206,5 @@ class WithCommunity(Dealer):
# chunk["q_%d_vec" % len(ebd[0])] = ebd[0] # chunk["q_%d_vec" % len(ebd[0])] = ebd[0]
#except Exception as e: #except Exception as e:
# logging.exception(f"Fail to embed entity relation: {e}") # logging.exception(f"Fail to embed entity relation: {e}")
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id)) await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(self.tenant_id)))

View File

@ -16,16 +16,14 @@
import logging import logging
import collections import collections
import os
import re import re
import traceback
from typing import Any from typing import Any
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
import trio
from graphrag.general.extractor import Extractor from graphrag.general.extractor import Extractor
from graphrag.general.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT from graphrag.general.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, chat_limiter
from rag.llm.chat_model import Base as CompletionLLM from rag.llm.chat_model import Base as CompletionLLM
import markdown_to_json import markdown_to_json
from functools import reduce from functools import reduce
@ -80,63 +78,47 @@ class MindMapExtractor(Extractor):
) )
return arr return arr
def __call__( async def __call__(
self, sections: list[str], prompt_variables: dict[str, Any] | None = None self, sections: list[str], prompt_variables: dict[str, Any] | None = None
) -> MindMapResult: ) -> MindMapResult:
"""Call method definition.""" """Call method definition."""
if prompt_variables is None: if prompt_variables is None:
prompt_variables = {} prompt_variables = {}
try: res = []
res = [] token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512)
max_workers = int(os.environ.get('MINDMAP_EXTRACTOR_MAX_WORKERS', 12)) texts = []
with ThreadPoolExecutor(max_workers=max_workers) as exe: cnt = 0
threads = [] async with trio.open_nursery() as nursery:
token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512) for i in range(len(sections)):
texts = [] section_cnt = num_tokens_from_string(sections[i])
cnt = 0 if cnt + section_cnt >= token_count and texts:
for i in range(len(sections)): nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res)
section_cnt = num_tokens_from_string(sections[i]) texts = []
if cnt + section_cnt >= token_count and texts: cnt = 0
threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables)) texts.append(sections[i])
texts = [] cnt += section_cnt
cnt = 0 if texts:
texts.append(sections[i]) nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res)
cnt += section_cnt if not res:
if texts: return MindMapResult(output={"id": "root", "children": []})
threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables)) merge_json = reduce(self._merge, res)
if len(merge_json) > 1:
for i, _ in enumerate(threads): keys = [re.sub(r"\*+", "", k) for k, v in merge_json.items() if isinstance(v, dict)]
res.append(_.result()) keyset = set(i for i in keys if i)
merge_json = {
if not res: "id": "root",
return MindMapResult(output={"id": "root", "children": []}) "children": [
{
merge_json = reduce(self._merge, res) "id": self._key(k),
if len(merge_json) > 1: "children": self._be_children(v, keyset)
keys = [re.sub(r"\*+", "", k) for k, v in merge_json.items() if isinstance(v, dict)] }
keyset = set(i for i in keys if i) for k, v in merge_json.items() if isinstance(v, dict) and self._key(k)
merge_json = { ]
"id": "root", }
"children": [ else:
{ k = self._key(list(merge_json.keys())[0])
"id": self._key(k), merge_json = {"id": k, "children": self._be_children(list(merge_json.items())[0][1], {k})}
"children": self._be_children(v, keyset)
}
for k, v in merge_json.items() if isinstance(v, dict) and self._key(k)
]
}
else:
k = self._key(list(merge_json.keys())[0])
merge_json = {"id": k, "children": self._be_children(list(merge_json.items())[0][1], {k})}
except Exception as e:
logging.exception("error mind graph")
self._on_error(
e,
traceback.format_exc(), None
)
merge_json = {"error": str(e)}
return MindMapResult(output=merge_json) return MindMapResult(output=merge_json)
@ -181,8 +163,8 @@ class MindMapExtractor(Extractor):
return self._list_to_kv(to_ret) return self._list_to_kv(to_ret)
def _process_document( async def _process_document(
self, text: str, prompt_variables: dict[str, str] self, text: str, prompt_variables: dict[str, str], out_res
) -> str: ) -> str:
variables = { variables = {
**prompt_variables, **prompt_variables,
@ -190,8 +172,9 @@ class MindMapExtractor(Extractor):
} }
text = perform_variable_replacements(self._mind_map_prompt, variables=variables) text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
gen_conf = {"temperature": 0.5} gen_conf = {"temperature": 0.5}
response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf) async with chat_limiter:
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
response = re.sub(r"```[^\n]*", "", response) response = re.sub(r"```[^\n]*", "", response)
logging.debug(response) logging.debug(response)
logging.debug(self._todict(markdown_to_json.dictify(response))) logging.debug(self._todict(markdown_to_json.dictify(response)))
return self._todict(markdown_to_json.dictify(response)) out_res.append(self._todict(markdown_to_json.dictify(response)))

View File

@ -18,6 +18,7 @@ import argparse
import json import json
import networkx as nx import networkx as nx
import trio
from api import settings from api import settings
from api.db import LLMType from api.db import LLMType
@ -54,10 +55,13 @@ if __name__ == "__main__":
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id) embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
dealer = Dealer(GraphExtractor, args.tenant_id, kb_id, llm_bdl, chunks, "English", embed_bdl=embed_bdl) dealer = Dealer(GraphExtractor, args.tenant_id, kb_id, llm_bdl, chunks, "English", embed_bdl=embed_bdl)
trio.run(dealer())
print(json.dumps(nx.node_link_data(dealer.graph), ensure_ascii=False, indent=2)) print(json.dumps(nx.node_link_data(dealer.graph), ensure_ascii=False, indent=2))
dealer = WithResolution(args.tenant_id, kb_id, llm_bdl, embed_bdl) dealer = WithResolution(args.tenant_id, kb_id, llm_bdl, embed_bdl)
trio.run(dealer())
dealer = WithCommunity(args.tenant_id, kb_id, llm_bdl, embed_bdl) dealer = WithCommunity(args.tenant_id, kb_id, llm_bdl, embed_bdl)
trio.run(dealer())
print("------------------ COMMUNITY REPORT ----------------------\n", dealer.community_reports) print("------------------ COMMUNITY REPORT ----------------------\n", dealer.community_reports)
print(json.dumps(dealer.community_structure, ensure_ascii=False, indent=2)) print(json.dumps(dealer.community_structure, ensure_ascii=False, indent=2))

View File

@ -4,16 +4,16 @@
Reference: Reference:
- [graphrag](https://github.com/microsoft/graphrag) - [graphrag](https://github.com/microsoft/graphrag)
""" """
import logging
import re import re
from typing import Any, Callable from typing import Any, Callable
from dataclasses import dataclass from dataclasses import dataclass
from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS
from graphrag.light.graph_prompt import PROMPTS from graphrag.light.graph_prompt import PROMPTS
from graphrag.utils import pack_user_ass_to_openai_messages, split_string_by_multi_markers from graphrag.utils import pack_user_ass_to_openai_messages, split_string_by_multi_markers, chat_limiter
from rag.llm.chat_model import Base as CompletionLLM from rag.llm.chat_model import Base as CompletionLLM
import networkx as nx import networkx as nx
from rag.utils import num_tokens_from_string from rag.utils import num_tokens_from_string
import trio
@dataclass @dataclass
@ -82,7 +82,7 @@ class GraphExtractor(Extractor):
) )
self._left_token_count = max(llm_invoker.max_length * 0.6, self._left_token_count) self._left_token_count = max(llm_invoker.max_length * 0.6, self._left_token_count)
def _process_single_content(self, chunk_key_dp: tuple[str, str]): async def _process_single_content(self, chunk_key_dp: tuple[str, str], chunk_seq: int, num_chunks: int, out_results):
token_count = 0 token_count = 0
chunk_key = chunk_key_dp[0] chunk_key = chunk_key_dp[0]
content = chunk_key_dp[1] content = chunk_key_dp[1]
@ -90,38 +90,39 @@ class GraphExtractor(Extractor):
**self._context_base, input_text="{input_text}" **self._context_base, input_text="{input_text}"
).format(**self._context_base, input_text=content) ).format(**self._context_base, input_text=content)
try: gen_conf = {"temperature": 0.8}
gen_conf = {"temperature": 0.8} async with chat_limiter:
final_result = self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf) final_result = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf))
token_count += num_tokens_from_string(hint_prompt + final_result) token_count += num_tokens_from_string(hint_prompt + final_result)
history = pack_user_ass_to_openai_messages("Output:", final_result, self._continue_prompt) history = pack_user_ass_to_openai_messages("Output:", final_result, self._continue_prompt)
for now_glean_index in range(self._max_gleanings): for now_glean_index in range(self._max_gleanings):
glean_result = self._chat(hint_prompt, history, gen_conf) async with chat_limiter:
history.extend([{"role": "assistant", "content": glean_result}, {"role": "user", "content": self._continue_prompt}]) glean_result = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, history, gen_conf))
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt) history.extend([{"role": "assistant", "content": glean_result}, {"role": "user", "content": self._continue_prompt}])
final_result += glean_result token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt)
if now_glean_index == self._max_gleanings - 1: final_result += glean_result
break if now_glean_index == self._max_gleanings - 1:
break
if_loop_result = self._chat(self._if_loop_prompt, history, gen_conf) async with chat_limiter:
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt) if_loop_result = await trio.to_thread.run_sync(lambda: self._chat(self._if_loop_prompt, history, gen_conf))
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower() token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt)
if if_loop_result != "yes": if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
break if if_loop_result != "yes":
break
records = split_string_by_multi_markers( records = split_string_by_multi_markers(
final_result, final_result,
[self._context_base["record_delimiter"], self._context_base["completion_delimiter"]], [self._context_base["record_delimiter"], self._context_base["completion_delimiter"]],
) )
rcds = [] rcds = []
for record in records: for record in records:
record = re.search(r"\((.*)\)", record) record = re.search(r"\((.*)\)", record)
if record is None: if record is None:
continue continue
rcds.append(record.group(1)) rcds.append(record.group(1))
records = rcds records = rcds
maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, self._context_base["tuple_delimiter"]) maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, self._context_base["tuple_delimiter"])
return maybe_nodes, maybe_edges, token_count out_results.append((maybe_nodes, maybe_edges, token_count))
except Exception as e: if self.callback:
logging.exception("error extracting graph") self.callback(0.5+0.1*len(out_results)/num_chunks, msg = f"Entities extraction of chunk {chunk_seq} {len(out_results)}/{num_chunks} done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {token_count} tokens.")
return e, None, None

View File

@ -15,6 +15,8 @@ from collections import defaultdict
from copy import deepcopy from copy import deepcopy
from hashlib import md5 from hashlib import md5
from typing import Any, Callable from typing import Any, Callable
import os
import trio
import networkx as nx import networkx as nx
import numpy as np import numpy as np
@ -28,6 +30,7 @@ from rag.utils.redis_conn import REDIS_CONN
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None] ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
chat_limiter = trio.CapacityLimiter(int(os.environ.get('MAX_CONCURRENT_CHATS', 100)))
def perform_variable_replacements( def perform_variable_replacements(
input: str, history: list[dict] | None = None, variables: dict | None = None input: str, history: list[dict] | None = None, variables: dict | None = None

View File

@ -122,7 +122,8 @@ dependencies = [
"pyodbc>=5.2.0,<6.0.0", "pyodbc>=5.2.0,<6.0.0",
"pyicu>=2.13.1,<3.0.0", "pyicu>=2.13.1,<3.0.0",
"flasgger>=0.9.7.1,<0.10.0", "flasgger>=0.9.7.1,<0.10.0",
"xxhash>=3.5.0,<4.0.0" "xxhash>=3.5.0,<4.0.0",
"trio>=0.29.0",
] ]
[project.optional-dependencies] [project.optional-dependencies]
@ -133,4 +134,7 @@ full = [
"flagembedding==1.2.10", "flagembedding==1.2.10",
"torch>=2.5.0,<3.0.0", "torch>=2.5.0,<3.0.0",
"transformers>=4.35.0,<5.0.0" "transformers>=4.35.0,<5.0.0"
] ]
[[tool.uv.index]]
url = "https://mirrors.aliyun.com/pypi/simple"

View File

@ -14,15 +14,14 @@
# limitations under the License. # limitations under the License.
# #
import logging import logging
import os
import re import re
from concurrent.futures import ThreadPoolExecutor, ALL_COMPLETED, wait
from threading import Lock from threading import Lock
import umap import umap
import numpy as np import numpy as np
from sklearn.mixture import GaussianMixture from sklearn.mixture import GaussianMixture
import trio
from graphrag.utils import get_llm_cache, get_embed_cache, set_embed_cache, set_llm_cache from graphrag.utils import get_llm_cache, get_embed_cache, set_embed_cache, set_llm_cache, chat_limiter
from rag.utils import truncate from rag.utils import truncate
@ -68,24 +67,25 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
optimal_clusters = n_clusters[np.argmin(bics)] optimal_clusters = n_clusters[np.argmin(bics)]
return optimal_clusters return optimal_clusters
def __call__(self, chunks, random_state, callback=None): async def __call__(self, chunks, random_state, callback=None):
layers = [(0, len(chunks))] layers = [(0, len(chunks))]
start, end = 0, len(chunks) start, end = 0, len(chunks)
if len(chunks) <= 1: if len(chunks) <= 1:
return [] return []
chunks = [(s, a) for s, a in chunks if s and len(a) > 0] chunks = [(s, a) for s, a in chunks if s and len(a) > 0]
def summarize(ck_idx, lock): async def summarize(ck_idx, lock):
nonlocal chunks nonlocal chunks
try: try:
texts = [chunks[i][0] for i in ck_idx] texts = [chunks[i][0] for i in ck_idx]
len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts)) len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts]) cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
cnt = self._chat("You're a helpful assistant.", async with chat_limiter:
[{"role": "user", cnt = await trio.to_thread.run_sync(lambda: self._chat("You're a helpful assistant.",
"content": self._prompt.format(cluster_content=cluster_content)}], [{"role": "user",
{"temperature": 0.3, "max_tokens": self._max_token} "content": self._prompt.format(cluster_content=cluster_content)}],
) {"temperature": 0.3, "max_tokens": self._max_token}
))
cnt = re.sub("(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", "", cnt = re.sub("(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", "",
cnt) cnt)
logging.debug(f"SUM: {cnt}") logging.debug(f"SUM: {cnt}")
@ -97,10 +97,11 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
return e return e
labels = [] labels = []
lock = Lock()
while end - start > 1: while end - start > 1:
embeddings = [embd for _, embd in chunks[start: end]] embeddings = [embd for _, embd in chunks[start: end]]
if len(embeddings) == 2: if len(embeddings) == 2:
summarize([start, start + 1], Lock()) await summarize([start, start + 1], lock)
if callback: if callback:
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end)) callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
labels.extend([0, 0]) labels.extend([0, 0])
@ -122,19 +123,14 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
probs = gm.predict_proba(reduced_embeddings) probs = gm.predict_proba(reduced_embeddings)
lbls = [np.where(prob > self._threshold)[0] for prob in probs] lbls = [np.where(prob > self._threshold)[0] for prob in probs]
lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls] lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls]
lock = Lock()
with ThreadPoolExecutor(max_workers=int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 10))) as executor: async with trio.open_nursery() as nursery:
threads = []
for c in range(n_clusters): for c in range(n_clusters):
ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c] ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]
if not ck_idx: if not ck_idx:
continue continue
threads.append(executor.submit(summarize, ck_idx, lock)) async with chat_limiter:
wait(threads, return_when=ALL_COMPLETED) nursery.start_soon(lambda: summarize(ck_idx, lock))
for th in threads:
if isinstance(th.result(), Exception):
raise th.result()
logging.debug(str([t.result() for t in threads]))
assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters) assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
labels.extend(lbls) labels.extend(lbls)

View File

@ -30,7 +30,6 @@ CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
CONSUMER_NAME = "task_executor_" + CONSUMER_NO CONSUMER_NAME = "task_executor_" + CONSUMER_NO
initRootLogger(CONSUMER_NAME) initRootLogger(CONSUMER_NAME)
import asyncio
import logging import logging
import os import os
from datetime import datetime from datetime import datetime
@ -38,14 +37,14 @@ import json
import xxhash import xxhash
import copy import copy
import re import re
import time
import threading
from functools import partial from functools import partial
from io import BytesIO from io import BytesIO
from multiprocessing.context import TimeoutError from multiprocessing.context import TimeoutError
from timeit import default_timer as timer from timeit import default_timer as timer
import tracemalloc import tracemalloc
import resource
import signal import signal
import trio
import numpy as np import numpy as np
from peewee import DoesNotExist from peewee import DoesNotExist
@ -64,8 +63,9 @@ from rag.nlp import search, rag_tokenizer
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings, TAG_FLD, PAGERANK_FLD from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings, TAG_FLD, PAGERANK_FLD
from rag.utils import num_tokens_from_string from rag.utils import num_tokens_from_string
from rag.utils.redis_conn import REDIS_CONN, Payload from rag.utils.redis_conn import REDIS_CONN
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
from graphrag.utils import chat_limiter
BATCH_SIZE = 64 BATCH_SIZE = 64
@ -88,28 +88,28 @@ FACTORY = {
ParserType.TAG.value: tag ParserType.TAG.value: tag
} }
UNACKED_ITERATOR = None
CONSUMER_NAME = "task_consumer_" + CONSUMER_NO CONSUMER_NAME = "task_consumer_" + CONSUMER_NO
PAYLOAD: Payload | None = None
BOOT_AT = datetime.now().astimezone().isoformat(timespec="milliseconds") BOOT_AT = datetime.now().astimezone().isoformat(timespec="milliseconds")
PENDING_TASKS = 0 PENDING_TASKS = 0
LAG_TASKS = 0 LAG_TASKS = 0
mt_lock = threading.Lock()
DONE_TASKS = 0 DONE_TASKS = 0
FAILED_TASKS = 0 FAILED_TASKS = 0
CURRENT_TASK = None
tracemalloc_started = False CURRENT_TASKS = {}
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
MAX_CONCURRENT_CHUNK_BUILDERS = int(os.environ.get('MAX_CONCURRENT_CHUNK_BUILDERS', "1"))
task_limiter = trio.CapacityLimiter(MAX_CONCURRENT_TASKS)
chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS)
# SIGUSR1 handler: start tracemalloc and take snapshot # SIGUSR1 handler: start tracemalloc and take snapshot
def start_tracemalloc_and_snapshot(signum, frame): def start_tracemalloc_and_snapshot(signum, frame):
global tracemalloc_started if not tracemalloc.is_tracing():
if not tracemalloc_started: logging.info("start tracemalloc")
logging.info("got SIGUSR1, start tracemalloc")
tracemalloc.start() tracemalloc.start()
tracemalloc_started = True
else: else:
logging.info("got SIGUSR1, tracemalloc is already running") logging.info("tracemalloc is already running")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
snapshot_file = f"snapshot_{timestamp}.trace" snapshot_file = f"snapshot_{timestamp}.trace"
@ -117,17 +117,17 @@ def start_tracemalloc_and_snapshot(signum, frame):
snapshot = tracemalloc.take_snapshot() snapshot = tracemalloc.take_snapshot()
snapshot.dump(snapshot_file) snapshot.dump(snapshot_file)
logging.info(f"taken snapshot {snapshot_file}") current, peak = tracemalloc.get_traced_memory()
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
# SIGUSR2 handler: stop tracemalloc # SIGUSR2 handler: stop tracemalloc
def stop_tracemalloc(signum, frame): def stop_tracemalloc(signum, frame):
global tracemalloc_started if tracemalloc.is_tracing():
if tracemalloc_started: logging.info("stop tracemalloc")
logging.info("go SIGUSR2, stop tracemalloc")
tracemalloc.stop() tracemalloc.stop()
tracemalloc_started = False
else: else:
logging.info("got SIGUSR2, tracemalloc not running") logging.info("tracemalloc not running")
class TaskCanceledException(Exception): class TaskCanceledException(Exception):
def __init__(self, msg): def __init__(self, msg):
@ -135,17 +135,9 @@ class TaskCanceledException(Exception):
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."): def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
global PAYLOAD
if prog is not None and prog < 0: if prog is not None and prog < 0:
msg = "[ERROR]" + msg msg = "[ERROR]" + msg
try: cancel = TaskService.do_cancel(task_id)
cancel = TaskService.do_cancel(task_id)
except DoesNotExist:
logging.warning(f"set_progress task {task_id} is unknown")
if PAYLOAD:
PAYLOAD.ack()
PAYLOAD = None
return
if cancel: if cancel:
msg += " [Canceled]" msg += " [Canceled]"
@ -162,66 +154,55 @@ def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing...
d["progress"] = prog d["progress"] = prog
logging.info(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}") logging.info(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}")
try: TaskService.update_progress(task_id, d)
TaskService.update_progress(task_id, d)
except DoesNotExist:
logging.warning(f"set_progress task {task_id} is unknown")
if PAYLOAD:
PAYLOAD.ack()
PAYLOAD = None
return
close_connection() close_connection()
if cancel and PAYLOAD: if cancel:
PAYLOAD.ack()
PAYLOAD = None
raise TaskCanceledException(msg) raise TaskCanceledException(msg)
async def collect():
def collect(): global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS
global CONSUMER_NAME, PAYLOAD, DONE_TASKS, FAILED_TASKS global UNACKED_ITERATOR
try: try:
PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker") if not UNACKED_ITERATOR:
if not PAYLOAD: UNACKED_ITERATOR = REDIS_CONN.get_unacked_iterator(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME) try:
if not PAYLOAD: redis_msg = next(UNACKED_ITERATOR)
time.sleep(1) except StopIteration:
return None redis_msg = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
if not redis_msg:
await trio.sleep(1)
return None, None
except Exception: except Exception:
logging.exception("Get task event from queue exception") logging.exception("collect got exception")
return None return None, None
msg = PAYLOAD.get_message() msg = redis_msg.get_message()
if not msg: if not msg:
return None logging.error(f"collect got empty message of {redis_msg.get_msg_id()}")
redis_msg.ack()
return None, None
task = None
canceled = False canceled = False
try: task = TaskService.get_task(msg["id"])
task = TaskService.get_task(msg["id"]) if task:
if task: _, doc = DocumentService.get_by_id(task["doc_id"])
_, doc = DocumentService.get_by_id(task["doc_id"]) canceled = doc.run == TaskStatus.CANCEL.value or doc.progress < 0
canceled = doc.run == TaskStatus.CANCEL.value or doc.progress < 0
except DoesNotExist:
pass
except Exception:
logging.exception("collect get_task exception")
if not task or canceled: if not task or canceled:
state = "is unknown" if not task else "has been cancelled" state = "is unknown" if not task else "has been cancelled"
with mt_lock: FAILED_TASKS += 1
DONE_TASKS += 1 logging.warning(f"collect task {msg['id']} {state}")
logging.info(f"collect task {msg['id']} {state}") redis_msg.ack()
return None return None
task["task_type"] = msg.get("task_type", "") task["task_type"] = msg.get("task_type", "")
return task return redis_msg, task
def get_storage_binary(bucket, name): async def get_storage_binary(bucket, name):
return STORAGE_IMPL.get(bucket, name) return await trio.to_thread.run_sync(lambda: STORAGE_IMPL.get(bucket, name))
def build_chunks(task, progress_callback): async def build_chunks(task, progress_callback):
if task["size"] > DOC_MAXIMUM_SIZE: if task["size"] > DOC_MAXIMUM_SIZE:
set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" % set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
(int(DOC_MAXIMUM_SIZE / 1024 / 1024))) (int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
@ -231,7 +212,7 @@ def build_chunks(task, progress_callback):
try: try:
st = timer() st = timer()
bucket, name = File2DocumentService.get_storage_address(doc_id=task["doc_id"]) bucket, name = File2DocumentService.get_storage_address(doc_id=task["doc_id"])
binary = get_storage_binary(bucket, name) binary = await get_storage_binary(bucket, name)
logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"])) logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"]))
except TimeoutError: except TimeoutError:
progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.") progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
@ -247,9 +228,10 @@ def build_chunks(task, progress_callback):
raise raise
try: try:
cks = chunker.chunk(task["name"], binary=binary, from_page=task["from_page"], async with chunk_limiter:
to_page=task["to_page"], lang=task["language"], callback=progress_callback, cks = await trio.to_thread.run_sync(lambda: chunker.chunk(task["name"], binary=binary, from_page=task["from_page"],
kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"]) to_page=task["to_page"], lang=task["language"], callback=progress_callback,
kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"]))
logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"])) logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"]))
except TaskCanceledException: except TaskCanceledException:
raise raise
@ -286,7 +268,7 @@ def build_chunks(task, progress_callback):
d["image"].save(output_buffer, format='JPEG') d["image"].save(output_buffer, format='JPEG')
st = timer() st = timer()
STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue()) await trio.to_thread.run_sync(lambda: STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue()))
el += timer() - st el += timer() - st
except Exception: except Exception:
logging.exception( logging.exception(
@ -306,14 +288,16 @@ def build_chunks(task, progress_callback):
async def doc_keyword_extraction(chat_mdl, d, topn): async def doc_keyword_extraction(chat_mdl, d, topn):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn}) cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn})
if not cached: if not cached:
cached = await asyncio.to_thread(keyword_extraction, chat_mdl, d["content_with_weight"], topn) async with chat_limiter:
cached = await trio.to_thread.run_sync(lambda: keyword_extraction(chat_mdl, d["content_with_weight"], topn))
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn}) set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn})
if cached: if cached:
d["important_kwd"] = cached.split(",") d["important_kwd"] = cached.split(",")
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"])) d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
return return
tasks = [doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"]) for d in docs] async with trio.open_nursery() as nursery:
asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks)) for d in docs:
nursery.start_soon(doc_keyword_extraction, chat_mdl, d, task["parser_config"]["auto_keywords"])
progress_callback(msg="Keywords generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st)) progress_callback(msg="Keywords generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
if task["parser_config"].get("auto_questions", 0): if task["parser_config"].get("auto_questions", 0):
@ -324,13 +308,15 @@ def build_chunks(task, progress_callback):
async def doc_question_proposal(chat_mdl, d, topn): async def doc_question_proposal(chat_mdl, d, topn):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn}) cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn})
if not cached: if not cached:
cached = await asyncio.to_thread(question_proposal, chat_mdl, d["content_with_weight"], topn) async with chat_limiter:
cached = await trio.to_thread.run_sync(lambda: question_proposal(chat_mdl, d["content_with_weight"], topn))
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", {"topn": topn}) set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", {"topn": topn})
if cached: if cached:
d["question_kwd"] = cached.split("\n") d["question_kwd"] = cached.split("\n")
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"])) d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
tasks = [doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"]) for d in docs] async with trio.open_nursery() as nursery:
asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks)) for d in docs:
nursery.start_soon(doc_question_proposal, chat_mdl, d, task["parser_config"]["auto_questions"])
progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st)) progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
if task["kb_parser_config"].get("tag_kb_ids", []): if task["kb_parser_config"].get("tag_kb_ids", []):
@ -361,14 +347,16 @@ def build_chunks(task, progress_callback):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], all_tags, {"topn": topn_tags}) cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], all_tags, {"topn": topn_tags})
if not cached: if not cached:
picked_examples = random.choices(examples, k=2) if len(examples)>2 else examples picked_examples = random.choices(examples, k=2) if len(examples)>2 else examples
cached = await asyncio.to_thread(content_tagging, chat_mdl, d["content_with_weight"], all_tags, picked_examples, topn=topn_tags) async with chat_limiter:
cached = await trio.to_thread.run_sync(lambda: content_tagging(chat_mdl, d["content_with_weight"], all_tags, picked_examples, topn=topn_tags))
if cached: if cached:
cached = json.dumps(cached) cached = json.dumps(cached)
if cached: if cached:
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags}) set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
d[TAG_FLD] = json.loads(cached) d[TAG_FLD] = json.loads(cached)
tasks = [doc_content_tagging(chat_mdl, d, topn_tags) for d in docs_to_tag] async with trio.open_nursery() as nursery:
asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks)) for d in docs_to_tag:
nursery.start_soon(doc_content_tagging, chat_mdl, d, topn_tags)
progress_callback(msg="Tagging {} chunks completed in {:.2f}s".format(len(docs), timer() - st)) progress_callback(msg="Tagging {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
return docs return docs
@ -379,7 +367,7 @@ def init_kb(row, vector_size: int):
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size) return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
def embedding(docs, mdl, parser_config=None, callback=None): async def embedding(docs, mdl, parser_config=None, callback=None):
if parser_config is None: if parser_config is None:
parser_config = {} parser_config = {}
batch_size = 16 batch_size = 16
@ -396,13 +384,13 @@ def embedding(docs, mdl, parser_config=None, callback=None):
tk_count = 0 tk_count = 0
if len(tts) == len(cnts): if len(tts) == len(cnts):
vts, c = mdl.encode(tts[0: 1]) vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1]))
tts = np.concatenate([vts for _ in range(len(tts))], axis=0) tts = np.concatenate([vts for _ in range(len(tts))], axis=0)
tk_count += c tk_count += c
cnts_ = np.array([]) cnts_ = np.array([])
for i in range(0, len(cnts), batch_size): for i in range(0, len(cnts), batch_size):
vts, c = mdl.encode(cnts[i: i + batch_size]) vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(cnts[i: i + batch_size]))
if len(cnts_) == 0: if len(cnts_) == 0:
cnts_ = vts cnts_ = vts
else: else:
@ -424,7 +412,7 @@ def embedding(docs, mdl, parser_config=None, callback=None):
return tk_count, vector_size return tk_count, vector_size
def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None): async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
chunks = [] chunks = []
vctr_nm = "q_%d_vec"%vector_size vctr_nm = "q_%d_vec"%vector_size
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
@ -440,7 +428,7 @@ def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
row["parser_config"]["raptor"]["threshold"] row["parser_config"]["raptor"]["threshold"]
) )
original_length = len(chunks) original_length = len(chunks)
chunks = raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback) chunks = await raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
doc = { doc = {
"doc_id": row["doc_id"], "doc_id": row["doc_id"],
"kb_id": [str(row["kb_id"])], "kb_id": [str(row["kb_id"])],
@ -465,13 +453,13 @@ def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
return res, tk_count return res, tk_count
def run_graphrag(row, chat_model, language, embedding_model, callback=None): async def run_graphrag(row, chat_model, language, embedding_model, callback=None):
chunks = [] chunks = []
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", "doc_id"]): fields=["content_with_weight", "doc_id"]):
chunks.append((d["doc_id"], d["content_with_weight"])) chunks.append((d["doc_id"], d["content_with_weight"]))
Dealer(LightKGExt if row["parser_config"]["graphrag"]["method"] != 'general' else GeneralKGExt, dealer = Dealer(LightKGExt if row["parser_config"]["graphrag"]["method"] != 'general' else GeneralKGExt,
row["tenant_id"], row["tenant_id"],
str(row["kb_id"]), str(row["kb_id"]),
chat_model, chat_model,
@ -480,9 +468,10 @@ def run_graphrag(row, chat_model, language, embedding_model, callback=None):
entity_types=row["parser_config"]["graphrag"]["entity_types"], entity_types=row["parser_config"]["graphrag"]["entity_types"],
embed_bdl=embedding_model, embed_bdl=embedding_model,
callback=callback) callback=callback)
await dealer()
def do_handle_task(task): async def do_handle_task(task):
task_id = task["id"] task_id = task["id"]
task_from_page = task["from_page"] task_from_page = task["from_page"]
task_to_page = task["to_page"] task_to_page = task["to_page"]
@ -494,6 +483,7 @@ def do_handle_task(task):
task_doc_id = task["doc_id"] task_doc_id = task["doc_id"]
task_document_name = task["name"] task_document_name = task["name"]
task_parser_config = task["parser_config"] task_parser_config = task["parser_config"]
task_start_ts = timer()
# prepare the progress callback function # prepare the progress callback function
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page) progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
@ -505,11 +495,7 @@ def do_handle_task(task):
progress_callback(-1, msg=error_message) progress_callback(-1, msg=error_message)
raise Exception(error_message) raise Exception(error_message)
try: task_canceled = TaskService.do_cancel(task_id)
task_canceled = TaskService.do_cancel(task_id)
except DoesNotExist:
logging.warning(f"task {task_id} is unknown")
return
if task_canceled: if task_canceled:
progress_callback(-1, msg="Task has been canceled.") progress_callback(-1, msg="Task has been canceled.")
return return
@ -529,71 +515,41 @@ def do_handle_task(task):
# Either using RAPTOR or Standard chunking methods # Either using RAPTOR or Standard chunking methods
if task.get("task_type", "") == "raptor": if task.get("task_type", "") == "raptor":
try: # bind LLM for raptor
# bind LLM for raptor chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) # run RAPTOR
# run RAPTOR chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
chunks, token_count = run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
except TaskCanceledException:
raise
except Exception as e:
error_message = f'Fail to bind LLM used by RAPTOR: {str(e)}'
progress_callback(-1, msg=error_message)
logging.exception(error_message)
raise
# Either using graphrag or Standard chunking methods # Either using graphrag or Standard chunking methods
elif task.get("task_type", "") == "graphrag": elif task.get("task_type", "") == "graphrag":
start_ts = timer() start_ts = timer()
try: chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) await run_graphrag(task, chat_model, task_language, embedding_model, progress_callback)
run_graphrag(task, chat_model, task_language, embedding_model, progress_callback) progress_callback(prog=1.0, msg="Knowledge Graph is done ({:.2f}s)".format(timer() - start_ts))
progress_callback(prog=1.0, msg="Knowledge Graph is done ({:.2f}s)".format(timer() - start_ts))
except TaskCanceledException:
raise
except Exception as e:
error_message = f'Fail to bind LLM used by Knowledge Graph: {str(e)}'
progress_callback(-1, msg=error_message)
logging.exception(error_message)
raise
return return
elif task.get("task_type", "") == "graph_resolution": elif task.get("task_type", "") == "graph_resolution":
start_ts = timer() start_ts = timer()
try: chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) with_res = WithResolution(
WithResolution( task["tenant_id"], str(task["kb_id"]),chat_model, embedding_model,
task["tenant_id"], str(task["kb_id"]),chat_model, embedding_model, progress_callback
progress_callback )
) await with_res()
progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts)) progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts))
except TaskCanceledException:
raise
except Exception as e:
error_message = f'Fail to bind LLM used by Knowledge Graph resolution: {str(e)}'
progress_callback(-1, msg=error_message)
logging.exception(error_message)
raise
return return
elif task.get("task_type", "") == "graph_community": elif task.get("task_type", "") == "graph_community":
start_ts = timer() start_ts = timer()
try: chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) with_comm = WithCommunity(
WithCommunity( task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model, progress_callback
progress_callback )
) await with_comm()
progress_callback(prog=1.0, msg="GraphRAG community reports generation is done ({:.2f}s)".format(timer() - start_ts)) progress_callback(prog=1.0, msg="GraphRAG community reports generation is done ({:.2f}s)".format(timer() - start_ts))
except TaskCanceledException:
raise
except Exception as e:
error_message = f'Fail to bind LLM used by GraphRAG community reports generation: {str(e)}'
progress_callback(-1, msg=error_message)
logging.exception(error_message)
raise
return return
else: else:
# Standard chunking methods # Standard chunking methods
start_ts = timer() start_ts = timer()
chunks = build_chunks(task, progress_callback) chunks = await build_chunks(task, progress_callback)
logging.info("Build document {}: {:.2f}s".format(task_document_name, timer() - start_ts)) logging.info("Build document {}: {:.2f}s".format(task_document_name, timer() - start_ts))
if chunks is None: if chunks is None:
return return
@ -605,7 +561,7 @@ def do_handle_task(task):
progress_callback(msg="Generate {} chunks".format(len(chunks))) progress_callback(msg="Generate {} chunks".format(len(chunks)))
start_ts = timer() start_ts = timer()
try: try:
token_count, vector_size = embedding(chunks, embedding_model, task_parser_config, progress_callback) token_count, vector_size = await embedding(chunks, embedding_model, task_parser_config, progress_callback)
except Exception as e: except Exception as e:
error_message = "Generate embedding error:{}".format(str(e)) error_message = "Generate embedding error:{}".format(str(e))
progress_callback(-1, error_message) progress_callback(-1, error_message)
@ -621,8 +577,7 @@ def do_handle_task(task):
doc_store_result = "" doc_store_result = ""
es_bulk_size = 4 es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size): for b in range(0, len(chunks), es_bulk_size):
doc_store_result = settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id), doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id), task_dataset_id))
task_dataset_id)
if b % 128 == 0: if b % 128 == 0:
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="") progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
if doc_store_result: if doc_store_result:
@ -635,8 +590,7 @@ def do_handle_task(task):
TaskService.update_chunk_ids(task["id"], chunk_ids_str) TaskService.update_chunk_ids(task["id"], chunk_ids_str)
except DoesNotExist: except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.") logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.")
doc_store_result = settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
task_dataset_id)
return return
logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page, logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page,
task_to_page, len(chunks), task_to_page, len(chunks),
@ -645,51 +599,39 @@ def do_handle_task(task):
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0) DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
time_cost = timer() - start_ts time_cost = timer() - start_ts
progress_callback(prog=1.0, msg="Done ({:.2f}s)".format(time_cost)) task_time_cost = timer() - task_start_ts
progress_callback(prog=1.0, msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
logging.info( logging.info(
"Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page, "Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page,
task_to_page, len(chunks), task_to_page, len(chunks),
token_count, time_cost)) token_count, task_time_cost))
def handle_task(): async def handle_task():
global PAYLOAD, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK global DONE_TASKS, FAILED_TASKS
task = collect() redis_msg, task = await collect()
if task: if not task:
return
try:
logging.info(f"handle_task begin for task {json.dumps(task)}")
CURRENT_TASKS[task["id"]] = copy.deepcopy(task)
await do_handle_task(task)
DONE_TASKS += 1
CURRENT_TASKS.pop(task["id"], None)
logging.info(f"handle_task done for task {json.dumps(task)}")
except Exception as e:
FAILED_TASKS += 1
CURRENT_TASKS.pop(task["id"], None)
try: try:
logging.info(f"handle_task begin for task {json.dumps(task)}") set_progress(task["id"], prog=-1, msg=f"[Exception]: {e}")
with mt_lock: except Exception:
CURRENT_TASK = copy.deepcopy(task) pass
do_handle_task(task) logging.exception(f"handle_task got exception for task {json.dumps(task)}")
with mt_lock: redis_msg.ack()
DONE_TASKS += 1
CURRENT_TASK = None
logging.info(f"handle_task done for task {json.dumps(task)}")
except TaskCanceledException:
with mt_lock:
DONE_TASKS += 1
CURRENT_TASK = None
try:
set_progress(task["id"], prog=-1, msg="handle_task got TaskCanceledException")
except Exception:
pass
logging.debug("handle_task got TaskCanceledException", exc_info=True)
except Exception as e:
with mt_lock:
FAILED_TASKS += 1
CURRENT_TASK = None
try:
set_progress(task["id"], prog=-1, msg=f"[Exception]: {e}")
except Exception:
pass
logging.exception(f"handle_task got exception for task {json.dumps(task)}")
if PAYLOAD:
PAYLOAD.ack()
PAYLOAD = None
def report_status(): async def report_status():
global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, DONE_TASKS, FAILED_TASKS
REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME) REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME)
while True: while True:
try: try:
@ -699,17 +641,17 @@ def report_status():
PENDING_TASKS = int(group_info.get("pending", 0)) PENDING_TASKS = int(group_info.get("pending", 0))
LAG_TASKS = int(group_info.get("lag", 0)) LAG_TASKS = int(group_info.get("lag", 0))
with mt_lock: current = copy.deepcopy(CURRENT_TASKS)
heartbeat = json.dumps({ heartbeat = json.dumps({
"name": CONSUMER_NAME, "name": CONSUMER_NAME,
"now": now.astimezone().isoformat(timespec="milliseconds"), "now": now.astimezone().isoformat(timespec="milliseconds"),
"boot_at": BOOT_AT, "boot_at": BOOT_AT,
"pending": PENDING_TASKS, "pending": PENDING_TASKS,
"lag": LAG_TASKS, "lag": LAG_TASKS,
"done": DONE_TASKS, "done": DONE_TASKS,
"failed": FAILED_TASKS, "failed": FAILED_TASKS,
"current": CURRENT_TASK, "current": current,
}) })
REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp()) REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}") logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
@ -718,27 +660,10 @@ def report_status():
REDIS_CONN.zpopmin(CONSUMER_NAME, expired) REDIS_CONN.zpopmin(CONSUMER_NAME, expired)
except Exception: except Exception:
logging.exception("report_status got exception") logging.exception("report_status got exception")
time.sleep(30) await trio.sleep(30)
def analyze_heap(snapshot1: tracemalloc.Snapshot, snapshot2: tracemalloc.Snapshot, snapshot_id: int, dump_full: bool): async def main():
msg = ""
if dump_full:
stats2 = snapshot2.statistics('lineno')
msg += f"{CONSUMER_NAME} memory usage of snapshot {snapshot_id}:\n"
for stat in stats2[:10]:
msg += f"{stat}\n"
stats1_vs_2 = snapshot2.compare_to(snapshot1, 'lineno')
msg += f"{CONSUMER_NAME} memory usage increase from snapshot {snapshot_id - 1} to snapshot {snapshot_id}:\n"
for stat in stats1_vs_2[:10]:
msg += f"{stat}\n"
msg += f"{CONSUMER_NAME} detailed traceback for the top memory consumers:\n"
for stat in stats1_vs_2[:3]:
msg += '\n'.join(stat.traceback.format())
logging.info(msg)
def main():
logging.info(r""" logging.info(r"""
______ __ ______ __ ______ __ ______ __
/_ __/___ ______/ /__ / ____/ _____ _______ __/ /_____ _____ /_ __/___ ______/ /__ / ____/ _____ _______ __/ /_____ _____
@ -755,33 +680,12 @@ def main():
if TRACE_MALLOC_ENABLED: if TRACE_MALLOC_ENABLED:
start_tracemalloc_and_snapshot(None, None) start_tracemalloc_and_snapshot(None, None)
# Create an event to signal the background thread to exit async with trio.open_nursery() as nursery:
stop_event = threading.Event() nursery.start_soon(report_status)
while True:
background_thread = threading.Thread(target=report_status) async with task_limiter:
background_thread.daemon = True nursery.start_soon(handle_task)
background_thread.start() logging.error("BUG!!! You should not reach here!!!")
# Handle SIGINT (Ctrl+C)
def signal_handler(sig, frame):
logging.info("Received Ctrl+C, shutting down gracefully...")
stop_event.set()
# Give the background thread time to clean up
if background_thread.is_alive():
background_thread.join(timeout=5)
logging.info("Exiting...")
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
try:
while not stop_event.is_set():
handle_task()
except KeyboardInterrupt:
logging.info("Interrupted by keyboard, shutting down...")
stop_event.set()
if background_thread.is_alive():
background_thread.join(timeout=5)
if __name__ == "__main__": if __name__ == "__main__":
main() trio.run(main)

View File

@ -24,7 +24,7 @@ from rag import settings
from rag.utils import singleton from rag.utils import singleton
class Payload: class RedisMsg:
def __init__(self, consumer, queue_name, group_name, msg_id, message): def __init__(self, consumer, queue_name, group_name, msg_id, message):
self.__consumer = consumer self.__consumer = consumer
self.__queue_name = queue_name self.__queue_name = queue_name
@ -43,6 +43,9 @@ class Payload:
def get_message(self): def get_message(self):
return self.__message return self.__message
def get_msg_id(self):
return self.__msg_id
@singleton @singleton
class RedisDB: class RedisDB:
@ -206,9 +209,8 @@ class RedisDB:
) )
return False return False
def queue_consumer( def queue_consumer(self, queue_name, group_name, consumer_name, msg_id=b">") -> RedisMsg:
self, queue_name, group_name, consumer_name, msg_id=b">" """https://redis.io/docs/latest/commands/xreadgroup/"""
) -> Payload:
try: try:
group_info = self.REDIS.xinfo_groups(queue_name) group_info = self.REDIS.xinfo_groups(queue_name)
if not any(e["name"] == group_name for e in group_info): if not any(e["name"] == group_name for e in group_info):
@ -217,15 +219,17 @@ class RedisDB:
"groupname": group_name, "groupname": group_name,
"consumername": consumer_name, "consumername": consumer_name,
"count": 1, "count": 1,
"block": 10000, "block": 5,
"streams": {queue_name: msg_id}, "streams": {queue_name: msg_id},
} }
messages = self.REDIS.xreadgroup(**args) messages = self.REDIS.xreadgroup(**args)
if not messages: if not messages:
return None return None
stream, element_list = messages[0] stream, element_list = messages[0]
if not element_list:
return None
msg_id, payload = element_list[0] msg_id, payload = element_list[0]
res = Payload(self.REDIS, queue_name, group_name, msg_id, payload) res = RedisMsg(self.REDIS, queue_name, group_name, msg_id, payload)
return res return res
except Exception as e: except Exception as e:
if "key" in str(e): if "key" in str(e):
@ -239,30 +243,24 @@ class RedisDB:
) )
return None return None
def get_unacked_for(self, consumer_name, queue_name, group_name): def get_unacked_iterator(self, queue_name, group_name, consumer_name):
try: try:
group_info = self.REDIS.xinfo_groups(queue_name) group_info = self.REDIS.xinfo_groups(queue_name)
if not any(e["name"] == group_name for e in group_info): if not any(e["name"] == group_name for e in group_info):
return return
pendings = self.REDIS.xpending_range( current_min = 0
queue_name, while True:
group_name, payload = self.queue_consumer(queue_name, group_name, consumer_name, current_min)
min=0, if not payload:
max=10000000000000, return
count=1, current_min = payload.get_msg_id()
consumername=consumer_name, logging.info(f"RedisDB.get_unacked_iterator {consumer_name} msg_id {current_min}")
) yield payload
if not pendings:
return
msg_id = pendings[0]["message_id"]
msg = self.REDIS.xrange(queue_name, min=msg_id, count=1)
_, payload = msg[0]
return Payload(self.REDIS, queue_name, group_name, msg_id, payload)
except Exception as e: except Exception as e:
if "key" in str(e): if "key" in str(e):
return return
logging.exception( logging.exception(
"RedisDB.get_unacked_for " + consumer_name + " got exception: " + str(e) "RedisDB.get_unacked_iterator " + consumer_name + " got exception: "
) )
self.__open__() self.__open__()

52
uv.lock generated
View File

@ -1,4 +1,5 @@
version = 1 version = 1
revision = 1
requires-python = ">=3.10, <3.13" requires-python = ">=3.10, <3.13"
resolution-markers = [ resolution-markers = [
"python_full_version >= '3.12' and sys_platform == 'darwin'", "python_full_version >= '3.12' and sys_platform == 'darwin'",
@ -1083,9 +1084,6 @@ name = "datrie"
version = "0.8.2" version = "0.8.2"
source = { registry = "https://mirrors.aliyun.com/pypi/simple" } source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
sdist = { url = "https://mirrors.aliyun.com/pypi/packages/9d/fe/db74bd405d515f06657f11ad529878fd389576dca4812bea6f98d9b31574/datrie-0.8.2.tar.gz", hash = "sha256:525b08f638d5cf6115df6ccd818e5a01298cd230b2dac91c8ff2e6499d18765d" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/9d/fe/db74bd405d515f06657f11ad529878fd389576dca4812bea6f98d9b31574/datrie-0.8.2.tar.gz", hash = "sha256:525b08f638d5cf6115df6ccd818e5a01298cd230b2dac91c8ff2e6499d18765d" }
wheels = [
{ url = "https://mirrors.aliyun.com/pypi/packages/44/02/53f0cf0bf0cd629ba6c2cc13f2f9db24323459e9c19463783d890a540a96/datrie-0.8.2-pp273-pypy_73-win32.whl", hash = "sha256:b07bd5fdfc3399a6dab86d6e35c72b1dbd598e80c97509c7c7518ab8774d3fda" },
]
[[package]] [[package]]
name = "decorator" name = "decorator"
@ -1362,17 +1360,17 @@ name = "fastembed-gpu"
version = "0.3.6" version = "0.3.6"
source = { registry = "https://mirrors.aliyun.com/pypi/simple" } source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
dependencies = [ dependencies = [
{ name = "huggingface-hub" }, { name = "huggingface-hub", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "loguru" }, { name = "loguru", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "mmh3" }, { name = "mmh3", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "numpy" }, { name = "numpy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "onnxruntime-gpu" }, { name = "onnxruntime-gpu", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "pillow" }, { name = "pillow", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "pystemmer" }, { name = "pystemmer", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "requests" }, { name = "requests", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "snowballstemmer" }, { name = "snowballstemmer", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "tokenizers" }, { name = "tokenizers", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "tqdm" }, { name = "tqdm", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
] ]
sdist = { url = "https://mirrors.aliyun.com/pypi/packages/da/07/7336c7f3d7ee47f33b407eeb50f5eeb152889de538a52a8f1cc637192816/fastembed_gpu-0.3.6.tar.gz", hash = "sha256:ee2de8918b142adbbf48caaffec0c492f864d73c073eea5a3dcd0e8c1041c50d" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/da/07/7336c7f3d7ee47f33b407eeb50f5eeb152889de538a52a8f1cc637192816/fastembed_gpu-0.3.6.tar.gz", hash = "sha256:ee2de8918b142adbbf48caaffec0c492f864d73c073eea5a3dcd0e8c1041c50d" }
wheels = [ wheels = [
@ -3485,12 +3483,12 @@ name = "onnxruntime-gpu"
version = "1.19.2" version = "1.19.2"
source = { registry = "https://mirrors.aliyun.com/pypi/simple" } source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
dependencies = [ dependencies = [
{ name = "coloredlogs" }, { name = "coloredlogs", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "flatbuffers" }, { name = "flatbuffers", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "numpy" }, { name = "numpy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "packaging" }, { name = "packaging", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "protobuf" }, { name = "protobuf", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "sympy" }, { name = "sympy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
] ]
wheels = [ wheels = [
{ url = "https://mirrors.aliyun.com/pypi/packages/d0/9c/3fa310e0730643051eb88e884f19813a6c8b67d0fbafcda610d960e589db/onnxruntime_gpu-1.19.2-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a49740e079e7c5215830d30cde3df792e903df007aa0b0fd7aa797937061b27a" }, { url = "https://mirrors.aliyun.com/pypi/packages/d0/9c/3fa310e0730643051eb88e884f19813a6c8b67d0fbafcda610d960e589db/onnxruntime_gpu-1.19.2-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a49740e079e7c5215830d30cde3df792e903df007aa0b0fd7aa797937061b27a" },
@ -4164,15 +4162,6 @@ wheels = [
{ url = "https://mirrors.aliyun.com/pypi/packages/77/89/bc88a6711935ba795a679ea6ebee07e128050d6382eaa35a0a47c8032bdc/pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd" }, { url = "https://mirrors.aliyun.com/pypi/packages/77/89/bc88a6711935ba795a679ea6ebee07e128050d6382eaa35a0a47c8032bdc/pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd" },
] ]
[[package]]
name = "pybind11"
version = "2.13.6"
source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
sdist = { url = "https://mirrors.aliyun.com/pypi/packages/d2/c1/72b9622fcb32ff98b054f724e213c7f70d6898baa714f4516288456ceaba/pybind11-2.13.6.tar.gz", hash = "sha256:ba6af10348c12b24e92fa086b39cfba0eff619b61ac77c406167d813b096d39a" }
wheels = [
{ url = "https://mirrors.aliyun.com/pypi/packages/13/2f/0f24b288e2ce56f51c920137620b4434a38fd80583dbbe24fc2a1656c388/pybind11-2.13.6-py3-none-any.whl", hash = "sha256:237c41e29157b962835d356b370ededd57594a26d5894a795960f0047cb5caf5" },
]
[[package]] [[package]]
name = "pyclipper" name = "pyclipper"
version = "1.3.0.post5" version = "1.3.0.post5"
@ -4230,8 +4219,6 @@ wheels = [
{ url = "https://mirrors.aliyun.com/pypi/packages/48/7d/0f2b09490b98cc6a902ac15dda8760c568b9c18cfe70e0ef7a16de64d53a/pycryptodomex-3.20.0-cp35-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:7a7a8f33a1f1fb762ede6cc9cbab8f2a9ba13b196bfaf7bc6f0b39d2ba315a43" }, { url = "https://mirrors.aliyun.com/pypi/packages/48/7d/0f2b09490b98cc6a902ac15dda8760c568b9c18cfe70e0ef7a16de64d53a/pycryptodomex-3.20.0-cp35-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:7a7a8f33a1f1fb762ede6cc9cbab8f2a9ba13b196bfaf7bc6f0b39d2ba315a43" },
{ url = "https://mirrors.aliyun.com/pypi/packages/b0/1c/375adb14b71ee1c8d8232904e928b3e7af5bbbca7c04e4bec94fe8e90c3d/pycryptodomex-3.20.0-cp35-abi3-win32.whl", hash = "sha256:c39778fd0548d78917b61f03c1fa8bfda6cfcf98c767decf360945fe6f97461e" }, { url = "https://mirrors.aliyun.com/pypi/packages/b0/1c/375adb14b71ee1c8d8232904e928b3e7af5bbbca7c04e4bec94fe8e90c3d/pycryptodomex-3.20.0-cp35-abi3-win32.whl", hash = "sha256:c39778fd0548d78917b61f03c1fa8bfda6cfcf98c767decf360945fe6f97461e" },
{ url = "https://mirrors.aliyun.com/pypi/packages/b2/e8/1b92184ab7e5595bf38000587e6f8cf9556ebd1bf0a583619bee2057afbd/pycryptodomex-3.20.0-cp35-abi3-win_amd64.whl", hash = "sha256:2a47bcc478741b71273b917232f521fd5704ab4b25d301669879e7273d3586cc" }, { url = "https://mirrors.aliyun.com/pypi/packages/b2/e8/1b92184ab7e5595bf38000587e6f8cf9556ebd1bf0a583619bee2057afbd/pycryptodomex-3.20.0-cp35-abi3-win_amd64.whl", hash = "sha256:2a47bcc478741b71273b917232f521fd5704ab4b25d301669879e7273d3586cc" },
{ url = "https://mirrors.aliyun.com/pypi/packages/e7/c5/9140bb867141d948c8e242013ec8a8011172233c898dfdba0a2417c3169a/pycryptodomex-3.20.0-pp27-pypy_73-manylinux2010_x86_64.whl", hash = "sha256:1be97461c439a6af4fe1cf8bf6ca5936d3db252737d2f379cc6b2e394e12a458" },
{ url = "https://mirrors.aliyun.com/pypi/packages/5e/6a/04acb4978ce08ab16890c70611ebc6efd251681341617bbb9e53356dee70/pycryptodomex-3.20.0-pp27-pypy_73-win32.whl", hash = "sha256:19764605feea0df966445d46533729b645033f134baeb3ea26ad518c9fdf212c" },
{ url = "https://mirrors.aliyun.com/pypi/packages/eb/df/3f1ea084e43b91e6d2b6b3493cc948864c17ea5d93ff1261a03812fbfd1a/pycryptodomex-3.20.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:f2e497413560e03421484189a6b65e33fe800d3bd75590e6d78d4dfdb7accf3b" }, { url = "https://mirrors.aliyun.com/pypi/packages/eb/df/3f1ea084e43b91e6d2b6b3493cc948864c17ea5d93ff1261a03812fbfd1a/pycryptodomex-3.20.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:f2e497413560e03421484189a6b65e33fe800d3bd75590e6d78d4dfdb7accf3b" },
{ url = "https://mirrors.aliyun.com/pypi/packages/c9/f3/83ffbdfa0c8f9154bcd8866895f6cae5a3ec749da8b0840603cf936c4412/pycryptodomex-3.20.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e48217c7901edd95f9f097feaa0388da215ed14ce2ece803d3f300b4e694abea" }, { url = "https://mirrors.aliyun.com/pypi/packages/c9/f3/83ffbdfa0c8f9154bcd8866895f6cae5a3ec749da8b0840603cf936c4412/pycryptodomex-3.20.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e48217c7901edd95f9f097feaa0388da215ed14ce2ece803d3f300b4e694abea" },
{ url = "https://mirrors.aliyun.com/pypi/packages/c9/9d/c113e640aaf02af5631ae2686b742aac5cd0e1402b9d6512b1c7ec5ef05d/pycryptodomex-3.20.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d00fe8596e1cc46b44bf3907354e9377aa030ec4cd04afbbf6e899fc1e2a7781" }, { url = "https://mirrors.aliyun.com/pypi/packages/c9/9d/c113e640aaf02af5631ae2686b742aac5cd0e1402b9d6512b1c7ec5ef05d/pycryptodomex-3.20.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d00fe8596e1cc46b44bf3907354e9377aa030ec4cd04afbbf6e899fc1e2a7781" },
@ -4820,6 +4807,7 @@ dependencies = [
{ name = "tencentcloud-sdk-python" }, { name = "tencentcloud-sdk-python" },
{ name = "tika" }, { name = "tika" },
{ name = "tiktoken" }, { name = "tiktoken" },
{ name = "trio" },
{ name = "umap-learn" }, { name = "umap-learn" },
{ name = "valkey" }, { name = "valkey" },
{ name = "vertexai" }, { name = "vertexai" },
@ -4954,6 +4942,7 @@ requires-dist = [
{ name = "tiktoken", specifier = "==0.7.0" }, { name = "tiktoken", specifier = "==0.7.0" },
{ name = "torch", marker = "extra == 'full'", specifier = ">=2.5.0,<3.0.0" }, { name = "torch", marker = "extra == 'full'", specifier = ">=2.5.0,<3.0.0" },
{ name = "transformers", marker = "extra == 'full'", specifier = ">=4.35.0,<5.0.0" }, { name = "transformers", marker = "extra == 'full'", specifier = ">=4.35.0,<5.0.0" },
{ name = "trio", specifier = ">=0.29.0" },
{ name = "umap-learn", specifier = "==0.5.6" }, { name = "umap-learn", specifier = "==0.5.6" },
{ name = "valkey", specifier = "==6.0.2" }, { name = "valkey", specifier = "==6.0.2" },
{ name = "vertexai", specifier = "==1.64.0" }, { name = "vertexai", specifier = "==1.64.0" },
@ -4969,6 +4958,7 @@ requires-dist = [
{ name = "yfinance", specifier = "==0.1.96" }, { name = "yfinance", specifier = "==0.1.96" },
{ name = "zhipuai", specifier = "==2.0.1" }, { name = "zhipuai", specifier = "==2.0.1" },
] ]
provides-extras = ["full"]
[[package]] [[package]]
name = "ranx" name = "ranx"