fix: deleted_tools

This commit is contained in:
Yeuoly 2024-11-25 23:22:17 +08:00
parent ba3659a792
commit c2ce8e638e
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
5 changed files with 133 additions and 15 deletions

View File

@ -139,6 +139,7 @@ class GenericProviderID:
organization: str organization: str
plugin_name: str plugin_name: str
provider_name: str provider_name: str
is_hardcoded: bool
def to_string(self) -> str: def to_string(self) -> str:
return str(self) return str(self)
@ -146,7 +147,7 @@ class GenericProviderID:
def __str__(self) -> str: def __str__(self) -> str:
return f"{self.organization}/{self.plugin_name}/{self.provider_name}" return f"{self.organization}/{self.plugin_name}/{self.provider_name}"
def __init__(self, value: str) -> None: def __init__(self, value: str, is_hardcoded: bool = False) -> None:
# check if the value is a valid plugin id with format: $organization/$plugin_name/$provider_name # check if the value is a valid plugin id with format: $organization/$plugin_name/$provider_name
if not re.match(r"^[a-z0-9_-]+\/[a-z0-9_-]+\/[a-z0-9_-]+$", value): if not re.match(r"^[a-z0-9_-]+\/[a-z0-9_-]+\/[a-z0-9_-]+$", value):
# check if matches [a-z0-9_-]+, if yes, append with langgenius/$value/$value # check if matches [a-z0-9_-]+, if yes, append with langgenius/$value/$value
@ -156,6 +157,7 @@ class GenericProviderID:
raise ValueError("Invalid plugin id") raise ValueError("Invalid plugin id")
self.organization, self.plugin_name, self.provider_name = value.split("/") self.organization, self.plugin_name, self.provider_name = value.split("/")
self.is_hardcoded = is_hardcoded
@property @property
def plugin_id(self) -> str: def plugin_id(self) -> str:

View File

@ -4,6 +4,7 @@ from pydantic import BaseModel
from core.plugin.entities.bundle import PluginBundleDependency from core.plugin.entities.bundle import PluginBundleDependency
from core.plugin.entities.plugin import ( from core.plugin.entities.plugin import (
GenericProviderID,
PluginDeclaration, PluginDeclaration,
PluginEntity, PluginEntity,
PluginInstallation, PluginInstallation,
@ -224,3 +225,23 @@ class PluginInstallationManager(BasePluginManager):
}, },
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
) )
def check_tools_existence(self, tenant_id: str, provider_ids: Sequence[GenericProviderID]) -> Sequence[bool]:
"""
Check if the tools exist
"""
return self._request_with_plugin_daemon_response(
"POST",
f"plugin/{tenant_id}/management/tools/check_existence",
list[bool],
data={
"provider_ids": [
{
"plugin_id": provider_id.plugin_id,
"provider_name": provider_id.provider_name,
}
for provider_id in provider_ids
]
},
headers={"Content-Type": "application/json"},
)

View File

