feat: add from_variable_selector for stream chunk / message event (#8228)

This commit is contained in:
takatost 2024-09-10 22:15:50 +08:00 committed by GitHub
parent fdbbdb706f
commit cee0c51dbb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 23 additions and 6 deletions

View File

@ -451,7 +451,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
tts_publisher.publish(message=queue_message) tts_publisher.publish(message=queue_message)
self._task_state.answer += delta_text self._task_state.answer += delta_text
yield self._message_to_stream_response(delta_text, self._message.id) yield self._message_to_stream_response(
answer=delta_text, message_id=self._message.id, from_variable_selector=event.from_variable_selector
)
elif isinstance(event, QueueMessageReplaceEvent): elif isinstance(event, QueueMessageReplaceEvent):
# published by moderation # published by moderation
yield self._message_replace_to_stream_response(answer=event.text) yield self._message_replace_to_stream_response(answer=event.text)

View File

@ -376,7 +376,9 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
tts_publisher.publish(message=queue_message) tts_publisher.publish(message=queue_message)
self._task_state.answer += delta_text self._task_state.answer += delta_text
yield self._text_chunk_to_stream_response(delta_text) yield self._text_chunk_to_stream_response(
delta_text, from_variable_selector=event.from_variable_selector
)
else: else:
continue continue
@ -412,14 +414,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
db.session.commit() db.session.commit()
db.session.close() db.session.close()
def _text_chunk_to_stream_response(self, text: str) -> TextChunkStreamResponse: def _text_chunk_to_stream_response(
self, text: str, from_variable_selector: Optional[list[str]] = None
) -> TextChunkStreamResponse:
""" """
Handle completed event. Handle completed event.
:param text: text :param text: text
:return: :return:
""" """
response = TextChunkStreamResponse( response = TextChunkStreamResponse(
task_id=self._application_generate_entity.task_id, data=TextChunkStreamResponse.Data(text=text) task_id=self._application_generate_entity.task_id,
data=TextChunkStreamResponse.Data(text=text, from_variable_selector=from_variable_selector),
) )
return response return response

View File

@ -90,6 +90,7 @@ class MessageStreamResponse(StreamResponse):
event: StreamEvent = StreamEvent.MESSAGE event: StreamEvent = StreamEvent.MESSAGE
id: str id: str
answer: str answer: str
from_variable_selector: Optional[list[str]] = None
class MessageAudioStreamResponse(StreamResponse): class MessageAudioStreamResponse(StreamResponse):
@ -479,6 +480,7 @@ class TextChunkStreamResponse(StreamResponse):
""" """
text: str text: str
from_variable_selector: Optional[list[str]] = None
event: StreamEvent = StreamEvent.TEXT_CHUNK event: StreamEvent = StreamEvent.TEXT_CHUNK
data: Data data: Data

View File

@ -153,14 +153,21 @@ class MessageCycleManage:
return None return None
def _message_to_stream_response(self, answer: str, message_id: str) -> MessageStreamResponse: def _message_to_stream_response(
self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None
) -> MessageStreamResponse:
""" """
Message to stream response. Message to stream response.
:param answer: answer :param answer: answer
:param message_id: message id :param message_id: message id
:return: :return:
""" """
return MessageStreamResponse(task_id=self._application_generate_entity.task_id, id=message_id, answer=answer) return MessageStreamResponse(
task_id=self._application_generate_entity.task_id,
id=message_id,
answer=answer,
from_variable_selector=from_variable_selector,
)
def _message_replace_to_stream_response(self, answer: str) -> MessageReplaceStreamResponse: def _message_replace_to_stream_response(self, answer: str) -> MessageReplaceStreamResponse:
""" """

View File

@ -108,6 +108,7 @@ class AnswerStreamProcessor(StreamProcessor):
route_node_state=event.route_node_state, route_node_state=event.route_node_state,
parallel_id=event.parallel_id, parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id, parallel_start_node_id=event.parallel_start_node_id,
from_variable_selector=[answer_node_id, "answer"],
) )
else: else:
route_chunk = cast(VarGenerateRouteChunk, route_chunk) route_chunk = cast(VarGenerateRouteChunk, route_chunk)