mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-16 10:25:58 +08:00
feat(workflow): integrate parallel into workflow apps
This commit is contained in:
parent
1973f5003b
commit
352c45c8a2
@ -21,6 +21,9 @@ from core.app.entities.queue_entities import (
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueuePingEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueStopEvent,
|
||||
@ -304,6 +307,24 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationStartEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
@ -13,6 +13,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
@ -261,14 +262,16 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunStartedEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunStartedEvent(
|
||||
QueueParallelBranchRunSucceededEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
@ -276,6 +279,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
QueueParallelBranchRunFailedEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
error=event.error
|
||||
)
|
||||
)
|
||||
|
@ -15,10 +15,10 @@ from core.workflow.graph_engine.entities.event import (
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ParallelBranchRunFailedEvent,
|
||||
ParallelBranchRunStartedEvent,
|
||||
ParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
|
||||
_TEXT_COLOR_MAPPING = {
|
||||
"blue": "36;1",
|
||||
@ -36,9 +36,6 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
||||
|
||||
def on_event(
|
||||
self,
|
||||
graph: Graph,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
event: GraphEngineEvent
|
||||
) -> None:
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
@ -49,49 +46,38 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
||||
self.print_text(f"\n[on_workflow_run_failed] reason: {event.error}", color='red')
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self.on_workflow_node_execute_started(
|
||||
graph=graph,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
self.on_workflow_node_execute_succeeded(
|
||||
graph=graph,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
self.on_workflow_node_execute_failed(
|
||||
graph=graph,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
self.on_node_text_chunk(
|
||||
graph=graph,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunStartedEvent):
|
||||
self.on_workflow_parallel_started(
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent):
|
||||
self.on_workflow_parallel_completed(
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, IterationRunStartedEvent):
|
||||
self.on_workflow_iteration_started(
|
||||
graph=graph,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, IterationRunNextEvent):
|
||||
self.on_workflow_iteration_next(
|
||||
graph=graph,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent):
|
||||
self.on_workflow_iteration_completed(
|
||||
graph=graph,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
event=event
|
||||
)
|
||||
else:
|
||||
@ -99,39 +85,29 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
||||
|
||||
def on_workflow_node_execute_started(
|
||||
self,
|
||||
graph: Graph,
|
||||
event: NodeRunStartedEvent
|
||||
) -> None:
|
||||
"""
|
||||
Workflow node execute started
|
||||
"""
|
||||
route_node_state = event.route_node_state
|
||||
node_config = graph.node_id_config_mapping.get(route_node_state.node_id)
|
||||
node_type = None
|
||||
if node_config:
|
||||
node_type = node_config.get("data", {}).get("type")
|
||||
node_type = event.node_type.value
|
||||
|
||||
self.print_text("\n[on_workflow_node_execute_started]", color='yellow')
|
||||
self.print_text("\n[NodeRunStartedEvent]", color='yellow')
|
||||
self.print_text(f"Node ID: {route_node_state.node_id}", color='yellow')
|
||||
self.print_text(f"Type: {node_type}", color='yellow')
|
||||
|
||||
def on_workflow_node_execute_succeeded(
|
||||
self,
|
||||
graph: Graph,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
event: NodeRunSucceededEvent
|
||||
) -> None:
|
||||
"""
|
||||
Workflow node execute succeeded
|
||||
"""
|
||||
route_node_state = event.route_node_state
|
||||
node_config = graph.node_id_config_mapping.get(route_node_state.node_id)
|
||||
node_type = None
|
||||
if node_config:
|
||||
node_type = node_config.get("data", {}).get("type")
|
||||
node_type = event.node_type.value
|
||||
|
||||
self.print_text("\n[on_workflow_node_execute_succeeded]", color='green')
|
||||
self.print_text("\n[NodeRunSucceededEvent]", color='green')
|
||||
self.print_text(f"Node ID: {route_node_state.node_id}", color='green')
|
||||
self.print_text(f"Type: {node_type}", color='green')
|
||||
|
||||
@ -150,21 +126,15 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
||||
|
||||
def on_workflow_node_execute_failed(
|
||||
self,
|
||||
graph: Graph,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
event: NodeRunFailedEvent
|
||||
) -> None:
|
||||
"""
|
||||
Workflow node execute failed
|
||||
"""
|
||||
route_node_state = event.route_node_state
|
||||
node_config = graph.node_id_config_mapping.get(route_node_state.node_id)
|
||||
node_type = None
|
||||
if node_config:
|
||||
node_type = node_config.get("data", {}).get("type")
|
||||
node_type = event.node_type.value
|
||||
|
||||
self.print_text("\n[on_workflow_node_execute_failed]", color='red')
|
||||
self.print_text("\n[NodeRunFailedEvent]", color='red')
|
||||
self.print_text(f"Node ID: {route_node_state.node_id}", color='red')
|
||||
self.print_text(f"Type: {node_type}", color='red')
|
||||
|
||||
@ -181,9 +151,6 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
||||
|
||||
def on_node_text_chunk(
|
||||
self,
|
||||
graph: Graph,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
event: NodeRunStreamChunkEvent
|
||||
) -> None:
|
||||
"""
|
||||
@ -192,7 +159,7 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
||||
route_node_state = event.route_node_state
|
||||
if not self.current_node_id or self.current_node_id != route_node_state.node_id:
|
||||
self.current_node_id = route_node_state.node_id
|
||||
self.print_text('\n[on_node_text_chunk]')
|
||||
self.print_text('\n[NodeRunStreamChunkEvent]')
|
||||
self.print_text(f"Node ID: {route_node_state.node_id}")
|
||||
|
||||
node_run_result = route_node_state.node_run_result
|
||||
@ -202,43 +169,69 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
||||
|
||||
self.print_text(event.chunk_content, color="pink", end="")
|
||||
|
||||
def on_workflow_parallel_started(
|
||||
self,
|
||||
event: ParallelBranchRunStartedEvent
|
||||
) -> None:
|
||||
"""
|
||||
Publish parallel started
|
||||
"""
|
||||
self.print_text("\n[ParallelBranchRunStartedEvent]", color='blue')
|
||||
self.print_text(f"Parallel ID: {event.parallel_id}", color='blue')
|
||||
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color='blue')
|
||||
if event.in_iteration_id:
|
||||
self.print_text(f"Iteration ID: {event.in_iteration_id}", color='blue')
|
||||
|
||||
def on_workflow_parallel_completed(
|
||||
self,
|
||||
event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
|
||||
) -> None:
|
||||
"""
|
||||
Publish parallel completed
|
||||
"""
|
||||
if isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
color = 'blue'
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
color = 'red'
|
||||
|
||||
self.print_text("\n[ParallelBranchRunSucceededEvent]" if isinstance(event, ParallelBranchRunSucceededEvent) else "\n[ParallelBranchRunFailedEvent]", color=color)
|
||||
self.print_text(f"Parallel ID: {event.parallel_id}", color=color)
|
||||
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color)
|
||||
if event.in_iteration_id:
|
||||
self.print_text(f"Iteration ID: {event.in_iteration_id}", color=color)
|
||||
|
||||
if isinstance(event, ParallelBranchRunFailedEvent):
|
||||
self.print_text(f"Error: {event.error}", color=color)
|
||||
|
||||
def on_workflow_iteration_started(
|
||||
self,
|
||||
graph: Graph,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
event: IterationRunStartedEvent
|
||||
) -> None:
|
||||
"""
|
||||
Publish iteration started
|
||||
"""
|
||||
self.print_text("\n[on_workflow_iteration_started]", color='blue')
|
||||
self.print_text(f"Node ID: {event.iteration_id}", color='blue')
|
||||
self.print_text("\n[IterationRunStartedEvent]", color='blue')
|
||||
self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue')
|
||||
|
||||
def on_workflow_iteration_next(
|
||||
self,
|
||||
graph: Graph,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
event: IterationRunNextEvent
|
||||
) -> None:
|
||||
"""
|
||||
Publish iteration next
|
||||
"""
|
||||
self.print_text("\n[on_workflow_iteration_next]", color='blue')
|
||||
self.print_text(f"Node ID: {event.iteration_id}", color='blue')
|
||||
self.print_text("\n[IterationRunNextEvent]", color='blue')
|
||||
self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue')
|
||||
self.print_text(f"Iteration Index: {event.index}", color='blue')
|
||||
|
||||
def on_workflow_iteration_completed(
|
||||
self,
|
||||
graph: Graph,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
event: IterationRunSucceededEvent | IterationRunFailedEvent
|
||||
) -> None:
|
||||
"""
|
||||
Publish iteration completed
|
||||
"""
|
||||
self.print_text("\n[on_workflow_iteration_completed]", color='blue')
|
||||
self.print_text("\n[IterationRunSucceededEvent]" if isinstance(event, IterationRunSucceededEvent) else "\n[IterationRunFailedEvent]", color='blue')
|
||||
self.print_text(f"Node ID: {event.iteration_id}", color='blue')
|
||||
|
||||
def print_text(
|
||||
|
@ -334,7 +334,7 @@ class QueueStopEvent(AppQueueEvent):
|
||||
QueueStopEvent.StopBy.OUTPUT_MODERATION: 'Stopped by output moderation.',
|
||||
QueueStopEvent.StopBy.INPUT_MODERATION: 'Stopped by input moderation.'
|
||||
}
|
||||
|
||||
|
||||
return reason_mapping.get(self.stopped_by, 'Stopped by unknown reason.')
|
||||
|
||||
|
||||
@ -370,6 +370,8 @@ class QueueParallelBranchRunStartedEvent(AppQueueEvent):
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
|
||||
|
||||
class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
|
||||
@ -380,6 +382,8 @@ class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
|
||||
|
||||
class QueueParallelBranchRunFailedEvent(AppQueueEvent):
|
||||
@ -390,4 +394,6 @@ class QueueParallelBranchRunFailedEvent(AppQueueEvent):
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
error: str
|
||||
|
@ -47,6 +47,8 @@ class StreamEvent(Enum):
|
||||
WORKFLOW_FINISHED = "workflow_finished"
|
||||
NODE_STARTED = "node_started"
|
||||
NODE_FINISHED = "node_finished"
|
||||
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
|
||||
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
|
||||
ITERATION_STARTED = "iteration_started"
|
||||
ITERATION_NEXT = "iteration_next"
|
||||
ITERATION_COMPLETED = "iteration_completed"
|
||||
@ -295,6 +297,46 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
"files": []
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ParallelBranchStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
ParallelBranchStartStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
parallel_id: str
|
||||
parallel_branch_id: str
|
||||
iteration_id: Optional[str] = None
|
||||
created_at: int
|
||||
|
||||
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class ParallelBranchFinishedStreamResponse(StreamResponse):
|
||||
"""
|
||||
ParallelBranchFinishedStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
parallel_id: str
|
||||
parallel_branch_id: str
|
||||
iteration_id: Optional[str] = None
|
||||
status: str
|
||||
error: Optional[str] = None
|
||||
created_at: int
|
||||
|
||||
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class IterationNodeStartStreamResponse(StreamResponse):
|
||||
|
@ -11,6 +11,9 @@ from core.app.entities.queue_entities import (
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
IterationNodeCompletedStreamResponse,
|
||||
@ -18,6 +21,8 @@ from core.app.entities.task_entities import (
|
||||
IterationNodeStartStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
ParallelBranchFinishedStreamResponse,
|
||||
ParallelBranchStartStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
WorkflowTaskState,
|
||||
@ -433,6 +438,56 @@ class WorkflowCycleManage:
|
||||
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_parallel_branch_start_to_stream_response(
|
||||
self,
|
||||
task_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
event: QueueParallelBranchRunStartedEvent
|
||||
) -> ParallelBranchStartStreamResponse:
|
||||
"""
|
||||
Workflow parallel branch start to stream response
|
||||
:param task_id: task id
|
||||
:param workflow_run: workflow run
|
||||
:param event: parallel branch run started event
|
||||
:return:
|
||||
"""
|
||||
return ParallelBranchStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=ParallelBranchStartStreamResponse.Data(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_branch_id=event.parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
created_at=int(time.time()),
|
||||
)
|
||||
)
|
||||
|
||||
def _workflow_parallel_branch_finished_to_stream_response(
|
||||
self,
|
||||
task_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent
|
||||
) -> ParallelBranchFinishedStreamResponse:
|
||||
"""
|
||||
Workflow parallel branch finished to stream response
|
||||
:param task_id: task id
|
||||
:param workflow_run: workflow run
|
||||
:param event: parallel branch run succeeded or failed event
|
||||
:return:
|
||||
"""
|
||||
return ParallelBranchFinishedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=ParallelBranchFinishedStreamResponse.Data(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_branch_id=event.parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
status='succeeded' if isinstance(event, QueueParallelBranchRunSucceededEvent) else 'failed',
|
||||
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
|
||||
created_at=int(time.time()),
|
||||
)
|
||||
)
|
||||
|
||||
def _workflow_iteration_start_to_stream_response(
|
||||
self,
|
||||
|
@ -1,18 +1,12 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
|
||||
|
||||
class WorkflowCallback(ABC):
|
||||
@abstractmethod
|
||||
def on_event(
|
||||
self,
|
||||
graph: Graph,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
event: GraphEngineEvent
|
||||
) -> None:
|
||||
"""
|
||||
|
@ -56,6 +56,8 @@ class NodeRunMetadataKey(Enum):
|
||||
TOOL_INFO = 'tool_info'
|
||||
ITERATION_ID = 'iteration_id'
|
||||
ITERATION_INDEX = 'iteration_index'
|
||||
PARALLEL_ID = 'parallel_id'
|
||||
PARALLEL_START_NODE_ID = 'parallel_start_node_id'
|
||||
|
||||
|
||||
class NodeRunResult(BaseModel):
|
||||
|
@ -99,6 +99,10 @@ class Graph(BaseModel):
|
||||
if target_node_id not in reverse_edge_mapping:
|
||||
reverse_edge_mapping[target_node_id] = []
|
||||
|
||||
# is target node id in source node id edge mapping
|
||||
if any(graph_edge.target_node_id == target_node_id for graph_edge in edge_mapping[source_node_id]):
|
||||
continue
|
||||
|
||||
target_edge_ids.add(target_node_id)
|
||||
|
||||
# parse run condition
|
||||
|
@ -244,6 +244,7 @@ class GraphEngine:
|
||||
|
||||
if len(edge_mappings) == 1:
|
||||
edge = edge_mappings[0]
|
||||
|
||||
if edge.run_condition:
|
||||
result = ConditionManager.get_condition_handler(
|
||||
init_params=self.init_params,
|
||||
@ -296,14 +297,20 @@ class GraphEngine:
|
||||
# run parallel nodes, run in new thread and use queue to get results
|
||||
q: queue.Queue = queue.Queue()
|
||||
|
||||
# Create a list to store the threads
|
||||
threads = []
|
||||
|
||||
# new thread
|
||||
for edge in edge_mappings:
|
||||
threading.Thread(target=self._run_parallel_node, kwargs={
|
||||
thread = threading.Thread(target=self._run_parallel_node, kwargs={
|
||||
'flask_app': current_app._get_current_object(), # type: ignore[attr-defined]
|
||||
'parallel_id': parallel_id,
|
||||
'parallel_start_node_id': edge.target_node_id,
|
||||
'q': q
|
||||
}).start()
|
||||
})
|
||||
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
succeeded_count = 0
|
||||
while True:
|
||||
@ -315,8 +322,8 @@ class GraphEngine:
|
||||
yield event
|
||||
if isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
succeeded_count += 1
|
||||
if succeeded_count == len(edge_mappings):
|
||||
break
|
||||
if succeeded_count == len(threads):
|
||||
q.put(None)
|
||||
|
||||
continue
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
@ -324,6 +331,10 @@ class GraphEngine:
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
# Join all threads
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# get final node id
|
||||
final_node_id = parallel.end_to_node_id
|
||||
if not final_node_id:
|
||||
@ -331,8 +342,8 @@ class GraphEngine:
|
||||
|
||||
next_node_id = final_node_id
|
||||
|
||||
# if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') == in_parallel_id:
|
||||
# break
|
||||
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id:
|
||||
break
|
||||
|
||||
def _run_parallel_node(self,
|
||||
flask_app: Flask,
|
||||
@ -449,6 +460,14 @@ class GraphEngine:
|
||||
variable_value=variable_value
|
||||
)
|
||||
|
||||
# add parallel info to run result metadata
|
||||
if parallel_id and parallel_start_node_id:
|
||||
if not run_result.metadata:
|
||||
run_result.metadata = {}
|
||||
|
||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
|
||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
|
||||
|
||||
yield NodeRunSucceededEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
|
@ -37,6 +37,10 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
|
||||
yield event
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
if event.in_iteration_id:
|
||||
yield event
|
||||
continue
|
||||
|
||||
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
||||
stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[
|
||||
event.route_node_state.node_id
|
||||
|
@ -157,14 +157,15 @@ class IterationNode(BaseNode):
|
||||
event.in_iteration_id = self.node_id
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
metadata = event.route_node_state.node_run_result.metadata
|
||||
if not metadata:
|
||||
metadata = {}
|
||||
if event.route_node_state.node_run_result:
|
||||
metadata = event.route_node_state.node_run_result.metadata
|
||||
if not metadata:
|
||||
metadata = {}
|
||||
|
||||
if NodeRunMetadataKey.ITERATION_ID not in metadata:
|
||||
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
|
||||
metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any([self.node_id, 'index'])
|
||||
event.route_node_state.node_run_result.metadata = metadata
|
||||
if NodeRunMetadataKey.ITERATION_ID not in metadata:
|
||||
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
|
||||
metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any([self.node_id, 'index'])
|
||||
event.route_node_state.node_run_result.metadata = metadata
|
||||
|
||||
yield event
|
||||
|
||||
|
@ -98,9 +98,6 @@ class WorkflowEntry:
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
callback.on_event(
|
||||
graph=self.graph_engine.graph,
|
||||
graph_init_params=graph_engine.init_params,
|
||||
graph_runtime_state=graph_engine.graph_runtime_state,
|
||||
event=event
|
||||
)
|
||||
yield event
|
||||
@ -111,9 +108,6 @@ class WorkflowEntry:
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
callback.on_event(
|
||||
graph=self.graph_engine.graph,
|
||||
graph_init_params=graph_engine.init_params,
|
||||
graph_runtime_state=graph_engine.graph_runtime_state,
|
||||
event=GraphRunFailedEvent(
|
||||
error=str(e)
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user