From 74743483945319743bd2dafa6150d270b49bbc04 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Thu, 19 Dec 2024 16:18:18 +0800 Subject: [PATCH] Fix fastembed reloading issue. (#4117) ### What problem does this PR solve? ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- rag/llm/embedding_model.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index d4fdc0956..140a35d67 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -47,6 +47,7 @@ class Base(ABC): class DefaultEmbedding(Base): _model = None + _model_name = "" _model_lock = threading.Lock() def __init__(self, key, model_name, **kwargs): """ @@ -69,6 +70,7 @@ class DefaultEmbedding(Base): DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", use_fp16=torch.cuda.is_available()) + DefaultEmbedding._model_name = model_name except Exception: model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), @@ -77,6 +79,7 @@ class DefaultEmbedding(Base): query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", use_fp16=torch.cuda.is_available()) self._model = DefaultEmbedding._model + self._model_name = DefaultEmbedding._model_name def encode(self, texts: list): batch_size = 16 @@ -250,6 +253,8 @@ class OllamaEmbed(Base): class FastEmbed(Base): _model = None + _model_name = "" + _model_lock = threading.Lock() def __init__( self, @@ -260,8 +265,20 @@ class FastEmbed(Base): **kwargs, ): if not settings.LIGHTEN and not FastEmbed._model: - from fastembed import TextEmbedding - self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) + with FastEmbed._model_lock: + from fastembed import TextEmbedding + if not FastEmbed._model or model_name != FastEmbed._model_name: + try: + FastEmbed._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) + FastEmbed._model_name = model_name + except Exception: + cache_dir = snapshot_download(repo_id="BAAI/bge-small-en-v1.5", + local_dir=os.path.join(get_home_cache_dir(), + re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), + local_dir_use_symlinks=False) + FastEmbed._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) + self._model = FastEmbed._model + self._model_name = model_name def encode(self, texts: list): # Using the internal tokenizer to encode the texts and get the total