mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 02:49:02 +08:00
refactor(tool-engine): Improve tool provider handling with session ma… (#14291)
This commit is contained in:
parent
490b6d092e
commit
9fb78ce827
@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Union, cast
|
|||||||
from yarl import URL
|
from yarl import URL
|
||||||
|
|
||||||
import contexts
|
import contexts
|
||||||
from core.plugin.entities.plugin import GenericProviderID
|
from core.plugin.entities.plugin import ToolProviderID
|
||||||
from core.plugin.manager.tool import PluginToolManager
|
from core.plugin.manager.tool import PluginToolManager
|
||||||
from core.tools.__base.tool_provider import ToolProviderController
|
from core.tools.__base.tool_provider import ToolProviderController
|
||||||
from core.tools.__base.tool_runtime import ToolRuntime
|
from core.tools.__base.tool_runtime import ToolRuntime
|
||||||
@ -188,7 +188,7 @@ class ToolManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(provider_controller, PluginToolProviderController):
|
if isinstance(provider_controller, PluginToolProviderController):
|
||||||
provider_id_entity = GenericProviderID(provider_id)
|
provider_id_entity = ToolProviderID(provider_id)
|
||||||
# get credentials
|
# get credentials
|
||||||
builtin_provider: BuiltinToolProvider | None = (
|
builtin_provider: BuiltinToolProvider | None = (
|
||||||
db.session.query(BuiltinToolProvider)
|
db.session.query(BuiltinToolProvider)
|
||||||
@ -572,95 +572,96 @@ class ToolManager:
|
|||||||
else:
|
else:
|
||||||
filters.append(typ)
|
filters.append(typ)
|
||||||
|
|
||||||
if "builtin" in filters:
|
with db.session.no_autoflush:
|
||||||
# get builtin providers
|
if "builtin" in filters:
|
||||||
builtin_providers = cls.list_builtin_providers(tenant_id)
|
# get builtin providers
|
||||||
|
builtin_providers = cls.list_builtin_providers(tenant_id)
|
||||||
|
|
||||||
# get db builtin providers
|
# get db builtin providers
|
||||||
db_builtin_providers: list[BuiltinToolProvider] = (
|
db_builtin_providers: list[BuiltinToolProvider] = (
|
||||||
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
|
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
|
||||||
)
|
|
||||||
|
|
||||||
# rewrite db_builtin_providers
|
|
||||||
for db_provider in db_builtin_providers:
|
|
||||||
tool_provider_id = GenericProviderID(db_provider.provider)
|
|
||||||
db_provider.provider = tool_provider_id.to_string()
|
|
||||||
|
|
||||||
def find_db_builtin_provider(provider):
|
|
||||||
return next((x for x in db_builtin_providers if x.provider == provider), None)
|
|
||||||
|
|
||||||
# append builtin providers
|
|
||||||
for provider in builtin_providers:
|
|
||||||
# handle include, exclude
|
|
||||||
if is_filtered(
|
|
||||||
include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
|
|
||||||
exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
|
|
||||||
data=provider,
|
|
||||||
name_func=lambda x: x.identity.name,
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
|
|
||||||
user_provider = ToolTransformService.builtin_provider_to_user_provider(
|
|
||||||
provider_controller=provider,
|
|
||||||
db_provider=find_db_builtin_provider(provider.entity.identity.name),
|
|
||||||
decrypt_credentials=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(provider, PluginToolProviderController):
|
# rewrite db_builtin_providers
|
||||||
result_providers[f"plugin_provider.{user_provider.name}"] = user_provider
|
for db_provider in db_builtin_providers:
|
||||||
else:
|
tool_provider_id = str(ToolProviderID(db_provider.provider))
|
||||||
result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
|
db_provider.provider = tool_provider_id
|
||||||
|
|
||||||
# get db api providers
|
def find_db_builtin_provider(provider):
|
||||||
|
return next((x for x in db_builtin_providers if x.provider == provider), None)
|
||||||
|
|
||||||
if "api" in filters:
|
# append builtin providers
|
||||||
db_api_providers: list[ApiToolProvider] = (
|
for provider in builtin_providers:
|
||||||
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
|
# handle include, exclude
|
||||||
)
|
if is_filtered(
|
||||||
|
include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
|
||||||
|
exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
|
||||||
|
data=provider,
|
||||||
|
name_func=lambda x: x.identity.name,
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
api_provider_controllers: list[dict[str, Any]] = [
|
user_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||||
{"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
|
provider_controller=provider,
|
||||||
for provider in db_api_providers
|
db_provider=find_db_builtin_provider(provider.entity.identity.name),
|
||||||
]
|
decrypt_credentials=False,
|
||||||
|
|
||||||
# get labels
|
|
||||||
labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
|
|
||||||
|
|
||||||
for api_provider_controller in api_provider_controllers:
|
|
||||||
user_provider = ToolTransformService.api_provider_to_user_provider(
|
|
||||||
provider_controller=api_provider_controller["controller"],
|
|
||||||
db_provider=api_provider_controller["provider"],
|
|
||||||
decrypt_credentials=False,
|
|
||||||
labels=labels.get(api_provider_controller["controller"].provider_id, []),
|
|
||||||
)
|
|
||||||
result_providers[f"api_provider.{user_provider.name}"] = user_provider
|
|
||||||
|
|
||||||
if "workflow" in filters:
|
|
||||||
# get workflow providers
|
|
||||||
workflow_providers: list[WorkflowToolProvider] = (
|
|
||||||
db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
|
|
||||||
)
|
|
||||||
|
|
||||||
workflow_provider_controllers: list[WorkflowToolProviderController] = []
|
|
||||||
for provider in workflow_providers:
|
|
||||||
try:
|
|
||||||
workflow_provider_controllers.append(
|
|
||||||
ToolTransformService.workflow_provider_to_controller(db_provider=provider)
|
|
||||||
)
|
)
|
||||||
except Exception:
|
|
||||||
# app has been deleted
|
|
||||||
pass
|
|
||||||
|
|
||||||
labels = ToolLabelManager.get_tools_labels(
|
if isinstance(provider, PluginToolProviderController):
|
||||||
[cast(ToolProviderController, controller) for controller in workflow_provider_controllers]
|
result_providers[f"plugin_provider.{user_provider.name}"] = user_provider
|
||||||
)
|
else:
|
||||||
|
result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
|
||||||
|
|
||||||
for provider_controller in workflow_provider_controllers:
|
# get db api providers
|
||||||
user_provider = ToolTransformService.workflow_provider_to_user_provider(
|
|
||||||
provider_controller=provider_controller,
|
if "api" in filters:
|
||||||
labels=labels.get(provider_controller.provider_id, []),
|
db_api_providers: list[ApiToolProvider] = (
|
||||||
|
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
|
||||||
)
|
)
|
||||||
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
|
|
||||||
|
api_provider_controllers: list[dict[str, Any]] = [
|
||||||
|
{"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
|
||||||
|
for provider in db_api_providers
|
||||||
|
]
|
||||||
|
|
||||||
|
# get labels
|
||||||
|
labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
|
||||||
|
|
||||||
|
for api_provider_controller in api_provider_controllers:
|
||||||
|
user_provider = ToolTransformService.api_provider_to_user_provider(
|
||||||
|
provider_controller=api_provider_controller["controller"],
|
||||||
|
db_provider=api_provider_controller["provider"],
|
||||||
|
decrypt_credentials=False,
|
||||||
|
labels=labels.get(api_provider_controller["controller"].provider_id, []),
|
||||||
|
)
|
||||||
|
result_providers[f"api_provider.{user_provider.name}"] = user_provider
|
||||||
|
|
||||||
|
if "workflow" in filters:
|
||||||
|
# get workflow providers
|
||||||
|
workflow_providers: list[WorkflowToolProvider] = (
|
||||||
|
db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow_provider_controllers: list[WorkflowToolProviderController] = []
|
||||||
|
for provider in workflow_providers:
|
||||||
|
try:
|
||||||
|
workflow_provider_controllers.append(
|
||||||
|
ToolTransformService.workflow_provider_to_controller(db_provider=provider)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# app has been deleted
|
||||||
|
pass
|
||||||
|
|
||||||
|
labels = ToolLabelManager.get_tools_labels(
|
||||||
|
[cast(ToolProviderController, controller) for controller in workflow_provider_controllers]
|
||||||
|
)
|
||||||
|
|
||||||
|
for provider_controller in workflow_provider_controllers:
|
||||||
|
user_provider = ToolTransformService.workflow_provider_to_user_provider(
|
||||||
|
provider_controller=provider_controller,
|
||||||
|
labels=labels.get(provider_controller.provider_id, []),
|
||||||
|
)
|
||||||
|
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
|
||||||
|
|
||||||
return BuiltinToolProviderSort.sort(list(result_providers.values()))
|
return BuiltinToolProviderSort.sort(list(result_providers.values()))
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user