diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index de4d22b3af..c0209a05cd 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -16,26 +16,25 @@ from services.account_service import RegisterService class ActivateCheckApi(Resource): def get(self): parser = reqparse.RequestParser() - parser.add_argument('workspace_id', type=str, required=True, nullable=False, location='args') - parser.add_argument('email', type=email, required=True, nullable=False, location='args') + parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='args') + parser.add_argument('email', type=email, required=False, nullable=True, location='args') parser.add_argument('token', type=str, required=True, nullable=False, location='args') args = parser.parse_args() - account = RegisterService.get_account_if_token_valid(args['workspace_id'], args['email'], args['token']) + workspaceId = args['workspace_id'] + reg_email = args['email'] + token = args['token'] - tenant = db.session.query(Tenant).filter( - Tenant.id == args['workspace_id'], - Tenant.status == 'normal' - ).first() + invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token) - return {'is_valid': account is not None, 'workspace_name': tenant.name} + return {'is_valid': invitation is not None, 'workspace_name': invitation['tenant'].name if invitation else None} class ActivateApi(Resource): def post(self): parser = reqparse.RequestParser() - parser.add_argument('workspace_id', type=str, required=True, nullable=False, location='json') - parser.add_argument('email', type=email, required=True, nullable=False, location='json') + parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='json') + parser.add_argument('email', type=email, required=False, nullable=True, location='json') parser.add_argument('token', type=str, required=True, nullable=False, location='json') parser.add_argument('name', type=str_len(30), required=True, nullable=False, location='json') parser.add_argument('password', type=valid_password, required=True, nullable=False, location='json') @@ -44,12 +43,13 @@ class ActivateApi(Resource): parser.add_argument('timezone', type=timezone, required=True, nullable=False, location='json') args = parser.parse_args() - account = RegisterService.get_account_if_token_valid(args['workspace_id'], args['email'], args['token']) - if account is None: + invitation = RegisterService.get_invitation_if_token_valid(args['workspace_id'], args['email'], args['token']) + if invitation is None: raise AlreadyActivateError() RegisterService.revoke_token(args['workspace_id'], args['email'], args['token']) + account = invitation['account'] account.name = args['name'] # generate password salt diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 942ae91302..ed829a5dce 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -72,7 +72,7 @@ class MemberInviteEmailApi(Resource): invitation_results.append({ 'status': 'success', 'email': invitee_email, - 'url': f'{console_web_url}/activate?workspace_id={current_user.current_tenant_id}&email={invitee_email}&token={token}' + 'url': f'{console_web_url}/activate?email={invitee_email}&token={token}' }) account = marshal(account, account_fields) account['role'] = role diff --git a/api/services/account_service.py b/api/services/account_service.py index 530a6279d8..0f5dfcf0f9 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -1,5 +1,6 @@ # -*- coding:utf-8 -*- import base64 +import json import logging import secrets import uuid @@ -346,6 +347,10 @@ class TenantService: class RegisterService: + @classmethod + def _get_invitation_token_key(cls, token: str) -> str: + return f'member_invite:token:{token}' + @classmethod def register(cls, email, name, password: str = None, open_id: str = None, provider: str = None) -> Account: db.session.begin_nested() @@ -401,7 +406,7 @@ class RegisterService: # send email send_invite_member_mail_task.delay( to=email, - token=cls.generate_invite_token(tenant, account), + token=token, inviter_name=inviter.name if inviter else 'Dify', workspace_id=tenant.id, workspace_name=tenant.name, @@ -412,21 +417,35 @@ class RegisterService: @classmethod def generate_invite_token(cls, tenant: Tenant, account: Account) -> str: token = str(uuid.uuid4()) - email_hash = sha256(account.email.encode()).hexdigest() - cache_key = 'member_invite_token:{}, {}:{}'.format(str(tenant.id), email_hash, token) - redis_client.setex(cache_key, 3600, str(account.id)) + invitation_data = { + 'account_id': account.id, + 'email': account.email, + 'workspace_id': tenant.id, + } + redis_client.setex( + cls._get_invitation_token_key(token), + 3600, + json.dumps(invitation_data) + ) return token @classmethod def revoke_token(cls, workspace_id: str, email: str, token: str): - email_hash = sha256(email.encode()).hexdigest() - cache_key = 'member_invite_token:{}, {}:{}'.format(workspace_id, email_hash, token) - redis_client.delete(cache_key) + if workspace_id and email: + email_hash = sha256(email.encode()).hexdigest() + cache_key = 'member_invite_token:{}, {}:{}'.format(workspace_id, email_hash, token) + redis_client.delete(cache_key) + else: + redis_client.delete(cls._get_invitation_token_key(token)) @classmethod - def get_account_if_token_valid(cls, workspace_id: str, email: str, token: str) -> Optional[Account]: + def get_invitation_if_token_valid(cls, workspace_id: str, email: str, token: str) -> Optional[Account]: + invitation_data = cls._get_invitation_by_token(token, workspace_id, email) + if not invitation_data: + return None + tenant = db.session.query(Tenant).filter( - Tenant.id == workspace_id, + Tenant.id == invitation_data['workspace_id'], Tenant.status == 'normal' ).first() @@ -435,30 +454,43 @@ class RegisterService: tenant_account = db.session.query(Account, TenantAccountJoin.role).join( TenantAccountJoin, Account.id == TenantAccountJoin.account_id - ).filter(Account.email == email, TenantAccountJoin.tenant_id == tenant.id).first() + ).filter(Account.email == invitation_data['email'], TenantAccountJoin.tenant_id == tenant.id).first() if not tenant_account: return None - account_id = cls._get_account_id_by_invite_token(workspace_id, email, token) - if not account_id: - return None - account = tenant_account[0] if not account: return None - if account_id != str(account.id): + if invitation_data['account_id'] != str(account.id): return None - return account + return { + 'account': account, + 'data': invitation_data, + 'tenant': tenant, + } @classmethod - def _get_account_id_by_invite_token(cls, workspace_id: str, email: str, token: str) -> Optional[str]: - email_hash = sha256(email.encode()).hexdigest() - cache_key = 'member_invite_token:{}, {}:{}'.format(workspace_id, email_hash, token) - account_id = redis_client.get(cache_key) - if not account_id: - return None + def _get_invitation_by_token(cls, token: str, workspace_id: str, email: str) -> Optional[str]: + if workspace_id is not None and email is not None: + email_hash = sha256(email.encode()).hexdigest() + cache_key = f'member_invite_token:{workspace_id}, {email_hash}:{token}' + account_id = redis_client.get(cache_key) - return account_id.decode('utf-8') + if not account_id: + return None + + return { + 'account_id': account_id.decode('utf-8'), + 'email': email, + 'workspace_id': workspace_id, + } + else: + data = redis_client.get(cls._get_invitation_token_key(token)) + if not data: + return None + + invitation = json.loads(data) + return invitation diff --git a/web/app/activate/activateForm.tsx b/web/app/activate/activateForm.tsx index 7a11faeea8..ad60b199c5 100644 --- a/web/app/activate/activateForm.tsx +++ b/web/app/activate/activateForm.tsx @@ -31,8 +31,8 @@ const ActivateForm = () => { const checkParams = { url: '/activate/check', params: { - workspace_id: workspaceID, - email, + ...workspaceID && { workspace_id: workspaceID }, + ...email && { email }, token, }, }