refactor: update builtin tool provider methods to use session management (#11938)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2024-12-21 21:21:09 +08:00 committed by GitHub
parent 8f73670925
commit 606aadb891
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 30 additions and 27 deletions

View File

@ -3,12 +3,14 @@ import io
from flask import send_file from flask import send_file
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from configs import dify_config from configs import dify_config
from controllers.console import api from controllers.console import api
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from libs.helper import alphanumeric, uuid_value from libs.helper import alphanumeric, uuid_value
from libs.login import login_required from libs.login import login_required
from services.tools.api_tools_manage_service import ApiToolManageService from services.tools.api_tools_manage_service import ApiToolManageService
@ -91,12 +93,16 @@ class ToolBuiltinProviderUpdateApi(Resource):
args = parser.parse_args() args = parser.parse_args()
return BuiltinToolManageService.update_builtin_tool_provider( with Session(db.engine) as session:
user_id, result = BuiltinToolManageService.update_builtin_tool_provider(
tenant_id, session=session,
provider, user_id=user_id,
args["credentials"], tenant_id=tenant_id,
) provider_name=provider,
credentials=args["credentials"],
)
session.commit()
return result
class ToolBuiltinProviderGetCredentialsApi(Resource): class ToolBuiltinProviderGetCredentialsApi(Resource):
@ -104,13 +110,11 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider): def get(self, provider):
user_id = current_user.id
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
return BuiltinToolManageService.get_builtin_tool_provider_credentials( return BuiltinToolManageService.get_builtin_tool_provider_credentials(
user_id, tenant_id=tenant_id,
tenant_id, provider_name=provider,
provider,
) )

View File

@ -2,6 +2,9 @@ import json
import logging import logging
from pathlib import Path from pathlib import Path
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config from configs import dify_config
from core.helper.position_helper import is_filtered from core.helper.position_helper import is_filtered
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
@ -32,7 +35,7 @@ class BuiltinToolManageService:
tenant_id=tenant_id, provider_controller=provider_controller tenant_id=tenant_id, provider_controller=provider_controller
) )
# check if user has added the provider # check if user has added the provider
builtin_provider: BuiltinToolProvider = ( builtin_provider = (
db.session.query(BuiltinToolProvider) db.session.query(BuiltinToolProvider)
.filter( .filter(
BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.tenant_id == tenant_id,
@ -71,19 +74,18 @@ class BuiltinToolManageService:
return jsonable_encoder([v for _, v in (provider.credentials_schema or {}).items()]) return jsonable_encoder([v for _, v in (provider.credentials_schema or {}).items()])
@staticmethod @staticmethod
def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict): def update_builtin_tool_provider(
session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict
):
""" """
update builtin tool provider update builtin tool provider
""" """
# get if the provider exists # get if the provider exists
provider: BuiltinToolProvider = ( stmt = select(BuiltinToolProvider).where(
db.session.query(BuiltinToolProvider) BuiltinToolProvider.tenant_id == tenant_id,
.filter( BuiltinToolProvider.provider == provider_name,
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider_name,
)
.first()
) )
provider = session.scalar(stmt)
try: try:
# get provider # get provider
@ -115,13 +117,10 @@ class BuiltinToolManageService:
encrypted_credentials=json.dumps(credentials), encrypted_credentials=json.dumps(credentials),
) )
db.session.add(provider) session.add(provider)
db.session.commit()
else: else:
provider.encrypted_credentials = json.dumps(credentials) provider.encrypted_credentials = json.dumps(credentials)
db.session.add(provider)
db.session.commit()
# delete cache # delete cache
tool_configuration.delete_tool_credentials_cache() tool_configuration.delete_tool_credentials_cache()
@ -129,15 +128,15 @@ class BuiltinToolManageService:
return {"result": "success"} return {"result": "success"}
@staticmethod @staticmethod
def get_builtin_tool_provider_credentials(user_id: str, tenant_id: str, provider: str): def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str):
""" """
get builtin tool provider credentials get builtin tool provider credentials
""" """
provider: BuiltinToolProvider = ( provider = (
db.session.query(BuiltinToolProvider) db.session.query(BuiltinToolProvider)
.filter( .filter(
BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider, BuiltinToolProvider.provider == provider_name,
) )
.first() .first()
) )
@ -156,7 +155,7 @@ class BuiltinToolManageService:
""" """
delete tool provider delete tool provider
""" """
provider: BuiltinToolProvider = ( provider = (
db.session.query(BuiltinToolProvider) db.session.query(BuiltinToolProvider)
.filter( .filter(
BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.tenant_id == tenant_id,