diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 140a35d67..2ab95b337 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -251,11 +251,8 @@ class OllamaEmbed(Base): return np.array(res["embedding"]), 128 -class FastEmbed(Base): - _model = None - _model_name = "" - _model_lock = threading.Lock() - +class FastEmbed(DefaultEmbedding): + def __init__( self, key: str | None = None, @@ -267,17 +264,17 @@ class FastEmbed(Base): if not settings.LIGHTEN and not FastEmbed._model: with FastEmbed._model_lock: from fastembed import TextEmbedding - if not FastEmbed._model or model_name != FastEmbed._model_name: + if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name: try: - FastEmbed._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) - FastEmbed._model_name = model_name + DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) + DefaultEmbedding._model_name = model_name except Exception: cache_dir = snapshot_download(repo_id="BAAI/bge-small-en-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False) - FastEmbed._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) - self._model = FastEmbed._model + DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) + self._model = DefaultEmbedding._model self._model_name = model_name def encode(self, texts: list):