This commit is contained in:
jyong 2025-05-20 15:18:33 +08:00
parent ba52bf27c1
commit a64df507f6
5 changed files with 9 additions and 51 deletions

View File

@ -129,9 +129,6 @@ class DraftRagPipelineApi(Resource):
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
rag_pipeline_variables_list = args.get("rag_pipeline_variables") or []
rag_pipeline_variables = [variable_factory.build_pipeline_variable_from_mapping(obj) for obj in rag_pipeline_variables_list]
rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.sync_draft_workflow(
pipeline=pipeline,
@ -140,7 +137,7 @@ class DraftRagPipelineApi(Resource):
account=current_user,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
rag_pipeline_variables=rag_pipeline_variables,
rag_pipeline_variables=args.get("rag_pipeline_variables") or [],
)
except WorkflowHashNotEqualError:
raise DraftWorkflowNotSync()

View File

@ -82,7 +82,7 @@ def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Va
def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
if not mapping.get("variable"):
raise VariableError("missing variable")
return _build_variable_from_mapping(mapping=mapping, selector=[PIPELINE_VARIABLE_NODE_ID, mapping["variable"]])
return mapping["variable"]
def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
@ -123,44 +123,6 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
result = result.model_copy(update={"selector": selector})
return cast(Variable, result)
def _build_rag_pipeline_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
"""
This factory function is used to create the rag pipeline variable,
not support the File type.
"""
if (type := mapping.get("type")) is None:
raise VariableError("missing type")
if (value := mapping.get("value")) is None:
raise VariableError("missing value")
# FIXME: using Any here, fix it later
result: Any
match type:
case SegmentType.STRING:
result = StringVariable.model_validate(mapping)
case SegmentType.SECRET:
result = SecretVariable.model_validate(mapping)
case SegmentType.NUMBER if isinstance(value, int):
result = IntegerVariable.model_validate(mapping)
case SegmentType.NUMBER if isinstance(value, float):
result = FloatVariable.model_validate(mapping)
case SegmentType.NUMBER if not isinstance(value, float | int):
raise VariableError(f"invalid number value {value}")
case SegmentType.OBJECT if isinstance(value, dict):
result = ObjectVariable.model_validate(mapping)
case SegmentType.ARRAY_STRING if isinstance(value, list):
result = ArrayStringVariable.model_validate(mapping)
case SegmentType.ARRAY_NUMBER if isinstance(value, list):
result = ArrayNumberVariable.model_validate(mapping)
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
result = ArrayObjectVariable.model_validate(mapping)
case _:
raise VariableError(f"not supported type {type}")
if result.size > dify_config.MAX_VARIABLE_SIZE:
raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
if not result.selector:
result = result.model_copy(update={"selector": selector})
return cast(Variable, result)
def build_segment(value: Any, /) -> Segment:
if value is None:
return NoneSegment()

View File

@ -2,7 +2,7 @@ import json
from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
from enum import Enum, StrEnum
from typing import TYPE_CHECKING, Any, Optional, Self, Union
from typing import TYPE_CHECKING, Any, List, Optional, Self, Union
from uuid import uuid4
if TYPE_CHECKING:
@ -331,6 +331,7 @@ class Workflow(Base):
"features": self.features_dict,
"environment_variables": [var.model_dump(mode="json") for var in environment_variables],
"conversation_variables": [var.model_dump(mode="json") for var in self.conversation_variables],
"rag_pipeline_variables": [var.model_dump(mode="json") for var in self.rag_pipeline_variables],
}
return result
@ -358,13 +359,13 @@ class Workflow(Base):
self._rag_pipeline_variables = "{}"
variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables)
results = [variable_factory.build_pipeline_variable_from_mapping(v) for v in variables_dict.values()]
results = [v for v in variables_dict.values()]
return results
@rag_pipeline_variables.setter
def rag_pipeline_variables(self, values: Sequence[Variable]) -> None:
def rag_pipeline_variables(self, values: List[dict]) -> None:
self._rag_pipeline_variables = json.dumps(
{item.name: item.model_dump() for item in values},
{item["variable"]: item for item in values},
ensure_ascii=False,
)

View File

@ -201,7 +201,7 @@ class RagPipelineService:
account: Account,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
rag_pipeline_variables: Sequence[Variable],
rag_pipeline_variables: list,
) -> Workflow:
"""
Sync draft workflow

View File

@ -578,9 +578,6 @@ class RagPipelineDslService:
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", [])
rag_pipeline_variables = [
variable_factory.build_pipeline_variable_from_mapping(obj) for obj in rag_pipeline_variables_list
]
rag_pipeline_service = RagPipelineService()
current_draft_workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
@ -610,6 +607,7 @@ class RagPipelineDslService:
account=account,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
rag_pipeline_variables=rag_pipeline_variables_list,
)
return pipeline