mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-07-31 00:22:01 +08:00
fix: adjust enterprise api
This commit is contained in:
parent
5e50570739
commit
e9a207b38e
@ -17,15 +17,13 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import (
|
||||
app_detail_fields,
|
||||
app_detail_fields_with_site,
|
||||
app_pagination_fields,
|
||||
)
|
||||
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
|
||||
from libs.login import login_required
|
||||
from models import Account, App
|
||||
from services.app_dsl_service import AppDslService, ImportMode
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
|
||||
@ -67,7 +65,12 @@ class AppListApi(Resource):
|
||||
if not app_pagination:
|
||||
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
|
||||
|
||||
return marshal(app_pagination, app_pagination_fields)
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
for app in app_pagination.items:
|
||||
app_setting = EnterpriseService.get_app_access_mode_by_id(app_id=str(app.id))
|
||||
app.access_mode = app_setting.access_mode
|
||||
|
||||
return marshal(app_pagination, app_pagination_fields), 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -111,6 +114,10 @@ class AppApi(Resource):
|
||||
|
||||
app_model = app_service.get_app(app_model)
|
||||
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
app_setting = EnterpriseService.get_app_access_mode_by_id(app_id=str(app_model.id))
|
||||
app_model.access_mode = app_setting.access_mode
|
||||
|
||||
return app_model
|
||||
|
||||
@setup_required
|
||||
|
@ -5,5 +5,5 @@ from libs.external_api import ExternalApi
|
||||
bp = Blueprint("inner_api", __name__, url_prefix="/inner/api")
|
||||
api = ExternalApi(bp)
|
||||
|
||||
from .workspace import workspace
|
||||
from . import mail
|
||||
from .workspace import workspace
|
||||
|
@ -1,5 +1,7 @@
|
||||
from flask_restful import Resource # type: ignore
|
||||
from flask_restful import reqparse
|
||||
from flask_restful import (
|
||||
Resource, # type: ignore
|
||||
reqparse,
|
||||
)
|
||||
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.inner_api import api
|
||||
@ -12,7 +14,7 @@ class EnterpriseMail(Resource):
|
||||
@inner_api_only
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("to", type=str, action='append', required=True)
|
||||
parser.add_argument("to", type=str, action="append", required=True)
|
||||
parser.add_argument("subject", type=str, required=True)
|
||||
parser.add_argument("body", type=str, required=True)
|
||||
parser.add_argument("substitutions", type=dict, required=False)
|
||||
|
@ -1,6 +1,3 @@
|
||||
from typing import cast
|
||||
|
||||
import flask_login
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
from jwt import InvalidTokenError # type: ignore
|
||||
@ -8,14 +5,11 @@ from web import api
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
import services
|
||||
from controllers.console.auth.error import (EmailCodeError,
|
||||
EmailOrPasswordMismatchError,
|
||||
InvalidEmailError)
|
||||
from controllers.console.auth.error import EmailCodeError, EmailOrPasswordMismatchError, InvalidEmailError
|
||||
from controllers.console.error import AccountBannedError, AccountNotFound
|
||||
from controllers.console.wraps import setup_required
|
||||
from libs.helper import email
|
||||
from libs.password import valid_password
|
||||
from models.account import Account
|
||||
from services.account_service import AccountService
|
||||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
@ -51,14 +45,14 @@ class LoginApi(Resource):
|
||||
return {"result": "success", "token": token}
|
||||
|
||||
|
||||
class LogoutApi(Resource):
|
||||
@setup_required
|
||||
def get(self):
|
||||
account = cast(Account, flask_login.current_user)
|
||||
if isinstance(account, flask_login.AnonymousUserMixin):
|
||||
return {"result": "success"}
|
||||
flask_login.logout_user()
|
||||
return {"result": "success"}
|
||||
# class LogoutApi(Resource):
|
||||
# @setup_required
|
||||
# def get(self):
|
||||
# account = cast(Account, flask_login.current_user)
|
||||
# if isinstance(account, flask_login.AnonymousUserMixin):
|
||||
# return {"result": "success"}
|
||||
# flask_login.logout_user()
|
||||
# return {"result": "success"}
|
||||
|
||||
|
||||
class EmailCodeLoginSendEmailApi(Resource):
|
||||
@ -122,6 +116,6 @@ class EmailCodeLoginApi(Resource):
|
||||
|
||||
|
||||
api.add_resource(LoginApi, "/login")
|
||||
api.add_resource(LogoutApi, "/logout")
|
||||
# api.add_resource(LogoutApi, "/logout")
|
||||
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
|
||||
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")
|
||||
|
@ -4,8 +4,7 @@ from flask import request
|
||||
from flask_restful import Resource # type: ignore
|
||||
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
||||
|
||||
from controllers.web.error import (WebAppAuthFailedError,
|
||||
WebAppAuthRequiredError)
|
||||
from controllers.web.error import WebAppAuthFailedError, WebAppAuthRequiredError
|
||||
from extensions.ext_database import db
|
||||
from libs.passport import PassportService
|
||||
from models.model import App, EndUser, Site
|
||||
@ -61,7 +60,9 @@ def decode_jwt_token():
|
||||
# for enterprise webapp auth
|
||||
app_web_auth_enabled = False
|
||||
if system_features.webapp_auth.enabled:
|
||||
app_web_auth_enabled = EnterpriseService.get_web_app_settings(app_code=app_code).get("access_mode", "private") == "private"
|
||||
app_web_auth_enabled = (
|
||||
EnterpriseService.get_app_access_mode_by_code(app_code=app_code).access_mode != "public"
|
||||
)
|
||||
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled)
|
||||
_validate_user_accessibility(decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled)
|
||||
@ -69,7 +70,9 @@ def decode_jwt_token():
|
||||
return app_model, end_user
|
||||
except Unauthorized as e:
|
||||
if system_features.webapp_auth.enabled:
|
||||
app_web_auth_enabled = EnterpriseService.get_web_app_settings(app_code=app_code).get("access_mode", "private") == "private"
|
||||
app_web_auth_enabled = (
|
||||
EnterpriseService.get_app_access_mode_by_code(app_code=app_code).access_mode != "public"
|
||||
)
|
||||
if app_web_auth_enabled:
|
||||
raise WebAppAuthRequiredError()
|
||||
|
||||
@ -77,7 +80,8 @@ def decode_jwt_token():
|
||||
|
||||
|
||||
def _validate_webapp_token(decoded, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool):
|
||||
# Check if authentication is enforced for web app, and if the token source is not webapp, raise an error and redirect to login
|
||||
# Check if authentication is enforced for web app, and if the token source is not webapp,
|
||||
# raise an error and redirect to login
|
||||
if system_webapp_auth_enabled and app_web_auth_enabled:
|
||||
source = decoded.get("token_source")
|
||||
if not source or source != "webapp":
|
||||
|
@ -63,6 +63,7 @@ app_detail_fields = {
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
"access_mode": fields.String,
|
||||
}
|
||||
|
||||
prompt_config_fields = {
|
||||
@ -98,6 +99,7 @@ app_partial_fields = {
|
||||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
"tags": fields.List(fields.Nested(tag_fields)),
|
||||
"access_mode": fields.String,
|
||||
}
|
||||
|
||||
|
||||
@ -170,6 +172,7 @@ app_detail_fields_with_site = {
|
||||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
"deleted_tools": fields.List(fields.String),
|
||||
"access_mode": fields.String,
|
||||
}
|
||||
|
||||
app_site_fields = {
|
||||
|
@ -7,7 +7,7 @@ class WebAppSettings(BaseModel):
|
||||
access_mode: str = Field(
|
||||
description="Access mode for the web app. Can be 'public' or 'private'",
|
||||
default="private",
|
||||
alias="access_mode",
|
||||
alias="accessMode",
|
||||
)
|
||||
|
||||
|
||||
@ -17,19 +17,28 @@ class EnterpriseService:
|
||||
return EnterpriseRequest.send_request("GET", "/info")
|
||||
|
||||
@classmethod
|
||||
def is_user_allowed_to_access_webapp(cls, user_id: str, app_id=None, app_code=None) -> bool:
|
||||
if not app_id and not app_code:
|
||||
raise ValueError("Either app_id or app_code must be provided.")
|
||||
def is_user_allowed_to_access_webapp(cls, user_id: str, app_code: str) -> bool:
|
||||
params = {"userId": user_id, "appCode": app_code}
|
||||
data = EnterpriseRequest.send_request("GET", "/webapp/permission", params=params)
|
||||
|
||||
return EnterpriseRequest.send_request(
|
||||
"GET", f"/web-app/allowed?appId={app_id}&appCode={app_code}&userId={user_id}"
|
||||
)
|
||||
return data.get("result", False)
|
||||
|
||||
@classmethod
|
||||
def get_web_app_settings(cls, app_code: str = None, app_id: str = None):
|
||||
if not app_code and not app_id:
|
||||
raise ValueError("Either app_code or app_id must be provided.")
|
||||
data = EnterpriseRequest.send_request("GET", f"/web-app/settings?appCode={app_code}&appId={app_id}")
|
||||
def get_app_access_mode_by_id(cls, app_id: str) -> WebAppSettings:
|
||||
if not app_id:
|
||||
raise ValueError("app_id must be provided.")
|
||||
params = {"appId": app_id}
|
||||
data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/id", params=params)
|
||||
if not data:
|
||||
raise ValueError("No data found.")
|
||||
return WebAppSettings(**data)
|
||||
|
||||
@classmethod
|
||||
def get_app_access_mode_by_code(cls, app_code: str) -> WebAppSettings:
|
||||
if not app_code:
|
||||
raise ValueError("app_code must be provided.")
|
||||
params = {"appCode": app_code}
|
||||
data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/code", params=params)
|
||||
if not data:
|
||||
raise ValueError("No data found.")
|
||||
return WebAppSettings(**data)
|
||||
|
@ -1,26 +1,18 @@
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from tasks.mail_enterprise_task import send_enterprise_email_task
|
||||
|
||||
|
||||
class DifyMail(BaseModel):
|
||||
to: List[str]
|
||||
to: list[str]
|
||||
subject: str
|
||||
body: str
|
||||
substitutions: Dict[str, str] = {}
|
||||
substitutions: dict[str, str] = {}
|
||||
|
||||
|
||||
class EnterpriseMailService:
|
||||
|
||||
@classmethod
|
||||
def send_mail(cls, mail: DifyMail):
|
||||
|
||||
send_enterprise_email_task.delay(
|
||||
to=mail.to,
|
||||
subject=mail.subject,
|
||||
body=mail.body,
|
||||
substitutions=mail.substitutions
|
||||
to=mail.to, subject=mail.subject, body=mail.body, substitutions=mail.substitutions
|
||||
)
|
||||
|
@ -205,8 +205,12 @@ class FeatureService:
|
||||
|
||||
if "WebAppAuth" in enterprise_info:
|
||||
features.webapp_auth.allow_sso = enterprise_info["WebAppAuth"].get("allowSSO", False)
|
||||
features.webapp_auth.allow_email_code_login = enterprise_info["WebAppAuth"].get("allowEmailCodeLogin", False)
|
||||
features.webapp_auth.allow_email_password_login = enterprise_info["WebAppAuth"].get("allowEmailPasswordLogin", False)
|
||||
features.webapp_auth.allow_email_code_login = enterprise_info["WebAppAuth"].get(
|
||||
"allowEmailCodeLogin", False
|
||||
)
|
||||
features.webapp_auth.allow_email_password_login = enterprise_info["WebAppAuth"].get(
|
||||
"allowEmailPasswordLogin", False
|
||||
)
|
||||
|
||||
if "License" in enterprise_info:
|
||||
license_info = enterprise_info["License"]
|
||||
|
@ -5,8 +5,7 @@ from typing import Any, Optional, cast
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.web.error import (WebAppAuthFailedError,
|
||||
WebAppAuthRequiredError)
|
||||
from controllers.web.error import WebAppAuthFailedError
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TokenManager
|
||||
from libs.passport import PassportService
|
||||
@ -14,8 +13,7 @@ from libs.password import compare_password
|
||||
from models.account import Account, AccountStatus
|
||||
from models.model import App, EndUser, Site
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.errors.account import (AccountLoginError, AccountNotFoundError,
|
||||
AccountPasswordError)
|
||||
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
|
||||
from services.feature_service import FeatureService
|
||||
from tasks.mail_email_code_login import send_email_code_login_mail_task
|
||||
|
||||
@ -99,7 +97,7 @@ class WebAppAuthService:
|
||||
is_anonymous=False,
|
||||
session_id=email,
|
||||
name="enterpriseuser",
|
||||
external_user_id="enterpriseuser"
|
||||
external_user_id="enterpriseuser",
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
@ -112,9 +110,8 @@ class WebAppAuthService:
|
||||
system_features = FeatureService.get_system_features()
|
||||
if system_features.webapp_auth.enabled:
|
||||
app_settings = EnterpriseService.get_web_app_settings(app_code=app_code)
|
||||
if not app_settings or not app_settings.access_mode == "public":
|
||||
raise WebAppAuthRequiredError()
|
||||
if app_settings.access_mode == "private" and not EnterpriseService.is_user_allowed_to_access_webapp(
|
||||
|
||||
if app_settings.access_mode != "public" and not EnterpriseService.is_user_allowed_to_access_webapp(
|
||||
account.id, app_code=app_code
|
||||
):
|
||||
raise WebAppAuthFailedError()
|
||||
|
@ -13,9 +13,7 @@ def send_enterprise_email_task(to, subject, body, substitutions):
|
||||
if not mail.is_inited():
|
||||
return
|
||||
|
||||
logging.info(
|
||||
click.style("Start enterprise mail to {} with subject {}".format(to, subject), fg="green")
|
||||
)
|
||||
logging.info(click.style("Start enterprise mail to {} with subject {}".format(to, subject), fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
@ -29,9 +27,7 @@ def send_enterprise_email_task(to, subject, body, substitutions):
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style(
|
||||
"Send enterprise mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green"
|
||||
)
|
||||
click.style("Send enterprise mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green")
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("Send enterprise mail to {} failed".format(to))
|
||||
|
Loading…
x
Reference in New Issue
Block a user