diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index c4529704e..b373eda1d 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -196,8 +196,8 @@ def tts(): tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id) def stream_audio(): try: - for chunk in tts_mdl.tts(text): - yield chunk + for chunk in tts_mdl.tts(text): + yield chunk except Exception as e: yield ("data:" + json.dumps({"retcode": 500, "retmsg": str(e), "data": {"answer": "**ERROR**: "+str(e)}}, diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 018cc50a8..14afdb91f 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import binascii import os import json import re @@ -120,6 +121,9 @@ def chat(dialog, messages, stream=True, **kwargs): prompt_config = dialog.prompt_config field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) + tts_mdl = None + if prompt_config.get("tts"): + tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS) # try to use sql if field mapping is good to go if field_map: chat_logger.info("Use SQL to retrieval:{}".format(questions[-1])) @@ -168,7 +172,8 @@ def chat(dialog, messages, stream=True, **kwargs): "{}->{}".format(" ".join(questions), "\n->".join(knowledges))) if not knowledges and prompt_config.get("empty_response"): - yield {"answer": prompt_config["empty_response"], "reference": kbinfos} + empty_res = prompt_config["empty_response"] + yield {"answer": empty_res, "reference": kbinfos, "audio_binary": tts(tts_mdl, empty_res)} return {"answer": prompt_config["empty_response"], "reference": kbinfos} kwargs["knowledge"] = "\n".join(knowledges) @@ -214,16 +219,26 @@ def chat(dialog, messages, stream=True, **kwargs): return {"answer": answer, "reference": refs, "prompt": prompt} if stream: + last_ans = "" answer = "" for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf): answer = ans - yield {"answer": answer, "reference": {}} + delta_ans = ans[len(last_ans):] + if num_tokens_from_string(delta_ans) < 12: + continue + last_ans = answer + yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} + delta_ans = answer[len(last_ans):] + if delta_ans: + yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} yield decorate_answer(answer) else: answer = chat_mdl.chat(prompt, msg[1:], gen_conf) chat_logger.info("User: {}|Assistant: {}".format( msg[-1]["content"], answer)) - yield decorate_answer(answer) + res = decorate_answer(answer) + res["audio_binary"] = tts(tts_mdl, answer) + yield res def use_sql(question, field_map, tenant_id, chat_mdl, quota=True): @@ -392,3 +407,12 @@ def rewrite(tenant_id, llm_id, question): """ ans = chat_mdl.chat(prompt, [{"role": "user", "content": question}], {"temperature": 0.8}) return ans + + +def tts(tts_mdl, text): + return + if not tts_mdl or not text: return + bin = b"" + for chunk in tts_mdl.tts(text): + bin += chunk + return binascii.hexlify(bin).decode("utf-8") \ No newline at end of file