add support for TTS model (#2095)

### What problem does this PR solve?

add support for TTS model
#1853

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Zhedong Cen <cenzhedong2@126.com>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
黄腾 2024-08-26 15:19:43 +08:00 committed by GitHub
parent c3e344b0f1
commit 6b7c028578
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 338 additions and 7 deletions

View File

@ -20,7 +20,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
from api.db import StatusEnum, LLMType from api.db import StatusEnum, LLMType
from api.db.db_models import TenantLLM from api.db.db_models import TenantLLM
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
from rag.llm import EmbeddingModel, ChatModel, RerankModel,CvModel from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel
import requests import requests
import ast import ast
@ -142,6 +142,10 @@ def add_llm():
llm_name = req["llm_name"] llm_name = req["llm_name"]
api_key = '{' + f'"yiyan_ak": "{req.get("yiyan_ak", "")}", ' \ api_key = '{' + f'"yiyan_ak": "{req.get("yiyan_ak", "")}", ' \
f'"yiyan_sk": "{req.get("yiyan_sk", "")}"' + '}' f'"yiyan_sk": "{req.get("yiyan_sk", "")}"' + '}'
elif factory == "Fish Audio":
llm_name = req["llm_name"]
api_key = '{' + f'"fish_audio_ak": "{req.get("fish_audio_ak", "")}", ' \
f'"fish_audio_refid": "{req.get("fish_audio_refid", "59cb5986671546eaa6ca8ae6f29f6d22")}"' + '}'
else: else:
llm_name = req["llm_name"] llm_name = req["llm_name"]
api_key = req.get("api_key","xxxxxxxxxxxxxxx") api_key = req.get("api_key","xxxxxxxxxxxxxxx")
@ -215,6 +219,15 @@ def add_llm():
pass pass
except Exception as e: except Exception as e:
msg += f"\nFail to access model({llm['llm_name']})." + str(e) msg += f"\nFail to access model({llm['llm_name']})." + str(e)
elif llm["model_type"] == LLMType.TTS:
mdl = TTSModel[factory](
key=llm["api_key"], model_name=llm["llm_name"], base_url=llm["api_base"]
)
try:
for resp in mdl.transcription("Hello~ Ragflower!"):
pass
except RuntimeError as e:
msg += f"\nFail to access model({llm['llm_name']})." + str(e)
else: else:
# TODO: check other type of models # TODO: check other type of models
pass pass

View File

@ -410,7 +410,7 @@ def tenant_info():
@manager.route("/set_tenant_info", methods=["POST"]) @manager.route("/set_tenant_info", methods=["POST"])
@login_required @login_required
@validate_request("tenant_id", "asr_id", "embd_id", "img2txt_id", "llm_id") @validate_request("tenant_id", "asr_id", "embd_id", "img2txt_id", "llm_id", "tts_id")
def set_tenant_info(): def set_tenant_info():
req = request.json req = request.json
try: try:

View File

@ -55,6 +55,7 @@ class LLMType(StrEnum):
SPEECH2TEXT = 'speech2text' SPEECH2TEXT = 'speech2text'
IMAGE2TEXT = 'image2text' IMAGE2TEXT = 'image2text'
RERANK = 'rerank' RERANK = 'rerank'
TTS = 'tts'
class ChatStyle(StrEnum): class ChatStyle(StrEnum):

View File

@ -449,6 +449,11 @@ class Tenant(DataBaseModel):
null=False, null=False,
help_text="default rerank model ID", help_text="default rerank model ID",
index=True) index=True)
tts_id = CharField(
max_length=256,
null=True,
help_text="default tts model ID",
index=True)
parser_ids = CharField( parser_ids = CharField(
max_length=256, max_length=256,
null=False, null=False,
@ -958,6 +963,13 @@ def migrate_db():
) )
except Exception as e: except Exception as e:
pass pass
try:
migrate(
migrator.add_column("tenant","tts_id",
CharField(max_length=256,null=True,help_text="default tts model ID",index=True))
)
except Exception as e:
pass
try: try:
migrate( migrate(
migrator.add_column('api_4_conversation', 'source', migrator.add_column('api_4_conversation', 'source',

View File

@ -15,7 +15,7 @@
# #
from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService
from api.settings import database_logger from api.settings import database_logger
from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel
from api.db import LLMType from api.db import LLMType
from api.db.db_models import DB, UserTenant from api.db.db_models import DB, UserTenant
from api.db.db_models import LLMFactories, LLM, TenantLLM from api.db.db_models import LLMFactories, LLM, TenantLLM
@ -75,6 +75,8 @@ class TenantLLMService(CommonService):
mdlnm = tenant.llm_id if not llm_name else llm_name mdlnm = tenant.llm_id if not llm_name else llm_name
elif llm_type == LLMType.RERANK: elif llm_type == LLMType.RERANK:
mdlnm = tenant.rerank_id if not llm_name else llm_name mdlnm = tenant.rerank_id if not llm_name else llm_name
elif llm_type == LLMType.TTS:
mdlnm = tenant.tts_id if not llm_name else llm_name
else: else:
assert False, "LLM type error" assert False, "LLM type error"
@ -127,6 +129,14 @@ class TenantLLMService(CommonService):
model_config["api_key"], model_config["llm_name"], lang, model_config["api_key"], model_config["llm_name"], lang,
base_url=model_config["api_base"] base_url=model_config["api_base"]
) )
if llm_type == LLMType.TTS:
if model_config["llm_factory"] not in TTSModel:
return
return TTSModel[model_config["llm_factory"]](
model_config["api_key"],
model_config["llm_name"],
base_url=model_config["api_base"],
)
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
@ -144,7 +154,9 @@ class TenantLLMService(CommonService):
elif llm_type == LLMType.CHAT.value: elif llm_type == LLMType.CHAT.value:
mdlnm = tenant.llm_id if not llm_name else llm_name mdlnm = tenant.llm_id if not llm_name else llm_name
elif llm_type == LLMType.RERANK: elif llm_type == LLMType.RERANK:
mdlnm = tenant.llm_id if not llm_name else llm_name mdlnm = tenant.rerank_id if not llm_name else llm_name
elif llm_type == LLMType.TTS:
mdlnm = tenant.tts_id if not llm_name else llm_name
else: else:
assert False, "LLM type error" assert False, "LLM type error"

View File

@ -3214,6 +3214,13 @@
"tags": "LLM", "tags": "LLM",
"status": "1", "status": "1",
"llm": [] "llm": []
},
{
"name": "Fish Audio",
"logo": "",
"tags": "TTS",
"status": "1",
"llm": []
} }
] ]
} }

