feat: add datasets detail context and provider for improved data vali… (#16451)

This commit is contained in:
Wu Tianwei 2025-03-24 14:30:26 +08:00 committed by GitHub
parent 83cd14104d
commit 9701b573e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 192 additions and 35 deletions

View File

@ -0,0 +1,53 @@
import type { FC } from 'react'
import { createContext, useCallback, useEffect, useRef } from 'react'
import { createDatasetsDetailStore } from './store'
import type { CommonNodeType, Node } from '../types'
import { BlockEnum } from '../types'
import type { KnowledgeRetrievalNodeType } from '../nodes/knowledge-retrieval/types'
import { fetchDatasets } from '@/service/datasets'
type DatasetsDetailStoreApi = ReturnType<typeof createDatasetsDetailStore>
type DatasetsDetailContextType = DatasetsDetailStoreApi | undefined
export const DatasetsDetailContext = createContext<DatasetsDetailContextType>(undefined)
type DatasetsDetailProviderProps = {
nodes: Node[]
children: React.ReactNode
}
const DatasetsDetailProvider: FC<DatasetsDetailProviderProps> = ({
nodes,
children,
}) => {
const storeRef = useRef<DatasetsDetailStoreApi>()
if (!storeRef.current)
storeRef.current = createDatasetsDetailStore()
const updateDatasetsDetail = useCallback(async (datasetIds: string[]) => {
const { data: datasetsDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: datasetIds } })
if (datasetsDetail && datasetsDetail.length > 0)
storeRef.current!.getState().updateDatasetsDetail(datasetsDetail)
}, [])
useEffect(() => {
if (!storeRef.current) return
const knowledgeRetrievalNodes = nodes.filter(node => node.data.type === BlockEnum.KnowledgeRetrieval)
const allDatasetIds = knowledgeRetrievalNodes.reduce<string[]>((acc, node) => {
return Array.from(new Set([...acc, ...(node.data as CommonNodeType<KnowledgeRetrievalNodeType>).dataset_ids]))
}, [])
if (allDatasetIds.length === 0) return
updateDatasetsDetail(allDatasetIds)
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [])
return (
<DatasetsDetailContext.Provider value={storeRef.current!}>
{children}
</DatasetsDetailContext.Provider>
)
}
export default DatasetsDetailProvider

View File

@ -0,0 +1,38 @@
import { useContext } from 'react'
import { createStore, useStore } from 'zustand'
import type { DataSet } from '@/models/datasets'
import { DatasetsDetailContext } from './provider'
import produce from 'immer'
type DatasetsDetailStore = {
datasetsDetail: Record<string, DataSet>
updateDatasetsDetail: (datasetsDetail: DataSet[]) => void
}
export const createDatasetsDetailStore = () => {
return createStore<DatasetsDetailStore>((set, get) => ({
datasetsDetail: {},
updateDatasetsDetail: (datasets: DataSet[]) => {
const oldDatasetsDetail = get().datasetsDetail
const datasetsDetail = datasets.reduce<Record<string, DataSet>>((acc, dataset) => {
acc[dataset.id] = dataset
return acc
}, {})
// Merge new datasets detail into old one
const newDatasetsDetail = produce(oldDatasetsDetail, (draft) => {
Object.entries(datasetsDetail).forEach(([key, value]) => {
draft[key] = value
})
})
set({ datasetsDetail: newDatasetsDetail })
},
}))
}
export const useDatasetsDetailStore = <T>(selector: (state: DatasetsDetailStore) => T): T => {
const store = useContext(DatasetsDetailContext)
if (!store)
throw new Error('Missing DatasetsDetailContext.Provider in the tree')
return useStore(store, selector)
}

View File

