diff --git a/web/src/pages/flow/store.ts b/web/src/pages/flow/store.ts index 459d60a0d..8acb0aac5 100644 --- a/web/src/pages/flow/store.ts +++ b/web/src/pages/flow/store.ts @@ -19,6 +19,7 @@ import { create } from 'zustand'; import { devtools } from 'zustand/middleware'; import { Operator } from './constant'; import { NodeData } from './interface'; +import { getOperatorTypeFromId } from './utils'; export type RFState = { nodes: Node[]; @@ -35,6 +36,7 @@ export type RFState = { addNode: (nodes: Node) => void; getNode: (id: string) => Node | undefined; addEdge: (connection: Connection) => void; + deletePreviousEdgeOfClassificationNode: (connection: Connection) => void; duplicateNode: (id: string) => void; deleteEdge: () => void; deleteEdgeById: (id: string) => void; @@ -66,6 +68,7 @@ const useGraphStore = create()( set({ edges: addEdge(connection, get().edges), }); + get().deletePreviousEdgeOfClassificationNode(connection); }, onSelectionChange: ({ nodes, edges }: OnSelectionChangeParams) => { set({ @@ -90,6 +93,23 @@ const useGraphStore = create()( edges: addEdge(connection, get().edges), }); }, + deletePreviousEdgeOfClassificationNode: (connection: Connection) => { + // Delete the edge on the classification node anchor when the anchor is connected to other nodes + const { edges } = get(); + if (getOperatorTypeFromId(connection.source) === Operator.Categorize) { + const previousEdge = edges.find( + (x) => + x.source === connection.source && + x.sourceHandle === connection.sourceHandle && + x.target !== connection.target, + ); + if (previousEdge) { + set({ + edges: edges.filter((edge) => edge !== previousEdge), + }); + } + } + }, // addOnlyOneEdgeBetweenTwoNodes: (connection: Connection) => { // }, diff --git a/web/src/pages/flow/utils.ts b/web/src/pages/flow/utils.ts index f6f50c140..404f4d144 100644 --- a/web/src/pages/flow/utils.ts +++ b/web/src/pages/flow/utils.ts @@ -167,13 +167,13 @@ export const buildDslComponentsByGraph = ( return components; }; -export const getOperatorType = (id: string | null) => { +export const getOperatorTypeFromId = (id: string | null) => { return id?.split(':')[0] as Operator | undefined; }; // restricted lines cannot be connected successfully. export const isValidConnection = (connection: Connection) => { return RestrictedUpstreamMap[ - getOperatorType(connection.source) as Operator - ]?.every((x) => x !== getOperatorType(connection.target)); + getOperatorTypeFromId(connection.source) as Operator + ]?.every((x) => x !== getOperatorTypeFromId(connection.target)); };