fix(workflow): improve database session handling and variable management (#9581)

This commit is contained in:
-LAN- 2024-10-22 00:42:40 +08:00 committed by GitHub
parent 38a4f0234d
commit c063617553
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 27 additions and 28 deletions

View File

@ -4,6 +4,8 @@ from collections.abc import Mapping, Sequence
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, Optional, Union, cast from typing import Any, Optional, Union, cast
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
QueueIterationCompletedEvent, QueueIterationCompletedEvent,
@ -232,30 +234,30 @@ class WorkflowCycleManage:
self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
) -> WorkflowNodeExecution: ) -> WorkflowNodeExecution:
# init workflow node execution # init workflow node execution
workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.tenant_id = workflow_run.tenant_id
workflow_node_execution.app_id = workflow_run.app_id
workflow_node_execution.workflow_id = workflow_run.workflow_id
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
workflow_node_execution.workflow_run_id = workflow_run.id
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
workflow_node_execution.index = event.node_run_index
workflow_node_execution.node_execution_id = event.node_execution_id
workflow_node_execution.node_id = event.node_id
workflow_node_execution.node_type = event.node_type.value
workflow_node_execution.title = event.node_data.title
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
workflow_node_execution.created_by_role = workflow_run.created_by_role
workflow_node_execution.created_by = workflow_run.created_by
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.add(workflow_node_execution) with Session(db.engine, expire_on_commit=False) as session:
db.session.commit() workflow_node_execution = WorkflowNodeExecution()
db.session.refresh(workflow_node_execution) workflow_node_execution.tenant_id = workflow_run.tenant_id
db.session.close() workflow_node_execution.app_id = workflow_run.app_id
workflow_node_execution.workflow_id = workflow_run.workflow_id
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
workflow_node_execution.workflow_run_id = workflow_run.id
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
workflow_node_execution.index = event.node_run_index
workflow_node_execution.node_execution_id = event.node_execution_id
workflow_node_execution.node_id = event.node_id
workflow_node_execution.node_type = event.node_type.value
workflow_node_execution.title = event.node_data.title
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
workflow_node_execution.created_by_role = workflow_run.created_by_role
workflow_node_execution.created_by = workflow_run.created_by
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
session.add(workflow_node_execution)
session.commit()
session.refresh(workflow_node_execution)
self._wip_workflow_node_executions[workflow_node_execution.node_execution_id] = workflow_node_execution self._wip_workflow_node_executions[workflow_node_execution.node_execution_id] = workflow_node_execution
return workflow_node_execution return workflow_node_execution
def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:

View File

@ -96,9 +96,6 @@ class VariablePool(BaseModel):
if len(selector) < 2: if len(selector) < 2:
raise ValueError("Invalid selector") raise ValueError("Invalid selector")
if value is None:
return
if isinstance(value, Segment): if isinstance(value, Segment):
v = value v = value
else: else:

View File

@ -320,11 +320,11 @@ class LLMNode(BaseNode[LLMNodeData]):
variable_selectors = variable_template_parser.extract_variable_selectors() variable_selectors = variable_template_parser.extract_variable_selectors()
for variable_selector in variable_selectors: for variable_selector in variable_selectors:
variable_value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) variable_value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if variable_value is None: if variable_value is None:
raise ValueError(f"Variable {variable_selector.variable} not found") raise ValueError(f"Variable {variable_selector.variable} not found")
inputs[variable_selector.variable] = variable_value inputs[variable_selector.variable] = variable_value.to_object()
memory = node_data.memory memory = node_data.memory
if memory and memory.query_prompt_template: if memory and memory.query_prompt_template:
@ -332,11 +332,11 @@ class LLMNode(BaseNode[LLMNodeData]):
template=memory.query_prompt_template template=memory.query_prompt_template
).extract_variable_selectors() ).extract_variable_selectors()
for variable_selector in query_variable_selectors: for variable_selector in query_variable_selectors:
variable_value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) variable_value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if variable_value is None: if variable_value is None:
raise ValueError(f"Variable {variable_selector.variable} not found") raise ValueError(f"Variable {variable_selector.variable} not found")
inputs[variable_selector.variable] = variable_value inputs[variable_selector.variable] = variable_value.to_object()
return inputs return inputs