feat(api): Implement VariableLoader in DraftVarLoader.

Use DraftVarLoader to load required variables when single stepping a
node.
This commit is contained in:
QuantumGhost 2025-05-26 14:20:56 +08:00
parent 141f9b4d51
commit defe8fea63
3 changed files with 110 additions and 49 deletions

View File

@ -2,12 +2,13 @@ import logging
import time import time
import uuid import uuid
from collections.abc import Callable, Generator, Mapping, Sequence from collections.abc import Callable, Generator, Mapping, Sequence
from typing import Any, Optional, TypeAlias, TypeVar, cast from typing import Any, Optional, TypeAlias, cast
from configs import dify_config from configs import dify_config
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File from core.file.models import File
from core.variables import Variable
from core.workflow.callbacks import WorkflowCallback from core.workflow.callbacks import WorkflowCallback
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.entities.node_entities import NodeRunMetadataKey
@ -22,13 +23,27 @@ from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.variable_loader import VariableLoader
from factories import file_factory from factories import file_factory
from libs import gen_utils
from models.enums import UserFrom from models.enums import UserFrom
from models.workflow import ( from models.workflow import (
Workflow, Workflow,
WorkflowType, WorkflowType,
) )
class _DummyVariableLoader(VariableLoader):
"""A dummy implementation of VariableLoader that does not load any variables.
Serves as a placeholder when no variable loading is needed.
"""
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
return []
_DUMMY_VARIABLE_LOADER = _DummyVariableLoader()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -122,6 +137,7 @@ class WorkflowEntry:
user_id: str, user_id: str,
user_inputs: dict, user_inputs: dict,
conversation_variables: dict | None = None, conversation_variables: dict | None = None,
variable_loader: VariableLoader = _DUMMY_VARIABLE_LOADER,
) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]: ) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]:
""" """
Single step run workflow node Single step run workflow node
@ -190,6 +206,19 @@ class WorkflowEntry:
except NotImplementedError: except NotImplementedError:
variable_mapping = {} variable_mapping = {}
# Loading missing variable from draft var here, and set it into
# variable_pool.
variables_to_load: list[list[str]] = []
for key, selector in variable_mapping.items():
trimmed_key = key.removeprefix(f"{node_id}.")
if trimmed_key in user_inputs:
continue
if variable_pool.get(selector) is None:
variables_to_load.append(list(selector))
loaded = variable_loader.load_variables(variables_to_load)
for var in loaded:
variable_pool.add(var.selector, var.value)
cls.mapping_user_inputs_to_variable_pool( cls.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping, variable_mapping=variable_mapping,
user_inputs=user_inputs, user_inputs=user_inputs,
@ -204,7 +233,7 @@ class WorkflowEntry:
except Exception as e: except Exception as e:
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
if metadata_attacher: if metadata_attacher:
generator = _wrap_generator(generator, metadata_attacher) generator = gen_utils.map_(generator, metadata_attacher)
return node_instance, generator return node_instance, generator
@classmethod @classmethod
@ -391,18 +420,6 @@ class WorkflowEntry:
variable_pool.add([variable_node_id] + variable_key_list, input_value) variable_pool.add([variable_node_id] + variable_key_list, input_value)
_YieldT_co = TypeVar("_YieldT_co", covariant=True)
_YieldR_co = TypeVar("_YieldR_co", covariant=True)
def _wrap_generator(
gen: Generator[_YieldT_co, None, None],
mapper: Callable[[_YieldT_co], _YieldR_co],
) -> Generator[_YieldR_co, None, None]:
for item in gen:
yield mapper(item)
_NodeOrInNodeEvent: TypeAlias = NodeEvent | InNodeEvent _NodeOrInNodeEvent: TypeAlias = NodeEvent | InNodeEvent

View File

@ -3,17 +3,19 @@ import logging
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any from typing import Any
from sqlalchemy import orm from sqlalchemy import Engine, orm
from sqlalchemy.dialects.postgresql import insert from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import and_, or_
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.constants import is_dummy_output_variable from core.variables import Segment, Variable
from core.variables import Segment
from core.variables.consts import MIN_SELECTORS_LENGTH from core.variables.consts import MIN_SELECTORS_LENGTH
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
from core.workflow.variable_loader import VariableLoader
from factories import variable_factory from factories import variable_factory
from factories.variable_factory import build_segment, segment_to_variable
from models.workflow import WorkflowDraftVariable, is_system_variable_editable from models.workflow import WorkflowDraftVariable, is_system_variable_editable
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@ -25,6 +27,36 @@ class WorkflowDraftVariableList:
total: int | None = None total: int | None = None
class DraftVarLoader(VariableLoader):
# This implements the VariableLoader interface for loading draft variables.
#
# ref: core.workflow.variable_loader.VariableLoader
def __init__(self, engine: Engine, app_id: str) -> None:
self._engine = engine
self._app_id = app_id
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
if not selectors:
return []
with Session(bind=self._engine, expire_on_commit=False) as session:
srv = WorkflowDraftVariableService(session)
draft_vars = srv.get_draft_variables_by_selectors(self._app_id, selectors)
variables = []
for draft_var in draft_vars:
segment = build_segment(
draft_var.value,
)
variable = segment_to_variable(
segment=segment,
selector=draft_var.get_selector(),
id=draft_var.id,
name=draft_var.name,
description=draft_var.description,
)
variables.append(variable)
return variables
class WorkflowDraftVariableService: class WorkflowDraftVariableService:
_session: Session _session: Session
@ -34,6 +66,28 @@ class WorkflowDraftVariableService:
def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None: def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None:
return self._session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.id == variable_id).first() return self._session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.id == variable_id).first()
def get_draft_variables_by_selectors(
self,
app_id: str,
selectors: Sequence[list[str]],
) -> list[WorkflowDraftVariable]:
ors = []
for selector in selectors:
node_id, name = selector
ors.append(and_(WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.name == name))
# NOTE(QuantumGhost): Although the number of `or` expressions may be large, as long as
# each expression includes conditions on both `node_id` and `name` (which are covered by the unique index),
# PostgreSQL can efficiently retrieve the results using a bitmap index scan.
#
# Alternatively, a `SELECT` statement could be constructed for each selector and
# combined using `UNION` to fetch all rows.
# Benchmarking indicates that both approaches yield comparable performance.
variables = (
self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == app_id, or_(*ors)).all()
)
return variables
def save_output_variables(self, app_id: str, node_id: str, node_type: NodeType, output: Mapping[str, Any]): def save_output_variables(self, app_id: str, node_id: str, node_type: NodeType, output: Mapping[str, Any]):
variable_builder = _DraftVariableBuilder(app_id=app_id) variable_builder = _DraftVariableBuilder(app_id=app_id)
variable_builder.build(node_id=node_id, node_type=node_type, output=output) variable_builder.build(node_id=node_id, node_type=node_type, output=output)

