diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 18b8a7dba..f96cfa910 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -149,6 +149,11 @@ def chat(dialog, messages, stream=True, **kwargs): prompt_config["system"] = prompt_config["system"].replace( "{%s}" % p["key"], " ") + if len(questions) > 1 and prompt_config.get("refine_multiturn"): + questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)] + else: + questions = questions[-1:] + rerank_mdl = None if dialog.rerank_id: rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id) @@ -410,6 +415,58 @@ def rewrite(tenant_id, llm_id, question): return ans +def full_question(tenant_id, llm_id, messages): + if llm_id2llm_type(llm_id) == "image2text": + chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id) + else: + chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) + conv = [] + for m in messages: + if m["role"] not in ["user", "assistant"]: continue + conv.append("{}: {}".format(m["role"].upper(), m["content"])) + conv = "\n".join(conv) + prompt = f""" +Role: A helpful assistant +Task: Generate a full user question that would follow the conversation. +Requirements & Restrictions: + - Text generated MUST be in the same language of the original user's question. + - If the user's latest question is completely, don't do anything, just return the original question. + - DON'T generate anything except a refined question. + +###################### +-Examples- +###################### + +# Example 1 +## Conversation +USER: What is the name of Donald Trump's father? +ASSISTANT: Fred Trump. +USER: And his mother? +############### +Output: What's the name of Donald Trump's mother? + +------------ +# Example 2 +## Conversation +USER: What is the name of Donald Trump's father? +ASSISTANT: Fred Trump. +USER: And his mother? +ASSISTANT: Mary Trump. +User: What's her full name? +############### +Output: What's the full name of Donald Trump's mother Mary Trump? + +###################### + +# Real Data +## Conversation +{conv} +############### + """ + ans = chat_mdl.chat(prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.2}) + return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"] + + def tts(tts_mdl, text): if not tts_mdl or not text: return bin = b""