mirror of
https://www.modelscope.cn/OpenBMB/MiniCPM-o-2_6.git
synced 2025-08-16 13:46:00 +08:00
Replace the inplace operation
This commit is contained in:
parent
2710956711
commit
161f4b7db3
@ -377,6 +377,8 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
vllm_embedding = self.llm.model.embed_tokens(data["input_ids"])
|
vllm_embedding = self.llm.model.embed_tokens(data["input_ids"])
|
||||||
|
|
||||||
|
new_vllm_embedding = vllm_embedding.clone()
|
||||||
|
|
||||||
vision_hidden_states = [
|
vision_hidden_states = [
|
||||||
i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
|
i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
|
||||||
]
|
]
|
||||||
@ -392,15 +394,16 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|||||||
[torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
|
[torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
|
||||||
).to(vllm_embedding.device)
|
).to(vllm_embedding.device)
|
||||||
|
|
||||||
cur_vllm_emb.scatter_(
|
new_vllm_embedding[i] = cur_vllm_emb.scatter(
|
||||||
0,
|
0,
|
||||||
image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
|
image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
|
||||||
cur_vs_hs.view(-1, cur_vs_hs.shape[-1]),
|
cur_vs_hs.view(-1, cur_vs_hs.shape[-1]),
|
||||||
)
|
)
|
||||||
elif self.training:
|
|
||||||
cur_vllm_emb += cur_vs_hs[0].mean() * 0
|
|
||||||
|
|
||||||
return vllm_embedding, vision_hidden_states
|
elif self.training:
|
||||||
|
new_vllm_embedding[i] += cur_vs_hs[0].mean() * 0
|
||||||
|
|
||||||
|
return new_vllm_embedding, vision_hidden_states
|
||||||
|
|
||||||
def get_audio_embedding_streaming(self, data):
|
def get_audio_embedding_streaming(self, data):
|
||||||
r"""
|
r"""
|
||||||
@ -595,7 +598,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|||||||
elif self.training:
|
elif self.training:
|
||||||
for i in range(bs):
|
for i in range(bs):
|
||||||
# dummy audio_embeddings
|
# dummy audio_embeddings
|
||||||
input_embeddings += audio_embeddings[0].mean() * 0
|
input_embeddings = input_embeddings + audio_embeddings[0].mean() * 0
|
||||||
|
|
||||||
return input_embeddings
|
return input_embeddings
|
||||||
|
|
||||||
@ -668,9 +671,9 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|||||||
mode:
|
mode:
|
||||||
"default": default system prompt and not refer to any task
|
"default": default system prompt and not refer to any task
|
||||||
"omni": input video and audio simultaneously
|
"omni": input video and audio simultaneously
|
||||||
"audio_assistant": Default voice-only mode, the model will use the ref_audio's voice to reply user as a helpful assistant.
|
"audio_assistant": Default voice-only mode, the model will use the ref_audio's voice to reply user's question as a helpful assistant.
|
||||||
"audio_roleplay": Roleplay voice-only model, the model will use the ref_audio's voice to reply, and also role-play the character based on the audio prompt.
|
"audio_roleplay": Roleplay voice-only mode, the model will use the ref_audio's voice to reply, and also role-play the character based on the audio prompt.
|
||||||
"voice_cloning": TTS mode, the model will clone the voice of ref_audio
|
"voice_cloning": TTS mode, the model will clone the voice of ref_audio.
|
||||||
language: prompts language, the model has the ability to automatically select the response language
|
language: prompts language, the model has the ability to automatically select the response language
|
||||||
based on the question language
|
based on the question language
|
||||||
Returns:
|
Returns:
|
||||||
@ -751,7 +754,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|||||||
input_ids=None,
|
input_ids=None,
|
||||||
pixel_values=None,
|
pixel_values=None,
|
||||||
tgt_sizes=None,
|
tgt_sizes=None,
|
||||||
audio_features=None,
|
audio_features=[],
|
||||||
audio_feature_lens=None,
|
audio_feature_lens=None,
|
||||||
image_bound=None,
|
image_bound=None,
|
||||||
audio_bounds=None,
|
audio_bounds=None,
|
||||||
@ -2982,7 +2985,7 @@ class ConditionalChatTTS(PreTrainedModel):
|
|||||||
inputs_embeds = torch.stack(code_emb, 3).sum(3)
|
inputs_embeds = torch.stack(code_emb, 3).sum(3)
|
||||||
|
|
||||||
position_ids = torch.tensor(
|
position_ids = torch.tensor(
|
||||||
[past_key_values[0][0].shape[2] + 1], dtype=torch.long, device=self.device
|
[past_key_values[0][0].shape[2]], dtype=torch.long, device=self.device
|
||||||
).unsqueeze(0)
|
).unsqueeze(0)
|
||||||
|
|
||||||
cache_position = position_ids.clone()
|
cache_position = position_ids.clone()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user