mirror of
https://www.modelscope.cn/OpenBMB/MiniCPM-o-2_6.git
synced 2025-08-14 20:56:03 +08:00
LogitsProcessor and LogitsWarper were originally created as two separate classes, with LogitsWarper being a LogitsProcessor that should only be applied to sample-based decoding methods. In practice, they are exactly the same.
LogitsWarper was removed from HuggingFace. The LogitsWarper class in modeling_minicpmo.py has now been replaced with LogitsProcessor.
This commit is contained in:
parent
df106b958d
commit
b4f5409365
@ -42,7 +42,7 @@ from transformers import AutoProcessor
|
||||
from transformers import BertTokenizerFast
|
||||
from transformers import LlamaConfig
|
||||
from transformers import LlamaModel
|
||||
from transformers import LogitsWarper
|
||||
from transformers import LogitsProcessor
|
||||
from transformers import PreTrainedModel
|
||||
from transformers import Qwen2ForCausalLM
|
||||
from transformers import Qwen2PreTrainedModel
|
||||
@ -184,7 +184,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
||||
args=(),
|
||||
init={"class_path": "vocos.heads.ISTFTHead", "init_args": {"dim": 512, "n_fft": 1024, "hop_length": 256}},
|
||||
)
|
||||
vocos = Vocos(feature_extractor, backbone, head).to("cuda").eval().to(torch.float32)
|
||||
vocos = Vocos(feature_extractor, backbone, head).to(self.device).eval().to(torch.float32)
|
||||
vocos.load_state_dict(torch.load(ckpt_path, weights_only=True, mmap=True))
|
||||
return vocos
|
||||
|
||||
@ -584,6 +584,8 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
||||
|
||||
if self.config.chunk_input:
|
||||
for i in range(bs):
|
||||
if not audio_embeddings[i]:
|
||||
continue
|
||||
audio_embs = torch.cat(audio_embeddings[i], dim=0).to(
|
||||
device=input_embeddings.device, dtype=input_embeddings.dtype
|
||||
)
|
||||
@ -1090,7 +1092,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
||||
else:
|
||||
logger.error("Invalid content type:", c)
|
||||
|
||||
cur_contents = "".join(cur_msgs) if omni_input else "\n".join(omni_input)
|
||||
cur_contents = "".join(cur_msgs) if omni_input else "\n".join(cur_msgs)
|
||||
if not self.is_first and self.new_user_msg and msg["role"] == "user": # new user add im_start
|
||||
if self.llm_generated:
|
||||
if self.llm_generate_completed:
|
||||
@ -1205,7 +1207,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
||||
|
||||
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
|
||||
generate_prompt = "<|im_end|>\n<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>"
|
||||
input_ids = tokenizer(generate_prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].cuda()
|
||||
input_ids = tokenizer(generate_prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].to(self.device)
|
||||
|
||||
spk_start_idx = torch.where(input_ids[0] == tokenizer.spk_start_id)[0]
|
||||
spk_end_idx = torch.where(input_ids[0] == tokenizer.spk_end_id)[0]
|
||||
@ -1309,7 +1311,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
||||
text = "[Stts]" + "[spk_emb]" * self.tts.num_spk_embs
|
||||
tts_input_ids = self.tts_processor.text_tokenizer(text, return_tensors="pt", add_special_tokens=False)[
|
||||
"input_ids"
|
||||
].cuda()
|
||||
].to(self.device)
|
||||
return tts_input_ids
|
||||
|
||||
def _build_streaming_mask(self, tts_tokens_len):
|
||||
@ -1340,7 +1342,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
||||
gen_text = text.split("<|tts_eos|>")[0]
|
||||
tts_text, tts_token_lens = self.prepare_tts_text(gen_text)
|
||||
tts_inputs = self.tts_processor.text_tokenizer.encode(tts_text, add_special_tokens=False)
|
||||
tts_input_ids = torch.Tensor(tts_inputs).unsqueeze(0).to("cuda", dtype=torch.long)
|
||||
tts_input_ids = torch.Tensor(tts_inputs).unsqueeze(0).to(self.device, dtype=torch.long)
|
||||
streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
|
||||
|
||||
logits_warpers, logits_processors = gen_logits(
|
||||
@ -1637,7 +1639,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
||||
|
||||
tts_input_ids = self.tts_processor.text_tokenizer(
|
||||
tts_text, return_tensors="pt", add_special_tokens=False
|
||||
)["input_ids"].cuda()
|
||||
)["input_ids"].to(self.device)
|
||||
text_input_ids = tts_input_ids[:, begin:end]
|
||||
streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
|
||||
position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0)
|
||||
@ -1746,7 +1748,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
||||
if end > begin:
|
||||
tts_input_ids = self.tts_processor.text_tokenizer(
|
||||
tts_text, return_tensors="pt", add_special_tokens=False
|
||||
)["input_ids"].cuda()
|
||||
)["input_ids"].to(self.device)
|
||||
text_input_ids = tts_input_ids[:, begin:end]
|
||||
streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
|
||||
position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0)
|
||||
@ -2917,7 +2919,7 @@ class ConditionalChatTTS(PreTrainedModel):
|
||||
force_no_stop=False,
|
||||
min_new_token=10,
|
||||
max_new_token=50,
|
||||
logits_warpers: List[LogitsWarper] = [],
|
||||
logits_warpers: List[LogitsProcessor] = [],
|
||||
logits_processors: List[CustomRepetitionPenaltyLogitsProcessorRepeat] = [],
|
||||
show_tqdm=False,
|
||||
):
|
||||
@ -2935,7 +2937,7 @@ class ConditionalChatTTS(PreTrainedModel):
|
||||
eos_token (Union[int, torch.Tensor]): End of sequence token.
|
||||
streaming_tts_text_mask (Optional[torch.Tensor], optional): Mask for streaming TTS text. Defaults to None.
|
||||
max_new_token (int, optional): Maximum number of new tokens to generate. Defaults to 50.
|
||||
logits_warpers (List[LogitsWarper], optional): List of logits warpers. Defaults to [].
|
||||
logits_warpers (List[LogitsProcessor], optional): List of logits processor. Defaults to [].
|
||||
logits_processors (List[CustomRepetitionPenaltyLogitsProcessorRepeat], optional): List of logits processors. Defaults to [].
|
||||
show_tqdm (bool, optional): Whether to show progress bar. Defaults to True.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user