feat: move audio and webscraper back to dify

This commit is contained in:
Yeuoly 2024-12-03 19:27:57 +08:00
parent 574a6c1ded
commit dcf19549cb
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
21 changed files with 741 additions and 34 deletions

View File

@ -309,13 +309,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
)
# publish files
for message_file_id, save_as in message_files:
for message_file_id in message_files:
# publish message file
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id.id), PublishFrom.APPLICATION_MANAGER
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
)
# add message file ids
message_file_ids.append(message_file_id.id)
message_file_ids.append(message_file_id)
return tool_invoke_response, tool_invoke_meta

View File

@ -246,13 +246,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
conversation_id=self.conversation.id,
)
# publish files
for message_file_id, save_as in message_files:
for message_file_id in message_files:
# publish message file
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id.id), PublishFrom.APPLICATION_MANAGER
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
)
# add message file ids
message_file_ids.append(message_file_id.id)
message_file_ids.append(message_file_id)
tool_response = {
"tool_call_id": tool_call_id,

View File

@ -172,7 +172,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"contexts": contextvars.copy_context(),
"context": contextvars.copy_context(),
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,

View File

@ -157,7 +157,10 @@ class Tool(ABC):
return parameters
def create_image_message(self, image: str, save_as: str = "") -> ToolInvokeMessage:
def create_image_message(
self,
image: str,
) -> ToolInvokeMessage:
"""
create an image message
@ -165,7 +168,7 @@ class Tool(ABC):
:return: the image message
"""
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image), save_as=save_as
type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image)
)
def create_file_message(self, file: "File") -> ToolInvokeMessage:
@ -173,10 +176,9 @@ class Tool(ABC):
type=ToolInvokeMessage.MessageType.FILE,
message=ToolInvokeMessage.FileMessage(),
meta={"file": file},
save_as="",
)
def create_link_message(self, link: str, save_as: str = "") -> ToolInvokeMessage:
def create_link_message(self, link: str) -> ToolInvokeMessage:
"""
create a link message
@ -184,10 +186,10 @@ class Tool(ABC):
:return: the link message
"""
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text=link), save_as=save_as
type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text=link)
)
def create_text_message(self, text: str, save_as: str = "") -> ToolInvokeMessage:
def create_text_message(self, text: str) -> ToolInvokeMessage:
"""
create a text message
@ -195,10 +197,11 @@ class Tool(ABC):
:return: the text message
"""
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.TEXT, message=ToolInvokeMessage.TextMessage(text=text), save_as=save_as
type=ToolInvokeMessage.MessageType.TEXT,
message=ToolInvokeMessage.TextMessage(text=text),
)
def create_blob_message(self, blob: bytes, meta: Optional[dict] = None, save_as: str = "") -> ToolInvokeMessage:
def create_blob_message(self, blob: bytes, meta: Optional[dict] = None) -> ToolInvokeMessage:
"""
create a blob message
@ -209,7 +212,6 @@ class Tool(ABC):
type=ToolInvokeMessage.MessageType.BLOB,
message=ToolInvokeMessage.BlobMessage(blob=blob),
meta=meta,
save_as=save_as,
)
def create_json_message(self, object: dict) -> ToolInvokeMessage:

View File

@ -0,0 +1,3 @@
<svg xmlns="http://www.w3.org/2000/svg" width="200" height="200" viewBox="0 0 200 200" fill="none">
<path d="M167.358 102.395C167.358 117.174 157.246 129.18 144.61 131.027H137.861C125.225 129.18 115.113 117.174 115.113 102.395H100.792C100.792 123.637 115.118 142.106 133.653 145.801V164.276H147.139V145.801C165.674 142.106 180 124.558 180 102.4H167.358V102.395ZM154.717 62.677C154.717 53.4397 147.979 46.9765 140.396 46.9765C138.523 46.9446 136.663 47.3273 134.924 48.1024C133.185 48.8775 131.603 50.0294 130.27 51.4909C128.936 52.9524 127.878 54.6943 127.157 56.6148C126.436 58.5354 126.066 60.5962 126.07 62.677V78.3775H154.717V70.4478V62.677ZM126.07 102.395C126.07 111.632 132.813 118.095 140.396 118.095C142.269 118.127 144.13 117.744 145.868 116.969C147.607 116.194 149.189 115.042 150.523 113.581C151.856 112.119 152.914 110.377 153.635 108.457C154.356 106.536 154.726 104.475 154.722 102.395V86.694H126.07V102.395ZM92.1297 45.8938L70.4796 21.7595L69.4235 20.5865L59.604 20L68.3674 20.5865L67.3113 21.7654L64.1429 25.2961L63.6149 25.8826L64.1429 27.0614L66.2552 29.4133L77.8723 42.3631H54.1099C35.1 43.5361 20.3146 61.1896 20.3146 81.7874V83.5527H28.2354V81.7932C28.2354 65.8992 39.8525 52.3628 54.1099 51.1899H77.8723L66.2552 64.1338L64.671 65.8992L64.1429 67.0722L63.6149 67.6645L64.1429 68.251L68.3674 72.9606L68.8954 73.5471L69.4235 72.9606L74.1759 67.6645L92.1297 47.6591L92.6578 47.0727L92.1297 45.8938ZM20 95.8496V118.213H30.033V107.034H50.099V168.821H40.066V180H70.165V168.821H60.132V107.034H80.198V118.213H90.231V95.8496H20Z" fill="#FF0099"/>
</svg>

After

Width:  |  Height:  |  Size: 1.5 KiB

View File

@ -0,0 +1,6 @@
from core.tools.builtin_tool.provider import BuiltinToolProviderController
class AudioToolProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
pass

View File

@ -0,0 +1,11 @@
identity:
author: hjlarry
name: audio
label:
en_US: Audio
description:
en_US: A tool for tts and asr.
zh_Hans: 一个用于文本转语音和语音转文本的工具。
icon: icon.svg
tags:
- utilities

View File

@ -0,0 +1,71 @@
import io
from collections.abc import Generator
from typing import Any
from core.file.enums import FileType
from core.file.file_manager import download
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
from services.model_provider_service import ModelProviderService
class ASRTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]:
file = tool_parameters.get("audio_file")
if file.type != FileType.AUDIO: # type: ignore
yield self.create_text_message("not a valid audio file")
return
audio_binary = io.BytesIO(download(file)) # type: ignore
audio_binary.name = "temp.mp3"
provider, model = tool_parameters.get("model").split("#") # type: ignore
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.runtime.tenant_id,
provider=provider,
model_type=ModelType.SPEECH2TEXT,
model=model,
)
text = model_instance.invoke_speech2text(
file=audio_binary,
user=user_id,
)
yield self.create_text_message(text)
def get_available_models(self) -> list[tuple[str, str]]:
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(
tenant_id=self.runtime.tenant_id, model_type="speech2text"
)
items = []
for provider_model in models:
provider = provider_model.provider
for model in provider_model.models:
items.append((provider, model.model))
return items
def get_runtime_parameters(self) -> list[ToolParameter]:
parameters = []
options = []
for provider, model in self.get_available_models():
option = ToolParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})"))
options.append(option)
parameters.append(
ToolParameter(
name="model",
label=I18nObject(en_US="Model", zh_Hans="Model"),
human_description=I18nObject(
en_US="All available ASR models. You can config model in the Model Provider of Settings.",
zh_Hans="所有可用的 ASR 模型。你可以在设置中的模型供应商里配置。",
),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
required=True,
options=options,
)
)
return parameters

View File

@ -0,0 +1,22 @@
identity:
name: asr
author: hjlarry
label:
en_US: Speech To Text
description:
human:
en_US: Convert audio file to text.
zh_Hans: 将音频文件转换为文本。
llm: Convert audio file to text.
parameters:
- name: audio_file
type: file
required: true
label:
en_US: Audio File
zh_Hans: 音频文件
human_description:
en_US: The audio file to be converted.
zh_Hans: 要转换的音频文件。
llm_description: The audio file to be converted.
form: llm

View File

@ -0,0 +1,87 @@
import io
from collections.abc import Generator
from typing import Any
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
from services.model_provider_service import ModelProviderService
class TTSTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]:
provider, model = tool_parameters.get("model").split("#") # type: ignore
voice = tool_parameters.get(f"voice#{provider}#{model}")
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.runtime.tenant_id,
provider=provider,
model_type=ModelType.TTS,
model=model,
)
tts = model_instance.invoke_tts(
content_text=tool_parameters.get("text"), # type: ignore
user=user_id,
tenant_id=self.runtime.tenant_id,
voice=voice, # type: ignore
)
buffer = io.BytesIO()
for chunk in tts:
buffer.write(chunk)
wav_bytes = buffer.getvalue()
yield self.create_text_message("Audio generated successfully")
yield self.create_blob_message(
blob=wav_bytes,
meta={"mime_type": "audio/x-wav"},
)
def get_available_models(self) -> list[tuple[str, str, list[Any]]]:
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(tenant_id=self.runtime.tenant_id, model_type="tts")
items = []
for provider_model in models:
provider = provider_model.provider
for model in provider_model.models:
voices = model.model_properties.get(ModelPropertyKey.VOICES, [])
items.append((provider, model.model, voices))
return items
def get_runtime_parameters(self) -> list[ToolParameter]:
parameters = []
options = []
for provider, model, voices in self.get_available_models():
option = ToolParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})"))
options.append(option)
parameters.append(
ToolParameter(
name=f"voice#{provider}#{model}",
label=I18nObject(en_US=f"Voice of {model}({provider})"),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
options=[
ToolParameterOption(value=voice.get("mode"), label=I18nObject(en_US=voice.get("name")))
for voice in voices
],
)
)
parameters.insert(
0,
ToolParameter(
name="model",
label=I18nObject(en_US="Model", zh_Hans="Model"),
human_description=I18nObject(
en_US="All available TTS models. You can config model in the Model Provider of Settings.",
zh_Hans="所有可用的 TTS 模型。你可以在设置中的模型供应商里配置。",
),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
required=True,
options=options,
),
)
return parameters

View File

@ -0,0 +1,22 @@
identity:
name: tts
author: hjlarry
label:
en_US: Text To Speech
description:
human:
en_US: Convert text to audio file.
zh_Hans: 将文本转换为音频文件。
llm: Convert text to audio file.
parameters:
- name: text
type: string
required: true
label:
en_US: Text
zh_Hans: 文本
human_description:
en_US: The text to be converted.
zh_Hans: 要转换的文本。
llm_description: The text to be converted.
form: llm

View File

@ -0,0 +1,3 @@
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="17" viewBox="0 0 16 17" fill="none">
<path fill-rule="evenodd" clip-rule="evenodd" d="M2.6665 1.16667C1.56193 1.16667 0.666504 2.0621 0.666504 3.16667C0.666504 4.27124 1.56193 5.16667 2.6665 5.16667C2.79161 5.16667 2.91403 5.15519 3.03277 5.13321C2.3808 6.09319 1.99984 7.25211 1.99984 8.5C1.99984 9.7479 2.3808 10.9068 3.03277 11.8668C2.91403 11.8448 2.79161 11.8333 2.6665 11.8333C1.56193 11.8333 0.666504 12.7288 0.666504 13.8333C0.666504 14.9379 1.56193 15.8333 2.6665 15.8333C3.77107 15.8333 4.6665 14.9379 4.6665 13.8333C4.6665 13.7082 4.65502 13.5858 4.63304 13.4671C5.59302 14.119 6.75194 14.5 7.99984 14.5C9.24773 14.5 10.4066 14.119 11.3666 13.4671C11.3447 13.5858 11.3332 13.7082 11.3332 13.8333C11.3332 14.9379 12.2286 15.8333 13.3332 15.8333C14.4377 15.8333 15.3332 14.9379 15.3332 13.8333C15.3332 12.7288 14.4377 11.8333 13.3332 11.8333C13.2081 11.8333 13.0856 11.8448 12.9669 11.8668C13.6189 10.9068 13.9998 9.7479 13.9998 8.5C13.9998 7.25211 13.6189 6.09319 12.9669 5.13321C13.0856 5.15519 13.2081 5.16667 13.3332 5.16667C14.4377 5.16667 15.3332 4.27124 15.3332 3.16667C15.3332 2.0621 14.4377 1.16667 13.3332 1.16667C12.2286 1.16667 11.3332 2.0621 11.3332 3.16667C11.3332 3.29177 11.3447 3.41419 11.3666 3.53293C10.4066 2.88097 9.24773 2.50001 7.99984 2.50001C6.75194 2.50001 5.59302 2.88097 4.63304 3.53293C4.65502 3.41419 4.6665 3.29177 4.6665 3.16667C4.6665 2.0621 3.77107 1.16667 2.6665 1.16667ZM3.38043 7.83334C3.63081 6.08287 4.85262 4.64578 6.48223 4.08565C5.79223 5.22099 5.36488 6.50185 5.23815 7.83334H3.38043ZM6.48228 12.9144C4.85264 12.3543 3.63082 10.9172 3.38043 9.16667H5.23815C5.3649 10.4982 5.79226 11.779 6.48228 12.9144ZM12.6192 9.16667C12.3689 10.9168 11.1475 12.3537 9.5183 12.9141C10.2082 11.7788 10.6355 10.498 10.7622 9.16667H12.6192ZM9.51834 4.08596C11.1475 4.64631 12.3689 6.0832 12.6192 7.83334H10.7622C10.6355 6.50197 10.2082 5.22123 9.51834 4.08596ZM9.4218 7.83334C9.27457 6.52262 8.78381 5.27411 8.00019 4.2145C7.21658 5.27411 6.72582 6.52262 6.57859 7.83334H9.4218ZM6.5786 9.16667C6.72583 10.4774 7.21659 11.7259 8.00019 12.7855C8.7838 11.7259 9.27456 10.4774 9.42179 9.16667H6.5786Z" fill="#DD2590"/>
</svg>

After

Width:  |  Height:  |  Size: 2.2 KiB

View File

@ -0,0 +1,36 @@
from collections.abc import Generator
from typing import Any
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.errors import ToolInvokeError
from core.tools.utils.web_reader_tool import get_url
class WebscraperTool(BuiltinTool):
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Generator[ToolInvokeMessage, None, None]:
"""
invoke tools
"""
try:
url = tool_parameters.get("url", "")
user_agent = tool_parameters.get("user_agent", "")
if not url:
yield self.create_text_message("Please input url")
return
# get webpage
result = get_url(url, user_agent=user_agent)
if tool_parameters.get("generate_summary"):
# summarize and return
yield self.create_text_message(self.summary(user_id=user_id, content=result))
else:
# return full webpage
yield self.create_text_message(result)
except Exception as e:
raise ToolInvokeError(str(e))

View File

@ -0,0 +1,60 @@
identity:
name: webscraper
author: Dify
label:
en_US: Web Scraper
zh_Hans: 网页爬虫
pt_BR: Web Scraper
description:
human:
en_US: A tool for scraping webpages.
zh_Hans: 一个用于爬取网页的工具。
pt_BR: A tool for scraping webpages.
llm: A tool for scraping webpages. Input should be a URL.
parameters:
- name: url
type: string
required: true
label:
en_US: URL
zh_Hans: 网页链接
pt_BR: URL
human_description:
en_US: used for linking to webpages
zh_Hans: 用于链接到网页
pt_BR: used for linking to webpages
llm_description: url for scraping
form: llm
- name: user_agent
type: string
required: false
label:
en_US: User Agent
zh_Hans: User Agent
pt_BR: User Agent
human_description:
en_US: used for identifying the browser.
zh_Hans: 用于识别浏览器。
pt_BR: used for identifying the browser.
form: form
default: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/100.0.1000.0 Safari/537.36
- name: generate_summary
type: boolean
required: false
label:
en_US: Whether to generate summary
zh_Hans: 是否生成摘要
human_description:
en_US: If true, the crawler will only return the page summary content.
zh_Hans: 如果启用,爬虫将仅返回页面摘要内容。
form: form
options:
- value: "true"
label:
en_US: "Yes"
zh_Hans:
- value: "false"
label:
en_US: "No"
zh_Hans:
default: "false"

View File

@ -0,0 +1,8 @@
from typing import Any
from core.tools.builtin_tool.provider import BuiltinToolProviderController
class WebscraperProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
pass

View File

@ -0,0 +1,15 @@
identity:
author: Dify
name: webscraper
label:
en_US: WebScraper
zh_Hans: 网页抓取
pt_BR: WebScraper
description:
en_US: Web Scrapper tool kit is used to scrape web
zh_Hans: 一个用于抓取网页的工具。
pt_BR: Web Scrapper tool kit is used to scrape web
icon: icon.svg
tags:
- productivity
credentials_for_provider: []

View File

@ -166,7 +166,6 @@ class ToolInvokeMessage(BaseModel):
"""
message: JsonMessage | TextMessage | BlobMessage | VariableMessage | FileMessage | None
meta: dict[str, Any] | None = None
save_as: str = ""
@field_validator("message", mode="before")
@classmethod
@ -188,7 +187,6 @@ class ToolInvokeMessage(BaseModel):
class ToolInvokeMessageBinary(BaseModel):
mimetype: str = Field(..., description="The mimetype of the binary")
url: str = Field(..., description="The url of the binary")
save_as: str = ""
file_var: Optional[dict[str, Any]] = None

