infinity: Update embedding_model.py (#1109)

### What problem does this PR solve?

I implemented infinity, a fast vector embeddings engine. 

### Type of change


- [x] Performance Improvement
- [X] Other (please describe):
This commit is contained in:
Michael Feil 2024-06-10 17:23:58 -07:00 committed by GitHub
parent f900e432f3
commit 68a698655a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -26,6 +26,7 @@ import dashscope
from openai import OpenAI from openai import OpenAI
from FlagEmbedding import FlagModel from FlagEmbedding import FlagModel
import torch import torch
import asyncio
import numpy as np import numpy as np
from api.utils.file_utils import get_home_cache_dir from api.utils.file_utils import get_home_cache_dir
@ -305,3 +306,43 @@ class JinaEmbed(Base):
def encode_queries(self, text): def encode_queries(self, text):
embds, cnt = self.encode([text]) embds, cnt = self.encode([text])
return np.array(embds[0]), cnt return np.array(embds[0]), cnt
class InfinityEmbed(Base):
_model = None
def __init__(
self,
model_names: list[str] = ("BAAI/bge-small-en-v1.5",),
engine_kwargs: dict = {},
key = None,
):
from infinity_emb import EngineArgs
from infinity_emb.engine import AsyncEngineArray
self._default_model = model_names[0]
self.engine_array = AsyncEngineArray.from_args([EngineArgs(model_name_or_path = model_name, **engine_kwargs) for model_name in model_names])
async def _embed(self, sentences: list[str], model_name: str = ""):
if not model_name:
model_name = self._default_model
engine = self.engine_array[model_name]
was_already_running = engine.is_running
if not was_already_running:
await engine.astart()
embeddings, usage = await engine.embed(sentences=sentences)
if not was_already_running:
await engine.astop()
return embeddings, usage
def encode(self, texts: list[str], model_name: str = "") -> tuple[np.ndarray, int]:
# Using the internal tokenizer to encode the texts and get the total
# number of tokens
embeddings, usage = asyncio.run(self._embed(texts, model_name))
return np.array(embeddings), usage
def encode_queries(self, text: str) -> tuple[np.ndarray, int]:
# Using the internal tokenizer to encode the texts and get the total
# number of tokens
return self.encode([text])