mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-05-29 17:45:12 +08:00

Co-authored-by: NFish <douxc512@gmail.com> Co-authored-by: zxhlyh <jasonapring2015@outlook.com> Co-authored-by: twwu <twwu@dify.ai> Co-authored-by: jZonG <jzongcode@gmail.com>
252 lines
8.9 KiB
TypeScript
252 lines
8.9 KiB
TypeScript
import type {
|
|
FC,
|
|
ReactNode,
|
|
} from 'react'
|
|
import { useMemo, useState } from 'react'
|
|
import { useTranslation } from 'react-i18next'
|
|
import type {
|
|
DefaultModel,
|
|
FormValue,
|
|
} from '@/app/components/header/account-setting/model-provider-page/declarations'
|
|
import { ModelStatusEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
|
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
|
|
import {
|
|
useModelList,
|
|
} from '@/app/components/header/account-setting/model-provider-page/hooks'
|
|
import AgentModelTrigger from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/agent-model-trigger'
|
|
import Trigger from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/trigger'
|
|
import type { TriggerProps } from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/trigger'
|
|
import {
|
|
PortalToFollowElem,
|
|
PortalToFollowElemContent,
|
|
PortalToFollowElemTrigger,
|
|
} from '@/app/components/base/portal-to-follow-elem'
|
|
import LLMParamsPanel from './llm-params-panel'
|
|
import TTSParamsPanel from './tts-params-panel'
|
|
import { useProviderContext } from '@/context/provider-context'
|
|
import cn from '@/utils/classnames'
|
|
|
|
export type ModelParameterModalProps = {
|
|
popupClassName?: string
|
|
portalToFollowElemContentClassName?: string
|
|
isAdvancedMode: boolean
|
|
value: any
|
|
setModel: (model: any) => void
|
|
renderTrigger?: (v: TriggerProps) => ReactNode
|
|
readonly?: boolean
|
|
isInWorkflow?: boolean
|
|
isAgentStrategy?: boolean
|
|
scope?: string
|
|
}
|
|
|
|
const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
|
popupClassName,
|
|
portalToFollowElemContentClassName,
|
|
isAdvancedMode,
|
|
value,
|
|
setModel,
|
|
renderTrigger,
|
|
readonly,
|
|
isInWorkflow,
|
|
isAgentStrategy,
|
|
scope = ModelTypeEnum.textGeneration,
|
|
}) => {
|
|
const { t } = useTranslation()
|
|
const { isAPIKeySet } = useProviderContext()
|
|
const [open, setOpen] = useState(false)
|
|
const scopeArray = scope.split('&')
|
|
const scopeFeatures = useMemo(() => {
|
|
if (scopeArray.includes('all'))
|
|
return []
|
|
return scopeArray.filter(item => ![
|
|
ModelTypeEnum.textGeneration,
|
|
ModelTypeEnum.textEmbedding,
|
|
ModelTypeEnum.rerank,
|
|
ModelTypeEnum.moderation,
|
|
ModelTypeEnum.speech2text,
|
|
ModelTypeEnum.tts,
|
|
].includes(item as ModelTypeEnum))
|
|
}, [scopeArray])
|
|
|
|
const { data: textGenerationList } = useModelList(ModelTypeEnum.textGeneration)
|
|
const { data: textEmbeddingList } = useModelList(ModelTypeEnum.textEmbedding)
|
|
const { data: rerankList } = useModelList(ModelTypeEnum.rerank)
|
|
const { data: moderationList } = useModelList(ModelTypeEnum.moderation)
|
|
const { data: sttList } = useModelList(ModelTypeEnum.speech2text)
|
|
const { data: ttsList } = useModelList(ModelTypeEnum.tts)
|
|
|
|
const scopedModelList = useMemo(() => {
|
|
const resultList: any[] = []
|
|
if (scopeArray.includes('all')) {
|
|
return [
|
|
...textGenerationList,
|
|
...textEmbeddingList,
|
|
...rerankList,
|
|
...sttList,
|
|
...ttsList,
|
|
...moderationList,
|
|
]
|
|
}
|
|
if (scopeArray.includes(ModelTypeEnum.textGeneration))
|
|
return textGenerationList
|
|
if (scopeArray.includes(ModelTypeEnum.textEmbedding))
|
|
return textEmbeddingList
|
|
if (scopeArray.includes(ModelTypeEnum.rerank))
|
|
return rerankList
|
|
if (scopeArray.includes(ModelTypeEnum.moderation))
|
|
return moderationList
|
|
if (scopeArray.includes(ModelTypeEnum.speech2text))
|
|
return sttList
|
|
if (scopeArray.includes(ModelTypeEnum.tts))
|
|
return ttsList
|
|
return resultList
|
|
}, [scopeArray, textGenerationList, textEmbeddingList, rerankList, sttList, ttsList, moderationList])
|
|
|
|
const { currentProvider, currentModel } = useMemo(() => {
|
|
const currentProvider = scopedModelList.find(item => item.provider === value?.provider)
|
|
const currentModel = currentProvider?.models.find((model: { model: string }) => model.model === value?.model)
|
|
return {
|
|
currentProvider,
|
|
currentModel,
|
|
}
|
|
}, [scopedModelList, value?.provider, value?.model])
|
|
|
|
const hasDeprecated = useMemo(() => {
|
|
return !currentProvider || !currentModel
|
|
}, [currentModel, currentProvider])
|
|
const modelDisabled = useMemo(() => {
|
|
return currentModel?.status !== ModelStatusEnum.active
|
|
}, [currentModel?.status])
|
|
const disabled = useMemo(() => {
|
|
return !isAPIKeySet || hasDeprecated || modelDisabled
|
|
}, [hasDeprecated, isAPIKeySet, modelDisabled])
|
|
|
|
const handleChangeModel = ({ provider, model }: DefaultModel) => {
|
|
const targetProvider = scopedModelList.find(modelItem => modelItem.provider === provider)
|
|
const targetModelItem = targetProvider?.models.find((modelItem: { model: string }) => modelItem.model === model)
|
|
const model_type = targetModelItem?.model_type as string
|
|
setModel({
|
|
provider,
|
|
model,
|
|
model_type,
|
|
...(model_type === ModelTypeEnum.textGeneration ? {
|
|
mode: targetModelItem?.model_properties.mode as string,
|
|
completion_params: {},
|
|
} : {}),
|
|
})
|
|
}
|
|
|
|
const handleLLMParamsChange = (newParams: FormValue) => {
|
|
const newValue = {
|
|
...(value?.completionParams || {}),
|
|
completion_params: newParams,
|
|
}
|
|
setModel({
|
|
...value,
|
|
...newValue,
|
|
})
|
|
}
|
|
|
|
const handleTTSParamsChange = (language: string, voice: string) => {
|
|
setModel({
|
|
...value,
|
|
language,
|
|
voice,
|
|
})
|
|
}
|
|
|
|
return (
|
|
<PortalToFollowElem
|
|
open={open}
|
|
onOpenChange={setOpen}
|
|
placement={isInWorkflow ? 'left' : 'bottom-end'}
|
|
offset={4}
|
|
>
|
|
<div className='relative'>
|
|
<PortalToFollowElemTrigger
|
|
onClick={() => {
|
|
if (readonly)
|
|
return
|
|
setOpen(v => !v)
|
|
}}
|
|
className='block'
|
|
>
|
|
{
|
|
renderTrigger
|
|
? renderTrigger({
|
|
open,
|
|
disabled,
|
|
modelDisabled,
|
|
hasDeprecated,
|
|
currentProvider,
|
|
currentModel,
|
|
providerName: value?.provider,
|
|
modelId: value?.model,
|
|
})
|
|
: (isAgentStrategy
|
|
? <AgentModelTrigger
|
|
disabled={disabled}
|
|
hasDeprecated={hasDeprecated}
|
|
currentProvider={currentProvider}
|
|
currentModel={currentModel}
|
|
providerName={value?.provider}
|
|
modelId={value?.model}
|
|
scope={scope}
|
|
/>
|
|
: <Trigger
|
|
disabled={disabled}
|
|
isInWorkflow={isInWorkflow}
|
|
modelDisabled={modelDisabled}
|
|
hasDeprecated={hasDeprecated}
|
|
currentProvider={currentProvider}
|
|
currentModel={currentModel}
|
|
providerName={value?.provider}
|
|
modelId={value?.model}
|
|
/>
|
|
)
|
|
}
|
|
</PortalToFollowElemTrigger>
|
|
<PortalToFollowElemContent className={cn('z-50', portalToFollowElemContentClassName)}>
|
|
<div className={cn(popupClassName, 'w-[389px] rounded-2xl border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-lg')}>
|
|
<div className={cn('max-h-[420px] overflow-y-auto p-4 pt-3')}>
|
|
<div className='relative'>
|
|
<div className={cn('system-sm-semibold mb-1 flex h-6 items-center text-text-secondary')}>
|
|
{t('common.modelProvider.model').toLocaleUpperCase()}
|
|
</div>
|
|
<ModelSelector
|
|
defaultModel={(value?.provider || value?.model) ? { provider: value?.provider, model: value?.model } : undefined}
|
|
modelList={scopedModelList}
|
|
scopeFeatures={scopeFeatures}
|
|
onSelect={handleChangeModel}
|
|
/>
|
|
</div>
|
|
{(currentModel?.model_type === ModelTypeEnum.textGeneration || currentModel?.model_type === ModelTypeEnum.tts) && (
|
|
<div className='my-3 h-[1px] bg-divider-subtle' />
|
|
)}
|
|
{currentModel?.model_type === ModelTypeEnum.textGeneration && (
|
|
<LLMParamsPanel
|
|
provider={value?.provider}
|
|
modelId={value?.model}
|
|
completionParams={value?.completion_params || {}}
|
|
onCompletionParamsChange={handleLLMParamsChange}
|
|
isAdvancedMode={isAdvancedMode}
|
|
/>
|
|
)}
|
|
{currentModel?.model_type === ModelTypeEnum.tts && (
|
|
<TTSParamsPanel
|
|
currentModel={currentModel}
|
|
language={value?.language}
|
|
voice={value?.voice}
|
|
onChange={handleTTSParamsChange}
|
|
/>
|
|
)}
|
|
</div>
|
|
</div>
|
|
</PortalToFollowElemContent>
|
|
</div>
|
|
</PortalToFollowElem>
|
|
)
|
|
}
|
|
|
|
export default ModelParameterModal
|