mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 05:55:52 +08:00
r2
This commit is contained in:
parent
c5a2f43ceb
commit
ba52bf27c1
@ -93,7 +93,7 @@ class DraftRagPipelineApi(Resource):
|
|||||||
parser.add_argument("hash", type=str, required=False, location="json")
|
parser.add_argument("hash", type=str, required=False, location="json")
|
||||||
parser.add_argument("environment_variables", type=list, required=False, location="json")
|
parser.add_argument("environment_variables", type=list, required=False, location="json")
|
||||||
parser.add_argument("conversation_variables", type=list, required=False, location="json")
|
parser.add_argument("conversation_variables", type=list, required=False, location="json")
|
||||||
parser.add_argument("rag_pipeline_variables", type=dict, required=False, location="json")
|
parser.add_argument("rag_pipeline_variables", type=list, required=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
elif "text/plain" in content_type:
|
elif "text/plain" in content_type:
|
||||||
try:
|
try:
|
||||||
@ -101,8 +101,8 @@ class DraftRagPipelineApi(Resource):
|
|||||||
if "graph" not in data or "features" not in data:
|
if "graph" not in data or "features" not in data:
|
||||||
raise ValueError("graph or features not found in data")
|
raise ValueError("graph or features not found in data")
|
||||||
|
|
||||||
if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict):
|
if not isinstance(data.get("graph"), dict):
|
||||||
raise ValueError("graph or features is not a dict")
|
raise ValueError("graph is not a dict")
|
||||||
|
|
||||||
args = {
|
args = {
|
||||||
"graph": data.get("graph"),
|
"graph": data.get("graph"),
|
||||||
@ -129,11 +129,9 @@ class DraftRagPipelineApi(Resource):
|
|||||||
conversation_variables = [
|
conversation_variables = [
|
||||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||||
]
|
]
|
||||||
rag_pipeline_variables_list = args.get("rag_pipeline_variables") or {}
|
rag_pipeline_variables_list = args.get("rag_pipeline_variables") or []
|
||||||
rag_pipeline_variables = {
|
rag_pipeline_variables = [variable_factory.build_pipeline_variable_from_mapping(obj) for obj in rag_pipeline_variables_list]
|
||||||
k: [variable_factory.build_pipeline_variable_from_mapping(obj) for obj in v]
|
|
||||||
for k, v in rag_pipeline_variables_list.items()
|
|
||||||
}
|
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
workflow = rag_pipeline_service.sync_draft_workflow(
|
workflow = rag_pipeline_service.sync_draft_workflow(
|
||||||
pipeline=pipeline,
|
pipeline=pipeline,
|
||||||
@ -634,12 +632,15 @@ class RagPipelineSecondStepApi(Resource):
|
|||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.is_editor:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
datasource_provider = request.args.get("datasource_provider", required=True, type=str)
|
node_id = request.args.get("node_id", required=True, type=str)
|
||||||
|
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
return rag_pipeline_service.get_second_step_parameters(
|
variables = rag_pipeline_service.get_second_step_parameters(
|
||||||
pipeline=pipeline, datasource_provider=datasource_provider
|
pipeline=pipeline, node_id=node_id
|
||||||
)
|
)
|
||||||
|
return {
|
||||||
|
"variables": variables,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class RagPipelineWorkflowRunListApi(Resource):
|
class RagPipelineWorkflowRunListApi(Resource):
|
||||||
@ -785,3 +786,7 @@ api.add_resource(
|
|||||||
DatasourceListApi,
|
DatasourceListApi,
|
||||||
"/rag/pipelines/datasource-plugins",
|
"/rag/pipelines/datasource-plugins",
|
||||||
)
|
)
|
||||||
|
api.add_resource(
|
||||||
|
RagPipelineSecondStepApi,
|
||||||
|
"/rag/pipelines/<uuid:pipeline_id>/workflows/processing/paramters",
|
||||||
|
)
|
||||||
|
@ -4,7 +4,6 @@ from typing import Any, Optional, TextIO, Union
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.datasource.entities.datasource_entities import DatasourceInvokeMessage
|
|
||||||
from core.ops.entities.trace_entity import TraceTaskName
|
from core.ops.entities.trace_entity import TraceTaskName
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||||
@ -114,35 +113,6 @@ class DifyAgentCallbackHandler(BaseModel):
|
|||||||
color=self.color,
|
color=self.color,
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_datasource_end(
|
|
||||||
self,
|
|
||||||
datasource_name: str,
|
|
||||||
datasource_inputs: Mapping[str, Any],
|
|
||||||
datasource_outputs: Iterable[DatasourceInvokeMessage] | str,
|
|
||||||
message_id: Optional[str] = None,
|
|
||||||
timer: Optional[Any] = None,
|
|
||||||
trace_manager: Optional[TraceQueueManager] = None,
|
|
||||||
) -> None:
|
|
||||||
"""Run on datasource end."""
|
|
||||||
if dify_config.DEBUG:
|
|
||||||
print_text("\n[on_datasource_end]\n", color=self.color)
|
|
||||||
print_text("Datasource: " + datasource_name + "\n", color=self.color)
|
|
||||||
print_text("Inputs: " + str(datasource_inputs) + "\n", color=self.color)
|
|
||||||
print_text("Outputs: " + str(datasource_outputs)[:1000] + "\n", color=self.color)
|
|
||||||
print_text("\n")
|
|
||||||
|
|
||||||
if trace_manager:
|
|
||||||
trace_manager.add_trace_task(
|
|
||||||
TraceTask(
|
|
||||||
TraceTaskName.DATASOURCE_TRACE,
|
|
||||||
message_id=message_id,
|
|
||||||
datasource_name=datasource_name,
|
|
||||||
datasource_inputs=datasource_inputs,
|
|
||||||
datasource_outputs=datasource_outputs,
|
|
||||||
timer=timer,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ignore_agent(self) -> bool:
|
def ignore_agent(self) -> bool:
|
||||||
"""Whether to ignore agent callbacks."""
|
"""Whether to ignore agent callbacks."""
|
||||||
|
@ -145,7 +145,7 @@ class DatasourceProviderEntity(ToolProviderEntity):
|
|||||||
|
|
||||||
|
|
||||||
class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity):
|
class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity):
|
||||||
datasources: list[DatasourceEntity] = Field(default_factory=list)
|
datasources: list[DatasourceEntity] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class DatasourceInvokeMeta(BaseModel):
|
class DatasourceInvokeMeta(BaseModel):
|
||||||
|
@ -127,7 +127,7 @@ class GeneralStructureChunk(BaseModel):
|
|||||||
General Structure Chunk.
|
General Structure Chunk.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
general_chunk: list[str]
|
general_chunks: list[str]
|
||||||
data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo]
|
data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo]
|
||||||
|
|
||||||
|
|
||||||
|
@ -80,9 +80,9 @@ def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Va
|
|||||||
|
|
||||||
|
|
||||||
def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
||||||
if not mapping.get("name"):
|
if not mapping.get("variable"):
|
||||||
raise VariableError("missing name")
|
raise VariableError("missing variable")
|
||||||
return _build_variable_from_mapping(mapping=mapping, selector=[PIPELINE_VARIABLE_NODE_ID, mapping["name"]])
|
return _build_variable_from_mapping(mapping=mapping, selector=[PIPELINE_VARIABLE_NODE_ID, mapping["variable"]])
|
||||||
|
|
||||||
|
|
||||||
def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
|
def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
|
||||||
@ -123,6 +123,43 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
|
|||||||
result = result.model_copy(update={"selector": selector})
|
result = result.model_copy(update={"selector": selector})
|
||||||
return cast(Variable, result)
|
return cast(Variable, result)
|
||||||
|
|
||||||
|
def _build_rag_pipeline_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
|
||||||
|
"""
|
||||||
|
This factory function is used to create the rag pipeline variable,
|
||||||
|
not support the File type.
|
||||||
|
"""
|
||||||
|
if (type := mapping.get("type")) is None:
|
||||||
|
raise VariableError("missing type")
|
||||||
|
if (value := mapping.get("value")) is None:
|
||||||
|
raise VariableError("missing value")
|
||||||
|
# FIXME: using Any here, fix it later
|
||||||
|
result: Any
|
||||||
|
match type:
|
||||||
|
case SegmentType.STRING:
|
||||||
|
result = StringVariable.model_validate(mapping)
|
||||||
|
case SegmentType.SECRET:
|
||||||
|
result = SecretVariable.model_validate(mapping)
|
||||||
|
case SegmentType.NUMBER if isinstance(value, int):
|
||||||
|
result = IntegerVariable.model_validate(mapping)
|
||||||
|
case SegmentType.NUMBER if isinstance(value, float):
|
||||||
|
result = FloatVariable.model_validate(mapping)
|
||||||
|
case SegmentType.NUMBER if not isinstance(value, float | int):
|
||||||
|
raise VariableError(f"invalid number value {value}")
|
||||||
|
case SegmentType.OBJECT if isinstance(value, dict):
|
||||||
|
result = ObjectVariable.model_validate(mapping)
|
||||||
|
case SegmentType.ARRAY_STRING if isinstance(value, list):
|
||||||
|
result = ArrayStringVariable.model_validate(mapping)
|
||||||
|
case SegmentType.ARRAY_NUMBER if isinstance(value, list):
|
||||||
|
result = ArrayNumberVariable.model_validate(mapping)
|
||||||
|
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
|
||||||
|
result = ArrayObjectVariable.model_validate(mapping)
|
||||||
|
case _:
|
||||||
|
raise VariableError(f"not supported type {type}")
|
||||||
|
if result.size > dify_config.MAX_VARIABLE_SIZE:
|
||||||
|
raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
|
||||||
|
if not result.selector:
|
||||||
|
result = result.model_copy(update={"selector": selector})
|
||||||
|
return cast(Variable, result)
|
||||||
|
|
||||||
def build_segment(value: Any, /) -> Segment:
|
def build_segment(value: Any, /) -> Segment:
|
||||||
if value is None:
|
if value is None:
|
||||||
|
@ -42,9 +42,19 @@ conversation_variable_fields = {
|
|||||||
|
|
||||||
pipeline_variable_fields = {
|
pipeline_variable_fields = {
|
||||||
"id": fields.String,
|
"id": fields.String,
|
||||||
"name": fields.String,
|
"label": fields.String,
|
||||||
"value_type": fields.String(attribute="value_type.value"),
|
"variable": fields.String,
|
||||||
"value": fields.Raw,
|
"type": fields.String(attribute="type.value"),
|
||||||
|
"belong_to_node_id": fields.String,
|
||||||
|
"max_length": fields.Integer,
|
||||||
|
"required": fields.Boolean,
|
||||||
|
"default_value": fields.Raw,
|
||||||
|
"options": fields.List(fields.String),
|
||||||
|
"placeholder": fields.String,
|
||||||
|
"tooltips": fields.String,
|
||||||
|
"allowed_file_types": fields.List(fields.String),
|
||||||
|
"allow_file_extension": fields.List(fields.String),
|
||||||
|
"allow_file_upload_methods": fields.List(fields.String),
|
||||||
}
|
}
|
||||||
|
|
||||||
workflow_fields = {
|
workflow_fields = {
|
||||||
@ -62,6 +72,7 @@ workflow_fields = {
|
|||||||
"tool_published": fields.Boolean,
|
"tool_published": fields.Boolean,
|
||||||
"environment_variables": fields.List(EnvironmentVariableField()),
|
"environment_variables": fields.List(EnvironmentVariableField()),
|
||||||
"conversation_variables": fields.List(fields.Nested(conversation_variable_fields)),
|
"conversation_variables": fields.List(fields.Nested(conversation_variable_fields)),
|
||||||
|
"rag_pipeline_variables": fields.List(fields.Nested(pipeline_variable_fields)),
|
||||||
}
|
}
|
||||||
|
|
||||||
workflow_partial_fields = {
|
workflow_partial_fields = {
|
||||||
|
@ -352,21 +352,19 @@ class Workflow(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def rag_pipeline_variables(self) -> dict[str, Sequence[Variable]]:
|
def rag_pipeline_variables(self) -> Sequence[Variable]:
|
||||||
# TODO: find some way to init `self._conversation_variables` when instance created.
|
# TODO: find some way to init `self._conversation_variables` when instance created.
|
||||||
if self._rag_pipeline_variables is None:
|
if self._rag_pipeline_variables is None:
|
||||||
self._rag_pipeline_variables = "{}"
|
self._rag_pipeline_variables = "{}"
|
||||||
|
|
||||||
variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables)
|
variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables)
|
||||||
results = {}
|
results = [variable_factory.build_pipeline_variable_from_mapping(v) for v in variables_dict.values()]
|
||||||
for k, v in variables_dict.items():
|
|
||||||
results[k] = [variable_factory.build_pipeline_variable_from_mapping(item) for item in v.values()]
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@rag_pipeline_variables.setter
|
@rag_pipeline_variables.setter
|
||||||
def rag_pipeline_variables(self, values: dict[str, Sequence[Variable]]) -> None:
|
def rag_pipeline_variables(self, values: Sequence[Variable]) -> None:
|
||||||
self._rag_pipeline_variables = json.dumps(
|
self._rag_pipeline_variables = json.dumps(
|
||||||
{k: {item.name: item.model_dump() for item in v} for k, v in values.items()},
|
{item.name: item.model_dump() for item in values},
|
||||||
ensure_ascii=False,
|
ensure_ascii=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -201,7 +201,7 @@ class RagPipelineService:
|
|||||||
account: Account,
|
account: Account,
|
||||||
environment_variables: Sequence[Variable],
|
environment_variables: Sequence[Variable],
|
||||||
conversation_variables: Sequence[Variable],
|
conversation_variables: Sequence[Variable],
|
||||||
rag_pipeline_variables: dict[str, Sequence[Variable]],
|
rag_pipeline_variables: Sequence[Variable],
|
||||||
) -> Workflow:
|
) -> Workflow:
|
||||||
"""
|
"""
|
||||||
Sync draft workflow
|
Sync draft workflow
|
||||||
@ -552,7 +552,7 @@ class RagPipelineService:
|
|||||||
|
|
||||||
return workflow
|
return workflow
|
||||||
|
|
||||||
def get_second_step_parameters(self, pipeline: Pipeline, datasource_provider: str) -> dict:
|
def get_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> dict:
|
||||||
"""
|
"""
|
||||||
Get second step parameters of rag pipeline
|
Get second step parameters of rag pipeline
|
||||||
"""
|
"""
|
||||||
@ -562,13 +562,15 @@ class RagPipelineService:
|
|||||||
raise ValueError("Workflow not initialized")
|
raise ValueError("Workflow not initialized")
|
||||||
|
|
||||||
# get second step node
|
# get second step node
|
||||||
pipeline_variables = workflow.pipeline_variables
|
rag_pipeline_variables = workflow.rag_pipeline_variables
|
||||||
if not pipeline_variables:
|
if not rag_pipeline_variables:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
# get datasource provider
|
# get datasource provider
|
||||||
datasource_provider_variables = pipeline_variables.get(datasource_provider, [])
|
datasource_provider_variables = [item for item in rag_pipeline_variables
|
||||||
shared_variables = pipeline_variables.get("shared", [])
|
if item.get("belong_to_node_id") == node_id
|
||||||
return datasource_provider_variables + shared_variables
|
or item.get("belong_to_node_id") == "shared"]
|
||||||
|
return datasource_provider_variables
|
||||||
|
|
||||||
def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination:
|
def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination:
|
||||||
"""
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user