Refactor ask decorator (#4116)

### What problem does this PR solve?

Refactor ask decorator

### Type of change

- [x] Refactoring

---------

Signed-off-by: jinhai <haijin.chn@gmail.com>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
Jin Hai 2024-12-19 18:13:33 +08:00 committed by GitHub
parent 478da3118c
commit 213218a094
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 141 additions and 102 deletions

View File

@ -23,7 +23,7 @@ from copy import deepcopy
from timeit import default_timer as timer
import datetime
from datetime import timedelta
from api.db import LLMType, ParserType,StatusEnum
from api.db import LLMType, ParserType, StatusEnum
from api.db.db_models import Dialog, DB
from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService
@ -41,14 +41,14 @@ class DialogService(CommonService):
@classmethod
@DB.connection_context()
def get_list(cls, tenant_id,
page_number, items_per_page, orderby, desc, id , name):
page_number, items_per_page, orderby, desc, id, name):
chats = cls.model.select()
if id:
chats = chats.where(cls.model.id == id)
if name:
chats = chats.where(cls.model.name == name)
chats = chats.where(
(cls.model.tenant_id == tenant_id)
(cls.model.tenant_id == tenant_id)
& (cls.model.status == StatusEnum.VALID.value)
)
if desc:
@ -137,25 +137,37 @@ def kb_prompt(kbinfos, max_tokens):
def chat(dialog, messages, stream=True, **kwargs):
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
st = timer()
llm_id, fid = TenantLLMService.split_model_name_and_factory(dialog.llm_id)
llm = LLMService.query(llm_name=llm_id) if not fid else LLMService.query(llm_name=llm_id, fid=fid)
chat_start_ts = timer()
# Get llm model name and model provider name
llm_id, model_provider = TenantLLMService.split_model_name_and_factory(dialog.llm_id)
# Get llm model instance by model and provide name
llm = LLMService.query(llm_name=llm_id) if not model_provider else LLMService.query(llm_name=llm_id, fid=model_provider)
if not llm:
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not fid else \
TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id, llm_factory=fid)
# Model name is provided by tenant, but not system built-in
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not model_provider else \
TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id, llm_factory=model_provider)
if not llm:
raise LookupError("LLM(%s) not found" % dialog.llm_id)
max_tokens = 8192
else:
max_tokens = llm[0].max_tokens
check_llm_ts = timer()
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
embd_nms = list(set([kb.embd_id for kb in kbs]))
if len(embd_nms) != 1:
embedding_list = list(set([kb.embd_id for kb in kbs]))
if len(embedding_list) != 1:
yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
retr = settings.retrievaler if not is_kg else settings.kg_retrievaler
embedding_model_name = embedding_list[0]
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
@ -165,15 +177,21 @@ def chat(dialog, messages, stream=True, **kwargs):
if "doc_ids" in m:
attachments.extend(m["doc_ids"])
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
create_retriever_ts = timer()
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_model_name)
if not embd_mdl:
raise LookupError("Embedding model(%s) not found" % embd_nms[0])
raise LookupError("Embedding model(%s) not found" % embedding_model_name)
bind_embedding_ts = timer()
if llm_id2llm_type(dialog.llm_id) == "image2text":
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else:
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
bind_llm_ts = timer()
prompt_config = dialog.prompt_config
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
tts_mdl = None
@ -200,32 +218,35 @@ def chat(dialog, messages, stream=True, **kwargs):
questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
else:
questions = questions[-1:]
refineQ_tm = timer()
keyword_tm = timer()
refine_question_ts = timer()
rerank_mdl = None
if dialog.rerank_id:
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
for _ in range(len(questions) // 2):
questions.append(questions[-1])
bind_reranker_ts = timer()
generate_keyword_ts = bind_reranker_ts
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
else:
if prompt_config.get("keyword", False):
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
keyword_tm = timer()
generate_keyword_ts = timer()
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
kbinfos = retr.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight,
doc_ids=attachments,
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight,
doc_ids=attachments,
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
retrieval_ts = timer()
knowledges = kb_prompt(kbinfos, max_tokens)
logging.debug(
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
retrieval_tm = timer()
if not knowledges and prompt_config.get("empty_response"):
empty_res = prompt_config["empty_response"]
@ -249,17 +270,20 @@ def chat(dialog, messages, stream=True, **kwargs):
max_tokens - used_token_count)
def decorate_answer(answer):
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_tm
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts
finish_chat_ts = timer()
refs = []
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
answer, idx = retr.insert_citations(answer,
[ck["content_ltks"]
for ck in kbinfos["chunks"]],
[ck["vector"]
for ck in kbinfos["chunks"]],
embd_mdl,
tkweight=1 - dialog.vector_similarity_weight,
vtweight=dialog.vector_similarity_weight)
answer, idx = retriever.insert_citations(answer,
[ck["content_ltks"]
for ck in kbinfos["chunks"]],
[ck["vector"]
for ck in kbinfos["chunks"]],
embd_mdl,
tkweight=1 - dialog.vector_similarity_weight,
vtweight=dialog.vector_similarity_weight)
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
@ -274,10 +298,20 @@ def chat(dialog, messages, stream=True, **kwargs):
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
done_tm = timer()
prompt += "\n\n### Elapsed\n - Refine Question: %.1f ms\n - Keywords: %.1f ms\n - Retrieval: %.1f ms\n - LLM: %.1f ms" % (
(refineQ_tm - st) * 1000, (keyword_tm - refineQ_tm) * 1000, (retrieval_tm - keyword_tm) * 1000,
(done_tm - retrieval_tm) * 1000)
finish_chat_ts = timer()
total_time_cost = (finish_chat_ts - chat_start_ts) * 1000
check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000
create_retriever_time_cost = (create_retriever_ts - check_llm_ts) * 1000
bind_embedding_time_cost = (bind_embedding_ts - create_retriever_ts) * 1000
bind_llm_time_cost = (bind_llm_ts - bind_embedding_ts) * 1000
refine_question_time_cost = (refine_question_ts - bind_llm_ts) * 1000
bind_reranker_time_cost = (bind_reranker_ts - refine_question_ts) * 1000
generate_keyword_time_cost = (generate_keyword_ts - bind_reranker_ts) * 1000
retrieval_time_cost = (retrieval_ts - generate_keyword_ts) * 1000
generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000
prompt = f"{prompt} ### Elapsed\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms"
return {"answer": answer, "reference": refs, "prompt": prompt}
if stream:
@ -304,15 +338,15 @@ def chat(dialog, messages, stream=True, **kwargs):
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
sys_prompt = "你是一个DBA。你需要这对以下表的字段结构根据用户的问题列表写出最后一个问题对应的SQL。"
user_promt = """
表名{}
数据库表字段说明如下
sys_prompt = "You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question."
user_prompt = """
Table name: {};
Table of database fields are as follows:
{}
问题如下
Question are as follows:
{}
请写出SQL, 且只要SQL不要有其他说明及文字
Please write the SQL, only SQL, without any other explanations or text.
""".format(
index_name(tenant_id),
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
@ -321,10 +355,10 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
tried_times = 0
def get_table():
nonlocal sys_prompt, user_promt, question, tried_times
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {
nonlocal sys_prompt, user_prompt, question, tried_times
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {
"temperature": 0.06})
logging.debug(f"{question} ==> {user_promt} get SQL: {sql}")
logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
sql = re.sub(r"[\r\n]+", " ", sql.lower())
sql = re.sub(r".*select ", "select ", sql.lower())
sql = re.sub(r" +", " ", sql)
@ -352,21 +386,23 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
if tbl is None:
return None
if tbl.get("error") and tried_times <= 2:
user_promt = """
表名{}
数据库表字段说明如下
user_prompt = """
Table name: {};
Table of database fields are as follows:
{}
Question are as follows:
{}
Please write the SQL, only SQL, without any other explanations or text.
The SQL error you provided last time is as follows:
{}
问题如下
Error issued by database as follows:
{}
你上一次给出的错误SQL如下
{}
后台报错如下
{}
请纠正SQL中的错误再写一遍且只要SQL不要有其他说明及文字
Please correct the error and write SQL again, only SQL, without any other explanations or text.
""".format(
index_name(tenant_id),
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
@ -381,21 +417,21 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
docid_idx = set([ii for ii, c in enumerate(
tbl["columns"]) if c["name"] == "doc_id"])
docnm_idx = set([ii for ii, c in enumerate(
doc_name_idx = set([ii for ii, c in enumerate(
tbl["columns"]) if c["name"] == "docnm_kwd"])
clmn_idx = [ii for ii in range(
len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)]
column_idx = [ii for ii in range(
len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
# compose markdown table
clmns = "|" + "|".join([re.sub(r"(/.*|[^]+)", "", field_map.get(tbl["columns"][i]["name"],
tbl["columns"][i]["name"])) for i in
clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
# compose Markdown table
columns = "|" + "|".join([re.sub(r"(/.*|[^]+)", "", field_map.get(tbl["columns"][i]["name"],
tbl["columns"][i]["name"])) for i in
column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + \
("|------|" if docid_idx and docid_idx else "")
rows = ["|" +
"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
"|".join([rmSpace(str(r[i])) for i in column_idx]).replace("None", " ") +
"|" for r in tbl["rows"]]
rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
if quota:
@ -404,24 +440,24 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
if not docid_idx or not docnm_idx:
if not docid_idx or not doc_name_idx:
logging.warning("SQL missing field: " + sql)
return {
"answer": "\n".join([clmns, line, rows]),
"answer": "\n".join([columns, line, rows]),
"reference": {"chunks": [], "doc_aggs": []},
"prompt": sys_prompt
}
docid_idx = list(docid_idx)[0]
docnm_idx = list(docnm_idx)[0]
doc_name_idx = list(doc_name_idx)[0]
doc_aggs = {}
for r in tbl["rows"]:
if r[docid_idx] not in doc_aggs:
doc_aggs[r[docid_idx]] = {"doc_name": r[docnm_idx], "count": 0}
doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0}
doc_aggs[r[docid_idx]]["count"] += 1
return {
"answer": "\n".join([clmns, line, rows]),
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
"answer": "\n".join([columns, line, rows]),
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
doc_aggs.items()]},
"prompt": sys_prompt
@ -492,7 +528,7 @@ Requirements:
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple):
kwd = kwd[0]
if kwd.find("**ERROR**") >=0:
if kwd.find("**ERROR**") >= 0:
return ""
return kwd
@ -605,16 +641,16 @@ def tts(tts_mdl, text):
def ask(question, kb_ids, tenant_id):
kbs = KnowledgebaseService.get_by_ids(kb_ids)
embd_nms = list(set([kb.embd_id for kb in kbs]))
embedding_list = list(set([kb.embd_id for kb in kbs]))
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
retr = settings.retrievaler if not is_kg else settings.kg_retrievaler
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0])
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
max_tokens = chat_mdl.max_length
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
kbinfos = retr.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False)
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False)
knowledges = kb_prompt(kbinfos, max_tokens)
prompt = """
Role: You're a smart assistant. Your name is Miss R.
@ -636,14 +672,14 @@ def ask(question, kb_ids, tenant_id):
def decorate_answer(answer):
nonlocal knowledges, kbinfos, prompt
answer, idx = retr.insert_citations(answer,
[ck["content_ltks"]
for ck in kbinfos["chunks"]],
[ck["vector"]
for ck in kbinfos["chunks"]],
embd_mdl,
tkweight=0.7,
vtweight=0.3)
answer, idx = retriever.insert_citations(answer,
[ck["content_ltks"]
for ck in kbinfos["chunks"]],
[ck["vector"]
for ck in kbinfos["chunks"]],
embd_mdl,
tkweight=0.7,
vtweight=0.3)
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
@ -664,4 +700,3 @@ def ask(question, kb_ids, tenant_id):
answer = ans
yield {"answer": answer, "reference": {}}
yield decorate_answer(answer)

View File

@ -72,10 +72,12 @@ class TenantLLMService(CommonService):
return model_name, None
if len(arr) > 2:
return "@".join(arr[0:-1]), arr[-1]
# model name must be xxx@yyy
try:
fact = json.load(open(os.path.join(get_project_base_directory(), "conf/llm_factories.json"), "r"))["factory_llm_infos"]
fact = set([f["name"] for f in fact])
if arr[-1] not in fact:
model_factories = json.load(open(os.path.join(get_project_base_directory(), "conf/llm_factories.json"), "r"))["factory_llm_infos"]
model_providers = set([f["name"] for f in model_factories])
if arr[-1] not in model_providers:
return model_name, None
return arr[0], arr[-1]
except Exception as e:
@ -113,11 +115,11 @@ class TenantLLMService(CommonService):
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": mdlnm, "api_base": ""}
model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""}
if not model_config:
if mdlnm == "flag-embedding":
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
"llm_name": llm_name, "api_base": ""}
"llm_name": llm_name, "api_base": ""}
else:
if not mdlnm:
raise LookupError(f"Type of {llm_type} model is not set.")
@ -200,8 +202,8 @@ class TenantLLMService(CommonService):
return num
else:
tenant_llm = tenant_llms[0]
num = cls.model.update(used_tokens=tenant_llm.used_tokens + used_tokens)\
.where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name)\
num = cls.model.update(used_tokens=tenant_llm.used_tokens + used_tokens) \
.where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name) \
.execute()
except Exception:
logging.exception("TenantLLMService.increase_usage got exception")
@ -231,7 +233,7 @@ class LLMBundle(object):
for lm in LLMService.query(llm_name=llm_name):
self.max_length = lm.max_tokens
break
def encode(self, texts: list):
embeddings, used_tokens = self.mdl.encode(texts)
if not TenantLLMService.increase_usage(
@ -274,11 +276,11 @@ class LLMBundle(object):
def tts(self, text):
for chunk in self.mdl.tts(text):
if isinstance(chunk,int):
if isinstance(chunk, int):
if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, chunk, self.llm_name):
logging.error(
"LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
self.tenant_id, self.llm_type, chunk, self.llm_name):
logging.error(
"LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
return
yield chunk
@ -287,7 +289,8 @@ class LLMBundle(object):
if isinstance(txt, int) and not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens, self.llm_name):
logging.error(
"LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
"LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name,
used_tokens))
return txt
def chat_streamly(self, system, history, gen_conf):
@ -296,6 +299,7 @@ class LLMBundle(object):
if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, txt, self.llm_name):
logging.error(
"LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))
"LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name,
txt))
return
yield txt