Made task_executor async to speedup parsing (#5530)

### What problem does this PR solve?

Made task_executor async to speedup parsing

### Type of change

- [x] Performance Improvement
This commit is contained in:
Zhichang Yu 2025-03-03 18:59:49 +08:00 committed by GitHub
parent abac2ca2c5
commit c813c1ff4c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 576 additions and 1005 deletions

View File

@ -17,6 +17,7 @@ import json
import re
import traceback
from copy import deepcopy
import trio
from api.db.db_models import APIToken
from api.db.services.conversation_service import ConversationService, structure_answer
@ -386,7 +387,8 @@ def mindmap():
rank_feature=label_question(question, [kb])
)
mindmap = MindMapExtractor(chat_mdl)
mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]])
mind_map = mind_map.output
if "error" in mind_map:
return server_error_response(Exception(mind_map["error"]))
return get_json_result(data=mind_map)

View File

@ -22,6 +22,7 @@ from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from datetime import datetime
from io import BytesIO
import trio
from peewee import fn
@ -597,8 +598,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
if parser_ids[doc_id] != ParserType.PICTURE.value:
mindmap = MindMapExtractor(llm_bdl)
try:
mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output,
ensure_ascii=False, indent=2)
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in docs if c["doc_id"] == doc_id])
mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2)
if len(mind_map) < 32:
raise Exception("Few content: " + mind_map)
cks.append({

View File

@ -17,6 +17,8 @@ import base64
import json
import os
import re
import sys
import threading
from io import BytesIO
import pdfplumber
@ -30,6 +32,10 @@ from api.constants import IMG_BASE64_PREFIX
PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
RAG_BASE = os.getenv("RAG_BASE")
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
if LOCK_KEY_pdfplumber not in sys.modules:
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
def get_project_base_directory(*args):
global PROJECT_BASE
@ -175,19 +181,20 @@ def thumbnail_img(filename, blob):
"""
filename = filename.lower()
if re.match(r".*\.pdf$", filename):
pdf = pdfplumber.open(BytesIO(blob))
buffered = BytesIO()
resolution = 32
img = None
for _ in range(10):
# https://github.com/jsvine/pdfplumber?tab=readme-ov-file#creating-a-pageimage-with-to_image
pdf.pages[0].to_image(resolution=resolution).annotated.save(buffered, format="png")
img = buffered.getvalue()
if len(img) >= 64000 and resolution >= 2:
resolution = resolution / 2
buffered = BytesIO()
else:
break
with sys.modules[LOCK_KEY_pdfplumber]:
pdf = pdfplumber.open(BytesIO(blob))
buffered = BytesIO()
resolution = 32
img = None
for _ in range(10):
# https://github.com/jsvine/pdfplumber?tab=readme-ov-file#creating-a-pageimage-with-to_image
pdf.pages[0].to_image(resolution=resolution).annotated.save(buffered, format="png")
img = buffered.getvalue()
if len(img) >= 64000 and resolution >= 2:
resolution = resolution / 2
buffered = BytesIO()
else:
break
pdf.close()
return img

View File

@ -18,6 +18,8 @@ import os.path
import logging
from logging.handlers import RotatingFileHandler
initialized_root_logger = False
def get_project_base_directory():
PROJECT_BASE = os.path.abspath(
os.path.join(
@ -29,10 +31,13 @@ def get_project_base_directory():
return PROJECT_BASE
def initRootLogger(logfile_basename: str, log_format: str = "%(asctime)-15s %(levelname)-8s %(process)d %(message)s"):
logger = logging.getLogger()
if logger.hasHandlers():
global initialized_root_logger
if initialized_root_logger:
return
initialized_root_logger = True
logger = logging.getLogger()
logger.handlers.clear()
log_path = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{logfile_basename}.log"))
os.makedirs(os.path.dirname(log_path), exist_ok=True)

View File

@ -18,6 +18,8 @@ import logging
import os
import random
from timeit import default_timer as timer
import sys
import threading
import xgboost as xgb
from io import BytesIO
@ -34,6 +36,10 @@ from rag.nlp import rag_tokenizer
from copy import deepcopy
from huggingface_hub import snapshot_download
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
if LOCK_KEY_pdfplumber not in sys.modules:
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
class RAGFlowPdfParser:
def __init__(self):
self.ocr = OCR()
@ -948,8 +954,9 @@ class RAGFlowPdfParser:
@staticmethod
def total_page_number(fnm, binary=None):
try:
pdf = pdfplumber.open(
fnm) if not binary else pdfplumber.open(BytesIO(binary))
with sys.modules[LOCK_KEY_pdfplumber]:
pdf = pdfplumber.open(
fnm) if not binary else pdfplumber.open(BytesIO(binary))
total_page = len(pdf.pages)
pdf.close()
return total_page
@ -968,17 +975,18 @@ class RAGFlowPdfParser:
self.page_from = page_from
start = timer()
try:
self.pdf = pdfplumber.open(fnm) if isinstance(
fnm, str) else pdfplumber.open(BytesIO(fnm))
self.page_images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
enumerate(self.pdf.pages[page_from:page_to])]
try:
self.page_chars = [[c for c in page.dedupe_chars().chars if self._has_color(c)] for page in self.pdf.pages[page_from:page_to]]
except Exception as e:
logging.warning(f"Failed to extract characters for pages {page_from}-{page_to}: {str(e)}")
self.page_chars = [[] for _ in range(page_to - page_from)] # If failed to extract, using empty list instead.
self.total_page = len(self.pdf.pages)
with sys.modules[LOCK_KEY_pdfplumber]:
self.pdf = pdfplumber.open(fnm) if isinstance(
fnm, str) else pdfplumber.open(BytesIO(fnm))
self.page_images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
enumerate(self.pdf.pages[page_from:page_to])]
try:
self.page_chars = [[c for c in page.dedupe_chars().chars if self._has_color(c)] for page in self.pdf.pages[page_from:page_to]]
except Exception as e:
logging.warning(f"Failed to extract characters for pages {page_from}-{page_to}: {str(e)}")
self.page_chars = [[] for _ in range(page_to - page_from)] # If failed to extract, using empty list instead.
self.total_page = len(self.pdf.pages)
except Exception:
logging.exception("RAGFlowPdfParser __images__")
logging.info(f"__images__ dedupe_chars cost {timer() - start}s")

View File

@ -14,7 +14,8 @@
# limitations under the License.
#
import io
import sys
import threading
import pdfplumber
from .ocr import OCR
@ -23,6 +24,11 @@ from .layout_recognizer import LayoutRecognizer4YOLOv10 as LayoutRecognizer
from .table_structure_recognizer import TableStructureRecognizer
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
if LOCK_KEY_pdfplumber not in sys.modules:
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
def init_in_out(args):
from PIL import Image
import os
@ -36,9 +42,10 @@ def init_in_out(args):
def pdf_pages(fnm, zoomin=3):
nonlocal outputs, images
pdf = pdfplumber.open(fnm)
images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
enumerate(pdf.pages)]
with sys.modules[LOCK_KEY_pdfplumber]:
pdf = pdfplumber.open(fnm)
images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
enumerate(pdf.pages)]
for i, page in enumerate(images):
outputs.append(os.path.split(fnm)[-1] + f"_{i}.jpg")

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import itertools
import re
import time
@ -21,13 +20,14 @@ from dataclasses import dataclass
from typing import Any, Callable
import networkx as nx
import trio
from graphrag.general.extractor import Extractor
from rag.nlp import is_english
import editdistance
from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import perform_variable_replacements
from graphrag.utils import perform_variable_replacements, chat_limiter
DEFAULT_RECORD_DELIMITER = "##"
DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
@ -67,13 +67,13 @@ class EntityResolution(Extractor):
self._resolution_result_delimiter_key = "resolution_result_delimiter"
self._input_text_key = "input_text"
def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult:
async def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult:
"""Call method definition."""
if prompt_variables is None:
prompt_variables = {}
# Wire defaults into the prompt variables
prompt_variables = {
self.prompt_variables = {
**prompt_variables,
self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
or DEFAULT_RECORD_DELIMITER,
@ -94,39 +94,12 @@ class EntityResolution(Extractor):
for k, v in node_clusters.items():
candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if self.is_similarity(a, b)]
gen_conf = {"temperature": 0.5}
resolution_result = set()
for candidate_resolution_i in candidate_resolution.items():
if candidate_resolution_i[1]:
try:
pair_txt = [
f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n']
for index, candidate in enumerate(candidate_resolution_i[1]):
pair_txt.append(
f'Question {index + 1}: name of{candidate_resolution_i[0]} A is {candidate[0]} ,name of{candidate_resolution_i[0]} B is {candidate[1]}')
sent = 'question above' if len(pair_txt) == 1 else f'above {len(pair_txt)} questions'
pair_txt.append(
f'\nUse domain knowledge of {candidate_resolution_i[0]}s to help understand the text and answer the {sent} in the format: For Question i, Yes, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are the same {candidate_resolution_i[0]}./No, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are different {candidate_resolution_i[0]}s. For Question i+1, (repeat the above procedures)')
pair_prompt = '\n'.join(pair_txt)
variables = {
**prompt_variables,
self._input_text_key: pair_prompt
}
text = perform_variable_replacements(self._resolution_prompt, variables=variables)
response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
result = self._process_results(len(candidate_resolution_i[1]), response,
prompt_variables.get(self._record_delimiter_key,
DEFAULT_RECORD_DELIMITER),
prompt_variables.get(self._entity_index_dilimiter_key,
DEFAULT_ENTITY_INDEX_DELIMITER),
prompt_variables.get(self._resolution_result_delimiter_key,
DEFAULT_RESOLUTION_RESULT_DELIMITER))
for result_i in result:
resolution_result.add(candidate_resolution_i[1][result_i[0] - 1])
except Exception:
logging.exception("error entity resolution")
async with trio.open_nursery() as nursery:
for candidate_resolution_i in candidate_resolution.items():
if not candidate_resolution_i[1]:
continue
nursery.start_soon(self._resolve_candidate(candidate_resolution_i, resolution_result))
connect_graph = nx.Graph()
removed_entities = []
@ -172,6 +145,34 @@ class EntityResolution(Extractor):
removed_entities=removed_entities
)
async def _resolve_candidate(self, candidate_resolution_i, resolution_result):
gen_conf = {"temperature": 0.5}
pair_txt = [
f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n']
for index, candidate in enumerate(candidate_resolution_i[1]):
pair_txt.append(
f'Question {index + 1}: name of{candidate_resolution_i[0]} A is {candidate[0]} ,name of{candidate_resolution_i[0]} B is {candidate[1]}')
sent = 'question above' if len(pair_txt) == 1 else f'above {len(pair_txt)} questions'
pair_txt.append(
f'\nUse domain knowledge of {candidate_resolution_i[0]}s to help understand the text and answer the {sent} in the format: For Question i, Yes, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are the same {candidate_resolution_i[0]}./No, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are different {candidate_resolution_i[0]}s. For Question i+1, (repeat the above procedures)')
pair_prompt = '\n'.join(pair_txt)
variables = {
**self.prompt_variables,
self._input_text_key: pair_prompt
}
text = perform_variable_replacements(self._resolution_prompt, variables=variables)
async with chat_limiter:
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
result = self._process_results(len(candidate_resolution_i[1]), response,
self.prompt_variables.get(self._record_delimiter_key,
DEFAULT_RECORD_DELIMITER),
self.prompt_variables.get(self._entity_index_dilimiter_key,
DEFAULT_ENTITY_INDEX_DELIMITER),
self.prompt_variables.get(self._resolution_result_delimiter_key,
DEFAULT_RESOLUTION_RESULT_DELIMITER))
for result_i in result:
resolution_result.add(candidate_resolution_i[1][result_i[0] - 1])
def _process_results(
self,
records_length: int,

View File

@ -1,268 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""
import logging
import argparse
import json
import re
import traceback
from dataclasses import dataclass
from typing import Any
import tiktoken
from graphrag.general.claim_prompt import CLAIM_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
from graphrag.general.extractor import Extractor
from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
DEFAULT_TUPLE_DELIMITER = "<|>"
DEFAULT_RECORD_DELIMITER = "##"
DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
CLAIM_MAX_GLEANINGS = 1
@dataclass
class ClaimExtractorResult:
"""Claim extractor result class definition."""
output: list[dict]
source_docs: dict[str, Any]
class ClaimExtractor(Extractor):
"""Claim extractor class definition."""
_extraction_prompt: str
_summary_prompt: str
_output_formatter_prompt: str
_input_text_key: str
_input_entity_spec_key: str
_input_claim_description_key: str
_tuple_delimiter_key: str
_record_delimiter_key: str
_completion_delimiter_key: str
_max_gleanings: int
_on_error: ErrorHandlerFn
def __init__(
self,
llm_invoker: CompletionLLM,
extraction_prompt: str | None = None,
input_text_key: str | None = None,
input_entity_spec_key: str | None = None,
input_claim_description_key: str | None = None,
input_resolved_entities_key: str | None = None,
tuple_delimiter_key: str | None = None,
record_delimiter_key: str | None = None,
completion_delimiter_key: str | None = None,
encoding_model: str | None = None,
max_gleanings: int | None = None,
on_error: ErrorHandlerFn | None = None,
):
"""Init method definition."""
self._llm = llm_invoker
self._extraction_prompt = extraction_prompt or CLAIM_EXTRACTION_PROMPT
self._input_text_key = input_text_key or "input_text"
self._input_entity_spec_key = input_entity_spec_key or "entity_specs"
self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter"
self._record_delimiter_key = record_delimiter_key or "record_delimiter"
self._completion_delimiter_key = (
completion_delimiter_key or "completion_delimiter"
)
self._input_claim_description_key = (
input_claim_description_key or "claim_description"
)
self._input_resolved_entities_key = (
input_resolved_entities_key or "resolved_entities"
)
self._max_gleanings = (
max_gleanings if max_gleanings is not None else CLAIM_MAX_GLEANINGS
)
self._on_error = on_error or (lambda _e, _s, _d: None)
# Construct the looping arguments
encoding = tiktoken.get_encoding(encoding_model or "cl100k_base")
yes = encoding.encode("YES")
no = encoding.encode("NO")
self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
def __call__(
self, inputs: dict[str, Any], prompt_variables: dict | None = None
) -> ClaimExtractorResult:
"""Call method definition."""
if prompt_variables is None:
prompt_variables = {}
texts = inputs[self._input_text_key]
entity_spec = str(inputs[self._input_entity_spec_key])
claim_description = inputs[self._input_claim_description_key]
resolved_entities = inputs.get(self._input_resolved_entities_key, {})
source_doc_map = {}
prompt_args = {
self._input_entity_spec_key: entity_spec,
self._input_claim_description_key: claim_description,
self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key)
or DEFAULT_TUPLE_DELIMITER,
self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
or DEFAULT_RECORD_DELIMITER,
self._completion_delimiter_key: prompt_variables.get(
self._completion_delimiter_key
)
or DEFAULT_COMPLETION_DELIMITER,
}
all_claims: list[dict] = []
for doc_index, text in enumerate(texts):
document_id = f"d{doc_index}"
try:
claims = self._process_document(prompt_args, text, doc_index)
all_claims += [
self._clean_claim(c, document_id, resolved_entities) for c in claims
]
source_doc_map[document_id] = text
except Exception as e:
logging.exception("error extracting claim")
self._on_error(
e,
traceback.format_exc(),
{"doc_index": doc_index, "text": text},
)
continue
return ClaimExtractorResult(
output=all_claims,
source_docs=source_doc_map,
)
def _clean_claim(
self, claim: dict, document_id: str, resolved_entities: dict
) -> dict:
# clean the parsed claims to remove any claims with status = False
obj = claim.get("object_id", claim.get("object"))
subject = claim.get("subject_id", claim.get("subject"))
# If subject or object in resolved entities, then replace with resolved entity
obj = resolved_entities.get(obj, obj)
subject = resolved_entities.get(subject, subject)
claim["object_id"] = obj
claim["subject_id"] = subject
claim["doc_id"] = document_id
return claim
def _process_document(
self, prompt_args: dict, doc, doc_index: int
) -> list[dict]:
record_delimiter = prompt_args.get(
self._record_delimiter_key, DEFAULT_RECORD_DELIMITER
)
completion_delimiter = prompt_args.get(
self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER
)
variables = {
self._input_text_key: doc,
**prompt_args,
}
text = perform_variable_replacements(self._extraction_prompt, variables=variables)
gen_conf = {"temperature": 0.5}
results = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
claims = results.strip().removesuffix(completion_delimiter)
history = [{"role": "system", "content": text}, {"role": "assistant", "content": results}]
# Repeat to ensure we maximize entity count
for i in range(self._max_gleanings):
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
history.append({"role": "user", "content": text})
extension = self._chat("", history, gen_conf)
claims += record_delimiter + extension.strip().removesuffix(
completion_delimiter
)
# If this isn't the last loop, check to see if we should continue
if i >= self._max_gleanings - 1:
break
history.append({"role": "assistant", "content": extension})
history.append({"role": "user", "content": LOOP_PROMPT})
continuation = self._chat("", history, self._loop_args)
if continuation != "YES":
break
result = self._parse_claim_tuples(claims, prompt_args)
for r in result:
r["doc_id"] = f"{doc_index}"
return result
def _parse_claim_tuples(
self, claims: str, prompt_variables: dict
) -> list[dict[str, Any]]:
"""Parse claim tuples."""
record_delimiter = prompt_variables.get(
self._record_delimiter_key, DEFAULT_RECORD_DELIMITER
)
completion_delimiter = prompt_variables.get(
self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER
)
tuple_delimiter = prompt_variables.get(
self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER
)
def pull_field(index: int, fields: list[str]) -> str | None:
return fields[index].strip() if len(fields) > index else None
result: list[dict[str, Any]] = []
claims_values = (
claims.strip().removesuffix(completion_delimiter).split(record_delimiter)
)
for claim in claims_values:
claim = claim.strip().removeprefix("(").removesuffix(")")
claim = re.sub(r".*Output:", "", claim)
# Ignore the completion delimiter
if claim == completion_delimiter:
continue
claim_fields = claim.split(tuple_delimiter)
o = {
"subject_id": pull_field(0, claim_fields),
"object_id": pull_field(1, claim_fields),
"type": pull_field(2, claim_fields),
"status": pull_field(3, claim_fields),
"start_date": pull_field(4, claim_fields),
"end_date": pull_field(5, claim_fields),
"description": pull_field(6, claim_fields),
"source_text": pull_field(7, claim_fields),
"doc_id": pull_field(8, claim_fields),
}
if any([not o["subject_id"], not o["object_id"], o["subject_id"].lower() == "none", o["object_id"] == "none"]):
continue
result.append(o)
return result
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True)
parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True)
args = parser.parse_args()
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from api import settings
from api.db.services.knowledgebase_service import KnowledgebaseService
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
docs = [d["content_with_weight"] for d in settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])]
info = {
"input_text": docs,
"entity_specs": "organization, person",
"claim_description": ""
}
claim = ex(info)
logging.info(json.dumps(claim.output, ensure_ascii=False, indent=2))

View File

@ -1,71 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""
CLAIM_EXTRACTION_PROMPT = """
################
-Target activity-
################
You are an intelligent assistant that helps a human analyst to analyze claims against certain entities presented in a text document.
################
-Goal-
################
Given a text document that is potentially relevant to this activity, an entity specification, and a claim description, extract all entities that match the entity specification and all claims against those entities.
################
-Steps-
################
- 1. Extract all named entities that match the predefined entity specification. Entity specification can either be a list of entity names or a list of entity types.
- 2. For each entity identified in step 1, extract all claims associated with the entity. Claims need to match the specified claim description, and the entity should be the subject of the claim.
For each claim, extract the following information:
- Subject: name of the entity that is subject of the claim, capitalized. The subject entity is one that committed the action described in the claim. Subject needs to be one of the named entities identified in step 1.
- Object: name of the entity that is object of the claim, capitalized. The object entity is one that either reports/handles or is affected by the action described in the claim. If object entity is unknown, use **NONE**.
- Claim Type: overall category of the claim, capitalized. Name it in a way that can be repeated across multiple text inputs, so that similar claims share the same claim type
- Claim Status: **TRUE**, **FALSE**, or **SUSPECTED**. TRUE means the claim is confirmed, FALSE means the claim is found to be False, SUSPECTED means the claim is not verified.
- Claim Description: Detailed description explaining the reasoning behind the claim, together with all the related evidence and references.
- Claim Date: Period (start_date, end_date) when the claim was made. Both start_date and end_date should be in ISO-8601 format. If the claim was made on a single date rather than a date range, set the same date for both start_date and end_date. If date is unknown, return **NONE**.
- Claim Source Text: List of **all** quotes from the original text that are relevant to the claim.
- 3. Format each claim as (<subject_entity>{tuple_delimiter}<object_entity>{tuple_delimiter}<claim_type>{tuple_delimiter}<claim_status>{tuple_delimiter}<claim_start_date>{tuple_delimiter}<claim_end_date>{tuple_delimiter}<claim_description>{tuple_delimiter}<claim_source>)
- 4. Return output in language of the 'Text' as a single list of all the claims identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
- 5. If there's nothing satisfy the above requirements, just keep output empty.
- 6. When finished, output {completion_delimiter}
################
-Examples-
################
Example 1:
Entity specification: organization
Claim description: red flags associated with an entity
Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015.
Output:
(COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.)
{completion_delimiter}
###########################
Example 2:
Entity specification: Company A, Person C
Claim description: red flags associated with an entity
Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015.
Output:
(COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.)
{record_delimiter}
(PERSON C{tuple_delimiter}NONE{tuple_delimiter}CORRUPTION{tuple_delimiter}SUSPECTED{tuple_delimiter}2015-01-01T00:00:00{tuple_delimiter}2015-12-30T00:00:00{tuple_delimiter}Person C was suspected of engaging in corruption activities in 2015{tuple_delimiter}The company is owned by Person C who was suspected of engaging in corruption activities in 2015)
{completion_delimiter}
################
-Real Data-
################
Use the following input for your answer.
Entity specification: {entity_specs}
Claim description: {claim_description}
Text: {input_text}
Output:"""
CONTINUE_PROMPT = "MANY entities were missed in the last extraction. Add them below using the same format(see 'Steps', start with the 'Output').\nOutput: "
LOOP_PROMPT = "It appears some entities may have still been missed. Answer YES {tuple_delimiter} NO if there are still entities that need to be added.\n"

View File

@ -17,9 +17,10 @@ from graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
from graphrag.general.extractor import Extractor
from graphrag.general.leiden import add_community_info2graph
from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types
from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
from rag.utils import num_tokens_from_string
from timeit import default_timer as timer
import trio
@dataclass
@ -52,7 +53,7 @@ class CommunityReportsExtractor(Extractor):
self._extraction_prompt = COMMUNITY_REPORT_PROMPT
self._max_report_length = max_report_length or 1500
def __call__(self, graph: nx.Graph, callback: Callable | None = None):
async def __call__(self, graph: nx.Graph, callback: Callable | None = None):
for node_degree in graph.degree:
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
@ -86,28 +87,25 @@ class CommunityReportsExtractor(Extractor):
}
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
gen_conf = {"temperature": 0.3}
try:
response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
token_count += num_tokens_from_string(text + response)
response = re.sub(r"^[^\{]*", "", response)
response = re.sub(r"[^\}]*$", "", response)
response = re.sub(r"\{\{", "{", response)
response = re.sub(r"\}\}", "}", response)
logging.debug(response)
response = json.loads(response)
if not dict_has_keys_with_types(response, [
("title", str),
("summary", str),
("findings", list),
("rating", float),
("rating_explanation", str),
]):
continue
response["weight"] = weight
response["entities"] = ents
except Exception:
logging.exception("CommunityReportsExtractor got exception")
async with chat_limiter:
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
token_count += num_tokens_from_string(text + response)
response = re.sub(r"^[^\{]*", "", response)
response = re.sub(r"[^\}]*$", "", response)
response = re.sub(r"\{\{", "{", response)
response = re.sub(r"\}\}", "}", response)
logging.debug(response)
response = json.loads(response)
if not dict_has_keys_with_types(response, [
("title", str),
("summary", str),
("findings", list),
("rating", float),
("rating_explanation", str),
]):
continue
response["weight"] = weight
response["entities"] = ents
add_community_info2graph(graph, ents, response["title"])
res_str.append(self._get_text_output(response))

View File

@ -14,16 +14,15 @@
# limitations under the License.
#
import logging
import os
import re
from collections import defaultdict, Counter
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from typing import Callable
import trio
from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT
from graphrag.utils import get_llm_cache, set_llm_cache, handle_single_entity_extraction, \
handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list
handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list, chat_limiter
from rag.llm.chat_model import Base as CompletionLLM
from rag.utils import truncate
@ -91,54 +90,50 @@ class Extractor:
)
return dict(maybe_nodes), dict(maybe_edges)
def __call__(
async def __call__(
self, chunks: list[tuple[str, str]],
callback: Callable | None = None
):
results = []
max_workers = int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 10))
with ThreadPoolExecutor(max_workers=max_workers) as exe:
threads = []
self.callback = callback
start_ts = trio.current_time()
out_results = []
async with trio.open_nursery() as nursery:
for i, (cid, ck) in enumerate(chunks):
ck = truncate(ck, int(self._llm.max_length*0.8))
threads.append(
exe.submit(self._process_single_content, (cid, ck)))
for i, _ in enumerate(threads):
n, r, tc = _.result()
if not isinstance(n, Exception):
results.append((n, r))
if callback:
callback(0.5 + 0.1 * i / len(threads), f"Entities extraction progress ... {i + 1}/{len(threads)} ({tc} tokens)")
elif callback:
callback(msg="Knowledge graph extraction error:{}".format(str(n)))
nursery.start_soon(self._process_single_content, (cid, ck), i, len(chunks), out_results)
maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list)
for m_nodes, m_edges in results:
sum_token_count = 0
for m_nodes, m_edges, token_count in out_results:
for k, v in m_nodes.items():
maybe_nodes[k].extend(v)
for k, v in m_edges.items():
maybe_edges[tuple(sorted(k))].extend(v)
logging.info("Inserting entities into storage...")
sum_token_count += token_count
now = trio.current_time()
if callback:
callback(msg = f"Entities and relationships extraction done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {sum_token_count} tokens, {now-start_ts:.2f}s.")
start_ts = now
logging.info("Entities merging...")
all_entities_data = []
with ThreadPoolExecutor(max_workers=max_workers) as exe:
threads = []
async with trio.open_nursery() as nursery:
for en_nm, ents in maybe_nodes.items():
threads.append(
exe.submit(self._merge_nodes, en_nm, ents))
for t in threads:
n = t.result()
if not isinstance(n, Exception):
all_entities_data.append(n)
elif callback:
callback(msg="Knowledge graph nodes merging error: {}".format(str(n)))
nursery.start_soon(self._merge_nodes, en_nm, ents, all_entities_data)
now = trio.current_time()
if callback:
callback(msg = f"Entities merging done, {now-start_ts:.2f}s.")
logging.info("Inserting relationships into storage...")
start_ts = now
logging.info("Relationships merging...")
all_relationships_data = []
for (src, tgt), rels in maybe_edges.items():
all_relationships_data.append(self._merge_edges(src, tgt, rels))
async with trio.open_nursery() as nursery:
for (src, tgt), rels in maybe_edges.items():
nursery.start_soon(self._merge_edges, src, tgt, rels, all_relationships_data)
now = trio.current_time()
if callback:
callback(msg = f"Relationships merging done, {now-start_ts:.2f}s.")
if not len(all_entities_data) and not len(all_relationships_data):
logging.warning(
@ -152,7 +147,7 @@ class Extractor:
return all_entities_data, all_relationships_data
def _merge_nodes(self, entity_name: str, entities: list[dict]):
async def _merge_nodes(self, entity_name: str, entities: list[dict], all_relationships_data):
if not entities:
return
already_entity_types = []
@ -176,26 +171,22 @@ class Extractor:
sorted(set([dp["description"] for dp in entities] + already_description))
)
already_source_ids = flat_uniq_list(entities, "source_id")
try:
description = self._handle_entity_relation_summary(
entity_name, description
)
node_data = dict(
entity_type=entity_type,
description=description,
source_id=already_source_ids,
)
node_data["entity_name"] = entity_name
self._set_entity_(entity_name, node_data)
return node_data
except Exception as e:
return e
description = await self._handle_entity_relation_summary(entity_name, description)
node_data = dict(
entity_type=entity_type,
description=description,
source_id=already_source_ids,
)
node_data["entity_name"] = entity_name
self._set_entity_(entity_name, node_data)
all_relationships_data.append(node_data)
def _merge_edges(
async def _merge_edges(
self,
src_id: str,
tgt_id: str,
edges_data: list[dict]
edges_data: list[dict],
all_relationships_data
):
if not edges_data:
return
@ -226,7 +217,7 @@ class Extractor:
"description": description,
"entity_type": 'UNKNOWN'
})
description = self._handle_entity_relation_summary(
description = await self._handle_entity_relation_summary(
f"({src_id}, {tgt_id})", description
)
edge_data = dict(
@ -238,10 +229,9 @@ class Extractor:
source_id=source_id
)
self._set_relation_(src_id, tgt_id, edge_data)
all_relationships_data.append(edge_data)
return edge_data
def _handle_entity_relation_summary(
async def _handle_entity_relation_summary(
self,
entity_or_relation_name: str,
description: str
@ -256,5 +246,6 @@ class Extractor:
)
use_prompt = prompt_template.format(**context_base)
logging.info(f"Trigger summary: {entity_or_relation_name}")
summary = self._chat(use_prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.8})
async with chat_limiter:
summary = await trio.to_thread.run_sync(lambda: self._chat(use_prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.8}))
return summary

View File

@ -5,15 +5,15 @@ Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""
import logging
import re
from typing import Any, Callable
from dataclasses import dataclass
import tiktoken
import trio
from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS, DEFAULT_ENTITY_TYPES
from graphrag.general.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, chat_limiter
from rag.llm.chat_model import Base as CompletionLLM
import networkx as nx
from rag.utils import num_tokens_from_string
@ -102,53 +102,47 @@ class GraphExtractor(Extractor):
self._entity_types_key: ",".join(DEFAULT_ENTITY_TYPES),
}
def _process_single_content(self,
chunk_key_dp: tuple[str, str]
):
async def _process_single_content(self, chunk_key_dp: tuple[str, str], chunk_seq: int, num_chunks: int, out_results):
token_count = 0
chunk_key = chunk_key_dp[0]
content = chunk_key_dp[1]
variables = {
**self._prompt_variables,
self._input_text_key: content,
}
try:
gen_conf = {"temperature": 0.3}
hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
response = self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf)
token_count += num_tokens_from_string(hint_prompt + response)
results = response or ""
history = [{"role": "system", "content": hint_prompt}, {"role": "user", "content": response}]
# Repeat to ensure we maximize entity count
for i in range(self._max_gleanings):
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
history.append({"role": "user", "content": text})
response = self._chat("", history, gen_conf)
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
results += response or ""
# if this is the final glean, don't bother updating the continuation flag
if i >= self._max_gleanings - 1:
break
history.append({"role": "assistant", "content": response})
history.append({"role": "user", "content": LOOP_PROMPT})
continuation = self._chat("", history, {"temperature": 0.8})
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
if continuation != "YES":
break
record_delimiter = variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER)
tuple_delimiter = variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER)
records = [re.sub(r"^\(|\)$", "", r.strip()) for r in results.split(record_delimiter)]
records = [r for r in records if r.strip()]
maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, tuple_delimiter)
return maybe_nodes, maybe_edges, token_count
except Exception as e:
logging.exception("error extracting graph")
return e, None, None
gen_conf = {"temperature": 0.3}
hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
async with chat_limiter:
response = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf))
token_count += num_tokens_from_string(hint_prompt + response)
results = response or ""
history = [{"role": "system", "content": hint_prompt}, {"role": "user", "content": response}]
# Repeat to ensure we maximize entity count
for i in range(self._max_gleanings):
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
history.append({"role": "user", "content": text})
async with chat_limiter:
response = await trio.to_thread.run_sync(lambda: self._chat("", history, gen_conf))
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
results += response or ""
# if this is the final glean, don't bother updating the continuation flag
if i >= self._max_gleanings - 1:
break
history.append({"role": "assistant", "content": response})
history.append({"role": "user", "content": LOOP_PROMPT})
async with chat_limiter:
continuation = await trio.to_thread.run_sync(lambda: self._chat("", history, {"temperature": 0.8}))
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
if continuation != "YES":
break
record_delimiter = variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER)
tuple_delimiter = variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER)
records = [re.sub(r"^\(|\)$", "", r.strip()) for r in results.split(record_delimiter)]
records = [r for r in records if r.strip()]
maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, tuple_delimiter)
out_results.append((maybe_nodes, maybe_edges, token_count))
if self.callback:
self.callback(0.5+0.1*len(out_results)/num_chunks, msg = f"Entities extraction of chunk {chunk_seq} {len(out_results)}/{num_chunks} done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {token_count} tokens.")

View File

@ -17,6 +17,7 @@ import json
import logging
from functools import reduce, partial
import networkx as nx
import trio
from api import settings
from graphrag.general.community_reports_extractor import CommunityReportsExtractor
@ -41,18 +42,24 @@ class Dealer:
embed_bdl=None,
callback=None
):
docids = list(set([docid for docid,_ in chunks]))
self.tenant_id = tenant_id
self.kb_id = kb_id
self.chunks = chunks
self.llm_bdl = llm_bdl
self.embed_bdl = embed_bdl
ext = extractor(self.llm_bdl, language=language,
self.ext = extractor(self.llm_bdl, language=language,
entity_types=entity_types,
get_entity=partial(get_entity, tenant_id, kb_id),
set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl),
get_relation=partial(get_relation, tenant_id, kb_id),
set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl)
)
ents, rels = ext(chunks, callback)
self.graph = nx.Graph()
self.callback = callback
async def __call__(self):
docids = list(set([docid for docid, _ in self.chunks]))
ents, rels = await self.ext(self.chunks, self.callback)
for en in ents:
self.graph.add_node(en["entity_name"], entity_type=en["entity_type"])#, description=en["description"])
@ -64,16 +71,16 @@ class Dealer:
#description=rel["description"]
)
with RedisDistributedLock(kb_id, 60*60):
old_graph, old_doc_ids = get_graph(tenant_id, kb_id)
with RedisDistributedLock(self.kb_id, 60*60):
old_graph, old_doc_ids = get_graph(self.tenant_id, self.kb_id)
if old_graph is not None:
logging.info("Merge with an exiting graph...................")
self.graph = reduce(graph_merge, [old_graph, self.graph])
update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, self.graph, 2)
update_nodes_pagerank_nhop_neighbour(self.tenant_id, self.kb_id, self.graph, 2)
if old_doc_ids:
docids.extend(old_doc_ids)
docids = list(set(docids))
set_graph(tenant_id, kb_id, self.graph, docids)
set_graph(self.tenant_id, self.kb_id, self.graph, docids)
class WithResolution(Dealer):
@ -84,47 +91,50 @@ class WithResolution(Dealer):
embed_bdl=None,
callback=None
):
self.tenant_id = tenant_id
self.kb_id = kb_id
self.llm_bdl = llm_bdl
self.embed_bdl = embed_bdl
with RedisDistributedLock(kb_id, 60*60):
self.graph, doc_ids = get_graph(tenant_id, kb_id)
self.callback = callback
async def __call__(self):
with RedisDistributedLock(self.kb_id, 60*60):
self.graph, doc_ids = await trio.to_thread.run_sync(lambda: get_graph(self.tenant_id, self.kb_id))
if not self.graph:
logging.error(f"Faild to fetch the graph. tenant_id:{kb_id}, kb_id:{kb_id}")
if callback:
callback(-1, msg="Faild to fetch the graph.")
logging.error(f"Faild to fetch the graph. tenant_id:{self.kb_id}, kb_id:{self.kb_id}")
if self.callback:
self.callback(-1, msg="Faild to fetch the graph.")
return
if callback:
callback(msg="Fetch the existing graph.")
if self.callback:
self.callback(msg="Fetch the existing graph.")
er = EntityResolution(self.llm_bdl,
get_entity=partial(get_entity, tenant_id, kb_id),
set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl),
get_relation=partial(get_relation, tenant_id, kb_id),
set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl))
reso = er(self.graph)
get_entity=partial(get_entity, self.tenant_id, self.kb_id),
set_entity=partial(set_entity, self.tenant_id, self.kb_id, self.embed_bdl),
get_relation=partial(get_relation, self.tenant_id, self.kb_id),
set_relation=partial(set_relation, self.tenant_id, self.kb_id, self.embed_bdl))
reso = await er(self.graph)
self.graph = reso.graph
logging.info("Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
if callback:
callback(msg="Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, self.graph, 2)
set_graph(tenant_id, kb_id, self.graph, doc_ids)
if self.callback:
self.callback(msg="Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
await trio.to_thread.run_sync(lambda: update_nodes_pagerank_nhop_neighbour(self.tenant_id, self.kb_id, self.graph, 2))
await trio.to_thread.run_sync(lambda: set_graph(self.tenant_id, self.kb_id, self.graph, doc_ids))
settings.docStoreConn.delete({
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({
"knowledge_graph_kwd": "relation",
"kb_id": kb_id,
"kb_id": self.kb_id,
"from_entity_kwd": reso.removed_entities
}, search.index_name(tenant_id), kb_id)
settings.docStoreConn.delete({
}, search.index_name(self.tenant_id), self.kb_id))
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({
"knowledge_graph_kwd": "relation",
"kb_id": kb_id,
"kb_id": self.kb_id,
"to_entity_kwd": reso.removed_entities
}, search.index_name(tenant_id), kb_id)
settings.docStoreConn.delete({
}, search.index_name(self.tenant_id), self.kb_id))
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({
"knowledge_graph_kwd": "entity",
"kb_id": kb_id,
"kb_id": self.kb_id,
"entity_kwd": reso.removed_entities
}, search.index_name(tenant_id), kb_id)
}, search.index_name(self.tenant_id), self.kb_id))
class WithCommunity(Dealer):
@ -136,38 +146,41 @@ class WithCommunity(Dealer):
callback=None
):
self.tenant_id = tenant_id
self.kb_id = kb_id
self.community_structure = None
self.community_reports = None
self.llm_bdl = llm_bdl
self.embed_bdl = embed_bdl
with RedisDistributedLock(kb_id, 60*60):
self.graph, doc_ids = get_graph(tenant_id, kb_id)
self.callback = callback
async def __call__(self):
with RedisDistributedLock(self.kb_id, 60*60):
self.graph, doc_ids = get_graph(self.tenant_id, self.kb_id)
if not self.graph:
logging.error(f"Faild to fetch the graph. tenant_id:{kb_id}, kb_id:{kb_id}")
if callback:
callback(-1, msg="Faild to fetch the graph.")
logging.error(f"Faild to fetch the graph. tenant_id:{self.kb_id}, kb_id:{self.kb_id}")
if self.callback:
self.callback(-1, msg="Faild to fetch the graph.")
return
if callback:
callback(msg="Fetch the existing graph.")
if self.callback:
self.callback(msg="Fetch the existing graph.")
cr = CommunityReportsExtractor(self.llm_bdl,
get_entity=partial(get_entity, tenant_id, kb_id),
set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl),
get_relation=partial(get_relation, tenant_id, kb_id),
set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl))
cr = cr(self.graph, callback=callback)
get_entity=partial(get_entity, self.tenant_id, self.kb_id),
set_entity=partial(set_entity, self.tenant_id, self.kb_id, self.embed_bdl),
get_relation=partial(get_relation, self.tenant_id, self.kb_id),
set_relation=partial(set_relation, self.tenant_id, self.kb_id, self.embed_bdl))
cr = await cr(self.graph, callback=self.callback)
self.community_structure = cr.structured_output
self.community_reports = cr.output
set_graph(tenant_id, kb_id, self.graph, doc_ids)
await trio.to_thread.run_sync(lambda: set_graph(self.tenant_id, self.kb_id, self.graph, doc_ids))
if callback:
callback(msg="Graph community extraction is done. Indexing {} reports.".format(len(cr.structured_output)))
if self.callback:
self.callback(msg="Graph community extraction is done. Indexing {} reports.".format(len(cr.structured_output)))
settings.docStoreConn.delete({
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({
"knowledge_graph_kwd": "community_report",
"kb_id": kb_id
}, search.index_name(tenant_id), kb_id)
"kb_id": self.kb_id
}, search.index_name(self.tenant_id), self.kb_id))
for stru, rep in zip(self.community_structure, self.community_reports):
obj = {
@ -183,7 +196,7 @@ class WithCommunity(Dealer):
"weight_flt": stru["weight"],
"entities_kwd": stru["entities"],
"important_kwd": stru["entities"],
"kb_id": kb_id,
"kb_id": self.kb_id,
"source_id": doc_ids,
"available_int": 0
}
@ -193,5 +206,5 @@ class WithCommunity(Dealer):
# chunk["q_%d_vec" % len(ebd[0])] = ebd[0]
#except Exception as e:
# logging.exception(f"Fail to embed entity relation: {e}")
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id))
await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(self.tenant_id)))

View File

@ -16,16 +16,14 @@
import logging
import collections
import os
import re
import traceback
from typing import Any
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
import trio
from graphrag.general.extractor import Extractor
from graphrag.general.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, chat_limiter
from rag.llm.chat_model import Base as CompletionLLM
import markdown_to_json
from functools import reduce
@ -80,63 +78,47 @@ class MindMapExtractor(Extractor):
)
return arr
def __call__(
async def __call__(
self, sections: list[str], prompt_variables: dict[str, Any] | None = None
) -> MindMapResult:
"""Call method definition."""
if prompt_variables is None:
prompt_variables = {}
try:
res = []
max_workers = int(os.environ.get('MINDMAP_EXTRACTOR_MAX_WORKERS', 12))
with ThreadPoolExecutor(max_workers=max_workers) as exe:
threads = []
token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512)
texts = []
cnt = 0
for i in range(len(sections)):
section_cnt = num_tokens_from_string(sections[i])
if cnt + section_cnt >= token_count and texts:
threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
texts = []
cnt = 0
texts.append(sections[i])
cnt += section_cnt
if texts:
threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
for i, _ in enumerate(threads):
res.append(_.result())
if not res:
return MindMapResult(output={"id": "root", "children": []})
merge_json = reduce(self._merge, res)
if len(merge_json) > 1:
keys = [re.sub(r"\*+", "", k) for k, v in merge_json.items() if isinstance(v, dict)]
keyset = set(i for i in keys if i)
merge_json = {
"id": "root",
"children": [
{
"id": self._key(k),
"children": self._be_children(v, keyset)
}
for k, v in merge_json.items() if isinstance(v, dict) and self._key(k)
]
}
else:
k = self._key(list(merge_json.keys())[0])
merge_json = {"id": k, "children": self._be_children(list(merge_json.items())[0][1], {k})}
except Exception as e:
logging.exception("error mind graph")
self._on_error(
e,
traceback.format_exc(), None
)
merge_json = {"error": str(e)}
res = []
token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512)
texts = []
cnt = 0
async with trio.open_nursery() as nursery:
for i in range(len(sections)):
section_cnt = num_tokens_from_string(sections[i])
if cnt + section_cnt >= token_count and texts:
nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res)
texts = []
cnt = 0
texts.append(sections[i])
cnt += section_cnt
if texts:
nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res)
if not res:
return MindMapResult(output={"id": "root", "children": []})
merge_json = reduce(self._merge, res)
if len(merge_json) > 1:
keys = [re.sub(r"\*+", "", k) for k, v in merge_json.items() if isinstance(v, dict)]
keyset = set(i for i in keys if i)
merge_json = {
"id": "root",
"children": [
{
"id": self._key(k),
"children": self._be_children(v, keyset)
}
for k, v in merge_json.items() if isinstance(v, dict) and self._key(k)
]
}
else:
k = self._key(list(merge_json.keys())[0])
merge_json = {"id": k, "children": self._be_children(list(merge_json.items())[0][1], {k})}
return MindMapResult(output=merge_json)
@ -181,8 +163,8 @@ class MindMapExtractor(Extractor):
return self._list_to_kv(to_ret)
def _process_document(
self, text: str, prompt_variables: dict[str, str]
async def _process_document(
self, text: str, prompt_variables: dict[str, str], out_res
) -> str:
variables = {
**prompt_variables,
@ -190,8 +172,9 @@ class MindMapExtractor(Extractor):
}
text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
gen_conf = {"temperature": 0.5}
response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
async with chat_limiter:
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
response = re.sub(r"```[^\n]*", "", response)
logging.debug(response)
logging.debug(self._todict(markdown_to_json.dictify(response)))
return self._todict(markdown_to_json.dictify(response))
out_res.append(self._todict(markdown_to_json.dictify(response)))

View File

@ -18,6 +18,7 @@ import argparse
import json
import networkx as nx
import trio
from api import settings
from api.db import LLMType
@ -54,10 +55,13 @@ if __name__ == "__main__":
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
dealer = Dealer(GraphExtractor, args.tenant_id, kb_id, llm_bdl, chunks, "English", embed_bdl=embed_bdl)
trio.run(dealer())
print(json.dumps(nx.node_link_data(dealer.graph), ensure_ascii=False, indent=2))
dealer = WithResolution(args.tenant_id, kb_id, llm_bdl, embed_bdl)
trio.run(dealer())
dealer = WithCommunity(args.tenant_id, kb_id, llm_bdl, embed_bdl)
trio.run(dealer())
print("------------------ COMMUNITY REPORT ----------------------\n", dealer.community_reports)
print(json.dumps(dealer.community_structure, ensure_ascii=False, indent=2))

View File

@ -4,16 +4,16 @@
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""
import logging
import re
from typing import Any, Callable
from dataclasses import dataclass
from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS
from graphrag.light.graph_prompt import PROMPTS
from graphrag.utils import pack_user_ass_to_openai_messages, split_string_by_multi_markers
from graphrag.utils import pack_user_ass_to_openai_messages, split_string_by_multi_markers, chat_limiter
from rag.llm.chat_model import Base as CompletionLLM
import networkx as nx
from rag.utils import num_tokens_from_string
import trio
@dataclass
@ -82,7 +82,7 @@ class GraphExtractor(Extractor):
)
self._left_token_count = max(llm_invoker.max_length * 0.6, self._left_token_count)
def _process_single_content(self, chunk_key_dp: tuple[str, str]):
async def _process_single_content(self, chunk_key_dp: tuple[str, str], chunk_seq: int, num_chunks: int, out_results):
token_count = 0
chunk_key = chunk_key_dp[0]
content = chunk_key_dp[1]
@ -90,38 +90,39 @@ class GraphExtractor(Extractor):
**self._context_base, input_text="{input_text}"
).format(**self._context_base, input_text=content)
try:
gen_conf = {"temperature": 0.8}
final_result = self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf)
token_count += num_tokens_from_string(hint_prompt + final_result)
history = pack_user_ass_to_openai_messages("Output:", final_result, self._continue_prompt)
for now_glean_index in range(self._max_gleanings):
glean_result = self._chat(hint_prompt, history, gen_conf)
history.extend([{"role": "assistant", "content": glean_result}, {"role": "user", "content": self._continue_prompt}])
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt)
final_result += glean_result
if now_glean_index == self._max_gleanings - 1:
break
gen_conf = {"temperature": 0.8}
async with chat_limiter:
final_result = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf))
token_count += num_tokens_from_string(hint_prompt + final_result)
history = pack_user_ass_to_openai_messages("Output:", final_result, self._continue_prompt)
for now_glean_index in range(self._max_gleanings):
async with chat_limiter:
glean_result = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, history, gen_conf))
history.extend([{"role": "assistant", "content": glean_result}, {"role": "user", "content": self._continue_prompt}])
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt)
final_result += glean_result
if now_glean_index == self._max_gleanings - 1:
break
if_loop_result = self._chat(self._if_loop_prompt, history, gen_conf)
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt)
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
if if_loop_result != "yes":
break
async with chat_limiter:
if_loop_result = await trio.to_thread.run_sync(lambda: self._chat(self._if_loop_prompt, history, gen_conf))
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt)
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
if if_loop_result != "yes":
break
records = split_string_by_multi_markers(
final_result,
[self._context_base["record_delimiter"], self._context_base["completion_delimiter"]],
)
rcds = []
for record in records:
record = re.search(r"\((.*)\)", record)
if record is None:
continue
rcds.append(record.group(1))
records = rcds
maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, self._context_base["tuple_delimiter"])
return maybe_nodes, maybe_edges, token_count
except Exception as e:
logging.exception("error extracting graph")
return e, None, None
records = split_string_by_multi_markers(
final_result,
[self._context_base["record_delimiter"], self._context_base["completion_delimiter"]],
)
rcds = []
for record in records:
record = re.search(r"\((.*)\)", record)
if record is None:
continue
rcds.append(record.group(1))
records = rcds
maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, self._context_base["tuple_delimiter"])
out_results.append((maybe_nodes, maybe_edges, token_count))
if self.callback:
self.callback(0.5+0.1*len(out_results)/num_chunks, msg = f"Entities extraction of chunk {chunk_seq} {len(out_results)}/{num_chunks} done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {token_count} tokens.")

View File

@ -15,6 +15,8 @@ from collections import defaultdict
from copy import deepcopy
from hashlib import md5
from typing import Any, Callable
import os
import trio
import networkx as nx
import numpy as np
@ -28,6 +30,7 @@ from rag.utils.redis_conn import REDIS_CONN
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
chat_limiter = trio.CapacityLimiter(int(os.environ.get('MAX_CONCURRENT_CHATS', 100)))
def perform_variable_replacements(
input: str, history: list[dict] | None = None, variables: dict | None = None

View File

@ -122,7 +122,8 @@ dependencies = [
"pyodbc>=5.2.0,<6.0.0",
"pyicu>=2.13.1,<3.0.0",
"flasgger>=0.9.7.1,<0.10.0",
"xxhash>=3.5.0,<4.0.0"
"xxhash>=3.5.0,<4.0.0",
"trio>=0.29.0",
]
[project.optional-dependencies]
@ -133,4 +134,7 @@ full = [
"flagembedding==1.2.10",
"torch>=2.5.0,<3.0.0",
"transformers>=4.35.0,<5.0.0"
]
]
[[tool.uv.index]]
url = "https://mirrors.aliyun.com/pypi/simple"

View File

@ -14,15 +14,14 @@
# limitations under the License.
#
import logging
import os
import re
from concurrent.futures import ThreadPoolExecutor, ALL_COMPLETED, wait
from threading import Lock
import umap
import numpy as np
from sklearn.mixture import GaussianMixture
import trio
from graphrag.utils import get_llm_cache, get_embed_cache, set_embed_cache, set_llm_cache
from graphrag.utils import get_llm_cache, get_embed_cache, set_embed_cache, set_llm_cache, chat_limiter
from rag.utils import truncate
@ -68,24 +67,25 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
optimal_clusters = n_clusters[np.argmin(bics)]
return optimal_clusters
def __call__(self, chunks, random_state, callback=None):
async def __call__(self, chunks, random_state, callback=None):
layers = [(0, len(chunks))]
start, end = 0, len(chunks)
if len(chunks) <= 1:
return []
chunks = [(s, a) for s, a in chunks if s and len(a) > 0]
def summarize(ck_idx, lock):
async def summarize(ck_idx, lock):
nonlocal chunks
try:
texts = [chunks[i][0] for i in ck_idx]
len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
cnt = self._chat("You're a helpful assistant.",
[{"role": "user",
"content": self._prompt.format(cluster_content=cluster_content)}],
{"temperature": 0.3, "max_tokens": self._max_token}
)
async with chat_limiter:
cnt = await trio.to_thread.run_sync(lambda: self._chat("You're a helpful assistant.",
[{"role": "user",
"content": self._prompt.format(cluster_content=cluster_content)}],
{"temperature": 0.3, "max_tokens": self._max_token}
))
cnt = re.sub("(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", "",
cnt)
logging.debug(f"SUM: {cnt}")
@ -97,10 +97,11 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
return e
labels = []
lock = Lock()
while end - start > 1:
embeddings = [embd for _, embd in chunks[start: end]]
if len(embeddings) == 2:
summarize([start, start + 1], Lock())
await summarize([start, start + 1], lock)
if callback:
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
labels.extend([0, 0])
@ -122,19 +123,14 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
probs = gm.predict_proba(reduced_embeddings)
lbls = [np.where(prob > self._threshold)[0] for prob in probs]
lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls]
lock = Lock()
with ThreadPoolExecutor(max_workers=int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 10))) as executor:
threads = []
async with trio.open_nursery() as nursery:
for c in range(n_clusters):
ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]
if not ck_idx:
continue
threads.append(executor.submit(summarize, ck_idx, lock))
wait(threads, return_when=ALL_COMPLETED)
for th in threads:
if isinstance(th.result(), Exception):
raise th.result()
logging.debug(str([t.result() for t in threads]))
async with chat_limiter:
nursery.start_soon(lambda: summarize(ck_idx, lock))
assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
labels.extend(lbls)

View File

@ -30,7 +30,6 @@ CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
CONSUMER_NAME = "task_executor_" + CONSUMER_NO
initRootLogger(CONSUMER_NAME)
import asyncio
import logging
import os
from datetime import datetime
@ -38,14 +37,14 @@ import json
import xxhash
import copy
import re
import time
import threading
from functools import partial
from io import BytesIO
from multiprocessing.context import TimeoutError
from timeit import default_timer as timer
import tracemalloc
import resource
import signal
import trio
import numpy as np
from peewee import DoesNotExist
@ -64,8 +63,9 @@ from rag.nlp import search, rag_tokenizer
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings, TAG_FLD, PAGERANK_FLD
from rag.utils import num_tokens_from_string
from rag.utils.redis_conn import REDIS_CONN, Payload
from rag.utils.redis_conn import REDIS_CONN
from rag.utils.storage_factory import STORAGE_IMPL
from graphrag.utils import chat_limiter
BATCH_SIZE = 64
@ -88,28 +88,28 @@ FACTORY = {
ParserType.TAG.value: tag
}
UNACKED_ITERATOR = None
CONSUMER_NAME = "task_consumer_" + CONSUMER_NO
PAYLOAD: Payload | None = None
BOOT_AT = datetime.now().astimezone().isoformat(timespec="milliseconds")
PENDING_TASKS = 0
LAG_TASKS = 0
mt_lock = threading.Lock()
DONE_TASKS = 0
FAILED_TASKS = 0
CURRENT_TASK = None
tracemalloc_started = False
CURRENT_TASKS = {}
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
MAX_CONCURRENT_CHUNK_BUILDERS = int(os.environ.get('MAX_CONCURRENT_CHUNK_BUILDERS', "1"))
task_limiter = trio.CapacityLimiter(MAX_CONCURRENT_TASKS)
chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS)
# SIGUSR1 handler: start tracemalloc and take snapshot
def start_tracemalloc_and_snapshot(signum, frame):
global tracemalloc_started
if not tracemalloc_started:
logging.info("got SIGUSR1, start tracemalloc")
if not tracemalloc.is_tracing():
logging.info("start tracemalloc")
tracemalloc.start()
tracemalloc_started = True
else:
logging.info("got SIGUSR1, tracemalloc is already running")
logging.info("tracemalloc is already running")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
snapshot_file = f"snapshot_{timestamp}.trace"
@ -117,17 +117,17 @@ def start_tracemalloc_and_snapshot(signum, frame):
snapshot = tracemalloc.take_snapshot()
snapshot.dump(snapshot_file)
logging.info(f"taken snapshot {snapshot_file}")
current, peak = tracemalloc.get_traced_memory()
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
# SIGUSR2 handler: stop tracemalloc
def stop_tracemalloc(signum, frame):
global tracemalloc_started
if tracemalloc_started:
logging.info("go SIGUSR2, stop tracemalloc")
if tracemalloc.is_tracing():
logging.info("stop tracemalloc")
tracemalloc.stop()
tracemalloc_started = False
else:
logging.info("got SIGUSR2, tracemalloc not running")
logging.info("tracemalloc not running")
class TaskCanceledException(Exception):
def __init__(self, msg):
@ -135,17 +135,9 @@ class TaskCanceledException(Exception):
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
global PAYLOAD
if prog is not None and prog < 0:
msg = "[ERROR]" + msg
try:
cancel = TaskService.do_cancel(task_id)
except DoesNotExist:
logging.warning(f"set_progress task {task_id} is unknown")
if PAYLOAD:
PAYLOAD.ack()
PAYLOAD = None
return
cancel = TaskService.do_cancel(task_id)
if cancel:
msg += " [Canceled]"
@ -162,66 +154,55 @@ def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing...
d["progress"] = prog
logging.info(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}")
try:
TaskService.update_progress(task_id, d)
except DoesNotExist:
logging.warning(f"set_progress task {task_id} is unknown")
if PAYLOAD:
PAYLOAD.ack()
PAYLOAD = None
return
TaskService.update_progress(task_id, d)
close_connection()
if cancel and PAYLOAD:
PAYLOAD.ack()
PAYLOAD = None
if cancel:
raise TaskCanceledException(msg)
def collect():
global CONSUMER_NAME, PAYLOAD, DONE_TASKS, FAILED_TASKS
async def collect():
global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS
global UNACKED_ITERATOR
try:
PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
if not PAYLOAD:
PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
if not PAYLOAD:
time.sleep(1)
return None
if not UNACKED_ITERATOR:
UNACKED_ITERATOR = REDIS_CONN.get_unacked_iterator(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
try:
redis_msg = next(UNACKED_ITERATOR)
except StopIteration:
redis_msg = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
if not redis_msg:
await trio.sleep(1)
return None, None
except Exception:
logging.exception("Get task event from queue exception")
return None
logging.exception("collect got exception")
return None, None
msg = PAYLOAD.get_message()
msg = redis_msg.get_message()
if not msg:
return None
logging.error(f"collect got empty message of {redis_msg.get_msg_id()}")
redis_msg.ack()
return None, None
task = None
canceled = False
try:
task = TaskService.get_task(msg["id"])
if task:
_, doc = DocumentService.get_by_id(task["doc_id"])
canceled = doc.run == TaskStatus.CANCEL.value or doc.progress < 0
except DoesNotExist:
pass
except Exception:
logging.exception("collect get_task exception")
task = TaskService.get_task(msg["id"])
if task:
_, doc = DocumentService.get_by_id(task["doc_id"])
canceled = doc.run == TaskStatus.CANCEL.value or doc.progress < 0
if not task or canceled:
state = "is unknown" if not task else "has been cancelled"
with mt_lock:
DONE_TASKS += 1
logging.info(f"collect task {msg['id']} {state}")
FAILED_TASKS += 1
logging.warning(f"collect task {msg['id']} {state}")
redis_msg.ack()
return None
task["task_type"] = msg.get("task_type", "")
return task
return redis_msg, task
def get_storage_binary(bucket, name):
return STORAGE_IMPL.get(bucket, name)
async def get_storage_binary(bucket, name):
return await trio.to_thread.run_sync(lambda: STORAGE_IMPL.get(bucket, name))
def build_chunks(task, progress_callback):
async def build_chunks(task, progress_callback):
if task["size"] > DOC_MAXIMUM_SIZE:
set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
@ -231,7 +212,7 @@ def build_chunks(task, progress_callback):
try:
st = timer()
bucket, name = File2DocumentService.get_storage_address(doc_id=task["doc_id"])
binary = get_storage_binary(bucket, name)
binary = await get_storage_binary(bucket, name)
logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"]))
except TimeoutError:
progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
@ -247,9 +228,10 @@ def build_chunks(task, progress_callback):
raise
try:
cks = chunker.chunk(task["name"], binary=binary, from_page=task["from_page"],
to_page=task["to_page"], lang=task["language"], callback=progress_callback,
kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"])
async with chunk_limiter:
cks = await trio.to_thread.run_sync(lambda: chunker.chunk(task["name"], binary=binary, from_page=task["from_page"],
to_page=task["to_page"], lang=task["language"], callback=progress_callback,
kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"]))
logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"]))
except TaskCanceledException:
raise
@ -286,7 +268,7 @@ def build_chunks(task, progress_callback):
d["image"].save(output_buffer, format='JPEG')
st = timer()
STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue())
await trio.to_thread.run_sync(lambda: STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue()))
el += timer() - st
except Exception:
logging.exception(
@ -306,14 +288,16 @@ def build_chunks(task, progress_callback):
async def doc_keyword_extraction(chat_mdl, d, topn):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn})
if not cached:
cached = await asyncio.to_thread(keyword_extraction, chat_mdl, d["content_with_weight"], topn)
async with chat_limiter:
cached = await trio.to_thread.run_sync(lambda: keyword_extraction(chat_mdl, d["content_with_weight"], topn))
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn})
if cached:
d["important_kwd"] = cached.split(",")
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
return
tasks = [doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"]) for d in docs]
asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks))
async with trio.open_nursery() as nursery:
for d in docs:
nursery.start_soon(doc_keyword_extraction, chat_mdl, d, task["parser_config"]["auto_keywords"])
progress_callback(msg="Keywords generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
if task["parser_config"].get("auto_questions", 0):
@ -324,13 +308,15 @@ def build_chunks(task, progress_callback):
async def doc_question_proposal(chat_mdl, d, topn):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn})
if not cached:
cached = await asyncio.to_thread(question_proposal, chat_mdl, d["content_with_weight"], topn)
async with chat_limiter:
cached = await trio.to_thread.run_sync(lambda: question_proposal(chat_mdl, d["content_with_weight"], topn))
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", {"topn": topn})
if cached:
d["question_kwd"] = cached.split("\n")
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
tasks = [doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"]) for d in docs]
asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks))
async with trio.open_nursery() as nursery:
for d in docs:
nursery.start_soon(doc_question_proposal, chat_mdl, d, task["parser_config"]["auto_questions"])
progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
if task["kb_parser_config"].get("tag_kb_ids", []):
@ -361,14 +347,16 @@ def build_chunks(task, progress_callback):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], all_tags, {"topn": topn_tags})
if not cached:
picked_examples = random.choices(examples, k=2) if len(examples)>2 else examples
cached = await asyncio.to_thread(content_tagging, chat_mdl, d["content_with_weight"], all_tags, picked_examples, topn=topn_tags)
async with chat_limiter:
cached = await trio.to_thread.run_sync(lambda: content_tagging(chat_mdl, d["content_with_weight"], all_tags, picked_examples, topn=topn_tags))
if cached:
cached = json.dumps(cached)
if cached:
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
d[TAG_FLD] = json.loads(cached)
tasks = [doc_content_tagging(chat_mdl, d, topn_tags) for d in docs_to_tag]
asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks))
async with trio.open_nursery() as nursery:
for d in docs_to_tag:
nursery.start_soon(doc_content_tagging, chat_mdl, d, topn_tags)
progress_callback(msg="Tagging {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
return docs
@ -379,7 +367,7 @@ def init_kb(row, vector_size: int):
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
def embedding(docs, mdl, parser_config=None, callback=None):
async def embedding(docs, mdl, parser_config=None, callback=None):
if parser_config is None:
parser_config = {}
batch_size = 16
@ -396,13 +384,13 @@ def embedding(docs, mdl, parser_config=None, callback=None):
tk_count = 0
if len(tts) == len(cnts):
vts, c = mdl.encode(tts[0: 1])
vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1]))
tts = np.concatenate([vts for _ in range(len(tts))], axis=0)
tk_count += c
cnts_ = np.array([])
for i in range(0, len(cnts), batch_size):
vts, c = mdl.encode(cnts[i: i + batch_size])
vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(cnts[i: i + batch_size]))
if len(cnts_) == 0:
cnts_ = vts
else:
@ -424,7 +412,7 @@ def embedding(docs, mdl, parser_config=None, callback=None):
return tk_count, vector_size
def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
chunks = []
vctr_nm = "q_%d_vec"%vector_size
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
@ -440,7 +428,7 @@ def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
row["parser_config"]["raptor"]["threshold"]
)
original_length = len(chunks)
chunks = raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
chunks = await raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
doc = {
"doc_id": row["doc_id"],
"kb_id": [str(row["kb_id"])],
@ -465,13 +453,13 @@ def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
return res, tk_count
def run_graphrag(row, chat_model, language, embedding_model, callback=None):
async def run_graphrag(row, chat_model, language, embedding_model, callback=None):
chunks = []
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", "doc_id"]):
chunks.append((d["doc_id"], d["content_with_weight"]))
Dealer(LightKGExt if row["parser_config"]["graphrag"]["method"] != 'general' else GeneralKGExt,
dealer = Dealer(LightKGExt if row["parser_config"]["graphrag"]["method"] != 'general' else GeneralKGExt,
row["tenant_id"],
str(row["kb_id"]),
chat_model,
@ -480,9 +468,10 @@ def run_graphrag(row, chat_model, language, embedding_model, callback=None):
entity_types=row["parser_config"]["graphrag"]["entity_types"],
embed_bdl=embedding_model,
callback=callback)
await dealer()
def do_handle_task(task):
async def do_handle_task(task):
task_id = task["id"]
task_from_page = task["from_page"]
task_to_page = task["to_page"]
@ -494,6 +483,7 @@ def do_handle_task(task):
task_doc_id = task["doc_id"]
task_document_name = task["name"]
task_parser_config = task["parser_config"]
task_start_ts = timer()
# prepare the progress callback function
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
@ -505,11 +495,7 @@ def do_handle_task(task):
progress_callback(-1, msg=error_message)
raise Exception(error_message)
try:
task_canceled = TaskService.do_cancel(task_id)
except DoesNotExist:
logging.warning(f"task {task_id} is unknown")
return
task_canceled = TaskService.do_cancel(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return
@ -529,71 +515,41 @@ def do_handle_task(task):
# Either using RAPTOR or Standard chunking methods
if task.get("task_type", "") == "raptor":
try:
# bind LLM for raptor
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
# run RAPTOR
chunks, token_count = run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
except TaskCanceledException:
raise
except Exception as e:
error_message = f'Fail to bind LLM used by RAPTOR: {str(e)}'
progress_callback(-1, msg=error_message)
logging.exception(error_message)
raise
# bind LLM for raptor
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
# run RAPTOR
chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
# Either using graphrag or Standard chunking methods
elif task.get("task_type", "") == "graphrag":
start_ts = timer()
try:
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
run_graphrag(task, chat_model, task_language, embedding_model, progress_callback)
progress_callback(prog=1.0, msg="Knowledge Graph is done ({:.2f}s)".format(timer() - start_ts))
except TaskCanceledException:
raise
except Exception as e:
error_message = f'Fail to bind LLM used by Knowledge Graph: {str(e)}'
progress_callback(-1, msg=error_message)
logging.exception(error_message)
raise
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
await run_graphrag(task, chat_model, task_language, embedding_model, progress_callback)
progress_callback(prog=1.0, msg="Knowledge Graph is done ({:.2f}s)".format(timer() - start_ts))
return
elif task.get("task_type", "") == "graph_resolution":
start_ts = timer()
try:
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
WithResolution(
task["tenant_id"], str(task["kb_id"]),chat_model, embedding_model,
progress_callback
)
progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts))
except TaskCanceledException:
raise
except Exception as e:
error_message = f'Fail to bind LLM used by Knowledge Graph resolution: {str(e)}'
progress_callback(-1, msg=error_message)
logging.exception(error_message)
raise
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
with_res = WithResolution(
task["tenant_id"], str(task["kb_id"]),chat_model, embedding_model,
progress_callback
)
await with_res()
progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts))
return
elif task.get("task_type", "") == "graph_community":
start_ts = timer()
try:
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
WithCommunity(
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
progress_callback
)
progress_callback(prog=1.0, msg="GraphRAG community reports generation is done ({:.2f}s)".format(timer() - start_ts))
except TaskCanceledException:
raise
except Exception as e:
error_message = f'Fail to bind LLM used by GraphRAG community reports generation: {str(e)}'
progress_callback(-1, msg=error_message)
logging.exception(error_message)
raise
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
with_comm = WithCommunity(
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
progress_callback
)
await with_comm()
progress_callback(prog=1.0, msg="GraphRAG community reports generation is done ({:.2f}s)".format(timer() - start_ts))
return
else:
# Standard chunking methods
start_ts = timer()
chunks = build_chunks(task, progress_callback)
chunks = await build_chunks(task, progress_callback)
logging.info("Build document {}: {:.2f}s".format(task_document_name, timer() - start_ts))
if chunks is None:
return
@ -605,7 +561,7 @@ def do_handle_task(task):
progress_callback(msg="Generate {} chunks".format(len(chunks)))
start_ts = timer()
try:
token_count, vector_size = embedding(chunks, embedding_model, task_parser_config, progress_callback)
token_count, vector_size = await embedding(chunks, embedding_model, task_parser_config, progress_callback)
except Exception as e:
error_message = "Generate embedding error:{}".format(str(e))
progress_callback(-1, error_message)
@ -621,8 +577,7 @@ def do_handle_task(task):
doc_store_result = ""
es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size):
doc_store_result = settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id),
task_dataset_id)
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id), task_dataset_id))
if b % 128 == 0:
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
if doc_store_result:
@ -635,8 +590,7 @@ def do_handle_task(task):
TaskService.update_chunk_ids(task["id"], chunk_ids_str)
except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.")
doc_store_result = settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id),
task_dataset_id)
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
return
logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page,
task_to_page, len(chunks),
@ -645,51 +599,39 @@ def do_handle_task(task):
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
time_cost = timer() - start_ts
progress_callback(prog=1.0, msg="Done ({:.2f}s)".format(time_cost))
task_time_cost = timer() - task_start_ts
progress_callback(prog=1.0, msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
logging.info(
"Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page,
task_to_page, len(chunks),
token_count, time_cost))
token_count, task_time_cost))
def handle_task():
global PAYLOAD, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK
task = collect()
if task:
async def handle_task():
global DONE_TASKS, FAILED_TASKS
redis_msg, task = await collect()
if not task:
return
try:
logging.info(f"handle_task begin for task {json.dumps(task)}")
CURRENT_TASKS[task["id"]] = copy.deepcopy(task)
await do_handle_task(task)
DONE_TASKS += 1
CURRENT_TASKS.pop(task["id"], None)
logging.info(f"handle_task done for task {json.dumps(task)}")
except Exception as e:
FAILED_TASKS += 1
CURRENT_TASKS.pop(task["id"], None)
try:
logging.info(f"handle_task begin for task {json.dumps(task)}")
with mt_lock:
CURRENT_TASK = copy.deepcopy(task)
do_handle_task(task)
with mt_lock:
DONE_TASKS += 1
CURRENT_TASK = None
logging.info(f"handle_task done for task {json.dumps(task)}")
except TaskCanceledException:
with mt_lock:
DONE_TASKS += 1
CURRENT_TASK = None
try:
set_progress(task["id"], prog=-1, msg="handle_task got TaskCanceledException")
except Exception:
pass
logging.debug("handle_task got TaskCanceledException", exc_info=True)
except Exception as e:
with mt_lock:
FAILED_TASKS += 1
CURRENT_TASK = None
try:
set_progress(task["id"], prog=-1, msg=f"[Exception]: {e}")
except Exception:
pass
logging.exception(f"handle_task got exception for task {json.dumps(task)}")
if PAYLOAD:
PAYLOAD.ack()
PAYLOAD = None
set_progress(task["id"], prog=-1, msg=f"[Exception]: {e}")
except Exception:
pass
logging.exception(f"handle_task got exception for task {json.dumps(task)}")
redis_msg.ack()
def report_status():
global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK
async def report_status():
global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, DONE_TASKS, FAILED_TASKS
REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME)
while True:
try:
@ -699,17 +641,17 @@ def report_status():
PENDING_TASKS = int(group_info.get("pending", 0))
LAG_TASKS = int(group_info.get("lag", 0))
with mt_lock:
heartbeat = json.dumps({
"name": CONSUMER_NAME,
"now": now.astimezone().isoformat(timespec="milliseconds"),
"boot_at": BOOT_AT,
"pending": PENDING_TASKS,
"lag": LAG_TASKS,
"done": DONE_TASKS,
"failed": FAILED_TASKS,
"current": CURRENT_TASK,
})
current = copy.deepcopy(CURRENT_TASKS)
heartbeat = json.dumps({
"name": CONSUMER_NAME,
"now": now.astimezone().isoformat(timespec="milliseconds"),
"boot_at": BOOT_AT,
"pending": PENDING_TASKS,
"lag": LAG_TASKS,
"done": DONE_TASKS,
"failed": FAILED_TASKS,
"current": current,
})
REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
@ -718,27 +660,10 @@ def report_status():
REDIS_CONN.zpopmin(CONSUMER_NAME, expired)
except Exception:
logging.exception("report_status got exception")
time.sleep(30)
await trio.sleep(30)
def analyze_heap(snapshot1: tracemalloc.Snapshot, snapshot2: tracemalloc.Snapshot, snapshot_id: int, dump_full: bool):
msg = ""
if dump_full:
stats2 = snapshot2.statistics('lineno')
msg += f"{CONSUMER_NAME} memory usage of snapshot {snapshot_id}:\n"
for stat in stats2[:10]:
msg += f"{stat}\n"
stats1_vs_2 = snapshot2.compare_to(snapshot1, 'lineno')
msg += f"{CONSUMER_NAME} memory usage increase from snapshot {snapshot_id - 1} to snapshot {snapshot_id}:\n"
for stat in stats1_vs_2[:10]:
msg += f"{stat}\n"
msg += f"{CONSUMER_NAME} detailed traceback for the top memory consumers:\n"
for stat in stats1_vs_2[:3]:
msg += '\n'.join(stat.traceback.format())
logging.info(msg)
def main():
async def main():
logging.info(r"""
______ __ ______ __
/_ __/___ ______/ /__ / ____/ _____ _______ __/ /_____ _____
@ -755,33 +680,12 @@ def main():
if TRACE_MALLOC_ENABLED:
start_tracemalloc_and_snapshot(None, None)
# Create an event to signal the background thread to exit
stop_event = threading.Event()
background_thread = threading.Thread(target=report_status)
background_thread.daemon = True
background_thread.start()
# Handle SIGINT (Ctrl+C)
def signal_handler(sig, frame):
logging.info("Received Ctrl+C, shutting down gracefully...")
stop_event.set()
# Give the background thread time to clean up
if background_thread.is_alive():
background_thread.join(timeout=5)
logging.info("Exiting...")
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
try:
while not stop_event.is_set():
handle_task()
except KeyboardInterrupt:
logging.info("Interrupted by keyboard, shutting down...")
stop_event.set()
if background_thread.is_alive():
background_thread.join(timeout=5)
async with trio.open_nursery() as nursery:
nursery.start_soon(report_status)
while True:
async with task_limiter:
nursery.start_soon(handle_task)
logging.error("BUG!!! You should not reach here!!!")
if __name__ == "__main__":
main()
trio.run(main)

View File

@ -24,7 +24,7 @@ from rag import settings
from rag.utils import singleton
class Payload:
class RedisMsg:
def __init__(self, consumer, queue_name, group_name, msg_id, message):
self.__consumer = consumer
self.__queue_name = queue_name
@ -43,6 +43,9 @@ class Payload:
def get_message(self):
return self.__message
def get_msg_id(self):
return self.__msg_id
@singleton
class RedisDB:
@ -206,9 +209,8 @@ class RedisDB:
)
return False
def queue_consumer(
self, queue_name, group_name, consumer_name, msg_id=b">"
) -> Payload:
def queue_consumer(self, queue_name, group_name, consumer_name, msg_id=b">") -> RedisMsg:
"""https://redis.io/docs/latest/commands/xreadgroup/"""
try:
group_info = self.REDIS.xinfo_groups(queue_name)
if not any(e["name"] == group_name for e in group_info):
@ -217,15 +219,17 @@ class RedisDB:
"groupname": group_name,
"consumername": consumer_name,
"count": 1,
"block": 10000,
"block": 5,
"streams": {queue_name: msg_id},
}
messages = self.REDIS.xreadgroup(**args)
if not messages:
return None
stream, element_list = messages[0]
if not element_list:
return None
msg_id, payload = element_list[0]
res = Payload(self.REDIS, queue_name, group_name, msg_id, payload)
res = RedisMsg(self.REDIS, queue_name, group_name, msg_id, payload)
return res
except Exception as e:
if "key" in str(e):
@ -239,30 +243,24 @@ class RedisDB:
)
return None
def get_unacked_for(self, consumer_name, queue_name, group_name):
def get_unacked_iterator(self, queue_name, group_name, consumer_name):
try:
group_info = self.REDIS.xinfo_groups(queue_name)
if not any(e["name"] == group_name for e in group_info):
return
pendings = self.REDIS.xpending_range(
queue_name,
group_name,
min=0,
max=10000000000000,
count=1,
consumername=consumer_name,
)
if not pendings:
return
msg_id = pendings[0]["message_id"]
msg = self.REDIS.xrange(queue_name, min=msg_id, count=1)
_, payload = msg[0]
return Payload(self.REDIS, queue_name, group_name, msg_id, payload)
current_min = 0
while True:
payload = self.queue_consumer(queue_name, group_name, consumer_name, current_min)
if not payload:
return
current_min = payload.get_msg_id()
logging.info(f"RedisDB.get_unacked_iterator {consumer_name} msg_id {current_min}")
yield payload
except Exception as e:
if "key" in str(e):
return
logging.exception(
"RedisDB.get_unacked_for " + consumer_name + " got exception: " + str(e)
"RedisDB.get_unacked_iterator " + consumer_name + " got exception: "
)
self.__open__()

