diff --git a/web/app/components/app/configuration/config/index.tsx b/web/app/components/app/configuration/config/index.tsx index 6f16ace26e..13162bfce9 100644 --- a/web/app/components/app/configuration/config/index.tsx +++ b/web/app/components/app/configuration/config/index.tsx @@ -26,6 +26,7 @@ import { useModalContext } from '@/context/modal-context' import ConfigParamModal from '@/app/components/app/configuration/toolbox/annotation/config-param-modal' import AnnotationFullModal from '@/app/components/billing/annotation-full/modal' import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' const Config: FC = () => { const { @@ -61,8 +62,8 @@ const Config: FC = () => { setModerationConfig, } = useContext(ConfigContext) const isChatApp = mode === AppType.chat - const { data: speech2textDefaultModel } = useDefaultModel(4) - const { data: text2speechDefaultModel } = useDefaultModel(5) + const { data: speech2textDefaultModel } = useDefaultModel(ModelTypeEnum.speech2text) + const { data: text2speechDefaultModel } = useDefaultModel(ModelTypeEnum.tts) const { setShowModerationSettingModal } = useModalContext() const formattingChangedDispatcher = useFormattingChangedDispatcher() diff --git a/web/app/components/app/configuration/dataset-config/params-config/index.tsx b/web/app/components/app/configuration/dataset-config/params-config/index.tsx index 8aecaf734f..952002a7cb 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/index.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/index.tsx @@ -20,6 +20,7 @@ import { } from '@/app/components/base/icons/src/public/common' import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' const ParamsConfig: FC = () => { const { t } = useTranslation() @@ -41,7 +42,7 @@ const ParamsConfig: FC = () => { modelList: rerankModelList, defaultModel: rerankDefaultModel, currentModel: isRerankDefaultModelVaild, - } = useModelListAndDefaultModelAndCurrentProviderAndModel(3) + } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) const rerankModel = (() => { if (tempDataSetConfigs.reranking_model) { diff --git a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx index c16605defa..945fe1b3f4 100644 --- a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx +++ b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx @@ -22,6 +22,7 @@ import { useModelList, useModelListAndDefaultModelAndCurrentProviderAndModel, } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' type SettingsModalProps = { currentDataset: DataSet @@ -42,12 +43,12 @@ const SettingsModal: FC = ({ onCancel, onSave, }) => { - const { data: embeddingsModelList } = useModelList(2) + const { data: embeddingsModelList } = useModelList(ModelTypeEnum.textEmbedding) const { modelList: rerankModelList, defaultModel: rerankDefaultModel, currentModel: isRerankDefaultModelVaild, - } = useModelListAndDefaultModelAndCurrentProviderAndModel(3) + } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) const { t } = useTranslation() const { notify } = useToastContext() const ref = useRef(null) diff --git a/web/app/components/app/configuration/debug/index.tsx b/web/app/components/app/configuration/debug/index.tsx index a7fd2d5ef7..45d6728097 100644 --- a/web/app/components/app/configuration/debug/index.tsx +++ b/web/app/components/app/configuration/debug/index.tsx @@ -31,7 +31,7 @@ import { IS_CE_EDITION } from '@/config' import type { Inputs } from '@/models/debug' import { fetchFileUploadConfig } from '@/service/common' import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' -import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { ModelFeatureEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import type { ModelParameterModalProps } from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal' import { Plus } from '@/app/components/base/icons/src/vender/line/general' import { useEventEmitterContextContext } from '@/context/event-emitter' @@ -84,7 +84,7 @@ const Debug: FC = ({ setVisionConfig, } = useContext(ConfigContext) const { eventEmitter } = useEventEmitterContextContext() - const { data: text2speechDefaultModel } = useDefaultModel(5) + const { data: text2speechDefaultModel } = useDefaultModel(ModelTypeEnum.textEmbedding) const { data: fileUploadConfigResponse } = useSWR({ url: '/files/upload' }, fetchFileUploadConfig) useEffect(() => { setAutoFreeze(false) diff --git a/web/app/components/app/configuration/toolbox/annotation/config-param-modal.tsx b/web/app/components/app/configuration/toolbox/annotation/config-param-modal.tsx index c4f449628e..c06c851149 100644 --- a/web/app/components/app/configuration/toolbox/annotation/config-param-modal.tsx +++ b/web/app/components/app/configuration/toolbox/annotation/config-param-modal.tsx @@ -11,6 +11,7 @@ import type { AnnotationReplyConfig } from '@/models/debug' import { ANNOTATION_DEFAULT } from '@/config' import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' type Props = { appId: string @@ -36,7 +37,7 @@ const ConfigParamModal: FC = ({ modelList: embeddingsModelList, defaultModel: embeddingsDefaultModel, currentModel: isEmbeddingsDefaultModelValid, - } = useModelListAndDefaultModelAndCurrentProviderAndModel(2) + } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textEmbedding) const [annotationConfig, setAnnotationConfig] = useState(oldAnnotationConfig) const [isLoading, setLoading] = useState(false) diff --git a/web/app/components/datasets/common/retrieval-method-config/index.tsx b/web/app/components/datasets/common/retrieval-method-config/index.tsx index 9e161fb3bc..1e407b62e1 100644 --- a/web/app/components/datasets/common/retrieval-method-config/index.tsx +++ b/web/app/components/datasets/common/retrieval-method-config/index.tsx @@ -10,6 +10,7 @@ import { PatternRecognition, Semantic } from '@/app/components/base/icons/src/ve import { FileSearch02 } from '@/app/components/base/icons/src/vender/solid/files' import { useProviderContext } from '@/context/provider-context' import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' type Props = { value: RetrievalConfig @@ -22,7 +23,7 @@ const RetrievalMethodConfig: FC = ({ }) => { const { t } = useTranslation() const { supportRetrievalMethods } = useProviderContext() - const { data: rerankDefaultModel } = useDefaultModel(3) + const { data: rerankDefaultModel } = useDefaultModel(ModelTypeEnum.rerank) const value = (() => { if (!passValue.reranking_model.reranking_model_name) { return { diff --git a/web/app/components/datasets/common/retrieval-param-config/index.tsx b/web/app/components/datasets/common/retrieval-param-config/index.tsx index dc578ae0ce..6fe896fda2 100644 --- a/web/app/components/datasets/common/retrieval-param-config/index.tsx +++ b/web/app/components/datasets/common/retrieval-param-config/index.tsx @@ -12,6 +12,7 @@ import { HelpCircle } from '@/app/components/base/icons/src/vender/line/general' import type { RetrievalConfig } from '@/types/app' import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' import { useModelListAndDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' type Props = { type: RETRIEVE_METHOD @@ -30,7 +31,7 @@ const RetrievalParamConfig: FC = ({ const { defaultModel: rerankDefaultModel, modelList: rerankModelList, - } = useModelListAndDefaultModel(3) + } = useModelListAndDefaultModel(ModelTypeEnum.rerank) const rerankModel = (() => { if (value.reranking_model) { diff --git a/web/app/components/datasets/create/index.tsx b/web/app/components/datasets/create/index.tsx index 535a9a4ad2..ce32f29fc5 100644 --- a/web/app/components/datasets/create/index.tsx +++ b/web/app/components/datasets/create/index.tsx @@ -2,6 +2,7 @@ import React, { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import AppUnavailable from '../../base/app-unavailable' +import { ModelTypeEnum } from '../../header/account-setting/model-provider-page/declarations' import StepsNavBar from './steps-nav-bar' import StepOne from './step-one' import StepTwo from './step-two' @@ -28,7 +29,7 @@ const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => { const [fileList, setFiles] = useState([]) const [result, setResult] = useState() const [hasError, setHasError] = useState(false) - const { data: embeddingsDefaultModel } = useDefaultModel(2) + const { data: embeddingsDefaultModel } = useDefaultModel(ModelTypeEnum.textEmbedding) const [notionPages, setNotionPages] = useState([]) const updateNotionPages = (value: NotionPage[]) => { diff --git a/web/app/components/datasets/create/step-two/index.tsx b/web/app/components/datasets/create/step-two/index.tsx index 55ea939292..90b798025e 100644 --- a/web/app/components/datasets/create/step-two/index.tsx +++ b/web/app/components/datasets/create/step-two/index.tsx @@ -43,6 +43,7 @@ import Tooltip from '@/app/components/base/tooltip' import TooltipPlus from '@/app/components/base/tooltip-plus' import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { LanguagesSupported } from '@/i18n/language' +import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' type ValueOf = T[keyof T] type StepTwoProps = { @@ -275,7 +276,7 @@ const StepTwo = ({ modelList: rerankModelList, defaultModel: rerankDefaultModel, currentModel: isRerankDefaultModelVaild, - } = useModelListAndDefaultModelAndCurrentProviderAndModel(3) + } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) const getCreationParams = () => { let params if (segmentationType === SegmentType.CUSTOM && overlap > max) { diff --git a/web/app/components/datasets/documents/detail/settings/index.tsx b/web/app/components/datasets/documents/detail/settings/index.tsx index 149b5e5546..891e284c95 100644 --- a/web/app/components/datasets/documents/detail/settings/index.tsx +++ b/web/app/components/datasets/documents/detail/settings/index.tsx @@ -14,6 +14,7 @@ import StepTwo from '@/app/components/datasets/create/step-two' import AccountSetting from '@/app/components/header/account-setting' import AppUnavailable from '@/app/components/base/app-unavailable' import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' type DocumentSettingsProps = { datasetId: string @@ -26,7 +27,7 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => { const [isShowSetAPIKey, { setTrue: showSetAPIKey, setFalse: hideSetAPIkey }] = useBoolean() const [hasError, setHasError] = useState(false) const { indexingTechnique, dataset } = useContext(DatasetDetailContext) - const { data: embeddingsDefaultModel } = useDefaultModel(2) + const { data: embeddingsDefaultModel } = useDefaultModel(ModelTypeEnum.textEmbedding) const saveHandler = () => router.push(`/datasets/${datasetId}/documents/${documentId}`) diff --git a/web/app/components/datasets/hit-testing/modify-retrieval-modal.tsx b/web/app/components/datasets/hit-testing/modify-retrieval-modal.tsx index 0c1e5d8621..c52f850ee0 100644 --- a/web/app/components/datasets/hit-testing/modify-retrieval-modal.tsx +++ b/web/app/components/datasets/hit-testing/modify-retrieval-modal.tsx @@ -3,6 +3,7 @@ import type { FC } from 'react' import React, { useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import Toast from '../../base/toast' +import { ModelTypeEnum } from '../../header/account-setting/model-provider-page/declarations' import { XClose } from '@/app/components/base/icons/src/vender/line/general' import type { RetrievalConfig } from '@/types/app' import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config' @@ -39,7 +40,7 @@ const ModifyRetrievalModal: FC = ({ modelList: rerankModelList, defaultModel: rerankDefaultModel, currentModel: isRerankDefaultModelVaild, - } = useModelListAndDefaultModelAndCurrentProviderAndModel(3) + } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) const handleSave = () => { if ( diff --git a/web/app/components/datasets/settings/form/index.tsx b/web/app/components/datasets/settings/form/index.tsx index d212b4a135..0fd09486fc 100644 --- a/web/app/components/datasets/settings/form/index.tsx +++ b/web/app/components/datasets/settings/form/index.tsx @@ -24,6 +24,7 @@ import { useModelList, useModelListAndDefaultModelAndCurrentProviderAndModel, } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' const rowClass = ` flex justify-between py-4 flex-wrap gap-y-2 @@ -63,8 +64,8 @@ const Form = () => { modelList: rerankModelList, defaultModel: rerankDefaultModel, currentModel: isRerankDefaultModelVaild, - } = useModelListAndDefaultModelAndCurrentProviderAndModel(3) - const { data: embeddingModelList } = useModelList(2) + } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) + const { data: embeddingModelList } = useModelList(ModelTypeEnum.textEmbedding) const handleSave = async () => { if (loading) diff --git a/web/app/components/header/account-setting/model-provider-page/hooks.ts b/web/app/components/header/account-setting/model-provider-page/hooks.ts index 3b5bdbb682..58c7dc906d 100644 --- a/web/app/components/header/account-setting/model-provider-page/hooks.ts +++ b/web/app/components/header/account-setting/model-provider-page/hooks.ts @@ -11,10 +11,11 @@ import type { DefaultModel, DefaultModelResponse, Model, + + ModelTypeEnum, } from './declarations' import { ConfigurateMethodEnum, - ModelTypeEnum, } from './declarations' import I18n from '@/context/i18n' import { @@ -99,17 +100,8 @@ export const useProviderCrenditialsFormSchemasValue = ( return value } -export type ModelTypeIndex = 1 | 2 | 3 | 4 | 5 -export const MODEL_TYPE_MAPS = { - 1: ModelTypeEnum.textGeneration, - 2: ModelTypeEnum.textEmbedding, - 3: ModelTypeEnum.rerank, - 4: ModelTypeEnum.speech2text, - 5: ModelTypeEnum.tts, -} - -export const useModelList = (type: ModelTypeIndex) => { - const { data, mutate, isLoading } = useSWR(`/workspaces/current/models/model-types/${MODEL_TYPE_MAPS[type]}`, fetchModelList) +export const useModelList = (type: ModelTypeEnum) => { + const { data, mutate, isLoading } = useSWR(`/workspaces/current/models/model-types/${type}`, fetchModelList) return { data: data?.data || [], @@ -118,8 +110,8 @@ export const useModelList = (type: ModelTypeIndex) => { } } -export const useDefaultModel = (type: ModelTypeIndex) => { - const { data, mutate, isLoading } = useSWR(`/workspaces/current/default-model?model_type=${MODEL_TYPE_MAPS[type]}`, fetchDefaultModal) +export const useDefaultModel = (type: ModelTypeEnum) => { + const { data, mutate, isLoading } = useSWR(`/workspaces/current/default-model?model_type=${type}`, fetchDefaultModal) return { data: data?.data, @@ -152,7 +144,7 @@ export const useTextGenerationCurrentProviderAndModelAndModelList = (defaultMode } } -export const useModelListAndDefaultModel = (type: ModelTypeIndex) => { +export const useModelListAndDefaultModel = (type: ModelTypeEnum) => { const { data: modelList } = useModelList(type) const { data: defaultModel } = useDefaultModel(type) @@ -162,7 +154,7 @@ export const useModelListAndDefaultModel = (type: ModelTypeIndex) => { } } -export const useModelListAndDefaultModelAndCurrentProviderAndModel = (type: ModelTypeIndex) => { +export const useModelListAndDefaultModelAndCurrentProviderAndModel = (type: ModelTypeEnum) => { const { modelList, defaultModel } = useModelListAndDefaultModel(type) const { currentProvider, currentModel } = useCurrentProviderAndModel( modelList, @@ -180,9 +172,8 @@ export const useModelListAndDefaultModelAndCurrentProviderAndModel = (type: Mode export const useUpdateModelList = () => { const { mutate } = useSWRConfig() - const updateModelList = useCallback((type: ModelTypeIndex | ModelTypeEnum) => { - const modelType = typeof type === 'number' ? MODEL_TYPE_MAPS[type] : type - mutate(`/workspaces/current/models/model-types/${modelType}`) + const updateModelList = useCallback((type: ModelTypeEnum) => { + mutate(`/workspaces/current/models/model-types/${type}`) }, [mutate]) return updateModelList diff --git a/web/app/components/header/account-setting/model-provider-page/index.tsx b/web/app/components/header/account-setting/model-provider-page/index.tsx index 6cb672673b..1fff60db88 100644 --- a/web/app/components/header/account-setting/model-provider-page/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/index.tsx @@ -10,6 +10,7 @@ import type { import { ConfigurateMethodEnum, CustomConfigurationStatusEnum, + ModelTypeEnum, } from './declarations' import { useDefaultModel, @@ -26,11 +27,11 @@ const ModelProviderPage = () => { const { eventEmitter } = useEventEmitterContextContext() const updateModelProviders = useUpdateModelProviders() const updateModelList = useUpdateModelList() - const { data: textGenerationDefaultModel } = useDefaultModel(1) - const { data: embeddingsDefaultModel } = useDefaultModel(2) - const { data: rerankDefaultModel } = useDefaultModel(3) - const { data: speech2textDefaultModel } = useDefaultModel(4) - const { data: ttsDefaultModel } = useDefaultModel(5) + const { data: textGenerationDefaultModel } = useDefaultModel(ModelTypeEnum.textGeneration) + const { data: embeddingsDefaultModel } = useDefaultModel(ModelTypeEnum.textEmbedding) + const { data: rerankDefaultModel } = useDefaultModel(ModelTypeEnum.rerank) + const { data: speech2textDefaultModel } = useDefaultModel(ModelTypeEnum.speech2text) + const { data: ttsDefaultModel } = useDefaultModel(ModelTypeEnum.tts) const { modelProviders: providers } = useProviderContext() const { setShowModelModal } = useModalContext() const defaultModelNotConfigured = !textGenerationDefaultModel && !embeddingsDefaultModel && !speech2textDefaultModel && !rerankDefaultModel && !ttsDefaultModel diff --git a/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx b/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx index 4215bbfdec..88cefe147c 100644 --- a/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx @@ -42,10 +42,10 @@ const SystemModel: FC = ({ const { notify } = useToastContext() const { textGenerationModelList } = useProviderContext() const updateModelList = useUpdateModelList() - const { data: embeddingModelList } = useModelList(2) - const { data: rerankModelList } = useModelList(3) - const { data: speech2textModelList } = useModelList(4) - const { data: ttsModelList } = useModelList(5) + const { data: embeddingModelList } = useModelList(ModelTypeEnum.textEmbedding) + const { data: rerankModelList } = useModelList(ModelTypeEnum.rerank) + const { data: speech2textModelList } = useModelList(ModelTypeEnum.speech2text) + const { data: ttsModelList } = useModelList(ModelTypeEnum.tts) const [changedModelTypes, setChangedModelTypes] = useState([]) const [currentTextGenerationDefaultModel, changeCurrentTextGenerationDefaultModel] = useSystemDefaultModelAndModelList(textGenerationDefaultModel, textGenerationModelList) const [currentEmbeddingsDefaultModel, changeCurrentEmbeddingsDefaultModel] = useSystemDefaultModelAndModelList(embeddingsDefaultModel, embeddingModelList)