LogitsProcessor and LogitsWarper were originally created as two separate classes, with LogitsWarper being a LogitsProcessor that should only be applied to sample-based decoding methods. In practice, they are exactly the same.

LogitsWarper was removed from HuggingFace. The LogitsWarper class in modeling_minicpmo.py has now been replaced with LogitsProcessor.
This commit is contained in:
DennisHuang 2025-05-12 03:02:50 +00:00
parent df106b958d
commit b4f5409365

View File

@ -42,7 +42,7 @@ from transformers import AutoProcessor
from transformers import BertTokenizerFast
from transformers import LlamaConfig
from transformers import LlamaModel
from transformers import LogitsWarper
from transformers import LogitsProcessor
from transformers import PreTrainedModel
from transformers import Qwen2ForCausalLM
from transformers import Qwen2PreTrainedModel
@ -184,7 +184,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
args=(),
init={"class_path": "vocos.heads.ISTFTHead", "init_args": {"dim": 512, "n_fft": 1024, "hop_length": 256}},
)
vocos = Vocos(feature_extractor, backbone, head).to("cuda").eval().to(torch.float32)
vocos = Vocos(feature_extractor, backbone, head).to(self.device).eval().to(torch.float32)
vocos.load_state_dict(torch.load(ckpt_path, weights_only=True, mmap=True))
return vocos
@ -584,6 +584,8 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
if self.config.chunk_input:
for i in range(bs):
if not audio_embeddings[i]:
continue
audio_embs = torch.cat(audio_embeddings[i], dim=0).to(
device=input_embeddings.device, dtype=input_embeddings.dtype
)
@ -1090,7 +1092,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
else:
logger.error("Invalid content type:", c)
cur_contents = "".join(cur_msgs) if omni_input else "\n".join(omni_input)
cur_contents = "".join(cur_msgs) if omni_input else "\n".join(cur_msgs)
if not self.is_first and self.new_user_msg and msg["role"] == "user": # new user add im_start
if self.llm_generated:
if self.llm_generate_completed:
@ -1205,7 +1207,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
generate_prompt = "<|im_end|>\n<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>"
input_ids = tokenizer(generate_prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].cuda()
input_ids = tokenizer(generate_prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].to(self.device)
spk_start_idx = torch.where(input_ids[0] == tokenizer.spk_start_id)[0]
spk_end_idx = torch.where(input_ids[0] == tokenizer.spk_end_id)[0]
@ -1309,7 +1311,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
text = "[Stts]" + "[spk_emb]" * self.tts.num_spk_embs
tts_input_ids = self.tts_processor.text_tokenizer(text, return_tensors="pt", add_special_tokens=False)[
"input_ids"
].cuda()
].to(self.device)
return tts_input_ids
def _build_streaming_mask(self, tts_tokens_len):
@ -1340,7 +1342,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
gen_text = text.split("<|tts_eos|>")[0]
tts_text, tts_token_lens = self.prepare_tts_text(gen_text)
tts_inputs = self.tts_processor.text_tokenizer.encode(tts_text, add_special_tokens=False)
tts_input_ids = torch.Tensor(tts_inputs).unsqueeze(0).to("cuda", dtype=torch.long)
tts_input_ids = torch.Tensor(tts_inputs).unsqueeze(0).to(self.device, dtype=torch.long)
streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
logits_warpers, logits_processors = gen_logits(
@ -1637,7 +1639,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
tts_input_ids = self.tts_processor.text_tokenizer(
tts_text, return_tensors="pt", add_special_tokens=False
)["input_ids"].cuda()
)["input_ids"].to(self.device)
text_input_ids = tts_input_ids[:, begin:end]
streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0)
@ -1746,7 +1748,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
if end > begin:
tts_input_ids = self.tts_processor.text_tokenizer(
tts_text, return_tensors="pt", add_special_tokens=False
)["input_ids"].cuda()
)["input_ids"].to(self.device)
text_input_ids = tts_input_ids[:, begin:end]
streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0)
@ -2917,7 +2919,7 @@ class ConditionalChatTTS(PreTrainedModel):
force_no_stop=False,
min_new_token=10,
max_new_token=50,
logits_warpers: List[LogitsWarper] = [],
logits_warpers: List[LogitsProcessor] = [],
logits_processors: List[CustomRepetitionPenaltyLogitsProcessorRepeat] = [],
show_tqdm=False,
):
@ -2935,7 +2937,7 @@ class ConditionalChatTTS(PreTrainedModel):
eos_token (Union[int, torch.Tensor]): End of sequence token.
streaming_tts_text_mask (Optional[torch.Tensor], optional): Mask for streaming TTS text. Defaults to None.
max_new_token (int, optional): Maximum number of new tokens to generate. Defaults to 50.
logits_warpers (List[LogitsWarper], optional): List of logits warpers. Defaults to [].
logits_warpers (List[LogitsProcessor], optional): List of logits processor. Defaults to [].
logits_processors (List[CustomRepetitionPenaltyLogitsProcessorRepeat], optional): List of logits processors. Defaults to [].
show_tqdm (bool, optional): Whether to show progress bar. Defaults to True.