diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index c57e7f843..6caa28b0c 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -31,6 +31,7 @@ from rag.utils import num_tokens_from_string, truncate import json + def sigmoid(x): return 1 / (1 + np.exp(-x)) @@ -86,6 +87,57 @@ class DefaultRerank(Base): local_dir_use_symlinks=False) DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available()) self._model = DefaultRerank._model + self._dynamic_batch_size = 8 + self._min_batch_size = 1 + + + def torch_empty_cache(self): + try: + import torch + torch.cuda.empty_cache() + except Exception as e: + print(f"Error emptying cache: {e}") + + def _process_batch(self, pairs, max_batch_size=None): + """template method for subclass call""" + old_dynamic_batch_size = self._dynamic_batch_size + if max_batch_size is not None: + self._dynamic_batch_size = max_batch_size + res = [] + i = 0 + while i < len(pairs): + current_batch = self._dynamic_batch_size + max_retries = 5 + retry_count = 0 + while retry_count < max_retries: + try: + # call subclass implemented batch processing calculation + batch_scores = self._compute_batch_scores(pairs[i:i+current_batch]) + res.extend(batch_scores) + i += current_batch + self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8) + break + except RuntimeError as e: + if "CUDA out of memory" in str(e) and current_batch > self._min_batch_size: + current_batch = max(current_batch // 2, self._min_batch_size) + self.torch_empty_cache() + retry_count += 1 + else: + raise + if retry_count >= max_retries: + raise RuntimeError("max retry times, still cannot process batch, please check your GPU memory") + self.torch_empty_cache() + + self._dynamic_batch_size = old_dynamic_batch_size + return np.array(res) + + + def _compute_batch_scores(self, batch_pairs, max_length=None): + if max_length is None: + max_length = self._model.max_length + scores = self._model.compute_score(batch_pairs, max_length=max_length) + scores = sigmoid(np.array(scores)).tolist() + return scores def similarity(self, query: str, texts: list): pairs = [(query, truncate(t, 2048)) for t in texts] @@ -93,14 +145,7 @@ class DefaultRerank(Base): for _, t in pairs: token_count += num_tokens_from_string(t) batch_size = 4096 - res = [] - for i in range(0, len(pairs), batch_size): - scores = self._model.compute_score(pairs[i:i + batch_size], max_length=2048) - scores = sigmoid(np.array(scores)).tolist() - if isinstance(scores, float): - res.append(scores) - else: - res.extend(scores) + res = self._process_batch(pairs, max_batch_size=batch_size) return np.array(res), token_count @@ -155,14 +200,7 @@ class YoudaoRerank(DefaultRerank): for _, t in pairs: token_count += num_tokens_from_string(t) batch_size = 8 - res = [] - for i in range(0, len(pairs), batch_size): - scores = self._model.compute_score(pairs[i:i + batch_size], max_length=self._model.max_length) - scores = sigmoid(np.array(scores)).tolist() - if isinstance(scores, float): - res.append(scores) - else: - res.extend(scores) + res = self._process_batch(pairs, max_batch_size=batch_size) return np.array(res), token_count