From 64edf3d7235efd248e3be313992778d71935997e Mon Sep 17 00:00:00 2001 From: Hongji Zhu Date: Wed, 22 Jan 2025 11:40:42 +0800 Subject: [PATCH] support audio finetuning --- modeling_minicpmo.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/modeling_minicpmo.py b/modeling_minicpmo.py index a37b06c..924de65 100644 --- a/modeling_minicpmo.py +++ b/modeling_minicpmo.py @@ -466,7 +466,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel): else: return [] - def get_audio_embedding(self, data, chunk_length=-1): + def get_audio_embedding(self, data, chunk_length=-1, dummy=True): r""" Extract full audio embeddings with optional chunk-based attention. @@ -484,6 +484,8 @@ class MiniCPMO(MiniCPMOPreTrainedModel): Returns: List[List[torch.Tensor]]: audio embeddings """ + dtype = self.apm.embed_positions.weight.dtype + device = self.apm.embed_positions.weight.device wavforms = data.get("audio_features", []) # (bs, 80, frames) or [], multi audios need filled in advance audio_feature_lens_raw = data.get("audio_feature_lens", []) # list, [[x1, x2], [y1], [z1]] @@ -544,6 +546,17 @@ class MiniCPMO(MiniCPMOPreTrainedModel): idx += 1 final_audio_embeds.append(target_audio_embeds) return final_audio_embeds + elif self.training and dummy: + dummy_wavs = torch.zeros((1, 80, 100), device=device, dtype=dtype) + audio_states = self.apm(dummy_wavs, output_hidden_states=True).hidden_states[self.audio_encoder_layer] + + audio_embeds = self.audio_projection_layer(audio_states) + + audio_embeds = audio_embeds.transpose(1, 2) + audio_embeds = self.audio_avg_pooler(audio_embeds) + audio_embeds = audio_embeds.transpose(1, 2) + return [audio_embeds] + else: return [] @@ -576,7 +589,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel): audio_start_pos = 0 for bound in audio_bounds[i]: audio_len = bound[1] - bound[0] - input_embeddings[0, bound[0] : bound[1]] = audio_embs[ + input_embeddings[i, bound[0] : bound[1]] = audio_embs[ audio_start_pos : audio_start_pos + audio_len, : ] audio_start_pos += audio_len