mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 12:56:01 +08:00
Fix: rerank switch and validation before run (#9416)
This commit is contained in:
parent
4ac99ffe0e
commit
8a1f106c72
@ -63,7 +63,7 @@ const ConfigContent: FC<Props> = ({
|
|||||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
||||||
|
|
||||||
const {
|
const {
|
||||||
currentModel,
|
currentModel: currentRerankModel,
|
||||||
} = useCurrentProviderAndModel(
|
} = useCurrentProviderAndModel(
|
||||||
rerankModelList,
|
rerankModelList,
|
||||||
rerankDefaultModel
|
rerankDefaultModel
|
||||||
@ -74,11 +74,6 @@ const ConfigContent: FC<Props> = ({
|
|||||||
: undefined,
|
: undefined,
|
||||||
)
|
)
|
||||||
|
|
||||||
const handleDisabledSwitchClick = useCallback(() => {
|
|
||||||
if (!currentModel)
|
|
||||||
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
|
|
||||||
}, [currentModel, rerankDefaultModel, t])
|
|
||||||
|
|
||||||
const rerankModel = (() => {
|
const rerankModel = (() => {
|
||||||
if (datasetConfigs.reranking_model?.reranking_provider_name) {
|
if (datasetConfigs.reranking_model?.reranking_provider_name) {
|
||||||
return {
|
return {
|
||||||
@ -164,12 +159,33 @@ const ConfigContent: FC<Props> = ({
|
|||||||
const showWeightedScorePanel = showWeightedScore && datasetConfigs.reranking_mode === RerankingModeEnum.WeightedScore && datasetConfigs.weights
|
const showWeightedScorePanel = showWeightedScore && datasetConfigs.reranking_mode === RerankingModeEnum.WeightedScore && datasetConfigs.weights
|
||||||
const selectedRerankMode = datasetConfigs.reranking_mode || RerankingModeEnum.RerankingModel
|
const selectedRerankMode = datasetConfigs.reranking_mode || RerankingModeEnum.RerankingModel
|
||||||
|
|
||||||
|
const canManuallyToggleRerank = useMemo(() => {
|
||||||
|
return !(
|
||||||
|
(selectedDatasetsMode.allInternal && selectedDatasetsMode.allEconomic)
|
||||||
|
|| selectedDatasetsMode.allExternal
|
||||||
|
)
|
||||||
|
}, [selectedDatasetsMode.allEconomic, selectedDatasetsMode.allExternal, selectedDatasetsMode.allInternal])
|
||||||
|
|
||||||
const showRerankModel = useMemo(() => {
|
const showRerankModel = useMemo(() => {
|
||||||
if (datasetConfigs.reranking_enable === false && selectedDatasetsMode.allEconomic)
|
if (!canManuallyToggleRerank)
|
||||||
return false
|
return false
|
||||||
|
|
||||||
return true
|
return datasetConfigs.reranking_enable
|
||||||
}, [datasetConfigs.reranking_enable, selectedDatasetsMode.allEconomic])
|
}, [canManuallyToggleRerank, datasetConfigs.reranking_enable])
|
||||||
|
|
||||||
|
const handleDisabledSwitchClick = useCallback(() => {
|
||||||
|
if (!currentRerankModel && !showRerankModel)
|
||||||
|
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
|
||||||
|
}, [currentRerankModel, showRerankModel, t])
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!canManuallyToggleRerank && showRerankModel !== datasetConfigs.reranking_enable) {
|
||||||
|
onChange({
|
||||||
|
...datasetConfigs,
|
||||||
|
reranking_enable: showRerankModel,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}, [canManuallyToggleRerank, showRerankModel, datasetConfigs, onChange])
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div>
|
<div>
|
||||||
@ -256,13 +272,15 @@ const ConfigContent: FC<Props> = ({
|
|||||||
>
|
>
|
||||||
<Switch
|
<Switch
|
||||||
size='md'
|
size='md'
|
||||||
defaultValue={currentModel ? showRerankModel : false}
|
defaultValue={showRerankModel}
|
||||||
disabled={!currentModel}
|
disabled={!currentRerankModel || !canManuallyToggleRerank}
|
||||||
onChange={(v) => {
|
onChange={(v) => {
|
||||||
onChange({
|
if (canManuallyToggleRerank) {
|
||||||
...datasetConfigs,
|
onChange({
|
||||||
reranking_enable: v,
|
...datasetConfigs,
|
||||||
})
|
reranking_enable: v,
|
||||||
|
})
|
||||||
|
}
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
@ -42,6 +42,7 @@ const ParamsConfig = ({
|
|||||||
allHighQuality,
|
allHighQuality,
|
||||||
allHighQualityFullTextSearch,
|
allHighQualityFullTextSearch,
|
||||||
allHighQualityVectorSearch,
|
allHighQualityVectorSearch,
|
||||||
|
allInternal,
|
||||||
allExternal,
|
allExternal,
|
||||||
mixtureHighQualityAndEconomic,
|
mixtureHighQualityAndEconomic,
|
||||||
inconsistentEmbeddingModel,
|
inconsistentEmbeddingModel,
|
||||||
@ -50,7 +51,7 @@ const ParamsConfig = ({
|
|||||||
const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs
|
const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs
|
||||||
let rerankEnable = restConfigs.reranking_enable
|
let rerankEnable = restConfigs.reranking_enable
|
||||||
|
|
||||||
if ((allEconomic && !restConfigs.reranking_model?.reranking_provider_name && rerankEnable === undefined) || allExternal)
|
if (((allInternal && allEconomic) || allExternal) && !restConfigs.reranking_model?.reranking_provider_name && rerankEnable === undefined)
|
||||||
rerankEnable = false
|
rerankEnable = false
|
||||||
|
|
||||||
if (allEconomic || allHighQuality || allHighQualityFullTextSearch || allHighQualityVectorSearch || (allExternal && selectedDatasets.length === 1))
|
if (allEconomic || allHighQuality || allHighQualityFullTextSearch || allHighQualityVectorSearch || (allExternal && selectedDatasets.length === 1))
|
||||||
|
@ -1,25 +1,17 @@
|
|||||||
import { useCallback } from 'react'
|
import { useCallback } from 'react'
|
||||||
import { useStoreApi } from 'reactflow'
|
import { useStoreApi } from 'reactflow'
|
||||||
import { useTranslation } from 'react-i18next'
|
|
||||||
import { useWorkflowStore } from '../store'
|
import { useWorkflowStore } from '../store'
|
||||||
import {
|
import {
|
||||||
BlockEnum,
|
BlockEnum,
|
||||||
WorkflowRunningStatus,
|
WorkflowRunningStatus,
|
||||||
} from '../types'
|
} from '../types'
|
||||||
import type { KnowledgeRetrievalNodeType } from '../nodes/knowledge-retrieval/types'
|
|
||||||
import type { Node } from '../types'
|
|
||||||
import { useWorkflow } from './use-workflow'
|
|
||||||
import {
|
import {
|
||||||
useIsChatMode,
|
useIsChatMode,
|
||||||
useNodesSyncDraft,
|
useNodesSyncDraft,
|
||||||
useWorkflowInteractions,
|
useWorkflowInteractions,
|
||||||
useWorkflowRun,
|
useWorkflowRun,
|
||||||
} from './index'
|
} from './index'
|
||||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
|
||||||
import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
|
||||||
import { useFeaturesStore } from '@/app/components/base/features/hooks'
|
import { useFeaturesStore } from '@/app/components/base/features/hooks'
|
||||||
import KnowledgeRetrievalDefault from '@/app/components/workflow/nodes/knowledge-retrieval/default'
|
|
||||||
import Toast from '@/app/components/base/toast'
|
|
||||||
|
|
||||||
export const useWorkflowStartRun = () => {
|
export const useWorkflowStartRun = () => {
|
||||||
const store = useStoreApi()
|
const store = useStoreApi()
|
||||||
@ -28,26 +20,7 @@ export const useWorkflowStartRun = () => {
|
|||||||
const isChatMode = useIsChatMode()
|
const isChatMode = useIsChatMode()
|
||||||
const { handleCancelDebugAndPreviewPanel } = useWorkflowInteractions()
|
const { handleCancelDebugAndPreviewPanel } = useWorkflowInteractions()
|
||||||
const { handleRun } = useWorkflowRun()
|
const { handleRun } = useWorkflowRun()
|
||||||
const { isFromStartNode } = useWorkflow()
|
|
||||||
const { doSyncWorkflowDraft } = useNodesSyncDraft()
|
const { doSyncWorkflowDraft } = useNodesSyncDraft()
|
||||||
const { checkValid: checkKnowledgeRetrievalValid } = KnowledgeRetrievalDefault
|
|
||||||
const { t } = useTranslation()
|
|
||||||
const {
|
|
||||||
modelList: rerankModelList,
|
|
||||||
defaultModel: rerankDefaultModel,
|
|
||||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
|
||||||
|
|
||||||
const {
|
|
||||||
currentModel,
|
|
||||||
} = useCurrentProviderAndModel(
|
|
||||||
rerankModelList,
|
|
||||||
rerankDefaultModel
|
|
||||||
? {
|
|
||||||
...rerankDefaultModel,
|
|
||||||
provider: rerankDefaultModel.provider.provider,
|
|
||||||
}
|
|
||||||
: undefined,
|
|
||||||
)
|
|
||||||
|
|
||||||
const handleWorkflowStartRunInWorkflow = useCallback(async () => {
|
const handleWorkflowStartRunInWorkflow = useCallback(async () => {
|
||||||
const {
|
const {
|
||||||
@ -60,9 +33,6 @@ export const useWorkflowStartRun = () => {
|
|||||||
const { getNodes } = store.getState()
|
const { getNodes } = store.getState()
|
||||||
const nodes = getNodes()
|
const nodes = getNodes()
|
||||||
const startNode = nodes.find(node => node.data.type === BlockEnum.Start)
|
const startNode = nodes.find(node => node.data.type === BlockEnum.Start)
|
||||||
const knowledgeRetrievalNodes = nodes.filter((node: Node<KnowledgeRetrievalNodeType>) =>
|
|
||||||
node.data.type === BlockEnum.KnowledgeRetrieval,
|
|
||||||
)
|
|
||||||
const startVariables = startNode?.data.variables || []
|
const startVariables = startNode?.data.variables || []
|
||||||
const fileSettings = featuresStore!.getState().features.file
|
const fileSettings = featuresStore!.getState().features.file
|
||||||
const {
|
const {
|
||||||
@ -72,31 +42,6 @@ export const useWorkflowStartRun = () => {
|
|||||||
setShowEnvPanel,
|
setShowEnvPanel,
|
||||||
} = workflowStore.getState()
|
} = workflowStore.getState()
|
||||||
|
|
||||||
if (knowledgeRetrievalNodes.length > 0) {
|
|
||||||
for (const node of knowledgeRetrievalNodes) {
|
|
||||||
if (isFromStartNode(node.id)) {
|
|
||||||
const res = checkKnowledgeRetrievalValid(node.data, t)
|
|
||||||
if (!res.isValid || !currentModel || !rerankDefaultModel) {
|
|
||||||
const errorMessage = res.errorMessage
|
|
||||||
if (errorMessage) {
|
|
||||||
Toast.notify({
|
|
||||||
type: 'error',
|
|
||||||
message: errorMessage,
|
|
||||||
})
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
Toast.notify({
|
|
||||||
type: 'error',
|
|
||||||
message: t('appDebug.datasetConfig.rerankModelRequired'),
|
|
||||||
})
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
setShowEnvPanel(false)
|
setShowEnvPanel(false)
|
||||||
|
|
||||||
if (showDebugAndPreviewPanel) {
|
if (showDebugAndPreviewPanel) {
|
||||||
|
@ -23,7 +23,7 @@ import type { DataSet } from '@/models/datasets'
|
|||||||
import { fetchDatasets } from '@/service/datasets'
|
import { fetchDatasets } from '@/service/datasets'
|
||||||
import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud'
|
import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud'
|
||||||
import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run'
|
import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run'
|
||||||
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||||
|
|
||||||
const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||||
@ -34,6 +34,8 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|||||||
const startNodeId = startNode?.id
|
const startNodeId = startNode?.id
|
||||||
const { inputs, setInputs: doSetInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload)
|
const { inputs, setInputs: doSetInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload)
|
||||||
|
|
||||||
|
const inputRef = useRef(inputs)
|
||||||
|
|
||||||
const setInputs = useCallback((s: KnowledgeRetrievalNodeType) => {
|
const setInputs = useCallback((s: KnowledgeRetrievalNodeType) => {
|
||||||
const newInputs = produce(s, (draft) => {
|
const newInputs = produce(s, (draft) => {
|
||||||
if (s.retrieval_mode === RETRIEVE_TYPE.multiWay)
|
if (s.retrieval_mode === RETRIEVE_TYPE.multiWay)
|
||||||
@ -43,13 +45,9 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|||||||
})
|
})
|
||||||
// not work in pass to draft...
|
// not work in pass to draft...
|
||||||
doSetInputs(newInputs)
|
doSetInputs(newInputs)
|
||||||
|
inputRef.current = newInputs
|
||||||
}, [doSetInputs])
|
}, [doSetInputs])
|
||||||
|
|
||||||
const inputRef = useRef(inputs)
|
|
||||||
useEffect(() => {
|
|
||||||
inputRef.current = inputs
|
|
||||||
}, [inputs])
|
|
||||||
|
|
||||||
const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => {
|
const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => {
|
||||||
const newInputs = produce(inputs, (draft) => {
|
const newInputs = produce(inputs, (draft) => {
|
||||||
draft.query_variable_selector = newVar as ValueSelector
|
draft.query_variable_selector = newVar as ValueSelector
|
||||||
@ -63,9 +61,22 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|||||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration)
|
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration)
|
||||||
|
|
||||||
const {
|
const {
|
||||||
|
modelList: rerankModelList,
|
||||||
defaultModel: rerankDefaultModel,
|
defaultModel: rerankDefaultModel,
|
||||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
||||||
|
|
||||||
|
const {
|
||||||
|
currentModel: currentRerankModel,
|
||||||
|
} = useCurrentProviderAndModel(
|
||||||
|
rerankModelList,
|
||||||
|
rerankDefaultModel
|
||||||
|
? {
|
||||||
|
...rerankDefaultModel,
|
||||||
|
provider: rerankDefaultModel.provider.provider,
|
||||||
|
}
|
||||||
|
: undefined,
|
||||||
|
)
|
||||||
|
|
||||||
const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => {
|
const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => {
|
||||||
const newInputs = produce(inputRef.current, (draft) => {
|
const newInputs = produce(inputRef.current, (draft) => {
|
||||||
if (!draft.single_retrieval_config) {
|
if (!draft.single_retrieval_config) {
|
||||||
@ -110,7 +121,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|||||||
// set defaults models
|
// set defaults models
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const inputs = inputRef.current
|
const inputs = inputRef.current
|
||||||
if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider)
|
if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider && currentRerankModel && rerankDefaultModel)
|
||||||
return
|
return
|
||||||
|
|
||||||
if (inputs.retrieval_mode === RETRIEVE_TYPE.oneWay && inputs.single_retrieval_config?.model?.provider)
|
if (inputs.retrieval_mode === RETRIEVE_TYPE.oneWay && inputs.single_retrieval_config?.model?.provider)
|
||||||
@ -130,7 +141,6 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const multipleRetrievalConfig = draft.multiple_retrieval_config
|
const multipleRetrievalConfig = draft.multiple_retrieval_config
|
||||||
draft.multiple_retrieval_config = {
|
draft.multiple_retrieval_config = {
|
||||||
top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k,
|
top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k,
|
||||||
@ -138,6 +148,9 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|||||||
reranking_model: multipleRetrievalConfig?.reranking_model,
|
reranking_model: multipleRetrievalConfig?.reranking_model,
|
||||||
reranking_mode: multipleRetrievalConfig?.reranking_mode,
|
reranking_mode: multipleRetrievalConfig?.reranking_mode,
|
||||||
weights: multipleRetrievalConfig?.weights,
|
weights: multipleRetrievalConfig?.weights,
|
||||||
|
reranking_enable: multipleRetrievalConfig?.reranking_enable !== undefined
|
||||||
|
? multipleRetrievalConfig.reranking_enable
|
||||||
|
: Boolean(currentRerankModel && rerankDefaultModel),
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
setInputs(newInput)
|
setInputs(newInput)
|
||||||
@ -194,14 +207,14 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|||||||
}, [])
|
}, [])
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
|
const inputs = inputRef.current
|
||||||
let query_variable_selector: ValueSelector = inputs.query_variable_selector
|
let query_variable_selector: ValueSelector = inputs.query_variable_selector
|
||||||
if (isChatMode && inputs.query_variable_selector.length === 0 && startNodeId)
|
if (isChatMode && inputs.query_variable_selector.length === 0 && startNodeId)
|
||||||
query_variable_selector = [startNodeId, 'sys.query']
|
query_variable_selector = [startNodeId, 'sys.query']
|
||||||
|
|
||||||
setInputs({
|
setInputs(produce(inputs, (draft) => {
|
||||||
...inputs,
|
draft.query_variable_selector = query_variable_selector
|
||||||
query_variable_selector,
|
}))
|
||||||
})
|
|
||||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
}, [])
|
}, [])
|
||||||
|
|
||||||
|
@ -113,7 +113,7 @@ export const getMultipleRetrievalConfig = (multipleRetrievalConfig: MultipleRetr
|
|||||||
reranking_mode,
|
reranking_mode,
|
||||||
reranking_model,
|
reranking_model,
|
||||||
weights,
|
weights,
|
||||||
reranking_enable: allEconomic ? reranking_enable : true,
|
reranking_enable: ((allInternal && allEconomic) || allExternal) ? reranking_enable : true,
|
||||||
}
|
}
|
||||||
|
|
||||||
if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || allExternal || mixtureInternalAndExternal)
|
if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || allExternal || mixtureInternalAndExternal)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user