mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-06-04 11:24:00 +08:00
Refactor ask decorator (#4116)
### What problem does this PR solve? Refactor ask decorator ### Type of change - [x] Refactoring --------- Signed-off-by: jinhai <haijin.chn@gmail.com> Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
parent
478da3118c
commit
213218a094
@ -23,7 +23,7 @@ from copy import deepcopy
|
||||
from timeit import default_timer as timer
|
||||
import datetime
|
||||
from datetime import timedelta
|
||||
from api.db import LLMType, ParserType,StatusEnum
|
||||
from api.db import LLMType, ParserType, StatusEnum
|
||||
from api.db.db_models import Dialog, DB
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
@ -41,14 +41,14 @@ class DialogService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_list(cls, tenant_id,
|
||||
page_number, items_per_page, orderby, desc, id , name):
|
||||
page_number, items_per_page, orderby, desc, id, name):
|
||||
chats = cls.model.select()
|
||||
if id:
|
||||
chats = chats.where(cls.model.id == id)
|
||||
if name:
|
||||
chats = chats.where(cls.model.name == name)
|
||||
chats = chats.where(
|
||||
(cls.model.tenant_id == tenant_id)
|
||||
(cls.model.tenant_id == tenant_id)
|
||||
& (cls.model.status == StatusEnum.VALID.value)
|
||||
)
|
||||
if desc:
|
||||
@ -137,25 +137,37 @@ def kb_prompt(kbinfos, max_tokens):
|
||||
|
||||
def chat(dialog, messages, stream=True, **kwargs):
|
||||
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
||||
st = timer()
|
||||
llm_id, fid = TenantLLMService.split_model_name_and_factory(dialog.llm_id)
|
||||
llm = LLMService.query(llm_name=llm_id) if not fid else LLMService.query(llm_name=llm_id, fid=fid)
|
||||
|
||||
chat_start_ts = timer()
|
||||
|
||||
# Get llm model name and model provider name
|
||||
llm_id, model_provider = TenantLLMService.split_model_name_and_factory(dialog.llm_id)
|
||||
|
||||
# Get llm model instance by model and provide name
|
||||
llm = LLMService.query(llm_name=llm_id) if not model_provider else LLMService.query(llm_name=llm_id, fid=model_provider)
|
||||
|
||||
if not llm:
|
||||
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not fid else \
|
||||
TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id, llm_factory=fid)
|
||||
# Model name is provided by tenant, but not system built-in
|
||||
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not model_provider else \
|
||||
TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id, llm_factory=model_provider)
|
||||
if not llm:
|
||||
raise LookupError("LLM(%s) not found" % dialog.llm_id)
|
||||
max_tokens = 8192
|
||||
else:
|
||||
max_tokens = llm[0].max_tokens
|
||||
|
||||
check_llm_ts = timer()
|
||||
|
||||
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
|
||||
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
||||
if len(embd_nms) != 1:
|
||||
embedding_list = list(set([kb.embd_id for kb in kbs]))
|
||||
if len(embedding_list) != 1:
|
||||
yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
|
||||
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
|
||||
|
||||
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
|
||||
retr = settings.retrievaler if not is_kg else settings.kg_retrievaler
|
||||
embedding_model_name = embedding_list[0]
|
||||
|
||||
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
|
||||
retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
|
||||
|
||||
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
||||
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
|
||||
@ -165,15 +177,21 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
if "doc_ids" in m:
|
||||
attachments.extend(m["doc_ids"])
|
||||
|
||||
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
||||
create_retriever_ts = timer()
|
||||
|
||||
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_model_name)
|
||||
if not embd_mdl:
|
||||
raise LookupError("Embedding model(%s) not found" % embd_nms[0])
|
||||
raise LookupError("Embedding model(%s) not found" % embedding_model_name)
|
||||
|
||||
bind_embedding_ts = timer()
|
||||
|
||||
if llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||
else:
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
|
||||
bind_llm_ts = timer()
|
||||
|
||||
prompt_config = dialog.prompt_config
|
||||
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
||||
tts_mdl = None
|
||||
@ -200,32 +218,35 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
|
||||
else:
|
||||
questions = questions[-1:]
|
||||
refineQ_tm = timer()
|
||||
keyword_tm = timer()
|
||||
|
||||
refine_question_ts = timer()
|
||||
|
||||
rerank_mdl = None
|
||||
if dialog.rerank_id:
|
||||
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
|
||||
|
||||
for _ in range(len(questions) // 2):
|
||||
questions.append(questions[-1])
|
||||
bind_reranker_ts = timer()
|
||||
generate_keyword_ts = bind_reranker_ts
|
||||
|
||||
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
|
||||
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
|
||||
else:
|
||||
if prompt_config.get("keyword", False):
|
||||
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
|
||||
keyword_tm = timer()
|
||||
generate_keyword_ts = timer()
|
||||
|
||||
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||||
kbinfos = retr.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
|
||||
dialog.similarity_threshold,
|
||||
dialog.vector_similarity_weight,
|
||||
doc_ids=attachments,
|
||||
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
|
||||
kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
|
||||
dialog.similarity_threshold,
|
||||
dialog.vector_similarity_weight,
|
||||
doc_ids=attachments,
|
||||
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
|
||||
|
||||
retrieval_ts = timer()
|
||||
|
||||
knowledges = kb_prompt(kbinfos, max_tokens)
|
||||
logging.debug(
|
||||
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
||||
retrieval_tm = timer()
|
||||
|
||||
if not knowledges and prompt_config.get("empty_response"):
|
||||
empty_res = prompt_config["empty_response"]
|
||||
@ -249,17 +270,20 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
max_tokens - used_token_count)
|
||||
|
||||
def decorate_answer(answer):
|
||||
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_tm
|
||||
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts
|
||||
|
||||
finish_chat_ts = timer()
|
||||
|
||||
refs = []
|
||||
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
||||
answer, idx = retr.insert_citations(answer,
|
||||
[ck["content_ltks"]
|
||||
for ck in kbinfos["chunks"]],
|
||||
[ck["vector"]
|
||||
for ck in kbinfos["chunks"]],
|
||||
embd_mdl,
|
||||
tkweight=1 - dialog.vector_similarity_weight,
|
||||
vtweight=dialog.vector_similarity_weight)
|
||||
answer, idx = retriever.insert_citations(answer,
|
||||
[ck["content_ltks"]
|
||||
for ck in kbinfos["chunks"]],
|
||||
[ck["vector"]
|
||||
for ck in kbinfos["chunks"]],
|
||||
embd_mdl,
|
||||
tkweight=1 - dialog.vector_similarity_weight,
|
||||
vtweight=dialog.vector_similarity_weight)
|
||||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||||
recall_docs = [
|
||||
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||
@ -274,10 +298,20 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
|
||||
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
||||
answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
|
||||
done_tm = timer()
|
||||
prompt += "\n\n### Elapsed\n - Refine Question: %.1f ms\n - Keywords: %.1f ms\n - Retrieval: %.1f ms\n - LLM: %.1f ms" % (
|
||||
(refineQ_tm - st) * 1000, (keyword_tm - refineQ_tm) * 1000, (retrieval_tm - keyword_tm) * 1000,
|
||||
(done_tm - retrieval_tm) * 1000)
|
||||
finish_chat_ts = timer()
|
||||
|
||||
total_time_cost = (finish_chat_ts - chat_start_ts) * 1000
|
||||
check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000
|
||||
create_retriever_time_cost = (create_retriever_ts - check_llm_ts) * 1000
|
||||
bind_embedding_time_cost = (bind_embedding_ts - create_retriever_ts) * 1000
|
||||
bind_llm_time_cost = (bind_llm_ts - bind_embedding_ts) * 1000
|
||||
refine_question_time_cost = (refine_question_ts - bind_llm_ts) * 1000
|
||||
bind_reranker_time_cost = (bind_reranker_ts - refine_question_ts) * 1000
|
||||
generate_keyword_time_cost = (generate_keyword_ts - bind_reranker_ts) * 1000
|
||||
retrieval_time_cost = (retrieval_ts - generate_keyword_ts) * 1000
|
||||
generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000
|
||||
|
||||
prompt = f"{prompt} ### Elapsed\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms"
|
||||
return {"answer": answer, "reference": refs, "prompt": prompt}
|
||||
|
||||
if stream:
|
||||
@ -304,15 +338,15 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
|
||||
|
||||
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
||||
sys_prompt = "你是一个DBA。你需要这对以下表的字段结构,根据用户的问题列表,写出最后一个问题对应的SQL。"
|
||||
user_promt = """
|
||||
表名:{};
|
||||
数据库表字段说明如下:
|
||||
sys_prompt = "You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question."
|
||||
user_prompt = """
|
||||
Table name: {};
|
||||
Table of database fields are as follows:
|
||||
{}
|
||||
|
||||
问题如下:
|
||||
Question are as follows:
|
||||
{}
|
||||
请写出SQL, 且只要SQL,不要有其他说明及文字。
|
||||
Please write the SQL, only SQL, without any other explanations or text.
|
||||
""".format(
|
||||
index_name(tenant_id),
|
||||
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
|
||||
@ -321,10 +355,10 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
||||
tried_times = 0
|
||||
|
||||
def get_table():
|
||||
nonlocal sys_prompt, user_promt, question, tried_times
|
||||
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {
|
||||
nonlocal sys_prompt, user_prompt, question, tried_times
|
||||
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {
|
||||
"temperature": 0.06})
|
||||
logging.debug(f"{question} ==> {user_promt} get SQL: {sql}")
|
||||
logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
|
||||
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
||||
sql = re.sub(r".*select ", "select ", sql.lower())
|
||||
sql = re.sub(r" +", " ", sql)
|
||||
@ -352,21 +386,23 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
||||
if tbl is None:
|
||||
return None
|
||||
if tbl.get("error") and tried_times <= 2:
|
||||
user_promt = """
|
||||
表名:{};
|
||||
数据库表字段说明如下:
|
||||
user_prompt = """
|
||||
Table name: {};
|
||||
Table of database fields are as follows:
|
||||
{}
|
||||
|
||||
Question are as follows:
|
||||
{}
|
||||
Please write the SQL, only SQL, without any other explanations or text.
|
||||
|
||||
|
||||
The SQL error you provided last time is as follows:
|
||||
{}
|
||||
|
||||
问题如下:
|
||||
Error issued by database as follows:
|
||||
{}
|
||||
|
||||
你上一次给出的错误SQL如下:
|
||||
{}
|
||||
|
||||
后台报错如下:
|
||||
{}
|
||||
|
||||
请纠正SQL中的错误再写一遍,且只要SQL,不要有其他说明及文字。
|
||||
Please correct the error and write SQL again, only SQL, without any other explanations or text.
|
||||
""".format(
|
||||
index_name(tenant_id),
|
||||
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
|
||||
@ -381,21 +417,21 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
||||
|
||||
docid_idx = set([ii for ii, c in enumerate(
|
||||
tbl["columns"]) if c["name"] == "doc_id"])
|
||||
docnm_idx = set([ii for ii, c in enumerate(
|
||||
doc_name_idx = set([ii for ii, c in enumerate(
|
||||
tbl["columns"]) if c["name"] == "docnm_kwd"])
|
||||
clmn_idx = [ii for ii in range(
|
||||
len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)]
|
||||
column_idx = [ii for ii in range(
|
||||
len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
|
||||
|
||||
# compose markdown table
|
||||
clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
|
||||
tbl["columns"][i]["name"])) for i in
|
||||
clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
|
||||
# compose Markdown table
|
||||
columns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
|
||||
tbl["columns"][i]["name"])) for i in
|
||||
column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
|
||||
|
||||
line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \
|
||||
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + \
|
||||
("|------|" if docid_idx and docid_idx else "")
|
||||
|
||||
rows = ["|" +
|
||||
"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
|
||||
"|".join([rmSpace(str(r[i])) for i in column_idx]).replace("None", " ") +
|
||||
"|" for r in tbl["rows"]]
|
||||
rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
|
||||
if quota:
|
||||
@ -404,24 +440,24 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
||||
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
||||
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
|
||||
|
||||
if not docid_idx or not docnm_idx:
|
||||
if not docid_idx or not doc_name_idx:
|
||||
logging.warning("SQL missing field: " + sql)
|
||||
return {
|
||||
"answer": "\n".join([clmns, line, rows]),
|
||||
"answer": "\n".join([columns, line, rows]),
|
||||
"reference": {"chunks": [], "doc_aggs": []},
|
||||
"prompt": sys_prompt
|
||||
}
|
||||
|
||||
docid_idx = list(docid_idx)[0]
|
||||
docnm_idx = list(docnm_idx)[0]
|
||||
doc_name_idx = list(doc_name_idx)[0]
|
||||
doc_aggs = {}
|
||||
for r in tbl["rows"]:
|
||||
if r[docid_idx] not in doc_aggs:
|
||||
doc_aggs[r[docid_idx]] = {"doc_name": r[docnm_idx], "count": 0}
|
||||
doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0}
|
||||
doc_aggs[r[docid_idx]]["count"] += 1
|
||||
return {
|
||||
"answer": "\n".join([clmns, line, rows]),
|
||||
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
|
||||
"answer": "\n".join([columns, line, rows]),
|
||||
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
|
||||
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
|
||||
doc_aggs.items()]},
|
||||
"prompt": sys_prompt
|
||||
@ -492,7 +528,7 @@ Requirements:
|
||||
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
|
||||
if isinstance(kwd, tuple):
|
||||
kwd = kwd[0]
|
||||
if kwd.find("**ERROR**") >=0:
|
||||
if kwd.find("**ERROR**") >= 0:
|
||||
return ""
|
||||
return kwd
|
||||
|
||||
@ -605,16 +641,16 @@ def tts(tts_mdl, text):
|
||||
|
||||
def ask(question, kb_ids, tenant_id):
|
||||
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
||||
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
||||
embedding_list = list(set([kb.embd_id for kb in kbs]))
|
||||
|
||||
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
|
||||
retr = settings.retrievaler if not is_kg else settings.kg_retrievaler
|
||||
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
|
||||
retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
|
||||
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
||||
max_tokens = chat_mdl.max_length
|
||||
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||||
kbinfos = retr.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False)
|
||||
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False)
|
||||
knowledges = kb_prompt(kbinfos, max_tokens)
|
||||
prompt = """
|
||||
Role: You're a smart assistant. Your name is Miss R.
|
||||
@ -636,14 +672,14 @@ def ask(question, kb_ids, tenant_id):
|
||||
|
||||
def decorate_answer(answer):
|
||||
nonlocal knowledges, kbinfos, prompt
|
||||
answer, idx = retr.insert_citations(answer,
|
||||
[ck["content_ltks"]
|
||||
for ck in kbinfos["chunks"]],
|
||||
[ck["vector"]
|
||||
for ck in kbinfos["chunks"]],
|
||||
embd_mdl,
|
||||
tkweight=0.7,
|
||||
vtweight=0.3)
|
||||
answer, idx = retriever.insert_citations(answer,
|
||||
[ck["content_ltks"]
|
||||
for ck in kbinfos["chunks"]],
|
||||
[ck["vector"]
|
||||
for ck in kbinfos["chunks"]],
|
||||
embd_mdl,
|
||||
tkweight=0.7,
|
||||
vtweight=0.3)
|
||||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||||
recall_docs = [
|
||||
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||
@ -664,4 +700,3 @@ def ask(question, kb_ids, tenant_id):
|
||||
answer = ans
|
||||
yield {"answer": answer, "reference": {}}
|
||||
yield decorate_answer(answer)
|
||||
|
||||
|
@ -72,10 +72,12 @@ class TenantLLMService(CommonService):
|
||||
return model_name, None
|
||||
if len(arr) > 2:
|
||||
return "@".join(arr[0:-1]), arr[-1]
|
||||
|
||||
# model name must be xxx@yyy
|
||||
try:
|
||||
fact = json.load(open(os.path.join(get_project_base_directory(), "conf/llm_factories.json"), "r"))["factory_llm_infos"]
|
||||
fact = set([f["name"] for f in fact])
|
||||
if arr[-1] not in fact:
|
||||
model_factories = json.load(open(os.path.join(get_project_base_directory(), "conf/llm_factories.json"), "r"))["factory_llm_infos"]
|
||||
model_providers = set([f["name"] for f in model_factories])
|
||||
if arr[-1] not in model_providers:
|
||||
return model_name, None
|
||||
return arr[0], arr[-1]
|
||||
except Exception as e:
|
||||
@ -113,11 +115,11 @@ class TenantLLMService(CommonService):
|
||||
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
||||
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
||||
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
|
||||
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": mdlnm, "api_base": ""}
|
||||
model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""}
|
||||
if not model_config:
|
||||
if mdlnm == "flag-embedding":
|
||||
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
|
||||
"llm_name": llm_name, "api_base": ""}
|
||||
"llm_name": llm_name, "api_base": ""}
|
||||
else:
|
||||
if not mdlnm:
|
||||
raise LookupError(f"Type of {llm_type} model is not set.")
|
||||
@ -200,8 +202,8 @@ class TenantLLMService(CommonService):
|
||||
return num
|
||||
else:
|
||||
tenant_llm = tenant_llms[0]
|
||||
num = cls.model.update(used_tokens=tenant_llm.used_tokens + used_tokens)\
|
||||
.where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name)\
|
||||
num = cls.model.update(used_tokens=tenant_llm.used_tokens + used_tokens) \
|
||||
.where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name) \
|
||||
.execute()
|
||||
except Exception:
|
||||
logging.exception("TenantLLMService.increase_usage got exception")
|
||||
@ -231,7 +233,7 @@ class LLMBundle(object):
|
||||
for lm in LLMService.query(llm_name=llm_name):
|
||||
self.max_length = lm.max_tokens
|
||||
break
|
||||
|
||||
|
||||
def encode(self, texts: list):
|
||||
embeddings, used_tokens = self.mdl.encode(texts)
|
||||
if not TenantLLMService.increase_usage(
|
||||
@ -274,11 +276,11 @@ class LLMBundle(object):
|
||||
|
||||
def tts(self, text):
|
||||
for chunk in self.mdl.tts(text):
|
||||
if isinstance(chunk,int):
|
||||
if isinstance(chunk, int):
|
||||
if not TenantLLMService.increase_usage(
|
||||
self.tenant_id, self.llm_type, chunk, self.llm_name):
|
||||
logging.error(
|
||||
"LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
|
||||
self.tenant_id, self.llm_type, chunk, self.llm_name):
|
||||
logging.error(
|
||||
"LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
|
||||
return
|
||||
yield chunk
|
||||
|
||||
@ -287,7 +289,8 @@ class LLMBundle(object):
|
||||
if isinstance(txt, int) and not TenantLLMService.increase_usage(
|
||||
self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
||||
logging.error(
|
||||
"LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
|
||||
"LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name,
|
||||
used_tokens))
|
||||
return txt
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
@ -296,6 +299,7 @@ class LLMBundle(object):
|
||||
if not TenantLLMService.increase_usage(
|
||||
self.tenant_id, self.llm_type, txt, self.llm_name):
|
||||
logging.error(
|
||||
"LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))
|
||||
"LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name,
|
||||
txt))
|
||||
return
|
||||
yield txt
|
||||
|
Loading…
x
Reference in New Issue
Block a user