From d8fca4301764f32a964d0788486c5ee3deee526a Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Thu, 19 Dec 2024 17:27:09 +0800 Subject: [PATCH] Make fast embed and default embed mutually exclusive. (#4121) ### What problem does this PR solve? ### Type of change - [x] Performance Improvement --- rag/llm/embedding_model.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 140a35d67..2ab95b337 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -251,11 +251,8 @@ class OllamaEmbed(Base): return np.array(res["embedding"]), 128 -class FastEmbed(Base): - _model = None - _model_name = "" - _model_lock = threading.Lock() - +class FastEmbed(DefaultEmbedding): + def __init__( self, key: str | None = None, @@ -267,17 +264,17 @@ class FastEmbed(Base): if not settings.LIGHTEN and not FastEmbed._model: with FastEmbed._model_lock: from fastembed import TextEmbedding - if not FastEmbed._model or model_name != FastEmbed._model_name: + if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name: try: - FastEmbed._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) - FastEmbed._model_name = model_name + DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) + DefaultEmbedding._model_name = model_name except Exception: cache_dir = snapshot_download(repo_id="BAAI/bge-small-en-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False) - FastEmbed._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) - self._model = FastEmbed._model + DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) + self._model = DefaultEmbedding._model self._model_name = model_name def encode(self, texts: list):