mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 22:35:54 +08:00
feat(workflow): domain model for workflow node execution (#19430)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
parent
aeceb200ec
commit
4977bb21ec
@ -1,3 +1,6 @@
|
|||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from flask_login import current_user
|
||||||
from flask_restful import Resource, marshal_with, reqparse
|
from flask_restful import Resource, marshal_with, reqparse
|
||||||
from flask_restful.inputs import int_range
|
from flask_restful.inputs import int_range
|
||||||
|
|
||||||
@ -12,8 +15,7 @@ from fields.workflow_run_fields import (
|
|||||||
)
|
)
|
||||||
from libs.helper import uuid_value
|
from libs.helper import uuid_value
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models import App
|
from models import Account, App, AppMode, EndUser
|
||||||
from models.model import AppMode
|
|
||||||
from services.workflow_run_service import WorkflowRunService
|
from services.workflow_run_service import WorkflowRunService
|
||||||
|
|
||||||
|
|
||||||
@ -90,7 +92,12 @@ class WorkflowRunNodeExecutionListApi(Resource):
|
|||||||
run_id = str(run_id)
|
run_id = str(run_id)
|
||||||
|
|
||||||
workflow_run_service = WorkflowRunService()
|
workflow_run_service = WorkflowRunService()
|
||||||
node_executions = workflow_run_service.get_workflow_run_node_executions(app_model=app_model, run_id=run_id)
|
user = cast("Account | EndUser", current_user)
|
||||||
|
node_executions = workflow_run_service.get_workflow_run_node_executions(
|
||||||
|
app_model=app_model,
|
||||||
|
run_id=run_id,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
|
||||||
return {"data": node_executions}
|
return {"data": node_executions}
|
||||||
|
|
||||||
|
@ -29,9 +29,7 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
|||||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
from models.account import Account
|
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||||
from models.model import App, Conversation, EndUser, Message
|
|
||||||
from models.workflow import Workflow
|
|
||||||
from services.conversation_service import ConversationService
|
from services.conversation_service import ConversationService
|
||||||
from services.errors.message import MessageNotExistsError
|
from services.errors.message import MessageNotExistsError
|
||||||
|
|
||||||
@ -165,8 +163,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
session_factory=session_factory,
|
session_factory=session_factory,
|
||||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
user=user,
|
||||||
app_id=application_generate_entity.app_config.app_id,
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._generate(
|
return self._generate(
|
||||||
@ -231,8 +230,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
session_factory=session_factory,
|
session_factory=session_factory,
|
||||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
user=user,
|
||||||
app_id=application_generate_entity.app_config.app_id,
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._generate(
|
return self._generate(
|
||||||
@ -295,8 +295,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
session_factory=session_factory,
|
session_factory=session_factory,
|
||||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
user=user,
|
||||||
app_id=application_generate_entity.app_config.app_id,
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._generate(
|
return self._generate(
|
||||||
|
@ -70,7 +70,7 @@ from events.message_event import message_was_created
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models import Conversation, EndUser, Message, MessageFile
|
from models import Conversation, EndUser, Message, MessageFile
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.enums import CreatedByRole
|
from models.enums import CreatorUserRole
|
||||||
from models.workflow import (
|
from models.workflow import (
|
||||||
Workflow,
|
Workflow,
|
||||||
WorkflowRunStatus,
|
WorkflowRunStatus,
|
||||||
@ -105,11 +105,11 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||||||
if isinstance(user, EndUser):
|
if isinstance(user, EndUser):
|
||||||
self._user_id = user.id
|
self._user_id = user.id
|
||||||
user_session_id = user.session_id
|
user_session_id = user.session_id
|
||||||
self._created_by_role = CreatedByRole.END_USER
|
self._created_by_role = CreatorUserRole.END_USER
|
||||||
elif isinstance(user, Account):
|
elif isinstance(user, Account):
|
||||||
self._user_id = user.id
|
self._user_id = user.id
|
||||||
user_session_id = user.id
|
user_session_id = user.id
|
||||||
self._created_by_role = CreatedByRole.ACCOUNT
|
self._created_by_role = CreatorUserRole.ACCOUNT
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"User type not supported: {type(user)}")
|
raise NotImplementedError(f"User type not supported: {type(user)}")
|
||||||
|
|
||||||
@ -739,9 +739,9 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||||||
url=file["remote_url"],
|
url=file["remote_url"],
|
||||||
belongs_to="assistant",
|
belongs_to="assistant",
|
||||||
upload_file_id=file["related_id"],
|
upload_file_id=file["related_id"],
|
||||||
created_by_role=CreatedByRole.ACCOUNT
|
created_by_role=CreatorUserRole.ACCOUNT
|
||||||
if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||||
else CreatedByRole.END_USER,
|
else CreatorUserRole.END_USER,
|
||||||
created_by=message.from_account_id or message.from_end_user_id or "",
|
created_by=message.from_account_id or message.from_end_user_id or "",
|
||||||
)
|
)
|
||||||
for file in self._recorded_files
|
for file in self._recorded_files
|
||||||
|
@ -25,7 +25,7 @@ from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBa
|
|||||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models import Account
|
from models import Account
|
||||||
from models.enums import CreatedByRole
|
from models.enums import CreatorUserRole
|
||||||
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
|
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
|
||||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||||
from services.errors.conversation import ConversationNotExistsError
|
from services.errors.conversation import ConversationNotExistsError
|
||||||
@ -223,7 +223,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||||||
belongs_to="user",
|
belongs_to="user",
|
||||||
url=file.remote_url,
|
url=file.remote_url,
|
||||||
upload_file_id=file.related_id,
|
upload_file_id=file.related_id,
|
||||||
created_by_role=(CreatedByRole.ACCOUNT if account_id else CreatedByRole.END_USER),
|
created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER),
|
||||||
created_by=account_id or end_user_id or "",
|
created_by=account_id or end_user_id or "",
|
||||||
)
|
)
|
||||||
db.session.add(message_file)
|
db.session.add(message_file)
|
||||||
|
@ -27,7 +27,7 @@ from core.workflow.repository.workflow_node_execution_repository import Workflow
|
|||||||
from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
from models import Account, App, EndUser, Workflow
|
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -138,10 +138,12 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
|
|
||||||
# Create workflow node execution repository
|
# Create workflow node execution repository
|
||||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
|
||||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
session_factory=session_factory,
|
session_factory=session_factory,
|
||||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
user=user,
|
||||||
app_id=application_generate_entity.app_config.app_id,
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._generate(
|
return self._generate(
|
||||||
@ -262,10 +264,12 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
|
|
||||||
# Create workflow node execution repository
|
# Create workflow node execution repository
|
||||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
|
||||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
session_factory=session_factory,
|
session_factory=session_factory,
|
||||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
user=user,
|
||||||
app_id=application_generate_entity.app_config.app_id,
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._generate(
|
return self._generate(
|
||||||
@ -325,10 +329,12 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
|
|
||||||
# Create workflow node execution repository
|
# Create workflow node execution repository
|
||||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
|
||||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
session_factory=session_factory,
|
session_factory=session_factory,
|
||||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
user=user,
|
||||||
app_id=application_generate_entity.app_config.app_id,
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._generate(
|
return self._generate(
|
||||||
|
@ -6,7 +6,7 @@ from pydantic import BaseModel, ConfigDict
|
|||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult
|
from core.model_runtime.entities.llm_entities import LLMResult
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey
|
||||||
from models.workflow import WorkflowNodeExecutionStatus
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
@ -244,7 +244,7 @@ class NodeStartStreamResponse(StreamResponse):
|
|||||||
title: str
|
title: str
|
||||||
index: int
|
index: int
|
||||||
predecessor_node_id: Optional[str] = None
|
predecessor_node_id: Optional[str] = None
|
||||||
inputs: Optional[dict] = None
|
inputs: Optional[Mapping[str, Any]] = None
|
||||||
created_at: int
|
created_at: int
|
||||||
extras: dict = {}
|
extras: dict = {}
|
||||||
parallel_id: Optional[str] = None
|
parallel_id: Optional[str] = None
|
||||||
@ -301,13 +301,13 @@ class NodeFinishStreamResponse(StreamResponse):
|
|||||||
title: str
|
title: str
|
||||||
index: int
|
index: int
|
||||||
predecessor_node_id: Optional[str] = None
|
predecessor_node_id: Optional[str] = None
|
||||||
inputs: Optional[dict] = None
|
inputs: Optional[Mapping[str, Any]] = None
|
||||||
process_data: Optional[dict] = None
|
process_data: Optional[Mapping[str, Any]] = None
|
||||||
outputs: Optional[dict] = None
|
outputs: Optional[Mapping[str, Any]] = None
|
||||||
status: str
|
status: str
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
elapsed_time: float
|
elapsed_time: float
|
||||||
execution_metadata: Optional[dict] = None
|
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||||
created_at: int
|
created_at: int
|
||||||
finished_at: int
|
finished_at: int
|
||||||
files: Optional[Sequence[Mapping[str, Any]]] = []
|
files: Optional[Sequence[Mapping[str, Any]]] = []
|
||||||
@ -370,13 +370,13 @@ class NodeRetryStreamResponse(StreamResponse):
|
|||||||
title: str
|
title: str
|
||||||
index: int
|
index: int
|
||||||
predecessor_node_id: Optional[str] = None
|
predecessor_node_id: Optional[str] = None
|
||||||
inputs: Optional[dict] = None
|
inputs: Optional[Mapping[str, Any]] = None
|
||||||
process_data: Optional[dict] = None
|
process_data: Optional[Mapping[str, Any]] = None
|
||||||
outputs: Optional[dict] = None
|
outputs: Optional[Mapping[str, Any]] = None
|
||||||
status: str
|
status: str
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
elapsed_time: float
|
elapsed_time: float
|
||||||
execution_metadata: Optional[dict] = None
|
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||||
created_at: int
|
created_at: int
|
||||||
finished_at: int
|
finished_at: int
|
||||||
files: Optional[Sequence[Mapping[str, Any]]] = []
|
files: Optional[Sequence[Mapping[str, Any]]] = []
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
@ -155,10 +156,10 @@ class LangfuseSpan(BaseModel):
|
|||||||
description="The status message of the span. Additional field for context of the event. E.g. the error "
|
description="The status message of the span. Additional field for context of the event. E.g. the error "
|
||||||
"message of an error event.",
|
"message of an error event.",
|
||||||
)
|
)
|
||||||
input: Optional[Union[str, dict[str, Any], list, None]] = Field(
|
input: Optional[Union[str, Mapping[str, Any], list, None]] = Field(
|
||||||
default=None, description="The input of the span. Can be any JSON object."
|
default=None, description="The input of the span. Can be any JSON object."
|
||||||
)
|
)
|
||||||
output: Optional[Union[str, dict[str, Any], list, None]] = Field(
|
output: Optional[Union[str, Mapping[str, Any], list, None]] = Field(
|
||||||
default=None, description="The output of the span. Can be any JSON object."
|
default=None, description="The output of the span. Can be any JSON object."
|
||||||
)
|
)
|
||||||
version: Optional[str] = Field(
|
version: Optional[str] = Field(
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from langfuse import Langfuse # type: ignore
|
from langfuse import Langfuse # type: ignore
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
from core.ops.base_trace_instance import BaseTraceInstance
|
from core.ops.base_trace_instance import BaseTraceInstance
|
||||||
from core.ops.entities.config_entity import LangfuseConfig
|
from core.ops.entities.config_entity import LangfuseConfig
|
||||||
@ -30,8 +29,9 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
|
|||||||
)
|
)
|
||||||
from core.ops.utils import filter_none_values
|
from core.ops.utils import filter_none_values
|
||||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||||
|
from core.workflow.nodes.enums import NodeType
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import EndUser
|
from models import Account, App, EndUser, WorkflowNodeExecutionTriggeredFrom
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -113,8 +113,29 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||||||
|
|
||||||
# through workflow_run_id get all_nodes_execution using repository
|
# through workflow_run_id get all_nodes_execution using repository
|
||||||
session_factory = sessionmaker(bind=db.engine)
|
session_factory = sessionmaker(bind=db.engine)
|
||||||
|
# Find the app's creator account
|
||||||
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
|
# Get the app to find its creator
|
||||||
|
app_id = trace_info.metadata.get("app_id")
|
||||||
|
if not app_id:
|
||||||
|
raise ValueError("No app_id found in trace_info metadata")
|
||||||
|
|
||||||
|
app = session.query(App).filter(App.id == app_id).first()
|
||||||
|
if not app:
|
||||||
|
raise ValueError(f"App with id {app_id} not found")
|
||||||
|
|
||||||
|
if not app.created_by:
|
||||||
|
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||||
|
|
||||||
|
service_account = session.query(Account).filter(Account.id == app.created_by).first()
|
||||||
|
if not service_account:
|
||||||
|
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||||
|
|
||||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
session_factory=session_factory, tenant_id=trace_info.tenant_id
|
session_factory=session_factory,
|
||||||
|
user=service_account,
|
||||||
|
app_id=trace_info.metadata.get("app_id"),
|
||||||
|
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get all executions for this workflow run
|
# Get all executions for this workflow run
|
||||||
@ -124,23 +145,22 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||||||
|
|
||||||
for node_execution in workflow_node_executions:
|
for node_execution in workflow_node_executions:
|
||||||
node_execution_id = node_execution.id
|
node_execution_id = node_execution.id
|
||||||
tenant_id = node_execution.tenant_id
|
tenant_id = trace_info.tenant_id # Use from trace_info instead
|
||||||
app_id = node_execution.app_id
|
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
|
||||||
node_name = node_execution.title
|
node_name = node_execution.title
|
||||||
node_type = node_execution.node_type
|
node_type = node_execution.node_type
|
||||||
status = node_execution.status
|
status = node_execution.status
|
||||||
if node_type == "llm":
|
if node_type == NodeType.LLM:
|
||||||
inputs = (
|
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||||
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
|
inputs = node_execution.inputs if node_execution.inputs else {}
|
||||||
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
|
outputs = node_execution.outputs if node_execution.outputs else {}
|
||||||
created_at = node_execution.created_at or datetime.now()
|
created_at = node_execution.created_at or datetime.now()
|
||||||
elapsed_time = node_execution.elapsed_time
|
elapsed_time = node_execution.elapsed_time
|
||||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||||
|
|
||||||
metadata = json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
|
execution_metadata = node_execution.metadata if node_execution.metadata else {}
|
||||||
|
metadata = {str(k): v for k, v in execution_metadata.items()}
|
||||||
metadata.update(
|
metadata.update(
|
||||||
{
|
{
|
||||||
"workflow_run_id": trace_info.workflow_run_id,
|
"workflow_run_id": trace_info.workflow_run_id,
|
||||||
@ -152,7 +172,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||||||
"status": status,
|
"status": status,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
|
process_data = node_execution.process_data if node_execution.process_data else {}
|
||||||
model_provider = process_data.get("model_provider", None)
|
model_provider = process_data.get("model_provider", None)
|
||||||
model_name = process_data.get("model_name", None)
|
model_name = process_data.get("model_name", None)
|
||||||
if model_provider is not None and model_name is not None:
|
if model_provider is not None and model_name is not None:
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
@ -30,8 +31,8 @@ class LangSmithMultiModel(BaseModel):
|
|||||||
|
|
||||||
class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
|
class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
|
||||||
name: Optional[str] = Field(..., description="Name of the run")
|
name: Optional[str] = Field(..., description="Name of the run")
|
||||||
inputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Inputs of the run")
|
inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the run")
|
||||||
outputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Outputs of the run")
|
outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the run")
|
||||||
run_type: LangSmithRunType = Field(..., description="Type of the run")
|
run_type: LangSmithRunType = Field(..., description="Type of the run")
|
||||||
start_time: Optional[datetime | str] = Field(None, description="Start time of the run")
|
start_time: Optional[datetime | str] = Field(None, description="Start time of the run")
|
||||||
end_time: Optional[datetime | str] = Field(None, description="End time of the run")
|
end_time: Optional[datetime | str] = Field(None, description="End time of the run")
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
@ -7,7 +6,7 @@ from typing import Optional, cast
|
|||||||
|
|
||||||
from langsmith import Client
|
from langsmith import Client
|
||||||
from langsmith.schemas import RunBase
|
from langsmith.schemas import RunBase
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
from core.ops.base_trace_instance import BaseTraceInstance
|
from core.ops.base_trace_instance import BaseTraceInstance
|
||||||
from core.ops.entities.config_entity import LangSmithConfig
|
from core.ops.entities.config_entity import LangSmithConfig
|
||||||
@ -29,8 +28,10 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
|
|||||||
)
|
)
|
||||||
from core.ops.utils import filter_none_values, generate_dotted_order
|
from core.ops.utils import filter_none_values, generate_dotted_order
|
||||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||||
|
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||||
|
from core.workflow.nodes.enums import NodeType
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import EndUser, MessageFile
|
from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -137,8 +138,29 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||||||
|
|
||||||
# through workflow_run_id get all_nodes_execution using repository
|
# through workflow_run_id get all_nodes_execution using repository
|
||||||
session_factory = sessionmaker(bind=db.engine)
|
session_factory = sessionmaker(bind=db.engine)
|
||||||
|
# Find the app's creator account
|
||||||
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
|
# Get the app to find its creator
|
||||||
|
app_id = trace_info.metadata.get("app_id")
|
||||||
|
if not app_id:
|
||||||
|
raise ValueError("No app_id found in trace_info metadata")
|
||||||
|
|
||||||
|
app = session.query(App).filter(App.id == app_id).first()
|
||||||
|
if not app:
|
||||||
|
raise ValueError(f"App with id {app_id} not found")
|
||||||
|
|
||||||
|
if not app.created_by:
|
||||||
|
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||||
|
|
||||||
|
service_account = session.query(Account).filter(Account.id == app.created_by).first()
|
||||||
|
if not service_account:
|
||||||
|
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||||
|
|
||||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
session_factory=session_factory, tenant_id=trace_info.tenant_id, app_id=trace_info.metadata.get("app_id")
|
session_factory=session_factory,
|
||||||
|
user=service_account,
|
||||||
|
app_id=trace_info.metadata.get("app_id"),
|
||||||
|
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get all executions for this workflow run
|
# Get all executions for this workflow run
|
||||||
@ -148,27 +170,23 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||||||
|
|
||||||
for node_execution in workflow_node_executions:
|
for node_execution in workflow_node_executions:
|
||||||
node_execution_id = node_execution.id
|
node_execution_id = node_execution.id
|
||||||
tenant_id = node_execution.tenant_id
|
tenant_id = trace_info.tenant_id # Use from trace_info instead
|
||||||
app_id = node_execution.app_id
|
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
|
||||||
node_name = node_execution.title
|
node_name = node_execution.title
|
||||||
node_type = node_execution.node_type
|
node_type = node_execution.node_type
|
||||||
status = node_execution.status
|
status = node_execution.status
|
||||||
if node_type == "llm":
|
if node_type == NodeType.LLM:
|
||||||
inputs = (
|
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||||
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
|
inputs = node_execution.inputs if node_execution.inputs else {}
|
||||||
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
|
outputs = node_execution.outputs if node_execution.outputs else {}
|
||||||
created_at = node_execution.created_at or datetime.now()
|
created_at = node_execution.created_at or datetime.now()
|
||||||
elapsed_time = node_execution.elapsed_time
|
elapsed_time = node_execution.elapsed_time
|
||||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||||
|
|
||||||
execution_metadata = (
|
execution_metadata = node_execution.metadata if node_execution.metadata else {}
|
||||||
json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
|
node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0
|
||||||
)
|
metadata = {str(key): value for key, value in execution_metadata.items()}
|
||||||
node_total_tokens = execution_metadata.get("total_tokens", 0)
|
|
||||||
metadata = execution_metadata.copy()
|
|
||||||
metadata.update(
|
metadata.update(
|
||||||
{
|
{
|
||||||
"workflow_run_id": trace_info.workflow_run_id,
|
"workflow_run_id": trace_info.workflow_run_id,
|
||||||
@ -181,7 +199,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
|
process_data = node_execution.process_data if node_execution.process_data else {}
|
||||||
|
|
||||||
if process_data and process_data.get("model_mode") == "chat":
|
if process_data and process_data.get("model_mode") == "chat":
|
||||||
run_type = LangSmithRunType.llm
|
run_type = LangSmithRunType.llm
|
||||||
@ -191,7 +209,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||||||
"ls_model_name": process_data.get("model_name", ""),
|
"ls_model_name": process_data.get("model_name", ""),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
elif node_type == "knowledge-retrieval":
|
elif node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||||
run_type = LangSmithRunType.retriever
|
run_type = LangSmithRunType.retriever
|
||||||
else:
|
else:
|
||||||
run_type = LangSmithRunType.tool
|
run_type = LangSmithRunType.tool
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
@ -7,7 +6,7 @@ from typing import Optional, cast
|
|||||||
|
|
||||||
from opik import Opik, Trace
|
from opik import Opik, Trace
|
||||||
from opik.id_helpers import uuid4_to_uuid7
|
from opik.id_helpers import uuid4_to_uuid7
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
from core.ops.base_trace_instance import BaseTraceInstance
|
from core.ops.base_trace_instance import BaseTraceInstance
|
||||||
from core.ops.entities.config_entity import OpikConfig
|
from core.ops.entities.config_entity import OpikConfig
|
||||||
@ -23,8 +22,10 @@ from core.ops.entities.trace_entity import (
|
|||||||
WorkflowTraceInfo,
|
WorkflowTraceInfo,
|
||||||
)
|
)
|
||||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||||
|
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||||
|
from core.workflow.nodes.enums import NodeType
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import EndUser, MessageFile
|
from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -150,8 +151,29 @@ class OpikDataTrace(BaseTraceInstance):
|
|||||||
|
|
||||||
# through workflow_run_id get all_nodes_execution using repository
|
# through workflow_run_id get all_nodes_execution using repository
|
||||||
session_factory = sessionmaker(bind=db.engine)
|
session_factory = sessionmaker(bind=db.engine)
|
||||||
|
# Find the app's creator account
|
||||||
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
|
# Get the app to find its creator
|
||||||
|
app_id = trace_info.metadata.get("app_id")
|
||||||
|
if not app_id:
|
||||||
|
raise ValueError("No app_id found in trace_info metadata")
|
||||||
|
|
||||||
|
app = session.query(App).filter(App.id == app_id).first()
|
||||||
|
if not app:
|
||||||
|
raise ValueError(f"App with id {app_id} not found")
|
||||||
|
|
||||||
|
if not app.created_by:
|
||||||
|
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||||
|
|
||||||
|
service_account = session.query(Account).filter(Account.id == app.created_by).first()
|
||||||
|
if not service_account:
|
||||||
|
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||||
|
|
||||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
session_factory=session_factory, tenant_id=trace_info.tenant_id, app_id=trace_info.metadata.get("app_id")
|
session_factory=session_factory,
|
||||||
|
user=service_account,
|
||||||
|
app_id=trace_info.metadata.get("app_id"),
|
||||||
|
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get all executions for this workflow run
|
# Get all executions for this workflow run
|
||||||
@ -161,26 +183,22 @@ class OpikDataTrace(BaseTraceInstance):
|
|||||||
|
|
||||||
for node_execution in workflow_node_executions:
|
for node_execution in workflow_node_executions:
|
||||||
node_execution_id = node_execution.id
|
node_execution_id = node_execution.id
|
||||||
tenant_id = node_execution.tenant_id
|
tenant_id = trace_info.tenant_id # Use from trace_info instead
|
||||||
app_id = node_execution.app_id
|
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
|
||||||
node_name = node_execution.title
|
node_name = node_execution.title
|
||||||
node_type = node_execution.node_type
|
node_type = node_execution.node_type
|
||||||
status = node_execution.status
|
status = node_execution.status
|
||||||
if node_type == "llm":
|
if node_type == NodeType.LLM:
|
||||||
inputs = (
|
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||||
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
|
inputs = node_execution.inputs if node_execution.inputs else {}
|
||||||
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
|
outputs = node_execution.outputs if node_execution.outputs else {}
|
||||||
created_at = node_execution.created_at or datetime.now()
|
created_at = node_execution.created_at or datetime.now()
|
||||||
elapsed_time = node_execution.elapsed_time
|
elapsed_time = node_execution.elapsed_time
|
||||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||||
|
|
||||||
execution_metadata = (
|
execution_metadata = node_execution.metadata if node_execution.metadata else {}
|
||||||
json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
|
metadata = {str(k): v for k, v in execution_metadata.items()}
|
||||||
)
|
|
||||||
metadata = execution_metadata.copy()
|
|
||||||
metadata.update(
|
metadata.update(
|
||||||
{
|
{
|
||||||
"workflow_run_id": trace_info.workflow_run_id,
|
"workflow_run_id": trace_info.workflow_run_id,
|
||||||
@ -193,7 +211,7 @@ class OpikDataTrace(BaseTraceInstance):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
|
process_data = node_execution.process_data if node_execution.process_data else {}
|
||||||
|
|
||||||
provider = None
|
provider = None
|
||||||
model = None
|
model = None
|
||||||
@ -226,7 +244,7 @@ class OpikDataTrace(BaseTraceInstance):
|
|||||||
parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id
|
parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id
|
||||||
|
|
||||||
if not total_tokens:
|
if not total_tokens:
|
||||||
total_tokens = execution_metadata.get("total_tokens", 0)
|
total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0
|
||||||
|
|
||||||
span_data = {
|
span_data = {
|
||||||
"trace_id": opik_trace_id,
|
"trace_id": opik_trace_id,
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
@ -19,8 +20,8 @@ class WeaveMultiModel(BaseModel):
|
|||||||
class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
|
class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
|
||||||
id: str = Field(..., description="ID of the trace")
|
id: str = Field(..., description="ID of the trace")
|
||||||
op: str = Field(..., description="Name of the operation")
|
op: str = Field(..., description="Name of the operation")
|
||||||
inputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Inputs of the trace")
|
inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the trace")
|
||||||
outputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Outputs of the trace")
|
outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the trace")
|
||||||
attributes: Optional[Union[str, dict[str, Any], list, None]] = Field(
|
attributes: Optional[Union[str, dict[str, Any], list, None]] = Field(
|
||||||
None, description="Metadata and attributes associated with trace"
|
None, description="Metadata and attributes associated with trace"
|
||||||
)
|
)
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
@ -7,6 +6,7 @@ from typing import Any, Optional, cast
|
|||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
import weave
|
import weave
|
||||||
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
from core.ops.base_trace_instance import BaseTraceInstance
|
from core.ops.base_trace_instance import BaseTraceInstance
|
||||||
from core.ops.entities.config_entity import WeaveConfig
|
from core.ops.entities.config_entity import WeaveConfig
|
||||||
@ -22,9 +22,11 @@ from core.ops.entities.trace_entity import (
|
|||||||
WorkflowTraceInfo,
|
WorkflowTraceInfo,
|
||||||
)
|
)
|
||||||
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
|
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
|
||||||
|
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||||
|
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||||
|
from core.workflow.nodes.enums import NodeType
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import EndUser, MessageFile
|
from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||||
from models.workflow import WorkflowNodeExecution
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -128,58 +130,57 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||||||
|
|
||||||
self.start_call(workflow_run, parent_run_id=trace_info.message_id)
|
self.start_call(workflow_run, parent_run_id=trace_info.message_id)
|
||||||
|
|
||||||
# through workflow_run_id get all_nodes_execution
|
# through workflow_run_id get all_nodes_execution using repository
|
||||||
workflow_nodes_execution_id_records = (
|
session_factory = sessionmaker(bind=db.engine)
|
||||||
db.session.query(WorkflowNodeExecution.id)
|
# Find the app's creator account
|
||||||
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
.all()
|
# Get the app to find its creator
|
||||||
|
app_id = trace_info.metadata.get("app_id")
|
||||||
|
if not app_id:
|
||||||
|
raise ValueError("No app_id found in trace_info metadata")
|
||||||
|
|
||||||
|
app = session.query(App).filter(App.id == app_id).first()
|
||||||
|
if not app:
|
||||||
|
raise ValueError(f"App with id {app_id} not found")
|
||||||
|
|
||||||
|
if not app.created_by:
|
||||||
|
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||||
|
|
||||||
|
service_account = session.query(Account).filter(Account.id == app.created_by).first()
|
||||||
|
if not service_account:
|
||||||
|
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||||
|
|
||||||
|
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
|
session_factory=session_factory,
|
||||||
|
user=service_account,
|
||||||
|
app_id=trace_info.metadata.get("app_id"),
|
||||||
|
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||||
)
|
)
|
||||||
|
|
||||||
for node_execution_id_record in workflow_nodes_execution_id_records:
|
# Get all executions for this workflow run
|
||||||
node_execution = (
|
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
|
||||||
db.session.query(
|
workflow_run_id=trace_info.workflow_run_id
|
||||||
WorkflowNodeExecution.id,
|
|
||||||
WorkflowNodeExecution.tenant_id,
|
|
||||||
WorkflowNodeExecution.app_id,
|
|
||||||
WorkflowNodeExecution.title,
|
|
||||||
WorkflowNodeExecution.node_type,
|
|
||||||
WorkflowNodeExecution.status,
|
|
||||||
WorkflowNodeExecution.inputs,
|
|
||||||
WorkflowNodeExecution.outputs,
|
|
||||||
WorkflowNodeExecution.created_at,
|
|
||||||
WorkflowNodeExecution.elapsed_time,
|
|
||||||
WorkflowNodeExecution.process_data,
|
|
||||||
WorkflowNodeExecution.execution_metadata,
|
|
||||||
)
|
|
||||||
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
|
|
||||||
.first()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not node_execution:
|
for node_execution in workflow_node_executions:
|
||||||
continue
|
|
||||||
|
|
||||||
node_execution_id = node_execution.id
|
node_execution_id = node_execution.id
|
||||||
tenant_id = node_execution.tenant_id
|
tenant_id = trace_info.tenant_id # Use from trace_info instead
|
||||||
app_id = node_execution.app_id
|
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
|
||||||
node_name = node_execution.title
|
node_name = node_execution.title
|
||||||
node_type = node_execution.node_type
|
node_type = node_execution.node_type
|
||||||
status = node_execution.status
|
status = node_execution.status
|
||||||
if node_type == "llm":
|
if node_type == NodeType.LLM:
|
||||||
inputs = (
|
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||||
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
|
inputs = node_execution.inputs if node_execution.inputs else {}
|
||||||
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
|
outputs = node_execution.outputs if node_execution.outputs else {}
|
||||||
created_at = node_execution.created_at or datetime.now()
|
created_at = node_execution.created_at or datetime.now()
|
||||||
elapsed_time = node_execution.elapsed_time
|
elapsed_time = node_execution.elapsed_time
|
||||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||||
|
|
||||||
execution_metadata = (
|
execution_metadata = node_execution.metadata if node_execution.metadata else {}
|
||||||
json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
|
node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0
|
||||||
)
|
attributes = {str(k): v for k, v in execution_metadata.items()}
|
||||||
node_total_tokens = execution_metadata.get("total_tokens", 0)
|
|
||||||
attributes = execution_metadata.copy()
|
|
||||||
attributes.update(
|
attributes.update(
|
||||||
{
|
{
|
||||||
"workflow_run_id": trace_info.workflow_run_id,
|
"workflow_run_id": trace_info.workflow_run_id,
|
||||||
@ -192,7 +193,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
|
process_data = node_execution.process_data if node_execution.process_data else {}
|
||||||
if process_data and process_data.get("model_mode") == "chat":
|
if process_data and process_data.get("model_mode") == "chat":
|
||||||
attributes.update(
|
attributes.update(
|
||||||
{
|
{
|
||||||
|
@ -19,7 +19,7 @@ from core.rag.extractor.extractor_base import BaseExtractor
|
|||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_storage import storage
|
from extensions.ext_storage import storage
|
||||||
from models.enums import CreatedByRole
|
from models.enums import CreatorUserRole
|
||||||
from models.model import UploadFile
|
from models.model import UploadFile
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -116,7 +116,7 @@ class WordExtractor(BaseExtractor):
|
|||||||
extension=str(image_ext),
|
extension=str(image_ext),
|
||||||
mime_type=mime_type or "",
|
mime_type=mime_type or "",
|
||||||
created_by=self.user_id,
|
created_by=self.user_id,
|
||||||
created_by_role=CreatedByRole.ACCOUNT,
|
created_by_role=CreatorUserRole.ACCOUNT,
|
||||||
created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||||
used=True,
|
used=True,
|
||||||
used_by=self.user_id,
|
used_by=self.user_id,
|
||||||
|
@ -2,16 +2,29 @@
|
|||||||
SQLAlchemy implementation of the WorkflowNodeExecutionRepository.
|
SQLAlchemy implementation of the WorkflowNodeExecutionRepository.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
from sqlalchemy import UnaryExpression, asc, delete, desc, select
|
from sqlalchemy import UnaryExpression, asc, delete, desc, select
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
from core.workflow.entities.node_execution_entities import (
|
||||||
|
NodeExecution,
|
||||||
|
NodeExecutionStatus,
|
||||||
|
)
|
||||||
|
from core.workflow.nodes.enums import NodeType
|
||||||
from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
|
from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
|
||||||
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom
|
from models import (
|
||||||
|
Account,
|
||||||
|
CreatorUserRole,
|
||||||
|
EndUser,
|
||||||
|
WorkflowNodeExecution,
|
||||||
|
WorkflowNodeExecutionStatus,
|
||||||
|
WorkflowNodeExecutionTriggeredFrom,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -23,16 +36,26 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||||||
This implementation supports multi-tenancy by filtering operations based on tenant_id.
|
This implementation supports multi-tenancy by filtering operations based on tenant_id.
|
||||||
Each method creates its own session, handles the transaction, and commits changes
|
Each method creates its own session, handles the transaction, and commits changes
|
||||||
to the database. This prevents long-running connections in the workflow core.
|
to the database. This prevents long-running connections in the workflow core.
|
||||||
|
|
||||||
|
This implementation also includes an in-memory cache for node executions to improve
|
||||||
|
performance by reducing database queries.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, session_factory: sessionmaker | Engine, tenant_id: str, app_id: Optional[str] = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
session_factory: sessionmaker | Engine,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
app_id: Optional[str],
|
||||||
|
triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom],
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the repository with a SQLAlchemy sessionmaker or engine and tenant context.
|
Initialize the repository with a SQLAlchemy sessionmaker or engine and context information.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_factory: SQLAlchemy sessionmaker or engine for creating sessions
|
session_factory: SQLAlchemy sessionmaker or engine for creating sessions
|
||||||
tenant_id: Tenant ID for multi-tenancy
|
user: Account or EndUser object containing tenant_id, user ID, and role information
|
||||||
app_id: Optional app ID for filtering by application
|
app_id: App ID for filtering by application (can be None)
|
||||||
|
triggered_from: Source of the execution trigger (SINGLE_STEP or WORKFLOW_RUN)
|
||||||
"""
|
"""
|
||||||
# If an engine is provided, create a sessionmaker from it
|
# If an engine is provided, create a sessionmaker from it
|
||||||
if isinstance(session_factory, Engine):
|
if isinstance(session_factory, Engine):
|
||||||
@ -44,38 +67,155 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||||||
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
|
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Extract tenant_id from user
|
||||||
|
tenant_id: str | None = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id
|
||||||
|
if not tenant_id:
|
||||||
|
raise ValueError("User must have a tenant_id or current_tenant_id")
|
||||||
self._tenant_id = tenant_id
|
self._tenant_id = tenant_id
|
||||||
|
|
||||||
|
# Store app context
|
||||||
self._app_id = app_id
|
self._app_id = app_id
|
||||||
|
|
||||||
def save(self, execution: WorkflowNodeExecution) -> None:
|
# Extract user context
|
||||||
|
self._triggered_from = triggered_from
|
||||||
|
self._creator_user_id = user.id
|
||||||
|
|
||||||
|
# Determine user role based on user type
|
||||||
|
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
|
||||||
|
|
||||||
|
# Initialize in-memory cache for node executions
|
||||||
|
# Key: node_execution_id, Value: NodeExecution
|
||||||
|
self._node_execution_cache: dict[str, NodeExecution] = {}
|
||||||
|
|
||||||
|
def _to_domain_model(self, db_model: WorkflowNodeExecution) -> NodeExecution:
|
||||||
"""
|
"""
|
||||||
Save a WorkflowNodeExecution instance and commit changes to the database.
|
Convert a database model to a domain model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
execution: The WorkflowNodeExecution instance to save
|
db_model: The database model to convert
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The domain model
|
||||||
|
"""
|
||||||
|
# Parse JSON fields
|
||||||
|
inputs = db_model.inputs_dict
|
||||||
|
process_data = db_model.process_data_dict
|
||||||
|
outputs = db_model.outputs_dict
|
||||||
|
metadata = db_model.execution_metadata_dict
|
||||||
|
|
||||||
|
# Convert status to domain enum
|
||||||
|
status = NodeExecutionStatus(db_model.status)
|
||||||
|
|
||||||
|
return NodeExecution(
|
||||||
|
id=db_model.id,
|
||||||
|
node_execution_id=db_model.node_execution_id,
|
||||||
|
workflow_id=db_model.workflow_id,
|
||||||
|
workflow_run_id=db_model.workflow_run_id,
|
||||||
|
index=db_model.index,
|
||||||
|
predecessor_node_id=db_model.predecessor_node_id,
|
||||||
|
node_id=db_model.node_id,
|
||||||
|
node_type=NodeType(db_model.node_type),
|
||||||
|
title=db_model.title,
|
||||||
|
inputs=inputs,
|
||||||
|
process_data=process_data,
|
||||||
|
outputs=outputs,
|
||||||
|
status=status,
|
||||||
|
error=db_model.error,
|
||||||
|
elapsed_time=db_model.elapsed_time,
|
||||||
|
metadata=metadata,
|
||||||
|
created_at=db_model.created_at,
|
||||||
|
finished_at=db_model.finished_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _to_db_model(self, domain_model: NodeExecution) -> WorkflowNodeExecution:
|
||||||
|
"""
|
||||||
|
Convert a domain model to a database model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
domain_model: The domain model to convert
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The database model
|
||||||
|
"""
|
||||||
|
# Use values from constructor if provided
|
||||||
|
if not self._triggered_from:
|
||||||
|
raise ValueError("triggered_from is required in repository constructor")
|
||||||
|
if not self._creator_user_id:
|
||||||
|
raise ValueError("created_by is required in repository constructor")
|
||||||
|
if not self._creator_user_role:
|
||||||
|
raise ValueError("created_by_role is required in repository constructor")
|
||||||
|
|
||||||
|
db_model = WorkflowNodeExecution()
|
||||||
|
db_model.id = domain_model.id
|
||||||
|
db_model.tenant_id = self._tenant_id
|
||||||
|
if self._app_id is not None:
|
||||||
|
db_model.app_id = self._app_id
|
||||||
|
db_model.workflow_id = domain_model.workflow_id
|
||||||
|
db_model.triggered_from = self._triggered_from
|
||||||
|
db_model.workflow_run_id = domain_model.workflow_run_id
|
||||||
|
db_model.index = domain_model.index
|
||||||
|
db_model.predecessor_node_id = domain_model.predecessor_node_id
|
||||||
|
db_model.node_execution_id = domain_model.node_execution_id
|
||||||
|
db_model.node_id = domain_model.node_id
|
||||||
|
db_model.node_type = domain_model.node_type
|
||||||
|
db_model.title = domain_model.title
|
||||||
|
db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None
|
||||||
|
db_model.process_data = json.dumps(domain_model.process_data) if domain_model.process_data else None
|
||||||
|
db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None
|
||||||
|
db_model.status = domain_model.status
|
||||||
|
db_model.error = domain_model.error
|
||||||
|
db_model.elapsed_time = domain_model.elapsed_time
|
||||||
|
db_model.execution_metadata = json.dumps(domain_model.metadata) if domain_model.metadata else None
|
||||||
|
db_model.created_at = domain_model.created_at
|
||||||
|
db_model.created_by_role = self._creator_user_role
|
||||||
|
db_model.created_by = self._creator_user_id
|
||||||
|
db_model.finished_at = domain_model.finished_at
|
||||||
|
return db_model
|
||||||
|
|
||||||
|
def save(self, execution: NodeExecution) -> None:
|
||||||
|
"""
|
||||||
|
Save or update a NodeExecution instance and commit changes to the database.
|
||||||
|
|
||||||
|
This method handles both creating new records and updating existing ones.
|
||||||
|
It determines whether to create or update based on whether the record
|
||||||
|
already exists in the database. It also updates the in-memory cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
execution: The NodeExecution instance to save or update
|
||||||
"""
|
"""
|
||||||
with self._session_factory() as session:
|
with self._session_factory() as session:
|
||||||
# Ensure tenant_id is set
|
# Convert domain model to database model using instance attributes
|
||||||
if not execution.tenant_id:
|
db_model = self._to_db_model(execution)
|
||||||
execution.tenant_id = self._tenant_id
|
|
||||||
|
|
||||||
# Set app_id if provided and not already set
|
# Use merge which will handle both insert and update
|
||||||
if self._app_id and not execution.app_id:
|
session.merge(db_model)
|
||||||
execution.app_id = self._app_id
|
|
||||||
|
|
||||||
session.add(execution)
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
|
# Update the cache if node_execution_id is present
|
||||||
|
if execution.node_execution_id:
|
||||||
|
logger.debug(f"Updating cache for node_execution_id: {execution.node_execution_id}")
|
||||||
|
self._node_execution_cache[execution.node_execution_id] = execution
|
||||||
|
|
||||||
|
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[NodeExecution]:
|
||||||
"""
|
"""
|
||||||
Retrieve a WorkflowNodeExecution by its node_execution_id.
|
Retrieve a NodeExecution by its node_execution_id.
|
||||||
|
|
||||||
|
First checks the in-memory cache, and if not found, queries the database.
|
||||||
|
If found in the database, adds it to the cache for future lookups.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_execution_id: The node execution ID
|
node_execution_id: The node execution ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The WorkflowNodeExecution instance if found, None otherwise
|
The NodeExecution instance if found, None otherwise
|
||||||
"""
|
"""
|
||||||
|
# First check the cache
|
||||||
|
if node_execution_id in self._node_execution_cache:
|
||||||
|
logger.debug(f"Cache hit for node_execution_id: {node_execution_id}")
|
||||||
|
return self._node_execution_cache[node_execution_id]
|
||||||
|
|
||||||
|
# If not in cache, query the database
|
||||||
|
logger.debug(f"Cache miss for node_execution_id: {node_execution_id}, querying database")
|
||||||
with self._session_factory() as session:
|
with self._session_factory() as session:
|
||||||
stmt = select(WorkflowNodeExecution).where(
|
stmt = select(WorkflowNodeExecution).where(
|
||||||
WorkflowNodeExecution.node_execution_id == node_execution_id,
|
WorkflowNodeExecution.node_execution_id == node_execution_id,
|
||||||
@ -85,15 +225,28 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||||||
if self._app_id:
|
if self._app_id:
|
||||||
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||||
|
|
||||||
return session.scalar(stmt)
|
db_model = session.scalar(stmt)
|
||||||
|
if db_model:
|
||||||
|
# Convert to domain model
|
||||||
|
domain_model = self._to_domain_model(db_model)
|
||||||
|
|
||||||
|
# Add to cache
|
||||||
|
self._node_execution_cache[node_execution_id] = domain_model
|
||||||
|
|
||||||
|
return domain_model
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def get_by_workflow_run(
|
def get_by_workflow_run(
|
||||||
self,
|
self,
|
||||||
workflow_run_id: str,
|
workflow_run_id: str,
|
||||||
order_config: Optional[OrderConfig] = None,
|
order_config: Optional[OrderConfig] = None,
|
||||||
) -> Sequence[WorkflowNodeExecution]:
|
) -> Sequence[NodeExecution]:
|
||||||
"""
|
"""
|
||||||
Retrieve all WorkflowNodeExecution instances for a specific workflow run.
|
Retrieve all NodeExecution instances for a specific workflow run.
|
||||||
|
|
||||||
|
This method always queries the database to ensure complete and ordered results,
|
||||||
|
but updates the cache with any retrieved executions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
workflow_run_id: The workflow run ID
|
workflow_run_id: The workflow run ID
|
||||||
@ -102,7 +255,42 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||||||
order_config.order_direction: Direction to order ("asc" or "desc")
|
order_config.order_direction: Direction to order ("asc" or "desc")
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of WorkflowNodeExecution instances
|
A list of NodeExecution instances
|
||||||
|
"""
|
||||||
|
# Get the raw database models using the new method
|
||||||
|
db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config)
|
||||||
|
|
||||||
|
# Convert database models to domain models and update cache
|
||||||
|
domain_models = []
|
||||||
|
for model in db_models:
|
||||||
|
domain_model = self._to_domain_model(model)
|
||||||
|
# Update cache if node_execution_id is present
|
||||||
|
if domain_model.node_execution_id:
|
||||||
|
self._node_execution_cache[domain_model.node_execution_id] = domain_model
|
||||||
|
domain_models.append(domain_model)
|
||||||
|
|
||||||
|
return domain_models
|
||||||
|
|
||||||
|
def get_db_models_by_workflow_run(
|
||||||
|
self,
|
||||||
|
workflow_run_id: str,
|
||||||
|
order_config: Optional[OrderConfig] = None,
|
||||||
|
) -> Sequence[WorkflowNodeExecution]:
|
||||||
|
"""
|
||||||
|
Retrieve all WorkflowNodeExecution database models for a specific workflow run.
|
||||||
|
|
||||||
|
This method is similar to get_by_workflow_run but returns the raw database models
|
||||||
|
instead of converting them to domain models. This can be useful when direct access
|
||||||
|
to database model properties is needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_run_id: The workflow run ID
|
||||||
|
order_config: Optional configuration for ordering results
|
||||||
|
order_config.order_by: List of fields to order by (e.g., ["index", "created_at"])
|
||||||
|
order_config.order_direction: Direction to order ("asc" or "desc")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of WorkflowNodeExecution database models
|
||||||
"""
|
"""
|
||||||
with self._session_factory() as session:
|
with self._session_factory() as session:
|
||||||
stmt = select(WorkflowNodeExecution).where(
|
stmt = select(WorkflowNodeExecution).where(
|
||||||
@ -129,17 +317,25 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||||||
if order_columns:
|
if order_columns:
|
||||||
stmt = stmt.order_by(*order_columns)
|
stmt = stmt.order_by(*order_columns)
|
||||||
|
|
||||||
return session.scalars(stmt).all()
|
db_models = session.scalars(stmt).all()
|
||||||
|
|
||||||
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
|
# Note: We don't update the cache here since we're returning raw DB models
|
||||||
|
# and not converting to domain models
|
||||||
|
|
||||||
|
return db_models
|
||||||
|
|
||||||
|
def get_running_executions(self, workflow_run_id: str) -> Sequence[NodeExecution]:
|
||||||
"""
|
"""
|
||||||
Retrieve all running WorkflowNodeExecution instances for a specific workflow run.
|
Retrieve all running NodeExecution instances for a specific workflow run.
|
||||||
|
|
||||||
|
This method queries the database directly and updates the cache with any
|
||||||
|
retrieved executions that have a node_execution_id.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
workflow_run_id: The workflow run ID
|
workflow_run_id: The workflow run ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of running WorkflowNodeExecution instances
|
A list of running NodeExecution instances
|
||||||
"""
|
"""
|
||||||
with self._session_factory() as session:
|
with self._session_factory() as session:
|
||||||
stmt = select(WorkflowNodeExecution).where(
|
stmt = select(WorkflowNodeExecution).where(
|
||||||
@ -152,26 +348,17 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||||||
if self._app_id:
|
if self._app_id:
|
||||||
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||||
|
|
||||||
return session.scalars(stmt).all()
|
db_models = session.scalars(stmt).all()
|
||||||
|
domain_models = []
|
||||||
|
|
||||||
def update(self, execution: WorkflowNodeExecution) -> None:
|
for model in db_models:
|
||||||
"""
|
domain_model = self._to_domain_model(model)
|
||||||
Update an existing WorkflowNodeExecution instance and commit changes to the database.
|
# Update cache if node_execution_id is present
|
||||||
|
if domain_model.node_execution_id:
|
||||||
|
self._node_execution_cache[domain_model.node_execution_id] = domain_model
|
||||||
|
domain_models.append(domain_model)
|
||||||
|
|
||||||
Args:
|
return domain_models
|
||||||
execution: The WorkflowNodeExecution instance to update
|
|
||||||
"""
|
|
||||||
with self._session_factory() as session:
|
|
||||||
# Ensure tenant_id is set
|
|
||||||
if not execution.tenant_id:
|
|
||||||
execution.tenant_id = self._tenant_id
|
|
||||||
|
|
||||||
# Set app_id if provided and not already set
|
|
||||||
if self._app_id and not execution.app_id:
|
|
||||||
execution.app_id = self._app_id
|
|
||||||
|
|
||||||
session.merge(execution)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""
|
"""
|
||||||
@ -179,6 +366,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||||||
|
|
||||||
This method deletes all WorkflowNodeExecution records that match the tenant_id
|
This method deletes all WorkflowNodeExecution records that match the tenant_id
|
||||||
and app_id (if provided) associated with this repository instance.
|
and app_id (if provided) associated with this repository instance.
|
||||||
|
It also clears the in-memory cache.
|
||||||
"""
|
"""
|
||||||
with self._session_factory() as session:
|
with self._session_factory() as session:
|
||||||
stmt = delete(WorkflowNodeExecution).where(WorkflowNodeExecution.tenant_id == self._tenant_id)
|
stmt = delete(WorkflowNodeExecution).where(WorkflowNodeExecution.tenant_id == self._tenant_id)
|
||||||
@ -194,3 +382,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||||||
f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}"
|
f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}"
|
||||||
+ (f" and app {self._app_id}" if self._app_id else "")
|
+ (f" and app {self._app_id}" if self._app_id else "")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Clear the in-memory cache
|
||||||
|
self._node_execution_cache.clear()
|
||||||
|
logger.info("Cleared in-memory node execution cache")
|
||||||
|
@ -32,7 +32,7 @@ from core.tools.errors import (
|
|||||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.enums import CreatedByRole
|
from models.enums import CreatorUserRole
|
||||||
from models.model import Message, MessageFile
|
from models.model import Message, MessageFile
|
||||||
|
|
||||||
|
|
||||||
@ -339,9 +339,9 @@ class ToolEngine:
|
|||||||
url=message.url,
|
url=message.url,
|
||||||
upload_file_id=tool_file_id,
|
upload_file_id=tool_file_id,
|
||||||
created_by_role=(
|
created_by_role=(
|
||||||
CreatedByRole.ACCOUNT
|
CreatorUserRole.ACCOUNT
|
||||||
if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||||
else CreatedByRole.END_USER
|
else CreatorUserRole.END_USER
|
||||||
),
|
),
|
||||||
created_by=user_id,
|
created_by=user_id,
|
||||||
)
|
)
|
||||||
|
98
api/core/workflow/entities/node_execution_entities.py
Normal file
98
api/core/workflow/entities/node_execution_entities.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
"""
|
||||||
|
Domain entities for workflow node execution.
|
||||||
|
|
||||||
|
This module contains the domain model for workflow node execution, which is used
|
||||||
|
by the core workflow module. These models are independent of the storage mechanism
|
||||||
|
and don't contain implementation details like tenant_id, app_id, etc.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||||
|
from core.workflow.nodes.enums import NodeType
|
||||||
|
|
||||||
|
|
||||||
|
class NodeExecutionStatus(StrEnum):
|
||||||
|
"""
|
||||||
|
Node Execution Status Enum.
|
||||||
|
"""
|
||||||
|
|
||||||
|
RUNNING = "running"
|
||||||
|
SUCCEEDED = "succeeded"
|
||||||
|
FAILED = "failed"
|
||||||
|
EXCEPTION = "exception"
|
||||||
|
RETRY = "retry"
|
||||||
|
|
||||||
|
|
||||||
|
class NodeExecution(BaseModel):
|
||||||
|
"""
|
||||||
|
Domain model for workflow node execution.
|
||||||
|
|
||||||
|
This model represents the core business entity of a node execution,
|
||||||
|
without implementation details like tenant_id, app_id, etc.
|
||||||
|
|
||||||
|
Note: User/context-specific fields (triggered_from, created_by, created_by_role)
|
||||||
|
have been moved to the repository implementation to keep the domain model clean.
|
||||||
|
These fields are still accepted in the constructor for backward compatibility,
|
||||||
|
but they are not stored in the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Core identification fields
|
||||||
|
id: str # Unique identifier for this execution record
|
||||||
|
node_execution_id: Optional[str] = None # Optional secondary ID for cross-referencing
|
||||||
|
workflow_id: str # ID of the workflow this node belongs to
|
||||||
|
workflow_run_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging)
|
||||||
|
|
||||||
|
# Execution positioning and flow
|
||||||
|
index: int # Sequence number for ordering in trace visualization
|
||||||
|
predecessor_node_id: Optional[str] = None # ID of the node that executed before this one
|
||||||
|
node_id: str # ID of the node being executed
|
||||||
|
node_type: NodeType # Type of node (e.g., start, llm, knowledge)
|
||||||
|
title: str # Display title of the node
|
||||||
|
|
||||||
|
# Execution data
|
||||||
|
inputs: Optional[Mapping[str, Any]] = None # Input variables used by this node
|
||||||
|
process_data: Optional[Mapping[str, Any]] = None # Intermediate processing data
|
||||||
|
outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node
|
||||||
|
|
||||||
|
# Execution state
|
||||||
|
status: NodeExecutionStatus = NodeExecutionStatus.RUNNING # Current execution status
|
||||||
|
error: Optional[str] = None # Error message if execution failed
|
||||||
|
elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds
|
||||||
|
|
||||||
|
# Additional metadata
|
||||||
|
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.)
|
||||||
|
|
||||||
|
# Timing information
|
||||||
|
created_at: datetime # When execution started
|
||||||
|
finished_at: Optional[datetime] = None # When execution completed
|
||||||
|
|
||||||
|
def update_from_mapping(
|
||||||
|
self,
|
||||||
|
inputs: Optional[Mapping[str, Any]] = None,
|
||||||
|
process_data: Optional[Mapping[str, Any]] = None,
|
||||||
|
outputs: Optional[Mapping[str, Any]] = None,
|
||||||
|
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Update the model from mappings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: The inputs to update
|
||||||
|
process_data: The process data to update
|
||||||
|
outputs: The outputs to update
|
||||||
|
metadata: The metadata to update
|
||||||
|
"""
|
||||||
|
if inputs is not None:
|
||||||
|
self.inputs = dict(inputs)
|
||||||
|
if process_data is not None:
|
||||||
|
self.process_data = dict(process_data)
|
||||||
|
if outputs is not None:
|
||||||
|
self.outputs = dict(outputs)
|
||||||
|
if metadata is not None:
|
||||||
|
self.metadata = dict(metadata)
|
@ -2,12 +2,12 @@ from collections.abc import Sequence
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Literal, Optional, Protocol
|
from typing import Literal, Optional, Protocol
|
||||||
|
|
||||||
from models.workflow import WorkflowNodeExecution
|
from core.workflow.entities.node_execution_entities import NodeExecution
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OrderConfig:
|
class OrderConfig:
|
||||||
"""Configuration for ordering WorkflowNodeExecution instances."""
|
"""Configuration for ordering NodeExecution instances."""
|
||||||
|
|
||||||
order_by: list[str]
|
order_by: list[str]
|
||||||
order_direction: Optional[Literal["asc", "desc"]] = None
|
order_direction: Optional[Literal["asc", "desc"]] = None
|
||||||
@ -15,10 +15,10 @@ class OrderConfig:
|
|||||||
|
|
||||||
class WorkflowNodeExecutionRepository(Protocol):
|
class WorkflowNodeExecutionRepository(Protocol):
|
||||||
"""
|
"""
|
||||||
Repository interface for WorkflowNodeExecution.
|
Repository interface for NodeExecution.
|
||||||
|
|
||||||
This interface defines the contract for accessing and manipulating
|
This interface defines the contract for accessing and manipulating
|
||||||
WorkflowNodeExecution data, regardless of the underlying storage mechanism.
|
NodeExecution data, regardless of the underlying storage mechanism.
|
||||||
|
|
||||||
Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id),
|
Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id),
|
||||||
and trigger sources (triggered_from) should be handled at the implementation level, not in
|
and trigger sources (triggered_from) should be handled at the implementation level, not in
|
||||||
@ -26,24 +26,28 @@ class WorkflowNodeExecutionRepository(Protocol):
|
|||||||
application domains or deployment scenarios.
|
application domains or deployment scenarios.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def save(self, execution: WorkflowNodeExecution) -> None:
|
def save(self, execution: NodeExecution) -> None:
|
||||||
"""
|
"""
|
||||||
Save a WorkflowNodeExecution instance.
|
Save or update a NodeExecution instance.
|
||||||
|
|
||||||
|
This method handles both creating new records and updating existing ones.
|
||||||
|
The implementation should determine whether to create or update based on
|
||||||
|
the execution's ID or other identifying fields.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
execution: The WorkflowNodeExecution instance to save
|
execution: The NodeExecution instance to save or update
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
|
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[NodeExecution]:
|
||||||
"""
|
"""
|
||||||
Retrieve a WorkflowNodeExecution by its node_execution_id.
|
Retrieve a NodeExecution by its node_execution_id.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_execution_id: The node execution ID
|
node_execution_id: The node execution ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The WorkflowNodeExecution instance if found, None otherwise
|
The NodeExecution instance if found, None otherwise
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -51,9 +55,9 @@ class WorkflowNodeExecutionRepository(Protocol):
|
|||||||
self,
|
self,
|
||||||
workflow_run_id: str,
|
workflow_run_id: str,
|
||||||
order_config: Optional[OrderConfig] = None,
|
order_config: Optional[OrderConfig] = None,
|
||||||
) -> Sequence[WorkflowNodeExecution]:
|
) -> Sequence[NodeExecution]:
|
||||||
"""
|
"""
|
||||||
Retrieve all WorkflowNodeExecution instances for a specific workflow run.
|
Retrieve all NodeExecution instances for a specific workflow run.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
workflow_run_id: The workflow run ID
|
workflow_run_id: The workflow run ID
|
||||||
@ -62,34 +66,25 @@ class WorkflowNodeExecutionRepository(Protocol):
|
|||||||
order_config.order_direction: Direction to order ("asc" or "desc")
|
order_config.order_direction: Direction to order ("asc" or "desc")
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of WorkflowNodeExecution instances
|
A list of NodeExecution instances
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
|
def get_running_executions(self, workflow_run_id: str) -> Sequence[NodeExecution]:
|
||||||
"""
|
"""
|
||||||
Retrieve all running WorkflowNodeExecution instances for a specific workflow run.
|
Retrieve all running NodeExecution instances for a specific workflow run.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
workflow_run_id: The workflow run ID
|
workflow_run_id: The workflow run ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of running WorkflowNodeExecution instances
|
A list of running NodeExecution instances
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def update(self, execution: WorkflowNodeExecution) -> None:
|
|
||||||
"""
|
|
||||||
Update an existing WorkflowNodeExecution instance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
execution: The WorkflowNodeExecution instance to update
|
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""
|
"""
|
||||||
Clear all WorkflowNodeExecution records based on implementation-specific criteria.
|
Clear all NodeExecution records based on implementation-specific criteria.
|
||||||
|
|
||||||
This method is intended to be used for bulk deletion operations, such as removing
|
This method is intended to be used for bulk deletion operations, such as removing
|
||||||
all records associated with a specific app_id and tenant_id in multi-tenant implementations.
|
all records associated with a specific app_id and tenant_id in multi-tenant implementations.
|
||||||
|
@ -58,7 +58,7 @@ from core.workflow.repository.workflow_node_execution_repository import Workflow
|
|||||||
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
|
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.enums import CreatedByRole
|
from models.enums import CreatorUserRole
|
||||||
from models.model import EndUser
|
from models.model import EndUser
|
||||||
from models.workflow import (
|
from models.workflow import (
|
||||||
Workflow,
|
Workflow,
|
||||||
@ -94,11 +94,11 @@ class WorkflowAppGenerateTaskPipeline:
|
|||||||
if isinstance(user, EndUser):
|
if isinstance(user, EndUser):
|
||||||
self._user_id = user.id
|
self._user_id = user.id
|
||||||
user_session_id = user.session_id
|
user_session_id = user.session_id
|
||||||
self._created_by_role = CreatedByRole.END_USER
|
self._created_by_role = CreatorUserRole.END_USER
|
||||||
elif isinstance(user, Account):
|
elif isinstance(user, Account):
|
||||||
self._user_id = user.id
|
self._user_id = user.id
|
||||||
user_session_id = user.id
|
user_session_id = user.id
|
||||||
self._created_by_role = CreatedByRole.ACCOUNT
|
self._created_by_role = CreatorUserRole.ACCOUNT
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid user type: {type(user)}")
|
raise ValueError(f"Invalid user type: {type(user)}")
|
||||||
|
|
||||||
|
@ -46,26 +46,28 @@ from core.app.entities.task_entities import (
|
|||||||
)
|
)
|
||||||
from core.app.task_pipeline.exc import WorkflowRunNotFoundError
|
from core.app.task_pipeline.exc import WorkflowRunNotFoundError
|
||||||
from core.file import FILE_MODEL_IDENTITY, File
|
from core.file import FILE_MODEL_IDENTITY, File
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
|
||||||
from core.ops.entities.trace_entity import TraceTaskName
|
from core.ops.entities.trace_entity import TraceTaskName
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||||
from core.tools.tool_manager import ToolManager
|
from core.tools.tool_manager import ToolManager
|
||||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||||
|
from core.workflow.entities.node_execution_entities import (
|
||||||
|
NodeExecution,
|
||||||
|
NodeExecutionStatus,
|
||||||
|
)
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.nodes import NodeType
|
from core.workflow.nodes import NodeType
|
||||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
from models.account import Account
|
from models import (
|
||||||
from models.enums import CreatedByRole, WorkflowRunTriggeredFrom
|
Account,
|
||||||
from models.model import EndUser
|
CreatorUserRole,
|
||||||
from models.workflow import (
|
EndUser,
|
||||||
Workflow,
|
Workflow,
|
||||||
WorkflowNodeExecution,
|
|
||||||
WorkflowNodeExecutionStatus,
|
WorkflowNodeExecutionStatus,
|
||||||
WorkflowNodeExecutionTriggeredFrom,
|
|
||||||
WorkflowRun,
|
WorkflowRun,
|
||||||
WorkflowRunStatus,
|
WorkflowRunStatus,
|
||||||
|
WorkflowRunTriggeredFrom,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -78,7 +80,6 @@ class WorkflowCycleManager:
|
|||||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._workflow_run: WorkflowRun | None = None
|
self._workflow_run: WorkflowRun | None = None
|
||||||
self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {}
|
|
||||||
self._application_generate_entity = application_generate_entity
|
self._application_generate_entity = application_generate_entity
|
||||||
self._workflow_system_variables = workflow_system_variables
|
self._workflow_system_variables = workflow_system_variables
|
||||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||||
@ -89,7 +90,7 @@ class WorkflowCycleManager:
|
|||||||
session: Session,
|
session: Session,
|
||||||
workflow_id: str,
|
workflow_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
created_by_role: CreatedByRole,
|
created_by_role: CreatorUserRole,
|
||||||
) -> WorkflowRun:
|
) -> WorkflowRun:
|
||||||
workflow_stmt = select(Workflow).where(Workflow.id == workflow_id)
|
workflow_stmt = select(Workflow).where(Workflow.id == workflow_id)
|
||||||
workflow = session.scalar(workflow_stmt)
|
workflow = session.scalar(workflow_stmt)
|
||||||
@ -258,21 +259,22 @@ class WorkflowCycleManager:
|
|||||||
workflow_run.exceptions_count = exceptions_count
|
workflow_run.exceptions_count = exceptions_count
|
||||||
|
|
||||||
# Use the instance repository to find running executions for a workflow run
|
# Use the instance repository to find running executions for a workflow run
|
||||||
running_workflow_node_executions = self._workflow_node_execution_repository.get_running_executions(
|
running_domain_executions = self._workflow_node_execution_repository.get_running_executions(
|
||||||
workflow_run_id=workflow_run.id
|
workflow_run_id=workflow_run.id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update the cache with the retrieved executions
|
# Update the domain models
|
||||||
for execution in running_workflow_node_executions:
|
|
||||||
if execution.node_execution_id:
|
|
||||||
self._workflow_node_executions[execution.node_execution_id] = execution
|
|
||||||
|
|
||||||
for workflow_node_execution in running_workflow_node_executions:
|
|
||||||
now = datetime.now(UTC).replace(tzinfo=None)
|
now = datetime.now(UTC).replace(tzinfo=None)
|
||||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
for domain_execution in running_domain_executions:
|
||||||
workflow_node_execution.error = error
|
if domain_execution.node_execution_id:
|
||||||
workflow_node_execution.finished_at = now
|
# Update the domain model
|
||||||
workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds()
|
domain_execution.status = NodeExecutionStatus.FAILED
|
||||||
|
domain_execution.error = error
|
||||||
|
domain_execution.finished_at = now
|
||||||
|
domain_execution.elapsed_time = (now - domain_execution.created_at).total_seconds()
|
||||||
|
|
||||||
|
# Update the repository with the domain model
|
||||||
|
self._workflow_node_execution_repository.save(domain_execution)
|
||||||
|
|
||||||
if trace_manager:
|
if trace_manager:
|
||||||
trace_manager.add_trace_task(
|
trace_manager.add_trace_task(
|
||||||
@ -286,63 +288,67 @@ class WorkflowCycleManager:
|
|||||||
|
|
||||||
return workflow_run
|
return workflow_run
|
||||||
|
|
||||||
def _handle_node_execution_start(
|
def _handle_node_execution_start(self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> NodeExecution:
|
||||||
self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
|
# Create a domain model
|
||||||
) -> WorkflowNodeExecution:
|
created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
workflow_node_execution = WorkflowNodeExecution()
|
metadata = {
|
||||||
workflow_node_execution.id = str(uuid4())
|
|
||||||
workflow_node_execution.tenant_id = workflow_run.tenant_id
|
|
||||||
workflow_node_execution.app_id = workflow_run.app_id
|
|
||||||
workflow_node_execution.workflow_id = workflow_run.workflow_id
|
|
||||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
|
||||||
workflow_node_execution.workflow_run_id = workflow_run.id
|
|
||||||
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
|
|
||||||
workflow_node_execution.index = event.node_run_index
|
|
||||||
workflow_node_execution.node_execution_id = event.node_execution_id
|
|
||||||
workflow_node_execution.node_id = event.node_id
|
|
||||||
workflow_node_execution.node_type = event.node_type.value
|
|
||||||
workflow_node_execution.title = event.node_data.title
|
|
||||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
|
|
||||||
workflow_node_execution.created_by_role = workflow_run.created_by_role
|
|
||||||
workflow_node_execution.created_by = workflow_run.created_by
|
|
||||||
workflow_node_execution.execution_metadata = json.dumps(
|
|
||||||
{
|
|
||||||
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
|
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
|
||||||
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
|
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||||
NodeRunMetadataKey.LOOP_ID: event.in_loop_id,
|
NodeRunMetadataKey.LOOP_ID: event.in_loop_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
domain_execution = NodeExecution(
|
||||||
|
id=str(uuid4()),
|
||||||
|
workflow_id=workflow_run.workflow_id,
|
||||||
|
workflow_run_id=workflow_run.id,
|
||||||
|
predecessor_node_id=event.predecessor_node_id,
|
||||||
|
index=event.node_run_index,
|
||||||
|
node_execution_id=event.node_execution_id,
|
||||||
|
node_id=event.node_id,
|
||||||
|
node_type=event.node_type,
|
||||||
|
title=event.node_data.title,
|
||||||
|
status=NodeExecutionStatus.RUNNING,
|
||||||
|
metadata=metadata,
|
||||||
|
created_at=created_at,
|
||||||
)
|
)
|
||||||
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
|
|
||||||
|
|
||||||
# Use the instance repository to save the workflow node execution
|
# Use the instance repository to save the domain model
|
||||||
self._workflow_node_execution_repository.save(workflow_node_execution)
|
self._workflow_node_execution_repository.save(domain_execution)
|
||||||
|
|
||||||
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
|
return domain_execution
|
||||||
return workflow_node_execution
|
|
||||||
|
|
||||||
def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
|
def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> NodeExecution:
|
||||||
workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id)
|
# Get the domain model from repository
|
||||||
|
domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id)
|
||||||
|
if not domain_execution:
|
||||||
|
raise ValueError(f"Domain node execution not found: {event.node_execution_id}")
|
||||||
|
|
||||||
|
# Process data
|
||||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||||
execution_metadata_dict = dict(event.execution_metadata or {})
|
|
||||||
execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None
|
# Convert metadata keys to strings
|
||||||
|
execution_metadata_dict = {}
|
||||||
|
if event.execution_metadata:
|
||||||
|
for key, value in event.execution_metadata.items():
|
||||||
|
execution_metadata_dict[key] = value
|
||||||
|
|
||||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||||
|
|
||||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
# Update domain model
|
||||||
|
domain_execution.status = NodeExecutionStatus.SUCCEEDED
|
||||||
|
domain_execution.update_from_mapping(
|
||||||
|
inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict
|
||||||
|
)
|
||||||
|
domain_execution.finished_at = finished_at
|
||||||
|
domain_execution.elapsed_time = elapsed_time
|
||||||
|
|
||||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
# Update the repository with the domain model
|
||||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
self._workflow_node_execution_repository.save(domain_execution)
|
||||||
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
|
||||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
|
||||||
workflow_node_execution.execution_metadata = execution_metadata
|
|
||||||
workflow_node_execution.finished_at = finished_at
|
|
||||||
workflow_node_execution.elapsed_time = elapsed_time
|
|
||||||
|
|
||||||
# Use the instance repository to update the workflow node execution
|
return domain_execution
|
||||||
self._workflow_node_execution_repository.update(workflow_node_execution)
|
|
||||||
return workflow_node_execution
|
|
||||||
|
|
||||||
def _handle_workflow_node_execution_failed(
|
def _handle_workflow_node_execution_failed(
|
||||||
self,
|
self,
|
||||||
@ -351,43 +357,52 @@ class WorkflowCycleManager:
|
|||||||
| QueueNodeInIterationFailedEvent
|
| QueueNodeInIterationFailedEvent
|
||||||
| QueueNodeInLoopFailedEvent
|
| QueueNodeInLoopFailedEvent
|
||||||
| QueueNodeExceptionEvent,
|
| QueueNodeExceptionEvent,
|
||||||
) -> WorkflowNodeExecution:
|
) -> NodeExecution:
|
||||||
"""
|
"""
|
||||||
Workflow node execution failed
|
Workflow node execution failed
|
||||||
:param event: queue node failed event
|
:param event: queue node failed event
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id)
|
# Get the domain model from repository
|
||||||
|
domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id)
|
||||||
|
if not domain_execution:
|
||||||
|
raise ValueError(f"Domain node execution not found: {event.node_execution_id}")
|
||||||
|
|
||||||
|
# Process data
|
||||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||||
|
|
||||||
|
# Convert metadata keys to strings
|
||||||
|
execution_metadata_dict = {}
|
||||||
|
if event.execution_metadata:
|
||||||
|
for key, value in event.execution_metadata.items():
|
||||||
|
execution_metadata_dict[key] = value
|
||||||
|
|
||||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||||
execution_metadata = (
|
|
||||||
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
|
# Update domain model
|
||||||
)
|
domain_execution.status = (
|
||||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
NodeExecutionStatus.FAILED
|
||||||
workflow_node_execution.status = (
|
|
||||||
WorkflowNodeExecutionStatus.FAILED.value
|
|
||||||
if not isinstance(event, QueueNodeExceptionEvent)
|
if not isinstance(event, QueueNodeExceptionEvent)
|
||||||
else WorkflowNodeExecutionStatus.EXCEPTION.value
|
else NodeExecutionStatus.EXCEPTION
|
||||||
)
|
)
|
||||||
workflow_node_execution.error = event.error
|
domain_execution.error = event.error
|
||||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
domain_execution.update_from_mapping(
|
||||||
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict
|
||||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
)
|
||||||
workflow_node_execution.finished_at = finished_at
|
domain_execution.finished_at = finished_at
|
||||||
workflow_node_execution.elapsed_time = elapsed_time
|
domain_execution.elapsed_time = elapsed_time
|
||||||
workflow_node_execution.execution_metadata = execution_metadata
|
|
||||||
|
|
||||||
self._workflow_node_execution_repository.update(workflow_node_execution)
|
# Update the repository with the domain model
|
||||||
|
self._workflow_node_execution_repository.save(domain_execution)
|
||||||
|
|
||||||
return workflow_node_execution
|
return domain_execution
|
||||||
|
|
||||||
def _handle_workflow_node_execution_retried(
|
def _handle_workflow_node_execution_retried(
|
||||||
self, *, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
|
self, *, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
|
||||||
) -> WorkflowNodeExecution:
|
) -> NodeExecution:
|
||||||
"""
|
"""
|
||||||
Workflow node execution failed
|
Workflow node execution failed
|
||||||
:param workflow_run: workflow run
|
:param workflow_run: workflow run
|
||||||
@ -399,47 +414,47 @@ class WorkflowCycleManager:
|
|||||||
elapsed_time = (finished_at - created_at).total_seconds()
|
elapsed_time = (finished_at - created_at).total_seconds()
|
||||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||||
|
|
||||||
|
# Convert metadata keys to strings
|
||||||
origin_metadata = {
|
origin_metadata = {
|
||||||
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
|
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||||
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
|
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
|
||||||
NodeRunMetadataKey.LOOP_ID: event.in_loop_id,
|
NodeRunMetadataKey.LOOP_ID: event.in_loop_id,
|
||||||
}
|
}
|
||||||
merged_metadata = (
|
|
||||||
{**jsonable_encoder(event.execution_metadata), **origin_metadata}
|
# Convert execution metadata keys to strings
|
||||||
if event.execution_metadata is not None
|
execution_metadata_dict: dict[NodeRunMetadataKey, str | None] = {}
|
||||||
else origin_metadata
|
if event.execution_metadata:
|
||||||
|
for key, value in event.execution_metadata.items():
|
||||||
|
execution_metadata_dict[key] = value
|
||||||
|
|
||||||
|
merged_metadata = {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata
|
||||||
|
|
||||||
|
# Create a domain model
|
||||||
|
domain_execution = NodeExecution(
|
||||||
|
id=str(uuid4()),
|
||||||
|
workflow_id=workflow_run.workflow_id,
|
||||||
|
workflow_run_id=workflow_run.id,
|
||||||
|
predecessor_node_id=event.predecessor_node_id,
|
||||||
|
node_execution_id=event.node_execution_id,
|
||||||
|
node_id=event.node_id,
|
||||||
|
node_type=event.node_type,
|
||||||
|
title=event.node_data.title,
|
||||||
|
status=NodeExecutionStatus.RETRY,
|
||||||
|
created_at=created_at,
|
||||||
|
finished_at=finished_at,
|
||||||
|
elapsed_time=elapsed_time,
|
||||||
|
error=event.error,
|
||||||
|
index=event.node_run_index,
|
||||||
)
|
)
|
||||||
execution_metadata = json.dumps(merged_metadata)
|
|
||||||
|
|
||||||
workflow_node_execution = WorkflowNodeExecution()
|
# Update with mappings
|
||||||
workflow_node_execution.id = str(uuid4())
|
domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=merged_metadata)
|
||||||
workflow_node_execution.tenant_id = workflow_run.tenant_id
|
|
||||||
workflow_node_execution.app_id = workflow_run.app_id
|
|
||||||
workflow_node_execution.workflow_id = workflow_run.workflow_id
|
|
||||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
|
||||||
workflow_node_execution.workflow_run_id = workflow_run.id
|
|
||||||
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
|
|
||||||
workflow_node_execution.node_execution_id = event.node_execution_id
|
|
||||||
workflow_node_execution.node_id = event.node_id
|
|
||||||
workflow_node_execution.node_type = event.node_type.value
|
|
||||||
workflow_node_execution.title = event.node_data.title
|
|
||||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value
|
|
||||||
workflow_node_execution.created_by_role = workflow_run.created_by_role
|
|
||||||
workflow_node_execution.created_by = workflow_run.created_by
|
|
||||||
workflow_node_execution.created_at = created_at
|
|
||||||
workflow_node_execution.finished_at = finished_at
|
|
||||||
workflow_node_execution.elapsed_time = elapsed_time
|
|
||||||
workflow_node_execution.error = event.error
|
|
||||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
|
||||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
|
||||||
workflow_node_execution.execution_metadata = execution_metadata
|
|
||||||
workflow_node_execution.index = event.node_run_index
|
|
||||||
|
|
||||||
# Use the instance repository to save the workflow node execution
|
# Use the instance repository to save the domain model
|
||||||
self._workflow_node_execution_repository.save(workflow_node_execution)
|
self._workflow_node_execution_repository.save(domain_execution)
|
||||||
|
|
||||||
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
|
return domain_execution
|
||||||
return workflow_node_execution
|
|
||||||
|
|
||||||
def _workflow_start_to_stream_response(
|
def _workflow_start_to_stream_response(
|
||||||
self,
|
self,
|
||||||
@ -469,7 +484,7 @@ class WorkflowCycleManager:
|
|||||||
workflow_run: WorkflowRun,
|
workflow_run: WorkflowRun,
|
||||||
) -> WorkflowFinishStreamResponse:
|
) -> WorkflowFinishStreamResponse:
|
||||||
created_by = None
|
created_by = None
|
||||||
if workflow_run.created_by_role == CreatedByRole.ACCOUNT:
|
if workflow_run.created_by_role == CreatorUserRole.ACCOUNT:
|
||||||
stmt = select(Account).where(Account.id == workflow_run.created_by)
|
stmt = select(Account).where(Account.id == workflow_run.created_by)
|
||||||
account = session.scalar(stmt)
|
account = session.scalar(stmt)
|
||||||
if account:
|
if account:
|
||||||
@ -478,7 +493,7 @@ class WorkflowCycleManager:
|
|||||||
"name": account.name,
|
"name": account.name,
|
||||||
"email": account.email,
|
"email": account.email,
|
||||||
}
|
}
|
||||||
elif workflow_run.created_by_role == CreatedByRole.END_USER:
|
elif workflow_run.created_by_role == CreatorUserRole.END_USER:
|
||||||
stmt = select(EndUser).where(EndUser.id == workflow_run.created_by)
|
stmt = select(EndUser).where(EndUser.id == workflow_run.created_by)
|
||||||
end_user = session.scalar(stmt)
|
end_user = session.scalar(stmt)
|
||||||
if end_user:
|
if end_user:
|
||||||
@ -515,9 +530,9 @@ class WorkflowCycleManager:
|
|||||||
*,
|
*,
|
||||||
event: QueueNodeStartedEvent,
|
event: QueueNodeStartedEvent,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
workflow_node_execution: WorkflowNodeExecution,
|
workflow_node_execution: NodeExecution,
|
||||||
) -> Optional[NodeStartStreamResponse]:
|
) -> Optional[NodeStartStreamResponse]:
|
||||||
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||||
return None
|
return None
|
||||||
if not workflow_node_execution.workflow_run_id:
|
if not workflow_node_execution.workflow_run_id:
|
||||||
return None
|
return None
|
||||||
@ -532,7 +547,7 @@ class WorkflowCycleManager:
|
|||||||
title=workflow_node_execution.title,
|
title=workflow_node_execution.title,
|
||||||
index=workflow_node_execution.index,
|
index=workflow_node_execution.index,
|
||||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||||
inputs=workflow_node_execution.inputs_dict,
|
inputs=workflow_node_execution.inputs,
|
||||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||||
parallel_id=event.parallel_id,
|
parallel_id=event.parallel_id,
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
parallel_start_node_id=event.parallel_start_node_id,
|
||||||
@ -565,9 +580,9 @@ class WorkflowCycleManager:
|
|||||||
| QueueNodeInLoopFailedEvent
|
| QueueNodeInLoopFailedEvent
|
||||||
| QueueNodeExceptionEvent,
|
| QueueNodeExceptionEvent,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
workflow_node_execution: WorkflowNodeExecution,
|
workflow_node_execution: NodeExecution,
|
||||||
) -> Optional[NodeFinishStreamResponse]:
|
) -> Optional[NodeFinishStreamResponse]:
|
||||||
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||||
return None
|
return None
|
||||||
if not workflow_node_execution.workflow_run_id:
|
if not workflow_node_execution.workflow_run_id:
|
||||||
return None
|
return None
|
||||||
@ -584,16 +599,16 @@ class WorkflowCycleManager:
|
|||||||
index=workflow_node_execution.index,
|
index=workflow_node_execution.index,
|
||||||
title=workflow_node_execution.title,
|
title=workflow_node_execution.title,
|
||||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||||
inputs=workflow_node_execution.inputs_dict,
|
inputs=workflow_node_execution.inputs,
|
||||||
process_data=workflow_node_execution.process_data_dict,
|
process_data=workflow_node_execution.process_data,
|
||||||
outputs=workflow_node_execution.outputs_dict,
|
outputs=workflow_node_execution.outputs,
|
||||||
status=workflow_node_execution.status,
|
status=workflow_node_execution.status,
|
||||||
error=workflow_node_execution.error,
|
error=workflow_node_execution.error,
|
||||||
elapsed_time=workflow_node_execution.elapsed_time,
|
elapsed_time=workflow_node_execution.elapsed_time,
|
||||||
execution_metadata=workflow_node_execution.execution_metadata_dict,
|
execution_metadata=workflow_node_execution.metadata,
|
||||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||||
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
||||||
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
|
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
|
||||||
parallel_id=event.parallel_id,
|
parallel_id=event.parallel_id,
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
parallel_start_node_id=event.parallel_start_node_id,
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
parent_parallel_id=event.parent_parallel_id,
|
||||||
@ -608,9 +623,9 @@ class WorkflowCycleManager:
|
|||||||
*,
|
*,
|
||||||
event: QueueNodeRetryEvent,
|
event: QueueNodeRetryEvent,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
workflow_node_execution: WorkflowNodeExecution,
|
workflow_node_execution: NodeExecution,
|
||||||
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
|
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
|
||||||
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||||
return None
|
return None
|
||||||
if not workflow_node_execution.workflow_run_id:
|
if not workflow_node_execution.workflow_run_id:
|
||||||
return None
|
return None
|
||||||
@ -627,16 +642,16 @@ class WorkflowCycleManager:
|
|||||||
index=workflow_node_execution.index,
|
index=workflow_node_execution.index,
|
||||||
title=workflow_node_execution.title,
|
title=workflow_node_execution.title,
|
||||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||||
inputs=workflow_node_execution.inputs_dict,
|
inputs=workflow_node_execution.inputs,
|
||||||
process_data=workflow_node_execution.process_data_dict,
|
process_data=workflow_node_execution.process_data,
|
||||||
outputs=workflow_node_execution.outputs_dict,
|
outputs=workflow_node_execution.outputs,
|
||||||
status=workflow_node_execution.status,
|
status=workflow_node_execution.status,
|
||||||
error=workflow_node_execution.error,
|
error=workflow_node_execution.error,
|
||||||
elapsed_time=workflow_node_execution.elapsed_time,
|
elapsed_time=workflow_node_execution.elapsed_time,
|
||||||
execution_metadata=workflow_node_execution.execution_metadata_dict,
|
execution_metadata=workflow_node_execution.metadata,
|
||||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||||
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
||||||
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
|
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
|
||||||
parallel_id=event.parallel_id,
|
parallel_id=event.parallel_id,
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
parallel_start_node_id=event.parallel_start_node_id,
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
parent_parallel_id=event.parent_parallel_id,
|
||||||
@ -908,23 +923,6 @@ class WorkflowCycleManager:
|
|||||||
|
|
||||||
return workflow_run
|
return workflow_run
|
||||||
|
|
||||||
def _get_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution:
|
|
||||||
# First check the cache for performance
|
|
||||||
if node_execution_id in self._workflow_node_executions:
|
|
||||||
cached_execution = self._workflow_node_executions[node_execution_id]
|
|
||||||
# No need to merge with session since expire_on_commit=False
|
|
||||||
return cached_execution
|
|
||||||
|
|
||||||
# If not in cache, use the instance repository to get by node_execution_id
|
|
||||||
execution = self._workflow_node_execution_repository.get_by_node_execution_id(node_execution_id)
|
|
||||||
|
|
||||||
if not execution:
|
|
||||||
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
|
|
||||||
|
|
||||||
# Update cache
|
|
||||||
self._workflow_node_executions[node_execution_id] = execution
|
|
||||||
return execution
|
|
||||||
|
|
||||||
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
|
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
|
||||||
"""
|
"""
|
||||||
Handle agent log
|
Handle agent log
|
||||||
|
@ -27,7 +27,7 @@ from .dataset import (
|
|||||||
Whitelist,
|
Whitelist,
|
||||||
)
|
)
|
||||||
from .engine import db
|
from .engine import db
|
||||||
from .enums import CreatedByRole, UserFrom, WorkflowRunTriggeredFrom
|
from .enums import CreatorUserRole, UserFrom, WorkflowRunTriggeredFrom
|
||||||
from .model import (
|
from .model import (
|
||||||
ApiRequest,
|
ApiRequest,
|
||||||
ApiToken,
|
ApiToken,
|
||||||
@ -112,7 +112,7 @@ __all__ = [
|
|||||||
"CeleryTaskSet",
|
"CeleryTaskSet",
|
||||||
"Conversation",
|
"Conversation",
|
||||||
"ConversationVariable",
|
"ConversationVariable",
|
||||||
"CreatedByRole",
|
"CreatorUserRole",
|
||||||
"DataSourceApiKeyAuthBinding",
|
"DataSourceApiKeyAuthBinding",
|
||||||
"DataSourceOauthBinding",
|
"DataSourceOauthBinding",
|
||||||
"Dataset",
|
"Dataset",
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
|
||||||
|
|
||||||
class CreatedByRole(StrEnum):
|
class CreatorUserRole(StrEnum):
|
||||||
ACCOUNT = "account"
|
ACCOUNT = "account"
|
||||||
END_USER = "end_user"
|
END_USER = "end_user"
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ from libs.helper import generate_string
|
|||||||
from .account import Account, Tenant
|
from .account import Account, Tenant
|
||||||
from .base import Base
|
from .base import Base
|
||||||
from .engine import db
|
from .engine import db
|
||||||
from .enums import CreatedByRole
|
from .enums import CreatorUserRole
|
||||||
from .types import StringUUID
|
from .types import StringUUID
|
||||||
from .workflow import WorkflowRunStatus
|
from .workflow import WorkflowRunStatus
|
||||||
|
|
||||||
@ -1270,7 +1270,7 @@ class MessageFile(Base):
|
|||||||
url: str | None = None,
|
url: str | None = None,
|
||||||
belongs_to: Literal["user", "assistant"] | None = None,
|
belongs_to: Literal["user", "assistant"] | None = None,
|
||||||
upload_file_id: str | None = None,
|
upload_file_id: str | None = None,
|
||||||
created_by_role: CreatedByRole,
|
created_by_role: CreatorUserRole,
|
||||||
created_by: str,
|
created_by: str,
|
||||||
):
|
):
|
||||||
self.message_id = message_id
|
self.message_id = message_id
|
||||||
@ -1417,7 +1417,7 @@ class EndUser(Base, UserMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||||
tenant_id = db.Column(StringUUID, nullable=False)
|
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||||
app_id = db.Column(StringUUID, nullable=True)
|
app_id = db.Column(StringUUID, nullable=True)
|
||||||
type = db.Column(db.String(255), nullable=False)
|
type = db.Column(db.String(255), nullable=False)
|
||||||
external_user_id = db.Column(db.String(255), nullable=True)
|
external_user_id = db.Column(db.String(255), nullable=True)
|
||||||
@ -1547,7 +1547,7 @@ class UploadFile(Base):
|
|||||||
size: int,
|
size: int,
|
||||||
extension: str,
|
extension: str,
|
||||||
mime_type: str,
|
mime_type: str,
|
||||||
created_by_role: CreatedByRole,
|
created_by_role: CreatorUserRole,
|
||||||
created_by: str,
|
created_by: str,
|
||||||
created_at: datetime,
|
created_at: datetime,
|
||||||
used: bool,
|
used: bool,
|
||||||
|
@ -22,7 +22,7 @@ from libs import helper
|
|||||||
from .account import Account
|
from .account import Account
|
||||||
from .base import Base
|
from .base import Base
|
||||||
from .engine import db
|
from .engine import db
|
||||||
from .enums import CreatedByRole
|
from .enums import CreatorUserRole
|
||||||
from .types import StringUUID
|
from .types import StringUUID
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -429,15 +429,15 @@ class WorkflowRun(Base):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def created_by_account(self):
|
def created_by_account(self):
|
||||||
created_by_role = CreatedByRole(self.created_by_role)
|
created_by_role = CreatorUserRole(self.created_by_role)
|
||||||
return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None
|
return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def created_by_end_user(self):
|
def created_by_end_user(self):
|
||||||
from models.model import EndUser
|
from models.model import EndUser
|
||||||
|
|
||||||
created_by_role = CreatedByRole(self.created_by_role)
|
created_by_role = CreatorUserRole(self.created_by_role)
|
||||||
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
|
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def graph_dict(self):
|
def graph_dict(self):
|
||||||
@ -634,17 +634,17 @@ class WorkflowNodeExecution(Base):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def created_by_account(self):
|
def created_by_account(self):
|
||||||
created_by_role = CreatedByRole(self.created_by_role)
|
created_by_role = CreatorUserRole(self.created_by_role)
|
||||||
# TODO(-LAN-): Avoid using db.session.get() here.
|
# TODO(-LAN-): Avoid using db.session.get() here.
|
||||||
return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None
|
return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def created_by_end_user(self):
|
def created_by_end_user(self):
|
||||||
from models.model import EndUser
|
from models.model import EndUser
|
||||||
|
|
||||||
created_by_role = CreatedByRole(self.created_by_role)
|
created_by_role = CreatorUserRole(self.created_by_role)
|
||||||
# TODO(-LAN-): Avoid using db.session.get() here.
|
# TODO(-LAN-): Avoid using db.session.get() here.
|
||||||
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
|
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def inputs_dict(self):
|
def inputs_dict(self):
|
||||||
@ -755,15 +755,15 @@ class WorkflowAppLog(Base):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def created_by_account(self):
|
def created_by_account(self):
|
||||||
created_by_role = CreatedByRole(self.created_by_role)
|
created_by_role = CreatorUserRole(self.created_by_role)
|
||||||
return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None
|
return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def created_by_end_user(self):
|
def created_by_end_user(self):
|
||||||
from models.model import EndUser
|
from models.model import EndUser
|
||||||
|
|
||||||
created_by_role = CreatedByRole(self.created_by_role)
|
created_by_role = CreatorUserRole(self.created_by_role)
|
||||||
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
|
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
|
||||||
|
|
||||||
|
|
||||||
class ConversationVariable(Base):
|
class ConversationVariable(Base):
|
||||||
|
@ -19,7 +19,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_storage import storage
|
from extensions.ext_storage import storage
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.enums import CreatedByRole
|
from models.enums import CreatorUserRole
|
||||||
from models.model import EndUser, UploadFile
|
from models.model import EndUser, UploadFile
|
||||||
|
|
||||||
from .errors.file import FileTooLargeError, UnsupportedFileTypeError
|
from .errors.file import FileTooLargeError, UnsupportedFileTypeError
|
||||||
@ -81,7 +81,7 @@ class FileService:
|
|||||||
size=file_size,
|
size=file_size,
|
||||||
extension=extension,
|
extension=extension,
|
||||||
mime_type=mimetype,
|
mime_type=mimetype,
|
||||||
created_by_role=(CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER),
|
created_by_role=(CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER),
|
||||||
created_by=user.id,
|
created_by=user.id,
|
||||||
created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||||
used=False,
|
used=False,
|
||||||
@ -133,7 +133,7 @@ class FileService:
|
|||||||
extension="txt",
|
extension="txt",
|
||||||
mime_type="text/plain",
|
mime_type="text/plain",
|
||||||
created_by=current_user.id,
|
created_by=current_user.id,
|
||||||
created_by_role=CreatedByRole.ACCOUNT,
|
created_by_role=CreatorUserRole.ACCOUNT,
|
||||||
created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||||
used=True,
|
used=True,
|
||||||
used_by=current_user.id,
|
used_by=current_user.id,
|
||||||
|
@ -5,7 +5,7 @@ from sqlalchemy import and_, func, or_, select
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from models import App, EndUser, WorkflowAppLog, WorkflowRun
|
from models import App, EndUser, WorkflowAppLog, WorkflowRun
|
||||||
from models.enums import CreatedByRole
|
from models.enums import CreatorUserRole
|
||||||
from models.workflow import WorkflowRunStatus
|
from models.workflow import WorkflowRunStatus
|
||||||
|
|
||||||
|
|
||||||
@ -58,7 +58,7 @@ class WorkflowAppService:
|
|||||||
|
|
||||||
stmt = stmt.outerjoin(
|
stmt = stmt.outerjoin(
|
||||||
EndUser,
|
EndUser,
|
||||||
and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER),
|
and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatorUserRole.END_USER),
|
||||||
).where(or_(*keyword_conditions))
|
).where(or_(*keyword_conditions))
|
||||||
|
|
||||||
if status:
|
if status:
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import threading
|
import threading
|
||||||
|
from collections.abc import Sequence
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import contexts
|
import contexts
|
||||||
@ -6,11 +7,13 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
|||||||
from core.workflow.repository.workflow_node_execution_repository import OrderConfig
|
from core.workflow.repository.workflow_node_execution_repository import OrderConfig
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||||
from models.enums import WorkflowRunTriggeredFrom
|
from models import (
|
||||||
from models.model import App
|
Account,
|
||||||
from models.workflow import (
|
App,
|
||||||
|
EndUser,
|
||||||
WorkflowNodeExecution,
|
WorkflowNodeExecution,
|
||||||
WorkflowRun,
|
WorkflowRun,
|
||||||
|
WorkflowRunTriggeredFrom,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -116,7 +119,12 @@ class WorkflowRunService:
|
|||||||
|
|
||||||
return workflow_run
|
return workflow_run
|
||||||
|
|
||||||
def get_workflow_run_node_executions(self, app_model: App, run_id: str) -> list[WorkflowNodeExecution]:
|
def get_workflow_run_node_executions(
|
||||||
|
self,
|
||||||
|
app_model: App,
|
||||||
|
run_id: str,
|
||||||
|
user: Account | EndUser,
|
||||||
|
) -> Sequence[WorkflowNodeExecution]:
|
||||||
"""
|
"""
|
||||||
Get workflow run node execution list
|
Get workflow run node execution list
|
||||||
"""
|
"""
|
||||||
@ -128,13 +136,15 @@ class WorkflowRunService:
|
|||||||
if not workflow_run:
|
if not workflow_run:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Use the repository to get the node executions
|
|
||||||
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
session_factory=db.engine, tenant_id=app_model.tenant_id, app_id=app_model.id
|
session_factory=db.engine,
|
||||||
|
user=user,
|
||||||
|
app_id=app_model.id,
|
||||||
|
triggered_from=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use the repository to get the node executions with ordering
|
# Use the repository to get the node executions with ordering
|
||||||
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
||||||
node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config)
|
node_executions = repository.get_db_models_by_workflow_run(workflow_run_id=run_id, order_config=order_config)
|
||||||
|
|
||||||
return list(node_executions)
|
return node_executions
|
||||||
|
@ -26,7 +26,7 @@ from core.workflow.workflow_entry import WorkflowEntry
|
|||||||
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.enums import CreatedByRole
|
from models.enums import CreatorUserRole
|
||||||
from models.model import App, AppMode
|
from models.model import App, AppMode
|
||||||
from models.tools import WorkflowToolProvider
|
from models.tools import WorkflowToolProvider
|
||||||
from models.workflow import (
|
from models.workflow import (
|
||||||
@ -284,9 +284,11 @@ class WorkflowService:
|
|||||||
workflow_node_execution.created_by = account.id
|
workflow_node_execution.created_by = account.id
|
||||||
workflow_node_execution.workflow_id = draft_workflow.id
|
workflow_node_execution.workflow_id = draft_workflow.id
|
||||||
|
|
||||||
# Use the repository to save the workflow node execution
|
|
||||||
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
session_factory=db.engine, tenant_id=app_model.tenant_id, app_id=app_model.id
|
session_factory=db.engine,
|
||||||
|
user=account,
|
||||||
|
app_id=app_model.id,
|
||||||
|
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||||
)
|
)
|
||||||
repository.save(workflow_node_execution)
|
repository.save(workflow_node_execution)
|
||||||
|
|
||||||
@ -390,7 +392,7 @@ class WorkflowService:
|
|||||||
workflow_node_execution.node_type = node_instance.node_type
|
workflow_node_execution.node_type = node_instance.node_type
|
||||||
workflow_node_execution.title = node_instance.node_data.title
|
workflow_node_execution.title = node_instance.node_data.title
|
||||||
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
|
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
|
||||||
workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value
|
workflow_node_execution.created_by_role = CreatorUserRole.ACCOUNT.value
|
||||||
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
|
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
if run_succeeded and node_run_result:
|
if run_succeeded and node_run_result:
|
||||||
|
@ -4,16 +4,19 @@ from collections.abc import Callable
|
|||||||
|
|
||||||
import click
|
import click
|
||||||
from celery import shared_task # type: ignore
|
from celery import shared_task # type: ignore
|
||||||
from sqlalchemy import delete
|
from sqlalchemy import delete, select
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import AppDatasetJoin
|
from models import (
|
||||||
from models.model import (
|
Account,
|
||||||
ApiToken,
|
ApiToken,
|
||||||
|
App,
|
||||||
AppAnnotationHitHistory,
|
AppAnnotationHitHistory,
|
||||||
AppAnnotationSetting,
|
AppAnnotationSetting,
|
||||||
|
AppDatasetJoin,
|
||||||
AppModelConfig,
|
AppModelConfig,
|
||||||
Conversation,
|
Conversation,
|
||||||
EndUser,
|
EndUser,
|
||||||
@ -188,9 +191,24 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str):
|
|||||||
|
|
||||||
|
|
||||||
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
|
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
|
||||||
|
# Get app's owner
|
||||||
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
|
stmt = select(Account).where(Account.id == App.owner_id).where(App.id == app_id)
|
||||||
|
user = session.scalar(stmt)
|
||||||
|
|
||||||
|
if user is None:
|
||||||
|
errmsg = (
|
||||||
|
f"Failed to delete workflow node executions for tenant {tenant_id} and app {app_id}, app's owner not found"
|
||||||
|
)
|
||||||
|
logging.error(errmsg)
|
||||||
|
raise ValueError(errmsg)
|
||||||
|
|
||||||
# Create a repository instance for WorkflowNodeExecution
|
# Create a repository instance for WorkflowNodeExecution
|
||||||
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
session_factory=db.engine, tenant_id=tenant_id, app_id=app_id
|
session_factory=db.engine,
|
||||||
|
user=user,
|
||||||
|
app_id=app_id,
|
||||||
|
triggered_from=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use the clear method to delete all records for this tenant_id and app_id
|
# Use the clear method to delete all records for this tenant_id and app_id
|
||||||
|
@ -16,10 +16,9 @@ from core.workflow.enums import SystemVariableKey
|
|||||||
from core.workflow.nodes import NodeType
|
from core.workflow.nodes import NodeType
|
||||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||||
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
|
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
|
||||||
from models.enums import CreatedByRole
|
from models.enums import CreatorUserRole
|
||||||
from models.workflow import (
|
from models.workflow import (
|
||||||
Workflow,
|
Workflow,
|
||||||
WorkflowNodeExecution,
|
|
||||||
WorkflowNodeExecutionStatus,
|
WorkflowNodeExecutionStatus,
|
||||||
WorkflowRun,
|
WorkflowRun,
|
||||||
WorkflowRunStatus,
|
WorkflowRunStatus,
|
||||||
@ -94,7 +93,7 @@ def mock_workflow_run():
|
|||||||
workflow_run.app_id = "test-app-id"
|
workflow_run.app_id = "test-app-id"
|
||||||
workflow_run.workflow_id = "test-workflow-id"
|
workflow_run.workflow_id = "test-workflow-id"
|
||||||
workflow_run.status = WorkflowRunStatus.RUNNING
|
workflow_run.status = WorkflowRunStatus.RUNNING
|
||||||
workflow_run.created_by_role = CreatedByRole.ACCOUNT
|
workflow_run.created_by_role = CreatorUserRole.ACCOUNT
|
||||||
workflow_run.created_by = "test-user-id"
|
workflow_run.created_by = "test-user-id"
|
||||||
workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None)
|
workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
workflow_run.inputs_dict = {"query": "test query"}
|
workflow_run.inputs_dict = {"query": "test query"}
|
||||||
@ -107,7 +106,6 @@ def test_init(
|
|||||||
):
|
):
|
||||||
"""Test initialization of WorkflowCycleManager"""
|
"""Test initialization of WorkflowCycleManager"""
|
||||||
assert workflow_cycle_manager._workflow_run is None
|
assert workflow_cycle_manager._workflow_run is None
|
||||||
assert workflow_cycle_manager._workflow_node_executions == {}
|
|
||||||
assert workflow_cycle_manager._application_generate_entity == mock_app_generate_entity
|
assert workflow_cycle_manager._application_generate_entity == mock_app_generate_entity
|
||||||
assert workflow_cycle_manager._workflow_system_variables == mock_workflow_system_variables
|
assert workflow_cycle_manager._workflow_system_variables == mock_workflow_system_variables
|
||||||
assert workflow_cycle_manager._workflow_node_execution_repository == mock_node_execution_repository
|
assert workflow_cycle_manager._workflow_node_execution_repository == mock_node_execution_repository
|
||||||
@ -123,7 +121,7 @@ def test_handle_workflow_run_start(workflow_cycle_manager, mock_session, mock_wo
|
|||||||
session=mock_session,
|
session=mock_session,
|
||||||
workflow_id="test-workflow-id",
|
workflow_id="test-workflow-id",
|
||||||
user_id="test-user-id",
|
user_id="test-user-id",
|
||||||
created_by_role=CreatedByRole.ACCOUNT,
|
created_by_role=CreatorUserRole.ACCOUNT,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify the result
|
# Verify the result
|
||||||
@ -132,7 +130,7 @@ def test_handle_workflow_run_start(workflow_cycle_manager, mock_session, mock_wo
|
|||||||
assert workflow_run.workflow_id == mock_workflow.id
|
assert workflow_run.workflow_id == mock_workflow.id
|
||||||
assert workflow_run.sequence_number == 6 # max_sequence + 1
|
assert workflow_run.sequence_number == 6 # max_sequence + 1
|
||||||
assert workflow_run.status == WorkflowRunStatus.RUNNING
|
assert workflow_run.status == WorkflowRunStatus.RUNNING
|
||||||
assert workflow_run.created_by_role == CreatedByRole.ACCOUNT
|
assert workflow_run.created_by_role == CreatorUserRole.ACCOUNT
|
||||||
assert workflow_run.created_by == "test-user-id"
|
assert workflow_run.created_by == "test-user-id"
|
||||||
|
|
||||||
# Verify session.add was called
|
# Verify session.add was called
|
||||||
@ -215,24 +213,23 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_run):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Verify the result
|
# Verify the result
|
||||||
assert result.tenant_id == mock_workflow_run.tenant_id
|
# NodeExecution doesn't have tenant_id attribute, it's handled at repository level
|
||||||
assert result.app_id == mock_workflow_run.app_id
|
# assert result.tenant_id == mock_workflow_run.tenant_id
|
||||||
|
# assert result.app_id == mock_workflow_run.app_id
|
||||||
assert result.workflow_id == mock_workflow_run.workflow_id
|
assert result.workflow_id == mock_workflow_run.workflow_id
|
||||||
assert result.workflow_run_id == mock_workflow_run.id
|
assert result.workflow_run_id == mock_workflow_run.id
|
||||||
assert result.node_execution_id == event.node_execution_id
|
assert result.node_execution_id == event.node_execution_id
|
||||||
assert result.node_id == event.node_id
|
assert result.node_id == event.node_id
|
||||||
assert result.node_type == event.node_type.value
|
assert result.node_type == event.node_type
|
||||||
assert result.title == event.node_data.title
|
assert result.title == event.node_data.title
|
||||||
assert result.status == WorkflowNodeExecutionStatus.RUNNING.value
|
assert result.status == WorkflowNodeExecutionStatus.RUNNING.value
|
||||||
assert result.created_by_role == mock_workflow_run.created_by_role
|
# NodeExecution doesn't have created_by_role and created_by attributes, they're handled at repository level
|
||||||
assert result.created_by == mock_workflow_run.created_by
|
# assert result.created_by_role == mock_workflow_run.created_by_role
|
||||||
|
# assert result.created_by == mock_workflow_run.created_by
|
||||||
|
|
||||||
# Verify save was called
|
# Verify save was called
|
||||||
workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(result)
|
workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(result)
|
||||||
|
|
||||||
# Verify the node execution was added to the cache
|
|
||||||
assert workflow_cycle_manager._workflow_node_executions[event.node_execution_id] == result
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_workflow_run(workflow_cycle_manager, mock_session, mock_workflow_run):
|
def test_get_workflow_run(workflow_cycle_manager, mock_session, mock_workflow_run):
|
||||||
"""Test _get_workflow_run method"""
|
"""Test _get_workflow_run method"""
|
||||||
@ -261,12 +258,13 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager):
|
|||||||
event.execution_metadata = {"metadata": "test metadata"}
|
event.execution_metadata = {"metadata": "test metadata"}
|
||||||
event.start_at = datetime.now(UTC).replace(tzinfo=None)
|
event.start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
|
|
||||||
# Create a mock workflow node execution
|
# Create a mock node execution
|
||||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
node_execution = MagicMock()
|
||||||
node_execution.node_execution_id = "test-node-execution-id"
|
node_execution.node_execution_id = "test-node-execution-id"
|
||||||
|
|
||||||
# Mock _get_workflow_node_execution to return the mock node execution
|
# Mock the repository to return the node execution
|
||||||
with patch.object(workflow_cycle_manager, "_get_workflow_node_execution", return_value=node_execution):
|
workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution
|
||||||
|
|
||||||
# Call the method
|
# Call the method
|
||||||
result = workflow_cycle_manager._handle_workflow_node_execution_success(
|
result = workflow_cycle_manager._handle_workflow_node_execution_success(
|
||||||
event=event,
|
event=event,
|
||||||
@ -275,14 +273,9 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager):
|
|||||||
# Verify the result
|
# Verify the result
|
||||||
assert result == node_execution
|
assert result == node_execution
|
||||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED.value
|
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||||
assert result.inputs == json.dumps(event.inputs)
|
|
||||||
assert result.process_data == json.dumps(event.process_data)
|
|
||||||
assert result.outputs == json.dumps(event.outputs)
|
|
||||||
assert result.finished_at is not None
|
|
||||||
assert result.elapsed_time is not None
|
|
||||||
|
|
||||||
# Verify update was called
|
# Verify save was called
|
||||||
workflow_cycle_manager._workflow_node_execution_repository.update.assert_called_once_with(node_execution)
|
workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(node_execution)
|
||||||
|
|
||||||
|
|
||||||
def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_session, mock_workflow_run):
|
def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_session, mock_workflow_run):
|
||||||
@ -322,12 +315,13 @@ def test_handle_workflow_node_execution_failed(workflow_cycle_manager):
|
|||||||
event.start_at = datetime.now(UTC).replace(tzinfo=None)
|
event.start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
event.error = "Test error message"
|
event.error = "Test error message"
|
||||||
|
|
||||||
# Create a mock workflow node execution
|
# Create a mock node execution
|
||||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
node_execution = MagicMock()
|
||||||
node_execution.node_execution_id = "test-node-execution-id"
|
node_execution.node_execution_id = "test-node-execution-id"
|
||||||
|
|
||||||
# Mock _get_workflow_node_execution to return the mock node execution
|
# Mock the repository to return the node execution
|
||||||
with patch.object(workflow_cycle_manager, "_get_workflow_node_execution", return_value=node_execution):
|
workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution
|
||||||
|
|
||||||
# Call the method
|
# Call the method
|
||||||
result = workflow_cycle_manager._handle_workflow_node_execution_failed(
|
result = workflow_cycle_manager._handle_workflow_node_execution_failed(
|
||||||
event=event,
|
event=event,
|
||||||
@ -337,12 +331,6 @@ def test_handle_workflow_node_execution_failed(workflow_cycle_manager):
|
|||||||
assert result == node_execution
|
assert result == node_execution
|
||||||
assert result.status == WorkflowNodeExecutionStatus.FAILED.value
|
assert result.status == WorkflowNodeExecutionStatus.FAILED.value
|
||||||
assert result.error == "Test error message"
|
assert result.error == "Test error message"
|
||||||
assert result.inputs == json.dumps(event.inputs)
|
|
||||||
assert result.process_data == json.dumps(event.process_data)
|
|
||||||
assert result.outputs == json.dumps(event.outputs)
|
|
||||||
assert result.finished_at is not None
|
|
||||||
assert result.elapsed_time is not None
|
|
||||||
assert result.execution_metadata == json.dumps(event.execution_metadata)
|
|
||||||
|
|
||||||
# Verify update was called
|
# Verify save was called
|
||||||
workflow_cycle_manager._workflow_node_execution_repository.update.assert_called_once_with(node_execution)
|
workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(node_execution)
|
||||||
|
@ -2,15 +2,36 @@
|
|||||||
Unit tests for the SQLAlchemy implementation of WorkflowNodeExecutionRepository.
|
Unit tests for the SQLAlchemy implementation of WorkflowNodeExecutionRepository.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from unittest.mock import MagicMock
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import MagicMock, PropertyMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||||
|
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||||
|
from core.workflow.entities.node_execution_entities import NodeExecution, NodeExecutionStatus
|
||||||
|
from core.workflow.nodes.enums import NodeType
|
||||||
from core.workflow.repository.workflow_node_execution_repository import OrderConfig
|
from core.workflow.repository.workflow_node_execution_repository import OrderConfig
|
||||||
from models.workflow import WorkflowNodeExecution
|
from models.account import Account, Tenant
|
||||||
|
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom
|
||||||
|
|
||||||
|
|
||||||
|
def configure_mock_execution(mock_execution):
|
||||||
|
"""Configure a mock execution with proper JSON serializable values."""
|
||||||
|
# Configure inputs, outputs, process_data, and execution_metadata to return JSON serializable values
|
||||||
|
type(mock_execution).inputs = PropertyMock(return_value='{"key": "value"}')
|
||||||
|
type(mock_execution).outputs = PropertyMock(return_value='{"result": "success"}')
|
||||||
|
type(mock_execution).process_data = PropertyMock(return_value='{"process": "data"}')
|
||||||
|
type(mock_execution).execution_metadata = PropertyMock(return_value='{"metadata": "info"}')
|
||||||
|
|
||||||
|
# Configure status and triggered_from to be valid enum values
|
||||||
|
mock_execution.status = "running"
|
||||||
|
mock_execution.triggered_from = "workflow-run"
|
||||||
|
|
||||||
|
return mock_execution
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -28,13 +49,30 @@ def session():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def repository(session):
|
def mock_user():
|
||||||
|
"""Create a user instance for testing."""
|
||||||
|
user = Account()
|
||||||
|
user.id = "test-user-id"
|
||||||
|
|
||||||
|
tenant = Tenant()
|
||||||
|
tenant.id = "test-tenant"
|
||||||
|
tenant.name = "Test Workspace"
|
||||||
|
user._current_tenant = MagicMock()
|
||||||
|
user._current_tenant.id = "test-tenant"
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def repository(session, mock_user):
|
||||||
"""Create a repository instance with test data."""
|
"""Create a repository instance with test data."""
|
||||||
_, session_factory = session
|
_, session_factory = session
|
||||||
tenant_id = "test-tenant"
|
|
||||||
app_id = "test-app"
|
app_id = "test-app"
|
||||||
return SQLAlchemyWorkflowNodeExecutionRepository(
|
return SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
session_factory=session_factory, tenant_id=tenant_id, app_id=app_id
|
session_factory=session_factory,
|
||||||
|
user=mock_user,
|
||||||
|
app_id=app_id,
|
||||||
|
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -45,16 +83,23 @@ def test_save(repository, session):
|
|||||||
execution = MagicMock(spec=WorkflowNodeExecution)
|
execution = MagicMock(spec=WorkflowNodeExecution)
|
||||||
execution.tenant_id = None
|
execution.tenant_id = None
|
||||||
execution.app_id = None
|
execution.app_id = None
|
||||||
|
execution.inputs = None
|
||||||
|
execution.process_data = None
|
||||||
|
execution.outputs = None
|
||||||
|
execution.metadata = None
|
||||||
|
|
||||||
|
# Mock the _to_db_model method to return the execution itself
|
||||||
|
# This simulates the behavior of setting tenant_id and app_id
|
||||||
|
repository._to_db_model = MagicMock(return_value=execution)
|
||||||
|
|
||||||
# Call save method
|
# Call save method
|
||||||
repository.save(execution)
|
repository.save(execution)
|
||||||
|
|
||||||
# Assert tenant_id and app_id are set
|
# Assert _to_db_model was called with the execution
|
||||||
assert execution.tenant_id == repository._tenant_id
|
repository._to_db_model.assert_called_once_with(execution)
|
||||||
assert execution.app_id == repository._app_id
|
|
||||||
|
|
||||||
# Assert session.add was called
|
# Assert session.merge was called (now using merge for both save and update)
|
||||||
session_obj.add.assert_called_once_with(execution)
|
session_obj.merge.assert_called_once_with(execution)
|
||||||
|
|
||||||
|
|
||||||
def test_save_with_existing_tenant_id(repository, session):
|
def test_save_with_existing_tenant_id(repository, session):
|
||||||
@ -64,16 +109,27 @@ def test_save_with_existing_tenant_id(repository, session):
|
|||||||
execution = MagicMock(spec=WorkflowNodeExecution)
|
execution = MagicMock(spec=WorkflowNodeExecution)
|
||||||
execution.tenant_id = "existing-tenant"
|
execution.tenant_id = "existing-tenant"
|
||||||
execution.app_id = None
|
execution.app_id = None
|
||||||
|
execution.inputs = None
|
||||||
|
execution.process_data = None
|
||||||
|
execution.outputs = None
|
||||||
|
execution.metadata = None
|
||||||
|
|
||||||
|
# Create a modified execution that will be returned by _to_db_model
|
||||||
|
modified_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||||
|
modified_execution.tenant_id = "existing-tenant" # Tenant ID should not change
|
||||||
|
modified_execution.app_id = repository._app_id # App ID should be set
|
||||||
|
|
||||||
|
# Mock the _to_db_model method to return the modified execution
|
||||||
|
repository._to_db_model = MagicMock(return_value=modified_execution)
|
||||||
|
|
||||||
# Call save method
|
# Call save method
|
||||||
repository.save(execution)
|
repository.save(execution)
|
||||||
|
|
||||||
# Assert tenant_id is not changed and app_id is set
|
# Assert _to_db_model was called with the execution
|
||||||
assert execution.tenant_id == "existing-tenant"
|
repository._to_db_model.assert_called_once_with(execution)
|
||||||
assert execution.app_id == repository._app_id
|
|
||||||
|
|
||||||
# Assert session.add was called
|
# Assert session.merge was called with the modified execution (now using merge for both save and update)
|
||||||
session_obj.add.assert_called_once_with(execution)
|
session_obj.merge.assert_called_once_with(modified_execution)
|
||||||
|
|
||||||
|
|
||||||
def test_get_by_node_execution_id(repository, session, mocker: MockerFixture):
|
def test_get_by_node_execution_id(repository, session, mocker: MockerFixture):
|
||||||
@ -84,7 +140,16 @@ def test_get_by_node_execution_id(repository, session, mocker: MockerFixture):
|
|||||||
mock_stmt = mocker.MagicMock()
|
mock_stmt = mocker.MagicMock()
|
||||||
mock_select.return_value = mock_stmt
|
mock_select.return_value = mock_stmt
|
||||||
mock_stmt.where.return_value = mock_stmt
|
mock_stmt.where.return_value = mock_stmt
|
||||||
session_obj.scalar.return_value = mocker.MagicMock(spec=WorkflowNodeExecution)
|
|
||||||
|
# Create a properly configured mock execution
|
||||||
|
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecution)
|
||||||
|
configure_mock_execution(mock_execution)
|
||||||
|
session_obj.scalar.return_value = mock_execution
|
||||||
|
|
||||||
|
# Create a mock domain model to be returned by _to_domain_model
|
||||||
|
mock_domain_model = mocker.MagicMock()
|
||||||
|
# Mock the _to_domain_model method to return our mock domain model
|
||||||
|
repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model)
|
||||||
|
|
||||||
# Call method
|
# Call method
|
||||||
result = repository.get_by_node_execution_id("test-node-execution-id")
|
result = repository.get_by_node_execution_id("test-node-execution-id")
|
||||||
@ -92,7 +157,10 @@ def test_get_by_node_execution_id(repository, session, mocker: MockerFixture):
|
|||||||
# Assert select was called with correct parameters
|
# Assert select was called with correct parameters
|
||||||
mock_select.assert_called_once()
|
mock_select.assert_called_once()
|
||||||
session_obj.scalar.assert_called_once_with(mock_stmt)
|
session_obj.scalar.assert_called_once_with(mock_stmt)
|
||||||
assert result is not None
|
# Assert _to_domain_model was called with the mock execution
|
||||||
|
repository._to_domain_model.assert_called_once_with(mock_execution)
|
||||||
|
# Assert the result is our mock domain model
|
||||||
|
assert result is mock_domain_model
|
||||||
|
|
||||||
|
|
||||||
def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
|
def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
|
||||||
@ -104,7 +172,16 @@ def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
|
|||||||
mock_select.return_value = mock_stmt
|
mock_select.return_value = mock_stmt
|
||||||
mock_stmt.where.return_value = mock_stmt
|
mock_stmt.where.return_value = mock_stmt
|
||||||
mock_stmt.order_by.return_value = mock_stmt
|
mock_stmt.order_by.return_value = mock_stmt
|
||||||
session_obj.scalars.return_value.all.return_value = [mocker.MagicMock(spec=WorkflowNodeExecution)]
|
|
||||||
|
# Create a properly configured mock execution
|
||||||
|
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecution)
|
||||||
|
configure_mock_execution(mock_execution)
|
||||||
|
session_obj.scalars.return_value.all.return_value = [mock_execution]
|
||||||
|
|
||||||
|
# Create a mock domain model to be returned by _to_domain_model
|
||||||
|
mock_domain_model = mocker.MagicMock()
|
||||||
|
# Mock the _to_domain_model method to return our mock domain model
|
||||||
|
repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model)
|
||||||
|
|
||||||
# Call method
|
# Call method
|
||||||
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
||||||
@ -113,7 +190,45 @@ def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
|
|||||||
# Assert select was called with correct parameters
|
# Assert select was called with correct parameters
|
||||||
mock_select.assert_called_once()
|
mock_select.assert_called_once()
|
||||||
session_obj.scalars.assert_called_once_with(mock_stmt)
|
session_obj.scalars.assert_called_once_with(mock_stmt)
|
||||||
|
# Assert _to_domain_model was called with the mock execution
|
||||||
|
repository._to_domain_model.assert_called_once_with(mock_execution)
|
||||||
|
# Assert the result contains our mock domain model
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
|
assert result[0] is mock_domain_model
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_db_models_by_workflow_run(repository, session, mocker: MockerFixture):
|
||||||
|
"""Test get_db_models_by_workflow_run method."""
|
||||||
|
session_obj, _ = session
|
||||||
|
# Set up mock
|
||||||
|
mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select")
|
||||||
|
mock_stmt = mocker.MagicMock()
|
||||||
|
mock_select.return_value = mock_stmt
|
||||||
|
mock_stmt.where.return_value = mock_stmt
|
||||||
|
mock_stmt.order_by.return_value = mock_stmt
|
||||||
|
|
||||||
|
# Create a properly configured mock execution
|
||||||
|
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecution)
|
||||||
|
configure_mock_execution(mock_execution)
|
||||||
|
session_obj.scalars.return_value.all.return_value = [mock_execution]
|
||||||
|
|
||||||
|
# Mock the _to_domain_model method
|
||||||
|
to_domain_model_mock = mocker.patch.object(repository, "_to_domain_model")
|
||||||
|
|
||||||
|
# Call method
|
||||||
|
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
||||||
|
result = repository.get_db_models_by_workflow_run(workflow_run_id="test-workflow-run-id", order_config=order_config)
|
||||||
|
|
||||||
|
# Assert select was called with correct parameters
|
||||||
|
mock_select.assert_called_once()
|
||||||
|
session_obj.scalars.assert_called_once_with(mock_stmt)
|
||||||
|
|
||||||
|
# Assert the result contains our mock db model directly (without conversion to domain model)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0] is mock_execution
|
||||||
|
|
||||||
|
# Verify that _to_domain_model was NOT called (since we're returning raw DB models)
|
||||||
|
to_domain_model_mock.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
def test_get_running_executions(repository, session, mocker: MockerFixture):
|
def test_get_running_executions(repository, session, mocker: MockerFixture):
|
||||||
@ -124,7 +239,16 @@ def test_get_running_executions(repository, session, mocker: MockerFixture):
|
|||||||
mock_stmt = mocker.MagicMock()
|
mock_stmt = mocker.MagicMock()
|
||||||
mock_select.return_value = mock_stmt
|
mock_select.return_value = mock_stmt
|
||||||
mock_stmt.where.return_value = mock_stmt
|
mock_stmt.where.return_value = mock_stmt
|
||||||
session_obj.scalars.return_value.all.return_value = [mocker.MagicMock(spec=WorkflowNodeExecution)]
|
|
||||||
|
# Create a properly configured mock execution
|
||||||
|
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecution)
|
||||||
|
configure_mock_execution(mock_execution)
|
||||||
|
session_obj.scalars.return_value.all.return_value = [mock_execution]
|
||||||
|
|
||||||
|
# Create a mock domain model to be returned by _to_domain_model
|
||||||
|
mock_domain_model = mocker.MagicMock()
|
||||||
|
# Mock the _to_domain_model method to return our mock domain model
|
||||||
|
repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model)
|
||||||
|
|
||||||
# Call method
|
# Call method
|
||||||
result = repository.get_running_executions("test-workflow-run-id")
|
result = repository.get_running_executions("test-workflow-run-id")
|
||||||
@ -132,25 +256,36 @@ def test_get_running_executions(repository, session, mocker: MockerFixture):
|
|||||||
# Assert select was called with correct parameters
|
# Assert select was called with correct parameters
|
||||||
mock_select.assert_called_once()
|
mock_select.assert_called_once()
|
||||||
session_obj.scalars.assert_called_once_with(mock_stmt)
|
session_obj.scalars.assert_called_once_with(mock_stmt)
|
||||||
|
# Assert _to_domain_model was called with the mock execution
|
||||||
|
repository._to_domain_model.assert_called_once_with(mock_execution)
|
||||||
|
# Assert the result contains our mock domain model
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
|
assert result[0] is mock_domain_model
|
||||||
|
|
||||||
|
|
||||||
def test_update(repository, session):
|
def test_update_via_save(repository, session):
|
||||||
"""Test update method."""
|
"""Test updating an existing record via save method."""
|
||||||
session_obj, _ = session
|
session_obj, _ = session
|
||||||
# Create a mock execution
|
# Create a mock execution
|
||||||
execution = MagicMock(spec=WorkflowNodeExecution)
|
execution = MagicMock(spec=WorkflowNodeExecution)
|
||||||
execution.tenant_id = None
|
execution.tenant_id = None
|
||||||
execution.app_id = None
|
execution.app_id = None
|
||||||
|
execution.inputs = None
|
||||||
|
execution.process_data = None
|
||||||
|
execution.outputs = None
|
||||||
|
execution.metadata = None
|
||||||
|
|
||||||
# Call update method
|
# Mock the _to_db_model method to return the execution itself
|
||||||
repository.update(execution)
|
# This simulates the behavior of setting tenant_id and app_id
|
||||||
|
repository._to_db_model = MagicMock(return_value=execution)
|
||||||
|
|
||||||
# Assert tenant_id and app_id are set
|
# Call save method to update an existing record
|
||||||
assert execution.tenant_id == repository._tenant_id
|
repository.save(execution)
|
||||||
assert execution.app_id == repository._app_id
|
|
||||||
|
|
||||||
# Assert session.merge was called
|
# Assert _to_db_model was called with the execution
|
||||||
|
repository._to_db_model.assert_called_once_with(execution)
|
||||||
|
|
||||||
|
# Assert session.merge was called (for updates)
|
||||||
session_obj.merge.assert_called_once_with(execution)
|
session_obj.merge.assert_called_once_with(execution)
|
||||||
|
|
||||||
|
|
||||||
@ -176,3 +311,118 @@ def test_clear(repository, session, mocker: MockerFixture):
|
|||||||
mock_stmt.where.assert_called()
|
mock_stmt.where.assert_called()
|
||||||
session_obj.execute.assert_called_once_with(mock_stmt)
|
session_obj.execute.assert_called_once_with(mock_stmt)
|
||||||
session_obj.commit.assert_called_once()
|
session_obj.commit.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_db_model(repository):
|
||||||
|
"""Test _to_db_model method."""
|
||||||
|
# Create a domain model
|
||||||
|
domain_model = NodeExecution(
|
||||||
|
id="test-id",
|
||||||
|
workflow_id="test-workflow-id",
|
||||||
|
node_execution_id="test-node-execution-id",
|
||||||
|
workflow_run_id="test-workflow-run-id",
|
||||||
|
index=1,
|
||||||
|
predecessor_node_id="test-predecessor-id",
|
||||||
|
node_id="test-node-id",
|
||||||
|
node_type=NodeType.START,
|
||||||
|
title="Test Node",
|
||||||
|
inputs={"input_key": "input_value"},
|
||||||
|
process_data={"process_key": "process_value"},
|
||||||
|
outputs={"output_key": "output_value"},
|
||||||
|
status=NodeExecutionStatus.RUNNING,
|
||||||
|
error=None,
|
||||||
|
elapsed_time=1.5,
|
||||||
|
metadata={NodeRunMetadataKey.TOTAL_TOKENS: 100},
|
||||||
|
created_at=datetime.now(),
|
||||||
|
finished_at=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to DB model
|
||||||
|
db_model = repository._to_db_model(domain_model)
|
||||||
|
|
||||||
|
# Assert DB model has correct values
|
||||||
|
assert isinstance(db_model, WorkflowNodeExecution)
|
||||||
|
assert db_model.id == domain_model.id
|
||||||
|
assert db_model.tenant_id == repository._tenant_id
|
||||||
|
assert db_model.app_id == repository._app_id
|
||||||
|
assert db_model.workflow_id == domain_model.workflow_id
|
||||||
|
assert db_model.triggered_from == repository._triggered_from
|
||||||
|
assert db_model.workflow_run_id == domain_model.workflow_run_id
|
||||||
|
assert db_model.index == domain_model.index
|
||||||
|
assert db_model.predecessor_node_id == domain_model.predecessor_node_id
|
||||||
|
assert db_model.node_execution_id == domain_model.node_execution_id
|
||||||
|
assert db_model.node_id == domain_model.node_id
|
||||||
|
assert db_model.node_type == domain_model.node_type
|
||||||
|
assert db_model.title == domain_model.title
|
||||||
|
|
||||||
|
assert db_model.inputs_dict == domain_model.inputs
|
||||||
|
assert db_model.process_data_dict == domain_model.process_data
|
||||||
|
assert db_model.outputs_dict == domain_model.outputs
|
||||||
|
assert db_model.execution_metadata_dict == domain_model.metadata
|
||||||
|
|
||||||
|
assert db_model.status == domain_model.status
|
||||||
|
assert db_model.error == domain_model.error
|
||||||
|
assert db_model.elapsed_time == domain_model.elapsed_time
|
||||||
|
assert db_model.created_at == domain_model.created_at
|
||||||
|
assert db_model.created_by_role == repository._creator_user_role
|
||||||
|
assert db_model.created_by == repository._creator_user_id
|
||||||
|
assert db_model.finished_at == domain_model.finished_at
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_domain_model(repository):
|
||||||
|
"""Test _to_domain_model method."""
|
||||||
|
# Create input dictionaries
|
||||||
|
inputs_dict = {"input_key": "input_value"}
|
||||||
|
process_data_dict = {"process_key": "process_value"}
|
||||||
|
outputs_dict = {"output_key": "output_value"}
|
||||||
|
metadata_dict = {str(NodeRunMetadataKey.TOTAL_TOKENS): 100}
|
||||||
|
|
||||||
|
# Create a DB model using our custom subclass
|
||||||
|
db_model = WorkflowNodeExecution()
|
||||||
|
db_model.id = "test-id"
|
||||||
|
db_model.tenant_id = "test-tenant-id"
|
||||||
|
db_model.app_id = "test-app-id"
|
||||||
|
db_model.workflow_id = "test-workflow-id"
|
||||||
|
db_model.triggered_from = "workflow-run"
|
||||||
|
db_model.workflow_run_id = "test-workflow-run-id"
|
||||||
|
db_model.index = 1
|
||||||
|
db_model.predecessor_node_id = "test-predecessor-id"
|
||||||
|
db_model.node_execution_id = "test-node-execution-id"
|
||||||
|
db_model.node_id = "test-node-id"
|
||||||
|
db_model.node_type = NodeType.START.value
|
||||||
|
db_model.title = "Test Node"
|
||||||
|
db_model.inputs = json.dumps(inputs_dict)
|
||||||
|
db_model.process_data = json.dumps(process_data_dict)
|
||||||
|
db_model.outputs = json.dumps(outputs_dict)
|
||||||
|
db_model.status = WorkflowNodeExecutionStatus.RUNNING
|
||||||
|
db_model.error = None
|
||||||
|
db_model.elapsed_time = 1.5
|
||||||
|
db_model.execution_metadata = json.dumps(metadata_dict)
|
||||||
|
db_model.created_at = datetime.now()
|
||||||
|
db_model.created_by_role = "account"
|
||||||
|
db_model.created_by = "test-user-id"
|
||||||
|
db_model.finished_at = None
|
||||||
|
|
||||||
|
# Convert to domain model
|
||||||
|
domain_model = repository._to_domain_model(db_model)
|
||||||
|
|
||||||
|
# Assert domain model has correct values
|
||||||
|
assert isinstance(domain_model, NodeExecution)
|
||||||
|
assert domain_model.id == db_model.id
|
||||||
|
assert domain_model.workflow_id == db_model.workflow_id
|
||||||
|
assert domain_model.workflow_run_id == db_model.workflow_run_id
|
||||||
|
assert domain_model.index == db_model.index
|
||||||
|
assert domain_model.predecessor_node_id == db_model.predecessor_node_id
|
||||||
|
assert domain_model.node_execution_id == db_model.node_execution_id
|
||||||
|
assert domain_model.node_id == db_model.node_id
|
||||||
|
assert domain_model.node_type == NodeType(db_model.node_type)
|
||||||
|
assert domain_model.title == db_model.title
|
||||||
|
assert domain_model.inputs == inputs_dict
|
||||||
|
assert domain_model.process_data == process_data_dict
|
||||||
|
assert domain_model.outputs == outputs_dict
|
||||||
|
assert domain_model.status == NodeExecutionStatus(db_model.status)
|
||||||
|
assert domain_model.error == db_model.error
|
||||||
|
assert domain_model.elapsed_time == db_model.elapsed_time
|
||||||
|
assert domain_model.metadata == metadata_dict
|
||||||
|
assert domain_model.created_at == db_model.created_at
|
||||||
|
assert domain_model.finished_at == db_model.finished_at
|
||||||
|
Loading…
x
Reference in New Issue
Block a user