Test APIs and fix bugs (#41)

This commit is contained in:
KevinHuSh 2024-01-22 19:51:38 +08:00 committed by GitHub
parent 484e5abc1f
commit 34b2ab3b2f
11 changed files with 46 additions and 27 deletions

View File

@ -214,7 +214,7 @@ def retrieval_test():
question = req["question"] question = req["question"]
kb_id = req["kb_id"] kb_id = req["kb_id"]
doc_ids = req.get("doc_ids", []) doc_ids = req.get("doc_ids", [])
similarity_threshold = float(req.get("similarity_threshold", 0.4)) similarity_threshold = float(req.get("similarity_threshold", 0.2))
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
top = int(req.get("top", 1024)) top = int(req.get("top", 1024))
try: try:

View File

@ -170,7 +170,7 @@ def chat(dialog, messages, **kwargs):
if p["key"] not in kwargs: if p["key"] not in kwargs:
prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ") prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ")
model_config = TenantLLMService.get_api_key(dialog.tenant_id, LLMType.CHAT.value, dialog.llm_id) model_config = TenantLLMService.get_api_key(dialog.tenant_id, dialog.llm_id)
if not model_config: raise LookupError("LLM({}) API key not found".format(dialog.llm_id)) if not model_config: raise LookupError("LLM({}) API key not found".format(dialog.llm_id))
question = messages[-1]["content"] question = messages[-1]["content"]
@ -186,10 +186,10 @@ def chat(dialog, messages, **kwargs):
kwargs["knowledge"] = "\n".join(knowledges) kwargs["knowledge"] = "\n".join(knowledges)
gen_conf = dialog.llm_setting[dialog.llm_setting_type] gen_conf = dialog.llm_setting[dialog.llm_setting_type]
msg = [{"role": m["role"], "content": m["content"]} for m in messages if m["role"] != "system"] msg = [{"role": m["role"], "content": m["content"]} for m in messages if m["role"] != "system"]
used_token_count = message_fit_in(msg, int(llm.max_tokens * 0.97)) used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97))
if "max_tokens" in gen_conf: if "max_tokens" in gen_conf:
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count) gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count)
mdl = ChatModel[model_config.llm_factory](model_config["api_key"], dialog.llm_id) mdl = ChatModel[model_config.llm_factory](model_config.api_key, dialog.llm_id)
answer = mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf) answer = mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
answer = retrievaler.insert_citations(answer, answer = retrievaler.insert_citations(answer,
@ -198,4 +198,6 @@ def chat(dialog, messages, **kwargs):
embd_mdl, embd_mdl,
tkweight=1-dialog.vector_similarity_weight, tkweight=1-dialog.vector_similarity_weight,
vtweight=dialog.vector_similarity_weight) vtweight=dialog.vector_similarity_weight)
for c in kbinfos["chunks"]:
if c.get("vector"):del c["vector"]
return {"answer": answer, "retrieval": kbinfos} return {"answer": answer, "retrieval": kbinfos}

View File

@ -11,7 +11,8 @@
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License
#
# #
import base64 import base64
import pathlib import pathlib
@ -65,7 +66,7 @@ def upload():
while MINIO.obj_exist(kb_id, location): while MINIO.obj_exist(kb_id, location):
location += "_" location += "_"
blob = request.files['file'].read() blob = request.files['file'].read()
MINIO.put(kb_id, filename, blob) MINIO.put(kb_id, location, blob)
doc = DocumentService.insert({ doc = DocumentService.insert({
"id": get_uuid(), "id": get_uuid(),
"kb_id": kb.id, "kb_id": kb.id,
@ -188,7 +189,10 @@ def rm():
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
return get_data_error_result(retmsg="Document not found!") return get_data_error_result(retmsg="Document not found!")
ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(doc.kb_id)) tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id:
return get_data_error_result(retmsg="Tenant not found!")
ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, 0) DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, 0)
if not DocumentService.delete_by_id(req["doc_id"]): if not DocumentService.delete_by_id(req["doc_id"]):

View File

@ -75,7 +75,7 @@ def list():
llms = LLMService.get_all() llms = LLMService.get_all()
llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value] llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value]
for m in llms: for m in llms:
m["available"] = m.llm_name in mdlnms m["available"] = m["llm_name"] in mdlnms
res = {} res = {}
for m in llms: for m in llms:

View File

