add conversation API (#35)

This commit is contained in:
KevinHuSh 2024-01-18 19:28:37 +08:00 committed by GitHub
parent fad2ec7cf3
commit 4a858d33b6
13 changed files with 425 additions and 153 deletions

View File

@ -13,17 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import hashlib
import re
import numpy as np
from flask import request from flask import request
from flask_login import login_required, current_user from flask_login import login_required, current_user
from elasticsearch_dsl import Q
from rag.nlp import search, huqie from rag.nlp import search, huqie, retrievaler
from rag.utils import ELASTICSEARCH, rmSpace from rag.utils import ELASTICSEARCH, rmSpace
from api.db import LLMType from api.db import LLMType
from api.db.services import duplicate_name
from api.db.services.kb_service import KnowledgebaseService from api.db.services.kb_service import KnowledgebaseService
from api.db.services.llm_service import TenantLLMService from api.db.services.llm_service import TenantLLMService
from api.db.services.user_service import UserTenantService from api.db.services.user_service import UserTenantService
@ -31,8 +27,9 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.settings import RetCode from api.settings import RetCode
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
import hashlib
import re
retrival = search.Dealer(ELASTICSEARCH)
@manager.route('/list', methods=['POST']) @manager.route('/list', methods=['POST'])
@login_required @login_required
@ -45,12 +42,14 @@ def list():
question = req.get("keywords", "") question = req.get("keywords", "")
try: try:
tenant_id = DocumentService.get_tenant_id(req["doc_id"]) tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id: return get_data_error_result(retmsg="Tenant not found!") if not tenant_id:
return get_data_error_result(retmsg="Tenant not found!")
query = { query = {
"doc_ids": [doc_id], "page": page, "size": size, "question": question "doc_ids": [doc_id], "page": page, "size": size, "question": question
} }
if "available_int" in req: query["available_int"] = int(req["available_int"]) if "available_int" in req:
sres = retrival.search(query, search.index_name(tenant_id)) query["available_int"] = int(req["available_int"])
sres = retrievaler.search(query, search.index_name(tenant_id))
res = {"total": sres.total, "chunks": []} res = {"total": sres.total, "chunks": []}
for id in sres.ids: for id in sres.ids:
d = { d = {
@ -79,8 +78,11 @@ def get():
tenants = UserTenantService.query(user_id=current_user.id) tenants = UserTenantService.query(user_id=current_user.id)
if not tenants: if not tenants:
return get_data_error_result(retmsg="Tenant not found!") return get_data_error_result(retmsg="Tenant not found!")
res = ELASTICSEARCH.get(chunk_id, search.index_name(tenants[0].tenant_id)) res = ELASTICSEARCH.get(
if not res.get("found"):return server_error_response("Chunk not found") chunk_id, search.index_name(
tenants[0].tenant_id))
if not res.get("found"):
return server_error_response("Chunk not found")
id = res["_id"] id = res["_id"]
res = res["_source"] res = res["_source"]
res["chunk_id"] = id res["chunk_id"] = id
@ -90,7 +92,8 @@ def get():
k.append(n) k.append(n)
if re.search(r"(_tks|_ltks)", n): if re.search(r"(_tks|_ltks)", n):
res[n] = rmSpace(res[n]) res[n] = rmSpace(res[n])
for n in k: del res[n] for n in k:
del res[n]
return get_json_result(data=res) return get_json_result(data=res)
except Exception as e: except Exception as e:
@ -102,7 +105,8 @@ def get():
@manager.route('/set', methods=['POST']) @manager.route('/set', methods=['POST'])
@login_required @login_required
@validate_request("doc_id", "chunk_id", "content_ltks", "important_kwd", "docnm_kwd") @validate_request("doc_id", "chunk_id", "content_ltks",
"important_kwd")
def set(): def set():
req = request.json req = request.json
d = {"id": req["chunk_id"]} d = {"id": req["chunk_id"]}
@ -110,15 +114,21 @@ def set():
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
d["important_kwd"] = req["important_kwd"] d["important_kwd"] = req["important_kwd"]
d["important_tks"] = huqie.qie(" ".join(req["important_kwd"])) d["important_tks"] = huqie.qie(" ".join(req["important_kwd"]))
if "available_int" in req: d["available_int"] = req["available_int"] if "available_int" in req:
d["available_int"] = req["available_int"]
try: try:
tenant_id = DocumentService.get_tenant_id(req["doc_id"]) tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id: return get_data_error_result(retmsg="Tenant not found!") if not tenant_id:
embd_mdl = TenantLLMService.model_instance(tenant_id, LLMType.EMBEDDING.value) return get_data_error_result(retmsg="Tenant not found!")
v, c = embd_mdl.encode([req["docnm_kwd"], req["content_ltks"]]) embd_mdl = TenantLLMService.model_instance(
tenant_id, LLMType.EMBEDDING.value)
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(retmsg="Document not found!")
v, c = embd_mdl.encode([doc.name, req["content_ltks"]])
v = 0.1 * v[0] + 0.9 * v[1] v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec"%len(v)] = v.tolist() d["q_%d_vec" % len(v)] = v.tolist()
ELASTICSEARCH.upsert([d], search.index_name(tenant_id)) ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
@ -132,7 +142,8 @@ def switch():
req = request.json req = request.json
try: try:
tenant_id = DocumentService.get_tenant_id(req["doc_id"]) tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id: return get_data_error_result(retmsg="Tenant not found!") if not tenant_id:
return get_data_error_result(retmsg="Tenant not found!")
if not ELASTICSEARCH.upsert([{"id": i, "available_int": int(req["available_int"])} for i in req["chunk_ids"]], if not ELASTICSEARCH.upsert([{"id": i, "available_int": int(req["available_int"])} for i in req["chunk_ids"]],
search.index_name(tenant_id)): search.index_name(tenant_id)):
return get_data_error_result(retmsg="Index updating failure") return get_data_error_result(retmsg="Index updating failure")
@ -141,10 +152,22 @@ def switch():
return server_error_response(e) return server_error_response(e)
@manager.route('/rm', methods=['POST'])
@login_required
@validate_request("chunk_ids")
def rm():
req = request.json
try:
if not ELASTICSEARCH.deleteByQuery(Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)):
return get_data_error_result(retmsg="Index updating failure")
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@manager.route('/create', methods=['POST']) @manager.route('/create', methods=['POST'])
@login_required @login_required
@validate_request("doc_id", "content_ltks", "important_kwd") @validate_request("doc_id", "content_ltks")
def create(): def create():
req = request.json req = request.json
md5 = hashlib.md5() md5 = hashlib.md5()
@ -152,24 +175,27 @@ def create():
chunck_id = md5.hexdigest() chunck_id = md5.hexdigest()
d = {"id": chunck_id, "content_ltks": huqie.qie(req["content_ltks"])} d = {"id": chunck_id, "content_ltks": huqie.qie(req["content_ltks"])}
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
d["important_kwd"] = req["important_kwd"] d["important_kwd"] = req.get("important_kwd", [])
d["important_tks"] = huqie.qie(" ".join(req["important_kwd"])) d["important_tks"] = huqie.qie(" ".join(req.get("important_kwd", [])))
try: try:
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: return get_data_error_result(retmsg="Document not found!") if not e:
return get_data_error_result(retmsg="Document not found!")
d["kb_id"] = [doc.kb_id] d["kb_id"] = [doc.kb_id]
d["docnm_kwd"] = doc.name d["docnm_kwd"] = doc.name
d["doc_id"] = doc.id d["doc_id"] = doc.id
tenant_id = DocumentService.get_tenant_id(req["doc_id"]) tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id: return get_data_error_result(retmsg="Tenant not found!") if not tenant_id:
return get_data_error_result(retmsg="Tenant not found!")
embd_mdl = TenantLLMService.model_instance(tenant_id, LLMType.EMBEDDING.value) embd_mdl = TenantLLMService.model_instance(
tenant_id, LLMType.EMBEDDING.value)
v, c = embd_mdl.encode([doc.name, req["content_ltks"]]) v, c = embd_mdl.encode([doc.name, req["content_ltks"]])
DocumentService.increment_chunk_num(req["doc_id"], doc.kb_id, c, 1, 0) DocumentService.increment_chunk_num(req["doc_id"], doc.kb_id, c, 1, 0)
v = 0.1 * v[0] + 0.9 * v[1] v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec"%len(v)] = v.tolist() d["q_%d_vec" % len(v)] = v.tolist()
ELASTICSEARCH.upsert([d], search.index_name(tenant_id)) ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
return get_json_result(data={"chunk_id": chunck_id}) return get_json_result(data={"chunk_id": chunck_id})
except Exception as e: except Exception as e:
@ -194,40 +220,10 @@ def retrieval_test():
if not e: if not e:
return get_data_error_result(retmsg="Knowledgebase not found!") return get_data_error_result(retmsg="Knowledgebase not found!")
embd_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.EMBEDDING.value) embd_mdl = TenantLLMService.model_instance(
sres = retrival.search({"kb_ids": [kb_id], "doc_ids": doc_ids, "size": top, kb.tenant_id, LLMType.EMBEDDING.value)
"question": question, "vector": True, ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, similarity_threshold,
"similarity": similarity_threshold}, vector_similarity_weight, top, doc_ids)
search.index_name(kb.tenant_id),
embd_mdl)
sim, tsim, vsim = retrival.rerank(sres, question, 1-vector_similarity_weight, vector_similarity_weight)
idx = np.argsort(sim*-1)
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
start_idx = (page-1)*size
for i in idx:
ranks["total"] += 1
if sim[i] < similarity_threshold: break
start_idx -= 1
if start_idx >= 0:continue
if len(ranks["chunks"]) == size:continue
id = sres.ids[i]
dnm = sres.field[id]["docnm_kwd"]
d = {
"chunk_id": id,
"content_ltks": sres.field[id]["content_ltks"],
"doc_id": sres.field[id]["doc_id"],
"docnm_kwd": dnm,
"kb_id": sres.field[id]["kb_id"],
"important_kwd": sres.field[id].get("important_kwd", []),
"img_id": sres.field[id].get("img_id", ""),
"similarity": sim[i],
"vector_similarity": vsim[i],
"term_similarity": tsim[i]
}
ranks["chunks"].append(d)
if dnm not in ranks["doc_aggs"]:ranks["doc_aggs"][dnm] = 0
ranks["doc_aggs"][dnm] += 1
return get_json_result(data=ranks) return get_json_result(data=ranks)
except Exception as e: except Exception as e:
@ -235,3 +231,4 @@ def retrieval_test():
return get_json_result(data=False, retmsg=f'Index not found!', return get_json_result(data=False, retmsg=f'Index not found!',
retcode=RetCode.DATA_ERROR) retcode=RetCode.DATA_ERROR)
return server_error_response(e) return server_error_response(e)

