Fix: rerank switch and validation before run (#9416)

This commit is contained in:
Yi Xiao 2024-10-17 14:26:38 +08:00 committed by GitHub
parent 4ac99ffe0e
commit 8a1f106c72
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 61 additions and 84 deletions

View File

@ -63,7 +63,7 @@ const ConfigContent: FC<Props> = ({
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
const { const {
currentModel, currentModel: currentRerankModel,
} = useCurrentProviderAndModel( } = useCurrentProviderAndModel(
rerankModelList, rerankModelList,
rerankDefaultModel rerankDefaultModel
@ -74,11 +74,6 @@ const ConfigContent: FC<Props> = ({
: undefined, : undefined,
) )
const handleDisabledSwitchClick = useCallback(() => {
if (!currentModel)
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
}, [currentModel, rerankDefaultModel, t])
const rerankModel = (() => { const rerankModel = (() => {
if (datasetConfigs.reranking_model?.reranking_provider_name) { if (datasetConfigs.reranking_model?.reranking_provider_name) {
return { return {
@ -164,12 +159,33 @@ const ConfigContent: FC<Props> = ({
const showWeightedScorePanel = showWeightedScore && datasetConfigs.reranking_mode === RerankingModeEnum.WeightedScore && datasetConfigs.weights const showWeightedScorePanel = showWeightedScore && datasetConfigs.reranking_mode === RerankingModeEnum.WeightedScore && datasetConfigs.weights
const selectedRerankMode = datasetConfigs.reranking_mode || RerankingModeEnum.RerankingModel const selectedRerankMode = datasetConfigs.reranking_mode || RerankingModeEnum.RerankingModel
const canManuallyToggleRerank = useMemo(() => {
return !(
(selectedDatasetsMode.allInternal && selectedDatasetsMode.allEconomic)
|| selectedDatasetsMode.allExternal
)
}, [selectedDatasetsMode.allEconomic, selectedDatasetsMode.allExternal, selectedDatasetsMode.allInternal])
const showRerankModel = useMemo(() => { const showRerankModel = useMemo(() => {
if (datasetConfigs.reranking_enable === false && selectedDatasetsMode.allEconomic) if (!canManuallyToggleRerank)
return false return false
return true return datasetConfigs.reranking_enable
}, [datasetConfigs.reranking_enable, selectedDatasetsMode.allEconomic]) }, [canManuallyToggleRerank, datasetConfigs.reranking_enable])
const handleDisabledSwitchClick = useCallback(() => {
if (!currentRerankModel && !showRerankModel)
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
}, [currentRerankModel, showRerankModel, t])
useEffect(() => {
if (!canManuallyToggleRerank && showRerankModel !== datasetConfigs.reranking_enable) {
onChange({
...datasetConfigs,
reranking_enable: showRerankModel,
})
}
}, [canManuallyToggleRerank, showRerankModel, datasetConfigs, onChange])
return ( return (
<div> <div>
@ -256,13 +272,15 @@ const ConfigContent: FC<Props> = ({
> >
<Switch <Switch
size='md' size='md'
defaultValue={currentModel ? showRerankModel : false} defaultValue={showRerankModel}
disabled={!currentModel} disabled={!currentRerankModel || !canManuallyToggleRerank}
onChange={(v) => { onChange={(v) => {
onChange({ if (canManuallyToggleRerank) {
...datasetConfigs, onChange({
reranking_enable: v, ...datasetConfigs,
}) reranking_enable: v,
})
}
}} }}
/> />
</div> </div>

View File

@ -42,6 +42,7 @@ const ParamsConfig = ({
allHighQuality, allHighQuality,
allHighQualityFullTextSearch, allHighQualityFullTextSearch,
allHighQualityVectorSearch, allHighQualityVectorSearch,
allInternal,
allExternal, allExternal,
mixtureHighQualityAndEconomic, mixtureHighQualityAndEconomic,
inconsistentEmbeddingModel, inconsistentEmbeddingModel,
@ -50,7 +51,7 @@ const ParamsConfig = ({
const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs
let rerankEnable = restConfigs.reranking_enable let rerankEnable = restConfigs.reranking_enable
if ((allEconomic && !restConfigs.reranking_model?.reranking_provider_name && rerankEnable === undefined) || allExternal) if (((allInternal && allEconomic) || allExternal) && !restConfigs.reranking_model?.reranking_provider_name && rerankEnable === undefined)
rerankEnable = false rerankEnable = false
if (allEconomic || allHighQuality || allHighQualityFullTextSearch || allHighQualityVectorSearch || (allExternal && selectedDatasets.length === 1)) if (allEconomic || allHighQuality || allHighQualityFullTextSearch || allHighQualityVectorSearch || (allExternal && selectedDatasets.length === 1))

View File

@ -1,25 +1,17 @@
import { useCallback } from 'react' import { useCallback } from 'react'
import { useStoreApi } from 'reactflow' import { useStoreApi } from 'reactflow'
import { useTranslation } from 'react-i18next'
import { useWorkflowStore } from '../store' import { useWorkflowStore } from '../store'
import { import {
BlockEnum, BlockEnum,
WorkflowRunningStatus, WorkflowRunningStatus,
} from '../types' } from '../types'
import type { KnowledgeRetrievalNodeType } from '../nodes/knowledge-retrieval/types'
import type { Node } from '../types'
import { useWorkflow } from './use-workflow'
import { import {
useIsChatMode, useIsChatMode,
useNodesSyncDraft, useNodesSyncDraft,
useWorkflowInteractions, useWorkflowInteractions,
useWorkflowRun, useWorkflowRun,
} from './index' } from './index'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { useFeaturesStore } from '@/app/components/base/features/hooks' import { useFeaturesStore } from '@/app/components/base/features/hooks'
import KnowledgeRetrievalDefault from '@/app/components/workflow/nodes/knowledge-retrieval/default'
import Toast from '@/app/components/base/toast'
export const useWorkflowStartRun = () => { export const useWorkflowStartRun = () => {
const store = useStoreApi() const store = useStoreApi()
@ -28,26 +20,7 @@ export const useWorkflowStartRun = () => {
const isChatMode = useIsChatMode() const isChatMode = useIsChatMode()
const { handleCancelDebugAndPreviewPanel } = useWorkflowInteractions() const { handleCancelDebugAndPreviewPanel } = useWorkflowInteractions()
const { handleRun } = useWorkflowRun() const { handleRun } = useWorkflowRun()
const { isFromStartNode } = useWorkflow()
const { doSyncWorkflowDraft } = useNodesSyncDraft() const { doSyncWorkflowDraft } = useNodesSyncDraft()
const { checkValid: checkKnowledgeRetrievalValid } = KnowledgeRetrievalDefault
const { t } = useTranslation()
const {
modelList: rerankModelList,
defaultModel: rerankDefaultModel,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
const {
currentModel,
} = useCurrentProviderAndModel(
rerankModelList,
rerankDefaultModel
? {
...rerankDefaultModel,
provider: rerankDefaultModel.provider.provider,
}
: undefined,
)
const handleWorkflowStartRunInWorkflow = useCallback(async () => { const handleWorkflowStartRunInWorkflow = useCallback(async () => {
const { const {
@ -60,9 +33,6 @@ export const useWorkflowStartRun = () => {
const { getNodes } = store.getState() const { getNodes } = store.getState()
const nodes = getNodes() const nodes = getNodes()
const startNode = nodes.find(node => node.data.type === BlockEnum.Start) const startNode = nodes.find(node => node.data.type === BlockEnum.Start)
const knowledgeRetrievalNodes = nodes.filter((node: Node<KnowledgeRetrievalNodeType>) =>
node.data.type === BlockEnum.KnowledgeRetrieval,
)
const startVariables = startNode?.data.variables || [] const startVariables = startNode?.data.variables || []
const fileSettings = featuresStore!.getState().features.file const fileSettings = featuresStore!.getState().features.file
const { const {
@ -72,31 +42,6 @@ export const useWorkflowStartRun = () => {
setShowEnvPanel, setShowEnvPanel,
} = workflowStore.getState() } = workflowStore.getState()
if (knowledgeRetrievalNodes.length > 0) {
for (const node of knowledgeRetrievalNodes) {
if (isFromStartNode(node.id)) {
const res = checkKnowledgeRetrievalValid(node.data, t)
if (!res.isValid || !currentModel || !rerankDefaultModel) {
const errorMessage = res.errorMessage
if (errorMessage) {
Toast.notify({
type: 'error',
message: errorMessage,
})
return false
}
else {
Toast.notify({
type: 'error',
message: t('appDebug.datasetConfig.rerankModelRequired'),
})
return false
}
}
}
}
}
setShowEnvPanel(false) setShowEnvPanel(false)
if (showDebugAndPreviewPanel) { if (showDebugAndPreviewPanel) {

View File

@ -23,7 +23,7 @@ import type { DataSet } from '@/models/datasets'
import { fetchDatasets } from '@/service/datasets' import { fetchDatasets } from '@/service/datasets'
import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud' import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud'
import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run' import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run'
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
@ -34,6 +34,8 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
const startNodeId = startNode?.id const startNodeId = startNode?.id
const { inputs, setInputs: doSetInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload) const { inputs, setInputs: doSetInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload)
const inputRef = useRef(inputs)
const setInputs = useCallback((s: KnowledgeRetrievalNodeType) => { const setInputs = useCallback((s: KnowledgeRetrievalNodeType) => {
const newInputs = produce(s, (draft) => { const newInputs = produce(s, (draft) => {
if (s.retrieval_mode === RETRIEVE_TYPE.multiWay) if (s.retrieval_mode === RETRIEVE_TYPE.multiWay)
@ -43,13 +45,9 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
}) })
// not work in pass to draft... // not work in pass to draft...
doSetInputs(newInputs) doSetInputs(newInputs)
inputRef.current = newInputs
}, [doSetInputs]) }, [doSetInputs])
const inputRef = useRef(inputs)
useEffect(() => {
inputRef.current = inputs
}, [inputs])
const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => { const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => {
const newInputs = produce(inputs, (draft) => { const newInputs = produce(inputs, (draft) => {
draft.query_variable_selector = newVar as ValueSelector draft.query_variable_selector = newVar as ValueSelector
@ -63,9 +61,22 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration) } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration)
const { const {
modelList: rerankModelList,
defaultModel: rerankDefaultModel, defaultModel: rerankDefaultModel,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
const {
currentModel: currentRerankModel,
} = useCurrentProviderAndModel(
rerankModelList,
rerankDefaultModel
? {
...rerankDefaultModel,
provider: rerankDefaultModel.provider.provider,
}
: undefined,
)
const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => { const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => {
const newInputs = produce(inputRef.current, (draft) => { const newInputs = produce(inputRef.current, (draft) => {
if (!draft.single_retrieval_config) { if (!draft.single_retrieval_config) {
@ -110,7 +121,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
// set defaults models // set defaults models
useEffect(() => { useEffect(() => {
const inputs = inputRef.current const inputs = inputRef.current
if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider) if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider && currentRerankModel && rerankDefaultModel)
return return
if (inputs.retrieval_mode === RETRIEVE_TYPE.oneWay && inputs.single_retrieval_config?.model?.provider) if (inputs.retrieval_mode === RETRIEVE_TYPE.oneWay && inputs.single_retrieval_config?.model?.provider)
@ -130,7 +141,6 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
} }
} }
} }
const multipleRetrievalConfig = draft.multiple_retrieval_config const multipleRetrievalConfig = draft.multiple_retrieval_config
draft.multiple_retrieval_config = { draft.multiple_retrieval_config = {
top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k, top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k,
@ -138,6 +148,9 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
reranking_model: multipleRetrievalConfig?.reranking_model, reranking_model: multipleRetrievalConfig?.reranking_model,
reranking_mode: multipleRetrievalConfig?.reranking_mode, reranking_mode: multipleRetrievalConfig?.reranking_mode,
weights: multipleRetrievalConfig?.weights, weights: multipleRetrievalConfig?.weights,
reranking_enable: multipleRetrievalConfig?.reranking_enable !== undefined
? multipleRetrievalConfig.reranking_enable
: Boolean(currentRerankModel && rerankDefaultModel),
} }
}) })
setInputs(newInput) setInputs(newInput)
@ -194,14 +207,14 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
}, []) }, [])
useEffect(() => { useEffect(() => {
const inputs = inputRef.current
let query_variable_selector: ValueSelector = inputs.query_variable_selector let query_variable_selector: ValueSelector = inputs.query_variable_selector
if (isChatMode && inputs.query_variable_selector.length === 0 && startNodeId) if (isChatMode && inputs.query_variable_selector.length === 0 && startNodeId)
query_variable_selector = [startNodeId, 'sys.query'] query_variable_selector = [startNodeId, 'sys.query']
setInputs({ setInputs(produce(inputs, (draft) => {
...inputs, draft.query_variable_selector = query_variable_selector
query_variable_selector, }))
})
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
}, []) }, [])

View File

@ -113,7 +113,7 @@ export const getMultipleRetrievalConfig = (multipleRetrievalConfig: MultipleRetr
reranking_mode, reranking_mode,
reranking_model, reranking_model,
weights, weights,
reranking_enable: allEconomic ? reranking_enable : true, reranking_enable: ((allInternal && allEconomic) || allExternal) ? reranking_enable : true,
} }
if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || allExternal || mixtureInternalAndExternal) if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || allExternal || mixtureInternalAndExternal)