mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 19:49:04 +08:00
Feat/move tenant id into db (#2341)
This commit is contained in:
parent
ecf947258a
commit
a8f23ed712
@ -8,7 +8,7 @@ from flask import current_app, request
|
|||||||
from flask_restful import Resource, reqparse
|
from flask_restful import Resource, reqparse
|
||||||
from libs.helper import email
|
from libs.helper import email
|
||||||
from libs.password import valid_password
|
from libs.password import valid_password
|
||||||
from services.account_service import AccountService, TenantService
|
from services.account_service import AccountService
|
||||||
|
|
||||||
|
|
||||||
class LoginApi(Resource):
|
class LoginApi(Resource):
|
||||||
@ -30,11 +30,6 @@ class LoginApi(Resource):
|
|||||||
except services.errors.account.AccountLoginError:
|
except services.errors.account.AccountLoginError:
|
||||||
return {'code': 'unauthorized', 'message': 'Invalid email or password'}, 401
|
return {'code': 'unauthorized', 'message': 'Invalid email or password'}, 401
|
||||||
|
|
||||||
try:
|
|
||||||
TenantService.switch_tenant(account)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
AccountService.update_last_login(account, request)
|
AccountService.update_last_login(account, request)
|
||||||
|
|
||||||
# todo: return the user info
|
# todo: return the user info
|
||||||
@ -47,7 +42,6 @@ class LogoutApi(Resource):
|
|||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
def get(self):
|
def get(self):
|
||||||
flask.session.pop('workspace_id', None)
|
|
||||||
flask_login.logout_user()
|
flask_login.logout_user()
|
||||||
return {'result': 'success'}
|
return {'result': 'success'}
|
||||||
|
|
||||||
|
32
api/migrations/versions/16830a790f0f_.py
Normal file
32
api/migrations/versions/16830a790f0f_.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
"""empty message
|
||||||
|
|
||||||
|
Revision ID: 16830a790f0f
|
||||||
|
Revises: 380c6aa5a70d
|
||||||
|
Create Date: 2024-02-01 08:21:31.111119
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '16830a790f0f'
|
||||||
|
down_revision = '380c6aa5a70d'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('tenant_account_joins', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('current', sa.Boolean(), server_default=sa.text('false'), nullable=False))
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('tenant_account_joins', schema=None) as batch_op:
|
||||||
|
batch_op.drop_column('current')
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
@ -1,6 +1,5 @@
|
|||||||
import enum
|
import enum
|
||||||
import json
|
import json
|
||||||
from math import e
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -155,6 +154,7 @@ class TenantAccountJoin(db.Model):
|
|||||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||||
tenant_id = db.Column(UUID, nullable=False)
|
tenant_id = db.Column(UUID, nullable=False)
|
||||||
account_id = db.Column(UUID, nullable=False)
|
account_id = db.Column(UUID, nullable=False)
|
||||||
|
current = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
|
||||||
role = db.Column(db.String(16), nullable=False, server_default='normal')
|
role = db.Column(db.String(16), nullable=False, server_default='normal')
|
||||||
invited_by = db.Column(UUID, nullable=True)
|
invited_by = db.Column(UUID, nullable=True)
|
||||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||||
|
@ -11,7 +11,7 @@ from typing import Any, Dict, Optional
|
|||||||
from constants.languages import language_timezone_mapping, languages
|
from constants.languages import language_timezone_mapping, languages
|
||||||
from events.tenant_event import tenant_was_created
|
from events.tenant_event import tenant_was_created
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from flask import current_app, session
|
from flask import current_app
|
||||||
from libs.helper import get_remote_ip
|
from libs.helper import get_remote_ip
|
||||||
from libs.passport import PassportService
|
from libs.passport import PassportService
|
||||||
from libs.password import compare_password, hash_password
|
from libs.password import compare_password, hash_password
|
||||||
@ -23,7 +23,8 @@ from services.errors.account import (AccountAlreadyInTenantError, AccountLoginEr
|
|||||||
NoPermissionError, RoleAlreadyAssignedError, TenantNotFound)
|
NoPermissionError, RoleAlreadyAssignedError, TenantNotFound)
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from tasks.mail_invite_member_task import send_invite_member_mail_task
|
from tasks.mail_invite_member_task import send_invite_member_mail_task
|
||||||
from werkzeug.exceptions import Forbidden, Unauthorized
|
from werkzeug.exceptions import Forbidden
|
||||||
|
from sqlalchemy import exc
|
||||||
|
|
||||||
|
|
||||||
def _create_tenant_for_account(account) -> Tenant:
|
def _create_tenant_for_account(account) -> Tenant:
|
||||||
@ -39,54 +40,33 @@ class AccountService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_user(user_id: str) -> Account:
|
def load_user(user_id: str) -> Account:
|
||||||
# todo: used by flask_login
|
account = Account.query.filter_by(id=user_id).first()
|
||||||
if '.' in user_id:
|
if not account:
|
||||||
tenant_id, account_id = user_id.split('.')
|
return None
|
||||||
else:
|
|
||||||
account_id = user_id
|
|
||||||
|
|
||||||
account = db.session.query(Account).filter(Account.id == account_id).first()
|
if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]:
|
||||||
|
|
||||||
if account:
|
|
||||||
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
|
|
||||||
raise Forbidden('Account is banned or closed.')
|
raise Forbidden('Account is banned or closed.')
|
||||||
|
|
||||||
workspace_id = session.get('workspace_id')
|
# init owner's tenant
|
||||||
if workspace_id:
|
tenant_owner = TenantAccountJoin.query.filter_by(account_id=account.id, role='owner').first()
|
||||||
tenant_account_join = db.session.query(TenantAccountJoin).filter(
|
if not tenant_owner:
|
||||||
TenantAccountJoin.account_id == account.id,
|
|
||||||
TenantAccountJoin.tenant_id == workspace_id
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not tenant_account_join:
|
|
||||||
tenant_account_join = db.session.query(TenantAccountJoin).filter(
|
|
||||||
TenantAccountJoin.account_id == account.id).first()
|
|
||||||
|
|
||||||
if tenant_account_join:
|
|
||||||
account.current_tenant_id = tenant_account_join.tenant_id
|
|
||||||
else:
|
|
||||||
_create_tenant_for_account(account)
|
_create_tenant_for_account(account)
|
||||||
session['workspace_id'] = account.current_tenant_id
|
|
||||||
else:
|
|
||||||
account.current_tenant_id = workspace_id
|
|
||||||
else:
|
|
||||||
tenant_account_join = db.session.query(TenantAccountJoin).filter(
|
|
||||||
TenantAccountJoin.account_id == account.id).first()
|
|
||||||
if tenant_account_join:
|
|
||||||
account.current_tenant_id = tenant_account_join.tenant_id
|
|
||||||
else:
|
|
||||||
_create_tenant_for_account(account)
|
|
||||||
session['workspace_id'] = account.current_tenant_id
|
|
||||||
|
|
||||||
current_time = datetime.utcnow()
|
current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first()
|
||||||
|
if current_tenant:
|
||||||
|
account.current_tenant_id = current_tenant.tenant_id
|
||||||
|
else:
|
||||||
|
account.current_tenant_id = tenant_owner.tenant_id
|
||||||
|
tenant_owner.current = True
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
# update last_active_at when last_active_at is more than 10 minutes ago
|
if datetime.utcnow() - account.last_active_at > timedelta(minutes=10):
|
||||||
if current_time - account.last_active_at > timedelta(minutes=10):
|
account.last_active_at = datetime.utcnow()
|
||||||
account.last_active_at = current_time
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
return account
|
return account
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_account_jwt_token(account):
|
def get_account_jwt_token(account):
|
||||||
payload = {
|
payload = {
|
||||||
@ -277,18 +257,21 @@ class TenantService:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def switch_tenant(account: Account, tenant_id: int = None) -> None:
|
def switch_tenant(account: Account, tenant_id: int = None) -> None:
|
||||||
"""Switch the current workspace for the account"""
|
"""Switch the current workspace for the account"""
|
||||||
if not tenant_id:
|
|
||||||
tenant_account_join = TenantAccountJoin.query.filter_by(account_id=account.id).first()
|
|
||||||
else:
|
|
||||||
tenant_account_join = TenantAccountJoin.query.filter_by(account_id=account.id, tenant_id=tenant_id).first()
|
|
||||||
|
|
||||||
# Check if the tenant exists and the account is a member of the tenant
|
tenant_account_join = TenantAccountJoin.query.filter_by(account_id=account.id, tenant_id=tenant_id).first()
|
||||||
if not tenant_account_join:
|
if not tenant_account_join:
|
||||||
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
|
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
|
||||||
|
else:
|
||||||
|
with db.session.begin():
|
||||||
|
try:
|
||||||
|
TenantAccountJoin.query.filter_by(account_id=account.id).update({'current': False})
|
||||||
|
tenant_account_join.current = True
|
||||||
|
db.session.commit()
|
||||||
# Set the current tenant for the account
|
# Set the current tenant for the account
|
||||||
account.current_tenant_id = tenant_account_join.tenant_id
|
account.current_tenant_id = tenant_account_join.tenant_id
|
||||||
session['workspace_id'] = account.current_tenant.id
|
except exc.SQLAlchemyError:
|
||||||
|
db.session.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_tenant_members(tenant: Tenant) -> List[Account]:
|
def get_tenant_members(tenant: Tenant) -> List[Account]:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user