From 741c548f3c0d348d9a3e252cffc7859a9fb4dd32 Mon Sep 17 00:00:00 2001 From: Joe <79627742+ZhouhaoJiang@users.noreply.github.com> Date: Sun, 25 Aug 2024 18:47:02 +0800 Subject: [PATCH] feat: web sso (#7135) --- api/controllers/web/passport.py | 10 ++++---- api/controllers/web/wraps.py | 23 ++++++++++++------- api/services/enterprise/enterprise_service.py | 4 ++++ api/services/feature_service.py | 3 ++- 4 files changed, 27 insertions(+), 13 deletions(-) diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index ccc8683a79..cce8943ead 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -9,21 +9,23 @@ from controllers.web.error import WebSSOAuthRequiredError from extensions.ext_database import db from libs.passport import PassportService from models.model import App, EndUser, Site +from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService class PassportResource(Resource): """Base resource for passport.""" def get(self): - system_features = FeatureService.get_system_features() - if system_features.sso_enforced_for_web: - raise WebSSOAuthRequiredError() - app_code = request.headers.get('X-App-Code') if app_code is None: raise Unauthorized('X-App-Code header is missing.') + if system_features.sso_enforced_for_web: + app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get('enabled', False) + if app_web_sso_enabled: + raise WebSSOAuthRequiredError() + # get site from db and check if it is normal site = db.session.query(Site).filter( Site.code == app_code, diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index f5ab49d7e1..ae363672c6 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -8,6 +8,7 @@ from controllers.web.error import WebSSOAuthRequiredError from extensions.ext_database import db from libs.passport import PassportService from models.model import App, EndUser, Site +from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService @@ -26,7 +27,7 @@ def validate_jwt_token(view=None): def decode_jwt_token(): system_features = FeatureService.get_system_features() - + app_code = request.headers.get('X-App-Code') try: auth_header = request.headers.get('Authorization') if auth_header is None: @@ -54,25 +55,31 @@ def decode_jwt_token(): if not end_user: raise NotFound() - _validate_web_sso_token(decoded, system_features) + _validate_web_sso_token(decoded, system_features, app_code) return app_model, end_user except Unauthorized as e: if system_features.sso_enforced_for_web: - raise WebSSOAuthRequiredError() + app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get('enabled', False) + if app_web_sso_enabled: + raise WebSSOAuthRequiredError() raise Unauthorized(e.description) -def _validate_web_sso_token(decoded, system_features): +def _validate_web_sso_token(decoded, system_features, app_code): + app_web_sso_enabled = False + # Check if SSO is enforced for web, and if the token source is not SSO, raise an error and redirect to SSO login if system_features.sso_enforced_for_web: - source = decoded.get('token_source') - if not source or source != 'sso': - raise WebSSOAuthRequiredError() + app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get('enabled', False) + if app_web_sso_enabled: + source = decoded.get('token_source') + if not source or source != 'sso': + raise WebSSOAuthRequiredError() # Check if SSO is not enforced for web, and if the token source is SSO, raise an error and redirect to normal passport login - if not system_features.sso_enforced_for_web: + if not system_features.sso_enforced_for_web or not app_web_sso_enabled: source = decoded.get('token_source') if source and source == 'sso': raise Unauthorized('sso token expired.') diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 115d0d5523..6fd72c2321 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -6,3 +6,7 @@ class EnterpriseService: @classmethod def get_info(cls): return EnterpriseRequest.send_request('GET', '/info') + + @classmethod + def get_app_web_sso_enabled(cls, app_code): + return EnterpriseRequest.send_request('GET', f'/app-sso-setting?appCode={app_code}') diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 83e675a9d2..57400e5951 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -41,7 +41,7 @@ class SystemFeatureModel(BaseModel): sso_enforced_for_signin_protocol: str = '' sso_enforced_for_web: bool = False sso_enforced_for_web_protocol: str = '' - + enable_web_sso_switch_component: bool = False class FeatureService: @@ -61,6 +61,7 @@ class FeatureService: system_features = SystemFeatureModel() if dify_config.ENTERPRISE_ENABLED: + system_features.enable_web_sso_switch_component = True cls._fulfill_params_from_enterprise(system_features) return system_features