This commit is contained in:
jyong 2025-05-20 15:18:33 +08:00
parent ba52bf27c1
commit a64df507f6
5 changed files with 9 additions and 51 deletions

View File

@ -129,9 +129,6 @@ class DraftRagPipelineApi(Resource):
conversation_variables = [ conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
] ]
rag_pipeline_variables_list = args.get("rag_pipeline_variables") or []
rag_pipeline_variables = [variable_factory.build_pipeline_variable_from_mapping(obj) for obj in rag_pipeline_variables_list]
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.sync_draft_workflow( workflow = rag_pipeline_service.sync_draft_workflow(
pipeline=pipeline, pipeline=pipeline,
@ -140,7 +137,7 @@ class DraftRagPipelineApi(Resource):
account=current_user, account=current_user,
environment_variables=environment_variables, environment_variables=environment_variables,
conversation_variables=conversation_variables, conversation_variables=conversation_variables,
rag_pipeline_variables=rag_pipeline_variables, rag_pipeline_variables=args.get("rag_pipeline_variables") or [],
) )
except WorkflowHashNotEqualError: except WorkflowHashNotEqualError:
raise DraftWorkflowNotSync() raise DraftWorkflowNotSync()

View File

@ -82,7 +82,7 @@ def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Va
def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
if not mapping.get("variable"): if not mapping.get("variable"):
raise VariableError("missing variable") raise VariableError("missing variable")
return _build_variable_from_mapping(mapping=mapping, selector=[PIPELINE_VARIABLE_NODE_ID, mapping["variable"]]) return mapping["variable"]
def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable: def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
@ -123,44 +123,6 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
result = result.model_copy(update={"selector": selector}) result = result.model_copy(update={"selector": selector})
return cast(Variable, result) return cast(Variable, result)
def _build_rag_pipeline_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
"""
This factory function is used to create the rag pipeline variable,
not support the File type.
"""
if (type := mapping.get("type")) is None:
raise VariableError("missing type")
if (value := mapping.get("value")) is None:
raise VariableError("missing value")
# FIXME: using Any here, fix it later
result: Any
match type:
case SegmentType.STRING:
result = StringVariable.model_validate(mapping)
case SegmentType.SECRET:
result = SecretVariable.model_validate(mapping)
case SegmentType.NUMBER if isinstance(value, int):
result = IntegerVariable.model_validate(mapping)
case SegmentType.NUMBER if isinstance(value, float):
result = FloatVariable.model_validate(mapping)
case SegmentType.NUMBER if not isinstance(value, float | int):
raise VariableError(f"invalid number value {value}")
case SegmentType.OBJECT if isinstance(value, dict):
result = ObjectVariable.model_validate(mapping)
case SegmentType.ARRAY_STRING if isinstance(value, list):
result = ArrayStringVariable.model_validate(mapping)
case SegmentType.ARRAY_NUMBER if isinstance(value, list):
result = ArrayNumberVariable.model_validate(mapping)
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
result = ArrayObjectVariable.model_validate(mapping)
case _:
raise VariableError(f"not supported type {type}")
if result.size > dify_config.MAX_VARIABLE_SIZE:
raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
if not result.selector:
result = result.model_copy(update={"selector": selector})
return cast(Variable, result)
def build_segment(value: Any, /) -> Segment: def build_segment(value: Any, /) -> Segment:
if value is None: if value is None:
return NoneSegment() return NoneSegment()

View File

@ -2,7 +2,7 @@ import json
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from datetime import UTC, datetime from datetime import UTC, datetime
from enum import Enum, StrEnum from enum import Enum, StrEnum
from typing import TYPE_CHECKING, Any, Optional, Self, Union from typing import TYPE_CHECKING, Any, List, Optional, Self, Union
from uuid import uuid4 from uuid import uuid4
if TYPE_CHECKING: if TYPE_CHECKING:
@ -331,6 +331,7 @@ class Workflow(Base):
"features": self.features_dict, "features": self.features_dict,
"environment_variables": [var.model_dump(mode="json") for var in environment_variables], "environment_variables": [var.model_dump(mode="json") for var in environment_variables],
"conversation_variables": [var.model_dump(mode="json") for var in self.conversation_variables], "conversation_variables": [var.model_dump(mode="json") for var in self.conversation_variables],
"rag_pipeline_variables": [var.model_dump(mode="json") for var in self.rag_pipeline_variables],
} }
return result return result
@ -358,13 +359,13 @@ class Workflow(Base):
self._rag_pipeline_variables = "{}" self._rag_pipeline_variables = "{}"
variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables) variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables)
results = [variable_factory.build_pipeline_variable_from_mapping(v) for v in variables_dict.values()] results = [v for v in variables_dict.values()]
return results return results
@rag_pipeline_variables.setter @rag_pipeline_variables.setter
def rag_pipeline_variables(self, values: Sequence[Variable]) -> None: def rag_pipeline_variables(self, values: List[dict]) -> None:
self._rag_pipeline_variables = json.dumps( self._rag_pipeline_variables = json.dumps(
{item.name: item.model_dump() for item in values}, {item["variable"]: item for item in values},
ensure_ascii=False, ensure_ascii=False,
) )

View File

@ -201,7 +201,7 @@ class RagPipelineService:
account: Account, account: Account,
environment_variables: Sequence[Variable], environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable], conversation_variables: Sequence[Variable],
rag_pipeline_variables: Sequence[Variable], rag_pipeline_variables: list,
) -> Workflow: ) -> Workflow:
""" """
Sync draft workflow Sync draft workflow

View File

@ -578,9 +578,6 @@ class RagPipelineDslService:
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
] ]
rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", []) rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", [])
rag_pipeline_variables = [
variable_factory.build_pipeline_variable_from_mapping(obj) for obj in rag_pipeline_variables_list
]
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
current_draft_workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline) current_draft_workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
@ -610,6 +607,7 @@ class RagPipelineDslService:
account=account, account=account,
environment_variables=environment_variables, environment_variables=environment_variables,
conversation_variables=conversation_variables, conversation_variables=conversation_variables,
rag_pipeline_variables=rag_pipeline_variables_list,
) )
return pipeline return pipeline