refine page ranges (#147)

This commit is contained in:
KevinHuSh 2024-03-25 13:11:57 +08:00 committed by GitHub
parent 1d9a50b090
commit 71fe314955
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 169 additions and 72 deletions

View File

@ -477,7 +477,7 @@ class Knowledgebase(DataBaseModel):
vector_similarity_weight = FloatField(default=0.3) vector_similarity_weight = FloatField(default=0.3)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value) parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value)
parser_config = JSONField(null=False, default={"pages":[[0,1000000]]}) parser_config = JSONField(null=False, default={"pages":[[1,1000000]]})
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1") status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
def __str__(self): def __str__(self):
@ -492,7 +492,7 @@ class Document(DataBaseModel):
thumbnail = TextField(null=True, help_text="thumbnail base64 string") thumbnail = TextField(null=True, help_text="thumbnail base64 string")
kb_id = CharField(max_length=256, null=False, index=True) kb_id = CharField(max_length=256, null=False, index=True)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID") parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
parser_config = JSONField(null=False, default={"pages":[[0,1000000]]}) parser_config = JSONField(null=False, default={"pages":[[1,1000000]]})
source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document from") source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document from")
type = CharField(max_length=32, null=False, help_text="file extension") type = CharField(max_length=32, null=False, help_text="file extension")
created_by = CharField(max_length=32, null=False, help_text="who created it") created_by = CharField(max_length=32, null=False, help_text="who created it")

View File

@ -1074,15 +1074,15 @@ class HuParser:
class PlainParser(object): class PlainParser(object):
def __call__(self, filename, **kwargs): def __call__(self, filename, from_page=0, to_page=100000, **kwargs):
self.outlines = [] self.outlines = []
lines = [] lines = []
try: try:
self.pdf = pdf2_read(filename if isinstance(filename, str) else BytesIO(filename)) self.pdf = pdf2_read(filename if isinstance(filename, str) else BytesIO(filename))
outlines = self.pdf.outline for page in self.pdf.pages[from_page:to_page]:
for page in self.pdf.pages:
lines.extend([t for t in page.extract_text().split("\n")]) lines.extend([t for t in page.extract_text().split("\n")])
outlines = self.pdf.outline
def dfs(arr, depth): def dfs(arr, depth):
for a in arr: for a in arr:
if isinstance(a, dict): if isinstance(a, dict):

View File

@ -15,6 +15,7 @@ import re
from collections import Counter from collections import Counter
from copy import deepcopy from copy import deepcopy
import numpy as np import numpy as np
from huggingface_hub import snapshot_download
from api.db import ParserType from api.db import ParserType
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
@ -36,7 +37,8 @@ class LayoutRecognizer(Recognizer):
"Equation", "Equation",
] ]
def __init__(self, domain): def __init__(self, domain):
super().__init__(self.labels, domain, os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
super().__init__(self.labels, domain, model_dir)#os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
self.garbage_layouts = ["footer", "header", "reference"] self.garbage_layouts = ["footer", "header", "reference"]
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16, drop=True): def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16, drop=True):

View File

