mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-16 04:55:58 +08:00
fix: enhance type hints and improve audio message handling in TTS pub… (#11947)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
90f093eb67
commit
dd0e81d094
@ -4,14 +4,17 @@ import logging
|
|||||||
import queue
|
import queue
|
||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
from core.app.entities.queue_entities import (
|
from core.app.entities.queue_entities import (
|
||||||
|
MessageQueueMessage,
|
||||||
QueueAgentMessageEvent,
|
QueueAgentMessageEvent,
|
||||||
QueueLLMChunkEvent,
|
QueueLLMChunkEvent,
|
||||||
QueueNodeSucceededEvent,
|
QueueNodeSucceededEvent,
|
||||||
QueueTextChunkEvent,
|
QueueTextChunkEvent,
|
||||||
|
WorkflowQueueMessage,
|
||||||
)
|
)
|
||||||
from core.model_manager import ModelManager
|
from core.model_manager import ModelInstance, ModelManager
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
|
||||||
|
|
||||||
@ -21,7 +24,7 @@ class AudioTrunk:
|
|||||||
self.status = status
|
self.status = status
|
||||||
|
|
||||||
|
|
||||||
def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
|
def _invoice_tts(text_content: str, model_instance: ModelInstance, tenant_id: str, voice: str):
|
||||||
if not text_content or text_content.isspace():
|
if not text_content or text_content.isspace():
|
||||||
return
|
return
|
||||||
return model_instance.invoke_tts(
|
return model_instance.invoke_tts(
|
||||||
@ -29,13 +32,19 @@ def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _process_future(future_queue, audio_queue):
|
def _process_future(
|
||||||
|
future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None],
|
||||||
|
audio_queue: queue.Queue[AudioTrunk],
|
||||||
|
):
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
future = future_queue.get()
|
future = future_queue.get()
|
||||||
if future is None:
|
if future is None:
|
||||||
break
|
break
|
||||||
for audio in future.result():
|
invoke_result = future.result()
|
||||||
|
if not invoke_result:
|
||||||
|
continue
|
||||||
|
for audio in invoke_result:
|
||||||
audio_base64 = base64.b64encode(bytes(audio))
|
audio_base64 = base64.b64encode(bytes(audio))
|
||||||
audio_queue.put(AudioTrunk("responding", audio=audio_base64))
|
audio_queue.put(AudioTrunk("responding", audio=audio_base64))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -49,8 +58,8 @@ class AppGeneratorTTSPublisher:
|
|||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
self.tenant_id = tenant_id
|
self.tenant_id = tenant_id
|
||||||
self.msg_text = ""
|
self.msg_text = ""
|
||||||
self._audio_queue = queue.Queue()
|
self._audio_queue: queue.Queue[AudioTrunk] = queue.Queue()
|
||||||
self._msg_queue = queue.Queue()
|
self._msg_queue: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
|
||||||
self.match = re.compile(r"[。.!?]")
|
self.match = re.compile(r"[。.!?]")
|
||||||
self.model_manager = ModelManager()
|
self.model_manager = ModelManager()
|
||||||
self.model_instance = self.model_manager.get_default_model_instance(
|
self.model_instance = self.model_manager.get_default_model_instance(
|
||||||
@ -66,14 +75,11 @@ class AppGeneratorTTSPublisher:
|
|||||||
self._runtime_thread = threading.Thread(target=self._runtime).start()
|
self._runtime_thread = threading.Thread(target=self._runtime).start()
|
||||||
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
|
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
|
||||||
|
|
||||||
def publish(self, message):
|
def publish(self, message: WorkflowQueueMessage | MessageQueueMessage | None, /):
|
||||||
try:
|
self._msg_queue.put(message)
|
||||||
self._msg_queue.put(message)
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.warning(e)
|
|
||||||
|
|
||||||
def _runtime(self):
|
def _runtime(self):
|
||||||
future_queue = queue.Queue()
|
future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None] = queue.Queue()
|
||||||
threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start()
|
threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start()
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@ -110,7 +116,7 @@ class AppGeneratorTTSPublisher:
|
|||||||
break
|
break
|
||||||
future_queue.put(None)
|
future_queue.put(None)
|
||||||
|
|
||||||
def check_and_get_audio(self) -> AudioTrunk | None:
|
def check_and_get_audio(self):
|
||||||
try:
|
try:
|
||||||
if self._last_audio_event and self._last_audio_event.status == "finish":
|
if self._last_audio_event and self._last_audio_event.status == "finish":
|
||||||
if self.executor:
|
if self.executor:
|
||||||
|
@ -197,11 +197,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
stream_response=stream_response,
|
stream_response=stream_response,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _listen_audio_msg(self, publisher, task_id: str):
|
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
|
||||||
if not publisher:
|
if not publisher:
|
||||||
return None
|
return None
|
||||||
audio_msg: AudioTrunk = publisher.check_and_get_audio()
|
audio_msg = publisher.check_and_get_audio()
|
||||||
if audio_msg and audio_msg.status != "finish":
|
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
|
||||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -222,7 +222,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
|
|
||||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||||
while True:
|
while True:
|
||||||
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
|
audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id)
|
||||||
if audio_response:
|
if audio_response:
|
||||||
yield audio_response
|
yield audio_response
|
||||||
else:
|
else:
|
||||||
@ -511,7 +511,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
|
|
||||||
# only publish tts message at text chunk streaming
|
# only publish tts message at text chunk streaming
|
||||||
if tts_publisher:
|
if tts_publisher:
|
||||||
tts_publisher.publish(message=queue_message)
|
tts_publisher.publish(queue_message)
|
||||||
|
|
||||||
self._task_state.answer += delta_text
|
self._task_state.answer += delta_text
|
||||||
yield self._message_to_stream_response(
|
yield self._message_to_stream_response(
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import queue
|
import queue
|
||||||
import time
|
import time
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections.abc import Generator
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -11,9 +10,11 @@ from configs import dify_config
|
|||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.app.entities.queue_entities import (
|
from core.app.entities.queue_entities import (
|
||||||
AppQueueEvent,
|
AppQueueEvent,
|
||||||
|
MessageQueueMessage,
|
||||||
QueueErrorEvent,
|
QueueErrorEvent,
|
||||||
QueuePingEvent,
|
QueuePingEvent,
|
||||||
QueueStopEvent,
|
QueueStopEvent,
|
||||||
|
WorkflowQueueMessage,
|
||||||
)
|
)
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
|
|
||||||
@ -37,11 +38,11 @@ class AppQueueManager:
|
|||||||
AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
|
AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
q = queue.Queue()
|
q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
|
||||||
|
|
||||||
self._q = q
|
self._q = q
|
||||||
|
|
||||||
def listen(self) -> Generator:
|
def listen(self):
|
||||||
"""
|
"""
|
||||||
Listen to queue
|
Listen to queue
|
||||||
:return:
|
:return:
|
||||||
|
@ -171,11 +171,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
|
|
||||||
yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)
|
yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)
|
||||||
|
|
||||||
def _listen_audio_msg(self, publisher, task_id: str):
|
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
|
||||||
if not publisher:
|
if not publisher:
|
||||||
return None
|
return None
|
||||||
audio_msg: AudioTrunk = publisher.check_and_get_audio()
|
audio_msg = publisher.check_and_get_audio()
|
||||||
if audio_msg and audio_msg.status != "finish":
|
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
|
||||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -196,7 +196,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
|
|
||||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||||
while True:
|
while True:
|
||||||
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
|
audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id)
|
||||||
if audio_response:
|
if audio_response:
|
||||||
yield audio_response
|
yield audio_response
|
||||||
else:
|
else:
|
||||||
@ -421,7 +421,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
|
|
||||||
# only publish tts message at text chunk streaming
|
# only publish tts message at text chunk streaming
|
||||||
if tts_publisher:
|
if tts_publisher:
|
||||||
tts_publisher.publish(message=queue_message)
|
tts_publisher.publish(queue_message)
|
||||||
|
|
||||||
self._task_state.answer += delta_text
|
self._task_state.answer += delta_text
|
||||||
yield self._text_chunk_to_stream_response(
|
yield self._text_chunk_to_stream_response(
|
||||||
|
@ -201,11 +201,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||||||
stream_response=stream_response,
|
stream_response=stream_response,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _listen_audio_msg(self, publisher, task_id: str):
|
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
|
||||||
if publisher is None:
|
if publisher is None:
|
||||||
return None
|
return None
|
||||||
audio_msg: AudioTrunk = publisher.check_and_get_audio()
|
audio_msg = publisher.check_and_get_audio()
|
||||||
if audio_msg and audio_msg.status != "finish":
|
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
|
||||||
# audio_str = audio_msg.audio.decode('utf-8', errors='ignore')
|
# audio_str = audio_msg.audio.decode('utf-8', errors='ignore')
|
||||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||||
return None
|
return None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user