mirror of
https://www.modelscope.cn/OpenBMB/MiniCPM-o-2_6.git
synced 2025-07-31 01:05:59 +08:00
Update system prompt of assistant mode to make sure AudioEvals metric is replicable
This commit is contained in:
parent
4f835638ec
commit
beba53602a
@ -484,8 +484,6 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|||||||
Returns:
|
Returns:
|
||||||
List[List[torch.Tensor]]: audio embeddings
|
List[List[torch.Tensor]]: audio embeddings
|
||||||
"""
|
"""
|
||||||
dtype = self.apm.embed_positions.weight.dtype
|
|
||||||
device = self.apm.embed_positions.weight.device
|
|
||||||
|
|
||||||
wavforms = data.get("audio_features", []) # (bs, 80, frames) or [], multi audios need filled in advance
|
wavforms = data.get("audio_features", []) # (bs, 80, frames) or [], multi audios need filled in advance
|
||||||
audio_feature_lens_raw = data.get("audio_feature_lens", []) # list, [[x1, x2], [y1], [z1]]
|
audio_feature_lens_raw = data.get("audio_feature_lens", []) # list, [[x1, x2], [y1], [z1]]
|
||||||
@ -547,6 +545,9 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|||||||
final_audio_embeds.append(target_audio_embeds)
|
final_audio_embeds.append(target_audio_embeds)
|
||||||
return final_audio_embeds
|
return final_audio_embeds
|
||||||
elif self.training and dummy:
|
elif self.training and dummy:
|
||||||
|
dtype = self.apm.embed_positions.weight.dtype
|
||||||
|
device = self.apm.embed_positions.weight.device
|
||||||
|
|
||||||
dummy_wavs = torch.zeros((1, 80, 100), device=device, dtype=dtype)
|
dummy_wavs = torch.zeros((1, 80, 100), device=device, dtype=dtype)
|
||||||
audio_states = self.apm(dummy_wavs, output_hidden_states=True).hidden_states[self.audio_encoder_layer]
|
audio_states = self.apm(dummy_wavs, output_hidden_states=True).hidden_states[self.audio_encoder_layer]
|
||||||
|
|
||||||
@ -635,6 +636,8 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|||||||
return self.llm(input_ids=None, position_ids=position_ids, inputs_embeds=vllm_embedding, **kwargs)
|
return self.llm(input_ids=None, position_ids=position_ids, inputs_embeds=vllm_embedding, **kwargs)
|
||||||
|
|
||||||
def _decode(self, inputs_embeds, tokenizer, attention_mask, **kwargs):
|
def _decode(self, inputs_embeds, tokenizer, attention_mask, **kwargs):
|
||||||
|
kwargs.pop("output_hidden_states", None)
|
||||||
|
kwargs.pop("return_dict_in_generate", None)
|
||||||
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
|
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
|
||||||
outputs = self.llm.generate(
|
outputs = self.llm.generate(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
@ -716,8 +719,8 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|||||||
vc_prompt_prefix = "模仿输入音频中的声音特征。"
|
vc_prompt_prefix = "模仿输入音频中的声音特征。"
|
||||||
vc_prompt_suffix = "作为助手,你将使用这种声音风格说话。"
|
vc_prompt_suffix = "作为助手,你将使用这种声音风格说话。"
|
||||||
else:
|
else:
|
||||||
vc_prompt_prefix = "Clone the voice in the provided audio prompt."
|
vc_prompt_prefix = "Use the voice in the audio prompt to synthesize new content."
|
||||||
vc_prompt_suffix = "As an assistant, you will speak using this voice style."
|
vc_prompt_suffix = "You are a helpful assistant with the above voice style."
|
||||||
|
|
||||||
if ref_audio is not None:
|
if ref_audio is not None:
|
||||||
sys_msgs = {"role": "user", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]}
|
sys_msgs = {"role": "user", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]}
|
||||||
@ -776,6 +779,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|||||||
tokenizer=None,
|
tokenizer=None,
|
||||||
vision_hidden_states=None,
|
vision_hidden_states=None,
|
||||||
stream=False,
|
stream=False,
|
||||||
|
decode_text=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
assert input_ids is not None
|
assert input_ids is not None
|
||||||
@ -814,6 +818,9 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|||||||
|
|
||||||
result = self._decode_text(outputs.sequences, tokenizer)
|
result = self._decode_text(outputs.sequences, tokenizer)
|
||||||
|
|
||||||
|
if decode_text is False:
|
||||||
|
return outputs
|
||||||
|
|
||||||
return result, outputs
|
return result, outputs
|
||||||
|
|
||||||
def chat(
|
def chat(
|
||||||
@ -2671,6 +2678,7 @@ class ConditionalChatTTS(PreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
config_class = ConditionalChatTTSConfig
|
config_class = ConditionalChatTTSConfig
|
||||||
|
_no_split_modules = []
|
||||||
|
|
||||||
def __init__(self, config: ConditionalChatTTSConfig):
|
def __init__(self, config: ConditionalChatTTSConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user