diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index daddbbba8..b9e1fec8e 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -66,12 +66,12 @@ class DefaultEmbedding(Base): import torch if not DefaultEmbedding._model: try: - DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), + DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", use_fp16=torch.cuda.is_available()) except Exception: model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5", - local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), + local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False) DefaultEmbedding._model = FlagModel(model_dir, query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index 29925505f..a2c3902d6 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -65,12 +65,12 @@ class DefaultRerank(Base): if not DefaultRerank._model: try: DefaultRerank._model = FlagReranker( - os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), + os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), use_fp16=torch.cuda.is_available()) except Exception: model_dir = snapshot_download(repo_id=model_name, local_dir=os.path.join(get_home_cache_dir(), - re.sub(r"^[a-zA-Z]+/", "", model_name)), + re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False) DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available()) self._model = DefaultRerank._model @@ -130,7 +130,7 @@ class YoudaoRerank(DefaultRerank): logger.info("LOADING BCE...") YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join( get_home_cache_dir(), - re.sub(r"^[a-zA-Z]+/", "", model_name))) + re.sub(r"^[a-zA-Z0-9]+/", "", model_name))) except Exception: YoudaoRerank._model = RerankerModel( model_name_or_path=model_name.replace(