Compare commits
39 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2e4c601d1d | |||
| 525ebc65a3 | |||
| 564b98b2d5 | |||
| 19b6bf7590 | |||
| 86bff0fc82 | |||
| dd7b78f2c0 | |||
| 29577d57f8 | |||
| 4509becfde | |||
| cd79a647fa | |||
| 96d605b9de | |||
| a9b03a768c | |||
| 77f847fcba | |||
| d3cc88722c | |||
| ec2acec8a1 | |||
| 13605c5a0e | |||
| afa63e6195 | |||
| eae0a29908 | |||
| 35895982d7 | |||
| f7f1b78c4d | |||
| 38d61cdf03 | |||
| 1565e83efe | |||
| 61b36d4e56 | |||
| b1584aec7c | |||
| 4457617953 | |||
| 5510503182 | |||
| fb46aad9a5 | |||
| e4e049624c | |||
| abf01b9bf3 | |||
| 4f4a5b9f6c | |||
| 79c0cf68dd | |||
| 75cfa3e9b8 | |||
| 5611bd08a0 | |||
| 66205135fc | |||
| 364eff6840 | |||
| 6d10932b09 | |||
| 68af4fe502 | |||
| ee3649c1b3 | |||
| 82d77d445c | |||
| 8f95d13073 |
@@ -2,3 +2,6 @@ launch.json
|
|||||||
__pycache__
|
__pycache__
|
||||||
voxcpm.egg-info
|
voxcpm.egg-info
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
./pretrained_models/
|
||||||
|
app_local.py
|
||||||
|
models/
|
||||||
@@ -46,7 +46,7 @@ VoxCPM is a **tokenizer-free** Text-to-Speech system that directly generates con
|
|||||||
- 🎙️ **Ultimate Cloning** — Reproduce every vocal nuance: provide both reference audio and its transcript, and the model continues seamlessly from the reference, faithfully preserving every vocal detail — timbre, rhythm, emotion, and style (same as VoxCPM1.5)
|
- 🎙️ **Ultimate Cloning** — Reproduce every vocal nuance: provide both reference audio and its transcript, and the model continues seamlessly from the reference, faithfully preserving every vocal detail — timbre, rhythm, emotion, and style (same as VoxCPM1.5)
|
||||||
- 🔊 **48kHz High-Quality Audio** — Accepts 16kHz reference audio and directly outputs 48kHz studio-quality audio via AudioVAE V2's asymmetric encode/decode design, with built-in super-resolution — no external upsampler needed
|
- 🔊 **48kHz High-Quality Audio** — Accepts 16kHz reference audio and directly outputs 48kHz studio-quality audio via AudioVAE V2's asymmetric encode/decode design, with built-in super-resolution — no external upsampler needed
|
||||||
- 🧠 **Context-Aware Synthesis** — Automatically infers appropriate prosody and expressiveness from text content
|
- 🧠 **Context-Aware Synthesis** — Automatically infers appropriate prosody and expressiveness from text content
|
||||||
- ⚡ **Real-Time Streaming** — RTF as low as ~0.3 on NVIDIA RTX 4090, and ~0.13 accelerated by [Nano-VLLM](https://github.com/a710128/nanovllm-voxcpm)
|
- ⚡ **Real-Time Streaming** — RTF as low as ~0.3 on NVIDIA RTX 4090, and ~0.13 accelerated by [Nano-vLLM](https://github.com/a710128/nanovllm-voxcpm) or [vLLM-Omni](https://github.com/vllm-project/vllm-omni) — official vLLM omni-modal serving for VoxCPM2 with PagedAttention and an OpenAI-compatible API
|
||||||
- 📜 **Fully Open-Source & Commercial-Ready** — Weights and code released under the [Apache-2.0](LICENSE) license, free for commercial use
|
- 📜 **Fully Open-Source & Commercial-Ready** — Weights and code released under the [Apache-2.0](LICENSE) license, free for commercial use
|
||||||
|
|
||||||
|
|
||||||
@@ -92,7 +92,7 @@ Chinese Dialect: 四川话, 粤语, 吴语, 东北话, 河南话, 陕西话, 山
|
|||||||
pip install voxcpm
|
pip install voxcpm
|
||||||
```
|
```
|
||||||
|
|
||||||
> **Requirements:** Python ≥ 3.10, PyTorch ≥ 2.5.0, CUDA ≥ 12.0. See [Quick Start Docs](https://voxcpm.readthedocs.io/en/latest/quickstart.html) for details.
|
> **Requirements:** Python ≥ 3.10 (<3.13), PyTorch ≥ 2.5.0, CUDA ≥ 12.0. See [Quick Start Docs](https://voxcpm.readthedocs.io/en/latest/quickstart.html) for details.
|
||||||
|
|
||||||
### Python API
|
### Python API
|
||||||
|
|
||||||
@@ -123,12 +123,12 @@ pip install modelscope
|
|||||||
```
|
```
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from modelscope.hub.snapshot_download import snapshot_download
|
from modelscope import snapshot_download
|
||||||
|
snapshot_download("OpenBMB/VoxCPM2", local_dir='./pretrained_models/VoxCPM2') # specify the local directory to save the model
|
||||||
|
|
||||||
from voxcpm import VoxCPM
|
from voxcpm import VoxCPM
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
model = VoxCPM.from_pretrained("./pretrained_models/VoxCPM2", load_denoiser=False)
|
||||||
local_model_dir = snapshot_download("OpenBMB/VoxCPM2")
|
|
||||||
model = VoxCPM.from_pretrained(local_model_dir, load_denoiser=False)
|
|
||||||
|
|
||||||
wav = model.generate(
|
wav = model.generate(
|
||||||
text="VoxCPM2 is the current recommended release for realistic multilingual speech synthesis.",
|
text="VoxCPM2 is the current recommended release for realistic multilingual speech synthesis.",
|
||||||
@@ -239,7 +239,7 @@ voxcpm --help
|
|||||||
### Web Demo
|
### Web Demo
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python app.py # then open http://localhost:7860
|
python app.py --port 8808 # then open in browser: http://localhost:8808
|
||||||
```
|
```
|
||||||
|
|
||||||
### 🚢 Production Deployment (Nano-vLLM)
|
### 🚢 Production Deployment (Nano-vLLM)
|
||||||
@@ -262,6 +262,32 @@ server.stop()
|
|||||||
|
|
||||||
> **RTF as low as ~0.13 on NVIDIA RTX 4090** (vs ~0.3 with the standard PyTorch implementation), with support for batched concurrent requests and a FastAPI HTTP server. See the [Nano-vLLM-VoxCPM repo](https://github.com/a710128/nanovllm-voxcpm) for deployment details.
|
> **RTF as low as ~0.13 on NVIDIA RTX 4090** (vs ~0.3 with the standard PyTorch implementation), with support for batched concurrent requests and a FastAPI HTTP server. See the [Nano-vLLM-VoxCPM repo](https://github.com/a710128/nanovllm-voxcpm) for deployment details.
|
||||||
|
|
||||||
|
### 🏭 Production Serving (vLLM-Omni)
|
||||||
|
|
||||||
|
For production multi-tenant deployments, use [**vLLM-Omni**](https://github.com/vllm-project/vllm-omni) — the official vLLM project's omni-modal extension with native **VoxCPM2** support. PagedAttention KV cache, continuous batching, and a drop-in **OpenAI-compatible** `/v1/audio/speech` endpoint.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install from source (latest main — vllm-omni is rapidly evolving)
|
||||||
|
uv pip install vllm==0.19.0 --torch-backend=auto
|
||||||
|
git clone https://github.com/vllm-project/vllm-omni.git && cd vllm-omni
|
||||||
|
uv pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
See the [vLLM-Omni installation guide](https://vllm-omni.readthedocs.io/en/latest/getting_started/installation/) for other platforms (ROCm, XPU, MUSA, NPU) and Docker images.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Launch an OpenAI-compatible TTS server (--omni enables omni-modal serving)
|
||||||
|
vllm serve openbmb/VoxCPM2 --omni --port 8000
|
||||||
|
|
||||||
|
# Call it from any OpenAI client
|
||||||
|
curl http://localhost:8000/v1/audio/speech \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"model":"openbmb/VoxCPM2","input":"Hello from VoxCPM2 on vLLM-Omni!","voice":"default"}' \
|
||||||
|
--output out.wav
|
||||||
|
```
|
||||||
|
|
||||||
|
> Built on the upstream vLLM scheduler, with batched concurrent requests, streaming chunk delivery, and multi-GPU deployment out of the box. See the [VoxCPM2 example](https://github.com/vllm-project/vllm-omni/tree/main/examples/online_serving/voxcpm2) for full deployment recipes.
|
||||||
|
|
||||||
> **Full parameter reference, multi-scenario examples, and voice cloning tips →** [Quick Start Guide](https://voxcpm.readthedocs.io/en/latest/quickstart.html) | [Usage Guide](https://voxcpm.readthedocs.io/en/latest/usage_guide.html) | [Cookbook](https://voxcpm.readthedocs.io/en/latest/cookbook.html)
|
> **Full parameter reference, multi-scenario examples, and voice cloning tips →** [Quick Start Guide](https://voxcpm.readthedocs.io/en/latest/quickstart.html) | [Usage Guide](https://voxcpm.readthedocs.io/en/latest/usage_guide.html) | [Cookbook](https://voxcpm.readthedocs.io/en/latest/cookbook.html)
|
||||||
|
|
||||||
---
|
---
|
||||||
@@ -415,10 +441,54 @@ VoxCPM2 achieves state-of-the-art or comparable results on public zero-shot and
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
### Internal 30-Language ASR Benchmark
|
||||||
|
|
||||||
|
We additionally run an internal multilingual intelligibility benchmark with **30 languages × 500 samples**. ASR transcription is evaluated via **Gemini 3.1 Flash Lite API**.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><b>Internal 30-Language ASR Benchmark (click to expand)</b></summary>
|
||||||
|
|
||||||
|
| Language | Metric | VoxCPM2 | Fish S2-Pro |
|
||||||
|
|---|---:|---:|---:|
|
||||||
|
| ar (Arabic) | CER | 1.23% | 0.30% |
|
||||||
|
| da (Danish) | WER | 2.70% | 3.52% |
|
||||||
|
| de (German) | WER | 0.96% | 0.64% |
|
||||||
|
| el (Greek) | WER | 3.17% | 4.61% |
|
||||||
|
| en (English) | WER | 0.42% | 1.03% |
|
||||||
|
| es (Spanish) | WER | 1.33% | 0.64% |
|
||||||
|
| fi (Finnish) | WER | 2.24% | 2.80% |
|
||||||
|
| fr (French) | WER | 2.16% | 2.34% |
|
||||||
|
| he (Hebrew) | CER | 2.98% | 15.27% |
|
||||||
|
| hi (Hindi) | CER | 0.79% | 0.91% |
|
||||||
|
| id (Indonesian) | WER | 1.36% | 1.68% |
|
||||||
|
| it (Italian) | WER | 1.65% | 1.08% |
|
||||||
|
| ja (Japanese) | CER | 2.40% | 1.82% |
|
||||||
|
| km (Khmer) | CER | 2.05% | 75.15% |
|
||||||
|
| ko (Korean) | CER | 0.95% | 0.29% |
|
||||||
|
| lo (Lao) | CER | 1.90% | 87.40% |
|
||||||
|
| ms (Malay) | WER | 1.75% | 1.41% |
|
||||||
|
| my (Burmese) | CER | 1.42% | 85.27% |
|
||||||
|
| nl (Dutch) | WER | 1.25% | 1.68% |
|
||||||
|
| no (Norwegian) | WER | 2.49% | 3.76% |
|
||||||
|
| pl (Polish) | WER | 1.90% | 1.65% |
|
||||||
|
| pt (Portuguese) | WER | 1.48% | 1.49% |
|
||||||
|
| ru (Russian) | WER | 0.90% | 0.86% |
|
||||||
|
| sv (Swedish) | WER | 2.22% | 2.63% |
|
||||||
|
| sw (Swahili) | CER | 1.07% | 2.02% |
|
||||||
|
| th (Thai) | CER | 0.94% | 1.92% |
|
||||||
|
| tl (Tagalog) | WER | 2.63% | 4.00% |
|
||||||
|
| tr (Turkish) | WER | 1.65% | 1.65% |
|
||||||
|
| vi (Vietnamese) | WER | 1.56% | 5.56% |
|
||||||
|
| zh (Chinese) | CER | 0.92% | 1.02% |
|
||||||
|
| Average (30 languages) | | **1.68%** | - |
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
### InstructTTSEval
|
### InstructTTSEval
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><b>Instruction-Guided Voice Design Results</b></summary>
|
<summary><b>Instruction-Guided Voice Design Results (click to expand)</b></summary>
|
||||||
|
|
||||||
| Model | InstructTTSEval-ZH | | | InstructTTSEval-EN | | |
|
| Model | InstructTTSEval-ZH | | | InstructTTSEval-EN | | |
|
||||||
|-------|:---:|:----:|:----:|:----:|:----:|:----:|
|
|-------|:---:|:----:|:----:|:----:|:----:|:----:|
|
||||||
@@ -484,11 +554,13 @@ Full documentation: **[voxcpm.readthedocs.io](https://voxcpm.readthedocs.io/en/l
|
|||||||
| Project | Description |
|
| Project | Description |
|
||||||
|---|---|
|
|---|---|
|
||||||
| [**Nano-vLLM**](https://github.com/a710128/nanovllm-voxcpm) | High-throughput and Fast GPU serving |
|
| [**Nano-vLLM**](https://github.com/a710128/nanovllm-voxcpm) | High-throughput and Fast GPU serving |
|
||||||
|
| [**vLLM-Omni**](https://github.com/vllm-project/vllm-omni) | Official vLLM omni-modal serving for VoxCPM2 — PagedAttention, OpenAI-compatible API |
|
||||||
| [**VoxCPM.cpp**](https://github.com/bluryar/VoxCPM.cpp) | GGML/GGUF: CPU, CUDA, Vulkan inference |
|
| [**VoxCPM.cpp**](https://github.com/bluryar/VoxCPM.cpp) | GGML/GGUF: CPU, CUDA, Vulkan inference |
|
||||||
| [**VoxCPM-ONNX**](https://github.com/bluryar/VoxCPM-ONNX) | ONNX export for CPU inference |
|
| [**VoxCPM-ONNX**](https://github.com/bluryar/VoxCPM-ONNX) | ONNX export for CPU inference |
|
||||||
| [**VoxCPMANE**](https://github.com/0seba/VoxCPMANE) | Apple Neural Engine backend |
|
| [**VoxCPMANE**](https://github.com/0seba/VoxCPMANE) | Apple Neural Engine backend |
|
||||||
| [**voxcpm_rs**](https://github.com/madushan1000/voxcpm_rs) | Rust re-implementation |
|
| [**voxcpm_rs**](https://github.com/madushan1000/voxcpm_rs) | Rust re-implementation |
|
||||||
| [**ComfyUI-VoxCPM**](https://github.com/wildminder/ComfyUI-VoxCPM) | ComfyUI node-based workflows |
|
| [**ComfyUI-VoxCPM**](https://github.com/wildminder/ComfyUI-VoxCPM) | ComfyUI node-based workflows |
|
||||||
|
| [**ComfyUI_RH_VoxCPM**](https://github.com/HM-RunningHub/ComfyUI_RH_VoxCPM) | Feature-complete ComfyUI workflow for VoxCPM 2 with multi-speaker generation, LoRA, and auto-ASR |
|
||||||
| [**ComfyUI-VoxCPMTTS**](https://github.com/1038lab/ComfyUI-VoxCPMTTS) | ComfyUI TTS extension |
|
| [**ComfyUI-VoxCPMTTS**](https://github.com/1038lab/ComfyUI-VoxCPMTTS) | ComfyUI TTS extension |
|
||||||
| [**TTS WebUI**](https://github.com/rsxdalv/tts_webui_extension.vox_cpm) | Browser-based TTS extension |
|
| [**TTS WebUI**](https://github.com/rsxdalv/tts_webui_extension.vox_cpm) | Browser-based TTS extension |
|
||||||
|
|
||||||
|
|||||||
+80
-8
@@ -46,7 +46,7 @@ VoxCPM 是一个**无离散音频分词器**(Tokenizer-Free)的语音合成
|
|||||||
- 🎙️ **极致克隆** — 提供参考音频及其文本内容,模型接着参考音频进行无缝续写,从而精准还原声音细节特征(与 VoxCPM1.5 一致)
|
- 🎙️ **极致克隆** — 提供参考音频及其文本内容,模型接着参考音频进行无缝续写,从而精准还原声音细节特征(与 VoxCPM1.5 一致)
|
||||||
- 🔊 **48kHz 高质量音频** — 输入 16kHz 参考音频,通过 AudioVAE V2 的非对称编解码设计直接输出 48kHz 高质量音频,内置超分能力
|
- 🔊 **48kHz 高质量音频** — 输入 16kHz 参考音频,通过 AudioVAE V2 的非对称编解码设计直接输出 48kHz 高质量音频,内置超分能力
|
||||||
- 🧠 **语境感知合成** — 根据文本内容自动推断合适的韵律和表现力
|
- 🧠 **语境感知合成** — 根据文本内容自动推断合适的韵律和表现力
|
||||||
- ⚡ **实时流式合成** — 在 NVIDIA RTX 4090 上 RTF 低至 ~0.3,通过 [Nano-VLLM](https://github.com/a710128/nanovllm-voxcpm) 加速后可达 ~0.13
|
- ⚡ **实时流式合成** — 在 NVIDIA RTX 4090 上 RTF 低至 ~0.3,通过 [Nano-vLLM](https://github.com/a710128/nanovllm-voxcpm) 或 [vLLM-Omni](https://github.com/vllm-project/vllm-omni)(官方 vLLM 全模态服务,原生支持 VoxCPM2,提供 PagedAttention 与 OpenAI 兼容 API)加速后可达 ~0.13
|
||||||
- 📜 **完全开源,商用就绪** — 权重和代码基于 [Apache-2.0](LICENSE) 协议发布,免费商用
|
- 📜 **完全开源,商用就绪** — 权重和代码基于 [Apache-2.0](LICENSE) 协议发布,免费商用
|
||||||
|
|
||||||
<summary><b>🌍 支持的语言(30种)</b></summary>
|
<summary><b>🌍 支持的语言(30种)</b></summary>
|
||||||
@@ -91,7 +91,7 @@ VoxCPM 是一个**无离散音频分词器**(Tokenizer-Free)的语音合成
|
|||||||
pip install voxcpm
|
pip install voxcpm
|
||||||
```
|
```
|
||||||
|
|
||||||
> **环境要求:** Python ≥ 3.10,PyTorch ≥ 2.5.0,CUDA ≥ 12.0。详见 [快速开始文档](https://voxcpm.readthedocs.io/zh-cn/latest/quickstart.html)。
|
> **环境要求:** Python ≥ 3.10 (<3.13),PyTorch ≥ 2.5.0,CUDA ≥ 12.0。详见 [快速开始文档](https://voxcpm.readthedocs.io/zh-cn/latest/quickstart.html)。
|
||||||
|
|
||||||
### Python API
|
### Python API
|
||||||
|
|
||||||
@@ -122,12 +122,12 @@ pip install modelscope
|
|||||||
```
|
```
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from modelscope.hub.snapshot_download import snapshot_download
|
from modelscope import snapshot_download
|
||||||
|
snapshot_download("OpenBMB/VoxCPM2", local_dir='./pretrained_models/VoxCPM2') # 指定模型保存的本地路径
|
||||||
|
|
||||||
from voxcpm import VoxCPM
|
from voxcpm import VoxCPM
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
model = VoxCPM.from_pretrained('./pretrained_models/VoxCPM2', load_denoiser=False)
|
||||||
local_model_dir = snapshot_download("OpenBMB/VoxCPM2")
|
|
||||||
model = VoxCPM.from_pretrained(local_model_dir, load_denoiser=False)
|
|
||||||
|
|
||||||
wav = model.generate(
|
wav = model.generate(
|
||||||
text="VoxCPM2 是目前推荐使用的多语言语音合成版本。",
|
text="VoxCPM2 是目前推荐使用的多语言语音合成版本。",
|
||||||
@@ -238,7 +238,7 @@ voxcpm --help
|
|||||||
### Web Demo
|
### Web Demo
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python app.py # 然后打开 http://localhost:7860
|
python app.py --port 8808 # 然后在浏览器打开 http://localhost:8808
|
||||||
```
|
```
|
||||||
|
|
||||||
### 🚢 生产部署(Nano-vLLM)
|
### 🚢 生产部署(Nano-vLLM)
|
||||||
@@ -261,6 +261,32 @@ server.stop()
|
|||||||
|
|
||||||
> **在 NVIDIA RTX 4090 上 RTF 低至 ~0.13**(标准 PyTorch 实现约 ~0.3),支持批量并发请求和 FastAPI HTTP 服务。详见 [Nano-vLLM-VoxCPM 仓库](https://github.com/a710128/nanovllm-voxcpm)。
|
> **在 NVIDIA RTX 4090 上 RTF 低至 ~0.13**(标准 PyTorch 实现约 ~0.3),支持批量并发请求和 FastAPI HTTP 服务。详见 [Nano-vLLM-VoxCPM 仓库](https://github.com/a710128/nanovllm-voxcpm)。
|
||||||
|
|
||||||
|
### 🏭 生产环境部署(vLLM-Omni)
|
||||||
|
|
||||||
|
如需生产级多租户部署,使用 [**vLLM-Omni**](https://github.com/vllm-project/vllm-omni) — 官方 vLLM 项目的全模态扩展,原生支持 **VoxCPM2**。具备 PagedAttention KV 缓存、连续批处理,以及与 OpenAI 完全兼容的 `/v1/audio/speech` 接口。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 从源码安装(最新 main 分支 —— vllm-omni 正在快速迭代)
|
||||||
|
uv pip install vllm==0.19.0 --torch-backend=auto
|
||||||
|
git clone https://github.com/vllm-project/vllm-omni.git && cd vllm-omni
|
||||||
|
uv pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
其他平台(ROCm、XPU、MUSA、NPU)与 Docker 镜像请参考 [vLLM-Omni 安装文档](https://vllm-omni.readthedocs.io/en/latest/getting_started/installation/)。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 启动 OpenAI 兼容的 TTS 服务(--omni 启用全模态服务)
|
||||||
|
vllm serve openbmb/VoxCPM2 --omni --port 8000
|
||||||
|
|
||||||
|
# 任意 OpenAI 客户端均可调用
|
||||||
|
curl http://localhost:8000/v1/audio/speech \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"model":"openbmb/VoxCPM2","input":"你好,欢迎使用 VoxCPM2 on vLLM-Omni!","voice":"default"}' \
|
||||||
|
--output out.wav
|
||||||
|
```
|
||||||
|
|
||||||
|
> 基于上游 vLLM 调度器构建,开箱即用支持批量并发、流式分块输出和多 GPU 部署。完整示例见 [VoxCPM2 部署样例](https://github.com/vllm-project/vllm-omni/tree/main/examples/online_serving/voxcpm2)。
|
||||||
|
|
||||||
> **完整参数说明、多场景示例与声音克隆技巧 →** [快速开始指南](https://voxcpm.readthedocs.io/zh-cn/latest/quickstart.html) | [使用指南](https://voxcpm.readthedocs.io/zh-cn/latest/usage_guide.html) | [Cookbook](https://voxcpm.readthedocs.io/zh-cn/latest/cookbook.html)
|
> **完整参数说明、多场景示例与声音克隆技巧 →** [快速开始指南](https://voxcpm.readthedocs.io/zh-cn/latest/quickstart.html) | [使用指南](https://voxcpm.readthedocs.io/zh-cn/latest/usage_guide.html) | [Cookbook](https://voxcpm.readthedocs.io/zh-cn/latest/cookbook.html)
|
||||||
|
|
||||||
---
|
---
|
||||||
@@ -414,10 +440,54 @@ VoxCPM2 在公开的零样本和可控 TTS 基准测试中取得了 SOTA 或可
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
### Internal 30-Language ASR Benchmark
|
||||||
|
|
||||||
|
我们额外进行了内部多语言可懂度评测:**30 语种 × 500 样本**,ASR 转写评估使用 **Gemini 3.1 Flash Lite API**。
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><b>内部30语种评测集ASR结果(点击展开)</b></summary>
|
||||||
|
|
||||||
|
| 语言 | 指标 | VoxCPM2 | Fish S2-Pro |
|
||||||
|
|---|---:|---:|---:|
|
||||||
|
| ar (阿拉伯语) | CER | 1.23% | 0.30% |
|
||||||
|
| da (丹麦语) | WER | 2.70% | 3.52% |
|
||||||
|
| de (德语) | WER | 0.96% | 0.64% |
|
||||||
|
| el (希腊语) | WER | 3.17% | 4.61% |
|
||||||
|
| en (英语) | WER | 0.42% | 1.03% |
|
||||||
|
| es (西班牙语) | WER | 1.33% | 0.64% |
|
||||||
|
| fi (芬兰语) | WER | 2.24% | 2.80% |
|
||||||
|
| fr (法语) | WER | 2.16% | 2.34% |
|
||||||
|
| he (希伯来语) | CER | 2.98% | 15.27% |
|
||||||
|
| hi (印地语) | CER | 0.79% | 0.91% |
|
||||||
|
| id (印尼语) | WER | 1.36% | 1.68% |
|
||||||
|
| it (意大利语) | WER | 1.65% | 1.08% |
|
||||||
|
| ja (日语) | CER | 2.40% | 1.82% |
|
||||||
|
| km (高棉语) | CER | 2.05% | 75.15% |
|
||||||
|
| ko (韩语) | CER | 0.95% | 0.29% |
|
||||||
|
| lo (老挝语) | CER | 1.90% | 87.40% |
|
||||||
|
| ms (马来语) | WER | 1.75% | 1.41% |
|
||||||
|
| my (缅甸语) | CER | 1.42% | 85.27% |
|
||||||
|
| nl (荷兰语) | WER | 1.25% | 1.68% |
|
||||||
|
| no (挪威语) | WER | 2.49% | 3.76% |
|
||||||
|
| pl (波兰语) | WER | 1.90% | 1.65% |
|
||||||
|
| pt (葡萄牙语) | WER | 1.48% | 1.49% |
|
||||||
|
| ru (俄语) | WER | 0.90% | 0.86% |
|
||||||
|
| sv (瑞典语) | WER | 2.22% | 2.63% |
|
||||||
|
| sw (斯瓦希里语) | CER | 1.07% | 2.02% |
|
||||||
|
| th (泰语) | CER | 0.94% | 1.92% |
|
||||||
|
| tl (菲律宾语) | WER | 2.63% | 4.00% |
|
||||||
|
| tr (土耳其语) | WER | 1.65% | 1.65% |
|
||||||
|
| vi (越南语) | WER | 1.56% | 5.56% |
|
||||||
|
| zh (中文) | CER | 0.92% | 1.02% |
|
||||||
|
| 平均(30 语种) | | **1.68%** | - |
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
### InstructTTSEval
|
### InstructTTSEval
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><b>指令驱动音色设计结果</b></summary>
|
<summary><b>指令驱动音色设计结果 (点击展开)</b></summary>
|
||||||
|
|
||||||
| Model | InstructTTSEval-ZH | | | InstructTTSEval-EN | | |
|
| Model | InstructTTSEval-ZH | | | InstructTTSEval-EN | | |
|
||||||
|-------|:---:|:----:|:----:|:----:|:----:|:----:|
|
|-------|:---:|:----:|:----:|:----:|:----:|:----:|
|
||||||
@@ -477,11 +547,13 @@ python lora_ft_webui.py # 然后打开 http://localhost:7860
|
|||||||
| 项目 | 说明 |
|
| 项目 | 说明 |
|
||||||
|---|---|
|
|---|---|
|
||||||
| [**Nano-vLLM**](https://github.com/a710128/nanovllm-voxcpm) | 高吞吐快速 GPU 推理引擎 |
|
| [**Nano-vLLM**](https://github.com/a710128/nanovllm-voxcpm) | 高吞吐快速 GPU 推理引擎 |
|
||||||
|
| [**vLLM-Omni**](https://github.com/vllm-project/vllm-omni) | 官方 vLLM 全模态服务(原生支持 VoxCPM2)— PagedAttention、OpenAI 兼容 API |
|
||||||
| [**VoxCPM.cpp**](https://github.com/bluryar/VoxCPM.cpp) | GGML/GGUF:CPU、CUDA、Vulkan 推理 |
|
| [**VoxCPM.cpp**](https://github.com/bluryar/VoxCPM.cpp) | GGML/GGUF:CPU、CUDA、Vulkan 推理 |
|
||||||
| [**VoxCPM-ONNX**](https://github.com/bluryar/VoxCPM-ONNX) | ONNX 导出,支持 CPU 推理 |
|
| [**VoxCPM-ONNX**](https://github.com/bluryar/VoxCPM-ONNX) | ONNX 导出,支持 CPU 推理 |
|
||||||
| [**VoxCPMANE**](https://github.com/0seba/VoxCPMANE) | Apple Neural Engine 后端 |
|
| [**VoxCPMANE**](https://github.com/0seba/VoxCPMANE) | Apple Neural Engine 后端 |
|
||||||
| [**voxcpm_rs**](https://github.com/madushan1000/voxcpm_rs) | Rust 重新实现 |
|
| [**voxcpm_rs**](https://github.com/madushan1000/voxcpm_rs) | Rust 重新实现 |
|
||||||
| [**ComfyUI-VoxCPM**](https://github.com/wildminder/ComfyUI-VoxCPM) | ComfyUI 节点工作流 |
|
| [**ComfyUI-VoxCPM**](https://github.com/wildminder/ComfyUI-VoxCPM) | ComfyUI 节点工作流 |
|
||||||
|
| [**ComfyUI_RH_VoxCPM**](https://github.com/HM-RunningHub/ComfyUI_RH_VoxCPM) | 面向 VoxCPM 2 的功能更完整的 ComfyUI 工作流,支持多说话人、LoRA 和自动 ASR |
|
||||||
| [**ComfyUI-VoxCPMTTS**](https://github.com/1038lab/ComfyUI-VoxCPMTTS) | ComfyUI TTS 扩展 |
|
| [**ComfyUI-VoxCPMTTS**](https://github.com/1038lab/ComfyUI-VoxCPMTTS) | ComfyUI TTS 扩展 |
|
||||||
| [**TTS WebUI**](https://github.com/rsxdalv/tts_webui_extension.vox_cpm) | 浏览器端 TTS 扩展 |
|
| [**TTS WebUI**](https://github.com/rsxdalv/tts_webui_extension.vox_cpm) | 浏览器端 TTS 扩展 |
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import sys
|
import sys
|
||||||
import logging
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -9,8 +10,6 @@ from funasr import AutoModel
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
if os.environ.get("HF_REPO_ID", "").strip() == "":
|
|
||||||
os.environ["HF_REPO_ID"] = "openbmb/VoxCPM2"
|
|
||||||
|
|
||||||
import voxcpm
|
import voxcpm
|
||||||
|
|
||||||
@@ -55,7 +54,7 @@ _EXAMPLES_FOOTER_EN = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
_USAGE_INSTRUCTIONS_ZH = (
|
_USAGE_INSTRUCTIONS_ZH = (
|
||||||
"**VoxCPM2 — 三种语音生成方式:**\n\n"
|
"**三种语音生成方式:**\n\n"
|
||||||
"🎨 **声音设计(Voice Design)** \n"
|
"🎨 **声音设计(Voice Design)** \n"
|
||||||
"无需参考音频。在 **Control Instruction** 中描述目标音色特征"
|
"无需参考音频。在 **Control Instruction** 中描述目标音色特征"
|
||||||
"(性别、年龄、语气、情绪、语速等),VoxCPM2 即可为你从零创造独一无二的声音。\n\n"
|
"(性别、年龄、语气、情绪、语速等),VoxCPM2 即可为你从零创造独一无二的声音。\n\n"
|
||||||
@@ -66,6 +65,8 @@ _USAGE_INSTRUCTIONS_ZH = (
|
|||||||
"开启 **极致克隆模式** 并提供参考音频的文字内容(可自动识别)。"
|
"开启 **极致克隆模式** 并提供参考音频的文字内容(可自动识别)。"
|
||||||
"模型会将参考音频视为已说出的前文,以**音频续写**的方式完整还原参考音频中的所有声音细节。"
|
"模型会将参考音频视为已说出的前文,以**音频续写**的方式完整还原参考音频中的所有声音细节。"
|
||||||
"注意:该模式与可控克隆模式互斥,将禁用Control Instruction。\n\n"
|
"注意:该模式与可控克隆模式互斥,将禁用Control Instruction。\n\n"
|
||||||
|
"目前支持的方言包括:\n"
|
||||||
|
"「四川话、粤语、吴语、东北话、河南话、陕西话、山东话、天津话、闽南话」"
|
||||||
)
|
)
|
||||||
|
|
||||||
_EXAMPLES_FOOTER_ZH = (
|
_EXAMPLES_FOOTER_ZH = (
|
||||||
@@ -221,11 +222,11 @@ _APP_THEME = gr.themes.Soft(
|
|||||||
# ---------- Model ----------
|
# ---------- Model ----------
|
||||||
|
|
||||||
class VoxCPMDemo:
|
class VoxCPMDemo:
|
||||||
def __init__(self, model_dir: Optional[str] = None) -> None:
|
def __init__(self, model_id: str = "openbmb/VoxCPM2") -> None:
|
||||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
logger.info(f"Running on device: {self.device}")
|
logger.info(f"运行在设备上: {self.device}")
|
||||||
|
|
||||||
self.asr_model_id = "iic/SenseVoiceSmall"
|
self.asr_model_id = "./models/iic/SenseVoiceSmall"
|
||||||
self.asr_model: Optional[AutoModel] = AutoModel(
|
self.asr_model: Optional[AutoModel] = AutoModel(
|
||||||
model=self.asr_model_id,
|
model=self.asr_model_id,
|
||||||
disable_update=True,
|
disable_update=True,
|
||||||
@@ -234,36 +235,13 @@ class VoxCPMDemo:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.voxcpm_model: Optional[voxcpm.VoxCPM] = None
|
self.voxcpm_model: Optional[voxcpm.VoxCPM] = None
|
||||||
self.explicit_model_dir = model_dir
|
self._model_id = model_id
|
||||||
|
|
||||||
def _resolve_model_dir(self) -> str:
|
|
||||||
if self.explicit_model_dir and os.path.isdir(self.explicit_model_dir):
|
|
||||||
return self.explicit_model_dir
|
|
||||||
env_model_dir = os.environ.get("VOXCPM_MODEL_DIR", "").strip()
|
|
||||||
if env_model_dir and os.path.isdir(env_model_dir):
|
|
||||||
return env_model_dir
|
|
||||||
repo_id = os.environ.get("HF_REPO_ID", "").strip()
|
|
||||||
if len(repo_id) > 0:
|
|
||||||
target_dir = os.path.join("models", repo_id.replace("/", "__"))
|
|
||||||
if not os.path.isdir(target_dir):
|
|
||||||
try:
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
os.makedirs(target_dir, exist_ok=True)
|
|
||||||
logger.info(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...")
|
|
||||||
snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"HF download failed: {e}. Falling back to 'models'.")
|
|
||||||
return "models"
|
|
||||||
return target_dir
|
|
||||||
return "models"
|
|
||||||
|
|
||||||
def get_or_load_voxcpm(self) -> voxcpm.VoxCPM:
|
def get_or_load_voxcpm(self) -> voxcpm.VoxCPM:
|
||||||
if self.voxcpm_model is not None:
|
if self.voxcpm_model is not None:
|
||||||
return self.voxcpm_model
|
return self.voxcpm_model
|
||||||
logger.info("Model not loaded, initializing...")
|
logger.info(f"Loading model: {self._model_id}")
|
||||||
model_dir = self._resolve_model_dir()
|
self.voxcpm_model = voxcpm.VoxCPM.from_pretrained(self._model_id, optimize=True,zipenhancer_model_id="./models/iic/speech_zipenhancer_ans_multiloss_16k_base")
|
||||||
logger.info(f"Using model dir: {model_dir}")
|
|
||||||
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir, optimize=True)
|
|
||||||
logger.info("Model loaded successfully.")
|
logger.info("Model loaded successfully.")
|
||||||
return self.voxcpm_model
|
return self.voxcpm_model
|
||||||
|
|
||||||
@@ -315,6 +293,9 @@ class VoxCPMDemo:
|
|||||||
raise ValueError("Please input text to synthesize.")
|
raise ValueError("Please input text to synthesize.")
|
||||||
|
|
||||||
control = (control_instruction or "").strip()
|
control = (control_instruction or "").strip()
|
||||||
|
# Strip any parentheses (half-width/full-width) from control text to avoid
|
||||||
|
# breaking the "(control)text" prompt format expected by the model.
|
||||||
|
control = re.sub(r"[()()]", "", control).strip()
|
||||||
final_text = f"({control}){text}" if control else text
|
final_text = f"({control}){text}" if control else text
|
||||||
|
|
||||||
audio_path = reference_wav_path_input if reference_wav_path_input else None
|
audio_path = reference_wav_path_input if reference_wav_path_input else None
|
||||||
@@ -507,9 +488,9 @@ def run_demo(
|
|||||||
server_name: str = "0.0.0.0",
|
server_name: str = "0.0.0.0",
|
||||||
server_port: int = 8808,
|
server_port: int = 8808,
|
||||||
show_error: bool = True,
|
show_error: bool = True,
|
||||||
model_dir: Optional[str] = None,
|
model_id: str = "./models/OpenBMB/VoxCPM2",
|
||||||
):
|
):
|
||||||
demo = VoxCPMDemo(model_dir=model_dir)
|
demo = VoxCPMDemo(model_id=model_id)
|
||||||
interface = create_demo_interface(demo)
|
interface = create_demo_interface(demo)
|
||||||
interface.queue(max_size=10, default_concurrency_limit=1).launch(
|
interface.queue(max_size=10, default_concurrency_limit=1).launch(
|
||||||
server_name=server_name,
|
server_name=server_name,
|
||||||
@@ -524,7 +505,10 @@ def run_demo(
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import argparse
|
import argparse
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--model-dir", type=str, default=None, help="Path to VoxCPM2 checkpoint directory")
|
parser.add_argument(
|
||||||
parser.add_argument("--port", type=int, default=8808, help="Server port")
|
"--model-id", type=str, default="./models/OpenBMB/VoxCPM2",
|
||||||
|
help="本地路径或HuggingFace仓库ID(默认:./models/openbmb/VoxCPM2)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--port", type=int, default=8808, help="服务端口")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
run_demo(model_dir=args.model_dir, server_port=args.port)
|
run_demo(model_id=args.model_id, server_port=args.port)
|
||||||
-280
@@ -1,280 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import gradio as gr
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
from funasr import AutoModel
|
|
||||||
from pathlib import Path
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
||||||
if os.environ.get("HF_REPO_ID", "").strip() == "":
|
|
||||||
os.environ["HF_REPO_ID"] = "openbmb/VoxCPM1.5"
|
|
||||||
|
|
||||||
import voxcpm
|
|
||||||
|
|
||||||
|
|
||||||
class VoxCPMDemo:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
print(f"🚀 Running on device: {self.device}", file=sys.stderr)
|
|
||||||
|
|
||||||
# ASR model for prompt text recognition
|
|
||||||
self.asr_model_id = "iic/SenseVoiceSmall"
|
|
||||||
self.asr_model: Optional[AutoModel] = AutoModel(
|
|
||||||
model=self.asr_model_id,
|
|
||||||
disable_update=True,
|
|
||||||
log_level='DEBUG',
|
|
||||||
device="cuda:0" if self.device == "cuda" else "cpu",
|
|
||||||
)
|
|
||||||
|
|
||||||
# TTS model (lazy init)
|
|
||||||
self.voxcpm_model: Optional[voxcpm.VoxCPM] = None
|
|
||||||
self.default_local_model_dir = "./models/VoxCPM1.5"
|
|
||||||
|
|
||||||
# ---------- Model helpers ----------
|
|
||||||
def _resolve_model_dir(self) -> str:
|
|
||||||
"""
|
|
||||||
Resolve model directory:
|
|
||||||
1) Use local checkpoint directory if exists
|
|
||||||
2) If HF_REPO_ID env is set, download into models/{repo}
|
|
||||||
3) Fallback to 'models'
|
|
||||||
"""
|
|
||||||
if os.path.isdir(self.default_local_model_dir):
|
|
||||||
return self.default_local_model_dir
|
|
||||||
|
|
||||||
repo_id = os.environ.get("HF_REPO_ID", "").strip()
|
|
||||||
if len(repo_id) > 0:
|
|
||||||
target_dir = os.path.join("models", repo_id.replace("/", "__"))
|
|
||||||
if not os.path.isdir(target_dir):
|
|
||||||
try:
|
|
||||||
from huggingface_hub import snapshot_download # type: ignore
|
|
||||||
os.makedirs(target_dir, exist_ok=True)
|
|
||||||
print(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...", file=sys.stderr)
|
|
||||||
snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Warning: HF download failed: {e}. Falling back to 'data'.", file=sys.stderr)
|
|
||||||
return "models"
|
|
||||||
return target_dir
|
|
||||||
return "models"
|
|
||||||
|
|
||||||
def get_or_load_voxcpm(self) -> voxcpm.VoxCPM:
|
|
||||||
if self.voxcpm_model is not None:
|
|
||||||
return self.voxcpm_model
|
|
||||||
print("Model not loaded, initializing...", file=sys.stderr)
|
|
||||||
model_dir = self._resolve_model_dir()
|
|
||||||
print(f"Using model dir: {model_dir}", file=sys.stderr)
|
|
||||||
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir)
|
|
||||||
print("Model loaded successfully.", file=sys.stderr)
|
|
||||||
return self.voxcpm_model
|
|
||||||
|
|
||||||
# ---------- Functional endpoints ----------
|
|
||||||
def prompt_wav_recognition(self, prompt_wav: Optional[str]) -> str:
|
|
||||||
if prompt_wav is None:
|
|
||||||
return ""
|
|
||||||
res = self.asr_model.generate(input=prompt_wav, language="auto", use_itn=True)
|
|
||||||
text = res[0]["text"].split('|>')[-1]
|
|
||||||
return text
|
|
||||||
|
|
||||||
def generate_tts_audio(
|
|
||||||
self,
|
|
||||||
text_input: str,
|
|
||||||
prompt_wav_path_input: Optional[str] = None,
|
|
||||||
prompt_text_input: Optional[str] = None,
|
|
||||||
cfg_value_input: float = 2.0,
|
|
||||||
inference_timesteps_input: int = 10,
|
|
||||||
do_normalize: bool = True,
|
|
||||||
denoise: bool = True,
|
|
||||||
) -> Tuple[int, np.ndarray]:
|
|
||||||
"""
|
|
||||||
Generate speech from text using VoxCPM; optional reference audio for voice style guidance.
|
|
||||||
Returns (sample_rate, waveform_numpy)
|
|
||||||
"""
|
|
||||||
current_model = self.get_or_load_voxcpm()
|
|
||||||
|
|
||||||
text = (text_input or "").strip()
|
|
||||||
if len(text) == 0:
|
|
||||||
raise ValueError("Please input text to synthesize.")
|
|
||||||
|
|
||||||
prompt_wav_path = prompt_wav_path_input if prompt_wav_path_input else None
|
|
||||||
prompt_text = prompt_text_input if prompt_text_input else None
|
|
||||||
|
|
||||||
print(f"Generating audio for text: '{text[:60]}...'", file=sys.stderr)
|
|
||||||
wav = current_model.generate(
|
|
||||||
text=text,
|
|
||||||
prompt_text=prompt_text,
|
|
||||||
prompt_wav_path=prompt_wav_path,
|
|
||||||
cfg_value=float(cfg_value_input),
|
|
||||||
inference_timesteps=int(inference_timesteps_input),
|
|
||||||
normalize=do_normalize,
|
|
||||||
denoise=denoise,
|
|
||||||
)
|
|
||||||
return (current_model.tts_model.sample_rate, wav)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------- UI Builders ----------
|
|
||||||
|
|
||||||
_APP_THEME = gr.themes.Soft(
|
|
||||||
primary_hue="blue",
|
|
||||||
secondary_hue="gray",
|
|
||||||
neutral_hue="slate",
|
|
||||||
font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"],
|
|
||||||
)
|
|
||||||
|
|
||||||
_CUSTOM_CSS = """
|
|
||||||
.logo-container {
|
|
||||||
text-align: center;
|
|
||||||
margin: 0.5rem 0 1rem 0;
|
|
||||||
}
|
|
||||||
.logo-container img {
|
|
||||||
height: 80px;
|
|
||||||
width: auto;
|
|
||||||
max-width: 200px;
|
|
||||||
display: inline-block;
|
|
||||||
}
|
|
||||||
/* Bold accordion labels */
|
|
||||||
#acc_quick details > summary,
|
|
||||||
#acc_tips details > summary {
|
|
||||||
font-weight: 600 !important;
|
|
||||||
font-size: 1.1em !important;
|
|
||||||
}
|
|
||||||
/* Bold labels for specific checkboxes */
|
|
||||||
#chk_denoise label,
|
|
||||||
#chk_denoise span,
|
|
||||||
#chk_normalize label,
|
|
||||||
#chk_normalize span {
|
|
||||||
font-weight: 600;
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def create_demo_interface(demo: VoxCPMDemo):
|
|
||||||
"""Build the Gradio UI for VoxCPM demo."""
|
|
||||||
gr.set_static_paths(paths=[Path.cwd().absolute()/"assets"])
|
|
||||||
|
|
||||||
with gr.Blocks() as interface:
|
|
||||||
# Header logo
|
|
||||||
gr.HTML('<div class="logo-container"><img src="/gradio_api/file=assets/voxcpm_logo.png" alt="VoxCPM Logo"></div>')
|
|
||||||
|
|
||||||
# Quick Start
|
|
||||||
with gr.Accordion("📋 Quick Start Guide |快速入门", open=False, elem_id="acc_quick"):
|
|
||||||
gr.Markdown("""
|
|
||||||
### How to Use |使用说明
|
|
||||||
1. **(Optional) Provide a Voice Prompt** - Upload or record an audio clip to provide the desired voice characteristics for synthesis.
|
|
||||||
**(可选)提供参考声音** - 上传或录制一段音频,为声音合成提供音色、语调和情感等个性化特征
|
|
||||||
2. **(Optional) Enter prompt text** - If you provided a voice prompt, enter the corresponding transcript here (auto-recognition available).
|
|
||||||
**(可选项)输入参考文本** - 如果提供了参考语音,请输入其对应的文本内容(支持自动识别)。
|
|
||||||
3. **Enter target text** - Type the text you want the model to speak.
|
|
||||||
**输入目标文本** - 输入您希望模型朗读的文字内容。
|
|
||||||
4. **Generate Speech** - Click the "Generate" button to create your audio.
|
|
||||||
**生成语音** - 点击"生成"按钮,即可为您创造出音频。
|
|
||||||
""")
|
|
||||||
|
|
||||||
# Pro Tips
|
|
||||||
with gr.Accordion("💡 Pro Tips |使用建议", open=False, elem_id="acc_tips"):
|
|
||||||
gr.Markdown("""
|
|
||||||
### Prompt Speech Enhancement|参考语音降噪
|
|
||||||
- **Enable** to remove background noise for a clean voice, with an external ZipEnhancer component. However, this will limit the audio sampling rate to 16kHz, restricting the cloning quality ceiling.
|
|
||||||
**启用**:通过 ZipEnhancer 组件消除背景噪音,但会将音频采样率限制在16kHz,限制克隆上限。
|
|
||||||
- **Disable** to preserve the original audio's all information, including background atmosphere, and support audio cloning up to 44.1kHz sampling rate.
|
|
||||||
**禁用**:保留原始音频的全部信息,包括背景环境声,最高支持44.1kHz的音频复刻。
|
|
||||||
|
|
||||||
### Text Normalization|文本正则化
|
|
||||||
- **Enable** to process general text with an external WeTextProcessing component.
|
|
||||||
**启用**:使用 WeTextProcessing 组件,可支持常见文本的正则化处理。
|
|
||||||
- **Disable** to use VoxCPM's native text understanding ability. For example, it supports phonemes input (For Chinese, phonemes are converted using pinyin, {ni3}{hao3}; For English, phonemes are converted using CMUDict, {HH AH0 L OW1}), try it!
|
|
||||||
**禁用**:将使用 VoxCPM 内置的文本理解能力。如,支持音素输入(如中文转拼音:{ni3}{hao3};英文转CMUDict:{HH AH0 L OW1})和公式符号合成,尝试一下!
|
|
||||||
|
|
||||||
### CFG Value|CFG 值
|
|
||||||
- **Lower CFG** if the voice prompt sounds strained or expressive, or instability occurs with long text input.
|
|
||||||
**调低**:如果提示语音听起来不自然或过于夸张,或者长文本输入出现稳定性问题。
|
|
||||||
- **Higher CFG** for better adherence to the prompt speech style or input text, or instability occurs with too short text input.
|
|
||||||
**调高**:为更好地贴合提示音频的风格或输入文本, 或者极短文本输入出现稳定性问题。
|
|
||||||
|
|
||||||
### Inference Timesteps|推理时间步
|
|
||||||
- **Lower** for faster synthesis speed.
|
|
||||||
**调低**:合成速度更快。
|
|
||||||
- **Higher** for better synthesis quality.
|
|
||||||
**调高**:合成质量更佳。
|
|
||||||
""")
|
|
||||||
|
|
||||||
# Main controls
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
prompt_wav = gr.Audio(
|
|
||||||
sources=["upload", 'microphone'],
|
|
||||||
type="filepath",
|
|
||||||
label="Prompt Speech (Optional, or let VoxCPM improvise)",
|
|
||||||
value="./examples/example.wav",
|
|
||||||
)
|
|
||||||
DoDenoisePromptAudio = gr.Checkbox(
|
|
||||||
value=False,
|
|
||||||
label="Prompt Speech Enhancement",
|
|
||||||
elem_id="chk_denoise",
|
|
||||||
info="We use ZipEnhancer model to denoise the prompt audio."
|
|
||||||
)
|
|
||||||
with gr.Row():
|
|
||||||
prompt_text = gr.Textbox(
|
|
||||||
value="Just by listening a few minutes a day, you'll be able to eliminate negative thoughts by conditioning your mind to be more positive.",
|
|
||||||
label="Prompt Text",
|
|
||||||
placeholder="Please enter the prompt text. Automatic recognition is supported, and you can correct the results yourself..."
|
|
||||||
)
|
|
||||||
run_btn = gr.Button("Generate Speech", variant="primary")
|
|
||||||
|
|
||||||
with gr.Column():
|
|
||||||
cfg_value = gr.Slider(
|
|
||||||
minimum=1.0,
|
|
||||||
maximum=3.0,
|
|
||||||
value=2.0,
|
|
||||||
step=0.1,
|
|
||||||
label="CFG Value (Guidance Scale)",
|
|
||||||
info="Higher values increase adherence to prompt, lower values allow more creativity"
|
|
||||||
)
|
|
||||||
inference_timesteps = gr.Slider(
|
|
||||||
minimum=4,
|
|
||||||
maximum=30,
|
|
||||||
value=10,
|
|
||||||
step=1,
|
|
||||||
label="Inference Timesteps",
|
|
||||||
info="Number of inference timesteps for generation (higher values may improve quality but slower)"
|
|
||||||
)
|
|
||||||
with gr.Row():
|
|
||||||
text = gr.Textbox(
|
|
||||||
value="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly realistic speech.",
|
|
||||||
label="Target Text",
|
|
||||||
)
|
|
||||||
with gr.Row():
|
|
||||||
DoNormalizeText = gr.Checkbox(
|
|
||||||
value=False,
|
|
||||||
label="Text Normalization",
|
|
||||||
elem_id="chk_normalize",
|
|
||||||
info="We use wetext library to normalize the input text."
|
|
||||||
)
|
|
||||||
audio_output = gr.Audio(label="Output Audio")
|
|
||||||
|
|
||||||
# Wiring
|
|
||||||
run_btn.click(
|
|
||||||
fn=demo.generate_tts_audio,
|
|
||||||
inputs=[text, prompt_wav, prompt_text, cfg_value, inference_timesteps, DoNormalizeText, DoDenoisePromptAudio],
|
|
||||||
outputs=[audio_output],
|
|
||||||
show_progress=True,
|
|
||||||
api_name="generate",
|
|
||||||
)
|
|
||||||
prompt_wav.change(fn=demo.prompt_wav_recognition, inputs=[prompt_wav], outputs=[prompt_text])
|
|
||||||
|
|
||||||
return interface
|
|
||||||
|
|
||||||
|
|
||||||
def run_demo(server_name: str = "localhost", server_port: int = 7860, show_error: bool = True):
|
|
||||||
demo = VoxCPMDemo()
|
|
||||||
interface = create_demo_interface(demo)
|
|
||||||
interface.queue(max_size=10, default_concurrency_limit=1).launch(
|
|
||||||
server_name=server_name,
|
|
||||||
server_port=server_port,
|
|
||||||
show_error=show_error,
|
|
||||||
theme=_APP_THEME,
|
|
||||||
css=_CUSTOM_CSS,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
run_demo()
|
|
||||||
@@ -2,7 +2,7 @@ pretrained_path: /path/to/VoxCPM2/
|
|||||||
train_manifest: /path/to/train.jsonl
|
train_manifest: /path/to/train.jsonl
|
||||||
val_manifest: null
|
val_manifest: null
|
||||||
sample_rate: 16000 # AudioVAE encoder input rate; must match audio_vae_config.sample_rate
|
sample_rate: 16000 # AudioVAE encoder input rate; must match audio_vae_config.sample_rate
|
||||||
out_sample_rate: 48000 # AudioVAE decoder output rate; only used at inference, not during training
|
out_sample_rate: 48000 # AudioVAE decoder output rate; used for TensorBoard audio logging
|
||||||
batch_size: 2
|
batch_size: 2
|
||||||
grad_accum_steps: 8 # effective batch size = batch_size × grad_accum_steps = 16
|
grad_accum_steps: 8 # effective batch size = batch_size × grad_accum_steps = 16
|
||||||
num_workers: 8
|
num_workers: 8
|
||||||
@@ -15,6 +15,7 @@ weight_decay: 0.01
|
|||||||
warmup_steps: 100
|
warmup_steps: 100
|
||||||
max_steps: 1000
|
max_steps: 1000
|
||||||
max_batch_tokens: 8192
|
max_batch_tokens: 8192
|
||||||
|
max_grad_norm: 1.0 # gradient clipping max norm; 0 = disabled
|
||||||
save_path: /path/to/checkpoints/finetune_all
|
save_path: /path/to/checkpoints/finetune_all
|
||||||
tensorboard: /path/to/logs/finetune_all
|
tensorboard: /path/to/logs/finetune_all
|
||||||
lambdas:
|
lambdas:
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ pretrained_path: /path/to/VoxCPM2/
|
|||||||
train_manifest: /path/to/train.jsonl
|
train_manifest: /path/to/train.jsonl
|
||||||
val_manifest: null
|
val_manifest: null
|
||||||
sample_rate: 16000 # AudioVAE encoder input rate; must match audio_vae_config.sample_rate
|
sample_rate: 16000 # AudioVAE encoder input rate; must match audio_vae_config.sample_rate
|
||||||
out_sample_rate: 48000 # AudioVAE decoder output rate; only used at inference, not during training
|
out_sample_rate: 48000 # AudioVAE decoder output rate; used for TensorBoard audio logging
|
||||||
batch_size: 2
|
batch_size: 2
|
||||||
grad_accum_steps: 8 # effective batch size = batch_size × grad_accum_steps = 16
|
grad_accum_steps: 8 # effective batch size = batch_size × grad_accum_steps = 16
|
||||||
num_workers: 8
|
num_workers: 8
|
||||||
@@ -15,6 +15,7 @@ weight_decay: 0.01
|
|||||||
warmup_steps: 100
|
warmup_steps: 100
|
||||||
max_steps: 1000
|
max_steps: 1000
|
||||||
max_batch_tokens: 8192
|
max_batch_tokens: 8192
|
||||||
|
max_grad_norm: 1.0 # gradient clipping max norm; 0 = disabled
|
||||||
save_path: /path/to/checkpoints/finetune_lora
|
save_path: /path/to/checkpoints/finetune_lora
|
||||||
tensorboard: /path/to/logs/finetune_lora
|
tensorboard: /path/to/logs/finetune_lora
|
||||||
lambdas:
|
lambdas:
|
||||||
|
|||||||
+58
-12
@@ -14,8 +14,10 @@ from typing import Optional
|
|||||||
project_root = Path(__file__).parent
|
project_root = Path(__file__).parent
|
||||||
sys.path.insert(0, str(project_root / "src"))
|
sys.path.insert(0, str(project_root / "src"))
|
||||||
|
|
||||||
# Default pretrained model path relative to this repo
|
# Default pretrained model path: prefer VoxCPM2 if it exists, fallback to VoxCPM1.5
|
||||||
default_pretrained_path = str(project_root / "models" / "openbmb__VoxCPM1.5")
|
_v2_path = project_root / "models" / "openbmb__VoxCPM2"
|
||||||
|
_v15_path = project_root / "models" / "openbmb__VoxCPM1.5"
|
||||||
|
default_pretrained_path = str(_v2_path if _v2_path.exists() else _v15_path)
|
||||||
|
|
||||||
from voxcpm.core import VoxCPM
|
from voxcpm.core import VoxCPM
|
||||||
from voxcpm.model.voxcpm import LoRAConfig
|
from voxcpm.model.voxcpm import LoRAConfig
|
||||||
@@ -279,27 +281,48 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
|
|||||||
print(f"Warning: Failed to read base_model from LoRA config: {e}", file=sys.stderr)
|
print(f"Warning: Failed to read base_model from LoRA config: {e}", file=sys.stderr)
|
||||||
|
|
||||||
# 加载模型
|
# 加载模型
|
||||||
|
lora_to_load = lora_selection if lora_selection and lora_selection != "None" else None
|
||||||
try:
|
try:
|
||||||
print(f"Loading base model: {base_model_path}", file=sys.stderr)
|
print(f"Loading base model: {base_model_path}", file=sys.stderr)
|
||||||
load_model(base_model_path)
|
load_model(base_model_path, lora_to_load)
|
||||||
if lora_selection and lora_selection != "None":
|
if lora_to_load:
|
||||||
print(f"Model loaded for LoRA: {lora_selection}", file=sys.stderr)
|
print(f"Model loaded with LoRA: {lora_selection}", file=sys.stderr)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Failed to load model from {base_model_path}: {str(e)}"
|
error_msg = f"Failed to load model from {base_model_path}: {str(e)}"
|
||||||
print(error_msg, file=sys.stderr)
|
print(error_msg, file=sys.stderr)
|
||||||
return None, error_msg
|
return None, error_msg
|
||||||
|
lora_just_loaded = lora_to_load
|
||||||
|
else:
|
||||||
|
lora_just_loaded = None
|
||||||
|
|
||||||
# Handle LoRA hot-swapping
|
# Handle LoRA hot-swapping
|
||||||
assert current_model is not None, "Model must be loaded before inference"
|
assert current_model is not None, "Model must be loaded before inference"
|
||||||
if lora_selection and lora_selection != "None":
|
if lora_selection and lora_selection != "None":
|
||||||
full_lora_path = os.path.join("lora", lora_selection)
|
full_lora_path = os.path.join("lora", lora_selection)
|
||||||
|
|
||||||
|
if lora_just_loaded != lora_selection:
|
||||||
|
new_lora_config, new_base_model = load_lora_config_from_checkpoint(full_lora_path)
|
||||||
|
current_r = current_model.tts_model.lora_config.r if current_model.tts_model.lora_config else None
|
||||||
|
new_r = new_lora_config.r if new_lora_config else None
|
||||||
|
|
||||||
|
if new_r is not None and current_r is not None and new_r != current_r:
|
||||||
|
print(f"LoRA rank mismatch (model r={current_r}, checkpoint r={new_r}), reloading...", file=sys.stderr)
|
||||||
|
reload_base = (
|
||||||
|
new_base_model if new_base_model and os.path.exists(new_base_model)
|
||||||
|
else (pretrained_path if pretrained_path and pretrained_path.strip() else default_pretrained_path)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
load_model(reload_base, lora_selection)
|
||||||
|
except Exception as e:
|
||||||
|
return None, f"Failed to reload model for LoRA rank change: {e}"
|
||||||
|
else:
|
||||||
print(f"Hot-loading LoRA: {full_lora_path}", file=sys.stderr)
|
print(f"Hot-loading LoRA: {full_lora_path}", file=sys.stderr)
|
||||||
try:
|
try:
|
||||||
current_model.load_lora(full_lora_path)
|
current_model.load_lora(full_lora_path)
|
||||||
current_model.set_lora_enabled(True)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error loading LoRA: {e}", file=sys.stderr)
|
print(f"Error loading LoRA: {e}", file=sys.stderr)
|
||||||
return None, f"Error loading LoRA: {e}"
|
return None, f"Error loading LoRA: {e}"
|
||||||
|
current_model.set_lora_enabled(True)
|
||||||
else:
|
else:
|
||||||
print("Disabling LoRA", file=sys.stderr)
|
print("Disabling LoRA", file=sys.stderr)
|
||||||
current_model.set_lora_enabled(False)
|
current_model.set_lora_enabled(False)
|
||||||
@@ -368,6 +391,7 @@ def start_training(
|
|||||||
warmup_steps=100,
|
warmup_steps=100,
|
||||||
max_steps=None,
|
max_steps=None,
|
||||||
sample_rate=44100,
|
sample_rate=44100,
|
||||||
|
max_grad_norm=1.0,
|
||||||
# LoRA advanced
|
# LoRA advanced
|
||||||
enable_lm=True,
|
enable_lm=True,
|
||||||
enable_dit=True,
|
enable_dit=True,
|
||||||
@@ -409,11 +433,25 @@ def start_training(
|
|||||||
# Resolve max_steps default
|
# Resolve max_steps default
|
||||||
resolved_max_steps = int(max_steps) if max_steps not in (None, "", 0) else int(num_iters)
|
resolved_max_steps = int(max_steps) if max_steps not in (None, "", 0) else int(num_iters)
|
||||||
|
|
||||||
|
# Auto-detect out_sample_rate from model config
|
||||||
|
out_sample_rate = 0
|
||||||
|
config_file = os.path.join(pretrained_path, "config.json")
|
||||||
|
if os.path.isfile(config_file):
|
||||||
|
try:
|
||||||
|
with open(config_file, "r", encoding="utf-8") as f:
|
||||||
|
cfg = json.load(f)
|
||||||
|
out_sr = cfg.get("audio_vae_config", {}).get("out_sample_rate")
|
||||||
|
if out_sr:
|
||||||
|
out_sample_rate = int(out_sr)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"pretrained_path": pretrained_path,
|
"pretrained_path": pretrained_path,
|
||||||
"train_manifest": train_manifest,
|
"train_manifest": train_manifest,
|
||||||
"val_manifest": val_manifest,
|
"val_manifest": val_manifest,
|
||||||
"sample_rate": int(sample_rate),
|
"sample_rate": int(sample_rate),
|
||||||
|
"out_sample_rate": out_sample_rate,
|
||||||
"batch_size": int(batch_size),
|
"batch_size": int(batch_size),
|
||||||
"grad_accum_steps": int(grad_accum_steps),
|
"grad_accum_steps": int(grad_accum_steps),
|
||||||
"num_workers": int(num_workers),
|
"num_workers": int(num_workers),
|
||||||
@@ -425,6 +463,7 @@ def start_training(
|
|||||||
"weight_decay": float(weight_decay),
|
"weight_decay": float(weight_decay),
|
||||||
"warmup_steps": int(warmup_steps),
|
"warmup_steps": int(warmup_steps),
|
||||||
"max_steps": resolved_max_steps,
|
"max_steps": resolved_max_steps,
|
||||||
|
"max_grad_norm": float(max_grad_norm),
|
||||||
"save_path": checkpoints_dir,
|
"save_path": checkpoints_dir,
|
||||||
"tensorboard": tensorboard_path if tensorboard_path else logs_dir,
|
"tensorboard": tensorboard_path if tensorboard_path else logs_dir,
|
||||||
"lambdas": {"loss/diff": 1.0, "loss/stop": 1.0},
|
"lambdas": {"loss/diff": 1.0, "loss/stop": 1.0},
|
||||||
@@ -932,17 +971,19 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
max_steps = gr.Number(label="最大步数 (max_steps, 0→默认num_iters)", value=0, precision=0)
|
max_steps = gr.Number(label="最大步数 (max_steps, 0→默认num_iters)", value=0, precision=0)
|
||||||
sample_rate = gr.Number(label="采样率 (sample_rate)", value=44100, precision=0)
|
sample_rate = gr.Number(label="采样率 (sample_rate)", value=44100, precision=0)
|
||||||
tensorboard_path = gr.Textbox(label="Tensorboard 路径 (可选)", value="")
|
max_grad_norm = gr.Number(label="梯度裁剪 (max_grad_norm, 0=关闭)", value=1.0)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
tensorboard_path = gr.Textbox(label="Tensorboard 路径 (可选)", value="")
|
||||||
enable_lm = gr.Checkbox(label="启用 LoRA LM (enable_lm)", value=True)
|
enable_lm = gr.Checkbox(label="启用 LoRA LM (enable_lm)", value=True)
|
||||||
enable_dit = gr.Checkbox(label="启用 LoRA DIT (enable_dit)", value=True)
|
enable_dit = gr.Checkbox(label="启用 LoRA DIT (enable_dit)", value=True)
|
||||||
|
with gr.Row():
|
||||||
enable_proj = gr.Checkbox(label="启用投影 (enable_proj)", value=False)
|
enable_proj = gr.Checkbox(label="启用投影 (enable_proj)", value=False)
|
||||||
dropout = gr.Number(label="LoRA Dropout", value=0.0)
|
dropout = gr.Number(label="LoRA Dropout", value=0.0)
|
||||||
|
|
||||||
gr.Markdown("#### 分发选项 (Distribution)")
|
gr.Markdown("#### 分发选项 (Distribution)")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
hf_model_id = gr.Textbox(
|
hf_model_id = gr.Textbox(
|
||||||
label="HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)", value="openbmb/VoxCPM1.5"
|
label="HuggingFace Model ID (e.g., openbmb/VoxCPM2)", value=""
|
||||||
)
|
)
|
||||||
distribute = gr.Checkbox(label="分发模式 (distribute)", value=False)
|
distribute = gr.Checkbox(label="分发模式 (distribute)", value=False)
|
||||||
|
|
||||||
@@ -992,6 +1033,7 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
|
|||||||
warmup_steps,
|
warmup_steps,
|
||||||
max_steps,
|
max_steps,
|
||||||
sample_rate,
|
sample_rate,
|
||||||
|
max_grad_norm,
|
||||||
enable_lm,
|
enable_lm,
|
||||||
enable_dit,
|
enable_dit,
|
||||||
enable_proj,
|
enable_proj,
|
||||||
@@ -1150,12 +1192,13 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
|
|||||||
"warmup_steps": "warmup_steps",
|
"warmup_steps": "warmup_steps",
|
||||||
"max_steps": "最大步数 (max_steps)",
|
"max_steps": "最大步数 (max_steps)",
|
||||||
"sample_rate": "采样率 (sample_rate)",
|
"sample_rate": "采样率 (sample_rate)",
|
||||||
|
"max_grad_norm": "梯度裁剪 (max_grad_norm, 0=关闭)",
|
||||||
"enable_lm": "启用 LoRA LM (enable_lm)",
|
"enable_lm": "启用 LoRA LM (enable_lm)",
|
||||||
"enable_dit": "启用 LoRA DIT (enable_dit)",
|
"enable_dit": "启用 LoRA DIT (enable_dit)",
|
||||||
"enable_proj": "启用投影 (enable_proj)",
|
"enable_proj": "启用投影 (enable_proj)",
|
||||||
"dropout": "LoRA Dropout",
|
"dropout": "LoRA Dropout",
|
||||||
"tensorboard_path": "Tensorboard 路径 (可选)",
|
"tensorboard_path": "Tensorboard 路径 (可选)",
|
||||||
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
|
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM2)",
|
||||||
"distribute": "分发模式 (distribute)",
|
"distribute": "分发模式 (distribute)",
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
@@ -1168,12 +1211,13 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
|
|||||||
"warmup_steps": "Warmup Steps",
|
"warmup_steps": "Warmup Steps",
|
||||||
"max_steps": "Max Steps",
|
"max_steps": "Max Steps",
|
||||||
"sample_rate": "Sample Rate",
|
"sample_rate": "Sample Rate",
|
||||||
|
"max_grad_norm": "Max Grad Norm (0=disabled)",
|
||||||
"enable_lm": "Enable LoRA LM",
|
"enable_lm": "Enable LoRA LM",
|
||||||
"enable_dit": "Enable LoRA DIT",
|
"enable_dit": "Enable LoRA DIT",
|
||||||
"enable_proj": "Enable Projection",
|
"enable_proj": "Enable Projection",
|
||||||
"dropout": "LoRA Dropout",
|
"dropout": "LoRA Dropout",
|
||||||
"tensorboard_path": "Tensorboard Path (Optional)",
|
"tensorboard_path": "Tensorboard Path (Optional)",
|
||||||
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
|
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM2)",
|
||||||
"distribute": "Distribute Mode",
|
"distribute": "Distribute Mode",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1203,11 +1247,12 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
|
|||||||
gr.update(label=adv["warmup_steps"]),
|
gr.update(label=adv["warmup_steps"]),
|
||||||
gr.update(label=adv["max_steps"]),
|
gr.update(label=adv["max_steps"]),
|
||||||
gr.update(label=adv["sample_rate"]),
|
gr.update(label=adv["sample_rate"]),
|
||||||
|
gr.update(label=adv["max_grad_norm"]),
|
||||||
|
gr.update(label=adv["tensorboard_path"]),
|
||||||
gr.update(label=adv["enable_lm"]),
|
gr.update(label=adv["enable_lm"]),
|
||||||
gr.update(label=adv["enable_dit"]),
|
gr.update(label=adv["enable_dit"]),
|
||||||
gr.update(label=adv["enable_proj"]),
|
gr.update(label=adv["enable_proj"]),
|
||||||
gr.update(label=adv["dropout"]),
|
gr.update(label=adv["dropout"]),
|
||||||
gr.update(label=adv["tensorboard_path"]),
|
|
||||||
# Distribution options
|
# Distribution options
|
||||||
gr.update(label=adv["hf_model_id"]),
|
gr.update(label=adv["hf_model_id"]),
|
||||||
gr.update(label=adv["distribute"]),
|
gr.update(label=adv["distribute"]),
|
||||||
@@ -1254,11 +1299,12 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
|
|||||||
warmup_steps,
|
warmup_steps,
|
||||||
max_steps,
|
max_steps,
|
||||||
sample_rate,
|
sample_rate,
|
||||||
|
max_grad_norm,
|
||||||
|
tensorboard_path,
|
||||||
enable_lm,
|
enable_lm,
|
||||||
enable_dit,
|
enable_dit,
|
||||||
enable_proj,
|
enable_proj,
|
||||||
dropout,
|
dropout,
|
||||||
tensorboard_path,
|
|
||||||
# distribution outputs
|
# distribution outputs
|
||||||
hf_model_id,
|
hf_model_id,
|
||||||
distribute,
|
distribute,
|
||||||
|
|||||||
@@ -0,0 +1,34 @@
|
|||||||
|
"""
|
||||||
|
模型下载脚本
|
||||||
|
|
||||||
|
"""
|
||||||
|
from modelscope import snapshot_download
|
||||||
|
|
||||||
|
|
||||||
|
def download(repo_id:str, local_dir:str):
|
||||||
|
"""
|
||||||
|
下载模型仓库或单个文件
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_id (str): 用户名/仓库名,例如 'stabilityai/sdxl-turbo'
|
||||||
|
local_dir (str or Path): 下载文件放置的本地目录路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 下载文件的本地路径
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 当 repo_id 格式不正确时
|
||||||
|
"""
|
||||||
|
model_dir = snapshot_download(
|
||||||
|
repo_id,
|
||||||
|
repo_type='model',
|
||||||
|
local_dir=f"{local_dir}/{repo_id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
download("OpenBMB/VoxCPM2", "./models")
|
||||||
|
download("iic/SenseVoiceSmall", "./models")
|
||||||
|
download('iic/speech_zipenhancer_ans_multiloss_16k_base',"./models")
|
||||||
+1
-2
@@ -47,8 +47,7 @@ dependencies = [
|
|||||||
"funasr",
|
"funasr",
|
||||||
"spaces",
|
"spaces",
|
||||||
"argbind",
|
"argbind",
|
||||||
"safetensors"
|
"safetensors",
|
||||||
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|||||||
@@ -0,0 +1,108 @@
|
|||||||
|
"""Unit checks for pick_runtime_dtype / get_dtype consistency.
|
||||||
|
|
||||||
|
Loads src/voxcpm/model/utils.py directly to avoid the heavy voxcpm package
|
||||||
|
init. Run with: `python scripts/test_pick_runtime_dtype.py`.
|
||||||
|
"""
|
||||||
|
import importlib.util
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import sys
|
||||||
|
|
||||||
|
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent
|
||||||
|
UTILS = str(REPO_ROOT / "src" / "voxcpm" / "model" / "utils.py")
|
||||||
|
spec = importlib.util.spec_from_file_location("voxcpm_utils", UTILS)
|
||||||
|
utils = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(utils)
|
||||||
|
|
||||||
|
_LOW_PRECISION_DTYPES = utils._LOW_PRECISION_DTYPES
|
||||||
|
_VALID_DTYPE_OVERRIDES = utils._VALID_DTYPE_OVERRIDES
|
||||||
|
get_dtype = utils.get_dtype
|
||||||
|
pick_runtime_dtype = utils.pick_runtime_dtype
|
||||||
|
|
||||||
|
|
||||||
|
def expect(actual, expected, label):
|
||||||
|
ok = actual == expected
|
||||||
|
mark = "OK " if ok else "FAIL"
|
||||||
|
print(f"[{mark}] {label}: got={actual!r} expected={expected!r}")
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def expect_raises(fn, exc_type, label):
|
||||||
|
try:
|
||||||
|
fn()
|
||||||
|
except exc_type as e:
|
||||||
|
print(f"[OK ] {label}: raised {exc_type.__name__}: {e}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[FAIL] {label}: raised {type(e).__name__} not {exc_type.__name__}: {e}")
|
||||||
|
return False
|
||||||
|
print(f"[FAIL] {label}: no exception raised")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
print("=== override set sanity ===")
|
||||||
|
results.append(expect("half" not in _VALID_DTYPE_OVERRIDES, True, "half removed from _VALID_DTYPE_OVERRIDES"))
|
||||||
|
results.append(expect("half" not in _LOW_PRECISION_DTYPES, True, "half removed from _LOW_PRECISION_DTYPES"))
|
||||||
|
|
||||||
|
print("\n=== every accepted override parses through get_dtype ===")
|
||||||
|
for dt in sorted(_VALID_DTYPE_OVERRIDES):
|
||||||
|
try:
|
||||||
|
torch_dtype = get_dtype(dt)
|
||||||
|
print(f"[OK ] get_dtype({dt!r}) -> {torch_dtype}")
|
||||||
|
results.append(True)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[FAIL] get_dtype({dt!r}) raised: {e}")
|
||||||
|
results.append(False)
|
||||||
|
|
||||||
|
print("\n=== pick_runtime_dtype: non-mps is a no-op ===")
|
||||||
|
results.append(expect(pick_runtime_dtype("cuda", "bfloat16"), "bfloat16", "cuda/bf16 untouched"))
|
||||||
|
results.append(expect(pick_runtime_dtype("cpu", "float16"), "float16", "cpu/fp16 untouched"))
|
||||||
|
results.append(expect(pick_runtime_dtype("cuda", "float32"), "float32", "cuda/fp32 untouched"))
|
||||||
|
|
||||||
|
print("\n=== pick_runtime_dtype: mps forces fp32 for low-precision ===")
|
||||||
|
os.environ.pop("VOXCPM_MPS_DTYPE", None)
|
||||||
|
results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "float32", "mps/bf16 -> fp32"))
|
||||||
|
results.append(expect(pick_runtime_dtype("mps", "bf16"), "float32", "mps/bf16-alias -> fp32"))
|
||||||
|
results.append(expect(pick_runtime_dtype("mps", "float16"), "float32", "mps/fp16 -> fp32"))
|
||||||
|
results.append(expect(pick_runtime_dtype("mps", "fp16"), "float32", "mps/fp16-alias -> fp32"))
|
||||||
|
results.append(expect(pick_runtime_dtype("mps", "float32"), "float32", "mps/fp32 stays"))
|
||||||
|
results.append(expect(pick_runtime_dtype("mps", "fp32"), "fp32", "mps/fp32-alias stays"))
|
||||||
|
|
||||||
|
print("\n=== pick_runtime_dtype: VOXCPM_MPS_DTYPE override ===")
|
||||||
|
os.environ["VOXCPM_MPS_DTYPE"] = "bfloat16"
|
||||||
|
results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "bfloat16", "override bf16 honored"))
|
||||||
|
|
||||||
|
os.environ["VOXCPM_MPS_DTYPE"] = "FP16"
|
||||||
|
results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "fp16", "override is case-insensitive"))
|
||||||
|
|
||||||
|
os.environ["VOXCPM_MPS_DTYPE"] = " float32 "
|
||||||
|
results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "float32", "override is whitespace-trimmed"))
|
||||||
|
|
||||||
|
print("\n=== pick_runtime_dtype: 'half' is no longer a valid override ===")
|
||||||
|
os.environ["VOXCPM_MPS_DTYPE"] = "half"
|
||||||
|
results.append(
|
||||||
|
expect_raises(
|
||||||
|
lambda: pick_runtime_dtype("mps", "bfloat16"),
|
||||||
|
ValueError,
|
||||||
|
"override=half now rejected (was the bug)",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
os.environ["VOXCPM_MPS_DTYPE"] = "garbage"
|
||||||
|
results.append(
|
||||||
|
expect_raises(
|
||||||
|
lambda: pick_runtime_dtype("mps", "bfloat16"),
|
||||||
|
ValueError,
|
||||||
|
"override=garbage still rejected",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
os.environ.pop("VOXCPM_MPS_DTYPE", None)
|
||||||
|
|
||||||
|
print("\n=== summary ===")
|
||||||
|
passed = sum(results)
|
||||||
|
total = len(results)
|
||||||
|
print(f"{passed}/{total} passed")
|
||||||
|
sys.exit(0 if passed == total else 1)
|
||||||
@@ -30,7 +30,8 @@ except ImportError:
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
from voxcpm.model import VoxCPMModel, VoxCPM2Model
|
from voxcpm.model import VoxCPMModel, VoxCPM2Model
|
||||||
from voxcpm.model.voxcpm import LoRAConfig
|
from voxcpm.model.voxcpm import LoRAConfig as LoRAConfigV1
|
||||||
|
from voxcpm.model.voxcpm2 import LoRAConfig as LoRAConfigV2
|
||||||
from voxcpm.training import (
|
from voxcpm.training import (
|
||||||
Accelerator,
|
Accelerator,
|
||||||
BatchProcessor,
|
BatchProcessor,
|
||||||
@@ -46,7 +47,7 @@ def train(
|
|||||||
train_manifest: str,
|
train_manifest: str,
|
||||||
val_manifest: str = "",
|
val_manifest: str = "",
|
||||||
sample_rate: int = 16_000,
|
sample_rate: int = 16_000,
|
||||||
out_sample_rate: int = 0, # accepted from YAML for documentation; not used in training
|
out_sample_rate: int = 0, # AudioVAE decoder output rate; used for TensorBoard audio logging
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
grad_accum_steps: int = 1,
|
grad_accum_steps: int = 1,
|
||||||
num_workers: int = 2,
|
num_workers: int = 2,
|
||||||
@@ -64,12 +65,12 @@ def train(
|
|||||||
lambdas: Dict[str, float] = {"loss/diff": 1.0, "loss/stop": 1.0},
|
lambdas: Dict[str, float] = {"loss/diff": 1.0, "loss/stop": 1.0},
|
||||||
lora: dict = None,
|
lora: dict = None,
|
||||||
config_path: str = "",
|
config_path: str = "",
|
||||||
|
max_grad_norm: float = 0.0, # gradient clipping; 0 = disabled (backward compat)
|
||||||
# Distribution options (for LoRA checkpoints)
|
# Distribution options (for LoRA checkpoints)
|
||||||
hf_model_id: str = "", # HuggingFace model ID (e.g., "openbmb/VoxCPM1.5")
|
hf_model_id: str = "", # HuggingFace model ID (e.g., "openbmb/VoxCPM1.5")
|
||||||
distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path
|
distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path
|
||||||
):
|
):
|
||||||
_ = config_path
|
_ = config_path
|
||||||
_ = out_sample_rate
|
|
||||||
|
|
||||||
# Validate distribution options
|
# Validate distribution options
|
||||||
if lora is not None and distribute and not hf_model_id:
|
if lora is not None and distribute and not hf_model_id:
|
||||||
@@ -93,6 +94,7 @@ def train(
|
|||||||
with open(os.path.join(pretrained_path, "config.json"), "r", encoding="utf-8") as _f:
|
with open(os.path.join(pretrained_path, "config.json"), "r", encoding="utf-8") as _f:
|
||||||
_arch = json.load(_f).get("architecture", "voxcpm").lower()
|
_arch = json.load(_f).get("architecture", "voxcpm").lower()
|
||||||
_model_cls = VoxCPM2Model if _arch == "voxcpm2" else VoxCPMModel
|
_model_cls = VoxCPM2Model if _arch == "voxcpm2" else VoxCPMModel
|
||||||
|
LoRAConfig = LoRAConfigV2 if _arch == "voxcpm2" else LoRAConfigV1
|
||||||
if accelerator.rank == 0:
|
if accelerator.rank == 0:
|
||||||
print(f"Detected architecture: {_arch} -> {_model_cls.__name__}", file=sys.stderr)
|
print(f"Detected architecture: {_arch} -> {_model_cls.__name__}", file=sys.stderr)
|
||||||
base_model = _model_cls.from_local(
|
base_model = _model_cls.from_local(
|
||||||
@@ -178,8 +180,12 @@ def train(
|
|||||||
dataset_cnt=dataset_cnt,
|
dataset_cnt=dataset_cnt,
|
||||||
device=accelerator.device,
|
device=accelerator.device,
|
||||||
)
|
)
|
||||||
# Save audio_vae for audio generation
|
# Save audio_vae and output sample rate for audio generation.
|
||||||
|
# Prefer model's actual output rate; fall back to YAML out_sample_rate or encode rate.
|
||||||
audio_vae_for_gen = base_model.audio_vae
|
audio_vae_for_gen = base_model.audio_vae
|
||||||
|
out_sr = base_model.sample_rate # decoder output rate (e.g. 48000 for V2)
|
||||||
|
if out_sr == 0 and out_sample_rate > 0:
|
||||||
|
out_sr = out_sample_rate
|
||||||
del base_model.audio_vae
|
del base_model.audio_vae
|
||||||
model = accelerator.prepare_model(base_model)
|
model = accelerator.prepare_model(base_model)
|
||||||
unwrapped_model = accelerator.unwrap(model)
|
unwrapped_model = accelerator.unwrap(model)
|
||||||
@@ -312,8 +318,8 @@ def train(
|
|||||||
scaler = getattr(accelerator, "scaler", None)
|
scaler = getattr(accelerator, "scaler", None)
|
||||||
if scaler is not None:
|
if scaler is not None:
|
||||||
scaler.unscale_(optimizer)
|
scaler.unscale_(optimizer)
|
||||||
# Use large max_norm to only compute grad_norm without actual clipping
|
effective_max_norm = max_grad_norm if max_grad_norm > 0 else 1e9
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(unwrapped_model.parameters(), max_norm=1e9)
|
grad_norm = torch.nn.utils.clip_grad_norm_(unwrapped_model.parameters(), max_norm=effective_max_norm)
|
||||||
|
|
||||||
accelerator.step(optimizer)
|
accelerator.step(optimizer)
|
||||||
accelerator.update()
|
accelerator.update()
|
||||||
@@ -341,6 +347,7 @@ def train(
|
|||||||
val_ds=val_ds,
|
val_ds=val_ds,
|
||||||
audio_vae=audio_vae_for_gen,
|
audio_vae=audio_vae_for_gen,
|
||||||
sample_rate=sample_rate,
|
sample_rate=sample_rate,
|
||||||
|
out_sample_rate=out_sr,
|
||||||
val_texts=val_texts,
|
val_texts=val_texts,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
valid_interval=valid_interval,
|
valid_interval=valid_interval,
|
||||||
@@ -367,6 +374,7 @@ def validate(
|
|||||||
val_ds=None,
|
val_ds=None,
|
||||||
audio_vae=None,
|
audio_vae=None,
|
||||||
sample_rate=22050,
|
sample_rate=22050,
|
||||||
|
out_sample_rate=0,
|
||||||
val_texts=None,
|
val_texts=None,
|
||||||
tokenizer=None,
|
tokenizer=None,
|
||||||
valid_interval=1000,
|
valid_interval=1000,
|
||||||
@@ -432,6 +440,7 @@ def validate(
|
|||||||
step,
|
step,
|
||||||
accelerator,
|
accelerator,
|
||||||
sample_rate,
|
sample_rate,
|
||||||
|
out_sample_rate=out_sample_rate,
|
||||||
val_texts=val_texts,
|
val_texts=val_texts,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
valid_interval=valid_interval,
|
valid_interval=valid_interval,
|
||||||
@@ -534,6 +543,7 @@ def generate_sample_audio(
|
|||||||
step,
|
step,
|
||||||
accelerator,
|
accelerator,
|
||||||
sample_rate=22050,
|
sample_rate=22050,
|
||||||
|
out_sample_rate=0,
|
||||||
val_texts=None,
|
val_texts=None,
|
||||||
tokenizer=None,
|
tokenizer=None,
|
||||||
pretrained_path=None,
|
pretrained_path=None,
|
||||||
@@ -548,6 +558,10 @@ def generate_sample_audio(
|
|||||||
log(f"[Audio] Starting audio generation for {num_samples} samples at step {step}")
|
log(f"[Audio] Starting audio generation for {num_samples} samples at step {step}")
|
||||||
|
|
||||||
unwrapped_model = accelerator.unwrap(model)
|
unwrapped_model = accelerator.unwrap(model)
|
||||||
|
# Determine the correct output sample rate for generated audio.
|
||||||
|
# out_sample_rate is the decoder output rate (e.g. 48kHz for V2);
|
||||||
|
# sample_rate is the encoder input rate (e.g. 16kHz for V2).
|
||||||
|
gen_sr = out_sample_rate if out_sample_rate > 0 else sample_rate
|
||||||
|
|
||||||
for i in range(num_samples):
|
for i in range(num_samples):
|
||||||
sample = val_ds[i]
|
sample = val_ds[i]
|
||||||
@@ -604,10 +618,10 @@ def generate_sample_audio(
|
|||||||
gen_audio_np = normalize_audio(gen_audio_np)
|
gen_audio_np = normalize_audio(gen_audio_np)
|
||||||
|
|
||||||
tag = f"val_sample_{i}"
|
tag = f"val_sample_{i}"
|
||||||
writer.add_audio(f"{tag}/generated_audio", gen_audio_np, global_step=step, sample_rate=sample_rate)
|
writer.add_audio(f"{tag}/generated_audio", gen_audio_np, global_step=step, sample_rate=gen_sr)
|
||||||
log(f"[Audio] Generated audio for sample {i}: duration={len(gen_audio_np)/sample_rate:.2f}s")
|
log(f"[Audio] Generated audio for sample {i}: duration={len(gen_audio_np)/gen_sr:.2f}s")
|
||||||
|
|
||||||
# Log reference audio
|
# Log reference audio (at encoder input rate, which is what val_ds provides)
|
||||||
if ref_audio_np is not None:
|
if ref_audio_np is not None:
|
||||||
writer.add_audio(
|
writer.add_audio(
|
||||||
f"{tag}/reference_audio", normalize_audio(ref_audio_np), global_step=step, sample_rate=sample_rate
|
f"{tag}/reference_audio", normalize_audio(ref_audio_np), global_step=step, sample_rate=sample_rate
|
||||||
@@ -615,9 +629,9 @@ def generate_sample_audio(
|
|||||||
|
|
||||||
# Generate mel spectrogram figure
|
# Generate mel spectrogram figure
|
||||||
try:
|
try:
|
||||||
mel_gen = compute_mel_spectrogram(gen_audio_np, sample_rate)
|
mel_gen = compute_mel_spectrogram(gen_audio_np, gen_sr)
|
||||||
mel_ref = compute_mel_spectrogram(ref_audio_np, sample_rate) if ref_audio_np is not None else None
|
mel_ref = compute_mel_spectrogram(ref_audio_np, sample_rate) if ref_audio_np is not None else None
|
||||||
fig = create_mel_figure(gen_audio_np, mel_gen, sample_rate, step, ref_audio_np, mel_ref)
|
fig = create_mel_figure(gen_audio_np, mel_gen, gen_sr, step, ref_audio_np, mel_ref)
|
||||||
writer.add_figure(f"{tag}/mel_spectrogram", fig, global_step=step)
|
writer.add_figure(f"{tag}/mel_spectrogram", fig, global_step=step)
|
||||||
log(f"[Audio] Created mel spectrogram figure for sample {i}")
|
log(f"[Audio] Created mel spectrogram figure for sample {i}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
+60
-6
@@ -11,11 +11,6 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import soundfile as sf
|
|
||||||
|
|
||||||
from voxcpm.core import VoxCPM
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_HF_MODEL_ID = "openbmb/VoxCPM2"
|
DEFAULT_HF_MODEL_ID = "openbmb/VoxCPM2"
|
||||||
|
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
@@ -173,7 +168,9 @@ def validate_batch_args(args, parser):
|
|||||||
# -----------------------------
|
# -----------------------------
|
||||||
|
|
||||||
|
|
||||||
def load_model(args) -> VoxCPM:
|
def load_model(args):
|
||||||
|
from voxcpm.core import VoxCPM
|
||||||
|
|
||||||
print("Loading VoxCPM model...", file=sys.stderr)
|
print("Loading VoxCPM model...", file=sys.stderr)
|
||||||
|
|
||||||
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get(
|
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get(
|
||||||
@@ -209,6 +206,7 @@ def load_model(args) -> VoxCPM:
|
|||||||
zipenhancer_model_path=zipenhancer_path,
|
zipenhancer_model_path=zipenhancer_path,
|
||||||
enable_denoiser=not args.no_denoiser,
|
enable_denoiser=not args.no_denoiser,
|
||||||
optimize=not args.no_optimize,
|
optimize=not args.no_optimize,
|
||||||
|
device=args.device,
|
||||||
lora_config=lora_config,
|
lora_config=lora_config,
|
||||||
lora_weights_path=lora_weights_path,
|
lora_weights_path=lora_weights_path,
|
||||||
)
|
)
|
||||||
@@ -227,6 +225,7 @@ def load_model(args) -> VoxCPM:
|
|||||||
cache_dir=args.cache_dir,
|
cache_dir=args.cache_dir,
|
||||||
local_files_only=args.local_files_only,
|
local_files_only=args.local_files_only,
|
||||||
optimize=not args.no_optimize,
|
optimize=not args.no_optimize,
|
||||||
|
device=args.device,
|
||||||
lora_config=lora_config,
|
lora_config=lora_config,
|
||||||
lora_weights_path=lora_weights_path,
|
lora_weights_path=lora_weights_path,
|
||||||
)
|
)
|
||||||
@@ -264,6 +263,8 @@ def _run_single(args, parser, *, text: str, output: str, prompt_text: str | None
|
|||||||
and (args.prompt_audio is not None or args.reference_audio is not None),
|
and (args.prompt_audio is not None or args.reference_audio is not None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
|
sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
|
||||||
|
|
||||||
duration = len(audio_array) / model.tts_model.sample_rate
|
duration = len(audio_array) / model.tts_model.sample_rate
|
||||||
@@ -286,7 +287,27 @@ def cmd_clone(args, parser):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_validate(args, parser):
|
||||||
|
from voxcpm.training.validate import (
|
||||||
|
print_validation_report,
|
||||||
|
validate_manifest,
|
||||||
|
)
|
||||||
|
|
||||||
|
manifest = str(require_file_exists(args.manifest, parser, "manifest file"))
|
||||||
|
result = validate_manifest(
|
||||||
|
manifest_path=manifest,
|
||||||
|
sample_rate=args.sample_rate,
|
||||||
|
max_samples=args.max_samples,
|
||||||
|
verbose=args.verbose,
|
||||||
|
)
|
||||||
|
print_validation_report(result, manifest)
|
||||||
|
if not result.is_valid:
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def cmd_batch(args, parser):
|
def cmd_batch(args, parser):
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
input_file = require_file_exists(args.input, parser, "input file")
|
input_file = require_file_exists(args.input, parser, "input file")
|
||||||
output_dir = Path(args.output_dir)
|
output_dir = Path(args.output_dir)
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -403,6 +424,12 @@ def _add_model_args(parser):
|
|||||||
default=DEFAULT_HF_MODEL_ID,
|
default=DEFAULT_HF_MODEL_ID,
|
||||||
help=f"Hugging Face repo id (default: {DEFAULT_HF_MODEL_ID})",
|
help=f"Hugging Face repo id (default: {DEFAULT_HF_MODEL_ID})",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default="auto",
|
||||||
|
help="Runtime device: auto, cpu, mps, cuda, or cuda:N (default: auto)",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cache-dir", type=str, help="Cache directory for Hub downloads"
|
"--cache-dir", type=str, help="Cache directory for Hub downloads"
|
||||||
)
|
)
|
||||||
@@ -524,6 +551,30 @@ Examples:
|
|||||||
_add_model_args(batch_parser)
|
_add_model_args(batch_parser)
|
||||||
_add_lora_args(batch_parser)
|
_add_lora_args(batch_parser)
|
||||||
|
|
||||||
|
# Validate subcommand
|
||||||
|
validate_parser = subparsers.add_parser(
|
||||||
|
"validate",
|
||||||
|
help="Validate a training data manifest (JSONL) before fine-tuning",
|
||||||
|
)
|
||||||
|
validate_parser.add_argument(
|
||||||
|
"--manifest", "-m", required=True, help="Path to JSONL training manifest"
|
||||||
|
)
|
||||||
|
validate_parser.add_argument(
|
||||||
|
"--sample-rate",
|
||||||
|
type=int,
|
||||||
|
default=16_000,
|
||||||
|
help="Expected audio sample rate in Hz (default: 16000)",
|
||||||
|
)
|
||||||
|
validate_parser.add_argument(
|
||||||
|
"--max-samples",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Maximum number of samples to validate (0 = all, default: 0)",
|
||||||
|
)
|
||||||
|
validate_parser.add_argument(
|
||||||
|
"--verbose", "-v", action="store_true", help="Print per-sample progress"
|
||||||
|
)
|
||||||
|
|
||||||
# Legacy root arguments
|
# Legacy root arguments
|
||||||
parser.add_argument("--input", "-i", help="Input text file (batch mode only)")
|
parser.add_argument("--input", "-i", help="Input text file (batch mode only)")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -576,6 +627,9 @@ def main():
|
|||||||
parser = _build_parser()
|
parser = _build_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.command == "validate":
|
||||||
|
return cmd_validate(args, parser)
|
||||||
|
|
||||||
validate_ranges(args, parser)
|
validate_ranges(args, parser)
|
||||||
|
|
||||||
if args.command == "design":
|
if args.command == "design":
|
||||||
|
|||||||
+32
-5
@@ -8,6 +8,7 @@ from typing import Generator, Optional
|
|||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from .model.voxcpm import VoxCPMModel, LoRAConfig
|
from .model.voxcpm import VoxCPMModel, LoRAConfig
|
||||||
from .model.voxcpm2 import VoxCPM2Model
|
from .model.voxcpm2 import VoxCPM2Model
|
||||||
|
from .model.utils import next_and_close
|
||||||
|
|
||||||
|
|
||||||
class VoxCPM:
|
class VoxCPM:
|
||||||
@@ -17,6 +18,7 @@ class VoxCPM:
|
|||||||
zipenhancer_model_path: str | None = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
zipenhancer_model_path: str | None = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||||
enable_denoiser: bool = True,
|
enable_denoiser: bool = True,
|
||||||
optimize: bool = True,
|
optimize: bool = True,
|
||||||
|
device: str | None = None,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
lora_weights_path: Optional[str] = None,
|
lora_weights_path: Optional[str] = None,
|
||||||
):
|
):
|
||||||
@@ -30,6 +32,9 @@ class VoxCPM:
|
|||||||
id or local path. If None, denoiser will not be initialized.
|
id or local path. If None, denoiser will not be initialized.
|
||||||
enable_denoiser: Whether to initialize the denoiser pipeline.
|
enable_denoiser: Whether to initialize the denoiser pipeline.
|
||||||
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
|
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
|
||||||
|
device: Runtime device. If set to ``None`` or ``"auto"``, VoxCPM
|
||||||
|
will choose automatically (preferring CUDA, then MPS, then CPU).
|
||||||
|
If set explicitly, that device is used or a clear error is raised.
|
||||||
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
|
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
|
||||||
provided without lora_config, a default config will be created.
|
provided without lora_config, a default config will be created.
|
||||||
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
|
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
|
||||||
@@ -56,10 +61,20 @@ class VoxCPM:
|
|||||||
arch = config.get("architecture", "voxcpm").lower()
|
arch = config.get("architecture", "voxcpm").lower()
|
||||||
|
|
||||||
if arch == "voxcpm2":
|
if arch == "voxcpm2":
|
||||||
self.tts_model = VoxCPM2Model.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
self.tts_model = VoxCPM2Model.from_local(
|
||||||
|
voxcpm_model_path,
|
||||||
|
optimize=optimize,
|
||||||
|
device=device,
|
||||||
|
lora_config=lora_config,
|
||||||
|
)
|
||||||
print("Loaded VoxCPM2Model", file=sys.stderr)
|
print("Loaded VoxCPM2Model", file=sys.stderr)
|
||||||
elif arch == "voxcpm":
|
elif arch == "voxcpm":
|
||||||
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
self.tts_model = VoxCPMModel.from_local(
|
||||||
|
voxcpm_model_path,
|
||||||
|
optimize=optimize,
|
||||||
|
device=device,
|
||||||
|
lora_config=lora_config,
|
||||||
|
)
|
||||||
print("Loaded VoxCPMModel", file=sys.stderr)
|
print("Loaded VoxCPMModel", file=sys.stderr)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported architecture: {arch}")
|
raise ValueError(f"Unsupported architecture: {arch}")
|
||||||
@@ -94,6 +109,7 @@ class VoxCPM:
|
|||||||
cache_dir: str = None,
|
cache_dir: str = None,
|
||||||
local_files_only: bool = False,
|
local_files_only: bool = False,
|
||||||
optimize: bool = True,
|
optimize: bool = True,
|
||||||
|
device: str | None = None,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
lora_weights_path: Optional[str] = None,
|
lora_weights_path: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -109,6 +125,9 @@ class VoxCPM:
|
|||||||
cache_dir: Custom cache directory for the snapshot.
|
cache_dir: Custom cache directory for the snapshot.
|
||||||
local_files_only: If True, only use local files and do not attempt
|
local_files_only: If True, only use local files and do not attempt
|
||||||
to download.
|
to download.
|
||||||
|
device: Runtime device. Use ``None``/``"auto"`` for automatic
|
||||||
|
fallback, or an explicit value such as ``"cpu"``, ``"mps"``,
|
||||||
|
``"cuda"``, or ``"cuda:0"``.
|
||||||
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
|
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
|
||||||
provided without lora_config, a default config will be created with
|
provided without lora_config, a default config will be created with
|
||||||
enable_lm=True and enable_dit=True.
|
enable_lm=True and enable_dit=True.
|
||||||
@@ -130,7 +149,7 @@ class VoxCPM:
|
|||||||
if not repo_id:
|
if not repo_id:
|
||||||
raise ValueError("You must provide hf_model_id")
|
raise ValueError("You must provide hf_model_id")
|
||||||
|
|
||||||
# Load from local path if provided
|
# 从本地路径加载(如果提供)
|
||||||
if os.path.isdir(repo_id):
|
if os.path.isdir(repo_id):
|
||||||
local_path = repo_id
|
local_path = repo_id
|
||||||
else:
|
else:
|
||||||
@@ -146,13 +165,14 @@ class VoxCPM:
|
|||||||
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
|
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
|
||||||
enable_denoiser=load_denoiser,
|
enable_denoiser=load_denoiser,
|
||||||
optimize=optimize,
|
optimize=optimize,
|
||||||
|
device=device,
|
||||||
lora_config=lora_config,
|
lora_config=lora_config,
|
||||||
lora_weights_path=lora_weights_path,
|
lora_weights_path=lora_weights_path,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate(self, *args, **kwargs) -> np.ndarray:
|
def generate(self, *args, **kwargs) -> np.ndarray:
|
||||||
return next(self._generate(*args, streaming=False, **kwargs))
|
return next_and_close(self._generate(*args, streaming=False, **kwargs))
|
||||||
|
|
||||||
def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]:
|
def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]:
|
||||||
return self._generate(*args, streaming=True, **kwargs)
|
return self._generate(*args, streaming=True, **kwargs)
|
||||||
@@ -200,7 +220,7 @@ class VoxCPM:
|
|||||||
Yields audio chunks for each generation step if ``streaming=True``,
|
Yields audio chunks for each generation step if ``streaming=True``,
|
||||||
otherwise yields a single array containing the final audio.
|
otherwise yields a single array containing the final audio.
|
||||||
"""
|
"""
|
||||||
if not text.strip() or not isinstance(text, str):
|
if not isinstance(text, str) or not text.strip():
|
||||||
raise ValueError("target text must be a non-empty string")
|
raise ValueError("target text must be a non-empty string")
|
||||||
|
|
||||||
if prompt_wav_path is not None:
|
if prompt_wav_path is not None:
|
||||||
@@ -273,8 +293,15 @@ class VoxCPM:
|
|||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if streaming:
|
||||||
|
try:
|
||||||
for wav, _, _ in generate_result:
|
for wav, _, _ in generate_result:
|
||||||
yield wav.squeeze(0).cpu().numpy()
|
yield wav.squeeze(0).cpu().numpy()
|
||||||
|
finally:
|
||||||
|
generate_result.close()
|
||||||
|
else:
|
||||||
|
wav, _, _ = next_and_close(generate_result)
|
||||||
|
yield wav.squeeze(0).cpu().numpy()
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
for tmp_path in temp_files:
|
for tmp_path in temp_files:
|
||||||
|
|||||||
+111
-1
@@ -1,7 +1,25 @@
|
|||||||
from typing import List
|
import os
|
||||||
|
from typing import List, Optional
|
||||||
import torch
|
import torch
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
_LOW_PRECISION_DTYPES = {"bfloat16", "bf16", "float16", "fp16"}
|
||||||
|
_VALID_DTYPE_OVERRIDES = {
|
||||||
|
"bfloat16", "bf16",
|
||||||
|
"float16", "fp16",
|
||||||
|
"float32", "fp32",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Ref: https://github.com/OpenBMB/VoxCPM/issues/256#issuecomment-4235252732
|
||||||
|
# Explicitly close partially-consumed generators so inference_mode cleanup
|
||||||
|
# does not get deferred to Python's GC/finalizer path.
|
||||||
|
def next_and_close(gen):
|
||||||
|
try:
|
||||||
|
return next(gen)
|
||||||
|
finally:
|
||||||
|
gen.close()
|
||||||
|
|
||||||
|
|
||||||
def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
|
def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
|
||||||
"""Create a tokenizer wrapper that converts multi-character Chinese tokens to single characters.
|
"""Create a tokenizer wrapper that converts multi-character Chinese tokens to single characters.
|
||||||
@@ -119,3 +137,95 @@ def get_dtype(dtype: str):
|
|||||||
return torch.float32
|
return torch.float32
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
def _has_mps() -> bool:
|
||||||
|
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||||
|
|
||||||
|
|
||||||
|
def pick_runtime_dtype(device: str, configured_dtype: str) -> str:
|
||||||
|
"""Pick a safe runtime dtype for the resolved device.
|
||||||
|
|
||||||
|
On Apple Silicon (MPS), bfloat16/float16 produce enough numerical drift
|
||||||
|
in the diffusion AR loop that the output is glitched and the model's
|
||||||
|
badcase detector triggers infinite retries. float32 is the only stable
|
||||||
|
option today. CUDA and CPU keep whatever the checkpoint was trained with.
|
||||||
|
|
||||||
|
Users can override with ``VOXCPM_MPS_DTYPE`` (e.g. ``bfloat16``) when
|
||||||
|
they want to test future MPS improvements.
|
||||||
|
"""
|
||||||
|
if device != "mps":
|
||||||
|
return configured_dtype
|
||||||
|
|
||||||
|
override = os.environ.get("VOXCPM_MPS_DTYPE", "").strip().lower()
|
||||||
|
if override:
|
||||||
|
if override not in _VALID_DTYPE_OVERRIDES:
|
||||||
|
raise ValueError(
|
||||||
|
f"VOXCPM_MPS_DTYPE='{override}' is not one of "
|
||||||
|
f"{sorted(_VALID_DTYPE_OVERRIDES)}"
|
||||||
|
)
|
||||||
|
return override
|
||||||
|
|
||||||
|
if (configured_dtype or "").lower() in _LOW_PRECISION_DTYPES:
|
||||||
|
return "float32"
|
||||||
|
return configured_dtype
|
||||||
|
|
||||||
|
|
||||||
|
def auto_select_device(preferred_device: Optional[str] = "cuda") -> str:
|
||||||
|
"""
|
||||||
|
Choose a runtime device automatically.
|
||||||
|
|
||||||
|
Preference order:
|
||||||
|
- if the preferred device is available, use it
|
||||||
|
- otherwise fall back to CUDA -> MPS -> CPU
|
||||||
|
"""
|
||||||
|
preferred = (preferred_device or "cuda").strip().lower()
|
||||||
|
|
||||||
|
if preferred.startswith("cuda") and torch.cuda.is_available():
|
||||||
|
return preferred
|
||||||
|
if preferred == "mps" and _has_mps():
|
||||||
|
return "mps"
|
||||||
|
if preferred == "cpu":
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return "cuda"
|
||||||
|
if _has_mps():
|
||||||
|
return "mps"
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_runtime_device(device: Optional[str], configured_device: str = "cuda") -> str:
|
||||||
|
"""
|
||||||
|
Resolve the actual runtime device.
|
||||||
|
|
||||||
|
Semantics:
|
||||||
|
- ``device`` is ``None`` or ``"auto"``: use automatic fallback selection
|
||||||
|
- otherwise: treat it as an explicit user choice and validate availability
|
||||||
|
"""
|
||||||
|
explicit = None if device is None else device.strip().lower()
|
||||||
|
|
||||||
|
if explicit is None or explicit == "auto":
|
||||||
|
return auto_select_device(configured_device)
|
||||||
|
|
||||||
|
if explicit.startswith("cuda"):
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise ValueError(
|
||||||
|
f"Requested device '{device}', but CUDA is not available. "
|
||||||
|
"Use device='auto' for automatic fallback."
|
||||||
|
)
|
||||||
|
return explicit
|
||||||
|
if explicit == "mps":
|
||||||
|
if not _has_mps():
|
||||||
|
raise ValueError(
|
||||||
|
"Requested device 'mps', but MPS is not available. "
|
||||||
|
"Use device='auto' for automatic fallback."
|
||||||
|
)
|
||||||
|
return "mps"
|
||||||
|
if explicit == "cpu":
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported device '{device}'. Supported values are 'auto', 'cpu', 'mps', "
|
||||||
|
"'cuda', or indexed CUDA devices like 'cuda:0'."
|
||||||
|
)
|
||||||
|
|||||||
+37
-17
@@ -44,7 +44,13 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules
|
|||||||
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
|
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
|
||||||
from ..modules.locenc import VoxCPMLocEnc
|
from ..modules.locenc import VoxCPMLocEnc
|
||||||
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
||||||
from .utils import get_dtype, mask_multichar_chinese_tokens
|
from .utils import (
|
||||||
|
get_dtype,
|
||||||
|
mask_multichar_chinese_tokens,
|
||||||
|
next_and_close,
|
||||||
|
pick_runtime_dtype,
|
||||||
|
resolve_runtime_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class VoxCPMEncoderConfig(BaseModel):
|
class VoxCPMEncoderConfig(BaseModel):
|
||||||
@@ -109,18 +115,22 @@ class VoxCPMModel(nn.Module):
|
|||||||
tokenizer: LlamaTokenizerFast,
|
tokenizer: LlamaTokenizerFast,
|
||||||
audio_vae: AudioVAE,
|
audio_vae: AudioVAE,
|
||||||
lora_config: LoRAConfig = None,
|
lora_config: LoRAConfig = None,
|
||||||
|
device: str | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
self.feat_dim = config.feat_dim
|
self.feat_dim = config.feat_dim
|
||||||
self.patch_size = config.patch_size
|
self.patch_size = config.patch_size
|
||||||
self.device = config.device
|
self.device = resolve_runtime_device(device, config.device)
|
||||||
if not torch.cuda.is_available():
|
self.config.device = self.device
|
||||||
if torch.backends.mps.is_available():
|
resolved_dtype = pick_runtime_dtype(self.device, self.config.dtype)
|
||||||
self.device = "mps"
|
if resolved_dtype != self.config.dtype:
|
||||||
else:
|
print(
|
||||||
self.device = "cpu"
|
f"[voxcpm] adjusted dtype {self.config.dtype} -> {resolved_dtype} for device {self.device}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
self.config.dtype = resolved_dtype
|
||||||
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
|
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
|
||||||
|
|
||||||
# Text-Semantic LM
|
# Text-Semantic LM
|
||||||
@@ -227,6 +237,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
self.residual_lm.forward_step = torch.compile(
|
self.residual_lm.forward_step = torch.compile(
|
||||||
self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True
|
self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True
|
||||||
)
|
)
|
||||||
|
self._feat_encoder_raw = self.feat_encoder
|
||||||
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
||||||
self.feat_decoder.estimator = torch.compile(
|
self.feat_decoder.estimator = torch.compile(
|
||||||
self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True
|
self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True
|
||||||
@@ -337,7 +348,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
return get_dtype(self.config.dtype)
|
return get_dtype(self.config.dtype)
|
||||||
|
|
||||||
def generate(self, *args, **kwargs) -> torch.Tensor:
|
def generate(self, *args, **kwargs) -> torch.Tensor:
|
||||||
return next(self._generate(*args, streaming=False, **kwargs))
|
return next_and_close(self._generate(*args, streaming=False, **kwargs))
|
||||||
|
|
||||||
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
|
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
|
||||||
return self._generate(*args, streaming=True, **kwargs)
|
return self._generate(*args, streaming=True, **kwargs)
|
||||||
@@ -463,7 +474,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
yield decode_audio
|
yield decode_audio
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
latent_pred, pred_audio_feat = next(inference_result)
|
latent_pred, pred_audio_feat = next_and_close(inference_result)
|
||||||
if retry_badcase:
|
if retry_badcase:
|
||||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||||
print(
|
print(
|
||||||
@@ -571,7 +582,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
return merged_cache
|
return merged_cache
|
||||||
|
|
||||||
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
return next_and_close(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
||||||
|
|
||||||
def generate_with_prompt_cache_streaming(
|
def generate_with_prompt_cache_streaming(
|
||||||
self, *args, **kwargs
|
self, *args, **kwargs
|
||||||
@@ -690,7 +701,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
latent_pred, pred_audio_feat = next(inference_result)
|
latent_pred, pred_audio_feat = next_and_close(inference_result)
|
||||||
if retry_badcase:
|
if retry_badcase:
|
||||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||||
print(
|
print(
|
||||||
@@ -713,7 +724,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||||
|
|
||||||
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
return next(self._inference(*args, streaming=False, **kwargs))
|
return next_and_close(self._inference(*args, streaming=False, **kwargs))
|
||||||
|
|
||||||
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
|
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
|
||||||
return self._inference(*args, streaming=True, **kwargs)
|
return self._inference(*args, streaming=True, **kwargs)
|
||||||
@@ -755,7 +766,8 @@ class VoxCPMModel(nn.Module):
|
|||||||
"""
|
"""
|
||||||
B, T, P, D = feat.shape
|
B, T, P, D = feat.shape
|
||||||
|
|
||||||
feat_embed = self.feat_encoder(feat) # [b, t, h_feat]
|
prefill_encoder = getattr(self, "_feat_encoder_raw", self.feat_encoder)
|
||||||
|
feat_embed = prefill_encoder(feat) # [b, t, h_feat]
|
||||||
feat_embed = self.enc_to_lm_proj(feat_embed)
|
feat_embed = self.enc_to_lm_proj(feat_embed)
|
||||||
|
|
||||||
if self.config.lm_config.use_mup:
|
if self.config.lm_config.use_mup:
|
||||||
@@ -845,8 +857,16 @@ class VoxCPMModel(nn.Module):
|
|||||||
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
|
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
|
def from_local(
|
||||||
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
cls,
|
||||||
|
path: str,
|
||||||
|
optimize: bool = True,
|
||||||
|
training: bool = False,
|
||||||
|
device: str | None = None,
|
||||||
|
lora_config: LoRAConfig = None,
|
||||||
|
):
|
||||||
|
with open(os.path.join(path, "config.json"), "r", encoding="utf-8") as _cfg_f:
|
||||||
|
config = VoxCPMConfig.model_validate_json(_cfg_f.read())
|
||||||
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
||||||
audio_vae_config = getattr(config, "audio_vae_config", None)
|
audio_vae_config = getattr(config, "audio_vae_config", None)
|
||||||
audio_vae = AudioVAE(config=audio_vae_config) if audio_vae_config else AudioVAE()
|
audio_vae = AudioVAE(config=audio_vae_config) if audio_vae_config else AudioVAE()
|
||||||
@@ -868,7 +888,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
|
f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
|
||||||
)
|
)
|
||||||
model = cls(config, tokenizer, audio_vae, lora_config)
|
model = cls(config, tokenizer, audio_vae, lora_config, device=device)
|
||||||
if not training:
|
if not training:
|
||||||
lm_dtype = get_dtype(model.config.dtype)
|
lm_dtype = get_dtype(model.config.dtype)
|
||||||
model = model.to(lm_dtype)
|
model = model.to(lm_dtype)
|
||||||
@@ -950,7 +970,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
|
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
|
||||||
state_dict = load_file(str(safetensors_file), device=device)
|
state_dict = load_file(str(safetensors_file), device=device)
|
||||||
elif ckpt_file and ckpt_file.exists():
|
elif ckpt_file and ckpt_file.exists():
|
||||||
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
|
ckpt = torch.load(ckpt_file, map_location=device, weights_only=True)
|
||||||
state_dict = ckpt.get("state_dict", ckpt)
|
state_dict = ckpt.get("state_dict", ckpt)
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}")
|
raise FileNotFoundError(f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}")
|
||||||
|
|||||||
+69
-64
@@ -45,28 +45,17 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules
|
|||||||
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiTV2
|
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiTV2
|
||||||
from ..modules.locenc import VoxCPMLocEnc
|
from ..modules.locenc import VoxCPMLocEnc
|
||||||
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
||||||
from .utils import get_dtype, mask_multichar_chinese_tokens
|
from .utils import (
|
||||||
|
get_dtype,
|
||||||
|
mask_multichar_chinese_tokens,
|
||||||
|
next_and_close,
|
||||||
|
pick_runtime_dtype,
|
||||||
|
resolve_runtime_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _trim_audio_silence_vad(
|
# A simple function to trim audio silence using VAD, not used default
|
||||||
audio: torch.Tensor,
|
def _trim_audio_silence_vad(audio: torch.Tensor, sample_rate: int, max_silence_ms: float = 200.0, top_db: float = 35.0) -> torch.Tensor:
|
||||||
sample_rate: int,
|
|
||||||
max_silence_ms: float = 200.0,
|
|
||||||
top_db: float = 35.0,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""使用能量阈值(VAD 方式)截取首尾静音及尾部长段伪静音,首尾各最多保留 max_silence_ms 毫秒静音。
|
|
||||||
|
|
||||||
会同时截掉末尾的长段伪静音(低能量但非完全静音的段落,如长时间底噪)。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
audio: (1, T) 的音频 tensor
|
|
||||||
sample_rate: 采样率
|
|
||||||
max_silence_ms: 首尾允许保留的最大静音长度(毫秒)
|
|
||||||
top_db: 低于参考电平多少 dB 视为静音
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
截取后的 (1, T') tensor
|
|
||||||
"""
|
|
||||||
if audio.numel() == 0:
|
if audio.numel() == 0:
|
||||||
return audio
|
return audio
|
||||||
y = audio.squeeze(0).numpy()
|
y = audio.squeeze(0).numpy()
|
||||||
@@ -85,7 +74,7 @@ def _trim_audio_silence_vad(
|
|||||||
except Exception:
|
except Exception:
|
||||||
start, end = 0, n
|
start, end = 0, n
|
||||||
|
|
||||||
# 用逐帧 RMS 找「最后一段有持续能量的位置」,截掉末尾长伪静音(低能量底噪等)
|
# Find the last frame with continuous energy, trim the long pseudo-silence at the end (low energy background noise, etc.)
|
||||||
n_frames = max(0, (n - frame_length) // hop_length + 1)
|
n_frames = max(0, (n - frame_length) // hop_length + 1)
|
||||||
last_voice_frame = -1
|
last_voice_frame = -1
|
||||||
for j in range(n_frames):
|
for j in range(n_frames):
|
||||||
@@ -168,18 +157,22 @@ class VoxCPM2Model(nn.Module):
|
|||||||
tokenizer: LlamaTokenizerFast,
|
tokenizer: LlamaTokenizerFast,
|
||||||
audio_vae: AudioVAEV2,
|
audio_vae: AudioVAEV2,
|
||||||
lora_config: LoRAConfig = None,
|
lora_config: LoRAConfig = None,
|
||||||
|
device: str | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
self.feat_dim = config.feat_dim
|
self.feat_dim = config.feat_dim
|
||||||
self.patch_size = config.patch_size
|
self.patch_size = config.patch_size
|
||||||
self.device = config.device
|
self.device = resolve_runtime_device(device, config.device)
|
||||||
if not torch.cuda.is_available():
|
self.config.device = self.device
|
||||||
if torch.backends.mps.is_available():
|
resolved_dtype = pick_runtime_dtype(self.device, self.config.dtype)
|
||||||
self.device = "mps"
|
if resolved_dtype != self.config.dtype:
|
||||||
else:
|
print(
|
||||||
self.device = "cpu"
|
f"[voxcpm2] adjusted dtype {self.config.dtype} -> {resolved_dtype} for device {self.device}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
self.config.dtype = resolved_dtype
|
||||||
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
|
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
|
||||||
|
|
||||||
# Text-Semantic LM
|
# Text-Semantic LM
|
||||||
@@ -246,6 +239,7 @@ class VoxCPM2Model(nn.Module):
|
|||||||
# Audio VAE
|
# Audio VAE
|
||||||
self.audio_vae = audio_vae
|
self.audio_vae = audio_vae
|
||||||
self.chunk_size = audio_vae.chunk_size
|
self.chunk_size = audio_vae.chunk_size
|
||||||
|
self._decode_chunk_size = getattr(audio_vae, "decode_chunk_size", audio_vae.chunk_size)
|
||||||
self._encode_sample_rate = audio_vae.sample_rate
|
self._encode_sample_rate = audio_vae.sample_rate
|
||||||
self.sample_rate = getattr(audio_vae, "out_sample_rate", audio_vae.sample_rate)
|
self.sample_rate = getattr(audio_vae, "out_sample_rate", audio_vae.sample_rate)
|
||||||
|
|
||||||
@@ -291,6 +285,7 @@ class VoxCPM2Model(nn.Module):
|
|||||||
self.residual_lm.forward_step = torch.compile(
|
self.residual_lm.forward_step = torch.compile(
|
||||||
self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True
|
self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True
|
||||||
)
|
)
|
||||||
|
self._feat_encoder_raw = self.feat_encoder
|
||||||
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
||||||
self.feat_decoder.estimator = torch.compile(
|
self.feat_decoder.estimator = torch.compile(
|
||||||
self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True
|
self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True
|
||||||
@@ -382,11 +377,7 @@ class VoxCPM2Model(nn.Module):
|
|||||||
mu=dit_hidden,
|
mu=dit_hidden,
|
||||||
patch_size=self.patch_size,
|
patch_size=self.patch_size,
|
||||||
cond=feat_cond_for_sample,
|
cond=feat_cond_for_sample,
|
||||||
n_timesteps=(
|
n_timesteps=10,
|
||||||
self.config.dit_config.cfm_config.inference_cfg_rate
|
|
||||||
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
|
|
||||||
else 10
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
|
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
|
||||||
|
|
||||||
@@ -463,7 +454,7 @@ class VoxCPM2Model(nn.Module):
|
|||||||
return tokens, feats, t_mask, a_mask
|
return tokens, feats, t_mask, a_mask
|
||||||
|
|
||||||
def generate(self, *args, **kwargs) -> torch.Tensor:
|
def generate(self, *args, **kwargs) -> torch.Tensor:
|
||||||
return next(self._generate(*args, streaming=False, **kwargs))
|
return next_and_close(self._generate(*args, streaming=False, **kwargs))
|
||||||
|
|
||||||
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
|
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
|
||||||
return self._generate(*args, streaming=True, **kwargs)
|
return self._generate(*args, streaming=True, **kwargs)
|
||||||
@@ -656,14 +647,14 @@ class VoxCPM2Model(nn.Module):
|
|||||||
streaming_prefix_len=streaming_prefix_len,
|
streaming_prefix_len=streaming_prefix_len,
|
||||||
)
|
)
|
||||||
if streaming:
|
if streaming:
|
||||||
patch_len = self.patch_size * self.chunk_size
|
with self.audio_vae.streaming_decode() as vae_dec:
|
||||||
for latent_pred, _ in inference_result:
|
for latent_pred, _, _ctx in inference_result:
|
||||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
decode_audio = vae_dec.decode_chunk(latent_pred.to(torch.float32))
|
||||||
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
|
decode_audio = decode_audio.squeeze(1).cpu()
|
||||||
yield decode_audio
|
yield decode_audio
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
latent_pred, pred_audio_feat = next(inference_result)
|
latent_pred, pred_audio_feat, context_len = next_and_close(inference_result)
|
||||||
if retry_badcase:
|
if retry_badcase:
|
||||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||||
print(
|
print(
|
||||||
@@ -679,10 +670,9 @@ class VoxCPM2Model(nn.Module):
|
|||||||
|
|
||||||
if not streaming:
|
if not streaming:
|
||||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||||
patch_len = self.patch_size * self.chunk_size
|
decode_patch_len = self.patch_size * self._decode_chunk_size
|
||||||
has_continuation = bool(prompt_wav_path)
|
if context_len > 0:
|
||||||
if has_continuation:
|
decode_audio = decode_audio[..., decode_patch_len * context_len:].squeeze(1).cpu()
|
||||||
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1):].squeeze(1).cpu()
|
|
||||||
else:
|
else:
|
||||||
decode_audio = decode_audio.squeeze(1).cpu()
|
decode_audio = decode_audio.squeeze(1).cpu()
|
||||||
yield decode_audio
|
yield decode_audio
|
||||||
@@ -782,7 +772,7 @@ class VoxCPM2Model(nn.Module):
|
|||||||
return merged
|
return merged
|
||||||
|
|
||||||
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
return next_and_close(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
||||||
|
|
||||||
def generate_with_prompt_cache_streaming(
|
def generate_with_prompt_cache_streaming(
|
||||||
self, *args, **kwargs
|
self, *args, **kwargs
|
||||||
@@ -944,14 +934,14 @@ class VoxCPM2Model(nn.Module):
|
|||||||
streaming_prefix_len=streaming_prefix_len,
|
streaming_prefix_len=streaming_prefix_len,
|
||||||
)
|
)
|
||||||
if streaming:
|
if streaming:
|
||||||
patch_len = self.patch_size * self.chunk_size
|
with self.audio_vae.streaming_decode() as vae_dec:
|
||||||
for latent_pred, pred_audio_feat in inference_result:
|
for latent_pred, pred_audio_feat, _ctx in inference_result:
|
||||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
decode_audio = vae_dec.decode_chunk(latent_pred.to(torch.float32))
|
||||||
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
|
decode_audio = decode_audio.squeeze(1).cpu()
|
||||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
latent_pred, pred_audio_feat = next(inference_result)
|
latent_pred, pred_audio_feat, context_len = next_and_close(inference_result)
|
||||||
if retry_badcase:
|
if retry_badcase:
|
||||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||||
print(
|
print(
|
||||||
@@ -966,18 +956,20 @@ class VoxCPM2Model(nn.Module):
|
|||||||
break
|
break
|
||||||
if not streaming:
|
if not streaming:
|
||||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||||
patch_len = self.patch_size * self.chunk_size
|
decode_patch_len = self.patch_size * self._decode_chunk_size
|
||||||
if mode in ("continuation", "ref_continuation"):
|
if context_len > 0:
|
||||||
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1) :].squeeze(1).cpu()
|
decode_audio = decode_audio[..., decode_patch_len * context_len:].squeeze(1).cpu()
|
||||||
else:
|
else:
|
||||||
decode_audio = decode_audio[..., :].squeeze(1).cpu()
|
decode_audio = decode_audio.squeeze(1).cpu()
|
||||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||||
|
|
||||||
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
return next(self._inference(*args, streaming=False, **kwargs))
|
feat_pred, generated_feat, _ = next_and_close(self._inference(*args, streaming=False, **kwargs))
|
||||||
|
return feat_pred, generated_feat
|
||||||
|
|
||||||
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
|
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
|
||||||
return self._inference(*args, streaming=True, **kwargs)
|
for feat_pred, pred_feat_seq, _ in self._inference(*args, streaming=True, **kwargs):
|
||||||
|
yield feat_pred, pred_feat_seq
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def _inference(
|
def _inference(
|
||||||
@@ -1016,7 +1008,8 @@ class VoxCPM2Model(nn.Module):
|
|||||||
"""
|
"""
|
||||||
B, T, P, D = feat.shape
|
B, T, P, D = feat.shape
|
||||||
|
|
||||||
feat_embed = self.feat_encoder(feat) # [b, t, h_feat]
|
prefill_encoder = getattr(self, "_feat_encoder_raw", self.feat_encoder)
|
||||||
|
feat_embed = prefill_encoder(feat) # [b, t, h_feat]
|
||||||
feat_embed = self.enc_to_lm_proj(feat_embed)
|
feat_embed = self.enc_to_lm_proj(feat_embed)
|
||||||
|
|
||||||
if self.config.lm_config.use_mup:
|
if self.config.lm_config.use_mup:
|
||||||
@@ -1036,6 +1029,7 @@ class VoxCPM2Model(nn.Module):
|
|||||||
# trailing audio patches as initial context so the VAE can decode smoothly.
|
# trailing audio patches as initial context so the VAE can decode smoothly.
|
||||||
# - Reference-only / zero-shot (feat_mask ends with 0): start from scratch.
|
# - Reference-only / zero-shot (feat_mask ends with 0): start from scratch.
|
||||||
has_continuation_audio = feat_mask[0, -1].item() == 1
|
has_continuation_audio = feat_mask[0, -1].item() == 1
|
||||||
|
context_len = 0
|
||||||
if has_continuation_audio:
|
if has_continuation_audio:
|
||||||
audio_indices = feat_mask.squeeze(0).nonzero(as_tuple=True)[0]
|
audio_indices = feat_mask.squeeze(0).nonzero(as_tuple=True)[0]
|
||||||
context_len = min(streaming_prefix_len - 1, len(audio_indices))
|
context_len = min(streaming_prefix_len - 1, len(audio_indices))
|
||||||
@@ -1085,11 +1079,13 @@ class VoxCPM2Model(nn.Module):
|
|||||||
prefix_feat_cond = pred_feat
|
prefix_feat_cond = pred_feat
|
||||||
|
|
||||||
if streaming:
|
if streaming:
|
||||||
# return the last three predicted latent features to provide enough context for smooth decoding
|
# Yield only the newest patch latent for stateful VAE decode
|
||||||
pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1)
|
feat_pred = rearrange(pred_feat.unsqueeze(1), "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||||
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
|
||||||
|
|
||||||
yield feat_pred, pred_feat_seq
|
yield feat_pred, pred_feat_seq, context_len
|
||||||
|
|
||||||
|
if len(pred_feat_seq) > streaming_prefix_len:
|
||||||
|
pred_feat_seq = pred_feat_seq[-streaming_prefix_len:]
|
||||||
|
|
||||||
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
|
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
|
||||||
if i > min_len and stop_flag == 1:
|
if i > min_len and stop_flag == 1:
|
||||||
@@ -1108,11 +1104,20 @@ class VoxCPM2Model(nn.Module):
|
|||||||
if not streaming:
|
if not streaming:
|
||||||
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
|
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
|
||||||
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||||
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
|
generated_feat = pred_feat_seq[:, context_len:, :, :].squeeze(0).cpu()
|
||||||
|
yield feat_pred, generated_feat, context_len
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
|
def from_local(
|
||||||
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
cls,
|
||||||
|
path: str,
|
||||||
|
optimize: bool = True,
|
||||||
|
training: bool = False,
|
||||||
|
device: str | None = None,
|
||||||
|
lora_config: LoRAConfig = None,
|
||||||
|
):
|
||||||
|
with open(os.path.join(path, "config.json"), "r", encoding="utf-8") as _cfg_f:
|
||||||
|
config = VoxCPMConfig.model_validate_json(_cfg_f.read())
|
||||||
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
||||||
audio_vae_config = getattr(config, "audio_vae_config", None)
|
audio_vae_config = getattr(config, "audio_vae_config", None)
|
||||||
audio_vae = AudioVAEV2(config=audio_vae_config) if audio_vae_config else AudioVAEV2()
|
audio_vae = AudioVAEV2(config=audio_vae_config) if audio_vae_config else AudioVAEV2()
|
||||||
@@ -1134,7 +1139,7 @@ class VoxCPM2Model(nn.Module):
|
|||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
|
f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
|
||||||
)
|
)
|
||||||
model = cls(config, tokenizer, audio_vae, lora_config)
|
model = cls(config, tokenizer, audio_vae, lora_config, device=device)
|
||||||
if not training:
|
if not training:
|
||||||
lm_dtype = get_dtype(model.config.dtype)
|
lm_dtype = get_dtype(model.config.dtype)
|
||||||
model = model.to(lm_dtype)
|
model = model.to(lm_dtype)
|
||||||
@@ -1216,7 +1221,7 @@ class VoxCPM2Model(nn.Module):
|
|||||||
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
|
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
|
||||||
state_dict = load_file(str(safetensors_file), device=device)
|
state_dict = load_file(str(safetensors_file), device=device)
|
||||||
elif ckpt_file and ckpt_file.exists():
|
elif ckpt_file and ckpt_file.exists():
|
||||||
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
|
ckpt = torch.load(ckpt_file, map_location=device, weights_only=True)
|
||||||
state_dict = ckpt.get("state_dict", ckpt)
|
state_dict = ckpt.get("state_dict", ckpt)
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}")
|
raise FileNotFoundError(f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}")
|
||||||
|
|||||||
@@ -436,6 +436,7 @@ class AudioVAE(nn.Module):
|
|||||||
self.out_sample_rate = out_sample_rate
|
self.out_sample_rate = out_sample_rate
|
||||||
self.sr_bin_boundaries = sr_bin_boundaries
|
self.sr_bin_boundaries = sr_bin_boundaries
|
||||||
self.chunk_size = math.prod(encoder_rates)
|
self.chunk_size = math.prod(encoder_rates)
|
||||||
|
self.decode_chunk_size = math.prod(decoder_rates)
|
||||||
|
|
||||||
def preprocess(self, audio_data, sample_rate):
|
def preprocess(self, audio_data, sample_rate):
|
||||||
if sample_rate is None:
|
if sample_rate is None:
|
||||||
@@ -471,6 +472,20 @@ class AudioVAE(nn.Module):
|
|||||||
sr_cond = torch.tensor([self.out_sample_rate], device=z.device, dtype=torch.int32)
|
sr_cond = torch.tensor([self.out_sample_rate], device=z.device, dtype=torch.int32)
|
||||||
return self.decoder(z, sr_cond)
|
return self.decoder(z, sr_cond)
|
||||||
|
|
||||||
|
def streaming_decode(self):
|
||||||
|
"""Return a ``StreamingVAEDecoder`` context manager for stateful
|
||||||
|
chunk-by-chunk decoding. Each call to ``decode_chunk`` processes only
|
||||||
|
the new latent patch and carries causal-conv state internally, avoiding
|
||||||
|
the redundant overlap decode used previously.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
with vae.streaming_decode() as dec:
|
||||||
|
for patch in patches:
|
||||||
|
audio_chunk = dec.decode_chunk(patch)
|
||||||
|
"""
|
||||||
|
return StreamingVAEDecoder(self)
|
||||||
|
|
||||||
def encode(self, audio_data: torch.Tensor, sample_rate: int):
|
def encode(self, audio_data: torch.Tensor, sample_rate: int):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -484,3 +499,82 @@ class AudioVAE(nn.Module):
|
|||||||
|
|
||||||
audio_data = self.preprocess(audio_data, sample_rate)
|
audio_data = self.preprocess(audio_data, sample_rate)
|
||||||
return self.encoder(audio_data)["mu"]
|
return self.encoder(audio_data)["mu"]
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingVAEDecoder:
|
||||||
|
"""Stateful streaming wrapper for :class:`AudioVAE`.
|
||||||
|
|
||||||
|
Carries causal-convolution padding buffers between calls so that each
|
||||||
|
``decode_chunk`` processes only the new latent patch — no overlap needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, vae: AudioVAE):
|
||||||
|
self._vae = vae
|
||||||
|
self._states: dict = {}
|
||||||
|
self._originals: list = []
|
||||||
|
|
||||||
|
# -- context manager --------------------------------------------------
|
||||||
|
def __enter__(self):
|
||||||
|
self._states.clear()
|
||||||
|
self._install()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *exc):
|
||||||
|
self._restore()
|
||||||
|
self._states.clear()
|
||||||
|
|
||||||
|
# -- public API --------------------------------------------------------
|
||||||
|
def decode_chunk(self, z_chunk: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Decode a single latent chunk and return the audio waveform."""
|
||||||
|
return self._vae.decode(z_chunk)
|
||||||
|
|
||||||
|
# -- internals ---------------------------------------------------------
|
||||||
|
def _install(self):
|
||||||
|
for name, mod in self._vae.decoder.named_modules():
|
||||||
|
if isinstance(mod, CausalConv1d):
|
||||||
|
pad = mod._CausalConv1d__padding * 2 - mod._CausalConv1d__output_padding
|
||||||
|
if pad > 0:
|
||||||
|
self._patch_causal_conv(mod, pad)
|
||||||
|
elif isinstance(mod, CausalTransposeConv1d):
|
||||||
|
trim = mod._CausalTransposeConv1d__padding * 2 - mod._CausalTransposeConv1d__output_padding
|
||||||
|
ctx = (mod.kernel_size[0] - 1) // mod.stride[0]
|
||||||
|
if ctx > 0:
|
||||||
|
self._patch_transpose_conv(mod, ctx, trim)
|
||||||
|
|
||||||
|
def _patch_causal_conv(self, mod, pad_size):
|
||||||
|
states = self._states
|
||||||
|
key = id(mod)
|
||||||
|
orig = mod.forward
|
||||||
|
|
||||||
|
def fwd(x, _k=key, _p=pad_size, _m=mod):
|
||||||
|
x_pad = torch.cat([states[_k], x], dim=-1) if _k in states else F.pad(x, (_p, 0))
|
||||||
|
if x.shape[-1] >= _p:
|
||||||
|
states[_k] = x[:, :, -_p:].detach()
|
||||||
|
else:
|
||||||
|
prev = states.get(_k, torch.zeros(x.shape[0], x.shape[1], _p,
|
||||||
|
device=x.device, dtype=x.dtype))
|
||||||
|
states[_k] = torch.cat([prev, x], dim=-1)[:, :, -_p:].detach()
|
||||||
|
return nn.Conv1d.forward(_m, x_pad)
|
||||||
|
|
||||||
|
mod.forward = fwd
|
||||||
|
self._originals.append((mod, orig))
|
||||||
|
|
||||||
|
def _patch_transpose_conv(self, mod, ctx, trim):
|
||||||
|
states = self._states
|
||||||
|
key = id(mod)
|
||||||
|
orig = mod.forward
|
||||||
|
|
||||||
|
def fwd(x, _k=key, _c=ctx, _t=trim, _m=mod):
|
||||||
|
x_full = torch.cat([states[_k], x], dim=-1) if _k in states else F.pad(x, (_c, 0))
|
||||||
|
states[_k] = x[:, :, -_c:].detach()
|
||||||
|
out = nn.ConvTranspose1d.forward(_m, x_full)
|
||||||
|
left = _c * _m.stride[0]
|
||||||
|
return out[..., left:-_t] if _t > 0 else out[..., left:]
|
||||||
|
|
||||||
|
mod.forward = fwd
|
||||||
|
self._originals.append((mod, orig))
|
||||||
|
|
||||||
|
def _restore(self):
|
||||||
|
for mod, orig in self._originals:
|
||||||
|
mod.forward = orig
|
||||||
|
self._originals.clear()
|
||||||
|
|||||||
@@ -225,7 +225,7 @@ class UnifiedCFM(torch.nn.Module):
|
|||||||
losses = F.mse_loss(u_pred, u_tgt.detach(), reduction="none").mean(dim=1)
|
losses = F.mse_loss(u_pred, u_tgt.detach(), reduction="none").mean(dim=1)
|
||||||
if tgt_mask is not None:
|
if tgt_mask is not None:
|
||||||
weights = self.adaptive_loss_weighting(losses, tgt_mask.squeeze(1))
|
weights = self.adaptive_loss_weighting(losses, tgt_mask.squeeze(1))
|
||||||
loss = (weights * losses).sum() / torch.sum(tgt_mask)
|
loss = (weights * losses).sum() / torch.clamp(torch.sum(tgt_mask), min=1.0)
|
||||||
else:
|
else:
|
||||||
loss = losses.mean()
|
loss = losses.mean()
|
||||||
|
|
||||||
|
|||||||
@@ -196,7 +196,9 @@ class MiniCPMAttention(nn.Module):
|
|||||||
key_cache[:, :, position_id, :] = key_states
|
key_cache[:, :, position_id, :] = key_states
|
||||||
value_cache[:, :, position_id, :] = value_states
|
value_cache[:, :, position_id, :] = value_states
|
||||||
|
|
||||||
attn_mask = torch.arange(key_cache.size(2), device=key_cache.device) <= position_id
|
# Use an explicit broadcastable mask shape for SDPA. A 1D mask can
|
||||||
|
# trigger a CPU-side dimension bug in some PyTorch versions.
|
||||||
|
attn_mask = (torch.arange(key_cache.size(2), device=key_cache.device) <= position_id).view(1, 1, 1, -1)
|
||||||
|
|
||||||
# ref: https://github.com/pytorch/pytorch/issues/163597
|
# ref: https://github.com/pytorch/pytorch/issues/163597
|
||||||
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
|
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from .data import (
|
|||||||
BatchProcessor,
|
BatchProcessor,
|
||||||
)
|
)
|
||||||
from .state import TrainingState
|
from .state import TrainingState
|
||||||
|
from .validate import validate_manifest, ValidationResult
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Accelerator",
|
"Accelerator",
|
||||||
@@ -24,4 +25,6 @@ __all__ = [
|
|||||||
"TrainingState",
|
"TrainingState",
|
||||||
"load_audio_text_datasets",
|
"load_audio_text_datasets",
|
||||||
"build_dataloader",
|
"build_dataloader",
|
||||||
|
"validate_manifest",
|
||||||
|
"ValidationResult",
|
||||||
]
|
]
|
||||||
|
|||||||
+53
-11
@@ -12,6 +12,7 @@ from .packers import AudioFeatureProcessingPacker
|
|||||||
|
|
||||||
DEFAULT_TEXT_COLUMN = "text"
|
DEFAULT_TEXT_COLUMN = "text"
|
||||||
DEFAULT_AUDIO_COLUMN = "audio"
|
DEFAULT_AUDIO_COLUMN = "audio"
|
||||||
|
DEFAULT_REF_AUDIO_COLUMN = "ref_audio"
|
||||||
DEFAULT_ID_COLUMN = "dataset_id"
|
DEFAULT_ID_COLUMN = "dataset_id"
|
||||||
|
|
||||||
|
|
||||||
@@ -21,6 +22,7 @@ def load_audio_text_datasets(
|
|||||||
val_manifest: str = "",
|
val_manifest: str = "",
|
||||||
text_column: str = DEFAULT_TEXT_COLUMN,
|
text_column: str = DEFAULT_TEXT_COLUMN,
|
||||||
audio_column: str = DEFAULT_AUDIO_COLUMN,
|
audio_column: str = DEFAULT_AUDIO_COLUMN,
|
||||||
|
ref_audio_column: str = DEFAULT_REF_AUDIO_COLUMN,
|
||||||
dataset_id_column: str = DEFAULT_ID_COLUMN,
|
dataset_id_column: str = DEFAULT_ID_COLUMN,
|
||||||
sample_rate: int = 16_000,
|
sample_rate: int = 16_000,
|
||||||
num_proc: int = 1,
|
num_proc: int = 1,
|
||||||
@@ -34,14 +36,19 @@ def load_audio_text_datasets(
|
|||||||
def prepare(ds: Dataset) -> Dataset:
|
def prepare(ds: Dataset) -> Dataset:
|
||||||
if audio_column not in ds.column_names:
|
if audio_column not in ds.column_names:
|
||||||
raise ValueError(f"Expected '{audio_column}' column in manifest.")
|
raise ValueError(f"Expected '{audio_column}' column in manifest.")
|
||||||
# We cast to Audio to ensure proper handling during training,
|
|
||||||
# but for length calculation we might need raw path or duration if available.
|
|
||||||
# HF datasets usually don't compute duration automatically for 'Audio' column.
|
|
||||||
ds = ds.cast_column(audio_column, Audio(sampling_rate=sample_rate))
|
ds = ds.cast_column(audio_column, Audio(sampling_rate=sample_rate))
|
||||||
if audio_column != DEFAULT_AUDIO_COLUMN:
|
if audio_column != DEFAULT_AUDIO_COLUMN:
|
||||||
ds = ds.rename_column(audio_column, DEFAULT_AUDIO_COLUMN)
|
ds = ds.rename_column(audio_column, DEFAULT_AUDIO_COLUMN)
|
||||||
if text_column != DEFAULT_TEXT_COLUMN:
|
if text_column != DEFAULT_TEXT_COLUMN:
|
||||||
ds = ds.rename_column(text_column, DEFAULT_TEXT_COLUMN)
|
ds = ds.rename_column(text_column, DEFAULT_TEXT_COLUMN)
|
||||||
|
|
||||||
|
# ref_audio is optional — cast to Audio if the column exists
|
||||||
|
ref_col = ref_audio_column if ref_audio_column in ds.column_names else DEFAULT_REF_AUDIO_COLUMN
|
||||||
|
if ref_col in ds.column_names:
|
||||||
|
ds = ds.cast_column(ref_col, Audio(sampling_rate=sample_rate))
|
||||||
|
if ref_col != DEFAULT_REF_AUDIO_COLUMN:
|
||||||
|
ds = ds.rename_column(ref_col, DEFAULT_REF_AUDIO_COLUMN)
|
||||||
|
|
||||||
if dataset_id_column and dataset_id_column in ds.column_names:
|
if dataset_id_column and dataset_id_column in ds.column_names:
|
||||||
if dataset_id_column != DEFAULT_ID_COLUMN:
|
if dataset_id_column != DEFAULT_ID_COLUMN:
|
||||||
ds = ds.rename_column(dataset_id_column, DEFAULT_ID_COLUMN)
|
ds = ds.rename_column(dataset_id_column, DEFAULT_ID_COLUMN)
|
||||||
@@ -67,11 +74,11 @@ def compute_sample_lengths(
|
|||||||
- 音频长度:
|
- 音频长度:
|
||||||
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
|
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
|
||||||
t_seq = ceil(t_vae / patch_size)
|
t_seq = ceil(t_vae / patch_size)
|
||||||
- 序列总长约为: text_len + t_seq + 2
|
- 无 ref_audio: text_len + t_seq + 2
|
||||||
|
- 有 ref_audio: text_len + t_seq + ref_seq + 4
|
||||||
|
|
||||||
Optimized: Use batch column access instead of iterating item by item.
|
Optimized: Use batch column access instead of iterating item by item.
|
||||||
"""
|
"""
|
||||||
# Batch access columns - much faster than per-item access
|
|
||||||
text_ids_list = ds["text_ids"]
|
text_ids_list = ds["text_ids"]
|
||||||
text_lens = [len(t) for t in text_ids_list]
|
text_lens = [len(t) for t in text_ids_list]
|
||||||
|
|
||||||
@@ -79,18 +86,35 @@ def compute_sample_lengths(
|
|||||||
if has_duration:
|
if has_duration:
|
||||||
durations = ds["duration"]
|
durations = ds["duration"]
|
||||||
else:
|
else:
|
||||||
# Fallback: need to compute from audio (slow, but unavoidable without duration column)
|
|
||||||
durations = []
|
durations = []
|
||||||
for i in range(len(ds)):
|
for i in range(len(ds)):
|
||||||
audio = ds[i][DEFAULT_AUDIO_COLUMN]
|
audio = ds[i][DEFAULT_AUDIO_COLUMN]
|
||||||
durations.append(len(audio["array"]) / float(audio["sampling_rate"]))
|
durations.append(len(audio["array"]) / float(audio["sampling_rate"]))
|
||||||
|
|
||||||
# Vectorized length computation
|
has_ref_audio = DEFAULT_REF_AUDIO_COLUMN in ds.column_names
|
||||||
|
if has_ref_audio:
|
||||||
|
ref_duration_col = "ref_duration" if "ref_duration" in ds.column_names else None
|
||||||
|
|
||||||
lengths = []
|
lengths = []
|
||||||
for text_len, duration in zip(text_lens, durations):
|
for i, (text_len, duration) in enumerate(zip(text_lens, durations)):
|
||||||
t_vae = math.ceil(float(duration) * audio_vae_fps)
|
t_vae = math.ceil(float(duration) * audio_vae_fps)
|
||||||
t_seq = math.ceil(t_vae / patch_size)
|
t_seq = math.ceil(t_vae / patch_size)
|
||||||
total_len = text_len + t_seq + 2
|
|
||||||
|
ref_seq = 0
|
||||||
|
if has_ref_audio:
|
||||||
|
# Estimate ref_audio length; ref_audio is None for samples without it
|
||||||
|
if ref_duration_col:
|
||||||
|
ref_dur = ds[i].get(ref_duration_col)
|
||||||
|
else:
|
||||||
|
ref_item = ds[i].get(DEFAULT_REF_AUDIO_COLUMN)
|
||||||
|
ref_dur = len(ref_item["array"]) / float(ref_item["sampling_rate"]) if ref_item else None
|
||||||
|
if ref_dur is not None and float(ref_dur) > 0:
|
||||||
|
ref_vae = math.ceil(float(ref_dur) * audio_vae_fps)
|
||||||
|
ref_seq = math.ceil(ref_vae / patch_size)
|
||||||
|
|
||||||
|
# +2 for 101/102; +2 more for 103/104 when ref_audio present
|
||||||
|
overhead = 4 if ref_seq > 0 else 2
|
||||||
|
total_len = text_len + t_seq + ref_seq + overhead
|
||||||
lengths.append(total_len)
|
lengths.append(total_len)
|
||||||
|
|
||||||
return lengths
|
return lengths
|
||||||
@@ -102,8 +126,11 @@ class HFVoxCPMDataset(TorchDataset):
|
|||||||
PyTorch-friendly samples.
|
PyTorch-friendly samples.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_SENTINEL = [-100.0]
|
||||||
|
|
||||||
def __init__(self, dataset: Dataset):
|
def __init__(self, dataset: Dataset):
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
|
self.has_ref_audio = DEFAULT_REF_AUDIO_COLUMN in dataset.column_names
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.dataset)
|
return len(self.dataset)
|
||||||
@@ -111,13 +138,17 @@ class HFVoxCPMDataset(TorchDataset):
|
|||||||
def __getitem__(self, idx: int):
|
def __getitem__(self, idx: int):
|
||||||
item = self.dataset[idx]
|
item = self.dataset[idx]
|
||||||
audio = item[DEFAULT_AUDIO_COLUMN]
|
audio = item[DEFAULT_AUDIO_COLUMN]
|
||||||
return {
|
sample = {
|
||||||
"text_ids": item["text_ids"],
|
"text_ids": item["text_ids"],
|
||||||
"audio_array": audio["array"],
|
"audio_array": audio["array"],
|
||||||
"audio_sampling_rate": audio["sampling_rate"],
|
"audio_sampling_rate": audio["sampling_rate"],
|
||||||
"dataset_id": item.get(DEFAULT_ID_COLUMN, 0),
|
"dataset_id": item.get(DEFAULT_ID_COLUMN, 0),
|
||||||
"is_prompt": item.get("is_prompt", False),
|
"is_prompt": item.get("is_prompt", False),
|
||||||
}
|
}
|
||||||
|
if self.has_ref_audio:
|
||||||
|
ref = item.get(DEFAULT_REF_AUDIO_COLUMN)
|
||||||
|
sample["ref_audio_array"] = ref["array"] if ref else self._SENTINEL
|
||||||
|
return sample
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def pad_sequences(seqs: List[torch.Tensor], pad_value: float):
|
def pad_sequences(seqs: List[torch.Tensor], pad_value: float):
|
||||||
@@ -143,7 +174,7 @@ class HFVoxCPMDataset(TorchDataset):
|
|||||||
audio_padded = cls.pad_sequences(audio_tensors, pad_value=-100.0)
|
audio_padded = cls.pad_sequences(audio_tensors, pad_value=-100.0)
|
||||||
task_ids = torch.ones(text_padded.size(0), dtype=torch.int32)
|
task_ids = torch.ones(text_padded.size(0), dtype=torch.int32)
|
||||||
|
|
||||||
return {
|
result = {
|
||||||
"text_tokens": text_padded,
|
"text_tokens": text_padded,
|
||||||
"audio_tokens": audio_padded,
|
"audio_tokens": audio_padded,
|
||||||
"task_ids": task_ids,
|
"task_ids": task_ids,
|
||||||
@@ -151,6 +182,12 @@ class HFVoxCPMDataset(TorchDataset):
|
|||||||
"is_prompts": is_prompts,
|
"is_prompts": is_prompts,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if "ref_audio_array" in batch[0]:
|
||||||
|
ref_tensors = [torch.tensor(s["ref_audio_array"], dtype=torch.float32) for s in batch]
|
||||||
|
result["ref_audio_tokens"] = cls.pad_sequences(ref_tensors, pad_value=-100.0)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class BatchProcessor:
|
class BatchProcessor:
|
||||||
"""
|
"""
|
||||||
@@ -184,12 +221,17 @@ class BatchProcessor:
|
|||||||
task_ids = batch["task_ids"].to(self.device)
|
task_ids = batch["task_ids"].to(self.device)
|
||||||
dataset_ids = batch["dataset_ids"].to(self.device)
|
dataset_ids = batch["dataset_ids"].to(self.device)
|
||||||
|
|
||||||
|
ref_audio_tokens = None
|
||||||
|
if "ref_audio_tokens" in batch:
|
||||||
|
ref_audio_tokens = batch["ref_audio_tokens"].to(self.device)
|
||||||
|
|
||||||
packed = self.packer(
|
packed = self.packer(
|
||||||
audio_tokens=audio_tokens,
|
audio_tokens=audio_tokens,
|
||||||
text_tokens=text_tokens,
|
text_tokens=text_tokens,
|
||||||
task_ids=task_ids,
|
task_ids=task_ids,
|
||||||
dataset_ids=dataset_ids,
|
dataset_ids=dataset_ids,
|
||||||
is_prompts=batch["is_prompts"],
|
is_prompts=batch["is_prompts"],
|
||||||
|
ref_audio_tokens=ref_audio_tokens,
|
||||||
)
|
)
|
||||||
return packed
|
return packed
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, List
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -14,7 +14,6 @@ class AudioFeatureProcessingPacker:
|
|||||||
def __init__(self, dataset_cnt: int, max_len: int, patch_size: int, feat_dim: int, audio_vae: nn.Module):
|
def __init__(self, dataset_cnt: int, max_len: int, patch_size: int, feat_dim: int, audio_vae: nn.Module):
|
||||||
self.audio_start_id = 101
|
self.audio_start_id = 101
|
||||||
self.audio_end_id = 102
|
self.audio_end_id = 102
|
||||||
# unused now
|
|
||||||
self.audio_prompt_start_id = 103
|
self.audio_prompt_start_id = 103
|
||||||
self.audio_prompt_end_id = 104
|
self.audio_prompt_end_id = 104
|
||||||
self.text_eos_token_id = 2
|
self.text_eos_token_id = 2
|
||||||
@@ -78,11 +77,16 @@ class AudioFeatureProcessingPacker:
|
|||||||
task_ids: torch.Tensor,
|
task_ids: torch.Tensor,
|
||||||
dataset_ids: torch.Tensor,
|
dataset_ids: torch.Tensor,
|
||||||
is_prompts: List[bool],
|
is_prompts: List[bool],
|
||||||
|
ref_audio_tokens: Optional[torch.Tensor] = None,
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Padding-based batching: each sample in the input batch is processed
|
Padding-based batching: each sample in the input batch is processed
|
||||||
independently and then padded to a common length (capped by ``max_len``).
|
independently and then padded to a common length (capped by ``max_len``).
|
||||||
The result tensors all have shape [B, T, ...].
|
The result tensors all have shape [B, T, ...].
|
||||||
|
|
||||||
|
If ``ref_audio_tokens`` is provided (same batch dim as ``audio_tokens``),
|
||||||
|
samples whose unpadded ref_audio length > 0 will be processed with the
|
||||||
|
reference-audio path (tokens 103/104 prepended, loss only on target audio).
|
||||||
"""
|
"""
|
||||||
device = audio_tokens.device
|
device = audio_tokens.device
|
||||||
max_dataset_id = int(dataset_ids.max().item()) if dataset_ids.numel() > 0 else -1
|
max_dataset_id = int(dataset_ids.max().item()) if dataset_ids.numel() > 0 else -1
|
||||||
@@ -101,13 +105,33 @@ class AudioFeatureProcessingPacker:
|
|||||||
audio_duration_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
|
audio_duration_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
|
||||||
text_token_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
|
text_token_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
|
||||||
|
|
||||||
for audio_token, text_token, task_id, dataset_idx, is_prompt in zip(
|
ref_iter = ref_audio_tokens if ref_audio_tokens is not None else [None] * audio_tokens.size(0)
|
||||||
audio_tokens, text_tokens, task_ids.tolist(), dataset_ids.tolist(), is_prompts
|
|
||||||
|
for audio_token, text_token, task_id, dataset_idx, is_prompt, ref_token in zip(
|
||||||
|
audio_tokens, text_tokens, task_ids.tolist(), dataset_ids.tolist(), is_prompts, ref_iter
|
||||||
):
|
):
|
||||||
unpad_audio_token = self.unpad_audio_tokens(audio_token).to(torch.float32)
|
unpad_audio_token = self.unpad_audio_tokens(audio_token).to(torch.float32)
|
||||||
unpad_text_token = self.unpad_text_tokens(text_token)
|
unpad_text_token = self.unpad_text_tokens(text_token)
|
||||||
usage = self.id_to_task[task_id]
|
usage = self.id_to_task[task_id]
|
||||||
|
|
||||||
|
has_ref = False
|
||||||
|
if ref_token is not None:
|
||||||
|
unpad_ref_token = self.unpad_audio_tokens(ref_token).to(torch.float32)
|
||||||
|
if unpad_ref_token.numel() > 0:
|
||||||
|
has_ref = True
|
||||||
|
|
||||||
|
if has_ref:
|
||||||
|
(
|
||||||
|
packed_text,
|
||||||
|
audio_feat,
|
||||||
|
text_mask,
|
||||||
|
audio_mask,
|
||||||
|
loss_mask,
|
||||||
|
labels,
|
||||||
|
audio_duration,
|
||||||
|
text_token_count,
|
||||||
|
) = self.process_tts_data_with_ref(unpad_ref_token, unpad_audio_token, unpad_text_token)
|
||||||
|
else:
|
||||||
(
|
(
|
||||||
packed_text,
|
packed_text,
|
||||||
audio_feat,
|
audio_feat,
|
||||||
@@ -294,3 +318,89 @@ class AudioFeatureProcessingPacker:
|
|||||||
audio_duration,
|
audio_duration,
|
||||||
text_token_count,
|
text_token_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def process_tts_data_with_ref(
|
||||||
|
self,
|
||||||
|
ref_audio_token: torch.Tensor,
|
||||||
|
target_audio_token: torch.Tensor,
|
||||||
|
text_token: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Build a training sequence with reference audio prepended:
|
||||||
|
|
||||||
|
[103, ref_feats, 104, text, 101, target_feats, 102]
|
||||||
|
|
||||||
|
Loss is computed only on the target audio segment.
|
||||||
|
"""
|
||||||
|
device = text_token.device
|
||||||
|
txt_len = len(text_token)
|
||||||
|
|
||||||
|
ref_feats, ref_duration = self.extract_audio_feats(ref_audio_token)
|
||||||
|
ref_feats = ref_feats.squeeze(0) # [R, P, D]
|
||||||
|
ref_len = ref_feats.shape[0]
|
||||||
|
|
||||||
|
tgt_feats, tgt_duration = self.extract_audio_feats(target_audio_token)
|
||||||
|
tgt_feats = tgt_feats.squeeze(0) # [A, P, D]
|
||||||
|
tgt_len = tgt_feats.shape[0]
|
||||||
|
|
||||||
|
feat_shape = (self.patch_size, ref_feats.size(-1))
|
||||||
|
|
||||||
|
def _tok(ids):
|
||||||
|
return torch.tensor(ids, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
# -- text token track --
|
||||||
|
# [103, 0×R, 104, text_ids, 101, 0×A, 102]
|
||||||
|
text_token_info = torch.cat([
|
||||||
|
_tok([self.audio_prompt_start_id]),
|
||||||
|
torch.zeros(ref_len, dtype=torch.int32, device=device),
|
||||||
|
_tok([self.audio_prompt_end_id]),
|
||||||
|
text_token,
|
||||||
|
_tok([self.audio_start_id]),
|
||||||
|
torch.zeros(tgt_len, dtype=torch.int32, device=device),
|
||||||
|
_tok([self.audio_end_id]),
|
||||||
|
])
|
||||||
|
|
||||||
|
# -- audio feature track --
|
||||||
|
zero_1 = torch.zeros((1,) + feat_shape, dtype=torch.float32, device=device)
|
||||||
|
zero_txt = torch.zeros((txt_len,) + feat_shape, dtype=torch.float32, device=device)
|
||||||
|
audio_feat_info = torch.cat([
|
||||||
|
zero_1, ref_feats, zero_1, # 103, ref, 104
|
||||||
|
zero_txt, # text
|
||||||
|
zero_1, tgt_feats, zero_1, # 101, target, 102
|
||||||
|
], dim=0)
|
||||||
|
|
||||||
|
# -- masks --
|
||||||
|
text_mask = torch.cat([
|
||||||
|
torch.ones(1), torch.zeros(ref_len), torch.ones(1),
|
||||||
|
torch.ones(txt_len),
|
||||||
|
torch.ones(1), torch.zeros(tgt_len), torch.ones(1),
|
||||||
|
]).to(torch.int32).to(device)
|
||||||
|
|
||||||
|
audio_mask = torch.cat([
|
||||||
|
torch.zeros(1), torch.ones(ref_len), torch.zeros(1),
|
||||||
|
torch.zeros(txt_len),
|
||||||
|
torch.zeros(1), torch.ones(tgt_len), torch.zeros(1),
|
||||||
|
]).to(torch.int32).to(device)
|
||||||
|
|
||||||
|
loss_mask = torch.cat([
|
||||||
|
torch.zeros(1 + ref_len + 1), # ref part: no loss
|
||||||
|
torch.zeros(txt_len), # text: no loss
|
||||||
|
torch.zeros(1), # 101: no loss
|
||||||
|
torch.ones(tgt_len), # target audio: LOSS
|
||||||
|
torch.zeros(1), # 102: no loss
|
||||||
|
]).to(torch.int32).to(device)
|
||||||
|
|
||||||
|
total_len = 1 + ref_len + 1 + txt_len + 1 + tgt_len + 1
|
||||||
|
labels = torch.zeros(total_len, dtype=torch.int32, device=device)
|
||||||
|
labels[-2] = 1 # stop label at last target audio position
|
||||||
|
|
||||||
|
return (
|
||||||
|
text_token_info,
|
||||||
|
audio_feat_info,
|
||||||
|
text_mask,
|
||||||
|
audio_mask,
|
||||||
|
loss_mask,
|
||||||
|
labels,
|
||||||
|
ref_duration + tgt_duration,
|
||||||
|
txt_len,
|
||||||
|
)
|
||||||
|
|||||||
@@ -0,0 +1,309 @@
|
|||||||
|
"""
|
||||||
|
Pre-flight validation for VoxCPM training data manifests.
|
||||||
|
|
||||||
|
Validates JSONL manifest files before starting expensive fine-tuning jobs,
|
||||||
|
catching format issues, missing files, and data quality problems early.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ValidationResult:
|
||||||
|
"""Structured result of a manifest validation run."""
|
||||||
|
|
||||||
|
total_samples: int = 0
|
||||||
|
valid_samples: int = 0
|
||||||
|
errors: List[str] = field(default_factory=list)
|
||||||
|
warnings: List[str] = field(default_factory=list)
|
||||||
|
audio_durations: List[float] = field(default_factory=list)
|
||||||
|
text_lengths: List[int] = field(default_factory=list)
|
||||||
|
has_ref_audio: int = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_valid(self) -> bool:
|
||||||
|
return len(self.errors) == 0 and self.valid_samples > 0
|
||||||
|
|
||||||
|
|
||||||
|
def _check_audio_file(audio_path: str, sample_rate: int) -> Optional[str]:
|
||||||
|
"""Check if an audio file exists, is readable, and matches expected sample rate.
|
||||||
|
|
||||||
|
Returns an error message, or None if the file is valid.
|
||||||
|
"""
|
||||||
|
if not os.path.isfile(audio_path):
|
||||||
|
return f"Audio file not found: {audio_path}"
|
||||||
|
try:
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
info = sf.info(audio_path)
|
||||||
|
if info.frames == 0:
|
||||||
|
return f"Audio file is empty: {audio_path}"
|
||||||
|
if info.samplerate != sample_rate:
|
||||||
|
return (
|
||||||
|
f"Sample rate mismatch in {audio_path}: "
|
||||||
|
f"expected {sample_rate} Hz, got {info.samplerate} Hz"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
except ImportError:
|
||||||
|
# soundfile not available; just check existence
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
return f"Cannot read audio file {audio_path}: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_audio_duration(audio_path: str) -> Optional[float]:
|
||||||
|
"""Get audio duration in seconds. Returns None if unavailable."""
|
||||||
|
try:
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
info = sf.info(audio_path)
|
||||||
|
return info.duration
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def validate_manifest(
|
||||||
|
manifest_path: str,
|
||||||
|
sample_rate: int = 16_000,
|
||||||
|
max_samples: int = 0,
|
||||||
|
verbose: bool = False,
|
||||||
|
) -> ValidationResult:
|
||||||
|
"""Validate a JSONL training manifest file.
|
||||||
|
|
||||||
|
Checks:
|
||||||
|
1. File exists and is readable
|
||||||
|
2. Each line is valid JSON
|
||||||
|
3. Required columns present (text, audio)
|
||||||
|
4. Audio files exist and are readable
|
||||||
|
5. Text content is non-empty
|
||||||
|
6. Collects duration and text length statistics
|
||||||
|
7. Validates optional ref_audio column
|
||||||
|
|
||||||
|
Args:
|
||||||
|
manifest_path: Path to the JSONL manifest file.
|
||||||
|
sample_rate: Expected audio sample rate (for informational purposes).
|
||||||
|
max_samples: Maximum number of samples to validate (0 = all).
|
||||||
|
verbose: Print per-sample progress.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ValidationResult with errors, warnings, and statistics.
|
||||||
|
"""
|
||||||
|
result = ValidationResult()
|
||||||
|
path = Path(manifest_path)
|
||||||
|
|
||||||
|
if not path.exists():
|
||||||
|
result.errors.append(f"Manifest file not found: {manifest_path}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
if not path.is_file():
|
||||||
|
result.errors.append(f"Manifest path is not a file: {manifest_path}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
manifest_dir = path.parent
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
except Exception as e:
|
||||||
|
result.errors.append(f"Cannot read manifest file: {e}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
if not lines:
|
||||||
|
result.errors.append("Manifest file is empty")
|
||||||
|
return result
|
||||||
|
|
||||||
|
samples_to_check = len(lines)
|
||||||
|
if max_samples > 0:
|
||||||
|
samples_to_check = min(samples_to_check, max_samples)
|
||||||
|
|
||||||
|
missing_audio_count = 0
|
||||||
|
empty_text_count = 0
|
||||||
|
|
||||||
|
for i, line in enumerate(lines[:samples_to_check]):
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
result.total_samples += 1
|
||||||
|
|
||||||
|
# Check JSON validity
|
||||||
|
try:
|
||||||
|
entry = json.loads(line)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
result.errors.append(f"Line {i + 1}: Invalid JSON — {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not isinstance(entry, dict):
|
||||||
|
result.errors.append(f"Line {i + 1}: Expected JSON object, got {type(entry).__name__}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check required columns
|
||||||
|
has_error = False
|
||||||
|
|
||||||
|
if "text" not in entry:
|
||||||
|
result.errors.append(f"Line {i + 1}: Missing required column 'text'")
|
||||||
|
has_error = True
|
||||||
|
|
||||||
|
if "audio" not in entry:
|
||||||
|
result.errors.append(f"Line {i + 1}: Missing required column 'audio'")
|
||||||
|
has_error = True
|
||||||
|
|
||||||
|
if has_error:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Validate text
|
||||||
|
text = entry["text"]
|
||||||
|
if not isinstance(text, str) or not text.strip():
|
||||||
|
empty_text_count += 1
|
||||||
|
if empty_text_count <= 5:
|
||||||
|
result.warnings.append(f"Line {i + 1}: Empty or non-string text")
|
||||||
|
else:
|
||||||
|
result.text_lengths.append(len(text))
|
||||||
|
|
||||||
|
# Validate audio path
|
||||||
|
audio_path = entry["audio"]
|
||||||
|
if isinstance(audio_path, dict):
|
||||||
|
# HuggingFace Audio format with {"path": ..., "array": ...}
|
||||||
|
audio_path = audio_path.get("path", "")
|
||||||
|
|
||||||
|
if isinstance(audio_path, str) and audio_path:
|
||||||
|
# Resolve relative paths against manifest directory
|
||||||
|
if not os.path.isabs(audio_path):
|
||||||
|
audio_path = str(manifest_dir / audio_path)
|
||||||
|
|
||||||
|
audio_error = _check_audio_file(audio_path, sample_rate)
|
||||||
|
if audio_error:
|
||||||
|
missing_audio_count += 1
|
||||||
|
if missing_audio_count <= 5:
|
||||||
|
result.errors.append(f"Line {i + 1}: {audio_error}")
|
||||||
|
has_error = True
|
||||||
|
else:
|
||||||
|
duration = _get_audio_duration(audio_path)
|
||||||
|
if duration is not None:
|
||||||
|
result.audio_durations.append(duration)
|
||||||
|
if duration < 0.3:
|
||||||
|
result.warnings.append(
|
||||||
|
f"Line {i + 1}: Very short audio ({duration:.2f}s)"
|
||||||
|
)
|
||||||
|
elif duration > 30.0:
|
||||||
|
result.warnings.append(
|
||||||
|
f"Line {i + 1}: Very long audio ({duration:.1f}s), may cause OOM"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result.errors.append(f"Line {i + 1}: Invalid audio path")
|
||||||
|
has_error = True
|
||||||
|
|
||||||
|
# Validate optional ref_audio
|
||||||
|
if "ref_audio" in entry:
|
||||||
|
ref_path = entry["ref_audio"]
|
||||||
|
if isinstance(ref_path, dict):
|
||||||
|
ref_path = ref_path.get("path", "")
|
||||||
|
if isinstance(ref_path, str) and ref_path:
|
||||||
|
if not os.path.isabs(ref_path):
|
||||||
|
ref_path = str(manifest_dir / ref_path)
|
||||||
|
if os.path.isfile(ref_path):
|
||||||
|
result.has_ref_audio += 1
|
||||||
|
else:
|
||||||
|
result.warnings.append(
|
||||||
|
f"Line {i + 1}: ref_audio file not found: {ref_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not has_error:
|
||||||
|
result.valid_samples += 1
|
||||||
|
|
||||||
|
if verbose and (i + 1) % 100 == 0:
|
||||||
|
print(f" Validated {i + 1}/{samples_to_check} samples...", file=sys.stderr)
|
||||||
|
|
||||||
|
# Summarize truncated errors
|
||||||
|
if missing_audio_count > 5:
|
||||||
|
result.errors.append(
|
||||||
|
f"... and {missing_audio_count - 5} more missing audio files "
|
||||||
|
f"({missing_audio_count} total)"
|
||||||
|
)
|
||||||
|
if empty_text_count > 5:
|
||||||
|
result.warnings.append(
|
||||||
|
f"... and {empty_text_count - 5} more empty text entries "
|
||||||
|
f"({empty_text_count} total)"
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def print_validation_report(result: ValidationResult, manifest_path: str) -> None:
|
||||||
|
"""Print a human-readable validation report to stderr."""
|
||||||
|
print(f"\n{'=' * 60}", file=sys.stderr)
|
||||||
|
print(f" VoxCPM Training Data Validation Report", file=sys.stderr)
|
||||||
|
print(f"{'=' * 60}", file=sys.stderr)
|
||||||
|
print(f" Manifest : {manifest_path}", file=sys.stderr)
|
||||||
|
print(f" Samples : {result.valid_samples}/{result.total_samples} valid", file=sys.stderr)
|
||||||
|
|
||||||
|
if result.has_ref_audio > 0:
|
||||||
|
print(
|
||||||
|
f" Ref Audio: {result.has_ref_audio} samples with reference audio",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Audio duration statistics
|
||||||
|
if result.audio_durations:
|
||||||
|
durations = sorted(result.audio_durations)
|
||||||
|
total_hrs = sum(durations) / 3600
|
||||||
|
print(f"\n Audio Duration Statistics:", file=sys.stderr)
|
||||||
|
print(f" Total : {total_hrs:.2f} hours", file=sys.stderr)
|
||||||
|
print(
|
||||||
|
f" Range : {durations[0]:.2f}s — {durations[-1]:.1f}s",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Mean : {sum(durations) / len(durations):.2f}s",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
median_idx = len(durations) // 2
|
||||||
|
print(f" Median : {durations[median_idx]:.2f}s", file=sys.stderr)
|
||||||
|
|
||||||
|
# Text length statistics
|
||||||
|
if result.text_lengths:
|
||||||
|
lengths = sorted(result.text_lengths)
|
||||||
|
print(f"\n Text Length Statistics (characters):", file=sys.stderr)
|
||||||
|
print(
|
||||||
|
f" Range : {lengths[0]} — {lengths[-1]}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Mean : {sum(lengths) / len(lengths):.0f}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Errors
|
||||||
|
if result.errors:
|
||||||
|
print(f"\n ERRORS ({len(result.errors)}):", file=sys.stderr)
|
||||||
|
for err in result.errors[:20]:
|
||||||
|
print(f" x {err}", file=sys.stderr)
|
||||||
|
if len(result.errors) > 20:
|
||||||
|
print(
|
||||||
|
f" ... ({len(result.errors) - 20} more errors omitted)",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Warnings
|
||||||
|
if result.warnings:
|
||||||
|
print(f"\n WARNINGS ({len(result.warnings)}):", file=sys.stderr)
|
||||||
|
for warn in result.warnings[:10]:
|
||||||
|
print(f" ! {warn}", file=sys.stderr)
|
||||||
|
if len(result.warnings) > 10:
|
||||||
|
print(
|
||||||
|
f" ... ({len(result.warnings) - 10} more warnings omitted)",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print(f"\n{'=' * 60}", file=sys.stderr)
|
||||||
|
if result.is_valid:
|
||||||
|
print(" PASSED: Manifest is valid for training.", file=sys.stderr)
|
||||||
|
else:
|
||||||
|
print(" FAILED: Fix errors above before starting training.", file=sys.stderr)
|
||||||
|
print(f"{'=' * 60}\n", file=sys.stderr)
|
||||||
@@ -58,6 +58,7 @@ def test_parser_defaults_to_voxcpm2():
|
|||||||
parser = cli._build_parser()
|
parser = cli._build_parser()
|
||||||
args = parser.parse_args(["design", "--text", "hello", "--output", "out.wav"])
|
args = parser.parse_args(["design", "--text", "hello", "--output", "out.wav"])
|
||||||
assert args.hf_model_id == "openbmb/VoxCPM2"
|
assert args.hf_model_id == "openbmb/VoxCPM2"
|
||||||
|
assert args.device == "auto"
|
||||||
assert args.no_optimize is False
|
assert args.no_optimize is False
|
||||||
|
|
||||||
|
|
||||||
@@ -85,6 +86,7 @@ def test_load_model_respects_no_optimize_for_local_model(monkeypatch):
|
|||||||
|
|
||||||
cli.load_model(args)
|
cli.load_model(args)
|
||||||
|
|
||||||
|
assert calls["kwargs"]["device"] == "auto"
|
||||||
assert calls["kwargs"]["optimize"] is False
|
assert calls["kwargs"]["optimize"] is False
|
||||||
|
|
||||||
|
|
||||||
@@ -110,6 +112,7 @@ def test_load_model_defaults_optimize_for_hf(monkeypatch):
|
|||||||
|
|
||||||
cli.load_model(args)
|
cli.load_model(args)
|
||||||
|
|
||||||
|
assert calls["kwargs"]["device"] == "auto"
|
||||||
assert calls["kwargs"]["optimize"] is True
|
assert calls["kwargs"]["optimize"] is True
|
||||||
|
|
||||||
|
|
||||||
@@ -136,9 +139,37 @@ def test_load_model_respects_no_optimize_for_hf(monkeypatch):
|
|||||||
|
|
||||||
cli.load_model(args)
|
cli.load_model(args)
|
||||||
|
|
||||||
|
assert calls["kwargs"]["device"] == "auto"
|
||||||
assert calls["kwargs"]["optimize"] is False
|
assert calls["kwargs"]["optimize"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_model_passes_explicit_device_to_hf(monkeypatch):
|
||||||
|
calls = {}
|
||||||
|
|
||||||
|
class FakeVoxCPM:
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, **kwargs):
|
||||||
|
calls["kwargs"] = kwargs
|
||||||
|
return DummyModel()
|
||||||
|
|
||||||
|
monkeypatch.setattr(cli, "VoxCPM", FakeVoxCPM)
|
||||||
|
args = cli._build_parser().parse_args(
|
||||||
|
[
|
||||||
|
"design",
|
||||||
|
"--text",
|
||||||
|
"hello",
|
||||||
|
"--output",
|
||||||
|
"out.wav",
|
||||||
|
"--device",
|
||||||
|
"mps",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
cli.load_model(args)
|
||||||
|
|
||||||
|
assert calls["kwargs"]["device"] == "mps"
|
||||||
|
|
||||||
|
|
||||||
def test_design_subcommand_applies_control(monkeypatch, tmp_path):
|
def test_design_subcommand_applies_control(monkeypatch, tmp_path):
|
||||||
dummy_model = DummyModel()
|
dummy_model = DummyModel()
|
||||||
monkeypatch.setattr(cli, "load_model", lambda args: dummy_model)
|
monkeypatch.setattr(cli, "load_model", lambda args: dummy_model)
|
||||||
|
|||||||
@@ -0,0 +1,150 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
SRC = ROOT / "src"
|
||||||
|
|
||||||
|
|
||||||
|
def _load_module(name: str, path: Path):
|
||||||
|
spec = importlib.util.spec_from_file_location(name, path)
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
assert spec.loader is not None
|
||||||
|
sys.modules[name] = module
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def bootstrap_repo_modules(monkeypatch):
|
||||||
|
for name, path in [
|
||||||
|
("voxcpm", SRC / "voxcpm"),
|
||||||
|
("voxcpm.model", SRC / "voxcpm" / "model"),
|
||||||
|
("voxcpm.modules", SRC / "voxcpm" / "modules"),
|
||||||
|
]:
|
||||||
|
pkg = types.ModuleType(name)
|
||||||
|
pkg.__path__ = [str(path)]
|
||||||
|
monkeypatch.setitem(sys.modules, name, pkg)
|
||||||
|
|
||||||
|
hh = types.ModuleType("huggingface_hub")
|
||||||
|
hh.snapshot_download = lambda *a, **k: "/tmp/fake"
|
||||||
|
monkeypatch.setitem(sys.modules, "huggingface_hub", hh)
|
||||||
|
|
||||||
|
pydantic = types.ModuleType("pydantic")
|
||||||
|
|
||||||
|
class BaseModel:
|
||||||
|
@classmethod
|
||||||
|
def model_rebuild(cls):
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def model_validate_json(cls, s):
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
def model_dump(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
pydantic.BaseModel = BaseModel
|
||||||
|
monkeypatch.setitem(sys.modules, "pydantic", pydantic)
|
||||||
|
|
||||||
|
torchaudio = types.ModuleType("torchaudio")
|
||||||
|
monkeypatch.setitem(sys.modules, "torchaudio", torchaudio)
|
||||||
|
|
||||||
|
librosa = types.ModuleType("librosa")
|
||||||
|
librosa.effects = types.SimpleNamespace(trim=lambda *a, **k: (None, (0, 0)))
|
||||||
|
monkeypatch.setitem(sys.modules, "librosa", librosa)
|
||||||
|
|
||||||
|
einops = types.ModuleType("einops")
|
||||||
|
einops.rearrange = lambda x, *a, **k: x
|
||||||
|
monkeypatch.setitem(sys.modules, "einops", einops)
|
||||||
|
|
||||||
|
tqdm_pkg = types.ModuleType("tqdm")
|
||||||
|
tqdm_pkg.__path__ = ["/nonexistent"]
|
||||||
|
tqdm_pkg.tqdm = lambda x, *a, **k: x
|
||||||
|
monkeypatch.setitem(sys.modules, "tqdm", tqdm_pkg)
|
||||||
|
|
||||||
|
tqdm_auto = types.ModuleType("tqdm.auto")
|
||||||
|
tqdm_auto.tqdm = lambda x, *a, **k: x
|
||||||
|
monkeypatch.setitem(sys.modules, "tqdm.auto", tqdm_auto)
|
||||||
|
|
||||||
|
transformers = types.ModuleType("transformers")
|
||||||
|
|
||||||
|
class LlamaTokenizerFast:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class PreTrainedTokenizer:
|
||||||
|
pass
|
||||||
|
|
||||||
|
transformers.LlamaTokenizerFast = LlamaTokenizerFast
|
||||||
|
transformers.PreTrainedTokenizer = PreTrainedTokenizer
|
||||||
|
monkeypatch.setitem(sys.modules, "transformers", transformers)
|
||||||
|
|
||||||
|
internal_mods = {
|
||||||
|
"voxcpm.modules.audiovae": ["AudioVAE", "AudioVAEConfig", "AudioVAEV2", "AudioVAEConfigV2"],
|
||||||
|
"voxcpm.modules.layers": ["ScalarQuantizationLayer"],
|
||||||
|
"voxcpm.modules.locdit": ["CfmConfig", "UnifiedCFM", "VoxCPMLocDiT", "VoxCPMLocDiTV2"],
|
||||||
|
"voxcpm.modules.locenc": ["VoxCPMLocEnc"],
|
||||||
|
"voxcpm.modules.minicpm4": ["MiniCPM4Config", "MiniCPMModel"],
|
||||||
|
"voxcpm.modules.layers.lora": ["apply_lora_to_named_linear_modules", "LoRALinear"],
|
||||||
|
}
|
||||||
|
for modname, names in internal_mods.items():
|
||||||
|
module = types.ModuleType(modname)
|
||||||
|
for name in names:
|
||||||
|
if name == "apply_lora_to_named_linear_modules":
|
||||||
|
setattr(module, name, lambda *a, **k: None)
|
||||||
|
else:
|
||||||
|
setattr(module, name, type(name, (), {}))
|
||||||
|
monkeypatch.setitem(sys.modules, modname, module)
|
||||||
|
|
||||||
|
_load_module("voxcpm.model.utils", SRC / "voxcpm" / "model" / "utils.py")
|
||||||
|
voxcpm = _load_module("voxcpm.model.voxcpm", SRC / "voxcpm" / "model" / "voxcpm.py")
|
||||||
|
voxcpm2 = _load_module("voxcpm.model.voxcpm2", SRC / "voxcpm" / "model" / "voxcpm2.py")
|
||||||
|
return voxcpm.VoxCPMModel, voxcpm2.VoxCPM2Model
|
||||||
|
|
||||||
|
|
||||||
|
class DummyModel:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
def named_parameters(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("module_name", ["v1", "v2"])
|
||||||
|
def test_load_lora_weights_accepts_tensor_only_legacy_checkpoints(monkeypatch, tmp_path, module_name):
|
||||||
|
VoxCPMModel, VoxCPM2Model = bootstrap_repo_modules(monkeypatch)
|
||||||
|
cls = VoxCPMModel if module_name == "v1" else VoxCPM2Model
|
||||||
|
|
||||||
|
ckpt_path = tmp_path / "lora_weights.ckpt"
|
||||||
|
torch.save({"state_dict": {"fake": torch.zeros(1)}}, ckpt_path)
|
||||||
|
|
||||||
|
loaded, skipped = cls.load_lora_weights(DummyModel(), str(ckpt_path), device="cpu")
|
||||||
|
|
||||||
|
assert loaded == []
|
||||||
|
assert skipped == ["fake"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("module_name", ["v1", "v2"])
|
||||||
|
def test_load_lora_weights_rejects_malicious_pickle_payloads(monkeypatch, tmp_path, module_name):
|
||||||
|
VoxCPMModel, VoxCPM2Model = bootstrap_repo_modules(monkeypatch)
|
||||||
|
cls = VoxCPMModel if module_name == "v1" else VoxCPM2Model
|
||||||
|
|
||||||
|
ckpt_path = tmp_path / "lora_weights.ckpt"
|
||||||
|
marker_path = tmp_path / f"{module_name}-marker.txt"
|
||||||
|
|
||||||
|
class Exploit:
|
||||||
|
def __reduce__(self):
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
return (pathlib.Path.write_text, (marker_path, f"{module_name} executed\n"))
|
||||||
|
|
||||||
|
torch.save({"state_dict": {"fake": torch.zeros(1)}, "boom": Exploit()}, ckpt_path)
|
||||||
|
|
||||||
|
with pytest.raises(Exception, match="Weights only load failed"):
|
||||||
|
cls.load_lora_weights(DummyModel(), str(ckpt_path), device="cpu")
|
||||||
|
|
||||||
|
assert not marker_path.exists()
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
UTILS_PATH = ROOT / "src" / "voxcpm" / "model" / "utils.py"
|
||||||
|
|
||||||
|
transformers_stub = types.ModuleType("transformers")
|
||||||
|
transformers_stub.PreTrainedTokenizer = object
|
||||||
|
sys.modules.setdefault("transformers", transformers_stub)
|
||||||
|
|
||||||
|
spec = importlib.util.spec_from_file_location("voxcpm.model.utils", UTILS_PATH)
|
||||||
|
utils = importlib.util.module_from_spec(spec)
|
||||||
|
assert spec.loader is not None
|
||||||
|
spec.loader.exec_module(utils)
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_runtime_device_auto_falls_back_to_cpu(monkeypatch):
|
||||||
|
monkeypatch.setattr(utils.torch.cuda, "is_available", lambda: False)
|
||||||
|
monkeypatch.setattr(utils, "_has_mps", lambda: False)
|
||||||
|
|
||||||
|
assert utils.resolve_runtime_device(None, "cuda") == "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_runtime_device_auto_uses_mps_when_available(monkeypatch):
|
||||||
|
monkeypatch.setattr(utils.torch.cuda, "is_available", lambda: False)
|
||||||
|
monkeypatch.setattr(utils, "_has_mps", lambda: True)
|
||||||
|
|
||||||
|
assert utils.resolve_runtime_device("auto", "cuda") == "mps"
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_runtime_device_respects_explicit_cpu(monkeypatch):
|
||||||
|
monkeypatch.setattr(utils.torch.cuda, "is_available", lambda: True)
|
||||||
|
monkeypatch.setattr(utils, "_has_mps", lambda: True)
|
||||||
|
|
||||||
|
assert utils.resolve_runtime_device("cpu", "cuda") == "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_runtime_device_rejects_unavailable_explicit_cuda(monkeypatch):
|
||||||
|
monkeypatch.setattr(utils.torch.cuda, "is_available", lambda: False)
|
||||||
|
monkeypatch.setattr(utils, "_has_mps", lambda: True)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="CUDA is not available"):
|
||||||
|
utils.resolve_runtime_device("cuda:0", "cuda")
|
||||||
@@ -0,0 +1,252 @@
|
|||||||
|
"""Tests for the training data validation module."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import types
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
|
||||||
|
# Stub voxcpm package so imports work without full dependencies
|
||||||
|
pkg = types.ModuleType("voxcpm")
|
||||||
|
pkg.__path__ = [str(ROOT / "src" / "voxcpm")]
|
||||||
|
sys.modules.setdefault("voxcpm", pkg)
|
||||||
|
|
||||||
|
training_pkg = types.ModuleType("voxcpm.training")
|
||||||
|
training_pkg.__path__ = [str(ROOT / "src" / "voxcpm" / "training")]
|
||||||
|
sys.modules.setdefault("voxcpm.training", training_pkg)
|
||||||
|
|
||||||
|
from voxcpm.training.validate import ValidationResult, validate_manifest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tmp_dir():
|
||||||
|
with tempfile.TemporaryDirectory() as d:
|
||||||
|
yield Path(d)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_wav(path: Path, duration_s: float = 1.0, sr: int = 16000):
|
||||||
|
"""Create a minimal valid WAV file."""
|
||||||
|
try:
|
||||||
|
import soundfile as sf
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
samples = int(duration_s * sr)
|
||||||
|
data = np.zeros(samples, dtype=np.float32)
|
||||||
|
sf.write(str(path), data, sr)
|
||||||
|
except ImportError:
|
||||||
|
# If soundfile is not available, create a minimal WAV header
|
||||||
|
import struct
|
||||||
|
|
||||||
|
samples = int(duration_s * sr)
|
||||||
|
data_size = samples * 2 # 16-bit PCM
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
f.write(b"RIFF")
|
||||||
|
f.write(struct.pack("<I", 36 + data_size))
|
||||||
|
f.write(b"WAVEfmt ")
|
||||||
|
f.write(struct.pack("<IHHIIHH", 16, 1, 1, sr, sr * 2, 2, 16))
|
||||||
|
f.write(b"data")
|
||||||
|
f.write(struct.pack("<I", data_size))
|
||||||
|
f.write(b"\x00" * data_size)
|
||||||
|
|
||||||
|
|
||||||
|
def _write_manifest(path: Path, entries: list[dict]):
|
||||||
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
|
for entry in entries:
|
||||||
|
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateManifest:
|
||||||
|
def test_valid_manifest(self, tmp_dir):
|
||||||
|
audio1 = tmp_dir / "audio1.wav"
|
||||||
|
audio2 = tmp_dir / "audio2.wav"
|
||||||
|
_create_wav(audio1, 2.0)
|
||||||
|
_create_wav(audio2, 3.0)
|
||||||
|
|
||||||
|
manifest = tmp_dir / "train.jsonl"
|
||||||
|
_write_manifest(
|
||||||
|
manifest,
|
||||||
|
[
|
||||||
|
{"text": "Hello world", "audio": str(audio1)},
|
||||||
|
{"text": "Goodbye world", "audio": str(audio2)},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
result = validate_manifest(str(manifest))
|
||||||
|
assert result.total_samples == 2
|
||||||
|
assert result.valid_samples == 2
|
||||||
|
assert result.is_valid
|
||||||
|
assert len(result.errors) == 0
|
||||||
|
|
||||||
|
def test_missing_manifest(self):
|
||||||
|
result = validate_manifest("/nonexistent/path.jsonl")
|
||||||
|
assert not result.is_valid
|
||||||
|
assert any("not found" in e for e in result.errors)
|
||||||
|
|
||||||
|
def test_empty_manifest(self, tmp_dir):
|
||||||
|
manifest = tmp_dir / "empty.jsonl"
|
||||||
|
manifest.write_text("")
|
||||||
|
result = validate_manifest(str(manifest))
|
||||||
|
assert not result.is_valid
|
||||||
|
|
||||||
|
def test_invalid_json(self, tmp_dir):
|
||||||
|
manifest = tmp_dir / "bad.jsonl"
|
||||||
|
manifest.write_text("not json\n{bad json}\n")
|
||||||
|
result = validate_manifest(str(manifest))
|
||||||
|
assert len(result.errors) >= 2
|
||||||
|
assert any("Invalid JSON" in e for e in result.errors)
|
||||||
|
|
||||||
|
def test_missing_columns(self, tmp_dir):
|
||||||
|
manifest = tmp_dir / "missing.jsonl"
|
||||||
|
_write_manifest(
|
||||||
|
manifest,
|
||||||
|
[
|
||||||
|
{"text": "hello"}, # missing audio
|
||||||
|
{"audio": "test.wav"}, # missing text
|
||||||
|
],
|
||||||
|
)
|
||||||
|
result = validate_manifest(str(manifest))
|
||||||
|
assert len(result.errors) >= 2
|
||||||
|
assert any("'audio'" in e for e in result.errors)
|
||||||
|
assert any("'text'" in e for e in result.errors)
|
||||||
|
|
||||||
|
def test_missing_audio_file(self, tmp_dir):
|
||||||
|
manifest = tmp_dir / "missing_audio.jsonl"
|
||||||
|
_write_manifest(
|
||||||
|
manifest,
|
||||||
|
[{"text": "hello", "audio": "/nonexistent/audio.wav"}],
|
||||||
|
)
|
||||||
|
result = validate_manifest(str(manifest))
|
||||||
|
assert not result.is_valid
|
||||||
|
assert any("not found" in e for e in result.errors)
|
||||||
|
|
||||||
|
def test_empty_text_warning(self, tmp_dir):
|
||||||
|
audio = tmp_dir / "audio.wav"
|
||||||
|
_create_wav(audio)
|
||||||
|
manifest = tmp_dir / "empty_text.jsonl"
|
||||||
|
_write_manifest(
|
||||||
|
manifest,
|
||||||
|
[{"text": "", "audio": str(audio)}],
|
||||||
|
)
|
||||||
|
result = validate_manifest(str(manifest))
|
||||||
|
assert len(result.warnings) > 0
|
||||||
|
assert any("Empty" in w for w in result.warnings)
|
||||||
|
|
||||||
|
def test_relative_audio_path(self, tmp_dir):
|
||||||
|
audio = tmp_dir / "audio.wav"
|
||||||
|
_create_wav(audio)
|
||||||
|
manifest = tmp_dir / "rel.jsonl"
|
||||||
|
_write_manifest(
|
||||||
|
manifest,
|
||||||
|
[{"text": "hello", "audio": "audio.wav"}],
|
||||||
|
)
|
||||||
|
result = validate_manifest(str(manifest))
|
||||||
|
assert result.valid_samples == 1
|
||||||
|
assert result.is_valid
|
||||||
|
|
||||||
|
def test_max_samples_limit(self, tmp_dir):
|
||||||
|
audio = tmp_dir / "audio.wav"
|
||||||
|
_create_wav(audio)
|
||||||
|
manifest = tmp_dir / "many.jsonl"
|
||||||
|
_write_manifest(
|
||||||
|
manifest,
|
||||||
|
[{"text": f"sample {i}", "audio": str(audio)} for i in range(100)],
|
||||||
|
)
|
||||||
|
result = validate_manifest(str(manifest), max_samples=10)
|
||||||
|
assert result.total_samples == 10
|
||||||
|
|
||||||
|
def test_ref_audio_counted(self, tmp_dir):
|
||||||
|
audio = tmp_dir / "audio.wav"
|
||||||
|
ref = tmp_dir / "ref.wav"
|
||||||
|
_create_wav(audio)
|
||||||
|
_create_wav(ref)
|
||||||
|
manifest = tmp_dir / "ref.jsonl"
|
||||||
|
_write_manifest(
|
||||||
|
manifest,
|
||||||
|
[{"text": "hello", "audio": str(audio), "ref_audio": str(ref)}],
|
||||||
|
)
|
||||||
|
result = validate_manifest(str(manifest))
|
||||||
|
assert result.has_ref_audio == 1
|
||||||
|
|
||||||
|
def test_validation_result_properties(self):
|
||||||
|
r = ValidationResult(total_samples=5, valid_samples=5)
|
||||||
|
assert r.is_valid
|
||||||
|
|
||||||
|
r2 = ValidationResult(total_samples=5, valid_samples=5, errors=["err"])
|
||||||
|
assert not r2.is_valid
|
||||||
|
|
||||||
|
r3 = ValidationResult(total_samples=0, valid_samples=0)
|
||||||
|
assert not r3.is_valid
|
||||||
|
|
||||||
|
def test_invalid_audio_not_counted_as_valid(self, tmp_dir):
|
||||||
|
"""A row with a bad audio path must not increment valid_samples."""
|
||||||
|
manifest = tmp_dir / "bad_audio.jsonl"
|
||||||
|
_write_manifest(
|
||||||
|
manifest,
|
||||||
|
[{"text": "hello", "audio": "/nonexistent/audio.wav"}],
|
||||||
|
)
|
||||||
|
result = validate_manifest(str(manifest))
|
||||||
|
assert result.total_samples == 1
|
||||||
|
assert result.valid_samples == 0
|
||||||
|
assert not result.is_valid
|
||||||
|
assert any("not found" in e for e in result.errors)
|
||||||
|
|
||||||
|
def test_sample_rate_mismatch(self, tmp_dir):
|
||||||
|
"""A file with a different sample rate should be reported as an error."""
|
||||||
|
try:
|
||||||
|
import soundfile as sf
|
||||||
|
import numpy as np
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("soundfile not available")
|
||||||
|
|
||||||
|
audio = tmp_dir / "audio_8k.wav"
|
||||||
|
import numpy as np
|
||||||
|
samples = np.zeros(8000, dtype=np.float32)
|
||||||
|
sf.write(str(audio), samples, 8000)
|
||||||
|
|
||||||
|
manifest = tmp_dir / "sr_mismatch.jsonl"
|
||||||
|
_write_manifest(manifest, [{"text": "hello", "audio": str(audio)}])
|
||||||
|
|
||||||
|
result = validate_manifest(str(manifest), sample_rate=16000)
|
||||||
|
assert result.valid_samples == 0
|
||||||
|
assert not result.is_valid
|
||||||
|
assert any("Sample rate mismatch" in e or "sample rate" in e.lower() for e in result.errors)
|
||||||
|
|
||||||
|
def test_mixed_ref_audio_warns_for_each_missing(self, tmp_dir):
|
||||||
|
"""Missing ref_audio entries should each generate a warning independently."""
|
||||||
|
audio = tmp_dir / "audio.wav"
|
||||||
|
ref_good = tmp_dir / "ref_good.wav"
|
||||||
|
_create_wav(audio)
|
||||||
|
_create_wav(ref_good)
|
||||||
|
|
||||||
|
manifest = tmp_dir / "mixed_ref.jsonl"
|
||||||
|
_write_manifest(
|
||||||
|
manifest,
|
||||||
|
[
|
||||||
|
{"text": "row1", "audio": str(audio), "ref_audio": str(ref_good)},
|
||||||
|
{"text": "row2", "audio": str(audio), "ref_audio": "/nonexistent/ref.wav"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
result = validate_manifest(str(manifest))
|
||||||
|
assert result.has_ref_audio == 1
|
||||||
|
assert any("ref_audio file not found" in w for w in result.warnings)
|
||||||
|
|
||||||
|
def test_cli_validate_exit_code(self, tmp_dir):
|
||||||
|
"""validate subcommand must exit 1 on validation error (missing audio)."""
|
||||||
|
import subprocess
|
||||||
|
manifest = tmp_dir / "bad.jsonl"
|
||||||
|
_write_manifest(manifest, [{"text": "hi", "audio": "/nonexistent/x.wav"}])
|
||||||
|
|
||||||
|
proc = subprocess.run(
|
||||||
|
[sys.executable, "-m", "voxcpm.cli", "validate", "--manifest", str(manifest)],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
assert proc.returncode == 1, f"Expected exit 1, got {proc.returncode}"
|
||||||
|
assert "FAILED" in proc.stderr or "Audio file not found" in proc.stderr
|
||||||
Reference in New Issue
Block a user