refactor(tool-engine): Improve tool provider handling with session ma… (#14291)

This commit is contained in:
Yeuoly 2025-02-25 12:33:29 +08:00 committed by GitHub
parent 490b6d092e
commit 9fb78ce827
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Union, cast
from yarl import URL from yarl import URL
import contexts import contexts
from core.plugin.entities.plugin import GenericProviderID from core.plugin.entities.plugin import ToolProviderID
from core.plugin.manager.tool import PluginToolManager from core.plugin.manager.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime from core.tools.__base.tool_runtime import ToolRuntime
@ -188,7 +188,7 @@ class ToolManager:
) )
if isinstance(provider_controller, PluginToolProviderController): if isinstance(provider_controller, PluginToolProviderController):
provider_id_entity = GenericProviderID(provider_id) provider_id_entity = ToolProviderID(provider_id)
# get credentials # get credentials
builtin_provider: BuiltinToolProvider | None = ( builtin_provider: BuiltinToolProvider | None = (
db.session.query(BuiltinToolProvider) db.session.query(BuiltinToolProvider)
@ -572,6 +572,7 @@ class ToolManager:
else: else:
filters.append(typ) filters.append(typ)
with db.session.no_autoflush:
if "builtin" in filters: if "builtin" in filters:
# get builtin providers # get builtin providers
builtin_providers = cls.list_builtin_providers(tenant_id) builtin_providers = cls.list_builtin_providers(tenant_id)
@ -583,8 +584,8 @@ class ToolManager:
# rewrite db_builtin_providers # rewrite db_builtin_providers
for db_provider in db_builtin_providers: for db_provider in db_builtin_providers:
tool_provider_id = GenericProviderID(db_provider.provider) tool_provider_id = str(ToolProviderID(db_provider.provider))
db_provider.provider = tool_provider_id.to_string() db_provider.provider = tool_provider_id
def find_db_builtin_provider(provider): def find_db_builtin_provider(provider):
return next((x for x in db_builtin_providers if x.provider == provider), None) return next((x for x in db_builtin_providers if x.provider == provider), None)