Add pdf support for QA parser (#1155)

### What problem does this PR solve?

Support extracting questions and answers from PDF files

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Zhedong Cen 2024-06-14 15:12:39 +08:00 committed by GitHub
parent 7dc39cbfa6
commit 90975460af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 194 additions and 10 deletions

View File

@ -13,13 +13,13 @@
import re import re
from copy import deepcopy from copy import deepcopy
from io import BytesIO from io import BytesIO
from timeit import default_timer as timer
from nltk import word_tokenize from nltk import word_tokenize
from openpyxl import load_workbook from openpyxl import load_workbook
from rag.nlp import is_english, random_choices, find_codec from rag.nlp import is_english, random_choices, find_codec, qbullets_category, add_positions, has_qbullet
from rag.nlp import rag_tokenizer from rag.nlp import rag_tokenizer, tokenize_table
from deepdoc.parser import ExcelParser from rag.settings import cron_logger
from deepdoc.parser import PdfParser, ExcelParser
class Excel(ExcelParser): class Excel(ExcelParser):
def __call__(self, fnm, binary=None, callback=None): def __call__(self, fnm, binary=None, callback=None):
if not binary: if not binary:
@ -62,12 +62,80 @@ class Excel(ExcelParser):
[rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q) > 1]) [rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q) > 1])
return res return res
class Pdf(PdfParser):
def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None):
start = timer()
callback(msg="OCR is running...")
self.__images__(
filename if not binary else binary,
zoomin,
from_page,
to_page,
callback
)
callback(msg="OCR finished")
cron_logger.info("OCR({}~{}): {}".format(from_page, to_page, timer() - start))
start = timer()
self._layouts_rec(zoomin, drop=False)
callback(0.63, "Layout analysis finished.")
self._table_transformer_job(zoomin)
callback(0.65, "Table analysis finished.")
self._text_merge()
callback(0.67, "Text merging finished")
tbls = self._extract_table_figure(True, zoomin, True, True)
#self._naive_vertical_merge()
# self._concat_downward()
#self._filter_forpages()
cron_logger.info("layouts: {}".format(timer() - start))
sections = [b["text"] for b in self.boxes]
bull_x0_list = []
q_bull, reg = qbullets_category(sections)
if q_bull == -1:
raise ValueError("Unable to recognize Q&A structure.")
qai_list = []
last_q, last_a, last_tag = '', '', ''
last_index = -1
last_box = {'text':''}
last_bull = None
for box in self.boxes:
section, line_tag = box['text'], self._line_tag(box, zoomin)
has_bull, index = has_qbullet(reg, box, last_box, last_index, last_bull, bull_x0_list)
last_box, last_index, last_bull = box, index, has_bull
if not has_bull: # No question bullet
if not last_q:
continue
else:
last_a = f'{last_a}{section}'
last_tag = f'{last_tag}{line_tag}'
else:
if last_q:
qai_list.append((last_q, last_a, *self.crop(last_tag, need_position=True)))
last_q, last_a, last_tag = '', '', ''
last_q = has_bull.group()
_, end = has_bull.span()
last_a = section[end:]
last_tag = line_tag
if last_q:
qai_list.append((last_q, last_a, *self.crop(last_tag, need_position=True)))
return qai_list, tbls
def rmPrefix(txt): def rmPrefix(txt):
return re.sub( return re.sub(
r"^(问题|答案|回答|user|assistant|Q|A|Question|Answer|问|答)[\t: ]+", "", txt.strip(), flags=re.IGNORECASE) r"^(问题|答案|回答|user|assistant|Q|A|Question|Answer|问|答)[\t: ]+", "", txt.strip(), flags=re.IGNORECASE)
def beAdocPdf(d, q, a, eng, image, poss):
qprefix = "Question: " if eng else "问题:"
aprefix = "Answer: " if eng else "回答:"
d["content_with_weight"] = "\t".join(
[qprefix + rmPrefix(q), aprefix + rmPrefix(a)])
d["content_ltks"] = rag_tokenizer.tokenize(q)
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
d["image"] = image
add_positions(d, poss)
return d
def beAdoc(d, q, a, eng): def beAdoc(d, q, a, eng):
qprefix = "Question: " if eng else "问题:" qprefix = "Question: " if eng else "问题:"
aprefix = "Answer: " if eng else "回答:" aprefix = "Answer: " if eng else "回答:"
@ -145,6 +213,19 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
return res return res
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf()
count = 0
qai_list, tbls = pdf_parser(filename if not binary else binary,
from_page=0, to_page=10000, callback=callback)
res = tokenize_table(tbls, doc, eng)
for q, a, image, poss in qai_list:
count += 1
res.append(beAdocPdf(deepcopy(doc), q, a, eng, image, poss))
return res
raise NotImplementedError( raise NotImplementedError(
"Excel and csv(txt) format files are supported.") "Excel and csv(txt) format files are supported.")
@ -153,6 +234,8 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
def dummy(a, b): def dummy(prog=None, msg=""):
pass pass
chunk(sys.argv[1], callback=dummy) import json
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)