View File

@ -0,0 +1,207 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
import tiktoken
from flask import request
from flask_login import login_required, current_user
from api.db.services.dialog_service import DialogService, ConversationService
from api.db import StatusEnum, LLMType
from api.db.services.kb_service import KnowledgebaseService
from api.db.services.llm_service import LLMService, TenantLLMService
from api.db.services.user_service import TenantService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid
from api.utils.api_utils import get_json_result
from rag.llm import ChatModel
from rag.nlp import retrievaler
from rag.nlp.query import EsQueryer
from rag.utils import num_tokens_from_string, encoder
@manager.route('/set', methods=['POST'])
@login_required
@validate_request("dialog_id")
def set():
req = request.json
conv_id = req.get("conversation_id")
if conv_id:
del req["conversation_id"]
try:
if not ConversationService.update_by_id(conv_id, req):
return get_data_error_result(retmsg="Conversation not found!")
e, conv = ConversationService.get_by_id(conv_id)
if not e:
return get_data_error_result(
retmsg="Fail to update a conversation!")
conv = conv.to_dict()
return get_json_result(data=conv)
except Exception as e:
return server_error_response(e)
try:
e, dia = DialogService.get_by_id(req["dialog_id"])
if not e:
return get_data_error_result(retmsg="Dialog not found")
conv = {
"id": get_uuid(),
"dialog_id": req["dialog_id"],
"name": "New conversation",
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]
}
ConversationService.save(**conv)
e, conv = ConversationService.get_by_id(conv["id"])
if not e:
return get_data_error_result(retmsg="Fail to new a conversation!")
conv = conv.to_dict()
return get_json_result(data=conv)
except Exception as e:
return server_error_response(e)
@manager.route('/get', methods=['GET'])
@login_required
def get():
conv_id = request.args["conversation_id"]
try:
e, conv = ConversationService.get_by_id(conv_id)
if not e:
return get_data_error_result(retmsg="Conversation not found!")
conv = conv.to_dict()
return get_json_result(data=conv)
except Exception as e:
return server_error_response(e)
@manager.route('/rm', methods=['POST'])
@login_required
def rm():
conv_ids = request.json["conversation_ids"]
try:
for cid in conv_ids:
ConversationService.delete_by_id(cid)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@manager.route('/list', methods=['GET'])
@login_required
def list():
dialog_id = request.args["dialog_id"]
try:
convs = ConversationService.query(dialog_id=dialog_id)
convs = [d.to_dict() for d in convs]
return get_json_result(data=convs)
except Exception as e:
return server_error_response(e)
def message_fit_in(msg, max_length=4000):
def count():
nonlocal msg
tks_cnts = []
for m in msg:tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])})
total = 0
for m in tks_cnts: total += m["count"]
return total
c = count()
if c < max_length: return c, msg
msg = [m for m in msg if m.role in ["system", "user"]]
c = count()
if c < max_length:return c, msg
msg_ = [m for m in msg[:-1] if m.role == "system"]
msg_.append(msg[-1])
msg = msg_
c = count()
if c < max_length:return c, msg
ll = num_tokens_from_string(msg_[0].content)
l = num_tokens_from_string(msg_[-1].content)
if ll/(ll + l) > 0.8:
m = msg_[0].content
m = encoder.decode(encoder.encode(m)[:max_length-l])
msg[0].content = m
return max_length, msg
m = msg_[1].content
m = encoder.decode(encoder.encode(m)[:max_length-l])
msg[1].content = m
return max_length, msg
def chat(dialog, messages, **kwargs):
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
llm = LLMService.query(llm_name=dialog.llm_id)
if not llm:
raise LookupError("LLM(%s) not found"%dialog.llm_id)
llm = llm[0]
prompt_config = dialog.prompt_config
for p in prompt_config["parameters"]:
if p["key"] == "knowledge":continue
if p["key"] not in kwargs and not p["optional"]:raise KeyError("Miss parameter: " + p["key"])
if p["key"] not in kwargs:
prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ")
model_config = TenantLLMService.get_api_key(dialog.tenant_id, LLMType.CHAT.value, dialog.llm_id)
if not model_config: raise LookupError("LLM(%s) API key not found"%dialog.llm_id)
question = messages[-1]["content"]
embd_mdl = TenantLLMService.model_instance(
dialog.tenant_id, LLMType.EMBEDDING.value)
kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold,
dialog.vector_similarity_weight, top=1024, aggs=False)
knowledges = [ck["content_ltks"] for ck in kbinfos["chunks"]]
if not knowledges and prompt_config["empty_response"]:
return {"answer": prompt_config["empty_response"], "retrieval": kbinfos}
kwargs["knowledge"] = "\n".join(knowledges)
gen_conf = dialog.llm_setting[dialog.llm_setting_type]
msg = [{"role": m["role"], "content": m["content"]} for m in messages if m["role"] != "system"]
used_token_count = message_fit_in(msg, int(llm.max_tokens * 0.97))
if "max_tokens" in gen_conf:
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count)
mdl = ChatModel[model_config.llm_factory](model_config["api_key"], dialog.llm_id)
answer = mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
answer = retrievaler.insert_citations(answer,
[ck["content_ltks"] for ck in kbinfos["chunks"]],
[ck["vector"] for ck in kbinfos["chunks"]],
embd_mdl,
tkweight=1-dialog.vector_similarity_weight,
vtweight=dialog.vector_similarity_weight)
return {"answer": answer, "retrieval": kbinfos}
@manager.route('/completion', methods=['POST'])
@login_required
@validate_request("dialog_id", "messages")
def completion():
req = request.json
msg = []
for m in req["messages"]:
if m["role"] == "system":continue
if m["role"] == "assistant" and not msg:continue
msg.append({"role": m["role"], "content": m["content"]})
try:
e, dia = DialogService.get_by_id(req["dialog_id"])
if not e:
return get_data_error_result(retmsg="Dialog not found!")
del req["dialog_id"]
del req["messages"]
return get_json_result(data=chat(dia, msg, **req))
except Exception as e:
return server_error_response(e)

