Update resampler.py (#45)

- Update resampler.py (585890d29b440e43c32ad8fedb5a8579aa244be1)
- Update resampler.py (dd98ddb3845e288e3f7e2abf0d25f3eedbdf5d1c)


Co-authored-by: Ekaterina Aidova <katuni4ka@users.noreply.huggingface.co>
This commit is contained in:
Cherrytest 2025-05-14 01:20:26 +00:00
parent df106b958d
commit 1fb2f12e58
16 changed files with 119 additions and 303679 deletions

3
.gitattributes vendored
View File

@ -49,3 +49,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*jpg filter=lfs diff=lfs merge=lfs -text
*gif filter=lfs diff=lfs merge=lfs -text
*.wav filter=lfs diff=lfs merge=lfs -text
assets/Skiing.mp4 filter=lfs diff=lfs merge=lfs -text
tokenizer.json filter=lfs diff=lfs merge=lfs -text

159
README.md
View File

@ -1,60 +1,85 @@
---
frameworks:
- Pytorch
license: other
tasks:
- any-to-any
pipeline_tag: any-to-any
datasets:
- openbmb/RLAIF-V-Dataset
library_name: transformers
language:
- multilingual
tags:
- minicpm-o
- omni
- vision
- ocr
- multi-image
- video
- custom_code
- audio
- speech
- voice cloning
- live Streaming
- realtime speech conversation
- asr
- tts
---
<h1>端侧可用的 GPT-4o 级视觉、语音、多模态流式大模型</h1>
<h1>A GPT-4o Level MLLM for Vision, Speech and Multimodal Live Streaming on Your Phone</h1>
[GitHub](https://github.com/OpenBMB/MiniCPM-o) | [Online Demo](https://minicpm-omni-webdemo-us.modelbest.cn) </a>
[GitHub](https://github.com/OpenBMB/MiniCPM-o) | [Online Demo](https://minicpm-omni-webdemo-us.modelbest.cn) | [Technical Blog](https://openbmb.notion.site/MiniCPM-o-2-6-A-GPT-4o-Level-MLLM-for-Vision-Speech-and-Multimodal-Live-Streaming-on-Your-Phone-185ede1b7a558042b5d5e45e6b237da9)
### News
* [2025.03.01] 🚀🚀🚀 RLAIF-V, which is the alignment technique of MiniCPM-o, is accepted by CVPR 2025The [code](https://github.com/RLHF-V/RLAIF-V), [dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset), [paper](https://arxiv.org/abs/2405.17220) are open-sourced!
* [2025.01.24] 📢📢📢 MiniCPM-o 2.6 technical report is released! [See Here](https://openbmb.notion.site/MiniCPM-o-2-6-A-GPT-4o-Level-MLLM-for-Vision-Speech-and-Multimodal-Live-Streaming-on-Your-Phone-185ede1b7a558042b5d5e45e6b237da9).
* [2025.01.19] ⭐️⭐️⭐️ MiniCPM-o tops GitHub Trending and reaches top-2 on Hugging Face Trending!
## MiniCPM-o 2.6
MiniCPM-o 2.6 是 MiniCPM-o 系列的最新、性能最佳模型。该模型基于 SigLip-400M、Whisper-medium-300M、ChatTTS-200M 和 Qwen2.5-7B 构建,共 8B 参数,通过端到端方式训练和推理。相比 MiniCPM-V 2.6该模型在性能上有了显著提升并支持了实时语音对话和多模态流式交互的新功能。MiniCPM-o 2.6 的主要特性包括:
**MiniCPM-o 2.6** is the latest and most capable model in the MiniCPM-o series. The model is built in an end-to-end fashion based on SigLip-400M, Whisper-medium-300M, ChatTTS-200M, and Qwen2.5-7B with a total of 8B parameters. It exhibits a significant performance improvement over MiniCPM-V 2.6, and introduces new features for real-time speech conversation and multimodal live streaming. Notable features of MiniCPM-o 2.6 include:
- 🔥 **Leading Visual Capability.**
MiniCPM-o 2.6 achieves an average score of 70.2 on OpenCompass, a comprehensive evaluation over 8 popular benchmarks. **With only 8B parameters, it surpasses widely used proprietary models like GPT-4o-202405, Gemini 1.5 Pro, and Claude 3.5 Sonnet** for single image understanding. It also **outperforms GPT-4V and Claude 3.5 Sonnet** in mutli-image and video understanding, and shows promising in-context learning capability.
- 🎙 **State-of-the-art Speech Capability.** MiniCPM-o 2.6 supports **bilingual real-time speech conversation with configurable voices** in English and Chinese. It **outperforms GPT-4o-realtime on audio understanding tasks** such as ASR and STT translation, and shows **state-of-the-art performance on speech conversation in both semantic and acoustic evaluations in the open-source community**. It also allows for fun features such as emotion/speed/style control, end-to-end voice cloning, role play, etc.
- 🎬 **Strong Multimodal Live Streaming Capability.** As a new feature, MiniCPM-o 2.6 can **accept continous video and audio streams independent of user queries, and support real-time speech interaction**. It **outperforms GPT-4o-202408 and Claude 3.5 Sonnet and shows state-of-art performance in open-source community on StreamingBench**, a comprehensive benchmark for real-time video understanding, omni-source (video & audio) understanding, and multimodal contextual understanding.
- 💪 **Strong OCR Capability and Others.**
Advancing popular visual capabilites from MiniCPM-V series, MiniCPM-o 2.6 can process images with any aspect ratio and up to 1.8 million pixels (e.g., 1344x1344). It achieves **state-of-the-art performance on OCRBench for models under 25B, surpassing proprietary models such as GPT-4o-202405**.
Based on the the latest [RLAIF-V](https://github.com/RLHF-V/RLAIF-V/) and [VisCPM](https://github.com/OpenBMB/VisCPM) techniques, it features **trustworthy behaviors**, outperforming GPT-4o and Claude 3.5 Sonnet on MMHal-Bench, and supports **multilingual capabilities** on more than 30 languages.
- 🔥 **领先的视觉能力。**
MiniCPM-o 2.6 在 OpenCompass 榜单上(综合 8 个主流多模态评测基准)平均得分 70.2**以 8B 量级的大小在单图理解方面超越了 GPT-4o-202405、Gemini 1.5 Pro 和 Claude 3.5 Sonnet 等主流商用闭源多模态大模型**。此外,它的多图和视频理解表现也**优于 GPT-4V 和 Claude 3.5 Sonnet**,并展现出了优秀的上下文学习能力。
- 🚀 **Superior Efficiency.**
In addition to its friendly size, MiniCPM-o 2.6 also shows **state-of-the-art token density** (i.e., number of pixels encoded into each visual token). **It produces only 640 tokens when processing a 1.8M pixel image, which is 75% fewer than most models**. This directly improves the inference speed, first-token latency, memory usage, and power consumption. As a result, MiniCPM-o 2.6 can efficiently support **multimodal live streaming** on end-side devices such as iPad.
- 🎙 **出色的语音能力。**
MiniCPM-o 2.6 **支持可配置声音的中英双语实时对话**。MiniCPM-o 2.6 在语音理解任务(如 ASR 和 STT translation上的表现**优于 GPT-4o-realtime**,并在语音对话的语义和声学评估中展现了**开源模型中最高的语音生成性能**。它还支持情绪/语速/风格控制、语音克隆、角色扮演等进阶能力。
- 🎬 **强大的多模态流式交互能力。**
作为一项新功能MiniCPM-o 2.6 能够**接受连续的视频和音频流,并和用户进行实时语音交互**。在 StreamingBench针对实时视频理解、全模态视/音频理解、多模态上下文理解的综合评测基准MiniCPM-o 2.6 获得开源模型最高分并**超过了 GPT-4o-realtime 和 Claude 3.5 Sonnet**。
- 💪 **强大的 OCR 能力及其他功能。**
MiniCPM-o 2.6 进一步优化了 MiniCPM-V 2.6 的众多视觉理解能力,其可以处理任意长宽比的图像,像素数可达 180 万(如 1344x1344。在 OCRBench 上取得**25B 以下最佳水平,超过 GPT-4o-202405 等商用闭源模型**。基于最新的 [RLHF-V](https://rlhf-v.github.io/)、[RLAIF-V](https://github.com/RLHF-V/RLAIF-V/) 和 [VisCPM](https://github.com/OpenBMB/VisCPM) 技术,其具备了**可信的多模态行为**,在 MMHal-Bench 上超过了 GPT-4o 和 Claude 3.5,并支持英语、中文、德语、法语、意大利语、韩语等**多种语言**。
- 🚀 **卓越的效率。**
除了对个人用户友好的模型大小MiniCPM-o 2.6 还表现出**最先进的视觉 token 密度**(即每个视觉 token 编码的像素数量)。它**仅需 640 个 token 即可处理 180 万像素图像,比大多数模型少 75%**。这一特性优化了模型的推理速度、首 token 延迟、内存占用和功耗。因此MiniCPM-o 2.6 可以支持 iPad 等终端设备上的高效**多模态流式交互**。
- 💫 **Easy Usage.**
MiniCPM-o 2.6 can be easily used in various ways: (1) [llama.cpp](https://github.com/OpenBMB/llama.cpp/blob/minicpm-omni/examples/llava/README-minicpmo2.6.md) support for efficient CPU inference on local devices, (2) [int4](https://huggingface.co/openbmb/MiniCPM-o-2_6-int4) and [GGUF](https://huggingface.co/openbmb/MiniCPM-o-2_6-gguf) format quantized models in 16 sizes, (3) [vLLM](#efficient-inference-with-llamacpp-ollama-vllm) support for high-throughput and memory-efficient inference, (4) fine-tuning on new domains and tasks with [LLaMA-Factory](./docs/llamafactory_train.md), (5) quick local WebUI demo setup with [Gradio](#chat-with-our-demo-on-gradio), and (6) online web demo on [server](https://minicpm-omni-webdemo-us.modelbest.cn/).
- 💫 **易于使用。**
MiniCPM-o 2.6 可以通过多种方式轻松使用:(1) [llama.cpp](https://github.com/OpenBMB/llama.cpp/blob/minicpm-omni/examples/llava/README-minicpmo2.6.md) 支持在本地设备上进行高效的 CPU 推理,(2) [int4](https://huggingface.co/openbmb/MiniCPM-o-2_6-int4) 和 [GGUF](https://huggingface.co/openbmb/MiniCPM-o-2_6-gguf) 格式的量化模型,有 16 种尺寸,(3) [vLLM](#vllm-部署-) 支持高吞吐量和内存高效的推理,(4) 通过[LLaMA-Factory](./docs/llamafactory_train.md)框架针对新领域和任务进行微调,(5) 使用 [Gradio](#本地-webui-demo-) 快速设置本地 WebUI 演示,(6) [在线demo](https://minicpm-omni-webdemo-us.modelbest.cn/)。
**Model Architecture.**
**模型架构。**
- **端到端全模态架构。** 通过**端到端**的方式连接和训练不同模态的编/解码模块以充分利用丰富的多模态知识。
- **全模态流式机制。** (1) 我们将不同模态的离线编/解码器改造为适用于**流式输入/输出**的在线模块。 (2) 我们针对大语言模型基座设计了一种**时分复用的全模态流式信息处理机制**,将平行的不同模态的信息流拆分重组为周期性时间片序列。
- **可配置的声音方案。** 我们设计了包含传统文本系统提示词和**用于指定模型声音的语音系统提示词**结构。从而,模型可在推理时灵活地通过文字或语音样例控制声音风格,支持声音克隆和声音生成等高级能力。
- **End-to-end Omni-modal Architecture.** Different modality encoder/decoders are connected and trained in an **end-to-end** fashion to fully exploit rich multimodal knowledge.
- **Omni-modal Live Streaming Mechanism.** (1) We change the offline modality encoder/decoders into online ones for **streaminig inputs/outputs.** (2) We devise a **time-division multiplexing (TDM) mechanism** for omni-modality streaminig processing in the LLM backbone. It divides parallel omni-modality streams into sequential info within small periodic time slices.
- **Configurable Speech Modeling Design.** We devise a multimodal system prompt, including traditional text system prompt, and **a new audio system prompt to determine the assistant voice**. This enables flexible voice configurations in inference time, and also facilitates end-to-end voice cloning and description-based voice creation.
<div align="center">
<img src="./assets/minicpm-o-26-framework.png" , width=80%>
<img src="https://github.com/OpenBMB/MiniCPM-o/raw/main/assets/minicpm-o-26-framework-v2.png" , width=100%>
</div>
### 性能评估 <!-- omit in toc -->
### Evaluation <!-- omit in toc -->
<div align="center">
<img src="./assets/radar.jpg", width=70%>
<img src="https://github.com/OpenBMB/MiniCPM-o/raw/main/assets/radar.jpg" width=90% />
</div>
<details>
<summary>点击查看视觉理解能力详细评测结果。</summary>
#### Visual understanding results
**图像理解能力**
**Image Understanding:**
<div align="center">
<table style="margin: 0px auto;">
@ -364,13 +389,18 @@ MiniCPM-o 2.6 可以通过多种方式轻松使用:(1) [llama.cpp](https://git
</tbody>
</table>
</div>
* 我们使用思维链提示词来评估这些基准,对于 MME 我们只在 Cognition 任务上使用了思维链。
+ Token Density每个视觉 token 在最大分辨率下编码的像素数,即最大分辨率下的像素数 / 视觉 token 数。
* We evaluate this benchmark using chain-of-thought prompting. Specifically, for MME, we used this technique only for the Cognition set.
注意:闭源模型的 Token Density 由 API 收费方式估算得到。
**多图和视频理解能力**
<sup>+</sup> Token Density: number of pixels encoded into each visual token at maximum resolution, i.e., # pixels at maximum resolution / # visual tokens.
Note: For proprietary models, we calculate token density based on the image encoding charging strategy defined in the official API documentation, which provides an upper-bound estimation.
**Multi-image and Video Understanding:**
<details>
<summary>click to view</summary>
<div align="center">
<table style="margin: 0px auto;">
@ -467,15 +497,14 @@ MiniCPM-o 2.6 可以通过多种方式轻松使用:(1) [llama.cpp](https://git
</table>
</div>
* 正式开源模型权重的评测结果。
* We evaluate officially released checkpoints by ourselves.
</details>
<details>
<summary>点击查看语音理解和生成能力的详细评测结果。</summary>
#### Audio understanding and speech conversation results.
**语音理解能力**
**Audio Understanding:**
<div align="center">
<table style="margin: 0px auto;">
@ -597,9 +626,9 @@ MiniCPM-o 2.6 可以通过多种方式轻松使用:(1) [llama.cpp](https://git
</tbody>
</table>
</div>
* 正式开源模型权重的评测结果。<br><br>
* We evaluate officially released checkpoints by ourselves.<br><br>
**语音生成能力。**
**Speech Generation:**
<div align="center">
<table style="margin: 0px auto;">
@ -718,9 +747,9 @@ MiniCPM-o 2.6 可以通过多种方式轻松使用:(1) [llama.cpp](https://git
</tbody>
</table>
</div>
所有的结果都基于 <a href="https://github.com/OpenBMB/UltraEval-Audio" target="_blank">AudioEvals</a><br><br>
All results are from AudioEvals, and the evaluation methods along with further details can be found in <a href="https://github.com/OpenBMB/UltraEval-Audio" target="_blank">UltraEval-Audio</a>.<br><br>
**声音克隆能力。**
**End-to-end Voice Cloning**
<div align="center">
<table style="margin: 0px auto;">
@ -765,12 +794,10 @@ MiniCPM-o 2.6 可以通过多种方式轻松使用:(1) [llama.cpp](https://git
</table>
</div>
</details>
<details>
<summary>点击查看多模态流式交互能力评测详细结果。</summary>
#### Multimodal live streaming results.
**多模态流式交互能力**: StreamingBench 分数
**Multimodal Live Streaming:** results on StreamingBench
<table style="margin: 0px auto;">
<thead>
@ -897,12 +924,11 @@ MiniCPM-o 2.6 可以通过多种方式轻松使用:(1) [llama.cpp](https://git
</tbody>
</table>
</details>
### 典型示例 <!-- omit in toc -->
### Examples <!-- omit in toc -->
以下示例为 MiniCPM-o 2.6 部署在 iPad Pro 和web demo录制得到。
We deploy MiniCPM-o 2.6 on end devices. The demo video is the raw-speed recording on an iPad Pro and a Web demo.
<div align="center">
<a href="https://youtu.be/JFJg9KZ_iZk"><img src="https://github.com/OpenBMB/MiniCPM-o/raw/main/assets/o-2dot6-demo-video-preview.png", width=70%></a>
@ -912,15 +938,20 @@ MiniCPM-o 2.6 可以通过多种方式轻松使用:(1) [llama.cpp](https://git
<div style="display: flex; flex-direction: column; align-items: center;">
<img src="./assets/minicpmo2_6_math_intersect.png" alt="math" style="margin-bottom: 5px;">
<img src="./assets/minicpmo2_6_diagram_train_NN.png" alt="diagram" style="margin-bottom: 5px;">
<img src="./assets/minicpmo2_6_multi-image_bike.png" alt="bike" style="margin-bottom: 5px;">
<img src="https://github.com/OpenBMB/MiniCPM-o/raw/main/assets/minicpmo2_6/minicpmo2_6_math_intersect.png" alt="math" style="margin-bottom: 5px;">
<img src="https://github.com/OpenBMB/MiniCPM-o/raw/main/assets/minicpmo2_6/minicpmo2_6_diagram_train_NN.png" alt="diagram" style="margin-bottom: 5px;">
<img src="https://github.com/OpenBMB/MiniCPM-o/raw/main/assets/minicpmo2_6/minicpmo2_6_multi-image_bike.png" alt="bike" style="margin-bottom: 5px;">
</div>
## Online Demo
Click here to try the online demo of [MiniCPM-o 2.6](https://minicpm-omni-webdemo-us.modelbest.cn).
## Usage
Inference using Huggingface transformers on NVIDIA GPUs. Please ensure that **transformers==4.44.2** is installed, as other versions may have compatibility issues. We are investigating this issue. Requirements tested on python 3.10
Inference using Huggingface transformers on NVIDIA GPUs. Please ensure that `transformers==4.44.2` is installed, as other versions may have compatibility issues. We are investigating this issue. Requirements tested on python 3.10
```
Pillow==10.1.0
torch==2.3.1
@ -1192,10 +1223,12 @@ print(res)
#### Speech Conversation as an AI Assistant
An enhanced feature of `MiniCPM-o-2.6` is to act as an AI assistant, but only with limited choice of voices. In this mode, `MiniCPM-o-2.6` is **less human-like and more like a voice assistant**. In this mode, the model is more instruction-following. For demo, you are suggested to use `assistant_default_female_voice`, `assistant_male_voice`. Other voices may work but not as stable as the default voices.
An enhanced feature of `MiniCPM-o-2.6` is to act as an AI assistant, but only with limited choice of voices. In this mode, `MiniCPM-o-2.6` is **less human-like and more like a voice assistant**. In this mode, the model is more instruction-following. For demo, you are suggested to use `assistant_female_voice`, `assistant_male_voice`, and `assistant_default_female_voice`. Other voices may work but not as stable as the default voices.
*Please note that, `assistant_female_voice` and `assistant_male_voice` are more stable but sounds like robots, while `assistant_default_female_voice` is more human-alike but not stable, its voice often changes in multiple turns. We suggest you to try stable voices `assistant_female_voice` and `assistant_male_voice`.*
```python
ref_audio, _ = librosa.load('./assets/input_examples/assistant_default_female_voice.wav', sr=16000, mono=True) # or use `./assets/input_examples/assistant_male_voice.wav`
ref_audio, _ = librosa.load('./assets/input_examples/assistant_female_voice.wav', sr=16000, mono=True) # or use `./assets/input_examples/assistant_male_voice.wav`
sys_prompt = model.get_sys_prompt(ref_audio=ref_audio, mode='audio_assistant', language='en')
user_question = {'role': 'user', 'content': [librosa.load('xxx.wav', sr=16000, mono=True)[0]]} # load the user's audio question
@ -1426,11 +1459,13 @@ print(answer)
Please look at [GitHub](https://github.com/OpenBMB/MiniCPM-o) for more detail about usage.
## llama.cpp <a id="llamacpp"></a>
## Inference with llama.cpp<a id="llamacpp"></a>
MiniCPM-o 2.6 (vision-only mode) can run with llama.cpp. See our fork of [llama.cpp](https://github.com/OpenBMB/llama.cpp/tree/minicpm-omni) and [readme](https://github.com/OpenBMB/llama.cpp/blob/minicpm-omni/examples/llava/README-minicpmo2.6.md) for more detail.
## Int4 量化版
int4 量化版,更低的显存占用(9GB): [MiniCPM-o-2_6-int4](https://modelscope.cn/models/OpenBMB/MiniCPM-o-2_6-int4).
## Int4 quantized version
Download the int4 quantized version for lower GPU memory (7GB) usage: [MiniCPM-o-2_6-int4](https://huggingface.co/openbmb/MiniCPM-o-2_6-int4).
## License

Binary file not shown.

Binary file not shown.

BIN
assets/input_examples/assistant_female_voice.wav (Stored with Git LFS) Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -42,7 +42,7 @@ from transformers import AutoProcessor
from transformers import BertTokenizerFast
from transformers import LlamaConfig
from transformers import LlamaModel
from transformers import LogitsWarper
from transformers import LogitsProcessor
from transformers import PreTrainedModel
from transformers import Qwen2ForCausalLM
from transformers import Qwen2PreTrainedModel
@ -184,7 +184,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
args=(),
init={"class_path": "vocos.heads.ISTFTHead", "init_args": {"dim": 512, "n_fft": 1024, "hop_length": 256}},
)
vocos = Vocos(feature_extractor, backbone, head).to("cuda").eval().to(torch.float32)
vocos = Vocos(feature_extractor, backbone, head).to(self.device).eval().to(torch.float32)
vocos.load_state_dict(torch.load(ckpt_path, weights_only=True, mmap=True))
return vocos
@ -584,6 +584,8 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
if self.config.chunk_input:
for i in range(bs):
if not audio_embeddings[i]:
continue
audio_embs = torch.cat(audio_embeddings[i], dim=0).to(
device=input_embeddings.device, dtype=input_embeddings.dtype
)
@ -1090,7 +1092,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
else:
logger.error("Invalid content type:", c)
cur_contents = "".join(cur_msgs) if omni_input else "\n".join(omni_input)
cur_contents = "".join(cur_msgs) if omni_input else "\n".join(cur_msgs)
if not self.is_first and self.new_user_msg and msg["role"] == "user": # new user add im_start
if self.llm_generated:
if self.llm_generate_completed:
@ -1205,7 +1207,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
generate_prompt = "<|im_end|>\n<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>"
input_ids = tokenizer(generate_prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].cuda()
input_ids = tokenizer(generate_prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].to(self.device)
spk_start_idx = torch.where(input_ids[0] == tokenizer.spk_start_id)[0]
spk_end_idx = torch.where(input_ids[0] == tokenizer.spk_end_id)[0]
@ -1309,7 +1311,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
text = "[Stts]" + "[spk_emb]" * self.tts.num_spk_embs
tts_input_ids = self.tts_processor.text_tokenizer(text, return_tensors="pt", add_special_tokens=False)[
"input_ids"
].cuda()
].to(self.device)
return tts_input_ids
def _build_streaming_mask(self, tts_tokens_len):
@ -1340,7 +1342,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
gen_text = text.split("<|tts_eos|>")[0]
tts_text, tts_token_lens = self.prepare_tts_text(gen_text)
tts_inputs = self.tts_processor.text_tokenizer.encode(tts_text, add_special_tokens=False)
tts_input_ids = torch.Tensor(tts_inputs).unsqueeze(0).to("cuda", dtype=torch.long)
tts_input_ids = torch.Tensor(tts_inputs).unsqueeze(0).to(self.device, dtype=torch.long)
streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
logits_warpers, logits_processors = gen_logits(
@ -1637,7 +1639,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
tts_input_ids = self.tts_processor.text_tokenizer(
tts_text, return_tensors="pt", add_special_tokens=False
)["input_ids"].cuda()
)["input_ids"].to(self.device)
text_input_ids = tts_input_ids[:, begin:end]
streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0)
@ -1746,7 +1748,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
if end > begin:
tts_input_ids = self.tts_processor.text_tokenizer(
tts_text, return_tensors="pt", add_special_tokens=False
)["input_ids"].cuda()
)["input_ids"].to(self.device)
text_input_ids = tts_input_ids[:, begin:end]
streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0)
@ -2917,7 +2919,7 @@ class ConditionalChatTTS(PreTrainedModel):
force_no_stop=False,
min_new_token=10,
max_new_token=50,
logits_warpers: List[LogitsWarper] = [],
logits_warpers: List[LogitsProcessor] = [],
logits_processors: List[CustomRepetitionPenaltyLogitsProcessorRepeat] = [],
show_tqdm=False,
):
@ -2935,7 +2937,7 @@ class ConditionalChatTTS(PreTrainedModel):
eos_token (Union[int, torch.Tensor]): End of sequence token.
streaming_tts_text_mask (Optional[torch.Tensor], optional): Mask for streaming TTS text. Defaults to None.
max_new_token (int, optional): Maximum number of new tokens to generate. Defaults to 50.
logits_warpers (List[LogitsWarper], optional): List of logits warpers. Defaults to [].
logits_warpers (List[LogitsProcessor], optional): List of logits processors. Defaults to [].
logits_processors (List[CustomRepetitionPenaltyLogitsProcessorRepeat], optional): List of logits processors. Defaults to [].
show_tqdm (bool, optional): Whether to show progress bar. Defaults to True.

View File

@ -16,7 +16,7 @@
import warnings
from functools import partial
from typing import Optional
from typing import Tuple
from typing import Tuple, List
import numpy as np
import torch

File diff suppressed because it is too large Load Diff