View File

@ -21,6 +21,9 @@ from rag.utils import num_tokens_from_string
from . import rag_tokenizer from . import rag_tokenizer
import re import re
import copy import copy
import roman_numbers as r
from word2number import w2n
from cn2an import cn2an
all_codecs = [ all_codecs = [
'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs', 'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs',
@ -57,6 +60,95 @@ def find_codec(blob):
return "utf-8" return "utf-8"
QUESTION_PATTERN = [
r"第([零一二三四五六七八九十百0-9]+)问",
r"第([零一二三四五六七八九十百0-9]+)条",
r"[\(]([零一二三四五六七八九十百]+)[\)]",
r"第([0-9]+)问",
r"第([0-9]+)条",
r"([0-9]{1,2})[\. 、]",
r"([零一二三四五六七八九十百]+)[ 、]",
r"[\(]([0-9]{1,2})[\)]",
r"QUESTION (ONE|TWO|THREE|FOUR|FIVE|SIX|SEVEN|EIGHT|NINE|TEN)",
r"QUESTION (I+V?|VI*|XI|IX|X)",
r"QUESTION ([0-9]+)",
]
def has_qbullet(reg, box, last_box, last_index, last_bull, bull_x0_list):
section, last_section = box['text'], last_box['text']
q_reg = r'(\w|\W)*?(?:|\?|\n|$)+'
full_reg = reg + q_reg
has_bull = re.match(full_reg, section)
index_str = None
if has_bull:
if 'x0' not in last_box:
last_box['x0'] = box['x0']
if 'top' not in last_box:
last_box['top'] = box['top']
if last_bull and box['x0']-last_box['x0']>10:
return None, last_index
if not last_bull and box['x0'] >= last_box['x0'] and box['top'] - last_box['top'] < 20:
return None, last_index
avg_bull_x0 = 0
if bull_x0_list:
avg_bull_x0 = sum(bull_x0_list) / len(bull_x0_list)
else:
avg_bull_x0 = box['x0']
if box['x0'] - avg_bull_x0 > 10:
return None, last_index
index_str = has_bull.group(1)
index = index_int(index_str)
if last_section[-1] == ':' or last_section[-1] == '':
return None, last_index
if not last_index or index >= last_index:
bull_x0_list.append(box['x0'])
return has_bull, index
if section[-1] == '?' or section[-1] == '':
bull_x0_list.append(box['x0'])
return has_bull, index
if box['layout_type'] == 'title':
bull_x0_list.append(box['x0'])
return has_bull, index
pure_section = section.lstrip(re.match(reg, section).group()).lower()
ask_reg = r'(what|when|where|how|why|which|who|whose|为什么|为啥|哪)'
if re.match(ask_reg, pure_section):
bull_x0_list.append(box['x0'])
return has_bull, index
return None, last_index
def index_int(index_str):
res = -1
try:
res=int(index_str)
except ValueError:
try:
res=w2n.word_to_num(index_str)
except ValueError:
try:
res = cn2an(index_str)
except ValueError:
try:
res = r.number(index_str)
except ValueError:
return -1
return res
def qbullets_category(sections):
global QUESTION_PATTERN
hits = [0] * len(QUESTION_PATTERN)
for i, pro in enumerate(QUESTION_PATTERN):
for sec in sections:
if re.match(pro, sec) and not not_bullet(sec):
hits[i] += 1
break
maxium = 0
res = -1
for i, h in enumerate(hits):
if h <= maxium:
continue
res = i
maxium = h
return res, QUESTION_PATTERN[res]
BULLET_PATTERN = [[ BULLET_PATTERN = [[
r"第[零一二三四五六七八九十百0-9]+(分?编|部分)", r"第[零一二三四五六七八九十百0-9]+(分?编|部分)",

View File

@ -141,3 +141,6 @@ readability-lxml==0.8.1
html_text==0.6.2 html_text==0.6.2
selenium==4.21.0 selenium==4.21.0
webdriver-manager==4.0.1 webdriver-manager==4.0.1
cn2an==0.5.22
roman-numbers==1.0.2
word2number==1.1

View File

@ -140,3 +140,6 @@ volcengine==1.0.141
opencv-python-headless==4.9.0.80 opencv-python-headless==4.9.0.80
readability-lxml==0.8.1 readability-lxml==0.8.1
html_text==0.6.2 html_text==0.6.2
cn2an==0.5.22
roman-numbers==1.0.2
word2number==1.1

View File

@ -127,3 +127,6 @@ umap-learn
volcengine==1.0.141 volcengine==1.0.141
readability-lxml==0.8.1 readability-lxml==0.8.1
html_text==0.6.2 html_text==0.6.2
cn2an==0.5.22
roman-numbers==1.0.2
word2number==1.1