mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-07-05 16:25:12 +08:00
feat(api): Implement VariableLoader in DraftVarLoader
.
Use DraftVarLoader to load required variables when single stepping a node.
This commit is contained in:
parent
141f9b4d51
commit
defe8fea63
@ -2,12 +2,13 @@ import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, TypeAlias, TypeVar, cast
|
||||
from typing import Any, Optional, TypeAlias, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from core.variables import Variable
|
||||
from core.workflow.callbacks import WorkflowCallback
|
||||
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
@ -22,13 +23,27 @@ from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from factories import file_factory
|
||||
from libs import gen_utils
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowType,
|
||||
)
|
||||
|
||||
|
||||
class _DummyVariableLoader(VariableLoader):
|
||||
"""A dummy implementation of VariableLoader that does not load any variables.
|
||||
Serves as a placeholder when no variable loading is needed.
|
||||
"""
|
||||
|
||||
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
|
||||
return []
|
||||
|
||||
|
||||
_DUMMY_VARIABLE_LOADER = _DummyVariableLoader()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -122,6 +137,7 @@ class WorkflowEntry:
|
||||
user_id: str,
|
||||
user_inputs: dict,
|
||||
conversation_variables: dict | None = None,
|
||||
variable_loader: VariableLoader = _DUMMY_VARIABLE_LOADER,
|
||||
) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]:
|
||||
"""
|
||||
Single step run workflow node
|
||||
@ -190,6 +206,19 @@ class WorkflowEntry:
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
||||
# Loading missing variable from draft var here, and set it into
|
||||
# variable_pool.
|
||||
variables_to_load: list[list[str]] = []
|
||||
for key, selector in variable_mapping.items():
|
||||
trimmed_key = key.removeprefix(f"{node_id}.")
|
||||
if trimmed_key in user_inputs:
|
||||
continue
|
||||
if variable_pool.get(selector) is None:
|
||||
variables_to_load.append(list(selector))
|
||||
loaded = variable_loader.load_variables(variables_to_load)
|
||||
for var in loaded:
|
||||
variable_pool.add(var.selector, var.value)
|
||||
|
||||
cls.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
@ -204,7 +233,7 @@ class WorkflowEntry:
|
||||
except Exception as e:
|
||||
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
|
||||
if metadata_attacher:
|
||||
generator = _wrap_generator(generator, metadata_attacher)
|
||||
generator = gen_utils.map_(generator, metadata_attacher)
|
||||
return node_instance, generator
|
||||
|
||||
@classmethod
|
||||
@ -391,18 +420,6 @@ class WorkflowEntry:
|
||||
variable_pool.add([variable_node_id] + variable_key_list, input_value)
|
||||
|
||||
|
||||
_YieldT_co = TypeVar("_YieldT_co", covariant=True)
|
||||
_YieldR_co = TypeVar("_YieldR_co", covariant=True)
|
||||
|
||||
|
||||
def _wrap_generator(
|
||||
gen: Generator[_YieldT_co, None, None],
|
||||
mapper: Callable[[_YieldT_co], _YieldR_co],
|
||||
) -> Generator[_YieldR_co, None, None]:
|
||||
for item in gen:
|
||||
yield mapper(item)
|
||||
|
||||
|
||||
_NodeOrInNodeEvent: TypeAlias = NodeEvent | InNodeEvent
|
||||
|
||||
|
||||
|
@ -3,17 +3,19 @@ import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import orm
|
||||
from sqlalchemy import Engine, orm
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import and_, or_
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.constants import is_dummy_output_variable
|
||||
from core.variables import Segment
|
||||
from core.variables import Segment, Variable
|
||||
from core.variables.consts import MIN_SELECTORS_LENGTH
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from factories import variable_factory
|
||||
from factories.variable_factory import build_segment, segment_to_variable
|
||||
from models.workflow import WorkflowDraftVariable, is_system_variable_editable
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
@ -25,6 +27,36 @@ class WorkflowDraftVariableList:
|
||||
total: int | None = None
|
||||
|
||||
|
||||
class DraftVarLoader(VariableLoader):
|
||||
# This implements the VariableLoader interface for loading draft variables.
|
||||
#
|
||||
# ref: core.workflow.variable_loader.VariableLoader
|
||||
def __init__(self, engine: Engine, app_id: str) -> None:
|
||||
self._engine = engine
|
||||
self._app_id = app_id
|
||||
|
||||
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
|
||||
if not selectors:
|
||||
return []
|
||||
with Session(bind=self._engine, expire_on_commit=False) as session:
|
||||
srv = WorkflowDraftVariableService(session)
|
||||
draft_vars = srv.get_draft_variables_by_selectors(self._app_id, selectors)
|
||||
variables = []
|
||||
for draft_var in draft_vars:
|
||||
segment = build_segment(
|
||||
draft_var.value,
|
||||
)
|
||||
variable = segment_to_variable(
|
||||
segment=segment,
|
||||
selector=draft_var.get_selector(),
|
||||
id=draft_var.id,
|
||||
name=draft_var.name,
|
||||
description=draft_var.description,
|
||||
)
|
||||
variables.append(variable)
|
||||
return variables
|
||||
|
||||
|
||||
class WorkflowDraftVariableService:
|
||||
_session: Session
|
||||
|
||||
@ -34,6 +66,28 @@ class WorkflowDraftVariableService:
|
||||
def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None:
|
||||
return self._session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.id == variable_id).first()
|
||||
|
||||
def get_draft_variables_by_selectors(
|
||||
self,
|
||||
app_id: str,
|
||||
selectors: Sequence[list[str]],
|
||||
) -> list[WorkflowDraftVariable]:
|
||||
ors = []
|
||||
for selector in selectors:
|
||||
node_id, name = selector
|
||||
ors.append(and_(WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.name == name))
|
||||
|
||||
# NOTE(QuantumGhost): Although the number of `or` expressions may be large, as long as
|
||||
# each expression includes conditions on both `node_id` and `name` (which are covered by the unique index),
|
||||
# PostgreSQL can efficiently retrieve the results using a bitmap index scan.
|
||||
#
|
||||
# Alternatively, a `SELECT` statement could be constructed for each selector and
|
||||
# combined using `UNION` to fetch all rows.
|
||||
# Benchmarking indicates that both approaches yield comparable performance.
|
||||
variables = (
|
||||
self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == app_id, or_(*ors)).all()
|
||||
)
|
||||
return variables
|
||||
|
||||
def save_output_variables(self, app_id: str, node_id: str, node_type: NodeType, output: Mapping[str, Any]):
|
||||
variable_builder = _DraftVariableBuilder(app_id=app_id)
|
||||
variable_builder.build(node_id=node_id, node_type=node_type, output=output)
|
||||
|
@ -3,7 +3,6 @@ import logging
|
||||
import time
|
||||
from collections.abc import Callable, Generator, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from inspect import isgenerator
|
||||
from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
@ -28,6 +27,7 @@ from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_M
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
||||
from extensions.ext_database import db
|
||||
from libs import gen_utils
|
||||
from models.account import Account
|
||||
from models.model import App, AppMode
|
||||
from models.tools import WorkflowToolProvider
|
||||
@ -42,7 +42,11 @@ from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError
|
||||
from services.workflow.workflow_converter import WorkflowConverter
|
||||
|
||||
from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
|
||||
from .workflow_draft_variable_service import WorkflowDraftVariableService, should_save_output_variables_for_draft
|
||||
from .workflow_draft_variable_service import (
|
||||
DraftVarLoader,
|
||||
WorkflowDraftVariableService,
|
||||
should_save_output_variables_for_draft,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowService:
|
||||
@ -310,26 +314,28 @@ class WorkflowService:
|
||||
if not draft_workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
# conv_vars = common_helpers.get_conversation_variables()
|
||||
|
||||
# run draft workflow node
|
||||
start_at = time.perf_counter()
|
||||
# TODO(QuantumGhost): We may get rid of the `list_conversation_variables`
|
||||
# here, and rely on `DraftVarLoader` to load conversation variables.
|
||||
with Session(bind=db.engine) as session:
|
||||
# TODO(QunatumGhost): inject conversation variables
|
||||
# to variable pool.
|
||||
draft_var_srv = WorkflowDraftVariableService(session)
|
||||
|
||||
conv_vars_list = draft_var_srv.list_conversation_variables(app_id=app_model.id)
|
||||
conv_var_mapping = {v.name: v.get_value().value for v in conv_vars_list.variables}
|
||||
|
||||
variable_loader = DraftVarLoader(engine=db.engine, app_id=app_model.id)
|
||||
run = WorkflowEntry.single_step_run(
|
||||
workflow=draft_workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
user_id=account.id,
|
||||
conversation_variables=conv_var_mapping,
|
||||
variable_loader=variable_loader,
|
||||
)
|
||||
|
||||
# run draft workflow node
|
||||
start_at = time.perf_counter()
|
||||
node_execution = self._handle_node_run_result(
|
||||
invoke_node_fn=lambda: WorkflowEntry.single_step_run(
|
||||
workflow=draft_workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
user_id=account.id,
|
||||
conversation_variables=conv_var_mapping,
|
||||
),
|
||||
invoke_node_fn=lambda: run,
|
||||
start_at=start_at,
|
||||
node_id=node_id,
|
||||
)
|
||||
@ -403,7 +409,7 @@ class WorkflowService:
|
||||
) -> NodeExecution:
|
||||
try:
|
||||
node_instance, generator = invoke_node_fn()
|
||||
generator = _inspect_generator(generator)
|
||||
generator = gen_utils.inspect(generator, logging.getLogger(__name__))
|
||||
|
||||
node_run_result: NodeRunResult | None = None
|
||||
for event in generator:
|
||||
@ -610,19 +616,3 @@ class WorkflowService:
|
||||
|
||||
session.delete(workflow)
|
||||
return True
|
||||
|
||||
|
||||
def _inspect_generator(gen: Generator[Any] | Any) -> Any:
|
||||
if not isgenerator(gen):
|
||||
return gen
|
||||
|
||||
def wrapper():
|
||||
for item in gen:
|
||||
logging.getLogger(__name__).info(
|
||||
"received generator item, type=%s, value=%s",
|
||||
type(item),
|
||||
item,
|
||||
)
|
||||
yield item
|
||||
|
||||
return wrapper()
|
||||
|
Loading…
x
Reference in New Issue
Block a user