Compare commits
46 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2e4c601d1d | |||
| 525ebc65a3 | |||
| 564b98b2d5 | |||
| 19b6bf7590 | |||
| 86bff0fc82 | |||
| dd7b78f2c0 | |||
| 29577d57f8 | |||
| 4509becfde | |||
| cd79a647fa | |||
| 96d605b9de | |||
| a9b03a768c | |||
| 77f847fcba | |||
| d3cc88722c | |||
| ec2acec8a1 | |||
| 13605c5a0e | |||
| afa63e6195 | |||
| eae0a29908 | |||
| 35895982d7 | |||
| f7f1b78c4d | |||
| 38d61cdf03 | |||
| 1565e83efe | |||
| 61b36d4e56 | |||
| b1584aec7c | |||
| 4457617953 | |||
| 5510503182 | |||
| fb46aad9a5 | |||
| e4e049624c | |||
| abf01b9bf3 | |||
| 4f4a5b9f6c | |||
| 79c0cf68dd | |||
| 75cfa3e9b8 | |||
| 5611bd08a0 | |||
| 66205135fc | |||
| 364eff6840 | |||
| 6d10932b09 | |||
| 68af4fe502 | |||
| ee3649c1b3 | |||
| 82d77d445c | |||
| 8f95d13073 | |||
| df38f0a167 | |||
| 9adfaf6996 | |||
| 46cfce0c97 | |||
| da700f264e | |||
| 9da570d409 | |||
| 9374524c47 | |||
| ec6d30e996 |
+4
-1
@@ -1,4 +1,7 @@
|
||||
launch.json
|
||||
__pycache__
|
||||
voxcpm.egg-info
|
||||
.DS_Store
|
||||
.DS_Store
|
||||
./pretrained_models/
|
||||
app_local.py
|
||||
models/
|
||||
@@ -1,5 +1,9 @@
|
||||
<h2 align="center">VoxCPM2: Tokenizer-Free TTS for Multilingual Speech Generation, Creative Voice Design, and True-to-Life Cloning</h2>
|
||||
|
||||
<p align="center">
|
||||
<b>English</b> | <a href="./README_zh.md">中文</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/OpenBMB/VoxCPM/"><img src="https://img.shields.io/badge/Project%20Page-GitHub-blue" alt="Project Page"></a>
|
||||
<a href="https://huggingface.co/spaces/OpenBMB/VoxCPM-Demo"><img src="https://img.shields.io/badge/Live%20Playground-Demo-orange" alt="Live Playground"></a>
|
||||
@@ -42,16 +46,16 @@ VoxCPM is a **tokenizer-free** Text-to-Speech system that directly generates con
|
||||
- 🎙️ **Ultimate Cloning** — Reproduce every vocal nuance: provide both reference audio and its transcript, and the model continues seamlessly from the reference, faithfully preserving every vocal detail — timbre, rhythm, emotion, and style (same as VoxCPM1.5)
|
||||
- 🔊 **48kHz High-Quality Audio** — Accepts 16kHz reference audio and directly outputs 48kHz studio-quality audio via AudioVAE V2's asymmetric encode/decode design, with built-in super-resolution — no external upsampler needed
|
||||
- 🧠 **Context-Aware Synthesis** — Automatically infers appropriate prosody and expressiveness from text content
|
||||
- ⚡ **Real-Time Streaming** — RTF as low as ~0.3 on NVIDIA RTX 4090, and ~0.13 accelerated by [Nano-VLLM](https://github.com/a710128/nanovllm-voxcpm)
|
||||
- ⚡ **Real-Time Streaming** — RTF as low as ~0.3 on NVIDIA RTX 4090, and ~0.13 accelerated by [Nano-vLLM](https://github.com/a710128/nanovllm-voxcpm) or [vLLM-Omni](https://github.com/vllm-project/vllm-omni) — official vLLM omni-modal serving for VoxCPM2 with PagedAttention and an OpenAI-compatible API
|
||||
- 📜 **Fully Open-Source & Commercial-Ready** — Weights and code released under the [Apache-2.0](LICENSE) license, free for commercial use
|
||||
|
||||
<details>
|
||||
|
||||
<summary><b>🌍 Supported Languages (30)</b></summary>
|
||||
<br>
|
||||
Arabic, Burmese, Chinese, Danish, Dutch, English, Finnish, French, German, Greek, Hebrew, Hindi, Indonesian, Italian, Japanese, Khmer, Korean, Lao, Malay, Norwegian, Polish, Portuguese, Russian, Spanish, Swahili, Swedish, Tagalog, Thai, Turkish, Vietnamese
|
||||
|
||||
Chinese Dialect: 四川话, 粤语, 吴语, 东北话, 河南话, 陕西话, 山东话, 天津话, 闽南话
|
||||
</details>
|
||||
|
||||
|
||||
### News
|
||||
|
||||
@@ -88,7 +92,7 @@ Chinese Dialect: 四川话, 粤语, 吴语, 东北话, 河南话, 陕西话, 山
|
||||
pip install voxcpm
|
||||
```
|
||||
|
||||
> **Requirements:** Python ≥ 3.10, PyTorch ≥ 2.5.0, CUDA ≥ 12.0. See [Quick Start Docs](https://voxcpm.readthedocs.io/en/latest/quickstart.html) for details.
|
||||
> **Requirements:** Python ≥ 3.10 (<3.13), PyTorch ≥ 2.5.0, CUDA ≥ 12.0. See [Quick Start Docs](https://voxcpm.readthedocs.io/en/latest/quickstart.html) for details.
|
||||
|
||||
### Python API
|
||||
|
||||
@@ -99,7 +103,7 @@ from voxcpm import VoxCPM
|
||||
import soundfile as sf
|
||||
|
||||
model = VoxCPM.from_pretrained(
|
||||
"openbmb/VoxCPM2"
|
||||
"openbmb/VoxCPM2",
|
||||
load_denoiser=False,
|
||||
)
|
||||
|
||||
@@ -112,6 +116,28 @@ sf.write("demo.wav", wav, model.tts_model.sample_rate)
|
||||
print("saved: demo.wav")
|
||||
```
|
||||
|
||||
If you prefer downloading from ModelScope first, you can use:
|
||||
|
||||
```bash
|
||||
pip install modelscope
|
||||
```
|
||||
|
||||
```python
|
||||
from modelscope import snapshot_download
|
||||
snapshot_download("OpenBMB/VoxCPM2", local_dir='./pretrained_models/VoxCPM2') # specify the local directory to save the model
|
||||
|
||||
from voxcpm import VoxCPM
|
||||
import soundfile as sf
|
||||
model = VoxCPM.from_pretrained("./pretrained_models/VoxCPM2", load_denoiser=False)
|
||||
|
||||
wav = model.generate(
|
||||
text="VoxCPM2 is the current recommended release for realistic multilingual speech synthesis.",
|
||||
cfg_value=2.0,
|
||||
inference_timesteps=10,
|
||||
)
|
||||
sf.write("demo.wav", wav, model.tts_model.sample_rate)
|
||||
```
|
||||
|
||||
#### 🎨 Voice Design
|
||||
|
||||
Create a voice from a natural-language description — no reference audio needed. **Format:** put the description in parentheses at the start of `text`(e.g. `"(your voice description)The text to synthesize."`):
|
||||
@@ -132,13 +158,13 @@ Upload a reference audio. The model clones the timbre, and you can still use con
|
||||
```python
|
||||
wav = model.generate(
|
||||
text="This is a cloned voice generated by VoxCPM2.",
|
||||
reference_wav_path="speaker.wav",
|
||||
reference_wav_path="path/to/voice.wav",
|
||||
)
|
||||
sf.write("clone.wav", wav, model.tts_model.sample_rate)
|
||||
|
||||
wav = model.generate(
|
||||
text="(slightly faster, cheerful tone)This is a cloned voice with style control.",
|
||||
reference_wav_path="speaker.wav",
|
||||
reference_wav_path="path/to/voice.wav",
|
||||
cfg_value=2.0,
|
||||
inference_timesteps=10,
|
||||
)
|
||||
@@ -152,9 +178,9 @@ Provide both the reference audio and its exact transcript for audio-continuation
|
||||
```python
|
||||
wav = model.generate(
|
||||
text="This is an ultimate cloning demonstration using VoxCPM2.",
|
||||
prompt_wav_path="speaker_reference.wav",
|
||||
prompt_wav_path="path/to/voice.wav",
|
||||
prompt_text="The transcript of the reference audio.",
|
||||
reference_wav_path="speaker_reference.wav",
|
||||
reference_wav_path="path/to/voice.wav", # optional, for better simliarity
|
||||
)
|
||||
sf.write("hifi_clone.wav", wav, model.tts_model.sample_rate)
|
||||
```
|
||||
@@ -200,6 +226,7 @@ voxcpm clone \
|
||||
--text "This is a voice cloning demo." \
|
||||
--prompt-audio path/to/voice.wav \
|
||||
--prompt-text "reference transcript" \
|
||||
--reference-audio path/to/voice.wav \ # optional, for better simliarity
|
||||
--output out.wav
|
||||
|
||||
# Batch processing
|
||||
@@ -212,7 +239,7 @@ voxcpm --help
|
||||
### Web Demo
|
||||
|
||||
```bash
|
||||
python app.py # then open http://localhost:7860
|
||||
python app.py --port 8808 # then open in browser: http://localhost:8808
|
||||
```
|
||||
|
||||
### 🚢 Production Deployment (Nano-vLLM)
|
||||
@@ -235,6 +262,32 @@ server.stop()
|
||||
|
||||
> **RTF as low as ~0.13 on NVIDIA RTX 4090** (vs ~0.3 with the standard PyTorch implementation), with support for batched concurrent requests and a FastAPI HTTP server. See the [Nano-vLLM-VoxCPM repo](https://github.com/a710128/nanovllm-voxcpm) for deployment details.
|
||||
|
||||
### 🏭 Production Serving (vLLM-Omni)
|
||||
|
||||
For production multi-tenant deployments, use [**vLLM-Omni**](https://github.com/vllm-project/vllm-omni) — the official vLLM project's omni-modal extension with native **VoxCPM2** support. PagedAttention KV cache, continuous batching, and a drop-in **OpenAI-compatible** `/v1/audio/speech` endpoint.
|
||||
|
||||
```bash
|
||||
# Install from source (latest main — vllm-omni is rapidly evolving)
|
||||
uv pip install vllm==0.19.0 --torch-backend=auto
|
||||
git clone https://github.com/vllm-project/vllm-omni.git && cd vllm-omni
|
||||
uv pip install -e .
|
||||
```
|
||||
|
||||
See the [vLLM-Omni installation guide](https://vllm-omni.readthedocs.io/en/latest/getting_started/installation/) for other platforms (ROCm, XPU, MUSA, NPU) and Docker images.
|
||||
|
||||
```bash
|
||||
# Launch an OpenAI-compatible TTS server (--omni enables omni-modal serving)
|
||||
vllm serve openbmb/VoxCPM2 --omni --port 8000
|
||||
|
||||
# Call it from any OpenAI client
|
||||
curl http://localhost:8000/v1/audio/speech \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"model":"openbmb/VoxCPM2","input":"Hello from VoxCPM2 on vLLM-Omni!","voice":"default"}' \
|
||||
--output out.wav
|
||||
```
|
||||
|
||||
> Built on the upstream vLLM scheduler, with batched concurrent requests, streaming chunk delivery, and multi-GPU deployment out of the box. See the [VoxCPM2 example](https://github.com/vllm-project/vllm-omni/tree/main/examples/online_serving/voxcpm2) for full deployment recipes.
|
||||
|
||||
> **Full parameter reference, multi-scenario examples, and voice cloning tips →** [Quick Start Guide](https://voxcpm.readthedocs.io/en/latest/quickstart.html) | [Usage Guide](https://voxcpm.readthedocs.io/en/latest/usage_guide.html) | [Cookbook](https://voxcpm.readthedocs.io/en/latest/cookbook.html)
|
||||
|
||||
---
|
||||
@@ -388,10 +441,54 @@ VoxCPM2 achieves state-of-the-art or comparable results on public zero-shot and
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
### Internal 30-Language ASR Benchmark
|
||||
|
||||
We additionally run an internal multilingual intelligibility benchmark with **30 languages × 500 samples**. ASR transcription is evaluated via **Gemini 3.1 Flash Lite API**.
|
||||
|
||||
<details>
|
||||
<summary><b>Internal 30-Language ASR Benchmark (click to expand)</b></summary>
|
||||
|
||||
| Language | Metric | VoxCPM2 | Fish S2-Pro |
|
||||
|---|---:|---:|---:|
|
||||
| ar (Arabic) | CER | 1.23% | 0.30% |
|
||||
| da (Danish) | WER | 2.70% | 3.52% |
|
||||
| de (German) | WER | 0.96% | 0.64% |
|
||||
| el (Greek) | WER | 3.17% | 4.61% |
|
||||
| en (English) | WER | 0.42% | 1.03% |
|
||||
| es (Spanish) | WER | 1.33% | 0.64% |
|
||||
| fi (Finnish) | WER | 2.24% | 2.80% |
|
||||
| fr (French) | WER | 2.16% | 2.34% |
|
||||
| he (Hebrew) | CER | 2.98% | 15.27% |
|
||||
| hi (Hindi) | CER | 0.79% | 0.91% |
|
||||
| id (Indonesian) | WER | 1.36% | 1.68% |
|
||||
| it (Italian) | WER | 1.65% | 1.08% |
|
||||
| ja (Japanese) | CER | 2.40% | 1.82% |
|
||||
| km (Khmer) | CER | 2.05% | 75.15% |
|
||||
| ko (Korean) | CER | 0.95% | 0.29% |
|
||||
| lo (Lao) | CER | 1.90% | 87.40% |
|
||||
| ms (Malay) | WER | 1.75% | 1.41% |
|
||||
| my (Burmese) | CER | 1.42% | 85.27% |
|
||||
| nl (Dutch) | WER | 1.25% | 1.68% |
|
||||
| no (Norwegian) | WER | 2.49% | 3.76% |
|
||||
| pl (Polish) | WER | 1.90% | 1.65% |
|
||||
| pt (Portuguese) | WER | 1.48% | 1.49% |
|
||||
| ru (Russian) | WER | 0.90% | 0.86% |
|
||||
| sv (Swedish) | WER | 2.22% | 2.63% |
|
||||
| sw (Swahili) | CER | 1.07% | 2.02% |
|
||||
| th (Thai) | CER | 0.94% | 1.92% |
|
||||
| tl (Tagalog) | WER | 2.63% | 4.00% |
|
||||
| tr (Turkish) | WER | 1.65% | 1.65% |
|
||||
| vi (Vietnamese) | WER | 1.56% | 5.56% |
|
||||
| zh (Chinese) | CER | 0.92% | 1.02% |
|
||||
| Average (30 languages) | | **1.68%** | - |
|
||||
|
||||
</details>
|
||||
|
||||
### InstructTTSEval
|
||||
|
||||
<details>
|
||||
<summary><b>Instruction-Guided Voice Design Results</b></summary>
|
||||
<summary><b>Instruction-Guided Voice Design Results (click to expand)</b></summary>
|
||||
|
||||
| Model | InstructTTSEval-ZH | | | InstructTTSEval-EN | | |
|
||||
|-------|:---:|:----:|:----:|:----:|:----:|:----:|
|
||||
@@ -457,11 +554,13 @@ Full documentation: **[voxcpm.readthedocs.io](https://voxcpm.readthedocs.io/en/l
|
||||
| Project | Description |
|
||||
|---|---|
|
||||
| [**Nano-vLLM**](https://github.com/a710128/nanovllm-voxcpm) | High-throughput and Fast GPU serving |
|
||||
| [**vLLM-Omni**](https://github.com/vllm-project/vllm-omni) | Official vLLM omni-modal serving for VoxCPM2 — PagedAttention, OpenAI-compatible API |
|
||||
| [**VoxCPM.cpp**](https://github.com/bluryar/VoxCPM.cpp) | GGML/GGUF: CPU, CUDA, Vulkan inference |
|
||||
| [**VoxCPM-ONNX**](https://github.com/bluryar/VoxCPM-ONNX) | ONNX export for CPU inference |
|
||||
| [**VoxCPMANE**](https://github.com/0seba/VoxCPMANE) | Apple Neural Engine backend |
|
||||
| [**voxcpm_rs**](https://github.com/madushan1000/voxcpm_rs) | Rust re-implementation |
|
||||
| [**ComfyUI-VoxCPM**](https://github.com/wildminder/ComfyUI-VoxCPM) | ComfyUI node-based workflows |
|
||||
| [**ComfyUI_RH_VoxCPM**](https://github.com/HM-RunningHub/ComfyUI_RH_VoxCPM) | Feature-complete ComfyUI workflow for VoxCPM 2 with multi-speaker generation, LoRA, and auto-ASR |
|
||||
| [**ComfyUI-VoxCPMTTS**](https://github.com/1038lab/ComfyUI-VoxCPMTTS) | ComfyUI TTS extension |
|
||||
| [**TTS WebUI**](https://github.com/rsxdalv/tts_webui_extension.vox_cpm) | Browser-based TTS extension |
|
||||
|
||||
|
||||
+618
@@ -0,0 +1,618 @@
|
||||
<h2 align="center">VoxCPM2:基于连续表征的多语言语音合成、创意音色设计与高保真声音克隆</h2>
|
||||
|
||||
<p align="center">
|
||||
<a href="./README.md">English</a> | <b>中文</b>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/OpenBMB/VoxCPM/"><img src="https://img.shields.io/badge/Project%20Page-GitHub-blue" alt="Project Page"></a>
|
||||
<a href="https://huggingface.co/spaces/OpenBMB/VoxCPM-Demo"><img src="https://img.shields.io/badge/Live%20Playground-Demo-orange" alt="Live Playground"></a>
|
||||
<a href="https://voxcpm.readthedocs.io/zh-cn/latest/"><img src="https://img.shields.io/badge/Docs-ReadTheDocs-8CA1AF" alt="Documentation"></a>
|
||||
<a href="https://huggingface.co/openbmb/VoxCPM2"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-VoxCPM2-yellow" alt="Hugging Face"></a>
|
||||
<a href="https://modelscope.cn/models/OpenBMB/VoxCPM2"><img src="https://img.shields.io/badge/ModelScope-VoxCPM2-purple" alt="ModelScope"></a>
|
||||
<a href="https://openbmb.github.io/voxcpm2-demopage/"><img src="https://img.shields.io/badge/DemoPage-Audio Samples-red"></a>
|
||||
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
<img src="assets/voxcpm_logo.png" alt="VoxCPM Logo" width="35%">
|
||||
<br><br>
|
||||
<a href="https://trendshift.io/repositories/17704" target="_blank"><img src="https://trendshift.io/api/badge/repositories/17704" alt="OpenBMB%2FVoxCPM | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<p align="center">
|
||||
👋 欢迎加入社区,参与讨论与交流!
|
||||
<br>
|
||||
<a href="./assets/feishu-group.png" style="display:inline-block;vertical-align:middle; margin-left: 10px;">
|
||||
<img src="./assets/feishu-logo.png" width="16" height="16" style="vertical-align:middle;"> 飞书群
|
||||
</a>
|
||||
|
|
||||
<a href="https://discord.gg/KZUx7tVNwz" style="display:inline-block;vertical-align:middle;">
|
||||
<img src="./assets/discord-logo.png" width="16" height="16" style="vertical-align:middle;"> Discord
|
||||
</a>
|
||||
</p>
|
||||
|
||||
VoxCPM 是一个**无离散音频分词器**(Tokenizer-Free)的语音合成系统,通过端到端的**扩散自回归架构**直接生成连续语音表征,绕过对音频的离散编码步骤,实现高度自然且富有表现力的语音合成。
|
||||
|
||||
**VoxCPM2** 是最新的版本 — 基于 [MiniCPM-4](https://github.com/OpenBMB/MiniCPM) 基座构建,总计 **20亿** 参数,在超过 **200万小时** 的多语种音频数据上训练,支持 **30种全球语言+9种中文方言**、**音色设计**、**可控声音克隆**,原生输出 **48kHz** 高质量音频。
|
||||
|
||||
### ✨ 核心特性
|
||||
|
||||
- 🌍 **30种语言语音合成** — 直接输入原始文本即可合成(支持语言详见下文),无需额外语言标签
|
||||
- 🎨 **音色设计** — 用自然语言描述(性别、年龄、音色、情绪、语速……)凭空创建全新音色,无需参考音频
|
||||
- 🎛️ **可控声音克隆** — 从参考音频片段克隆任意声音,可叠加风格指令控制情绪、语速和表现力,同时保持原始音色
|
||||
- 🎙️ **极致克隆** — 提供参考音频及其文本内容,模型接着参考音频进行无缝续写,从而精准还原声音细节特征(与 VoxCPM1.5 一致)
|
||||
- 🔊 **48kHz 高质量音频** — 输入 16kHz 参考音频,通过 AudioVAE V2 的非对称编解码设计直接输出 48kHz 高质量音频,内置超分能力
|
||||
- 🧠 **语境感知合成** — 根据文本内容自动推断合适的韵律和表现力
|
||||
- ⚡ **实时流式合成** — 在 NVIDIA RTX 4090 上 RTF 低至 ~0.3,通过 [Nano-vLLM](https://github.com/a710128/nanovllm-voxcpm) 或 [vLLM-Omni](https://github.com/vllm-project/vllm-omni)(官方 vLLM 全模态服务,原生支持 VoxCPM2,提供 PagedAttention 与 OpenAI 兼容 API)加速后可达 ~0.13
|
||||
- 📜 **完全开源,商用就绪** — 权重和代码基于 [Apache-2.0](LICENSE) 协议发布,免费商用
|
||||
|
||||
<summary><b>🌍 支持的语言(30种)</b></summary>
|
||||
<br>
|
||||
阿拉伯语、缅甸语、中文、丹麦语、荷兰语、英语、芬兰语、法语、德语、希腊语、希伯来语、印地语、印尼语、意大利语、日语、高棉语、韩语、老挝语、马来语、挪威语、波兰语、葡萄牙语、俄语、西班牙语、斯瓦希里语、瑞典语、菲律宾语、泰语、土耳其语、越南语
|
||||
|
||||
中国方言:四川话、粤语、吴语、东北话、河南话、陕西话、山东话、天津话、闽南话
|
||||
|
||||
|
||||
### 最新动态
|
||||
|
||||
* **[2026.04]** 🔥 发布 **VoxCPM2** — 20亿参数,30种语言,音色设计与可控声音克隆,48kHz 音频输出 | [使用文档](https://voxcpm.readthedocs.io/zh-cn/latest/) | [在线体验](https://huggingface.co/spaces/OpenBMB/VoxCPM-Demo) | [官网体验](https://voxcpm.modelbest.cn/) (适用国内访问)
|
||||
* **[2025.12]** 🎉 开源 **VoxCPM1.5** [模型权重](https://huggingface.co/openbmb/VoxCPM1.5),支持 SFT 和 LoRA 微调。(**🏆 GitHub Trending #1**)
|
||||
* **[2025.09]** 🔥 发布 VoxCPM [技术报告](https://arxiv.org/abs/2509.24650)。
|
||||
* **[2025.09]** 🎉 开源 **VoxCPM-0.5B** [模型权重](https://huggingface.co/openbmb/VoxCPM-0.5B) (**🏆 HuggingFace Trending #1**)
|
||||
|
||||
---
|
||||
|
||||
## 目录
|
||||
|
||||
- [快速开始](#-快速开始)
|
||||
- [安装](#安装)
|
||||
- [Python API](#python-api)
|
||||
- [命令行使用](#命令行使用)
|
||||
- [Web Demo](#web-demo)
|
||||
- [生产部署](#-生产部署nano-vllm)
|
||||
- [模型与版本](#-模型与版本)
|
||||
- [性能评测](#-性能评测)
|
||||
- [微调](#%EF%B8%8F-微调)
|
||||
- [文档](#-文档)
|
||||
- [生态与社区](#-生态与社区)
|
||||
- [风险与局限性](#%EF%B8%8F-风险与局限性)
|
||||
- [引用](#-引用)
|
||||
|
||||
---
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
### 安装
|
||||
|
||||
```sh
|
||||
pip install voxcpm
|
||||
```
|
||||
|
||||
> **环境要求:** Python ≥ 3.10 (<3.13),PyTorch ≥ 2.5.0,CUDA ≥ 12.0。详见 [快速开始文档](https://voxcpm.readthedocs.io/zh-cn/latest/quickstart.html)。
|
||||
|
||||
### Python API
|
||||
|
||||
#### 🗣️ 文本转语音
|
||||
|
||||
```python
|
||||
from voxcpm import VoxCPM
|
||||
import soundfile as sf
|
||||
|
||||
model = VoxCPM.from_pretrained(
|
||||
"openbmb/VoxCPM2",
|
||||
load_denoiser=False,
|
||||
)
|
||||
|
||||
wav = model.generate(
|
||||
text="VoxCPM2 是目前推荐使用的多语言语音合成版本。",
|
||||
cfg_value=2.0,
|
||||
inference_timesteps=10,
|
||||
)
|
||||
sf.write("demo.wav", wav, model.tts_model.sample_rate)
|
||||
print("已保存: demo.wav")
|
||||
```
|
||||
|
||||
如果你希望先从 ModelScope 下载模型到本地(适用于国内网络访问),可以使用:
|
||||
|
||||
```bash
|
||||
pip install modelscope
|
||||
```
|
||||
|
||||
```python
|
||||
from modelscope import snapshot_download
|
||||
snapshot_download("OpenBMB/VoxCPM2", local_dir='./pretrained_models/VoxCPM2') # 指定模型保存的本地路径
|
||||
|
||||
from voxcpm import VoxCPM
|
||||
import soundfile as sf
|
||||
model = VoxCPM.from_pretrained('./pretrained_models/VoxCPM2', load_denoiser=False)
|
||||
|
||||
wav = model.generate(
|
||||
text="VoxCPM2 是目前推荐使用的多语言语音合成版本。",
|
||||
cfg_value=2.0,
|
||||
inference_timesteps=10,
|
||||
)
|
||||
sf.write("demo.wav", wav, model.tts_model.sample_rate)
|
||||
```
|
||||
|
||||
#### 🎨 音色设计
|
||||
|
||||
用自然语言描述创建全新音色,无需参考音频。**格式:** 在 `text` 开头用括号写入音色描述(如 `"(音色描述)要合成的文本。"`):
|
||||
|
||||
```python
|
||||
wav = model.generate(
|
||||
text="(年轻女性,声音温柔甜美)你好,欢迎使用VoxCPM2!",
|
||||
cfg_value=2.0,
|
||||
inference_timesteps=10,
|
||||
)
|
||||
sf.write("voice_design.wav", wav, model.tts_model.sample_rate)
|
||||
```
|
||||
|
||||
#### 🎛️ 可控声音克隆
|
||||
|
||||
上传一段参考音频,模型克隆其音色,同时可以使用控制指令调节语速、情绪或风格。
|
||||
|
||||
```python
|
||||
wav = model.generate(
|
||||
text="这是VoxCPM2生成的克隆语音。",
|
||||
reference_wav_path="path/to/voice.wav",
|
||||
)
|
||||
sf.write("clone.wav", wav, model.tts_model.sample_rate)
|
||||
|
||||
wav = model.generate(
|
||||
text="(稍快一点,欢快的语气)这是带风格控制的克隆语音。",
|
||||
reference_wav_path="path/to/voice.wav",
|
||||
cfg_value=2.0,
|
||||
inference_timesteps=10,
|
||||
)
|
||||
sf.write("controllable_clone.wav", wav, model.tts_model.sample_rate)
|
||||
```
|
||||
|
||||
#### 🎙️ 极致克隆
|
||||
|
||||
提供参考音频及其精确文本转录,实现基于音频续写的高保真克隆。为获得最高克隆相似度,可将同一音频同时传给 `reference_wav_path` 和 `prompt_wav_path`:
|
||||
|
||||
```python
|
||||
wav = model.generate(
|
||||
text="这是使用VoxCPM2的极致克隆演示。",
|
||||
prompt_wav_path="path/to/voice.wav",
|
||||
prompt_text="参考音频的文本转录。",
|
||||
reference_wav_path="path/to/voice.wav", # 可选,提升相似度
|
||||
)
|
||||
sf.write("hifi_clone.wav", wav, model.tts_model.sample_rate)
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary><b>🔄 流式 API</b></summary>
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
|
||||
chunks = []
|
||||
for chunk in model.generate_streaming(
|
||||
text="使用VoxCPM进行流式语音合成非常简单!",
|
||||
):
|
||||
chunks.append(chunk)
|
||||
wav = np.concatenate(chunks)
|
||||
sf.write("streaming.wav", wav, model.tts_model.sample_rate)
|
||||
```
|
||||
</details>
|
||||
|
||||
### 命令行使用
|
||||
|
||||
```bash
|
||||
# 音色设计(无需参考音频)
|
||||
voxcpm design \
|
||||
--text "VoxCPM2带来全新语音合成体验。" \
|
||||
--output out.wav
|
||||
|
||||
# 可控声音克隆(带风格控制)
|
||||
voxcpm design \
|
||||
--text "VoxCPM2带来全新语音合成体验。" \
|
||||
--control "年轻女声,温暖温柔,略带微笑" \
|
||||
--output out.wav
|
||||
|
||||
# 声音克隆(参考音频)
|
||||
voxcpm clone \
|
||||
--text "这是一个声音克隆的演示。" \
|
||||
--reference-audio path/to/voice.wav \
|
||||
--output out.wav
|
||||
|
||||
# 极致克隆(提示音频 + 转录文本)
|
||||
voxcpm clone \
|
||||
--text "这是一个声音克隆的演示。" \
|
||||
--prompt-audio path/to/voice.wav \
|
||||
--prompt-text "参考音频转录文本" \
|
||||
--reference-audio path/to/voice.wav \
|
||||
--output out.wav
|
||||
|
||||
# 批量处理
|
||||
voxcpm batch --input examples/input.txt --output-dir outs
|
||||
|
||||
# 帮助
|
||||
voxcpm --help
|
||||
```
|
||||
|
||||
### Web Demo
|
||||
|
||||
```bash
|
||||
python app.py --port 8808 # 然后在浏览器打开 http://localhost:8808
|
||||
```
|
||||
|
||||
### 🚢 生产部署(Nano-vLLM)
|
||||
|
||||
如需高吞吐量部署,使用 [**Nano-vLLM-VoxCPM**](https://github.com/a710128/nanovllm-voxcpm) — 基于 Nano-vLLM 构建的专用推理引擎,支持并发请求和异步 API。
|
||||
|
||||
```bash
|
||||
pip install nano-vllm-voxcpm
|
||||
```
|
||||
|
||||
```python
|
||||
from nanovllm_voxcpm import VoxCPM
|
||||
import numpy as np, soundfile as sf
|
||||
|
||||
server = VoxCPM.from_pretrained(model="/path/to/VoxCPM", devices=[0])
|
||||
chunks = list(server.generate(target_text="你好,我来自VoxCPM!"))
|
||||
sf.write("out.wav", np.concatenate(chunks), 48000)
|
||||
server.stop()
|
||||
```
|
||||
|
||||
> **在 NVIDIA RTX 4090 上 RTF 低至 ~0.13**(标准 PyTorch 实现约 ~0.3),支持批量并发请求和 FastAPI HTTP 服务。详见 [Nano-vLLM-VoxCPM 仓库](https://github.com/a710128/nanovllm-voxcpm)。
|
||||
|
||||
### 🏭 生产环境部署(vLLM-Omni)
|
||||
|
||||
如需生产级多租户部署,使用 [**vLLM-Omni**](https://github.com/vllm-project/vllm-omni) — 官方 vLLM 项目的全模态扩展,原生支持 **VoxCPM2**。具备 PagedAttention KV 缓存、连续批处理,以及与 OpenAI 完全兼容的 `/v1/audio/speech` 接口。
|
||||
|
||||
```bash
|
||||
# 从源码安装(最新 main 分支 —— vllm-omni 正在快速迭代)
|
||||
uv pip install vllm==0.19.0 --torch-backend=auto
|
||||
git clone https://github.com/vllm-project/vllm-omni.git && cd vllm-omni
|
||||
uv pip install -e .
|
||||
```
|
||||
|
||||
其他平台(ROCm、XPU、MUSA、NPU)与 Docker 镜像请参考 [vLLM-Omni 安装文档](https://vllm-omni.readthedocs.io/en/latest/getting_started/installation/)。
|
||||
|
||||
```bash
|
||||
# 启动 OpenAI 兼容的 TTS 服务(--omni 启用全模态服务)
|
||||
vllm serve openbmb/VoxCPM2 --omni --port 8000
|
||||
|
||||
# 任意 OpenAI 客户端均可调用
|
||||
curl http://localhost:8000/v1/audio/speech \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"model":"openbmb/VoxCPM2","input":"你好,欢迎使用 VoxCPM2 on vLLM-Omni!","voice":"default"}' \
|
||||
--output out.wav
|
||||
```
|
||||
|
||||
> 基于上游 vLLM 调度器构建,开箱即用支持批量并发、流式分块输出和多 GPU 部署。完整示例见 [VoxCPM2 部署样例](https://github.com/vllm-project/vllm-omni/tree/main/examples/online_serving/voxcpm2)。
|
||||
|
||||
> **完整参数说明、多场景示例与声音克隆技巧 →** [快速开始指南](https://voxcpm.readthedocs.io/zh-cn/latest/quickstart.html) | [使用指南](https://voxcpm.readthedocs.io/zh-cn/latest/usage_guide.html) | [Cookbook](https://voxcpm.readthedocs.io/zh-cn/latest/cookbook.html)
|
||||
|
||||
---
|
||||
|
||||
## 📦 模型与版本
|
||||
|
||||
| | **VoxCPM2** | **VoxCPM1.5** | **VoxCPM-0.5B** |
|
||||
|---|:---:|:---:|:---:|
|
||||
| **状态** | 🟢 最新版本 | 稳定版 | 旧版 |
|
||||
| **主模型参数量** | 2B | 0.6B | 0.5B |
|
||||
| **音频采样率** | 48kHz | 44.1kHz | 16kHz |
|
||||
| **LM处理码率** | 6.25Hz | 6.25Hz | 12.5Hz |
|
||||
| **语言支持数量** | 30 | 2(中文、英文) | 2(中文、英文) |
|
||||
| **克隆模式** | 隔离参考音频(无需文本) & 音频续写 | 仅音频续写 | 仅音频续写 |
|
||||
| **音色设计** | ✅ | — | — |
|
||||
| **可控声音克隆** | ✅ | — | — |
|
||||
| **SFT / LoRA** | ✅ | ✅ | ✅ |
|
||||
| **RTF (RTX 4090)** | ~0.30 | ~0.15 | ~0.17 |
|
||||
| **RTF Nano-VLLM (RTX 4090)** | ~0.13 | ~0.08 | ~0.10 |
|
||||
| **显存占用** | ~8 GB | ~6 GB | ~5 GB |
|
||||
| **模型权重** | [🤗 HF](https://huggingface.co/openbmb/VoxCPM2) / [MS](https://modelscope.cn/models/OpenBMB/VoxCPM2) | [🤗 HF](https://huggingface.co/openbmb/VoxCPM1.5) / [MS](https://modelscope.cn/models/OpenBMB/VoxCPM1.5) | [🤗 HF](https://huggingface.co/openbmb/VoxCPM-0.5B) / [MS](https://modelscope.cn/models/OpenBMB/VoxCPM-0.5B) |
|
||||
| **技术报告** | 即将发布 | — | [arXiv](https://arxiv.org/abs/2509.24650) [ICLR 2026](https://openreview.net/forum?id=h5KLpGoqzC) |
|
||||
| **Demo 页面** | [音频示例](https://openbmb.github.io/voxcpm2-demopage) | — | [音频示例](https://openbmb.github.io/VoxCPM-demopage) |
|
||||
|
||||
VoxCPM2 采用**连续音频表征、扩散自回归**范式,模型在 **AudioVAE** 的连续隐空间中通过四阶段处理:**LocEnc → TSLM → RALM → LocDiT**,实现丰富的表现力语音合成和 48kHz 原生音频输出。
|
||||
|
||||
<div align="center">
|
||||
<img src="assets/voxcpm_model.png" alt="VoxCPM2 模型架构" width="90%">
|
||||
</div>
|
||||
|
||||
> 完整架构细节、VoxCPM2 升级内容和模型对比表见 [架构设计文档](https://voxcpm.readthedocs.io/zh-cn/latest/models/architecture.html)。
|
||||
|
||||
---
|
||||
|
||||
## 📊 性能评测
|
||||
|
||||
VoxCPM2 在公开的零样本和可控 TTS 基准测试中取得了 SOTA 或可比的结果。
|
||||
|
||||
### Seed-TTS-eval
|
||||
|
||||
<details>
|
||||
<summary><b>Seed-TTS-eval WER(⬇)&SIM(⬆) 结果(点击展开)</b></summary>
|
||||
|
||||
| Model | Parameters | Open-Source | test-EN | | test-ZH | | test-Hard | |
|
||||
|------|------|------|:------------:|:--:|:------------:|:--:|:-------------:|:--:|
|
||||
| | | | WER/%⬇ | SIM/%⬆| CER/%⬇| SIM/%⬆ | CER/%⬇ | SIM/%⬆ |
|
||||
| MegaTTS3 | 0.5B | ❌ | 2.79 | 77.1 | 1.52 | 79.0 | - | - |
|
||||
| DiTAR | 0.6B | ❌ | 1.69 | 73.5 | 1.02 | 75.3 | - | - |
|
||||
| CosyVoice3 | 0.5B | ❌ | 2.02 | 71.8 | 1.16 | 78.0 | 6.08 | 75.8 |
|
||||
| CosyVoice3 | 1.5B | ❌ | 2.22 | 72.0 | 1.12 | 78.1 | 5.83 | 75.8 |
|
||||
| Seed-TTS | - | ❌ | 2.25 | 76.2 | 1.12 | 79.6 | 7.59 | 77.6 |
|
||||
| MiniMax-Speech | - | ❌ | 1.65 | 69.2 | 0.83 | 78.3 | - | - |
|
||||
| F5-TTS | 0.3B | ✅ | 2.00 | 67.0 | 1.53 | 76.0 | 8.67 | 71.3 |
|
||||
| MaskGCT | 1B | ✅ | 2.62 | 71.7 | 2.27 | 77.4 | - | - |
|
||||
| CosyVoice | 0.3B | ✅ | 4.29 | 60.9 | 3.63 | 72.3 | 11.75 | 70.9 |
|
||||
| CosyVoice2 | 0.5B | ✅ | 3.09 | 65.9 | 1.38 | 75.7 | 6.83 | 72.4 |
|
||||
| SparkTTS | 0.5B | ✅ | 3.14 | 57.3 | 1.54 | 66.0 | - | - |
|
||||
| FireRedTTS | 0.5B | ✅ | 3.82 | 46.0 | 1.51 | 63.5 | 17.45 | 62.1 |
|
||||
| FireRedTTS-2 | 1.5B | ✅ | 1.95 | 66.5 | 1.14 | 73.6 | - | - |
|
||||
| Qwen2.5-Omni | 7B | ✅ | 2.72 | 63.2 | 1.70 | 75.2 | 7.97 | 74.7 |
|
||||
| Qwen3-Omni | 30B-A3B | ✅ | 1.39 | - | 1.07 | - | - | - |
|
||||
| OpenAudio-s1-mini | 0.5B | ✅ | 1.94 | 55.0 | 1.18 | 68.5 | 23.37 | 64.3 |
|
||||
| IndexTTS2 | 1.5B | ✅ | 2.23 | 70.6 | 1.03 | 76.5 | 7.12 | 75.5 |
|
||||
| VibeVoice | 1.5B | ✅ | 3.04 | 68.9 | 1.16 | 74.4 | - | - |
|
||||
| HiggsAudio-v2 | 3B | ✅ | 2.44 | 67.7 | 1.50 | 74.0 | 55.07 | 65.6 |
|
||||
| VoxCPM-0.5B | 0.6B | ✅ | 1.85 | 72.9 | 0.93 | 77.2 | 8.87 | 73.0 |
|
||||
| VoxCPM1.5 | 0.8B | ✅ | 2.12 | 71.4 | 1.18 | 77.0 | 7.74 | 73.1 |
|
||||
| MOSS-TTS | | ✅ | 1.85 | 73.4 | 1.20 | 78.8 | - | - |
|
||||
| Qwen3-TTS | 1.7B | ✅ | 1.23 | 71.7 | 1.22 | 77.0 | 6.76 | 74.8 |
|
||||
| FishAudio S2 | 4B | ✅ | 0.99 | - | 0.54 | - | 5.99 | - |
|
||||
| LongCat-Audio-DiT | 3.5B | ✅ | 1.50 | 78.6 | 1.09 | 81.8 | 6.04 | 79.7 |
|
||||
| **VoxCPM2** | 2B | ✅ | 1.84 | 75.3 | 0.97| 79.5| 8.13 | 75.3 |
|
||||
</details>
|
||||
|
||||
|
||||
### CV3-eval
|
||||
<details>
|
||||
<summary><b>CV3-eval 多语言 WER/CER(⬇) 结果(点击展开)</b></summary>
|
||||
|
||||
| Model | zh | en | hard-zh | hard-en | ja | ko | de | es | fr | it | ru |
|
||||
|-------|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
|
||||
| CosyVoice2 | 4.08 | 6.32 | 12.58| 11.96| 9.13 | 19.7 |- | - | - | - | - |
|
||||
| CosyVoice3-1.5B | 3.91 | 4.99 | 9.77 | 10.55 | 7.57 | 5.69 | 6.43 | 4.47 | 11.8 | 10.5 | 6.64 |
|
||||
| Fish Audio S2 | 2.65 | 2.43 | 9.10 | 4.40 | 3.96 | 2.76 | 2.22 | 2.00 | 6.26 | 2.04 | 2.78 |
|
||||
| **VoxCPM2** | 3.65 | 5.00 | 8.55 | 8.48 | 5.96 | 5.69 | 4.77 | 3.80 | 9.85 | 4.25 | 5.21 |
|
||||
</details>
|
||||
|
||||
### MiniMax-Multilingual-Test
|
||||
|
||||
<details>
|
||||
<summary><b>Minimax-MLS-test WER(⬇) 结果(点击展开)</b></summary>
|
||||
|
||||
| Language | Minimax | ElevenLabs | Qwen3-TTS | FishAudio S2 | **VoxCPM2** |
|
||||
|----------|:-------:|:----------:|:--------------------:|:------------:|:-----------:|
|
||||
| Arabic | **1.665** | 1.666 | – | 3.500 | 13.046 |
|
||||
| Cantonese | 34.111 | 51.513 | – | **30.670** | 38.584 |
|
||||
| Chinese | 2.252 | 16.026 | 0.928 | **0.730** | 1.136 |
|
||||
| Czech | 3.875 | **2.108** | – | 2.840 | 24.132 |
|
||||
| Dutch | 1.143 | **0.803** | – | 0.990 | 0.913 |
|
||||
| English | 2.164 | 2.339 | **0.934** | 1.620 | 2.289 |
|
||||
| Finnish | 4.666 | 2.964 | – | 3.330 | **2.632** |
|
||||
| French | 4.099 | 5.216 | **2.858** | 3.050 | 4.534 |
|
||||
| German | 1.906 | 0.572 | 1.235 | **0.550** | 0.679 |
|
||||
| Greek | 2.016 | **0.991** | – | 5.740 | 2.844 |
|
||||
| Hindi | 6.962 | **5.827** | – | 14.640 | 19.699 |
|
||||
| Indonesian | 1.237 | **1.059** | – | 1.460 | 1.084 |
|
||||
| Italian | 1.543 | 1.743 | **0.948** | 1.270 | 1.563 |
|
||||
| Japanese | 3.519 | 10.646 | 3.823 | **2.760** | 4.628 |
|
||||
| Korean | 1.747 | 1.865 | 1.755 | **1.180** | 1.962 |
|
||||
| Polish | 1.415 | **0.766** | – | 1.260 | 1.141 |
|
||||
| Portuguese | 1.877 | 1.331 | 1.526 | **1.140** | 1.938 |
|
||||
| Romanian | 2.878 | **1.347** | – | 10.740 | 21.577 |
|
||||
| Russian | 4.281 | 3.878 | 3.212 | **2.400** | 3.634 |
|
||||
| Spanish | 1.029 | 1.084 | 1.126 | **0.910** | 1.438 |
|
||||
| Thai | 2.701 | 73.936 | – | 4.230 | 2.961 |
|
||||
| Turkish | 1.52 | 0.699 | – | 0.870 | 0.817 |
|
||||
| Ukrainian | 1.082 | **0.997** | – | 2.300 | 6.316 |
|
||||
| Vietnamese | **0.88** | 73.415 | – | 7.410 | 3.307 |
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>Minimax-MLS-test SIM(⬆) 结果(点击展开)</b></summary>
|
||||
|
||||
| Language | Minimax | ElevenLabs | Qwen3-TTS | FishAudio S2 | **VoxCPM2** |
|
||||
|----------|:-------:|:----------:|:--------------------:|:------------:|:-----------:|
|
||||
| Arabic | 73.6 | 70.6 | – | 75.0 | **79.1** |
|
||||
| Cantonese | 77.8 | 67.0 | – | 80.5 | **83.5** |
|
||||
| Chinese | 78.0 | 67.7 | 79.9 | 81.6 | **82.5** |
|
||||
| Czech | 79.6 | 68.5 | – | **79.8** | 78.3 |
|
||||
| Dutch | 73.8 | 68.0 | – | 73.0 | **80.8** |
|
||||
| English | 75.6 | 61.3 | 77.5 | 79.7 | **85.4** |
|
||||
| Finnish | 83.5 | 75.9 | – | 81.9 | **89.0** |
|
||||
| French | 62.8 | 53.5 | 62.8 | 69.8 | **73.5** |
|
||||
| German | 73.3 | 61.4 | 77.5 | 76.7 | **80.3** |
|
||||
| Greek | 82.6 | 73.3 | – | 79.5 | **86.0** |
|
||||
| Hindi | 81.8 | 73.0 | – | 82.1 | **85.6** |
|
||||
| Indonesian | 72.9 | 66.0 | – | 76.3 | **80.0** |
|
||||
| Italian | 69.9 | 57.9 | 81.7 | 74.7 | **78.0** |
|
||||
| Japanese | 77.6 | 73.8 | 78.8 | 79.6 | **82.8** |
|
||||
| Korean | 77.6 | 70.0 | 79.9 | 81.7 | **83.3** |
|
||||
| Polish | 80.2 | 72.9 | – | 81.9 | **88.4** |
|
||||
| Portuguese | 80.5 | 71.1 | 81.7 | 78.1 | **83.7** |
|
||||
| Romanian | **80.9** | 69.9 | – | 73.3 | 79.7 |
|
||||
| Russian | 76.1 | 67.6 | 79.2 | 79.0 | **81.1** |
|
||||
| Spanish | 76.2 | 61.5 | 81.4 | 77.6 | **83.1** |
|
||||
| Thai | 80.0 | 58.8 | – | 78.6 | **84.0** |
|
||||
| Turkish | 77.9 | 59.6 | – | 83.5 | **87.1** |
|
||||
| Ukrainian | 73.0 | 64.7 | – | 74.7 | **79.8** |
|
||||
| Vietnamese | 74.3 | 36.9 | – | 74.0 | **80.6** |
|
||||
|
||||
</details>
|
||||
|
||||
### Internal 30-Language ASR Benchmark
|
||||
|
||||
我们额外进行了内部多语言可懂度评测:**30 语种 × 500 样本**,ASR 转写评估使用 **Gemini 3.1 Flash Lite API**。
|
||||
|
||||
<details>
|
||||
<summary><b>内部30语种评测集ASR结果(点击展开)</b></summary>
|
||||
|
||||
| 语言 | 指标 | VoxCPM2 | Fish S2-Pro |
|
||||
|---|---:|---:|---:|
|
||||
| ar (阿拉伯语) | CER | 1.23% | 0.30% |
|
||||
| da (丹麦语) | WER | 2.70% | 3.52% |
|
||||
| de (德语) | WER | 0.96% | 0.64% |
|
||||
| el (希腊语) | WER | 3.17% | 4.61% |
|
||||
| en (英语) | WER | 0.42% | 1.03% |
|
||||
| es (西班牙语) | WER | 1.33% | 0.64% |
|
||||
| fi (芬兰语) | WER | 2.24% | 2.80% |
|
||||
| fr (法语) | WER | 2.16% | 2.34% |
|
||||
| he (希伯来语) | CER | 2.98% | 15.27% |
|
||||
| hi (印地语) | CER | 0.79% | 0.91% |
|
||||
| id (印尼语) | WER | 1.36% | 1.68% |
|
||||
| it (意大利语) | WER | 1.65% | 1.08% |
|
||||
| ja (日语) | CER | 2.40% | 1.82% |
|
||||
| km (高棉语) | CER | 2.05% | 75.15% |
|
||||
| ko (韩语) | CER | 0.95% | 0.29% |
|
||||
| lo (老挝语) | CER | 1.90% | 87.40% |
|
||||
| ms (马来语) | WER | 1.75% | 1.41% |
|
||||
| my (缅甸语) | CER | 1.42% | 85.27% |
|
||||
| nl (荷兰语) | WER | 1.25% | 1.68% |
|
||||
| no (挪威语) | WER | 2.49% | 3.76% |
|
||||
| pl (波兰语) | WER | 1.90% | 1.65% |
|
||||
| pt (葡萄牙语) | WER | 1.48% | 1.49% |
|
||||
| ru (俄语) | WER | 0.90% | 0.86% |
|
||||
| sv (瑞典语) | WER | 2.22% | 2.63% |
|
||||
| sw (斯瓦希里语) | CER | 1.07% | 2.02% |
|
||||
| th (泰语) | CER | 0.94% | 1.92% |
|
||||
| tl (菲律宾语) | WER | 2.63% | 4.00% |
|
||||
| tr (土耳其语) | WER | 1.65% | 1.65% |
|
||||
| vi (越南语) | WER | 1.56% | 5.56% |
|
||||
| zh (中文) | CER | 0.92% | 1.02% |
|
||||
| 平均(30 语种) | | **1.68%** | - |
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
### InstructTTSEval
|
||||
|
||||
<details>
|
||||
<summary><b>指令驱动音色设计结果 (点击展开)</b></summary>
|
||||
|
||||
| Model | InstructTTSEval-ZH | | | InstructTTSEval-EN | | |
|
||||
|-------|:---:|:----:|:----:|:----:|:----:|:----:|
|
||||
| | APS⬆| DSD⬆ | RP⬆| APS⬆ | DSD⬆ | RP⬆ |
|
||||
| Hume | – | – | – | 83.0 | 75.3 | 54.3 |
|
||||
| VoxInstruct | 47.5 | 52.3 | 42.6 | 54.9 | 57.0 | 39.3 |
|
||||
| Parler-tts-mini | – | – | – | 63.4 | 48.7 | 28.6 |
|
||||
| Parler-tts-large | – | – | – | 60.0 | 45.9 | 31.2 |
|
||||
| PromptTTS | – | – | – | 64.3 | 47.2 | 31.4 |
|
||||
| PromptStyle | – | – | – | 57.4 | 46.4 | 30.9 |
|
||||
| VoiceSculptor | 75.7 | 64.7 | 61.5 | – | – | – |
|
||||
| Mimo-Audio-7B-Instruct | 75.7 | 74.3 | 61.5 | 80.6 | 77.6 | 59.5 |
|
||||
| Qwen3TTS-12Hz-1.7B-VD | **85.2** | **81.1** | **65.1** | 82.9 | 82.4 | 68.4 |
|
||||
| **VoxCPM2** | **85.2** | 71.5 | 60.8 | **84.2** | **83.2** | **71.4** |
|
||||
|
||||
</details>
|
||||
|
||||
---
|
||||
|
||||
## ⚙️ 微调
|
||||
|
||||
VoxCPM 支持**全参数微调(SFT)** 和 **LoRA 微调**。仅需 **5-10分钟** 的音频数据,即可适配特定说话人、语言或领域。
|
||||
|
||||
```bash
|
||||
# LoRA 微调(参数高效,推荐)
|
||||
python scripts/train_voxcpm_finetune.py \
|
||||
--config_path conf/voxcpm_v2/voxcpm_finetune_lora.yaml
|
||||
|
||||
# 全参数微调
|
||||
python scripts/train_voxcpm_finetune.py \
|
||||
--config_path conf/voxcpm_v2/voxcpm_finetune_all.yaml
|
||||
|
||||
# WebUI 训练与推理
|
||||
python lora_ft_webui.py # 然后打开 http://localhost:7860
|
||||
```
|
||||
|
||||
> **完整指南 →** [微调文档](https://voxcpm.readthedocs.io/zh-cn/latest/finetuning/finetune.html)(数据准备、配置、训练、LoRA 热切换、常见问题)
|
||||
|
||||
---
|
||||
|
||||
## 📚 文档
|
||||
|
||||
完整文档:**[voxcpm.readthedocs.io](https://voxcpm.readthedocs.io/zh-cn/latest/)**
|
||||
|
||||
| 主题 | 链接 |
|
||||
|---|---|
|
||||
| 快速开始与安装 | [快速开始](https://voxcpm.readthedocs.io/zh-cn/latest/quickstart.html) |
|
||||
| 使用指南与 Cookbook | [使用指南](https://voxcpm.readthedocs.io/zh-cn/latest/usage_guide.html) |
|
||||
| VoxCPM 系列模型 | [模型列表](https://voxcpm.readthedocs.io/zh-cn/latest/models/version_history.html) |
|
||||
| 微调(SFT & LoRA) | [微调指南](https://voxcpm.readthedocs.io/zh-cn/latest/finetuning/finetune.html) |
|
||||
| 常见问题 | [FAQ](https://voxcpm.readthedocs.io/zh-cn/latest/faq.html) |
|
||||
|
||||
---
|
||||
|
||||
## 🌟 生态与社区
|
||||
|
||||
| 项目 | 说明 |
|
||||
|---|---|
|
||||
| [**Nano-vLLM**](https://github.com/a710128/nanovllm-voxcpm) | 高吞吐快速 GPU 推理引擎 |
|
||||
| [**vLLM-Omni**](https://github.com/vllm-project/vllm-omni) | 官方 vLLM 全模态服务(原生支持 VoxCPM2)— PagedAttention、OpenAI 兼容 API |
|
||||
| [**VoxCPM.cpp**](https://github.com/bluryar/VoxCPM.cpp) | GGML/GGUF:CPU、CUDA、Vulkan 推理 |
|
||||
| [**VoxCPM-ONNX**](https://github.com/bluryar/VoxCPM-ONNX) | ONNX 导出,支持 CPU 推理 |
|
||||
| [**VoxCPMANE**](https://github.com/0seba/VoxCPMANE) | Apple Neural Engine 后端 |
|
||||
| [**voxcpm_rs**](https://github.com/madushan1000/voxcpm_rs) | Rust 重新实现 |
|
||||
| [**ComfyUI-VoxCPM**](https://github.com/wildminder/ComfyUI-VoxCPM) | ComfyUI 节点工作流 |
|
||||
| [**ComfyUI_RH_VoxCPM**](https://github.com/HM-RunningHub/ComfyUI_RH_VoxCPM) | 面向 VoxCPM 2 的功能更完整的 ComfyUI 工作流,支持多说话人、LoRA 和自动 ASR |
|
||||
| [**ComfyUI-VoxCPMTTS**](https://github.com/1038lab/ComfyUI-VoxCPMTTS) | ComfyUI TTS 扩展 |
|
||||
| [**TTS WebUI**](https://github.com/rsxdalv/tts_webui_extension.vox_cpm) | 浏览器端 TTS 扩展 |
|
||||
|
||||
> 完整生态见[文档](https://voxcpm.readthedocs.io/zh-cn/latest/)。社区项目非 OpenBMB 官方维护。做了什么有趣的东西?[提 Issue 或 PR](https://github.com/OpenBMB/VoxCPM/issues) 把它加进来!
|
||||
|
||||
---
|
||||
|
||||
## ⚠️ 风险与局限性
|
||||
|
||||
- **滥用风险:** VoxCPM 的声音克隆能力可生成高度逼真的合成语音。**严禁**将 VoxCPM 用于冒充他人、欺诈或虚假信息传播。我们强烈建议对所有 AI 生成的内容进行明确标注。
|
||||
- **可控生成稳定性:** 音色设计和可控声音克隆的结果可能因生成次数而异 — 建议尝试生成 1~3 次以获得理想的音色或风格。我们正在积极提升可控性的一致性。
|
||||
- **语言覆盖:** VoxCPM2 官方支持 30 种语言。对于未列入的语言,欢迎直接测试或使用自有数据进行微调。我们计划在未来版本中扩展语言覆盖。
|
||||
- **使用说明:** 本模型基于 Apache-2.0 协议发布。用于生产部署时,我们建议针对具体场景进行充分的测试和安全评估。
|
||||
|
||||
---
|
||||
|
||||
## 📖 引用
|
||||
|
||||
如果 VoxCPM 对您有帮助,请考虑引用我们的工作并为仓库加星 ⭐!
|
||||
|
||||
```bib
|
||||
@article{voxcpm2_2026,
|
||||
title = {VoxCPM2: Tokenizer-Free TTS for Multilingual Speech Generation, Creative Voice Design, and True-to-Life Cloning},
|
||||
author = {VoxCPM Team},
|
||||
journal = {GitHub},
|
||||
year = {2026},
|
||||
}
|
||||
|
||||
@article{voxcpm2025,
|
||||
title = {VoxCPM: Tokenizer-Free TTS for Context-Aware Speech Generation
|
||||
and True-to-Life Voice Cloning},
|
||||
author = {Zhou, Yixuan and Zeng, Guoyang and Liu, Xin and Li, Xiang and
|
||||
Yu, Renjie and Wang, Ziyang and Ye, Runchuan and Sun, Weiyue and
|
||||
Gui, Jiancheng and Li, Kehan and Wu, Zhiyong and Liu, Zhiyuan},
|
||||
journal = {arXiv preprint arXiv:2509.24650},
|
||||
year = {2025},
|
||||
}
|
||||
```
|
||||
|
||||
## 📄 许可证
|
||||
|
||||
VoxCPM 模型权重和代码基于 [Apache-2.0](LICENSE) 协议开源。
|
||||
|
||||
## 🙏 致谢
|
||||
|
||||
- [DiTAR](https://arxiv.org/abs/2502.03930) 扩散自回归骨干架构
|
||||
- [MiniCPM-4](https://github.com/OpenBMB/MiniCPM) 语言模型基座
|
||||
- [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) 基于 Flow Matching 的 LocDiT 实现
|
||||
- [DAC](https://github.com/descriptinc/descript-audio-codec) Audio VAE 骨干
|
||||
- 感谢所有社区用户试用 VoxCPM、反馈问题、分享想法和贡献——你们的支持让项目持续进步
|
||||
|
||||
## 机构
|
||||
|
||||
<p>
|
||||
<a href="https://modelbest.cn/"><img src="assets/modelbest_logo.png" width="28px"> 面壁智能</a>
|
||||
|
||||
<a href="https://github.com/thuhcsi"><img src="assets/thuhcsi_logo.png" width="28px"> 清华大学人机交互实验室</a>
|
||||
</p>
|
||||
|
||||
## ⭐ Star 历史
|
||||
|
||||
[](https://star-history.com/#OpenBMB/VoxCPM&Date)
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import logging
|
||||
import numpy as np
|
||||
@@ -9,8 +10,6 @@ from funasr import AutoModel
|
||||
from pathlib import Path
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
if os.environ.get("HF_REPO_ID", "").strip() == "":
|
||||
os.environ["HF_REPO_ID"] = "openbmb/VoxCPM2"
|
||||
|
||||
import voxcpm
|
||||
|
||||
@@ -55,7 +54,7 @@ _EXAMPLES_FOOTER_EN = (
|
||||
)
|
||||
|
||||
_USAGE_INSTRUCTIONS_ZH = (
|
||||
"**VoxCPM2 — 三种语音生成方式:**\n\n"
|
||||
"**三种语音生成方式:**\n\n"
|
||||
"🎨 **声音设计(Voice Design)** \n"
|
||||
"无需参考音频。在 **Control Instruction** 中描述目标音色特征"
|
||||
"(性别、年龄、语气、情绪、语速等),VoxCPM2 即可为你从零创造独一无二的声音。\n\n"
|
||||
@@ -66,6 +65,8 @@ _USAGE_INSTRUCTIONS_ZH = (
|
||||
"开启 **极致克隆模式** 并提供参考音频的文字内容(可自动识别)。"
|
||||
"模型会将参考音频视为已说出的前文,以**音频续写**的方式完整还原参考音频中的所有声音细节。"
|
||||
"注意:该模式与可控克隆模式互斥,将禁用Control Instruction。\n\n"
|
||||
"目前支持的方言包括:\n"
|
||||
"「四川话、粤语、吴语、东北话、河南话、陕西话、山东话、天津话、闽南话」"
|
||||
)
|
||||
|
||||
_EXAMPLES_FOOTER_ZH = (
|
||||
@@ -221,11 +222,11 @@ _APP_THEME = gr.themes.Soft(
|
||||
# ---------- Model ----------
|
||||
|
||||
class VoxCPMDemo:
|
||||
def __init__(self, model_dir: Optional[str] = None) -> None:
|
||||
def __init__(self, model_id: str = "openbmb/VoxCPM2") -> None:
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"Running on device: {self.device}")
|
||||
logger.info(f"运行在设备上: {self.device}")
|
||||
|
||||
self.asr_model_id = "iic/SenseVoiceSmall"
|
||||
self.asr_model_id = "./models/iic/SenseVoiceSmall"
|
||||
self.asr_model: Optional[AutoModel] = AutoModel(
|
||||
model=self.asr_model_id,
|
||||
disable_update=True,
|
||||
@@ -234,36 +235,13 @@ class VoxCPMDemo:
|
||||
)
|
||||
|
||||
self.voxcpm_model: Optional[voxcpm.VoxCPM] = None
|
||||
self.explicit_model_dir = model_dir
|
||||
|
||||
def _resolve_model_dir(self) -> str:
|
||||
if self.explicit_model_dir and os.path.isdir(self.explicit_model_dir):
|
||||
return self.explicit_model_dir
|
||||
env_model_dir = os.environ.get("VOXCPM_MODEL_DIR", "").strip()
|
||||
if env_model_dir and os.path.isdir(env_model_dir):
|
||||
return env_model_dir
|
||||
repo_id = os.environ.get("HF_REPO_ID", "").strip()
|
||||
if len(repo_id) > 0:
|
||||
target_dir = os.path.join("models", repo_id.replace("/", "__"))
|
||||
if not os.path.isdir(target_dir):
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
os.makedirs(target_dir, exist_ok=True)
|
||||
logger.info(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...")
|
||||
snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False)
|
||||
except Exception as e:
|
||||
logger.warning(f"HF download failed: {e}. Falling back to 'models'.")
|
||||
return "models"
|
||||
return target_dir
|
||||
return "models"
|
||||
self._model_id = model_id
|
||||
|
||||
def get_or_load_voxcpm(self) -> voxcpm.VoxCPM:
|
||||
if self.voxcpm_model is not None:
|
||||
return self.voxcpm_model
|
||||
logger.info("Model not loaded, initializing...")
|
||||
model_dir = self._resolve_model_dir()
|
||||
logger.info(f"Using model dir: {model_dir}")
|
||||
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir, optimize=True)
|
||||
logger.info(f"Loading model: {self._model_id}")
|
||||
self.voxcpm_model = voxcpm.VoxCPM.from_pretrained(self._model_id, optimize=True,zipenhancer_model_id="./models/iic/speech_zipenhancer_ans_multiloss_16k_base")
|
||||
logger.info("Model loaded successfully.")
|
||||
return self.voxcpm_model
|
||||
|
||||
@@ -315,6 +293,9 @@ class VoxCPMDemo:
|
||||
raise ValueError("Please input text to synthesize.")
|
||||
|
||||
control = (control_instruction or "").strip()
|
||||
# Strip any parentheses (half-width/full-width) from control text to avoid
|
||||
# breaking the "(control)text" prompt format expected by the model.
|
||||
control = re.sub(r"[()()]", "", control).strip()
|
||||
final_text = f"({control}){text}" if control else text
|
||||
|
||||
audio_path = reference_wav_path_input if reference_wav_path_input else None
|
||||
@@ -507,9 +488,9 @@ def run_demo(
|
||||
server_name: str = "0.0.0.0",
|
||||
server_port: int = 8808,
|
||||
show_error: bool = True,
|
||||
model_dir: Optional[str] = None,
|
||||
model_id: str = "./models/OpenBMB/VoxCPM2",
|
||||
):
|
||||
demo = VoxCPMDemo(model_dir=model_dir)
|
||||
demo = VoxCPMDemo(model_id=model_id)
|
||||
interface = create_demo_interface(demo)
|
||||
interface.queue(max_size=10, default_concurrency_limit=1).launch(
|
||||
server_name=server_name,
|
||||
@@ -524,7 +505,10 @@ def run_demo(
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model-dir", type=str, default=None, help="Path to VoxCPM2 checkpoint directory")
|
||||
parser.add_argument("--port", type=int, default=8808, help="Server port")
|
||||
parser.add_argument(
|
||||
"--model-id", type=str, default="./models/OpenBMB/VoxCPM2",
|
||||
help="本地路径或HuggingFace仓库ID(默认:./models/openbmb/VoxCPM2)",
|
||||
)
|
||||
parser.add_argument("--port", type=int, default=8808, help="服务端口")
|
||||
args = parser.parse_args()
|
||||
run_demo(model_dir=args.model_dir, server_port=args.port)
|
||||
run_demo(model_id=args.model_id, server_port=args.port)
|
||||
-280
@@ -1,280 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
import gradio as gr
|
||||
from typing import Optional, Tuple
|
||||
from funasr import AutoModel
|
||||
from pathlib import Path
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
if os.environ.get("HF_REPO_ID", "").strip() == "":
|
||||
os.environ["HF_REPO_ID"] = "openbmb/VoxCPM1.5"
|
||||
|
||||
import voxcpm
|
||||
|
||||
|
||||
class VoxCPMDemo:
|
||||
def __init__(self) -> None:
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"🚀 Running on device: {self.device}", file=sys.stderr)
|
||||
|
||||
# ASR model for prompt text recognition
|
||||
self.asr_model_id = "iic/SenseVoiceSmall"
|
||||
self.asr_model: Optional[AutoModel] = AutoModel(
|
||||
model=self.asr_model_id,
|
||||
disable_update=True,
|
||||
log_level='DEBUG',
|
||||
device="cuda:0" if self.device == "cuda" else "cpu",
|
||||
)
|
||||
|
||||
# TTS model (lazy init)
|
||||
self.voxcpm_model: Optional[voxcpm.VoxCPM] = None
|
||||
self.default_local_model_dir = "./models/VoxCPM1.5"
|
||||
|
||||
# ---------- Model helpers ----------
|
||||
def _resolve_model_dir(self) -> str:
|
||||
"""
|
||||
Resolve model directory:
|
||||
1) Use local checkpoint directory if exists
|
||||
2) If HF_REPO_ID env is set, download into models/{repo}
|
||||
3) Fallback to 'models'
|
||||
"""
|
||||
if os.path.isdir(self.default_local_model_dir):
|
||||
return self.default_local_model_dir
|
||||
|
||||
repo_id = os.environ.get("HF_REPO_ID", "").strip()
|
||||
if len(repo_id) > 0:
|
||||
target_dir = os.path.join("models", repo_id.replace("/", "__"))
|
||||
if not os.path.isdir(target_dir):
|
||||
try:
|
||||
from huggingface_hub import snapshot_download # type: ignore
|
||||
os.makedirs(target_dir, exist_ok=True)
|
||||
print(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...", file=sys.stderr)
|
||||
snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False)
|
||||
except Exception as e:
|
||||
print(f"Warning: HF download failed: {e}. Falling back to 'data'.", file=sys.stderr)
|
||||
return "models"
|
||||
return target_dir
|
||||
return "models"
|
||||
|
||||
def get_or_load_voxcpm(self) -> voxcpm.VoxCPM:
|
||||
if self.voxcpm_model is not None:
|
||||
return self.voxcpm_model
|
||||
print("Model not loaded, initializing...", file=sys.stderr)
|
||||
model_dir = self._resolve_model_dir()
|
||||
print(f"Using model dir: {model_dir}", file=sys.stderr)
|
||||
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir)
|
||||
print("Model loaded successfully.", file=sys.stderr)
|
||||
return self.voxcpm_model
|
||||
|
||||
# ---------- Functional endpoints ----------
|
||||
def prompt_wav_recognition(self, prompt_wav: Optional[str]) -> str:
|
||||
if prompt_wav is None:
|
||||
return ""
|
||||
res = self.asr_model.generate(input=prompt_wav, language="auto", use_itn=True)
|
||||
text = res[0]["text"].split('|>')[-1]
|
||||
return text
|
||||
|
||||
def generate_tts_audio(
|
||||
self,
|
||||
text_input: str,
|
||||
prompt_wav_path_input: Optional[str] = None,
|
||||
prompt_text_input: Optional[str] = None,
|
||||
cfg_value_input: float = 2.0,
|
||||
inference_timesteps_input: int = 10,
|
||||
do_normalize: bool = True,
|
||||
denoise: bool = True,
|
||||
) -> Tuple[int, np.ndarray]:
|
||||
"""
|
||||
Generate speech from text using VoxCPM; optional reference audio for voice style guidance.
|
||||
Returns (sample_rate, waveform_numpy)
|
||||
"""
|
||||
current_model = self.get_or_load_voxcpm()
|
||||
|
||||
text = (text_input or "").strip()
|
||||
if len(text) == 0:
|
||||
raise ValueError("Please input text to synthesize.")
|
||||
|
||||
prompt_wav_path = prompt_wav_path_input if prompt_wav_path_input else None
|
||||
prompt_text = prompt_text_input if prompt_text_input else None
|
||||
|
||||
print(f"Generating audio for text: '{text[:60]}...'", file=sys.stderr)
|
||||
wav = current_model.generate(
|
||||
text=text,
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
cfg_value=float(cfg_value_input),
|
||||
inference_timesteps=int(inference_timesteps_input),
|
||||
normalize=do_normalize,
|
||||
denoise=denoise,
|
||||
)
|
||||
return (current_model.tts_model.sample_rate, wav)
|
||||
|
||||
|
||||
# ---------- UI Builders ----------
|
||||
|
||||
_APP_THEME = gr.themes.Soft(
|
||||
primary_hue="blue",
|
||||
secondary_hue="gray",
|
||||
neutral_hue="slate",
|
||||
font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"],
|
||||
)
|
||||
|
||||
_CUSTOM_CSS = """
|
||||
.logo-container {
|
||||
text-align: center;
|
||||
margin: 0.5rem 0 1rem 0;
|
||||
}
|
||||
.logo-container img {
|
||||
height: 80px;
|
||||
width: auto;
|
||||
max-width: 200px;
|
||||
display: inline-block;
|
||||
}
|
||||
/* Bold accordion labels */
|
||||
#acc_quick details > summary,
|
||||
#acc_tips details > summary {
|
||||
font-weight: 600 !important;
|
||||
font-size: 1.1em !important;
|
||||
}
|
||||
/* Bold labels for specific checkboxes */
|
||||
#chk_denoise label,
|
||||
#chk_denoise span,
|
||||
#chk_normalize label,
|
||||
#chk_normalize span {
|
||||
font-weight: 600;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def create_demo_interface(demo: VoxCPMDemo):
|
||||
"""Build the Gradio UI for VoxCPM demo."""
|
||||
gr.set_static_paths(paths=[Path.cwd().absolute()/"assets"])
|
||||
|
||||
with gr.Blocks() as interface:
|
||||
# Header logo
|
||||
gr.HTML('<div class="logo-container"><img src="/gradio_api/file=assets/voxcpm_logo.png" alt="VoxCPM Logo"></div>')
|
||||
|
||||
# Quick Start
|
||||
with gr.Accordion("📋 Quick Start Guide |快速入门", open=False, elem_id="acc_quick"):
|
||||
gr.Markdown("""
|
||||
### How to Use |使用说明
|
||||
1. **(Optional) Provide a Voice Prompt** - Upload or record an audio clip to provide the desired voice characteristics for synthesis.
|
||||
**(可选)提供参考声音** - 上传或录制一段音频,为声音合成提供音色、语调和情感等个性化特征
|
||||
2. **(Optional) Enter prompt text** - If you provided a voice prompt, enter the corresponding transcript here (auto-recognition available).
|
||||
**(可选项)输入参考文本** - 如果提供了参考语音,请输入其对应的文本内容(支持自动识别)。
|
||||
3. **Enter target text** - Type the text you want the model to speak.
|
||||
**输入目标文本** - 输入您希望模型朗读的文字内容。
|
||||
4. **Generate Speech** - Click the "Generate" button to create your audio.
|
||||
**生成语音** - 点击"生成"按钮,即可为您创造出音频。
|
||||
""")
|
||||
|
||||
# Pro Tips
|
||||
with gr.Accordion("💡 Pro Tips |使用建议", open=False, elem_id="acc_tips"):
|
||||
gr.Markdown("""
|
||||
### Prompt Speech Enhancement|参考语音降噪
|
||||
- **Enable** to remove background noise for a clean voice, with an external ZipEnhancer component. However, this will limit the audio sampling rate to 16kHz, restricting the cloning quality ceiling.
|
||||
**启用**:通过 ZipEnhancer 组件消除背景噪音,但会将音频采样率限制在16kHz,限制克隆上限。
|
||||
- **Disable** to preserve the original audio's all information, including background atmosphere, and support audio cloning up to 44.1kHz sampling rate.
|
||||
**禁用**:保留原始音频的全部信息,包括背景环境声,最高支持44.1kHz的音频复刻。
|
||||
|
||||
### Text Normalization|文本正则化
|
||||
- **Enable** to process general text with an external WeTextProcessing component.
|
||||
**启用**:使用 WeTextProcessing 组件,可支持常见文本的正则化处理。
|
||||
- **Disable** to use VoxCPM's native text understanding ability. For example, it supports phonemes input (For Chinese, phonemes are converted using pinyin, {ni3}{hao3}; For English, phonemes are converted using CMUDict, {HH AH0 L OW1}), try it!
|
||||
**禁用**:将使用 VoxCPM 内置的文本理解能力。如,支持音素输入(如中文转拼音:{ni3}{hao3};英文转CMUDict:{HH AH0 L OW1})和公式符号合成,尝试一下!
|
||||
|
||||
### CFG Value|CFG 值
|
||||
- **Lower CFG** if the voice prompt sounds strained or expressive, or instability occurs with long text input.
|
||||
**调低**:如果提示语音听起来不自然或过于夸张,或者长文本输入出现稳定性问题。
|
||||
- **Higher CFG** for better adherence to the prompt speech style or input text, or instability occurs with too short text input.
|
||||
**调高**:为更好地贴合提示音频的风格或输入文本, 或者极短文本输入出现稳定性问题。
|
||||
|
||||
### Inference Timesteps|推理时间步
|
||||
- **Lower** for faster synthesis speed.
|
||||
**调低**:合成速度更快。
|
||||
- **Higher** for better synthesis quality.
|
||||
**调高**:合成质量更佳。
|
||||
""")
|
||||
|
||||
# Main controls
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
prompt_wav = gr.Audio(
|
||||
sources=["upload", 'microphone'],
|
||||
type="filepath",
|
||||
label="Prompt Speech (Optional, or let VoxCPM improvise)",
|
||||
value="./examples/example.wav",
|
||||
)
|
||||
DoDenoisePromptAudio = gr.Checkbox(
|
||||
value=False,
|
||||
label="Prompt Speech Enhancement",
|
||||
elem_id="chk_denoise",
|
||||
info="We use ZipEnhancer model to denoise the prompt audio."
|
||||
)
|
||||
with gr.Row():
|
||||
prompt_text = gr.Textbox(
|
||||
value="Just by listening a few minutes a day, you'll be able to eliminate negative thoughts by conditioning your mind to be more positive.",
|
||||
label="Prompt Text",
|
||||
placeholder="Please enter the prompt text. Automatic recognition is supported, and you can correct the results yourself..."
|
||||
)
|
||||
run_btn = gr.Button("Generate Speech", variant="primary")
|
||||
|
||||
with gr.Column():
|
||||
cfg_value = gr.Slider(
|
||||
minimum=1.0,
|
||||
maximum=3.0,
|
||||
value=2.0,
|
||||
step=0.1,
|
||||
label="CFG Value (Guidance Scale)",
|
||||
info="Higher values increase adherence to prompt, lower values allow more creativity"
|
||||
)
|
||||
inference_timesteps = gr.Slider(
|
||||
minimum=4,
|
||||
maximum=30,
|
||||
value=10,
|
||||
step=1,
|
||||
label="Inference Timesteps",
|
||||
info="Number of inference timesteps for generation (higher values may improve quality but slower)"
|
||||
)
|
||||
with gr.Row():
|
||||
text = gr.Textbox(
|
||||
value="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly realistic speech.",
|
||||
label="Target Text",
|
||||
)
|
||||
with gr.Row():
|
||||
DoNormalizeText = gr.Checkbox(
|
||||
value=False,
|
||||
label="Text Normalization",
|
||||
elem_id="chk_normalize",
|
||||
info="We use wetext library to normalize the input text."
|
||||
)
|
||||
audio_output = gr.Audio(label="Output Audio")
|
||||
|
||||
# Wiring
|
||||
run_btn.click(
|
||||
fn=demo.generate_tts_audio,
|
||||
inputs=[text, prompt_wav, prompt_text, cfg_value, inference_timesteps, DoNormalizeText, DoDenoisePromptAudio],
|
||||
outputs=[audio_output],
|
||||
show_progress=True,
|
||||
api_name="generate",
|
||||
)
|
||||
prompt_wav.change(fn=demo.prompt_wav_recognition, inputs=[prompt_wav], outputs=[prompt_text])
|
||||
|
||||
return interface
|
||||
|
||||
|
||||
def run_demo(server_name: str = "localhost", server_port: int = 7860, show_error: bool = True):
|
||||
demo = VoxCPMDemo()
|
||||
interface = create_demo_interface(demo)
|
||||
interface.queue(max_size=10, default_concurrency_limit=1).launch(
|
||||
server_name=server_name,
|
||||
server_port=server_port,
|
||||
show_error=show_error,
|
||||
theme=_APP_THEME,
|
||||
css=_CUSTOM_CSS,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_demo()
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 9.5 KiB |
@@ -1,7 +1,8 @@
|
||||
pretrained_path: /path/to/VoxCPM2/
|
||||
train_manifest: /path/to/train.jsonl
|
||||
val_manifest: null
|
||||
sample_rate: 48000
|
||||
sample_rate: 16000 # AudioVAE encoder input rate; must match audio_vae_config.sample_rate
|
||||
out_sample_rate: 48000 # AudioVAE decoder output rate; used for TensorBoard audio logging
|
||||
batch_size: 2
|
||||
grad_accum_steps: 8 # effective batch size = batch_size × grad_accum_steps = 16
|
||||
num_workers: 8
|
||||
@@ -14,6 +15,7 @@ weight_decay: 0.01
|
||||
warmup_steps: 100
|
||||
max_steps: 1000
|
||||
max_batch_tokens: 8192
|
||||
max_grad_norm: 1.0 # gradient clipping max norm; 0 = disabled
|
||||
save_path: /path/to/checkpoints/finetune_all
|
||||
tensorboard: /path/to/logs/finetune_all
|
||||
lambdas:
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
pretrained_path: /path/to/VoxCPM2/
|
||||
train_manifest: /path/to/train.jsonl
|
||||
val_manifest: null
|
||||
sample_rate: 48000
|
||||
sample_rate: 16000 # AudioVAE encoder input rate; must match audio_vae_config.sample_rate
|
||||
out_sample_rate: 48000 # AudioVAE decoder output rate; used for TensorBoard audio logging
|
||||
batch_size: 2
|
||||
grad_accum_steps: 8 # effective batch size = batch_size × grad_accum_steps = 16
|
||||
num_workers: 8
|
||||
@@ -14,6 +15,7 @@ weight_decay: 0.01
|
||||
warmup_steps: 100
|
||||
max_steps: 1000
|
||||
max_batch_tokens: 8192
|
||||
max_grad_norm: 1.0 # gradient clipping max norm; 0 = disabled
|
||||
save_path: /path/to/checkpoints/finetune_lora
|
||||
tensorboard: /path/to/logs/finetune_lora
|
||||
lambdas:
|
||||
|
||||
+105
-18
@@ -14,8 +14,10 @@ from typing import Optional
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root / "src"))
|
||||
|
||||
# Default pretrained model path relative to this repo
|
||||
default_pretrained_path = str(project_root / "models" / "openbmb__VoxCPM1.5")
|
||||
# Default pretrained model path: prefer VoxCPM2 if it exists, fallback to VoxCPM1.5
|
||||
_v2_path = project_root / "models" / "openbmb__VoxCPM2"
|
||||
_v15_path = project_root / "models" / "openbmb__VoxCPM1.5"
|
||||
default_pretrained_path = str(_v2_path if _v2_path.exists() else _v15_path)
|
||||
|
||||
from voxcpm.core import VoxCPM
|
||||
from voxcpm.model.voxcpm import LoRAConfig
|
||||
@@ -99,6 +101,24 @@ def get_timestamp_str():
|
||||
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
|
||||
def detect_sample_rate(pretrained_path: str) -> Optional[int]:
|
||||
"""Read audio_vae_config.sample_rate from the model's config.json.
|
||||
|
||||
This is the AudioVAE *encoder* input rate, which is the correct rate for
|
||||
resampling training data. Returns None when detection fails.
|
||||
"""
|
||||
config_file = os.path.join(pretrained_path, "config.json")
|
||||
if not os.path.isfile(config_file):
|
||||
return None
|
||||
try:
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
return int(cfg["audio_vae_config"]["sample_rate"])
|
||||
except (KeyError, ValueError, json.JSONDecodeError) as e:
|
||||
print(f"Warning: failed to detect sample_rate from {config_file}: {e}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
|
||||
def get_or_load_asr_model():
|
||||
global asr_model
|
||||
if asr_model is None:
|
||||
@@ -261,27 +281,48 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
|
||||
print(f"Warning: Failed to read base_model from LoRA config: {e}", file=sys.stderr)
|
||||
|
||||
# 加载模型
|
||||
lora_to_load = lora_selection if lora_selection and lora_selection != "None" else None
|
||||
try:
|
||||
print(f"Loading base model: {base_model_path}", file=sys.stderr)
|
||||
load_model(base_model_path)
|
||||
if lora_selection and lora_selection != "None":
|
||||
print(f"Model loaded for LoRA: {lora_selection}", file=sys.stderr)
|
||||
load_model(base_model_path, lora_to_load)
|
||||
if lora_to_load:
|
||||
print(f"Model loaded with LoRA: {lora_selection}", file=sys.stderr)
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to load model from {base_model_path}: {str(e)}"
|
||||
print(error_msg, file=sys.stderr)
|
||||
return None, error_msg
|
||||
lora_just_loaded = lora_to_load
|
||||
else:
|
||||
lora_just_loaded = None
|
||||
|
||||
# Handle LoRA hot-swapping
|
||||
assert current_model is not None, "Model must be loaded before inference"
|
||||
if lora_selection and lora_selection != "None":
|
||||
full_lora_path = os.path.join("lora", lora_selection)
|
||||
print(f"Hot-loading LoRA: {full_lora_path}", file=sys.stderr)
|
||||
try:
|
||||
current_model.load_lora(full_lora_path)
|
||||
current_model.set_lora_enabled(True)
|
||||
except Exception as e:
|
||||
print(f"Error loading LoRA: {e}", file=sys.stderr)
|
||||
return None, f"Error loading LoRA: {e}"
|
||||
|
||||
if lora_just_loaded != lora_selection:
|
||||
new_lora_config, new_base_model = load_lora_config_from_checkpoint(full_lora_path)
|
||||
current_r = current_model.tts_model.lora_config.r if current_model.tts_model.lora_config else None
|
||||
new_r = new_lora_config.r if new_lora_config else None
|
||||
|
||||
if new_r is not None and current_r is not None and new_r != current_r:
|
||||
print(f"LoRA rank mismatch (model r={current_r}, checkpoint r={new_r}), reloading...", file=sys.stderr)
|
||||
reload_base = (
|
||||
new_base_model if new_base_model and os.path.exists(new_base_model)
|
||||
else (pretrained_path if pretrained_path and pretrained_path.strip() else default_pretrained_path)
|
||||
)
|
||||
try:
|
||||
load_model(reload_base, lora_selection)
|
||||
except Exception as e:
|
||||
return None, f"Failed to reload model for LoRA rank change: {e}"
|
||||
else:
|
||||
print(f"Hot-loading LoRA: {full_lora_path}", file=sys.stderr)
|
||||
try:
|
||||
current_model.load_lora(full_lora_path)
|
||||
except Exception as e:
|
||||
print(f"Error loading LoRA: {e}", file=sys.stderr)
|
||||
return None, f"Error loading LoRA: {e}"
|
||||
current_model.set_lora_enabled(True)
|
||||
else:
|
||||
print("Disabling LoRA", file=sys.stderr)
|
||||
current_model.set_lora_enabled(False)
|
||||
@@ -350,6 +391,7 @@ def start_training(
|
||||
warmup_steps=100,
|
||||
max_steps=None,
|
||||
sample_rate=44100,
|
||||
max_grad_norm=1.0,
|
||||
# LoRA advanced
|
||||
enable_lm=True,
|
||||
enable_dit=True,
|
||||
@@ -377,15 +419,39 @@ def start_training(
|
||||
os.makedirs(checkpoints_dir, exist_ok=True)
|
||||
os.makedirs(logs_dir, exist_ok=True)
|
||||
|
||||
# Auto-detect sample_rate from model config.json to prevent mismatch
|
||||
detected_sr = detect_sample_rate(pretrained_path)
|
||||
if detected_sr is not None:
|
||||
if int(sample_rate) != detected_sr:
|
||||
training_log += (
|
||||
f"[Auto-fix] sample_rate changed from {int(sample_rate)} to {detected_sr} "
|
||||
f"(read from {pretrained_path}/config.json audio_vae_config.sample_rate)\n"
|
||||
)
|
||||
sample_rate = detected_sr
|
||||
|
||||
# Create config dictionary
|
||||
# Resolve max_steps default
|
||||
resolved_max_steps = int(max_steps) if max_steps not in (None, "", 0) else int(num_iters)
|
||||
|
||||
# Auto-detect out_sample_rate from model config
|
||||
out_sample_rate = 0
|
||||
config_file = os.path.join(pretrained_path, "config.json")
|
||||
if os.path.isfile(config_file):
|
||||
try:
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
out_sr = cfg.get("audio_vae_config", {}).get("out_sample_rate")
|
||||
if out_sr:
|
||||
out_sample_rate = int(out_sr)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
config = {
|
||||
"pretrained_path": pretrained_path,
|
||||
"train_manifest": train_manifest,
|
||||
"val_manifest": val_manifest,
|
||||
"sample_rate": int(sample_rate),
|
||||
"out_sample_rate": out_sample_rate,
|
||||
"batch_size": int(batch_size),
|
||||
"grad_accum_steps": int(grad_accum_steps),
|
||||
"num_workers": int(num_workers),
|
||||
@@ -397,6 +463,7 @@ def start_training(
|
||||
"weight_decay": float(weight_decay),
|
||||
"warmup_steps": int(warmup_steps),
|
||||
"max_steps": resolved_max_steps,
|
||||
"max_grad_norm": float(max_grad_norm),
|
||||
"save_path": checkpoints_dir,
|
||||
"tensorboard": tensorboard_path if tensorboard_path else logs_dir,
|
||||
"lambdas": {"loss/diff": 1.0, "loss/stop": 1.0},
|
||||
@@ -904,17 +971,19 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
|
||||
with gr.Row():
|
||||
max_steps = gr.Number(label="最大步数 (max_steps, 0→默认num_iters)", value=0, precision=0)
|
||||
sample_rate = gr.Number(label="采样率 (sample_rate)", value=44100, precision=0)
|
||||
tensorboard_path = gr.Textbox(label="Tensorboard 路径 (可选)", value="")
|
||||
max_grad_norm = gr.Number(label="梯度裁剪 (max_grad_norm, 0=关闭)", value=1.0)
|
||||
with gr.Row():
|
||||
tensorboard_path = gr.Textbox(label="Tensorboard 路径 (可选)", value="")
|
||||
enable_lm = gr.Checkbox(label="启用 LoRA LM (enable_lm)", value=True)
|
||||
enable_dit = gr.Checkbox(label="启用 LoRA DIT (enable_dit)", value=True)
|
||||
with gr.Row():
|
||||
enable_proj = gr.Checkbox(label="启用投影 (enable_proj)", value=False)
|
||||
dropout = gr.Number(label="LoRA Dropout", value=0.0)
|
||||
|
||||
gr.Markdown("#### 分发选项 (Distribution)")
|
||||
with gr.Row():
|
||||
hf_model_id = gr.Textbox(
|
||||
label="HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)", value="openbmb/VoxCPM1.5"
|
||||
label="HuggingFace Model ID (e.g., openbmb/VoxCPM2)", value=""
|
||||
)
|
||||
distribute = gr.Checkbox(label="分发模式 (distribute)", value=False)
|
||||
|
||||
@@ -929,6 +998,19 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
|
||||
show_label=False,
|
||||
)
|
||||
|
||||
def on_pretrained_path_change(path):
|
||||
"""Auto-detect sample_rate when pretrained model path changes."""
|
||||
sr = detect_sample_rate(path)
|
||||
if sr is not None:
|
||||
return gr.update(value=sr)
|
||||
return gr.update()
|
||||
|
||||
train_pretrained_path.change(
|
||||
on_pretrained_path_change,
|
||||
inputs=[train_pretrained_path],
|
||||
outputs=[sample_rate],
|
||||
)
|
||||
|
||||
start_btn.click(
|
||||
start_training,
|
||||
inputs=[
|
||||
@@ -951,6 +1033,7 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
|
||||
warmup_steps,
|
||||
max_steps,
|
||||
sample_rate,
|
||||
max_grad_norm,
|
||||
enable_lm,
|
||||
enable_dit,
|
||||
enable_proj,
|
||||
@@ -1109,12 +1192,13 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
|
||||
"warmup_steps": "warmup_steps",
|
||||
"max_steps": "最大步数 (max_steps)",
|
||||
"sample_rate": "采样率 (sample_rate)",
|
||||
"max_grad_norm": "梯度裁剪 (max_grad_norm, 0=关闭)",
|
||||
"enable_lm": "启用 LoRA LM (enable_lm)",
|
||||
"enable_dit": "启用 LoRA DIT (enable_dit)",
|
||||
"enable_proj": "启用投影 (enable_proj)",
|
||||
"dropout": "LoRA Dropout",
|
||||
"tensorboard_path": "Tensorboard 路径 (可选)",
|
||||
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
|
||||
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM2)",
|
||||
"distribute": "分发模式 (distribute)",
|
||||
}
|
||||
else:
|
||||
@@ -1127,12 +1211,13 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
|
||||
"warmup_steps": "Warmup Steps",
|
||||
"max_steps": "Max Steps",
|
||||
"sample_rate": "Sample Rate",
|
||||
"max_grad_norm": "Max Grad Norm (0=disabled)",
|
||||
"enable_lm": "Enable LoRA LM",
|
||||
"enable_dit": "Enable LoRA DIT",
|
||||
"enable_proj": "Enable Projection",
|
||||
"dropout": "LoRA Dropout",
|
||||
"tensorboard_path": "Tensorboard Path (Optional)",
|
||||
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
|
||||
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM2)",
|
||||
"distribute": "Distribute Mode",
|
||||
}
|
||||
|
||||
@@ -1162,11 +1247,12 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
|
||||
gr.update(label=adv["warmup_steps"]),
|
||||
gr.update(label=adv["max_steps"]),
|
||||
gr.update(label=adv["sample_rate"]),
|
||||
gr.update(label=adv["max_grad_norm"]),
|
||||
gr.update(label=adv["tensorboard_path"]),
|
||||
gr.update(label=adv["enable_lm"]),
|
||||
gr.update(label=adv["enable_dit"]),
|
||||
gr.update(label=adv["enable_proj"]),
|
||||
gr.update(label=adv["dropout"]),
|
||||
gr.update(label=adv["tensorboard_path"]),
|
||||
# Distribution options
|
||||
gr.update(label=adv["hf_model_id"]),
|
||||
gr.update(label=adv["distribute"]),
|
||||
@@ -1213,11 +1299,12 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
|
||||
warmup_steps,
|
||||
max_steps,
|
||||
sample_rate,
|
||||
max_grad_norm,
|
||||
tensorboard_path,
|
||||
enable_lm,
|
||||
enable_dit,
|
||||
enable_proj,
|
||||
dropout,
|
||||
tensorboard_path,
|
||||
# distribution outputs
|
||||
hf_model_id,
|
||||
distribute,
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
模型下载脚本
|
||||
|
||||
"""
|
||||
from modelscope import snapshot_download
|
||||
|
||||
|
||||
def download(repo_id:str, local_dir:str):
|
||||
"""
|
||||
下载模型仓库或单个文件
|
||||
|
||||
|
||||
Args:
|
||||
repo_id (str): 用户名/仓库名,例如 'stabilityai/sdxl-turbo'
|
||||
local_dir (str or Path): 下载文件放置的本地目录路径
|
||||
|
||||
Returns:
|
||||
str: 下载文件的本地路径
|
||||
|
||||
Raises:
|
||||
ValueError: 当 repo_id 格式不正确时
|
||||
"""
|
||||
model_dir = snapshot_download(
|
||||
repo_id,
|
||||
repo_type='model',
|
||||
local_dir=f"{local_dir}/{repo_id}",
|
||||
)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
download("OpenBMB/VoxCPM2", "./models")
|
||||
download("iic/SenseVoiceSmall", "./models")
|
||||
download('iic/speech_zipenhancer_ans_multiloss_16k_base',"./models")
|
||||
+1
-2
@@ -47,8 +47,7 @@ dependencies = [
|
||||
"funasr",
|
||||
"spaces",
|
||||
"argbind",
|
||||
"safetensors"
|
||||
|
||||
"safetensors",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
"""Unit checks for pick_runtime_dtype / get_dtype consistency.
|
||||
|
||||
Loads src/voxcpm/model/utils.py directly to avoid the heavy voxcpm package
|
||||
init. Run with: `python scripts/test_pick_runtime_dtype.py`.
|
||||
"""
|
||||
import importlib.util
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent
|
||||
UTILS = str(REPO_ROOT / "src" / "voxcpm" / "model" / "utils.py")
|
||||
spec = importlib.util.spec_from_file_location("voxcpm_utils", UTILS)
|
||||
utils = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(utils)
|
||||
|
||||
_LOW_PRECISION_DTYPES = utils._LOW_PRECISION_DTYPES
|
||||
_VALID_DTYPE_OVERRIDES = utils._VALID_DTYPE_OVERRIDES
|
||||
get_dtype = utils.get_dtype
|
||||
pick_runtime_dtype = utils.pick_runtime_dtype
|
||||
|
||||
|
||||
def expect(actual, expected, label):
|
||||
ok = actual == expected
|
||||
mark = "OK " if ok else "FAIL"
|
||||
print(f"[{mark}] {label}: got={actual!r} expected={expected!r}")
|
||||
return ok
|
||||
|
||||
|
||||
def expect_raises(fn, exc_type, label):
|
||||
try:
|
||||
fn()
|
||||
except exc_type as e:
|
||||
print(f"[OK ] {label}: raised {exc_type.__name__}: {e}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"[FAIL] {label}: raised {type(e).__name__} not {exc_type.__name__}: {e}")
|
||||
return False
|
||||
print(f"[FAIL] {label}: no exception raised")
|
||||
return False
|
||||
|
||||
|
||||
results = []
|
||||
|
||||
print("=== override set sanity ===")
|
||||
results.append(expect("half" not in _VALID_DTYPE_OVERRIDES, True, "half removed from _VALID_DTYPE_OVERRIDES"))
|
||||
results.append(expect("half" not in _LOW_PRECISION_DTYPES, True, "half removed from _LOW_PRECISION_DTYPES"))
|
||||
|
||||
print("\n=== every accepted override parses through get_dtype ===")
|
||||
for dt in sorted(_VALID_DTYPE_OVERRIDES):
|
||||
try:
|
||||
torch_dtype = get_dtype(dt)
|
||||
print(f"[OK ] get_dtype({dt!r}) -> {torch_dtype}")
|
||||
results.append(True)
|
||||
except Exception as e:
|
||||
print(f"[FAIL] get_dtype({dt!r}) raised: {e}")
|
||||
results.append(False)
|
||||
|
||||
print("\n=== pick_runtime_dtype: non-mps is a no-op ===")
|
||||
results.append(expect(pick_runtime_dtype("cuda", "bfloat16"), "bfloat16", "cuda/bf16 untouched"))
|
||||
results.append(expect(pick_runtime_dtype("cpu", "float16"), "float16", "cpu/fp16 untouched"))
|
||||
results.append(expect(pick_runtime_dtype("cuda", "float32"), "float32", "cuda/fp32 untouched"))
|
||||
|
||||
print("\n=== pick_runtime_dtype: mps forces fp32 for low-precision ===")
|
||||
os.environ.pop("VOXCPM_MPS_DTYPE", None)
|
||||
results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "float32", "mps/bf16 -> fp32"))
|
||||
results.append(expect(pick_runtime_dtype("mps", "bf16"), "float32", "mps/bf16-alias -> fp32"))
|
||||
results.append(expect(pick_runtime_dtype("mps", "float16"), "float32", "mps/fp16 -> fp32"))
|
||||
results.append(expect(pick_runtime_dtype("mps", "fp16"), "float32", "mps/fp16-alias -> fp32"))
|
||||
results.append(expect(pick_runtime_dtype("mps", "float32"), "float32", "mps/fp32 stays"))
|
||||
results.append(expect(pick_runtime_dtype("mps", "fp32"), "fp32", "mps/fp32-alias stays"))
|
||||
|
||||
print("\n=== pick_runtime_dtype: VOXCPM_MPS_DTYPE override ===")
|
||||
os.environ["VOXCPM_MPS_DTYPE"] = "bfloat16"
|
||||
results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "bfloat16", "override bf16 honored"))
|
||||
|
||||
os.environ["VOXCPM_MPS_DTYPE"] = "FP16"
|
||||
results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "fp16", "override is case-insensitive"))
|
||||
|
||||
os.environ["VOXCPM_MPS_DTYPE"] = " float32 "
|
||||
results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "float32", "override is whitespace-trimmed"))
|
||||
|
||||
print("\n=== pick_runtime_dtype: 'half' is no longer a valid override ===")
|
||||
os.environ["VOXCPM_MPS_DTYPE"] = "half"
|
||||
results.append(
|
||||
expect_raises(
|
||||
lambda: pick_runtime_dtype("mps", "bfloat16"),
|
||||
ValueError,
|
||||
"override=half now rejected (was the bug)",
|
||||
)
|
||||
)
|
||||
|
||||
os.environ["VOXCPM_MPS_DTYPE"] = "garbage"
|
||||
results.append(
|
||||
expect_raises(
|
||||
lambda: pick_runtime_dtype("mps", "bfloat16"),
|
||||
ValueError,
|
||||
"override=garbage still rejected",
|
||||
)
|
||||
)
|
||||
|
||||
os.environ.pop("VOXCPM_MPS_DTYPE", None)
|
||||
|
||||
print("\n=== summary ===")
|
||||
passed = sum(results)
|
||||
total = len(results)
|
||||
print(f"{passed}/{total} passed")
|
||||
sys.exit(0 if passed == total else 1)
|
||||
@@ -30,7 +30,8 @@ except ImportError:
|
||||
import json
|
||||
|
||||
from voxcpm.model import VoxCPMModel, VoxCPM2Model
|
||||
from voxcpm.model.voxcpm import LoRAConfig
|
||||
from voxcpm.model.voxcpm import LoRAConfig as LoRAConfigV1
|
||||
from voxcpm.model.voxcpm2 import LoRAConfig as LoRAConfigV2
|
||||
from voxcpm.training import (
|
||||
Accelerator,
|
||||
BatchProcessor,
|
||||
@@ -46,6 +47,7 @@ def train(
|
||||
train_manifest: str,
|
||||
val_manifest: str = "",
|
||||
sample_rate: int = 16_000,
|
||||
out_sample_rate: int = 0, # AudioVAE decoder output rate; used for TensorBoard audio logging
|
||||
batch_size: int = 1,
|
||||
grad_accum_steps: int = 1,
|
||||
num_workers: int = 2,
|
||||
@@ -63,6 +65,7 @@ def train(
|
||||
lambdas: Dict[str, float] = {"loss/diff": 1.0, "loss/stop": 1.0},
|
||||
lora: dict = None,
|
||||
config_path: str = "",
|
||||
max_grad_norm: float = 0.0, # gradient clipping; 0 = disabled (backward compat)
|
||||
# Distribution options (for LoRA checkpoints)
|
||||
hf_model_id: str = "", # HuggingFace model ID (e.g., "openbmb/VoxCPM1.5")
|
||||
distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path
|
||||
@@ -91,6 +94,7 @@ def train(
|
||||
with open(os.path.join(pretrained_path, "config.json"), "r", encoding="utf-8") as _f:
|
||||
_arch = json.load(_f).get("architecture", "voxcpm").lower()
|
||||
_model_cls = VoxCPM2Model if _arch == "voxcpm2" else VoxCPMModel
|
||||
LoRAConfig = LoRAConfigV2 if _arch == "voxcpm2" else LoRAConfigV1
|
||||
if accelerator.rank == 0:
|
||||
print(f"Detected architecture: {_arch} -> {_model_cls.__name__}", file=sys.stderr)
|
||||
base_model = _model_cls.from_local(
|
||||
@@ -98,6 +102,12 @@ def train(
|
||||
)
|
||||
tokenizer = base_model.text_tokenizer
|
||||
|
||||
expected_sr = base_model.audio_vae.sample_rate
|
||||
assert sample_rate == expected_sr, (
|
||||
f"sample_rate mismatch: config says {sample_rate}, but the AudioVAE encoder expects {expected_sr}. "
|
||||
f"Please set sample_rate: {expected_sr} in your training config. "
|
||||
)
|
||||
|
||||
train_ds, val_ds = load_audio_text_datasets(
|
||||
train_manifest=train_manifest,
|
||||
val_manifest=val_manifest,
|
||||
@@ -170,8 +180,12 @@ def train(
|
||||
dataset_cnt=dataset_cnt,
|
||||
device=accelerator.device,
|
||||
)
|
||||
# Save audio_vae for audio generation
|
||||
# Save audio_vae and output sample rate for audio generation.
|
||||
# Prefer model's actual output rate; fall back to YAML out_sample_rate or encode rate.
|
||||
audio_vae_for_gen = base_model.audio_vae
|
||||
out_sr = base_model.sample_rate # decoder output rate (e.g. 48000 for V2)
|
||||
if out_sr == 0 and out_sample_rate > 0:
|
||||
out_sr = out_sample_rate
|
||||
del base_model.audio_vae
|
||||
model = accelerator.prepare_model(base_model)
|
||||
unwrapped_model = accelerator.unwrap(model)
|
||||
@@ -304,8 +318,8 @@ def train(
|
||||
scaler = getattr(accelerator, "scaler", None)
|
||||
if scaler is not None:
|
||||
scaler.unscale_(optimizer)
|
||||
# Use large max_norm to only compute grad_norm without actual clipping
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(unwrapped_model.parameters(), max_norm=1e9)
|
||||
effective_max_norm = max_grad_norm if max_grad_norm > 0 else 1e9
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(unwrapped_model.parameters(), max_norm=effective_max_norm)
|
||||
|
||||
accelerator.step(optimizer)
|
||||
accelerator.update()
|
||||
@@ -333,6 +347,7 @@ def train(
|
||||
val_ds=val_ds,
|
||||
audio_vae=audio_vae_for_gen,
|
||||
sample_rate=sample_rate,
|
||||
out_sample_rate=out_sr,
|
||||
val_texts=val_texts,
|
||||
tokenizer=tokenizer,
|
||||
valid_interval=valid_interval,
|
||||
@@ -359,6 +374,7 @@ def validate(
|
||||
val_ds=None,
|
||||
audio_vae=None,
|
||||
sample_rate=22050,
|
||||
out_sample_rate=0,
|
||||
val_texts=None,
|
||||
tokenizer=None,
|
||||
valid_interval=1000,
|
||||
@@ -424,6 +440,7 @@ def validate(
|
||||
step,
|
||||
accelerator,
|
||||
sample_rate,
|
||||
out_sample_rate=out_sample_rate,
|
||||
val_texts=val_texts,
|
||||
tokenizer=tokenizer,
|
||||
valid_interval=valid_interval,
|
||||
@@ -526,6 +543,7 @@ def generate_sample_audio(
|
||||
step,
|
||||
accelerator,
|
||||
sample_rate=22050,
|
||||
out_sample_rate=0,
|
||||
val_texts=None,
|
||||
tokenizer=None,
|
||||
pretrained_path=None,
|
||||
@@ -540,6 +558,10 @@ def generate_sample_audio(
|
||||
log(f"[Audio] Starting audio generation for {num_samples} samples at step {step}")
|
||||
|
||||
unwrapped_model = accelerator.unwrap(model)
|
||||
# Determine the correct output sample rate for generated audio.
|
||||
# out_sample_rate is the decoder output rate (e.g. 48kHz for V2);
|
||||
# sample_rate is the encoder input rate (e.g. 16kHz for V2).
|
||||
gen_sr = out_sample_rate if out_sample_rate > 0 else sample_rate
|
||||
|
||||
for i in range(num_samples):
|
||||
sample = val_ds[i]
|
||||
@@ -596,10 +618,10 @@ def generate_sample_audio(
|
||||
gen_audio_np = normalize_audio(gen_audio_np)
|
||||
|
||||
tag = f"val_sample_{i}"
|
||||
writer.add_audio(f"{tag}/generated_audio", gen_audio_np, global_step=step, sample_rate=sample_rate)
|
||||
log(f"[Audio] Generated audio for sample {i}: duration={len(gen_audio_np)/sample_rate:.2f}s")
|
||||
writer.add_audio(f"{tag}/generated_audio", gen_audio_np, global_step=step, sample_rate=gen_sr)
|
||||
log(f"[Audio] Generated audio for sample {i}: duration={len(gen_audio_np)/gen_sr:.2f}s")
|
||||
|
||||
# Log reference audio
|
||||
# Log reference audio (at encoder input rate, which is what val_ds provides)
|
||||
if ref_audio_np is not None:
|
||||
writer.add_audio(
|
||||
f"{tag}/reference_audio", normalize_audio(ref_audio_np), global_step=step, sample_rate=sample_rate
|
||||
@@ -607,9 +629,9 @@ def generate_sample_audio(
|
||||
|
||||
# Generate mel spectrogram figure
|
||||
try:
|
||||
mel_gen = compute_mel_spectrogram(gen_audio_np, sample_rate)
|
||||
mel_gen = compute_mel_spectrogram(gen_audio_np, gen_sr)
|
||||
mel_ref = compute_mel_spectrogram(ref_audio_np, sample_rate) if ref_audio_np is not None else None
|
||||
fig = create_mel_figure(gen_audio_np, mel_gen, sample_rate, step, ref_audio_np, mel_ref)
|
||||
fig = create_mel_figure(gen_audio_np, mel_gen, gen_sr, step, ref_audio_np, mel_ref)
|
||||
writer.add_figure(f"{tag}/mel_spectrogram", fig, global_step=step)
|
||||
log(f"[Audio] Created mel spectrogram figure for sample {i}")
|
||||
except Exception as e:
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
+60
-6
@@ -11,11 +11,6 @@ import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import soundfile as sf
|
||||
|
||||
from voxcpm.core import VoxCPM
|
||||
|
||||
|
||||
DEFAULT_HF_MODEL_ID = "openbmb/VoxCPM2"
|
||||
|
||||
# -----------------------------
|
||||
@@ -173,7 +168,9 @@ def validate_batch_args(args, parser):
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def load_model(args) -> VoxCPM:
|
||||
def load_model(args):
|
||||
from voxcpm.core import VoxCPM
|
||||
|
||||
print("Loading VoxCPM model...", file=sys.stderr)
|
||||
|
||||
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get(
|
||||
@@ -209,6 +206,7 @@ def load_model(args) -> VoxCPM:
|
||||
zipenhancer_model_path=zipenhancer_path,
|
||||
enable_denoiser=not args.no_denoiser,
|
||||
optimize=not args.no_optimize,
|
||||
device=args.device,
|
||||
lora_config=lora_config,
|
||||
lora_weights_path=lora_weights_path,
|
||||
)
|
||||
@@ -227,6 +225,7 @@ def load_model(args) -> VoxCPM:
|
||||
cache_dir=args.cache_dir,
|
||||
local_files_only=args.local_files_only,
|
||||
optimize=not args.no_optimize,
|
||||
device=args.device,
|
||||
lora_config=lora_config,
|
||||
lora_weights_path=lora_weights_path,
|
||||
)
|
||||
@@ -264,6 +263,8 @@ def _run_single(args, parser, *, text: str, output: str, prompt_text: str | None
|
||||
and (args.prompt_audio is not None or args.reference_audio is not None),
|
||||
)
|
||||
|
||||
import soundfile as sf
|
||||
|
||||
sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
|
||||
|
||||
duration = len(audio_array) / model.tts_model.sample_rate
|
||||
@@ -286,7 +287,27 @@ def cmd_clone(args, parser):
|
||||
)
|
||||
|
||||
|
||||
def cmd_validate(args, parser):
|
||||
from voxcpm.training.validate import (
|
||||
print_validation_report,
|
||||
validate_manifest,
|
||||
)
|
||||
|
||||
manifest = str(require_file_exists(args.manifest, parser, "manifest file"))
|
||||
result = validate_manifest(
|
||||
manifest_path=manifest,
|
||||
sample_rate=args.sample_rate,
|
||||
max_samples=args.max_samples,
|
||||
verbose=args.verbose,
|
||||
)
|
||||
print_validation_report(result, manifest)
|
||||
if not result.is_valid:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def cmd_batch(args, parser):
|
||||
import soundfile as sf
|
||||
|
||||
input_file = require_file_exists(args.input, parser, "input file")
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -403,6 +424,12 @@ def _add_model_args(parser):
|
||||
default=DEFAULT_HF_MODEL_ID,
|
||||
help=f"Hugging Face repo id (default: {DEFAULT_HF_MODEL_ID})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="auto",
|
||||
help="Runtime device: auto, cpu, mps, cuda, or cuda:N (default: auto)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache-dir", type=str, help="Cache directory for Hub downloads"
|
||||
)
|
||||
@@ -524,6 +551,30 @@ Examples:
|
||||
_add_model_args(batch_parser)
|
||||
_add_lora_args(batch_parser)
|
||||
|
||||
# Validate subcommand
|
||||
validate_parser = subparsers.add_parser(
|
||||
"validate",
|
||||
help="Validate a training data manifest (JSONL) before fine-tuning",
|
||||
)
|
||||
validate_parser.add_argument(
|
||||
"--manifest", "-m", required=True, help="Path to JSONL training manifest"
|
||||
)
|
||||
validate_parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16_000,
|
||||
help="Expected audio sample rate in Hz (default: 16000)",
|
||||
)
|
||||
validate_parser.add_argument(
|
||||
"--max-samples",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Maximum number of samples to validate (0 = all, default: 0)",
|
||||
)
|
||||
validate_parser.add_argument(
|
||||
"--verbose", "-v", action="store_true", help="Print per-sample progress"
|
||||
)
|
||||
|
||||
# Legacy root arguments
|
||||
parser.add_argument("--input", "-i", help="Input text file (batch mode only)")
|
||||
parser.add_argument(
|
||||
@@ -576,6 +627,9 @@ def main():
|
||||
parser = _build_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "validate":
|
||||
return cmd_validate(args, parser)
|
||||
|
||||
validate_ranges(args, parser)
|
||||
|
||||
if args.command == "design":
|
||||
|
||||
+33
-6
@@ -8,6 +8,7 @@ from typing import Generator, Optional
|
||||
from huggingface_hub import snapshot_download
|
||||
from .model.voxcpm import VoxCPMModel, LoRAConfig
|
||||
from .model.voxcpm2 import VoxCPM2Model
|
||||
from .model.utils import next_and_close
|
||||
|
||||
|
||||
class VoxCPM:
|
||||
@@ -17,6 +18,7 @@ class VoxCPM:
|
||||
zipenhancer_model_path: str | None = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||
enable_denoiser: bool = True,
|
||||
optimize: bool = True,
|
||||
device: str | None = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
lora_weights_path: Optional[str] = None,
|
||||
):
|
||||
@@ -30,6 +32,9 @@ class VoxCPM:
|
||||
id or local path. If None, denoiser will not be initialized.
|
||||
enable_denoiser: Whether to initialize the denoiser pipeline.
|
||||
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
|
||||
device: Runtime device. If set to ``None`` or ``"auto"``, VoxCPM
|
||||
will choose automatically (preferring CUDA, then MPS, then CPU).
|
||||
If set explicitly, that device is used or a clear error is raised.
|
||||
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
|
||||
provided without lora_config, a default config will be created.
|
||||
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
|
||||
@@ -56,10 +61,20 @@ class VoxCPM:
|
||||
arch = config.get("architecture", "voxcpm").lower()
|
||||
|
||||
if arch == "voxcpm2":
|
||||
self.tts_model = VoxCPM2Model.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
||||
self.tts_model = VoxCPM2Model.from_local(
|
||||
voxcpm_model_path,
|
||||
optimize=optimize,
|
||||
device=device,
|
||||
lora_config=lora_config,
|
||||
)
|
||||
print("Loaded VoxCPM2Model", file=sys.stderr)
|
||||
elif arch == "voxcpm":
|
||||
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
||||
self.tts_model = VoxCPMModel.from_local(
|
||||
voxcpm_model_path,
|
||||
optimize=optimize,
|
||||
device=device,
|
||||
lora_config=lora_config,
|
||||
)
|
||||
print("Loaded VoxCPMModel", file=sys.stderr)
|
||||
else:
|
||||
raise ValueError(f"Unsupported architecture: {arch}")
|
||||
@@ -94,6 +109,7 @@ class VoxCPM:
|
||||
cache_dir: str = None,
|
||||
local_files_only: bool = False,
|
||||
optimize: bool = True,
|
||||
device: str | None = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
lora_weights_path: Optional[str] = None,
|
||||
**kwargs,
|
||||
@@ -109,6 +125,9 @@ class VoxCPM:
|
||||
cache_dir: Custom cache directory for the snapshot.
|
||||
local_files_only: If True, only use local files and do not attempt
|
||||
to download.
|
||||
device: Runtime device. Use ``None``/``"auto"`` for automatic
|
||||
fallback, or an explicit value such as ``"cpu"``, ``"mps"``,
|
||||
``"cuda"``, or ``"cuda:0"``.
|
||||
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
|
||||
provided without lora_config, a default config will be created with
|
||||
enable_lm=True and enable_dit=True.
|
||||
@@ -130,7 +149,7 @@ class VoxCPM:
|
||||
if not repo_id:
|
||||
raise ValueError("You must provide hf_model_id")
|
||||
|
||||
# Load from local path if provided
|
||||
# 从本地路径加载(如果提供)
|
||||
if os.path.isdir(repo_id):
|
||||
local_path = repo_id
|
||||
else:
|
||||
@@ -146,13 +165,14 @@ class VoxCPM:
|
||||
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
|
||||
enable_denoiser=load_denoiser,
|
||||
optimize=optimize,
|
||||
device=device,
|
||||
lora_config=lora_config,
|
||||
lora_weights_path=lora_weights_path,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def generate(self, *args, **kwargs) -> np.ndarray:
|
||||
return next(self._generate(*args, streaming=False, **kwargs))
|
||||
return next_and_close(self._generate(*args, streaming=False, **kwargs))
|
||||
|
||||
def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]:
|
||||
return self._generate(*args, streaming=True, **kwargs)
|
||||
@@ -200,7 +220,7 @@ class VoxCPM:
|
||||
Yields audio chunks for each generation step if ``streaming=True``,
|
||||
otherwise yields a single array containing the final audio.
|
||||
"""
|
||||
if not text.strip() or not isinstance(text, str):
|
||||
if not isinstance(text, str) or not text.strip():
|
||||
raise ValueError("target text must be a non-empty string")
|
||||
|
||||
if prompt_wav_path is not None:
|
||||
@@ -273,7 +293,14 @@ class VoxCPM:
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
for wav, _, _ in generate_result:
|
||||
if streaming:
|
||||
try:
|
||||
for wav, _, _ in generate_result:
|
||||
yield wav.squeeze(0).cpu().numpy()
|
||||
finally:
|
||||
generate_result.close()
|
||||
else:
|
||||
wav, _, _ = next_and_close(generate_result)
|
||||
yield wav.squeeze(0).cpu().numpy()
|
||||
|
||||
finally:
|
||||
|
||||
+111
-1
@@ -1,7 +1,25 @@
|
||||
from typing import List
|
||||
import os
|
||||
from typing import List, Optional
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
_LOW_PRECISION_DTYPES = {"bfloat16", "bf16", "float16", "fp16"}
|
||||
_VALID_DTYPE_OVERRIDES = {
|
||||
"bfloat16", "bf16",
|
||||
"float16", "fp16",
|
||||
"float32", "fp32",
|
||||
}
|
||||
|
||||
|
||||
# Ref: https://github.com/OpenBMB/VoxCPM/issues/256#issuecomment-4235252732
|
||||
# Explicitly close partially-consumed generators so inference_mode cleanup
|
||||
# does not get deferred to Python's GC/finalizer path.
|
||||
def next_and_close(gen):
|
||||
try:
|
||||
return next(gen)
|
||||
finally:
|
||||
gen.close()
|
||||
|
||||
|
||||
def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
|
||||
"""Create a tokenizer wrapper that converts multi-character Chinese tokens to single characters.
|
||||
@@ -119,3 +137,95 @@ def get_dtype(dtype: str):
|
||||
return torch.float32
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
|
||||
def _has_mps() -> bool:
|
||||
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||
|
||||
|
||||
def pick_runtime_dtype(device: str, configured_dtype: str) -> str:
|
||||
"""Pick a safe runtime dtype for the resolved device.
|
||||
|
||||
On Apple Silicon (MPS), bfloat16/float16 produce enough numerical drift
|
||||
in the diffusion AR loop that the output is glitched and the model's
|
||||
badcase detector triggers infinite retries. float32 is the only stable
|
||||
option today. CUDA and CPU keep whatever the checkpoint was trained with.
|
||||
|
||||
Users can override with ``VOXCPM_MPS_DTYPE`` (e.g. ``bfloat16``) when
|
||||
they want to test future MPS improvements.
|
||||
"""
|
||||
if device != "mps":
|
||||
return configured_dtype
|
||||
|
||||
override = os.environ.get("VOXCPM_MPS_DTYPE", "").strip().lower()
|
||||
if override:
|
||||
if override not in _VALID_DTYPE_OVERRIDES:
|
||||
raise ValueError(
|
||||
f"VOXCPM_MPS_DTYPE='{override}' is not one of "
|
||||
f"{sorted(_VALID_DTYPE_OVERRIDES)}"
|
||||
)
|
||||
return override
|
||||
|
||||
if (configured_dtype or "").lower() in _LOW_PRECISION_DTYPES:
|
||||
return "float32"
|
||||
return configured_dtype
|
||||
|
||||
|
||||
def auto_select_device(preferred_device: Optional[str] = "cuda") -> str:
|
||||
"""
|
||||
Choose a runtime device automatically.
|
||||
|
||||
Preference order:
|
||||
- if the preferred device is available, use it
|
||||
- otherwise fall back to CUDA -> MPS -> CPU
|
||||
"""
|
||||
preferred = (preferred_device or "cuda").strip().lower()
|
||||
|
||||
if preferred.startswith("cuda") and torch.cuda.is_available():
|
||||
return preferred
|
||||
if preferred == "mps" and _has_mps():
|
||||
return "mps"
|
||||
if preferred == "cpu":
|
||||
return "cpu"
|
||||
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
if _has_mps():
|
||||
return "mps"
|
||||
return "cpu"
|
||||
|
||||
|
||||
def resolve_runtime_device(device: Optional[str], configured_device: str = "cuda") -> str:
|
||||
"""
|
||||
Resolve the actual runtime device.
|
||||
|
||||
Semantics:
|
||||
- ``device`` is ``None`` or ``"auto"``: use automatic fallback selection
|
||||
- otherwise: treat it as an explicit user choice and validate availability
|
||||
"""
|
||||
explicit = None if device is None else device.strip().lower()
|
||||
|
||||
if explicit is None or explicit == "auto":
|
||||
return auto_select_device(configured_device)
|
||||
|
||||
if explicit.startswith("cuda"):
|
||||
if not torch.cuda.is_available():
|
||||
raise ValueError(
|
||||
f"Requested device '{device}', but CUDA is not available. "
|
||||
"Use device='auto' for automatic fallback."
|
||||
)
|
||||
return explicit
|
||||
if explicit == "mps":
|
||||
if not _has_mps():
|
||||
raise ValueError(
|
||||
"Requested device 'mps', but MPS is not available. "
|
||||
"Use device='auto' for automatic fallback."
|
||||
)
|
||||
return "mps"
|
||||
if explicit == "cpu":
|
||||
return "cpu"
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported device '{device}'. Supported values are 'auto', 'cpu', 'mps', "
|
||||
"'cuda', or indexed CUDA devices like 'cuda:0'."
|
||||
)
|
||||
|
||||
+37
-17
@@ -44,7 +44,13 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules
|
||||
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
|
||||
from ..modules.locenc import VoxCPMLocEnc
|
||||
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
||||
from .utils import get_dtype, mask_multichar_chinese_tokens
|
||||
from .utils import (
|
||||
get_dtype,
|
||||
mask_multichar_chinese_tokens,
|
||||
next_and_close,
|
||||
pick_runtime_dtype,
|
||||
resolve_runtime_device,
|
||||
)
|
||||
|
||||
|
||||
class VoxCPMEncoderConfig(BaseModel):
|
||||
@@ -109,18 +115,22 @@ class VoxCPMModel(nn.Module):
|
||||
tokenizer: LlamaTokenizerFast,
|
||||
audio_vae: AudioVAE,
|
||||
lora_config: LoRAConfig = None,
|
||||
device: str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
self.feat_dim = config.feat_dim
|
||||
self.patch_size = config.patch_size
|
||||
self.device = config.device
|
||||
if not torch.cuda.is_available():
|
||||
if torch.backends.mps.is_available():
|
||||
self.device = "mps"
|
||||
else:
|
||||
self.device = "cpu"
|
||||
self.device = resolve_runtime_device(device, config.device)
|
||||
self.config.device = self.device
|
||||
resolved_dtype = pick_runtime_dtype(self.device, self.config.dtype)
|
||||
if resolved_dtype != self.config.dtype:
|
||||
print(
|
||||
f"[voxcpm] adjusted dtype {self.config.dtype} -> {resolved_dtype} for device {self.device}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
self.config.dtype = resolved_dtype
|
||||
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
|
||||
|
||||
# Text-Semantic LM
|
||||
@@ -227,6 +237,7 @@ class VoxCPMModel(nn.Module):
|
||||
self.residual_lm.forward_step = torch.compile(
|
||||
self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True
|
||||
)
|
||||
self._feat_encoder_raw = self.feat_encoder
|
||||
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
||||
self.feat_decoder.estimator = torch.compile(
|
||||
self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True
|
||||
@@ -337,7 +348,7 @@ class VoxCPMModel(nn.Module):
|
||||
return get_dtype(self.config.dtype)
|
||||
|
||||
def generate(self, *args, **kwargs) -> torch.Tensor:
|
||||
return next(self._generate(*args, streaming=False, **kwargs))
|
||||
return next_and_close(self._generate(*args, streaming=False, **kwargs))
|
||||
|
||||
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
|
||||
return self._generate(*args, streaming=True, **kwargs)
|
||||
@@ -463,7 +474,7 @@ class VoxCPMModel(nn.Module):
|
||||
yield decode_audio
|
||||
break
|
||||
else:
|
||||
latent_pred, pred_audio_feat = next(inference_result)
|
||||
latent_pred, pred_audio_feat = next_and_close(inference_result)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(
|
||||
@@ -571,7 +582,7 @@ class VoxCPMModel(nn.Module):
|
||||
return merged_cache
|
||||
|
||||
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
||||
return next_and_close(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
||||
|
||||
def generate_with_prompt_cache_streaming(
|
||||
self, *args, **kwargs
|
||||
@@ -690,7 +701,7 @@ class VoxCPMModel(nn.Module):
|
||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||
break
|
||||
else:
|
||||
latent_pred, pred_audio_feat = next(inference_result)
|
||||
latent_pred, pred_audio_feat = next_and_close(inference_result)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(
|
||||
@@ -713,7 +724,7 @@ class VoxCPMModel(nn.Module):
|
||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||
|
||||
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return next(self._inference(*args, streaming=False, **kwargs))
|
||||
return next_and_close(self._inference(*args, streaming=False, **kwargs))
|
||||
|
||||
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
|
||||
return self._inference(*args, streaming=True, **kwargs)
|
||||
@@ -755,7 +766,8 @@ class VoxCPMModel(nn.Module):
|
||||
"""
|
||||
B, T, P, D = feat.shape
|
||||
|
||||
feat_embed = self.feat_encoder(feat) # [b, t, h_feat]
|
||||
prefill_encoder = getattr(self, "_feat_encoder_raw", self.feat_encoder)
|
||||
feat_embed = prefill_encoder(feat) # [b, t, h_feat]
|
||||
feat_embed = self.enc_to_lm_proj(feat_embed)
|
||||
|
||||
if self.config.lm_config.use_mup:
|
||||
@@ -845,8 +857,16 @@ class VoxCPMModel(nn.Module):
|
||||
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
|
||||
|
||||
@classmethod
|
||||
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
|
||||
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
||||
def from_local(
|
||||
cls,
|
||||
path: str,
|
||||
optimize: bool = True,
|
||||
training: bool = False,
|
||||
device: str | None = None,
|
||||
lora_config: LoRAConfig = None,
|
||||
):
|
||||
with open(os.path.join(path, "config.json"), "r", encoding="utf-8") as _cfg_f:
|
||||
config = VoxCPMConfig.model_validate_json(_cfg_f.read())
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
||||
audio_vae_config = getattr(config, "audio_vae_config", None)
|
||||
audio_vae = AudioVAE(config=audio_vae_config) if audio_vae_config else AudioVAE()
|
||||
@@ -868,7 +888,7 @@ class VoxCPMModel(nn.Module):
|
||||
raise FileNotFoundError(
|
||||
f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
|
||||
)
|
||||
model = cls(config, tokenizer, audio_vae, lora_config)
|
||||
model = cls(config, tokenizer, audio_vae, lora_config, device=device)
|
||||
if not training:
|
||||
lm_dtype = get_dtype(model.config.dtype)
|
||||
model = model.to(lm_dtype)
|
||||
@@ -950,7 +970,7 @@ class VoxCPMModel(nn.Module):
|
||||
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
|
||||
state_dict = load_file(str(safetensors_file), device=device)
|
||||
elif ckpt_file and ckpt_file.exists():
|
||||
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
|
||||
ckpt = torch.load(ckpt_file, map_location=device, weights_only=True)
|
||||
state_dict = ckpt.get("state_dict", ckpt)
|
||||
else:
|
||||
raise FileNotFoundError(f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}")
|
||||
|
||||
+106
-74
@@ -45,28 +45,17 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules
|
||||
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiTV2
|
||||
from ..modules.locenc import VoxCPMLocEnc
|
||||
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
||||
from .utils import get_dtype, mask_multichar_chinese_tokens
|
||||
from .utils import (
|
||||
get_dtype,
|
||||
mask_multichar_chinese_tokens,
|
||||
next_and_close,
|
||||
pick_runtime_dtype,
|
||||
resolve_runtime_device,
|
||||
)
|
||||
|
||||
|
||||
def _trim_audio_silence_vad(
|
||||
audio: torch.Tensor,
|
||||
sample_rate: int,
|
||||
max_silence_ms: float = 200.0,
|
||||
top_db: float = 35.0,
|
||||
) -> torch.Tensor:
|
||||
"""使用能量阈值(VAD 方式)截取首尾静音及尾部长段伪静音,首尾各最多保留 max_silence_ms 毫秒静音。
|
||||
|
||||
会同时截掉末尾的长段伪静音(低能量但非完全静音的段落,如长时间底噪)。
|
||||
|
||||
Args:
|
||||
audio: (1, T) 的音频 tensor
|
||||
sample_rate: 采样率
|
||||
max_silence_ms: 首尾允许保留的最大静音长度(毫秒)
|
||||
top_db: 低于参考电平多少 dB 视为静音
|
||||
|
||||
Returns:
|
||||
截取后的 (1, T') tensor
|
||||
"""
|
||||
# A simple function to trim audio silence using VAD, not used default
|
||||
def _trim_audio_silence_vad(audio: torch.Tensor, sample_rate: int, max_silence_ms: float = 200.0, top_db: float = 35.0) -> torch.Tensor:
|
||||
if audio.numel() == 0:
|
||||
return audio
|
||||
y = audio.squeeze(0).numpy()
|
||||
@@ -85,7 +74,7 @@ def _trim_audio_silence_vad(
|
||||
except Exception:
|
||||
start, end = 0, n
|
||||
|
||||
# 用逐帧 RMS 找「最后一段有持续能量的位置」,截掉末尾长伪静音(低能量底噪等)
|
||||
# Find the last frame with continuous energy, trim the long pseudo-silence at the end (low energy background noise, etc.)
|
||||
n_frames = max(0, (n - frame_length) // hop_length + 1)
|
||||
last_voice_frame = -1
|
||||
for j in range(n_frames):
|
||||
@@ -168,18 +157,22 @@ class VoxCPM2Model(nn.Module):
|
||||
tokenizer: LlamaTokenizerFast,
|
||||
audio_vae: AudioVAEV2,
|
||||
lora_config: LoRAConfig = None,
|
||||
device: str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
self.feat_dim = config.feat_dim
|
||||
self.patch_size = config.patch_size
|
||||
self.device = config.device
|
||||
if not torch.cuda.is_available():
|
||||
if torch.backends.mps.is_available():
|
||||
self.device = "mps"
|
||||
else:
|
||||
self.device = "cpu"
|
||||
self.device = resolve_runtime_device(device, config.device)
|
||||
self.config.device = self.device
|
||||
resolved_dtype = pick_runtime_dtype(self.device, self.config.dtype)
|
||||
if resolved_dtype != self.config.dtype:
|
||||
print(
|
||||
f"[voxcpm2] adjusted dtype {self.config.dtype} -> {resolved_dtype} for device {self.device}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
self.config.dtype = resolved_dtype
|
||||
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
|
||||
|
||||
# Text-Semantic LM
|
||||
@@ -246,6 +239,7 @@ class VoxCPM2Model(nn.Module):
|
||||
# Audio VAE
|
||||
self.audio_vae = audio_vae
|
||||
self.chunk_size = audio_vae.chunk_size
|
||||
self._decode_chunk_size = getattr(audio_vae, "decode_chunk_size", audio_vae.chunk_size)
|
||||
self._encode_sample_rate = audio_vae.sample_rate
|
||||
self.sample_rate = getattr(audio_vae, "out_sample_rate", audio_vae.sample_rate)
|
||||
|
||||
@@ -291,6 +285,7 @@ class VoxCPM2Model(nn.Module):
|
||||
self.residual_lm.forward_step = torch.compile(
|
||||
self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True
|
||||
)
|
||||
self._feat_encoder_raw = self.feat_encoder
|
||||
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
||||
self.feat_decoder.estimator = torch.compile(
|
||||
self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True
|
||||
@@ -382,11 +377,7 @@ class VoxCPM2Model(nn.Module):
|
||||
mu=dit_hidden,
|
||||
patch_size=self.patch_size,
|
||||
cond=feat_cond_for_sample,
|
||||
n_timesteps=(
|
||||
self.config.dit_config.cfm_config.inference_cfg_rate
|
||||
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
|
||||
else 10
|
||||
),
|
||||
n_timesteps=10,
|
||||
)
|
||||
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
|
||||
|
||||
@@ -402,19 +393,26 @@ class VoxCPM2Model(nn.Module):
|
||||
def _dtype(self):
|
||||
return get_dtype(self.config.dtype)
|
||||
|
||||
def _encode_wav(self, wav_path: str, padding_mode: str = "right") -> torch.Tensor:
|
||||
def _encode_wav(
|
||||
self,
|
||||
wav_path: str,
|
||||
padding_mode: str = "right",
|
||||
trim_silence_vad: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Load, trim, pad and VAE-encode an audio file.
|
||||
|
||||
Args:
|
||||
wav_path: path to the audio file.
|
||||
padding_mode: "right" (default) or "left" padding for alignment.
|
||||
trim_silence_vad: whether to apply VAD-based silence trimming.
|
||||
|
||||
Returns:
|
||||
audio_feat: (T, P, D) tensor of latent patches.
|
||||
"""
|
||||
audio, _ = librosa.load(wav_path, sr=self._encode_sample_rate, mono=True)
|
||||
audio = torch.from_numpy(audio).unsqueeze(0)
|
||||
audio = _trim_audio_silence_vad(audio, self._encode_sample_rate, max_silence_ms=200.0)
|
||||
if trim_silence_vad:
|
||||
audio = _trim_audio_silence_vad(audio, self._encode_sample_rate, max_silence_ms=200.0)
|
||||
patch_len = self.patch_size * self.chunk_size
|
||||
if audio.size(1) % patch_len != 0:
|
||||
padding_size = patch_len - audio.size(1) % patch_len
|
||||
@@ -456,7 +454,7 @@ class VoxCPM2Model(nn.Module):
|
||||
return tokens, feats, t_mask, a_mask
|
||||
|
||||
def generate(self, *args, **kwargs) -> torch.Tensor:
|
||||
return next(self._generate(*args, streaming=False, **kwargs))
|
||||
return next_and_close(self._generate(*args, streaming=False, **kwargs))
|
||||
|
||||
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
|
||||
return self._generate(*args, streaming=True, **kwargs)
|
||||
@@ -475,6 +473,7 @@ class VoxCPM2Model(nn.Module):
|
||||
retry_badcase: bool = False,
|
||||
retry_badcase_max_times: int = 3,
|
||||
retry_badcase_ratio_threshold: float = 6.0,
|
||||
trim_silence_vad: bool = False,
|
||||
streaming: bool = False,
|
||||
streaming_prefix_len: int = 4,
|
||||
) -> Generator[torch.Tensor, None, None]:
|
||||
@@ -495,8 +494,12 @@ class VoxCPM2Model(nn.Module):
|
||||
)
|
||||
text_length = text_token.shape[0]
|
||||
|
||||
ref_feat = self._encode_wav(reference_wav_path, padding_mode="right")
|
||||
prompt_feat = self._encode_wav(prompt_wav_path, padding_mode="left")
|
||||
ref_feat = self._encode_wav(
|
||||
reference_wav_path,
|
||||
padding_mode="right",
|
||||
trim_silence_vad=trim_silence_vad,
|
||||
)
|
||||
prompt_feat = self._encode_wav(prompt_wav_path, padding_mode="left", trim_silence_vad=trim_silence_vad)
|
||||
prompt_audio_length = prompt_feat.size(0)
|
||||
|
||||
ref_tokens, ref_feats, ref_t_mask, ref_a_mask = self._make_ref_prefix(ref_feat, text_token.device)
|
||||
@@ -538,7 +541,11 @@ class VoxCPM2Model(nn.Module):
|
||||
)
|
||||
text_length = text_token.shape[0]
|
||||
|
||||
ref_feat = self._encode_wav(reference_wav_path, padding_mode="right")
|
||||
ref_feat = self._encode_wav(
|
||||
reference_wav_path,
|
||||
padding_mode="right",
|
||||
trim_silence_vad=trim_silence_vad,
|
||||
)
|
||||
ref_tokens, ref_feats, ref_t_mask, ref_a_mask = self._make_ref_prefix(ref_feat, text_token.device)
|
||||
|
||||
text_pad_feat = torch.zeros(
|
||||
@@ -595,7 +602,7 @@ class VoxCPM2Model(nn.Module):
|
||||
)
|
||||
text_length = text_token.shape[0]
|
||||
|
||||
prompt_feat = self._encode_wav(prompt_wav_path, padding_mode="left")
|
||||
prompt_feat = self._encode_wav(prompt_wav_path, padding_mode="left", trim_silence_vad=trim_silence_vad)
|
||||
prompt_audio_length = prompt_feat.size(0)
|
||||
prompt_pad_token = torch.zeros(prompt_audio_length, dtype=torch.int32, device=text_token.device)
|
||||
text_pad_feat = torch.zeros(
|
||||
@@ -640,14 +647,14 @@ class VoxCPM2Model(nn.Module):
|
||||
streaming_prefix_len=streaming_prefix_len,
|
||||
)
|
||||
if streaming:
|
||||
patch_len = self.patch_size * self.chunk_size
|
||||
for latent_pred, _ in inference_result:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
|
||||
yield decode_audio
|
||||
with self.audio_vae.streaming_decode() as vae_dec:
|
||||
for latent_pred, _, _ctx in inference_result:
|
||||
decode_audio = vae_dec.decode_chunk(latent_pred.to(torch.float32))
|
||||
decode_audio = decode_audio.squeeze(1).cpu()
|
||||
yield decode_audio
|
||||
break
|
||||
else:
|
||||
latent_pred, pred_audio_feat = next(inference_result)
|
||||
latent_pred, pred_audio_feat, context_len = next_and_close(inference_result)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(
|
||||
@@ -663,10 +670,9 @@ class VoxCPM2Model(nn.Module):
|
||||
|
||||
if not streaming:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||
patch_len = self.patch_size * self.chunk_size
|
||||
has_continuation = bool(prompt_wav_path)
|
||||
if has_continuation:
|
||||
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1):].squeeze(1).cpu()
|
||||
decode_patch_len = self.patch_size * self._decode_chunk_size
|
||||
if context_len > 0:
|
||||
decode_audio = decode_audio[..., decode_patch_len * context_len:].squeeze(1).cpu()
|
||||
else:
|
||||
decode_audio = decode_audio.squeeze(1).cpu()
|
||||
yield decode_audio
|
||||
@@ -677,6 +683,7 @@ class VoxCPM2Model(nn.Module):
|
||||
prompt_text: str = None,
|
||||
prompt_wav_path: str = None,
|
||||
reference_wav_path: str = None,
|
||||
trim_silence_vad: bool = False,
|
||||
):
|
||||
"""
|
||||
Build prompt cache for subsequent generation.
|
||||
@@ -693,6 +700,8 @@ class VoxCPM2Model(nn.Module):
|
||||
Must be paired with ``prompt_text``.
|
||||
reference_wav_path: reference audio path for voice cloning
|
||||
(structurally isolated via ref_audio tokens).
|
||||
trim_silence_vad: whether to apply VAD-based silence trimming
|
||||
before encoding prompt/reference audio.
|
||||
|
||||
Returns:
|
||||
prompt_cache: dict used by ``_generate_with_prompt_cache``.
|
||||
@@ -705,11 +714,19 @@ class VoxCPM2Model(nn.Module):
|
||||
cache = {}
|
||||
|
||||
if reference_wav_path:
|
||||
cache["ref_audio_feat"] = self._encode_wav(reference_wav_path, padding_mode="right")
|
||||
cache["ref_audio_feat"] = self._encode_wav(
|
||||
reference_wav_path,
|
||||
padding_mode="right",
|
||||
trim_silence_vad=trim_silence_vad,
|
||||
)
|
||||
|
||||
if prompt_wav_path and prompt_text is not None:
|
||||
cache["prompt_text"] = prompt_text
|
||||
cache["audio_feat"] = self._encode_wav(prompt_wav_path, padding_mode="left")
|
||||
cache["audio_feat"] = self._encode_wav(
|
||||
prompt_wav_path,
|
||||
padding_mode="left",
|
||||
trim_silence_vad=trim_silence_vad,
|
||||
)
|
||||
|
||||
has_ref = "ref_audio_feat" in cache
|
||||
has_prompt = "audio_feat" in cache
|
||||
@@ -755,7 +772,7 @@ class VoxCPM2Model(nn.Module):
|
||||
return merged
|
||||
|
||||
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
||||
return next_and_close(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
||||
|
||||
def generate_with_prompt_cache_streaming(
|
||||
self, *args, **kwargs
|
||||
@@ -917,14 +934,14 @@ class VoxCPM2Model(nn.Module):
|
||||
streaming_prefix_len=streaming_prefix_len,
|
||||
)
|
||||
if streaming:
|
||||
patch_len = self.patch_size * self.chunk_size
|
||||
for latent_pred, pred_audio_feat in inference_result:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
|
||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||
with self.audio_vae.streaming_decode() as vae_dec:
|
||||
for latent_pred, pred_audio_feat, _ctx in inference_result:
|
||||
decode_audio = vae_dec.decode_chunk(latent_pred.to(torch.float32))
|
||||
decode_audio = decode_audio.squeeze(1).cpu()
|
||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||
break
|
||||
else:
|
||||
latent_pred, pred_audio_feat = next(inference_result)
|
||||
latent_pred, pred_audio_feat, context_len = next_and_close(inference_result)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(
|
||||
@@ -939,18 +956,20 @@ class VoxCPM2Model(nn.Module):
|
||||
break
|
||||
if not streaming:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||
patch_len = self.patch_size * self.chunk_size
|
||||
if mode in ("continuation", "ref_continuation"):
|
||||
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1) :].squeeze(1).cpu()
|
||||
decode_patch_len = self.patch_size * self._decode_chunk_size
|
||||
if context_len > 0:
|
||||
decode_audio = decode_audio[..., decode_patch_len * context_len:].squeeze(1).cpu()
|
||||
else:
|
||||
decode_audio = decode_audio[..., :].squeeze(1).cpu()
|
||||
decode_audio = decode_audio.squeeze(1).cpu()
|
||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||
|
||||
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return next(self._inference(*args, streaming=False, **kwargs))
|
||||
feat_pred, generated_feat, _ = next_and_close(self._inference(*args, streaming=False, **kwargs))
|
||||
return feat_pred, generated_feat
|
||||
|
||||
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
|
||||
return self._inference(*args, streaming=True, **kwargs)
|
||||
for feat_pred, pred_feat_seq, _ in self._inference(*args, streaming=True, **kwargs):
|
||||
yield feat_pred, pred_feat_seq
|
||||
|
||||
@torch.inference_mode()
|
||||
def _inference(
|
||||
@@ -989,7 +1008,8 @@ class VoxCPM2Model(nn.Module):
|
||||
"""
|
||||
B, T, P, D = feat.shape
|
||||
|
||||
feat_embed = self.feat_encoder(feat) # [b, t, h_feat]
|
||||
prefill_encoder = getattr(self, "_feat_encoder_raw", self.feat_encoder)
|
||||
feat_embed = prefill_encoder(feat) # [b, t, h_feat]
|
||||
feat_embed = self.enc_to_lm_proj(feat_embed)
|
||||
|
||||
if self.config.lm_config.use_mup:
|
||||
@@ -1009,6 +1029,7 @@ class VoxCPM2Model(nn.Module):
|
||||
# trailing audio patches as initial context so the VAE can decode smoothly.
|
||||
# - Reference-only / zero-shot (feat_mask ends with 0): start from scratch.
|
||||
has_continuation_audio = feat_mask[0, -1].item() == 1
|
||||
context_len = 0
|
||||
if has_continuation_audio:
|
||||
audio_indices = feat_mask.squeeze(0).nonzero(as_tuple=True)[0]
|
||||
context_len = min(streaming_prefix_len - 1, len(audio_indices))
|
||||
@@ -1058,11 +1079,13 @@ class VoxCPM2Model(nn.Module):
|
||||
prefix_feat_cond = pred_feat
|
||||
|
||||
if streaming:
|
||||
# return the last three predicted latent features to provide enough context for smooth decoding
|
||||
pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1)
|
||||
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
# Yield only the newest patch latent for stateful VAE decode
|
||||
feat_pred = rearrange(pred_feat.unsqueeze(1), "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
|
||||
yield feat_pred, pred_feat_seq
|
||||
yield feat_pred, pred_feat_seq, context_len
|
||||
|
||||
if len(pred_feat_seq) > streaming_prefix_len:
|
||||
pred_feat_seq = pred_feat_seq[-streaming_prefix_len:]
|
||||
|
||||
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
|
||||
if i > min_len and stop_flag == 1:
|
||||
@@ -1081,11 +1104,20 @@ class VoxCPM2Model(nn.Module):
|
||||
if not streaming:
|
||||
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
|
||||
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
|
||||
generated_feat = pred_feat_seq[:, context_len:, :, :].squeeze(0).cpu()
|
||||
yield feat_pred, generated_feat, context_len
|
||||
|
||||
@classmethod
|
||||
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
|
||||
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
||||
def from_local(
|
||||
cls,
|
||||
path: str,
|
||||
optimize: bool = True,
|
||||
training: bool = False,
|
||||
device: str | None = None,
|
||||
lora_config: LoRAConfig = None,
|
||||
):
|
||||
with open(os.path.join(path, "config.json"), "r", encoding="utf-8") as _cfg_f:
|
||||
config = VoxCPMConfig.model_validate_json(_cfg_f.read())
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
||||
audio_vae_config = getattr(config, "audio_vae_config", None)
|
||||
audio_vae = AudioVAEV2(config=audio_vae_config) if audio_vae_config else AudioVAEV2()
|
||||
@@ -1107,7 +1139,7 @@ class VoxCPM2Model(nn.Module):
|
||||
raise FileNotFoundError(
|
||||
f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
|
||||
)
|
||||
model = cls(config, tokenizer, audio_vae, lora_config)
|
||||
model = cls(config, tokenizer, audio_vae, lora_config, device=device)
|
||||
if not training:
|
||||
lm_dtype = get_dtype(model.config.dtype)
|
||||
model = model.to(lm_dtype)
|
||||
@@ -1189,7 +1221,7 @@ class VoxCPM2Model(nn.Module):
|
||||
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
|
||||
state_dict = load_file(str(safetensors_file), device=device)
|
||||
elif ckpt_file and ckpt_file.exists():
|
||||
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
|
||||
ckpt = torch.load(ckpt_file, map_location=device, weights_only=True)
|
||||
state_dict = ckpt.get("state_dict", ckpt)
|
||||
else:
|
||||
raise FileNotFoundError(f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}")
|
||||
|
||||
@@ -436,6 +436,7 @@ class AudioVAE(nn.Module):
|
||||
self.out_sample_rate = out_sample_rate
|
||||
self.sr_bin_boundaries = sr_bin_boundaries
|
||||
self.chunk_size = math.prod(encoder_rates)
|
||||
self.decode_chunk_size = math.prod(decoder_rates)
|
||||
|
||||
def preprocess(self, audio_data, sample_rate):
|
||||
if sample_rate is None:
|
||||
@@ -471,6 +472,20 @@ class AudioVAE(nn.Module):
|
||||
sr_cond = torch.tensor([self.out_sample_rate], device=z.device, dtype=torch.int32)
|
||||
return self.decoder(z, sr_cond)
|
||||
|
||||
def streaming_decode(self):
|
||||
"""Return a ``StreamingVAEDecoder`` context manager for stateful
|
||||
chunk-by-chunk decoding. Each call to ``decode_chunk`` processes only
|
||||
the new latent patch and carries causal-conv state internally, avoiding
|
||||
the redundant overlap decode used previously.
|
||||
|
||||
Usage::
|
||||
|
||||
with vae.streaming_decode() as dec:
|
||||
for patch in patches:
|
||||
audio_chunk = dec.decode_chunk(patch)
|
||||
"""
|
||||
return StreamingVAEDecoder(self)
|
||||
|
||||
def encode(self, audio_data: torch.Tensor, sample_rate: int):
|
||||
"""
|
||||
Args:
|
||||
@@ -484,3 +499,82 @@ class AudioVAE(nn.Module):
|
||||
|
||||
audio_data = self.preprocess(audio_data, sample_rate)
|
||||
return self.encoder(audio_data)["mu"]
|
||||
|
||||
|
||||
class StreamingVAEDecoder:
|
||||
"""Stateful streaming wrapper for :class:`AudioVAE`.
|
||||
|
||||
Carries causal-convolution padding buffers between calls so that each
|
||||
``decode_chunk`` processes only the new latent patch — no overlap needed.
|
||||
"""
|
||||
|
||||
def __init__(self, vae: AudioVAE):
|
||||
self._vae = vae
|
||||
self._states: dict = {}
|
||||
self._originals: list = []
|
||||
|
||||
# -- context manager --------------------------------------------------
|
||||
def __enter__(self):
|
||||
self._states.clear()
|
||||
self._install()
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
self._restore()
|
||||
self._states.clear()
|
||||
|
||||
# -- public API --------------------------------------------------------
|
||||
def decode_chunk(self, z_chunk: torch.Tensor) -> torch.Tensor:
|
||||
"""Decode a single latent chunk and return the audio waveform."""
|
||||
return self._vae.decode(z_chunk)
|
||||
|
||||
# -- internals ---------------------------------------------------------
|
||||
def _install(self):
|
||||
for name, mod in self._vae.decoder.named_modules():
|
||||
if isinstance(mod, CausalConv1d):
|
||||
pad = mod._CausalConv1d__padding * 2 - mod._CausalConv1d__output_padding
|
||||
if pad > 0:
|
||||
self._patch_causal_conv(mod, pad)
|
||||
elif isinstance(mod, CausalTransposeConv1d):
|
||||
trim = mod._CausalTransposeConv1d__padding * 2 - mod._CausalTransposeConv1d__output_padding
|
||||
ctx = (mod.kernel_size[0] - 1) // mod.stride[0]
|
||||
if ctx > 0:
|
||||
self._patch_transpose_conv(mod, ctx, trim)
|
||||
|
||||
def _patch_causal_conv(self, mod, pad_size):
|
||||
states = self._states
|
||||
key = id(mod)
|
||||
orig = mod.forward
|
||||
|
||||
def fwd(x, _k=key, _p=pad_size, _m=mod):
|
||||
x_pad = torch.cat([states[_k], x], dim=-1) if _k in states else F.pad(x, (_p, 0))
|
||||
if x.shape[-1] >= _p:
|
||||
states[_k] = x[:, :, -_p:].detach()
|
||||
else:
|
||||
prev = states.get(_k, torch.zeros(x.shape[0], x.shape[1], _p,
|
||||
device=x.device, dtype=x.dtype))
|
||||
states[_k] = torch.cat([prev, x], dim=-1)[:, :, -_p:].detach()
|
||||
return nn.Conv1d.forward(_m, x_pad)
|
||||
|
||||
mod.forward = fwd
|
||||
self._originals.append((mod, orig))
|
||||
|
||||
def _patch_transpose_conv(self, mod, ctx, trim):
|
||||
states = self._states
|
||||
key = id(mod)
|
||||
orig = mod.forward
|
||||
|
||||
def fwd(x, _k=key, _c=ctx, _t=trim, _m=mod):
|
||||
x_full = torch.cat([states[_k], x], dim=-1) if _k in states else F.pad(x, (_c, 0))
|
||||
states[_k] = x[:, :, -_c:].detach()
|
||||
out = nn.ConvTranspose1d.forward(_m, x_full)
|
||||
left = _c * _m.stride[0]
|
||||
return out[..., left:-_t] if _t > 0 else out[..., left:]
|
||||
|
||||
mod.forward = fwd
|
||||
self._originals.append((mod, orig))
|
||||
|
||||
def _restore(self):
|
||||
for mod, orig in self._originals:
|
||||
mod.forward = orig
|
||||
self._originals.clear()
|
||||
|
||||
@@ -225,7 +225,7 @@ class UnifiedCFM(torch.nn.Module):
|
||||
losses = F.mse_loss(u_pred, u_tgt.detach(), reduction="none").mean(dim=1)
|
||||
if tgt_mask is not None:
|
||||
weights = self.adaptive_loss_weighting(losses, tgt_mask.squeeze(1))
|
||||
loss = (weights * losses).sum() / torch.sum(tgt_mask)
|
||||
loss = (weights * losses).sum() / torch.clamp(torch.sum(tgt_mask), min=1.0)
|
||||
else:
|
||||
loss = losses.mean()
|
||||
|
||||
|
||||
@@ -196,7 +196,9 @@ class MiniCPMAttention(nn.Module):
|
||||
key_cache[:, :, position_id, :] = key_states
|
||||
value_cache[:, :, position_id, :] = value_states
|
||||
|
||||
attn_mask = torch.arange(key_cache.size(2), device=key_cache.device) <= position_id
|
||||
# Use an explicit broadcastable mask shape for SDPA. A 1D mask can
|
||||
# trigger a CPU-side dimension bug in some PyTorch versions.
|
||||
attn_mask = (torch.arange(key_cache.size(2), device=key_cache.device) <= position_id).view(1, 1, 1, -1)
|
||||
|
||||
# ref: https://github.com/pytorch/pytorch/issues/163597
|
||||
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
|
||||
|
||||
@@ -15,6 +15,7 @@ from .data import (
|
||||
BatchProcessor,
|
||||
)
|
||||
from .state import TrainingState
|
||||
from .validate import validate_manifest, ValidationResult
|
||||
|
||||
__all__ = [
|
||||
"Accelerator",
|
||||
@@ -24,4 +25,6 @@ __all__ = [
|
||||
"TrainingState",
|
||||
"load_audio_text_datasets",
|
||||
"build_dataloader",
|
||||
"validate_manifest",
|
||||
"ValidationResult",
|
||||
]
|
||||
|
||||
+53
-11
@@ -12,6 +12,7 @@ from .packers import AudioFeatureProcessingPacker
|
||||
|
||||
DEFAULT_TEXT_COLUMN = "text"
|
||||
DEFAULT_AUDIO_COLUMN = "audio"
|
||||
DEFAULT_REF_AUDIO_COLUMN = "ref_audio"
|
||||
DEFAULT_ID_COLUMN = "dataset_id"
|
||||
|
||||
|
||||
@@ -21,6 +22,7 @@ def load_audio_text_datasets(
|
||||
val_manifest: str = "",
|
||||
text_column: str = DEFAULT_TEXT_COLUMN,
|
||||
audio_column: str = DEFAULT_AUDIO_COLUMN,
|
||||
ref_audio_column: str = DEFAULT_REF_AUDIO_COLUMN,
|
||||
dataset_id_column: str = DEFAULT_ID_COLUMN,
|
||||
sample_rate: int = 16_000,
|
||||
num_proc: int = 1,
|
||||
@@ -34,14 +36,19 @@ def load_audio_text_datasets(
|
||||
def prepare(ds: Dataset) -> Dataset:
|
||||
if audio_column not in ds.column_names:
|
||||
raise ValueError(f"Expected '{audio_column}' column in manifest.")
|
||||
# We cast to Audio to ensure proper handling during training,
|
||||
# but for length calculation we might need raw path or duration if available.
|
||||
# HF datasets usually don't compute duration automatically for 'Audio' column.
|
||||
ds = ds.cast_column(audio_column, Audio(sampling_rate=sample_rate))
|
||||
if audio_column != DEFAULT_AUDIO_COLUMN:
|
||||
ds = ds.rename_column(audio_column, DEFAULT_AUDIO_COLUMN)
|
||||
if text_column != DEFAULT_TEXT_COLUMN:
|
||||
ds = ds.rename_column(text_column, DEFAULT_TEXT_COLUMN)
|
||||
|
||||
# ref_audio is optional — cast to Audio if the column exists
|
||||
ref_col = ref_audio_column if ref_audio_column in ds.column_names else DEFAULT_REF_AUDIO_COLUMN
|
||||
if ref_col in ds.column_names:
|
||||
ds = ds.cast_column(ref_col, Audio(sampling_rate=sample_rate))
|
||||
if ref_col != DEFAULT_REF_AUDIO_COLUMN:
|
||||
ds = ds.rename_column(ref_col, DEFAULT_REF_AUDIO_COLUMN)
|
||||
|
||||
if dataset_id_column and dataset_id_column in ds.column_names:
|
||||
if dataset_id_column != DEFAULT_ID_COLUMN:
|
||||
ds = ds.rename_column(dataset_id_column, DEFAULT_ID_COLUMN)
|
||||
@@ -67,11 +74,11 @@ def compute_sample_lengths(
|
||||
- 音频长度:
|
||||
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
|
||||
t_seq = ceil(t_vae / patch_size)
|
||||
- 序列总长约为: text_len + t_seq + 2
|
||||
- 无 ref_audio: text_len + t_seq + 2
|
||||
- 有 ref_audio: text_len + t_seq + ref_seq + 4
|
||||
|
||||
Optimized: Use batch column access instead of iterating item by item.
|
||||
"""
|
||||
# Batch access columns - much faster than per-item access
|
||||
text_ids_list = ds["text_ids"]
|
||||
text_lens = [len(t) for t in text_ids_list]
|
||||
|
||||
@@ -79,18 +86,35 @@ def compute_sample_lengths(
|
||||
if has_duration:
|
||||
durations = ds["duration"]
|
||||
else:
|
||||
# Fallback: need to compute from audio (slow, but unavoidable without duration column)
|
||||
durations = []
|
||||
for i in range(len(ds)):
|
||||
audio = ds[i][DEFAULT_AUDIO_COLUMN]
|
||||
durations.append(len(audio["array"]) / float(audio["sampling_rate"]))
|
||||
|
||||
# Vectorized length computation
|
||||
has_ref_audio = DEFAULT_REF_AUDIO_COLUMN in ds.column_names
|
||||
if has_ref_audio:
|
||||
ref_duration_col = "ref_duration" if "ref_duration" in ds.column_names else None
|
||||
|
||||
lengths = []
|
||||
for text_len, duration in zip(text_lens, durations):
|
||||
for i, (text_len, duration) in enumerate(zip(text_lens, durations)):
|
||||
t_vae = math.ceil(float(duration) * audio_vae_fps)
|
||||
t_seq = math.ceil(t_vae / patch_size)
|
||||
total_len = text_len + t_seq + 2
|
||||
|
||||
ref_seq = 0
|
||||
if has_ref_audio:
|
||||
# Estimate ref_audio length; ref_audio is None for samples without it
|
||||
if ref_duration_col:
|
||||
ref_dur = ds[i].get(ref_duration_col)
|
||||
else:
|
||||
ref_item = ds[i].get(DEFAULT_REF_AUDIO_COLUMN)
|
||||
ref_dur = len(ref_item["array"]) / float(ref_item["sampling_rate"]) if ref_item else None
|
||||
if ref_dur is not None and float(ref_dur) > 0:
|
||||
ref_vae = math.ceil(float(ref_dur) * audio_vae_fps)
|
||||
ref_seq = math.ceil(ref_vae / patch_size)
|
||||
|
||||
# +2 for 101/102; +2 more for 103/104 when ref_audio present
|
||||
overhead = 4 if ref_seq > 0 else 2
|
||||
total_len = text_len + t_seq + ref_seq + overhead
|
||||
lengths.append(total_len)
|
||||
|
||||
return lengths
|
||||
@@ -102,8 +126,11 @@ class HFVoxCPMDataset(TorchDataset):
|
||||
PyTorch-friendly samples.
|
||||
"""
|
||||
|
||||
_SENTINEL = [-100.0]
|
||||
|
||||
def __init__(self, dataset: Dataset):
|
||||
self.dataset = dataset
|
||||
self.has_ref_audio = DEFAULT_REF_AUDIO_COLUMN in dataset.column_names
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
@@ -111,13 +138,17 @@ class HFVoxCPMDataset(TorchDataset):
|
||||
def __getitem__(self, idx: int):
|
||||
item = self.dataset[idx]
|
||||
audio = item[DEFAULT_AUDIO_COLUMN]
|
||||
return {
|
||||
sample = {
|
||||
"text_ids": item["text_ids"],
|
||||
"audio_array": audio["array"],
|
||||
"audio_sampling_rate": audio["sampling_rate"],
|
||||
"dataset_id": item.get(DEFAULT_ID_COLUMN, 0),
|
||||
"is_prompt": item.get("is_prompt", False),
|
||||
}
|
||||
if self.has_ref_audio:
|
||||
ref = item.get(DEFAULT_REF_AUDIO_COLUMN)
|
||||
sample["ref_audio_array"] = ref["array"] if ref else self._SENTINEL
|
||||
return sample
|
||||
|
||||
@staticmethod
|
||||
def pad_sequences(seqs: List[torch.Tensor], pad_value: float):
|
||||
@@ -143,7 +174,7 @@ class HFVoxCPMDataset(TorchDataset):
|
||||
audio_padded = cls.pad_sequences(audio_tensors, pad_value=-100.0)
|
||||
task_ids = torch.ones(text_padded.size(0), dtype=torch.int32)
|
||||
|
||||
return {
|
||||
result = {
|
||||
"text_tokens": text_padded,
|
||||
"audio_tokens": audio_padded,
|
||||
"task_ids": task_ids,
|
||||
@@ -151,6 +182,12 @@ class HFVoxCPMDataset(TorchDataset):
|
||||
"is_prompts": is_prompts,
|
||||
}
|
||||
|
||||
if "ref_audio_array" in batch[0]:
|
||||
ref_tensors = [torch.tensor(s["ref_audio_array"], dtype=torch.float32) for s in batch]
|
||||
result["ref_audio_tokens"] = cls.pad_sequences(ref_tensors, pad_value=-100.0)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class BatchProcessor:
|
||||
"""
|
||||
@@ -184,12 +221,17 @@ class BatchProcessor:
|
||||
task_ids = batch["task_ids"].to(self.device)
|
||||
dataset_ids = batch["dataset_ids"].to(self.device)
|
||||
|
||||
ref_audio_tokens = None
|
||||
if "ref_audio_tokens" in batch:
|
||||
ref_audio_tokens = batch["ref_audio_tokens"].to(self.device)
|
||||
|
||||
packed = self.packer(
|
||||
audio_tokens=audio_tokens,
|
||||
text_tokens=text_tokens,
|
||||
task_ids=task_ids,
|
||||
dataset_ids=dataset_ids,
|
||||
is_prompts=batch["is_prompts"],
|
||||
ref_audio_tokens=ref_audio_tokens,
|
||||
)
|
||||
return packed
|
||||
|
||||
|
||||
+124
-14
@@ -1,4 +1,4 @@
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -14,7 +14,6 @@ class AudioFeatureProcessingPacker:
|
||||
def __init__(self, dataset_cnt: int, max_len: int, patch_size: int, feat_dim: int, audio_vae: nn.Module):
|
||||
self.audio_start_id = 101
|
||||
self.audio_end_id = 102
|
||||
# unused now
|
||||
self.audio_prompt_start_id = 103
|
||||
self.audio_prompt_end_id = 104
|
||||
self.text_eos_token_id = 2
|
||||
@@ -78,11 +77,16 @@ class AudioFeatureProcessingPacker:
|
||||
task_ids: torch.Tensor,
|
||||
dataset_ids: torch.Tensor,
|
||||
is_prompts: List[bool],
|
||||
ref_audio_tokens: Optional[torch.Tensor] = None,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Padding-based batching: each sample in the input batch is processed
|
||||
independently and then padded to a common length (capped by ``max_len``).
|
||||
The result tensors all have shape [B, T, ...].
|
||||
|
||||
If ``ref_audio_tokens`` is provided (same batch dim as ``audio_tokens``),
|
||||
samples whose unpadded ref_audio length > 0 will be processed with the
|
||||
reference-audio path (tokens 103/104 prepended, loss only on target audio).
|
||||
"""
|
||||
device = audio_tokens.device
|
||||
max_dataset_id = int(dataset_ids.max().item()) if dataset_ids.numel() > 0 else -1
|
||||
@@ -101,23 +105,43 @@ class AudioFeatureProcessingPacker:
|
||||
audio_duration_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
|
||||
text_token_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
|
||||
|
||||
for audio_token, text_token, task_id, dataset_idx, is_prompt in zip(
|
||||
audio_tokens, text_tokens, task_ids.tolist(), dataset_ids.tolist(), is_prompts
|
||||
ref_iter = ref_audio_tokens if ref_audio_tokens is not None else [None] * audio_tokens.size(0)
|
||||
|
||||
for audio_token, text_token, task_id, dataset_idx, is_prompt, ref_token in zip(
|
||||
audio_tokens, text_tokens, task_ids.tolist(), dataset_ids.tolist(), is_prompts, ref_iter
|
||||
):
|
||||
unpad_audio_token = self.unpad_audio_tokens(audio_token).to(torch.float32)
|
||||
unpad_text_token = self.unpad_text_tokens(text_token)
|
||||
usage = self.id_to_task[task_id]
|
||||
|
||||
(
|
||||
packed_text,
|
||||
audio_feat,
|
||||
text_mask,
|
||||
audio_mask,
|
||||
loss_mask,
|
||||
labels,
|
||||
audio_duration,
|
||||
text_token_count,
|
||||
) = self.process_functions[usage](unpad_audio_token, unpad_text_token, is_prompt)
|
||||
has_ref = False
|
||||
if ref_token is not None:
|
||||
unpad_ref_token = self.unpad_audio_tokens(ref_token).to(torch.float32)
|
||||
if unpad_ref_token.numel() > 0:
|
||||
has_ref = True
|
||||
|
||||
if has_ref:
|
||||
(
|
||||
packed_text,
|
||||
audio_feat,
|
||||
text_mask,
|
||||
audio_mask,
|
||||
loss_mask,
|
||||
labels,
|
||||
audio_duration,
|
||||
text_token_count,
|
||||
) = self.process_tts_data_with_ref(unpad_ref_token, unpad_audio_token, unpad_text_token)
|
||||
else:
|
||||
(
|
||||
packed_text,
|
||||
audio_feat,
|
||||
text_mask,
|
||||
audio_mask,
|
||||
loss_mask,
|
||||
labels,
|
||||
audio_duration,
|
||||
text_token_count,
|
||||
) = self.process_functions[usage](unpad_audio_token, unpad_text_token, is_prompt)
|
||||
|
||||
audio_duration_consumed[dataset_idx] += audio_duration
|
||||
text_token_consumed[dataset_idx] += text_token_count
|
||||
@@ -294,3 +318,89 @@ class AudioFeatureProcessingPacker:
|
||||
audio_duration,
|
||||
text_token_count,
|
||||
)
|
||||
|
||||
def process_tts_data_with_ref(
|
||||
self,
|
||||
ref_audio_token: torch.Tensor,
|
||||
target_audio_token: torch.Tensor,
|
||||
text_token: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Build a training sequence with reference audio prepended:
|
||||
|
||||
[103, ref_feats, 104, text, 101, target_feats, 102]
|
||||
|
||||
Loss is computed only on the target audio segment.
|
||||
"""
|
||||
device = text_token.device
|
||||
txt_len = len(text_token)
|
||||
|
||||
ref_feats, ref_duration = self.extract_audio_feats(ref_audio_token)
|
||||
ref_feats = ref_feats.squeeze(0) # [R, P, D]
|
||||
ref_len = ref_feats.shape[0]
|
||||
|
||||
tgt_feats, tgt_duration = self.extract_audio_feats(target_audio_token)
|
||||
tgt_feats = tgt_feats.squeeze(0) # [A, P, D]
|
||||
tgt_len = tgt_feats.shape[0]
|
||||
|
||||
feat_shape = (self.patch_size, ref_feats.size(-1))
|
||||
|
||||
def _tok(ids):
|
||||
return torch.tensor(ids, dtype=torch.int32, device=device)
|
||||
|
||||
# -- text token track --
|
||||
# [103, 0×R, 104, text_ids, 101, 0×A, 102]
|
||||
text_token_info = torch.cat([
|
||||
_tok([self.audio_prompt_start_id]),
|
||||
torch.zeros(ref_len, dtype=torch.int32, device=device),
|
||||
_tok([self.audio_prompt_end_id]),
|
||||
text_token,
|
||||
_tok([self.audio_start_id]),
|
||||
torch.zeros(tgt_len, dtype=torch.int32, device=device),
|
||||
_tok([self.audio_end_id]),
|
||||
])
|
||||
|
||||
# -- audio feature track --
|
||||
zero_1 = torch.zeros((1,) + feat_shape, dtype=torch.float32, device=device)
|
||||
zero_txt = torch.zeros((txt_len,) + feat_shape, dtype=torch.float32, device=device)
|
||||
audio_feat_info = torch.cat([
|
||||
zero_1, ref_feats, zero_1, # 103, ref, 104
|
||||
zero_txt, # text
|
||||
zero_1, tgt_feats, zero_1, # 101, target, 102
|
||||
], dim=0)
|
||||
|
||||
# -- masks --
|
||||
text_mask = torch.cat([
|
||||
torch.ones(1), torch.zeros(ref_len), torch.ones(1),
|
||||
torch.ones(txt_len),
|
||||
torch.ones(1), torch.zeros(tgt_len), torch.ones(1),
|
||||
]).to(torch.int32).to(device)
|
||||
|
||||
audio_mask = torch.cat([
|
||||
torch.zeros(1), torch.ones(ref_len), torch.zeros(1),
|
||||
torch.zeros(txt_len),
|
||||
torch.zeros(1), torch.ones(tgt_len), torch.zeros(1),
|
||||
]).to(torch.int32).to(device)
|
||||
|
||||
loss_mask = torch.cat([
|
||||
torch.zeros(1 + ref_len + 1), # ref part: no loss
|
||||
torch.zeros(txt_len), # text: no loss
|
||||
torch.zeros(1), # 101: no loss
|
||||
torch.ones(tgt_len), # target audio: LOSS
|
||||
torch.zeros(1), # 102: no loss
|
||||
]).to(torch.int32).to(device)
|
||||
|
||||
total_len = 1 + ref_len + 1 + txt_len + 1 + tgt_len + 1
|
||||
labels = torch.zeros(total_len, dtype=torch.int32, device=device)
|
||||
labels[-2] = 1 # stop label at last target audio position
|
||||
|
||||
return (
|
||||
text_token_info,
|
||||
audio_feat_info,
|
||||
text_mask,
|
||||
audio_mask,
|
||||
loss_mask,
|
||||
labels,
|
||||
ref_duration + tgt_duration,
|
||||
txt_len,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,309 @@
|
||||
"""
|
||||
Pre-flight validation for VoxCPM training data manifests.
|
||||
|
||||
Validates JSONL manifest files before starting expensive fine-tuning jobs,
|
||||
catching format issues, missing files, and data quality problems early.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""Structured result of a manifest validation run."""
|
||||
|
||||
total_samples: int = 0
|
||||
valid_samples: int = 0
|
||||
errors: List[str] = field(default_factory=list)
|
||||
warnings: List[str] = field(default_factory=list)
|
||||
audio_durations: List[float] = field(default_factory=list)
|
||||
text_lengths: List[int] = field(default_factory=list)
|
||||
has_ref_audio: int = 0
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
return len(self.errors) == 0 and self.valid_samples > 0
|
||||
|
||||
|
||||
def _check_audio_file(audio_path: str, sample_rate: int) -> Optional[str]:
|
||||
"""Check if an audio file exists, is readable, and matches expected sample rate.
|
||||
|
||||
Returns an error message, or None if the file is valid.
|
||||
"""
|
||||
if not os.path.isfile(audio_path):
|
||||
return f"Audio file not found: {audio_path}"
|
||||
try:
|
||||
import soundfile as sf
|
||||
|
||||
info = sf.info(audio_path)
|
||||
if info.frames == 0:
|
||||
return f"Audio file is empty: {audio_path}"
|
||||
if info.samplerate != sample_rate:
|
||||
return (
|
||||
f"Sample rate mismatch in {audio_path}: "
|
||||
f"expected {sample_rate} Hz, got {info.samplerate} Hz"
|
||||
)
|
||||
return None
|
||||
except ImportError:
|
||||
# soundfile not available; just check existence
|
||||
return None
|
||||
except Exception as e:
|
||||
return f"Cannot read audio file {audio_path}: {e}"
|
||||
|
||||
|
||||
def _get_audio_duration(audio_path: str) -> Optional[float]:
|
||||
"""Get audio duration in seconds. Returns None if unavailable."""
|
||||
try:
|
||||
import soundfile as sf
|
||||
|
||||
info = sf.info(audio_path)
|
||||
return info.duration
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def validate_manifest(
|
||||
manifest_path: str,
|
||||
sample_rate: int = 16_000,
|
||||
max_samples: int = 0,
|
||||
verbose: bool = False,
|
||||
) -> ValidationResult:
|
||||
"""Validate a JSONL training manifest file.
|
||||
|
||||
Checks:
|
||||
1. File exists and is readable
|
||||
2. Each line is valid JSON
|
||||
3. Required columns present (text, audio)
|
||||
4. Audio files exist and are readable
|
||||
5. Text content is non-empty
|
||||
6. Collects duration and text length statistics
|
||||
7. Validates optional ref_audio column
|
||||
|
||||
Args:
|
||||
manifest_path: Path to the JSONL manifest file.
|
||||
sample_rate: Expected audio sample rate (for informational purposes).
|
||||
max_samples: Maximum number of samples to validate (0 = all).
|
||||
verbose: Print per-sample progress.
|
||||
|
||||
Returns:
|
||||
ValidationResult with errors, warnings, and statistics.
|
||||
"""
|
||||
result = ValidationResult()
|
||||
path = Path(manifest_path)
|
||||
|
||||
if not path.exists():
|
||||
result.errors.append(f"Manifest file not found: {manifest_path}")
|
||||
return result
|
||||
|
||||
if not path.is_file():
|
||||
result.errors.append(f"Manifest path is not a file: {manifest_path}")
|
||||
return result
|
||||
|
||||
manifest_dir = path.parent
|
||||
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
except Exception as e:
|
||||
result.errors.append(f"Cannot read manifest file: {e}")
|
||||
return result
|
||||
|
||||
if not lines:
|
||||
result.errors.append("Manifest file is empty")
|
||||
return result
|
||||
|
||||
samples_to_check = len(lines)
|
||||
if max_samples > 0:
|
||||
samples_to_check = min(samples_to_check, max_samples)
|
||||
|
||||
missing_audio_count = 0
|
||||
empty_text_count = 0
|
||||
|
||||
for i, line in enumerate(lines[:samples_to_check]):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
result.total_samples += 1
|
||||
|
||||
# Check JSON validity
|
||||
try:
|
||||
entry = json.loads(line)
|
||||
except json.JSONDecodeError as e:
|
||||
result.errors.append(f"Line {i + 1}: Invalid JSON — {e}")
|
||||
continue
|
||||
|
||||
if not isinstance(entry, dict):
|
||||
result.errors.append(f"Line {i + 1}: Expected JSON object, got {type(entry).__name__}")
|
||||
continue
|
||||
|
||||
# Check required columns
|
||||
has_error = False
|
||||
|
||||
if "text" not in entry:
|
||||
result.errors.append(f"Line {i + 1}: Missing required column 'text'")
|
||||
has_error = True
|
||||
|
||||
if "audio" not in entry:
|
||||
result.errors.append(f"Line {i + 1}: Missing required column 'audio'")
|
||||
has_error = True
|
||||
|
||||
if has_error:
|
||||
continue
|
||||
|
||||
# Validate text
|
||||
text = entry["text"]
|
||||
if not isinstance(text, str) or not text.strip():
|
||||
empty_text_count += 1
|
||||
if empty_text_count <= 5:
|
||||
result.warnings.append(f"Line {i + 1}: Empty or non-string text")
|
||||
else:
|
||||
result.text_lengths.append(len(text))
|
||||
|
||||
# Validate audio path
|
||||
audio_path = entry["audio"]
|
||||
if isinstance(audio_path, dict):
|
||||
# HuggingFace Audio format with {"path": ..., "array": ...}
|
||||
audio_path = audio_path.get("path", "")
|
||||
|
||||
if isinstance(audio_path, str) and audio_path:
|
||||
# Resolve relative paths against manifest directory
|
||||
if not os.path.isabs(audio_path):
|
||||
audio_path = str(manifest_dir / audio_path)
|
||||
|
||||
audio_error = _check_audio_file(audio_path, sample_rate)
|
||||
if audio_error:
|
||||
missing_audio_count += 1
|
||||
if missing_audio_count <= 5:
|
||||
result.errors.append(f"Line {i + 1}: {audio_error}")
|
||||
has_error = True
|
||||
else:
|
||||
duration = _get_audio_duration(audio_path)
|
||||
if duration is not None:
|
||||
result.audio_durations.append(duration)
|
||||
if duration < 0.3:
|
||||
result.warnings.append(
|
||||
f"Line {i + 1}: Very short audio ({duration:.2f}s)"
|
||||
)
|
||||
elif duration > 30.0:
|
||||
result.warnings.append(
|
||||
f"Line {i + 1}: Very long audio ({duration:.1f}s), may cause OOM"
|
||||
)
|
||||
else:
|
||||
result.errors.append(f"Line {i + 1}: Invalid audio path")
|
||||
has_error = True
|
||||
|
||||
# Validate optional ref_audio
|
||||
if "ref_audio" in entry:
|
||||
ref_path = entry["ref_audio"]
|
||||
if isinstance(ref_path, dict):
|
||||
ref_path = ref_path.get("path", "")
|
||||
if isinstance(ref_path, str) and ref_path:
|
||||
if not os.path.isabs(ref_path):
|
||||
ref_path = str(manifest_dir / ref_path)
|
||||
if os.path.isfile(ref_path):
|
||||
result.has_ref_audio += 1
|
||||
else:
|
||||
result.warnings.append(
|
||||
f"Line {i + 1}: ref_audio file not found: {ref_path}"
|
||||
)
|
||||
|
||||
if not has_error:
|
||||
result.valid_samples += 1
|
||||
|
||||
if verbose and (i + 1) % 100 == 0:
|
||||
print(f" Validated {i + 1}/{samples_to_check} samples...", file=sys.stderr)
|
||||
|
||||
# Summarize truncated errors
|
||||
if missing_audio_count > 5:
|
||||
result.errors.append(
|
||||
f"... and {missing_audio_count - 5} more missing audio files "
|
||||
f"({missing_audio_count} total)"
|
||||
)
|
||||
if empty_text_count > 5:
|
||||
result.warnings.append(
|
||||
f"... and {empty_text_count - 5} more empty text entries "
|
||||
f"({empty_text_count} total)"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def print_validation_report(result: ValidationResult, manifest_path: str) -> None:
|
||||
"""Print a human-readable validation report to stderr."""
|
||||
print(f"\n{'=' * 60}", file=sys.stderr)
|
||||
print(f" VoxCPM Training Data Validation Report", file=sys.stderr)
|
||||
print(f"{'=' * 60}", file=sys.stderr)
|
||||
print(f" Manifest : {manifest_path}", file=sys.stderr)
|
||||
print(f" Samples : {result.valid_samples}/{result.total_samples} valid", file=sys.stderr)
|
||||
|
||||
if result.has_ref_audio > 0:
|
||||
print(
|
||||
f" Ref Audio: {result.has_ref_audio} samples with reference audio",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# Audio duration statistics
|
||||
if result.audio_durations:
|
||||
durations = sorted(result.audio_durations)
|
||||
total_hrs = sum(durations) / 3600
|
||||
print(f"\n Audio Duration Statistics:", file=sys.stderr)
|
||||
print(f" Total : {total_hrs:.2f} hours", file=sys.stderr)
|
||||
print(
|
||||
f" Range : {durations[0]:.2f}s — {durations[-1]:.1f}s",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(
|
||||
f" Mean : {sum(durations) / len(durations):.2f}s",
|
||||
file=sys.stderr,
|
||||
)
|
||||
median_idx = len(durations) // 2
|
||||
print(f" Median : {durations[median_idx]:.2f}s", file=sys.stderr)
|
||||
|
||||
# Text length statistics
|
||||
if result.text_lengths:
|
||||
lengths = sorted(result.text_lengths)
|
||||
print(f"\n Text Length Statistics (characters):", file=sys.stderr)
|
||||
print(
|
||||
f" Range : {lengths[0]} — {lengths[-1]}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(
|
||||
f" Mean : {sum(lengths) / len(lengths):.0f}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# Errors
|
||||
if result.errors:
|
||||
print(f"\n ERRORS ({len(result.errors)}):", file=sys.stderr)
|
||||
for err in result.errors[:20]:
|
||||
print(f" x {err}", file=sys.stderr)
|
||||
if len(result.errors) > 20:
|
||||
print(
|
||||
f" ... ({len(result.errors) - 20} more errors omitted)",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# Warnings
|
||||
if result.warnings:
|
||||
print(f"\n WARNINGS ({len(result.warnings)}):", file=sys.stderr)
|
||||
for warn in result.warnings[:10]:
|
||||
print(f" ! {warn}", file=sys.stderr)
|
||||
if len(result.warnings) > 10:
|
||||
print(
|
||||
f" ... ({len(result.warnings) - 10} more warnings omitted)",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# Summary
|
||||
print(f"\n{'=' * 60}", file=sys.stderr)
|
||||
if result.is_valid:
|
||||
print(" PASSED: Manifest is valid for training.", file=sys.stderr)
|
||||
else:
|
||||
print(" FAILED: Fix errors above before starting training.", file=sys.stderr)
|
||||
print(f"{'=' * 60}\n", file=sys.stderr)
|
||||
@@ -58,6 +58,7 @@ def test_parser_defaults_to_voxcpm2():
|
||||
parser = cli._build_parser()
|
||||
args = parser.parse_args(["design", "--text", "hello", "--output", "out.wav"])
|
||||
assert args.hf_model_id == "openbmb/VoxCPM2"
|
||||
assert args.device == "auto"
|
||||
assert args.no_optimize is False
|
||||
|
||||
|
||||
@@ -85,6 +86,7 @@ def test_load_model_respects_no_optimize_for_local_model(monkeypatch):
|
||||
|
||||
cli.load_model(args)
|
||||
|
||||
assert calls["kwargs"]["device"] == "auto"
|
||||
assert calls["kwargs"]["optimize"] is False
|
||||
|
||||
|
||||
@@ -110,6 +112,7 @@ def test_load_model_defaults_optimize_for_hf(monkeypatch):
|
||||
|
||||
cli.load_model(args)
|
||||
|
||||
assert calls["kwargs"]["device"] == "auto"
|
||||
assert calls["kwargs"]["optimize"] is True
|
||||
|
||||
|
||||
@@ -136,9 +139,37 @@ def test_load_model_respects_no_optimize_for_hf(monkeypatch):
|
||||
|
||||
cli.load_model(args)
|
||||
|
||||
assert calls["kwargs"]["device"] == "auto"
|
||||
assert calls["kwargs"]["optimize"] is False
|
||||
|
||||
|
||||
def test_load_model_passes_explicit_device_to_hf(monkeypatch):
|
||||
calls = {}
|
||||
|
||||
class FakeVoxCPM:
|
||||
@classmethod
|
||||
def from_pretrained(cls, **kwargs):
|
||||
calls["kwargs"] = kwargs
|
||||
return DummyModel()
|
||||
|
||||
monkeypatch.setattr(cli, "VoxCPM", FakeVoxCPM)
|
||||
args = cli._build_parser().parse_args(
|
||||
[
|
||||
"design",
|
||||
"--text",
|
||||
"hello",
|
||||
"--output",
|
||||
"out.wav",
|
||||
"--device",
|
||||
"mps",
|
||||
]
|
||||
)
|
||||
|
||||
cli.load_model(args)
|
||||
|
||||
assert calls["kwargs"]["device"] == "mps"
|
||||
|
||||
|
||||
def test_design_subcommand_applies_control(monkeypatch, tmp_path):
|
||||
dummy_model = DummyModel()
|
||||
monkeypatch.setattr(cli, "load_model", lambda args: dummy_model)
|
||||
|
||||
@@ -0,0 +1,150 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
SRC = ROOT / "src"
|
||||
|
||||
|
||||
def _load_module(name: str, path: Path):
|
||||
spec = importlib.util.spec_from_file_location(name, path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec.loader is not None
|
||||
sys.modules[name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def bootstrap_repo_modules(monkeypatch):
|
||||
for name, path in [
|
||||
("voxcpm", SRC / "voxcpm"),
|
||||
("voxcpm.model", SRC / "voxcpm" / "model"),
|
||||
("voxcpm.modules", SRC / "voxcpm" / "modules"),
|
||||
]:
|
||||
pkg = types.ModuleType(name)
|
||||
pkg.__path__ = [str(path)]
|
||||
monkeypatch.setitem(sys.modules, name, pkg)
|
||||
|
||||
hh = types.ModuleType("huggingface_hub")
|
||||
hh.snapshot_download = lambda *a, **k: "/tmp/fake"
|
||||
monkeypatch.setitem(sys.modules, "huggingface_hub", hh)
|
||||
|
||||
pydantic = types.ModuleType("pydantic")
|
||||
|
||||
class BaseModel:
|
||||
@classmethod
|
||||
def model_rebuild(cls):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def model_validate_json(cls, s):
|
||||
return cls()
|
||||
|
||||
def model_dump(self):
|
||||
return {}
|
||||
|
||||
pydantic.BaseModel = BaseModel
|
||||
monkeypatch.setitem(sys.modules, "pydantic", pydantic)
|
||||
|
||||
torchaudio = types.ModuleType("torchaudio")
|
||||
monkeypatch.setitem(sys.modules, "torchaudio", torchaudio)
|
||||
|
||||
librosa = types.ModuleType("librosa")
|
||||
librosa.effects = types.SimpleNamespace(trim=lambda *a, **k: (None, (0, 0)))
|
||||
monkeypatch.setitem(sys.modules, "librosa", librosa)
|
||||
|
||||
einops = types.ModuleType("einops")
|
||||
einops.rearrange = lambda x, *a, **k: x
|
||||
monkeypatch.setitem(sys.modules, "einops", einops)
|
||||
|
||||
tqdm_pkg = types.ModuleType("tqdm")
|
||||
tqdm_pkg.__path__ = ["/nonexistent"]
|
||||
tqdm_pkg.tqdm = lambda x, *a, **k: x
|
||||
monkeypatch.setitem(sys.modules, "tqdm", tqdm_pkg)
|
||||
|
||||
tqdm_auto = types.ModuleType("tqdm.auto")
|
||||
tqdm_auto.tqdm = lambda x, *a, **k: x
|
||||
monkeypatch.setitem(sys.modules, "tqdm.auto", tqdm_auto)
|
||||
|
||||
transformers = types.ModuleType("transformers")
|
||||
|
||||
class LlamaTokenizerFast:
|
||||
pass
|
||||
|
||||
class PreTrainedTokenizer:
|
||||
pass
|
||||
|
||||
transformers.LlamaTokenizerFast = LlamaTokenizerFast
|
||||
transformers.PreTrainedTokenizer = PreTrainedTokenizer
|
||||
monkeypatch.setitem(sys.modules, "transformers", transformers)
|
||||
|
||||
internal_mods = {
|
||||
"voxcpm.modules.audiovae": ["AudioVAE", "AudioVAEConfig", "AudioVAEV2", "AudioVAEConfigV2"],
|
||||
"voxcpm.modules.layers": ["ScalarQuantizationLayer"],
|
||||
"voxcpm.modules.locdit": ["CfmConfig", "UnifiedCFM", "VoxCPMLocDiT", "VoxCPMLocDiTV2"],
|
||||
"voxcpm.modules.locenc": ["VoxCPMLocEnc"],
|
||||
"voxcpm.modules.minicpm4": ["MiniCPM4Config", "MiniCPMModel"],
|
||||
"voxcpm.modules.layers.lora": ["apply_lora_to_named_linear_modules", "LoRALinear"],
|
||||
}
|
||||
for modname, names in internal_mods.items():
|
||||
module = types.ModuleType(modname)
|
||||
for name in names:
|
||||
if name == "apply_lora_to_named_linear_modules":
|
||||
setattr(module, name, lambda *a, **k: None)
|
||||
else:
|
||||
setattr(module, name, type(name, (), {}))
|
||||
monkeypatch.setitem(sys.modules, modname, module)
|
||||
|
||||
_load_module("voxcpm.model.utils", SRC / "voxcpm" / "model" / "utils.py")
|
||||
voxcpm = _load_module("voxcpm.model.voxcpm", SRC / "voxcpm" / "model" / "voxcpm.py")
|
||||
voxcpm2 = _load_module("voxcpm.model.voxcpm2", SRC / "voxcpm" / "model" / "voxcpm2.py")
|
||||
return voxcpm.VoxCPMModel, voxcpm2.VoxCPM2Model
|
||||
|
||||
|
||||
class DummyModel:
|
||||
device = "cpu"
|
||||
|
||||
def named_parameters(self):
|
||||
return []
|
||||
|
||||
|
||||
@pytest.mark.parametrize("module_name", ["v1", "v2"])
|
||||
def test_load_lora_weights_accepts_tensor_only_legacy_checkpoints(monkeypatch, tmp_path, module_name):
|
||||
VoxCPMModel, VoxCPM2Model = bootstrap_repo_modules(monkeypatch)
|
||||
cls = VoxCPMModel if module_name == "v1" else VoxCPM2Model
|
||||
|
||||
ckpt_path = tmp_path / "lora_weights.ckpt"
|
||||
torch.save({"state_dict": {"fake": torch.zeros(1)}}, ckpt_path)
|
||||
|
||||
loaded, skipped = cls.load_lora_weights(DummyModel(), str(ckpt_path), device="cpu")
|
||||
|
||||
assert loaded == []
|
||||
assert skipped == ["fake"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("module_name", ["v1", "v2"])
|
||||
def test_load_lora_weights_rejects_malicious_pickle_payloads(monkeypatch, tmp_path, module_name):
|
||||
VoxCPMModel, VoxCPM2Model = bootstrap_repo_modules(monkeypatch)
|
||||
cls = VoxCPMModel if module_name == "v1" else VoxCPM2Model
|
||||
|
||||
ckpt_path = tmp_path / "lora_weights.ckpt"
|
||||
marker_path = tmp_path / f"{module_name}-marker.txt"
|
||||
|
||||
class Exploit:
|
||||
def __reduce__(self):
|
||||
import pathlib
|
||||
|
||||
return (pathlib.Path.write_text, (marker_path, f"{module_name} executed\n"))
|
||||
|
||||
torch.save({"state_dict": {"fake": torch.zeros(1)}, "boom": Exploit()}, ckpt_path)
|
||||
|
||||
with pytest.raises(Exception, match="Weights only load failed"):
|
||||
cls.load_lora_weights(DummyModel(), str(ckpt_path), device="cpu")
|
||||
|
||||
assert not marker_path.exists()
|
||||
@@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
UTILS_PATH = ROOT / "src" / "voxcpm" / "model" / "utils.py"
|
||||
|
||||
transformers_stub = types.ModuleType("transformers")
|
||||
transformers_stub.PreTrainedTokenizer = object
|
||||
sys.modules.setdefault("transformers", transformers_stub)
|
||||
|
||||
spec = importlib.util.spec_from_file_location("voxcpm.model.utils", UTILS_PATH)
|
||||
utils = importlib.util.module_from_spec(spec)
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(utils)
|
||||
|
||||
|
||||
def test_resolve_runtime_device_auto_falls_back_to_cpu(monkeypatch):
|
||||
monkeypatch.setattr(utils.torch.cuda, "is_available", lambda: False)
|
||||
monkeypatch.setattr(utils, "_has_mps", lambda: False)
|
||||
|
||||
assert utils.resolve_runtime_device(None, "cuda") == "cpu"
|
||||
|
||||
|
||||
def test_resolve_runtime_device_auto_uses_mps_when_available(monkeypatch):
|
||||
monkeypatch.setattr(utils.torch.cuda, "is_available", lambda: False)
|
||||
monkeypatch.setattr(utils, "_has_mps", lambda: True)
|
||||
|
||||
assert utils.resolve_runtime_device("auto", "cuda") == "mps"
|
||||
|
||||
|
||||
def test_resolve_runtime_device_respects_explicit_cpu(monkeypatch):
|
||||
monkeypatch.setattr(utils.torch.cuda, "is_available", lambda: True)
|
||||
monkeypatch.setattr(utils, "_has_mps", lambda: True)
|
||||
|
||||
assert utils.resolve_runtime_device("cpu", "cuda") == "cpu"
|
||||
|
||||
|
||||
def test_resolve_runtime_device_rejects_unavailable_explicit_cuda(monkeypatch):
|
||||
monkeypatch.setattr(utils.torch.cuda, "is_available", lambda: False)
|
||||
monkeypatch.setattr(utils, "_has_mps", lambda: True)
|
||||
|
||||
with pytest.raises(ValueError, match="CUDA is not available"):
|
||||
utils.resolve_runtime_device("cuda:0", "cuda")
|
||||
@@ -0,0 +1,252 @@
|
||||
"""Tests for the training data validation module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
|
||||
# Stub voxcpm package so imports work without full dependencies
|
||||
pkg = types.ModuleType("voxcpm")
|
||||
pkg.__path__ = [str(ROOT / "src" / "voxcpm")]
|
||||
sys.modules.setdefault("voxcpm", pkg)
|
||||
|
||||
training_pkg = types.ModuleType("voxcpm.training")
|
||||
training_pkg.__path__ = [str(ROOT / "src" / "voxcpm" / "training")]
|
||||
sys.modules.setdefault("voxcpm.training", training_pkg)
|
||||
|
||||
from voxcpm.training.validate import ValidationResult, validate_manifest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_dir():
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
yield Path(d)
|
||||
|
||||
|
||||
def _create_wav(path: Path, duration_s: float = 1.0, sr: int = 16000):
|
||||
"""Create a minimal valid WAV file."""
|
||||
try:
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
|
||||
samples = int(duration_s * sr)
|
||||
data = np.zeros(samples, dtype=np.float32)
|
||||
sf.write(str(path), data, sr)
|
||||
except ImportError:
|
||||
# If soundfile is not available, create a minimal WAV header
|
||||
import struct
|
||||
|
||||
samples = int(duration_s * sr)
|
||||
data_size = samples * 2 # 16-bit PCM
|
||||
with open(path, "wb") as f:
|
||||
f.write(b"RIFF")
|
||||
f.write(struct.pack("<I", 36 + data_size))
|
||||
f.write(b"WAVEfmt ")
|
||||
f.write(struct.pack("<IHHIIHH", 16, 1, 1, sr, sr * 2, 2, 16))
|
||||
f.write(b"data")
|
||||
f.write(struct.pack("<I", data_size))
|
||||
f.write(b"\x00" * data_size)
|
||||
|
||||
|
||||
def _write_manifest(path: Path, entries: list[dict]):
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
for entry in entries:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
||||
|
||||
|
||||
class TestValidateManifest:
|
||||
def test_valid_manifest(self, tmp_dir):
|
||||
audio1 = tmp_dir / "audio1.wav"
|
||||
audio2 = tmp_dir / "audio2.wav"
|
||||
_create_wav(audio1, 2.0)
|
||||
_create_wav(audio2, 3.0)
|
||||
|
||||
manifest = tmp_dir / "train.jsonl"
|
||||
_write_manifest(
|
||||
manifest,
|
||||
[
|
||||
{"text": "Hello world", "audio": str(audio1)},
|
||||
{"text": "Goodbye world", "audio": str(audio2)},
|
||||
],
|
||||
)
|
||||
|
||||
result = validate_manifest(str(manifest))
|
||||
assert result.total_samples == 2
|
||||
assert result.valid_samples == 2
|
||||
assert result.is_valid
|
||||
assert len(result.errors) == 0
|
||||
|
||||
def test_missing_manifest(self):
|
||||
result = validate_manifest("/nonexistent/path.jsonl")
|
||||
assert not result.is_valid
|
||||
assert any("not found" in e for e in result.errors)
|
||||
|
||||
def test_empty_manifest(self, tmp_dir):
|
||||
manifest = tmp_dir / "empty.jsonl"
|
||||
manifest.write_text("")
|
||||
result = validate_manifest(str(manifest))
|
||||
assert not result.is_valid
|
||||
|
||||
def test_invalid_json(self, tmp_dir):
|
||||
manifest = tmp_dir / "bad.jsonl"
|
||||
manifest.write_text("not json\n{bad json}\n")
|
||||
result = validate_manifest(str(manifest))
|
||||
assert len(result.errors) >= 2
|
||||
assert any("Invalid JSON" in e for e in result.errors)
|
||||
|
||||
def test_missing_columns(self, tmp_dir):
|
||||
manifest = tmp_dir / "missing.jsonl"
|
||||
_write_manifest(
|
||||
manifest,
|
||||
[
|
||||
{"text": "hello"}, # missing audio
|
||||
{"audio": "test.wav"}, # missing text
|
||||
],
|
||||
)
|
||||
result = validate_manifest(str(manifest))
|
||||
assert len(result.errors) >= 2
|
||||
assert any("'audio'" in e for e in result.errors)
|
||||
assert any("'text'" in e for e in result.errors)
|
||||
|
||||
def test_missing_audio_file(self, tmp_dir):
|
||||
manifest = tmp_dir / "missing_audio.jsonl"
|
||||
_write_manifest(
|
||||
manifest,
|
||||
[{"text": "hello", "audio": "/nonexistent/audio.wav"}],
|
||||
)
|
||||
result = validate_manifest(str(manifest))
|
||||
assert not result.is_valid
|
||||
assert any("not found" in e for e in result.errors)
|
||||
|
||||
def test_empty_text_warning(self, tmp_dir):
|
||||
audio = tmp_dir / "audio.wav"
|
||||
_create_wav(audio)
|
||||
manifest = tmp_dir / "empty_text.jsonl"
|
||||
_write_manifest(
|
||||
manifest,
|
||||
[{"text": "", "audio": str(audio)}],
|
||||
)
|
||||
result = validate_manifest(str(manifest))
|
||||
assert len(result.warnings) > 0
|
||||
assert any("Empty" in w for w in result.warnings)
|
||||
|
||||
def test_relative_audio_path(self, tmp_dir):
|
||||
audio = tmp_dir / "audio.wav"
|
||||
_create_wav(audio)
|
||||
manifest = tmp_dir / "rel.jsonl"
|
||||
_write_manifest(
|
||||
manifest,
|
||||
[{"text": "hello", "audio": "audio.wav"}],
|
||||
)
|
||||
result = validate_manifest(str(manifest))
|
||||
assert result.valid_samples == 1
|
||||
assert result.is_valid
|
||||
|
||||
def test_max_samples_limit(self, tmp_dir):
|
||||
audio = tmp_dir / "audio.wav"
|
||||
_create_wav(audio)
|
||||
manifest = tmp_dir / "many.jsonl"
|
||||
_write_manifest(
|
||||
manifest,
|
||||
[{"text": f"sample {i}", "audio": str(audio)} for i in range(100)],
|
||||
)
|
||||
result = validate_manifest(str(manifest), max_samples=10)
|
||||
assert result.total_samples == 10
|
||||
|
||||
def test_ref_audio_counted(self, tmp_dir):
|
||||
audio = tmp_dir / "audio.wav"
|
||||
ref = tmp_dir / "ref.wav"
|
||||
_create_wav(audio)
|
||||
_create_wav(ref)
|
||||
manifest = tmp_dir / "ref.jsonl"
|
||||
_write_manifest(
|
||||
manifest,
|
||||
[{"text": "hello", "audio": str(audio), "ref_audio": str(ref)}],
|
||||
)
|
||||
result = validate_manifest(str(manifest))
|
||||
assert result.has_ref_audio == 1
|
||||
|
||||
def test_validation_result_properties(self):
|
||||
r = ValidationResult(total_samples=5, valid_samples=5)
|
||||
assert r.is_valid
|
||||
|
||||
r2 = ValidationResult(total_samples=5, valid_samples=5, errors=["err"])
|
||||
assert not r2.is_valid
|
||||
|
||||
r3 = ValidationResult(total_samples=0, valid_samples=0)
|
||||
assert not r3.is_valid
|
||||
|
||||
def test_invalid_audio_not_counted_as_valid(self, tmp_dir):
|
||||
"""A row with a bad audio path must not increment valid_samples."""
|
||||
manifest = tmp_dir / "bad_audio.jsonl"
|
||||
_write_manifest(
|
||||
manifest,
|
||||
[{"text": "hello", "audio": "/nonexistent/audio.wav"}],
|
||||
)
|
||||
result = validate_manifest(str(manifest))
|
||||
assert result.total_samples == 1
|
||||
assert result.valid_samples == 0
|
||||
assert not result.is_valid
|
||||
assert any("not found" in e for e in result.errors)
|
||||
|
||||
def test_sample_rate_mismatch(self, tmp_dir):
|
||||
"""A file with a different sample rate should be reported as an error."""
|
||||
try:
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
pytest.skip("soundfile not available")
|
||||
|
||||
audio = tmp_dir / "audio_8k.wav"
|
||||
import numpy as np
|
||||
samples = np.zeros(8000, dtype=np.float32)
|
||||
sf.write(str(audio), samples, 8000)
|
||||
|
||||
manifest = tmp_dir / "sr_mismatch.jsonl"
|
||||
_write_manifest(manifest, [{"text": "hello", "audio": str(audio)}])
|
||||
|
||||
result = validate_manifest(str(manifest), sample_rate=16000)
|
||||
assert result.valid_samples == 0
|
||||
assert not result.is_valid
|
||||
assert any("Sample rate mismatch" in e or "sample rate" in e.lower() for e in result.errors)
|
||||
|
||||
def test_mixed_ref_audio_warns_for_each_missing(self, tmp_dir):
|
||||
"""Missing ref_audio entries should each generate a warning independently."""
|
||||
audio = tmp_dir / "audio.wav"
|
||||
ref_good = tmp_dir / "ref_good.wav"
|
||||
_create_wav(audio)
|
||||
_create_wav(ref_good)
|
||||
|
||||
manifest = tmp_dir / "mixed_ref.jsonl"
|
||||
_write_manifest(
|
||||
manifest,
|
||||
[
|
||||
{"text": "row1", "audio": str(audio), "ref_audio": str(ref_good)},
|
||||
{"text": "row2", "audio": str(audio), "ref_audio": "/nonexistent/ref.wav"},
|
||||
],
|
||||
)
|
||||
result = validate_manifest(str(manifest))
|
||||
assert result.has_ref_audio == 1
|
||||
assert any("ref_audio file not found" in w for w in result.warnings)
|
||||
|
||||
def test_cli_validate_exit_code(self, tmp_dir):
|
||||
"""validate subcommand must exit 1 on validation error (missing audio)."""
|
||||
import subprocess
|
||||
manifest = tmp_dir / "bad.jsonl"
|
||||
_write_manifest(manifest, [{"text": "hi", "audio": "/nonexistent/x.wav"}])
|
||||
|
||||
proc = subprocess.run(
|
||||
[sys.executable, "-m", "voxcpm.cli", "validate", "--manifest", str(manifest)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
assert proc.returncode == 1, f"Expected exit 1, got {proc.returncode}"
|
||||
assert "FAILED" in proc.stderr or "Audio file not found" in proc.stderr
|
||||
Reference in New Issue
Block a user