mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-14 04:36:01 +08:00
SparkTTS (#2535)
### What problem does this PR solve? SparkTTS ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn>
This commit is contained in:
parent
38e3475714
commit
d9c2a128a5
@ -161,7 +161,10 @@ def add_llm():
|
||||
|
||||
elif factory =="XunFei Spark":
|
||||
llm_name = req["llm_name"]
|
||||
api_key = req.get("spark_api_password","xxxxxxxxxxxxxxx")
|
||||
if req["model_type"] == "chat":
|
||||
api_key = req.get("spark_api_password", "xxxxxxxxxxxxxxx")
|
||||
elif req["model_type"] == "tts":
|
||||
api_key = apikey_json(["spark_app_id", "spark_api_secret","spark_api_key"])
|
||||
|
||||
elif factory == "BaiduYiyan":
|
||||
llm_name = req["llm_name"]
|
||||
|
@ -139,5 +139,6 @@ Seq2txtModel = {
|
||||
TTSModel = {
|
||||
"Fish Audio": FishAudioTTS,
|
||||
"Tongyi-Qianwen": QwenTTS,
|
||||
"OpenAI":OpenAITTS
|
||||
"OpenAI":OpenAITTS,
|
||||
"XunFei Spark":SparkTTS
|
||||
}
|
@ -14,16 +14,30 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import requests
|
||||
from typing import Annotated, Literal
|
||||
import _thread as thread
|
||||
import base64
|
||||
import datetime
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import queue
|
||||
import re
|
||||
import ssl
|
||||
import time
|
||||
from abc import ABC
|
||||
from datetime import datetime
|
||||
from time import mktime
|
||||
from typing import Annotated, Literal
|
||||
from urllib.parse import urlencode
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
import httpx
|
||||
import ormsgpack
|
||||
import requests
|
||||
import websocket
|
||||
from pydantic import BaseModel, conint
|
||||
|
||||
from rag.utils import num_tokens_from_string
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
|
||||
|
||||
class ServeReferenceAudio(BaseModel):
|
||||
@ -161,7 +175,7 @@ class QwenTTS(Base):
|
||||
|
||||
class OpenAITTS(Base):
|
||||
def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
|
||||
if not base_url: base_url="https://api.openai.com/v1"
|
||||
if not base_url: base_url = "https://api.openai.com/v1"
|
||||
self.api_key = key
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
@ -185,3 +199,101 @@ class OpenAITTS(Base):
|
||||
for chunk in response.iter_content():
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
|
||||
class SparkTTS:
|
||||
STATUS_FIRST_FRAME = 0
|
||||
STATUS_CONTINUE_FRAME = 1
|
||||
STATUS_LAST_FRAME = 2
|
||||
|
||||
def __init__(self, key, model_name, base_url=""):
|
||||
key = json.loads(key)
|
||||
self.APPID = key.get("spark_app_id", "xxxxxxx")
|
||||
self.APISecret = key.get("spark_api_secret", "xxxxxxx")
|
||||
self.APIKey = key.get("spark_api_key", "xxxxxx")
|
||||
self.model_name = model_name
|
||||
self.CommonArgs = {"app_id": self.APPID}
|
||||
self.audio_queue = queue.Queue()
|
||||
|
||||
# 用来存储音频数据
|
||||
|
||||
# 生成url
|
||||
def create_url(self):
|
||||
url = 'wss://tts-api.xfyun.cn/v2/tts'
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
|
||||
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||||
digestmod=hashlib.sha256).digest()
|
||||
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
|
||||
self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
|
||||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||
v = {
|
||||
"authorization": authorization,
|
||||
"date": date,
|
||||
"host": "ws-api.xfyun.cn"
|
||||
}
|
||||
url = url + '?' + urlencode(v)
|
||||
return url
|
||||
|
||||
def tts(self, text):
|
||||
BusinessArgs = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.model_name, "tte": "utf8"}
|
||||
Data = {"status": 2, "text": base64.b64encode(text.encode('utf-8')).decode('utf-8')}
|
||||
CommonArgs = {"app_id": self.APPID}
|
||||
audio_queue = self.audio_queue
|
||||
model_name = self.model_name
|
||||
|
||||
class Callback:
|
||||
def __init__(self):
|
||||
self.audio_queue = audio_queue
|
||||
|
||||
def on_message(self, ws, message):
|
||||
message = json.loads(message)
|
||||
code = message["code"]
|
||||
sid = message["sid"]
|
||||
audio = message["data"]["audio"]
|
||||
audio = base64.b64decode(audio)
|
||||
status = message["data"]["status"]
|
||||
if status == 2:
|
||||
ws.close()
|
||||
if code != 0:
|
||||
errMsg = message["message"]
|
||||
raise Exception(f"sid:{sid} call error:{errMsg} code:{code}")
|
||||
else:
|
||||
self.audio_queue.put(audio)
|
||||
|
||||
def on_error(self, ws, error):
|
||||
raise Exception(error)
|
||||
|
||||
def on_close(self, ws, close_status_code, close_msg):
|
||||
self.audio_queue.put(None) # 放入 None 作为结束标志
|
||||
|
||||
def on_open(self, ws):
|
||||
def run(*args):
|
||||
d = {"common": CommonArgs,
|
||||
"business": BusinessArgs,
|
||||
"data": Data}
|
||||
ws.send(json.dumps(d))
|
||||
|
||||
thread.start_new_thread(run, ())
|
||||
|
||||
wsUrl = self.create_url()
|
||||
websocket.enableTrace(False)
|
||||
a = Callback()
|
||||
ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close,
|
||||
on_message=a.on_message)
|
||||
status_code = 0
|
||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||
while True:
|
||||
audio_chunk = self.audio_queue.get()
|
||||
if audio_chunk is None:
|
||||
if status_code == 0:
|
||||
raise Exception(
|
||||
f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.")
|
||||
else:
|
||||
break
|
||||
status_code = 1
|
||||
yield audio_chunk
|
||||
|
@ -94,6 +94,8 @@ vertexai==1.64.0
|
||||
volcengine==1.0.146
|
||||
voyageai==0.2.3
|
||||
webdriver_manager==4.0.1
|
||||
websocket==0.2.1
|
||||
websocket-client==1.8.0
|
||||
Werkzeug==3.0.3
|
||||
wikipedia==1.4.0
|
||||
word2number==1.1
|
||||
|
@ -551,6 +551,12 @@ The above is the content you need to summarize.`,
|
||||
SparkModelNameMessage: 'Please select Spark model',
|
||||
addSparkAPIPassword: 'Spark APIPassword',
|
||||
SparkAPIPasswordMessage: 'please input your APIPassword',
|
||||
addSparkAPPID: 'Spark APPID',
|
||||
SparkAPPIDMessage: 'please input your APPID',
|
||||
addSparkAPISecret: 'Spark APISecret',
|
||||
SparkAPISecretMessage: 'please input your APISecret',
|
||||
addSparkAPIKey: 'Spark APIKey',
|
||||
SparkAPIKeyMessage: 'please input your APIKey',
|
||||
yiyanModelNameMessage: 'Please input model name',
|
||||
addyiyanAK: 'yiyan API KEY',
|
||||
yiyanAKMessage: 'Please input your API KEY',
|
||||
|
@ -512,6 +512,12 @@ export default {
|
||||
SparkModelNameMessage: '請選擇星火模型!',
|
||||
addSparkAPIPassword: '星火 APIPassword',
|
||||
SparkAPIPasswordMessage: '請輸入 APIPassword',
|
||||
addSparkAPPID: '星火 APPID',
|
||||
SparkAPPIDMessage: '請輸入 APPID',
|
||||
addSparkAPISecret: '星火 APISecret',
|
||||
SparkAPISecretMessage: '請輸入 APISecret',
|
||||
addSparkAPIKey: '星火 APIKey',
|
||||
SparkAPIKeyMessage: '請輸入 APIKey',
|
||||
yiyanModelNameMessage: '輸入模型名稱',
|
||||
addyiyanAK: '一言 API KEY',
|
||||
yiyanAKMessage: '請輸入 API KEY',
|
||||
|
@ -529,6 +529,12 @@ export default {
|
||||
SparkModelNameMessage: '请选择星火模型!',
|
||||
addSparkAPIPassword: '星火 APIPassword',
|
||||
SparkAPIPasswordMessage: '请输入 APIPassword',
|
||||
addSparkAPPID: '星火 APPID',
|
||||
SparkAPPIDMessage: '请输入 APPID',
|
||||
addSparkAPISecret: '星火 APISecret',
|
||||
SparkAPISecretMessage: '请输入 APISecret',
|
||||
addSparkAPIKey: '星火 APIKey',
|
||||
SparkAPIKeyMessage: '请输入 APIKey',
|
||||
yiyanModelNameMessage: '请输入模型名称',
|
||||
addyiyanAK: '一言 API KEY',
|
||||
yiyanAKMessage: '请输入 API KEY',
|
||||
|
@ -7,6 +7,9 @@ import omit from 'lodash/omit';
|
||||
type FieldType = IAddLlmRequestBody & {
|
||||
vision: boolean;
|
||||
spark_api_password: string;
|
||||
spark_app_id: string;
|
||||
spark_api_secret: string;
|
||||
spark_api_key: string;
|
||||
};
|
||||
|
||||
const { Option } = Select;
|
||||
@ -63,22 +66,19 @@ const SparkModal = ({
|
||||
>
|
||||
<Select placeholder={t('modelTypeMessage')}>
|
||||
<Option value="chat">chat</Option>
|
||||
<Option value="tts">tts</Option>
|
||||
</Select>
|
||||
</Form.Item>
|
||||
<Form.Item<FieldType>
|
||||
label={t('modelName')}
|
||||
name="llm_name"
|
||||
initialValue={'Spark-Max'}
|
||||
rules={[{ required: true, message: t('SparkModelNameMessage') }]}
|
||||
>
|
||||
<Select placeholder={t('modelTypeMessage')}>
|
||||
<Option value="Spark-Max">Spark-Max</Option>
|
||||
<Option value="Spark-Lite">Spark-Lite</Option>
|
||||
<Option value="Spark-Pro">Spark-Pro</Option>
|
||||
<Option value="Spark-Pro-128K">Spark-Pro-128K</Option>
|
||||
<Option value="Spark-4.0-Ultra">Spark-4.0-Ultra</Option>
|
||||
</Select>
|
||||
<Input placeholder={t('modelNameMessage')} />
|
||||
</Form.Item>
|
||||
<Form.Item noStyle dependencies={['model_type']}>
|
||||
{({ getFieldValue }) =>
|
||||
getFieldValue('model_type') === 'chat' && (
|
||||
<Form.Item<FieldType>
|
||||
label={t('addSparkAPIPassword')}
|
||||
name="spark_api_password"
|
||||
@ -86,6 +86,48 @@ const SparkModal = ({
|
||||
>
|
||||
<Input placeholder={t('SparkAPIPasswordMessage')} />
|
||||
</Form.Item>
|
||||
)
|
||||
}
|
||||
</Form.Item>
|
||||
<Form.Item noStyle dependencies={['model_type']}>
|
||||
{({ getFieldValue }) =>
|
||||
getFieldValue('model_type') === 'tts' && (
|
||||
<Form.Item<FieldType>
|
||||
label={t('addSparkAPPID')}
|
||||
name="spark_app_id"
|
||||
rules={[{ required: true, message: t('SparkAPPIDMessage') }]}
|
||||
>
|
||||
<Input placeholder={t('SparkAPPIDMessage')} />
|
||||
</Form.Item>
|
||||
)
|
||||
}
|
||||
</Form.Item>
|
||||
<Form.Item noStyle dependencies={['model_type']}>
|
||||
{({ getFieldValue }) =>
|
||||
getFieldValue('model_type') === 'tts' && (
|
||||
<Form.Item<FieldType>
|
||||
label={t('addSparkAPISecret')}
|
||||
name="spark_api_secret"
|
||||
rules={[{ required: true, message: t('SparkAPISecretMessage') }]}
|
||||
>
|
||||
<Input placeholder={t('SparkAPISecretMessage')} />
|
||||
</Form.Item>
|
||||
)
|
||||
}
|
||||
</Form.Item>
|
||||
<Form.Item noStyle dependencies={['model_type']}>
|
||||
{({ getFieldValue }) =>
|
||||
getFieldValue('model_type') === 'tts' && (
|
||||
<Form.Item<FieldType>
|
||||
label={t('addSparkAPIKey')}
|
||||
name="spark_api_key"
|
||||
rules={[{ required: true, message: t('SparkAPIKeyMessage') }]}
|
||||
>
|
||||
<Input placeholder={t('SparkAPIKeyMessage')} />
|
||||
</Form.Item>
|
||||
)
|
||||
}
|
||||
</Form.Item>
|
||||
</Form>
|
||||
</Modal>
|
||||
);
|
||||
|
Loading…
x
Reference in New Issue
Block a user