mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 04:18:58 +08:00
refactor: streamline initialization of application_generate_entity and task_state in task pipeline classes (#12326)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
478150e850
commit
7ed6485f86
@ -67,24 +67,17 @@ from models.account import Account
|
|||||||
from models.enums import CreatedByRole
|
from models.enums import CreatedByRole
|
||||||
from models.workflow import (
|
from models.workflow import (
|
||||||
Workflow,
|
Workflow,
|
||||||
WorkflowNodeExecution,
|
|
||||||
WorkflowRunStatus,
|
WorkflowRunStatus,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage, MessageCycleManage):
|
class AdvancedChatAppGenerateTaskPipeline:
|
||||||
"""
|
"""
|
||||||
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_task_state: WorkflowTaskState
|
|
||||||
_application_generate_entity: AdvancedChatAppGenerateEntity
|
|
||||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
|
||||||
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
|
||||||
_conversation_name_generate_thread: Optional[Thread] = None
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||||
@ -96,7 +89,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
stream: bool,
|
stream: bool,
|
||||||
dialogue_count: int,
|
dialogue_count: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
self._base_task_pipeline = BasedGenerateTaskPipeline(
|
||||||
application_generate_entity=application_generate_entity,
|
application_generate_entity=application_generate_entity,
|
||||||
queue_manager=queue_manager,
|
queue_manager=queue_manager,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
@ -113,16 +106,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"User type not supported: {type(user)}")
|
raise NotImplementedError(f"User type not supported: {type(user)}")
|
||||||
|
|
||||||
self._workflow_id = workflow.id
|
self._workflow_cycle_manager = WorkflowCycleManage(
|
||||||
self._workflow_features_dict = workflow.features_dict
|
application_generate_entity=application_generate_entity,
|
||||||
|
workflow_system_variables={
|
||||||
self._conversation_id = conversation.id
|
|
||||||
self._conversation_mode = conversation.mode
|
|
||||||
|
|
||||||
self._message_id = message.id
|
|
||||||
self._message_created_at = int(message.created_at.timestamp())
|
|
||||||
|
|
||||||
self._workflow_system_variables = {
|
|
||||||
SystemVariableKey.QUERY: message.query,
|
SystemVariableKey.QUERY: message.query,
|
||||||
SystemVariableKey.FILES: application_generate_entity.files,
|
SystemVariableKey.FILES: application_generate_entity.files,
|
||||||
SystemVariableKey.CONVERSATION_ID: conversation.id,
|
SystemVariableKey.CONVERSATION_ID: conversation.id,
|
||||||
@ -131,14 +117,24 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||||
}
|
},
|
||||||
|
)
|
||||||
|
|
||||||
self._task_state = WorkflowTaskState()
|
self._task_state = WorkflowTaskState()
|
||||||
self._wip_workflow_node_executions = {}
|
self._message_cycle_manager = MessageCycleManage(
|
||||||
|
application_generate_entity=application_generate_entity, task_state=self._task_state
|
||||||
|
)
|
||||||
|
|
||||||
self._conversation_name_generate_thread = None
|
self._application_generate_entity = application_generate_entity
|
||||||
|
self._workflow_id = workflow.id
|
||||||
|
self._workflow_features_dict = workflow.features_dict
|
||||||
|
self._conversation_id = conversation.id
|
||||||
|
self._conversation_mode = conversation.mode
|
||||||
|
self._message_id = message.id
|
||||||
|
self._message_created_at = int(message.created_at.timestamp())
|
||||||
|
self._conversation_name_generate_thread: Thread | None = None
|
||||||
self._recorded_files: list[Mapping[str, Any]] = []
|
self._recorded_files: list[Mapping[str, Any]] = []
|
||||||
self._workflow_run_id = ""
|
self._workflow_run_id: str = ""
|
||||||
|
|
||||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||||
"""
|
"""
|
||||||
@ -146,13 +142,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# start generate conversation name thread
|
# start generate conversation name thread
|
||||||
self._conversation_name_generate_thread = self._generate_conversation_name(
|
self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name(
|
||||||
conversation_id=self._conversation_id, query=self._application_generate_entity.query
|
conversation_id=self._conversation_id, query=self._application_generate_entity.query
|
||||||
)
|
)
|
||||||
|
|
||||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||||
|
|
||||||
if self._stream:
|
if self._base_task_pipeline._stream:
|
||||||
return self._to_stream_response(generator)
|
return self._to_stream_response(generator)
|
||||||
else:
|
else:
|
||||||
return self._to_blocking_response(generator)
|
return self._to_blocking_response(generator)
|
||||||
@ -269,24 +265,26 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
# init fake graph runtime state
|
# init fake graph runtime state
|
||||||
graph_runtime_state: Optional[GraphRuntimeState] = None
|
graph_runtime_state: Optional[GraphRuntimeState] = None
|
||||||
|
|
||||||
for queue_message in self._queue_manager.listen():
|
for queue_message in self._base_task_pipeline._queue_manager.listen():
|
||||||
event = queue_message.event
|
event = queue_message.event
|
||||||
|
|
||||||
if isinstance(event, QueuePingEvent):
|
if isinstance(event, QueuePingEvent):
|
||||||
yield self._ping_stream_response()
|
yield self._base_task_pipeline._ping_stream_response()
|
||||||
elif isinstance(event, QueueErrorEvent):
|
elif isinstance(event, QueueErrorEvent):
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
err = self._handle_error(event=event, session=session, message_id=self._message_id)
|
err = self._base_task_pipeline._handle_error(
|
||||||
|
event=event, session=session, message_id=self._message_id
|
||||||
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
yield self._error_to_stream_response(err)
|
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||||
break
|
break
|
||||||
elif isinstance(event, QueueWorkflowStartedEvent):
|
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||||
# override graph runtime state
|
# override graph runtime state
|
||||||
graph_runtime_state = event.graph_runtime_state
|
graph_runtime_state = event.graph_runtime_state
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
# init workflow run
|
# init workflow run
|
||||||
workflow_run = self._handle_workflow_run_start(
|
workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
|
||||||
session=session,
|
session=session,
|
||||||
workflow_id=self._workflow_id,
|
workflow_id=self._workflow_id,
|
||||||
user_id=self._user_id,
|
user_id=self._user_id,
|
||||||
@ -297,7 +295,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
if not message:
|
if not message:
|
||||||
raise ValueError(f"Message not found: {self._message_id}")
|
raise ValueError(f"Message not found: {self._message_id}")
|
||||||
message.workflow_run_id = workflow_run.id
|
message.workflow_run_id = workflow_run.id
|
||||||
workflow_start_resp = self._workflow_start_to_stream_response(
|
workflow_start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
|
||||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||||
)
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
@ -310,12 +308,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
|
||||||
session=session, workflow_run=workflow_run, event=event
|
session=session, workflow_run=workflow_run, event=event
|
||||||
)
|
)
|
||||||
node_retry_resp = self._workflow_node_retry_to_stream_response(
|
node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
@ -329,13 +329,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
workflow_node_execution = self._handle_node_execution_start(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
|
||||||
session=session, workflow_run=workflow_run, event=event
|
session=session, workflow_run=workflow_run, event=event
|
||||||
)
|
)
|
||||||
|
|
||||||
node_start_resp = self._workflow_node_start_to_stream_response(
|
node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
@ -348,12 +350,16 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
elif isinstance(event, QueueNodeSucceededEvent):
|
elif isinstance(event, QueueNodeSucceededEvent):
|
||||||
# Record files if it's an answer node or end node
|
# Record files if it's an answer node or end node
|
||||||
if event.node_type in [NodeType.ANSWER, NodeType.END]:
|
if event.node_type in [NodeType.ANSWER, NodeType.END]:
|
||||||
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
|
self._recorded_files.extend(
|
||||||
|
self._workflow_cycle_manager._fetch_files_from_node_outputs(event.outputs or {})
|
||||||
|
)
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event)
|
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
|
||||||
|
session=session, event=event
|
||||||
|
)
|
||||||
|
|
||||||
node_finish_resp = self._workflow_node_finish_to_stream_response(
|
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
@ -364,10 +370,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
if node_finish_resp:
|
if node_finish_resp:
|
||||||
yield node_finish_resp
|
yield node_finish_resp
|
||||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_node_execution = self._handle_workflow_node_execution_failed(session=session, event=event)
|
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
|
||||||
|
session=session, event=event
|
||||||
|
)
|
||||||
|
|
||||||
node_finish_resp = self._workflow_node_finish_to_stream_response(
|
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
@ -381,37 +389,47 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
parallel_start_resp = (
|
||||||
|
self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run=workflow_run,
|
workflow_run=workflow_run,
|
||||||
event=event,
|
event=event,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
yield parallel_start_resp
|
yield parallel_start_resp
|
||||||
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
|
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
|
||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
parallel_finish_resp = (
|
||||||
|
self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run=workflow_run,
|
workflow_run=workflow_run,
|
||||||
event=event,
|
event=event,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
yield parallel_finish_resp
|
yield parallel_finish_resp
|
||||||
elif isinstance(event, QueueIterationStartEvent):
|
elif isinstance(event, QueueIterationStartEvent):
|
||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
iter_start_resp = self._workflow_iteration_start_to_stream_response(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run=workflow_run,
|
workflow_run=workflow_run,
|
||||||
@ -423,9 +441,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
iter_next_resp = self._workflow_iteration_next_to_stream_response(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run=workflow_run,
|
workflow_run=workflow_run,
|
||||||
@ -437,9 +457,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
iter_finish_resp = self._workflow_iteration_completed_to_stream_response(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run=workflow_run,
|
workflow_run=workflow_run,
|
||||||
@ -454,8 +476,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
if not graph_runtime_state:
|
if not graph_runtime_state:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._handle_workflow_run_success(
|
workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
|
||||||
session=session,
|
session=session,
|
||||||
workflow_run_id=self._workflow_run_id,
|
workflow_run_id=self._workflow_run_id,
|
||||||
start_at=graph_runtime_state.start_at,
|
start_at=graph_runtime_state.start_at,
|
||||||
@ -466,21 +488,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
trace_manager=trace_manager,
|
trace_manager=trace_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||||
)
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
yield workflow_finish_resp
|
yield workflow_finish_resp
|
||||||
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
self._base_task_pipeline._queue_manager.publish(
|
||||||
|
QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
|
||||||
|
)
|
||||||
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
|
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
|
||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
if not graph_runtime_state:
|
if not graph_runtime_state:
|
||||||
raise ValueError("graph runtime state not initialized.")
|
raise ValueError("graph runtime state not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._handle_workflow_run_partial_success(
|
workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
|
||||||
session=session,
|
session=session,
|
||||||
workflow_run_id=self._workflow_run_id,
|
workflow_run_id=self._workflow_run_id,
|
||||||
start_at=graph_runtime_state.start_at,
|
start_at=graph_runtime_state.start_at,
|
||||||
@ -491,21 +515,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
conversation_id=None,
|
conversation_id=None,
|
||||||
trace_manager=trace_manager,
|
trace_manager=trace_manager,
|
||||||
)
|
)
|
||||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||||
)
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
yield workflow_finish_resp
|
yield workflow_finish_resp
|
||||||
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
self._base_task_pipeline._queue_manager.publish(
|
||||||
|
QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
|
||||||
|
)
|
||||||
elif isinstance(event, QueueWorkflowFailedEvent):
|
elif isinstance(event, QueueWorkflowFailedEvent):
|
||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
if not graph_runtime_state:
|
if not graph_runtime_state:
|
||||||
raise ValueError("graph runtime state not initialized.")
|
raise ValueError("graph runtime state not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._handle_workflow_run_failed(
|
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
|
||||||
session=session,
|
session=session,
|
||||||
workflow_run_id=self._workflow_run_id,
|
workflow_run_id=self._workflow_run_id,
|
||||||
start_at=graph_runtime_state.start_at,
|
start_at=graph_runtime_state.start_at,
|
||||||
@ -517,20 +543,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
trace_manager=trace_manager,
|
trace_manager=trace_manager,
|
||||||
exceptions_count=event.exceptions_count,
|
exceptions_count=event.exceptions_count,
|
||||||
)
|
)
|
||||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||||
)
|
)
|
||||||
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
|
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
|
||||||
err = self._handle_error(event=err_event, session=session, message_id=self._message_id)
|
err = self._base_task_pipeline._handle_error(
|
||||||
|
event=err_event, session=session, message_id=self._message_id
|
||||||
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
yield workflow_finish_resp
|
yield workflow_finish_resp
|
||||||
yield self._error_to_stream_response(err)
|
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||||
break
|
break
|
||||||
elif isinstance(event, QueueStopEvent):
|
elif isinstance(event, QueueStopEvent):
|
||||||
if self._workflow_run_id and graph_runtime_state:
|
if self._workflow_run_id and graph_runtime_state:
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._handle_workflow_run_failed(
|
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
|
||||||
session=session,
|
session=session,
|
||||||
workflow_run_id=self._workflow_run_id,
|
workflow_run_id=self._workflow_run_id,
|
||||||
start_at=graph_runtime_state.start_at,
|
start_at=graph_runtime_state.start_at,
|
||||||
@ -541,7 +569,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
conversation_id=self._conversation_id,
|
conversation_id=self._conversation_id,
|
||||||
trace_manager=trace_manager,
|
trace_manager=trace_manager,
|
||||||
)
|
)
|
||||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run=workflow_run,
|
workflow_run=workflow_run,
|
||||||
@ -555,18 +583,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
yield self._message_end_to_stream_response()
|
yield self._message_end_to_stream_response()
|
||||||
break
|
break
|
||||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||||
self._handle_retriever_resources(event)
|
self._message_cycle_manager._handle_retriever_resources(event)
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
message = self._get_message(session=session)
|
message = self._get_message(session=session)
|
||||||
message.message_metadata = (
|
message.message_metadata = (
|
||||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||||
)
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||||
self._handle_annotation_reply(event)
|
self._message_cycle_manager._handle_annotation_reply(event)
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
message = self._get_message(session=session)
|
message = self._get_message(session=session)
|
||||||
message.message_metadata = (
|
message.message_metadata = (
|
||||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||||
@ -587,23 +615,27 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
tts_publisher.publish(queue_message)
|
tts_publisher.publish(queue_message)
|
||||||
|
|
||||||
self._task_state.answer += delta_text
|
self._task_state.answer += delta_text
|
||||||
yield self._message_to_stream_response(
|
yield self._message_cycle_manager._message_to_stream_response(
|
||||||
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
|
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
|
||||||
)
|
)
|
||||||
elif isinstance(event, QueueMessageReplaceEvent):
|
elif isinstance(event, QueueMessageReplaceEvent):
|
||||||
# published by moderation
|
# published by moderation
|
||||||
yield self._message_replace_to_stream_response(answer=event.text)
|
yield self._message_cycle_manager._message_replace_to_stream_response(answer=event.text)
|
||||||
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
||||||
if not graph_runtime_state:
|
if not graph_runtime_state:
|
||||||
raise ValueError("graph runtime state not initialized.")
|
raise ValueError("graph runtime state not initialized.")
|
||||||
|
|
||||||
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
|
output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished(
|
||||||
|
self._task_state.answer
|
||||||
|
)
|
||||||
if output_moderation_answer:
|
if output_moderation_answer:
|
||||||
self._task_state.answer = output_moderation_answer
|
self._task_state.answer = output_moderation_answer
|
||||||
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
|
yield self._message_cycle_manager._message_replace_to_stream_response(
|
||||||
|
answer=output_moderation_answer
|
||||||
|
)
|
||||||
|
|
||||||
# Save message
|
# Save message
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
self._save_message(session=session, graph_runtime_state=graph_runtime_state)
|
self._save_message(session=session, graph_runtime_state=graph_runtime_state)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
@ -621,7 +653,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
|
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
|
||||||
message = self._get_message(session=session)
|
message = self._get_message(session=session)
|
||||||
message.answer = self._task_state.answer
|
message.answer = self._task_state.answer
|
||||||
message.provider_response_latency = time.perf_counter() - self._start_at
|
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
|
||||||
message.message_metadata = (
|
message.message_metadata = (
|
||||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||||
)
|
)
|
||||||
@ -685,20 +717,20 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
:param text: text
|
:param text: text
|
||||||
:return: True if output moderation should direct output, otherwise False
|
:return: True if output moderation should direct output, otherwise False
|
||||||
"""
|
"""
|
||||||
if self._output_moderation_handler:
|
if self._base_task_pipeline._output_moderation_handler:
|
||||||
if self._output_moderation_handler.should_direct_output():
|
if self._base_task_pipeline._output_moderation_handler.should_direct_output():
|
||||||
# stop subscribe new token when output moderation should direct output
|
# stop subscribe new token when output moderation should direct output
|
||||||
self._task_state.answer = self._output_moderation_handler.get_final_output()
|
self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output()
|
||||||
self._queue_manager.publish(
|
self._base_task_pipeline._queue_manager.publish(
|
||||||
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
|
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
|
||||||
)
|
)
|
||||||
|
|
||||||
self._queue_manager.publish(
|
self._base_task_pipeline._queue_manager.publish(
|
||||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
|
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
self._output_moderation_handler.append_new_token(text)
|
self._base_task_pipeline._output_moderation_handler.append_new_token(text)
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Any, Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@ -58,7 +58,6 @@ from models.workflow import (
|
|||||||
Workflow,
|
Workflow,
|
||||||
WorkflowAppLog,
|
WorkflowAppLog,
|
||||||
WorkflowAppLogCreatedFrom,
|
WorkflowAppLogCreatedFrom,
|
||||||
WorkflowNodeExecution,
|
|
||||||
WorkflowRun,
|
WorkflowRun,
|
||||||
WorkflowRunStatus,
|
WorkflowRunStatus,
|
||||||
)
|
)
|
||||||
@ -66,16 +65,11 @@ from models.workflow import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage):
|
class WorkflowAppGenerateTaskPipeline:
|
||||||
"""
|
"""
|
||||||
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_task_state: WorkflowTaskState
|
|
||||||
_application_generate_entity: WorkflowAppGenerateEntity
|
|
||||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
|
||||||
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
application_generate_entity: WorkflowAppGenerateEntity,
|
application_generate_entity: WorkflowAppGenerateEntity,
|
||||||
@ -84,7 +78,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
user: Union[Account, EndUser],
|
user: Union[Account, EndUser],
|
||||||
stream: bool,
|
stream: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
self._base_task_pipeline = BasedGenerateTaskPipeline(
|
||||||
application_generate_entity=application_generate_entity,
|
application_generate_entity=application_generate_entity,
|
||||||
queue_manager=queue_manager,
|
queue_manager=queue_manager,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
@ -101,19 +95,21 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid user type: {type(user)}")
|
raise ValueError(f"Invalid user type: {type(user)}")
|
||||||
|
|
||||||
self._workflow_id = workflow.id
|
self._workflow_cycle_manager = WorkflowCycleManage(
|
||||||
self._workflow_features_dict = workflow.features_dict
|
application_generate_entity=application_generate_entity,
|
||||||
|
workflow_system_variables={
|
||||||
self._workflow_system_variables = {
|
|
||||||
SystemVariableKey.FILES: application_generate_entity.files,
|
SystemVariableKey.FILES: application_generate_entity.files,
|
||||||
SystemVariableKey.USER_ID: user_session_id,
|
SystemVariableKey.USER_ID: user_session_id,
|
||||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||||
}
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self._application_generate_entity = application_generate_entity
|
||||||
|
self._workflow_id = workflow.id
|
||||||
|
self._workflow_features_dict = workflow.features_dict
|
||||||
self._task_state = WorkflowTaskState()
|
self._task_state = WorkflowTaskState()
|
||||||
self._wip_workflow_node_executions = {}
|
|
||||||
self._workflow_run_id = ""
|
self._workflow_run_id = ""
|
||||||
|
|
||||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||||
@ -122,7 +118,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||||
if self._stream:
|
if self._base_task_pipeline._stream:
|
||||||
return self._to_stream_response(generator)
|
return self._to_stream_response(generator)
|
||||||
else:
|
else:
|
||||||
return self._to_blocking_response(generator)
|
return self._to_blocking_response(generator)
|
||||||
@ -237,29 +233,29 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
"""
|
"""
|
||||||
graph_runtime_state = None
|
graph_runtime_state = None
|
||||||
|
|
||||||
for queue_message in self._queue_manager.listen():
|
for queue_message in self._base_task_pipeline._queue_manager.listen():
|
||||||
event = queue_message.event
|
event = queue_message.event
|
||||||
|
|
||||||
if isinstance(event, QueuePingEvent):
|
if isinstance(event, QueuePingEvent):
|
||||||
yield self._ping_stream_response()
|
yield self._base_task_pipeline._ping_stream_response()
|
||||||
elif isinstance(event, QueueErrorEvent):
|
elif isinstance(event, QueueErrorEvent):
|
||||||
err = self._handle_error(event=event)
|
err = self._base_task_pipeline._handle_error(event=event)
|
||||||
yield self._error_to_stream_response(err)
|
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||||
break
|
break
|
||||||
elif isinstance(event, QueueWorkflowStartedEvent):
|
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||||
# override graph runtime state
|
# override graph runtime state
|
||||||
graph_runtime_state = event.graph_runtime_state
|
graph_runtime_state = event.graph_runtime_state
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
# init workflow run
|
# init workflow run
|
||||||
workflow_run = self._handle_workflow_run_start(
|
workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
|
||||||
session=session,
|
session=session,
|
||||||
workflow_id=self._workflow_id,
|
workflow_id=self._workflow_id,
|
||||||
user_id=self._user_id,
|
user_id=self._user_id,
|
||||||
created_by_role=self._created_by_role,
|
created_by_role=self._created_by_role,
|
||||||
)
|
)
|
||||||
self._workflow_run_id = workflow_run.id
|
self._workflow_run_id = workflow_run.id
|
||||||
start_resp = self._workflow_start_to_stream_response(
|
start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
|
||||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||||
)
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
@ -271,12 +267,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
):
|
):
|
||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
|
||||||
session=session, workflow_run=workflow_run, event=event
|
session=session, workflow_run=workflow_run, event=event
|
||||||
)
|
)
|
||||||
response = self._workflow_node_retry_to_stream_response(
|
response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
@ -290,12 +288,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
workflow_node_execution = self._handle_node_execution_start(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
|
||||||
session=session, workflow_run=workflow_run, event=event
|
session=session, workflow_run=workflow_run, event=event
|
||||||
)
|
)
|
||||||
node_start_response = self._workflow_node_start_to_stream_response(
|
node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
@ -306,9 +306,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if node_start_response:
|
if node_start_response:
|
||||||
yield node_start_response
|
yield node_start_response
|
||||||
elif isinstance(event, QueueNodeSucceededEvent):
|
elif isinstance(event, QueueNodeSucceededEvent):
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event)
|
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
|
||||||
node_success_response = self._workflow_node_finish_to_stream_response(
|
session=session, event=event
|
||||||
|
)
|
||||||
|
node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
@ -319,12 +321,12 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if node_success_response:
|
if node_success_response:
|
||||||
yield node_success_response
|
yield node_success_response
|
||||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_node_execution = self._handle_workflow_node_execution_failed(
|
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
|
||||||
session=session,
|
session=session,
|
||||||
event=event,
|
event=event,
|
||||||
)
|
)
|
||||||
node_failed_response = self._workflow_node_finish_to_stream_response(
|
node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
@ -339,14 +341,18 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
parallel_start_resp = (
|
||||||
|
self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run=workflow_run,
|
workflow_run=workflow_run,
|
||||||
event=event,
|
event=event,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
yield parallel_start_resp
|
yield parallel_start_resp
|
||||||
|
|
||||||
@ -354,14 +360,18 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
parallel_finish_resp = (
|
||||||
|
self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run=workflow_run,
|
workflow_run=workflow_run,
|
||||||
event=event,
|
event=event,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
yield parallel_finish_resp
|
yield parallel_finish_resp
|
||||||
|
|
||||||
@ -369,9 +379,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
iter_start_resp = self._workflow_iteration_start_to_stream_response(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run=workflow_run,
|
workflow_run=workflow_run,
|
||||||
@ -384,9 +396,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
iter_next_resp = self._workflow_iteration_next_to_stream_response(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run=workflow_run,
|
workflow_run=workflow_run,
|
||||||
@ -399,9 +413,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
iter_finish_resp = self._workflow_iteration_completed_to_stream_response(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run=workflow_run,
|
workflow_run=workflow_run,
|
||||||
@ -416,8 +432,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if not graph_runtime_state:
|
if not graph_runtime_state:
|
||||||
raise ValueError("graph runtime state not initialized.")
|
raise ValueError("graph runtime state not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._handle_workflow_run_success(
|
workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
|
||||||
session=session,
|
session=session,
|
||||||
workflow_run_id=self._workflow_run_id,
|
workflow_run_id=self._workflow_run_id,
|
||||||
start_at=graph_runtime_state.start_at,
|
start_at=graph_runtime_state.start_at,
|
||||||
@ -431,7 +447,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
# save workflow app log
|
# save workflow app log
|
||||||
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
||||||
|
|
||||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run=workflow_run,
|
workflow_run=workflow_run,
|
||||||
@ -445,8 +461,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if not graph_runtime_state:
|
if not graph_runtime_state:
|
||||||
raise ValueError("graph runtime state not initialized.")
|
raise ValueError("graph runtime state not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._handle_workflow_run_partial_success(
|
workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
|
||||||
session=session,
|
session=session,
|
||||||
workflow_run_id=self._workflow_run_id,
|
workflow_run_id=self._workflow_run_id,
|
||||||
start_at=graph_runtime_state.start_at,
|
start_at=graph_runtime_state.start_at,
|
||||||
@ -461,7 +477,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
# save workflow app log
|
# save workflow app log
|
||||||
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
||||||
|
|
||||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||||
)
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
@ -473,8 +489,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if not graph_runtime_state:
|
if not graph_runtime_state:
|
||||||
raise ValueError("graph runtime state not initialized.")
|
raise ValueError("graph runtime state not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._handle_workflow_run_failed(
|
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
|
||||||
session=session,
|
session=session,
|
||||||
workflow_run_id=self._workflow_run_id,
|
workflow_run_id=self._workflow_run_id,
|
||||||
start_at=graph_runtime_state.start_at,
|
start_at=graph_runtime_state.start_at,
|
||||||
@ -492,7 +508,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
# save workflow app log
|
# save workflow app log
|
||||||
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
||||||
|
|
||||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||||
)
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
@ -15,7 +15,6 @@ from core.app.entities.queue_entities import (
|
|||||||
from core.app.entities.task_entities import (
|
from core.app.entities.task_entities import (
|
||||||
ErrorStreamResponse,
|
ErrorStreamResponse,
|
||||||
PingStreamResponse,
|
PingStreamResponse,
|
||||||
TaskState,
|
|
||||||
)
|
)
|
||||||
from core.errors.error import QuotaExceededError
|
from core.errors.error import QuotaExceededError
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||||
@ -30,22 +29,12 @@ class BasedGenerateTaskPipeline:
|
|||||||
BasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
BasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_task_state: TaskState
|
|
||||||
_application_generate_entity: AppGenerateEntity
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
application_generate_entity: AppGenerateEntity,
|
application_generate_entity: AppGenerateEntity,
|
||||||
queue_manager: AppQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
stream: bool,
|
stream: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
|
||||||
Initialize GenerateTaskPipeline.
|
|
||||||
:param application_generate_entity: application generate entity
|
|
||||||
:param queue_manager: queue manager
|
|
||||||
:param user: user
|
|
||||||
:param stream: stream
|
|
||||||
"""
|
|
||||||
self._application_generate_entity = application_generate_entity
|
self._application_generate_entity = application_generate_entity
|
||||||
self._queue_manager = queue_manager
|
self._queue_manager = queue_manager
|
||||||
self._start_at = time.perf_counter()
|
self._start_at = time.perf_counter()
|
||||||
|
@ -31,10 +31,19 @@ from services.annotation_service import AppAnnotationService
|
|||||||
|
|
||||||
|
|
||||||
class MessageCycleManage:
|
class MessageCycleManage:
|
||||||
_application_generate_entity: Union[
|
def __init__(
|
||||||
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity
|
self,
|
||||||
]
|
*,
|
||||||
_task_state: Union[EasyUITaskState, WorkflowTaskState]
|
application_generate_entity: Union[
|
||||||
|
ChatAppGenerateEntity,
|
||||||
|
CompletionAppGenerateEntity,
|
||||||
|
AgentChatAppGenerateEntity,
|
||||||
|
AdvancedChatAppGenerateEntity,
|
||||||
|
],
|
||||||
|
task_state: Union[EasyUITaskState, WorkflowTaskState],
|
||||||
|
) -> None:
|
||||||
|
self._application_generate_entity = application_generate_entity
|
||||||
|
self._task_state = task_state
|
||||||
|
|
||||||
def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
|
def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
|
||||||
"""
|
"""
|
||||||
|
@ -34,7 +34,6 @@ from core.app.entities.task_entities import (
|
|||||||
ParallelBranchStartStreamResponse,
|
ParallelBranchStartStreamResponse,
|
||||||
WorkflowFinishStreamResponse,
|
WorkflowFinishStreamResponse,
|
||||||
WorkflowStartStreamResponse,
|
WorkflowStartStreamResponse,
|
||||||
WorkflowTaskState,
|
|
||||||
)
|
)
|
||||||
from core.file import FILE_MODEL_IDENTITY, File
|
from core.file import FILE_MODEL_IDENTITY, File
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
@ -58,13 +57,20 @@ from models.workflow import (
|
|||||||
WorkflowRunStatus,
|
WorkflowRunStatus,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .exc import WorkflowNodeExecutionNotFoundError, WorkflowRunNotFoundError
|
from .exc import WorkflowRunNotFoundError
|
||||||
|
|
||||||
|
|
||||||
class WorkflowCycleManage:
|
class WorkflowCycleManage:
|
||||||
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
|
def __init__(
|
||||||
_task_state: WorkflowTaskState
|
self,
|
||||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
*,
|
||||||
|
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
|
||||||
|
workflow_system_variables: dict[SystemVariableKey, Any],
|
||||||
|
) -> None:
|
||||||
|
self._workflow_run: WorkflowRun | None = None
|
||||||
|
self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {}
|
||||||
|
self._application_generate_entity = application_generate_entity
|
||||||
|
self._workflow_system_variables = workflow_system_variables
|
||||||
|
|
||||||
def _handle_workflow_run_start(
|
def _handle_workflow_run_start(
|
||||||
self,
|
self,
|
||||||
@ -240,7 +246,7 @@ class WorkflowCycleManage:
|
|||||||
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
workflow_run.exceptions_count = exceptions_count
|
workflow_run.exceptions_count = exceptions_count
|
||||||
|
|
||||||
stmt = select(WorkflowNodeExecution).where(
|
stmt = select(WorkflowNodeExecution.node_execution_id).where(
|
||||||
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
|
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
|
||||||
WorkflowNodeExecution.app_id == workflow_run.app_id,
|
WorkflowNodeExecution.app_id == workflow_run.app_id,
|
||||||
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
||||||
@ -248,16 +254,18 @@ class WorkflowCycleManage:
|
|||||||
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
|
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
|
||||||
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
|
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
|
||||||
)
|
)
|
||||||
|
ids = session.scalars(stmt).all()
|
||||||
running_workflow_node_executions = session.scalars(stmt).all()
|
# Use self._get_workflow_node_execution here to make sure the cache is updated
|
||||||
|
running_workflow_node_executions = [
|
||||||
|
self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id
|
||||||
|
]
|
||||||
|
|
||||||
for workflow_node_execution in running_workflow_node_executions:
|
for workflow_node_execution in running_workflow_node_executions:
|
||||||
|
now = datetime.now(UTC).replace(tzinfo=None)
|
||||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||||
workflow_node_execution.error = error
|
workflow_node_execution.error = error
|
||||||
workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
workflow_node_execution.finished_at = now
|
||||||
workflow_node_execution.elapsed_time = (
|
workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds()
|
||||||
workflow_node_execution.finished_at - workflow_node_execution.created_at
|
|
||||||
).total_seconds()
|
|
||||||
|
|
||||||
if trace_manager:
|
if trace_manager:
|
||||||
trace_manager.add_trace_task(
|
trace_manager.add_trace_task(
|
||||||
@ -299,6 +307,8 @@ class WorkflowCycleManage:
|
|||||||
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
|
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
|
|
||||||
session.add(workflow_node_execution)
|
session.add(workflow_node_execution)
|
||||||
|
|
||||||
|
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
def _handle_workflow_node_execution_success(
|
def _handle_workflow_node_execution_success(
|
||||||
@ -326,6 +336,7 @@ class WorkflowCycleManage:
|
|||||||
workflow_node_execution.finished_at = finished_at
|
workflow_node_execution.finished_at = finished_at
|
||||||
workflow_node_execution.elapsed_time = elapsed_time
|
workflow_node_execution.elapsed_time = elapsed_time
|
||||||
|
|
||||||
|
workflow_node_execution = session.merge(workflow_node_execution)
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
def _handle_workflow_node_execution_failed(
|
def _handle_workflow_node_execution_failed(
|
||||||
@ -365,6 +376,7 @@ class WorkflowCycleManage:
|
|||||||
workflow_node_execution.elapsed_time = elapsed_time
|
workflow_node_execution.elapsed_time = elapsed_time
|
||||||
workflow_node_execution.execution_metadata = execution_metadata
|
workflow_node_execution.execution_metadata = execution_metadata
|
||||||
|
|
||||||
|
workflow_node_execution = session.merge(workflow_node_execution)
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
def _handle_workflow_node_execution_retried(
|
def _handle_workflow_node_execution_retried(
|
||||||
@ -416,6 +428,8 @@ class WorkflowCycleManage:
|
|||||||
workflow_node_execution.index = event.node_run_index
|
workflow_node_execution.index = event.node_run_index
|
||||||
|
|
||||||
session.add(workflow_node_execution)
|
session.add(workflow_node_execution)
|
||||||
|
|
||||||
|
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
#################################################
|
#################################################
|
||||||
@ -812,22 +826,20 @@ class WorkflowCycleManage:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun:
|
def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun:
|
||||||
"""
|
if self._workflow_run and self._workflow_run.id == workflow_run_id:
|
||||||
Refetch workflow run
|
cached_workflow_run = self._workflow_run
|
||||||
:param workflow_run_id: workflow run id
|
cached_workflow_run = session.merge(cached_workflow_run)
|
||||||
:return:
|
return cached_workflow_run
|
||||||
"""
|
|
||||||
stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
|
stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
|
||||||
workflow_run = session.scalar(stmt)
|
workflow_run = session.scalar(stmt)
|
||||||
if not workflow_run:
|
if not workflow_run:
|
||||||
raise WorkflowRunNotFoundError(workflow_run_id)
|
raise WorkflowRunNotFoundError(workflow_run_id)
|
||||||
|
self._workflow_run = workflow_run
|
||||||
|
|
||||||
return workflow_run
|
return workflow_run
|
||||||
|
|
||||||
def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution:
|
def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution:
|
||||||
stmt = select(WorkflowNodeExecution).where(WorkflowNodeExecution.node_execution_id == node_execution_id)
|
if node_execution_id not in self._workflow_node_executions:
|
||||||
workflow_node_execution = session.scalar(stmt)
|
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
|
||||||
if not workflow_node_execution:
|
cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
|
||||||
raise WorkflowNodeExecutionNotFoundError(node_execution_id)
|
return cached_workflow_node_execution
|
||||||
|
|
||||||
return workflow_node_execution
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user