refactor: streamline initialization of application_generate_entity and task_state in task pipeline classes (#12326)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2025-01-03 18:41:44 +08:00 committed by GitHub
parent 478150e850
commit 7ed6485f86
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 280 additions and 222 deletions

View File

@ -67,24 +67,17 @@ from models.account import Account
from models.enums import CreatedByRole from models.enums import CreatedByRole
from models.workflow import ( from models.workflow import (
Workflow, Workflow,
WorkflowNodeExecution,
WorkflowRunStatus, WorkflowRunStatus,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage, MessageCycleManage): class AdvancedChatAppGenerateTaskPipeline:
""" """
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
""" """
_task_state: WorkflowTaskState
_application_generate_entity: AdvancedChatAppGenerateEntity
_workflow_system_variables: dict[SystemVariableKey, Any]
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
_conversation_name_generate_thread: Optional[Thread] = None
def __init__( def __init__(
self, self,
application_generate_entity: AdvancedChatAppGenerateEntity, application_generate_entity: AdvancedChatAppGenerateEntity,
@ -96,7 +89,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
stream: bool, stream: bool,
dialogue_count: int, dialogue_count: int,
) -> None: ) -> None:
super().__init__( self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
stream=stream, stream=stream,
@ -113,32 +106,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
else: else:
raise NotImplementedError(f"User type not supported: {type(user)}") raise NotImplementedError(f"User type not supported: {type(user)}")
self._workflow_id = workflow.id self._workflow_cycle_manager = WorkflowCycleManage(
self._workflow_features_dict = workflow.features_dict application_generate_entity=application_generate_entity,
workflow_system_variables={
self._conversation_id = conversation.id SystemVariableKey.QUERY: message.query,
self._conversation_mode = conversation.mode SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.CONVERSATION_ID: conversation.id,
self._message_id = message.id SystemVariableKey.USER_ID: user_session_id,
self._message_created_at = int(message.created_at.timestamp()) SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
self._workflow_system_variables = { SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.QUERY: message.query, SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
SystemVariableKey.FILES: application_generate_entity.files, },
SystemVariableKey.CONVERSATION_ID: conversation.id, )
SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
}
self._task_state = WorkflowTaskState() self._task_state = WorkflowTaskState()
self._wip_workflow_node_executions = {} self._message_cycle_manager = MessageCycleManage(
application_generate_entity=application_generate_entity, task_state=self._task_state
)
self._conversation_name_generate_thread = None self._application_generate_entity = application_generate_entity
self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict
self._conversation_id = conversation.id
self._conversation_mode = conversation.mode
self._message_id = message.id
self._message_created_at = int(message.created_at.timestamp())
self._conversation_name_generate_thread: Thread | None = None
self._recorded_files: list[Mapping[str, Any]] = [] self._recorded_files: list[Mapping[str, Any]] = []
self._workflow_run_id = "" self._workflow_run_id: str = ""
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
""" """
@ -146,13 +142,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
:return: :return:
""" """
# start generate conversation name thread # start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name( self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name(
conversation_id=self._conversation_id, query=self._application_generate_entity.query conversation_id=self._conversation_id, query=self._application_generate_entity.query
) )
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream: if self._base_task_pipeline._stream:
return self._to_stream_response(generator) return self._to_stream_response(generator)
else: else:
return self._to_blocking_response(generator) return self._to_blocking_response(generator)
@ -269,24 +265,26 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# init fake graph runtime state # init fake graph runtime state
graph_runtime_state: Optional[GraphRuntimeState] = None graph_runtime_state: Optional[GraphRuntimeState] = None
for queue_message in self._queue_manager.listen(): for queue_message in self._base_task_pipeline._queue_manager.listen():
event = queue_message.event event = queue_message.event
if isinstance(event, QueuePingEvent): if isinstance(event, QueuePingEvent):
yield self._ping_stream_response() yield self._base_task_pipeline._ping_stream_response()
elif isinstance(event, QueueErrorEvent): elif isinstance(event, QueueErrorEvent):
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
err = self._handle_error(event=event, session=session, message_id=self._message_id) err = self._base_task_pipeline._handle_error(
event=event, session=session, message_id=self._message_id
)
session.commit() session.commit()
yield self._error_to_stream_response(err) yield self._base_task_pipeline._error_to_stream_response(err)
break break
elif isinstance(event, QueueWorkflowStartedEvent): elif isinstance(event, QueueWorkflowStartedEvent):
# override graph runtime state # override graph runtime state
graph_runtime_state = event.graph_runtime_state graph_runtime_state = event.graph_runtime_state
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
# init workflow run # init workflow run
workflow_run = self._handle_workflow_run_start( workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
session=session, session=session,
workflow_id=self._workflow_id, workflow_id=self._workflow_id,
user_id=self._user_id, user_id=self._user_id,
@ -297,7 +295,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not message: if not message:
raise ValueError(f"Message not found: {self._message_id}") raise ValueError(f"Message not found: {self._message_id}")
message.workflow_run_id = workflow_run.id message.workflow_run_id = workflow_run.id
workflow_start_resp = self._workflow_start_to_stream_response( workflow_start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
session.commit() session.commit()
@ -310,12 +308,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
workflow_node_execution = self._handle_workflow_node_execution_retried( session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
session=session, workflow_run=workflow_run, event=event session=session, workflow_run=workflow_run, event=event
) )
node_retry_resp = self._workflow_node_retry_to_stream_response( node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
session=session, session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
@ -329,13 +329,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
workflow_node_execution = self._handle_node_execution_start( session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
session=session, workflow_run=workflow_run, event=event session=session, workflow_run=workflow_run, event=event
) )
node_start_resp = self._workflow_node_start_to_stream_response( node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
session=session, session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
@ -348,12 +350,16 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(event, QueueNodeSucceededEvent): elif isinstance(event, QueueNodeSucceededEvent):
# Record files if it's an answer node or end node # Record files if it's an answer node or end node
if event.node_type in [NodeType.ANSWER, NodeType.END]: if event.node_type in [NodeType.ANSWER, NodeType.END]:
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {})) self._recorded_files.extend(
self._workflow_cycle_manager._fetch_files_from_node_outputs(event.outputs or {})
)
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event) workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
session=session, event=event
)
node_finish_resp = self._workflow_node_finish_to_stream_response( node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session, session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
@ -364,10 +370,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if node_finish_resp: if node_finish_resp:
yield node_finish_resp yield node_finish_resp
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._handle_workflow_node_execution_failed(session=session, event=event) workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
session=session, event=event
)
node_finish_resp = self._workflow_node_finish_to_stream_response( node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session, session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
@ -381,13 +389,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response( session=session, workflow_run_id=self._workflow_run_id
session=session, )
task_id=self._application_generate_entity.task_id, parallel_start_resp = (
workflow_run=workflow_run, self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
event=event, session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
) )
yield parallel_start_resp yield parallel_start_resp
@ -395,13 +407,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response( session=session, workflow_run_id=self._workflow_run_id
session=session, )
task_id=self._application_generate_entity.task_id, parallel_finish_resp = (
workflow_run=workflow_run, self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
event=event, session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
) )
yield parallel_finish_resp yield parallel_finish_resp
@ -409,9 +425,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
iter_start_resp = self._workflow_iteration_start_to_stream_response( session=session, workflow_run_id=self._workflow_run_id
)
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
session=session, session=session,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
@ -423,9 +441,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
iter_next_resp = self._workflow_iteration_next_to_stream_response( session=session, workflow_run_id=self._workflow_run_id
)
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
session=session, session=session,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
@ -437,9 +457,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
iter_finish_resp = self._workflow_iteration_completed_to_stream_response( session=session, workflow_run_id=self._workflow_run_id
)
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
session=session, session=session,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
@ -454,8 +476,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._handle_workflow_run_success( workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
session=session, session=session,
workflow_run_id=self._workflow_run_id, workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
@ -466,21 +488,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
trace_manager=trace_manager, trace_manager=trace_manager,
) )
workflow_finish_resp = self._workflow_finish_to_stream_response( workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
session.commit() session.commit()
yield workflow_finish_resp yield workflow_finish_resp
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) self._base_task_pipeline._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
)
elif isinstance(event, QueueWorkflowPartialSuccessEvent): elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._handle_workflow_run_partial_success( workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
session=session, session=session,
workflow_run_id=self._workflow_run_id, workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
@ -491,21 +515,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
conversation_id=None, conversation_id=None,
trace_manager=trace_manager, trace_manager=trace_manager,
) )
workflow_finish_resp = self._workflow_finish_to_stream_response( workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
session.commit() session.commit()
yield workflow_finish_resp yield workflow_finish_resp
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) self._base_task_pipeline._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
)
elif isinstance(event, QueueWorkflowFailedEvent): elif isinstance(event, QueueWorkflowFailedEvent):
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._handle_workflow_run_failed( workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
session=session, session=session,
workflow_run_id=self._workflow_run_id, workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
@ -517,20 +543,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
trace_manager=trace_manager, trace_manager=trace_manager,
exceptions_count=event.exceptions_count, exceptions_count=event.exceptions_count,
) )
workflow_finish_resp = self._workflow_finish_to_stream_response( workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}")) err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
err = self._handle_error(event=err_event, session=session, message_id=self._message_id) err = self._base_task_pipeline._handle_error(
event=err_event, session=session, message_id=self._message_id
)
session.commit() session.commit()
yield workflow_finish_resp yield workflow_finish_resp
yield self._error_to_stream_response(err) yield self._base_task_pipeline._error_to_stream_response(err)
break break
elif isinstance(event, QueueStopEvent): elif isinstance(event, QueueStopEvent):
if self._workflow_run_id and graph_runtime_state: if self._workflow_run_id and graph_runtime_state:
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._handle_workflow_run_failed( workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
session=session, session=session,
workflow_run_id=self._workflow_run_id, workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
@ -541,7 +569,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
conversation_id=self._conversation_id, conversation_id=self._conversation_id,
trace_manager=trace_manager, trace_manager=trace_manager,
) )
workflow_finish_resp = self._workflow_finish_to_stream_response( workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, session=session,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
@ -555,18 +583,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
yield self._message_end_to_stream_response() yield self._message_end_to_stream_response()
break break
elif isinstance(event, QueueRetrieverResourcesEvent): elif isinstance(event, QueueRetrieverResourcesEvent):
self._handle_retriever_resources(event) self._message_cycle_manager._handle_retriever_resources(event)
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
message = self._get_message(session=session) message = self._get_message(session=session)
message.message_metadata = ( message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
) )
session.commit() session.commit()
elif isinstance(event, QueueAnnotationReplyEvent): elif isinstance(event, QueueAnnotationReplyEvent):
self._handle_annotation_reply(event) self._message_cycle_manager._handle_annotation_reply(event)
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
message = self._get_message(session=session) message = self._get_message(session=session)
message.message_metadata = ( message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
@ -587,23 +615,27 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
tts_publisher.publish(queue_message) tts_publisher.publish(queue_message)
self._task_state.answer += delta_text self._task_state.answer += delta_text
yield self._message_to_stream_response( yield self._message_cycle_manager._message_to_stream_response(
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
) )
elif isinstance(event, QueueMessageReplaceEvent): elif isinstance(event, QueueMessageReplaceEvent):
# published by moderation # published by moderation
yield self._message_replace_to_stream_response(answer=event.text) yield self._message_cycle_manager._message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueueAdvancedChatMessageEndEvent): elif isinstance(event, QueueAdvancedChatMessageEndEvent):
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished(
self._task_state.answer
)
if output_moderation_answer: if output_moderation_answer:
self._task_state.answer = output_moderation_answer self._task_state.answer = output_moderation_answer
yield self._message_replace_to_stream_response(answer=output_moderation_answer) yield self._message_cycle_manager._message_replace_to_stream_response(
answer=output_moderation_answer
)
# Save message # Save message
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
self._save_message(session=session, graph_runtime_state=graph_runtime_state) self._save_message(session=session, graph_runtime_state=graph_runtime_state)
session.commit() session.commit()
@ -621,7 +653,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
message = self._get_message(session=session) message = self._get_message(session=session)
message.answer = self._task_state.answer message.answer = self._task_state.answer
message.provider_response_latency = time.perf_counter() - self._start_at message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
message.message_metadata = ( message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
) )
@ -685,20 +717,20 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
:param text: text :param text: text
:return: True if output moderation should direct output, otherwise False :return: True if output moderation should direct output, otherwise False
""" """
if self._output_moderation_handler: if self._base_task_pipeline._output_moderation_handler:
if self._output_moderation_handler.should_direct_output(): if self._base_task_pipeline._output_moderation_handler.should_direct_output():
# stop subscribe new token when output moderation should direct output # stop subscribe new token when output moderation should direct output
self._task_state.answer = self._output_moderation_handler.get_final_output() self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output()
self._queue_manager.publish( self._base_task_pipeline._queue_manager.publish(
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
) )
self._queue_manager.publish( self._base_task_pipeline._queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
) )
return True return True
else: else:
self._output_moderation_handler.append_new_token(text) self._base_task_pipeline._output_moderation_handler.append_new_token(text)
return False return False

