mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-12 17:39:06 +08:00
add tts api (#2107)
### What problem does this PR solve? add tts api - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Zhedong Cen <cenzhedong2@126.com> Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
parent
2da4e7aa46
commit
b88c3897b9
@ -15,8 +15,10 @@
|
|||||||
#
|
#
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from flask import request, Response
|
from flask import request, Response
|
||||||
from flask_login import login_required
|
from flask_login import login_required,current_user
|
||||||
from api.db.services.dialog_service import DialogService, ConversationService, chat
|
from api.db.services.dialog_service import DialogService, ConversationService, chat
|
||||||
|
from api.db.services.llm_service import LLMBundle, TenantService
|
||||||
|
from api.db import LLMType
|
||||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.utils.api_utils import get_json_result
|
from api.utils.api_utils import get_json_result
|
||||||
@ -176,6 +178,38 @@ def completion():
|
|||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/tts', methods=['POST'])
|
||||||
|
@login_required
|
||||||
|
def tts():
|
||||||
|
req = request.json
|
||||||
|
text = req["text"]
|
||||||
|
|
||||||
|
tenants = TenantService.get_by_user_id(current_user.id)
|
||||||
|
if not tenants:
|
||||||
|
return get_data_error_result(retmsg="Tenant not found!")
|
||||||
|
|
||||||
|
tts_id = tenants[0]["tts_id"]
|
||||||
|
if not tts_id:
|
||||||
|
return get_data_error_result(retmsg="No default TTS model is set")
|
||||||
|
|
||||||
|
tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id)
|
||||||
|
def stream_audio():
|
||||||
|
try:
|
||||||
|
for chunk in tts_mdl(text):
|
||||||
|
yield chunk
|
||||||
|
except Exception as e:
|
||||||
|
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
|
||||||
|
"data": {"answer": "**ERROR**: "+str(e)}},
|
||||||
|
ensure_ascii=False).encode('utf-8')
|
||||||
|
|
||||||
|
resp = Response(stream_audio(), mimetype="audio/mpeg")
|
||||||
|
resp.headers.add_header("Cache-Control", "no-cache")
|
||||||
|
resp.headers.add_header("Connection", "keep-alive")
|
||||||
|
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/delete_msg', methods=['POST'])
|
@manager.route('/delete_msg', methods=['POST'])
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request("conversation_id", "message_id")
|
@validate_request("conversation_id", "message_id")
|
||||||
|
@ -96,6 +96,7 @@ class TenantService(CommonService):
|
|||||||
cls.model.rerank_id,
|
cls.model.rerank_id,
|
||||||
cls.model.asr_id,
|
cls.model.asr_id,
|
||||||
cls.model.img2txt_id,
|
cls.model.img2txt_id,
|
||||||
|
cls.model.tts_id,
|
||||||
cls.model.parser_ids,
|
cls.model.parser_ids,
|
||||||
UserTenant.role]
|
UserTenant.role]
|
||||||
return list(cls.model.select(*fields)
|
return list(cls.model.select(*fields)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user