mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-07-31 21:51:59 +08:00
fix: adjust enterprise api
This commit is contained in:
parent
5e50570739
commit
e9a207b38e
@ -17,15 +17,13 @@ from controllers.console.wraps import (
|
|||||||
)
|
)
|
||||||
from core.ops.ops_trace_manager import OpsTraceManager
|
from core.ops.ops_trace_manager import OpsTraceManager
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.app_fields import (
|
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
|
||||||
app_detail_fields,
|
|
||||||
app_detail_fields_with_site,
|
|
||||||
app_pagination_fields,
|
|
||||||
)
|
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models import Account, App
|
from models import Account, App
|
||||||
from services.app_dsl_service import AppDslService, ImportMode
|
from services.app_dsl_service import AppDslService, ImportMode
|
||||||
from services.app_service import AppService
|
from services.app_service import AppService
|
||||||
|
from services.enterprise.enterprise_service import EnterpriseService
|
||||||
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||||
|
|
||||||
@ -67,7 +65,12 @@ class AppListApi(Resource):
|
|||||||
if not app_pagination:
|
if not app_pagination:
|
||||||
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
|
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
|
||||||
|
|
||||||
return marshal(app_pagination, app_pagination_fields)
|
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||||
|
for app in app_pagination.items:
|
||||||
|
app_setting = EnterpriseService.get_app_access_mode_by_id(app_id=str(app.id))
|
||||||
|
app.access_mode = app_setting.access_mode
|
||||||
|
|
||||||
|
return marshal(app_pagination, app_pagination_fields), 200
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -111,6 +114,10 @@ class AppApi(Resource):
|
|||||||
|
|
||||||
app_model = app_service.get_app(app_model)
|
app_model = app_service.get_app(app_model)
|
||||||
|
|
||||||
|
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||||
|
app_setting = EnterpriseService.get_app_access_mode_by_id(app_id=str(app_model.id))
|
||||||
|
app_model.access_mode = app_setting.access_mode
|
||||||
|
|
||||||
return app_model
|
return app_model
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
|
@ -5,5 +5,5 @@ from libs.external_api import ExternalApi
|
|||||||
bp = Blueprint("inner_api", __name__, url_prefix="/inner/api")
|
bp = Blueprint("inner_api", __name__, url_prefix="/inner/api")
|
||||||
api = ExternalApi(bp)
|
api = ExternalApi(bp)
|
||||||
|
|
||||||
from .workspace import workspace
|
|
||||||
from . import mail
|
from . import mail
|
||||||
|
from .workspace import workspace
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
from flask_restful import Resource # type: ignore
|
from flask_restful import (
|
||||||
from flask_restful import reqparse
|
Resource, # type: ignore
|
||||||
|
reqparse,
|
||||||
|
)
|
||||||
|
|
||||||
from controllers.console.wraps import setup_required
|
from controllers.console.wraps import setup_required
|
||||||
from controllers.inner_api import api
|
from controllers.inner_api import api
|
||||||
@ -12,7 +14,7 @@ class EnterpriseMail(Resource):
|
|||||||
@inner_api_only
|
@inner_api_only
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("to", type=str, action='append', required=True)
|
parser.add_argument("to", type=str, action="append", required=True)
|
||||||
parser.add_argument("subject", type=str, required=True)
|
parser.add_argument("subject", type=str, required=True)
|
||||||
parser.add_argument("body", type=str, required=True)
|
parser.add_argument("body", type=str, required=True)
|
||||||
parser.add_argument("substitutions", type=dict, required=False)
|
parser.add_argument("substitutions", type=dict, required=False)
|
||||||
|
@ -1,6 +1,3 @@
|
|||||||
from typing import cast
|
|
||||||
|
|
||||||
import flask_login
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restful import Resource, reqparse
|
from flask_restful import Resource, reqparse
|
||||||
from jwt import InvalidTokenError # type: ignore
|
from jwt import InvalidTokenError # type: ignore
|
||||||
@ -8,14 +5,11 @@ from web import api
|
|||||||
from werkzeug.exceptions import BadRequest
|
from werkzeug.exceptions import BadRequest
|
||||||
|
|
||||||
import services
|
import services
|
||||||
from controllers.console.auth.error import (EmailCodeError,
|
from controllers.console.auth.error import EmailCodeError, EmailOrPasswordMismatchError, InvalidEmailError
|
||||||
EmailOrPasswordMismatchError,
|
|
||||||
InvalidEmailError)
|
|
||||||
from controllers.console.error import AccountBannedError, AccountNotFound
|
from controllers.console.error import AccountBannedError, AccountNotFound
|
||||||
from controllers.console.wraps import setup_required
|
from controllers.console.wraps import setup_required
|
||||||
from libs.helper import email
|
from libs.helper import email
|
||||||
from libs.password import valid_password
|
from libs.password import valid_password
|
||||||
from models.account import Account
|
|
||||||
from services.account_service import AccountService
|
from services.account_service import AccountService
|
||||||
from services.webapp_auth_service import WebAppAuthService
|
from services.webapp_auth_service import WebAppAuthService
|
||||||
|
|
||||||
@ -51,14 +45,14 @@ class LoginApi(Resource):
|
|||||||
return {"result": "success", "token": token}
|
return {"result": "success", "token": token}
|
||||||
|
|
||||||
|
|
||||||
class LogoutApi(Resource):
|
# class LogoutApi(Resource):
|
||||||
@setup_required
|
# @setup_required
|
||||||
def get(self):
|
# def get(self):
|
||||||
account = cast(Account, flask_login.current_user)
|
# account = cast(Account, flask_login.current_user)
|
||||||
if isinstance(account, flask_login.AnonymousUserMixin):
|
# if isinstance(account, flask_login.AnonymousUserMixin):
|
||||||
return {"result": "success"}
|
# return {"result": "success"}
|
||||||
flask_login.logout_user()
|
# flask_login.logout_user()
|
||||||
return {"result": "success"}
|
# return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
class EmailCodeLoginSendEmailApi(Resource):
|
class EmailCodeLoginSendEmailApi(Resource):
|
||||||
@ -122,6 +116,6 @@ class EmailCodeLoginApi(Resource):
|
|||||||
|
|
||||||
|
|
||||||
api.add_resource(LoginApi, "/login")
|
api.add_resource(LoginApi, "/login")
|
||||||
api.add_resource(LogoutApi, "/logout")
|
# api.add_resource(LogoutApi, "/logout")
|
||||||
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
|
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
|
||||||
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")
|
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")
|
||||||
|
@ -4,8 +4,7 @@ from flask import request
|
|||||||
from flask_restful import Resource # type: ignore
|
from flask_restful import Resource # type: ignore
|
||||||
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
||||||
|
|
||||||
from controllers.web.error import (WebAppAuthFailedError,
|
from controllers.web.error import WebAppAuthFailedError, WebAppAuthRequiredError
|
||||||
WebAppAuthRequiredError)
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.passport import PassportService
|
from libs.passport import PassportService
|
||||||
from models.model import App, EndUser, Site
|
from models.model import App, EndUser, Site
|
||||||
@ -61,7 +60,9 @@ def decode_jwt_token():
|
|||||||
# for enterprise webapp auth
|
# for enterprise webapp auth
|
||||||
app_web_auth_enabled = False
|
app_web_auth_enabled = False
|
||||||
if system_features.webapp_auth.enabled:
|
if system_features.webapp_auth.enabled:
|
||||||
app_web_auth_enabled = EnterpriseService.get_web_app_settings(app_code=app_code).get("access_mode", "private") == "private"
|
app_web_auth_enabled = (
|
||||||
|
EnterpriseService.get_app_access_mode_by_code(app_code=app_code).access_mode != "public"
|
||||||
|
)
|
||||||
|
|
||||||
_validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled)
|
_validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled)
|
||||||
_validate_user_accessibility(decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled)
|
_validate_user_accessibility(decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled)
|
||||||
@ -69,7 +70,9 @@ def decode_jwt_token():
|
|||||||
return app_model, end_user
|
return app_model, end_user
|
||||||
except Unauthorized as e:
|
except Unauthorized as e:
|
||||||
if system_features.webapp_auth.enabled:
|
if system_features.webapp_auth.enabled:
|
||||||
app_web_auth_enabled = EnterpriseService.get_web_app_settings(app_code=app_code).get("access_mode", "private") == "private"
|
app_web_auth_enabled = (
|
||||||
|
EnterpriseService.get_app_access_mode_by_code(app_code=app_code).access_mode != "public"
|
||||||
|
)
|
||||||
if app_web_auth_enabled:
|
if app_web_auth_enabled:
|
||||||
raise WebAppAuthRequiredError()
|
raise WebAppAuthRequiredError()
|
||||||
|
|
||||||
@ -77,7 +80,8 @@ def decode_jwt_token():
|
|||||||
|
|
||||||
|
|
||||||
def _validate_webapp_token(decoded, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool):
|
def _validate_webapp_token(decoded, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool):
|
||||||
# Check if authentication is enforced for web app, and if the token source is not webapp, raise an error and redirect to login
|
# Check if authentication is enforced for web app, and if the token source is not webapp,
|
||||||
|
# raise an error and redirect to login
|
||||||
if system_webapp_auth_enabled and app_web_auth_enabled:
|
if system_webapp_auth_enabled and app_web_auth_enabled:
|
||||||
source = decoded.get("token_source")
|
source = decoded.get("token_source")
|
||||||
if not source or source != "webapp":
|
if not source or source != "webapp":
|
||||||
|
@ -63,6 +63,7 @@ app_detail_fields = {
|
|||||||
"created_at": TimestampField,
|
"created_at": TimestampField,
|
||||||
"updated_by": fields.String,
|
"updated_by": fields.String,
|
||||||
"updated_at": TimestampField,
|
"updated_at": TimestampField,
|
||||||
|
"access_mode": fields.String,
|
||||||
}
|
}
|
||||||
|
|
||||||
prompt_config_fields = {
|
prompt_config_fields = {
|
||||||
@ -98,6 +99,7 @@ app_partial_fields = {
|
|||||||
"updated_by": fields.String,
|
"updated_by": fields.String,
|
||||||
"updated_at": TimestampField,
|
"updated_at": TimestampField,
|
||||||
"tags": fields.List(fields.Nested(tag_fields)),
|
"tags": fields.List(fields.Nested(tag_fields)),
|
||||||
|
"access_mode": fields.String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -170,6 +172,7 @@ app_detail_fields_with_site = {
|
|||||||
"updated_by": fields.String,
|
"updated_by": fields.String,
|
||||||
"updated_at": TimestampField,
|
"updated_at": TimestampField,
|
||||||
"deleted_tools": fields.List(fields.String),
|
"deleted_tools": fields.List(fields.String),
|
||||||
|
"access_mode": fields.String,
|
||||||
}
|
}
|
||||||
|
|
||||||
app_site_fields = {
|
app_site_fields = {
|
||||||
|
@ -7,7 +7,7 @@ class WebAppSettings(BaseModel):
|
|||||||
access_mode: str = Field(
|
access_mode: str = Field(
|
||||||
description="Access mode for the web app. Can be 'public' or 'private'",
|
description="Access mode for the web app. Can be 'public' or 'private'",
|
||||||
default="private",
|
default="private",
|
||||||
alias="access_mode",
|
alias="accessMode",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -17,19 +17,28 @@ class EnterpriseService:
|
|||||||
return EnterpriseRequest.send_request("GET", "/info")
|
return EnterpriseRequest.send_request("GET", "/info")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_user_allowed_to_access_webapp(cls, user_id: str, app_id=None, app_code=None) -> bool:
|
def is_user_allowed_to_access_webapp(cls, user_id: str, app_code: str) -> bool:
|
||||||
if not app_id and not app_code:
|
params = {"userId": user_id, "appCode": app_code}
|
||||||
raise ValueError("Either app_id or app_code must be provided.")
|
data = EnterpriseRequest.send_request("GET", "/webapp/permission", params=params)
|
||||||
|
|
||||||
return EnterpriseRequest.send_request(
|
return data.get("result", False)
|
||||||
"GET", f"/web-app/allowed?appId={app_id}&appCode={app_code}&userId={user_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_web_app_settings(cls, app_code: str = None, app_id: str = None):
|
def get_app_access_mode_by_id(cls, app_id: str) -> WebAppSettings:
|
||||||
if not app_code and not app_id:
|
if not app_id:
|
||||||
raise ValueError("Either app_code or app_id must be provided.")
|
raise ValueError("app_id must be provided.")
|
||||||
data = EnterpriseRequest.send_request("GET", f"/web-app/settings?appCode={app_code}&appId={app_id}")
|
params = {"appId": app_id}
|
||||||
|
data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/id", params=params)
|
||||||
|
if not data:
|
||||||
|
raise ValueError("No data found.")
|
||||||
|
return WebAppSettings(**data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_app_access_mode_by_code(cls, app_code: str) -> WebAppSettings:
|
||||||
|
if not app_code:
|
||||||
|
raise ValueError("app_code must be provided.")
|
||||||
|
params = {"appCode": app_code}
|
||||||
|
data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/code", params=params)
|
||||||
if not data:
|
if not data:
|
||||||
raise ValueError("No data found.")
|
raise ValueError("No data found.")
|
||||||
return WebAppSettings(**data)
|
return WebAppSettings(**data)
|
||||||
|
@ -1,26 +1,18 @@
|
|||||||
|
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from tasks.mail_enterprise_task import send_enterprise_email_task
|
from tasks.mail_enterprise_task import send_enterprise_email_task
|
||||||
|
|
||||||
|
|
||||||
class DifyMail(BaseModel):
|
class DifyMail(BaseModel):
|
||||||
to: List[str]
|
to: list[str]
|
||||||
subject: str
|
subject: str
|
||||||
body: str
|
body: str
|
||||||
substitutions: Dict[str, str] = {}
|
substitutions: dict[str, str] = {}
|
||||||
|
|
||||||
|
|
||||||
class EnterpriseMailService:
|
class EnterpriseMailService:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def send_mail(cls, mail: DifyMail):
|
def send_mail(cls, mail: DifyMail):
|
||||||
|
|
||||||
send_enterprise_email_task.delay(
|
send_enterprise_email_task.delay(
|
||||||
to=mail.to,
|
to=mail.to, subject=mail.subject, body=mail.body, substitutions=mail.substitutions
|
||||||
subject=mail.subject,
|
|
||||||
body=mail.body,
|
|
||||||
substitutions=mail.substitutions
|
|
||||||
)
|
)
|
||||||
|
@ -205,8 +205,12 @@ class FeatureService:
|
|||||||
|
|
||||||
if "WebAppAuth" in enterprise_info:
|
if "WebAppAuth" in enterprise_info:
|
||||||
features.webapp_auth.allow_sso = enterprise_info["WebAppAuth"].get("allowSSO", False)
|
features.webapp_auth.allow_sso = enterprise_info["WebAppAuth"].get("allowSSO", False)
|
||||||
features.webapp_auth.allow_email_code_login = enterprise_info["WebAppAuth"].get("allowEmailCodeLogin", False)
|
features.webapp_auth.allow_email_code_login = enterprise_info["WebAppAuth"].get(
|
||||||
features.webapp_auth.allow_email_password_login = enterprise_info["WebAppAuth"].get("allowEmailPasswordLogin", False)
|
"allowEmailCodeLogin", False
|
||||||
|
)
|
||||||
|
features.webapp_auth.allow_email_password_login = enterprise_info["WebAppAuth"].get(
|
||||||
|
"allowEmailPasswordLogin", False
|
||||||
|
)
|
||||||
|
|
||||||
if "License" in enterprise_info:
|
if "License" in enterprise_info:
|
||||||
license_info = enterprise_info["License"]
|
license_info = enterprise_info["License"]
|
||||||
|
@ -5,8 +5,7 @@ from typing import Any, Optional, cast
|
|||||||
from werkzeug.exceptions import NotFound, Unauthorized
|
from werkzeug.exceptions import NotFound, Unauthorized
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from controllers.web.error import (WebAppAuthFailedError,
|
from controllers.web.error import WebAppAuthFailedError
|
||||||
WebAppAuthRequiredError)
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import TokenManager
|
from libs.helper import TokenManager
|
||||||
from libs.passport import PassportService
|
from libs.passport import PassportService
|
||||||
@ -14,8 +13,7 @@ from libs.password import compare_password
|
|||||||
from models.account import Account, AccountStatus
|
from models.account import Account, AccountStatus
|
||||||
from models.model import App, EndUser, Site
|
from models.model import App, EndUser, Site
|
||||||
from services.enterprise.enterprise_service import EnterpriseService
|
from services.enterprise.enterprise_service import EnterpriseService
|
||||||
from services.errors.account import (AccountLoginError, AccountNotFoundError,
|
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
|
||||||
AccountPasswordError)
|
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
from tasks.mail_email_code_login import send_email_code_login_mail_task
|
from tasks.mail_email_code_login import send_email_code_login_mail_task
|
||||||
|
|
||||||
@ -99,7 +97,7 @@ class WebAppAuthService:
|
|||||||
is_anonymous=False,
|
is_anonymous=False,
|
||||||
session_id=email,
|
session_id=email,
|
||||||
name="enterpriseuser",
|
name="enterpriseuser",
|
||||||
external_user_id="enterpriseuser"
|
external_user_id="enterpriseuser",
|
||||||
)
|
)
|
||||||
db.session.add(end_user)
|
db.session.add(end_user)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
@ -112,9 +110,8 @@ class WebAppAuthService:
|
|||||||
system_features = FeatureService.get_system_features()
|
system_features = FeatureService.get_system_features()
|
||||||
if system_features.webapp_auth.enabled:
|
if system_features.webapp_auth.enabled:
|
||||||
app_settings = EnterpriseService.get_web_app_settings(app_code=app_code)
|
app_settings = EnterpriseService.get_web_app_settings(app_code=app_code)
|
||||||
if not app_settings or not app_settings.access_mode == "public":
|
|
||||||
raise WebAppAuthRequiredError()
|
if app_settings.access_mode != "public" and not EnterpriseService.is_user_allowed_to_access_webapp(
|
||||||
if app_settings.access_mode == "private" and not EnterpriseService.is_user_allowed_to_access_webapp(
|
|
||||||
account.id, app_code=app_code
|
account.id, app_code=app_code
|
||||||
):
|
):
|
||||||
raise WebAppAuthFailedError()
|
raise WebAppAuthFailedError()
|
||||||
|
@ -13,9 +13,7 @@ def send_enterprise_email_task(to, subject, body, substitutions):
|
|||||||
if not mail.is_inited():
|
if not mail.is_inited():
|
||||||
return
|
return
|
||||||
|
|
||||||
logging.info(
|
logging.info(click.style("Start enterprise mail to {} with subject {}".format(to, subject), fg="green"))
|
||||||
click.style("Start enterprise mail to {} with subject {}".format(to, subject), fg="green")
|
|
||||||
)
|
|
||||||
start_at = time.perf_counter()
|
start_at = time.perf_counter()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -29,9 +27,7 @@ def send_enterprise_email_task(to, subject, body, substitutions):
|
|||||||
|
|
||||||
end_at = time.perf_counter()
|
end_at = time.perf_counter()
|
||||||
logging.info(
|
logging.info(
|
||||||
click.style(
|
click.style("Send enterprise mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green")
|
||||||
"Send enterprise mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception("Send enterprise mail to {} failed".format(to))
|
logging.exception("Send enterprise mail to {} failed".format(to))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user