### What problem does this PR solve?

SparkTTS

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn>
This commit is contained in:
liuhua 2024-09-24 12:15:12 +08:00 committed by GitHub
parent 38e3475714
commit d9c2a128a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 200 additions and 22 deletions

View File

@ -161,7 +161,10 @@ def add_llm():
elif factory =="XunFei Spark":
llm_name = req["llm_name"]
api_key = req.get("spark_api_password","xxxxxxxxxxxxxxx")
if req["model_type"] == "chat":
api_key = req.get("spark_api_password", "xxxxxxxxxxxxxxx")
elif req["model_type"] == "tts":
api_key = apikey_json(["spark_app_id", "spark_api_secret","spark_api_key"])
elif factory == "BaiduYiyan":
llm_name = req["llm_name"]

View File

@ -139,5 +139,6 @@ Seq2txtModel = {
TTSModel = {
"Fish Audio": FishAudioTTS,
"Tongyi-Qianwen": QwenTTS,
"OpenAI":OpenAITTS
"OpenAI":OpenAITTS,
"XunFei Spark":SparkTTS
}

View File

@ -14,16 +14,30 @@
# limitations under the License.
#
import requests
from typing import Annotated, Literal
import _thread as thread
import base64
import datetime
import hashlib
import hmac
import json
import queue
import re
import ssl
import time
from abc import ABC
from datetime import datetime
from time import mktime
from typing import Annotated, Literal
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time
import httpx
import ormsgpack
import requests
import websocket
from pydantic import BaseModel, conint
from rag.utils import num_tokens_from_string
import json
import re
import time
class ServeReferenceAudio(BaseModel):
@ -161,7 +175,7 @@ class QwenTTS(Base):
class OpenAITTS(Base):
def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
if not base_url: base_url="https://api.openai.com/v1"
if not base_url: base_url = "https://api.openai.com/v1"
self.api_key = key
self.model_name = model_name
self.base_url = base_url
@ -185,3 +199,101 @@ class OpenAITTS(Base):
for chunk in response.iter_content():
if chunk:
yield chunk
class SparkTTS:
STATUS_FIRST_FRAME = 0
STATUS_CONTINUE_FRAME = 1
STATUS_LAST_FRAME = 2
def __init__(self, key, model_name, base_url=""):
key = json.loads(key)
self.APPID = key.get("spark_app_id", "xxxxxxx")
self.APISecret = key.get("spark_api_secret", "xxxxxxx")
self.APIKey = key.get("spark_api_key", "xxxxxx")
self.model_name = model_name
self.CommonArgs = {"app_id": self.APPID}
self.audio_queue = queue.Queue()
# 用来存储音频数据
# 生成url
def create_url(self):
url = 'wss://tts-api.xfyun.cn/v2/tts'
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
v = {
"authorization": authorization,
"date": date,
"host": "ws-api.xfyun.cn"
}
url = url + '?' + urlencode(v)
return url
def tts(self, text):
BusinessArgs = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.model_name, "tte": "utf8"}
Data = {"status": 2, "text": base64.b64encode(text.encode('utf-8')).decode('utf-8')}
CommonArgs = {"app_id": self.APPID}
audio_queue = self.audio_queue
model_name = self.model_name
class Callback:
def __init__(self):
self.audio_queue = audio_queue
def on_message(self, ws, message):
message = json.loads(message)
code = message["code"]
sid = message["sid"]
audio = message["data"]["audio"]
audio = base64.b64decode(audio)
status = message["data"]["status"]
if status == 2:
ws.close()
if code != 0:
errMsg = message["message"]
raise Exception(f"sid:{sid} call error:{errMsg} code:{code}")
else:
self.audio_queue.put(audio)
def on_error(self, ws, error):
raise Exception(error)
def on_close(self, ws, close_status_code, close_msg):
self.audio_queue.put(None) # 放入 None 作为结束标志
def on_open(self, ws):
def run(*args):
d = {"common": CommonArgs,
"business": BusinessArgs,
"data": Data}
ws.send(json.dumps(d))
thread.start_new_thread(run, ())
wsUrl = self.create_url()
websocket.enableTrace(False)
a = Callback()
ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close,
on_message=a.on_message)
status_code = 0
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
while True:
audio_chunk = self.audio_queue.get()
if audio_chunk is None:
if status_code == 0:
raise Exception(
f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.")
else:
break
status_code = 1
yield audio_chunk

View File

@ -94,6 +94,8 @@ vertexai==1.64.0
volcengine==1.0.146
voyageai==0.2.3
webdriver_manager==4.0.1
websocket==0.2.1
websocket-client==1.8.0
Werkzeug==3.0.3
wikipedia==1.4.0
word2number==1.1

View File

@ -551,6 +551,12 @@ The above is the content you need to summarize.`,
SparkModelNameMessage: 'Please select Spark model',
addSparkAPIPassword: 'Spark APIPassword',
SparkAPIPasswordMessage: 'please input your APIPassword',
addSparkAPPID: 'Spark APPID',
SparkAPPIDMessage: 'please input your APPID',
addSparkAPISecret: 'Spark APISecret',
SparkAPISecretMessage: 'please input your APISecret',
addSparkAPIKey: 'Spark APIKey',
SparkAPIKeyMessage: 'please input your APIKey',
yiyanModelNameMessage: 'Please input model name',
addyiyanAK: 'yiyan API KEY',
yiyanAKMessage: 'Please input your API KEY',

View File

@ -512,6 +512,12 @@ export default {
SparkModelNameMessage: '請選擇星火模型!',
addSparkAPIPassword: '星火 APIPassword',
SparkAPIPasswordMessage: '請輸入 APIPassword',
addSparkAPPID: '星火 APPID',
SparkAPPIDMessage: '請輸入 APPID',
addSparkAPISecret: '星火 APISecret',
SparkAPISecretMessage: '請輸入 APISecret',
addSparkAPIKey: '星火 APIKey',
SparkAPIKeyMessage: '請輸入 APIKey',
yiyanModelNameMessage: '輸入模型名稱',
addyiyanAK: '一言 API KEY',
yiyanAKMessage: '請輸入 API KEY',

View File

@ -529,6 +529,12 @@ export default {
SparkModelNameMessage: '请选择星火模型!',
addSparkAPIPassword: '星火 APIPassword',
SparkAPIPasswordMessage: '请输入 APIPassword',
addSparkAPPID: '星火 APPID',
SparkAPPIDMessage: '请输入 APPID',
addSparkAPISecret: '星火 APISecret',
SparkAPISecretMessage: '请输入 APISecret',
addSparkAPIKey: '星火 APIKey',
SparkAPIKeyMessage: '请输入 APIKey',
yiyanModelNameMessage: '请输入模型名称',
addyiyanAK: '一言 API KEY',
yiyanAKMessage: '请输入 API KEY',

View File

@ -7,6 +7,9 @@ import omit from 'lodash/omit';
type FieldType = IAddLlmRequestBody & {
vision: boolean;
spark_api_password: string;
spark_app_id: string;
spark_api_secret: string;
spark_api_key: string;
};
const { Option } = Select;
@ -63,22 +66,19 @@ const SparkModal = ({
>
<Select placeholder={t('modelTypeMessage')}>
<Option value="chat">chat</Option>
<Option value="tts">tts</Option>
</Select>
</Form.Item>
<Form.Item<FieldType>
label={t('modelName')}
name="llm_name"
initialValue={'Spark-Max'}
rules={[{ required: true, message: t('SparkModelNameMessage') }]}
>
<Select placeholder={t('modelTypeMessage')}>
<Option value="Spark-Max">Spark-Max</Option>
<Option value="Spark-Lite">Spark-Lite</Option>
<Option value="Spark-Pro">Spark-Pro</Option>
<Option value="Spark-Pro-128K">Spark-Pro-128K</Option>
<Option value="Spark-4.0-Ultra">Spark-4.0-Ultra</Option>
</Select>
<Input placeholder={t('modelNameMessage')} />
</Form.Item>
<Form.Item noStyle dependencies={['model_type']}>
{({ getFieldValue }) =>
getFieldValue('model_type') === 'chat' && (
<Form.Item<FieldType>
label={t('addSparkAPIPassword')}
name="spark_api_password"
@ -86,6 +86,48 @@ const SparkModal = ({
>
<Input placeholder={t('SparkAPIPasswordMessage')} />
</Form.Item>
)
}
</Form.Item>
<Form.Item noStyle dependencies={['model_type']}>
{({ getFieldValue }) =>
getFieldValue('model_type') === 'tts' && (
<Form.Item<FieldType>
label={t('addSparkAPPID')}
name="spark_app_id"
rules={[{ required: true, message: t('SparkAPPIDMessage') }]}
>
<Input placeholder={t('SparkAPPIDMessage')} />
</Form.Item>
)
}
</Form.Item>
<Form.Item noStyle dependencies={['model_type']}>
{({ getFieldValue }) =>
getFieldValue('model_type') === 'tts' && (
<Form.Item<FieldType>
label={t('addSparkAPISecret')}
name="spark_api_secret"
rules={[{ required: true, message: t('SparkAPISecretMessage') }]}
>
<Input placeholder={t('SparkAPISecretMessage')} />
</Form.Item>
)
}
</Form.Item>
<Form.Item noStyle dependencies={['model_type']}>
{({ getFieldValue }) =>
getFieldValue('model_type') === 'tts' && (
<Form.Item<FieldType>
label={t('addSparkAPIKey')}
name="spark_api_key"
rules={[{ required: true, message: t('SparkAPIKeyMessage') }]}
>
<Input placeholder={t('SparkAPIKeyMessage')} />
</Form.Item>
)
}
</Form.Item>
</Form>
</Modal>
);