View File

@ -13,28 +13,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import hashlib
import re
import numpy as np
from flask import request from flask import request
from flask_login import login_required, current_user from flask_login import login_required, current_user
from api.db.services.dialog_service import DialogService from api.db.services.dialog_service import DialogService
from rag.nlp import search, huqie from api.db import StatusEnum
from rag.utils import ELASTICSEARCH, rmSpace
from api.db import LLMType, StatusEnum
from api.db.services import duplicate_name
from api.db.services.kb_service import KnowledgebaseService from api.db.services.kb_service import KnowledgebaseService
from api.db.services.llm_service import TenantLLMService from api.db.services.user_service import TenantService
from api.db.services.user_service import UserTenantService, TenantService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid from api.utils import get_uuid
from api.db.services.document_service import DocumentService
from api.settings import RetCode, stat_logger
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
from rag.utils.minio_conn import MINIO
from api.utils.file_utils import filename_type
@manager.route('/set', methods=['POST']) @manager.route('/set', methods=['POST'])
@ -128,6 +116,7 @@ def set():
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/get', methods=['GET']) @manager.route('/get', methods=['GET'])
@login_required @login_required
def get(): def get():
@ -161,3 +150,16 @@ def list():
return get_json_result(data=diags) return get_json_result(data=diags)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/rm', methods=['POST'])
@login_required
@validate_request("dialog_id")
def rm():
req = request.json
try:
if not DialogService.update_by_id(req["dialog_id"], {"status": StatusEnum.INVALID.value}):
return get_data_error_result(retmsg="Dialog not found!")
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)