View File

@ -18,6 +18,7 @@ from .chat_model import *
from .cv_model import * from .cv_model import *
from .rerank_model import * from .rerank_model import *
from .sequence2txt_model import * from .sequence2txt_model import *
from .tts_model import *
EmbeddingModel = { EmbeddingModel = {
"Ollama": OllamaEmbed, "Ollama": OllamaEmbed,
@ -129,3 +130,7 @@ Seq2txtModel = {
"Azure-OpenAI": AzureSeq2txt, "Azure-OpenAI": AzureSeq2txt,
"Xinference": XinferenceSeq2txt "Xinference": XinferenceSeq2txt
} }
TTSModel = {
"Fish Audio": FishAudioTTS
}

94
rag/llm/tts_model.py Normal file
View File

@ -0,0 +1,94 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Annotated, Literal
from abc import ABC
import httpx
import ormsgpack
from pydantic import BaseModel, conint
from rag.utils import num_tokens_from_string
import json
class ServeReferenceAudio(BaseModel):
audio: bytes
text: str
class ServeTTSRequest(BaseModel):
text: str
chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
# Audio format
format: Literal["wav", "pcm", "mp3"] = "mp3"
mp3_bitrate: Literal[64, 128, 192] = 128
# References audios for in-context learning
references: list[ServeReferenceAudio] = []
# Reference id
# For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
# Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
reference_id: str | None = None
# Normalize text for en & zh, this increase stability for numbers
normalize: bool = True
# Balance mode will reduce latency to 300ms, but may decrease stability
latency: Literal["normal", "balanced"] = "normal"
class Base(ABC):
def __init__(self, key, model_name, base_url):
pass
def transcription(self, audio):
pass
class FishAudioTTS(Base):
def __init__(self, key, model_name, base_url="https://api.fish.audio/v1/tts"):
if not base_url:
base_url = "https://api.fish.audio/v1/tts"
key = json.loads(key)
self.headers = {
"api-key": key.get("fish_audio_ak"),
"content-type": "application/msgpack",
}
self.ref_id = key.get("fish_audio_refid")
self.base_url = base_url
def transcription(self, text):
from http import HTTPStatus
request = request = ServeTTSRequest(text=text, reference_id=self.ref_id)
with httpx.Client() as client:
try:
with client.stream(
method="POST",
url=self.base_url,
content=ormsgpack.packb(
request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC
),
headers=self.headers,
timeout=None,
) as response:
if response.status_code == HTTPStatus.OK:
for chunk in response.iter_bytes():
yield chunk
else:
response.raise_for_status()
yield num_tokens_from_string(text)
except httpx.HTTPStatusError as e:
raise RuntimeError(f"**ERROR**: {e}")

View File

@ -47,6 +47,7 @@ openai==1.12.0
opencv_python==4.9.0.80 opencv_python==4.9.0.80
opencv_python_headless==4.9.0.80 opencv_python_headless==4.9.0.80
openpyxl==3.1.2 openpyxl==3.1.2
ormsgpack==1.5.0
pandas==2.2.2 pandas==2.2.2
pdfplumber==0.10.4 pdfplumber==0.10.4
peewee==3.17.1 peewee==3.17.1

View File

@ -74,6 +74,7 @@ ollama==0.1.9
openai==1.12.0 openai==1.12.0
opencv-python==4.9.0.80 opencv-python==4.9.0.80
openpyxl==3.1.2 openpyxl==3.1.2
ormsgpack==1.5.0
packaging==23.2 packaging==23.2
pandas==2.2.1 pandas==2.2.1
pdfminer.six==20221105 pdfminer.six==20221105

View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="61.1 180.15 377.8 139.718" class="h-6 dark:invert"><path d="M431.911 245.181c3.842 0 6.989 1.952 6.989 4.337v14.776c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.952-6.99-4.337v-14.776c0-2.385 3.144-4.337 6.99-4.337ZM404.135 250.955c3.846 0 6.989 1.952 6.989 4.337v32.528c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-32.528c0-2.385 3.147-4.337 6.989-4.337ZM376.363 257.688c3.842 0 6.989 1.952 6.989 4.337v36.562c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.993-1.952-6.993-4.337v-36.562c0-2.386 3.147-4.337 6.993-4.337ZM348.587 263.26c3.846 0 6.989 1.952 6.989 4.337v36.159c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-36.159c0-2.385 3.147-4.337 6.989-4.337ZM320.811 268.177c3.846 0 6.989 1.952 6.989 4.337v31.318c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-31.318c0-2.385 3.147-4.337 6.989-4.337ZM293.179 288.148c3.846 0 6.989 1.952 6.989 4.337v9.935c0 2.384-3.147 4.336-6.989 4.336s-6.99-1.951-6.99-4.336v-9.935c0-2.386 3.144-4.337 6.99-4.337Z" style="fill: rgb(177, 179, 180); fill-rule: evenodd;"></path><path d="M431.911 205.441c3.842 0 6.989 1.952 6.989 4.337v24.459c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.952-6.99-4.337v-24.459c0-2.385 3.144-4.337 6.99-4.337ZM404.135 189.026c3.846 0 6.989 1.952 6.989 4.337v43.622c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.951-6.989-4.337v-43.622c0-2.385 3.147-4.337 6.989-4.337ZM376.363 182.848c3.842 0 6.989 1.953 6.989 4.337v56.937c0 2.384-3.147 4.337-6.989 4.337-3.846 0-6.993-1.952-6.993-4.337v-56.937c0-2.385 3.147-4.337 6.993-4.337ZM348.587 180.15c3.846 0 6.989 1.952 6.989 4.337v66.619c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-66.619c0-2.385 3.147-4.337 6.989-4.337ZM320.811 181.84c3.846 0 6.989 1.952 6.989 4.337v67.627c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.951-6.989-4.337v-67.627c0-2.386 3.147-4.337 6.989-4.337ZM293.179 186.076c3.846 0 6.989 1.952 6.989 4.337v84.37c0 2.385-3.147 4.337-6.989 4.337s-6.99-1.951-6.99-4.337v-84.37c0-2.386 3.144-4.337 6.99-4.337ZM264.829 193.262c3.846 0 6.989 1.953 6.989 4.337v95.667c0 2.385-3.143 4.337-6.989 4.337-3.843 0-6.99-1.951-6.99-4.337v-95.667c0-2.385 3.147-4.337 6.99-4.337ZM237.057 205.441c3.842 0 6.989 1.953 6.989 4.337v92.036c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.951-6.99-4.337v-92.036c0-2.385 3.144-4.337 6.99-4.337ZM209.281 221.302c3.846 0 6.989 1.952 6.989 4.337v80.134c0 2.385-3.147 4.337-6.989 4.337s-6.99-1.952-6.99-4.337v-80.134c0-2.386 3.144-4.337 6.99-4.337ZM181.505 232.271c3.846 0 6.993 1.952 6.993 4.336v78.924c0 2.385-3.147 4.337-6.993 4.337-3.842 0-6.989-1.951-6.989-4.337v-78.924c0-2.385 3.147-4.336 6.989-4.336ZM153.873 241.348c3.846 0 6.989 1.953 6.989 4.337v42.009c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.952-6.99-4.337v-42.009c0-2.385 3.147-4.337 6.99-4.337ZM125.266 200.398c3.842 0 6.989 1.953 6.989 4.337v58.55c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.951-6.99-4.337v-58.55c0-2.385 3.144-4.337 6.99-4.337ZM96.7 204.231c3.842 0 6.989 1.953 6.989 4.337v18.004c0 2.384-3.147 4.337-6.989 4.337s-6.989-1.952-6.989-4.337v-18.004c0-2.385 3.143-4.337 6.989-4.337ZM68.089 201.81c3.846 0 6.99 1.953 6.99 4.337v8.12c0 2.384-3.147 4.336-6.99 4.336-3.842 0-6.989-1.951-6.989-4.336v-8.12c0-2.385 3.143-4.337 6.989-4.337ZM153.873 194.94c3.846 0 6.989 1.953 6.989 4.337v6.102c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.952-6.99-4.337v-6.102c0-2.385 3.147-4.337 6.99-4.337Z" style="fill: rgb(0, 0, 0); fill-rule: evenodd;"></path></svg>

After

Width:  |  Height:  |  Size: 3.4 KiB

View File

@ -48,6 +48,7 @@ export enum LlmModelType {
Image2text = 'image2text', Image2text = 'image2text',
Speech2text = 'speech2text', Speech2text = 'speech2text',
Rerank = 'rerank', Rerank = 'rerank',
TTS = 'tts',
} }
export enum KnowledgeSearchParams { export enum KnowledgeSearchParams {

View File

@ -87,6 +87,7 @@ export const useSelectLlmOptionsByModelType = () => {
LlmModelType.Speech2text, LlmModelType.Speech2text,
), ),
[LlmModelType.Rerank]: groupOptionsByModelType(LlmModelType.Rerank), [LlmModelType.Rerank]: groupOptionsByModelType(LlmModelType.Rerank),
[LlmModelType.TTS]: groupOptionsByModelType(LlmModelType.TTS),
}; };
}; };

