feat(workflow_service): workflow version control api. (#14860)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2025-03-10 13:34:31 +08:00 committed by GitHub
parent f2b7df94d7
commit 3254018ddb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 1732 additions and 885 deletions

View File

@ -1,8 +1,10 @@
import json import json
import logging import logging
from typing import cast
from flask import abort, request from flask import abort, request
from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services import services
@ -13,6 +15,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from factories import variable_factory from factories import variable_factory
from fields.workflow_fields import workflow_fields, workflow_pagination_fields from fields.workflow_fields import workflow_fields, workflow_pagination_fields
from fields.workflow_run_fields import workflow_run_node_execution_fields from fields.workflow_run_fields import workflow_run_node_execution_fields
@ -24,7 +27,7 @@ from models.account import Account
from models.model import AppMode from models.model import AppMode
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowHashNotEqualError from services.errors.app import WorkflowHashNotEqualError
from services.workflow_service import WorkflowService from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -439,10 +442,38 @@ class PublishedWorkflowApi(Resource):
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise Forbidden() raise Forbidden()
workflow_service = WorkflowService() parser = reqparse.RequestParser()
workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user) parser.add_argument("marked_name", type=str, required=False, default="", location="json")
parser.add_argument("marked_comment", type=str, required=False, default="", location="json")
args = parser.parse_args()
return {"result": "success", "created_at": TimestampField().format(workflow.created_at)} # Validate name and comment length
if args.marked_name and len(args.marked_name) > 20:
raise ValueError("Marked name cannot exceed 20 characters")
if args.marked_comment and len(args.marked_comment) > 100:
raise ValueError("Marked comment cannot exceed 100 characters")
workflow_service = WorkflowService()
with Session(db.engine) as session:
workflow = workflow_service.publish_workflow(
session=session,
app_model=app_model,
account=current_user,
marked_name=args.marked_name or "",
marked_comment=args.marked_comment or "",
)
app_model.workflow_id = workflow.id
db.session.commit()
workflow_created_at = TimestampField().format(workflow.created_at)
session.commit()
return {
"result": "success",
"created_at": workflow_created_at,
}
class DefaultBlockConfigsApi(Resource): class DefaultBlockConfigsApi(Resource):
@ -564,37 +595,193 @@ class PublishedAllWorkflowApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
parser.add_argument("user_id", type=str, required=False, location="args")
parser.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
args = parser.parse_args() args = parser.parse_args()
page = args.get("page") page = int(args.get("page", 1))
limit = args.get("limit") limit = int(args.get("limit", 10))
user_id = args.get("user_id")
named_only = args.get("named_only", False)
if user_id:
if user_id != current_user.id:
raise Forbidden()
user_id = cast(str, user_id)
workflow_service = WorkflowService() workflow_service = WorkflowService()
workflows, has_more = workflow_service.get_all_published_workflow(app_model=app_model, page=page, limit=limit) with Session(db.engine) as session:
workflows, has_more = workflow_service.get_all_published_workflow(
session=session,
app_model=app_model,
page=page,
limit=limit,
user_id=user_id,
named_only=named_only,
)
return {"items": workflows, "page": page, "limit": limit, "has_more": has_more} return {
"items": workflows,
"page": page,
"limit": limit,
"has_more": has_more,
}
api.add_resource(DraftWorkflowApi, "/apps/<uuid:app_id>/workflows/draft") class WorkflowByIdApi(Resource):
api.add_resource(WorkflowConfigApi, "/apps/<uuid:app_id>/workflows/draft/config") @setup_required
api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run") @login_required
api.add_resource(DraftWorkflowRunApi, "/apps/<uuid:app_id>/workflows/draft/run") @account_initialization_required
api.add_resource(WorkflowTaskStopApi, "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop") @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
api.add_resource(DraftWorkflowNodeRunApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run") @marshal_with(workflow_fields)
def patch(self, app_model: App, workflow_id: str):
"""
Update workflow attributes
"""
# Check permission
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("marked_name", type=str, required=False, location="json")
parser.add_argument("marked_comment", type=str, required=False, location="json")
args = parser.parse_args()
# Validate name and comment length
if args.marked_name and len(args.marked_name) > 20:
raise ValueError("Marked name cannot exceed 20 characters")
if args.marked_comment and len(args.marked_comment) > 100:
raise ValueError("Marked comment cannot exceed 100 characters")
args = parser.parse_args()
# Prepare update data
update_data = {}
if args.get("marked_name") is not None:
update_data["marked_name"] = args["marked_name"]
if args.get("marked_comment") is not None:
update_data["marked_comment"] = args["marked_comment"]
if not update_data:
return {"message": "No valid fields to update"}, 400
workflow_service = WorkflowService()
# Create a session and manage the transaction
with Session(db.engine, expire_on_commit=False) as session:
workflow = workflow_service.update_workflow(
session=session,
workflow_id=workflow_id,
tenant_id=app_model.tenant_id,
account_id=current_user.id,
data=update_data,
)
if not workflow:
raise NotFound("Workflow not found")
# Commit the transaction in the controller
session.commit()
return workflow
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def delete(self, app_model: App, workflow_id: str):
"""
Delete workflow
"""
# Check permission
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
workflow_service = WorkflowService()
# Create a session and manage the transaction
with Session(db.engine) as session:
try:
workflow_service.delete_workflow(
session=session, workflow_id=workflow_id, tenant_id=app_model.tenant_id
)
# Commit the transaction in the controller
session.commit()
except WorkflowInUseError as e:
abort(400, description=str(e))
except DraftWorkflowDeletionError as e:
abort(400, description=str(e))
except ValueError as e:
raise NotFound(str(e))
return None, 204
api.add_resource(
DraftWorkflowApi,
"/apps/<uuid:app_id>/workflows/draft",
)
api.add_resource(
WorkflowConfigApi,
"/apps/<uuid:app_id>/workflows/draft/config",
)
api.add_resource(
AdvancedChatDraftWorkflowRunApi,
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/run",
)
api.add_resource(
DraftWorkflowRunApi,
"/apps/<uuid:app_id>/workflows/draft/run",
)
api.add_resource(
WorkflowTaskStopApi,
"/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop",
)
api.add_resource(
DraftWorkflowNodeRunApi,
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run",
)
api.add_resource( api.add_resource(
AdvancedChatDraftRunIterationNodeApi, AdvancedChatDraftRunIterationNodeApi,
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run", "/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run",
) )
api.add_resource( api.add_resource(
WorkflowDraftRunIterationNodeApi, "/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run" WorkflowDraftRunIterationNodeApi,
"/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run",
) )
api.add_resource( api.add_resource(
AdvancedChatDraftRunLoopNodeApi, AdvancedChatDraftRunLoopNodeApi,
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/loop/nodes/<string:node_id>/run", "/apps/<uuid:app_id>/advanced-chat/workflows/draft/loop/nodes/<string:node_id>/run",
) )
api.add_resource(WorkflowDraftRunLoopNodeApi, "/apps/<uuid:app_id>/workflows/draft/loop/nodes/<string:node_id>/run")
api.add_resource(PublishedWorkflowApi, "/apps/<uuid:app_id>/workflows/publish")
api.add_resource(PublishedAllWorkflowApi, "/apps/<uuid:app_id>/workflows")
api.add_resource(DefaultBlockConfigsApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
api.add_resource( api.add_resource(
DefaultBlockConfigApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>" WorkflowDraftRunLoopNodeApi,
"/apps/<uuid:app_id>/workflows/draft/loop/nodes/<string:node_id>/run",
)
api.add_resource(
PublishedWorkflowApi,
"/apps/<uuid:app_id>/workflows/publish",
)
api.add_resource(
PublishedAllWorkflowApi,
"/apps/<uuid:app_id>/workflows",
)
api.add_resource(
DefaultBlockConfigsApi,
"/apps/<uuid:app_id>/workflows/default-workflow-block-configs",
)
api.add_resource(
DefaultBlockConfigApi,
"/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>",
)
api.add_resource(
ConvertToWorkflowApi,
"/apps/<uuid:app_id>/convert-to-workflow",
)
api.add_resource(
WorkflowByIdApi,
"/apps/<uuid:app_id>/workflows/<string:workflow_id>",
) )
api.add_resource(ConvertToWorkflowApi, "/apps/<uuid:app_id>/convert-to-workflow")

View File

@ -45,7 +45,9 @@ workflow_fields = {
"graph": fields.Raw(attribute="graph_dict"), "graph": fields.Raw(attribute="graph_dict"),
"features": fields.Raw(attribute="features_dict"), "features": fields.Raw(attribute="features_dict"),
"hash": fields.String(attribute="unique_hash"), "hash": fields.String(attribute="unique_hash"),
"version": fields.String(attribute="version"), "version": fields.String,
"marked_name": fields.String,
"marked_comment": fields.String,
"created_by": fields.Nested(simple_account_fields, attribute="created_by_account"), "created_by": fields.Nested(simple_account_fields, attribute="created_by_account"),
"created_at": TimestampField, "created_at": TimestampField,
"updated_by": fields.Nested(simple_account_fields, attribute="updated_by_account", allow_null=True), "updated_by": fields.Nested(simple_account_fields, attribute="updated_by_account", allow_null=True),

View File

@ -0,0 +1,29 @@
"""add marked_name and marked_comment in workflows
Revision ID: ee79d9b1c156
Revises: 4413929e1ec2
Create Date: 2025-03-03 14:36:05.750346
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'ee79d9b1c156'
down_revision = '5511c782ee4c'
branch_labels = None
depends_on = None
def upgrade():
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.add_column(sa.Column('marked_name', sa.String(), nullable=False, server_default=''))
batch_op.add_column(sa.Column('marked_comment', sa.String(), nullable=False, server_default=''))
def downgrade():
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.drop_column('marked_comment')
batch_op.drop_column('marked_name')

View File

@ -2,7 +2,8 @@ import json
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from datetime import UTC, datetime from datetime import UTC, datetime
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Self, Union
from uuid import uuid4
if TYPE_CHECKING: if TYPE_CHECKING:
from models.model import AppMode from models.model import AppMode
@ -108,7 +109,9 @@ class Workflow(Base):
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(db.String(255), nullable=False) type: Mapped[str] = mapped_column(db.String(255), nullable=False)
version: Mapped[str] = mapped_column(db.String(255), nullable=False) version: Mapped[str]
marked_name: Mapped[str] = mapped_column(default="", server_default="")
marked_comment: Mapped[str] = mapped_column(default="", server_default="")
graph: Mapped[str] = mapped_column(sa.Text) graph: Mapped[str] = mapped_column(sa.Text)
_features: Mapped[str] = mapped_column("features", sa.TEXT) _features: Mapped[str] = mapped_column("features", sa.TEXT)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
@ -127,8 +130,9 @@ class Workflow(Base):
"conversation_variables", db.Text, nullable=False, server_default="{}" "conversation_variables", db.Text, nullable=False, server_default="{}"
) )
def __init__( @classmethod
self, def new(
cls,
*, *,
tenant_id: str, tenant_id: str,
app_id: str, app_id: str,
@ -139,16 +143,25 @@ class Workflow(Base):
created_by: str, created_by: str,
environment_variables: Sequence[Variable], environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable], conversation_variables: Sequence[Variable],
): marked_name: str = "",
self.tenant_id = tenant_id marked_comment: str = "",
self.app_id = app_id ) -> Self:
self.type = type workflow = Workflow()
self.version = version workflow.id = str(uuid4())
self.graph = graph workflow.tenant_id = tenant_id
self.features = features workflow.app_id = app_id
self.created_by = created_by workflow.type = type
self.environment_variables = environment_variables or [] workflow.version = version
self.conversation_variables = conversation_variables or [] workflow.graph = graph
workflow.features = features
workflow.created_by = created_by
workflow.environment_variables = environment_variables or []
workflow.conversation_variables = conversation_variables or []
workflow.marked_name = marked_name
workflow.marked_comment = marked_comment
workflow.created_at = datetime.now(UTC).replace(tzinfo=None)
workflow.updated_at = workflow.created_at
return workflow
@property @property
def created_by_account(self): def created_by_account(self):

2006
api/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -151,6 +151,7 @@ pytest-benchmark = "~4.0.0"
pytest-env = "~1.1.3" pytest-env = "~1.1.3"
pytest-mock = "~3.14.0" pytest-mock = "~3.14.0"
types-beautifulsoup4 = "~4.12.0.20241020" types-beautifulsoup4 = "~4.12.0.20241020"
types-deprecated = "~1.2.15.20250304"
types-flask-cors = "~5.0.0.20240902" types-flask-cors = "~5.0.0.20240902"
types-flask-migrate = "~4.1.0.20250112" types-flask-migrate = "~4.1.0.20250112"
types-html5lib = "~1.1.11.20241018" types-html5lib = "~1.1.11.20241018"

View File

@ -0,0 +1,10 @@
class WorkflowInUseError(ValueError):
"""Raised when attempting to delete a workflow that's in use by an app"""
pass
class DraftWorkflowDeletionError(ValueError):
"""Raised when attempting to delete a draft workflow"""
pass

View File

@ -5,7 +5,8 @@ from datetime import UTC, datetime
from typing import Any, Optional from typing import Any, Optional
from uuid import uuid4 from uuid import uuid4
from sqlalchemy import desc from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
@ -36,6 +37,8 @@ from models.workflow import (
from services.errors.app import WorkflowHashNotEqualError from services.errors.app import WorkflowHashNotEqualError
from services.workflow.workflow_converter import WorkflowConverter from services.workflow.workflow_converter import WorkflowConverter
from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
class WorkflowService: class WorkflowService:
""" """
@ -79,22 +82,38 @@ class WorkflowService:
return workflow return workflow
def get_all_published_workflow(self, app_model: App, page: int, limit: int) -> tuple[list[Workflow], bool]: def get_all_published_workflow(
self,
*,
session: Session,
app_model: App,
page: int,
limit: int,
user_id: str | None,
named_only: bool = False,
) -> tuple[Sequence[Workflow], bool]:
""" """
Get published workflow with pagination Get published workflow with pagination
""" """
if not app_model.workflow_id: if not app_model.workflow_id:
return [], False return [], False
workflows = ( stmt = (
db.session.query(Workflow) select(Workflow)
.filter(Workflow.app_id == app_model.id) .where(Workflow.app_id == app_model.id)
.order_by(desc(Workflow.version)) .order_by(Workflow.version.desc())
.offset((page - 1) * limit)
.limit(limit + 1) .limit(limit + 1)
.all() .offset((page - 1) * limit)
) )
if user_id:
stmt = stmt.where(Workflow.created_by == user_id)
if named_only:
stmt = stmt.where(Workflow.marked_name != "")
workflows = session.scalars(stmt).all()
has_more = len(workflows) > limit has_more = len(workflows) > limit
if has_more: if has_more:
workflows = workflows[:-1] workflows = workflows[:-1]
@ -157,23 +176,26 @@ class WorkflowService:
# return draft workflow # return draft workflow
return workflow return workflow
def publish_workflow(self, app_model: App, account: Account, draft_workflow: Optional[Workflow] = None) -> Workflow: def publish_workflow(
""" self,
Publish workflow from draft *,
session: Session,
:param app_model: App instance app_model: App,
:param account: Account instance account: Account,
:param draft_workflow: Workflow instance marked_name: str = "",
""" marked_comment: str = "",
if not draft_workflow: ) -> Workflow:
# fetch draft workflow by app_model draft_workflow_stmt = select(Workflow).where(
draft_workflow = self.get_draft_workflow(app_model=app_model) Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.version == "draft",
)
draft_workflow = session.scalar(draft_workflow_stmt)
if not draft_workflow: if not draft_workflow:
raise ValueError("No valid workflow found.") raise ValueError("No valid workflow found.")
# create new workflow # create new workflow
workflow = Workflow( workflow = Workflow.new(
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
app_id=app_model.id, app_id=app_model.id,
type=draft_workflow.type, type=draft_workflow.type,
@ -183,15 +205,12 @@ class WorkflowService:
created_by=account.id, created_by=account.id,
environment_variables=draft_workflow.environment_variables, environment_variables=draft_workflow.environment_variables,
conversation_variables=draft_workflow.conversation_variables, conversation_variables=draft_workflow.conversation_variables,
marked_name=marked_name,
marked_comment=marked_comment,
) )
# commit db session changes # commit db session changes
db.session.add(workflow) session.add(workflow)
db.session.flush()
db.session.commit()
app_model.workflow_id = workflow.id
db.session.commit()
# trigger app workflow events # trigger app workflow events
app_published_workflow_was_updated.send(app_model, published_workflow=workflow) app_published_workflow_was_updated.send(app_model, published_workflow=workflow)
@ -436,3 +455,65 @@ class WorkflowService:
) )
else: else:
raise ValueError(f"Invalid app mode: {app_model.mode}") raise ValueError(f"Invalid app mode: {app_model.mode}")
def update_workflow(
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
) -> Optional[Workflow]:
"""
Update workflow attributes
:param session: SQLAlchemy database session
:param workflow_id: Workflow ID
:param tenant_id: Tenant ID
:param account_id: Account ID (for permission check)
:param data: Dictionary containing fields to update
:return: Updated workflow or None if not found
"""
stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
workflow = session.scalar(stmt)
if not workflow:
return None
allowed_fields = ["marked_name", "marked_comment"]
for field, value in data.items():
if field in allowed_fields:
setattr(workflow, field, value)
workflow.updated_by = account_id
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
return workflow
def delete_workflow(self, *, session: Session, workflow_id: str, tenant_id: str) -> bool:
"""
Delete a workflow
:param session: SQLAlchemy database session
:param workflow_id: Workflow ID
:param tenant_id: Tenant ID
:return: True if successful
:raises: ValueError if workflow not found
:raises: WorkflowInUseError if workflow is in use
:raises: DraftWorkflowDeletionError if workflow is a draft version
"""
stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
workflow = session.scalar(stmt)
if not workflow:
raise ValueError(f"Workflow with ID {workflow_id} not found")
# Check if workflow is a draft version
if workflow.version == "draft":
raise DraftWorkflowDeletionError("Cannot delete draft workflow versions")
# Check if this workflow is currently referenced by an app
stmt = select(App).where(App.workflow_id == workflow_id)
app = session.scalar(stmt)
if app:
# Cannot delete a workflow that's currently in use by an app
raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.name}'")
session.delete(workflow)
return True

