From 4977bb21ec7701a6bdfdd73d015c73b4eccfce46 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sat, 17 May 2025 00:56:16 +0800 Subject: [PATCH] feat(workflow): domain model for workflow node execution (#19430) Signed-off-by: -LAN- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- api/controllers/console/app/workflow_run.py | 13 +- .../app/apps/advanced_chat/app_generator.py | 13 +- .../advanced_chat/generate_task_pipeline.py | 10 +- .../app/apps/message_based_app_generator.py | 4 +- api/core/app/apps/workflow/app_generator.py | 14 +- api/core/app/entities/task_entities.py | 20 +- .../entities/langfuse_trace_entity.py | 5 +- api/core/ops/langfuse_trace/langfuse_trace.py | 48 ++- .../entities/langsmith_trace_entity.py | 5 +- .../ops/langsmith_trace/langsmith_trace.py | 56 ++-- api/core/ops/opik_trace/opik_trace.py | 54 ++-- .../entities/weave_trace_entity.py | 5 +- api/core/ops/weave_trace/weave_trace.py | 89 ++--- api/core/rag/extractor/word_extractor.py | 4 +- ...hemy_workflow_node_execution_repository.py | 284 +++++++++++++--- api/core/tools/tool_engine.py | 6 +- .../entities/node_execution_entities.py | 98 ++++++ .../workflow_node_execution_repository.py | 47 ++- .../workflow_app_generate_task_pipeline.py | 6 +- api/core/workflow/workflow_cycle_manager.py | 302 +++++++++-------- api/models/__init__.py | 4 +- api/models/enums.py | 2 +- api/models/model.py | 8 +- api/models/workflow.py | 26 +- api/services/file_service.py | 6 +- api/services/workflow_app_service.py | 4 +- api/services/workflow_run_service.py | 26 +- api/services/workflow_service.py | 10 +- api/tasks/remove_app_and_related_data_task.py | 26 +- .../workflow/test_workflow_cycle_manager.py | 92 +++--- .../test_sqlalchemy_repository.py | 304 ++++++++++++++++-- 31 files changed, 1108 insertions(+), 483 deletions(-) create mode 100644 api/core/workflow/entities/node_execution_entities.py diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 08ab61bbb9..9099700213 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -1,3 +1,6 @@ +from typing import cast + +from flask_login import current_user from flask_restful import Resource, marshal_with, reqparse from flask_restful.inputs import int_range @@ -12,8 +15,7 @@ from fields.workflow_run_fields import ( ) from libs.helper import uuid_value from libs.login import login_required -from models import App -from models.model import AppMode +from models import Account, App, AppMode, EndUser from services.workflow_run_service import WorkflowRunService @@ -90,7 +92,12 @@ class WorkflowRunNodeExecutionListApi(Resource): run_id = str(run_id) workflow_run_service = WorkflowRunService() - node_executions = workflow_run_service.get_workflow_run_node_executions(app_model=app_model, run_id=run_id) + user = cast("Account | EndUser", current_user) + node_executions = workflow_run_service.get_workflow_run_node_executions( + app_model=app_model, + run_id=run_id, + user=user, + ) return {"data": node_executions} diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 4b0e64130b..b74100bb19 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -29,9 +29,7 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from extensions.ext_database import db from factories import file_factory -from models.account import Account -from models.model import App, Conversation, EndUser, Message -from models.workflow import Workflow +from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom from services.conversation_service import ConversationService from services.errors.message import MessageNotExistsError @@ -165,8 +163,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( session_factory=session_factory, - tenant_id=application_generate_entity.app_config.tenant_id, + user=user, app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) return self._generate( @@ -231,8 +230,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( session_factory=session_factory, - tenant_id=application_generate_entity.app_config.tenant_id, + user=user, app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) return self._generate( @@ -295,8 +295,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( session_factory=session_factory, - tenant_id=application_generate_entity.app_config.tenant_id, + user=user, app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) return self._generate( diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index f71c49d112..735b2a9709 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -70,7 +70,7 @@ from events.message_event import message_was_created from extensions.ext_database import db from models import Conversation, EndUser, Message, MessageFile from models.account import Account -from models.enums import CreatedByRole +from models.enums import CreatorUserRole from models.workflow import ( Workflow, WorkflowRunStatus, @@ -105,11 +105,11 @@ class AdvancedChatAppGenerateTaskPipeline: if isinstance(user, EndUser): self._user_id = user.id user_session_id = user.session_id - self._created_by_role = CreatedByRole.END_USER + self._created_by_role = CreatorUserRole.END_USER elif isinstance(user, Account): self._user_id = user.id user_session_id = user.id - self._created_by_role = CreatedByRole.ACCOUNT + self._created_by_role = CreatorUserRole.ACCOUNT else: raise NotImplementedError(f"User type not supported: {type(user)}") @@ -739,9 +739,9 @@ class AdvancedChatAppGenerateTaskPipeline: url=file["remote_url"], belongs_to="assistant", upload_file_id=file["related_id"], - created_by_role=CreatedByRole.ACCOUNT + created_by_role=CreatorUserRole.ACCOUNT if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} - else CreatedByRole.END_USER, + else CreatorUserRole.END_USER, created_by=message.from_account_id or message.from_end_user_id or "", ) for file in self._recorded_files diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 995082b79d..58b94f4d43 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -25,7 +25,7 @@ from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBa from core.prompt.utils.prompt_template_parser import PromptTemplateParser from extensions.ext_database import db from models import Account -from models.enums import CreatedByRole +from models.enums import CreatorUserRole from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.conversation import ConversationNotExistsError @@ -223,7 +223,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): belongs_to="user", url=file.remote_url, upload_file_id=file.related_id, - created_by_role=(CreatedByRole.ACCOUNT if account_id else CreatedByRole.END_USER), + created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER), created_by=account_id or end_user_id or "", ) db.session.add(message_file) diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 1d67671974..d49ff682b9 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -27,7 +27,7 @@ from core.workflow.repository.workflow_node_execution_repository import Workflow from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerateTaskPipeline from extensions.ext_database import db from factories import file_factory -from models import Account, App, EndUser, Workflow +from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -138,10 +138,12 @@ class WorkflowAppGenerator(BaseAppGenerator): # Create workflow node execution repository session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( session_factory=session_factory, - tenant_id=application_generate_entity.app_config.tenant_id, + user=user, app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) return self._generate( @@ -262,10 +264,12 @@ class WorkflowAppGenerator(BaseAppGenerator): # Create workflow node execution repository session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( session_factory=session_factory, - tenant_id=application_generate_entity.app_config.tenant_id, + user=user, app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) return self._generate( @@ -325,10 +329,12 @@ class WorkflowAppGenerator(BaseAppGenerator): # Create workflow node execution repository session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( session_factory=session_factory, - tenant_id=application_generate_entity.app_config.tenant_id, + user=user, app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) return self._generate( diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 817699bd20..0c2d617f80 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, ConfigDict from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.entities.node_entities import AgentNodeStrategyInit +from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey from models.workflow import WorkflowNodeExecutionStatus @@ -244,7 +244,7 @@ class NodeStartStreamResponse(StreamResponse): title: str index: int predecessor_node_id: Optional[str] = None - inputs: Optional[dict] = None + inputs: Optional[Mapping[str, Any]] = None created_at: int extras: dict = {} parallel_id: Optional[str] = None @@ -301,13 +301,13 @@ class NodeFinishStreamResponse(StreamResponse): title: str index: int predecessor_node_id: Optional[str] = None - inputs: Optional[dict] = None - process_data: Optional[dict] = None - outputs: Optional[dict] = None + inputs: Optional[Mapping[str, Any]] = None + process_data: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None status: str error: Optional[str] = None elapsed_time: float - execution_metadata: Optional[dict] = None + execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None created_at: int finished_at: int files: Optional[Sequence[Mapping[str, Any]]] = [] @@ -370,13 +370,13 @@ class NodeRetryStreamResponse(StreamResponse): title: str index: int predecessor_node_id: Optional[str] = None - inputs: Optional[dict] = None - process_data: Optional[dict] = None - outputs: Optional[dict] = None + inputs: Optional[Mapping[str, Any]] = None + process_data: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None status: str error: Optional[str] = None elapsed_time: float - execution_metadata: Optional[dict] = None + execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None created_at: int finished_at: int files: Optional[Sequence[Mapping[str, Any]]] = [] diff --git a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py index f486da3a6d..46ba1c45b9 100644 --- a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py +++ b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from datetime import datetime from enum import StrEnum from typing import Any, Optional, Union @@ -155,10 +156,10 @@ class LangfuseSpan(BaseModel): description="The status message of the span. Additional field for context of the event. E.g. the error " "message of an error event.", ) - input: Optional[Union[str, dict[str, Any], list, None]] = Field( + input: Optional[Union[str, Mapping[str, Any], list, None]] = Field( default=None, description="The input of the span. Can be any JSON object." ) - output: Optional[Union[str, dict[str, Any], list, None]] = Field( + output: Optional[Union[str, Mapping[str, Any], list, None]] = Field( default=None, description="The output of the span. Can be any JSON object." ) version: Optional[str] = Field( diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index c74617e558..120c36f53d 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -1,11 +1,10 @@ -import json import logging import os from datetime import datetime, timedelta from typing import Optional from langfuse import Langfuse # type: ignore -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session, sessionmaker from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import LangfuseConfig @@ -30,8 +29,9 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( ) from core.ops.utils import filter_none_values from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.nodes.enums import NodeType from extensions.ext_database import db -from models.model import EndUser +from models import Account, App, EndUser, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -113,8 +113,29 @@ class LangFuseDataTrace(BaseTraceInstance): # through workflow_run_id get all_nodes_execution using repository session_factory = sessionmaker(bind=db.engine) + # Find the app's creator account + with Session(db.engine, expire_on_commit=False) as session: + # Get the app to find its creator + app_id = trace_info.metadata.get("app_id") + if not app_id: + raise ValueError("No app_id found in trace_info metadata") + + app = session.query(App).filter(App.id == app_id).first() + if not app: + raise ValueError(f"App with id {app_id} not found") + + if not app.created_by: + raise ValueError(f"App with id {app_id} has no creator (created_by is None)") + + service_account = session.query(Account).filter(Account.id == app.created_by).first() + if not service_account: + raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=session_factory, tenant_id=trace_info.tenant_id + session_factory=session_factory, + user=service_account, + app_id=trace_info.metadata.get("app_id"), + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) # Get all executions for this workflow run @@ -124,23 +145,22 @@ class LangFuseDataTrace(BaseTraceInstance): for node_execution in workflow_node_executions: node_execution_id = node_execution.id - tenant_id = node_execution.tenant_id - app_id = node_execution.app_id + tenant_id = trace_info.tenant_id # Use from trace_info instead + app_id = trace_info.metadata.get("app_id") # Use from trace_info instead node_name = node_execution.title node_type = node_execution.node_type status = node_execution.status - if node_type == "llm": - inputs = ( - json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {} - ) + if node_type == NodeType.LLM: + inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: - inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} - outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} + inputs = node_execution.inputs if node_execution.inputs else {} + outputs = node_execution.outputs if node_execution.outputs else {} created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) - metadata = json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} + execution_metadata = node_execution.metadata if node_execution.metadata else {} + metadata = {str(k): v for k, v in execution_metadata.items()} metadata.update( { "workflow_run_id": trace_info.workflow_run_id, @@ -152,7 +172,7 @@ class LangFuseDataTrace(BaseTraceInstance): "status": status, } ) - process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} + process_data = node_execution.process_data if node_execution.process_data else {} model_provider = process_data.get("model_provider", None) model_name = process_data.get("model_name", None) if model_provider is not None and model_name is not None: diff --git a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py index 348b7ba501..4fd01136ba 100644 --- a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py +++ b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from datetime import datetime from enum import StrEnum from typing import Any, Optional, Union @@ -30,8 +31,8 @@ class LangSmithMultiModel(BaseModel): class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): name: Optional[str] = Field(..., description="Name of the run") - inputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Inputs of the run") - outputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Outputs of the run") + inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the run") + outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the run") run_type: LangSmithRunType = Field(..., description="Type of the run") start_time: Optional[datetime | str] = Field(None, description="Start time of the run") end_time: Optional[datetime | str] = Field(None, description="End time of the run") diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index d1e16d3152..6631727c79 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -1,4 +1,3 @@ -import json import logging import os import uuid @@ -7,7 +6,7 @@ from typing import Optional, cast from langsmith import Client from langsmith.schemas import RunBase -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session, sessionmaker from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import LangSmithConfig @@ -29,8 +28,10 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( ) from core.ops.utils import filter_none_values, generate_dotted_order from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.entities.node_entities import NodeRunMetadataKey +from core.workflow.nodes.enums import NodeType from extensions.ext_database import db -from models.model import EndUser, MessageFile +from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -137,8 +138,29 @@ class LangSmithDataTrace(BaseTraceInstance): # through workflow_run_id get all_nodes_execution using repository session_factory = sessionmaker(bind=db.engine) + # Find the app's creator account + with Session(db.engine, expire_on_commit=False) as session: + # Get the app to find its creator + app_id = trace_info.metadata.get("app_id") + if not app_id: + raise ValueError("No app_id found in trace_info metadata") + + app = session.query(App).filter(App.id == app_id).first() + if not app: + raise ValueError(f"App with id {app_id} not found") + + if not app.created_by: + raise ValueError(f"App with id {app_id} has no creator (created_by is None)") + + service_account = session.query(Account).filter(Account.id == app.created_by).first() + if not service_account: + raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=session_factory, tenant_id=trace_info.tenant_id, app_id=trace_info.metadata.get("app_id") + session_factory=session_factory, + user=service_account, + app_id=trace_info.metadata.get("app_id"), + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) # Get all executions for this workflow run @@ -148,27 +170,23 @@ class LangSmithDataTrace(BaseTraceInstance): for node_execution in workflow_node_executions: node_execution_id = node_execution.id - tenant_id = node_execution.tenant_id - app_id = node_execution.app_id + tenant_id = trace_info.tenant_id # Use from trace_info instead + app_id = trace_info.metadata.get("app_id") # Use from trace_info instead node_name = node_execution.title node_type = node_execution.node_type status = node_execution.status - if node_type == "llm": - inputs = ( - json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {} - ) + if node_type == NodeType.LLM: + inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: - inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} - outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} + inputs = node_execution.inputs if node_execution.inputs else {} + outputs = node_execution.outputs if node_execution.outputs else {} created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) - execution_metadata = ( - json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} - ) - node_total_tokens = execution_metadata.get("total_tokens", 0) - metadata = execution_metadata.copy() + execution_metadata = node_execution.metadata if node_execution.metadata else {} + node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0 + metadata = {str(key): value for key, value in execution_metadata.items()} metadata.update( { "workflow_run_id": trace_info.workflow_run_id, @@ -181,7 +199,7 @@ class LangSmithDataTrace(BaseTraceInstance): } ) - process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} + process_data = node_execution.process_data if node_execution.process_data else {} if process_data and process_data.get("model_mode") == "chat": run_type = LangSmithRunType.llm @@ -191,7 +209,7 @@ class LangSmithDataTrace(BaseTraceInstance): "ls_model_name": process_data.get("model_name", ""), } ) - elif node_type == "knowledge-retrieval": + elif node_type == NodeType.KNOWLEDGE_RETRIEVAL: run_type = LangSmithRunType.retriever else: run_type = LangSmithRunType.tool diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index 1484041447..c22df55357 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -1,4 +1,3 @@ -import json import logging import os import uuid @@ -7,7 +6,7 @@ from typing import Optional, cast from opik import Opik, Trace from opik.id_helpers import uuid4_to_uuid7 -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session, sessionmaker from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import OpikConfig @@ -23,8 +22,10 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.entities.node_entities import NodeRunMetadataKey +from core.workflow.nodes.enums import NodeType from extensions.ext_database import db -from models.model import EndUser, MessageFile +from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -150,8 +151,29 @@ class OpikDataTrace(BaseTraceInstance): # through workflow_run_id get all_nodes_execution using repository session_factory = sessionmaker(bind=db.engine) + # Find the app's creator account + with Session(db.engine, expire_on_commit=False) as session: + # Get the app to find its creator + app_id = trace_info.metadata.get("app_id") + if not app_id: + raise ValueError("No app_id found in trace_info metadata") + + app = session.query(App).filter(App.id == app_id).first() + if not app: + raise ValueError(f"App with id {app_id} not found") + + if not app.created_by: + raise ValueError(f"App with id {app_id} has no creator (created_by is None)") + + service_account = session.query(Account).filter(Account.id == app.created_by).first() + if not service_account: + raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=session_factory, tenant_id=trace_info.tenant_id, app_id=trace_info.metadata.get("app_id") + session_factory=session_factory, + user=service_account, + app_id=trace_info.metadata.get("app_id"), + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) # Get all executions for this workflow run @@ -161,26 +183,22 @@ class OpikDataTrace(BaseTraceInstance): for node_execution in workflow_node_executions: node_execution_id = node_execution.id - tenant_id = node_execution.tenant_id - app_id = node_execution.app_id + tenant_id = trace_info.tenant_id # Use from trace_info instead + app_id = trace_info.metadata.get("app_id") # Use from trace_info instead node_name = node_execution.title node_type = node_execution.node_type status = node_execution.status - if node_type == "llm": - inputs = ( - json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {} - ) + if node_type == NodeType.LLM: + inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: - inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} - outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} + inputs = node_execution.inputs if node_execution.inputs else {} + outputs = node_execution.outputs if node_execution.outputs else {} created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) - execution_metadata = ( - json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} - ) - metadata = execution_metadata.copy() + execution_metadata = node_execution.metadata if node_execution.metadata else {} + metadata = {str(k): v for k, v in execution_metadata.items()} metadata.update( { "workflow_run_id": trace_info.workflow_run_id, @@ -193,7 +211,7 @@ class OpikDataTrace(BaseTraceInstance): } ) - process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} + process_data = node_execution.process_data if node_execution.process_data else {} provider = None model = None @@ -226,7 +244,7 @@ class OpikDataTrace(BaseTraceInstance): parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id if not total_tokens: - total_tokens = execution_metadata.get("total_tokens", 0) + total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0 span_data = { "trace_id": opik_trace_id, diff --git a/api/core/ops/weave_trace/entities/weave_trace_entity.py b/api/core/ops/weave_trace/entities/weave_trace_entity.py index e423f5ccbb..7f489f37ac 100644 --- a/api/core/ops/weave_trace/entities/weave_trace_entity.py +++ b/api/core/ops/weave_trace/entities/weave_trace_entity.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from typing import Any, Optional, Union from pydantic import BaseModel, Field, field_validator @@ -19,8 +20,8 @@ class WeaveMultiModel(BaseModel): class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel): id: str = Field(..., description="ID of the trace") op: str = Field(..., description="Name of the operation") - inputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Inputs of the trace") - outputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Outputs of the trace") + inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the trace") + outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the trace") attributes: Optional[Union[str, dict[str, Any], list, None]] = Field( None, description="Metadata and attributes associated with trace" ) diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index 49594cb0f1..a4f38dfbba 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -1,4 +1,3 @@ -import json import logging import os import uuid @@ -7,6 +6,7 @@ from typing import Any, Optional, cast import wandb import weave +from sqlalchemy.orm import Session, sessionmaker from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import WeaveConfig @@ -22,9 +22,11 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel +from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.entities.node_entities import NodeRunMetadataKey +from core.workflow.nodes.enums import NodeType from extensions.ext_database import db -from models.model import EndUser, MessageFile -from models.workflow import WorkflowNodeExecution +from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -128,58 +130,57 @@ class WeaveDataTrace(BaseTraceInstance): self.start_call(workflow_run, parent_run_id=trace_info.message_id) - # through workflow_run_id get all_nodes_execution - workflow_nodes_execution_id_records = ( - db.session.query(WorkflowNodeExecution.id) - .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id) - .all() + # through workflow_run_id get all_nodes_execution using repository + session_factory = sessionmaker(bind=db.engine) + # Find the app's creator account + with Session(db.engine, expire_on_commit=False) as session: + # Get the app to find its creator + app_id = trace_info.metadata.get("app_id") + if not app_id: + raise ValueError("No app_id found in trace_info metadata") + + app = session.query(App).filter(App.id == app_id).first() + if not app: + raise ValueError(f"App with id {app_id} not found") + + if not app.created_by: + raise ValueError(f"App with id {app_id} has no creator (created_by is None)") + + service_account = session.query(Account).filter(Account.id == app.created_by).first() + if not service_account: + raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") + + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=session_factory, + user=service_account, + app_id=trace_info.metadata.get("app_id"), + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) - for node_execution_id_record in workflow_nodes_execution_id_records: - node_execution = ( - db.session.query( - WorkflowNodeExecution.id, - WorkflowNodeExecution.tenant_id, - WorkflowNodeExecution.app_id, - WorkflowNodeExecution.title, - WorkflowNodeExecution.node_type, - WorkflowNodeExecution.status, - WorkflowNodeExecution.inputs, - WorkflowNodeExecution.outputs, - WorkflowNodeExecution.created_at, - WorkflowNodeExecution.elapsed_time, - WorkflowNodeExecution.process_data, - WorkflowNodeExecution.execution_metadata, - ) - .filter(WorkflowNodeExecution.id == node_execution_id_record.id) - .first() - ) - - if not node_execution: - continue + # Get all executions for this workflow run + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( + workflow_run_id=trace_info.workflow_run_id + ) + for node_execution in workflow_node_executions: node_execution_id = node_execution.id - tenant_id = node_execution.tenant_id - app_id = node_execution.app_id + tenant_id = trace_info.tenant_id # Use from trace_info instead + app_id = trace_info.metadata.get("app_id") # Use from trace_info instead node_name = node_execution.title node_type = node_execution.node_type status = node_execution.status - if node_type == "llm": - inputs = ( - json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {} - ) + if node_type == NodeType.LLM: + inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: - inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} - outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} + inputs = node_execution.inputs if node_execution.inputs else {} + outputs = node_execution.outputs if node_execution.outputs else {} created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) - execution_metadata = ( - json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} - ) - node_total_tokens = execution_metadata.get("total_tokens", 0) - attributes = execution_metadata.copy() + execution_metadata = node_execution.metadata if node_execution.metadata else {} + node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0 + attributes = {str(k): v for k, v in execution_metadata.items()} attributes.update( { "workflow_run_id": trace_info.workflow_run_id, @@ -192,7 +193,7 @@ class WeaveDataTrace(BaseTraceInstance): } ) - process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} + process_data = node_execution.process_data if node_execution.process_data else {} if process_data and process_data.get("model_mode") == "chat": attributes.update( { diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index a4ccdcafd3..bff0acc48f 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -19,7 +19,7 @@ from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_storage import storage -from models.enums import CreatedByRole +from models.enums import CreatorUserRole from models.model import UploadFile logger = logging.getLogger(__name__) @@ -116,7 +116,7 @@ class WordExtractor(BaseExtractor): extension=str(image_ext), mime_type=mime_type or "", created_by=self.user_id, - created_by_role=CreatedByRole.ACCOUNT, + created_by_role=CreatorUserRole.ACCOUNT, created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), used=True, used_by=self.user_id, diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 8bf2ab8761..2dc53479dd 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -2,16 +2,29 @@ SQLAlchemy implementation of the WorkflowNodeExecutionRepository. """ +import json import logging from collections.abc import Sequence -from typing import Optional +from typing import Optional, Union from sqlalchemy import UnaryExpression, asc, delete, desc, select from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker +from core.workflow.entities.node_execution_entities import ( + NodeExecution, + NodeExecutionStatus, +) +from core.workflow.nodes.enums import NodeType from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository -from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom +from models import ( + Account, + CreatorUserRole, + EndUser, + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, + WorkflowNodeExecutionTriggeredFrom, +) logger = logging.getLogger(__name__) @@ -23,16 +36,26 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) This implementation supports multi-tenancy by filtering operations based on tenant_id. Each method creates its own session, handles the transaction, and commits changes to the database. This prevents long-running connections in the workflow core. + + This implementation also includes an in-memory cache for node executions to improve + performance by reducing database queries. """ - def __init__(self, session_factory: sessionmaker | Engine, tenant_id: str, app_id: Optional[str] = None): + def __init__( + self, + session_factory: sessionmaker | Engine, + user: Union[Account, EndUser], + app_id: Optional[str], + triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom], + ): """ - Initialize the repository with a SQLAlchemy sessionmaker or engine and tenant context. + Initialize the repository with a SQLAlchemy sessionmaker or engine and context information. Args: session_factory: SQLAlchemy sessionmaker or engine for creating sessions - tenant_id: Tenant ID for multi-tenancy - app_id: Optional app ID for filtering by application + user: Account or EndUser object containing tenant_id, user ID, and role information + app_id: App ID for filtering by application (can be None) + triggered_from: Source of the execution trigger (SINGLE_STEP or WORKFLOW_RUN) """ # If an engine is provided, create a sessionmaker from it if isinstance(session_factory, Engine): @@ -44,38 +67,155 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine" ) + # Extract tenant_id from user + tenant_id: str | None = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id + if not tenant_id: + raise ValueError("User must have a tenant_id or current_tenant_id") self._tenant_id = tenant_id + + # Store app context self._app_id = app_id - def save(self, execution: WorkflowNodeExecution) -> None: + # Extract user context + self._triggered_from = triggered_from + self._creator_user_id = user.id + + # Determine user role based on user type + self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER + + # Initialize in-memory cache for node executions + # Key: node_execution_id, Value: NodeExecution + self._node_execution_cache: dict[str, NodeExecution] = {} + + def _to_domain_model(self, db_model: WorkflowNodeExecution) -> NodeExecution: """ - Save a WorkflowNodeExecution instance and commit changes to the database. + Convert a database model to a domain model. Args: - execution: The WorkflowNodeExecution instance to save + db_model: The database model to convert + + Returns: + The domain model + """ + # Parse JSON fields + inputs = db_model.inputs_dict + process_data = db_model.process_data_dict + outputs = db_model.outputs_dict + metadata = db_model.execution_metadata_dict + + # Convert status to domain enum + status = NodeExecutionStatus(db_model.status) + + return NodeExecution( + id=db_model.id, + node_execution_id=db_model.node_execution_id, + workflow_id=db_model.workflow_id, + workflow_run_id=db_model.workflow_run_id, + index=db_model.index, + predecessor_node_id=db_model.predecessor_node_id, + node_id=db_model.node_id, + node_type=NodeType(db_model.node_type), + title=db_model.title, + inputs=inputs, + process_data=process_data, + outputs=outputs, + status=status, + error=db_model.error, + elapsed_time=db_model.elapsed_time, + metadata=metadata, + created_at=db_model.created_at, + finished_at=db_model.finished_at, + ) + + def _to_db_model(self, domain_model: NodeExecution) -> WorkflowNodeExecution: + """ + Convert a domain model to a database model. + + Args: + domain_model: The domain model to convert + + Returns: + The database model + """ + # Use values from constructor if provided + if not self._triggered_from: + raise ValueError("triggered_from is required in repository constructor") + if not self._creator_user_id: + raise ValueError("created_by is required in repository constructor") + if not self._creator_user_role: + raise ValueError("created_by_role is required in repository constructor") + + db_model = WorkflowNodeExecution() + db_model.id = domain_model.id + db_model.tenant_id = self._tenant_id + if self._app_id is not None: + db_model.app_id = self._app_id + db_model.workflow_id = domain_model.workflow_id + db_model.triggered_from = self._triggered_from + db_model.workflow_run_id = domain_model.workflow_run_id + db_model.index = domain_model.index + db_model.predecessor_node_id = domain_model.predecessor_node_id + db_model.node_execution_id = domain_model.node_execution_id + db_model.node_id = domain_model.node_id + db_model.node_type = domain_model.node_type + db_model.title = domain_model.title + db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None + db_model.process_data = json.dumps(domain_model.process_data) if domain_model.process_data else None + db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None + db_model.status = domain_model.status + db_model.error = domain_model.error + db_model.elapsed_time = domain_model.elapsed_time + db_model.execution_metadata = json.dumps(domain_model.metadata) if domain_model.metadata else None + db_model.created_at = domain_model.created_at + db_model.created_by_role = self._creator_user_role + db_model.created_by = self._creator_user_id + db_model.finished_at = domain_model.finished_at + return db_model + + def save(self, execution: NodeExecution) -> None: + """ + Save or update a NodeExecution instance and commit changes to the database. + + This method handles both creating new records and updating existing ones. + It determines whether to create or update based on whether the record + already exists in the database. It also updates the in-memory cache. + + Args: + execution: The NodeExecution instance to save or update """ with self._session_factory() as session: - # Ensure tenant_id is set - if not execution.tenant_id: - execution.tenant_id = self._tenant_id + # Convert domain model to database model using instance attributes + db_model = self._to_db_model(execution) - # Set app_id if provided and not already set - if self._app_id and not execution.app_id: - execution.app_id = self._app_id - - session.add(execution) + # Use merge which will handle both insert and update + session.merge(db_model) session.commit() - def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]: + # Update the cache if node_execution_id is present + if execution.node_execution_id: + logger.debug(f"Updating cache for node_execution_id: {execution.node_execution_id}") + self._node_execution_cache[execution.node_execution_id] = execution + + def get_by_node_execution_id(self, node_execution_id: str) -> Optional[NodeExecution]: """ - Retrieve a WorkflowNodeExecution by its node_execution_id. + Retrieve a NodeExecution by its node_execution_id. + + First checks the in-memory cache, and if not found, queries the database. + If found in the database, adds it to the cache for future lookups. Args: node_execution_id: The node execution ID Returns: - The WorkflowNodeExecution instance if found, None otherwise + The NodeExecution instance if found, None otherwise """ + # First check the cache + if node_execution_id in self._node_execution_cache: + logger.debug(f"Cache hit for node_execution_id: {node_execution_id}") + return self._node_execution_cache[node_execution_id] + + # If not in cache, query the database + logger.debug(f"Cache miss for node_execution_id: {node_execution_id}, querying database") with self._session_factory() as session: stmt = select(WorkflowNodeExecution).where( WorkflowNodeExecution.node_execution_id == node_execution_id, @@ -85,15 +225,28 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) if self._app_id: stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) - return session.scalar(stmt) + db_model = session.scalar(stmt) + if db_model: + # Convert to domain model + domain_model = self._to_domain_model(db_model) + + # Add to cache + self._node_execution_cache[node_execution_id] = domain_model + + return domain_model + + return None def get_by_workflow_run( self, workflow_run_id: str, order_config: Optional[OrderConfig] = None, - ) -> Sequence[WorkflowNodeExecution]: + ) -> Sequence[NodeExecution]: """ - Retrieve all WorkflowNodeExecution instances for a specific workflow run. + Retrieve all NodeExecution instances for a specific workflow run. + + This method always queries the database to ensure complete and ordered results, + but updates the cache with any retrieved executions. Args: workflow_run_id: The workflow run ID @@ -102,7 +255,42 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) order_config.order_direction: Direction to order ("asc" or "desc") Returns: - A list of WorkflowNodeExecution instances + A list of NodeExecution instances + """ + # Get the raw database models using the new method + db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config) + + # Convert database models to domain models and update cache + domain_models = [] + for model in db_models: + domain_model = self._to_domain_model(model) + # Update cache if node_execution_id is present + if domain_model.node_execution_id: + self._node_execution_cache[domain_model.node_execution_id] = domain_model + domain_models.append(domain_model) + + return domain_models + + def get_db_models_by_workflow_run( + self, + workflow_run_id: str, + order_config: Optional[OrderConfig] = None, + ) -> Sequence[WorkflowNodeExecution]: + """ + Retrieve all WorkflowNodeExecution database models for a specific workflow run. + + This method is similar to get_by_workflow_run but returns the raw database models + instead of converting them to domain models. This can be useful when direct access + to database model properties is needed. + + Args: + workflow_run_id: The workflow run ID + order_config: Optional configuration for ordering results + order_config.order_by: List of fields to order by (e.g., ["index", "created_at"]) + order_config.order_direction: Direction to order ("asc" or "desc") + + Returns: + A list of WorkflowNodeExecution database models """ with self._session_factory() as session: stmt = select(WorkflowNodeExecution).where( @@ -129,17 +317,25 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) if order_columns: stmt = stmt.order_by(*order_columns) - return session.scalars(stmt).all() + db_models = session.scalars(stmt).all() - def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]: + # Note: We don't update the cache here since we're returning raw DB models + # and not converting to domain models + + return db_models + + def get_running_executions(self, workflow_run_id: str) -> Sequence[NodeExecution]: """ - Retrieve all running WorkflowNodeExecution instances for a specific workflow run. + Retrieve all running NodeExecution instances for a specific workflow run. + + This method queries the database directly and updates the cache with any + retrieved executions that have a node_execution_id. Args: workflow_run_id: The workflow run ID Returns: - A list of running WorkflowNodeExecution instances + A list of running NodeExecution instances """ with self._session_factory() as session: stmt = select(WorkflowNodeExecution).where( @@ -152,26 +348,17 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) if self._app_id: stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) - return session.scalars(stmt).all() + db_models = session.scalars(stmt).all() + domain_models = [] - def update(self, execution: WorkflowNodeExecution) -> None: - """ - Update an existing WorkflowNodeExecution instance and commit changes to the database. + for model in db_models: + domain_model = self._to_domain_model(model) + # Update cache if node_execution_id is present + if domain_model.node_execution_id: + self._node_execution_cache[domain_model.node_execution_id] = domain_model + domain_models.append(domain_model) - Args: - execution: The WorkflowNodeExecution instance to update - """ - with self._session_factory() as session: - # Ensure tenant_id is set - if not execution.tenant_id: - execution.tenant_id = self._tenant_id - - # Set app_id if provided and not already set - if self._app_id and not execution.app_id: - execution.app_id = self._app_id - - session.merge(execution) - session.commit() + return domain_models def clear(self) -> None: """ @@ -179,6 +366,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) This method deletes all WorkflowNodeExecution records that match the tenant_id and app_id (if provided) associated with this repository instance. + It also clears the in-memory cache. """ with self._session_factory() as session: stmt = delete(WorkflowNodeExecution).where(WorkflowNodeExecution.tenant_id == self._tenant_id) @@ -194,3 +382,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}" + (f" and app {self._app_id}" if self._app_id else "") ) + + # Clear the in-memory cache + self._node_execution_cache.clear() + logger.info("Cleared in-memory node execution cache") diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 3dce1ca293..178f2b9689 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -32,7 +32,7 @@ from core.tools.errors import ( from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db -from models.enums import CreatedByRole +from models.enums import CreatorUserRole from models.model import Message, MessageFile @@ -339,9 +339,9 @@ class ToolEngine: url=message.url, upload_file_id=tool_file_id, created_by_role=( - CreatedByRole.ACCOUNT + CreatorUserRole.ACCOUNT if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} - else CreatedByRole.END_USER + else CreatorUserRole.END_USER ), created_by=user_id, ) diff --git a/api/core/workflow/entities/node_execution_entities.py b/api/core/workflow/entities/node_execution_entities.py new file mode 100644 index 0000000000..5e5ead062f --- /dev/null +++ b/api/core/workflow/entities/node_execution_entities.py @@ -0,0 +1,98 @@ +""" +Domain entities for workflow node execution. + +This module contains the domain model for workflow node execution, which is used +by the core workflow module. These models are independent of the storage mechanism +and don't contain implementation details like tenant_id, app_id, etc. +""" + +from collections.abc import Mapping +from datetime import datetime +from enum import StrEnum +from typing import Any, Optional + +from pydantic import BaseModel, Field + +from core.workflow.entities.node_entities import NodeRunMetadataKey +from core.workflow.nodes.enums import NodeType + + +class NodeExecutionStatus(StrEnum): + """ + Node Execution Status Enum. + """ + + RUNNING = "running" + SUCCEEDED = "succeeded" + FAILED = "failed" + EXCEPTION = "exception" + RETRY = "retry" + + +class NodeExecution(BaseModel): + """ + Domain model for workflow node execution. + + This model represents the core business entity of a node execution, + without implementation details like tenant_id, app_id, etc. + + Note: User/context-specific fields (triggered_from, created_by, created_by_role) + have been moved to the repository implementation to keep the domain model clean. + These fields are still accepted in the constructor for backward compatibility, + but they are not stored in the model. + """ + + # Core identification fields + id: str # Unique identifier for this execution record + node_execution_id: Optional[str] = None # Optional secondary ID for cross-referencing + workflow_id: str # ID of the workflow this node belongs to + workflow_run_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging) + + # Execution positioning and flow + index: int # Sequence number for ordering in trace visualization + predecessor_node_id: Optional[str] = None # ID of the node that executed before this one + node_id: str # ID of the node being executed + node_type: NodeType # Type of node (e.g., start, llm, knowledge) + title: str # Display title of the node + + # Execution data + inputs: Optional[Mapping[str, Any]] = None # Input variables used by this node + process_data: Optional[Mapping[str, Any]] = None # Intermediate processing data + outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node + + # Execution state + status: NodeExecutionStatus = NodeExecutionStatus.RUNNING # Current execution status + error: Optional[str] = None # Error message if execution failed + elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds + + # Additional metadata + metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.) + + # Timing information + created_at: datetime # When execution started + finished_at: Optional[datetime] = None # When execution completed + + def update_from_mapping( + self, + inputs: Optional[Mapping[str, Any]] = None, + process_data: Optional[Mapping[str, Any]] = None, + outputs: Optional[Mapping[str, Any]] = None, + metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None, + ) -> None: + """ + Update the model from mappings. + + Args: + inputs: The inputs to update + process_data: The process data to update + outputs: The outputs to update + metadata: The metadata to update + """ + if inputs is not None: + self.inputs = dict(inputs) + if process_data is not None: + self.process_data = dict(process_data) + if outputs is not None: + self.outputs = dict(outputs) + if metadata is not None: + self.metadata = dict(metadata) diff --git a/api/core/workflow/repository/workflow_node_execution_repository.py b/api/core/workflow/repository/workflow_node_execution_repository.py index 9bb790cb0f..3ca9e2ecab 100644 --- a/api/core/workflow/repository/workflow_node_execution_repository.py +++ b/api/core/workflow/repository/workflow_node_execution_repository.py @@ -2,12 +2,12 @@ from collections.abc import Sequence from dataclasses import dataclass from typing import Literal, Optional, Protocol -from models.workflow import WorkflowNodeExecution +from core.workflow.entities.node_execution_entities import NodeExecution @dataclass class OrderConfig: - """Configuration for ordering WorkflowNodeExecution instances.""" + """Configuration for ordering NodeExecution instances.""" order_by: list[str] order_direction: Optional[Literal["asc", "desc"]] = None @@ -15,10 +15,10 @@ class OrderConfig: class WorkflowNodeExecutionRepository(Protocol): """ - Repository interface for WorkflowNodeExecution. + Repository interface for NodeExecution. This interface defines the contract for accessing and manipulating - WorkflowNodeExecution data, regardless of the underlying storage mechanism. + NodeExecution data, regardless of the underlying storage mechanism. Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), and trigger sources (triggered_from) should be handled at the implementation level, not in @@ -26,24 +26,28 @@ class WorkflowNodeExecutionRepository(Protocol): application domains or deployment scenarios. """ - def save(self, execution: WorkflowNodeExecution) -> None: + def save(self, execution: NodeExecution) -> None: """ - Save a WorkflowNodeExecution instance. + Save or update a NodeExecution instance. + + This method handles both creating new records and updating existing ones. + The implementation should determine whether to create or update based on + the execution's ID or other identifying fields. Args: - execution: The WorkflowNodeExecution instance to save + execution: The NodeExecution instance to save or update """ ... - def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]: + def get_by_node_execution_id(self, node_execution_id: str) -> Optional[NodeExecution]: """ - Retrieve a WorkflowNodeExecution by its node_execution_id. + Retrieve a NodeExecution by its node_execution_id. Args: node_execution_id: The node execution ID Returns: - The WorkflowNodeExecution instance if found, None otherwise + The NodeExecution instance if found, None otherwise """ ... @@ -51,9 +55,9 @@ class WorkflowNodeExecutionRepository(Protocol): self, workflow_run_id: str, order_config: Optional[OrderConfig] = None, - ) -> Sequence[WorkflowNodeExecution]: + ) -> Sequence[NodeExecution]: """ - Retrieve all WorkflowNodeExecution instances for a specific workflow run. + Retrieve all NodeExecution instances for a specific workflow run. Args: workflow_run_id: The workflow run ID @@ -62,34 +66,25 @@ class WorkflowNodeExecutionRepository(Protocol): order_config.order_direction: Direction to order ("asc" or "desc") Returns: - A list of WorkflowNodeExecution instances + A list of NodeExecution instances """ ... - def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]: + def get_running_executions(self, workflow_run_id: str) -> Sequence[NodeExecution]: """ - Retrieve all running WorkflowNodeExecution instances for a specific workflow run. + Retrieve all running NodeExecution instances for a specific workflow run. Args: workflow_run_id: The workflow run ID Returns: - A list of running WorkflowNodeExecution instances - """ - ... - - def update(self, execution: WorkflowNodeExecution) -> None: - """ - Update an existing WorkflowNodeExecution instance. - - Args: - execution: The WorkflowNodeExecution instance to update + A list of running NodeExecution instances """ ... def clear(self) -> None: """ - Clear all WorkflowNodeExecution records based on implementation-specific criteria. + Clear all NodeExecution records based on implementation-specific criteria. This method is intended to be used for bulk deletion operations, such as removing all records associated with a specific app_id and tenant_id in multi-tenant implementations. diff --git a/api/core/workflow/workflow_app_generate_task_pipeline.py b/api/core/workflow/workflow_app_generate_task_pipeline.py index 10a2d8b38b..0396fa8157 100644 --- a/api/core/workflow/workflow_app_generate_task_pipeline.py +++ b/api/core/workflow/workflow_app_generate_task_pipeline.py @@ -58,7 +58,7 @@ from core.workflow.repository.workflow_node_execution_repository import Workflow from core.workflow.workflow_cycle_manager import WorkflowCycleManager from extensions.ext_database import db from models.account import Account -from models.enums import CreatedByRole +from models.enums import CreatorUserRole from models.model import EndUser from models.workflow import ( Workflow, @@ -94,11 +94,11 @@ class WorkflowAppGenerateTaskPipeline: if isinstance(user, EndUser): self._user_id = user.id user_session_id = user.session_id - self._created_by_role = CreatedByRole.END_USER + self._created_by_role = CreatorUserRole.END_USER elif isinstance(user, Account): self._user_id = user.id user_session_id = user.id - self._created_by_role = CreatedByRole.ACCOUNT + self._created_by_role = CreatorUserRole.ACCOUNT else: raise ValueError(f"Invalid user type: {type(user)}") diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py index 01d5db4303..6d33d7372c 100644 --- a/api/core/workflow/workflow_cycle_manager.py +++ b/api/core/workflow/workflow_cycle_manager.py @@ -46,26 +46,28 @@ from core.app.entities.task_entities import ( ) from core.app.task_pipeline.exc import WorkflowRunNotFoundError from core.file import FILE_MODEL_IDENTITY, File -from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.tools.tool_manager import ToolManager from core.workflow.entities.node_entities import NodeRunMetadataKey +from core.workflow.entities.node_execution_entities import ( + NodeExecution, + NodeExecutionStatus, +) from core.workflow.enums import SystemVariableKey from core.workflow.nodes import NodeType from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.workflow_entry import WorkflowEntry -from models.account import Account -from models.enums import CreatedByRole, WorkflowRunTriggeredFrom -from models.model import EndUser -from models.workflow import ( +from models import ( + Account, + CreatorUserRole, + EndUser, Workflow, - WorkflowNodeExecution, WorkflowNodeExecutionStatus, - WorkflowNodeExecutionTriggeredFrom, WorkflowRun, WorkflowRunStatus, + WorkflowRunTriggeredFrom, ) @@ -78,7 +80,6 @@ class WorkflowCycleManager: workflow_node_execution_repository: WorkflowNodeExecutionRepository, ) -> None: self._workflow_run: WorkflowRun | None = None - self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {} self._application_generate_entity = application_generate_entity self._workflow_system_variables = workflow_system_variables self._workflow_node_execution_repository = workflow_node_execution_repository @@ -89,7 +90,7 @@ class WorkflowCycleManager: session: Session, workflow_id: str, user_id: str, - created_by_role: CreatedByRole, + created_by_role: CreatorUserRole, ) -> WorkflowRun: workflow_stmt = select(Workflow).where(Workflow.id == workflow_id) workflow = session.scalar(workflow_stmt) @@ -258,21 +259,22 @@ class WorkflowCycleManager: workflow_run.exceptions_count = exceptions_count # Use the instance repository to find running executions for a workflow run - running_workflow_node_executions = self._workflow_node_execution_repository.get_running_executions( + running_domain_executions = self._workflow_node_execution_repository.get_running_executions( workflow_run_id=workflow_run.id ) - # Update the cache with the retrieved executions - for execution in running_workflow_node_executions: - if execution.node_execution_id: - self._workflow_node_executions[execution.node_execution_id] = execution + # Update the domain models + now = datetime.now(UTC).replace(tzinfo=None) + for domain_execution in running_domain_executions: + if domain_execution.node_execution_id: + # Update the domain model + domain_execution.status = NodeExecutionStatus.FAILED + domain_execution.error = error + domain_execution.finished_at = now + domain_execution.elapsed_time = (now - domain_execution.created_at).total_seconds() - for workflow_node_execution in running_workflow_node_executions: - now = datetime.now(UTC).replace(tzinfo=None) - workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value - workflow_node_execution.error = error - workflow_node_execution.finished_at = now - workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds() + # Update the repository with the domain model + self._workflow_node_execution_repository.save(domain_execution) if trace_manager: trace_manager.add_trace_task( @@ -286,63 +288,67 @@ class WorkflowCycleManager: return workflow_run - def _handle_node_execution_start( - self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent - ) -> WorkflowNodeExecution: - workflow_node_execution = WorkflowNodeExecution() - workflow_node_execution.id = str(uuid4()) - workflow_node_execution.tenant_id = workflow_run.tenant_id - workflow_node_execution.app_id = workflow_run.app_id - workflow_node_execution.workflow_id = workflow_run.workflow_id - workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value - workflow_node_execution.workflow_run_id = workflow_run.id - workflow_node_execution.predecessor_node_id = event.predecessor_node_id - workflow_node_execution.index = event.node_run_index - workflow_node_execution.node_execution_id = event.node_execution_id - workflow_node_execution.node_id = event.node_id - workflow_node_execution.node_type = event.node_type.value - workflow_node_execution.title = event.node_data.title - workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value - workflow_node_execution.created_by_role = workflow_run.created_by_role - workflow_node_execution.created_by = workflow_run.created_by - workflow_node_execution.execution_metadata = json.dumps( - { - NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, - NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, - NodeRunMetadataKey.LOOP_ID: event.in_loop_id, - } + def _handle_node_execution_start(self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> NodeExecution: + # Create a domain model + created_at = datetime.now(UTC).replace(tzinfo=None) + metadata = { + NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, + NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, + NodeRunMetadataKey.LOOP_ID: event.in_loop_id, + } + + domain_execution = NodeExecution( + id=str(uuid4()), + workflow_id=workflow_run.workflow_id, + workflow_run_id=workflow_run.id, + predecessor_node_id=event.predecessor_node_id, + index=event.node_run_index, + node_execution_id=event.node_execution_id, + node_id=event.node_id, + node_type=event.node_type, + title=event.node_data.title, + status=NodeExecutionStatus.RUNNING, + metadata=metadata, + created_at=created_at, ) - workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) - # Use the instance repository to save the workflow node execution - self._workflow_node_execution_repository.save(workflow_node_execution) + # Use the instance repository to save the domain model + self._workflow_node_execution_repository.save(domain_execution) - self._workflow_node_executions[event.node_execution_id] = workflow_node_execution - return workflow_node_execution + return domain_execution - def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: - workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id) + def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> NodeExecution: + # Get the domain model from repository + domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id) + if not domain_execution: + raise ValueError(f"Domain node execution not found: {event.node_execution_id}") + + # Process data inputs = WorkflowEntry.handle_special_values(event.inputs) process_data = WorkflowEntry.handle_special_values(event.process_data) outputs = WorkflowEntry.handle_special_values(event.outputs) - execution_metadata_dict = dict(event.execution_metadata or {}) - execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None + + # Convert metadata keys to strings + execution_metadata_dict = {} + if event.execution_metadata: + for key, value in event.execution_metadata.items(): + execution_metadata_dict[key] = value + finished_at = datetime.now(UTC).replace(tzinfo=None) elapsed_time = (finished_at - event.start_at).total_seconds() - process_data = WorkflowEntry.handle_special_values(event.process_data) + # Update domain model + domain_execution.status = NodeExecutionStatus.SUCCEEDED + domain_execution.update_from_mapping( + inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict + ) + domain_execution.finished_at = finished_at + domain_execution.elapsed_time = elapsed_time - workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value - workflow_node_execution.inputs = json.dumps(inputs) if inputs else None - workflow_node_execution.process_data = json.dumps(process_data) if process_data else None - workflow_node_execution.outputs = json.dumps(outputs) if outputs else None - workflow_node_execution.execution_metadata = execution_metadata - workflow_node_execution.finished_at = finished_at - workflow_node_execution.elapsed_time = elapsed_time + # Update the repository with the domain model + self._workflow_node_execution_repository.save(domain_execution) - # Use the instance repository to update the workflow node execution - self._workflow_node_execution_repository.update(workflow_node_execution) - return workflow_node_execution + return domain_execution def _handle_workflow_node_execution_failed( self, @@ -351,43 +357,52 @@ class WorkflowCycleManager: | QueueNodeInIterationFailedEvent | QueueNodeInLoopFailedEvent | QueueNodeExceptionEvent, - ) -> WorkflowNodeExecution: + ) -> NodeExecution: """ Workflow node execution failed :param event: queue node failed event :return: """ - workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id) + # Get the domain model from repository + domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id) + if not domain_execution: + raise ValueError(f"Domain node execution not found: {event.node_execution_id}") + # Process data inputs = WorkflowEntry.handle_special_values(event.inputs) process_data = WorkflowEntry.handle_special_values(event.process_data) outputs = WorkflowEntry.handle_special_values(event.outputs) + + # Convert metadata keys to strings + execution_metadata_dict = {} + if event.execution_metadata: + for key, value in event.execution_metadata.items(): + execution_metadata_dict[key] = value + finished_at = datetime.now(UTC).replace(tzinfo=None) elapsed_time = (finished_at - event.start_at).total_seconds() - execution_metadata = ( - json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None - ) - process_data = WorkflowEntry.handle_special_values(event.process_data) - workflow_node_execution.status = ( - WorkflowNodeExecutionStatus.FAILED.value + + # Update domain model + domain_execution.status = ( + NodeExecutionStatus.FAILED if not isinstance(event, QueueNodeExceptionEvent) - else WorkflowNodeExecutionStatus.EXCEPTION.value + else NodeExecutionStatus.EXCEPTION ) - workflow_node_execution.error = event.error - workflow_node_execution.inputs = json.dumps(inputs) if inputs else None - workflow_node_execution.process_data = json.dumps(process_data) if process_data else None - workflow_node_execution.outputs = json.dumps(outputs) if outputs else None - workflow_node_execution.finished_at = finished_at - workflow_node_execution.elapsed_time = elapsed_time - workflow_node_execution.execution_metadata = execution_metadata + domain_execution.error = event.error + domain_execution.update_from_mapping( + inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict + ) + domain_execution.finished_at = finished_at + domain_execution.elapsed_time = elapsed_time - self._workflow_node_execution_repository.update(workflow_node_execution) + # Update the repository with the domain model + self._workflow_node_execution_repository.save(domain_execution) - return workflow_node_execution + return domain_execution def _handle_workflow_node_execution_retried( self, *, workflow_run: WorkflowRun, event: QueueNodeRetryEvent - ) -> WorkflowNodeExecution: + ) -> NodeExecution: """ Workflow node execution failed :param workflow_run: workflow run @@ -399,47 +414,47 @@ class WorkflowCycleManager: elapsed_time = (finished_at - created_at).total_seconds() inputs = WorkflowEntry.handle_special_values(event.inputs) outputs = WorkflowEntry.handle_special_values(event.outputs) + + # Convert metadata keys to strings origin_metadata = { NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, NodeRunMetadataKey.LOOP_ID: event.in_loop_id, } - merged_metadata = ( - {**jsonable_encoder(event.execution_metadata), **origin_metadata} - if event.execution_metadata is not None - else origin_metadata + + # Convert execution metadata keys to strings + execution_metadata_dict: dict[NodeRunMetadataKey, str | None] = {} + if event.execution_metadata: + for key, value in event.execution_metadata.items(): + execution_metadata_dict[key] = value + + merged_metadata = {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata + + # Create a domain model + domain_execution = NodeExecution( + id=str(uuid4()), + workflow_id=workflow_run.workflow_id, + workflow_run_id=workflow_run.id, + predecessor_node_id=event.predecessor_node_id, + node_execution_id=event.node_execution_id, + node_id=event.node_id, + node_type=event.node_type, + title=event.node_data.title, + status=NodeExecutionStatus.RETRY, + created_at=created_at, + finished_at=finished_at, + elapsed_time=elapsed_time, + error=event.error, + index=event.node_run_index, ) - execution_metadata = json.dumps(merged_metadata) - workflow_node_execution = WorkflowNodeExecution() - workflow_node_execution.id = str(uuid4()) - workflow_node_execution.tenant_id = workflow_run.tenant_id - workflow_node_execution.app_id = workflow_run.app_id - workflow_node_execution.workflow_id = workflow_run.workflow_id - workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value - workflow_node_execution.workflow_run_id = workflow_run.id - workflow_node_execution.predecessor_node_id = event.predecessor_node_id - workflow_node_execution.node_execution_id = event.node_execution_id - workflow_node_execution.node_id = event.node_id - workflow_node_execution.node_type = event.node_type.value - workflow_node_execution.title = event.node_data.title - workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value - workflow_node_execution.created_by_role = workflow_run.created_by_role - workflow_node_execution.created_by = workflow_run.created_by - workflow_node_execution.created_at = created_at - workflow_node_execution.finished_at = finished_at - workflow_node_execution.elapsed_time = elapsed_time - workflow_node_execution.error = event.error - workflow_node_execution.inputs = json.dumps(inputs) if inputs else None - workflow_node_execution.outputs = json.dumps(outputs) if outputs else None - workflow_node_execution.execution_metadata = execution_metadata - workflow_node_execution.index = event.node_run_index + # Update with mappings + domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=merged_metadata) - # Use the instance repository to save the workflow node execution - self._workflow_node_execution_repository.save(workflow_node_execution) + # Use the instance repository to save the domain model + self._workflow_node_execution_repository.save(domain_execution) - self._workflow_node_executions[event.node_execution_id] = workflow_node_execution - return workflow_node_execution + return domain_execution def _workflow_start_to_stream_response( self, @@ -469,7 +484,7 @@ class WorkflowCycleManager: workflow_run: WorkflowRun, ) -> WorkflowFinishStreamResponse: created_by = None - if workflow_run.created_by_role == CreatedByRole.ACCOUNT: + if workflow_run.created_by_role == CreatorUserRole.ACCOUNT: stmt = select(Account).where(Account.id == workflow_run.created_by) account = session.scalar(stmt) if account: @@ -478,7 +493,7 @@ class WorkflowCycleManager: "name": account.name, "email": account.email, } - elif workflow_run.created_by_role == CreatedByRole.END_USER: + elif workflow_run.created_by_role == CreatorUserRole.END_USER: stmt = select(EndUser).where(EndUser.id == workflow_run.created_by) end_user = session.scalar(stmt) if end_user: @@ -515,9 +530,9 @@ class WorkflowCycleManager: *, event: QueueNodeStartedEvent, task_id: str, - workflow_node_execution: WorkflowNodeExecution, + workflow_node_execution: NodeExecution, ) -> Optional[NodeStartStreamResponse]: - if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: + if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: return None if not workflow_node_execution.workflow_run_id: return None @@ -532,7 +547,7 @@ class WorkflowCycleManager: title=workflow_node_execution.title, index=workflow_node_execution.index, predecessor_node_id=workflow_node_execution.predecessor_node_id, - inputs=workflow_node_execution.inputs_dict, + inputs=workflow_node_execution.inputs, created_at=int(workflow_node_execution.created_at.timestamp()), parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, @@ -565,9 +580,9 @@ class WorkflowCycleManager: | QueueNodeInLoopFailedEvent | QueueNodeExceptionEvent, task_id: str, - workflow_node_execution: WorkflowNodeExecution, + workflow_node_execution: NodeExecution, ) -> Optional[NodeFinishStreamResponse]: - if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: + if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: return None if not workflow_node_execution.workflow_run_id: return None @@ -584,16 +599,16 @@ class WorkflowCycleManager: index=workflow_node_execution.index, title=workflow_node_execution.title, predecessor_node_id=workflow_node_execution.predecessor_node_id, - inputs=workflow_node_execution.inputs_dict, - process_data=workflow_node_execution.process_data_dict, - outputs=workflow_node_execution.outputs_dict, + inputs=workflow_node_execution.inputs, + process_data=workflow_node_execution.process_data, + outputs=workflow_node_execution.outputs, status=workflow_node_execution.status, error=workflow_node_execution.error, elapsed_time=workflow_node_execution.elapsed_time, - execution_metadata=workflow_node_execution.execution_metadata_dict, + execution_metadata=workflow_node_execution.metadata, created_at=int(workflow_node_execution.created_at.timestamp()), finished_at=int(workflow_node_execution.finished_at.timestamp()), - files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}), + files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs or {}), parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, parent_parallel_id=event.parent_parallel_id, @@ -608,9 +623,9 @@ class WorkflowCycleManager: *, event: QueueNodeRetryEvent, task_id: str, - workflow_node_execution: WorkflowNodeExecution, + workflow_node_execution: NodeExecution, ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]: - if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: + if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: return None if not workflow_node_execution.workflow_run_id: return None @@ -627,16 +642,16 @@ class WorkflowCycleManager: index=workflow_node_execution.index, title=workflow_node_execution.title, predecessor_node_id=workflow_node_execution.predecessor_node_id, - inputs=workflow_node_execution.inputs_dict, - process_data=workflow_node_execution.process_data_dict, - outputs=workflow_node_execution.outputs_dict, + inputs=workflow_node_execution.inputs, + process_data=workflow_node_execution.process_data, + outputs=workflow_node_execution.outputs, status=workflow_node_execution.status, error=workflow_node_execution.error, elapsed_time=workflow_node_execution.elapsed_time, - execution_metadata=workflow_node_execution.execution_metadata_dict, + execution_metadata=workflow_node_execution.metadata, created_at=int(workflow_node_execution.created_at.timestamp()), finished_at=int(workflow_node_execution.finished_at.timestamp()), - files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}), + files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs or {}), parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, parent_parallel_id=event.parent_parallel_id, @@ -908,23 +923,6 @@ class WorkflowCycleManager: return workflow_run - def _get_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution: - # First check the cache for performance - if node_execution_id in self._workflow_node_executions: - cached_execution = self._workflow_node_executions[node_execution_id] - # No need to merge with session since expire_on_commit=False - return cached_execution - - # If not in cache, use the instance repository to get by node_execution_id - execution = self._workflow_node_execution_repository.get_by_node_execution_id(node_execution_id) - - if not execution: - raise ValueError(f"Workflow node execution not found: {node_execution_id}") - - # Update cache - self._workflow_node_executions[node_execution_id] = execution - return execution - def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse: """ Handle agent log diff --git a/api/models/__init__.py b/api/models/__init__.py index 2066481a61..f652449e98 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -27,7 +27,7 @@ from .dataset import ( Whitelist, ) from .engine import db -from .enums import CreatedByRole, UserFrom, WorkflowRunTriggeredFrom +from .enums import CreatorUserRole, UserFrom, WorkflowRunTriggeredFrom from .model import ( ApiRequest, ApiToken, @@ -112,7 +112,7 @@ __all__ = [ "CeleryTaskSet", "Conversation", "ConversationVariable", - "CreatedByRole", + "CreatorUserRole", "DataSourceApiKeyAuthBinding", "DataSourceOauthBinding", "Dataset", diff --git a/api/models/enums.py b/api/models/enums.py index 7b9500ebe4..7d9f6068bb 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -1,7 +1,7 @@ from enum import StrEnum -class CreatedByRole(StrEnum): +class CreatorUserRole(StrEnum): ACCOUNT = "account" END_USER = "end_user" diff --git a/api/models/model.py b/api/models/model.py index ab426649c5..ee79fbd6b5 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -29,7 +29,7 @@ from libs.helper import generate_string from .account import Account, Tenant from .base import Base from .engine import db -from .enums import CreatedByRole +from .enums import CreatorUserRole from .types import StringUUID from .workflow import WorkflowRunStatus @@ -1270,7 +1270,7 @@ class MessageFile(Base): url: str | None = None, belongs_to: Literal["user", "assistant"] | None = None, upload_file_id: str | None = None, - created_by_role: CreatedByRole, + created_by_role: CreatorUserRole, created_by: str, ): self.message_id = message_id @@ -1417,7 +1417,7 @@ class EndUser(Base, UserMixin): ) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) + tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=True) type = db.Column(db.String(255), nullable=False) external_user_id = db.Column(db.String(255), nullable=True) @@ -1547,7 +1547,7 @@ class UploadFile(Base): size: int, extension: str, mime_type: str, - created_by_role: CreatedByRole, + created_by_role: CreatorUserRole, created_by: str, created_at: datetime, used: bool, diff --git a/api/models/workflow.py b/api/models/workflow.py index 730ddcfcb5..fd0d279d50 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -22,7 +22,7 @@ from libs import helper from .account import Account from .base import Base from .engine import db -from .enums import CreatedByRole +from .enums import CreatorUserRole from .types import StringUUID if TYPE_CHECKING: @@ -429,15 +429,15 @@ class WorkflowRun(Base): @property def created_by_account(self): - created_by_role = CreatedByRole(self.created_by_role) - return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None + created_by_role = CreatorUserRole(self.created_by_role) + return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser - created_by_role = CreatedByRole(self.created_by_role) - return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None + created_by_role = CreatorUserRole(self.created_by_role) + return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None @property def graph_dict(self): @@ -634,17 +634,17 @@ class WorkflowNodeExecution(Base): @property def created_by_account(self): - created_by_role = CreatedByRole(self.created_by_role) + created_by_role = CreatorUserRole(self.created_by_role) # TODO(-LAN-): Avoid using db.session.get() here. - return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None + return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser - created_by_role = CreatedByRole(self.created_by_role) + created_by_role = CreatorUserRole(self.created_by_role) # TODO(-LAN-): Avoid using db.session.get() here. - return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None + return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None @property def inputs_dict(self): @@ -755,15 +755,15 @@ class WorkflowAppLog(Base): @property def created_by_account(self): - created_by_role = CreatedByRole(self.created_by_role) - return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None + created_by_role = CreatorUserRole(self.created_by_role) + return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser - created_by_role = CreatedByRole(self.created_by_role) - return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None + created_by_role = CreatorUserRole(self.created_by_role) + return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None class ConversationVariable(Base): diff --git a/api/services/file_service.py b/api/services/file_service.py index 2ca6b4f9aa..2d68f30c5a 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -19,7 +19,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor from extensions.ext_database import db from extensions.ext_storage import storage from models.account import Account -from models.enums import CreatedByRole +from models.enums import CreatorUserRole from models.model import EndUser, UploadFile from .errors.file import FileTooLargeError, UnsupportedFileTypeError @@ -81,7 +81,7 @@ class FileService: size=file_size, extension=extension, mime_type=mimetype, - created_by_role=(CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER), + created_by_role=(CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER), created_by=user.id, created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), used=False, @@ -133,7 +133,7 @@ class FileService: extension="txt", mime_type="text/plain", created_by=current_user.id, - created_by_role=CreatedByRole.ACCOUNT, + created_by_role=CreatorUserRole.ACCOUNT, created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), used=True, used_by=current_user.id, diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index e526517b51..a899ebe278 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -5,7 +5,7 @@ from sqlalchemy import and_, func, or_, select from sqlalchemy.orm import Session from models import App, EndUser, WorkflowAppLog, WorkflowRun -from models.enums import CreatedByRole +from models.enums import CreatorUserRole from models.workflow import WorkflowRunStatus @@ -58,7 +58,7 @@ class WorkflowAppService: stmt = stmt.outerjoin( EndUser, - and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER), + and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatorUserRole.END_USER), ).where(or_(*keyword_conditions)) if status: diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index 6d5b737962..a760b0f586 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -1,4 +1,5 @@ import threading +from collections.abc import Sequence from typing import Optional import contexts @@ -6,11 +7,13 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.workflow.repository.workflow_node_execution_repository import OrderConfig from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination -from models.enums import WorkflowRunTriggeredFrom -from models.model import App -from models.workflow import ( +from models import ( + Account, + App, + EndUser, WorkflowNodeExecution, WorkflowRun, + WorkflowRunTriggeredFrom, ) @@ -116,7 +119,12 @@ class WorkflowRunService: return workflow_run - def get_workflow_run_node_executions(self, app_model: App, run_id: str) -> list[WorkflowNodeExecution]: + def get_workflow_run_node_executions( + self, + app_model: App, + run_id: str, + user: Account | EndUser, + ) -> Sequence[WorkflowNodeExecution]: """ Get workflow run node execution list """ @@ -128,13 +136,15 @@ class WorkflowRunService: if not workflow_run: return [] - # Use the repository to get the node executions repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=db.engine, tenant_id=app_model.tenant_id, app_id=app_model.id + session_factory=db.engine, + user=user, + app_id=app_model.id, + triggered_from=None, ) # Use the repository to get the node executions with ordering order_config = OrderConfig(order_by=["index"], order_direction="desc") - node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config) + node_executions = repository.get_db_models_by_workflow_run(workflow_run_id=run_id, order_config=order_config) - return list(node_executions) + return node_executions diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 331dba8bf1..04e5f2eb41 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -26,7 +26,7 @@ from core.workflow.workflow_entry import WorkflowEntry from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db from models.account import Account -from models.enums import CreatedByRole +from models.enums import CreatorUserRole from models.model import App, AppMode from models.tools import WorkflowToolProvider from models.workflow import ( @@ -284,9 +284,11 @@ class WorkflowService: workflow_node_execution.created_by = account.id workflow_node_execution.workflow_id = draft_workflow.id - # Use the repository to save the workflow node execution repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=db.engine, tenant_id=app_model.tenant_id, app_id=app_model.id + session_factory=db.engine, + user=account, + app_id=app_model.id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) repository.save(workflow_node_execution) @@ -390,7 +392,7 @@ class WorkflowService: workflow_node_execution.node_type = node_instance.node_type workflow_node_execution.title = node_instance.node_data.title workflow_node_execution.elapsed_time = time.perf_counter() - start_at - workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value + workflow_node_execution.created_by_role = CreatorUserRole.ACCOUNT.value workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None) if run_succeeded and node_run_result: diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index d5a783396a..d9c1980d3f 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -4,16 +4,19 @@ from collections.abc import Callable import click from celery import shared_task # type: ignore -from sqlalchemy import delete +from sqlalchemy import delete, select from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from extensions.ext_database import db -from models.dataset import AppDatasetJoin -from models.model import ( +from models import ( + Account, ApiToken, + App, AppAnnotationHitHistory, AppAnnotationSetting, + AppDatasetJoin, AppModelConfig, Conversation, EndUser, @@ -188,9 +191,24 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str): def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): + # Get app's owner + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(Account).where(Account.id == App.owner_id).where(App.id == app_id) + user = session.scalar(stmt) + + if user is None: + errmsg = ( + f"Failed to delete workflow node executions for tenant {tenant_id} and app {app_id}, app's owner not found" + ) + logging.error(errmsg) + raise ValueError(errmsg) + # Create a repository instance for WorkflowNodeExecution repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=db.engine, tenant_id=tenant_id, app_id=app_id + session_factory=db.engine, + user=user, + app_id=app_id, + triggered_from=None, ) # Use the clear method to delete all records for this tenant_id and app_id diff --git a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py index 6b00b203c4..94b9d3e2c6 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py @@ -16,10 +16,9 @@ from core.workflow.enums import SystemVariableKey from core.workflow.nodes import NodeType from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.workflow_cycle_manager import WorkflowCycleManager -from models.enums import CreatedByRole +from models.enums import CreatorUserRole from models.workflow import ( Workflow, - WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowRun, WorkflowRunStatus, @@ -94,7 +93,7 @@ def mock_workflow_run(): workflow_run.app_id = "test-app-id" workflow_run.workflow_id = "test-workflow-id" workflow_run.status = WorkflowRunStatus.RUNNING - workflow_run.created_by_role = CreatedByRole.ACCOUNT + workflow_run.created_by_role = CreatorUserRole.ACCOUNT workflow_run.created_by = "test-user-id" workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) workflow_run.inputs_dict = {"query": "test query"} @@ -107,7 +106,6 @@ def test_init( ): """Test initialization of WorkflowCycleManager""" assert workflow_cycle_manager._workflow_run is None - assert workflow_cycle_manager._workflow_node_executions == {} assert workflow_cycle_manager._application_generate_entity == mock_app_generate_entity assert workflow_cycle_manager._workflow_system_variables == mock_workflow_system_variables assert workflow_cycle_manager._workflow_node_execution_repository == mock_node_execution_repository @@ -123,7 +121,7 @@ def test_handle_workflow_run_start(workflow_cycle_manager, mock_session, mock_wo session=mock_session, workflow_id="test-workflow-id", user_id="test-user-id", - created_by_role=CreatedByRole.ACCOUNT, + created_by_role=CreatorUserRole.ACCOUNT, ) # Verify the result @@ -132,7 +130,7 @@ def test_handle_workflow_run_start(workflow_cycle_manager, mock_session, mock_wo assert workflow_run.workflow_id == mock_workflow.id assert workflow_run.sequence_number == 6 # max_sequence + 1 assert workflow_run.status == WorkflowRunStatus.RUNNING - assert workflow_run.created_by_role == CreatedByRole.ACCOUNT + assert workflow_run.created_by_role == CreatorUserRole.ACCOUNT assert workflow_run.created_by == "test-user-id" # Verify session.add was called @@ -215,24 +213,23 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_run): ) # Verify the result - assert result.tenant_id == mock_workflow_run.tenant_id - assert result.app_id == mock_workflow_run.app_id + # NodeExecution doesn't have tenant_id attribute, it's handled at repository level + # assert result.tenant_id == mock_workflow_run.tenant_id + # assert result.app_id == mock_workflow_run.app_id assert result.workflow_id == mock_workflow_run.workflow_id assert result.workflow_run_id == mock_workflow_run.id assert result.node_execution_id == event.node_execution_id assert result.node_id == event.node_id - assert result.node_type == event.node_type.value + assert result.node_type == event.node_type assert result.title == event.node_data.title assert result.status == WorkflowNodeExecutionStatus.RUNNING.value - assert result.created_by_role == mock_workflow_run.created_by_role - assert result.created_by == mock_workflow_run.created_by + # NodeExecution doesn't have created_by_role and created_by attributes, they're handled at repository level + # assert result.created_by_role == mock_workflow_run.created_by_role + # assert result.created_by == mock_workflow_run.created_by # Verify save was called workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(result) - # Verify the node execution was added to the cache - assert workflow_cycle_manager._workflow_node_executions[event.node_execution_id] == result - def test_get_workflow_run(workflow_cycle_manager, mock_session, mock_workflow_run): """Test _get_workflow_run method""" @@ -261,28 +258,24 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager): event.execution_metadata = {"metadata": "test metadata"} event.start_at = datetime.now(UTC).replace(tzinfo=None) - # Create a mock workflow node execution - node_execution = MagicMock(spec=WorkflowNodeExecution) + # Create a mock node execution + node_execution = MagicMock() node_execution.node_execution_id = "test-node-execution-id" - # Mock _get_workflow_node_execution to return the mock node execution - with patch.object(workflow_cycle_manager, "_get_workflow_node_execution", return_value=node_execution): - # Call the method - result = workflow_cycle_manager._handle_workflow_node_execution_success( - event=event, - ) + # Mock the repository to return the node execution + workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution - # Verify the result - assert result == node_execution - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED.value - assert result.inputs == json.dumps(event.inputs) - assert result.process_data == json.dumps(event.process_data) - assert result.outputs == json.dumps(event.outputs) - assert result.finished_at is not None - assert result.elapsed_time is not None + # Call the method + result = workflow_cycle_manager._handle_workflow_node_execution_success( + event=event, + ) - # Verify update was called - workflow_cycle_manager._workflow_node_execution_repository.update.assert_called_once_with(node_execution) + # Verify the result + assert result == node_execution + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED.value + + # Verify save was called + workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(node_execution) def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_session, mock_workflow_run): @@ -322,27 +315,22 @@ def test_handle_workflow_node_execution_failed(workflow_cycle_manager): event.start_at = datetime.now(UTC).replace(tzinfo=None) event.error = "Test error message" - # Create a mock workflow node execution - node_execution = MagicMock(spec=WorkflowNodeExecution) + # Create a mock node execution + node_execution = MagicMock() node_execution.node_execution_id = "test-node-execution-id" - # Mock _get_workflow_node_execution to return the mock node execution - with patch.object(workflow_cycle_manager, "_get_workflow_node_execution", return_value=node_execution): - # Call the method - result = workflow_cycle_manager._handle_workflow_node_execution_failed( - event=event, - ) + # Mock the repository to return the node execution + workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution - # Verify the result - assert result == node_execution - assert result.status == WorkflowNodeExecutionStatus.FAILED.value - assert result.error == "Test error message" - assert result.inputs == json.dumps(event.inputs) - assert result.process_data == json.dumps(event.process_data) - assert result.outputs == json.dumps(event.outputs) - assert result.finished_at is not None - assert result.elapsed_time is not None - assert result.execution_metadata == json.dumps(event.execution_metadata) + # Call the method + result = workflow_cycle_manager._handle_workflow_node_execution_failed( + event=event, + ) - # Verify update was called - workflow_cycle_manager._workflow_node_execution_repository.update.assert_called_once_with(node_execution) + # Verify the result + assert result == node_execution + assert result.status == WorkflowNodeExecutionStatus.FAILED.value + assert result.error == "Test error message" + + # Verify save was called + workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(node_execution) diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py index 9cda873e90..64c3b66f12 100644 --- a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py +++ b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py @@ -2,15 +2,36 @@ Unit tests for the SQLAlchemy implementation of WorkflowNodeExecutionRepository. """ -from unittest.mock import MagicMock +import json +from datetime import datetime +from unittest.mock import MagicMock, PropertyMock import pytest from pytest_mock import MockerFixture from sqlalchemy.orm import Session, sessionmaker from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.entities.node_entities import NodeRunMetadataKey +from core.workflow.entities.node_execution_entities import NodeExecution, NodeExecutionStatus +from core.workflow.nodes.enums import NodeType from core.workflow.repository.workflow_node_execution_repository import OrderConfig -from models.workflow import WorkflowNodeExecution +from models.account import Account, Tenant +from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom + + +def configure_mock_execution(mock_execution): + """Configure a mock execution with proper JSON serializable values.""" + # Configure inputs, outputs, process_data, and execution_metadata to return JSON serializable values + type(mock_execution).inputs = PropertyMock(return_value='{"key": "value"}') + type(mock_execution).outputs = PropertyMock(return_value='{"result": "success"}') + type(mock_execution).process_data = PropertyMock(return_value='{"process": "data"}') + type(mock_execution).execution_metadata = PropertyMock(return_value='{"metadata": "info"}') + + # Configure status and triggered_from to be valid enum values + mock_execution.status = "running" + mock_execution.triggered_from = "workflow-run" + + return mock_execution @pytest.fixture @@ -28,13 +49,30 @@ def session(): @pytest.fixture -def repository(session): +def mock_user(): + """Create a user instance for testing.""" + user = Account() + user.id = "test-user-id" + + tenant = Tenant() + tenant.id = "test-tenant" + tenant.name = "Test Workspace" + user._current_tenant = MagicMock() + user._current_tenant.id = "test-tenant" + + return user + + +@pytest.fixture +def repository(session, mock_user): """Create a repository instance with test data.""" _, session_factory = session - tenant_id = "test-tenant" app_id = "test-app" return SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=session_factory, tenant_id=tenant_id, app_id=app_id + session_factory=session_factory, + user=mock_user, + app_id=app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) @@ -45,16 +83,23 @@ def test_save(repository, session): execution = MagicMock(spec=WorkflowNodeExecution) execution.tenant_id = None execution.app_id = None + execution.inputs = None + execution.process_data = None + execution.outputs = None + execution.metadata = None + + # Mock the _to_db_model method to return the execution itself + # This simulates the behavior of setting tenant_id and app_id + repository._to_db_model = MagicMock(return_value=execution) # Call save method repository.save(execution) - # Assert tenant_id and app_id are set - assert execution.tenant_id == repository._tenant_id - assert execution.app_id == repository._app_id + # Assert _to_db_model was called with the execution + repository._to_db_model.assert_called_once_with(execution) - # Assert session.add was called - session_obj.add.assert_called_once_with(execution) + # Assert session.merge was called (now using merge for both save and update) + session_obj.merge.assert_called_once_with(execution) def test_save_with_existing_tenant_id(repository, session): @@ -64,16 +109,27 @@ def test_save_with_existing_tenant_id(repository, session): execution = MagicMock(spec=WorkflowNodeExecution) execution.tenant_id = "existing-tenant" execution.app_id = None + execution.inputs = None + execution.process_data = None + execution.outputs = None + execution.metadata = None + + # Create a modified execution that will be returned by _to_db_model + modified_execution = MagicMock(spec=WorkflowNodeExecution) + modified_execution.tenant_id = "existing-tenant" # Tenant ID should not change + modified_execution.app_id = repository._app_id # App ID should be set + + # Mock the _to_db_model method to return the modified execution + repository._to_db_model = MagicMock(return_value=modified_execution) # Call save method repository.save(execution) - # Assert tenant_id is not changed and app_id is set - assert execution.tenant_id == "existing-tenant" - assert execution.app_id == repository._app_id + # Assert _to_db_model was called with the execution + repository._to_db_model.assert_called_once_with(execution) - # Assert session.add was called - session_obj.add.assert_called_once_with(execution) + # Assert session.merge was called with the modified execution (now using merge for both save and update) + session_obj.merge.assert_called_once_with(modified_execution) def test_get_by_node_execution_id(repository, session, mocker: MockerFixture): @@ -84,7 +140,16 @@ def test_get_by_node_execution_id(repository, session, mocker: MockerFixture): mock_stmt = mocker.MagicMock() mock_select.return_value = mock_stmt mock_stmt.where.return_value = mock_stmt - session_obj.scalar.return_value = mocker.MagicMock(spec=WorkflowNodeExecution) + + # Create a properly configured mock execution + mock_execution = mocker.MagicMock(spec=WorkflowNodeExecution) + configure_mock_execution(mock_execution) + session_obj.scalar.return_value = mock_execution + + # Create a mock domain model to be returned by _to_domain_model + mock_domain_model = mocker.MagicMock() + # Mock the _to_domain_model method to return our mock domain model + repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model) # Call method result = repository.get_by_node_execution_id("test-node-execution-id") @@ -92,7 +157,10 @@ def test_get_by_node_execution_id(repository, session, mocker: MockerFixture): # Assert select was called with correct parameters mock_select.assert_called_once() session_obj.scalar.assert_called_once_with(mock_stmt) - assert result is not None + # Assert _to_domain_model was called with the mock execution + repository._to_domain_model.assert_called_once_with(mock_execution) + # Assert the result is our mock domain model + assert result is mock_domain_model def test_get_by_workflow_run(repository, session, mocker: MockerFixture): @@ -104,7 +172,16 @@ def test_get_by_workflow_run(repository, session, mocker: MockerFixture): mock_select.return_value = mock_stmt mock_stmt.where.return_value = mock_stmt mock_stmt.order_by.return_value = mock_stmt - session_obj.scalars.return_value.all.return_value = [mocker.MagicMock(spec=WorkflowNodeExecution)] + + # Create a properly configured mock execution + mock_execution = mocker.MagicMock(spec=WorkflowNodeExecution) + configure_mock_execution(mock_execution) + session_obj.scalars.return_value.all.return_value = [mock_execution] + + # Create a mock domain model to be returned by _to_domain_model + mock_domain_model = mocker.MagicMock() + # Mock the _to_domain_model method to return our mock domain model + repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model) # Call method order_config = OrderConfig(order_by=["index"], order_direction="desc") @@ -113,7 +190,45 @@ def test_get_by_workflow_run(repository, session, mocker: MockerFixture): # Assert select was called with correct parameters mock_select.assert_called_once() session_obj.scalars.assert_called_once_with(mock_stmt) + # Assert _to_domain_model was called with the mock execution + repository._to_domain_model.assert_called_once_with(mock_execution) + # Assert the result contains our mock domain model assert len(result) == 1 + assert result[0] is mock_domain_model + + +def test_get_db_models_by_workflow_run(repository, session, mocker: MockerFixture): + """Test get_db_models_by_workflow_run method.""" + session_obj, _ = session + # Set up mock + mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select") + mock_stmt = mocker.MagicMock() + mock_select.return_value = mock_stmt + mock_stmt.where.return_value = mock_stmt + mock_stmt.order_by.return_value = mock_stmt + + # Create a properly configured mock execution + mock_execution = mocker.MagicMock(spec=WorkflowNodeExecution) + configure_mock_execution(mock_execution) + session_obj.scalars.return_value.all.return_value = [mock_execution] + + # Mock the _to_domain_model method + to_domain_model_mock = mocker.patch.object(repository, "_to_domain_model") + + # Call method + order_config = OrderConfig(order_by=["index"], order_direction="desc") + result = repository.get_db_models_by_workflow_run(workflow_run_id="test-workflow-run-id", order_config=order_config) + + # Assert select was called with correct parameters + mock_select.assert_called_once() + session_obj.scalars.assert_called_once_with(mock_stmt) + + # Assert the result contains our mock db model directly (without conversion to domain model) + assert len(result) == 1 + assert result[0] is mock_execution + + # Verify that _to_domain_model was NOT called (since we're returning raw DB models) + to_domain_model_mock.assert_not_called() def test_get_running_executions(repository, session, mocker: MockerFixture): @@ -124,7 +239,16 @@ def test_get_running_executions(repository, session, mocker: MockerFixture): mock_stmt = mocker.MagicMock() mock_select.return_value = mock_stmt mock_stmt.where.return_value = mock_stmt - session_obj.scalars.return_value.all.return_value = [mocker.MagicMock(spec=WorkflowNodeExecution)] + + # Create a properly configured mock execution + mock_execution = mocker.MagicMock(spec=WorkflowNodeExecution) + configure_mock_execution(mock_execution) + session_obj.scalars.return_value.all.return_value = [mock_execution] + + # Create a mock domain model to be returned by _to_domain_model + mock_domain_model = mocker.MagicMock() + # Mock the _to_domain_model method to return our mock domain model + repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model) # Call method result = repository.get_running_executions("test-workflow-run-id") @@ -132,25 +256,36 @@ def test_get_running_executions(repository, session, mocker: MockerFixture): # Assert select was called with correct parameters mock_select.assert_called_once() session_obj.scalars.assert_called_once_with(mock_stmt) + # Assert _to_domain_model was called with the mock execution + repository._to_domain_model.assert_called_once_with(mock_execution) + # Assert the result contains our mock domain model assert len(result) == 1 + assert result[0] is mock_domain_model -def test_update(repository, session): - """Test update method.""" +def test_update_via_save(repository, session): + """Test updating an existing record via save method.""" session_obj, _ = session # Create a mock execution execution = MagicMock(spec=WorkflowNodeExecution) execution.tenant_id = None execution.app_id = None + execution.inputs = None + execution.process_data = None + execution.outputs = None + execution.metadata = None - # Call update method - repository.update(execution) + # Mock the _to_db_model method to return the execution itself + # This simulates the behavior of setting tenant_id and app_id + repository._to_db_model = MagicMock(return_value=execution) - # Assert tenant_id and app_id are set - assert execution.tenant_id == repository._tenant_id - assert execution.app_id == repository._app_id + # Call save method to update an existing record + repository.save(execution) - # Assert session.merge was called + # Assert _to_db_model was called with the execution + repository._to_db_model.assert_called_once_with(execution) + + # Assert session.merge was called (for updates) session_obj.merge.assert_called_once_with(execution) @@ -176,3 +311,118 @@ def test_clear(repository, session, mocker: MockerFixture): mock_stmt.where.assert_called() session_obj.execute.assert_called_once_with(mock_stmt) session_obj.commit.assert_called_once() + + +def test_to_db_model(repository): + """Test _to_db_model method.""" + # Create a domain model + domain_model = NodeExecution( + id="test-id", + workflow_id="test-workflow-id", + node_execution_id="test-node-execution-id", + workflow_run_id="test-workflow-run-id", + index=1, + predecessor_node_id="test-predecessor-id", + node_id="test-node-id", + node_type=NodeType.START, + title="Test Node", + inputs={"input_key": "input_value"}, + process_data={"process_key": "process_value"}, + outputs={"output_key": "output_value"}, + status=NodeExecutionStatus.RUNNING, + error=None, + elapsed_time=1.5, + metadata={NodeRunMetadataKey.TOTAL_TOKENS: 100}, + created_at=datetime.now(), + finished_at=None, + ) + + # Convert to DB model + db_model = repository._to_db_model(domain_model) + + # Assert DB model has correct values + assert isinstance(db_model, WorkflowNodeExecution) + assert db_model.id == domain_model.id + assert db_model.tenant_id == repository._tenant_id + assert db_model.app_id == repository._app_id + assert db_model.workflow_id == domain_model.workflow_id + assert db_model.triggered_from == repository._triggered_from + assert db_model.workflow_run_id == domain_model.workflow_run_id + assert db_model.index == domain_model.index + assert db_model.predecessor_node_id == domain_model.predecessor_node_id + assert db_model.node_execution_id == domain_model.node_execution_id + assert db_model.node_id == domain_model.node_id + assert db_model.node_type == domain_model.node_type + assert db_model.title == domain_model.title + + assert db_model.inputs_dict == domain_model.inputs + assert db_model.process_data_dict == domain_model.process_data + assert db_model.outputs_dict == domain_model.outputs + assert db_model.execution_metadata_dict == domain_model.metadata + + assert db_model.status == domain_model.status + assert db_model.error == domain_model.error + assert db_model.elapsed_time == domain_model.elapsed_time + assert db_model.created_at == domain_model.created_at + assert db_model.created_by_role == repository._creator_user_role + assert db_model.created_by == repository._creator_user_id + assert db_model.finished_at == domain_model.finished_at + + +def test_to_domain_model(repository): + """Test _to_domain_model method.""" + # Create input dictionaries + inputs_dict = {"input_key": "input_value"} + process_data_dict = {"process_key": "process_value"} + outputs_dict = {"output_key": "output_value"} + metadata_dict = {str(NodeRunMetadataKey.TOTAL_TOKENS): 100} + + # Create a DB model using our custom subclass + db_model = WorkflowNodeExecution() + db_model.id = "test-id" + db_model.tenant_id = "test-tenant-id" + db_model.app_id = "test-app-id" + db_model.workflow_id = "test-workflow-id" + db_model.triggered_from = "workflow-run" + db_model.workflow_run_id = "test-workflow-run-id" + db_model.index = 1 + db_model.predecessor_node_id = "test-predecessor-id" + db_model.node_execution_id = "test-node-execution-id" + db_model.node_id = "test-node-id" + db_model.node_type = NodeType.START.value + db_model.title = "Test Node" + db_model.inputs = json.dumps(inputs_dict) + db_model.process_data = json.dumps(process_data_dict) + db_model.outputs = json.dumps(outputs_dict) + db_model.status = WorkflowNodeExecutionStatus.RUNNING + db_model.error = None + db_model.elapsed_time = 1.5 + db_model.execution_metadata = json.dumps(metadata_dict) + db_model.created_at = datetime.now() + db_model.created_by_role = "account" + db_model.created_by = "test-user-id" + db_model.finished_at = None + + # Convert to domain model + domain_model = repository._to_domain_model(db_model) + + # Assert domain model has correct values + assert isinstance(domain_model, NodeExecution) + assert domain_model.id == db_model.id + assert domain_model.workflow_id == db_model.workflow_id + assert domain_model.workflow_run_id == db_model.workflow_run_id + assert domain_model.index == db_model.index + assert domain_model.predecessor_node_id == db_model.predecessor_node_id + assert domain_model.node_execution_id == db_model.node_execution_id + assert domain_model.node_id == db_model.node_id + assert domain_model.node_type == NodeType(db_model.node_type) + assert domain_model.title == db_model.title + assert domain_model.inputs == inputs_dict + assert domain_model.process_data == process_data_dict + assert domain_model.outputs == outputs_dict + assert domain_model.status == NodeExecutionStatus(db_model.status) + assert domain_model.error == db_model.error + assert domain_model.elapsed_time == db_model.elapsed_time + assert domain_model.metadata == metadata_dict + assert domain_model.created_at == db_model.created_at + assert domain_model.finished_at == db_model.finished_at