From 161f4b7db3d4a78b8329e6d86ed8992582bb1be2 Mon Sep 17 00:00:00 2001 From: Hongji Zhu Date: Mon, 20 Jan 2025 11:56:10 +0800 Subject: [PATCH] Replace the inplace operation --- modeling_minicpmo.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/modeling_minicpmo.py b/modeling_minicpmo.py index 913e0d8..a37b06c 100644 --- a/modeling_minicpmo.py +++ b/modeling_minicpmo.py @@ -377,6 +377,8 @@ class MiniCPMO(MiniCPMOPreTrainedModel): else: vllm_embedding = self.llm.model.embed_tokens(data["input_ids"]) + new_vllm_embedding = vllm_embedding.clone() + vision_hidden_states = [ i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states ] @@ -392,15 +394,16 @@ class MiniCPMO(MiniCPMOPreTrainedModel): [torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound] ).to(vllm_embedding.device) - cur_vllm_emb.scatter_( + new_vllm_embedding[i] = cur_vllm_emb.scatter( 0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]), cur_vs_hs.view(-1, cur_vs_hs.shape[-1]), ) - elif self.training: - cur_vllm_emb += cur_vs_hs[0].mean() * 0 - return vllm_embedding, vision_hidden_states + elif self.training: + new_vllm_embedding[i] += cur_vs_hs[0].mean() * 0 + + return new_vllm_embedding, vision_hidden_states def get_audio_embedding_streaming(self, data): r""" @@ -595,7 +598,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel): elif self.training: for i in range(bs): # dummy audio_embeddings - input_embeddings += audio_embeddings[0].mean() * 0 + input_embeddings = input_embeddings + audio_embeddings[0].mean() * 0 return input_embeddings @@ -668,9 +671,9 @@ class MiniCPMO(MiniCPMOPreTrainedModel): mode: "default": default system prompt and not refer to any task "omni": input video and audio simultaneously - "audio_assistant": Default voice-only mode, the model will use the ref_audio's voice to reply user as a helpful assistant. - "audio_roleplay": Roleplay voice-only model, the model will use the ref_audio's voice to reply, and also role-play the character based on the audio prompt. - "voice_cloning": TTS mode, the model will clone the voice of ref_audio + "audio_assistant": Default voice-only mode, the model will use the ref_audio's voice to reply user's question as a helpful assistant. + "audio_roleplay": Roleplay voice-only mode, the model will use the ref_audio's voice to reply, and also role-play the character based on the audio prompt. + "voice_cloning": TTS mode, the model will clone the voice of ref_audio. language: prompts language, the model has the ability to automatically select the response language based on the question language Returns: @@ -751,7 +754,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel): input_ids=None, pixel_values=None, tgt_sizes=None, - audio_features=None, + audio_features=[], audio_feature_lens=None, image_bound=None, audio_bounds=None, @@ -2982,7 +2985,7 @@ class ConditionalChatTTS(PreTrainedModel): inputs_embeds = torch.stack(code_emb, 3).sum(3) position_ids = torch.tensor( - [past_key_values[0][0].shape[2] + 1], dtype=torch.long, device=self.device + [past_key_values[0][0].shape[2]], dtype=torch.long, device=self.device ).unsqueeze(0) cache_position = position_ids.clone()