mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-06-04 03:13:58 +08:00
Made task_executor async to speedup parsing (#5530)
### What problem does this PR solve? Made task_executor async to speedup parsing ### Type of change - [x] Performance Improvement
This commit is contained in:
parent
abac2ca2c5
commit
c813c1ff4c
@ -17,6 +17,7 @@ import json
|
||||
import re
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
import trio
|
||||
from api.db.db_models import APIToken
|
||||
|
||||
from api.db.services.conversation_service import ConversationService, structure_answer
|
||||
@ -386,7 +387,8 @@ def mindmap():
|
||||
rank_feature=label_question(question, [kb])
|
||||
)
|
||||
mindmap = MindMapExtractor(chat_mdl)
|
||||
mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output
|
||||
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]])
|
||||
mind_map = mind_map.output
|
||||
if "error" in mind_map:
|
||||
return server_error_response(Exception(mind_map["error"]))
|
||||
return get_json_result(data=mind_map)
|
||||
|
@ -22,6 +22,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
import trio
|
||||
|
||||
from peewee import fn
|
||||
|
||||
@ -597,8 +598,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
if parser_ids[doc_id] != ParserType.PICTURE.value:
|
||||
mindmap = MindMapExtractor(llm_bdl)
|
||||
try:
|
||||
mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output,
|
||||
ensure_ascii=False, indent=2)
|
||||
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in docs if c["doc_id"] == doc_id])
|
||||
mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2)
|
||||
if len(mind_map) < 32:
|
||||
raise Exception("Few content: " + mind_map)
|
||||
cks.append({
|
||||
|
@ -17,6 +17,8 @@ import base64
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
from io import BytesIO
|
||||
|
||||
import pdfplumber
|
||||
@ -30,6 +32,10 @@ from api.constants import IMG_BASE64_PREFIX
|
||||
PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
|
||||
RAG_BASE = os.getenv("RAG_BASE")
|
||||
|
||||
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
|
||||
if LOCK_KEY_pdfplumber not in sys.modules:
|
||||
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
|
||||
|
||||
|
||||
def get_project_base_directory(*args):
|
||||
global PROJECT_BASE
|
||||
@ -175,19 +181,20 @@ def thumbnail_img(filename, blob):
|
||||
"""
|
||||
filename = filename.lower()
|
||||
if re.match(r".*\.pdf$", filename):
|
||||
pdf = pdfplumber.open(BytesIO(blob))
|
||||
buffered = BytesIO()
|
||||
resolution = 32
|
||||
img = None
|
||||
for _ in range(10):
|
||||
# https://github.com/jsvine/pdfplumber?tab=readme-ov-file#creating-a-pageimage-with-to_image
|
||||
pdf.pages[0].to_image(resolution=resolution).annotated.save(buffered, format="png")
|
||||
img = buffered.getvalue()
|
||||
if len(img) >= 64000 and resolution >= 2:
|
||||
resolution = resolution / 2
|
||||
buffered = BytesIO()
|
||||
else:
|
||||
break
|
||||
with sys.modules[LOCK_KEY_pdfplumber]:
|
||||
pdf = pdfplumber.open(BytesIO(blob))
|
||||
buffered = BytesIO()
|
||||
resolution = 32
|
||||
img = None
|
||||
for _ in range(10):
|
||||
# https://github.com/jsvine/pdfplumber?tab=readme-ov-file#creating-a-pageimage-with-to_image
|
||||
pdf.pages[0].to_image(resolution=resolution).annotated.save(buffered, format="png")
|
||||
img = buffered.getvalue()
|
||||
if len(img) >= 64000 and resolution >= 2:
|
||||
resolution = resolution / 2
|
||||
buffered = BytesIO()
|
||||
else:
|
||||
break
|
||||
pdf.close()
|
||||
return img
|
||||
|
||||
|
@ -18,6 +18,8 @@ import os.path
|
||||
import logging
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
initialized_root_logger = False
|
||||
|
||||
def get_project_base_directory():
|
||||
PROJECT_BASE = os.path.abspath(
|
||||
os.path.join(
|
||||
@ -29,10 +31,13 @@ def get_project_base_directory():
|
||||
return PROJECT_BASE
|
||||
|
||||
def initRootLogger(logfile_basename: str, log_format: str = "%(asctime)-15s %(levelname)-8s %(process)d %(message)s"):
|
||||
logger = logging.getLogger()
|
||||
if logger.hasHandlers():
|
||||
global initialized_root_logger
|
||||
if initialized_root_logger:
|
||||
return
|
||||
initialized_root_logger = True
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger.handlers.clear()
|
||||
log_path = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{logfile_basename}.log"))
|
||||
|
||||
os.makedirs(os.path.dirname(log_path), exist_ok=True)
|
||||
|
@ -18,6 +18,8 @@ import logging
|
||||
import os
|
||||
import random
|
||||
from timeit import default_timer as timer
|
||||
import sys
|
||||
import threading
|
||||
|
||||
import xgboost as xgb
|
||||
from io import BytesIO
|
||||
@ -34,6 +36,10 @@ from rag.nlp import rag_tokenizer
|
||||
from copy import deepcopy
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
|
||||
if LOCK_KEY_pdfplumber not in sys.modules:
|
||||
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
|
||||
|
||||
class RAGFlowPdfParser:
|
||||
def __init__(self):
|
||||
self.ocr = OCR()
|
||||
@ -948,8 +954,9 @@ class RAGFlowPdfParser:
|
||||
@staticmethod
|
||||
def total_page_number(fnm, binary=None):
|
||||
try:
|
||||
pdf = pdfplumber.open(
|
||||
fnm) if not binary else pdfplumber.open(BytesIO(binary))
|
||||
with sys.modules[LOCK_KEY_pdfplumber]:
|
||||
pdf = pdfplumber.open(
|
||||
fnm) if not binary else pdfplumber.open(BytesIO(binary))
|
||||
total_page = len(pdf.pages)
|
||||
pdf.close()
|
||||
return total_page
|
||||
@ -968,17 +975,18 @@ class RAGFlowPdfParser:
|
||||
self.page_from = page_from
|
||||
start = timer()
|
||||
try:
|
||||
self.pdf = pdfplumber.open(fnm) if isinstance(
|
||||
fnm, str) else pdfplumber.open(BytesIO(fnm))
|
||||
self.page_images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
|
||||
enumerate(self.pdf.pages[page_from:page_to])]
|
||||
try:
|
||||
self.page_chars = [[c for c in page.dedupe_chars().chars if self._has_color(c)] for page in self.pdf.pages[page_from:page_to]]
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to extract characters for pages {page_from}-{page_to}: {str(e)}")
|
||||
self.page_chars = [[] for _ in range(page_to - page_from)] # If failed to extract, using empty list instead.
|
||||
|
||||
self.total_page = len(self.pdf.pages)
|
||||
with sys.modules[LOCK_KEY_pdfplumber]:
|
||||
self.pdf = pdfplumber.open(fnm) if isinstance(
|
||||
fnm, str) else pdfplumber.open(BytesIO(fnm))
|
||||
self.page_images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
|
||||
enumerate(self.pdf.pages[page_from:page_to])]
|
||||
try:
|
||||
self.page_chars = [[c for c in page.dedupe_chars().chars if self._has_color(c)] for page in self.pdf.pages[page_from:page_to]]
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to extract characters for pages {page_from}-{page_to}: {str(e)}")
|
||||
self.page_chars = [[] for _ in range(page_to - page_from)] # If failed to extract, using empty list instead.
|
||||
|
||||
self.total_page = len(self.pdf.pages)
|
||||
except Exception:
|
||||
logging.exception("RAGFlowPdfParser __images__")
|
||||
logging.info(f"__images__ dedupe_chars cost {timer() - start}s")
|
||||
|
@ -14,7 +14,8 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import io
|
||||
|
||||
import sys
|
||||
import threading
|
||||
import pdfplumber
|
||||
|
||||
from .ocr import OCR
|
||||
@ -23,6 +24,11 @@ from .layout_recognizer import LayoutRecognizer4YOLOv10 as LayoutRecognizer
|
||||
from .table_structure_recognizer import TableStructureRecognizer
|
||||
|
||||
|
||||
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
|
||||
if LOCK_KEY_pdfplumber not in sys.modules:
|
||||
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
|
||||
|
||||
|
||||
def init_in_out(args):
|
||||
from PIL import Image
|
||||
import os
|
||||
@ -36,9 +42,10 @@ def init_in_out(args):
|
||||
|
||||
def pdf_pages(fnm, zoomin=3):
|
||||
nonlocal outputs, images
|
||||
pdf = pdfplumber.open(fnm)
|
||||
images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
|
||||
enumerate(pdf.pages)]
|
||||
with sys.modules[LOCK_KEY_pdfplumber]:
|
||||
pdf = pdfplumber.open(fnm)
|
||||
images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
|
||||
enumerate(pdf.pages)]
|
||||
|
||||
for i, page in enumerate(images):
|
||||
outputs.append(os.path.split(fnm)[-1] + f"_{i}.jpg")
|
||||
|
@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import itertools
|
||||
import re
|
||||
import time
|
||||
@ -21,13 +20,14 @@ from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
import networkx as nx
|
||||
import trio
|
||||
|
||||
from graphrag.general.extractor import Extractor
|
||||
from rag.nlp import is_english
|
||||
import editdistance
|
||||
from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
from graphrag.utils import perform_variable_replacements
|
||||
from graphrag.utils import perform_variable_replacements, chat_limiter
|
||||
|
||||
DEFAULT_RECORD_DELIMITER = "##"
|
||||
DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
|
||||
@ -67,13 +67,13 @@ class EntityResolution(Extractor):
|
||||
self._resolution_result_delimiter_key = "resolution_result_delimiter"
|
||||
self._input_text_key = "input_text"
|
||||
|
||||
def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult:
|
||||
async def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult:
|
||||
"""Call method definition."""
|
||||
if prompt_variables is None:
|
||||
prompt_variables = {}
|
||||
|
||||
# Wire defaults into the prompt variables
|
||||
prompt_variables = {
|
||||
self.prompt_variables = {
|
||||
**prompt_variables,
|
||||
self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
|
||||
or DEFAULT_RECORD_DELIMITER,
|
||||
@ -94,39 +94,12 @@ class EntityResolution(Extractor):
|
||||
for k, v in node_clusters.items():
|
||||
candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if self.is_similarity(a, b)]
|
||||
|
||||
gen_conf = {"temperature": 0.5}
|
||||
resolution_result = set()
|
||||
for candidate_resolution_i in candidate_resolution.items():
|
||||
if candidate_resolution_i[1]:
|
||||
try:
|
||||
pair_txt = [
|
||||
f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n']
|
||||
for index, candidate in enumerate(candidate_resolution_i[1]):
|
||||
pair_txt.append(
|
||||
f'Question {index + 1}: name of{candidate_resolution_i[0]} A is {candidate[0]} ,name of{candidate_resolution_i[0]} B is {candidate[1]}')
|
||||
sent = 'question above' if len(pair_txt) == 1 else f'above {len(pair_txt)} questions'
|
||||
pair_txt.append(
|
||||
f'\nUse domain knowledge of {candidate_resolution_i[0]}s to help understand the text and answer the {sent} in the format: For Question i, Yes, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are the same {candidate_resolution_i[0]}./No, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are different {candidate_resolution_i[0]}s. For Question i+1, (repeat the above procedures)')
|
||||
pair_prompt = '\n'.join(pair_txt)
|
||||
|
||||
variables = {
|
||||
**prompt_variables,
|
||||
self._input_text_key: pair_prompt
|
||||
}
|
||||
text = perform_variable_replacements(self._resolution_prompt, variables=variables)
|
||||
|
||||
response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
|
||||
result = self._process_results(len(candidate_resolution_i[1]), response,
|
||||
prompt_variables.get(self._record_delimiter_key,
|
||||
DEFAULT_RECORD_DELIMITER),
|
||||
prompt_variables.get(self._entity_index_dilimiter_key,
|
||||
DEFAULT_ENTITY_INDEX_DELIMITER),
|
||||
prompt_variables.get(self._resolution_result_delimiter_key,
|
||||
DEFAULT_RESOLUTION_RESULT_DELIMITER))
|
||||
for result_i in result:
|
||||
resolution_result.add(candidate_resolution_i[1][result_i[0] - 1])
|
||||
except Exception:
|
||||
logging.exception("error entity resolution")
|
||||
async with trio.open_nursery() as nursery:
|
||||
for candidate_resolution_i in candidate_resolution.items():
|
||||
if not candidate_resolution_i[1]:
|
||||
continue
|
||||
nursery.start_soon(self._resolve_candidate(candidate_resolution_i, resolution_result))
|
||||
|
||||
connect_graph = nx.Graph()
|
||||
removed_entities = []
|
||||
@ -172,6 +145,34 @@ class EntityResolution(Extractor):
|
||||
removed_entities=removed_entities
|
||||
)
|
||||
|
||||
async def _resolve_candidate(self, candidate_resolution_i, resolution_result):
|
||||
gen_conf = {"temperature": 0.5}
|
||||
pair_txt = [
|
||||
f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n']
|
||||
for index, candidate in enumerate(candidate_resolution_i[1]):
|
||||
pair_txt.append(
|
||||
f'Question {index + 1}: name of{candidate_resolution_i[0]} A is {candidate[0]} ,name of{candidate_resolution_i[0]} B is {candidate[1]}')
|
||||
sent = 'question above' if len(pair_txt) == 1 else f'above {len(pair_txt)} questions'
|
||||
pair_txt.append(
|
||||
f'\nUse domain knowledge of {candidate_resolution_i[0]}s to help understand the text and answer the {sent} in the format: For Question i, Yes, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are the same {candidate_resolution_i[0]}./No, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are different {candidate_resolution_i[0]}s. For Question i+1, (repeat the above procedures)')
|
||||
pair_prompt = '\n'.join(pair_txt)
|
||||
variables = {
|
||||
**self.prompt_variables,
|
||||
self._input_text_key: pair_prompt
|
||||
}
|
||||
text = perform_variable_replacements(self._resolution_prompt, variables=variables)
|
||||
async with chat_limiter:
|
||||
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
|
||||
result = self._process_results(len(candidate_resolution_i[1]), response,
|
||||
self.prompt_variables.get(self._record_delimiter_key,
|
||||
DEFAULT_RECORD_DELIMITER),
|
||||
self.prompt_variables.get(self._entity_index_dilimiter_key,
|
||||
DEFAULT_ENTITY_INDEX_DELIMITER),
|
||||
self.prompt_variables.get(self._resolution_result_delimiter_key,
|
||||
DEFAULT_RESOLUTION_RESULT_DELIMITER))
|
||||
for result_i in result:
|
||||
resolution_result.add(candidate_resolution_i[1][result_i[0] - 1])
|
||||
|
||||
def _process_results(
|
||||
self,
|
||||
records_length: int,
|
||||
|
@ -1,268 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
"""
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
|
||||
from graphrag.general.claim_prompt import CLAIM_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
|
||||
from graphrag.general.extractor import Extractor
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
|
||||
|
||||
DEFAULT_TUPLE_DELIMITER = "<|>"
|
||||
DEFAULT_RECORD_DELIMITER = "##"
|
||||
DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
|
||||
CLAIM_MAX_GLEANINGS = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClaimExtractorResult:
|
||||
"""Claim extractor result class definition."""
|
||||
|
||||
output: list[dict]
|
||||
source_docs: dict[str, Any]
|
||||
|
||||
|
||||
class ClaimExtractor(Extractor):
|
||||
"""Claim extractor class definition."""
|
||||
|
||||
_extraction_prompt: str
|
||||
_summary_prompt: str
|
||||
_output_formatter_prompt: str
|
||||
_input_text_key: str
|
||||
_input_entity_spec_key: str
|
||||
_input_claim_description_key: str
|
||||
_tuple_delimiter_key: str
|
||||
_record_delimiter_key: str
|
||||
_completion_delimiter_key: str
|
||||
_max_gleanings: int
|
||||
_on_error: ErrorHandlerFn
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_invoker: CompletionLLM,
|
||||
extraction_prompt: str | None = None,
|
||||
input_text_key: str | None = None,
|
||||
input_entity_spec_key: str | None = None,
|
||||
input_claim_description_key: str | None = None,
|
||||
input_resolved_entities_key: str | None = None,
|
||||
tuple_delimiter_key: str | None = None,
|
||||
record_delimiter_key: str | None = None,
|
||||
completion_delimiter_key: str | None = None,
|
||||
encoding_model: str | None = None,
|
||||
max_gleanings: int | None = None,
|
||||
on_error: ErrorHandlerFn | None = None,
|
||||
):
|
||||
"""Init method definition."""
|
||||
self._llm = llm_invoker
|
||||
self._extraction_prompt = extraction_prompt or CLAIM_EXTRACTION_PROMPT
|
||||
self._input_text_key = input_text_key or "input_text"
|
||||
self._input_entity_spec_key = input_entity_spec_key or "entity_specs"
|
||||
self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter"
|
||||
self._record_delimiter_key = record_delimiter_key or "record_delimiter"
|
||||
self._completion_delimiter_key = (
|
||||
completion_delimiter_key or "completion_delimiter"
|
||||
)
|
||||
self._input_claim_description_key = (
|
||||
input_claim_description_key or "claim_description"
|
||||
)
|
||||
self._input_resolved_entities_key = (
|
||||
input_resolved_entities_key or "resolved_entities"
|
||||
)
|
||||
self._max_gleanings = (
|
||||
max_gleanings if max_gleanings is not None else CLAIM_MAX_GLEANINGS
|
||||
)
|
||||
self._on_error = on_error or (lambda _e, _s, _d: None)
|
||||
|
||||
# Construct the looping arguments
|
||||
encoding = tiktoken.get_encoding(encoding_model or "cl100k_base")
|
||||
yes = encoding.encode("YES")
|
||||
no = encoding.encode("NO")
|
||||
self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
|
||||
|
||||
def __call__(
|
||||
self, inputs: dict[str, Any], prompt_variables: dict | None = None
|
||||
) -> ClaimExtractorResult:
|
||||
"""Call method definition."""
|
||||
if prompt_variables is None:
|
||||
prompt_variables = {}
|
||||
texts = inputs[self._input_text_key]
|
||||
entity_spec = str(inputs[self._input_entity_spec_key])
|
||||
claim_description = inputs[self._input_claim_description_key]
|
||||
resolved_entities = inputs.get(self._input_resolved_entities_key, {})
|
||||
source_doc_map = {}
|
||||
|
||||
prompt_args = {
|
||||
self._input_entity_spec_key: entity_spec,
|
||||
self._input_claim_description_key: claim_description,
|
||||
self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key)
|
||||
or DEFAULT_TUPLE_DELIMITER,
|
||||
self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
|
||||
or DEFAULT_RECORD_DELIMITER,
|
||||
self._completion_delimiter_key: prompt_variables.get(
|
||||
self._completion_delimiter_key
|
||||
)
|
||||
or DEFAULT_COMPLETION_DELIMITER,
|
||||
}
|
||||
|
||||
all_claims: list[dict] = []
|
||||
for doc_index, text in enumerate(texts):
|
||||
document_id = f"d{doc_index}"
|
||||
try:
|
||||
claims = self._process_document(prompt_args, text, doc_index)
|
||||
all_claims += [
|
||||
self._clean_claim(c, document_id, resolved_entities) for c in claims
|
||||
]
|
||||
source_doc_map[document_id] = text
|
||||
except Exception as e:
|
||||
logging.exception("error extracting claim")
|
||||
self._on_error(
|
||||
e,
|
||||
traceback.format_exc(),
|
||||
{"doc_index": doc_index, "text": text},
|
||||
)
|
||||
continue
|
||||
|
||||
return ClaimExtractorResult(
|
||||
output=all_claims,
|
||||
source_docs=source_doc_map,
|
||||
)
|
||||
|
||||
def _clean_claim(
|
||||
self, claim: dict, document_id: str, resolved_entities: dict
|
||||
) -> dict:
|
||||
# clean the parsed claims to remove any claims with status = False
|
||||
obj = claim.get("object_id", claim.get("object"))
|
||||
subject = claim.get("subject_id", claim.get("subject"))
|
||||
|
||||
# If subject or object in resolved entities, then replace with resolved entity
|
||||
obj = resolved_entities.get(obj, obj)
|
||||
subject = resolved_entities.get(subject, subject)
|
||||
claim["object_id"] = obj
|
||||
claim["subject_id"] = subject
|
||||
claim["doc_id"] = document_id
|
||||
return claim
|
||||
|
||||
def _process_document(
|
||||
self, prompt_args: dict, doc, doc_index: int
|
||||
) -> list[dict]:
|
||||
record_delimiter = prompt_args.get(
|
||||
self._record_delimiter_key, DEFAULT_RECORD_DELIMITER
|
||||
)
|
||||
completion_delimiter = prompt_args.get(
|
||||
self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER
|
||||
)
|
||||
variables = {
|
||||
self._input_text_key: doc,
|
||||
**prompt_args,
|
||||
}
|
||||
text = perform_variable_replacements(self._extraction_prompt, variables=variables)
|
||||
gen_conf = {"temperature": 0.5}
|
||||
results = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
|
||||
claims = results.strip().removesuffix(completion_delimiter)
|
||||
history = [{"role": "system", "content": text}, {"role": "assistant", "content": results}]
|
||||
|
||||
# Repeat to ensure we maximize entity count
|
||||
for i in range(self._max_gleanings):
|
||||
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
|
||||
history.append({"role": "user", "content": text})
|
||||
extension = self._chat("", history, gen_conf)
|
||||
claims += record_delimiter + extension.strip().removesuffix(
|
||||
completion_delimiter
|
||||
)
|
||||
|
||||
# If this isn't the last loop, check to see if we should continue
|
||||
if i >= self._max_gleanings - 1:
|
||||
break
|
||||
|
||||
history.append({"role": "assistant", "content": extension})
|
||||
history.append({"role": "user", "content": LOOP_PROMPT})
|
||||
continuation = self._chat("", history, self._loop_args)
|
||||
if continuation != "YES":
|
||||
break
|
||||
|
||||
result = self._parse_claim_tuples(claims, prompt_args)
|
||||
for r in result:
|
||||
r["doc_id"] = f"{doc_index}"
|
||||
return result
|
||||
|
||||
def _parse_claim_tuples(
|
||||
self, claims: str, prompt_variables: dict
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Parse claim tuples."""
|
||||
record_delimiter = prompt_variables.get(
|
||||
self._record_delimiter_key, DEFAULT_RECORD_DELIMITER
|
||||
)
|
||||
completion_delimiter = prompt_variables.get(
|
||||
self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER
|
||||
)
|
||||
tuple_delimiter = prompt_variables.get(
|
||||
self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER
|
||||
)
|
||||
|
||||
def pull_field(index: int, fields: list[str]) -> str | None:
|
||||
return fields[index].strip() if len(fields) > index else None
|
||||
|
||||
result: list[dict[str, Any]] = []
|
||||
claims_values = (
|
||||
claims.strip().removesuffix(completion_delimiter).split(record_delimiter)
|
||||
)
|
||||
for claim in claims_values:
|
||||
claim = claim.strip().removeprefix("(").removesuffix(")")
|
||||
claim = re.sub(r".*Output:", "", claim)
|
||||
|
||||
# Ignore the completion delimiter
|
||||
if claim == completion_delimiter:
|
||||
continue
|
||||
|
||||
claim_fields = claim.split(tuple_delimiter)
|
||||
o = {
|
||||
"subject_id": pull_field(0, claim_fields),
|
||||
"object_id": pull_field(1, claim_fields),
|
||||
"type": pull_field(2, claim_fields),
|
||||
"status": pull_field(3, claim_fields),
|
||||
"start_date": pull_field(4, claim_fields),
|
||||
"end_date": pull_field(5, claim_fields),
|
||||
"description": pull_field(6, claim_fields),
|
||||
"source_text": pull_field(7, claim_fields),
|
||||
"doc_id": pull_field(8, claim_fields),
|
||||
}
|
||||
if any([not o["subject_id"], not o["object_id"], o["subject_id"].lower() == "none", o["object_id"] == "none"]):
|
||||
continue
|
||||
result.append(o)
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True)
|
||||
parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api import settings
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
|
||||
|
||||
ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
|
||||
docs = [d["content_with_weight"] for d in settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])]
|
||||
info = {
|
||||
"input_text": docs,
|
||||
"entity_specs": "organization, person",
|
||||
"claim_description": ""
|
||||
}
|
||||
claim = ex(info)
|
||||
logging.info(json.dumps(claim.output, ensure_ascii=False, indent=2))
|
@ -1,71 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
"""
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
"""
|
||||
|
||||
CLAIM_EXTRACTION_PROMPT = """
|
||||
################
|
||||
-Target activity-
|
||||
################
|
||||
You are an intelligent assistant that helps a human analyst to analyze claims against certain entities presented in a text document.
|
||||
|
||||
################
|
||||
-Goal-
|
||||
################
|
||||
Given a text document that is potentially relevant to this activity, an entity specification, and a claim description, extract all entities that match the entity specification and all claims against those entities.
|
||||
|
||||
################
|
||||
-Steps-
|
||||
################
|
||||
- 1. Extract all named entities that match the predefined entity specification. Entity specification can either be a list of entity names or a list of entity types.
|
||||
- 2. For each entity identified in step 1, extract all claims associated with the entity. Claims need to match the specified claim description, and the entity should be the subject of the claim.
|
||||
For each claim, extract the following information:
|
||||
- Subject: name of the entity that is subject of the claim, capitalized. The subject entity is one that committed the action described in the claim. Subject needs to be one of the named entities identified in step 1.
|
||||
- Object: name of the entity that is object of the claim, capitalized. The object entity is one that either reports/handles or is affected by the action described in the claim. If object entity is unknown, use **NONE**.
|
||||
- Claim Type: overall category of the claim, capitalized. Name it in a way that can be repeated across multiple text inputs, so that similar claims share the same claim type
|
||||
- Claim Status: **TRUE**, **FALSE**, or **SUSPECTED**. TRUE means the claim is confirmed, FALSE means the claim is found to be False, SUSPECTED means the claim is not verified.
|
||||
- Claim Description: Detailed description explaining the reasoning behind the claim, together with all the related evidence and references.
|
||||
- Claim Date: Period (start_date, end_date) when the claim was made. Both start_date and end_date should be in ISO-8601 format. If the claim was made on a single date rather than a date range, set the same date for both start_date and end_date. If date is unknown, return **NONE**.
|
||||
- Claim Source Text: List of **all** quotes from the original text that are relevant to the claim.
|
||||
|
||||
- 3. Format each claim as (<subject_entity>{tuple_delimiter}<object_entity>{tuple_delimiter}<claim_type>{tuple_delimiter}<claim_status>{tuple_delimiter}<claim_start_date>{tuple_delimiter}<claim_end_date>{tuple_delimiter}<claim_description>{tuple_delimiter}<claim_source>)
|
||||
- 4. Return output in language of the 'Text' as a single list of all the claims identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
|
||||
- 5. If there's nothing satisfy the above requirements, just keep output empty.
|
||||
- 6. When finished, output {completion_delimiter}
|
||||
|
||||
################
|
||||
-Examples-
|
||||
################
|
||||
Example 1:
|
||||
Entity specification: organization
|
||||
Claim description: red flags associated with an entity
|
||||
Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015.
|
||||
Output:
|
||||
(COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.)
|
||||
{completion_delimiter}
|
||||
|
||||
###########################
|
||||
Example 2:
|
||||
Entity specification: Company A, Person C
|
||||
Claim description: red flags associated with an entity
|
||||
Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015.
|
||||
Output:
|
||||
(COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.)
|
||||
{record_delimiter}
|
||||
(PERSON C{tuple_delimiter}NONE{tuple_delimiter}CORRUPTION{tuple_delimiter}SUSPECTED{tuple_delimiter}2015-01-01T00:00:00{tuple_delimiter}2015-12-30T00:00:00{tuple_delimiter}Person C was suspected of engaging in corruption activities in 2015{tuple_delimiter}The company is owned by Person C who was suspected of engaging in corruption activities in 2015)
|
||||
{completion_delimiter}
|
||||
|
||||
################
|
||||
-Real Data-
|
||||
################
|
||||
Use the following input for your answer.
|
||||
Entity specification: {entity_specs}
|
||||
Claim description: {claim_description}
|
||||
Text: {input_text}
|
||||
Output:"""
|
||||
|
||||
|
||||
CONTINUE_PROMPT = "MANY entities were missed in the last extraction. Add them below using the same format(see 'Steps', start with the 'Output').\nOutput: "
|
||||
LOOP_PROMPT = "It appears some entities may have still been missed. Answer YES {tuple_delimiter} NO if there are still entities that need to be added.\n"
|
@ -17,9 +17,10 @@ from graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
|
||||
from graphrag.general.extractor import Extractor
|
||||
from graphrag.general.leiden import add_community_info2graph
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types
|
||||
from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
|
||||
from rag.utils import num_tokens_from_string
|
||||
from timeit import default_timer as timer
|
||||
import trio
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -52,7 +53,7 @@ class CommunityReportsExtractor(Extractor):
|
||||
self._extraction_prompt = COMMUNITY_REPORT_PROMPT
|
||||
self._max_report_length = max_report_length or 1500
|
||||
|
||||
def __call__(self, graph: nx.Graph, callback: Callable | None = None):
|
||||
async def __call__(self, graph: nx.Graph, callback: Callable | None = None):
|
||||
for node_degree in graph.degree:
|
||||
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
|
||||
|
||||
@ -86,28 +87,25 @@ class CommunityReportsExtractor(Extractor):
|
||||
}
|
||||
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
|
||||
gen_conf = {"temperature": 0.3}
|
||||
try:
|
||||
response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
|
||||
token_count += num_tokens_from_string(text + response)
|
||||
response = re.sub(r"^[^\{]*", "", response)
|
||||
response = re.sub(r"[^\}]*$", "", response)
|
||||
response = re.sub(r"\{\{", "{", response)
|
||||
response = re.sub(r"\}\}", "}", response)
|
||||
logging.debug(response)
|
||||
response = json.loads(response)
|
||||
if not dict_has_keys_with_types(response, [
|
||||
("title", str),
|
||||
("summary", str),
|
||||
("findings", list),
|
||||
("rating", float),
|
||||
("rating_explanation", str),
|
||||
]):
|
||||
continue
|
||||
response["weight"] = weight
|
||||
response["entities"] = ents
|
||||
except Exception:
|
||||
logging.exception("CommunityReportsExtractor got exception")
|
||||
async with chat_limiter:
|
||||
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
|
||||
token_count += num_tokens_from_string(text + response)
|
||||
response = re.sub(r"^[^\{]*", "", response)
|
||||
response = re.sub(r"[^\}]*$", "", response)
|
||||
response = re.sub(r"\{\{", "{", response)
|
||||
response = re.sub(r"\}\}", "}", response)
|
||||
logging.debug(response)
|
||||
response = json.loads(response)
|
||||
if not dict_has_keys_with_types(response, [
|
||||
("title", str),
|
||||
("summary", str),
|
||||
("findings", list),
|
||||
("rating", float),
|
||||
("rating_explanation", str),
|
||||
]):
|
||||
continue
|
||||
response["weight"] = weight
|
||||
response["entities"] = ents
|
||||
|
||||
add_community_info2graph(graph, ents, response["title"])
|
||||
res_str.append(self._get_text_output(response))
|
||||
|
@ -14,16 +14,15 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict, Counter
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from copy import deepcopy
|
||||
from typing import Callable
|
||||
import trio
|
||||
|
||||
from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT
|
||||
from graphrag.utils import get_llm_cache, set_llm_cache, handle_single_entity_extraction, \
|
||||
handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list
|
||||
handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list, chat_limiter
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
from rag.utils import truncate
|
||||
|
||||
@ -91,54 +90,50 @@ class Extractor:
|
||||
)
|
||||
return dict(maybe_nodes), dict(maybe_edges)
|
||||
|
||||
def __call__(
|
||||
async def __call__(
|
||||
self, chunks: list[tuple[str, str]],
|
||||
callback: Callable | None = None
|
||||
):
|
||||
|
||||
results = []
|
||||
max_workers = int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 10))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as exe:
|
||||
threads = []
|
||||
self.callback = callback
|
||||
start_ts = trio.current_time()
|
||||
out_results = []
|
||||
async with trio.open_nursery() as nursery:
|
||||
for i, (cid, ck) in enumerate(chunks):
|
||||
ck = truncate(ck, int(self._llm.max_length*0.8))
|
||||
threads.append(
|
||||
exe.submit(self._process_single_content, (cid, ck)))
|
||||
|
||||
for i, _ in enumerate(threads):
|
||||
n, r, tc = _.result()
|
||||
if not isinstance(n, Exception):
|
||||
results.append((n, r))
|
||||
if callback:
|
||||
callback(0.5 + 0.1 * i / len(threads), f"Entities extraction progress ... {i + 1}/{len(threads)} ({tc} tokens)")
|
||||
elif callback:
|
||||
callback(msg="Knowledge graph extraction error:{}".format(str(n)))
|
||||
nursery.start_soon(self._process_single_content, (cid, ck), i, len(chunks), out_results)
|
||||
|
||||
maybe_nodes = defaultdict(list)
|
||||
maybe_edges = defaultdict(list)
|
||||
for m_nodes, m_edges in results:
|
||||
sum_token_count = 0
|
||||
for m_nodes, m_edges, token_count in out_results:
|
||||
for k, v in m_nodes.items():
|
||||
maybe_nodes[k].extend(v)
|
||||
for k, v in m_edges.items():
|
||||
maybe_edges[tuple(sorted(k))].extend(v)
|
||||
logging.info("Inserting entities into storage...")
|
||||
sum_token_count += token_count
|
||||
now = trio.current_time()
|
||||
if callback:
|
||||
callback(msg = f"Entities and relationships extraction done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {sum_token_count} tokens, {now-start_ts:.2f}s.")
|
||||
start_ts = now
|
||||
logging.info("Entities merging...")
|
||||
all_entities_data = []
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as exe:
|
||||
threads = []
|
||||
async with trio.open_nursery() as nursery:
|
||||
for en_nm, ents in maybe_nodes.items():
|
||||
threads.append(
|
||||
exe.submit(self._merge_nodes, en_nm, ents))
|
||||
for t in threads:
|
||||
n = t.result()
|
||||
if not isinstance(n, Exception):
|
||||
all_entities_data.append(n)
|
||||
elif callback:
|
||||
callback(msg="Knowledge graph nodes merging error: {}".format(str(n)))
|
||||
nursery.start_soon(self._merge_nodes, en_nm, ents, all_entities_data)
|
||||
now = trio.current_time()
|
||||
if callback:
|
||||
callback(msg = f"Entities merging done, {now-start_ts:.2f}s.")
|
||||
|
||||
logging.info("Inserting relationships into storage...")
|
||||
start_ts = now
|
||||
logging.info("Relationships merging...")
|
||||
all_relationships_data = []
|
||||
for (src, tgt), rels in maybe_edges.items():
|
||||
all_relationships_data.append(self._merge_edges(src, tgt, rels))
|
||||
async with trio.open_nursery() as nursery:
|
||||
for (src, tgt), rels in maybe_edges.items():
|
||||
nursery.start_soon(self._merge_edges, src, tgt, rels, all_relationships_data)
|
||||
now = trio.current_time()
|
||||
if callback:
|
||||
callback(msg = f"Relationships merging done, {now-start_ts:.2f}s.")
|
||||
|
||||
if not len(all_entities_data) and not len(all_relationships_data):
|
||||
logging.warning(
|
||||
@ -152,7 +147,7 @@ class Extractor:
|
||||
|
||||
return all_entities_data, all_relationships_data
|
||||
|
||||
def _merge_nodes(self, entity_name: str, entities: list[dict]):
|
||||
async def _merge_nodes(self, entity_name: str, entities: list[dict], all_relationships_data):
|
||||
if not entities:
|
||||
return
|
||||
already_entity_types = []
|
||||
@ -176,26 +171,22 @@ class Extractor:
|
||||
sorted(set([dp["description"] for dp in entities] + already_description))
|
||||
)
|
||||
already_source_ids = flat_uniq_list(entities, "source_id")
|
||||
try:
|
||||
description = self._handle_entity_relation_summary(
|
||||
entity_name, description
|
||||
)
|
||||
node_data = dict(
|
||||
entity_type=entity_type,
|
||||
description=description,
|
||||
source_id=already_source_ids,
|
||||
)
|
||||
node_data["entity_name"] = entity_name
|
||||
self._set_entity_(entity_name, node_data)
|
||||
return node_data
|
||||
except Exception as e:
|
||||
return e
|
||||
description = await self._handle_entity_relation_summary(entity_name, description)
|
||||
node_data = dict(
|
||||
entity_type=entity_type,
|
||||
description=description,
|
||||
source_id=already_source_ids,
|
||||
)
|
||||
node_data["entity_name"] = entity_name
|
||||
self._set_entity_(entity_name, node_data)
|
||||
all_relationships_data.append(node_data)
|
||||
|
||||
def _merge_edges(
|
||||
async def _merge_edges(
|
||||
self,
|
||||
src_id: str,
|
||||
tgt_id: str,
|
||||
edges_data: list[dict]
|
||||
edges_data: list[dict],
|
||||
all_relationships_data
|
||||
):
|
||||
if not edges_data:
|
||||
return
|
||||
@ -226,7 +217,7 @@ class Extractor:
|
||||
"description": description,
|
||||
"entity_type": 'UNKNOWN'
|
||||
})
|
||||
description = self._handle_entity_relation_summary(
|
||||
description = await self._handle_entity_relation_summary(
|
||||
f"({src_id}, {tgt_id})", description
|
||||
)
|
||||
edge_data = dict(
|
||||
@ -238,10 +229,9 @@ class Extractor:
|
||||
source_id=source_id
|
||||
)
|
||||
self._set_relation_(src_id, tgt_id, edge_data)
|
||||
all_relationships_data.append(edge_data)
|
||||
|
||||
return edge_data
|
||||
|
||||
def _handle_entity_relation_summary(
|
||||
async def _handle_entity_relation_summary(
|
||||
self,
|
||||
entity_or_relation_name: str,
|
||||
description: str
|
||||
@ -256,5 +246,6 @@ class Extractor:
|
||||
)
|
||||
use_prompt = prompt_template.format(**context_base)
|
||||
logging.info(f"Trigger summary: {entity_or_relation_name}")
|
||||
summary = self._chat(use_prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.8})
|
||||
async with chat_limiter:
|
||||
summary = await trio.to_thread.run_sync(lambda: self._chat(use_prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.8}))
|
||||
return summary
|
||||
|
@ -5,15 +5,15 @@ Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Callable
|
||||
from dataclasses import dataclass
|
||||
import tiktoken
|
||||
import trio
|
||||
|
||||
from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS, DEFAULT_ENTITY_TYPES
|
||||
from graphrag.general.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
|
||||
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
|
||||
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, chat_limiter
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
import networkx as nx
|
||||
from rag.utils import num_tokens_from_string
|
||||
@ -102,53 +102,47 @@ class GraphExtractor(Extractor):
|
||||
self._entity_types_key: ",".join(DEFAULT_ENTITY_TYPES),
|
||||
}
|
||||
|
||||
def _process_single_content(self,
|
||||
chunk_key_dp: tuple[str, str]
|
||||
):
|
||||
async def _process_single_content(self, chunk_key_dp: tuple[str, str], chunk_seq: int, num_chunks: int, out_results):
|
||||
token_count = 0
|
||||
|
||||
chunk_key = chunk_key_dp[0]
|
||||
content = chunk_key_dp[1]
|
||||
variables = {
|
||||
**self._prompt_variables,
|
||||
self._input_text_key: content,
|
||||
}
|
||||
try:
|
||||
gen_conf = {"temperature": 0.3}
|
||||
hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
|
||||
response = self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf)
|
||||
token_count += num_tokens_from_string(hint_prompt + response)
|
||||
|
||||
results = response or ""
|
||||
history = [{"role": "system", "content": hint_prompt}, {"role": "user", "content": response}]
|
||||
|
||||
# Repeat to ensure we maximize entity count
|
||||
for i in range(self._max_gleanings):
|
||||
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
|
||||
history.append({"role": "user", "content": text})
|
||||
response = self._chat("", history, gen_conf)
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
|
||||
results += response or ""
|
||||
|
||||
# if this is the final glean, don't bother updating the continuation flag
|
||||
if i >= self._max_gleanings - 1:
|
||||
break
|
||||
history.append({"role": "assistant", "content": response})
|
||||
history.append({"role": "user", "content": LOOP_PROMPT})
|
||||
continuation = self._chat("", history, {"temperature": 0.8})
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
|
||||
if continuation != "YES":
|
||||
break
|
||||
|
||||
record_delimiter = variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER)
|
||||
tuple_delimiter = variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER)
|
||||
records = [re.sub(r"^\(|\)$", "", r.strip()) for r in results.split(record_delimiter)]
|
||||
records = [r for r in records if r.strip()]
|
||||
maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, tuple_delimiter)
|
||||
return maybe_nodes, maybe_edges, token_count
|
||||
except Exception as e:
|
||||
logging.exception("error extracting graph")
|
||||
return e, None, None
|
||||
gen_conf = {"temperature": 0.3}
|
||||
hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
|
||||
async with chat_limiter:
|
||||
response = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf))
|
||||
token_count += num_tokens_from_string(hint_prompt + response)
|
||||
|
||||
results = response or ""
|
||||
history = [{"role": "system", "content": hint_prompt}, {"role": "user", "content": response}]
|
||||
|
||||
# Repeat to ensure we maximize entity count
|
||||
for i in range(self._max_gleanings):
|
||||
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
|
||||
history.append({"role": "user", "content": text})
|
||||
async with chat_limiter:
|
||||
response = await trio.to_thread.run_sync(lambda: self._chat("", history, gen_conf))
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
|
||||
results += response or ""
|
||||
|
||||
# if this is the final glean, don't bother updating the continuation flag
|
||||
if i >= self._max_gleanings - 1:
|
||||
break
|
||||
history.append({"role": "assistant", "content": response})
|
||||
history.append({"role": "user", "content": LOOP_PROMPT})
|
||||
async with chat_limiter:
|
||||
continuation = await trio.to_thread.run_sync(lambda: self._chat("", history, {"temperature": 0.8}))
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
|
||||
if continuation != "YES":
|
||||
break
|
||||
record_delimiter = variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER)
|
||||
tuple_delimiter = variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER)
|
||||
records = [re.sub(r"^\(|\)$", "", r.strip()) for r in results.split(record_delimiter)]
|
||||
records = [r for r in records if r.strip()]
|
||||
maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, tuple_delimiter)
|
||||
out_results.append((maybe_nodes, maybe_edges, token_count))
|
||||
if self.callback:
|
||||
self.callback(0.5+0.1*len(out_results)/num_chunks, msg = f"Entities extraction of chunk {chunk_seq} {len(out_results)}/{num_chunks} done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {token_count} tokens.")
|
||||
|
@ -17,6 +17,7 @@ import json
|
||||
import logging
|
||||
from functools import reduce, partial
|
||||
import networkx as nx
|
||||
import trio
|
||||
|
||||
from api import settings
|
||||
from graphrag.general.community_reports_extractor import CommunityReportsExtractor
|
||||
@ -41,18 +42,24 @@ class Dealer:
|
||||
embed_bdl=None,
|
||||
callback=None
|
||||
):
|
||||
docids = list(set([docid for docid,_ in chunks]))
|
||||
self.tenant_id = tenant_id
|
||||
self.kb_id = kb_id
|
||||
self.chunks = chunks
|
||||
self.llm_bdl = llm_bdl
|
||||
self.embed_bdl = embed_bdl
|
||||
ext = extractor(self.llm_bdl, language=language,
|
||||
self.ext = extractor(self.llm_bdl, language=language,
|
||||
entity_types=entity_types,
|
||||
get_entity=partial(get_entity, tenant_id, kb_id),
|
||||
set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl),
|
||||
get_relation=partial(get_relation, tenant_id, kb_id),
|
||||
set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl)
|
||||
)
|
||||
ents, rels = ext(chunks, callback)
|
||||
self.graph = nx.Graph()
|
||||
self.callback = callback
|
||||
|
||||
async def __call__(self):
|
||||
docids = list(set([docid for docid, _ in self.chunks]))
|
||||
ents, rels = await self.ext(self.chunks, self.callback)
|
||||
for en in ents:
|
||||
self.graph.add_node(en["entity_name"], entity_type=en["entity_type"])#, description=en["description"])
|
||||
|
||||
@ -64,16 +71,16 @@ class Dealer:
|
||||
#description=rel["description"]
|
||||
)
|
||||
|
||||
with RedisDistributedLock(kb_id, 60*60):
|
||||
old_graph, old_doc_ids = get_graph(tenant_id, kb_id)
|
||||
with RedisDistributedLock(self.kb_id, 60*60):
|
||||
old_graph, old_doc_ids = get_graph(self.tenant_id, self.kb_id)
|
||||
if old_graph is not None:
|
||||
logging.info("Merge with an exiting graph...................")
|
||||
self.graph = reduce(graph_merge, [old_graph, self.graph])
|
||||
update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, self.graph, 2)
|
||||
update_nodes_pagerank_nhop_neighbour(self.tenant_id, self.kb_id, self.graph, 2)
|
||||
if old_doc_ids:
|
||||
docids.extend(old_doc_ids)
|
||||
docids = list(set(docids))
|
||||
set_graph(tenant_id, kb_id, self.graph, docids)
|
||||
set_graph(self.tenant_id, self.kb_id, self.graph, docids)
|
||||
|
||||
|
||||
class WithResolution(Dealer):
|
||||
@ -84,47 +91,50 @@ class WithResolution(Dealer):
|
||||
embed_bdl=None,
|
||||
callback=None
|
||||
):
|
||||
self.tenant_id = tenant_id
|
||||
self.kb_id = kb_id
|
||||
self.llm_bdl = llm_bdl
|
||||
self.embed_bdl = embed_bdl
|
||||
|
||||
with RedisDistributedLock(kb_id, 60*60):
|
||||
self.graph, doc_ids = get_graph(tenant_id, kb_id)
|
||||
self.callback = callback
|
||||
async def __call__(self):
|
||||
with RedisDistributedLock(self.kb_id, 60*60):
|
||||
self.graph, doc_ids = await trio.to_thread.run_sync(lambda: get_graph(self.tenant_id, self.kb_id))
|
||||
if not self.graph:
|
||||
logging.error(f"Faild to fetch the graph. tenant_id:{kb_id}, kb_id:{kb_id}")
|
||||
if callback:
|
||||
callback(-1, msg="Faild to fetch the graph.")
|
||||
logging.error(f"Faild to fetch the graph. tenant_id:{self.kb_id}, kb_id:{self.kb_id}")
|
||||
if self.callback:
|
||||
self.callback(-1, msg="Faild to fetch the graph.")
|
||||
return
|
||||
|
||||
if callback:
|
||||
callback(msg="Fetch the existing graph.")
|
||||
if self.callback:
|
||||
self.callback(msg="Fetch the existing graph.")
|
||||
er = EntityResolution(self.llm_bdl,
|
||||
get_entity=partial(get_entity, tenant_id, kb_id),
|
||||
set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl),
|
||||
get_relation=partial(get_relation, tenant_id, kb_id),
|
||||
set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl))
|
||||
reso = er(self.graph)
|
||||
get_entity=partial(get_entity, self.tenant_id, self.kb_id),
|
||||
set_entity=partial(set_entity, self.tenant_id, self.kb_id, self.embed_bdl),
|
||||
get_relation=partial(get_relation, self.tenant_id, self.kb_id),
|
||||
set_relation=partial(set_relation, self.tenant_id, self.kb_id, self.embed_bdl))
|
||||
reso = await er(self.graph)
|
||||
self.graph = reso.graph
|
||||
logging.info("Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
|
||||
if callback:
|
||||
callback(msg="Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
|
||||
update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, self.graph, 2)
|
||||
set_graph(tenant_id, kb_id, self.graph, doc_ids)
|
||||
if self.callback:
|
||||
self.callback(msg="Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
|
||||
await trio.to_thread.run_sync(lambda: update_nodes_pagerank_nhop_neighbour(self.tenant_id, self.kb_id, self.graph, 2))
|
||||
await trio.to_thread.run_sync(lambda: set_graph(self.tenant_id, self.kb_id, self.graph, doc_ids))
|
||||
|
||||
settings.docStoreConn.delete({
|
||||
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({
|
||||
"knowledge_graph_kwd": "relation",
|
||||
"kb_id": kb_id,
|
||||
"kb_id": self.kb_id,
|
||||
"from_entity_kwd": reso.removed_entities
|
||||
}, search.index_name(tenant_id), kb_id)
|
||||
settings.docStoreConn.delete({
|
||||
}, search.index_name(self.tenant_id), self.kb_id))
|
||||
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({
|
||||
"knowledge_graph_kwd": "relation",
|
||||
"kb_id": kb_id,
|
||||
"kb_id": self.kb_id,
|
||||
"to_entity_kwd": reso.removed_entities
|
||||
}, search.index_name(tenant_id), kb_id)
|
||||
settings.docStoreConn.delete({
|
||||
}, search.index_name(self.tenant_id), self.kb_id))
|
||||
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({
|
||||
"knowledge_graph_kwd": "entity",
|
||||
"kb_id": kb_id,
|
||||
"kb_id": self.kb_id,
|
||||
"entity_kwd": reso.removed_entities
|
||||
}, search.index_name(tenant_id), kb_id)
|
||||
}, search.index_name(self.tenant_id), self.kb_id))
|
||||
|
||||
|
||||
class WithCommunity(Dealer):
|
||||
@ -136,38 +146,41 @@ class WithCommunity(Dealer):
|
||||
callback=None
|
||||
):
|
||||
|
||||
self.tenant_id = tenant_id
|
||||
self.kb_id = kb_id
|
||||
self.community_structure = None
|
||||
self.community_reports = None
|
||||
self.llm_bdl = llm_bdl
|
||||
self.embed_bdl = embed_bdl
|
||||
|
||||
with RedisDistributedLock(kb_id, 60*60):
|
||||
self.graph, doc_ids = get_graph(tenant_id, kb_id)
|
||||
self.callback = callback
|
||||
async def __call__(self):
|
||||
with RedisDistributedLock(self.kb_id, 60*60):
|
||||
self.graph, doc_ids = get_graph(self.tenant_id, self.kb_id)
|
||||
if not self.graph:
|
||||
logging.error(f"Faild to fetch the graph. tenant_id:{kb_id}, kb_id:{kb_id}")
|
||||
if callback:
|
||||
callback(-1, msg="Faild to fetch the graph.")
|
||||
logging.error(f"Faild to fetch the graph. tenant_id:{self.kb_id}, kb_id:{self.kb_id}")
|
||||
if self.callback:
|
||||
self.callback(-1, msg="Faild to fetch the graph.")
|
||||
return
|
||||
if callback:
|
||||
callback(msg="Fetch the existing graph.")
|
||||
if self.callback:
|
||||
self.callback(msg="Fetch the existing graph.")
|
||||
|
||||
cr = CommunityReportsExtractor(self.llm_bdl,
|
||||
get_entity=partial(get_entity, tenant_id, kb_id),
|
||||
set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl),
|
||||
get_relation=partial(get_relation, tenant_id, kb_id),
|
||||
set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl))
|
||||
cr = cr(self.graph, callback=callback)
|
||||
get_entity=partial(get_entity, self.tenant_id, self.kb_id),
|
||||
set_entity=partial(set_entity, self.tenant_id, self.kb_id, self.embed_bdl),
|
||||
get_relation=partial(get_relation, self.tenant_id, self.kb_id),
|
||||
set_relation=partial(set_relation, self.tenant_id, self.kb_id, self.embed_bdl))
|
||||
cr = await cr(self.graph, callback=self.callback)
|
||||
self.community_structure = cr.structured_output
|
||||
self.community_reports = cr.output
|
||||
set_graph(tenant_id, kb_id, self.graph, doc_ids)
|
||||
await trio.to_thread.run_sync(lambda: set_graph(self.tenant_id, self.kb_id, self.graph, doc_ids))
|
||||
|
||||
if callback:
|
||||
callback(msg="Graph community extraction is done. Indexing {} reports.".format(len(cr.structured_output)))
|
||||
if self.callback:
|
||||
self.callback(msg="Graph community extraction is done. Indexing {} reports.".format(len(cr.structured_output)))
|
||||
|
||||
settings.docStoreConn.delete({
|
||||
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({
|
||||
"knowledge_graph_kwd": "community_report",
|
||||
"kb_id": kb_id
|
||||
}, search.index_name(tenant_id), kb_id)
|
||||
"kb_id": self.kb_id
|
||||
}, search.index_name(self.tenant_id), self.kb_id))
|
||||
|
||||
for stru, rep in zip(self.community_structure, self.community_reports):
|
||||
obj = {
|
||||
@ -183,7 +196,7 @@ class WithCommunity(Dealer):
|
||||
"weight_flt": stru["weight"],
|
||||
"entities_kwd": stru["entities"],
|
||||
"important_kwd": stru["entities"],
|
||||
"kb_id": kb_id,
|
||||
"kb_id": self.kb_id,
|
||||
"source_id": doc_ids,
|
||||
"available_int": 0
|
||||
}
|
||||
@ -193,5 +206,5 @@ class WithCommunity(Dealer):
|
||||
# chunk["q_%d_vec" % len(ebd[0])] = ebd[0]
|
||||
#except Exception as e:
|
||||
# logging.exception(f"Fail to embed entity relation: {e}")
|
||||
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id))
|
||||
await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(self.tenant_id)))
|
||||
|
||||
|
@ -16,16 +16,14 @@
|
||||
|
||||
import logging
|
||||
import collections
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
from typing import Any
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
import trio
|
||||
|
||||
from graphrag.general.extractor import Extractor
|
||||
from graphrag.general.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
|
||||
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
|
||||
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, chat_limiter
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
import markdown_to_json
|
||||
from functools import reduce
|
||||
@ -80,63 +78,47 @@ class MindMapExtractor(Extractor):
|
||||
)
|
||||
return arr
|
||||
|
||||
def __call__(
|
||||
async def __call__(
|
||||
self, sections: list[str], prompt_variables: dict[str, Any] | None = None
|
||||
) -> MindMapResult:
|
||||
"""Call method definition."""
|
||||
if prompt_variables is None:
|
||||
prompt_variables = {}
|
||||
|
||||
try:
|
||||
res = []
|
||||
max_workers = int(os.environ.get('MINDMAP_EXTRACTOR_MAX_WORKERS', 12))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as exe:
|
||||
threads = []
|
||||
token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512)
|
||||
texts = []
|
||||
cnt = 0
|
||||
for i in range(len(sections)):
|
||||
section_cnt = num_tokens_from_string(sections[i])
|
||||
if cnt + section_cnt >= token_count and texts:
|
||||
threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
|
||||
texts = []
|
||||
cnt = 0
|
||||
texts.append(sections[i])
|
||||
cnt += section_cnt
|
||||
if texts:
|
||||
threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
|
||||
|
||||
for i, _ in enumerate(threads):
|
||||
res.append(_.result())
|
||||
|
||||
if not res:
|
||||
return MindMapResult(output={"id": "root", "children": []})
|
||||
|
||||
merge_json = reduce(self._merge, res)
|
||||
if len(merge_json) > 1:
|
||||
keys = [re.sub(r"\*+", "", k) for k, v in merge_json.items() if isinstance(v, dict)]
|
||||
keyset = set(i for i in keys if i)
|
||||
merge_json = {
|
||||
"id": "root",
|
||||
"children": [
|
||||
{
|
||||
"id": self._key(k),
|
||||
"children": self._be_children(v, keyset)
|
||||
}
|
||||
for k, v in merge_json.items() if isinstance(v, dict) and self._key(k)
|
||||
]
|
||||
}
|
||||
else:
|
||||
k = self._key(list(merge_json.keys())[0])
|
||||
merge_json = {"id": k, "children": self._be_children(list(merge_json.items())[0][1], {k})}
|
||||
|
||||
except Exception as e:
|
||||
logging.exception("error mind graph")
|
||||
self._on_error(
|
||||
e,
|
||||
traceback.format_exc(), None
|
||||
)
|
||||
merge_json = {"error": str(e)}
|
||||
res = []
|
||||
token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512)
|
||||
texts = []
|
||||
cnt = 0
|
||||
async with trio.open_nursery() as nursery:
|
||||
for i in range(len(sections)):
|
||||
section_cnt = num_tokens_from_string(sections[i])
|
||||
if cnt + section_cnt >= token_count and texts:
|
||||
nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res)
|
||||
texts = []
|
||||
cnt = 0
|
||||
texts.append(sections[i])
|
||||
cnt += section_cnt
|
||||
if texts:
|
||||
nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res)
|
||||
if not res:
|
||||
return MindMapResult(output={"id": "root", "children": []})
|
||||
merge_json = reduce(self._merge, res)
|
||||
if len(merge_json) > 1:
|
||||
keys = [re.sub(r"\*+", "", k) for k, v in merge_json.items() if isinstance(v, dict)]
|
||||
keyset = set(i for i in keys if i)
|
||||
merge_json = {
|
||||
"id": "root",
|
||||
"children": [
|
||||
{
|
||||
"id": self._key(k),
|
||||
"children": self._be_children(v, keyset)
|
||||
}
|
||||
for k, v in merge_json.items() if isinstance(v, dict) and self._key(k)
|
||||
]
|
||||
}
|
||||
else:
|
||||
k = self._key(list(merge_json.keys())[0])
|
||||
merge_json = {"id": k, "children": self._be_children(list(merge_json.items())[0][1], {k})}
|
||||
|
||||
return MindMapResult(output=merge_json)
|
||||
|
||||
@ -181,8 +163,8 @@ class MindMapExtractor(Extractor):
|
||||
|
||||
return self._list_to_kv(to_ret)
|
||||
|
||||
def _process_document(
|
||||
self, text: str, prompt_variables: dict[str, str]
|
||||
async def _process_document(
|
||||
self, text: str, prompt_variables: dict[str, str], out_res
|
||||
) -> str:
|
||||
variables = {
|
||||
**prompt_variables,
|
||||
@ -190,8 +172,9 @@ class MindMapExtractor(Extractor):
|
||||
}
|
||||
text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
|
||||
gen_conf = {"temperature": 0.5}
|
||||
response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
|
||||
async with chat_limiter:
|
||||
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
|
||||
response = re.sub(r"```[^\n]*", "", response)
|
||||
logging.debug(response)
|
||||
logging.debug(self._todict(markdown_to_json.dictify(response)))
|
||||
return self._todict(markdown_to_json.dictify(response))
|
||||
out_res.append(self._todict(markdown_to_json.dictify(response)))
|
||||
|
@ -18,6 +18,7 @@ import argparse
|
||||
import json
|
||||
|
||||
import networkx as nx
|
||||
import trio
|
||||
|
||||
from api import settings
|
||||
from api.db import LLMType
|
||||
@ -54,10 +55,13 @@ if __name__ == "__main__":
|
||||
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
|
||||
|
||||
dealer = Dealer(GraphExtractor, args.tenant_id, kb_id, llm_bdl, chunks, "English", embed_bdl=embed_bdl)
|
||||
trio.run(dealer())
|
||||
print(json.dumps(nx.node_link_data(dealer.graph), ensure_ascii=False, indent=2))
|
||||
|
||||
dealer = WithResolution(args.tenant_id, kb_id, llm_bdl, embed_bdl)
|
||||
trio.run(dealer())
|
||||
dealer = WithCommunity(args.tenant_id, kb_id, llm_bdl, embed_bdl)
|
||||
trio.run(dealer())
|
||||
|
||||
print("------------------ COMMUNITY REPORT ----------------------\n", dealer.community_reports)
|
||||
print(json.dumps(dealer.community_structure, ensure_ascii=False, indent=2))
|
||||
|
@ -4,16 +4,16 @@
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
"""
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Callable
|
||||
from dataclasses import dataclass
|
||||
from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS
|
||||
from graphrag.light.graph_prompt import PROMPTS
|
||||
from graphrag.utils import pack_user_ass_to_openai_messages, split_string_by_multi_markers
|
||||
from graphrag.utils import pack_user_ass_to_openai_messages, split_string_by_multi_markers, chat_limiter
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
import networkx as nx
|
||||
from rag.utils import num_tokens_from_string
|
||||
import trio
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -82,7 +82,7 @@ class GraphExtractor(Extractor):
|
||||
)
|
||||
self._left_token_count = max(llm_invoker.max_length * 0.6, self._left_token_count)
|
||||
|
||||
def _process_single_content(self, chunk_key_dp: tuple[str, str]):
|
||||
async def _process_single_content(self, chunk_key_dp: tuple[str, str], chunk_seq: int, num_chunks: int, out_results):
|
||||
token_count = 0
|
||||
chunk_key = chunk_key_dp[0]
|
||||
content = chunk_key_dp[1]
|
||||
@ -90,38 +90,39 @@ class GraphExtractor(Extractor):
|
||||
**self._context_base, input_text="{input_text}"
|
||||
).format(**self._context_base, input_text=content)
|
||||
|
||||
try:
|
||||
gen_conf = {"temperature": 0.8}
|
||||
final_result = self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf)
|
||||
token_count += num_tokens_from_string(hint_prompt + final_result)
|
||||
history = pack_user_ass_to_openai_messages("Output:", final_result, self._continue_prompt)
|
||||
for now_glean_index in range(self._max_gleanings):
|
||||
glean_result = self._chat(hint_prompt, history, gen_conf)
|
||||
history.extend([{"role": "assistant", "content": glean_result}, {"role": "user", "content": self._continue_prompt}])
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt)
|
||||
final_result += glean_result
|
||||
if now_glean_index == self._max_gleanings - 1:
|
||||
break
|
||||
gen_conf = {"temperature": 0.8}
|
||||
async with chat_limiter:
|
||||
final_result = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf))
|
||||
token_count += num_tokens_from_string(hint_prompt + final_result)
|
||||
history = pack_user_ass_to_openai_messages("Output:", final_result, self._continue_prompt)
|
||||
for now_glean_index in range(self._max_gleanings):
|
||||
async with chat_limiter:
|
||||
glean_result = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, history, gen_conf))
|
||||
history.extend([{"role": "assistant", "content": glean_result}, {"role": "user", "content": self._continue_prompt}])
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt)
|
||||
final_result += glean_result
|
||||
if now_glean_index == self._max_gleanings - 1:
|
||||
break
|
||||
|
||||
if_loop_result = self._chat(self._if_loop_prompt, history, gen_conf)
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt)
|
||||
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
||||
if if_loop_result != "yes":
|
||||
break
|
||||
async with chat_limiter:
|
||||
if_loop_result = await trio.to_thread.run_sync(lambda: self._chat(self._if_loop_prompt, history, gen_conf))
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt)
|
||||
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
||||
if if_loop_result != "yes":
|
||||
break
|
||||
|
||||
records = split_string_by_multi_markers(
|
||||
final_result,
|
||||
[self._context_base["record_delimiter"], self._context_base["completion_delimiter"]],
|
||||
)
|
||||
rcds = []
|
||||
for record in records:
|
||||
record = re.search(r"\((.*)\)", record)
|
||||
if record is None:
|
||||
continue
|
||||
rcds.append(record.group(1))
|
||||
records = rcds
|
||||
maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, self._context_base["tuple_delimiter"])
|
||||
return maybe_nodes, maybe_edges, token_count
|
||||
except Exception as e:
|
||||
logging.exception("error extracting graph")
|
||||
return e, None, None
|
||||
records = split_string_by_multi_markers(
|
||||
final_result,
|
||||
[self._context_base["record_delimiter"], self._context_base["completion_delimiter"]],
|
||||
)
|
||||
rcds = []
|
||||
for record in records:
|
||||
record = re.search(r"\((.*)\)", record)
|
||||
if record is None:
|
||||
continue
|
||||
rcds.append(record.group(1))
|
||||
records = rcds
|
||||
maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, self._context_base["tuple_delimiter"])
|
||||
out_results.append((maybe_nodes, maybe_edges, token_count))
|
||||
if self.callback:
|
||||
self.callback(0.5+0.1*len(out_results)/num_chunks, msg = f"Entities extraction of chunk {chunk_seq} {len(out_results)}/{num_chunks} done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {token_count} tokens.")
|
||||
|
@ -15,6 +15,8 @@ from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from hashlib import md5
|
||||
from typing import Any, Callable
|
||||
import os
|
||||
import trio
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
@ -28,6 +30,7 @@ from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
|
||||
|
||||
chat_limiter = trio.CapacityLimiter(int(os.environ.get('MAX_CONCURRENT_CHATS', 100)))
|
||||
|
||||
def perform_variable_replacements(
|
||||
input: str, history: list[dict] | None = None, variables: dict | None = None
|
||||
|
@ -122,7 +122,8 @@ dependencies = [
|
||||
"pyodbc>=5.2.0,<6.0.0",
|
||||
"pyicu>=2.13.1,<3.0.0",
|
||||
"flasgger>=0.9.7.1,<0.10.0",
|
||||
"xxhash>=3.5.0,<4.0.0"
|
||||
"xxhash>=3.5.0,<4.0.0",
|
||||
"trio>=0.29.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
@ -133,4 +134,7 @@ full = [
|
||||
"flagembedding==1.2.10",
|
||||
"torch>=2.5.0,<3.0.0",
|
||||
"transformers>=4.35.0,<5.0.0"
|
||||
]
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
url = "https://mirrors.aliyun.com/pypi/simple"
|
||||
|
@ -14,15 +14,14 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor, ALL_COMPLETED, wait
|
||||
from threading import Lock
|
||||
import umap
|
||||
import numpy as np
|
||||
from sklearn.mixture import GaussianMixture
|
||||
import trio
|
||||
|
||||
from graphrag.utils import get_llm_cache, get_embed_cache, set_embed_cache, set_llm_cache
|
||||
from graphrag.utils import get_llm_cache, get_embed_cache, set_embed_cache, set_llm_cache, chat_limiter
|
||||
from rag.utils import truncate
|
||||
|
||||
|
||||
@ -68,24 +67,25 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
optimal_clusters = n_clusters[np.argmin(bics)]
|
||||
return optimal_clusters
|
||||
|
||||
def __call__(self, chunks, random_state, callback=None):
|
||||
async def __call__(self, chunks, random_state, callback=None):
|
||||
layers = [(0, len(chunks))]
|
||||
start, end = 0, len(chunks)
|
||||
if len(chunks) <= 1:
|
||||
return []
|
||||
chunks = [(s, a) for s, a in chunks if s and len(a) > 0]
|
||||
|
||||
def summarize(ck_idx, lock):
|
||||
async def summarize(ck_idx, lock):
|
||||
nonlocal chunks
|
||||
try:
|
||||
texts = [chunks[i][0] for i in ck_idx]
|
||||
len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
|
||||
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
|
||||
cnt = self._chat("You're a helpful assistant.",
|
||||
[{"role": "user",
|
||||
"content": self._prompt.format(cluster_content=cluster_content)}],
|
||||
{"temperature": 0.3, "max_tokens": self._max_token}
|
||||
)
|
||||
async with chat_limiter:
|
||||
cnt = await trio.to_thread.run_sync(lambda: self._chat("You're a helpful assistant.",
|
||||
[{"role": "user",
|
||||
"content": self._prompt.format(cluster_content=cluster_content)}],
|
||||
{"temperature": 0.3, "max_tokens": self._max_token}
|
||||
))
|
||||
cnt = re.sub("(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", "",
|
||||
cnt)
|
||||
logging.debug(f"SUM: {cnt}")
|
||||
@ -97,10 +97,11 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
return e
|
||||
|
||||
labels = []
|
||||
lock = Lock()
|
||||
while end - start > 1:
|
||||
embeddings = [embd for _, embd in chunks[start: end]]
|
||||
if len(embeddings) == 2:
|
||||
summarize([start, start + 1], Lock())
|
||||
await summarize([start, start + 1], lock)
|
||||
if callback:
|
||||
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
|
||||
labels.extend([0, 0])
|
||||
@ -122,19 +123,14 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
probs = gm.predict_proba(reduced_embeddings)
|
||||
lbls = [np.where(prob > self._threshold)[0] for prob in probs]
|
||||
lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls]
|
||||
lock = Lock()
|
||||
with ThreadPoolExecutor(max_workers=int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 10))) as executor:
|
||||
threads = []
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for c in range(n_clusters):
|
||||
ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]
|
||||
if not ck_idx:
|
||||
continue
|
||||
threads.append(executor.submit(summarize, ck_idx, lock))
|
||||
wait(threads, return_when=ALL_COMPLETED)
|
||||
for th in threads:
|
||||
if isinstance(th.result(), Exception):
|
||||
raise th.result()
|
||||
logging.debug(str([t.result() for t in threads]))
|
||||
async with chat_limiter:
|
||||
nursery.start_soon(lambda: summarize(ck_idx, lock))
|
||||
|
||||
assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
|
||||
labels.extend(lbls)
|
||||
|
@ -30,7 +30,6 @@ CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
|
||||
CONSUMER_NAME = "task_executor_" + CONSUMER_NO
|
||||
initRootLogger(CONSUMER_NAME)
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
@ -38,14 +37,14 @@ import json
|
||||
import xxhash
|
||||
import copy
|
||||
import re
|
||||
import time
|
||||
import threading
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
from multiprocessing.context import TimeoutError
|
||||
from timeit import default_timer as timer
|
||||
import tracemalloc
|
||||
import resource
|
||||
import signal
|
||||
import trio
|
||||
|
||||
import numpy as np
|
||||
from peewee import DoesNotExist
|
||||
@ -64,8 +63,9 @@ from rag.nlp import search, rag_tokenizer
|
||||
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
||||
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings, TAG_FLD, PAGERANK_FLD
|
||||
from rag.utils import num_tokens_from_string
|
||||
from rag.utils.redis_conn import REDIS_CONN, Payload
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from graphrag.utils import chat_limiter
|
||||
|
||||
BATCH_SIZE = 64
|
||||
|
||||
@ -88,28 +88,28 @@ FACTORY = {
|
||||
ParserType.TAG.value: tag
|
||||
}
|
||||
|
||||
UNACKED_ITERATOR = None
|
||||
CONSUMER_NAME = "task_consumer_" + CONSUMER_NO
|
||||
PAYLOAD: Payload | None = None
|
||||
BOOT_AT = datetime.now().astimezone().isoformat(timespec="milliseconds")
|
||||
PENDING_TASKS = 0
|
||||
LAG_TASKS = 0
|
||||
|
||||
mt_lock = threading.Lock()
|
||||
DONE_TASKS = 0
|
||||
FAILED_TASKS = 0
|
||||
CURRENT_TASK = None
|
||||
|
||||
tracemalloc_started = False
|
||||
CURRENT_TASKS = {}
|
||||
|
||||
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
|
||||
MAX_CONCURRENT_CHUNK_BUILDERS = int(os.environ.get('MAX_CONCURRENT_CHUNK_BUILDERS', "1"))
|
||||
task_limiter = trio.CapacityLimiter(MAX_CONCURRENT_TASKS)
|
||||
chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS)
|
||||
|
||||
# SIGUSR1 handler: start tracemalloc and take snapshot
|
||||
def start_tracemalloc_and_snapshot(signum, frame):
|
||||
global tracemalloc_started
|
||||
if not tracemalloc_started:
|
||||
logging.info("got SIGUSR1, start tracemalloc")
|
||||
if not tracemalloc.is_tracing():
|
||||
logging.info("start tracemalloc")
|
||||
tracemalloc.start()
|
||||
tracemalloc_started = True
|
||||
else:
|
||||
logging.info("got SIGUSR1, tracemalloc is already running")
|
||||
logging.info("tracemalloc is already running")
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
snapshot_file = f"snapshot_{timestamp}.trace"
|
||||
@ -117,17 +117,17 @@ def start_tracemalloc_and_snapshot(signum, frame):
|
||||
|
||||
snapshot = tracemalloc.take_snapshot()
|
||||
snapshot.dump(snapshot_file)
|
||||
logging.info(f"taken snapshot {snapshot_file}")
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
|
||||
|
||||
# SIGUSR2 handler: stop tracemalloc
|
||||
def stop_tracemalloc(signum, frame):
|
||||
global tracemalloc_started
|
||||
if tracemalloc_started:
|
||||
logging.info("go SIGUSR2, stop tracemalloc")
|
||||
if tracemalloc.is_tracing():
|
||||
logging.info("stop tracemalloc")
|
||||
tracemalloc.stop()
|
||||
tracemalloc_started = False
|
||||
else:
|
||||
logging.info("got SIGUSR2, tracemalloc not running")
|
||||
logging.info("tracemalloc not running")
|
||||
|
||||
class TaskCanceledException(Exception):
|
||||
def __init__(self, msg):
|
||||
@ -135,17 +135,9 @@ class TaskCanceledException(Exception):
|
||||
|
||||
|
||||
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
|
||||
global PAYLOAD
|
||||
if prog is not None and prog < 0:
|
||||
msg = "[ERROR]" + msg
|
||||
try:
|
||||
cancel = TaskService.do_cancel(task_id)
|
||||
except DoesNotExist:
|
||||
logging.warning(f"set_progress task {task_id} is unknown")
|
||||
if PAYLOAD:
|
||||
PAYLOAD.ack()
|
||||
PAYLOAD = None
|
||||
return
|
||||
cancel = TaskService.do_cancel(task_id)
|
||||
|
||||
if cancel:
|
||||
msg += " [Canceled]"
|
||||
@ -162,66 +154,55 @@ def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing...
|
||||
d["progress"] = prog
|
||||
|
||||
logging.info(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}")
|
||||
try:
|
||||
TaskService.update_progress(task_id, d)
|
||||
except DoesNotExist:
|
||||
logging.warning(f"set_progress task {task_id} is unknown")
|
||||
if PAYLOAD:
|
||||
PAYLOAD.ack()
|
||||
PAYLOAD = None
|
||||
return
|
||||
TaskService.update_progress(task_id, d)
|
||||
|
||||
close_connection()
|
||||
if cancel and PAYLOAD:
|
||||
PAYLOAD.ack()
|
||||
PAYLOAD = None
|
||||
if cancel:
|
||||
raise TaskCanceledException(msg)
|
||||
|
||||
|
||||
def collect():
|
||||
global CONSUMER_NAME, PAYLOAD, DONE_TASKS, FAILED_TASKS
|
||||
async def collect():
|
||||
global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS
|
||||
global UNACKED_ITERATOR
|
||||
try:
|
||||
PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
|
||||
if not PAYLOAD:
|
||||
PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
|
||||
if not PAYLOAD:
|
||||
time.sleep(1)
|
||||
return None
|
||||
if not UNACKED_ITERATOR:
|
||||
UNACKED_ITERATOR = REDIS_CONN.get_unacked_iterator(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
|
||||
try:
|
||||
redis_msg = next(UNACKED_ITERATOR)
|
||||
except StopIteration:
|
||||
redis_msg = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
|
||||
if not redis_msg:
|
||||
await trio.sleep(1)
|
||||
return None, None
|
||||
except Exception:
|
||||
logging.exception("Get task event from queue exception")
|
||||
return None
|
||||
logging.exception("collect got exception")
|
||||
return None, None
|
||||
|
||||
msg = PAYLOAD.get_message()
|
||||
msg = redis_msg.get_message()
|
||||
if not msg:
|
||||
return None
|
||||
logging.error(f"collect got empty message of {redis_msg.get_msg_id()}")
|
||||
redis_msg.ack()
|
||||
return None, None
|
||||
|
||||
task = None
|
||||
canceled = False
|
||||
try:
|
||||
task = TaskService.get_task(msg["id"])
|
||||
if task:
|
||||
_, doc = DocumentService.get_by_id(task["doc_id"])
|
||||
canceled = doc.run == TaskStatus.CANCEL.value or doc.progress < 0
|
||||
except DoesNotExist:
|
||||
pass
|
||||
except Exception:
|
||||
logging.exception("collect get_task exception")
|
||||
task = TaskService.get_task(msg["id"])
|
||||
if task:
|
||||
_, doc = DocumentService.get_by_id(task["doc_id"])
|
||||
canceled = doc.run == TaskStatus.CANCEL.value or doc.progress < 0
|
||||
if not task or canceled:
|
||||
state = "is unknown" if not task else "has been cancelled"
|
||||
with mt_lock:
|
||||
DONE_TASKS += 1
|
||||
logging.info(f"collect task {msg['id']} {state}")
|
||||
FAILED_TASKS += 1
|
||||
logging.warning(f"collect task {msg['id']} {state}")
|
||||
redis_msg.ack()
|
||||
return None
|
||||
|
||||
task["task_type"] = msg.get("task_type", "")
|
||||
return task
|
||||
return redis_msg, task
|
||||
|
||||
|
||||
def get_storage_binary(bucket, name):
|
||||
return STORAGE_IMPL.get(bucket, name)
|
||||
async def get_storage_binary(bucket, name):
|
||||
return await trio.to_thread.run_sync(lambda: STORAGE_IMPL.get(bucket, name))
|
||||
|
||||
|
||||
def build_chunks(task, progress_callback):
|
||||
async def build_chunks(task, progress_callback):
|
||||
if task["size"] > DOC_MAXIMUM_SIZE:
|
||||
set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
|
||||
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
|
||||
@ -231,7 +212,7 @@ def build_chunks(task, progress_callback):
|
||||
try:
|
||||
st = timer()
|
||||
bucket, name = File2DocumentService.get_storage_address(doc_id=task["doc_id"])
|
||||
binary = get_storage_binary(bucket, name)
|
||||
binary = await get_storage_binary(bucket, name)
|
||||
logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"]))
|
||||
except TimeoutError:
|
||||
progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
|
||||
@ -247,9 +228,10 @@ def build_chunks(task, progress_callback):
|
||||
raise
|
||||
|
||||
try:
|
||||
cks = chunker.chunk(task["name"], binary=binary, from_page=task["from_page"],
|
||||
to_page=task["to_page"], lang=task["language"], callback=progress_callback,
|
||||
kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"])
|
||||
async with chunk_limiter:
|
||||
cks = await trio.to_thread.run_sync(lambda: chunker.chunk(task["name"], binary=binary, from_page=task["from_page"],
|
||||
to_page=task["to_page"], lang=task["language"], callback=progress_callback,
|
||||
kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"]))
|
||||
logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"]))
|
||||
except TaskCanceledException:
|
||||
raise
|
||||
@ -286,7 +268,7 @@ def build_chunks(task, progress_callback):
|
||||
d["image"].save(output_buffer, format='JPEG')
|
||||
|
||||
st = timer()
|
||||
STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue())
|
||||
await trio.to_thread.run_sync(lambda: STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue()))
|
||||
el += timer() - st
|
||||
except Exception:
|
||||
logging.exception(
|
||||
@ -306,14 +288,16 @@ def build_chunks(task, progress_callback):
|
||||
async def doc_keyword_extraction(chat_mdl, d, topn):
|
||||
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn})
|
||||
if not cached:
|
||||
cached = await asyncio.to_thread(keyword_extraction, chat_mdl, d["content_with_weight"], topn)
|
||||
async with chat_limiter:
|
||||
cached = await trio.to_thread.run_sync(lambda: keyword_extraction(chat_mdl, d["content_with_weight"], topn))
|
||||
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn})
|
||||
if cached:
|
||||
d["important_kwd"] = cached.split(",")
|
||||
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
|
||||
return
|
||||
tasks = [doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"]) for d in docs]
|
||||
asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks))
|
||||
async with trio.open_nursery() as nursery:
|
||||
for d in docs:
|
||||
nursery.start_soon(doc_keyword_extraction, chat_mdl, d, task["parser_config"]["auto_keywords"])
|
||||
progress_callback(msg="Keywords generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
|
||||
|
||||
if task["parser_config"].get("auto_questions", 0):
|
||||
@ -324,13 +308,15 @@ def build_chunks(task, progress_callback):
|
||||
async def doc_question_proposal(chat_mdl, d, topn):
|
||||
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn})
|
||||
if not cached:
|
||||
cached = await asyncio.to_thread(question_proposal, chat_mdl, d["content_with_weight"], topn)
|
||||
async with chat_limiter:
|
||||
cached = await trio.to_thread.run_sync(lambda: question_proposal(chat_mdl, d["content_with_weight"], topn))
|
||||
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", {"topn": topn})
|
||||
if cached:
|
||||
d["question_kwd"] = cached.split("\n")
|
||||
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
|
||||
tasks = [doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"]) for d in docs]
|
||||
asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks))
|
||||
async with trio.open_nursery() as nursery:
|
||||
for d in docs:
|
||||
nursery.start_soon(doc_question_proposal, chat_mdl, d, task["parser_config"]["auto_questions"])
|
||||
progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
|
||||
|
||||
if task["kb_parser_config"].get("tag_kb_ids", []):
|
||||
@ -361,14 +347,16 @@ def build_chunks(task, progress_callback):
|
||||
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], all_tags, {"topn": topn_tags})
|
||||
if not cached:
|
||||
picked_examples = random.choices(examples, k=2) if len(examples)>2 else examples
|
||||
cached = await asyncio.to_thread(content_tagging, chat_mdl, d["content_with_weight"], all_tags, picked_examples, topn=topn_tags)
|
||||
async with chat_limiter:
|
||||
cached = await trio.to_thread.run_sync(lambda: content_tagging(chat_mdl, d["content_with_weight"], all_tags, picked_examples, topn=topn_tags))
|
||||
if cached:
|
||||
cached = json.dumps(cached)
|
||||
if cached:
|
||||
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
|
||||
d[TAG_FLD] = json.loads(cached)
|
||||
tasks = [doc_content_tagging(chat_mdl, d, topn_tags) for d in docs_to_tag]
|
||||
asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks))
|
||||
async with trio.open_nursery() as nursery:
|
||||
for d in docs_to_tag:
|
||||
nursery.start_soon(doc_content_tagging, chat_mdl, d, topn_tags)
|
||||
progress_callback(msg="Tagging {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
|
||||
|
||||
return docs
|
||||
@ -379,7 +367,7 @@ def init_kb(row, vector_size: int):
|
||||
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
|
||||
|
||||
|
||||
def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
if parser_config is None:
|
||||
parser_config = {}
|
||||
batch_size = 16
|
||||
@ -396,13 +384,13 @@ def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
|
||||
tk_count = 0
|
||||
if len(tts) == len(cnts):
|
||||
vts, c = mdl.encode(tts[0: 1])
|
||||
vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1]))
|
||||
tts = np.concatenate([vts for _ in range(len(tts))], axis=0)
|
||||
tk_count += c
|
||||
|
||||
cnts_ = np.array([])
|
||||
for i in range(0, len(cnts), batch_size):
|
||||
vts, c = mdl.encode(cnts[i: i + batch_size])
|
||||
vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(cnts[i: i + batch_size]))
|
||||
if len(cnts_) == 0:
|
||||
cnts_ = vts
|
||||
else:
|
||||
@ -424,7 +412,7 @@ def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
return tk_count, vector_size
|
||||
|
||||
|
||||
def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
||||
async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
||||
chunks = []
|
||||
vctr_nm = "q_%d_vec"%vector_size
|
||||
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
|
||||
@ -440,7 +428,7 @@ def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
||||
row["parser_config"]["raptor"]["threshold"]
|
||||
)
|
||||
original_length = len(chunks)
|
||||
chunks = raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
|
||||
chunks = await raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
|
||||
doc = {
|
||||
"doc_id": row["doc_id"],
|
||||
"kb_id": [str(row["kb_id"])],
|
||||
@ -465,13 +453,13 @@ def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
||||
return res, tk_count
|
||||
|
||||
|
||||
def run_graphrag(row, chat_model, language, embedding_model, callback=None):
|
||||
async def run_graphrag(row, chat_model, language, embedding_model, callback=None):
|
||||
chunks = []
|
||||
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
|
||||
fields=["content_with_weight", "doc_id"]):
|
||||
chunks.append((d["doc_id"], d["content_with_weight"]))
|
||||
|
||||
Dealer(LightKGExt if row["parser_config"]["graphrag"]["method"] != 'general' else GeneralKGExt,
|
||||
dealer = Dealer(LightKGExt if row["parser_config"]["graphrag"]["method"] != 'general' else GeneralKGExt,
|
||||
row["tenant_id"],
|
||||
str(row["kb_id"]),
|
||||
chat_model,
|
||||
@ -480,9 +468,10 @@ def run_graphrag(row, chat_model, language, embedding_model, callback=None):
|
||||
entity_types=row["parser_config"]["graphrag"]["entity_types"],
|
||||
embed_bdl=embedding_model,
|
||||
callback=callback)
|
||||
await dealer()
|
||||
|
||||
|
||||
def do_handle_task(task):
|
||||
async def do_handle_task(task):
|
||||
task_id = task["id"]
|
||||
task_from_page = task["from_page"]
|
||||
task_to_page = task["to_page"]
|
||||
@ -494,6 +483,7 @@ def do_handle_task(task):
|
||||
task_doc_id = task["doc_id"]
|
||||
task_document_name = task["name"]
|
||||
task_parser_config = task["parser_config"]
|
||||
task_start_ts = timer()
|
||||
|
||||
# prepare the progress callback function
|
||||
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
|
||||
@ -505,11 +495,7 @@ def do_handle_task(task):
|
||||
progress_callback(-1, msg=error_message)
|
||||
raise Exception(error_message)
|
||||
|
||||
try:
|
||||
task_canceled = TaskService.do_cancel(task_id)
|
||||
except DoesNotExist:
|
||||
logging.warning(f"task {task_id} is unknown")
|
||||
return
|
||||
task_canceled = TaskService.do_cancel(task_id)
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
return
|
||||
@ -529,71 +515,41 @@ def do_handle_task(task):
|
||||
|
||||
# Either using RAPTOR or Standard chunking methods
|
||||
if task.get("task_type", "") == "raptor":
|
||||
try:
|
||||
# bind LLM for raptor
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||
# run RAPTOR
|
||||
chunks, token_count = run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
|
||||
except TaskCanceledException:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_message = f'Fail to bind LLM used by RAPTOR: {str(e)}'
|
||||
progress_callback(-1, msg=error_message)
|
||||
logging.exception(error_message)
|
||||
raise
|
||||
# bind LLM for raptor
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||
# run RAPTOR
|
||||
chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
|
||||
# Either using graphrag or Standard chunking methods
|
||||
elif task.get("task_type", "") == "graphrag":
|
||||
start_ts = timer()
|
||||
try:
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||
run_graphrag(task, chat_model, task_language, embedding_model, progress_callback)
|
||||
progress_callback(prog=1.0, msg="Knowledge Graph is done ({:.2f}s)".format(timer() - start_ts))
|
||||
except TaskCanceledException:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_message = f'Fail to bind LLM used by Knowledge Graph: {str(e)}'
|
||||
progress_callback(-1, msg=error_message)
|
||||
logging.exception(error_message)
|
||||
raise
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||
await run_graphrag(task, chat_model, task_language, embedding_model, progress_callback)
|
||||
progress_callback(prog=1.0, msg="Knowledge Graph is done ({:.2f}s)".format(timer() - start_ts))
|
||||
return
|
||||
elif task.get("task_type", "") == "graph_resolution":
|
||||
start_ts = timer()
|
||||
try:
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||
WithResolution(
|
||||
task["tenant_id"], str(task["kb_id"]),chat_model, embedding_model,
|
||||
progress_callback
|
||||
)
|
||||
progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts))
|
||||
except TaskCanceledException:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_message = f'Fail to bind LLM used by Knowledge Graph resolution: {str(e)}'
|
||||
progress_callback(-1, msg=error_message)
|
||||
logging.exception(error_message)
|
||||
raise
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||
with_res = WithResolution(
|
||||
task["tenant_id"], str(task["kb_id"]),chat_model, embedding_model,
|
||||
progress_callback
|
||||
)
|
||||
await with_res()
|
||||
progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts))
|
||||
return
|
||||
elif task.get("task_type", "") == "graph_community":
|
||||
start_ts = timer()
|
||||
try:
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||
WithCommunity(
|
||||
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
|
||||
progress_callback
|
||||
)
|
||||
progress_callback(prog=1.0, msg="GraphRAG community reports generation is done ({:.2f}s)".format(timer() - start_ts))
|
||||
except TaskCanceledException:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_message = f'Fail to bind LLM used by GraphRAG community reports generation: {str(e)}'
|
||||
progress_callback(-1, msg=error_message)
|
||||
logging.exception(error_message)
|
||||
raise
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||
with_comm = WithCommunity(
|
||||
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
|
||||
progress_callback
|
||||
)
|
||||
await with_comm()
|
||||
progress_callback(prog=1.0, msg="GraphRAG community reports generation is done ({:.2f}s)".format(timer() - start_ts))
|
||||
return
|
||||
else:
|
||||
# Standard chunking methods
|
||||
start_ts = timer()
|
||||
chunks = build_chunks(task, progress_callback)
|
||||
chunks = await build_chunks(task, progress_callback)
|
||||
logging.info("Build document {}: {:.2f}s".format(task_document_name, timer() - start_ts))
|
||||
if chunks is None:
|
||||
return
|
||||
@ -605,7 +561,7 @@ def do_handle_task(task):
|
||||
progress_callback(msg="Generate {} chunks".format(len(chunks)))
|
||||
start_ts = timer()
|
||||
try:
|
||||
token_count, vector_size = embedding(chunks, embedding_model, task_parser_config, progress_callback)
|
||||
token_count, vector_size = await embedding(chunks, embedding_model, task_parser_config, progress_callback)
|
||||
except Exception as e:
|
||||
error_message = "Generate embedding error:{}".format(str(e))
|
||||
progress_callback(-1, error_message)
|
||||
@ -621,8 +577,7 @@ def do_handle_task(task):
|
||||
doc_store_result = ""
|
||||
es_bulk_size = 4
|
||||
for b in range(0, len(chunks), es_bulk_size):
|
||||
doc_store_result = settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id),
|
||||
task_dataset_id)
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id), task_dataset_id))
|
||||
if b % 128 == 0:
|
||||
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
|
||||
if doc_store_result:
|
||||
@ -635,8 +590,7 @@ def do_handle_task(task):
|
||||
TaskService.update_chunk_ids(task["id"], chunk_ids_str)
|
||||
except DoesNotExist:
|
||||
logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.")
|
||||
doc_store_result = settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id),
|
||||
task_dataset_id)
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
|
||||
return
|
||||
logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page,
|
||||
task_to_page, len(chunks),
|
||||
@ -645,51 +599,39 @@ def do_handle_task(task):
|
||||
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
|
||||
|
||||
time_cost = timer() - start_ts
|
||||
progress_callback(prog=1.0, msg="Done ({:.2f}s)".format(time_cost))
|
||||
task_time_cost = timer() - task_start_ts
|
||||
progress_callback(prog=1.0, msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
|
||||
logging.info(
|
||||
"Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page,
|
||||
task_to_page, len(chunks),
|
||||
token_count, time_cost))
|
||||
token_count, task_time_cost))
|
||||
|
||||
|
||||
def handle_task():
|
||||
global PAYLOAD, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK
|
||||
task = collect()
|
||||
if task:
|
||||
async def handle_task():
|
||||
global DONE_TASKS, FAILED_TASKS
|
||||
redis_msg, task = await collect()
|
||||
if not task:
|
||||
return
|
||||
try:
|
||||
logging.info(f"handle_task begin for task {json.dumps(task)}")
|
||||
CURRENT_TASKS[task["id"]] = copy.deepcopy(task)
|
||||
await do_handle_task(task)
|
||||
DONE_TASKS += 1
|
||||
CURRENT_TASKS.pop(task["id"], None)
|
||||
logging.info(f"handle_task done for task {json.dumps(task)}")
|
||||
except Exception as e:
|
||||
FAILED_TASKS += 1
|
||||
CURRENT_TASKS.pop(task["id"], None)
|
||||
try:
|
||||
logging.info(f"handle_task begin for task {json.dumps(task)}")
|
||||
with mt_lock:
|
||||
CURRENT_TASK = copy.deepcopy(task)
|
||||
do_handle_task(task)
|
||||
with mt_lock:
|
||||
DONE_TASKS += 1
|
||||
CURRENT_TASK = None
|
||||
logging.info(f"handle_task done for task {json.dumps(task)}")
|
||||
except TaskCanceledException:
|
||||
with mt_lock:
|
||||
DONE_TASKS += 1
|
||||
CURRENT_TASK = None
|
||||
try:
|
||||
set_progress(task["id"], prog=-1, msg="handle_task got TaskCanceledException")
|
||||
except Exception:
|
||||
pass
|
||||
logging.debug("handle_task got TaskCanceledException", exc_info=True)
|
||||
except Exception as e:
|
||||
with mt_lock:
|
||||
FAILED_TASKS += 1
|
||||
CURRENT_TASK = None
|
||||
try:
|
||||
set_progress(task["id"], prog=-1, msg=f"[Exception]: {e}")
|
||||
except Exception:
|
||||
pass
|
||||
logging.exception(f"handle_task got exception for task {json.dumps(task)}")
|
||||
if PAYLOAD:
|
||||
PAYLOAD.ack()
|
||||
PAYLOAD = None
|
||||
set_progress(task["id"], prog=-1, msg=f"[Exception]: {e}")
|
||||
except Exception:
|
||||
pass
|
||||
logging.exception(f"handle_task got exception for task {json.dumps(task)}")
|
||||
redis_msg.ack()
|
||||
|
||||
|
||||
def report_status():
|
||||
global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK
|
||||
async def report_status():
|
||||
global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, DONE_TASKS, FAILED_TASKS
|
||||
REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME)
|
||||
while True:
|
||||
try:
|
||||
@ -699,17 +641,17 @@ def report_status():
|
||||
PENDING_TASKS = int(group_info.get("pending", 0))
|
||||
LAG_TASKS = int(group_info.get("lag", 0))
|
||||
|
||||
with mt_lock:
|
||||
heartbeat = json.dumps({
|
||||
"name": CONSUMER_NAME,
|
||||
"now": now.astimezone().isoformat(timespec="milliseconds"),
|
||||
"boot_at": BOOT_AT,
|
||||
"pending": PENDING_TASKS,
|
||||
"lag": LAG_TASKS,
|
||||
"done": DONE_TASKS,
|
||||
"failed": FAILED_TASKS,
|
||||
"current": CURRENT_TASK,
|
||||
})
|
||||
current = copy.deepcopy(CURRENT_TASKS)
|
||||
heartbeat = json.dumps({
|
||||
"name": CONSUMER_NAME,
|
||||
"now": now.astimezone().isoformat(timespec="milliseconds"),
|
||||
"boot_at": BOOT_AT,
|
||||
"pending": PENDING_TASKS,
|
||||
"lag": LAG_TASKS,
|
||||
"done": DONE_TASKS,
|
||||
"failed": FAILED_TASKS,
|
||||
"current": current,
|
||||
})
|
||||
REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
|
||||
logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
|
||||
|
||||
@ -718,27 +660,10 @@ def report_status():
|
||||
REDIS_CONN.zpopmin(CONSUMER_NAME, expired)
|
||||
except Exception:
|
||||
logging.exception("report_status got exception")
|
||||
time.sleep(30)
|
||||
await trio.sleep(30)
|
||||
|
||||
|
||||
def analyze_heap(snapshot1: tracemalloc.Snapshot, snapshot2: tracemalloc.Snapshot, snapshot_id: int, dump_full: bool):
|
||||
msg = ""
|
||||
if dump_full:
|
||||
stats2 = snapshot2.statistics('lineno')
|
||||
msg += f"{CONSUMER_NAME} memory usage of snapshot {snapshot_id}:\n"
|
||||
for stat in stats2[:10]:
|
||||
msg += f"{stat}\n"
|
||||
stats1_vs_2 = snapshot2.compare_to(snapshot1, 'lineno')
|
||||
msg += f"{CONSUMER_NAME} memory usage increase from snapshot {snapshot_id - 1} to snapshot {snapshot_id}:\n"
|
||||
for stat in stats1_vs_2[:10]:
|
||||
msg += f"{stat}\n"
|
||||
msg += f"{CONSUMER_NAME} detailed traceback for the top memory consumers:\n"
|
||||
for stat in stats1_vs_2[:3]:
|
||||
msg += '\n'.join(stat.traceback.format())
|
||||
logging.info(msg)
|
||||
|
||||
|
||||
def main():
|
||||
async def main():
|
||||
logging.info(r"""
|
||||
______ __ ______ __
|
||||
/_ __/___ ______/ /__ / ____/ _____ _______ __/ /_____ _____
|
||||
@ -755,33 +680,12 @@ def main():
|
||||
if TRACE_MALLOC_ENABLED:
|
||||
start_tracemalloc_and_snapshot(None, None)
|
||||
|
||||
# Create an event to signal the background thread to exit
|
||||
stop_event = threading.Event()
|
||||
|
||||
background_thread = threading.Thread(target=report_status)
|
||||
background_thread.daemon = True
|
||||
background_thread.start()
|
||||
|
||||
# Handle SIGINT (Ctrl+C)
|
||||
def signal_handler(sig, frame):
|
||||
logging.info("Received Ctrl+C, shutting down gracefully...")
|
||||
stop_event.set()
|
||||
# Give the background thread time to clean up
|
||||
if background_thread.is_alive():
|
||||
background_thread.join(timeout=5)
|
||||
logging.info("Exiting...")
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
try:
|
||||
while not stop_event.is_set():
|
||||
handle_task()
|
||||
except KeyboardInterrupt:
|
||||
logging.info("Interrupted by keyboard, shutting down...")
|
||||
stop_event.set()
|
||||
if background_thread.is_alive():
|
||||
background_thread.join(timeout=5)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(report_status)
|
||||
while True:
|
||||
async with task_limiter:
|
||||
nursery.start_soon(handle_task)
|
||||
logging.error("BUG!!! You should not reach here!!!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
trio.run(main)
|
||||
|
@ -24,7 +24,7 @@ from rag import settings
|
||||
from rag.utils import singleton
|
||||
|
||||
|
||||
class Payload:
|
||||
class RedisMsg:
|
||||
def __init__(self, consumer, queue_name, group_name, msg_id, message):
|
||||
self.__consumer = consumer
|
||||
self.__queue_name = queue_name
|
||||
@ -43,6 +43,9 @@ class Payload:
|
||||
def get_message(self):
|
||||
return self.__message
|
||||
|
||||
def get_msg_id(self):
|
||||
return self.__msg_id
|
||||
|
||||
|
||||
@singleton
|
||||
class RedisDB:
|
||||
@ -206,9 +209,8 @@ class RedisDB:
|
||||
)
|
||||
return False
|
||||
|
||||
def queue_consumer(
|
||||
self, queue_name, group_name, consumer_name, msg_id=b">"
|
||||
) -> Payload:
|
||||
def queue_consumer(self, queue_name, group_name, consumer_name, msg_id=b">") -> RedisMsg:
|
||||
"""https://redis.io/docs/latest/commands/xreadgroup/"""
|
||||
try:
|
||||
group_info = self.REDIS.xinfo_groups(queue_name)
|
||||
if not any(e["name"] == group_name for e in group_info):
|
||||
@ -217,15 +219,17 @@ class RedisDB:
|
||||
"groupname": group_name,
|
||||
"consumername": consumer_name,
|
||||
"count": 1,
|
||||
"block": 10000,
|
||||
"block": 5,
|
||||
"streams": {queue_name: msg_id},
|
||||
}
|
||||
messages = self.REDIS.xreadgroup(**args)
|
||||
if not messages:
|
||||
return None
|
||||
stream, element_list = messages[0]
|
||||
if not element_list:
|
||||
return None
|
||||
msg_id, payload = element_list[0]
|
||||
res = Payload(self.REDIS, queue_name, group_name, msg_id, payload)
|
||||
res = RedisMsg(self.REDIS, queue_name, group_name, msg_id, payload)
|
||||
return res
|
||||
except Exception as e:
|
||||
if "key" in str(e):
|
||||
@ -239,30 +243,24 @@ class RedisDB:
|
||||
)
|
||||
return None
|
||||
|
||||
def get_unacked_for(self, consumer_name, queue_name, group_name):
|
||||
def get_unacked_iterator(self, queue_name, group_name, consumer_name):
|
||||
try:
|
||||
group_info = self.REDIS.xinfo_groups(queue_name)
|
||||
if not any(e["name"] == group_name for e in group_info):
|
||||
return
|
||||
pendings = self.REDIS.xpending_range(
|
||||
queue_name,
|
||||
group_name,
|
||||
min=0,
|
||||
max=10000000000000,
|
||||
count=1,
|
||||
consumername=consumer_name,
|
||||
)
|
||||
if not pendings:
|
||||
return
|
||||
msg_id = pendings[0]["message_id"]
|
||||
msg = self.REDIS.xrange(queue_name, min=msg_id, count=1)
|
||||
_, payload = msg[0]
|
||||
return Payload(self.REDIS, queue_name, group_name, msg_id, payload)
|
||||
current_min = 0
|
||||
while True:
|
||||
payload = self.queue_consumer(queue_name, group_name, consumer_name, current_min)
|
||||
if not payload:
|
||||
return
|
||||
current_min = payload.get_msg_id()
|
||||
logging.info(f"RedisDB.get_unacked_iterator {consumer_name} msg_id {current_min}")
|
||||
yield payload
|
||||
except Exception as e:
|
||||
if "key" in str(e):
|
||||
return
|
||||
logging.exception(
|
||||
"RedisDB.get_unacked_for " + consumer_name + " got exception: " + str(e)
|
||||
"RedisDB.get_unacked_iterator " + consumer_name + " got exception: "
|
||||
)
|
||||
self.__open__()
|
||||
|
||||
|
52
uv.lock
generated
52
uv.lock
generated
@ -1,4 +1,5 @@
|
||||
version = 1
|
||||
revision = 1
|
||||
requires-python = ">=3.10, <3.13"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.12' and sys_platform == 'darwin'",
|
||||
@ -1083,9 +1084,6 @@ name = "datrie"
|
||||
version = "0.8.2"
|
||||
source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
|
||||
sdist = { url = "https://mirrors.aliyun.com/pypi/packages/9d/fe/db74bd405d515f06657f11ad529878fd389576dca4812bea6f98d9b31574/datrie-0.8.2.tar.gz", hash = "sha256:525b08f638d5cf6115df6ccd818e5a01298cd230b2dac91c8ff2e6499d18765d" }
|
||||
wheels = [
|
||||
{ url = "https://mirrors.aliyun.com/pypi/packages/44/02/53f0cf0bf0cd629ba6c2cc13f2f9db24323459e9c19463783d890a540a96/datrie-0.8.2-pp273-pypy_73-win32.whl", hash = "sha256:b07bd5fdfc3399a6dab86d6e35c72b1dbd598e80c97509c7c7518ab8774d3fda" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "decorator"
|
||||
@ -1362,17 +1360,17 @@ name = "fastembed-gpu"
|
||||
version = "0.3.6"
|
||||
source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
|
||||
dependencies = [
|
||||
{ name = "huggingface-hub" },
|
||||
{ name = "loguru" },
|
||||
{ name = "mmh3" },
|
||||
{ name = "numpy" },
|
||||
{ name = "onnxruntime-gpu" },
|
||||
{ name = "pillow" },
|
||||
{ name = "pystemmer" },
|
||||
{ name = "requests" },
|
||||
{ name = "snowballstemmer" },
|
||||
{ name = "tokenizers" },
|
||||
{ name = "tqdm" },
|
||||
{ name = "huggingface-hub", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "loguru", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "mmh3", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "numpy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "onnxruntime-gpu", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "pillow", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "pystemmer", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "requests", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "snowballstemmer", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "tokenizers", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "tqdm", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
]
|
||||
sdist = { url = "https://mirrors.aliyun.com/pypi/packages/da/07/7336c7f3d7ee47f33b407eeb50f5eeb152889de538a52a8f1cc637192816/fastembed_gpu-0.3.6.tar.gz", hash = "sha256:ee2de8918b142adbbf48caaffec0c492f864d73c073eea5a3dcd0e8c1041c50d" }
|
||||
wheels = [
|
||||
@ -3485,12 +3483,12 @@ name = "onnxruntime-gpu"
|
||||
version = "1.19.2"
|
||||
source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
|
||||
dependencies = [
|
||||
{ name = "coloredlogs" },
|
||||
{ name = "flatbuffers" },
|
||||
{ name = "numpy" },
|
||||
{ name = "packaging" },
|
||||
{ name = "protobuf" },
|
||||
{ name = "sympy" },
|
||||
{ name = "coloredlogs", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "flatbuffers", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "numpy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "packaging", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "protobuf", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "sympy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://mirrors.aliyun.com/pypi/packages/d0/9c/3fa310e0730643051eb88e884f19813a6c8b67d0fbafcda610d960e589db/onnxruntime_gpu-1.19.2-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a49740e079e7c5215830d30cde3df792e903df007aa0b0fd7aa797937061b27a" },
|
||||
@ -4164,15 +4162,6 @@ wheels = [
|
||||
{ url = "https://mirrors.aliyun.com/pypi/packages/77/89/bc88a6711935ba795a679ea6ebee07e128050d6382eaa35a0a47c8032bdc/pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pybind11"
|
||||
version = "2.13.6"
|
||||
source = { registry = "https://mirrors.aliyun.com/pypi/simple" }
|
||||
sdist = { url = "https://mirrors.aliyun.com/pypi/packages/d2/c1/72b9622fcb32ff98b054f724e213c7f70d6898baa714f4516288456ceaba/pybind11-2.13.6.tar.gz", hash = "sha256:ba6af10348c12b24e92fa086b39cfba0eff619b61ac77c406167d813b096d39a" }
|
||||
wheels = [
|
||||
{ url = "https://mirrors.aliyun.com/pypi/packages/13/2f/0f24b288e2ce56f51c920137620b4434a38fd80583dbbe24fc2a1656c388/pybind11-2.13.6-py3-none-any.whl", hash = "sha256:237c41e29157b962835d356b370ededd57594a26d5894a795960f0047cb5caf5" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyclipper"
|
||||
version = "1.3.0.post5"
|
||||
@ -4230,8 +4219,6 @@ wheels = [
|
||||
{ url = "https://mirrors.aliyun.com/pypi/packages/48/7d/0f2b09490b98cc6a902ac15dda8760c568b9c18cfe70e0ef7a16de64d53a/pycryptodomex-3.20.0-cp35-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:7a7a8f33a1f1fb762ede6cc9cbab8f2a9ba13b196bfaf7bc6f0b39d2ba315a43" },
|
||||
{ url = "https://mirrors.aliyun.com/pypi/packages/b0/1c/375adb14b71ee1c8d8232904e928b3e7af5bbbca7c04e4bec94fe8e90c3d/pycryptodomex-3.20.0-cp35-abi3-win32.whl", hash = "sha256:c39778fd0548d78917b61f03c1fa8bfda6cfcf98c767decf360945fe6f97461e" },
|
||||
{ url = "https://mirrors.aliyun.com/pypi/packages/b2/e8/1b92184ab7e5595bf38000587e6f8cf9556ebd1bf0a583619bee2057afbd/pycryptodomex-3.20.0-cp35-abi3-win_amd64.whl", hash = "sha256:2a47bcc478741b71273b917232f521fd5704ab4b25d301669879e7273d3586cc" },
|
||||
{ url = "https://mirrors.aliyun.com/pypi/packages/e7/c5/9140bb867141d948c8e242013ec8a8011172233c898dfdba0a2417c3169a/pycryptodomex-3.20.0-pp27-pypy_73-manylinux2010_x86_64.whl", hash = "sha256:1be97461c439a6af4fe1cf8bf6ca5936d3db252737d2f379cc6b2e394e12a458" },
|
||||
{ url = "https://mirrors.aliyun.com/pypi/packages/5e/6a/04acb4978ce08ab16890c70611ebc6efd251681341617bbb9e53356dee70/pycryptodomex-3.20.0-pp27-pypy_73-win32.whl", hash = "sha256:19764605feea0df966445d46533729b645033f134baeb3ea26ad518c9fdf212c" },
|
||||
{ url = "https://mirrors.aliyun.com/pypi/packages/eb/df/3f1ea084e43b91e6d2b6b3493cc948864c17ea5d93ff1261a03812fbfd1a/pycryptodomex-3.20.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:f2e497413560e03421484189a6b65e33fe800d3bd75590e6d78d4dfdb7accf3b" },
|
||||
{ url = "https://mirrors.aliyun.com/pypi/packages/c9/f3/83ffbdfa0c8f9154bcd8866895f6cae5a3ec749da8b0840603cf936c4412/pycryptodomex-3.20.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e48217c7901edd95f9f097feaa0388da215ed14ce2ece803d3f300b4e694abea" },
|
||||
{ url = "https://mirrors.aliyun.com/pypi/packages/c9/9d/c113e640aaf02af5631ae2686b742aac5cd0e1402b9d6512b1c7ec5ef05d/pycryptodomex-3.20.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d00fe8596e1cc46b44bf3907354e9377aa030ec4cd04afbbf6e899fc1e2a7781" },
|
||||
@ -4820,6 +4807,7 @@ dependencies = [
|
||||
{ name = "tencentcloud-sdk-python" },
|
||||
{ name = "tika" },
|
||||
{ name = "tiktoken" },
|
||||
{ name = "trio" },
|
||||
{ name = "umap-learn" },
|
||||
{ name = "valkey" },
|
||||
{ name = "vertexai" },
|
||||
@ -4954,6 +4942,7 @@ requires-dist = [
|
||||
{ name = "tiktoken", specifier = "==0.7.0" },
|
||||
{ name = "torch", marker = "extra == 'full'", specifier = ">=2.5.0,<3.0.0" },
|
||||
{ name = "transformers", marker = "extra == 'full'", specifier = ">=4.35.0,<5.0.0" },
|
||||
{ name = "trio", specifier = ">=0.29.0" },
|
||||
{ name = "umap-learn", specifier = "==0.5.6" },
|
||||
{ name = "valkey", specifier = "==6.0.2" },
|
||||
{ name = "vertexai", specifier = "==1.64.0" },
|
||||
@ -4969,6 +4958,7 @@ requires-dist = [
|
||||
{ name = "yfinance", specifier = "==0.1.96" },
|
||||
{ name = "zhipuai", specifier = "==2.0.1" },
|
||||
]
|
||||
provides-extras = ["full"]
|
||||
|
||||
[[package]]
|
||||
name = "ranx"
|
||||
|
Loading…
x
Reference in New Issue
Block a user