@ -30,8 +30,6 @@ class Pdf(PdfParser):
# print(b) # print(b)
print("OCR:", timer()-start) print("OCR:", timer()-start)
self._layouts_rec(zoomin) self._layouts_rec(zoomin)
callback(0.65, "Layout analysis finished.") callback(0.65, "Layout analysis finished.")
print("paddle layouts:", timer() - start) print("paddle layouts:", timer() - start)
@ -47,53 +45,8 @@ class Pdf(PdfParser):
for b in self.boxes: for b in self.boxes:
b["text"] = re.sub(r"([\t  ]|\u3000){2,}", " ", b["text"].strip()) b["text"] = re.sub(r"([\t  ]|\u3000){2,}", " ", b["text"].strip())
return [(b["text"], b.get("layout_no", ""), self.get_position(b, zoomin)) for i, b in enumerate(self.boxes)] return [(b["text"], b.get("layout_no", ""), self.get_position(b, zoomin)) for i, b in enumerate(self.boxes)], tbls
# set pivot using the most frequent type of title,
# then merge between 2 pivot
if len(self.boxes)>0 and len(self.outlines)/len(self.boxes) > 0.1:
max_lvl = max([lvl for _, lvl in self.outlines])
most_level = max(0, max_lvl-1)
levels = []
for b in self.boxes:
for t,lvl in self.outlines:
tks = set([t[i]+t[i+1] for i in range(len(t)-1)])
tks_ = set([b["text"][i]+b["text"][i+1] for i in range(min(len(t), len(b["text"])-1))])
if len(set(tks & tks_))/max([len(tks), len(tks_), 1]) > 0.8:
levels.append(lvl)
break
else:
levels.append(max_lvl + 1)
else:
bull = bullets_category([b["text"] for b in self.boxes])
most_level, levels = title_frequency(bull, [(b["text"], b.get("layout_no","")) for b in self.boxes])
assert len(self.boxes) == len(levels)
sec_ids = []
sid = 0
for i, lvl in enumerate(levels):
if lvl <= most_level and i > 0 and lvl != levels[i-1]: sid += 1
sec_ids.append(sid)
#print(lvl, self.boxes[i]["text"], most_level, sid)
sections = [(b["text"], sec_ids[i], self.get_position(b, zoomin)) for i, b in enumerate(self.boxes)]
for (img, rows), poss in tbls:
sections.append((rows if isinstance(rows, str) else rows[0], -1, [(p[0]+1-from_page, p[1], p[2], p[3], p[4]) for p in poss]))
chunks = []
last_sid = -2
tk_cnt = 0
for txt, sec_id, poss in sorted(sections, key=lambda x: (x[-1][0][0], x[-1][0][3], x[-1][0][1])):
poss = "\t".join([tag(*pos) for pos in poss])
if tk_cnt < 2048 and (sec_id == last_sid or sec_id == -1):
if chunks:
chunks[-1] += "\n" + txt + poss
tk_cnt += num_tokens_from_string(txt)
continue
chunks.append(txt + poss)
tk_cnt = num_tokens_from_string(txt)
if sec_id >-1: last_sid = sec_id
return chunks, tbls
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
@ -106,7 +59,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
pdf_parser = Pdf() if kwargs.get("parser_config",{}).get("layout_recognize", True) else PlainParser() pdf_parser = Pdf() if kwargs.get("parser_config",{}).get("layout_recognize", True) else PlainParser()
sections, tbls = pdf_parser(filename if not binary else binary, sections, tbls = pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback) from_page=from_page, to_page=to_page, callback=callback)
if sections and len(sections[0])<3: cks = [(t, l, [0]*5) for t, l in sections] if sections and len(sections[0])<3: sections = [(t, l, [[0]*5]) for t, l in sections]
else: raise NotImplementedError("file type not supported yet(pdf supported)") else: raise NotImplementedError("file type not supported yet(pdf supported)")
doc = { doc = {
"docnm_kwd": filename "docnm_kwd": filename
@ -131,6 +85,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
break break
else: else:
levels.append(max_lvl + 1) levels.append(max_lvl + 1)
else: else:
bull = bullets_category([txt for txt,_,_ in sections]) bull = bullets_category([txt for txt,_,_ in sections])
most_level, levels = title_frequency(bull, [(txt, l) for txt, l, poss in sections]) most_level, levels = title_frequency(bull, [(txt, l) for txt, l, poss in sections])

View File

@ -45,7 +45,7 @@ class Pdf(PdfParser):
for (img, rows), poss in tbls: for (img, rows), poss in tbls:
sections.append((rows if isinstance(rows, str) else rows[0], sections.append((rows if isinstance(rows, str) else rows[0],
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss])) [(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
return [(txt, "") for txt, _ in sorted(sections, key=lambda x: (x[-1][0][0], x[-1][0][3], x[-1][0][1]))] return [(txt, "") for txt, _ in sorted(sections, key=lambda x: (x[-1][0][0], x[-1][0][3], x[-1][0][1]))], None
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
@ -56,7 +56,6 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
eng = lang.lower() == "english"#is_english(cks) eng = lang.lower() == "english"#is_english(cks)
sections = []
if re.search(r"\.docx?$", filename, re.IGNORECASE): if re.search(r"\.docx?$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
sections = [txt for txt in laws.Docx()(filename, binary) if txt] sections = [txt for txt in laws.Docx()(filename, binary) if txt]
@ -64,7 +63,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
elif re.search(r"\.pdf$", filename, re.IGNORECASE): elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf() if kwargs.get("parser_config",{}).get("layout_recognize", True) else PlainParser() pdf_parser = Pdf() if kwargs.get("parser_config",{}).get("layout_recognize", True) else PlainParser()
sections = pdf_parser(filename if not binary else binary, to_page=to_page, callback=callback) sections, _ = pdf_parser(filename if not binary else binary, to_page=to_page, callback=callback)
sections = [s for s, _ in sections if s] sections = [s for s, _ in sections if s]
elif re.search(r"\.xlsx?$", filename, re.IGNORECASE): elif re.search(r"\.xlsx?$", filename, re.IGNORECASE):

View File

@ -136,7 +136,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
"title": filename, "title": filename,
"authors": " ", "authors": " ",
"abstract": "", "abstract": "",
"sections": pdf_parser(filename if not binary else binary), "sections": pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page),
"tables": [] "tables": []
} }
else: else:

View File

@ -65,10 +65,10 @@ class Pdf(PdfParser):
class PlainPdf(PlainParser): class PlainPdf(PlainParser):
def __call__(self, filename, binary=None, callback=None, **kwargs): def __call__(self, filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
self.pdf = pdf2_read(filename if not binary else BytesIO(filename)) self.pdf = pdf2_read(filename if not binary else BytesIO(filename))
page_txt = [] page_txt = []
for page in self.pdf.pages: for page in self.pdf.pages[from_page: to_page]:
page_txt.append(page.extract_text()) page_txt.append(page.extract_text())
callback(0.9, "Parsing finished") callback(0.9, "Parsing finished")
return [(txt, None) for txt in page_txt] return [(txt, None) for txt in page_txt]

View File

@ -16,8 +16,8 @@ BULLET_PATTERN = [[
], [ ], [
r"第[0-9]+章", r"第[0-9]+章",
r"第[0-9]+节", r"第[0-9]+节",
r"[0-9]{,3}[\. 、]", r"[0-9]{,2}[\. 、]",
r"[0-9]{,2}\.[0-9]{,2}", r"[0-9]{,2}\.[0-9]{,2}[^a-zA-Z/%~-]",
r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}", r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}",
r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}", r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}",
], [ ], [
@ -40,13 +40,20 @@ def random_choices(arr, k):
return random.choices(arr, k=k) return random.choices(arr, k=k)
def not_bullet(line):
patt = [
r"0", r"[0-9]+ +[0-9~个只-]", r"[0-9]+\.{2,}"
]
return any([re.match(r, line) for r in patt])
def bullets_category(sections): def bullets_category(sections):
global BULLET_PATTERN global BULLET_PATTERN
hits = [0] * len(BULLET_PATTERN) hits = [0] * len(BULLET_PATTERN)
for i, pro in enumerate(BULLET_PATTERN): for i, pro in enumerate(BULLET_PATTERN):
for sec in sections: for sec in sections:
for p in pro: for p in pro:
if re.match(p, sec): if re.match(p, sec) and not not_bullet(sec):
hits[i] += 1 hits[i] += 1
break break
maxium = 0 maxium = 0
@ -194,7 +201,7 @@ def title_frequency(bull, sections):
for i, (txt, layout) in enumerate(sections): for i, (txt, layout) in enumerate(sections):
for j, p in enumerate(BULLET_PATTERN[bull]): for j, p in enumerate(BULLET_PATTERN[bull]):
if re.match(p, txt.strip()): if re.match(p, txt.strip()) and not not_bullet(txt):
levels[i] = j levels[i] = j
break break
else: else:

View File

@ -81,21 +81,22 @@ def dispatch():
tsks = [] tsks = []
if r["type"] == FileType.PDF.value: if r["type"] == FileType.PDF.value:
if not r["parser_config"].get("layout_recognize", True): do_layout = r["parser_config"].get("layout_recognize", True)
tsks.append(new_task())
continue
pages = PdfParser.total_page_number(r["name"], MINIO.get(r["kb_id"], r["location"])) pages = PdfParser.total_page_number(r["name"], MINIO.get(r["kb_id"], r["location"]))
page_size = r["parser_config"].get("task_page_size", 12) page_size = r["parser_config"].get("task_page_size", 12)
if r["parser_id"] == "paper": page_size = r["parser_config"].get("task_page_size", 22) if r["parser_id"] == "paper": page_size = r["parser_config"].get("task_page_size", 22)
if r["parser_id"] == "one": page_size = 1000000000 if r["parser_id"] == "one": page_size = 1000000000
if not do_layout: page_size = 1000000000
for s,e in r["parser_config"].get("pages", [(1, 100000)]): for s,e in r["parser_config"].get("pages", [(1, 100000)]):
s -= 1 s -= 1
e = min(e, pages) s = max(0, s)
e = min(e-1, pages)
for p in range(s, e, page_size): for p in range(s, e, page_size):
task = new_task() task = new_task()
task["from_page"] = p task["from_page"] = p
task["to_page"] = min(p + page_size, e) task["to_page"] = min(p + page_size, e)
tsks.append(task) tsks.append(task)
elif r["parser_id"] == "table": elif r["parser_id"] == "table":
rn = HuExcelParser.row_number(r["name"], MINIO.get(r["kb_id"], r["location"])) rn = HuExcelParser.row_number(r["name"], MINIO.get(r["kb_id"], r["location"]))
for i in range(0, rn, 3000): for i in range(0, rn, 3000):

View File

@ -75,7 +75,7 @@ def set_progress(task_id, from_page=0, to_page=-1,
if to_page > 0: if to_page > 0:
if msg: if msg:
msg = f"Page({from_page}~{to_page}): " + msg msg = f"Page({from_page+1}~{to_page+1}): " + msg
d = {"progress_msg": msg} d = {"progress_msg": msg}
if prog is not None: if prog is not None:
d["progress"] = prog d["progress"] = prog

133
requirements.txt Normal file
View File

@ -0,0 +1,133 @@
accelerate==0.27.2
aiohttp==3.9.3
aiosignal==1.3.1
annotated-types==0.6.0
anyio==4.3.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
Aspose.Slides==24.2.0
attrs==23.2.0
blinker==1.7.0
cachelib==0.12.0
cachetools==5.3.3
certifi==2024.2.2
cffi==1.16.0
charset-normalizer==3.3.2
click==8.1.7
coloredlogs==15.0.1
cryptography==42.0.5
dashscope==1.14.1
datasets==2.17.1
datrie==0.8.2
demjson==2.2.4
dill==0.3.8
distro==1.9.0
elastic-transport==8.12.0
elasticsearch==8.12.1
elasticsearch-dsl==8.12.0
et-xmlfile==1.1.0
filelock==3.13.1
FlagEmbedding==1.2.5
Flask==3.0.2
Flask-Cors==4.0.0
Flask-Login==0.6.3
Flask-Session==0.6.0
flatbuffers==23.5.26
frozenlist==1.4.1
fsspec==2023.10.0
h11==0.14.0
hanziconv==0.3.2
httpcore==1.0.4
httpx==0.27.0
huggingface-hub==0.20.3
humanfriendly==10.0
idna==3.6
install==1.3.5
itsdangerous==2.1.2
Jinja2==3.1.3
joblib==1.3.2
lxml==5.1.0
MarkupSafe==2.1.5
minio==7.2.4
mpi4py==3.1.5
mpmath==1.3.0
multidict==6.0.5
multiprocess==0.70.16
networkx==3.2.1
nltk==3.8.1
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.19.3
nvidia-nvjitlink-cu12==12.3.101
nvidia-nvtx-cu12==12.1.105
onnxruntime-gpu==1.17.1
openai==1.12.0
opencv-python==4.9.0.80
openpyxl==3.1.2
packaging==23.2
pandas==2.2.1
pdfminer.six==20221105
pdfplumber==0.10.4
peewee==3.17.1
pillow==10.2.0
protobuf==4.25.3
psutil==5.9.8
pyarrow==15.0.0
pyarrow-hotfix==0.6
pyclipper==1.3.0.post5
pycparser==2.21
pycryptodome==3.20.0
pycryptodome-test-vectors==1.0.14
pycryptodomex==3.20.0
pydantic==2.6.2
pydantic_core==2.16.3
PyJWT==2.8.0
PyMuPDF==1.23.25
PyMuPDFb==1.23.22
PyMySQL==1.1.0
PyPDF2==3.0.1
pypdfium2==4.27.0
python-dateutil==2.8.2
python-docx==1.1.0
python-dotenv==1.0.1
python-pptx==0.6.23
pytz==2024.1
PyYAML==6.0.1
regex==2023.12.25
requests==2.31.0
ruamel.yaml==0.18.6
ruamel.yaml.clib==0.2.8
safetensors==0.4.2
scikit-learn==1.4.1.post1
scipy==1.12.0
sentence-transformers==2.4.0
shapely==2.0.3
six==1.16.0
sniffio==1.3.1
StrEnum==0.4.15
sympy==1.12
threadpoolctl==3.3.0
tiktoken==0.6.0
tokenizers==0.15.2
torch==2.2.1
tqdm==4.66.2
transformers==4.38.1
triton==2.2.0
typing_extensions==4.10.0
tzdata==2024.1
urllib3==2.2.1
Werkzeug==3.0.1
xgboost==2.0.3
XlsxWriter==3.2.0
xpinyin==0.7.6
xxhash==3.4.1
yarl==1.9.4
zhipuai==2.0.1

View File

@ -193,7 +193,7 @@ const ChunkMethodModal: React.FC<IProps> = ({
rules={[ rules={[
{ {
required: true, required: true,
message: 'Missing end page number(excluding)', message: 'Missing end page number(excluded)',
}, },
({ getFieldValue }) => ({ ({ getFieldValue }) => ({
validator(_, value) { validator(_, value) {

View File

@ -120,7 +120,7 @@ export const TextMap = {
</p><p> </p><p>
For a document, it will be treated as an entire chunk, no split at all. For a document, it will be treated as an entire chunk, no split at all.
</p><p> </p><p>
If you don't trust any chunk method and the selected LLM's context length covers the document length, you can try this method. If you want to summarize something that needs all the context of an article and the selected LLM's context length covers the document length, you can try this method.
</p>`, </p>`,
}, },
}; };