mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-14 04:26:05 +08:00
Fix fastembed reloading issue. (#4117)
### What problem does this PR solve? ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
parent
8939206531
commit
7474348394
@ -47,6 +47,7 @@ class Base(ABC):
|
|||||||
|
|
||||||
class DefaultEmbedding(Base):
|
class DefaultEmbedding(Base):
|
||||||
_model = None
|
_model = None
|
||||||
|
_model_name = ""
|
||||||
_model_lock = threading.Lock()
|
_model_lock = threading.Lock()
|
||||||
def __init__(self, key, model_name, **kwargs):
|
def __init__(self, key, model_name, **kwargs):
|
||||||
"""
|
"""
|
||||||
@ -69,6 +70,7 @@ class DefaultEmbedding(Base):
|
|||||||
DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
||||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||||
use_fp16=torch.cuda.is_available())
|
use_fp16=torch.cuda.is_available())
|
||||||
|
DefaultEmbedding._model_name = model_name
|
||||||
except Exception:
|
except Exception:
|
||||||
model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
|
model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
|
||||||
local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
||||||
@ -77,6 +79,7 @@ class DefaultEmbedding(Base):
|
|||||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||||
use_fp16=torch.cuda.is_available())
|
use_fp16=torch.cuda.is_available())
|
||||||
self._model = DefaultEmbedding._model
|
self._model = DefaultEmbedding._model
|
||||||
|
self._model_name = DefaultEmbedding._model_name
|
||||||
|
|
||||||
def encode(self, texts: list):
|
def encode(self, texts: list):
|
||||||
batch_size = 16
|
batch_size = 16
|
||||||
@ -250,6 +253,8 @@ class OllamaEmbed(Base):
|
|||||||
|
|
||||||
class FastEmbed(Base):
|
class FastEmbed(Base):
|
||||||
_model = None
|
_model = None
|
||||||
|
_model_name = ""
|
||||||
|
_model_lock = threading.Lock()
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -260,8 +265,20 @@ class FastEmbed(Base):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if not settings.LIGHTEN and not FastEmbed._model:
|
if not settings.LIGHTEN and not FastEmbed._model:
|
||||||
from fastembed import TextEmbedding
|
with FastEmbed._model_lock:
|
||||||
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
from fastembed import TextEmbedding
|
||||||
|
if not FastEmbed._model or model_name != FastEmbed._model_name:
|
||||||
|
try:
|
||||||
|
FastEmbed._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
||||||
|
FastEmbed._model_name = model_name
|
||||||
|
except Exception:
|
||||||
|
cache_dir = snapshot_download(repo_id="BAAI/bge-small-en-v1.5",
|
||||||
|
local_dir=os.path.join(get_home_cache_dir(),
|
||||||
|
re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
||||||
|
local_dir_use_symlinks=False)
|
||||||
|
FastEmbed._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
||||||
|
self._model = FastEmbed._model
|
||||||
|
self._model_name = model_name
|
||||||
|
|
||||||
def encode(self, texts: list):
|
def encode(self, texts: list):
|
||||||
# Using the internal tokenizer to encode the texts and get the total
|
# Using the internal tokenizer to encode the texts and get the total
|
||||||
|
Loading…
x
Reference in New Issue
Block a user