From d545633a6ca05b3e9ea6d25362425e5195531e45 Mon Sep 17 00:00:00 2001 From: liuhua <10215101452@stu.ecnu.edu.cn> Date: Thu, 19 Sep 2024 16:55:18 +0800 Subject: [PATCH] OpenAITTS (#2493) ### What problem does this PR solve? OpenAITTS ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn> Co-authored-by: Kevin Hu --- conf/llm_factories.json | 6 ++++ rag/llm/__init__.py | 3 +- rag/llm/tts_model.py | 62 ++++++++++++++++++++++++++++++----------- 3 files changed, 54 insertions(+), 17 deletions(-) diff --git a/conf/llm_factories.json b/conf/llm_factories.json index 848ca3b8a..6aece225e 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -77,6 +77,12 @@ "tags": "LLM,CHAT,IMAGE2TEXT", "max_tokens": 765, "model_type": "image2text" + }, + { + "llm_name": "tts-1", + "tags": "TTS", + "max_tokens": 2048, + "model_type": "tts" } ] }, diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 8868fce58..fbcea43b7 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -138,5 +138,6 @@ Seq2txtModel = { TTSModel = { "Fish Audio": FishAudioTTS, - "Tongyi-Qianwen": QwenTTS + "Tongyi-Qianwen": QwenTTS, + "OpenAI":OpenAITTS } \ No newline at end of file diff --git a/rag/llm/tts_model.py b/rag/llm/tts_model.py index 8bf930caa..4ffbc521c 100644 --- a/rag/llm/tts_model.py +++ b/rag/llm/tts_model.py @@ -14,6 +14,7 @@ # limitations under the License. # +import requests from typing import Annotated, Literal from abc import ABC import httpx @@ -23,6 +24,8 @@ from rag.utils import num_tokens_from_string import json import re import time + + class ServeReferenceAudio(BaseModel): audio: bytes text: str @@ -52,7 +55,7 @@ class Base(ABC): def tts(self, audio): pass - + def normalize_text(self, text): return re.sub(r'(\*\*|##\d+\$\$|#)', '', text) @@ -78,13 +81,13 @@ class FishAudioTTS(Base): with httpx.Client() as client: try: with client.stream( - method="POST", - url=self.base_url, - content=ormsgpack.packb( - request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC - ), - headers=self.headers, - timeout=None, + method="POST", + url=self.base_url, + content=ormsgpack.packb( + request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC + ), + headers=self.headers, + timeout=None, ) as response: if response.status_code == HTTPStatus.OK: for chunk in response.iter_bytes(): @@ -101,7 +104,7 @@ class FishAudioTTS(Base): class QwenTTS(Base): def __init__(self, key, model_name, base_url=""): import dashscope - + self.model_name = model_name dashscope.api_key = key @@ -109,11 +112,11 @@ class QwenTTS(Base): from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse from dashscope.audio.tts import ResultCallback, SpeechSynthesizer, SpeechSynthesisResult from collections import deque - + class Callback(ResultCallback): def __init__(self) -> None: self.dque = deque() - + def _run(self): while True: if not self.dque: @@ -144,13 +147,40 @@ class QwenTTS(Base): text = self.normalize_text(text) callback = Callback() SpeechSynthesizer.call(model=self.model_name, - text=text, - callback=callback, - format="mp3") + text=text, + callback=callback, + format="mp3") try: for data in callback._run(): yield data yield num_tokens_from_string(text) - + except Exception as e: - raise RuntimeError(f"**ERROR**: {e}") \ No newline at end of file + raise RuntimeError(f"**ERROR**: {e}") + + +class OpenAITTS(Base): + def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"): + self.api_key = key + self.model_name = model_name + self.base_url = base_url + self.headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + + def tts(self, text, voice="alloy"): + text = self.normalize_text(text) + payload = { + "model": self.model_name, + "voice": voice, + "input": text + } + + response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload, stream=True) + + if response.status_code != 200: + raise Exception(f"**Error**: {response.status_code}, {response.text}") + for chunk in response.iter_content(chunk_size=1024): + if chunk: + yield chunk