From 9f7d187ab31b857692db78b8bcc0ff966d9095f3 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Mon, 9 Sep 2024 12:08:50 +0800 Subject: [PATCH] add elapsed time of conversation (#2316) ### What problem does this PR solve? #2315 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- api/db/services/dialog_service.py | 95 ++++++++++++++++++++++++++----- 1 file changed, 80 insertions(+), 15 deletions(-) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 14afdb91f..7f9fc84b4 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -18,7 +18,7 @@ import os import json import re from copy import deepcopy - +from timeit import default_timer as timer from api.db import LLMType, ParserType from api.db.db_models import Dialog, Conversation from api.db.services.common_service import CommonService @@ -88,6 +88,7 @@ def llm_id2llm_type(llm_id): def chat(dialog, messages, stream=True, **kwargs): assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." + st = timer() llm = LLMService.query(llm_name=dialog.llm_id) if not llm: llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=dialog.llm_id) @@ -158,25 +159,16 @@ def chat(dialog, messages, stream=True, **kwargs): doc_ids=attachments, top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl) knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] - #self-rag - if dialog.prompt_config.get("self_rag") and not relevant(dialog.tenant_id, dialog.llm_id, questions[-1], knowledges): - questions[-1] = rewrite(dialog.tenant_id, dialog.llm_id, questions[-1]) - kbinfos = retr.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, - dialog.similarity_threshold, - dialog.vector_similarity_weight, - doc_ids=attachments, - top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl) - knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] - chat_logger.info( "{}->{}".format(" ".join(questions), "\n->".join(knowledges))) + retrieval_tm = timer() if not knowledges and prompt_config.get("empty_response"): empty_res = prompt_config["empty_response"] yield {"answer": empty_res, "reference": kbinfos, "audio_binary": tts(tts_mdl, empty_res)} return {"answer": prompt_config["empty_response"], "reference": kbinfos} - kwargs["knowledge"] = "\n".join(knowledges) + kwargs["knowledge"] = "\n------\n".join(knowledges) gen_conf = dialog.llm_setting msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}] @@ -192,7 +184,7 @@ def chat(dialog, messages, stream=True, **kwargs): max_tokens - used_token_count) def decorate_answer(answer): - nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt + nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_tm refs = [] if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): answer, idx = retr.insert_citations(answer, @@ -216,7 +208,9 @@ def chat(dialog, messages, stream=True, **kwargs): if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0: answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'" - return {"answer": answer, "reference": refs, "prompt": prompt} + done_tm = timer() + prompt += "\n### Elapsed\n - Retrieval: %.1f ms\n - LLM: %.1f ms"%((retrieval_tm-st)*1000, (done_tm-st)*1000) + return {"answer": answer, "reference": refs, "prompt": re.sub(r"\n", "
", prompt)} if stream: last_ans = "" @@ -415,4 +409,75 @@ def tts(tts_mdl, text): bin = b"" for chunk in tts_mdl.tts(text): bin += chunk - return binascii.hexlify(bin).decode("utf-8") \ No newline at end of file + return binascii.hexlify(bin).decode("utf-8") + + +def ask(question, kb_ids, tenant_id): + kbs = KnowledgebaseService.get_by_ids(kb_ids) + embd_nms = list(set([kb.embd_id for kb in kbs])) + + is_kg = all([kb.parser_id == ParserType.KG for kb in kbs]) + retr = retrievaler if not is_kg else kg_retrievaler + + embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0]) + chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) + max_tokens = chat_mdl.max_length + + kbinfos = retr.retrieval(question, embd_mdl, tenant_id, kb_ids, 1, 12, 0.1, 0.3, aggs=False) + knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] + + used_token_count = 0 + for i, c in enumerate(knowledges): + used_token_count += num_tokens_from_string(c) + if max_tokens * 0.97 < used_token_count: + knowledges = knowledges[:i] + break + + prompt = """ + Role: You're a smart assistant. Your name is Miss R. + Task: Summarize the information from knowledge bases and answer user's question. + Requirements and restriction: + - DO NOT make things up, especially for numbers. + - If the information from knowledge is irrelevant with user's question, JUST SAY: Sorry, no relevant information provided. + - Answer with markdown format text. + - Answer in language of user's question. + - DO NOT make things up, especially for numbers. + + ### Information from knowledge bases + %s + + The above is information from knowledge bases. + + """%"\n".join(knowledges) + msg = [{"role": "user", "content": question}] + + def decorate_answer(answer): + nonlocal knowledges, kbinfos, prompt + answer, idx = retr.insert_citations(answer, + [ck["content_ltks"] + for ck in kbinfos["chunks"]], + [ck["vector"] + for ck in kbinfos["chunks"]], + embd_mdl, + tkweight=0.7, + vtweight=0.3) + idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) + recall_docs = [ + d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] + if not recall_docs: recall_docs = kbinfos["doc_aggs"] + kbinfos["doc_aggs"] = recall_docs + refs = deepcopy(kbinfos) + for c in refs["chunks"]: + if c.get("vector"): + del c["vector"] + + if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0: + answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'" + return {"answer": answer, "reference": refs} + + answer = "" + for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}): + answer = ans + yield {"answer": answer, "reference": {}} + yield decorate_answer(answer) +