From b4f540936555db0d07084e84497dda2cc10b5b16 Mon Sep 17 00:00:00 2001 From: DennisHuang Date: Mon, 12 May 2025 03:02:50 +0000 Subject: [PATCH] =?UTF-8?q?LogitsProcessor=C2=A0and=C2=A0LogitsWarper?= =?UTF-8?q?=C2=A0were=20originally=20created=20as=20two=20separate=20class?= =?UTF-8?q?es,=20with=C2=A0LogitsWarper=C2=A0being=20a=C2=A0LogitsProcesso?= =?UTF-8?q?r=C2=A0that=20should=20only=20be=20applied=20to=20sample-based?= =?UTF-8?q?=20decoding=20methods.=20In=20practice,=20they=20are=20exactly?= =?UTF-8?q?=20the=20same.=20LogitsWarper=20was=20removed=20from=20HuggingF?= =?UTF-8?q?ace.=20The=20LogitsWarper=20class=20in=20modeling=5Fminicpmo.py?= =?UTF-8?q?=20has=20now=20been=20replaced=20with=20LogitsProcessor.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- modeling_minicpmo.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/modeling_minicpmo.py b/modeling_minicpmo.py index 112fca9..388f4a3 100644 --- a/modeling_minicpmo.py +++ b/modeling_minicpmo.py @@ -42,7 +42,7 @@ from transformers import AutoProcessor from transformers import BertTokenizerFast from transformers import LlamaConfig from transformers import LlamaModel -from transformers import LogitsWarper +from transformers import LogitsProcessor from transformers import PreTrainedModel from transformers import Qwen2ForCausalLM from transformers import Qwen2PreTrainedModel @@ -184,7 +184,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel): args=(), init={"class_path": "vocos.heads.ISTFTHead", "init_args": {"dim": 512, "n_fft": 1024, "hop_length": 256}}, ) - vocos = Vocos(feature_extractor, backbone, head).to("cuda").eval().to(torch.float32) + vocos = Vocos(feature_extractor, backbone, head).to(self.device).eval().to(torch.float32) vocos.load_state_dict(torch.load(ckpt_path, weights_only=True, mmap=True)) return vocos @@ -584,6 +584,8 @@ class MiniCPMO(MiniCPMOPreTrainedModel): if self.config.chunk_input: for i in range(bs): + if not audio_embeddings[i]: + continue audio_embs = torch.cat(audio_embeddings[i], dim=0).to( device=input_embeddings.device, dtype=input_embeddings.dtype ) @@ -1090,7 +1092,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel): else: logger.error("Invalid content type:", c) - cur_contents = "".join(cur_msgs) if omni_input else "\n".join(omni_input) + cur_contents = "".join(cur_msgs) if omni_input else "\n".join(cur_msgs) if not self.is_first and self.new_user_msg and msg["role"] == "user": # new user add im_start if self.llm_generated: if self.llm_generate_completed: @@ -1205,7 +1207,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel): terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators] generate_prompt = "<|im_end|>\n<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>" - input_ids = tokenizer(generate_prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].cuda() + input_ids = tokenizer(generate_prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].to(self.device) spk_start_idx = torch.where(input_ids[0] == tokenizer.spk_start_id)[0] spk_end_idx = torch.where(input_ids[0] == tokenizer.spk_end_id)[0] @@ -1309,7 +1311,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel): text = "[Stts]" + "[spk_emb]" * self.tts.num_spk_embs tts_input_ids = self.tts_processor.text_tokenizer(text, return_tensors="pt", add_special_tokens=False)[ "input_ids" - ].cuda() + ].to(self.device) return tts_input_ids def _build_streaming_mask(self, tts_tokens_len): @@ -1340,7 +1342,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel): gen_text = text.split("<|tts_eos|>")[0] tts_text, tts_token_lens = self.prepare_tts_text(gen_text) tts_inputs = self.tts_processor.text_tokenizer.encode(tts_text, add_special_tokens=False) - tts_input_ids = torch.Tensor(tts_inputs).unsqueeze(0).to("cuda", dtype=torch.long) + tts_input_ids = torch.Tensor(tts_inputs).unsqueeze(0).to(self.device, dtype=torch.long) streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device) logits_warpers, logits_processors = gen_logits( @@ -1637,7 +1639,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel): tts_input_ids = self.tts_processor.text_tokenizer( tts_text, return_tensors="pt", add_special_tokens=False - )["input_ids"].cuda() + )["input_ids"].to(self.device) text_input_ids = tts_input_ids[:, begin:end] streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device) position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0) @@ -1746,7 +1748,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel): if end > begin: tts_input_ids = self.tts_processor.text_tokenizer( tts_text, return_tensors="pt", add_special_tokens=False - )["input_ids"].cuda() + )["input_ids"].to(self.device) text_input_ids = tts_input_ids[:, begin:end] streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device) position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0) @@ -2917,7 +2919,7 @@ class ConditionalChatTTS(PreTrainedModel): force_no_stop=False, min_new_token=10, max_new_token=50, - logits_warpers: List[LogitsWarper] = [], + logits_warpers: List[LogitsProcessor] = [], logits_processors: List[CustomRepetitionPenaltyLogitsProcessorRepeat] = [], show_tqdm=False, ): @@ -2935,7 +2937,7 @@ class ConditionalChatTTS(PreTrainedModel): eos_token (Union[int, torch.Tensor]): End of sequence token. streaming_tts_text_mask (Optional[torch.Tensor], optional): Mask for streaming TTS text. Defaults to None. max_new_token (int, optional): Maximum number of new tokens to generate. Defaults to 50. - logits_warpers (List[LogitsWarper], optional): List of logits warpers. Defaults to []. + logits_warpers (List[LogitsProcessor], optional): List of logits processor. Defaults to []. logits_processors (List[CustomRepetitionPenaltyLogitsProcessorRepeat], optional): List of logits processors. Defaults to []. show_tqdm (bool, optional): Whether to show progress bar. Defaults to True.