diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 91c2e77ee..29ec5444f 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -26,6 +26,7 @@ import dashscope from openai import OpenAI from FlagEmbedding import FlagModel import torch +import asyncio import numpy as np from api.utils.file_utils import get_home_cache_dir @@ -304,4 +305,44 @@ class JinaEmbed(Base): def encode_queries(self, text): embds, cnt = self.encode([text]) - return np.array(embds[0]), cnt \ No newline at end of file + return np.array(embds[0]), cnt + + +class InfinityEmbed(Base): + _model = None + + def __init__( + self, + model_names: list[str] = ("BAAI/bge-small-en-v1.5",), + engine_kwargs: dict = {}, + key = None, + ): + + from infinity_emb import EngineArgs + from infinity_emb.engine import AsyncEngineArray + + self._default_model = model_names[0] + self.engine_array = AsyncEngineArray.from_args([EngineArgs(model_name_or_path = model_name, **engine_kwargs) for model_name in model_names]) + + async def _embed(self, sentences: list[str], model_name: str = ""): + if not model_name: + model_name = self._default_model + engine = self.engine_array[model_name] + was_already_running = engine.is_running + if not was_already_running: + await engine.astart() + embeddings, usage = await engine.embed(sentences=sentences) + if not was_already_running: + await engine.astop() + return embeddings, usage + + def encode(self, texts: list[str], model_name: str = "") -> tuple[np.ndarray, int]: + # Using the internal tokenizer to encode the texts and get the total + # number of tokens + embeddings, usage = asyncio.run(self._embed(texts, model_name)) + return np.array(embeddings), usage + + def encode_queries(self, text: str) -> tuple[np.ndarray, int]: + # Using the internal tokenizer to encode the texts and get the total + # number of tokens + return self.encode([text])