feat: Introduce WorkflowExecution Domain Entity and Repository, Replace WorkflowRun Direct Usage, and Unify Stream Response Logic (#20067)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2025-05-21 22:01:53 +08:00 committed by GitHub
parent 7d230acf40
commit d31235ca13
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 1710 additions and 644 deletions

View File

@ -26,10 +26,13 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from extensions.ext_database import db
from factories import file_factory
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.enums import WorkflowRunTriggeredFrom
from services.conversation_service import ConversationService
from services.errors.message import MessageNotExistsError
@ -159,8 +162,22 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
# Create repositories
#
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
if invoke_from == InvokeFrom.DEBUGGER:
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=workflow_triggered_from,
)
# Create workflow node execution repository
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=user,
@ -173,6 +190,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user=user,
invoke_from=invoke_from,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=conversation,
stream=streaming,
@ -226,8 +244,18 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
# Create repositories
#
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Create workflow node execution repository
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=user,
@ -240,6 +268,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=None,
stream=streaming,
@ -291,8 +320,18 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
# Create repositories
#
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Create workflow node execution repository
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=user,
@ -305,6 +344,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=None,
stream=streaming,
@ -317,6 +357,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
conversation: Optional[Conversation] = None,
stream: bool = True,
@ -381,6 +422,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message,
user=user,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=stream,
)
@ -453,6 +495,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
stream: bool = False,
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
@ -476,9 +519,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message,
user=user,
stream=stream,
dialogue_count=self._dialogue_count,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=stream,
)
try:

View File

