diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 26cd3dd90b..fa4130b762 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -129,9 +129,6 @@ class DraftRagPipelineApi(Resource): conversation_variables = [ variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list ] - rag_pipeline_variables_list = args.get("rag_pipeline_variables") or [] - rag_pipeline_variables = [variable_factory.build_pipeline_variable_from_mapping(obj) for obj in rag_pipeline_variables_list] - rag_pipeline_service = RagPipelineService() workflow = rag_pipeline_service.sync_draft_workflow( pipeline=pipeline, @@ -140,7 +137,7 @@ class DraftRagPipelineApi(Resource): account=current_user, environment_variables=environment_variables, conversation_variables=conversation_variables, - rag_pipeline_variables=rag_pipeline_variables, + rag_pipeline_variables=args.get("rag_pipeline_variables") or [], ) except WorkflowHashNotEqualError: raise DraftWorkflowNotSync() diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 002833d786..69a786e2f5 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -82,7 +82,7 @@ def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Va def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: if not mapping.get("variable"): raise VariableError("missing variable") - return _build_variable_from_mapping(mapping=mapping, selector=[PIPELINE_VARIABLE_NODE_ID, mapping["variable"]]) + return mapping["variable"] def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable: @@ -123,44 +123,6 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen result = result.model_copy(update={"selector": selector}) return cast(Variable, result) -def _build_rag_pipeline_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable: - """ - This factory function is used to create the rag pipeline variable, - not support the File type. - """ - if (type := mapping.get("type")) is None: - raise VariableError("missing type") - if (value := mapping.get("value")) is None: - raise VariableError("missing value") - # FIXME: using Any here, fix it later - result: Any - match type: - case SegmentType.STRING: - result = StringVariable.model_validate(mapping) - case SegmentType.SECRET: - result = SecretVariable.model_validate(mapping) - case SegmentType.NUMBER if isinstance(value, int): - result = IntegerVariable.model_validate(mapping) - case SegmentType.NUMBER if isinstance(value, float): - result = FloatVariable.model_validate(mapping) - case SegmentType.NUMBER if not isinstance(value, float | int): - raise VariableError(f"invalid number value {value}") - case SegmentType.OBJECT if isinstance(value, dict): - result = ObjectVariable.model_validate(mapping) - case SegmentType.ARRAY_STRING if isinstance(value, list): - result = ArrayStringVariable.model_validate(mapping) - case SegmentType.ARRAY_NUMBER if isinstance(value, list): - result = ArrayNumberVariable.model_validate(mapping) - case SegmentType.ARRAY_OBJECT if isinstance(value, list): - result = ArrayObjectVariable.model_validate(mapping) - case _: - raise VariableError(f"not supported type {type}") - if result.size > dify_config.MAX_VARIABLE_SIZE: - raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}") - if not result.selector: - result = result.model_copy(update={"selector": selector}) - return cast(Variable, result) - def build_segment(value: Any, /) -> Segment: if value is None: return NoneSegment() diff --git a/api/models/workflow.py b/api/models/workflow.py index 4ab59b26a6..f04cafe3ed 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -2,7 +2,7 @@ import json from collections.abc import Mapping, Sequence from datetime import UTC, datetime from enum import Enum, StrEnum -from typing import TYPE_CHECKING, Any, Optional, Self, Union +from typing import TYPE_CHECKING, Any, List, Optional, Self, Union from uuid import uuid4 if TYPE_CHECKING: @@ -331,6 +331,7 @@ class Workflow(Base): "features": self.features_dict, "environment_variables": [var.model_dump(mode="json") for var in environment_variables], "conversation_variables": [var.model_dump(mode="json") for var in self.conversation_variables], + "rag_pipeline_variables": [var.model_dump(mode="json") for var in self.rag_pipeline_variables], } return result @@ -358,13 +359,13 @@ class Workflow(Base): self._rag_pipeline_variables = "{}" variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables) - results = [variable_factory.build_pipeline_variable_from_mapping(v) for v in variables_dict.values()] + results = [v for v in variables_dict.values()] return results @rag_pipeline_variables.setter - def rag_pipeline_variables(self, values: Sequence[Variable]) -> None: + def rag_pipeline_variables(self, values: List[dict]) -> None: self._rag_pipeline_variables = json.dumps( - {item.name: item.model_dump() for item in values}, + {item["variable"]: item for item in values}, ensure_ascii=False, ) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index d2fc4d8100..63b5c9983c 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -201,7 +201,7 @@ class RagPipelineService: account: Account, environment_variables: Sequence[Variable], conversation_variables: Sequence[Variable], - rag_pipeline_variables: Sequence[Variable], + rag_pipeline_variables: list, ) -> Workflow: """ Sync draft workflow diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index 1c6dac55be..19c7d37f6e 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -578,9 +578,6 @@ class RagPipelineDslService: variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list ] rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", []) - rag_pipeline_variables = [ - variable_factory.build_pipeline_variable_from_mapping(obj) for obj in rag_pipeline_variables_list - ] rag_pipeline_service = RagPipelineService() current_draft_workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline) @@ -610,6 +607,7 @@ class RagPipelineDslService: account=account, environment_variables=environment_variables, conversation_variables=conversation_variables, + rag_pipeline_variables=rag_pipeline_variables_list, ) return pipeline