chage tas execution logic (#103)

This commit is contained in:
KevinHuSh 2024-03-06 19:16:31 +08:00 committed by GitHub
parent 16eade4c48
commit b89ac3c4be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 25 additions and 16 deletions

View File

@ -15,6 +15,8 @@ import re
from collections import Counter from collections import Counter
from copy import deepcopy from copy import deepcopy
import numpy as np import numpy as np
from api.db import ParserType
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from deepdoc.vision import Recognizer from deepdoc.vision import Recognizer
@ -35,6 +37,7 @@ class LayoutRecognizer(Recognizer):
] ]
def __init__(self, domain): def __init__(self, domain):
super().__init__(self.labels, domain, os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) super().__init__(self.labels, domain, os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
self.garbage_layouts = ["footer", "header", "reference"]
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16): def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16):
def __is_garbage(b): def __is_garbage(b):
@ -85,7 +88,7 @@ class LayoutRecognizer(Recognizer):
i += 1 i += 1
continue continue
lts_[ii]["visited"] = True lts_[ii]["visited"] = True
if lts_[ii]["type"] in ["footer", "header", "reference"]: if lts_[ii]["type"] in self.garbage_layouts:
if lts_[ii]["type"] not in garbages: if lts_[ii]["type"] not in garbages:
garbages[lts_[ii]["type"]] = [] garbages[lts_[ii]["type"]] = []
garbages[lts_[ii]["type"]].append(bxs[i]["text"]) garbages[lts_[ii]["type"]].append(bxs[i]["text"])

View File

@ -6,11 +6,10 @@ export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/
PY=/root/miniconda3/envs/py11/bin/python PY=/root/miniconda3/envs/py11/bin/python
function task_exe(){ function task_exe(){
sleep 60; while [ 1 -eq 1 ];do
while [ 1 -eq 1 ];do mpirun -n 4 --allow-run-as-root $PY rag/svr/task_executor.py ; done $PY rag/svr/task_executor.py $1 $2;
done
} }
function watch_broker(){ function watch_broker(){
@ -29,7 +28,12 @@ function task_bro(){
} }
task_bro & task_bro &
task_exe &
WS=8
for ((i=0;i<WS;i++))
do
task_exe $i $WS &
done
$PY api/ragflow_server.py $PY api/ragflow_server.py

View File

@ -119,7 +119,6 @@ def add_positions(d, poss):
d["page_num_int"].append(pn + 1) d["page_num_int"].append(pn + 1)
d["top_int"].append(top) d["top_int"].append(top)
d["position_int"].append((pn + 1, left, right, top, bottom)) d["position_int"].append((pn + 1, left, right, top, bottom))
d["top_int"] = d["top_int"][:1]
def remove_contents_table(sections, eng=False): def remove_contents_table(sections, eng=False):

View File

@ -157,11 +157,11 @@ class EsQueryer:
s = 1e-9 s = 1e-9
for k, v in qtwt.items(): for k, v in qtwt.items():
if k in dtwt: if k in dtwt:
s += v * dtwt[k] s += v# * dtwt[k]
q = 1e-9 q = 1e-9
for k, v in qtwt.items(): for k, v in qtwt.items():
q += v * v q += v * v
d = 1e-9 d = 1e-9
for k, v in dtwt.items(): for k, v in dtwt.items():
d += v * v d += v * v
return s / math.sqrt(q) / math.sqrt(d) return s / q#math.sqrt(q) / math.sqrt(d)

View File

@ -192,7 +192,7 @@ class Dealer:
return [float(t) for t in txt.split("\t")] return [float(t) for t in txt.split("\t")]
def insert_citations(self, answer, chunks, chunk_v, def insert_citations(self, answer, chunks, chunk_v,
embd_mdl, tkweight=0.3, vtweight=0.7): embd_mdl, tkweight=0.7, vtweight=0.3):
assert len(chunks) == len(chunk_v) assert len(chunks) == len(chunk_v)
pieces = re.split(r"([;。?!\n]|[a-z][.?;!][ \n])", answer) pieces = re.split(r"([;。?!\n]|[a-z][.?;!][ \n])", answer)
for i in range(1, len(pieces)): for i in range(1, len(pieces)):
@ -224,7 +224,7 @@ class Dealer:
chunks_tks, chunks_tks,
tkweight, vtweight) tkweight, vtweight)
mx = np.max(sim) * 0.99 mx = np.max(sim) * 0.99
if mx < 0.55: if mx < 0.35:
continue continue
cites[idx[i]] = list( cites[idx[i]] = list(
set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4] set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4]
@ -237,7 +237,7 @@ class Dealer:
if i not in cites: if i not in cites:
continue continue
for c in cites[i]: assert int(c) < len(chunk_v) for c in cites[i]: assert int(c) < len(chunk_v)
res += "##%s$$" % "$".join(cites[i]) for c in cites[i]: res += f" ##{c}$$"
return res return res

View File

@ -152,6 +152,7 @@ class Dealer:
def ner(t): def ner(t):
if not self.ne or t not in self.ne: if not self.ne or t not in self.ne:
return 1 return 1
if re.match(r"[0-9,.]+$", t): return 2
m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3, m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3,
"firstnm": 1} "firstnm": 1}
return m[self.ne[t]] return m[self.ne[t]]

View File

@ -36,3 +36,5 @@ es_logger = getLogger("es")
minio_logger = getLogger("minio") minio_logger = getLogger("minio")
cron_logger = getLogger("cron_logger") cron_logger = getLogger("cron_logger")
chunk_logger = getLogger("chunk_logger") chunk_logger = getLogger("chunk_logger")
database_logger = getLogger("database")

View File

@ -23,13 +23,14 @@ import re
import sys import sys
import traceback import traceback
from functools import partial from functools import partial
from timeit import default_timer as timer
from rag.settings import database_logger
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
import numpy as np import numpy as np
from elasticsearch_dsl import Q from elasticsearch_dsl import Q
from api.db.services.task_service import TaskService from api.db.services.task_service import TaskService
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
from rag.utils import ELASTICSEARCH from rag.utils import ELASTICSEARCH
from rag.utils import MINIO from rag.utils import MINIO
from rag.utils import rmSpace, findMaxTm from rag.utils import rmSpace, findMaxTm
@ -43,7 +44,6 @@ from rag.app import laws, paper, presentation, manual, qa, table, book, resume,
from api.db import LLMType, ParserType from api.db import LLMType, ParserType
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.settings import database_logger
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
BATCH_SIZE = 64 BATCH_SIZE = 64
@ -267,4 +267,4 @@ if __name__ == "__main__":
from mpi4py import MPI from mpi4py import MPI
comm = MPI.COMM_WORLD comm = MPI.COMM_WORLD
main(comm.Get_size(), comm.Get_rank()) main(int(sys.argv[2]), int(sys.argv[1]))