mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-12 04:59:01 +08:00
OpenAITTS (#2493)
### What problem does this PR solve? OpenAITTS ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn> Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
parent
af0b4b0828
commit
d545633a6c
@ -77,6 +77,12 @@
|
|||||||
"tags": "LLM,CHAT,IMAGE2TEXT",
|
"tags": "LLM,CHAT,IMAGE2TEXT",
|
||||||
"max_tokens": 765,
|
"max_tokens": 765,
|
||||||
"model_type": "image2text"
|
"model_type": "image2text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "tts-1",
|
||||||
|
"tags": "TTS",
|
||||||
|
"max_tokens": 2048,
|
||||||
|
"model_type": "tts"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -138,5 +138,6 @@ Seq2txtModel = {
|
|||||||
|
|
||||||
TTSModel = {
|
TTSModel = {
|
||||||
"Fish Audio": FishAudioTTS,
|
"Fish Audio": FishAudioTTS,
|
||||||
"Tongyi-Qianwen": QwenTTS
|
"Tongyi-Qianwen": QwenTTS,
|
||||||
|
"OpenAI":OpenAITTS
|
||||||
}
|
}
|
@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import requests
|
||||||
from typing import Annotated, Literal
|
from typing import Annotated, Literal
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
import httpx
|
import httpx
|
||||||
@ -23,6 +24,8 @@ from rag.utils import num_tokens_from_string
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
||||||
class ServeReferenceAudio(BaseModel):
|
class ServeReferenceAudio(BaseModel):
|
||||||
audio: bytes
|
audio: bytes
|
||||||
text: str
|
text: str
|
||||||
@ -52,7 +55,7 @@ class Base(ABC):
|
|||||||
|
|
||||||
def tts(self, audio):
|
def tts(self, audio):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def normalize_text(self, text):
|
def normalize_text(self, text):
|
||||||
return re.sub(r'(\*\*|##\d+\$\$|#)', '', text)
|
return re.sub(r'(\*\*|##\d+\$\$|#)', '', text)
|
||||||
|
|
||||||
@ -78,13 +81,13 @@ class FishAudioTTS(Base):
|
|||||||
with httpx.Client() as client:
|
with httpx.Client() as client:
|
||||||
try:
|
try:
|
||||||
with client.stream(
|
with client.stream(
|
||||||
method="POST",
|
method="POST",
|
||||||
url=self.base_url,
|
url=self.base_url,
|
||||||
content=ormsgpack.packb(
|
content=ormsgpack.packb(
|
||||||
request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC
|
request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC
|
||||||
),
|
),
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
timeout=None,
|
timeout=None,
|
||||||
) as response:
|
) as response:
|
||||||
if response.status_code == HTTPStatus.OK:
|
if response.status_code == HTTPStatus.OK:
|
||||||
for chunk in response.iter_bytes():
|
for chunk in response.iter_bytes():
|
||||||
@ -101,7 +104,7 @@ class FishAudioTTS(Base):
|
|||||||
class QwenTTS(Base):
|
class QwenTTS(Base):
|
||||||
def __init__(self, key, model_name, base_url=""):
|
def __init__(self, key, model_name, base_url=""):
|
||||||
import dashscope
|
import dashscope
|
||||||
|
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
dashscope.api_key = key
|
dashscope.api_key = key
|
||||||
|
|
||||||
@ -109,11 +112,11 @@ class QwenTTS(Base):
|
|||||||
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
|
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
|
||||||
from dashscope.audio.tts import ResultCallback, SpeechSynthesizer, SpeechSynthesisResult
|
from dashscope.audio.tts import ResultCallback, SpeechSynthesizer, SpeechSynthesisResult
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
class Callback(ResultCallback):
|
class Callback(ResultCallback):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.dque = deque()
|
self.dque = deque()
|
||||||
|
|
||||||
def _run(self):
|
def _run(self):
|
||||||
while True:
|
while True:
|
||||||
if not self.dque:
|
if not self.dque:
|
||||||
@ -144,13 +147,40 @@ class QwenTTS(Base):
|
|||||||
text = self.normalize_text(text)
|
text = self.normalize_text(text)
|
||||||
callback = Callback()
|
callback = Callback()
|
||||||
SpeechSynthesizer.call(model=self.model_name,
|
SpeechSynthesizer.call(model=self.model_name,
|
||||||
text=text,
|
text=text,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
format="mp3")
|
format="mp3")
|
||||||
try:
|
try:
|
||||||
for data in callback._run():
|
for data in callback._run():
|
||||||
yield data
|
yield data
|
||||||
yield num_tokens_from_string(text)
|
yield num_tokens_from_string(text)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"**ERROR**: {e}")
|
raise RuntimeError(f"**ERROR**: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAITTS(Base):
|
||||||
|
def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
|
||||||
|
self.api_key = key
|
||||||
|
self.model_name = model_name
|
||||||
|
self.base_url = base_url
|
||||||
|
self.headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
def tts(self, text, voice="alloy"):
|
||||||
|
text = self.normalize_text(text)
|
||||||
|
payload = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"voice": voice,
|
||||||
|
"input": text
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload, stream=True)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f"**Error**: {response.status_code}, {response.text}")
|
||||||
|
for chunk in response.iter_content(chunk_size=1024):
|
||||||
|
if chunk:
|
||||||
|
yield chunk
|
||||||
|
Loading…
x
Reference in New Issue
Block a user