diff --git a/web/app/components/workflow/datasets-detail-store/provider.tsx b/web/app/components/workflow/datasets-detail-store/provider.tsx new file mode 100644 index 0000000000..7d6ede823a --- /dev/null +++ b/web/app/components/workflow/datasets-detail-store/provider.tsx @@ -0,0 +1,53 @@ +import type { FC } from 'react' +import { createContext, useCallback, useEffect, useRef } from 'react' +import { createDatasetsDetailStore } from './store' +import type { CommonNodeType, Node } from '../types' +import { BlockEnum } from '../types' +import type { KnowledgeRetrievalNodeType } from '../nodes/knowledge-retrieval/types' +import { fetchDatasets } from '@/service/datasets' + +type DatasetsDetailStoreApi = ReturnType + +type DatasetsDetailContextType = DatasetsDetailStoreApi | undefined + +export const DatasetsDetailContext = createContext(undefined) + +type DatasetsDetailProviderProps = { + nodes: Node[] + children: React.ReactNode +} + +const DatasetsDetailProvider: FC = ({ + nodes, + children, +}) => { + const storeRef = useRef() + + if (!storeRef.current) + storeRef.current = createDatasetsDetailStore() + + const updateDatasetsDetail = useCallback(async (datasetIds: string[]) => { + const { data: datasetsDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: datasetIds } }) + if (datasetsDetail && datasetsDetail.length > 0) + storeRef.current!.getState().updateDatasetsDetail(datasetsDetail) + }, []) + + useEffect(() => { + if (!storeRef.current) return + const knowledgeRetrievalNodes = nodes.filter(node => node.data.type === BlockEnum.KnowledgeRetrieval) + const allDatasetIds = knowledgeRetrievalNodes.reduce((acc, node) => { + return Array.from(new Set([...acc, ...(node.data as CommonNodeType).dataset_ids])) + }, []) + if (allDatasetIds.length === 0) return + updateDatasetsDetail(allDatasetIds) + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []) + + return ( + + {children} + + ) +} + +export default DatasetsDetailProvider diff --git a/web/app/components/workflow/datasets-detail-store/store.ts b/web/app/components/workflow/datasets-detail-store/store.ts new file mode 100644 index 0000000000..4bc8c335e5 --- /dev/null +++ b/web/app/components/workflow/datasets-detail-store/store.ts @@ -0,0 +1,38 @@ +import { useContext } from 'react' +import { createStore, useStore } from 'zustand' +import type { DataSet } from '@/models/datasets' +import { DatasetsDetailContext } from './provider' +import produce from 'immer' + +type DatasetsDetailStore = { + datasetsDetail: Record + updateDatasetsDetail: (datasetsDetail: DataSet[]) => void +} + +export const createDatasetsDetailStore = () => { + return createStore((set, get) => ({ + datasetsDetail: {}, + updateDatasetsDetail: (datasets: DataSet[]) => { + const oldDatasetsDetail = get().datasetsDetail + const datasetsDetail = datasets.reduce>((acc, dataset) => { + acc[dataset.id] = dataset + return acc + }, {}) + // Merge new datasets detail into old one + const newDatasetsDetail = produce(oldDatasetsDetail, (draft) => { + Object.entries(datasetsDetail).forEach(([key, value]) => { + draft[key] = value + }) + }) + set({ datasetsDetail: newDatasetsDetail }) + }, + })) +} + +export const useDatasetsDetailStore = (selector: (state: DatasetsDetailStore) => T): T => { + const store = useContext(DatasetsDetailContext) + if (!store) + throw new Error('Missing DatasetsDetailContext.Provider in the tree') + + return useStore(store, selector) +} diff --git a/web/app/components/workflow/header/index.tsx b/web/app/components/workflow/header/index.tsx index eb259dd01f..7e99f5dd6b 100644 --- a/web/app/components/workflow/header/index.tsx +++ b/web/app/components/workflow/header/index.tsx @@ -160,7 +160,7 @@ const Header: FC = () => { const { mutateAsync: publishWorkflow } = usePublishWorkflow(appID!) const onPublish = useCallback(async (params?: PublishWorkflowParams) => { - if (handleCheckBeforePublish()) { + if (await handleCheckBeforePublish()) { const res = await publishWorkflow({ title: params?.title || '', releaseNotes: params?.releaseNotes || '', diff --git a/web/app/components/workflow/hooks/use-checklist.ts b/web/app/components/workflow/hooks/use-checklist.ts index 7a3a99ab38..c1b0189b8b 100644 --- a/web/app/components/workflow/hooks/use-checklist.ts +++ b/web/app/components/workflow/hooks/use-checklist.ts @@ -1,10 +1,12 @@ import { useCallback, useMemo, + useRef, } from 'react' import { useTranslation } from 'react-i18next' import { useStoreApi } from 'reactflow' import type { + CommonNodeType, Edge, Node, } from '../types' @@ -27,6 +29,10 @@ import { useGetLanguage } from '@/context/i18n' import type { AgentNodeType } from '../nodes/agent/types' import { useStrategyProviders } from '@/service/use-strategy' import { canFindTool } from '@/utils' +import { useDatasetsDetailStore } from '../datasets-detail-store/store' +import type { KnowledgeRetrievalNodeType } from '../nodes/knowledge-retrieval/types' +import type { DataSet } from '@/models/datasets' +import { fetchDatasets } from '@/service/datasets' export const useChecklist = (nodes: Node[], edges: Edge[]) => { const { t } = useTranslation() @@ -37,6 +43,24 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => { const customTools = useStore(s => s.customTools) const workflowTools = useStore(s => s.workflowTools) const { data: strategyProviders } = useStrategyProviders() + const datasetsDetail = useDatasetsDetailStore(s => s.datasetsDetail) + + const getCheckData = useCallback((data: CommonNodeType<{}>) => { + let checkData = data + if (data.type === BlockEnum.KnowledgeRetrieval) { + const datasetIds = (data as CommonNodeType).dataset_ids + const _datasets = datasetIds.reduce((acc, id) => { + if (datasetsDetail[id]) + acc.push(datasetsDetail[id]) + return acc + }, []) + checkData = { + ...data, + _datasets, + } as CommonNodeType + } + return checkData + }, [datasetsDetail]) const needWarningNodes = useMemo(() => { const list = [] @@ -75,7 +99,8 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => { } if (node.type === CUSTOM_NODE) { - const { errorMessage } = nodesExtraData[node.data.type].checkValid(node.data, t, moreDataForCheckValid) + const checkData = getCheckData(node.data) + const { errorMessage } = nodesExtraData[node.data.type].checkValid(checkData, t, moreDataForCheckValid) if (errorMessage || !validNodes.find(n => n.id === node.id)) { list.push({ @@ -109,7 +134,7 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => { } return list - }, [nodes, edges, isChatMode, buildInTools, customTools, workflowTools, language, nodesExtraData, t, strategyProviders]) + }, [nodes, edges, isChatMode, buildInTools, customTools, workflowTools, language, nodesExtraData, t, strategyProviders, getCheckData]) return needWarningNodes } @@ -125,8 +150,31 @@ export const useChecklistBeforePublish = () => { const store = useStoreApi() const nodesExtraData = useNodesExtraData() const { data: strategyProviders } = useStrategyProviders() + const updateDatasetsDetail = useDatasetsDetailStore(s => s.updateDatasetsDetail) + const updateTime = useRef(0) - const handleCheckBeforePublish = useCallback(() => { + const getCheckData = useCallback((data: CommonNodeType<{}>, datasets: DataSet[]) => { + let checkData = data + if (data.type === BlockEnum.KnowledgeRetrieval) { + const datasetIds = (data as CommonNodeType).dataset_ids + const datasetsDetail = datasets.reduce>((acc, dataset) => { + acc[dataset.id] = dataset + return acc + }, {}) + const _datasets = datasetIds.reduce((acc, id) => { + if (datasetsDetail[id]) + acc.push(datasetsDetail[id]) + return acc + }, []) + checkData = { + ...data, + _datasets, + } as CommonNodeType + } + return checkData + }, []) + + const handleCheckBeforePublish = useCallback(async () => { const { getNodes, edges, @@ -141,6 +189,24 @@ export const useChecklistBeforePublish = () => { notify({ type: 'error', message: t('workflow.common.maxTreeDepth', { depth: MAX_TREE_DEPTH }) }) return false } + // Before publish, we need to fetch datasets detail, in case of the settings of datasets have been changed + const knowledgeRetrievalNodes = nodes.filter(node => node.data.type === BlockEnum.KnowledgeRetrieval) + const allDatasetIds = knowledgeRetrievalNodes.reduce((acc, node) => { + return Array.from(new Set([...acc, ...(node.data as CommonNodeType).dataset_ids])) + }, []) + let datasets: DataSet[] = [] + if (allDatasetIds.length > 0) { + updateTime.current = updateTime.current + 1 + const currUpdateTime = updateTime.current + const { data: datasetsDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: allDatasetIds } }) + if (datasetsDetail && datasetsDetail.length > 0) { + // avoid old data to overwrite the new data + if (currUpdateTime < updateTime.current) + return false + datasets = datasetsDetail + updateDatasetsDetail(datasetsDetail) + } + } for (let i = 0; i < nodes.length; i++) { const node = nodes[i] @@ -161,7 +227,8 @@ export const useChecklistBeforePublish = () => { } } - const { errorMessage } = nodesExtraData[node.data.type as BlockEnum].checkValid(node.data, t, moreDataForCheckValid) + const checkData = getCheckData(node.data, datasets) + const { errorMessage } = nodesExtraData[node.data.type as BlockEnum].checkValid(checkData, t, moreDataForCheckValid) if (errorMessage) { notify({ type: 'error', message: `[${node.data.title}] ${errorMessage}` }) @@ -185,7 +252,7 @@ export const useChecklistBeforePublish = () => { } return true - }, [store, isChatMode, notify, t, buildInTools, customTools, workflowTools, language, nodesExtraData, strategyProviders]) + }, [store, isChatMode, notify, t, buildInTools, customTools, workflowTools, language, nodesExtraData, strategyProviders, updateDatasetsDetail, getCheckData]) return { handleCheckBeforePublish, diff --git a/web/app/components/workflow/index.tsx b/web/app/components/workflow/index.tsx index 3384e39707..0c4d5aa671 100644 --- a/web/app/components/workflow/index.tsx +++ b/web/app/components/workflow/index.tsx @@ -99,6 +99,7 @@ import { useEventEmitterContextContext } from '@/context/event-emitter' import Confirm from '@/app/components/base/confirm' import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants' import { fetchFileUploadConfig } from '@/service/common' +import DatasetsDetailProvider from './datasets-detail-store/provider' const nodeTypes = { [CUSTOM_NODE]: CustomNode, @@ -448,11 +449,13 @@ const WorkflowWrap = memo(() => { nodes={nodesData} edges={edgesData} > - + + + diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/node.tsx b/web/app/components/workflow/nodes/knowledge-retrieval/node.tsx index 8dcb0d5ce0..c24c3e089c 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/node.tsx +++ b/web/app/components/workflow/nodes/knowledge-retrieval/node.tsx @@ -1,33 +1,30 @@ -import { type FC, useEffect, useRef, useState } from 'react' +import { type FC, useEffect, useState } from 'react' import React from 'react' import type { KnowledgeRetrievalNodeType } from './types' import { Folder } from '@/app/components/base/icons/src/vender/solid/files' import type { NodeProps } from '@/app/components/workflow/types' -import { fetchDatasets } from '@/service/datasets' import type { DataSet } from '@/models/datasets' +import { useDatasetsDetailStore } from '../../datasets-detail-store/store' const Node: FC> = ({ data, }) => { const [selectedDatasets, setSelectedDatasets] = useState([]) - const updateTime = useRef(0) - useEffect(() => { - (async () => { - updateTime.current = updateTime.current + 1 - const currUpdateTime = updateTime.current + const datasetsDetail = useDatasetsDetailStore(s => s.datasetsDetail) - if (data.dataset_ids?.length > 0) { - const { data: dataSetsWithDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: data.dataset_ids } }) - // avoid old data overwrite new data - if (currUpdateTime < updateTime.current) - return - setSelectedDatasets(dataSetsWithDetail) - } - else { - setSelectedDatasets([]) - } - })() - }, [data.dataset_ids]) + useEffect(() => { + if (data.dataset_ids?.length > 0) { + const dataSetsWithDetail = data.dataset_ids.reduce((acc, id) => { + if (datasetsDetail[id]) + acc.push(datasetsDetail[id]) + return acc + }, []) + setSelectedDatasets(dataSetsWithDetail) + } + else { + setSelectedDatasets([]) + } + }, [data.dataset_ids, datasetsDetail]) if (!selectedDatasets.length) return null diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts b/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts index 05c50f5ff7..42aa7def25 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts +++ b/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts @@ -41,6 +41,7 @@ import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-s import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import useAvailableVarList from '@/app/components/workflow/nodes/_base/hooks/use-available-var-list' +import { useDatasetsDetailStore } from '../../datasets-detail-store/store' const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { const { nodesReadOnly: readOnly } = useNodesReadOnly() @@ -49,6 +50,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { const startNode = getBeforeNodesInSameBranch(id).find(node => node.data.type === BlockEnum.Start) const startNodeId = startNode?.id const { inputs, setInputs: doSetInputs } = useNodeCrud(id, payload) + const updateDatasetsDetail = useDatasetsDetailStore(s => s.updateDatasetsDetail) const inputRef = useRef(inputs) @@ -218,15 +220,12 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { (async () => { const inputs = inputRef.current const datasetIds = inputs.dataset_ids - let _datasets = selectedDatasets if (datasetIds?.length > 0) { const { data: dataSetsWithDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: datasetIds } as any }) - _datasets = dataSetsWithDetail setSelectedDatasets(dataSetsWithDetail) } const newInputs = produce(inputs, (draft) => { draft.dataset_ids = datasetIds - draft._datasets = _datasets }) setInputs(newInputs) setSelectedDatasetsLoaded(true) @@ -256,7 +255,6 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { } = getSelectedDatasetsMode(newDatasets) const newInputs = produce(inputs, (draft) => { draft.dataset_ids = newDatasets.map(d => d.id) - draft._datasets = newDatasets if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) { const multipleRetrievalConfig = draft.multiple_retrieval_config @@ -266,6 +264,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { }) } }) + updateDatasetsDetail(newDatasets) setInputs(newInputs) setSelectedDatasets(newDatasets) @@ -275,7 +274,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { || allExternal ) setRerankModelOpen(true) - }, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel, currentRerankProvider]) + }, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel, currentRerankProvider, updateDatasetsDetail]) const filterVar = useCallback((varPayload: Var) => { return varPayload.type === VarType.string