diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 498557cd51..72ec05f654 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -37,9 +37,6 @@ from .billing import billing # Import datasets controllers from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing -# Import enterprise controllers -from .enterprise import enterprise_sso - # Import explore controllers from .explore import ( audio, diff --git a/api/controllers/console/enterprise/__init__.py b/api/controllers/console/enterprise/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/controllers/console/enterprise/enterprise_sso.py b/api/controllers/console/enterprise/enterprise_sso.py deleted file mode 100644 index f6a2897d5a..0000000000 --- a/api/controllers/console/enterprise/enterprise_sso.py +++ /dev/null @@ -1,59 +0,0 @@ -from flask import current_app, redirect -from flask_restful import Resource, reqparse - -from controllers.console import api -from controllers.console.setup import setup_required -from services.enterprise.enterprise_sso_service import EnterpriseSSOService - - -class EnterpriseSSOSamlLogin(Resource): - - @setup_required - def get(self): - return EnterpriseSSOService.get_sso_saml_login() - - -class EnterpriseSSOSamlAcs(Resource): - - @setup_required - def post(self): - parser = reqparse.RequestParser() - parser.add_argument('SAMLResponse', type=str, required=True, location='form') - args = parser.parse_args() - saml_response = args['SAMLResponse'] - - try: - token = EnterpriseSSOService.post_sso_saml_acs(saml_response) - return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}') - except Exception as e: - return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}') - - -class EnterpriseSSOOidcLogin(Resource): - - @setup_required - def get(self): - return EnterpriseSSOService.get_sso_oidc_login() - - -class EnterpriseSSOOidcCallback(Resource): - - @setup_required - def get(self): - parser = reqparse.RequestParser() - parser.add_argument('state', type=str, required=True, location='args') - parser.add_argument('code', type=str, required=True, location='args') - parser.add_argument('oidc-state', type=str, required=True, location='cookies') - args = parser.parse_args() - - try: - token = EnterpriseSSOService.get_sso_oidc_callback(args) - return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}') - except Exception as e: - return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}') - - -api.add_resource(EnterpriseSSOSamlLogin, '/enterprise/sso/saml/login') -api.add_resource(EnterpriseSSOSamlAcs, '/enterprise/sso/saml/acs') -api.add_resource(EnterpriseSSOOidcLogin, '/enterprise/sso/oidc/login') -api.add_resource(EnterpriseSSOOidcCallback, '/enterprise/sso/oidc/callback') diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 325652a447..7334f85a57 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -1,7 +1,6 @@ from flask_login import current_user from flask_restful import Resource -from services.enterprise.enterprise_feature_service import EnterpriseFeatureService from services.feature_service import FeatureService from . import api @@ -15,10 +14,10 @@ class FeatureApi(Resource): return FeatureService.get_features(current_user.current_tenant_id).dict() -class EnterpriseFeatureApi(Resource): +class SystemFeatureApi(Resource): def get(self): - return EnterpriseFeatureService.get_enterprise_features().dict() + return FeatureService.get_system_features().dict() api.add_resource(FeatureApi, '/features') -api.add_resource(EnterpriseFeatureApi, '/enterprise-features') +api.add_resource(SystemFeatureApi, '/system-features') diff --git a/api/controllers/web/__init__.py b/api/controllers/web/__init__.py index b6d46d4081..aa19bdc034 100644 --- a/api/controllers/web/__init__.py +++ b/api/controllers/web/__init__.py @@ -6,4 +6,4 @@ bp = Blueprint('web', __name__, url_prefix='/api') api = ExternalApi(bp) -from . import app, audio, completion, conversation, file, message, passport, saved_message, site, workflow +from . import app, audio, completion, conversation, feature, file, message, passport, saved_message, site, workflow diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 2586f2e6ec..91d9015c33 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,14 +1,10 @@ -import json - from flask import current_app from flask_restful import fields, marshal_with from controllers.web import api from controllers.web.error import AppUnavailableError from controllers.web.wraps import WebApiResource -from extensions.ext_database import db -from models.model import App, AppMode, AppModelConfig -from models.tools import ApiToolProvider +from models.model import App, AppMode from services.app_service import AppService diff --git a/api/controllers/web/error.py b/api/controllers/web/error.py index 390e3fe7d1..bc87f51051 100644 --- a/api/controllers/web/error.py +++ b/api/controllers/web/error.py @@ -115,3 +115,9 @@ class UnsupportedFileTypeError(BaseHTTPException): error_code = 'unsupported_file_type' description = "File type not allowed." code = 415 + + +class WebSSOAuthRequiredError(BaseHTTPException): + error_code = 'web_sso_auth_required' + description = "Web SSO authentication required." + code = 401 diff --git a/api/controllers/web/feature.py b/api/controllers/web/feature.py new file mode 100644 index 0000000000..65842d78c6 --- /dev/null +++ b/api/controllers/web/feature.py @@ -0,0 +1,12 @@ +from flask_restful import Resource + +from controllers.web import api +from services.feature_service import FeatureService + + +class SystemFeatureApi(Resource): + def get(self): + return FeatureService.get_system_features().dict() + + +api.add_resource(SystemFeatureApi, '/system-features') diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index 92b28d8125..ccc8683a79 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -5,14 +5,21 @@ from flask_restful import Resource from werkzeug.exceptions import NotFound, Unauthorized from controllers.web import api +from controllers.web.error import WebSSOAuthRequiredError from extensions.ext_database import db from libs.passport import PassportService from models.model import App, EndUser, Site +from services.feature_service import FeatureService class PassportResource(Resource): """Base resource for passport.""" def get(self): + + system_features = FeatureService.get_system_features() + if system_features.sso_enforced_for_web: + raise WebSSOAuthRequiredError() + app_code = request.headers.get('X-App-Code') if app_code is None: raise Unauthorized('X-App-Code header is missing.') @@ -28,7 +35,7 @@ class PassportResource(Resource): app_model = db.session.query(App).filter(App.id == site.app_id).first() if not app_model or app_model.status != 'normal' or not app_model.enable_site: raise NotFound() - + end_user = EndUser( tenant_id=app_model.tenant_id, app_id=app_model.id, @@ -36,6 +43,7 @@ class PassportResource(Resource): is_anonymous=True, session_id=generate_session_id(), ) + db.session.add(end_user) db.session.commit() @@ -53,8 +61,10 @@ class PassportResource(Resource): 'access_token': tk, } + api.add_resource(PassportResource, '/passport') + def generate_session_id(): """ Generate a unique session ID. diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index bdaa476f34..f5ab49d7e1 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -2,11 +2,13 @@ from functools import wraps from flask import request from flask_restful import Resource -from werkzeug.exceptions import NotFound, Unauthorized +from werkzeug.exceptions import BadRequest, NotFound, Unauthorized +from controllers.web.error import WebSSOAuthRequiredError from extensions.ext_database import db from libs.passport import PassportService from models.model import App, EndUser, Site +from services.feature_service import FeatureService def validate_jwt_token(view=None): @@ -21,34 +23,60 @@ def validate_jwt_token(view=None): return decorator(view) return decorator + def decode_jwt_token(): - auth_header = request.headers.get('Authorization') - if auth_header is None: - raise Unauthorized('Authorization header is missing.') + system_features = FeatureService.get_system_features() - if ' ' not in auth_header: - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') - - auth_scheme, tk = auth_header.split(None, 1) - auth_scheme = auth_scheme.lower() + try: + auth_header = request.headers.get('Authorization') + if auth_header is None: + raise Unauthorized('Authorization header is missing.') - if auth_scheme != 'bearer': - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') - decoded = PassportService().verify(tk) - app_code = decoded.get('app_code') - app_model = db.session.query(App).filter(App.id == decoded['app_id']).first() - site = db.session.query(Site).filter(Site.code == app_code).first() - if not app_model: - raise NotFound() - if not app_code or not site: - raise Unauthorized('Site URL is no longer valid.') - if app_model.enable_site is False: - raise Unauthorized('Site is disabled.') - end_user = db.session.query(EndUser).filter(EndUser.id == decoded['end_user_id']).first() - if not end_user: - raise NotFound() + if ' ' not in auth_header: + raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + + auth_scheme, tk = auth_header.split(None, 1) + auth_scheme = auth_scheme.lower() + + if auth_scheme != 'bearer': + raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + decoded = PassportService().verify(tk) + app_code = decoded.get('app_code') + app_model = db.session.query(App).filter(App.id == decoded['app_id']).first() + site = db.session.query(Site).filter(Site.code == app_code).first() + if not app_model: + raise NotFound() + if not app_code or not site: + raise BadRequest('Site URL is no longer valid.') + if app_model.enable_site is False: + raise BadRequest('Site is disabled.') + end_user = db.session.query(EndUser).filter(EndUser.id == decoded['end_user_id']).first() + if not end_user: + raise NotFound() + + _validate_web_sso_token(decoded, system_features) + + return app_model, end_user + except Unauthorized as e: + if system_features.sso_enforced_for_web: + raise WebSSOAuthRequiredError() + + raise Unauthorized(e.description) + + +def _validate_web_sso_token(decoded, system_features): + # Check if SSO is enforced for web, and if the token source is not SSO, raise an error and redirect to SSO login + if system_features.sso_enforced_for_web: + source = decoded.get('token_source') + if not source or source != 'sso': + raise WebSSOAuthRequiredError() + + # Check if SSO is not enforced for web, and if the token source is SSO, raise an error and redirect to normal passport login + if not system_features.sso_enforced_for_web: + source = decoded.get('token_source') + if source and source == 'sso': + raise Unauthorized('sso token expired.') - return app_model, end_user class WebApiResource(Resource): method_decorators = [validate_jwt_token] diff --git a/api/services/enterprise/enterprise_feature_service.py b/api/services/enterprise/enterprise_feature_service.py deleted file mode 100644 index fe33349aa8..0000000000 --- a/api/services/enterprise/enterprise_feature_service.py +++ /dev/null @@ -1,28 +0,0 @@ -from flask import current_app -from pydantic import BaseModel - -from services.enterprise.enterprise_service import EnterpriseService - - -class EnterpriseFeatureModel(BaseModel): - sso_enforced_for_signin: bool = False - sso_enforced_for_signin_protocol: str = '' - - -class EnterpriseFeatureService: - - @classmethod - def get_enterprise_features(cls) -> EnterpriseFeatureModel: - features = EnterpriseFeatureModel() - - if current_app.config['ENTERPRISE_ENABLED']: - cls._fulfill_params_from_enterprise(features) - - return features - - @classmethod - def _fulfill_params_from_enterprise(cls, features): - enterprise_info = EnterpriseService.get_info() - - features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin'] - features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol'] diff --git a/api/services/enterprise/enterprise_sso_service.py b/api/services/enterprise/enterprise_sso_service.py deleted file mode 100644 index d8e19f23bf..0000000000 --- a/api/services/enterprise/enterprise_sso_service.py +++ /dev/null @@ -1,60 +0,0 @@ -import logging - -from models.account import Account, AccountStatus -from services.account_service import AccountService, TenantService -from services.enterprise.base import EnterpriseRequest - -logger = logging.getLogger(__name__) - - -class EnterpriseSSOService: - - @classmethod - def get_sso_saml_login(cls) -> str: - return EnterpriseRequest.send_request('GET', '/sso/saml/login') - - @classmethod - def post_sso_saml_acs(cls, saml_response: str) -> str: - response = EnterpriseRequest.send_request('POST', '/sso/saml/acs', json={'SAMLResponse': saml_response}) - if 'email' not in response or response['email'] is None: - logger.exception(response) - raise Exception('Saml response is invalid') - - return cls.login_with_email(response.get('email')) - - @classmethod - def get_sso_oidc_login(cls): - return EnterpriseRequest.send_request('GET', '/sso/oidc/login') - - @classmethod - def get_sso_oidc_callback(cls, args: dict): - state_from_query = args['state'] - code_from_query = args['code'] - state_from_cookies = args['oidc-state'] - - if state_from_cookies != state_from_query: - raise Exception('invalid state or code') - - response = EnterpriseRequest.send_request('GET', '/sso/oidc/callback', params={'code': code_from_query}) - if 'email' not in response or response['email'] is None: - logger.exception(response) - raise Exception('OIDC response is invalid') - - return cls.login_with_email(response.get('email')) - - @classmethod - def login_with_email(cls, email: str) -> str: - account = Account.query.filter_by(email=email).first() - if account is None: - raise Exception('account not found, please contact system admin to invite you to join in a workspace') - - if account.status == AccountStatus.BANNED: - raise Exception('account is banned, please contact system admin') - - tenants = TenantService.get_join_tenants(account) - if len(tenants) == 0: - raise Exception("workspace not found, please contact system admin to invite you to join in a workspace") - - token = AccountService.get_account_jwt_token(account) - - return token diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 3cf51d11a0..29842d68b7 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -2,6 +2,7 @@ from flask import current_app from pydantic import BaseModel from services.billing_service import BillingService +from services.enterprise.enterprise_service import EnterpriseService class SubscriptionModel(BaseModel): @@ -30,6 +31,13 @@ class FeatureModel(BaseModel): can_replace_logo: bool = False +class SystemFeatureModel(BaseModel): + sso_enforced_for_signin: bool = False + sso_enforced_for_signin_protocol: str = '' + sso_enforced_for_web: bool = False + sso_enforced_for_web_protocol: str = '' + + class FeatureService: @classmethod @@ -43,6 +51,15 @@ class FeatureService: return features + @classmethod + def get_system_features(cls) -> SystemFeatureModel: + system_features = SystemFeatureModel() + + if current_app.config['ENTERPRISE_ENABLED']: + cls._fulfill_params_from_enterprise(system_features) + + return system_features + @classmethod def _fulfill_params_from_env(cls, features: FeatureModel): features.can_replace_logo = current_app.config['CAN_REPLACE_LOGO'] @@ -73,3 +90,11 @@ class FeatureService: features.docs_processing = billing_info['docs_processing'] features.can_replace_logo = billing_info['can_replace_logo'] + @classmethod + def _fulfill_params_from_enterprise(cls, features): + enterprise_info = EnterpriseService.get_info() + + features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin'] + features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol'] + features.sso_enforced_for_web = enterprise_info['sso_enforced_for_web'] + features.sso_enforced_for_web_protocol = enterprise_info['sso_enforced_for_web_protocol'] diff --git a/web/app/(shareLayout)/chat/[token]/page.tsx b/web/app/(shareLayout)/chat/[token]/page.tsx index 6c3fe2b4a4..56b2e0da7d 100644 --- a/web/app/(shareLayout)/chat/[token]/page.tsx +++ b/web/app/(shareLayout)/chat/[token]/page.tsx @@ -1,5 +1,4 @@ 'use client' - import type { FC } from 'react' import React from 'react' diff --git a/web/app/(shareLayout)/chatbot/[token]/page.tsx b/web/app/(shareLayout)/chatbot/[token]/page.tsx index 8aa182893a..0dc7b07169 100644 --- a/web/app/(shareLayout)/chatbot/[token]/page.tsx +++ b/web/app/(shareLayout)/chatbot/[token]/page.tsx @@ -1,12 +1,87 @@ +'use client' import type { FC } from 'react' -import React from 'react' - +import React, { useEffect } from 'react' +import cn from 'classnames' import type { IMainProps } from '@/app/components/share/chat' import Main from '@/app/components/share/chatbot' +import Loading from '@/app/components/base/loading' +import { fetchSystemFeatures } from '@/service/share' +import LogoSite from '@/app/components/base/logo/logo-site' const Chatbot: FC = () => { + const [isSSOEnforced, setIsSSOEnforced] = React.useState(true) + const [loading, setLoading] = React.useState(true) + + useEffect(() => { + fetchSystemFeatures().then((res) => { + setIsSSOEnforced(res.sso_enforced_for_web) + setLoading(false) + }) + }, []) + return ( -
+ <> + { + loading + ? ( +
+
+ +
+
+ ) + : ( + <> + {isSSOEnforced + ? ( +
+
+
+ +
+ +
+
+
+

+ Warning: Chatbot is not available +

+

+ Because SSO is enforced. Please contact your administrator. +

+
+
+
+
+
+ ) + :
+ } + + )} + ) } diff --git a/web/app/(shareLayout)/webapp-signin/page.tsx b/web/app/(shareLayout)/webapp-signin/page.tsx new file mode 100644 index 0000000000..d0d05cdd0d --- /dev/null +++ b/web/app/(shareLayout)/webapp-signin/page.tsx @@ -0,0 +1,147 @@ +'use client' +import cn from 'classnames' +import { useRouter, useSearchParams } from 'next/navigation' +import type { FC } from 'react' +import React, { useEffect, useState } from 'react' +import { useTranslation } from 'react-i18next' +import Toast from '@/app/components/base/toast' +import Button from '@/app/components/base/button' +import { fetchSystemFeatures, fetchWebOIDCSSOUrl, fetchWebSAMLSSOUrl } from '@/service/share' +import LogoSite from '@/app/components/base/logo/logo-site' +import { setAccessToken } from '@/app/components/share/utils' + +const WebSSOForm: FC = () => { + const searchParams = useSearchParams() + + const redirectUrl = searchParams.get('redirect_url') + const tokenFromUrl = searchParams.get('web_sso_token') + const message = searchParams.get('message') + + const router = useRouter() + const { t } = useTranslation() + + const [isLoading, setIsLoading] = useState(false) + const [protocal, setProtocal] = useState('') + + useEffect(() => { + const fetchFeaturesAndSetToken = async () => { + await fetchSystemFeatures().then((res) => { + setProtocal(res.sso_enforced_for_web_protocol) + }) + + // Callback from SSO, process token and redirect + if (tokenFromUrl && redirectUrl) { + const appCode = redirectUrl.split('/').pop() + if (!appCode) { + Toast.notify({ + type: 'error', + message: 'redirect url is invalid. App code is not found.', + }) + return + } + + await setAccessToken(appCode, tokenFromUrl) + router.push(redirectUrl) + } + } + + fetchFeaturesAndSetToken() + + if (message) { + Toast.notify({ + type: 'error', + message, + }) + } + }, []) + + const handleSSOLogin = () => { + setIsLoading(true) + + if (!redirectUrl) { + Toast.notify({ + type: 'error', + message: 'redirect url is not found.', + }) + setIsLoading(false) + return + } + + const appCode = redirectUrl.split('/').pop() + if (!appCode) { + Toast.notify({ + type: 'error', + message: 'redirect url is invalid. App code is not found.', + }) + return + } + + if (protocal === 'saml') { + fetchWebSAMLSSOUrl(appCode, redirectUrl).then((res) => { + router.push(res.url) + }).finally(() => { + setIsLoading(false) + }) + } + else if (protocal === 'oidc') { + fetchWebOIDCSSOUrl(appCode, redirectUrl).then((res) => { + router.push(res.url) + }).finally(() => { + setIsLoading(false) + }) + } + else { + Toast.notify({ + type: 'error', + message: 'sso protocal is not supported.', + }) + setIsLoading(false) + } + } + + return ( +
+
+
+ +
+ +
+
+
+

{t('login.pageTitle')}

+
+
+ +
+
+
+
+
+ ) +} + +export default React.memo(WebSSOForm) diff --git a/web/app/components/share/utils.ts b/web/app/components/share/utils.ts index 6362fd8330..5a41523404 100644 --- a/web/app/components/share/utils.ts +++ b/web/app/components/share/utils.ts @@ -1,4 +1,6 @@ +import { CONVERSATION_ID_INFO } from '../base/chat/constants' import { fetchAccessToken } from '@/service/share' + export const checkOrSetAccessToken = async () => { const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0] const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' }) @@ -15,3 +17,37 @@ export const checkOrSetAccessToken = async () => { localStorage.setItem('token', JSON.stringify(accessTokenJson)) } } + +export const setAccessToken = async (sharedToken: string, token: string) => { + const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' }) + let accessTokenJson = { [sharedToken]: '' } + try { + accessTokenJson = JSON.parse(accessToken) + } + catch (e) { + + } + + localStorage.removeItem(CONVERSATION_ID_INFO) + + accessTokenJson[sharedToken] = token + localStorage.setItem('token', JSON.stringify(accessTokenJson)) +} + +export const removeAccessToken = () => { + const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0] + + const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' }) + let accessTokenJson = { [sharedToken]: '' } + try { + accessTokenJson = JSON.parse(accessToken) + } + catch (e) { + + } + + localStorage.removeItem(CONVERSATION_ID_INFO) + + delete accessTokenJson[sharedToken] + localStorage.setItem('token', JSON.stringify(accessTokenJson)) +} diff --git a/web/app/signin/page.tsx b/web/app/signin/page.tsx index 8abb656c2e..b0ee172a95 100644 --- a/web/app/signin/page.tsx +++ b/web/app/signin/page.tsx @@ -6,19 +6,20 @@ import Loading from '../components/base/loading' import Forms from './forms' import Header from './_header' import style from './page.module.css' -import EnterpriseSSOForm from './enterpriseSSOForm' +import UserSSOForm from './userSSOForm' import { IS_CE_EDITION } from '@/config' -import { getEnterpriseFeatures } from '@/service/enterprise' -import type { EnterpriseFeatures } from '@/types/enterprise' -import { defaultEnterpriseFeatures } from '@/types/enterprise' + +import type { SystemFeatures } from '@/types/feature' +import { defaultSystemFeatures } from '@/types/feature' +import { getSystemFeatures } from '@/service/common' const SignIn = () => { const [loading, setLoading] = useState(true) - const [enterpriseFeatures, setEnterpriseFeatures] = useState(defaultEnterpriseFeatures) + const [systemFeatures, setSystemFeatures] = useState(defaultSystemFeatures) useEffect(() => { - getEnterpriseFeatures().then((res) => { - setEnterpriseFeatures(res) + getSystemFeatures().then((res) => { + setSystemFeatures(res) }).finally(() => { setLoading(false) }) @@ -70,7 +71,7 @@ gtag('config', 'AW-11217955271"'); )} - {!loading && !enterpriseFeatures.sso_enforced_for_signin && ( + {!loading && !systemFeatures.sso_enforced_for_signin && ( <>
@@ -79,8 +80,8 @@ gtag('config', 'AW-11217955271"'); )} - {!loading && enterpriseFeatures.sso_enforced_for_signin && ( - + {!loading && systemFeatures.sso_enforced_for_signin && ( + )}
diff --git a/web/app/signin/enterpriseSSOForm.tsx b/web/app/signin/userSSOForm.tsx similarity index 85% rename from web/app/signin/enterpriseSSOForm.tsx rename to web/app/signin/userSSOForm.tsx index 747f2aa478..fe95be8c66 100644 --- a/web/app/signin/enterpriseSSOForm.tsx +++ b/web/app/signin/userSSOForm.tsx @@ -5,14 +5,14 @@ import type { FC } from 'react' import { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import Toast from '@/app/components/base/toast' -import { getOIDCSSOUrl, getSAMLSSOUrl } from '@/service/enterprise' +import { getUserOIDCSSOUrl, getUserSAMLSSOUrl } from '@/service/sso' import Button from '@/app/components/base/button' -type EnterpriseSSOFormProps = { +type UserSSOFormProps = { protocol: string } -const EnterpriseSSOForm: FC = ({ +const UserSSOForm: FC = ({ protocol, }) => { const searchParams = useSearchParams() @@ -41,15 +41,15 @@ const EnterpriseSSOForm: FC = ({ const handleSSOLogin = () => { setIsLoading(true) if (protocol === 'saml') { - getSAMLSSOUrl().then((res) => { + getUserSAMLSSOUrl().then((res) => { router.push(res.url) }).finally(() => { setIsLoading(false) }) } else { - getOIDCSSOUrl().then((res) => { - document.cookie = `oidc-state=${res.state}` + getUserOIDCSSOUrl().then((res) => { + document.cookie = `user-oidc-state=${res.state}` router.push(res.url) }).finally(() => { setIsLoading(false) @@ -84,4 +84,4 @@ const EnterpriseSSOForm: FC = ({ ) } -export default EnterpriseSSOForm +export default UserSSOForm diff --git a/web/service/base.ts b/web/service/base.ts index 48baeaeb05..c500e31d7e 100644 --- a/web/service/base.ts +++ b/web/service/base.ts @@ -10,6 +10,7 @@ import type { WorkflowFinishedResponse, WorkflowStartedResponse, } from '@/types/workflow' +import { removeAccessToken } from '@/app/components/share/utils' const TIME_OUT = 100000 const ContentType = { @@ -97,6 +98,10 @@ function unicodeToChar(text: string) { }) } +function requiredWebSSOLogin() { + globalThis.location.href = `/webapp-signin?redirect_url=${globalThis.location.pathname}` +} + export function format(text: string) { let res = text.trim() if (res.startsWith('\n')) @@ -308,6 +313,15 @@ const baseFetch = ( return bodyJson.then((data: ResponseError) => { if (!silent) Toast.notify({ type: 'error', message: data.message }) + + if (data.code === 'web_sso_auth_required') + requiredWebSSOLogin() + + if (data.code === 'unauthorized') { + removeAccessToken() + globalThis.location.reload() + } + return Promise.reject(data) }) } @@ -467,6 +481,16 @@ export const ssePost = ( if (!/^(2|3)\d{2}$/.test(String(res.status))) { res.json().then((data: any) => { Toast.notify({ type: 'error', message: data.message || 'Server Error' }) + + if (isPublicAPI) { + if (data.code === 'web_sso_auth_required') + requiredWebSSOLogin() + + if (data.code === 'unauthorized') { + removeAccessToken() + globalThis.location.reload() + } + } }) onError?.('Server Error') return diff --git a/web/service/common.ts b/web/service/common.ts index 3a7d97af14..98fe50488c 100644 --- a/web/service/common.ts +++ b/web/service/common.ts @@ -34,6 +34,7 @@ import type { ModelProvider, } from '@/app/components/header/account-setting/model-provider-page/declarations' import type { RETRIEVE_METHOD } from '@/types/app' +import type { SystemFeatures } from '@/types/feature' export const login: Fetcher }> = ({ url, body }) => { return post(url, { body }) as Promise @@ -271,3 +272,7 @@ type RetrievalMethodsRes = { export const fetchSupportRetrievalMethods: Fetcher = (url) => { return get(url) } + +export const getSystemFeatures = () => { + return get('/system-features') +} diff --git a/web/service/enterprise.ts b/web/service/enterprise.ts deleted file mode 100644 index b7d9c8213d..0000000000 --- a/web/service/enterprise.ts +++ /dev/null @@ -1,14 +0,0 @@ -import { get } from './base' -import type { EnterpriseFeatures } from '@/types/enterprise' - -export const getEnterpriseFeatures = () => { - return get('/enterprise-features') -} - -export const getSAMLSSOUrl = () => { - return get<{ url: string }>('/enterprise/sso/saml/login') -} - -export const getOIDCSSOUrl = () => { - return get<{ url: string; state: string }>('/enterprise/sso/oidc/login') -} diff --git a/web/service/share.ts b/web/service/share.ts index 48a99705a3..4b8ce6d3b3 100644 --- a/web/service/share.ts +++ b/web/service/share.ts @@ -11,6 +11,7 @@ import type { ConversationItem, } from '@/models/share' import type { ChatConfig } from '@/app/components/base/chat/types' +import type { SystemFeatures } from '@/types/feature' function getAction(action: 'get' | 'post' | 'del' | 'patch', isInstalledApp: boolean) { switch (action) { @@ -135,6 +136,29 @@ export const fetchAppParams = async (isInstalledApp: boolean, installedAppId = ' return (getAction('get', isInstalledApp))(getUrl('parameters', isInstalledApp, installedAppId)) as Promise } +export const fetchSystemFeatures = async () => { + return (getAction('get', false))(getUrl('system-features', false, '')) as Promise +} + +export const fetchWebSAMLSSOUrl = async (appCode: string, redirectUrl: string) => { + return (getAction('get', false))(getUrl('/enterprise/sso/saml/login', false, ''), { + params: { + app_code: appCode, + redirect_url: redirectUrl, + }, + }) as Promise<{ url: string }> +} + +export const fetchWebOIDCSSOUrl = async (appCode: string, redirectUrl: string) => { + return (getAction('get', false))(getUrl('/enterprise/sso/oidc/login', false, ''), { + params: { + app_code: appCode, + redirect_url: redirectUrl, + }, + + }) as Promise<{ url: string }> +} + export const fetchAppMeta = async (isInstalledApp: boolean, installedAppId = '') => { return (getAction('get', isInstalledApp))(getUrl('meta', isInstalledApp, installedAppId)) as Promise } diff --git a/web/service/sso.ts b/web/service/sso.ts new file mode 100644 index 0000000000..77b81fe4a6 --- /dev/null +++ b/web/service/sso.ts @@ -0,0 +1,9 @@ +import { get } from './base' + +export const getUserSAMLSSOUrl = () => { + return get<{ url: string }>('/enterprise/sso/saml/login') +} + +export const getUserOIDCSSOUrl = () => { + return get<{ url: string; state: string }>('/enterprise/sso/oidc/login') +} diff --git a/web/types/enterprise.ts b/web/types/enterprise.ts deleted file mode 100644 index 479c593c04..0000000000 --- a/web/types/enterprise.ts +++ /dev/null @@ -1,9 +0,0 @@ -export type EnterpriseFeatures = { - sso_enforced_for_signin: boolean - sso_enforced_for_signin_protocol: string -} - -export const defaultEnterpriseFeatures: EnterpriseFeatures = { - sso_enforced_for_signin: false, - sso_enforced_for_signin_protocol: '', -} diff --git a/web/types/feature.ts b/web/types/feature.ts new file mode 100644 index 0000000000..89af9d21ab --- /dev/null +++ b/web/types/feature.ts @@ -0,0 +1,13 @@ +export type SystemFeatures = { + sso_enforced_for_signin: boolean + sso_enforced_for_signin_protocol: string + sso_enforced_for_web: boolean + sso_enforced_for_web_protocol: string +} + +export const defaultSystemFeatures: SystemFeatures = { + sso_enforced_for_signin: false, + sso_enforced_for_signin_protocol: '', + sso_enforced_for_web: false, + sso_enforced_for_web_protocol: '', +}