mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-11 16:18:59 +08:00
feat(workflow_service): workflow version control api. (#14860)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
f2b7df94d7
commit
3254018ddb
@ -1,8 +1,10 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
@ -13,6 +15,7 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
||||
@ -24,7 +27,7 @@ from models.account import Account
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.workflow_service import WorkflowService
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -439,10 +442,38 @@ class PublishedWorkflowApi(Resource):
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("marked_name", type=str, required=False, default="", location="json")
|
||||
parser.add_argument("marked_comment", type=str, required=False, default="", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
return {"result": "success", "created_at": TimestampField().format(workflow.created_at)}
|
||||
# Validate name and comment length
|
||||
if args.marked_name and len(args.marked_name) > 20:
|
||||
raise ValueError("Marked name cannot exceed 20 characters")
|
||||
if args.marked_comment and len(args.marked_comment) > 100:
|
||||
raise ValueError("Marked comment cannot exceed 100 characters")
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
with Session(db.engine) as session:
|
||||
workflow = workflow_service.publish_workflow(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
marked_name=args.marked_name or "",
|
||||
marked_comment=args.marked_comment or "",
|
||||
)
|
||||
|
||||
app_model.workflow_id = workflow.id
|
||||
db.session.commit()
|
||||
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
|
||||
session.commit()
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"created_at": workflow_created_at,
|
||||
}
|
||||
|
||||
|
||||
class DefaultBlockConfigsApi(Resource):
|
||||
@ -564,37 +595,193 @@ class PublishedAllWorkflowApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
parser.add_argument("user_id", type=str, required=False, location="args")
|
||||
parser.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
|
||||
args = parser.parse_args()
|
||||
page = args.get("page")
|
||||
limit = args.get("limit")
|
||||
page = int(args.get("page", 1))
|
||||
limit = int(args.get("limit", 10))
|
||||
user_id = args.get("user_id")
|
||||
named_only = args.get("named_only", False)
|
||||
|
||||
if user_id:
|
||||
if user_id != current_user.id:
|
||||
raise Forbidden()
|
||||
user_id = cast(str, user_id)
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(app_model=app_model, page=page, limit=limit)
|
||||
with Session(db.engine) as session:
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
page=page,
|
||||
limit=limit,
|
||||
user_id=user_id,
|
||||
named_only=named_only,
|
||||
)
|
||||
|
||||
return {"items": workflows, "page": page, "limit": limit, "has_more": has_more}
|
||||
return {
|
||||
"items": workflows,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"has_more": has_more,
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(DraftWorkflowApi, "/apps/<uuid:app_id>/workflows/draft")
|
||||
api.add_resource(WorkflowConfigApi, "/apps/<uuid:app_id>/workflows/draft/config")
|
||||
api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run")
|
||||
api.add_resource(DraftWorkflowRunApi, "/apps/<uuid:app_id>/workflows/draft/run")
|
||||
api.add_resource(WorkflowTaskStopApi, "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop")
|
||||
api.add_resource(DraftWorkflowNodeRunApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run")
|
||||
class WorkflowByIdApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_fields)
|
||||
def patch(self, app_model: App, workflow_id: str):
|
||||
"""
|
||||
Update workflow attributes
|
||||
"""
|
||||
# Check permission
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("marked_name", type=str, required=False, location="json")
|
||||
parser.add_argument("marked_comment", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate name and comment length
|
||||
if args.marked_name and len(args.marked_name) > 20:
|
||||
raise ValueError("Marked name cannot exceed 20 characters")
|
||||
if args.marked_comment and len(args.marked_comment) > 100:
|
||||
raise ValueError("Marked comment cannot exceed 100 characters")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Prepare update data
|
||||
update_data = {}
|
||||
if args.get("marked_name") is not None:
|
||||
update_data["marked_name"] = args["marked_name"]
|
||||
if args.get("marked_comment") is not None:
|
||||
update_data["marked_comment"] = args["marked_comment"]
|
||||
|
||||
if not update_data:
|
||||
return {"message": "No valid fields to update"}, 400
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
# Create a session and manage the transaction
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow = workflow_service.update_workflow(
|
||||
session=session,
|
||||
workflow_id=workflow_id,
|
||||
tenant_id=app_model.tenant_id,
|
||||
account_id=current_user.id,
|
||||
data=update_data,
|
||||
)
|
||||
|
||||
if not workflow:
|
||||
raise NotFound("Workflow not found")
|
||||
|
||||
# Commit the transaction in the controller
|
||||
session.commit()
|
||||
|
||||
return workflow
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def delete(self, app_model: App, workflow_id: str):
|
||||
"""
|
||||
Delete workflow
|
||||
"""
|
||||
# Check permission
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
# Create a session and manage the transaction
|
||||
with Session(db.engine) as session:
|
||||
try:
|
||||
workflow_service.delete_workflow(
|
||||
session=session, workflow_id=workflow_id, tenant_id=app_model.tenant_id
|
||||
)
|
||||
# Commit the transaction in the controller
|
||||
session.commit()
|
||||
except WorkflowInUseError as e:
|
||||
abort(400, description=str(e))
|
||||
except DraftWorkflowDeletionError as e:
|
||||
abort(400, description=str(e))
|
||||
except ValueError as e:
|
||||
raise NotFound(str(e))
|
||||
|
||||
return None, 204
|
||||
|
||||
|
||||
api.add_resource(
|
||||
DraftWorkflowApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowConfigApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/config",
|
||||
)
|
||||
api.add_resource(
|
||||
AdvancedChatDraftWorkflowRunApi,
|
||||
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/run",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftWorkflowRunApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/run",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowTaskStopApi,
|
||||
"/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftWorkflowNodeRunApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
AdvancedChatDraftRunIterationNodeApi,
|
||||
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowDraftRunIterationNodeApi, "/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run"
|
||||
WorkflowDraftRunIterationNodeApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
AdvancedChatDraftRunLoopNodeApi,
|
||||
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/loop/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(WorkflowDraftRunLoopNodeApi, "/apps/<uuid:app_id>/workflows/draft/loop/nodes/<string:node_id>/run")
|
||||
api.add_resource(PublishedWorkflowApi, "/apps/<uuid:app_id>/workflows/publish")
|
||||
api.add_resource(PublishedAllWorkflowApi, "/apps/<uuid:app_id>/workflows")
|
||||
api.add_resource(DefaultBlockConfigsApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
|
||||
api.add_resource(
|
||||
DefaultBlockConfigApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>"
|
||||
WorkflowDraftRunLoopNodeApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/loop/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
PublishedWorkflowApi,
|
||||
"/apps/<uuid:app_id>/workflows/publish",
|
||||
)
|
||||
api.add_resource(
|
||||
PublishedAllWorkflowApi,
|
||||
"/apps/<uuid:app_id>/workflows",
|
||||
)
|
||||
api.add_resource(
|
||||
DefaultBlockConfigsApi,
|
||||
"/apps/<uuid:app_id>/workflows/default-workflow-block-configs",
|
||||
)
|
||||
api.add_resource(
|
||||
DefaultBlockConfigApi,
|
||||
"/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>",
|
||||
)
|
||||
api.add_resource(
|
||||
ConvertToWorkflowApi,
|
||||
"/apps/<uuid:app_id>/convert-to-workflow",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowByIdApi,
|
||||
"/apps/<uuid:app_id>/workflows/<string:workflow_id>",
|
||||
)
|
||||
api.add_resource(ConvertToWorkflowApi, "/apps/<uuid:app_id>/convert-to-workflow")
|
||||
|
@ -45,7 +45,9 @@ workflow_fields = {
|
||||
"graph": fields.Raw(attribute="graph_dict"),
|
||||
"features": fields.Raw(attribute="features_dict"),
|
||||
"hash": fields.String(attribute="unique_hash"),
|
||||
"version": fields.String(attribute="version"),
|
||||
"version": fields.String,
|
||||
"marked_name": fields.String,
|
||||
"marked_comment": fields.String,
|
||||
"created_by": fields.Nested(simple_account_fields, attribute="created_by_account"),
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.Nested(simple_account_fields, attribute="updated_by_account", allow_null=True),
|
||||
|
@ -0,0 +1,29 @@
|
||||
"""add marked_name and marked_comment in workflows
|
||||
|
||||
Revision ID: ee79d9b1c156
|
||||
Revises: 4413929e1ec2
|
||||
Create Date: 2025-03-03 14:36:05.750346
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'ee79d9b1c156'
|
||||
down_revision = '5511c782ee4c'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('marked_name', sa.String(), nullable=False, server_default=''))
|
||||
batch_op.add_column(sa.Column('marked_comment', sa.String(), nullable=False, server_default=''))
|
||||
|
||||
|
||||
def downgrade():
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.drop_column('marked_comment')
|
||||
batch_op.drop_column('marked_name')
|
@ -2,7 +2,8 @@ import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Self, Union
|
||||
from uuid import uuid4
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.model import AppMode
|
||||
@ -108,7 +109,9 @@ class Workflow(Base):
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
version: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
version: Mapped[str]
|
||||
marked_name: Mapped[str] = mapped_column(default="", server_default="")
|
||||
marked_comment: Mapped[str] = mapped_column(default="", server_default="")
|
||||
graph: Mapped[str] = mapped_column(sa.Text)
|
||||
_features: Mapped[str] = mapped_column("features", sa.TEXT)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
@ -127,8 +130,9 @@ class Workflow(Base):
|
||||
"conversation_variables", db.Text, nullable=False, server_default="{}"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@classmethod
|
||||
def new(
|
||||
cls,
|
||||
*,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
@ -139,16 +143,25 @@ class Workflow(Base):
|
||||
created_by: str,
|
||||
environment_variables: Sequence[Variable],
|
||||
conversation_variables: Sequence[Variable],
|
||||
):
|
||||
self.tenant_id = tenant_id
|
||||
self.app_id = app_id
|
||||
self.type = type
|
||||
self.version = version
|
||||
self.graph = graph
|
||||
self.features = features
|
||||
self.created_by = created_by
|
||||
self.environment_variables = environment_variables or []
|
||||
self.conversation_variables = conversation_variables or []
|
||||
marked_name: str = "",
|
||||
marked_comment: str = "",
|
||||
) -> Self:
|
||||
workflow = Workflow()
|
||||
workflow.id = str(uuid4())
|
||||
workflow.tenant_id = tenant_id
|
||||
workflow.app_id = app_id
|
||||
workflow.type = type
|
||||
workflow.version = version
|
||||
workflow.graph = graph
|
||||
workflow.features = features
|
||||
workflow.created_by = created_by
|
||||
workflow.environment_variables = environment_variables or []
|
||||
workflow.conversation_variables = conversation_variables or []
|
||||
workflow.marked_name = marked_name
|
||||
workflow.marked_comment = marked_comment
|
||||
workflow.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow.updated_at = workflow.created_at
|
||||
return workflow
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
|
2006
api/poetry.lock
generated
2006
api/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -151,6 +151,7 @@ pytest-benchmark = "~4.0.0"
|
||||
pytest-env = "~1.1.3"
|
||||
pytest-mock = "~3.14.0"
|
||||
types-beautifulsoup4 = "~4.12.0.20241020"
|
||||
types-deprecated = "~1.2.15.20250304"
|
||||
types-flask-cors = "~5.0.0.20240902"
|
||||
types-flask-migrate = "~4.1.0.20250112"
|
||||
types-html5lib = "~1.1.11.20241018"
|
||||
|
10
api/services/errors/workflow_service.py
Normal file
10
api/services/errors/workflow_service.py
Normal file
@ -0,0 +1,10 @@
|
||||
class WorkflowInUseError(ValueError):
|
||||
"""Raised when attempting to delete a workflow that's in use by an app"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DraftWorkflowDeletionError(ValueError):
|
||||
"""Raised when attempting to delete a draft workflow"""
|
||||
|
||||
pass
|
@ -5,7 +5,8 @@ from datetime import UTC, datetime
|
||||
from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
@ -36,6 +37,8 @@ from models.workflow import (
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.workflow.workflow_converter import WorkflowConverter
|
||||
|
||||
from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
|
||||
|
||||
|
||||
class WorkflowService:
|
||||
"""
|
||||
@ -79,22 +82,38 @@ class WorkflowService:
|
||||
|
||||
return workflow
|
||||
|
||||
def get_all_published_workflow(self, app_model: App, page: int, limit: int) -> tuple[list[Workflow], bool]:
|
||||
def get_all_published_workflow(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
app_model: App,
|
||||
page: int,
|
||||
limit: int,
|
||||
user_id: str | None,
|
||||
named_only: bool = False,
|
||||
) -> tuple[Sequence[Workflow], bool]:
|
||||
"""
|
||||
Get published workflow with pagination
|
||||
"""
|
||||
if not app_model.workflow_id:
|
||||
return [], False
|
||||
|
||||
workflows = (
|
||||
db.session.query(Workflow)
|
||||
.filter(Workflow.app_id == app_model.id)
|
||||
.order_by(desc(Workflow.version))
|
||||
.offset((page - 1) * limit)
|
||||
stmt = (
|
||||
select(Workflow)
|
||||
.where(Workflow.app_id == app_model.id)
|
||||
.order_by(Workflow.version.desc())
|
||||
.limit(limit + 1)
|
||||
.all()
|
||||
.offset((page - 1) * limit)
|
||||
)
|
||||
|
||||
if user_id:
|
||||
stmt = stmt.where(Workflow.created_by == user_id)
|
||||
|
||||
if named_only:
|
||||
stmt = stmt.where(Workflow.marked_name != "")
|
||||
|
||||
workflows = session.scalars(stmt).all()
|
||||
|
||||
has_more = len(workflows) > limit
|
||||
if has_more:
|
||||
workflows = workflows[:-1]
|
||||
@ -157,23 +176,26 @@ class WorkflowService:
|
||||
# return draft workflow
|
||||
return workflow
|
||||
|
||||
def publish_workflow(self, app_model: App, account: Account, draft_workflow: Optional[Workflow] = None) -> Workflow:
|
||||
"""
|
||||
Publish workflow from draft
|
||||
|
||||
:param app_model: App instance
|
||||
:param account: Account instance
|
||||
:param draft_workflow: Workflow instance
|
||||
"""
|
||||
if not draft_workflow:
|
||||
# fetch draft workflow by app_model
|
||||
draft_workflow = self.get_draft_workflow(app_model=app_model)
|
||||
|
||||
def publish_workflow(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
app_model: App,
|
||||
account: Account,
|
||||
marked_name: str = "",
|
||||
marked_comment: str = "",
|
||||
) -> Workflow:
|
||||
draft_workflow_stmt = select(Workflow).where(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.version == "draft",
|
||||
)
|
||||
draft_workflow = session.scalar(draft_workflow_stmt)
|
||||
if not draft_workflow:
|
||||
raise ValueError("No valid workflow found.")
|
||||
|
||||
# create new workflow
|
||||
workflow = Workflow(
|
||||
workflow = Workflow.new(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type=draft_workflow.type,
|
||||
@ -183,15 +205,12 @@ class WorkflowService:
|
||||
created_by=account.id,
|
||||
environment_variables=draft_workflow.environment_variables,
|
||||
conversation_variables=draft_workflow.conversation_variables,
|
||||
marked_name=marked_name,
|
||||
marked_comment=marked_comment,
|
||||
)
|
||||
|
||||
# commit db session changes
|
||||
db.session.add(workflow)
|
||||
db.session.flush()
|
||||
db.session.commit()
|
||||
|
||||
app_model.workflow_id = workflow.id
|
||||
db.session.commit()
|
||||
session.add(workflow)
|
||||
|
||||
# trigger app workflow events
|
||||
app_published_workflow_was_updated.send(app_model, published_workflow=workflow)
|
||||
@ -436,3 +455,65 @@ class WorkflowService:
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid app mode: {app_model.mode}")
|
||||
|
||||
def update_workflow(
|
||||
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
|
||||
) -> Optional[Workflow]:
|
||||
"""
|
||||
Update workflow attributes
|
||||
|
||||
:param session: SQLAlchemy database session
|
||||
:param workflow_id: Workflow ID
|
||||
:param tenant_id: Tenant ID
|
||||
:param account_id: Account ID (for permission check)
|
||||
:param data: Dictionary containing fields to update
|
||||
:return: Updated workflow or None if not found
|
||||
"""
|
||||
stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
|
||||
workflow = session.scalar(stmt)
|
||||
|
||||
if not workflow:
|
||||
return None
|
||||
|
||||
allowed_fields = ["marked_name", "marked_comment"]
|
||||
|
||||
for field, value in data.items():
|
||||
if field in allowed_fields:
|
||||
setattr(workflow, field, value)
|
||||
|
||||
workflow.updated_by = account_id
|
||||
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
return workflow
|
||||
|
||||
def delete_workflow(self, *, session: Session, workflow_id: str, tenant_id: str) -> bool:
|
||||
"""
|
||||
Delete a workflow
|
||||
|
||||
:param session: SQLAlchemy database session
|
||||
:param workflow_id: Workflow ID
|
||||
:param tenant_id: Tenant ID
|
||||
:return: True if successful
|
||||
:raises: ValueError if workflow not found
|
||||
:raises: WorkflowInUseError if workflow is in use
|
||||
:raises: DraftWorkflowDeletionError if workflow is a draft version
|
||||
"""
|
||||
stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
|
||||
workflow = session.scalar(stmt)
|
||||
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow with ID {workflow_id} not found")
|
||||
|
||||
# Check if workflow is a draft version
|
||||
if workflow.version == "draft":
|
||||
raise DraftWorkflowDeletionError("Cannot delete draft workflow versions")
|
||||
|
||||
# Check if this workflow is currently referenced by an app
|
||||
stmt = select(App).where(App.workflow_id == workflow_id)
|
||||
app = session.scalar(stmt)
|
||||
if app:
|
||||
# Cannot delete a workflow that's currently in use by an app
|
||||
raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.name}'")
|
||||
|
||||
session.delete(workflow)
|
||||
return True
|
||||
|
162
api/tests/unit_tests/services/workflow/test_workflow_service.py
Normal file
162
api/tests/unit_tests/services/workflow/test_workflow_service.py
Normal file
@ -0,0 +1,162 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from models.model import App
|
||||
from models.workflow import Workflow
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
class TestWorkflowService:
|
||||
@pytest.fixture
|
||||
def workflow_service(self):
|
||||
return WorkflowService()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app(self):
|
||||
app = MagicMock(spec=App)
|
||||
app.id = "app-id-1"
|
||||
app.workflow_id = "workflow-id-1"
|
||||
app.tenant_id = "tenant-id-1"
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_workflows(self):
|
||||
workflows = []
|
||||
for i in range(5):
|
||||
workflow = MagicMock(spec=Workflow)
|
||||
workflow.id = f"workflow-id-{i}"
|
||||
workflow.app_id = "app-id-1"
|
||||
workflow.created_at = f"2023-01-0{5 - i}" # Descending date order
|
||||
workflow.created_by = "user-id-1" if i % 2 == 0 else "user-id-2"
|
||||
workflow.marked_name = f"Workflow {i}" if i % 2 == 0 else ""
|
||||
workflows.append(workflow)
|
||||
return workflows
|
||||
|
||||
def test_get_all_published_workflow_no_workflow_id(self, workflow_service, mock_app):
|
||||
mock_app.workflow_id = None
|
||||
mock_session = MagicMock()
|
||||
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None
|
||||
)
|
||||
|
||||
assert workflows == []
|
||||
assert has_more is False
|
||||
mock_session.scalars.assert_not_called()
|
||||
|
||||
def test_get_all_published_workflow_basic(self, workflow_service, mock_app, mock_workflows):
|
||||
mock_session = MagicMock()
|
||||
mock_scalar_result = MagicMock()
|
||||
mock_scalar_result.all.return_value = mock_workflows[:3]
|
||||
mock_session.scalars.return_value = mock_scalar_result
|
||||
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=mock_session, app_model=mock_app, page=1, limit=3, user_id=None
|
||||
)
|
||||
|
||||
assert workflows == mock_workflows[:3]
|
||||
assert has_more is False
|
||||
mock_session.scalars.assert_called_once()
|
||||
|
||||
def test_get_all_published_workflow_pagination(self, workflow_service, mock_app, mock_workflows):
|
||||
mock_session = MagicMock()
|
||||
mock_scalar_result = MagicMock()
|
||||
# Return 4 items when limit is 3, which should indicate has_more=True
|
||||
mock_scalar_result.all.return_value = mock_workflows[:4]
|
||||
mock_session.scalars.return_value = mock_scalar_result
|
||||
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=mock_session, app_model=mock_app, page=1, limit=3, user_id=None
|
||||
)
|
||||
|
||||
# Should return only the first 3 items
|
||||
assert len(workflows) == 3
|
||||
assert workflows == mock_workflows[:3]
|
||||
assert has_more is True
|
||||
|
||||
# Test page 2
|
||||
mock_scalar_result.all.return_value = mock_workflows[3:]
|
||||
mock_session.scalars.return_value = mock_scalar_result
|
||||
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=mock_session, app_model=mock_app, page=2, limit=3, user_id=None
|
||||
)
|
||||
|
||||
assert len(workflows) == 2
|
||||
assert has_more is False
|
||||
|
||||
def test_get_all_published_workflow_user_filter(self, workflow_service, mock_app, mock_workflows):
|
||||
mock_session = MagicMock()
|
||||
mock_scalar_result = MagicMock()
|
||||
# Filter workflows for user-id-1
|
||||
filtered_workflows = [w for w in mock_workflows if w.created_by == "user-id-1"]
|
||||
mock_scalar_result.all.return_value = filtered_workflows
|
||||
mock_session.scalars.return_value = mock_scalar_result
|
||||
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=mock_session, app_model=mock_app, page=1, limit=10, user_id="user-id-1"
|
||||
)
|
||||
|
||||
assert workflows == filtered_workflows
|
||||
assert has_more is False
|
||||
mock_session.scalars.assert_called_once()
|
||||
|
||||
# Verify that the select contains a user filter clause
|
||||
args = mock_session.scalars.call_args[0][0]
|
||||
assert "created_by" in str(args)
|
||||
|
||||
def test_get_all_published_workflow_named_only(self, workflow_service, mock_app, mock_workflows):
|
||||
mock_session = MagicMock()
|
||||
mock_scalar_result = MagicMock()
|
||||
# Filter workflows that have a marked_name
|
||||
named_workflows = [w for w in mock_workflows if w.marked_name]
|
||||
mock_scalar_result.all.return_value = named_workflows
|
||||
mock_session.scalars.return_value = mock_scalar_result
|
||||
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None, named_only=True
|
||||
)
|
||||
|
||||
assert workflows == named_workflows
|
||||
assert has_more is False
|
||||
mock_session.scalars.assert_called_once()
|
||||
|
||||
# Verify that the select contains a named_only filter clause
|
||||
args = mock_session.scalars.call_args[0][0]
|
||||
assert "marked_name !=" in str(args)
|
||||
|
||||
def test_get_all_published_workflow_combined_filters(self, workflow_service, mock_app, mock_workflows):
|
||||
mock_session = MagicMock()
|
||||
mock_scalar_result = MagicMock()
|
||||
# Combined filter: user-id-1 and has marked_name
|
||||
filtered_workflows = [w for w in mock_workflows if w.created_by == "user-id-1" and w.marked_name]
|
||||
mock_scalar_result.all.return_value = filtered_workflows
|
||||
mock_session.scalars.return_value = mock_scalar_result
|
||||
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=mock_session, app_model=mock_app, page=1, limit=10, user_id="user-id-1", named_only=True
|
||||
)
|
||||
|
||||
assert workflows == filtered_workflows
|
||||
assert has_more is False
|
||||
mock_session.scalars.assert_called_once()
|
||||
|
||||
# Verify that both filters are applied
|
||||
args = mock_session.scalars.call_args[0][0]
|
||||
assert "created_by" in str(args)
|
||||
assert "marked_name !=" in str(args)
|
||||
|
||||
def test_get_all_published_workflow_empty_result(self, workflow_service, mock_app):
|
||||
mock_session = MagicMock()
|
||||
mock_scalar_result = MagicMock()
|
||||
mock_scalar_result.all.return_value = []
|
||||
mock_session.scalars.return_value = mock_scalar_result
|
||||
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None
|
||||
)
|
||||
|
||||
assert workflows == []
|
||||
assert has_more is False
|
||||
mock_session.scalars.assert_called_once()
|
Loading…
x
Reference in New Issue
Block a user