mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-06-04 11:14:10 +08:00
Feature/use jwt in web (#533)
Co-authored-by: crazywoola <li.zheng@dentsplysirona.com> Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
This commit is contained in:
parent
57de19a5ca
commit
d49ac1e4ac
1
.gitignore
vendored
1
.gitignore
vendored
@ -109,6 +109,7 @@ venv/
|
|||||||
ENV/
|
ENV/
|
||||||
env.bak/
|
env.bak/
|
||||||
venv.bak/
|
venv.bak/
|
||||||
|
.conda/
|
||||||
|
|
||||||
# Spyder project settings
|
# Spyder project settings
|
||||||
.spyderproject
|
.spyderproject
|
||||||
|
@ -155,7 +155,7 @@ def register_blueprints(app):
|
|||||||
resources={
|
resources={
|
||||||
r"/*": {"origins": app.config['WEB_API_CORS_ALLOW_ORIGINS']}},
|
r"/*": {"origins": app.config['WEB_API_CORS_ALLOW_ORIGINS']}},
|
||||||
supports_credentials=True,
|
supports_credentials=True,
|
||||||
allow_headers=['Content-Type', 'Authorization'],
|
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
|
||||||
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
|
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
|
||||||
expose_headers=['X-Version', 'X-Env']
|
expose_headers=['X-Version', 'X-Env']
|
||||||
)
|
)
|
||||||
|
@ -7,4 +7,4 @@ bp = Blueprint('web', __name__, url_prefix='/api')
|
|||||||
api = ExternalApi(bp)
|
api = ExternalApi(bp)
|
||||||
|
|
||||||
|
|
||||||
from . import completion, app, conversation, message, site, saved_message, audio
|
from . import completion, app, conversation, message, site, saved_message, audio, passport
|
||||||
|
64
api/controllers/web/passport.py
Normal file
64
api/controllers/web/passport.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
import uuid
|
||||||
|
from controllers.web import api
|
||||||
|
from flask_restful import Resource
|
||||||
|
from flask import request
|
||||||
|
from werkzeug.exceptions import Unauthorized, NotFound
|
||||||
|
from models.model import Site, EndUser, App
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from libs.passport import PassportService
|
||||||
|
|
||||||
|
class PassportResource(Resource):
|
||||||
|
"""Base resource for passport."""
|
||||||
|
def get(self):
|
||||||
|
app_id = request.headers.get('X-App-Code')
|
||||||
|
if app_id is None:
|
||||||
|
raise Unauthorized('X-App-Code header is missing.')
|
||||||
|
|
||||||
|
# get site from db and check if it is normal
|
||||||
|
site = db.session.query(Site).filter(
|
||||||
|
Site.code == app_id,
|
||||||
|
Site.status == 'normal'
|
||||||
|
).first()
|
||||||
|
if not site:
|
||||||
|
raise NotFound()
|
||||||
|
# get app from db and check if it is normal and enable_site
|
||||||
|
app_model = db.session.query(App).filter(App.id == site.app_id).first()
|
||||||
|
if not app_model or app_model.status != 'normal' or not app_model.enable_site:
|
||||||
|
raise NotFound()
|
||||||
|
|
||||||
|
end_user = EndUser(
|
||||||
|
tenant_id=app_model.tenant_id,
|
||||||
|
app_id=app_model.id,
|
||||||
|
type='browser',
|
||||||
|
is_anonymous=True,
|
||||||
|
session_id=generate_session_id(),
|
||||||
|
)
|
||||||
|
db.session.add(end_user)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"iss": site.app_id,
|
||||||
|
'sub': 'Web API Passport',
|
||||||
|
'app_id': site.app_id,
|
||||||
|
'end_user_id': end_user.id,
|
||||||
|
}
|
||||||
|
|
||||||
|
tk = PassportService().issue(payload)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'access_token': tk,
|
||||||
|
}
|
||||||
|
|
||||||
|
api.add_resource(PassportResource, '/passport')
|
||||||
|
|
||||||
|
def generate_session_id():
|
||||||
|
"""
|
||||||
|
Generate a unique session ID.
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
existing_count = db.session.query(EndUser) \
|
||||||
|
.filter(EndUser.session_id == session_id).count()
|
||||||
|
if existing_count == 0:
|
||||||
|
return session_id
|
@ -1,110 +1,48 @@
|
|||||||
# -*- coding:utf-8 -*-
|
# -*- coding:utf-8 -*-
|
||||||
import uuid
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
from flask import request, session
|
from flask import request
|
||||||
from flask_restful import Resource
|
from flask_restful import Resource
|
||||||
from werkzeug.exceptions import NotFound, Unauthorized
|
from werkzeug.exceptions import NotFound, Unauthorized
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import App, Site, EndUser
|
from models.model import App, EndUser
|
||||||
|
from libs.passport import PassportService
|
||||||
|
|
||||||
|
def validate_jwt_token(view=None):
|
||||||
def validate_token(view=None):
|
|
||||||
def decorator(view):
|
def decorator(view):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args, **kwargs):
|
||||||
site = validate_and_get_site()
|
app_model, end_user = decode_jwt_token()
|
||||||
|
|
||||||
app_model = db.session.query(App).filter(App.id == site.app_id).first()
|
|
||||||
if not app_model:
|
|
||||||
raise NotFound()
|
|
||||||
|
|
||||||
if app_model.status != 'normal':
|
|
||||||
raise NotFound()
|
|
||||||
|
|
||||||
if not app_model.enable_site:
|
|
||||||
raise NotFound()
|
|
||||||
|
|
||||||
end_user = create_or_update_end_user_for_session(app_model)
|
|
||||||
|
|
||||||
return view(app_model, end_user, *args, **kwargs)
|
return view(app_model, end_user, *args, **kwargs)
|
||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
if view:
|
if view:
|
||||||
return decorator(view)
|
return decorator(view)
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
def decode_jwt_token():
|
||||||
def validate_and_get_site():
|
|
||||||
"""
|
|
||||||
Validate and get API token.
|
|
||||||
"""
|
|
||||||
auth_header = request.headers.get('Authorization')
|
auth_header = request.headers.get('Authorization')
|
||||||
if auth_header is None:
|
if auth_header is None:
|
||||||
raise Unauthorized('Authorization header is missing.')
|
raise Unauthorized('Authorization header is missing.')
|
||||||
|
|
||||||
if ' ' not in auth_header:
|
if ' ' not in auth_header:
|
||||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||||
|
|
||||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
auth_scheme, tk = auth_header.split(None, 1)
|
||||||
auth_scheme = auth_scheme.lower()
|
auth_scheme = auth_scheme.lower()
|
||||||
|
|
||||||
if auth_scheme != 'bearer':
|
if auth_scheme != 'bearer':
|
||||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||||
|
decoded = PassportService().verify(tk)
|
||||||
site = db.session.query(Site).filter(
|
app_model = db.session.query(App).filter(App.id == decoded['app_id']).first()
|
||||||
Site.code == auth_token,
|
if not app_model:
|
||||||
Site.status == 'normal'
|
raise NotFound()
|
||||||
).first()
|
end_user = db.session.query(EndUser).filter(EndUser.id == decoded['end_user_id']).first()
|
||||||
|
if not end_user:
|
||||||
if not site:
|
|
||||||
raise NotFound()
|
raise NotFound()
|
||||||
|
|
||||||
return site
|
return app_model, end_user
|
||||||
|
|
||||||
|
|
||||||
def create_or_update_end_user_for_session(app_model):
|
|
||||||
"""
|
|
||||||
Create or update session terminal based on session ID.
|
|
||||||
"""
|
|
||||||
if 'session_id' not in session:
|
|
||||||
session['session_id'] = generate_session_id()
|
|
||||||
|
|
||||||
session_id = session.get('session_id')
|
|
||||||
end_user = db.session.query(EndUser) \
|
|
||||||
.filter(
|
|
||||||
EndUser.session_id == session_id,
|
|
||||||
EndUser.type == 'browser'
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if end_user is None:
|
|
||||||
end_user = EndUser(
|
|
||||||
tenant_id=app_model.tenant_id,
|
|
||||||
app_id=app_model.id,
|
|
||||||
type='browser',
|
|
||||||
is_anonymous=True,
|
|
||||||
session_id=session_id
|
|
||||||
)
|
|
||||||
db.session.add(end_user)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
return end_user
|
|
||||||
|
|
||||||
|
|
||||||
def generate_session_id():
|
|
||||||
"""
|
|
||||||
Generate a unique session ID.
|
|
||||||
"""
|
|
||||||
count = 1
|
|
||||||
session_id = ''
|
|
||||||
while count != 0:
|
|
||||||
session_id = str(uuid.uuid4())
|
|
||||||
count = db.session.query(EndUser) \
|
|
||||||
.filter(EndUser.session_id == session_id).count()
|
|
||||||
|
|
||||||
return session_id
|
|
||||||
|
|
||||||
|
|
||||||
class WebApiResource(Resource):
|
class WebApiResource(Resource):
|
||||||
method_decorators = [validate_token]
|
method_decorators = [validate_jwt_token]
|
||||||
|
20
api/libs/passport.py
Normal file
20
api/libs/passport.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
import jwt
|
||||||
|
from werkzeug.exceptions import Unauthorized
|
||||||
|
from flask import current_app
|
||||||
|
class PassportService:
|
||||||
|
def __init__(self):
|
||||||
|
self.sk = current_app.config.get('SECRET_KEY')
|
||||||
|
|
||||||
|
def issue(self, payload):
|
||||||
|
return jwt.encode(payload, self.sk, algorithm='HS256')
|
||||||
|
|
||||||
|
def verify(self, token):
|
||||||
|
try:
|
||||||
|
return jwt.decode(token, self.sk, algorithms=['HS256'])
|
||||||
|
except jwt.exceptions.InvalidSignatureError:
|
||||||
|
raise Unauthorized('Invalid token signature.')
|
||||||
|
except jwt.exceptions.DecodeError:
|
||||||
|
raise Unauthorized('Invalid token.')
|
||||||
|
except jwt.exceptions.ExpiredSignatureError:
|
||||||
|
raise Unauthorized('Token has expired.')
|
@ -32,4 +32,5 @@ redis~=4.5.4
|
|||||||
openpyxl==3.1.2
|
openpyxl==3.1.2
|
||||||
chardet~=5.1.0
|
chardet~=5.1.0
|
||||||
docx2txt==0.8
|
docx2txt==0.8
|
||||||
pypdfium2==4.16.0
|
pypdfium2==4.16.0
|
||||||
|
pyjwt~=2.6.0
|
@ -8,13 +8,26 @@ import { useContext } from 'use-context-selector'
|
|||||||
import produce from 'immer'
|
import produce from 'immer'
|
||||||
import { useBoolean, useGetState } from 'ahooks'
|
import { useBoolean, useGetState } from 'ahooks'
|
||||||
import AppUnavailable from '../../base/app-unavailable'
|
import AppUnavailable from '../../base/app-unavailable'
|
||||||
|
import { checkOrSetAccessToken } from '../utils'
|
||||||
import useConversation from './hooks/use-conversation'
|
import useConversation from './hooks/use-conversation'
|
||||||
import s from './style.module.css'
|
import s from './style.module.css'
|
||||||
import { ToastContext } from '@/app/components/base/toast'
|
import { ToastContext } from '@/app/components/base/toast'
|
||||||
import Sidebar from '@/app/components/share/chat/sidebar'
|
import Sidebar from '@/app/components/share/chat/sidebar'
|
||||||
import ConfigSence from '@/app/components/share/chat/config-scence'
|
import ConfigSence from '@/app/components/share/chat/config-scence'
|
||||||
import Header from '@/app/components/share/header'
|
import Header from '@/app/components/share/header'
|
||||||
import { delConversation, fetchAppInfo, fetchAppParams, fetchChatList, fetchConversations, fetchSuggestedQuestions, pinConversation, sendChatMessage, stopChatMessageResponding, unpinConversation, updateFeedback } from '@/service/share'
|
import {
|
||||||
|
delConversation,
|
||||||
|
fetchAppInfo,
|
||||||
|
fetchAppParams,
|
||||||
|
fetchChatList,
|
||||||
|
fetchConversations,
|
||||||
|
fetchSuggestedQuestions,
|
||||||
|
pinConversation,
|
||||||
|
sendChatMessage,
|
||||||
|
stopChatMessageResponding,
|
||||||
|
unpinConversation,
|
||||||
|
updateFeedback,
|
||||||
|
} from '@/service/share'
|
||||||
import type { ConversationItem, SiteInfo } from '@/models/share'
|
import type { ConversationItem, SiteInfo } from '@/models/share'
|
||||||
import type { PromptConfig, SuggestedQuestionsAfterAnswerConfig } from '@/models/debug'
|
import type { PromptConfig, SuggestedQuestionsAfterAnswerConfig } from '@/models/debug'
|
||||||
import type { Feedbacktype, IChatItem } from '@/app/components/app/chat'
|
import type { Feedbacktype, IChatItem } from '@/app/components/app/chat'
|
||||||
@ -296,7 +309,9 @@ const Main: FC<IMainProps> = ({
|
|||||||
return fetchConversations(isInstalledApp, installedAppInfo?.id, undefined, undefined, 100)
|
return fetchConversations(isInstalledApp, installedAppInfo?.id, undefined, undefined, 100)
|
||||||
}
|
}
|
||||||
|
|
||||||
const fetchInitData = () => {
|
const fetchInitData = async () => {
|
||||||
|
await checkOrSetAccessToken()
|
||||||
|
|
||||||
return Promise.all([isInstalledApp
|
return Promise.all([isInstalledApp
|
||||||
? {
|
? {
|
||||||
app_id: installedAppInfo?.id,
|
app_id: installedAppInfo?.id,
|
||||||
|
@ -7,6 +7,7 @@ import { useBoolean, useClickAway, useGetState } from 'ahooks'
|
|||||||
import { XMarkIcon } from '@heroicons/react/24/outline'
|
import { XMarkIcon } from '@heroicons/react/24/outline'
|
||||||
import TabHeader from '../../base/tab-header'
|
import TabHeader from '../../base/tab-header'
|
||||||
import Button from '../../base/button'
|
import Button from '../../base/button'
|
||||||
|
import { checkOrSetAccessToken } from '../utils'
|
||||||
import s from './style.module.css'
|
import s from './style.module.css'
|
||||||
import RunBatch from './run-batch'
|
import RunBatch from './run-batch'
|
||||||
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
|
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
|
||||||
@ -76,9 +77,6 @@ const TextGeneration: FC<IMainProps> = ({
|
|||||||
const res: any = await doFetchSavedMessage(isInstalledApp, installedAppInfo?.id)
|
const res: any = await doFetchSavedMessage(isInstalledApp, installedAppInfo?.id)
|
||||||
setSavedMessages(res.data)
|
setSavedMessages(res.data)
|
||||||
}
|
}
|
||||||
useEffect(() => {
|
|
||||||
fetchSavedMessage()
|
|
||||||
}, [])
|
|
||||||
const handleSaveMessage = async (messageId: string) => {
|
const handleSaveMessage = async (messageId: string) => {
|
||||||
await saveMessage(messageId, isInstalledApp, installedAppInfo?.id)
|
await saveMessage(messageId, isInstalledApp, installedAppInfo?.id)
|
||||||
notify({ type: 'success', message: t('common.api.saved') })
|
notify({ type: 'success', message: t('common.api.saved') })
|
||||||
@ -256,7 +254,9 @@ const TextGeneration: FC<IMainProps> = ({
|
|||||||
setAllTaskList(newAllTaskList)
|
setAllTaskList(newAllTaskList)
|
||||||
}
|
}
|
||||||
|
|
||||||
const fetchInitData = () => {
|
const fetchInitData = async () => {
|
||||||
|
await checkOrSetAccessToken()
|
||||||
|
|
||||||
return Promise.all([isInstalledApp
|
return Promise.all([isInstalledApp
|
||||||
? {
|
? {
|
||||||
app_id: installedAppInfo?.id,
|
app_id: installedAppInfo?.id,
|
||||||
@ -267,7 +267,7 @@ const TextGeneration: FC<IMainProps> = ({
|
|||||||
},
|
},
|
||||||
plan: 'basic',
|
plan: 'basic',
|
||||||
}
|
}
|
||||||
: fetchAppInfo(), fetchAppParams(isInstalledApp, installedAppInfo?.id)])
|
: fetchAppInfo(), fetchAppParams(isInstalledApp, installedAppInfo?.id), fetchSavedMessage()])
|
||||||
}
|
}
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
|
18
web/app/components/share/utils.ts
Normal file
18
web/app/components/share/utils.ts
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
import { fetchAccessToken } from '@/service/share'
|
||||||
|
|
||||||
|
export const checkOrSetAccessToken = async () => {
|
||||||
|
const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0]
|
||||||
|
const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' })
|
||||||
|
let accessTokenJson = { [sharedToken]: '' }
|
||||||
|
try {
|
||||||
|
accessTokenJson = JSON.parse(accessToken)
|
||||||
|
}
|
||||||
|
catch (e) {
|
||||||
|
|
||||||
|
}
|
||||||
|
if (!accessTokenJson[sharedToken]) {
|
||||||
|
const res = await fetchAccessToken(sharedToken)
|
||||||
|
accessTokenJson[sharedToken] = res.access_token
|
||||||
|
localStorage.setItem('token', JSON.stringify(accessTokenJson))
|
||||||
|
}
|
||||||
|
}
|
@ -142,7 +142,15 @@ const baseFetch = (
|
|||||||
const options = Object.assign({}, baseOptions, fetchOptions)
|
const options = Object.assign({}, baseOptions, fetchOptions)
|
||||||
if (isPublicAPI) {
|
if (isPublicAPI) {
|
||||||
const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0]
|
const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0]
|
||||||
options.headers.set('Authorization', `bearer ${sharedToken}`)
|
const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' })
|
||||||
|
let accessTokenJson = { [sharedToken]: '' }
|
||||||
|
try {
|
||||||
|
accessTokenJson = JSON.parse(accessToken)
|
||||||
|
}
|
||||||
|
catch (e) {
|
||||||
|
|
||||||
|
}
|
||||||
|
options.headers.set('Authorization', `Bearer ${accessTokenJson[sharedToken]}`)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (deleteContentType) {
|
if (deleteContentType) {
|
||||||
@ -194,7 +202,7 @@ const baseFetch = (
|
|||||||
case 401: {
|
case 401: {
|
||||||
if (isPublicAPI) {
|
if (isPublicAPI) {
|
||||||
Toast.notify({ type: 'error', message: 'Invalid token' })
|
Toast.notify({ type: 'error', message: 'Invalid token' })
|
||||||
return
|
return bodyJson.then((data: any) => Promise.reject(data))
|
||||||
}
|
}
|
||||||
const loginUrl = `${globalThis.location.origin}/signin`
|
const loginUrl = `${globalThis.location.origin}/signin`
|
||||||
if (IS_CE_EDITION) {
|
if (IS_CE_EDITION) {
|
||||||
|
@ -118,3 +118,9 @@ export const fetchSuggestedQuestions = (messageId: string, isInstalledApp: boole
|
|||||||
export const audioToText = (url: string, isPublicAPI: boolean, body: FormData) => {
|
export const audioToText = (url: string, isPublicAPI: boolean, body: FormData) => {
|
||||||
return (getAction('post', !isPublicAPI))(url, { body }, { bodyStringify: false, deleteContentType: true }) as Promise<{ text: string }>
|
return (getAction('post', !isPublicAPI))(url, { body }, { bodyStringify: false, deleteContentType: true }) as Promise<{ text: string }>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const fetchAccessToken = async (appCode: string) => {
|
||||||
|
const headers = new Headers()
|
||||||
|
headers.append('X-App-Code', appCode)
|
||||||
|
return get('/passport', { headers }) as Promise<{ access_token: string }>
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user