diff --git a/conf/llm_factories.json b/conf/llm_factories.json index 891eccf52..4dc9ed3bb 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -104,18 +104,24 @@ "max_tokens": 2048, "model_type": "embedding" }, + { + "llm_name": "sambert-zhide-v1", + "tags": "TTS", + "max_tokens": 2048, + "model_type": "tts" + }, + { + "llm_name": "sambert-zhiru-v1", + "tags": "TTS", + "max_tokens": 2048, + "model_type": "tts" + }, { "llm_name": "text-embedding-v3", "tags": "TEXT EMBEDDING,8K", "max_tokens": 8192, "model_type": "embedding" }, - { - "llm_name": "paraformer-realtime-8k-v1", - "tags": "SPEECH2TEXT", - "max_tokens": 26214400, - "model_type": "speech2text" - }, { "llm_name": "qwen-vl-max", "tags": "LLM,CHAT,IMAGE2TEXT", diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index c3b30b1c9..8868fce58 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -137,5 +137,6 @@ Seq2txtModel = { } TTSModel = { - "Fish Audio": FishAudioTTS + "Fish Audio": FishAudioTTS, + "Tongyi-Qianwen": QwenTTS } \ No newline at end of file diff --git a/rag/llm/tts_model.py b/rag/llm/tts_model.py index c477313a9..8bf930caa 100644 --- a/rag/llm/tts_model.py +++ b/rag/llm/tts_model.py @@ -22,7 +22,7 @@ from pydantic import BaseModel, conint from rag.utils import num_tokens_from_string import json import re - +import time class ServeReferenceAudio(BaseModel): audio: bytes text: str @@ -96,3 +96,61 @@ class FishAudioTTS(Base): except httpx.HTTPStatusError as e: raise RuntimeError(f"**ERROR**: {e}") + + +class QwenTTS(Base): + def __init__(self, key, model_name, base_url=""): + import dashscope + + self.model_name = model_name + dashscope.api_key = key + + def tts(self, text): + from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse + from dashscope.audio.tts import ResultCallback, SpeechSynthesizer, SpeechSynthesisResult + from collections import deque + + class Callback(ResultCallback): + def __init__(self) -> None: + self.dque = deque() + + def _run(self): + while True: + if not self.dque: + time.sleep(0) + continue + val = self.dque.popleft() + if val: + yield val + else: + break + + def on_open(self): + pass + + def on_complete(self): + self.dque.append(None) + + def on_error(self, response: SpeechSynthesisResponse): + raise RuntimeError(str(response)) + + def on_close(self): + pass + + def on_event(self, result: SpeechSynthesisResult): + if result.get_audio_frame() is not None: + self.dque.append(result.get_audio_frame()) + + text = self.normalize_text(text) + callback = Callback() + SpeechSynthesizer.call(model=self.model_name, + text=text, + callback=callback, + format="mp3") + try: + for data in callback._run(): + yield data + yield num_tokens_from_string(text) + + except Exception as e: + raise RuntimeError(f"**ERROR**: {e}") \ No newline at end of file