diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index c67b897f81..26cd3dd90b 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -93,7 +93,7 @@ class DraftRagPipelineApi(Resource): parser.add_argument("hash", type=str, required=False, location="json") parser.add_argument("environment_variables", type=list, required=False, location="json") parser.add_argument("conversation_variables", type=list, required=False, location="json") - parser.add_argument("rag_pipeline_variables", type=dict, required=False, location="json") + parser.add_argument("rag_pipeline_variables", type=list, required=False, location="json") args = parser.parse_args() elif "text/plain" in content_type: try: @@ -101,8 +101,8 @@ class DraftRagPipelineApi(Resource): if "graph" not in data or "features" not in data: raise ValueError("graph or features not found in data") - if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict): - raise ValueError("graph or features is not a dict") + if not isinstance(data.get("graph"), dict): + raise ValueError("graph is not a dict") args = { "graph": data.get("graph"), @@ -129,11 +129,9 @@ class DraftRagPipelineApi(Resource): conversation_variables = [ variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list ] - rag_pipeline_variables_list = args.get("rag_pipeline_variables") or {} - rag_pipeline_variables = { - k: [variable_factory.build_pipeline_variable_from_mapping(obj) for obj in v] - for k, v in rag_pipeline_variables_list.items() - } + rag_pipeline_variables_list = args.get("rag_pipeline_variables") or [] + rag_pipeline_variables = [variable_factory.build_pipeline_variable_from_mapping(obj) for obj in rag_pipeline_variables_list] + rag_pipeline_service = RagPipelineService() workflow = rag_pipeline_service.sync_draft_workflow( pipeline=pipeline, @@ -634,12 +632,15 @@ class RagPipelineSecondStepApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - datasource_provider = request.args.get("datasource_provider", required=True, type=str) + node_id = request.args.get("node_id", required=True, type=str) rag_pipeline_service = RagPipelineService() - return rag_pipeline_service.get_second_step_parameters( - pipeline=pipeline, datasource_provider=datasource_provider + variables = rag_pipeline_service.get_second_step_parameters( + pipeline=pipeline, node_id=node_id ) + return { + "variables": variables, + } class RagPipelineWorkflowRunListApi(Resource): @@ -785,3 +786,7 @@ api.add_resource( DatasourceListApi, "/rag/pipelines/datasource-plugins", ) +api.add_resource( + RagPipelineSecondStepApi, + "/rag/pipelines//workflows/processing/paramters", +) diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py index 21bc39d440..1063e66c59 100644 --- a/api/core/callback_handler/agent_tool_callback_handler.py +++ b/api/core/callback_handler/agent_tool_callback_handler.py @@ -4,7 +4,6 @@ from typing import Any, Optional, TextIO, Union from pydantic import BaseModel from configs import dify_config -from core.datasource.entities.datasource_entities import DatasourceInvokeMessage from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.tools.entities.tool_entities import ToolInvokeMessage @@ -114,35 +113,6 @@ class DifyAgentCallbackHandler(BaseModel): color=self.color, ) - def on_datasource_end( - self, - datasource_name: str, - datasource_inputs: Mapping[str, Any], - datasource_outputs: Iterable[DatasourceInvokeMessage] | str, - message_id: Optional[str] = None, - timer: Optional[Any] = None, - trace_manager: Optional[TraceQueueManager] = None, - ) -> None: - """Run on datasource end.""" - if dify_config.DEBUG: - print_text("\n[on_datasource_end]\n", color=self.color) - print_text("Datasource: " + datasource_name + "\n", color=self.color) - print_text("Inputs: " + str(datasource_inputs) + "\n", color=self.color) - print_text("Outputs: " + str(datasource_outputs)[:1000] + "\n", color=self.color) - print_text("\n") - - if trace_manager: - trace_manager.add_trace_task( - TraceTask( - TraceTaskName.DATASOURCE_TRACE, - message_id=message_id, - datasource_name=datasource_name, - datasource_inputs=datasource_inputs, - datasource_outputs=datasource_outputs, - timer=timer, - ) - ) - @property def ignore_agent(self) -> bool: """Whether to ignore agent callbacks.""" diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index e1bcbc323b..25d7c1c352 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -145,7 +145,7 @@ class DatasourceProviderEntity(ToolProviderEntity): class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity): - datasources: list[DatasourceEntity] = Field(default_factory=list) + datasources: list[DatasourceEntity] = Field(default_factory=list) class DatasourceInvokeMeta(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index 05661a6cc8..6b2c91a8a0 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -127,7 +127,7 @@ class GeneralStructureChunk(BaseModel): General Structure Chunk. """ - general_chunk: list[str] + general_chunks: list[str] data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo] diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index e1db5db43d..002833d786 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -80,9 +80,9 @@ def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Va def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: - if not mapping.get("name"): - raise VariableError("missing name") - return _build_variable_from_mapping(mapping=mapping, selector=[PIPELINE_VARIABLE_NODE_ID, mapping["name"]]) + if not mapping.get("variable"): + raise VariableError("missing variable") + return _build_variable_from_mapping(mapping=mapping, selector=[PIPELINE_VARIABLE_NODE_ID, mapping["variable"]]) def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable: @@ -123,6 +123,43 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen result = result.model_copy(update={"selector": selector}) return cast(Variable, result) +def _build_rag_pipeline_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable: + """ + This factory function is used to create the rag pipeline variable, + not support the File type. + """ + if (type := mapping.get("type")) is None: + raise VariableError("missing type") + if (value := mapping.get("value")) is None: + raise VariableError("missing value") + # FIXME: using Any here, fix it later + result: Any + match type: + case SegmentType.STRING: + result = StringVariable.model_validate(mapping) + case SegmentType.SECRET: + result = SecretVariable.model_validate(mapping) + case SegmentType.NUMBER if isinstance(value, int): + result = IntegerVariable.model_validate(mapping) + case SegmentType.NUMBER if isinstance(value, float): + result = FloatVariable.model_validate(mapping) + case SegmentType.NUMBER if not isinstance(value, float | int): + raise VariableError(f"invalid number value {value}") + case SegmentType.OBJECT if isinstance(value, dict): + result = ObjectVariable.model_validate(mapping) + case SegmentType.ARRAY_STRING if isinstance(value, list): + result = ArrayStringVariable.model_validate(mapping) + case SegmentType.ARRAY_NUMBER if isinstance(value, list): + result = ArrayNumberVariable.model_validate(mapping) + case SegmentType.ARRAY_OBJECT if isinstance(value, list): + result = ArrayObjectVariable.model_validate(mapping) + case _: + raise VariableError(f"not supported type {type}") + if result.size > dify_config.MAX_VARIABLE_SIZE: + raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}") + if not result.selector: + result = result.model_copy(update={"selector": selector}) + return cast(Variable, result) def build_segment(value: Any, /) -> Segment: if value is None: diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index a37ae7856d..0733192c4f 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -42,9 +42,19 @@ conversation_variable_fields = { pipeline_variable_fields = { "id": fields.String, - "name": fields.String, - "value_type": fields.String(attribute="value_type.value"), - "value": fields.Raw, + "label": fields.String, + "variable": fields.String, + "type": fields.String(attribute="type.value"), + "belong_to_node_id": fields.String, + "max_length": fields.Integer, + "required": fields.Boolean, + "default_value": fields.Raw, + "options": fields.List(fields.String), + "placeholder": fields.String, + "tooltips": fields.String, + "allowed_file_types": fields.List(fields.String), + "allow_file_extension": fields.List(fields.String), + "allow_file_upload_methods": fields.List(fields.String), } workflow_fields = { @@ -62,6 +72,7 @@ workflow_fields = { "tool_published": fields.Boolean, "environment_variables": fields.List(EnvironmentVariableField()), "conversation_variables": fields.List(fields.Nested(conversation_variable_fields)), + "rag_pipeline_variables": fields.List(fields.Nested(pipeline_variable_fields)), } workflow_partial_fields = { diff --git a/api/models/workflow.py b/api/models/workflow.py index 5cb413b6a6..4ab59b26a6 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -352,21 +352,19 @@ class Workflow(Base): ) @property - def rag_pipeline_variables(self) -> dict[str, Sequence[Variable]]: + def rag_pipeline_variables(self) -> Sequence[Variable]: # TODO: find some way to init `self._conversation_variables` when instance created. if self._rag_pipeline_variables is None: self._rag_pipeline_variables = "{}" variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables) - results = {} - for k, v in variables_dict.items(): - results[k] = [variable_factory.build_pipeline_variable_from_mapping(item) for item in v.values()] + results = [variable_factory.build_pipeline_variable_from_mapping(v) for v in variables_dict.values()] return results @rag_pipeline_variables.setter - def rag_pipeline_variables(self, values: dict[str, Sequence[Variable]]) -> None: + def rag_pipeline_variables(self, values: Sequence[Variable]) -> None: self._rag_pipeline_variables = json.dumps( - {k: {item.name: item.model_dump() for item in v} for k, v in values.items()}, + {item.name: item.model_dump() for item in values}, ensure_ascii=False, ) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index f380bc32d7..d2fc4d8100 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -201,7 +201,7 @@ class RagPipelineService: account: Account, environment_variables: Sequence[Variable], conversation_variables: Sequence[Variable], - rag_pipeline_variables: dict[str, Sequence[Variable]], + rag_pipeline_variables: Sequence[Variable], ) -> Workflow: """ Sync draft workflow @@ -552,7 +552,7 @@ class RagPipelineService: return workflow - def get_second_step_parameters(self, pipeline: Pipeline, datasource_provider: str) -> dict: + def get_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> dict: """ Get second step parameters of rag pipeline """ @@ -562,13 +562,15 @@ class RagPipelineService: raise ValueError("Workflow not initialized") # get second step node - pipeline_variables = workflow.pipeline_variables - if not pipeline_variables: + rag_pipeline_variables = workflow.rag_pipeline_variables + if not rag_pipeline_variables: return {} + # get datasource provider - datasource_provider_variables = pipeline_variables.get(datasource_provider, []) - shared_variables = pipeline_variables.get("shared", []) - return datasource_provider_variables + shared_variables + datasource_provider_variables = [item for item in rag_pipeline_variables + if item.get("belong_to_node_id") == node_id + or item.get("belong_to_node_id") == "shared"] + return datasource_provider_variables def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination: """