Add bce-embedding and fastembed (#383)

### What problem does this PR solve?


Issue link:#326

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
KevinHuSh 2024-04-16 16:42:19 +08:00 committed by GitHub
parent a7be5d4e8b
commit 890561703b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 99 additions and 37 deletions

View File

@ -55,6 +55,8 @@
## 📌 Latest Features ## 📌 Latest Features
- 2024-04-16 Add an embedding model 'bce-embedding-base_v1' from [QAnything](https://github.com/netease-youdao/QAnything).
- 2024-04-16 Add [FastEmbed](https://github.com/qdrant/fastembed) is designed for light and speeding embedding.
- 2024-04-11 Support [Xinference](./docs/xinference.md) for local LLM deployment. - 2024-04-11 Support [Xinference](./docs/xinference.md) for local LLM deployment.
- 2024-04-10 Add a new layout recognization model for analyzing Laws documentation. - 2024-04-10 Add a new layout recognization model for analyzing Laws documentation.
- 2024-04-08 Support [Ollama](./docs/ollama.md) for local LLM deployment. - 2024-04-08 Support [Ollama](./docs/ollama.md) for local LLM deployment.

View File

@ -55,6 +55,8 @@
## 📌 最新の機能 ## 📌 最新の機能
- 2024-04-16 [QAnything](https://github.com/netease-youdao/QAnything) から埋め込みモデル「bce-embedding-base_v1」を追加します。
- 2024-04-16 [FastEmbed](https://github.com/qdrant/fastembed) は、軽量かつ高速な埋め込み用に設計されています。
- 2024-04-11 ローカル LLM デプロイメント用に [Xinference](./docs/xinference.md) をサポートします。 - 2024-04-11 ローカル LLM デプロイメント用に [Xinference](./docs/xinference.md) をサポートします。
- 2024-04-10 メソッド「Laws」に新しいレイアウト認識モデルを追加します。 - 2024-04-10 メソッド「Laws」に新しいレイアウト認識モデルを追加します。
- 2024-04-08 [Ollama](./docs/ollama.md) を使用した大規模モデルのローカライズされたデプロイメントをサポートします。 - 2024-04-08 [Ollama](./docs/ollama.md) を使用した大規模モデルのローカライズされたデプロイメントをサポートします。

View File

@ -55,6 +55,8 @@
## 📌 新增功能 ## 📌 新增功能
- 2024-04-16 添加嵌入模型 [QAnything的bce-embedding-base_v1](https://github.com/netease-youdao/QAnything) 。
- 2024-04-16 添加 [FastEmbed](https://github.com/qdrant/fastembed) 专为轻型和高速嵌入而设计。
- 2024-04-11 支持用 [Xinference](./docs/xinference.md) 本地化部署大模型。 - 2024-04-11 支持用 [Xinference](./docs/xinference.md) 本地化部署大模型。
- 2024-04-10 为Laws版面分析增加了底层模型。 - 2024-04-10 为Laws版面分析增加了底层模型。
- 2024-04-08 支持用 [Ollama](./docs/ollama.md) 本地化部署大模型。 - 2024-04-08 支持用 [Ollama](./docs/ollama.md) 本地化部署大模型。

View File

@ -252,7 +252,7 @@ def retrieval_test():
return get_data_error_result(retmsg="Knowledgebase not found!") return get_data_error_result(retmsg="Knowledgebase not found!")
embd_mdl = TenantLLMService.model_instance( embd_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.EMBEDDING.value) kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, similarity_threshold, ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, similarity_threshold,
vector_similarity_weight, top, doc_ids) vector_similarity_weight, top, doc_ids)
for c in ranks["chunks"]: for c in ranks["chunks"]:

View File

@ -15,6 +15,7 @@
# #
import base64 import base64
import os
import pathlib import pathlib
import re import re

View File

@ -28,7 +28,7 @@ from rag.llm import EmbeddingModel, ChatModel
def factories(): def factories():
try: try:
fac = LLMFactoriesService.get_all() fac = LLMFactoriesService.get_all()
return get_json_result(data=[f.to_dict() for f in fac]) return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["QAnything", "FastEmbed"]])
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -174,7 +174,7 @@ def list():
llms = [m.to_dict() llms = [m.to_dict()
for m in llms if m.status == StatusEnum.VALID.value] for m in llms if m.status == StatusEnum.VALID.value]
for m in llms: for m in llms:
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["QAnything","FastEmbed"]
llm_set = set([m["llm_name"] for m in llms]) llm_set = set([m["llm_name"] for m in llms])
for o in objs: for o in objs:

View File

@ -18,7 +18,7 @@ import time
import uuid import uuid
from api.db import LLMType, UserTenantRole from api.db import LLMType, UserTenantRole
from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM, TenantLLM
from api.db.services import UserService from api.db.services import UserService
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
from api.db.services.user_service import TenantService, UserTenantService from api.db.services.user_service import TenantService, UserTenantService
@ -114,12 +114,16 @@ factory_infos = [{
"logo": "", "logo": "",
"tags": "TEXT EMBEDDING", "tags": "TEXT EMBEDDING",
"status": "1", "status": "1",
}, }, {
{
"name": "Xinference", "name": "Xinference",
"logo": "", "logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1", "status": "1",
},{
"name": "QAnything",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
}, },
# { # {
# "name": "文心一言", # "name": "文心一言",
@ -254,12 +258,6 @@ def init_llm_factory():
"tags": "LLM,CHAT,", "tags": "LLM,CHAT,",
"max_tokens": 7900, "max_tokens": 7900,
"model_type": LLMType.CHAT.value "model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[4]["name"],
"llm_name": "flag-embedding",
"tags": "TEXT EMBEDDING,",
"max_tokens": 128 * 1000,
"model_type": LLMType.EMBEDDING.value
}, { }, {
"fid": factory_infos[4]["name"], "fid": factory_infos[4]["name"],
"llm_name": "moonshot-v1-32k", "llm_name": "moonshot-v1-32k",
@ -325,6 +323,14 @@ def init_llm_factory():
"max_tokens": 2147483648, "max_tokens": 2147483648,
"model_type": LLMType.EMBEDDING.value "model_type": LLMType.EMBEDDING.value
}, },
# ------------------------ QAnything -----------------------
{
"fid": factory_infos[7]["name"],
"llm_name": "maidalun1020/bce-embedding-base_v1",
"tags": "TEXT EMBEDDING,",
"max_tokens": 512,
"model_type": LLMType.EMBEDDING.value
},
] ]
for info in factory_infos: for info in factory_infos:
try: try:
@ -337,8 +343,10 @@ def init_llm_factory():
except Exception as e: except Exception as e:
pass pass
LLMFactoriesService.filter_delete([LLMFactories.name=="Local"]) LLMFactoriesService.filter_delete([LLMFactories.name == "Local"])
LLMService.filter_delete([LLM.fid=="Local"]) LLMService.filter_delete([LLM.fid == "Local"])
LLMService.filter_delete([LLM.fid == "Moonshot", LLM.llm_name == "flag-embedding"])
TenantLLMService.filter_delete([TenantLLM.llm_factory == "Moonshot", TenantLLM.llm_name == "flag-embedding"])
""" """
drop table llm; drop table llm;

