mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-15 15:25:52 +08:00
handle nits in task_executor (#2637)
### What problem does this PR solve? - fix typo - fix string format - format import ### Type of change - [x] Refactoring
This commit is contained in:
parent
ff9c11c970
commit
a44ed9626a
@ -25,34 +25,31 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from api.db.services.file2document_service import File2DocumentService
|
|
||||||
from api.settings import retrievaler
|
|
||||||
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
|
||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
|
||||||
from api.db.db_models import close_connection
|
|
||||||
from rag.settings import database_logger, SVR_QUEUE_NAME
|
|
||||||
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
|
|
||||||
from multiprocessing import Pool
|
|
||||||
import numpy as np
|
|
||||||
from elasticsearch_dsl import Q, Search
|
|
||||||
from multiprocessing.context import TimeoutError
|
|
||||||
from api.db.services.task_service import TaskService
|
|
||||||
from rag.utils.es_conn import ELASTICSEARCH
|
|
||||||
from timeit import default_timer as timer
|
|
||||||
from rag.utils import rmSpace, findMaxTm, num_tokens_from_string
|
|
||||||
|
|
||||||
from rag.nlp import search, rag_tokenizer
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import pandas as pd
|
from multiprocessing.context import TimeoutError
|
||||||
|
from timeit import default_timer as timer
|
||||||
|
|
||||||
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from elasticsearch_dsl import Q
|
||||||
|
|
||||||
from api.db import LLMType, ParserType
|
from api.db import LLMType, ParserType
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
|
from api.db.services.task_service import TaskService
|
||||||
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
|
from api.settings import retrievaler
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
from rag.utils.redis_conn import REDIS_CONN
|
from api.db.db_models import close_connection
|
||||||
|
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
|
||||||
|
from rag.nlp import search, rag_tokenizer
|
||||||
|
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
||||||
|
from rag.settings import database_logger, SVR_QUEUE_NAME
|
||||||
|
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
|
||||||
|
from rag.utils import rmSpace, num_tokens_from_string
|
||||||
|
from rag.utils.es_conn import ELASTICSEARCH
|
||||||
|
from rag.utils.redis_conn import REDIS_CONN, Payload
|
||||||
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
|
|
||||||
BATCH_SIZE = 64
|
BATCH_SIZE = 64
|
||||||
|
|
||||||
@ -74,11 +71,11 @@ FACTORY = {
|
|||||||
ParserType.KG.value: knowledge_graph
|
ParserType.KG.value: knowledge_graph
|
||||||
}
|
}
|
||||||
|
|
||||||
CONSUMEER_NAME = "task_consumer_" + ("0" if len(sys.argv) < 2 else sys.argv[1])
|
CONSUMER_NAME = "task_consumer_" + ("0" if len(sys.argv) < 2 else sys.argv[1])
|
||||||
PAYLOAD = None
|
PAYLOAD: Payload | None = None
|
||||||
|
|
||||||
def set_progress(task_id, from_page=0, to_page=-1,
|
|
||||||
prog=None, msg="Processing..."):
|
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
|
||||||
global PAYLOAD
|
global PAYLOAD
|
||||||
if prog is not None and prog < 0:
|
if prog is not None and prog < 0:
|
||||||
msg = "[ERROR]" + msg
|
msg = "[ERROR]" + msg
|
||||||
@ -107,11 +104,11 @@ def set_progress(task_id, from_page=0, to_page=-1,
|
|||||||
|
|
||||||
|
|
||||||
def collect():
|
def collect():
|
||||||
global CONSUMEER_NAME, PAYLOAD
|
global CONSUMER_NAME, PAYLOAD
|
||||||
try:
|
try:
|
||||||
PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMEER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
|
PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
|
||||||
if not PAYLOAD:
|
if not PAYLOAD:
|
||||||
PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMEER_NAME)
|
PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
|
||||||
if not PAYLOAD:
|
if not PAYLOAD:
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
return pd.DataFrame()
|
return pd.DataFrame()
|
||||||
@ -159,8 +156,8 @@ def build(row):
|
|||||||
binary = get_storage_binary(bucket, name)
|
binary = get_storage_binary(bucket, name)
|
||||||
cron_logger.info(
|
cron_logger.info(
|
||||||
"From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
|
"From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
|
||||||
except TimeoutError as e:
|
except TimeoutError:
|
||||||
callback(-1, f"Internal server error: Fetch file from minio timeout. Could you try it again.")
|
callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
|
||||||
cron_logger.error(
|
cron_logger.error(
|
||||||
"Minio {}/{}: Fetch file from minio timeout.".format(row["location"], row["name"]))
|
"Minio {}/{}: Fetch file from minio timeout.".format(row["location"], row["name"]))
|
||||||
return
|
return
|
||||||
@ -168,8 +165,7 @@ def build(row):
|
|||||||
if re.search("(No such file|not found)", str(e)):
|
if re.search("(No such file|not found)", str(e)):
|
||||||
callback(-1, "Can not find file <%s> from minio. Could you try it again?" % row["name"])
|
callback(-1, "Can not find file <%s> from minio. Could you try it again?" % row["name"])
|
||||||
else:
|
else:
|
||||||
callback(-1, f"Get file from minio: %s" %
|
callback(-1, "Get file from minio: %s" % str(e).replace("'", ""))
|
||||||
str(e).replace("'", ""))
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -180,7 +176,7 @@ def build(row):
|
|||||||
cron_logger.info(
|
cron_logger.info(
|
||||||
"Chunking({}) {}/{}".format(timer() - st, row["location"], row["name"]))
|
"Chunking({}) {}/{}".format(timer() - st, row["location"], row["name"]))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
callback(-1, f"Internal server error while chunking: %s" %
|
callback(-1, "Internal server error while chunking: %s" %
|
||||||
str(e).replace("'", ""))
|
str(e).replace("'", ""))
|
||||||
cron_logger.error(
|
cron_logger.error(
|
||||||
"Chunking {}/{}: {}".format(row["location"], row["name"], str(e)))
|
"Chunking {}/{}: {}".format(row["location"], row["name"], str(e)))
|
||||||
@ -236,7 +232,9 @@ def init_kb(row):
|
|||||||
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
|
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
|
||||||
|
|
||||||
|
|
||||||
def embedding(docs, mdl, parser_config={}, callback=None):
|
def embedding(docs, mdl, parser_config=None, callback=None):
|
||||||
|
if parser_config is None:
|
||||||
|
parser_config = {}
|
||||||
batch_size = 32
|
batch_size = 32
|
||||||
tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [
|
tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [
|
||||||
re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", d["content_with_weight"]) for d in docs]
|
re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", d["content_with_weight"]) for d in docs]
|
||||||
@ -374,7 +372,7 @@ def main():
|
|||||||
|
|
||||||
cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
||||||
if es_r:
|
if es_r:
|
||||||
callback(-1, f"Insert chunk error, detail info please check ragflow-logs/api/cron_logger.log. Please also check ES status!")
|
callback(-1, "Insert chunk error, detail info please check ragflow-logs/api/cron_logger.log. Please also check ES status!")
|
||||||
ELASTICSEARCH.deleteByQuery(
|
ELASTICSEARCH.deleteByQuery(
|
||||||
Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
|
Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
|
||||||
cron_logger.error(str(es_r))
|
cron_logger.error(str(es_r))
|
||||||
@ -392,15 +390,15 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
def report_status():
|
def report_status():
|
||||||
global CONSUMEER_NAME
|
global CONSUMER_NAME
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
obj = REDIS_CONN.get("TASKEXE")
|
obj = REDIS_CONN.get("TASKEXE")
|
||||||
if not obj: obj = {}
|
if not obj: obj = {}
|
||||||
else: obj = json.loads(obj)
|
else: obj = json.loads(obj)
|
||||||
if CONSUMEER_NAME not in obj: obj[CONSUMEER_NAME] = []
|
if CONSUMER_NAME not in obj: obj[CONSUMER_NAME] = []
|
||||||
obj[CONSUMEER_NAME].append(timer())
|
obj[CONSUMER_NAME].append(timer())
|
||||||
obj[CONSUMEER_NAME] = obj[CONSUMEER_NAME][-60:]
|
obj[CONSUMER_NAME] = obj[CONSUMER_NAME][-60:]
|
||||||
REDIS_CONN.set_obj("TASKEXE", obj, 60*2)
|
REDIS_CONN.set_obj("TASKEXE", obj, 60*2)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("[Exception]:", str(e))
|
print("[Exception]:", str(e))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user