fix: enhance type hints and improve audio message handling in TTS pub… (#11947)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2024-12-22 10:41:06 +08:00 committed by GitHub
parent 90f093eb67
commit dd0e81d094
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 36 additions and 29 deletions

View File

@ -4,14 +4,17 @@ import logging
import queue import queue
import re import re
import threading import threading
from collections.abc import Iterable
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
MessageQueueMessage,
QueueAgentMessageEvent, QueueAgentMessageEvent,
QueueLLMChunkEvent, QueueLLMChunkEvent,
QueueNodeSucceededEvent, QueueNodeSucceededEvent,
QueueTextChunkEvent, QueueTextChunkEvent,
WorkflowQueueMessage,
) )
from core.model_manager import ModelManager from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
@ -21,7 +24,7 @@ class AudioTrunk:
self.status = status self.status = status
def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str): def _invoice_tts(text_content: str, model_instance: ModelInstance, tenant_id: str, voice: str):
if not text_content or text_content.isspace(): if not text_content or text_content.isspace():
return return
return model_instance.invoke_tts( return model_instance.invoke_tts(
@ -29,13 +32,19 @@ def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
) )
def _process_future(future_queue, audio_queue): def _process_future(
future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None],
audio_queue: queue.Queue[AudioTrunk],
):
while True: while True:
try: try:
future = future_queue.get() future = future_queue.get()
if future is None: if future is None:
break break
for audio in future.result(): invoke_result = future.result()
if not invoke_result:
continue
for audio in invoke_result:
audio_base64 = base64.b64encode(bytes(audio)) audio_base64 = base64.b64encode(bytes(audio))
audio_queue.put(AudioTrunk("responding", audio=audio_base64)) audio_queue.put(AudioTrunk("responding", audio=audio_base64))
except Exception as e: except Exception as e:
@ -49,8 +58,8 @@ class AppGeneratorTTSPublisher:
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.msg_text = "" self.msg_text = ""
self._audio_queue = queue.Queue() self._audio_queue: queue.Queue[AudioTrunk] = queue.Queue()
self._msg_queue = queue.Queue() self._msg_queue: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
self.match = re.compile(r"[。.!?]") self.match = re.compile(r"[。.!?]")
self.model_manager = ModelManager() self.model_manager = ModelManager()
self.model_instance = self.model_manager.get_default_model_instance( self.model_instance = self.model_manager.get_default_model_instance(
@ -66,14 +75,11 @@ class AppGeneratorTTSPublisher:
self._runtime_thread = threading.Thread(target=self._runtime).start() self._runtime_thread = threading.Thread(target=self._runtime).start()
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3) self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
def publish(self, message): def publish(self, message: WorkflowQueueMessage | MessageQueueMessage | None, /):
try:
self._msg_queue.put(message) self._msg_queue.put(message)
except Exception as e:
self.logger.warning(e)
def _runtime(self): def _runtime(self):
future_queue = queue.Queue() future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None] = queue.Queue()
threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start() threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start()
while True: while True:
try: try:
@ -110,7 +116,7 @@ class AppGeneratorTTSPublisher:
break break
future_queue.put(None) future_queue.put(None)
def check_and_get_audio(self) -> AudioTrunk | None: def check_and_get_audio(self):
try: try:
if self._last_audio_event and self._last_audio_event.status == "finish": if self._last_audio_event and self._last_audio_event.status == "finish":
if self.executor: if self.executor:

View File

@ -197,11 +197,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
stream_response=stream_response, stream_response=stream_response,
) )
def _listen_audio_msg(self, publisher, task_id: str): def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
if not publisher: if not publisher:
return None return None
audio_msg: AudioTrunk = publisher.check_and_get_audio() audio_msg = publisher.check_and_get_audio()
if audio_msg and audio_msg.status != "finish": if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None return None
@ -222,7 +222,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True: while True:
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id) audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id)
if audio_response: if audio_response:
yield audio_response yield audio_response
else: else:
@ -511,7 +511,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# only publish tts message at text chunk streaming # only publish tts message at text chunk streaming
if tts_publisher: if tts_publisher:
tts_publisher.publish(message=queue_message) tts_publisher.publish(queue_message)
self._task_state.answer += delta_text self._task_state.answer += delta_text
yield self._message_to_stream_response( yield self._message_to_stream_response(

View File

@ -1,7 +1,6 @@
import queue import queue
import time import time
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Generator
from enum import Enum from enum import Enum
from typing import Any from typing import Any
@ -11,9 +10,11 @@ from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
AppQueueEvent, AppQueueEvent,
MessageQueueMessage,
QueueErrorEvent, QueueErrorEvent,
QueuePingEvent, QueuePingEvent,
QueueStopEvent, QueueStopEvent,
WorkflowQueueMessage,
) )
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
@ -37,11 +38,11 @@ class AppQueueManager:
AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}" AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
) )
q = queue.Queue() q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
self._q = q self._q = q
def listen(self) -> Generator: def listen(self):
""" """
Listen to queue Listen to queue
:return: :return:

View File

@ -171,11 +171,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response) yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)
def _listen_audio_msg(self, publisher, task_id: str): def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
if not publisher: if not publisher:
return None return None
audio_msg: AudioTrunk = publisher.check_and_get_audio() audio_msg = publisher.check_and_get_audio()
if audio_msg and audio_msg.status != "finish": if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None return None
@ -196,7 +196,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True: while True:
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id) audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id)
if audio_response: if audio_response:
yield audio_response yield audio_response
else: else:
@ -421,7 +421,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# only publish tts message at text chunk streaming # only publish tts message at text chunk streaming
if tts_publisher: if tts_publisher:
tts_publisher.publish(message=queue_message) tts_publisher.publish(queue_message)
self._task_state.answer += delta_text self._task_state.answer += delta_text
yield self._text_chunk_to_stream_response( yield self._text_chunk_to_stream_response(

View File

@ -201,11 +201,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
stream_response=stream_response, stream_response=stream_response,
) )
def _listen_audio_msg(self, publisher, task_id: str): def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
if publisher is None: if publisher is None:
return None return None
audio_msg: AudioTrunk = publisher.check_and_get_audio() audio_msg = publisher.check_and_get_audio()
if audio_msg and audio_msg.status != "finish": if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
# audio_str = audio_msg.audio.decode('utf-8', errors='ignore') # audio_str = audio_msg.audio.decode('utf-8', errors='ignore')
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None return None