52
uv.lock generated
View File

@ -1,4 +1,5 @@
version = 1
revision = 1
requires-python = ">=3.10, <3.13"
resolution-markers = [
"python_full_version >= '3.12' and sys_platform == 'darwin'",
@ -1083,9 +1084,6 @@ name = "datrie"
version = "0.8.2"
source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
sdist = { url = "https://mirrors.aliyun.com/pypi/packages/9d/fe/db74bd405d515f06657f11ad529878fd389576dca4812bea6f98d9b31574/datrie-0.8.2.tar.gz", hash = "sha256:525b08f638d5cf6115df6ccd818e5a01298cd230b2dac91c8ff2e6499d18765d" }
wheels = [
{ url = "https://mirrors.aliyun.com/pypi/packages/44/02/53f0cf0bf0cd629ba6c2cc13f2f9db24323459e9c19463783d890a540a96/datrie-0.8.2-pp273-pypy_73-win32.whl", hash = "sha256:b07bd5fdfc3399a6dab86d6e35c72b1dbd598e80c97509c7c7518ab8774d3fda" },
]
[[package]]
name = "decorator"
@ -1362,17 +1360,17 @@ name = "fastembed-gpu"
version = "0.3.6"
source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
dependencies = [
{ name = "huggingface-hub" },
{ name = "loguru" },
{ name = "mmh3" },
{ name = "numpy" },
{ name = "onnxruntime-gpu" },
{ name = "pillow" },
{ name = "pystemmer" },
{ name = "requests" },
{ name = "snowballstemmer" },
{ name = "tokenizers" },
{ name = "tqdm" },
{ name = "huggingface-hub", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "loguru", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "mmh3", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "numpy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "onnxruntime-gpu", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "pillow", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "pystemmer", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "requests", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "snowballstemmer", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "tokenizers", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "tqdm", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
]
sdist = { url = "https://mirrors.aliyun.com/pypi/packages/da/07/7336c7f3d7ee47f33b407eeb50f5eeb152889de538a52a8f1cc637192816/fastembed_gpu-0.3.6.tar.gz", hash = "sha256:ee2de8918b142adbbf48caaffec0c492f864d73c073eea5a3dcd0e8c1041c50d" }
wheels = [
@ -3485,12 +3483,12 @@ name = "onnxruntime-gpu"
version = "1.19.2"
source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
dependencies = [
{ name = "coloredlogs" },
{ name = "flatbuffers" },
{ name = "numpy" },
{ name = "packaging" },
{ name = "protobuf" },
{ name = "sympy" },
{ name = "coloredlogs", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "flatbuffers", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "numpy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "packaging", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "protobuf", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "sympy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
]
wheels = [
{ url = "https://mirrors.aliyun.com/pypi/packages/d0/9c/3fa310e0730643051eb88e884f19813a6c8b67d0fbafcda610d960e589db/onnxruntime_gpu-1.19.2-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a49740e079e7c5215830d30cde3df792e903df007aa0b0fd7aa797937061b27a" },
@ -4164,15 +4162,6 @@ wheels = [
{ url = "https://mirrors.aliyun.com/pypi/packages/77/89/bc88a6711935ba795a679ea6ebee07e128050d6382eaa35a0a47c8032bdc/pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd" },
]
[[package]]
name = "pybind11"
version = "2.13.6"
source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
sdist = { url = "https://mirrors.aliyun.com/pypi/packages/d2/c1/72b9622fcb32ff98b054f724e213c7f70d6898baa714f4516288456ceaba/pybind11-2.13.6.tar.gz", hash = "sha256:ba6af10348c12b24e92fa086b39cfba0eff619b61ac77c406167d813b096d39a" }
wheels = [
{ url = "https://mirrors.aliyun.com/pypi/packages/13/2f/0f24b288e2ce56f51c920137620b4434a38fd80583dbbe24fc2a1656c388/pybind11-2.13.6-py3-none-any.whl", hash = "sha256:237c41e29157b962835d356b370ededd57594a26d5894a795960f0047cb5caf5" },
]
[[package]]
name = "pyclipper"
version = "1.3.0.post5"
@ -4230,8 +4219,6 @@ wheels = [
{ url = "https://mirrors.aliyun.com/pypi/packages/48/7d/0f2b09490b98cc6a902ac15dda8760c568b9c18cfe70e0ef7a16de64d53a/pycryptodomex-3.20.0-cp35-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:7a7a8f33a1f1fb762ede6cc9cbab8f2a9ba13b196bfaf7bc6f0b39d2ba315a43" },
{ url = "https://mirrors.aliyun.com/pypi/packages/b0/1c/375adb14b71ee1c8d8232904e928b3e7af5bbbca7c04e4bec94fe8e90c3d/pycryptodomex-3.20.0-cp35-abi3-win32.whl", hash = "sha256:c39778fd0548d78917b61f03c1fa8bfda6cfcf98c767decf360945fe6f97461e" },
{ url = "https://mirrors.aliyun.com/pypi/packages/b2/e8/1b92184ab7e5595bf38000587e6f8cf9556ebd1bf0a583619bee2057afbd/pycryptodomex-3.20.0-cp35-abi3-win_amd64.whl", hash = "sha256:2a47bcc478741b71273b917232f521fd5704ab4b25d301669879e7273d3586cc" },
{ url = "https://mirrors.aliyun.com/pypi/packages/e7/c5/9140bb867141d948c8e242013ec8a8011172233c898dfdba0a2417c3169a/pycryptodomex-3.20.0-pp27-pypy_73-manylinux2010_x86_64.whl", hash = "sha256:1be97461c439a6af4fe1cf8bf6ca5936d3db252737d2f379cc6b2e394e12a458" },
{ url = "https://mirrors.aliyun.com/pypi/packages/5e/6a/04acb4978ce08ab16890c70611ebc6efd251681341617bbb9e53356dee70/pycryptodomex-3.20.0-pp27-pypy_73-win32.whl", hash = "sha256:19764605feea0df966445d46533729b645033f134baeb3ea26ad518c9fdf212c" },
{ url = "https://mirrors.aliyun.com/pypi/packages/eb/df/3f1ea084e43b91e6d2b6b3493cc948864c17ea5d93ff1261a03812fbfd1a/pycryptodomex-3.20.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:f2e497413560e03421484189a6b65e33fe800d3bd75590e6d78d4dfdb7accf3b" },
{ url = "https://mirrors.aliyun.com/pypi/packages/c9/f3/83ffbdfa0c8f9154bcd8866895f6cae5a3ec749da8b0840603cf936c4412/pycryptodomex-3.20.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e48217c7901edd95f9f097feaa0388da215ed14ce2ece803d3f300b4e694abea" },
{ url = "https://mirrors.aliyun.com/pypi/packages/c9/9d/c113e640aaf02af5631ae2686b742aac5cd0e1402b9d6512b1c7ec5ef05d/pycryptodomex-3.20.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d00fe8596e1cc46b44bf3907354e9377aa030ec4cd04afbbf6e899fc1e2a7781" },
@ -4820,6 +4807,7 @@ dependencies = [
{ name = "tencentcloud-sdk-python" },
{ name = "tika" },
{ name = "tiktoken" },
{ name = "trio" },
{ name = "umap-learn" },
{ name = "valkey" },
{ name = "vertexai" },
@ -4954,6 +4942,7 @@ requires-dist = [
{ name = "tiktoken", specifier = "==0.7.0" },
{ name = "torch", marker = "extra == 'full'", specifier = ">=2.5.0,<3.0.0" },
{ name = "transformers", marker = "extra == 'full'", specifier = ">=4.35.0,<5.0.0" },
{ name = "trio", specifier = ">=0.29.0" },
{ name = "umap-learn", specifier = "==0.5.6" },
{ name = "valkey", specifier = "==6.0.2" },
{ name = "vertexai", specifier = "==1.64.0" },
@ -4969,6 +4958,7 @@ requires-dist = [
{ name = "yfinance", specifier = "==0.1.96" },
{ name = "zhipuai", specifier = "==2.0.1" },
]
provides-extras = ["full"]
[[package]]
name = "ranx"