llm configuation refine and trievalTest API refine (#40)

This commit is contained in:
KevinHuSh 2024-01-19 19:51:57 +08:00 committed by GitHub
parent f3dd131403
commit 484e5abc1f
39 changed files with 160 additions and 121 deletions

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import datetime
from flask import request
from flask_login import login_required, current_user
@ -177,6 +178,7 @@ def create():
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
d["important_kwd"] = req.get("important_kwd", [])
d["important_tks"] = huqie.qie(" ".join(req.get("important_kwd", [])))
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
@ -223,7 +225,7 @@ def retrieval_test():
embd_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.EMBEDDING.value)
ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, similarity_threshold,
vector_similarity_weight, top, doc_ids)
vector_similarity_weight, top, doc_ids)
return get_json_result(data=ranks)
except Exception as e:
@ -231,4 +233,3 @@ def retrieval_test():
return get_json_result(data=False, retmsg=f'Index not found!',
retcode=RetCode.DATA_ERROR)
return server_error_response(e)

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,22 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
import tiktoken
from flask import request
from flask_login import login_required, current_user
from flask_login import login_required
from api.db.services.dialog_service import DialogService, ConversationService
from api.db import StatusEnum, LLMType
from api.db.services.kb_service import KnowledgebaseService
from api.db import LLMType
from api.db.services.llm_service import LLMService, TenantLLMService
from api.db.services.user_service import TenantService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid
from api.utils.api_utils import get_json_result
from rag.llm import ChatModel
from rag.nlp import retrievaler
from rag.nlp.query import EsQueryer
from rag.utils import num_tokens_from_string, encoder
@ -142,6 +136,27 @@ def message_fit_in(msg, max_length=4000):
return max_length, msg
@manager.route('/completion', methods=['POST'])
@login_required
@validate_request("dialog_id", "messages")
def completion():
req = request.json
msg = []
for m in req["messages"]:
if m["role"] == "system":continue
if m["role"] == "assistant" and not msg:continue
msg.append({"role": m["role"], "content": m["content"]})
try:
e, dia = DialogService.get_by_id(req["dialog_id"])
if not e:
return get_data_error_result(retmsg="Dialog not found!")
del req["dialog_id"]
del req["messages"]
return get_json_result(data=chat(dia, msg, **req))
except Exception as e:
return server_error_response(e)
def chat(dialog, messages, **kwargs):
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
llm = LLMService.query(llm_name=dialog.llm_id)
@ -156,7 +171,7 @@ def chat(dialog, messages, **kwargs):
prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ")
model_config = TenantLLMService.get_api_key(dialog.tenant_id, LLMType.CHAT.value, dialog.llm_id)
if not model_config: raise LookupError("LLM(%s) API key not found"%dialog.llm_id)
if not model_config: raise LookupError("LLM({}) API key not found".format(dialog.llm_id))
question = messages[-1]["content"]
embd_mdl = TenantLLMService.model_instance(
@ -183,25 +198,4 @@ def chat(dialog, messages, **kwargs):
embd_mdl,
tkweight=1-dialog.vector_similarity_weight,
vtweight=dialog.vector_similarity_weight)
return {"answer": answer, "retrieval": kbinfos}
@manager.route('/completion', methods=['POST'])
@login_required
@validate_request("dialog_id", "messages")
def completion():
req = request.json
msg = []
for m in req["messages"]:
if m["role"] == "system":continue
if m["role"] == "assistant" and not msg:continue
msg.append({"role": m["role"], "content": m["content"]})
try:
e, dia = DialogService.get_by_id(req["dialog_id"])
if not e:
return get_data_error_result(retmsg="Dialog not found!")
del req["dialog_id"]
del req["messages"]
return get_json_result(data=chat(dia, msg, **req))
except Exception as e:
return server_error_response(e)
return {"answer": answer, "retrieval": kbinfos}

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -71,18 +71,12 @@ def my_llms():
def list():
try:
objs = TenantLLMService.query(tenant_id=current_user.id)
objs = [o.to_dict() for o in objs if o.api_key]
fct = {}
for o in objs:
if o["llm_factory"] not in fct: fct[o["llm_factory"]] = []
if o["llm_name"]: fct[o["llm_factory"]].append(o["llm_name"])
mdlnms = set([o.to_dict()["llm_name"] for o in objs if o.api_key])
llms = LLMService.get_all()
llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value]
for m in llms:
m["available"] = False
if m["fid"] in fct and (not fct[m["fid"]] or m["llm_name"] in fct[m["fid"]]):
m["available"] = True
m["available"] = m.llm_name in mdlnms
res = {}
for m in llms:
if m["fid"] not in res: res[m["fid"]] = []

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,12 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
from flask import request, session, redirect, url_for
from werkzeug.security import generate_password_hash, check_password_hash
from flask_login import login_required, current_user, login_user, logout_user
from api.db.db_models import TenantLLM
from api.db.services.llm_service import TenantLLMService
from api.db.services.llm_service import TenantLLMService, LLMService
from api.utils.api_utils import server_error_response, validate_request
from api.utils import get_uuid, get_format_time, decrypt, download_img
from api.db import UserTenantRole, LLMType
@ -185,8 +187,6 @@ def rollback_user_registration(user_id):
def user_register(user_id, user):
user_id = get_uuid()
user["id"] = user_id
tenant = {
"id": user_id,
@ -203,12 +203,14 @@ def user_register(user_id, user):
"invited_by": user_id,
"role": UserTenantRole.OWNER
}
tenant_llm = {"tenant_id": user_id, "llm_factory": "OpenAI", "api_key": "infiniflow API Key"}
tenant_llm = []
for llm in LLMService.query(fid="Infiniflow"):
tenant_llm.append({"tenant_id": user_id, "llm_factory": "Infiniflow", "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": "infiniflow API Key"})
if not UserService.save(**user):return
TenantService.save(**tenant)
UserTenantService.save(**usr_tenant)
TenantLLMService.save(**tenant_llm)
TenantLLMService.insert_many(tenant_llm)
return UserService.query(email=user["email"])
@ -218,6 +220,9 @@ def user_add():
req = request.json
if UserService.query(email=req["email"]):
return get_json_result(data=False, retmsg=f'Email: {req["email"]} has already registered!', retcode=RetCode.OPERATING_ERROR)
if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,4}$", req["email"]):
return get_json_result(data=False, retmsg=f'Invaliad e-mail: {req["email"]}!',
retcode=RetCode.OPERATING_ERROR)
user_dict = {
"access_token": get_uuid(),

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -426,8 +426,8 @@ class LLMFactories(DataBaseModel):
class LLM(DataBaseModel):
# defautlt LLMs for every users
llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True)
# LLMs dictionary
llm_name = CharField(max_length=128, null=False, help_text="LLM name", index=True)
model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
fid = CharField(max_length=128, null=False, help_text="LLM factory id")
max_tokens = IntegerField(default=0)
@ -448,6 +448,7 @@ class TenantLLM(DataBaseModel):
llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="")
api_key = CharField(max_length=255, null=True, help_text="API KEY")
api_base = CharField(max_length=255, null=True, help_text="API Base")
used_tokens = IntegerField(default=0)
def __str__(self):
return self.llm_name
@ -468,8 +469,8 @@ class Knowledgebase(DataBaseModel):
doc_num = IntegerField(default=0)
token_num = IntegerField(default=0)
chunk_num = IntegerField(default=0)
#similarity_threshold = FloatField(default=0.4)
#vector_similarity_weight = FloatField(default=0.3)
similarity_threshold = FloatField(default=0.4)
vector_similarity_weight = FloatField(default=0.3)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -46,6 +46,11 @@ def init_llm_factory():
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},{
"name": "Infiniflow",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},{
"name": "智普AI",
"logo": "",
@ -130,6 +135,30 @@ def init_llm_factory():
"tags": "LLM,CHAT,IMAGE2TEXT",
"max_tokens": 765,
"model_type": LLMType.IMAGE2TEXT.value
},{
"fid": factory_infos[2]["name"],
"llm_name": "gpt-3.5-turbo",
"tags": "LLM,CHAT,4K",
"max_tokens": 4096,
"model_type": LLMType.CHAT.value
},{
"fid": factory_infos[2]["name"],
"llm_name": "text-embedding-ada-002",
"tags": "TEXT EMBEDDING,8K",
"max_tokens": 8191,
"model_type": LLMType.EMBEDDING.value
},{
"fid": factory_infos[2]["name"],
"llm_name": "whisper-1",
"tags": "SPEECH2TEXT",
"max_tokens": 25*1024*1024,
"model_type": LLMType.SPEECH2TEXT.value
},{
"fid": factory_infos[2]["name"],
"llm_name": "gpt-4-vision-preview",
"tags": "LLM,CHAT,IMAGE2TEXT",
"max_tokens": 765,
"model_type": LLMType.IMAGE2TEXT.value
},
]
for info in factory_infos:

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from api.db.services.user_service import TenantService
from rag.llm import EmbeddingModel, CvModel
from api.db import LLMType
from api.db.db_models import DB, UserTenant
@ -34,40 +35,39 @@ class TenantLLMService(CommonService):
@classmethod
@DB.connection_context()
def get_api_key(cls, tenant_id, model_type, model_name=""):
objs = cls.query(tenant_id=tenant_id, model_type=model_type)
if objs and len(objs)>0 and objs[0].llm_name:
return objs[0]
fields = [LLM.llm_name, cls.model.llm_factory, cls.model.api_key]
objs = cls.model.select(*fields).join(LLM, on=(LLM.fid == cls.model.llm_factory)).where(
(cls.model.tenant_id == tenant_id),
((cls.model.model_type == model_type) | (cls.model.llm_name == model_name)),
(LLM.status == StatusEnum.VALID)
)
if not objs:return
def get_api_key(cls, tenant_id, model_name):
objs = cls.query(tenant_id=tenant_id, llm_name=model_name)
if not objs: return
return objs[0]
@classmethod
@DB.connection_context()
def get_my_llms(cls, tenant_id):
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name]
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory==LLMFactories.name)).where(cls.model.tenant_id==tenant_id).dicts()
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(
cls.model.tenant_id == tenant_id).dicts()
return list(objs)
@classmethod
@DB.connection_context()
def model_instance(cls, tenant_id, llm_type):
model_config = cls.get_api_key(tenant_id, model_type=LLMType.EMBEDDING.value)
if not model_config:
model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""}
else:
model_config = model_config[0].to_dict()
if llm_type == LLMType.EMBEDDING:
e,tenant = TenantService.get_by_id(tenant_id)
if not e: raise LookupError("Tenant not found")
if llm_type == LLMType.EMBEDDING.value: mdlnm = tenant.embd_id
elif llm_type == LLMType.SPEECH2TEXT.value: mdlnm = tenant.asr_id
elif llm_type == LLMType.IMAGE2TEXT.value: mdlnm = tenant.img2txt_id
elif llm_type == LLMType.CHAT.value: mdlnm = tenant.llm_id
else: assert False, "LLM type error"
model_config = cls.get_api_key(tenant_id, mdlnm)
if not model_config: raise LookupError("Model({}) not found".format(mdlnm))
model_config = model_config[0].to_dict()
if llm_type == LLMType.EMBEDDING.value:
if model_config["llm_factory"] not in EmbeddingModel: return
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
if llm_type == LLMType.IMAGE2TEXT:
if llm_type == LLMType.IMAGE2TEXT.value:
if model_config["llm_factory"] not in CvModel: return
return CvModel[model_config.llm_factory](model_config["api_key"], model_config["llm_name"])
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -19,7 +19,7 @@ from .cv_model import *
EmbeddingModel = {
"local": HuEmbedding,
"Infiniflow": HuEmbedding,
"OpenAI": OpenAIEmbed,
"通义千问": QWenEmbed,
}
@ -27,12 +27,14 @@ EmbeddingModel = {
CvModel = {
"OpenAI": GptV4,
"Infiniflow": GptV4,
"通义千问": QWenCV,
}
ChatModel = {
"OpenAI": GptTurbo,
"Infiniflow": GptTurbo,
"通义千问": QWenChat,
}

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -24,6 +24,9 @@ import numpy as np
from rag.utils import num_tokens_from_string
flag_model = FlagModel("BAAI/bge-large-zh-v1.5",
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=torch.cuda.is_available())
class Base(ABC):
def __init__(self, key, model_name):
@ -47,9 +50,7 @@ class HuEmbedding(Base):
^_-
"""
self.model = FlagModel("BAAI/bge-large-zh-v1.5",
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=torch.cuda.is_available())
self.model = flag_model
def encode(self, texts: list, batch_size=32):

View File

@ -42,7 +42,7 @@ class EsQueryer:
def question(self, txt, tbl="qa", min_match="60%"):
txt = re.sub(
r"[ \t,,。??/`!&]+",
r"[ \r\n\t,,。??/`!&]+",
" ",
huqie.tradi2simp(
huqie.strQ2B(

View File

@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
import fitz
import xgboost as xgb
from io import BytesIO
import torch
@ -1527,8 +1528,6 @@ class HuParser:
return "\n\n".join(res)
def __call__(self, fnm, need_image=True, zoomin=3, return_html=False):
self.pdf = pdfplumber.open(fnm) if isinstance(
fnm, str) else pdfplumber.open(BytesIO(fnm))
self.lefted_chars = []
self.mean_height = []
self.mean_width = []
@ -1536,13 +1535,26 @@ class HuParser:
self.garbages = {}
self.page_cum_height = [0]
self.page_layout = []
self.page_images = [p.to_image(
resolution=72 * zoomin).annotated for i, p in enumerate(self.pdf.pages[:299])]
try:
self.pdf = pdfplumber.open(fnm) if isinstance(fnm, str) else pdfplumber.open(BytesIO(fnm))
self.page_images = [p.to_image(resolution=72*zoomin).annotated for i,p in enumerate(self.pdf.pages[:299])]
self.page_chars = [[c for c in self.pdf.pages[i].chars if self._has_color(c)] for i in range(len(self.page_images))]
except Exception as e:
self.pdf = fitz.open(fnm) if isinstance(fnm, str) else fitz.open(stream=fnm, filetype="pdf")
self.page_images = []
self.page_chars = []
mat = fitz.Matrix(zoomin, zoomin)
for page in self.pdf:
pix = page.getPixmap(matrix = mat)
img = Image.frombytes("RGB", [pix.width, pix.height],
pix.samples)
self.page_images.append(img)
self.page_chars.append([])
logging.info("Images converted.")
logging.info("Table processed.")
for i, img in enumerate(self.page_images):
chars = [c for c in self.pdf.pages[i].chars if self._has_color(c)]
chars = self.page_chars[i]
self.mean_height.append(
np.median(sorted([c["height"] for c in chars])) if chars else 0
)

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.