@ -469,7 +469,7 @@ class Knowledgebase(DataBaseModel):
doc_num = IntegerField(default=0) doc_num = IntegerField(default=0)
token_num = IntegerField(default=0) token_num = IntegerField(default=0)
chunk_num = IntegerField(default=0) chunk_num = IntegerField(default=0)
similarity_threshold = FloatField(default=0.4) similarity_threshold = FloatField(default=0.2)
vector_similarity_weight = FloatField(default=0.3) vector_similarity_weight = FloatField(default=0.3)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID") parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
@ -521,7 +521,7 @@ class Dialog(DataBaseModel):
prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好我是您的助手小樱长得可爱又善良can I help you?", prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好我是您的助手小樱长得可爱又善良can I help you?",
"parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"}) "parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"})
similarity_threshold = FloatField(default=0.4) similarity_threshold = FloatField(default=0.2)
vector_similarity_weight = FloatField(default=0.3) vector_similarity_weight = FloatField(default=0.3)
top_n = IntegerField(default=6) top_n = IntegerField(default=6)

View File

@ -63,7 +63,7 @@ class TenantLLMService(CommonService):
model_config = cls.get_api_key(tenant_id, mdlnm) model_config = cls.get_api_key(tenant_id, mdlnm)
if not model_config: raise LookupError("Model({}) not found".format(mdlnm)) if not model_config: raise LookupError("Model({}) not found".format(mdlnm))
model_config = model_config[0].to_dict() model_config = model_config.to_dict()
if llm_type == LLMType.EMBEDDING.value: if llm_type == LLMType.EMBEDDING.value:
if model_config["llm_factory"] not in EmbeddingModel: return if model_config["llm_factory"] not in EmbeddingModel: return
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"]) return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])

View File

@ -143,7 +143,7 @@ def filename_type(filename):
if re.match(r".*\.pdf$", filename): if re.match(r".*\.pdf$", filename):
return FileType.PDF.value return FileType.PDF.value
if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename): if re.match(r".*\.(docx|doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename):
return FileType.DOC.value return FileType.DOC.value
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename): if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):

View File

@ -19,31 +19,39 @@ import os
class Base(ABC): class Base(ABC):
def __init__(self, key, model_name):
pass
def chat(self, system, history, gen_conf): def chat(self, system, history, gen_conf):
raise NotImplementedError("Please implement encode method!") raise NotImplementedError("Please implement encode method!")
class GptTurbo(Base): class GptTurbo(Base):
def __init__(self): def __init__(self, key, model_name="gpt-3.5-turbo"):
self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"]) self.client = OpenAI(api_key=key)
self.model_name = model_name
def chat(self, system, history, gen_conf): def chat(self, system, history, gen_conf):
history.insert(0, {"role": "system", "content": system}) history.insert(0, {"role": "system", "content": system})
res = self.client.chat.completions.create( res = self.client.chat.completions.create(
model="gpt-3.5-turbo", model=self.model_name,
messages=history, messages=history,
**gen_conf) **gen_conf)
return res.choices[0].message.content.strip() return res.choices[0].message.content.strip()
from dashscope import Generation
class QWenChat(Base): class QWenChat(Base):
def __init__(self, key, model_name=Generation.Models.qwen_turbo):
import dashscope
dashscope.api_key = key
self.model_name = model_name
def chat(self, system, history, gen_conf): def chat(self, system, history, gen_conf):
from http import HTTPStatus from http import HTTPStatus
from dashscope import Generation
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
history.insert(0, {"role": "system", "content": system}) history.insert(0, {"role": "system", "content": system})
response = Generation.call( response = Generation.call(
Generation.Models.qwen_turbo, self.model_name,
messages=history, messages=history,
result_format='message' result_format='message'
) )

View File

@ -28,6 +28,8 @@ class Base(ABC):
raise NotImplementedError("Please implement encode method!") raise NotImplementedError("Please implement encode method!")
def image2base64(self, image): def image2base64(self, image):
if isinstance(image, bytes):
return base64.b64encode(image).decode("utf-8")
if isinstance(image, BytesIO): if isinstance(image, BytesIO):
return base64.b64encode(image.getvalue()).decode("utf-8") return base64.b64encode(image.getvalue()).decode("utf-8")
buffered = BytesIO() buffered = BytesIO()
@ -59,7 +61,7 @@ class Base(ABC):
class GptV4(Base): class GptV4(Base):
def __init__(self, key, model_name="gpt-4-vision-preview"): def __init__(self, key, model_name="gpt-4-vision-preview"):
self.client = OpenAI(key) self.client = OpenAI(api_key = key)
self.model_name = model_name self.model_name = model_name
def describe(self, image, max_tokens=300): def describe(self, image, max_tokens=300):

View File

