mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 06:19:03 +08:00
fix(workflow): Take back LLM streaming output after IF-ELSE (#9875)
This commit is contained in:
parent
17cacf258e
commit
72ea3d6b98
@ -130,15 +130,14 @@ class GraphEngine:
|
|||||||
yield GraphRunStartedEvent()
|
yield GraphRunStartedEvent()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
stream_processor_cls: type[AnswerStreamProcessor | EndStreamProcessor]
|
|
||||||
if self.init_params.workflow_type == WorkflowType.CHAT:
|
if self.init_params.workflow_type == WorkflowType.CHAT:
|
||||||
stream_processor_cls = AnswerStreamProcessor
|
stream_processor = AnswerStreamProcessor(
|
||||||
|
graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
stream_processor_cls = EndStreamProcessor
|
stream_processor = EndStreamProcessor(
|
||||||
|
graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool
|
||||||
stream_processor = stream_processor_cls(
|
)
|
||||||
graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool
|
|
||||||
)
|
|
||||||
|
|
||||||
# run graph
|
# run graph
|
||||||
generator = stream_processor.process(self._run(start_node_id=self.graph.root_node_id))
|
generator = stream_processor.process(self._run(start_node_id=self.graph.root_node_id))
|
||||||
|
@ -149,10 +149,10 @@ class AnswerStreamGeneratorRouter:
|
|||||||
source_node_id = edge.source_node_id
|
source_node_id = edge.source_node_id
|
||||||
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
|
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
|
||||||
if source_node_type in {
|
if source_node_type in {
|
||||||
NodeType.ANSWER.value,
|
NodeType.ANSWER,
|
||||||
NodeType.IF_ELSE.value,
|
NodeType.IF_ELSE,
|
||||||
NodeType.QUESTION_CLASSIFIER.value,
|
NodeType.QUESTION_CLASSIFIER,
|
||||||
NodeType.ITERATION.value,
|
NodeType.ITERATION,
|
||||||
}:
|
}:
|
||||||
answer_dependencies[answer_node_id].append(source_node_id)
|
answer_dependencies[answer_node_id].append(source_node_id)
|
||||||
else:
|
else:
|
||||||
|
@ -22,7 +22,7 @@ class AnswerStreamProcessor(StreamProcessor):
|
|||||||
super().__init__(graph, variable_pool)
|
super().__init__(graph, variable_pool)
|
||||||
self.generate_routes = graph.answer_stream_generate_routes
|
self.generate_routes = graph.answer_stream_generate_routes
|
||||||
self.route_position = {}
|
self.route_position = {}
|
||||||
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
|
for answer_node_id in self.generate_routes.answer_generate_route:
|
||||||
self.route_position[answer_node_id] = 0
|
self.route_position[answer_node_id] = 0
|
||||||
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
|
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
|
||||||
|
|
||||||
|
@ -41,7 +41,6 @@ class StreamProcessor(ABC):
|
|||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
unreachable_first_node_ids.append(edge.target_node_id)
|
unreachable_first_node_ids.append(edge.target_node_id)
|
||||||
unreachable_first_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
|
||||||
|
|
||||||
for node_id in unreachable_first_node_ids:
|
for node_id in unreachable_first_node_ids:
|
||||||
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
|
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
@ -32,7 +33,7 @@ class VarGenerateRouteChunk(GenerateRouteChunk):
|
|||||||
|
|
||||||
type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR
|
type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR
|
||||||
"""generate route chunk type"""
|
"""generate route chunk type"""
|
||||||
value_selector: list[str] = Field(..., description="value selector")
|
value_selector: Sequence[str] = Field(..., description="value selector")
|
||||||
|
|
||||||
|
|
||||||
class TextGenerateRouteChunk(GenerateRouteChunk):
|
class TextGenerateRouteChunk(GenerateRouteChunk):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user