mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 23:06:15 +08:00
refactor: improve handling of leading punctuation removal (#10761)
This commit is contained in:
parent
0ba17ec116
commit
14f3d44c37
@ -29,6 +29,7 @@ from core.rag.splitter.fixed_text_splitter import (
|
|||||||
FixedRecursiveCharacterTextSplitter,
|
FixedRecursiveCharacterTextSplitter,
|
||||||
)
|
)
|
||||||
from core.rag.splitter.text_splitter import TextSplitter
|
from core.rag.splitter.text_splitter import TextSplitter
|
||||||
|
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from extensions.ext_storage import storage
|
from extensions.ext_storage import storage
|
||||||
@ -500,11 +501,7 @@ class IndexingRunner:
|
|||||||
document_node.metadata["doc_hash"] = hash
|
document_node.metadata["doc_hash"] = hash
|
||||||
# delete Splitter character
|
# delete Splitter character
|
||||||
page_content = document_node.page_content
|
page_content = document_node.page_content
|
||||||
if page_content.startswith(".") or page_content.startswith("。"):
|
document_node.page_content = remove_leading_symbols(page_content)
|
||||||
page_content = page_content[1:]
|
|
||||||
else:
|
|
||||||
page_content = page_content
|
|
||||||
document_node.page_content = page_content
|
|
||||||
|
|
||||||
if document_node.page_content:
|
if document_node.page_content:
|
||||||
split_documents.append(document_node)
|
split_documents.append(document_node)
|
||||||
|
@ -11,6 +11,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
|
|||||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
|
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
|
|
||||||
@ -43,11 +44,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
|||||||
document_node.metadata["doc_id"] = doc_id
|
document_node.metadata["doc_id"] = doc_id
|
||||||
document_node.metadata["doc_hash"] = hash
|
document_node.metadata["doc_hash"] = hash
|
||||||
# delete Splitter character
|
# delete Splitter character
|
||||||
page_content = document_node.page_content
|
page_content = remove_leading_symbols(document_node.page_content).strip()
|
||||||
if page_content.startswith(".") or page_content.startswith("。"):
|
|
||||||
page_content = page_content[1:].strip()
|
|
||||||
else:
|
|
||||||
page_content = page_content
|
|
||||||
if len(page_content) > 0:
|
if len(page_content) > 0:
|
||||||
document_node.page_content = page_content
|
document_node.page_content = page_content
|
||||||
split_documents.append(document_node)
|
split_documents.append(document_node)
|
||||||
|
@ -18,6 +18,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
|
|||||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
|
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
|
|
||||||
@ -53,11 +54,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
|||||||
document_node.metadata["doc_hash"] = hash
|
document_node.metadata["doc_hash"] = hash
|
||||||
# delete Splitter character
|
# delete Splitter character
|
||||||
page_content = document_node.page_content
|
page_content = document_node.page_content
|
||||||
if page_content.startswith(".") or page_content.startswith("。"):
|
document_node.page_content = remove_leading_symbols(page_content)
|
||||||
page_content = page_content[1:]
|
|
||||||
else:
|
|
||||||
page_content = page_content
|
|
||||||
document_node.page_content = page_content
|
|
||||||
split_documents.append(document_node)
|
split_documents.append(document_node)
|
||||||
all_documents.extend(split_documents)
|
all_documents.extend(split_documents)
|
||||||
for i in range(0, len(all_documents), 10):
|
for i in range(0, len(all_documents), 10):
|
||||||
|
16
api/core/tools/utils/text_processing_utils.py
Normal file
16
api/core/tools/utils/text_processing_utils.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def remove_leading_symbols(text: str) -> str:
|
||||||
|
"""
|
||||||
|
Remove leading punctuation or symbols from the given text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): The input text to process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The text with leading punctuation or symbols removed.
|
||||||
|
"""
|
||||||
|
# Match Unicode ranges for punctuation and symbols
|
||||||
|
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,\-./:;<=>?@\[\]^_`{|}~]+"
|
||||||
|
return re.sub(pattern, "", text)
|
20
api/tests/unit_tests/utils/test_text_processing.py
Normal file
20
api/tests/unit_tests/utils/test_text_processing.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from textwrap import dedent
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("input_text", "expected_output"),
|
||||||
|
[
|
||||||
|
("...Hello, World!", "Hello, World!"),
|
||||||
|
("。测试中文标点", "测试中文标点"),
|
||||||
|
("!@#Test symbols", "Test symbols"),
|
||||||
|
("Hello, World!", "Hello, World!"),
|
||||||
|
("", ""),
|
||||||
|
(" ", " "),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_remove_leading_symbols(input_text, expected_output):
|
||||||
|
assert remove_leading_symbols(input_text) == expected_output
|
Loading…
x
Reference in New Issue
Block a user