From 0eeacdc80cd70201e1aa09f301e14857ded99a7a Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 7 Jan 2025 18:04:41 +0800 Subject: [PATCH] refactor: enhance API token validation with session locking and last used timestamp update (#12426) Signed-off-by: -LAN- --- api/controllers/service_api/wraps.py | 35 +++++++++++++++++----------- api/docker/entrypoint.sh | 1 + api/services/billing_service.py | 8 +++---- docker/.env.example | 7 ++++-- docker/docker-compose.yaml | 5 ++-- 5 files changed, 34 insertions(+), 22 deletions(-) diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 740b92ef8e..976db1eb46 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from datetime import UTC, datetime +from datetime import UTC, datetime, timedelta from enum import Enum from functools import wraps from typing import Optional @@ -8,6 +8,8 @@ from flask import current_app, request from flask_login import user_logged_in # type: ignore from flask_restful import Resource # type: ignore from pydantic import BaseModel +from sqlalchemy import select, update +from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, Unauthorized from extensions.ext_database import db @@ -174,7 +176,7 @@ def validate_dataset_token(view=None): return decorator -def validate_and_get_api_token(scope=None): +def validate_and_get_api_token(scope: str | None = None): """ Validate and get API token. """ @@ -188,20 +190,25 @@ def validate_and_get_api_token(scope=None): if auth_scheme != "bearer": raise Unauthorized("Authorization scheme must be 'Bearer'") - api_token = ( - db.session.query(ApiToken) - .filter( - ApiToken.token == auth_token, - ApiToken.type == scope, + current_time = datetime.now(UTC).replace(tzinfo=None) + cutoff_time = current_time - timedelta(minutes=1) + with Session(db.engine, expire_on_commit=False) as session: + update_stmt = ( + update(ApiToken) + .where(ApiToken.token == auth_token, ApiToken.last_used_at < cutoff_time, ApiToken.type == scope) + .values(last_used_at=current_time) + .returning(ApiToken) ) - .first() - ) + result = session.execute(update_stmt) + api_token = result.scalar_one_or_none() - if not api_token: - raise Unauthorized("Access token is invalid") - - api_token.last_used_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() + if not api_token: + stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope) + api_token = session.scalar(stmt) + if not api_token: + raise Unauthorized("Access token is invalid") + else: + session.commit() return api_token diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 881263171f..f0c6ca61d9 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -33,6 +33,7 @@ else --bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \ --workers ${SERVER_WORKER_AMOUNT:-1} \ --worker-class ${SERVER_WORKER_CLASS:-gevent} \ + --worker-connections ${SERVER_WORKER_CONNECTIONS:-10} \ --timeout ${GUNICORN_TIMEOUT:-200} \ app:app fi diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 3a13c10102..0d50a2aa8c 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -1,5 +1,5 @@ import os -from typing import Optional +from typing import Literal, Optional import httpx from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed @@ -17,7 +17,6 @@ class BillingService: params = {"tenant_id": tenant_id} billing_info = cls._send_request("GET", "/subscription/info", params=params) - return billing_info @classmethod @@ -47,12 +46,13 @@ class BillingService: retry=retry_if_exception_type(httpx.RequestError), reraise=True, ) - def _send_request(cls, method, endpoint, json=None, params=None): + def _send_request(cls, method: Literal["GET", "POST", "DELETE"], endpoint: str, json=None, params=None): headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} url = f"{cls.base_url}{endpoint}" response = httpx.request(method, url, json=json, params=params, headers=headers) - + if method == "GET" and response.status_code != httpx.codes.OK: + raise ValueError("Unable to retrieve billing information. Please try again later or contact support.") return response.json() @staticmethod diff --git a/docker/.env.example b/docker/.env.example index 05f7aba9bd..7c5447ef5b 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -126,10 +126,13 @@ DIFY_PORT=5001 # The number of API server workers, i.e., the number of workers. # Formula: number of cpu cores x 2 + 1 for sync, 1 for Gevent # Reference: https://docs.gunicorn.org/en/stable/design.html#how-many-workers -SERVER_WORKER_AMOUNT= +SERVER_WORKER_AMOUNT=1 # Defaults to gevent. If using windows, it can be switched to sync or solo. -SERVER_WORKER_CLASS= +SERVER_WORKER_CLASS=gevent + +# Default number of worker connections, the default is 10. +SERVER_WORKER_CONNECTIONS=10 # Similar to SERVER_WORKER_CLASS. # If using windows, it can be switched to sync or solo. diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 922e42fec5..554118a4a5 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -32,8 +32,9 @@ x-shared-env: &shared-api-worker-env APP_MAX_EXECUTION_TIME: ${APP_MAX_EXECUTION_TIME:-1200} DIFY_BIND_ADDRESS: ${DIFY_BIND_ADDRESS:-0.0.0.0} DIFY_PORT: ${DIFY_PORT:-5001} - SERVER_WORKER_AMOUNT: ${SERVER_WORKER_AMOUNT:-} - SERVER_WORKER_CLASS: ${SERVER_WORKER_CLASS:-} + SERVER_WORKER_AMOUNT: ${SERVER_WORKER_AMOUNT:-1} + SERVER_WORKER_CLASS: ${SERVER_WORKER_CLASS:-gevent} + SERVER_WORKER_CONNECTIONS: ${SERVER_WORKER_CONNECTIONS:-10} CELERY_WORKER_CLASS: ${CELERY_WORKER_CLASS:-} GUNICORN_TIMEOUT: ${GUNICORN_TIMEOUT:-360} CELERY_WORKER_AMOUNT: ${CELERY_WORKER_AMOUNT:-}