View File

@ -3,7 +3,6 @@ import logging
import time import time
from collections.abc import Callable, Generator, Sequence from collections.abc import Callable, Generator, Sequence
from datetime import UTC, datetime from datetime import UTC, datetime
from inspect import isgenerator
from typing import Any, Optional from typing import Any, Optional
from uuid import uuid4 from uuid import uuid4
@ -28,6 +27,7 @@ from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_M
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
from extensions.ext_database import db from extensions.ext_database import db
from libs import gen_utils
from models.account import Account from models.account import Account
from models.model import App, AppMode from models.model import App, AppMode
from models.tools import WorkflowToolProvider from models.tools import WorkflowToolProvider
@ -42,7 +42,11 @@ from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError
from services.workflow.workflow_converter import WorkflowConverter from services.workflow.workflow_converter import WorkflowConverter
from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
from .workflow_draft_variable_service import WorkflowDraftVariableService, should_save_output_variables_for_draft from .workflow_draft_variable_service import (
DraftVarLoader,
WorkflowDraftVariableService,
should_save_output_variables_for_draft,
)
class WorkflowService: class WorkflowService:
@ -310,26 +314,28 @@ class WorkflowService:
if not draft_workflow: if not draft_workflow:
raise ValueError("Workflow not initialized") raise ValueError("Workflow not initialized")
# conv_vars = common_helpers.get_conversation_variables() # TODO(QuantumGhost): We may get rid of the `list_conversation_variables`
# here, and rely on `DraftVarLoader` to load conversation variables.
# run draft workflow node
start_at = time.perf_counter()
with Session(bind=db.engine) as session: with Session(bind=db.engine) as session:
# TODO(QunatumGhost): inject conversation variables
# to variable pool.
draft_var_srv = WorkflowDraftVariableService(session) draft_var_srv = WorkflowDraftVariableService(session)
conv_vars_list = draft_var_srv.list_conversation_variables(app_id=app_model.id) conv_vars_list = draft_var_srv.list_conversation_variables(app_id=app_model.id)
conv_var_mapping = {v.name: v.get_value().value for v in conv_vars_list.variables} conv_var_mapping = {v.name: v.get_value().value for v in conv_vars_list.variables}
node_execution = self._handle_node_run_result( variable_loader = DraftVarLoader(engine=db.engine, app_id=app_model.id)
invoke_node_fn=lambda: WorkflowEntry.single_step_run( run = WorkflowEntry.single_step_run(
workflow=draft_workflow, workflow=draft_workflow,
node_id=node_id, node_id=node_id,
user_inputs=user_inputs, user_inputs=user_inputs,
user_id=account.id, user_id=account.id,
conversation_variables=conv_var_mapping, conversation_variables=conv_var_mapping,
), variable_loader=variable_loader,
)
# run draft workflow node
start_at = time.perf_counter()
node_execution = self._handle_node_run_result(
invoke_node_fn=lambda: run,
start_at=start_at, start_at=start_at,
node_id=node_id, node_id=node_id,
) )
@ -403,7 +409,7 @@ class WorkflowService:
) -> NodeExecution: ) -> NodeExecution:
try: try:
node_instance, generator = invoke_node_fn() node_instance, generator = invoke_node_fn()
generator = _inspect_generator(generator) generator = gen_utils.inspect(generator, logging.getLogger(__name__))
node_run_result: NodeRunResult | None = None node_run_result: NodeRunResult | None = None
for event in generator: for event in generator:
@ -610,19 +616,3 @@ class WorkflowService:
session.delete(workflow) session.delete(workflow)
return True return True
def _inspect_generator(gen: Generator[Any] | Any) -> Any:
if not isgenerator(gen):
return gen
def wrapper():
for item in gen:
logging.getLogger(__name__).info(
"received generator item, type=%s, value=%s",
type(item),
item,
)
yield item
return wrapper()