diff --git a/web/src/hooks/flow-hooks.ts b/web/src/hooks/flow-hooks.ts index 754c6ca30..cfdc09be9 100644 --- a/web/src/hooks/flow-hooks.ts +++ b/web/src/hooks/flow-hooks.ts @@ -93,6 +93,8 @@ export const useFetchFlow = (): { } = useQuery({ queryKey: ['flowDetail'], initialData: {} as IFlow, + refetchOnReconnect: false, + refetchOnMount: false, queryFn: async () => { const { data } = await flowService.getCanvas({}, id); diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index 2645f0b71..bb646fe69 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -589,7 +589,7 @@ The above is the content you need to summarize.`, answer: 'Answer', categorize: 'Categorize', relevant: 'Relevant', - rewriteQuestion: 'RewriteQuestion', + rewriteQuestion: 'Rewrite', rewrite: 'Rewrite', begin: 'Begin', message: 'Message', diff --git a/web/src/pages/flow/canvas/edge/index.tsx b/web/src/pages/flow/canvas/edge/index.tsx index 84ac87afb..9e5bf20d7 100644 --- a/web/src/pages/flow/canvas/edge/index.tsx +++ b/web/src/pages/flow/canvas/edge/index.tsx @@ -6,7 +6,8 @@ import { } from 'reactflow'; import useGraphStore from '../../store'; -import { useFetchFlow } from '@/hooks/flow-hooks'; +import { IFlow } from '@/interfaces/database/flow'; +import { useQueryClient } from '@tanstack/react-query'; import { useMemo } from 'react'; import styles from './index.less'; @@ -43,10 +44,12 @@ export function ButtonEdge({ }; // highlight the nodes that the workflow passes through - const { data: flowDetail } = useFetchFlow(); + const queryClient = useQueryClient(); + const flowDetail = queryClient.getQueryData(['flowDetail']); + const graphPath = useMemo(() => { // TODO: this will be called multiple times - const path = flowDetail.dsl.path ?? []; + const path = flowDetail?.dsl.path ?? []; // The second to last const previousGraphPath: string[] = path.at(-2) ?? []; let graphPath: string[] = path.at(-1) ?? []; @@ -56,7 +59,7 @@ export function ButtonEdge({ graphPath = [previousLatestElement, ...graphPath]; } return graphPath; - }, [flowDetail.dsl.path]); + }, [flowDetail?.dsl.path]); const highlightStyle = useMemo(() => { const idx = graphPath.findIndex((x) => x === source); diff --git a/web/src/pages/flow/canvas/index.tsx b/web/src/pages/flow/canvas/index.tsx index 00029a77b..fde7eb8e1 100644 --- a/web/src/pages/flow/canvas/index.tsx +++ b/web/src/pages/flow/canvas/index.tsx @@ -16,6 +16,7 @@ import { useSelectCanvasData, useShowDrawer, useValidateConnection, + useWatchNodeFormDataChange, } from '../hooks'; import { RagNode } from './node'; @@ -69,6 +70,7 @@ function FlowCanvas({ chatDrawerVisible, hideChatDrawer }: IProps) { const { onDrop, onDragOver, setReactFlowInstance } = useHandleDrop(); const { handleKeyUp } = useHandleKeyUp(); + useWatchNodeFormDataChange(); return (
diff --git a/web/src/pages/flow/canvas/node/categorize-node.tsx b/web/src/pages/flow/canvas/node/categorize-node.tsx index 2ea40cac1..324d215e8 100644 --- a/web/src/pages/flow/canvas/node/categorize-node.tsx +++ b/web/src/pages/flow/canvas/node/categorize-node.tsx @@ -1,25 +1,68 @@ import { useTranslate } from '@/hooks/commonHooks'; import { Flex } from 'antd'; import classNames from 'classnames'; +import { pick } from 'lodash'; import get from 'lodash/get'; +import intersectionWith from 'lodash/intersectionWith'; +import isEqual from 'lodash/isEqual'; import lowerFirst from 'lodash/lowerFirst'; -import { Handle, NodeProps, Position } from 'reactflow'; -import { - CategorizeAnchorPointPositions, - Operator, - operatorMap, -} from '../../constant'; -import { NodeData } from '../../interface'; +import { useEffect, useMemo, useState } from 'react'; +import { Handle, NodeProps, Position, useUpdateNodeInternals } from 'reactflow'; +import { Operator, operatorMap } from '../../constant'; +import { IPosition, NodeData } from '../../interface'; import OperatorIcon from '../../operator-icon'; +import { buildNewPositionMap } from '../../utils'; import CategorizeHandle from './categorize-handle'; import NodeDropdown from './dropdown'; import styles from './index.less'; import NodePopover from './popover'; export function CategorizeNode({ id, data, selected }: NodeProps) { - const categoryData = get(data, 'form.category_description') ?? {}; + const updateNodeInternals = useUpdateNodeInternals(); + const [postionMap, setPositionMap] = useState>({}); + const categoryData = useMemo( + () => get(data, 'form.category_description') ?? {}, + [data], + ); const style = operatorMap[data.label as Operator]; const { t } = useTranslate('flow'); + + useEffect(() => { + // Cache used coordinates + setPositionMap((state) => { + // index in use + const indexesInUse = Object.values(state).map((x) => x.idx); + const categoryDataKeys = Object.keys(categoryData); + const stateKeys = Object.keys(state); + if (!isEqual(categoryDataKeys.sort(), stateKeys.sort())) { + const intersectionKeys = intersectionWith( + stateKeys, + categoryDataKeys, + (categoryDataKey, postionMapKey) => categoryDataKey === postionMapKey, + ); + const newPositionMap = buildNewPositionMap( + categoryDataKeys.filter( + (x) => !intersectionKeys.some((y) => y === x), + ), + indexesInUse, + ); + console.info('newPositionMap:', newPositionMap); + + const nextPostionMap = { + ...pick(state, intersectionKeys), + ...newPositionMap, + }; + + return nextPostionMap; + } + return state; + }); + }, [categoryData]); + + useEffect(() => { + updateNodeInternals(id); + }, [id, updateNodeInternals, postionMap]); + return (
) { id={'c'} > {Object.keys(categoryData).map((x, idx) => { + const position = postionMap[x]; return ( - + position && ( + + ) ); })} diff --git a/web/src/pages/flow/categorize-form/dynamic-categorize.tsx b/web/src/pages/flow/categorize-form/dynamic-categorize.tsx index 69516967c..6b7a41fab 100644 --- a/web/src/pages/flow/categorize-form/dynamic-categorize.tsx +++ b/web/src/pages/flow/categorize-form/dynamic-categorize.tsx @@ -1,12 +1,10 @@ import { useTranslate } from '@/hooks/commonHooks'; import { CloseOutlined } from '@ant-design/icons'; import { Button, Card, Form, Input, Select } from 'antd'; +import { humanId } from 'human-id'; import { useUpdateNodeInternals } from 'reactflow'; import { Operator } from '../constant'; -import { - useBuildFormSelectOptions, - useHandleFormSelectChange, -} from '../form-hooks'; +import { useBuildFormSelectOptions } from '../form-hooks'; import { ICategorizeItem } from '../interface'; interface IProps { @@ -20,7 +18,6 @@ const DynamicCategorize = ({ nodeId }: IProps) => { Operator.Categorize, nodeId, ); - const { handleSelectChange } = useHandleFormSelectChange(nodeId); const { t } = useTranslate('flow'); return ( @@ -28,8 +25,7 @@ const DynamicCategorize = ({ nodeId }: IProps) => { {(fields, { add, remove }) => { const handleAdd = () => { - const idx = fields.length; - add({ name: `Categorize ${idx + 1}` }); + add({ name: humanId() }); if (nodeId) updateNodeInternals(nodeId); }; return ( @@ -79,9 +75,6 @@ const DynamicCategorize = ({ nodeId }: IProps) => { form.getFieldValue(['items', field.name, 'to']), ), )} - onChange={handleSelectChange( - form.getFieldValue(['items', field.name, 'name']), - )} /> diff --git a/web/src/pages/flow/categorize-form/hooks.ts b/web/src/pages/flow/categorize-form/hooks.ts index b986fd7ba..8002d4556 100644 --- a/web/src/pages/flow/categorize-form/hooks.ts +++ b/web/src/pages/flow/categorize-form/hooks.ts @@ -1,12 +1,10 @@ import get from 'lodash/get'; import omit from 'lodash/omit'; import { useCallback, useEffect } from 'react'; -import { Edge, Node } from 'reactflow'; import { ICategorizeItem, ICategorizeItemResult, IOperatorForm, - NodeData, } from '../interface'; import useGraphStore from '../store'; @@ -23,18 +21,14 @@ import useGraphStore from '../store'; */ const buildCategorizeListFromObject = ( categorizeItem: ICategorizeItemResult, - edges: Edge[], - node?: Node, ) => { // Categorize's to field has two data sources, with edges as the data source. // Changes in the edge or to field need to be synchronized to the form field. return Object.keys(categorizeItem).reduce>( (pre, cur) => { // synchronize edge data to the to field - const edge = edges.find( - (x) => x.source === node?.id && x.sourceHandle === cur, - ); - pre.push({ name: cur, ...categorizeItem[cur], to: edge?.target }); + + pre.push({ name: cur, ...categorizeItem[cur] }); return pre; }, [], @@ -68,7 +62,6 @@ export const useHandleFormValuesChange = ({ form, nodeId, }: IOperatorForm) => { - const edges = useGraphStore((state) => state.edges); const getNode = useGraphStore((state) => state.getNode); const node = getNode(nodeId); @@ -86,13 +79,12 @@ export const useHandleFormValuesChange = ({ useEffect(() => { const items = buildCategorizeListFromObject( get(node, 'data.form.category_description', {}), - edges, - node, ); + console.info('effect:', items); form?.setFieldsValue({ items, }); - }, [form, node, edges]); + }, [form, node]); return { handleValuesChange }; }; diff --git a/web/src/pages/flow/form-hooks.ts b/web/src/pages/flow/form-hooks.ts index 53c06fb18..cdd9b19f6 100644 --- a/web/src/pages/flow/form-hooks.ts +++ b/web/src/pages/flow/form-hooks.ts @@ -33,6 +33,11 @@ export const useBuildFormSelectOptions = ( return buildCategorizeToOptions; }; +/** + * dumped + * @param nodeId + * @returns + */ export const useHandleFormSelectChange = (nodeId?: string) => { const { addEdge, deleteEdgeBySourceAndSourceHandle } = useGraphStore( (state) => state, diff --git a/web/src/pages/flow/generate-form/hooks.ts b/web/src/pages/flow/generate-form/hooks.ts index 28334ad68..4876b607e 100644 --- a/web/src/pages/flow/generate-form/hooks.ts +++ b/web/src/pages/flow/generate-form/hooks.ts @@ -27,12 +27,10 @@ export const useHandleOperateParameters = (nodeId: string) => { const { getNode, updateNodeForm } = useGraphStore((state) => state); const node = getNode(nodeId); const dataSource: IGenerateParameter[] = useMemo( - () => get(node, 'data.form.parameters', []), + () => get(node, 'data.form.parameters', []) as IGenerateParameter[], [node], ); - // const [x, setDataSource] = useState([]); - const handleComponentIdChange = useCallback( (row: IGenerateParameter) => (value: string) => { const newData = [...dataSource]; @@ -44,7 +42,6 @@ export const useHandleOperateParameters = (nodeId: string) => { }); updateNodeForm(nodeId, { parameters: newData }); - // setDataSource(newData); }, [updateNodeForm, nodeId, dataSource], ); @@ -53,20 +50,11 @@ export const useHandleOperateParameters = (nodeId: string) => { (id?: string) => () => { const newData = dataSource.filter((item) => item.id !== id); updateNodeForm(nodeId, { parameters: newData }); - // setDataSource(newData); }, [updateNodeForm, nodeId, dataSource], ); const handleAdd = useCallback(() => { - // setDataSource((state) => [ - // ...state, - // { - // id: uuid(), - // key: '', - // component_id: undefined, - // }, - // ]); updateNodeForm(nodeId, { parameters: [ ...dataSource, @@ -89,7 +77,6 @@ export const useHandleOperateParameters = (nodeId: string) => { }); updateNodeForm(nodeId, { parameters: newData }); - // setDataSource(newData); }; return { diff --git a/web/src/pages/flow/hooks.ts b/web/src/pages/flow/hooks.ts index 7c19eb723..2c516491a 100644 --- a/web/src/pages/flow/hooks.ts +++ b/web/src/pages/flow/hooks.ts @@ -10,7 +10,7 @@ import React, { useEffect, useState, } from 'react'; -import { Connection, Node, Position, ReactFlowInstance } from 'reactflow'; +import { Connection, Edge, Node, Position, ReactFlowInstance } from 'reactflow'; // import { shallow } from 'zustand/shallow'; import { variableEnabledFieldMap } from '@/constants/chat'; import { @@ -25,6 +25,7 @@ import { FormInstance, message } from 'antd'; import { humanId } from 'human-id'; import trim from 'lodash/trim'; import { useParams } from 'umi'; +import { v4 as uuid } from 'uuid'; import { NodeMap, Operator, @@ -37,6 +38,7 @@ import { initialRetrievalValues, initialRewriteQuestionValues, } from './constant'; +import { ICategorizeForm, IRelevantForm } from './interface'; import useGraphStore, { RFState } from './store'; import { buildDslComponentsByGraph, @@ -253,7 +255,7 @@ const useSetGraphInfo = () => { }; export const useFetchDataOnMount = () => { - const { loading, data } = useFetchFlow(); + const { loading, data, refetch } = useFetchFlow(); const setGraphInfo = useSetGraphInfo(); useEffect(() => { @@ -264,6 +266,10 @@ export const useFetchDataOnMount = () => { useFetchLlmList(); + useEffect(() => { + refetch(); + }, [refetch]); + return { loading, flowDetail: data }; }; @@ -390,3 +396,78 @@ export const useReplaceIdWithText = (output: unknown) => { return replaceIdWithText(output, getNameById); }; + +/** + * monitor changes in the data.form field of the categorize and relevant operators + * and then synchronize them to the edge + */ +export const useWatchNodeFormDataChange = () => { + const { getNode, nodes, setEdgesByNodeId } = useGraphStore((state) => state); + + const buildCategorizeEdgesByFormData = useCallback( + (nodeId: string, form: ICategorizeForm) => { + // add + // delete + // edit + const categoryDescription = form.category_description; + const downstreamEdges = Object.keys(categoryDescription).reduce( + (pre, sourceHandle) => { + const target = categoryDescription[sourceHandle]?.to; + if (target) { + pre.push({ + id: uuid(), + source: nodeId, + target, + sourceHandle, + }); + } + + return pre; + }, + [], + ); + + setEdgesByNodeId(nodeId, downstreamEdges); + }, + [setEdgesByNodeId], + ); + + const buildRelevantEdgesByFormData = useCallback( + (nodeId: string, form: IRelevantForm) => { + const downstreamEdges = ['yes', 'no'].reduce((pre, cur) => { + const target = form[cur as keyof IRelevantForm] as string; + if (target) { + pre.push({ id: uuid(), source: nodeId, target, sourceHandle: cur }); + } + + return pre; + }, []); + + setEdgesByNodeId(nodeId, downstreamEdges); + }, + [setEdgesByNodeId], + ); + + useEffect(() => { + nodes.forEach((node) => { + const currentNode = getNode(node.id); + const form = currentNode?.data.form ?? {}; + const operatorType = currentNode?.data.label; + switch (operatorType) { + case Operator.Relevant: + buildRelevantEdgesByFormData(node.id, form as IRelevantForm); + break; + case Operator.Categorize: + buildCategorizeEdgesByFormData(node.id, form as ICategorizeForm); + break; + default: + break; + } + }); + }, [ + nodes, + buildCategorizeEdgesByFormData, + getNode, + buildRelevantEdgesByFormData, + ]); +}; diff --git a/web/src/pages/flow/interface.ts b/web/src/pages/flow/interface.ts index 87e99871c..a252e45ce 100644 --- a/web/src/pages/flow/interface.ts +++ b/web/src/pages/flow/interface.ts @@ -70,3 +70,5 @@ export type NodeData = { color: string; form: IBeginForm | IRetrievalForm | IGenerateForm | ICategorizeForm; }; + +export type IPosition = { top: number; right: number; idx: number }; diff --git a/web/src/pages/flow/relevant-form/index.tsx b/web/src/pages/flow/relevant-form/index.tsx index 68ee09671..203a118f7 100644 --- a/web/src/pages/flow/relevant-form/index.tsx +++ b/web/src/pages/flow/relevant-form/index.tsx @@ -2,10 +2,7 @@ import LLMSelect from '@/components/llm-select'; import { useTranslate } from '@/hooks/commonHooks'; import { Form, Select } from 'antd'; import { Operator } from '../constant'; -import { - useBuildFormSelectOptions, - useHandleFormSelectChange, -} from '../form-hooks'; +import { useBuildFormSelectOptions } from '../form-hooks'; import { useSetLlmSetting } from '../hooks'; import { IOperatorForm } from '../interface'; import { useWatchConnectionChanges } from './hooks'; @@ -18,7 +15,6 @@ const RelevantForm = ({ onValuesChange, form, node }: IOperatorForm) => { node?.id, ); useWatchConnectionChanges({ nodeId: node?.id, form }); - const { handleSelectChange } = useHandleFormSelectChange(node?.id); return (
{
diff --git a/web/src/pages/flow/store.ts b/web/src/pages/flow/store.ts index 3e9f18693..a87a81919 100644 --- a/web/src/pages/flow/store.ts +++ b/web/src/pages/flow/store.ts @@ -1,5 +1,7 @@ import type {} from '@redux-devtools/extension'; import { humanId } from 'human-id'; +import differenceWith from 'lodash/differenceWith'; +import intersectionWith from 'lodash/intersectionWith'; import lodashSet from 'lodash/set'; import { Connection, @@ -21,6 +23,7 @@ import { devtools } from 'zustand/middleware'; import { immer } from 'zustand/middleware/immer'; import { Operator } from './constant'; import { NodeData } from './interface'; +import { isEdgeEqual } from './utils'; export type RFState = { nodes: Node[]; @@ -33,6 +36,7 @@ export type RFState = { onConnect: OnConnect; setNodes: (nodes: Node[]) => void; setEdges: (edges: Edge[]) => void; + setEdgesByNodeId: (nodeId: string, edges: Edge[]) => void; updateNodeForm: (nodeId: string, values: any, path?: string[]) => void; onSelectionChange: OnSelectionChangeFunc; addNode: (nodes: Node) => void; @@ -95,6 +99,55 @@ const useGraphStore = create()( setEdges: (edges: Edge[]) => { set({ edges }); }, + setEdgesByNodeId: (nodeId: string, currentDownstreamEdges: Edge[]) => { + const { edges, setEdges } = get(); + // the previous downstream edge of this node + const previousDownstreamEdges = edges.filter( + (x) => x.source === nodeId, + ); + const isDifferent = + previousDownstreamEdges.length !== currentDownstreamEdges.length || + !previousDownstreamEdges.every((x) => + currentDownstreamEdges.some( + (y) => + y.source === x.source && + y.target === x.target && + y.sourceHandle === x.sourceHandle, + ), + ) || + !currentDownstreamEdges.every((x) => + previousDownstreamEdges.some( + (y) => + y.source === x.source && + y.target === x.target && + y.sourceHandle === x.sourceHandle, + ), + ); + + const intersectionDownstreamEdges = intersectionWith( + previousDownstreamEdges, + currentDownstreamEdges, + isEdgeEqual, + ); + if (isDifferent) { + // other operator's edges + const irrelevantEdges = edges.filter((x) => x.source !== nodeId); + // the abandoned edges + const selfAbandonedEdges = []; + // the added downstream edges + const selfAddedDownstreamEdges = differenceWith( + currentDownstreamEdges, + intersectionDownstreamEdges, + isEdgeEqual, + ); + setEdges([ + ...irrelevantEdges, + ...intersectionDownstreamEdges, + ...selfAddedDownstreamEdges, + ]); + } + }, + addNode: (node: Node) => { set({ nodes: get().nodes.concat(node) }); }, @@ -242,10 +295,6 @@ const useGraphStore = create()( set({ nodes: get().nodes.map((node) => { if (node.id === nodeId) { - // node.data = { - // ...node.data, - // form: { ...node.data.form, ...values }, - // }; let nextForm: Record = { ...node.data.form }; if (path.length === 0) { nextForm = Object.assign(nextForm, values); diff --git a/web/src/pages/flow/utils.ts b/web/src/pages/flow/utils.ts index e070e5485..a5d44c850 100644 --- a/web/src/pages/flow/utils.ts +++ b/web/src/pages/flow/utils.ts @@ -2,13 +2,13 @@ import { DSLComponents } from '@/interfaces/database/flow'; import { removeUselessFieldsFromValues } from '@/utils/form'; import dagre from 'dagre'; import { humanId } from 'human-id'; -import { curry } from 'lodash'; +import { curry, sample } from 'lodash'; import pipe from 'lodash/fp/pipe'; import isObject from 'lodash/isObject'; import { Edge, Node, Position } from 'reactflow'; import { v4 as uuidv4 } from 'uuid'; -import { NodeMap, Operator } from './constant'; -import { ICategorizeItemResult, NodeData } from './interface'; +import { CategorizeAnchorPointPositions, NodeMap, Operator } from './constant'; +import { ICategorizeItemResult, IPosition, NodeData } from './interface'; const buildEdges = ( operatorIds: string[], @@ -208,3 +208,27 @@ export const replaceIdWithText = ( return obj; }; + +export const isEdgeEqual = (previous: Edge, current: Edge) => + previous.source === current.source && + previous.target === current.target && + previous.sourceHandle === current.sourceHandle; + +export const buildNewPositionMap = ( + categoryDataKeys: string[], + indexesInUse: number[], +) => { + return categoryDataKeys.reduce>((pre, cur) => { + // take a coordinate + const effectiveIdxes = CategorizeAnchorPointPositions.map( + (x, idx) => idx, + ).filter((x) => !indexesInUse.some((y) => y === x)); + const idx = sample(effectiveIdxes); + if (idx !== undefined) { + indexesInUse.push(idx); + pre[cur] = { ...CategorizeAnchorPointPositions[idx], idx }; + } + + return pre; + }, {}); +};