mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-07-26 15:41:58 +08:00
Feat/api jwt (#1212)
This commit is contained in:
parent
c40ee7e629
commit
227f9fb77d
@ -50,24 +50,6 @@ S3_REGION=your-region
|
|||||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||||
|
|
||||||
# Cookie configuration
|
|
||||||
COOKIE_HTTPONLY=true
|
|
||||||
COOKIE_SAMESITE=None
|
|
||||||
COOKIE_SECURE=true
|
|
||||||
|
|
||||||
# Session configuration
|
|
||||||
SESSION_PERMANENT=true
|
|
||||||
SESSION_USE_SIGNER=true
|
|
||||||
|
|
||||||
## support redis, sqlalchemy
|
|
||||||
SESSION_TYPE=redis
|
|
||||||
|
|
||||||
# session redis configuration
|
|
||||||
SESSION_REDIS_HOST=localhost
|
|
||||||
SESSION_REDIS_PORT=6379
|
|
||||||
SESSION_REDIS_PASSWORD=difyai123456
|
|
||||||
SESSION_REDIS_DB=2
|
|
||||||
|
|
||||||
# Vector database configuration, support: weaviate, qdrant
|
# Vector database configuration, support: weaviate, qdrant
|
||||||
VECTOR_STORE=weaviate
|
VECTOR_STORE=weaviate
|
||||||
|
|
||||||
|
91
api/app.py
91
api/app.py
@ -1,8 +1,7 @@
|
|||||||
# -*- coding:utf-8 -*-
|
# -*- coding:utf-8 -*-
|
||||||
import os
|
import os
|
||||||
from datetime import datetime, timedelta
|
|
||||||
|
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Unauthorized
|
||||||
|
|
||||||
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
|
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
|
||||||
from gevent import monkey
|
from gevent import monkey
|
||||||
@ -12,12 +11,11 @@ import logging
|
|||||||
import json
|
import json
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
from flask import Flask, request, Response, session
|
from flask import Flask, request, Response
|
||||||
import flask_login
|
|
||||||
from flask_cors import CORS
|
from flask_cors import CORS
|
||||||
|
|
||||||
from core.model_providers.providers import hosted
|
from core.model_providers.providers import hosted
|
||||||
from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
|
from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
|
||||||
ext_database, ext_storage, ext_mail, ext_stripe
|
ext_database, ext_storage, ext_mail, ext_stripe
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_login import login_manager
|
from extensions.ext_login import login_manager
|
||||||
@ -27,12 +25,10 @@ from models import model, account, dataset, web, task, source, tool
|
|||||||
from events import event_handlers
|
from events import event_handlers
|
||||||
# DO NOT REMOVE ABOVE
|
# DO NOT REMOVE ABOVE
|
||||||
|
|
||||||
import core
|
|
||||||
from config import Config, CloudEditionConfig
|
from config import Config, CloudEditionConfig
|
||||||
from commands import register_commands
|
from commands import register_commands
|
||||||
from models.account import TenantAccountJoin, AccountStatus
|
from services.account_service import AccountService
|
||||||
from models.model import Account, EndUser, App
|
from libs.passport import PassportService
|
||||||
from services.account_service import TenantService
|
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
warnings.simplefilter("ignore", ResourceWarning)
|
warnings.simplefilter("ignore", ResourceWarning)
|
||||||
@ -85,81 +81,33 @@ def initialize_extensions(app):
|
|||||||
ext_redis.init_app(app)
|
ext_redis.init_app(app)
|
||||||
ext_storage.init_app(app)
|
ext_storage.init_app(app)
|
||||||
ext_celery.init_app(app)
|
ext_celery.init_app(app)
|
||||||
ext_session.init_app(app)
|
|
||||||
ext_login.init_app(app)
|
ext_login.init_app(app)
|
||||||
ext_mail.init_app(app)
|
ext_mail.init_app(app)
|
||||||
ext_sentry.init_app(app)
|
ext_sentry.init_app(app)
|
||||||
ext_stripe.init_app(app)
|
ext_stripe.init_app(app)
|
||||||
|
|
||||||
|
|
||||||
def _create_tenant_for_account(account):
|
|
||||||
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
|
||||||
|
|
||||||
TenantService.create_tenant_member(tenant, account, role='owner')
|
|
||||||
account.current_tenant = tenant
|
|
||||||
|
|
||||||
return tenant
|
|
||||||
|
|
||||||
|
|
||||||
# Flask-Login configuration
|
# Flask-Login configuration
|
||||||
@login_manager.user_loader
|
@login_manager.request_loader
|
||||||
def load_user(user_id):
|
def load_user_from_request(request_from_flask_login):
|
||||||
"""Load user based on the user_id."""
|
"""Load user based on the request."""
|
||||||
if request.blueprint == 'console':
|
if request.blueprint == 'console':
|
||||||
# Check if the user_id contains a dot, indicating the old format
|
# Check if the user_id contains a dot, indicating the old format
|
||||||
if '.' in user_id:
|
auth_header = request.headers.get('Authorization', '')
|
||||||
tenant_id, account_id = user_id.split('.')
|
if ' ' not in auth_header:
|
||||||
else:
|
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||||
account_id = user_id
|
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||||
|
auth_scheme = auth_scheme.lower()
|
||||||
|
if auth_scheme != 'bearer':
|
||||||
|
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||||
|
|
||||||
|
decoded = PassportService().verify(auth_token)
|
||||||
|
user_id = decoded.get('user_id')
|
||||||
|
|
||||||
account = db.session.query(Account).filter(Account.id == account_id).first()
|
return AccountService.load_user(user_id)
|
||||||
|
|
||||||
if account:
|
|
||||||
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
|
|
||||||
raise Forbidden('Account is banned or closed.')
|
|
||||||
|
|
||||||
workspace_id = session.get('workspace_id')
|
|
||||||
if workspace_id:
|
|
||||||
tenant_account_join = db.session.query(TenantAccountJoin).filter(
|
|
||||||
TenantAccountJoin.account_id == account.id,
|
|
||||||
TenantAccountJoin.tenant_id == workspace_id
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not tenant_account_join:
|
|
||||||
tenant_account_join = db.session.query(TenantAccountJoin).filter(
|
|
||||||
TenantAccountJoin.account_id == account.id).first()
|
|
||||||
|
|
||||||
if tenant_account_join:
|
|
||||||
account.current_tenant_id = tenant_account_join.tenant_id
|
|
||||||
else:
|
|
||||||
_create_tenant_for_account(account)
|
|
||||||
session['workspace_id'] = account.current_tenant_id
|
|
||||||
else:
|
|
||||||
account.current_tenant_id = workspace_id
|
|
||||||
else:
|
|
||||||
tenant_account_join = db.session.query(TenantAccountJoin).filter(
|
|
||||||
TenantAccountJoin.account_id == account.id).first()
|
|
||||||
if tenant_account_join:
|
|
||||||
account.current_tenant_id = tenant_account_join.tenant_id
|
|
||||||
else:
|
|
||||||
_create_tenant_for_account(account)
|
|
||||||
session['workspace_id'] = account.current_tenant_id
|
|
||||||
|
|
||||||
current_time = datetime.utcnow()
|
|
||||||
|
|
||||||
# update last_active_at when last_active_at is more than 10 minutes ago
|
|
||||||
if current_time - account.last_active_at > timedelta(minutes=10):
|
|
||||||
account.last_active_at = current_time
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Log in the user with the updated user_id
|
|
||||||
flask_login.login_user(account, remember=True)
|
|
||||||
|
|
||||||
return account
|
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@login_manager.unauthorized_handler
|
@login_manager.unauthorized_handler
|
||||||
def unauthorized_handler():
|
def unauthorized_handler():
|
||||||
"""Handle unauthorized requests."""
|
"""Handle unauthorized requests."""
|
||||||
@ -216,6 +164,7 @@ if app.config['TESTING']:
|
|||||||
@app.after_request
|
@app.after_request
|
||||||
def after_request(response):
|
def after_request(response):
|
||||||
"""Add Version headers to the response."""
|
"""Add Version headers to the response."""
|
||||||
|
response.set_cookie('remember_token', '', expires=0)
|
||||||
response.headers.add('X-Version', app.config['CURRENT_VERSION'])
|
response.headers.add('X-Version', app.config['CURRENT_VERSION'])
|
||||||
response.headers.add('X-Env', app.config['DEPLOY_ENV'])
|
response.headers.add('X-Env', app.config['DEPLOY_ENV'])
|
||||||
return response
|
return response
|
||||||
|
@ -10,9 +10,6 @@ from extensions.ext_redis import redis_client
|
|||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
DEFAULTS = {
|
DEFAULTS = {
|
||||||
'COOKIE_HTTPONLY': 'True',
|
|
||||||
'COOKIE_SECURE': 'True',
|
|
||||||
'COOKIE_SAMESITE': 'None',
|
|
||||||
'DB_USERNAME': 'postgres',
|
'DB_USERNAME': 'postgres',
|
||||||
'DB_PASSWORD': '',
|
'DB_PASSWORD': '',
|
||||||
'DB_HOST': 'localhost',
|
'DB_HOST': 'localhost',
|
||||||
@ -22,10 +19,6 @@ DEFAULTS = {
|
|||||||
'REDIS_PORT': '6379',
|
'REDIS_PORT': '6379',
|
||||||
'REDIS_DB': '0',
|
'REDIS_DB': '0',
|
||||||
'REDIS_USE_SSL': 'False',
|
'REDIS_USE_SSL': 'False',
|
||||||
'SESSION_REDIS_HOST': 'localhost',
|
|
||||||
'SESSION_REDIS_PORT': '6379',
|
|
||||||
'SESSION_REDIS_DB': '2',
|
|
||||||
'SESSION_REDIS_USE_SSL': 'False',
|
|
||||||
'OAUTH_REDIRECT_PATH': '/console/api/oauth/authorize',
|
'OAUTH_REDIRECT_PATH': '/console/api/oauth/authorize',
|
||||||
'OAUTH_REDIRECT_INDEX_PATH': '/',
|
'OAUTH_REDIRECT_INDEX_PATH': '/',
|
||||||
'CONSOLE_WEB_URL': 'https://cloud.dify.ai',
|
'CONSOLE_WEB_URL': 'https://cloud.dify.ai',
|
||||||
@ -36,9 +29,6 @@ DEFAULTS = {
|
|||||||
'STORAGE_TYPE': 'local',
|
'STORAGE_TYPE': 'local',
|
||||||
'STORAGE_LOCAL_PATH': 'storage',
|
'STORAGE_LOCAL_PATH': 'storage',
|
||||||
'CHECK_UPDATE_URL': 'https://updates.dify.ai',
|
'CHECK_UPDATE_URL': 'https://updates.dify.ai',
|
||||||
'SESSION_TYPE': 'sqlalchemy',
|
|
||||||
'SESSION_PERMANENT': 'True',
|
|
||||||
'SESSION_USE_SIGNER': 'True',
|
|
||||||
'DEPLOY_ENV': 'PRODUCTION',
|
'DEPLOY_ENV': 'PRODUCTION',
|
||||||
'SQLALCHEMY_POOL_SIZE': 30,
|
'SQLALCHEMY_POOL_SIZE': 30,
|
||||||
'SQLALCHEMY_POOL_RECYCLE': 3600,
|
'SQLALCHEMY_POOL_RECYCLE': 3600,
|
||||||
@ -115,20 +105,6 @@ class Config:
|
|||||||
# Alternatively you can set it with `SECRET_KEY` environment variable.
|
# Alternatively you can set it with `SECRET_KEY` environment variable.
|
||||||
self.SECRET_KEY = get_env('SECRET_KEY')
|
self.SECRET_KEY = get_env('SECRET_KEY')
|
||||||
|
|
||||||
# cookie settings
|
|
||||||
self.REMEMBER_COOKIE_HTTPONLY = get_bool_env('COOKIE_HTTPONLY')
|
|
||||||
self.SESSION_COOKIE_HTTPONLY = get_bool_env('COOKIE_HTTPONLY')
|
|
||||||
self.REMEMBER_COOKIE_SAMESITE = get_env('COOKIE_SAMESITE')
|
|
||||||
self.SESSION_COOKIE_SAMESITE = get_env('COOKIE_SAMESITE')
|
|
||||||
self.REMEMBER_COOKIE_SECURE = get_bool_env('COOKIE_SECURE')
|
|
||||||
self.SESSION_COOKIE_SECURE = get_bool_env('COOKIE_SECURE')
|
|
||||||
self.PERMANENT_SESSION_LIFETIME = timedelta(days=7)
|
|
||||||
|
|
||||||
# session settings, only support sqlalchemy, redis
|
|
||||||
self.SESSION_TYPE = get_env('SESSION_TYPE')
|
|
||||||
self.SESSION_PERMANENT = get_bool_env('SESSION_PERMANENT')
|
|
||||||
self.SESSION_USE_SIGNER = get_bool_env('SESSION_USE_SIGNER')
|
|
||||||
|
|
||||||
# redis settings
|
# redis settings
|
||||||
self.REDIS_HOST = get_env('REDIS_HOST')
|
self.REDIS_HOST = get_env('REDIS_HOST')
|
||||||
self.REDIS_PORT = get_env('REDIS_PORT')
|
self.REDIS_PORT = get_env('REDIS_PORT')
|
||||||
@ -137,14 +113,6 @@ class Config:
|
|||||||
self.REDIS_DB = get_env('REDIS_DB')
|
self.REDIS_DB = get_env('REDIS_DB')
|
||||||
self.REDIS_USE_SSL = get_bool_env('REDIS_USE_SSL')
|
self.REDIS_USE_SSL = get_bool_env('REDIS_USE_SSL')
|
||||||
|
|
||||||
# session redis settings
|
|
||||||
self.SESSION_REDIS_HOST = get_env('SESSION_REDIS_HOST')
|
|
||||||
self.SESSION_REDIS_PORT = get_env('SESSION_REDIS_PORT')
|
|
||||||
self.SESSION_REDIS_USERNAME = get_env('SESSION_REDIS_USERNAME')
|
|
||||||
self.SESSION_REDIS_PASSWORD = get_env('SESSION_REDIS_PASSWORD')
|
|
||||||
self.SESSION_REDIS_DB = get_env('SESSION_REDIS_DB')
|
|
||||||
self.SESSION_REDIS_USE_SSL = get_bool_env('SESSION_REDIS_USE_SSL')
|
|
||||||
|
|
||||||
# storage settings
|
# storage settings
|
||||||
self.STORAGE_TYPE = get_env('STORAGE_TYPE')
|
self.STORAGE_TYPE = get_env('STORAGE_TYPE')
|
||||||
self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH')
|
self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH')
|
||||||
|
@ -6,7 +6,6 @@ from flask_restful import Resource, reqparse
|
|||||||
|
|
||||||
import services
|
import services
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.error import AccountNotLinkTenantError
|
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from libs.helper import email
|
from libs.helper import email
|
||||||
from libs.password import valid_password
|
from libs.password import valid_password
|
||||||
@ -37,12 +36,12 @@ class LoginApi(Resource):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
flask_login.login_user(account, remember=args['remember_me'])
|
|
||||||
AccountService.update_last_login(account, request)
|
AccountService.update_last_login(account, request)
|
||||||
|
|
||||||
# todo: return the user info
|
# todo: return the user info
|
||||||
|
token = AccountService.get_account_jwt_token(account)
|
||||||
|
|
||||||
return {'result': 'success'}
|
return {'result': 'success', 'data': token}
|
||||||
|
|
||||||
|
|
||||||
class LogoutApi(Resource):
|
class LogoutApi(Resource):
|
||||||
|
@ -2,9 +2,8 @@ import logging
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import flask_login
|
|
||||||
import requests
|
import requests
|
||||||
from flask import request, redirect, current_app, session
|
from flask import request, redirect, current_app
|
||||||
from flask_restful import Resource
|
from flask_restful import Resource
|
||||||
|
|
||||||
from libs.oauth import OAuthUserInfo, GitHubOAuth, GoogleOAuth
|
from libs.oauth import OAuthUserInfo, GitHubOAuth, GoogleOAuth
|
||||||
@ -75,12 +74,11 @@ class OAuthCallback(Resource):
|
|||||||
account.initialized_at = datetime.utcnow()
|
account.initialized_at = datetime.utcnow()
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
# login user
|
|
||||||
session.clear()
|
|
||||||
flask_login.login_user(account, remember=True)
|
|
||||||
AccountService.update_last_login(account, request)
|
AccountService.update_last_login(account, request)
|
||||||
|
|
||||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_login=success')
|
token = AccountService.get_account_jwt_token(account)
|
||||||
|
|
||||||
|
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?console_token={token}')
|
||||||
|
|
||||||
|
|
||||||
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
|
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
# -*- coding:utf-8 -*-
|
# -*- coding:utf-8 -*-
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
import flask_login
|
|
||||||
from flask import request, current_app
|
from flask import request, current_app
|
||||||
from flask_restful import Resource, reqparse
|
from flask_restful import Resource, reqparse
|
||||||
|
|
||||||
@ -58,9 +57,6 @@ class SetupApi(Resource):
|
|||||||
)
|
)
|
||||||
|
|
||||||
setup()
|
setup()
|
||||||
|
|
||||||
# Login
|
|
||||||
flask_login.login_user(account)
|
|
||||||
AccountService.update_last_login(account, request)
|
AccountService.update_last_login(account, request)
|
||||||
|
|
||||||
return {'result': 'success'}, 201
|
return {'result': 'success'}, 201
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
import flask_login
|
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
from flask import g
|
from flask import g
|
||||||
from flask import has_request_context
|
from flask import has_request_context
|
||||||
from flask import request
|
from flask import request, session
|
||||||
from flask_login import user_logged_in
|
from flask_login import user_logged_in
|
||||||
from flask_login.config import EXEMPT_METHODS
|
from flask_login.config import EXEMPT_METHODS
|
||||||
from werkzeug.exceptions import Unauthorized
|
from werkzeug.exceptions import Unauthorized
|
||||||
|
@ -1,174 +0,0 @@
|
|||||||
import redis
|
|
||||||
from redis.connection import SSLConnection, Connection
|
|
||||||
from flask import request
|
|
||||||
from flask_session import Session, SqlAlchemySessionInterface, RedisSessionInterface
|
|
||||||
from flask_session.sessions import total_seconds
|
|
||||||
from itsdangerous import want_bytes
|
|
||||||
|
|
||||||
from extensions.ext_database import db
|
|
||||||
|
|
||||||
sess = Session()
|
|
||||||
|
|
||||||
|
|
||||||
def init_app(app):
|
|
||||||
sqlalchemy_session_interface = CustomSqlAlchemySessionInterface(
|
|
||||||
app,
|
|
||||||
db,
|
|
||||||
app.config.get('SESSION_SQLALCHEMY_TABLE', 'sessions'),
|
|
||||||
app.config.get('SESSION_KEY_PREFIX', 'session:'),
|
|
||||||
app.config.get('SESSION_USE_SIGNER', False),
|
|
||||||
app.config.get('SESSION_PERMANENT', True)
|
|
||||||
)
|
|
||||||
|
|
||||||
session_type = app.config.get('SESSION_TYPE')
|
|
||||||
if session_type == 'sqlalchemy':
|
|
||||||
app.session_interface = sqlalchemy_session_interface
|
|
||||||
elif session_type == 'redis':
|
|
||||||
connection_class = Connection
|
|
||||||
if app.config.get('SESSION_REDIS_USE_SSL', False):
|
|
||||||
connection_class = SSLConnection
|
|
||||||
|
|
||||||
sess_redis_client = redis.Redis()
|
|
||||||
sess_redis_client.connection_pool = redis.ConnectionPool(**{
|
|
||||||
'host': app.config.get('SESSION_REDIS_HOST', 'localhost'),
|
|
||||||
'port': app.config.get('SESSION_REDIS_PORT', 6379),
|
|
||||||
'username': app.config.get('SESSION_REDIS_USERNAME', None),
|
|
||||||
'password': app.config.get('SESSION_REDIS_PASSWORD', None),
|
|
||||||
'db': app.config.get('SESSION_REDIS_DB', 2),
|
|
||||||
'encoding': 'utf-8',
|
|
||||||
'encoding_errors': 'strict',
|
|
||||||
'decode_responses': False
|
|
||||||
}, connection_class=connection_class)
|
|
||||||
|
|
||||||
app.extensions['session_redis'] = sess_redis_client
|
|
||||||
|
|
||||||
app.session_interface = CustomRedisSessionInterface(
|
|
||||||
sess_redis_client,
|
|
||||||
app.config.get('SESSION_KEY_PREFIX', 'session:'),
|
|
||||||
app.config.get('SESSION_USE_SIGNER', False),
|
|
||||||
app.config.get('SESSION_PERMANENT', True)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CustomSqlAlchemySessionInterface(SqlAlchemySessionInterface):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
app,
|
|
||||||
db,
|
|
||||||
table,
|
|
||||||
key_prefix,
|
|
||||||
use_signer=False,
|
|
||||||
permanent=True,
|
|
||||||
sequence=None,
|
|
||||||
autodelete=False,
|
|
||||||
):
|
|
||||||
if db is None:
|
|
||||||
from flask_sqlalchemy import SQLAlchemy
|
|
||||||
|
|
||||||
db = SQLAlchemy(app)
|
|
||||||
self.db = db
|
|
||||||
self.key_prefix = key_prefix
|
|
||||||
self.use_signer = use_signer
|
|
||||||
self.permanent = permanent
|
|
||||||
self.autodelete = autodelete
|
|
||||||
self.sequence = sequence
|
|
||||||
self.has_same_site_capability = hasattr(self, "get_cookie_samesite")
|
|
||||||
|
|
||||||
class Session(self.db.Model):
|
|
||||||
__tablename__ = table
|
|
||||||
|
|
||||||
if sequence:
|
|
||||||
id = self.db.Column( # noqa: A003, VNE003, A001
|
|
||||||
self.db.Integer, self.db.Sequence(sequence), primary_key=True
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
id = self.db.Column( # noqa: A003, VNE003, A001
|
|
||||||
self.db.Integer, primary_key=True
|
|
||||||
)
|
|
||||||
|
|
||||||
session_id = self.db.Column(self.db.String(255), unique=True)
|
|
||||||
data = self.db.Column(self.db.LargeBinary)
|
|
||||||
expiry = self.db.Column(self.db.DateTime)
|
|
||||||
|
|
||||||
def __init__(self, session_id, data, expiry):
|
|
||||||
self.session_id = session_id
|
|
||||||
self.data = data
|
|
||||||
self.expiry = expiry
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"<Session data {self.data}>"
|
|
||||||
|
|
||||||
self.sql_session_model = Session
|
|
||||||
|
|
||||||
def save_session(self, *args, **kwargs):
|
|
||||||
if request.blueprint == 'service_api':
|
|
||||||
return
|
|
||||||
elif request.method == 'OPTIONS':
|
|
||||||
return
|
|
||||||
elif request.endpoint and request.endpoint == 'health':
|
|
||||||
return
|
|
||||||
return super().save_session(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class CustomRedisSessionInterface(RedisSessionInterface):
|
|
||||||
|
|
||||||
def save_session(self, app, session, response):
|
|
||||||
if request.blueprint == 'service_api':
|
|
||||||
return
|
|
||||||
elif request.method == 'OPTIONS':
|
|
||||||
return
|
|
||||||
elif request.endpoint and request.endpoint == 'health':
|
|
||||||
return
|
|
||||||
|
|
||||||
if not self.should_set_cookie(app, session):
|
|
||||||
return
|
|
||||||
domain = self.get_cookie_domain(app)
|
|
||||||
path = self.get_cookie_path(app)
|
|
||||||
if not session:
|
|
||||||
if session.modified:
|
|
||||||
self.redis.delete(self.key_prefix + session.sid)
|
|
||||||
response.delete_cookie(
|
|
||||||
app.config["SESSION_COOKIE_NAME"], domain=domain, path=path
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Modification case. There are upsides and downsides to
|
|
||||||
# emitting a set-cookie header each request. The behavior
|
|
||||||
# is controlled by the :meth:`should_set_cookie` method
|
|
||||||
# which performs a quick check to figure out if the cookie
|
|
||||||
# should be set or not. This is controlled by the
|
|
||||||
# SESSION_REFRESH_EACH_REQUEST config flag as well as
|
|
||||||
# the permanent flag on the session itself.
|
|
||||||
# if not self.should_set_cookie(app, session):
|
|
||||||
# return
|
|
||||||
conditional_cookie_kwargs = {}
|
|
||||||
httponly = self.get_cookie_httponly(app)
|
|
||||||
secure = self.get_cookie_secure(app)
|
|
||||||
if self.has_same_site_capability:
|
|
||||||
conditional_cookie_kwargs["samesite"] = self.get_cookie_samesite(app)
|
|
||||||
expires = self.get_expiration_time(app, session)
|
|
||||||
|
|
||||||
if session.permanent:
|
|
||||||
value = self.serializer.dumps(dict(session))
|
|
||||||
if value is not None:
|
|
||||||
self.redis.setex(
|
|
||||||
name=self.key_prefix + session.sid,
|
|
||||||
value=value,
|
|
||||||
time=total_seconds(app.permanent_session_lifetime),
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.use_signer:
|
|
||||||
session_id = self._get_signer(app).sign(want_bytes(session.sid)).decode("utf-8")
|
|
||||||
else:
|
|
||||||
session_id = session.sid
|
|
||||||
response.set_cookie(
|
|
||||||
app.config["SESSION_COOKIE_NAME"],
|
|
||||||
session_id,
|
|
||||||
expires=expires,
|
|
||||||
httponly=httponly,
|
|
||||||
domain=domain,
|
|
||||||
path=path,
|
|
||||||
secure=secure,
|
|
||||||
**conditional_cookie_kwargs,
|
|
||||||
)
|
|
@ -4,11 +4,12 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import secrets
|
import secrets
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime, timedelta
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from flask import session
|
from werkzeug.exceptions import Forbidden, Unauthorized
|
||||||
|
from flask import session, current_app
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
|
|
||||||
from events.tenant_event import tenant_was_created
|
from events.tenant_event import tenant_was_created
|
||||||
@ -19,16 +20,82 @@ from services.errors.account import AccountLoginError, CurrentPasswordIncorrectE
|
|||||||
from libs.helper import get_remote_ip
|
from libs.helper import get_remote_ip
|
||||||
from libs.password import compare_password, hash_password
|
from libs.password import compare_password, hash_password
|
||||||
from libs.rsa import generate_key_pair
|
from libs.rsa import generate_key_pair
|
||||||
|
from libs.passport import PassportService
|
||||||
from models.account import *
|
from models.account import *
|
||||||
from tasks.mail_invite_member_task import send_invite_member_mail_task
|
from tasks.mail_invite_member_task import send_invite_member_mail_task
|
||||||
|
|
||||||
|
def _create_tenant_for_account(account):
|
||||||
|
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||||
|
|
||||||
|
TenantService.create_tenant_member(tenant, account, role='owner')
|
||||||
|
account.current_tenant = tenant
|
||||||
|
|
||||||
|
return tenant
|
||||||
|
|
||||||
|
|
||||||
class AccountService:
|
class AccountService:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_user(account_id: int) -> Account:
|
def load_user(user_id: str) -> Account:
|
||||||
# todo: used by flask_login
|
# todo: used by flask_login
|
||||||
pass
|
if '.' in user_id:
|
||||||
|
tenant_id, account_id = user_id.split('.')
|
||||||
|
else:
|
||||||
|
account_id = user_id
|
||||||
|
|
||||||
|
account = db.session.query(Account).filter(Account.id == account_id).first()
|
||||||
|
|
||||||
|
if account:
|
||||||
|
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
|
||||||
|
raise Forbidden('Account is banned or closed.')
|
||||||
|
|
||||||
|
workspace_id = session.get('workspace_id')
|
||||||
|
if workspace_id:
|
||||||
|
tenant_account_join = db.session.query(TenantAccountJoin).filter(
|
||||||
|
TenantAccountJoin.account_id == account.id,
|
||||||
|
TenantAccountJoin.tenant_id == workspace_id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not tenant_account_join:
|
||||||
|
tenant_account_join = db.session.query(TenantAccountJoin).filter(
|
||||||
|
TenantAccountJoin.account_id == account.id).first()
|
||||||
|
|
||||||
|
if tenant_account_join:
|
||||||
|
account.current_tenant_id = tenant_account_join.tenant_id
|
||||||
|
else:
|
||||||
|
_create_tenant_for_account(account)
|
||||||
|
session['workspace_id'] = account.current_tenant_id
|
||||||
|
else:
|
||||||
|
account.current_tenant_id = workspace_id
|
||||||
|
else:
|
||||||
|
tenant_account_join = db.session.query(TenantAccountJoin).filter(
|
||||||
|
TenantAccountJoin.account_id == account.id).first()
|
||||||
|
if tenant_account_join:
|
||||||
|
account.current_tenant_id = tenant_account_join.tenant_id
|
||||||
|
else:
|
||||||
|
_create_tenant_for_account(account)
|
||||||
|
session['workspace_id'] = account.current_tenant_id
|
||||||
|
|
||||||
|
current_time = datetime.utcnow()
|
||||||
|
|
||||||
|
# update last_active_at when last_active_at is more than 10 minutes ago
|
||||||
|
if current_time - account.last_active_at > timedelta(minutes=10):
|
||||||
|
account.last_active_at = current_time
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return account
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_account_jwt_token(account):
|
||||||
|
payload = {
|
||||||
|
"user_id": account.id,
|
||||||
|
"exp": datetime.utcnow() + timedelta(days=30),
|
||||||
|
"iss": current_app.config['EDITION'],
|
||||||
|
"sub": 'Console API Passport',
|
||||||
|
}
|
||||||
|
|
||||||
|
token = PassportService().issue(payload)
|
||||||
|
return token
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def authenticate(email: str, password: str) -> Account:
|
def authenticate(email: str, password: str) -> Account:
|
||||||
|
@ -49,15 +49,6 @@ services:
|
|||||||
REDIS_USE_SSL: 'false'
|
REDIS_USE_SSL: 'false'
|
||||||
# use redis db 0 for redis cache
|
# use redis db 0 for redis cache
|
||||||
REDIS_DB: 0
|
REDIS_DB: 0
|
||||||
# The configurations of session, Supported values are `sqlalchemy`. `redis`
|
|
||||||
SESSION_TYPE: redis
|
|
||||||
SESSION_REDIS_HOST: redis
|
|
||||||
SESSION_REDIS_PORT: 6379
|
|
||||||
SESSION_REDIS_USERNAME: ''
|
|
||||||
SESSION_REDIS_PASSWORD: difyai123456
|
|
||||||
SESSION_REDIS_USE_SSL: 'false'
|
|
||||||
# use redis db 2 for session store
|
|
||||||
SESSION_REDIS_DB: 2
|
|
||||||
# The configurations of celery broker.
|
# The configurations of celery broker.
|
||||||
# Use redis as the broker, and redis db 1 for celery broker.
|
# Use redis as the broker, and redis db 1 for celery broker.
|
||||||
CELERY_BROKER_URL: redis://:difyai123456@redis:6379/1
|
CELERY_BROKER_URL: redis://:difyai123456@redis:6379/1
|
||||||
@ -76,10 +67,6 @@ services:
|
|||||||
# If you want to enable cross-origin support,
|
# If you want to enable cross-origin support,
|
||||||
# you must use the HTTPS protocol and set the configuration to `SameSite=None, Secure=true, HttpOnly=true`.
|
# you must use the HTTPS protocol and set the configuration to `SameSite=None, Secure=true, HttpOnly=true`.
|
||||||
#
|
#
|
||||||
# For **production** purposes, please set `SameSite=Lax, Secure=true, HttpOnly=true`.
|
|
||||||
COOKIE_HTTPONLY: 'true'
|
|
||||||
COOKIE_SAMESITE: 'Lax'
|
|
||||||
COOKIE_SECURE: 'false'
|
|
||||||
# The type of storage to use for storing user files. Supported values are `local` and `s3`, Default: `local`
|
# The type of storage to use for storing user files. Supported values are `local` and `s3`, Default: `local`
|
||||||
STORAGE_TYPE: local
|
STORAGE_TYPE: local
|
||||||
# The path to the local storage directory, the directory relative the root path of API service codes or absolute path. Default: `storage` or `/home/john/storage`.
|
# The path to the local storage directory, the directory relative the root path of API service codes or absolute path. Default: `storage` or `/home/john/storage`.
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
'use client'
|
'use client'
|
||||||
|
|
||||||
import { SWRConfig } from 'swr'
|
import { SWRConfig } from 'swr'
|
||||||
|
import { useEffect, useState } from 'react'
|
||||||
import type { ReactNode } from 'react'
|
import type { ReactNode } from 'react'
|
||||||
|
import { useRouter, useSearchParams } from 'next/navigation'
|
||||||
|
|
||||||
type SwrInitorProps = {
|
type SwrInitorProps = {
|
||||||
children: ReactNode
|
children: ReactNode
|
||||||
@ -9,13 +11,32 @@ type SwrInitorProps = {
|
|||||||
const SwrInitor = ({
|
const SwrInitor = ({
|
||||||
children,
|
children,
|
||||||
}: SwrInitorProps) => {
|
}: SwrInitorProps) => {
|
||||||
return (
|
const router = useRouter()
|
||||||
<SWRConfig value={{
|
const searchParams = useSearchParams()
|
||||||
shouldRetryOnError: false,
|
const consoleToken = searchParams.get('console_token')
|
||||||
}}>
|
const consoleTokenFromLocalStorage = localStorage?.getItem('console_token')
|
||||||
{children}
|
const [init, setInit] = useState(false)
|
||||||
</SWRConfig>
|
|
||||||
)
|
useEffect(() => {
|
||||||
|
if (!(consoleToken || consoleTokenFromLocalStorage))
|
||||||
|
router.replace('/signin')
|
||||||
|
|
||||||
|
if (consoleToken) {
|
||||||
|
localStorage?.setItem('console_token', consoleToken!)
|
||||||
|
router.replace('/apps', { forceOptimisticNavigation: false })
|
||||||
|
}
|
||||||
|
setInit(true)
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
return init
|
||||||
|
? (
|
||||||
|
<SWRConfig value={{
|
||||||
|
shouldRetryOnError: false,
|
||||||
|
}}>
|
||||||
|
{children}
|
||||||
|
</SWRConfig>
|
||||||
|
)
|
||||||
|
: null
|
||||||
}
|
}
|
||||||
|
|
||||||
export default SwrInitor
|
export default SwrInitor
|
||||||
|
@ -8,6 +8,10 @@ import I18n from '@/context/i18n'
|
|||||||
|
|
||||||
const Header = () => {
|
const Header = () => {
|
||||||
const { locale, setLocaleOnClient } = useContext(I18n)
|
const { locale, setLocaleOnClient } = useContext(I18n)
|
||||||
|
|
||||||
|
if (localStorage?.getItem('console_token'))
|
||||||
|
localStorage.removeItem('console_token')
|
||||||
|
|
||||||
return <div className='flex items-center justify-between p-6 w-full'>
|
return <div className='flex items-center justify-between p-6 w-full'>
|
||||||
<div className={style.logo}></div>
|
<div className={style.logo}></div>
|
||||||
<Select
|
<Select
|
||||||
|
@ -89,7 +89,7 @@ const NormalForm = () => {
|
|||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
setIsLoading(true)
|
setIsLoading(true)
|
||||||
await login({
|
const res = await login({
|
||||||
url: '/login',
|
url: '/login',
|
||||||
body: {
|
body: {
|
||||||
email,
|
email,
|
||||||
@ -97,7 +97,8 @@ const NormalForm = () => {
|
|||||||
remember_me: true,
|
remember_me: true,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
router.push('/apps')
|
localStorage.setItem('console_token', res.data)
|
||||||
|
router.replace('/apps')
|
||||||
}
|
}
|
||||||
finally {
|
finally {
|
||||||
setIsLoading(false)
|
setIsLoading(false)
|
||||||
|
@ -179,6 +179,10 @@ const baseFetch = <T>(
|
|||||||
}
|
}
|
||||||
options.headers.set('Authorization', `Bearer ${accessTokenJson[sharedToken]}`)
|
options.headers.set('Authorization', `Bearer ${accessTokenJson[sharedToken]}`)
|
||||||
}
|
}
|
||||||
|
else {
|
||||||
|
const accessToken = localStorage.getItem('console_token') || ''
|
||||||
|
options.headers.set('Authorization', `Bearer ${accessToken}`)
|
||||||
|
}
|
||||||
|
|
||||||
if (deleteContentType) {
|
if (deleteContentType) {
|
||||||
options.headers.delete('Content-Type')
|
options.headers.delete('Content-Type')
|
||||||
@ -292,7 +296,9 @@ export const upload = (options: any): Promise<any> => {
|
|||||||
const defaultOptions = {
|
const defaultOptions = {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
url: `${API_PREFIX}/files/upload`,
|
url: `${API_PREFIX}/files/upload`,
|
||||||
headers: {},
|
headers: {
|
||||||
|
Authorization: `Bearer ${localStorage.getItem('console_token') || ''}`,
|
||||||
|
},
|
||||||
data: {},
|
data: {},
|
||||||
}
|
}
|
||||||
options = {
|
options = {
|
||||||
|
@ -15,8 +15,8 @@ import type {
|
|||||||
} from '@/models/app'
|
} from '@/models/app'
|
||||||
import type { BackendModel, ProviderMap } from '@/app/components/header/account-setting/model-page/declarations'
|
import type { BackendModel, ProviderMap } from '@/app/components/header/account-setting/model-page/declarations'
|
||||||
|
|
||||||
export const login: Fetcher<CommonResponse, { url: string; body: Record<string, any> }> = ({ url, body }) => {
|
export const login: Fetcher<CommonResponse & { data: string }, { url: string; body: Record<string, any> }> = ({ url, body }) => {
|
||||||
return post<CommonResponse>(url, { body })
|
return post(url, { body }) as Promise<CommonResponse & { data: string }>
|
||||||
}
|
}
|
||||||
|
|
||||||
export const setup: Fetcher<CommonResponse, { body: Record<string, any> }> = ({ body }) => {
|
export const setup: Fetcher<CommonResponse, { body: Record<string, any> }> = ({ body }) => {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user