View File

@ -1,7 +1,7 @@
import logging import logging
import time import time
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Optional, Union from typing import Optional, Union
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -58,7 +58,6 @@ from models.workflow import (
Workflow, Workflow,
WorkflowAppLog, WorkflowAppLog,
WorkflowAppLogCreatedFrom, WorkflowAppLogCreatedFrom,
WorkflowNodeExecution,
WorkflowRun, WorkflowRun,
WorkflowRunStatus, WorkflowRunStatus,
) )
@ -66,16 +65,11 @@ from models.workflow import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage): class WorkflowAppGenerateTaskPipeline:
""" """
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
""" """
_task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariableKey, Any]
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
def __init__( def __init__(
self, self,
application_generate_entity: WorkflowAppGenerateEntity, application_generate_entity: WorkflowAppGenerateEntity,
@ -84,7 +78,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
user: Union[Account, EndUser], user: Union[Account, EndUser],
stream: bool, stream: bool,
) -> None: ) -> None:
super().__init__( self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
stream=stream, stream=stream,
@ -101,19 +95,21 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
else: else:
raise ValueError(f"Invalid user type: {type(user)}") raise ValueError(f"Invalid user type: {type(user)}")
self._workflow_cycle_manager = WorkflowCycleManage(
application_generate_entity=application_generate_entity,
workflow_system_variables={
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
},
)
self._application_generate_entity = application_generate_entity
self._workflow_id = workflow.id self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict self._workflow_features_dict = workflow.features_dict
self._workflow_system_variables = {
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
}
self._task_state = WorkflowTaskState() self._task_state = WorkflowTaskState()
self._wip_workflow_node_executions = {}
self._workflow_run_id = "" self._workflow_run_id = ""
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
@ -122,7 +118,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
:return: :return:
""" """
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream: if self._base_task_pipeline._stream:
return self._to_stream_response(generator) return self._to_stream_response(generator)
else: else:
return self._to_blocking_response(generator) return self._to_blocking_response(generator)
@ -237,29 +233,29 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
""" """
graph_runtime_state = None graph_runtime_state = None
for queue_message in self._queue_manager.listen(): for queue_message in self._base_task_pipeline._queue_manager.listen():
event = queue_message.event event = queue_message.event
if isinstance(event, QueuePingEvent): if isinstance(event, QueuePingEvent):
yield self._ping_stream_response() yield self._base_task_pipeline._ping_stream_response()
elif isinstance(event, QueueErrorEvent): elif isinstance(event, QueueErrorEvent):
err = self._handle_error(event=event) err = self._base_task_pipeline._handle_error(event=event)
yield self._error_to_stream_response(err) yield self._base_task_pipeline._error_to_stream_response(err)
break break
elif isinstance(event, QueueWorkflowStartedEvent): elif isinstance(event, QueueWorkflowStartedEvent):
# override graph runtime state # override graph runtime state
graph_runtime_state = event.graph_runtime_state graph_runtime_state = event.graph_runtime_state
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
# init workflow run # init workflow run
workflow_run = self._handle_workflow_run_start( workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
session=session, session=session,
workflow_id=self._workflow_id, workflow_id=self._workflow_id,
user_id=self._user_id, user_id=self._user_id,
created_by_role=self._created_by_role, created_by_role=self._created_by_role,
) )
self._workflow_run_id = workflow_run.id self._workflow_run_id = workflow_run.id
start_resp = self._workflow_start_to_stream_response( start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
session.commit() session.commit()
@ -271,12 +267,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
): ):
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
workflow_node_execution = self._handle_workflow_node_execution_retried( session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
session=session, workflow_run=workflow_run, event=event session=session, workflow_run=workflow_run, event=event
) )
response = self._workflow_node_retry_to_stream_response( response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
session=session, session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
@ -290,12 +288,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
workflow_node_execution = self._handle_node_execution_start( session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
session=session, workflow_run=workflow_run, event=event session=session, workflow_run=workflow_run, event=event
) )
node_start_response = self._workflow_node_start_to_stream_response( node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
session=session, session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
@ -306,9 +306,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if node_start_response: if node_start_response:
yield node_start_response yield node_start_response
elif isinstance(event, QueueNodeSucceededEvent): elif isinstance(event, QueueNodeSucceededEvent):
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event) workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
node_success_response = self._workflow_node_finish_to_stream_response( session=session, event=event
)
node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session, session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
@ -319,12 +321,12 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if node_success_response: if node_success_response:
yield node_success_response yield node_success_response
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._handle_workflow_node_execution_failed( workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
session=session, session=session,
event=event, event=event,
) )
node_failed_response = self._workflow_node_finish_to_stream_response( node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session, session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
@ -339,13 +341,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response( session=session, workflow_run_id=self._workflow_run_id
session=session, )
task_id=self._application_generate_entity.task_id, parallel_start_resp = (
workflow_run=workflow_run, self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
event=event, session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
) )
yield parallel_start_resp yield parallel_start_resp
@ -354,13 +360,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response( session=session, workflow_run_id=self._workflow_run_id
session=session, )
task_id=self._application_generate_entity.task_id, parallel_finish_resp = (
workflow_run=workflow_run, self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
event=event, session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
) )
yield parallel_finish_resp yield parallel_finish_resp
@ -369,9 +379,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
iter_start_resp = self._workflow_iteration_start_to_stream_response( session=session, workflow_run_id=self._workflow_run_id
)
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
session=session, session=session,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
@ -384,9 +396,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
iter_next_resp = self._workflow_iteration_next_to_stream_response( session=session, workflow_run_id=self._workflow_run_id
)
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
session=session, session=session,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
@ -399,9 +413,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
iter_finish_resp = self._workflow_iteration_completed_to_stream_response( session=session, workflow_run_id=self._workflow_run_id
)
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
session=session, session=session,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
@ -416,8 +432,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._handle_workflow_run_success( workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
session=session, session=session,
workflow_run_id=self._workflow_run_id, workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
@ -431,7 +447,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# save workflow app log # save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run) self._save_workflow_app_log(session=session, workflow_run=workflow_run)
workflow_finish_resp = self._workflow_finish_to_stream_response( workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, session=session,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
@ -445,8 +461,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._handle_workflow_run_partial_success( workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
session=session, session=session,
workflow_run_id=self._workflow_run_id, workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
@ -461,7 +477,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# save workflow app log # save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run) self._save_workflow_app_log(session=session, workflow_run=workflow_run)
workflow_finish_resp = self._workflow_finish_to_stream_response( workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
session.commit() session.commit()
@ -473,8 +489,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._handle_workflow_run_failed( workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
session=session, session=session,
workflow_run_id=self._workflow_run_id, workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
@ -492,7 +508,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# save workflow app log # save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run) self._save_workflow_app_log(session=session, workflow_run=workflow_run)
workflow_finish_resp = self._workflow_finish_to_stream_response( workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
session.commit() session.commit()

View File

@ -15,7 +15,6 @@ from core.app.entities.queue_entities import (
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
ErrorStreamResponse, ErrorStreamResponse,
PingStreamResponse, PingStreamResponse,
TaskState,
) )
from core.errors.error import QuotaExceededError from core.errors.error import QuotaExceededError
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
@ -30,22 +29,12 @@ class BasedGenerateTaskPipeline:
BasedGenerateTaskPipeline is a class that generate stream output and state management for Application. BasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
""" """
_task_state: TaskState
_application_generate_entity: AppGenerateEntity
def __init__( def __init__(
self, self,
application_generate_entity: AppGenerateEntity, application_generate_entity: AppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
stream: bool, stream: bool,
) -> None: ) -> None:
"""
Initialize GenerateTaskPipeline.
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param user: user
:param stream: stream
"""
self._application_generate_entity = application_generate_entity self._application_generate_entity = application_generate_entity
self._queue_manager = queue_manager self._queue_manager = queue_manager
self._start_at = time.perf_counter() self._start_at = time.perf_counter()

View File

@ -31,10 +31,19 @@ from services.annotation_service import AppAnnotationService
class MessageCycleManage: class MessageCycleManage:
_application_generate_entity: Union[ def __init__(
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity self,
] *,
_task_state: Union[EasyUITaskState, WorkflowTaskState] application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity,
],
task_state: Union[EasyUITaskState, WorkflowTaskState],
) -> None:
self._application_generate_entity = application_generate_entity
self._task_state = task_state
def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]: def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
""" """

View File

@ -34,7 +34,6 @@ from core.app.entities.task_entities import (
ParallelBranchStartStreamResponse, ParallelBranchStartStreamResponse,
WorkflowFinishStreamResponse, WorkflowFinishStreamResponse,
WorkflowStartStreamResponse, WorkflowStartStreamResponse,
WorkflowTaskState,
) )
from core.file import FILE_MODEL_IDENTITY, File from core.file import FILE_MODEL_IDENTITY, File
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
@ -58,13 +57,20 @@ from models.workflow import (
WorkflowRunStatus, WorkflowRunStatus,
) )
from .exc import WorkflowNodeExecutionNotFoundError, WorkflowRunNotFoundError from .exc import WorkflowRunNotFoundError
class WorkflowCycleManage: class WorkflowCycleManage:
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] def __init__(
_task_state: WorkflowTaskState self,
_workflow_system_variables: dict[SystemVariableKey, Any] *,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
workflow_system_variables: dict[SystemVariableKey, Any],
) -> None:
self._workflow_run: WorkflowRun | None = None
self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {}
self._application_generate_entity = application_generate_entity
self._workflow_system_variables = workflow_system_variables
def _handle_workflow_run_start( def _handle_workflow_run_start(
self, self,
@ -240,7 +246,7 @@ class WorkflowCycleManage:
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count workflow_run.exceptions_count = exceptions_count
stmt = select(WorkflowNodeExecution).where( stmt = select(WorkflowNodeExecution.node_execution_id).where(
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
WorkflowNodeExecution.app_id == workflow_run.app_id, WorkflowNodeExecution.app_id == workflow_run.app_id,
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
@ -248,16 +254,18 @@ class WorkflowCycleManage:
WorkflowNodeExecution.workflow_run_id == workflow_run.id, WorkflowNodeExecution.workflow_run_id == workflow_run.id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value, WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
) )
ids = session.scalars(stmt).all()
running_workflow_node_executions = session.scalars(stmt).all() # Use self._get_workflow_node_execution here to make sure the cache is updated
running_workflow_node_executions = [
self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id
]
for workflow_node_execution in running_workflow_node_executions: for workflow_node_execution in running_workflow_node_executions:
now = datetime.now(UTC).replace(tzinfo=None)
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error workflow_node_execution.error = error
workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_node_execution.finished_at = now
workflow_node_execution.elapsed_time = ( workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds()
workflow_node_execution.finished_at - workflow_node_execution.created_at
).total_seconds()
if trace_manager: if trace_manager:
trace_manager.add_trace_task( trace_manager.add_trace_task(
@ -299,6 +307,8 @@ class WorkflowCycleManage:
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
session.add(workflow_node_execution) session.add(workflow_node_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution return workflow_node_execution
def _handle_workflow_node_execution_success( def _handle_workflow_node_execution_success(
@ -326,6 +336,7 @@ class WorkflowCycleManage:
workflow_node_execution.finished_at = finished_at workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution = session.merge(workflow_node_execution)
return workflow_node_execution return workflow_node_execution
def _handle_workflow_node_execution_failed( def _handle_workflow_node_execution_failed(
@ -365,6 +376,7 @@ class WorkflowCycleManage:
workflow_node_execution.elapsed_time = elapsed_time workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.execution_metadata = execution_metadata workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution = session.merge(workflow_node_execution)
return workflow_node_execution return workflow_node_execution
def _handle_workflow_node_execution_retried( def _handle_workflow_node_execution_retried(
@ -416,6 +428,8 @@ class WorkflowCycleManage:
workflow_node_execution.index = event.node_run_index workflow_node_execution.index = event.node_run_index
session.add(workflow_node_execution) session.add(workflow_node_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution return workflow_node_execution
################################################# #################################################
@ -812,22 +826,20 @@ class WorkflowCycleManage:
return None return None
def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun: def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun:
""" if self._workflow_run and self._workflow_run.id == workflow_run_id:
Refetch workflow run cached_workflow_run = self._workflow_run
:param workflow_run_id: workflow run id cached_workflow_run = session.merge(cached_workflow_run)
:return: return cached_workflow_run
"""
stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id) stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
workflow_run = session.scalar(stmt) workflow_run = session.scalar(stmt)
if not workflow_run: if not workflow_run:
raise WorkflowRunNotFoundError(workflow_run_id) raise WorkflowRunNotFoundError(workflow_run_id)
self._workflow_run = workflow_run
return workflow_run return workflow_run
def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution: def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution:
stmt = select(WorkflowNodeExecution).where(WorkflowNodeExecution.node_execution_id == node_execution_id) if node_execution_id not in self._workflow_node_executions:
workflow_node_execution = session.scalar(stmt) raise ValueError(f"Workflow node execution not found: {node_execution_id}")
if not workflow_node_execution: cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
raise WorkflowNodeExecutionNotFoundError(node_execution_id) return cached_workflow_node_execution
return workflow_node_execution