mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-15 01:25:57 +08:00
refactor: replace BuiltinToolManageService with RagPipelineManageService for datasource management and remove unused datasource engine and related code
This commit is contained in:
parent
8bea88c8cc
commit
c5a2f43ceb
@ -101,7 +101,9 @@ class CustomizedPipelineTemplateApi(Resource):
|
||||
@enterprise_license_required
|
||||
def post(self, template_id: str):
|
||||
with Session(db.engine) as session:
|
||||
template = session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
|
||||
template = (
|
||||
session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
|
||||
)
|
||||
if not template:
|
||||
raise ValueError("Customized pipeline template not found.")
|
||||
pipeline = session.query(Pipeline).filter(Pipeline.id == template.pipeline_id).first()
|
||||
|
@ -43,7 +43,7 @@ from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
from services.rag_pipeline.rag_pipeline_manage_service import RagPipelineManageService
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -711,14 +711,7 @@ class DatasourceListApi(Resource):
|
||||
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return jsonable_encoder(
|
||||
[
|
||||
provider.to_dict()
|
||||
for provider in BuiltinToolManageService.list_rag_pipeline_datasources(
|
||||
tenant_id,
|
||||
)
|
||||
]
|
||||
)
|
||||
return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(tenant_id))
|
||||
|
||||
|
||||
api.add_resource(
|
||||
|
@ -1,12 +1,9 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceInvokeMessage,
|
||||
DatasourceParameter,
|
||||
DatasourceProviderType,
|
||||
)
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.plugin.utils.converter import convert_parameters_to_plugin_format
|
||||
@ -16,7 +13,6 @@ class DatasourcePlugin:
|
||||
tenant_id: str
|
||||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
runtime_parameters: Optional[list[DatasourceParameter]]
|
||||
entity: DatasourceEntity
|
||||
runtime: DatasourceRuntime
|
||||
|
||||
@ -33,49 +29,41 @@ class DatasourcePlugin:
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
self.runtime_parameters = None
|
||||
|
||||
def datasource_provider_type(self) -> DatasourceProviderType:
|
||||
return DatasourceProviderType.RAG_PIPELINE
|
||||
|
||||
def _invoke_first_step(
|
||||
self,
|
||||
user_id: str,
|
||||
datasource_parameters: dict[str, Any],
|
||||
rag_pipeline_id: Optional[str] = None,
|
||||
) -> Generator[DatasourceInvokeMessage, None, None]:
|
||||
) -> Mapping[str, Any]:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters)
|
||||
|
||||
yield from manager.invoke_first_step(
|
||||
return manager.invoke_first_step(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider=self.entity.identity.provider,
|
||||
datasource_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
datasource_parameters=datasource_parameters,
|
||||
rag_pipeline_id=rag_pipeline_id,
|
||||
)
|
||||
|
||||
def _invoke_second_step(
|
||||
self,
|
||||
user_id: str,
|
||||
datasource_parameters: dict[str, Any],
|
||||
rag_pipeline_id: Optional[str] = None,
|
||||
) -> Generator[DatasourceInvokeMessage, None, None]:
|
||||
) -> Mapping[str, Any]:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters)
|
||||
|
||||
yield from manager.invoke(
|
||||
return manager.invoke_second_step(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider=self.entity.identity.provider,
|
||||
datasource_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
datasource_parameters=datasource_parameters,
|
||||
rag_pipeline_id=rag_pipeline_id,
|
||||
)
|
||||
|
||||
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
|
||||
@ -86,28 +74,3 @@ class DatasourcePlugin:
|
||||
icon=self.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
rag_pipeline_id: Optional[str] = None,
|
||||
) -> list[DatasourceParameter]:
|
||||
"""
|
||||
get the runtime parameters
|
||||
"""
|
||||
if not self.entity.has_runtime_parameters:
|
||||
return self.entity.parameters
|
||||
|
||||
if self.runtime_parameters is not None:
|
||||
return self.runtime_parameters
|
||||
|
||||
manager = PluginDatasourceManager()
|
||||
self.runtime_parameters = manager.get_runtime_parameters(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="",
|
||||
provider=self.entity.identity.provider,
|
||||
datasource=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
rag_pipeline_id=rag_pipeline_id,
|
||||
)
|
||||
|
||||
return self.runtime_parameters
|
||||
|
@ -2,7 +2,7 @@ from typing import Any
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
@ -22,15 +22,6 @@ class DatasourcePluginProviderController:
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
:return: type of the provider
|
||||
"""
|
||||
return DatasourceProviderType.RAG_PIPELINE
|
||||
|
||||
@property
|
||||
def need_credentials(self) -> bool:
|
||||
"""
|
||||
|
@ -1,224 +0,0 @@
|
||||
import json
|
||||
from collections.abc import Generator, Iterable
|
||||
from mimetypes import guess_type
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from yarl import URL
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceInvokeMessage,
|
||||
DatasourceInvokeMessageBinary,
|
||||
)
|
||||
from core.file import FileType
|
||||
from core.file.models import FileTransferMethod
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatedByRole
|
||||
from models.model import Message, MessageFile
|
||||
|
||||
|
||||
class DatasourceEngine:
|
||||
"""
|
||||
Datasource runtime engine take care of the datasource executions.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def invoke_first_step(
|
||||
datasource: DatasourcePlugin,
|
||||
datasource_parameters: dict[str, Any],
|
||||
user_id: str,
|
||||
workflow_tool_callback: DifyWorkflowCallbackHandler,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[DatasourceInvokeMessage, None, None]:
|
||||
"""
|
||||
Workflow invokes the datasource with the given arguments.
|
||||
"""
|
||||
try:
|
||||
# hit the callback handler
|
||||
workflow_tool_callback.on_datasource_start(
|
||||
datasource_name=datasource.entity.identity.name, datasource_inputs=datasource_parameters
|
||||
)
|
||||
|
||||
if datasource.runtime and datasource.runtime.runtime_parameters:
|
||||
datasource_parameters = {**datasource.runtime.runtime_parameters, **datasource_parameters}
|
||||
|
||||
response = datasource._invoke_first_step(
|
||||
user_id=user_id,
|
||||
datasource_parameters=datasource_parameters,
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
# hit the callback handler
|
||||
response = workflow_tool_callback.on_datasource_end(
|
||||
datasource_name=datasource.entity.identity.name,
|
||||
datasource_inputs=datasource_parameters,
|
||||
datasource_outputs=response,
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
workflow_tool_callback.on_tool_error(e)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def invoke_second_step(
|
||||
datasource: DatasourcePlugin,
|
||||
datasource_parameters: dict[str, Any],
|
||||
user_id: str,
|
||||
workflow_tool_callback: DifyWorkflowCallbackHandler,
|
||||
) -> Generator[DatasourceInvokeMessage, None, None]:
|
||||
"""
|
||||
Workflow invokes the datasource with the given arguments.
|
||||
"""
|
||||
try:
|
||||
response = datasource._invoke_second_step(
|
||||
user_id=user_id,
|
||||
datasource_parameters=datasource_parameters,
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
workflow_tool_callback.on_tool_error(e)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def _convert_datasource_response_to_str(datasource_response: list[DatasourceInvokeMessage]) -> str:
|
||||
"""
|
||||
Handle datasource response
|
||||
"""
|
||||
result = ""
|
||||
for response in datasource_response:
|
||||
if response.type == DatasourceInvokeMessage.MessageType.TEXT:
|
||||
result += cast(DatasourceInvokeMessage.TextMessage, response.message).text
|
||||
elif response.type == DatasourceInvokeMessage.MessageType.LINK:
|
||||
result += (
|
||||
f"result link: {cast(DatasourceInvokeMessage.TextMessage, response.message).text}."
|
||||
+ " please tell user to check it."
|
||||
)
|
||||
elif response.type in {
|
||||
DatasourceInvokeMessage.MessageType.IMAGE_LINK,
|
||||
DatasourceInvokeMessage.MessageType.IMAGE,
|
||||
}:
|
||||
result += (
|
||||
"image has been created and sent to user already, "
|
||||
+ "you do not need to create it, just tell the user to check it now."
|
||||
)
|
||||
elif response.type == DatasourceInvokeMessage.MessageType.JSON:
|
||||
result = json.dumps(
|
||||
cast(DatasourceInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False
|
||||
)
|
||||
else:
|
||||
result += str(response.message)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _extract_datasource_response_binary_and_text(
|
||||
datasource_response: list[DatasourceInvokeMessage],
|
||||
) -> Generator[DatasourceInvokeMessageBinary, None, None]:
|
||||
"""
|
||||
Extract datasource response binary
|
||||
"""
|
||||
for response in datasource_response:
|
||||
if response.type in {
|
||||
DatasourceInvokeMessage.MessageType.IMAGE_LINK,
|
||||
DatasourceInvokeMessage.MessageType.IMAGE,
|
||||
}:
|
||||
mimetype = None
|
||||
if not response.meta:
|
||||
raise ValueError("missing meta data")
|
||||
if response.meta.get("mime_type"):
|
||||
mimetype = response.meta.get("mime_type")
|
||||
else:
|
||||
try:
|
||||
url = URL(cast(DatasourceInvokeMessage.TextMessage, response.message).text)
|
||||
extension = url.suffix
|
||||
guess_type_result, _ = guess_type(f"a{extension}")
|
||||
if guess_type_result:
|
||||
mimetype = guess_type_result
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not mimetype:
|
||||
mimetype = "image/jpeg"
|
||||
|
||||
yield DatasourceInvokeMessageBinary(
|
||||
mimetype=response.meta.get("mime_type", "image/jpeg"),
|
||||
url=cast(DatasourceInvokeMessage.TextMessage, response.message).text,
|
||||
)
|
||||
elif response.type == DatasourceInvokeMessage.MessageType.BLOB:
|
||||
if not response.meta:
|
||||
raise ValueError("missing meta data")
|
||||
|
||||
yield DatasourceInvokeMessageBinary(
|
||||
mimetype=response.meta.get("mime_type", "application/octet-stream"),
|
||||
url=cast(DatasourceInvokeMessage.TextMessage, response.message).text,
|
||||
)
|
||||
elif response.type == DatasourceInvokeMessage.MessageType.LINK:
|
||||
# check if there is a mime type in meta
|
||||
if response.meta and "mime_type" in response.meta:
|
||||
yield DatasourceInvokeMessageBinary(
|
||||
mimetype=response.meta.get("mime_type", "application/octet-stream")
|
||||
if response.meta
|
||||
else "application/octet-stream",
|
||||
url=cast(DatasourceInvokeMessage.TextMessage, response.message).text,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_message_files(
|
||||
datasource_messages: Iterable[DatasourceInvokeMessageBinary],
|
||||
agent_message: Message,
|
||||
invoke_from: InvokeFrom,
|
||||
user_id: str,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Create message file
|
||||
|
||||
:return: message file ids
|
||||
"""
|
||||
result = []
|
||||
|
||||
for message in datasource_messages:
|
||||
if "image" in message.mimetype:
|
||||
file_type = FileType.IMAGE
|
||||
elif "video" in message.mimetype:
|
||||
file_type = FileType.VIDEO
|
||||
elif "audio" in message.mimetype:
|
||||
file_type = FileType.AUDIO
|
||||
elif "text" in message.mimetype or "pdf" in message.mimetype:
|
||||
file_type = FileType.DOCUMENT
|
||||
else:
|
||||
file_type = FileType.CUSTOM
|
||||
|
||||
# extract tool file id from url
|
||||
tool_file_id = message.url.split("/")[-1].split(".")[0]
|
||||
message_file = MessageFile(
|
||||
message_id=agent_message.id,
|
||||
type=file_type,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
belongs_to="assistant",
|
||||
url=message.url,
|
||||
upload_file_id=tool_file_id,
|
||||
created_by_role=(
|
||||
CreatedByRole.ACCOUNT
|
||||
if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||
else CreatedByRole.END_USER
|
||||
),
|
||||
created_by=user_id,
|
||||
)
|
||||
|
||||
db.session.add(message_file)
|
||||
db.session.commit()
|
||||
db.session.refresh(message_file)
|
||||
|
||||
result.append(message_file.id)
|
||||
|
||||
db.session.close()
|
||||
|
||||
return result
|
@ -6,9 +6,8 @@ import contexts
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.entities.common_entities import I18nObject
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||
from core.datasource.errors import DatasourceProviderNotFoundError
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -36,7 +35,7 @@ class DatasourceManager:
|
||||
if provider in datasource_plugin_providers:
|
||||
return datasource_plugin_providers[provider]
|
||||
|
||||
manager = PluginToolManager()
|
||||
manager = PluginDatasourceManager()
|
||||
provider_entity = manager.fetch_datasource_provider(tenant_id, provider)
|
||||
if not provider_entity:
|
||||
raise DatasourceProviderNotFoundError(f"plugin provider {provider} not found")
|
||||
@ -55,7 +54,6 @@ class DatasourceManager:
|
||||
@classmethod
|
||||
def get_datasource_runtime(
|
||||
cls,
|
||||
provider_type: DatasourceProviderType,
|
||||
provider_id: str,
|
||||
datasource_name: str,
|
||||
tenant_id: str,
|
||||
@ -70,18 +68,15 @@ class DatasourceManager:
|
||||
|
||||
:return: the datasource plugin
|
||||
"""
|
||||
if provider_type == DatasourceProviderType.RAG_PIPELINE:
|
||||
return cls.get_datasource_plugin_provider(provider_id, tenant_id).get_datasource(datasource_name)
|
||||
else:
|
||||
raise DatasourceProviderNotFoundError(f"provider type {provider_type.value} not found")
|
||||
|
||||
@classmethod
|
||||
def list_datasource_providers(cls, tenant_id: str) -> list[DatasourcePluginProviderController]:
|
||||
"""
|
||||
list all the datasource providers
|
||||
"""
|
||||
manager = PluginToolManager()
|
||||
provider_entities = manager.fetch_datasources(tenant_id)
|
||||
manager = PluginDatasourceManager()
|
||||
provider_entities = manager.fetch_datasource_providers(tenant_id)
|
||||
return [
|
||||
DatasourcePluginProviderController(
|
||||
entity=provider.declaration,
|
||||
|
@ -4,7 +4,6 @@ from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.datasource.entities.datasource_entities import DatasourceParameter
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool import ToolParameter
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
|
||||
@ -14,7 +13,7 @@ class DatasourceApiEntity(BaseModel):
|
||||
name: str # identifier
|
||||
label: I18nObject # label
|
||||
description: I18nObject
|
||||
parameters: Optional[list[ToolParameter]] = None
|
||||
parameters: Optional[list[DatasourceParameter]] = None
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
output_schema: Optional[dict] = None
|
||||
|
||||
|
@ -1 +0,0 @@
|
||||
DATASOURCE_SELECTOR_MODEL_IDENTITY = "__dify__datasource_selector__"
|
@ -1,13 +1,9 @@
|
||||
import base64
|
||||
import enum
|
||||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_serializer, field_validator, model_validator
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
|
||||
from core.datasource.entities.constants import DATASOURCE_SELECTOR_MODEL_IDENTITY
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.entities.parameters import (
|
||||
PluginParameter,
|
||||
PluginParameterOption,
|
||||
@ -17,25 +13,7 @@ from core.plugin.entities.parameters import (
|
||||
init_frontend_parameter,
|
||||
)
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
||||
|
||||
class ToolLabelEnum(Enum):
|
||||
SEARCH = "search"
|
||||
IMAGE = "image"
|
||||
VIDEOS = "videos"
|
||||
WEATHER = "weather"
|
||||
FINANCE = "finance"
|
||||
DESIGN = "design"
|
||||
TRAVEL = "travel"
|
||||
SOCIAL = "social"
|
||||
NEWS = "news"
|
||||
MEDICAL = "medical"
|
||||
PRODUCTIVITY = "productivity"
|
||||
EDUCATION = "education"
|
||||
BUSINESS = "business"
|
||||
ENTERTAINMENT = "entertainment"
|
||||
UTILITIES = "utilities"
|
||||
OTHER = "other"
|
||||
from core.tools.entities.tool_entities import ToolProviderEntity
|
||||
|
||||
|
||||
class DatasourceProviderType(enum.StrEnum):
|
||||
@ -43,7 +21,9 @@ class DatasourceProviderType(enum.StrEnum):
|
||||
Enum class for datasource provider
|
||||
"""
|
||||
|
||||
RAG_PIPELINE = "rag_pipeline"
|
||||
ONLINE_DOCUMENT = "online_document"
|
||||
LOCAL_FILE = "local_file"
|
||||
WEBSITE = "website"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "DatasourceProviderType":
|
||||
@ -59,153 +39,6 @@ class DatasourceProviderType(enum.StrEnum):
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
class ApiProviderSchemaType(Enum):
|
||||
"""
|
||||
Enum class for api provider schema type.
|
||||
"""
|
||||
|
||||
OPENAPI = "openapi"
|
||||
SWAGGER = "swagger"
|
||||
OPENAI_PLUGIN = "openai_plugin"
|
||||
OPENAI_ACTIONS = "openai_actions"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ApiProviderSchemaType":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
class ApiProviderAuthType(Enum):
|
||||
"""
|
||||
Enum class for api provider auth type.
|
||||
"""
|
||||
|
||||
NONE = "none"
|
||||
API_KEY = "api_key"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ApiProviderAuthType":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
class DatasourceInvokeMessage(BaseModel):
|
||||
class TextMessage(BaseModel):
|
||||
text: str
|
||||
|
||||
class JsonMessage(BaseModel):
|
||||
json_object: dict
|
||||
|
||||
class BlobMessage(BaseModel):
|
||||
blob: bytes
|
||||
|
||||
class FileMessage(BaseModel):
|
||||
pass
|
||||
|
||||
class VariableMessage(BaseModel):
|
||||
variable_name: str = Field(..., description="The name of the variable")
|
||||
variable_value: Any = Field(..., description="The value of the variable")
|
||||
stream: bool = Field(default=False, description="Whether the variable is streamed")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def transform_variable_value(cls, values) -> Any:
|
||||
"""
|
||||
Only basic types and lists are allowed.
|
||||
"""
|
||||
value = values.get("variable_value")
|
||||
if not isinstance(value, dict | list | str | int | float | bool):
|
||||
raise ValueError("Only basic types and lists are allowed.")
|
||||
|
||||
# if stream is true, the value must be a string
|
||||
if values.get("stream"):
|
||||
if not isinstance(value, str):
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
|
||||
return values
|
||||
|
||||
@field_validator("variable_name", mode="before")
|
||||
@classmethod
|
||||
def transform_variable_name(cls, value: str) -> str:
|
||||
"""
|
||||
The variable name must be a string.
|
||||
"""
|
||||
if value in {"json", "text", "files"}:
|
||||
raise ValueError(f"The variable name '{value}' is reserved.")
|
||||
return value
|
||||
|
||||
class LogMessage(BaseModel):
|
||||
class LogStatus(Enum):
|
||||
START = "start"
|
||||
ERROR = "error"
|
||||
SUCCESS = "success"
|
||||
|
||||
id: str
|
||||
label: str = Field(..., description="The label of the log")
|
||||
parent_id: Optional[str] = Field(default=None, description="Leave empty for root log")
|
||||
error: Optional[str] = Field(default=None, description="The error message")
|
||||
status: LogStatus = Field(..., description="The status of the log")
|
||||
data: Mapping[str, Any] = Field(..., description="Detailed log data")
|
||||
metadata: Optional[Mapping[str, Any]] = Field(default=None, description="The metadata of the log")
|
||||
|
||||
class MessageType(Enum):
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
LINK = "link"
|
||||
BLOB = "blob"
|
||||
JSON = "json"
|
||||
IMAGE_LINK = "image_link"
|
||||
BINARY_LINK = "binary_link"
|
||||
VARIABLE = "variable"
|
||||
FILE = "file"
|
||||
LOG = "log"
|
||||
|
||||
type: MessageType = MessageType.TEXT
|
||||
"""
|
||||
plain text, image url or link url
|
||||
"""
|
||||
message: JsonMessage | TextMessage | BlobMessage | LogMessage | FileMessage | None | VariableMessage
|
||||
meta: dict[str, Any] | None = None
|
||||
|
||||
@field_validator("message", mode="before")
|
||||
@classmethod
|
||||
def decode_blob_message(cls, v):
|
||||
if isinstance(v, dict) and "blob" in v:
|
||||
try:
|
||||
v["blob"] = base64.b64decode(v["blob"])
|
||||
except Exception:
|
||||
pass
|
||||
return v
|
||||
|
||||
@field_serializer("message")
|
||||
def serialize_message(self, v):
|
||||
if isinstance(v, self.BlobMessage):
|
||||
return {"blob": base64.b64encode(v.blob).decode("utf-8")}
|
||||
return v
|
||||
|
||||
|
||||
class DatasourceInvokeMessageBinary(BaseModel):
|
||||
mimetype: str = Field(..., description="The mimetype of the binary")
|
||||
url: str = Field(..., description="The url of the binary")
|
||||
file_var: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class DatasourceParameter(PluginParameter):
|
||||
"""
|
||||
Overrides type
|
||||
@ -223,8 +56,6 @@ class DatasourceParameter(PluginParameter):
|
||||
SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
|
||||
FILE = PluginParameterType.FILE.value
|
||||
FILES = PluginParameterType.FILES.value
|
||||
APP_SELECTOR = PluginParameterType.APP_SELECTOR.value
|
||||
MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
|
||||
@ -235,21 +66,13 @@ class DatasourceParameter(PluginParameter):
|
||||
def cast_value(self, value: Any):
|
||||
return cast_parameter_value(self, value)
|
||||
|
||||
class DatasourceParameterForm(Enum):
|
||||
SCHEMA = "schema" # should be set while adding tool
|
||||
FORM = "form" # should be set before invoking tool
|
||||
LLM = "llm" # will be set by LLM
|
||||
|
||||
type: DatasourceParameterType = Field(..., description="The type of the parameter")
|
||||
human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user")
|
||||
form: DatasourceParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
|
||||
llm_description: Optional[str] = None
|
||||
description: I18nObject = Field(..., description="The description of the parameter")
|
||||
|
||||
@classmethod
|
||||
def get_simple_instance(
|
||||
cls,
|
||||
name: str,
|
||||
llm_description: str,
|
||||
typ: DatasourceParameterType,
|
||||
required: bool,
|
||||
options: Optional[list[str]] = None,
|
||||
@ -277,30 +100,16 @@ class DatasourceParameter(PluginParameter):
|
||||
name=name,
|
||||
label=I18nObject(en_US="", zh_Hans=""),
|
||||
placeholder=None,
|
||||
human_description=I18nObject(en_US="", zh_Hans=""),
|
||||
type=typ,
|
||||
form=cls.ToolParameterForm.LLM,
|
||||
llm_description=llm_description,
|
||||
required=required,
|
||||
options=option_objs,
|
||||
description=I18nObject(en_US="", zh_Hans=""),
|
||||
)
|
||||
|
||||
def init_frontend_parameter(self, value: Any):
|
||||
return init_frontend_parameter(self, self.type, value)
|
||||
|
||||
|
||||
class ToolProviderIdentity(BaseModel):
|
||||
author: str = Field(..., description="The author of the tool")
|
||||
name: str = Field(..., description="The name of the tool")
|
||||
description: I18nObject = Field(..., description="The description of the tool")
|
||||
icon: str = Field(..., description="The icon of the tool")
|
||||
label: I18nObject = Field(..., description="The label of the tool")
|
||||
tags: Optional[list[ToolLabelEnum]] = Field(
|
||||
default=[],
|
||||
description="The tags of the tool",
|
||||
)
|
||||
|
||||
|
||||
class DatasourceIdentity(BaseModel):
|
||||
author: str = Field(..., description="The author of the tool")
|
||||
name: str = Field(..., description="The name of the tool")
|
||||
@ -327,26 +136,18 @@ class DatasourceEntity(BaseModel):
|
||||
return v or []
|
||||
|
||||
|
||||
class ToolProviderEntity(BaseModel):
|
||||
identity: ToolProviderIdentity
|
||||
plugin_id: Optional[str] = None
|
||||
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
|
||||
class DatasourceProviderEntity(ToolProviderEntity):
|
||||
"""
|
||||
Datasource provider entity
|
||||
"""
|
||||
|
||||
provider_type: DatasourceProviderType
|
||||
|
||||
|
||||
class DatasourceProviderEntityWithPlugin(ToolProviderEntity):
|
||||
class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity):
|
||||
datasources: list[DatasourceEntity] = Field(default_factory=list)
|
||||
|
||||
|
||||
class WorkflowToolParameterConfiguration(BaseModel):
|
||||
"""
|
||||
Workflow tool configuration
|
||||
"""
|
||||
|
||||
name: str = Field(..., description="The name of the parameter")
|
||||
description: str = Field(..., description="The description of the parameter")
|
||||
form: DatasourceParameter.DatasourceParameterForm = Field(..., description="The form of the parameter")
|
||||
|
||||
|
||||
class DatasourceInvokeMeta(BaseModel):
|
||||
"""
|
||||
Datasource invoke meta
|
||||
@ -394,24 +195,3 @@ class DatasourceInvokeFrom(Enum):
|
||||
"""
|
||||
|
||||
RAG_PIPELINE = "rag_pipeline"
|
||||
|
||||
|
||||
class DatasourceSelector(BaseModel):
|
||||
dify_model_identity: str = DATASOURCE_SELECTOR_MODEL_IDENTITY
|
||||
|
||||
class Parameter(BaseModel):
|
||||
name: str = Field(..., description="The name of the parameter")
|
||||
type: DatasourceParameter.DatasourceParameterType = Field(..., description="The type of the parameter")
|
||||
required: bool = Field(..., description="Whether the parameter is required")
|
||||
description: str = Field(..., description="The description of the parameter")
|
||||
default: Optional[Union[int, float, str]] = None
|
||||
options: Optional[list[PluginParameterOption]] = None
|
||||
|
||||
provider_id: str = Field(..., description="The id of the provider")
|
||||
datasource_name: str = Field(..., description="The name of the datasource")
|
||||
datasource_description: str = Field(..., description="The description of the datasource")
|
||||
datasource_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form")
|
||||
datasource_parameters: Mapping[str, Parameter] = Field(..., description="Parameters, type llm")
|
||||
|
||||
def to_plugin_parameter(self) -> dict[str, Any]:
|
||||
return self.model_dump()
|
||||
|
@ -1,111 +0,0 @@
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolLabel, ToolLabelEnum
|
||||
|
||||
ICONS = {
|
||||
ToolLabelEnum.SEARCH: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M7.33398 1.3335C10.646 1.3335 13.334 4.0215 13.334 7.3335C13.334 10.6455 10.646 13.3335 7.33398 13.3335C4.02198 13.3335 1.33398 10.6455 1.33398 7.3335C1.33398 4.0215 4.02198 1.3335 7.33398 1.3335ZM7.33398 12.0002C9.91232 12.0002 12.0007 9.91183 12.0007 7.3335C12.0007 4.75516 9.91232 2.66683 7.33398 2.66683C4.75565 2.66683 2.66732 4.75516 2.66732 7.3335C2.66732 9.91183 4.75565 12.0002 7.33398 12.0002ZM12.9909 12.0476L14.8764 13.9332L13.9337 14.876L12.0481 12.9904L12.9909 12.0476Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.IMAGE: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M13.0514 9.71752L10.4718 7.13792C10.2115 6.87752 9.78932 6.87752 9.52898 7.13792L4.57721 12.0897C3.4097 11.1113 2.66732 9.64232 2.66732 7.99992C2.66732 5.0544 5.05513 2.66659 8.00065 2.66659C10.9462 2.66659 13.334 5.0544 13.334 7.99992C13.334 8.60085 13.2346 9.17852 13.0514 9.71752ZM5.72683 12.8257L10.0004 8.55212L12.4259 10.9777C11.4668 12.4001 9.84152 13.3331 8.00038 13.3331C7.18632 13.3331 6.41628 13.1511 5.72683 12.8257ZM8.00065 14.6666C11.6825 14.6666 14.6673 11.6818 14.6673 7.99992C14.6673 4.31802 11.6825 1.33325 8.00065 1.33325C4.31875 1.33325 1.33398 4.31802 1.33398 7.99992C1.33398 11.6818 4.31875 14.6666 8.00065 14.6666ZM7.33398 6.66658C7.33398 7.40299 6.73705 7.99992 6.00065 7.99992C5.26427 7.99992 4.66732 7.40299 4.66732 6.66658C4.66732 5.9302 5.26427 5.33325 6.00065 5.33325C6.73705 5.33325 7.33398 5.9302 7.33398 6.66658Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.VIDEOS: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M8.00065 13.3333H13.334V14.6666H8.00065C4.31875 14.6666 1.33398 11.6818 1.33398 7.99992C1.33398 4.31802 4.31875 1.33325 8.00065 1.33325C11.6825 1.33325 14.6673 4.31802 14.6673 7.99992C14.6673 9.50072 14.1714 10.8857 13.3345 11.9999H11.5284C12.6356 11.0227 13.334 9.59285 13.334 7.99992C13.334 5.0544 10.9462 2.66659 8.00065 2.66659C5.05513 2.66659 2.66732 5.0544 2.66732 7.99992C2.66732 10.9455 5.05513 13.3333 8.00065 13.3333ZM8.00065 6.66658C7.26425 6.66658 6.66732 6.06963 6.66732 5.33325C6.66732 4.59687 7.26425 3.99992 8.00065 3.99992C8.73705 3.99992 9.33398 4.59687 9.33398 5.33325C9.33398 6.06963 8.73705 6.66658 8.00065 6.66658ZM5.33398 9.33325C4.5976 9.33325 4.00065 8.73632 4.00065 7.99992C4.00065 7.26352 4.5976 6.66658 5.33398 6.66658C6.07036 6.66658 6.66732 7.26352 6.66732 7.99992C6.66732 8.73632 6.07036 9.33325 5.33398 9.33325ZM10.6673 9.33325C9.93092 9.33325 9.33398 8.73632 9.33398 7.99992C9.33398 7.26352 9.93092 6.66658 10.6673 6.66658C11.4037 6.66658 12.0007 7.26352 12.0007 7.99992C12.0007 8.73632 11.4037 9.33325 10.6673 9.33325ZM8.00065 11.9999C7.26425 11.9999 6.66732 11.403 6.66732 10.6666C6.66732 9.93018 7.26425 9.33325 8.00065 9.33325C8.73705 9.33325 9.33398 9.93018 9.33398 10.6666C9.33398 11.403 8.73705 11.9999 8.00065 11.9999Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.WEATHER: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M6.6553 3.37344C7.42088 2.1484 8.78162 1.3335 10.3327 1.3335C12.7259 1.3335 14.666 3.2736 14.666 5.66683C14.666 6.38704 14.4903 7.06623 14.1794 7.66383C14.8894 8.3325 15.3327 9.28123 15.3327 10.3335C15.3327 12.3586 13.6911 14.0002 11.666 14.0002H5.99935C3.05383 14.0002 0.666016 11.6124 0.666016 8.66683C0.666016 5.72131 3.05383 3.3335 5.99935 3.3335C6.22143 3.3335 6.44034 3.34707 6.6553 3.37344ZM8.03628 3.73629C9.37768 4.29108 10.4435 5.37735 10.9711 6.73256C11.1961 6.68943 11.4284 6.66683 11.666 6.66683C12.1561 6.66683 12.6237 6.76296 13.0511 6.93743C13.2317 6.55162 13.3327 6.12102 13.3327 5.66683C13.3327 4.00998 11.9895 2.66683 10.3327 2.66683C9.41115 2.66683 8.58662 3.08236 8.03628 3.73629ZM11.666 12.6668C12.9547 12.6668 13.9993 11.6222 13.9993 10.3335C13.9993 9.04483 12.9547 8.00016 11.666 8.00016C11.013 8.00016 10.4227 8.26836 9.99922 8.70063C9.99928 8.68936 9.99935 8.6781 9.99935 8.66683C9.99935 6.45769 8.20848 4.66683 5.99935 4.66683C3.79021 4.66683 1.99935 6.45769 1.99935 8.66683C1.99935 10.876 3.79021 12.6668 5.99935 12.6668H11.666Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.FINANCE: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M8.00262 14.6685C4.32071 14.6685 1.33594 11.6838 1.33594 8.00184C1.33594 4.31997 4.32071 1.33521 8.00262 1.33521C11.6845 1.33521 14.6693 4.31997 14.6693 8.00184C14.6693 11.6838 11.6845 14.6685 8.00262 14.6685ZM8.00262 13.3352C10.9482 13.3352 13.336 10.9474 13.336 8.00184C13.336 5.05635 10.9482 2.66854 8.00262 2.66854C5.05708 2.66854 2.66927 5.05635 2.66927 8.00184C2.66927 10.9474 5.05708 13.3352 8.00262 13.3352ZM5.66927 9.33517H9.33595C9.52002 9.33517 9.66928 9.18597 9.66928 9.00184C9.66928 8.81777 9.52002 8.66851 9.33595 8.66851H6.66928C5.7488 8.66851 5.0026 7.92237 5.0026 7.00184C5.0026 6.08139 5.7488 5.33521 6.66928 5.33521H7.33595V4.00187H8.66928V5.33521H10.336V6.66851H6.66928C6.48518 6.66851 6.33594 6.81777 6.33594 7.00184C6.33594 7.18597 6.48518 7.33517 6.66928 7.33517H9.33595C10.2564 7.33517 11.0026 8.08137 11.0026 9.00184C11.0026 9.92237 10.2564 10.6685 9.33595 10.6685H8.66928V12.0018H7.33595V10.6685H5.66927V9.33517Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.DESIGN: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M4.70152 9.41416L3.2873 10.8284L5.17292 12.714L12.7154 5.17154L10.8298 3.28592L9.41557 4.70013L10.3584 5.64295L9.41557 6.58575L8.47277 5.64295L7.52997 6.58575L8.47277 7.52856L7.52997 8.47136L6.58713 7.52856L5.64433 8.47136L6.58713 9.41416L5.64433 10.357L4.70152 9.41416ZM11.3012 1.87171L14.1296 4.70013C14.39 4.96049 14.39 5.38259 14.1296 5.64295L5.64433 14.1282C5.38397 14.3886 4.96187 14.3886 4.70152 14.1282L1.87309 11.2998C1.61274 11.0394 1.61274 10.6174 1.87309 10.357L10.3584 1.87171C10.6187 1.61136 11.0408 1.61136 11.3012 1.87171ZM9.41557 12.2423L10.3584 11.2995L11.8534 12.7945H12.7962V11.8517L11.3012 10.3567L12.244 9.41383L14.0011 11.171V13.9999H11.1732L9.41557 12.2423ZM3.75861 6.58533L1.87299 4.69971C1.61265 4.43937 1.61265 4.01725 1.87299 3.75691L3.75861 1.87129C4.01896 1.61094 4.44107 1.61094 4.70142 1.87129L6.58704 3.75691L5.64423 4.69971L4.23002 3.2855L3.28721 4.22831L4.70142 5.64253L3.75861 6.58533Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.TRAVEL: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M9.44839 2C9.80198 2 10.1411 2.14047 10.3912 2.39053L13.6101 5.60947C13.8602 5.85953 14.0007 6.19866 14.0007 6.55229V11.3333H15.334V12.6667L9.91652 12.6672C9.62032 13.8171 8.57638 14.6667 7.33398 14.6667C6.0916 14.6667 5.04766 13.8171 4.75146 12.6672L2.00065 12.6667C1.63246 12.6667 1.33398 12.3682 1.33398 12V3.33333C1.33398 2.59695 1.93094 2 2.66732 2H9.44839ZM7.33398 10.6667C6.5976 10.6667 6.00065 11.2636 6.00065 12C6.00065 12.7364 6.5976 13.3333 7.33398 13.3333C8.07038 13.3333 8.66732 12.7364 8.66732 12C8.66732 11.2636 8.07038 10.6667 7.33398 10.6667ZM9.44839 3.33333H2.66732V11.3333L4.75128 11.3335C5.04726 10.1833 6.09136 9.33333 7.33398 9.33333C8.57658 9.33333 9.62072 10.1833 9.91665 11.3335L12.6673 11.3333V6.55229L9.44839 3.33333ZM9.33398 4.66667V8.66667H4.00065V4.66667H9.33398ZM8.00065 6H5.33398V7.33333H8.00065V6Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.SOCIAL: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M13.334 7.99992C13.334 5.0544 10.9462 2.66659 8.00065 2.66659C5.05513 2.66659 2.66732 5.0544 2.66732 7.99992C2.66732 10.9455 5.05513 13.3333 8.00065 13.3333C9.09518 13.3333 10.1127 13.0035 10.9594 12.438L11.699 13.5475C10.6408 14.2545 9.36885 14.6666 8.00065 14.6666C4.31875 14.6666 1.33398 11.6818 1.33398 7.99992C1.33398 4.31802 4.31875 1.33325 8.00065 1.33325C11.6825 1.33325 14.6673 4.31802 14.6673 7.99992V8.99992C14.6673 10.2886 13.6227 11.3333 12.334 11.3333C11.5312 11.3333 10.8231 10.9278 10.4032 10.3105C9.79678 10.9409 8.94452 11.3333 8.00065 11.3333C6.1597 11.3333 4.66732 9.84085 4.66732 7.99992C4.66732 6.15897 6.1597 4.66658 8.00065 4.66658C8.75118 4.66658 9.44378 4.91464 10.001 5.33325H11.334V8.99992C11.334 9.55219 11.7817 9.99992 12.334 9.99992C12.8863 9.99992 13.334 9.55219 13.334 8.99992V7.99992ZM8.00065 5.99992C6.89605 5.99992 6.00065 6.89532 6.00065 7.99992C6.00065 9.10452 6.89605 9.99992 8.00065 9.99992C9.10525 9.99992 10.0007 9.10452 10.0007 7.99992C10.0007 6.89532 9.10525 5.99992 8.00065 5.99992Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.NEWS: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M10.6673 13.3335V2.66683H2.66732V12.6668C2.66732 13.035 2.9658 13.3335 3.33398 13.3335H10.6673ZM12.6673 14.6668H3.33398C2.22942 14.6668 1.33398 13.7714 1.33398 12.6668V2.00016C1.33398 1.63198 1.63246 1.3335 2.00065 1.3335H11.334C11.7022 1.3335 12.0007 1.63198 12.0007 2.00016V6.66683H14.6673V12.6668C14.6673 13.7714 13.7719 14.6668 12.6673 14.6668ZM12.0007 8.00016V12.6668C12.0007 13.035 12.2991 13.3335 12.6673 13.3335C13.0355 13.3335 13.334 13.035 13.334 12.6668V8.00016H12.0007ZM4.00065 4.00016H8.00065V8.00016H4.00065V4.00016ZM5.33398 5.3335V6.66683H6.66732V5.3335H5.33398ZM4.00065 8.66683H9.33398V10.0002H4.00065V8.66683ZM4.00065 10.6668H9.33398V12.0002H4.00065V10.6668Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.MEDICAL: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M8.79747 1.51186L10.9641 5.26464C11.1482 5.5835 11.0389 5.99122 10.7201 6.17532L9.85373 6.67474L10.5207 7.83001L9.366 8.49668L8.699 7.34141L7.83333 7.84201C7.51447 8.02608 7.10673 7.91681 6.92267 7.59794L5.69747 5.47632C4.32922 5.89145 3.33333 7.16268 3.33333 8.66654C3.33333 9.08348 3.40987 9.48248 3.54965 9.85034C4.06613 9.52254 4.67762 9.33321 5.33333 9.33321C6.45605 9.33321 7.44913 9.88828 8.05313 10.7389L13.1787 7.78014L13.8454 8.93488L8.5932 11.9672C8.64133 12.1927 8.66667 12.4267 8.66667 12.6665C8.66667 12.895 8.64367 13.1181 8.59993 13.3337L14 13.3332V14.6665L2.66703 14.6673C2.2482 14.1101 2 13.4173 2 12.6665C2 11.9951 2.19855 11.3699 2.54014 10.8467C2.19517 10.1964 2 9.45428 2 8.66654C2 6.66968 3.25421 4.96575 5.01785 4.29953L4.75598 3.84519C4.38779 3.20747 4.60629 2.39202 5.24402 2.02382L6.97607 1.02382C7.6138 0.655637 8.42927 0.874138 8.79747 1.51186ZM5.33333 10.6665C4.22877 10.6665 3.33333 11.562 3.33333 12.6665C3.33333 12.9003 3.37343 13.1247 3.44711 13.3331H7.21953C7.29327 13.1247 7.33333 12.9003 7.33333 12.6665C7.33333 11.562 6.4379 10.6665 5.33333 10.6665ZM7.64273 2.17852L5.91068 3.17852L7.744 6.35395L9.47607 5.35395L7.64273 2.17852Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.PRODUCTIVITY: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M6.64807 11.9999H9.35062C9.43862 11.1989 9.84742 10.5376 10.5111 9.81499C10.5858 9.73365 11.0652 9.23752 11.1221 9.16665C11.6872 8.46199 11.9993 7.58992 11.9993 6.66659C11.9993 4.45745 10.2085 2.66659 7.99935 2.66659C5.79021 2.66659 3.99935 4.45745 3.99935 6.66659C3.99935 7.58945 4.31118 8.46105 4.87576 9.16552C4.93271 9.23659 5.41322 9.73405 5.48704 9.81445C6.15112 10.5375 6.56004 11.1989 6.64807 11.9999ZM9.33268 13.3333H6.66602V13.9999H9.33268V13.3333ZM3.83532 9.99939C3.10365 9.08639 2.66602 7.92759 2.66602 6.66659C2.66602 3.72107 5.05383 1.33325 7.99935 1.33325C10.9449 1.33325 13.3327 3.72107 13.3327 6.66659C13.3327 7.92825 12.8945 9.08759 12.1622 10.0009C11.7487 10.5165 10.666 11.3333 10.666 12.3333V13.9999C10.666 14.7363 10.0691 15.3333 9.33268 15.3333H6.66602C5.92964 15.3333 5.33268 14.7363 5.33268 13.9999V12.3333C5.33268 11.3333 4.24907 10.5157 3.83532 9.99939ZM8.66602 6.66979H10.3327L7.33268 10.6698V8.00312H5.66602L8.66602 3.99992V6.66979Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.EDUCATION: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M14 2.66683H4.66667C3.93029 2.66683 3.33333 3.26378 3.33333 4.00016C3.33333 4.73654 3.93029 5.3335 4.66667 5.3335H14V14.0002C14 14.3684 13.7015 14.6668 13.3333 14.6668H4.66667C3.19391 14.6668 2 13.4729 2 12.0002V4.00016C2 2.5274 3.19391 1.3335 4.66667 1.3335H13.3333C13.7015 1.3335 14 1.63198 14 2.00016V2.66683ZM3.33333 12.0002C3.33333 12.7366 3.93029 13.3335 4.66667 13.3335H12.6667V6.66683H4.66667C4.18095 6.66683 3.72557 6.53697 3.33333 6.31008V12.0002ZM13.3333 4.66683H4.66667C4.29848 4.66683 4 4.36835 4 4.00016C4 3.63198 4.29848 3.3335 4.66667 3.3335H13.3333V4.66683Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.BUSINESS: """<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 14 14" fill="none">
|
||||
<path d="M3.66732 3.33341V1.33341C3.66732 0.965228 3.9658 0.666748 4.33398 0.666748H9.66732C10.0355 0.666748 10.334 0.965228 10.334 1.33341V3.33341H13.0007C13.3689 3.33341 13.6673 3.63189 13.6673 4.00008V13.3334C13.6673 13.7016 13.3689 14.0001 13.0007 14.0001H1.00065C0.632464 14.0001 0.333984 13.7016 0.333984 13.3334V4.00008C0.333984 3.63189 0.632464 3.33341 1.00065 3.33341H3.66732ZM12.334 8.66675H1.66732V12.6667H12.334V8.66675ZM12.334 4.66675H1.66732V7.33341H3.66732V6.00008H5.00065V7.33341H9.00065V6.00008H10.334V7.33341H12.334V4.66675ZM5.00065 2.00008V3.33341H9.00065V2.00008H5.00065Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.ENTERTAINMENT: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M11.3327 2.66675C13.5418 2.66675 15.3327 4.45761 15.3327 6.66675V9.33342C15.3327 11.5425 13.5418 13.3334 11.3327 13.3334H4.66602C2.45688 13.3334 0.666016 11.5425 0.666016 9.33342V6.66675C0.666016 4.45761 2.45688 2.66675 4.66602 2.66675H11.3327ZM11.3327 4.00008H4.66602C3.23788 4.00008 2.07196 5.12273 2.00262 6.53365L1.99935 6.66675V9.33342C1.99935 10.7615 3.122 11.9275 4.53292 11.9968L4.66602 12.0001H11.3327C12.7608 12.0001 13.9267 10.8774 13.9961 9.46648L13.9993 9.33342V6.66675C13.9993 5.23861 12.8767 4.07269 11.4657 4.00335L11.3327 4.00008ZM6.66602 6.00008V7.33342H7.99935V8.66675H6.66535L6.66602 10.0001H5.33268L5.33202 8.66675H3.99935V7.33342H5.33268V6.00008H6.66602ZM11.9993 8.66675V10.0001H10.666V8.66675H11.9993ZM10.666 6.00008V7.33342H9.33268V6.00008H10.666Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.UTILITIES: """<svg xmlns="http://www.w3.org/2000/svg" width="13" height="15" viewBox="0 0 13 15" fill="none">
|
||||
<path d="M12.3346 0.333252C12.7028 0.333252 13.0013 0.631732 13.0013 0.999919V4.33325C13.0013 4.70144 12.7028 4.99992 12.3346 4.99992H9.0013V13.6666C9.0013 14.0348 8.70284 14.3333 8.33463 14.3333H5.66797C5.29978 14.3333 5.0013 14.0348 5.0013 13.6666V4.99992H1.33464C0.966449 4.99992 0.667969 4.70144 0.667969 4.33325V2.74527C0.667969 2.49276 0.810635 2.26192 1.0365 2.14899L4.66797 0.333252H12.3346ZM9.0013 1.66659H4.98273L2.0013 3.1573V3.66659H6.33464V12.9999H7.66797V3.66659H9.0013V1.66659ZM11.668 1.66659H10.3346V3.66659H11.668V1.66659Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.OTHER: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M8.00052 0.666748L4.00065 7.33342H12.0007L8.00052 0.666748ZM8.00052 3.25828L9.64572 6.00008H6.35553L8.00052 3.25828ZM4.50065 13.3334C3.48813 13.3334 2.66732 12.5126 2.66732 11.5001C2.66732 10.4875 3.48813 9.66675 4.50065 9.66675C5.51317 9.66675 6.33398 10.4875 6.33398 11.5001C6.33398 12.5126 5.51317 13.3334 4.50065 13.3334ZM4.50065 14.6667C6.24955 14.6667 7.66732 13.249 7.66732 11.5001C7.66732 9.75115 6.24955 8.33342 4.50065 8.33342C2.75175 8.33342 1.33398 9.75115 1.33398 11.5001C1.33398 13.249 2.75175 14.6667 4.50065 14.6667ZM10.0007 10.3334V13.0001H12.6673V10.3334H10.0007ZM8.66732 14.3334V9.00008H14.0007V14.3334H8.66732Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
}
|
||||
|
||||
default_tool_label_dict = {
|
||||
ToolLabelEnum.SEARCH: ToolLabel(
|
||||
name="search", label=I18nObject(en_US="Search", zh_Hans="搜索"), icon=ICONS[ToolLabelEnum.SEARCH]
|
||||
),
|
||||
ToolLabelEnum.IMAGE: ToolLabel(
|
||||
name="image", label=I18nObject(en_US="Image", zh_Hans="图片"), icon=ICONS[ToolLabelEnum.IMAGE]
|
||||
),
|
||||
ToolLabelEnum.VIDEOS: ToolLabel(
|
||||
name="videos", label=I18nObject(en_US="Videos", zh_Hans="视频"), icon=ICONS[ToolLabelEnum.VIDEOS]
|
||||
),
|
||||
ToolLabelEnum.WEATHER: ToolLabel(
|
||||
name="weather", label=I18nObject(en_US="Weather", zh_Hans="天气"), icon=ICONS[ToolLabelEnum.WEATHER]
|
||||
),
|
||||
ToolLabelEnum.FINANCE: ToolLabel(
|
||||
name="finance", label=I18nObject(en_US="Finance", zh_Hans="金融"), icon=ICONS[ToolLabelEnum.FINANCE]
|
||||
),
|
||||
ToolLabelEnum.DESIGN: ToolLabel(
|
||||
name="design", label=I18nObject(en_US="Design", zh_Hans="设计"), icon=ICONS[ToolLabelEnum.DESIGN]
|
||||
),
|
||||
ToolLabelEnum.TRAVEL: ToolLabel(
|
||||
name="travel", label=I18nObject(en_US="Travel", zh_Hans="旅行"), icon=ICONS[ToolLabelEnum.TRAVEL]
|
||||
),
|
||||
ToolLabelEnum.SOCIAL: ToolLabel(
|
||||
name="social", label=I18nObject(en_US="Social", zh_Hans="社交"), icon=ICONS[ToolLabelEnum.SOCIAL]
|
||||
),
|
||||
ToolLabelEnum.NEWS: ToolLabel(
|
||||
name="news", label=I18nObject(en_US="News", zh_Hans="新闻"), icon=ICONS[ToolLabelEnum.NEWS]
|
||||
),
|
||||
ToolLabelEnum.MEDICAL: ToolLabel(
|
||||
name="medical", label=I18nObject(en_US="Medical", zh_Hans="医疗"), icon=ICONS[ToolLabelEnum.MEDICAL]
|
||||
),
|
||||
ToolLabelEnum.PRODUCTIVITY: ToolLabel(
|
||||
name="productivity",
|
||||
label=I18nObject(en_US="Productivity", zh_Hans="生产力"),
|
||||
icon=ICONS[ToolLabelEnum.PRODUCTIVITY],
|
||||
),
|
||||
ToolLabelEnum.EDUCATION: ToolLabel(
|
||||
name="education", label=I18nObject(en_US="Education", zh_Hans="教育"), icon=ICONS[ToolLabelEnum.EDUCATION]
|
||||
),
|
||||
ToolLabelEnum.BUSINESS: ToolLabel(
|
||||
name="business", label=I18nObject(en_US="Business", zh_Hans="商业"), icon=ICONS[ToolLabelEnum.BUSINESS]
|
||||
),
|
||||
ToolLabelEnum.ENTERTAINMENT: ToolLabel(
|
||||
name="entertainment",
|
||||
label=I18nObject(en_US="Entertainment", zh_Hans="娱乐"),
|
||||
icon=ICONS[ToolLabelEnum.ENTERTAINMENT],
|
||||
),
|
||||
ToolLabelEnum.UTILITIES: ToolLabel(
|
||||
name="utilities", label=I18nObject(en_US="Utilities", zh_Hans="工具"), icon=ICONS[ToolLabelEnum.UTILITIES]
|
||||
),
|
||||
ToolLabelEnum.OTHER: ToolLabel(
|
||||
name="other", label=I18nObject(en_US="Other", zh_Hans="其他"), icon=ICONS[ToolLabelEnum.OTHER]
|
||||
),
|
||||
}
|
||||
|
||||
default_tool_labels = [v for k, v in default_tool_label_dict.items()]
|
||||
default_tool_label_name_list = [label.name for label in default_tool_labels]
|
@ -1,16 +1,16 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
|
||||
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
|
||||
from core.plugin.entities.plugin_daemon import (
|
||||
PluginBasicBooleanResponse,
|
||||
PluginDatasourceProviderEntity,
|
||||
)
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
|
||||
|
||||
class PluginDatasourceManager(BasePluginClient):
|
||||
def fetch_datasource_providers(self, tenant_id: str) -> list[PluginToolProviderEntity]:
|
||||
def fetch_datasource_providers(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]:
|
||||
"""
|
||||
Fetch datasource providers for the given tenant.
|
||||
"""
|
||||
@ -27,7 +27,7 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/datasources",
|
||||
list[PluginToolProviderEntity],
|
||||
list[PluginDatasourceProviderEntity],
|
||||
params={"page": 1, "page_size": 256},
|
||||
transformer=transformer,
|
||||
)
|
||||
@ -36,12 +36,12 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each tool to plugin_id/provider_name
|
||||
for tool in provider.declaration.tools:
|
||||
tool.identity.provider = provider.declaration.identity.name
|
||||
for datasource in provider.declaration.datasources:
|
||||
datasource.identity.provider = provider.declaration.identity.name
|
||||
|
||||
return response
|
||||
|
||||
def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity:
|
||||
def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginDatasourceProviderEntity:
|
||||
"""
|
||||
Fetch datasource provider for the given tenant and plugin.
|
||||
"""
|
||||
@ -58,7 +58,7 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/datasources",
|
||||
PluginToolProviderEntity,
|
||||
PluginDatasourceProviderEntity,
|
||||
params={"provider": tool_provider_id.provider_name, "plugin_id": tool_provider_id.plugin_id},
|
||||
transformer=transformer,
|
||||
)
|
||||
@ -66,8 +66,8 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each tool to plugin_id/provider_name
|
||||
for tool in response.declaration.tools:
|
||||
tool.identity.provider = response.declaration.identity.name
|
||||
for datasource in response.declaration.datasources:
|
||||
datasource.identity.provider = response.declaration.identity.name
|
||||
|
||||
return response
|
||||
|
||||
@ -79,7 +79,7 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
datasource_name: str,
|
||||
credentials: dict[str, Any],
|
||||
datasource_parameters: dict[str, Any],
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||
"""
|
||||
@ -88,8 +88,8 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/datasource/{online_document}/pages",
|
||||
ToolInvokeMessage,
|
||||
f"plugin/{tenant_id}/dispatch/datasource/first_step",
|
||||
dict,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
@ -104,7 +104,10 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
return response
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise Exception("No response from plugin daemon")
|
||||
|
||||
def invoke_second_step(
|
||||
self,
|
||||
@ -114,7 +117,7 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
datasource_name: str,
|
||||
credentials: dict[str, Any],
|
||||
datasource_parameters: dict[str, Any],
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||
"""
|
||||
@ -123,8 +126,8 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/datasource/invoke_second_step",
|
||||
ToolInvokeMessage,
|
||||
f"plugin/{tenant_id}/dispatch/datasource/second_step",
|
||||
dict,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
@ -139,7 +142,10 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
return response
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise Exception("No response from plugin daemon")
|
||||
|
||||
def validate_provider_credentials(
|
||||
self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any]
|
||||
@ -151,7 +157,7 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/tool/validate_credentials",
|
||||
f"plugin/{tenant_id}/dispatch/datasource/validate_credentials",
|
||||
PluginBasicBooleanResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
@ -170,48 +176,3 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
return resp.result
|
||||
|
||||
return False
|
||||
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
credentials: dict[str, Any],
|
||||
datasource: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> list[ToolParameter]:
|
||||
"""
|
||||
get the runtime parameters of the datasource
|
||||
"""
|
||||
datasource_provider_id = GenericProviderID(provider)
|
||||
|
||||
class RuntimeParametersResponse(BaseModel):
|
||||
parameters: list[ToolParameter]
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/datasource/get_runtime_parameters",
|
||||
RuntimeParametersResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"conversation_id": conversation_id,
|
||||
"app_id": app_id,
|
||||
"message_id": message_id,
|
||||
"data": {
|
||||
"provider": datasource_provider_id.provider_name,
|
||||
"datasource": datasource,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": datasource_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp.parameters
|
||||
|
||||
return []
|
||||
|
@ -3,10 +3,9 @@ from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.plugin.entities.plugin import DatasourceProviderID, GenericProviderID, ToolProviderID
|
||||
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
|
||||
from core.plugin.entities.plugin_daemon import (
|
||||
PluginBasicBooleanResponse,
|
||||
PluginDatasourceProviderEntity,
|
||||
PluginToolProviderEntity,
|
||||
)
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
@ -45,67 +44,6 @@ class PluginToolManager(BasePluginClient):
|
||||
|
||||
return response
|
||||
|
||||
def fetch_datasources(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]:
|
||||
"""
|
||||
Fetch datasources for the given tenant.
|
||||
"""
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
for provider in json_response.get("data", []):
|
||||
declaration = provider.get("declaration", {}) or {}
|
||||
provider_name = declaration.get("identity", {}).get("name")
|
||||
for tool in declaration.get("tools", []):
|
||||
tool["identity"]["provider"] = provider_name
|
||||
|
||||
return json_response
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/datasources",
|
||||
list[PluginToolProviderEntity],
|
||||
params={"page": 1, "page_size": 256},
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
for provider in response:
|
||||
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each tool to plugin_id/provider_name
|
||||
for tool in provider.declaration.tools:
|
||||
tool.identity.provider = provider.declaration.identity.name
|
||||
|
||||
return response
|
||||
|
||||
def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginDatasourceProviderEntity:
|
||||
"""
|
||||
Fetch datasource provider for the given tenant and plugin.
|
||||
"""
|
||||
datasource_provider_id = DatasourceProviderID(provider)
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
data = json_response.get("data")
|
||||
if data:
|
||||
for tool in data.get("declaration", {}).get("tools", []):
|
||||
tool["identity"]["provider"] = datasource_provider_id.provider_name
|
||||
|
||||
return json_response
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/datasource",
|
||||
PluginDatasourceProviderEntity,
|
||||
params={"provider": datasource_provider_id.provider_name, "plugin_id": datasource_provider_id.plugin_id},
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each tool to plugin_id/provider_name
|
||||
for tool in response.declaration.tools:
|
||||
tool.identity.provider = response.declaration.identity.name
|
||||
|
||||
return response
|
||||
|
||||
def fetch_tool_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity:
|
||||
"""
|
||||
Fetch tool provider for the given tenant and plugin.
|
||||
|
@ -1,6 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from core.datasource.entities.datasource_entities import DatasourceSelector
|
||||
from core.file.models import File
|
||||
from core.tools.entities.tool_entities import ToolSelector
|
||||
|
||||
@ -19,10 +18,4 @@ def convert_parameters_to_plugin_format(parameters: dict[str, Any]) -> dict[str,
|
||||
parameters[parameter_name] = []
|
||||
for p in parameter:
|
||||
parameters[parameter_name].append(p.to_plugin_parameter())
|
||||
elif isinstance(parameter, DatasourceSelector):
|
||||
parameters[parameter_name] = parameter.to_plugin_parameter()
|
||||
elif isinstance(parameter, list) and all(isinstance(p, DatasourceSelector) for p in parameter):
|
||||
parameters[parameter_name] = []
|
||||
for p in parameter:
|
||||
parameters[parameter_name].append(p.to_plugin_parameter())
|
||||
return parameters
|
||||
|
@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, Any, Union, cast
|
||||
from yarl import URL
|
||||
|
||||
import contexts
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
@ -496,31 +495,6 @@ class ToolManager:
|
||||
# get plugin providers
|
||||
yield from cls.list_plugin_providers(tenant_id)
|
||||
|
||||
@classmethod
|
||||
def list_datasource_providers(cls, tenant_id: str) -> list[DatasourcePluginProviderController]:
|
||||
"""
|
||||
list all the datasource providers
|
||||
"""
|
||||
manager = PluginToolManager()
|
||||
provider_entities = manager.fetch_datasources(tenant_id)
|
||||
return [
|
||||
DatasourcePluginProviderController(
|
||||
entity=provider.declaration,
|
||||
plugin_id=provider.plugin_id,
|
||||
plugin_unique_identifier=provider.plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
for provider in provider_entities
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def list_builtin_datasources(cls, tenant_id: str) -> Generator[DatasourcePluginProviderController, None, None]:
|
||||
"""
|
||||
list all the builtin datasources
|
||||
"""
|
||||
# get builtin datasources
|
||||
yield from cls.list_datasource_providers(tenant_id)
|
||||
|
||||
@classmethod
|
||||
def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
|
||||
"""
|
||||
|
@ -1,35 +1,24 @@
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.datasource.datasource_engine import DatasourceEngine
|
||||
from core.datasource.entities.datasource_entities import DatasourceInvokeMessage, DatasourceParameter
|
||||
from core.datasource.errors import DatasourceInvokeError
|
||||
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.plugin.manager.exc import PluginDaemonClientSideError
|
||||
from core.plugin.manager.plugin import PluginInstallationManager
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceParameter,
|
||||
)
|
||||
from core.file import File
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import AgentLogEvent
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.event import RunCompletedEvent
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import ToolFile
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
from .entities import DatasourceNodeData
|
||||
from .exc import DatasourceNodeError, DatasourceParameterError, ToolFileError
|
||||
from .exc import DatasourceNodeError, DatasourceParameterError
|
||||
|
||||
|
||||
class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
@ -49,7 +38,6 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
|
||||
# fetch datasource icon
|
||||
datasource_info = {
|
||||
"provider_type": node_data.provider_type.value,
|
||||
"provider_id": node_data.provider_id,
|
||||
"plugin_unique_identifier": node_data.plugin_unique_identifier,
|
||||
}
|
||||
@ -58,8 +46,10 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
try:
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
|
||||
datasource_runtime = DatasourceManager.get_workflow_datasource_runtime(
|
||||
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from
|
||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id=node_data.provider_id,
|
||||
datasource_name=node_data.datasource_name,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
except DatasourceNodeError as e:
|
||||
yield RunCompletedEvent(
|
||||
@ -74,7 +64,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
return
|
||||
|
||||
# get parameters
|
||||
datasource_parameters = datasource_runtime.get_merged_runtime_parameters() or []
|
||||
datasource_parameters = datasource_runtime.entity.parameters
|
||||
parameters = self._generate_parameters(
|
||||
datasource_parameters=datasource_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
@ -91,15 +81,20 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
|
||||
try:
|
||||
message_stream = DatasourceEngine.generic_invoke(
|
||||
datasource=datasource_runtime,
|
||||
datasource_parameters=parameters,
|
||||
# TODO: handle result
|
||||
result = datasource_runtime._invoke_second_step(
|
||||
user_id=self.user_id,
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler(),
|
||||
workflow_call_depth=self.workflow_call_depth,
|
||||
thread_pool_id=self.thread_pool_id,
|
||||
app_id=self.app_id,
|
||||
conversation_id=conversation_id.text if conversation_id else None,
|
||||
datasource_parameters=parameters,
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
error=f"Failed to transform datasource message: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
except DatasourceNodeError as e:
|
||||
yield RunCompletedEvent(
|
||||
@ -113,20 +108,6 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# convert datasource messages
|
||||
yield from self._transform_message(message_stream, datasource_info, parameters_for_log)
|
||||
except (PluginDaemonClientSideError, DatasourceInvokeError) as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
error=f"Failed to transform datasource message: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
|
||||
def _generate_parameters(
|
||||
self,
|
||||
*,
|
||||
@ -175,200 +156,6 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||
return list(variable.value) if variable else []
|
||||
|
||||
def _transform_message(
|
||||
self,
|
||||
messages: Generator[DatasourceInvokeMessage, None, None],
|
||||
datasource_info: Mapping[str, Any],
|
||||
parameters_for_log: dict[str, Any],
|
||||
) -> Generator:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json: list[dict] = []
|
||||
|
||||
agent_logs: list[AgentLogEvent] = []
|
||||
agent_execution_metadata: Mapping[NodeRunMetadataKey, Any] = {}
|
||||
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
if message.type in {
|
||||
DatasourceInvokeMessage.MessageType.IMAGE_LINK,
|
||||
DatasourceInvokeMessage.MessageType.BINARY_LINK,
|
||||
DatasourceInvokeMessage.MessageType.IMAGE,
|
||||
}:
|
||||
assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
if message.meta:
|
||||
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
||||
else:
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
tool_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
elif message.type == DatasourceInvokeMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileError(f"tool file {tool_file_id} not exists")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
)
|
||||
elif message.type == DatasourceInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
|
||||
)
|
||||
elif message.type == DatasourceInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, DatasourceInvokeMessage.JsonMessage)
|
||||
if self.node_type == NodeType.AGENT:
|
||||
msg_metadata = message.message.json_object.pop("execution_metadata", {})
|
||||
agent_execution_metadata = {
|
||||
key: value
|
||||
for key, value in msg_metadata.items()
|
||||
if key in NodeRunMetadataKey.__members__.values()
|
||||
}
|
||||
json.append(message.message.json_object)
|
||||
elif message.type == DatasourceInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"])
|
||||
elif message.type == DatasourceInvokeMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, DatasourceInvokeMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name]
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
elif message.type == DatasourceInvokeMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
files.append(message.meta["file"])
|
||||
elif message.type == DatasourceInvokeMessage.MessageType.LOG:
|
||||
assert isinstance(message.message, DatasourceInvokeMessage.LogMessage)
|
||||
if message.message.metadata:
|
||||
icon = datasource_info.get("icon", "")
|
||||
dict_metadata = dict(message.message.metadata)
|
||||
if dict_metadata.get("provider"):
|
||||
manager = PluginInstallationManager()
|
||||
plugins = manager.list_plugins(self.tenant_id)
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
pass
|
||||
try:
|
||||
builtin_tool = next(
|
||||
provider
|
||||
for provider in BuiltinToolManageService.list_builtin_tools(
|
||||
self.user_id,
|
||||
self.tenant_id,
|
||||
)
|
||||
if provider.name == dict_metadata["provider"]
|
||||
)
|
||||
icon = builtin_tool.icon
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
dict_metadata["icon"] = icon
|
||||
message.message.metadata = dict_metadata
|
||||
agent_log = AgentLogEvent(
|
||||
id=message.message.id,
|
||||
node_execution_id=self.id,
|
||||
parent_id=message.message.parent_id,
|
||||
error=message.message.error,
|
||||
status=message.message.status.value,
|
||||
data=message.message.data,
|
||||
label=message.message.label,
|
||||
metadata=message.message.metadata,
|
||||
node_id=self.node_id,
|
||||
)
|
||||
|
||||
# check if the agent log is already in the list
|
||||
for log in agent_logs:
|
||||
if log.id == agent_log.id:
|
||||
# update the log
|
||||
log.data = agent_log.data
|
||||
log.status = agent_log.status
|
||||
log.error = agent_log.error
|
||||
log.label = agent_log.label
|
||||
log.metadata = agent_log.metadata
|
||||
break
|
||||
else:
|
||||
agent_logs.append(agent_log)
|
||||
|
||||
yield agent_log
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"text": text, "files": files, "json": json, **variables},
|
||||
metadata={
|
||||
**agent_execution_metadata,
|
||||
NodeRunMetadataKey.DATASOURCE_INFO: datasource_info,
|
||||
NodeRunMetadataKey.AGENT_LOG: agent_logs,
|
||||
},
|
||||
inputs=parameters_for_log,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
|
@ -3,17 +3,15 @@ from typing import Any, Literal, Union
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
|
||||
|
||||
class DatasourceEntity(BaseModel):
|
||||
provider_id: str
|
||||
provider_type: ToolProviderType
|
||||
provider_name: str # redundancy
|
||||
tool_name: str
|
||||
datasource_name: str
|
||||
tool_label: str # redundancy
|
||||
tool_configurations: dict[str, Any]
|
||||
datasource_configurations: dict[str, Any]
|
||||
plugin_unique_identifier: str | None = None # redundancy
|
||||
|
||||
@field_validator("tool_configurations", mode="before")
|
||||
|
@ -1,7 +1,8 @@
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, cast, Mapping
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from flask_login import current_user
|
||||
|
||||
|
@ -245,7 +245,11 @@ class DatasetService:
|
||||
rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity,
|
||||
):
|
||||
# check if dataset name already exists
|
||||
if db.session.query(Dataset).filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id).first():
|
||||
if (
|
||||
db.session.query(Dataset)
|
||||
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
|
||||
.first()
|
||||
):
|
||||
raise DatasetNameDuplicateError(
|
||||
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
|
||||
)
|
||||
@ -254,7 +258,7 @@ class DatasetService:
|
||||
tenant_id=tenant_id,
|
||||
name=rag_pipeline_dataset_create_entity.name,
|
||||
description=rag_pipeline_dataset_create_entity.description,
|
||||
created_by=current_user.id
|
||||
created_by=current_user.id,
|
||||
)
|
||||
db.session.add(pipeline)
|
||||
db.session.flush()
|
||||
@ -268,7 +272,7 @@ class DatasetService:
|
||||
runtime_mode="rag_pipeline",
|
||||
icon_info=rag_pipeline_dataset_create_entity.icon_info,
|
||||
created_by=current_user.id,
|
||||
pipeline_id=pipeline.id
|
||||
pipeline_id=pipeline.id,
|
||||
)
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
@ -280,7 +284,11 @@ class DatasetService:
|
||||
rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity,
|
||||
):
|
||||
# check if dataset name already exists
|
||||
if db.session.query(Dataset).filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id).first():
|
||||
if (
|
||||
db.session.query(Dataset)
|
||||
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
|
||||
.first()
|
||||
):
|
||||
raise DatasetNameDuplicateError(
|
||||
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
|
||||
)
|
||||
@ -299,7 +307,7 @@ class DatasetService:
|
||||
account=current_user,
|
||||
import_mode=ImportMode.YAML_CONTENT.value,
|
||||
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
|
||||
dataset=dataset
|
||||
dataset=dataset,
|
||||
)
|
||||
return {
|
||||
"id": rag_pipeline_import_info.id,
|
||||
@ -1254,281 +1262,282 @@ class DocumentService:
|
||||
|
||||
return documents, batch
|
||||
|
||||
@staticmethod
|
||||
def save_document_with_dataset_id(
|
||||
dataset: Dataset,
|
||||
knowledge_config: KnowledgeConfig,
|
||||
account: Account | Any,
|
||||
dataset_process_rule: Optional[DatasetProcessRule] = None,
|
||||
created_from: str = "web",
|
||||
):
|
||||
# check document limit
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
# @staticmethod
|
||||
# def save_document_with_dataset_id(
|
||||
# dataset: Dataset,
|
||||
# knowledge_config: KnowledgeConfig,
|
||||
# account: Account | Any,
|
||||
# dataset_process_rule: Optional[DatasetProcessRule] = None,
|
||||
# created_from: str = "web",
|
||||
# ):
|
||||
# # check document limit
|
||||
# features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
|
||||
if features.billing.enabled:
|
||||
if not knowledge_config.original_document_id:
|
||||
count = 0
|
||||
if knowledge_config.data_source:
|
||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
|
||||
count = len(upload_file_list)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
||||
for notion_info in notion_info_list: # type: ignore
|
||||
count = count + len(notion_info.pages)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
||||
website_info = knowledge_config.data_source.info_list.website_info_list
|
||||
count = len(website_info.urls) # type: ignore
|
||||
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||
# if features.billing.enabled:
|
||||
# if not knowledge_config.original_document_id:
|
||||
# count = 0
|
||||
# if knowledge_config.data_source:
|
||||
# if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
||||
# upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
||||
# # type: ignore
|
||||
# count = len(upload_file_list)
|
||||
# elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||
# notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
||||
# for notion_info in notion_info_list: # type: ignore
|
||||
# count = count + len(notion_info.pages)
|
||||
# elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
||||
# website_info = knowledge_config.data_source.info_list.website_info_list
|
||||
# count = len(website_info.urls) # type: ignore
|
||||
# batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||
|
||||
if features.billing.subscription.plan == "sandbox" and count > 1:
|
||||
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
# if features.billing.subscription.plan == "sandbox" and count > 1:
|
||||
# raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
|
||||
# if count > batch_upload_limit:
|
||||
# raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
|
||||
DocumentService.check_documents_upload_quota(count, features)
|
||||
# DocumentService.check_documents_upload_quota(count, features)
|
||||
|
||||
# if dataset is empty, update dataset data_source_type
|
||||
if not dataset.data_source_type:
|
||||
dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore
|
||||
# # if dataset is empty, update dataset data_source_type
|
||||
# if not dataset.data_source_type:
|
||||
# dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore
|
||||
|
||||
if not dataset.indexing_technique:
|
||||
if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
|
||||
raise ValueError("Indexing technique is invalid")
|
||||
# if not dataset.indexing_technique:
|
||||
# if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
|
||||
# raise ValueError("Indexing technique is invalid")
|
||||
|
||||
dataset.indexing_technique = knowledge_config.indexing_technique
|
||||
if knowledge_config.indexing_technique == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
|
||||
dataset_embedding_model = knowledge_config.embedding_model
|
||||
dataset_embedding_model_provider = knowledge_config.embedding_model_provider
|
||||
else:
|
||||
embedding_model = model_manager.get_default_model_instance(
|
||||
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
|
||||
)
|
||||
dataset_embedding_model = embedding_model.model
|
||||
dataset_embedding_model_provider = embedding_model.provider
|
||||
dataset.embedding_model = dataset_embedding_model
|
||||
dataset.embedding_model_provider = dataset_embedding_model_provider
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
dataset_embedding_model_provider, dataset_embedding_model
|
||||
)
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
if not dataset.retrieval_model:
|
||||
default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
# dataset.indexing_technique = knowledge_config.indexing_technique
|
||||
# if knowledge_config.indexing_technique == "high_quality":
|
||||
# model_manager = ModelManager()
|
||||
# if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
|
||||
# dataset_embedding_model = knowledge_config.embedding_model
|
||||
# dataset_embedding_model_provider = knowledge_config.embedding_model_provider
|
||||
# else:
|
||||
# embedding_model = model_manager.get_default_model_instance(
|
||||
# tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
|
||||
# )
|
||||
# dataset_embedding_model = embedding_model.model
|
||||
# dataset_embedding_model_provider = embedding_model.provider
|
||||
# dataset.embedding_model = dataset_embedding_model
|
||||
# dataset.embedding_model_provider = dataset_embedding_model_provider
|
||||
# dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
# dataset_embedding_model_provider, dataset_embedding_model
|
||||
# )
|
||||
# dataset.collection_binding_id = dataset_collection_binding.id
|
||||
# if not dataset.retrieval_model:
|
||||
# default_retrieval_model = {
|
||||
# "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
# "reranking_enable": False,
|
||||
# "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
# "top_k": 2,
|
||||
# "score_threshold_enabled": False,
|
||||
# }
|
||||
|
||||
dataset.retrieval_model = (
|
||||
knowledge_config.retrieval_model.model_dump()
|
||||
if knowledge_config.retrieval_model
|
||||
else default_retrieval_model
|
||||
) # type: ignore
|
||||
# dataset.retrieval_model = (
|
||||
# knowledge_config.retrieval_model.model_dump()
|
||||
# if knowledge_config.retrieval_model
|
||||
# else default_retrieval_model
|
||||
# ) # type: ignore
|
||||
|
||||
documents = []
|
||||
if knowledge_config.original_document_id:
|
||||
document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account)
|
||||
documents.append(document)
|
||||
batch = document.batch
|
||||
else:
|
||||
batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
|
||||
# save process rule
|
||||
if not dataset_process_rule:
|
||||
process_rule = knowledge_config.process_rule
|
||||
if process_rule:
|
||||
if process_rule.mode in ("custom", "hierarchical"):
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=process_rule.mode,
|
||||
rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
|
||||
created_by=account.id,
|
||||
)
|
||||
elif process_rule.mode == "automatic":
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=process_rule.mode,
|
||||
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
|
||||
created_by=account.id,
|
||||
)
|
||||
else:
|
||||
logging.warn(
|
||||
f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule"
|
||||
)
|
||||
return
|
||||
db.session.add(dataset_process_rule)
|
||||
db.session.commit()
|
||||
lock_name = "add_document_lock_dataset_id_{}".format(dataset.id)
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
position = DocumentService.get_documents_position(dataset.id)
|
||||
document_ids = []
|
||||
duplicate_document_ids = []
|
||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore
|
||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
|
||||
for file_id in upload_file_list:
|
||||
file = (
|
||||
db.session.query(UploadFile)
|
||||
.filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
|
||||
.first()
|
||||
)
|
||||
# documents = []
|
||||
# if knowledge_config.original_document_id:
|
||||
# document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account)
|
||||
# documents.append(document)
|
||||
# batch = document.batch
|
||||
# else:
|
||||
# batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
|
||||
# # save process rule
|
||||
# if not dataset_process_rule:
|
||||
# process_rule = knowledge_config.process_rule
|
||||
# if process_rule:
|
||||
# if process_rule.mode in ("custom", "hierarchical"):
|
||||
# dataset_process_rule = DatasetProcessRule(
|
||||
# dataset_id=dataset.id,
|
||||
# mode=process_rule.mode,
|
||||
# rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
|
||||
# created_by=account.id,
|
||||
# )
|
||||
# elif process_rule.mode == "automatic":
|
||||
# dataset_process_rule = DatasetProcessRule(
|
||||
# dataset_id=dataset.id,
|
||||
# mode=process_rule.mode,
|
||||
# rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
|
||||
# created_by=account.id,
|
||||
# )
|
||||
# else:
|
||||
# logging.warn(
|
||||
# f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule"
|
||||
# )
|
||||
# return
|
||||
# db.session.add(dataset_process_rule)
|
||||
# db.session.commit()
|
||||
# lock_name = "add_document_lock_dataset_id_{}".format(dataset.id)
|
||||
# with redis_client.lock(lock_name, timeout=600):
|
||||
# position = DocumentService.get_documents_position(dataset.id)
|
||||
# document_ids = []
|
||||
# duplicate_document_ids = []
|
||||
# if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore
|
||||
# upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
|
||||
# for file_id in upload_file_list:
|
||||
# file = (
|
||||
# db.session.query(UploadFile)
|
||||
# .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
|
||||
# .first()
|
||||
# )
|
||||
|
||||
# raise error if file not found
|
||||
if not file:
|
||||
raise FileNotExistsError()
|
||||
# # raise error if file not found
|
||||
# if not file:
|
||||
# raise FileNotExistsError()
|
||||
|
||||
file_name = file.name
|
||||
data_source_info = {
|
||||
"upload_file_id": file_id,
|
||||
}
|
||||
# check duplicate
|
||||
if knowledge_config.duplicate:
|
||||
document = Document.query.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type="upload_file",
|
||||
enabled=True,
|
||||
name=file_name,
|
||||
).first()
|
||||
if document:
|
||||
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
|
||||
document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
document.created_from = created_from
|
||||
document.doc_form = knowledge_config.doc_form
|
||||
document.doc_language = knowledge_config.doc_language
|
||||
document.data_source_info = json.dumps(data_source_info)
|
||||
document.batch = batch
|
||||
document.indexing_status = "waiting"
|
||||
db.session.add(document)
|
||||
documents.append(document)
|
||||
duplicate_document_ids.append(document.id)
|
||||
continue
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id, # type: ignore
|
||||
knowledge_config.data_source.info_list.data_source_type, # type: ignore
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
data_source_info,
|
||||
created_from,
|
||||
position,
|
||||
account,
|
||||
file_name,
|
||||
batch,
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
document_ids.append(document.id)
|
||||
documents.append(document)
|
||||
position += 1
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore
|
||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
|
||||
if not notion_info_list:
|
||||
raise ValueError("No notion info list found.")
|
||||
exist_page_ids = []
|
||||
exist_document = {}
|
||||
documents = Document.query.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type="notion_import",
|
||||
enabled=True,
|
||||
).all()
|
||||
if documents:
|
||||
for document in documents:
|
||||
data_source_info = json.loads(document.data_source_info)
|
||||
exist_page_ids.append(data_source_info["notion_page_id"])
|
||||
exist_document[data_source_info["notion_page_id"]] = document.id
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info.workspace_id
|
||||
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||
)
|
||||
).first()
|
||||
if not data_source_binding:
|
||||
raise ValueError("Data source binding not found.")
|
||||
for page in notion_info.pages:
|
||||
if page.page_id not in exist_page_ids:
|
||||
data_source_info = {
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_page_id": page.page_id,
|
||||
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None,
|
||||
"type": page.type,
|
||||
}
|
||||
# Truncate page name to 255 characters to prevent DB field length errors
|
||||
truncated_page_name = page.page_name[:255] if page.page_name else "nopagename"
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id, # type: ignore
|
||||
knowledge_config.data_source.info_list.data_source_type, # type: ignore
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
data_source_info,
|
||||
created_from,
|
||||
position,
|
||||
account,
|
||||
truncated_page_name,
|
||||
batch,
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
document_ids.append(document.id)
|
||||
documents.append(document)
|
||||
position += 1
|
||||
else:
|
||||
exist_document.pop(page.page_id)
|
||||
# delete not selected documents
|
||||
if len(exist_document) > 0:
|
||||
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore
|
||||
website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore
|
||||
if not website_info:
|
||||
raise ValueError("No website info list found.")
|
||||
urls = website_info.urls
|
||||
for url in urls:
|
||||
data_source_info = {
|
||||
"url": url,
|
||||
"provider": website_info.provider,
|
||||
"job_id": website_info.job_id,
|
||||
"only_main_content": website_info.only_main_content,
|
||||
"mode": "crawl",
|
||||
}
|
||||
if len(url) > 255:
|
||||
document_name = url[:200] + "..."
|
||||
else:
|
||||
document_name = url
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id, # type: ignore
|
||||
knowledge_config.data_source.info_list.data_source_type, # type: ignore
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
data_source_info,
|
||||
created_from,
|
||||
position,
|
||||
account,
|
||||
document_name,
|
||||
batch,
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
document_ids.append(document.id)
|
||||
documents.append(document)
|
||||
position += 1
|
||||
db.session.commit()
|
||||
# file_name = file.name
|
||||
# data_source_info = {
|
||||
# "upload_file_id": file_id,
|
||||
# }
|
||||
# # check duplicate
|
||||
# if knowledge_config.duplicate:
|
||||
# document = Document.query.filter_by(
|
||||
# dataset_id=dataset.id,
|
||||
# tenant_id=current_user.current_tenant_id,
|
||||
# data_source_type="upload_file",
|
||||
# enabled=True,
|
||||
# name=file_name,
|
||||
# ).first()
|
||||
# if document:
|
||||
# document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
|
||||
# document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
# document.created_from = created_from
|
||||
# document.doc_form = knowledge_config.doc_form
|
||||
# document.doc_language = knowledge_config.doc_language
|
||||
# document.data_source_info = json.dumps(data_source_info)
|
||||
# document.batch = batch
|
||||
# document.indexing_status = "waiting"
|
||||
# db.session.add(document)
|
||||
# documents.append(document)
|
||||
# duplicate_document_ids.append(document.id)
|
||||
# continue
|
||||
# document = DocumentService.build_document(
|
||||
# dataset,
|
||||
# dataset_process_rule.id, # type: ignore
|
||||
# knowledge_config.data_source.info_list.data_source_type, # type: ignore
|
||||
# knowledge_config.doc_form,
|
||||
# knowledge_config.doc_language,
|
||||
# data_source_info,
|
||||
# created_from,
|
||||
# position,
|
||||
# account,
|
||||
# file_name,
|
||||
# batch,
|
||||
# )
|
||||
# db.session.add(document)
|
||||
# db.session.flush()
|
||||
# document_ids.append(document.id)
|
||||
# documents.append(document)
|
||||
# position += 1
|
||||
# elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore
|
||||
# notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
|
||||
# if not notion_info_list:
|
||||
# raise ValueError("No notion info list found.")
|
||||
# exist_page_ids = []
|
||||
# exist_document = {}
|
||||
# documents = Document.query.filter_by(
|
||||
# dataset_id=dataset.id,
|
||||
# tenant_id=current_user.current_tenant_id,
|
||||
# data_source_type="notion_import",
|
||||
# enabled=True,
|
||||
# ).all()
|
||||
# if documents:
|
||||
# for document in documents:
|
||||
# data_source_info = json.loads(document.data_source_info)
|
||||
# exist_page_ids.append(data_source_info["notion_page_id"])
|
||||
# exist_document[data_source_info["notion_page_id"]] = document.id
|
||||
# for notion_info in notion_info_list:
|
||||
# workspace_id = notion_info.workspace_id
|
||||
# data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
# db.and_(
|
||||
# DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
# DataSourceOauthBinding.provider == "notion",
|
||||
# DataSourceOauthBinding.disabled == False,
|
||||
# DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||
# )
|
||||
# ).first()
|
||||
# if not data_source_binding:
|
||||
# raise ValueError("Data source binding not found.")
|
||||
# for page in notion_info.pages:
|
||||
# if page.page_id not in exist_page_ids:
|
||||
# data_source_info = {
|
||||
# "notion_workspace_id": workspace_id,
|
||||
# "notion_page_id": page.page_id,
|
||||
# "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None,
|
||||
# "type": page.type,
|
||||
# }
|
||||
# # Truncate page name to 255 characters to prevent DB field length errors
|
||||
# truncated_page_name = page.page_name[:255] if page.page_name else "nopagename"
|
||||
# document = DocumentService.build_document(
|
||||
# dataset,
|
||||
# dataset_process_rule.id, # type: ignore
|
||||
# knowledge_config.data_source.info_list.data_source_type, # type: ignore
|
||||
# knowledge_config.doc_form,
|
||||
# knowledge_config.doc_language,
|
||||
# data_source_info,
|
||||
# created_from,
|
||||
# position,
|
||||
# account,
|
||||
# truncated_page_name,
|
||||
# batch,
|
||||
# )
|
||||
# db.session.add(document)
|
||||
# db.session.flush()
|
||||
# document_ids.append(document.id)
|
||||
# documents.append(document)
|
||||
# position += 1
|
||||
# else:
|
||||
# exist_document.pop(page.page_id)
|
||||
# # delete not selected documents
|
||||
# if len(exist_document) > 0:
|
||||
# clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
|
||||
# elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore
|
||||
# website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore
|
||||
# if not website_info:
|
||||
# raise ValueError("No website info list found.")
|
||||
# urls = website_info.urls
|
||||
# for url in urls:
|
||||
# data_source_info = {
|
||||
# "url": url,
|
||||
# "provider": website_info.provider,
|
||||
# "job_id": website_info.job_id,
|
||||
# "only_main_content": website_info.only_main_content,
|
||||
# "mode": "crawl",
|
||||
# }
|
||||
# if len(url) > 255:
|
||||
# document_name = url[:200] + "..."
|
||||
# else:
|
||||
# document_name = url
|
||||
# document = DocumentService.build_document(
|
||||
# dataset,
|
||||
# dataset_process_rule.id, # type: ignore
|
||||
# knowledge_config.data_source.info_list.data_source_type, # type: ignore
|
||||
# knowledge_config.doc_form,
|
||||
# knowledge_config.doc_language,
|
||||
# data_source_info,
|
||||
# created_from,
|
||||
# position,
|
||||
# account,
|
||||
# document_name,
|
||||
# batch,
|
||||
# )
|
||||
# db.session.add(document)
|
||||
# db.session.flush()
|
||||
# document_ids.append(document.id)
|
||||
# documents.append(document)
|
||||
# position += 1
|
||||
# db.session.commit()
|
||||
|
||||
# trigger async task
|
||||
if document_ids:
|
||||
document_indexing_task.delay(dataset.id, document_ids)
|
||||
if duplicate_document_ids:
|
||||
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
|
||||
# # trigger async task
|
||||
# if document_ids:
|
||||
# document_indexing_task.delay(dataset.id, document_ids)
|
||||
# if duplicate_document_ids:
|
||||
# duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
|
||||
|
||||
return documents, batch
|
||||
# return documents, batch
|
||||
|
||||
@staticmethod
|
||||
def check_documents_upload_quota(count: int, features: FeatureModel):
|
||||
|
@ -309,8 +309,10 @@ class RagPipelineDslService:
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.filter(
|
||||
DatasetCollectionBinding.provider_name == knowledge_configuration.index_method.embedding_setting.embedding_provider_name,
|
||||
DatasetCollectionBinding.model_name == knowledge_configuration.index_method.embedding_setting.embedding_model_name,
|
||||
DatasetCollectionBinding.provider_name
|
||||
== knowledge_configuration.index_method.embedding_setting.embedding_provider_name,
|
||||
DatasetCollectionBinding.model_name
|
||||
== knowledge_configuration.index_method.embedding_setting.embedding_model_name,
|
||||
DatasetCollectionBinding.type == "dataset",
|
||||
)
|
||||
.order_by(DatasetCollectionBinding.created_at)
|
||||
|
14
api/services/rag_pipeline/rag_pipeline_manage_service.py
Normal file
14
api/services/rag_pipeline/rag_pipeline_manage_service.py
Normal file
@ -0,0 +1,14 @@
|
||||
from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
|
||||
|
||||
class RagPipelineManageService:
|
||||
@staticmethod
|
||||
def list_rag_pipeline_datasources(tenant_id: str) -> list[PluginDatasourceProviderEntity]:
|
||||
"""
|
||||
list rag pipeline datasources
|
||||
"""
|
||||
|
||||
# get all builtin providers
|
||||
manager = PluginDatasourceManager()
|
||||
return manager.fetch_datasource_providers(tenant_id)
|
@ -5,7 +5,6 @@ from pathlib import Path
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.datasource.entities.api_entities import DatasourceProviderApiEntity
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
|
||||
@ -17,7 +16,7 @@ from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
||||
from extensions.ext_database import db
|
||||
from models.tools import BuiltinDatasourceProvider, BuiltinToolProvider
|
||||
from models.tools import BuiltinToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -287,67 +286,6 @@ class BuiltinToolManageService:
|
||||
|
||||
return BuiltinToolProviderSort.sort(result)
|
||||
|
||||
@staticmethod
|
||||
def list_rag_pipeline_datasources(tenant_id: str) -> list[DatasourceProviderApiEntity]:
|
||||
"""
|
||||
list rag pipeline datasources
|
||||
"""
|
||||
# get all builtin providers
|
||||
datasource_provider_controllers = ToolManager.list_datasource_providers(tenant_id)
|
||||
|
||||
with db.session.no_autoflush:
|
||||
# get all user added providers
|
||||
db_providers: list[BuiltinDatasourceProvider] = (
|
||||
db.session.query(BuiltinDatasourceProvider)
|
||||
.filter(BuiltinDatasourceProvider.tenant_id == tenant_id)
|
||||
.all()
|
||||
or []
|
||||
)
|
||||
|
||||
# find provider
|
||||
def find_provider(provider):
|
||||
return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
|
||||
|
||||
result: list[DatasourceProviderApiEntity] = []
|
||||
|
||||
for provider_controller in datasource_provider_controllers:
|
||||
try:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
|
||||
data=provider_controller,
|
||||
name_func=lambda x: x.identity.name,
|
||||
):
|
||||
continue
|
||||
|
||||
# convert provider controller to user provider
|
||||
user_builtin_provider = ToolTransformService.builtin_datasource_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
db_provider=find_provider(provider_controller.entity.identity.name),
|
||||
decrypt_credentials=True,
|
||||
)
|
||||
|
||||
# add icon
|
||||
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider)
|
||||
|
||||
datasources = provider_controller.get_datasources()
|
||||
for datasource in datasources or []:
|
||||
user_builtin_provider.datasources.append(
|
||||
ToolTransformService.convert_datasource_entity_to_api_entity(
|
||||
tenant_id=tenant_id,
|
||||
datasource=datasource,
|
||||
credentials=user_builtin_provider.original_credentials,
|
||||
labels=ToolLabelManager.get_tool_labels(provider_controller),
|
||||
)
|
||||
)
|
||||
|
||||
result.append(user_builtin_provider)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
|
||||
try:
|
||||
|
@ -5,11 +5,6 @@ from typing import Optional, Union, cast
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.api_entities import DatasourceApiEntity, DatasourceProviderApiEntity
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
@ -26,7 +21,7 @@ from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from models.tools import ApiToolProvider, BuiltinDatasourceProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -145,64 +140,6 @@ class ToolTransformService:
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def builtin_datasource_provider_to_user_provider(
|
||||
cls,
|
||||
provider_controller: DatasourcePluginProviderController,
|
||||
db_provider: Optional[BuiltinDatasourceProvider],
|
||||
decrypt_credentials: bool = True,
|
||||
) -> DatasourceProviderApiEntity:
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
result = DatasourceProviderApiEntity(
|
||||
id=provider_controller.entity.identity.name,
|
||||
author=provider_controller.entity.identity.author,
|
||||
name=provider_controller.entity.identity.name,
|
||||
description=provider_controller.entity.identity.description,
|
||||
icon=provider_controller.entity.identity.icon,
|
||||
label=provider_controller.entity.identity.label,
|
||||
type=DatasourceProviderType.RAG_PIPELINE,
|
||||
masked_credentials={},
|
||||
is_team_authorization=False,
|
||||
plugin_id=provider_controller.plugin_id,
|
||||
plugin_unique_identifier=provider_controller.plugin_unique_identifier,
|
||||
datasources=[],
|
||||
)
|
||||
|
||||
# get credentials schema
|
||||
schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()}
|
||||
|
||||
for name, value in schema.items():
|
||||
if result.masked_credentials:
|
||||
result.masked_credentials[name] = ""
|
||||
|
||||
# check if the provider need credentials
|
||||
if not provider_controller.need_credentials:
|
||||
result.is_team_authorization = True
|
||||
result.allow_delete = False
|
||||
elif db_provider:
|
||||
result.is_team_authorization = True
|
||||
|
||||
if decrypt_credentials:
|
||||
credentials = db_provider.credentials
|
||||
|
||||
# init tool configuration
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
# decrypt the credentials and mask the credentials
|
||||
decrypted_credentials = tool_configuration.decrypt(data=credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials)
|
||||
|
||||
result.masked_credentials = masked_credentials
|
||||
result.original_credentials = decrypted_credentials
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def api_provider_to_controller(
|
||||
db_provider: ApiToolProvider,
|
||||
@ -367,48 +304,3 @@ class ToolTransformService:
|
||||
parameters=tool.parameters,
|
||||
labels=labels or [],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def convert_datasource_entity_to_api_entity(
|
||||
datasource: DatasourcePlugin,
|
||||
tenant_id: str,
|
||||
credentials: dict | None = None,
|
||||
labels: list[str] | None = None,
|
||||
) -> DatasourceApiEntity:
|
||||
"""
|
||||
convert tool to user tool
|
||||
"""
|
||||
# fork tool runtime
|
||||
datasource = datasource.fork_datasource_runtime(
|
||||
runtime=DatasourceRuntime(
|
||||
credentials=credentials or {},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
# get datasource parameters
|
||||
parameters = datasource.entity.parameters or []
|
||||
# get datasource runtime parameters
|
||||
runtime_parameters = datasource.get_runtime_parameters()
|
||||
# override parameters
|
||||
current_parameters = parameters.copy()
|
||||
for runtime_parameter in runtime_parameters:
|
||||
found = False
|
||||
for index, parameter in enumerate(current_parameters):
|
||||
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
|
||||
current_parameters[index] = runtime_parameter
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
current_parameters.append(runtime_parameter)
|
||||
|
||||
return DatasourceApiEntity(
|
||||
author=datasource.entity.identity.author,
|
||||
name=datasource.entity.identity.name,
|
||||
label=datasource.entity.identity.label,
|
||||
description=datasource.entity.description.human if datasource.entity.description else I18nObject(en_US=""),
|
||||
output_schema=datasource.entity.output_schema,
|
||||
parameters=current_parameters,
|
||||
labels=labels or [],
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user