mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 17:39:00 +08:00
Feat: rerank model verification in front end (#9271)
This commit is contained in:
parent
c6b74daa0a
commit
793205afc5
@ -1,6 +1,6 @@
|
|||||||
'use client'
|
'use client'
|
||||||
|
|
||||||
import { memo, useEffect, useMemo } from 'react'
|
import { memo, useCallback, useEffect, useMemo } from 'react'
|
||||||
import type { FC } from 'react'
|
import type { FC } from 'react'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
import WeightedScore from './weighted-score'
|
import WeightedScore from './weighted-score'
|
||||||
@ -11,7 +11,7 @@ import type {
|
|||||||
DatasetConfigs,
|
DatasetConfigs,
|
||||||
} from '@/models/debug'
|
} from '@/models/debug'
|
||||||
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
|
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
|
||||||
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||||
import type { ModelConfig } from '@/app/components/workflow/types'
|
import type { ModelConfig } from '@/app/components/workflow/types'
|
||||||
import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal'
|
import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal'
|
||||||
import Tooltip from '@/app/components/base/tooltip'
|
import Tooltip from '@/app/components/base/tooltip'
|
||||||
@ -23,6 +23,7 @@ import { RerankingModeEnum } from '@/models/datasets'
|
|||||||
import cn from '@/utils/classnames'
|
import cn from '@/utils/classnames'
|
||||||
import { useSelectedDatasetsMode } from '@/app/components/workflow/nodes/knowledge-retrieval/hooks'
|
import { useSelectedDatasetsMode } from '@/app/components/workflow/nodes/knowledge-retrieval/hooks'
|
||||||
import Switch from '@/app/components/base/switch'
|
import Switch from '@/app/components/base/switch'
|
||||||
|
import Toast from '@/app/components/base/toast'
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
datasetConfigs: DatasetConfigs
|
datasetConfigs: DatasetConfigs
|
||||||
@ -60,6 +61,24 @@ const ConfigContent: FC<Props> = ({
|
|||||||
modelList: rerankModelList,
|
modelList: rerankModelList,
|
||||||
defaultModel: rerankDefaultModel,
|
defaultModel: rerankDefaultModel,
|
||||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
||||||
|
|
||||||
|
const {
|
||||||
|
currentModel,
|
||||||
|
} = useCurrentProviderAndModel(
|
||||||
|
rerankModelList,
|
||||||
|
rerankDefaultModel
|
||||||
|
? {
|
||||||
|
...rerankDefaultModel,
|
||||||
|
provider: rerankDefaultModel.provider.provider,
|
||||||
|
}
|
||||||
|
: undefined,
|
||||||
|
)
|
||||||
|
|
||||||
|
const handleDisabledSwitchClick = useCallback(() => {
|
||||||
|
if (!currentModel)
|
||||||
|
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
|
||||||
|
}, [currentModel, rerankDefaultModel, t])
|
||||||
|
|
||||||
const rerankModel = (() => {
|
const rerankModel = (() => {
|
||||||
if (datasetConfigs.reranking_model?.reranking_provider_name) {
|
if (datasetConfigs.reranking_model?.reranking_provider_name) {
|
||||||
return {
|
return {
|
||||||
@ -231,16 +250,22 @@ const ConfigContent: FC<Props> = ({
|
|||||||
<div className='flex items-center'>
|
<div className='flex items-center'>
|
||||||
{
|
{
|
||||||
selectedDatasetsMode.allEconomic && (
|
selectedDatasetsMode.allEconomic && (
|
||||||
<Switch
|
<div
|
||||||
size='md'
|
className='flex items-center'
|
||||||
defaultValue={showRerankModel}
|
onClick={handleDisabledSwitchClick}
|
||||||
onChange={(v) => {
|
>
|
||||||
onChange({
|
<Switch
|
||||||
...datasetConfigs,
|
size='md'
|
||||||
reranking_enable: v,
|
defaultValue={currentModel ? showRerankModel : false}
|
||||||
})
|
disabled={!currentModel}
|
||||||
}}
|
onChange={(v) => {
|
||||||
/>
|
onChange({
|
||||||
|
...datasetConfigs,
|
||||||
|
reranking_enable: v,
|
||||||
|
})
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
<div className='leading-[32px] ml-1 text-text-secondary system-sm-semibold'>{t('common.modelProvider.rerankModel.key')}</div>
|
<div className='leading-[32px] ml-1 text-text-secondary system-sm-semibold'>{t('common.modelProvider.rerankModel.key')}</div>
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
'use client'
|
'use client'
|
||||||
import type { FC } from 'react'
|
import type { FC } from 'react'
|
||||||
import React from 'react'
|
import React, { useCallback } from 'react'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
|
|
||||||
import cn from '@/utils/classnames'
|
import cn from '@/utils/classnames'
|
||||||
@ -11,7 +11,7 @@ import Switch from '@/app/components/base/switch'
|
|||||||
import Tooltip from '@/app/components/base/tooltip'
|
import Tooltip from '@/app/components/base/tooltip'
|
||||||
import type { RetrievalConfig } from '@/types/app'
|
import type { RetrievalConfig } from '@/types/app'
|
||||||
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
|
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
|
||||||
import { useModelListAndDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
import { useCurrentProviderAndModel, useModelListAndDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||||
import {
|
import {
|
||||||
DEFAULT_WEIGHTED_SCORE,
|
DEFAULT_WEIGHTED_SCORE,
|
||||||
@ -19,6 +19,7 @@ import {
|
|||||||
WeightedScoreEnum,
|
WeightedScoreEnum,
|
||||||
} from '@/models/datasets'
|
} from '@/models/datasets'
|
||||||
import WeightedScore from '@/app/components/app/configuration/dataset-config/params-config/weighted-score'
|
import WeightedScore from '@/app/components/app/configuration/dataset-config/params-config/weighted-score'
|
||||||
|
import Toast from '@/app/components/base/toast'
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
type: RETRIEVE_METHOD
|
type: RETRIEVE_METHOD
|
||||||
@ -38,6 +39,24 @@ const RetrievalParamConfig: FC<Props> = ({
|
|||||||
defaultModel: rerankDefaultModel,
|
defaultModel: rerankDefaultModel,
|
||||||
modelList: rerankModelList,
|
modelList: rerankModelList,
|
||||||
} = useModelListAndDefaultModel(ModelTypeEnum.rerank)
|
} = useModelListAndDefaultModel(ModelTypeEnum.rerank)
|
||||||
|
|
||||||
|
const {
|
||||||
|
currentModel,
|
||||||
|
} = useCurrentProviderAndModel(
|
||||||
|
rerankModelList,
|
||||||
|
rerankDefaultModel
|
||||||
|
? {
|
||||||
|
...rerankDefaultModel,
|
||||||
|
provider: rerankDefaultModel.provider.provider,
|
||||||
|
}
|
||||||
|
: undefined,
|
||||||
|
)
|
||||||
|
|
||||||
|
const handleDisabledSwitchClick = useCallback(() => {
|
||||||
|
if (!currentModel)
|
||||||
|
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
|
||||||
|
}, [currentModel, rerankDefaultModel, t])
|
||||||
|
|
||||||
const isHybridSearch = type === RETRIEVE_METHOD.hybrid
|
const isHybridSearch = type === RETRIEVE_METHOD.hybrid
|
||||||
|
|
||||||
const rerankModel = (() => {
|
const rerankModel = (() => {
|
||||||
@ -99,16 +118,22 @@ const RetrievalParamConfig: FC<Props> = ({
|
|||||||
<div>
|
<div>
|
||||||
<div className='flex h-8 items-center text-[13px] font-medium text-gray-900 space-x-2'>
|
<div className='flex h-8 items-center text-[13px] font-medium text-gray-900 space-x-2'>
|
||||||
{canToggleRerankModalEnable && (
|
{canToggleRerankModalEnable && (
|
||||||
<Switch
|
<div
|
||||||
size='md'
|
className='flex items-center'
|
||||||
defaultValue={value.reranking_enable}
|
onClick={handleDisabledSwitchClick}
|
||||||
onChange={(v) => {
|
>
|
||||||
onChange({
|
<Switch
|
||||||
...value,
|
size='md'
|
||||||
reranking_enable: v,
|
defaultValue={currentModel ? value.reranking_enable : false}
|
||||||
})
|
onChange={(v) => {
|
||||||
}}
|
onChange({
|
||||||
/>
|
...value,
|
||||||
|
reranking_enable: v,
|
||||||
|
})
|
||||||
|
}}
|
||||||
|
disabled={!currentModel}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
)}
|
)}
|
||||||
<div className='flex items-center'>
|
<div className='flex items-center'>
|
||||||
<span className='mr-0.5'>{t('common.modelProvider.rerankModel.key')}</span>
|
<span className='mr-0.5'>{t('common.modelProvider.rerankModel.key')}</span>
|
||||||
|
@ -1,17 +1,25 @@
|
|||||||
import { useCallback } from 'react'
|
import { useCallback } from 'react'
|
||||||
import { useStoreApi } from 'reactflow'
|
import { useStoreApi } from 'reactflow'
|
||||||
|
import { useTranslation } from 'react-i18next'
|
||||||
import { useWorkflowStore } from '../store'
|
import { useWorkflowStore } from '../store'
|
||||||
import {
|
import {
|
||||||
BlockEnum,
|
BlockEnum,
|
||||||
WorkflowRunningStatus,
|
WorkflowRunningStatus,
|
||||||
} from '../types'
|
} from '../types'
|
||||||
|
import type { KnowledgeRetrievalNodeType } from '../nodes/knowledge-retrieval/types'
|
||||||
|
import type { Node } from '../types'
|
||||||
|
import { useWorkflow } from './use-workflow'
|
||||||
import {
|
import {
|
||||||
useIsChatMode,
|
useIsChatMode,
|
||||||
useNodesSyncDraft,
|
useNodesSyncDraft,
|
||||||
useWorkflowInteractions,
|
useWorkflowInteractions,
|
||||||
useWorkflowRun,
|
useWorkflowRun,
|
||||||
} from './index'
|
} from './index'
|
||||||
|
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||||
|
import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||||
import { useFeaturesStore } from '@/app/components/base/features/hooks'
|
import { useFeaturesStore } from '@/app/components/base/features/hooks'
|
||||||
|
import KnowledgeRetrievalDefault from '@/app/components/workflow/nodes/knowledge-retrieval/default'
|
||||||
|
import Toast from '@/app/components/base/toast'
|
||||||
|
|
||||||
export const useWorkflowStartRun = () => {
|
export const useWorkflowStartRun = () => {
|
||||||
const store = useStoreApi()
|
const store = useStoreApi()
|
||||||
@ -20,7 +28,26 @@ export const useWorkflowStartRun = () => {
|
|||||||
const isChatMode = useIsChatMode()
|
const isChatMode = useIsChatMode()
|
||||||
const { handleCancelDebugAndPreviewPanel } = useWorkflowInteractions()
|
const { handleCancelDebugAndPreviewPanel } = useWorkflowInteractions()
|
||||||
const { handleRun } = useWorkflowRun()
|
const { handleRun } = useWorkflowRun()
|
||||||
|
const { isFromStartNode } = useWorkflow()
|
||||||
const { doSyncWorkflowDraft } = useNodesSyncDraft()
|
const { doSyncWorkflowDraft } = useNodesSyncDraft()
|
||||||
|
const { checkValid: checkKnowledgeRetrievalValid } = KnowledgeRetrievalDefault
|
||||||
|
const { t } = useTranslation()
|
||||||
|
const {
|
||||||
|
modelList: rerankModelList,
|
||||||
|
defaultModel: rerankDefaultModel,
|
||||||
|
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
||||||
|
|
||||||
|
const {
|
||||||
|
currentModel,
|
||||||
|
} = useCurrentProviderAndModel(
|
||||||
|
rerankModelList,
|
||||||
|
rerankDefaultModel
|
||||||
|
? {
|
||||||
|
...rerankDefaultModel,
|
||||||
|
provider: rerankDefaultModel.provider.provider,
|
||||||
|
}
|
||||||
|
: undefined,
|
||||||
|
)
|
||||||
|
|
||||||
const handleWorkflowStartRunInWorkflow = useCallback(async () => {
|
const handleWorkflowStartRunInWorkflow = useCallback(async () => {
|
||||||
const {
|
const {
|
||||||
@ -33,6 +60,9 @@ export const useWorkflowStartRun = () => {
|
|||||||
const { getNodes } = store.getState()
|
const { getNodes } = store.getState()
|
||||||
const nodes = getNodes()
|
const nodes = getNodes()
|
||||||
const startNode = nodes.find(node => node.data.type === BlockEnum.Start)
|
const startNode = nodes.find(node => node.data.type === BlockEnum.Start)
|
||||||
|
const knowledgeRetrievalNodes = nodes.filter((node: Node<KnowledgeRetrievalNodeType>) =>
|
||||||
|
node.data.type === BlockEnum.KnowledgeRetrieval,
|
||||||
|
)
|
||||||
const startVariables = startNode?.data.variables || []
|
const startVariables = startNode?.data.variables || []
|
||||||
const fileSettings = featuresStore!.getState().features.file
|
const fileSettings = featuresStore!.getState().features.file
|
||||||
const {
|
const {
|
||||||
@ -42,6 +72,31 @@ export const useWorkflowStartRun = () => {
|
|||||||
setShowEnvPanel,
|
setShowEnvPanel,
|
||||||
} = workflowStore.getState()
|
} = workflowStore.getState()
|
||||||
|
|
||||||
|
if (knowledgeRetrievalNodes.length > 0) {
|
||||||
|
for (const node of knowledgeRetrievalNodes) {
|
||||||
|
if (isFromStartNode(node.id)) {
|
||||||
|
const res = checkKnowledgeRetrievalValid(node.data, t)
|
||||||
|
if (!res.isValid || !currentModel || !rerankDefaultModel) {
|
||||||
|
const errorMessage = res.errorMessage
|
||||||
|
if (errorMessage) {
|
||||||
|
Toast.notify({
|
||||||
|
type: 'error',
|
||||||
|
message: errorMessage,
|
||||||
|
})
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
Toast.notify({
|
||||||
|
type: 'error',
|
||||||
|
message: t('appDebug.datasetConfig.rerankModelRequired'),
|
||||||
|
})
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
setShowEnvPanel(false)
|
setShowEnvPanel(false)
|
||||||
|
|
||||||
if (showDebugAndPreviewPanel) {
|
if (showDebugAndPreviewPanel) {
|
||||||
|
@ -235,6 +235,33 @@ export const useWorkflow = () => {
|
|||||||
return nodes.filter(node => node.parentId === nodeId)
|
return nodes.filter(node => node.parentId === nodeId)
|
||||||
}, [store])
|
}, [store])
|
||||||
|
|
||||||
|
const isFromStartNode = useCallback((nodeId: string) => {
|
||||||
|
const { getNodes } = store.getState()
|
||||||
|
const nodes = getNodes()
|
||||||
|
const currentNode = nodes.find(node => node.id === nodeId)
|
||||||
|
|
||||||
|
if (!currentNode)
|
||||||
|
return false
|
||||||
|
|
||||||
|
if (currentNode.data.type === BlockEnum.Start)
|
||||||
|
return true
|
||||||
|
|
||||||
|
const checkPreviousNodes = (node: Node) => {
|
||||||
|
const previousNodes = getBeforeNodeById(node.id)
|
||||||
|
|
||||||
|
for (const prevNode of previousNodes) {
|
||||||
|
if (prevNode.data.type === BlockEnum.Start)
|
||||||
|
return true
|
||||||
|
if (checkPreviousNodes(prevNode))
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return checkPreviousNodes(currentNode)
|
||||||
|
}, [store, getBeforeNodeById])
|
||||||
|
|
||||||
const handleOutVarRenameChange = useCallback((nodeId: string, oldValeSelector: ValueSelector, newVarSelector: ValueSelector) => {
|
const handleOutVarRenameChange = useCallback((nodeId: string, oldValeSelector: ValueSelector, newVarSelector: ValueSelector) => {
|
||||||
const { getNodes, setNodes } = store.getState()
|
const { getNodes, setNodes } = store.getState()
|
||||||
const afterNodes = getAfterNodesInSameBranch(nodeId)
|
const afterNodes = getAfterNodesInSameBranch(nodeId)
|
||||||
@ -389,6 +416,7 @@ export const useWorkflow = () => {
|
|||||||
checkParallelLimit,
|
checkParallelLimit,
|
||||||
checkNestedParallelLimit,
|
checkNestedParallelLimit,
|
||||||
isValidConnection,
|
isValidConnection,
|
||||||
|
isFromStartNode,
|
||||||
formatTimeFromNow,
|
formatTimeFromNow,
|
||||||
getNode,
|
getNode,
|
||||||
getBeforeNodeById,
|
getBeforeNodeById,
|
||||||
|
@ -172,6 +172,7 @@ const translation = {
|
|||||||
},
|
},
|
||||||
errorMsg: {
|
errorMsg: {
|
||||||
fieldRequired: '{{field}} is required',
|
fieldRequired: '{{field}} is required',
|
||||||
|
rerankModelRequired: 'Before turning on the Rerank Model, please confirm that the model has been successfully configured in the settings.',
|
||||||
authRequired: 'Authorization is required',
|
authRequired: 'Authorization is required',
|
||||||
invalidJson: '{{field}} is invalid JSON',
|
invalidJson: '{{field}} is invalid JSON',
|
||||||
fields: {
|
fields: {
|
||||||
|
@ -172,6 +172,7 @@ const translation = {
|
|||||||
},
|
},
|
||||||
errorMsg: {
|
errorMsg: {
|
||||||
fieldRequired: '{{field}} 不能为空',
|
fieldRequired: '{{field}} 不能为空',
|
||||||
|
rerankModelRequired: '开启 Rerank 模型前,请务必确认模型已在设置中成功配置。',
|
||||||
authRequired: '请先授权',
|
authRequired: '请先授权',
|
||||||
invalidJson: '{{field}} 是非法的 JSON',
|
invalidJson: '{{field}} 是非法的 JSON',
|
||||||
fields: {
|
fields: {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user