mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-14 19:25:54 +08:00
handle nits in task_executor (#2637)
### What problem does this PR solve? - fix typo - fix string format - format import ### Type of change - [x] Refactoring
This commit is contained in:
parent
ff9c11c970
commit
a44ed9626a
@ -25,34 +25,31 @@ import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.settings import retrievaler
|
||||
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from api.db.db_models import close_connection
|
||||
from rag.settings import database_logger, SVR_QUEUE_NAME
|
||||
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
|
||||
from multiprocessing import Pool
|
||||
import numpy as np
|
||||
from elasticsearch_dsl import Q, Search
|
||||
from multiprocessing.context import TimeoutError
|
||||
from api.db.services.task_service import TaskService
|
||||
from rag.utils.es_conn import ELASTICSEARCH
|
||||
from timeit import default_timer as timer
|
||||
from rag.utils import rmSpace, findMaxTm, num_tokens_from_string
|
||||
|
||||
from rag.nlp import search, rag_tokenizer
|
||||
from io import BytesIO
|
||||
import pandas as pd
|
||||
from multiprocessing.context import TimeoutError
|
||||
from timeit import default_timer as timer
|
||||
|
||||
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from elasticsearch_dsl import Q
|
||||
|
||||
from api.db import LLMType, ParserType
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.task_service import TaskService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.settings import retrievaler
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from api.db.db_models import close_connection
|
||||
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
|
||||
from rag.nlp import search, rag_tokenizer
|
||||
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
||||
from rag.settings import database_logger, SVR_QUEUE_NAME
|
||||
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
|
||||
from rag.utils import rmSpace, num_tokens_from_string
|
||||
from rag.utils.es_conn import ELASTICSEARCH
|
||||
from rag.utils.redis_conn import REDIS_CONN, Payload
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
|
||||
BATCH_SIZE = 64
|
||||
|
||||
@ -74,11 +71,11 @@ FACTORY = {
|
||||
ParserType.KG.value: knowledge_graph
|
||||
}
|
||||
|
||||
CONSUMEER_NAME = "task_consumer_" + ("0" if len(sys.argv) < 2 else sys.argv[1])
|
||||
PAYLOAD = None
|
||||
CONSUMER_NAME = "task_consumer_" + ("0" if len(sys.argv) < 2 else sys.argv[1])
|
||||
PAYLOAD: Payload | None = None
|
||||
|
||||
def set_progress(task_id, from_page=0, to_page=-1,
|
||||
prog=None, msg="Processing..."):
|
||||
|
||||
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
|
||||
global PAYLOAD
|
||||
if prog is not None and prog < 0:
|
||||
msg = "[ERROR]" + msg
|
||||
@ -107,11 +104,11 @@ def set_progress(task_id, from_page=0, to_page=-1,
|
||||
|
||||
|
||||
def collect():
|
||||
global CONSUMEER_NAME, PAYLOAD
|
||||
global CONSUMER_NAME, PAYLOAD
|
||||
try:
|
||||
PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMEER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
|
||||
PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
|
||||
if not PAYLOAD:
|
||||
PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMEER_NAME)
|
||||
PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
|
||||
if not PAYLOAD:
|
||||
time.sleep(1)
|
||||
return pd.DataFrame()
|
||||
@ -159,8 +156,8 @@ def build(row):
|
||||
binary = get_storage_binary(bucket, name)
|
||||
cron_logger.info(
|
||||
"From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
|
||||
except TimeoutError as e:
|
||||
callback(-1, f"Internal server error: Fetch file from minio timeout. Could you try it again.")
|
||||
except TimeoutError:
|
||||
callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
|
||||
cron_logger.error(
|
||||
"Minio {}/{}: Fetch file from minio timeout.".format(row["location"], row["name"]))
|
||||
return
|
||||
@ -168,8 +165,7 @@ def build(row):
|
||||
if re.search("(No such file|not found)", str(e)):
|
||||
callback(-1, "Can not find file <%s> from minio. Could you try it again?" % row["name"])
|
||||
else:
|
||||
callback(-1, f"Get file from minio: %s" %
|
||||
str(e).replace("'", ""))
|
||||
callback(-1, "Get file from minio: %s" % str(e).replace("'", ""))
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
@ -180,7 +176,7 @@ def build(row):
|
||||
cron_logger.info(
|
||||
"Chunking({}) {}/{}".format(timer() - st, row["location"], row["name"]))
|
||||
except Exception as e:
|
||||
callback(-1, f"Internal server error while chunking: %s" %
|
||||
callback(-1, "Internal server error while chunking: %s" %
|
||||
str(e).replace("'", ""))
|
||||
cron_logger.error(
|
||||
"Chunking {}/{}: {}".format(row["location"], row["name"], str(e)))
|
||||
@ -236,7 +232,9 @@ def init_kb(row):
|
||||
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
|
||||
|
||||
|
||||
def embedding(docs, mdl, parser_config={}, callback=None):
|
||||
def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
if parser_config is None:
|
||||
parser_config = {}
|
||||
batch_size = 32
|
||||
tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [
|
||||
re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", d["content_with_weight"]) for d in docs]
|
||||
@ -277,7 +275,7 @@ def embedding(docs, mdl, parser_config={}, callback=None):
|
||||
|
||||
def run_raptor(row, chat_mdl, embd_mdl, callback=None):
|
||||
vts, _ = embd_mdl.encode(["ok"])
|
||||
vctr_nm = "q_%d_vec"%len(vts[0])
|
||||
vctr_nm = "q_%d_vec" % len(vts[0])
|
||||
chunks = []
|
||||
for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], fields=["content_with_weight", vctr_nm]):
|
||||
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
||||
@ -374,7 +372,7 @@ def main():
|
||||
|
||||
cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
||||
if es_r:
|
||||
callback(-1, f"Insert chunk error, detail info please check ragflow-logs/api/cron_logger.log. Please also check ES status!")
|
||||
callback(-1, "Insert chunk error, detail info please check ragflow-logs/api/cron_logger.log. Please also check ES status!")
|
||||
ELASTICSEARCH.deleteByQuery(
|
||||
Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
|
||||
cron_logger.error(str(es_r))
|
||||
@ -392,15 +390,15 @@ def main():
|
||||
|
||||
|
||||
def report_status():
|
||||
global CONSUMEER_NAME
|
||||
global CONSUMER_NAME
|
||||
while True:
|
||||
try:
|
||||
obj = REDIS_CONN.get("TASKEXE")
|
||||
if not obj: obj = {}
|
||||
else: obj = json.loads(obj)
|
||||
if CONSUMEER_NAME not in obj: obj[CONSUMEER_NAME] = []
|
||||
obj[CONSUMEER_NAME].append(timer())
|
||||
obj[CONSUMEER_NAME] = obj[CONSUMEER_NAME][-60:]
|
||||
if CONSUMER_NAME not in obj: obj[CONSUMER_NAME] = []
|
||||
obj[CONSUMER_NAME].append(timer())
|
||||
obj[CONSUMER_NAME] = obj[CONSUMER_NAME][-60:]
|
||||
REDIS_CONN.set_obj("TASKEXE", obj, 60*2)
|
||||
except Exception as e:
|
||||
print("[Exception]:", str(e))
|
||||
|
Loading…
x
Reference in New Issue
Block a user