View File

@ -49,7 +49,7 @@ class ToolEngine:
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> tuple[str, list[tuple[MessageFile, str]], ToolInvokeMeta]:
) -> tuple[str, list[str], ToolInvokeMeta]:
"""
Agent invokes the tool with the given arguments.
"""
@ -279,7 +279,6 @@ class ToolEngine:
yield ToolInvokeMessageBinary(
mimetype=response.meta.get("mime_type", "image/jpeg"),
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
save_as=response.save_as,
)
elif response.type == ToolInvokeMessage.MessageType.BLOB:
if not response.meta:
@ -288,7 +287,6 @@ class ToolEngine:
yield ToolInvokeMessageBinary(
mimetype=response.meta.get("mime_type", "octet/stream"),
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
save_as=response.save_as,
)
elif response.type == ToolInvokeMessage.MessageType.LINK:
# check if there is a mime type in meta
@ -296,7 +294,6 @@ class ToolEngine:
yield ToolInvokeMessageBinary(
mimetype=response.meta.get("mime_type", "octet/stream") if response.meta else "octet/stream",
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
save_as=response.save_as,
)
@staticmethod
@ -305,12 +302,12 @@ class ToolEngine:
agent_message: Message,
invoke_from: InvokeFrom,
user_id: str,
) -> list[tuple[MessageFile, str]]:
) -> list[str]:
"""
Create message file
:param messages: messages
:return: message files, should save as variable
:return: message file ids
"""
result = []
@ -347,7 +344,7 @@ class ToolEngine:
db.session.commit()
db.session.refresh(message_file)
result.append((message_file.id, message.save_as))
result.append(message_file.id)
db.session.close()

