feat(api): implement draft var related api

This commit is contained in:
QuantumGhost 2025-05-22 17:31:27 +08:00
parent 0f7ea8d5fa
commit be098dee35
24 changed files with 1139 additions and 37 deletions

View File

@ -1,4 +1,39 @@
import os
def _setup_gevent():
"""Do gevent monkey patching.
This function should be called as early as possible. Ideally
it should be the first statement in the entrypoint file.
It should be
"""
# It seems that JetBrains Python debugger does not work well with gevent,
# so we need to disable gevent in debug mode.
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() not in {"false", "0", "no"}:
return
if os.environ.get("GEVENT_SUPPORT", "0") == "0":
return
from gevent import monkey
# gevent
monkey.patch_all()
from grpc.experimental import gevent as grpc_gevent # type: ignore
# grpc gevent
grpc_gevent.init_gevent()
import psycogreen.gevent # type: ignore
psycogreen.gevent.patch_psycopg()
_setup_gevent()
import sys
@ -14,24 +49,6 @@ if is_db_command():
app = create_migrations_app()
else:
# It seems that JetBrains Python debugger does not work well with gevent,
# so we need to disable gevent in debug mode.
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
from gevent import monkey
# gevent
monkey.patch_all()
from grpc.experimental import gevent as grpc_gevent # type: ignore
# grpc gevent
grpc_gevent.init_gevent()
import psycogreen.gevent # type: ignore
psycogreen.gevent.patch_psycopg()
from app_factory import create_app
app = create_app()

View File

@ -63,6 +63,7 @@ from .app import (
statistic,
workflow,
workflow_app_log,
workflow_draft_variable,
workflow_run,
workflow_statistic,
)

View File

@ -0,0 +1,319 @@
import logging
from typing import NoReturn
from flask import Response
from flask_restful import Resource, fields, inputs, marshal_with, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.app.error import (
DraftWorkflowNotExist,
)
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from factories.variable_factory import build_segment
from libs.login import current_user, login_required
from models import App, AppMode, db
from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__)
def _create_pagination_parser():
parser = reqparse.RequestParser()
parser.add_argument(
"page",
type=inputs.int_range(1, 100_000),
required=False,
default=1,
location="args",
help="the page of data requested",
)
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
return parser
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
"id": fields.String,
"type": fields.String(attribute=lambda model: model.get_variable_type()),
"name": fields.String,
"description": fields.String,
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
"value_type": fields.String,
"edited": fields.Boolean(attribute=lambda model: model.edited),
"visible": fields.Boolean,
}
_WORKFLOW_DRAFT_VARIABLE_FIELDS = dict(
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
value=fields.Raw(attribute=lambda variable: variable.get_value().value),
)
_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
"id": fields.String,
"type": fields.String(attribute=lambda _: "env"),
"name": fields.String,
"description": fields.String,
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
"value_type": fields.String,
"edited": fields.Boolean(attribute=lambda model: model.edited),
"visible": fields.Boolean,
}
_WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS = {
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)),
}
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
return var_list.variables
_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = {
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items),
"total": fields.Raw(),
}
_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
}
def _api_prerequisite(f):
"""Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied:
- Dify has been property setup.
- The request user has logged in and initialized.
- The requested app is a workflow or a chat flow.
- The request user has the edit permission for the app.
"""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def wrapper(*args, **kwargs):
if not current_user.is_editor:
raise Forbidden()
return f(*args, **kwargs)
return wrapper
class WorkflowVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
def get(self, app_model: App):
"""
Get draft workflow
"""
parser = _create_pagination_parser()
args = parser.parse_args()
# fetch draft workflow by app_model
workflow_service = WorkflowService()
workflow_exist = workflow_service.is_workflow_exist(app_model=app_model)
if not workflow_exist:
raise DraftWorkflowNotExist()
# fetch draft workflow by app_model
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
workflow_vars = draft_var_srv.list_variables_without_values(
app_id=app_model.id,
page=args.page,
limit=args.limit,
)
return workflow_vars
@_api_prerequisite
def delete(self, app_model: App):
draft_var_srv = WorkflowDraftVariableService(
session=db.session,
)
draft_var_srv.delete_workflow_variables(app_model.id)
db.session.commit()
return Response("", 204)
def validate_node_id(node_id: str) -> NoReturn | None:
if node_id in [
CONVERSATION_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
]:
# NOTE(QuantumGhost): While we store the system and conversation variables as node variables
# with specific `node_id` in database, we still want to make the API separated. By disallowing
# accessing system and conversation variables in `WorkflowDraftNodeVariableListApi`,
# we mitigate the risk that user of the API depending on the implementation detail of the API.
#
# ref: [Hyrum's Law](https://www.hyrumslaw.com/)
raise InvalidArgumentError(
f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
)
return None
class NodeVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
def get(self, app_model: App, node_id: str):
validate_node_id(node_id)
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
node_vars = draft_var_srv.list_node_variables(app_model.id, node_id)
return node_vars
@_api_prerequisite
def delete(self, app_model: App, node_id: str):
validate_node_id(node_id)
srv = WorkflowDraftVariableService(db.session)
srv.delete_node_variables(app_model.id, node_id)
db.session.commit()
return Response("", 204)
class VariableApi(Resource):
_PATCH_NAME_FIELD = "name"
_PATCH_VALUE_FIELD = "value"
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
def get(self, app_model: App, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session,
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != app_model.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
return variable
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
def patch(self, app_model: App, variable_id: str):
parser = reqparse.RequestParser()
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
parser.add_argument(self._PATCH_VALUE_FIELD, type=build_segment, required=False, nullable=True, location="json")
draft_var_srv = WorkflowDraftVariableService(
session=db.session,
)
args = parser.parse_args(strict=True)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != app_model.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
new_name = args.get(self._PATCH_NAME_FIELD, None)
new_value = args.get(self._PATCH_VALUE_FIELD, None)
if new_name is None and new_value is None:
return variable
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
db.session.commit()
return variable
@_api_prerequisite
def delete(self, app_model: App, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session,
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != app_model.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
draft_var_srv.delete_variable(variable)
db.session.commit()
return Response("", 204)
def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
if node_id == CONVERSATION_VARIABLE_NODE_ID:
draft_vars = draft_var_srv.list_conversation_variables(app_model.id)
elif node_id == SYSTEM_VARIABLE_NODE_ID:
draft_vars = draft_var_srv.list_system_variables(app_model.id)
else:
draft_vars = draft_var_srv.list_node_variables(app_id=app_model.id, node_id=node_id)
return draft_vars
class ConversationVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
def get(self, app_model: App):
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
class SystemVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
def get(self, app_model: App):
return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID)
class EnvironmentVariableCollectionApi(Resource):
@_api_prerequisite
def get(self, app_model: App):
"""
Get draft workflow
"""
# fetch draft workflow by app_model
workflow_service = WorkflowService()
workflow = workflow_service.get_draft_workflow(app_model=app_model)
if workflow is None:
raise DraftWorkflowNotExist()
env_vars = workflow.environment_variables
env_vars_list = []
for v in env_vars:
env_vars_list.append(
{
"id": v.id,
"type": "env",
"name": v.name,
"description": v.description,
"selector": v.selector,
"value_type": v.value_type.value,
"value": v.value,
# Do not track edited for env vars.
"edited": False,
"visible": True,
"editable": True,
}
)
return {"items": env_vars_list}
api.add_resource(
WorkflowVariableCollectionApi,
"/apps/<uuid:app_id>/workflows/draft/variables",
)
api.add_resource(NodeVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
api.add_resource(VariableApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>")
api.add_resource(ConversationVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/conversation-variables")
api.add_resource(SystemVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/system-variables")
api.add_resource(EnvironmentVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/environment-variables")

View File

@ -0,0 +1,196 @@
import datetime
import uuid
from collections import OrderedDict
from typing import NamedTuple
from flask_restful import marshal
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from factories.variable_factory import build_segment
from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableList
from .workflow_draft_variable import (
_WORKFLOW_DRAFT_VARIABLE_FIELDS,
_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS,
_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS,
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
)
_TEST_APP_ID = "test_app_id"
class TestWorkflowDraftVariableFields:
def test_conversation_variable(self):
conv_var = WorkflowDraftVariable.new_conversation_variable(
app_id=_TEST_APP_ID, name="conv_var", value=build_segment(1)
)
conv_var.id = str(uuid.uuid4())
conv_var.visible = True
expected_without_value = OrderedDict(
{
"id": str(conv_var.id),
"type": conv_var.get_variable_type().value,
"name": "conv_var",
"description": "",
"selector": [CONVERSATION_VARIABLE_NODE_ID, "conv_var"],
"value_type": "number",
"edited": False,
"visible": True,
}
)
assert marshal(conv_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
expected_with_value = expected_without_value.copy()
expected_with_value["value"] = 1
assert marshal(conv_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value
def test_create_sys_variable(self):
sys_var = WorkflowDraftVariable.new_sys_variable(
app_id=_TEST_APP_ID,
name="sys_var",
value=build_segment("a"),
editable=True,
)
sys_var.id = str(uuid.uuid4())
sys_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
sys_var.visible = True
expected_without_value = OrderedDict(
{
"id": str(sys_var.id),
"type": sys_var.get_variable_type().value,
"name": "sys_var",
"description": "",
"selector": [SYSTEM_VARIABLE_NODE_ID, "sys_var"],
"value_type": "string",
"edited": True,
"visible": True,
}
)
assert marshal(sys_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
expected_with_value = expected_without_value.copy()
expected_with_value["value"] = "a"
assert marshal(sys_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value
def test_node_variable(self):
node_var = WorkflowDraftVariable.new_node_variable(
app_id=_TEST_APP_ID,
node_id="test_node",
name="node_var",
value=build_segment([1, "a"]),
visible=False,
)
node_var.id = str(uuid.uuid4())
node_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
expected_without_value = OrderedDict(
{
"id": str(node_var.id),
"type": node_var.get_variable_type().value,
"name": "node_var",
"description": "",
"selector": ["test_node", "node_var"],
"value_type": "array[any]",
"edited": True,
"visible": False,
}
)
assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
expected_with_value = expected_without_value.copy()
expected_with_value["value"] = [1, "a"]
assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value
class TestWorkflowDraftVariableList:
def test_workflow_draft_variable_list(self):
class TestCase(NamedTuple):
name: str
var_list: WorkflowDraftVariableList
expected: dict
node_var = WorkflowDraftVariable.new_node_variable(
app_id=_TEST_APP_ID,
node_id="test_node",
name="test_var",
value=build_segment("a"),
visible=True,
)
node_var.id = str(uuid.uuid4())
node_var_dict = OrderedDict(
{
"id": str(node_var.id),
"type": node_var.get_variable_type().value,
"name": "test_var",
"description": "",
"selector": ["test_node", "test_var"],
"value_type": "string",
"edited": False,
"visible": True,
}
)
cases = [
TestCase(
name="empty variable list",
var_list=WorkflowDraftVariableList(variables=[]),
expected=OrderedDict(
{
"items": [],
"total": None,
}
),
),
TestCase(
name="empty variable list with total",
var_list=WorkflowDraftVariableList(variables=[], total=10),
expected=OrderedDict(
{
"items": [],
"total": 10,
}
),
),
TestCase(
name="non-empty variable list",
var_list=WorkflowDraftVariableList(variables=[node_var], total=None),
expected=OrderedDict(
{
"items": [node_var_dict],
"total": None,
}
),
),
TestCase(
name="non-empty variable list with total",
var_list=WorkflowDraftVariableList(variables=[node_var], total=10),
expected=OrderedDict(
{
"items": [node_var_dict],
"total": 10,
}
),
),
]
for idx, case in enumerate(cases, 1):
assert marshal(case.var_list, _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS) == case.expected, (
f"Test case {idx} failed, {case.name=}"
)
def test_workflow_node_variables_fields():
conv_var = WorkflowDraftVariable.new_conversation_variable(
app_id=_TEST_APP_ID, name="conv_var", value=build_segment(1)
)
resp = marshal(WorkflowDraftVariableList(variables=[conv_var]), _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
assert isinstance(resp, dict)
assert len(resp["items"]) == 1
item_dict = resp["items"][0]
assert item_dict["name"] == "conv_var"
assert item_dict["value"] == 1

View File

@ -8,6 +8,15 @@ from libs.login import current_user
from models import App, AppMode
def _load_app_model(app_id: str) -> Optional[App]:
app_model = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
return app_model
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func):
@wraps(view_func)
@ -20,11 +29,7 @@ def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[
del kwargs["app_id"]
app_model = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
app_model = _load_app_model(app_id)
if not app_model:
raise AppNotFoundError()

View File

@ -139,3 +139,13 @@ class InvokeRateLimitError(BaseHTTPException):
error_code = "rate_limit_error"
description = "Rate Limit Error"
code = 429
class NotFoundError(BaseHTTPException):
error_code = "unknown"
code = 404
class InvalidArgumentError(BaseHTTPException):
error_code = "invalid_param"
code = 400

View File

@ -17,9 +17,24 @@ class InvokeFrom(Enum):
Invoke From.
"""
# SERVICE_API indicates that this invocation is from an API call to Dify app.
#
# Description of service api in Dify docs:
# https://docs.dify.ai/en/guides/application-publishing/developing-with-apis
SERVICE_API = "service-api"
# WEB_APP indicates that this invocation is from
# the web app of the workflow (or chatflow).
#
# Description of web app in Dify docs:
# https://docs.dify.ai/en/guides/application-publishing/launch-your-webapp-quickly/README
WEB_APP = "web-app"
# EXPLORE indicates that this invocation is from
# the workflow (or chatflow) explore page.
EXPLORE = "explore"
# DEBUGGER indicates that this invocation is from
# the workflow (or chatflow) edit page.
DEBUGGER = "debugger"
@classmethod

View File

@ -1 +1,21 @@
from typing import Any
FILE_MODEL_IDENTITY = "__dify__file__"
# DUMMY_OUTPUT_IDENTITY is a placeholder output for workflow nodes.
# Its sole possible value is `None`.
#
# This is used to signal the execution of a workflow node when it has no other outputs.
_DUMMY_OUTPUT_IDENTITY = "__dummy__"
_DUMMY_OUTPUT_VALUE: None = None
def add_dummy_output(original: dict[str, Any] | None) -> dict[str, Any]:
if original is None:
original = {}
original[_DUMMY_OUTPUT_IDENTITY] = _DUMMY_OUTPUT_VALUE
return original
def is_dummy_output_variable(name: str) -> bool:
return name == _DUMMY_OUTPUT_IDENTITY

View File

@ -7,12 +7,12 @@ from pydantic import BaseModel, Field
from core.file import File, FileAttribute, file_manager
from core.variables import Segment, SegmentGroup, Variable
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.variables.segments import FileSegment, NoneSegment
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.enums import SystemVariableKey
from factories import variable_factory
from ..constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from ..enums import SystemVariableKey
VariableValue = Union[str, int, float, dict, list, File]
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
@ -91,7 +91,7 @@ class VariablePool(BaseModel):
Returns:
None
"""
if len(selector) < 2:
if len(selector) < MIN_SELECTORS_LENGTH:
raise ValueError("Invalid selector")
if isinstance(value, Variable):
@ -118,7 +118,7 @@ class VariablePool(BaseModel):
Raises:
ValueError: If the selector is invalid.
"""
if len(selector) < 2:
if len(selector) < MIN_SELECTORS_LENGTH:
return None
hash_key = hash(tuple(selector[1:]))

View File

@ -1,8 +1,8 @@
import logging
import time
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from collections.abc import Callable, Generator, Mapping, Sequence
from typing import Any, Optional, TypeAlias, TypeVar, cast
from configs import dify_config
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.workflow.callbacks import WorkflowCallback
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent, InNodeEvent
@ -19,7 +20,7 @@ from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntime
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.event import NodeEvent
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from factories import file_factory
from models.enums import UserFrom
@ -120,6 +121,7 @@ class WorkflowEntry:
node_id: str,
user_id: str,
user_inputs: dict,
conversation_variables: dict | None = None,
) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]:
"""
Single step run workflow node
@ -144,13 +146,19 @@ class WorkflowEntry:
except StopIteration:
raise ValueError("node id not found in workflow graph")
node_config_data = node_config.get("data", {})
# Get node class
node_type = NodeType(node_config.get("data", {}).get("type"))
node_version = node_config.get("data", {}).get("version", "1")
node_type = NodeType(node_config_data.get("type"))
node_version = node_config_data.get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
metadata_attacher = _attach_execution_metadata_based_on_node_config(node_config_data)
# init variable pool
variable_pool = VariablePool(environment_variables=workflow.environment_variables)
variable_pool = VariablePool(
environment_variables=workflow.environment_variables,
conversation_variable=conversation_variables or {},
)
# init graph
graph = Graph.init(graph_config=workflow.graph_dict)
@ -188,11 +196,15 @@ class WorkflowEntry:
variable_pool=variable_pool,
tenant_id=workflow.tenant_id,
)
cls._load_persisted_draft_var_and_populate_pool(app_id=workflow.app_id, variable_pool=variable_pool)
try:
# run node
generator = node_instance.run()
except Exception as e:
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
if metadata_attacher:
generator = _wrap_generator(generator, metadata_attacher)
return node_instance, generator
@classmethod
@ -319,6 +331,16 @@ class WorkflowEntry:
return value.to_dict()
return value
@classmethod
def _load_persisted_draft_var_and_populate_pool(cls, app_id: str, variable_pool: VariablePool) -> None:
"""
Load persisted draft variables and populate the variable pool.
:param app_id: The application ID.
:param variable_pool: The variable pool to populate.
"""
# TODO(QuantumGhost):
pass
@classmethod
def mapping_user_inputs_to_variable_pool(
cls,
@ -367,3 +389,61 @@ class WorkflowEntry:
# append variable and value to variable pool
if variable_node_id != ENVIRONMENT_VARIABLE_NODE_ID:
variable_pool.add([variable_node_id] + variable_key_list, input_value)
_YieldT_co = TypeVar("_YieldT_co", covariant=True)
_YieldR_co = TypeVar("_YieldR_co", covariant=True)
def _wrap_generator(
gen: Generator[_YieldT_co, None, None],
mapper: Callable[[_YieldT_co], _YieldR_co],
) -> Generator[_YieldR_co, None, None]:
for item in gen:
yield mapper(item)
_NodeOrInNodeEvent: TypeAlias = NodeEvent | InNodeEvent
def _attach_execution_metadata(
extra_metadata: dict[NodeRunMetadataKey, Any],
) -> Callable[[_NodeOrInNodeEvent], _NodeOrInNodeEvent]:
def _execution_metadata_mapper(e: NodeEvent | InNodeEvent) -> NodeEvent | InNodeEvent:
if not isinstance(e, RunCompletedEvent):
return e
run_result = e.run_result
if run_result.metadata is None:
run_result.metadata = {}
for k, v in extra_metadata.items():
run_result.metadata[k] = v
return e
return _execution_metadata_mapper
def _attach_execution_metadata_based_on_node_config(
node_config: dict,
) -> Callable[[_NodeOrInNodeEvent], _NodeOrInNodeEvent] | None:
in_loop = node_config.get("isInLoop", False)
in_iteration = node_config.get("isInIteration", False)
if in_loop:
loop_id = node_config.get("loop_id")
if loop_id is None:
raise Exception("invalid graph")
return _attach_execution_metadata(
{
NodeRunMetadataKey.LOOP_ID: loop_id,
}
)
elif in_iteration:
iteration_id = node_config.get("iteration_id")
if iteration_id is None:
raise Exception("invalid graph")
return _attach_execution_metadata(
{
NodeRunMetadataKey.ITERATION_ID: iteration_id,
}
)
else:
return None

View File

@ -114,6 +114,10 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
return cast(Variable, result)
def infer_segment_type_from_value(value: Any, /) -> SegmentType:
return build_segment(value).value_type
def build_segment(value: Any, /) -> Segment:
if value is None:
return NoneSegment()

View File

@ -602,6 +602,14 @@ class InstalledApp(Base):
return tenant
class ConversationSource(StrEnum):
"""This enumeration is designed for use with `Conversation.from_source`."""
# NOTE(QuantumGhost): The enumeration members may not cover all possible cases.
API = "api"
CONSOLE = "console"
class Conversation(Base):
__tablename__ = "conversations"
__table_args__ = (
@ -623,7 +631,14 @@ class Conversation(Base):
system_instruction = db.Column(db.Text)
system_instruction_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
status = db.Column(db.String(255), nullable=False)
# The `invoke_from` records how the conversation is created.
#
# Its value corresponds to the members of `InvokeFrom`.
# (api/core/app/entities/app_invoke_entities.py)
invoke_from = db.Column(db.String(255), nullable=True)
# ref: ConversationSource.
from_source = db.Column(db.String(255), nullable=False)
from_end_user_id = db.Column(StringUUID)
from_account_id = db.Column(StringUUID)

View File

@ -135,6 +135,8 @@ class Workflow(Base):
"conversation_variables", db.Text, nullable=False, server_default="{}"
)
VERSION_DRAFT = "draft"
@classmethod
def new(
cls,
@ -356,6 +358,10 @@ class Workflow(Base):
ensure_ascii=False,
)
@staticmethod
def version_from_datetime(d: datetime) -> str:
return str(d)
class WorkflowRunStatus(StrEnum):
"""
@ -823,7 +829,7 @@ def _naive_utc_datetime():
class WorkflowDraftVariable(Base):
@staticmethod
def unique_columns() -> list[str]:
def unique_app_id_node_id_name() -> list[str]:
return [
"app_id",
"node_id",
@ -831,7 +837,7 @@ class WorkflowDraftVariable(Base):
]
__tablename__ = "workflow_draft_variables"
__table_args__ = (UniqueConstraint(*unique_columns()),)
__table_args__ = (UniqueConstraint(*unique_app_id_node_id_name()),)
# id is the unique identifier of a draft variable.
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
@ -996,10 +1002,11 @@ class WorkflowDraftVariable(Base):
name: str,
value: Segment,
visible: bool = True,
editable: bool = True,
) -> "WorkflowDraftVariable":
variable = cls._new(app_id=app_id, node_id=node_id, name=name, value=value)
variable.visible = visible
variable.editable = True
variable.editable = editable
return variable
@property

View File

@ -149,6 +149,7 @@ dev = [
"types-tqdm~=4.67.0",
"types-ujson~=5.10.0",
"boto3-stubs>=1.38.20",
"hypothesis>=6.131.15",
]
############################################################

View File

@ -1,4 +1,16 @@
import os
import random
import secrets
from collections.abc import Generator
import pytest
from flask import Flask
from flask.testing import FlaskClient
from sqlalchemy.orm import Session
from app_factory import create_app
from models import Account, DifySetup, Tenant, TenantAccountJoin, db
from services.account_service import AccountService, RegisterService
# Getting the absolute path of the current file's directory
ABS_PATH = os.path.dirname(os.path.abspath(__file__))
@ -17,3 +29,61 @@ def _load_env() -> None:
_load_env()
_CACHED_APP = create_app()
@pytest.fixture
def flask_app() -> Flask:
return _CACHED_APP
@pytest.fixture(scope="session")
def setup_account(request) -> Generator[Account, None, None]:
"""`dify_setup` completes the setup process for the Dify application.
It creates `Account` and `Tenant`, and inserts a `DifySetup` record into the database.
Most tests in the `controllers` package may require dify has been successfully setup.
"""
with _CACHED_APP.test_request_context():
rand_suffix = random.randint(int(1e6), int(1e7))
name = f"test-user-{rand_suffix}"
email = f"{name}@example.com"
RegisterService.setup(
email=email,
name=name,
password=secrets.token_hex(16),
ip_address="localhost",
)
with _CACHED_APP.test_request_context():
with Session(bind=db.engine, expire_on_commit=False) as session:
account = session.query(Account).filter_by(email=email).one()
yield account
with _CACHED_APP.test_request_context():
db.session.query(DifySetup).delete()
db.session.query(TenantAccountJoin).delete()
db.session.query(Account).delete()
db.session.query(Tenant).delete()
db.session.commit()
@pytest.fixture
def flask_req_ctx():
with _CACHED_APP.test_request_context():
yield
@pytest.fixture
def auth_header(setup_account) -> dict[str, str]:
token = AccountService.get_account_jwt_token(setup_account)
return {"Authorization": f"Bearer {token}"}
@pytest.fixture
def test_client() -> Generator[FlaskClient, None, None]:
with _CACHED_APP.test_client() as client:
yield client

View File

@ -0,0 +1,46 @@
import uuid
from unittest import mock
from controllers.console.app import workflow_draft_variable as draft_variable_api
from controllers.console.app import wraps
from factories.variable_factory import build_segment
from models import App, AppMode
from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
def _get_mock_srv_class() -> type[WorkflowDraftVariableService]:
return mock.create_autospec(WorkflowDraftVariableService)
class TestWorkflowDraftNodeVariableListApi:
def test_get(self, test_client, auth_header, monkeypatch):
srv_class = _get_mock_srv_class()
mock_app_model: App = App()
mock_app_model.id = str(uuid.uuid4())
test_node_id = "test_node_id"
mock_app_model.mode = AppMode.ADVANCED_CHAT
mock_load_app_model = mock.Mock(return_value=mock_app_model)
monkeypatch.setattr(draft_variable_api, "WorkflowDraftVariableService", srv_class)
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
var1 = WorkflowDraftVariable.create_node_variable(
app_id="test_app_1",
node_id="test_node_1",
name="str_var",
value=build_segment("str_value"),
)
srv_instance = mock.create_autospec(WorkflowDraftVariableService, instance=True)
srv_class.return_value = srv_instance
srv_instance.list_node_variables.return_value = WorkflowDraftVariableList(variables=[var1])
response = test_client.get(
f"/console/api/apps/{mock_app_model.id}/workflows/draft/nodes/{test_node_id}/variables",
headers=auth_header,
)
assert response.status_code == 200
response_dict = response.json
assert isinstance(response_dict, dict)
assert "items" in response_dict
assert len(response_dict["items"]) == 1

View File

@ -0,0 +1,142 @@
import unittest
import uuid
import pytest
from sqlalchemy.orm import Session
from factories.variable_factory import build_segment
from models import db
from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableService
@pytest.mark.usefixtures("flask_req_ctx")
class TestWorkflowDraftVariableService(unittest.TestCase):
_test_app_id: str
_session: Session
_node2_id = "test_node_2"
def setUp(self):
self._test_app_id = str(uuid.uuid4())
self._session: Session = db.session
sys_var = WorkflowDraftVariable.create_sys_variable(
app_id=self._test_app_id,
name="sys_var",
value=build_segment("sys_value"),
)
conv_var = WorkflowDraftVariable.create_conversation_variable(
app_id=self._test_app_id,
name="conv_var",
value=build_segment("conv_value"),
)
node2_vars = [
WorkflowDraftVariable.create_node_variable(
app_id=self._test_app_id,
node_id=self._node2_id,
name="int_var",
value=build_segment(1),
visible=False,
),
WorkflowDraftVariable.create_node_variable(
app_id=self._test_app_id,
node_id=self._node2_id,
name="str_var",
value=build_segment("str_value"),
visible=True,
),
]
node1_var = WorkflowDraftVariable.create_node_variable(
app_id=self._test_app_id,
node_id="node_1",
name="str_var",
value=build_segment("str_value"),
visible=True,
)
_variables = list(node2_vars)
_variables.extend(
[
node1_var,
sys_var,
conv_var,
]
)
db.session.add_all(_variables)
db.session.flush()
self._variable_ids = [v.id for v in _variables]
self._node1_str_var_id = node1_var.id
self._sys_var_id = sys_var.id
self._conv_var_id = conv_var.id
self._node2_var_ids = [v.id for v in node2_vars]
def _get_test_srv(self) -> WorkflowDraftVariableService:
return WorkflowDraftVariableService(session=self._session)
def tearDown(self):
self._session.rollback()
def test_list_variables(self):
srv = self._get_test_srv()
var_list = srv.list_variables_without_values(self._test_app_id, page=1, limit=2)
assert var_list.total == 5
assert len(var_list.variables) == 2
page1_var_ids = {v.id for v in var_list.variables}
assert page1_var_ids.issubset(self._variable_ids)
var_list_2 = srv.list_variables_without_values(self._test_app_id, page=2, limit=2)
assert var_list_2.total is None
assert len(var_list_2.variables) == 2
page2_var_ids = {v.id for v in var_list_2.variables}
assert page2_var_ids.isdisjoint(page1_var_ids)
assert page2_var_ids.issubset(self._variable_ids)
def test_get_node_variable(self):
srv = self._get_test_srv()
node_var = srv.get_node_variable(self._test_app_id, "node_1", "str_var")
assert node_var.id == self._node1_str_var_id
assert node_var.name == "str_var"
assert node_var.get_value() == build_segment("str_value")
def test_get_system_variable(self):
srv = self._get_test_srv()
sys_var = srv.get_system_variable(self._test_app_id, "sys_var")
assert sys_var.id == self._sys_var_id
assert sys_var.name == "sys_var"
assert sys_var.get_value() == build_segment("sys_value")
def test_get_conversation_variable(self):
srv = self._get_test_srv()
conv_var = srv.get_conversation_variable(self._test_app_id, "conv_var")
assert conv_var.id == self._conv_var_id
assert conv_var.name == "conv_var"
assert conv_var.get_value() == build_segment("conv_value")
def test_delete_node_variables(self):
srv = self._get_test_srv()
srv.delete_node_variables(self._test_app_id, self._node2_id)
node2_var_count = (
self._session.query(WorkflowDraftVariable)
.where(
WorkflowDraftVariable.app_id == self._test_app_id,
WorkflowDraftVariable.node_id == self._node2_id,
)
.count()
)
assert node2_var_count == 0
def test_delete_variable(self):
srv = self._get_test_srv()
node_1_var = (
self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).one()
)
srv.delete_variable(node_1_var)
exists = bool(
self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).first()
)
assert exists is False
def test__list_node_variables(self):
srv = self._get_test_srv()
node_vars = srv._list_node_variables(self._test_app_id, self._node2_id)
assert len(node_vars) == 2
assert {v.id for v in node_vars} == set(self._node2_var_ids)

View File

@ -1,7 +1,11 @@
from dataclasses import dataclass
from uuid import uuid4
import pytest
from hypothesis import given
from hypothesis import strategies as st
from core.file import File, FileTransferMethod, FileType
from core.variables import (
ArrayNumberVariable,
ArrayObjectVariable,
@ -10,6 +14,7 @@ from core.variables import (
IntegerVariable,
ObjectSegment,
SecretVariable,
SegmentType,
StringVariable,
)
from core.variables.exc import VariableError
@ -163,3 +168,103 @@ def test_array_none_variable():
var = variable_factory.build_segment([None, None, None, None])
assert isinstance(var, ArrayAnySegment)
assert var.value == [None, None, None, None]
@st.composite
def _generate_file(draw) -> File:
file_id = draw(st.text(min_size=1, max_size=10))
tenant_id = draw(st.text(min_size=1, max_size=10))
file_type, mime_type, extension = draw(
st.sampled_from(
[
(FileType.IMAGE, "image/png", ".png"),
(FileType.VIDEO, "video/mp4", ".mp4"),
(FileType.DOCUMENT, "text/plain", ".txt"),
(FileType.AUDIO, "audio/mpeg", ".mp3"),
]
)
)
filename = "test-file"
size = draw(st.integers(min_value=0, max_value=1024 * 1024))
transfer_method = draw(st.sampled_from(list(FileTransferMethod)))
if transfer_method == FileTransferMethod.REMOTE_URL:
url = "https://test.example.com/test-file"
file = File(
id="test_file_id",
tenant_id="test_tenant_id",
type=file_type,
transfer_method=transfer_method,
remote_url=url,
related_id=None,
filename=filename,
extension=extension,
mime_type=mime_type,
size=size,
)
else:
relation_id = draw(st.uuids(version=4))
file = File(
id="test_file_id",
tenant_id="test_tenant_id",
type=file_type,
transfer_method=transfer_method,
related_id=str(relation_id),
filename=filename,
extension=extension,
mime_type=mime_type,
size=size,
)
return file
def _scalar_value() -> st.SearchStrategy[int | float | str | File]:
return st.one_of(
st.none(),
st.integers(),
st.floats(),
st.text(),
_generate_file(),
)
@given(_scalar_value())
def test_build_segment_and_extract_values_for_scalar_types(value):
seg = variable_factory.build_segment(value)
assert seg.value == value
@given(st.lists(_scalar_value()))
def test_build_segment_and_extract_values_for_array_types(values):
seg = variable_factory.build_segment(values)
assert seg.value == values
def test_build_segment_type_for_scalar():
@dataclass(frozen=True)
class TestCase:
value: int | float | str | File
expected_type: SegmentType
file = File(
id="test_file_id",
tenant_id="test_tenant_id",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://test.example.com/test-file.png",
filename="test-file",
extension=".png",
mime_type="image/png",
size=1000,
)
cases = [
TestCase(0, SegmentType.NUMBER),
TestCase(0.0, SegmentType.NUMBER),
TestCase("", SegmentType.STRING),
TestCase(file, SegmentType.FILE),
]
for idx, c in enumerate(cases, 1):
segment = variable_factory.build_segment(c.value)
assert segment.value_type == c.expected_type, f"test case {idx} failed."

View File

@ -0,0 +1,25 @@
from core.file import File, FileTransferMethod, FileType
def test_file():
file = File(
id="test-file",
tenant_id="test-tenant-id",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="test-related-id",
filename="image.png",
extension=".png",
mime_type="image/png",
size=67,
storage_key="test-storage-key",
url="https://example.com/image.png",
)
assert file.tenant_id == "test-tenant-id"
assert file.type == FileType.IMAGE
assert file.transfer_method == FileTransferMethod.TOOL_FILE
assert file.related_id == "test-related-id"
assert file.filename == "image.png"
assert file.extension == ".png"
assert file.mime_type == "image/png"
assert file.size == 67

View File

24
api/uv.lock generated
View File

@ -1295,6 +1295,7 @@ dev = [
{ name = "coverage" },
{ name = "dotenv-linter" },
{ name = "faker" },
{ name = "hypothesis" },
{ name = "lxml-stubs" },
{ name = "mypy" },
{ name = "pytest" },
@ -1466,6 +1467,7 @@ dev = [
{ name = "coverage", specifier = "~=7.2.4" },
{ name = "dotenv-linter", specifier = "~=0.5.0" },
{ name = "faker", specifier = "~=32.1.0" },
{ name = "hypothesis", specifier = ">=6.131.15" },
{ name = "lxml-stubs", specifier = "~=0.5.1" },
{ name = "mypy", specifier = "~=1.15.0" },
{ name = "pytest", specifier = "~=8.3.2" },
@ -2562,6 +2564,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/48/30/47d0bf6072f7252e6521f3447ccfa40b421b6824517f82854703d0f5a98b/hyperframe-6.1.0-py3-none-any.whl", hash = "sha256:b03380493a519fce58ea5af42e4a42317bf9bd425596f7a0835ffce80f1a42e5", size = 13007, upload-time = "2025-01-22T21:41:47.295Z" },
]
[[package]]
name = "hypothesis"
version = "6.131.15"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "attrs" },
{ name = "sortedcontainers" },
]
sdist = { url = "https://files.pythonhosted.org/packages/f1/6f/1e291f80627f3e043b19a86f9f6b172b910e3575577917d3122a6558410d/hypothesis-6.131.15.tar.gz", hash = "sha256:11849998ae5eecc8c586c6c98e47677fcc02d97475065f62768cfffbcc15ef7a", size = 436596, upload_time = "2025-05-07T23:04:25.127Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b6/c7/78597bcec48e1585ea9029deb2bf2341516e90dd615a3db498413d68a4cc/hypothesis-6.131.15-py3-none-any.whl", hash = "sha256:e02e67e9f3cfd4cd4a67ccc03bf7431beccc1a084c5e90029799ddd36ce006d7", size = 501128, upload_time = "2025-05-07T23:04:22.045Z" },
]
[[package]]
name = "idna"
version = "3.10"
@ -5250,6 +5265,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/37/c3/6eeb6034408dac0fa653d126c9204ade96b819c936e136c5e8a6897eee9c/socksio-1.0.0-py3-none-any.whl", hash = "sha256:95dc1f15f9b34e8d7b16f06d74b8ccf48f609af32ab33c608d08761c5dcbb1f3", size = 12763, upload-time = "2020-04-17T15:50:31.878Z" },
]
[[package]]
name = "sortedcontainers"
version = "2.4.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/e8/c4/ba2f8066cceb6f23394729afe52f3bf7adec04bf9ed2c820b39e19299111/sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88", size = 30594, upload_time = "2021-05-16T22:03:42.897Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0", size = 29575, upload_time = "2021-05-16T22:03:41.177Z" },
]
[[package]]
name = "soupsieve"
version = "2.7"