From 5f87bdbe3ac2abdd981060312b625bd87babc963 Mon Sep 17 00:00:00 2001 From: GareArc Date: Fri, 11 Apr 2025 15:24:32 -0400 Subject: [PATCH] fix: add batch get access mode api --- api/controllers/console/app/app.py | 11 +- api/controllers/web/passport.py | 2 +- api/controllers/web/wraps.py | 9 +- api/services/app_service.py | 13 +-- api/services/enterprise/enterprise_service.py | 105 +++++++++++------- api/services/feature_service.py | 4 +- api/services/webapp_auth_service.py | 10 +- 7 files changed, 86 insertions(+), 68 deletions(-) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index ce6da4af79..7ab594eb26 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -66,9 +66,14 @@ class AppListApi(Resource): return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} if FeatureService.get_system_features().webapp_auth.enabled: + app_ids = [str(app.id) for app in app_pagination.items] + res = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(app_ids=app_ids) + if len(res) != len(app_ids): + raise BadRequest("Invalid app id in webapp auth") + for app in app_pagination.items: - app_setting = EnterpriseService.get_app_access_mode_by_id(app_id=str(app.id)) - app.access_mode = app_setting.access_mode + if str(app.id) in res: + app.access_mode = res[str(app.id)].access_mode return marshal(app_pagination, app_pagination_fields), 200 @@ -115,7 +120,7 @@ class AppApi(Resource): app_model = app_service.get_app(app_model) if FeatureService.get_system_features().webapp_auth.enabled: - app_setting = EnterpriseService.get_app_access_mode_by_id(app_id=str(app_model.id)) + app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id)) app_model.access_mode = app_setting.access_mode return app_model diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index 3c07b3e87d..8ab9b84574 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -23,7 +23,7 @@ class PassportResource(Resource): raise Unauthorized("X-App-Code header is missing.") if system_features.webapp_auth.enabled: - app_settings = EnterpriseService.get_app_access_mode_by_code(app_code=app_code) + app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code) if not app_settings or not app_settings.access_mode == "public": raise WebAppAuthRequiredError() diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index 5a74296b82..8d35b8e4be 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -4,8 +4,7 @@ from flask import request from flask_restful import Resource # type: ignore from werkzeug.exceptions import BadRequest, NotFound, Unauthorized -from controllers.web.error import (WebAppAuthAccessDeniedError, - WebAppAuthRequiredError) +from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError from extensions.ext_database import db from libs.passport import PassportService from models.model import App, EndUser, Site @@ -62,7 +61,7 @@ def decode_jwt_token(): app_web_auth_enabled = False if system_features.webapp_auth.enabled: app_web_auth_enabled = ( - EnterpriseService.get_app_access_mode_by_code(app_code=app_code).access_mode != "public" + EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code).access_mode != "public" ) _validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled) @@ -72,7 +71,7 @@ def decode_jwt_token(): except Unauthorized as e: if system_features.webapp_auth.enabled: app_web_auth_enabled = ( - EnterpriseService.get_app_access_mode_by_code(app_code=app_code).access_mode != "public" + EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code).access_mode != "public" ) if app_web_auth_enabled: raise WebAppAuthRequiredError() @@ -103,7 +102,7 @@ def _validate_user_accessibility(decoded, app_code, app_web_auth_enabled: bool, if not user_id: raise WebAppAuthRequiredError() - if not EnterpriseService.is_user_allowed_to_access_webapp(user_id, app_code=app_code): + if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code): raise WebAppAuthAccessDeniedError() diff --git a/api/services/app_service.py b/api/services/app_service.py index 03393c00fa..9359bb2844 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -12,10 +12,8 @@ from core.agent.entities import AgentToolEntity from core.app.features.rate_limiting import RateLimit from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import (ModelPropertyKey, - ModelType) -from core.model_runtime.model_providers.__base.large_language_model import \ - LargeLanguageModel +from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_was_created @@ -26,8 +24,7 @@ from models.tools import ApiToolProvider from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService from services.tag_service import TagService -from tasks.remove_app_and_related_data_task import \ - remove_app_and_related_data_task +from tasks.remove_app_and_related_data_task import remove_app_and_related_data_task class AppService: @@ -159,7 +156,7 @@ class AppService: if FeatureService.get_system_features().webapp_auth.enabled: # update web app setting as private - EnterpriseService.update_app_access_mode(app.id, "private") + EnterpriseService.WebAppAuth.update_app_access_mode(app.id, "private") return app @@ -319,7 +316,7 @@ class AppService: # clean up web app settings if FeatureService.get_system_features().webapp_auth.enabled: - EnterpriseService.cleanup_webapp(app.id) + EnterpriseService.WebAppAuth.cleanup_webapp(app.id) # Trigger asynchronous deletion of app and related data remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id) diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 299764ffc4..a3e4d163c3 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -1,3 +1,5 @@ +import logging + from pydantic import BaseModel, Field from services.enterprise.base import EnterpriseRequest @@ -16,55 +18,72 @@ class EnterpriseService: def get_info(cls): return EnterpriseRequest.send_request("GET", "/info") - @classmethod - def is_user_allowed_to_access_webapp(cls, user_id: str, app_code: str) -> bool: - params = {"userId": user_id, "appCode": app_code} - data = EnterpriseRequest.send_request("GET", "/webapp/permission", params=params) + class WebAppAuth: + @classmethod + def is_user_allowed_to_access_webapp(cls, user_id: str, app_code: str) -> bool: + params = {"userId": user_id, "appCode": app_code} + data = EnterpriseRequest.send_request("GET", "/webapp/permission", params=params) - return data.get("result", False) + return data.get("result", False) - @classmethod - def get_app_access_mode_by_id(cls, app_id: str) -> WebAppSettings: - if not app_id: - raise ValueError("app_id must be provided.") - params = {"appId": app_id} - data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/id", params=params) - if not data: - raise ValueError("No data found.") - return WebAppSettings(**data) + @classmethod + def get_app_access_mode_by_id(cls, app_id: str) -> WebAppSettings: + if not app_id: + raise ValueError("app_id must be provided.") + params = {"appId": app_id} + data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/id", params=params) + if not data: + raise ValueError("No data found.") + return WebAppSettings(**data) - @classmethod - def get_app_access_mode_by_code(cls, app_code: str) -> WebAppSettings: - if not app_code: - raise ValueError("app_code must be provided.") - params = {"appCode": app_code} - data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/code", params=params) - if not data: - raise ValueError("No data found.") - return WebAppSettings(**data) + @classmethod + def batch_get_app_access_mode_by_id(cls, app_ids: list[str]) -> dict[str, WebAppSettings]: + if not app_ids: + raise ValueError("app_ids must be provided.") + params = {"appIds": ",".join(app_ids)} + data: dict[str, str] = EnterpriseRequest.send_request("GET", "/webapp/access-mode/batch/id", params=params) + if not data: + raise ValueError("No data found.") - @classmethod - def update_app_access_mode(cls, app_id: str, access_mode: str) -> bool: - if not app_id: - raise ValueError("app_id must be provided.") - if access_mode not in ["public", "private", "private_all"]: - raise ValueError("access_mode must be either 'public', 'private', or 'private_all'") + logging.info(f"Batch get app access mode by id returns data: {data}") - data = { - "appId": app_id, - "accessMode": access_mode - } + if not isinstance(data, dict): + raise ValueError("Invalid data format.") - response = EnterpriseRequest.send_request("POST", "/webapp/access-mode", json=data) + for key, value in data.items(): + curr = WebAppSettings() + curr.access_mode = value + data[key] = curr - return response.get("result", False) + return data - @classmethod - def cleanup_webapp(cls, app_id: str): - if not app_id: - raise ValueError("app_id must be provided.") + @classmethod + def get_app_access_mode_by_code(cls, app_code: str) -> WebAppSettings: + if not app_code: + raise ValueError("app_code must be provided.") + params = {"appCode": app_code} + data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/code", params=params) + if not data: + raise ValueError("No data found.") + return WebAppSettings(**data) - body = { - "appId": app_id - } - EnterpriseRequest.send_request("DELETE", "/webapp/clean", json=body) + @classmethod + def update_app_access_mode(cls, app_id: str, access_mode: str) -> bool: + if not app_id: + raise ValueError("app_id must be provided.") + if access_mode not in ["public", "private", "private_all"]: + raise ValueError("access_mode must be either 'public', 'private', or 'private_all'") + + data = {"appId": app_id, "accessMode": access_mode} + + response = EnterpriseRequest.send_request("POST", "/webapp/access-mode", json=data) + + return response.get("result", False) + + @classmethod + def cleanup_webapp(cls, app_id: str): + if not app_id: + raise ValueError("app_id must be provided.") + + body = {"appId": app_id} + EnterpriseRequest.send_request("DELETE", "/webapp/clean", json=body) diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 7575d4101b..e62a94cc9d 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -203,9 +203,7 @@ class FeatureService: features.webapp_auth.allow_email_password_login = enterprise_info["WebAppAuth"].get( "allowEmailPasswordLogin", False ) - features.webapp_auth.sso_config.protocol = enterprise_info.get( - "SSOEnforcedForSigninProtocol", "" - ) + features.webapp_auth.sso_config.protocol = enterprise_info.get("SSOEnforcedForSigninProtocol", "") if "License" in enterprise_info: license_info = enterprise_info["License"] diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py index 6a4a9c795e..506b7698e0 100644 --- a/api/services/webapp_auth_service.py +++ b/api/services/webapp_auth_service.py @@ -13,8 +13,7 @@ from libs.password import compare_password from models.account import Account, AccountStatus from models.model import App, EndUser, Site from services.enterprise.enterprise_service import EnterpriseService -from services.errors.account import (AccountLoginError, AccountNotFoundError, - AccountPasswordError) +from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError from services.feature_service import FeatureService from tasks.mail_email_code_login import send_email_code_login_mail_task @@ -110,10 +109,11 @@ class WebAppAuthService: """Check if the user is allowed to access the app.""" system_features = FeatureService.get_system_features() if system_features.webapp_auth.enabled: - app_settings = EnterpriseService.get_app_access_mode_by_code(app_code=app_code) + app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code) - if app_settings.access_mode != "public" and not EnterpriseService.is_user_allowed_to_access_webapp( - account.id, app_code=app_code + if ( + app_settings.access_mode != "public" + and not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(account.id, app_code=app_code) ): raise WebAppAuthAccessDeniedError()