mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-19 18:59:07 +08:00
feat: move audio and webscraper back to dify
This commit is contained in:
parent
574a6c1ded
commit
dcf19549cb
@ -309,13 +309,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
)
|
||||
|
||||
# publish files
|
||||
for message_file_id, save_as in message_files:
|
||||
for message_file_id in message_files:
|
||||
# publish message file
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id.id), PublishFrom.APPLICATION_MANAGER
|
||||
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file_id.id)
|
||||
message_file_ids.append(message_file_id)
|
||||
|
||||
return tool_invoke_response, tool_invoke_meta
|
||||
|
||||
|
@ -246,13 +246,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
conversation_id=self.conversation.id,
|
||||
)
|
||||
# publish files
|
||||
for message_file_id, save_as in message_files:
|
||||
for message_file_id in message_files:
|
||||
# publish message file
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id.id), PublishFrom.APPLICATION_MANAGER
|
||||
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file_id.id)
|
||||
message_file_ids.append(message_file_id)
|
||||
|
||||
tool_response = {
|
||||
"tool_call_id": tool_call_id,
|
||||
|
@ -172,7 +172,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"contexts": contextvars.copy_context(),
|
||||
"context": contextvars.copy_context(),
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
|
@ -157,7 +157,10 @@ class Tool(ABC):
|
||||
|
||||
return parameters
|
||||
|
||||
def create_image_message(self, image: str, save_as: str = "") -> ToolInvokeMessage:
|
||||
def create_image_message(
|
||||
self,
|
||||
image: str,
|
||||
) -> ToolInvokeMessage:
|
||||
"""
|
||||
create an image message
|
||||
|
||||
@ -165,7 +168,7 @@ class Tool(ABC):
|
||||
:return: the image message
|
||||
"""
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image), save_as=save_as
|
||||
type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image)
|
||||
)
|
||||
|
||||
def create_file_message(self, file: "File") -> ToolInvokeMessage:
|
||||
@ -173,10 +176,9 @@ class Tool(ABC):
|
||||
type=ToolInvokeMessage.MessageType.FILE,
|
||||
message=ToolInvokeMessage.FileMessage(),
|
||||
meta={"file": file},
|
||||
save_as="",
|
||||
)
|
||||
|
||||
def create_link_message(self, link: str, save_as: str = "") -> ToolInvokeMessage:
|
||||
def create_link_message(self, link: str) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a link message
|
||||
|
||||
@ -184,10 +186,10 @@ class Tool(ABC):
|
||||
:return: the link message
|
||||
"""
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text=link), save_as=save_as
|
||||
type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text=link)
|
||||
)
|
||||
|
||||
def create_text_message(self, text: str, save_as: str = "") -> ToolInvokeMessage:
|
||||
def create_text_message(self, text: str) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a text message
|
||||
|
||||
@ -195,10 +197,11 @@ class Tool(ABC):
|
||||
:return: the text message
|
||||
"""
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.TEXT, message=ToolInvokeMessage.TextMessage(text=text), save_as=save_as
|
||||
type=ToolInvokeMessage.MessageType.TEXT,
|
||||
message=ToolInvokeMessage.TextMessage(text=text),
|
||||
)
|
||||
|
||||
def create_blob_message(self, blob: bytes, meta: Optional[dict] = None, save_as: str = "") -> ToolInvokeMessage:
|
||||
def create_blob_message(self, blob: bytes, meta: Optional[dict] = None) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a blob message
|
||||
|
||||
@ -209,7 +212,6 @@ class Tool(ABC):
|
||||
type=ToolInvokeMessage.MessageType.BLOB,
|
||||
message=ToolInvokeMessage.BlobMessage(blob=blob),
|
||||
meta=meta,
|
||||
save_as=save_as,
|
||||
)
|
||||
|
||||
def create_json_message(self, object: dict) -> ToolInvokeMessage:
|
||||
|
@ -0,0 +1,3 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="200" height="200" viewBox="0 0 200 200" fill="none">
|
||||
<path d="M167.358 102.395C167.358 117.174 157.246 129.18 144.61 131.027H137.861C125.225 129.18 115.113 117.174 115.113 102.395H100.792C100.792 123.637 115.118 142.106 133.653 145.801V164.276H147.139V145.801C165.674 142.106 180 124.558 180 102.4H167.358V102.395ZM154.717 62.677C154.717 53.4397 147.979 46.9765 140.396 46.9765C138.523 46.9446 136.663 47.3273 134.924 48.1024C133.185 48.8775 131.603 50.0294 130.27 51.4909C128.936 52.9524 127.878 54.6943 127.157 56.6148C126.436 58.5354 126.066 60.5962 126.07 62.677V78.3775H154.717V70.4478V62.677ZM126.07 102.395C126.07 111.632 132.813 118.095 140.396 118.095C142.269 118.127 144.13 117.744 145.868 116.969C147.607 116.194 149.189 115.042 150.523 113.581C151.856 112.119 152.914 110.377 153.635 108.457C154.356 106.536 154.726 104.475 154.722 102.395V86.694H126.07V102.395ZM92.1297 45.8938L70.4796 21.7595L69.4235 20.5865L59.604 20L68.3674 20.5865L67.3113 21.7654L64.1429 25.2961L63.6149 25.8826L64.1429 27.0614L66.2552 29.4133L77.8723 42.3631H54.1099C35.1 43.5361 20.3146 61.1896 20.3146 81.7874V83.5527H28.2354V81.7932C28.2354 65.8992 39.8525 52.3628 54.1099 51.1899H77.8723L66.2552 64.1338L64.671 65.8992L64.1429 67.0722L63.6149 67.6645L64.1429 68.251L68.3674 72.9606L68.8954 73.5471L69.4235 72.9606L74.1759 67.6645L92.1297 47.6591L92.6578 47.0727L92.1297 45.8938ZM20 95.8496V118.213H30.033V107.034H50.099V168.821H40.066V180H70.165V168.821H60.132V107.034H80.198V118.213H90.231V95.8496H20Z" fill="#FF0099"/>
|
||||
</svg>
|
After Width: | Height: | Size: 1.5 KiB |
6
api/core/tools/builtin_tool/providers/audio/audio.py
Normal file
6
api/core/tools/builtin_tool/providers/audio/audio.py
Normal file
@ -0,0 +1,6 @@
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class AudioToolProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
pass
|
11
api/core/tools/builtin_tool/providers/audio/audio.yaml
Normal file
11
api/core/tools/builtin_tool/providers/audio/audio.yaml
Normal file
@ -0,0 +1,11 @@
|
||||
identity:
|
||||
author: hjlarry
|
||||
name: audio
|
||||
label:
|
||||
en_US: Audio
|
||||
description:
|
||||
en_US: A tool for tts and asr.
|
||||
zh_Hans: 一个用于文本转语音和语音转文本的工具。
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- utilities
|
71
api/core/tools/builtin_tool/providers/audio/tools/asr.py
Normal file
71
api/core/tools/builtin_tool/providers/audio/tools/asr.py
Normal file
@ -0,0 +1,71 @@
|
||||
import io
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.file.enums import FileType
|
||||
from core.file.file_manager import download
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
class ASRTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]:
|
||||
file = tool_parameters.get("audio_file")
|
||||
if file.type != FileType.AUDIO: # type: ignore
|
||||
yield self.create_text_message("not a valid audio file")
|
||||
return
|
||||
audio_binary = io.BytesIO(download(file)) # type: ignore
|
||||
audio_binary.name = "temp.mp3"
|
||||
provider, model = tool_parameters.get("model").split("#") # type: ignore
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
provider=provider,
|
||||
model_type=ModelType.SPEECH2TEXT,
|
||||
model=model,
|
||||
)
|
||||
text = model_instance.invoke_speech2text(
|
||||
file=audio_binary,
|
||||
user=user_id,
|
||||
)
|
||||
yield self.create_text_message(text)
|
||||
|
||||
def get_available_models(self) -> list[tuple[str, str]]:
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_model_type(
|
||||
tenant_id=self.runtime.tenant_id, model_type="speech2text"
|
||||
)
|
||||
items = []
|
||||
for provider_model in models:
|
||||
provider = provider_model.provider
|
||||
for model in provider_model.models:
|
||||
items.append((provider, model.model))
|
||||
return items
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
parameters = []
|
||||
|
||||
options = []
|
||||
for provider, model in self.get_available_models():
|
||||
option = ToolParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})"))
|
||||
options.append(option)
|
||||
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
name="model",
|
||||
label=I18nObject(en_US="Model", zh_Hans="Model"),
|
||||
human_description=I18nObject(
|
||||
en_US="All available ASR models. You can config model in the Model Provider of Settings.",
|
||||
zh_Hans="所有可用的 ASR 模型。你可以在设置中的模型供应商里配置。",
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
required=True,
|
||||
options=options,
|
||||
)
|
||||
)
|
||||
return parameters
|
22
api/core/tools/builtin_tool/providers/audio/tools/asr.yaml
Normal file
22
api/core/tools/builtin_tool/providers/audio/tools/asr.yaml
Normal file
@ -0,0 +1,22 @@
|
||||
identity:
|
||||
name: asr
|
||||
author: hjlarry
|
||||
label:
|
||||
en_US: Speech To Text
|
||||
description:
|
||||
human:
|
||||
en_US: Convert audio file to text.
|
||||
zh_Hans: 将音频文件转换为文本。
|
||||
llm: Convert audio file to text.
|
||||
parameters:
|
||||
- name: audio_file
|
||||
type: file
|
||||
required: true
|
||||
label:
|
||||
en_US: Audio File
|
||||
zh_Hans: 音频文件
|
||||
human_description:
|
||||
en_US: The audio file to be converted.
|
||||
zh_Hans: 要转换的音频文件。
|
||||
llm_description: The audio file to be converted.
|
||||
form: llm
|
87
api/core/tools/builtin_tool/providers/audio/tools/tts.py
Normal file
87
api/core/tools/builtin_tool/providers/audio/tools/tts.py
Normal file
@ -0,0 +1,87 @@
|
||||
import io
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
class TTSTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]:
|
||||
provider, model = tool_parameters.get("model").split("#") # type: ignore
|
||||
voice = tool_parameters.get(f"voice#{provider}#{model}")
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
provider=provider,
|
||||
model_type=ModelType.TTS,
|
||||
model=model,
|
||||
)
|
||||
tts = model_instance.invoke_tts(
|
||||
content_text=tool_parameters.get("text"), # type: ignore
|
||||
user=user_id,
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
voice=voice, # type: ignore
|
||||
)
|
||||
buffer = io.BytesIO()
|
||||
for chunk in tts:
|
||||
buffer.write(chunk)
|
||||
|
||||
wav_bytes = buffer.getvalue()
|
||||
yield self.create_text_message("Audio generated successfully")
|
||||
yield self.create_blob_message(
|
||||
blob=wav_bytes,
|
||||
meta={"mime_type": "audio/x-wav"},
|
||||
)
|
||||
|
||||
def get_available_models(self) -> list[tuple[str, str, list[Any]]]:
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_model_type(tenant_id=self.runtime.tenant_id, model_type="tts")
|
||||
items = []
|
||||
for provider_model in models:
|
||||
provider = provider_model.provider
|
||||
for model in provider_model.models:
|
||||
voices = model.model_properties.get(ModelPropertyKey.VOICES, [])
|
||||
items.append((provider, model.model, voices))
|
||||
return items
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
parameters = []
|
||||
|
||||
options = []
|
||||
for provider, model, voices in self.get_available_models():
|
||||
option = ToolParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})"))
|
||||
options.append(option)
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
name=f"voice#{provider}#{model}",
|
||||
label=I18nObject(en_US=f"Voice of {model}({provider})"),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
options=[
|
||||
ToolParameterOption(value=voice.get("mode"), label=I18nObject(en_US=voice.get("name")))
|
||||
for voice in voices
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
parameters.insert(
|
||||
0,
|
||||
ToolParameter(
|
||||
name="model",
|
||||
label=I18nObject(en_US="Model", zh_Hans="Model"),
|
||||
human_description=I18nObject(
|
||||
en_US="All available TTS models. You can config model in the Model Provider of Settings.",
|
||||
zh_Hans="所有可用的 TTS 模型。你可以在设置中的模型供应商里配置。",
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
required=True,
|
||||
options=options,
|
||||
),
|
||||
)
|
||||
return parameters
|
22
api/core/tools/builtin_tool/providers/audio/tools/tts.yaml
Normal file
22
api/core/tools/builtin_tool/providers/audio/tools/tts.yaml
Normal file
@ -0,0 +1,22 @@
|
||||
identity:
|
||||
name: tts
|
||||
author: hjlarry
|
||||
label:
|
||||
en_US: Text To Speech
|
||||
description:
|
||||
human:
|
||||
en_US: Convert text to audio file.
|
||||
zh_Hans: 将文本转换为音频文件。
|
||||
llm: Convert text to audio file.
|
||||
parameters:
|
||||
- name: text
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Text
|
||||
zh_Hans: 文本
|
||||
human_description:
|
||||
en_US: The text to be converted.
|
||||
zh_Hans: 要转换的文本。
|
||||
llm_description: The text to be converted.
|
||||
form: llm
|
@ -0,0 +1,3 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="17" viewBox="0 0 16 17" fill="none">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M2.6665 1.16667C1.56193 1.16667 0.666504 2.0621 0.666504 3.16667C0.666504 4.27124 1.56193 5.16667 2.6665 5.16667C2.79161 5.16667 2.91403 5.15519 3.03277 5.13321C2.3808 6.09319 1.99984 7.25211 1.99984 8.5C1.99984 9.7479 2.3808 10.9068 3.03277 11.8668C2.91403 11.8448 2.79161 11.8333 2.6665 11.8333C1.56193 11.8333 0.666504 12.7288 0.666504 13.8333C0.666504 14.9379 1.56193 15.8333 2.6665 15.8333C3.77107 15.8333 4.6665 14.9379 4.6665 13.8333C4.6665 13.7082 4.65502 13.5858 4.63304 13.4671C5.59302 14.119 6.75194 14.5 7.99984 14.5C9.24773 14.5 10.4066 14.119 11.3666 13.4671C11.3447 13.5858 11.3332 13.7082 11.3332 13.8333C11.3332 14.9379 12.2286 15.8333 13.3332 15.8333C14.4377 15.8333 15.3332 14.9379 15.3332 13.8333C15.3332 12.7288 14.4377 11.8333 13.3332 11.8333C13.2081 11.8333 13.0856 11.8448 12.9669 11.8668C13.6189 10.9068 13.9998 9.7479 13.9998 8.5C13.9998 7.25211 13.6189 6.09319 12.9669 5.13321C13.0856 5.15519 13.2081 5.16667 13.3332 5.16667C14.4377 5.16667 15.3332 4.27124 15.3332 3.16667C15.3332 2.0621 14.4377 1.16667 13.3332 1.16667C12.2286 1.16667 11.3332 2.0621 11.3332 3.16667C11.3332 3.29177 11.3447 3.41419 11.3666 3.53293C10.4066 2.88097 9.24773 2.50001 7.99984 2.50001C6.75194 2.50001 5.59302 2.88097 4.63304 3.53293C4.65502 3.41419 4.6665 3.29177 4.6665 3.16667C4.6665 2.0621 3.77107 1.16667 2.6665 1.16667ZM3.38043 7.83334C3.63081 6.08287 4.85262 4.64578 6.48223 4.08565C5.79223 5.22099 5.36488 6.50185 5.23815 7.83334H3.38043ZM6.48228 12.9144C4.85264 12.3543 3.63082 10.9172 3.38043 9.16667H5.23815C5.3649 10.4982 5.79226 11.779 6.48228 12.9144ZM12.6192 9.16667C12.3689 10.9168 11.1475 12.3537 9.5183 12.9141C10.2082 11.7788 10.6355 10.498 10.7622 9.16667H12.6192ZM9.51834 4.08596C11.1475 4.64631 12.3689 6.0832 12.6192 7.83334H10.7622C10.6355 6.50197 10.2082 5.22123 9.51834 4.08596ZM9.4218 7.83334C9.27457 6.52262 8.78381 5.27411 8.00019 4.2145C7.21658 5.27411 6.72582 6.52262 6.57859 7.83334H9.4218ZM6.5786 9.16667C6.72583 10.4774 7.21659 11.7259 8.00019 12.7855C8.7838 11.7259 9.27456 10.4774 9.42179 9.16667H6.5786Z" fill="#DD2590"/>
|
||||
</svg>
|
After Width: | Height: | Size: 2.2 KiB |
@ -0,0 +1,36 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.utils.web_reader_tool import get_url
|
||||
|
||||
|
||||
class WebscraperTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
try:
|
||||
url = tool_parameters.get("url", "")
|
||||
user_agent = tool_parameters.get("user_agent", "")
|
||||
if not url:
|
||||
yield self.create_text_message("Please input url")
|
||||
return
|
||||
|
||||
# get webpage
|
||||
result = get_url(url, user_agent=user_agent)
|
||||
|
||||
if tool_parameters.get("generate_summary"):
|
||||
# summarize and return
|
||||
yield self.create_text_message(self.summary(user_id=user_id, content=result))
|
||||
else:
|
||||
# return full webpage
|
||||
yield self.create_text_message(result)
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(str(e))
|
@ -0,0 +1,60 @@
|
||||
identity:
|
||||
name: webscraper
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Web Scraper
|
||||
zh_Hans: 网页爬虫
|
||||
pt_BR: Web Scraper
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for scraping webpages.
|
||||
zh_Hans: 一个用于爬取网页的工具。
|
||||
pt_BR: A tool for scraping webpages.
|
||||
llm: A tool for scraping webpages. Input should be a URL.
|
||||
parameters:
|
||||
- name: url
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: URL
|
||||
zh_Hans: 网页链接
|
||||
pt_BR: URL
|
||||
human_description:
|
||||
en_US: used for linking to webpages
|
||||
zh_Hans: 用于链接到网页
|
||||
pt_BR: used for linking to webpages
|
||||
llm_description: url for scraping
|
||||
form: llm
|
||||
- name: user_agent
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: User Agent
|
||||
zh_Hans: User Agent
|
||||
pt_BR: User Agent
|
||||
human_description:
|
||||
en_US: used for identifying the browser.
|
||||
zh_Hans: 用于识别浏览器。
|
||||
pt_BR: used for identifying the browser.
|
||||
form: form
|
||||
default: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/100.0.1000.0 Safari/537.36
|
||||
- name: generate_summary
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Whether to generate summary
|
||||
zh_Hans: 是否生成摘要
|
||||
human_description:
|
||||
en_US: If true, the crawler will only return the page summary content.
|
||||
zh_Hans: 如果启用,爬虫将仅返回页面摘要内容。
|
||||
form: form
|
||||
options:
|
||||
- value: "true"
|
||||
label:
|
||||
en_US: "Yes"
|
||||
zh_Hans: 是
|
||||
- value: "false"
|
||||
label:
|
||||
en_US: "No"
|
||||
zh_Hans: 否
|
||||
default: "false"
|
@ -0,0 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class WebscraperProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
pass
|
@ -0,0 +1,15 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: webscraper
|
||||
label:
|
||||
en_US: WebScraper
|
||||
zh_Hans: 网页抓取
|
||||
pt_BR: WebScraper
|
||||
description:
|
||||
en_US: Web Scrapper tool kit is used to scrape web
|
||||
zh_Hans: 一个用于抓取网页的工具。
|
||||
pt_BR: Web Scrapper tool kit is used to scrape web
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- productivity
|
||||
credentials_for_provider: []
|
@ -166,7 +166,6 @@ class ToolInvokeMessage(BaseModel):
|
||||
"""
|
||||
message: JsonMessage | TextMessage | BlobMessage | VariableMessage | FileMessage | None
|
||||
meta: dict[str, Any] | None = None
|
||||
save_as: str = ""
|
||||
|
||||
@field_validator("message", mode="before")
|
||||
@classmethod
|
||||
@ -188,7 +187,6 @@ class ToolInvokeMessage(BaseModel):
|
||||
class ToolInvokeMessageBinary(BaseModel):
|
||||
mimetype: str = Field(..., description="The mimetype of the binary")
|
||||
url: str = Field(..., description="The url of the binary")
|
||||
save_as: str = ""
|
||||
file_var: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
|
@ -49,7 +49,7 @@ class ToolEngine:
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> tuple[str, list[tuple[MessageFile, str]], ToolInvokeMeta]:
|
||||
) -> tuple[str, list[str], ToolInvokeMeta]:
|
||||
"""
|
||||
Agent invokes the tool with the given arguments.
|
||||
"""
|
||||
@ -279,7 +279,6 @@ class ToolEngine:
|
||||
yield ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get("mime_type", "image/jpeg"),
|
||||
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
|
||||
save_as=response.save_as,
|
||||
)
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
if not response.meta:
|
||||
@ -288,7 +287,6 @@ class ToolEngine:
|
||||
yield ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get("mime_type", "octet/stream"),
|
||||
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
|
||||
save_as=response.save_as,
|
||||
)
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
# check if there is a mime type in meta
|
||||
@ -296,7 +294,6 @@ class ToolEngine:
|
||||
yield ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get("mime_type", "octet/stream") if response.meta else "octet/stream",
|
||||
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
|
||||
save_as=response.save_as,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -305,12 +302,12 @@ class ToolEngine:
|
||||
agent_message: Message,
|
||||
invoke_from: InvokeFrom,
|
||||
user_id: str,
|
||||
) -> list[tuple[MessageFile, str]]:
|
||||
) -> list[str]:
|
||||
"""
|
||||
Create message file
|
||||
|
||||
:param messages: messages
|
||||
:return: message files, should save as variable
|
||||
:return: message file ids
|
||||
"""
|
||||
result = []
|
||||
|
||||
@ -347,7 +344,7 @@ class ToolEngine:
|
||||
db.session.commit()
|
||||
db.session.refresh(message_file)
|
||||
|
||||
result.append((message_file.id, message.save_as))
|
||||
result.append(message_file.id)
|
||||
|
||||
db.session.close()
|
||||
|
||||
|
@ -44,7 +44,6 @@ class ToolFileMessageTransformer:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=ToolInvokeMessage.TextMessage(text=url),
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
except Exception as e:
|
||||
@ -54,7 +53,6 @@ class ToolFileMessageTransformer:
|
||||
text=f"Failed to download image: {message.message.text}: {e}"
|
||||
),
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
save_as=message.save_as,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get mime type and save blob to storage
|
||||
@ -83,14 +81,12 @@ class ToolFileMessageTransformer:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=ToolInvokeMessage.TextMessage(text=url),
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
else:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LINK,
|
||||
message=ToolInvokeMessage.TextMessage(text=url),
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
||||
@ -104,14 +100,12 @@ class ToolFileMessageTransformer:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=ToolInvokeMessage.TextMessage(text=url),
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
else:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LINK,
|
||||
message=ToolInvokeMessage.TextMessage(text=url),
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
else:
|
||||
|
374
api/core/tools/utils/web_reader_tool.py
Normal file
374
api/core/tools/utils/web_reader_tool.py
Normal file
@ -0,0 +1,374 @@
|
||||
import hashlib
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
import site
|
||||
import subprocess
|
||||
import tempfile
|
||||
import unicodedata
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from urllib.parse import unquote
|
||||
|
||||
import chardet
|
||||
import cloudscraper
|
||||
from bs4 import BeautifulSoup, CData, Comment, NavigableString
|
||||
from regex import regex
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from core.rag.extractor import extract_processor
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
|
||||
FULL_TEMPLATE = """
|
||||
TITLE: {title}
|
||||
AUTHORS: {authors}
|
||||
PUBLISH DATE: {publish_date}
|
||||
TOP_IMAGE_URL: {top_image}
|
||||
TEXT:
|
||||
|
||||
{text}
|
||||
"""
|
||||
|
||||
|
||||
def page_result(text: str, cursor: int, max_length: int) -> str:
|
||||
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
|
||||
return text[cursor : cursor + max_length]
|
||||
|
||||
|
||||
def get_url(url: str, user_agent: Optional[str] = None) -> str:
|
||||
"""Fetch URL and return the contents as a string."""
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
|
||||
" Chrome/91.0.4472.124 Safari/537.36"
|
||||
}
|
||||
if user_agent:
|
||||
headers["User-Agent"] = user_agent
|
||||
|
||||
main_content_type = None
|
||||
supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
|
||||
response = ssrf_proxy.head(url, headers=headers, follow_redirects=True, timeout=(5, 10))
|
||||
|
||||
if response.status_code == 200:
|
||||
# check content-type
|
||||
content_type = response.headers.get("Content-Type")
|
||||
if content_type:
|
||||
main_content_type = response.headers.get("Content-Type").split(";")[0].strip()
|
||||
else:
|
||||
content_disposition = response.headers.get("Content-Disposition", "")
|
||||
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
|
||||
if filename_match:
|
||||
filename = unquote(filename_match.group(1))
|
||||
extension = re.search(r"\.(\w+)$", filename)
|
||||
if extension:
|
||||
main_content_type = mimetypes.guess_type(filename)[0]
|
||||
|
||||
if main_content_type not in supported_content_types:
|
||||
return "Unsupported content-type [{}] of URL.".format(main_content_type)
|
||||
|
||||
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
|
||||
return ExtractProcessor.load_from_url(url, return_text=True)
|
||||
|
||||
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
|
||||
elif response.status_code == 403:
|
||||
scraper = cloudscraper.create_scraper()
|
||||
scraper.perform_request = ssrf_proxy.make_request
|
||||
response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
|
||||
|
||||
if response.status_code != 200:
|
||||
return "URL returned status code {}.".format(response.status_code)
|
||||
|
||||
# Detect encoding using chardet
|
||||
detected_encoding = chardet.detect(response.content)
|
||||
encoding = detected_encoding["encoding"]
|
||||
if encoding:
|
||||
try:
|
||||
content = response.content.decode(encoding)
|
||||
except (UnicodeDecodeError, TypeError):
|
||||
content = response.text
|
||||
else:
|
||||
content = response.text
|
||||
|
||||
a = extract_using_readabilipy(content)
|
||||
|
||||
if not a["plain_text"] or not a["plain_text"].strip():
|
||||
return ""
|
||||
|
||||
res = FULL_TEMPLATE.format(
|
||||
title=a["title"],
|
||||
authors=a["byline"],
|
||||
publish_date=a["date"],
|
||||
top_image="",
|
||||
text=a["plain_text"] or "",
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def extract_using_readabilipy(html):
|
||||
with tempfile.NamedTemporaryFile(delete=False, mode="w+") as f_html:
|
||||
f_html.write(html)
|
||||
f_html.close()
|
||||
html_path = f_html.name
|
||||
|
||||
# Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file
|
||||
article_json_path = html_path + ".json"
|
||||
jsdir = os.path.join(find_module_path("readabilipy"), "javascript")
|
||||
with chdir(jsdir):
|
||||
subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path])
|
||||
|
||||
# Read output of call to Readability.parse() from JSON file and return as Python dictionary
|
||||
input_json = json.loads(Path(article_json_path).read_text(encoding="utf-8"))
|
||||
|
||||
# Deleting files after processing
|
||||
os.unlink(article_json_path)
|
||||
os.unlink(html_path)
|
||||
|
||||
article_json = {
|
||||
"title": None,
|
||||
"byline": None,
|
||||
"date": None,
|
||||
"content": None,
|
||||
"plain_content": None,
|
||||
"plain_text": None,
|
||||
}
|
||||
# Populate article fields from readability fields where present
|
||||
if input_json:
|
||||
if input_json.get("title"):
|
||||
article_json["title"] = input_json["title"]
|
||||
if input_json.get("byline"):
|
||||
article_json["byline"] = input_json["byline"]
|
||||
if input_json.get("date"):
|
||||
article_json["date"] = input_json["date"]
|
||||
if input_json.get("content"):
|
||||
article_json["content"] = input_json["content"]
|
||||
article_json["plain_content"] = plain_content(article_json["content"], False, False)
|
||||
article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"])
|
||||
if input_json.get("textContent"):
|
||||
article_json["plain_text"] = input_json["textContent"]
|
||||
article_json["plain_text"] = re.sub(r"\n\s*\n", "\n", article_json["plain_text"])
|
||||
|
||||
return article_json
|
||||
|
||||
|
||||
def find_module_path(module_name):
|
||||
for package_path in site.getsitepackages():
|
||||
potential_path = os.path.join(package_path, module_name)
|
||||
if os.path.exists(potential_path):
|
||||
return potential_path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def chdir(path):
|
||||
"""Change directory in context and return to original on exit"""
|
||||
# From https://stackoverflow.com/a/37996581, couldn't find a built-in
|
||||
original_path = os.getcwd()
|
||||
os.chdir(path)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
os.chdir(original_path)
|
||||
|
||||
|
||||
def extract_text_blocks_as_plain_text(paragraph_html):
|
||||
# Load article as DOM
|
||||
soup = BeautifulSoup(paragraph_html, "html.parser")
|
||||
# Select all lists
|
||||
list_elements = soup.find_all(["ul", "ol"])
|
||||
# Prefix text in all list items with "* " and make lists paragraphs
|
||||
for list_element in list_elements:
|
||||
plain_items = "".join(
|
||||
list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all("li")]))
|
||||
)
|
||||
list_element.string = plain_items
|
||||
list_element.name = "p"
|
||||
# Select all text blocks
|
||||
text_blocks = [s.parent for s in soup.find_all(string=True)]
|
||||
text_blocks = [plain_text_leaf_node(block) for block in text_blocks]
|
||||
# Drop empty paragraphs
|
||||
text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks))
|
||||
return text_blocks
|
||||
|
||||
|
||||
def plain_text_leaf_node(element):
|
||||
# Extract all text, stripped of any child HTML elements and normalize it
|
||||
plain_text = normalize_text(element.get_text())
|
||||
if plain_text != "" and element.name == "li":
|
||||
plain_text = "* {}, ".format(plain_text)
|
||||
if plain_text == "":
|
||||
plain_text = None
|
||||
if "data-node-index" in element.attrs:
|
||||
plain = {"node_index": element["data-node-index"], "text": plain_text}
|
||||
else:
|
||||
plain = {"text": plain_text}
|
||||
return plain
|
||||
|
||||
|
||||
def plain_content(readability_content, content_digests, node_indexes):
|
||||
# Load article as DOM
|
||||
soup = BeautifulSoup(readability_content, "html.parser")
|
||||
# Make all elements plain
|
||||
elements = plain_elements(soup.contents, content_digests, node_indexes)
|
||||
if node_indexes:
|
||||
# Add node index attributes to nodes
|
||||
elements = [add_node_indexes(element) for element in elements]
|
||||
# Replace article contents with plain elements
|
||||
soup.contents = elements
|
||||
return str(soup)
|
||||
|
||||
|
||||
def plain_elements(elements, content_digests, node_indexes):
|
||||
# Get plain content versions of all elements
|
||||
elements = [plain_element(element, content_digests, node_indexes) for element in elements]
|
||||
if content_digests:
|
||||
# Add content digest attribute to nodes
|
||||
elements = [add_content_digest(element) for element in elements]
|
||||
return elements
|
||||
|
||||
|
||||
def plain_element(element, content_digests, node_indexes):
|
||||
# For lists, we make each item plain text
|
||||
if is_leaf(element):
|
||||
# For leaf node elements, extract the text content, discarding any HTML tags
|
||||
# 1. Get element contents as text
|
||||
plain_text = element.get_text()
|
||||
# 2. Normalize the extracted text string to a canonical representation
|
||||
plain_text = normalize_text(plain_text)
|
||||
# 3. Update element content to be plain text
|
||||
element.string = plain_text
|
||||
elif is_text(element):
|
||||
if is_non_printing(element):
|
||||
# The simplified HTML may have come from Readability.js so might
|
||||
# have non-printing text (e.g. Comment or CData). In this case, we
|
||||
# keep the structure, but ensure that the string is empty.
|
||||
element = type(element)("")
|
||||
else:
|
||||
plain_text = element.string
|
||||
plain_text = normalize_text(plain_text)
|
||||
element = type(element)(plain_text)
|
||||
else:
|
||||
# If not a leaf node or leaf type call recursively on child nodes, replacing
|
||||
element.contents = plain_elements(element.contents, content_digests, node_indexes)
|
||||
return element
|
||||
|
||||
|
||||
def add_node_indexes(element, node_index="0"):
|
||||
# Can't add attributes to string types
|
||||
if is_text(element):
|
||||
return element
|
||||
# Add index to current element
|
||||
element["data-node-index"] = node_index
|
||||
# Add index to child elements
|
||||
for local_idx, child in enumerate([c for c in element.contents if not is_text(c)], start=1):
|
||||
# Can't add attributes to leaf string types
|
||||
child_index = "{stem}.{local}".format(stem=node_index, local=local_idx)
|
||||
add_node_indexes(child, node_index=child_index)
|
||||
return element
|
||||
|
||||
|
||||
def normalize_text(text):
|
||||
"""Normalize unicode and whitespace."""
|
||||
# Normalize unicode first to try and standardize whitespace characters as much as possible before normalizing them
|
||||
text = strip_control_characters(text)
|
||||
text = normalize_unicode(text)
|
||||
text = normalize_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def strip_control_characters(text):
|
||||
"""Strip out unicode control characters which might break the parsing."""
|
||||
# Unicode control characters
|
||||
# [Cc]: Other, Control [includes new lines]
|
||||
# [Cf]: Other, Format
|
||||
# [Cn]: Other, Not Assigned
|
||||
# [Co]: Other, Private Use
|
||||
# [Cs]: Other, Surrogate
|
||||
control_chars = {"Cc", "Cf", "Cn", "Co", "Cs"}
|
||||
retained_chars = ["\t", "\n", "\r", "\f"]
|
||||
|
||||
# Remove non-printing control characters
|
||||
return "".join(
|
||||
[
|
||||
"" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char
|
||||
for char in text
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def normalize_unicode(text):
|
||||
"""Normalize unicode such that things that are visually equivalent map to the same unicode string where possible."""
|
||||
normal_form = "NFKC"
|
||||
text = unicodedata.normalize(normal_form, text)
|
||||
return text
|
||||
|
||||
|
||||
def normalize_whitespace(text):
|
||||
"""Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed."""
|
||||
text = regex.sub(r"\s+", " ", text)
|
||||
# Remove leading and trailing whitespace
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
def is_leaf(element):
|
||||
return element.name in {"p", "li"}
|
||||
|
||||
|
||||
def is_text(element):
|
||||
return isinstance(element, NavigableString)
|
||||
|
||||
|
||||
def is_non_printing(element):
|
||||
return any(isinstance(element, _e) for _e in [Comment, CData])
|
||||
|
||||
|
||||
def add_content_digest(element):
|
||||
if not is_text(element):
|
||||
element["data-content-digest"] = content_digest(element)
|
||||
return element
|
||||
|
||||
|
||||
def content_digest(element):
|
||||
if is_text(element):
|
||||
# Hash
|
||||
trimmed_string = element.string.strip()
|
||||
if trimmed_string == "":
|
||||
digest = ""
|
||||
else:
|
||||
digest = hashlib.sha256(trimmed_string.encode("utf-8")).hexdigest()
|
||||
else:
|
||||
contents = element.contents
|
||||
num_contents = len(contents)
|
||||
if num_contents == 0:
|
||||
# No hash when no child elements exist
|
||||
digest = ""
|
||||
elif num_contents == 1:
|
||||
# If single child, use digest of child
|
||||
digest = content_digest(contents[0])
|
||||
else:
|
||||
# Build content digest from the "non-empty" digests of child nodes
|
||||
digest = hashlib.sha256()
|
||||
child_digests = list(filter(lambda x: x != "", [content_digest(content) for content in contents]))
|
||||
for child in child_digests:
|
||||
digest.update(child.encode("utf-8"))
|
||||
digest = digest.hexdigest()
|
||||
return digest
|
||||
|
||||
|
||||
def get_image_upload_file_ids(content):
|
||||
pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)"
|
||||
matches = re.findall(pattern, content)
|
||||
image_upload_file_ids = []
|
||||
for match in matches:
|
||||
if match[1] == "file-preview":
|
||||
content_pattern = r"files/([^/]+)/file-preview"
|
||||
else:
|
||||
content_pattern = r"files/([^/]+)/image-preview"
|
||||
content_match = re.search(content_pattern, match[0])
|
||||
if content_match:
|
||||
image_upload_file_id = content_match.group(1)
|
||||
image_upload_file_ids.append(image_upload_file_id)
|
||||
return image_upload_file_ids
|
@ -1,5 +1,4 @@
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from os import path
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
@ -236,8 +235,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id=tool_file_id,
|
||||
filename=message.save_as,
|
||||
extension=path.splitext(message.save_as)[1],
|
||||
extension=None,
|
||||
mime_type=message.meta.get("mime_type", "application/octet-stream"),
|
||||
)
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user