diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 2cd6dcda3b..9e62a54699 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -3,12 +3,14 @@ import io from flask import send_file from flask_login import current_user from flask_restful import Resource, reqparse +from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden from configs import dify_config from controllers.console import api from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder +from extensions.ext_database import db from libs.helper import alphanumeric, uuid_value from libs.login import login_required from services.tools.api_tools_manage_service import ApiToolManageService @@ -91,12 +93,16 @@ class ToolBuiltinProviderUpdateApi(Resource): args = parser.parse_args() - return BuiltinToolManageService.update_builtin_tool_provider( - user_id, - tenant_id, - provider, - args["credentials"], - ) + with Session(db.engine) as session: + result = BuiltinToolManageService.update_builtin_tool_provider( + session=session, + user_id=user_id, + tenant_id=tenant_id, + provider_name=provider, + credentials=args["credentials"], + ) + session.commit() + return result class ToolBuiltinProviderGetCredentialsApi(Resource): @@ -104,13 +110,11 @@ class ToolBuiltinProviderGetCredentialsApi(Resource): @login_required @account_initialization_required def get(self, provider): - user_id = current_user.id tenant_id = current_user.current_tenant_id return BuiltinToolManageService.get_builtin_tool_provider_credentials( - user_id, - tenant_id, - provider, + tenant_id=tenant_id, + provider_name=provider, ) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index e2e49d017e..fada881fde 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -2,6 +2,9 @@ import json import logging from pathlib import Path +from sqlalchemy import select +from sqlalchemy.orm import Session + from configs import dify_config from core.helper.position_helper import is_filtered from core.model_runtime.utils.encoders import jsonable_encoder @@ -32,7 +35,7 @@ class BuiltinToolManageService: tenant_id=tenant_id, provider_controller=provider_controller ) # check if user has added the provider - builtin_provider: BuiltinToolProvider = ( + builtin_provider = ( db.session.query(BuiltinToolProvider) .filter( BuiltinToolProvider.tenant_id == tenant_id, @@ -71,19 +74,18 @@ class BuiltinToolManageService: return jsonable_encoder([v for _, v in (provider.credentials_schema or {}).items()]) @staticmethod - def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict): + def update_builtin_tool_provider( + session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict + ): """ update builtin tool provider """ # get if the provider exists - provider: BuiltinToolProvider = ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider_name, - ) - .first() + stmt = select(BuiltinToolProvider).where( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider_name, ) + provider = session.scalar(stmt) try: # get provider @@ -115,13 +117,10 @@ class BuiltinToolManageService: encrypted_credentials=json.dumps(credentials), ) - db.session.add(provider) - db.session.commit() + session.add(provider) else: provider.encrypted_credentials = json.dumps(credentials) - db.session.add(provider) - db.session.commit() # delete cache tool_configuration.delete_tool_credentials_cache() @@ -129,15 +128,15 @@ class BuiltinToolManageService: return {"result": "success"} @staticmethod - def get_builtin_tool_provider_credentials(user_id: str, tenant_id: str, provider: str): + def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str): """ get builtin tool provider credentials """ - provider: BuiltinToolProvider = ( + provider = ( db.session.query(BuiltinToolProvider) .filter( BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider, + BuiltinToolProvider.provider == provider_name, ) .first() ) @@ -156,7 +155,7 @@ class BuiltinToolManageService: """ delete tool provider """ - provider: BuiltinToolProvider = ( + provider = ( db.session.query(BuiltinToolProvider) .filter( BuiltinToolProvider.tenant_id == tenant_id,