Joel 7709d9df20
Chore: frontend infrastructure upgrade (#16420)
Co-authored-by: NFish <douxc512@gmail.com>
Co-authored-by: zxhlyh <jasonapring2015@outlook.com>
Co-authored-by: twwu <twwu@dify.ai>
Co-authored-by: jZonG <jzongcode@gmail.com>
2025-03-21 17:41:03 +08:00

252 lines
8.9 KiB
TypeScript

import type {
FC,
ReactNode,
} from 'react'
import { useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import type {
DefaultModel,
FormValue,
} from '@/app/components/header/account-setting/model-provider-page/declarations'
import { ModelStatusEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
import {
useModelList,
} from '@/app/components/header/account-setting/model-provider-page/hooks'
import AgentModelTrigger from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/agent-model-trigger'
import Trigger from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/trigger'
import type { TriggerProps } from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/trigger'
import {
PortalToFollowElem,
PortalToFollowElemContent,
PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem'
import LLMParamsPanel from './llm-params-panel'
import TTSParamsPanel from './tts-params-panel'
import { useProviderContext } from '@/context/provider-context'
import cn from '@/utils/classnames'
export type ModelParameterModalProps = {
popupClassName?: string
portalToFollowElemContentClassName?: string
isAdvancedMode: boolean
value: any
setModel: (model: any) => void
renderTrigger?: (v: TriggerProps) => ReactNode
readonly?: boolean
isInWorkflow?: boolean
isAgentStrategy?: boolean
scope?: string
}
const ModelParameterModal: FC<ModelParameterModalProps> = ({
popupClassName,
portalToFollowElemContentClassName,
isAdvancedMode,
value,
setModel,
renderTrigger,
readonly,
isInWorkflow,
isAgentStrategy,
scope = ModelTypeEnum.textGeneration,
}) => {
const { t } = useTranslation()
const { isAPIKeySet } = useProviderContext()
const [open, setOpen] = useState(false)
const scopeArray = scope.split('&')
const scopeFeatures = useMemo(() => {
if (scopeArray.includes('all'))
return []
return scopeArray.filter(item => ![
ModelTypeEnum.textGeneration,
ModelTypeEnum.textEmbedding,
ModelTypeEnum.rerank,
ModelTypeEnum.moderation,
ModelTypeEnum.speech2text,
ModelTypeEnum.tts,
].includes(item as ModelTypeEnum))
}, [scopeArray])
const { data: textGenerationList } = useModelList(ModelTypeEnum.textGeneration)
const { data: textEmbeddingList } = useModelList(ModelTypeEnum.textEmbedding)
const { data: rerankList } = useModelList(ModelTypeEnum.rerank)
const { data: moderationList } = useModelList(ModelTypeEnum.moderation)
const { data: sttList } = useModelList(ModelTypeEnum.speech2text)
const { data: ttsList } = useModelList(ModelTypeEnum.tts)
const scopedModelList = useMemo(() => {
const resultList: any[] = []
if (scopeArray.includes('all')) {
return [
...textGenerationList,
...textEmbeddingList,
...rerankList,
...sttList,
...ttsList,
...moderationList,
]
}
if (scopeArray.includes(ModelTypeEnum.textGeneration))
return textGenerationList
if (scopeArray.includes(ModelTypeEnum.textEmbedding))
return textEmbeddingList
if (scopeArray.includes(ModelTypeEnum.rerank))
return rerankList
if (scopeArray.includes(ModelTypeEnum.moderation))
return moderationList
if (scopeArray.includes(ModelTypeEnum.speech2text))
return sttList
if (scopeArray.includes(ModelTypeEnum.tts))
return ttsList
return resultList
}, [scopeArray, textGenerationList, textEmbeddingList, rerankList, sttList, ttsList, moderationList])
const { currentProvider, currentModel } = useMemo(() => {
const currentProvider = scopedModelList.find(item => item.provider === value?.provider)
const currentModel = currentProvider?.models.find((model: { model: string }) => model.model === value?.model)
return {
currentProvider,
currentModel,
}
}, [scopedModelList, value?.provider, value?.model])
const hasDeprecated = useMemo(() => {
return !currentProvider || !currentModel
}, [currentModel, currentProvider])
const modelDisabled = useMemo(() => {
return currentModel?.status !== ModelStatusEnum.active
}, [currentModel?.status])
const disabled = useMemo(() => {
return !isAPIKeySet || hasDeprecated || modelDisabled
}, [hasDeprecated, isAPIKeySet, modelDisabled])
const handleChangeModel = ({ provider, model }: DefaultModel) => {
const targetProvider = scopedModelList.find(modelItem => modelItem.provider === provider)
const targetModelItem = targetProvider?.models.find((modelItem: { model: string }) => modelItem.model === model)
const model_type = targetModelItem?.model_type as string
setModel({
provider,
model,
model_type,
...(model_type === ModelTypeEnum.textGeneration ? {
mode: targetModelItem?.model_properties.mode as string,
completion_params: {},
} : {}),
})
}
const handleLLMParamsChange = (newParams: FormValue) => {
const newValue = {
...(value?.completionParams || {}),
completion_params: newParams,
}
setModel({
...value,
...newValue,
})
}
const handleTTSParamsChange = (language: string, voice: string) => {
setModel({
...value,
language,
voice,
})
}
return (
<PortalToFollowElem
open={open}
onOpenChange={setOpen}
placement={isInWorkflow ? 'left' : 'bottom-end'}
offset={4}
>
<div className='relative'>
<PortalToFollowElemTrigger
onClick={() => {
if (readonly)
return
setOpen(v => !v)
}}
className='block'
>
{
renderTrigger
? renderTrigger({
open,
disabled,
modelDisabled,
hasDeprecated,
currentProvider,
currentModel,
providerName: value?.provider,
modelId: value?.model,
})
: (isAgentStrategy
? <AgentModelTrigger
disabled={disabled}
hasDeprecated={hasDeprecated}
currentProvider={currentProvider}
currentModel={currentModel}
providerName={value?.provider}
modelId={value?.model}
scope={scope}
/>
: <Trigger
disabled={disabled}
isInWorkflow={isInWorkflow}
modelDisabled={modelDisabled}
hasDeprecated={hasDeprecated}
currentProvider={currentProvider}
currentModel={currentModel}
providerName={value?.provider}
modelId={value?.model}
/>
)
}
</PortalToFollowElemTrigger>
<PortalToFollowElemContent className={cn('z-50', portalToFollowElemContentClassName)}>
<div className={cn(popupClassName, 'w-[389px] rounded-2xl border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-lg')}>
<div className={cn('max-h-[420px] overflow-y-auto p-4 pt-3')}>
<div className='relative'>
<div className={cn('system-sm-semibold mb-1 flex h-6 items-center text-text-secondary')}>
{t('common.modelProvider.model').toLocaleUpperCase()}
</div>
<ModelSelector
defaultModel={(value?.provider || value?.model) ? { provider: value?.provider, model: value?.model } : undefined}
modelList={scopedModelList}
scopeFeatures={scopeFeatures}
onSelect={handleChangeModel}
/>
</div>
{(currentModel?.model_type === ModelTypeEnum.textGeneration || currentModel?.model_type === ModelTypeEnum.tts) && (
<div className='my-3 h-[1px] bg-divider-subtle' />
)}
{currentModel?.model_type === ModelTypeEnum.textGeneration && (
<LLMParamsPanel
provider={value?.provider}
modelId={value?.model}
completionParams={value?.completion_params || {}}
onCompletionParamsChange={handleLLMParamsChange}
isAdvancedMode={isAdvancedMode}
/>
)}
{currentModel?.model_type === ModelTypeEnum.tts && (
<TTSParamsPanel
currentModel={currentModel}
language={value?.language}
voice={value?.voice}
onChange={handleTTSParamsChange}
/>
)}
</div>
</div>
</PortalToFollowElemContent>
</div>
</PortalToFollowElem>
)
}
export default ModelParameterModal