mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-15 23:45:59 +08:00
fix end node bug
This commit is contained in:
parent
42899fb3be
commit
85d319719c
@ -255,7 +255,8 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||||
self._publish_event(
|
self._publish_event(
|
||||||
QueueTextChunkEvent(
|
QueueTextChunkEvent(
|
||||||
text=event.chunk_content
|
text=event.chunk_content,
|
||||||
|
from_variable_selector=event.from_variable_selector
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(event, NodeRunRetrieverResourceEvent):
|
elif isinstance(event, NodeRunRetrieverResourceEvent):
|
||||||
|
@ -150,6 +150,8 @@ class QueueTextChunkEvent(AppQueueEvent):
|
|||||||
"""
|
"""
|
||||||
event: QueueEvent = QueueEvent.TEXT_CHUNK
|
event: QueueEvent = QueueEvent.TEXT_CHUNK
|
||||||
text: str
|
text: str
|
||||||
|
from_variable_selector: Optional[list[str]] = None
|
||||||
|
"""from variable selector"""
|
||||||
|
|
||||||
|
|
||||||
class QueueAgentMessageEvent(AppQueueEvent):
|
class QueueAgentMessageEvent(AppQueueEvent):
|
||||||
|
@ -203,7 +203,7 @@ class AnswerStreamProcessor(StreamProcessor):
|
|||||||
return files
|
return files
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_file_var_from_value(self, value: dict | list) -> Optional[dict]:
|
def _get_file_var_from_value(cls, value: dict | list) -> Optional[dict]:
|
||||||
"""
|
"""
|
||||||
Get file var from value
|
Get file var from value
|
||||||
:param value: variable value
|
:param value: variable value
|
||||||
|
@ -62,14 +62,13 @@ class EndStreamGeneratorRouter:
|
|||||||
if node_id != 'sys' and node_id in node_id_config_mapping:
|
if node_id != 'sys' and node_id in node_id_config_mapping:
|
||||||
node = node_id_config_mapping[node_id]
|
node = node_id_config_mapping[node_id]
|
||||||
node_type = node.get('data', {}).get('type')
|
node_type = node.get('data', {}).get('type')
|
||||||
if node_type == NodeType.LLM.value and variable_selector.value_selector[1] == 'text':
|
if (
|
||||||
|
variable_selector.value_selector not in value_selectors
|
||||||
|
and node_type == NodeType.LLM.value
|
||||||
|
and variable_selector.value_selector[1] == 'text'
|
||||||
|
):
|
||||||
value_selectors.append(variable_selector.value_selector)
|
value_selectors.append(variable_selector.value_selector)
|
||||||
|
|
||||||
# remove duplicates
|
|
||||||
value_selector_tuples = [tuple(item) for item in value_selectors]
|
|
||||||
unique_value_selector_tuples = list(set(value_selector_tuples))
|
|
||||||
value_selectors = [list(item) for item in unique_value_selector_tuples]
|
|
||||||
|
|
||||||
return value_selectors
|
return value_selectors
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -18,9 +18,13 @@ class EndStreamProcessor(StreamProcessor):
|
|||||||
|
|
||||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||||
super().__init__(graph, variable_pool)
|
super().__init__(graph, variable_pool)
|
||||||
self.stream_param = graph.end_stream_param
|
self.end_stream_param = graph.end_stream_param
|
||||||
self.end_streamed_variable_selectors = graph.end_stream_param.end_stream_variable_selector_mapping.copy()
|
self.route_position = {}
|
||||||
|
for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items():
|
||||||
|
self.route_position[end_node_id] = 0
|
||||||
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
|
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
|
||||||
|
self.has_outputed = False
|
||||||
|
self.outputed_node_ids = set()
|
||||||
|
|
||||||
def process(self,
|
def process(self,
|
||||||
generator: Generator[GraphEngineEvent, None, None]
|
generator: Generator[GraphEngineEvent, None, None]
|
||||||
@ -32,6 +36,15 @@ class EndStreamProcessor(StreamProcessor):
|
|||||||
|
|
||||||
yield event
|
yield event
|
||||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||||
|
if event.in_iteration_id:
|
||||||
|
if self.has_outputed and event.node_id not in self.outputed_node_ids:
|
||||||
|
event.chunk_content = '\n' + event.chunk_content
|
||||||
|
|
||||||
|
self.outputed_node_ids.add(event.node_id)
|
||||||
|
self.has_outputed = True
|
||||||
|
yield event
|
||||||
|
continue
|
||||||
|
|
||||||
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
||||||
stream_out_end_node_ids = self.current_stream_chunk_generating_node_ids[
|
stream_out_end_node_ids = self.current_stream_chunk_generating_node_ids[
|
||||||
event.route_node_state.node_id
|
event.route_node_state.node_id
|
||||||
@ -42,23 +55,97 @@ class EndStreamProcessor(StreamProcessor):
|
|||||||
event.route_node_state.node_id
|
event.route_node_state.node_id
|
||||||
] = stream_out_end_node_ids
|
] = stream_out_end_node_ids
|
||||||
|
|
||||||
for _ in stream_out_end_node_ids:
|
if stream_out_end_node_ids:
|
||||||
|
if self.has_outputed and event.node_id not in self.outputed_node_ids:
|
||||||
|
event.chunk_content = '\n' + event.chunk_content
|
||||||
|
|
||||||
|
self.outputed_node_ids.add(event.node_id)
|
||||||
|
self.has_outputed = True
|
||||||
yield event
|
yield event
|
||||||
elif isinstance(event, NodeRunSucceededEvent):
|
elif isinstance(event, NodeRunSucceededEvent):
|
||||||
yield event
|
yield event
|
||||||
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
||||||
|
# update self.route_position after all stream event finished
|
||||||
|
for end_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]:
|
||||||
|
self.route_position[end_node_id] += 1
|
||||||
|
|
||||||
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
|
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
|
||||||
|
|
||||||
# remove unreachable nodes
|
# remove unreachable nodes
|
||||||
self._remove_unreachable_nodes(event)
|
self._remove_unreachable_nodes(event)
|
||||||
|
|
||||||
|
# generate stream outputs
|
||||||
|
yield from self._generate_stream_outputs_when_node_finished(event)
|
||||||
else:
|
else:
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
self.end_streamed_variable_selectors = self.graph.end_stream_param.end_stream_variable_selector_mapping.copy()
|
self.route_position = {}
|
||||||
|
for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items():
|
||||||
|
self.route_position[end_node_id] = 0
|
||||||
self.rest_node_ids = self.graph.node_ids.copy()
|
self.rest_node_ids = self.graph.node_ids.copy()
|
||||||
self.current_stream_chunk_generating_node_ids = {}
|
self.current_stream_chunk_generating_node_ids = {}
|
||||||
|
|
||||||
|
def _generate_stream_outputs_when_node_finished(self,
|
||||||
|
event: NodeRunSucceededEvent
|
||||||
|
) -> Generator[GraphEngineEvent, None, None]:
|
||||||
|
"""
|
||||||
|
Generate stream outputs.
|
||||||
|
:param event: node run succeeded event
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for end_node_id, position in self.route_position.items():
|
||||||
|
# all depends on end node id not in rest node ids
|
||||||
|
if (event.route_node_state.node_id != end_node_id
|
||||||
|
and (end_node_id not in self.rest_node_ids
|
||||||
|
or not all(dep_id not in self.rest_node_ids
|
||||||
|
for dep_id in self.end_stream_param.end_dependencies[end_node_id]))):
|
||||||
|
continue
|
||||||
|
|
||||||
|
route_position = self.route_position[end_node_id]
|
||||||
|
|
||||||
|
position = 0
|
||||||
|
value_selectors = []
|
||||||
|
for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]:
|
||||||
|
if position >= route_position:
|
||||||
|
value_selectors.append(current_value_selectors)
|
||||||
|
|
||||||
|
position += 1
|
||||||
|
|
||||||
|
for value_selector in value_selectors:
|
||||||
|
if not value_selector:
|
||||||
|
continue
|
||||||
|
|
||||||
|
value = self.variable_pool.get(
|
||||||
|
value_selector
|
||||||
|
)
|
||||||
|
|
||||||
|
if value is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
text = value.markdown
|
||||||
|
|
||||||
|
if text:
|
||||||
|
current_node_id = value_selector[0]
|
||||||
|
if self.has_outputed and current_node_id not in self.outputed_node_ids:
|
||||||
|
text = '\n' + text
|
||||||
|
|
||||||
|
self.outputed_node_ids.add(current_node_id)
|
||||||
|
self.has_outputed = True
|
||||||
|
yield NodeRunStreamChunkEvent(
|
||||||
|
id=event.id,
|
||||||
|
node_id=event.node_id,
|
||||||
|
node_type=event.node_type,
|
||||||
|
node_data=event.node_data,
|
||||||
|
chunk_content=text,
|
||||||
|
from_variable_selector=value_selector,
|
||||||
|
route_node_state=event.route_node_state,
|
||||||
|
parallel_id=event.parallel_id,
|
||||||
|
parallel_start_node_id=event.parallel_start_node_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.route_position[end_node_id] += 1
|
||||||
|
|
||||||
def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
|
def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Is stream out support
|
Is stream out support
|
||||||
@ -73,14 +160,30 @@ class EndStreamProcessor(StreamProcessor):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
stream_out_end_node_ids = []
|
stream_out_end_node_ids = []
|
||||||
for end_node_id, variable_selectors in self.end_streamed_variable_selectors.items():
|
for end_node_id, route_position in self.route_position.items():
|
||||||
if end_node_id not in self.rest_node_ids:
|
if end_node_id not in self.rest_node_ids:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# all depends on end node id not in rest node ids
|
# all depends on end node id not in rest node ids
|
||||||
if all(dep_id not in self.rest_node_ids
|
if all(dep_id not in self.rest_node_ids
|
||||||
for dep_id in self.stream_param.end_dependencies[end_node_id]):
|
for dep_id in self.end_stream_param.end_dependencies[end_node_id]):
|
||||||
if stream_output_value_selector not in variable_selectors:
|
if route_position >= len(self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]):
|
||||||
|
continue
|
||||||
|
|
||||||
|
position = 0
|
||||||
|
value_selector = None
|
||||||
|
for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]:
|
||||||
|
if position == route_position:
|
||||||
|
value_selector = current_value_selectors
|
||||||
|
break
|
||||||
|
|
||||||
|
position += 1
|
||||||
|
|
||||||
|
if not value_selector:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# check chunk node id is before current node id or equal to current node id
|
||||||
|
if value_selector != stream_output_value_selector:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
stream_out_end_node_ids.append(end_node_id)
|
stream_out_end_node_ids.append(end_node_id)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user