View File

@ -44,7 +44,6 @@ class ToolFileMessageTransformer:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=ToolInvokeMessage.TextMessage(text=url),
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
)
except Exception as e:
@ -54,7 +53,6 @@ class ToolFileMessageTransformer:
text=f"Failed to download image: {message.message.text}: {e}"
),
meta=message.meta.copy() if message.meta is not None else {},
save_as=message.save_as,
)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get mime type and save blob to storage
@ -83,14 +81,12 @@ class ToolFileMessageTransformer:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=ToolInvokeMessage.TextMessage(text=url),
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
)
else:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LINK,
message=ToolInvokeMessage.TextMessage(text=url),
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
)
elif message.type == ToolInvokeMessage.MessageType.FILE:
@ -104,14 +100,12 @@ class ToolFileMessageTransformer:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=ToolInvokeMessage.TextMessage(text=url),
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
)
else:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LINK,
message=ToolInvokeMessage.TextMessage(text=url),
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
)
else:

View File

@ -0,0 +1,374 @@
import hashlib
import json
import mimetypes
import os
import re
import site
import subprocess
import tempfile
import unicodedata
from contextlib import contextmanager
from pathlib import Path
from typing import Optional
from urllib.parse import unquote
import chardet
import cloudscraper
from bs4 import BeautifulSoup, CData, Comment, NavigableString
from regex import regex
from core.helper import ssrf_proxy
from core.rag.extractor import extract_processor
from core.rag.extractor.extract_processor import ExtractProcessor
FULL_TEMPLATE = """
TITLE: {title}
AUTHORS: {authors}
PUBLISH DATE: {publish_date}
TOP_IMAGE_URL: {top_image}
TEXT:
{text}
"""
def page_result(text: str, cursor: int, max_length: int) -> str:
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
return text[cursor : cursor + max_length]
def get_url(url: str, user_agent: Optional[str] = None) -> str:
"""Fetch URL and return the contents as a string."""
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
" Chrome/91.0.4472.124 Safari/537.36"
}
if user_agent:
headers["User-Agent"] = user_agent
main_content_type = None
supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
response = ssrf_proxy.head(url, headers=headers, follow_redirects=True, timeout=(5, 10))
if response.status_code == 200:
# check content-type
content_type = response.headers.get("Content-Type")
if content_type:
main_content_type = response.headers.get("Content-Type").split(";")[0].strip()
else:
content_disposition = response.headers.get("Content-Disposition", "")
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
if filename_match:
filename = unquote(filename_match.group(1))
extension = re.search(r"\.(\w+)$", filename)
if extension:
main_content_type = mimetypes.guess_type(filename)[0]
if main_content_type not in supported_content_types:
return "Unsupported content-type [{}] of URL.".format(main_content_type)
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
return ExtractProcessor.load_from_url(url, return_text=True)
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
elif response.status_code == 403:
scraper = cloudscraper.create_scraper()
scraper.perform_request = ssrf_proxy.make_request
response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
if response.status_code != 200:
return "URL returned status code {}.".format(response.status_code)
# Detect encoding using chardet
detected_encoding = chardet.detect(response.content)
encoding = detected_encoding["encoding"]
if encoding:
try:
content = response.content.decode(encoding)
except (UnicodeDecodeError, TypeError):
content = response.text
else:
content = response.text
a = extract_using_readabilipy(content)
if not a["plain_text"] or not a["plain_text"].strip():
return ""
res = FULL_TEMPLATE.format(
title=a["title"],
authors=a["byline"],
publish_date=a["date"],
top_image="",
text=a["plain_text"] or "",
)
return res
def extract_using_readabilipy(html):
with tempfile.NamedTemporaryFile(delete=False, mode="w+") as f_html:
f_html.write(html)
f_html.close()
html_path = f_html.name
# Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file
article_json_path = html_path + ".json"
jsdir = os.path.join(find_module_path("readabilipy"), "javascript")
with chdir(jsdir):
subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path])
# Read output of call to Readability.parse() from JSON file and return as Python dictionary
input_json = json.loads(Path(article_json_path).read_text(encoding="utf-8"))
# Deleting files after processing
os.unlink(article_json_path)
os.unlink(html_path)
article_json = {
"title": None,
"byline": None,
"date": None,
"content": None,
"plain_content": None,
"plain_text": None,
}
# Populate article fields from readability fields where present
if input_json:
if input_json.get("title"):
article_json["title"] = input_json["title"]
if input_json.get("byline"):
article_json["byline"] = input_json["byline"]
if input_json.get("date"):
article_json["date"] = input_json["date"]
if input_json.get("content"):
article_json["content"] = input_json["content"]
article_json["plain_content"] = plain_content(article_json["content"], False, False)
article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"])
if input_json.get("textContent"):
article_json["plain_text"] = input_json["textContent"]
article_json["plain_text"] = re.sub(r"\n\s*\n", "\n", article_json["plain_text"])
return article_json
def find_module_path(module_name):
for package_path in site.getsitepackages():
potential_path = os.path.join(package_path, module_name)
if os.path.exists(potential_path):
return potential_path
return None
@contextmanager
def chdir(path):
"""Change directory in context and return to original on exit"""
# From https://stackoverflow.com/a/37996581, couldn't find a built-in
original_path = os.getcwd()
os.chdir(path)
try:
yield
finally:
os.chdir(original_path)
def extract_text_blocks_as_plain_text(paragraph_html):
# Load article as DOM
soup = BeautifulSoup(paragraph_html, "html.parser")
# Select all lists
list_elements = soup.find_all(["ul", "ol"])
# Prefix text in all list items with "* " and make lists paragraphs
for list_element in list_elements:
plain_items = "".join(
list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all("li")]))
)
list_element.string = plain_items
list_element.name = "p"
# Select all text blocks
text_blocks = [s.parent for s in soup.find_all(string=True)]
text_blocks = [plain_text_leaf_node(block) for block in text_blocks]
# Drop empty paragraphs
text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks))
return text_blocks
def plain_text_leaf_node(element):
# Extract all text, stripped of any child HTML elements and normalize it
plain_text = normalize_text(element.get_text())
if plain_text != "" and element.name == "li":
plain_text = "* {}, ".format(plain_text)
if plain_text == "":
plain_text = None
if "data-node-index" in element.attrs:
plain = {"node_index": element["data-node-index"], "text": plain_text}
else:
plain = {"text": plain_text}
return plain
def plain_content(readability_content, content_digests, node_indexes):
# Load article as DOM
soup = BeautifulSoup(readability_content, "html.parser")
# Make all elements plain
elements = plain_elements(soup.contents, content_digests, node_indexes)
if node_indexes:
# Add node index attributes to nodes
elements = [add_node_indexes(element) for element in elements]
# Replace article contents with plain elements
soup.contents = elements
return str(soup)
def plain_elements(elements, content_digests, node_indexes):
# Get plain content versions of all elements
elements = [plain_element(element, content_digests, node_indexes) for element in elements]
if content_digests:
# Add content digest attribute to nodes
elements = [add_content_digest(element) for element in elements]
return elements
def plain_element(element, content_digests, node_indexes):
# For lists, we make each item plain text
if is_leaf(element):
# For leaf node elements, extract the text content, discarding any HTML tags
# 1. Get element contents as text
plain_text = element.get_text()
# 2. Normalize the extracted text string to a canonical representation
plain_text = normalize_text(plain_text)
# 3. Update element content to be plain text
element.string = plain_text
elif is_text(element):
if is_non_printing(element):
# The simplified HTML may have come from Readability.js so might
# have non-printing text (e.g. Comment or CData). In this case, we
# keep the structure, but ensure that the string is empty.
element = type(element)("")
else:
plain_text = element.string
plain_text = normalize_text(plain_text)
element = type(element)(plain_text)
else:
# If not a leaf node or leaf type call recursively on child nodes, replacing
element.contents = plain_elements(element.contents, content_digests, node_indexes)
return element
def add_node_indexes(element, node_index="0"):
# Can't add attributes to string types
if is_text(element):
return element
# Add index to current element
element["data-node-index"] = node_index
# Add index to child elements
for local_idx, child in enumerate([c for c in element.contents if not is_text(c)], start=1):
# Can't add attributes to leaf string types
child_index = "{stem}.{local}".format(stem=node_index, local=local_idx)
add_node_indexes(child, node_index=child_index)
return element
def normalize_text(text):
"""Normalize unicode and whitespace."""
# Normalize unicode first to try and standardize whitespace characters as much as possible before normalizing them
text = strip_control_characters(text)
text = normalize_unicode(text)
text = normalize_whitespace(text)
return text
def strip_control_characters(text):
"""Strip out unicode control characters which might break the parsing."""
# Unicode control characters
# [Cc]: Other, Control [includes new lines]
# [Cf]: Other, Format
# [Cn]: Other, Not Assigned
# [Co]: Other, Private Use
# [Cs]: Other, Surrogate
control_chars = {"Cc", "Cf", "Cn", "Co", "Cs"}
retained_chars = ["\t", "\n", "\r", "\f"]
# Remove non-printing control characters
return "".join(
[
"" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char
for char in text
]
)
def normalize_unicode(text):
"""Normalize unicode such that things that are visually equivalent map to the same unicode string where possible."""
normal_form = "NFKC"
text = unicodedata.normalize(normal_form, text)
return text
def normalize_whitespace(text):
"""Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed."""
text = regex.sub(r"\s+", " ", text)
# Remove leading and trailing whitespace
text = text.strip()
return text
def is_leaf(element):
return element.name in {"p", "li"}
def is_text(element):
return isinstance(element, NavigableString)
def is_non_printing(element):
return any(isinstance(element, _e) for _e in [Comment, CData])
def add_content_digest(element):
if not is_text(element):
element["data-content-digest"] = content_digest(element)
return element
def content_digest(element):
if is_text(element):
# Hash
trimmed_string = element.string.strip()
if trimmed_string == "":
digest = ""
else:
digest = hashlib.sha256(trimmed_string.encode("utf-8")).hexdigest()
else:
contents = element.contents
num_contents = len(contents)
if num_contents == 0:
# No hash when no child elements exist
digest = ""
elif num_contents == 1:
# If single child, use digest of child
digest = content_digest(contents[0])
else:
# Build content digest from the "non-empty" digests of child nodes
digest = hashlib.sha256()
child_digests = list(filter(lambda x: x != "", [content_digest(content) for content in contents]))
for child in child_digests:
digest.update(child.encode("utf-8"))
digest = digest.hexdigest()
return digest
def get_image_upload_file_ids(content):
pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)"
matches = re.findall(pattern, content)
image_upload_file_ids = []
for match in matches:
if match[1] == "file-preview":
content_pattern = r"files/([^/]+)/file-preview"
else:
content_pattern = r"files/([^/]+)/image-preview"
content_match = re.search(content_pattern, match[0])
if content_match:
image_upload_file_id = content_match.group(1)
image_upload_file_ids.append(image_upload_file_id)
return image_upload_file_ids

View File

@ -1,5 +1,4 @@
from collections.abc import Generator, Mapping, Sequence
from os import path
from typing import Any, cast
from sqlalchemy import select
@ -236,8 +235,7 @@ class ToolNode(BaseNode[ToolNodeData]):
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=tool_file_id,
filename=message.save_as,
extension=path.splitext(message.save_as)[1],
extension=None,
mime_type=message.meta.get("mime_type", "application/octet-stream"),
)
)