mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 03:49:04 +08:00
feat: add datasets detail context and provider for improved data vali… (#16451)
This commit is contained in:
parent
83cd14104d
commit
9701b573e0
@ -0,0 +1,53 @@
|
||||
import type { FC } from 'react'
|
||||
import { createContext, useCallback, useEffect, useRef } from 'react'
|
||||
import { createDatasetsDetailStore } from './store'
|
||||
import type { CommonNodeType, Node } from '../types'
|
||||
import { BlockEnum } from '../types'
|
||||
import type { KnowledgeRetrievalNodeType } from '../nodes/knowledge-retrieval/types'
|
||||
import { fetchDatasets } from '@/service/datasets'
|
||||
|
||||
type DatasetsDetailStoreApi = ReturnType<typeof createDatasetsDetailStore>
|
||||
|
||||
type DatasetsDetailContextType = DatasetsDetailStoreApi | undefined
|
||||
|
||||
export const DatasetsDetailContext = createContext<DatasetsDetailContextType>(undefined)
|
||||
|
||||
type DatasetsDetailProviderProps = {
|
||||
nodes: Node[]
|
||||
children: React.ReactNode
|
||||
}
|
||||
|
||||
const DatasetsDetailProvider: FC<DatasetsDetailProviderProps> = ({
|
||||
nodes,
|
||||
children,
|
||||
}) => {
|
||||
const storeRef = useRef<DatasetsDetailStoreApi>()
|
||||
|
||||
if (!storeRef.current)
|
||||
storeRef.current = createDatasetsDetailStore()
|
||||
|
||||
const updateDatasetsDetail = useCallback(async (datasetIds: string[]) => {
|
||||
const { data: datasetsDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: datasetIds } })
|
||||
if (datasetsDetail && datasetsDetail.length > 0)
|
||||
storeRef.current!.getState().updateDatasetsDetail(datasetsDetail)
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
if (!storeRef.current) return
|
||||
const knowledgeRetrievalNodes = nodes.filter(node => node.data.type === BlockEnum.KnowledgeRetrieval)
|
||||
const allDatasetIds = knowledgeRetrievalNodes.reduce<string[]>((acc, node) => {
|
||||
return Array.from(new Set([...acc, ...(node.data as CommonNodeType<KnowledgeRetrievalNodeType>).dataset_ids]))
|
||||
}, [])
|
||||
if (allDatasetIds.length === 0) return
|
||||
updateDatasetsDetail(allDatasetIds)
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [])
|
||||
|
||||
return (
|
||||
<DatasetsDetailContext.Provider value={storeRef.current!}>
|
||||
{children}
|
||||
</DatasetsDetailContext.Provider>
|
||||
)
|
||||
}
|
||||
|
||||
export default DatasetsDetailProvider
|
38
web/app/components/workflow/datasets-detail-store/store.ts
Normal file
38
web/app/components/workflow/datasets-detail-store/store.ts
Normal file
@ -0,0 +1,38 @@
|
||||
import { useContext } from 'react'
|
||||
import { createStore, useStore } from 'zustand'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
import { DatasetsDetailContext } from './provider'
|
||||
import produce from 'immer'
|
||||
|
||||
type DatasetsDetailStore = {
|
||||
datasetsDetail: Record<string, DataSet>
|
||||
updateDatasetsDetail: (datasetsDetail: DataSet[]) => void
|
||||
}
|
||||
|
||||
export const createDatasetsDetailStore = () => {
|
||||
return createStore<DatasetsDetailStore>((set, get) => ({
|
||||
datasetsDetail: {},
|
||||
updateDatasetsDetail: (datasets: DataSet[]) => {
|
||||
const oldDatasetsDetail = get().datasetsDetail
|
||||
const datasetsDetail = datasets.reduce<Record<string, DataSet>>((acc, dataset) => {
|
||||
acc[dataset.id] = dataset
|
||||
return acc
|
||||
}, {})
|
||||
// Merge new datasets detail into old one
|
||||
const newDatasetsDetail = produce(oldDatasetsDetail, (draft) => {
|
||||
Object.entries(datasetsDetail).forEach(([key, value]) => {
|
||||
draft[key] = value
|
||||
})
|
||||
})
|
||||
set({ datasetsDetail: newDatasetsDetail })
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
export const useDatasetsDetailStore = <T>(selector: (state: DatasetsDetailStore) => T): T => {
|
||||
const store = useContext(DatasetsDetailContext)
|
||||
if (!store)
|
||||
throw new Error('Missing DatasetsDetailContext.Provider in the tree')
|
||||
|
||||
return useStore(store, selector)
|
||||
}
|
@ -160,7 +160,7 @@ const Header: FC = () => {
|
||||
const { mutateAsync: publishWorkflow } = usePublishWorkflow(appID!)
|
||||
|
||||
const onPublish = useCallback(async (params?: PublishWorkflowParams) => {
|
||||
if (handleCheckBeforePublish()) {
|
||||
if (await handleCheckBeforePublish()) {
|
||||
const res = await publishWorkflow({
|
||||
title: params?.title || '',
|
||||
releaseNotes: params?.releaseNotes || '',
|
||||
|
@ -1,10 +1,12 @@
|
||||
import {
|
||||
useCallback,
|
||||
useMemo,
|
||||
useRef,
|
||||
} from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useStoreApi } from 'reactflow'
|
||||
import type {
|
||||
CommonNodeType,
|
||||
Edge,
|
||||
Node,
|
||||
} from '../types'
|
||||
@ -27,6 +29,10 @@ import { useGetLanguage } from '@/context/i18n'
|
||||
import type { AgentNodeType } from '../nodes/agent/types'
|
||||
import { useStrategyProviders } from '@/service/use-strategy'
|
||||
import { canFindTool } from '@/utils'
|
||||
import { useDatasetsDetailStore } from '../datasets-detail-store/store'
|
||||
import type { KnowledgeRetrievalNodeType } from '../nodes/knowledge-retrieval/types'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
import { fetchDatasets } from '@/service/datasets'
|
||||
|
||||
export const useChecklist = (nodes: Node[], edges: Edge[]) => {
|
||||
const { t } = useTranslation()
|
||||
@ -37,6 +43,24 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
|
||||
const customTools = useStore(s => s.customTools)
|
||||
const workflowTools = useStore(s => s.workflowTools)
|
||||
const { data: strategyProviders } = useStrategyProviders()
|
||||
const datasetsDetail = useDatasetsDetailStore(s => s.datasetsDetail)
|
||||
|
||||
const getCheckData = useCallback((data: CommonNodeType<{}>) => {
|
||||
let checkData = data
|
||||
if (data.type === BlockEnum.KnowledgeRetrieval) {
|
||||
const datasetIds = (data as CommonNodeType<KnowledgeRetrievalNodeType>).dataset_ids
|
||||
const _datasets = datasetIds.reduce<DataSet[]>((acc, id) => {
|
||||
if (datasetsDetail[id])
|
||||
acc.push(datasetsDetail[id])
|
||||
return acc
|
||||
}, [])
|
||||
checkData = {
|
||||
...data,
|
||||
_datasets,
|
||||
} as CommonNodeType<KnowledgeRetrievalNodeType>
|
||||
}
|
||||
return checkData
|
||||
}, [datasetsDetail])
|
||||
|
||||
const needWarningNodes = useMemo(() => {
|
||||
const list = []
|
||||
@ -75,7 +99,8 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
|
||||
}
|
||||
|
||||
if (node.type === CUSTOM_NODE) {
|
||||
const { errorMessage } = nodesExtraData[node.data.type].checkValid(node.data, t, moreDataForCheckValid)
|
||||
const checkData = getCheckData(node.data)
|
||||
const { errorMessage } = nodesExtraData[node.data.type].checkValid(checkData, t, moreDataForCheckValid)
|
||||
|
||||
if (errorMessage || !validNodes.find(n => n.id === node.id)) {
|
||||
list.push({
|
||||
@ -109,7 +134,7 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
|
||||
}
|
||||
|
||||
return list
|
||||
}, [nodes, edges, isChatMode, buildInTools, customTools, workflowTools, language, nodesExtraData, t, strategyProviders])
|
||||
}, [nodes, edges, isChatMode, buildInTools, customTools, workflowTools, language, nodesExtraData, t, strategyProviders, getCheckData])
|
||||
|
||||
return needWarningNodes
|
||||
}
|
||||
@ -125,8 +150,31 @@ export const useChecklistBeforePublish = () => {
|
||||
const store = useStoreApi()
|
||||
const nodesExtraData = useNodesExtraData()
|
||||
const { data: strategyProviders } = useStrategyProviders()
|
||||
const updateDatasetsDetail = useDatasetsDetailStore(s => s.updateDatasetsDetail)
|
||||
const updateTime = useRef(0)
|
||||
|
||||
const handleCheckBeforePublish = useCallback(() => {
|
||||
const getCheckData = useCallback((data: CommonNodeType<{}>, datasets: DataSet[]) => {
|
||||
let checkData = data
|
||||
if (data.type === BlockEnum.KnowledgeRetrieval) {
|
||||
const datasetIds = (data as CommonNodeType<KnowledgeRetrievalNodeType>).dataset_ids
|
||||
const datasetsDetail = datasets.reduce<Record<string, DataSet>>((acc, dataset) => {
|
||||
acc[dataset.id] = dataset
|
||||
return acc
|
||||
}, {})
|
||||
const _datasets = datasetIds.reduce<DataSet[]>((acc, id) => {
|
||||
if (datasetsDetail[id])
|
||||
acc.push(datasetsDetail[id])
|
||||
return acc
|
||||
}, [])
|
||||
checkData = {
|
||||
...data,
|
||||
_datasets,
|
||||
} as CommonNodeType<KnowledgeRetrievalNodeType>
|
||||
}
|
||||
return checkData
|
||||
}, [])
|
||||
|
||||
const handleCheckBeforePublish = useCallback(async () => {
|
||||
const {
|
||||
getNodes,
|
||||
edges,
|
||||
@ -141,6 +189,24 @@ export const useChecklistBeforePublish = () => {
|
||||
notify({ type: 'error', message: t('workflow.common.maxTreeDepth', { depth: MAX_TREE_DEPTH }) })
|
||||
return false
|
||||
}
|
||||
// Before publish, we need to fetch datasets detail, in case of the settings of datasets have been changed
|
||||
const knowledgeRetrievalNodes = nodes.filter(node => node.data.type === BlockEnum.KnowledgeRetrieval)
|
||||
const allDatasetIds = knowledgeRetrievalNodes.reduce<string[]>((acc, node) => {
|
||||
return Array.from(new Set([...acc, ...(node.data as CommonNodeType<KnowledgeRetrievalNodeType>).dataset_ids]))
|
||||
}, [])
|
||||
let datasets: DataSet[] = []
|
||||
if (allDatasetIds.length > 0) {
|
||||
updateTime.current = updateTime.current + 1
|
||||
const currUpdateTime = updateTime.current
|
||||
const { data: datasetsDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: allDatasetIds } })
|
||||
if (datasetsDetail && datasetsDetail.length > 0) {
|
||||
// avoid old data to overwrite the new data
|
||||
if (currUpdateTime < updateTime.current)
|
||||
return false
|
||||
datasets = datasetsDetail
|
||||
updateDatasetsDetail(datasetsDetail)
|
||||
}
|
||||
}
|
||||
|
||||
for (let i = 0; i < nodes.length; i++) {
|
||||
const node = nodes[i]
|
||||
@ -161,7 +227,8 @@ export const useChecklistBeforePublish = () => {
|
||||
}
|
||||
}
|
||||
|
||||
const { errorMessage } = nodesExtraData[node.data.type as BlockEnum].checkValid(node.data, t, moreDataForCheckValid)
|
||||
const checkData = getCheckData(node.data, datasets)
|
||||
const { errorMessage } = nodesExtraData[node.data.type as BlockEnum].checkValid(checkData, t, moreDataForCheckValid)
|
||||
|
||||
if (errorMessage) {
|
||||
notify({ type: 'error', message: `[${node.data.title}] ${errorMessage}` })
|
||||
@ -185,7 +252,7 @@ export const useChecklistBeforePublish = () => {
|
||||
}
|
||||
|
||||
return true
|
||||
}, [store, isChatMode, notify, t, buildInTools, customTools, workflowTools, language, nodesExtraData, strategyProviders])
|
||||
}, [store, isChatMode, notify, t, buildInTools, customTools, workflowTools, language, nodesExtraData, strategyProviders, updateDatasetsDetail, getCheckData])
|
||||
|
||||
return {
|
||||
handleCheckBeforePublish,
|
||||
|
@ -99,6 +99,7 @@ import { useEventEmitterContextContext } from '@/context/event-emitter'
|
||||
import Confirm from '@/app/components/base/confirm'
|
||||
import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants'
|
||||
import { fetchFileUploadConfig } from '@/service/common'
|
||||
import DatasetsDetailProvider from './datasets-detail-store/provider'
|
||||
|
||||
const nodeTypes = {
|
||||
[CUSTOM_NODE]: CustomNode,
|
||||
@ -448,11 +449,13 @@ const WorkflowWrap = memo(() => {
|
||||
nodes={nodesData}
|
||||
edges={edgesData} >
|
||||
<FeaturesProvider features={initialFeatures}>
|
||||
<Workflow
|
||||
nodes={nodesData}
|
||||
edges={edgesData}
|
||||
viewport={data?.graph.viewport}
|
||||
/>
|
||||
<DatasetsDetailProvider nodes={nodesData}>
|
||||
<Workflow
|
||||
nodes={nodesData}
|
||||
edges={edgesData}
|
||||
viewport={data?.graph.viewport}
|
||||
/>
|
||||
</DatasetsDetailProvider>
|
||||
</FeaturesProvider>
|
||||
</WorkflowHistoryProvider>
|
||||
</ReactFlowProvider>
|
||||
|
@ -1,33 +1,30 @@
|
||||
import { type FC, useEffect, useRef, useState } from 'react'
|
||||
import { type FC, useEffect, useState } from 'react'
|
||||
import React from 'react'
|
||||
import type { KnowledgeRetrievalNodeType } from './types'
|
||||
import { Folder } from '@/app/components/base/icons/src/vender/solid/files'
|
||||
import type { NodeProps } from '@/app/components/workflow/types'
|
||||
import { fetchDatasets } from '@/service/datasets'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
import { useDatasetsDetailStore } from '../../datasets-detail-store/store'
|
||||
|
||||
const Node: FC<NodeProps<KnowledgeRetrievalNodeType>> = ({
|
||||
data,
|
||||
}) => {
|
||||
const [selectedDatasets, setSelectedDatasets] = useState<DataSet[]>([])
|
||||
const updateTime = useRef(0)
|
||||
useEffect(() => {
|
||||
(async () => {
|
||||
updateTime.current = updateTime.current + 1
|
||||
const currUpdateTime = updateTime.current
|
||||
const datasetsDetail = useDatasetsDetailStore(s => s.datasetsDetail)
|
||||
|
||||
if (data.dataset_ids?.length > 0) {
|
||||
const { data: dataSetsWithDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: data.dataset_ids } })
|
||||
// avoid old data overwrite new data
|
||||
if (currUpdateTime < updateTime.current)
|
||||
return
|
||||
setSelectedDatasets(dataSetsWithDetail)
|
||||
}
|
||||
else {
|
||||
setSelectedDatasets([])
|
||||
}
|
||||
})()
|
||||
}, [data.dataset_ids])
|
||||
useEffect(() => {
|
||||
if (data.dataset_ids?.length > 0) {
|
||||
const dataSetsWithDetail = data.dataset_ids.reduce<DataSet[]>((acc, id) => {
|
||||
if (datasetsDetail[id])
|
||||
acc.push(datasetsDetail[id])
|
||||
return acc
|
||||
}, [])
|
||||
setSelectedDatasets(dataSetsWithDetail)
|
||||
}
|
||||
else {
|
||||
setSelectedDatasets([])
|
||||
}
|
||||
}, [data.dataset_ids, datasetsDetail])
|
||||
|
||||
if (!selectedDatasets.length)
|
||||
return null
|
||||
|
@ -41,6 +41,7 @@ import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-s
|
||||
import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import useAvailableVarList from '@/app/components/workflow/nodes/_base/hooks/use-available-var-list'
|
||||
import { useDatasetsDetailStore } from '../../datasets-detail-store/store'
|
||||
|
||||
const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||
const { nodesReadOnly: readOnly } = useNodesReadOnly()
|
||||
@ -49,6 +50,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||
const startNode = getBeforeNodesInSameBranch(id).find(node => node.data.type === BlockEnum.Start)
|
||||
const startNodeId = startNode?.id
|
||||
const { inputs, setInputs: doSetInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload)
|
||||
const updateDatasetsDetail = useDatasetsDetailStore(s => s.updateDatasetsDetail)
|
||||
|
||||
const inputRef = useRef(inputs)
|
||||
|
||||
@ -218,15 +220,12 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||
(async () => {
|
||||
const inputs = inputRef.current
|
||||
const datasetIds = inputs.dataset_ids
|
||||
let _datasets = selectedDatasets
|
||||
if (datasetIds?.length > 0) {
|
||||
const { data: dataSetsWithDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: datasetIds } as any })
|
||||
_datasets = dataSetsWithDetail
|
||||
setSelectedDatasets(dataSetsWithDetail)
|
||||
}
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
draft.dataset_ids = datasetIds
|
||||
draft._datasets = _datasets
|
||||
})
|
||||
setInputs(newInputs)
|
||||
setSelectedDatasetsLoaded(true)
|
||||
@ -256,7 +255,6 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||
} = getSelectedDatasetsMode(newDatasets)
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
draft.dataset_ids = newDatasets.map(d => d.id)
|
||||
draft._datasets = newDatasets
|
||||
|
||||
if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) {
|
||||
const multipleRetrievalConfig = draft.multiple_retrieval_config
|
||||
@ -266,6 +264,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||
})
|
||||
}
|
||||
})
|
||||
updateDatasetsDetail(newDatasets)
|
||||
setInputs(newInputs)
|
||||
setSelectedDatasets(newDatasets)
|
||||
|
||||
@ -275,7 +274,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||
|| allExternal
|
||||
)
|
||||
setRerankModelOpen(true)
|
||||
}, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel, currentRerankProvider])
|
||||
}, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel, currentRerankProvider, updateDatasetsDetail])
|
||||
|
||||
const filterVar = useCallback((varPayload: Var) => {
|
||||
return varPayload.type === VarType.string
|
||||
|
Loading…
x
Reference in New Issue
Block a user