View File

@ -71,6 +71,7 @@ export interface ITenantInfo {
tenant_id: string; tenant_id: string;
chat_id: string; chat_id: string;
speech2text_id: string; speech2text_id: string;
tts_id: string;
} }
export interface IChunk { export interface IChunk {

View File

@ -490,6 +490,9 @@ The above is the content you need to summarize.`,
'The default ASR model all the newly created knowledgebase will use. Use this model to translate voices to corresponding text.', 'The default ASR model all the newly created knowledgebase will use. Use this model to translate voices to corresponding text.',
rerankModel: 'Rerank Model', rerankModel: 'Rerank Model',
rerankModelTip: `The default rerank model is used to rerank chunks retrieved by users' questions.`, rerankModelTip: `The default rerank model is used to rerank chunks retrieved by users' questions.`,
ttsModel: 'TTS Model',
ttsModelTip:
'The default TTS model will be used to generate speech during conversations upon request.',
workspace: 'Workspace', workspace: 'Workspace',
upgrade: 'Upgrade', upgrade: 'Upgrade',
addLlmTitle: 'Add LLM', addLlmTitle: 'Add LLM',
@ -502,6 +505,7 @@ The above is the content you need to summarize.`,
baseUrlNameMessage: 'Please input your base url!', baseUrlNameMessage: 'Please input your base url!',
vision: 'Does it support Vision?', vision: 'Does it support Vision?',
ollamaLink: 'How to integrate {{name}}', ollamaLink: 'How to integrate {{name}}',
FishAudioLink: 'How to use FishAudio',
volcModelNameMessage: 'Please input your model name!', volcModelNameMessage: 'Please input your model name!',
addEndpointID: 'EndpointID of the model', addEndpointID: 'EndpointID of the model',
endpointIDMessage: 'Please input your EndpointID of the model', endpointIDMessage: 'Please input your EndpointID of the model',
@ -533,6 +537,13 @@ The above is the content you need to summarize.`,
yiyanAKMessage: 'Please input your API KEY', yiyanAKMessage: 'Please input your API KEY',
addyiyanSK: 'yiyan Secret KEY', addyiyanSK: 'yiyan Secret KEY',
yiyanSKMessage: 'Please input your Secret KEY', yiyanSKMessage: 'Please input your Secret KEY',
FishAudioModelNameMessage:
'Please give your speech synthesis model a name',
addFishAudioAK: 'Fish Audio API KEY',
addFishAudioAKMessage: 'Please input your API KEY',
addFishAudioRefID: 'FishAudio Refrence ID',
addFishAudioRefIDMessage:
'Please input the Reference ID (leave blank to use the default model).',
}, },
message: { message: {
registered: 'Registered!', registered: 'Registered!',

View File

@ -443,6 +443,8 @@ export default {
systemModelSettings: '系統模型設置', systemModelSettings: '系統模型設置',
chatModel: '聊天模型', chatModel: '聊天模型',
chatModelTip: '所有新創建的知識庫都會使用默認的聊天LLM。', chatModelTip: '所有新創建的知識庫都會使用默認的聊天LLM。',
ttsModel: '語音合成模型',
ttsModelTip: '默認的tts模型會被用於在對話過程中請求語音生成時使用。',
embeddingModel: '嵌入模型', embeddingModel: '嵌入模型',
embeddingModelTip: '所有新創建的知識庫都將使用的默認嵌入模型。', embeddingModelTip: '所有新創建的知識庫都將使用的默認嵌入模型。',
img2txtModel: 'img2Txt模型', img2txtModel: 'img2Txt模型',
@ -465,6 +467,7 @@ export default {
modelTypeMessage: '請輸入模型類型!', modelTypeMessage: '請輸入模型類型!',
baseUrlNameMessage: '請輸入基礎 Url', baseUrlNameMessage: '請輸入基礎 Url',
ollamaLink: '如何集成 {{name}}', ollamaLink: '如何集成 {{name}}',
FishAudioLink: '如何使用Fish Audio',
volcModelNameMessage: '請輸入模型名稱!', volcModelNameMessage: '請輸入模型名稱!',
addEndpointID: '模型 EndpointID', addEndpointID: '模型 EndpointID',
endpointIDMessage: '請輸入模型對應的EndpointID', endpointIDMessage: '請輸入模型對應的EndpointID',
@ -496,6 +499,10 @@ export default {
yiyanAKMessage: '請輸入 API KEY', yiyanAKMessage: '請輸入 API KEY',
addyiyanSK: '一言 Secret KEY', addyiyanSK: '一言 Secret KEY',
yiyanSKMessage: '請輸入 Secret KEY', yiyanSKMessage: '請輸入 Secret KEY',
addFishAudioAK: 'Fish Audio API KEY',
addFishAudioAKMessage: '請輸入 API KEY',
addFishAudioRefID: 'FishAudio Refrence ID',
addFishAudioRefIDMessage: '請輸入引用模型的ID留空表示使用默認模型',
}, },
message: { message: {
registered: '註冊成功', registered: '註冊成功',

View File

@ -460,6 +460,8 @@ export default {
systemModelSettings: '系统模型设置', systemModelSettings: '系统模型设置',
chatModel: '聊天模型', chatModel: '聊天模型',
chatModelTip: '所有新创建的知识库都会使用默认的聊天LLM。', chatModelTip: '所有新创建的知识库都会使用默认的聊天LLM。',
ttsModel: 'TTS模型',
ttsModelTip: '默认的tts模型会被用于在对话过程中请求语音生成时使用',
embeddingModel: '嵌入模型', embeddingModel: '嵌入模型',
embeddingModelTip: '所有新创建的知识库都将使用的默认嵌入模型。', embeddingModelTip: '所有新创建的知识库都将使用的默认嵌入模型。',
img2txtModel: 'Img2txt模型', img2txtModel: 'Img2txt模型',
@ -482,6 +484,7 @@ export default {
modelTypeMessage: '请输入模型类型!', modelTypeMessage: '请输入模型类型!',
baseUrlNameMessage: '请输入基础 Url', baseUrlNameMessage: '请输入基础 Url',
ollamaLink: '如何集成 {{name}}', ollamaLink: '如何集成 {{name}}',
FishAudioLink: '如何使用Fish Audio',
volcModelNameMessage: '请输入模型名称!', volcModelNameMessage: '请输入模型名称!',
addEndpointID: '模型 EndpointID', addEndpointID: '模型 EndpointID',
endpointIDMessage: '请输入模型对应的EndpointID', endpointIDMessage: '请输入模型对应的EndpointID',
@ -513,6 +516,10 @@ export default {
yiyanAKMessage: '请输入 API KEY', yiyanAKMessage: '请输入 API KEY',
addyiyanSK: '一言 Secret KEY', addyiyanSK: '一言 Secret KEY',
yiyanSKMessage: '请输入 Secret KEY', yiyanSKMessage: '请输入 Secret KEY',
addFishAudioAK: 'Fish Audio API KEY',
FishAudioAKMessage: '请输入 API KEY',
addFishAudioRefID: 'FishAudio Refrence ID',
FishAudioRefIDMessage: '请输入引用模型的ID留空表示使用默认模型',
}, },
message: { message: {
registered: '注册成功', registered: '注册成功',

View File

@ -35,6 +35,7 @@ export const IconMap = {
'Tencent Hunyuan': 'hunyuan', 'Tencent Hunyuan': 'hunyuan',
'XunFei Spark': 'spark', 'XunFei Spark': 'spark',
BaiduYiyan: 'yiyan', BaiduYiyan: 'yiyan',
'Fish Audio': 'fish-audio',
}; };
export const BedrockRegionList = [ export const BedrockRegionList = [

View File

@ -0,0 +1,101 @@
import { useTranslate } from '@/hooks/common-hooks';
import { IModalProps } from '@/interfaces/common';
import { IAddLlmRequestBody } from '@/interfaces/request/llm';
import { Flex, Form, Input, Modal, Select, Space } from 'antd';
import omit from 'lodash/omit';
type FieldType = IAddLlmRequestBody & {
fish_audio_ak: string;
fish_audio_refid: string;
};
const { Option } = Select;
const FishAudioModal = ({
visible,
hideModal,
onOk,
loading,
llmFactory,
}: IModalProps<IAddLlmRequestBody> & { llmFactory: string }) => {
const [form] = Form.useForm<FieldType>();
const { t } = useTranslate('setting');
const handleOk = async () => {
const values = await form.validateFields();
const modelType = values.model_type;
const data = {
...omit(values),
model_type: modelType,
llm_factory: llmFactory,
};
console.info(data);
onOk?.(data);
};
return (
<Modal
title={t('addLlmTitle', { name: llmFactory })}
open={visible}
onOk={handleOk}
onCancel={hideModal}
okButtonProps={{ loading }}
footer={(originNode: React.ReactNode) => {
return (
<Flex justify={'space-between'}>
<a href={`https://fish.audio`} target="_blank" rel="noreferrer">
{t('FishAudioLink')}
</a>
<Space>{originNode}</Space>
</Flex>
);
}}
confirmLoading={loading}
>
<Form
name="basic"
style={{ maxWidth: 600 }}
autoComplete="off"
layout={'vertical'}
form={form}
>
<Form.Item<FieldType>
label={t('modelType')}
name="model_type"
initialValue={'tts'}
rules={[{ required: true, message: t('modelTypeMessage') }]}
>
<Select placeholder={t('modelTypeMessage')}>
<Option value="tts">tts</Option>
</Select>
</Form.Item>
<Form.Item<FieldType>
label={t('modelName')}
name="llm_name"
rules={[{ required: true, message: t('FishAudioModelNameMessage') }]}
>
<Input placeholder={t('FishAudioModelNameMessage')} />
</Form.Item>
<Form.Item<FieldType>
label={t('addFishAudioAK')}
name="FishAudio_ak"
rules={[{ required: true, message: t('FishAudioAKMessage') }]}
>
<Input placeholder={t('FishAudioAKMessage')} />
</Form.Item>
<Form.Item<FieldType>
label={t('addFishAudioRefID')}
name="FishAudio_refid"
rules={[{ required: false, message: t('FishAudioRefIDMessage') }]}
>
<Input placeholder={t('FishAudioRefIDMessage')} />
</Form.Item>
</Form>
</Modal>
);
};
export default FishAudioModal;