View File

@ -271,7 +271,7 @@ def change_parser():
@manager.route('/image/<image_id>', methods=['GET']) @manager.route('/image/<image_id>', methods=['GET'])
@login_required #@login_required
def get_image(image_id): def get_image(image_id):
try: try:
bkt, nm = image_id.split("-") bkt, nm = image_id.split("-")

View File

@ -108,7 +108,7 @@ def rm():
if not KnowledgebaseService.query(created_by=current_user.id, id=req["kb_id"]): if not KnowledgebaseService.query(created_by=current_user.id, id=req["kb_id"]):
return get_json_result(data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR) return get_json_result(data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
if not KnowledgebaseService.update_by_id(req["kb_id"], {"status": StatusEnum.IN_VALID.value}): return get_data_error_result(retmsg="Database error (Knowledgebase removal)!") if not KnowledgebaseService.update_by_id(req["kb_id"], {"status": StatusEnum.INVALID.value}): return get_data_error_result(retmsg="Database error (Knowledgebase removal)!")
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)

View File

@ -20,7 +20,7 @@ from strenum import StrEnum
class StatusEnum(Enum): class StatusEnum(Enum):
VALID = "1" VALID = "1"
IN_VALID = "0" INVALID = "0"
class UserTenantRole(StrEnum): class UserTenantRole(StrEnum):

