mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 23:16:04 +08:00
refactor: Refactors workflow node execution handling (#18382)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
20df6e9c00
commit
44a2eca449
@ -39,6 +39,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
|
|||||||
:param query: str
|
:param query: str
|
||||||
:return: dict
|
:return: dict
|
||||||
"""
|
"""
|
||||||
|
# FIXME(-LAN-): Avoid import service into core
|
||||||
workflow_service = WorkflowService()
|
workflow_service = WorkflowService()
|
||||||
node_id = "1919810"
|
node_id = "1919810"
|
||||||
node_data = ParameterExtractorNodeData(
|
node_data = ParameterExtractorNodeData(
|
||||||
@ -89,6 +90,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
|
|||||||
:param query: str
|
:param query: str
|
||||||
:return: dict
|
:return: dict
|
||||||
"""
|
"""
|
||||||
|
# FIXME(-LAN-): Avoid import service into core
|
||||||
workflow_service = WorkflowService()
|
workflow_service = WorkflowService()
|
||||||
node_id = "1919810"
|
node_id = "1919810"
|
||||||
node_data = QuestionClassifierNodeData(
|
node_data = QuestionClassifierNodeData(
|
||||||
|
@ -86,3 +86,12 @@ class WorkflowNodeExecutionRepository(Protocol):
|
|||||||
execution: The WorkflowNodeExecution instance to update
|
execution: The WorkflowNodeExecution instance to update
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""
|
||||||
|
Clear all WorkflowNodeExecution records based on implementation-specific criteria.
|
||||||
|
|
||||||
|
This method is intended to be used for bulk deletion operations, such as removing
|
||||||
|
all records associated with a specific app_id and tenant_id in multi-tenant implementations.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
@ -630,6 +630,7 @@ class WorkflowNodeExecution(Base):
|
|||||||
@property
|
@property
|
||||||
def created_by_account(self):
|
def created_by_account(self):
|
||||||
created_by_role = CreatedByRole(self.created_by_role)
|
created_by_role = CreatedByRole(self.created_by_role)
|
||||||
|
# TODO(-LAN-): Avoid using db.session.get() here.
|
||||||
return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None
|
return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -637,6 +638,7 @@ class WorkflowNodeExecution(Base):
|
|||||||
from models.model import EndUser
|
from models.model import EndUser
|
||||||
|
|
||||||
created_by_role = CreatedByRole(self.created_by_role)
|
created_by_role = CreatedByRole(self.created_by_role)
|
||||||
|
# TODO(-LAN-): Avoid using db.session.get() here.
|
||||||
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
|
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -6,7 +6,7 @@ import logging
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from sqlalchemy import UnaryExpression, asc, desc, select
|
from sqlalchemy import UnaryExpression, asc, delete, desc, select
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
@ -168,3 +168,25 @@ class SQLAlchemyWorkflowNodeExecutionRepository:
|
|||||||
|
|
||||||
session.merge(execution)
|
session.merge(execution)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""
|
||||||
|
Clear all WorkflowNodeExecution records for the current tenant_id and app_id.
|
||||||
|
|
||||||
|
This method deletes all WorkflowNodeExecution records that match the tenant_id
|
||||||
|
and app_id (if provided) associated with this repository instance.
|
||||||
|
"""
|
||||||
|
with self._session_factory() as session:
|
||||||
|
stmt = delete(WorkflowNodeExecution).where(WorkflowNodeExecution.tenant_id == self._tenant_id)
|
||||||
|
|
||||||
|
if self._app_id:
|
||||||
|
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||||
|
|
||||||
|
result = session.execute(stmt)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
deleted_count = result.rowcount
|
||||||
|
logger.info(
|
||||||
|
f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}"
|
||||||
|
+ (f" and app {self._app_id}" if self._app_id else "")
|
||||||
|
)
|
||||||
|
@ -2,13 +2,14 @@ import threading
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import contexts
|
import contexts
|
||||||
|
from core.repository import RepositoryFactory
|
||||||
|
from core.repository.workflow_node_execution_repository import OrderConfig
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||||
from models.enums import WorkflowRunTriggeredFrom
|
from models.enums import WorkflowRunTriggeredFrom
|
||||||
from models.model import App
|
from models.model import App
|
||||||
from models.workflow import (
|
from models.workflow import (
|
||||||
WorkflowNodeExecution,
|
WorkflowNodeExecution,
|
||||||
WorkflowNodeExecutionTriggeredFrom,
|
|
||||||
WorkflowRun,
|
WorkflowRun,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -127,17 +128,17 @@ class WorkflowRunService:
|
|||||||
if not workflow_run:
|
if not workflow_run:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
node_executions = (
|
# Use the repository to get the node executions
|
||||||
db.session.query(WorkflowNodeExecution)
|
repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||||
.filter(
|
params={
|
||||||
WorkflowNodeExecution.tenant_id == app_model.tenant_id,
|
"tenant_id": app_model.tenant_id,
|
||||||
WorkflowNodeExecution.app_id == app_model.id,
|
"app_id": app_model.id,
|
||||||
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
"session_factory": db.session.get_bind,
|
||||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
}
|
||||||
WorkflowNodeExecution.workflow_run_id == run_id,
|
|
||||||
)
|
|
||||||
.order_by(WorkflowNodeExecution.index.desc())
|
|
||||||
.all()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return node_executions
|
# Use the repository to get the node executions with ordering
|
||||||
|
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
||||||
|
node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config)
|
||||||
|
|
||||||
|
return list(node_executions)
|
||||||
|
@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
|
|||||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
|
from core.repository import RepositoryFactory
|
||||||
from core.variables import Variable
|
from core.variables import Variable
|
||||||
from core.workflow.entities.node_entities import NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||||
@ -282,8 +283,15 @@ class WorkflowService:
|
|||||||
workflow_node_execution.created_by = account.id
|
workflow_node_execution.created_by = account.id
|
||||||
workflow_node_execution.workflow_id = draft_workflow.id
|
workflow_node_execution.workflow_id = draft_workflow.id
|
||||||
|
|
||||||
db.session.add(workflow_node_execution)
|
# Use the repository to save the workflow node execution
|
||||||
db.session.commit()
|
repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||||
|
params={
|
||||||
|
"tenant_id": app_model.tenant_id,
|
||||||
|
"app_id": app_model.id,
|
||||||
|
"session_factory": db.session.get_bind,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
repository.save(workflow_node_execution)
|
||||||
|
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ from celery import shared_task # type: ignore
|
|||||||
from sqlalchemy import delete
|
from sqlalchemy import delete
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
|
from core.repository import RepositoryFactory
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import AppDatasetJoin
|
from models.dataset import AppDatasetJoin
|
||||||
from models.model import (
|
from models.model import (
|
||||||
@ -30,7 +31,7 @@ from models.model import (
|
|||||||
)
|
)
|
||||||
from models.tools import WorkflowToolProvider
|
from models.tools import WorkflowToolProvider
|
||||||
from models.web import PinnedConversation, SavedMessage
|
from models.web import PinnedConversation, SavedMessage
|
||||||
from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecution, WorkflowRun
|
from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowRun
|
||||||
|
|
||||||
|
|
||||||
@shared_task(queue="app_deletion", bind=True, max_retries=3)
|
@shared_task(queue="app_deletion", bind=True, max_retries=3)
|
||||||
@ -187,18 +188,20 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str):
|
|||||||
|
|
||||||
|
|
||||||
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
|
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
|
||||||
def del_workflow_node_execution(workflow_node_execution_id: str):
|
# Create a repository instance for WorkflowNodeExecution
|
||||||
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).delete(
|
repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||||
synchronize_session=False
|
params={
|
||||||
)
|
"tenant_id": tenant_id,
|
||||||
|
"app_id": app_id,
|
||||||
_delete_records(
|
"session_factory": db.session.get_bind,
|
||||||
"""select id from workflow_node_executions where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
}
|
||||||
{"tenant_id": tenant_id, "app_id": app_id},
|
|
||||||
del_workflow_node_execution,
|
|
||||||
"workflow node execution",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Use the clear method to delete all records for this tenant_id and app_id
|
||||||
|
repository.clear()
|
||||||
|
|
||||||
|
logging.info(click.style(f"Deleted workflow node executions for tenant {tenant_id} and app {app_id}", fg="green"))
|
||||||
|
|
||||||
|
|
||||||
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
|
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
|
||||||
def del_workflow_app_log(workflow_app_log_id: str):
|
def del_workflow_app_log(workflow_app_log_id: str):
|
||||||
|
@ -152,3 +152,27 @@ def test_update(repository, session):
|
|||||||
|
|
||||||
# Assert session.merge was called
|
# Assert session.merge was called
|
||||||
session_obj.merge.assert_called_once_with(execution)
|
session_obj.merge.assert_called_once_with(execution)
|
||||||
|
|
||||||
|
|
||||||
|
def test_clear(repository, session, mocker: MockerFixture):
|
||||||
|
"""Test clear method."""
|
||||||
|
session_obj, _ = session
|
||||||
|
# Set up mock
|
||||||
|
mock_delete = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.delete")
|
||||||
|
mock_stmt = mocker.MagicMock()
|
||||||
|
mock_delete.return_value = mock_stmt
|
||||||
|
mock_stmt.where.return_value = mock_stmt
|
||||||
|
|
||||||
|
# Mock the execute result with rowcount
|
||||||
|
mock_result = mocker.MagicMock()
|
||||||
|
mock_result.rowcount = 5 # Simulate 5 records deleted
|
||||||
|
session_obj.execute.return_value = mock_result
|
||||||
|
|
||||||
|
# Call method
|
||||||
|
repository.clear()
|
||||||
|
|
||||||
|
# Assert delete was called with correct parameters
|
||||||
|
mock_delete.assert_called_once_with(WorkflowNodeExecution)
|
||||||
|
mock_stmt.where.assert_called()
|
||||||
|
session_obj.execute.assert_called_once_with(mock_stmt)
|
||||||
|
session_obj.commit.assert_called_once()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user