diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index a7c06f481d..2c8fdeeaf5 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -8,7 +8,7 @@ from flask import current_app, request from flask_restful import Resource, reqparse from libs.helper import email from libs.password import valid_password -from services.account_service import AccountService, TenantService +from services.account_service import AccountService class LoginApi(Resource): @@ -30,11 +30,6 @@ class LoginApi(Resource): except services.errors.account.AccountLoginError: return {'code': 'unauthorized', 'message': 'Invalid email or password'}, 401 - try: - TenantService.switch_tenant(account) - except Exception: - pass - AccountService.update_last_login(account, request) # todo: return the user info @@ -47,7 +42,6 @@ class LogoutApi(Resource): @setup_required def get(self): - flask.session.pop('workspace_id', None) flask_login.logout_user() return {'result': 'success'} diff --git a/api/migrations/versions/16830a790f0f_.py b/api/migrations/versions/16830a790f0f_.py new file mode 100644 index 0000000000..fd1eaedf67 --- /dev/null +++ b/api/migrations/versions/16830a790f0f_.py @@ -0,0 +1,32 @@ +"""empty message + +Revision ID: 16830a790f0f +Revises: 380c6aa5a70d +Create Date: 2024-02-01 08:21:31.111119 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '16830a790f0f' +down_revision = '380c6aa5a70d' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tenant_account_joins', schema=None) as batch_op: + batch_op.add_column(sa.Column('current', sa.Boolean(), server_default=sa.text('false'), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tenant_account_joins', schema=None) as batch_op: + batch_op.drop_column('current') + + # ### end Alembic commands ### diff --git a/api/models/account.py b/api/models/account.py index 322e2670a3..21fc998185 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -1,6 +1,5 @@ import enum import json -from math import e from typing import List from extensions.ext_database import db @@ -155,6 +154,7 @@ class TenantAccountJoin(db.Model): id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) tenant_id = db.Column(UUID, nullable=False) account_id = db.Column(UUID, nullable=False) + current = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) role = db.Column(db.String(16), nullable=False, server_default='normal') invited_by = db.Column(UUID, nullable=True) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) diff --git a/api/services/account_service.py b/api/services/account_service.py index 9a97d19889..3d82af2966 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -11,7 +11,7 @@ from typing import Any, Dict, Optional from constants.languages import language_timezone_mapping, languages from events.tenant_event import tenant_was_created from extensions.ext_redis import redis_client -from flask import current_app, session +from flask import current_app from libs.helper import get_remote_ip from libs.passport import PassportService from libs.password import compare_password, hash_password @@ -23,7 +23,8 @@ from services.errors.account import (AccountAlreadyInTenantError, AccountLoginEr NoPermissionError, RoleAlreadyAssignedError, TenantNotFound) from sqlalchemy import func from tasks.mail_invite_member_task import send_invite_member_mail_task -from werkzeug.exceptions import Forbidden, Unauthorized +from werkzeug.exceptions import Forbidden +from sqlalchemy import exc def _create_tenant_for_account(account) -> Tenant: @@ -39,54 +40,33 @@ class AccountService: @staticmethod def load_user(user_id: str) -> Account: - # todo: used by flask_login - if '.' in user_id: - tenant_id, account_id = user_id.split('.') + account = Account.query.filter_by(id=user_id).first() + if not account: + return None + + if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]: + raise Forbidden('Account is banned or closed.') + + # init owner's tenant + tenant_owner = TenantAccountJoin.query.filter_by(account_id=account.id, role='owner').first() + if not tenant_owner: + _create_tenant_for_account(account) + + current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first() + if current_tenant: + account.current_tenant_id = current_tenant.tenant_id else: - account_id = user_id - - account = db.session.query(Account).filter(Account.id == account_id).first() - - if account: - if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: - raise Forbidden('Account is banned or closed.') - - workspace_id = session.get('workspace_id') - if workspace_id: - tenant_account_join = db.session.query(TenantAccountJoin).filter( - TenantAccountJoin.account_id == account.id, - TenantAccountJoin.tenant_id == workspace_id - ).first() - - if not tenant_account_join: - tenant_account_join = db.session.query(TenantAccountJoin).filter( - TenantAccountJoin.account_id == account.id).first() - - if tenant_account_join: - account.current_tenant_id = tenant_account_join.tenant_id - else: - _create_tenant_for_account(account) - session['workspace_id'] = account.current_tenant_id - else: - account.current_tenant_id = workspace_id - else: - tenant_account_join = db.session.query(TenantAccountJoin).filter( - TenantAccountJoin.account_id == account.id).first() - if tenant_account_join: - account.current_tenant_id = tenant_account_join.tenant_id - else: - _create_tenant_for_account(account) - session['workspace_id'] = account.current_tenant_id - - current_time = datetime.utcnow() - - # update last_active_at when last_active_at is more than 10 minutes ago - if current_time - account.last_active_at > timedelta(minutes=10): - account.last_active_at = current_time - db.session.commit() + account.current_tenant_id = tenant_owner.tenant_id + tenant_owner.current = True + db.session.commit() + + if datetime.utcnow() - account.last_active_at > timedelta(minutes=10): + account.last_active_at = datetime.utcnow() + db.session.commit() return account + @staticmethod def get_account_jwt_token(account): payload = { @@ -277,18 +257,21 @@ class TenantService: @staticmethod def switch_tenant(account: Account, tenant_id: int = None) -> None: """Switch the current workspace for the account""" - if not tenant_id: - tenant_account_join = TenantAccountJoin.query.filter_by(account_id=account.id).first() - else: - tenant_account_join = TenantAccountJoin.query.filter_by(account_id=account.id, tenant_id=tenant_id).first() - # Check if the tenant exists and the account is a member of the tenant + tenant_account_join = TenantAccountJoin.query.filter_by(account_id=account.id, tenant_id=tenant_id).first() if not tenant_account_join: raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.") - - # Set the current tenant for the account - account.current_tenant_id = tenant_account_join.tenant_id - session['workspace_id'] = account.current_tenant.id + else: + with db.session.begin(): + try: + TenantAccountJoin.query.filter_by(account_id=account.id).update({'current': False}) + tenant_account_join.current = True + db.session.commit() + # Set the current tenant for the account + account.current_tenant_id = tenant_account_join.tenant_id + except exc.SQLAlchemyError: + db.session.rollback() + raise @staticmethod def get_tenant_members(tenant: Tenant) -> List[Account]: