From 86707928d4099f62e8ec52fc0690ea2d27168809 Mon Sep 17 00:00:00 2001 From: zxhlyh Date: Tue, 9 Apr 2024 12:24:41 +0800 Subject: [PATCH] fix: node connect self (#3194) --- .../workflow/hooks/use-nodes-interactions.ts | 2 + .../components/workflow/hooks/use-workflow.ts | 27 ++++++-- web/app/components/workflow/utils.ts | 61 ++++++++++++++++++- 3 files changed, 85 insertions(+), 5 deletions(-) diff --git a/web/app/components/workflow/hooks/use-nodes-interactions.ts b/web/app/components/workflow/hooks/use-nodes-interactions.ts index 99c0cd224b..523a8e9a6e 100644 --- a/web/app/components/workflow/hooks/use-nodes-interactions.ts +++ b/web/app/components/workflow/hooks/use-nodes-interactions.ts @@ -301,6 +301,8 @@ export const useNodesInteractions = () => { target, targetHandle, }) => { + if (source === target) + return if (getNodesReadOnly()) return diff --git a/web/app/components/workflow/hooks/use-workflow.ts b/web/app/components/workflow/hooks/use-workflow.ts index 37feb1ae8d..0417b395db 100644 --- a/web/app/components/workflow/hooks/use-workflow.ts +++ b/web/app/components/workflow/hooks/use-workflow.ts @@ -160,8 +160,10 @@ export const useWorkflow = () => { if (incomers.length) { incomers.forEach((node) => { - callback(node) - traverse(node, callback) + if (!list.find(n => node.id === n.id)) { + callback(node) + traverse(node, callback) + } }) } } @@ -272,7 +274,10 @@ export const useWorkflow = () => { }, [isVarUsedInNodes]) const isValidConnection = useCallback(({ source, target }: Connection) => { - const { getNodes } = store.getState() + const { + edges, + getNodes, + } = store.getState() const nodes = getNodes() const sourceNode: Node = nodes.find(node => node.id === source)! const targetNode: Node = nodes.find(node => node.id === target)! @@ -287,7 +292,21 @@ export const useWorkflow = () => { return false } - return true + const hasCycle = (node: Node, visited = new Set()) => { + if (visited.has(node.id)) + return false + + visited.add(node.id) + + for (const outgoer of getOutgoers(node, nodes, edges)) { + if (outgoer.id === source) + return true + if (hasCycle(outgoer, visited)) + return true + } + } + + return !hasCycle(targetNode) }, [store, nodesExtraData]) const formatTimeFromNow = useCallback((time: number) => { diff --git a/web/app/components/workflow/utils.ts b/web/app/components/workflow/utils.ts index 9e7d376cae..f586992dc7 100644 --- a/web/app/components/workflow/utils.ts +++ b/web/app/components/workflow/utils.ts @@ -24,6 +24,60 @@ import type { ToolNodeType } from './nodes/tool/types' import { CollectionType } from '@/app/components/tools/types' import { toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema' +const WHITE = 'WHITE' +const GRAY = 'GRAY' +const BLACK = 'BLACK' + +const isCyclicUtil = (nodeId: string, color: Record, adjaList: Record, stack: string[]) => { + color[nodeId] = GRAY + stack.push(nodeId) + + for (let i = 0; i < adjaList[nodeId].length; ++i) { + const childId = adjaList[nodeId][i] + + if (color[childId] === GRAY) { + stack.push(childId) + return true + } + if (color[childId] === WHITE && isCyclicUtil(childId, color, adjaList, stack)) + return true + } + color[nodeId] = BLACK + if (stack.length > 0 && stack[stack.length - 1] === nodeId) + stack.pop() + return false +} + +const getCycleEdges = (nodes: Node[], edges: Edge[]) => { + const adjaList: Record = {} + const color: Record = {} + const stack: string[] = [] + + for (const node of nodes) { + color[node.id] = WHITE + adjaList[node.id] = [] + } + + for (const edge of edges) + adjaList[edge.source].push(edge.target) + + for (let i = 0; i < nodes.length; i++) { + if (color[nodes[i].id] === WHITE) + isCyclicUtil(nodes[i].id, color, adjaList, stack) + } + + const cycleEdges = [] + if (stack.length > 0) { + const cycleNodes = new Set(stack) + for (const edge of edges) { + if (cycleNodes.has(edge.source) && cycleNodes.has(edge.target)) + cycleEdges.push(edge) + } + } + + return cycleEdges +} + export const initialNodes = (nodes: Node[], edges: Edge[]) => { const firstNode = nodes[0] @@ -35,6 +89,7 @@ export const initialNodes = (nodes: Node[], edges: Edge[]) => { } }) } + return nodes.map((node) => { node.type = 'custom' @@ -75,7 +130,11 @@ export const initialEdges = (edges: Edge[], nodes: Node[]) => { return acc }, {} as Record) - return edges.map((edge) => { + + const cycleEdges = getCycleEdges(nodes, edges) + return edges.filter((edge) => { + return !cycleEdges.find(cycEdge => cycEdge.source === edge.source && cycEdge.target === edge.target) + }).map((edge) => { edge.type = 'custom' if (!edge.sourceHandle)