feat: FastEmbed embedding support (#291)

### Description

Following up on https://github.com/infiniflow/ragflow/pull/275, this PR
adds support for FastEmbed model configurations.

The options are not exhaustive. You can find the full list
[here](https://qdrant.github.io/fastembed/examples/Supported_Models/).

P.S. I ran into OOM issues when building the image.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: KevinHuSh <kevinhu.sh@gmail.com>
This commit is contained in:
Anush 2024-04-15 13:28:06 +05:30 committed by GitHub
parent e5a5b820a8
commit 826ad6a33a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 91 additions and 1 deletions

View File

@ -109,6 +109,11 @@ factory_infos = [{
"logo": "",
"tags": "LLM,TEXT EMBEDDING",
"status": "1",
}, {
"name": "FastEmbed",
"logo": "",
"tags": "TEXT EMBEDDING",
"status": "1",
},
{
"name": "Xinference",
@ -268,6 +273,58 @@ def init_llm_factory():
"max_tokens": 128 * 1000,
"model_type": LLMType.CHAT.value
},
# ------------------------ FastEmbed -----------------------
{
"fid": factory_infos[5]["name"],
"llm_name": "BAAI/bge-small-en-v1.5",
"tags": "TEXT EMBEDDING,",
"max_tokens": 512,
"model_type": LLMType.EMBEDDING.value
}, {
"fid": factory_infos[5]["name"],
"llm_name": "BAAI/bge-small-zh-v1.5",
"tags": "TEXT EMBEDDING,",
"max_tokens": 512,
"model_type": LLMType.EMBEDDING.value
}, {
}, {
"fid": factory_infos[5]["name"],
"llm_name": "BAAI/bge-base-en-v1.5",
"tags": "TEXT EMBEDDING,",
"max_tokens": 512,
"model_type": LLMType.EMBEDDING.value
}, {
}, {
"fid": factory_infos[5]["name"],
"llm_name": "BAAI/bge-large-en-v1.5",
"tags": "TEXT EMBEDDING,",
"max_tokens": 512,
"model_type": LLMType.EMBEDDING.value
}, {
"fid": factory_infos[5]["name"],
"llm_name": "sentence-transformers/all-MiniLM-L6-v2",
"tags": "TEXT EMBEDDING,",
"max_tokens": 512,
"model_type": LLMType.EMBEDDING.value
}, {
"fid": factory_infos[5]["name"],
"llm_name": "nomic-ai/nomic-embed-text-v1.5",
"tags": "TEXT EMBEDDING,",
"max_tokens": 8192,
"model_type": LLMType.EMBEDDING.value
}, {
"fid": factory_infos[5]["name"],
"llm_name": "jinaai/jina-embeddings-v2-small-en",
"tags": "TEXT EMBEDDING,",
"max_tokens": 2147483648,
"model_type": LLMType.EMBEDDING.value
}, {
"fid": factory_infos[5]["name"],
"llm_name": "jinaai/jina-embeddings-v2-base-en",
"tags": "TEXT EMBEDDING,",
"max_tokens": 2147483648,
"model_type": LLMType.EMBEDDING.value
},
]
for info in factory_infos:
try:

View File

@ -24,7 +24,8 @@ EmbeddingModel = {
"Xinference": XinferenceEmbed,
"Tongyi-Qianwen": HuEmbedding, #QWenEmbed,
"ZHIPU-AI": ZhipuEmbed,
"Moonshot": HuEmbedding
"Moonshot": HuEmbedding,
"FastEmbed": FastEmbed
}

View File

@ -13,12 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional
from zhipuai import ZhipuAI
import os
from abc import ABC
from ollama import Client
import dashscope
from openai import OpenAI
from fastembed import TextEmbedding
from FlagEmbedding import FlagModel
import torch
import numpy as np
@ -172,6 +174,34 @@ class OllamaEmbed(Base):
return np.array(res["embedding"]), 128
class FastEmbed(Base):
def __init__(
self,
key: Optional[str] = None,
model_name: str = "BAAI/bge-small-en-v1.5",
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
**kwargs,
):
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
def encode(self, texts: list, batch_size=32):
# Using the internal tokenizer to encode the texts and get the total number of tokens
encodings = self._model.model.tokenizer.encode_batch(texts)
total_tokens = sum(len(e) for e in encodings)
embeddings = [e.tolist() for e in self._model.embed(texts, batch_size)]
return np.array(embeddings), total_tokens
def encode_queries(self, text: str):
# Using the internal tokenizer to encode the texts and get the total number of tokens
encoding = self._model.model.tokenizer.encode(text)
embedding = next(self._model.query_embed(text)).tolist()
return np.array(embedding), len(encoding.ids)
class XinferenceEmbed(Base):
def __init__(self, key, model_name="", base_url=""):
self.client = OpenAI(api_key="xxx", base_url=base_url)
@ -187,3 +217,4 @@ class XinferenceEmbed(Base):
res = self.client.embeddings.create(input=[text],
model=self.model_name)
return np.array(res.data[0].embedding), res.usage.total_tokens

View File

@ -27,6 +27,7 @@ elasticsearch==8.12.1
elasticsearch-dsl==8.12.0
et-xmlfile==1.1.0
filelock==3.13.1
fastembed==0.2.6
FlagEmbedding==1.2.5
Flask==3.0.2
Flask-Cors==4.0.0