dify/api/services/workflow_draft_variable_service.py

409 lines
16 KiB
Python

import dataclasses
import logging
from collections.abc import Mapping, Sequence
from typing import Any
from sqlalchemy import Engine, orm
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import and_, or_
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import Segment, Variable
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.nodes import NodeType
from core.workflow.variable_loader import VariableLoader
from factories import variable_factory
from factories.variable_factory import build_segment, segment_to_variable
from models.workflow import WorkflowDraftVariable, is_system_variable_editable
_logger = logging.getLogger(__name__)
@dataclasses.dataclass(frozen=True)
class WorkflowDraftVariableList:
variables: list[WorkflowDraftVariable]
total: int | None = None
class DraftVarLoader(VariableLoader):
# This implements the VariableLoader interface for loading draft variables.
#
# ref: core.workflow.variable_loader.VariableLoader
def __init__(self, engine: Engine, app_id: str) -> None:
self._engine = engine
self._app_id = app_id
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
if not selectors:
return []
with Session(bind=self._engine, expire_on_commit=False) as session:
srv = WorkflowDraftVariableService(session)
draft_vars = srv.get_draft_variables_by_selectors(self._app_id, selectors)
variables = []
for draft_var in draft_vars:
segment = build_segment(
draft_var.value,
)
variable = segment_to_variable(
segment=segment,
selector=draft_var.get_selector(),
id=draft_var.id,
name=draft_var.name,
description=draft_var.description,
)
variables.append(variable)
return variables
class WorkflowDraftVariableService:
_session: Session
def __init__(self, session: Session) -> None:
self._session = session
def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None:
return self._session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.id == variable_id).first()
def get_draft_variables_by_selectors(
self,
app_id: str,
selectors: Sequence[list[str]],
) -> list[WorkflowDraftVariable]:
ors = []
for selector in selectors:
node_id, name = selector
ors.append(and_(WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.name == name))
# NOTE(QuantumGhost): Although the number of `or` expressions may be large, as long as
# each expression includes conditions on both `node_id` and `name` (which are covered by the unique index),
# PostgreSQL can efficiently retrieve the results using a bitmap index scan.
#
# Alternatively, a `SELECT` statement could be constructed for each selector and
# combined using `UNION` to fetch all rows.
# Benchmarking indicates that both approaches yield comparable performance.
variables = (
self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == app_id, or_(*ors)).all()
)
return variables
def save_output_variables(self, app_id: str, node_id: str, node_type: NodeType, output: Mapping[str, Any]):
variable_builder = _DraftVariableBuilder(app_id=app_id)
variable_builder.build(node_id=node_id, node_type=node_type, output=output)
draft_variables = variable_builder.get_variables()
# draft_variables = _build_variables_from_output_mapping(app_id, node_id, node_type, output)
if not draft_variables:
return
# Although we could use SQLAlchemy ORM operations here, we choose not to for several reasons:
#
# 1. The variable saving process involves writing multiple rows to the
# `workflow_draft_variables` table. Batch insertion significantly improves performance.
# 2. Using the ORM would require either:
#
# a. Checking for the existence of each variable before insertion,
# resulting in 2n SQL statements for n variables and potential concurrency issues.
# b. Attempting insertion first, then updating if a unique index violation occurs,
# which still results in n to 2n SQL statements.
#
# Both approaches are inefficient and suboptimal.
# 3. We do not need to retrieve the results of the SQL execution or populate ORM
# model instances with the returned values.
# 4. Batch insertion with `ON CONFLICT DO UPDATE` allows us to insert or update all
# variables in a single SQL statement, avoiding the issues above.
#
# For these reasons, we use the SQLAlchemy query builder and rely on dialect-specific
# insert operations instead of the ORM layer.
if node_type == NodeType.CODE:
# Clear existing variable for code node.
self._session.query(WorkflowDraftVariable).filter(
WorkflowDraftVariable.app_id == app_id,
WorkflowDraftVariable.node_id == node_id,
).delete(synchronize_session=False)
stmt = insert(WorkflowDraftVariable).values([_model_to_insertion_dict(v) for v in draft_variables])
stmt = stmt.on_conflict_do_update(
index_elements=WorkflowDraftVariable.unique_app_id_node_id_name(),
set_={
"updated_at": stmt.excluded.updated_at,
"last_edited_at": stmt.excluded.last_edited_at,
"description": stmt.excluded.description,
"value_type": stmt.excluded.value_type,
"value": stmt.excluded.value,
"visible": stmt.excluded.visible,
"editable": stmt.excluded.editable,
},
)
self._session.execute(stmt)
def list_variables_without_values(self, app_id: str, page: int, limit: int) -> WorkflowDraftVariableList:
criteria = WorkflowDraftVariable.app_id == app_id
total = None
query = self._session.query(WorkflowDraftVariable).filter(criteria)
if page == 1:
total = query.count()
variables = (
# Do not load the `value` field.
query.options(orm.defer(WorkflowDraftVariable.value))
.order_by(WorkflowDraftVariable.id.desc())
.limit(limit)
.offset((page - 1) * limit)
.all()
)
return WorkflowDraftVariableList(variables=variables, total=total)
def _list_node_variables(self, app_id: str, node_id: str) -> WorkflowDraftVariableList:
criteria = (
WorkflowDraftVariable.app_id == app_id,
WorkflowDraftVariable.node_id == node_id,
)
query = self._session.query(WorkflowDraftVariable).filter(*criteria)
variables = query.order_by(WorkflowDraftVariable.id.desc()).all()
return WorkflowDraftVariableList(variables=variables)
def list_node_variables(self, app_id: str, node_id: str) -> WorkflowDraftVariableList:
return self._list_node_variables(app_id, node_id)
def list_conversation_variables(self, app_id: str) -> WorkflowDraftVariableList:
return self._list_node_variables(app_id, CONVERSATION_VARIABLE_NODE_ID)
def list_system_variables(self, app_id: str) -> WorkflowDraftVariableList:
return self._list_node_variables(app_id, SYSTEM_VARIABLE_NODE_ID)
def get_conversation_variable(self, app_id: str, name: str) -> WorkflowDraftVariable | None:
return self._get_variable(app_id=app_id, node_id=CONVERSATION_VARIABLE_NODE_ID, name=name)
def get_system_variable(self, app_id: str, name: str) -> WorkflowDraftVariable | None:
return self._get_variable(app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, name=name)
def get_node_variable(self, app_id: str, node_id: str, name: str) -> WorkflowDraftVariable | None:
return self._get_variable(app_id, node_id, name)
def _get_variable(self, app_id: str, node_id: str, name: str) -> WorkflowDraftVariable | None:
variable = (
self._session.query(WorkflowDraftVariable)
.where(
WorkflowDraftVariable.app_id == app_id,
WorkflowDraftVariable.node_id == node_id,
WorkflowDraftVariable.name == name,
)
.first()
)
return variable
def update_variable(
self,
variable: WorkflowDraftVariable,
name: str | None = None,
value: Segment | None = None,
) -> WorkflowDraftVariable:
if name is not None:
variable.set_name(name)
if value is not None:
variable.set_value(value)
self._session.flush()
return variable
def delete_variable(self, variable: WorkflowDraftVariable):
self._session.delete(variable)
def delete_workflow_variables(self, app_id: str):
(
self._session.query(WorkflowDraftVariable)
.filter(WorkflowDraftVariable.app_id == app_id)
.delete(synchronize_session=False)
)
def delete_node_variables(self, app_id: str, node_id: str):
return self._delete_node_variables(app_id, node_id)
def _delete_node_variables(self, app_id: str, node_id: str):
self._session.query(WorkflowDraftVariable).where(
WorkflowDraftVariable.app_id == app_id,
WorkflowDraftVariable.node_id == node_id,
).delete()
def _model_to_insertion_dict(model: WorkflowDraftVariable) -> dict[str, Any]:
d: dict[str, Any] = {
"app_id": model.app_id,
"last_edited_at": None,
"node_id": model.node_id,
"name": model.name,
"selector": model.selector,
"value_type": model.value_type,
"value": model.value,
}
if model.visible is not None:
d["visible"] = model.visible
if model.editable is not None:
d["editable"] = model.editable
if model.created_at is not None:
d["created_at"] = model.created_at
if model.updated_at is not None:
d["updated_at"] = model.updated_at
if model.description is not None:
d["description"] = model.description
return d
def should_save_output_variables_for_draft(
invoke_from: InvokeFrom, loop_id: str | None, iteration_id: str | None
) -> bool:
# Only save output variables for debugging execution of workflow.
if invoke_from != InvokeFrom.DEBUGGER:
return False
# Currently we do not save output variables for nodes inside loop or iteration.
if loop_id is not None:
return False
if iteration_id is not None:
return False
return True
class _DraftVariableBuilder:
_app_id: str
_draft_vars: list[WorkflowDraftVariable]
def __init__(self, app_id: str):
self._app_id = app_id
self._draft_vars: list[WorkflowDraftVariable] = []
def _build_from_variable_assigner_mapping(self, node_id: str, output: Mapping[str, Any]):
updated_variables = output.get("updated_variables", [])
for item in updated_variables:
selector = item.get("selector")
if selector is None:
continue
if len(selector) < MIN_SELECTORS_LENGTH:
raise Exception("selector too short")
# NOTE(QuantumGhost): only the following two kinds of variable could be updated by
# VariableAssigner: ConversationVariable and iteration variable.
# We only save conversation variable here.
if selector[0] != CONVERSATION_VARIABLE_NODE_ID:
continue
name = item.get("name")
if name is None:
continue
new_value = item["new_value"]
value_type = item.get("type")
if value_type is None:
continue
var_seg = variable_factory.build_segment(new_value)
if var_seg.value_type != value_type:
raise Exception("value_type mismatch!")
self._draft_vars.append(
WorkflowDraftVariable.new_conversation_variable(
app_id=self._app_id,
name=name,
value=var_seg,
)
)
def _build_variables_from_start_mapping(
self,
node_id: str,
output: Mapping[str, Any],
):
original_node_id = node_id
for name, value in output.items():
value_seg = variable_factory.build_segment(value)
node_id, name = self._normalize_variable_for_start_node(node_id, name)
if node_id != SYSTEM_VARIABLE_NODE_ID:
self._draft_vars.append(
WorkflowDraftVariable.new_node_variable(
app_id=self._app_id,
node_id=original_node_id,
name=name,
value=value_seg,
visible=False,
editable=False,
)
)
else:
self._draft_vars.append(
WorkflowDraftVariable.new_sys_variable(
app_id=self._app_id,
name=name,
value=value_seg,
editable=self._should_variable_be_editable(node_id, name),
)
)
@staticmethod
def _normalize_variable_for_start_node(node_id: str, name: str) -> tuple[str, str]:
if not name.startswith(f"{SYSTEM_VARIABLE_NODE_ID}."):
return node_id, name
node_id, name_ = name.split(".", maxsplit=1)
return node_id, name_
def _build_variables_from_mapping(
self,
node_id: str,
node_type: NodeType,
output: Mapping[str, Any],
):
for name, value in output.items():
value_seg = variable_factory.build_segment(value)
self._draft_vars.append(
WorkflowDraftVariable.new_node_variable(
app_id=self._app_id,
node_id=node_id,
name=name,
value=value_seg,
visible=self._should_variable_be_visible(node_type, node_id, name),
)
)
def build(
self,
node_id: str,
node_type: NodeType,
output: Mapping[str, Any],
):
if node_type == NodeType.VARIABLE_ASSIGNER:
self._build_from_variable_assigner_mapping(node_id, output)
elif node_type == NodeType.START:
self._build_variables_from_start_mapping(node_id, output)
else:
self._build_variables_from_mapping(node_id, node_type, output)
def get_variables(self) -> Sequence[WorkflowDraftVariable]:
return self._draft_vars
@staticmethod
def _should_variable_be_editable(node_id: str, name: str) -> bool:
if node_id in (CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID):
return False
if node_id == SYSTEM_VARIABLE_NODE_ID and not is_system_variable_editable(name):
return False
return True
@staticmethod
def _should_variable_be_visible(node_type: NodeType, node_id: str, name: str) -> bool:
if node_type in (NodeType.IF_ELSE, NodeType.START):
return False
if node_id == SYSTEM_VARIABLE_NODE_ID and not is_system_variable_editable(name):
return False
return True
# @staticmethod
# def _normalize_variable(node_type: NodeType, node_id: str, name: str) -> tuple[str, str]:
# if node_type != NodeType.START:
# return node_id, name
#
# # TODO(QuantumGhost): need special handling for dummy output variable in
# # `Start` node.
# if not name.startswith(f"{SYSTEM_VARIABLE_NODE_ID}."):
# return node_id, name
# logging.getLogger(__name__).info(
# "Normalizing variable: node_type=%s, node_id=%s, name=%s",
# node_type,
# node_id,
# name,
# )
# node_id, name_ = name.split(".", maxsplit=1)
# return node_id, name_