mirror of
https://www.modelscope.cn/OpenBMB/MiniCPM-o-2_6.git
synced 2025-04-18 15:49:34 +08:00
Update system prompt of assistant mode to make sure AudioEvals metric is replicable
This commit is contained in:
parent
4f835638ec
commit
beba53602a
@ -378,11 +378,11 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
||||
vllm_embedding = self.llm.model.embed_tokens(data["input_ids"])
|
||||
|
||||
new_vllm_embedding = vllm_embedding.clone()
|
||||
|
||||
|
||||
vision_hidden_states = [
|
||||
i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
|
||||
]
|
||||
|
||||
|
||||
bs = len(data["input_ids"])
|
||||
for i in range(bs):
|
||||
cur_vs_hs = vision_hidden_states[i]
|
||||
@ -484,9 +484,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
||||
Returns:
|
||||
List[List[torch.Tensor]]: audio embeddings
|
||||
"""
|
||||
dtype = self.apm.embed_positions.weight.dtype
|
||||
device = self.apm.embed_positions.weight.device
|
||||
|
||||
|
||||
wavforms = data.get("audio_features", []) # (bs, 80, frames) or [], multi audios need filled in advance
|
||||
audio_feature_lens_raw = data.get("audio_feature_lens", []) # list, [[x1, x2], [y1], [z1]]
|
||||
|
||||
@ -547,6 +545,9 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
||||
final_audio_embeds.append(target_audio_embeds)
|
||||
return final_audio_embeds
|
||||
elif self.training and dummy:
|
||||
dtype = self.apm.embed_positions.weight.dtype
|
||||
device = self.apm.embed_positions.weight.device
|
||||
|
||||
dummy_wavs = torch.zeros((1, 80, 100), device=device, dtype=dtype)
|
||||
audio_states = self.apm(dummy_wavs, output_hidden_states=True).hidden_states[self.audio_encoder_layer]
|
||||
|
||||
@ -635,6 +636,8 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
||||
return self.llm(input_ids=None, position_ids=position_ids, inputs_embeds=vllm_embedding, **kwargs)
|
||||
|
||||
def _decode(self, inputs_embeds, tokenizer, attention_mask, **kwargs):
|
||||
kwargs.pop("output_hidden_states", None)
|
||||
kwargs.pop("return_dict_in_generate", None)
|
||||
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
|
||||
outputs = self.llm.generate(
|
||||
inputs_embeds=inputs_embeds,
|
||||
@ -716,8 +719,8 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
||||
vc_prompt_prefix = "模仿输入音频中的声音特征。"
|
||||
vc_prompt_suffix = "作为助手,你将使用这种声音风格说话。"
|
||||
else:
|
||||
vc_prompt_prefix = "Clone the voice in the provided audio prompt."
|
||||
vc_prompt_suffix = "As an assistant, you will speak using this voice style."
|
||||
vc_prompt_prefix = "Use the voice in the audio prompt to synthesize new content."
|
||||
vc_prompt_suffix = "You are a helpful assistant with the above voice style."
|
||||
|
||||
if ref_audio is not None:
|
||||
sys_msgs = {"role": "user", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]}
|
||||
@ -776,6 +779,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
||||
tokenizer=None,
|
||||
vision_hidden_states=None,
|
||||
stream=False,
|
||||
decode_text=True,
|
||||
**kwargs,
|
||||
):
|
||||
assert input_ids is not None
|
||||
@ -813,7 +817,10 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
||||
outputs = self._decode(model_inputs["inputs_embeds"], tokenizer, attention_mask, **kwargs)
|
||||
|
||||
result = self._decode_text(outputs.sequences, tokenizer)
|
||||
|
||||
|
||||
if decode_text is False:
|
||||
return outputs
|
||||
|
||||
return result, outputs
|
||||
|
||||
def chat(
|
||||
@ -2671,6 +2678,7 @@ class ConditionalChatTTS(PreTrainedModel):
|
||||
"""
|
||||
|
||||
config_class = ConditionalChatTTSConfig
|
||||
_no_split_modules = []
|
||||
|
||||
def __init__(self, config: ConditionalChatTTSConfig):
|
||||
super().__init__(config)
|
||||
|
Loading…
x
Reference in New Issue
Block a user