### What problem does this PR solve?

OpenAITTS

### Type of change


- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
liuhua 2024-09-19 16:55:18 +08:00 committed by GitHub
parent af0b4b0828
commit d545633a6c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 54 additions and 17 deletions

View File

@ -77,6 +77,12 @@
"tags": "LLM,CHAT,IMAGE2TEXT", "tags": "LLM,CHAT,IMAGE2TEXT",
"max_tokens": 765, "max_tokens": 765,
"model_type": "image2text" "model_type": "image2text"
},
{
"llm_name": "tts-1",
"tags": "TTS",
"max_tokens": 2048,
"model_type": "tts"
} }
] ]
}, },

View File

@ -138,5 +138,6 @@ Seq2txtModel = {
TTSModel = { TTSModel = {
"Fish Audio": FishAudioTTS, "Fish Audio": FishAudioTTS,
"Tongyi-Qianwen": QwenTTS "Tongyi-Qianwen": QwenTTS,
"OpenAI":OpenAITTS
} }

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
# #
import requests
from typing import Annotated, Literal from typing import Annotated, Literal
from abc import ABC from abc import ABC
import httpx import httpx
@ -23,6 +24,8 @@ from rag.utils import num_tokens_from_string
import json import json
import re import re
import time import time
class ServeReferenceAudio(BaseModel): class ServeReferenceAudio(BaseModel):
audio: bytes audio: bytes
text: str text: str
@ -52,7 +55,7 @@ class Base(ABC):
def tts(self, audio): def tts(self, audio):
pass pass
def normalize_text(self, text): def normalize_text(self, text):
return re.sub(r'(\*\*|##\d+\$\$|#)', '', text) return re.sub(r'(\*\*|##\d+\$\$|#)', '', text)
@ -78,13 +81,13 @@ class FishAudioTTS(Base):
with httpx.Client() as client: with httpx.Client() as client:
try: try:
with client.stream( with client.stream(
method="POST", method="POST",
url=self.base_url, url=self.base_url,
content=ormsgpack.packb( content=ormsgpack.packb(
request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC
), ),
headers=self.headers, headers=self.headers,
timeout=None, timeout=None,
) as response: ) as response:
if response.status_code == HTTPStatus.OK: if response.status_code == HTTPStatus.OK:
for chunk in response.iter_bytes(): for chunk in response.iter_bytes():
@ -101,7 +104,7 @@ class FishAudioTTS(Base):
class QwenTTS(Base): class QwenTTS(Base):
def __init__(self, key, model_name, base_url=""): def __init__(self, key, model_name, base_url=""):
import dashscope import dashscope
self.model_name = model_name self.model_name = model_name
dashscope.api_key = key dashscope.api_key = key
@ -109,11 +112,11 @@ class QwenTTS(Base):
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
from dashscope.audio.tts import ResultCallback, SpeechSynthesizer, SpeechSynthesisResult from dashscope.audio.tts import ResultCallback, SpeechSynthesizer, SpeechSynthesisResult
from collections import deque from collections import deque
class Callback(ResultCallback): class Callback(ResultCallback):
def __init__(self) -> None: def __init__(self) -> None:
self.dque = deque() self.dque = deque()
def _run(self): def _run(self):
while True: while True:
if not self.dque: if not self.dque:
@ -144,13 +147,40 @@ class QwenTTS(Base):
text = self.normalize_text(text) text = self.normalize_text(text)
callback = Callback() callback = Callback()
SpeechSynthesizer.call(model=self.model_name, SpeechSynthesizer.call(model=self.model_name,
text=text, text=text,
callback=callback, callback=callback,
format="mp3") format="mp3")
try: try:
for data in callback._run(): for data in callback._run():
yield data yield data
yield num_tokens_from_string(text) yield num_tokens_from_string(text)
except Exception as e: except Exception as e:
raise RuntimeError(f"**ERROR**: {e}") raise RuntimeError(f"**ERROR**: {e}")
class OpenAITTS(Base):
def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
self.api_key = key
self.model_name = model_name
self.base_url = base_url
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
def tts(self, text, voice="alloy"):
text = self.normalize_text(text)
payload = {
"model": self.model_name,
"voice": voice,
"input": text
}
response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload, stream=True)
if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")
for chunk in response.iter_content(chunk_size=1024):
if chunk:
yield chunk