mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-13 22:35:53 +08:00
SparkTTS (#2535)
### What problem does this PR solve? SparkTTS ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn>
This commit is contained in:
parent
38e3475714
commit
d9c2a128a5
@ -161,7 +161,10 @@ def add_llm():
|
|||||||
|
|
||||||
elif factory =="XunFei Spark":
|
elif factory =="XunFei Spark":
|
||||||
llm_name = req["llm_name"]
|
llm_name = req["llm_name"]
|
||||||
api_key = req.get("spark_api_password","xxxxxxxxxxxxxxx")
|
if req["model_type"] == "chat":
|
||||||
|
api_key = req.get("spark_api_password", "xxxxxxxxxxxxxxx")
|
||||||
|
elif req["model_type"] == "tts":
|
||||||
|
api_key = apikey_json(["spark_app_id", "spark_api_secret","spark_api_key"])
|
||||||
|
|
||||||
elif factory == "BaiduYiyan":
|
elif factory == "BaiduYiyan":
|
||||||
llm_name = req["llm_name"]
|
llm_name = req["llm_name"]
|
||||||
|
@ -139,5 +139,6 @@ Seq2txtModel = {
|
|||||||
TTSModel = {
|
TTSModel = {
|
||||||
"Fish Audio": FishAudioTTS,
|
"Fish Audio": FishAudioTTS,
|
||||||
"Tongyi-Qianwen": QwenTTS,
|
"Tongyi-Qianwen": QwenTTS,
|
||||||
"OpenAI":OpenAITTS
|
"OpenAI":OpenAITTS,
|
||||||
|
"XunFei Spark":SparkTTS
|
||||||
}
|
}
|
@ -14,16 +14,30 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
import requests
|
import _thread as thread
|
||||||
from typing import Annotated, Literal
|
import base64
|
||||||
|
import datetime
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import json
|
||||||
|
import queue
|
||||||
|
import re
|
||||||
|
import ssl
|
||||||
|
import time
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
from datetime import datetime
|
||||||
|
from time import mktime
|
||||||
|
from typing import Annotated, Literal
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
from wsgiref.handlers import format_date_time
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import ormsgpack
|
import ormsgpack
|
||||||
|
import requests
|
||||||
|
import websocket
|
||||||
from pydantic import BaseModel, conint
|
from pydantic import BaseModel, conint
|
||||||
|
|
||||||
from rag.utils import num_tokens_from_string
|
from rag.utils import num_tokens_from_string
|
||||||
import json
|
|
||||||
import re
|
|
||||||
import time
|
|
||||||
|
|
||||||
|
|
||||||
class ServeReferenceAudio(BaseModel):
|
class ServeReferenceAudio(BaseModel):
|
||||||
@ -161,7 +175,7 @@ class QwenTTS(Base):
|
|||||||
|
|
||||||
class OpenAITTS(Base):
|
class OpenAITTS(Base):
|
||||||
def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
|
def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
|
||||||
if not base_url: base_url="https://api.openai.com/v1"
|
if not base_url: base_url = "https://api.openai.com/v1"
|
||||||
self.api_key = key
|
self.api_key = key
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
@ -185,3 +199,101 @@ class OpenAITTS(Base):
|
|||||||
for chunk in response.iter_content():
|
for chunk in response.iter_content():
|
||||||
if chunk:
|
if chunk:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
class SparkTTS:
|
||||||
|
STATUS_FIRST_FRAME = 0
|
||||||
|
STATUS_CONTINUE_FRAME = 1
|
||||||
|
STATUS_LAST_FRAME = 2
|
||||||
|
|
||||||
|
def __init__(self, key, model_name, base_url=""):
|
||||||
|
key = json.loads(key)
|
||||||
|
self.APPID = key.get("spark_app_id", "xxxxxxx")
|
||||||
|
self.APISecret = key.get("spark_api_secret", "xxxxxxx")
|
||||||
|
self.APIKey = key.get("spark_api_key", "xxxxxx")
|
||||||
|
self.model_name = model_name
|
||||||
|
self.CommonArgs = {"app_id": self.APPID}
|
||||||
|
self.audio_queue = queue.Queue()
|
||||||
|
|
||||||
|
# 用来存储音频数据
|
||||||
|
|
||||||
|
# 生成url
|
||||||
|
def create_url(self):
|
||||||
|
url = 'wss://tts-api.xfyun.cn/v2/tts'
|
||||||
|
now = datetime.now()
|
||||||
|
date = format_date_time(mktime(now.timetuple()))
|
||||||
|
signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
|
||||||
|
signature_origin += "date: " + date + "\n"
|
||||||
|
signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
|
||||||
|
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||||||
|
digestmod=hashlib.sha256).digest()
|
||||||
|
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||||
|
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
|
||||||
|
self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
|
||||||
|
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||||
|
v = {
|
||||||
|
"authorization": authorization,
|
||||||
|
"date": date,
|
||||||
|
"host": "ws-api.xfyun.cn"
|
||||||
|
}
|
||||||
|
url = url + '?' + urlencode(v)
|
||||||
|
return url
|
||||||
|
|
||||||
|
def tts(self, text):
|
||||||
|
BusinessArgs = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.model_name, "tte": "utf8"}
|
||||||
|
Data = {"status": 2, "text": base64.b64encode(text.encode('utf-8')).decode('utf-8')}
|
||||||
|
CommonArgs = {"app_id": self.APPID}
|
||||||
|
audio_queue = self.audio_queue
|
||||||
|
model_name = self.model_name
|
||||||
|
|
||||||
|
class Callback:
|
||||||
|
def __init__(self):
|
||||||
|
self.audio_queue = audio_queue
|
||||||
|
|
||||||
|
def on_message(self, ws, message):
|
||||||
|
message = json.loads(message)
|
||||||
|
code = message["code"]
|
||||||
|
sid = message["sid"]
|
||||||
|
audio = message["data"]["audio"]
|
||||||
|
audio = base64.b64decode(audio)
|
||||||
|
status = message["data"]["status"]
|
||||||
|
if status == 2:
|
||||||
|
ws.close()
|
||||||
|
if code != 0:
|
||||||
|
errMsg = message["message"]
|
||||||
|
raise Exception(f"sid:{sid} call error:{errMsg} code:{code}")
|
||||||
|
else:
|
||||||
|
self.audio_queue.put(audio)
|
||||||
|
|
||||||
|
def on_error(self, ws, error):
|
||||||
|
raise Exception(error)
|
||||||
|
|
||||||
|
def on_close(self, ws, close_status_code, close_msg):
|
||||||
|
self.audio_queue.put(None) # 放入 None 作为结束标志
|
||||||
|
|
||||||
|
def on_open(self, ws):
|
||||||
|
def run(*args):
|
||||||
|
d = {"common": CommonArgs,
|
||||||
|
"business": BusinessArgs,
|
||||||
|
"data": Data}
|
||||||
|
ws.send(json.dumps(d))
|
||||||
|
|
||||||
|
thread.start_new_thread(run, ())
|
||||||
|
|
||||||
|
wsUrl = self.create_url()
|
||||||
|
websocket.enableTrace(False)
|
||||||
|
a = Callback()
|
||||||
|
ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close,
|
||||||
|
on_message=a.on_message)
|
||||||
|
status_code = 0
|
||||||
|
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||||
|
while True:
|
||||||
|
audio_chunk = self.audio_queue.get()
|
||||||
|
if audio_chunk is None:
|
||||||
|
if status_code == 0:
|
||||||
|
raise Exception(
|
||||||
|
f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.")
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
status_code = 1
|
||||||
|
yield audio_chunk
|
||||||
|
@ -94,6 +94,8 @@ vertexai==1.64.0
|
|||||||
volcengine==1.0.146
|
volcengine==1.0.146
|
||||||
voyageai==0.2.3
|
voyageai==0.2.3
|
||||||
webdriver_manager==4.0.1
|
webdriver_manager==4.0.1
|
||||||
|
websocket==0.2.1
|
||||||
|
websocket-client==1.8.0
|
||||||
Werkzeug==3.0.3
|
Werkzeug==3.0.3
|
||||||
wikipedia==1.4.0
|
wikipedia==1.4.0
|
||||||
word2number==1.1
|
word2number==1.1
|
||||||
|
@ -551,6 +551,12 @@ The above is the content you need to summarize.`,
|
|||||||
SparkModelNameMessage: 'Please select Spark model',
|
SparkModelNameMessage: 'Please select Spark model',
|
||||||
addSparkAPIPassword: 'Spark APIPassword',
|
addSparkAPIPassword: 'Spark APIPassword',
|
||||||
SparkAPIPasswordMessage: 'please input your APIPassword',
|
SparkAPIPasswordMessage: 'please input your APIPassword',
|
||||||
|
addSparkAPPID: 'Spark APPID',
|
||||||
|
SparkAPPIDMessage: 'please input your APPID',
|
||||||
|
addSparkAPISecret: 'Spark APISecret',
|
||||||
|
SparkAPISecretMessage: 'please input your APISecret',
|
||||||
|
addSparkAPIKey: 'Spark APIKey',
|
||||||
|
SparkAPIKeyMessage: 'please input your APIKey',
|
||||||
yiyanModelNameMessage: 'Please input model name',
|
yiyanModelNameMessage: 'Please input model name',
|
||||||
addyiyanAK: 'yiyan API KEY',
|
addyiyanAK: 'yiyan API KEY',
|
||||||
yiyanAKMessage: 'Please input your API KEY',
|
yiyanAKMessage: 'Please input your API KEY',
|
||||||
|
@ -512,6 +512,12 @@ export default {
|
|||||||
SparkModelNameMessage: '請選擇星火模型!',
|
SparkModelNameMessage: '請選擇星火模型!',
|
||||||
addSparkAPIPassword: '星火 APIPassword',
|
addSparkAPIPassword: '星火 APIPassword',
|
||||||
SparkAPIPasswordMessage: '請輸入 APIPassword',
|
SparkAPIPasswordMessage: '請輸入 APIPassword',
|
||||||
|
addSparkAPPID: '星火 APPID',
|
||||||
|
SparkAPPIDMessage: '請輸入 APPID',
|
||||||
|
addSparkAPISecret: '星火 APISecret',
|
||||||
|
SparkAPISecretMessage: '請輸入 APISecret',
|
||||||
|
addSparkAPIKey: '星火 APIKey',
|
||||||
|
SparkAPIKeyMessage: '請輸入 APIKey',
|
||||||
yiyanModelNameMessage: '輸入模型名稱',
|
yiyanModelNameMessage: '輸入模型名稱',
|
||||||
addyiyanAK: '一言 API KEY',
|
addyiyanAK: '一言 API KEY',
|
||||||
yiyanAKMessage: '請輸入 API KEY',
|
yiyanAKMessage: '請輸入 API KEY',
|
||||||
|
@ -529,6 +529,12 @@ export default {
|
|||||||
SparkModelNameMessage: '请选择星火模型!',
|
SparkModelNameMessage: '请选择星火模型!',
|
||||||
addSparkAPIPassword: '星火 APIPassword',
|
addSparkAPIPassword: '星火 APIPassword',
|
||||||
SparkAPIPasswordMessage: '请输入 APIPassword',
|
SparkAPIPasswordMessage: '请输入 APIPassword',
|
||||||
|
addSparkAPPID: '星火 APPID',
|
||||||
|
SparkAPPIDMessage: '请输入 APPID',
|
||||||
|
addSparkAPISecret: '星火 APISecret',
|
||||||
|
SparkAPISecretMessage: '请输入 APISecret',
|
||||||
|
addSparkAPIKey: '星火 APIKey',
|
||||||
|
SparkAPIKeyMessage: '请输入 APIKey',
|
||||||
yiyanModelNameMessage: '请输入模型名称',
|
yiyanModelNameMessage: '请输入模型名称',
|
||||||
addyiyanAK: '一言 API KEY',
|
addyiyanAK: '一言 API KEY',
|
||||||
yiyanAKMessage: '请输入 API KEY',
|
yiyanAKMessage: '请输入 API KEY',
|
||||||
|
@ -7,6 +7,9 @@ import omit from 'lodash/omit';
|
|||||||
type FieldType = IAddLlmRequestBody & {
|
type FieldType = IAddLlmRequestBody & {
|
||||||
vision: boolean;
|
vision: boolean;
|
||||||
spark_api_password: string;
|
spark_api_password: string;
|
||||||
|
spark_app_id: string;
|
||||||
|
spark_api_secret: string;
|
||||||
|
spark_api_key: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
const { Option } = Select;
|
const { Option } = Select;
|
||||||
@ -63,28 +66,67 @@ const SparkModal = ({
|
|||||||
>
|
>
|
||||||
<Select placeholder={t('modelTypeMessage')}>
|
<Select placeholder={t('modelTypeMessage')}>
|
||||||
<Option value="chat">chat</Option>
|
<Option value="chat">chat</Option>
|
||||||
|
<Option value="tts">tts</Option>
|
||||||
</Select>
|
</Select>
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
<Form.Item<FieldType>
|
<Form.Item<FieldType>
|
||||||
label={t('modelName')}
|
label={t('modelName')}
|
||||||
name="llm_name"
|
name="llm_name"
|
||||||
initialValue={'Spark-Max'}
|
|
||||||
rules={[{ required: true, message: t('SparkModelNameMessage') }]}
|
rules={[{ required: true, message: t('SparkModelNameMessage') }]}
|
||||||
>
|
>
|
||||||
<Select placeholder={t('modelTypeMessage')}>
|
<Input placeholder={t('modelNameMessage')} />
|
||||||
<Option value="Spark-Max">Spark-Max</Option>
|
|
||||||
<Option value="Spark-Lite">Spark-Lite</Option>
|
|
||||||
<Option value="Spark-Pro">Spark-Pro</Option>
|
|
||||||
<Option value="Spark-Pro-128K">Spark-Pro-128K</Option>
|
|
||||||
<Option value="Spark-4.0-Ultra">Spark-4.0-Ultra</Option>
|
|
||||||
</Select>
|
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
<Form.Item<FieldType>
|
<Form.Item noStyle dependencies={['model_type']}>
|
||||||
label={t('addSparkAPIPassword')}
|
{({ getFieldValue }) =>
|
||||||
name="spark_api_password"
|
getFieldValue('model_type') === 'chat' && (
|
||||||
rules={[{ required: true, message: t('SparkAPIPasswordMessage') }]}
|
<Form.Item<FieldType>
|
||||||
>
|
label={t('addSparkAPIPassword')}
|
||||||
<Input placeholder={t('SparkAPIPasswordMessage')} />
|
name="spark_api_password"
|
||||||
|
rules={[{ required: true, message: t('SparkAPIPasswordMessage') }]}
|
||||||
|
>
|
||||||
|
<Input placeholder={t('SparkAPIPasswordMessage')} />
|
||||||
|
</Form.Item>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item noStyle dependencies={['model_type']}>
|
||||||
|
{({ getFieldValue }) =>
|
||||||
|
getFieldValue('model_type') === 'tts' && (
|
||||||
|
<Form.Item<FieldType>
|
||||||
|
label={t('addSparkAPPID')}
|
||||||
|
name="spark_app_id"
|
||||||
|
rules={[{ required: true, message: t('SparkAPPIDMessage') }]}
|
||||||
|
>
|
||||||
|
<Input placeholder={t('SparkAPPIDMessage')} />
|
||||||
|
</Form.Item>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item noStyle dependencies={['model_type']}>
|
||||||
|
{({ getFieldValue }) =>
|
||||||
|
getFieldValue('model_type') === 'tts' && (
|
||||||
|
<Form.Item<FieldType>
|
||||||
|
label={t('addSparkAPISecret')}
|
||||||
|
name="spark_api_secret"
|
||||||
|
rules={[{ required: true, message: t('SparkAPISecretMessage') }]}
|
||||||
|
>
|
||||||
|
<Input placeholder={t('SparkAPISecretMessage')} />
|
||||||
|
</Form.Item>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item noStyle dependencies={['model_type']}>
|
||||||
|
{({ getFieldValue }) =>
|
||||||
|
getFieldValue('model_type') === 'tts' && (
|
||||||
|
<Form.Item<FieldType>
|
||||||
|
label={t('addSparkAPIKey')}
|
||||||
|
name="spark_api_key"
|
||||||
|
rules={[{ required: true, message: t('SparkAPIKeyMessage') }]}
|
||||||
|
>
|
||||||
|
<Input placeholder={t('SparkAPIKeyMessage')} />
|
||||||
|
</Form.Item>
|
||||||
|
)
|
||||||
|
}
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
</Form>
|
</Form>
|
||||||
</Modal>
|
</Modal>
|
||||||
|
Loading…
x
Reference in New Issue
Block a user