mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-12 07:19:04 +08:00
infinity: Update embedding_model.py (#1109)
### What problem does this PR solve? I implemented infinity, a fast vector embeddings engine. ### Type of change - [x] Performance Improvement - [X] Other (please describe):
This commit is contained in:
parent
f900e432f3
commit
68a698655a
@ -26,6 +26,7 @@ import dashscope
|
||||
from openai import OpenAI
|
||||
from FlagEmbedding import FlagModel
|
||||
import torch
|
||||
import asyncio
|
||||
import numpy as np
|
||||
|
||||
from api.utils.file_utils import get_home_cache_dir
|
||||
@ -304,4 +305,44 @@ class JinaEmbed(Base):
|
||||
|
||||
def encode_queries(self, text):
|
||||
embds, cnt = self.encode([text])
|
||||
return np.array(embds[0]), cnt
|
||||
return np.array(embds[0]), cnt
|
||||
|
||||
|
||||
class InfinityEmbed(Base):
|
||||
_model = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_names: list[str] = ("BAAI/bge-small-en-v1.5",),
|
||||
engine_kwargs: dict = {},
|
||||
key = None,
|
||||
):
|
||||
|
||||
from infinity_emb import EngineArgs
|
||||
from infinity_emb.engine import AsyncEngineArray
|
||||
|
||||
self._default_model = model_names[0]
|
||||
self.engine_array = AsyncEngineArray.from_args([EngineArgs(model_name_or_path = model_name, **engine_kwargs) for model_name in model_names])
|
||||
|
||||
async def _embed(self, sentences: list[str], model_name: str = ""):
|
||||
if not model_name:
|
||||
model_name = self._default_model
|
||||
engine = self.engine_array[model_name]
|
||||
was_already_running = engine.is_running
|
||||
if not was_already_running:
|
||||
await engine.astart()
|
||||
embeddings, usage = await engine.embed(sentences=sentences)
|
||||
if not was_already_running:
|
||||
await engine.astop()
|
||||
return embeddings, usage
|
||||
|
||||
def encode(self, texts: list[str], model_name: str = "") -> tuple[np.ndarray, int]:
|
||||
# Using the internal tokenizer to encode the texts and get the total
|
||||
# number of tokens
|
||||
embeddings, usage = asyncio.run(self._embed(texts, model_name))
|
||||
return np.array(embeddings), usage
|
||||
|
||||
def encode_queries(self, text: str) -> tuple[np.ndarray, int]:
|
||||
# Using the internal tokenizer to encode the texts and get the total
|
||||
# number of tokens
|
||||
return self.encode([text])
|
||||
|
Loading…
x
Reference in New Issue
Block a user