View File

@ -430,6 +430,7 @@ class LLM(DataBaseModel):
llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True) llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True)
model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR") model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
fid = CharField(max_length=128, null=False, help_text="LLM factory id") fid = CharField(max_length=128, null=False, help_text="LLM factory id")
max_tokens = IntegerField(default=0)
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...") tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1") status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
@ -467,8 +468,8 @@ class Knowledgebase(DataBaseModel):
doc_num = IntegerField(default=0) doc_num = IntegerField(default=0)
token_num = IntegerField(default=0) token_num = IntegerField(default=0)
chunk_num = IntegerField(default=0) chunk_num = IntegerField(default=0)
similarity_threshold = FloatField(default=0.4) #similarity_threshold = FloatField(default=0.4)
vector_similarity_weight = FloatField(default=0.3) #vector_similarity_weight = FloatField(default=0.3)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID") parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1") status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
@ -518,6 +519,11 @@ class Dialog(DataBaseModel):
prompt_type = CharField(max_length=16, null=False, default="simple", help_text="simple|advanced") prompt_type = CharField(max_length=16, null=False, default="simple", help_text="simple|advanced")
prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好我是您的助手小樱长得可爱又善良can I help you?", prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好我是您的助手小樱长得可爱又善良can I help you?",
"parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"}) "parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"})
similarity_threshold = FloatField(default=0.4)
vector_similarity_weight = FloatField(default=0.3)
top_n = IntegerField(default=6)
kb_ids = JSONField(null=False, default=[]) kb_ids = JSONField(null=False, default=[])
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1") status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")

View File

