From 36686d7425e24828d6cdb4f7eb020801ad343b39 Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Mon, 4 Mar 2024 14:16:47 +0800 Subject: [PATCH] fix: test custom tool already exists without decrypting credentials (#2668) --- .../console/workspace/tool_providers.py | 2 + api/core/tools/tool/api_tool.py | 3 ++ api/services/tools_manage_service.py | 47 ++++++++++++++----- .../edit-custom-collection-modal/test-api.tsx | 1 + 4 files changed, 41 insertions(+), 12 deletions(-) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index c2c5286d51..817c75765a 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -259,6 +259,7 @@ class ToolApiProviderPreviousTestApi(Resource): parser = reqparse.RequestParser() parser.add_argument('tool_name', type=str, required=True, nullable=False, location='json') + parser.add_argument('provider_name', type=str, required=False, nullable=False, location='json') parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') parser.add_argument('parameters', type=dict, required=True, nullable=False, location='json') parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json') @@ -268,6 +269,7 @@ class ToolApiProviderPreviousTestApi(Resource): return ToolManageService.test_api_tool_preview( current_user.current_tenant_id, + args['provider_name'] if args['provider_name'] else '', args['tool_name'], args['credentials'], args['parameters'], diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py index d5a4bf20bd..31519734ed 100644 --- a/api/core/tools/tool/api_tool.py +++ b/api/core/tools/tool/api_tool.py @@ -1,6 +1,7 @@ import json from json import dumps from typing import Any, Union +from urllib.parse import urlencode import httpx import requests @@ -203,6 +204,8 @@ class ApiTool(Tool): if 'Content-Type' in headers: if headers['Content-Type'] == 'application/json': body = dumps(body) + elif headers['Content-Type'] == 'application/x-www-form-urlencoded': + body = urlencode(body) else: body = body diff --git a/api/services/tools_manage_service.py b/api/services/tools_manage_service.py index 7e305c3f7b..0e3d481640 100644 --- a/api/services/tools_manage_service.py +++ b/api/services/tools_manage_service.py @@ -498,12 +498,16 @@ class ToolManageService: @staticmethod def test_api_tool_preview( - tenant_id: str, tool_name: str, credentials: dict, parameters: dict, schema_type: str, schema: str + tenant_id: str, + provider_name: str, + tool_name: str, + credentials: dict, + parameters: dict, + schema_type: str, + schema: str ): """ test api tool before adding api tool provider - - 1. parse schema into tool bundle """ if schema_type not in [member.value for member in ApiProviderSchemaType]: raise ValueError(f'invalid schema type {schema_type}') @@ -518,15 +522,21 @@ class ToolManageService: if tool_bundle is None: raise ValueError(f'invalid tool name {tool_name}') - # create a fake db provider - db_provider = ApiToolProvider( - tenant_id='', user_id='', name='', icon='', - schema=schema, - description='', - schema_type_str=ApiProviderSchemaType.OPENAPI.value, - tools_str=serialize_base_model_array(tool_bundles), - credentials_str=json.dumps(credentials), - ) + db_provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ).first() + + if not db_provider: + # create a fake db provider + db_provider = ApiToolProvider( + tenant_id='', user_id='', name='', icon='', + schema=schema, + description='', + schema_type_str=ApiProviderSchemaType.OPENAPI.value, + tools_str=serialize_base_model_array(tool_bundles), + credentials_str=json.dumps(credentials), + ) if 'auth_type' not in credentials: raise ValueError('auth_type is required') @@ -539,6 +549,19 @@ class ToolManageService: # load tools into provider entity provider_controller.load_bundled_tools(tool_bundles) + # decrypt credentials + if db_provider.id: + tool_configuration = ToolConfiguration( + tenant_id=tenant_id, + provider_controller=provider_controller + ) + decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) + # check if the credential has changed, save the original credential + masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) + for name, value in credentials.items(): + if name in masked_credentials and value == masked_credentials[name]: + credentials[name] = decrypted_credentials[name] + try: provider_controller.validate_credentials_format(credentials) # get tool diff --git a/web/app/components/tools/edit-custom-collection-modal/test-api.tsx b/web/app/components/tools/edit-custom-collection-modal/test-api.tsx index 620ceb8f48..791ac5edbf 100644 --- a/web/app/components/tools/edit-custom-collection-modal/test-api.tsx +++ b/web/app/components/tools/edit-custom-collection-modal/test-api.tsx @@ -42,6 +42,7 @@ const TestApi: FC = ({ delete credentials.api_key_value } const data = { + provider_name: customCollection.provider, tool_name: toolName, credentials, schema_type: customCollection.schema_type,