View File

@ -244,6 +244,33 @@ export const useSubmityiyan = () => {
}; };
}; };
export const useSubmitFishAudio = () => {
const { addLlm, loading } = useAddLlm();
const {
visible: FishAudioAddingVisible,
hideModal: hideFishAudioAddingModal,
showModal: showFishAudioAddingModal,
} = useSetModalState();
const onFishAudioAddingOk = useCallback(
async (payload: IAddLlmRequestBody) => {
const ret = await addLlm(payload);
if (ret === 0) {
hideFishAudioAddingModal();
}
},
[hideFishAudioAddingModal, addLlm],
);
return {
FishAudioAddingLoading: loading,
onFishAudioAddingOk,
FishAudioAddingVisible,
hideFishAudioAddingModal,
showFishAudioAddingModal,
};
};
export const useSubmitBedrock = () => { export const useSubmitBedrock = () => {
const { addLlm, loading } = useAddLlm(); const { addLlm, loading } = useAddLlm();
const { const {

View File

@ -30,10 +30,12 @@ import { isLocalLlmFactory } from '../utils';
import ApiKeyModal from './api-key-modal'; import ApiKeyModal from './api-key-modal';
import BedrockModal from './bedrock-modal'; import BedrockModal from './bedrock-modal';
import { IconMap } from './constant'; import { IconMap } from './constant';
import FishAudioModal from './fish-audio-modal';
import { import {
useHandleDeleteLlm, useHandleDeleteLlm,
useSubmitApiKey, useSubmitApiKey,
useSubmitBedrock, useSubmitBedrock,
useSubmitFishAudio,
useSubmitHunyuan, useSubmitHunyuan,
useSubmitOllama, useSubmitOllama,
useSubmitSpark, useSubmitSpark,
@ -98,7 +100,8 @@ const ModelCard = ({ item, clickApiKey }: IModelCardProps) => {
item.name === 'VolcEngine' || item.name === 'VolcEngine' ||
item.name === 'Tencent Hunyuan' || item.name === 'Tencent Hunyuan' ||
item.name === 'XunFei Spark' || item.name === 'XunFei Spark' ||
item.name === 'BaiduYiyan' item.name === 'BaiduYiyan' ||
item.name === 'Fish Audio'
? t('addTheModel') ? t('addTheModel')
: 'API-Key'} : 'API-Key'}
<SettingOutlined /> <SettingOutlined />
@ -196,6 +199,14 @@ const UserSettingModel = () => {
yiyanAddingLoading, yiyanAddingLoading,
} = useSubmityiyan(); } = useSubmityiyan();
const {
FishAudioAddingVisible,
hideFishAudioAddingModal,
showFishAudioAddingModal,
onFishAudioAddingOk,
FishAudioAddingLoading,
} = useSubmitFishAudio();
const { const {
bedrockAddingLoading, bedrockAddingLoading,
onBedrockAddingOk, onBedrockAddingOk,
@ -211,6 +222,7 @@ const UserSettingModel = () => {
'Tencent Hunyuan': showHunyuanAddingModal, 'Tencent Hunyuan': showHunyuanAddingModal,
'XunFei Spark': showSparkAddingModal, 'XunFei Spark': showSparkAddingModal,
BaiduYiyan: showyiyanAddingModal, BaiduYiyan: showyiyanAddingModal,
'Fish Audio': showFishAudioAddingModal,
}), }),
[ [
showBedrockAddingModal, showBedrockAddingModal,
@ -218,6 +230,7 @@ const UserSettingModel = () => {
showHunyuanAddingModal, showHunyuanAddingModal,
showSparkAddingModal, showSparkAddingModal,
showyiyanAddingModal, showyiyanAddingModal,
showFishAudioAddingModal,
], ],
); );
@ -350,6 +363,13 @@ const UserSettingModel = () => {
loading={yiyanAddingLoading} loading={yiyanAddingLoading}
llmFactory={'BaiduYiyan'} llmFactory={'BaiduYiyan'}
></YiyanModal> ></YiyanModal>
<FishAudioModal
visible={FishAudioAddingVisible}
hideModal={hideFishAudioAddingModal}
onOk={onFishAudioAddingOk}
loading={FishAudioAddingLoading}
llmFactory={'Fish Audio'}
></FishAudioModal>
<BedrockModal <BedrockModal
visible={bedrockAddingVisible} visible={bedrockAddingVisible}
hideModal={hideBedrockAddingModal} hideModal={hideBedrockAddingModal}

View File

@ -82,9 +82,9 @@ const SparkModal = ({
<Form.Item<FieldType> <Form.Item<FieldType>
label={t('addSparkAPIPassword')} label={t('addSparkAPIPassword')}
name="spark_api_password" name="spark_api_password"
rules={[{ required: true, message: t('SparkPasswordMessage') }]} rules={[{ required: true, message: t('SparkAPIPasswordMessage') }]}
> >
<Input placeholder={t('SparkSIDMessage')} /> <Input placeholder={t('SparkAPIPasswordMessage')} />
</Form.Item> </Form.Item>
</Form> </Form>
</Modal> </Modal>

View File

@ -83,6 +83,13 @@ const SystemModelSettingModal = ({
> >
<Select options={allOptions[LlmModelType.Rerank]} /> <Select options={allOptions[LlmModelType.Rerank]} />
</Form.Item> </Form.Item>
<Form.Item
label={t('ttsModel')}
name="tts_id"
tooltip={t('ttsModelTip')}
>
<Select options={allOptions[LlmModelType.TTS]} />
</Form.Item>
</Form> </Form>
</Modal> </Modal>
); );