mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-15 00:55:55 +08:00
feat: add web app auth
This commit is contained in:
parent
1045f6db7a
commit
46d43e6758
@ -121,9 +121,15 @@ class UnsupportedFileTypeError(BaseHTTPException):
|
|||||||
code = 415
|
code = 415
|
||||||
|
|
||||||
|
|
||||||
class WebSSOAuthRequiredError(BaseHTTPException):
|
class WebAppAuthRequiredError(BaseHTTPException):
|
||||||
error_code = "web_sso_auth_required"
|
error_code = "web_auth_required"
|
||||||
description = "Web SSO authentication required."
|
description = "Web app authentication required."
|
||||||
|
code = 401
|
||||||
|
|
||||||
|
|
||||||
|
class WebAppAuthFailedError(BaseHTTPException):
|
||||||
|
error_code = "web_app_auth_failed"
|
||||||
|
description = "You do not have permission to access this web app."
|
||||||
code = 401
|
code = 401
|
||||||
|
|
||||||
|
|
||||||
|
118
api/controllers/web/login.py
Normal file
118
api/controllers/web/login.py
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import flask_login
|
||||||
|
from flask import request
|
||||||
|
from flask_restful import Resource, reqparse
|
||||||
|
from jwt import InvalidTokenError # type: ignore
|
||||||
|
from web import api
|
||||||
|
|
||||||
|
import services
|
||||||
|
from controllers.console.auth.error import (EmailCodeError,
|
||||||
|
EmailOrPasswordMismatchError,
|
||||||
|
InvalidEmailError)
|
||||||
|
from controllers.console.error import AccountBannedError, AccountNotFound
|
||||||
|
from controllers.console.wraps import setup_required
|
||||||
|
from libs.helper import email
|
||||||
|
from libs.password import valid_password
|
||||||
|
from models.account import Account
|
||||||
|
from services.account_service import AccountService
|
||||||
|
from services.webapp_auth_service import Unauthorized, WebAppAuthService
|
||||||
|
|
||||||
|
|
||||||
|
class LoginApi(Resource):
|
||||||
|
"""Resource for web app email/password login."""
|
||||||
|
|
||||||
|
def post(self):
|
||||||
|
"""Authenticate user and login."""
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("email", type=email, required=True, location="json")
|
||||||
|
parser.add_argument("password", type=valid_password, required=True, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
app_code = request.headers.get("X-App-Code")
|
||||||
|
if app_code is None:
|
||||||
|
raise Unauthorized("X-App-Code header is missing.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
account = WebAppAuthService.authenticate(args["email"], args["password"])
|
||||||
|
except services.errors.account.AccountLoginError:
|
||||||
|
raise AccountBannedError()
|
||||||
|
except services.errors.account.AccountPasswordError:
|
||||||
|
raise EmailOrPasswordMismatchError()
|
||||||
|
except services.errors.account.AccountNotFoundError:
|
||||||
|
raise AccountNotFound()
|
||||||
|
|
||||||
|
token = WebAppAuthService.login(account=account, app_code=app_code)
|
||||||
|
return {"result": "success", "token": token}
|
||||||
|
|
||||||
|
|
||||||
|
class LogoutApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
def get(self):
|
||||||
|
account = cast(Account, flask_login.current_user)
|
||||||
|
if isinstance(account, flask_login.AnonymousUserMixin):
|
||||||
|
return {"result": "success"}
|
||||||
|
flask_login.logout_user()
|
||||||
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
|
class EmailCodeLoginSendEmailApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
def post(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("email", type=email, required=True, location="json")
|
||||||
|
parser.add_argument("language", type=str, required=False, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||||
|
language = "zh-Hans"
|
||||||
|
else:
|
||||||
|
language = "en-US"
|
||||||
|
|
||||||
|
account = WebAppAuthService.get_user_through_email(args["email"])
|
||||||
|
if account is None:
|
||||||
|
raise AccountNotFound()
|
||||||
|
else:
|
||||||
|
token = WebAppAuthService.send_email_code_login_email(account=account, language=language)
|
||||||
|
|
||||||
|
return {"result": "success", "data": token}
|
||||||
|
|
||||||
|
|
||||||
|
class EmailCodeLoginApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
def post(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("email", type=str, required=True, location="json")
|
||||||
|
parser.add_argument("code", type=str, required=True, location="json")
|
||||||
|
parser.add_argument("token", type=str, required=True, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
user_email = args["email"]
|
||||||
|
app_code = request.headers.get("X-App-Code")
|
||||||
|
if app_code is None:
|
||||||
|
raise Unauthorized("X-App-Code header is missing.")
|
||||||
|
|
||||||
|
token_data = WebAppAuthService.get_email_code_login_data(args["token"])
|
||||||
|
if token_data is None:
|
||||||
|
raise InvalidTokenError()
|
||||||
|
|
||||||
|
if token_data["email"] != args["email"]:
|
||||||
|
raise InvalidEmailError()
|
||||||
|
|
||||||
|
if token_data["code"] != args["code"]:
|
||||||
|
raise EmailCodeError()
|
||||||
|
|
||||||
|
WebAppAuthService.revoke_email_code_login_token(args["token"])
|
||||||
|
account = WebAppAuthService.get_user_through_email(user_email)
|
||||||
|
if not account:
|
||||||
|
raise AccountNotFound()
|
||||||
|
|
||||||
|
token = WebAppAuthService.login(account=account, app_code=app_code)
|
||||||
|
AccountService.reset_login_error_rate_limit(args["email"])
|
||||||
|
return {"result": "success", "token": token}
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(LoginApi, "/login")
|
||||||
|
api.add_resource(LogoutApi, "/logout")
|
||||||
|
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
|
||||||
|
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")
|
@ -5,7 +5,7 @@ from flask_restful import Resource # type: ignore
|
|||||||
from werkzeug.exceptions import NotFound, Unauthorized
|
from werkzeug.exceptions import NotFound, Unauthorized
|
||||||
|
|
||||||
from controllers.web import api
|
from controllers.web import api
|
||||||
from controllers.web.error import WebSSOAuthRequiredError
|
from controllers.web.error import WebAppAuthRequiredError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.passport import PassportService
|
from libs.passport import PassportService
|
||||||
from models.model import App, EndUser, Site
|
from models.model import App, EndUser, Site
|
||||||
@ -22,10 +22,10 @@ class PassportResource(Resource):
|
|||||||
if app_code is None:
|
if app_code is None:
|
||||||
raise Unauthorized("X-App-Code header is missing.")
|
raise Unauthorized("X-App-Code header is missing.")
|
||||||
|
|
||||||
if system_features.sso_enforced_for_web:
|
if system_features.webapp_auth.enabled:
|
||||||
app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False)
|
app_settings = EnterpriseService.get_web_app_settings(app_code=app_code)
|
||||||
if app_web_sso_enabled:
|
if not app_settings or not app_settings.access_mode == "public":
|
||||||
raise WebSSOAuthRequiredError()
|
raise WebAppAuthRequiredError()
|
||||||
|
|
||||||
# get site from db and check if it is normal
|
# get site from db and check if it is normal
|
||||||
site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()
|
site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()
|
||||||
|
@ -4,7 +4,8 @@ from flask import request
|
|||||||
from flask_restful import Resource # type: ignore
|
from flask_restful import Resource # type: ignore
|
||||||
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
||||||
|
|
||||||
from controllers.web.error import WebSSOAuthRequiredError
|
from controllers.web.error import (WebAppAuthFailedError,
|
||||||
|
WebAppAuthRequiredError)
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.passport import PassportService
|
from libs.passport import PassportService
|
||||||
from models.model import App, EndUser, Site
|
from models.model import App, EndUser, Site
|
||||||
@ -57,35 +58,48 @@ def decode_jwt_token():
|
|||||||
if not end_user:
|
if not end_user:
|
||||||
raise NotFound()
|
raise NotFound()
|
||||||
|
|
||||||
_validate_web_sso_token(decoded, system_features, app_code)
|
# for enterprise webapp auth
|
||||||
|
app_web_auth_enabled = False
|
||||||
|
if system_features.webapp_auth.enabled:
|
||||||
|
app_web_auth_enabled = EnterpriseService.get_web_app_settings(app_code=app_code).get("access_mode", "private") == "private"
|
||||||
|
|
||||||
|
_validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled)
|
||||||
|
_validate_user_accessibility(decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled)
|
||||||
|
|
||||||
return app_model, end_user
|
return app_model, end_user
|
||||||
except Unauthorized as e:
|
except Unauthorized as e:
|
||||||
if system_features.sso_enforced_for_web:
|
if system_features.webapp_auth.enabled:
|
||||||
app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False)
|
app_web_auth_enabled = EnterpriseService.get_web_app_settings(app_code=app_code).get("access_mode", "private") == "private"
|
||||||
if app_web_sso_enabled:
|
if app_web_auth_enabled:
|
||||||
raise WebSSOAuthRequiredError()
|
raise WebAppAuthRequiredError()
|
||||||
|
|
||||||
raise Unauthorized(e.description)
|
raise Unauthorized(e.description)
|
||||||
|
|
||||||
|
|
||||||
def _validate_web_sso_token(decoded, system_features, app_code):
|
def _validate_webapp_token(decoded, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool):
|
||||||
app_web_sso_enabled = False
|
# Check if authentication is enforced for web app, and if the token source is not webapp, raise an error and redirect to login
|
||||||
|
if system_webapp_auth_enabled and app_web_auth_enabled:
|
||||||
# Check if SSO is enforced for web, and if the token source is not SSO, raise an error and redirect to SSO login
|
|
||||||
if system_features.sso_enforced_for_web:
|
|
||||||
app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False)
|
|
||||||
if app_web_sso_enabled:
|
|
||||||
source = decoded.get("token_source")
|
|
||||||
if not source or source != "sso":
|
|
||||||
raise WebSSOAuthRequiredError()
|
|
||||||
|
|
||||||
# Check if SSO is not enforced for web, and if the token source is SSO,
|
|
||||||
# raise an error and redirect to normal passport login
|
|
||||||
if not system_features.sso_enforced_for_web or not app_web_sso_enabled:
|
|
||||||
source = decoded.get("token_source")
|
source = decoded.get("token_source")
|
||||||
if source and source == "sso":
|
if not source or source != "webapp":
|
||||||
raise Unauthorized("sso token expired.")
|
raise WebAppAuthRequiredError()
|
||||||
|
|
||||||
|
# Check if authentication is not enforced for web, and if the token source is webapp,
|
||||||
|
# raise an error and redirect to normal passport login
|
||||||
|
if not system_webapp_auth_enabled or not app_web_auth_enabled:
|
||||||
|
source = decoded.get("token_source")
|
||||||
|
if source and source == "webapp":
|
||||||
|
raise Unauthorized("webapp token expired.")
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_user_accessibility(decoded, app_code, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool):
|
||||||
|
if system_webapp_auth_enabled and app_web_auth_enabled:
|
||||||
|
# Check if the user is allowed to access the web app
|
||||||
|
user_id = decoded.get("user_id")
|
||||||
|
if not user_id:
|
||||||
|
raise WebAppAuthRequiredError()
|
||||||
|
|
||||||
|
if not EnterpriseService.is_user_allowed_to_access_webapp(user_id, app_code=app_code):
|
||||||
|
raise WebAppAuthFailedError()
|
||||||
|
|
||||||
|
|
||||||
class WebApiResource(Resource):
|
class WebApiResource(Resource):
|
||||||
|
@ -1,11 +1,35 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from services.enterprise.base import EnterpriseRequest
|
from services.enterprise.base import EnterpriseRequest
|
||||||
|
|
||||||
|
|
||||||
|
class WebAppSettings(BaseModel):
|
||||||
|
access_mode: str = Field(
|
||||||
|
description="Access mode for the web app. Can be 'public' or 'private'",
|
||||||
|
default="private",
|
||||||
|
alias="access_mode",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class EnterpriseService:
|
class EnterpriseService:
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_info(cls):
|
def get_info(cls):
|
||||||
return EnterpriseRequest.send_request("GET", "/info")
|
return EnterpriseRequest.send_request("GET", "/info")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_app_web_sso_enabled(cls, app_code):
|
def is_user_allowed_to_access_webapp(cls, user_id: str, app_id=None, app_code=None) -> bool:
|
||||||
return EnterpriseRequest.send_request("GET", f"/app-sso-setting?appCode={app_code}")
|
if not app_id and not app_code:
|
||||||
|
raise ValueError("Either app_id or app_code must be provided.")
|
||||||
|
|
||||||
|
return EnterpriseRequest.send_request(
|
||||||
|
"GET", f"/web-app/allowed?appId={app_id}&appCode={app_code}&userId={user_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_web_app_settings(cls, app_code: str = None, app_id: str = None):
|
||||||
|
if not app_code and not app_id:
|
||||||
|
raise ValueError("Either app_code or app_id must be provided.")
|
||||||
|
data = EnterpriseRequest.send_request("GET", f"/web-app/settings?appCode={app_code}&appId={app_id}")
|
||||||
|
if not data:
|
||||||
|
raise ValueError("No data found.")
|
||||||
|
return WebAppSettings(**data)
|
||||||
|
@ -44,6 +44,13 @@ class BrandingModel(BaseModel):
|
|||||||
favicon: str = ""
|
favicon: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class WebAppAuthModel(BaseModel):
|
||||||
|
enabled: bool = False
|
||||||
|
allow_sso: bool = False
|
||||||
|
allow_email_code_login: bool = False
|
||||||
|
allow_email_password_login: bool = False
|
||||||
|
|
||||||
|
|
||||||
class FeatureModel(BaseModel):
|
class FeatureModel(BaseModel):
|
||||||
billing: BillingModel = BillingModel()
|
billing: BillingModel = BillingModel()
|
||||||
members: LimitationModel = LimitationModel(size=0, limit=1)
|
members: LimitationModel = LimitationModel(size=0, limit=1)
|
||||||
@ -75,6 +82,7 @@ class SystemFeatureModel(BaseModel):
|
|||||||
is_email_setup: bool = False
|
is_email_setup: bool = False
|
||||||
license: LicenseModel = LicenseModel()
|
license: LicenseModel = LicenseModel()
|
||||||
branding: BrandingModel = BrandingModel()
|
branding: BrandingModel = BrandingModel()
|
||||||
|
webapp_auth: WebAppAuthModel = WebAppAuthModel()
|
||||||
|
|
||||||
|
|
||||||
class FeatureService:
|
class FeatureService:
|
||||||
@ -101,6 +109,7 @@ class FeatureService:
|
|||||||
if dify_config.ENTERPRISE_ENABLED:
|
if dify_config.ENTERPRISE_ENABLED:
|
||||||
system_features.enable_web_sso_switch_component = True
|
system_features.enable_web_sso_switch_component = True
|
||||||
system_features.branding.enabled = True
|
system_features.branding.enabled = True
|
||||||
|
system_features.webapp_auth.enabled = True
|
||||||
cls._fulfill_params_from_enterprise(system_features)
|
cls._fulfill_params_from_enterprise(system_features)
|
||||||
|
|
||||||
return system_features
|
return system_features
|
||||||
@ -194,6 +203,11 @@ class FeatureService:
|
|||||||
features.branding.workspace_logo = enterprise_info["Branding"].get("workspaceLogo", "")
|
features.branding.workspace_logo = enterprise_info["Branding"].get("workspaceLogo", "")
|
||||||
features.branding.favicon = enterprise_info["Branding"].get("favicon", "")
|
features.branding.favicon = enterprise_info["Branding"].get("favicon", "")
|
||||||
|
|
||||||
|
if "WebAppAuth" in enterprise_info:
|
||||||
|
features.webapp_auth.allow_sso = enterprise_info["WebAppAuth"].get("allowSSO", False)
|
||||||
|
features.webapp_auth.allow_email_code_login = enterprise_info["WebAppAuth"].get("allowEmailCodeLogin", False)
|
||||||
|
features.webapp_auth.allow_email_password_login = enterprise_info["WebAppAuth"].get("allowEmailPasswordLogin", False)
|
||||||
|
|
||||||
if "License" in enterprise_info:
|
if "License" in enterprise_info:
|
||||||
license_info = enterprise_info["License"]
|
license_info = enterprise_info["License"]
|
||||||
|
|
||||||
|
103
api/services/webapp_auth_service.py
Normal file
103
api/services/webapp_auth_service.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
import random
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
|
from werkzeug.exceptions import NotFound, Unauthorized
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from libs.helper import TokenManager
|
||||||
|
from libs.passport import PassportService
|
||||||
|
from libs.password import compare_password
|
||||||
|
from models.account import Account, AccountStatus
|
||||||
|
from models.model import Site
|
||||||
|
from services.errors.account import (AccountLoginError, AccountNotFoundError,
|
||||||
|
AccountPasswordError)
|
||||||
|
from tasks.mail_email_code_login import send_email_code_login_mail_task
|
||||||
|
|
||||||
|
|
||||||
|
class WebAppAuthService:
|
||||||
|
"""Service for web app authentication."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def authenticate(email: str, password: str) -> Account:
|
||||||
|
"""authenticate account with email and password"""
|
||||||
|
|
||||||
|
account = Account.query.filter_by(email=email).first()
|
||||||
|
if not account:
|
||||||
|
raise AccountNotFoundError()
|
||||||
|
|
||||||
|
if account.status == AccountStatus.BANNED.value:
|
||||||
|
raise AccountLoginError("Account is banned.")
|
||||||
|
|
||||||
|
if account.password is None or not compare_password(password, account.password, account.password_salt):
|
||||||
|
raise AccountPasswordError("Invalid email or password.")
|
||||||
|
|
||||||
|
return cast(Account, account)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def login(account: Account, app_code: str) -> str:
|
||||||
|
site = db.session.query(Site).filter(Site.code == app_code).first()
|
||||||
|
if not site:
|
||||||
|
raise NotFound("Site not found.")
|
||||||
|
|
||||||
|
access_token = WebAppAuthService._get_account_jwt_token(account=account, site=site)
|
||||||
|
|
||||||
|
return access_token
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_user_through_email(cls, email: str):
|
||||||
|
account = db.session.query(Account).filter(Account.email == email).first()
|
||||||
|
if not account:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if account.status == AccountStatus.BANNED.value:
|
||||||
|
raise Unauthorized("Account is banned.")
|
||||||
|
|
||||||
|
return account
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def send_email_code_login_email(
|
||||||
|
cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US"
|
||||||
|
):
|
||||||
|
email = account.email if account else email
|
||||||
|
if email is None:
|
||||||
|
raise ValueError("Email must be provided.")
|
||||||
|
|
||||||
|
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
|
||||||
|
token = TokenManager.generate_token(
|
||||||
|
account=account, email=email, token_type="webapp_email_code_login", additional_data={"code": code}
|
||||||
|
)
|
||||||
|
send_email_code_login_mail_task.delay(
|
||||||
|
language=language,
|
||||||
|
to=account.email if account else email,
|
||||||
|
code=code,
|
||||||
|
)
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]:
|
||||||
|
return TokenManager.get_token_data(token, "webapp_email_code_login")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def revoke_email_code_login_token(cls, token: str):
|
||||||
|
TokenManager.revoke_token(token, "webapp_email_code_login")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_account_jwt_token(account: Account, site: Site) -> str:
|
||||||
|
exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.WebAppSessionTimeoutInHours * 24)
|
||||||
|
exp = int(exp_dt.timestamp())
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"iss": site.id,
|
||||||
|
"sub": "Web API Passport",
|
||||||
|
"app_id": site.app_id,
|
||||||
|
"app_code": site.code,
|
||||||
|
"user_id": account.id,
|
||||||
|
"token_source": "webapp",
|
||||||
|
"exp": exp,
|
||||||
|
}
|
||||||
|
|
||||||
|
token: str = PassportService().issue(payload)
|
||||||
|
return token
|
Loading…
x
Reference in New Issue
Block a user