diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index 44296d5a31..cc07084dea 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -101,7 +101,9 @@ class CustomizedPipelineTemplateApi(Resource): @enterprise_license_required def post(self, template_id: str): with Session(db.engine) as session: - template = session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first() + template = ( + session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first() + ) if not template: raise ValueError("Customized pipeline template not found.") pipeline = session.query(Pipeline).filter(Pipeline.id == template.pipeline_id).first() diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index c76014d0a3..c67b897f81 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -43,7 +43,7 @@ from services.app_generate_service import AppGenerateService from services.errors.app import WorkflowHashNotEqualError from services.errors.llm import InvokeRateLimitError from services.rag_pipeline.rag_pipeline import RagPipelineService -from services.tools.builtin_tools_manage_service import BuiltinToolManageService +from services.rag_pipeline.rag_pipeline_manage_service import RagPipelineManageService from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError logger = logging.getLogger(__name__) @@ -711,14 +711,7 @@ class DatasourceListApi(Resource): tenant_id = user.current_tenant_id - return jsonable_encoder( - [ - provider.to_dict() - for provider in BuiltinToolManageService.list_rag_pipeline_datasources( - tenant_id, - ) - ] - ) + return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(tenant_id)) api.add_resource( diff --git a/api/core/datasource/__base/datasource_plugin.py b/api/core/datasource/__base/datasource_plugin.py index 8fb89e1172..15d9e7d9ba 100644 --- a/api/core/datasource/__base/datasource_plugin.py +++ b/api/core/datasource/__base/datasource_plugin.py @@ -1,12 +1,9 @@ -from collections.abc import Generator -from typing import Any, Optional +from collections.abc import Mapping +from typing import Any from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import ( DatasourceEntity, - DatasourceInvokeMessage, - DatasourceParameter, - DatasourceProviderType, ) from core.plugin.impl.datasource import PluginDatasourceManager from core.plugin.utils.converter import convert_parameters_to_plugin_format @@ -16,7 +13,6 @@ class DatasourcePlugin: tenant_id: str icon: str plugin_unique_identifier: str - runtime_parameters: Optional[list[DatasourceParameter]] entity: DatasourceEntity runtime: DatasourceRuntime @@ -33,49 +29,41 @@ class DatasourcePlugin: self.tenant_id = tenant_id self.icon = icon self.plugin_unique_identifier = plugin_unique_identifier - self.runtime_parameters = None - - def datasource_provider_type(self) -> DatasourceProviderType: - return DatasourceProviderType.RAG_PIPELINE def _invoke_first_step( self, user_id: str, datasource_parameters: dict[str, Any], - rag_pipeline_id: Optional[str] = None, - ) -> Generator[DatasourceInvokeMessage, None, None]: + ) -> Mapping[str, Any]: manager = PluginDatasourceManager() datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) - yield from manager.invoke_first_step( + return manager.invoke_first_step( tenant_id=self.tenant_id, user_id=user_id, datasource_provider=self.entity.identity.provider, datasource_name=self.entity.identity.name, credentials=self.runtime.credentials, datasource_parameters=datasource_parameters, - rag_pipeline_id=rag_pipeline_id, ) def _invoke_second_step( self, user_id: str, datasource_parameters: dict[str, Any], - rag_pipeline_id: Optional[str] = None, - ) -> Generator[DatasourceInvokeMessage, None, None]: + ) -> Mapping[str, Any]: manager = PluginDatasourceManager() datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) - yield from manager.invoke( + return manager.invoke_second_step( tenant_id=self.tenant_id, user_id=user_id, datasource_provider=self.entity.identity.provider, datasource_name=self.entity.identity.name, credentials=self.runtime.credentials, datasource_parameters=datasource_parameters, - rag_pipeline_id=rag_pipeline_id, ) def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin": @@ -86,28 +74,3 @@ class DatasourcePlugin: icon=self.icon, plugin_unique_identifier=self.plugin_unique_identifier, ) - - def get_runtime_parameters( - self, - rag_pipeline_id: Optional[str] = None, - ) -> list[DatasourceParameter]: - """ - get the runtime parameters - """ - if not self.entity.has_runtime_parameters: - return self.entity.parameters - - if self.runtime_parameters is not None: - return self.runtime_parameters - - manager = PluginDatasourceManager() - self.runtime_parameters = manager.get_runtime_parameters( - tenant_id=self.tenant_id, - user_id="", - provider=self.entity.identity.provider, - datasource=self.entity.identity.name, - credentials=self.runtime.credentials, - rag_pipeline_id=rag_pipeline_id, - ) - - return self.runtime_parameters diff --git a/api/core/datasource/__base/datasource_provider.py b/api/core/datasource/__base/datasource_provider.py index ef3382b948..13804f53d9 100644 --- a/api/core/datasource/__base/datasource_provider.py +++ b/api/core/datasource/__base/datasource_provider.py @@ -2,7 +2,7 @@ from typing import Any from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime -from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType +from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin from core.entities.provider_entities import ProviderConfig from core.plugin.impl.tool import PluginToolManager from core.tools.errors import ToolProviderCredentialValidationError @@ -22,15 +22,6 @@ class DatasourcePluginProviderController: self.plugin_id = plugin_id self.plugin_unique_identifier = plugin_unique_identifier - @property - def provider_type(self) -> DatasourceProviderType: - """ - returns the type of the provider - - :return: type of the provider - """ - return DatasourceProviderType.RAG_PIPELINE - @property def need_credentials(self) -> bool: """ diff --git a/api/core/datasource/datasource_engine.py b/api/core/datasource/datasource_engine.py deleted file mode 100644 index c193c4c629..0000000000 --- a/api/core/datasource/datasource_engine.py +++ /dev/null @@ -1,224 +0,0 @@ -import json -from collections.abc import Generator, Iterable -from mimetypes import guess_type -from typing import Any, Optional, cast - -from yarl import URL - -from core.app.entities.app_invoke_entities import InvokeFrom -from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler -from core.datasource.__base.datasource_plugin import DatasourcePlugin -from core.datasource.entities.datasource_entities import ( - DatasourceInvokeMessage, - DatasourceInvokeMessageBinary, -) -from core.file import FileType -from core.file.models import FileTransferMethod -from extensions.ext_database import db -from models.enums import CreatedByRole -from models.model import Message, MessageFile - - -class DatasourceEngine: - """ - Datasource runtime engine take care of the datasource executions. - """ - - @staticmethod - def invoke_first_step( - datasource: DatasourcePlugin, - datasource_parameters: dict[str, Any], - user_id: str, - workflow_tool_callback: DifyWorkflowCallbackHandler, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, - ) -> Generator[DatasourceInvokeMessage, None, None]: - """ - Workflow invokes the datasource with the given arguments. - """ - try: - # hit the callback handler - workflow_tool_callback.on_datasource_start( - datasource_name=datasource.entity.identity.name, datasource_inputs=datasource_parameters - ) - - if datasource.runtime and datasource.runtime.runtime_parameters: - datasource_parameters = {**datasource.runtime.runtime_parameters, **datasource_parameters} - - response = datasource._invoke_first_step( - user_id=user_id, - datasource_parameters=datasource_parameters, - conversation_id=conversation_id, - app_id=app_id, - message_id=message_id, - ) - - # hit the callback handler - response = workflow_tool_callback.on_datasource_end( - datasource_name=datasource.entity.identity.name, - datasource_inputs=datasource_parameters, - datasource_outputs=response, - ) - - return response - except Exception as e: - workflow_tool_callback.on_tool_error(e) - raise e - - @staticmethod - def invoke_second_step( - datasource: DatasourcePlugin, - datasource_parameters: dict[str, Any], - user_id: str, - workflow_tool_callback: DifyWorkflowCallbackHandler, - ) -> Generator[DatasourceInvokeMessage, None, None]: - """ - Workflow invokes the datasource with the given arguments. - """ - try: - response = datasource._invoke_second_step( - user_id=user_id, - datasource_parameters=datasource_parameters, - ) - - return response - except Exception as e: - workflow_tool_callback.on_tool_error(e) - raise e - - @staticmethod - def _convert_datasource_response_to_str(datasource_response: list[DatasourceInvokeMessage]) -> str: - """ - Handle datasource response - """ - result = "" - for response in datasource_response: - if response.type == DatasourceInvokeMessage.MessageType.TEXT: - result += cast(DatasourceInvokeMessage.TextMessage, response.message).text - elif response.type == DatasourceInvokeMessage.MessageType.LINK: - result += ( - f"result link: {cast(DatasourceInvokeMessage.TextMessage, response.message).text}." - + " please tell user to check it." - ) - elif response.type in { - DatasourceInvokeMessage.MessageType.IMAGE_LINK, - DatasourceInvokeMessage.MessageType.IMAGE, - }: - result += ( - "image has been created and sent to user already, " - + "you do not need to create it, just tell the user to check it now." - ) - elif response.type == DatasourceInvokeMessage.MessageType.JSON: - result = json.dumps( - cast(DatasourceInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False - ) - else: - result += str(response.message) - - return result - - @staticmethod - def _extract_datasource_response_binary_and_text( - datasource_response: list[DatasourceInvokeMessage], - ) -> Generator[DatasourceInvokeMessageBinary, None, None]: - """ - Extract datasource response binary - """ - for response in datasource_response: - if response.type in { - DatasourceInvokeMessage.MessageType.IMAGE_LINK, - DatasourceInvokeMessage.MessageType.IMAGE, - }: - mimetype = None - if not response.meta: - raise ValueError("missing meta data") - if response.meta.get("mime_type"): - mimetype = response.meta.get("mime_type") - else: - try: - url = URL(cast(DatasourceInvokeMessage.TextMessage, response.message).text) - extension = url.suffix - guess_type_result, _ = guess_type(f"a{extension}") - if guess_type_result: - mimetype = guess_type_result - except Exception: - pass - - if not mimetype: - mimetype = "image/jpeg" - - yield DatasourceInvokeMessageBinary( - mimetype=response.meta.get("mime_type", "image/jpeg"), - url=cast(DatasourceInvokeMessage.TextMessage, response.message).text, - ) - elif response.type == DatasourceInvokeMessage.MessageType.BLOB: - if not response.meta: - raise ValueError("missing meta data") - - yield DatasourceInvokeMessageBinary( - mimetype=response.meta.get("mime_type", "application/octet-stream"), - url=cast(DatasourceInvokeMessage.TextMessage, response.message).text, - ) - elif response.type == DatasourceInvokeMessage.MessageType.LINK: - # check if there is a mime type in meta - if response.meta and "mime_type" in response.meta: - yield DatasourceInvokeMessageBinary( - mimetype=response.meta.get("mime_type", "application/octet-stream") - if response.meta - else "application/octet-stream", - url=cast(DatasourceInvokeMessage.TextMessage, response.message).text, - ) - - @staticmethod - def _create_message_files( - datasource_messages: Iterable[DatasourceInvokeMessageBinary], - agent_message: Message, - invoke_from: InvokeFrom, - user_id: str, - ) -> list[str]: - """ - Create message file - - :return: message file ids - """ - result = [] - - for message in datasource_messages: - if "image" in message.mimetype: - file_type = FileType.IMAGE - elif "video" in message.mimetype: - file_type = FileType.VIDEO - elif "audio" in message.mimetype: - file_type = FileType.AUDIO - elif "text" in message.mimetype or "pdf" in message.mimetype: - file_type = FileType.DOCUMENT - else: - file_type = FileType.CUSTOM - - # extract tool file id from url - tool_file_id = message.url.split("/")[-1].split(".")[0] - message_file = MessageFile( - message_id=agent_message.id, - type=file_type, - transfer_method=FileTransferMethod.TOOL_FILE, - belongs_to="assistant", - url=message.url, - upload_file_id=tool_file_id, - created_by_role=( - CreatedByRole.ACCOUNT - if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} - else CreatedByRole.END_USER - ), - created_by=user_id, - ) - - db.session.add(message_file) - db.session.commit() - db.session.refresh(message_file) - - result.append(message_file.id) - - db.session.close() - - return result diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index fa141a679a..c865b557f9 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -6,9 +6,8 @@ import contexts from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.entities.common_entities import I18nObject -from core.datasource.entities.datasource_entities import DatasourceProviderType from core.datasource.errors import DatasourceProviderNotFoundError -from core.plugin.impl.tool import PluginToolManager +from core.plugin.impl.datasource import PluginDatasourceManager logger = logging.getLogger(__name__) @@ -36,7 +35,7 @@ class DatasourceManager: if provider in datasource_plugin_providers: return datasource_plugin_providers[provider] - manager = PluginToolManager() + manager = PluginDatasourceManager() provider_entity = manager.fetch_datasource_provider(tenant_id, provider) if not provider_entity: raise DatasourceProviderNotFoundError(f"plugin provider {provider} not found") @@ -55,7 +54,6 @@ class DatasourceManager: @classmethod def get_datasource_runtime( cls, - provider_type: DatasourceProviderType, provider_id: str, datasource_name: str, tenant_id: str, @@ -70,18 +68,15 @@ class DatasourceManager: :return: the datasource plugin """ - if provider_type == DatasourceProviderType.RAG_PIPELINE: - return cls.get_datasource_plugin_provider(provider_id, tenant_id).get_datasource(datasource_name) - else: - raise DatasourceProviderNotFoundError(f"provider type {provider_type.value} not found") + return cls.get_datasource_plugin_provider(provider_id, tenant_id).get_datasource(datasource_name) @classmethod def list_datasource_providers(cls, tenant_id: str) -> list[DatasourcePluginProviderController]: """ list all the datasource providers """ - manager = PluginToolManager() - provider_entities = manager.fetch_datasources(tenant_id) + manager = PluginDatasourceManager() + provider_entities = manager.fetch_datasource_providers(tenant_id) return [ DatasourcePluginProviderController( entity=provider.declaration, diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py index 2d42484a30..8d6bed41fa 100644 --- a/api/core/datasource/entities/api_entities.py +++ b/api/core/datasource/entities/api_entities.py @@ -4,7 +4,6 @@ from pydantic import BaseModel, Field, field_validator from core.datasource.entities.datasource_entities import DatasourceParameter from core.model_runtime.utils.encoders import jsonable_encoder -from core.tools.__base.tool import ToolParameter from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType @@ -14,7 +13,7 @@ class DatasourceApiEntity(BaseModel): name: str # identifier label: I18nObject # label description: I18nObject - parameters: Optional[list[ToolParameter]] = None + parameters: Optional[list[DatasourceParameter]] = None labels: list[str] = Field(default_factory=list) output_schema: Optional[dict] = None diff --git a/api/core/datasource/entities/constants.py b/api/core/datasource/entities/constants.py deleted file mode 100644 index a4dbf6f11f..0000000000 --- a/api/core/datasource/entities/constants.py +++ /dev/null @@ -1 +0,0 @@ -DATASOURCE_SELECTOR_MODEL_IDENTITY = "__dify__datasource_selector__" diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index aa31a7f86a..e1bcbc323b 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -1,13 +1,9 @@ -import base64 import enum -from collections.abc import Mapping from enum import Enum -from typing import Any, Optional, Union +from typing import Any, Optional -from pydantic import BaseModel, Field, ValidationInfo, field_serializer, field_validator, model_validator +from pydantic import BaseModel, Field, ValidationInfo, field_validator -from core.datasource.entities.constants import DATASOURCE_SELECTOR_MODEL_IDENTITY -from core.entities.provider_entities import ProviderConfig from core.plugin.entities.parameters import ( PluginParameter, PluginParameterOption, @@ -17,25 +13,7 @@ from core.plugin.entities.parameters import ( init_frontend_parameter, ) from core.tools.entities.common_entities import I18nObject - - -class ToolLabelEnum(Enum): - SEARCH = "search" - IMAGE = "image" - VIDEOS = "videos" - WEATHER = "weather" - FINANCE = "finance" - DESIGN = "design" - TRAVEL = "travel" - SOCIAL = "social" - NEWS = "news" - MEDICAL = "medical" - PRODUCTIVITY = "productivity" - EDUCATION = "education" - BUSINESS = "business" - ENTERTAINMENT = "entertainment" - UTILITIES = "utilities" - OTHER = "other" +from core.tools.entities.tool_entities import ToolProviderEntity class DatasourceProviderType(enum.StrEnum): @@ -43,7 +21,9 @@ class DatasourceProviderType(enum.StrEnum): Enum class for datasource provider """ - RAG_PIPELINE = "rag_pipeline" + ONLINE_DOCUMENT = "online_document" + LOCAL_FILE = "local_file" + WEBSITE = "website" @classmethod def value_of(cls, value: str) -> "DatasourceProviderType": @@ -59,153 +39,6 @@ class DatasourceProviderType(enum.StrEnum): raise ValueError(f"invalid mode value {value}") -class ApiProviderSchemaType(Enum): - """ - Enum class for api provider schema type. - """ - - OPENAPI = "openapi" - SWAGGER = "swagger" - OPENAI_PLUGIN = "openai_plugin" - OPENAI_ACTIONS = "openai_actions" - - @classmethod - def value_of(cls, value: str) -> "ApiProviderSchemaType": - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f"invalid mode value {value}") - - -class ApiProviderAuthType(Enum): - """ - Enum class for api provider auth type. - """ - - NONE = "none" - API_KEY = "api_key" - - @classmethod - def value_of(cls, value: str) -> "ApiProviderAuthType": - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f"invalid mode value {value}") - - -class DatasourceInvokeMessage(BaseModel): - class TextMessage(BaseModel): - text: str - - class JsonMessage(BaseModel): - json_object: dict - - class BlobMessage(BaseModel): - blob: bytes - - class FileMessage(BaseModel): - pass - - class VariableMessage(BaseModel): - variable_name: str = Field(..., description="The name of the variable") - variable_value: Any = Field(..., description="The value of the variable") - stream: bool = Field(default=False, description="Whether the variable is streamed") - - @model_validator(mode="before") - @classmethod - def transform_variable_value(cls, values) -> Any: - """ - Only basic types and lists are allowed. - """ - value = values.get("variable_value") - if not isinstance(value, dict | list | str | int | float | bool): - raise ValueError("Only basic types and lists are allowed.") - - # if stream is true, the value must be a string - if values.get("stream"): - if not isinstance(value, str): - raise ValueError("When 'stream' is True, 'variable_value' must be a string.") - - return values - - @field_validator("variable_name", mode="before") - @classmethod - def transform_variable_name(cls, value: str) -> str: - """ - The variable name must be a string. - """ - if value in {"json", "text", "files"}: - raise ValueError(f"The variable name '{value}' is reserved.") - return value - - class LogMessage(BaseModel): - class LogStatus(Enum): - START = "start" - ERROR = "error" - SUCCESS = "success" - - id: str - label: str = Field(..., description="The label of the log") - parent_id: Optional[str] = Field(default=None, description="Leave empty for root log") - error: Optional[str] = Field(default=None, description="The error message") - status: LogStatus = Field(..., description="The status of the log") - data: Mapping[str, Any] = Field(..., description="Detailed log data") - metadata: Optional[Mapping[str, Any]] = Field(default=None, description="The metadata of the log") - - class MessageType(Enum): - TEXT = "text" - IMAGE = "image" - LINK = "link" - BLOB = "blob" - JSON = "json" - IMAGE_LINK = "image_link" - BINARY_LINK = "binary_link" - VARIABLE = "variable" - FILE = "file" - LOG = "log" - - type: MessageType = MessageType.TEXT - """ - plain text, image url or link url - """ - message: JsonMessage | TextMessage | BlobMessage | LogMessage | FileMessage | None | VariableMessage - meta: dict[str, Any] | None = None - - @field_validator("message", mode="before") - @classmethod - def decode_blob_message(cls, v): - if isinstance(v, dict) and "blob" in v: - try: - v["blob"] = base64.b64decode(v["blob"]) - except Exception: - pass - return v - - @field_serializer("message") - def serialize_message(self, v): - if isinstance(v, self.BlobMessage): - return {"blob": base64.b64encode(v.blob).decode("utf-8")} - return v - - -class DatasourceInvokeMessageBinary(BaseModel): - mimetype: str = Field(..., description="The mimetype of the binary") - url: str = Field(..., description="The url of the binary") - file_var: Optional[dict[str, Any]] = None - - class DatasourceParameter(PluginParameter): """ Overrides type @@ -223,8 +56,6 @@ class DatasourceParameter(PluginParameter): SECRET_INPUT = PluginParameterType.SECRET_INPUT.value FILE = PluginParameterType.FILE.value FILES = PluginParameterType.FILES.value - APP_SELECTOR = PluginParameterType.APP_SELECTOR.value - MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value # deprecated, should not use. SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value @@ -235,21 +66,13 @@ class DatasourceParameter(PluginParameter): def cast_value(self, value: Any): return cast_parameter_value(self, value) - class DatasourceParameterForm(Enum): - SCHEMA = "schema" # should be set while adding tool - FORM = "form" # should be set before invoking tool - LLM = "llm" # will be set by LLM - type: DatasourceParameterType = Field(..., description="The type of the parameter") - human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user") - form: DatasourceParameterForm = Field(..., description="The form of the parameter, schema/form/llm") - llm_description: Optional[str] = None + description: I18nObject = Field(..., description="The description of the parameter") @classmethod def get_simple_instance( cls, name: str, - llm_description: str, typ: DatasourceParameterType, required: bool, options: Optional[list[str]] = None, @@ -277,30 +100,16 @@ class DatasourceParameter(PluginParameter): name=name, label=I18nObject(en_US="", zh_Hans=""), placeholder=None, - human_description=I18nObject(en_US="", zh_Hans=""), type=typ, - form=cls.ToolParameterForm.LLM, - llm_description=llm_description, required=required, options=option_objs, + description=I18nObject(en_US="", zh_Hans=""), ) def init_frontend_parameter(self, value: Any): return init_frontend_parameter(self, self.type, value) -class ToolProviderIdentity(BaseModel): - author: str = Field(..., description="The author of the tool") - name: str = Field(..., description="The name of the tool") - description: I18nObject = Field(..., description="The description of the tool") - icon: str = Field(..., description="The icon of the tool") - label: I18nObject = Field(..., description="The label of the tool") - tags: Optional[list[ToolLabelEnum]] = Field( - default=[], - description="The tags of the tool", - ) - - class DatasourceIdentity(BaseModel): author: str = Field(..., description="The author of the tool") name: str = Field(..., description="The name of the tool") @@ -327,26 +136,18 @@ class DatasourceEntity(BaseModel): return v or [] -class ToolProviderEntity(BaseModel): - identity: ToolProviderIdentity - plugin_id: Optional[str] = None - credentials_schema: list[ProviderConfig] = Field(default_factory=list) +class DatasourceProviderEntity(ToolProviderEntity): + """ + Datasource provider entity + """ + + provider_type: DatasourceProviderType -class DatasourceProviderEntityWithPlugin(ToolProviderEntity): +class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity): datasources: list[DatasourceEntity] = Field(default_factory=list) -class WorkflowToolParameterConfiguration(BaseModel): - """ - Workflow tool configuration - """ - - name: str = Field(..., description="The name of the parameter") - description: str = Field(..., description="The description of the parameter") - form: DatasourceParameter.DatasourceParameterForm = Field(..., description="The form of the parameter") - - class DatasourceInvokeMeta(BaseModel): """ Datasource invoke meta @@ -394,24 +195,3 @@ class DatasourceInvokeFrom(Enum): """ RAG_PIPELINE = "rag_pipeline" - - -class DatasourceSelector(BaseModel): - dify_model_identity: str = DATASOURCE_SELECTOR_MODEL_IDENTITY - - class Parameter(BaseModel): - name: str = Field(..., description="The name of the parameter") - type: DatasourceParameter.DatasourceParameterType = Field(..., description="The type of the parameter") - required: bool = Field(..., description="Whether the parameter is required") - description: str = Field(..., description="The description of the parameter") - default: Optional[Union[int, float, str]] = None - options: Optional[list[PluginParameterOption]] = None - - provider_id: str = Field(..., description="The id of the provider") - datasource_name: str = Field(..., description="The name of the datasource") - datasource_description: str = Field(..., description="The description of the datasource") - datasource_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form") - datasource_parameters: Mapping[str, Parameter] = Field(..., description="Parameters, type llm") - - def to_plugin_parameter(self) -> dict[str, Any]: - return self.model_dump() diff --git a/api/core/datasource/entities/values.py b/api/core/datasource/entities/values.py deleted file mode 100644 index f460df7e25..0000000000 --- a/api/core/datasource/entities/values.py +++ /dev/null @@ -1,111 +0,0 @@ -from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolLabel, ToolLabelEnum - -ICONS = { - ToolLabelEnum.SEARCH: """ - -""", # noqa: E501 - ToolLabelEnum.IMAGE: """ - -""", # noqa: E501 - ToolLabelEnum.VIDEOS: """ - -""", # noqa: E501 - ToolLabelEnum.WEATHER: """ - -""", # noqa: E501 - ToolLabelEnum.FINANCE: """ - -""", # noqa: E501 - ToolLabelEnum.DESIGN: """ - -""", # noqa: E501 - ToolLabelEnum.TRAVEL: """ - -""", # noqa: E501 - ToolLabelEnum.SOCIAL: """ - -""", # noqa: E501 - ToolLabelEnum.NEWS: """ - -""", # noqa: E501 - ToolLabelEnum.MEDICAL: """ - -""", # noqa: E501 - ToolLabelEnum.PRODUCTIVITY: """ - -""", # noqa: E501 - ToolLabelEnum.EDUCATION: """ - -""", # noqa: E501 - ToolLabelEnum.BUSINESS: """ - -""", # noqa: E501 - ToolLabelEnum.ENTERTAINMENT: """ - -""", # noqa: E501 - ToolLabelEnum.UTILITIES: """ - -""", # noqa: E501 - ToolLabelEnum.OTHER: """ - -""", # noqa: E501 -} - -default_tool_label_dict = { - ToolLabelEnum.SEARCH: ToolLabel( - name="search", label=I18nObject(en_US="Search", zh_Hans="搜索"), icon=ICONS[ToolLabelEnum.SEARCH] - ), - ToolLabelEnum.IMAGE: ToolLabel( - name="image", label=I18nObject(en_US="Image", zh_Hans="图片"), icon=ICONS[ToolLabelEnum.IMAGE] - ), - ToolLabelEnum.VIDEOS: ToolLabel( - name="videos", label=I18nObject(en_US="Videos", zh_Hans="视频"), icon=ICONS[ToolLabelEnum.VIDEOS] - ), - ToolLabelEnum.WEATHER: ToolLabel( - name="weather", label=I18nObject(en_US="Weather", zh_Hans="天气"), icon=ICONS[ToolLabelEnum.WEATHER] - ), - ToolLabelEnum.FINANCE: ToolLabel( - name="finance", label=I18nObject(en_US="Finance", zh_Hans="金融"), icon=ICONS[ToolLabelEnum.FINANCE] - ), - ToolLabelEnum.DESIGN: ToolLabel( - name="design", label=I18nObject(en_US="Design", zh_Hans="设计"), icon=ICONS[ToolLabelEnum.DESIGN] - ), - ToolLabelEnum.TRAVEL: ToolLabel( - name="travel", label=I18nObject(en_US="Travel", zh_Hans="旅行"), icon=ICONS[ToolLabelEnum.TRAVEL] - ), - ToolLabelEnum.SOCIAL: ToolLabel( - name="social", label=I18nObject(en_US="Social", zh_Hans="社交"), icon=ICONS[ToolLabelEnum.SOCIAL] - ), - ToolLabelEnum.NEWS: ToolLabel( - name="news", label=I18nObject(en_US="News", zh_Hans="新闻"), icon=ICONS[ToolLabelEnum.NEWS] - ), - ToolLabelEnum.MEDICAL: ToolLabel( - name="medical", label=I18nObject(en_US="Medical", zh_Hans="医疗"), icon=ICONS[ToolLabelEnum.MEDICAL] - ), - ToolLabelEnum.PRODUCTIVITY: ToolLabel( - name="productivity", - label=I18nObject(en_US="Productivity", zh_Hans="生产力"), - icon=ICONS[ToolLabelEnum.PRODUCTIVITY], - ), - ToolLabelEnum.EDUCATION: ToolLabel( - name="education", label=I18nObject(en_US="Education", zh_Hans="教育"), icon=ICONS[ToolLabelEnum.EDUCATION] - ), - ToolLabelEnum.BUSINESS: ToolLabel( - name="business", label=I18nObject(en_US="Business", zh_Hans="商业"), icon=ICONS[ToolLabelEnum.BUSINESS] - ), - ToolLabelEnum.ENTERTAINMENT: ToolLabel( - name="entertainment", - label=I18nObject(en_US="Entertainment", zh_Hans="娱乐"), - icon=ICONS[ToolLabelEnum.ENTERTAINMENT], - ), - ToolLabelEnum.UTILITIES: ToolLabel( - name="utilities", label=I18nObject(en_US="Utilities", zh_Hans="工具"), icon=ICONS[ToolLabelEnum.UTILITIES] - ), - ToolLabelEnum.OTHER: ToolLabel( - name="other", label=I18nObject(en_US="Other", zh_Hans="其他"), icon=ICONS[ToolLabelEnum.OTHER] - ), -} - -default_tool_labels = [v for k, v in default_tool_label_dict.items()] -default_tool_label_name_list = [label.name for label in default_tool_labels] diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index c69fa2fe32..922e65d725 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -1,16 +1,16 @@ -from collections.abc import Generator -from typing import Any, Optional - -from pydantic import BaseModel +from collections.abc import Mapping +from typing import Any from core.plugin.entities.plugin import GenericProviderID, ToolProviderID -from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity +from core.plugin.entities.plugin_daemon import ( + PluginBasicBooleanResponse, + PluginDatasourceProviderEntity, +) from core.plugin.impl.base import BasePluginClient -from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter class PluginDatasourceManager(BasePluginClient): - def fetch_datasource_providers(self, tenant_id: str) -> list[PluginToolProviderEntity]: + def fetch_datasource_providers(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]: """ Fetch datasource providers for the given tenant. """ @@ -27,7 +27,7 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response( "GET", f"plugin/{tenant_id}/management/datasources", - list[PluginToolProviderEntity], + list[PluginDatasourceProviderEntity], params={"page": 1, "page_size": 256}, transformer=transformer, ) @@ -36,12 +36,12 @@ class PluginDatasourceManager(BasePluginClient): provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}" # override the provider name for each tool to plugin_id/provider_name - for tool in provider.declaration.tools: - tool.identity.provider = provider.declaration.identity.name + for datasource in provider.declaration.datasources: + datasource.identity.provider = provider.declaration.identity.name return response - def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity: + def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginDatasourceProviderEntity: """ Fetch datasource provider for the given tenant and plugin. """ @@ -58,7 +58,7 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response( "GET", f"plugin/{tenant_id}/management/datasources", - PluginToolProviderEntity, + PluginDatasourceProviderEntity, params={"provider": tool_provider_id.provider_name, "plugin_id": tool_provider_id.plugin_id}, transformer=transformer, ) @@ -66,8 +66,8 @@ class PluginDatasourceManager(BasePluginClient): response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}" # override the provider name for each tool to plugin_id/provider_name - for tool in response.declaration.tools: - tool.identity.provider = response.declaration.identity.name + for datasource in response.declaration.datasources: + datasource.identity.provider = response.declaration.identity.name return response @@ -79,7 +79,7 @@ class PluginDatasourceManager(BasePluginClient): datasource_name: str, credentials: dict[str, Any], datasource_parameters: dict[str, Any], - ) -> Generator[ToolInvokeMessage, None, None]: + ) -> Mapping[str, Any]: """ Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. """ @@ -88,8 +88,8 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( "POST", - f"plugin/{tenant_id}/dispatch/datasource/{online_document}/pages", - ToolInvokeMessage, + f"plugin/{tenant_id}/dispatch/datasource/first_step", + dict, data={ "user_id": user_id, "data": { @@ -104,7 +104,10 @@ class PluginDatasourceManager(BasePluginClient): "Content-Type": "application/json", }, ) - return response + for resp in response: + return resp + + raise Exception("No response from plugin daemon") def invoke_second_step( self, @@ -114,7 +117,7 @@ class PluginDatasourceManager(BasePluginClient): datasource_name: str, credentials: dict[str, Any], datasource_parameters: dict[str, Any], - ) -> Generator[ToolInvokeMessage, None, None]: + ) -> Mapping[str, Any]: """ Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. """ @@ -123,8 +126,8 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( "POST", - f"plugin/{tenant_id}/dispatch/datasource/invoke_second_step", - ToolInvokeMessage, + f"plugin/{tenant_id}/dispatch/datasource/second_step", + dict, data={ "user_id": user_id, "data": { @@ -139,7 +142,10 @@ class PluginDatasourceManager(BasePluginClient): "Content-Type": "application/json", }, ) - return response + for resp in response: + return resp + + raise Exception("No response from plugin daemon") def validate_provider_credentials( self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any] @@ -151,7 +157,7 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( "POST", - f"plugin/{tenant_id}/dispatch/tool/validate_credentials", + f"plugin/{tenant_id}/dispatch/datasource/validate_credentials", PluginBasicBooleanResponse, data={ "user_id": user_id, @@ -170,48 +176,3 @@ class PluginDatasourceManager(BasePluginClient): return resp.result return False - - def get_runtime_parameters( - self, - tenant_id: str, - user_id: str, - provider: str, - credentials: dict[str, Any], - datasource: str, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, - ) -> list[ToolParameter]: - """ - get the runtime parameters of the datasource - """ - datasource_provider_id = GenericProviderID(provider) - - class RuntimeParametersResponse(BaseModel): - parameters: list[ToolParameter] - - response = self._request_with_plugin_daemon_response_stream( - "POST", - f"plugin/{tenant_id}/dispatch/datasource/get_runtime_parameters", - RuntimeParametersResponse, - data={ - "user_id": user_id, - "conversation_id": conversation_id, - "app_id": app_id, - "message_id": message_id, - "data": { - "provider": datasource_provider_id.provider_name, - "datasource": datasource, - "credentials": credentials, - }, - }, - headers={ - "X-Plugin-ID": datasource_provider_id.plugin_id, - "Content-Type": "application/json", - }, - ) - - for resp in response: - return resp.parameters - - return [] diff --git a/api/core/plugin/impl/tool.py b/api/core/plugin/impl/tool.py index 54f5418bb4..bb9c00005c 100644 --- a/api/core/plugin/impl/tool.py +++ b/api/core/plugin/impl/tool.py @@ -3,10 +3,9 @@ from typing import Any, Optional from pydantic import BaseModel -from core.plugin.entities.plugin import DatasourceProviderID, GenericProviderID, ToolProviderID +from core.plugin.entities.plugin import GenericProviderID, ToolProviderID from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, - PluginDatasourceProviderEntity, PluginToolProviderEntity, ) from core.plugin.impl.base import BasePluginClient @@ -45,67 +44,6 @@ class PluginToolManager(BasePluginClient): return response - def fetch_datasources(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]: - """ - Fetch datasources for the given tenant. - """ - - def transformer(json_response: dict[str, Any]) -> dict: - for provider in json_response.get("data", []): - declaration = provider.get("declaration", {}) or {} - provider_name = declaration.get("identity", {}).get("name") - for tool in declaration.get("tools", []): - tool["identity"]["provider"] = provider_name - - return json_response - - response = self._request_with_plugin_daemon_response( - "GET", - f"plugin/{tenant_id}/management/datasources", - list[PluginToolProviderEntity], - params={"page": 1, "page_size": 256}, - transformer=transformer, - ) - - for provider in response: - provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}" - - # override the provider name for each tool to plugin_id/provider_name - for tool in provider.declaration.tools: - tool.identity.provider = provider.declaration.identity.name - - return response - - def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginDatasourceProviderEntity: - """ - Fetch datasource provider for the given tenant and plugin. - """ - datasource_provider_id = DatasourceProviderID(provider) - - def transformer(json_response: dict[str, Any]) -> dict: - data = json_response.get("data") - if data: - for tool in data.get("declaration", {}).get("tools", []): - tool["identity"]["provider"] = datasource_provider_id.provider_name - - return json_response - - response = self._request_with_plugin_daemon_response( - "GET", - f"plugin/{tenant_id}/management/datasource", - PluginDatasourceProviderEntity, - params={"provider": datasource_provider_id.provider_name, "plugin_id": datasource_provider_id.plugin_id}, - transformer=transformer, - ) - - response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}" - - # override the provider name for each tool to plugin_id/provider_name - for tool in response.declaration.tools: - tool.identity.provider = response.declaration.identity.name - - return response - def fetch_tool_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity: """ Fetch tool provider for the given tenant and plugin. diff --git a/api/core/plugin/utils/converter.py b/api/core/plugin/utils/converter.py index 49bf7c308a..6876285b31 100644 --- a/api/core/plugin/utils/converter.py +++ b/api/core/plugin/utils/converter.py @@ -1,6 +1,5 @@ from typing import Any -from core.datasource.entities.datasource_entities import DatasourceSelector from core.file.models import File from core.tools.entities.tool_entities import ToolSelector @@ -19,10 +18,4 @@ def convert_parameters_to_plugin_format(parameters: dict[str, Any]) -> dict[str, parameters[parameter_name] = [] for p in parameter: parameters[parameter_name].append(p.to_plugin_parameter()) - elif isinstance(parameter, DatasourceSelector): - parameters[parameter_name] = parameter.to_plugin_parameter() - elif isinstance(parameter, list) and all(isinstance(p, DatasourceSelector) for p in parameter): - parameters[parameter_name] = [] - for p in parameter: - parameters[parameter_name].append(p.to_plugin_parameter()) return parameters diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 682a32d26f..aa2661fe63 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, Any, Union, cast from yarl import URL import contexts -from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.plugin.entities.plugin import ToolProviderID from core.plugin.impl.tool import PluginToolManager from core.tools.__base.tool_provider import ToolProviderController @@ -496,31 +495,6 @@ class ToolManager: # get plugin providers yield from cls.list_plugin_providers(tenant_id) - @classmethod - def list_datasource_providers(cls, tenant_id: str) -> list[DatasourcePluginProviderController]: - """ - list all the datasource providers - """ - manager = PluginToolManager() - provider_entities = manager.fetch_datasources(tenant_id) - return [ - DatasourcePluginProviderController( - entity=provider.declaration, - plugin_id=provider.plugin_id, - plugin_unique_identifier=provider.plugin_unique_identifier, - tenant_id=tenant_id, - ) - for provider in provider_entities - ] - - @classmethod - def list_builtin_datasources(cls, tenant_id: str) -> Generator[DatasourcePluginProviderController, None, None]: - """ - list all the builtin datasources - """ - # get builtin datasources - yield from cls.list_datasource_providers(tenant_id) - @classmethod def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]: """ diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 8ecf66c0d6..e7d4da8426 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -1,35 +1,24 @@ from collections.abc import Generator, Mapping, Sequence from typing import Any, cast -from sqlalchemy import select -from sqlalchemy.orm import Session - -from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler -from core.datasource.datasource_engine import DatasourceEngine -from core.datasource.entities.datasource_entities import DatasourceInvokeMessage, DatasourceParameter -from core.datasource.errors import DatasourceInvokeError -from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer -from core.file import File, FileTransferMethod -from core.plugin.manager.exc import PluginDaemonClientSideError -from core.plugin.manager.plugin import PluginInstallationManager +from core.datasource.entities.datasource_entities import ( + DatasourceParameter, +) +from core.file import File +from core.plugin.impl.exc import PluginDaemonClientSideError from core.variables.segments import ArrayAnySegment from core.variables.variables import ArrayAnyVariable from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey -from core.workflow.graph_engine.entities.event import AgentLogEvent from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType -from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent +from core.workflow.nodes.event import RunCompletedEvent from core.workflow.utils.variable_template_parser import VariableTemplateParser -from extensions.ext_database import db -from factories import file_factory -from models import ToolFile from models.workflow import WorkflowNodeExecutionStatus -from services.tools.builtin_tools_manage_service import BuiltinToolManageService from .entities import DatasourceNodeData -from .exc import DatasourceNodeError, DatasourceParameterError, ToolFileError +from .exc import DatasourceNodeError, DatasourceParameterError class DatasourceNode(BaseNode[DatasourceNodeData]): @@ -49,7 +38,6 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): # fetch datasource icon datasource_info = { - "provider_type": node_data.provider_type.value, "provider_id": node_data.provider_id, "plugin_unique_identifier": node_data.plugin_unique_identifier, } @@ -58,8 +46,10 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): try: from core.datasource.datasource_manager import DatasourceManager - datasource_runtime = DatasourceManager.get_workflow_datasource_runtime( - self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from + datasource_runtime = DatasourceManager.get_datasource_runtime( + provider_id=node_data.provider_id, + datasource_name=node_data.datasource_name, + tenant_id=self.tenant_id, ) except DatasourceNodeError as e: yield RunCompletedEvent( @@ -74,7 +64,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): return # get parameters - datasource_parameters = datasource_runtime.get_merged_runtime_parameters() or [] + datasource_parameters = datasource_runtime.entity.parameters parameters = self._generate_parameters( datasource_parameters=datasource_parameters, variable_pool=self.graph_runtime_state.variable_pool, @@ -91,15 +81,20 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) try: - message_stream = DatasourceEngine.generic_invoke( - datasource=datasource_runtime, - datasource_parameters=parameters, + # TODO: handle result + result = datasource_runtime._invoke_second_step( user_id=self.user_id, - workflow_tool_callback=DifyWorkflowCallbackHandler(), - workflow_call_depth=self.workflow_call_depth, - thread_pool_id=self.thread_pool_id, - app_id=self.app_id, - conversation_id=conversation_id.text if conversation_id else None, + datasource_parameters=parameters, + ) + except PluginDaemonClientSideError as e: + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters_for_log, + metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + error=f"Failed to transform datasource message: {str(e)}", + error_type=type(e).__name__, + ) ) except DatasourceNodeError as e: yield RunCompletedEvent( @@ -113,20 +108,6 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): ) return - try: - # convert datasource messages - yield from self._transform_message(message_stream, datasource_info, parameters_for_log) - except (PluginDaemonClientSideError, DatasourceInvokeError) as e: - yield RunCompletedEvent( - run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, - error=f"Failed to transform datasource message: {str(e)}", - error_type=type(e).__name__, - ) - ) - def _generate_parameters( self, *, @@ -175,200 +156,6 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) return list(variable.value) if variable else [] - def _transform_message( - self, - messages: Generator[DatasourceInvokeMessage, None, None], - datasource_info: Mapping[str, Any], - parameters_for_log: dict[str, Any], - ) -> Generator: - """ - Convert ToolInvokeMessages into tuple[plain_text, files] - """ - # transform message and handle file storage - message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages( - messages=messages, - user_id=self.user_id, - tenant_id=self.tenant_id, - conversation_id=None, - ) - - text = "" - files: list[File] = [] - json: list[dict] = [] - - agent_logs: list[AgentLogEvent] = [] - agent_execution_metadata: Mapping[NodeRunMetadataKey, Any] = {} - - variables: dict[str, Any] = {} - - for message in message_stream: - if message.type in { - DatasourceInvokeMessage.MessageType.IMAGE_LINK, - DatasourceInvokeMessage.MessageType.BINARY_LINK, - DatasourceInvokeMessage.MessageType.IMAGE, - }: - assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) - - url = message.message.text - if message.meta: - transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) - else: - transfer_method = FileTransferMethod.TOOL_FILE - - tool_file_id = str(url).split("/")[-1].split(".")[0] - - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == tool_file_id) - tool_file = session.scalar(stmt) - if tool_file is None: - raise ToolFileError(f"Tool file {tool_file_id} does not exist") - - mapping = { - "tool_file_id": tool_file_id, - "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype), - "transfer_method": transfer_method, - "url": url, - } - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=self.tenant_id, - ) - files.append(file) - elif message.type == DatasourceInvokeMessage.MessageType.BLOB: - # get tool file id - assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) - assert message.meta - - tool_file_id = message.message.text.split("/")[-1].split(".")[0] - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == tool_file_id) - tool_file = session.scalar(stmt) - if tool_file is None: - raise ToolFileError(f"tool file {tool_file_id} not exists") - - mapping = { - "tool_file_id": tool_file_id, - "transfer_method": FileTransferMethod.TOOL_FILE, - } - - files.append( - file_factory.build_from_mapping( - mapping=mapping, - tenant_id=self.tenant_id, - ) - ) - elif message.type == DatasourceInvokeMessage.MessageType.TEXT: - assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) - text += message.message.text - yield RunStreamChunkEvent( - chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"] - ) - elif message.type == DatasourceInvokeMessage.MessageType.JSON: - assert isinstance(message.message, DatasourceInvokeMessage.JsonMessage) - if self.node_type == NodeType.AGENT: - msg_metadata = message.message.json_object.pop("execution_metadata", {}) - agent_execution_metadata = { - key: value - for key, value in msg_metadata.items() - if key in NodeRunMetadataKey.__members__.values() - } - json.append(message.message.json_object) - elif message.type == DatasourceInvokeMessage.MessageType.LINK: - assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) - stream_text = f"Link: {message.message.text}\n" - text += stream_text - yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"]) - elif message.type == DatasourceInvokeMessage.MessageType.VARIABLE: - assert isinstance(message.message, DatasourceInvokeMessage.VariableMessage) - variable_name = message.message.variable_name - variable_value = message.message.variable_value - if message.message.stream: - if not isinstance(variable_value, str): - raise ValueError("When 'stream' is True, 'variable_value' must be a string.") - if variable_name not in variables: - variables[variable_name] = "" - variables[variable_name] += variable_value - - yield RunStreamChunkEvent( - chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name] - ) - else: - variables[variable_name] = variable_value - elif message.type == DatasourceInvokeMessage.MessageType.FILE: - assert message.meta is not None - files.append(message.meta["file"]) - elif message.type == DatasourceInvokeMessage.MessageType.LOG: - assert isinstance(message.message, DatasourceInvokeMessage.LogMessage) - if message.message.metadata: - icon = datasource_info.get("icon", "") - dict_metadata = dict(message.message.metadata) - if dict_metadata.get("provider"): - manager = PluginInstallationManager() - plugins = manager.list_plugins(self.tenant_id) - try: - current_plugin = next( - plugin - for plugin in plugins - if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"] - ) - icon = current_plugin.declaration.icon - except StopIteration: - pass - try: - builtin_tool = next( - provider - for provider in BuiltinToolManageService.list_builtin_tools( - self.user_id, - self.tenant_id, - ) - if provider.name == dict_metadata["provider"] - ) - icon = builtin_tool.icon - except StopIteration: - pass - - dict_metadata["icon"] = icon - message.message.metadata = dict_metadata - agent_log = AgentLogEvent( - id=message.message.id, - node_execution_id=self.id, - parent_id=message.message.parent_id, - error=message.message.error, - status=message.message.status.value, - data=message.message.data, - label=message.message.label, - metadata=message.message.metadata, - node_id=self.node_id, - ) - - # check if the agent log is already in the list - for log in agent_logs: - if log.id == agent_log.id: - # update the log - log.data = agent_log.data - log.status = agent_log.status - log.error = agent_log.error - log.label = agent_log.label - log.metadata = agent_log.metadata - break - else: - agent_logs.append(agent_log) - - yield agent_log - - yield RunCompletedEvent( - run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"text": text, "files": files, "json": json, **variables}, - metadata={ - **agent_execution_metadata, - NodeRunMetadataKey.DATASOURCE_INFO: datasource_info, - NodeRunMetadataKey.AGENT_LOG: agent_logs, - }, - inputs=parameters_for_log, - ) - ) - @classmethod def _extract_variable_selector_to_variable_mapping( cls, diff --git a/api/core/workflow/nodes/datasource/entities.py b/api/core/workflow/nodes/datasource/entities.py index 66e8adc431..68aa9fa34c 100644 --- a/api/core/workflow/nodes/datasource/entities.py +++ b/api/core/workflow/nodes/datasource/entities.py @@ -3,17 +3,15 @@ from typing import Any, Literal, Union from pydantic import BaseModel, field_validator from pydantic_core.core_schema import ValidationInfo -from core.tools.entities.tool_entities import ToolProviderType from core.workflow.nodes.base.entities import BaseNodeData class DatasourceEntity(BaseModel): provider_id: str - provider_type: ToolProviderType provider_name: str # redundancy - tool_name: str + datasource_name: str tool_label: str # redundancy - tool_configurations: dict[str, Any] + datasource_configurations: dict[str, Any] plugin_unique_identifier: str | None = None # redundancy @field_validator("tool_configurations", mode="before") diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index f039b233a5..1fa6c20bf9 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -1,7 +1,8 @@ import datetime import logging import time -from typing import Any, cast, Mapping +from collections.abc import Mapping +from typing import Any, cast from flask_login import current_user diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 0f5069f052..2e3cb604de 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -245,16 +245,20 @@ class DatasetService: rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity, ): # check if dataset name already exists - if db.session.query(Dataset).filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id).first(): + if ( + db.session.query(Dataset) + .filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id) + .first() + ): raise DatasetNameDuplicateError( f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists." ) - + pipeline = Pipeline( tenant_id=tenant_id, name=rag_pipeline_dataset_create_entity.name, description=rag_pipeline_dataset_create_entity.description, - created_by=current_user.id + created_by=current_user.id, ) db.session.add(pipeline) db.session.flush() @@ -268,7 +272,7 @@ class DatasetService: runtime_mode="rag_pipeline", icon_info=rag_pipeline_dataset_create_entity.icon_info, created_by=current_user.id, - pipeline_id=pipeline.id + pipeline_id=pipeline.id, ) db.session.add(dataset) db.session.commit() @@ -280,7 +284,11 @@ class DatasetService: rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity, ): # check if dataset name already exists - if db.session.query(Dataset).filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id).first(): + if ( + db.session.query(Dataset) + .filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id) + .first() + ): raise DatasetNameDuplicateError( f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists." ) @@ -299,7 +307,7 @@ class DatasetService: account=current_user, import_mode=ImportMode.YAML_CONTENT.value, yaml_content=rag_pipeline_dataset_create_entity.yaml_content, - dataset=dataset + dataset=dataset, ) return { "id": rag_pipeline_import_info.id, @@ -1254,281 +1262,282 @@ class DocumentService: return documents, batch - @staticmethod - def save_document_with_dataset_id( - dataset: Dataset, - knowledge_config: KnowledgeConfig, - account: Account | Any, - dataset_process_rule: Optional[DatasetProcessRule] = None, - created_from: str = "web", - ): - # check document limit - features = FeatureService.get_features(current_user.current_tenant_id) + # @staticmethod + # def save_document_with_dataset_id( + # dataset: Dataset, + # knowledge_config: KnowledgeConfig, + # account: Account | Any, + # dataset_process_rule: Optional[DatasetProcessRule] = None, + # created_from: str = "web", + # ): + # # check document limit + # features = FeatureService.get_features(current_user.current_tenant_id) - if features.billing.enabled: - if not knowledge_config.original_document_id: - count = 0 - if knowledge_config.data_source: - if knowledge_config.data_source.info_list.data_source_type == "upload_file": - upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore - count = len(upload_file_list) - elif knowledge_config.data_source.info_list.data_source_type == "notion_import": - notion_info_list = knowledge_config.data_source.info_list.notion_info_list - for notion_info in notion_info_list: # type: ignore - count = count + len(notion_info.pages) - elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": - website_info = knowledge_config.data_source.info_list.website_info_list - count = len(website_info.urls) # type: ignore - batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) + # if features.billing.enabled: + # if not knowledge_config.original_document_id: + # count = 0 + # if knowledge_config.data_source: + # if knowledge_config.data_source.info_list.data_source_type == "upload_file": + # upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids + # # type: ignore + # count = len(upload_file_list) + # elif knowledge_config.data_source.info_list.data_source_type == "notion_import": + # notion_info_list = knowledge_config.data_source.info_list.notion_info_list + # for notion_info in notion_info_list: # type: ignore + # count = count + len(notion_info.pages) + # elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": + # website_info = knowledge_config.data_source.info_list.website_info_list + # count = len(website_info.urls) # type: ignore + # batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) - if features.billing.subscription.plan == "sandbox" and count > 1: - raise ValueError("Your current plan does not support batch upload, please upgrade your plan.") - if count > batch_upload_limit: - raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") + # if features.billing.subscription.plan == "sandbox" and count > 1: + # raise ValueError("Your current plan does not support batch upload, please upgrade your plan.") + # if count > batch_upload_limit: + # raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") - DocumentService.check_documents_upload_quota(count, features) + # DocumentService.check_documents_upload_quota(count, features) - # if dataset is empty, update dataset data_source_type - if not dataset.data_source_type: - dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore + # # if dataset is empty, update dataset data_source_type + # if not dataset.data_source_type: + # dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore - if not dataset.indexing_technique: - if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: - raise ValueError("Indexing technique is invalid") + # if not dataset.indexing_technique: + # if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: + # raise ValueError("Indexing technique is invalid") - dataset.indexing_technique = knowledge_config.indexing_technique - if knowledge_config.indexing_technique == "high_quality": - model_manager = ModelManager() - if knowledge_config.embedding_model and knowledge_config.embedding_model_provider: - dataset_embedding_model = knowledge_config.embedding_model - dataset_embedding_model_provider = knowledge_config.embedding_model_provider - else: - embedding_model = model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING - ) - dataset_embedding_model = embedding_model.model - dataset_embedding_model_provider = embedding_model.provider - dataset.embedding_model = dataset_embedding_model - dataset.embedding_model_provider = dataset_embedding_model_provider - dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - dataset_embedding_model_provider, dataset_embedding_model - ) - dataset.collection_binding_id = dataset_collection_binding.id - if not dataset.retrieval_model: - default_retrieval_model = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, - "reranking_enable": False, - "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, - "top_k": 2, - "score_threshold_enabled": False, - } + # dataset.indexing_technique = knowledge_config.indexing_technique + # if knowledge_config.indexing_technique == "high_quality": + # model_manager = ModelManager() + # if knowledge_config.embedding_model and knowledge_config.embedding_model_provider: + # dataset_embedding_model = knowledge_config.embedding_model + # dataset_embedding_model_provider = knowledge_config.embedding_model_provider + # else: + # embedding_model = model_manager.get_default_model_instance( + # tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING + # ) + # dataset_embedding_model = embedding_model.model + # dataset_embedding_model_provider = embedding_model.provider + # dataset.embedding_model = dataset_embedding_model + # dataset.embedding_model_provider = dataset_embedding_model_provider + # dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + # dataset_embedding_model_provider, dataset_embedding_model + # ) + # dataset.collection_binding_id = dataset_collection_binding.id + # if not dataset.retrieval_model: + # default_retrieval_model = { + # "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + # "reranking_enable": False, + # "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + # "top_k": 2, + # "score_threshold_enabled": False, + # } - dataset.retrieval_model = ( - knowledge_config.retrieval_model.model_dump() - if knowledge_config.retrieval_model - else default_retrieval_model - ) # type: ignore + # dataset.retrieval_model = ( + # knowledge_config.retrieval_model.model_dump() + # if knowledge_config.retrieval_model + # else default_retrieval_model + # ) # type: ignore - documents = [] - if knowledge_config.original_document_id: - document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account) - documents.append(document) - batch = document.batch - else: - batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) - # save process rule - if not dataset_process_rule: - process_rule = knowledge_config.process_rule - if process_rule: - if process_rule.mode in ("custom", "hierarchical"): - dataset_process_rule = DatasetProcessRule( - dataset_id=dataset.id, - mode=process_rule.mode, - rules=process_rule.rules.model_dump_json() if process_rule.rules else None, - created_by=account.id, - ) - elif process_rule.mode == "automatic": - dataset_process_rule = DatasetProcessRule( - dataset_id=dataset.id, - mode=process_rule.mode, - rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), - created_by=account.id, - ) - else: - logging.warn( - f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule" - ) - return - db.session.add(dataset_process_rule) - db.session.commit() - lock_name = "add_document_lock_dataset_id_{}".format(dataset.id) - with redis_client.lock(lock_name, timeout=600): - position = DocumentService.get_documents_position(dataset.id) - document_ids = [] - duplicate_document_ids = [] - if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore - upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore - for file_id in upload_file_list: - file = ( - db.session.query(UploadFile) - .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) - .first() - ) + # documents = [] + # if knowledge_config.original_document_id: + # document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account) + # documents.append(document) + # batch = document.batch + # else: + # batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) + # # save process rule + # if not dataset_process_rule: + # process_rule = knowledge_config.process_rule + # if process_rule: + # if process_rule.mode in ("custom", "hierarchical"): + # dataset_process_rule = DatasetProcessRule( + # dataset_id=dataset.id, + # mode=process_rule.mode, + # rules=process_rule.rules.model_dump_json() if process_rule.rules else None, + # created_by=account.id, + # ) + # elif process_rule.mode == "automatic": + # dataset_process_rule = DatasetProcessRule( + # dataset_id=dataset.id, + # mode=process_rule.mode, + # rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), + # created_by=account.id, + # ) + # else: + # logging.warn( + # f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule" + # ) + # return + # db.session.add(dataset_process_rule) + # db.session.commit() + # lock_name = "add_document_lock_dataset_id_{}".format(dataset.id) + # with redis_client.lock(lock_name, timeout=600): + # position = DocumentService.get_documents_position(dataset.id) + # document_ids = [] + # duplicate_document_ids = [] + # if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore + # upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore + # for file_id in upload_file_list: + # file = ( + # db.session.query(UploadFile) + # .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) + # .first() + # ) - # raise error if file not found - if not file: - raise FileNotExistsError() + # # raise error if file not found + # if not file: + # raise FileNotExistsError() - file_name = file.name - data_source_info = { - "upload_file_id": file_id, - } - # check duplicate - if knowledge_config.duplicate: - document = Document.query.filter_by( - dataset_id=dataset.id, - tenant_id=current_user.current_tenant_id, - data_source_type="upload_file", - enabled=True, - name=file_name, - ).first() - if document: - document.dataset_process_rule_id = dataset_process_rule.id # type: ignore - document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - document.created_from = created_from - document.doc_form = knowledge_config.doc_form - document.doc_language = knowledge_config.doc_language - document.data_source_info = json.dumps(data_source_info) - document.batch = batch - document.indexing_status = "waiting" - db.session.add(document) - documents.append(document) - duplicate_document_ids.append(document.id) - continue - document = DocumentService.build_document( - dataset, - dataset_process_rule.id, # type: ignore - knowledge_config.data_source.info_list.data_source_type, # type: ignore - knowledge_config.doc_form, - knowledge_config.doc_language, - data_source_info, - created_from, - position, - account, - file_name, - batch, - ) - db.session.add(document) - db.session.flush() - document_ids.append(document.id) - documents.append(document) - position += 1 - elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore - notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore - if not notion_info_list: - raise ValueError("No notion info list found.") - exist_page_ids = [] - exist_document = {} - documents = Document.query.filter_by( - dataset_id=dataset.id, - tenant_id=current_user.current_tenant_id, - data_source_type="notion_import", - enabled=True, - ).all() - if documents: - for document in documents: - data_source_info = json.loads(document.data_source_info) - exist_page_ids.append(data_source_info["notion_page_id"]) - exist_document[data_source_info["notion_page_id"]] = document.id - for notion_info in notion_info_list: - workspace_id = notion_info.workspace_id - data_source_binding = DataSourceOauthBinding.query.filter( - db.and_( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', - ) - ).first() - if not data_source_binding: - raise ValueError("Data source binding not found.") - for page in notion_info.pages: - if page.page_id not in exist_page_ids: - data_source_info = { - "notion_workspace_id": workspace_id, - "notion_page_id": page.page_id, - "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, - "type": page.type, - } - # Truncate page name to 255 characters to prevent DB field length errors - truncated_page_name = page.page_name[:255] if page.page_name else "nopagename" - document = DocumentService.build_document( - dataset, - dataset_process_rule.id, # type: ignore - knowledge_config.data_source.info_list.data_source_type, # type: ignore - knowledge_config.doc_form, - knowledge_config.doc_language, - data_source_info, - created_from, - position, - account, - truncated_page_name, - batch, - ) - db.session.add(document) - db.session.flush() - document_ids.append(document.id) - documents.append(document) - position += 1 - else: - exist_document.pop(page.page_id) - # delete not selected documents - if len(exist_document) > 0: - clean_notion_document_task.delay(list(exist_document.values()), dataset.id) - elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore - website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore - if not website_info: - raise ValueError("No website info list found.") - urls = website_info.urls - for url in urls: - data_source_info = { - "url": url, - "provider": website_info.provider, - "job_id": website_info.job_id, - "only_main_content": website_info.only_main_content, - "mode": "crawl", - } - if len(url) > 255: - document_name = url[:200] + "..." - else: - document_name = url - document = DocumentService.build_document( - dataset, - dataset_process_rule.id, # type: ignore - knowledge_config.data_source.info_list.data_source_type, # type: ignore - knowledge_config.doc_form, - knowledge_config.doc_language, - data_source_info, - created_from, - position, - account, - document_name, - batch, - ) - db.session.add(document) - db.session.flush() - document_ids.append(document.id) - documents.append(document) - position += 1 - db.session.commit() + # file_name = file.name + # data_source_info = { + # "upload_file_id": file_id, + # } + # # check duplicate + # if knowledge_config.duplicate: + # document = Document.query.filter_by( + # dataset_id=dataset.id, + # tenant_id=current_user.current_tenant_id, + # data_source_type="upload_file", + # enabled=True, + # name=file_name, + # ).first() + # if document: + # document.dataset_process_rule_id = dataset_process_rule.id # type: ignore + # document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + # document.created_from = created_from + # document.doc_form = knowledge_config.doc_form + # document.doc_language = knowledge_config.doc_language + # document.data_source_info = json.dumps(data_source_info) + # document.batch = batch + # document.indexing_status = "waiting" + # db.session.add(document) + # documents.append(document) + # duplicate_document_ids.append(document.id) + # continue + # document = DocumentService.build_document( + # dataset, + # dataset_process_rule.id, # type: ignore + # knowledge_config.data_source.info_list.data_source_type, # type: ignore + # knowledge_config.doc_form, + # knowledge_config.doc_language, + # data_source_info, + # created_from, + # position, + # account, + # file_name, + # batch, + # ) + # db.session.add(document) + # db.session.flush() + # document_ids.append(document.id) + # documents.append(document) + # position += 1 + # elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore + # notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore + # if not notion_info_list: + # raise ValueError("No notion info list found.") + # exist_page_ids = [] + # exist_document = {} + # documents = Document.query.filter_by( + # dataset_id=dataset.id, + # tenant_id=current_user.current_tenant_id, + # data_source_type="notion_import", + # enabled=True, + # ).all() + # if documents: + # for document in documents: + # data_source_info = json.loads(document.data_source_info) + # exist_page_ids.append(data_source_info["notion_page_id"]) + # exist_document[data_source_info["notion_page_id"]] = document.id + # for notion_info in notion_info_list: + # workspace_id = notion_info.workspace_id + # data_source_binding = DataSourceOauthBinding.query.filter( + # db.and_( + # DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + # DataSourceOauthBinding.provider == "notion", + # DataSourceOauthBinding.disabled == False, + # DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', + # ) + # ).first() + # if not data_source_binding: + # raise ValueError("Data source binding not found.") + # for page in notion_info.pages: + # if page.page_id not in exist_page_ids: + # data_source_info = { + # "notion_workspace_id": workspace_id, + # "notion_page_id": page.page_id, + # "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, + # "type": page.type, + # } + # # Truncate page name to 255 characters to prevent DB field length errors + # truncated_page_name = page.page_name[:255] if page.page_name else "nopagename" + # document = DocumentService.build_document( + # dataset, + # dataset_process_rule.id, # type: ignore + # knowledge_config.data_source.info_list.data_source_type, # type: ignore + # knowledge_config.doc_form, + # knowledge_config.doc_language, + # data_source_info, + # created_from, + # position, + # account, + # truncated_page_name, + # batch, + # ) + # db.session.add(document) + # db.session.flush() + # document_ids.append(document.id) + # documents.append(document) + # position += 1 + # else: + # exist_document.pop(page.page_id) + # # delete not selected documents + # if len(exist_document) > 0: + # clean_notion_document_task.delay(list(exist_document.values()), dataset.id) + # elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore + # website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore + # if not website_info: + # raise ValueError("No website info list found.") + # urls = website_info.urls + # for url in urls: + # data_source_info = { + # "url": url, + # "provider": website_info.provider, + # "job_id": website_info.job_id, + # "only_main_content": website_info.only_main_content, + # "mode": "crawl", + # } + # if len(url) > 255: + # document_name = url[:200] + "..." + # else: + # document_name = url + # document = DocumentService.build_document( + # dataset, + # dataset_process_rule.id, # type: ignore + # knowledge_config.data_source.info_list.data_source_type, # type: ignore + # knowledge_config.doc_form, + # knowledge_config.doc_language, + # data_source_info, + # created_from, + # position, + # account, + # document_name, + # batch, + # ) + # db.session.add(document) + # db.session.flush() + # document_ids.append(document.id) + # documents.append(document) + # position += 1 + # db.session.commit() - # trigger async task - if document_ids: - document_indexing_task.delay(dataset.id, document_ids) - if duplicate_document_ids: - duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids) + # # trigger async task + # if document_ids: + # document_indexing_task.delay(dataset.id, document_ids) + # if duplicate_document_ids: + # duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids) - return documents, batch + # return documents, batch @staticmethod def check_documents_upload_quota(count: int, features: FeatureModel): diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index 3664c988e5..1c6dac55be 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -309,8 +309,10 @@ class RagPipelineDslService: dataset_collection_binding = ( db.session.query(DatasetCollectionBinding) .filter( - DatasetCollectionBinding.provider_name == knowledge_configuration.index_method.embedding_setting.embedding_provider_name, - DatasetCollectionBinding.model_name == knowledge_configuration.index_method.embedding_setting.embedding_model_name, + DatasetCollectionBinding.provider_name + == knowledge_configuration.index_method.embedding_setting.embedding_provider_name, + DatasetCollectionBinding.model_name + == knowledge_configuration.index_method.embedding_setting.embedding_model_name, DatasetCollectionBinding.type == "dataset", ) .order_by(DatasetCollectionBinding.created_at) diff --git a/api/services/rag_pipeline/rag_pipeline_manage_service.py b/api/services/rag_pipeline/rag_pipeline_manage_service.py new file mode 100644 index 0000000000..4d8d69f913 --- /dev/null +++ b/api/services/rag_pipeline/rag_pipeline_manage_service.py @@ -0,0 +1,14 @@ +from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity +from core.plugin.impl.datasource import PluginDatasourceManager + + +class RagPipelineManageService: + @staticmethod + def list_rag_pipeline_datasources(tenant_id: str) -> list[PluginDatasourceProviderEntity]: + """ + list rag pipeline datasources + """ + + # get all builtin providers + manager = PluginDatasourceManager() + return manager.fetch_datasource_providers(tenant_id) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index daf3773309..3ccd14415d 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -5,7 +5,6 @@ from pathlib import Path from sqlalchemy.orm import Session from configs import dify_config -from core.datasource.entities.api_entities import DatasourceProviderApiEntity from core.helper.position_helper import is_filtered from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import GenericProviderID, ToolProviderID @@ -17,7 +16,7 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ProviderConfigEncrypter from extensions.ext_database import db -from models.tools import BuiltinDatasourceProvider, BuiltinToolProvider +from models.tools import BuiltinToolProvider from services.tools.tools_transform_service import ToolTransformService logger = logging.getLogger(__name__) @@ -287,67 +286,6 @@ class BuiltinToolManageService: return BuiltinToolProviderSort.sort(result) - @staticmethod - def list_rag_pipeline_datasources(tenant_id: str) -> list[DatasourceProviderApiEntity]: - """ - list rag pipeline datasources - """ - # get all builtin providers - datasource_provider_controllers = ToolManager.list_datasource_providers(tenant_id) - - with db.session.no_autoflush: - # get all user added providers - db_providers: list[BuiltinDatasourceProvider] = ( - db.session.query(BuiltinDatasourceProvider) - .filter(BuiltinDatasourceProvider.tenant_id == tenant_id) - .all() - or [] - ) - - # find provider - def find_provider(provider): - return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None) - - result: list[DatasourceProviderApiEntity] = [] - - for provider_controller in datasource_provider_controllers: - try: - # handle include, exclude - if is_filtered( - include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore - exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore - data=provider_controller, - name_func=lambda x: x.identity.name, - ): - continue - - # convert provider controller to user provider - user_builtin_provider = ToolTransformService.builtin_datasource_provider_to_user_provider( - provider_controller=provider_controller, - db_provider=find_provider(provider_controller.entity.identity.name), - decrypt_credentials=True, - ) - - # add icon - ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider) - - datasources = provider_controller.get_datasources() - for datasource in datasources or []: - user_builtin_provider.datasources.append( - ToolTransformService.convert_datasource_entity_to_api_entity( - tenant_id=tenant_id, - datasource=datasource, - credentials=user_builtin_provider.original_credentials, - labels=ToolLabelManager.get_tool_labels(provider_controller), - ) - ) - - result.append(user_builtin_provider) - except Exception as e: - raise e - - return result - @staticmethod def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None: try: diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index e0c1ce7217..367121125b 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -5,11 +5,6 @@ from typing import Optional, Union, cast from yarl import URL from configs import dify_config -from core.datasource.__base.datasource_plugin import DatasourcePlugin -from core.datasource.__base.datasource_provider import DatasourcePluginProviderController -from core.datasource.__base.datasource_runtime import DatasourceRuntime -from core.datasource.entities.api_entities import DatasourceApiEntity, DatasourceProviderApiEntity -from core.datasource.entities.datasource_entities import DatasourceProviderType from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.provider import BuiltinToolProviderController @@ -26,7 +21,7 @@ from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.utils.configuration import ProviderConfigEncrypter from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool -from models.tools import ApiToolProvider, BuiltinDatasourceProvider, BuiltinToolProvider, WorkflowToolProvider +from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider logger = logging.getLogger(__name__) @@ -145,64 +140,6 @@ class ToolTransformService: return result - @classmethod - def builtin_datasource_provider_to_user_provider( - cls, - provider_controller: DatasourcePluginProviderController, - db_provider: Optional[BuiltinDatasourceProvider], - decrypt_credentials: bool = True, - ) -> DatasourceProviderApiEntity: - """ - convert provider controller to user provider - """ - result = DatasourceProviderApiEntity( - id=provider_controller.entity.identity.name, - author=provider_controller.entity.identity.author, - name=provider_controller.entity.identity.name, - description=provider_controller.entity.identity.description, - icon=provider_controller.entity.identity.icon, - label=provider_controller.entity.identity.label, - type=DatasourceProviderType.RAG_PIPELINE, - masked_credentials={}, - is_team_authorization=False, - plugin_id=provider_controller.plugin_id, - plugin_unique_identifier=provider_controller.plugin_unique_identifier, - datasources=[], - ) - - # get credentials schema - schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()} - - for name, value in schema.items(): - if result.masked_credentials: - result.masked_credentials[name] = "" - - # check if the provider need credentials - if not provider_controller.need_credentials: - result.is_team_authorization = True - result.allow_delete = False - elif db_provider: - result.is_team_authorization = True - - if decrypt_credentials: - credentials = db_provider.credentials - - # init tool configuration - tool_configuration = ProviderConfigEncrypter( - tenant_id=db_provider.tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) - # decrypt the credentials and mask the credentials - decrypted_credentials = tool_configuration.decrypt(data=credentials) - masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials) - - result.masked_credentials = masked_credentials - result.original_credentials = decrypted_credentials - - return result - @staticmethod def api_provider_to_controller( db_provider: ApiToolProvider, @@ -367,48 +304,3 @@ class ToolTransformService: parameters=tool.parameters, labels=labels or [], ) - - @staticmethod - def convert_datasource_entity_to_api_entity( - datasource: DatasourcePlugin, - tenant_id: str, - credentials: dict | None = None, - labels: list[str] | None = None, - ) -> DatasourceApiEntity: - """ - convert tool to user tool - """ - # fork tool runtime - datasource = datasource.fork_datasource_runtime( - runtime=DatasourceRuntime( - credentials=credentials or {}, - tenant_id=tenant_id, - ) - ) - - # get datasource parameters - parameters = datasource.entity.parameters or [] - # get datasource runtime parameters - runtime_parameters = datasource.get_runtime_parameters() - # override parameters - current_parameters = parameters.copy() - for runtime_parameter in runtime_parameters: - found = False - for index, parameter in enumerate(current_parameters): - if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form: - current_parameters[index] = runtime_parameter - found = True - break - - if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM: - current_parameters.append(runtime_parameter) - - return DatasourceApiEntity( - author=datasource.entity.identity.author, - name=datasource.entity.identity.name, - label=datasource.entity.identity.label, - description=datasource.entity.description.human if datasource.entity.description else I18nObject(en_US=""), - output_schema=datasource.entity.output_schema, - parameters=current_parameters, - labels=labels or [], - )