diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 4aa10ac6e9..ce6da4af79 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -17,15 +17,13 @@ from controllers.console.wraps import ( ) from core.ops.ops_trace_manager import OpsTraceManager from extensions.ext_database import db -from fields.app_fields import ( - app_detail_fields, - app_detail_fields_with_site, - app_pagination_fields, -) +from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields from libs.login import login_required from models import Account, App from services.app_dsl_service import AppDslService, ImportMode from services.app_service import AppService +from services.enterprise.enterprise_service import EnterpriseService +from services.feature_service import FeatureService ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] @@ -67,7 +65,12 @@ class AppListApi(Resource): if not app_pagination: return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} - return marshal(app_pagination, app_pagination_fields) + if FeatureService.get_system_features().webapp_auth.enabled: + for app in app_pagination.items: + app_setting = EnterpriseService.get_app_access_mode_by_id(app_id=str(app.id)) + app.access_mode = app_setting.access_mode + + return marshal(app_pagination, app_pagination_fields), 200 @setup_required @login_required @@ -111,6 +114,10 @@ class AppApi(Resource): app_model = app_service.get_app(app_model) + if FeatureService.get_system_features().webapp_auth.enabled: + app_setting = EnterpriseService.get_app_access_mode_by_id(app_id=str(app_model.id)) + app_model.access_mode = app_setting.access_mode + return app_model @setup_required diff --git a/api/controllers/inner_api/__init__.py b/api/controllers/inner_api/__init__.py index dfedea582d..b1bb9d6545 100644 --- a/api/controllers/inner_api/__init__.py +++ b/api/controllers/inner_api/__init__.py @@ -5,5 +5,5 @@ from libs.external_api import ExternalApi bp = Blueprint("inner_api", __name__, url_prefix="/inner/api") api = ExternalApi(bp) -from .workspace import workspace from . import mail +from .workspace import workspace diff --git a/api/controllers/inner_api/mail.py b/api/controllers/inner_api/mail.py index 9f10356fb6..47cbcb713c 100644 --- a/api/controllers/inner_api/mail.py +++ b/api/controllers/inner_api/mail.py @@ -1,5 +1,7 @@ -from flask_restful import Resource # type: ignore -from flask_restful import reqparse +from flask_restful import ( + Resource, # type: ignore + reqparse, +) from controllers.console.wraps import setup_required from controllers.inner_api import api @@ -12,7 +14,7 @@ class EnterpriseMail(Resource): @inner_api_only def post(self): parser = reqparse.RequestParser() - parser.add_argument("to", type=str, action='append', required=True) + parser.add_argument("to", type=str, action="append", required=True) parser.add_argument("subject", type=str, required=True) parser.add_argument("body", type=str, required=True) parser.add_argument("substitutions", type=dict, required=False) diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py index 955c781989..4106e6a179 100644 --- a/api/controllers/web/login.py +++ b/api/controllers/web/login.py @@ -1,6 +1,3 @@ -from typing import cast - -import flask_login from flask import request from flask_restful import Resource, reqparse from jwt import InvalidTokenError # type: ignore @@ -8,14 +5,11 @@ from web import api from werkzeug.exceptions import BadRequest import services -from controllers.console.auth.error import (EmailCodeError, - EmailOrPasswordMismatchError, - InvalidEmailError) +from controllers.console.auth.error import EmailCodeError, EmailOrPasswordMismatchError, InvalidEmailError from controllers.console.error import AccountBannedError, AccountNotFound from controllers.console.wraps import setup_required from libs.helper import email from libs.password import valid_password -from models.account import Account from services.account_service import AccountService from services.webapp_auth_service import WebAppAuthService @@ -51,14 +45,14 @@ class LoginApi(Resource): return {"result": "success", "token": token} -class LogoutApi(Resource): - @setup_required - def get(self): - account = cast(Account, flask_login.current_user) - if isinstance(account, flask_login.AnonymousUserMixin): - return {"result": "success"} - flask_login.logout_user() - return {"result": "success"} +# class LogoutApi(Resource): +# @setup_required +# def get(self): +# account = cast(Account, flask_login.current_user) +# if isinstance(account, flask_login.AnonymousUserMixin): +# return {"result": "success"} +# flask_login.logout_user() +# return {"result": "success"} class EmailCodeLoginSendEmailApi(Resource): @@ -122,6 +116,6 @@ class EmailCodeLoginApi(Resource): api.add_resource(LoginApi, "/login") -api.add_resource(LogoutApi, "/logout") +# api.add_resource(LogoutApi, "/logout") api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login") api.add_resource(EmailCodeLoginApi, "/email-code-login/validity") diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index 482d7859fa..a009cd3288 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -4,8 +4,7 @@ from flask import request from flask_restful import Resource # type: ignore from werkzeug.exceptions import BadRequest, NotFound, Unauthorized -from controllers.web.error import (WebAppAuthFailedError, - WebAppAuthRequiredError) +from controllers.web.error import WebAppAuthFailedError, WebAppAuthRequiredError from extensions.ext_database import db from libs.passport import PassportService from models.model import App, EndUser, Site @@ -61,7 +60,9 @@ def decode_jwt_token(): # for enterprise webapp auth app_web_auth_enabled = False if system_features.webapp_auth.enabled: - app_web_auth_enabled = EnterpriseService.get_web_app_settings(app_code=app_code).get("access_mode", "private") == "private" + app_web_auth_enabled = ( + EnterpriseService.get_app_access_mode_by_code(app_code=app_code).access_mode != "public" + ) _validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled) _validate_user_accessibility(decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled) @@ -69,7 +70,9 @@ def decode_jwt_token(): return app_model, end_user except Unauthorized as e: if system_features.webapp_auth.enabled: - app_web_auth_enabled = EnterpriseService.get_web_app_settings(app_code=app_code).get("access_mode", "private") == "private" + app_web_auth_enabled = ( + EnterpriseService.get_app_access_mode_by_code(app_code=app_code).access_mode != "public" + ) if app_web_auth_enabled: raise WebAppAuthRequiredError() @@ -77,7 +80,8 @@ def decode_jwt_token(): def _validate_webapp_token(decoded, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool): - # Check if authentication is enforced for web app, and if the token source is not webapp, raise an error and redirect to login + # Check if authentication is enforced for web app, and if the token source is not webapp, + # raise an error and redirect to login if system_webapp_auth_enabled and app_web_auth_enabled: source = decoded.get("token_source") if not source or source != "webapp": diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index 73800eab85..95eef8fed1 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -63,6 +63,7 @@ app_detail_fields = { "created_at": TimestampField, "updated_by": fields.String, "updated_at": TimestampField, + "access_mode": fields.String, } prompt_config_fields = { @@ -98,6 +99,7 @@ app_partial_fields = { "updated_by": fields.String, "updated_at": TimestampField, "tags": fields.List(fields.Nested(tag_fields)), + "access_mode": fields.String, } @@ -170,6 +172,7 @@ app_detail_fields_with_site = { "updated_by": fields.String, "updated_at": TimestampField, "deleted_tools": fields.List(fields.String), + "access_mode": fields.String, } app_site_fields = { diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 2ff3b3348a..21e4831715 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -7,7 +7,7 @@ class WebAppSettings(BaseModel): access_mode: str = Field( description="Access mode for the web app. Can be 'public' or 'private'", default="private", - alias="access_mode", + alias="accessMode", ) @@ -17,19 +17,28 @@ class EnterpriseService: return EnterpriseRequest.send_request("GET", "/info") @classmethod - def is_user_allowed_to_access_webapp(cls, user_id: str, app_id=None, app_code=None) -> bool: - if not app_id and not app_code: - raise ValueError("Either app_id or app_code must be provided.") + def is_user_allowed_to_access_webapp(cls, user_id: str, app_code: str) -> bool: + params = {"userId": user_id, "appCode": app_code} + data = EnterpriseRequest.send_request("GET", "/webapp/permission", params=params) - return EnterpriseRequest.send_request( - "GET", f"/web-app/allowed?appId={app_id}&appCode={app_code}&userId={user_id}" - ) + return data.get("result", False) @classmethod - def get_web_app_settings(cls, app_code: str = None, app_id: str = None): - if not app_code and not app_id: - raise ValueError("Either app_code or app_id must be provided.") - data = EnterpriseRequest.send_request("GET", f"/web-app/settings?appCode={app_code}&appId={app_id}") + def get_app_access_mode_by_id(cls, app_id: str) -> WebAppSettings: + if not app_id: + raise ValueError("app_id must be provided.") + params = {"appId": app_id} + data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/id", params=params) + if not data: + raise ValueError("No data found.") + return WebAppSettings(**data) + + @classmethod + def get_app_access_mode_by_code(cls, app_code: str) -> WebAppSettings: + if not app_code: + raise ValueError("app_code must be provided.") + params = {"appCode": app_code} + data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/code", params=params) if not data: raise ValueError("No data found.") return WebAppSettings(**data) diff --git a/api/services/enterprise/mail_service.py b/api/services/enterprise/mail_service.py index 24b22008a1..630e7679ac 100644 --- a/api/services/enterprise/mail_service.py +++ b/api/services/enterprise/mail_service.py @@ -1,26 +1,18 @@ - -from typing import Dict, List - from pydantic import BaseModel from tasks.mail_enterprise_task import send_enterprise_email_task class DifyMail(BaseModel): - to: List[str] + to: list[str] subject: str body: str - substitutions: Dict[str, str] = {} + substitutions: dict[str, str] = {} class EnterpriseMailService: - @classmethod def send_mail(cls, mail: DifyMail): - send_enterprise_email_task.delay( - to=mail.to, - subject=mail.subject, - body=mail.body, - substitutions=mail.substitutions + to=mail.to, subject=mail.subject, body=mail.body, substitutions=mail.substitutions ) diff --git a/api/services/feature_service.py b/api/services/feature_service.py index f37f4e0d92..c38c9cb72b 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -205,8 +205,12 @@ class FeatureService: if "WebAppAuth" in enterprise_info: features.webapp_auth.allow_sso = enterprise_info["WebAppAuth"].get("allowSSO", False) - features.webapp_auth.allow_email_code_login = enterprise_info["WebAppAuth"].get("allowEmailCodeLogin", False) - features.webapp_auth.allow_email_password_login = enterprise_info["WebAppAuth"].get("allowEmailPasswordLogin", False) + features.webapp_auth.allow_email_code_login = enterprise_info["WebAppAuth"].get( + "allowEmailCodeLogin", False + ) + features.webapp_auth.allow_email_password_login = enterprise_info["WebAppAuth"].get( + "allowEmailPasswordLogin", False + ) if "License" in enterprise_info: license_info = enterprise_info["License"] diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py index 24d1177d87..2f3ef5d97c 100644 --- a/api/services/webapp_auth_service.py +++ b/api/services/webapp_auth_service.py @@ -5,8 +5,7 @@ from typing import Any, Optional, cast from werkzeug.exceptions import NotFound, Unauthorized from configs import dify_config -from controllers.web.error import (WebAppAuthFailedError, - WebAppAuthRequiredError) +from controllers.web.error import WebAppAuthFailedError from extensions.ext_database import db from libs.helper import TokenManager from libs.passport import PassportService @@ -14,8 +13,7 @@ from libs.password import compare_password from models.account import Account, AccountStatus from models.model import App, EndUser, Site from services.enterprise.enterprise_service import EnterpriseService -from services.errors.account import (AccountLoginError, AccountNotFoundError, - AccountPasswordError) +from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError from services.feature_service import FeatureService from tasks.mail_email_code_login import send_email_code_login_mail_task @@ -99,7 +97,7 @@ class WebAppAuthService: is_anonymous=False, session_id=email, name="enterpriseuser", - external_user_id="enterpriseuser" + external_user_id="enterpriseuser", ) db.session.add(end_user) db.session.commit() @@ -112,9 +110,8 @@ class WebAppAuthService: system_features = FeatureService.get_system_features() if system_features.webapp_auth.enabled: app_settings = EnterpriseService.get_web_app_settings(app_code=app_code) - if not app_settings or not app_settings.access_mode == "public": - raise WebAppAuthRequiredError() - if app_settings.access_mode == "private" and not EnterpriseService.is_user_allowed_to_access_webapp( + + if app_settings.access_mode != "public" and not EnterpriseService.is_user_allowed_to_access_webapp( account.id, app_code=app_code ): raise WebAppAuthFailedError() diff --git a/api/tasks/mail_enterprise_task.py b/api/tasks/mail_enterprise_task.py index 67475185db..b9d8fd55df 100644 --- a/api/tasks/mail_enterprise_task.py +++ b/api/tasks/mail_enterprise_task.py @@ -13,9 +13,7 @@ def send_enterprise_email_task(to, subject, body, substitutions): if not mail.is_inited(): return - logging.info( - click.style("Start enterprise mail to {} with subject {}".format(to, subject), fg="green") - ) + logging.info(click.style("Start enterprise mail to {} with subject {}".format(to, subject), fg="green")) start_at = time.perf_counter() try: @@ -29,9 +27,7 @@ def send_enterprise_email_task(to, subject, body, substitutions): end_at = time.perf_counter() logging.info( - click.style( - "Send enterprise mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green" - ) + click.style("Send enterprise mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green") ) except Exception: logging.exception("Send enterprise mail to {} failed".format(to))