diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index b4aee2621a..7db8f54f70 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -29,6 +29,7 @@ from core.rag.splitter.fixed_text_splitter import ( FixedRecursiveCharacterTextSplitter, ) from core.rag.splitter.text_splitter import TextSplitter +from core.tools.utils.text_processing_utils import remove_leading_symbols from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage @@ -500,11 +501,7 @@ class IndexingRunner: document_node.metadata["doc_hash"] = hash # delete Splitter character page_content = document_node.page_content - if page_content.startswith(".") or page_content.startswith("。"): - page_content = page_content[1:] - else: - page_content = page_content - document_node.page_content = page_content + document_node.page_content = remove_leading_symbols(page_content) if document_node.page_content: split_documents.append(document_node) diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index ed5712220f..a631f953ce 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -11,6 +11,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.models.document import Document +from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper from models.dataset import Dataset @@ -43,11 +44,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): document_node.metadata["doc_id"] = doc_id document_node.metadata["doc_hash"] = hash # delete Splitter character - page_content = document_node.page_content - if page_content.startswith(".") or page_content.startswith("。"): - page_content = page_content[1:].strip() - else: - page_content = page_content + page_content = remove_leading_symbols(document_node.page_content).strip() if len(page_content) > 0: document_node.page_content = page_content split_documents.append(document_node) diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 48e6cf7df7..320f0157a1 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -18,6 +18,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.models.document import Document +from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper from models.dataset import Dataset @@ -53,11 +54,7 @@ class QAIndexProcessor(BaseIndexProcessor): document_node.metadata["doc_hash"] = hash # delete Splitter character page_content = document_node.page_content - if page_content.startswith(".") or page_content.startswith("。"): - page_content = page_content[1:] - else: - page_content = page_content - document_node.page_content = page_content + document_node.page_content = remove_leading_symbols(page_content) split_documents.append(document_node) all_documents.extend(split_documents) for i in range(0, len(all_documents), 10): diff --git a/api/core/tools/utils/text_processing_utils.py b/api/core/tools/utils/text_processing_utils.py new file mode 100644 index 0000000000..6db9dfd0d9 --- /dev/null +++ b/api/core/tools/utils/text_processing_utils.py @@ -0,0 +1,16 @@ +import re + + +def remove_leading_symbols(text: str) -> str: + """ + Remove leading punctuation or symbols from the given text. + + Args: + text (str): The input text to process. + + Returns: + str: The text with leading punctuation or symbols removed. + """ + # Match Unicode ranges for punctuation and symbols + pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,\-./:;<=>?@\[\]^_`{|}~]+" + return re.sub(pattern, "", text) diff --git a/api/tests/unit_tests/utils/test_text_processing.py b/api/tests/unit_tests/utils/test_text_processing.py new file mode 100644 index 0000000000..f9d00d0b39 --- /dev/null +++ b/api/tests/unit_tests/utils/test_text_processing.py @@ -0,0 +1,20 @@ +from textwrap import dedent + +import pytest + +from core.tools.utils.text_processing_utils import remove_leading_symbols + + +@pytest.mark.parametrize( + ("input_text", "expected_output"), + [ + ("...Hello, World!", "Hello, World!"), + ("。测试中文标点", "测试中文标点"), + ("!@#Test symbols", "Test symbols"), + ("Hello, World!", "Hello, World!"), + ("", ""), + (" ", " "), + ], +) +def test_remove_leading_symbols(input_text, expected_output): + assert remove_leading_symbols(input_text) == expected_output