feat(workflow): domain model for workflow node execution (#19430)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
-LAN- 2025-05-17 00:56:16 +08:00 committed by GitHub
parent aeceb200ec
commit 4977bb21ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 1108 additions and 483 deletions

View File

@ -1,3 +1,6 @@
from typing import cast
from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse from flask_restful import Resource, marshal_with, reqparse
from flask_restful.inputs import int_range from flask_restful.inputs import int_range
@ -12,8 +15,7 @@ from fields.workflow_run_fields import (
) )
from libs.helper import uuid_value from libs.helper import uuid_value
from libs.login import login_required from libs.login import login_required
from models import App from models import Account, App, AppMode, EndUser
from models.model import AppMode
from services.workflow_run_service import WorkflowRunService from services.workflow_run_service import WorkflowRunService
@ -90,7 +92,12 @@ class WorkflowRunNodeExecutionListApi(Resource):
run_id = str(run_id) run_id = str(run_id)
workflow_run_service = WorkflowRunService() workflow_run_service = WorkflowRunService()
node_executions = workflow_run_service.get_workflow_run_node_executions(app_model=app_model, run_id=run_id) user = cast("Account | EndUser", current_user)
node_executions = workflow_run_service.get_workflow_run_node_executions(
app_model=app_model,
run_id=run_id,
user=user,
)
return {"data": node_executions} return {"data": node_executions}

View File

@ -29,9 +29,7 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
from models.account import Account from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.model import App, Conversation, EndUser, Message
from models.workflow import Workflow
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
from services.errors.message import MessageNotExistsError from services.errors.message import MessageNotExistsError
@ -165,8 +163,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory, session_factory=session_factory,
tenant_id=application_generate_entity.app_config.tenant_id, user=user,
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
) )
return self._generate( return self._generate(
@ -231,8 +230,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory, session_factory=session_factory,
tenant_id=application_generate_entity.app_config.tenant_id, user=user,
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
) )
return self._generate( return self._generate(
@ -295,8 +295,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory, session_factory=session_factory,
tenant_id=application_generate_entity.app_config.tenant_id, user=user,
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
) )
return self._generate( return self._generate(

View File

@ -70,7 +70,7 @@ from events.message_event import message_was_created
from extensions.ext_database import db from extensions.ext_database import db
from models import Conversation, EndUser, Message, MessageFile from models import Conversation, EndUser, Message, MessageFile
from models.account import Account from models.account import Account
from models.enums import CreatedByRole from models.enums import CreatorUserRole
from models.workflow import ( from models.workflow import (
Workflow, Workflow,
WorkflowRunStatus, WorkflowRunStatus,
@ -105,11 +105,11 @@ class AdvancedChatAppGenerateTaskPipeline:
if isinstance(user, EndUser): if isinstance(user, EndUser):
self._user_id = user.id self._user_id = user.id
user_session_id = user.session_id user_session_id = user.session_id
self._created_by_role = CreatedByRole.END_USER self._created_by_role = CreatorUserRole.END_USER
elif isinstance(user, Account): elif isinstance(user, Account):
self._user_id = user.id self._user_id = user.id
user_session_id = user.id user_session_id = user.id
self._created_by_role = CreatedByRole.ACCOUNT self._created_by_role = CreatorUserRole.ACCOUNT
else: else:
raise NotImplementedError(f"User type not supported: {type(user)}") raise NotImplementedError(f"User type not supported: {type(user)}")
@ -739,9 +739,9 @@ class AdvancedChatAppGenerateTaskPipeline:
url=file["remote_url"], url=file["remote_url"],
belongs_to="assistant", belongs_to="assistant",
upload_file_id=file["related_id"], upload_file_id=file["related_id"],
created_by_role=CreatedByRole.ACCOUNT created_by_role=CreatorUserRole.ACCOUNT
if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else CreatedByRole.END_USER, else CreatorUserRole.END_USER,
created_by=message.from_account_id or message.from_end_user_id or "", created_by=message.from_account_id or message.from_end_user_id or "",
) )
for file in self._recorded_files for file in self._recorded_files

View File

@ -25,7 +25,7 @@ from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBa
from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from extensions.ext_database import db from extensions.ext_database import db
from models import Account from models import Account
from models.enums import CreatedByRole from models.enums import CreatorUserRole
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.conversation import ConversationNotExistsError from services.errors.conversation import ConversationNotExistsError
@ -223,7 +223,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
belongs_to="user", belongs_to="user",
url=file.remote_url, url=file.remote_url,
upload_file_id=file.related_id, upload_file_id=file.related_id,
created_by_role=(CreatedByRole.ACCOUNT if account_id else CreatedByRole.END_USER), created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER),
created_by=account_id or end_user_id or "", created_by=account_id or end_user_id or "",
) )
db.session.add(message_file) db.session.add(message_file)

View File

@ -27,7 +27,7 @@ from core.workflow.repository.workflow_node_execution_repository import Workflow
from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
from models import Account, App, EndUser, Workflow from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -138,10 +138,12 @@ class WorkflowAppGenerator(BaseAppGenerator):
# Create workflow node execution repository # Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory, session_factory=session_factory,
tenant_id=application_generate_entity.app_config.tenant_id, user=user,
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
) )
return self._generate( return self._generate(
@ -262,10 +264,12 @@ class WorkflowAppGenerator(BaseAppGenerator):
# Create workflow node execution repository # Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory, session_factory=session_factory,
tenant_id=application_generate_entity.app_config.tenant_id, user=user,
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
) )
return self._generate( return self._generate(
@ -325,10 +329,12 @@ class WorkflowAppGenerator(BaseAppGenerator):
# Create workflow node execution repository # Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory, session_factory=session_factory,
tenant_id=application_generate_entity.app_config.tenant_id, user=user,
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
) )
return self._generate( return self._generate(

View File

@ -6,7 +6,7 @@ from pydantic import BaseModel, ConfigDict
from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.node_entities import AgentNodeStrategyInit from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
@ -244,7 +244,7 @@ class NodeStartStreamResponse(StreamResponse):
title: str title: str
index: int index: int
predecessor_node_id: Optional[str] = None predecessor_node_id: Optional[str] = None
inputs: Optional[dict] = None inputs: Optional[Mapping[str, Any]] = None
created_at: int created_at: int
extras: dict = {} extras: dict = {}
parallel_id: Optional[str] = None parallel_id: Optional[str] = None
@ -301,13 +301,13 @@ class NodeFinishStreamResponse(StreamResponse):
title: str title: str
index: int index: int
predecessor_node_id: Optional[str] = None predecessor_node_id: Optional[str] = None
inputs: Optional[dict] = None inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[dict] = None process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[dict] = None outputs: Optional[Mapping[str, Any]] = None
status: str status: str
error: Optional[str] = None error: Optional[str] = None
elapsed_time: float elapsed_time: float
execution_metadata: Optional[dict] = None execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
created_at: int created_at: int
finished_at: int finished_at: int
files: Optional[Sequence[Mapping[str, Any]]] = [] files: Optional[Sequence[Mapping[str, Any]]] = []
@ -370,13 +370,13 @@ class NodeRetryStreamResponse(StreamResponse):
title: str title: str
index: int index: int
predecessor_node_id: Optional[str] = None predecessor_node_id: Optional[str] = None
inputs: Optional[dict] = None inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[dict] = None process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[dict] = None outputs: Optional[Mapping[str, Any]] = None
status: str status: str
error: Optional[str] = None error: Optional[str] = None
elapsed_time: float elapsed_time: float
execution_metadata: Optional[dict] = None execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
created_at: int created_at: int
finished_at: int finished_at: int
files: Optional[Sequence[Mapping[str, Any]]] = [] files: Optional[Sequence[Mapping[str, Any]]] = []

View File

@ -1,3 +1,4 @@
from collections.abc import Mapping
from datetime import datetime from datetime import datetime
from enum import StrEnum from enum import StrEnum
from typing import Any, Optional, Union from typing import Any, Optional, Union
@ -155,10 +156,10 @@ class LangfuseSpan(BaseModel):
description="The status message of the span. Additional field for context of the event. E.g. the error " description="The status message of the span. Additional field for context of the event. E.g. the error "
"message of an error event.", "message of an error event.",
) )
input: Optional[Union[str, dict[str, Any], list, None]] = Field( input: Optional[Union[str, Mapping[str, Any], list, None]] = Field(
default=None, description="The input of the span. Can be any JSON object." default=None, description="The input of the span. Can be any JSON object."
) )
output: Optional[Union[str, dict[str, Any], list, None]] = Field( output: Optional[Union[str, Mapping[str, Any], list, None]] = Field(
default=None, description="The output of the span. Can be any JSON object." default=None, description="The output of the span. Can be any JSON object."
) )
version: Optional[str] = Field( version: Optional[str] = Field(

View File

@ -1,11 +1,10 @@
import json
import logging import logging
import os import os
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional from typing import Optional
from langfuse import Langfuse # type: ignore from langfuse import Langfuse # type: ignore
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import Session, sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import LangfuseConfig from core.ops.entities.config_entity import LangfuseConfig
@ -30,8 +29,9 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
) )
from core.ops.utils import filter_none_values from core.ops.utils import filter_none_values
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db from extensions.ext_database import db
from models.model import EndUser from models import Account, App, EndUser, WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -113,8 +113,29 @@ class LangFuseDataTrace(BaseTraceInstance):
# through workflow_run_id get all_nodes_execution using repository # through workflow_run_id get all_nodes_execution using repository
session_factory = sessionmaker(bind=db.engine) session_factory = sessionmaker(bind=db.engine)
# Find the app's creator account
with Session(db.engine, expire_on_commit=False) as session:
# Get the app to find its creator
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
app = session.query(App).filter(App.id == app_id).first()
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
service_account = session.query(Account).filter(Account.id == app.created_by).first()
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory, tenant_id=trace_info.tenant_id session_factory=session_factory,
user=service_account,
app_id=trace_info.metadata.get("app_id"),
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
) )
# Get all executions for this workflow run # Get all executions for this workflow run
@ -124,23 +145,22 @@ class LangFuseDataTrace(BaseTraceInstance):
for node_execution in workflow_node_executions: for node_execution in workflow_node_executions:
node_execution_id = node_execution.id node_execution_id = node_execution.id
tenant_id = node_execution.tenant_id tenant_id = trace_info.tenant_id # Use from trace_info instead
app_id = node_execution.app_id app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
node_name = node_execution.title node_name = node_execution.title
node_type = node_execution.node_type node_type = node_execution.node_type
status = node_execution.status status = node_execution.status
if node_type == "llm": if node_type == NodeType.LLM:
inputs = ( inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
)
else: else:
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} inputs = node_execution.inputs if node_execution.inputs else {}
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} outputs = node_execution.outputs if node_execution.outputs else {}
created_at = node_execution.created_at or datetime.now() created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time) finished_at = created_at + timedelta(seconds=elapsed_time)
metadata = json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} execution_metadata = node_execution.metadata if node_execution.metadata else {}
metadata = {str(k): v for k, v in execution_metadata.items()}
metadata.update( metadata.update(
{ {
"workflow_run_id": trace_info.workflow_run_id, "workflow_run_id": trace_info.workflow_run_id,
@ -152,7 +172,7 @@ class LangFuseDataTrace(BaseTraceInstance):
"status": status, "status": status,
} }
) )
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} process_data = node_execution.process_data if node_execution.process_data else {}
model_provider = process_data.get("model_provider", None) model_provider = process_data.get("model_provider", None)
model_name = process_data.get("model_name", None) model_name = process_data.get("model_name", None)
if model_provider is not None and model_name is not None: if model_provider is not None and model_name is not None:

View File

@ -1,3 +1,4 @@
from collections.abc import Mapping
from datetime import datetime from datetime import datetime
from enum import StrEnum from enum import StrEnum
from typing import Any, Optional, Union from typing import Any, Optional, Union
@ -30,8 +31,8 @@ class LangSmithMultiModel(BaseModel):
class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
name: Optional[str] = Field(..., description="Name of the run") name: Optional[str] = Field(..., description="Name of the run")
inputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Inputs of the run") inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the run")
outputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Outputs of the run") outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the run")
run_type: LangSmithRunType = Field(..., description="Type of the run") run_type: LangSmithRunType = Field(..., description="Type of the run")
start_time: Optional[datetime | str] = Field(None, description="Start time of the run") start_time: Optional[datetime | str] = Field(None, description="Start time of the run")
end_time: Optional[datetime | str] = Field(None, description="End time of the run") end_time: Optional[datetime | str] = Field(None, description="End time of the run")

