Fix/rerank validation issue (#10131)

Co-authored-by: Yi <yxiaoisme@gmail.com>
This commit is contained in:
zxhlyh 2024-10-31 20:20:46 +08:00 committed by GitHub
parent ce260f79d2
commit 2ecdc54b0b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 34 additions and 6 deletions

View File

@ -15,6 +15,7 @@ import { AppType } from '@/types/app'
import type { DataSet } from '@/models/datasets' import type { DataSet } from '@/models/datasets'
import { import {
getMultipleRetrievalConfig, getMultipleRetrievalConfig,
getSelectedDatasetsMode,
} from '@/app/components/workflow/nodes/knowledge-retrieval/utils' } from '@/app/components/workflow/nodes/knowledge-retrieval/utils'
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
@ -38,6 +39,7 @@ const DatasetConfig: FC = () => {
isAgent, isAgent,
datasetConfigs, datasetConfigs,
setDatasetConfigs, setDatasetConfigs,
setRerankSettingModalOpen,
} = useContext(ConfigContext) } = useContext(ConfigContext)
const formattingChangedDispatcher = useFormattingChangedDispatcher() const formattingChangedDispatcher = useFormattingChangedDispatcher()
@ -55,6 +57,20 @@ const DatasetConfig: FC = () => {
...(datasetConfigs as any), ...(datasetConfigs as any),
...retrievalConfig, ...retrievalConfig,
}) })
const {
allExternal,
allInternal,
mixtureInternalAndExternal,
mixtureHighQualityAndEconomic,
inconsistentEmbeddingModel,
} = getSelectedDatasetsMode(filteredDataSets)
if (
(allInternal && (mixtureHighQualityAndEconomic || inconsistentEmbeddingModel))
|| mixtureInternalAndExternal
|| allExternal
)
setRerankSettingModalOpen(true)
formattingChangedDispatcher() formattingChangedDispatcher()
} }

View File

@ -266,7 +266,7 @@ const ConfigContent: FC<Props> = ({
<div className='mt-2'> <div className='mt-2'>
<div className='flex items-center'> <div className='flex items-center'>
{ {
selectedDatasetsMode.allEconomic && ( selectedDatasetsMode.allEconomic && !selectedDatasetsMode.mixtureInternalAndExternal && (
<div <div
className='flex items-center' className='flex items-center'
onClick={handleDisabledSwitchClick} onClick={handleDisabledSwitchClick}

View File

@ -12,6 +12,7 @@ import { RETRIEVE_TYPE } from '@/types/app'
import Toast from '@/app/components/base/toast' import Toast from '@/app/components/base/toast'
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { RerankingModeEnum } from '@/models/datasets'
import type { DataSet } from '@/models/datasets' import type { DataSet } from '@/models/datasets'
import type { DatasetConfigs } from '@/models/debug' import type { DatasetConfigs } from '@/models/debug'
import { import {
@ -47,7 +48,10 @@ const ParamsConfig = ({
const isValid = () => { const isValid = () => {
let errMsg = '' let errMsg = ''
if (tempDataSetConfigs.retrieval_model === RETRIEVE_TYPE.multiWay) { if (tempDataSetConfigs.retrieval_model === RETRIEVE_TYPE.multiWay) {
if (!tempDataSetConfigs.reranking_model?.reranking_model_name && (rerankDefaultModel && !isRerankDefaultModelValid)) if (tempDataSetConfigs.reranking_enable
&& tempDataSetConfigs.reranking_mode === RerankingModeEnum.RerankingModel
&& !isRerankDefaultModelValid
)
errMsg = t('appDebug.datasetConfig.rerankModelRequired') errMsg = t('appDebug.datasetConfig.rerankModelRequired')
} }
if (errMsg) { if (errMsg) {
@ -62,7 +66,9 @@ const ParamsConfig = ({
if (!isValid()) if (!isValid())
return return
const config = { ...tempDataSetConfigs } const config = { ...tempDataSetConfigs }
if (config.retrieval_model === RETRIEVE_TYPE.multiWay && !config.reranking_model) { if (config.retrieval_model === RETRIEVE_TYPE.multiWay
&& config.reranking_mode === RerankingModeEnum.RerankingModel
&& !config.reranking_model) {
config.reranking_model = { config.reranking_model = {
reranking_provider_name: rerankDefaultModel?.provider?.provider, reranking_provider_name: rerankDefaultModel?.provider?.provider,
reranking_model_name: rerankDefaultModel?.model, reranking_model_name: rerankDefaultModel?.model,

View File

@ -253,12 +253,18 @@ const Configuration: FC = () => {
} }
hideSelectDataSet() hideSelectDataSet()
const { const {
allEconomic, allExternal,
allInternal,
mixtureInternalAndExternal,
mixtureHighQualityAndEconomic, mixtureHighQualityAndEconomic,
inconsistentEmbeddingModel, inconsistentEmbeddingModel,
} = getSelectedDatasetsMode(newDatasets) } = getSelectedDatasetsMode(newDatasets)
if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel) if (
(allInternal && (mixtureHighQualityAndEconomic || inconsistentEmbeddingModel))
|| mixtureInternalAndExternal
|| allExternal
)
setRerankSettingModalOpen(true) setRerankSettingModalOpen(true)
const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs

View File

@ -240,7 +240,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
if ( if (
(allInternal && (mixtureHighQualityAndEconomic || inconsistentEmbeddingModel)) (allInternal && (mixtureHighQualityAndEconomic || inconsistentEmbeddingModel))
|| mixtureInternalAndExternal || mixtureInternalAndExternal
|| (allExternal && newDatasets.length > 1) || allExternal
) )
setRerankModelOpen(true) setRerankModelOpen(true)
}, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel]) }, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel])