From d9c2a128a5bff7b55f5b39339a8a765eb312367a Mon Sep 17 00:00:00 2001 From: liuhua <10215101452@stu.ecnu.edu.cn> Date: Tue, 24 Sep 2024 12:15:12 +0800 Subject: [PATCH] SparkTTS (#2535) ### What problem does this PR solve? SparkTTS ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn> --- api/apps/llm_app.py | 5 +- rag/llm/__init__.py | 3 +- rag/llm/tts_model.py | 124 +++++++++++++++++- requirements.txt | 2 + web/src/locales/en.ts | 6 + web/src/locales/zh-traditional.ts | 6 + web/src/locales/zh.ts | 6 + .../setting-model/spark-modal/index.tsx | 70 ++++++++-- 8 files changed, 200 insertions(+), 22 deletions(-) diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 6f26fdcf3..30a6a297d 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -161,7 +161,10 @@ def add_llm(): elif factory =="XunFei Spark": llm_name = req["llm_name"] - api_key = req.get("spark_api_password","xxxxxxxxxxxxxxx") + if req["model_type"] == "chat": + api_key = req.get("spark_api_password", "xxxxxxxxxxxxxxx") + elif req["model_type"] == "tts": + api_key = apikey_json(["spark_app_id", "spark_api_secret","spark_api_key"]) elif factory == "BaiduYiyan": llm_name = req["llm_name"] diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index fbcea43b7..84c1adf01 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -139,5 +139,6 @@ Seq2txtModel = { TTSModel = { "Fish Audio": FishAudioTTS, "Tongyi-Qianwen": QwenTTS, - "OpenAI":OpenAITTS + "OpenAI":OpenAITTS, + "XunFei Spark":SparkTTS } \ No newline at end of file diff --git a/rag/llm/tts_model.py b/rag/llm/tts_model.py index 1af100723..bfdb8762c 100644 --- a/rag/llm/tts_model.py +++ b/rag/llm/tts_model.py @@ -14,16 +14,30 @@ # limitations under the License. # -import requests -from typing import Annotated, Literal +import _thread as thread +import base64 +import datetime +import hashlib +import hmac +import json +import queue +import re +import ssl +import time from abc import ABC +from datetime import datetime +from time import mktime +from typing import Annotated, Literal +from urllib.parse import urlencode +from wsgiref.handlers import format_date_time + import httpx import ormsgpack +import requests +import websocket from pydantic import BaseModel, conint + from rag.utils import num_tokens_from_string -import json -import re -import time class ServeReferenceAudio(BaseModel): @@ -161,7 +175,7 @@ class QwenTTS(Base): class OpenAITTS(Base): def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"): - if not base_url: base_url="https://api.openai.com/v1" + if not base_url: base_url = "https://api.openai.com/v1" self.api_key = key self.model_name = model_name self.base_url = base_url @@ -185,3 +199,101 @@ class OpenAITTS(Base): for chunk in response.iter_content(): if chunk: yield chunk + + +class SparkTTS: + STATUS_FIRST_FRAME = 0 + STATUS_CONTINUE_FRAME = 1 + STATUS_LAST_FRAME = 2 + + def __init__(self, key, model_name, base_url=""): + key = json.loads(key) + self.APPID = key.get("spark_app_id", "xxxxxxx") + self.APISecret = key.get("spark_api_secret", "xxxxxxx") + self.APIKey = key.get("spark_api_key", "xxxxxx") + self.model_name = model_name + self.CommonArgs = {"app_id": self.APPID} + self.audio_queue = queue.Queue() + + # 用来存储音频数据 + + # 生成url + def create_url(self): + url = 'wss://tts-api.xfyun.cn/v2/tts' + now = datetime.now() + date = format_date_time(mktime(now.timetuple())) + signature_origin = "host: " + "ws-api.xfyun.cn" + "\n" + signature_origin += "date: " + date + "\n" + signature_origin += "GET " + "/v2/tts " + "HTTP/1.1" + signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'), + digestmod=hashlib.sha256).digest() + signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8') + authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % ( + self.APIKey, "hmac-sha256", "host date request-line", signature_sha) + authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') + v = { + "authorization": authorization, + "date": date, + "host": "ws-api.xfyun.cn" + } + url = url + '?' + urlencode(v) + return url + + def tts(self, text): + BusinessArgs = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.model_name, "tte": "utf8"} + Data = {"status": 2, "text": base64.b64encode(text.encode('utf-8')).decode('utf-8')} + CommonArgs = {"app_id": self.APPID} + audio_queue = self.audio_queue + model_name = self.model_name + + class Callback: + def __init__(self): + self.audio_queue = audio_queue + + def on_message(self, ws, message): + message = json.loads(message) + code = message["code"] + sid = message["sid"] + audio = message["data"]["audio"] + audio = base64.b64decode(audio) + status = message["data"]["status"] + if status == 2: + ws.close() + if code != 0: + errMsg = message["message"] + raise Exception(f"sid:{sid} call error:{errMsg} code:{code}") + else: + self.audio_queue.put(audio) + + def on_error(self, ws, error): + raise Exception(error) + + def on_close(self, ws, close_status_code, close_msg): + self.audio_queue.put(None) # 放入 None 作为结束标志 + + def on_open(self, ws): + def run(*args): + d = {"common": CommonArgs, + "business": BusinessArgs, + "data": Data} + ws.send(json.dumps(d)) + + thread.start_new_thread(run, ()) + + wsUrl = self.create_url() + websocket.enableTrace(False) + a = Callback() + ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close, + on_message=a.on_message) + status_code = 0 + ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) + while True: + audio_chunk = self.audio_queue.get() + if audio_chunk is None: + if status_code == 0: + raise Exception( + f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.") + else: + break + status_code = 1 + yield audio_chunk diff --git a/requirements.txt b/requirements.txt index dbed17ed5..1530bc4eb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -94,6 +94,8 @@ vertexai==1.64.0 volcengine==1.0.146 voyageai==0.2.3 webdriver_manager==4.0.1 +websocket==0.2.1 +websocket-client==1.8.0 Werkzeug==3.0.3 wikipedia==1.4.0 word2number==1.1 diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index 92a85297b..56c430db6 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -551,6 +551,12 @@ The above is the content you need to summarize.`, SparkModelNameMessage: 'Please select Spark model', addSparkAPIPassword: 'Spark APIPassword', SparkAPIPasswordMessage: 'please input your APIPassword', + addSparkAPPID: 'Spark APPID', + SparkAPPIDMessage: 'please input your APPID', + addSparkAPISecret: 'Spark APISecret', + SparkAPISecretMessage: 'please input your APISecret', + addSparkAPIKey: 'Spark APIKey', + SparkAPIKeyMessage: 'please input your APIKey', yiyanModelNameMessage: 'Please input model name', addyiyanAK: 'yiyan API KEY', yiyanAKMessage: 'Please input your API KEY', diff --git a/web/src/locales/zh-traditional.ts b/web/src/locales/zh-traditional.ts index ee355f10f..d7241fe48 100644 --- a/web/src/locales/zh-traditional.ts +++ b/web/src/locales/zh-traditional.ts @@ -512,6 +512,12 @@ export default { SparkModelNameMessage: '請選擇星火模型!', addSparkAPIPassword: '星火 APIPassword', SparkAPIPasswordMessage: '請輸入 APIPassword', + addSparkAPPID: '星火 APPID', + SparkAPPIDMessage: '請輸入 APPID', + addSparkAPISecret: '星火 APISecret', + SparkAPISecretMessage: '請輸入 APISecret', + addSparkAPIKey: '星火 APIKey', + SparkAPIKeyMessage: '請輸入 APIKey', yiyanModelNameMessage: '輸入模型名稱', addyiyanAK: '一言 API KEY', yiyanAKMessage: '請輸入 API KEY', diff --git a/web/src/locales/zh.ts b/web/src/locales/zh.ts index cc11c804c..88b0ff705 100644 --- a/web/src/locales/zh.ts +++ b/web/src/locales/zh.ts @@ -529,6 +529,12 @@ export default { SparkModelNameMessage: '请选择星火模型!', addSparkAPIPassword: '星火 APIPassword', SparkAPIPasswordMessage: '请输入 APIPassword', + addSparkAPPID: '星火 APPID', + SparkAPPIDMessage: '请输入 APPID', + addSparkAPISecret: '星火 APISecret', + SparkAPISecretMessage: '请输入 APISecret', + addSparkAPIKey: '星火 APIKey', + SparkAPIKeyMessage: '请输入 APIKey', yiyanModelNameMessage: '请输入模型名称', addyiyanAK: '一言 API KEY', yiyanAKMessage: '请输入 API KEY', diff --git a/web/src/pages/user-setting/setting-model/spark-modal/index.tsx b/web/src/pages/user-setting/setting-model/spark-modal/index.tsx index 942c524f1..59be63301 100644 --- a/web/src/pages/user-setting/setting-model/spark-modal/index.tsx +++ b/web/src/pages/user-setting/setting-model/spark-modal/index.tsx @@ -7,6 +7,9 @@ import omit from 'lodash/omit'; type FieldType = IAddLlmRequestBody & { vision: boolean; spark_api_password: string; + spark_app_id: string; + spark_api_secret: string; + spark_api_key: string; }; const { Option } = Select; @@ -63,28 +66,67 @@ const SparkModal = ({ > label={t('modelName')} name="llm_name" - initialValue={'Spark-Max'} rules={[{ required: true, message: t('SparkModelNameMessage') }]} > - + - - label={t('addSparkAPIPassword')} - name="spark_api_password" - rules={[{ required: true, message: t('SparkAPIPasswordMessage') }]} - > - + + {({ getFieldValue }) => + getFieldValue('model_type') === 'chat' && ( + + label={t('addSparkAPIPassword')} + name="spark_api_password" + rules={[{ required: true, message: t('SparkAPIPasswordMessage') }]} + > + + + ) + } + + + {({ getFieldValue }) => + getFieldValue('model_type') === 'tts' && ( + + label={t('addSparkAPPID')} + name="spark_app_id" + rules={[{ required: true, message: t('SparkAPPIDMessage') }]} + > + + + ) + } + + + {({ getFieldValue }) => + getFieldValue('model_type') === 'tts' && ( + + label={t('addSparkAPISecret')} + name="spark_api_secret" + rules={[{ required: true, message: t('SparkAPISecretMessage') }]} + > + + + ) + } + + + {({ getFieldValue }) => + getFieldValue('model_type') === 'tts' && ( + + label={t('addSparkAPIKey')} + name="spark_api_key" + rules={[{ required: true, message: t('SparkAPIKeyMessage') }]} + > + + + ) + }