mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-19 12:39:59 +08:00
Feat: add TTS support for SILICONFLOW. (#6264)
### What problem does this PR solve? #6244 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
parent
41e112294b
commit
c6e1a2ca8a
@ -143,6 +143,7 @@ from .tts_model import (
|
||||
SparkTTS,
|
||||
XinferenceTTS,
|
||||
GPUStackTTS,
|
||||
SILICONFLOWTTS
|
||||
)
|
||||
|
||||
EmbeddingModel = {
|
||||
@ -278,4 +279,5 @@ TTSModel = {
|
||||
"XunFei Spark": SparkTTS,
|
||||
"Xinference": XinferenceTTS,
|
||||
"GPUStack": GPUStackTTS,
|
||||
"SILICONFLOW": SILICONFLOWTTS,
|
||||
}
|
||||
|
@ -356,6 +356,7 @@ class OllamaTTS(Base):
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
|
||||
class GPUStackTTS:
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
self.base_url = kwargs.get("base_url", None)
|
||||
@ -386,4 +387,38 @@ class GPUStackTTS:
|
||||
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
yield chunk
|
||||
yield chunk
|
||||
|
||||
|
||||
class SILICONFLOWTTS(Base):
|
||||
def __init__(self, key, model_name="FunAudioLLM/CosyVoice2-0.5B", base_url="https://api.siliconflow.cn/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.siliconflow.cn/v1"
|
||||
self.api_key = key
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
def tts(self, text, voice="anna"):
|
||||
text = self.normalize_text(text)
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"input": text,
|
||||
"voice": f"{self.model_name}:{voice}",
|
||||
"response_format": "mp3",
|
||||
"sample_rate": 123,
|
||||
"stream": True,
|
||||
"speed": 1,
|
||||
"gain": 0
|
||||
}
|
||||
|
||||
response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"**Error**: {response.status_code}, {response.text}")
|
||||
for chunk in response.iter_content():
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
Loading…
x
Reference in New Issue
Block a user