diff --git a/api/services/plugin/dependencies_analysis.py b/api/services/plugin/dependencies_analysis.py index 77da3d09ea..778f05a0cd 100644 --- a/api/services/plugin/dependencies_analysis.py +++ b/api/services/plugin/dependencies_analysis.py @@ -1,5 +1,5 @@ from core.helper import marketplace -from core.plugin.entities.plugin import GenericProviderID, PluginDependency, PluginInstallationSource +from core.plugin.entities.plugin import ModelProviderID, PluginDependency, PluginInstallationSource, ToolProviderID from core.plugin.manager.plugin import PluginInstallationManager @@ -12,10 +12,7 @@ class DependenciesAnalysisService: Convert the tool id to the plugin_id """ try: - tool_provider_id = GenericProviderID(tool_id) - if tool_id in ["jina", "siliconflow"]: - tool_provider_id.plugin_name = tool_provider_id.plugin_name + "_tool" - return tool_provider_id.plugin_id + return ToolProviderID(tool_id).plugin_id except Exception as e: raise e @@ -27,11 +24,7 @@ class DependenciesAnalysisService: Convert the model provider id to the plugin_id """ try: - generic_provider_id = GenericProviderID(model_provider_id) - if model_provider_id == "google": - generic_provider_id.plugin_name = "gemini" - - return generic_provider_id.plugin_id + return ModelProviderID(model_provider_id).plugin_id except Exception as e: raise e diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index 6b91183573..ec9e0aa8dc 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -14,9 +14,8 @@ from flask import Flask, current_app from sqlalchemy.orm import Session from core.agent.entities import AgentToolEntity -from core.entities import DEFAULT_PLUGIN_ID from core.helper import marketplace -from core.plugin.entities.plugin import PluginInstallationSource +from core.plugin.entities.plugin import ModelProviderID, PluginInstallationSource, ToolProviderID from core.plugin.entities.plugin_daemon import PluginInstallTaskStatus from core.plugin.manager.plugin import PluginInstallationManager from core.tools.entities.tool_entities import ToolProviderType @@ -203,13 +202,7 @@ class PluginMigration: result = [] for row in rs: provider_name = str(row[0]) - if provider_name and "/" not in provider_name: - if provider_name == "google": - provider_name = "gemini" - - result.append(DEFAULT_PLUGIN_ID + "/" + provider_name) - elif provider_name: - result.append(provider_name) + result.append(ModelProviderID(provider_name).plugin_id) return result @@ -222,30 +215,10 @@ class PluginMigration: rs = session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() result = [] for row in rs: - if "/" not in row.provider: - result.append(DEFAULT_PLUGIN_ID + "/" + row.provider) - else: - result.append(row.provider) + result.append(ToolProviderID(row.provider).plugin_id) return result - @classmethod - def _handle_builtin_tool_provider(cls, provider_name: str) -> str: - """ - Handle builtin tool provider. - """ - if provider_name == "jina": - provider_name = "jina_tool" - elif provider_name == "siliconflow": - provider_name = "siliconflow_tool" - elif provider_name == "stepfun": - provider_name = "stepfun_tool" - - if "/" not in provider_name: - return DEFAULT_PLUGIN_ID + "/" + provider_name - else: - return provider_name - @classmethod def extract_workflow_tables(cls, tenant_id: str) -> Sequence[str]: """ @@ -266,8 +239,7 @@ class PluginMigration: provider_name = data.get("provider_name") provider_type = data.get("provider_type") if provider_name not in excluded_providers and provider_type == ToolProviderType.BUILT_IN.value: - provider_name = cls._handle_builtin_tool_provider(provider_name) - result.append(provider_name) + result.append(ToolProviderID(provider_name).plugin_id) return result @@ -298,7 +270,7 @@ class PluginMigration: tool_entity.provider_type == ToolProviderType.BUILT_IN.value and tool_entity.provider_id not in excluded_providers ): - result.append(cls._handle_builtin_tool_provider(tool_entity.provider_id)) + result.append(ToolProviderID(tool_entity.provider_id).plugin_id) except Exception: logger.exception(f"Failed to process tool {tool}")