mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 06:59:01 +08:00
refactor: Refactors workflow node execution handling (#18382)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
20df6e9c00
commit
44a2eca449
@ -39,6 +39,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
|
||||
:param query: str
|
||||
:return: dict
|
||||
"""
|
||||
# FIXME(-LAN-): Avoid import service into core
|
||||
workflow_service = WorkflowService()
|
||||
node_id = "1919810"
|
||||
node_data = ParameterExtractorNodeData(
|
||||
@ -89,6 +90,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
|
||||
:param query: str
|
||||
:return: dict
|
||||
"""
|
||||
# FIXME(-LAN-): Avoid import service into core
|
||||
workflow_service = WorkflowService()
|
||||
node_id = "1919810"
|
||||
node_data = QuestionClassifierNodeData(
|
||||
|
@ -86,3 +86,12 @@ class WorkflowNodeExecutionRepository(Protocol):
|
||||
execution: The WorkflowNodeExecution instance to update
|
||||
"""
|
||||
...
|
||||
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
Clear all WorkflowNodeExecution records based on implementation-specific criteria.
|
||||
|
||||
This method is intended to be used for bulk deletion operations, such as removing
|
||||
all records associated with a specific app_id and tenant_id in multi-tenant implementations.
|
||||
"""
|
||||
...
|
||||
|
@ -630,6 +630,7 @@ class WorkflowNodeExecution(Base):
|
||||
@property
|
||||
def created_by_account(self):
|
||||
created_by_role = CreatedByRole(self.created_by_role)
|
||||
# TODO(-LAN-): Avoid using db.session.get() here.
|
||||
return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None
|
||||
|
||||
@property
|
||||
@ -637,6 +638,7 @@ class WorkflowNodeExecution(Base):
|
||||
from models.model import EndUser
|
||||
|
||||
created_by_role = CreatedByRole(self.created_by_role)
|
||||
# TODO(-LAN-): Avoid using db.session.get() here.
|
||||
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
|
||||
|
||||
@property
|
||||
|
@ -6,7 +6,7 @@ import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import UnaryExpression, asc, desc, select
|
||||
from sqlalchemy import UnaryExpression, asc, delete, desc, select
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
@ -168,3 +168,25 @@ class SQLAlchemyWorkflowNodeExecutionRepository:
|
||||
|
||||
session.merge(execution)
|
||||
session.commit()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
Clear all WorkflowNodeExecution records for the current tenant_id and app_id.
|
||||
|
||||
This method deletes all WorkflowNodeExecution records that match the tenant_id
|
||||
and app_id (if provided) associated with this repository instance.
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
stmt = delete(WorkflowNodeExecution).where(WorkflowNodeExecution.tenant_id == self._tenant_id)
|
||||
|
||||
if self._app_id:
|
||||
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||
|
||||
result = session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
deleted_count = result.rowcount
|
||||
logger.info(
|
||||
f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}"
|
||||
+ (f" and app {self._app_id}" if self._app_id else "")
|
||||
)
|
||||
|
@ -2,13 +2,14 @@ import threading
|
||||
from typing import Optional
|
||||
|
||||
import contexts
|
||||
from core.repository import RepositoryFactory
|
||||
from core.repository.workflow_node_execution_repository import OrderConfig
|
||||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import App
|
||||
from models.workflow import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
WorkflowRun,
|
||||
)
|
||||
|
||||
@ -127,17 +128,17 @@ class WorkflowRunService:
|
||||
if not workflow_run:
|
||||
return []
|
||||
|
||||
node_executions = (
|
||||
db.session.query(WorkflowNodeExecution)
|
||||
.filter(
|
||||
WorkflowNodeExecution.tenant_id == app_model.tenant_id,
|
||||
WorkflowNodeExecution.app_id == app_model.id,
|
||||
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
WorkflowNodeExecution.workflow_run_id == run_id,
|
||||
)
|
||||
.order_by(WorkflowNodeExecution.index.desc())
|
||||
.all()
|
||||
# Use the repository to get the node executions
|
||||
repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||
params={
|
||||
"tenant_id": app_model.tenant_id,
|
||||
"app_id": app_model.id,
|
||||
"session_factory": db.session.get_bind,
|
||||
}
|
||||
)
|
||||
|
||||
return node_executions
|
||||
# Use the repository to get the node executions with ordering
|
||||
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
||||
node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config)
|
||||
|
||||
return list(node_executions)
|
||||
|
@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.repository import RepositoryFactory
|
||||
from core.variables import Variable
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
@ -282,8 +283,15 @@ class WorkflowService:
|
||||
workflow_node_execution.created_by = account.id
|
||||
workflow_node_execution.workflow_id = draft_workflow.id
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
# Use the repository to save the workflow node execution
|
||||
repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||
params={
|
||||
"tenant_id": app_model.tenant_id,
|
||||
"app_id": app_model.id,
|
||||
"session_factory": db.session.get_bind,
|
||||
}
|
||||
)
|
||||
repository.save(workflow_node_execution)
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
|
@ -7,6 +7,7 @@ from celery import shared_task # type: ignore
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from core.repository import RepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import AppDatasetJoin
|
||||
from models.model import (
|
||||
@ -30,7 +31,7 @@ from models.model import (
|
||||
)
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.web import PinnedConversation, SavedMessage
|
||||
from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecution, WorkflowRun
|
||||
from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowRun
|
||||
|
||||
|
||||
@shared_task(queue="app_deletion", bind=True, max_retries=3)
|
||||
@ -187,18 +188,20 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str):
|
||||
|
||||
|
||||
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
|
||||
def del_workflow_node_execution(workflow_node_execution_id: str):
|
||||
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from workflow_node_executions where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
{"tenant_id": tenant_id, "app_id": app_id},
|
||||
del_workflow_node_execution,
|
||||
"workflow node execution",
|
||||
# Create a repository instance for WorkflowNodeExecution
|
||||
repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||
params={
|
||||
"tenant_id": tenant_id,
|
||||
"app_id": app_id,
|
||||
"session_factory": db.session.get_bind,
|
||||
}
|
||||
)
|
||||
|
||||
# Use the clear method to delete all records for this tenant_id and app_id
|
||||
repository.clear()
|
||||
|
||||
logging.info(click.style(f"Deleted workflow node executions for tenant {tenant_id} and app {app_id}", fg="green"))
|
||||
|
||||
|
||||
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
|
||||
def del_workflow_app_log(workflow_app_log_id: str):
|
||||
|
@ -152,3 +152,27 @@ def test_update(repository, session):
|
||||
|
||||
# Assert session.merge was called
|
||||
session_obj.merge.assert_called_once_with(execution)
|
||||
|
||||
|
||||
def test_clear(repository, session, mocker: MockerFixture):
|
||||
"""Test clear method."""
|
||||
session_obj, _ = session
|
||||
# Set up mock
|
||||
mock_delete = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.delete")
|
||||
mock_stmt = mocker.MagicMock()
|
||||
mock_delete.return_value = mock_stmt
|
||||
mock_stmt.where.return_value = mock_stmt
|
||||
|
||||
# Mock the execute result with rowcount
|
||||
mock_result = mocker.MagicMock()
|
||||
mock_result.rowcount = 5 # Simulate 5 records deleted
|
||||
session_obj.execute.return_value = mock_result
|
||||
|
||||
# Call method
|
||||
repository.clear()
|
||||
|
||||
# Assert delete was called with correct parameters
|
||||
mock_delete.assert_called_once_with(WorkflowNodeExecution)
|
||||
mock_stmt.where.assert_called()
|
||||
session_obj.execute.assert_called_once_with(mock_stmt)
|
||||
session_obj.commit.assert_called_once()
|
||||
|
Loading…
x
Reference in New Issue
Block a user