check node edge

This commit is contained in:
StyleZhang 2024-09-04 13:27:01 +08:00
parent cd42dbdae8
commit 4962b2c460
2 changed files with 22 additions and 3 deletions

View File

@ -29,7 +29,9 @@ import {
useStore, useStore,
useWorkflowStore, useWorkflowStore,
} from '../store' } from '../store'
import { getParallelInfo } from '../utils' import {
getParallelInfo,
} from '../utils'
import { import {
PARALLEL_DEPTH_LIMIT, PARALLEL_DEPTH_LIMIT,
PARALLEL_LIMIT, PARALLEL_LIMIT,
@ -299,7 +301,13 @@ export const useWorkflow = () => {
}, [store, workflowStore, t]) }, [store, workflowStore, t])
const checkNestedParallelLimit = useCallback((nodes: Node[], edges: Edge[], parentNodeId?: string) => { const checkNestedParallelLimit = useCallback((nodes: Node[], edges: Edge[], parentNodeId?: string) => {
const parallelList = getParallelInfo(nodes, edges, parentNodeId) const {
parallelList,
hasAbnormalEdges,
} = getParallelInfo(nodes, edges, parentNodeId)
if (hasAbnormalEdges)
return false
for (let i = 0; i < parallelList.length; i++) { for (let i = 0; i < parallelList.length; i++) {
const parallel = parallelList[i] const parallel = parallelList[i]
@ -329,6 +337,9 @@ export const useWorkflow = () => {
if (sourceNode.type === CUSTOM_NOTE_NODE || targetNode.type === CUSTOM_NOTE_NODE) if (sourceNode.type === CUSTOM_NOTE_NODE || targetNode.type === CUSTOM_NOTE_NODE)
return false return false
if (sourceNode.parentId !== targetNode.parentId)
return false
if (sourceNode && targetNode) { if (sourceNode && targetNode) {
const sourceNodeAvailableNextNodes = nodesExtraData[sourceNode.data.type].availableNextNodes const sourceNodeAvailableNextNodes = nodesExtraData[sourceNode.data.type].availableNextNodes
const targetNodeAvailablePrevNodes = [...nodesExtraData[targetNode.data.type].availablePrevNodes, BlockEnum.Start] const targetNodeAvailablePrevNodes = [...nodesExtraData[targetNode.data.type].availablePrevNodes, BlockEnum.Start]

View File

@ -629,6 +629,7 @@ export const getParallelInfo = (nodes: Node[], edges: Edge[], parentNodeId?: str
const parallelList = [] as ParallelInfoItem[] const parallelList = [] as ParallelInfoItem[]
const nextNodeHandles = [{ node: startNode, handle: 'source' }] const nextNodeHandles = [{ node: startNode, handle: 'source' }]
let hasAbnormalEdges = false
const traverse = (firstNodeHandle: NodeHandle) => { const traverse = (firstNodeHandle: NodeHandle) => {
const nodeEdgesSet = {} as Record<string, Set<string>> const nodeEdgesSet = {} as Record<string, Set<string>>
@ -681,6 +682,10 @@ export const getParallelInfo = (nodes: Node[], edges: Edge[], parentNodeId?: str
outgoers.forEach((outgoer) => { outgoers.forEach((outgoer) => {
const outgoerConnectedEdges = getConnectedEdges([outgoer], edges).filter(edge => edge.source === outgoer.id) const outgoerConnectedEdges = getConnectedEdges([outgoer], edges).filter(edge => edge.source === outgoer.id)
const sourceEdgesGroup = groupBy(outgoerConnectedEdges, 'sourceHandle') const sourceEdgesGroup = groupBy(outgoerConnectedEdges, 'sourceHandle')
const incomers = getIncomers(outgoer, nodes, edges)
if (outgoers.length > 1 && incomers.length > 1)
hasAbnormalEdges = true
Object.keys(sourceEdgesGroup).forEach((sourceHandle) => { Object.keys(sourceEdgesGroup).forEach((sourceHandle) => {
nextHandles.push({ node: outgoer, handle: sourceHandle }) nextHandles.push({ node: outgoer, handle: sourceHandle })
@ -746,5 +751,8 @@ export const getParallelInfo = (nodes: Node[], edges: Edge[], parentNodeId?: str
traverse(nodeHandle) traverse(nodeHandle)
} }
return parallelList return {
parallelList,
hasAbnormalEdges,
}
} }