mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-20 18:39:06 +08:00
fix: optimize DEFAULT-USER
This commit is contained in:
parent
2cb640de15
commit
97a3727962
@ -5,6 +5,7 @@ from typing import Optional
|
||||
from flask import request
|
||||
from flask_restful import reqparse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account, Tenant
|
||||
@ -12,19 +13,29 @@ from models.model import EndUser
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
def get_user(user_id: str | None) -> Account | EndUser:
|
||||
def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser:
|
||||
try:
|
||||
if not user_id:
|
||||
user_id = "DEFAULT-USER"
|
||||
with Session(db.engine) as session:
|
||||
if not user_id:
|
||||
user_id = "DEFAULT-USER"
|
||||
|
||||
if user_id == "DEFAULT-USER":
|
||||
user_model = db.session.query(EndUser).filter(EndUser.session_id == "DEFAULT-USER").first()
|
||||
else:
|
||||
user_model = AccountService.load_user(user_id)
|
||||
if not user_model:
|
||||
user_model = db.session.query(EndUser).filter(EndUser.id == user_id).first()
|
||||
if not user_model:
|
||||
raise ValueError("user not found")
|
||||
if user_id == "DEFAULT-USER":
|
||||
user_model = session.query(EndUser).filter(EndUser.session_id == "DEFAULT-USER").first()
|
||||
if not user_model:
|
||||
user_model = EndUser(
|
||||
tenant_id=tenant_id,
|
||||
type="service_api",
|
||||
is_anonymous=True if user_id == "DEFAULT-USER" else False,
|
||||
session_id=user_id,
|
||||
)
|
||||
session.add(user_model)
|
||||
session.commit()
|
||||
else:
|
||||
user_model = AccountService.load_user(user_id)
|
||||
if not user_model:
|
||||
user_model = session.query(EndUser).filter(EndUser.id == user_id).first()
|
||||
if not user_model:
|
||||
raise ValueError("user not found")
|
||||
except Exception:
|
||||
raise ValueError("user not found")
|
||||
|
||||
@ -45,6 +56,12 @@ def get_user_tenant(view: Optional[Callable] = None):
|
||||
user_id = kwargs.get("user_id")
|
||||
tenant_id = kwargs.get("tenant_id")
|
||||
|
||||
if not tenant_id:
|
||||
raise ValueError("tenant_id is required")
|
||||
|
||||
if not user_id:
|
||||
user_id = "DEFAULT-USER"
|
||||
|
||||
del kwargs["tenant_id"]
|
||||
del kwargs["user_id"]
|
||||
|
||||
@ -63,7 +80,7 @@ def get_user_tenant(view: Optional[Callable] = None):
|
||||
raise ValueError("tenant not found")
|
||||
|
||||
kwargs["tenant_model"] = tenant_model
|
||||
kwargs["user_model"] = get_user(user_id)
|
||||
kwargs["user_model"] = get_user(tenant_id, user_id)
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
|
@ -1260,7 +1260,7 @@ class OperationLog(Base):
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
|
||||
|
||||
class EndUser(UserMixin, Base):
|
||||
class EndUser(Base, UserMixin):
|
||||
__tablename__ = "end_users"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="end_user_pkey"),
|
||||
|
Loading…
x
Reference in New Issue
Block a user