diff --git a/api/.env.example b/api/.env.example index 502461f658..01ddb4adfd 100644 --- a/api/.env.example +++ b/api/.env.example @@ -424,6 +424,12 @@ WORKFLOW_CALL_MAX_DEPTH=5 WORKFLOW_PARALLEL_DEPTH_LIMIT=3 MAX_VARIABLE_SIZE=204800 +# Workflow storage configuration +# Options: rdbms, hybrid +# rdbms: Use only the relational database (default) +# hybrid: Save new data to object storage, read from both object storage and RDBMS +WORKFLOW_NODE_EXECUTION_STORAGE=rdbms + # App configuration APP_MAX_EXECUTION_TIME=1200 APP_MAX_ACTIVE_REQUESTS=0 diff --git a/api/app_factory.py b/api/app_factory.py index 1c886ac5c7..586f2ded9e 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -54,6 +54,7 @@ def initialize_extensions(app: DifyApp): ext_otel, ext_proxy_fix, ext_redis, + ext_repositories, ext_sentry, ext_set_secretkey, ext_storage, @@ -74,6 +75,7 @@ def initialize_extensions(app: DifyApp): ext_migrate, ext_redis, ext_storage, + ext_repositories, ext_celery, ext_login, ext_mail, diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index d35a74e3ee..f498dccbbc 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -12,7 +12,7 @@ from pydantic import ( ) from pydantic_settings import BaseSettings -from configs.feature.hosted_service import HostedServiceConfig +from .hosted_service import HostedServiceConfig class SecurityConfig(BaseSettings): @@ -519,6 +519,11 @@ class WorkflowNodeExecutionConfig(BaseSettings): default=100, ) + WORKFLOW_NODE_EXECUTION_STORAGE: str = Field( + default="rdbms", + description="Storage backend for WorkflowNodeExecution. Options: 'rdbms', 'hybrid'", + ) + class AuthConfig(BaseSettings): """ diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 66f2c754bb..3bf6c330db 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -320,10 +320,9 @@ class AdvancedChatAppGenerateTaskPipeline: session=session, workflow_run_id=self._workflow_run_id ) workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried( - session=session, workflow_run=workflow_run, event=event + workflow_run=workflow_run, event=event ) node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response( - session=session, event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, @@ -341,11 +340,10 @@ class AdvancedChatAppGenerateTaskPipeline: session=session, workflow_run_id=self._workflow_run_id ) workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start( - session=session, workflow_run=workflow_run, event=event + workflow_run=workflow_run, event=event ) node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response( - session=session, event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, @@ -363,11 +361,10 @@ class AdvancedChatAppGenerateTaskPipeline: with Session(db.engine, expire_on_commit=False) as session: workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( - session=session, event=event + event=event ) node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( - session=session, event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, @@ -383,18 +380,15 @@ class AdvancedChatAppGenerateTaskPipeline: | QueueNodeInLoopFailedEvent | QueueNodeExceptionEvent, ): - with Session(db.engine, expire_on_commit=False) as session: - workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( - session=session, event=event - ) + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( + event=event + ) - node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( - session=session, - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) - session.commit() + node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) if node_finish_resp: yield node_finish_resp diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 14441ada40..1f998edb6a 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -279,10 +279,9 @@ class WorkflowAppGenerateTaskPipeline: session=session, workflow_run_id=self._workflow_run_id ) workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried( - session=session, workflow_run=workflow_run, event=event + workflow_run=workflow_run, event=event ) response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response( - session=session, event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, @@ -300,10 +299,9 @@ class WorkflowAppGenerateTaskPipeline: session=session, workflow_run_id=self._workflow_run_id ) workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start( - session=session, workflow_run=workflow_run, event=event + workflow_run=workflow_run, event=event ) node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response( - session=session, event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, @@ -313,17 +311,14 @@ class WorkflowAppGenerateTaskPipeline: if node_start_response: yield node_start_response elif isinstance(event, QueueNodeSucceededEvent): - with Session(db.engine, expire_on_commit=False) as session: - workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( - session=session, event=event - ) - node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( - session=session, - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) - session.commit() + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( + event=event + ) + node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) if node_success_response: yield node_success_response @@ -334,18 +329,14 @@ class WorkflowAppGenerateTaskPipeline: | QueueNodeInLoopFailedEvent | QueueNodeExceptionEvent, ): - with Session(db.engine, expire_on_commit=False) as session: - workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( - session=session, - event=event, - ) - node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( - session=session, - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) - session.commit() + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( + event=event, + ) + node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) if node_failed_response: yield node_failed_response @@ -627,6 +618,7 @@ class WorkflowAppGenerateTaskPipeline: workflow_app_log.created_by = self._user_id session.add(workflow_app_log) + session.commit() def _text_chunk_to_stream_response( self, text: str, from_variable_selector: Optional[list[str]] = None diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 4d629ca186..5ce9f737d1 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -6,7 +6,7 @@ from typing import Any, Optional, Union, cast from uuid import uuid4 from sqlalchemy import func, select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( @@ -49,12 +49,14 @@ from core.file import FILE_MODEL_IDENTITY, File from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask +from core.repository import RepositoryFactory from core.tools.tool_manager import ToolManager from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.enums import SystemVariableKey from core.workflow.nodes import NodeType from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.workflow_entry import WorkflowEntry +from extensions.ext_database import db from models.account import Account from models.enums import CreatedByRole, WorkflowRunTriggeredFrom from models.model import EndUser @@ -80,6 +82,21 @@ class WorkflowCycleManage: self._application_generate_entity = application_generate_entity self._workflow_system_variables = workflow_system_variables + # Initialize the session factory and repository + # We use the global db engine instead of the session passed to methods + # Disable expire_on_commit to avoid the need for merging objects + self._session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + self._workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + params={ + "tenant_id": self._application_generate_entity.app_config.tenant_id, + "app_id": self._application_generate_entity.app_config.app_id, + "session_factory": self._session_factory, + } + ) + + # We'll still keep the cache for backward compatibility and performance + # but use the repository for database operations + def _handle_workflow_run_start( self, *, @@ -254,19 +271,15 @@ class WorkflowCycleManage: workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_run.exceptions_count = exceptions_count - stmt = select(WorkflowNodeExecution.node_execution_id).where( - WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, - WorkflowNodeExecution.app_id == workflow_run.app_id, - WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, - WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - WorkflowNodeExecution.workflow_run_id == workflow_run.id, - WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value, + # Use the instance repository to find running executions for a workflow run + running_workflow_node_executions = self._workflow_node_execution_repository.get_running_executions( + workflow_run_id=workflow_run.id ) - ids = session.scalars(stmt).all() - # Use self._get_workflow_node_execution here to make sure the cache is updated - running_workflow_node_executions = [ - self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id - ] + + # Update the cache with the retrieved executions + for execution in running_workflow_node_executions: + if execution.node_execution_id: + self._workflow_node_executions[execution.node_execution_id] = execution for workflow_node_execution in running_workflow_node_executions: now = datetime.now(UTC).replace(tzinfo=None) @@ -288,7 +301,7 @@ class WorkflowCycleManage: return workflow_run def _handle_node_execution_start( - self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeStartedEvent + self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent ) -> WorkflowNodeExecution: workflow_node_execution = WorkflowNodeExecution() workflow_node_execution.id = str(uuid4()) @@ -315,17 +328,14 @@ class WorkflowCycleManage: ) workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) - session.add(workflow_node_execution) + # Use the instance repository to save the workflow node execution + self._workflow_node_execution_repository.save(workflow_node_execution) self._workflow_node_executions[event.node_execution_id] = workflow_node_execution return workflow_node_execution - def _handle_workflow_node_execution_success( - self, *, session: Session, event: QueueNodeSucceededEvent - ) -> WorkflowNodeExecution: - workflow_node_execution = self._get_workflow_node_execution( - session=session, node_execution_id=event.node_execution_id - ) + def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: + workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id) inputs = WorkflowEntry.handle_special_values(event.inputs) process_data = WorkflowEntry.handle_special_values(event.process_data) outputs = WorkflowEntry.handle_special_values(event.outputs) @@ -344,13 +354,13 @@ class WorkflowCycleManage: workflow_node_execution.finished_at = finished_at workflow_node_execution.elapsed_time = elapsed_time - workflow_node_execution = session.merge(workflow_node_execution) + # Use the instance repository to update the workflow node execution + self._workflow_node_execution_repository.update(workflow_node_execution) return workflow_node_execution def _handle_workflow_node_execution_failed( self, *, - session: Session, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeInLoopFailedEvent @@ -361,9 +371,7 @@ class WorkflowCycleManage: :param event: queue node failed event :return: """ - workflow_node_execution = self._get_workflow_node_execution( - session=session, node_execution_id=event.node_execution_id - ) + workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id) inputs = WorkflowEntry.handle_special_values(event.inputs) process_data = WorkflowEntry.handle_special_values(event.process_data) @@ -387,14 +395,14 @@ class WorkflowCycleManage: workflow_node_execution.elapsed_time = elapsed_time workflow_node_execution.execution_metadata = execution_metadata - workflow_node_execution = session.merge(workflow_node_execution) return workflow_node_execution def _handle_workflow_node_execution_retried( - self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeRetryEvent + self, *, workflow_run: WorkflowRun, event: QueueNodeRetryEvent ) -> WorkflowNodeExecution: """ Workflow node execution failed + :param workflow_run: workflow run :param event: queue node failed event :return: """ @@ -439,15 +447,12 @@ class WorkflowCycleManage: workflow_node_execution.execution_metadata = execution_metadata workflow_node_execution.index = event.node_run_index - session.add(workflow_node_execution) + # Use the instance repository to save the workflow node execution + self._workflow_node_execution_repository.save(workflow_node_execution) self._workflow_node_executions[event.node_execution_id] = workflow_node_execution return workflow_node_execution - ################################################# - # to stream responses # - ################################################# - def _workflow_start_to_stream_response( self, *, @@ -455,7 +460,6 @@ class WorkflowCycleManage: task_id: str, workflow_run: WorkflowRun, ) -> WorkflowStartStreamResponse: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this _ = session return WorkflowStartStreamResponse( task_id=task_id, @@ -521,14 +525,10 @@ class WorkflowCycleManage: def _workflow_node_start_to_stream_response( self, *, - session: Session, event: QueueNodeStartedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution, ) -> Optional[NodeStartStreamResponse]: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this - _ = session - if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: return None if not workflow_node_execution.workflow_run_id: @@ -571,7 +571,6 @@ class WorkflowCycleManage: def _workflow_node_finish_to_stream_response( self, *, - session: Session, event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent @@ -580,8 +579,6 @@ class WorkflowCycleManage: task_id: str, workflow_node_execution: WorkflowNodeExecution, ) -> Optional[NodeFinishStreamResponse]: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this - _ = session if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: return None if not workflow_node_execution.workflow_run_id: @@ -621,13 +618,10 @@ class WorkflowCycleManage: def _workflow_node_retry_to_stream_response( self, *, - session: Session, event: QueueNodeRetryEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution, ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this - _ = session if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: return None if not workflow_node_execution.workflow_run_id: @@ -668,7 +662,6 @@ class WorkflowCycleManage: def _workflow_parallel_branch_start_to_stream_response( self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent ) -> ParallelBranchStartStreamResponse: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this _ = session return ParallelBranchStartStreamResponse( task_id=task_id, @@ -692,7 +685,6 @@ class WorkflowCycleManage: workflow_run: WorkflowRun, event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent, ) -> ParallelBranchFinishedStreamResponse: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this _ = session return ParallelBranchFinishedStreamResponse( task_id=task_id, @@ -713,7 +705,6 @@ class WorkflowCycleManage: def _workflow_iteration_start_to_stream_response( self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent ) -> IterationNodeStartStreamResponse: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this _ = session return IterationNodeStartStreamResponse( task_id=task_id, @@ -735,7 +726,6 @@ class WorkflowCycleManage: def _workflow_iteration_next_to_stream_response( self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent ) -> IterationNodeNextStreamResponse: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this _ = session return IterationNodeNextStreamResponse( task_id=task_id, @@ -759,7 +749,6 @@ class WorkflowCycleManage: def _workflow_iteration_completed_to_stream_response( self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent ) -> IterationNodeCompletedStreamResponse: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this _ = session return IterationNodeCompletedStreamResponse( task_id=task_id, @@ -790,7 +779,6 @@ class WorkflowCycleManage: def _workflow_loop_start_to_stream_response( self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopStartEvent ) -> LoopNodeStartStreamResponse: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this _ = session return LoopNodeStartStreamResponse( task_id=task_id, @@ -812,7 +800,6 @@ class WorkflowCycleManage: def _workflow_loop_next_to_stream_response( self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopNextEvent ) -> LoopNodeNextStreamResponse: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this _ = session return LoopNodeNextStreamResponse( task_id=task_id, @@ -836,7 +823,6 @@ class WorkflowCycleManage: def _workflow_loop_completed_to_stream_response( self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopCompletedEvent ) -> LoopNodeCompletedStreamResponse: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this _ = session return LoopNodeCompletedStreamResponse( task_id=task_id, @@ -934,11 +920,22 @@ class WorkflowCycleManage: return workflow_run - def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution: - if node_execution_id not in self._workflow_node_executions: + def _get_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution: + # First check the cache for performance + if node_execution_id in self._workflow_node_executions: + cached_execution = self._workflow_node_executions[node_execution_id] + # No need to merge with session since expire_on_commit=False + return cached_execution + + # If not in cache, use the instance repository to get by node_execution_id + execution = self._workflow_node_execution_repository.get_by_node_execution_id(node_execution_id) + + if not execution: raise ValueError(f"Workflow node execution not found: {node_execution_id}") - cached_workflow_node_execution = self._workflow_node_executions[node_execution_id] - return session.merge(cached_workflow_node_execution) + + # Update cache + self._workflow_node_executions[node_execution_id] = execution + return execution def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse: """ diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index f67e270ab1..fa78b7b8e9 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -5,6 +5,7 @@ from datetime import datetime, timedelta from typing import Optional from langfuse import Langfuse # type: ignore +from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import LangfuseConfig @@ -28,9 +29,9 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( UnitEnum, ) from core.ops.utils import filter_none_values +from core.repository.repository_factory import RepositoryFactory from extensions.ext_database import db from models.model import EndUser -from models.workflow import WorkflowNodeExecution logger = logging.getLogger(__name__) @@ -110,36 +111,18 @@ class LangFuseDataTrace(BaseTraceInstance): ) self.add_trace(langfuse_trace_data=trace_data) - # through workflow_run_id get all_nodes_execution - workflow_nodes_execution_id_records = ( - db.session.query(WorkflowNodeExecution.id) - .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id) - .all() + # through workflow_run_id get all_nodes_execution using repository + session_factory = sessionmaker(bind=db.engine) + workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + params={"tenant_id": trace_info.tenant_id, "session_factory": session_factory}, ) - for node_execution_id_record in workflow_nodes_execution_id_records: - node_execution = ( - db.session.query( - WorkflowNodeExecution.id, - WorkflowNodeExecution.tenant_id, - WorkflowNodeExecution.app_id, - WorkflowNodeExecution.title, - WorkflowNodeExecution.node_type, - WorkflowNodeExecution.status, - WorkflowNodeExecution.inputs, - WorkflowNodeExecution.outputs, - WorkflowNodeExecution.created_at, - WorkflowNodeExecution.elapsed_time, - WorkflowNodeExecution.process_data, - WorkflowNodeExecution.execution_metadata, - ) - .filter(WorkflowNodeExecution.id == node_execution_id_record.id) - .first() - ) - - if not node_execution: - continue + # Get all executions for this workflow run + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( + workflow_run_id=trace_info.workflow_run_id + ) + for node_execution in workflow_node_executions: node_execution_id = node_execution.id tenant_id = node_execution.tenant_id app_id = node_execution.app_id diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index e3494e2f23..85a0eafdc1 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -7,6 +7,7 @@ from typing import Optional, cast from langsmith import Client from langsmith.schemas import RunBase +from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import LangSmithConfig @@ -27,9 +28,9 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( LangSmithRunUpdateModel, ) from core.ops.utils import filter_none_values, generate_dotted_order +from core.repository.repository_factory import RepositoryFactory from extensions.ext_database import db from models.model import EndUser, MessageFile -from models.workflow import WorkflowNodeExecution logger = logging.getLogger(__name__) @@ -134,36 +135,22 @@ class LangSmithDataTrace(BaseTraceInstance): self.add_run(langsmith_run) - # through workflow_run_id get all_nodes_execution - workflow_nodes_execution_id_records = ( - db.session.query(WorkflowNodeExecution.id) - .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id) - .all() + # through workflow_run_id get all_nodes_execution using repository + session_factory = sessionmaker(bind=db.engine) + workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + params={ + "tenant_id": trace_info.tenant_id, + "app_id": trace_info.metadata.get("app_id"), + "session_factory": session_factory, + }, ) - for node_execution_id_record in workflow_nodes_execution_id_records: - node_execution = ( - db.session.query( - WorkflowNodeExecution.id, - WorkflowNodeExecution.tenant_id, - WorkflowNodeExecution.app_id, - WorkflowNodeExecution.title, - WorkflowNodeExecution.node_type, - WorkflowNodeExecution.status, - WorkflowNodeExecution.inputs, - WorkflowNodeExecution.outputs, - WorkflowNodeExecution.created_at, - WorkflowNodeExecution.elapsed_time, - WorkflowNodeExecution.process_data, - WorkflowNodeExecution.execution_metadata, - ) - .filter(WorkflowNodeExecution.id == node_execution_id_record.id) - .first() - ) - - if not node_execution: - continue + # Get all executions for this workflow run + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( + workflow_run_id=trace_info.workflow_run_id + ) + for node_execution in workflow_node_executions: node_execution_id = node_execution.id tenant_id = node_execution.tenant_id app_id = node_execution.app_id diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index fabf38fbd6..923b9a24ed 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -7,6 +7,7 @@ from typing import Optional, cast from opik import Opik, Trace from opik.id_helpers import uuid4_to_uuid7 +from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import OpikConfig @@ -21,9 +22,9 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) +from core.repository.repository_factory import RepositoryFactory from extensions.ext_database import db from models.model import EndUser, MessageFile -from models.workflow import WorkflowNodeExecution logger = logging.getLogger(__name__) @@ -147,36 +148,22 @@ class OpikDataTrace(BaseTraceInstance): } self.add_trace(trace_data) - # through workflow_run_id get all_nodes_execution - workflow_nodes_execution_id_records = ( - db.session.query(WorkflowNodeExecution.id) - .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id) - .all() + # through workflow_run_id get all_nodes_execution using repository + session_factory = sessionmaker(bind=db.engine) + workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + params={ + "tenant_id": trace_info.tenant_id, + "app_id": trace_info.metadata.get("app_id"), + "session_factory": session_factory, + }, ) - for node_execution_id_record in workflow_nodes_execution_id_records: - node_execution = ( - db.session.query( - WorkflowNodeExecution.id, - WorkflowNodeExecution.tenant_id, - WorkflowNodeExecution.app_id, - WorkflowNodeExecution.title, - WorkflowNodeExecution.node_type, - WorkflowNodeExecution.status, - WorkflowNodeExecution.inputs, - WorkflowNodeExecution.outputs, - WorkflowNodeExecution.created_at, - WorkflowNodeExecution.elapsed_time, - WorkflowNodeExecution.process_data, - WorkflowNodeExecution.execution_metadata, - ) - .filter(WorkflowNodeExecution.id == node_execution_id_record.id) - .first() - ) - - if not node_execution: - continue + # Get all executions for this workflow run + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( + workflow_run_id=trace_info.workflow_run_id + ) + for node_execution in workflow_node_executions: node_execution_id = node_execution.id tenant_id = node_execution.tenant_id app_id = node_execution.app_id diff --git a/api/core/repository/__init__.py b/api/core/repository/__init__.py new file mode 100644 index 0000000000..253df1251d --- /dev/null +++ b/api/core/repository/__init__.py @@ -0,0 +1,15 @@ +""" +Repository interfaces for data access. + +This package contains repository interfaces that define the contract +for accessing and manipulating data, regardless of the underlying +storage mechanism. +""" + +from core.repository.repository_factory import RepositoryFactory +from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository + +__all__ = [ + "RepositoryFactory", + "WorkflowNodeExecutionRepository", +] diff --git a/api/core/repository/repository_factory.py b/api/core/repository/repository_factory.py new file mode 100644 index 0000000000..02e343d7ff --- /dev/null +++ b/api/core/repository/repository_factory.py @@ -0,0 +1,97 @@ +""" +Repository factory for creating repository instances. + +This module provides a simple factory interface for creating repository instances. +It does not contain any implementation details or dependencies on specific repositories. +""" + +from collections.abc import Callable, Mapping +from typing import Any, Literal, Optional, cast + +from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository + +# Type for factory functions - takes a dict of parameters and returns any repository type +RepositoryFactoryFunc = Callable[[Mapping[str, Any]], Any] + +# Type for workflow node execution factory function +WorkflowNodeExecutionFactoryFunc = Callable[[Mapping[str, Any]], WorkflowNodeExecutionRepository] + +# Repository type literals +RepositoryType = Literal["workflow_node_execution"] + + +class RepositoryFactory: + """ + Factory class for creating repository instances. + + This factory delegates the actual repository creation to implementation-specific + factory functions that are registered with the factory at runtime. + """ + + # Dictionary to store factory functions + _factory_functions: dict[str, RepositoryFactoryFunc] = {} + + @classmethod + def _register_factory(cls, repository_type: RepositoryType, factory_func: RepositoryFactoryFunc) -> None: + """ + Register a factory function for a specific repository type. + This is a private method and should not be called directly. + + Args: + repository_type: The type of repository (e.g., 'workflow_node_execution') + factory_func: A function that takes parameters and returns a repository instance + """ + cls._factory_functions[repository_type] = factory_func + + @classmethod + def _create_repository(cls, repository_type: RepositoryType, params: Optional[Mapping[str, Any]] = None) -> Any: + """ + Create a new repository instance with the provided parameters. + This is a private method and should not be called directly. + + Args: + repository_type: The type of repository to create + params: A dictionary of parameters to pass to the factory function + + Returns: + A new instance of the requested repository + + Raises: + ValueError: If no factory function is registered for the repository type + """ + if repository_type not in cls._factory_functions: + raise ValueError(f"No factory function registered for repository type '{repository_type}'") + + # Use empty dict if params is None + params = params or {} + + return cls._factory_functions[repository_type](params) + + @classmethod + def register_workflow_node_execution_factory(cls, factory_func: WorkflowNodeExecutionFactoryFunc) -> None: + """ + Register a factory function for the workflow node execution repository. + + Args: + factory_func: A function that takes parameters and returns a WorkflowNodeExecutionRepository instance + """ + cls._register_factory("workflow_node_execution", factory_func) + + @classmethod + def create_workflow_node_execution_repository( + cls, params: Optional[Mapping[str, Any]] = None + ) -> WorkflowNodeExecutionRepository: + """ + Create a new WorkflowNodeExecutionRepository instance with the provided parameters. + + Args: + params: A dictionary of parameters to pass to the factory function + + Returns: + A new instance of the WorkflowNodeExecutionRepository + + Raises: + ValueError: If no factory function is registered for the workflow_node_execution repository type + """ + # We can safely cast here because we've registered a WorkflowNodeExecutionFactoryFunc + return cast(WorkflowNodeExecutionRepository, cls._create_repository("workflow_node_execution", params)) diff --git a/api/core/repository/workflow_node_execution_repository.py b/api/core/repository/workflow_node_execution_repository.py new file mode 100644 index 0000000000..6dea4566de --- /dev/null +++ b/api/core/repository/workflow_node_execution_repository.py @@ -0,0 +1,88 @@ +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Literal, Optional, Protocol + +from models.workflow import WorkflowNodeExecution + + +@dataclass +class OrderConfig: + """Configuration for ordering WorkflowNodeExecution instances.""" + + order_by: list[str] + order_direction: Optional[Literal["asc", "desc"]] = None + + +class WorkflowNodeExecutionRepository(Protocol): + """ + Repository interface for WorkflowNodeExecution. + + This interface defines the contract for accessing and manipulating + WorkflowNodeExecution data, regardless of the underlying storage mechanism. + + Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), + and trigger sources (triggered_from) should be handled at the implementation level, not in + the core interface. This keeps the core domain model clean and independent of specific + application domains or deployment scenarios. + """ + + def save(self, execution: WorkflowNodeExecution) -> None: + """ + Save a WorkflowNodeExecution instance. + + Args: + execution: The WorkflowNodeExecution instance to save + """ + ... + + def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]: + """ + Retrieve a WorkflowNodeExecution by its node_execution_id. + + Args: + node_execution_id: The node execution ID + + Returns: + The WorkflowNodeExecution instance if found, None otherwise + """ + ... + + def get_by_workflow_run( + self, + workflow_run_id: str, + order_config: Optional[OrderConfig] = None, + ) -> Sequence[WorkflowNodeExecution]: + """ + Retrieve all WorkflowNodeExecution instances for a specific workflow run. + + Args: + workflow_run_id: The workflow run ID + order_config: Optional configuration for ordering results + order_config.order_by: List of fields to order by (e.g., ["index", "created_at"]) + order_config.order_direction: Direction to order ("asc" or "desc") + + Returns: + A list of WorkflowNodeExecution instances + """ + ... + + def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]: + """ + Retrieve all running WorkflowNodeExecution instances for a specific workflow run. + + Args: + workflow_run_id: The workflow run ID + + Returns: + A list of running WorkflowNodeExecution instances + """ + ... + + def update(self, execution: WorkflowNodeExecution) -> None: + """ + Update an existing WorkflowNodeExecution instance. + + Args: + execution: The WorkflowNodeExecution instance to update + """ + ... diff --git a/api/extensions/ext_repositories.py b/api/extensions/ext_repositories.py new file mode 100644 index 0000000000..27d8408ec1 --- /dev/null +++ b/api/extensions/ext_repositories.py @@ -0,0 +1,18 @@ +""" +Extension for initializing repositories. + +This extension registers repository implementations with the RepositoryFactory. +""" + +from dify_app import DifyApp +from repositories.repository_registry import register_repositories + + +def init_app(_app: DifyApp) -> None: + """ + Initialize repository implementations. + + Args: + _app: The Flask application instance (unused) + """ + register_repositories() diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py index 588bdb2d27..4c811c66ba 100644 --- a/api/extensions/ext_storage.py +++ b/api/extensions/ext_storage.py @@ -73,11 +73,7 @@ class Storage: raise ValueError(f"unsupported storage type {storage_type}") def save(self, filename, data): - try: - self.storage_runner.save(filename, data) - except Exception as e: - logger.exception(f"Failed to save file {filename}") - raise e + self.storage_runner.save(filename, data) @overload def load(self, filename: str, /, *, stream: Literal[False] = False) -> bytes: ... @@ -86,49 +82,25 @@ class Storage: def load(self, filename: str, /, *, stream: Literal[True]) -> Generator: ... def load(self, filename: str, /, *, stream: bool = False) -> Union[bytes, Generator]: - try: - if stream: - return self.load_stream(filename) - else: - return self.load_once(filename) - except Exception as e: - logger.exception(f"Failed to load file {filename}") - raise e + if stream: + return self.load_stream(filename) + else: + return self.load_once(filename) def load_once(self, filename: str) -> bytes: - try: - return self.storage_runner.load_once(filename) - except Exception as e: - logger.exception(f"Failed to load_once file {filename}") - raise e + return self.storage_runner.load_once(filename) def load_stream(self, filename: str) -> Generator: - try: - return self.storage_runner.load_stream(filename) - except Exception as e: - logger.exception(f"Failed to load_stream file {filename}") - raise e + return self.storage_runner.load_stream(filename) def download(self, filename, target_filepath): - try: - self.storage_runner.download(filename, target_filepath) - except Exception as e: - logger.exception(f"Failed to download file {filename}") - raise e + self.storage_runner.download(filename, target_filepath) def exists(self, filename): - try: - return self.storage_runner.exists(filename) - except Exception as e: - logger.exception(f"Failed to check file exists {filename}") - raise e + return self.storage_runner.exists(filename) def delete(self, filename): - try: - return self.storage_runner.delete(filename) - except Exception as e: - logger.exception(f"Failed to delete file {filename}") - raise e + return self.storage_runner.delete(filename) storage = Storage() diff --git a/api/models/workflow.py b/api/models/workflow.py index 8b7c376e4b..045fa0aaa0 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -510,7 +510,7 @@ class WorkflowRun(Base): ) -class WorkflowNodeExecutionTriggeredFrom(Enum): +class WorkflowNodeExecutionTriggeredFrom(StrEnum): """ Workflow Node Execution Triggered From Enum """ @@ -518,21 +518,8 @@ class WorkflowNodeExecutionTriggeredFrom(Enum): SINGLE_STEP = "single-step" WORKFLOW_RUN = "workflow-run" - @classmethod - def value_of(cls, value: str) -> "WorkflowNodeExecutionTriggeredFrom": - """ - Get value of given mode. - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f"invalid workflow node execution triggered from value {value}") - - -class WorkflowNodeExecutionStatus(Enum): +class WorkflowNodeExecutionStatus(StrEnum): """ Workflow Node Execution Status Enum """ @@ -543,19 +530,6 @@ class WorkflowNodeExecutionStatus(Enum): EXCEPTION = "exception" RETRY = "retry" - @classmethod - def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus": - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f"invalid workflow node execution status value {value}") - class WorkflowNodeExecution(Base): """ diff --git a/api/repositories/__init__.py b/api/repositories/__init__.py new file mode 100644 index 0000000000..4cc339688b --- /dev/null +++ b/api/repositories/__init__.py @@ -0,0 +1,6 @@ +""" +Repository implementations for data access. + +This package contains concrete implementations of the repository interfaces +defined in the core.repository package. +""" diff --git a/api/repositories/repository_registry.py b/api/repositories/repository_registry.py new file mode 100644 index 0000000000..aa0a208d8e --- /dev/null +++ b/api/repositories/repository_registry.py @@ -0,0 +1,87 @@ +""" +Registry for repository implementations. + +This module is responsible for registering factory functions with the repository factory. +""" + +import logging +from collections.abc import Mapping +from typing import Any + +from sqlalchemy.orm import sessionmaker + +from configs import dify_config +from core.repository.repository_factory import RepositoryFactory +from extensions.ext_database import db +from repositories.workflow_node_execution import SQLAlchemyWorkflowNodeExecutionRepository + +logger = logging.getLogger(__name__) + +# Storage type constants +STORAGE_TYPE_RDBMS = "rdbms" +STORAGE_TYPE_HYBRID = "hybrid" + + +def register_repositories() -> None: + """ + Register repository factory functions with the RepositoryFactory. + + This function reads configuration settings to determine which repository + implementations to register. + """ + # Configure WorkflowNodeExecutionRepository factory based on configuration + workflow_node_execution_storage = dify_config.WORKFLOW_NODE_EXECUTION_STORAGE + + # Check storage type and register appropriate implementation + if workflow_node_execution_storage == STORAGE_TYPE_RDBMS: + # Register SQLAlchemy implementation for RDBMS storage + logger.info("Registering WorkflowNodeExecution repository with RDBMS storage") + RepositoryFactory.register_workflow_node_execution_factory(create_workflow_node_execution_repository) + elif workflow_node_execution_storage == STORAGE_TYPE_HYBRID: + # Hybrid storage is not yet implemented + raise NotImplementedError("Hybrid storage for WorkflowNodeExecution repository is not yet implemented") + else: + # Unknown storage type + raise ValueError( + f"Unknown storage type '{workflow_node_execution_storage}' for WorkflowNodeExecution repository. " + f"Supported types: {STORAGE_TYPE_RDBMS}" + ) + + +def create_workflow_node_execution_repository(params: Mapping[str, Any]) -> SQLAlchemyWorkflowNodeExecutionRepository: + """ + Create a WorkflowNodeExecutionRepository instance using SQLAlchemy implementation. + + This factory function creates a repository for the RDBMS storage type. + + Args: + params: Parameters for creating the repository, including: + - tenant_id: Required. The tenant ID for multi-tenancy. + - app_id: Optional. The application ID for filtering. + - session_factory: Optional. A SQLAlchemy sessionmaker instance. If not provided, + a new sessionmaker will be created using the global database engine. + + Returns: + A WorkflowNodeExecutionRepository instance + + Raises: + ValueError: If required parameters are missing + """ + # Extract required parameters + tenant_id = params.get("tenant_id") + if tenant_id is None: + raise ValueError("tenant_id is required for WorkflowNodeExecution repository with RDBMS storage") + + # Extract optional parameters + app_id = params.get("app_id") + + # Use the session_factory from params if provided, otherwise create one using the global db engine + session_factory = params.get("session_factory") + if session_factory is None: + # Create a sessionmaker using the same engine as the global db session + session_factory = sessionmaker(bind=db.engine) + + # Create and return the repository + return SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=session_factory, tenant_id=tenant_id, app_id=app_id + ) diff --git a/api/repositories/workflow_node_execution/__init__.py b/api/repositories/workflow_node_execution/__init__.py new file mode 100644 index 0000000000..eed827bd05 --- /dev/null +++ b/api/repositories/workflow_node_execution/__init__.py @@ -0,0 +1,9 @@ +""" +WorkflowNodeExecution repository implementations. +""" + +from repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository + +__all__ = [ + "SQLAlchemyWorkflowNodeExecutionRepository", +] diff --git a/api/repositories/workflow_node_execution/sqlalchemy_repository.py b/api/repositories/workflow_node_execution/sqlalchemy_repository.py new file mode 100644 index 0000000000..01c54dfcd7 --- /dev/null +++ b/api/repositories/workflow_node_execution/sqlalchemy_repository.py @@ -0,0 +1,170 @@ +""" +SQLAlchemy implementation of the WorkflowNodeExecutionRepository. +""" + +import logging +from collections.abc import Sequence +from typing import Optional + +from sqlalchemy import UnaryExpression, asc, desc, select +from sqlalchemy.engine import Engine +from sqlalchemy.orm import sessionmaker + +from core.repository.workflow_node_execution_repository import OrderConfig +from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom + +logger = logging.getLogger(__name__) + + +class SQLAlchemyWorkflowNodeExecutionRepository: + """ + SQLAlchemy implementation of the WorkflowNodeExecutionRepository interface. + + This implementation supports multi-tenancy by filtering operations based on tenant_id. + Each method creates its own session, handles the transaction, and commits changes + to the database. This prevents long-running connections in the workflow core. + """ + + def __init__(self, session_factory: sessionmaker | Engine, tenant_id: str, app_id: Optional[str] = None): + """ + Initialize the repository with a SQLAlchemy sessionmaker or engine and tenant context. + + Args: + session_factory: SQLAlchemy sessionmaker or engine for creating sessions + tenant_id: Tenant ID for multi-tenancy + app_id: Optional app ID for filtering by application + """ + # If an engine is provided, create a sessionmaker from it + if isinstance(session_factory, Engine): + self._session_factory = sessionmaker(bind=session_factory) + else: + self._session_factory = session_factory + + self._tenant_id = tenant_id + self._app_id = app_id + + def save(self, execution: WorkflowNodeExecution) -> None: + """ + Save a WorkflowNodeExecution instance and commit changes to the database. + + Args: + execution: The WorkflowNodeExecution instance to save + """ + with self._session_factory() as session: + # Ensure tenant_id is set + if not execution.tenant_id: + execution.tenant_id = self._tenant_id + + # Set app_id if provided and not already set + if self._app_id and not execution.app_id: + execution.app_id = self._app_id + + session.add(execution) + session.commit() + + def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]: + """ + Retrieve a WorkflowNodeExecution by its node_execution_id. + + Args: + node_execution_id: The node execution ID + + Returns: + The WorkflowNodeExecution instance if found, None otherwise + """ + with self._session_factory() as session: + stmt = select(WorkflowNodeExecution).where( + WorkflowNodeExecution.node_execution_id == node_execution_id, + WorkflowNodeExecution.tenant_id == self._tenant_id, + ) + + if self._app_id: + stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) + + return session.scalar(stmt) + + def get_by_workflow_run( + self, + workflow_run_id: str, + order_config: Optional[OrderConfig] = None, + ) -> Sequence[WorkflowNodeExecution]: + """ + Retrieve all WorkflowNodeExecution instances for a specific workflow run. + + Args: + workflow_run_id: The workflow run ID + order_config: Optional configuration for ordering results + order_config.order_by: List of fields to order by (e.g., ["index", "created_at"]) + order_config.order_direction: Direction to order ("asc" or "desc") + + Returns: + A list of WorkflowNodeExecution instances + """ + with self._session_factory() as session: + stmt = select(WorkflowNodeExecution).where( + WorkflowNodeExecution.workflow_run_id == workflow_run_id, + WorkflowNodeExecution.tenant_id == self._tenant_id, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + if self._app_id: + stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) + + # Apply ordering if provided + if order_config and order_config.order_by: + order_columns: list[UnaryExpression] = [] + for field in order_config.order_by: + column = getattr(WorkflowNodeExecution, field, None) + if not column: + continue + if order_config.order_direction == "desc": + order_columns.append(desc(column)) + else: + order_columns.append(asc(column)) + + if order_columns: + stmt = stmt.order_by(*order_columns) + + return session.scalars(stmt).all() + + def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]: + """ + Retrieve all running WorkflowNodeExecution instances for a specific workflow run. + + Args: + workflow_run_id: The workflow run ID + + Returns: + A list of running WorkflowNodeExecution instances + """ + with self._session_factory() as session: + stmt = select(WorkflowNodeExecution).where( + WorkflowNodeExecution.workflow_run_id == workflow_run_id, + WorkflowNodeExecution.tenant_id == self._tenant_id, + WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + if self._app_id: + stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) + + return session.scalars(stmt).all() + + def update(self, execution: WorkflowNodeExecution) -> None: + """ + Update an existing WorkflowNodeExecution instance and commit changes to the database. + + Args: + execution: The WorkflowNodeExecution instance to update + """ + with self._session_factory() as session: + # Ensure tenant_id is set + if not execution.tenant_id: + execution.tenant_id = self._tenant_id + + # Set app_id if provided and not already set + if self._app_id and not execution.app_id: + execution.app_id = self._app_id + + session.merge(execution) + session.commit() diff --git a/api/tests/unit_tests/repositories/__init__.py b/api/tests/unit_tests/repositories/__init__.py new file mode 100644 index 0000000000..bc0d6e78c9 --- /dev/null +++ b/api/tests/unit_tests/repositories/__init__.py @@ -0,0 +1,3 @@ +""" +Unit tests for repositories. +""" diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/__init__.py b/api/tests/unit_tests/repositories/workflow_node_execution/__init__.py new file mode 100644 index 0000000000..78815a8d1a --- /dev/null +++ b/api/tests/unit_tests/repositories/workflow_node_execution/__init__.py @@ -0,0 +1,3 @@ +""" +Unit tests for workflow_node_execution repositories. +""" diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py new file mode 100644 index 0000000000..f31adab2a8 --- /dev/null +++ b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py @@ -0,0 +1,154 @@ +""" +Unit tests for the SQLAlchemy implementation of WorkflowNodeExecutionRepository. +""" + +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture +from sqlalchemy.orm import Session, sessionmaker + +from core.repository.workflow_node_execution_repository import OrderConfig +from models.workflow import WorkflowNodeExecution +from repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository + + +@pytest.fixture +def session(): + """Create a mock SQLAlchemy session.""" + session = MagicMock(spec=Session) + # Configure the session to be used as a context manager + session.__enter__ = MagicMock(return_value=session) + session.__exit__ = MagicMock(return_value=None) + + # Configure the session factory to return the session + session_factory = MagicMock(spec=sessionmaker) + session_factory.return_value = session + return session, session_factory + + +@pytest.fixture +def repository(session): + """Create a repository instance with test data.""" + _, session_factory = session + tenant_id = "test-tenant" + app_id = "test-app" + return SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=session_factory, tenant_id=tenant_id, app_id=app_id + ) + + +def test_save(repository, session): + """Test save method.""" + session_obj, _ = session + # Create a mock execution + execution = MagicMock(spec=WorkflowNodeExecution) + execution.tenant_id = None + execution.app_id = None + + # Call save method + repository.save(execution) + + # Assert tenant_id and app_id are set + assert execution.tenant_id == repository._tenant_id + assert execution.app_id == repository._app_id + + # Assert session.add was called + session_obj.add.assert_called_once_with(execution) + + +def test_save_with_existing_tenant_id(repository, session): + """Test save method with existing tenant_id.""" + session_obj, _ = session + # Create a mock execution with existing tenant_id + execution = MagicMock(spec=WorkflowNodeExecution) + execution.tenant_id = "existing-tenant" + execution.app_id = None + + # Call save method + repository.save(execution) + + # Assert tenant_id is not changed and app_id is set + assert execution.tenant_id == "existing-tenant" + assert execution.app_id == repository._app_id + + # Assert session.add was called + session_obj.add.assert_called_once_with(execution) + + +def test_get_by_node_execution_id(repository, session, mocker: MockerFixture): + """Test get_by_node_execution_id method.""" + session_obj, _ = session + # Set up mock + mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select") + mock_stmt = mocker.MagicMock() + mock_select.return_value = mock_stmt + mock_stmt.where.return_value = mock_stmt + session_obj.scalar.return_value = mocker.MagicMock(spec=WorkflowNodeExecution) + + # Call method + result = repository.get_by_node_execution_id("test-node-execution-id") + + # Assert select was called with correct parameters + mock_select.assert_called_once() + session_obj.scalar.assert_called_once_with(mock_stmt) + assert result is not None + + +def test_get_by_workflow_run(repository, session, mocker: MockerFixture): + """Test get_by_workflow_run method.""" + session_obj, _ = session + # Set up mock + mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select") + mock_stmt = mocker.MagicMock() + mock_select.return_value = mock_stmt + mock_stmt.where.return_value = mock_stmt + mock_stmt.order_by.return_value = mock_stmt + session_obj.scalars.return_value.all.return_value = [mocker.MagicMock(spec=WorkflowNodeExecution)] + + # Call method + order_config = OrderConfig(order_by=["index"], order_direction="desc") + result = repository.get_by_workflow_run(workflow_run_id="test-workflow-run-id", order_config=order_config) + + # Assert select was called with correct parameters + mock_select.assert_called_once() + session_obj.scalars.assert_called_once_with(mock_stmt) + assert len(result) == 1 + + +def test_get_running_executions(repository, session, mocker: MockerFixture): + """Test get_running_executions method.""" + session_obj, _ = session + # Set up mock + mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select") + mock_stmt = mocker.MagicMock() + mock_select.return_value = mock_stmt + mock_stmt.where.return_value = mock_stmt + session_obj.scalars.return_value.all.return_value = [mocker.MagicMock(spec=WorkflowNodeExecution)] + + # Call method + result = repository.get_running_executions("test-workflow-run-id") + + # Assert select was called with correct parameters + mock_select.assert_called_once() + session_obj.scalars.assert_called_once_with(mock_stmt) + assert len(result) == 1 + + +def test_update(repository, session): + """Test update method.""" + session_obj, _ = session + # Create a mock execution + execution = MagicMock(spec=WorkflowNodeExecution) + execution.tenant_id = None + execution.app_id = None + + # Call update method + repository.update(execution) + + # Assert tenant_id and app_id are set + assert execution.tenant_id == repository._tenant_id + assert execution.app_id == repository._app_id + + # Assert session.merge was called + session_obj.merge.assert_called_once_with(execution) diff --git a/docker/.env.example b/docker/.env.example index 9b372dcec9..82ef4174c2 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -744,6 +744,12 @@ MAX_VARIABLE_SIZE=204800 WORKFLOW_PARALLEL_DEPTH_LIMIT=3 WORKFLOW_FILE_UPLOAD_LIMIT=10 +# Workflow storage configuration +# Options: rdbms, hybrid +# rdbms: Use only the relational database (default) +# hybrid: Save new data to object storage, read from both object storage and RDBMS +WORKFLOW_NODE_EXECUTION_STORAGE=rdbms + # HTTP request node in workflow configuration HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 172cbe2d2f..e01b9f7e9a 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -327,6 +327,7 @@ x-shared-env: &shared-api-worker-env MAX_VARIABLE_SIZE: ${MAX_VARIABLE_SIZE:-204800} WORKFLOW_PARALLEL_DEPTH_LIMIT: ${WORKFLOW_PARALLEL_DEPTH_LIMIT:-3} WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10} + WORKFLOW_NODE_EXECUTION_STORAGE: ${WORKFLOW_NODE_EXECUTION_STORAGE:-rdbms} HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760} HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576} HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True}