@ -149,6 +149,12 @@ site_fields = {
"updated_at": TimestampField, "updated_at": TimestampField,
} }
deleted_tool_fields = {
"type": fields.String,
"tool_name": fields.String,
"provider_id": fields.String,
}
app_detail_fields_with_site = { app_detail_fields_with_site = {
"id": fields.String, "id": fields.String,
"name": fields.String, "name": fields.String,
@ -169,9 +175,10 @@ app_detail_fields_with_site = {
"created_at": TimestampField, "created_at": TimestampField,
"updated_by": fields.String, "updated_by": fields.String,
"updated_at": TimestampField, "updated_at": TimestampField,
"deleted_tools": fields.List(fields.String), "deleted_tools": fields.List(fields.Nested(deleted_tool_fields)),
} }
app_site_fields = { app_site_fields = {
"app_id": fields.String, "app_id": fields.String,
"access_token": fields.String(attribute="code"), "access_token": fields.String(attribute="code"),

View File

@ -1,4 +1,5 @@
import json import json
import logging
import re import re
import uuid import uuid
from collections.abc import Mapping from collections.abc import Mapping
@ -6,6 +7,10 @@ from datetime import datetime
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from core.plugin.entities.plugin import GenericProviderID
from core.tools.entities.tool_entities import ToolProviderType
from services.plugin.plugin_service import PluginService
if TYPE_CHECKING: if TYPE_CHECKING:
from models.workflow import Workflow from models.workflow import Workflow
@ -16,7 +21,7 @@ import sqlalchemy as sa
from flask import request from flask import request
from flask_login import UserMixin from flask_login import UserMixin
from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, Session, mapped_column
from configs import dify_config from configs import dify_config
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
@ -30,6 +35,8 @@ from models.enums import CreatedByRole
from .account import Account, Tenant from .account import Account, Tenant
from .types import StringUUID from .types import StringUUID
logger = logging.getLogger(__name__)
class DifySetup(Base): class DifySetup(Base):
__tablename__ = "dify_setups" __tablename__ = "dify_setups"
@ -162,47 +169,114 @@ class App(Base):
@property @property
def deleted_tools(self) -> list: def deleted_tools(self) -> list:
from core.tools.tool_manager import ToolManager
# get agent mode tools # get agent mode tools
app_model_config = self.app_model_config app_model_config = self.app_model_config
if not app_model_config: if not app_model_config:
return [] return []
if not app_model_config.agent_mode: if not app_model_config.agent_mode:
return [] return []
agent_mode = app_model_config.agent_mode_dict agent_mode = app_model_config.agent_mode_dict
tools = agent_mode.get("tools", []) tools = agent_mode.get("tools", [])
provider_ids = [] api_provider_ids: list[str] = []
builtin_provider_ids: list[GenericProviderID] = []
for tool in tools: for tool in tools:
keys = list(tool.keys()) keys = list(tool.keys())
if len(keys) >= 4: if len(keys) >= 4:
provider_type = tool.get("provider_type", "") provider_type = tool.get("provider_type", "")
provider_id = tool.get("provider_id", "") provider_id = tool.get("provider_id", "")
if provider_type == "api": if provider_type == ToolProviderType.API.value:
# check if provider id is a uuid string, if not, skip
try: try:
uuid.UUID(provider_id) uuid.UUID(provider_id)
except Exception: except Exception:
continue continue
provider_ids.append(provider_id) api_provider_ids.append(provider_id)
if provider_type == ToolProviderType.BUILT_IN.value:
try:
# check if it's hardcoded
try:
ToolManager.get_hardcoded_provider(provider_id)
is_hardcoded = True
except Exception:
is_hardcoded = False
if not provider_ids: provider_id = GenericProviderID(provider_id, is_hardcoded)
except Exception:
logger.exception(f"Invalid builtin provider id: {provider_id}")
continue
builtin_provider_ids.append(provider_id)
if not api_provider_ids and not builtin_provider_ids:
return [] return []
api_providers = db.session.execute( with Session(db.engine) as session:
text("SELECT id FROM tool_api_providers WHERE id IN :provider_ids"), {"provider_ids": tuple(provider_ids)} if api_provider_ids:
existing_api_providers = [
api_provider.id
for api_provider in session.execute(
text("SELECT id FROM tool_api_providers WHERE id IN :provider_ids"),
{"provider_ids": tuple(api_provider_ids)},
).fetchall() ).fetchall()
]
else:
existing_api_providers = []
if builtin_provider_ids:
# get the non-hardcoded builtin providers
non_hardcoded_builtin_providers = [
provider_id for provider_id in builtin_provider_ids if not provider_id.is_hardcoded
]
if non_hardcoded_builtin_providers:
existence = list(PluginService.check_tools_existence(self.tenant_id, non_hardcoded_builtin_providers))
else:
existence = []
# add the hardcoded builtin providers
existence.extend([True] * (len(builtin_provider_ids) - len(non_hardcoded_builtin_providers)))
builtin_provider_ids = non_hardcoded_builtin_providers + [
provider_id for provider_id in builtin_provider_ids if provider_id.is_hardcoded
]
else:
existence = []
existing_builtin_providers = {
provider_id.provider_name: existence[i] for i, provider_id in enumerate(builtin_provider_ids)
}
deleted_tools = [] deleted_tools = []
current_api_provider_ids = [str(api_provider.id) for api_provider in api_providers]
for tool in tools: for tool in tools:
keys = list(tool.keys()) keys = list(tool.keys())
if len(keys) >= 4: if len(keys) >= 4:
provider_type = tool.get("provider_type", "") provider_type = tool.get("provider_type", "")
provider_id = tool.get("provider_id", "") provider_id = tool.get("provider_id", "")
if provider_type == "api" and provider_id not in current_api_provider_ids:
deleted_tools.append(tool["tool_name"]) if provider_type == ToolProviderType.API.value:
if provider_id not in existing_api_providers:
deleted_tools.append(
{
"type": ToolProviderType.API.value,
"tool_name": tool["tool_name"],
"provider_id": provider_id,
}
)
if provider_type == ToolProviderType.BUILT_IN.value:
generic_provider_id = GenericProviderID(provider_id)
if not existing_builtin_providers[generic_provider_id.provider_name]:
deleted_tools.append(
{
"type": ToolProviderType.BUILT_IN.value,
"tool_name": tool["tool_name"],
"provider_id": provider_id, # use the original one
}
)
return deleted_tools return deleted_tools

View File

@ -7,7 +7,13 @@ from core.helper import marketplace
from core.helper.download import download_with_size_limit from core.helper.download import download_with_size_limit
from core.helper.marketplace import download_plugin_pkg from core.helper.marketplace import download_plugin_pkg
from core.plugin.entities.bundle import PluginBundleDependency from core.plugin.entities.bundle import PluginBundleDependency
from core.plugin.entities.plugin import PluginDeclaration, PluginEntity, PluginInstallation, PluginInstallationSource from core.plugin.entities.plugin import (
GenericProviderID,
PluginDeclaration,
PluginEntity,
PluginInstallation,
PluginInstallationSource,
)
from core.plugin.entities.plugin_daemon import PluginInstallTask, PluginUploadResponse from core.plugin.entities.plugin_daemon import PluginInstallTask, PluginUploadResponse
from core.plugin.manager.asset import PluginAssetManager from core.plugin.manager.asset import PluginAssetManager
from core.plugin.manager.debugging import PluginDebuggingManager from core.plugin.manager.debugging import PluginDebuggingManager
@ -279,3 +285,11 @@ class PluginService:
def uninstall(tenant_id: str, plugin_installation_id: str) -> bool: def uninstall(tenant_id: str, plugin_installation_id: str) -> bool:
manager = PluginInstallationManager() manager = PluginInstallationManager()
return manager.uninstall(tenant_id, plugin_installation_id) return manager.uninstall(tenant_id, plugin_installation_id)
@staticmethod
def check_tools_existence(tenant_id: str, provider_ids: Sequence[GenericProviderID]) -> Sequence[bool]:
"""
Check if the tools exist
"""
manager = PluginInstallationManager()
return manager.check_tools_existence(tenant_id, provider_ids)