fix issues when merging from main

This commit is contained in:
takatost 2024-08-13 17:11:19 +08:00
parent 14d020fffe
commit 2980e31ddf
5 changed files with 9 additions and 44 deletions

View File

@ -43,7 +43,6 @@ from models.workflow import (
WorkflowRunStatus, WorkflowRunStatus,
WorkflowRunTriggeredFrom, WorkflowRunTriggeredFrom,
) )
from services.workflow_service import WorkflowService
class WorkflowCycleManage: class WorkflowCycleManage:

View File

@ -1,14 +0,0 @@
from abc import ABC, abstractmethod
from typing import Any
from core.workflow.utils.condition.entities import Condition
class IterableNodeMixin(ABC):
@classmethod
@abstractmethod
def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]:
"""
Get conditions.
"""
raise NotImplementedError

View File

@ -13,6 +13,7 @@ from core.workflow.nodes.start.start_node import StartNode
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.nodes.tool.tool_node import ToolNode from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
from core.workflow.nodes.variable_assigner import VariableAssignerNode
node_classes = { node_classes = {
NodeType.START: StartNode, NodeType.START: StartNode,
@ -29,5 +30,6 @@ node_classes = {
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode, NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
NodeType.ITERATION: IterationNode, NodeType.ITERATION: IterationNode,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode,
} }

View File

@ -36,23 +36,23 @@ class VariableAssignerNode(BaseNode):
_node_data_cls: type[BaseNodeData] = VariableAssignerData _node_data_cls: type[BaseNodeData] = VariableAssignerData
_node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER _node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
def _run(self, variable_pool: VariablePool) -> NodeRunResult: def _run(self) -> NodeRunResult:
data = cast(VariableAssignerData, self.node_data) data = cast(VariableAssignerData, self.node_data)
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = variable_pool.get(data.assigned_variable_selector) original_variable = self.graph_runtime_state.variable_pool.get(data.assigned_variable_selector)
if not isinstance(original_variable, Variable): if not isinstance(original_variable, Variable):
raise VariableAssignerNodeError('assigned variable not found') raise VariableAssignerNodeError('assigned variable not found')
match data.write_mode: match data.write_mode:
case WriteMode.OVER_WRITE: case WriteMode.OVER_WRITE:
income_value = variable_pool.get(data.input_variable_selector) income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector)
if not income_value: if not income_value:
raise VariableAssignerNodeError('input value not found') raise VariableAssignerNodeError('input value not found')
updated_variable = original_variable.model_copy(update={'value': income_value.value}) updated_variable = original_variable.model_copy(update={'value': income_value.value})
case WriteMode.APPEND: case WriteMode.APPEND:
income_value = variable_pool.get(data.input_variable_selector) income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector)
if not income_value: if not income_value:
raise VariableAssignerNodeError('input value not found') raise VariableAssignerNodeError('input value not found')
updated_value = original_variable.value + [income_value.value] updated_value = original_variable.value + [income_value.value]
@ -66,11 +66,11 @@ class VariableAssignerNode(BaseNode):
raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}') raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}')
# Over write the variable. # Over write the variable.
variable_pool.add(data.assigned_variable_selector, updated_variable) self.graph_runtime_state.variable_pool.add(data.assigned_variable_selector, updated_variable)
# Update conversation variable. # Update conversation variable.
# TODO: Find a better way to use the database. # TODO: Find a better way to use the database.
conversation_id = variable_pool.get(['sys', 'conversation_id']) conversation_id = self.graph_runtime_state.variable_pool.get(['sys', 'conversation_id'])
if not conversation_id: if not conversation_id:
raise VariableAssignerNodeError('conversation_id not found') raise VariableAssignerNodeError('conversation_id not found')
update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable) update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)

View File

@ -336,25 +336,3 @@ class WorkflowService:
) )
else: else:
raise ValueError(f"Invalid app mode: {app_model.mode}") raise ValueError(f"Invalid app mode: {app_model.mode}")
@classmethod
def get_elapsed_time(cls, workflow_run_id: str) -> float:
"""
Get elapsed time
"""
elapsed_time = 0.0
# fetch workflow node execution by workflow_run_id
workflow_nodes = (
db.session.query(WorkflowNodeExecution)
.filter(WorkflowNodeExecution.workflow_run_id == workflow_run_id)
.order_by(WorkflowNodeExecution.created_at.asc())
.all()
)
if not workflow_nodes:
return elapsed_time
for node in workflow_nodes:
elapsed_time += node.elapsed_time
return elapsed_time