diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 138503d404..2abee5bef5 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -4,6 +4,8 @@ from collections.abc import Mapping, Sequence from datetime import datetime, timezone from typing import Any, Optional, Union, cast +from sqlalchemy.orm import Session + from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( QueueIterationCompletedEvent, @@ -232,30 +234,30 @@ class WorkflowCycleManage: self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent ) -> WorkflowNodeExecution: # init workflow node execution - workflow_node_execution = WorkflowNodeExecution() - workflow_node_execution.tenant_id = workflow_run.tenant_id - workflow_node_execution.app_id = workflow_run.app_id - workflow_node_execution.workflow_id = workflow_run.workflow_id - workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value - workflow_node_execution.workflow_run_id = workflow_run.id - workflow_node_execution.predecessor_node_id = event.predecessor_node_id - workflow_node_execution.index = event.node_run_index - workflow_node_execution.node_execution_id = event.node_execution_id - workflow_node_execution.node_id = event.node_id - workflow_node_execution.node_type = event.node_type.value - workflow_node_execution.title = event.node_data.title - workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value - workflow_node_execution.created_by_role = workflow_run.created_by_role - workflow_node_execution.created_by = workflow_run.created_by - workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None) - db.session.add(workflow_node_execution) - db.session.commit() - db.session.refresh(workflow_node_execution) - db.session.close() + with Session(db.engine, expire_on_commit=False) as session: + workflow_node_execution = WorkflowNodeExecution() + workflow_node_execution.tenant_id = workflow_run.tenant_id + workflow_node_execution.app_id = workflow_run.app_id + workflow_node_execution.workflow_id = workflow_run.workflow_id + workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value + workflow_node_execution.workflow_run_id = workflow_run.id + workflow_node_execution.predecessor_node_id = event.predecessor_node_id + workflow_node_execution.index = event.node_run_index + workflow_node_execution.node_execution_id = event.node_execution_id + workflow_node_execution.node_id = event.node_id + workflow_node_execution.node_type = event.node_type.value + workflow_node_execution.title = event.node_data.title + workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value + workflow_node_execution.created_by_role = workflow_run.created_by_role + workflow_node_execution.created_by = workflow_run.created_by + workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None) + + session.add(workflow_node_execution) + session.commit() + session.refresh(workflow_node_execution) self._wip_workflow_node_executions[workflow_node_execution.node_execution_id] = workflow_node_execution - return workflow_node_execution def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 5f932c0a8e..2a7c7234ea 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -96,9 +96,6 @@ class VariablePool(BaseModel): if len(selector) < 2: raise ValueError("Invalid selector") - if value is None: - return - if isinstance(value, Segment): v = value else: diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 24e479153e..f310a67b76 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -320,11 +320,11 @@ class LLMNode(BaseNode[LLMNodeData]): variable_selectors = variable_template_parser.extract_variable_selectors() for variable_selector in variable_selectors: - variable_value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) + variable_value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) if variable_value is None: raise ValueError(f"Variable {variable_selector.variable} not found") - inputs[variable_selector.variable] = variable_value + inputs[variable_selector.variable] = variable_value.to_object() memory = node_data.memory if memory and memory.query_prompt_template: @@ -332,11 +332,11 @@ class LLMNode(BaseNode[LLMNodeData]): template=memory.query_prompt_template ).extract_variable_selectors() for variable_selector in query_variable_selectors: - variable_value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) + variable_value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) if variable_value is None: raise ValueError(f"Variable {variable_selector.variable} not found") - inputs[variable_selector.variable] = variable_value + inputs[variable_selector.variable] = variable_value.to_object() return inputs