View File

@ -0,0 +1,162 @@
from unittest.mock import MagicMock
import pytest
from models.model import App
from models.workflow import Workflow
from services.workflow_service import WorkflowService
class TestWorkflowService:
@pytest.fixture
def workflow_service(self):
return WorkflowService()
@pytest.fixture
def mock_app(self):
app = MagicMock(spec=App)
app.id = "app-id-1"
app.workflow_id = "workflow-id-1"
app.tenant_id = "tenant-id-1"
return app
@pytest.fixture
def mock_workflows(self):
workflows = []
for i in range(5):
workflow = MagicMock(spec=Workflow)
workflow.id = f"workflow-id-{i}"
workflow.app_id = "app-id-1"
workflow.created_at = f"2023-01-0{5 - i}" # Descending date order
workflow.created_by = "user-id-1" if i % 2 == 0 else "user-id-2"
workflow.marked_name = f"Workflow {i}" if i % 2 == 0 else ""
workflows.append(workflow)
return workflows
def test_get_all_published_workflow_no_workflow_id(self, workflow_service, mock_app):
mock_app.workflow_id = None
mock_session = MagicMock()
workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None
)
assert workflows == []
assert has_more is False
mock_session.scalars.assert_not_called()
def test_get_all_published_workflow_basic(self, workflow_service, mock_app, mock_workflows):
mock_session = MagicMock()
mock_scalar_result = MagicMock()
mock_scalar_result.all.return_value = mock_workflows[:3]
mock_session.scalars.return_value = mock_scalar_result
workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=3, user_id=None
)
assert workflows == mock_workflows[:3]
assert has_more is False
mock_session.scalars.assert_called_once()
def test_get_all_published_workflow_pagination(self, workflow_service, mock_app, mock_workflows):
mock_session = MagicMock()
mock_scalar_result = MagicMock()
# Return 4 items when limit is 3, which should indicate has_more=True
mock_scalar_result.all.return_value = mock_workflows[:4]
mock_session.scalars.return_value = mock_scalar_result
workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=3, user_id=None
)
# Should return only the first 3 items
assert len(workflows) == 3
assert workflows == mock_workflows[:3]
assert has_more is True
# Test page 2
mock_scalar_result.all.return_value = mock_workflows[3:]
mock_session.scalars.return_value = mock_scalar_result
workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=2, limit=3, user_id=None
)
assert len(workflows) == 2
assert has_more is False
def test_get_all_published_workflow_user_filter(self, workflow_service, mock_app, mock_workflows):
mock_session = MagicMock()
mock_scalar_result = MagicMock()
# Filter workflows for user-id-1
filtered_workflows = [w for w in mock_workflows if w.created_by == "user-id-1"]
mock_scalar_result.all.return_value = filtered_workflows
mock_session.scalars.return_value = mock_scalar_result
workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=10, user_id="user-id-1"
)
assert workflows == filtered_workflows
assert has_more is False
mock_session.scalars.assert_called_once()
# Verify that the select contains a user filter clause
args = mock_session.scalars.call_args[0][0]
assert "created_by" in str(args)
def test_get_all_published_workflow_named_only(self, workflow_service, mock_app, mock_workflows):
mock_session = MagicMock()
mock_scalar_result = MagicMock()
# Filter workflows that have a marked_name
named_workflows = [w for w in mock_workflows if w.marked_name]
mock_scalar_result.all.return_value = named_workflows
mock_session.scalars.return_value = mock_scalar_result
workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None, named_only=True
)
assert workflows == named_workflows
assert has_more is False
mock_session.scalars.assert_called_once()
# Verify that the select contains a named_only filter clause
args = mock_session.scalars.call_args[0][0]
assert "marked_name !=" in str(args)
def test_get_all_published_workflow_combined_filters(self, workflow_service, mock_app, mock_workflows):
mock_session = MagicMock()
mock_scalar_result = MagicMock()
# Combined filter: user-id-1 and has marked_name
filtered_workflows = [w for w in mock_workflows if w.created_by == "user-id-1" and w.marked_name]
mock_scalar_result.all.return_value = filtered_workflows
mock_session.scalars.return_value = mock_scalar_result
workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=10, user_id="user-id-1", named_only=True
)
assert workflows == filtered_workflows
assert has_more is False
mock_session.scalars.assert_called_once()
# Verify that both filters are applied
args = mock_session.scalars.call_args[0][0]
assert "created_by" in str(args)
assert "marked_name !=" in str(args)
def test_get_all_published_workflow_empty_result(self, workflow_service, mock_app):
mock_session = MagicMock()
mock_scalar_result = MagicMock()
mock_scalar_result.all.return_value = []
mock_session.scalars.return_value = mock_scalar_result
workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None
)
assert workflows == []
assert has_more is False
mock_session.scalars.assert_called_once()