@ -62,61 +62,73 @@ def init_llm_factory():
"fid": factory_infos[0]["name"], "fid": factory_infos[0]["name"],
"llm_name": "gpt-3.5-turbo", "llm_name": "gpt-3.5-turbo",
"tags": "LLM,CHAT,4K", "tags": "LLM,CHAT,4K",
"max_tokens": 4096,
"model_type": LLMType.CHAT.value "model_type": LLMType.CHAT.value
},{ },{
"fid": factory_infos[0]["name"], "fid": factory_infos[0]["name"],
"llm_name": "gpt-3.5-turbo-16k-0613", "llm_name": "gpt-3.5-turbo-16k-0613",
"tags": "LLM,CHAT,16k", "tags": "LLM,CHAT,16k",
"max_tokens": 16385,
"model_type": LLMType.CHAT.value "model_type": LLMType.CHAT.value
},{ },{
"fid": factory_infos[0]["name"], "fid": factory_infos[0]["name"],
"llm_name": "text-embedding-ada-002", "llm_name": "text-embedding-ada-002",
"tags": "TEXT EMBEDDING,8K", "tags": "TEXT EMBEDDING,8K",
"max_tokens": 8191,
"model_type": LLMType.EMBEDDING.value "model_type": LLMType.EMBEDDING.value
},{ },{
"fid": factory_infos[0]["name"], "fid": factory_infos[0]["name"],
"llm_name": "whisper-1", "llm_name": "whisper-1",
"tags": "SPEECH2TEXT", "tags": "SPEECH2TEXT",
"max_tokens": 25*1024*1024,
"model_type": LLMType.SPEECH2TEXT.value "model_type": LLMType.SPEECH2TEXT.value
},{ },{
"fid": factory_infos[0]["name"], "fid": factory_infos[0]["name"],
"llm_name": "gpt-4", "llm_name": "gpt-4",
"tags": "LLM,CHAT,8K", "tags": "LLM,CHAT,8K",
"max_tokens": 8191,
"model_type": LLMType.CHAT.value "model_type": LLMType.CHAT.value
},{ },{
"fid": factory_infos[0]["name"], "fid": factory_infos[0]["name"],
"llm_name": "gpt-4-32k", "llm_name": "gpt-4-32k",
"tags": "LLM,CHAT,32K", "tags": "LLM,CHAT,32K",
"max_tokens": 32768,
"model_type": LLMType.CHAT.value "model_type": LLMType.CHAT.value
},{ },{
"fid": factory_infos[0]["name"], "fid": factory_infos[0]["name"],
"llm_name": "gpt-4-vision-preview", "llm_name": "gpt-4-vision-preview",
"tags": "LLM,CHAT,IMAGE2TEXT", "tags": "LLM,CHAT,IMAGE2TEXT",
"max_tokens": 765,
"model_type": LLMType.IMAGE2TEXT.value "model_type": LLMType.IMAGE2TEXT.value
},{ },{
"fid": factory_infos[1]["name"], "fid": factory_infos[1]["name"],
"llm_name": "qwen-turbo", "llm_name": "qwen-turbo",
"tags": "LLM,CHAT,8K", "tags": "LLM,CHAT,8K",
"max_tokens": 8191,
"model_type": LLMType.CHAT.value "model_type": LLMType.CHAT.value
},{ },{
"fid": factory_infos[1]["name"], "fid": factory_infos[1]["name"],
"llm_name": "qwen-plus", "llm_name": "qwen-plus",
"tags": "LLM,CHAT,32K", "tags": "LLM,CHAT,32K",
"max_tokens": 32768,
"model_type": LLMType.CHAT.value "model_type": LLMType.CHAT.value
},{ },{
"fid": factory_infos[1]["name"], "fid": factory_infos[1]["name"],
"llm_name": "text-embedding-v2", "llm_name": "text-embedding-v2",
"tags": "TEXT EMBEDDING,2K", "tags": "TEXT EMBEDDING,2K",
"max_tokens": 2048,
"model_type": LLMType.EMBEDDING.value "model_type": LLMType.EMBEDDING.value
},{ },{
"fid": factory_infos[1]["name"], "fid": factory_infos[1]["name"],
"llm_name": "paraformer-realtime-8k-v1", "llm_name": "paraformer-realtime-8k-v1",
"tags": "SPEECH2TEXT", "tags": "SPEECH2TEXT",
"max_tokens": 25*1024*1024,
"model_type": LLMType.SPEECH2TEXT.value "model_type": LLMType.SPEECH2TEXT.value
},{ },{
"fid": factory_infos[1]["name"], "fid": factory_infos[1]["name"],
"llm_name": "qwen_vl_chat_v1", "llm_name": "qwen_vl_chat_v1",
"tags": "LLM,CHAT,IMAGE2TEXT", "tags": "LLM,CHAT,IMAGE2TEXT",
"max_tokens": 765,
"model_type": LLMType.IMAGE2TEXT.value "model_type": LLMType.IMAGE2TEXT.value
}, },
] ]

View File

@ -34,7 +34,7 @@ class TenantLLMService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_api_key(cls, tenant_id, model_type): def get_api_key(cls, tenant_id, model_type, model_name=""):
objs = cls.query(tenant_id=tenant_id, model_type=model_type) objs = cls.query(tenant_id=tenant_id, model_type=model_type)
if objs and len(objs)>0 and objs[0].llm_name: if objs and len(objs)>0 and objs[0].llm_name:
return objs[0] return objs[0]
@ -42,7 +42,7 @@ class TenantLLMService(CommonService):
fields = [LLM.llm_name, cls.model.llm_factory, cls.model.api_key] fields = [LLM.llm_name, cls.model.llm_factory, cls.model.api_key]
objs = cls.model.select(*fields).join(LLM, on=(LLM.fid == cls.model.llm_factory)).where( objs = cls.model.select(*fields).join(LLM, on=(LLM.fid == cls.model.llm_factory)).where(
(cls.model.tenant_id == tenant_id), (cls.model.tenant_id == tenant_id),
(cls.model.model_type == model_type), ((cls.model.model_type == model_type) | (cls.model.llm_name == model_name)),
(LLM.status == StatusEnum.VALID) (LLM.status == StatusEnum.VALID)
) )
@ -60,7 +60,7 @@ class TenantLLMService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def model_instance(cls, tenant_id, llm_type): def model_instance(cls, tenant_id, llm_type):
model_config = cls.get_api_key(tenant_id, model_type=LLMType.EMBEDDING) model_config = cls.get_api_key(tenant_id, model_type=LLMType.EMBEDDING.value)
if not model_config: if not model_config:
model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""} model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""}
else: else:

