feat(workflow): integrate parallel into workflow apps

This commit is contained in:
takatost 2024-08-16 21:33:09 +08:00
parent 1973f5003b
commit 352c45c8a2
13 changed files with 233 additions and 94 deletions

View File

@ -21,6 +21,9 @@ from core.app.entities.queue_entities import (
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueuePingEvent,
QueueRetrieverResourcesEvent,
QueueStopEvent,
@ -304,6 +307,24 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if response:
yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueIterationStartEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')

View File

@ -13,6 +13,7 @@ from core.app.entities.queue_entities import (
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueueRetrieverResourcesEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
@ -261,14 +262,16 @@ class WorkflowBasedAppRunner(AppRunner):
self._publish_event(
QueueParallelBranchRunStartedEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id
parallel_start_node_id=event.parallel_start_node_id,
in_iteration_id=event.in_iteration_id
)
)
elif isinstance(event, ParallelBranchRunSucceededEvent):
self._publish_event(
QueueParallelBranchRunStartedEvent(
QueueParallelBranchRunSucceededEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id
parallel_start_node_id=event.parallel_start_node_id,
in_iteration_id=event.in_iteration_id
)
)
elif isinstance(event, ParallelBranchRunFailedEvent):
@ -276,6 +279,7 @@ class WorkflowBasedAppRunner(AppRunner):
QueueParallelBranchRunFailedEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
in_iteration_id=event.in_iteration_id,
error=event.error
)
)

View File

