diff --git a/web/app/components/app/configuration/dataset-config/index.tsx b/web/app/components/app/configuration/dataset-config/index.tsx index c98e90b18e..2c082d8815 100644 --- a/web/app/components/app/configuration/dataset-config/index.tsx +++ b/web/app/components/app/configuration/dataset-config/index.tsx @@ -13,6 +13,11 @@ import ContextVar from './context-var' import ConfigContext from '@/context/debug-configuration' import { AppType } from '@/types/app' import type { DataSet } from '@/models/datasets' +import { + getMultipleRetrievalConfig, +} from '@/app/components/workflow/nodes/knowledge-retrieval/utils' +import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' const Icon = ( @@ -31,13 +36,25 @@ const DatasetConfig: FC = () => { setModelConfig, showSelectDataSet, isAgent, + datasetConfigs, + setDatasetConfigs, } = useContext(ConfigContext) const formattingChangedDispatcher = useFormattingChangedDispatcher() const hasData = dataSet.length > 0 + const { + currentModel: currentRerankModel, + } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) + const onRemove = (id: string) => { - setDataSet(dataSet.filter(item => item.id !== id)) + const filteredDataSets = dataSet.filter(item => item.id !== id) + setDataSet(filteredDataSets) + const retrievalConfig = getMultipleRetrievalConfig(datasetConfigs as any, filteredDataSets, dataSet, !!currentRerankModel) + setDatasetConfigs({ + ...(datasetConfigs as any), + ...retrievalConfig, + }) formattingChangedDispatcher() } diff --git a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx index 0068a7abbf..f4c7c4ff19 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx @@ -55,7 +55,7 @@ const ConfigContent: FC = ({ retrieval_model: RETRIEVE_TYPE.multiWay, }, isInWorkflow) } - }, [type]) + }, [type, datasetConfigs, isInWorkflow, onChange]) const { modelList: rerankModelList, diff --git a/web/app/components/app/configuration/dataset-config/params-config/index.tsx b/web/app/components/app/configuration/dataset-config/params-config/index.tsx index c8c8acccd3..207d4ba81d 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/index.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/index.tsx @@ -16,7 +16,6 @@ import type { DataSet } from '@/models/datasets' import type { DatasetConfigs } from '@/models/debug' import { getMultipleRetrievalConfig, - getSelectedDatasetsMode, } from '@/app/components/workflow/nodes/knowledge-retrieval/utils' type ParamsConfigProps = { @@ -37,57 +36,8 @@ const ParamsConfig = ({ const [tempDataSetConfigs, setTempDataSetConfigs] = useState(datasetConfigs) useEffect(() => { - const { - allEconomic, - allHighQuality, - allHighQualityFullTextSearch, - allHighQualityVectorSearch, - allExternal, - mixtureHighQualityAndEconomic, - inconsistentEmbeddingModel, - mixtureInternalAndExternal, - } = getSelectedDatasetsMode(selectedDatasets) - - if (allEconomic || allHighQuality || allHighQualityFullTextSearch || allHighQualityVectorSearch || (allExternal && selectedDatasets.length === 1)) - setRerankSettingModalOpen(false) - - if (mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || mixtureInternalAndExternal || (allExternal && selectedDatasets.length > 1)) - setRerankSettingModalOpen(true) - }, [selectedDatasets]) - - useEffect(() => { - const { - allEconomic, - allInternal, - allExternal, - } = getSelectedDatasetsMode(selectedDatasets) - const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs - let rerankEnable = restConfigs.reranking_enable - - if (((allInternal && allEconomic) || allExternal) && !restConfigs.reranking_model?.reranking_provider_name && rerankEnable === undefined) - rerankEnable = false - - setTempDataSetConfigs({ - ...getMultipleRetrievalConfig({ - top_k: restConfigs.top_k, - score_threshold: restConfigs.score_threshold, - reranking_model: restConfigs.reranking_model && { - provider: restConfigs.reranking_model.reranking_provider_name, - model: restConfigs.reranking_model.reranking_model_name, - }, - reranking_mode: restConfigs.reranking_mode, - weights: restConfigs.weights, - reranking_enable: rerankEnable, - }, selectedDatasets), - reranking_model: restConfigs.reranking_model && { - reranking_provider_name: restConfigs.reranking_model.reranking_provider_name, - reranking_model_name: restConfigs.reranking_model.reranking_model_name, - }, - retrieval_model, - score_threshold_enabled, - datasets, - }) - }, [selectedDatasets, datasetConfigs]) + setTempDataSetConfigs(datasetConfigs) + }, [datasetConfigs]) const { defaultModel: rerankDefaultModel, @@ -135,7 +85,7 @@ const ParamsConfig = ({ reranking_mode: restConfigs.reranking_mode, weights: restConfigs.weights, reranking_enable: restConfigs.reranking_enable, - }, selectedDatasets) + }, selectedDatasets, selectedDatasets, !!isRerankDefaultModelValid) setTempDataSetConfigs({ ...retrievalConfig, @@ -180,6 +130,7 @@ const ParamsConfig = ({
diff --git a/web/app/components/app/configuration/index.tsx b/web/app/components/app/configuration/index.tsx index fab7b238c4..434b54ab91 100644 --- a/web/app/components/app/configuration/index.tsx +++ b/web/app/components/app/configuration/index.tsx @@ -38,7 +38,7 @@ import ConfigContext from '@/context/debug-configuration' import Config from '@/app/components/app/configuration/config' import Debug from '@/app/components/app/configuration/debug' import Confirm from '@/app/components/base/confirm' -import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { ModelFeatureEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { ToastContext } from '@/app/components/base/toast' import { fetchAppDetail, updateAppModelConfig } from '@/service/apps' import { promptVariablesToUserInputsForm, userInputsFormToPromptVariables } from '@/utils/model-config' @@ -53,7 +53,10 @@ import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import Drawer from '@/app/components/base/drawer' import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal' import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations' -import { useTextGenerationCurrentProviderAndModelAndModelList } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { + useModelListAndDefaultModelAndCurrentProviderAndModel, + useTextGenerationCurrentProviderAndModelAndModelList, +} from '@/app/components/header/account-setting/model-provider-page/hooks' import { fetchCollectionList } from '@/service/tools' import { type Collection } from '@/app/components/tools/types' import { useStore as useAppStore } from '@/app/components/app/store' @@ -217,6 +220,9 @@ const Configuration: FC = () => { const [isShowSelectDataSet, { setTrue: showSelectDataSet, setFalse: hideSelectDataSet }] = useBoolean(false) const selectedIds = dataSets.map(item => item.id) const [rerankSettingModalOpen, setRerankSettingModalOpen] = useState(false) + const { + currentModel: currentRerankModel, + } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) const handleSelect = (data: DataSet[]) => { if (isEqual(data.map(item => item.id), dataSets.map(item => item.id))) { hideSelectDataSet() @@ -263,7 +269,7 @@ const Configuration: FC = () => { reranking_mode: restConfigs.reranking_mode, weights: restConfigs.weights, reranking_enable: restConfigs.reranking_enable, - }, newDatasets) + }, newDatasets, dataSets, !!currentRerankModel) setDatasetConfigs({ ...retrievalConfig, @@ -603,9 +609,11 @@ const Configuration: FC = () => { syncToPublishedConfig(config) setPublishedConfig(config) + const retrievalConfig = getMultipleRetrievalConfig(modelConfig.dataset_configs, datasets, datasets, !!currentRerankModel) setDatasetConfigs({ retrieval_model: RETRIEVE_TYPE.multiWay, ...modelConfig.dataset_configs, + ...retrievalConfig, }) setHasFetchedDetail(true) }) diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts b/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts index 01c1e31ccc..d280a2d63e 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts +++ b/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts @@ -163,7 +163,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { draft.retrieval_mode = newMode if (newMode === RETRIEVE_TYPE.multiWay) { const multipleRetrievalConfig = draft.multiple_retrieval_config - draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets) + draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel) } else { const hasSetModel = draft.single_retrieval_config?.model?.provider @@ -180,14 +180,14 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { } }) setInputs(newInputs) - }, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets]) + }, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets, currentRerankModel]) const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => { const newInputs = produce(inputs, (draft) => { - draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets) + draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel) }) setInputs(newInputs) - }, [inputs, setInputs, selectedDatasets]) + }, [inputs, setInputs, selectedDatasets, currentRerankModel]) // datasets useEffect(() => { @@ -231,7 +231,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) { const multipleRetrievalConfig = draft.multiple_retrieval_config - draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets) + draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, !!currentRerankModel) } }) setInputs(newInputs) @@ -243,7 +243,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { || (allExternal && newDatasets.length > 1) ) setRerankModelOpen(true) - }, [inputs, setInputs, payload.retrieval_mode]) + }, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel]) const filterVar = useCallback((varPayload: Var) => { return varPayload.type === VarType.string diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/utils.ts b/web/app/components/workflow/nodes/knowledge-retrieval/utils.ts index e48777d948..fd3d3ebab9 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/utils.ts +++ b/web/app/components/workflow/nodes/knowledge-retrieval/utils.ts @@ -1,4 +1,7 @@ -import { uniq } from 'lodash-es' +import { + uniq, + xorBy, +} from 'lodash-es' import type { MultipleRetrievalConfig } from './types' import type { DataSet, @@ -15,7 +18,9 @@ export const checkNodeValid = () => { return true } -export const getSelectedDatasetsMode = (datasets: DataSet[]) => { +export const getSelectedDatasetsMode = (datasets: DataSet[] = []) => { + if (datasets === null) + datasets = [] let allHighQuality = true let allHighQualityVectorSearch = true let allHighQualityFullTextSearch = true @@ -85,7 +90,14 @@ export const getSelectedDatasetsMode = (datasets: DataSet[]) => { } as SelectedDatasetsMode } -export const getMultipleRetrievalConfig = (multipleRetrievalConfig: MultipleRetrievalConfig, selectedDatasets: DataSet[]) => { +export const getMultipleRetrievalConfig = ( + multipleRetrievalConfig: MultipleRetrievalConfig, + selectedDatasets: DataSet[], + originalDatasets: DataSet[], + isValidRerankModel?: boolean, +) => { + const shouldSetWeightDefaultValue = xorBy(selectedDatasets, originalDatasets, 'id').length > 0 + const { allHighQuality, allHighQualityVectorSearch, @@ -123,6 +135,37 @@ export const getMultipleRetrievalConfig = (multipleRetrievalConfig: MultipleRetr result.reranking_mode = RerankingModeEnum.WeightedScore if (allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined) && allInternal && !weights) { + if (!isValidRerankModel) + result.reranking_mode = RerankingModeEnum.WeightedScore + else + result.reranking_mode = RerankingModeEnum.RerankingModel + + result.weights = { + vector_setting: { + vector_weight: allHighQualityVectorSearch + ? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.semantic + : allHighQualityFullTextSearch + ? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.semantic + : DEFAULT_WEIGHTED_SCORE.other.semantic, + embedding_provider_name: selectedDatasets[0].embedding_model_provider, + embedding_model_name: selectedDatasets[0].embedding_model, + }, + keyword_setting: { + keyword_weight: allHighQualityVectorSearch + ? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.keyword + : allHighQualityFullTextSearch + ? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.keyword + : DEFAULT_WEIGHTED_SCORE.other.keyword, + }, + } + } + + if (shouldSetWeightDefaultValue && allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined || !isValidRerankModel) && allInternal && weights) { + if (!isValidRerankModel) + result.reranking_mode = RerankingModeEnum.WeightedScore + else + result.reranking_mode = RerankingModeEnum.RerankingModel + result.weights = { vector_setting: { vector_weight: allHighQualityVectorSearch diff --git a/web/models/datasets.ts b/web/models/datasets.ts index 81a750968b..1ecaa3e10b 100644 --- a/web/models/datasets.ts +++ b/web/models/datasets.ts @@ -566,14 +566,6 @@ export const DEFAULT_WEIGHTED_SCORE = { semantic: 0, keyword: 1.0, }, - semanticFirst: { - semantic: 0.7, - keyword: 0.3, - }, - keywordFirst: { - semantic: 0.3, - keyword: 0.7, - }, other: { semantic: 0.7, keyword: 0.3,