From 46d43e6758267e5b4337b30ead15bc73ac19f900 Mon Sep 17 00:00:00 2001 From: GareArc Date: Mon, 7 Apr 2025 17:03:26 -0400 Subject: [PATCH] feat: add web app auth --- api/controllers/web/error.py | 12 +- api/controllers/web/login.py | 118 ++++++++++++++++++ api/controllers/web/passport.py | 10 +- api/controllers/web/wraps.py | 58 +++++---- api/services/enterprise/enterprise_service.py | 28 ++++- api/services/feature_service.py | 14 +++ api/services/webapp_auth_service.py | 103 +++++++++++++++ 7 files changed, 311 insertions(+), 32 deletions(-) create mode 100644 api/controllers/web/login.py create mode 100644 api/services/webapp_auth_service.py diff --git a/api/controllers/web/error.py b/api/controllers/web/error.py index 9fe5d08d54..4909694d26 100644 --- a/api/controllers/web/error.py +++ b/api/controllers/web/error.py @@ -121,9 +121,15 @@ class UnsupportedFileTypeError(BaseHTTPException): code = 415 -class WebSSOAuthRequiredError(BaseHTTPException): - error_code = "web_sso_auth_required" - description = "Web SSO authentication required." +class WebAppAuthRequiredError(BaseHTTPException): + error_code = "web_auth_required" + description = "Web app authentication required." + code = 401 + + +class WebAppAuthFailedError(BaseHTTPException): + error_code = "web_app_auth_failed" + description = "You do not have permission to access this web app." code = 401 diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py new file mode 100644 index 0000000000..235fcaf8cc --- /dev/null +++ b/api/controllers/web/login.py @@ -0,0 +1,118 @@ +from typing import cast + +import flask_login +from flask import request +from flask_restful import Resource, reqparse +from jwt import InvalidTokenError # type: ignore +from web import api + +import services +from controllers.console.auth.error import (EmailCodeError, + EmailOrPasswordMismatchError, + InvalidEmailError) +from controllers.console.error import AccountBannedError, AccountNotFound +from controllers.console.wraps import setup_required +from libs.helper import email +from libs.password import valid_password +from models.account import Account +from services.account_service import AccountService +from services.webapp_auth_service import Unauthorized, WebAppAuthService + + +class LoginApi(Resource): + """Resource for web app email/password login.""" + + def post(self): + """Authenticate user and login.""" + parser = reqparse.RequestParser() + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("password", type=valid_password, required=True, location="json") + args = parser.parse_args() + + app_code = request.headers.get("X-App-Code") + if app_code is None: + raise Unauthorized("X-App-Code header is missing.") + + try: + account = WebAppAuthService.authenticate(args["email"], args["password"]) + except services.errors.account.AccountLoginError: + raise AccountBannedError() + except services.errors.account.AccountPasswordError: + raise EmailOrPasswordMismatchError() + except services.errors.account.AccountNotFoundError: + raise AccountNotFound() + + token = WebAppAuthService.login(account=account, app_code=app_code) + return {"result": "success", "token": token} + + +class LogoutApi(Resource): + @setup_required + def get(self): + account = cast(Account, flask_login.current_user) + if isinstance(account, flask_login.AnonymousUserMixin): + return {"result": "success"} + flask_login.logout_user() + return {"result": "success"} + + +class EmailCodeLoginSendEmailApi(Resource): + @setup_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("language", type=str, required=False, location="json") + args = parser.parse_args() + + if args["language"] is not None and args["language"] == "zh-Hans": + language = "zh-Hans" + else: + language = "en-US" + + account = WebAppAuthService.get_user_through_email(args["email"]) + if account is None: + raise AccountNotFound() + else: + token = WebAppAuthService.send_email_code_login_email(account=account, language=language) + + return {"result": "success", "data": token} + + +class EmailCodeLoginApi(Resource): + @setup_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=str, required=True, location="json") + parser.add_argument("code", type=str, required=True, location="json") + parser.add_argument("token", type=str, required=True, location="json") + args = parser.parse_args() + + user_email = args["email"] + app_code = request.headers.get("X-App-Code") + if app_code is None: + raise Unauthorized("X-App-Code header is missing.") + + token_data = WebAppAuthService.get_email_code_login_data(args["token"]) + if token_data is None: + raise InvalidTokenError() + + if token_data["email"] != args["email"]: + raise InvalidEmailError() + + if token_data["code"] != args["code"]: + raise EmailCodeError() + + WebAppAuthService.revoke_email_code_login_token(args["token"]) + account = WebAppAuthService.get_user_through_email(user_email) + if not account: + raise AccountNotFound() + + token = WebAppAuthService.login(account=account, app_code=app_code) + AccountService.reset_login_error_rate_limit(args["email"]) + return {"result": "success", "token": token} + + +api.add_resource(LoginApi, "/login") +api.add_resource(LogoutApi, "/logout") +api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login") +api.add_resource(EmailCodeLoginApi, "/email-code-login/validity") diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index 4625c1f43d..3c1f0a415f 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -5,7 +5,7 @@ from flask_restful import Resource # type: ignore from werkzeug.exceptions import NotFound, Unauthorized from controllers.web import api -from controllers.web.error import WebSSOAuthRequiredError +from controllers.web.error import WebAppAuthRequiredError from extensions.ext_database import db from libs.passport import PassportService from models.model import App, EndUser, Site @@ -22,10 +22,10 @@ class PassportResource(Resource): if app_code is None: raise Unauthorized("X-App-Code header is missing.") - if system_features.sso_enforced_for_web: - app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False) - if app_web_sso_enabled: - raise WebSSOAuthRequiredError() + if system_features.webapp_auth.enabled: + app_settings = EnterpriseService.get_web_app_settings(app_code=app_code) + if not app_settings or not app_settings.access_mode == "public": + raise WebAppAuthRequiredError() # get site from db and check if it is normal site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first() diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index 1b4d263bee..482d7859fa 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -4,7 +4,8 @@ from flask import request from flask_restful import Resource # type: ignore from werkzeug.exceptions import BadRequest, NotFound, Unauthorized -from controllers.web.error import WebSSOAuthRequiredError +from controllers.web.error import (WebAppAuthFailedError, + WebAppAuthRequiredError) from extensions.ext_database import db from libs.passport import PassportService from models.model import App, EndUser, Site @@ -57,35 +58,48 @@ def decode_jwt_token(): if not end_user: raise NotFound() - _validate_web_sso_token(decoded, system_features, app_code) + # for enterprise webapp auth + app_web_auth_enabled = False + if system_features.webapp_auth.enabled: + app_web_auth_enabled = EnterpriseService.get_web_app_settings(app_code=app_code).get("access_mode", "private") == "private" + + _validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled) + _validate_user_accessibility(decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled) return app_model, end_user except Unauthorized as e: - if system_features.sso_enforced_for_web: - app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False) - if app_web_sso_enabled: - raise WebSSOAuthRequiredError() + if system_features.webapp_auth.enabled: + app_web_auth_enabled = EnterpriseService.get_web_app_settings(app_code=app_code).get("access_mode", "private") == "private" + if app_web_auth_enabled: + raise WebAppAuthRequiredError() raise Unauthorized(e.description) -def _validate_web_sso_token(decoded, system_features, app_code): - app_web_sso_enabled = False - - # Check if SSO is enforced for web, and if the token source is not SSO, raise an error and redirect to SSO login - if system_features.sso_enforced_for_web: - app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False) - if app_web_sso_enabled: - source = decoded.get("token_source") - if not source or source != "sso": - raise WebSSOAuthRequiredError() - - # Check if SSO is not enforced for web, and if the token source is SSO, - # raise an error and redirect to normal passport login - if not system_features.sso_enforced_for_web or not app_web_sso_enabled: +def _validate_webapp_token(decoded, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool): + # Check if authentication is enforced for web app, and if the token source is not webapp, raise an error and redirect to login + if system_webapp_auth_enabled and app_web_auth_enabled: source = decoded.get("token_source") - if source and source == "sso": - raise Unauthorized("sso token expired.") + if not source or source != "webapp": + raise WebAppAuthRequiredError() + + # Check if authentication is not enforced for web, and if the token source is webapp, + # raise an error and redirect to normal passport login + if not system_webapp_auth_enabled or not app_web_auth_enabled: + source = decoded.get("token_source") + if source and source == "webapp": + raise Unauthorized("webapp token expired.") + + +def _validate_user_accessibility(decoded, app_code, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool): + if system_webapp_auth_enabled and app_web_auth_enabled: + # Check if the user is allowed to access the web app + user_id = decoded.get("user_id") + if not user_id: + raise WebAppAuthRequiredError() + + if not EnterpriseService.is_user_allowed_to_access_webapp(user_id, app_code=app_code): + raise WebAppAuthFailedError() class WebApiResource(Resource): diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index abc01ddf8f..2ff3b3348a 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -1,11 +1,35 @@ +from pydantic import BaseModel, Field + from services.enterprise.base import EnterpriseRequest +class WebAppSettings(BaseModel): + access_mode: str = Field( + description="Access mode for the web app. Can be 'public' or 'private'", + default="private", + alias="access_mode", + ) + + class EnterpriseService: @classmethod def get_info(cls): return EnterpriseRequest.send_request("GET", "/info") @classmethod - def get_app_web_sso_enabled(cls, app_code): - return EnterpriseRequest.send_request("GET", f"/app-sso-setting?appCode={app_code}") + def is_user_allowed_to_access_webapp(cls, user_id: str, app_id=None, app_code=None) -> bool: + if not app_id and not app_code: + raise ValueError("Either app_id or app_code must be provided.") + + return EnterpriseRequest.send_request( + "GET", f"/web-app/allowed?appId={app_id}&appCode={app_code}&userId={user_id}" + ) + + @classmethod + def get_web_app_settings(cls, app_code: str = None, app_id: str = None): + if not app_code and not app_id: + raise ValueError("Either app_code or app_id must be provided.") + data = EnterpriseRequest.send_request("GET", f"/web-app/settings?appCode={app_code}&appId={app_id}") + if not data: + raise ValueError("No data found.") + return WebAppSettings(**data) diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 959e0221b5..f37f4e0d92 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -44,6 +44,13 @@ class BrandingModel(BaseModel): favicon: str = "" +class WebAppAuthModel(BaseModel): + enabled: bool = False + allow_sso: bool = False + allow_email_code_login: bool = False + allow_email_password_login: bool = False + + class FeatureModel(BaseModel): billing: BillingModel = BillingModel() members: LimitationModel = LimitationModel(size=0, limit=1) @@ -75,6 +82,7 @@ class SystemFeatureModel(BaseModel): is_email_setup: bool = False license: LicenseModel = LicenseModel() branding: BrandingModel = BrandingModel() + webapp_auth: WebAppAuthModel = WebAppAuthModel() class FeatureService: @@ -101,6 +109,7 @@ class FeatureService: if dify_config.ENTERPRISE_ENABLED: system_features.enable_web_sso_switch_component = True system_features.branding.enabled = True + system_features.webapp_auth.enabled = True cls._fulfill_params_from_enterprise(system_features) return system_features @@ -194,6 +203,11 @@ class FeatureService: features.branding.workspace_logo = enterprise_info["Branding"].get("workspaceLogo", "") features.branding.favicon = enterprise_info["Branding"].get("favicon", "") + if "WebAppAuth" in enterprise_info: + features.webapp_auth.allow_sso = enterprise_info["WebAppAuth"].get("allowSSO", False) + features.webapp_auth.allow_email_code_login = enterprise_info["WebAppAuth"].get("allowEmailCodeLogin", False) + features.webapp_auth.allow_email_password_login = enterprise_info["WebAppAuth"].get("allowEmailPasswordLogin", False) + if "License" in enterprise_info: license_info = enterprise_info["License"] diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py new file mode 100644 index 0000000000..65501bbffa --- /dev/null +++ b/api/services/webapp_auth_service.py @@ -0,0 +1,103 @@ +import random +from datetime import UTC, datetime, timedelta +from typing import Any, Optional, cast + +from werkzeug.exceptions import NotFound, Unauthorized + +from configs import dify_config +from extensions.ext_database import db +from libs.helper import TokenManager +from libs.passport import PassportService +from libs.password import compare_password +from models.account import Account, AccountStatus +from models.model import Site +from services.errors.account import (AccountLoginError, AccountNotFoundError, + AccountPasswordError) +from tasks.mail_email_code_login import send_email_code_login_mail_task + + +class WebAppAuthService: + """Service for web app authentication.""" + + @staticmethod + def authenticate(email: str, password: str) -> Account: + """authenticate account with email and password""" + + account = Account.query.filter_by(email=email).first() + if not account: + raise AccountNotFoundError() + + if account.status == AccountStatus.BANNED.value: + raise AccountLoginError("Account is banned.") + + if account.password is None or not compare_password(password, account.password, account.password_salt): + raise AccountPasswordError("Invalid email or password.") + + return cast(Account, account) + + @staticmethod + def login(account: Account, app_code: str) -> str: + site = db.session.query(Site).filter(Site.code == app_code).first() + if not site: + raise NotFound("Site not found.") + + access_token = WebAppAuthService._get_account_jwt_token(account=account, site=site) + + return access_token + + @classmethod + def get_user_through_email(cls, email: str): + account = db.session.query(Account).filter(Account.email == email).first() + if not account: + return None + + if account.status == AccountStatus.BANNED.value: + raise Unauthorized("Account is banned.") + + return account + + @classmethod + def send_email_code_login_email( + cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US" + ): + email = account.email if account else email + if email is None: + raise ValueError("Email must be provided.") + + code = "".join([str(random.randint(0, 9)) for _ in range(6)]) + token = TokenManager.generate_token( + account=account, email=email, token_type="webapp_email_code_login", additional_data={"code": code} + ) + send_email_code_login_mail_task.delay( + language=language, + to=account.email if account else email, + code=code, + ) + + return token + + @classmethod + def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]: + return TokenManager.get_token_data(token, "webapp_email_code_login") + + @classmethod + def revoke_email_code_login_token(cls, token: str): + TokenManager.revoke_token(token, "webapp_email_code_login") + + @staticmethod + def _get_account_jwt_token(account: Account, site: Site) -> str: + exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.WebAppSessionTimeoutInHours * 24) + exp = int(exp_dt.timestamp()) + + payload = { + "iss": site.id, + "sub": "Web API Passport", + "app_id": site.app_id, + "app_code": site.code, + "user_id": account.id, + "token_source": "webapp", + "exp": exp, + } + + token: str = PassportService().issue(payload) + return token