refactor(tool-engine): Improve tool provider handling with session ma… (#14291)

This commit is contained in:
Yeuoly 2025-02-25 12:33:29 +08:00 committed by GitHub
parent 490b6d092e
commit 9fb78ce827
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Union, cast
from yarl import URL from yarl import URL
import contexts import contexts
from core.plugin.entities.plugin import GenericProviderID from core.plugin.entities.plugin import ToolProviderID
from core.plugin.manager.tool import PluginToolManager from core.plugin.manager.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime from core.tools.__base.tool_runtime import ToolRuntime
@ -188,7 +188,7 @@ class ToolManager:
) )
if isinstance(provider_controller, PluginToolProviderController): if isinstance(provider_controller, PluginToolProviderController):
provider_id_entity = GenericProviderID(provider_id) provider_id_entity = ToolProviderID(provider_id)
# get credentials # get credentials
builtin_provider: BuiltinToolProvider | None = ( builtin_provider: BuiltinToolProvider | None = (
db.session.query(BuiltinToolProvider) db.session.query(BuiltinToolProvider)
@ -572,95 +572,96 @@ class ToolManager:
else: else:
filters.append(typ) filters.append(typ)
if "builtin" in filters: with db.session.no_autoflush:
# get builtin providers if "builtin" in filters:
builtin_providers = cls.list_builtin_providers(tenant_id) # get builtin providers
builtin_providers = cls.list_builtin_providers(tenant_id)
# get db builtin providers # get db builtin providers
db_builtin_providers: list[BuiltinToolProvider] = ( db_builtin_providers: list[BuiltinToolProvider] = (
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
)
# rewrite db_builtin_providers
for db_provider in db_builtin_providers:
tool_provider_id = GenericProviderID(db_provider.provider)
db_provider.provider = tool_provider_id.to_string()
def find_db_builtin_provider(provider):
return next((x for x in db_builtin_providers if x.provider == provider), None)
# append builtin providers
for provider in builtin_providers:
# handle include, exclude
if is_filtered(
include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
data=provider,
name_func=lambda x: x.identity.name,
):
continue
user_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider,
db_provider=find_db_builtin_provider(provider.entity.identity.name),
decrypt_credentials=False,
) )
if isinstance(provider, PluginToolProviderController): # rewrite db_builtin_providers
result_providers[f"plugin_provider.{user_provider.name}"] = user_provider for db_provider in db_builtin_providers:
else: tool_provider_id = str(ToolProviderID(db_provider.provider))
result_providers[f"builtin_provider.{user_provider.name}"] = user_provider db_provider.provider = tool_provider_id
# get db api providers def find_db_builtin_provider(provider):
return next((x for x in db_builtin_providers if x.provider == provider), None)
if "api" in filters: # append builtin providers
db_api_providers: list[ApiToolProvider] = ( for provider in builtin_providers:
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() # handle include, exclude
) if is_filtered(
include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
data=provider,
name_func=lambda x: x.identity.name,
):
continue
api_provider_controllers: list[dict[str, Any]] = [ user_provider = ToolTransformService.builtin_provider_to_user_provider(
{"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)} provider_controller=provider,
for provider in db_api_providers db_provider=find_db_builtin_provider(provider.entity.identity.name),
] decrypt_credentials=False,
# get labels
labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
for api_provider_controller in api_provider_controllers:
user_provider = ToolTransformService.api_provider_to_user_provider(
provider_controller=api_provider_controller["controller"],
db_provider=api_provider_controller["provider"],
decrypt_credentials=False,
labels=labels.get(api_provider_controller["controller"].provider_id, []),
)
result_providers[f"api_provider.{user_provider.name}"] = user_provider
if "workflow" in filters:
# get workflow providers
workflow_providers: list[WorkflowToolProvider] = (
db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
)
workflow_provider_controllers: list[WorkflowToolProviderController] = []
for provider in workflow_providers:
try:
workflow_provider_controllers.append(
ToolTransformService.workflow_provider_to_controller(db_provider=provider)
) )
except Exception:
# app has been deleted
pass
labels = ToolLabelManager.get_tools_labels( if isinstance(provider, PluginToolProviderController):
[cast(ToolProviderController, controller) for controller in workflow_provider_controllers] result_providers[f"plugin_provider.{user_provider.name}"] = user_provider
) else:
result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
for provider_controller in workflow_provider_controllers: # get db api providers
user_provider = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=provider_controller, if "api" in filters:
labels=labels.get(provider_controller.provider_id, []), db_api_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
) )
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
api_provider_controllers: list[dict[str, Any]] = [
{"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
for provider in db_api_providers
]
# get labels
labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
for api_provider_controller in api_provider_controllers:
user_provider = ToolTransformService.api_provider_to_user_provider(
provider_controller=api_provider_controller["controller"],
db_provider=api_provider_controller["provider"],
decrypt_credentials=False,
labels=labels.get(api_provider_controller["controller"].provider_id, []),
)
result_providers[f"api_provider.{user_provider.name}"] = user_provider
if "workflow" in filters:
# get workflow providers
workflow_providers: list[WorkflowToolProvider] = (
db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
)
workflow_provider_controllers: list[WorkflowToolProviderController] = []
for provider in workflow_providers:
try:
workflow_provider_controllers.append(
ToolTransformService.workflow_provider_to_controller(db_provider=provider)
)
except Exception:
# app has been deleted
pass
labels = ToolLabelManager.get_tools_labels(
[cast(ToolProviderController, controller) for controller in workflow_provider_controllers]
)
for provider_controller in workflow_provider_controllers:
user_provider = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=provider_controller,
labels=labels.get(provider_controller.provider_id, []),
)
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
return BuiltinToolProviderSort.sort(list(result_providers.values())) return BuiltinToolProviderSort.sort(list(result_providers.values()))