mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-07-05 21:35:11 +08:00
feat(api): Implement VariableLoader in DraftVarLoader
.
Use DraftVarLoader to load required variables when single stepping a node.
This commit is contained in:
parent
141f9b4d51
commit
defe8fea63
@ -2,12 +2,13 @@ import logging
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||||
from typing import Any, Optional, TypeAlias, TypeVar, cast
|
from typing import Any, Optional, TypeAlias, cast
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.file.models import File
|
from core.file.models import File
|
||||||
|
from core.variables import Variable
|
||||||
from core.workflow.callbacks import WorkflowCallback
|
from core.workflow.callbacks import WorkflowCallback
|
||||||
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
|
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
|
||||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||||
@ -22,13 +23,27 @@ from core.workflow.nodes import NodeType
|
|||||||
from core.workflow.nodes.base import BaseNode
|
from core.workflow.nodes.base import BaseNode
|
||||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||||
|
from core.workflow.variable_loader import VariableLoader
|
||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
|
from libs import gen_utils
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.workflow import (
|
from models.workflow import (
|
||||||
Workflow,
|
Workflow,
|
||||||
WorkflowType,
|
WorkflowType,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyVariableLoader(VariableLoader):
|
||||||
|
"""A dummy implementation of VariableLoader that does not load any variables.
|
||||||
|
Serves as a placeholder when no variable loading is needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
_DUMMY_VARIABLE_LOADER = _DummyVariableLoader()
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -122,6 +137,7 @@ class WorkflowEntry:
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
user_inputs: dict,
|
user_inputs: dict,
|
||||||
conversation_variables: dict | None = None,
|
conversation_variables: dict | None = None,
|
||||||
|
variable_loader: VariableLoader = _DUMMY_VARIABLE_LOADER,
|
||||||
) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]:
|
) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]:
|
||||||
"""
|
"""
|
||||||
Single step run workflow node
|
Single step run workflow node
|
||||||
@ -190,6 +206,19 @@ class WorkflowEntry:
|
|||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
variable_mapping = {}
|
variable_mapping = {}
|
||||||
|
|
||||||
|
# Loading missing variable from draft var here, and set it into
|
||||||
|
# variable_pool.
|
||||||
|
variables_to_load: list[list[str]] = []
|
||||||
|
for key, selector in variable_mapping.items():
|
||||||
|
trimmed_key = key.removeprefix(f"{node_id}.")
|
||||||
|
if trimmed_key in user_inputs:
|
||||||
|
continue
|
||||||
|
if variable_pool.get(selector) is None:
|
||||||
|
variables_to_load.append(list(selector))
|
||||||
|
loaded = variable_loader.load_variables(variables_to_load)
|
||||||
|
for var in loaded:
|
||||||
|
variable_pool.add(var.selector, var.value)
|
||||||
|
|
||||||
cls.mapping_user_inputs_to_variable_pool(
|
cls.mapping_user_inputs_to_variable_pool(
|
||||||
variable_mapping=variable_mapping,
|
variable_mapping=variable_mapping,
|
||||||
user_inputs=user_inputs,
|
user_inputs=user_inputs,
|
||||||
@ -204,7 +233,7 @@ class WorkflowEntry:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
|
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
|
||||||
if metadata_attacher:
|
if metadata_attacher:
|
||||||
generator = _wrap_generator(generator, metadata_attacher)
|
generator = gen_utils.map_(generator, metadata_attacher)
|
||||||
return node_instance, generator
|
return node_instance, generator
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -391,18 +420,6 @@ class WorkflowEntry:
|
|||||||
variable_pool.add([variable_node_id] + variable_key_list, input_value)
|
variable_pool.add([variable_node_id] + variable_key_list, input_value)
|
||||||
|
|
||||||
|
|
||||||
_YieldT_co = TypeVar("_YieldT_co", covariant=True)
|
|
||||||
_YieldR_co = TypeVar("_YieldR_co", covariant=True)
|
|
||||||
|
|
||||||
|
|
||||||
def _wrap_generator(
|
|
||||||
gen: Generator[_YieldT_co, None, None],
|
|
||||||
mapper: Callable[[_YieldT_co], _YieldR_co],
|
|
||||||
) -> Generator[_YieldR_co, None, None]:
|
|
||||||
for item in gen:
|
|
||||||
yield mapper(item)
|
|
||||||
|
|
||||||
|
|
||||||
_NodeOrInNodeEvent: TypeAlias = NodeEvent | InNodeEvent
|
_NodeOrInNodeEvent: TypeAlias = NodeEvent | InNodeEvent
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,17 +3,19 @@ import logging
|
|||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy import orm
|
from sqlalchemy import Engine, orm
|
||||||
from sqlalchemy.dialects.postgresql import insert
|
from sqlalchemy.dialects.postgresql import insert
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy.sql.expression import and_, or_
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.file.constants import is_dummy_output_variable
|
from core.variables import Segment, Variable
|
||||||
from core.variables import Segment
|
|
||||||
from core.variables.consts import MIN_SELECTORS_LENGTH
|
from core.variables.consts import MIN_SELECTORS_LENGTH
|
||||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||||
from core.workflow.nodes import NodeType
|
from core.workflow.nodes import NodeType
|
||||||
|
from core.workflow.variable_loader import VariableLoader
|
||||||
from factories import variable_factory
|
from factories import variable_factory
|
||||||
|
from factories.variable_factory import build_segment, segment_to_variable
|
||||||
from models.workflow import WorkflowDraftVariable, is_system_variable_editable
|
from models.workflow import WorkflowDraftVariable, is_system_variable_editable
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
@ -25,6 +27,36 @@ class WorkflowDraftVariableList:
|
|||||||
total: int | None = None
|
total: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class DraftVarLoader(VariableLoader):
|
||||||
|
# This implements the VariableLoader interface for loading draft variables.
|
||||||
|
#
|
||||||
|
# ref: core.workflow.variable_loader.VariableLoader
|
||||||
|
def __init__(self, engine: Engine, app_id: str) -> None:
|
||||||
|
self._engine = engine
|
||||||
|
self._app_id = app_id
|
||||||
|
|
||||||
|
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
|
||||||
|
if not selectors:
|
||||||
|
return []
|
||||||
|
with Session(bind=self._engine, expire_on_commit=False) as session:
|
||||||
|
srv = WorkflowDraftVariableService(session)
|
||||||
|
draft_vars = srv.get_draft_variables_by_selectors(self._app_id, selectors)
|
||||||
|
variables = []
|
||||||
|
for draft_var in draft_vars:
|
||||||
|
segment = build_segment(
|
||||||
|
draft_var.value,
|
||||||
|
)
|
||||||
|
variable = segment_to_variable(
|
||||||
|
segment=segment,
|
||||||
|
selector=draft_var.get_selector(),
|
||||||
|
id=draft_var.id,
|
||||||
|
name=draft_var.name,
|
||||||
|
description=draft_var.description,
|
||||||
|
)
|
||||||
|
variables.append(variable)
|
||||||
|
return variables
|
||||||
|
|
||||||
|
|
||||||
class WorkflowDraftVariableService:
|
class WorkflowDraftVariableService:
|
||||||
_session: Session
|
_session: Session
|
||||||
|
|
||||||
@ -34,6 +66,28 @@ class WorkflowDraftVariableService:
|
|||||||
def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None:
|
def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None:
|
||||||
return self._session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.id == variable_id).first()
|
return self._session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.id == variable_id).first()
|
||||||
|
|
||||||
|
def get_draft_variables_by_selectors(
|
||||||
|
self,
|
||||||
|
app_id: str,
|
||||||
|
selectors: Sequence[list[str]],
|
||||||
|
) -> list[WorkflowDraftVariable]:
|
||||||
|
ors = []
|
||||||
|
for selector in selectors:
|
||||||
|
node_id, name = selector
|
||||||
|
ors.append(and_(WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.name == name))
|
||||||
|
|
||||||
|
# NOTE(QuantumGhost): Although the number of `or` expressions may be large, as long as
|
||||||
|
# each expression includes conditions on both `node_id` and `name` (which are covered by the unique index),
|
||||||
|
# PostgreSQL can efficiently retrieve the results using a bitmap index scan.
|
||||||
|
#
|
||||||
|
# Alternatively, a `SELECT` statement could be constructed for each selector and
|
||||||
|
# combined using `UNION` to fetch all rows.
|
||||||
|
# Benchmarking indicates that both approaches yield comparable performance.
|
||||||
|
variables = (
|
||||||
|
self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == app_id, or_(*ors)).all()
|
||||||
|
)
|
||||||
|
return variables
|
||||||
|
|
||||||
def save_output_variables(self, app_id: str, node_id: str, node_type: NodeType, output: Mapping[str, Any]):
|
def save_output_variables(self, app_id: str, node_id: str, node_type: NodeType, output: Mapping[str, Any]):
|
||||||
variable_builder = _DraftVariableBuilder(app_id=app_id)
|
variable_builder = _DraftVariableBuilder(app_id=app_id)
|
||||||
variable_builder.build(node_id=node_id, node_type=node_type, output=output)
|
variable_builder.build(node_id=node_id, node_type=node_type, output=output)
|
||||||
|
@ -3,7 +3,6 @@ import logging
|
|||||||
import time
|
import time
|
||||||
from collections.abc import Callable, Generator, Sequence
|
from collections.abc import Callable, Generator, Sequence
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from inspect import isgenerator
|
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
@ -28,6 +27,7 @@ from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_M
|
|||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from libs import gen_utils
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.model import App, AppMode
|
from models.model import App, AppMode
|
||||||
from models.tools import WorkflowToolProvider
|
from models.tools import WorkflowToolProvider
|
||||||
@ -42,7 +42,11 @@ from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError
|
|||||||
from services.workflow.workflow_converter import WorkflowConverter
|
from services.workflow.workflow_converter import WorkflowConverter
|
||||||
|
|
||||||
from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
|
from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
|
||||||
from .workflow_draft_variable_service import WorkflowDraftVariableService, should_save_output_variables_for_draft
|
from .workflow_draft_variable_service import (
|
||||||
|
DraftVarLoader,
|
||||||
|
WorkflowDraftVariableService,
|
||||||
|
should_save_output_variables_for_draft,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WorkflowService:
|
class WorkflowService:
|
||||||
@ -310,26 +314,28 @@ class WorkflowService:
|
|||||||
if not draft_workflow:
|
if not draft_workflow:
|
||||||
raise ValueError("Workflow not initialized")
|
raise ValueError("Workflow not initialized")
|
||||||
|
|
||||||
# conv_vars = common_helpers.get_conversation_variables()
|
# TODO(QuantumGhost): We may get rid of the `list_conversation_variables`
|
||||||
|
# here, and rely on `DraftVarLoader` to load conversation variables.
|
||||||
# run draft workflow node
|
|
||||||
start_at = time.perf_counter()
|
|
||||||
with Session(bind=db.engine) as session:
|
with Session(bind=db.engine) as session:
|
||||||
# TODO(QunatumGhost): inject conversation variables
|
|
||||||
# to variable pool.
|
|
||||||
draft_var_srv = WorkflowDraftVariableService(session)
|
draft_var_srv = WorkflowDraftVariableService(session)
|
||||||
|
|
||||||
conv_vars_list = draft_var_srv.list_conversation_variables(app_id=app_model.id)
|
conv_vars_list = draft_var_srv.list_conversation_variables(app_id=app_model.id)
|
||||||
conv_var_mapping = {v.name: v.get_value().value for v in conv_vars_list.variables}
|
conv_var_mapping = {v.name: v.get_value().value for v in conv_vars_list.variables}
|
||||||
|
|
||||||
|
variable_loader = DraftVarLoader(engine=db.engine, app_id=app_model.id)
|
||||||
|
run = WorkflowEntry.single_step_run(
|
||||||
|
workflow=draft_workflow,
|
||||||
|
node_id=node_id,
|
||||||
|
user_inputs=user_inputs,
|
||||||
|
user_id=account.id,
|
||||||
|
conversation_variables=conv_var_mapping,
|
||||||
|
variable_loader=variable_loader,
|
||||||
|
)
|
||||||
|
|
||||||
|
# run draft workflow node
|
||||||
|
start_at = time.perf_counter()
|
||||||
node_execution = self._handle_node_run_result(
|
node_execution = self._handle_node_run_result(
|
||||||
invoke_node_fn=lambda: WorkflowEntry.single_step_run(
|
invoke_node_fn=lambda: run,
|
||||||
workflow=draft_workflow,
|
|
||||||
node_id=node_id,
|
|
||||||
user_inputs=user_inputs,
|
|
||||||
user_id=account.id,
|
|
||||||
conversation_variables=conv_var_mapping,
|
|
||||||
),
|
|
||||||
start_at=start_at,
|
start_at=start_at,
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
)
|
)
|
||||||
@ -403,7 +409,7 @@ class WorkflowService:
|
|||||||
) -> NodeExecution:
|
) -> NodeExecution:
|
||||||
try:
|
try:
|
||||||
node_instance, generator = invoke_node_fn()
|
node_instance, generator = invoke_node_fn()
|
||||||
generator = _inspect_generator(generator)
|
generator = gen_utils.inspect(generator, logging.getLogger(__name__))
|
||||||
|
|
||||||
node_run_result: NodeRunResult | None = None
|
node_run_result: NodeRunResult | None = None
|
||||||
for event in generator:
|
for event in generator:
|
||||||
@ -610,19 +616,3 @@ class WorkflowService:
|
|||||||
|
|
||||||
session.delete(workflow)
|
session.delete(workflow)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _inspect_generator(gen: Generator[Any] | Any) -> Any:
|
|
||||||
if not isgenerator(gen):
|
|
||||||
return gen
|
|
||||||
|
|
||||||
def wrapper():
|
|
||||||
for item in gen:
|
|
||||||
logging.getLogger(__name__).info(
|
|
||||||
"received generator item, type=%s, value=%s",
|
|
||||||
type(item),
|
|
||||||
item,
|
|
||||||
)
|
|
||||||
yield item
|
|
||||||
|
|
||||||
return wrapper()
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user