mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 20:45:54 +08:00
r2
This commit is contained in:
parent
ba52bf27c1
commit
a64df507f6
@ -129,9 +129,6 @@ class DraftRagPipelineApi(Resource):
|
||||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
rag_pipeline_variables_list = args.get("rag_pipeline_variables") or []
|
||||
rag_pipeline_variables = [variable_factory.build_pipeline_variable_from_mapping(obj) for obj in rag_pipeline_variables_list]
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow = rag_pipeline_service.sync_draft_workflow(
|
||||
pipeline=pipeline,
|
||||
@ -140,7 +137,7 @@ class DraftRagPipelineApi(Resource):
|
||||
account=current_user,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
rag_pipeline_variables=rag_pipeline_variables,
|
||||
rag_pipeline_variables=args.get("rag_pipeline_variables") or [],
|
||||
)
|
||||
except WorkflowHashNotEqualError:
|
||||
raise DraftWorkflowNotSync()
|
||||
|
@ -82,7 +82,7 @@ def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Va
|
||||
def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
||||
if not mapping.get("variable"):
|
||||
raise VariableError("missing variable")
|
||||
return _build_variable_from_mapping(mapping=mapping, selector=[PIPELINE_VARIABLE_NODE_ID, mapping["variable"]])
|
||||
return mapping["variable"]
|
||||
|
||||
|
||||
def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
|
||||
@ -123,44 +123,6 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
|
||||
result = result.model_copy(update={"selector": selector})
|
||||
return cast(Variable, result)
|
||||
|
||||
def _build_rag_pipeline_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
|
||||
"""
|
||||
This factory function is used to create the rag pipeline variable,
|
||||
not support the File type.
|
||||
"""
|
||||
if (type := mapping.get("type")) is None:
|
||||
raise VariableError("missing type")
|
||||
if (value := mapping.get("value")) is None:
|
||||
raise VariableError("missing value")
|
||||
# FIXME: using Any here, fix it later
|
||||
result: Any
|
||||
match type:
|
||||
case SegmentType.STRING:
|
||||
result = StringVariable.model_validate(mapping)
|
||||
case SegmentType.SECRET:
|
||||
result = SecretVariable.model_validate(mapping)
|
||||
case SegmentType.NUMBER if isinstance(value, int):
|
||||
result = IntegerVariable.model_validate(mapping)
|
||||
case SegmentType.NUMBER if isinstance(value, float):
|
||||
result = FloatVariable.model_validate(mapping)
|
||||
case SegmentType.NUMBER if not isinstance(value, float | int):
|
||||
raise VariableError(f"invalid number value {value}")
|
||||
case SegmentType.OBJECT if isinstance(value, dict):
|
||||
result = ObjectVariable.model_validate(mapping)
|
||||
case SegmentType.ARRAY_STRING if isinstance(value, list):
|
||||
result = ArrayStringVariable.model_validate(mapping)
|
||||
case SegmentType.ARRAY_NUMBER if isinstance(value, list):
|
||||
result = ArrayNumberVariable.model_validate(mapping)
|
||||
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
|
||||
result = ArrayObjectVariable.model_validate(mapping)
|
||||
case _:
|
||||
raise VariableError(f"not supported type {type}")
|
||||
if result.size > dify_config.MAX_VARIABLE_SIZE:
|
||||
raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
|
||||
if not result.selector:
|
||||
result = result.model_copy(update={"selector": selector})
|
||||
return cast(Variable, result)
|
||||
|
||||
def build_segment(value: Any, /) -> Segment:
|
||||
if value is None:
|
||||
return NoneSegment()
|
||||
|
@ -2,7 +2,7 @@ import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum, StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Optional, Self, Union
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Self, Union
|
||||
from uuid import uuid4
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -331,6 +331,7 @@ class Workflow(Base):
|
||||
"features": self.features_dict,
|
||||
"environment_variables": [var.model_dump(mode="json") for var in environment_variables],
|
||||
"conversation_variables": [var.model_dump(mode="json") for var in self.conversation_variables],
|
||||
"rag_pipeline_variables": [var.model_dump(mode="json") for var in self.rag_pipeline_variables],
|
||||
}
|
||||
return result
|
||||
|
||||
@ -358,13 +359,13 @@ class Workflow(Base):
|
||||
self._rag_pipeline_variables = "{}"
|
||||
|
||||
variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables)
|
||||
results = [variable_factory.build_pipeline_variable_from_mapping(v) for v in variables_dict.values()]
|
||||
results = [v for v in variables_dict.values()]
|
||||
return results
|
||||
|
||||
@rag_pipeline_variables.setter
|
||||
def rag_pipeline_variables(self, values: Sequence[Variable]) -> None:
|
||||
def rag_pipeline_variables(self, values: List[dict]) -> None:
|
||||
self._rag_pipeline_variables = json.dumps(
|
||||
{item.name: item.model_dump() for item in values},
|
||||
{item["variable"]: item for item in values},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
|
@ -201,7 +201,7 @@ class RagPipelineService:
|
||||
account: Account,
|
||||
environment_variables: Sequence[Variable],
|
||||
conversation_variables: Sequence[Variable],
|
||||
rag_pipeline_variables: Sequence[Variable],
|
||||
rag_pipeline_variables: list,
|
||||
) -> Workflow:
|
||||
"""
|
||||
Sync draft workflow
|
||||
|
@ -578,9 +578,6 @@ class RagPipelineDslService:
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", [])
|
||||
rag_pipeline_variables = [
|
||||
variable_factory.build_pipeline_variable_from_mapping(obj) for obj in rag_pipeline_variables_list
|
||||
]
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
current_draft_workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
|
||||
@ -610,6 +607,7 @@ class RagPipelineDslService:
|
||||
account=account,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
rag_pipeline_variables=rag_pipeline_variables_list,
|
||||
)
|
||||
|
||||
return pipeline
|
||||
|
Loading…
x
Reference in New Issue
Block a user