diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index b2cfa23aa2..4225ad7c44 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -2,12 +2,13 @@ import logging import time import uuid from collections.abc import Callable, Generator, Mapping, Sequence -from typing import Any, Optional, TypeAlias, TypeVar, cast +from typing import Any, Optional, TypeAlias, cast from configs import dify_config from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.file.models import File +from core.variables import Variable from core.workflow.callbacks import WorkflowCallback from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID from core.workflow.entities.node_entities import NodeRunMetadataKey @@ -22,13 +23,27 @@ from core.workflow.nodes import NodeType from core.workflow.nodes.base import BaseNode from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +from core.workflow.variable_loader import VariableLoader from factories import file_factory +from libs import gen_utils from models.enums import UserFrom from models.workflow import ( Workflow, WorkflowType, ) + +class _DummyVariableLoader(VariableLoader): + """A dummy implementation of VariableLoader that does not load any variables. + Serves as a placeholder when no variable loading is needed. + """ + + def load_variables(self, selectors: list[list[str]]) -> list[Variable]: + return [] + + +_DUMMY_VARIABLE_LOADER = _DummyVariableLoader() + logger = logging.getLogger(__name__) @@ -122,6 +137,7 @@ class WorkflowEntry: user_id: str, user_inputs: dict, conversation_variables: dict | None = None, + variable_loader: VariableLoader = _DUMMY_VARIABLE_LOADER, ) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]: """ Single step run workflow node @@ -190,6 +206,19 @@ class WorkflowEntry: except NotImplementedError: variable_mapping = {} + # Loading missing variable from draft var here, and set it into + # variable_pool. + variables_to_load: list[list[str]] = [] + for key, selector in variable_mapping.items(): + trimmed_key = key.removeprefix(f"{node_id}.") + if trimmed_key in user_inputs: + continue + if variable_pool.get(selector) is None: + variables_to_load.append(list(selector)) + loaded = variable_loader.load_variables(variables_to_load) + for var in loaded: + variable_pool.add(var.selector, var.value) + cls.mapping_user_inputs_to_variable_pool( variable_mapping=variable_mapping, user_inputs=user_inputs, @@ -204,7 +233,7 @@ class WorkflowEntry: except Exception as e: raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) if metadata_attacher: - generator = _wrap_generator(generator, metadata_attacher) + generator = gen_utils.map_(generator, metadata_attacher) return node_instance, generator @classmethod @@ -391,18 +420,6 @@ class WorkflowEntry: variable_pool.add([variable_node_id] + variable_key_list, input_value) -_YieldT_co = TypeVar("_YieldT_co", covariant=True) -_YieldR_co = TypeVar("_YieldR_co", covariant=True) - - -def _wrap_generator( - gen: Generator[_YieldT_co, None, None], - mapper: Callable[[_YieldT_co], _YieldR_co], -) -> Generator[_YieldR_co, None, None]: - for item in gen: - yield mapper(item) - - _NodeOrInNodeEvent: TypeAlias = NodeEvent | InNodeEvent diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 15ff43b6e5..a37efd9000 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -3,17 +3,19 @@ import logging from collections.abc import Mapping, Sequence from typing import Any -from sqlalchemy import orm +from sqlalchemy import Engine, orm from sqlalchemy.dialects.postgresql import insert from sqlalchemy.orm import Session +from sqlalchemy.sql.expression import and_, or_ from core.app.entities.app_invoke_entities import InvokeFrom -from core.file.constants import is_dummy_output_variable -from core.variables import Segment +from core.variables import Segment, Variable from core.variables.consts import MIN_SELECTORS_LENGTH from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.nodes import NodeType +from core.workflow.variable_loader import VariableLoader from factories import variable_factory +from factories.variable_factory import build_segment, segment_to_variable from models.workflow import WorkflowDraftVariable, is_system_variable_editable _logger = logging.getLogger(__name__) @@ -25,6 +27,36 @@ class WorkflowDraftVariableList: total: int | None = None +class DraftVarLoader(VariableLoader): + # This implements the VariableLoader interface for loading draft variables. + # + # ref: core.workflow.variable_loader.VariableLoader + def __init__(self, engine: Engine, app_id: str) -> None: + self._engine = engine + self._app_id = app_id + + def load_variables(self, selectors: list[list[str]]) -> list[Variable]: + if not selectors: + return [] + with Session(bind=self._engine, expire_on_commit=False) as session: + srv = WorkflowDraftVariableService(session) + draft_vars = srv.get_draft_variables_by_selectors(self._app_id, selectors) + variables = [] + for draft_var in draft_vars: + segment = build_segment( + draft_var.value, + ) + variable = segment_to_variable( + segment=segment, + selector=draft_var.get_selector(), + id=draft_var.id, + name=draft_var.name, + description=draft_var.description, + ) + variables.append(variable) + return variables + + class WorkflowDraftVariableService: _session: Session @@ -34,6 +66,28 @@ class WorkflowDraftVariableService: def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None: return self._session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.id == variable_id).first() + def get_draft_variables_by_selectors( + self, + app_id: str, + selectors: Sequence[list[str]], + ) -> list[WorkflowDraftVariable]: + ors = [] + for selector in selectors: + node_id, name = selector + ors.append(and_(WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.name == name)) + + # NOTE(QuantumGhost): Although the number of `or` expressions may be large, as long as + # each expression includes conditions on both `node_id` and `name` (which are covered by the unique index), + # PostgreSQL can efficiently retrieve the results using a bitmap index scan. + # + # Alternatively, a `SELECT` statement could be constructed for each selector and + # combined using `UNION` to fetch all rows. + # Benchmarking indicates that both approaches yield comparable performance. + variables = ( + self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == app_id, or_(*ors)).all() + ) + return variables + def save_output_variables(self, app_id: str, node_id: str, node_type: NodeType, output: Mapping[str, Any]): variable_builder = _DraftVariableBuilder(app_id=app_id) variable_builder.build(node_id=node_id, node_type=node_type, output=output) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index b44f2706a8..23e19235b6 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -3,7 +3,6 @@ import logging import time from collections.abc import Callable, Generator, Sequence from datetime import UTC, datetime -from inspect import isgenerator from typing import Any, Optional from uuid import uuid4 @@ -28,6 +27,7 @@ from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_M from core.workflow.workflow_entry import WorkflowEntry from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db +from libs import gen_utils from models.account import Account from models.model import App, AppMode from models.tools import WorkflowToolProvider @@ -42,7 +42,11 @@ from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError from services.workflow.workflow_converter import WorkflowConverter from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError -from .workflow_draft_variable_service import WorkflowDraftVariableService, should_save_output_variables_for_draft +from .workflow_draft_variable_service import ( + DraftVarLoader, + WorkflowDraftVariableService, + should_save_output_variables_for_draft, +) class WorkflowService: @@ -310,26 +314,28 @@ class WorkflowService: if not draft_workflow: raise ValueError("Workflow not initialized") - # conv_vars = common_helpers.get_conversation_variables() - - # run draft workflow node - start_at = time.perf_counter() + # TODO(QuantumGhost): We may get rid of the `list_conversation_variables` + # here, and rely on `DraftVarLoader` to load conversation variables. with Session(bind=db.engine) as session: - # TODO(QunatumGhost): inject conversation variables - # to variable pool. draft_var_srv = WorkflowDraftVariableService(session) conv_vars_list = draft_var_srv.list_conversation_variables(app_id=app_model.id) conv_var_mapping = {v.name: v.get_value().value for v in conv_vars_list.variables} + variable_loader = DraftVarLoader(engine=db.engine, app_id=app_model.id) + run = WorkflowEntry.single_step_run( + workflow=draft_workflow, + node_id=node_id, + user_inputs=user_inputs, + user_id=account.id, + conversation_variables=conv_var_mapping, + variable_loader=variable_loader, + ) + + # run draft workflow node + start_at = time.perf_counter() node_execution = self._handle_node_run_result( - invoke_node_fn=lambda: WorkflowEntry.single_step_run( - workflow=draft_workflow, - node_id=node_id, - user_inputs=user_inputs, - user_id=account.id, - conversation_variables=conv_var_mapping, - ), + invoke_node_fn=lambda: run, start_at=start_at, node_id=node_id, ) @@ -403,7 +409,7 @@ class WorkflowService: ) -> NodeExecution: try: node_instance, generator = invoke_node_fn() - generator = _inspect_generator(generator) + generator = gen_utils.inspect(generator, logging.getLogger(__name__)) node_run_result: NodeRunResult | None = None for event in generator: @@ -610,19 +616,3 @@ class WorkflowService: session.delete(workflow) return True - - -def _inspect_generator(gen: Generator[Any] | Any) -> Any: - if not isgenerator(gen): - return gen - - def wrapper(): - for item in gen: - logging.getLogger(__name__).info( - "received generator item, type=%s, value=%s", - type(item), - item, - ) - yield item - - return wrapper()