diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index df99c82d2b..e97f0ca157 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Union, cast from yarl import URL import contexts -from core.plugin.entities.plugin import GenericProviderID +from core.plugin.entities.plugin import ToolProviderID from core.plugin.manager.tool import PluginToolManager from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_runtime import ToolRuntime @@ -188,7 +188,7 @@ class ToolManager: ) if isinstance(provider_controller, PluginToolProviderController): - provider_id_entity = GenericProviderID(provider_id) + provider_id_entity = ToolProviderID(provider_id) # get credentials builtin_provider: BuiltinToolProvider | None = ( db.session.query(BuiltinToolProvider) @@ -572,95 +572,96 @@ class ToolManager: else: filters.append(typ) - if "builtin" in filters: - # get builtin providers - builtin_providers = cls.list_builtin_providers(tenant_id) + with db.session.no_autoflush: + if "builtin" in filters: + # get builtin providers + builtin_providers = cls.list_builtin_providers(tenant_id) - # get db builtin providers - db_builtin_providers: list[BuiltinToolProvider] = ( - db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() - ) - - # rewrite db_builtin_providers - for db_provider in db_builtin_providers: - tool_provider_id = GenericProviderID(db_provider.provider) - db_provider.provider = tool_provider_id.to_string() - - def find_db_builtin_provider(provider): - return next((x for x in db_builtin_providers if x.provider == provider), None) - - # append builtin providers - for provider in builtin_providers: - # handle include, exclude - if is_filtered( - include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET), - exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET), - data=provider, - name_func=lambda x: x.identity.name, - ): - continue - - user_provider = ToolTransformService.builtin_provider_to_user_provider( - provider_controller=provider, - db_provider=find_db_builtin_provider(provider.entity.identity.name), - decrypt_credentials=False, + # get db builtin providers + db_builtin_providers: list[BuiltinToolProvider] = ( + db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() ) - if isinstance(provider, PluginToolProviderController): - result_providers[f"plugin_provider.{user_provider.name}"] = user_provider - else: - result_providers[f"builtin_provider.{user_provider.name}"] = user_provider + # rewrite db_builtin_providers + for db_provider in db_builtin_providers: + tool_provider_id = str(ToolProviderID(db_provider.provider)) + db_provider.provider = tool_provider_id - # get db api providers + def find_db_builtin_provider(provider): + return next((x for x in db_builtin_providers if x.provider == provider), None) - if "api" in filters: - db_api_providers: list[ApiToolProvider] = ( - db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() - ) + # append builtin providers + for provider in builtin_providers: + # handle include, exclude + if is_filtered( + include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET), + exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET), + data=provider, + name_func=lambda x: x.identity.name, + ): + continue - api_provider_controllers: list[dict[str, Any]] = [ - {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)} - for provider in db_api_providers - ] - - # get labels - labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers]) - - for api_provider_controller in api_provider_controllers: - user_provider = ToolTransformService.api_provider_to_user_provider( - provider_controller=api_provider_controller["controller"], - db_provider=api_provider_controller["provider"], - decrypt_credentials=False, - labels=labels.get(api_provider_controller["controller"].provider_id, []), - ) - result_providers[f"api_provider.{user_provider.name}"] = user_provider - - if "workflow" in filters: - # get workflow providers - workflow_providers: list[WorkflowToolProvider] = ( - db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() - ) - - workflow_provider_controllers: list[WorkflowToolProviderController] = [] - for provider in workflow_providers: - try: - workflow_provider_controllers.append( - ToolTransformService.workflow_provider_to_controller(db_provider=provider) + user_provider = ToolTransformService.builtin_provider_to_user_provider( + provider_controller=provider, + db_provider=find_db_builtin_provider(provider.entity.identity.name), + decrypt_credentials=False, ) - except Exception: - # app has been deleted - pass - labels = ToolLabelManager.get_tools_labels( - [cast(ToolProviderController, controller) for controller in workflow_provider_controllers] - ) + if isinstance(provider, PluginToolProviderController): + result_providers[f"plugin_provider.{user_provider.name}"] = user_provider + else: + result_providers[f"builtin_provider.{user_provider.name}"] = user_provider - for provider_controller in workflow_provider_controllers: - user_provider = ToolTransformService.workflow_provider_to_user_provider( - provider_controller=provider_controller, - labels=labels.get(provider_controller.provider_id, []), + # get db api providers + + if "api" in filters: + db_api_providers: list[ApiToolProvider] = ( + db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() ) - result_providers[f"workflow_provider.{user_provider.name}"] = user_provider + + api_provider_controllers: list[dict[str, Any]] = [ + {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)} + for provider in db_api_providers + ] + + # get labels + labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers]) + + for api_provider_controller in api_provider_controllers: + user_provider = ToolTransformService.api_provider_to_user_provider( + provider_controller=api_provider_controller["controller"], + db_provider=api_provider_controller["provider"], + decrypt_credentials=False, + labels=labels.get(api_provider_controller["controller"].provider_id, []), + ) + result_providers[f"api_provider.{user_provider.name}"] = user_provider + + if "workflow" in filters: + # get workflow providers + workflow_providers: list[WorkflowToolProvider] = ( + db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() + ) + + workflow_provider_controllers: list[WorkflowToolProviderController] = [] + for provider in workflow_providers: + try: + workflow_provider_controllers.append( + ToolTransformService.workflow_provider_to_controller(db_provider=provider) + ) + except Exception: + # app has been deleted + pass + + labels = ToolLabelManager.get_tools_labels( + [cast(ToolProviderController, controller) for controller in workflow_provider_controllers] + ) + + for provider_controller in workflow_provider_controllers: + user_provider = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=provider_controller, + labels=labels.get(provider_controller.provider_id, []), + ) + result_providers[f"workflow_provider.{user_provider.name}"] = user_provider return BuiltinToolProviderSort.sort(list(result_providers.values()))