View File

@ -1,4 +1,3 @@
import json
import logging import logging
import os import os
import uuid import uuid
@ -7,7 +6,7 @@ from typing import Optional, cast
from langsmith import Client from langsmith import Client
from langsmith.schemas import RunBase from langsmith.schemas import RunBase
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import Session, sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import LangSmithConfig from core.ops.entities.config_entity import LangSmithConfig
@ -29,8 +28,10 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
) )
from core.ops.utils import filter_none_values, generate_dotted_order from core.ops.utils import filter_none_values, generate_dotted_order
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db from extensions.ext_database import db
from models.model import EndUser, MessageFile from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -137,8 +138,29 @@ class LangSmithDataTrace(BaseTraceInstance):
# through workflow_run_id get all_nodes_execution using repository # through workflow_run_id get all_nodes_execution using repository
session_factory = sessionmaker(bind=db.engine) session_factory = sessionmaker(bind=db.engine)
# Find the app's creator account
with Session(db.engine, expire_on_commit=False) as session:
# Get the app to find its creator
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
app = session.query(App).filter(App.id == app_id).first()
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
service_account = session.query(Account).filter(Account.id == app.created_by).first()
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory, tenant_id=trace_info.tenant_id, app_id=trace_info.metadata.get("app_id") session_factory=session_factory,
user=service_account,
app_id=trace_info.metadata.get("app_id"),
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
) )
# Get all executions for this workflow run # Get all executions for this workflow run
@ -148,27 +170,23 @@ class LangSmithDataTrace(BaseTraceInstance):
for node_execution in workflow_node_executions: for node_execution in workflow_node_executions:
node_execution_id = node_execution.id node_execution_id = node_execution.id
tenant_id = node_execution.tenant_id tenant_id = trace_info.tenant_id # Use from trace_info instead
app_id = node_execution.app_id app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
node_name = node_execution.title node_name = node_execution.title
node_type = node_execution.node_type node_type = node_execution.node_type
status = node_execution.status status = node_execution.status
if node_type == "llm": if node_type == NodeType.LLM:
inputs = ( inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
)
else: else:
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} inputs = node_execution.inputs if node_execution.inputs else {}
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} outputs = node_execution.outputs if node_execution.outputs else {}
created_at = node_execution.created_at or datetime.now() created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time) finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = ( execution_metadata = node_execution.metadata if node_execution.metadata else {}
json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0
) metadata = {str(key): value for key, value in execution_metadata.items()}
node_total_tokens = execution_metadata.get("total_tokens", 0)
metadata = execution_metadata.copy()
metadata.update( metadata.update(
{ {
"workflow_run_id": trace_info.workflow_run_id, "workflow_run_id": trace_info.workflow_run_id,
@ -181,7 +199,7 @@ class LangSmithDataTrace(BaseTraceInstance):
} }
) )
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} process_data = node_execution.process_data if node_execution.process_data else {}
if process_data and process_data.get("model_mode") == "chat": if process_data and process_data.get("model_mode") == "chat":
run_type = LangSmithRunType.llm run_type = LangSmithRunType.llm
@ -191,7 +209,7 @@ class LangSmithDataTrace(BaseTraceInstance):
"ls_model_name": process_data.get("model_name", ""), "ls_model_name": process_data.get("model_name", ""),
} }
) )
elif node_type == "knowledge-retrieval": elif node_type == NodeType.KNOWLEDGE_RETRIEVAL:
run_type = LangSmithRunType.retriever run_type = LangSmithRunType.retriever
else: else:
run_type = LangSmithRunType.tool run_type = LangSmithRunType.tool

View File

@ -1,4 +1,3 @@
import json
import logging import logging
import os import os
import uuid import uuid
@ -7,7 +6,7 @@ from typing import Optional, cast
from opik import Opik, Trace from opik import Opik, Trace
from opik.id_helpers import uuid4_to_uuid7 from opik.id_helpers import uuid4_to_uuid7
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import Session, sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import OpikConfig from core.ops.entities.config_entity import OpikConfig
@ -23,8 +22,10 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo, WorkflowTraceInfo,
) )
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db from extensions.ext_database import db
from models.model import EndUser, MessageFile from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -150,8 +151,29 @@ class OpikDataTrace(BaseTraceInstance):
# through workflow_run_id get all_nodes_execution using repository # through workflow_run_id get all_nodes_execution using repository
session_factory = sessionmaker(bind=db.engine) session_factory = sessionmaker(bind=db.engine)
# Find the app's creator account
with Session(db.engine, expire_on_commit=False) as session:
# Get the app to find its creator
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
app = session.query(App).filter(App.id == app_id).first()
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
service_account = session.query(Account).filter(Account.id == app.created_by).first()
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory, tenant_id=trace_info.tenant_id, app_id=trace_info.metadata.get("app_id") session_factory=session_factory,
user=service_account,
app_id=trace_info.metadata.get("app_id"),
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
) )
# Get all executions for this workflow run # Get all executions for this workflow run
@ -161,26 +183,22 @@ class OpikDataTrace(BaseTraceInstance):
for node_execution in workflow_node_executions: for node_execution in workflow_node_executions:
node_execution_id = node_execution.id node_execution_id = node_execution.id
tenant_id = node_execution.tenant_id tenant_id = trace_info.tenant_id # Use from trace_info instead
app_id = node_execution.app_id app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
node_name = node_execution.title node_name = node_execution.title
node_type = node_execution.node_type node_type = node_execution.node_type
status = node_execution.status status = node_execution.status
if node_type == "llm": if node_type == NodeType.LLM:
inputs = ( inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
)
else: else:
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} inputs = node_execution.inputs if node_execution.inputs else {}
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} outputs = node_execution.outputs if node_execution.outputs else {}
created_at = node_execution.created_at or datetime.now() created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time) finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = ( execution_metadata = node_execution.metadata if node_execution.metadata else {}
json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} metadata = {str(k): v for k, v in execution_metadata.items()}
)
metadata = execution_metadata.copy()
metadata.update( metadata.update(
{ {
"workflow_run_id": trace_info.workflow_run_id, "workflow_run_id": trace_info.workflow_run_id,
@ -193,7 +211,7 @@ class OpikDataTrace(BaseTraceInstance):
} }
) )
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} process_data = node_execution.process_data if node_execution.process_data else {}
provider = None provider = None
model = None model = None
@ -226,7 +244,7 @@ class OpikDataTrace(BaseTraceInstance):
parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id
if not total_tokens: if not total_tokens:
total_tokens = execution_metadata.get("total_tokens", 0) total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0
span_data = { span_data = {
"trace_id": opik_trace_id, "trace_id": opik_trace_id,

View File

@ -1,3 +1,4 @@
from collections.abc import Mapping
from typing import Any, Optional, Union from typing import Any, Optional, Union
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
@ -19,8 +20,8 @@ class WeaveMultiModel(BaseModel):
class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel): class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
id: str = Field(..., description="ID of the trace") id: str = Field(..., description="ID of the trace")
op: str = Field(..., description="Name of the operation") op: str = Field(..., description="Name of the operation")
inputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Inputs of the trace") inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the trace")
outputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Outputs of the trace") outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the trace")
attributes: Optional[Union[str, dict[str, Any], list, None]] = Field( attributes: Optional[Union[str, dict[str, Any], list, None]] = Field(
None, description="Metadata and attributes associated with trace" None, description="Metadata and attributes associated with trace"
) )

View File

