diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 1a21a1a7c..ef86d763c 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -23,7 +23,7 @@ from copy import deepcopy from timeit import default_timer as timer import datetime from datetime import timedelta -from api.db import LLMType, ParserType,StatusEnum +from api.db import LLMType, ParserType, StatusEnum from api.db.db_models import Dialog, DB from api.db.services.common_service import CommonService from api.db.services.knowledgebase_service import KnowledgebaseService @@ -41,14 +41,14 @@ class DialogService(CommonService): @classmethod @DB.connection_context() def get_list(cls, tenant_id, - page_number, items_per_page, orderby, desc, id , name): + page_number, items_per_page, orderby, desc, id, name): chats = cls.model.select() if id: chats = chats.where(cls.model.id == id) if name: chats = chats.where(cls.model.name == name) chats = chats.where( - (cls.model.tenant_id == tenant_id) + (cls.model.tenant_id == tenant_id) & (cls.model.status == StatusEnum.VALID.value) ) if desc: @@ -137,25 +137,37 @@ def kb_prompt(kbinfos, max_tokens): def chat(dialog, messages, stream=True, **kwargs): assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." - st = timer() - llm_id, fid = TenantLLMService.split_model_name_and_factory(dialog.llm_id) - llm = LLMService.query(llm_name=llm_id) if not fid else LLMService.query(llm_name=llm_id, fid=fid) + + chat_start_ts = timer() + + # Get llm model name and model provider name + llm_id, model_provider = TenantLLMService.split_model_name_and_factory(dialog.llm_id) + + # Get llm model instance by model and provide name + llm = LLMService.query(llm_name=llm_id) if not model_provider else LLMService.query(llm_name=llm_id, fid=model_provider) + if not llm: - llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not fid else \ - TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id, llm_factory=fid) + # Model name is provided by tenant, but not system built-in + llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not model_provider else \ + TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id, llm_factory=model_provider) if not llm: raise LookupError("LLM(%s) not found" % dialog.llm_id) max_tokens = 8192 else: max_tokens = llm[0].max_tokens + + check_llm_ts = timer() + kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids) - embd_nms = list(set([kb.embd_id for kb in kbs])) - if len(embd_nms) != 1: + embedding_list = list(set([kb.embd_id for kb in kbs])) + if len(embedding_list) != 1: yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []} return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []} - is_kg = all([kb.parser_id == ParserType.KG for kb in kbs]) - retr = settings.retrievaler if not is_kg else settings.kg_retrievaler + embedding_model_name = embedding_list[0] + + is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs]) + retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler questions = [m["content"] for m in messages if m["role"] == "user"][-3:] attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None @@ -165,15 +177,21 @@ def chat(dialog, messages, stream=True, **kwargs): if "doc_ids" in m: attachments.extend(m["doc_ids"]) - embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0]) + create_retriever_ts = timer() + + embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_model_name) if not embd_mdl: - raise LookupError("Embedding model(%s) not found" % embd_nms[0]) + raise LookupError("Embedding model(%s) not found" % embedding_model_name) + + bind_embedding_ts = timer() if llm_id2llm_type(dialog.llm_id) == "image2text": chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) else: chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) + bind_llm_ts = timer() + prompt_config = dialog.prompt_config field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) tts_mdl = None @@ -200,32 +218,35 @@ def chat(dialog, messages, stream=True, **kwargs): questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)] else: questions = questions[-1:] - refineQ_tm = timer() - keyword_tm = timer() + + refine_question_ts = timer() rerank_mdl = None if dialog.rerank_id: rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id) - for _ in range(len(questions) // 2): - questions.append(questions[-1]) + bind_reranker_ts = timer() + generate_keyword_ts = bind_reranker_ts + if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]: kbinfos = {"total": 0, "chunks": [], "doc_aggs": []} else: if prompt_config.get("keyword", False): questions[-1] += keyword_extraction(chat_mdl, questions[-1]) - keyword_tm = timer() + generate_keyword_ts = timer() tenant_ids = list(set([kb.tenant_id for kb in kbs])) - kbinfos = retr.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n, - dialog.similarity_threshold, - dialog.vector_similarity_weight, - doc_ids=attachments, - top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl) + kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n, + dialog.similarity_threshold, + dialog.vector_similarity_weight, + doc_ids=attachments, + top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl) + + retrieval_ts = timer() + knowledges = kb_prompt(kbinfos, max_tokens) logging.debug( "{}->{}".format(" ".join(questions), "\n->".join(knowledges))) - retrieval_tm = timer() if not knowledges and prompt_config.get("empty_response"): empty_res = prompt_config["empty_response"] @@ -249,17 +270,20 @@ def chat(dialog, messages, stream=True, **kwargs): max_tokens - used_token_count) def decorate_answer(answer): - nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_tm + nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts + + finish_chat_ts = timer() + refs = [] if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): - answer, idx = retr.insert_citations(answer, - [ck["content_ltks"] - for ck in kbinfos["chunks"]], - [ck["vector"] - for ck in kbinfos["chunks"]], - embd_mdl, - tkweight=1 - dialog.vector_similarity_weight, - vtweight=dialog.vector_similarity_weight) + answer, idx = retriever.insert_citations(answer, + [ck["content_ltks"] + for ck in kbinfos["chunks"]], + [ck["vector"] + for ck in kbinfos["chunks"]], + embd_mdl, + tkweight=1 - dialog.vector_similarity_weight, + vtweight=dialog.vector_similarity_weight) idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) recall_docs = [ d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] @@ -274,10 +298,20 @@ def chat(dialog, messages, stream=True, **kwargs): if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0: answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'" - done_tm = timer() - prompt += "\n\n### Elapsed\n - Refine Question: %.1f ms\n - Keywords: %.1f ms\n - Retrieval: %.1f ms\n - LLM: %.1f ms" % ( - (refineQ_tm - st) * 1000, (keyword_tm - refineQ_tm) * 1000, (retrieval_tm - keyword_tm) * 1000, - (done_tm - retrieval_tm) * 1000) + finish_chat_ts = timer() + + total_time_cost = (finish_chat_ts - chat_start_ts) * 1000 + check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000 + create_retriever_time_cost = (create_retriever_ts - check_llm_ts) * 1000 + bind_embedding_time_cost = (bind_embedding_ts - create_retriever_ts) * 1000 + bind_llm_time_cost = (bind_llm_ts - bind_embedding_ts) * 1000 + refine_question_time_cost = (refine_question_ts - bind_llm_ts) * 1000 + bind_reranker_time_cost = (bind_reranker_ts - refine_question_ts) * 1000 + generate_keyword_time_cost = (generate_keyword_ts - bind_reranker_ts) * 1000 + retrieval_time_cost = (retrieval_ts - generate_keyword_ts) * 1000 + generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000 + + prompt = f"{prompt} ### Elapsed\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms" return {"answer": answer, "reference": refs, "prompt": prompt} if stream: @@ -304,15 +338,15 @@ def chat(dialog, messages, stream=True, **kwargs): def use_sql(question, field_map, tenant_id, chat_mdl, quota=True): - sys_prompt = "你是一个DBA。你需要这对以下表的字段结构,根据用户的问题列表,写出最后一个问题对应的SQL。" - user_promt = """ -表名:{}; -数据库表字段说明如下: + sys_prompt = "You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question." + user_prompt = """ +Table name: {}; +Table of database fields are as follows: {} -问题如下: +Question are as follows: {} -请写出SQL, 且只要SQL,不要有其他说明及文字。 +Please write the SQL, only SQL, without any other explanations or text. """.format( index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), @@ -321,10 +355,10 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True): tried_times = 0 def get_table(): - nonlocal sys_prompt, user_promt, question, tried_times - sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], { + nonlocal sys_prompt, user_prompt, question, tried_times + sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], { "temperature": 0.06}) - logging.debug(f"{question} ==> {user_promt} get SQL: {sql}") + logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}") sql = re.sub(r"[\r\n]+", " ", sql.lower()) sql = re.sub(r".*select ", "select ", sql.lower()) sql = re.sub(r" +", " ", sql) @@ -352,21 +386,23 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True): if tbl is None: return None if tbl.get("error") and tried_times <= 2: - user_promt = """ - 表名:{}; - 数据库表字段说明如下: + user_prompt = """ + Table name: {}; + Table of database fields are as follows: + {} + + Question are as follows: + {} + Please write the SQL, only SQL, without any other explanations or text. + + + The SQL error you provided last time is as follows: {} - 问题如下: + Error issued by database as follows: {} - 你上一次给出的错误SQL如下: - {} - - 后台报错如下: - {} - - 请纠正SQL中的错误再写一遍,且只要SQL,不要有其他说明及文字。 + Please correct the error and write SQL again, only SQL, without any other explanations or text. """.format( index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), @@ -381,21 +417,21 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True): docid_idx = set([ii for ii, c in enumerate( tbl["columns"]) if c["name"] == "doc_id"]) - docnm_idx = set([ii for ii, c in enumerate( + doc_name_idx = set([ii for ii, c in enumerate( tbl["columns"]) if c["name"] == "docnm_kwd"]) - clmn_idx = [ii for ii in range( - len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)] + column_idx = [ii for ii in range( + len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)] - # compose markdown table - clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], - tbl["columns"][i]["name"])) for i in - clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|") + # compose Markdown table + columns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], + tbl["columns"][i]["name"])) for i in + column_idx]) + ("|Source|" if docid_idx and docid_idx else "|") - line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \ + line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + \ ("|------|" if docid_idx and docid_idx else "") rows = ["|" + - "|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + + "|".join([rmSpace(str(r[i])) for i in column_idx]).replace("None", " ") + "|" for r in tbl["rows"]] rows = [r for r in rows if re.sub(r"[ |]+", "", r)] if quota: @@ -404,24 +440,24 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True): rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows) - if not docid_idx or not docnm_idx: + if not docid_idx or not doc_name_idx: logging.warning("SQL missing field: " + sql) return { - "answer": "\n".join([clmns, line, rows]), + "answer": "\n".join([columns, line, rows]), "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt } docid_idx = list(docid_idx)[0] - docnm_idx = list(docnm_idx)[0] + doc_name_idx = list(doc_name_idx)[0] doc_aggs = {} for r in tbl["rows"]: if r[docid_idx] not in doc_aggs: - doc_aggs[r[docid_idx]] = {"doc_name": r[docnm_idx], "count": 0} + doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0} doc_aggs[r[docid_idx]]["count"] += 1 return { - "answer": "\n".join([clmns, line, rows]), - "reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]], + "answer": "\n".join([columns, line, rows]), + "reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]], "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]}, "prompt": sys_prompt @@ -492,7 +528,7 @@ Requirements: kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2}) if isinstance(kwd, tuple): kwd = kwd[0] - if kwd.find("**ERROR**") >=0: + if kwd.find("**ERROR**") >= 0: return "" return kwd @@ -605,16 +641,16 @@ def tts(tts_mdl, text): def ask(question, kb_ids, tenant_id): kbs = KnowledgebaseService.get_by_ids(kb_ids) - embd_nms = list(set([kb.embd_id for kb in kbs])) + embedding_list = list(set([kb.embd_id for kb in kbs])) - is_kg = all([kb.parser_id == ParserType.KG for kb in kbs]) - retr = settings.retrievaler if not is_kg else settings.kg_retrievaler + is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs]) + retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler - embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0]) + embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0]) chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) max_tokens = chat_mdl.max_length tenant_ids = list(set([kb.tenant_id for kb in kbs])) - kbinfos = retr.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False) + kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False) knowledges = kb_prompt(kbinfos, max_tokens) prompt = """ Role: You're a smart assistant. Your name is Miss R. @@ -636,14 +672,14 @@ def ask(question, kb_ids, tenant_id): def decorate_answer(answer): nonlocal knowledges, kbinfos, prompt - answer, idx = retr.insert_citations(answer, - [ck["content_ltks"] - for ck in kbinfos["chunks"]], - [ck["vector"] - for ck in kbinfos["chunks"]], - embd_mdl, - tkweight=0.7, - vtweight=0.3) + answer, idx = retriever.insert_citations(answer, + [ck["content_ltks"] + for ck in kbinfos["chunks"]], + [ck["vector"] + for ck in kbinfos["chunks"]], + embd_mdl, + tkweight=0.7, + vtweight=0.3) idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) recall_docs = [ d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] @@ -664,4 +700,3 @@ def ask(question, kb_ids, tenant_id): answer = ans yield {"answer": answer, "reference": {}} yield decorate_answer(answer) - diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 2d47a93ef..a5fe633c5 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -72,10 +72,12 @@ class TenantLLMService(CommonService): return model_name, None if len(arr) > 2: return "@".join(arr[0:-1]), arr[-1] + + # model name must be xxx@yyy try: - fact = json.load(open(os.path.join(get_project_base_directory(), "conf/llm_factories.json"), "r"))["factory_llm_infos"] - fact = set([f["name"] for f in fact]) - if arr[-1] not in fact: + model_factories = json.load(open(os.path.join(get_project_base_directory(), "conf/llm_factories.json"), "r"))["factory_llm_infos"] + model_providers = set([f["name"] for f in model_factories]) + if arr[-1] not in model_providers: return model_name, None return arr[0], arr[-1] except Exception as e: @@ -113,11 +115,11 @@ class TenantLLMService(CommonService): if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]: llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid) if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]: - model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": mdlnm, "api_base": ""} + model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""} if not model_config: if mdlnm == "flag-embedding": model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", - "llm_name": llm_name, "api_base": ""} + "llm_name": llm_name, "api_base": ""} else: if not mdlnm: raise LookupError(f"Type of {llm_type} model is not set.") @@ -200,8 +202,8 @@ class TenantLLMService(CommonService): return num else: tenant_llm = tenant_llms[0] - num = cls.model.update(used_tokens=tenant_llm.used_tokens + used_tokens)\ - .where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name)\ + num = cls.model.update(used_tokens=tenant_llm.used_tokens + used_tokens) \ + .where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name) \ .execute() except Exception: logging.exception("TenantLLMService.increase_usage got exception") @@ -231,7 +233,7 @@ class LLMBundle(object): for lm in LLMService.query(llm_name=llm_name): self.max_length = lm.max_tokens break - + def encode(self, texts: list): embeddings, used_tokens = self.mdl.encode(texts) if not TenantLLMService.increase_usage( @@ -274,11 +276,11 @@ class LLMBundle(object): def tts(self, text): for chunk in self.mdl.tts(text): - if isinstance(chunk,int): + if isinstance(chunk, int): if not TenantLLMService.increase_usage( - self.tenant_id, self.llm_type, chunk, self.llm_name): - logging.error( - "LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id)) + self.tenant_id, self.llm_type, chunk, self.llm_name): + logging.error( + "LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id)) return yield chunk @@ -287,7 +289,8 @@ class LLMBundle(object): if isinstance(txt, int) and not TenantLLMService.increase_usage( self.tenant_id, self.llm_type, used_tokens, self.llm_name): logging.error( - "LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens)) + "LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, + used_tokens)) return txt def chat_streamly(self, system, history, gen_conf): @@ -296,6 +299,7 @@ class LLMBundle(object): if not TenantLLMService.increase_usage( self.tenant_id, self.llm_type, txt, self.llm_name): logging.error( - "LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt)) + "LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, + txt)) return yield txt