mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-01 20:10:41 +08:00
Feat: add TTS support for SILICONFLOW. (#6264)
### What problem does this PR solve? #6244 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
parent
41e112294b
commit
c6e1a2ca8a
@ -143,6 +143,7 @@ from .tts_model import (
|
|||||||
SparkTTS,
|
SparkTTS,
|
||||||
XinferenceTTS,
|
XinferenceTTS,
|
||||||
GPUStackTTS,
|
GPUStackTTS,
|
||||||
|
SILICONFLOWTTS
|
||||||
)
|
)
|
||||||
|
|
||||||
EmbeddingModel = {
|
EmbeddingModel = {
|
||||||
@ -278,4 +279,5 @@ TTSModel = {
|
|||||||
"XunFei Spark": SparkTTS,
|
"XunFei Spark": SparkTTS,
|
||||||
"Xinference": XinferenceTTS,
|
"Xinference": XinferenceTTS,
|
||||||
"GPUStack": GPUStackTTS,
|
"GPUStack": GPUStackTTS,
|
||||||
|
"SILICONFLOW": SILICONFLOWTTS,
|
||||||
}
|
}
|
||||||
|
@ -356,6 +356,7 @@ class OllamaTTS(Base):
|
|||||||
if chunk:
|
if chunk:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
class GPUStackTTS:
|
class GPUStackTTS:
|
||||||
def __init__(self, key, model_name, **kwargs):
|
def __init__(self, key, model_name, **kwargs):
|
||||||
self.base_url = kwargs.get("base_url", None)
|
self.base_url = kwargs.get("base_url", None)
|
||||||
@ -386,4 +387,38 @@ class GPUStackTTS:
|
|||||||
|
|
||||||
for chunk in response.iter_content(chunk_size=1024):
|
for chunk in response.iter_content(chunk_size=1024):
|
||||||
if chunk:
|
if chunk:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
class SILICONFLOWTTS(Base):
|
||||||
|
def __init__(self, key, model_name="FunAudioLLM/CosyVoice2-0.5B", base_url="https://api.siliconflow.cn/v1"):
|
||||||
|
if not base_url:
|
||||||
|
base_url = "https://api.siliconflow.cn/v1"
|
||||||
|
self.api_key = key
|
||||||
|
self.model_name = model_name
|
||||||
|
self.base_url = base_url
|
||||||
|
self.headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
def tts(self, text, voice="anna"):
|
||||||
|
text = self.normalize_text(text)
|
||||||
|
payload = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"input": text,
|
||||||
|
"voice": f"{self.model_name}:{voice}",
|
||||||
|
"response_format": "mp3",
|
||||||
|
"sample_rate": 123,
|
||||||
|
"stream": True,
|
||||||
|
"speed": 1,
|
||||||
|
"gain": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f"**Error**: {response.status_code}, {response.text}")
|
||||||
|
for chunk in response.iter_content():
|
||||||
|
if chunk:
|
||||||
|
yield chunk
|
||||||
|
Loading…
x
Reference in New Issue
Block a user