View File

@ -80,8 +80,12 @@ def chat(dialog, messages, **kwargs):
raise LookupError("LLM(%s) not found" % dialog.llm_id) raise LookupError("LLM(%s) not found" % dialog.llm_id)
max_tokens = 1024 max_tokens = 1024
else: max_tokens = llm[0].max_tokens else: max_tokens = llm[0].max_tokens
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
embd_nms = list(set([kb.embd_id for kb in kbs]))
assert len(embd_nms) == 1, "Knowledge bases use different embedding models."
questions = [m["content"] for m in messages if m["role"] == "user"] questions = [m["content"] for m in messages if m["role"] == "user"]
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING) embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
prompt_config = dialog.prompt_config prompt_config = dialog.prompt_config

View File

@ -66,7 +66,7 @@ class TenantLLMService(CommonService):
raise LookupError("Tenant not found") raise LookupError("Tenant not found")
if llm_type == LLMType.EMBEDDING.value: if llm_type == LLMType.EMBEDDING.value:
mdlnm = tenant.embd_id mdlnm = tenant.embd_id if not llm_name else llm_name
elif llm_type == LLMType.SPEECH2TEXT.value: elif llm_type == LLMType.SPEECH2TEXT.value:
mdlnm = tenant.asr_id mdlnm = tenant.asr_id
elif llm_type == LLMType.IMAGE2TEXT.value: elif llm_type == LLMType.IMAGE2TEXT.value:
@ -77,9 +77,14 @@ class TenantLLMService(CommonService):
assert False, "LLM type error" assert False, "LLM type error"
model_config = cls.get_api_key(tenant_id, mdlnm) model_config = cls.get_api_key(tenant_id, mdlnm)
if model_config: model_config = model_config.to_dict()
if not model_config: if not model_config:
raise LookupError("Model({}) not authorized".format(mdlnm)) if llm_type == LLMType.EMBEDDING.value:
model_config = model_config.to_dict() llm = LLMService.query(llm_name=llm_name)
if llm and llm[0].fid in ["QAnything", "FastEmbed"]:
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""}
if not model_config: raise LookupError("Model({}) not authorized".format(mdlnm))
if llm_type == LLMType.EMBEDDING.value: if llm_type == LLMType.EMBEDDING.value:
if model_config["llm_factory"] not in EmbeddingModel: if model_config["llm_factory"] not in EmbeddingModel:
return return

