diff --git a/api/apps/dialog_app.py b/api/apps/dialog_app.py index a77e355b5..42fde4f9f 100644 --- a/api/apps/dialog_app.py +++ b/api/apps/dialog_app.py @@ -58,11 +58,6 @@ def set_dialog(): if not prompt_config["system"]: prompt_config["system"] = default_prompt["system"] - # if len(prompt_config["parameters"]) < 1: - # prompt_config["parameters"] = default_prompt["parameters"] - # for p in prompt_config["parameters"]: - # if p["key"] == "knowledge":break - # else: prompt_config["parameters"].append(default_prompt["parameters"][0]) for p in prompt_config["parameters"]: if p["optional"]: @@ -75,23 +70,19 @@ def set_dialog(): e, tenant = TenantService.get_by_id(current_user.id) if not e: return get_data_error_result(message="Tenant not found!") - kbs = KnowledgebaseService.get_by_ids(req.get("kb_ids")) + kbs = KnowledgebaseService.get_by_ids(req.get("kb_ids", [])) embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison embd_count = len(set(embd_ids)) - if embd_count != 1: + if embd_count > 1: return get_data_error_result(message=f'Datasets use different embedding models: {[kb.embd_id for kb in kbs]}"') llm_id = req.get("llm_id", tenant.llm_id) if not dialog_id: - if not req.get("kb_ids"): - return get_data_error_result( - message="Fail! Please select knowledgebase!") - dia = { "id": get_uuid(), "tenant_id": current_user.id, "name": name, - "kb_ids": req["kb_ids"], + "kb_ids": req.get("kb_ids", []), "description": description, "llm_id": llm_id, "llm_setting": llm_setting, diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 3b68ca1b4..937626764 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -170,8 +170,40 @@ def label_question(question, kbs): return tags +def chat_solo(dialog, messages, stream=True): + if llm_id2llm_type(dialog.llm_id) == "image2text": + chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) + else: + chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) + + prompt_config = dialog.prompt_config + tts_mdl = None + if prompt_config.get("tts"): + tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS) + msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} + for m in messages if m["role"] != "system"] + if stream: + last_ans = "" + for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting): + answer = ans + delta_ans = ans[len(last_ans):] + if num_tokens_from_string(delta_ans) < 16: + continue + last_ans = answer + yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt":"", "created_at": time.time()} + else: + answer = chat_mdl.chat(prompt_config.get("system", ""), msg, dialog.llm_setting) + user_content = msg[-1].get("content", "[content not available]") + logging.debug("User: {}|Assistant: {}".format(user_content, answer)) + yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()} + + def chat(dialog, messages, stream=True, **kwargs): assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." + if not dialog.kb_ids: + for ans in chat_solo(dialog, messages, stream): + yield ans + return chat_start_ts = timer()