mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-20 08:09:24 +08:00
fix: deleted_tools
This commit is contained in:
parent
ba3659a792
commit
c2ce8e638e
@ -139,6 +139,7 @@ class GenericProviderID:
|
||||
organization: str
|
||||
plugin_name: str
|
||||
provider_name: str
|
||||
is_hardcoded: bool
|
||||
|
||||
def to_string(self) -> str:
|
||||
return str(self)
|
||||
@ -146,7 +147,7 @@ class GenericProviderID:
|
||||
def __str__(self) -> str:
|
||||
return f"{self.organization}/{self.plugin_name}/{self.provider_name}"
|
||||
|
||||
def __init__(self, value: str) -> None:
|
||||
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
|
||||
# check if the value is a valid plugin id with format: $organization/$plugin_name/$provider_name
|
||||
if not re.match(r"^[a-z0-9_-]+\/[a-z0-9_-]+\/[a-z0-9_-]+$", value):
|
||||
# check if matches [a-z0-9_-]+, if yes, append with langgenius/$value/$value
|
||||
@ -156,6 +157,7 @@ class GenericProviderID:
|
||||
raise ValueError("Invalid plugin id")
|
||||
|
||||
self.organization, self.plugin_name, self.provider_name = value.split("/")
|
||||
self.is_hardcoded = is_hardcoded
|
||||
|
||||
@property
|
||||
def plugin_id(self) -> str:
|
||||
|
||||
@ -4,6 +4,7 @@ from pydantic import BaseModel
|
||||
|
||||
from core.plugin.entities.bundle import PluginBundleDependency
|
||||
from core.plugin.entities.plugin import (
|
||||
GenericProviderID,
|
||||
PluginDeclaration,
|
||||
PluginEntity,
|
||||
PluginInstallation,
|
||||
@ -224,3 +225,23 @@ class PluginInstallationManager(BasePluginManager):
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
def check_tools_existence(self, tenant_id: str, provider_ids: Sequence[GenericProviderID]) -> Sequence[bool]:
|
||||
"""
|
||||
Check if the tools exist
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/management/tools/check_existence",
|
||||
list[bool],
|
||||
data={
|
||||
"provider_ids": [
|
||||
{
|
||||
"plugin_id": provider_id.plugin_id,
|
||||
"provider_name": provider_id.provider_name,
|
||||
}
|
||||
for provider_id in provider_ids
|
||||
]
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
@ -149,6 +149,12 @@ site_fields = {
|
||||
"updated_at": TimestampField,
|
||||
}
|
||||
|
||||
deleted_tool_fields = {
|
||||
"type": fields.String,
|
||||
"tool_name": fields.String,
|
||||
"provider_id": fields.String,
|
||||
}
|
||||
|
||||
app_detail_fields_with_site = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
@ -169,9 +175,10 @@ app_detail_fields_with_site = {
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
"deleted_tools": fields.List(fields.String),
|
||||
"deleted_tools": fields.List(fields.Nested(deleted_tool_fields)),
|
||||
}
|
||||
|
||||
|
||||
app_site_fields = {
|
||||
"app_id": fields.String,
|
||||
"access_token": fields.String(attribute="code"),
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
@ -6,6 +7,10 @@ from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from core.plugin.entities.plugin import GenericProviderID
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.workflow import Workflow
|
||||
|
||||
@ -16,7 +21,7 @@ import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_login import UserMixin
|
||||
from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
|
||||
@ -30,6 +35,8 @@ from models.enums import CreatedByRole
|
||||
from .account import Account, Tenant
|
||||
from .types import StringUUID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DifySetup(Base):
|
||||
__tablename__ = "dify_setups"
|
||||
@ -162,47 +169,114 @@ class App(Base):
|
||||
|
||||
@property
|
||||
def deleted_tools(self) -> list:
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
# get agent mode tools
|
||||
app_model_config = self.app_model_config
|
||||
if not app_model_config:
|
||||
return []
|
||||
|
||||
if not app_model_config.agent_mode:
|
||||
return []
|
||||
|
||||
agent_mode = app_model_config.agent_mode_dict
|
||||
tools = agent_mode.get("tools", [])
|
||||
|
||||
provider_ids = []
|
||||
api_provider_ids: list[str] = []
|
||||
builtin_provider_ids: list[GenericProviderID] = []
|
||||
|
||||
for tool in tools:
|
||||
keys = list(tool.keys())
|
||||
if len(keys) >= 4:
|
||||
provider_type = tool.get("provider_type", "")
|
||||
provider_id = tool.get("provider_id", "")
|
||||
if provider_type == "api":
|
||||
# check if provider id is a uuid string, if not, skip
|
||||
if provider_type == ToolProviderType.API.value:
|
||||
try:
|
||||
uuid.UUID(provider_id)
|
||||
except Exception:
|
||||
continue
|
||||
provider_ids.append(provider_id)
|
||||
api_provider_ids.append(provider_id)
|
||||
if provider_type == ToolProviderType.BUILT_IN.value:
|
||||
try:
|
||||
# check if it's hardcoded
|
||||
try:
|
||||
ToolManager.get_hardcoded_provider(provider_id)
|
||||
is_hardcoded = True
|
||||
except Exception:
|
||||
is_hardcoded = False
|
||||
|
||||
if not provider_ids:
|
||||
provider_id = GenericProviderID(provider_id, is_hardcoded)
|
||||
except Exception:
|
||||
logger.exception(f"Invalid builtin provider id: {provider_id}")
|
||||
continue
|
||||
|
||||
builtin_provider_ids.append(provider_id)
|
||||
|
||||
if not api_provider_ids and not builtin_provider_ids:
|
||||
return []
|
||||
|
||||
api_providers = db.session.execute(
|
||||
text("SELECT id FROM tool_api_providers WHERE id IN :provider_ids"), {"provider_ids": tuple(provider_ids)}
|
||||
).fetchall()
|
||||
with Session(db.engine) as session:
|
||||
if api_provider_ids:
|
||||
existing_api_providers = [
|
||||
api_provider.id
|
||||
for api_provider in session.execute(
|
||||
text("SELECT id FROM tool_api_providers WHERE id IN :provider_ids"),
|
||||
{"provider_ids": tuple(api_provider_ids)},
|
||||
).fetchall()
|
||||
]
|
||||
else:
|
||||
existing_api_providers = []
|
||||
|
||||
if builtin_provider_ids:
|
||||
# get the non-hardcoded builtin providers
|
||||
non_hardcoded_builtin_providers = [
|
||||
provider_id for provider_id in builtin_provider_ids if not provider_id.is_hardcoded
|
||||
]
|
||||
if non_hardcoded_builtin_providers:
|
||||
existence = list(PluginService.check_tools_existence(self.tenant_id, non_hardcoded_builtin_providers))
|
||||
else:
|
||||
existence = []
|
||||
# add the hardcoded builtin providers
|
||||
existence.extend([True] * (len(builtin_provider_ids) - len(non_hardcoded_builtin_providers)))
|
||||
builtin_provider_ids = non_hardcoded_builtin_providers + [
|
||||
provider_id for provider_id in builtin_provider_ids if provider_id.is_hardcoded
|
||||
]
|
||||
else:
|
||||
existence = []
|
||||
|
||||
existing_builtin_providers = {
|
||||
provider_id.provider_name: existence[i] for i, provider_id in enumerate(builtin_provider_ids)
|
||||
}
|
||||
|
||||
deleted_tools = []
|
||||
current_api_provider_ids = [str(api_provider.id) for api_provider in api_providers]
|
||||
|
||||
for tool in tools:
|
||||
keys = list(tool.keys())
|
||||
if len(keys) >= 4:
|
||||
provider_type = tool.get("provider_type", "")
|
||||
provider_id = tool.get("provider_id", "")
|
||||
if provider_type == "api" and provider_id not in current_api_provider_ids:
|
||||
deleted_tools.append(tool["tool_name"])
|
||||
|
||||
if provider_type == ToolProviderType.API.value:
|
||||
if provider_id not in existing_api_providers:
|
||||
deleted_tools.append(
|
||||
{
|
||||
"type": ToolProviderType.API.value,
|
||||
"tool_name": tool["tool_name"],
|
||||
"provider_id": provider_id,
|
||||
}
|
||||
)
|
||||
|
||||
if provider_type == ToolProviderType.BUILT_IN.value:
|
||||
generic_provider_id = GenericProviderID(provider_id)
|
||||
|
||||
if not existing_builtin_providers[generic_provider_id.provider_name]:
|
||||
deleted_tools.append(
|
||||
{
|
||||
"type": ToolProviderType.BUILT_IN.value,
|
||||
"tool_name": tool["tool_name"],
|
||||
"provider_id": provider_id, # use the original one
|
||||
}
|
||||
)
|
||||
|
||||
return deleted_tools
|
||||
|
||||
|
||||
@ -7,7 +7,13 @@ from core.helper import marketplace
|
||||
from core.helper.download import download_with_size_limit
|
||||
from core.helper.marketplace import download_plugin_pkg
|
||||
from core.plugin.entities.bundle import PluginBundleDependency
|
||||
from core.plugin.entities.plugin import PluginDeclaration, PluginEntity, PluginInstallation, PluginInstallationSource
|
||||
from core.plugin.entities.plugin import (
|
||||
GenericProviderID,
|
||||
PluginDeclaration,
|
||||
PluginEntity,
|
||||
PluginInstallation,
|
||||
PluginInstallationSource,
|
||||
)
|
||||
from core.plugin.entities.plugin_daemon import PluginInstallTask, PluginUploadResponse
|
||||
from core.plugin.manager.asset import PluginAssetManager
|
||||
from core.plugin.manager.debugging import PluginDebuggingManager
|
||||
@ -279,3 +285,11 @@ class PluginService:
|
||||
def uninstall(tenant_id: str, plugin_installation_id: str) -> bool:
|
||||
manager = PluginInstallationManager()
|
||||
return manager.uninstall(tenant_id, plugin_installation_id)
|
||||
|
||||
@staticmethod
|
||||
def check_tools_existence(tenant_id: str, provider_ids: Sequence[GenericProviderID]) -> Sequence[bool]:
|
||||
"""
|
||||
Check if the tools exist
|
||||
"""
|
||||
manager = PluginInstallationManager()
|
||||
return manager.check_tools_existence(tenant_id, provider_ids)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user