mirror of
https://www.modelscope.cn/OpenBMB/MiniCPM-o-2_6-int4.git
synced 2025-04-19 07:59:31 +08:00
update
This commit is contained in:
parent
583218859c
commit
8995e34672
@ -377,6 +377,8 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
||||
else:
|
||||
vllm_embedding = self.llm.model.embed_tokens(data["input_ids"])
|
||||
|
||||
new_vllm_embedding = vllm_embedding.clone()
|
||||
|
||||
vision_hidden_states = [
|
||||
i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
|
||||
]
|
||||
@ -392,15 +394,16 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
||||
[torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
|
||||
).to(vllm_embedding.device)
|
||||
|
||||
cur_vllm_emb.scatter_(
|
||||
new_vllm_embedding[i] = cur_vllm_emb.scatter(
|
||||
0,
|
||||
image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
|
||||
cur_vs_hs.view(-1, cur_vs_hs.shape[-1]),
|
||||
)
|
||||
elif self.training:
|
||||
cur_vllm_emb += cur_vs_hs[0].mean() * 0
|
||||
|
||||
return vllm_embedding, vision_hidden_states
|
||||
elif self.training:
|
||||
new_vllm_embedding[i] += cur_vs_hs[0].mean() * 0
|
||||
|
||||
return new_vllm_embedding, vision_hidden_states
|
||||
|
||||
def get_audio_embedding_streaming(self, data):
|
||||
r"""
|
||||
@ -463,7 +466,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_audio_embedding(self, data, chunk_length=-1):
|
||||
def get_audio_embedding(self, data, chunk_length=-1, dummy=True):
|
||||
r"""
|
||||
Extract full audio embeddings with optional chunk-based attention.
|
||||
|
||||
@ -481,6 +484,8 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
||||
Returns:
|
||||
List[List[torch.Tensor]]: audio embeddings
|
||||
"""
|
||||
dtype = self.apm.embed_positions.weight.dtype
|
||||
device = self.apm.embed_positions.weight.device
|
||||
|
||||
wavforms = data.get("audio_features", []) # (bs, 80, frames) or [], multi audios need filled in advance
|
||||
audio_feature_lens_raw = data.get("audio_feature_lens", []) # list, [[x1, x2], [y1], [z1]]
|
||||
@ -541,6 +546,17 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
||||
idx += 1
|
||||
final_audio_embeds.append(target_audio_embeds)
|
||||
return final_audio_embeds
|
||||
elif self.training and dummy:
|
||||
dummy_wavs = torch.zeros((1, 80, 100), device=device, dtype=dtype)
|
||||
audio_states = self.apm(dummy_wavs, output_hidden_states=True).hidden_states[self.audio_encoder_layer]
|
||||
|
||||
audio_embeds = self.audio_projection_layer(audio_states)
|
||||
|
||||
audio_embeds = audio_embeds.transpose(1, 2)
|
||||
audio_embeds = self.audio_avg_pooler(audio_embeds)
|
||||
audio_embeds = audio_embeds.transpose(1, 2)
|
||||
return [audio_embeds]
|
||||
|
||||
else:
|
||||
return []
|
||||
|
||||
@ -573,7 +589,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
||||
audio_start_pos = 0
|
||||
for bound in audio_bounds[i]:
|
||||
audio_len = bound[1] - bound[0]
|
||||
input_embeddings[0, bound[0] : bound[1]] = audio_embs[
|
||||
input_embeddings[i, bound[0] : bound[1]] = audio_embs[
|
||||
audio_start_pos : audio_start_pos + audio_len, :
|
||||
]
|
||||
audio_start_pos += audio_len
|
||||
@ -595,7 +611,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
||||
elif self.training:
|
||||
for i in range(bs):
|
||||
# dummy audio_embeddings
|
||||
input_embeddings += audio_embeddings[0].mean() * 0
|
||||
input_embeddings = input_embeddings + audio_embeddings[0].mean() * 0
|
||||
|
||||
return input_embeddings
|
||||
|
||||
@ -751,7 +767,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
||||
input_ids=None,
|
||||
pixel_values=None,
|
||||
tgt_sizes=None,
|
||||
audio_features=None,
|
||||
audio_features=[],
|
||||
audio_feature_lens=None,
|
||||
image_bound=None,
|
||||
audio_bounds=None,
|
||||
@ -2982,7 +2998,7 @@ class ConditionalChatTTS(PreTrainedModel):
|
||||
inputs_embeds = torch.stack(code_emb, 3).sum(3)
|
||||
|
||||
position_ids = torch.tensor(
|
||||
[past_key_values[0][0].shape[2] + 1], dtype=torch.long, device=self.device
|
||||
[past_key_values[0][0].shape[2]], dtype=torch.long, device=self.device
|
||||
).unsqueeze(0)
|
||||
|
||||
cache_position = position_ids.clone()
|
||||
|
Loading…
x
Reference in New Issue
Block a user