refactor: enhance API token validation with session locking and last used timestamp update (#12426)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2025-01-07 18:04:41 +08:00 committed by GitHub
parent 41f39bf3fc
commit 0eeacdc80c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 34 additions and 22 deletions

View File

@ -1,5 +1,5 @@
from collections.abc import Callable from collections.abc import Callable
from datetime import UTC, datetime from datetime import UTC, datetime, timedelta
from enum import Enum from enum import Enum
from functools import wraps from functools import wraps
from typing import Optional from typing import Optional
@ -8,6 +8,8 @@ from flask import current_app, request
from flask_login import user_logged_in # type: ignore from flask_login import user_logged_in # type: ignore
from flask_restful import Resource # type: ignore from flask_restful import Resource # type: ignore
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, Unauthorized from werkzeug.exceptions import Forbidden, Unauthorized
from extensions.ext_database import db from extensions.ext_database import db
@ -174,7 +176,7 @@ def validate_dataset_token(view=None):
return decorator return decorator
def validate_and_get_api_token(scope=None): def validate_and_get_api_token(scope: str | None = None):
""" """
Validate and get API token. Validate and get API token.
""" """
@ -188,20 +190,25 @@ def validate_and_get_api_token(scope=None):
if auth_scheme != "bearer": if auth_scheme != "bearer":
raise Unauthorized("Authorization scheme must be 'Bearer'") raise Unauthorized("Authorization scheme must be 'Bearer'")
api_token = ( current_time = datetime.now(UTC).replace(tzinfo=None)
db.session.query(ApiToken) cutoff_time = current_time - timedelta(minutes=1)
.filter( with Session(db.engine, expire_on_commit=False) as session:
ApiToken.token == auth_token, update_stmt = (
ApiToken.type == scope, update(ApiToken)
.where(ApiToken.token == auth_token, ApiToken.last_used_at < cutoff_time, ApiToken.type == scope)
.values(last_used_at=current_time)
.returning(ApiToken)
) )
.first() result = session.execute(update_stmt)
) api_token = result.scalar_one_or_none()
if not api_token: if not api_token:
raise Unauthorized("Access token is invalid") stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
api_token = session.scalar(stmt)
api_token.last_used_at = datetime.now(UTC).replace(tzinfo=None) if not api_token:
db.session.commit() raise Unauthorized("Access token is invalid")
else:
session.commit()
return api_token return api_token

View File

@ -33,6 +33,7 @@ else
--bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \ --bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \
--workers ${SERVER_WORKER_AMOUNT:-1} \ --workers ${SERVER_WORKER_AMOUNT:-1} \
--worker-class ${SERVER_WORKER_CLASS:-gevent} \ --worker-class ${SERVER_WORKER_CLASS:-gevent} \
--worker-connections ${SERVER_WORKER_CONNECTIONS:-10} \
--timeout ${GUNICORN_TIMEOUT:-200} \ --timeout ${GUNICORN_TIMEOUT:-200} \
app:app app:app
fi fi

View File

@ -1,5 +1,5 @@
import os import os
from typing import Optional from typing import Literal, Optional
import httpx import httpx
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
@ -17,7 +17,6 @@ class BillingService:
params = {"tenant_id": tenant_id} params = {"tenant_id": tenant_id}
billing_info = cls._send_request("GET", "/subscription/info", params=params) billing_info = cls._send_request("GET", "/subscription/info", params=params)
return billing_info return billing_info
@classmethod @classmethod
@ -47,12 +46,13 @@ class BillingService:
retry=retry_if_exception_type(httpx.RequestError), retry=retry_if_exception_type(httpx.RequestError),
reraise=True, reraise=True,
) )
def _send_request(cls, method, endpoint, json=None, params=None): def _send_request(cls, method: Literal["GET", "POST", "DELETE"], endpoint: str, json=None, params=None):
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
url = f"{cls.base_url}{endpoint}" url = f"{cls.base_url}{endpoint}"
response = httpx.request(method, url, json=json, params=params, headers=headers) response = httpx.request(method, url, json=json, params=params, headers=headers)
if method == "GET" and response.status_code != httpx.codes.OK:
raise ValueError("Unable to retrieve billing information. Please try again later or contact support.")
return response.json() return response.json()
@staticmethod @staticmethod

View File

@ -126,10 +126,13 @@ DIFY_PORT=5001
# The number of API server workers, i.e., the number of workers. # The number of API server workers, i.e., the number of workers.
# Formula: number of cpu cores x 2 + 1 for sync, 1 for Gevent # Formula: number of cpu cores x 2 + 1 for sync, 1 for Gevent
# Reference: https://docs.gunicorn.org/en/stable/design.html#how-many-workers # Reference: https://docs.gunicorn.org/en/stable/design.html#how-many-workers
SERVER_WORKER_AMOUNT= SERVER_WORKER_AMOUNT=1
# Defaults to gevent. If using windows, it can be switched to sync or solo. # Defaults to gevent. If using windows, it can be switched to sync or solo.
SERVER_WORKER_CLASS= SERVER_WORKER_CLASS=gevent
# Default number of worker connections, the default is 10.
SERVER_WORKER_CONNECTIONS=10
# Similar to SERVER_WORKER_CLASS. # Similar to SERVER_WORKER_CLASS.
# If using windows, it can be switched to sync or solo. # If using windows, it can be switched to sync or solo.

View File

@ -32,8 +32,9 @@ x-shared-env: &shared-api-worker-env
APP_MAX_EXECUTION_TIME: ${APP_MAX_EXECUTION_TIME:-1200} APP_MAX_EXECUTION_TIME: ${APP_MAX_EXECUTION_TIME:-1200}
DIFY_BIND_ADDRESS: ${DIFY_BIND_ADDRESS:-0.0.0.0} DIFY_BIND_ADDRESS: ${DIFY_BIND_ADDRESS:-0.0.0.0}
DIFY_PORT: ${DIFY_PORT:-5001} DIFY_PORT: ${DIFY_PORT:-5001}
SERVER_WORKER_AMOUNT: ${SERVER_WORKER_AMOUNT:-} SERVER_WORKER_AMOUNT: ${SERVER_WORKER_AMOUNT:-1}
SERVER_WORKER_CLASS: ${SERVER_WORKER_CLASS:-} SERVER_WORKER_CLASS: ${SERVER_WORKER_CLASS:-gevent}
SERVER_WORKER_CONNECTIONS: ${SERVER_WORKER_CONNECTIONS:-10}
CELERY_WORKER_CLASS: ${CELERY_WORKER_CLASS:-} CELERY_WORKER_CLASS: ${CELERY_WORKER_CLASS:-}
GUNICORN_TIMEOUT: ${GUNICORN_TIMEOUT:-360} GUNICORN_TIMEOUT: ${GUNICORN_TIMEOUT:-360}
CELERY_WORKER_AMOUNT: ${CELERY_WORKER_AMOUNT:-} CELERY_WORKER_AMOUNT: ${CELERY_WORKER_AMOUNT:-}