From c2ce8e638e672a2312cf3bc81d3fa7a0c1ac810e Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 25 Nov 2024 23:22:17 +0800 Subject: [PATCH] fix: deleted_tools --- api/core/plugin/entities/plugin.py | 4 +- api/core/plugin/manager/plugin.py | 21 ++++++ api/fields/app_fields.py | 9 ++- api/models/model.py | 98 +++++++++++++++++++++++---- api/services/plugin/plugin_service.py | 16 ++++- 5 files changed, 133 insertions(+), 15 deletions(-) diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index d603f37b3d..16133d5427 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -139,6 +139,7 @@ class GenericProviderID: organization: str plugin_name: str provider_name: str + is_hardcoded: bool def to_string(self) -> str: return str(self) @@ -146,7 +147,7 @@ class GenericProviderID: def __str__(self) -> str: return f"{self.organization}/{self.plugin_name}/{self.provider_name}" - def __init__(self, value: str) -> None: + def __init__(self, value: str, is_hardcoded: bool = False) -> None: # check if the value is a valid plugin id with format: $organization/$plugin_name/$provider_name if not re.match(r"^[a-z0-9_-]+\/[a-z0-9_-]+\/[a-z0-9_-]+$", value): # check if matches [a-z0-9_-]+, if yes, append with langgenius/$value/$value @@ -156,6 +157,7 @@ class GenericProviderID: raise ValueError("Invalid plugin id") self.organization, self.plugin_name, self.provider_name = value.split("/") + self.is_hardcoded = is_hardcoded @property def plugin_id(self) -> str: diff --git a/api/core/plugin/manager/plugin.py b/api/core/plugin/manager/plugin.py index 82073e34b3..14fba6c989 100644 --- a/api/core/plugin/manager/plugin.py +++ b/api/core/plugin/manager/plugin.py @@ -4,6 +4,7 @@ from pydantic import BaseModel from core.plugin.entities.bundle import PluginBundleDependency from core.plugin.entities.plugin import ( + GenericProviderID, PluginDeclaration, PluginEntity, PluginInstallation, @@ -224,3 +225,23 @@ class PluginInstallationManager(BasePluginManager): }, headers={"Content-Type": "application/json"}, ) + + def check_tools_existence(self, tenant_id: str, provider_ids: Sequence[GenericProviderID]) -> Sequence[bool]: + """ + Check if the tools exist + """ + return self._request_with_plugin_daemon_response( + "POST", + f"plugin/{tenant_id}/management/tools/check_existence", + list[bool], + data={ + "provider_ids": [ + { + "plugin_id": provider_id.plugin_id, + "provider_name": provider_id.provider_name, + } + for provider_id in provider_ids + ] + }, + headers={"Content-Type": "application/json"}, + ) diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index abb27fdad1..ed5be866f6 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -149,6 +149,12 @@ site_fields = { "updated_at": TimestampField, } +deleted_tool_fields = { + "type": fields.String, + "tool_name": fields.String, + "provider_id": fields.String, +} + app_detail_fields_with_site = { "id": fields.String, "name": fields.String, @@ -169,9 +175,10 @@ app_detail_fields_with_site = { "created_at": TimestampField, "updated_by": fields.String, "updated_at": TimestampField, - "deleted_tools": fields.List(fields.String), + "deleted_tools": fields.List(fields.Nested(deleted_tool_fields)), } + app_site_fields = { "app_id": fields.String, "access_token": fields.String(attribute="code"), diff --git a/api/models/model.py b/api/models/model.py index c67c1051f2..2423513085 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1,4 +1,5 @@ import json +import logging import re import uuid from collections.abc import Mapping @@ -6,6 +7,10 @@ from datetime import datetime from enum import Enum from typing import TYPE_CHECKING, Optional +from core.plugin.entities.plugin import GenericProviderID +from core.tools.entities.tool_entities import ToolProviderType +from services.plugin.plugin_service import PluginService + if TYPE_CHECKING: from models.workflow import Workflow @@ -16,7 +21,7 @@ import sqlalchemy as sa from flask import request from flask_login import UserMixin from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text -from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType @@ -30,6 +35,8 @@ from models.enums import CreatedByRole from .account import Account, Tenant from .types import StringUUID +logger = logging.getLogger(__name__) + class DifySetup(Base): __tablename__ = "dify_setups" @@ -162,47 +169,114 @@ class App(Base): @property def deleted_tools(self) -> list: + from core.tools.tool_manager import ToolManager + # get agent mode tools app_model_config = self.app_model_config if not app_model_config: return [] + if not app_model_config.agent_mode: return [] + agent_mode = app_model_config.agent_mode_dict tools = agent_mode.get("tools", []) - provider_ids = [] + api_provider_ids: list[str] = [] + builtin_provider_ids: list[GenericProviderID] = [] for tool in tools: keys = list(tool.keys()) if len(keys) >= 4: provider_type = tool.get("provider_type", "") provider_id = tool.get("provider_id", "") - if provider_type == "api": - # check if provider id is a uuid string, if not, skip + if provider_type == ToolProviderType.API.value: try: uuid.UUID(provider_id) except Exception: continue - provider_ids.append(provider_id) + api_provider_ids.append(provider_id) + if provider_type == ToolProviderType.BUILT_IN.value: + try: + # check if it's hardcoded + try: + ToolManager.get_hardcoded_provider(provider_id) + is_hardcoded = True + except Exception: + is_hardcoded = False - if not provider_ids: + provider_id = GenericProviderID(provider_id, is_hardcoded) + except Exception: + logger.exception(f"Invalid builtin provider id: {provider_id}") + continue + + builtin_provider_ids.append(provider_id) + + if not api_provider_ids and not builtin_provider_ids: return [] - api_providers = db.session.execute( - text("SELECT id FROM tool_api_providers WHERE id IN :provider_ids"), {"provider_ids": tuple(provider_ids)} - ).fetchall() + with Session(db.engine) as session: + if api_provider_ids: + existing_api_providers = [ + api_provider.id + for api_provider in session.execute( + text("SELECT id FROM tool_api_providers WHERE id IN :provider_ids"), + {"provider_ids": tuple(api_provider_ids)}, + ).fetchall() + ] + else: + existing_api_providers = [] + + if builtin_provider_ids: + # get the non-hardcoded builtin providers + non_hardcoded_builtin_providers = [ + provider_id for provider_id in builtin_provider_ids if not provider_id.is_hardcoded + ] + if non_hardcoded_builtin_providers: + existence = list(PluginService.check_tools_existence(self.tenant_id, non_hardcoded_builtin_providers)) + else: + existence = [] + # add the hardcoded builtin providers + existence.extend([True] * (len(builtin_provider_ids) - len(non_hardcoded_builtin_providers))) + builtin_provider_ids = non_hardcoded_builtin_providers + [ + provider_id for provider_id in builtin_provider_ids if provider_id.is_hardcoded + ] + else: + existence = [] + + existing_builtin_providers = { + provider_id.provider_name: existence[i] for i, provider_id in enumerate(builtin_provider_ids) + } deleted_tools = [] - current_api_provider_ids = [str(api_provider.id) for api_provider in api_providers] for tool in tools: keys = list(tool.keys()) if len(keys) >= 4: provider_type = tool.get("provider_type", "") provider_id = tool.get("provider_id", "") - if provider_type == "api" and provider_id not in current_api_provider_ids: - deleted_tools.append(tool["tool_name"]) + + if provider_type == ToolProviderType.API.value: + if provider_id not in existing_api_providers: + deleted_tools.append( + { + "type": ToolProviderType.API.value, + "tool_name": tool["tool_name"], + "provider_id": provider_id, + } + ) + + if provider_type == ToolProviderType.BUILT_IN.value: + generic_provider_id = GenericProviderID(provider_id) + + if not existing_builtin_providers[generic_provider_id.provider_name]: + deleted_tools.append( + { + "type": ToolProviderType.BUILT_IN.value, + "tool_name": tool["tool_name"], + "provider_id": provider_id, # use the original one + } + ) return deleted_tools diff --git a/api/services/plugin/plugin_service.py b/api/services/plugin/plugin_service.py index 5fb136e634..701b7c6171 100644 --- a/api/services/plugin/plugin_service.py +++ b/api/services/plugin/plugin_service.py @@ -7,7 +7,13 @@ from core.helper import marketplace from core.helper.download import download_with_size_limit from core.helper.marketplace import download_plugin_pkg from core.plugin.entities.bundle import PluginBundleDependency -from core.plugin.entities.plugin import PluginDeclaration, PluginEntity, PluginInstallation, PluginInstallationSource +from core.plugin.entities.plugin import ( + GenericProviderID, + PluginDeclaration, + PluginEntity, + PluginInstallation, + PluginInstallationSource, +) from core.plugin.entities.plugin_daemon import PluginInstallTask, PluginUploadResponse from core.plugin.manager.asset import PluginAssetManager from core.plugin.manager.debugging import PluginDebuggingManager @@ -279,3 +285,11 @@ class PluginService: def uninstall(tenant_id: str, plugin_installation_id: str) -> bool: manager = PluginInstallationManager() return manager.uninstall(tenant_id, plugin_installation_id) + + @staticmethod + def check_tools_existence(tenant_id: str, provider_ids: Sequence[GenericProviderID]) -> Sequence[bool]: + """ + Check if the tools exist + """ + manager = PluginInstallationManager() + return manager.check_tools_existence(tenant_id, provider_ids)