fix: missing iterator in task pipeline (#4948)

This commit is contained in:
Yeuoly 2024-06-05 15:10:20 +08:00 committed by GitHub
parent 63addc9258
commit 80a87f36ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -107,8 +107,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
usage=LLMUsage.empty_usage()
)
self._stream_generate_routes = self._get_stream_generate_routes()
self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
self._stream_generate_routes = self._get_stream_generate_routes()
self._conversation_name_generate_thread = None
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
@ -410,6 +410,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
ingoing_edges.append(edge)
if not ingoing_edges:
# check if it's the first node in the iteration
target_node = next((node for node in nodes if node.get('id') == target_node_id), None)
if not target_node:
return []
node_iteration_id = target_node.get('data', {}).get('iteration_id')
# get iteration start node id
for node in nodes:
if node.get('id') == node_iteration_id:
if node.get('data', {}).get('start_node_id') == target_node_id:
return [target_node_id]
return []
start_node_ids = []
@ -514,6 +526,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._task_state.answer += route_chunk.text
yield self._message_to_stream_response(route_chunk.text, self._message.id)
else:
value = None
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
if not value_selector:
@ -525,6 +538,20 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if route_chunk_node_id == 'sys':
# system variable
value = self._workflow_system_variables.get(SystemVariable.value_of(value_selector[1]))
elif route_chunk_node_id in self._iteration_nested_relations:
# it's a iteration variable
if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations:
continue
iteration_state = self._iteration_state.current_iterations[route_chunk_node_id]
iterator = iteration_state.inputs
if not iterator:
continue
iterator_selector = iterator.get('iterator_selector', [])
if value_selector[1] == 'index':
value = iteration_state.current_index
elif value_selector[1] == 'item':
value = iterator_selector[iteration_state.current_index] if iteration_state.current_index < len(
iterator_selector) else None
else:
# check chunk node id is before current node id or equal to current node id
if route_chunk_node_id not in self._task_state.ran_node_execution_infos:
@ -554,7 +581,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
else:
value = value.get(key)
if value:
if value is not None:
text = ''
if isinstance(value, str | int | float):
text = str(value)