mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-18 22:55:58 +08:00
fix test
This commit is contained in:
parent
90e518b05b
commit
f67a88f44d
@ -7,7 +7,6 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
|||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||||
from core.workflow.entities.node_entities import NodeType
|
from core.workflow.entities.node_entities import NodeType
|
||||||
from core.workflow.nodes.answer.entities import GenerateRouteChunk
|
|
||||||
from models.workflow import WorkflowNodeExecutionStatus
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
@ -19,15 +18,6 @@ class WorkflowStreamGenerateNodes(BaseModel):
|
|||||||
stream_node_ids: list[str]
|
stream_node_ids: list[str]
|
||||||
|
|
||||||
|
|
||||||
class ChatflowStreamGenerateRoute(BaseModel):
|
|
||||||
"""
|
|
||||||
ChatflowStreamGenerateRoute entity
|
|
||||||
"""
|
|
||||||
answer_node_id: str
|
|
||||||
generate_route: list[GenerateRouteChunk]
|
|
||||||
current_route_position: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
class NodeExecutionInfo(BaseModel):
|
class NodeExecutionInfo(BaseModel):
|
||||||
"""
|
"""
|
||||||
NodeExecutionInfo entity
|
NodeExecutionInfo entity
|
||||||
|
@ -5,6 +5,8 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
from core.workflow.entities.node_entities import NodeType
|
from core.workflow.entities.node_entities import NodeType
|
||||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||||
|
from core.workflow.nodes.answer.answer_stream_output_manager import AnswerStreamOutputManager
|
||||||
|
from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute
|
||||||
|
|
||||||
|
|
||||||
class GraphEdge(BaseModel):
|
class GraphEdge(BaseModel):
|
||||||
@ -39,6 +41,10 @@ class Graph(BaseModel):
|
|||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="graph node parallel mapping (node id: parallel id)"
|
description="graph node parallel mapping (node id: parallel id)"
|
||||||
)
|
)
|
||||||
|
answer_stream_generate_routes: dict[str, AnswerStreamGenerateRoute] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="answer stream generate routes"
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init(cls,
|
def init(cls,
|
||||||
@ -142,6 +148,12 @@ class Graph(BaseModel):
|
|||||||
node_parallel_mapping=node_parallel_mapping
|
node_parallel_mapping=node_parallel_mapping
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# init answer stream generate routes
|
||||||
|
answer_stream_generate_routes = AnswerStreamOutputManager.init_stream_generate_routes(
|
||||||
|
node_id_config_mapping=node_id_config_mapping,
|
||||||
|
edge_mapping=edge_mapping
|
||||||
|
)
|
||||||
|
|
||||||
# init graph
|
# init graph
|
||||||
graph = cls(
|
graph = cls(
|
||||||
root_node_id=root_node_id,
|
root_node_id=root_node_id,
|
||||||
@ -149,7 +161,8 @@ class Graph(BaseModel):
|
|||||||
node_id_config_mapping=node_id_config_mapping,
|
node_id_config_mapping=node_id_config_mapping,
|
||||||
edge_mapping=edge_mapping,
|
edge_mapping=edge_mapping,
|
||||||
parallel_mapping=parallel_mapping,
|
parallel_mapping=parallel_mapping,
|
||||||
node_parallel_mapping=node_parallel_mapping
|
node_parallel_mapping=node_parallel_mapping,
|
||||||
|
answer_stream_generate_routes=answer_stream_generate_routes
|
||||||
)
|
)
|
||||||
|
|
||||||
return graph
|
return graph
|
||||||
|
@ -28,8 +28,8 @@ from core.workflow.graph_engine.entities.graph import Graph
|
|||||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||||
from core.workflow.nodes import node_classes
|
|
||||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||||
|
from core.workflow.nodes.node_mapping import node_classes
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||||
|
|
||||||
|
@ -1,33 +0,0 @@
|
|||||||
from core.workflow.entities.node_entities import NodeType
|
|
||||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
|
||||||
from core.workflow.nodes.code.code_node import CodeNode
|
|
||||||
from core.workflow.nodes.end.end_node import EndNode
|
|
||||||
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
|
|
||||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
|
||||||
from core.workflow.nodes.iteration.iteration_node import IterationNode
|
|
||||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
|
||||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
|
||||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
|
||||||
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
|
||||||
from core.workflow.nodes.start.start_node import StartNode
|
|
||||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
|
||||||
from core.workflow.nodes.tool.tool_node import ToolNode
|
|
||||||
from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
|
|
||||||
|
|
||||||
node_classes = {
|
|
||||||
NodeType.START: StartNode,
|
|
||||||
NodeType.END: EndNode,
|
|
||||||
NodeType.ANSWER: AnswerNode,
|
|
||||||
NodeType.LLM: LLMNode,
|
|
||||||
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
|
|
||||||
NodeType.IF_ELSE: IfElseNode,
|
|
||||||
NodeType.CODE: CodeNode,
|
|
||||||
NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode,
|
|
||||||
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
|
|
||||||
NodeType.HTTP_REQUEST: HttpRequestNode,
|
|
||||||
NodeType.TOOL: ToolNode,
|
|
||||||
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
|
|
||||||
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
|
|
||||||
NodeType.ITERATION: IterationNode,
|
|
||||||
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode
|
|
||||||
}
|
|
@ -2,9 +2,9 @@ import json
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from core.file.file_obj import FileVar
|
from core.file.file_obj import FileVar
|
||||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
|
||||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||||
|
from core.workflow.nodes.answer.answer_stream_output_manager import AnswerStreamOutputManager
|
||||||
from core.workflow.nodes.answer.entities import (
|
from core.workflow.nodes.answer.entities import (
|
||||||
AnswerNodeData,
|
AnswerNodeData,
|
||||||
GenerateRouteChunk,
|
GenerateRouteChunk,
|
||||||
@ -29,11 +29,11 @@ class AnswerNode(BaseNode):
|
|||||||
node_data = cast(AnswerNodeData, node_data)
|
node_data = cast(AnswerNodeData, node_data)
|
||||||
|
|
||||||
# generate routes
|
# generate routes
|
||||||
generate_routes = self.extract_generate_route_from_node_data(node_data)
|
generate_routes = AnswerStreamOutputManager.extract_generate_route_from_node_data(node_data)
|
||||||
|
|
||||||
answer = ''
|
answer = ''
|
||||||
for part in generate_routes:
|
for part in generate_routes:
|
||||||
if part.type == "var":
|
if part.type == GenerateRouteChunk.ChunkType.VAR:
|
||||||
part = cast(VarGenerateRouteChunk, part)
|
part = cast(VarGenerateRouteChunk, part)
|
||||||
value_selector = part.value_selector
|
value_selector = part.value_selector
|
||||||
value = self.graph_runtime_state.variable_pool.get_variable_value(
|
value = self.graph_runtime_state.variable_pool.get_variable_value(
|
||||||
@ -72,67 +72,6 @@ class AnswerNode(BaseNode):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
|
|
||||||
"""
|
|
||||||
Extract generate route selectors
|
|
||||||
:param config: node config
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
node_data = cls._node_data_cls(**config.get("data", {}))
|
|
||||||
node_data = cast(cls._node_data_cls, node_data)
|
|
||||||
|
|
||||||
return cls.extract_generate_route_from_node_data(node_data)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
|
|
||||||
"""
|
|
||||||
Extract generate route from node data
|
|
||||||
:param node_data: node data object
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
variable_template_parser = VariableTemplateParser(template=node_data.answer)
|
|
||||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
|
||||||
|
|
||||||
value_selector_mapping = {
|
|
||||||
variable_selector.variable: variable_selector.value_selector
|
|
||||||
for variable_selector in variable_selectors
|
|
||||||
}
|
|
||||||
|
|
||||||
variable_keys = list(value_selector_mapping.keys())
|
|
||||||
|
|
||||||
# format answer template
|
|
||||||
template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True)
|
|
||||||
template_variable_keys = template_parser.variable_keys
|
|
||||||
|
|
||||||
# Take the intersection of variable_keys and template_variable_keys
|
|
||||||
variable_keys = list(set(variable_keys) & set(template_variable_keys))
|
|
||||||
|
|
||||||
template = node_data.answer
|
|
||||||
for var in variable_keys:
|
|
||||||
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
|
|
||||||
|
|
||||||
generate_routes = []
|
|
||||||
for part in template.split('Ω'):
|
|
||||||
if part:
|
|
||||||
if cls._is_variable(part, variable_keys):
|
|
||||||
var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '')
|
|
||||||
value_selector = value_selector_mapping[var_key]
|
|
||||||
generate_routes.append(VarGenerateRouteChunk(
|
|
||||||
value_selector=value_selector
|
|
||||||
))
|
|
||||||
else:
|
|
||||||
generate_routes.append(TextGenerateRouteChunk(
|
|
||||||
text=part
|
|
||||||
))
|
|
||||||
|
|
||||||
return generate_routes
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _is_variable(cls, part, variable_keys):
|
|
||||||
cleaned_part = part.replace('{{', '').replace('}}', '')
|
|
||||||
return part.startswith('{{') and cleaned_part in variable_keys
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||||
"""
|
"""
|
||||||
@ -141,7 +80,7 @@ class AnswerNode(BaseNode):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
node_data = node_data
|
node_data = node_data
|
||||||
node_data = cast(cls._node_data_cls, node_data)
|
node_data = cast(AnswerNodeData, node_data)
|
||||||
|
|
||||||
variable_template_parser = VariableTemplateParser(template=node_data.answer)
|
variable_template_parser = VariableTemplateParser(template=node_data.answer)
|
||||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||||
|
160
api/core/workflow/nodes/answer/answer_stream_output_manager.py
Normal file
160
api/core/workflow/nodes/answer/answer_stream_output_manager.py
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||||
|
from core.workflow.entities.node_entities import NodeType
|
||||||
|
from core.workflow.nodes.answer.entities import (
|
||||||
|
AnswerNodeData,
|
||||||
|
AnswerStreamGenerateRoute,
|
||||||
|
GenerateRouteChunk,
|
||||||
|
TextGenerateRouteChunk,
|
||||||
|
VarGenerateRouteChunk,
|
||||||
|
)
|
||||||
|
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||||
|
|
||||||
|
|
||||||
|
class AnswerStreamOutputManager:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def init_stream_generate_routes(cls,
|
||||||
|
node_id_config_mapping: dict[str, dict],
|
||||||
|
edge_mapping: dict[str, list["GraphEdge"]] # type: ignore[name-defined]
|
||||||
|
) -> dict[str, AnswerStreamGenerateRoute]:
|
||||||
|
"""
|
||||||
|
Get stream generate routes.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# parse stream output node value selectors of answer nodes
|
||||||
|
stream_generate_routes = {}
|
||||||
|
for node_id, node_config in node_id_config_mapping.items():
|
||||||
|
if not node_config.get('data', {}).get('type') == NodeType.ANSWER.value:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# get generate route for stream output
|
||||||
|
generate_route = cls._extract_generate_route_selectors(node_config)
|
||||||
|
streaming_node_ids = cls._get_streaming_node_ids(
|
||||||
|
target_node_id=node_id,
|
||||||
|
node_id_config_mapping=node_id_config_mapping,
|
||||||
|
edge_mapping=edge_mapping
|
||||||
|
)
|
||||||
|
|
||||||
|
if not streaming_node_ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for streaming_node_id in streaming_node_ids:
|
||||||
|
stream_generate_routes[streaming_node_id] = AnswerStreamGenerateRoute(
|
||||||
|
answer_node_id=node_id,
|
||||||
|
generate_route=generate_route
|
||||||
|
)
|
||||||
|
|
||||||
|
return stream_generate_routes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
|
||||||
|
"""
|
||||||
|
Extract generate route from node data
|
||||||
|
:param node_data: node data object
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
variable_template_parser = VariableTemplateParser(template=node_data.answer)
|
||||||
|
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||||
|
|
||||||
|
value_selector_mapping = {
|
||||||
|
variable_selector.variable: variable_selector.value_selector
|
||||||
|
for variable_selector in variable_selectors
|
||||||
|
}
|
||||||
|
|
||||||
|
variable_keys = list(value_selector_mapping.keys())
|
||||||
|
|
||||||
|
# format answer template
|
||||||
|
template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True)
|
||||||
|
template_variable_keys = template_parser.variable_keys
|
||||||
|
|
||||||
|
# Take the intersection of variable_keys and template_variable_keys
|
||||||
|
variable_keys = list(set(variable_keys) & set(template_variable_keys))
|
||||||
|
|
||||||
|
template = node_data.answer
|
||||||
|
for var in variable_keys:
|
||||||
|
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
|
||||||
|
|
||||||
|
generate_routes: list[GenerateRouteChunk] = []
|
||||||
|
for part in template.split('Ω'):
|
||||||
|
if part:
|
||||||
|
if cls._is_variable(part, variable_keys):
|
||||||
|
var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '')
|
||||||
|
value_selector = value_selector_mapping[var_key]
|
||||||
|
generate_routes.append(VarGenerateRouteChunk(
|
||||||
|
value_selector=value_selector
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
generate_routes.append(TextGenerateRouteChunk(
|
||||||
|
text=part
|
||||||
|
))
|
||||||
|
|
||||||
|
return generate_routes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_streaming_node_ids(cls,
|
||||||
|
target_node_id: str,
|
||||||
|
node_id_config_mapping: dict[str, dict],
|
||||||
|
edge_mapping: dict[str, list["GraphEdge"]]) -> list[str]: # type: ignore[name-defined]
|
||||||
|
"""
|
||||||
|
Get answer stream node IDs.
|
||||||
|
:param target_node_id: target node ID
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# fetch all ingoing edges from source node
|
||||||
|
ingoing_graph_edges = []
|
||||||
|
for graph_edges in edge_mapping.values():
|
||||||
|
for graph_edge in graph_edges:
|
||||||
|
if graph_edge.target_node_id == target_node_id:
|
||||||
|
ingoing_graph_edges.append(graph_edge)
|
||||||
|
|
||||||
|
if not ingoing_graph_edges:
|
||||||
|
return []
|
||||||
|
|
||||||
|
streaming_node_ids = []
|
||||||
|
for ingoing_graph_edge in ingoing_graph_edges:
|
||||||
|
source_node_id = ingoing_graph_edge.source_node_id
|
||||||
|
source_node = node_id_config_mapping.get(source_node_id)
|
||||||
|
if not source_node:
|
||||||
|
continue
|
||||||
|
|
||||||
|
node_type = source_node.get('data', {}).get('type')
|
||||||
|
|
||||||
|
if node_type in [
|
||||||
|
NodeType.ANSWER.value,
|
||||||
|
NodeType.IF_ELSE.value,
|
||||||
|
NodeType.QUESTION_CLASSIFIER.value,
|
||||||
|
NodeType.ITERATION.value,
|
||||||
|
NodeType.LOOP.value
|
||||||
|
]:
|
||||||
|
# Current node (answer nodes / multi-branch nodes / iteration nodes) cannot be stream node.
|
||||||
|
streaming_node_ids.append(target_node_id)
|
||||||
|
elif node_type == NodeType.START.value:
|
||||||
|
# Current node is START node, can be stream node.
|
||||||
|
streaming_node_ids.append(source_node_id)
|
||||||
|
else:
|
||||||
|
# Find the stream node forward.
|
||||||
|
sub_streaming_node_ids = cls._get_streaming_node_ids(
|
||||||
|
target_node_id=source_node_id,
|
||||||
|
node_id_config_mapping=node_id_config_mapping,
|
||||||
|
edge_mapping=edge_mapping
|
||||||
|
)
|
||||||
|
|
||||||
|
if sub_streaming_node_ids:
|
||||||
|
streaming_node_ids.extend(sub_streaming_node_ids)
|
||||||
|
|
||||||
|
return streaming_node_ids
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
|
||||||
|
"""
|
||||||
|
Extract generate route selectors
|
||||||
|
:param config: node config
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
node_data = AnswerNodeData(**config.get("data", {}))
|
||||||
|
return cls.extract_generate_route_from_node_data(node_data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _is_variable(cls, part, variable_keys):
|
||||||
|
cleaned_part = part.replace('{{', '').replace('}}', '')
|
||||||
|
return part.startswith('{{') and cleaned_part in variable_keys
|
@ -1,5 +1,6 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||||
|
|
||||||
@ -8,27 +9,44 @@ class AnswerNodeData(BaseNodeData):
|
|||||||
"""
|
"""
|
||||||
Answer Node Data.
|
Answer Node Data.
|
||||||
"""
|
"""
|
||||||
answer: str
|
answer: str = Field(..., description="answer template string")
|
||||||
|
|
||||||
|
|
||||||
class GenerateRouteChunk(BaseModel):
|
class GenerateRouteChunk(BaseModel):
|
||||||
"""
|
"""
|
||||||
Generate Route Chunk.
|
Generate Route Chunk.
|
||||||
"""
|
"""
|
||||||
type: str
|
|
||||||
|
class ChunkType(Enum):
|
||||||
|
VAR = "var"
|
||||||
|
TEXT = "text"
|
||||||
|
|
||||||
|
type: ChunkType = Field(..., description="generate route chunk type")
|
||||||
|
|
||||||
|
|
||||||
class VarGenerateRouteChunk(GenerateRouteChunk):
|
class VarGenerateRouteChunk(GenerateRouteChunk):
|
||||||
"""
|
"""
|
||||||
Var Generate Route Chunk.
|
Var Generate Route Chunk.
|
||||||
"""
|
"""
|
||||||
type: str = "var"
|
type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR
|
||||||
value_selector: list[str]
|
"""generate route chunk type"""
|
||||||
|
value_selector: list[str] = Field(..., description="value selector")
|
||||||
|
|
||||||
|
|
||||||
class TextGenerateRouteChunk(GenerateRouteChunk):
|
class TextGenerateRouteChunk(GenerateRouteChunk):
|
||||||
"""
|
"""
|
||||||
Text Generate Route Chunk.
|
Text Generate Route Chunk.
|
||||||
"""
|
"""
|
||||||
type: str = "text"
|
type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.TEXT
|
||||||
text: str
|
"""generate route chunk type"""
|
||||||
|
text: str = Field(..., description="text")
|
||||||
|
|
||||||
|
|
||||||
|
class AnswerStreamGenerateRoute(BaseModel):
|
||||||
|
"""
|
||||||
|
ChatflowStreamGenerateRoute entity
|
||||||
|
"""
|
||||||
|
answer_node_id: str = Field(..., description="answer node ID")
|
||||||
|
generate_route: list[GenerateRouteChunk] = Field(..., description="answer stream generate route")
|
||||||
|
current_route_position: int = 0
|
||||||
|
"""current generate route position"""
|
||||||
|
33
api/core/workflow/nodes/node_mapping.py
Normal file
33
api/core/workflow/nodes/node_mapping.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from core.workflow.entities.node_entities import NodeType
|
||||||
|
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||||
|
from core.workflow.nodes.code.code_node import CodeNode
|
||||||
|
from core.workflow.nodes.end.end_node import EndNode
|
||||||
|
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
|
||||||
|
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||||
|
from core.workflow.nodes.iteration.iteration_node import IterationNode
|
||||||
|
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||||
|
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||||
|
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||||
|
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||||
|
from core.workflow.nodes.start.start_node import StartNode
|
||||||
|
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||||
|
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||||
|
from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
|
||||||
|
|
||||||
|
node_classes = {
|
||||||
|
NodeType.START: StartNode,
|
||||||
|
NodeType.END: EndNode,
|
||||||
|
NodeType.ANSWER: AnswerNode,
|
||||||
|
NodeType.LLM: LLMNode,
|
||||||
|
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
|
||||||
|
NodeType.IF_ELSE: IfElseNode,
|
||||||
|
NodeType.CODE: CodeNode,
|
||||||
|
NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode,
|
||||||
|
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
|
||||||
|
NodeType.HTTP_REQUEST: HttpRequestNode,
|
||||||
|
NodeType.TOOL: ToolNode,
|
||||||
|
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
|
||||||
|
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
|
||||||
|
NodeType.ITERATION: IterationNode,
|
||||||
|
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode
|
||||||
|
}
|
@ -17,10 +17,10 @@ from core.workflow.entities.workflow_runtime_state import WorkflowRuntimeState
|
|||||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||||
from core.workflow.nodes import node_classes
|
|
||||||
from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom
|
from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom
|
||||||
from core.workflow.nodes.iteration.entities import IterationState
|
from core.workflow.nodes.iteration.entities import IterationState
|
||||||
from core.workflow.nodes.llm.entities import LLMNodeData
|
from core.workflow.nodes.llm.entities import LLMNodeData
|
||||||
|
from core.workflow.nodes.node_mapping import node_classes
|
||||||
from core.workflow.nodes.start.start_node import StartNode
|
from core.workflow.nodes.start.start_node import StartNode
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.workflow import (
|
from models.workflow import (
|
||||||
|
@ -8,7 +8,7 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
|||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.workflow.entities.node_entities import NodeType
|
from core.workflow.entities.node_entities import NodeType
|
||||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||||
from core.workflow.nodes import node_classes
|
from core.workflow.nodes.node_mapping import node_classes
|
||||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||||
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
@ -50,6 +50,8 @@ def test_init():
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"type": "answer",
|
"type": "answer",
|
||||||
|
"title": "answer",
|
||||||
|
"answer": "1"
|
||||||
},
|
},
|
||||||
"id": "answer",
|
"id": "answer",
|
||||||
},
|
},
|
||||||
@ -68,6 +70,8 @@ def test_init():
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"type": "answer",
|
"type": "answer",
|
||||||
|
"title": "answer",
|
||||||
|
"answer": "1"
|
||||||
},
|
},
|
||||||
"id": "answer2",
|
"id": "answer2",
|
||||||
}
|
}
|
||||||
@ -141,6 +145,8 @@ def test__init_iteration_graph():
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"type": "answer",
|
"type": "answer",
|
||||||
|
"title": "answer",
|
||||||
|
"answer": "1"
|
||||||
},
|
},
|
||||||
"id": "answer",
|
"id": "answer",
|
||||||
},
|
},
|
||||||
@ -167,6 +173,8 @@ def test__init_iteration_graph():
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"type": "answer",
|
"type": "answer",
|
||||||
|
"title": "answer",
|
||||||
|
"answer": "1"
|
||||||
},
|
},
|
||||||
"id": "answer-in-iteration",
|
"id": "answer-in-iteration",
|
||||||
"parentId": "iteration",
|
"parentId": "iteration",
|
||||||
@ -270,6 +278,8 @@ def test_parallels_graph():
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"type": "answer",
|
"type": "answer",
|
||||||
|
"title": "answer",
|
||||||
|
"answer": "1"
|
||||||
},
|
},
|
||||||
"id": "answer",
|
"id": "answer",
|
||||||
},
|
},
|
||||||
@ -350,6 +360,8 @@ def test_parallels_graph2():
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"type": "answer",
|
"type": "answer",
|
||||||
|
"title": "answer",
|
||||||
|
"answer": "1"
|
||||||
},
|
},
|
||||||
"id": "answer",
|
"id": "answer",
|
||||||
},
|
},
|
||||||
@ -422,6 +434,8 @@ def test_parallels_graph3():
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"type": "answer",
|
"type": "answer",
|
||||||
|
"title": "answer",
|
||||||
|
"answer": "1"
|
||||||
},
|
},
|
||||||
"id": "answer",
|
"id": "answer",
|
||||||
},
|
},
|
||||||
@ -538,6 +552,8 @@ def test_parallels_graph4():
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"type": "answer",
|
"type": "answer",
|
||||||
|
"title": "answer",
|
||||||
|
"answer": "1"
|
||||||
},
|
},
|
||||||
"id": "answer",
|
"id": "answer",
|
||||||
},
|
},
|
||||||
@ -673,6 +689,8 @@ def test_parallels_graph5():
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"type": "answer",
|
"type": "answer",
|
||||||
|
"title": "answer",
|
||||||
|
"answer": "1"
|
||||||
},
|
},
|
||||||
"id": "answer",
|
"id": "answer",
|
||||||
},
|
},
|
||||||
@ -816,6 +834,8 @@ def test_parallels_graph6():
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"type": "answer",
|
"type": "answer",
|
||||||
|
"title": "answer",
|
||||||
|
"answer": "1"
|
||||||
},
|
},
|
||||||
"id": "answer",
|
"id": "answer",
|
||||||
},
|
},
|
||||||
|
@ -9,6 +9,7 @@ from core.workflow.graph_engine.entities.event import (
|
|||||||
BaseNodeEvent,
|
BaseNodeEvent,
|
||||||
GraphRunFailedEvent,
|
GraphRunFailedEvent,
|
||||||
GraphRunStartedEvent,
|
GraphRunStartedEvent,
|
||||||
|
GraphRunSucceededEvent,
|
||||||
NodeRunFailedEvent,
|
NodeRunFailedEvent,
|
||||||
NodeRunStartedEvent,
|
NodeRunStartedEvent,
|
||||||
NodeRunSucceededEvent,
|
NodeRunSucceededEvent,
|
||||||
@ -187,12 +188,19 @@ def test_run_branch(mock_close, mock_remove):
|
|||||||
"data": {
|
"data": {
|
||||||
"title": "Start",
|
"title": "Start",
|
||||||
"type": "start",
|
"type": "start",
|
||||||
"variables": []
|
"variables": [{
|
||||||
|
"label": "uid",
|
||||||
|
"max_length": 48,
|
||||||
|
"options": [],
|
||||||
|
"required": True,
|
||||||
|
"type": "text-input",
|
||||||
|
"variable": "uid"
|
||||||
|
}]
|
||||||
},
|
},
|
||||||
"id": "start"
|
"id": "start"
|
||||||
}, {
|
}, {
|
||||||
"data": {
|
"data": {
|
||||||
"answer": "1",
|
"answer": "1 {{#start.uid#}}",
|
||||||
"title": "Answer",
|
"title": "Answer",
|
||||||
"type": "answer",
|
"type": "answer",
|
||||||
"variables": []
|
"variables": []
|
||||||
@ -261,7 +269,9 @@ def test_run_branch(mock_close, mock_remove):
|
|||||||
SystemVariable.FILES: [],
|
SystemVariable.FILES: [],
|
||||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
SystemVariable.CONVERSATION_ID: 'abababa',
|
||||||
SystemVariable.USER_ID: 'aaa'
|
SystemVariable.USER_ID: 'aaa'
|
||||||
}, user_inputs={})
|
}, user_inputs={
|
||||||
|
"uid": "takato"
|
||||||
|
})
|
||||||
|
|
||||||
graph_engine = GraphEngine(
|
graph_engine = GraphEngine(
|
||||||
tenant_id="111",
|
tenant_id="111",
|
||||||
@ -286,7 +296,7 @@ def test_run_branch(mock_close, mock_remove):
|
|||||||
with app.app_context():
|
with app.app_context():
|
||||||
generator = graph_engine.run()
|
generator = graph_engine.run()
|
||||||
for item in generator:
|
for item in generator:
|
||||||
# print(type(item), item)
|
print(type(item), item)
|
||||||
items.append(item)
|
items.append(item)
|
||||||
|
|
||||||
assert len(items) == 8
|
assert len(items) == 8
|
||||||
@ -294,5 +304,7 @@ def test_run_branch(mock_close, mock_remove):
|
|||||||
assert items[4].route_node_state.node_id == 'if-else-1'
|
assert items[4].route_node_state.node_id == 'if-else-1'
|
||||||
assert items[5].route_node_state.node_id == 'answer-1'
|
assert items[5].route_node_state.node_id == 'answer-1'
|
||||||
assert items[6].route_node_state.node_id == 'answer-1'
|
assert items[6].route_node_state.node_id == 'answer-1'
|
||||||
|
assert items[6].route_node_state.node_run_result.outputs['answer'] == '1 takato'
|
||||||
|
assert isinstance(items[7], GraphRunSucceededEvent)
|
||||||
|
|
||||||
# print(graph_engine.graph_runtime_state.model_dump_json(indent=2))
|
# print(graph_engine.graph_runtime_state.model_dump_json(indent=2))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user