View File

@ -30,3 +30,9 @@ CvModel = {
"通义千问": QWenCV, "通义千问": QWenCV,
} }
ChatModel = {
"OpenAI": GptTurbo,
"通义千问": QWenChat,
}

View File

@ -0,0 +1,4 @@
from . import search
from rag.utils import ELASTICSEARCH
retrievaler = search.Dealer(ELASTICSEARCH)

View File

@ -2,7 +2,7 @@
import json import json
import re import re
from elasticsearch_dsl import Q, Search, A from elasticsearch_dsl import Q, Search, A
from typing import List, Optional, Tuple, Dict, Union from typing import List, Optional, Dict, Union
from dataclasses import dataclass from dataclasses import dataclass
from rag.settings import es_logger from rag.settings import es_logger
@ -20,6 +20,8 @@ class Dealer:
self.qryr.flds = [ self.qryr.flds = [
"title_tks^10", "title_tks^10",
"title_sm_tks^5", "title_sm_tks^5",
"important_kwd^30",
"important_tks^20",
"content_ltks^2", "content_ltks^2",
"content_sm_ltks"] "content_sm_ltks"]
self.es = es self.es = es
@ -38,10 +40,10 @@ class Dealer:
def _vector(self, txt, emb_mdl, sim=0.8, topk=10): def _vector(self, txt, emb_mdl, sim=0.8, topk=10):
qv, c = emb_mdl.encode_queries(txt) qv, c = emb_mdl.encode_queries(txt)
return { return {
"field": "q_%d_vec"%len(qv), "field": "q_%d_vec" % len(qv),
"k": topk, "k": topk,
"similarity": sim, "similarity": sim,
"num_candidates": topk*2, "num_candidates": topk * 2,
"query_vector": qv "query_vector": qv
} }
@ -53,14 +55,16 @@ class Dealer:
if req.get("doc_ids"): if req.get("doc_ids"):
bqry.filter.append(Q("terms", doc_id=req["doc_ids"])) bqry.filter.append(Q("terms", doc_id=req["doc_ids"]))
if "available_int" in req: if "available_int" in req:
if req["available_int"] == 0: bqry.filter.append(Q("range", available_int={"lt": 1})) if req["available_int"] == 0:
else: bqry.filter.append(Q("bool", must_not=Q("range", available_int={"lt": 1}))) bqry.filter.append(Q("range", available_int={"lt": 1}))
else:
bqry.filter.append(Q("bool", must_not=Q("range", available_int={"lt": 1})))
bqry.boost = 0.05 bqry.boost = 0.05
s = Search() s = Search()
pg = int(req.get("page", 1)) - 1 pg = int(req.get("page", 1)) - 1
ps = int(req.get("size", 1000)) ps = int(req.get("size", 1000))
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id","img_id", src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id",
"image_id", "doc_id", "q_512_vec", "q_768_vec", "image_id", "doc_id", "q_512_vec", "q_768_vec",
"q_1024_vec", "q_1536_vec", "available_int"]) "q_1024_vec", "q_1536_vec", "available_int"])
@ -171,74 +175,106 @@ class Dealer:
def trans2floats(txt): def trans2floats(txt):
return [float(t) for t in txt.split("\t")] return [float(t) for t in txt.split("\t")]
def insert_citations(self, ans, top_idx, sres, emb_mdl, def insert_citations(self, answer, chunks, chunk_v, embd_mdl, tkweight=0.3, vtweight=0.7):
vfield="q_vec", cfield="content_ltks"): pieces = re.split(r"([;。?!\n]|[a-z][.?;!][ \n])", answer)
for i in range(1, len(pieces)):
if re.match(r"[a-z][.?;!][ \n]", pieces[i]):
pieces[i - 1] += pieces[i][0]
pieces[i] = pieces[i][1:]
idx = []
pieces_ = []
for i, t in enumerate(pieces):
if len(t) < 5: continue
idx.append(i)
pieces_.append(t)
if not pieces_: return answer
ins_embd = [Dealer.trans2floats( ans_v = embd_mdl.encode(pieces_)
sres.field[sres.ids[i]][vfield]) for i in top_idx] assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
ins_tw = [sres.field[sres.ids[i]][cfield].split(" ") for i in top_idx] len(ans_v[0]), len(chunk_v[0]))
s = 0
e = 0
res = ""
def citeit(): chunks_tks = [huqie.qie(ck).split(" ") for ck in chunks]
nonlocal s, e, ans, res, emb_mdl cites = {}
if not ins_embd: for i,a in enumerate(pieces_):
return sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i],
embd = emb_mdl.encode(ans[s: e]) chunk_v,
sim = self.qryr.hybrid_similarity(embd, huqie.qie(pieces_[i]).split(" "),
ins_embd, chunks_tks,
huqie.qie(ans[s:e]).split(" "), tkweight, vtweight)
ins_tw)
mx = np.max(sim) * 0.99 mx = np.max(sim) * 0.99
if mx < 0.55: if mx < 0.55: continue
return cites[idx[i]] = list(set([str(i) for i in range(len(chunk_v)) if sim[i] > mx]))[:4]
cita = list(set([top_idx[i]
for i in range(len(ins_embd)) if sim[i] > mx]))[:4]
for i in cita:
res += f"@?{i}?@"
return cita res = ""
for i,p in enumerate(pieces):
punct = set(";。?!") res += p
if not self.qryr.isChinese(ans): if i not in idx:continue
punct.add("?") if i not in cites:continue
punct.add(".") res += "##%s$$"%"$".join(cites[i])
while e < len(ans):
if e - s < 12 or ans[e] not in punct:
e += 1
continue
if ans[e] == "." and e + \
1 < len(ans) and re.match(r"[0-9]", ans[e + 1]):
e += 1
continue
if ans[e] == "." and e - 2 >= 0 and ans[e - 2] == "\n":
e += 1
continue
res += ans[s: e]
citeit()
res += ans[e]
e += 1
s = e
if s < len(ans):
res += ans[s:]
citeit()
return res return res
def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, cfield="content_ltks"): def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, cfield="content_ltks"):
ins_embd = [ ins_embd = [
Dealer.trans2floats( Dealer.trans2floats(
sres.field[i]["q_%d_vec"%len(sres.query_vector)]) for i in sres.ids] sres.field[i]["q_%d_vec" % len(sres.query_vector)]) for i in sres.ids]
if not ins_embd: if not ins_embd:
return [] return []
ins_tw = [sres.field[i][cfield].split(" ") for i in sres.ids] ins_tw = [huqie.qie(sres.field[i][cfield]).split(" ") for i in sres.ids]
sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector, sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector,
ins_embd, ins_embd,
huqie.qie(query).split(" "), huqie.qie(query).split(" "),
ins_tw, tkweight, vtweight) ins_tw, tkweight, vtweight)
return sim, tksim, vtsim return sim, tksim, vtsim
def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
return self.qryr.hybrid_similarity(ans_embd,
ins_embd,
huqie.qie(ans).split(" "),
huqie.qie(inst).split(" "))
def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2,
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True):
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": top,
"question": question, "vector": True,
"similarity": similarity_threshold}
sres = self.search(req, index_name(tenant_id), embd_mdl)
sim, tsim, vsim = self.rerank(
sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
idx = np.argsort(sim * -1)
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
dim = len(sres.query_vector)
start_idx = (page - 1) * page_size
for i in idx:
ranks["total"] += 1
if sim[i] < similarity_threshold:
break
start_idx -= 1
if start_idx >= 0:
continue
if len(ranks["chunks"]) == page_size:
if aggs:
continue
break
id = sres.ids[i]
dnm = sres.field[id]["docnm_kwd"]
d = {
"chunk_id": id,
"content_ltks": sres.field[id]["content_ltks"],
"doc_id": sres.field[id]["doc_id"],
"docnm_kwd": dnm,
"kb_id": sres.field[id]["kb_id"],
"important_kwd": sres.field[id].get("important_kwd", []),
"img_id": sres.field[id].get("img_id", ""),
"similarity": sim[i],
"vector_similarity": vsim[i],
"term_similarity": tsim[i],
"vector": self.trans2floats(sres.field[id].get("q_%d_vec" % dim, "\t".join(["0"] * dim)))
}
ranks["chunks"].append(d)
if dnm not in ranks["doc_aggs"]:
ranks["doc_aggs"][dnm] = 0
ranks["doc_aggs"][dnm] += 1
return ranks

View File

@ -59,8 +59,10 @@ def findMaxTm(fnm):
return m return m
encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")
def num_tokens_from_string(string: str) -> int: def num_tokens_from_string(string: str) -> int:
"""Returns the number of tokens in a text string.""" """Returns the number of tokens in a text string."""
encoding = tiktoken.get_encoding('cl100k_base') num_tokens = len(encoder.encode(string))
num_tokens = len(encoding.encode(string))
return num_tokens return num_tokens