mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 21:09:04 +08:00
fix: correct type mismatch in WorkflowService node execution handling (#19846)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
df631591f2
commit
7d0106b220
@ -64,9 +64,9 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"inputs": execution.inputs_dict,
|
"inputs": execution.inputs,
|
||||||
"outputs": execution.outputs_dict,
|
"outputs": execution.outputs,
|
||||||
"process_data": execution.process_data_dict,
|
"process_data": execution.process_data,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -113,7 +113,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"inputs": execution.inputs_dict,
|
"inputs": execution.inputs,
|
||||||
"outputs": execution.outputs_dict,
|
"outputs": execution.outputs,
|
||||||
"process_data": execution.process_data_dict,
|
"process_data": execution.process_data,
|
||||||
}
|
}
|
||||||
|
@ -127,7 +127,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||||||
finished_at=db_model.finished_at,
|
finished_at=db_model.finished_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _to_db_model(self, domain_model: NodeExecution) -> WorkflowNodeExecution:
|
def to_db_model(self, domain_model: NodeExecution) -> WorkflowNodeExecution:
|
||||||
"""
|
"""
|
||||||
Convert a domain model to a database model.
|
Convert a domain model to a database model.
|
||||||
|
|
||||||
@ -174,27 +174,35 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||||||
|
|
||||||
def save(self, execution: NodeExecution) -> None:
|
def save(self, execution: NodeExecution) -> None:
|
||||||
"""
|
"""
|
||||||
Save or update a NodeExecution instance and commit changes to the database.
|
Save or update a NodeExecution domain entity to the database.
|
||||||
|
|
||||||
This method handles both creating new records and updating existing ones.
|
This method serves as a domain-to-database adapter that:
|
||||||
It determines whether to create or update based on whether the record
|
1. Converts the domain entity to its database representation
|
||||||
already exists in the database. It also updates the in-memory cache.
|
2. Persists the database model using SQLAlchemy's merge operation
|
||||||
|
3. Maintains proper multi-tenancy by including tenant context during conversion
|
||||||
|
4. Updates the in-memory cache for faster subsequent lookups
|
||||||
|
|
||||||
|
The method handles both creating new records and updating existing ones through
|
||||||
|
SQLAlchemy's merge operation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
execution: The NodeExecution instance to save or update
|
execution: The NodeExecution domain entity to persist
|
||||||
"""
|
"""
|
||||||
with self._session_factory() as session:
|
# Convert domain model to database model using tenant context and other attributes
|
||||||
# Convert domain model to database model using instance attributes
|
db_model = self.to_db_model(execution)
|
||||||
db_model = self._to_db_model(execution)
|
|
||||||
|
|
||||||
# Use merge which will handle both insert and update
|
# Create a new database session
|
||||||
|
with self._session_factory() as session:
|
||||||
|
# SQLAlchemy merge intelligently handles both insert and update operations
|
||||||
|
# based on the presence of the primary key
|
||||||
session.merge(db_model)
|
session.merge(db_model)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
# Update the cache if node_execution_id is present
|
# Update the in-memory cache for faster subsequent lookups
|
||||||
if execution.node_execution_id:
|
# Only cache if we have a node_execution_id to use as the cache key
|
||||||
logger.debug(f"Updating cache for node_execution_id: {execution.node_execution_id}")
|
if db_model.node_execution_id:
|
||||||
self._node_execution_cache[execution.node_execution_id] = execution
|
logger.debug(f"Updating cache for node_execution_id: {db_model.node_execution_id}")
|
||||||
|
self._node_execution_cache[db_model.node_execution_id] = db_model
|
||||||
|
|
||||||
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[NodeExecution]:
|
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[NodeExecution]:
|
||||||
"""
|
"""
|
||||||
@ -257,41 +265,6 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||||||
Returns:
|
Returns:
|
||||||
A list of NodeExecution instances
|
A list of NodeExecution instances
|
||||||
"""
|
"""
|
||||||
# Get the raw database models using the new method
|
|
||||||
db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config)
|
|
||||||
|
|
||||||
# Convert database models to domain models and update cache
|
|
||||||
domain_models = []
|
|
||||||
for model in db_models:
|
|
||||||
domain_model = self._to_domain_model(model)
|
|
||||||
# Update cache if node_execution_id is present
|
|
||||||
if domain_model.node_execution_id:
|
|
||||||
self._node_execution_cache[domain_model.node_execution_id] = domain_model
|
|
||||||
domain_models.append(domain_model)
|
|
||||||
|
|
||||||
return domain_models
|
|
||||||
|
|
||||||
def get_db_models_by_workflow_run(
|
|
||||||
self,
|
|
||||||
workflow_run_id: str,
|
|
||||||
order_config: Optional[OrderConfig] = None,
|
|
||||||
) -> Sequence[WorkflowNodeExecution]:
|
|
||||||
"""
|
|
||||||
Retrieve all WorkflowNodeExecution database models for a specific workflow run.
|
|
||||||
|
|
||||||
This method is similar to get_by_workflow_run but returns the raw database models
|
|
||||||
instead of converting them to domain models. This can be useful when direct access
|
|
||||||
to database model properties is needed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
workflow_run_id: The workflow run ID
|
|
||||||
order_config: Optional configuration for ordering results
|
|
||||||
order_config.order_by: List of fields to order by (e.g., ["index", "created_at"])
|
|
||||||
order_config.order_direction: Direction to order ("asc" or "desc")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of WorkflowNodeExecution database models
|
|
||||||
"""
|
|
||||||
with self._session_factory() as session:
|
with self._session_factory() as session:
|
||||||
stmt = select(WorkflowNodeExecution).where(
|
stmt = select(WorkflowNodeExecution).where(
|
||||||
WorkflowNodeExecution.workflow_run_id == workflow_run_id,
|
WorkflowNodeExecution.workflow_run_id == workflow_run_id,
|
||||||
@ -319,10 +292,16 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||||||
|
|
||||||
db_models = session.scalars(stmt).all()
|
db_models = session.scalars(stmt).all()
|
||||||
|
|
||||||
# Note: We don't update the cache here since we're returning raw DB models
|
# Convert database models to domain models and update cache
|
||||||
# and not converting to domain models
|
domain_models = []
|
||||||
|
for model in db_models:
|
||||||
|
domain_model = self._to_domain_model(model)
|
||||||
|
# Update cache if node_execution_id is present
|
||||||
|
if domain_model.node_execution_id:
|
||||||
|
self._node_execution_cache[domain_model.node_execution_id] = domain_model
|
||||||
|
domain_models.append(domain_model)
|
||||||
|
|
||||||
return db_models
|
return domain_models
|
||||||
|
|
||||||
def get_running_executions(self, workflow_run_id: str) -> Sequence[NodeExecution]:
|
def get_running_executions(self, workflow_run_id: str) -> Sequence[NodeExecution]:
|
||||||
"""
|
"""
|
||||||
|
@ -145,6 +145,9 @@ class WorkflowRunService:
|
|||||||
|
|
||||||
# Use the repository to get the node executions with ordering
|
# Use the repository to get the node executions with ordering
|
||||||
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
||||||
node_executions = repository.get_db_models_by_workflow_run(workflow_run_id=run_id, order_config=order_config)
|
node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config)
|
||||||
|
|
||||||
return node_executions
|
# Convert domain models to database models
|
||||||
|
workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions]
|
||||||
|
|
||||||
|
return workflow_node_executions
|
||||||
|
@ -10,10 +10,10 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
|
||||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||||
from core.variables import Variable
|
from core.variables import Variable
|
||||||
from core.workflow.entities.node_entities import NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
|
from core.workflow.entities.node_execution_entities import NodeExecution, NodeExecutionStatus
|
||||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||||
from core.workflow.nodes import NodeType
|
from core.workflow.nodes import NodeType
|
||||||
@ -26,7 +26,6 @@ from core.workflow.workflow_entry import WorkflowEntry
|
|||||||
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.enums import CreatorUserRole
|
|
||||||
from models.model import App, AppMode
|
from models.model import App, AppMode
|
||||||
from models.tools import WorkflowToolProvider
|
from models.tools import WorkflowToolProvider
|
||||||
from models.workflow import (
|
from models.workflow import (
|
||||||
@ -268,35 +267,37 @@ class WorkflowService:
|
|||||||
# run draft workflow node
|
# run draft workflow node
|
||||||
start_at = time.perf_counter()
|
start_at = time.perf_counter()
|
||||||
|
|
||||||
workflow_node_execution = self._handle_node_run_result(
|
node_execution = self._handle_node_run_result(
|
||||||
getter=lambda: WorkflowEntry.single_step_run(
|
invoke_node_fn=lambda: WorkflowEntry.single_step_run(
|
||||||
workflow=draft_workflow,
|
workflow=draft_workflow,
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
user_inputs=user_inputs,
|
user_inputs=user_inputs,
|
||||||
user_id=account.id,
|
user_id=account.id,
|
||||||
),
|
),
|
||||||
start_at=start_at,
|
start_at=start_at,
|
||||||
tenant_id=app_model.tenant_id,
|
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
workflow_node_execution.app_id = app_model.id
|
# Set workflow_id on the NodeExecution
|
||||||
workflow_node_execution.created_by = account.id
|
node_execution.workflow_id = draft_workflow.id
|
||||||
workflow_node_execution.workflow_id = draft_workflow.id
|
|
||||||
|
|
||||||
|
# Create repository and save the node execution
|
||||||
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
session_factory=db.engine,
|
session_factory=db.engine,
|
||||||
user=account,
|
user=account,
|
||||||
app_id=app_model.id,
|
app_id=app_model.id,
|
||||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||||
)
|
)
|
||||||
repository.save(workflow_node_execution)
|
repository.save(node_execution)
|
||||||
|
|
||||||
|
# Convert node_execution to WorkflowNodeExecution after save
|
||||||
|
workflow_node_execution = repository.to_db_model(node_execution)
|
||||||
|
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
def run_free_workflow_node(
|
def run_free_workflow_node(
|
||||||
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
|
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
|
||||||
) -> WorkflowNodeExecution:
|
) -> NodeExecution:
|
||||||
"""
|
"""
|
||||||
Run draft workflow node
|
Run draft workflow node
|
||||||
"""
|
"""
|
||||||
@ -304,7 +305,7 @@ class WorkflowService:
|
|||||||
start_at = time.perf_counter()
|
start_at = time.perf_counter()
|
||||||
|
|
||||||
workflow_node_execution = self._handle_node_run_result(
|
workflow_node_execution = self._handle_node_run_result(
|
||||||
getter=lambda: WorkflowEntry.run_free_node(
|
invoke_node_fn=lambda: WorkflowEntry.run_free_node(
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
node_data=node_data,
|
node_data=node_data,
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
@ -312,7 +313,6 @@ class WorkflowService:
|
|||||||
user_inputs=user_inputs,
|
user_inputs=user_inputs,
|
||||||
),
|
),
|
||||||
start_at=start_at,
|
start_at=start_at,
|
||||||
tenant_id=tenant_id,
|
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -320,21 +320,12 @@ class WorkflowService:
|
|||||||
|
|
||||||
def _handle_node_run_result(
|
def _handle_node_run_result(
|
||||||
self,
|
self,
|
||||||
getter: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]],
|
invoke_node_fn: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]],
|
||||||
start_at: float,
|
start_at: float,
|
||||||
tenant_id: str,
|
|
||||||
node_id: str,
|
node_id: str,
|
||||||
) -> WorkflowNodeExecution:
|
) -> NodeExecution:
|
||||||
"""
|
|
||||||
Handle node run result
|
|
||||||
|
|
||||||
:param getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]]
|
|
||||||
:param start_at: float
|
|
||||||
:param tenant_id: str
|
|
||||||
:param node_id: str
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
node_instance, generator = getter()
|
node_instance, generator = invoke_node_fn()
|
||||||
|
|
||||||
node_run_result: NodeRunResult | None = None
|
node_run_result: NodeRunResult | None = None
|
||||||
for event in generator:
|
for event in generator:
|
||||||
@ -383,20 +374,21 @@ class WorkflowService:
|
|||||||
node_run_result = None
|
node_run_result = None
|
||||||
error = e.error
|
error = e.error
|
||||||
|
|
||||||
workflow_node_execution = WorkflowNodeExecution()
|
# Create a NodeExecution domain model
|
||||||
workflow_node_execution.id = str(uuid4())
|
node_execution = NodeExecution(
|
||||||
workflow_node_execution.tenant_id = tenant_id
|
id=str(uuid4()),
|
||||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value
|
workflow_id="", # This is a single-step execution, so no workflow ID
|
||||||
workflow_node_execution.index = 1
|
index=1,
|
||||||
workflow_node_execution.node_id = node_id
|
node_id=node_id,
|
||||||
workflow_node_execution.node_type = node_instance.node_type
|
node_type=node_instance.node_type,
|
||||||
workflow_node_execution.title = node_instance.node_data.title
|
title=node_instance.node_data.title,
|
||||||
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
|
elapsed_time=time.perf_counter() - start_at,
|
||||||
workflow_node_execution.created_by_role = CreatorUserRole.ACCOUNT.value
|
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||||
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
|
finished_at=datetime.now(UTC).replace(tzinfo=None),
|
||||||
workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
)
|
||||||
|
|
||||||
if run_succeeded and node_run_result:
|
if run_succeeded and node_run_result:
|
||||||
# create workflow node execution
|
# Set inputs, process_data, and outputs as dictionaries (not JSON strings)
|
||||||
inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None
|
inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None
|
||||||
process_data = (
|
process_data = (
|
||||||
WorkflowEntry.handle_special_values(node_run_result.process_data)
|
WorkflowEntry.handle_special_values(node_run_result.process_data)
|
||||||
@ -405,23 +397,23 @@ class WorkflowService:
|
|||||||
)
|
)
|
||||||
outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None
|
outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None
|
||||||
|
|
||||||
workflow_node_execution.inputs = json.dumps(inputs)
|
node_execution.inputs = inputs
|
||||||
workflow_node_execution.process_data = json.dumps(process_data)
|
node_execution.process_data = process_data
|
||||||
workflow_node_execution.outputs = json.dumps(outputs)
|
node_execution.outputs = outputs
|
||||||
workflow_node_execution.execution_metadata = (
|
node_execution.metadata = node_run_result.metadata
|
||||||
json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
|
|
||||||
)
|
|
||||||
if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
|
||||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
|
||||||
elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
|
|
||||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION.value
|
|
||||||
workflow_node_execution.error = node_run_result.error
|
|
||||||
else:
|
|
||||||
# create workflow node execution
|
|
||||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
|
||||||
workflow_node_execution.error = error
|
|
||||||
|
|
||||||
return workflow_node_execution
|
# Map status from WorkflowNodeExecutionStatus to NodeExecutionStatus
|
||||||
|
if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||||
|
node_execution.status = NodeExecutionStatus.SUCCEEDED
|
||||||
|
elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
|
||||||
|
node_execution.status = NodeExecutionStatus.EXCEPTION
|
||||||
|
node_execution.error = node_run_result.error
|
||||||
|
else:
|
||||||
|
# Set failed status and error
|
||||||
|
node_execution.status = NodeExecutionStatus.FAILED
|
||||||
|
node_execution.error = error
|
||||||
|
|
||||||
|
return node_execution
|
||||||
|
|
||||||
def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App:
|
def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App:
|
||||||
"""
|
"""
|
||||||
|
@ -88,15 +88,15 @@ def test_save(repository, session):
|
|||||||
execution.outputs = None
|
execution.outputs = None
|
||||||
execution.metadata = None
|
execution.metadata = None
|
||||||
|
|
||||||
# Mock the _to_db_model method to return the execution itself
|
# Mock the to_db_model method to return the execution itself
|
||||||
# This simulates the behavior of setting tenant_id and app_id
|
# This simulates the behavior of setting tenant_id and app_id
|
||||||
repository._to_db_model = MagicMock(return_value=execution)
|
repository.to_db_model = MagicMock(return_value=execution)
|
||||||
|
|
||||||
# Call save method
|
# Call save method
|
||||||
repository.save(execution)
|
repository.save(execution)
|
||||||
|
|
||||||
# Assert _to_db_model was called with the execution
|
# Assert to_db_model was called with the execution
|
||||||
repository._to_db_model.assert_called_once_with(execution)
|
repository.to_db_model.assert_called_once_with(execution)
|
||||||
|
|
||||||
# Assert session.merge was called (now using merge for both save and update)
|
# Assert session.merge was called (now using merge for both save and update)
|
||||||
session_obj.merge.assert_called_once_with(execution)
|
session_obj.merge.assert_called_once_with(execution)
|
||||||
@ -119,14 +119,14 @@ def test_save_with_existing_tenant_id(repository, session):
|
|||||||
modified_execution.tenant_id = "existing-tenant" # Tenant ID should not change
|
modified_execution.tenant_id = "existing-tenant" # Tenant ID should not change
|
||||||
modified_execution.app_id = repository._app_id # App ID should be set
|
modified_execution.app_id = repository._app_id # App ID should be set
|
||||||
|
|
||||||
# Mock the _to_db_model method to return the modified execution
|
# Mock the to_db_model method to return the modified execution
|
||||||
repository._to_db_model = MagicMock(return_value=modified_execution)
|
repository.to_db_model = MagicMock(return_value=modified_execution)
|
||||||
|
|
||||||
# Call save method
|
# Call save method
|
||||||
repository.save(execution)
|
repository.save(execution)
|
||||||
|
|
||||||
# Assert _to_db_model was called with the execution
|
# Assert to_db_model was called with the execution
|
||||||
repository._to_db_model.assert_called_once_with(execution)
|
repository.to_db_model.assert_called_once_with(execution)
|
||||||
|
|
||||||
# Assert session.merge was called with the modified execution (now using merge for both save and update)
|
# Assert session.merge was called with the modified execution (now using merge for both save and update)
|
||||||
session_obj.merge.assert_called_once_with(modified_execution)
|
session_obj.merge.assert_called_once_with(modified_execution)
|
||||||
@ -197,40 +197,6 @@ def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
|
|||||||
assert result[0] is mock_domain_model
|
assert result[0] is mock_domain_model
|
||||||
|
|
||||||
|
|
||||||
def test_get_db_models_by_workflow_run(repository, session, mocker: MockerFixture):
|
|
||||||
"""Test get_db_models_by_workflow_run method."""
|
|
||||||
session_obj, _ = session
|
|
||||||
# Set up mock
|
|
||||||
mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select")
|
|
||||||
mock_stmt = mocker.MagicMock()
|
|
||||||
mock_select.return_value = mock_stmt
|
|
||||||
mock_stmt.where.return_value = mock_stmt
|
|
||||||
mock_stmt.order_by.return_value = mock_stmt
|
|
||||||
|
|
||||||
# Create a properly configured mock execution
|
|
||||||
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecution)
|
|
||||||
configure_mock_execution(mock_execution)
|
|
||||||
session_obj.scalars.return_value.all.return_value = [mock_execution]
|
|
||||||
|
|
||||||
# Mock the _to_domain_model method
|
|
||||||
to_domain_model_mock = mocker.patch.object(repository, "_to_domain_model")
|
|
||||||
|
|
||||||
# Call method
|
|
||||||
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
|
||||||
result = repository.get_db_models_by_workflow_run(workflow_run_id="test-workflow-run-id", order_config=order_config)
|
|
||||||
|
|
||||||
# Assert select was called with correct parameters
|
|
||||||
mock_select.assert_called_once()
|
|
||||||
session_obj.scalars.assert_called_once_with(mock_stmt)
|
|
||||||
|
|
||||||
# Assert the result contains our mock db model directly (without conversion to domain model)
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0] is mock_execution
|
|
||||||
|
|
||||||
# Verify that _to_domain_model was NOT called (since we're returning raw DB models)
|
|
||||||
to_domain_model_mock.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_running_executions(repository, session, mocker: MockerFixture):
|
def test_get_running_executions(repository, session, mocker: MockerFixture):
|
||||||
"""Test get_running_executions method."""
|
"""Test get_running_executions method."""
|
||||||
session_obj, _ = session
|
session_obj, _ = session
|
||||||
@ -275,15 +241,15 @@ def test_update_via_save(repository, session):
|
|||||||
execution.outputs = None
|
execution.outputs = None
|
||||||
execution.metadata = None
|
execution.metadata = None
|
||||||
|
|
||||||
# Mock the _to_db_model method to return the execution itself
|
# Mock the to_db_model method to return the execution itself
|
||||||
# This simulates the behavior of setting tenant_id and app_id
|
# This simulates the behavior of setting tenant_id and app_id
|
||||||
repository._to_db_model = MagicMock(return_value=execution)
|
repository.to_db_model = MagicMock(return_value=execution)
|
||||||
|
|
||||||
# Call save method to update an existing record
|
# Call save method to update an existing record
|
||||||
repository.save(execution)
|
repository.save(execution)
|
||||||
|
|
||||||
# Assert _to_db_model was called with the execution
|
# Assert to_db_model was called with the execution
|
||||||
repository._to_db_model.assert_called_once_with(execution)
|
repository.to_db_model.assert_called_once_with(execution)
|
||||||
|
|
||||||
# Assert session.merge was called (for updates)
|
# Assert session.merge was called (for updates)
|
||||||
session_obj.merge.assert_called_once_with(execution)
|
session_obj.merge.assert_called_once_with(execution)
|
||||||
@ -314,7 +280,7 @@ def test_clear(repository, session, mocker: MockerFixture):
|
|||||||
|
|
||||||
|
|
||||||
def test_to_db_model(repository):
|
def test_to_db_model(repository):
|
||||||
"""Test _to_db_model method."""
|
"""Test to_db_model method."""
|
||||||
# Create a domain model
|
# Create a domain model
|
||||||
domain_model = NodeExecution(
|
domain_model = NodeExecution(
|
||||||
id="test-id",
|
id="test-id",
|
||||||
@ -338,7 +304,7 @@ def test_to_db_model(repository):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Convert to DB model
|
# Convert to DB model
|
||||||
db_model = repository._to_db_model(domain_model)
|
db_model = repository.to_db_model(domain_model)
|
||||||
|
|
||||||
# Assert DB model has correct values
|
# Assert DB model has correct values
|
||||||
assert isinstance(db_model, WorkflowNodeExecution)
|
assert isinstance(db_model, WorkflowNodeExecution)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user