Refactor rerank model with dynamic batch processing and memory manage… (#5273)

…ment

### What problem does this PR solve?
Issue:https://github.com/infiniflow/ragflow/issues/5262
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Co-authored-by: wenju.li <wenju.li@deepctr.cn>
This commit is contained in:
liwenju0 2025-02-24 11:32:08 +08:00 committed by GitHub
parent 3d605a23fe
commit 569e40544d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -31,6 +31,7 @@ from rag.utils import num_tokens_from_string, truncate
import json
def sigmoid(x):
return 1 / (1 + np.exp(-x))
@ -86,6 +87,57 @@ class DefaultRerank(Base):
local_dir_use_symlinks=False)
DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
self._model = DefaultRerank._model
self._dynamic_batch_size = 8
self._min_batch_size = 1
def torch_empty_cache(self):
try:
import torch
torch.cuda.empty_cache()
except Exception as e:
print(f"Error emptying cache: {e}")
def _process_batch(self, pairs, max_batch_size=None):
"""template method for subclass call"""
old_dynamic_batch_size = self._dynamic_batch_size
if max_batch_size is not None:
self._dynamic_batch_size = max_batch_size
res = []
i = 0
while i < len(pairs):
current_batch = self._dynamic_batch_size
max_retries = 5
retry_count = 0
while retry_count < max_retries:
try:
# call subclass implemented batch processing calculation
batch_scores = self._compute_batch_scores(pairs[i:i+current_batch])
res.extend(batch_scores)
i += current_batch
self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8)
break
except RuntimeError as e:
if "CUDA out of memory" in str(e) and current_batch > self._min_batch_size:
current_batch = max(current_batch // 2, self._min_batch_size)
self.torch_empty_cache()
retry_count += 1
else:
raise
if retry_count >= max_retries:
raise RuntimeError("max retry times, still cannot process batch, please check your GPU memory")
self.torch_empty_cache()
self._dynamic_batch_size = old_dynamic_batch_size
return np.array(res)
def _compute_batch_scores(self, batch_pairs, max_length=None):
if max_length is None:
max_length = self._model.max_length
scores = self._model.compute_score(batch_pairs, max_length=max_length)
scores = sigmoid(np.array(scores)).tolist()
return scores
def similarity(self, query: str, texts: list):
pairs = [(query, truncate(t, 2048)) for t in texts]
@ -93,14 +145,7 @@ class DefaultRerank(Base):
for _, t in pairs:
token_count += num_tokens_from_string(t)
batch_size = 4096
res = []
for i in range(0, len(pairs), batch_size):
scores = self._model.compute_score(pairs[i:i + batch_size], max_length=2048)
scores = sigmoid(np.array(scores)).tolist()
if isinstance(scores, float):
res.append(scores)
else:
res.extend(scores)
res = self._process_batch(pairs, max_batch_size=batch_size)
return np.array(res), token_count
@ -155,14 +200,7 @@ class YoudaoRerank(DefaultRerank):
for _, t in pairs:
token_count += num_tokens_from_string(t)
batch_size = 8
res = []
for i in range(0, len(pairs), batch_size):
scores = self._model.compute_score(pairs[i:i + batch_size], max_length=self._model.max_length)
scores = sigmoid(np.array(scores)).tolist()
if isinstance(scores, float):
res.append(scores)
else:
res.extend(scores)
res = self._process_batch(pairs, max_batch_size=batch_size)
return np.array(res), token_count