diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 3efcc373b1..b4ffee1f13 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -75,7 +75,7 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]): ) -def _extract_text(*, file_content: bytes, mime_type: str) -> str: +def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str: """Extract text from a file based on its MIME type.""" if mime_type.startswith("text/plain") or mime_type in {"text/html", "text/htm", "text/markdown", "text/xml"}: return _extract_text_from_plain_text(file_content) @@ -107,6 +107,33 @@ def _extract_text(*, file_content: bytes, mime_type: str) -> str: raise UnsupportedFileTypeError(f"Unsupported MIME type: {mime_type}") +def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) -> str: + """Extract text from a file based on its file extension.""" + match file_extension: + case ".txt" | ".markdown" | ".md" | ".html" | ".htm" | ".xml": + return _extract_text_from_plain_text(file_content) + case ".pdf": + return _extract_text_from_pdf(file_content) + case ".doc" | ".docx": + return _extract_text_from_doc(file_content) + case ".csv": + return _extract_text_from_csv(file_content) + case ".xls" | ".xlsx": + return _extract_text_from_excel(file_content) + case ".ppt": + return _extract_text_from_ppt(file_content) + case ".pptx": + return _extract_text_from_pptx(file_content) + case ".epub": + return _extract_text_from_epub(file_content) + case ".eml": + return _extract_text_from_eml(file_content) + case ".msg": + return _extract_text_from_msg(file_content) + case _: + raise UnsupportedFileTypeError(f"Unsupported Extension Type: {file_extension}") + + def _extract_text_from_plain_text(file_content: bytes) -> str: try: return file_content.decode("utf-8") @@ -159,7 +186,10 @@ def _extract_text_from_file(file: File): if file.mime_type is None: raise UnsupportedFileTypeError("Unable to determine file type: MIME type is missing") file_content = _download_file_content(file) - extracted_text = _extract_text(file_content=file_content, mime_type=file.mime_type) + if file.transfer_method == FileTransferMethod.REMOTE_URL: + extracted_text = _extract_text_by_mime_type(file_content=file_content, mime_type=file.mime_type) + else: + extracted_text = _extract_text_by_file_extension(file_content=file_content, file_extension=file.extension) return extracted_text @@ -172,7 +202,7 @@ def _extract_text_from_csv(file_content: bytes) -> str: if not rows: return "" - # Create markdown table + # Create Markdown table markdown_table = "| " + " | ".join(rows[0]) + " |\n" markdown_table += "| " + " | ".join(["---"] * len(rows[0])) + " |\n" for row in rows[1:]: @@ -192,7 +222,7 @@ def _extract_text_from_excel(file_content: bytes) -> str: # Drop rows where all elements are NaN df.dropna(how="all", inplace=True) - # Convert DataFrame to markdown table + # Convert DataFrame to Markdown table markdown_table = df.to_markdown(index=False) return markdown_table except Exception as e: diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index 7471e13e1e..a141fa9a13 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -63,17 +63,24 @@ def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_s @pytest.mark.parametrize( - ("mime_type", "file_content", "expected_text", "transfer_method"), + ("mime_type", "file_content", "expected_text", "transfer_method", "extension"), [ - ("text/plain", b"Hello, world!", ["Hello, world!"], FileTransferMethod.LOCAL_FILE), - ("application/pdf", b"%PDF-1.5\n%Test PDF content", ["Mocked PDF content"], FileTransferMethod.LOCAL_FILE), + ("text/plain", b"Hello, world!", ["Hello, world!"], FileTransferMethod.LOCAL_FILE, ".txt"), + ( + "application/pdf", + b"%PDF-1.5\n%Test PDF content", + ["Mocked PDF content"], + FileTransferMethod.LOCAL_FILE, + ".pdf", + ), ( "application/vnd.openxmlformats-officedocument.wordprocessingml.document", b"PK\x03\x04", ["Mocked DOCX content"], - FileTransferMethod.LOCAL_FILE, + FileTransferMethod.REMOTE_URL, + "", ), - ("text/plain", b"Remote content", ["Remote content"], FileTransferMethod.REMOTE_URL), + ("text/plain", b"Remote content", ["Remote content"], FileTransferMethod.REMOTE_URL, None), ], ) def test_run_extract_text( @@ -83,6 +90,7 @@ def test_run_extract_text( file_content, expected_text, transfer_method, + extension, monkeypatch, ): document_extractor_node.graph_runtime_state = mock_graph_runtime_state @@ -92,6 +100,7 @@ def test_run_extract_text( mock_file.transfer_method = transfer_method mock_file.related_id = "test_file_id" if transfer_method == FileTransferMethod.LOCAL_FILE else None mock_file.remote_url = "https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None + mock_file.extension = extension mock_array_file_segment = Mock(spec=ArrayFileSegment) mock_array_file_segment.value = [mock_file]