mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 06:49:03 +08:00
fix: correct type mismatch in WorkflowService node execution handling (#19846)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
df631591f2
commit
7d0106b220
@ -64,9 +64,9 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
|
||||
)
|
||||
|
||||
return {
|
||||
"inputs": execution.inputs_dict,
|
||||
"outputs": execution.outputs_dict,
|
||||
"process_data": execution.process_data_dict,
|
||||
"inputs": execution.inputs,
|
||||
"outputs": execution.outputs,
|
||||
"process_data": execution.process_data,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@ -113,7 +113,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
|
||||
)
|
||||
|
||||
return {
|
||||
"inputs": execution.inputs_dict,
|
||||
"outputs": execution.outputs_dict,
|
||||
"process_data": execution.process_data_dict,
|
||||
"inputs": execution.inputs,
|
||||
"outputs": execution.outputs,
|
||||
"process_data": execution.process_data,
|
||||
}
|
||||
|
@ -127,7 +127,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
finished_at=db_model.finished_at,
|
||||
)
|
||||
|
||||
def _to_db_model(self, domain_model: NodeExecution) -> WorkflowNodeExecution:
|
||||
def to_db_model(self, domain_model: NodeExecution) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Convert a domain model to a database model.
|
||||
|
||||
@ -174,27 +174,35 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
|
||||
def save(self, execution: NodeExecution) -> None:
|
||||
"""
|
||||
Save or update a NodeExecution instance and commit changes to the database.
|
||||
Save or update a NodeExecution domain entity to the database.
|
||||
|
||||
This method handles both creating new records and updating existing ones.
|
||||
It determines whether to create or update based on whether the record
|
||||
already exists in the database. It also updates the in-memory cache.
|
||||
This method serves as a domain-to-database adapter that:
|
||||
1. Converts the domain entity to its database representation
|
||||
2. Persists the database model using SQLAlchemy's merge operation
|
||||
3. Maintains proper multi-tenancy by including tenant context during conversion
|
||||
4. Updates the in-memory cache for faster subsequent lookups
|
||||
|
||||
The method handles both creating new records and updating existing ones through
|
||||
SQLAlchemy's merge operation.
|
||||
|
||||
Args:
|
||||
execution: The NodeExecution instance to save or update
|
||||
execution: The NodeExecution domain entity to persist
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
# Convert domain model to database model using instance attributes
|
||||
db_model = self._to_db_model(execution)
|
||||
# Convert domain model to database model using tenant context and other attributes
|
||||
db_model = self.to_db_model(execution)
|
||||
|
||||
# Use merge which will handle both insert and update
|
||||
# Create a new database session
|
||||
with self._session_factory() as session:
|
||||
# SQLAlchemy merge intelligently handles both insert and update operations
|
||||
# based on the presence of the primary key
|
||||
session.merge(db_model)
|
||||
session.commit()
|
||||
|
||||
# Update the cache if node_execution_id is present
|
||||
if execution.node_execution_id:
|
||||
logger.debug(f"Updating cache for node_execution_id: {execution.node_execution_id}")
|
||||
self._node_execution_cache[execution.node_execution_id] = execution
|
||||
# Update the in-memory cache for faster subsequent lookups
|
||||
# Only cache if we have a node_execution_id to use as the cache key
|
||||
if db_model.node_execution_id:
|
||||
logger.debug(f"Updating cache for node_execution_id: {db_model.node_execution_id}")
|
||||
self._node_execution_cache[db_model.node_execution_id] = db_model
|
||||
|
||||
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[NodeExecution]:
|
||||
"""
|
||||
@ -257,41 +265,6 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
Returns:
|
||||
A list of NodeExecution instances
|
||||
"""
|
||||
# Get the raw database models using the new method
|
||||
db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config)
|
||||
|
||||
# Convert database models to domain models and update cache
|
||||
domain_models = []
|
||||
for model in db_models:
|
||||
domain_model = self._to_domain_model(model)
|
||||
# Update cache if node_execution_id is present
|
||||
if domain_model.node_execution_id:
|
||||
self._node_execution_cache[domain_model.node_execution_id] = domain_model
|
||||
domain_models.append(domain_model)
|
||||
|
||||
return domain_models
|
||||
|
||||
def get_db_models_by_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
order_config: Optional[OrderConfig] = None,
|
||||
) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all WorkflowNodeExecution database models for a specific workflow run.
|
||||
|
||||
This method is similar to get_by_workflow_run but returns the raw database models
|
||||
instead of converting them to domain models. This can be useful when direct access
|
||||
to database model properties is needed.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID
|
||||
order_config: Optional configuration for ordering results
|
||||
order_config.order_by: List of fields to order by (e.g., ["index", "created_at"])
|
||||
order_config.order_direction: Direction to order ("asc" or "desc")
|
||||
|
||||
Returns:
|
||||
A list of WorkflowNodeExecution database models
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
stmt = select(WorkflowNodeExecution).where(
|
||||
WorkflowNodeExecution.workflow_run_id == workflow_run_id,
|
||||
@ -319,10 +292,16 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
|
||||
db_models = session.scalars(stmt).all()
|
||||
|
||||
# Note: We don't update the cache here since we're returning raw DB models
|
||||
# and not converting to domain models
|
||||
# Convert database models to domain models and update cache
|
||||
domain_models = []
|
||||
for model in db_models:
|
||||
domain_model = self._to_domain_model(model)
|
||||
# Update cache if node_execution_id is present
|
||||
if domain_model.node_execution_id:
|
||||
self._node_execution_cache[domain_model.node_execution_id] = domain_model
|
||||
domain_models.append(domain_model)
|
||||
|
||||
return db_models
|
||||
return domain_models
|
||||
|
||||
def get_running_executions(self, workflow_run_id: str) -> Sequence[NodeExecution]:
|
||||
"""
|
||||
|
@ -145,6 +145,9 @@ class WorkflowRunService:
|
||||
|
||||
# Use the repository to get the node executions with ordering
|
||||
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
||||
node_executions = repository.get_db_models_by_workflow_run(workflow_run_id=run_id, order_config=order_config)
|
||||
node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config)
|
||||
|
||||
return node_executions
|
||||
# Convert domain models to database models
|
||||
workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions]
|
||||
|
||||
return workflow_node_executions
|
||||
|
@ -10,10 +10,10 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.variables import Variable
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.node_execution_entities import NodeExecution, NodeExecutionStatus
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.nodes import NodeType
|
||||
@ -26,7 +26,6 @@ from core.workflow.workflow_entry import WorkflowEntry
|
||||
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, AppMode
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.workflow import (
|
||||
@ -268,35 +267,37 @@ class WorkflowService:
|
||||
# run draft workflow node
|
||||
start_at = time.perf_counter()
|
||||
|
||||
workflow_node_execution = self._handle_node_run_result(
|
||||
getter=lambda: WorkflowEntry.single_step_run(
|
||||
node_execution = self._handle_node_run_result(
|
||||
invoke_node_fn=lambda: WorkflowEntry.single_step_run(
|
||||
workflow=draft_workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
user_id=account.id,
|
||||
),
|
||||
start_at=start_at,
|
||||
tenant_id=app_model.tenant_id,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
workflow_node_execution.app_id = app_model.id
|
||||
workflow_node_execution.created_by = account.id
|
||||
workflow_node_execution.workflow_id = draft_workflow.id
|
||||
# Set workflow_id on the NodeExecution
|
||||
node_execution.workflow_id = draft_workflow.id
|
||||
|
||||
# Create repository and save the node execution
|
||||
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=db.engine,
|
||||
user=account,
|
||||
app_id=app_model.id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
repository.save(workflow_node_execution)
|
||||
repository.save(node_execution)
|
||||
|
||||
# Convert node_execution to WorkflowNodeExecution after save
|
||||
workflow_node_execution = repository.to_db_model(node_execution)
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def run_free_workflow_node(
|
||||
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
|
||||
) -> WorkflowNodeExecution:
|
||||
) -> NodeExecution:
|
||||
"""
|
||||
Run draft workflow node
|
||||
"""
|
||||
@ -304,7 +305,7 @@ class WorkflowService:
|
||||
start_at = time.perf_counter()
|
||||
|
||||
workflow_node_execution = self._handle_node_run_result(
|
||||
getter=lambda: WorkflowEntry.run_free_node(
|
||||
invoke_node_fn=lambda: WorkflowEntry.run_free_node(
|
||||
node_id=node_id,
|
||||
node_data=node_data,
|
||||
tenant_id=tenant_id,
|
||||
@ -312,7 +313,6 @@ class WorkflowService:
|
||||
user_inputs=user_inputs,
|
||||
),
|
||||
start_at=start_at,
|
||||
tenant_id=tenant_id,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
@ -320,21 +320,12 @@ class WorkflowService:
|
||||
|
||||
def _handle_node_run_result(
|
||||
self,
|
||||
getter: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]],
|
||||
invoke_node_fn: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]],
|
||||
start_at: float,
|
||||
tenant_id: str,
|
||||
node_id: str,
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Handle node run result
|
||||
|
||||
:param getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]]
|
||||
:param start_at: float
|
||||
:param tenant_id: str
|
||||
:param node_id: str
|
||||
"""
|
||||
) -> NodeExecution:
|
||||
try:
|
||||
node_instance, generator = getter()
|
||||
node_instance, generator = invoke_node_fn()
|
||||
|
||||
node_run_result: NodeRunResult | None = None
|
||||
for event in generator:
|
||||
@ -383,20 +374,21 @@ class WorkflowService:
|
||||
node_run_result = None
|
||||
error = e.error
|
||||
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.id = str(uuid4())
|
||||
workflow_node_execution.tenant_id = tenant_id
|
||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value
|
||||
workflow_node_execution.index = 1
|
||||
workflow_node_execution.node_id = node_id
|
||||
workflow_node_execution.node_type = node_instance.node_type
|
||||
workflow_node_execution.title = node_instance.node_data.title
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_node_execution.created_by_role = CreatorUserRole.ACCOUNT.value
|
||||
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
# Create a NodeExecution domain model
|
||||
node_execution = NodeExecution(
|
||||
id=str(uuid4()),
|
||||
workflow_id="", # This is a single-step execution, so no workflow ID
|
||||
index=1,
|
||||
node_id=node_id,
|
||||
node_type=node_instance.node_type,
|
||||
title=node_instance.node_data.title,
|
||||
elapsed_time=time.perf_counter() - start_at,
|
||||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
finished_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
if run_succeeded and node_run_result:
|
||||
# create workflow node execution
|
||||
# Set inputs, process_data, and outputs as dictionaries (not JSON strings)
|
||||
inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None
|
||||
process_data = (
|
||||
WorkflowEntry.handle_special_values(node_run_result.process_data)
|
||||
@ -405,23 +397,23 @@ class WorkflowService:
|
||||
)
|
||||
outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None
|
||||
|
||||
workflow_node_execution.inputs = json.dumps(inputs)
|
||||
workflow_node_execution.process_data = json.dumps(process_data)
|
||||
workflow_node_execution.outputs = json.dumps(outputs)
|
||||
workflow_node_execution.execution_metadata = (
|
||||
json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
|
||||
)
|
||||
if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION.value
|
||||
workflow_node_execution.error = node_run_result.error
|
||||
else:
|
||||
# create workflow node execution
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
node_execution.inputs = inputs
|
||||
node_execution.process_data = process_data
|
||||
node_execution.outputs = outputs
|
||||
node_execution.metadata = node_run_result.metadata
|
||||
|
||||
return workflow_node_execution
|
||||
# Map status from WorkflowNodeExecutionStatus to NodeExecutionStatus
|
||||
if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
node_execution.status = NodeExecutionStatus.SUCCEEDED
|
||||
elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
|
||||
node_execution.status = NodeExecutionStatus.EXCEPTION
|
||||
node_execution.error = node_run_result.error
|
||||
else:
|
||||
# Set failed status and error
|
||||
node_execution.status = NodeExecutionStatus.FAILED
|
||||
node_execution.error = error
|
||||
|
||||
return node_execution
|
||||
|
||||
def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App:
|
||||
"""
|
||||
|
@ -88,15 +88,15 @@ def test_save(repository, session):
|
||||
execution.outputs = None
|
||||
execution.metadata = None
|
||||
|
||||
# Mock the _to_db_model method to return the execution itself
|
||||
# Mock the to_db_model method to return the execution itself
|
||||
# This simulates the behavior of setting tenant_id and app_id
|
||||
repository._to_db_model = MagicMock(return_value=execution)
|
||||
repository.to_db_model = MagicMock(return_value=execution)
|
||||
|
||||
# Call save method
|
||||
repository.save(execution)
|
||||
|
||||
# Assert _to_db_model was called with the execution
|
||||
repository._to_db_model.assert_called_once_with(execution)
|
||||
# Assert to_db_model was called with the execution
|
||||
repository.to_db_model.assert_called_once_with(execution)
|
||||
|
||||
# Assert session.merge was called (now using merge for both save and update)
|
||||
session_obj.merge.assert_called_once_with(execution)
|
||||
@ -119,14 +119,14 @@ def test_save_with_existing_tenant_id(repository, session):
|
||||
modified_execution.tenant_id = "existing-tenant" # Tenant ID should not change
|
||||
modified_execution.app_id = repository._app_id # App ID should be set
|
||||
|
||||
# Mock the _to_db_model method to return the modified execution
|
||||
repository._to_db_model = MagicMock(return_value=modified_execution)
|
||||
# Mock the to_db_model method to return the modified execution
|
||||
repository.to_db_model = MagicMock(return_value=modified_execution)
|
||||
|
||||
# Call save method
|
||||
repository.save(execution)
|
||||
|
||||
# Assert _to_db_model was called with the execution
|
||||
repository._to_db_model.assert_called_once_with(execution)
|
||||
# Assert to_db_model was called with the execution
|
||||
repository.to_db_model.assert_called_once_with(execution)
|
||||
|
||||
# Assert session.merge was called with the modified execution (now using merge for both save and update)
|
||||
session_obj.merge.assert_called_once_with(modified_execution)
|
||||
@ -197,40 +197,6 @@ def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
|
||||
assert result[0] is mock_domain_model
|
||||
|
||||
|
||||
def test_get_db_models_by_workflow_run(repository, session, mocker: MockerFixture):
|
||||
"""Test get_db_models_by_workflow_run method."""
|
||||
session_obj, _ = session
|
||||
# Set up mock
|
||||
mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select")
|
||||
mock_stmt = mocker.MagicMock()
|
||||
mock_select.return_value = mock_stmt
|
||||
mock_stmt.where.return_value = mock_stmt
|
||||
mock_stmt.order_by.return_value = mock_stmt
|
||||
|
||||
# Create a properly configured mock execution
|
||||
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecution)
|
||||
configure_mock_execution(mock_execution)
|
||||
session_obj.scalars.return_value.all.return_value = [mock_execution]
|
||||
|
||||
# Mock the _to_domain_model method
|
||||
to_domain_model_mock = mocker.patch.object(repository, "_to_domain_model")
|
||||
|
||||
# Call method
|
||||
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
||||
result = repository.get_db_models_by_workflow_run(workflow_run_id="test-workflow-run-id", order_config=order_config)
|
||||
|
||||
# Assert select was called with correct parameters
|
||||
mock_select.assert_called_once()
|
||||
session_obj.scalars.assert_called_once_with(mock_stmt)
|
||||
|
||||
# Assert the result contains our mock db model directly (without conversion to domain model)
|
||||
assert len(result) == 1
|
||||
assert result[0] is mock_execution
|
||||
|
||||
# Verify that _to_domain_model was NOT called (since we're returning raw DB models)
|
||||
to_domain_model_mock.assert_not_called()
|
||||
|
||||
|
||||
def test_get_running_executions(repository, session, mocker: MockerFixture):
|
||||
"""Test get_running_executions method."""
|
||||
session_obj, _ = session
|
||||
@ -275,15 +241,15 @@ def test_update_via_save(repository, session):
|
||||
execution.outputs = None
|
||||
execution.metadata = None
|
||||
|
||||
# Mock the _to_db_model method to return the execution itself
|
||||
# Mock the to_db_model method to return the execution itself
|
||||
# This simulates the behavior of setting tenant_id and app_id
|
||||
repository._to_db_model = MagicMock(return_value=execution)
|
||||
repository.to_db_model = MagicMock(return_value=execution)
|
||||
|
||||
# Call save method to update an existing record
|
||||
repository.save(execution)
|
||||
|
||||
# Assert _to_db_model was called with the execution
|
||||
repository._to_db_model.assert_called_once_with(execution)
|
||||
# Assert to_db_model was called with the execution
|
||||
repository.to_db_model.assert_called_once_with(execution)
|
||||
|
||||
# Assert session.merge was called (for updates)
|
||||
session_obj.merge.assert_called_once_with(execution)
|
||||
@ -314,7 +280,7 @@ def test_clear(repository, session, mocker: MockerFixture):
|
||||
|
||||
|
||||
def test_to_db_model(repository):
|
||||
"""Test _to_db_model method."""
|
||||
"""Test to_db_model method."""
|
||||
# Create a domain model
|
||||
domain_model = NodeExecution(
|
||||
id="test-id",
|
||||
@ -338,7 +304,7 @@ def test_to_db_model(repository):
|
||||
)
|
||||
|
||||
# Convert to DB model
|
||||
db_model = repository._to_db_model(domain_model)
|
||||
db_model = repository.to_db_model(domain_model)
|
||||
|
||||
# Assert DB model has correct values
|
||||
assert isinstance(db_model, WorkflowNodeExecution)
|
||||
|
Loading…
x
Reference in New Issue
Block a user