@ -1,4 +1,3 @@
import json
import logging import logging
import os import os
import uuid import uuid
@ -7,6 +6,7 @@ from typing import Any, Optional, cast
import wandb import wandb
import weave import weave
from sqlalchemy.orm import Session, sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import WeaveConfig from core.ops.entities.config_entity import WeaveConfig
@ -22,9 +22,11 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo, WorkflowTraceInfo,
) )
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db from extensions.ext_database import db
from models.model import EndUser, MessageFile from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
from models.workflow import WorkflowNodeExecution
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -128,58 +130,57 @@ class WeaveDataTrace(BaseTraceInstance):
self.start_call(workflow_run, parent_run_id=trace_info.message_id) self.start_call(workflow_run, parent_run_id=trace_info.message_id)
# through workflow_run_id get all_nodes_execution # through workflow_run_id get all_nodes_execution using repository
workflow_nodes_execution_id_records = ( session_factory = sessionmaker(bind=db.engine)
db.session.query(WorkflowNodeExecution.id) # Find the app's creator account
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id) with Session(db.engine, expire_on_commit=False) as session:
.all() # Get the app to find its creator
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
app = session.query(App).filter(App.id == app_id).first()
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
service_account = session.query(Account).filter(Account.id == app.created_by).first()
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=service_account,
app_id=trace_info.metadata.get("app_id"),
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
) )
for node_execution_id_record in workflow_nodes_execution_id_records: # Get all executions for this workflow run
node_execution = ( workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
db.session.query( workflow_run_id=trace_info.workflow_run_id
WorkflowNodeExecution.id,
WorkflowNodeExecution.tenant_id,
WorkflowNodeExecution.app_id,
WorkflowNodeExecution.title,
WorkflowNodeExecution.node_type,
WorkflowNodeExecution.status,
WorkflowNodeExecution.inputs,
WorkflowNodeExecution.outputs,
WorkflowNodeExecution.created_at,
WorkflowNodeExecution.elapsed_time,
WorkflowNodeExecution.process_data,
WorkflowNodeExecution.execution_metadata,
)
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
.first()
) )
if not node_execution: for node_execution in workflow_node_executions:
continue
node_execution_id = node_execution.id node_execution_id = node_execution.id
tenant_id = node_execution.tenant_id tenant_id = trace_info.tenant_id # Use from trace_info instead
app_id = node_execution.app_id app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
node_name = node_execution.title node_name = node_execution.title
node_type = node_execution.node_type node_type = node_execution.node_type
status = node_execution.status status = node_execution.status
if node_type == "llm": if node_type == NodeType.LLM:
inputs = ( inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
)
else: else:
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} inputs = node_execution.inputs if node_execution.inputs else {}
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} outputs = node_execution.outputs if node_execution.outputs else {}
created_at = node_execution.created_at or datetime.now() created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time) finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = ( execution_metadata = node_execution.metadata if node_execution.metadata else {}
json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0
) attributes = {str(k): v for k, v in execution_metadata.items()}
node_total_tokens = execution_metadata.get("total_tokens", 0)
attributes = execution_metadata.copy()
attributes.update( attributes.update(
{ {
"workflow_run_id": trace_info.workflow_run_id, "workflow_run_id": trace_info.workflow_run_id,
@ -192,7 +193,7 @@ class WeaveDataTrace(BaseTraceInstance):
} }
) )
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} process_data = node_execution.process_data if node_execution.process_data else {}
if process_data and process_data.get("model_mode") == "chat": if process_data and process_data.get("model_mode") == "chat":
attributes.update( attributes.update(
{ {

View File

@ -19,7 +19,7 @@ from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_storage import storage from extensions.ext_storage import storage
from models.enums import CreatedByRole from models.enums import CreatorUserRole
from models.model import UploadFile from models.model import UploadFile
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -116,7 +116,7 @@ class WordExtractor(BaseExtractor):
extension=str(image_ext), extension=str(image_ext),
mime_type=mime_type or "", mime_type=mime_type or "",
created_by=self.user_id, created_by=self.user_id,
created_by_role=CreatedByRole.ACCOUNT, created_by_role=CreatorUserRole.ACCOUNT,
created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
used=True, used=True,
used_by=self.user_id, used_by=self.user_id,

View File

@ -2,16 +2,29 @@
SQLAlchemy implementation of the WorkflowNodeExecutionRepository. SQLAlchemy implementation of the WorkflowNodeExecutionRepository.
""" """
import json
import logging import logging
from collections.abc import Sequence from collections.abc import Sequence
from typing import Optional from typing import Optional, Union
from sqlalchemy import UnaryExpression, asc, delete, desc, select from sqlalchemy import UnaryExpression, asc, delete, desc, select
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from core.workflow.entities.node_execution_entities import (
NodeExecution,
NodeExecutionStatus,
)
from core.workflow.nodes.enums import NodeType
from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom from models import (
Account,
CreatorUserRole,
EndUser,
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
WorkflowNodeExecutionTriggeredFrom,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -23,16 +36,26 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
This implementation supports multi-tenancy by filtering operations based on tenant_id. This implementation supports multi-tenancy by filtering operations based on tenant_id.
Each method creates its own session, handles the transaction, and commits changes Each method creates its own session, handles the transaction, and commits changes
to the database. This prevents long-running connections in the workflow core. to the database. This prevents long-running connections in the workflow core.
This implementation also includes an in-memory cache for node executions to improve
performance by reducing database queries.
""" """
def __init__(self, session_factory: sessionmaker | Engine, tenant_id: str, app_id: Optional[str] = None): def __init__(
self,
session_factory: sessionmaker | Engine,
user: Union[Account, EndUser],
app_id: Optional[str],
triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom],
):
""" """
Initialize the repository with a SQLAlchemy sessionmaker or engine and tenant context. Initialize the repository with a SQLAlchemy sessionmaker or engine and context information.
Args: Args:
session_factory: SQLAlchemy sessionmaker or engine for creating sessions session_factory: SQLAlchemy sessionmaker or engine for creating sessions
tenant_id: Tenant ID for multi-tenancy user: Account or EndUser object containing tenant_id, user ID, and role information
app_id: Optional app ID for filtering by application app_id: App ID for filtering by application (can be None)
triggered_from: Source of the execution trigger (SINGLE_STEP or WORKFLOW_RUN)
""" """
# If an engine is provided, create a sessionmaker from it # If an engine is provided, create a sessionmaker from it
if isinstance(session_factory, Engine): if isinstance(session_factory, Engine):
@ -44,38 +67,155 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine" f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
) )
# Extract tenant_id from user
tenant_id: str | None = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id
if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id self._tenant_id = tenant_id
# Store app context
self._app_id = app_id self._app_id = app_id
def save(self, execution: WorkflowNodeExecution) -> None: # Extract user context
self._triggered_from = triggered_from
self._creator_user_id = user.id
# Determine user role based on user type
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
# Initialize in-memory cache for node executions
# Key: node_execution_id, Value: NodeExecution
self._node_execution_cache: dict[str, NodeExecution] = {}
def _to_domain_model(self, db_model: WorkflowNodeExecution) -> NodeExecution:
""" """
Save a WorkflowNodeExecution instance and commit changes to the database. Convert a database model to a domain model.
Args: Args:
execution: The WorkflowNodeExecution instance to save db_model: The database model to convert
Returns:
The domain model
"""
# Parse JSON fields
inputs = db_model.inputs_dict
process_data = db_model.process_data_dict
outputs = db_model.outputs_dict
metadata = db_model.execution_metadata_dict
# Convert status to domain enum
status = NodeExecutionStatus(db_model.status)
return NodeExecution(
id=db_model.id,
node_execution_id=db_model.node_execution_id,
workflow_id=db_model.workflow_id,
workflow_run_id=db_model.workflow_run_id,
index=db_model.index,
predecessor_node_id=db_model.predecessor_node_id,
node_id=db_model.node_id,
node_type=NodeType(db_model.node_type),
title=db_model.title,
inputs=inputs,
process_data=process_data,
outputs=outputs,
status=status,
error=db_model.error,
elapsed_time=db_model.elapsed_time,
metadata=metadata,
created_at=db_model.created_at,
finished_at=db_model.finished_at,
)
def _to_db_model(self, domain_model: NodeExecution) -> WorkflowNodeExecution:
"""
Convert a domain model to a database model.
Args:
domain_model: The domain model to convert
Returns:
The database model
"""
# Use values from constructor if provided
if not self._triggered_from:
raise ValueError("triggered_from is required in repository constructor")
if not self._creator_user_id:
raise ValueError("created_by is required in repository constructor")
if not self._creator_user_role:
raise ValueError("created_by_role is required in repository constructor")
db_model = WorkflowNodeExecution()
db_model.id = domain_model.id
db_model.tenant_id = self._tenant_id
if self._app_id is not None:
db_model.app_id = self._app_id
db_model.workflow_id = domain_model.workflow_id
db_model.triggered_from = self._triggered_from
db_model.workflow_run_id = domain_model.workflow_run_id
db_model.index = domain_model.index
db_model.predecessor_node_id = domain_model.predecessor_node_id
db_model.node_execution_id = domain_model.node_execution_id
db_model.node_id = domain_model.node_id
db_model.node_type = domain_model.node_type
db_model.title = domain_model.title
db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None
db_model.process_data = json.dumps(domain_model.process_data) if domain_model.process_data else None
db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None
db_model.status = domain_model.status
db_model.error = domain_model.error
db_model.elapsed_time = domain_model.elapsed_time
db_model.execution_metadata = json.dumps(domain_model.metadata) if domain_model.metadata else None
db_model.created_at = domain_model.created_at
db_model.created_by_role = self._creator_user_role
db_model.created_by = self._creator_user_id
db_model.finished_at = domain_model.finished_at
return db_model
def save(self, execution: NodeExecution) -> None:
"""
Save or update a NodeExecution instance and commit changes to the database.
This method handles both creating new records and updating existing ones.
It determines whether to create or update based on whether the record
already exists in the database. It also updates the in-memory cache.
Args:
execution: The NodeExecution instance to save or update
""" """
with self._session_factory() as session: with self._session_factory() as session:
# Ensure tenant_id is set # Convert domain model to database model using instance attributes
if not execution.tenant_id: db_model = self._to_db_model(execution)
execution.tenant_id = self._tenant_id
# Set app_id if provided and not already set # Use merge which will handle both insert and update
if self._app_id and not execution.app_id: session.merge(db_model)
execution.app_id = self._app_id
session.add(execution)
session.commit() session.commit()
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]: # Update the cache if node_execution_id is present
if execution.node_execution_id:
logger.debug(f"Updating cache for node_execution_id: {execution.node_execution_id}")
self._node_execution_cache[execution.node_execution_id] = execution
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[NodeExecution]:
""" """
Retrieve a WorkflowNodeExecution by its node_execution_id. Retrieve a NodeExecution by its node_execution_id.
First checks the in-memory cache, and if not found, queries the database.
If found in the database, adds it to the cache for future lookups.
Args: Args:
node_execution_id: The node execution ID node_execution_id: The node execution ID
Returns: Returns:
The WorkflowNodeExecution instance if found, None otherwise The NodeExecution instance if found, None otherwise
""" """
# First check the cache
if node_execution_id in self._node_execution_cache:
logger.debug(f"Cache hit for node_execution_id: {node_execution_id}")
return self._node_execution_cache[node_execution_id]
# If not in cache, query the database
logger.debug(f"Cache miss for node_execution_id: {node_execution_id}, querying database")
with self._session_factory() as session: with self._session_factory() as session:
stmt = select(WorkflowNodeExecution).where( stmt = select(WorkflowNodeExecution).where(
WorkflowNodeExecution.node_execution_id == node_execution_id, WorkflowNodeExecution.node_execution_id == node_execution_id,
@ -85,15 +225,28 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
if self._app_id: if self._app_id:
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
return session.scalar(stmt) db_model = session.scalar(stmt)
if db_model:
# Convert to domain model
domain_model = self._to_domain_model(db_model)
# Add to cache
self._node_execution_cache[node_execution_id] = domain_model
return domain_model
return None
def get_by_workflow_run( def get_by_workflow_run(
self, self,
workflow_run_id: str, workflow_run_id: str,
order_config: Optional[OrderConfig] = None, order_config: Optional[OrderConfig] = None,
) -> Sequence[WorkflowNodeExecution]: ) -> Sequence[NodeExecution]:
""" """
Retrieve all WorkflowNodeExecution instances for a specific workflow run. Retrieve all NodeExecution instances for a specific workflow run.
This method always queries the database to ensure complete and ordered results,
but updates the cache with any retrieved executions.
Args: Args:
workflow_run_id: The workflow run ID workflow_run_id: The workflow run ID
@ -102,7 +255,42 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
order_config.order_direction: Direction to order ("asc" or "desc") order_config.order_direction: Direction to order ("asc" or "desc")
Returns: Returns:
A list of WorkflowNodeExecution instances A list of NodeExecution instances
"""
# Get the raw database models using the new method
db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config)
# Convert database models to domain models and update cache
domain_models = []
for model in db_models:
domain_model = self._to_domain_model(model)
# Update cache if node_execution_id is present
if domain_model.node_execution_id:
self._node_execution_cache[domain_model.node_execution_id] = domain_model
domain_models.append(domain_model)
return domain_models
def get_db_models_by_workflow_run(
self,
workflow_run_id: str,
order_config: Optional[OrderConfig] = None,
) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all WorkflowNodeExecution database models for a specific workflow run.
This method is similar to get_by_workflow_run but returns the raw database models
instead of converting them to domain models. This can be useful when direct access
to database model properties is needed.
Args:
workflow_run_id: The workflow run ID
order_config: Optional configuration for ordering results
order_config.order_by: List of fields to order by (e.g., ["index", "created_at"])
order_config.order_direction: Direction to order ("asc" or "desc")
Returns:
A list of WorkflowNodeExecution database models
""" """
with self._session_factory() as session: with self._session_factory() as session:
stmt = select(WorkflowNodeExecution).where( stmt = select(WorkflowNodeExecution).where(
@ -129,17 +317,25 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
if order_columns: if order_columns:
stmt = stmt.order_by(*order_columns) stmt = stmt.order_by(*order_columns)
return session.scalars(stmt).all() db_models = session.scalars(stmt).all()
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]: # Note: We don't update the cache here since we're returning raw DB models
# and not converting to domain models
return db_models
def get_running_executions(self, workflow_run_id: str) -> Sequence[NodeExecution]:
""" """
Retrieve all running WorkflowNodeExecution instances for a specific workflow run. Retrieve all running NodeExecution instances for a specific workflow run.
This method queries the database directly and updates the cache with any
retrieved executions that have a node_execution_id.
Args: Args:
workflow_run_id: The workflow run ID workflow_run_id: The workflow run ID
Returns: Returns:
A list of running WorkflowNodeExecution instances A list of running NodeExecution instances
""" """
with self._session_factory() as session: with self._session_factory() as session:
stmt = select(WorkflowNodeExecution).where( stmt = select(WorkflowNodeExecution).where(
@ -152,26 +348,17 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
if self._app_id: if self._app_id:
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
return session.scalars(stmt).all() db_models = session.scalars(stmt).all()
domain_models = []
def update(self, execution: WorkflowNodeExecution) -> None: for model in db_models:
""" domain_model = self._to_domain_model(model)
Update an existing WorkflowNodeExecution instance and commit changes to the database. # Update cache if node_execution_id is present
if domain_model.node_execution_id:
self._node_execution_cache[domain_model.node_execution_id] = domain_model
domain_models.append(domain_model)
Args: return domain_models
execution: The WorkflowNodeExecution instance to update
"""
with self._session_factory() as session:
# Ensure tenant_id is set
if not execution.tenant_id:
execution.tenant_id = self._tenant_id
# Set app_id if provided and not already set
if self._app_id and not execution.app_id:
execution.app_id = self._app_id
session.merge(execution)
session.commit()
def clear(self) -> None: def clear(self) -> None:
""" """
@ -179,6 +366,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
This method deletes all WorkflowNodeExecution records that match the tenant_id This method deletes all WorkflowNodeExecution records that match the tenant_id
and app_id (if provided) associated with this repository instance. and app_id (if provided) associated with this repository instance.
It also clears the in-memory cache.
""" """
with self._session_factory() as session: with self._session_factory() as session:
stmt = delete(WorkflowNodeExecution).where(WorkflowNodeExecution.tenant_id == self._tenant_id) stmt = delete(WorkflowNodeExecution).where(WorkflowNodeExecution.tenant_id == self._tenant_id)
@ -194,3 +382,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}" f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}"
+ (f" and app {self._app_id}" if self._app_id else "") + (f" and app {self._app_id}" if self._app_id else "")
) )
# Clear the in-memory cache
self._node_execution_cache.clear()
logger.info("Cleared in-memory node execution cache")

View File

@ -32,7 +32,7 @@ from core.tools.errors import (
from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.tools.workflow_as_tool.tool import WorkflowTool from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db from extensions.ext_database import db
from models.enums import CreatedByRole from models.enums import CreatorUserRole
from models.model import Message, MessageFile from models.model import Message, MessageFile
@ -339,9 +339,9 @@ class ToolEngine:
url=message.url, url=message.url,
upload_file_id=tool_file_id, upload_file_id=tool_file_id,
created_by_role=( created_by_role=(
CreatedByRole.ACCOUNT CreatorUserRole.ACCOUNT
if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else CreatedByRole.END_USER else CreatorUserRole.END_USER
), ),
created_by=user_id, created_by=user_id,
) )

View File

@ -0,0 +1,98 @@
"""
Domain entities for workflow node execution.
This module contains the domain model for workflow node execution, which is used
by the core workflow module. These models are independent of the storage mechanism
and don't contain implementation details like tenant_id, app_id, etc.
"""
from collections.abc import Mapping
from datetime import datetime
from enum import StrEnum
from typing import Any, Optional
from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.nodes.enums import NodeType
class NodeExecutionStatus(StrEnum):
"""
Node Execution Status Enum.
"""
RUNNING = "running"
SUCCEEDED = "succeeded"
FAILED = "failed"
EXCEPTION = "exception"
RETRY = "retry"
class NodeExecution(BaseModel):
"""
Domain model for workflow node execution.
This model represents the core business entity of a node execution,
without implementation details like tenant_id, app_id, etc.
Note: User/context-specific fields (triggered_from, created_by, created_by_role)
have been moved to the repository implementation to keep the domain model clean.
These fields are still accepted in the constructor for backward compatibility,
but they are not stored in the model.
"""
# Core identification fields
id: str # Unique identifier for this execution record
node_execution_id: Optional[str] = None # Optional secondary ID for cross-referencing
workflow_id: str # ID of the workflow this node belongs to
workflow_run_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging)
# Execution positioning and flow
index: int # Sequence number for ordering in trace visualization
predecessor_node_id: Optional[str] = None # ID of the node that executed before this one
node_id: str # ID of the node being executed
node_type: NodeType # Type of node (e.g., start, llm, knowledge)
title: str # Display title of the node
# Execution data
inputs: Optional[Mapping[str, Any]] = None # Input variables used by this node
process_data: Optional[Mapping[str, Any]] = None # Intermediate processing data
outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node
# Execution state
status: NodeExecutionStatus = NodeExecutionStatus.RUNNING # Current execution status
error: Optional[str] = None # Error message if execution failed
elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds
# Additional metadata
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.)
# Timing information
created_at: datetime # When execution started
finished_at: Optional[datetime] = None # When execution completed
def update_from_mapping(
self,
inputs: Optional[Mapping[str, Any]] = None,
process_data: Optional[Mapping[str, Any]] = None,
outputs: Optional[Mapping[str, Any]] = None,
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None,
) -> None:
"""
Update the model from mappings.
Args:
inputs: The inputs to update
process_data: The process data to update
outputs: The outputs to update
metadata: The metadata to update
"""
if inputs is not None:
self.inputs = dict(inputs)
if process_data is not None:
self.process_data = dict(process_data)
if outputs is not None:
self.outputs = dict(outputs)
if metadata is not None:
self.metadata = dict(metadata)

View File

@ -2,12 +2,12 @@ from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import Literal, Optional, Protocol from typing import Literal, Optional, Protocol
from models.workflow import WorkflowNodeExecution from core.workflow.entities.node_execution_entities import NodeExecution
@dataclass @dataclass
class OrderConfig: class OrderConfig:
"""Configuration for ordering WorkflowNodeExecution instances.""" """Configuration for ordering NodeExecution instances."""
order_by: list[str] order_by: list[str]
order_direction: Optional[Literal["asc", "desc"]] = None order_direction: Optional[Literal["asc", "desc"]] = None
@ -15,10 +15,10 @@ class OrderConfig:
class WorkflowNodeExecutionRepository(Protocol): class WorkflowNodeExecutionRepository(Protocol):
""" """
Repository interface for WorkflowNodeExecution. Repository interface for NodeExecution.
This interface defines the contract for accessing and manipulating This interface defines the contract for accessing and manipulating
WorkflowNodeExecution data, regardless of the underlying storage mechanism. NodeExecution data, regardless of the underlying storage mechanism.
Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id),
and trigger sources (triggered_from) should be handled at the implementation level, not in and trigger sources (triggered_from) should be handled at the implementation level, not in
@ -26,24 +26,28 @@ class WorkflowNodeExecutionRepository(Protocol):
application domains or deployment scenarios. application domains or deployment scenarios.
""" """
def save(self, execution: WorkflowNodeExecution) -> None: def save(self, execution: NodeExecution) -> None:
""" """
Save a WorkflowNodeExecution instance. Save or update a NodeExecution instance.
This method handles both creating new records and updating existing ones.
The implementation should determine whether to create or update based on
the execution's ID or other identifying fields.
Args: Args:
execution: The WorkflowNodeExecution instance to save execution: The NodeExecution instance to save or update
""" """
... ...
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]: def get_by_node_execution_id(self, node_execution_id: str) -> Optional[NodeExecution]:
""" """
Retrieve a WorkflowNodeExecution by its node_execution_id. Retrieve a NodeExecution by its node_execution_id.
Args: Args:
node_execution_id: The node execution ID node_execution_id: The node execution ID
Returns: Returns:
The WorkflowNodeExecution instance if found, None otherwise The NodeExecution instance if found, None otherwise
""" """
... ...
@ -51,9 +55,9 @@ class WorkflowNodeExecutionRepository(Protocol):
self, self,
workflow_run_id: str, workflow_run_id: str,
order_config: Optional[OrderConfig] = None, order_config: Optional[OrderConfig] = None,
) -> Sequence[WorkflowNodeExecution]: ) -> Sequence[NodeExecution]:
""" """
Retrieve all WorkflowNodeExecution instances for a specific workflow run. Retrieve all NodeExecution instances for a specific workflow run.
Args: Args:
workflow_run_id: The workflow run ID workflow_run_id: The workflow run ID
@ -62,34 +66,25 @@ class WorkflowNodeExecutionRepository(Protocol):
order_config.order_direction: Direction to order ("asc" or "desc") order_config.order_direction: Direction to order ("asc" or "desc")
Returns: Returns:
A list of WorkflowNodeExecution instances A list of NodeExecution instances
""" """
... ...
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]: def get_running_executions(self, workflow_run_id: str) -> Sequence[NodeExecution]:
""" """
Retrieve all running WorkflowNodeExecution instances for a specific workflow run. Retrieve all running NodeExecution instances for a specific workflow run.
Args: Args:
workflow_run_id: The workflow run ID workflow_run_id: The workflow run ID
Returns: Returns:
A list of running WorkflowNodeExecution instances A list of running NodeExecution instances
"""
...
def update(self, execution: WorkflowNodeExecution) -> None:
"""
Update an existing WorkflowNodeExecution instance.
Args:
execution: The WorkflowNodeExecution instance to update
""" """
... ...
def clear(self) -> None: def clear(self) -> None:
""" """
Clear all WorkflowNodeExecution records based on implementation-specific criteria. Clear all NodeExecution records based on implementation-specific criteria.
This method is intended to be used for bulk deletion operations, such as removing This method is intended to be used for bulk deletion operations, such as removing
all records associated with a specific app_id and tenant_id in multi-tenant implementations. all records associated with a specific app_id and tenant_id in multi-tenant implementations.

View File

@ -58,7 +58,7 @@ from core.workflow.repository.workflow_node_execution_repository import Workflow
from core.workflow.workflow_cycle_manager import WorkflowCycleManager from core.workflow.workflow_cycle_manager import WorkflowCycleManager
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.enums import CreatedByRole from models.enums import CreatorUserRole
from models.model import EndUser from models.model import EndUser
from models.workflow import ( from models.workflow import (
Workflow, Workflow,
@ -94,11 +94,11 @@ class WorkflowAppGenerateTaskPipeline:
if isinstance(user, EndUser): if isinstance(user, EndUser):
self._user_id = user.id self._user_id = user.id
user_session_id = user.session_id user_session_id = user.session_id
self._created_by_role = CreatedByRole.END_USER self._created_by_role = CreatorUserRole.END_USER
elif isinstance(user, Account): elif isinstance(user, Account):
self._user_id = user.id self._user_id = user.id
user_session_id = user.id user_session_id = user.id
self._created_by_role = CreatedByRole.ACCOUNT self._created_by_role = CreatorUserRole.ACCOUNT
else: else:
raise ValueError(f"Invalid user type: {type(user)}") raise ValueError(f"Invalid user type: {type(user)}")

View File

@ -46,26 +46,28 @@ from core.app.entities.task_entities import (
) )
from core.app.task_pipeline.exc import WorkflowRunNotFoundError from core.app.task_pipeline.exc import WorkflowRunNotFoundError
from core.file import FILE_MODEL_IDENTITY, File from core.file import FILE_MODEL_IDENTITY, File
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.node_execution_entities import (
NodeExecution,
NodeExecutionStatus,
)
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from models.account import Account from models import (
from models.enums import CreatedByRole, WorkflowRunTriggeredFrom Account,
from models.model import EndUser CreatorUserRole,
from models.workflow import ( EndUser,
Workflow, Workflow,
WorkflowNodeExecution,
WorkflowNodeExecutionStatus, WorkflowNodeExecutionStatus,
WorkflowNodeExecutionTriggeredFrom,
WorkflowRun, WorkflowRun,
WorkflowRunStatus, WorkflowRunStatus,
WorkflowRunTriggeredFrom,
) )
@ -78,7 +80,6 @@ class WorkflowCycleManager:
workflow_node_execution_repository: WorkflowNodeExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository,
) -> None: ) -> None:
self._workflow_run: WorkflowRun | None = None self._workflow_run: WorkflowRun | None = None
self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {}
self._application_generate_entity = application_generate_entity self._application_generate_entity = application_generate_entity
self._workflow_system_variables = workflow_system_variables self._workflow_system_variables = workflow_system_variables
self._workflow_node_execution_repository = workflow_node_execution_repository self._workflow_node_execution_repository = workflow_node_execution_repository
@ -89,7 +90,7 @@ class WorkflowCycleManager:
session: Session, session: Session,
workflow_id: str, workflow_id: str,
user_id: str, user_id: str,
created_by_role: CreatedByRole, created_by_role: CreatorUserRole,
) -> WorkflowRun: ) -> WorkflowRun:
workflow_stmt = select(Workflow).where(Workflow.id == workflow_id) workflow_stmt = select(Workflow).where(Workflow.id == workflow_id)
workflow = session.scalar(workflow_stmt) workflow = session.scalar(workflow_stmt)
@ -258,21 +259,22 @@ class WorkflowCycleManager:
workflow_run.exceptions_count = exceptions_count workflow_run.exceptions_count = exceptions_count
# Use the instance repository to find running executions for a workflow run # Use the instance repository to find running executions for a workflow run
running_workflow_node_executions = self._workflow_node_execution_repository.get_running_executions( running_domain_executions = self._workflow_node_execution_repository.get_running_executions(
workflow_run_id=workflow_run.id workflow_run_id=workflow_run.id
) )
# Update the cache with the retrieved executions # Update the domain models
for execution in running_workflow_node_executions:
if execution.node_execution_id:
self._workflow_node_executions[execution.node_execution_id] = execution
for workflow_node_execution in running_workflow_node_executions:
now = datetime.now(UTC).replace(tzinfo=None) now = datetime.now(UTC).replace(tzinfo=None)
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value for domain_execution in running_domain_executions:
workflow_node_execution.error = error if domain_execution.node_execution_id:
workflow_node_execution.finished_at = now # Update the domain model
workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds() domain_execution.status = NodeExecutionStatus.FAILED
domain_execution.error = error
domain_execution.finished_at = now
domain_execution.elapsed_time = (now - domain_execution.created_at).total_seconds()
# Update the repository with the domain model
self._workflow_node_execution_repository.save(domain_execution)
if trace_manager: if trace_manager:
trace_manager.add_trace_task( trace_manager.add_trace_task(
@ -286,63 +288,67 @@ class WorkflowCycleManager:
return workflow_run return workflow_run
def _handle_node_execution_start( def _handle_node_execution_start(self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> NodeExecution:
self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent # Create a domain model
) -> WorkflowNodeExecution: created_at = datetime.now(UTC).replace(tzinfo=None)
workflow_node_execution = WorkflowNodeExecution() metadata = {
workflow_node_execution.id = str(uuid4())
workflow_node_execution.tenant_id = workflow_run.tenant_id
workflow_node_execution.app_id = workflow_run.app_id
workflow_node_execution.workflow_id = workflow_run.workflow_id
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
workflow_node_execution.workflow_run_id = workflow_run.id
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
workflow_node_execution.index = event.node_run_index
workflow_node_execution.node_execution_id = event.node_execution_id
workflow_node_execution.node_id = event.node_id
workflow_node_execution.node_type = event.node_type.value
workflow_node_execution.title = event.node_data.title
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
workflow_node_execution.created_by_role = workflow_run.created_by_role
workflow_node_execution.created_by = workflow_run.created_by
workflow_node_execution.execution_metadata = json.dumps(
{
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
NodeRunMetadataKey.LOOP_ID: event.in_loop_id, NodeRunMetadataKey.LOOP_ID: event.in_loop_id,
} }
domain_execution = NodeExecution(
id=str(uuid4()),
workflow_id=workflow_run.workflow_id,
workflow_run_id=workflow_run.id,
predecessor_node_id=event.predecessor_node_id,
index=event.node_run_index,
node_execution_id=event.node_execution_id,
node_id=event.node_id,
node_type=event.node_type,
title=event.node_data.title,
status=NodeExecutionStatus.RUNNING,
metadata=metadata,
created_at=created_at,
) )
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
# Use the instance repository to save the workflow node execution # Use the instance repository to save the domain model
self._workflow_node_execution_repository.save(workflow_node_execution) self._workflow_node_execution_repository.save(domain_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution return domain_execution
return workflow_node_execution
def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> NodeExecution:
workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id) # Get the domain model from repository
domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id)
if not domain_execution:
raise ValueError(f"Domain node execution not found: {event.node_execution_id}")
# Process data
inputs = WorkflowEntry.handle_special_values(event.inputs) inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data) process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs) outputs = WorkflowEntry.handle_special_values(event.outputs)
execution_metadata_dict = dict(event.execution_metadata or {})
execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None # Convert metadata keys to strings
execution_metadata_dict = {}
if event.execution_metadata:
for key, value in event.execution_metadata.items():
execution_metadata_dict[key] = value
finished_at = datetime.now(UTC).replace(tzinfo=None) finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds() elapsed_time = (finished_at - event.start_at).total_seconds()
process_data = WorkflowEntry.handle_special_values(event.process_data) # Update domain model
domain_execution.status = NodeExecutionStatus.SUCCEEDED
domain_execution.update_from_mapping(
inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict
)
domain_execution.finished_at = finished_at
domain_execution.elapsed_time = elapsed_time
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value # Update the repository with the domain model
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None self._workflow_node_execution_repository.save(domain_execution)
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time
# Use the instance repository to update the workflow node execution return domain_execution
self._workflow_node_execution_repository.update(workflow_node_execution)
return workflow_node_execution
def _handle_workflow_node_execution_failed( def _handle_workflow_node_execution_failed(
self, self,
@ -351,43 +357,52 @@ class WorkflowCycleManager:
| QueueNodeInIterationFailedEvent | QueueNodeInIterationFailedEvent
| QueueNodeInLoopFailedEvent | QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent, | QueueNodeExceptionEvent,
) -> WorkflowNodeExecution: ) -> NodeExecution:
""" """
Workflow node execution failed Workflow node execution failed
:param event: queue node failed event :param event: queue node failed event
:return: :return:
""" """
workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id) # Get the domain model from repository
domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id)
if not domain_execution:
raise ValueError(f"Domain node execution not found: {event.node_execution_id}")
# Process data
inputs = WorkflowEntry.handle_special_values(event.inputs) inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data) process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs) outputs = WorkflowEntry.handle_special_values(event.outputs)
# Convert metadata keys to strings
execution_metadata_dict = {}
if event.execution_metadata:
for key, value in event.execution_metadata.items():
execution_metadata_dict[key] = value
finished_at = datetime.now(UTC).replace(tzinfo=None) finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds() elapsed_time = (finished_at - event.start_at).total_seconds()
execution_metadata = (
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None # Update domain model
) domain_execution.status = (
process_data = WorkflowEntry.handle_special_values(event.process_data) NodeExecutionStatus.FAILED
workflow_node_execution.status = (
WorkflowNodeExecutionStatus.FAILED.value
if not isinstance(event, QueueNodeExceptionEvent) if not isinstance(event, QueueNodeExceptionEvent)
else WorkflowNodeExecutionStatus.EXCEPTION.value else NodeExecutionStatus.EXCEPTION
) )
workflow_node_execution.error = event.error domain_execution.error = event.error
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None domain_execution.update_from_mapping(
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None )
workflow_node_execution.finished_at = finished_at domain_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time domain_execution.elapsed_time = elapsed_time
workflow_node_execution.execution_metadata = execution_metadata
self._workflow_node_execution_repository.update(workflow_node_execution) # Update the repository with the domain model
self._workflow_node_execution_repository.save(domain_execution)
return workflow_node_execution return domain_execution
def _handle_workflow_node_execution_retried( def _handle_workflow_node_execution_retried(
self, *, workflow_run: WorkflowRun, event: QueueNodeRetryEvent self, *, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
) -> WorkflowNodeExecution: ) -> NodeExecution:
""" """
Workflow node execution failed Workflow node execution failed
:param workflow_run: workflow run :param workflow_run: workflow run
@ -399,47 +414,47 @@ class WorkflowCycleManager:
elapsed_time = (finished_at - created_at).total_seconds() elapsed_time = (finished_at - created_at).total_seconds()
inputs = WorkflowEntry.handle_special_values(event.inputs) inputs = WorkflowEntry.handle_special_values(event.inputs)
outputs = WorkflowEntry.handle_special_values(event.outputs) outputs = WorkflowEntry.handle_special_values(event.outputs)
# Convert metadata keys to strings
origin_metadata = { origin_metadata = {
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
NodeRunMetadataKey.LOOP_ID: event.in_loop_id, NodeRunMetadataKey.LOOP_ID: event.in_loop_id,
} }
merged_metadata = (
{**jsonable_encoder(event.execution_metadata), **origin_metadata} # Convert execution metadata keys to strings
if event.execution_metadata is not None execution_metadata_dict: dict[NodeRunMetadataKey, str | None] = {}
else origin_metadata if event.execution_metadata:
for key, value in event.execution_metadata.items():
execution_metadata_dict[key] = value
merged_metadata = {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata
# Create a domain model
domain_execution = NodeExecution(
id=str(uuid4()),
workflow_id=workflow_run.workflow_id,
workflow_run_id=workflow_run.id,
predecessor_node_id=event.predecessor_node_id,
node_execution_id=event.node_execution_id,
node_id=event.node_id,
node_type=event.node_type,
title=event.node_data.title,
status=NodeExecutionStatus.RETRY,
created_at=created_at,
finished_at=finished_at,
elapsed_time=elapsed_time,
error=event.error,
index=event.node_run_index,
) )
execution_metadata = json.dumps(merged_metadata)
workflow_node_execution = WorkflowNodeExecution() # Update with mappings
workflow_node_execution.id = str(uuid4()) domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=merged_metadata)
workflow_node_execution.tenant_id = workflow_run.tenant_id
workflow_node_execution.app_id = workflow_run.app_id
workflow_node_execution.workflow_id = workflow_run.workflow_id
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
workflow_node_execution.workflow_run_id = workflow_run.id
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
workflow_node_execution.node_execution_id = event.node_execution_id
workflow_node_execution.node_id = event.node_id
workflow_node_execution.node_type = event.node_type.value
workflow_node_execution.title = event.node_data.title
workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value
workflow_node_execution.created_by_role = workflow_run.created_by_role
workflow_node_execution.created_by = workflow_run.created_by
workflow_node_execution.created_at = created_at
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.error = event.error
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution.index = event.node_run_index
# Use the instance repository to save the workflow node execution # Use the instance repository to save the domain model
self._workflow_node_execution_repository.save(workflow_node_execution) self._workflow_node_execution_repository.save(domain_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution return domain_execution
return workflow_node_execution
def _workflow_start_to_stream_response( def _workflow_start_to_stream_response(
self, self,
@ -469,7 +484,7 @@ class WorkflowCycleManager:
workflow_run: WorkflowRun, workflow_run: WorkflowRun,
) -> WorkflowFinishStreamResponse: ) -> WorkflowFinishStreamResponse:
created_by = None created_by = None
if workflow_run.created_by_role == CreatedByRole.ACCOUNT: if workflow_run.created_by_role == CreatorUserRole.ACCOUNT:
stmt = select(Account).where(Account.id == workflow_run.created_by) stmt = select(Account).where(Account.id == workflow_run.created_by)
account = session.scalar(stmt) account = session.scalar(stmt)
if account: if account:
@ -478,7 +493,7 @@ class WorkflowCycleManager:
"name": account.name, "name": account.name,
"email": account.email, "email": account.email,
} }
elif workflow_run.created_by_role == CreatedByRole.END_USER: elif workflow_run.created_by_role == CreatorUserRole.END_USER:
stmt = select(EndUser).where(EndUser.id == workflow_run.created_by) stmt = select(EndUser).where(EndUser.id == workflow_run.created_by)
end_user = session.scalar(stmt) end_user = session.scalar(stmt)
if end_user: if end_user:
@ -515,9 +530,9 @@ class WorkflowCycleManager:
*, *,
event: QueueNodeStartedEvent, event: QueueNodeStartedEvent,
task_id: str, task_id: str,
workflow_node_execution: WorkflowNodeExecution, workflow_node_execution: NodeExecution,
) -> Optional[NodeStartStreamResponse]: ) -> Optional[NodeStartStreamResponse]:
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None return None
if not workflow_node_execution.workflow_run_id: if not workflow_node_execution.workflow_run_id:
return None return None
@ -532,7 +547,7 @@ class WorkflowCycleManager:
title=workflow_node_execution.title, title=workflow_node_execution.title,
index=workflow_node_execution.index, index=workflow_node_execution.index,
predecessor_node_id=workflow_node_execution.predecessor_node_id, predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs_dict, inputs=workflow_node_execution.inputs,
created_at=int(workflow_node_execution.created_at.timestamp()), created_at=int(workflow_node_execution.created_at.timestamp()),
parallel_id=event.parallel_id, parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id, parallel_start_node_id=event.parallel_start_node_id,
@ -565,9 +580,9 @@ class WorkflowCycleManager:
| QueueNodeInLoopFailedEvent | QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent, | QueueNodeExceptionEvent,
task_id: str, task_id: str,
workflow_node_execution: WorkflowNodeExecution, workflow_node_execution: NodeExecution,
) -> Optional[NodeFinishStreamResponse]: ) -> Optional[NodeFinishStreamResponse]:
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None return None
if not workflow_node_execution.workflow_run_id: if not workflow_node_execution.workflow_run_id:
return None return None
@ -584,16 +599,16 @@ class WorkflowCycleManager:
index=workflow_node_execution.index, index=workflow_node_execution.index,
title=workflow_node_execution.title, title=workflow_node_execution.title,
predecessor_node_id=workflow_node_execution.predecessor_node_id, predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs_dict, inputs=workflow_node_execution.inputs,
process_data=workflow_node_execution.process_data_dict, process_data=workflow_node_execution.process_data,
outputs=workflow_node_execution.outputs_dict, outputs=workflow_node_execution.outputs,
status=workflow_node_execution.status, status=workflow_node_execution.status,
error=workflow_node_execution.error, error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time, elapsed_time=workflow_node_execution.elapsed_time,
execution_metadata=workflow_node_execution.execution_metadata_dict, execution_metadata=workflow_node_execution.metadata,
created_at=int(workflow_node_execution.created_at.timestamp()), created_at=int(workflow_node_execution.created_at.timestamp()),
finished_at=int(workflow_node_execution.finished_at.timestamp()), finished_at=int(workflow_node_execution.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}), files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
parallel_id=event.parallel_id, parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id, parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id, parent_parallel_id=event.parent_parallel_id,
@ -608,9 +623,9 @@ class WorkflowCycleManager:
*, *,
event: QueueNodeRetryEvent, event: QueueNodeRetryEvent,
task_id: str, task_id: str,
workflow_node_execution: WorkflowNodeExecution, workflow_node_execution: NodeExecution,
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]: ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None return None
if not workflow_node_execution.workflow_run_id: if not workflow_node_execution.workflow_run_id:
return None return None
@ -627,16 +642,16 @@ class WorkflowCycleManager:
index=workflow_node_execution.index, index=workflow_node_execution.index,
title=workflow_node_execution.title, title=workflow_node_execution.title,
predecessor_node_id=workflow_node_execution.predecessor_node_id, predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs_dict, inputs=workflow_node_execution.inputs,
process_data=workflow_node_execution.process_data_dict, process_data=workflow_node_execution.process_data,
outputs=workflow_node_execution.outputs_dict, outputs=workflow_node_execution.outputs,
status=workflow_node_execution.status, status=workflow_node_execution.status,
error=workflow_node_execution.error, error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time, elapsed_time=workflow_node_execution.elapsed_time,
execution_metadata=workflow_node_execution.execution_metadata_dict, execution_metadata=workflow_node_execution.metadata,
created_at=int(workflow_node_execution.created_at.timestamp()), created_at=int(workflow_node_execution.created_at.timestamp()),
finished_at=int(workflow_node_execution.finished_at.timestamp()), finished_at=int(workflow_node_execution.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}), files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
parallel_id=event.parallel_id, parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id, parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id, parent_parallel_id=event.parent_parallel_id,
@ -908,23 +923,6 @@ class WorkflowCycleManager:
return workflow_run return workflow_run
def _get_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution:
# First check the cache for performance
if node_execution_id in self._workflow_node_executions:
cached_execution = self._workflow_node_executions[node_execution_id]
# No need to merge with session since expire_on_commit=False
return cached_execution
# If not in cache, use the instance repository to get by node_execution_id
execution = self._workflow_node_execution_repository.get_by_node_execution_id(node_execution_id)
if not execution:
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
# Update cache
self._workflow_node_executions[node_execution_id] = execution
return execution
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse: def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
""" """
Handle agent log Handle agent log

View File

@ -27,7 +27,7 @@ from .dataset import (
Whitelist, Whitelist,
) )
from .engine import db from .engine import db
from .enums import CreatedByRole, UserFrom, WorkflowRunTriggeredFrom from .enums import CreatorUserRole, UserFrom, WorkflowRunTriggeredFrom
from .model import ( from .model import (
ApiRequest, ApiRequest,
ApiToken, ApiToken,
@ -112,7 +112,7 @@ __all__ = [
"CeleryTaskSet", "CeleryTaskSet",
"Conversation", "Conversation",
"ConversationVariable", "ConversationVariable",
"CreatedByRole", "CreatorUserRole",
"DataSourceApiKeyAuthBinding", "DataSourceApiKeyAuthBinding",
"DataSourceOauthBinding", "DataSourceOauthBinding",
"Dataset", "Dataset",

View File

@ -1,7 +1,7 @@
from enum import StrEnum from enum import StrEnum
class CreatedByRole(StrEnum): class CreatorUserRole(StrEnum):
ACCOUNT = "account" ACCOUNT = "account"
END_USER = "end_user" END_USER = "end_user"

View File

@ -29,7 +29,7 @@ from libs.helper import generate_string
from .account import Account, Tenant from .account import Account, Tenant
from .base import Base from .base import Base
from .engine import db from .engine import db
from .enums import CreatedByRole from .enums import CreatorUserRole
from .types import StringUUID from .types import StringUUID
from .workflow import WorkflowRunStatus from .workflow import WorkflowRunStatus
@ -1270,7 +1270,7 @@ class MessageFile(Base):
url: str | None = None, url: str | None = None,
belongs_to: Literal["user", "assistant"] | None = None, belongs_to: Literal["user", "assistant"] | None = None,
upload_file_id: str | None = None, upload_file_id: str | None = None,
created_by_role: CreatedByRole, created_by_role: CreatorUserRole,
created_by: str, created_by: str,
): ):
self.message_id = message_id self.message_id = message_id
@ -1417,7 +1417,7 @@ class EndUser(Base, UserMixin):
) )
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False) tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
app_id = db.Column(StringUUID, nullable=True) app_id = db.Column(StringUUID, nullable=True)
type = db.Column(db.String(255), nullable=False) type = db.Column(db.String(255), nullable=False)
external_user_id = db.Column(db.String(255), nullable=True) external_user_id = db.Column(db.String(255), nullable=True)
@ -1547,7 +1547,7 @@ class UploadFile(Base):
size: int, size: int,
extension: str, extension: str,
mime_type: str, mime_type: str,
created_by_role: CreatedByRole, created_by_role: CreatorUserRole,
created_by: str, created_by: str,
created_at: datetime, created_at: datetime,
used: bool, used: bool,

View File

@ -22,7 +22,7 @@ from libs import helper
from .account import Account from .account import Account
from .base import Base from .base import Base
from .engine import db from .engine import db
from .enums import CreatedByRole from .enums import CreatorUserRole
from .types import StringUUID from .types import StringUUID
if TYPE_CHECKING: if TYPE_CHECKING:
@ -429,15 +429,15 @@ class WorkflowRun(Base):
@property @property
def created_by_account(self): def created_by_account(self):
created_by_role = CreatedByRole(self.created_by_role) created_by_role = CreatorUserRole(self.created_by_role)
return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None
@property @property
def created_by_end_user(self): def created_by_end_user(self):
from models.model import EndUser from models.model import EndUser
created_by_role = CreatedByRole(self.created_by_role) created_by_role = CreatorUserRole(self.created_by_role)
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
@property @property
def graph_dict(self): def graph_dict(self):
@ -634,17 +634,17 @@ class WorkflowNodeExecution(Base):
@property @property
def created_by_account(self): def created_by_account(self):
created_by_role = CreatedByRole(self.created_by_role) created_by_role = CreatorUserRole(self.created_by_role)
# TODO(-LAN-): Avoid using db.session.get() here. # TODO(-LAN-): Avoid using db.session.get() here.
return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None
@property @property
def created_by_end_user(self): def created_by_end_user(self):
from models.model import EndUser from models.model import EndUser
created_by_role = CreatedByRole(self.created_by_role) created_by_role = CreatorUserRole(self.created_by_role)
# TODO(-LAN-): Avoid using db.session.get() here. # TODO(-LAN-): Avoid using db.session.get() here.
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
@property @property
def inputs_dict(self): def inputs_dict(self):
@ -755,15 +755,15 @@ class WorkflowAppLog(Base):
@property @property
def created_by_account(self): def created_by_account(self):
created_by_role = CreatedByRole(self.created_by_role) created_by_role = CreatorUserRole(self.created_by_role)
return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None
@property @property
def created_by_end_user(self): def created_by_end_user(self):
from models.model import EndUser from models.model import EndUser
created_by_role = CreatedByRole(self.created_by_role) created_by_role = CreatorUserRole(self.created_by_role)
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
class ConversationVariable(Base): class ConversationVariable(Base):

View File

@ -19,7 +19,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_storage import storage from extensions.ext_storage import storage
from models.account import Account from models.account import Account
from models.enums import CreatedByRole from models.enums import CreatorUserRole
from models.model import EndUser, UploadFile from models.model import EndUser, UploadFile
from .errors.file import FileTooLargeError, UnsupportedFileTypeError from .errors.file import FileTooLargeError, UnsupportedFileTypeError
@ -81,7 +81,7 @@ class FileService:
size=file_size, size=file_size,
extension=extension, extension=extension,
mime_type=mimetype, mime_type=mimetype,
created_by_role=(CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER), created_by_role=(CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER),
created_by=user.id, created_by=user.id,
created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
used=False, used=False,
@ -133,7 +133,7 @@ class FileService:
extension="txt", extension="txt",
mime_type="text/plain", mime_type="text/plain",
created_by=current_user.id, created_by=current_user.id,
created_by_role=CreatedByRole.ACCOUNT, created_by_role=CreatorUserRole.ACCOUNT,
created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
used=True, used=True,
used_by=current_user.id, used_by=current_user.id,

View File

@ -5,7 +5,7 @@ from sqlalchemy import and_, func, or_, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from models import App, EndUser, WorkflowAppLog, WorkflowRun from models import App, EndUser, WorkflowAppLog, WorkflowRun
from models.enums import CreatedByRole from models.enums import CreatorUserRole
from models.workflow import WorkflowRunStatus from models.workflow import WorkflowRunStatus
@ -58,7 +58,7 @@ class WorkflowAppService:
stmt = stmt.outerjoin( stmt = stmt.outerjoin(
EndUser, EndUser,
and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER), and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatorUserRole.END_USER),
).where(or_(*keyword_conditions)) ).where(or_(*keyword_conditions))
if status: if status:

View File

@ -1,4 +1,5 @@
import threading import threading
from collections.abc import Sequence
from typing import Optional from typing import Optional
import contexts import contexts
@ -6,11 +7,13 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import OrderConfig from core.workflow.repository.workflow_node_execution_repository import OrderConfig
from extensions.ext_database import db from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import WorkflowRunTriggeredFrom from models import (
from models.model import App Account,
from models.workflow import ( App,
EndUser,
WorkflowNodeExecution, WorkflowNodeExecution,
WorkflowRun, WorkflowRun,
WorkflowRunTriggeredFrom,
) )
@ -116,7 +119,12 @@ class WorkflowRunService:
return workflow_run return workflow_run
def get_workflow_run_node_executions(self, app_model: App, run_id: str) -> list[WorkflowNodeExecution]: def get_workflow_run_node_executions(
self,
app_model: App,
run_id: str,
user: Account | EndUser,
) -> Sequence[WorkflowNodeExecution]:
""" """
Get workflow run node execution list Get workflow run node execution list
""" """
@ -128,13 +136,15 @@ class WorkflowRunService:
if not workflow_run: if not workflow_run:
return [] return []
# Use the repository to get the node executions
repository = SQLAlchemyWorkflowNodeExecutionRepository( repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=db.engine, tenant_id=app_model.tenant_id, app_id=app_model.id session_factory=db.engine,
user=user,
app_id=app_model.id,
triggered_from=None,
) )
# Use the repository to get the node executions with ordering # Use the repository to get the node executions with ordering
order_config = OrderConfig(order_by=["index"], order_direction="desc") order_config = OrderConfig(order_by=["index"], order_direction="desc")
node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config) node_executions = repository.get_db_models_by_workflow_run(workflow_run_id=run_id, order_config=order_config)
return list(node_executions) return node_executions

View File

@ -26,7 +26,7 @@ from core.workflow.workflow_entry import WorkflowEntry
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.enums import CreatedByRole from models.enums import CreatorUserRole
from models.model import App, AppMode from models.model import App, AppMode
from models.tools import WorkflowToolProvider from models.tools import WorkflowToolProvider
from models.workflow import ( from models.workflow import (
@ -284,9 +284,11 @@ class WorkflowService:
workflow_node_execution.created_by = account.id workflow_node_execution.created_by = account.id
workflow_node_execution.workflow_id = draft_workflow.id workflow_node_execution.workflow_id = draft_workflow.id
# Use the repository to save the workflow node execution
repository = SQLAlchemyWorkflowNodeExecutionRepository( repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=db.engine, tenant_id=app_model.tenant_id, app_id=app_model.id session_factory=db.engine,
user=account,
app_id=app_model.id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
) )
repository.save(workflow_node_execution) repository.save(workflow_node_execution)
@ -390,7 +392,7 @@ class WorkflowService:
workflow_node_execution.node_type = node_instance.node_type workflow_node_execution.node_type = node_instance.node_type
workflow_node_execution.title = node_instance.node_data.title workflow_node_execution.title = node_instance.node_data.title
workflow_node_execution.elapsed_time = time.perf_counter() - start_at workflow_node_execution.elapsed_time = time.perf_counter() - start_at
workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value workflow_node_execution.created_by_role = CreatorUserRole.ACCOUNT.value
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
if run_succeeded and node_run_result: if run_succeeded and node_run_result:

View File

@ -4,16 +4,19 @@ from collections.abc import Callable
import click import click
from celery import shared_task # type: ignore from celery import shared_task # type: ignore
from sqlalchemy import delete from sqlalchemy import delete, select
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import AppDatasetJoin from models import (
from models.model import ( Account,
ApiToken, ApiToken,
App,
AppAnnotationHitHistory, AppAnnotationHitHistory,
AppAnnotationSetting, AppAnnotationSetting,
AppDatasetJoin,
AppModelConfig, AppModelConfig,
Conversation, Conversation,
EndUser, EndUser,
@ -188,9 +191,24 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str):
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
# Get app's owner
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Account).where(Account.id == App.owner_id).where(App.id == app_id)
user = session.scalar(stmt)
if user is None:
errmsg = (
f"Failed to delete workflow node executions for tenant {tenant_id} and app {app_id}, app's owner not found"
)
logging.error(errmsg)
raise ValueError(errmsg)
# Create a repository instance for WorkflowNodeExecution # Create a repository instance for WorkflowNodeExecution
repository = SQLAlchemyWorkflowNodeExecutionRepository( repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=db.engine, tenant_id=tenant_id, app_id=app_id session_factory=db.engine,
user=user,
app_id=app_id,
triggered_from=None,
) )
# Use the clear method to delete all records for this tenant_id and app_id # Use the clear method to delete all records for this tenant_id and app_id

View File

@ -16,10 +16,9 @@ from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_cycle_manager import WorkflowCycleManager from core.workflow.workflow_cycle_manager import WorkflowCycleManager
from models.enums import CreatedByRole from models.enums import CreatorUserRole
from models.workflow import ( from models.workflow import (
Workflow, Workflow,
WorkflowNodeExecution,
WorkflowNodeExecutionStatus, WorkflowNodeExecutionStatus,
WorkflowRun, WorkflowRun,
WorkflowRunStatus, WorkflowRunStatus,
@ -94,7 +93,7 @@ def mock_workflow_run():
workflow_run.app_id = "test-app-id" workflow_run.app_id = "test-app-id"
workflow_run.workflow_id = "test-workflow-id" workflow_run.workflow_id = "test-workflow-id"
workflow_run.status = WorkflowRunStatus.RUNNING workflow_run.status = WorkflowRunStatus.RUNNING
workflow_run.created_by_role = CreatedByRole.ACCOUNT workflow_run.created_by_role = CreatorUserRole.ACCOUNT
workflow_run.created_by = "test-user-id" workflow_run.created_by = "test-user-id"
workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.inputs_dict = {"query": "test query"} workflow_run.inputs_dict = {"query": "test query"}
@ -107,7 +106,6 @@ def test_init(
): ):
"""Test initialization of WorkflowCycleManager""" """Test initialization of WorkflowCycleManager"""
assert workflow_cycle_manager._workflow_run is None assert workflow_cycle_manager._workflow_run is None
assert workflow_cycle_manager._workflow_node_executions == {}
assert workflow_cycle_manager._application_generate_entity == mock_app_generate_entity assert workflow_cycle_manager._application_generate_entity == mock_app_generate_entity
assert workflow_cycle_manager._workflow_system_variables == mock_workflow_system_variables assert workflow_cycle_manager._workflow_system_variables == mock_workflow_system_variables
assert workflow_cycle_manager._workflow_node_execution_repository == mock_node_execution_repository assert workflow_cycle_manager._workflow_node_execution_repository == mock_node_execution_repository
@ -123,7 +121,7 @@ def test_handle_workflow_run_start(workflow_cycle_manager, mock_session, mock_wo
session=mock_session, session=mock_session,
workflow_id="test-workflow-id", workflow_id="test-workflow-id",
user_id="test-user-id", user_id="test-user-id",
created_by_role=CreatedByRole.ACCOUNT, created_by_role=CreatorUserRole.ACCOUNT,
) )
# Verify the result # Verify the result
@ -132,7 +130,7 @@ def test_handle_workflow_run_start(workflow_cycle_manager, mock_session, mock_wo
assert workflow_run.workflow_id == mock_workflow.id assert workflow_run.workflow_id == mock_workflow.id
assert workflow_run.sequence_number == 6 # max_sequence + 1 assert workflow_run.sequence_number == 6 # max_sequence + 1
assert workflow_run.status == WorkflowRunStatus.RUNNING assert workflow_run.status == WorkflowRunStatus.RUNNING
assert workflow_run.created_by_role == CreatedByRole.ACCOUNT assert workflow_run.created_by_role == CreatorUserRole.ACCOUNT
assert workflow_run.created_by == "test-user-id" assert workflow_run.created_by == "test-user-id"
# Verify session.add was called # Verify session.add was called
@ -215,24 +213,23 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_run):
) )
# Verify the result # Verify the result
assert result.tenant_id == mock_workflow_run.tenant_id # NodeExecution doesn't have tenant_id attribute, it's handled at repository level
assert result.app_id == mock_workflow_run.app_id # assert result.tenant_id == mock_workflow_run.tenant_id
# assert result.app_id == mock_workflow_run.app_id
assert result.workflow_id == mock_workflow_run.workflow_id assert result.workflow_id == mock_workflow_run.workflow_id
assert result.workflow_run_id == mock_workflow_run.id assert result.workflow_run_id == mock_workflow_run.id
assert result.node_execution_id == event.node_execution_id assert result.node_execution_id == event.node_execution_id
assert result.node_id == event.node_id assert result.node_id == event.node_id
assert result.node_type == event.node_type.value assert result.node_type == event.node_type
assert result.title == event.node_data.title assert result.title == event.node_data.title
assert result.status == WorkflowNodeExecutionStatus.RUNNING.value assert result.status == WorkflowNodeExecutionStatus.RUNNING.value
assert result.created_by_role == mock_workflow_run.created_by_role # NodeExecution doesn't have created_by_role and created_by attributes, they're handled at repository level
assert result.created_by == mock_workflow_run.created_by # assert result.created_by_role == mock_workflow_run.created_by_role
# assert result.created_by == mock_workflow_run.created_by
# Verify save was called # Verify save was called
workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(result) workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(result)
# Verify the node execution was added to the cache
assert workflow_cycle_manager._workflow_node_executions[event.node_execution_id] == result
def test_get_workflow_run(workflow_cycle_manager, mock_session, mock_workflow_run): def test_get_workflow_run(workflow_cycle_manager, mock_session, mock_workflow_run):
"""Test _get_workflow_run method""" """Test _get_workflow_run method"""
@ -261,12 +258,13 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager):
event.execution_metadata = {"metadata": "test metadata"} event.execution_metadata = {"metadata": "test metadata"}
event.start_at = datetime.now(UTC).replace(tzinfo=None) event.start_at = datetime.now(UTC).replace(tzinfo=None)
# Create a mock workflow node execution # Create a mock node execution
node_execution = MagicMock(spec=WorkflowNodeExecution) node_execution = MagicMock()
node_execution.node_execution_id = "test-node-execution-id" node_execution.node_execution_id = "test-node-execution-id"
# Mock _get_workflow_node_execution to return the mock node execution # Mock the repository to return the node execution
with patch.object(workflow_cycle_manager, "_get_workflow_node_execution", return_value=node_execution): workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution
# Call the method # Call the method
result = workflow_cycle_manager._handle_workflow_node_execution_success( result = workflow_cycle_manager._handle_workflow_node_execution_success(
event=event, event=event,
@ -275,14 +273,9 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager):
# Verify the result # Verify the result
assert result == node_execution assert result == node_execution
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED.value assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED.value
assert result.inputs == json.dumps(event.inputs)
assert result.process_data == json.dumps(event.process_data)
assert result.outputs == json.dumps(event.outputs)
assert result.finished_at is not None
assert result.elapsed_time is not None
# Verify update was called # Verify save was called
workflow_cycle_manager._workflow_node_execution_repository.update.assert_called_once_with(node_execution) workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(node_execution)
def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_session, mock_workflow_run): def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_session, mock_workflow_run):
@ -322,12 +315,13 @@ def test_handle_workflow_node_execution_failed(workflow_cycle_manager):
event.start_at = datetime.now(UTC).replace(tzinfo=None) event.start_at = datetime.now(UTC).replace(tzinfo=None)
event.error = "Test error message" event.error = "Test error message"
# Create a mock workflow node execution # Create a mock node execution
node_execution = MagicMock(spec=WorkflowNodeExecution) node_execution = MagicMock()
node_execution.node_execution_id = "test-node-execution-id" node_execution.node_execution_id = "test-node-execution-id"
# Mock _get_workflow_node_execution to return the mock node execution # Mock the repository to return the node execution
with patch.object(workflow_cycle_manager, "_get_workflow_node_execution", return_value=node_execution): workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution
# Call the method # Call the method
result = workflow_cycle_manager._handle_workflow_node_execution_failed( result = workflow_cycle_manager._handle_workflow_node_execution_failed(
event=event, event=event,
@ -337,12 +331,6 @@ def test_handle_workflow_node_execution_failed(workflow_cycle_manager):
assert result == node_execution assert result == node_execution
assert result.status == WorkflowNodeExecutionStatus.FAILED.value assert result.status == WorkflowNodeExecutionStatus.FAILED.value
assert result.error == "Test error message" assert result.error == "Test error message"
assert result.inputs == json.dumps(event.inputs)
assert result.process_data == json.dumps(event.process_data)
assert result.outputs == json.dumps(event.outputs)
assert result.finished_at is not None
assert result.elapsed_time is not None
assert result.execution_metadata == json.dumps(event.execution_metadata)
# Verify update was called # Verify save was called
workflow_cycle_manager._workflow_node_execution_repository.update.assert_called_once_with(node_execution) workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(node_execution)

View File

@ -2,15 +2,36 @@
Unit tests for the SQLAlchemy implementation of WorkflowNodeExecutionRepository. Unit tests for the SQLAlchemy implementation of WorkflowNodeExecutionRepository.
""" """
from unittest.mock import MagicMock import json
from datetime import datetime
from unittest.mock import MagicMock, PropertyMock
import pytest import pytest
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.node_execution_entities import NodeExecution, NodeExecutionStatus
from core.workflow.nodes.enums import NodeType
from core.workflow.repository.workflow_node_execution_repository import OrderConfig from core.workflow.repository.workflow_node_execution_repository import OrderConfig
from models.workflow import WorkflowNodeExecution from models.account import Account, Tenant
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom
def configure_mock_execution(mock_execution):
"""Configure a mock execution with proper JSON serializable values."""
# Configure inputs, outputs, process_data, and execution_metadata to return JSON serializable values
type(mock_execution).inputs = PropertyMock(return_value='{"key": "value"}')
type(mock_execution).outputs = PropertyMock(return_value='{"result": "success"}')
type(mock_execution).process_data = PropertyMock(return_value='{"process": "data"}')
type(mock_execution).execution_metadata = PropertyMock(return_value='{"metadata": "info"}')
# Configure status and triggered_from to be valid enum values
mock_execution.status = "running"
mock_execution.triggered_from = "workflow-run"
return mock_execution
@pytest.fixture @pytest.fixture
@ -28,13 +49,30 @@ def session():
@pytest.fixture @pytest.fixture
def repository(session): def mock_user():
"""Create a user instance for testing."""
user = Account()
user.id = "test-user-id"
tenant = Tenant()
tenant.id = "test-tenant"
tenant.name = "Test Workspace"
user._current_tenant = MagicMock()
user._current_tenant.id = "test-tenant"
return user
@pytest.fixture
def repository(session, mock_user):
"""Create a repository instance with test data.""" """Create a repository instance with test data."""
_, session_factory = session _, session_factory = session
tenant_id = "test-tenant"
app_id = "test-app" app_id = "test-app"
return SQLAlchemyWorkflowNodeExecutionRepository( return SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory, tenant_id=tenant_id, app_id=app_id session_factory=session_factory,
user=mock_user,
app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
) )
@ -45,16 +83,23 @@ def test_save(repository, session):
execution = MagicMock(spec=WorkflowNodeExecution) execution = MagicMock(spec=WorkflowNodeExecution)
execution.tenant_id = None execution.tenant_id = None
execution.app_id = None execution.app_id = None
execution.inputs = None
execution.process_data = None
execution.outputs = None
execution.metadata = None
# Mock the _to_db_model method to return the execution itself
# This simulates the behavior of setting tenant_id and app_id
repository._to_db_model = MagicMock(return_value=execution)
# Call save method # Call save method
repository.save(execution) repository.save(execution)
# Assert tenant_id and app_id are set # Assert _to_db_model was called with the execution
assert execution.tenant_id == repository._tenant_id repository._to_db_model.assert_called_once_with(execution)
assert execution.app_id == repository._app_id
# Assert session.add was called # Assert session.merge was called (now using merge for both save and update)
session_obj.add.assert_called_once_with(execution) session_obj.merge.assert_called_once_with(execution)
def test_save_with_existing_tenant_id(repository, session): def test_save_with_existing_tenant_id(repository, session):
@ -64,16 +109,27 @@ def test_save_with_existing_tenant_id(repository, session):
execution = MagicMock(spec=WorkflowNodeExecution) execution = MagicMock(spec=WorkflowNodeExecution)
execution.tenant_id = "existing-tenant" execution.tenant_id = "existing-tenant"
execution.app_id = None execution.app_id = None
execution.inputs = None
execution.process_data = None
execution.outputs = None
execution.metadata = None
# Create a modified execution that will be returned by _to_db_model
modified_execution = MagicMock(spec=WorkflowNodeExecution)
modified_execution.tenant_id = "existing-tenant" # Tenant ID should not change
modified_execution.app_id = repository._app_id # App ID should be set
# Mock the _to_db_model method to return the modified execution
repository._to_db_model = MagicMock(return_value=modified_execution)
# Call save method # Call save method
repository.save(execution) repository.save(execution)
# Assert tenant_id is not changed and app_id is set # Assert _to_db_model was called with the execution
assert execution.tenant_id == "existing-tenant" repository._to_db_model.assert_called_once_with(execution)
assert execution.app_id == repository._app_id
# Assert session.add was called # Assert session.merge was called with the modified execution (now using merge for both save and update)
session_obj.add.assert_called_once_with(execution) session_obj.merge.assert_called_once_with(modified_execution)
def test_get_by_node_execution_id(repository, session, mocker: MockerFixture): def test_get_by_node_execution_id(repository, session, mocker: MockerFixture):
@ -84,7 +140,16 @@ def test_get_by_node_execution_id(repository, session, mocker: MockerFixture):
mock_stmt = mocker.MagicMock() mock_stmt = mocker.MagicMock()
mock_select.return_value = mock_stmt mock_select.return_value = mock_stmt
mock_stmt.where.return_value = mock_stmt mock_stmt.where.return_value = mock_stmt
session_obj.scalar.return_value = mocker.MagicMock(spec=WorkflowNodeExecution)
# Create a properly configured mock execution
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecution)
configure_mock_execution(mock_execution)
session_obj.scalar.return_value = mock_execution
# Create a mock domain model to be returned by _to_domain_model
mock_domain_model = mocker.MagicMock()
# Mock the _to_domain_model method to return our mock domain model
repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model)
# Call method # Call method
result = repository.get_by_node_execution_id("test-node-execution-id") result = repository.get_by_node_execution_id("test-node-execution-id")
@ -92,7 +157,10 @@ def test_get_by_node_execution_id(repository, session, mocker: MockerFixture):
# Assert select was called with correct parameters # Assert select was called with correct parameters
mock_select.assert_called_once() mock_select.assert_called_once()
session_obj.scalar.assert_called_once_with(mock_stmt) session_obj.scalar.assert_called_once_with(mock_stmt)
assert result is not None # Assert _to_domain_model was called with the mock execution
repository._to_domain_model.assert_called_once_with(mock_execution)
# Assert the result is our mock domain model
assert result is mock_domain_model
def test_get_by_workflow_run(repository, session, mocker: MockerFixture): def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
@ -104,7 +172,16 @@ def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
mock_select.return_value = mock_stmt mock_select.return_value = mock_stmt
mock_stmt.where.return_value = mock_stmt mock_stmt.where.return_value = mock_stmt
mock_stmt.order_by.return_value = mock_stmt mock_stmt.order_by.return_value = mock_stmt
session_obj.scalars.return_value.all.return_value = [mocker.MagicMock(spec=WorkflowNodeExecution)]
# Create a properly configured mock execution
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecution)
configure_mock_execution(mock_execution)
session_obj.scalars.return_value.all.return_value = [mock_execution]
# Create a mock domain model to be returned by _to_domain_model
mock_domain_model = mocker.MagicMock()
# Mock the _to_domain_model method to return our mock domain model
repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model)
# Call method # Call method
order_config = OrderConfig(order_by=["index"], order_direction="desc") order_config = OrderConfig(order_by=["index"], order_direction="desc")
@ -113,7 +190,45 @@ def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
# Assert select was called with correct parameters # Assert select was called with correct parameters
mock_select.assert_called_once() mock_select.assert_called_once()
session_obj.scalars.assert_called_once_with(mock_stmt) session_obj.scalars.assert_called_once_with(mock_stmt)
# Assert _to_domain_model was called with the mock execution
repository._to_domain_model.assert_called_once_with(mock_execution)
# Assert the result contains our mock domain model
assert len(result) == 1 assert len(result) == 1
assert result[0] is mock_domain_model
def test_get_db_models_by_workflow_run(repository, session, mocker: MockerFixture):
"""Test get_db_models_by_workflow_run method."""
session_obj, _ = session
# Set up mock
mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select")
mock_stmt = mocker.MagicMock()
mock_select.return_value = mock_stmt
mock_stmt.where.return_value = mock_stmt
mock_stmt.order_by.return_value = mock_stmt
# Create a properly configured mock execution
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecution)
configure_mock_execution(mock_execution)
session_obj.scalars.return_value.all.return_value = [mock_execution]
# Mock the _to_domain_model method
to_domain_model_mock = mocker.patch.object(repository, "_to_domain_model")
# Call method
order_config = OrderConfig(order_by=["index"], order_direction="desc")
result = repository.get_db_models_by_workflow_run(workflow_run_id="test-workflow-run-id", order_config=order_config)
# Assert select was called with correct parameters
mock_select.assert_called_once()
session_obj.scalars.assert_called_once_with(mock_stmt)
# Assert the result contains our mock db model directly (without conversion to domain model)
assert len(result) == 1
assert result[0] is mock_execution
# Verify that _to_domain_model was NOT called (since we're returning raw DB models)
to_domain_model_mock.assert_not_called()
def test_get_running_executions(repository, session, mocker: MockerFixture): def test_get_running_executions(repository, session, mocker: MockerFixture):
@ -124,7 +239,16 @@ def test_get_running_executions(repository, session, mocker: MockerFixture):
mock_stmt = mocker.MagicMock() mock_stmt = mocker.MagicMock()
mock_select.return_value = mock_stmt mock_select.return_value = mock_stmt
mock_stmt.where.return_value = mock_stmt mock_stmt.where.return_value = mock_stmt
session_obj.scalars.return_value.all.return_value = [mocker.MagicMock(spec=WorkflowNodeExecution)]
# Create a properly configured mock execution
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecution)
configure_mock_execution(mock_execution)
session_obj.scalars.return_value.all.return_value = [mock_execution]
# Create a mock domain model to be returned by _to_domain_model
mock_domain_model = mocker.MagicMock()
# Mock the _to_domain_model method to return our mock domain model
repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model)
# Call method # Call method
result = repository.get_running_executions("test-workflow-run-id") result = repository.get_running_executions("test-workflow-run-id")
@ -132,25 +256,36 @@ def test_get_running_executions(repository, session, mocker: MockerFixture):
# Assert select was called with correct parameters # Assert select was called with correct parameters
mock_select.assert_called_once() mock_select.assert_called_once()
session_obj.scalars.assert_called_once_with(mock_stmt) session_obj.scalars.assert_called_once_with(mock_stmt)
# Assert _to_domain_model was called with the mock execution
repository._to_domain_model.assert_called_once_with(mock_execution)
# Assert the result contains our mock domain model
assert len(result) == 1 assert len(result) == 1
assert result[0] is mock_domain_model
def test_update(repository, session): def test_update_via_save(repository, session):
"""Test update method.""" """Test updating an existing record via save method."""
session_obj, _ = session session_obj, _ = session
# Create a mock execution # Create a mock execution
execution = MagicMock(spec=WorkflowNodeExecution) execution = MagicMock(spec=WorkflowNodeExecution)
execution.tenant_id = None execution.tenant_id = None
execution.app_id = None execution.app_id = None
execution.inputs = None
execution.process_data = None
execution.outputs = None
execution.metadata = None
# Call update method # Mock the _to_db_model method to return the execution itself
repository.update(execution) # This simulates the behavior of setting tenant_id and app_id
repository._to_db_model = MagicMock(return_value=execution)
# Assert tenant_id and app_id are set # Call save method to update an existing record
assert execution.tenant_id == repository._tenant_id repository.save(execution)
assert execution.app_id == repository._app_id
# Assert session.merge was called # Assert _to_db_model was called with the execution
repository._to_db_model.assert_called_once_with(execution)
# Assert session.merge was called (for updates)
session_obj.merge.assert_called_once_with(execution) session_obj.merge.assert_called_once_with(execution)
@ -176,3 +311,118 @@ def test_clear(repository, session, mocker: MockerFixture):
mock_stmt.where.assert_called() mock_stmt.where.assert_called()
session_obj.execute.assert_called_once_with(mock_stmt) session_obj.execute.assert_called_once_with(mock_stmt)
session_obj.commit.assert_called_once() session_obj.commit.assert_called_once()
def test_to_db_model(repository):
"""Test _to_db_model method."""
# Create a domain model
domain_model = NodeExecution(
id="test-id",
workflow_id="test-workflow-id",
node_execution_id="test-node-execution-id",
workflow_run_id="test-workflow-run-id",
index=1,
predecessor_node_id="test-predecessor-id",
node_id="test-node-id",
node_type=NodeType.START,
title="Test Node",
inputs={"input_key": "input_value"},
process_data={"process_key": "process_value"},
outputs={"output_key": "output_value"},
status=NodeExecutionStatus.RUNNING,
error=None,
elapsed_time=1.5,
metadata={NodeRunMetadataKey.TOTAL_TOKENS: 100},
created_at=datetime.now(),
finished_at=None,
)
# Convert to DB model
db_model = repository._to_db_model(domain_model)
# Assert DB model has correct values
assert isinstance(db_model, WorkflowNodeExecution)
assert db_model.id == domain_model.id
assert db_model.tenant_id == repository._tenant_id
assert db_model.app_id == repository._app_id
assert db_model.workflow_id == domain_model.workflow_id
assert db_model.triggered_from == repository._triggered_from
assert db_model.workflow_run_id == domain_model.workflow_run_id
assert db_model.index == domain_model.index
assert db_model.predecessor_node_id == domain_model.predecessor_node_id
assert db_model.node_execution_id == domain_model.node_execution_id
assert db_model.node_id == domain_model.node_id
assert db_model.node_type == domain_model.node_type
assert db_model.title == domain_model.title
assert db_model.inputs_dict == domain_model.inputs
assert db_model.process_data_dict == domain_model.process_data
assert db_model.outputs_dict == domain_model.outputs
assert db_model.execution_metadata_dict == domain_model.metadata
assert db_model.status == domain_model.status
assert db_model.error == domain_model.error
assert db_model.elapsed_time == domain_model.elapsed_time
assert db_model.created_at == domain_model.created_at
assert db_model.created_by_role == repository._creator_user_role
assert db_model.created_by == repository._creator_user_id
assert db_model.finished_at == domain_model.finished_at
def test_to_domain_model(repository):
"""Test _to_domain_model method."""
# Create input dictionaries
inputs_dict = {"input_key": "input_value"}
process_data_dict = {"process_key": "process_value"}
outputs_dict = {"output_key": "output_value"}
metadata_dict = {str(NodeRunMetadataKey.TOTAL_TOKENS): 100}
# Create a DB model using our custom subclass
db_model = WorkflowNodeExecution()
db_model.id = "test-id"
db_model.tenant_id = "test-tenant-id"
db_model.app_id = "test-app-id"
db_model.workflow_id = "test-workflow-id"
db_model.triggered_from = "workflow-run"
db_model.workflow_run_id = "test-workflow-run-id"
db_model.index = 1
db_model.predecessor_node_id = "test-predecessor-id"
db_model.node_execution_id = "test-node-execution-id"
db_model.node_id = "test-node-id"
db_model.node_type = NodeType.START.value
db_model.title = "Test Node"
db_model.inputs = json.dumps(inputs_dict)
db_model.process_data = json.dumps(process_data_dict)
db_model.outputs = json.dumps(outputs_dict)
db_model.status = WorkflowNodeExecutionStatus.RUNNING
db_model.error = None
db_model.elapsed_time = 1.5
db_model.execution_metadata = json.dumps(metadata_dict)
db_model.created_at = datetime.now()
db_model.created_by_role = "account"
db_model.created_by = "test-user-id"
db_model.finished_at = None
# Convert to domain model
domain_model = repository._to_domain_model(db_model)
# Assert domain model has correct values
assert isinstance(domain_model, NodeExecution)
assert domain_model.id == db_model.id
assert domain_model.workflow_id == db_model.workflow_id
assert domain_model.workflow_run_id == db_model.workflow_run_id
assert domain_model.index == db_model.index
assert domain_model.predecessor_node_id == db_model.predecessor_node_id
assert domain_model.node_execution_id == db_model.node_execution_id
assert domain_model.node_id == db_model.node_id
assert domain_model.node_type == NodeType(db_model.node_type)
assert domain_model.title == db_model.title
assert domain_model.inputs == inputs_dict
assert domain_model.process_data == process_data_dict
assert domain_model.outputs == outputs_dict
assert domain_model.status == NodeExecutionStatus(db_model.status)
assert domain_model.error == db_model.error
assert domain_model.elapsed_time == db_model.elapsed_time
assert domain_model.metadata == metadata_dict
assert domain_model.created_at == db_model.created_at
assert domain_model.finished_at == db_model.finished_at