Add bce-embedding and fastembed (#383)

### What problem does this PR solve?


Issue link:#326

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
KevinHuSh 2024-04-16 16:42:19 +08:00 committed by GitHub
parent a7be5d4e8b
commit 890561703b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 99 additions and 37 deletions

View File

@ -55,6 +55,8 @@
## 📌 Latest Features
- 2024-04-16 Add an embedding model 'bce-embedding-base_v1' from [QAnything](https://github.com/netease-youdao/QAnything).
- 2024-04-16 Add [FastEmbed](https://github.com/qdrant/fastembed) is designed for light and speeding embedding.
- 2024-04-11 Support [Xinference](./docs/xinference.md) for local LLM deployment.
- 2024-04-10 Add a new layout recognization model for analyzing Laws documentation.
- 2024-04-08 Support [Ollama](./docs/ollama.md) for local LLM deployment.

View File

@ -55,6 +55,8 @@
## 📌 最新の機能
- 2024-04-16 [QAnything](https://github.com/netease-youdao/QAnything) から埋め込みモデル「bce-embedding-base_v1」を追加します。
- 2024-04-16 [FastEmbed](https://github.com/qdrant/fastembed) は、軽量かつ高速な埋め込み用に設計されています。
- 2024-04-11 ローカル LLM デプロイメント用に [Xinference](./docs/xinference.md) をサポートします。
- 2024-04-10 メソッド「Laws」に新しいレイアウト認識モデルを追加します。
- 2024-04-08 [Ollama](./docs/ollama.md) を使用した大規模モデルのローカライズされたデプロイメントをサポートします。

View File

@ -55,6 +55,8 @@
## 📌 新增功能
- 2024-04-16 添加嵌入模型 [QAnything的bce-embedding-base_v1](https://github.com/netease-youdao/QAnything) 。
- 2024-04-16 添加 [FastEmbed](https://github.com/qdrant/fastembed) 专为轻型和高速嵌入而设计。
- 2024-04-11 支持用 [Xinference](./docs/xinference.md) 本地化部署大模型。
- 2024-04-10 为Laws版面分析增加了底层模型。
- 2024-04-08 支持用 [Ollama](./docs/ollama.md) 本地化部署大模型。

View File

@ -252,7 +252,7 @@ def retrieval_test():
return get_data_error_result(retmsg="Knowledgebase not found!")
embd_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.EMBEDDING.value)
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, similarity_threshold,
vector_similarity_weight, top, doc_ids)
for c in ranks["chunks"]:

View File

@ -15,6 +15,7 @@
#
import base64
import os
import pathlib
import re

View File

@ -28,7 +28,7 @@ from rag.llm import EmbeddingModel, ChatModel
def factories():
try:
fac = LLMFactoriesService.get_all()
return get_json_result(data=[f.to_dict() for f in fac])
return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["QAnything", "FastEmbed"]])
except Exception as e:
return server_error_response(e)
@ -174,7 +174,7 @@ def list():
llms = [m.to_dict()
for m in llms if m.status == StatusEnum.VALID.value]
for m in llms:
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding"
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["QAnything","FastEmbed"]
llm_set = set([m["llm_name"] for m in llms])
for o in objs:

View File

@ -18,7 +18,7 @@ import time
import uuid
from api.db import LLMType, UserTenantRole
from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM
from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM, TenantLLM
from api.db.services import UserService
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
from api.db.services.user_service import TenantService, UserTenantService
@ -114,12 +114,16 @@ factory_infos = [{
"logo": "",
"tags": "TEXT EMBEDDING",
"status": "1",
},
{
}, {
"name": "Xinference",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},{
"name": "QAnything",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},
# {
# "name": "文心一言",
@ -254,12 +258,6 @@ def init_llm_factory():
"tags": "LLM,CHAT,",
"max_tokens": 7900,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[4]["name"],
"llm_name": "flag-embedding",
"tags": "TEXT EMBEDDING,",
"max_tokens": 128 * 1000,
"model_type": LLMType.EMBEDDING.value
}, {
"fid": factory_infos[4]["name"],
"llm_name": "moonshot-v1-32k",
@ -325,6 +323,14 @@ def init_llm_factory():
"max_tokens": 2147483648,
"model_type": LLMType.EMBEDDING.value
},
# ------------------------ QAnything -----------------------
{
"fid": factory_infos[7]["name"],
"llm_name": "maidalun1020/bce-embedding-base_v1",
"tags": "TEXT EMBEDDING,",
"max_tokens": 512,
"model_type": LLMType.EMBEDDING.value
},
]
for info in factory_infos:
try:
@ -337,8 +343,10 @@ def init_llm_factory():
except Exception as e:
pass
LLMFactoriesService.filter_delete([LLMFactories.name=="Local"])
LLMService.filter_delete([LLM.fid=="Local"])
LLMFactoriesService.filter_delete([LLMFactories.name == "Local"])
LLMService.filter_delete([LLM.fid == "Local"])
LLMService.filter_delete([LLM.fid == "Moonshot", LLM.llm_name == "flag-embedding"])
TenantLLMService.filter_delete([TenantLLM.llm_factory == "Moonshot", TenantLLM.llm_name == "flag-embedding"])
"""
drop table llm;

View File

@ -80,8 +80,12 @@ def chat(dialog, messages, **kwargs):
raise LookupError("LLM(%s) not found" % dialog.llm_id)
max_tokens = 1024
else: max_tokens = llm[0].max_tokens
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
embd_nms = list(set([kb.embd_id for kb in kbs]))
assert len(embd_nms) == 1, "Knowledge bases use different embedding models."
questions = [m["content"] for m in messages if m["role"] == "user"]
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
prompt_config = dialog.prompt_config

View File

@ -66,7 +66,7 @@ class TenantLLMService(CommonService):
raise LookupError("Tenant not found")
if llm_type == LLMType.EMBEDDING.value:
mdlnm = tenant.embd_id
mdlnm = tenant.embd_id if not llm_name else llm_name
elif llm_type == LLMType.SPEECH2TEXT.value:
mdlnm = tenant.asr_id
elif llm_type == LLMType.IMAGE2TEXT.value:
@ -77,9 +77,14 @@ class TenantLLMService(CommonService):
assert False, "LLM type error"
model_config = cls.get_api_key(tenant_id, mdlnm)
if model_config: model_config = model_config.to_dict()
if not model_config:
raise LookupError("Model({}) not authorized".format(mdlnm))
model_config = model_config.to_dict()
if llm_type == LLMType.EMBEDDING.value:
llm = LLMService.query(llm_name=llm_name)
if llm and llm[0].fid in ["QAnything", "FastEmbed"]:
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""}
if not model_config: raise LookupError("Model({}) not authorized".format(mdlnm))
if llm_type == LLMType.EMBEDDING.value:
if model_config["llm_factory"] not in EmbeddingModel:
return

View File

@ -41,7 +41,7 @@ class TaskService(CommonService):
Document.size,
Knowledgebase.tenant_id,
Knowledgebase.language,
Tenant.embd_id,
Knowledgebase.embd_id,
Tenant.img2txt_id,
Tenant.asr_id,
cls.model.update_time]

View File

@ -24,8 +24,8 @@ EmbeddingModel = {
"Xinference": XinferenceEmbed,
"Tongyi-Qianwen": HuEmbedding, #QWenEmbed,
"ZHIPU-AI": ZhipuEmbed,
"Moonshot": HuEmbedding,
"FastEmbed": FastEmbed
"FastEmbed": FastEmbed,
"QAnything": QAnythingEmbed
}

View File

@ -20,7 +20,6 @@ from abc import ABC
from ollama import Client
import dashscope
from openai import OpenAI
from fastembed import TextEmbedding
from FlagEmbedding import FlagModel
import torch
import numpy as np
@ -28,6 +27,7 @@ import numpy as np
from api.utils.file_utils import get_project_base_directory
from rag.utils import num_tokens_from_string
try:
flag_model = FlagModel(os.path.join(
get_project_base_directory(),
@ -82,8 +82,10 @@ class HuEmbedding(Base):
class OpenAIEmbed(Base):
def __init__(self, key, model_name="text-embedding-ada-002", base_url="https://api.openai.com/v1"):
if not base_url: base_url="https://api.openai.com/v1"
def __init__(self, key, model_name="text-embedding-ada-002",
base_url="https://api.openai.com/v1"):
if not base_url:
base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
@ -183,10 +185,12 @@ class FastEmbed(Base):
threads: Optional[int] = None,
**kwargs,
):
from fastembed import TextEmbedding
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
def encode(self, texts: list, batch_size=32):
# Using the internal tokenizer to encode the texts and get the total number of tokens
# Using the internal tokenizer to encode the texts and get the total
# number of tokens
encodings = self._model.model.tokenizer.encode_batch(texts)
total_tokens = sum(len(e) for e in encodings)
@ -195,7 +199,8 @@ class FastEmbed(Base):
return np.array(embeddings), total_tokens
def encode_queries(self, text: str):
# Using the internal tokenizer to encode the texts and get the total number of tokens
# Using the internal tokenizer to encode the texts and get the total
# number of tokens
encoding = self._model.model.tokenizer.encode(text)
embedding = next(self._model.query_embed(text)).tolist()
@ -218,3 +223,33 @@ class XinferenceEmbed(Base):
model=self.model_name)
return np.array(res.data[0].embedding), res.usage.total_tokens
class QAnythingEmbed(Base):
_client = None
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
from BCEmbedding import EmbeddingModel as qanthing
if not QAnythingEmbed._client:
try:
print("LOADING BCE...")
QAnythingEmbed._client = qanthing(model_name_or_path=os.path.join(
get_project_base_directory(),
"rag/res/bce-embedding-base_v1"))
except Exception as e:
QAnythingEmbed._client = qanthing(
model_name_or_path=model_name.replace(
"maidalun1020", "InfiniFlow"))
def encode(self, texts: list, batch_size=10):
res = []
token_count = 0
for t in texts:
token_count += num_tokens_from_string(t)
for i in range(0, len(texts), batch_size):
embds = QAnythingEmbed._client.encode(texts[i:i + batch_size])
res.extend(embds)
return np.array(res), token_count
def encode_queries(self, text):
embds = QAnythingEmbed._client.encode([text])
return np.array(embds[0]), num_tokens_from_string(text)

View File

@ -46,7 +46,7 @@ class Dealer:
"k": topk,
"similarity": sim,
"num_candidates": topk * 2,
"query_vector": list(qv)
"query_vector": [float(v) for v in qv]
}
def search(self, req, idxnm, emb_mdl=None):

View File

@ -244,8 +244,9 @@ def main(comm, mod):
for _, r in rows.iterrows():
callback = partial(set_progress, r["id"], r["from_page"], r["to_page"])
try:
embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING)
embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING, llm_name=r["embd_id"], lang=r["language"])
except Exception as e:
traceback.print_stack(e)
callback(prog=-1, msg=str(e))
continue

View File

@ -132,3 +132,5 @@ xpinyin==0.7.6
xxhash==3.4.1
yarl==1.9.4
zhipuai==2.0.1
BCEmbedding
loguru==0.7.2