mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-14 02:55:55 +08:00
add rerank model (#969)
### What problem does this PR solve? feat: add rerank models to the project #724 #162 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
parent
e1f0644deb
commit
614defec21
@ -257,8 +257,15 @@ def retrieval_test():
|
|||||||
|
|
||||||
embd_mdl = TenantLLMService.model_instance(
|
embd_mdl = TenantLLMService.model_instance(
|
||||||
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||||
ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, similarity_threshold,
|
|
||||||
vector_similarity_weight, top, doc_ids)
|
rerank_mdl = None
|
||||||
|
if req.get("rerank_id"):
|
||||||
|
rerank_mdl = TenantLLMService.model_instance(
|
||||||
|
kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
|
||||||
|
|
||||||
|
ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size,
|
||||||
|
similarity_threshold, vector_similarity_weight, top,
|
||||||
|
doc_ids, rerank_mdl=rerank_mdl)
|
||||||
for c in ranks["chunks"]:
|
for c in ranks["chunks"]:
|
||||||
if "vector" in c:
|
if "vector" in c:
|
||||||
del c["vector"]
|
del c["vector"]
|
||||||
|
@ -33,6 +33,9 @@ def set_dialog():
|
|||||||
name = req.get("name", "New Dialog")
|
name = req.get("name", "New Dialog")
|
||||||
description = req.get("description", "A helpful Dialog")
|
description = req.get("description", "A helpful Dialog")
|
||||||
top_n = req.get("top_n", 6)
|
top_n = req.get("top_n", 6)
|
||||||
|
top_k = req.get("top_k", 1024)
|
||||||
|
rerank_id = req.get("rerank_id", "")
|
||||||
|
if not rerank_id: req["rerank_id"] = ""
|
||||||
similarity_threshold = req.get("similarity_threshold", 0.1)
|
similarity_threshold = req.get("similarity_threshold", 0.1)
|
||||||
vector_similarity_weight = req.get("vector_similarity_weight", 0.3)
|
vector_similarity_weight = req.get("vector_similarity_weight", 0.3)
|
||||||
llm_setting = req.get("llm_setting", {})
|
llm_setting = req.get("llm_setting", {})
|
||||||
@ -83,6 +86,8 @@ def set_dialog():
|
|||||||
"llm_setting": llm_setting,
|
"llm_setting": llm_setting,
|
||||||
"prompt_config": prompt_config,
|
"prompt_config": prompt_config,
|
||||||
"top_n": top_n,
|
"top_n": top_n,
|
||||||
|
"top_k": top_k,
|
||||||
|
"rerank_id": rerank_id,
|
||||||
"similarity_threshold": similarity_threshold,
|
"similarity_threshold": similarity_threshold,
|
||||||
"vector_similarity_weight": vector_similarity_weight
|
"vector_similarity_weight": vector_similarity_weight
|
||||||
}
|
}
|
||||||
|
@ -20,7 +20,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
|
|||||||
from api.db import StatusEnum, LLMType
|
from api.db import StatusEnum, LLMType
|
||||||
from api.db.db_models import TenantLLM
|
from api.db.db_models import TenantLLM
|
||||||
from api.utils.api_utils import get_json_result
|
from api.utils.api_utils import get_json_result
|
||||||
from rag.llm import EmbeddingModel, ChatModel
|
from rag.llm import EmbeddingModel, ChatModel, RerankModel
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/factories', methods=['GET'])
|
@manager.route('/factories', methods=['GET'])
|
||||||
@ -28,7 +28,7 @@ from rag.llm import EmbeddingModel, ChatModel
|
|||||||
def factories():
|
def factories():
|
||||||
try:
|
try:
|
||||||
fac = LLMFactoriesService.get_all()
|
fac = LLMFactoriesService.get_all()
|
||||||
return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed"]])
|
return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI"]])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -64,6 +64,16 @@ def set_api_key():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
|
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
|
||||||
e)
|
e)
|
||||||
|
elif llm.model_type == LLMType.RERANK:
|
||||||
|
mdl = RerankModel[factory](
|
||||||
|
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
|
||||||
|
try:
|
||||||
|
m, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
|
||||||
|
if len(arr[0]) == 0 or tc == 0:
|
||||||
|
raise Exception("Fail")
|
||||||
|
except Exception as e:
|
||||||
|
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
|
||||||
|
e)
|
||||||
|
|
||||||
if msg:
|
if msg:
|
||||||
return get_data_error_result(retmsg=msg)
|
return get_data_error_result(retmsg=msg)
|
||||||
@ -199,7 +209,7 @@ def list_app():
|
|||||||
llms = [m.to_dict()
|
llms = [m.to_dict()
|
||||||
for m in llms if m.status == StatusEnum.VALID.value]
|
for m in llms if m.status == StatusEnum.VALID.value]
|
||||||
for m in llms:
|
for m in llms:
|
||||||
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["Youdao","FastEmbed"]
|
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["Youdao","FastEmbed", "BAAI"]
|
||||||
|
|
||||||
llm_set = set([m["llm_name"] for m in llms])
|
llm_set = set([m["llm_name"] for m in llms])
|
||||||
for o in objs:
|
for o in objs:
|
||||||
|
@ -26,8 +26,9 @@ from api.db.services.llm_service import TenantLLMService, LLMService
|
|||||||
from api.utils.api_utils import server_error_response, validate_request
|
from api.utils.api_utils import server_error_response, validate_request
|
||||||
from api.utils import get_uuid, get_format_time, decrypt, download_img, current_timestamp, datetime_format
|
from api.utils import get_uuid, get_format_time, decrypt, download_img, current_timestamp, datetime_format
|
||||||
from api.db import UserTenantRole, LLMType, FileType
|
from api.db import UserTenantRole, LLMType, FileType
|
||||||
from api.settings import RetCode, GITHUB_OAUTH, FEISHU_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, API_KEY, \
|
from api.settings import RetCode, GITHUB_OAUTH, FEISHU_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, \
|
||||||
LLM_FACTORY, LLM_BASE_URL
|
API_KEY, \
|
||||||
|
LLM_FACTORY, LLM_BASE_URL, RERANK_MDL
|
||||||
from api.db.services.user_service import UserService, TenantService, UserTenantService
|
from api.db.services.user_service import UserService, TenantService, UserTenantService
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
from api.settings import stat_logger
|
from api.settings import stat_logger
|
||||||
@ -288,7 +289,8 @@ def user_register(user_id, user):
|
|||||||
"embd_id": EMBEDDING_MDL,
|
"embd_id": EMBEDDING_MDL,
|
||||||
"asr_id": ASR_MDL,
|
"asr_id": ASR_MDL,
|
||||||
"parser_ids": PARSERS,
|
"parser_ids": PARSERS,
|
||||||
"img2txt_id": IMAGE2TEXT_MDL
|
"img2txt_id": IMAGE2TEXT_MDL,
|
||||||
|
"rerank_id": RERANK_MDL
|
||||||
}
|
}
|
||||||
usr_tenant = {
|
usr_tenant = {
|
||||||
"tenant_id": user_id,
|
"tenant_id": user_id,
|
||||||
|
@ -54,6 +54,7 @@ class LLMType(StrEnum):
|
|||||||
EMBEDDING = 'embedding'
|
EMBEDDING = 'embedding'
|
||||||
SPEECH2TEXT = 'speech2text'
|
SPEECH2TEXT = 'speech2text'
|
||||||
IMAGE2TEXT = 'image2text'
|
IMAGE2TEXT = 'image2text'
|
||||||
|
RERANK = 'rerank'
|
||||||
|
|
||||||
|
|
||||||
class ChatStyle(StrEnum):
|
class ChatStyle(StrEnum):
|
||||||
|
@ -437,6 +437,10 @@ class Tenant(DataBaseModel):
|
|||||||
max_length=128,
|
max_length=128,
|
||||||
null=False,
|
null=False,
|
||||||
help_text="default image to text model ID")
|
help_text="default image to text model ID")
|
||||||
|
rerank_id = CharField(
|
||||||
|
max_length=128,
|
||||||
|
null=False,
|
||||||
|
help_text="default rerank model ID")
|
||||||
parser_ids = CharField(
|
parser_ids = CharField(
|
||||||
max_length=256,
|
max_length=256,
|
||||||
null=False,
|
null=False,
|
||||||
@ -771,11 +775,16 @@ class Dialog(DataBaseModel):
|
|||||||
similarity_threshold = FloatField(default=0.2)
|
similarity_threshold = FloatField(default=0.2)
|
||||||
vector_similarity_weight = FloatField(default=0.3)
|
vector_similarity_weight = FloatField(default=0.3)
|
||||||
top_n = IntegerField(default=6)
|
top_n = IntegerField(default=6)
|
||||||
|
top_k = IntegerField(default=1024)
|
||||||
do_refer = CharField(
|
do_refer = CharField(
|
||||||
max_length=1,
|
max_length=1,
|
||||||
null=False,
|
null=False,
|
||||||
help_text="it needs to insert reference index into answer or not",
|
help_text="it needs to insert reference index into answer or not",
|
||||||
default="1")
|
default="1")
|
||||||
|
rerank_id = CharField(
|
||||||
|
max_length=128,
|
||||||
|
null=False,
|
||||||
|
help_text="default rerank model ID")
|
||||||
|
|
||||||
kb_ids = JSONField(null=False, default=[])
|
kb_ids = JSONField(null=False, default=[])
|
||||||
status = CharField(
|
status = CharField(
|
||||||
@ -825,11 +834,29 @@ class API4Conversation(DataBaseModel):
|
|||||||
|
|
||||||
|
|
||||||
def migrate_db():
|
def migrate_db():
|
||||||
try:
|
|
||||||
with DB.transaction():
|
with DB.transaction():
|
||||||
migrator = MySQLMigrator(DB)
|
migrator = MySQLMigrator(DB)
|
||||||
|
try:
|
||||||
migrate(
|
migrate(
|
||||||
migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="", help_text="where dose this document come from"))
|
migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="", help_text="where dose this document come from"))
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
try:
|
||||||
|
migrate(
|
||||||
|
migrator.add_column('tenant', 'rerank_id', CharField(max_length=128, null=False, default="BAAI/bge-reranker-v2-m3", help_text="default rerank model ID"))
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
migrate(
|
||||||
|
migrator.add_column('dialog', 'rerank_id', CharField(max_length=128, null=False, default="", help_text="default rerank model ID"))
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
migrate(
|
||||||
|
migrator.add_column('dialog', 'top_k', IntegerField(default=1024))
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
@ -142,7 +142,17 @@ factory_infos = [{
|
|||||||
"logo": "",
|
"logo": "",
|
||||||
"tags": "LLM,TEXT EMBEDDING",
|
"tags": "LLM,TEXT EMBEDDING",
|
||||||
"status": "1",
|
"status": "1",
|
||||||
},
|
},{
|
||||||
|
"name": "Jina",
|
||||||
|
"logo": "",
|
||||||
|
"tags": "TEXT EMBEDDING, TEXT RE-RANK",
|
||||||
|
"status": "1",
|
||||||
|
},{
|
||||||
|
"name": "BAAI",
|
||||||
|
"logo": "",
|
||||||
|
"tags": "TEXT EMBEDDING, TEXT RE-RANK",
|
||||||
|
"status": "1",
|
||||||
|
}
|
||||||
# {
|
# {
|
||||||
# "name": "文心一言",
|
# "name": "文心一言",
|
||||||
# "logo": "",
|
# "logo": "",
|
||||||
@ -367,6 +377,13 @@ def init_llm_factory():
|
|||||||
"max_tokens": 512,
|
"max_tokens": 512,
|
||||||
"model_type": LLMType.EMBEDDING.value
|
"model_type": LLMType.EMBEDDING.value
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"fid": factory_infos[7]["name"],
|
||||||
|
"llm_name": "maidalun1020/bce-reranker-base_v1",
|
||||||
|
"tags": "RE-RANK, 8K",
|
||||||
|
"max_tokens": 8196,
|
||||||
|
"model_type": LLMType.RERANK.value
|
||||||
|
},
|
||||||
# ------------------------ DeepSeek -----------------------
|
# ------------------------ DeepSeek -----------------------
|
||||||
{
|
{
|
||||||
"fid": factory_infos[8]["name"],
|
"fid": factory_infos[8]["name"],
|
||||||
@ -440,6 +457,85 @@ def init_llm_factory():
|
|||||||
"max_tokens": 512,
|
"max_tokens": 512,
|
||||||
"model_type": LLMType.EMBEDDING.value
|
"model_type": LLMType.EMBEDDING.value
|
||||||
},
|
},
|
||||||
|
# ------------------------ Jina -----------------------
|
||||||
|
{
|
||||||
|
"fid": factory_infos[11]["name"],
|
||||||
|
"llm_name": "jina-reranker-v1-base-en",
|
||||||
|
"tags": "RE-RANK,8k",
|
||||||
|
"max_tokens": 8196,
|
||||||
|
"model_type": LLMType.RERANK.value
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"fid": factory_infos[11]["name"],
|
||||||
|
"llm_name": "jina-reranker-v1-turbo-en",
|
||||||
|
"tags": "RE-RANK,8k",
|
||||||
|
"max_tokens": 8196,
|
||||||
|
"model_type": LLMType.RERANK.value
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"fid": factory_infos[11]["name"],
|
||||||
|
"llm_name": "jina-reranker-v1-tiny-en",
|
||||||
|
"tags": "RE-RANK,8k",
|
||||||
|
"max_tokens": 8196,
|
||||||
|
"model_type": LLMType.RERANK.value
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"fid": factory_infos[11]["name"],
|
||||||
|
"llm_name": "jina-colbert-v1-en",
|
||||||
|
"tags": "RE-RANK,8k",
|
||||||
|
"max_tokens": 8196,
|
||||||
|
"model_type": LLMType.RERANK.value
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"fid": factory_infos[11]["name"],
|
||||||
|
"llm_name": "jina-embeddings-v2-base-en",
|
||||||
|
"tags": "TEXT EMBEDDING",
|
||||||
|
"max_tokens": 8196,
|
||||||
|
"model_type": LLMType.EMBEDDING.value
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"fid": factory_infos[11]["name"],
|
||||||
|
"llm_name": "jina-embeddings-v2-base-de",
|
||||||
|
"tags": "TEXT EMBEDDING",
|
||||||
|
"max_tokens": 8196,
|
||||||
|
"model_type": LLMType.EMBEDDING.value
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"fid": factory_infos[11]["name"],
|
||||||
|
"llm_name": "jina-embeddings-v2-base-es",
|
||||||
|
"tags": "TEXT EMBEDDING",
|
||||||
|
"max_tokens": 8196,
|
||||||
|
"model_type": LLMType.EMBEDDING.value
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"fid": factory_infos[11]["name"],
|
||||||
|
"llm_name": "jina-embeddings-v2-base-code",
|
||||||
|
"tags": "TEXT EMBEDDING",
|
||||||
|
"max_tokens": 8196,
|
||||||
|
"model_type": LLMType.EMBEDDING.value
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"fid": factory_infos[11]["name"],
|
||||||
|
"llm_name": "jina-embeddings-v2-base-zh",
|
||||||
|
"tags": "TEXT EMBEDDING",
|
||||||
|
"max_tokens": 8196,
|
||||||
|
"model_type": LLMType.EMBEDDING.value
|
||||||
|
},
|
||||||
|
# ------------------------ BAAI -----------------------
|
||||||
|
{
|
||||||
|
"fid": factory_infos[12]["name"],
|
||||||
|
"llm_name": "BAAI/bge-large-zh-v1.5",
|
||||||
|
"tags": "TEXT EMBEDDING,",
|
||||||
|
"max_tokens": 1024,
|
||||||
|
"model_type": LLMType.EMBEDDING.value
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"fid": factory_infos[12]["name"],
|
||||||
|
"llm_name": "BAAI/bge-reranker-v2-m3",
|
||||||
|
"tags": "LLM,CHAT,",
|
||||||
|
"max_tokens": 16385,
|
||||||
|
"model_type": LLMType.RERANK.value
|
||||||
|
},
|
||||||
]
|
]
|
||||||
for info in factory_infos:
|
for info in factory_infos:
|
||||||
try:
|
try:
|
||||||
|
@ -115,11 +115,14 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
|
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
|
||||||
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
|
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
|
||||||
else:
|
else:
|
||||||
|
rerank_mdl = None
|
||||||
|
if dialog.rerank_id:
|
||||||
|
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
|
||||||
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
|
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
|
||||||
dialog.similarity_threshold,
|
dialog.similarity_threshold,
|
||||||
dialog.vector_similarity_weight,
|
dialog.vector_similarity_weight,
|
||||||
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
|
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
|
||||||
top=1024, aggs=False)
|
top=1024, aggs=False, rerank_mdl=rerank_mdl)
|
||||||
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
||||||
chat_logger.info(
|
chat_logger.info(
|
||||||
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
#
|
#
|
||||||
from api.db.services.user_service import TenantService
|
from api.db.services.user_service import TenantService
|
||||||
from api.settings import database_logger
|
from api.settings import database_logger
|
||||||
from rag.llm import EmbeddingModel, CvModel, ChatModel
|
from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel
|
||||||
from api.db import LLMType
|
from api.db import LLMType
|
||||||
from api.db.db_models import DB, UserTenant
|
from api.db.db_models import DB, UserTenant
|
||||||
from api.db.db_models import LLMFactories, LLM, TenantLLM
|
from api.db.db_models import LLMFactories, LLM, TenantLLM
|
||||||
@ -73,21 +73,25 @@ class TenantLLMService(CommonService):
|
|||||||
mdlnm = tenant.img2txt_id
|
mdlnm = tenant.img2txt_id
|
||||||
elif llm_type == LLMType.CHAT.value:
|
elif llm_type == LLMType.CHAT.value:
|
||||||
mdlnm = tenant.llm_id if not llm_name else llm_name
|
mdlnm = tenant.llm_id if not llm_name else llm_name
|
||||||
|
elif llm_type == LLMType.RERANK:
|
||||||
|
mdlnm = tenant.rerank_id if not llm_name else llm_name
|
||||||
else:
|
else:
|
||||||
assert False, "LLM type error"
|
assert False, "LLM type error"
|
||||||
|
|
||||||
model_config = cls.get_api_key(tenant_id, mdlnm)
|
model_config = cls.get_api_key(tenant_id, mdlnm)
|
||||||
if model_config: model_config = model_config.to_dict()
|
if model_config: model_config = model_config.to_dict()
|
||||||
if not model_config:
|
if not model_config:
|
||||||
if llm_type == LLMType.EMBEDDING.value:
|
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
||||||
llm = LLMService.query(llm_name=llm_name)
|
llm = LLMService.query(llm_name=llm_name)
|
||||||
if llm and llm[0].fid in ["Youdao", "FastEmbed", "DeepSeek"]:
|
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
|
||||||
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""}
|
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""}
|
||||||
if not model_config:
|
if not model_config:
|
||||||
if llm_name == "flag-embedding":
|
if llm_name == "flag-embedding":
|
||||||
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
|
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
|
||||||
"llm_name": llm_name, "api_base": ""}
|
"llm_name": llm_name, "api_base": ""}
|
||||||
else:
|
else:
|
||||||
|
if not mdlnm:
|
||||||
|
raise LookupError(f"Type of {llm_type} model is not set.")
|
||||||
raise LookupError("Model({}) not authorized".format(mdlnm))
|
raise LookupError("Model({}) not authorized".format(mdlnm))
|
||||||
|
|
||||||
if llm_type == LLMType.EMBEDDING.value:
|
if llm_type == LLMType.EMBEDDING.value:
|
||||||
@ -96,6 +100,12 @@ class TenantLLMService(CommonService):
|
|||||||
return EmbeddingModel[model_config["llm_factory"]](
|
return EmbeddingModel[model_config["llm_factory"]](
|
||||||
model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||||
|
|
||||||
|
if llm_type == LLMType.RERANK:
|
||||||
|
if model_config["llm_factory"] not in RerankModel:
|
||||||
|
return
|
||||||
|
return RerankModel[model_config["llm_factory"]](
|
||||||
|
model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||||
|
|
||||||
if llm_type == LLMType.IMAGE2TEXT.value:
|
if llm_type == LLMType.IMAGE2TEXT.value:
|
||||||
if model_config["llm_factory"] not in CvModel:
|
if model_config["llm_factory"] not in CvModel:
|
||||||
return
|
return
|
||||||
@ -125,14 +135,20 @@ class TenantLLMService(CommonService):
|
|||||||
mdlnm = tenant.img2txt_id
|
mdlnm = tenant.img2txt_id
|
||||||
elif llm_type == LLMType.CHAT.value:
|
elif llm_type == LLMType.CHAT.value:
|
||||||
mdlnm = tenant.llm_id if not llm_name else llm_name
|
mdlnm = tenant.llm_id if not llm_name else llm_name
|
||||||
|
elif llm_type == LLMType.RERANK:
|
||||||
|
mdlnm = tenant.llm_id if not llm_name else llm_name
|
||||||
else:
|
else:
|
||||||
assert False, "LLM type error"
|
assert False, "LLM type error"
|
||||||
|
|
||||||
num = 0
|
num = 0
|
||||||
|
try:
|
||||||
for u in cls.query(tenant_id = tenant_id, llm_name=mdlnm):
|
for u in cls.query(tenant_id = tenant_id, llm_name=mdlnm):
|
||||||
num += cls.model.update(used_tokens = u.used_tokens + used_tokens)\
|
num += cls.model.update(used_tokens = u.used_tokens + used_tokens)\
|
||||||
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\
|
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\
|
||||||
.execute()
|
.execute()
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
pass
|
||||||
return num
|
return num
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -176,6 +192,14 @@ class LLMBundle(object):
|
|||||||
"Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
|
"Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
|
||||||
return emd, used_tokens
|
return emd, used_tokens
|
||||||
|
|
||||||
|
def similarity(self, query: str, texts: list):
|
||||||
|
sim, used_tokens = self.mdl.similarity(query, texts)
|
||||||
|
if not TenantLLMService.increase_usage(
|
||||||
|
self.tenant_id, self.llm_type, used_tokens):
|
||||||
|
database_logger.error(
|
||||||
|
"Can't update token usage for {}/RERANK".format(self.tenant_id))
|
||||||
|
return sim, used_tokens
|
||||||
|
|
||||||
def describe(self, image, max_tokens=300):
|
def describe(self, image, max_tokens=300):
|
||||||
txt, used_tokens = self.mdl.describe(image, max_tokens)
|
txt, used_tokens = self.mdl.describe(image, max_tokens)
|
||||||
if not TenantLLMService.increase_usage(
|
if not TenantLLMService.increase_usage(
|
||||||
|
@ -93,6 +93,7 @@ class TenantService(CommonService):
|
|||||||
cls.model.name,
|
cls.model.name,
|
||||||
cls.model.llm_id,
|
cls.model.llm_id,
|
||||||
cls.model.embd_id,
|
cls.model.embd_id,
|
||||||
|
cls.model.rerank_id,
|
||||||
cls.model.asr_id,
|
cls.model.asr_id,
|
||||||
cls.model.img2txt_id,
|
cls.model.img2txt_id,
|
||||||
cls.model.parser_ids,
|
cls.model.parser_ids,
|
||||||
|
@ -89,9 +89,22 @@ default_llm = {
|
|||||||
},
|
},
|
||||||
"DeepSeek": {
|
"DeepSeek": {
|
||||||
"chat_model": "deepseek-chat",
|
"chat_model": "deepseek-chat",
|
||||||
|
"embedding_model": "",
|
||||||
|
"image2text_model": "",
|
||||||
|
"asr_model": "",
|
||||||
|
},
|
||||||
|
"VolcEngine": {
|
||||||
|
"chat_model": "",
|
||||||
|
"embedding_model": "",
|
||||||
|
"image2text_model": "",
|
||||||
|
"asr_model": "",
|
||||||
|
},
|
||||||
|
"BAAI": {
|
||||||
|
"chat_model": "",
|
||||||
"embedding_model": "BAAI/bge-large-zh-v1.5",
|
"embedding_model": "BAAI/bge-large-zh-v1.5",
|
||||||
"image2text_model": "",
|
"image2text_model": "",
|
||||||
"asr_model": "",
|
"asr_model": "",
|
||||||
|
"rerank_model": "BAAI/bge-reranker-v2-m3",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
LLM = get_base_config("user_default_llm", {})
|
LLM = get_base_config("user_default_llm", {})
|
||||||
@ -104,7 +117,8 @@ if LLM_FACTORY not in default_llm:
|
|||||||
f"LLM factory {LLM_FACTORY} has not supported yet, switch to 'Tongyi-Qianwen/QWen' automatically, and please check the API_KEY in service_conf.yaml.")
|
f"LLM factory {LLM_FACTORY} has not supported yet, switch to 'Tongyi-Qianwen/QWen' automatically, and please check the API_KEY in service_conf.yaml.")
|
||||||
LLM_FACTORY = "Tongyi-Qianwen"
|
LLM_FACTORY = "Tongyi-Qianwen"
|
||||||
CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"]
|
CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"]
|
||||||
EMBEDDING_MDL = default_llm[LLM_FACTORY]["embedding_model"]
|
EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"]
|
||||||
|
RERANK_MDL = default_llm["BAAI"]["rerank_model"]
|
||||||
ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
|
ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
|
||||||
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
|
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
|
||||||
|
|
||||||
|
@ -16,18 +16,19 @@
|
|||||||
from .embedding_model import *
|
from .embedding_model import *
|
||||||
from .chat_model import *
|
from .chat_model import *
|
||||||
from .cv_model import *
|
from .cv_model import *
|
||||||
|
from .rerank_model import *
|
||||||
|
|
||||||
|
|
||||||
EmbeddingModel = {
|
EmbeddingModel = {
|
||||||
"Ollama": OllamaEmbed,
|
"Ollama": OllamaEmbed,
|
||||||
"OpenAI": OpenAIEmbed,
|
"OpenAI": OpenAIEmbed,
|
||||||
"Xinference": XinferenceEmbed,
|
"Xinference": XinferenceEmbed,
|
||||||
"Tongyi-Qianwen": DefaultEmbedding, #QWenEmbed,
|
"Tongyi-Qianwen": DefaultEmbedding,#QWenEmbed,
|
||||||
"ZHIPU-AI": ZhipuEmbed,
|
"ZHIPU-AI": ZhipuEmbed,
|
||||||
"FastEmbed": FastEmbed,
|
"FastEmbed": FastEmbed,
|
||||||
"Youdao": YoudaoEmbed,
|
"Youdao": YoudaoEmbed,
|
||||||
"DeepSeek": DefaultEmbedding,
|
"BaiChuan": BaiChuanEmbed,
|
||||||
"BaiChuan": BaiChuanEmbed
|
"BAAI": DefaultEmbedding
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -52,3 +53,9 @@ ChatModel = {
|
|||||||
"BaiChuan": BaiChuanChat
|
"BaiChuan": BaiChuanChat
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
RerankModel = {
|
||||||
|
"BAAI": DefaultRerank,
|
||||||
|
"Jina": JinaRerank,
|
||||||
|
"Youdao": YoudaoRerank,
|
||||||
|
}
|
||||||
|
@ -13,8 +13,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import re
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from zhipuai import ZhipuAI
|
from zhipuai import ZhipuAI
|
||||||
import os
|
import os
|
||||||
@ -26,21 +28,9 @@ from FlagEmbedding import FlagModel
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from api.utils.file_utils import get_project_base_directory, get_home_cache_dir
|
from api.utils.file_utils import get_home_cache_dir
|
||||||
from rag.utils import num_tokens_from_string, truncate
|
from rag.utils import num_tokens_from_string, truncate
|
||||||
|
|
||||||
try:
|
|
||||||
flag_model = FlagModel(os.path.join(get_home_cache_dir(), "bge-large-zh-v1.5"),
|
|
||||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
|
||||||
use_fp16=torch.cuda.is_available())
|
|
||||||
except Exception as e:
|
|
||||||
model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
|
|
||||||
local_dir=os.path.join(get_home_cache_dir(), "bge-large-zh-v1.5"),
|
|
||||||
local_dir_use_symlinks=False)
|
|
||||||
flag_model = FlagModel(model_dir,
|
|
||||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
|
||||||
use_fp16=torch.cuda.is_available())
|
|
||||||
|
|
||||||
|
|
||||||
class Base(ABC):
|
class Base(ABC):
|
||||||
def __init__(self, key, model_name):
|
def __init__(self, key, model_name):
|
||||||
@ -54,7 +44,9 @@ class Base(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class DefaultEmbedding(Base):
|
class DefaultEmbedding(Base):
|
||||||
def __init__(self, *args, **kwargs):
|
_model = None
|
||||||
|
|
||||||
|
def __init__(self, key, model_name, **kwargs):
|
||||||
"""
|
"""
|
||||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||||
|
|
||||||
@ -66,7 +58,18 @@ class DefaultEmbedding(Base):
|
|||||||
^_-
|
^_-
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.model = flag_model
|
if not DefaultEmbedding._model:
|
||||||
|
try:
|
||||||
|
self._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
|
||||||
|
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||||
|
use_fp16=torch.cuda.is_available())
|
||||||
|
except Exception as e:
|
||||||
|
model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
|
||||||
|
local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
|
||||||
|
local_dir_use_symlinks=False)
|
||||||
|
self._model = FlagModel(model_dir,
|
||||||
|
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||||
|
use_fp16=torch.cuda.is_available())
|
||||||
|
|
||||||
def encode(self, texts: list, batch_size=32):
|
def encode(self, texts: list, batch_size=32):
|
||||||
texts = [truncate(t, 2048) for t in texts]
|
texts = [truncate(t, 2048) for t in texts]
|
||||||
@ -75,12 +78,12 @@ class DefaultEmbedding(Base):
|
|||||||
token_count += num_tokens_from_string(t)
|
token_count += num_tokens_from_string(t)
|
||||||
res = []
|
res = []
|
||||||
for i in range(0, len(texts), batch_size):
|
for i in range(0, len(texts), batch_size):
|
||||||
res.extend(self.model.encode(texts[i:i + batch_size]).tolist())
|
res.extend(self._model.encode(texts[i:i + batch_size]).tolist())
|
||||||
return np.array(res), token_count
|
return np.array(res), token_count
|
||||||
|
|
||||||
def encode_queries(self, text: str):
|
def encode_queries(self, text: str):
|
||||||
token_count = num_tokens_from_string(text)
|
token_count = num_tokens_from_string(text)
|
||||||
return self.model.encode_queries([text]).tolist()[0], token_count
|
return self._model.encode_queries([text]).tolist()[0], token_count
|
||||||
|
|
||||||
|
|
||||||
class OpenAIEmbed(Base):
|
class OpenAIEmbed(Base):
|
||||||
@ -189,6 +192,8 @@ class OllamaEmbed(Base):
|
|||||||
|
|
||||||
|
|
||||||
class FastEmbed(Base):
|
class FastEmbed(Base):
|
||||||
|
_model = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
key: Optional[str] = None,
|
key: Optional[str] = None,
|
||||||
@ -198,6 +203,7 @@ class FastEmbed(Base):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
from fastembed import TextEmbedding
|
from fastembed import TextEmbedding
|
||||||
|
if not FastEmbed._model:
|
||||||
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
||||||
|
|
||||||
def encode(self, texts: list, batch_size=32):
|
def encode(self, texts: list, batch_size=32):
|
||||||
@ -265,3 +271,29 @@ class YoudaoEmbed(Base):
|
|||||||
def encode_queries(self, text):
|
def encode_queries(self, text):
|
||||||
embds = YoudaoEmbed._client.encode([text])
|
embds = YoudaoEmbed._client.encode([text])
|
||||||
return np.array(embds[0]), num_tokens_from_string(text)
|
return np.array(embds[0]), num_tokens_from_string(text)
|
||||||
|
|
||||||
|
|
||||||
|
class JinaEmbed(Base):
|
||||||
|
def __init__(self, key, model_name="jina-embeddings-v2-base-zh",
|
||||||
|
base_url="https://api.jina.ai/v1/embeddings"):
|
||||||
|
|
||||||
|
self.base_url = "https://api.jina.ai/v1/embeddings"
|
||||||
|
self.headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {key}"
|
||||||
|
}
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
def encode(self, texts: list, batch_size=None):
|
||||||
|
texts = [truncate(t, 8196) for t in texts]
|
||||||
|
data = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"input": texts,
|
||||||
|
'encoding_type': 'float'
|
||||||
|
}
|
||||||
|
res = requests.post(self.base_url, headers=self.headers, json=data)
|
||||||
|
return np.array([d["embedding"] for d in res["data"]]), res["usage"]["total_tokens"]
|
||||||
|
|
||||||
|
def encode_queries(self, text):
|
||||||
|
embds, cnt = self.encode([text])
|
||||||
|
return np.array(embds[0]), cnt
|
113
rag/llm/rerank_model.py
Normal file
113
rag/llm/rerank_model.py
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import re
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
from FlagEmbedding import FlagReranker
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
import os
|
||||||
|
from abc import ABC
|
||||||
|
import numpy as np
|
||||||
|
from api.utils.file_utils import get_home_cache_dir
|
||||||
|
from rag.utils import num_tokens_from_string, truncate
|
||||||
|
|
||||||
|
|
||||||
|
class Base(ABC):
|
||||||
|
def __init__(self, key, model_name):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def similarity(self, query: str, texts: list):
|
||||||
|
raise NotImplementedError("Please implement encode method!")
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultRerank(Base):
|
||||||
|
_model = None
|
||||||
|
|
||||||
|
def __init__(self, key, model_name, **kwargs):
|
||||||
|
"""
|
||||||
|
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||||
|
|
||||||
|
For Linux:
|
||||||
|
export HF_ENDPOINT=https://hf-mirror.com
|
||||||
|
|
||||||
|
For Windows:
|
||||||
|
Good luck
|
||||||
|
^_-
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not DefaultRerank._model:
|
||||||
|
try:
|
||||||
|
self._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
|
||||||
|
use_fp16=torch.cuda.is_available())
|
||||||
|
except Exception as e:
|
||||||
|
self._model = snapshot_download(repo_id=model_name,
|
||||||
|
local_dir=os.path.join(get_home_cache_dir(),
|
||||||
|
re.sub(r"^[a-zA-Z]+/", "", model_name)),
|
||||||
|
local_dir_use_symlinks=False)
|
||||||
|
self._model = FlagReranker(os.path.join(get_home_cache_dir(), model_name),
|
||||||
|
use_fp16=torch.cuda.is_available())
|
||||||
|
|
||||||
|
def similarity(self, query: str, texts: list):
|
||||||
|
pairs = [(query,truncate(t, 2048)) for t in texts]
|
||||||
|
token_count = 0
|
||||||
|
for _, t in pairs:
|
||||||
|
token_count += num_tokens_from_string(t)
|
||||||
|
batch_size = 32
|
||||||
|
res = []
|
||||||
|
for i in range(0, len(pairs), batch_size):
|
||||||
|
scores = self._model.compute_score(pairs[i:i + batch_size], max_length=2048)
|
||||||
|
res.extend(scores)
|
||||||
|
return np.array(res), token_count
|
||||||
|
|
||||||
|
|
||||||
|
class JinaRerank(Base):
|
||||||
|
def __init__(self, key, model_name="jina-reranker-v1-base-en",
|
||||||
|
base_url="https://api.jina.ai/v1/rerank"):
|
||||||
|
self.base_url = "https://api.jina.ai/v1/rerank"
|
||||||
|
self.headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {key}"
|
||||||
|
}
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
def similarity(self, query: str, texts: list):
|
||||||
|
texts = [truncate(t, 8196) for t in texts]
|
||||||
|
data = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"query": query,
|
||||||
|
"documents": texts,
|
||||||
|
"top_n": len(texts)
|
||||||
|
}
|
||||||
|
res = requests.post(self.base_url, headers=self.headers, json=data)
|
||||||
|
return np.array([d["relevance_score"] for d in res["results"]]), res["usage"]["total_tokens"]
|
||||||
|
|
||||||
|
|
||||||
|
class YoudaoRerank(DefaultRerank):
|
||||||
|
_model = None
|
||||||
|
|
||||||
|
def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
|
||||||
|
from BCEmbedding import RerankerModel
|
||||||
|
if not YoudaoRerank._model:
|
||||||
|
try:
|
||||||
|
print("LOADING BCE...")
|
||||||
|
YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(
|
||||||
|
get_home_cache_dir(),
|
||||||
|
re.sub(r"^[a-zA-Z]+/", "", model_name)))
|
||||||
|
except Exception as e:
|
||||||
|
YoudaoRerank._model = RerankerModel(
|
||||||
|
model_name_or_path=model_name.replace(
|
||||||
|
"maidalun1020", "InfiniFlow"))
|
||||||
|
|
@ -54,7 +54,8 @@ class EsQueryer:
|
|||||||
if not self.isChinese(txt):
|
if not self.isChinese(txt):
|
||||||
tks = rag_tokenizer.tokenize(txt).split(" ")
|
tks = rag_tokenizer.tokenize(txt).split(" ")
|
||||||
tks_w = self.tw.weights(tks)
|
tks_w = self.tw.weights(tks)
|
||||||
q = [re.sub(r"[ \\\"']+", "", tk)+"^{:.4f}".format(w) for tk, w in tks_w]
|
tks_w = [(re.sub(r"[ \\\"']+", "", tk), w) for tk, w in tks_w]
|
||||||
|
q = ["{}^{:.4f}".format(tk, w) for tk, w in tks_w if tk]
|
||||||
for i in range(1, len(tks_w)):
|
for i in range(1, len(tks_w)):
|
||||||
q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2))
|
q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2))
|
||||||
if not q:
|
if not q:
|
||||||
@ -136,7 +137,11 @@ class EsQueryer:
|
|||||||
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
|
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
|
||||||
import numpy as np
|
import numpy as np
|
||||||
sims = CosineSimilarity([avec], bvecs)
|
sims = CosineSimilarity([avec], bvecs)
|
||||||
|
tksim = self.token_similarity(atks, btkss)
|
||||||
|
return np.array(sims[0]) * vtweight + \
|
||||||
|
np.array(tksim) * tkweight, tksim, sims[0]
|
||||||
|
|
||||||
|
def token_similarity(self, atks, btkss):
|
||||||
def toDict(tks):
|
def toDict(tks):
|
||||||
d = {}
|
d = {}
|
||||||
if isinstance(tks, str):
|
if isinstance(tks, str):
|
||||||
@ -149,9 +154,7 @@ class EsQueryer:
|
|||||||
|
|
||||||
atks = toDict(atks)
|
atks = toDict(atks)
|
||||||
btkss = [toDict(tks) for tks in btkss]
|
btkss = [toDict(tks) for tks in btkss]
|
||||||
tksim = [self.similarity(atks, btks) for btks in btkss]
|
return [self.similarity(atks, btks) for btks in btkss]
|
||||||
return np.array(sims[0]) * vtweight + \
|
|
||||||
np.array(tksim) * tkweight, tksim, sims[0]
|
|
||||||
|
|
||||||
def similarity(self, qtwt, dtwt):
|
def similarity(self, qtwt, dtwt):
|
||||||
if isinstance(dtwt, type("")):
|
if isinstance(dtwt, type("")):
|
||||||
|
@ -241,11 +241,14 @@ class RagTokenizer:
|
|||||||
|
|
||||||
return self.score_(res[::-1])
|
return self.score_(res[::-1])
|
||||||
|
|
||||||
|
def english_normalize_(self, tks):
|
||||||
|
return [self.stemmer.stem(self.lemmatizer.lemmatize(t)) if re.match(r"[a-zA-Z_-]+$", t) else t for t in tks]
|
||||||
|
|
||||||
def tokenize(self, line):
|
def tokenize(self, line):
|
||||||
line = self._strQ2B(line).lower()
|
line = self._strQ2B(line).lower()
|
||||||
line = self._tradi2simp(line)
|
line = self._tradi2simp(line)
|
||||||
zh_num = len([1 for c in line if is_chinese(c)])
|
zh_num = len([1 for c in line if is_chinese(c)])
|
||||||
if zh_num < len(line) * 0.2:
|
if zh_num == 0:
|
||||||
return " ".join([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(line)])
|
return " ".join([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(line)])
|
||||||
|
|
||||||
arr = re.split(self.SPLIT_CHAR, line)
|
arr = re.split(self.SPLIT_CHAR, line)
|
||||||
@ -293,7 +296,7 @@ class RagTokenizer:
|
|||||||
|
|
||||||
i = e + 1
|
i = e + 1
|
||||||
|
|
||||||
res = " ".join(res)
|
res = " ".join(self.english_normalize_(res))
|
||||||
if self.DEBUG:
|
if self.DEBUG:
|
||||||
print("[TKS]", self.merge_(res))
|
print("[TKS]", self.merge_(res))
|
||||||
return self.merge_(res)
|
return self.merge_(res)
|
||||||
@ -336,7 +339,7 @@ class RagTokenizer:
|
|||||||
|
|
||||||
res.append(stk)
|
res.append(stk)
|
||||||
|
|
||||||
return " ".join(res)
|
return " ".join(self.english_normalize_(res))
|
||||||
|
|
||||||
|
|
||||||
def is_chinese(s):
|
def is_chinese(s):
|
||||||
|
@ -71,8 +71,8 @@ class Dealer:
|
|||||||
|
|
||||||
s = Search()
|
s = Search()
|
||||||
pg = int(req.get("page", 1)) - 1
|
pg = int(req.get("page", 1)) - 1
|
||||||
ps = int(req.get("size", 1000))
|
|
||||||
topk = int(req.get("topk", 1024))
|
topk = int(req.get("topk", 1024))
|
||||||
|
ps = int(req.get("size", topk))
|
||||||
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
|
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
|
||||||
"image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int",
|
"image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int",
|
||||||
"q_1024_vec", "q_1536_vec", "available_int", "content_with_weight"])
|
"q_1024_vec", "q_1536_vec", "available_int", "content_with_weight"])
|
||||||
@ -311,6 +311,26 @@ class Dealer:
|
|||||||
ins_tw, tkweight, vtweight)
|
ins_tw, tkweight, vtweight)
|
||||||
return sim, tksim, vtsim
|
return sim, tksim, vtsim
|
||||||
|
|
||||||
|
def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3,
|
||||||
|
vtweight=0.7, cfield="content_ltks"):
|
||||||
|
_, keywords = self.qryr.question(query)
|
||||||
|
|
||||||
|
for i in sres.ids:
|
||||||
|
if isinstance(sres.field[i].get("important_kwd", []), str):
|
||||||
|
sres.field[i]["important_kwd"] = [sres.field[i]["important_kwd"]]
|
||||||
|
ins_tw = []
|
||||||
|
for i in sres.ids:
|
||||||
|
content_ltks = sres.field[i][cfield].split(" ")
|
||||||
|
title_tks = [t for t in sres.field[i].get("title_tks", "").split(" ") if t]
|
||||||
|
important_kwd = sres.field[i].get("important_kwd", [])
|
||||||
|
tks = content_ltks + title_tks + important_kwd
|
||||||
|
ins_tw.append(tks)
|
||||||
|
|
||||||
|
tksim = self.qryr.token_similarity(keywords, ins_tw)
|
||||||
|
vtsim,_ = rerank_mdl.similarity(" ".join(keywords), [rmSpace(" ".join(tks)) for tks in ins_tw])
|
||||||
|
|
||||||
|
return tkweight*np.array(tksim) + vtweight*vtsim, tksim, vtsim
|
||||||
|
|
||||||
def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
|
def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
|
||||||
return self.qryr.hybrid_similarity(ans_embd,
|
return self.qryr.hybrid_similarity(ans_embd,
|
||||||
ins_embd,
|
ins_embd,
|
||||||
@ -318,15 +338,20 @@ class Dealer:
|
|||||||
rag_tokenizer.tokenize(inst).split(" "))
|
rag_tokenizer.tokenize(inst).split(" "))
|
||||||
|
|
||||||
def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2,
|
def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2,
|
||||||
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True):
|
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None):
|
||||||
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
|
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
|
||||||
if not question:
|
if not question:
|
||||||
return ranks
|
return ranks
|
||||||
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": page_size,
|
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": page_size,
|
||||||
"question": question, "vector": True, "topk": top,
|
"question": question, "vector": True, "topk": top,
|
||||||
"similarity": similarity_threshold}
|
"similarity": similarity_threshold,
|
||||||
|
"available_int": 1}
|
||||||
sres = self.search(req, index_name(tenant_id), embd_mdl)
|
sres = self.search(req, index_name(tenant_id), embd_mdl)
|
||||||
|
|
||||||
|
if rerank_mdl:
|
||||||
|
sim, tsim, vsim = self.rerank_by_model(rerank_mdl,
|
||||||
|
sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
|
||||||
|
else:
|
||||||
sim, tsim, vsim = self.rerank(
|
sim, tsim, vsim = self.rerank(
|
||||||
sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
|
sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
|
||||||
idx = np.argsort(sim * -1)
|
idx = np.argsort(sim * -1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user