From 92e932065714e7740df22f6d0d7b69886c166dd7 Mon Sep 17 00:00:00 2001 From: KevinHuSh Date: Mon, 1 Jul 2024 15:50:24 +0800 Subject: [PATCH] upgrade laws parser of docx (#1332) ### What problem does this PR solve? ### Type of change - [x] Refactoring --- api/apps/chunk_app.py | 6 ++- api/db/services/dialog_service.py | 4 +- rag/app/laws.py | 72 ++++++++++++------------------- rag/nlp/__init__.py | 27 +++++++++--- 4 files changed, 56 insertions(+), 53 deletions(-) diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index 1da5c92de..67d65b930 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -20,7 +20,7 @@ from flask_login import login_required, current_user from elasticsearch_dsl import Q from rag.app.qa import rmPrefix, beAdoc -from rag.nlp import search, rag_tokenizer +from rag.nlp import search, rag_tokenizer, keyword_extraction from rag.utils.es_conn import ELASTICSEARCH from rag.utils import rmSpace from api.db import LLMType, ParserType @@ -268,6 +268,10 @@ def retrieval_test(): rerank_mdl = TenantLLMService.model_instance( kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"]) + if req.get("keyword", False): + chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT) + question += keyword_extraction(chat_mdl, question) + ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, similarity_threshold, vector_similarity_weight, top, doc_ids, rerank_mdl=rerank_mdl) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index e417a5609..245447dbb 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -23,7 +23,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle from api.settings import chat_logger, retrievaler from rag.app.resume import forbidden_select_fields4resume -from rag.nlp.rag_tokenizer import is_chinese +from rag.nlp import keyword_extraction from rag.nlp.search import index_name from rag.utils import rmSpace, num_tokens_from_string, encoder @@ -121,6 +121,8 @@ def chat(dialog, messages, stream=True, **kwargs): if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]: kbinfos = {"total": 0, "chunks": [], "doc_aggs": []} else: + if prompt_config.get("keyword", False): + questions[-1] += keyword_extraction(chat_mdl, questions[-1]) kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold, dialog.vector_similarity_weight, diff --git a/rag/app/laws.py b/rag/app/laws.py index 21929d1c1..3465d5938 100644 --- a/rag/app/laws.py +++ b/rag/app/laws.py @@ -54,62 +54,44 @@ class Docx(DocxParser): self.doc = Document( filename) if not binary else Document(BytesIO(binary)) pn = 0 - last_question, last_answer, last_level = "", "", -1 lines = [] - root = DocxNode() - point = root bull = bullets_category([p.text for p in self.doc.paragraphs]) for p in self.doc.paragraphs: if pn > to_page: break - question_level, p_text = 0, '' - if from_page <= pn < to_page and p.text.strip(): - question_level, p_text = docx_question_level(p, bull) - if not question_level or question_level > 6: # not a question - last_answer = f'{last_answer}\n{p_text}' - else: # is a question - if last_question: - while last_level <= point.level: - point = point.parent - new_node = DocxNode(last_question, last_answer, last_level, [], point) - point.childs.append(new_node) - point = new_node - last_question, last_answer, last_level = '', '', -1 - last_level = question_level - last_answer = '' - last_question = p_text - + question_level, p_text = docx_question_level(p, bull) + if not p_text.strip("\n"):continue + lines.append((question_level, p_text)) + for run in p.runs: if 'lastRenderedPageBreak' in run._element.xml: pn += 1 continue if 'w:br' in run._element.xml and 'type="page"' in run._element.xml: pn += 1 - if last_question: - while last_level <= point.level: - point = point.parent - new_node = DocxNode(last_question, last_answer, last_level, [], point) - point.childs.append(new_node) - point = new_node - last_question, last_answer, last_level = '', '', -1 - traversal_queue = [root] - while traversal_queue: - current_node: DocxNode = traversal_queue.pop() - sum_text = f'{self.__clean(current_node.question)}\n{self.__clean(current_node.answer)}' - if not current_node.childs and not current_node.answer.strip(): - continue - for child in current_node.childs: - sum_text = f'{sum_text}\n{self.__clean(child.question)}' - traversal_queue.insert(0, child) - lines.append(self.__clean(sum_text)) - return [l for l in lines if l] -class DocxNode: - def __init__(self, question: str = '', answer: str = '', level: int = 0, childs: list = [], parent = None) -> None: - self.question = question - self.answer = answer - self.level = level - self.childs = childs - self.parent = parent + + visit = [False for _ in range(len(lines))] + sections = [] + for s in range(len(lines)): + e = s + 1 + while e < len(lines): + if lines[e][0] <= lines[s][0]: + break + e += 1 + if e - s == 1 and visit[s]: continue + sec = [] + next_level = lines[s][0] + 1 + while not sec and next_level < 22: + for i in range(s+1, e): + if lines[i][0] != next_level: continue + sec.append(lines[i][1]) + visit[i] = True + next_level += 1 + sec.insert(0, lines[s][1]) + + sections.append("\n".join(sec)) + return [l for l in sections if l] + def __str__(self) -> str: return f''' question:{self.question}, diff --git a/rag/nlp/__init__.py b/rag/nlp/__init__.py index 8572a104a..3b5b03acc 100644 --- a/rag/nlp/__init__.py +++ b/rag/nlp/__init__.py @@ -514,16 +514,19 @@ def naive_merge(sections, chunk_token_num=128, delimiter="\n。;!?"): return cks + def docx_question_level(p, bull = -1): + txt = re.sub(r"\u3000", " ", p.text).strip() if p.style.name.startswith('Heading'): - return int(p.style.name.split(' ')[-1]), re.sub(r"\u3000", " ", p.text).strip() + return int(p.style.name.split(' ')[-1]), txt else: if bull < 0: - return 0, re.sub(r"\u3000", " ", p.text).strip() + return 0, txt for j, title in enumerate(BULLET_PATTERN[bull]): - if re.match(title, re.sub(r"\u3000", " ", p.text).strip()): - return j+1, re.sub(r"\u3000", " ", p.text).strip() - return 0, re.sub(r"\u3000", " ", p.text).strip() + if re.match(title, txt): + return j+1, txt + return len(BULLET_PATTERN[bull]), txt + def concat_img(img1, img2): if img1 and not img2: @@ -544,6 +547,7 @@ def concat_img(img1, img2): return new_image + def naive_merge_docx(sections, chunk_token_num=128, delimiter="\n。;!?"): if not sections: return [] @@ -573,4 +577,15 @@ def naive_merge_docx(sections, chunk_token_num=128, delimiter="\n。;!?"): for sec, image in sections: add_chunk(sec, image, '') - return cks, images \ No newline at end of file + return cks, images + + +def keyword_extraction(chat_mdl, content): + prompt = """ +You're a question analyzer. +1. Please give me the most important keyword/phrase of this question. +Answer format: (in language of user's question) + - keyword: +""" + kwd, _ = chat_mdl.chat(prompt, [{"role": "user", "content": content}], {"temperature": 0.2}) + return kwd