From 673a28e4927a8970a0c97f0916e1256fdc65b1b8 Mon Sep 17 00:00:00 2001 From: KevinHuSh Date: Fri, 17 May 2024 20:03:00 +0800 Subject: [PATCH] fix bug of chat without stream (#830) ### What problem does this PR solve? ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- api/apps/api_app.py | 11 +++++++---- api/apps/conversation_app.py | 11 +++++++---- api/db/services/dialog_service.py | 8 +++----- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/api/apps/api_app.py b/api/apps/api_app.py index 0c0b191b2..3c257e4cc 100644 --- a/api/apps/api_app.py +++ b/api/apps/api_app.py @@ -222,10 +222,13 @@ def completion(): resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") return resp else: - ans = chat(dia, msg, **req) - fillin_conv(ans) - API4ConversationService.append_message(conv.id, conv.to_dict()) - return get_json_result(data=ans) + answer = None + for ans in chat(dia, msg, **req): + answer = ans + fillin_conv(ans) + API4ConversationService.append_message(conv.id, conv.to_dict()) + break + return get_json_result(data=answer) except Exception as e: return server_error_response(e) diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index 6d06a05f8..1400d6623 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -162,10 +162,13 @@ def completion(): return resp else: - ans = chat(dia, msg, **req) - fillin_conv(ans) - ConversationService.update_by_id(conv.id, conv.to_dict()) - return get_json_result(data=ans) + answer = None + for ans in chat(dia, msg, **req): + answer = ans + fillin_conv(ans) + ConversationService.update_by_id(conv.id, conv.to_dict()) + break + return get_json_result(data=answer) except Exception as e: return server_error_response(e) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 94c2285c7..321d01bb6 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -84,8 +84,7 @@ def chat(dialog, messages, stream=True, **kwargs): kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids) embd_nms = list(set([kb.embd_id for kb in kbs])) if len(embd_nms) != 1: - if stream: - yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []} + yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []} return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []} questions = [m["content"] for m in messages if m["role"] == "user"] @@ -126,8 +125,7 @@ def chat(dialog, messages, stream=True, **kwargs): "{}->{}".format(" ".join(questions), "\n->".join(knowledges))) if not knowledges and prompt_config.get("empty_response"): - if stream: - yield {"answer": prompt_config["empty_response"], "reference": kbinfos} + yield {"answer": prompt_config["empty_response"], "reference": kbinfos} return {"answer": prompt_config["empty_response"], "reference": kbinfos} kwargs["knowledge"] = "\n".join(knowledges) @@ -177,7 +175,7 @@ def chat(dialog, messages, stream=True, **kwargs): **kwargs), msg, gen_conf) chat_logger.info("User: {}|Assistant: {}".format( msg[-1]["content"], answer)) - return decorate_answer(answer) + yield decorate_answer(answer) def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):