diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 2ab95b337..785650a57 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -61,11 +61,11 @@ class DefaultEmbedding(Base): ^_- """ - if not settings.LIGHTEN and not DefaultEmbedding._model: + if not settings.LIGHTEN: with DefaultEmbedding._model_lock: from FlagEmbedding import FlagModel import torch - if not DefaultEmbedding._model: + if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name: try: DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", @@ -261,7 +261,7 @@ class FastEmbed(DefaultEmbedding): threads: int | None = None, **kwargs, ): - if not settings.LIGHTEN and not FastEmbed._model: + if not settings.LIGHTEN: with FastEmbed._model_lock: from fastembed import TextEmbedding if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name: