mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-07-20 12:04:26 +08:00
feat(api): implement draft var related api
This commit is contained in:
parent
0f7ea8d5fa
commit
be098dee35
53
api/app.py
53
api/app.py
@ -1,4 +1,39 @@
|
||||
import os
|
||||
|
||||
|
||||
def _setup_gevent():
|
||||
"""Do gevent monkey patching.
|
||||
|
||||
This function should be called as early as possible. Ideally
|
||||
it should be the first statement in the entrypoint file.
|
||||
|
||||
It should be
|
||||
"""
|
||||
# It seems that JetBrains Python debugger does not work well with gevent,
|
||||
# so we need to disable gevent in debug mode.
|
||||
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
|
||||
if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() not in {"false", "0", "no"}:
|
||||
return
|
||||
if os.environ.get("GEVENT_SUPPORT", "0") == "0":
|
||||
return
|
||||
|
||||
from gevent import monkey
|
||||
|
||||
# gevent
|
||||
monkey.patch_all()
|
||||
|
||||
from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||
|
||||
# grpc gevent
|
||||
grpc_gevent.init_gevent()
|
||||
|
||||
import psycogreen.gevent # type: ignore
|
||||
|
||||
psycogreen.gevent.patch_psycopg()
|
||||
|
||||
|
||||
_setup_gevent()
|
||||
|
||||
import sys
|
||||
|
||||
|
||||
@ -14,24 +49,6 @@ if is_db_command():
|
||||
|
||||
app = create_migrations_app()
|
||||
else:
|
||||
# It seems that JetBrains Python debugger does not work well with gevent,
|
||||
# so we need to disable gevent in debug mode.
|
||||
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
|
||||
if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
|
||||
from gevent import monkey
|
||||
|
||||
# gevent
|
||||
monkey.patch_all()
|
||||
|
||||
from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||
|
||||
# grpc gevent
|
||||
grpc_gevent.init_gevent()
|
||||
|
||||
import psycogreen.gevent # type: ignore
|
||||
|
||||
psycogreen.gevent.patch_psycopg()
|
||||
|
||||
from app_factory import create_app
|
||||
|
||||
app = create_app()
|
||||
|
@ -63,6 +63,7 @@ from .app import (
|
||||
statistic,
|
||||
workflow,
|
||||
workflow_app_log,
|
||||
workflow_draft_variable,
|
||||
workflow_run,
|
||||
workflow_statistic,
|
||||
)
|
||||
|
319
api/controllers/console/app/workflow_draft_variable.py
Normal file
319
api/controllers/console/app/workflow_draft_variable.py
Normal file
@ -0,0 +1,319 @@
|
||||
import logging
|
||||
from typing import NoReturn
|
||||
|
||||
from flask import Response
|
||||
from flask_restful import Resource, fields, inputs, marshal_with, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (
|
||||
DraftWorkflowNotExist,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from factories.variable_factory import build_segment
|
||||
from libs.login import current_user, login_required
|
||||
from models import App, AppMode, db
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_pagination_parser():
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"page",
|
||||
type=inputs.int_range(1, 100_000),
|
||||
required=False,
|
||||
default=1,
|
||||
location="args",
|
||||
help="the page of data requested",
|
||||
)
|
||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
return parser
|
||||
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
|
||||
"id": fields.String,
|
||||
"type": fields.String(attribute=lambda model: model.get_variable_type()),
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
|
||||
"value_type": fields.String,
|
||||
"edited": fields.Boolean(attribute=lambda model: model.edited),
|
||||
"visible": fields.Boolean,
|
||||
}
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_FIELDS = dict(
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
|
||||
value=fields.Raw(attribute=lambda variable: variable.get_value().value),
|
||||
)
|
||||
|
||||
_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
|
||||
"id": fields.String,
|
||||
"type": fields.String(attribute=lambda _: "env"),
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
|
||||
"value_type": fields.String,
|
||||
"edited": fields.Boolean(attribute=lambda model: model.edited),
|
||||
"visible": fields.Boolean,
|
||||
}
|
||||
|
||||
_WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS = {
|
||||
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)),
|
||||
}
|
||||
|
||||
|
||||
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
|
||||
return var_list.variables
|
||||
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = {
|
||||
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items),
|
||||
"total": fields.Raw(),
|
||||
}
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
|
||||
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
|
||||
}
|
||||
|
||||
|
||||
def _api_prerequisite(f):
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
It ensures the following conditions are satisfied:
|
||||
|
||||
- Dify has been property setup.
|
||||
- The request user has logged in and initialized.
|
||||
- The requested app is a workflow or a chat flow.
|
||||
- The request user has the edit permission for the app.
|
||||
"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def wrapper(*args, **kwargs):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class WorkflowVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
parser = _create_pagination_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
workflow_service = WorkflowService()
|
||||
workflow_exist = workflow_service.is_workflow_exist(app_model=app_model)
|
||||
if not workflow_exist:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
workflow_vars = draft_var_srv.list_variables_without_values(
|
||||
app_id=app_model.id,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
)
|
||||
|
||||
return workflow_vars
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session,
|
||||
)
|
||||
draft_var_srv.delete_workflow_variables(app_model.id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
def validate_node_id(node_id: str) -> NoReturn | None:
|
||||
if node_id in [
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
]:
|
||||
# NOTE(QuantumGhost): While we store the system and conversation variables as node variables
|
||||
# with specific `node_id` in database, we still want to make the API separated. By disallowing
|
||||
# accessing system and conversation variables in `WorkflowDraftNodeVariableListApi`,
|
||||
# we mitigate the risk that user of the API depending on the implementation detail of the API.
|
||||
#
|
||||
# ref: [Hyrum's Law](https://www.hyrumslaw.com/)
|
||||
|
||||
raise InvalidArgumentError(
|
||||
f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class NodeVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
def get(self, app_model: App, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
node_vars = draft_var_srv.list_node_variables(app_model.id, node_id)
|
||||
|
||||
return node_vars
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
srv = WorkflowDraftVariableService(db.session)
|
||||
srv.delete_node_variables(app_model.id, node_id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
class VariableApi(Resource):
|
||||
_PATCH_NAME_FIELD = "name"
|
||||
_PATCH_VALUE_FIELD = "value"
|
||||
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
def get(self, app_model: App, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session,
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
return variable
|
||||
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
def patch(self, app_model: App, variable_id: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument(self._PATCH_VALUE_FIELD, type=build_segment, required=False, nullable=True, location="json")
|
||||
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session,
|
||||
)
|
||||
args = parser.parse_args(strict=True)
|
||||
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
new_name = args.get(self._PATCH_NAME_FIELD, None)
|
||||
new_value = args.get(self._PATCH_VALUE_FIELD, None)
|
||||
|
||||
if new_name is None and new_value is None:
|
||||
return variable
|
||||
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||
db.session.commit()
|
||||
return variable
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session,
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
draft_var_srv.delete_variable(variable)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
if node_id == CONVERSATION_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_conversation_variables(app_model.id)
|
||||
elif node_id == SYSTEM_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_system_variables(app_model.id)
|
||||
else:
|
||||
draft_vars = draft_var_srv.list_node_variables(app_id=app_model.id, node_id=node_id)
|
||||
return draft_vars
|
||||
|
||||
|
||||
class ConversationVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
def get(self, app_model: App):
|
||||
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
|
||||
|
||||
|
||||
class SystemVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
def get(self, app_model: App):
|
||||
return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID)
|
||||
|
||||
|
||||
class EnvironmentVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
# fetch draft workflow by app_model
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.get_draft_workflow(app_model=app_model)
|
||||
if workflow is None:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
env_vars = workflow.environment_variables
|
||||
env_vars_list = []
|
||||
for v in env_vars:
|
||||
env_vars_list.append(
|
||||
{
|
||||
"id": v.id,
|
||||
"type": "env",
|
||||
"name": v.name,
|
||||
"description": v.description,
|
||||
"selector": v.selector,
|
||||
"value_type": v.value_type.value,
|
||||
"value": v.value,
|
||||
# Do not track edited for env vars.
|
||||
"edited": False,
|
||||
"visible": True,
|
||||
"editable": True,
|
||||
}
|
||||
)
|
||||
|
||||
return {"items": env_vars_list}
|
||||
|
||||
|
||||
api.add_resource(
|
||||
WorkflowVariableCollectionApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/variables",
|
||||
)
|
||||
api.add_resource(NodeVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
|
||||
api.add_resource(VariableApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>")
|
||||
|
||||
api.add_resource(ConversationVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/conversation-variables")
|
||||
api.add_resource(SystemVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/system-variables")
|
||||
api.add_resource(EnvironmentVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/environment-variables")
|
196
api/controllers/console/app/workflow_draft_variables_test.py
Normal file
196
api/controllers/console/app/workflow_draft_variables_test.py
Normal file
@ -0,0 +1,196 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from collections import OrderedDict
|
||||
from typing import NamedTuple
|
||||
|
||||
from flask_restful import marshal
|
||||
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from factories.variable_factory import build_segment
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList
|
||||
|
||||
from .workflow_draft_variable import (
|
||||
_WORKFLOW_DRAFT_VARIABLE_FIELDS,
|
||||
_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS,
|
||||
_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS,
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
|
||||
)
|
||||
|
||||
_TEST_APP_ID = "test_app_id"
|
||||
|
||||
|
||||
class TestWorkflowDraftVariableFields:
|
||||
def test_conversation_variable(self):
|
||||
conv_var = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=_TEST_APP_ID, name="conv_var", value=build_segment(1)
|
||||
)
|
||||
|
||||
conv_var.id = str(uuid.uuid4())
|
||||
conv_var.visible = True
|
||||
|
||||
expected_without_value = OrderedDict(
|
||||
{
|
||||
"id": str(conv_var.id),
|
||||
"type": conv_var.get_variable_type().value,
|
||||
"name": "conv_var",
|
||||
"description": "",
|
||||
"selector": [CONVERSATION_VARIABLE_NODE_ID, "conv_var"],
|
||||
"value_type": "number",
|
||||
"edited": False,
|
||||
"visible": True,
|
||||
}
|
||||
)
|
||||
|
||||
assert marshal(conv_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
|
||||
expected_with_value = expected_without_value.copy()
|
||||
expected_with_value["value"] = 1
|
||||
assert marshal(conv_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value
|
||||
|
||||
def test_create_sys_variable(self):
|
||||
sys_var = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=_TEST_APP_ID,
|
||||
name="sys_var",
|
||||
value=build_segment("a"),
|
||||
editable=True,
|
||||
)
|
||||
|
||||
sys_var.id = str(uuid.uuid4())
|
||||
sys_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
sys_var.visible = True
|
||||
|
||||
expected_without_value = OrderedDict(
|
||||
{
|
||||
"id": str(sys_var.id),
|
||||
"type": sys_var.get_variable_type().value,
|
||||
"name": "sys_var",
|
||||
"description": "",
|
||||
"selector": [SYSTEM_VARIABLE_NODE_ID, "sys_var"],
|
||||
"value_type": "string",
|
||||
"edited": True,
|
||||
"visible": True,
|
||||
}
|
||||
)
|
||||
assert marshal(sys_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
|
||||
expected_with_value = expected_without_value.copy()
|
||||
expected_with_value["value"] = "a"
|
||||
assert marshal(sys_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value
|
||||
|
||||
def test_node_variable(self):
|
||||
node_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=_TEST_APP_ID,
|
||||
node_id="test_node",
|
||||
name="node_var",
|
||||
value=build_segment([1, "a"]),
|
||||
visible=False,
|
||||
)
|
||||
|
||||
node_var.id = str(uuid.uuid4())
|
||||
node_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
|
||||
expected_without_value = OrderedDict(
|
||||
{
|
||||
"id": str(node_var.id),
|
||||
"type": node_var.get_variable_type().value,
|
||||
"name": "node_var",
|
||||
"description": "",
|
||||
"selector": ["test_node", "node_var"],
|
||||
"value_type": "array[any]",
|
||||
"edited": True,
|
||||
"visible": False,
|
||||
}
|
||||
)
|
||||
|
||||
assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
|
||||
expected_with_value = expected_without_value.copy()
|
||||
expected_with_value["value"] = [1, "a"]
|
||||
assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value
|
||||
|
||||
|
||||
class TestWorkflowDraftVariableList:
|
||||
def test_workflow_draft_variable_list(self):
|
||||
class TestCase(NamedTuple):
|
||||
name: str
|
||||
var_list: WorkflowDraftVariableList
|
||||
expected: dict
|
||||
|
||||
node_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=_TEST_APP_ID,
|
||||
node_id="test_node",
|
||||
name="test_var",
|
||||
value=build_segment("a"),
|
||||
visible=True,
|
||||
)
|
||||
node_var.id = str(uuid.uuid4())
|
||||
node_var_dict = OrderedDict(
|
||||
{
|
||||
"id": str(node_var.id),
|
||||
"type": node_var.get_variable_type().value,
|
||||
"name": "test_var",
|
||||
"description": "",
|
||||
"selector": ["test_node", "test_var"],
|
||||
"value_type": "string",
|
||||
"edited": False,
|
||||
"visible": True,
|
||||
}
|
||||
)
|
||||
|
||||
cases = [
|
||||
TestCase(
|
||||
name="empty variable list",
|
||||
var_list=WorkflowDraftVariableList(variables=[]),
|
||||
expected=OrderedDict(
|
||||
{
|
||||
"items": [],
|
||||
"total": None,
|
||||
}
|
||||
),
|
||||
),
|
||||
TestCase(
|
||||
name="empty variable list with total",
|
||||
var_list=WorkflowDraftVariableList(variables=[], total=10),
|
||||
expected=OrderedDict(
|
||||
{
|
||||
"items": [],
|
||||
"total": 10,
|
||||
}
|
||||
),
|
||||
),
|
||||
TestCase(
|
||||
name="non-empty variable list",
|
||||
var_list=WorkflowDraftVariableList(variables=[node_var], total=None),
|
||||
expected=OrderedDict(
|
||||
{
|
||||
"items": [node_var_dict],
|
||||
"total": None,
|
||||
}
|
||||
),
|
||||
),
|
||||
TestCase(
|
||||
name="non-empty variable list with total",
|
||||
var_list=WorkflowDraftVariableList(variables=[node_var], total=10),
|
||||
expected=OrderedDict(
|
||||
{
|
||||
"items": [node_var_dict],
|
||||
"total": 10,
|
||||
}
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
for idx, case in enumerate(cases, 1):
|
||||
assert marshal(case.var_list, _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS) == case.expected, (
|
||||
f"Test case {idx} failed, {case.name=}"
|
||||
)
|
||||
|
||||
|
||||
def test_workflow_node_variables_fields():
|
||||
conv_var = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=_TEST_APP_ID, name="conv_var", value=build_segment(1)
|
||||
)
|
||||
resp = marshal(WorkflowDraftVariableList(variables=[conv_var]), _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
assert isinstance(resp, dict)
|
||||
assert len(resp["items"]) == 1
|
||||
item_dict = resp["items"][0]
|
||||
assert item_dict["name"] == "conv_var"
|
||||
assert item_dict["value"] == 1
|
@ -8,6 +8,15 @@ from libs.login import current_user
|
||||
from models import App, AppMode
|
||||
|
||||
|
||||
def _load_app_model(app_id: str) -> Optional[App]:
|
||||
app_model = (
|
||||
db.session.query(App)
|
||||
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
return app_model
|
||||
|
||||
|
||||
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
||||
def decorator(view_func):
|
||||
@wraps(view_func)
|
||||
@ -20,11 +29,7 @@ def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[
|
||||
|
||||
del kwargs["app_id"]
|
||||
|
||||
app_model = (
|
||||
db.session.query(App)
|
||||
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
app_model = _load_app_model(app_id)
|
||||
|
||||
if not app_model:
|
||||
raise AppNotFoundError()
|
||||
|
@ -139,3 +139,13 @@ class InvokeRateLimitError(BaseHTTPException):
|
||||
error_code = "rate_limit_error"
|
||||
description = "Rate Limit Error"
|
||||
code = 429
|
||||
|
||||
|
||||
class NotFoundError(BaseHTTPException):
|
||||
error_code = "unknown"
|
||||
code = 404
|
||||
|
||||
|
||||
class InvalidArgumentError(BaseHTTPException):
|
||||
error_code = "invalid_param"
|
||||
code = 400
|
||||
|
@ -17,9 +17,24 @@ class InvokeFrom(Enum):
|
||||
Invoke From.
|
||||
"""
|
||||
|
||||
# SERVICE_API indicates that this invocation is from an API call to Dify app.
|
||||
#
|
||||
# Description of service api in Dify docs:
|
||||
# https://docs.dify.ai/en/guides/application-publishing/developing-with-apis
|
||||
SERVICE_API = "service-api"
|
||||
|
||||
# WEB_APP indicates that this invocation is from
|
||||
# the web app of the workflow (or chatflow).
|
||||
#
|
||||
# Description of web app in Dify docs:
|
||||
# https://docs.dify.ai/en/guides/application-publishing/launch-your-webapp-quickly/README
|
||||
WEB_APP = "web-app"
|
||||
|
||||
# EXPLORE indicates that this invocation is from
|
||||
# the workflow (or chatflow) explore page.
|
||||
EXPLORE = "explore"
|
||||
# DEBUGGER indicates that this invocation is from
|
||||
# the workflow (or chatflow) edit page.
|
||||
DEBUGGER = "debugger"
|
||||
|
||||
@classmethod
|
||||
|
@ -1 +1,21 @@
|
||||
from typing import Any
|
||||
|
||||
FILE_MODEL_IDENTITY = "__dify__file__"
|
||||
|
||||
# DUMMY_OUTPUT_IDENTITY is a placeholder output for workflow nodes.
|
||||
# Its sole possible value is `None`.
|
||||
#
|
||||
# This is used to signal the execution of a workflow node when it has no other outputs.
|
||||
_DUMMY_OUTPUT_IDENTITY = "__dummy__"
|
||||
_DUMMY_OUTPUT_VALUE: None = None
|
||||
|
||||
|
||||
def add_dummy_output(original: dict[str, Any] | None) -> dict[str, Any]:
|
||||
if original is None:
|
||||
original = {}
|
||||
original[_DUMMY_OUTPUT_IDENTITY] = _DUMMY_OUTPUT_VALUE
|
||||
return original
|
||||
|
||||
|
||||
def is_dummy_output_variable(name: str) -> bool:
|
||||
return name == _DUMMY_OUTPUT_IDENTITY
|
||||
|
@ -7,12 +7,12 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from core.file import File, FileAttribute, file_manager
|
||||
from core.variables import Segment, SegmentGroup, Variable
|
||||
from core.variables.consts import MIN_SELECTORS_LENGTH
|
||||
from core.variables.segments import FileSegment, NoneSegment
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from factories import variable_factory
|
||||
|
||||
from ..constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from ..enums import SystemVariableKey
|
||||
|
||||
VariableValue = Union[str, int, float, dict, list, File]
|
||||
|
||||
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
|
||||
@ -91,7 +91,7 @@ class VariablePool(BaseModel):
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if len(selector) < 2:
|
||||
if len(selector) < MIN_SELECTORS_LENGTH:
|
||||
raise ValueError("Invalid selector")
|
||||
|
||||
if isinstance(value, Variable):
|
||||
@ -118,7 +118,7 @@ class VariablePool(BaseModel):
|
||||
Raises:
|
||||
ValueError: If the selector is invalid.
|
||||
"""
|
||||
if len(selector) < 2:
|
||||
if len(selector) < MIN_SELECTORS_LENGTH:
|
||||
return None
|
||||
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
|
@ -1,8 +1,8 @@
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, TypeAlias, TypeVar, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
||||
@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from core.workflow.callbacks import WorkflowCallback
|
||||
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent, InNodeEvent
|
||||
@ -19,7 +20,7 @@ from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntime
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.event import NodeEvent
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from factories import file_factory
|
||||
from models.enums import UserFrom
|
||||
@ -120,6 +121,7 @@ class WorkflowEntry:
|
||||
node_id: str,
|
||||
user_id: str,
|
||||
user_inputs: dict,
|
||||
conversation_variables: dict | None = None,
|
||||
) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]:
|
||||
"""
|
||||
Single step run workflow node
|
||||
@ -144,13 +146,19 @@ class WorkflowEntry:
|
||||
except StopIteration:
|
||||
raise ValueError("node id not found in workflow graph")
|
||||
|
||||
node_config_data = node_config.get("data", {})
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType(node_config.get("data", {}).get("type"))
|
||||
node_version = node_config.get("data", {}).get("version", "1")
|
||||
node_type = NodeType(node_config_data.get("type"))
|
||||
node_version = node_config_data.get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
metadata_attacher = _attach_execution_metadata_based_on_node_config(node_config_data)
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(environment_variables=workflow.environment_variables)
|
||||
variable_pool = VariablePool(
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variable=conversation_variables or {},
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=workflow.graph_dict)
|
||||
@ -188,11 +196,15 @@ class WorkflowEntry:
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
)
|
||||
cls._load_persisted_draft_var_and_populate_pool(app_id=workflow.app_id, variable_pool=variable_pool)
|
||||
|
||||
try:
|
||||
# run node
|
||||
generator = node_instance.run()
|
||||
except Exception as e:
|
||||
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
|
||||
if metadata_attacher:
|
||||
generator = _wrap_generator(generator, metadata_attacher)
|
||||
return node_instance, generator
|
||||
|
||||
@classmethod
|
||||
@ -319,6 +331,16 @@ class WorkflowEntry:
|
||||
return value.to_dict()
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _load_persisted_draft_var_and_populate_pool(cls, app_id: str, variable_pool: VariablePool) -> None:
|
||||
"""
|
||||
Load persisted draft variables and populate the variable pool.
|
||||
:param app_id: The application ID.
|
||||
:param variable_pool: The variable pool to populate.
|
||||
"""
|
||||
# TODO(QuantumGhost):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def mapping_user_inputs_to_variable_pool(
|
||||
cls,
|
||||
@ -367,3 +389,61 @@ class WorkflowEntry:
|
||||
# append variable and value to variable pool
|
||||
if variable_node_id != ENVIRONMENT_VARIABLE_NODE_ID:
|
||||
variable_pool.add([variable_node_id] + variable_key_list, input_value)
|
||||
|
||||
|
||||
_YieldT_co = TypeVar("_YieldT_co", covariant=True)
|
||||
_YieldR_co = TypeVar("_YieldR_co", covariant=True)
|
||||
|
||||
|
||||
def _wrap_generator(
|
||||
gen: Generator[_YieldT_co, None, None],
|
||||
mapper: Callable[[_YieldT_co], _YieldR_co],
|
||||
) -> Generator[_YieldR_co, None, None]:
|
||||
for item in gen:
|
||||
yield mapper(item)
|
||||
|
||||
|
||||
_NodeOrInNodeEvent: TypeAlias = NodeEvent | InNodeEvent
|
||||
|
||||
|
||||
def _attach_execution_metadata(
|
||||
extra_metadata: dict[NodeRunMetadataKey, Any],
|
||||
) -> Callable[[_NodeOrInNodeEvent], _NodeOrInNodeEvent]:
|
||||
def _execution_metadata_mapper(e: NodeEvent | InNodeEvent) -> NodeEvent | InNodeEvent:
|
||||
if not isinstance(e, RunCompletedEvent):
|
||||
return e
|
||||
run_result = e.run_result
|
||||
if run_result.metadata is None:
|
||||
run_result.metadata = {}
|
||||
for k, v in extra_metadata.items():
|
||||
run_result.metadata[k] = v
|
||||
return e
|
||||
|
||||
return _execution_metadata_mapper
|
||||
|
||||
|
||||
def _attach_execution_metadata_based_on_node_config(
|
||||
node_config: dict,
|
||||
) -> Callable[[_NodeOrInNodeEvent], _NodeOrInNodeEvent] | None:
|
||||
in_loop = node_config.get("isInLoop", False)
|
||||
in_iteration = node_config.get("isInIteration", False)
|
||||
if in_loop:
|
||||
loop_id = node_config.get("loop_id")
|
||||
if loop_id is None:
|
||||
raise Exception("invalid graph")
|
||||
return _attach_execution_metadata(
|
||||
{
|
||||
NodeRunMetadataKey.LOOP_ID: loop_id,
|
||||
}
|
||||
)
|
||||
elif in_iteration:
|
||||
iteration_id = node_config.get("iteration_id")
|
||||
if iteration_id is None:
|
||||
raise Exception("invalid graph")
|
||||
return _attach_execution_metadata(
|
||||
{
|
||||
NodeRunMetadataKey.ITERATION_ID: iteration_id,
|
||||
}
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
@ -114,6 +114,10 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
|
||||
return cast(Variable, result)
|
||||
|
||||
|
||||
def infer_segment_type_from_value(value: Any, /) -> SegmentType:
|
||||
return build_segment(value).value_type
|
||||
|
||||
|
||||
def build_segment(value: Any, /) -> Segment:
|
||||
if value is None:
|
||||
return NoneSegment()
|
||||
|
@ -602,6 +602,14 @@ class InstalledApp(Base):
|
||||
return tenant
|
||||
|
||||
|
||||
class ConversationSource(StrEnum):
|
||||
"""This enumeration is designed for use with `Conversation.from_source`."""
|
||||
|
||||
# NOTE(QuantumGhost): The enumeration members may not cover all possible cases.
|
||||
API = "api"
|
||||
CONSOLE = "console"
|
||||
|
||||
|
||||
class Conversation(Base):
|
||||
__tablename__ = "conversations"
|
||||
__table_args__ = (
|
||||
@ -623,7 +631,14 @@ class Conversation(Base):
|
||||
system_instruction = db.Column(db.Text)
|
||||
system_instruction_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
status = db.Column(db.String(255), nullable=False)
|
||||
|
||||
# The `invoke_from` records how the conversation is created.
|
||||
#
|
||||
# Its value corresponds to the members of `InvokeFrom`.
|
||||
# (api/core/app/entities/app_invoke_entities.py)
|
||||
invoke_from = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# ref: ConversationSource.
|
||||
from_source = db.Column(db.String(255), nullable=False)
|
||||
from_end_user_id = db.Column(StringUUID)
|
||||
from_account_id = db.Column(StringUUID)
|
||||
|
@ -135,6 +135,8 @@ class Workflow(Base):
|
||||
"conversation_variables", db.Text, nullable=False, server_default="{}"
|
||||
)
|
||||
|
||||
VERSION_DRAFT = "draft"
|
||||
|
||||
@classmethod
|
||||
def new(
|
||||
cls,
|
||||
@ -356,6 +358,10 @@ class Workflow(Base):
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def version_from_datetime(d: datetime) -> str:
|
||||
return str(d)
|
||||
|
||||
|
||||
class WorkflowRunStatus(StrEnum):
|
||||
"""
|
||||
@ -823,7 +829,7 @@ def _naive_utc_datetime():
|
||||
|
||||
class WorkflowDraftVariable(Base):
|
||||
@staticmethod
|
||||
def unique_columns() -> list[str]:
|
||||
def unique_app_id_node_id_name() -> list[str]:
|
||||
return [
|
||||
"app_id",
|
||||
"node_id",
|
||||
@ -831,7 +837,7 @@ class WorkflowDraftVariable(Base):
|
||||
]
|
||||
|
||||
__tablename__ = "workflow_draft_variables"
|
||||
__table_args__ = (UniqueConstraint(*unique_columns()),)
|
||||
__table_args__ = (UniqueConstraint(*unique_app_id_node_id_name()),)
|
||||
|
||||
# id is the unique identifier of a draft variable.
|
||||
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
|
||||
@ -996,10 +1002,11 @@ class WorkflowDraftVariable(Base):
|
||||
name: str,
|
||||
value: Segment,
|
||||
visible: bool = True,
|
||||
editable: bool = True,
|
||||
) -> "WorkflowDraftVariable":
|
||||
variable = cls._new(app_id=app_id, node_id=node_id, name=name, value=value)
|
||||
variable.visible = visible
|
||||
variable.editable = True
|
||||
variable.editable = editable
|
||||
return variable
|
||||
|
||||
@property
|
||||
|
@ -149,6 +149,7 @@ dev = [
|
||||
"types-tqdm~=4.67.0",
|
||||
"types-ujson~=5.10.0",
|
||||
"boto3-stubs>=1.38.20",
|
||||
"hypothesis>=6.131.15",
|
||||
]
|
||||
|
||||
############################################################
|
||||
|
@ -1,4 +1,16 @@
|
||||
import os
|
||||
import random
|
||||
import secrets
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app_factory import create_app
|
||||
from models import Account, DifySetup, Tenant, TenantAccountJoin, db
|
||||
from services.account_service import AccountService, RegisterService
|
||||
|
||||
# Getting the absolute path of the current file's directory
|
||||
ABS_PATH = os.path.dirname(os.path.abspath(__file__))
|
||||
@ -17,3 +29,61 @@ def _load_env() -> None:
|
||||
|
||||
|
||||
_load_env()
|
||||
|
||||
_CACHED_APP = create_app()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_app() -> Flask:
|
||||
return _CACHED_APP
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def setup_account(request) -> Generator[Account, None, None]:
|
||||
"""`dify_setup` completes the setup process for the Dify application.
|
||||
|
||||
It creates `Account` and `Tenant`, and inserts a `DifySetup` record into the database.
|
||||
|
||||
Most tests in the `controllers` package may require dify has been successfully setup.
|
||||
"""
|
||||
with _CACHED_APP.test_request_context():
|
||||
rand_suffix = random.randint(int(1e6), int(1e7))
|
||||
name = f"test-user-{rand_suffix}"
|
||||
email = f"{name}@example.com"
|
||||
RegisterService.setup(
|
||||
email=email,
|
||||
name=name,
|
||||
password=secrets.token_hex(16),
|
||||
ip_address="localhost",
|
||||
)
|
||||
|
||||
with _CACHED_APP.test_request_context():
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
account = session.query(Account).filter_by(email=email).one()
|
||||
|
||||
yield account
|
||||
|
||||
with _CACHED_APP.test_request_context():
|
||||
db.session.query(DifySetup).delete()
|
||||
db.session.query(TenantAccountJoin).delete()
|
||||
db.session.query(Account).delete()
|
||||
db.session.query(Tenant).delete()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_req_ctx():
|
||||
with _CACHED_APP.test_request_context():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_header(setup_account) -> dict[str, str]:
|
||||
token = AccountService.get_account_jwt_token(setup_account)
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_client() -> Generator[FlaskClient, None, None]:
|
||||
with _CACHED_APP.test_client() as client:
|
||||
yield client
|
||||
|
@ -0,0 +1,46 @@
|
||||
import uuid
|
||||
from unittest import mock
|
||||
|
||||
from controllers.console.app import workflow_draft_variable as draft_variable_api
|
||||
from controllers.console.app import wraps
|
||||
from factories.variable_factory import build_segment
|
||||
from models import App, AppMode
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
|
||||
|
||||
def _get_mock_srv_class() -> type[WorkflowDraftVariableService]:
|
||||
return mock.create_autospec(WorkflowDraftVariableService)
|
||||
|
||||
|
||||
class TestWorkflowDraftNodeVariableListApi:
|
||||
def test_get(self, test_client, auth_header, monkeypatch):
|
||||
srv_class = _get_mock_srv_class()
|
||||
mock_app_model: App = App()
|
||||
mock_app_model.id = str(uuid.uuid4())
|
||||
test_node_id = "test_node_id"
|
||||
mock_app_model.mode = AppMode.ADVANCED_CHAT
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
|
||||
monkeypatch.setattr(draft_variable_api, "WorkflowDraftVariableService", srv_class)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
|
||||
var1 = WorkflowDraftVariable.create_node_variable(
|
||||
app_id="test_app_1",
|
||||
node_id="test_node_1",
|
||||
name="str_var",
|
||||
value=build_segment("str_value"),
|
||||
)
|
||||
srv_instance = mock.create_autospec(WorkflowDraftVariableService, instance=True)
|
||||
srv_class.return_value = srv_instance
|
||||
srv_instance.list_node_variables.return_value = WorkflowDraftVariableList(variables=[var1])
|
||||
|
||||
response = test_client.get(
|
||||
f"/console/api/apps/{mock_app_model.id}/workflows/draft/nodes/{test_node_id}/variables",
|
||||
headers=auth_header,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
response_dict = response.json
|
||||
assert isinstance(response_dict, dict)
|
||||
assert "items" in response_dict
|
||||
assert len(response_dict["items"]) == 1
|
0
api/tests/integration_tests/services/__init__.py
Normal file
0
api/tests/integration_tests/services/__init__.py
Normal file
@ -0,0 +1,142 @@
|
||||
import unittest
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from factories.variable_factory import build_segment
|
||||
from models import db
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableService
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("flask_req_ctx")
|
||||
class TestWorkflowDraftVariableService(unittest.TestCase):
|
||||
_test_app_id: str
|
||||
_session: Session
|
||||
_node2_id = "test_node_2"
|
||||
|
||||
def setUp(self):
|
||||
self._test_app_id = str(uuid.uuid4())
|
||||
self._session: Session = db.session
|
||||
sys_var = WorkflowDraftVariable.create_sys_variable(
|
||||
app_id=self._test_app_id,
|
||||
name="sys_var",
|
||||
value=build_segment("sys_value"),
|
||||
)
|
||||
conv_var = WorkflowDraftVariable.create_conversation_variable(
|
||||
app_id=self._test_app_id,
|
||||
name="conv_var",
|
||||
value=build_segment("conv_value"),
|
||||
)
|
||||
node2_vars = [
|
||||
WorkflowDraftVariable.create_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
node_id=self._node2_id,
|
||||
name="int_var",
|
||||
value=build_segment(1),
|
||||
visible=False,
|
||||
),
|
||||
WorkflowDraftVariable.create_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
node_id=self._node2_id,
|
||||
name="str_var",
|
||||
value=build_segment("str_value"),
|
||||
visible=True,
|
||||
),
|
||||
]
|
||||
node1_var = WorkflowDraftVariable.create_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
node_id="node_1",
|
||||
name="str_var",
|
||||
value=build_segment("str_value"),
|
||||
visible=True,
|
||||
)
|
||||
_variables = list(node2_vars)
|
||||
_variables.extend(
|
||||
[
|
||||
node1_var,
|
||||
sys_var,
|
||||
conv_var,
|
||||
]
|
||||
)
|
||||
|
||||
db.session.add_all(_variables)
|
||||
db.session.flush()
|
||||
self._variable_ids = [v.id for v in _variables]
|
||||
self._node1_str_var_id = node1_var.id
|
||||
self._sys_var_id = sys_var.id
|
||||
self._conv_var_id = conv_var.id
|
||||
self._node2_var_ids = [v.id for v in node2_vars]
|
||||
|
||||
def _get_test_srv(self) -> WorkflowDraftVariableService:
|
||||
return WorkflowDraftVariableService(session=self._session)
|
||||
|
||||
def tearDown(self):
|
||||
self._session.rollback()
|
||||
|
||||
def test_list_variables(self):
|
||||
srv = self._get_test_srv()
|
||||
var_list = srv.list_variables_without_values(self._test_app_id, page=1, limit=2)
|
||||
assert var_list.total == 5
|
||||
assert len(var_list.variables) == 2
|
||||
page1_var_ids = {v.id for v in var_list.variables}
|
||||
assert page1_var_ids.issubset(self._variable_ids)
|
||||
|
||||
var_list_2 = srv.list_variables_without_values(self._test_app_id, page=2, limit=2)
|
||||
assert var_list_2.total is None
|
||||
assert len(var_list_2.variables) == 2
|
||||
page2_var_ids = {v.id for v in var_list_2.variables}
|
||||
assert page2_var_ids.isdisjoint(page1_var_ids)
|
||||
assert page2_var_ids.issubset(self._variable_ids)
|
||||
|
||||
def test_get_node_variable(self):
|
||||
srv = self._get_test_srv()
|
||||
node_var = srv.get_node_variable(self._test_app_id, "node_1", "str_var")
|
||||
assert node_var.id == self._node1_str_var_id
|
||||
assert node_var.name == "str_var"
|
||||
assert node_var.get_value() == build_segment("str_value")
|
||||
|
||||
def test_get_system_variable(self):
|
||||
srv = self._get_test_srv()
|
||||
sys_var = srv.get_system_variable(self._test_app_id, "sys_var")
|
||||
assert sys_var.id == self._sys_var_id
|
||||
assert sys_var.name == "sys_var"
|
||||
assert sys_var.get_value() == build_segment("sys_value")
|
||||
|
||||
def test_get_conversation_variable(self):
|
||||
srv = self._get_test_srv()
|
||||
conv_var = srv.get_conversation_variable(self._test_app_id, "conv_var")
|
||||
assert conv_var.id == self._conv_var_id
|
||||
assert conv_var.name == "conv_var"
|
||||
assert conv_var.get_value() == build_segment("conv_value")
|
||||
|
||||
def test_delete_node_variables(self):
|
||||
srv = self._get_test_srv()
|
||||
srv.delete_node_variables(self._test_app_id, self._node2_id)
|
||||
node2_var_count = (
|
||||
self._session.query(WorkflowDraftVariable)
|
||||
.where(
|
||||
WorkflowDraftVariable.app_id == self._test_app_id,
|
||||
WorkflowDraftVariable.node_id == self._node2_id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
assert node2_var_count == 0
|
||||
|
||||
def test_delete_variable(self):
|
||||
srv = self._get_test_srv()
|
||||
node_1_var = (
|
||||
self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).one()
|
||||
)
|
||||
srv.delete_variable(node_1_var)
|
||||
exists = bool(
|
||||
self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).first()
|
||||
)
|
||||
assert exists is False
|
||||
|
||||
def test__list_node_variables(self):
|
||||
srv = self._get_test_srv()
|
||||
node_vars = srv._list_node_variables(self._test_app_id, self._node2_id)
|
||||
assert len(node_vars) == 2
|
||||
assert {v.id for v in node_vars} == set(self._node2_var_ids)
|
@ -1,7 +1,11 @@
|
||||
from dataclasses import dataclass
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from hypothesis import given
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.variables import (
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectVariable,
|
||||
@ -10,6 +14,7 @@ from core.variables import (
|
||||
IntegerVariable,
|
||||
ObjectSegment,
|
||||
SecretVariable,
|
||||
SegmentType,
|
||||
StringVariable,
|
||||
)
|
||||
from core.variables.exc import VariableError
|
||||
@ -163,3 +168,103 @@ def test_array_none_variable():
|
||||
var = variable_factory.build_segment([None, None, None, None])
|
||||
assert isinstance(var, ArrayAnySegment)
|
||||
assert var.value == [None, None, None, None]
|
||||
|
||||
|
||||
@st.composite
|
||||
def _generate_file(draw) -> File:
|
||||
file_id = draw(st.text(min_size=1, max_size=10))
|
||||
tenant_id = draw(st.text(min_size=1, max_size=10))
|
||||
file_type, mime_type, extension = draw(
|
||||
st.sampled_from(
|
||||
[
|
||||
(FileType.IMAGE, "image/png", ".png"),
|
||||
(FileType.VIDEO, "video/mp4", ".mp4"),
|
||||
(FileType.DOCUMENT, "text/plain", ".txt"),
|
||||
(FileType.AUDIO, "audio/mpeg", ".mp3"),
|
||||
]
|
||||
)
|
||||
)
|
||||
filename = "test-file"
|
||||
size = draw(st.integers(min_value=0, max_value=1024 * 1024))
|
||||
|
||||
transfer_method = draw(st.sampled_from(list(FileTransferMethod)))
|
||||
if transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
url = "https://test.example.com/test-file"
|
||||
file = File(
|
||||
id="test_file_id",
|
||||
tenant_id="test_tenant_id",
|
||||
type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=url,
|
||||
related_id=None,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=size,
|
||||
)
|
||||
else:
|
||||
relation_id = draw(st.uuids(version=4))
|
||||
|
||||
file = File(
|
||||
id="test_file_id",
|
||||
tenant_id="test_tenant_id",
|
||||
type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
related_id=str(relation_id),
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=size,
|
||||
)
|
||||
return file
|
||||
|
||||
|
||||
def _scalar_value() -> st.SearchStrategy[int | float | str | File]:
|
||||
return st.one_of(
|
||||
st.none(),
|
||||
st.integers(),
|
||||
st.floats(),
|
||||
st.text(),
|
||||
_generate_file(),
|
||||
)
|
||||
|
||||
|
||||
@given(_scalar_value())
|
||||
def test_build_segment_and_extract_values_for_scalar_types(value):
|
||||
seg = variable_factory.build_segment(value)
|
||||
assert seg.value == value
|
||||
|
||||
|
||||
@given(st.lists(_scalar_value()))
|
||||
def test_build_segment_and_extract_values_for_array_types(values):
|
||||
seg = variable_factory.build_segment(values)
|
||||
assert seg.value == values
|
||||
|
||||
|
||||
def test_build_segment_type_for_scalar():
|
||||
@dataclass(frozen=True)
|
||||
class TestCase:
|
||||
value: int | float | str | File
|
||||
expected_type: SegmentType
|
||||
|
||||
file = File(
|
||||
id="test_file_id",
|
||||
tenant_id="test_tenant_id",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://test.example.com/test-file.png",
|
||||
filename="test-file",
|
||||
extension=".png",
|
||||
mime_type="image/png",
|
||||
size=1000,
|
||||
)
|
||||
cases = [
|
||||
TestCase(0, SegmentType.NUMBER),
|
||||
TestCase(0.0, SegmentType.NUMBER),
|
||||
TestCase("", SegmentType.STRING),
|
||||
TestCase(file, SegmentType.FILE),
|
||||
]
|
||||
|
||||
for idx, c in enumerate(cases, 1):
|
||||
segment = variable_factory.build_segment(c.value)
|
||||
assert segment.value_type == c.expected_type, f"test case {idx} failed."
|
||||
|
25
api/tests/unit_tests/core/file/test_models.py
Normal file
25
api/tests/unit_tests/core/file/test_models.py
Normal file
@ -0,0 +1,25 @@
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
|
||||
|
||||
def test_file():
|
||||
file = File(
|
||||
id="test-file",
|
||||
tenant_id="test-tenant-id",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id="test-related-id",
|
||||
filename="image.png",
|
||||
extension=".png",
|
||||
mime_type="image/png",
|
||||
size=67,
|
||||
storage_key="test-storage-key",
|
||||
url="https://example.com/image.png",
|
||||
)
|
||||
assert file.tenant_id == "test-tenant-id"
|
||||
assert file.type == FileType.IMAGE
|
||||
assert file.transfer_method == FileTransferMethod.TOOL_FILE
|
||||
assert file.related_id == "test-related-id"
|
||||
assert file.filename == "image.png"
|
||||
assert file.extension == ".png"
|
||||
assert file.mime_type == "image/png"
|
||||
assert file.size == 67
|
0
api/tests/unit_tests/models/__init__.py
Normal file
0
api/tests/unit_tests/models/__init__.py
Normal file
24
api/uv.lock
generated
24
api/uv.lock
generated
@ -1295,6 +1295,7 @@ dev = [
|
||||
{ name = "coverage" },
|
||||
{ name = "dotenv-linter" },
|
||||
{ name = "faker" },
|
||||
{ name = "hypothesis" },
|
||||
{ name = "lxml-stubs" },
|
||||
{ name = "mypy" },
|
||||
{ name = "pytest" },
|
||||
@ -1466,6 +1467,7 @@ dev = [
|
||||
{ name = "coverage", specifier = "~=7.2.4" },
|
||||
{ name = "dotenv-linter", specifier = "~=0.5.0" },
|
||||
{ name = "faker", specifier = "~=32.1.0" },
|
||||
{ name = "hypothesis", specifier = ">=6.131.15" },
|
||||
{ name = "lxml-stubs", specifier = "~=0.5.1" },
|
||||
{ name = "mypy", specifier = "~=1.15.0" },
|
||||
{ name = "pytest", specifier = "~=8.3.2" },
|
||||
@ -2562,6 +2564,19 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/48/30/47d0bf6072f7252e6521f3447ccfa40b421b6824517f82854703d0f5a98b/hyperframe-6.1.0-py3-none-any.whl", hash = "sha256:b03380493a519fce58ea5af42e4a42317bf9bd425596f7a0835ffce80f1a42e5", size = 13007, upload-time = "2025-01-22T21:41:47.295Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hypothesis"
|
||||
version = "6.131.15"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "attrs" },
|
||||
{ name = "sortedcontainers" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f1/6f/1e291f80627f3e043b19a86f9f6b172b910e3575577917d3122a6558410d/hypothesis-6.131.15.tar.gz", hash = "sha256:11849998ae5eecc8c586c6c98e47677fcc02d97475065f62768cfffbcc15ef7a", size = 436596, upload_time = "2025-05-07T23:04:25.127Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b6/c7/78597bcec48e1585ea9029deb2bf2341516e90dd615a3db498413d68a4cc/hypothesis-6.131.15-py3-none-any.whl", hash = "sha256:e02e67e9f3cfd4cd4a67ccc03bf7431beccc1a084c5e90029799ddd36ce006d7", size = 501128, upload_time = "2025-05-07T23:04:22.045Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "3.10"
|
||||
@ -5250,6 +5265,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/37/c3/6eeb6034408dac0fa653d126c9204ade96b819c936e136c5e8a6897eee9c/socksio-1.0.0-py3-none-any.whl", hash = "sha256:95dc1f15f9b34e8d7b16f06d74b8ccf48f609af32ab33c608d08761c5dcbb1f3", size = 12763, upload-time = "2020-04-17T15:50:31.878Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sortedcontainers"
|
||||
version = "2.4.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e8/c4/ba2f8066cceb6f23394729afe52f3bf7adec04bf9ed2c820b39e19299111/sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88", size = 30594, upload_time = "2021-05-16T22:03:42.897Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0", size = 29575, upload_time = "2021-05-16T22:03:41.177Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "soupsieve"
|
||||
version = "2.7"
|
||||
|
Loading…
x
Reference in New Issue
Block a user