fix: enhance type hints and improve audio message handling in TTS pub… (#11947)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2024-12-22 10:41:06 +08:00 committed by GitHub
parent 90f093eb67
commit dd0e81d094
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 36 additions and 29 deletions

View File

@ -4,14 +4,17 @@ import logging
import queue
import re
import threading
from collections.abc import Iterable
from core.app.entities.queue_entities import (
MessageQueueMessage,
QueueAgentMessageEvent,
QueueLLMChunkEvent,
QueueNodeSucceededEvent,
QueueTextChunkEvent,
WorkflowQueueMessage,
)
from core.model_manager import ModelManager
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
@ -21,7 +24,7 @@ class AudioTrunk:
self.status = status
def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
def _invoice_tts(text_content: str, model_instance: ModelInstance, tenant_id: str, voice: str):
if not text_content or text_content.isspace():
return
return model_instance.invoke_tts(
@ -29,13 +32,19 @@ def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
)
def _process_future(future_queue, audio_queue):
def _process_future(
future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None],
audio_queue: queue.Queue[AudioTrunk],
):
while True:
try:
future = future_queue.get()
if future is None:
break
for audio in future.result():
invoke_result = future.result()
if not invoke_result:
continue
for audio in invoke_result:
audio_base64 = base64.b64encode(bytes(audio))
audio_queue.put(AudioTrunk("responding", audio=audio_base64))
except Exception as e:
@ -49,8 +58,8 @@ class AppGeneratorTTSPublisher:
self.logger = logging.getLogger(__name__)
self.tenant_id = tenant_id
self.msg_text = ""
self._audio_queue = queue.Queue()
self._msg_queue = queue.Queue()
self._audio_queue: queue.Queue[AudioTrunk] = queue.Queue()
self._msg_queue: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
self.match = re.compile(r"[。.!?]")
self.model_manager = ModelManager()
self.model_instance = self.model_manager.get_default_model_instance(
@ -66,14 +75,11 @@ class AppGeneratorTTSPublisher:
self._runtime_thread = threading.Thread(target=self._runtime).start()
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
def publish(self, message):
try:
self._msg_queue.put(message)
except Exception as e:
self.logger.warning(e)
def publish(self, message: WorkflowQueueMessage | MessageQueueMessage | None, /):
self._msg_queue.put(message)
def _runtime(self):
future_queue = queue.Queue()
future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None] = queue.Queue()
threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start()
while True:
try:
@ -110,7 +116,7 @@ class AppGeneratorTTSPublisher:
break
future_queue.put(None)
def check_and_get_audio(self) -> AudioTrunk | None:
def check_and_get_audio(self):
try:
if self._last_audio_event and self._last_audio_event.status == "finish":
if self.executor:

View File

@ -197,11 +197,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
stream_response=stream_response,
)
def _listen_audio_msg(self, publisher, task_id: str):
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
if not publisher:
return None
audio_msg: AudioTrunk = publisher.check_and_get_audio()
if audio_msg and audio_msg.status != "finish":
audio_msg = publisher.check_and_get_audio()
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
@ -222,7 +222,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
@ -511,7 +511,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# only publish tts message at text chunk streaming
if tts_publisher:
tts_publisher.publish(message=queue_message)
tts_publisher.publish(queue_message)
self._task_state.answer += delta_text
yield self._message_to_stream_response(

View File

@ -1,7 +1,6 @@
import queue
import time
from abc import abstractmethod
from collections.abc import Generator
from enum import Enum
from typing import Any
@ -11,9 +10,11 @@ from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
MessageQueueMessage,
QueueErrorEvent,
QueuePingEvent,
QueueStopEvent,
WorkflowQueueMessage,
)
from extensions.ext_redis import redis_client
@ -37,11 +38,11 @@ class AppQueueManager:
AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
)
q = queue.Queue()
q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
self._q = q
def listen(self) -> Generator:
def listen(self):
"""
Listen to queue
:return:

View File

@ -171,11 +171,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)
def _listen_audio_msg(self, publisher, task_id: str):
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
if not publisher:
return None
audio_msg: AudioTrunk = publisher.check_and_get_audio()
if audio_msg and audio_msg.status != "finish":
audio_msg = publisher.check_and_get_audio()
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
@ -196,7 +196,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
@ -421,7 +421,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# only publish tts message at text chunk streaming
if tts_publisher:
tts_publisher.publish(message=queue_message)
tts_publisher.publish(queue_message)
self._task_state.answer += delta_text
yield self._text_chunk_to_stream_response(

View File

@ -201,11 +201,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
stream_response=stream_response,
)
def _listen_audio_msg(self, publisher, task_id: str):
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
if publisher is None:
return None
audio_msg: AudioTrunk = publisher.check_and_get_audio()
if audio_msg and audio_msg.status != "finish":
audio_msg = publisher.check_and_get_audio()
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
# audio_str = audio_msg.audio.decode('utf-8', errors='ignore')
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None