mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-14 01:35:59 +08:00
add stream chat with TTS (#2228)
### What problem does this PR solve? ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
parent
07de36ec86
commit
abc32803cc
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import binascii
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
@ -120,6 +121,9 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
|
|
||||||
prompt_config = dialog.prompt_config
|
prompt_config = dialog.prompt_config
|
||||||
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
||||||
|
tts_mdl = None
|
||||||
|
if prompt_config.get("tts"):
|
||||||
|
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
|
||||||
# try to use sql if field mapping is good to go
|
# try to use sql if field mapping is good to go
|
||||||
if field_map:
|
if field_map:
|
||||||
chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
|
chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
|
||||||
@ -168,7 +172,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
||||||
|
|
||||||
if not knowledges and prompt_config.get("empty_response"):
|
if not knowledges and prompt_config.get("empty_response"):
|
||||||
yield {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
empty_res = prompt_config["empty_response"]
|
||||||
|
yield {"answer": empty_res, "reference": kbinfos, "audio_binary": tts(tts_mdl, empty_res)}
|
||||||
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
||||||
|
|
||||||
kwargs["knowledge"] = "\n".join(knowledges)
|
kwargs["knowledge"] = "\n".join(knowledges)
|
||||||
@ -214,16 +219,26 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
return {"answer": answer, "reference": refs, "prompt": prompt}
|
return {"answer": answer, "reference": refs, "prompt": prompt}
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
last_ans = ""
|
||||||
answer = ""
|
answer = ""
|
||||||
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
|
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
|
||||||
answer = ans
|
answer = ans
|
||||||
yield {"answer": answer, "reference": {}}
|
delta_ans = ans[len(last_ans):]
|
||||||
|
if num_tokens_from_string(delta_ans) < 12:
|
||||||
|
continue
|
||||||
|
last_ans = answer
|
||||||
|
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||||
|
delta_ans = answer[len(last_ans):]
|
||||||
|
if delta_ans:
|
||||||
|
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||||
yield decorate_answer(answer)
|
yield decorate_answer(answer)
|
||||||
else:
|
else:
|
||||||
answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
|
answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
|
||||||
chat_logger.info("User: {}|Assistant: {}".format(
|
chat_logger.info("User: {}|Assistant: {}".format(
|
||||||
msg[-1]["content"], answer))
|
msg[-1]["content"], answer))
|
||||||
yield decorate_answer(answer)
|
res = decorate_answer(answer)
|
||||||
|
res["audio_binary"] = tts(tts_mdl, answer)
|
||||||
|
yield res
|
||||||
|
|
||||||
|
|
||||||
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
||||||
@ -392,3 +407,12 @@ def rewrite(tenant_id, llm_id, question):
|
|||||||
"""
|
"""
|
||||||
ans = chat_mdl.chat(prompt, [{"role": "user", "content": question}], {"temperature": 0.8})
|
ans = chat_mdl.chat(prompt, [{"role": "user", "content": question}], {"temperature": 0.8})
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def tts(tts_mdl, text):
|
||||||
|
return
|
||||||
|
if not tts_mdl or not text: return
|
||||||
|
bin = b""
|
||||||
|
for chunk in tts_mdl.tts(text):
|
||||||
|
bin += chunk
|
||||||
|
return binascii.hexlify(bin).decode("utf-8")
|
Loading…
x
Reference in New Issue
Block a user