refactor: Refactors workflow node execution handling (#18382)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2025-04-18 21:06:24 +09:00 committed by GitHub
parent 20df6e9c00
commit 44a2eca449
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 98 additions and 27 deletions

View File

@ -39,6 +39,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
:param query: str
:return: dict
"""
# FIXME(-LAN-): Avoid import service into core
workflow_service = WorkflowService()
node_id = "1919810"
node_data = ParameterExtractorNodeData(
@ -89,6 +90,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
:param query: str
:return: dict
"""
# FIXME(-LAN-): Avoid import service into core
workflow_service = WorkflowService()
node_id = "1919810"
node_data = QuestionClassifierNodeData(

View File

@ -86,3 +86,12 @@ class WorkflowNodeExecutionRepository(Protocol):
execution: The WorkflowNodeExecution instance to update
"""
...
def clear(self) -> None:
"""
Clear all WorkflowNodeExecution records based on implementation-specific criteria.
This method is intended to be used for bulk deletion operations, such as removing
all records associated with a specific app_id and tenant_id in multi-tenant implementations.
"""
...

View File

@ -630,6 +630,7 @@ class WorkflowNodeExecution(Base):
@property
def created_by_account(self):
created_by_role = CreatedByRole(self.created_by_role)
# TODO(-LAN-): Avoid using db.session.get() here.
return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None
@property
@ -637,6 +638,7 @@ class WorkflowNodeExecution(Base):
from models.model import EndUser
created_by_role = CreatedByRole(self.created_by_role)
# TODO(-LAN-): Avoid using db.session.get() here.
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
@property

View File

@ -6,7 +6,7 @@ import logging
from collections.abc import Sequence
from typing import Optional
from sqlalchemy import UnaryExpression, asc, desc, select
from sqlalchemy import UnaryExpression, asc, delete, desc, select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
@ -168,3 +168,25 @@ class SQLAlchemyWorkflowNodeExecutionRepository:
session.merge(execution)
session.commit()
def clear(self) -> None:
"""
Clear all WorkflowNodeExecution records for the current tenant_id and app_id.
This method deletes all WorkflowNodeExecution records that match the tenant_id
and app_id (if provided) associated with this repository instance.
"""
with self._session_factory() as session:
stmt = delete(WorkflowNodeExecution).where(WorkflowNodeExecution.tenant_id == self._tenant_id)
if self._app_id:
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
result = session.execute(stmt)
session.commit()
deleted_count = result.rowcount
logger.info(
f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}"
+ (f" and app {self._app_id}" if self._app_id else "")
)

View File

@ -2,13 +2,14 @@ import threading
from typing import Optional
import contexts
from core.repository import RepositoryFactory
from core.repository.workflow_node_execution_repository import OrderConfig
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import WorkflowRunTriggeredFrom
from models.model import App
from models.workflow import (
WorkflowNodeExecution,
WorkflowNodeExecutionTriggeredFrom,
WorkflowRun,
)
@ -127,17 +128,17 @@ class WorkflowRunService:
if not workflow_run:
return []
node_executions = (
db.session.query(WorkflowNodeExecution)
.filter(
WorkflowNodeExecution.tenant_id == app_model.tenant_id,
WorkflowNodeExecution.app_id == app_model.id,
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
WorkflowNodeExecution.workflow_run_id == run_id,
)
.order_by(WorkflowNodeExecution.index.desc())
.all()
# Use the repository to get the node executions
repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": app_model.tenant_id,
"app_id": app_model.id,
"session_factory": db.session.get_bind,
}
)
return node_executions
# Use the repository to get the node executions with ordering
order_config = OrderConfig(order_by=["index"], order_direction="desc")
node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config)
return list(node_executions)

View File

@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.model_runtime.utils.encoders import jsonable_encoder
from core.repository import RepositoryFactory
from core.variables import Variable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.errors import WorkflowNodeRunFailedError
@ -282,8 +283,15 @@ class WorkflowService:
workflow_node_execution.created_by = account.id
workflow_node_execution.workflow_id = draft_workflow.id
db.session.add(workflow_node_execution)
db.session.commit()
# Use the repository to save the workflow node execution
repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": app_model.tenant_id,
"app_id": app_model.id,
"session_factory": db.session.get_bind,
}
)
repository.save(workflow_node_execution)
return workflow_node_execution

View File

@ -7,6 +7,7 @@ from celery import shared_task # type: ignore
from sqlalchemy import delete
from sqlalchemy.exc import SQLAlchemyError
from core.repository import RepositoryFactory
from extensions.ext_database import db
from models.dataset import AppDatasetJoin
from models.model import (
@ -30,7 +31,7 @@ from models.model import (
)
from models.tools import WorkflowToolProvider
from models.web import PinnedConversation, SavedMessage
from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecution, WorkflowRun
from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowRun
@shared_task(queue="app_deletion", bind=True, max_retries=3)
@ -187,18 +188,20 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str):
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
def del_workflow_node_execution(workflow_node_execution_id: str):
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).delete(
synchronize_session=False
)
_delete_records(
"""select id from workflow_node_executions where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
{"tenant_id": tenant_id, "app_id": app_id},
del_workflow_node_execution,
"workflow node execution",
# Create a repository instance for WorkflowNodeExecution
repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": tenant_id,
"app_id": app_id,
"session_factory": db.session.get_bind,
}
)
# Use the clear method to delete all records for this tenant_id and app_id
repository.clear()
logging.info(click.style(f"Deleted workflow node executions for tenant {tenant_id} and app {app_id}", fg="green"))
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
def del_workflow_app_log(workflow_app_log_id: str):

View File

@ -152,3 +152,27 @@ def test_update(repository, session):
# Assert session.merge was called
session_obj.merge.assert_called_once_with(execution)
def test_clear(repository, session, mocker: MockerFixture):
"""Test clear method."""
session_obj, _ = session
# Set up mock
mock_delete = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.delete")
mock_stmt = mocker.MagicMock()
mock_delete.return_value = mock_stmt
mock_stmt.where.return_value = mock_stmt
# Mock the execute result with rowcount
mock_result = mocker.MagicMock()
mock_result.rowcount = 5 # Simulate 5 records deleted
session_obj.execute.return_value = mock_result
# Call method
repository.clear()
# Assert delete was called with correct parameters
mock_delete.assert_called_once_with(WorkflowNodeExecution)
mock_stmt.where.assert_called()
session_obj.execute.assert_called_once_with(mock_stmt)
session_obj.commit.assert_called_once()