diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 2bcca52fcb..56a442a223 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -85,4 +85,12 @@ class ToolConfiguration(BaseModel): pass cache.set(credentials) - return credentials \ No newline at end of file + return credentials + + def delete_tool_credentials_cache(self): + cache = ToolProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}', + cache_type=ToolProviderCredentialsCacheType.PROVIDER + ) + cache.delete() diff --git a/api/services/tools_manage_service.py b/api/services/tools_manage_service.py index d4975e6cbb..9c8c0dbfac 100644 --- a/api/services/tools_manage_service.py +++ b/api/services/tools_manage_service.py @@ -355,10 +355,12 @@ class ToolManageService: else: provider.encrypted_credentials = json.dumps(credentials) - db.session.add(provider) db.session.commit() + # delete cache + tool_configuration.delete_tool_credentials_cache() + return { 'result': 'success' } @staticmethod @@ -393,7 +395,6 @@ class ToolManageService: provider.description = extra_info.get('description', '') provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value provider.tools_str = serialize_base_model_array(tool_bundles) - provider.credentials_str = json.dumps(credentials) provider.privacy_policy = privacy_policy if 'auth_type' not in credentials: @@ -403,33 +404,54 @@ class ToolManageService: auth_type = ApiProviderAuthType.value_of(credentials['auth_type']) # create provider entity - provider_entity = ApiBasedToolProviderController.from_db(provider, auth_type) + provider_controller = ApiBasedToolProviderController.from_db(provider, auth_type) # load tools into provider entity - provider_entity.load_bundled_tools(tool_bundles) + provider_controller.load_bundled_tools(tool_bundles) + + # get original credentials if exists + tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller) + + original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) + masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) + # check if the credential has changed, save the original credential + for name, value in credentials.items(): + if name in masked_credentials and value == masked_credentials[name]: + credentials[name] = original_credentials[name] + + credentials = tool_configuration.encrypt_tool_credentials(credentials) + provider.credentials_str = json.dumps(credentials) db.session.add(provider) db.session.commit() + # delete cache + tool_configuration.delete_tool_credentials_cache() + return { 'result': 'success' } @staticmethod def delete_builtin_tool_provider( - user_id: str, tenant_id: str, provider: str + user_id: str, tenant_id: str, provider_name: str ): """ delete tool provider """ provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider, + BuiltinToolProvider.provider == provider_name, ).first() if provider is None: - raise ValueError(f'you have not added provider {provider}') + raise ValueError(f'you have not added provider {provider_name}') db.session.delete(provider) db.session.commit() + # delete cache + provider_controller = ToolManager.get_builtin_provider(provider_name) + tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller) + tool_configuration.delete_tool_credentials_cache() + return { 'result': 'success' } @staticmethod @@ -437,7 +459,7 @@ class ToolManageService: provider: str ): """ - get tool provider icon and it's minetype + get tool provider icon and it's mimetype """ icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider) with open(icon_path, 'rb') as f: @@ -447,18 +469,18 @@ class ToolManageService: @staticmethod def delete_api_tool_provider( - user_id: str, tenant_id: str, provider: str + user_id: str, tenant_id: str, provider_name: str ): """ delete tool provider """ provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider, + ApiToolProvider.name == provider_name, ).first() if provider is None: - raise ValueError(f'you have not added provider {provider}') + raise ValueError(f'you have not added provider {provider_name}') db.session.delete(provider) db.session.commit()