From 90975460af25d31b65bc5578b9124eaa136b1c32 Mon Sep 17 00:00:00 2001 From: Zhedong Cen Date: Fri, 14 Jun 2024 15:12:39 +0800 Subject: [PATCH] Add pdf support for QA parser (#1155) ### What problem does this PR solve? Support extracting questions and answers from PDF files ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- rag/app/qa.py | 99 ++++++++++++++++++++++++++++++++++++++++---- rag/nlp/__init__.py | 92 ++++++++++++++++++++++++++++++++++++++++ requirements.txt | 3 ++ requirements_arm.txt | 5 ++- requirements_dev.txt | 5 ++- 5 files changed, 194 insertions(+), 10 deletions(-) diff --git a/rag/app/qa.py b/rag/app/qa.py index 1ecf9b187..5eabe1dde 100644 --- a/rag/app/qa.py +++ b/rag/app/qa.py @@ -13,13 +13,13 @@ import re from copy import deepcopy from io import BytesIO +from timeit import default_timer as timer from nltk import word_tokenize from openpyxl import load_workbook -from rag.nlp import is_english, random_choices, find_codec -from rag.nlp import rag_tokenizer -from deepdoc.parser import ExcelParser - - +from rag.nlp import is_english, random_choices, find_codec, qbullets_category, add_positions, has_qbullet +from rag.nlp import rag_tokenizer, tokenize_table +from rag.settings import cron_logger +from deepdoc.parser import PdfParser, ExcelParser class Excel(ExcelParser): def __call__(self, fnm, binary=None, callback=None): if not binary: @@ -62,12 +62,80 @@ class Excel(ExcelParser): [rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q) > 1]) return res - +class Pdf(PdfParser): + def __call__(self, filename, binary=None, from_page=0, + to_page=100000, zoomin=3, callback=None): + start = timer() + callback(msg="OCR is running...") + self.__images__( + filename if not binary else binary, + zoomin, + from_page, + to_page, + callback + ) + callback(msg="OCR finished") + cron_logger.info("OCR({}~{}): {}".format(from_page, to_page, timer() - start)) + start = timer() + self._layouts_rec(zoomin, drop=False) + callback(0.63, "Layout analysis finished.") + self._table_transformer_job(zoomin) + callback(0.65, "Table analysis finished.") + self._text_merge() + callback(0.67, "Text merging finished") + tbls = self._extract_table_figure(True, zoomin, True, True) + #self._naive_vertical_merge() + # self._concat_downward() + #self._filter_forpages() + cron_logger.info("layouts: {}".format(timer() - start)) + sections = [b["text"] for b in self.boxes] + bull_x0_list = [] + q_bull, reg = qbullets_category(sections) + if q_bull == -1: + raise ValueError("Unable to recognize Q&A structure.") + qai_list = [] + last_q, last_a, last_tag = '', '', '' + last_index = -1 + last_box = {'text':''} + last_bull = None + for box in self.boxes: + section, line_tag = box['text'], self._line_tag(box, zoomin) + has_bull, index = has_qbullet(reg, box, last_box, last_index, last_bull, bull_x0_list) + last_box, last_index, last_bull = box, index, has_bull + if not has_bull: # No question bullet + if not last_q: + continue + else: + last_a = f'{last_a}{section}' + last_tag = f'{last_tag}{line_tag}' + else: + if last_q: + qai_list.append((last_q, last_a, *self.crop(last_tag, need_position=True))) + last_q, last_a, last_tag = '', '', '' + last_q = has_bull.group() + _, end = has_bull.span() + last_a = section[end:] + last_tag = line_tag + if last_q: + qai_list.append((last_q, last_a, *self.crop(last_tag, need_position=True))) + return qai_list, tbls + def rmPrefix(txt): return re.sub( r"^(问题|答案|回答|user|assistant|Q|A|Question|Answer|问|答)[\t:: ]+", "", txt.strip(), flags=re.IGNORECASE) +def beAdocPdf(d, q, a, eng, image, poss): + qprefix = "Question: " if eng else "问题:" + aprefix = "Answer: " if eng else "回答:" + d["content_with_weight"] = "\t".join( + [qprefix + rmPrefix(q), aprefix + rmPrefix(a)]) + d["content_ltks"] = rag_tokenizer.tokenize(q) + d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) + d["image"] = image + add_positions(d, poss) + return d + def beAdoc(d, q, a, eng): qprefix = "Question: " if eng else "问题:" aprefix = "Answer: " if eng else "回答:" @@ -145,6 +213,19 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs): f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) return res + elif re.search(r"\.pdf$", filename, re.IGNORECASE): + pdf_parser = Pdf() + count = 0 + qai_list, tbls = pdf_parser(filename if not binary else binary, + from_page=0, to_page=10000, callback=callback) + + res = tokenize_table(tbls, doc, eng) + + for q, a, image, poss in qai_list: + count += 1 + res.append(beAdocPdf(deepcopy(doc), q, a, eng, image, poss)) + return res + raise NotImplementedError( "Excel and csv(txt) format files are supported.") @@ -153,6 +234,8 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs): if __name__ == "__main__": import sys - def dummy(a, b): + def dummy(prog=None, msg=""): pass - chunk(sys.argv[1], callback=dummy) + import json + + chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy) diff --git a/rag/nlp/__init__.py b/rag/nlp/__init__.py index 9cc57992c..eff298686 100644 --- a/rag/nlp/__init__.py +++ b/rag/nlp/__init__.py @@ -21,6 +21,9 @@ from rag.utils import num_tokens_from_string from . import rag_tokenizer import re import copy +import roman_numbers as r +from word2number import w2n +from cn2an import cn2an all_codecs = [ 'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs', @@ -57,6 +60,95 @@ def find_codec(blob): return "utf-8" +QUESTION_PATTERN = [ + r"第([零一二三四五六七八九十百0-9]+)问", + r"第([零一二三四五六七八九十百0-9]+)条", + r"[\((]([零一二三四五六七八九十百]+)[\))]", + r"第([0-9]+)问", + r"第([0-9]+)条", + r"([0-9]{1,2})[\. 、]", + r"([零一二三四五六七八九十百]+)[ 、]", + r"[\((]([0-9]{1,2})[\))]", + r"QUESTION (ONE|TWO|THREE|FOUR|FIVE|SIX|SEVEN|EIGHT|NINE|TEN)", + r"QUESTION (I+V?|VI*|XI|IX|X)", + r"QUESTION ([0-9]+)", +] + +def has_qbullet(reg, box, last_box, last_index, last_bull, bull_x0_list): + section, last_section = box['text'], last_box['text'] + q_reg = r'(\w|\W)*?(?:?|\?|\n|$)+' + full_reg = reg + q_reg + has_bull = re.match(full_reg, section) + index_str = None + if has_bull: + if 'x0' not in last_box: + last_box['x0'] = box['x0'] + if 'top' not in last_box: + last_box['top'] = box['top'] + if last_bull and box['x0']-last_box['x0']>10: + return None, last_index + if not last_bull and box['x0'] >= last_box['x0'] and box['top'] - last_box['top'] < 20: + return None, last_index + avg_bull_x0 = 0 + if bull_x0_list: + avg_bull_x0 = sum(bull_x0_list) / len(bull_x0_list) + else: + avg_bull_x0 = box['x0'] + if box['x0'] - avg_bull_x0 > 10: + return None, last_index + index_str = has_bull.group(1) + index = index_int(index_str) + if last_section[-1] == ':' or last_section[-1] == ':': + return None, last_index + if not last_index or index >= last_index: + bull_x0_list.append(box['x0']) + return has_bull, index + if section[-1] == '?' or section[-1] == '?': + bull_x0_list.append(box['x0']) + return has_bull, index + if box['layout_type'] == 'title': + bull_x0_list.append(box['x0']) + return has_bull, index + pure_section = section.lstrip(re.match(reg, section).group()).lower() + ask_reg = r'(what|when|where|how|why|which|who|whose|为什么|为啥|哪)' + if re.match(ask_reg, pure_section): + bull_x0_list.append(box['x0']) + return has_bull, index + return None, last_index + +def index_int(index_str): + res = -1 + try: + res=int(index_str) + except ValueError: + try: + res=w2n.word_to_num(index_str) + except ValueError: + try: + res = cn2an(index_str) + except ValueError: + try: + res = r.number(index_str) + except ValueError: + return -1 + return res + +def qbullets_category(sections): + global QUESTION_PATTERN + hits = [0] * len(QUESTION_PATTERN) + for i, pro in enumerate(QUESTION_PATTERN): + for sec in sections: + if re.match(pro, sec) and not not_bullet(sec): + hits[i] += 1 + break + maxium = 0 + res = -1 + for i, h in enumerate(hits): + if h <= maxium: + continue + res = i + maxium = h + return res, QUESTION_PATTERN[res] BULLET_PATTERN = [[ r"第[零一二三四五六七八九十百0-9]+(分?编|部分)", diff --git a/requirements.txt b/requirements.txt index 3127ad72a..4cc6dac50 100644 --- a/requirements.txt +++ b/requirements.txt @@ -141,3 +141,6 @@ readability-lxml==0.8.1 html_text==0.6.2 selenium==4.21.0 webdriver-manager==4.0.1 +cn2an==0.5.22 +roman-numbers==1.0.2 +word2number==1.1 \ No newline at end of file diff --git a/requirements_arm.txt b/requirements_arm.txt index a48fc0c86..c52c08cd1 100644 --- a/requirements_arm.txt +++ b/requirements_arm.txt @@ -139,4 +139,7 @@ fasttext==0.9.2 volcengine==1.0.141 opencv-python-headless==4.9.0.80 readability-lxml==0.8.1 -html_text==0.6.2 \ No newline at end of file +html_text==0.6.2 +cn2an==0.5.22 +roman-numbers==1.0.2 +word2number==1.1 \ No newline at end of file diff --git a/requirements_dev.txt b/requirements_dev.txt index 4c7062afe..68ca5afd3 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -126,4 +126,7 @@ fasttext==0.9.2 umap-learn volcengine==1.0.141 readability-lxml==0.8.1 -html_text==0.6.2 \ No newline at end of file +html_text==0.6.2 +cn2an==0.5.22 +roman-numbers==1.0.2 +word2number==1.1 \ No newline at end of file