mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-20 12:29:17 +08:00
fix: deleted_tools
This commit is contained in:
parent
ba3659a792
commit
c2ce8e638e
@ -139,6 +139,7 @@ class GenericProviderID:
|
|||||||
organization: str
|
organization: str
|
||||||
plugin_name: str
|
plugin_name: str
|
||||||
provider_name: str
|
provider_name: str
|
||||||
|
is_hardcoded: bool
|
||||||
|
|
||||||
def to_string(self) -> str:
|
def to_string(self) -> str:
|
||||||
return str(self)
|
return str(self)
|
||||||
@ -146,7 +147,7 @@ class GenericProviderID:
|
|||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return f"{self.organization}/{self.plugin_name}/{self.provider_name}"
|
return f"{self.organization}/{self.plugin_name}/{self.provider_name}"
|
||||||
|
|
||||||
def __init__(self, value: str) -> None:
|
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
|
||||||
# check if the value is a valid plugin id with format: $organization/$plugin_name/$provider_name
|
# check if the value is a valid plugin id with format: $organization/$plugin_name/$provider_name
|
||||||
if not re.match(r"^[a-z0-9_-]+\/[a-z0-9_-]+\/[a-z0-9_-]+$", value):
|
if not re.match(r"^[a-z0-9_-]+\/[a-z0-9_-]+\/[a-z0-9_-]+$", value):
|
||||||
# check if matches [a-z0-9_-]+, if yes, append with langgenius/$value/$value
|
# check if matches [a-z0-9_-]+, if yes, append with langgenius/$value/$value
|
||||||
@ -156,6 +157,7 @@ class GenericProviderID:
|
|||||||
raise ValueError("Invalid plugin id")
|
raise ValueError("Invalid plugin id")
|
||||||
|
|
||||||
self.organization, self.plugin_name, self.provider_name = value.split("/")
|
self.organization, self.plugin_name, self.provider_name = value.split("/")
|
||||||
|
self.is_hardcoded = is_hardcoded
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def plugin_id(self) -> str:
|
def plugin_id(self) -> str:
|
||||||
|
@ -4,6 +4,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from core.plugin.entities.bundle import PluginBundleDependency
|
from core.plugin.entities.bundle import PluginBundleDependency
|
||||||
from core.plugin.entities.plugin import (
|
from core.plugin.entities.plugin import (
|
||||||
|
GenericProviderID,
|
||||||
PluginDeclaration,
|
PluginDeclaration,
|
||||||
PluginEntity,
|
PluginEntity,
|
||||||
PluginInstallation,
|
PluginInstallation,
|
||||||
@ -224,3 +225,23 @@ class PluginInstallationManager(BasePluginManager):
|
|||||||
},
|
},
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def check_tools_existence(self, tenant_id: str, provider_ids: Sequence[GenericProviderID]) -> Sequence[bool]:
|
||||||
|
"""
|
||||||
|
Check if the tools exist
|
||||||
|
"""
|
||||||
|
return self._request_with_plugin_daemon_response(
|
||||||
|
"POST",
|
||||||
|
f"plugin/{tenant_id}/management/tools/check_existence",
|
||||||
|
list[bool],
|
||||||
|
data={
|
||||||
|
"provider_ids": [
|
||||||
|
{
|
||||||
|
"plugin_id": provider_id.plugin_id,
|
||||||
|
"provider_name": provider_id.provider_name,
|
||||||
|
}
|
||||||
|
for provider_id in provider_ids
|
||||||
|
]
|
||||||
|
},
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
)
|
||||||
|
@ -149,6 +149,12 @@ site_fields = {
|
|||||||
"updated_at": TimestampField,
|
"updated_at": TimestampField,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
deleted_tool_fields = {
|
||||||
|
"type": fields.String,
|
||||||
|
"tool_name": fields.String,
|
||||||
|
"provider_id": fields.String,
|
||||||
|
}
|
||||||
|
|
||||||
app_detail_fields_with_site = {
|
app_detail_fields_with_site = {
|
||||||
"id": fields.String,
|
"id": fields.String,
|
||||||
"name": fields.String,
|
"name": fields.String,
|
||||||
@ -169,9 +175,10 @@ app_detail_fields_with_site = {
|
|||||||
"created_at": TimestampField,
|
"created_at": TimestampField,
|
||||||
"updated_by": fields.String,
|
"updated_by": fields.String,
|
||||||
"updated_at": TimestampField,
|
"updated_at": TimestampField,
|
||||||
"deleted_tools": fields.List(fields.String),
|
"deleted_tools": fields.List(fields.Nested(deleted_tool_fields)),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
app_site_fields = {
|
app_site_fields = {
|
||||||
"app_id": fields.String,
|
"app_id": fields.String,
|
||||||
"access_token": fields.String(attribute="code"),
|
"access_token": fields.String(attribute="code"),
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
@ -6,6 +7,10 @@ from datetime import datetime
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
from core.plugin.entities.plugin import GenericProviderID
|
||||||
|
from core.tools.entities.tool_entities import ToolProviderType
|
||||||
|
from services.plugin.plugin_service import PluginService
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
|
|
||||||
@ -16,7 +21,7 @@ import sqlalchemy as sa
|
|||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import UserMixin
|
from flask_login import UserMixin
|
||||||
from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text
|
from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
|
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
|
||||||
@ -30,6 +35,8 @@ from models.enums import CreatedByRole
|
|||||||
from .account import Account, Tenant
|
from .account import Account, Tenant
|
||||||
from .types import StringUUID
|
from .types import StringUUID
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DifySetup(Base):
|
class DifySetup(Base):
|
||||||
__tablename__ = "dify_setups"
|
__tablename__ = "dify_setups"
|
||||||
@ -162,47 +169,114 @@ class App(Base):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def deleted_tools(self) -> list:
|
def deleted_tools(self) -> list:
|
||||||
|
from core.tools.tool_manager import ToolManager
|
||||||
|
|
||||||
# get agent mode tools
|
# get agent mode tools
|
||||||
app_model_config = self.app_model_config
|
app_model_config = self.app_model_config
|
||||||
if not app_model_config:
|
if not app_model_config:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
if not app_model_config.agent_mode:
|
if not app_model_config.agent_mode:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
agent_mode = app_model_config.agent_mode_dict
|
agent_mode = app_model_config.agent_mode_dict
|
||||||
tools = agent_mode.get("tools", [])
|
tools = agent_mode.get("tools", [])
|
||||||
|
|
||||||
provider_ids = []
|
api_provider_ids: list[str] = []
|
||||||
|
builtin_provider_ids: list[GenericProviderID] = []
|
||||||
|
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
keys = list(tool.keys())
|
keys = list(tool.keys())
|
||||||
if len(keys) >= 4:
|
if len(keys) >= 4:
|
||||||
provider_type = tool.get("provider_type", "")
|
provider_type = tool.get("provider_type", "")
|
||||||
provider_id = tool.get("provider_id", "")
|
provider_id = tool.get("provider_id", "")
|
||||||
if provider_type == "api":
|
if provider_type == ToolProviderType.API.value:
|
||||||
# check if provider id is a uuid string, if not, skip
|
|
||||||
try:
|
try:
|
||||||
uuid.UUID(provider_id)
|
uuid.UUID(provider_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
provider_ids.append(provider_id)
|
api_provider_ids.append(provider_id)
|
||||||
|
if provider_type == ToolProviderType.BUILT_IN.value:
|
||||||
|
try:
|
||||||
|
# check if it's hardcoded
|
||||||
|
try:
|
||||||
|
ToolManager.get_hardcoded_provider(provider_id)
|
||||||
|
is_hardcoded = True
|
||||||
|
except Exception:
|
||||||
|
is_hardcoded = False
|
||||||
|
|
||||||
if not provider_ids:
|
provider_id = GenericProviderID(provider_id, is_hardcoded)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(f"Invalid builtin provider id: {provider_id}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
builtin_provider_ids.append(provider_id)
|
||||||
|
|
||||||
|
if not api_provider_ids and not builtin_provider_ids:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
api_providers = db.session.execute(
|
with Session(db.engine) as session:
|
||||||
text("SELECT id FROM tool_api_providers WHERE id IN :provider_ids"), {"provider_ids": tuple(provider_ids)}
|
if api_provider_ids:
|
||||||
|
existing_api_providers = [
|
||||||
|
api_provider.id
|
||||||
|
for api_provider in session.execute(
|
||||||
|
text("SELECT id FROM tool_api_providers WHERE id IN :provider_ids"),
|
||||||
|
{"provider_ids": tuple(api_provider_ids)},
|
||||||
).fetchall()
|
).fetchall()
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
existing_api_providers = []
|
||||||
|
|
||||||
|
if builtin_provider_ids:
|
||||||
|
# get the non-hardcoded builtin providers
|
||||||
|
non_hardcoded_builtin_providers = [
|
||||||
|
provider_id for provider_id in builtin_provider_ids if not provider_id.is_hardcoded
|
||||||
|
]
|
||||||
|
if non_hardcoded_builtin_providers:
|
||||||
|
existence = list(PluginService.check_tools_existence(self.tenant_id, non_hardcoded_builtin_providers))
|
||||||
|
else:
|
||||||
|
existence = []
|
||||||
|
# add the hardcoded builtin providers
|
||||||
|
existence.extend([True] * (len(builtin_provider_ids) - len(non_hardcoded_builtin_providers)))
|
||||||
|
builtin_provider_ids = non_hardcoded_builtin_providers + [
|
||||||
|
provider_id for provider_id in builtin_provider_ids if provider_id.is_hardcoded
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
existence = []
|
||||||
|
|
||||||
|
existing_builtin_providers = {
|
||||||
|
provider_id.provider_name: existence[i] for i, provider_id in enumerate(builtin_provider_ids)
|
||||||
|
}
|
||||||
|
|
||||||
deleted_tools = []
|
deleted_tools = []
|
||||||
current_api_provider_ids = [str(api_provider.id) for api_provider in api_providers]
|
|
||||||
|
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
keys = list(tool.keys())
|
keys = list(tool.keys())
|
||||||
if len(keys) >= 4:
|
if len(keys) >= 4:
|
||||||
provider_type = tool.get("provider_type", "")
|
provider_type = tool.get("provider_type", "")
|
||||||
provider_id = tool.get("provider_id", "")
|
provider_id = tool.get("provider_id", "")
|
||||||
if provider_type == "api" and provider_id not in current_api_provider_ids:
|
|
||||||
deleted_tools.append(tool["tool_name"])
|
if provider_type == ToolProviderType.API.value:
|
||||||
|
if provider_id not in existing_api_providers:
|
||||||
|
deleted_tools.append(
|
||||||
|
{
|
||||||
|
"type": ToolProviderType.API.value,
|
||||||
|
"tool_name": tool["tool_name"],
|
||||||
|
"provider_id": provider_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if provider_type == ToolProviderType.BUILT_IN.value:
|
||||||
|
generic_provider_id = GenericProviderID(provider_id)
|
||||||
|
|
||||||
|
if not existing_builtin_providers[generic_provider_id.provider_name]:
|
||||||
|
deleted_tools.append(
|
||||||
|
{
|
||||||
|
"type": ToolProviderType.BUILT_IN.value,
|
||||||
|
"tool_name": tool["tool_name"],
|
||||||
|
"provider_id": provider_id, # use the original one
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return deleted_tools
|
return deleted_tools
|
||||||
|
|
||||||
|
@ -7,7 +7,13 @@ from core.helper import marketplace
|
|||||||
from core.helper.download import download_with_size_limit
|
from core.helper.download import download_with_size_limit
|
||||||
from core.helper.marketplace import download_plugin_pkg
|
from core.helper.marketplace import download_plugin_pkg
|
||||||
from core.plugin.entities.bundle import PluginBundleDependency
|
from core.plugin.entities.bundle import PluginBundleDependency
|
||||||
from core.plugin.entities.plugin import PluginDeclaration, PluginEntity, PluginInstallation, PluginInstallationSource
|
from core.plugin.entities.plugin import (
|
||||||
|
GenericProviderID,
|
||||||
|
PluginDeclaration,
|
||||||
|
PluginEntity,
|
||||||
|
PluginInstallation,
|
||||||
|
PluginInstallationSource,
|
||||||
|
)
|
||||||
from core.plugin.entities.plugin_daemon import PluginInstallTask, PluginUploadResponse
|
from core.plugin.entities.plugin_daemon import PluginInstallTask, PluginUploadResponse
|
||||||
from core.plugin.manager.asset import PluginAssetManager
|
from core.plugin.manager.asset import PluginAssetManager
|
||||||
from core.plugin.manager.debugging import PluginDebuggingManager
|
from core.plugin.manager.debugging import PluginDebuggingManager
|
||||||
@ -279,3 +285,11 @@ class PluginService:
|
|||||||
def uninstall(tenant_id: str, plugin_installation_id: str) -> bool:
|
def uninstall(tenant_id: str, plugin_installation_id: str) -> bool:
|
||||||
manager = PluginInstallationManager()
|
manager = PluginInstallationManager()
|
||||||
return manager.uninstall(tenant_id, plugin_installation_id)
|
return manager.uninstall(tenant_id, plugin_installation_id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_tools_existence(tenant_id: str, provider_ids: Sequence[GenericProviderID]) -> Sequence[bool]:
|
||||||
|
"""
|
||||||
|
Check if the tools exist
|
||||||
|
"""
|
||||||
|
manager = PluginInstallationManager()
|
||||||
|
return manager.check_tools_existence(tenant_id, provider_ids)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user