@ -160,7 +160,7 @@ const Header: FC = () => {
const { mutateAsync: publishWorkflow } = usePublishWorkflow(appID!)
const onPublish = useCallback(async (params?: PublishWorkflowParams) => {
if (handleCheckBeforePublish()) {
if (await handleCheckBeforePublish()) {
const res = await publishWorkflow({
title: params?.title || '',
releaseNotes: params?.releaseNotes || '',

View File

@ -1,10 +1,12 @@
import {
useCallback,
useMemo,
useRef,
} from 'react'
import { useTranslation } from 'react-i18next'
import { useStoreApi } from 'reactflow'
import type {
CommonNodeType,
Edge,
Node,
} from '../types'
@ -27,6 +29,10 @@ import { useGetLanguage } from '@/context/i18n'
import type { AgentNodeType } from '../nodes/agent/types'
import { useStrategyProviders } from '@/service/use-strategy'
import { canFindTool } from '@/utils'
import { useDatasetsDetailStore } from '../datasets-detail-store/store'
import type { KnowledgeRetrievalNodeType } from '../nodes/knowledge-retrieval/types'
import type { DataSet } from '@/models/datasets'
import { fetchDatasets } from '@/service/datasets'
export const useChecklist = (nodes: Node[], edges: Edge[]) => {
const { t } = useTranslation()
@ -37,6 +43,24 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
const customTools = useStore(s => s.customTools)
const workflowTools = useStore(s => s.workflowTools)
const { data: strategyProviders } = useStrategyProviders()
const datasetsDetail = useDatasetsDetailStore(s => s.datasetsDetail)
const getCheckData = useCallback((data: CommonNodeType<{}>) => {
let checkData = data
if (data.type === BlockEnum.KnowledgeRetrieval) {
const datasetIds = (data as CommonNodeType<KnowledgeRetrievalNodeType>).dataset_ids
const _datasets = datasetIds.reduce<DataSet[]>((acc, id) => {
if (datasetsDetail[id])
acc.push(datasetsDetail[id])
return acc
}, [])
checkData = {
...data,
_datasets,
} as CommonNodeType<KnowledgeRetrievalNodeType>
}
return checkData
}, [datasetsDetail])
const needWarningNodes = useMemo(() => {
const list = []
@ -75,7 +99,8 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
}
if (node.type === CUSTOM_NODE) {
const { errorMessage } = nodesExtraData[node.data.type].checkValid(node.data, t, moreDataForCheckValid)
const checkData = getCheckData(node.data)
const { errorMessage } = nodesExtraData[node.data.type].checkValid(checkData, t, moreDataForCheckValid)
if (errorMessage || !validNodes.find(n => n.id === node.id)) {
list.push({
@ -109,7 +134,7 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
}
return list
}, [nodes, edges, isChatMode, buildInTools, customTools, workflowTools, language, nodesExtraData, t, strategyProviders])
}, [nodes, edges, isChatMode, buildInTools, customTools, workflowTools, language, nodesExtraData, t, strategyProviders, getCheckData])
return needWarningNodes
}
@ -125,8 +150,31 @@ export const useChecklistBeforePublish = () => {
const store = useStoreApi()
const nodesExtraData = useNodesExtraData()
const { data: strategyProviders } = useStrategyProviders()
const updateDatasetsDetail = useDatasetsDetailStore(s => s.updateDatasetsDetail)
const updateTime = useRef(0)
const handleCheckBeforePublish = useCallback(() => {
const getCheckData = useCallback((data: CommonNodeType<{}>, datasets: DataSet[]) => {
let checkData = data
if (data.type === BlockEnum.KnowledgeRetrieval) {
const datasetIds = (data as CommonNodeType<KnowledgeRetrievalNodeType>).dataset_ids
const datasetsDetail = datasets.reduce<Record<string, DataSet>>((acc, dataset) => {
acc[dataset.id] = dataset
return acc
}, {})
const _datasets = datasetIds.reduce<DataSet[]>((acc, id) => {
if (datasetsDetail[id])
acc.push(datasetsDetail[id])
return acc
}, [])
checkData = {
...data,
_datasets,
} as CommonNodeType<KnowledgeRetrievalNodeType>
}
return checkData
}, [])
const handleCheckBeforePublish = useCallback(async () => {
const {
getNodes,
edges,
@ -141,6 +189,24 @@ export const useChecklistBeforePublish = () => {
notify({ type: 'error', message: t('workflow.common.maxTreeDepth', { depth: MAX_TREE_DEPTH }) })
return false
}
// Before publish, we need to fetch datasets detail, in case of the settings of datasets have been changed
const knowledgeRetrievalNodes = nodes.filter(node => node.data.type === BlockEnum.KnowledgeRetrieval)
const allDatasetIds = knowledgeRetrievalNodes.reduce<string[]>((acc, node) => {
return Array.from(new Set([...acc, ...(node.data as CommonNodeType<KnowledgeRetrievalNodeType>).dataset_ids]))
}, [])
let datasets: DataSet[] = []
if (allDatasetIds.length > 0) {
updateTime.current = updateTime.current + 1
const currUpdateTime = updateTime.current
const { data: datasetsDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: allDatasetIds } })
if (datasetsDetail && datasetsDetail.length > 0) {
// avoid old data to overwrite the new data
if (currUpdateTime < updateTime.current)
return false
datasets = datasetsDetail
updateDatasetsDetail(datasetsDetail)
}
}
for (let i = 0; i < nodes.length; i++) {
const node = nodes[i]
@ -161,7 +227,8 @@ export const useChecklistBeforePublish = () => {
}
}
const { errorMessage } = nodesExtraData[node.data.type as BlockEnum].checkValid(node.data, t, moreDataForCheckValid)
const checkData = getCheckData(node.data, datasets)
const { errorMessage } = nodesExtraData[node.data.type as BlockEnum].checkValid(checkData, t, moreDataForCheckValid)
if (errorMessage) {
notify({ type: 'error', message: `[${node.data.title}] ${errorMessage}` })
@ -185,7 +252,7 @@ export const useChecklistBeforePublish = () => {
}
return true
}, [store, isChatMode, notify, t, buildInTools, customTools, workflowTools, language, nodesExtraData, strategyProviders])
}, [store, isChatMode, notify, t, buildInTools, customTools, workflowTools, language, nodesExtraData, strategyProviders, updateDatasetsDetail, getCheckData])
return {
handleCheckBeforePublish,

View File

@ -99,6 +99,7 @@ import { useEventEmitterContextContext } from '@/context/event-emitter'
import Confirm from '@/app/components/base/confirm'
import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants'
import { fetchFileUploadConfig } from '@/service/common'
import DatasetsDetailProvider from './datasets-detail-store/provider'
const nodeTypes = {
[CUSTOM_NODE]: CustomNode,
@ -448,11 +449,13 @@ const WorkflowWrap = memo(() => {
nodes={nodesData}
edges={edgesData} >
<FeaturesProvider features={initialFeatures}>
<Workflow
nodes={nodesData}
edges={edgesData}
viewport={data?.graph.viewport}
/>
<DatasetsDetailProvider nodes={nodesData}>
<Workflow
nodes={nodesData}
edges={edgesData}
viewport={data?.graph.viewport}
/>
</DatasetsDetailProvider>
</FeaturesProvider>
</WorkflowHistoryProvider>
</ReactFlowProvider>

View File

@ -1,33 +1,30 @@
import { type FC, useEffect, useRef, useState } from 'react'
import { type FC, useEffect, useState } from 'react'
import React from 'react'
import type { KnowledgeRetrievalNodeType } from './types'
import { Folder } from '@/app/components/base/icons/src/vender/solid/files'
import type { NodeProps } from '@/app/components/workflow/types'
import { fetchDatasets } from '@/service/datasets'
import type { DataSet } from '@/models/datasets'
import { useDatasetsDetailStore } from '../../datasets-detail-store/store'
const Node: FC<NodeProps<KnowledgeRetrievalNodeType>> = ({
data,
}) => {
const [selectedDatasets, setSelectedDatasets] = useState<DataSet[]>([])
const updateTime = useRef(0)
useEffect(() => {
(async () => {
updateTime.current = updateTime.current + 1
const currUpdateTime = updateTime.current
const datasetsDetail = useDatasetsDetailStore(s => s.datasetsDetail)
if (data.dataset_ids?.length > 0) {
const { data: dataSetsWithDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: data.dataset_ids } })
// avoid old data overwrite new data
if (currUpdateTime < updateTime.current)
return
setSelectedDatasets(dataSetsWithDetail)
}
else {
setSelectedDatasets([])
}
})()
}, [data.dataset_ids])
useEffect(() => {
if (data.dataset_ids?.length > 0) {
const dataSetsWithDetail = data.dataset_ids.reduce<DataSet[]>((acc, id) => {
if (datasetsDetail[id])
acc.push(datasetsDetail[id])
return acc
}, [])
setSelectedDatasets(dataSetsWithDetail)
}
else {
setSelectedDatasets([])
}
}, [data.dataset_ids, datasetsDetail])
if (!selectedDatasets.length)
return null

View File

@ -41,6 +41,7 @@ import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-s
import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import useAvailableVarList from '@/app/components/workflow/nodes/_base/hooks/use-available-var-list'
import { useDatasetsDetailStore } from '../../datasets-detail-store/store'
const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
const { nodesReadOnly: readOnly } = useNodesReadOnly()
@ -49,6 +50,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
const startNode = getBeforeNodesInSameBranch(id).find(node => node.data.type === BlockEnum.Start)
const startNodeId = startNode?.id
const { inputs, setInputs: doSetInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload)
const updateDatasetsDetail = useDatasetsDetailStore(s => s.updateDatasetsDetail)
const inputRef = useRef(inputs)
@ -218,15 +220,12 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
(async () => {
const inputs = inputRef.current
const datasetIds = inputs.dataset_ids
let _datasets = selectedDatasets
if (datasetIds?.length > 0) {
const { data: dataSetsWithDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: datasetIds } as any })
_datasets = dataSetsWithDetail
setSelectedDatasets(dataSetsWithDetail)
}
const newInputs = produce(inputs, (draft) => {
draft.dataset_ids = datasetIds
draft._datasets = _datasets
})
setInputs(newInputs)
setSelectedDatasetsLoaded(true)
@ -256,7 +255,6 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
} = getSelectedDatasetsMode(newDatasets)
const newInputs = produce(inputs, (draft) => {
draft.dataset_ids = newDatasets.map(d => d.id)
draft._datasets = newDatasets
if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) {
const multipleRetrievalConfig = draft.multiple_retrieval_config
@ -266,6 +264,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
})
}
})
updateDatasetsDetail(newDatasets)
setInputs(newInputs)
setSelectedDatasets(newDatasets)
@ -275,7 +274,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|| allExternal
)
setRerankModelOpen(true)
}, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel, currentRerankProvider])
}, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel, currentRerankProvider, updateDatasetsDetail])
const filterVar = useCallback((varPayload: Var) => {
return varPayload.type === VarType.string