mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-12 06:49:00 +08:00
add rerank model (#969)
### What problem does this PR solve? feat: add rerank models to the project #724 #162 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
parent
e1f0644deb
commit
614defec21
@ -257,8 +257,15 @@ def retrieval_test():
|
||||
|
||||
embd_mdl = TenantLLMService.model_instance(
|
||||
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||
ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, similarity_threshold,
|
||||
vector_similarity_weight, top, doc_ids)
|
||||
|
||||
rerank_mdl = None
|
||||
if req.get("rerank_id"):
|
||||
rerank_mdl = TenantLLMService.model_instance(
|
||||
kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
|
||||
|
||||
ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size,
|
||||
similarity_threshold, vector_similarity_weight, top,
|
||||
doc_ids, rerank_mdl=rerank_mdl)
|
||||
for c in ranks["chunks"]:
|
||||
if "vector" in c:
|
||||
del c["vector"]
|
||||
|
@ -33,6 +33,9 @@ def set_dialog():
|
||||
name = req.get("name", "New Dialog")
|
||||
description = req.get("description", "A helpful Dialog")
|
||||
top_n = req.get("top_n", 6)
|
||||
top_k = req.get("top_k", 1024)
|
||||
rerank_id = req.get("rerank_id", "")
|
||||
if not rerank_id: req["rerank_id"] = ""
|
||||
similarity_threshold = req.get("similarity_threshold", 0.1)
|
||||
vector_similarity_weight = req.get("vector_similarity_weight", 0.3)
|
||||
llm_setting = req.get("llm_setting", {})
|
||||
@ -83,6 +86,8 @@ def set_dialog():
|
||||
"llm_setting": llm_setting,
|
||||
"prompt_config": prompt_config,
|
||||
"top_n": top_n,
|
||||
"top_k": top_k,
|
||||
"rerank_id": rerank_id,
|
||||
"similarity_threshold": similarity_threshold,
|
||||
"vector_similarity_weight": vector_similarity_weight
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
|
||||
from api.db import StatusEnum, LLMType
|
||||
from api.db.db_models import TenantLLM
|
||||
from api.utils.api_utils import get_json_result
|
||||
from rag.llm import EmbeddingModel, ChatModel
|
||||
from rag.llm import EmbeddingModel, ChatModel, RerankModel
|
||||
|
||||
|
||||
@manager.route('/factories', methods=['GET'])
|
||||
@ -28,7 +28,7 @@ from rag.llm import EmbeddingModel, ChatModel
|
||||
def factories():
|
||||
try:
|
||||
fac = LLMFactoriesService.get_all()
|
||||
return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed"]])
|
||||
return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI"]])
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -64,6 +64,16 @@ def set_api_key():
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
|
||||
e)
|
||||
elif llm.model_type == LLMType.RERANK:
|
||||
mdl = RerankModel[factory](
|
||||
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
|
||||
try:
|
||||
m, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
|
||||
if len(arr[0]) == 0 or tc == 0:
|
||||
raise Exception("Fail")
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
|
||||
e)
|
||||
|
||||
if msg:
|
||||
return get_data_error_result(retmsg=msg)
|
||||
@ -199,7 +209,7 @@ def list_app():
|
||||
llms = [m.to_dict()
|
||||
for m in llms if m.status == StatusEnum.VALID.value]
|
||||
for m in llms:
|
||||
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["Youdao","FastEmbed"]
|
||||
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["Youdao","FastEmbed", "BAAI"]
|
||||
|
||||
llm_set = set([m["llm_name"] for m in llms])
|
||||
for o in objs:
|
||||
|
@ -26,8 +26,9 @@ from api.db.services.llm_service import TenantLLMService, LLMService
|
||||
from api.utils.api_utils import server_error_response, validate_request
|
||||
from api.utils import get_uuid, get_format_time, decrypt, download_img, current_timestamp, datetime_format
|
||||
from api.db import UserTenantRole, LLMType, FileType
|
||||
from api.settings import RetCode, GITHUB_OAUTH, FEISHU_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, API_KEY, \
|
||||
LLM_FACTORY, LLM_BASE_URL
|
||||
from api.settings import RetCode, GITHUB_OAUTH, FEISHU_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, \
|
||||
API_KEY, \
|
||||
LLM_FACTORY, LLM_BASE_URL, RERANK_MDL
|
||||
from api.db.services.user_service import UserService, TenantService, UserTenantService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.settings import stat_logger
|
||||
@ -288,7 +289,8 @@ def user_register(user_id, user):
|
||||
"embd_id": EMBEDDING_MDL,
|
||||
"asr_id": ASR_MDL,
|
||||
"parser_ids": PARSERS,
|
||||
"img2txt_id": IMAGE2TEXT_MDL
|
||||
"img2txt_id": IMAGE2TEXT_MDL,
|
||||
"rerank_id": RERANK_MDL
|
||||
}
|
||||
usr_tenant = {
|
||||
"tenant_id": user_id,
|
||||
|
@ -54,6 +54,7 @@ class LLMType(StrEnum):
|
||||
EMBEDDING = 'embedding'
|
||||
SPEECH2TEXT = 'speech2text'
|
||||
IMAGE2TEXT = 'image2text'
|
||||
RERANK = 'rerank'
|
||||
|
||||
|
||||
class ChatStyle(StrEnum):
|
||||
|
@ -437,6 +437,10 @@ class Tenant(DataBaseModel):
|
||||
max_length=128,
|
||||
null=False,
|
||||
help_text="default image to text model ID")
|
||||
rerank_id = CharField(
|
||||
max_length=128,
|
||||
null=False,
|
||||
help_text="default rerank model ID")
|
||||
parser_ids = CharField(
|
||||
max_length=256,
|
||||
null=False,
|
||||
@ -771,11 +775,16 @@ class Dialog(DataBaseModel):
|
||||
similarity_threshold = FloatField(default=0.2)
|
||||
vector_similarity_weight = FloatField(default=0.3)
|
||||
top_n = IntegerField(default=6)
|
||||
top_k = IntegerField(default=1024)
|
||||
do_refer = CharField(
|
||||
max_length=1,
|
||||
null=False,
|
||||
help_text="it needs to insert reference index into answer or not",
|
||||
default="1")
|
||||
rerank_id = CharField(
|
||||
max_length=128,
|
||||
null=False,
|
||||
help_text="default rerank model ID")
|
||||
|
||||
kb_ids = JSONField(null=False, default=[])
|
||||
status = CharField(
|
||||
@ -825,11 +834,29 @@ class API4Conversation(DataBaseModel):
|
||||
|
||||
|
||||
def migrate_db():
|
||||
try:
|
||||
with DB.transaction():
|
||||
migrator = MySQLMigrator(DB)
|
||||
migrate(
|
||||
migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="", help_text="where dose this document come from"))
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
migrate(
|
||||
migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="", help_text="where dose this document come from"))
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
migrate(
|
||||
migrator.add_column('tenant', 'rerank_id', CharField(max_length=128, null=False, default="BAAI/bge-reranker-v2-m3", help_text="default rerank model ID"))
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
migrate(
|
||||
migrator.add_column('dialog', 'rerank_id', CharField(max_length=128, null=False, default="", help_text="default rerank model ID"))
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
migrate(
|
||||
migrator.add_column('dialog', 'top_k', IntegerField(default=1024))
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
@ -142,7 +142,17 @@ factory_infos = [{
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING",
|
||||
"status": "1",
|
||||
},
|
||||
},{
|
||||
"name": "Jina",
|
||||
"logo": "",
|
||||
"tags": "TEXT EMBEDDING, TEXT RE-RANK",
|
||||
"status": "1",
|
||||
},{
|
||||
"name": "BAAI",
|
||||
"logo": "",
|
||||
"tags": "TEXT EMBEDDING, TEXT RE-RANK",
|
||||
"status": "1",
|
||||
}
|
||||
# {
|
||||
# "name": "文心一言",
|
||||
# "logo": "",
|
||||
@ -367,6 +377,13 @@ def init_llm_factory():
|
||||
"max_tokens": 512,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[7]["name"],
|
||||
"llm_name": "maidalun1020/bce-reranker-base_v1",
|
||||
"tags": "RE-RANK, 8K",
|
||||
"max_tokens": 8196,
|
||||
"model_type": LLMType.RERANK.value
|
||||
},
|
||||
# ------------------------ DeepSeek -----------------------
|
||||
{
|
||||
"fid": factory_infos[8]["name"],
|
||||
@ -440,6 +457,85 @@ def init_llm_factory():
|
||||
"max_tokens": 512,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},
|
||||
# ------------------------ Jina -----------------------
|
||||
{
|
||||
"fid": factory_infos[11]["name"],
|
||||
"llm_name": "jina-reranker-v1-base-en",
|
||||
"tags": "RE-RANK,8k",
|
||||
"max_tokens": 8196,
|
||||
"model_type": LLMType.RERANK.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[11]["name"],
|
||||
"llm_name": "jina-reranker-v1-turbo-en",
|
||||
"tags": "RE-RANK,8k",
|
||||
"max_tokens": 8196,
|
||||
"model_type": LLMType.RERANK.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[11]["name"],
|
||||
"llm_name": "jina-reranker-v1-tiny-en",
|
||||
"tags": "RE-RANK,8k",
|
||||
"max_tokens": 8196,
|
||||
"model_type": LLMType.RERANK.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[11]["name"],
|
||||
"llm_name": "jina-colbert-v1-en",
|
||||
"tags": "RE-RANK,8k",
|
||||
"max_tokens": 8196,
|
||||
"model_type": LLMType.RERANK.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[11]["name"],
|
||||
"llm_name": "jina-embeddings-v2-base-en",
|
||||
"tags": "TEXT EMBEDDING",
|
||||
"max_tokens": 8196,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[11]["name"],
|
||||
"llm_name": "jina-embeddings-v2-base-de",
|
||||
"tags": "TEXT EMBEDDING",
|
||||
"max_tokens": 8196,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[11]["name"],
|
||||
"llm_name": "jina-embeddings-v2-base-es",
|
||||
"tags": "TEXT EMBEDDING",
|
||||
"max_tokens": 8196,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[11]["name"],
|
||||
"llm_name": "jina-embeddings-v2-base-code",
|
||||
"tags": "TEXT EMBEDDING",
|
||||
"max_tokens": 8196,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[11]["name"],
|
||||
"llm_name": "jina-embeddings-v2-base-zh",
|
||||
"tags": "TEXT EMBEDDING",
|
||||
"max_tokens": 8196,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},
|
||||
# ------------------------ BAAI -----------------------
|
||||
{
|
||||
"fid": factory_infos[12]["name"],
|
||||
"llm_name": "BAAI/bge-large-zh-v1.5",
|
||||
"tags": "TEXT EMBEDDING,",
|
||||
"max_tokens": 1024,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},
|
||||
{
|
||||
"fid": factory_infos[12]["name"],
|
||||
"llm_name": "BAAI/bge-reranker-v2-m3",
|
||||
"tags": "LLM,CHAT,",
|
||||
"max_tokens": 16385,
|
||||
"model_type": LLMType.RERANK.value
|
||||
},
|
||||
]
|
||||
for info in factory_infos:
|
||||
try:
|
||||
|
@ -115,11 +115,14 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
|
||||
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
|
||||
else:
|
||||
rerank_mdl = None
|
||||
if dialog.rerank_id:
|
||||
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
|
||||
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
|
||||
dialog.similarity_threshold,
|
||||
dialog.vector_similarity_weight,
|
||||
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
|
||||
top=1024, aggs=False)
|
||||
top=1024, aggs=False, rerank_mdl=rerank_mdl)
|
||||
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
||||
chat_logger.info(
|
||||
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
||||
@ -130,7 +133,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
|
||||
kwargs["knowledge"] = "\n".join(knowledges)
|
||||
gen_conf = dialog.llm_setting
|
||||
|
||||
|
||||
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
|
||||
msg.extend([{"role": m["role"], "content": m["content"]}
|
||||
for m in messages if m["role"] != "system"])
|
||||
|
@ -15,7 +15,7 @@
|
||||
#
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.settings import database_logger
|
||||
from rag.llm import EmbeddingModel, CvModel, ChatModel
|
||||
from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel
|
||||
from api.db import LLMType
|
||||
from api.db.db_models import DB, UserTenant
|
||||
from api.db.db_models import LLMFactories, LLM, TenantLLM
|
||||
@ -73,21 +73,25 @@ class TenantLLMService(CommonService):
|
||||
mdlnm = tenant.img2txt_id
|
||||
elif llm_type == LLMType.CHAT.value:
|
||||
mdlnm = tenant.llm_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.RERANK:
|
||||
mdlnm = tenant.rerank_id if not llm_name else llm_name
|
||||
else:
|
||||
assert False, "LLM type error"
|
||||
|
||||
model_config = cls.get_api_key(tenant_id, mdlnm)
|
||||
if model_config: model_config = model_config.to_dict()
|
||||
if not model_config:
|
||||
if llm_type == LLMType.EMBEDDING.value:
|
||||
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
||||
llm = LLMService.query(llm_name=llm_name)
|
||||
if llm and llm[0].fid in ["Youdao", "FastEmbed", "DeepSeek"]:
|
||||
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
|
||||
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""}
|
||||
if not model_config:
|
||||
if llm_name == "flag-embedding":
|
||||
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
|
||||
"llm_name": llm_name, "api_base": ""}
|
||||
else:
|
||||
if not mdlnm:
|
||||
raise LookupError(f"Type of {llm_type} model is not set.")
|
||||
raise LookupError("Model({}) not authorized".format(mdlnm))
|
||||
|
||||
if llm_type == LLMType.EMBEDDING.value:
|
||||
@ -96,6 +100,12 @@ class TenantLLMService(CommonService):
|
||||
return EmbeddingModel[model_config["llm_factory"]](
|
||||
model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||
|
||||
if llm_type == LLMType.RERANK:
|
||||
if model_config["llm_factory"] not in RerankModel:
|
||||
return
|
||||
return RerankModel[model_config["llm_factory"]](
|
||||
model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||
|
||||
if llm_type == LLMType.IMAGE2TEXT.value:
|
||||
if model_config["llm_factory"] not in CvModel:
|
||||
return
|
||||
@ -125,14 +135,20 @@ class TenantLLMService(CommonService):
|
||||
mdlnm = tenant.img2txt_id
|
||||
elif llm_type == LLMType.CHAT.value:
|
||||
mdlnm = tenant.llm_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.RERANK:
|
||||
mdlnm = tenant.llm_id if not llm_name else llm_name
|
||||
else:
|
||||
assert False, "LLM type error"
|
||||
|
||||
num = 0
|
||||
for u in cls.query(tenant_id = tenant_id, llm_name=mdlnm):
|
||||
num += cls.model.update(used_tokens = u.used_tokens + used_tokens)\
|
||||
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\
|
||||
.execute()
|
||||
try:
|
||||
for u in cls.query(tenant_id = tenant_id, llm_name=mdlnm):
|
||||
num += cls.model.update(used_tokens = u.used_tokens + used_tokens)\
|
||||
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\
|
||||
.execute()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
@ -176,6 +192,14 @@ class LLMBundle(object):
|
||||
"Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
|
||||
return emd, used_tokens
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
sim, used_tokens = self.mdl.similarity(query, texts)
|
||||
if not TenantLLMService.increase_usage(
|
||||
self.tenant_id, self.llm_type, used_tokens):
|
||||
database_logger.error(
|
||||
"Can't update token usage for {}/RERANK".format(self.tenant_id))
|
||||
return sim, used_tokens
|
||||
|
||||
def describe(self, image, max_tokens=300):
|
||||
txt, used_tokens = self.mdl.describe(image, max_tokens)
|
||||
if not TenantLLMService.increase_usage(
|
||||
|
@ -93,6 +93,7 @@ class TenantService(CommonService):
|
||||
cls.model.name,
|
||||
cls.model.llm_id,
|
||||
cls.model.embd_id,
|
||||
cls.model.rerank_id,
|
||||
cls.model.asr_id,
|
||||
cls.model.img2txt_id,
|
||||
cls.model.parser_ids,
|
||||
|
@ -89,9 +89,22 @@ default_llm = {
|
||||
},
|
||||
"DeepSeek": {
|
||||
"chat_model": "deepseek-chat",
|
||||
"embedding_model": "",
|
||||
"image2text_model": "",
|
||||
"asr_model": "",
|
||||
},
|
||||
"VolcEngine": {
|
||||
"chat_model": "",
|
||||
"embedding_model": "",
|
||||
"image2text_model": "",
|
||||
"asr_model": "",
|
||||
},
|
||||
"BAAI": {
|
||||
"chat_model": "",
|
||||
"embedding_model": "BAAI/bge-large-zh-v1.5",
|
||||
"image2text_model": "",
|
||||
"asr_model": "",
|
||||
"rerank_model": "BAAI/bge-reranker-v2-m3",
|
||||
}
|
||||
}
|
||||
LLM = get_base_config("user_default_llm", {})
|
||||
@ -104,7 +117,8 @@ if LLM_FACTORY not in default_llm:
|
||||
f"LLM factory {LLM_FACTORY} has not supported yet, switch to 'Tongyi-Qianwen/QWen' automatically, and please check the API_KEY in service_conf.yaml.")
|
||||
LLM_FACTORY = "Tongyi-Qianwen"
|
||||
CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"]
|
||||
EMBEDDING_MDL = default_llm[LLM_FACTORY]["embedding_model"]
|
||||
EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"]
|
||||
RERANK_MDL = default_llm["BAAI"]["rerank_model"]
|
||||
ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
|
||||
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
|
||||
|
||||
|
@ -16,18 +16,19 @@
|
||||
from .embedding_model import *
|
||||
from .chat_model import *
|
||||
from .cv_model import *
|
||||
from .rerank_model import *
|
||||
|
||||
|
||||
EmbeddingModel = {
|
||||
"Ollama": OllamaEmbed,
|
||||
"OpenAI": OpenAIEmbed,
|
||||
"Xinference": XinferenceEmbed,
|
||||
"Tongyi-Qianwen": DefaultEmbedding, #QWenEmbed,
|
||||
"Tongyi-Qianwen": DefaultEmbedding,#QWenEmbed,
|
||||
"ZHIPU-AI": ZhipuEmbed,
|
||||
"FastEmbed": FastEmbed,
|
||||
"Youdao": YoudaoEmbed,
|
||||
"DeepSeek": DefaultEmbedding,
|
||||
"BaiChuan": BaiChuanEmbed
|
||||
"BaiChuan": BaiChuanEmbed,
|
||||
"BAAI": DefaultEmbedding
|
||||
}
|
||||
|
||||
|
||||
@ -52,3 +53,9 @@ ChatModel = {
|
||||
"BaiChuan": BaiChuanChat
|
||||
}
|
||||
|
||||
|
||||
RerankModel = {
|
||||
"BAAI": DefaultRerank,
|
||||
"Jina": JinaRerank,
|
||||
"Youdao": YoudaoRerank,
|
||||
}
|
||||
|
@ -13,8 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from huggingface_hub import snapshot_download
|
||||
from zhipuai import ZhipuAI
|
||||
import os
|
||||
@ -26,21 +28,9 @@ from FlagEmbedding import FlagModel
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from api.utils.file_utils import get_project_base_directory, get_home_cache_dir
|
||||
from api.utils.file_utils import get_home_cache_dir
|
||||
from rag.utils import num_tokens_from_string, truncate
|
||||
|
||||
try:
|
||||
flag_model = FlagModel(os.path.join(get_home_cache_dir(), "bge-large-zh-v1.5"),
|
||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||
use_fp16=torch.cuda.is_available())
|
||||
except Exception as e:
|
||||
model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
|
||||
local_dir=os.path.join(get_home_cache_dir(), "bge-large-zh-v1.5"),
|
||||
local_dir_use_symlinks=False)
|
||||
flag_model = FlagModel(model_dir,
|
||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||
use_fp16=torch.cuda.is_available())
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name):
|
||||
@ -54,7 +44,9 @@ class Base(ABC):
|
||||
|
||||
|
||||
class DefaultEmbedding(Base):
|
||||
def __init__(self, *args, **kwargs):
|
||||
_model = None
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
"""
|
||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||
|
||||
@ -66,7 +58,18 @@ class DefaultEmbedding(Base):
|
||||
^_-
|
||||
|
||||
"""
|
||||
self.model = flag_model
|
||||
if not DefaultEmbedding._model:
|
||||
try:
|
||||
self._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
|
||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||
use_fp16=torch.cuda.is_available())
|
||||
except Exception as e:
|
||||
model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
|
||||
local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
|
||||
local_dir_use_symlinks=False)
|
||||
self._model = FlagModel(model_dir,
|
||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||
use_fp16=torch.cuda.is_available())
|
||||
|
||||
def encode(self, texts: list, batch_size=32):
|
||||
texts = [truncate(t, 2048) for t in texts]
|
||||
@ -75,12 +78,12 @@ class DefaultEmbedding(Base):
|
||||
token_count += num_tokens_from_string(t)
|
||||
res = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res.extend(self.model.encode(texts[i:i + batch_size]).tolist())
|
||||
res.extend(self._model.encode(texts[i:i + batch_size]).tolist())
|
||||
return np.array(res), token_count
|
||||
|
||||
def encode_queries(self, text: str):
|
||||
token_count = num_tokens_from_string(text)
|
||||
return self.model.encode_queries([text]).tolist()[0], token_count
|
||||
return self._model.encode_queries([text]).tolist()[0], token_count
|
||||
|
||||
|
||||
class OpenAIEmbed(Base):
|
||||
@ -189,16 +192,19 @@ class OllamaEmbed(Base):
|
||||
|
||||
|
||||
class FastEmbed(Base):
|
||||
_model = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key: Optional[str] = None,
|
||||
model_name: str = "BAAI/bge-small-en-v1.5",
|
||||
cache_dir: Optional[str] = None,
|
||||
threads: Optional[int] = None,
|
||||
**kwargs,
|
||||
self,
|
||||
key: Optional[str] = None,
|
||||
model_name: str = "BAAI/bge-small-en-v1.5",
|
||||
cache_dir: Optional[str] = None,
|
||||
threads: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
from fastembed import TextEmbedding
|
||||
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
||||
if not FastEmbed._model:
|
||||
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
||||
|
||||
def encode(self, texts: list, batch_size=32):
|
||||
# Using the internal tokenizer to encode the texts and get the total
|
||||
@ -265,3 +271,29 @@ class YoudaoEmbed(Base):
|
||||
def encode_queries(self, text):
|
||||
embds = YoudaoEmbed._client.encode([text])
|
||||
return np.array(embds[0]), num_tokens_from_string(text)
|
||||
|
||||
|
||||
class JinaEmbed(Base):
|
||||
def __init__(self, key, model_name="jina-embeddings-v2-base-zh",
|
||||
base_url="https://api.jina.ai/v1/embeddings"):
|
||||
|
||||
self.base_url = "https://api.jina.ai/v1/embeddings"
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {key}"
|
||||
}
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list, batch_size=None):
|
||||
texts = [truncate(t, 8196) for t in texts]
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"input": texts,
|
||||
'encoding_type': 'float'
|
||||
}
|
||||
res = requests.post(self.base_url, headers=self.headers, json=data)
|
||||
return np.array([d["embedding"] for d in res["data"]]), res["usage"]["total_tokens"]
|
||||
|
||||
def encode_queries(self, text):
|
||||
embds, cnt = self.encode([text])
|
||||
return np.array(embds[0]), cnt
|
113
rag/llm/rerank_model.py
Normal file
113
rag/llm/rerank_model.py
Normal file
@ -0,0 +1,113 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import re
|
||||
import requests
|
||||
import torch
|
||||
from FlagEmbedding import FlagReranker
|
||||
from huggingface_hub import snapshot_download
|
||||
import os
|
||||
from abc import ABC
|
||||
import numpy as np
|
||||
from api.utils.file_utils import get_home_cache_dir
|
||||
from rag.utils import num_tokens_from_string, truncate
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name):
|
||||
pass
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
|
||||
class DefaultRerank(Base):
|
||||
_model = None
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
"""
|
||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||
|
||||
For Linux:
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
For Windows:
|
||||
Good luck
|
||||
^_-
|
||||
|
||||
"""
|
||||
if not DefaultRerank._model:
|
||||
try:
|
||||
self._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
|
||||
use_fp16=torch.cuda.is_available())
|
||||
except Exception as e:
|
||||
self._model = snapshot_download(repo_id=model_name,
|
||||
local_dir=os.path.join(get_home_cache_dir(),
|
||||
re.sub(r"^[a-zA-Z]+/", "", model_name)),
|
||||
local_dir_use_symlinks=False)
|
||||
self._model = FlagReranker(os.path.join(get_home_cache_dir(), model_name),
|
||||
use_fp16=torch.cuda.is_available())
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
pairs = [(query,truncate(t, 2048)) for t in texts]
|
||||
token_count = 0
|
||||
for _, t in pairs:
|
||||
token_count += num_tokens_from_string(t)
|
||||
batch_size = 32
|
||||
res = []
|
||||
for i in range(0, len(pairs), batch_size):
|
||||
scores = self._model.compute_score(pairs[i:i + batch_size], max_length=2048)
|
||||
res.extend(scores)
|
||||
return np.array(res), token_count
|
||||
|
||||
|
||||
class JinaRerank(Base):
|
||||
def __init__(self, key, model_name="jina-reranker-v1-base-en",
|
||||
base_url="https://api.jina.ai/v1/rerank"):
|
||||
self.base_url = "https://api.jina.ai/v1/rerank"
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {key}"
|
||||
}
|
||||
self.model_name = model_name
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
texts = [truncate(t, 8196) for t in texts]
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"query": query,
|
||||
"documents": texts,
|
||||
"top_n": len(texts)
|
||||
}
|
||||
res = requests.post(self.base_url, headers=self.headers, json=data)
|
||||
return np.array([d["relevance_score"] for d in res["results"]]), res["usage"]["total_tokens"]
|
||||
|
||||
|
||||
class YoudaoRerank(DefaultRerank):
|
||||
_model = None
|
||||
|
||||
def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
|
||||
from BCEmbedding import RerankerModel
|
||||
if not YoudaoRerank._model:
|
||||
try:
|
||||
print("LOADING BCE...")
|
||||
YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(
|
||||
get_home_cache_dir(),
|
||||
re.sub(r"^[a-zA-Z]+/", "", model_name)))
|
||||
except Exception as e:
|
||||
YoudaoRerank._model = RerankerModel(
|
||||
model_name_or_path=model_name.replace(
|
||||
"maidalun1020", "InfiniFlow"))
|
||||
|
@ -54,7 +54,8 @@ class EsQueryer:
|
||||
if not self.isChinese(txt):
|
||||
tks = rag_tokenizer.tokenize(txt).split(" ")
|
||||
tks_w = self.tw.weights(tks)
|
||||
q = [re.sub(r"[ \\\"']+", "", tk)+"^{:.4f}".format(w) for tk, w in tks_w]
|
||||
tks_w = [(re.sub(r"[ \\\"']+", "", tk), w) for tk, w in tks_w]
|
||||
q = ["{}^{:.4f}".format(tk, w) for tk, w in tks_w if tk]
|
||||
for i in range(1, len(tks_w)):
|
||||
q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2))
|
||||
if not q:
|
||||
@ -136,7 +137,11 @@ class EsQueryer:
|
||||
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
|
||||
import numpy as np
|
||||
sims = CosineSimilarity([avec], bvecs)
|
||||
tksim = self.token_similarity(atks, btkss)
|
||||
return np.array(sims[0]) * vtweight + \
|
||||
np.array(tksim) * tkweight, tksim, sims[0]
|
||||
|
||||
def token_similarity(self, atks, btkss):
|
||||
def toDict(tks):
|
||||
d = {}
|
||||
if isinstance(tks, str):
|
||||
@ -149,9 +154,7 @@ class EsQueryer:
|
||||
|
||||
atks = toDict(atks)
|
||||
btkss = [toDict(tks) for tks in btkss]
|
||||
tksim = [self.similarity(atks, btks) for btks in btkss]
|
||||
return np.array(sims[0]) * vtweight + \
|
||||
np.array(tksim) * tkweight, tksim, sims[0]
|
||||
return [self.similarity(atks, btks) for btks in btkss]
|
||||
|
||||
def similarity(self, qtwt, dtwt):
|
||||
if isinstance(dtwt, type("")):
|
||||
|
@ -241,11 +241,14 @@ class RagTokenizer:
|
||||
|
||||
return self.score_(res[::-1])
|
||||
|
||||
def english_normalize_(self, tks):
|
||||
return [self.stemmer.stem(self.lemmatizer.lemmatize(t)) if re.match(r"[a-zA-Z_-]+$", t) else t for t in tks]
|
||||
|
||||
def tokenize(self, line):
|
||||
line = self._strQ2B(line).lower()
|
||||
line = self._tradi2simp(line)
|
||||
zh_num = len([1 for c in line if is_chinese(c)])
|
||||
if zh_num < len(line) * 0.2:
|
||||
if zh_num == 0:
|
||||
return " ".join([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(line)])
|
||||
|
||||
arr = re.split(self.SPLIT_CHAR, line)
|
||||
@ -293,7 +296,7 @@ class RagTokenizer:
|
||||
|
||||
i = e + 1
|
||||
|
||||
res = " ".join(res)
|
||||
res = " ".join(self.english_normalize_(res))
|
||||
if self.DEBUG:
|
||||
print("[TKS]", self.merge_(res))
|
||||
return self.merge_(res)
|
||||
@ -336,7 +339,7 @@ class RagTokenizer:
|
||||
|
||||
res.append(stk)
|
||||
|
||||
return " ".join(res)
|
||||
return " ".join(self.english_normalize_(res))
|
||||
|
||||
|
||||
def is_chinese(s):
|
||||
|
@ -71,8 +71,8 @@ class Dealer:
|
||||
|
||||
s = Search()
|
||||
pg = int(req.get("page", 1)) - 1
|
||||
ps = int(req.get("size", 1000))
|
||||
topk = int(req.get("topk", 1024))
|
||||
ps = int(req.get("size", topk))
|
||||
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
|
||||
"image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int",
|
||||
"q_1024_vec", "q_1536_vec", "available_int", "content_with_weight"])
|
||||
@ -311,6 +311,26 @@ class Dealer:
|
||||
ins_tw, tkweight, vtweight)
|
||||
return sim, tksim, vtsim
|
||||
|
||||
def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3,
|
||||
vtweight=0.7, cfield="content_ltks"):
|
||||
_, keywords = self.qryr.question(query)
|
||||
|
||||
for i in sres.ids:
|
||||
if isinstance(sres.field[i].get("important_kwd", []), str):
|
||||
sres.field[i]["important_kwd"] = [sres.field[i]["important_kwd"]]
|
||||
ins_tw = []
|
||||
for i in sres.ids:
|
||||
content_ltks = sres.field[i][cfield].split(" ")
|
||||
title_tks = [t for t in sres.field[i].get("title_tks", "").split(" ") if t]
|
||||
important_kwd = sres.field[i].get("important_kwd", [])
|
||||
tks = content_ltks + title_tks + important_kwd
|
||||
ins_tw.append(tks)
|
||||
|
||||
tksim = self.qryr.token_similarity(keywords, ins_tw)
|
||||
vtsim,_ = rerank_mdl.similarity(" ".join(keywords), [rmSpace(" ".join(tks)) for tks in ins_tw])
|
||||
|
||||
return tkweight*np.array(tksim) + vtweight*vtsim, tksim, vtsim
|
||||
|
||||
def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
|
||||
return self.qryr.hybrid_similarity(ans_embd,
|
||||
ins_embd,
|
||||
@ -318,17 +338,22 @@ class Dealer:
|
||||
rag_tokenizer.tokenize(inst).split(" "))
|
||||
|
||||
def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2,
|
||||
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True):
|
||||
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None):
|
||||
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
|
||||
if not question:
|
||||
return ranks
|
||||
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": page_size,
|
||||
"question": question, "vector": True, "topk": top,
|
||||
"similarity": similarity_threshold}
|
||||
"similarity": similarity_threshold,
|
||||
"available_int": 1}
|
||||
sres = self.search(req, index_name(tenant_id), embd_mdl)
|
||||
|
||||
sim, tsim, vsim = self.rerank(
|
||||
sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
|
||||
if rerank_mdl:
|
||||
sim, tsim, vsim = self.rerank_by_model(rerank_mdl,
|
||||
sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
|
||||
else:
|
||||
sim, tsim, vsim = self.rerank(
|
||||
sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
|
||||
idx = np.argsort(sim * -1)
|
||||
|
||||
dim = len(sres.query_vector)
|
||||
|
Loading…
x
Reference in New Issue
Block a user