From dcf19549cb4b3e8b8c7e82e14d2afb04ae6fcf83 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Tue, 3 Dec 2024 19:27:57 +0800 Subject: [PATCH] feat: move audio and webscraper back to dify --- api/core/agent/cot_agent_runner.py | 6 +- api/core/agent/fc_agent_runner.py | 6 +- api/core/app/apps/agent_chat/app_generator.py | 2 +- api/core/tools/__base/tool.py | 20 +- .../providers/audio/_assets/icon.svg | 3 + .../builtin_tool/providers/audio/audio.py | 6 + .../builtin_tool/providers/audio/audio.yaml | 11 + .../builtin_tool/providers/audio/tools/asr.py | 71 ++++ .../providers/audio/tools/asr.yaml | 22 ++ .../builtin_tool/providers/audio/tools/tts.py | 87 ++++ .../providers/audio/tools/tts.yaml | 22 ++ .../providers/webscraper/_assets/icon.svg | 3 + .../providers/webscraper/tools/webscraper.py | 36 ++ .../webscraper/tools/webscraper.yaml | 60 +++ .../providers/webscraper/webscraper.py | 8 + .../providers/webscraper/webscraper.yaml | 15 + api/core/tools/entities/tool_entities.py | 2 - api/core/tools/tool_engine.py | 11 +- api/core/tools/utils/message_transformer.py | 6 - api/core/tools/utils/web_reader_tool.py | 374 ++++++++++++++++++ api/core/workflow/nodes/tool/tool_node.py | 4 +- 21 files changed, 741 insertions(+), 34 deletions(-) create mode 100644 api/core/tools/builtin_tool/providers/audio/_assets/icon.svg create mode 100644 api/core/tools/builtin_tool/providers/audio/audio.py create mode 100644 api/core/tools/builtin_tool/providers/audio/audio.yaml create mode 100644 api/core/tools/builtin_tool/providers/audio/tools/asr.py create mode 100644 api/core/tools/builtin_tool/providers/audio/tools/asr.yaml create mode 100644 api/core/tools/builtin_tool/providers/audio/tools/tts.py create mode 100644 api/core/tools/builtin_tool/providers/audio/tools/tts.yaml create mode 100644 api/core/tools/builtin_tool/providers/webscraper/_assets/icon.svg create mode 100644 api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.py create mode 100644 api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.yaml create mode 100644 api/core/tools/builtin_tool/providers/webscraper/webscraper.py create mode 100644 api/core/tools/builtin_tool/providers/webscraper/webscraper.yaml create mode 100644 api/core/tools/utils/web_reader_tool.py diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index efcc89f555..8b510258e8 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -309,13 +309,13 @@ class CotAgentRunner(BaseAgentRunner, ABC): ) # publish files - for message_file_id, save_as in message_files: + for message_file_id in message_files: # publish message file self.queue_manager.publish( - QueueMessageFileEvent(message_file_id=message_file_id.id), PublishFrom.APPLICATION_MANAGER + QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER ) # add message file ids - message_file_ids.append(message_file_id.id) + message_file_ids.append(message_file_id) return tool_invoke_response, tool_invoke_meta diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 68042cc7ee..a63d92c1ae 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -246,13 +246,13 @@ class FunctionCallAgentRunner(BaseAgentRunner): conversation_id=self.conversation.id, ) # publish files - for message_file_id, save_as in message_files: + for message_file_id in message_files: # publish message file self.queue_manager.publish( - QueueMessageFileEvent(message_file_id=message_file_id.id), PublishFrom.APPLICATION_MANAGER + QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER ) # add message file ids - message_file_ids.append(message_file_id.id) + message_file_ids.append(message_file_id) tool_response = { "tool_call_id": tool_call_id, diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 2bf696cbe0..33db4beb0e 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -172,7 +172,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): target=self._generate_worker, kwargs={ "flask_app": current_app._get_current_object(), # type: ignore - "contexts": contextvars.copy_context(), + "context": contextvars.copy_context(), "application_generate_entity": application_generate_entity, "queue_manager": queue_manager, "conversation_id": conversation.id, diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py index e08f4f64cf..255060ef3c 100644 --- a/api/core/tools/__base/tool.py +++ b/api/core/tools/__base/tool.py @@ -157,7 +157,10 @@ class Tool(ABC): return parameters - def create_image_message(self, image: str, save_as: str = "") -> ToolInvokeMessage: + def create_image_message( + self, + image: str, + ) -> ToolInvokeMessage: """ create an image message @@ -165,7 +168,7 @@ class Tool(ABC): :return: the image message """ return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image), save_as=save_as + type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image) ) def create_file_message(self, file: "File") -> ToolInvokeMessage: @@ -173,10 +176,9 @@ class Tool(ABC): type=ToolInvokeMessage.MessageType.FILE, message=ToolInvokeMessage.FileMessage(), meta={"file": file}, - save_as="", ) - def create_link_message(self, link: str, save_as: str = "") -> ToolInvokeMessage: + def create_link_message(self, link: str) -> ToolInvokeMessage: """ create a link message @@ -184,10 +186,10 @@ class Tool(ABC): :return: the link message """ return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text=link), save_as=save_as + type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text=link) ) - def create_text_message(self, text: str, save_as: str = "") -> ToolInvokeMessage: + def create_text_message(self, text: str) -> ToolInvokeMessage: """ create a text message @@ -195,10 +197,11 @@ class Tool(ABC): :return: the text message """ return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.TEXT, message=ToolInvokeMessage.TextMessage(text=text), save_as=save_as + type=ToolInvokeMessage.MessageType.TEXT, + message=ToolInvokeMessage.TextMessage(text=text), ) - def create_blob_message(self, blob: bytes, meta: Optional[dict] = None, save_as: str = "") -> ToolInvokeMessage: + def create_blob_message(self, blob: bytes, meta: Optional[dict] = None) -> ToolInvokeMessage: """ create a blob message @@ -209,7 +212,6 @@ class Tool(ABC): type=ToolInvokeMessage.MessageType.BLOB, message=ToolInvokeMessage.BlobMessage(blob=blob), meta=meta, - save_as=save_as, ) def create_json_message(self, object: dict) -> ToolInvokeMessage: diff --git a/api/core/tools/builtin_tool/providers/audio/_assets/icon.svg b/api/core/tools/builtin_tool/providers/audio/_assets/icon.svg new file mode 100644 index 0000000000..08cc4ede66 --- /dev/null +++ b/api/core/tools/builtin_tool/providers/audio/_assets/icon.svg @@ -0,0 +1,3 @@ + + + \ No newline at end of file diff --git a/api/core/tools/builtin_tool/providers/audio/audio.py b/api/core/tools/builtin_tool/providers/audio/audio.py new file mode 100644 index 0000000000..116279ad20 --- /dev/null +++ b/api/core/tools/builtin_tool/providers/audio/audio.py @@ -0,0 +1,6 @@ +from core.tools.builtin_tool.provider import BuiltinToolProviderController + + +class AudioToolProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + pass diff --git a/api/core/tools/builtin_tool/providers/audio/audio.yaml b/api/core/tools/builtin_tool/providers/audio/audio.yaml new file mode 100644 index 0000000000..07db268dac --- /dev/null +++ b/api/core/tools/builtin_tool/providers/audio/audio.yaml @@ -0,0 +1,11 @@ +identity: + author: hjlarry + name: audio + label: + en_US: Audio + description: + en_US: A tool for tts and asr. + zh_Hans: 一个用于文本转语音和语音转文本的工具。 + icon: icon.svg + tags: + - utilities diff --git a/api/core/tools/builtin_tool/providers/audio/tools/asr.py b/api/core/tools/builtin_tool/providers/audio/tools/asr.py new file mode 100644 index 0000000000..6af0430d01 --- /dev/null +++ b/api/core/tools/builtin_tool/providers/audio/tools/asr.py @@ -0,0 +1,71 @@ +import io +from collections.abc import Generator +from typing import Any + +from core.file.enums import FileType +from core.file.file_manager import download +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelType +from core.tools.builtin_tool.tool import BuiltinTool +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption +from services.model_provider_service import ModelProviderService + + +class ASRTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]: + file = tool_parameters.get("audio_file") + if file.type != FileType.AUDIO: # type: ignore + yield self.create_text_message("not a valid audio file") + return + audio_binary = io.BytesIO(download(file)) # type: ignore + audio_binary.name = "temp.mp3" + provider, model = tool_parameters.get("model").split("#") # type: ignore + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=self.runtime.tenant_id, + provider=provider, + model_type=ModelType.SPEECH2TEXT, + model=model, + ) + text = model_instance.invoke_speech2text( + file=audio_binary, + user=user_id, + ) + yield self.create_text_message(text) + + def get_available_models(self) -> list[tuple[str, str]]: + model_provider_service = ModelProviderService() + models = model_provider_service.get_models_by_model_type( + tenant_id=self.runtime.tenant_id, model_type="speech2text" + ) + items = [] + for provider_model in models: + provider = provider_model.provider + for model in provider_model.models: + items.append((provider, model.model)) + return items + + def get_runtime_parameters(self) -> list[ToolParameter]: + parameters = [] + + options = [] + for provider, model in self.get_available_models(): + option = ToolParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})")) + options.append(option) + + parameters.append( + ToolParameter( + name="model", + label=I18nObject(en_US="Model", zh_Hans="Model"), + human_description=I18nObject( + en_US="All available ASR models. You can config model in the Model Provider of Settings.", + zh_Hans="所有可用的 ASR 模型。你可以在设置中的模型供应商里配置。", + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + required=True, + options=options, + ) + ) + return parameters diff --git a/api/core/tools/builtin_tool/providers/audio/tools/asr.yaml b/api/core/tools/builtin_tool/providers/audio/tools/asr.yaml new file mode 100644 index 0000000000..b2c82f8086 --- /dev/null +++ b/api/core/tools/builtin_tool/providers/audio/tools/asr.yaml @@ -0,0 +1,22 @@ +identity: + name: asr + author: hjlarry + label: + en_US: Speech To Text +description: + human: + en_US: Convert audio file to text. + zh_Hans: 将音频文件转换为文本。 + llm: Convert audio file to text. +parameters: + - name: audio_file + type: file + required: true + label: + en_US: Audio File + zh_Hans: 音频文件 + human_description: + en_US: The audio file to be converted. + zh_Hans: 要转换的音频文件。 + llm_description: The audio file to be converted. + form: llm diff --git a/api/core/tools/builtin_tool/providers/audio/tools/tts.py b/api/core/tools/builtin_tool/providers/audio/tools/tts.py new file mode 100644 index 0000000000..9d083b35b3 --- /dev/null +++ b/api/core/tools/builtin_tool/providers/audio/tools/tts.py @@ -0,0 +1,87 @@ +import io +from collections.abc import Generator +from typing import Any + +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from core.tools.builtin_tool.tool import BuiltinTool +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption +from services.model_provider_service import ModelProviderService + + +class TTSTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]: + provider, model = tool_parameters.get("model").split("#") # type: ignore + voice = tool_parameters.get(f"voice#{provider}#{model}") + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=self.runtime.tenant_id, + provider=provider, + model_type=ModelType.TTS, + model=model, + ) + tts = model_instance.invoke_tts( + content_text=tool_parameters.get("text"), # type: ignore + user=user_id, + tenant_id=self.runtime.tenant_id, + voice=voice, # type: ignore + ) + buffer = io.BytesIO() + for chunk in tts: + buffer.write(chunk) + + wav_bytes = buffer.getvalue() + yield self.create_text_message("Audio generated successfully") + yield self.create_blob_message( + blob=wav_bytes, + meta={"mime_type": "audio/x-wav"}, + ) + + def get_available_models(self) -> list[tuple[str, str, list[Any]]]: + model_provider_service = ModelProviderService() + models = model_provider_service.get_models_by_model_type(tenant_id=self.runtime.tenant_id, model_type="tts") + items = [] + for provider_model in models: + provider = provider_model.provider + for model in provider_model.models: + voices = model.model_properties.get(ModelPropertyKey.VOICES, []) + items.append((provider, model.model, voices)) + return items + + def get_runtime_parameters(self) -> list[ToolParameter]: + parameters = [] + + options = [] + for provider, model, voices in self.get_available_models(): + option = ToolParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})")) + options.append(option) + parameters.append( + ToolParameter( + name=f"voice#{provider}#{model}", + label=I18nObject(en_US=f"Voice of {model}({provider})"), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + options=[ + ToolParameterOption(value=voice.get("mode"), label=I18nObject(en_US=voice.get("name"))) + for voice in voices + ], + ) + ) + + parameters.insert( + 0, + ToolParameter( + name="model", + label=I18nObject(en_US="Model", zh_Hans="Model"), + human_description=I18nObject( + en_US="All available TTS models. You can config model in the Model Provider of Settings.", + zh_Hans="所有可用的 TTS 模型。你可以在设置中的模型供应商里配置。", + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + required=True, + options=options, + ), + ) + return parameters diff --git a/api/core/tools/builtin_tool/providers/audio/tools/tts.yaml b/api/core/tools/builtin_tool/providers/audio/tools/tts.yaml new file mode 100644 index 0000000000..36f42bd689 --- /dev/null +++ b/api/core/tools/builtin_tool/providers/audio/tools/tts.yaml @@ -0,0 +1,22 @@ +identity: + name: tts + author: hjlarry + label: + en_US: Text To Speech +description: + human: + en_US: Convert text to audio file. + zh_Hans: 将文本转换为音频文件。 + llm: Convert text to audio file. +parameters: + - name: text + type: string + required: true + label: + en_US: Text + zh_Hans: 文本 + human_description: + en_US: The text to be converted. + zh_Hans: 要转换的文本。 + llm_description: The text to be converted. + form: llm diff --git a/api/core/tools/builtin_tool/providers/webscraper/_assets/icon.svg b/api/core/tools/builtin_tool/providers/webscraper/_assets/icon.svg new file mode 100644 index 0000000000..8123199a38 --- /dev/null +++ b/api/core/tools/builtin_tool/providers/webscraper/_assets/icon.svg @@ -0,0 +1,3 @@ + + + \ No newline at end of file diff --git a/api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.py b/api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.py new file mode 100644 index 0000000000..f356eefd09 --- /dev/null +++ b/api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.py @@ -0,0 +1,36 @@ +from collections.abc import Generator +from typing import Any + +from core.tools.builtin_tool.tool import BuiltinTool +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.errors import ToolInvokeError +from core.tools.utils.web_reader_tool import get_url + + +class WebscraperTool(BuiltinTool): + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Generator[ToolInvokeMessage, None, None]: + """ + invoke tools + """ + try: + url = tool_parameters.get("url", "") + user_agent = tool_parameters.get("user_agent", "") + if not url: + yield self.create_text_message("Please input url") + return + + # get webpage + result = get_url(url, user_agent=user_agent) + + if tool_parameters.get("generate_summary"): + # summarize and return + yield self.create_text_message(self.summary(user_id=user_id, content=result)) + else: + # return full webpage + yield self.create_text_message(result) + except Exception as e: + raise ToolInvokeError(str(e)) diff --git a/api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.yaml b/api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.yaml new file mode 100644 index 0000000000..291798c1f2 --- /dev/null +++ b/api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.yaml @@ -0,0 +1,60 @@ +identity: + name: webscraper + author: Dify + label: + en_US: Web Scraper + zh_Hans: 网页爬虫 + pt_BR: Web Scraper +description: + human: + en_US: A tool for scraping webpages. + zh_Hans: 一个用于爬取网页的工具。 + pt_BR: A tool for scraping webpages. + llm: A tool for scraping webpages. Input should be a URL. +parameters: + - name: url + type: string + required: true + label: + en_US: URL + zh_Hans: 网页链接 + pt_BR: URL + human_description: + en_US: used for linking to webpages + zh_Hans: 用于链接到网页 + pt_BR: used for linking to webpages + llm_description: url for scraping + form: llm + - name: user_agent + type: string + required: false + label: + en_US: User Agent + zh_Hans: User Agent + pt_BR: User Agent + human_description: + en_US: used for identifying the browser. + zh_Hans: 用于识别浏览器。 + pt_BR: used for identifying the browser. + form: form + default: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/100.0.1000.0 Safari/537.36 + - name: generate_summary + type: boolean + required: false + label: + en_US: Whether to generate summary + zh_Hans: 是否生成摘要 + human_description: + en_US: If true, the crawler will only return the page summary content. + zh_Hans: 如果启用,爬虫将仅返回页面摘要内容。 + form: form + options: + - value: "true" + label: + en_US: "Yes" + zh_Hans: 是 + - value: "false" + label: + en_US: "No" + zh_Hans: 否 + default: "false" diff --git a/api/core/tools/builtin_tool/providers/webscraper/webscraper.py b/api/core/tools/builtin_tool/providers/webscraper/webscraper.py new file mode 100644 index 0000000000..9d62fb5fcb --- /dev/null +++ b/api/core/tools/builtin_tool/providers/webscraper/webscraper.py @@ -0,0 +1,8 @@ +from typing import Any + +from core.tools.builtin_tool.provider import BuiltinToolProviderController + + +class WebscraperProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + pass diff --git a/api/core/tools/builtin_tool/providers/webscraper/webscraper.yaml b/api/core/tools/builtin_tool/providers/webscraper/webscraper.yaml new file mode 100644 index 0000000000..d6d0a0d610 --- /dev/null +++ b/api/core/tools/builtin_tool/providers/webscraper/webscraper.yaml @@ -0,0 +1,15 @@ +identity: + author: Dify + name: webscraper + label: + en_US: WebScraper + zh_Hans: 网页抓取 + pt_BR: WebScraper + description: + en_US: Web Scrapper tool kit is used to scrape web + zh_Hans: 一个用于抓取网页的工具。 + pt_BR: Web Scrapper tool kit is used to scrape web + icon: icon.svg + tags: + - productivity +credentials_for_provider: [] diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 7e0e251477..4cc0d4ae6e 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -166,7 +166,6 @@ class ToolInvokeMessage(BaseModel): """ message: JsonMessage | TextMessage | BlobMessage | VariableMessage | FileMessage | None meta: dict[str, Any] | None = None - save_as: str = "" @field_validator("message", mode="before") @classmethod @@ -188,7 +187,6 @@ class ToolInvokeMessage(BaseModel): class ToolInvokeMessageBinary(BaseModel): mimetype: str = Field(..., description="The mimetype of the binary") url: str = Field(..., description="The url of the binary") - save_as: str = "" file_var: Optional[dict[str, Any]] = None diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 6e8137a8e9..c27c149bfd 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -49,7 +49,7 @@ class ToolEngine: conversation_id: Optional[str] = None, app_id: Optional[str] = None, message_id: Optional[str] = None, - ) -> tuple[str, list[tuple[MessageFile, str]], ToolInvokeMeta]: + ) -> tuple[str, list[str], ToolInvokeMeta]: """ Agent invokes the tool with the given arguments. """ @@ -279,7 +279,6 @@ class ToolEngine: yield ToolInvokeMessageBinary( mimetype=response.meta.get("mime_type", "image/jpeg"), url=cast(ToolInvokeMessage.TextMessage, response.message).text, - save_as=response.save_as, ) elif response.type == ToolInvokeMessage.MessageType.BLOB: if not response.meta: @@ -288,7 +287,6 @@ class ToolEngine: yield ToolInvokeMessageBinary( mimetype=response.meta.get("mime_type", "octet/stream"), url=cast(ToolInvokeMessage.TextMessage, response.message).text, - save_as=response.save_as, ) elif response.type == ToolInvokeMessage.MessageType.LINK: # check if there is a mime type in meta @@ -296,7 +294,6 @@ class ToolEngine: yield ToolInvokeMessageBinary( mimetype=response.meta.get("mime_type", "octet/stream") if response.meta else "octet/stream", url=cast(ToolInvokeMessage.TextMessage, response.message).text, - save_as=response.save_as, ) @staticmethod @@ -305,12 +302,12 @@ class ToolEngine: agent_message: Message, invoke_from: InvokeFrom, user_id: str, - ) -> list[tuple[MessageFile, str]]: + ) -> list[str]: """ Create message file :param messages: messages - :return: message files, should save as variable + :return: message file ids """ result = [] @@ -347,7 +344,7 @@ class ToolEngine: db.session.commit() db.session.refresh(message_file) - result.append((message_file.id, message.save_as)) + result.append(message_file.id) db.session.close() diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 2385aa9b5b..09a7ef9d46 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -44,7 +44,6 @@ class ToolFileMessageTransformer: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=ToolInvokeMessage.TextMessage(text=url), - save_as=message.save_as, meta=message.meta.copy() if message.meta is not None else {}, ) except Exception as e: @@ -54,7 +53,6 @@ class ToolFileMessageTransformer: text=f"Failed to download image: {message.message.text}: {e}" ), meta=message.meta.copy() if message.meta is not None else {}, - save_as=message.save_as, ) elif message.type == ToolInvokeMessage.MessageType.BLOB: # get mime type and save blob to storage @@ -83,14 +81,12 @@ class ToolFileMessageTransformer: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=ToolInvokeMessage.TextMessage(text=url), - save_as=message.save_as, meta=message.meta.copy() if message.meta is not None else {}, ) else: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text=url), - save_as=message.save_as, meta=message.meta.copy() if message.meta is not None else {}, ) elif message.type == ToolInvokeMessage.MessageType.FILE: @@ -104,14 +100,12 @@ class ToolFileMessageTransformer: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=ToolInvokeMessage.TextMessage(text=url), - save_as=message.save_as, meta=message.meta.copy() if message.meta is not None else {}, ) else: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text=url), - save_as=message.save_as, meta=message.meta.copy() if message.meta is not None else {}, ) else: diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py new file mode 100644 index 0000000000..3aae31e93a --- /dev/null +++ b/api/core/tools/utils/web_reader_tool.py @@ -0,0 +1,374 @@ +import hashlib +import json +import mimetypes +import os +import re +import site +import subprocess +import tempfile +import unicodedata +from contextlib import contextmanager +from pathlib import Path +from typing import Optional +from urllib.parse import unquote + +import chardet +import cloudscraper +from bs4 import BeautifulSoup, CData, Comment, NavigableString +from regex import regex + +from core.helper import ssrf_proxy +from core.rag.extractor import extract_processor +from core.rag.extractor.extract_processor import ExtractProcessor + +FULL_TEMPLATE = """ +TITLE: {title} +AUTHORS: {authors} +PUBLISH DATE: {publish_date} +TOP_IMAGE_URL: {top_image} +TEXT: + +{text} +""" + + +def page_result(text: str, cursor: int, max_length: int) -> str: + """Page through `text` and return a substring of `max_length` characters starting from `cursor`.""" + return text[cursor : cursor + max_length] + + +def get_url(url: str, user_agent: Optional[str] = None) -> str: + """Fetch URL and return the contents as a string.""" + headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)" + " Chrome/91.0.4472.124 Safari/537.36" + } + if user_agent: + headers["User-Agent"] = user_agent + + main_content_type = None + supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] + response = ssrf_proxy.head(url, headers=headers, follow_redirects=True, timeout=(5, 10)) + + if response.status_code == 200: + # check content-type + content_type = response.headers.get("Content-Type") + if content_type: + main_content_type = response.headers.get("Content-Type").split(";")[0].strip() + else: + content_disposition = response.headers.get("Content-Disposition", "") + filename_match = re.search(r'filename="([^"]+)"', content_disposition) + if filename_match: + filename = unquote(filename_match.group(1)) + extension = re.search(r"\.(\w+)$", filename) + if extension: + main_content_type = mimetypes.guess_type(filename)[0] + + if main_content_type not in supported_content_types: + return "Unsupported content-type [{}] of URL.".format(main_content_type) + + if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: + return ExtractProcessor.load_from_url(url, return_text=True) + + response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) + elif response.status_code == 403: + scraper = cloudscraper.create_scraper() + scraper.perform_request = ssrf_proxy.make_request + response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) + + if response.status_code != 200: + return "URL returned status code {}.".format(response.status_code) + + # Detect encoding using chardet + detected_encoding = chardet.detect(response.content) + encoding = detected_encoding["encoding"] + if encoding: + try: + content = response.content.decode(encoding) + except (UnicodeDecodeError, TypeError): + content = response.text + else: + content = response.text + + a = extract_using_readabilipy(content) + + if not a["plain_text"] or not a["plain_text"].strip(): + return "" + + res = FULL_TEMPLATE.format( + title=a["title"], + authors=a["byline"], + publish_date=a["date"], + top_image="", + text=a["plain_text"] or "", + ) + + return res + + +def extract_using_readabilipy(html): + with tempfile.NamedTemporaryFile(delete=False, mode="w+") as f_html: + f_html.write(html) + f_html.close() + html_path = f_html.name + + # Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file + article_json_path = html_path + ".json" + jsdir = os.path.join(find_module_path("readabilipy"), "javascript") + with chdir(jsdir): + subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path]) + + # Read output of call to Readability.parse() from JSON file and return as Python dictionary + input_json = json.loads(Path(article_json_path).read_text(encoding="utf-8")) + + # Deleting files after processing + os.unlink(article_json_path) + os.unlink(html_path) + + article_json = { + "title": None, + "byline": None, + "date": None, + "content": None, + "plain_content": None, + "plain_text": None, + } + # Populate article fields from readability fields where present + if input_json: + if input_json.get("title"): + article_json["title"] = input_json["title"] + if input_json.get("byline"): + article_json["byline"] = input_json["byline"] + if input_json.get("date"): + article_json["date"] = input_json["date"] + if input_json.get("content"): + article_json["content"] = input_json["content"] + article_json["plain_content"] = plain_content(article_json["content"], False, False) + article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"]) + if input_json.get("textContent"): + article_json["plain_text"] = input_json["textContent"] + article_json["plain_text"] = re.sub(r"\n\s*\n", "\n", article_json["plain_text"]) + + return article_json + + +def find_module_path(module_name): + for package_path in site.getsitepackages(): + potential_path = os.path.join(package_path, module_name) + if os.path.exists(potential_path): + return potential_path + + return None + + +@contextmanager +def chdir(path): + """Change directory in context and return to original on exit""" + # From https://stackoverflow.com/a/37996581, couldn't find a built-in + original_path = os.getcwd() + os.chdir(path) + try: + yield + finally: + os.chdir(original_path) + + +def extract_text_blocks_as_plain_text(paragraph_html): + # Load article as DOM + soup = BeautifulSoup(paragraph_html, "html.parser") + # Select all lists + list_elements = soup.find_all(["ul", "ol"]) + # Prefix text in all list items with "* " and make lists paragraphs + for list_element in list_elements: + plain_items = "".join( + list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all("li")])) + ) + list_element.string = plain_items + list_element.name = "p" + # Select all text blocks + text_blocks = [s.parent for s in soup.find_all(string=True)] + text_blocks = [plain_text_leaf_node(block) for block in text_blocks] + # Drop empty paragraphs + text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks)) + return text_blocks + + +def plain_text_leaf_node(element): + # Extract all text, stripped of any child HTML elements and normalize it + plain_text = normalize_text(element.get_text()) + if plain_text != "" and element.name == "li": + plain_text = "* {}, ".format(plain_text) + if plain_text == "": + plain_text = None + if "data-node-index" in element.attrs: + plain = {"node_index": element["data-node-index"], "text": plain_text} + else: + plain = {"text": plain_text} + return plain + + +def plain_content(readability_content, content_digests, node_indexes): + # Load article as DOM + soup = BeautifulSoup(readability_content, "html.parser") + # Make all elements plain + elements = plain_elements(soup.contents, content_digests, node_indexes) + if node_indexes: + # Add node index attributes to nodes + elements = [add_node_indexes(element) for element in elements] + # Replace article contents with plain elements + soup.contents = elements + return str(soup) + + +def plain_elements(elements, content_digests, node_indexes): + # Get plain content versions of all elements + elements = [plain_element(element, content_digests, node_indexes) for element in elements] + if content_digests: + # Add content digest attribute to nodes + elements = [add_content_digest(element) for element in elements] + return elements + + +def plain_element(element, content_digests, node_indexes): + # For lists, we make each item plain text + if is_leaf(element): + # For leaf node elements, extract the text content, discarding any HTML tags + # 1. Get element contents as text + plain_text = element.get_text() + # 2. Normalize the extracted text string to a canonical representation + plain_text = normalize_text(plain_text) + # 3. Update element content to be plain text + element.string = plain_text + elif is_text(element): + if is_non_printing(element): + # The simplified HTML may have come from Readability.js so might + # have non-printing text (e.g. Comment or CData). In this case, we + # keep the structure, but ensure that the string is empty. + element = type(element)("") + else: + plain_text = element.string + plain_text = normalize_text(plain_text) + element = type(element)(plain_text) + else: + # If not a leaf node or leaf type call recursively on child nodes, replacing + element.contents = plain_elements(element.contents, content_digests, node_indexes) + return element + + +def add_node_indexes(element, node_index="0"): + # Can't add attributes to string types + if is_text(element): + return element + # Add index to current element + element["data-node-index"] = node_index + # Add index to child elements + for local_idx, child in enumerate([c for c in element.contents if not is_text(c)], start=1): + # Can't add attributes to leaf string types + child_index = "{stem}.{local}".format(stem=node_index, local=local_idx) + add_node_indexes(child, node_index=child_index) + return element + + +def normalize_text(text): + """Normalize unicode and whitespace.""" + # Normalize unicode first to try and standardize whitespace characters as much as possible before normalizing them + text = strip_control_characters(text) + text = normalize_unicode(text) + text = normalize_whitespace(text) + return text + + +def strip_control_characters(text): + """Strip out unicode control characters which might break the parsing.""" + # Unicode control characters + # [Cc]: Other, Control [includes new lines] + # [Cf]: Other, Format + # [Cn]: Other, Not Assigned + # [Co]: Other, Private Use + # [Cs]: Other, Surrogate + control_chars = {"Cc", "Cf", "Cn", "Co", "Cs"} + retained_chars = ["\t", "\n", "\r", "\f"] + + # Remove non-printing control characters + return "".join( + [ + "" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char + for char in text + ] + ) + + +def normalize_unicode(text): + """Normalize unicode such that things that are visually equivalent map to the same unicode string where possible.""" + normal_form = "NFKC" + text = unicodedata.normalize(normal_form, text) + return text + + +def normalize_whitespace(text): + """Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed.""" + text = regex.sub(r"\s+", " ", text) + # Remove leading and trailing whitespace + text = text.strip() + return text + + +def is_leaf(element): + return element.name in {"p", "li"} + + +def is_text(element): + return isinstance(element, NavigableString) + + +def is_non_printing(element): + return any(isinstance(element, _e) for _e in [Comment, CData]) + + +def add_content_digest(element): + if not is_text(element): + element["data-content-digest"] = content_digest(element) + return element + + +def content_digest(element): + if is_text(element): + # Hash + trimmed_string = element.string.strip() + if trimmed_string == "": + digest = "" + else: + digest = hashlib.sha256(trimmed_string.encode("utf-8")).hexdigest() + else: + contents = element.contents + num_contents = len(contents) + if num_contents == 0: + # No hash when no child elements exist + digest = "" + elif num_contents == 1: + # If single child, use digest of child + digest = content_digest(contents[0]) + else: + # Build content digest from the "non-empty" digests of child nodes + digest = hashlib.sha256() + child_digests = list(filter(lambda x: x != "", [content_digest(content) for content in contents])) + for child in child_digests: + digest.update(child.encode("utf-8")) + digest = digest.hexdigest() + return digest + + +def get_image_upload_file_ids(content): + pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)" + matches = re.findall(pattern, content) + image_upload_file_ids = [] + for match in matches: + if match[1] == "file-preview": + content_pattern = r"files/([^/]+)/file-preview" + else: + content_pattern = r"files/([^/]+)/image-preview" + content_match = re.search(content_pattern, match[0]) + if content_match: + image_upload_file_id = content_match.group(1) + image_upload_file_ids.append(image_upload_file_id) + return image_upload_file_ids diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 7354086b03..63858323bd 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -1,5 +1,4 @@ from collections.abc import Generator, Mapping, Sequence -from os import path from typing import Any, cast from sqlalchemy import select @@ -236,8 +235,7 @@ class ToolNode(BaseNode[ToolNodeData]): type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id=tool_file_id, - filename=message.save_as, - extension=path.splitext(message.save_as)[1], + extension=None, mime_type=message.meta.get("mime_type", "application/octet-stream"), ) )