diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 20e071c834..11db386b2e 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,12 +1,18 @@ -from flask_restful import marshal_with # type: ignore +import logging + +from flask import request +from flask_login import current_user +from flask_restful import Resource, marshal_with, reqparse # type: ignore from controllers.common import fields from controllers.common import helpers as controller_helpers from controllers.web import api from controllers.web.error import AppUnavailableError from controllers.web.wraps import WebApiResource +from libs.passport import PassportService from models.model import App, AppMode from services.app_service import AppService +from services.enterprise.enterprise_service import EnterpriseService class AppParameterApi(WebApiResource): @@ -42,5 +48,55 @@ class AppMeta(WebApiResource): return AppService().get_app_meta(app_model) +class AppAccessMode(Resource): + def get(self): + parser = reqparse.RequestParser() + parser.add_argument("appId", type=str, required=True, location="args") + args = parser.parse_args() + + app_id = args["appId"] + res = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id) + + return {"accessMode": res.access_mode} + + +class AppWebAuthPermission(Resource): + def get(self): + user_id = "visitor" + try: + auth_header = request.headers.get("Authorization") + if auth_header is None: + raise + if " " not in auth_header: + raise + + auth_scheme, tk = auth_header.split(None, 1) + auth_scheme = auth_scheme.lower() + if auth_scheme != "bearer": + raise + + decoded = PassportService().verify(tk) + user_id = decoded.get("user_id", "visitor") + except Exception as e: + pass + + parser = reqparse.RequestParser() + parser.add_argument("appId", type=str, required=True, location="args") + args = parser.parse_args() + + app_id = args["appId"] + user_id = current_user.id + logging.info(f"App ID: {app_id}, User ID: {user_id}") + + app_code = AppService.get_app_code_by_id(app_id) + logging.info(f"App code: {app_code}") + + res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_code) + return {"result": res} + + api.add_resource(AppParameterApi, "/parameters") api.add_resource(AppMeta, "/meta") +# webapp auth apis +api.add_resource(AppAccessMode, "/webapp/access-mode") +api.add_resource(AppWebAuthPermission, "/webapp/permission") diff --git a/api/services/app_service.py b/api/services/app_service.py index 9359bb2844..e6a1ae32a9 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -19,7 +19,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_was_created from extensions.ext_database import db from models.account import Account -from models.model import App, AppMode, AppModelConfig +from models.model import App, AppMode, AppModelConfig, Site from models.tools import ApiToolProvider from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService @@ -384,3 +384,15 @@ class AppService: meta["tool_icons"][tool_name] = {"background": "#252525", "content": "\ud83d\ude01"} return meta + + @staticmethod + def get_app_code_by_id(app_id: str) -> str: + """ + Get app code by app id + :param app_id: app id + :return: app code + """ + site = db.session.query(Site).filter(Site.app_id == app_id).first() + if not site: + raise ValueError(f"App with id {app_id} not found") + return str(site.code) diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index e44e7f6658..dd0857745e 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -45,12 +45,12 @@ class EnterpriseService: if not data: raise ValueError("No data found.") - if not isinstance(data['accessModes'], dict): + if not isinstance(data["accessModes"], dict): logging.info(f"Batch get app access mode by id returns data: {data}") raise ValueError("Invalid data format.") ret = {} - for key, value in data['accessModes'].items(): + for key, value in data["accessModes"].items(): curr = WebAppSettings() curr.access_mode = value ret[key] = curr