Feat: rerank model verification in front end (#9271)

This commit is contained in:
Yi Xiao 2024-10-12 21:24:43 +08:00 committed by GitHub
parent c6b74daa0a
commit 793205afc5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 159 additions and 24 deletions

View File

@ -1,6 +1,6 @@
'use client' 'use client'
import { memo, useEffect, useMemo } from 'react' import { memo, useCallback, useEffect, useMemo } from 'react'
import type { FC } from 'react' import type { FC } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import WeightedScore from './weighted-score' import WeightedScore from './weighted-score'
@ -11,7 +11,7 @@ import type {
DatasetConfigs, DatasetConfigs,
} from '@/models/debug' } from '@/models/debug'
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import type { ModelConfig } from '@/app/components/workflow/types' import type { ModelConfig } from '@/app/components/workflow/types'
import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal' import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal'
import Tooltip from '@/app/components/base/tooltip' import Tooltip from '@/app/components/base/tooltip'
@ -23,6 +23,7 @@ import { RerankingModeEnum } from '@/models/datasets'
import cn from '@/utils/classnames' import cn from '@/utils/classnames'
import { useSelectedDatasetsMode } from '@/app/components/workflow/nodes/knowledge-retrieval/hooks' import { useSelectedDatasetsMode } from '@/app/components/workflow/nodes/knowledge-retrieval/hooks'
import Switch from '@/app/components/base/switch' import Switch from '@/app/components/base/switch'
import Toast from '@/app/components/base/toast'
type Props = { type Props = {
datasetConfigs: DatasetConfigs datasetConfigs: DatasetConfigs
@ -60,6 +61,24 @@ const ConfigContent: FC<Props> = ({
modelList: rerankModelList, modelList: rerankModelList,
defaultModel: rerankDefaultModel, defaultModel: rerankDefaultModel,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
const {
currentModel,
} = useCurrentProviderAndModel(
rerankModelList,
rerankDefaultModel
? {
...rerankDefaultModel,
provider: rerankDefaultModel.provider.provider,
}
: undefined,
)
const handleDisabledSwitchClick = useCallback(() => {
if (!currentModel)
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
}, [currentModel, rerankDefaultModel, t])
const rerankModel = (() => { const rerankModel = (() => {
if (datasetConfigs.reranking_model?.reranking_provider_name) { if (datasetConfigs.reranking_model?.reranking_provider_name) {
return { return {
@ -231,9 +250,14 @@ const ConfigContent: FC<Props> = ({
<div className='flex items-center'> <div className='flex items-center'>
{ {
selectedDatasetsMode.allEconomic && ( selectedDatasetsMode.allEconomic && (
<div
className='flex items-center'
onClick={handleDisabledSwitchClick}
>
<Switch <Switch
size='md' size='md'
defaultValue={showRerankModel} defaultValue={currentModel ? showRerankModel : false}
disabled={!currentModel}
onChange={(v) => { onChange={(v) => {
onChange({ onChange({
...datasetConfigs, ...datasetConfigs,
@ -241,6 +265,7 @@ const ConfigContent: FC<Props> = ({
}) })
}} }}
/> />
</div>
) )
} }
<div className='leading-[32px] ml-1 text-text-secondary system-sm-semibold'>{t('common.modelProvider.rerankModel.key')}</div> <div className='leading-[32px] ml-1 text-text-secondary system-sm-semibold'>{t('common.modelProvider.rerankModel.key')}</div>

View File

@ -1,6 +1,6 @@
'use client' 'use client'
import type { FC } from 'react' import type { FC } from 'react'
import React from 'react' import React, { useCallback } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import cn from '@/utils/classnames' import cn from '@/utils/classnames'
@ -11,7 +11,7 @@ import Switch from '@/app/components/base/switch'
import Tooltip from '@/app/components/base/tooltip' import Tooltip from '@/app/components/base/tooltip'
import type { RetrievalConfig } from '@/types/app' import type { RetrievalConfig } from '@/types/app'
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
import { useModelListAndDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { useCurrentProviderAndModel, useModelListAndDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { import {
DEFAULT_WEIGHTED_SCORE, DEFAULT_WEIGHTED_SCORE,
@ -19,6 +19,7 @@ import {
WeightedScoreEnum, WeightedScoreEnum,
} from '@/models/datasets' } from '@/models/datasets'
import WeightedScore from '@/app/components/app/configuration/dataset-config/params-config/weighted-score' import WeightedScore from '@/app/components/app/configuration/dataset-config/params-config/weighted-score'
import Toast from '@/app/components/base/toast'
type Props = { type Props = {
type: RETRIEVE_METHOD type: RETRIEVE_METHOD
@ -38,6 +39,24 @@ const RetrievalParamConfig: FC<Props> = ({
defaultModel: rerankDefaultModel, defaultModel: rerankDefaultModel,
modelList: rerankModelList, modelList: rerankModelList,
} = useModelListAndDefaultModel(ModelTypeEnum.rerank) } = useModelListAndDefaultModel(ModelTypeEnum.rerank)
const {
currentModel,
} = useCurrentProviderAndModel(
rerankModelList,
rerankDefaultModel
? {
...rerankDefaultModel,
provider: rerankDefaultModel.provider.provider,
}
: undefined,
)
const handleDisabledSwitchClick = useCallback(() => {
if (!currentModel)
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
}, [currentModel, rerankDefaultModel, t])
const isHybridSearch = type === RETRIEVE_METHOD.hybrid const isHybridSearch = type === RETRIEVE_METHOD.hybrid
const rerankModel = (() => { const rerankModel = (() => {
@ -99,16 +118,22 @@ const RetrievalParamConfig: FC<Props> = ({
<div> <div>
<div className='flex h-8 items-center text-[13px] font-medium text-gray-900 space-x-2'> <div className='flex h-8 items-center text-[13px] font-medium text-gray-900 space-x-2'>
{canToggleRerankModalEnable && ( {canToggleRerankModalEnable && (
<div
className='flex items-center'
onClick={handleDisabledSwitchClick}
>
<Switch <Switch
size='md' size='md'
defaultValue={value.reranking_enable} defaultValue={currentModel ? value.reranking_enable : false}
onChange={(v) => { onChange={(v) => {
onChange({ onChange({
...value, ...value,
reranking_enable: v, reranking_enable: v,
}) })
}} }}
disabled={!currentModel}
/> />
</div>
)} )}
<div className='flex items-center'> <div className='flex items-center'>
<span className='mr-0.5'>{t('common.modelProvider.rerankModel.key')}</span> <span className='mr-0.5'>{t('common.modelProvider.rerankModel.key')}</span>

View File

@ -1,17 +1,25 @@
import { useCallback } from 'react' import { useCallback } from 'react'
import { useStoreApi } from 'reactflow' import { useStoreApi } from 'reactflow'
import { useTranslation } from 'react-i18next'
import { useWorkflowStore } from '../store' import { useWorkflowStore } from '../store'
import { import {
BlockEnum, BlockEnum,
WorkflowRunningStatus, WorkflowRunningStatus,
} from '../types' } from '../types'
import type { KnowledgeRetrievalNodeType } from '../nodes/knowledge-retrieval/types'
import type { Node } from '../types'
import { useWorkflow } from './use-workflow'
import { import {
useIsChatMode, useIsChatMode,
useNodesSyncDraft, useNodesSyncDraft,
useWorkflowInteractions, useWorkflowInteractions,
useWorkflowRun, useWorkflowRun,
} from './index' } from './index'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { useFeaturesStore } from '@/app/components/base/features/hooks' import { useFeaturesStore } from '@/app/components/base/features/hooks'
import KnowledgeRetrievalDefault from '@/app/components/workflow/nodes/knowledge-retrieval/default'
import Toast from '@/app/components/base/toast'
export const useWorkflowStartRun = () => { export const useWorkflowStartRun = () => {
const store = useStoreApi() const store = useStoreApi()
@ -20,7 +28,26 @@ export const useWorkflowStartRun = () => {
const isChatMode = useIsChatMode() const isChatMode = useIsChatMode()
const { handleCancelDebugAndPreviewPanel } = useWorkflowInteractions() const { handleCancelDebugAndPreviewPanel } = useWorkflowInteractions()
const { handleRun } = useWorkflowRun() const { handleRun } = useWorkflowRun()
const { isFromStartNode } = useWorkflow()
const { doSyncWorkflowDraft } = useNodesSyncDraft() const { doSyncWorkflowDraft } = useNodesSyncDraft()
const { checkValid: checkKnowledgeRetrievalValid } = KnowledgeRetrievalDefault
const { t } = useTranslation()
const {
modelList: rerankModelList,
defaultModel: rerankDefaultModel,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
const {
currentModel,
} = useCurrentProviderAndModel(
rerankModelList,
rerankDefaultModel
? {
...rerankDefaultModel,
provider: rerankDefaultModel.provider.provider,
}
: undefined,
)
const handleWorkflowStartRunInWorkflow = useCallback(async () => { const handleWorkflowStartRunInWorkflow = useCallback(async () => {
const { const {
@ -33,6 +60,9 @@ export const useWorkflowStartRun = () => {
const { getNodes } = store.getState() const { getNodes } = store.getState()
const nodes = getNodes() const nodes = getNodes()
const startNode = nodes.find(node => node.data.type === BlockEnum.Start) const startNode = nodes.find(node => node.data.type === BlockEnum.Start)
const knowledgeRetrievalNodes = nodes.filter((node: Node<KnowledgeRetrievalNodeType>) =>
node.data.type === BlockEnum.KnowledgeRetrieval,
)
const startVariables = startNode?.data.variables || [] const startVariables = startNode?.data.variables || []
const fileSettings = featuresStore!.getState().features.file const fileSettings = featuresStore!.getState().features.file
const { const {
@ -42,6 +72,31 @@ export const useWorkflowStartRun = () => {
setShowEnvPanel, setShowEnvPanel,
} = workflowStore.getState() } = workflowStore.getState()
if (knowledgeRetrievalNodes.length > 0) {
for (const node of knowledgeRetrievalNodes) {
if (isFromStartNode(node.id)) {
const res = checkKnowledgeRetrievalValid(node.data, t)
if (!res.isValid || !currentModel || !rerankDefaultModel) {
const errorMessage = res.errorMessage
if (errorMessage) {
Toast.notify({
type: 'error',
message: errorMessage,
})
return false
}
else {
Toast.notify({
type: 'error',
message: t('appDebug.datasetConfig.rerankModelRequired'),
})
return false
}
}
}
}
}
setShowEnvPanel(false) setShowEnvPanel(false)
if (showDebugAndPreviewPanel) { if (showDebugAndPreviewPanel) {

View File

@ -235,6 +235,33 @@ export const useWorkflow = () => {
return nodes.filter(node => node.parentId === nodeId) return nodes.filter(node => node.parentId === nodeId)
}, [store]) }, [store])
const isFromStartNode = useCallback((nodeId: string) => {
const { getNodes } = store.getState()
const nodes = getNodes()
const currentNode = nodes.find(node => node.id === nodeId)
if (!currentNode)
return false
if (currentNode.data.type === BlockEnum.Start)
return true
const checkPreviousNodes = (node: Node) => {
const previousNodes = getBeforeNodeById(node.id)
for (const prevNode of previousNodes) {
if (prevNode.data.type === BlockEnum.Start)
return true
if (checkPreviousNodes(prevNode))
return true
}
return false
}
return checkPreviousNodes(currentNode)
}, [store, getBeforeNodeById])
const handleOutVarRenameChange = useCallback((nodeId: string, oldValeSelector: ValueSelector, newVarSelector: ValueSelector) => { const handleOutVarRenameChange = useCallback((nodeId: string, oldValeSelector: ValueSelector, newVarSelector: ValueSelector) => {
const { getNodes, setNodes } = store.getState() const { getNodes, setNodes } = store.getState()
const afterNodes = getAfterNodesInSameBranch(nodeId) const afterNodes = getAfterNodesInSameBranch(nodeId)
@ -389,6 +416,7 @@ export const useWorkflow = () => {
checkParallelLimit, checkParallelLimit,
checkNestedParallelLimit, checkNestedParallelLimit,
isValidConnection, isValidConnection,
isFromStartNode,
formatTimeFromNow, formatTimeFromNow,
getNode, getNode,
getBeforeNodeById, getBeforeNodeById,

View File

@ -172,6 +172,7 @@ const translation = {
}, },
errorMsg: { errorMsg: {
fieldRequired: '{{field}} is required', fieldRequired: '{{field}} is required',
rerankModelRequired: 'Before turning on the Rerank Model, please confirm that the model has been successfully configured in the settings.',
authRequired: 'Authorization is required', authRequired: 'Authorization is required',
invalidJson: '{{field}} is invalid JSON', invalidJson: '{{field}} is invalid JSON',
fields: { fields: {

View File

@ -172,6 +172,7 @@ const translation = {
}, },
errorMsg: { errorMsg: {
fieldRequired: '{{field}} 不能为空', fieldRequired: '{{field}} 不能为空',
rerankModelRequired: '开启 Rerank 模型前,请务必确认模型已在设置中成功配置。',
authRequired: '请先授权', authRequired: '请先授权',
invalidJson: '{{field}} 是非法的 JSON', invalidJson: '{{field}} 是非法的 JSON',
fields: { fields: {