@ -64,6 +64,7 @@ from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes import NodeType
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
from events.message_event import message_was_created
@ -94,6 +95,7 @@ class AdvancedChatAppGenerateTaskPipeline:
user: Union[Account, EndUser],
stream: bool,
dialogue_count: int,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
) -> None:
self._base_task_pipeline = BasedGenerateTaskPipeline(
@ -125,6 +127,7 @@ class AdvancedChatAppGenerateTaskPipeline:
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
},
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
)
@ -294,21 +297,19 @@ class AdvancedChatAppGenerateTaskPipeline:
with Session(db.engine, expire_on_commit=False) as session:
# init workflow run
workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start(
session=session,
workflow_id=self._workflow_id,
user_id=self._user_id,
created_by_role=self._created_by_role,
)
self._workflow_run_id = workflow_run.id
self._workflow_run_id = workflow_execution.id
message = self._get_message(session=session)
if not message:
raise ValueError(f"Message not found: {self._message_id}")
message.workflow_run_id = workflow_run.id
workflow_start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
message.workflow_run_id = workflow_execution.id
workflow_start_resp = self._workflow_cycle_manager.workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
session.commit()
yield workflow_start_resp
elif isinstance(
@ -319,13 +320,10 @@ class AdvancedChatAppGenerateTaskPipeline:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id, event=event
)
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
workflow_run=workflow_run, event=event
)
node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
node_retry_resp = self._workflow_cycle_manager.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
@ -338,20 +336,15 @@ class AdvancedChatAppGenerateTaskPipeline:
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
workflow_run=workflow_run, event=event
)
workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start(
workflow_execution_id=self._workflow_run_id, event=event
)
node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
session.commit()
node_start_resp = self._workflow_cycle_manager.workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_start_resp:
yield node_start_resp
@ -359,15 +352,15 @@ class AdvancedChatAppGenerateTaskPipeline:
# Record files if it's an answer node or end node
if event.node_type in [NodeType.ANSWER, NodeType.END]:
self._recorded_files.extend(
self._workflow_cycle_manager._fetch_files_from_node_outputs(event.outputs or {})
self._workflow_cycle_manager.fetch_files_from_node_outputs(event.outputs or {})
)
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(
event=event
)
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
node_finish_resp = self._workflow_cycle_manager.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
@ -383,11 +376,11 @@ class AdvancedChatAppGenerateTaskPipeline:
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
):
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed(
event=event
)
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
node_finish_resp = self._workflow_cycle_manager.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
@ -399,132 +392,90 @@ class AdvancedChatAppGenerateTaskPipeline:
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
parallel_start_resp = (
self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
)
parallel_start_resp = self._workflow_cycle_manager.workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield parallel_start_resp
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
parallel_finish_resp = (
self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
parallel_finish_resp = (
self._workflow_cycle_manager.workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
)
yield parallel_finish_resp
elif isinstance(event, QueueIterationStartEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
iter_start_resp = self._workflow_cycle_manager.workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield iter_start_resp
elif isinstance(event, QueueIterationNextEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
iter_next_resp = self._workflow_cycle_manager.workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield iter_next_resp
elif isinstance(event, QueueIterationCompletedEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
iter_finish_resp = self._workflow_cycle_manager.workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield iter_finish_resp
elif isinstance(event, QueueLoopStartEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
loop_start_resp = self._workflow_cycle_manager._workflow_loop_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
loop_start_resp = self._workflow_cycle_manager.workflow_loop_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield loop_start_resp
elif isinstance(event, QueueLoopNextEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
loop_next_resp = self._workflow_cycle_manager._workflow_loop_next_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
loop_next_resp = self._workflow_cycle_manager.workflow_loop_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield loop_next_resp
elif isinstance(event, QueueLoopCompletedEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
loop_finish_resp = self._workflow_cycle_manager._workflow_loop_completed_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
loop_finish_resp = self._workflow_cycle_manager.workflow_loop_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield loop_finish_resp
elif isinstance(event, QueueWorkflowSucceededEvent):
@ -535,10 +486,8 @@ class AdvancedChatAppGenerateTaskPipeline:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
session=session,
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success(
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
@ -546,10 +495,11 @@ class AdvancedChatAppGenerateTaskPipeline:
trace_manager=trace_manager,
)
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
workflow_finish_resp = self._workflow_cycle_manager.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
session.commit()
yield workflow_finish_resp
self._base_task_pipeline._queue_manager.publish(
@ -562,10 +512,8 @@ class AdvancedChatAppGenerateTaskPipeline:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
session=session,
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success(
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
@ -573,10 +521,11 @@ class AdvancedChatAppGenerateTaskPipeline:
conversation_id=None,
trace_manager=trace_manager,
)
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
workflow_finish_resp = self._workflow_cycle_manager.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
session.commit()
yield workflow_finish_resp
self._base_task_pipeline._queue_manager.publish(
@ -589,26 +538,25 @@ class AdvancedChatAppGenerateTaskPipeline:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
session=session,
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED,
error=event.error,
error_message=event.error,
conversation_id=self._conversation_id,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count,
)
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
workflow_finish_resp = self._workflow_cycle_manager.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}"))
err = self._base_task_pipeline._handle_error(
event=err_event, session=session, message_id=self._message_id
)
session.commit()
yield workflow_finish_resp
yield self._base_task_pipeline._error_to_stream_response(err)
@ -616,21 +564,19 @@ class AdvancedChatAppGenerateTaskPipeline:
elif isinstance(event, QueueStopEvent):
if self._workflow_run_id and graph_runtime_state:
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
session=session,
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.STOPPED,
error=event.get_stop_reason(),
error_message=event.get_stop_reason(),
conversation_id=self._conversation_id,
trace_manager=trace_manager,
)
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
workflow_finish_resp = self._workflow_cycle_manager.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
workflow_execution=workflow_execution,
)
# Save message
self._save_message(session=session, graph_runtime_state=graph_runtime_state)
@ -711,7 +657,7 @@ class AdvancedChatAppGenerateTaskPipeline:
yield self._message_end_to_stream_response()
elif isinstance(event, QueueAgentLogEvent):
yield self._workflow_cycle_manager._handle_agent_log(
yield self._workflow_cycle_manager.handle_agent_log(
task_id=self._application_generate_entity.task_id, event=event
)
else:

View File

@ -18,16 +18,19 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
from core.app.apps.workflow.app_runner import WorkflowAppRunner
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from extensions.ext_database import db
from factories import file_factory
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.enums import WorkflowRunTriggeredFrom
logger = logging.getLogger(__name__)
@ -136,9 +139,22 @@ class WorkflowAppGenerator(BaseAppGenerator):
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
# Create repositories
#
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
if invoke_from == InvokeFrom.DEBUGGER:
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=workflow_triggered_from,
)
# Create workflow node execution repository
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=user,
@ -152,6 +168,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
user=user,
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
@ -165,6 +182,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
user: Union[Account, EndUser],
application_generate_entity: WorkflowAppGenerateEntity,
invoke_from: InvokeFrom,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
streaming: bool = True,
workflow_thread_pool_id: Optional[str] = None,
@ -209,6 +227,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow=workflow,
queue_manager=queue_manager,
user=user,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=streaming,
)
@ -262,6 +281,17 @@ class WorkflowAppGenerator(BaseAppGenerator):
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create repositories
#
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
@ -278,6 +308,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
)
@ -327,6 +358,17 @@ class WorkflowAppGenerator(BaseAppGenerator):
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create repositories
#
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
@ -343,6 +385,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
)
@ -400,6 +443,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
@ -419,8 +463,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow=workflow,
queue_manager=queue_manager,
user=user,
stream=stream,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=stream,
)
try:

View File

@ -0,0 +1,591 @@
import logging
import time
from collections.abc import Generator
from typing import Optional, Union
from sqlalchemy import select
from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import (
InvokeFrom,
WorkflowAppGenerateEntity,
)
from core.app.entities.queue_entities import (
QueueAgentLogEvent,
QueueErrorEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueLoopCompletedEvent,
QueueLoopNextEvent,
QueueLoopStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueuePingEvent,
QueueStopEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.app.entities.task_entities import (
ErrorStreamResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
StreamResponse,
TextChunkStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStartStreamResponse,
WorkflowTaskState,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_execution_entities import WorkflowExecution
from core.workflow.enums import SystemVariableKey
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
from extensions.ext_database import db
from models.account import Account
from models.enums import CreatorUserRole
from models.model import EndUser
from models.workflow import (
Workflow,
WorkflowAppLog,
WorkflowAppLogCreatedFrom,
WorkflowRun,
WorkflowRunStatus,
)
logger = logging.getLogger(__name__)
class WorkflowAppGenerateTaskPipeline:
"""
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
def __init__(
self,
application_generate_entity: WorkflowAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
) -> None:
self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
stream=stream,
)
if isinstance(user, EndUser):
self._user_id = user.id
user_session_id = user.session_id
self._created_by_role = CreatorUserRole.END_USER
elif isinstance(user, Account):
self._user_id = user.id
user_session_id = user.id
self._created_by_role = CreatorUserRole.ACCOUNT
else:
raise ValueError(f"Invalid user type: {type(user)}")
self._workflow_cycle_manager = WorkflowCycleManager(
application_generate_entity=application_generate_entity,
workflow_system_variables={
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
},
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
)
self._application_generate_entity = application_generate_entity
self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict
self._task_state = WorkflowTaskState()
self._workflow_run_id = ""
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
Process generate task pipeline.
:return:
"""
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._base_task_pipeline._stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse:
"""
To blocking response.
:return:
"""
for stream_response in generator:
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, WorkflowFinishStreamResponse):
response = WorkflowAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=stream_response.data.id,
data=WorkflowAppBlockingResponse.Data(
id=stream_response.data.id,
workflow_id=stream_response.data.workflow_id,
status=stream_response.data.status,
outputs=stream_response.data.outputs,
error=stream_response.data.error,
elapsed_time=stream_response.data.elapsed_time,
total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps,
created_at=int(stream_response.data.created_at),
finished_at=int(stream_response.data.finished_at),
),
)
return response
else:
continue
raise ValueError("queue listening stopped unexpectedly.")
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[WorkflowAppStreamResponse, None, None]:
"""
To stream response.
:return:
"""
workflow_run_id = None
for stream_response in generator:
if isinstance(stream_response, WorkflowStartStreamResponse):
workflow_run_id = stream_response.workflow_run_id
yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
if not publisher:
return None
audio_msg = publisher.check_and_get_audio()
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
def _wrapper_process_stream_response(
self, trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
tts_publisher = None
task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow_features_dict
if (
features_dict.get("text_to_speech")
and features_dict["text_to_speech"].get("enabled")
and features_dict["text_to_speech"].get("autoPlay") == "enabled"
):
tts_publisher = AppGeneratorTTSPublisher(
tenant_id, features_dict["text_to_speech"].get("voice"), features_dict["text_to_speech"].get("language")
)
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
break
yield response
start_listener_time = time.time()
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
try:
if not tts_publisher:
break
audio_trunk = tts_publisher.check_and_get_audio()
if audio_trunk is None:
# release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
continue
if audio_trunk.status == "finish":
break
else:
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
except Exception:
logger.exception(f"Fails to get audio trunk, task_id: {task_id}")
break
if tts_publisher:
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response(
self,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None,
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
:return:
"""
graph_runtime_state = None
for queue_message in self._base_task_pipeline._queue_manager.listen():
event = queue_message.event
if isinstance(event, QueuePingEvent):
yield self._base_task_pipeline._ping_stream_response()
elif isinstance(event, QueueErrorEvent):
err = self._base_task_pipeline._handle_error(event=event)
yield self._base_task_pipeline._error_to_stream_response(err)
break
elif isinstance(event, QueueWorkflowStartedEvent):
# override graph runtime state
graph_runtime_state = event.graph_runtime_state
with Session(db.engine, expire_on_commit=False) as session:
# init workflow run
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start(
session=session,
workflow_id=self._workflow_id,
)
self._workflow_run_id = workflow_execution.id
start_resp = self._workflow_cycle_manager.workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
yield start_resp
elif isinstance(
event,
QueueNodeRetryEvent,
):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id,
event=event,
)
response = self._workflow_cycle_manager.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
session.commit()
if response:
yield response
elif isinstance(event, QueueNodeStartedEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start(
workflow_execution_id=self._workflow_run_id, event=event
)
node_start_response = self._workflow_cycle_manager.workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_start_response:
yield node_start_response
elif isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(
event=event
)
node_success_response = self._workflow_cycle_manager.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_success_response:
yield node_success_response
elif isinstance(
event,
QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
):
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed(
event=event,
)
node_failed_response = self._workflow_cycle_manager.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_failed_response:
yield node_failed_response
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
parallel_start_resp = self._workflow_cycle_manager.workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield parallel_start_resp
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
parallel_finish_resp = (
self._workflow_cycle_manager.workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
)
yield parallel_finish_resp
elif isinstance(event, QueueIterationStartEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
iter_start_resp = self._workflow_cycle_manager.workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield iter_start_resp
elif isinstance(event, QueueIterationNextEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
iter_next_resp = self._workflow_cycle_manager.workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield iter_next_resp
elif isinstance(event, QueueIterationCompletedEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
iter_finish_resp = self._workflow_cycle_manager.workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield iter_finish_resp
elif isinstance(event, QueueLoopStartEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
loop_start_resp = self._workflow_cycle_manager.workflow_loop_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield loop_start_resp
elif isinstance(event, QueueLoopNextEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
loop_next_resp = self._workflow_cycle_manager.workflow_loop_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield loop_next_resp
elif isinstance(event, QueueLoopCompletedEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
loop_finish_resp = self._workflow_cycle_manager.workflow_loop_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield loop_finish_resp
elif isinstance(event, QueueWorkflowSucceededEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success(
workflow_run_id=self._workflow_run_id,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
conversation_id=None,
trace_manager=trace_manager,
)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
workflow_finish_resp = self._workflow_cycle_manager.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
session.commit()
yield workflow_finish_resp
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success(
workflow_run_id=self._workflow_run_id,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
exceptions_count=event.exceptions_count,
conversation_id=None,
trace_manager=trace_manager,
)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
workflow_finish_resp = self._workflow_cycle_manager.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
session.commit()
yield workflow_finish_resp
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
workflow_run_id=self._workflow_run_id,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowRunStatus.STOPPED,
error_message=event.error
if isinstance(event, QueueWorkflowFailedEvent)
else event.get_stop_reason(),
conversation_id=None,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
workflow_finish_resp = self._workflow_cycle_manager.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
session.commit()
yield workflow_finish_resp
elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text
if delta_text is None:
continue
# only publish tts message at text chunk streaming
if tts_publisher:
tts_publisher.publish(queue_message)
self._task_state.answer += delta_text
yield self._text_chunk_to_stream_response(
delta_text, from_variable_selector=event.from_variable_selector
)
elif isinstance(event, QueueAgentLogEvent):
yield self._workflow_cycle_manager.handle_agent_log(
task_id=self._application_generate_entity.task_id, event=event
)
else:
continue
if tts_publisher:
tts_publisher.publish(None)
def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None:
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id))
assert workflow_run is not None
invoke_from = self._application_generate_entity.invoke_from
if invoke_from == InvokeFrom.SERVICE_API:
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
elif invoke_from == InvokeFrom.EXPLORE:
created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP
elif invoke_from == InvokeFrom.WEB_APP:
created_from = WorkflowAppLogCreatedFrom.WEB_APP
else:
# not save log for debugging
return
workflow_app_log = WorkflowAppLog()
workflow_app_log.tenant_id = workflow_run.tenant_id
workflow_app_log.app_id = workflow_run.app_id
workflow_app_log.workflow_id = workflow_run.workflow_id
workflow_app_log.workflow_run_id = workflow_run.id
workflow_app_log.created_from = created_from.value
workflow_app_log.created_by_role = self._created_by_role
workflow_app_log.created_by = self._user_id
session.add(workflow_app_log)
session.commit()
def _text_chunk_to_stream_response(
self, text: str, from_variable_selector: Optional[list[str]] = None
) -> TextChunkStreamResponse:
"""
Handle completed event.
:param text: text
:return:
"""
response = TextChunkStreamResponse(
task_id=self._application_generate_entity.task_id,
data=TextChunkStreamResponse.Data(text=text, from_variable_selector=from_variable_selector),
)
return response

View File

@ -190,7 +190,7 @@ class WorkflowStartStreamResponse(StreamResponse):
id: str
workflow_id: str
sequence_number: int
inputs: dict
inputs: Mapping[str, Any]
created_at: int
event: StreamEvent = StreamEvent.WORKFLOW_STARTED
@ -212,7 +212,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
workflow_id: str
sequence_number: int
status: str
outputs: Optional[dict] = None
outputs: Optional[Mapping[str, Any]] = None
error: Optional[str] = None
elapsed_time: float
total_tokens: int
@ -788,7 +788,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
id: str
workflow_id: str
status: str
outputs: Optional[dict] = None
outputs: Optional[Mapping[str, Any]] = None
error: Optional[str] = None
elapsed_time: float
total_tokens: int

View File

@ -30,6 +30,7 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo,
)
from core.ops.utils import get_message_data
from core.workflow.entities.workflow_execution_entities import WorkflowExecution
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
@ -373,7 +374,7 @@ class TraceTask:
self,
trace_type: Any,
message_id: Optional[str] = None,
workflow_run: Optional[WorkflowRun] = None,
workflow_execution: Optional[WorkflowExecution] = None,
conversation_id: Optional[str] = None,
user_id: Optional[str] = None,
timer: Optional[Any] = None,
@ -381,7 +382,7 @@ class TraceTask:
):
self.trace_type = trace_type
self.message_id = message_id
self.workflow_run_id = workflow_run.id if workflow_run else None
self.workflow_run_id = workflow_execution.id if workflow_execution else None
self.conversation_id = conversation_id
self.user_id = user_id
self.timer = timer

View File

@ -0,0 +1,242 @@
"""
SQLAlchemy implementation of the WorkflowExecutionRepository.
"""
import json
import logging
from typing import Optional, Union
from sqlalchemy import select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.workflow.entities.workflow_execution_entities import (
WorkflowExecution,
WorkflowExecutionStatus,
WorkflowType,
)
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from models import (
Account,
CreatorUserRole,
EndUser,
WorkflowRun,
)
from models.enums import WorkflowRunTriggeredFrom
logger = logging.getLogger(__name__)
class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
"""
SQLAlchemy implementation of the WorkflowExecutionRepository interface.
This implementation supports multi-tenancy by filtering operations based on tenant_id.
Each method creates its own session, handles the transaction, and commits changes
to the database. This prevents long-running connections in the workflow core.
This implementation also includes an in-memory cache for workflow executions to improve
performance by reducing database queries.
"""
def __init__(
self,
session_factory: sessionmaker | Engine,
user: Union[Account, EndUser],
app_id: Optional[str],
triggered_from: Optional[WorkflowRunTriggeredFrom],
):
"""
Initialize the repository with a SQLAlchemy sessionmaker or engine and context information.
Args:
session_factory: SQLAlchemy sessionmaker or engine for creating sessions
user: Account or EndUser object containing tenant_id, user ID, and role information
app_id: App ID for filtering by application (can be None)
triggered_from: Source of the execution trigger (DEBUGGING or APP_RUN)
"""
# If an engine is provided, create a sessionmaker from it
if isinstance(session_factory, Engine):
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
elif isinstance(session_factory, sessionmaker):
self._session_factory = session_factory
else:
raise ValueError(
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
)
# Extract tenant_id from user
tenant_id: str | None = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id
if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id
# Store app context
self._app_id = app_id
# Extract user context
self._triggered_from = triggered_from
self._creator_user_id = user.id
# Determine user role based on user type
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
# Initialize in-memory cache for workflow executions
# Key: execution_id, Value: WorkflowRun (DB model)
self._execution_cache: dict[str, WorkflowRun] = {}
def _to_domain_model(self, db_model: WorkflowRun) -> WorkflowExecution:
"""
Convert a database model to a domain model.
Args:
db_model: The database model to convert
Returns:
The domain model
"""
# Parse JSON fields
inputs = db_model.inputs_dict
outputs = db_model.outputs_dict
graph = db_model.graph_dict
# Convert status to domain enum
status = WorkflowExecutionStatus(db_model.status)
return WorkflowExecution(
id=db_model.id,
workflow_id=db_model.workflow_id,
sequence_number=db_model.sequence_number,
type=WorkflowType(db_model.type),
workflow_version=db_model.version,
graph=graph,
inputs=inputs,
outputs=outputs,
status=status,
error_message=db_model.error or "",
total_tokens=db_model.total_tokens,
total_steps=db_model.total_steps,
exceptions_count=db_model.exceptions_count,
started_at=db_model.created_at,
finished_at=db_model.finished_at,
)
def _to_db_model(self, domain_model: WorkflowExecution) -> WorkflowRun:
"""
Convert a domain model to a database model.
Args:
domain_model: The domain model to convert
Returns:
The database model
"""
# Use values from constructor if provided
if not self._triggered_from:
raise ValueError("triggered_from is required in repository constructor")
if not self._creator_user_id:
raise ValueError("created_by is required in repository constructor")
if not self._creator_user_role:
raise ValueError("created_by_role is required in repository constructor")
db_model = WorkflowRun()
db_model.id = domain_model.id
db_model.tenant_id = self._tenant_id
if self._app_id is not None:
db_model.app_id = self._app_id
db_model.workflow_id = domain_model.workflow_id
db_model.triggered_from = self._triggered_from
db_model.sequence_number = domain_model.sequence_number
db_model.type = domain_model.type
db_model.version = domain_model.workflow_version
db_model.graph = json.dumps(domain_model.graph) if domain_model.graph else None
db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None
db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None
db_model.status = domain_model.status
db_model.error = domain_model.error_message if domain_model.error_message else None
db_model.total_tokens = domain_model.total_tokens
db_model.total_steps = domain_model.total_steps
db_model.exceptions_count = domain_model.exceptions_count
db_model.created_by_role = self._creator_user_role
db_model.created_by = self._creator_user_id
db_model.created_at = domain_model.started_at
db_model.finished_at = domain_model.finished_at
# Calculate elapsed time if finished_at is available
if domain_model.finished_at:
db_model.elapsed_time = (domain_model.finished_at - domain_model.started_at).total_seconds()
else:
db_model.elapsed_time = 0
return db_model
def save(self, execution: WorkflowExecution) -> None:
"""
Save or update a WorkflowExecution domain entity to the database.
This method serves as a domain-to-database adapter that:
1. Converts the domain entity to its database representation
2. Persists the database model using SQLAlchemy's merge operation
3. Maintains proper multi-tenancy by including tenant context during conversion
4. Updates the in-memory cache for faster subsequent lookups
The method handles both creating new records and updating existing ones through
SQLAlchemy's merge operation.
Args:
execution: The WorkflowExecution domain entity to persist
"""
# Convert domain model to database model using tenant context and other attributes
db_model = self._to_db_model(execution)
# Create a new database session
with self._session_factory() as session:
# SQLAlchemy merge intelligently handles both insert and update operations
# based on the presence of the primary key
session.merge(db_model)
session.commit()
# Update the in-memory cache for faster subsequent lookups
logger.debug(f"Updating cache for execution_id: {db_model.id}")
self._execution_cache[db_model.id] = db_model
def get(self, execution_id: str) -> Optional[WorkflowExecution]:
"""
Retrieve a WorkflowExecution by its ID.
First checks the in-memory cache, and if not found, queries the database.
If found in the database, adds it to the cache for future lookups.
Args:
execution_id: The workflow execution ID
Returns:
The WorkflowExecution instance if found, None otherwise
"""
# First check the cache
if execution_id in self._execution_cache:
logger.debug(f"Cache hit for execution_id: {execution_id}")
# Convert cached DB model to domain model
cached_db_model = self._execution_cache[execution_id]
return self._to_domain_model(cached_db_model)
# If not in cache, query the database
logger.debug(f"Cache miss for execution_id: {execution_id}, querying database")
with self._session_factory() as session:
stmt = select(WorkflowRun).where(
WorkflowRun.id == execution_id,
WorkflowRun.tenant_id == self._tenant_id,
)
if self._app_id:
stmt = stmt.where(WorkflowRun.app_id == self._app_id)
db_model = session.scalar(stmt)
if db_model:
# Add DB model to cache
self._execution_cache[execution_id] = db_model
# Convert to domain model and return
return self._to_domain_model(db_model)
return None

View File

@ -0,0 +1,91 @@
"""
Domain entities for workflow execution.
Models are independent of the storage mechanism and don't contain
implementation details like tenant_id, app_id, etc.
"""
from collections.abc import Mapping
from datetime import UTC, datetime
from enum import StrEnum
from typing import Any, Optional
from pydantic import BaseModel, Field
class WorkflowType(StrEnum):
"""
Workflow Type Enum for domain layer
"""
WORKFLOW = "workflow"
CHAT = "chat"
class WorkflowExecutionStatus(StrEnum):
RUNNING = "running"
SUCCEEDED = "succeeded"
FAILED = "failed"
STOPPED = "stopped"
PARTIAL_SUCCEEDED = "partial-succeeded"
class WorkflowExecution(BaseModel):
"""
Domain model for workflow execution based on WorkflowRun but without
user, tenant, and app attributes.
"""
id: str = Field(...)
workflow_id: str = Field(...)
workflow_version: str = Field(...)
sequence_number: int = Field(...)
type: WorkflowType = Field(...)
graph: Mapping[str, Any] = Field(...)
inputs: Mapping[str, Any] = Field(...)
outputs: Optional[Mapping[str, Any]] = None
status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING
error_message: str = Field(default="")
total_tokens: int = Field(default=0)
total_steps: int = Field(default=0)
exceptions_count: int = Field(default=0)
started_at: datetime = Field(...)
finished_at: Optional[datetime] = None
@property
def elapsed_time(self) -> float:
"""
Calculate elapsed time in seconds.
If workflow is not finished, use current time.
"""
end_time = self.finished_at or datetime.now(UTC).replace(tzinfo=None)
return (end_time - self.started_at).total_seconds()
@classmethod
def new(
cls,
*,
id: str,
workflow_id: str,
sequence_number: int,
type: WorkflowType,
workflow_version: str,
graph: Mapping[str, Any],
inputs: Mapping[str, Any],
started_at: datetime,
) -> "WorkflowExecution":
return WorkflowExecution(
id=id,
workflow_id=workflow_id,
sequence_number=sequence_number,
type=type,
workflow_version=workflow_version,
graph=graph,
inputs=inputs,
status=WorkflowExecutionStatus.RUNNING,
started_at=started_at,
)

View File

@ -0,0 +1,42 @@
from typing import Optional, Protocol
from core.workflow.entities.workflow_execution_entities import WorkflowExecution
class WorkflowExecutionRepository(Protocol):
"""
Repository interface for WorkflowExecution.
This interface defines the contract for accessing and manipulating
WorkflowExecution data, regardless of the underlying storage mechanism.
Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id),
and other implementation details should be handled at the implementation level, not in
the core interface. This keeps the core domain model clean and independent of specific
application domains or deployment scenarios.
"""
def save(self, execution: WorkflowExecution) -> None:
"""
Save or update a WorkflowExecution instance.
This method handles both creating new records and updating existing ones.
The implementation should determine whether to create or update based on
the execution's ID or other identifying fields.
Args:
execution: The WorkflowExecution instance to save or update
"""
...
def get(self, execution_id: str) -> Optional[WorkflowExecution]:
"""
Retrieve a WorkflowExecution by its ID.
Args:
execution_id: The workflow execution ID
Returns:
The WorkflowExecution instance if found, None otherwise
"""
...

View File

@ -3,6 +3,7 @@ import time
from collections.abc import Generator
from typing import Optional, Union
from sqlalchemy import select
from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
@ -53,7 +54,9 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_execution_entities import WorkflowExecution
from core.workflow.enums import SystemVariableKey
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
from extensions.ext_database import db
@ -83,6 +86,7 @@ class WorkflowAppGenerateTaskPipeline:
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
) -> None:
self._base_task_pipeline = BasedGenerateTaskPipeline(
@ -111,6 +115,7 @@ class WorkflowAppGenerateTaskPipeline:
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
},
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
)
@ -258,17 +263,15 @@ class WorkflowAppGenerateTaskPipeline:
with Session(db.engine, expire_on_commit=False) as session:
# init workflow run
workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start(
session=session,
workflow_id=self._workflow_id,
user_id=self._user_id,
created_by_role=self._created_by_role,
)
self._workflow_run_id = workflow_run.id
start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
self._workflow_run_id = workflow_execution.id
start_resp = self._workflow_cycle_manager.workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
session.commit()
yield start_resp
elif isinstance(
@ -278,13 +281,11 @@ class WorkflowAppGenerateTaskPipeline:
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id,
event=event,
)
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
workflow_run=workflow_run, event=event
)
response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
response = self._workflow_cycle_manager.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
@ -297,27 +298,22 @@ class WorkflowAppGenerateTaskPipeline:
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
workflow_run=workflow_run, event=event
)
node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
session.commit()
workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start(
workflow_execution_id=self._workflow_run_id, event=event
)
node_start_response = self._workflow_cycle_manager.workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_start_response:
yield node_start_response
elif isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(
event=event
)
node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
node_success_response = self._workflow_cycle_manager.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
@ -332,10 +328,10 @@ class WorkflowAppGenerateTaskPipeline:
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
):
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed(
event=event,
)
node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
node_failed_response = self._workflow_cycle_manager.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
@ -348,18 +344,11 @@ class WorkflowAppGenerateTaskPipeline:
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
parallel_start_resp = (
self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
)
parallel_start_resp = self._workflow_cycle_manager.workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield parallel_start_resp
@ -367,18 +356,13 @@ class WorkflowAppGenerateTaskPipeline:
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
parallel_finish_resp = (
self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
parallel_finish_resp = (
self._workflow_cycle_manager.workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
)
yield parallel_finish_resp
@ -386,16 +370,11 @@ class WorkflowAppGenerateTaskPipeline:
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
iter_start_resp = self._workflow_cycle_manager.workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield iter_start_resp
@ -403,16 +382,11 @@ class WorkflowAppGenerateTaskPipeline:
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
iter_next_resp = self._workflow_cycle_manager.workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield iter_next_resp
@ -420,16 +394,11 @@ class WorkflowAppGenerateTaskPipeline:
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
iter_finish_resp = self._workflow_cycle_manager.workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield iter_finish_resp
@ -437,16 +406,11 @@ class WorkflowAppGenerateTaskPipeline:
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
loop_start_resp = self._workflow_cycle_manager._workflow_loop_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
loop_start_resp = self._workflow_cycle_manager.workflow_loop_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield loop_start_resp
@ -454,16 +418,11 @@ class WorkflowAppGenerateTaskPipeline:
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
loop_next_resp = self._workflow_cycle_manager._workflow_loop_next_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
loop_next_resp = self._workflow_cycle_manager.workflow_loop_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield loop_next_resp
@ -471,16 +430,11 @@ class WorkflowAppGenerateTaskPipeline:
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
loop_finish_resp = self._workflow_cycle_manager._workflow_loop_completed_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
loop_finish_resp = self._workflow_cycle_manager.workflow_loop_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield loop_finish_resp
@ -491,10 +445,8 @@ class WorkflowAppGenerateTaskPipeline:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
session=session,
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success(
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
@ -503,12 +455,12 @@ class WorkflowAppGenerateTaskPipeline:
)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
workflow_finish_resp = self._workflow_cycle_manager.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
workflow_execution=workflow_execution,
)
session.commit()
@ -520,10 +472,8 @@ class WorkflowAppGenerateTaskPipeline:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
session=session,
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success(
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
@ -533,10 +483,12 @@ class WorkflowAppGenerateTaskPipeline:
)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
workflow_finish_resp = self._workflow_cycle_manager.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
session.commit()
@ -548,26 +500,28 @@ class WorkflowAppGenerateTaskPipeline:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
session=session,
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowRunStatus.STOPPED,
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
error_message=event.error
if isinstance(event, QueueWorkflowFailedEvent)
else event.get_stop_reason(),
conversation_id=None,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
workflow_finish_resp = self._workflow_cycle_manager.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
session.commit()
@ -586,7 +540,7 @@ class WorkflowAppGenerateTaskPipeline:
delta_text, from_variable_selector=event.from_variable_selector
)
elif isinstance(event, QueueAgentLogEvent):
yield self._workflow_cycle_manager._handle_agent_log(
yield self._workflow_cycle_manager.handle_agent_log(
task_id=self._application_generate_entity.task_id, event=event
)
else:
@ -595,11 +549,9 @@ class WorkflowAppGenerateTaskPipeline:
if tts_publisher:
tts_publisher.publish(None)
def _save_workflow_app_log(self, *, session: Session, workflow_run: WorkflowRun) -> None:
"""
Save workflow app log.
:return:
"""
def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None:
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id))
assert workflow_run is not None
invoke_from = self._application_generate_entity.invoke_from
if invoke_from == InvokeFrom.SERVICE_API:
created_from = WorkflowAppLogCreatedFrom.SERVICE_API

View File

@ -1,4 +1,3 @@
import json
import time
from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
@ -8,7 +7,7 @@ from uuid import uuid4
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
QueueAgentLogEvent,
QueueIterationCompletedEvent,
@ -54,9 +53,11 @@ from core.workflow.entities.node_execution_entities import (
NodeExecution,
NodeExecutionStatus,
)
from core.workflow.entities.workflow_execution_entities import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_entry import WorkflowEntry
from models import (
@ -67,7 +68,6 @@ from models import (
WorkflowNodeExecutionStatus,
WorkflowRun,
WorkflowRunStatus,
WorkflowRunTriggeredFrom,
)
@ -77,21 +77,20 @@ class WorkflowCycleManager:
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
workflow_system_variables: dict[SystemVariableKey, Any],
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
) -> None:
self._workflow_run: WorkflowRun | None = None
self._application_generate_entity = application_generate_entity
self._workflow_system_variables = workflow_system_variables
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
def _handle_workflow_run_start(
def handle_workflow_run_start(
self,
*,
session: Session,
workflow_id: str,
user_id: str,
created_by_role: CreatorUserRole,
) -> WorkflowRun:
) -> WorkflowExecution:
workflow_stmt = select(Workflow).where(Workflow.id == workflow_id)
workflow = session.scalar(workflow_stmt)
if not workflow:
@ -110,157 +109,116 @@ class WorkflowCycleManager:
continue
inputs[f"sys.{key.value}"] = value
triggered_from = (
WorkflowRunTriggeredFrom.DEBUGGING
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
else WorkflowRunTriggeredFrom.APP_RUN
)
# handle special values
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
# init workflow run
# TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this
workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4())
execution_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4())
execution = WorkflowExecution.new(
id=execution_id,
workflow_id=workflow.id,
sequence_number=new_sequence_number,
type=WorkflowType(workflow.type),
workflow_version=workflow.version,
graph=workflow.graph_dict,
inputs=inputs,
started_at=datetime.now(UTC).replace(tzinfo=None),
)
workflow_run = WorkflowRun()
workflow_run.id = workflow_run_id
workflow_run.tenant_id = workflow.tenant_id
workflow_run.app_id = workflow.app_id
workflow_run.sequence_number = new_sequence_number
workflow_run.workflow_id = workflow.id
workflow_run.type = workflow.type
workflow_run.triggered_from = triggered_from.value
workflow_run.version = workflow.version
workflow_run.graph = workflow.graph
workflow_run.inputs = json.dumps(inputs)
workflow_run.status = WorkflowRunStatus.RUNNING
workflow_run.created_by_role = created_by_role
workflow_run.created_by = user_id
workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None)
self._workflow_execution_repository.save(execution)
session.add(workflow_run)
return execution
return workflow_run
def _handle_workflow_run_success(
def handle_workflow_run_success(
self,
*,
session: Session,
workflow_run_id: str,
start_at: float,
total_tokens: int,
total_steps: int,
outputs: Mapping[str, Any] | None = None,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
) -> WorkflowRun:
"""
Workflow run success
:param workflow_run_id: workflow run id
:param start_at: start time
:param total_tokens: total tokens
:param total_steps: total steps
:param outputs: outputs
:param conversation_id: conversation id
:return:
"""
workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id)
) -> WorkflowExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
outputs = WorkflowEntry.handle_special_values(outputs)
workflow_run.status = WorkflowRunStatus.SUCCEEDED
workflow_run.outputs = json.dumps(outputs or {})
workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_execution.status = WorkflowExecutionStatus.SUCCEEDED
workflow_execution.outputs = outputs or {}
workflow_execution.total_tokens = total_tokens
workflow_execution.total_steps = total_steps
workflow_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_run=workflow_run,
workflow_execution=workflow_execution,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
)
)
return workflow_run
return workflow_execution
def _handle_workflow_run_partial_success(
def handle_workflow_run_partial_success(
self,
*,
session: Session,
workflow_run_id: str,
start_at: float,
total_tokens: int,
total_steps: int,
outputs: Mapping[str, Any] | None = None,
exceptions_count: int = 0,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
) -> WorkflowRun:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id)
) -> WorkflowExecution:
execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCEEDED.value
workflow_run.outputs = json.dumps(outputs or {})
workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count
execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED
execution.outputs = outputs or {}
execution.total_tokens = total_tokens
execution.total_steps = total_steps
execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
execution.exceptions_count = exceptions_count
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_run=workflow_run,
workflow_execution=execution,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
)
)
return workflow_run
return execution
def _handle_workflow_run_failed(
def handle_workflow_run_failed(
self,
*,
session: Session,
workflow_run_id: str,
start_at: float,
total_tokens: int,
total_steps: int,
status: WorkflowRunStatus,
error: str,
error_message: str,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
exceptions_count: int = 0,
) -> WorkflowRun:
"""
Workflow run failed
:param workflow_run_id: workflow run id
:param start_at: start time
:param total_tokens: total tokens
:param total_steps: total steps
:param status: status
:param error: error message
:return:
"""
workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id)
) -> WorkflowExecution:
execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
workflow_run.status = status.value
workflow_run.error = error
workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count
execution.status = WorkflowExecutionStatus(status.value)
execution.error_message = error_message
execution.total_tokens = total_tokens
execution.total_steps = total_steps
execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
execution.exceptions_count = exceptions_count
# Use the instance repository to find running executions for a workflow run
running_domain_executions = self._workflow_node_execution_repository.get_running_executions(
workflow_run_id=workflow_run.id
workflow_run_id=execution.id
)
# Update the domain models
@ -269,7 +227,7 @@ class WorkflowCycleManager:
if domain_execution.node_execution_id:
# Update the domain model
domain_execution.status = NodeExecutionStatus.FAILED
domain_execution.error = error
domain_execution.error = error_message
domain_execution.finished_at = now
domain_execution.elapsed_time = (now - domain_execution.created_at).total_seconds()
@ -280,15 +238,22 @@ class WorkflowCycleManager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_run=workflow_run,
workflow_execution=execution,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
)
)
return workflow_run
return execution
def handle_node_execution_start(
self,
*,
workflow_execution_id: str,
event: QueueNodeStartedEvent,
) -> NodeExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
def _handle_node_execution_start(self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> NodeExecution:
# Create a domain model
created_at = datetime.now(UTC).replace(tzinfo=None)
metadata = {
@ -299,8 +264,8 @@ class WorkflowCycleManager:
domain_execution = NodeExecution(
id=str(uuid4()),
workflow_id=workflow_run.workflow_id,
workflow_run_id=workflow_run.id,
workflow_id=workflow_execution.workflow_id,
workflow_run_id=workflow_execution.id,
predecessor_node_id=event.predecessor_node_id,
index=event.node_run_index,
node_execution_id=event.node_execution_id,
@ -317,7 +282,7 @@ class WorkflowCycleManager:
return domain_execution
def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> NodeExecution:
def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> NodeExecution:
# Get the domain model from repository
domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id)
if not domain_execution:
@ -350,7 +315,7 @@ class WorkflowCycleManager:
return domain_execution
def _handle_workflow_node_execution_failed(
def handle_workflow_node_execution_failed(
self,
*,
event: QueueNodeFailedEvent
@ -400,15 +365,10 @@ class WorkflowCycleManager:
return domain_execution
def _handle_workflow_node_execution_retried(
self, *, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
def handle_workflow_node_execution_retried(
self, *, workflow_execution_id: str, event: QueueNodeRetryEvent
) -> NodeExecution:
"""
Workflow node execution failed
:param workflow_run: workflow run
:param event: queue node failed event
:return:
"""
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
created_at = event.start_at
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - created_at).total_seconds()
@ -433,8 +393,8 @@ class WorkflowCycleManager:
# Create a domain model
domain_execution = NodeExecution(
id=str(uuid4()),
workflow_id=workflow_run.workflow_id,
workflow_run_id=workflow_run.id,
workflow_id=workflow_execution.workflow_id,
workflow_run_id=workflow_execution.id,
predecessor_node_id=event.predecessor_node_id,
node_execution_id=event.node_execution_id,
node_id=event.node_id,
@ -456,34 +416,34 @@ class WorkflowCycleManager:
return domain_execution
def _workflow_start_to_stream_response(
def workflow_start_to_stream_response(
self,
*,
session: Session,
task_id: str,
workflow_run: WorkflowRun,
workflow_execution: WorkflowExecution,
) -> WorkflowStartStreamResponse:
_ = session
return WorkflowStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
workflow_run_id=workflow_execution.id,
data=WorkflowStartStreamResponse.Data(
id=workflow_run.id,
workflow_id=workflow_run.workflow_id,
sequence_number=workflow_run.sequence_number,
inputs=dict(workflow_run.inputs_dict or {}),
created_at=int(workflow_run.created_at.timestamp()),
id=workflow_execution.id,
workflow_id=workflow_execution.workflow_id,
sequence_number=workflow_execution.sequence_number,
inputs=workflow_execution.inputs,
created_at=int(workflow_execution.started_at.timestamp()),
),
)
def _workflow_finish_to_stream_response(
def workflow_finish_to_stream_response(
self,
*,
session: Session,
task_id: str,
workflow_run: WorkflowRun,
workflow_execution: WorkflowExecution,
) -> WorkflowFinishStreamResponse:
created_by = None
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id))
assert workflow_run is not None
if workflow_run.created_by_role == CreatorUserRole.ACCOUNT:
stmt = select(Account).where(Account.id == workflow_run.created_by)
account = session.scalar(stmt)
@ -504,28 +464,35 @@ class WorkflowCycleManager:
else:
raise NotImplementedError(f"unknown created_by_role: {workflow_run.created_by_role}")
# Handle the case where finished_at is None by using current time as default
finished_at_timestamp = (
int(workflow_execution.finished_at.timestamp())
if workflow_execution.finished_at
else int(datetime.now(UTC).timestamp())
)
return WorkflowFinishStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
workflow_run_id=workflow_execution.id,
data=WorkflowFinishStreamResponse.Data(
id=workflow_run.id,
workflow_id=workflow_run.workflow_id,
sequence_number=workflow_run.sequence_number,
status=workflow_run.status,
outputs=dict(workflow_run.outputs_dict) if workflow_run.outputs_dict else None,
error=workflow_run.error,
elapsed_time=workflow_run.elapsed_time,
total_tokens=workflow_run.total_tokens,
total_steps=workflow_run.total_steps,
id=workflow_execution.id,
workflow_id=workflow_execution.workflow_id,
sequence_number=workflow_execution.sequence_number,
status=workflow_execution.status,
outputs=workflow_execution.outputs,
error=workflow_execution.error_message,
elapsed_time=workflow_execution.elapsed_time,
total_tokens=workflow_execution.total_tokens,
total_steps=workflow_execution.total_steps,
created_by=created_by,
created_at=int(workflow_run.created_at.timestamp()),
finished_at=int(workflow_run.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(dict(workflow_run.outputs_dict)),
exceptions_count=workflow_run.exceptions_count,
created_at=int(workflow_execution.started_at.timestamp()),
finished_at=finished_at_timestamp,
files=self.fetch_files_from_node_outputs(workflow_execution.outputs),
exceptions_count=workflow_execution.exceptions_count,
),
)
def _workflow_node_start_to_stream_response(
def workflow_node_start_to_stream_response(
self,
*,
event: QueueNodeStartedEvent,
@ -571,7 +538,7 @@ class WorkflowCycleManager:
return response
def _workflow_node_finish_to_stream_response(
def workflow_node_finish_to_stream_response(
self,
*,
event: QueueNodeSucceededEvent
@ -608,7 +575,7 @@ class WorkflowCycleManager:
execution_metadata=workflow_node_execution.metadata,
created_at=int(workflow_node_execution.created_at.timestamp()),
finished_at=int(workflow_node_execution.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
@ -618,7 +585,7 @@ class WorkflowCycleManager:
),
)
def _workflow_node_retry_to_stream_response(
def workflow_node_retry_to_stream_response(
self,
*,
event: QueueNodeRetryEvent,
@ -651,7 +618,7 @@ class WorkflowCycleManager:
execution_metadata=workflow_node_execution.metadata,
created_at=int(workflow_node_execution.created_at.timestamp()),
finished_at=int(workflow_node_execution.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
@ -662,13 +629,16 @@ class WorkflowCycleManager:
),
)
def _workflow_parallel_branch_start_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
def workflow_parallel_branch_start_to_stream_response(
self,
*,
task_id: str,
workflow_execution_id: str,
event: QueueParallelBranchRunStartedEvent,
) -> ParallelBranchStartStreamResponse:
_ = session
return ParallelBranchStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
workflow_run_id=workflow_execution_id,
data=ParallelBranchStartStreamResponse.Data(
parallel_id=event.parallel_id,
parallel_branch_id=event.parallel_start_node_id,
@ -680,18 +650,16 @@ class WorkflowCycleManager:
),
)
def _workflow_parallel_branch_finished_to_stream_response(
def workflow_parallel_branch_finished_to_stream_response(
self,
*,
session: Session,
task_id: str,
workflow_run: WorkflowRun,
workflow_execution_id: str,
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
) -> ParallelBranchFinishedStreamResponse:
_ = session
return ParallelBranchFinishedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
workflow_run_id=workflow_execution_id,
data=ParallelBranchFinishedStreamResponse.Data(
parallel_id=event.parallel_id,
parallel_branch_id=event.parallel_start_node_id,
@ -705,13 +673,16 @@ class WorkflowCycleManager:
),
)
def _workflow_iteration_start_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent
def workflow_iteration_start_to_stream_response(
self,
*,
task_id: str,
workflow_execution_id: str,
event: QueueIterationStartEvent,
) -> IterationNodeStartStreamResponse:
_ = session
return IterationNodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
workflow_run_id=workflow_execution_id,
data=IterationNodeStartStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
@ -726,13 +697,16 @@ class WorkflowCycleManager:
),
)
def _workflow_iteration_next_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent
def workflow_iteration_next_to_stream_response(
self,
*,
task_id: str,
workflow_execution_id: str,
event: QueueIterationNextEvent,
) -> IterationNodeNextStreamResponse:
_ = session
return IterationNodeNextStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
workflow_run_id=workflow_execution_id,
data=IterationNodeNextStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
@ -749,13 +723,16 @@ class WorkflowCycleManager:
),
)
def _workflow_iteration_completed_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent
def workflow_iteration_completed_to_stream_response(
self,
*,
task_id: str,
workflow_execution_id: str,
event: QueueIterationCompletedEvent,
) -> IterationNodeCompletedStreamResponse:
_ = session
return IterationNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
workflow_run_id=workflow_execution_id,
data=IterationNodeCompletedStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
@ -779,13 +756,12 @@ class WorkflowCycleManager:
),
)
def _workflow_loop_start_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopStartEvent
def workflow_loop_start_to_stream_response(
self, *, task_id: str, workflow_execution_id: str, event: QueueLoopStartEvent
) -> LoopNodeStartStreamResponse:
_ = session
return LoopNodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
workflow_run_id=workflow_execution_id,
data=LoopNodeStartStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
@ -800,13 +776,16 @@ class WorkflowCycleManager:
),
)
def _workflow_loop_next_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopNextEvent
def workflow_loop_next_to_stream_response(
self,
*,
task_id: str,
workflow_execution_id: str,
event: QueueLoopNextEvent,
) -> LoopNodeNextStreamResponse:
_ = session
return LoopNodeNextStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
workflow_run_id=workflow_execution_id,
data=LoopNodeNextStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
@ -823,13 +802,16 @@ class WorkflowCycleManager:
),
)
def _workflow_loop_completed_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopCompletedEvent
def workflow_loop_completed_to_stream_response(
self,
*,
task_id: str,
workflow_execution_id: str,
event: QueueLoopCompletedEvent,
) -> LoopNodeCompletedStreamResponse:
_ = session
return LoopNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
workflow_run_id=workflow_execution_id,
data=LoopNodeCompletedStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
@ -853,7 +835,7 @@ class WorkflowCycleManager:
),
)
def _fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any]) -> Sequence[Mapping[str, Any]]:
def fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]:
"""
Fetch files from node outputs
:param outputs_dict: node outputs dict
@ -910,20 +892,13 @@ class WorkflowCycleManager:
return None
def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun:
if self._workflow_run and self._workflow_run.id == workflow_run_id:
cached_workflow_run = self._workflow_run
cached_workflow_run = session.merge(cached_workflow_run)
return cached_workflow_run
stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
workflow_run = session.scalar(stmt)
if not workflow_run:
raise WorkflowRunNotFoundError(workflow_run_id)
self._workflow_run = workflow_run
def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution:
execution = self._workflow_execution_repository.get(id)
if not execution:
raise WorkflowRunNotFoundError(id)
return execution
return workflow_run
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
def handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
"""
Handle agent log
:param task_id: task id

View File

@ -425,14 +425,14 @@ class WorkflowRun(Base):
status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded
outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}")
error: Mapped[Optional[str]] = mapped_column(db.Text)
elapsed_time = db.Column(db.Float, nullable=False, server_default=sa.text("0"))
elapsed_time: Mapped[float] = mapped_column(db.Float, nullable=False, server_default=sa.text("0"))
total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))
total_steps = db.Column(db.Integer, server_default=db.text("0"))
total_steps: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"))
created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
finished_at = db.Column(db.DateTime)
exceptions_count = db.Column(db.Integer, server_default=db.text("0"))
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
exceptions_count: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"))
@property
def created_by_account(self):
@ -447,7 +447,7 @@ class WorkflowRun(Base):
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
@property
def graph_dict(self):
def graph_dict(self) -> Mapping[str, Any]:
return json.loads(self.graph) if self.graph else {}
@property
@ -752,12 +752,12 @@ class WorkflowAppLog(Base):
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID)
app_id: Mapped[str] = mapped_column(StringUUID)
workflow_id = db.Column(StringUUID, nullable=False)
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
workflow_run_id: Mapped[str] = mapped_column(StringUUID)
created_from = db.Column(db.String(255), nullable=False)
created_by_role = db.Column(db.String(255), nullable=False)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_from: Mapped[str] = mapped_column(db.String(255), nullable=False)
created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def workflow_run(self):
@ -782,9 +782,11 @@ class ConversationVariable(Base):
id: Mapped[str] = mapped_column(StringUUID, primary_key=True)
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True, index=True)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True)
data = mapped_column(db.Text, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp(), index=True)
updated_at = mapped_column(
data: Mapped[str] = mapped_column(db.Text, nullable=False)
created_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), index=True
)
updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
@ -832,14 +834,14 @@ class WorkflowDraftVariable(Base):
# id is the unique identifier of a draft variable.
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
created_at = mapped_column(
created_at: Mapped[datetime] = mapped_column(
db.DateTime,
nullable=False,
default=_naive_utc_datetime,
server_default=func.current_timestamp(),
)
updated_at = mapped_column(
updated_at: Mapped[datetime] = mapped_column(
db.DateTime,
nullable=False,
default=_naive_utc_datetime,

View File

@ -1,45 +1,73 @@
import json
import time
from datetime import UTC, datetime
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock
import pytest
from sqlalchemy.orm import Session
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.queue_entities import (
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
)
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.node_execution_entities import NodeExecution, NodeExecutionStatus
from core.workflow.entities.workflow_execution_entities import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
from models.enums import CreatorUserRole
from models.model import AppMode
from models.workflow import (
Workflow,
WorkflowNodeExecutionStatus,
WorkflowRun,
WorkflowRunStatus,
)
@pytest.fixture
def mock_app_generate_entity():
entity = MagicMock(spec=AdvancedChatAppGenerateEntity)
entity.inputs = {"query": "test query"}
entity.invoke_from = InvokeFrom.WEB_APP
# Create app_config as a separate mock
app_config = MagicMock()
app_config.tenant_id = "test-tenant-id"
app_config.app_id = "test-app-id"
entity.app_config = app_config
def real_app_generate_entity():
additional_features = AppAdditionalFeatures(
file_upload=None,
opening_statement=None,
suggested_questions=[],
suggested_questions_after_answer=False,
show_retrieve_source=False,
more_like_this=False,
speech_to_text=False,
text_to_speech=None,
trace_config=None,
)
app_config = WorkflowUIBasedAppConfig(
tenant_id="test-tenant-id",
app_id="test-app-id",
app_mode=AppMode.WORKFLOW,
additional_features=additional_features,
workflow_id="test-workflow-id",
)
entity = AdvancedChatAppGenerateEntity(
task_id="test-task-id",
app_config=app_config,
inputs={"query": "test query"},
files=[],
user_id="test-user-id",
stream=False,
invoke_from=InvokeFrom.WEB_APP,
query="test query",
conversation_id="test-conversation-id",
)
return entity
@pytest.fixture
def mock_workflow_system_variables():
def real_workflow_system_variables():
return {
SystemVariableKey.QUERY: "test query",
SystemVariableKey.CONVERSATION_ID: "test-conversation-id",
@ -59,10 +87,23 @@ def mock_node_execution_repository():
@pytest.fixture
def workflow_cycle_manager(mock_app_generate_entity, mock_workflow_system_variables, mock_node_execution_repository):
def mock_workflow_execution_repository():
repo = MagicMock(spec=WorkflowExecutionRepository)
repo.get.return_value = None
return repo
@pytest.fixture
def workflow_cycle_manager(
real_app_generate_entity,
real_workflow_system_variables,
mock_workflow_execution_repository,
mock_node_execution_repository,
):
return WorkflowCycleManager(
application_generate_entity=mock_app_generate_entity,
workflow_system_variables=mock_workflow_system_variables,
application_generate_entity=real_app_generate_entity,
workflow_system_variables=real_workflow_system_variables,
workflow_execution_repository=mock_workflow_execution_repository,
workflow_node_execution_repository=mock_node_execution_repository,
)
@ -74,121 +115,173 @@ def mock_session():
@pytest.fixture
def mock_workflow():
workflow = MagicMock(spec=Workflow)
def real_workflow():
workflow = Workflow()
workflow.id = "test-workflow-id"
workflow.tenant_id = "test-tenant-id"
workflow.app_id = "test-app-id"
workflow.type = "chat"
workflow.version = "1.0"
workflow.graph = json.dumps({"nodes": [], "edges": []})
graph_data = {"nodes": [], "edges": []}
workflow.graph = json.dumps(graph_data)
workflow.features = json.dumps({"file_upload": {"enabled": False}})
workflow.created_by = "test-user-id"
workflow.created_at = datetime.now(UTC).replace(tzinfo=None)
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
workflow._environment_variables = "{}"
workflow._conversation_variables = "{}"
return workflow
@pytest.fixture
def mock_workflow_run():
workflow_run = MagicMock(spec=WorkflowRun)
def real_workflow_run():
workflow_run = WorkflowRun()
workflow_run.id = "test-workflow-run-id"
workflow_run.tenant_id = "test-tenant-id"
workflow_run.app_id = "test-app-id"
workflow_run.workflow_id = "test-workflow-id"
workflow_run.sequence_number = 1
workflow_run.type = "chat"
workflow_run.triggered_from = "app-run"
workflow_run.version = "1.0"
workflow_run.graph = json.dumps({"nodes": [], "edges": []})
workflow_run.inputs = json.dumps({"query": "test query"})
workflow_run.status = WorkflowRunStatus.RUNNING
workflow_run.outputs = json.dumps({"answer": "test answer"})
workflow_run.created_by_role = CreatorUserRole.ACCOUNT
workflow_run.created_by = "test-user-id"
workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.inputs_dict = {"query": "test query"}
workflow_run.outputs_dict = {"answer": "test answer"}
return workflow_run
def test_init(
workflow_cycle_manager, mock_app_generate_entity, mock_workflow_system_variables, mock_node_execution_repository
workflow_cycle_manager,
real_app_generate_entity,
real_workflow_system_variables,
mock_workflow_execution_repository,
mock_node_execution_repository,
):
"""Test initialization of WorkflowCycleManager"""
assert workflow_cycle_manager._workflow_run is None
assert workflow_cycle_manager._application_generate_entity == mock_app_generate_entity
assert workflow_cycle_manager._workflow_system_variables == mock_workflow_system_variables
assert workflow_cycle_manager._application_generate_entity == real_app_generate_entity
assert workflow_cycle_manager._workflow_system_variables == real_workflow_system_variables
assert workflow_cycle_manager._workflow_execution_repository == mock_workflow_execution_repository
assert workflow_cycle_manager._workflow_node_execution_repository == mock_node_execution_repository
def test_handle_workflow_run_start(workflow_cycle_manager, mock_session, mock_workflow):
"""Test _handle_workflow_run_start method"""
def test_handle_workflow_run_start(workflow_cycle_manager, mock_session, real_workflow):
"""Test handle_workflow_run_start method"""
# Mock session.scalar to return the workflow and max sequence
mock_session.scalar.side_effect = [mock_workflow, 5]
mock_session.scalar.side_effect = [real_workflow, 5]
# Call the method
workflow_run = workflow_cycle_manager._handle_workflow_run_start(
workflow_execution = workflow_cycle_manager.handle_workflow_run_start(
session=mock_session,
workflow_id="test-workflow-id",
user_id="test-user-id",
created_by_role=CreatorUserRole.ACCOUNT,
)
# Verify the result
assert workflow_run.tenant_id == mock_workflow.tenant_id
assert workflow_run.app_id == mock_workflow.app_id
assert workflow_run.workflow_id == mock_workflow.id
assert workflow_run.sequence_number == 6 # max_sequence + 1
assert workflow_run.status == WorkflowRunStatus.RUNNING
assert workflow_run.created_by_role == CreatorUserRole.ACCOUNT
assert workflow_run.created_by == "test-user-id"
assert workflow_execution.workflow_id == real_workflow.id
assert workflow_execution.sequence_number == 6 # max_sequence + 1
# Verify session.add was called
mock_session.add.assert_called_once_with(workflow_run)
# Verify the workflow_execution_repository.save was called
workflow_cycle_manager._workflow_execution_repository.save.assert_called_once_with(workflow_execution)
def test_handle_workflow_run_success(workflow_cycle_manager, mock_session, mock_workflow_run):
"""Test _handle_workflow_run_success method"""
# Mock _get_workflow_run to return the mock_workflow_run
with patch.object(workflow_cycle_manager, "_get_workflow_run", return_value=mock_workflow_run):
# Call the method
result = workflow_cycle_manager._handle_workflow_run_success(
session=mock_session,
workflow_run_id="test-workflow-run-id",
start_at=time.perf_counter() - 10, # 10 seconds ago
total_tokens=100,
total_steps=5,
outputs={"answer": "test answer"},
)
def test_handle_workflow_run_success(workflow_cycle_manager, mock_workflow_execution_repository):
"""Test handle_workflow_run_success method"""
# Create a real WorkflowExecution
# Verify the result
assert result == mock_workflow_run
assert result.status == WorkflowRunStatus.SUCCEEDED
assert result.outputs == json.dumps({"answer": "test answer"})
assert result.total_tokens == 100
assert result.total_steps == 5
assert result.finished_at is not None
workflow_execution = WorkflowExecution(
id="test-workflow-run-id",
workflow_id="test-workflow-id",
workflow_version="1.0",
sequence_number=1,
type=WorkflowType.CHAT,
graph={"nodes": [], "edges": []},
inputs={"query": "test query"},
started_at=datetime.now(UTC).replace(tzinfo=None),
)
# Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
# Call the method
result = workflow_cycle_manager.handle_workflow_run_success(
workflow_run_id="test-workflow-run-id",
total_tokens=100,
total_steps=5,
outputs={"answer": "test answer"},
)
# Verify the result
assert result == workflow_execution
assert result.status == WorkflowExecutionStatus.SUCCEEDED
assert result.outputs == {"answer": "test answer"}
assert result.total_tokens == 100
assert result.total_steps == 5
assert result.finished_at is not None
def test_handle_workflow_run_failed(workflow_cycle_manager, mock_session, mock_workflow_run):
"""Test _handle_workflow_run_failed method"""
# Mock _get_workflow_run to return the mock_workflow_run
with patch.object(workflow_cycle_manager, "_get_workflow_run", return_value=mock_workflow_run):
# Mock get_running_executions to return an empty list
workflow_cycle_manager._workflow_node_execution_repository.get_running_executions.return_value = []
def test_handle_workflow_run_failed(workflow_cycle_manager, mock_workflow_execution_repository):
"""Test handle_workflow_run_failed method"""
# Create a real WorkflowExecution
# Call the method
result = workflow_cycle_manager._handle_workflow_run_failed(
session=mock_session,
workflow_run_id="test-workflow-run-id",
start_at=time.perf_counter() - 10, # 10 seconds ago
total_tokens=50,
total_steps=3,
status=WorkflowRunStatus.FAILED,
error="Test error message",
)
workflow_execution = WorkflowExecution(
id="test-workflow-run-id",
workflow_id="test-workflow-id",
workflow_version="1.0",
sequence_number=1,
type=WorkflowType.CHAT,
graph={"nodes": [], "edges": []},
inputs={"query": "test query"},
started_at=datetime.now(UTC).replace(tzinfo=None),
)
# Verify the result
assert result == mock_workflow_run
assert result.status == WorkflowRunStatus.FAILED.value
assert result.error == "Test error message"
assert result.total_tokens == 50
assert result.total_steps == 3
assert result.finished_at is not None
# Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
# Mock get_running_executions to return an empty list
workflow_cycle_manager._workflow_node_execution_repository.get_running_executions.return_value = []
# Call the method
result = workflow_cycle_manager.handle_workflow_run_failed(
workflow_run_id="test-workflow-run-id",
total_tokens=50,
total_steps=3,
status=WorkflowRunStatus.FAILED,
error_message="Test error message",
)
# Verify the result
assert result == workflow_execution
assert result.status == WorkflowExecutionStatus(WorkflowRunStatus.FAILED.value)
assert result.error_message == "Test error message"
assert result.total_tokens == 50
assert result.total_steps == 3
assert result.finished_at is not None
def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_run):
"""Test _handle_node_execution_start method"""
def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_execution_repository):
"""Test handle_node_execution_start method"""
# Create a real WorkflowExecution
workflow_execution = WorkflowExecution(
id="test-workflow-execution-id",
workflow_id="test-workflow-id",
workflow_version="1.0",
sequence_number=1,
type=WorkflowType.CHAT,
graph={"nodes": [], "edges": []},
inputs={"query": "test query"},
started_at=datetime.now(UTC).replace(tzinfo=None),
)
# Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
# Create a mock event
event = MagicMock(spec=QueueNodeStartedEvent)
event.node_execution_id = "test-node-execution-id"
@ -207,129 +300,171 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_run):
event.in_loop_id = "test-loop-id"
# Call the method
result = workflow_cycle_manager._handle_node_execution_start(
workflow_run=mock_workflow_run,
result = workflow_cycle_manager.handle_node_execution_start(
workflow_execution_id=workflow_execution.id,
event=event,
)
# Verify the result
# NodeExecution doesn't have tenant_id attribute, it's handled at repository level
# assert result.tenant_id == mock_workflow_run.tenant_id
# assert result.app_id == mock_workflow_run.app_id
assert result.workflow_id == mock_workflow_run.workflow_id
assert result.workflow_run_id == mock_workflow_run.id
assert result.workflow_id == workflow_execution.workflow_id
assert result.workflow_run_id == workflow_execution.id
assert result.node_execution_id == event.node_execution_id
assert result.node_id == event.node_id
assert result.node_type == event.node_type
assert result.title == event.node_data.title
assert result.status == WorkflowNodeExecutionStatus.RUNNING.value
# NodeExecution doesn't have created_by_role and created_by attributes, they're handled at repository level
# assert result.created_by_role == mock_workflow_run.created_by_role
# assert result.created_by == mock_workflow_run.created_by
assert result.status == NodeExecutionStatus.RUNNING
# Verify save was called
workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(result)
def test_get_workflow_run(workflow_cycle_manager, mock_session, mock_workflow_run):
"""Test _get_workflow_run method"""
# Mock session.scalar to return the workflow run
mock_session.scalar.return_value = mock_workflow_run
def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_workflow_execution_repository):
"""Test _get_workflow_execution_or_raise_error method"""
# Create a real WorkflowExecution
# Call the method
result = workflow_cycle_manager._get_workflow_run(
session=mock_session,
workflow_run_id="test-workflow-run-id",
workflow_execution = WorkflowExecution(
id="test-workflow-run-id",
workflow_id="test-workflow-id",
workflow_version="1.0",
sequence_number=1,
type=WorkflowType.CHAT,
graph={"nodes": [], "edges": []},
inputs={"query": "test query"},
started_at=datetime.now(UTC).replace(tzinfo=None),
)
# Mock the repository get method to return the real execution
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
# Call the method
result = workflow_cycle_manager._get_workflow_execution_or_raise_error("test-workflow-run-id")
# Verify the result
assert result == mock_workflow_run
assert workflow_cycle_manager._workflow_run == mock_workflow_run
assert result == workflow_execution
# Test error case
workflow_cycle_manager._workflow_execution_repository.get.return_value = None
# Expect an error when execution is not found
with pytest.raises(ValueError):
workflow_cycle_manager._get_workflow_execution_or_raise_error("non-existent-id")
def test_handle_workflow_node_execution_success(workflow_cycle_manager):
"""Test _handle_workflow_node_execution_success method"""
"""Test handle_workflow_node_execution_success method"""
# Create a mock event
event = MagicMock(spec=QueueNodeSucceededEvent)
event.node_execution_id = "test-node-execution-id"
event.inputs = {"input": "test input"}
event.process_data = {"process": "test process"}
event.outputs = {"output": "test output"}
event.execution_metadata = {"metadata": "test metadata"}
event.execution_metadata = {NodeRunMetadataKey.TOTAL_TOKENS: 100}
event.start_at = datetime.now(UTC).replace(tzinfo=None)
# Create a mock node execution
node_execution = MagicMock()
node_execution.node_execution_id = "test-node-execution-id"
# Create a real node execution
node_execution = NodeExecution(
id="test-node-execution-record-id",
node_execution_id="test-node-execution-id",
workflow_id="test-workflow-id",
workflow_run_id="test-workflow-run-id",
index=1,
node_id="test-node-id",
node_type=NodeType.LLM,
title="Test Node",
created_at=datetime.now(UTC).replace(tzinfo=None),
)
# Mock the repository to return the node execution
workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution
# Call the method
result = workflow_cycle_manager._handle_workflow_node_execution_success(
result = workflow_cycle_manager.handle_workflow_node_execution_success(
event=event,
)
# Verify the result
assert result == node_execution
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED.value
assert result.status == NodeExecutionStatus.SUCCEEDED
# Verify save was called
workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(node_execution)
def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_session, mock_workflow_run):
"""Test _handle_workflow_run_partial_success method"""
# Mock _get_workflow_run to return the mock_workflow_run
with patch.object(workflow_cycle_manager, "_get_workflow_run", return_value=mock_workflow_run):
# Call the method
result = workflow_cycle_manager._handle_workflow_run_partial_success(
session=mock_session,
workflow_run_id="test-workflow-run-id",
start_at=time.perf_counter() - 10, # 10 seconds ago
total_tokens=75,
total_steps=4,
outputs={"partial_answer": "test partial answer"},
exceptions_count=2,
)
def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_workflow_execution_repository):
"""Test handle_workflow_run_partial_success method"""
# Create a real WorkflowExecution
# Verify the result
assert result == mock_workflow_run
assert result.status == WorkflowRunStatus.PARTIAL_SUCCEEDED.value
assert result.outputs == json.dumps({"partial_answer": "test partial answer"})
assert result.total_tokens == 75
assert result.total_steps == 4
assert result.exceptions_count == 2
assert result.finished_at is not None
workflow_execution = WorkflowExecution(
id="test-workflow-run-id",
workflow_id="test-workflow-id",
workflow_version="1.0",
sequence_number=1,
type=WorkflowType.CHAT,
graph={"nodes": [], "edges": []},
inputs={"query": "test query"},
started_at=datetime.now(UTC).replace(tzinfo=None),
)
# Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
# Call the method
result = workflow_cycle_manager.handle_workflow_run_partial_success(
workflow_run_id="test-workflow-run-id",
total_tokens=75,
total_steps=4,
outputs={"partial_answer": "test partial answer"},
exceptions_count=2,
)
# Verify the result
assert result == workflow_execution
assert result.status == WorkflowExecutionStatus.PARTIAL_SUCCEEDED
assert result.outputs == {"partial_answer": "test partial answer"}
assert result.total_tokens == 75
assert result.total_steps == 4
assert result.exceptions_count == 2
assert result.finished_at is not None
def test_handle_workflow_node_execution_failed(workflow_cycle_manager):
"""Test _handle_workflow_node_execution_failed method"""
"""Test handle_workflow_node_execution_failed method"""
# Create a mock event
event = MagicMock(spec=QueueNodeFailedEvent)
event.node_execution_id = "test-node-execution-id"
event.inputs = {"input": "test input"}
event.process_data = {"process": "test process"}
event.outputs = {"output": "test output"}
event.execution_metadata = {"metadata": "test metadata"}
event.execution_metadata = {NodeRunMetadataKey.TOTAL_TOKENS: 100}
event.start_at = datetime.now(UTC).replace(tzinfo=None)
event.error = "Test error message"
# Create a mock node execution
node_execution = MagicMock()
node_execution.node_execution_id = "test-node-execution-id"
# Create a real node execution
node_execution = NodeExecution(
id="test-node-execution-record-id",
node_execution_id="test-node-execution-id",
workflow_id="test-workflow-id",
workflow_run_id="test-workflow-run-id",
index=1,
node_id="test-node-id",
node_type=NodeType.LLM,
title="Test Node",
created_at=datetime.now(UTC).replace(tzinfo=None),
)
# Mock the repository to return the node execution
workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution
# Call the method
result = workflow_cycle_manager._handle_workflow_node_execution_failed(
result = workflow_cycle_manager.handle_workflow_node_execution_failed(
event=event,
)
# Verify the result
assert result == node_execution
assert result.status == WorkflowNodeExecutionStatus.FAILED.value
assert result.status == NodeExecutionStatus.FAILED
assert result.error == "Test error message"
# Verify save was called