Support chat solo. (#5218)

### What problem does this PR solve?

#5216

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu 2025-02-21 12:24:02 +08:00 committed by GitHub
parent c54ec09519
commit f5d63bb7df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 12 deletions

View File

@ -58,11 +58,6 @@ def set_dialog():
if not prompt_config["system"]:
prompt_config["system"] = default_prompt["system"]
# if len(prompt_config["parameters"]) < 1:
# prompt_config["parameters"] = default_prompt["parameters"]
# for p in prompt_config["parameters"]:
# if p["key"] == "knowledge":break
# else: prompt_config["parameters"].append(default_prompt["parameters"][0])
for p in prompt_config["parameters"]:
if p["optional"]:
@ -75,23 +70,19 @@ def set_dialog():
e, tenant = TenantService.get_by_id(current_user.id)
if not e:
return get_data_error_result(message="Tenant not found!")
kbs = KnowledgebaseService.get_by_ids(req.get("kb_ids"))
kbs = KnowledgebaseService.get_by_ids(req.get("kb_ids", []))
embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison
embd_count = len(set(embd_ids))
if embd_count != 1:
if embd_count > 1:
return get_data_error_result(message=f'Datasets use different embedding models: {[kb.embd_id for kb in kbs]}"')
llm_id = req.get("llm_id", tenant.llm_id)
if not dialog_id:
if not req.get("kb_ids"):
return get_data_error_result(
message="Fail! Please select knowledgebase!")
dia = {
"id": get_uuid(),
"tenant_id": current_user.id,
"name": name,
"kb_ids": req["kb_ids"],
"kb_ids": req.get("kb_ids", []),
"description": description,
"llm_id": llm_id,
"llm_setting": llm_setting,

View File

@ -170,8 +170,40 @@ def label_question(question, kbs):
return tags
def chat_solo(dialog, messages, stream=True):
if llm_id2llm_type(dialog.llm_id) == "image2text":
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else:
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
prompt_config = dialog.prompt_config
tts_mdl = None
if prompt_config.get("tts"):
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])}
for m in messages if m["role"] != "system"]
if stream:
last_ans = ""
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
answer = ans
delta_ans = ans[len(last_ans):]
if num_tokens_from_string(delta_ans) < 16:
continue
last_ans = answer
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt":"", "created_at": time.time()}
else:
answer = chat_mdl.chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
user_content = msg[-1].get("content", "[content not available]")
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
def chat(dialog, messages, stream=True, **kwargs):
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
if not dialog.kb_ids:
for ans in chat_solo(dialog, messages, stream):
yield ans
return
chat_start_ts = timer()