View File

@ -41,7 +41,7 @@ class TaskService(CommonService):
Document.size, Document.size,
Knowledgebase.tenant_id, Knowledgebase.tenant_id,
Knowledgebase.language, Knowledgebase.language,
Tenant.embd_id, Knowledgebase.embd_id,
Tenant.img2txt_id, Tenant.img2txt_id,
Tenant.asr_id, Tenant.asr_id,
cls.model.update_time] cls.model.update_time]

View File

@ -24,8 +24,8 @@ EmbeddingModel = {
"Xinference": XinferenceEmbed, "Xinference": XinferenceEmbed,
"Tongyi-Qianwen": HuEmbedding, #QWenEmbed, "Tongyi-Qianwen": HuEmbedding, #QWenEmbed,
"ZHIPU-AI": ZhipuEmbed, "ZHIPU-AI": ZhipuEmbed,
"Moonshot": HuEmbedding, "FastEmbed": FastEmbed,
"FastEmbed": FastEmbed "QAnything": QAnythingEmbed
} }

View File

@ -20,7 +20,6 @@ from abc import ABC
from ollama import Client from ollama import Client
import dashscope import dashscope
from openai import OpenAI from openai import OpenAI
from fastembed import TextEmbedding
from FlagEmbedding import FlagModel from FlagEmbedding import FlagModel
import torch import torch
import numpy as np import numpy as np
@ -28,6 +27,7 @@ import numpy as np
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from rag.utils import num_tokens_from_string from rag.utils import num_tokens_from_string
try: try:
flag_model = FlagModel(os.path.join( flag_model = FlagModel(os.path.join(
get_project_base_directory(), get_project_base_directory(),
@ -82,8 +82,10 @@ class HuEmbedding(Base):
class OpenAIEmbed(Base): class OpenAIEmbed(Base):
def __init__(self, key, model_name="text-embedding-ada-002", base_url="https://api.openai.com/v1"): def __init__(self, key, model_name="text-embedding-ada-002",
if not base_url: base_url="https://api.openai.com/v1" base_url="https://api.openai.com/v1"):
if not base_url:
base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url) self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name self.model_name = model_name
@ -183,10 +185,12 @@ class FastEmbed(Base):
threads: Optional[int] = None, threads: Optional[int] = None,
**kwargs, **kwargs,
): ):
from fastembed import TextEmbedding
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
def encode(self, texts: list, batch_size=32): def encode(self, texts: list, batch_size=32):
# Using the internal tokenizer to encode the texts and get the total number of tokens # Using the internal tokenizer to encode the texts and get the total
# number of tokens
encodings = self._model.model.tokenizer.encode_batch(texts) encodings = self._model.model.tokenizer.encode_batch(texts)
total_tokens = sum(len(e) for e in encodings) total_tokens = sum(len(e) for e in encodings)
@ -195,7 +199,8 @@ class FastEmbed(Base):
return np.array(embeddings), total_tokens return np.array(embeddings), total_tokens
def encode_queries(self, text: str): def encode_queries(self, text: str):
# Using the internal tokenizer to encode the texts and get the total number of tokens # Using the internal tokenizer to encode the texts and get the total
# number of tokens
encoding = self._model.model.tokenizer.encode(text) encoding = self._model.model.tokenizer.encode(text)
embedding = next(self._model.query_embed(text)).tolist() embedding = next(self._model.query_embed(text)).tolist()
@ -218,3 +223,33 @@ class XinferenceEmbed(Base):
model=self.model_name) model=self.model_name)
return np.array(res.data[0].embedding), res.usage.total_tokens return np.array(res.data[0].embedding), res.usage.total_tokens
class QAnythingEmbed(Base):
_client = None
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
from BCEmbedding import EmbeddingModel as qanthing
if not QAnythingEmbed._client:
try:
print("LOADING BCE...")
QAnythingEmbed._client = qanthing(model_name_or_path=os.path.join(
get_project_base_directory(),
"rag/res/bce-embedding-base_v1"))
except Exception as e:
QAnythingEmbed._client = qanthing(
model_name_or_path=model_name.replace(
"maidalun1020", "InfiniFlow"))
def encode(self, texts: list, batch_size=10):
res = []
token_count = 0
for t in texts:
token_count += num_tokens_from_string(t)
for i in range(0, len(texts), batch_size):
embds = QAnythingEmbed._client.encode(texts[i:i + batch_size])
res.extend(embds)
return np.array(res), token_count
def encode_queries(self, text):
embds = QAnythingEmbed._client.encode([text])
return np.array(embds[0]), num_tokens_from_string(text)

View File

@ -46,7 +46,7 @@ class Dealer:
"k": topk, "k": topk,
"similarity": sim, "similarity": sim,
"num_candidates": topk * 2, "num_candidates": topk * 2,
"query_vector": list(qv) "query_vector": [float(v) for v in qv]
} }
def search(self, req, idxnm, emb_mdl=None): def search(self, req, idxnm, emb_mdl=None):

View File

@ -244,8 +244,9 @@ def main(comm, mod):
for _, r in rows.iterrows(): for _, r in rows.iterrows():
callback = partial(set_progress, r["id"], r["from_page"], r["to_page"]) callback = partial(set_progress, r["id"], r["from_page"], r["to_page"])
try: try:
embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING) embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING, llm_name=r["embd_id"], lang=r["language"])
except Exception as e: except Exception as e:
traceback.print_stack(e)
callback(prog=-1, msg=str(e)) callback(prog=-1, msg=str(e))
continue continue

View File

@ -132,3 +132,5 @@ xpinyin==0.7.6
xxhash==3.4.1 xxhash==3.4.1
yarl==1.9.4 yarl==1.9.4
zhipuai==2.0.1 zhipuai==2.0.1
BCEmbedding
loguru==0.7.2