diff --git a/api/core/plugin/backwards_invocation/node.py b/api/core/plugin/backwards_invocation/node.py index f402da030f..db07e52f3f 100644 --- a/api/core/plugin/backwards_invocation/node.py +++ b/api/core/plugin/backwards_invocation/node.py @@ -39,6 +39,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): :param query: str :return: dict """ + # FIXME(-LAN-): Avoid import service into core workflow_service = WorkflowService() node_id = "1919810" node_data = ParameterExtractorNodeData( @@ -89,6 +90,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): :param query: str :return: dict """ + # FIXME(-LAN-): Avoid import service into core workflow_service = WorkflowService() node_id = "1919810" node_data = QuestionClassifierNodeData( diff --git a/api/core/repository/workflow_node_execution_repository.py b/api/core/repository/workflow_node_execution_repository.py index 6dea4566de..9bb790cb0f 100644 --- a/api/core/repository/workflow_node_execution_repository.py +++ b/api/core/repository/workflow_node_execution_repository.py @@ -86,3 +86,12 @@ class WorkflowNodeExecutionRepository(Protocol): execution: The WorkflowNodeExecution instance to update """ ... + + def clear(self) -> None: + """ + Clear all WorkflowNodeExecution records based on implementation-specific criteria. + + This method is intended to be used for bulk deletion operations, such as removing + all records associated with a specific app_id and tenant_id in multi-tenant implementations. + """ + ... diff --git a/api/models/workflow.py b/api/models/workflow.py index 045fa0aaa0..51f2f4cc9f 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -630,6 +630,7 @@ class WorkflowNodeExecution(Base): @property def created_by_account(self): created_by_role = CreatedByRole(self.created_by_role) + # TODO(-LAN-): Avoid using db.session.get() here. return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None @property @@ -637,6 +638,7 @@ class WorkflowNodeExecution(Base): from models.model import EndUser created_by_role = CreatedByRole(self.created_by_role) + # TODO(-LAN-): Avoid using db.session.get() here. return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None @property diff --git a/api/repositories/workflow_node_execution/sqlalchemy_repository.py b/api/repositories/workflow_node_execution/sqlalchemy_repository.py index c9c6e70ff3..0594d816a2 100644 --- a/api/repositories/workflow_node_execution/sqlalchemy_repository.py +++ b/api/repositories/workflow_node_execution/sqlalchemy_repository.py @@ -6,7 +6,7 @@ import logging from collections.abc import Sequence from typing import Optional -from sqlalchemy import UnaryExpression, asc, desc, select +from sqlalchemy import UnaryExpression, asc, delete, desc, select from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -168,3 +168,25 @@ class SQLAlchemyWorkflowNodeExecutionRepository: session.merge(execution) session.commit() + + def clear(self) -> None: + """ + Clear all WorkflowNodeExecution records for the current tenant_id and app_id. + + This method deletes all WorkflowNodeExecution records that match the tenant_id + and app_id (if provided) associated with this repository instance. + """ + with self._session_factory() as session: + stmt = delete(WorkflowNodeExecution).where(WorkflowNodeExecution.tenant_id == self._tenant_id) + + if self._app_id: + stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) + + result = session.execute(stmt) + session.commit() + + deleted_count = result.rowcount + logger.info( + f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}" + + (f" and app {self._app_id}" if self._app_id else "") + ) diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index 0ddd18ea27..ff3b33eecd 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -2,13 +2,14 @@ import threading from typing import Optional import contexts +from core.repository import RepositoryFactory +from core.repository.workflow_node_execution_repository import OrderConfig from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom from models.model import App from models.workflow import ( WorkflowNodeExecution, - WorkflowNodeExecutionTriggeredFrom, WorkflowRun, ) @@ -127,17 +128,17 @@ class WorkflowRunService: if not workflow_run: return [] - node_executions = ( - db.session.query(WorkflowNodeExecution) - .filter( - WorkflowNodeExecution.tenant_id == app_model.tenant_id, - WorkflowNodeExecution.app_id == app_model.id, - WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, - WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - WorkflowNodeExecution.workflow_run_id == run_id, - ) - .order_by(WorkflowNodeExecution.index.desc()) - .all() + # Use the repository to get the node executions + repository = RepositoryFactory.create_workflow_node_execution_repository( + params={ + "tenant_id": app_model.tenant_id, + "app_id": app_model.id, + "session_factory": db.session.get_bind, + } ) - return node_executions + # Use the repository to get the node executions with ordering + order_config = OrderConfig(order_by=["index"], order_direction="desc") + node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config) + + return list(node_executions) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 992942fc70..b88c7b296d 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -11,6 +11,7 @@ from sqlalchemy.orm import Session from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.model_runtime.utils.encoders import jsonable_encoder +from core.repository import RepositoryFactory from core.variables import Variable from core.workflow.entities.node_entities import NodeRunResult from core.workflow.errors import WorkflowNodeRunFailedError @@ -282,8 +283,15 @@ class WorkflowService: workflow_node_execution.created_by = account.id workflow_node_execution.workflow_id = draft_workflow.id - db.session.add(workflow_node_execution) - db.session.commit() + # Use the repository to save the workflow node execution + repository = RepositoryFactory.create_workflow_node_execution_repository( + params={ + "tenant_id": app_model.tenant_id, + "app_id": app_model.id, + "session_factory": db.session.get_bind, + } + ) + repository.save(workflow_node_execution) return workflow_node_execution diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index c3910e2be3..4542b1b923 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -7,6 +7,7 @@ from celery import shared_task # type: ignore from sqlalchemy import delete from sqlalchemy.exc import SQLAlchemyError +from core.repository import RepositoryFactory from extensions.ext_database import db from models.dataset import AppDatasetJoin from models.model import ( @@ -30,7 +31,7 @@ from models.model import ( ) from models.tools import WorkflowToolProvider from models.web import PinnedConversation, SavedMessage -from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecution, WorkflowRun +from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowRun @shared_task(queue="app_deletion", bind=True, max_retries=3) @@ -187,18 +188,20 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str): def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): - def del_workflow_node_execution(workflow_node_execution_id: str): - db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).delete( - synchronize_session=False - ) - - _delete_records( - """select id from workflow_node_executions where tenant_id=:tenant_id and app_id=:app_id limit 1000""", - {"tenant_id": tenant_id, "app_id": app_id}, - del_workflow_node_execution, - "workflow node execution", + # Create a repository instance for WorkflowNodeExecution + repository = RepositoryFactory.create_workflow_node_execution_repository( + params={ + "tenant_id": tenant_id, + "app_id": app_id, + "session_factory": db.session.get_bind, + } ) + # Use the clear method to delete all records for this tenant_id and app_id + repository.clear() + + logging.info(click.style(f"Deleted workflow node executions for tenant {tenant_id} and app {app_id}", fg="green")) + def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): def del_workflow_app_log(workflow_app_log_id: str): diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py index f31adab2a8..36847f8a13 100644 --- a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py +++ b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py @@ -152,3 +152,27 @@ def test_update(repository, session): # Assert session.merge was called session_obj.merge.assert_called_once_with(execution) + + +def test_clear(repository, session, mocker: MockerFixture): + """Test clear method.""" + session_obj, _ = session + # Set up mock + mock_delete = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.delete") + mock_stmt = mocker.MagicMock() + mock_delete.return_value = mock_stmt + mock_stmt.where.return_value = mock_stmt + + # Mock the execute result with rowcount + mock_result = mocker.MagicMock() + mock_result.rowcount = 5 # Simulate 5 records deleted + session_obj.execute.return_value = mock_result + + # Call method + repository.clear() + + # Assert delete was called with correct parameters + mock_delete.assert_called_once_with(WorkflowNodeExecution) + mock_stmt.where.assert_called() + session_obj.execute.assert_called_once_with(mock_stmt) + session_obj.commit.assert_called_once()