@ -187,9 +187,10 @@ class Dealer:
if len(t) < 5: continue if len(t) < 5: continue
idx.append(i) idx.append(i)
pieces_.append(t) pieces_.append(t)
es_logger.info("{} => {}".format(answer, pieces_))
if not pieces_: return answer if not pieces_: return answer
ans_v = embd_mdl.encode(pieces_) ans_v, c = embd_mdl.encode(pieces_)
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format( assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
len(ans_v[0]), len(chunk_v[0])) len(ans_v[0]), len(chunk_v[0]))
@ -219,7 +220,7 @@ class Dealer:
Dealer.trans2floats( Dealer.trans2floats(
sres.field[i]["q_%d_vec" % len(sres.query_vector)]) for i in sres.ids] sres.field[i]["q_%d_vec" % len(sres.query_vector)]) for i in sres.ids]
if not ins_embd: if not ins_embd:
return [] return [], [], []
ins_tw = [huqie.qie(sres.field[i][cfield]).split(" ") for i in sres.ids] ins_tw = [huqie.qie(sres.field[i][cfield]).split(" ") for i in sres.ids]
sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector, sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector,
ins_embd, ins_embd,
@ -235,6 +236,8 @@ class Dealer:
def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2, def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2,
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True): vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True):
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
if not question: return ranks
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": top, req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": top,
"question": question, "vector": True, "question": question, "vector": True,
"similarity": similarity_threshold} "similarity": similarity_threshold}
@ -243,7 +246,7 @@ class Dealer:
sim, tsim, vsim = self.rerank( sim, tsim, vsim = self.rerank(
sres, question, 1 - vector_similarity_weight, vector_similarity_weight) sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
idx = np.argsort(sim * -1) idx = np.argsort(sim * -1)
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
dim = len(sres.query_vector) dim = len(sres.query_vector)
start_idx = (page - 1) * page_size start_idx = (page - 1) * page_size
for i in idx: for i in idx:

View File

@ -78,6 +78,7 @@ def chuck_doc(name, binary, cvmdl=None):
field = TextChunker.Fields() field = TextChunker.Fields()
field.text_chunks = [(txt, binary)] field.text_chunks = [(txt, binary)]
field.table_chunks = [] field.table_chunks = []
return field
return TextChunker()(binary) return TextChunker()(binary)
@ -161,9 +162,9 @@ def build(row, cvmdl):
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"]) doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
output_buffer = BytesIO() output_buffer = BytesIO()
docs = [] docs = []
md5 = hashlib.md5()
for txt, img in obj.text_chunks: for txt, img in obj.text_chunks:
d = copy.deepcopy(doc) d = copy.deepcopy(doc)
md5 = hashlib.md5()
md5.update((txt + str(d["doc_id"])).encode("utf-8")) md5.update((txt + str(d["doc_id"])).encode("utf-8"))
d["_id"] = md5.hexdigest() d["_id"] = md5.hexdigest()
d["content_ltks"] = huqie.qie(txt) d["content_ltks"] = huqie.qie(txt)
@ -186,6 +187,7 @@ def build(row, cvmdl):
for i, txt in enumerate(arr): for i, txt in enumerate(arr):
d = copy.deepcopy(doc) d = copy.deepcopy(doc)
d["content_ltks"] = huqie.qie(txt) d["content_ltks"] = huqie.qie(txt)
md5 = hashlib.md5()
md5.update((txt + str(d["doc_id"])).encode("utf-8")) md5.update((txt + str(d["doc_id"])).encode("utf-8"))
d["_id"] = md5.hexdigest() d["_id"] = md5.hexdigest()
if not img: if not img:
@ -226,9 +228,6 @@ def embedding(docs, mdl):
def main(comm, mod): def main(comm, mod):
global model
from rag.llm import HuEmbedding
model = HuEmbedding()
tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm") tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm")
tm = findMaxTm(tm_fnm) tm = findMaxTm(tm_fnm)
rows = collect(comm, mod, tm) rows = collect(comm, mod, tm)
@ -260,13 +259,14 @@ def main(comm, mod):
set_progress(r["id"], random.randint(70, 95) / 100., set_progress(r["id"], random.randint(70, 95) / 100.,
"Finished embedding! Start to build index!") "Finished embedding! Start to build index!")
init_kb(r) init_kb(r)
chunk_count = len(set([c["_id"] for c in cks]))
es_r = ELASTICSEARCH.bulk(cks, search.index_name(r["tenant_id"])) es_r = ELASTICSEARCH.bulk(cks, search.index_name(r["tenant_id"]))
if es_r: if es_r:
set_progress(r["id"], -1, "Index failure!") set_progress(r["id"], -1, "Index failure!")
cron_logger.error(str(es_r)) cron_logger.error(str(es_r))
else: else:
set_progress(r["id"], 1., "Done!") set_progress(r["id"], 1., "Done!")
DocumentService.increment_chunk_num(r["id"], r["kb_id"], tk_count, len(cks), timer()-st_tm) DocumentService.increment_chunk_num(r["id"], r["kb_id"], tk_count, chunk_count, timer()-st_tm)
cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks))) cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks)))
tmf.write(str(r["update_time"]) + "\n") tmf.write(str(r["update_time"]) + "\n")