@ -15,10 +15,10 @@ from core.workflow.graph_engine.entities.event import (
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ParallelBranchRunFailedEvent,
ParallelBranchRunStartedEvent,
ParallelBranchRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
_TEXT_COLOR_MAPPING = {
"blue": "36;1",
@ -36,9 +36,6 @@ class WorkflowLoggingCallback(WorkflowCallback):
def on_event(
self,
graph: Graph,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
event: GraphEngineEvent
) -> None:
if isinstance(event, GraphRunStartedEvent):
@ -49,49 +46,38 @@ class WorkflowLoggingCallback(WorkflowCallback):
self.print_text(f"\n[on_workflow_run_failed] reason: {event.error}", color='red')
elif isinstance(event, NodeRunStartedEvent):
self.on_workflow_node_execute_started(
graph=graph,
event=event
)
elif isinstance(event, NodeRunSucceededEvent):
self.on_workflow_node_execute_succeeded(
graph=graph,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
event=event
)
elif isinstance(event, NodeRunFailedEvent):
self.on_workflow_node_execute_failed(
graph=graph,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
event=event
)
elif isinstance(event, NodeRunStreamChunkEvent):
self.on_node_text_chunk(
graph=graph,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
event=event
)
elif isinstance(event, ParallelBranchRunStartedEvent):
self.on_workflow_parallel_started(
event=event
)
elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent):
self.on_workflow_parallel_completed(
event=event
)
elif isinstance(event, IterationRunStartedEvent):
self.on_workflow_iteration_started(
graph=graph,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
event=event
)
elif isinstance(event, IterationRunNextEvent):
self.on_workflow_iteration_next(
graph=graph,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
event=event
)
elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent):
self.on_workflow_iteration_completed(
graph=graph,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
event=event
)
else:
@ -99,39 +85,29 @@ class WorkflowLoggingCallback(WorkflowCallback):
def on_workflow_node_execute_started(
self,
graph: Graph,
event: NodeRunStartedEvent
) -> None:
"""
Workflow node execute started
"""
route_node_state = event.route_node_state
node_config = graph.node_id_config_mapping.get(route_node_state.node_id)
node_type = None
if node_config:
node_type = node_config.get("data", {}).get("type")
node_type = event.node_type.value
self.print_text("\n[on_workflow_node_execute_started]", color='yellow')
self.print_text("\n[NodeRunStartedEvent]", color='yellow')
self.print_text(f"Node ID: {route_node_state.node_id}", color='yellow')
self.print_text(f"Type: {node_type}", color='yellow')
def on_workflow_node_execute_succeeded(
self,
graph: Graph,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
event: NodeRunSucceededEvent
) -> None:
"""
Workflow node execute succeeded
"""
route_node_state = event.route_node_state
node_config = graph.node_id_config_mapping.get(route_node_state.node_id)
node_type = None
if node_config:
node_type = node_config.get("data", {}).get("type")
node_type = event.node_type.value
self.print_text("\n[on_workflow_node_execute_succeeded]", color='green')
self.print_text("\n[NodeRunSucceededEvent]", color='green')
self.print_text(f"Node ID: {route_node_state.node_id}", color='green')
self.print_text(f"Type: {node_type}", color='green')
@ -150,21 +126,15 @@ class WorkflowLoggingCallback(WorkflowCallback):
def on_workflow_node_execute_failed(
self,
graph: Graph,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
event: NodeRunFailedEvent
) -> None:
"""
Workflow node execute failed
"""
route_node_state = event.route_node_state
node_config = graph.node_id_config_mapping.get(route_node_state.node_id)
node_type = None
if node_config:
node_type = node_config.get("data", {}).get("type")
node_type = event.node_type.value
self.print_text("\n[on_workflow_node_execute_failed]", color='red')
self.print_text("\n[NodeRunFailedEvent]", color='red')
self.print_text(f"Node ID: {route_node_state.node_id}", color='red')
self.print_text(f"Type: {node_type}", color='red')
@ -181,9 +151,6 @@ class WorkflowLoggingCallback(WorkflowCallback):
def on_node_text_chunk(
self,
graph: Graph,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
event: NodeRunStreamChunkEvent
) -> None:
"""
@ -192,7 +159,7 @@ class WorkflowLoggingCallback(WorkflowCallback):
route_node_state = event.route_node_state
if not self.current_node_id or self.current_node_id != route_node_state.node_id:
self.current_node_id = route_node_state.node_id
self.print_text('\n[on_node_text_chunk]')
self.print_text('\n[NodeRunStreamChunkEvent]')
self.print_text(f"Node ID: {route_node_state.node_id}")
node_run_result = route_node_state.node_run_result
@ -202,43 +169,69 @@ class WorkflowLoggingCallback(WorkflowCallback):
self.print_text(event.chunk_content, color="pink", end="")
def on_workflow_parallel_started(
self,
event: ParallelBranchRunStartedEvent
) -> None:
"""
Publish parallel started
"""
self.print_text("\n[ParallelBranchRunStartedEvent]", color='blue')
self.print_text(f"Parallel ID: {event.parallel_id}", color='blue')
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color='blue')
if event.in_iteration_id:
self.print_text(f"Iteration ID: {event.in_iteration_id}", color='blue')
def on_workflow_parallel_completed(
self,
event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
) -> None:
"""
Publish parallel completed
"""
if isinstance(event, ParallelBranchRunSucceededEvent):
color = 'blue'
elif isinstance(event, ParallelBranchRunFailedEvent):
color = 'red'
self.print_text("\n[ParallelBranchRunSucceededEvent]" if isinstance(event, ParallelBranchRunSucceededEvent) else "\n[ParallelBranchRunFailedEvent]", color=color)
self.print_text(f"Parallel ID: {event.parallel_id}", color=color)
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color)
if event.in_iteration_id:
self.print_text(f"Iteration ID: {event.in_iteration_id}", color=color)
if isinstance(event, ParallelBranchRunFailedEvent):
self.print_text(f"Error: {event.error}", color=color)
def on_workflow_iteration_started(
self,
graph: Graph,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
event: IterationRunStartedEvent
) -> None:
"""
Publish iteration started
"""
self.print_text("\n[on_workflow_iteration_started]", color='blue')
self.print_text(f"Node ID: {event.iteration_id}", color='blue')
self.print_text("\n[IterationRunStartedEvent]", color='blue')
self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue')
def on_workflow_iteration_next(
self,
graph: Graph,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
event: IterationRunNextEvent
) -> None:
"""
Publish iteration next
"""
self.print_text("\n[on_workflow_iteration_next]", color='blue')
self.print_text(f"Node ID: {event.iteration_id}", color='blue')
self.print_text("\n[IterationRunNextEvent]", color='blue')
self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue')
self.print_text(f"Iteration Index: {event.index}", color='blue')
def on_workflow_iteration_completed(
self,
graph: Graph,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
event: IterationRunSucceededEvent | IterationRunFailedEvent
) -> None:
"""
Publish iteration completed
"""
self.print_text("\n[on_workflow_iteration_completed]", color='blue')
self.print_text("\n[IterationRunSucceededEvent]" if isinstance(event, IterationRunSucceededEvent) else "\n[IterationRunFailedEvent]", color='blue')
self.print_text(f"Node ID: {event.iteration_id}", color='blue')
def print_text(

View File

@ -334,7 +334,7 @@ class QueueStopEvent(AppQueueEvent):
QueueStopEvent.StopBy.OUTPUT_MODERATION: 'Stopped by output moderation.',
QueueStopEvent.StopBy.INPUT_MODERATION: 'Stopped by input moderation.'
}
return reason_mapping.get(self.stopped_by, 'Stopped by unknown reason.')
@ -370,6 +370,8 @@ class QueueParallelBranchRunStartedEvent(AppQueueEvent):
parallel_id: str
parallel_start_node_id: str
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
@ -380,6 +382,8 @@ class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
parallel_id: str
parallel_start_node_id: str
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
class QueueParallelBranchRunFailedEvent(AppQueueEvent):
@ -390,4 +394,6 @@ class QueueParallelBranchRunFailedEvent(AppQueueEvent):
parallel_id: str
parallel_start_node_id: str
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
error: str

View File

@ -47,6 +47,8 @@ class StreamEvent(Enum):
WORKFLOW_FINISHED = "workflow_finished"
NODE_STARTED = "node_started"
NODE_FINISHED = "node_finished"
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
ITERATION_STARTED = "iteration_started"
ITERATION_NEXT = "iteration_next"
ITERATION_COMPLETED = "iteration_completed"
@ -295,6 +297,46 @@ class NodeFinishStreamResponse(StreamResponse):
"files": []
}
}
class ParallelBranchStartStreamResponse(StreamResponse):
"""
ParallelBranchStartStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
parallel_id: str
parallel_branch_id: str
iteration_id: Optional[str] = None
created_at: int
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED
workflow_run_id: str
data: Data
class ParallelBranchFinishedStreamResponse(StreamResponse):
"""
ParallelBranchFinishedStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
parallel_id: str
parallel_branch_id: str
iteration_id: Optional[str] = None
status: str
error: Optional[str] = None
created_at: int
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED
workflow_run_id: str
data: Data
class IterationNodeStartStreamResponse(StreamResponse):

View File

@ -11,6 +11,9 @@ from core.app.entities.queue_entities import (
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
)
from core.app.entities.task_entities import (
IterationNodeCompletedStreamResponse,
@ -18,6 +21,8 @@ from core.app.entities.task_entities import (
IterationNodeStartStreamResponse,
NodeFinishStreamResponse,
NodeStartStreamResponse,
ParallelBranchFinishedStreamResponse,
ParallelBranchStartStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStartStreamResponse,
WorkflowTaskState,
@ -433,6 +438,56 @@ class WorkflowCycleManage:
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
),
)
def _workflow_parallel_branch_start_to_stream_response(
self,
task_id: str,
workflow_run: WorkflowRun,
event: QueueParallelBranchRunStartedEvent
) -> ParallelBranchStartStreamResponse:
"""
Workflow parallel branch start to stream response
:param task_id: task id
:param workflow_run: workflow run
:param event: parallel branch run started event
:return:
"""
return ParallelBranchStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=ParallelBranchStartStreamResponse.Data(
parallel_id=event.parallel_id,
parallel_branch_id=event.parallel_start_node_id,
iteration_id=event.in_iteration_id,
created_at=int(time.time()),
)
)
def _workflow_parallel_branch_finished_to_stream_response(
self,
task_id: str,
workflow_run: WorkflowRun,
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent
) -> ParallelBranchFinishedStreamResponse:
"""
Workflow parallel branch finished to stream response
:param task_id: task id
:param workflow_run: workflow run
:param event: parallel branch run succeeded or failed event
:return:
"""
return ParallelBranchFinishedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=ParallelBranchFinishedStreamResponse.Data(
parallel_id=event.parallel_id,
parallel_branch_id=event.parallel_start_node_id,
iteration_id=event.in_iteration_id,
status='succeeded' if isinstance(event, QueueParallelBranchRunSucceededEvent) else 'failed',
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
created_at=int(time.time()),
)
)
def _workflow_iteration_start_to_stream_response(
self,

View File

@ -1,18 +1,12 @@
from abc import ABC, abstractmethod
from core.workflow.graph_engine.entities.event import GraphEngineEvent
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
class WorkflowCallback(ABC):
@abstractmethod
def on_event(
self,
graph: Graph,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
event: GraphEngineEvent
) -> None:
"""

