diff --git a/web/app/components/workflow/hooks/use-config-vision.ts b/web/app/components/workflow/hooks/use-config-vision.ts new file mode 100644 index 0000000000..d952ae6114 --- /dev/null +++ b/web/app/components/workflow/hooks/use-config-vision.ts @@ -0,0 +1,77 @@ +import produce from 'immer' +import { useCallback } from 'react' +import type { ModelConfig, VisionSetting } from '@/app/components/workflow/types' +import { useTextGenerationCurrentProviderAndModelAndModelList } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { + ModelFeatureEnum, +} from '@/app/components/header/account-setting/model-provider-page/declarations' +import { Resolution } from '@/types/app' + +type Payload = { + enabled: boolean + configs?: VisionSetting +} + +type Params = { + payload: Payload + onChange: (payload: Payload) => void +} +const useConfigVision = (model: ModelConfig, { + payload, + onChange, +}: Params) => { + const { + currentModel: currModel, + } = useTextGenerationCurrentProviderAndModelAndModelList( + { + provider: model.provider, + model: model.name, + }, + ) + + const getIsVisionModel = useCallback(() => { + return !!currModel?.features?.includes(ModelFeatureEnum.vision) + }, [currModel]) + + const isVisionModel = getIsVisionModel() + + const handleVisionResolutionEnabledChange = useCallback((enabled: boolean) => { + const newPayload = produce(payload, (draft) => { + draft.enabled = enabled + }) + onChange(newPayload) + }, [onChange, payload]) + + const handleVisionResolutionChange = useCallback((config: VisionSetting) => { + const newPayload = produce(payload, (draft) => { + draft.configs = config + }) + onChange(newPayload) + }, [onChange, payload]) + + const handleModelChanged = useCallback(() => { + const isVisionModel = getIsVisionModel() + if (!isVisionModel) { + handleVisionResolutionEnabledChange(false) + return + } + if (payload.enabled) { + onChange({ + enabled: true, + configs: { + detail: Resolution.high, + valueSelector: [], + }, + }) + } + }, [getIsVisionModel, handleVisionResolutionEnabledChange, onChange, payload.enabled]) + + return { + isVisionModel, + handleVisionResolutionEnabledChange, + handleVisionResolutionChange, + handleModelChanged, + } +} + +export default useConfigVision diff --git a/web/app/components/workflow/nodes/llm/use-config.ts b/web/app/components/workflow/nodes/llm/use-config.ts index e4c973a99c..a195668417 100644 --- a/web/app/components/workflow/nodes/llm/use-config.ts +++ b/web/app/components/workflow/nodes/llm/use-config.ts @@ -1,18 +1,17 @@ import { useCallback, useEffect, useRef, useState } from 'react' import produce from 'immer' import { EditionType, VarType } from '../../types' -import type { Memory, PromptItem, ValueSelector, Var, Variable, VisionSetting } from '../../types' +import type { Memory, PromptItem, ValueSelector, Var, Variable } from '../../types' import { useStore } from '../../store' import { useIsChatMode, useNodesReadOnly, } from '../../hooks' import useAvailableVarList from '../_base/hooks/use-available-var-list' +import useConfigVision from '../../hooks/use-config-vision' import type { LLMNodeType } from './types' -import { Resolution } from '@/types/app' -import { useModelListAndDefaultModelAndCurrentProviderAndModel, useTextGenerationCurrentProviderAndModelAndModelList } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { - ModelFeatureEnum, ModelTypeEnum, } from '@/app/components/header/account-setting/model-provider-page/declarations' import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud' @@ -109,6 +108,21 @@ const useConfig = (id: string, payload: LLMNodeType) => { currentModel, } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration) + const { + isVisionModel, + handleVisionResolutionEnabledChange, + handleVisionResolutionChange, + handleModelChanged: handleVisionConfigAfterModelChanged, + } = useConfigVision(model, { + payload: inputs.vision, + onChange: (newPayload) => { + const newInputs = produce(inputs, (draft) => { + draft.vision = newPayload + }) + setInputs(newInputs) + }, + }) + const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => { const newInputs = produce(inputRef.current, (draft) => { draft.model.provider = model.provider @@ -139,43 +153,12 @@ const useConfig = (id: string, payload: LLMNodeType) => { setInputs(newInputs) }, [inputs, setInputs]) - const { - currentModel: currModel, - } = useTextGenerationCurrentProviderAndModelAndModelList( - { - provider: model.provider, - model: model.name, - }, - ) - const isVisionModel = !!currModel?.features?.includes(ModelFeatureEnum.vision) // change to vision model to set vision enabled, else disabled useEffect(() => { if (!modelChanged) return setModelChanged(false) - if (!isVisionModel) { - const newInputs = produce(inputs, (draft) => { - draft.vision = { - enabled: false, - } - }) - setInputs(newInputs) - return - } - if (!inputs.vision?.enabled) { - const newInputs = produce(inputs, (draft) => { - if (!draft.vision?.enabled) { - draft.vision = { - enabled: true, - configs: { - detail: Resolution.high, - valueSelector: [], - }, - } - } - }) - setInputs(newInputs) - } + handleVisionConfigAfterModelChanged() // eslint-disable-next-line react-hooks/exhaustive-deps }, [isVisionModel, modelChanged]) @@ -294,27 +277,6 @@ const useConfig = (id: string, payload: LLMNodeType) => { setInputs(newInputs) }, [inputs, setInputs]) - const handleVisionResolutionEnabledChange = useCallback((enabled: boolean) => { - const newInputs = produce(inputs, (draft) => { - if (!draft.vision) { - draft.vision = { - enabled, - } - } - else { - draft.vision.enabled = enabled - } - }) - setInputs(newInputs) - }, [inputs, setInputs]) - - const handleVisionResolutionChange = useCallback((config: VisionSetting) => { - const newInputs = produce(inputs, (draft) => { - draft.vision.configs = config - }) - setInputs(newInputs) - }, [inputs, setInputs]) - const filterInputVar = useCallback((varPayload: Var) => { return [VarType.number, VarType.string, VarType.secret].includes(varPayload.type) }, [])