refactor: improve handling of leading punctuation removal (#10761)

This commit is contained in:
Zane 2024-11-18 21:32:33 +08:00 committed by GitHub
parent 0ba17ec116
commit 14f3d44c37
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 42 additions and 15 deletions

View File

@ -29,6 +29,7 @@ from core.rag.splitter.fixed_text_splitter import (
FixedRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter,
) )
from core.rag.splitter.text_splitter import TextSplitter from core.rag.splitter.text_splitter import TextSplitter
from core.tools.utils.text_processing_utils import remove_leading_symbols
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from extensions.ext_storage import storage from extensions.ext_storage import storage
@ -500,11 +501,7 @@ class IndexingRunner:
document_node.metadata["doc_hash"] = hash document_node.metadata["doc_hash"] = hash
# delete Splitter character # delete Splitter character
page_content = document_node.page_content page_content = document_node.page_content
if page_content.startswith(".") or page_content.startswith(""): document_node.page_content = remove_leading_symbols(page_content)
page_content = page_content[1:]
else:
page_content = page_content
document_node.page_content = page_content
if document_node.page_content: if document_node.page_content:
split_documents.append(document_node) split_documents.append(document_node)

View File

@ -11,6 +11,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document from core.rag.models.document import Document
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper from libs import helper
from models.dataset import Dataset from models.dataset import Dataset
@ -43,11 +44,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
document_node.metadata["doc_id"] = doc_id document_node.metadata["doc_id"] = doc_id
document_node.metadata["doc_hash"] = hash document_node.metadata["doc_hash"] = hash
# delete Splitter character # delete Splitter character
page_content = document_node.page_content page_content = remove_leading_symbols(document_node.page_content).strip()
if page_content.startswith(".") or page_content.startswith(""):
page_content = page_content[1:].strip()
else:
page_content = page_content
if len(page_content) > 0: if len(page_content) > 0:
document_node.page_content = page_content document_node.page_content = page_content
split_documents.append(document_node) split_documents.append(document_node)

View File

@ -18,6 +18,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document from core.rag.models.document import Document
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper from libs import helper
from models.dataset import Dataset from models.dataset import Dataset
@ -53,11 +54,7 @@ class QAIndexProcessor(BaseIndexProcessor):
document_node.metadata["doc_hash"] = hash document_node.metadata["doc_hash"] = hash
# delete Splitter character # delete Splitter character
page_content = document_node.page_content page_content = document_node.page_content
if page_content.startswith(".") or page_content.startswith(""): document_node.page_content = remove_leading_symbols(page_content)
page_content = page_content[1:]
else:
page_content = page_content
document_node.page_content = page_content
split_documents.append(document_node) split_documents.append(document_node)
all_documents.extend(split_documents) all_documents.extend(split_documents)
for i in range(0, len(all_documents), 10): for i in range(0, len(all_documents), 10):

View File

@ -0,0 +1,16 @@
import re
def remove_leading_symbols(text: str) -> str:
"""
Remove leading punctuation or symbols from the given text.
Args:
text (str): The input text to process.
Returns:
str: The text with leading punctuation or symbols removed.
"""
# Match Unicode ranges for punctuation and symbols
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,\-./:;<=>?@\[\]^_`{|}~]+"
return re.sub(pattern, "", text)

View File

@ -0,0 +1,20 @@
from textwrap import dedent
import pytest
from core.tools.utils.text_processing_utils import remove_leading_symbols
@pytest.mark.parametrize(
("input_text", "expected_output"),
[
("...Hello, World!", "Hello, World!"),
("。测试中文标点", "测试中文标点"),
("!@#Test symbols", "Test symbols"),
("Hello, World!", "Hello, World!"),
("", ""),
(" ", " "),
],
)
def test_remove_leading_symbols(input_text, expected_output):
assert remove_leading_symbols(input_text) == expected_output