View File

@ -56,6 +56,8 @@ class NodeRunMetadataKey(Enum):
TOOL_INFO = 'tool_info'
ITERATION_ID = 'iteration_id'
ITERATION_INDEX = 'iteration_index'
PARALLEL_ID = 'parallel_id'
PARALLEL_START_NODE_ID = 'parallel_start_node_id'
class NodeRunResult(BaseModel):

View File

@ -99,6 +99,10 @@ class Graph(BaseModel):
if target_node_id not in reverse_edge_mapping:
reverse_edge_mapping[target_node_id] = []
# is target node id in source node id edge mapping
if any(graph_edge.target_node_id == target_node_id for graph_edge in edge_mapping[source_node_id]):
continue
target_edge_ids.add(target_node_id)
# parse run condition

View File

@ -244,6 +244,7 @@ class GraphEngine:
if len(edge_mappings) == 1:
edge = edge_mappings[0]
if edge.run_condition:
result = ConditionManager.get_condition_handler(
init_params=self.init_params,
@ -296,14 +297,20 @@ class GraphEngine:
# run parallel nodes, run in new thread and use queue to get results
q: queue.Queue = queue.Queue()
# Create a list to store the threads
threads = []
# new thread
for edge in edge_mappings:
threading.Thread(target=self._run_parallel_node, kwargs={
thread = threading.Thread(target=self._run_parallel_node, kwargs={
'flask_app': current_app._get_current_object(), # type: ignore[attr-defined]
'parallel_id': parallel_id,
'parallel_start_node_id': edge.target_node_id,
'q': q
}).start()
})
threads.append(thread)
thread.start()
succeeded_count = 0
while True:
@ -315,8 +322,8 @@ class GraphEngine:
yield event
if isinstance(event, ParallelBranchRunSucceededEvent):
succeeded_count += 1
if succeeded_count == len(edge_mappings):
break
if succeeded_count == len(threads):
q.put(None)
continue
elif isinstance(event, ParallelBranchRunFailedEvent):
@ -324,6 +331,10 @@ class GraphEngine:
except queue.Empty:
continue
# Join all threads
for thread in threads:
thread.join()
# get final node id
final_node_id = parallel.end_to_node_id
if not final_node_id:
@ -331,8 +342,8 @@ class GraphEngine:
next_node_id = final_node_id
# if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') == in_parallel_id:
# break
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id:
break
def _run_parallel_node(self,
flask_app: Flask,
@ -449,6 +460,14 @@ class GraphEngine:
variable_value=variable_value
)
# add parallel info to run result metadata
if parallel_id and parallel_start_node_id:
if not run_result.metadata:
run_result.metadata = {}
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
yield NodeRunSucceededEvent(
id=node_instance.id,
node_id=node_instance.node_id,

View File

@ -37,6 +37,10 @@ class AnswerStreamProcessor(StreamProcessor):
yield event
elif isinstance(event, NodeRunStreamChunkEvent):
if event.in_iteration_id:
yield event
continue
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[
event.route_node_state.node_id

View File

@ -157,14 +157,15 @@ class IterationNode(BaseNode):
event.in_iteration_id = self.node_id
if isinstance(event, NodeRunSucceededEvent):
metadata = event.route_node_state.node_run_result.metadata
if not metadata:
metadata = {}
if event.route_node_state.node_run_result:
metadata = event.route_node_state.node_run_result.metadata
if not metadata:
metadata = {}
if NodeRunMetadataKey.ITERATION_ID not in metadata:
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any([self.node_id, 'index'])
event.route_node_state.node_run_result.metadata = metadata
if NodeRunMetadataKey.ITERATION_ID not in metadata:
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any([self.node_id, 'index'])
event.route_node_state.node_run_result.metadata = metadata
yield event

View File

@ -98,9 +98,6 @@ class WorkflowEntry:
if callbacks:
for callback in callbacks:
callback.on_event(
graph=self.graph_engine.graph,
graph_init_params=graph_engine.init_params,
graph_runtime_state=graph_engine.graph_runtime_state,
event=event
)
yield event
@ -111,9 +108,6 @@ class WorkflowEntry:
if callbacks:
for callback in callbacks:
callback.on_event(
graph=self.graph_engine.graph,
graph_init_params=graph_engine.init_params,
graph_runtime_state=graph_engine.graph_runtime_state,
event=GraphRunFailedEvent(
error=str(e)
)