mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-11 19:19:00 +08:00
refactor: update builtin tool provider methods to use session management (#11938)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
8f73670925
commit
606aadb891
@ -3,12 +3,14 @@ import io
|
|||||||
from flask import send_file
|
from flask import send_file
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restful import Resource, reqparse
|
from flask_restful import Resource, reqparse
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
|
from extensions.ext_database import db
|
||||||
from libs.helper import alphanumeric, uuid_value
|
from libs.helper import alphanumeric, uuid_value
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||||
@ -91,12 +93,16 @@ class ToolBuiltinProviderUpdateApi(Resource):
|
|||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return BuiltinToolManageService.update_builtin_tool_provider(
|
with Session(db.engine) as session:
|
||||||
user_id,
|
result = BuiltinToolManageService.update_builtin_tool_provider(
|
||||||
tenant_id,
|
session=session,
|
||||||
provider,
|
user_id=user_id,
|
||||||
args["credentials"],
|
tenant_id=tenant_id,
|
||||||
)
|
provider_name=provider,
|
||||||
|
credentials=args["credentials"],
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class ToolBuiltinProviderGetCredentialsApi(Resource):
|
class ToolBuiltinProviderGetCredentialsApi(Resource):
|
||||||
@ -104,13 +110,11 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider):
|
def get(self, provider):
|
||||||
user_id = current_user.id
|
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
return BuiltinToolManageService.get_builtin_tool_provider_credentials(
|
return BuiltinToolManageService.get_builtin_tool_provider_credentials(
|
||||||
user_id,
|
tenant_id=tenant_id,
|
||||||
tenant_id,
|
provider_name=provider,
|
||||||
provider,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,6 +2,9 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.helper.position_helper import is_filtered
|
from core.helper.position_helper import is_filtered
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
@ -32,7 +35,7 @@ class BuiltinToolManageService:
|
|||||||
tenant_id=tenant_id, provider_controller=provider_controller
|
tenant_id=tenant_id, provider_controller=provider_controller
|
||||||
)
|
)
|
||||||
# check if user has added the provider
|
# check if user has added the provider
|
||||||
builtin_provider: BuiltinToolProvider = (
|
builtin_provider = (
|
||||||
db.session.query(BuiltinToolProvider)
|
db.session.query(BuiltinToolProvider)
|
||||||
.filter(
|
.filter(
|
||||||
BuiltinToolProvider.tenant_id == tenant_id,
|
BuiltinToolProvider.tenant_id == tenant_id,
|
||||||
@ -71,19 +74,18 @@ class BuiltinToolManageService:
|
|||||||
return jsonable_encoder([v for _, v in (provider.credentials_schema or {}).items()])
|
return jsonable_encoder([v for _, v in (provider.credentials_schema or {}).items()])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict):
|
def update_builtin_tool_provider(
|
||||||
|
session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
update builtin tool provider
|
update builtin tool provider
|
||||||
"""
|
"""
|
||||||
# get if the provider exists
|
# get if the provider exists
|
||||||
provider: BuiltinToolProvider = (
|
stmt = select(BuiltinToolProvider).where(
|
||||||
db.session.query(BuiltinToolProvider)
|
BuiltinToolProvider.tenant_id == tenant_id,
|
||||||
.filter(
|
BuiltinToolProvider.provider == provider_name,
|
||||||
BuiltinToolProvider.tenant_id == tenant_id,
|
|
||||||
BuiltinToolProvider.provider == provider_name,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
)
|
||||||
|
provider = session.scalar(stmt)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# get provider
|
# get provider
|
||||||
@ -115,13 +117,10 @@ class BuiltinToolManageService:
|
|||||||
encrypted_credentials=json.dumps(credentials),
|
encrypted_credentials=json.dumps(credentials),
|
||||||
)
|
)
|
||||||
|
|
||||||
db.session.add(provider)
|
session.add(provider)
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
provider.encrypted_credentials = json.dumps(credentials)
|
provider.encrypted_credentials = json.dumps(credentials)
|
||||||
db.session.add(provider)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# delete cache
|
# delete cache
|
||||||
tool_configuration.delete_tool_credentials_cache()
|
tool_configuration.delete_tool_credentials_cache()
|
||||||
@ -129,15 +128,15 @@ class BuiltinToolManageService:
|
|||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_builtin_tool_provider_credentials(user_id: str, tenant_id: str, provider: str):
|
def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str):
|
||||||
"""
|
"""
|
||||||
get builtin tool provider credentials
|
get builtin tool provider credentials
|
||||||
"""
|
"""
|
||||||
provider: BuiltinToolProvider = (
|
provider = (
|
||||||
db.session.query(BuiltinToolProvider)
|
db.session.query(BuiltinToolProvider)
|
||||||
.filter(
|
.filter(
|
||||||
BuiltinToolProvider.tenant_id == tenant_id,
|
BuiltinToolProvider.tenant_id == tenant_id,
|
||||||
BuiltinToolProvider.provider == provider,
|
BuiltinToolProvider.provider == provider_name,
|
||||||
)
|
)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
@ -156,7 +155,7 @@ class BuiltinToolManageService:
|
|||||||
"""
|
"""
|
||||||
delete tool provider
|
delete tool provider
|
||||||
"""
|
"""
|
||||||
provider: BuiltinToolProvider = (
|
provider = (
|
||||||
db.session.query(BuiltinToolProvider)
|
db.session.query(BuiltinToolProvider)
|
||||||
.filter(
|
.filter(
|
||||||
BuiltinToolProvider.tenant_id == tenant_id,
|
BuiltinToolProvider.tenant_id == tenant_id,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user