diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py
index efcc89f555..8b510258e8 100644
--- a/api/core/agent/cot_agent_runner.py
+++ b/api/core/agent/cot_agent_runner.py
@@ -309,13 +309,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
)
# publish files
- for message_file_id, save_as in message_files:
+ for message_file_id in message_files:
# publish message file
self.queue_manager.publish(
- QueueMessageFileEvent(message_file_id=message_file_id.id), PublishFrom.APPLICATION_MANAGER
+ QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
)
# add message file ids
- message_file_ids.append(message_file_id.id)
+ message_file_ids.append(message_file_id)
return tool_invoke_response, tool_invoke_meta
diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py
index 68042cc7ee..a63d92c1ae 100644
--- a/api/core/agent/fc_agent_runner.py
+++ b/api/core/agent/fc_agent_runner.py
@@ -246,13 +246,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
conversation_id=self.conversation.id,
)
# publish files
- for message_file_id, save_as in message_files:
+ for message_file_id in message_files:
# publish message file
self.queue_manager.publish(
- QueueMessageFileEvent(message_file_id=message_file_id.id), PublishFrom.APPLICATION_MANAGER
+ QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
)
# add message file ids
- message_file_ids.append(message_file_id.id)
+ message_file_ids.append(message_file_id)
tool_response = {
"tool_call_id": tool_call_id,
diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py
index 2bf696cbe0..33db4beb0e 100644
--- a/api/core/app/apps/agent_chat/app_generator.py
+++ b/api/core/app/apps/agent_chat/app_generator.py
@@ -172,7 +172,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
- "contexts": contextvars.copy_context(),
+ "context": contextvars.copy_context(),
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py
index e08f4f64cf..255060ef3c 100644
--- a/api/core/tools/__base/tool.py
+++ b/api/core/tools/__base/tool.py
@@ -157,7 +157,10 @@ class Tool(ABC):
return parameters
- def create_image_message(self, image: str, save_as: str = "") -> ToolInvokeMessage:
+ def create_image_message(
+ self,
+ image: str,
+ ) -> ToolInvokeMessage:
"""
create an image message
@@ -165,7 +168,7 @@ class Tool(ABC):
:return: the image message
"""
return ToolInvokeMessage(
- type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image), save_as=save_as
+ type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image)
)
def create_file_message(self, file: "File") -> ToolInvokeMessage:
@@ -173,10 +176,9 @@ class Tool(ABC):
type=ToolInvokeMessage.MessageType.FILE,
message=ToolInvokeMessage.FileMessage(),
meta={"file": file},
- save_as="",
)
- def create_link_message(self, link: str, save_as: str = "") -> ToolInvokeMessage:
+ def create_link_message(self, link: str) -> ToolInvokeMessage:
"""
create a link message
@@ -184,10 +186,10 @@ class Tool(ABC):
:return: the link message
"""
return ToolInvokeMessage(
- type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text=link), save_as=save_as
+ type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text=link)
)
- def create_text_message(self, text: str, save_as: str = "") -> ToolInvokeMessage:
+ def create_text_message(self, text: str) -> ToolInvokeMessage:
"""
create a text message
@@ -195,10 +197,11 @@ class Tool(ABC):
:return: the text message
"""
return ToolInvokeMessage(
- type=ToolInvokeMessage.MessageType.TEXT, message=ToolInvokeMessage.TextMessage(text=text), save_as=save_as
+ type=ToolInvokeMessage.MessageType.TEXT,
+ message=ToolInvokeMessage.TextMessage(text=text),
)
- def create_blob_message(self, blob: bytes, meta: Optional[dict] = None, save_as: str = "") -> ToolInvokeMessage:
+ def create_blob_message(self, blob: bytes, meta: Optional[dict] = None) -> ToolInvokeMessage:
"""
create a blob message
@@ -209,7 +212,6 @@ class Tool(ABC):
type=ToolInvokeMessage.MessageType.BLOB,
message=ToolInvokeMessage.BlobMessage(blob=blob),
meta=meta,
- save_as=save_as,
)
def create_json_message(self, object: dict) -> ToolInvokeMessage:
diff --git a/api/core/tools/builtin_tool/providers/audio/_assets/icon.svg b/api/core/tools/builtin_tool/providers/audio/_assets/icon.svg
new file mode 100644
index 0000000000..08cc4ede66
--- /dev/null
+++ b/api/core/tools/builtin_tool/providers/audio/_assets/icon.svg
@@ -0,0 +1,3 @@
+
\ No newline at end of file
diff --git a/api/core/tools/builtin_tool/providers/audio/audio.py b/api/core/tools/builtin_tool/providers/audio/audio.py
new file mode 100644
index 0000000000..116279ad20
--- /dev/null
+++ b/api/core/tools/builtin_tool/providers/audio/audio.py
@@ -0,0 +1,6 @@
+from core.tools.builtin_tool.provider import BuiltinToolProviderController
+
+
+class AudioToolProvider(BuiltinToolProviderController):
+ def _validate_credentials(self, credentials: dict) -> None:
+ pass
diff --git a/api/core/tools/builtin_tool/providers/audio/audio.yaml b/api/core/tools/builtin_tool/providers/audio/audio.yaml
new file mode 100644
index 0000000000..07db268dac
--- /dev/null
+++ b/api/core/tools/builtin_tool/providers/audio/audio.yaml
@@ -0,0 +1,11 @@
+identity:
+ author: hjlarry
+ name: audio
+ label:
+ en_US: Audio
+ description:
+ en_US: A tool for tts and asr.
+ zh_Hans: 一个用于文本转语音和语音转文本的工具。
+ icon: icon.svg
+ tags:
+ - utilities
diff --git a/api/core/tools/builtin_tool/providers/audio/tools/asr.py b/api/core/tools/builtin_tool/providers/audio/tools/asr.py
new file mode 100644
index 0000000000..6af0430d01
--- /dev/null
+++ b/api/core/tools/builtin_tool/providers/audio/tools/asr.py
@@ -0,0 +1,71 @@
+import io
+from collections.abc import Generator
+from typing import Any
+
+from core.file.enums import FileType
+from core.file.file_manager import download
+from core.model_manager import ModelManager
+from core.model_runtime.entities.model_entities import ModelType
+from core.tools.builtin_tool.tool import BuiltinTool
+from core.tools.entities.common_entities import I18nObject
+from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
+from services.model_provider_service import ModelProviderService
+
+
+class ASRTool(BuiltinTool):
+ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]:
+ file = tool_parameters.get("audio_file")
+ if file.type != FileType.AUDIO: # type: ignore
+ yield self.create_text_message("not a valid audio file")
+ return
+ audio_binary = io.BytesIO(download(file)) # type: ignore
+ audio_binary.name = "temp.mp3"
+ provider, model = tool_parameters.get("model").split("#") # type: ignore
+ model_manager = ModelManager()
+ model_instance = model_manager.get_model_instance(
+ tenant_id=self.runtime.tenant_id,
+ provider=provider,
+ model_type=ModelType.SPEECH2TEXT,
+ model=model,
+ )
+ text = model_instance.invoke_speech2text(
+ file=audio_binary,
+ user=user_id,
+ )
+ yield self.create_text_message(text)
+
+ def get_available_models(self) -> list[tuple[str, str]]:
+ model_provider_service = ModelProviderService()
+ models = model_provider_service.get_models_by_model_type(
+ tenant_id=self.runtime.tenant_id, model_type="speech2text"
+ )
+ items = []
+ for provider_model in models:
+ provider = provider_model.provider
+ for model in provider_model.models:
+ items.append((provider, model.model))
+ return items
+
+ def get_runtime_parameters(self) -> list[ToolParameter]:
+ parameters = []
+
+ options = []
+ for provider, model in self.get_available_models():
+ option = ToolParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})"))
+ options.append(option)
+
+ parameters.append(
+ ToolParameter(
+ name="model",
+ label=I18nObject(en_US="Model", zh_Hans="Model"),
+ human_description=I18nObject(
+ en_US="All available ASR models. You can config model in the Model Provider of Settings.",
+ zh_Hans="所有可用的 ASR 模型。你可以在设置中的模型供应商里配置。",
+ ),
+ type=ToolParameter.ToolParameterType.SELECT,
+ form=ToolParameter.ToolParameterForm.FORM,
+ required=True,
+ options=options,
+ )
+ )
+ return parameters
diff --git a/api/core/tools/builtin_tool/providers/audio/tools/asr.yaml b/api/core/tools/builtin_tool/providers/audio/tools/asr.yaml
new file mode 100644
index 0000000000..b2c82f8086
--- /dev/null
+++ b/api/core/tools/builtin_tool/providers/audio/tools/asr.yaml
@@ -0,0 +1,22 @@
+identity:
+ name: asr
+ author: hjlarry
+ label:
+ en_US: Speech To Text
+description:
+ human:
+ en_US: Convert audio file to text.
+ zh_Hans: 将音频文件转换为文本。
+ llm: Convert audio file to text.
+parameters:
+ - name: audio_file
+ type: file
+ required: true
+ label:
+ en_US: Audio File
+ zh_Hans: 音频文件
+ human_description:
+ en_US: The audio file to be converted.
+ zh_Hans: 要转换的音频文件。
+ llm_description: The audio file to be converted.
+ form: llm
diff --git a/api/core/tools/builtin_tool/providers/audio/tools/tts.py b/api/core/tools/builtin_tool/providers/audio/tools/tts.py
new file mode 100644
index 0000000000..9d083b35b3
--- /dev/null
+++ b/api/core/tools/builtin_tool/providers/audio/tools/tts.py
@@ -0,0 +1,87 @@
+import io
+from collections.abc import Generator
+from typing import Any
+
+from core.model_manager import ModelManager
+from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
+from core.tools.builtin_tool.tool import BuiltinTool
+from core.tools.entities.common_entities import I18nObject
+from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
+from services.model_provider_service import ModelProviderService
+
+
+class TTSTool(BuiltinTool):
+ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]:
+ provider, model = tool_parameters.get("model").split("#") # type: ignore
+ voice = tool_parameters.get(f"voice#{provider}#{model}")
+ model_manager = ModelManager()
+ model_instance = model_manager.get_model_instance(
+ tenant_id=self.runtime.tenant_id,
+ provider=provider,
+ model_type=ModelType.TTS,
+ model=model,
+ )
+ tts = model_instance.invoke_tts(
+ content_text=tool_parameters.get("text"), # type: ignore
+ user=user_id,
+ tenant_id=self.runtime.tenant_id,
+ voice=voice, # type: ignore
+ )
+ buffer = io.BytesIO()
+ for chunk in tts:
+ buffer.write(chunk)
+
+ wav_bytes = buffer.getvalue()
+ yield self.create_text_message("Audio generated successfully")
+ yield self.create_blob_message(
+ blob=wav_bytes,
+ meta={"mime_type": "audio/x-wav"},
+ )
+
+ def get_available_models(self) -> list[tuple[str, str, list[Any]]]:
+ model_provider_service = ModelProviderService()
+ models = model_provider_service.get_models_by_model_type(tenant_id=self.runtime.tenant_id, model_type="tts")
+ items = []
+ for provider_model in models:
+ provider = provider_model.provider
+ for model in provider_model.models:
+ voices = model.model_properties.get(ModelPropertyKey.VOICES, [])
+ items.append((provider, model.model, voices))
+ return items
+
+ def get_runtime_parameters(self) -> list[ToolParameter]:
+ parameters = []
+
+ options = []
+ for provider, model, voices in self.get_available_models():
+ option = ToolParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})"))
+ options.append(option)
+ parameters.append(
+ ToolParameter(
+ name=f"voice#{provider}#{model}",
+ label=I18nObject(en_US=f"Voice of {model}({provider})"),
+ type=ToolParameter.ToolParameterType.SELECT,
+ form=ToolParameter.ToolParameterForm.FORM,
+ options=[
+ ToolParameterOption(value=voice.get("mode"), label=I18nObject(en_US=voice.get("name")))
+ for voice in voices
+ ],
+ )
+ )
+
+ parameters.insert(
+ 0,
+ ToolParameter(
+ name="model",
+ label=I18nObject(en_US="Model", zh_Hans="Model"),
+ human_description=I18nObject(
+ en_US="All available TTS models. You can config model in the Model Provider of Settings.",
+ zh_Hans="所有可用的 TTS 模型。你可以在设置中的模型供应商里配置。",
+ ),
+ type=ToolParameter.ToolParameterType.SELECT,
+ form=ToolParameter.ToolParameterForm.FORM,
+ required=True,
+ options=options,
+ ),
+ )
+ return parameters
diff --git a/api/core/tools/builtin_tool/providers/audio/tools/tts.yaml b/api/core/tools/builtin_tool/providers/audio/tools/tts.yaml
new file mode 100644
index 0000000000..36f42bd689
--- /dev/null
+++ b/api/core/tools/builtin_tool/providers/audio/tools/tts.yaml
@@ -0,0 +1,22 @@
+identity:
+ name: tts
+ author: hjlarry
+ label:
+ en_US: Text To Speech
+description:
+ human:
+ en_US: Convert text to audio file.
+ zh_Hans: 将文本转换为音频文件。
+ llm: Convert text to audio file.
+parameters:
+ - name: text
+ type: string
+ required: true
+ label:
+ en_US: Text
+ zh_Hans: 文本
+ human_description:
+ en_US: The text to be converted.
+ zh_Hans: 要转换的文本。
+ llm_description: The text to be converted.
+ form: llm
diff --git a/api/core/tools/builtin_tool/providers/webscraper/_assets/icon.svg b/api/core/tools/builtin_tool/providers/webscraper/_assets/icon.svg
new file mode 100644
index 0000000000..8123199a38
--- /dev/null
+++ b/api/core/tools/builtin_tool/providers/webscraper/_assets/icon.svg
@@ -0,0 +1,3 @@
+
\ No newline at end of file
diff --git a/api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.py b/api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.py
new file mode 100644
index 0000000000..f356eefd09
--- /dev/null
+++ b/api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.py
@@ -0,0 +1,36 @@
+from collections.abc import Generator
+from typing import Any
+
+from core.tools.builtin_tool.tool import BuiltinTool
+from core.tools.entities.tool_entities import ToolInvokeMessage
+from core.tools.errors import ToolInvokeError
+from core.tools.utils.web_reader_tool import get_url
+
+
+class WebscraperTool(BuiltinTool):
+ def _invoke(
+ self,
+ user_id: str,
+ tool_parameters: dict[str, Any],
+ ) -> Generator[ToolInvokeMessage, None, None]:
+ """
+ invoke tools
+ """
+ try:
+ url = tool_parameters.get("url", "")
+ user_agent = tool_parameters.get("user_agent", "")
+ if not url:
+ yield self.create_text_message("Please input url")
+ return
+
+ # get webpage
+ result = get_url(url, user_agent=user_agent)
+
+ if tool_parameters.get("generate_summary"):
+ # summarize and return
+ yield self.create_text_message(self.summary(user_id=user_id, content=result))
+ else:
+ # return full webpage
+ yield self.create_text_message(result)
+ except Exception as e:
+ raise ToolInvokeError(str(e))
diff --git a/api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.yaml b/api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.yaml
new file mode 100644
index 0000000000..291798c1f2
--- /dev/null
+++ b/api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.yaml
@@ -0,0 +1,60 @@
+identity:
+ name: webscraper
+ author: Dify
+ label:
+ en_US: Web Scraper
+ zh_Hans: 网页爬虫
+ pt_BR: Web Scraper
+description:
+ human:
+ en_US: A tool for scraping webpages.
+ zh_Hans: 一个用于爬取网页的工具。
+ pt_BR: A tool for scraping webpages.
+ llm: A tool for scraping webpages. Input should be a URL.
+parameters:
+ - name: url
+ type: string
+ required: true
+ label:
+ en_US: URL
+ zh_Hans: 网页链接
+ pt_BR: URL
+ human_description:
+ en_US: used for linking to webpages
+ zh_Hans: 用于链接到网页
+ pt_BR: used for linking to webpages
+ llm_description: url for scraping
+ form: llm
+ - name: user_agent
+ type: string
+ required: false
+ label:
+ en_US: User Agent
+ zh_Hans: User Agent
+ pt_BR: User Agent
+ human_description:
+ en_US: used for identifying the browser.
+ zh_Hans: 用于识别浏览器。
+ pt_BR: used for identifying the browser.
+ form: form
+ default: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/100.0.1000.0 Safari/537.36
+ - name: generate_summary
+ type: boolean
+ required: false
+ label:
+ en_US: Whether to generate summary
+ zh_Hans: 是否生成摘要
+ human_description:
+ en_US: If true, the crawler will only return the page summary content.
+ zh_Hans: 如果启用,爬虫将仅返回页面摘要内容。
+ form: form
+ options:
+ - value: "true"
+ label:
+ en_US: "Yes"
+ zh_Hans: 是
+ - value: "false"
+ label:
+ en_US: "No"
+ zh_Hans: 否
+ default: "false"
diff --git a/api/core/tools/builtin_tool/providers/webscraper/webscraper.py b/api/core/tools/builtin_tool/providers/webscraper/webscraper.py
new file mode 100644
index 0000000000..9d62fb5fcb
--- /dev/null
+++ b/api/core/tools/builtin_tool/providers/webscraper/webscraper.py
@@ -0,0 +1,8 @@
+from typing import Any
+
+from core.tools.builtin_tool.provider import BuiltinToolProviderController
+
+
+class WebscraperProvider(BuiltinToolProviderController):
+ def _validate_credentials(self, credentials: dict[str, Any]) -> None:
+ pass
diff --git a/api/core/tools/builtin_tool/providers/webscraper/webscraper.yaml b/api/core/tools/builtin_tool/providers/webscraper/webscraper.yaml
new file mode 100644
index 0000000000..d6d0a0d610
--- /dev/null
+++ b/api/core/tools/builtin_tool/providers/webscraper/webscraper.yaml
@@ -0,0 +1,15 @@
+identity:
+ author: Dify
+ name: webscraper
+ label:
+ en_US: WebScraper
+ zh_Hans: 网页抓取
+ pt_BR: WebScraper
+ description:
+ en_US: Web Scrapper tool kit is used to scrape web
+ zh_Hans: 一个用于抓取网页的工具。
+ pt_BR: Web Scrapper tool kit is used to scrape web
+ icon: icon.svg
+ tags:
+ - productivity
+credentials_for_provider: []
diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py
index 7e0e251477..4cc0d4ae6e 100644
--- a/api/core/tools/entities/tool_entities.py
+++ b/api/core/tools/entities/tool_entities.py
@@ -166,7 +166,6 @@ class ToolInvokeMessage(BaseModel):
"""
message: JsonMessage | TextMessage | BlobMessage | VariableMessage | FileMessage | None
meta: dict[str, Any] | None = None
- save_as: str = ""
@field_validator("message", mode="before")
@classmethod
@@ -188,7 +187,6 @@ class ToolInvokeMessage(BaseModel):
class ToolInvokeMessageBinary(BaseModel):
mimetype: str = Field(..., description="The mimetype of the binary")
url: str = Field(..., description="The url of the binary")
- save_as: str = ""
file_var: Optional[dict[str, Any]] = None
diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py
index 6e8137a8e9..c27c149bfd 100644
--- a/api/core/tools/tool_engine.py
+++ b/api/core/tools/tool_engine.py
@@ -49,7 +49,7 @@ class ToolEngine:
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
- ) -> tuple[str, list[tuple[MessageFile, str]], ToolInvokeMeta]:
+ ) -> tuple[str, list[str], ToolInvokeMeta]:
"""
Agent invokes the tool with the given arguments.
"""
@@ -279,7 +279,6 @@ class ToolEngine:
yield ToolInvokeMessageBinary(
mimetype=response.meta.get("mime_type", "image/jpeg"),
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
- save_as=response.save_as,
)
elif response.type == ToolInvokeMessage.MessageType.BLOB:
if not response.meta:
@@ -288,7 +287,6 @@ class ToolEngine:
yield ToolInvokeMessageBinary(
mimetype=response.meta.get("mime_type", "octet/stream"),
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
- save_as=response.save_as,
)
elif response.type == ToolInvokeMessage.MessageType.LINK:
# check if there is a mime type in meta
@@ -296,7 +294,6 @@ class ToolEngine:
yield ToolInvokeMessageBinary(
mimetype=response.meta.get("mime_type", "octet/stream") if response.meta else "octet/stream",
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
- save_as=response.save_as,
)
@staticmethod
@@ -305,12 +302,12 @@ class ToolEngine:
agent_message: Message,
invoke_from: InvokeFrom,
user_id: str,
- ) -> list[tuple[MessageFile, str]]:
+ ) -> list[str]:
"""
Create message file
:param messages: messages
- :return: message files, should save as variable
+ :return: message file ids
"""
result = []
@@ -347,7 +344,7 @@ class ToolEngine:
db.session.commit()
db.session.refresh(message_file)
- result.append((message_file.id, message.save_as))
+ result.append(message_file.id)
db.session.close()
diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py
index 2385aa9b5b..09a7ef9d46 100644
--- a/api/core/tools/utils/message_transformer.py
+++ b/api/core/tools/utils/message_transformer.py
@@ -44,7 +44,6 @@ class ToolFileMessageTransformer:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=ToolInvokeMessage.TextMessage(text=url),
- save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
)
except Exception as e:
@@ -54,7 +53,6 @@ class ToolFileMessageTransformer:
text=f"Failed to download image: {message.message.text}: {e}"
),
meta=message.meta.copy() if message.meta is not None else {},
- save_as=message.save_as,
)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get mime type and save blob to storage
@@ -83,14 +81,12 @@ class ToolFileMessageTransformer:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=ToolInvokeMessage.TextMessage(text=url),
- save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
)
else:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LINK,
message=ToolInvokeMessage.TextMessage(text=url),
- save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
)
elif message.type == ToolInvokeMessage.MessageType.FILE:
@@ -104,14 +100,12 @@ class ToolFileMessageTransformer:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=ToolInvokeMessage.TextMessage(text=url),
- save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
)
else:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LINK,
message=ToolInvokeMessage.TextMessage(text=url),
- save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
)
else:
diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py
new file mode 100644
index 0000000000..3aae31e93a
--- /dev/null
+++ b/api/core/tools/utils/web_reader_tool.py
@@ -0,0 +1,374 @@
+import hashlib
+import json
+import mimetypes
+import os
+import re
+import site
+import subprocess
+import tempfile
+import unicodedata
+from contextlib import contextmanager
+from pathlib import Path
+from typing import Optional
+from urllib.parse import unquote
+
+import chardet
+import cloudscraper
+from bs4 import BeautifulSoup, CData, Comment, NavigableString
+from regex import regex
+
+from core.helper import ssrf_proxy
+from core.rag.extractor import extract_processor
+from core.rag.extractor.extract_processor import ExtractProcessor
+
+FULL_TEMPLATE = """
+TITLE: {title}
+AUTHORS: {authors}
+PUBLISH DATE: {publish_date}
+TOP_IMAGE_URL: {top_image}
+TEXT:
+
+{text}
+"""
+
+
+def page_result(text: str, cursor: int, max_length: int) -> str:
+ """Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
+ return text[cursor : cursor + max_length]
+
+
+def get_url(url: str, user_agent: Optional[str] = None) -> str:
+ """Fetch URL and return the contents as a string."""
+ headers = {
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
+ " Chrome/91.0.4472.124 Safari/537.36"
+ }
+ if user_agent:
+ headers["User-Agent"] = user_agent
+
+ main_content_type = None
+ supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
+ response = ssrf_proxy.head(url, headers=headers, follow_redirects=True, timeout=(5, 10))
+
+ if response.status_code == 200:
+ # check content-type
+ content_type = response.headers.get("Content-Type")
+ if content_type:
+ main_content_type = response.headers.get("Content-Type").split(";")[0].strip()
+ else:
+ content_disposition = response.headers.get("Content-Disposition", "")
+ filename_match = re.search(r'filename="([^"]+)"', content_disposition)
+ if filename_match:
+ filename = unquote(filename_match.group(1))
+ extension = re.search(r"\.(\w+)$", filename)
+ if extension:
+ main_content_type = mimetypes.guess_type(filename)[0]
+
+ if main_content_type not in supported_content_types:
+ return "Unsupported content-type [{}] of URL.".format(main_content_type)
+
+ if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
+ return ExtractProcessor.load_from_url(url, return_text=True)
+
+ response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
+ elif response.status_code == 403:
+ scraper = cloudscraper.create_scraper()
+ scraper.perform_request = ssrf_proxy.make_request
+ response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
+
+ if response.status_code != 200:
+ return "URL returned status code {}.".format(response.status_code)
+
+ # Detect encoding using chardet
+ detected_encoding = chardet.detect(response.content)
+ encoding = detected_encoding["encoding"]
+ if encoding:
+ try:
+ content = response.content.decode(encoding)
+ except (UnicodeDecodeError, TypeError):
+ content = response.text
+ else:
+ content = response.text
+
+ a = extract_using_readabilipy(content)
+
+ if not a["plain_text"] or not a["plain_text"].strip():
+ return ""
+
+ res = FULL_TEMPLATE.format(
+ title=a["title"],
+ authors=a["byline"],
+ publish_date=a["date"],
+ top_image="",
+ text=a["plain_text"] or "",
+ )
+
+ return res
+
+
+def extract_using_readabilipy(html):
+ with tempfile.NamedTemporaryFile(delete=False, mode="w+") as f_html:
+ f_html.write(html)
+ f_html.close()
+ html_path = f_html.name
+
+ # Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file
+ article_json_path = html_path + ".json"
+ jsdir = os.path.join(find_module_path("readabilipy"), "javascript")
+ with chdir(jsdir):
+ subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path])
+
+ # Read output of call to Readability.parse() from JSON file and return as Python dictionary
+ input_json = json.loads(Path(article_json_path).read_text(encoding="utf-8"))
+
+ # Deleting files after processing
+ os.unlink(article_json_path)
+ os.unlink(html_path)
+
+ article_json = {
+ "title": None,
+ "byline": None,
+ "date": None,
+ "content": None,
+ "plain_content": None,
+ "plain_text": None,
+ }
+ # Populate article fields from readability fields where present
+ if input_json:
+ if input_json.get("title"):
+ article_json["title"] = input_json["title"]
+ if input_json.get("byline"):
+ article_json["byline"] = input_json["byline"]
+ if input_json.get("date"):
+ article_json["date"] = input_json["date"]
+ if input_json.get("content"):
+ article_json["content"] = input_json["content"]
+ article_json["plain_content"] = plain_content(article_json["content"], False, False)
+ article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"])
+ if input_json.get("textContent"):
+ article_json["plain_text"] = input_json["textContent"]
+ article_json["plain_text"] = re.sub(r"\n\s*\n", "\n", article_json["plain_text"])
+
+ return article_json
+
+
+def find_module_path(module_name):
+ for package_path in site.getsitepackages():
+ potential_path = os.path.join(package_path, module_name)
+ if os.path.exists(potential_path):
+ return potential_path
+
+ return None
+
+
+@contextmanager
+def chdir(path):
+ """Change directory in context and return to original on exit"""
+ # From https://stackoverflow.com/a/37996581, couldn't find a built-in
+ original_path = os.getcwd()
+ os.chdir(path)
+ try:
+ yield
+ finally:
+ os.chdir(original_path)
+
+
+def extract_text_blocks_as_plain_text(paragraph_html):
+ # Load article as DOM
+ soup = BeautifulSoup(paragraph_html, "html.parser")
+ # Select all lists
+ list_elements = soup.find_all(["ul", "ol"])
+ # Prefix text in all list items with "* " and make lists paragraphs
+ for list_element in list_elements:
+ plain_items = "".join(
+ list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all("li")]))
+ )
+ list_element.string = plain_items
+ list_element.name = "p"
+ # Select all text blocks
+ text_blocks = [s.parent for s in soup.find_all(string=True)]
+ text_blocks = [plain_text_leaf_node(block) for block in text_blocks]
+ # Drop empty paragraphs
+ text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks))
+ return text_blocks
+
+
+def plain_text_leaf_node(element):
+ # Extract all text, stripped of any child HTML elements and normalize it
+ plain_text = normalize_text(element.get_text())
+ if plain_text != "" and element.name == "li":
+ plain_text = "* {}, ".format(plain_text)
+ if plain_text == "":
+ plain_text = None
+ if "data-node-index" in element.attrs:
+ plain = {"node_index": element["data-node-index"], "text": plain_text}
+ else:
+ plain = {"text": plain_text}
+ return plain
+
+
+def plain_content(readability_content, content_digests, node_indexes):
+ # Load article as DOM
+ soup = BeautifulSoup(readability_content, "html.parser")
+ # Make all elements plain
+ elements = plain_elements(soup.contents, content_digests, node_indexes)
+ if node_indexes:
+ # Add node index attributes to nodes
+ elements = [add_node_indexes(element) for element in elements]
+ # Replace article contents with plain elements
+ soup.contents = elements
+ return str(soup)
+
+
+def plain_elements(elements, content_digests, node_indexes):
+ # Get plain content versions of all elements
+ elements = [plain_element(element, content_digests, node_indexes) for element in elements]
+ if content_digests:
+ # Add content digest attribute to nodes
+ elements = [add_content_digest(element) for element in elements]
+ return elements
+
+
+def plain_element(element, content_digests, node_indexes):
+ # For lists, we make each item plain text
+ if is_leaf(element):
+ # For leaf node elements, extract the text content, discarding any HTML tags
+ # 1. Get element contents as text
+ plain_text = element.get_text()
+ # 2. Normalize the extracted text string to a canonical representation
+ plain_text = normalize_text(plain_text)
+ # 3. Update element content to be plain text
+ element.string = plain_text
+ elif is_text(element):
+ if is_non_printing(element):
+ # The simplified HTML may have come from Readability.js so might
+ # have non-printing text (e.g. Comment or CData). In this case, we
+ # keep the structure, but ensure that the string is empty.
+ element = type(element)("")
+ else:
+ plain_text = element.string
+ plain_text = normalize_text(plain_text)
+ element = type(element)(plain_text)
+ else:
+ # If not a leaf node or leaf type call recursively on child nodes, replacing
+ element.contents = plain_elements(element.contents, content_digests, node_indexes)
+ return element
+
+
+def add_node_indexes(element, node_index="0"):
+ # Can't add attributes to string types
+ if is_text(element):
+ return element
+ # Add index to current element
+ element["data-node-index"] = node_index
+ # Add index to child elements
+ for local_idx, child in enumerate([c for c in element.contents if not is_text(c)], start=1):
+ # Can't add attributes to leaf string types
+ child_index = "{stem}.{local}".format(stem=node_index, local=local_idx)
+ add_node_indexes(child, node_index=child_index)
+ return element
+
+
+def normalize_text(text):
+ """Normalize unicode and whitespace."""
+ # Normalize unicode first to try and standardize whitespace characters as much as possible before normalizing them
+ text = strip_control_characters(text)
+ text = normalize_unicode(text)
+ text = normalize_whitespace(text)
+ return text
+
+
+def strip_control_characters(text):
+ """Strip out unicode control characters which might break the parsing."""
+ # Unicode control characters
+ # [Cc]: Other, Control [includes new lines]
+ # [Cf]: Other, Format
+ # [Cn]: Other, Not Assigned
+ # [Co]: Other, Private Use
+ # [Cs]: Other, Surrogate
+ control_chars = {"Cc", "Cf", "Cn", "Co", "Cs"}
+ retained_chars = ["\t", "\n", "\r", "\f"]
+
+ # Remove non-printing control characters
+ return "".join(
+ [
+ "" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char
+ for char in text
+ ]
+ )
+
+
+def normalize_unicode(text):
+ """Normalize unicode such that things that are visually equivalent map to the same unicode string where possible."""
+ normal_form = "NFKC"
+ text = unicodedata.normalize(normal_form, text)
+ return text
+
+
+def normalize_whitespace(text):
+ """Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed."""
+ text = regex.sub(r"\s+", " ", text)
+ # Remove leading and trailing whitespace
+ text = text.strip()
+ return text
+
+
+def is_leaf(element):
+ return element.name in {"p", "li"}
+
+
+def is_text(element):
+ return isinstance(element, NavigableString)
+
+
+def is_non_printing(element):
+ return any(isinstance(element, _e) for _e in [Comment, CData])
+
+
+def add_content_digest(element):
+ if not is_text(element):
+ element["data-content-digest"] = content_digest(element)
+ return element
+
+
+def content_digest(element):
+ if is_text(element):
+ # Hash
+ trimmed_string = element.string.strip()
+ if trimmed_string == "":
+ digest = ""
+ else:
+ digest = hashlib.sha256(trimmed_string.encode("utf-8")).hexdigest()
+ else:
+ contents = element.contents
+ num_contents = len(contents)
+ if num_contents == 0:
+ # No hash when no child elements exist
+ digest = ""
+ elif num_contents == 1:
+ # If single child, use digest of child
+ digest = content_digest(contents[0])
+ else:
+ # Build content digest from the "non-empty" digests of child nodes
+ digest = hashlib.sha256()
+ child_digests = list(filter(lambda x: x != "", [content_digest(content) for content in contents]))
+ for child in child_digests:
+ digest.update(child.encode("utf-8"))
+ digest = digest.hexdigest()
+ return digest
+
+
+def get_image_upload_file_ids(content):
+ pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)"
+ matches = re.findall(pattern, content)
+ image_upload_file_ids = []
+ for match in matches:
+ if match[1] == "file-preview":
+ content_pattern = r"files/([^/]+)/file-preview"
+ else:
+ content_pattern = r"files/([^/]+)/image-preview"
+ content_match = re.search(content_pattern, match[0])
+ if content_match:
+ image_upload_file_id = content_match.group(1)
+ image_upload_file_ids.append(image_upload_file_id)
+ return image_upload_file_ids
diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py
index 7354086b03..63858323bd 100644
--- a/api/core/workflow/nodes/tool/tool_node.py
+++ b/api/core/workflow/nodes/tool/tool_node.py
@@ -1,5 +1,4 @@
from collections.abc import Generator, Mapping, Sequence
-from os import path
from typing import Any, cast
from sqlalchemy import select
@@ -236,8 +235,7 @@ class ToolNode(BaseNode[ToolNodeData]):
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=tool_file_id,
- filename=message.save_as,
- extension=path.splitext(message.save_as)[1],
+ extension=None,
mime_type=message.meta.get("mime_type", "application/octet-stream"),
)
)