diff --git a/web/src/pages/flow/canvas/index.tsx b/web/src/pages/flow/canvas/index.tsx index b1f4864e9..f275fe267 100644 --- a/web/src/pages/flow/canvas/index.tsx +++ b/web/src/pages/flow/canvas/index.tsx @@ -16,11 +16,11 @@ import { useHandleKeyUp, useSelectCanvasData, useShowDrawer, + useValidateConnection, } from '../hooks'; import { RagNode } from './node'; import ChatDrawer from '../chat/drawer'; -import { isValidConnection } from '../utils'; import styles from './index.less'; import { BeginNode } from './node/begin-node'; import { CategorizeNode } from './node/categorize-node'; @@ -49,6 +49,7 @@ function FlowCanvas({ chatDrawerVisible, hideChatDrawer }: IProps) { onNodesChange, onSelectionChange, } = useSelectCanvasData(); + const isValidConnection = useValidateConnection(); const { drawerVisible, hideDrawer, showDrawer, clickedNode } = useShowDrawer(); diff --git a/web/src/pages/flow/hooks.ts b/web/src/pages/flow/hooks.ts index 2492edd2f..0b2d9ff49 100644 --- a/web/src/pages/flow/hooks.ts +++ b/web/src/pages/flow/hooks.ts @@ -13,7 +13,7 @@ import React, { useEffect, useState, } from 'react'; -import { Node, Position, ReactFlowInstance } from 'reactflow'; +import { Connection, Node, Position, ReactFlowInstance } from 'reactflow'; // import { shallow } from 'zustand/shallow'; import { variableEnabledFieldMap } from '@/constants/chat'; import { @@ -25,9 +25,9 @@ import { useDebounceEffect } from 'ahooks'; import { FormInstance } from 'antd'; import { humanId } from 'human-id'; import { useParams } from 'umi'; -import { NodeMap, Operator } from './constant'; +import { NodeMap, Operator, RestrictedUpstreamMap } from './constant'; import useGraphStore, { RFState } from './store'; -import { buildDslComponentsByGraph } from './utils'; +import { buildDslComponentsByGraph, getOperatorTypeFromId } from './utils'; const selector = (state: RFState) => ({ nodes: state.nodes, @@ -247,3 +247,26 @@ export const useSetLlmSetting = (form?: FormInstance) => { form?.setFieldsValue({ ...switchBoxValues, ...otherValues }); }, [form, initialLlmSetting]); }; + +export const useValidateConnection = () => { + const edges = useGraphStore((state) => state.edges); + // restricted lines cannot be connected successfully. + const isValidConnection = useCallback( + (connection: Connection) => { + // limit there to be only one line between two nodes + const hasLine = edges.some( + (x) => x.source === connection.source && x.target === connection.target, + ); + + const ret = + !hasLine && + RestrictedUpstreamMap[ + getOperatorTypeFromId(connection.source) as Operator + ]?.every((x) => x !== getOperatorTypeFromId(connection.target)); + return ret; + }, + [edges], + ); + + return isValidConnection; +}; diff --git a/web/src/pages/flow/utils.ts b/web/src/pages/flow/utils.ts index b2a691616..e1e1e15cb 100644 --- a/web/src/pages/flow/utils.ts +++ b/web/src/pages/flow/utils.ts @@ -3,13 +3,9 @@ import { removeUselessFieldsFromValues } from '@/utils/form'; import dagre from 'dagre'; import { curry, isEmpty } from 'lodash'; import pipe from 'lodash/fp/pipe'; -import { Connection, Edge, MarkerType, Node, Position } from 'reactflow'; +import { Edge, MarkerType, Node, Position } from 'reactflow'; import { v4 as uuidv4 } from 'uuid'; -import { - Operator, - RestrictedUpstreamMap, - initialFormValuesMap, -} from './constant'; +import { Operator, initialFormValuesMap } from './constant'; import { NodeData } from './interface'; const buildEdges = ( @@ -170,11 +166,3 @@ export const buildDslComponentsByGraph = ( export const getOperatorTypeFromId = (id: string | null) => { return id?.split(':')[0] as Operator | undefined; }; - -// restricted lines cannot be connected successfully. -export const isValidConnection = (connection: Connection) => { - const ret = RestrictedUpstreamMap[ - getOperatorTypeFromId(connection.source) as Operator - ]?.every((x) => x !== getOperatorTypeFromId(connection.target)); - return ret; -};