From fa54cd5f5cf7ae0d518fe8740a76a99440d4e4c9 Mon Sep 17 00:00:00 2001 From: roc king Date: Wed, 13 Nov 2024 14:10:16 +0800 Subject: [PATCH] =?UTF-8?q?exstract=20model=20dir=20from=20model=E2=80=98s?= =?UTF-8?q?=20full=20name=20(#3368)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What problem does this PR solve? When model’s group name contains 0-9,we can't find downloaded model,because we do not correctly exstract model dir's name from model‘s full name ### Type of change - [ ] Bug Fix (non-breaking change which fixes an issue) Co-authored-by: 王志鹏 Co-authored-by: Kevin Hu --- rag/llm/embedding_model.py | 4 ++-- rag/llm/rerank_model.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index daddbbba8..b9e1fec8e 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -66,12 +66,12 @@ class DefaultEmbedding(Base): import torch if not DefaultEmbedding._model: try: - DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), + DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", use_fp16=torch.cuda.is_available()) except Exception: model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5", - local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), + local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False) DefaultEmbedding._model = FlagModel(model_dir, query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index 29925505f..a2c3902d6 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -65,12 +65,12 @@ class DefaultRerank(Base): if not DefaultRerank._model: try: DefaultRerank._model = FlagReranker( - os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), + os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), use_fp16=torch.cuda.is_available()) except Exception: model_dir = snapshot_download(repo_id=model_name, local_dir=os.path.join(get_home_cache_dir(), - re.sub(r"^[a-zA-Z]+/", "", model_name)), + re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False) DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available()) self._model = DefaultRerank._model @@ -130,7 +130,7 @@ class YoudaoRerank(DefaultRerank): logger.info("LOADING BCE...") YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join( get_home_cache_dir(), - re.sub(r"^[a-zA-Z]+/", "", model_name))) + re.sub(r"^[a-zA-Z0-9]+/", "", model_name))) except Exception: YoudaoRerank._model = RerankerModel( model_name_or_path=model_name.replace(