46 Commits

Author SHA1 Message Date
Pine 2e4c601d1d - 修复路径错误 2026-05-06 14:58:46 +08:00
Pine 525ebc65a3 - 修正一个本地模型使用 2026-05-06 14:28:36 +08:00
Pine 564b98b2d5 - 自部署单独用支持:给添哥 2026-05-06 13:41:26 +08:00
liuxin 19b6bf7590 fix: handle LoRA rank mismatch during inference in lora_ft_webui
Pass the selected LoRA checkpoint to load_model() on first load so the
model initializes with the correct rank from lora_config.json instead of
always defaulting to r=32.

On subsequent LoRA hot-swaps, detect rank incompatibility and
automatically reload the model with the new checkpoint's config,
preventing tensor shape mismatch errors (fixes #283).

Made-with: Cursor
2026-04-28 10:52:57 +08:00
ZGY 86bff0fc82 Merge pull request #253 from SuperMarioYL/feat/validate-training-data
feat: add voxcpm validate CLI for pre-flight training data checks
2026-04-27 21:09:41 +08:00
supermario_leo dd7b78f2c0 refactor(cli): defer soundfile and voxcpm.core imports to inference commands
Move `import soundfile as sf` and `from voxcpm.core import VoxCPM` from
module-level into the functions that require model inference (load_model,
_run_single, cmd_batch), so `voxcpm validate` can run without loading
the model/inference stack.
2026-04-25 05:09:23 +08:00
supermario_leo 29577d57f8 test: fix test_cli_validate_exit_code to use --manifest flag and assert specific exit code
Pass manifest path via --manifest flag (required) instead of as a
positional argument, so the test exercises cmd_validate rather than
argparse error handling.  Also assert returncode==1 and check stderr
for the FAILED/error message to prevent false positives.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-24 10:15:57 +08:00
supermario_leo 4509becfde fix: address four validation correctness issues from review
- Invalid audio rows (bad path or sample-rate mismatch) no longer
  increment valid_samples; has_error is now set on any audio failure
- _check_audio_file now enforces the expected sample rate when soundfile
  is available, making --sample-rate actually useful
- ref_audio missing-file warning is emitted for every invalid entry
  independently, not only before the first valid one is seen
- New tests cover each of the four corrected behaviours: invalid audio
  count, sample-rate mismatch, mixed ref_audio, and CLI exit code
2026-04-22 05:06:35 +08:00
ZGY cd79a647fa Merge pull request #263 from Oumnya/fix/mps-bf16-dtype
fix(mps): force float32 on Apple Silicon to avoid bf16 quality loss
2026-04-21 18:49:48 +08:00
Oumnya 96d605b9de fix(mps): align VOXCPM_MPS_DTYPE override set with get_dtype parser
Drop "half" from _VALID_DTYPE_OVERRIDES / _LOW_PRECISION_DTYPES.
get_dtype() has never accepted "half", so VOXCPM_MPS_DTYPE=half would
pass override validation and then crash downstream with
"Unsupported dtype: half". The remaining aliases (bfloat16/bf16,
float16/fp16, float32/fp32) already cover the intended dtype space.

Adds a standalone unit check under scripts/ to guard the invariant
that every accepted override parses through get_dtype().

Addresses review feedback on #263.
2026-04-21 18:24:53 +08:00
ZGY a9b03a768c Merge pull request #277 from gluttony-10/main
feat: enhance control text processing in VoxCPMDemo
2026-04-21 17:11:42 +08:00
ZGY 77f847fcba Merge pull request #268 from shaun0927/fix/lora-weights-only
fix: load legacy LoRA checkpoints with weights_only=True
2026-04-21 16:55:42 +08:00
gluttony-10 d3cc88722c feat: enhance control text processing in VoxCPMDemo
Added regex to strip parentheses from control instructions in the text synthesis method to ensure compatibility with the expected prompt format. This change improves the robustness of the input handling.
2026-04-21 07:07:24 +00:00
JunghwanNA ec2acec8a1 Harden LoRA checkpoint loading against untrusted pickle payloads
LoRA is a first-class workflow in VoxCPM, and the project already prefers
safetensors plus weights-only fallback loading for base model artifacts. The
legacy LoRA .ckpt/.pth path was the remaining place that still deserialized
arbitrary pickle objects, so this switches it to weights_only=True and adds
focused regression coverage for both model loaders.

Constraint: Must preserve compatibility with tensor-only legacy LoRA checkpoints
Rejected: Remove .ckpt/.pth support entirely | too disruptive for existing users
Confidence: high
Scope-risk: narrow
Reversibility: clean
Directive: Keep LoRA artifact handling aligned with the existing safetensors-first, weights-only loading pattern
Tested: python3 -m pytest -q tests/test_lora_checkpoint_loading.py tests/test_model_utils.py -q
Not-tested: Full end-to-end LoRA hot-load with heavyweight model assets
2026-04-18 00:31:28 +09:00
xliucs 13605c5a0e Merge pull request #266 from linyueqian/docs/add-vllm-omni-references
docs: add vLLM-Omni serving references
2026-04-17 10:46:21 +08:00
Yueqian Lin afa63e6195 docs: add vLLM-Omni serving references
Document vLLM-Omni as a production serving option for VoxCPM2
alongside the existing Nano-vLLM reference. Mirrors the addition in
README_zh.md, and adds an ecosystem table entry.

Install snippet follows the upstream vLLM-Omni installation guide
(from source, since vllm-omni is rapidly evolving).

Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
2026-04-16 21:19:27 -05:00
liuxin eae0a29908 docs: add ComfyUI RH link
Made-with: Cursor
2026-04-16 11:46:40 +08:00
Labmem-Zhouyx 35895982d7 Merge PR #212: perf: stateful streaming VAE decode — eliminate redundant overlap
- StreamingVAEDecoder caches CausalConv1d/CausalTransposeConv1d left-pad
  state between calls — one patch in, one patch out, no overlap
- _inference yields single-patch latents in streaming mode
- 2x faster streaming VAE decode, more accurate (max diff 0.0005 vs 0.0011)
2026-04-15 16:01:38 +08:00
Labmem-Zhouyx f7f1b78c4d fix: correct transpose conv context 2026-04-15 16:01:02 +08:00
oumnya 38d61cdf03 fix(mps): force float32 on Apple Silicon to avoid bf16 quality loss
VoxCPM checkpoints default to bfloat16. Following commit e4e0496 which
added MPS device routing, running with `device=mps` selects bf16 on
Apple Silicon. On Metal, bf16 introduces enough numerical drift in the
diffusion AR loop that the synthesized audio is glitched and trips the
model's badcase detector, which retries until the per-call retry budget
is exhausted. Effectively MPS support is unusable in the default config.

This patch adds a single helper, `pick_runtime_dtype(device, dtype)`,
that promotes any low-precision dtype to float32 when the resolved
device is `mps`. CUDA and CPU paths are untouched. An opt-out env var
`VOXCPM_MPS_DTYPE` lets users force a specific dtype on MPS once future
PyTorch / macOS releases improve bf16 stability.

Both VoxCPMModel and VoxCPM2Model adopt the helper in their __init__,
replacing what would otherwise be duplicated inline checks.

Verified locally on Apple M5 Max, PyTorch 2.11, macOS 15:
- VoxCPM2 (2B): clean output, RTF ~0.78 steady state
- VoxCPM 0.5B: clean output, RTF ~0.92
- No badcase retries fired in any test
- VOXCPM_MPS_DTYPE=bfloat16 round-trips and reproduces the original
  glitched output, confirming the override path.
2026-04-15 12:22:56 +08:00
刘鑫 1565e83efe fix: complete shared generator cleanup coverage
Move generator close handling into a shared utility and wire the core generation pipeline through it so partially-consumed prompt cache generators are cleaned up consistently across both model variants and the public VoxCPM wrapper.

Made-with: Cursor
2026-04-13 17:39:05 +08:00
刘鑫 61b36d4e56 refactor: centralize generator cleanup in model helpers
Factor repeated next-and-close patterns into a shared helper in both VoxCPM model variants so non-streaming inference cleans up generators consistently while keeping the issue reference close to the workaround.

Made-with: Cursor
2026-04-13 16:57:08 +08:00
刘鑫 b1584aec7c fix: stabilize CPU SDPA mask broadcasting
Use an explicit broadcastable attention mask shape during MiniCPM incremental decoding so CPU runtimes avoid a PyTorch SDPA dimension error without changing attention semantics.

Made-with: Cursor
2026-04-13 15:38:53 +08:00
supermario_leo 4457617953 feat: add voxcpm validate CLI for pre-flight training data checks
Add a new `validate` subcommand that checks JSONL training manifests
before starting expensive fine-tuning jobs. This catches format issues,
missing audio files, and data quality problems early.

The validator performs:
- JSONL format validation (each line must be valid JSON)
- Required column checks (text, audio)
- Audio file existence and readability verification
- Duration and text length statistics (min, max, mean, median)
- Optional ref_audio column validation
- Warnings for very short (<0.3s) or very long (>30s) audio samples

Usage:
  voxcpm validate --manifest train.jsonl
  voxcpm validate --manifest train.jsonl --sample-rate 16000 --verbose

The module uses lazy imports for soundfile, so it works even in
minimal environments. Includes 11 unit tests covering all validation
paths.
2026-04-13 03:15:50 +08:00
xliucs 5510503182 Merge pull request #246 from sharziki/fix/unclosed-file-handles
fix: close file handles in from_local() config loading
2026-04-11 13:10:04 +08:00
sharziki fb46aad9a5 fix: close file handles in from_local() config loading
Use context managers when reading config.json in VoxCPMModel.from_local()
and VoxCPM2Model.from_local() to prevent file descriptor leaks. Also add
explicit encoding="utf-8" to avoid locale-dependent decode errors.

Closes #235

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-11 00:01:14 -04:00
刘鑫 e4e049624c update finetuning pipeline and runtime device handling
Support optional ref_audio samples in finetuning and make runtime device selection explicit while keeping auto fallback behavior consistent. Also ignore the local app override file to avoid accidental commits.

Made-with: Cursor
2026-04-11 11:08:50 +08:00
xliucs abf01b9bf3 Merge pull request #229 from kuishou68/fix/issue-228-validate-text-type-order
fix: correct isinstance/strip order in _generate() to prevent AttributeError on non-string input
2026-04-10 10:30:15 +08:00
cocoon 4f4a5b9f6c fix: correct type-check order in _generate() to prevent AttributeError on non-string input
The previous guard `not text.strip() or not isinstance(text, str)` called
.strip() before verifying that text is actually a string, causing an
AttributeError (e.g. for int input) instead of the intended ValueError.

Swap operand order so isinstance check short-circuits first.

Closes #228
2026-04-09 16:13:40 +00:00
刘鑫 79c0cf68dd chore: remove accidentally committed app_local.py
Made-with: Cursor
2026-04-09 16:05:18 +08:00
刘鑫 75cfa3e9b8 fix: use uncompiled feat_encoder for prefill to prevent CUDA Graph dynamic shape accumulation (#209) 2026-04-09 16:00:17 +08:00
Labmem-Zhouyx 5611bd08a0 optim app.py 2026-04-09 00:30:19 +08:00
Kevin Knoedler 66205135fc perf: stateful streaming VAE decode — eliminate redundant overlap
Streaming decode previously re-decoded 4 overlapping patches through
the VAE each step, discarding 75% of the output. Replace with stateful
decode that carries causal conv padding buffers between calls — one
patch in, one patch out, no overlap.

Changes:
- Add StreamingVAEDecoder to audiovae/audio_vae_v2.py — caches
  CausalConv1d and CausalTransposeConv1d left-pad state between calls
- AudioVAE.streaming_decode() context manager for clean lifecycle
- _inference yields single-patch latents in streaming mode
- _generate and _generate_with_prompt_cache use StreamingVAEDecoder

Streaming VAE decode time (isolated): 289ms → 148ms (2x faster)
Stateful vs full decode: cosine 1.0000, max diff 0.0005
(more accurate than previous overlap approach at max diff 0.001)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-08 09:09:22 -07:00
Labmem-Zhouyx 364eff6840 update readme: python version 2026-04-08 23:07:38 +08:00
Labmem-Zhouyx 6d10932b09 update readme 2026-04-08 18:48:58 +08:00
Labmem-Zhouyx 68af4fe502 fix: ft log and setting 2026-04-08 18:15:17 +08:00
Labmem-Zhouyx ee3649c1b3 fix: streaming decode 2026-04-08 17:25:54 +08:00
Labmem-Zhouyx 82d77d445c fix: decode chunksize for audiovae_v2 2026-04-08 16:31:36 +08:00
Labmem-Zhouyx 8f95d13073 update readme: 30-language asr result on internal benchmark 2026-04-08 15:36:56 +08:00
Labmem-Zhouyx df38f0a167 update readme for modelscope download 2026-04-08 11:29:19 +08:00
Labmem-Zhouyx 9adfaf6996 update demo for zh 2026-04-08 00:15:16 +08:00
刘鑫 46cfce0c97 fix VoxCPM2 training sample_rate: 48000 -> 16000 (match AudioVAE encoder)
Made-with: Cursor
2026-04-07 22:59:18 +08:00
Labmem-Zhouyx da700f264e update ZH readme 2026-04-07 18:04:56 +08:00
Labmem-Zhouyx 9da570d409 remove wechat link 2026-04-07 15:29:12 +08:00
Labmem-Zhouyx 9374524c47 update readme 2026-04-06 23:01:16 +08:00
Labmem-Zhouyx ec6d30e996 update readme 2026-04-06 22:56:06 +08:00
32 changed files with 2454 additions and 491 deletions
+3
View File
@@ -2,3 +2,6 @@ launch.json
__pycache__
voxcpm.egg-info
.DS_Store
./pretrained_models/
app_local.py
models/
+110 -11
View File
@@ -1,5 +1,9 @@
<h2 align="center">VoxCPM2: Tokenizer-Free TTS for Multilingual Speech Generation, Creative Voice Design, and True-to-Life Cloning</h2>
<p align="center">
<b>English</b> | <a href="./README_zh.md">中文</a>
</p>
<p align="center">
<a href="https://github.com/OpenBMB/VoxCPM/"><img src="https://img.shields.io/badge/Project%20Page-GitHub-blue" alt="Project Page"></a>
<a href="https://huggingface.co/spaces/OpenBMB/VoxCPM-Demo"><img src="https://img.shields.io/badge/Live%20Playground-Demo-orange" alt="Live Playground"></a>
@@ -42,16 +46,16 @@ VoxCPM is a **tokenizer-free** Text-to-Speech system that directly generates con
- 🎙️ **Ultimate Cloning** — Reproduce every vocal nuance: provide both reference audio and its transcript, and the model continues seamlessly from the reference, faithfully preserving every vocal detail — timbre, rhythm, emotion, and style (same as VoxCPM1.5)
- 🔊 **48kHz High-Quality Audio** — Accepts 16kHz reference audio and directly outputs 48kHz studio-quality audio via AudioVAE V2's asymmetric encode/decode design, with built-in super-resolution — no external upsampler needed
- 🧠 **Context-Aware Synthesis** — Automatically infers appropriate prosody and expressiveness from text content
-**Real-Time Streaming** — RTF as low as ~0.3 on NVIDIA RTX 4090, and ~0.13 accelerated by [Nano-VLLM](https://github.com/a710128/nanovllm-voxcpm)
-**Real-Time Streaming** — RTF as low as ~0.3 on NVIDIA RTX 4090, and ~0.13 accelerated by [Nano-vLLM](https://github.com/a710128/nanovllm-voxcpm) or [vLLM-Omni](https://github.com/vllm-project/vllm-omni) — official vLLM omni-modal serving for VoxCPM2 with PagedAttention and an OpenAI-compatible API
- 📜 **Fully Open-Source & Commercial-Ready** — Weights and code released under the [Apache-2.0](LICENSE) license, free for commercial use
<details>
<summary><b>🌍 Supported Languages (30)</b></summary>
<br>
Arabic, Burmese, Chinese, Danish, Dutch, English, Finnish, French, German, Greek, Hebrew, Hindi, Indonesian, Italian, Japanese, Khmer, Korean, Lao, Malay, Norwegian, Polish, Portuguese, Russian, Spanish, Swahili, Swedish, Tagalog, Thai, Turkish, Vietnamese
Chinese Dialect: 四川话, 粤语, 吴语, 东北话, 河南话, 陕西话, 山东话, 天津话, 闽南话
</details>
### News
@@ -88,7 +92,7 @@ Chinese Dialect: 四川话, 粤语, 吴语, 东北话, 河南话, 陕西话, 山
pip install voxcpm
```
> **Requirements:** Python ≥ 3.10, PyTorch ≥ 2.5.0, CUDA ≥ 12.0. See [Quick Start Docs](https://voxcpm.readthedocs.io/en/latest/quickstart.html) for details.
> **Requirements:** Python ≥ 3.10 (<3.13), PyTorch ≥ 2.5.0, CUDA ≥ 12.0. See [Quick Start Docs](https://voxcpm.readthedocs.io/en/latest/quickstart.html) for details.
### Python API
@@ -99,7 +103,7 @@ from voxcpm import VoxCPM
import soundfile as sf
model = VoxCPM.from_pretrained(
"openbmb/VoxCPM2"
"openbmb/VoxCPM2",
load_denoiser=False,
)
@@ -112,6 +116,28 @@ sf.write("demo.wav", wav, model.tts_model.sample_rate)
print("saved: demo.wav")
```
If you prefer downloading from ModelScope first, you can use:
```bash
pip install modelscope
```
```python
from modelscope import snapshot_download
snapshot_download("OpenBMB/VoxCPM2", local_dir='./pretrained_models/VoxCPM2') # specify the local directory to save the model
from voxcpm import VoxCPM
import soundfile as sf
model = VoxCPM.from_pretrained("./pretrained_models/VoxCPM2", load_denoiser=False)
wav = model.generate(
text="VoxCPM2 is the current recommended release for realistic multilingual speech synthesis.",
cfg_value=2.0,
inference_timesteps=10,
)
sf.write("demo.wav", wav, model.tts_model.sample_rate)
```
#### 🎨 Voice Design
Create a voice from a natural-language description — no reference audio needed. **Format:** put the description in parentheses at the start of `text`(e.g. `"(your voice description)The text to synthesize."`):
@@ -132,13 +158,13 @@ Upload a reference audio. The model clones the timbre, and you can still use con
```python
wav = model.generate(
text="This is a cloned voice generated by VoxCPM2.",
reference_wav_path="speaker.wav",
reference_wav_path="path/to/voice.wav",
)
sf.write("clone.wav", wav, model.tts_model.sample_rate)
wav = model.generate(
text="(slightly faster, cheerful tone)This is a cloned voice with style control.",
reference_wav_path="speaker.wav",
reference_wav_path="path/to/voice.wav",
cfg_value=2.0,
inference_timesteps=10,
)
@@ -152,9 +178,9 @@ Provide both the reference audio and its exact transcript for audio-continuation
```python
wav = model.generate(
text="This is an ultimate cloning demonstration using VoxCPM2.",
prompt_wav_path="speaker_reference.wav",
prompt_wav_path="path/to/voice.wav",
prompt_text="The transcript of the reference audio.",
reference_wav_path="speaker_reference.wav",
reference_wav_path="path/to/voice.wav", # optional, for better simliarity
)
sf.write("hifi_clone.wav", wav, model.tts_model.sample_rate)
```
@@ -200,6 +226,7 @@ voxcpm clone \
--text "This is a voice cloning demo." \
--prompt-audio path/to/voice.wav \
--prompt-text "reference transcript" \
--reference-audio path/to/voice.wav \ # optional, for better simliarity
--output out.wav
# Batch processing
@@ -212,7 +239,7 @@ voxcpm --help
### Web Demo
```bash
python app.py # then open http://localhost:7860
python app.py --port 8808 # then open in browser: http://localhost:8808
```
### 🚢 Production Deployment (Nano-vLLM)
@@ -235,6 +262,32 @@ server.stop()
> **RTF as low as ~0.13 on NVIDIA RTX 4090** (vs ~0.3 with the standard PyTorch implementation), with support for batched concurrent requests and a FastAPI HTTP server. See the [Nano-vLLM-VoxCPM repo](https://github.com/a710128/nanovllm-voxcpm) for deployment details.
### 🏭 Production Serving (vLLM-Omni)
For production multi-tenant deployments, use [**vLLM-Omni**](https://github.com/vllm-project/vllm-omni) — the official vLLM project's omni-modal extension with native **VoxCPM2** support. PagedAttention KV cache, continuous batching, and a drop-in **OpenAI-compatible** `/v1/audio/speech` endpoint.
```bash
# Install from source (latest main — vllm-omni is rapidly evolving)
uv pip install vllm==0.19.0 --torch-backend=auto
git clone https://github.com/vllm-project/vllm-omni.git && cd vllm-omni
uv pip install -e .
```
See the [vLLM-Omni installation guide](https://vllm-omni.readthedocs.io/en/latest/getting_started/installation/) for other platforms (ROCm, XPU, MUSA, NPU) and Docker images.
```bash
# Launch an OpenAI-compatible TTS server (--omni enables omni-modal serving)
vllm serve openbmb/VoxCPM2 --omni --port 8000
# Call it from any OpenAI client
curl http://localhost:8000/v1/audio/speech \
-H "Content-Type: application/json" \
-d '{"model":"openbmb/VoxCPM2","input":"Hello from VoxCPM2 on vLLM-Omni!","voice":"default"}' \
--output out.wav
```
> Built on the upstream vLLM scheduler, with batched concurrent requests, streaming chunk delivery, and multi-GPU deployment out of the box. See the [VoxCPM2 example](https://github.com/vllm-project/vllm-omni/tree/main/examples/online_serving/voxcpm2) for full deployment recipes.
> **Full parameter reference, multi-scenario examples, and voice cloning tips →** [Quick Start Guide](https://voxcpm.readthedocs.io/en/latest/quickstart.html) | [Usage Guide](https://voxcpm.readthedocs.io/en/latest/usage_guide.html) | [Cookbook](https://voxcpm.readthedocs.io/en/latest/cookbook.html)
---
@@ -388,10 +441,54 @@ VoxCPM2 achieves state-of-the-art or comparable results on public zero-shot and
</details>
### Internal 30-Language ASR Benchmark
We additionally run an internal multilingual intelligibility benchmark with **30 languages × 500 samples**. ASR transcription is evaluated via **Gemini 3.1 Flash Lite API**.
<details>
<summary><b>Internal 30-Language ASR Benchmark (click to expand)</b></summary>
| Language | Metric | VoxCPM2 | Fish S2-Pro |
|---|---:|---:|---:|
| ar (Arabic) | CER | 1.23% | 0.30% |
| da (Danish) | WER | 2.70% | 3.52% |
| de (German) | WER | 0.96% | 0.64% |
| el (Greek) | WER | 3.17% | 4.61% |
| en (English) | WER | 0.42% | 1.03% |
| es (Spanish) | WER | 1.33% | 0.64% |
| fi (Finnish) | WER | 2.24% | 2.80% |
| fr (French) | WER | 2.16% | 2.34% |
| he (Hebrew) | CER | 2.98% | 15.27% |
| hi (Hindi) | CER | 0.79% | 0.91% |
| id (Indonesian) | WER | 1.36% | 1.68% |
| it (Italian) | WER | 1.65% | 1.08% |
| ja (Japanese) | CER | 2.40% | 1.82% |
| km (Khmer) | CER | 2.05% | 75.15% |
| ko (Korean) | CER | 0.95% | 0.29% |
| lo (Lao) | CER | 1.90% | 87.40% |
| ms (Malay) | WER | 1.75% | 1.41% |
| my (Burmese) | CER | 1.42% | 85.27% |
| nl (Dutch) | WER | 1.25% | 1.68% |
| no (Norwegian) | WER | 2.49% | 3.76% |
| pl (Polish) | WER | 1.90% | 1.65% |
| pt (Portuguese) | WER | 1.48% | 1.49% |
| ru (Russian) | WER | 0.90% | 0.86% |
| sv (Swedish) | WER | 2.22% | 2.63% |
| sw (Swahili) | CER | 1.07% | 2.02% |
| th (Thai) | CER | 0.94% | 1.92% |
| tl (Tagalog) | WER | 2.63% | 4.00% |
| tr (Turkish) | WER | 1.65% | 1.65% |
| vi (Vietnamese) | WER | 1.56% | 5.56% |
| zh (Chinese) | CER | 0.92% | 1.02% |
| Average (30 languages) | | **1.68%** | - |
</details>
### InstructTTSEval
<details>
<summary><b>Instruction-Guided Voice Design Results</b></summary>
<summary><b>Instruction-Guided Voice Design Results (click to expand)</b></summary>
| Model | InstructTTSEval-ZH | | | InstructTTSEval-EN | | |
|-------|:---:|:----:|:----:|:----:|:----:|:----:|
@@ -457,11 +554,13 @@ Full documentation: **[voxcpm.readthedocs.io](https://voxcpm.readthedocs.io/en/l
| Project | Description |
|---|---|
| [**Nano-vLLM**](https://github.com/a710128/nanovllm-voxcpm) | High-throughput and Fast GPU serving |
| [**vLLM-Omni**](https://github.com/vllm-project/vllm-omni) | Official vLLM omni-modal serving for VoxCPM2 — PagedAttention, OpenAI-compatible API |
| [**VoxCPM.cpp**](https://github.com/bluryar/VoxCPM.cpp) | GGML/GGUF: CPU, CUDA, Vulkan inference |
| [**VoxCPM-ONNX**](https://github.com/bluryar/VoxCPM-ONNX) | ONNX export for CPU inference |
| [**VoxCPMANE**](https://github.com/0seba/VoxCPMANE) | Apple Neural Engine backend |
| [**voxcpm_rs**](https://github.com/madushan1000/voxcpm_rs) | Rust re-implementation |
| [**ComfyUI-VoxCPM**](https://github.com/wildminder/ComfyUI-VoxCPM) | ComfyUI node-based workflows |
| [**ComfyUI_RH_VoxCPM**](https://github.com/HM-RunningHub/ComfyUI_RH_VoxCPM) | Feature-complete ComfyUI workflow for VoxCPM 2 with multi-speaker generation, LoRA, and auto-ASR |
| [**ComfyUI-VoxCPMTTS**](https://github.com/1038lab/ComfyUI-VoxCPMTTS) | ComfyUI TTS extension |
| [**TTS WebUI**](https://github.com/rsxdalv/tts_webui_extension.vox_cpm) | Browser-based TTS extension |
+618
View File
@@ -0,0 +1,618 @@
<h2 align="center">VoxCPM2:基于连续表征的多语言语音合成、创意音色设计与高保真声音克隆</h2>
<p align="center">
<a href="./README.md">English</a> | <b>中文</b>
</p>
<p align="center">
<a href="https://github.com/OpenBMB/VoxCPM/"><img src="https://img.shields.io/badge/Project%20Page-GitHub-blue" alt="Project Page"></a>
<a href="https://huggingface.co/spaces/OpenBMB/VoxCPM-Demo"><img src="https://img.shields.io/badge/Live%20Playground-Demo-orange" alt="Live Playground"></a>
<a href="https://voxcpm.readthedocs.io/zh-cn/latest/"><img src="https://img.shields.io/badge/Docs-ReadTheDocs-8CA1AF" alt="Documentation"></a>
<a href="https://huggingface.co/openbmb/VoxCPM2"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-VoxCPM2-yellow" alt="Hugging Face"></a>
<a href="https://modelscope.cn/models/OpenBMB/VoxCPM2"><img src="https://img.shields.io/badge/ModelScope-VoxCPM2-purple" alt="ModelScope"></a>
<a href="https://openbmb.github.io/voxcpm2-demopage/"><img src="https://img.shields.io/badge/DemoPage-Audio Samples-red"></a>
</p>
<div align="center">
<img src="assets/voxcpm_logo.png" alt="VoxCPM Logo" width="35%">
<br><br>
<a href="https://trendshift.io/repositories/17704" target="_blank"><img src="https://trendshift.io/api/badge/repositories/17704" alt="OpenBMB%2FVoxCPM | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
</div>
<br>
<p align="center">
👋 欢迎加入社区,参与讨论与交流!
<br>
<a href="./assets/feishu-group.png" style="display:inline-block;vertical-align:middle; margin-left: 10px;">
<img src="./assets/feishu-logo.png" width="16" height="16" style="vertical-align:middle;"> 飞书群
</a>
&nbsp;|&nbsp;
<a href="https://discord.gg/KZUx7tVNwz" style="display:inline-block;vertical-align:middle;">
<img src="./assets/discord-logo.png" width="16" height="16" style="vertical-align:middle;"> Discord
</a>
</p>
VoxCPM 是一个**无离散音频分词器**Tokenizer-Free)的语音合成系统,通过端到端的**扩散自回归架构**直接生成连续语音表征,绕过对音频的离散编码步骤,实现高度自然且富有表现力的语音合成。
**VoxCPM2** 是最新的版本 — 基于 [MiniCPM-4](https://github.com/OpenBMB/MiniCPM) 基座构建,总计 **20亿** 参数,在超过 **200万小时** 的多语种音频数据上训练,支持 **30种全球语言+9种中文方言**、**音色设计**、**可控声音克隆**,原生输出 **48kHz** 高质量音频。
### ✨ 核心特性
- 🌍 **30种语言语音合成** — 直接输入原始文本即可合成(支持语言详见下文),无需额外语言标签
- 🎨 **音色设计** — 用自然语言描述(性别、年龄、音色、情绪、语速……)凭空创建全新音色,无需参考音频
- 🎛️ **可控声音克隆** — 从参考音频片段克隆任意声音,可叠加风格指令控制情绪、语速和表现力,同时保持原始音色
- 🎙️ **极致克隆** — 提供参考音频及其文本内容,模型接着参考音频进行无缝续写,从而精准还原声音细节特征(与 VoxCPM1.5 一致)
- 🔊 **48kHz 高质量音频** — 输入 16kHz 参考音频,通过 AudioVAE V2 的非对称编解码设计直接输出 48kHz 高质量音频,内置超分能力
- 🧠 **语境感知合成** — 根据文本内容自动推断合适的韵律和表现力
-**实时流式合成** — 在 NVIDIA RTX 4090 上 RTF 低至 ~0.3,通过 [Nano-vLLM](https://github.com/a710128/nanovllm-voxcpm) 或 [vLLM-Omni](https://github.com/vllm-project/vllm-omni)(官方 vLLM 全模态服务,原生支持 VoxCPM2,提供 PagedAttention 与 OpenAI 兼容 API)加速后可达 ~0.13
- 📜 **完全开源,商用就绪** — 权重和代码基于 [Apache-2.0](LICENSE) 协议发布,免费商用
<summary><b>🌍 支持的语言(30种)</b></summary>
<br>
阿拉伯语、缅甸语、中文、丹麦语、荷兰语、英语、芬兰语、法语、德语、希腊语、希伯来语、印地语、印尼语、意大利语、日语、高棉语、韩语、老挝语、马来语、挪威语、波兰语、葡萄牙语、俄语、西班牙语、斯瓦希里语、瑞典语、菲律宾语、泰语、土耳其语、越南语
中国方言:四川话、粤语、吴语、东北话、河南话、陕西话、山东话、天津话、闽南话
### 最新动态
* **[2026.04]** 🔥 发布 **VoxCPM2** — 20亿参数,30种语言,音色设计与可控声音克隆,48kHz 音频输出![模型权重](https://huggingface.co/openbmb/VoxCPM2) | [使用文档](https://voxcpm.readthedocs.io/zh-cn/latest/) | [在线体验](https://huggingface.co/spaces/OpenBMB/VoxCPM-Demo) | [官网体验](https://voxcpm.modelbest.cn/) (适用国内访问)
* **[2025.12]** 🎉 开源 **VoxCPM1.5** [模型权重](https://huggingface.co/openbmb/VoxCPM1.5),支持 SFT 和 LoRA 微调。(**🏆 GitHub Trending #1**)
* **[2025.09]** 🔥 发布 VoxCPM [技术报告](https://arxiv.org/abs/2509.24650)。
* **[2025.09]** 🎉 开源 **VoxCPM-0.5B** [模型权重](https://huggingface.co/openbmb/VoxCPM-0.5B) (**🏆 HuggingFace Trending #1**)
---
## 目录
- [快速开始](#-快速开始)
- [安装](#安装)
- [Python API](#python-api)
- [命令行使用](#命令行使用)
- [Web Demo](#web-demo)
- [生产部署](#-生产部署nano-vllm)
- [模型与版本](#-模型与版本)
- [性能评测](#-性能评测)
- [微调](#%EF%B8%8F-微调)
- [文档](#-文档)
- [生态与社区](#-生态与社区)
- [风险与局限性](#%EF%B8%8F-风险与局限性)
- [引用](#-引用)
---
## 🚀 快速开始
### 安装
```sh
pip install voxcpm
```
> **环境要求:** Python ≥ 3.10 (<3.13)PyTorch ≥ 2.5.0CUDA ≥ 12.0。详见 [快速开始文档](https://voxcpm.readthedocs.io/zh-cn/latest/quickstart.html)。
### Python API
#### 🗣️ 文本转语音
```python
from voxcpm import VoxCPM
import soundfile as sf
model = VoxCPM.from_pretrained(
"openbmb/VoxCPM2",
load_denoiser=False,
)
wav = model.generate(
text="VoxCPM2 是目前推荐使用的多语言语音合成版本。",
cfg_value=2.0,
inference_timesteps=10,
)
sf.write("demo.wav", wav, model.tts_model.sample_rate)
print("已保存: demo.wav")
```
如果你希望先从 ModelScope 下载模型到本地(适用于国内网络访问),可以使用:
```bash
pip install modelscope
```
```python
from modelscope import snapshot_download
snapshot_download("OpenBMB/VoxCPM2", local_dir='./pretrained_models/VoxCPM2') # 指定模型保存的本地路径
from voxcpm import VoxCPM
import soundfile as sf
model = VoxCPM.from_pretrained('./pretrained_models/VoxCPM2', load_denoiser=False)
wav = model.generate(
text="VoxCPM2 是目前推荐使用的多语言语音合成版本。",
cfg_value=2.0,
inference_timesteps=10,
)
sf.write("demo.wav", wav, model.tts_model.sample_rate)
```
#### 🎨 音色设计
用自然语言描述创建全新音色,无需参考音频。**格式:** 在 `text` 开头用括号写入音色描述(如 `"(音色描述)要合成的文本。"`):
```python
wav = model.generate(
text="(年轻女性,声音温柔甜美)你好,欢迎使用VoxCPM2!",
cfg_value=2.0,
inference_timesteps=10,
)
sf.write("voice_design.wav", wav, model.tts_model.sample_rate)
```
#### 🎛️ 可控声音克隆
上传一段参考音频,模型克隆其音色,同时可以使用控制指令调节语速、情绪或风格。
```python
wav = model.generate(
text="这是VoxCPM2生成的克隆语音。",
reference_wav_path="path/to/voice.wav",
)
sf.write("clone.wav", wav, model.tts_model.sample_rate)
wav = model.generate(
text="(稍快一点,欢快的语气)这是带风格控制的克隆语音。",
reference_wav_path="path/to/voice.wav",
cfg_value=2.0,
inference_timesteps=10,
)
sf.write("controllable_clone.wav", wav, model.tts_model.sample_rate)
```
#### 🎙️ 极致克隆
提供参考音频及其精确文本转录,实现基于音频续写的高保真克隆。为获得最高克隆相似度,可将同一音频同时传给 `reference_wav_path``prompt_wav_path`
```python
wav = model.generate(
text="这是使用VoxCPM2的极致克隆演示。",
prompt_wav_path="path/to/voice.wav",
prompt_text="参考音频的文本转录。",
reference_wav_path="path/to/voice.wav", # 可选,提升相似度
)
sf.write("hifi_clone.wav", wav, model.tts_model.sample_rate)
```
<details>
<summary><b>🔄 流式 API</b></summary>
```python
import numpy as np
chunks = []
for chunk in model.generate_streaming(
text="使用VoxCPM进行流式语音合成非常简单!",
):
chunks.append(chunk)
wav = np.concatenate(chunks)
sf.write("streaming.wav", wav, model.tts_model.sample_rate)
```
</details>
### 命令行使用
```bash
# 音色设计(无需参考音频)
voxcpm design \
--text "VoxCPM2带来全新语音合成体验。" \
--output out.wav
# 可控声音克隆(带风格控制)
voxcpm design \
--text "VoxCPM2带来全新语音合成体验。" \
--control "年轻女声,温暖温柔,略带微笑" \
--output out.wav
# 声音克隆(参考音频)
voxcpm clone \
--text "这是一个声音克隆的演示。" \
--reference-audio path/to/voice.wav \
--output out.wav
# 极致克隆(提示音频 + 转录文本)
voxcpm clone \
--text "这是一个声音克隆的演示。" \
--prompt-audio path/to/voice.wav \
--prompt-text "参考音频转录文本" \
--reference-audio path/to/voice.wav \
--output out.wav
# 批量处理
voxcpm batch --input examples/input.txt --output-dir outs
# 帮助
voxcpm --help
```
### Web Demo
```bash
python app.py --port 8808 # 然后在浏览器打开 http://localhost:8808
```
### 🚢 生产部署(Nano-vLLM
如需高吞吐量部署,使用 [**Nano-vLLM-VoxCPM**](https://github.com/a710128/nanovllm-voxcpm) — 基于 Nano-vLLM 构建的专用推理引擎,支持并发请求和异步 API。
```bash
pip install nano-vllm-voxcpm
```
```python
from nanovllm_voxcpm import VoxCPM
import numpy as np, soundfile as sf
server = VoxCPM.from_pretrained(model="/path/to/VoxCPM", devices=[0])
chunks = list(server.generate(target_text="你好,我来自VoxCPM"))
sf.write("out.wav", np.concatenate(chunks), 48000)
server.stop()
```
> **在 NVIDIA RTX 4090 上 RTF 低至 ~0.13**(标准 PyTorch 实现约 ~0.3),支持批量并发请求和 FastAPI HTTP 服务。详见 [Nano-vLLM-VoxCPM 仓库](https://github.com/a710128/nanovllm-voxcpm)。
### 🏭 生产环境部署(vLLM-Omni
如需生产级多租户部署,使用 [**vLLM-Omni**](https://github.com/vllm-project/vllm-omni) — 官方 vLLM 项目的全模态扩展,原生支持 **VoxCPM2**。具备 PagedAttention KV 缓存、连续批处理,以及与 OpenAI 完全兼容的 `/v1/audio/speech` 接口。
```bash
# 从源码安装(最新 main 分支 —— vllm-omni 正在快速迭代)
uv pip install vllm==0.19.0 --torch-backend=auto
git clone https://github.com/vllm-project/vllm-omni.git && cd vllm-omni
uv pip install -e .
```
其他平台(ROCm、XPU、MUSA、NPU)与 Docker 镜像请参考 [vLLM-Omni 安装文档](https://vllm-omni.readthedocs.io/en/latest/getting_started/installation/)。
```bash
# 启动 OpenAI 兼容的 TTS 服务(--omni 启用全模态服务)
vllm serve openbmb/VoxCPM2 --omni --port 8000
# 任意 OpenAI 客户端均可调用
curl http://localhost:8000/v1/audio/speech \
-H "Content-Type: application/json" \
-d '{"model":"openbmb/VoxCPM2","input":"你好,欢迎使用 VoxCPM2 on vLLM-Omni","voice":"default"}' \
--output out.wav
```
> 基于上游 vLLM 调度器构建,开箱即用支持批量并发、流式分块输出和多 GPU 部署。完整示例见 [VoxCPM2 部署样例](https://github.com/vllm-project/vllm-omni/tree/main/examples/online_serving/voxcpm2)。
> **完整参数说明、多场景示例与声音克隆技巧 →** [快速开始指南](https://voxcpm.readthedocs.io/zh-cn/latest/quickstart.html) | [使用指南](https://voxcpm.readthedocs.io/zh-cn/latest/usage_guide.html) | [Cookbook](https://voxcpm.readthedocs.io/zh-cn/latest/cookbook.html)
---
## 📦 模型与版本
| | **VoxCPM2** | **VoxCPM1.5** | **VoxCPM-0.5B** |
|---|:---:|:---:|:---:|
| **状态** | 🟢 最新版本 | 稳定版 | 旧版 |
| **主模型参数量** | 2B | 0.6B | 0.5B |
| **音频采样率** | 48kHz | 44.1kHz | 16kHz |
| **LM处理码率** | 6.25Hz | 6.25Hz | 12.5Hz |
| **语言支持数量** | 30 | 2(中文、英文) | 2(中文、英文) |
| **克隆模式** | 隔离参考音频(无需文本) & 音频续写 | 仅音频续写 | 仅音频续写 |
| **音色设计** | ✅ | — | — |
| **可控声音克隆** | ✅ | — | — |
| **SFT / LoRA** | ✅ | ✅ | ✅ |
| **RTF (RTX 4090)** | ~0.30 | ~0.15 | ~0.17 |
| **RTF Nano-VLLM (RTX 4090)** | ~0.13 | ~0.08 | ~0.10 |
| **显存占用** | ~8 GB | ~6 GB | ~5 GB |
| **模型权重** | [🤗 HF](https://huggingface.co/openbmb/VoxCPM2) / [MS](https://modelscope.cn/models/OpenBMB/VoxCPM2) | [🤗 HF](https://huggingface.co/openbmb/VoxCPM1.5) / [MS](https://modelscope.cn/models/OpenBMB/VoxCPM1.5) | [🤗 HF](https://huggingface.co/openbmb/VoxCPM-0.5B) / [MS](https://modelscope.cn/models/OpenBMB/VoxCPM-0.5B) |
| **技术报告** | 即将发布 | — | [arXiv](https://arxiv.org/abs/2509.24650) [ICLR 2026](https://openreview.net/forum?id=h5KLpGoqzC) |
| **Demo 页面** | [音频示例](https://openbmb.github.io/voxcpm2-demopage) | — | [音频示例](https://openbmb.github.io/VoxCPM-demopage) |
VoxCPM2 采用**连续音频表征、扩散自回归**范式,模型在 **AudioVAE** 的连续隐空间中通过四阶段处理:**LocEnc → TSLM → RALM → LocDiT**,实现丰富的表现力语音合成和 48kHz 原生音频输出。
<div align="center">
<img src="assets/voxcpm_model.png" alt="VoxCPM2 模型架构" width="90%">
</div>
> 完整架构细节、VoxCPM2 升级内容和模型对比表见 [架构设计文档](https://voxcpm.readthedocs.io/zh-cn/latest/models/architecture.html)。
---
## 📊 性能评测
VoxCPM2 在公开的零样本和可控 TTS 基准测试中取得了 SOTA 或可比的结果。
### Seed-TTS-eval
<details>
<summary><b>Seed-TTS-eval WER(⬇)&SIM(⬆) 结果(点击展开)</b></summary>
| Model | Parameters | Open-Source | test-EN | | test-ZH | | test-Hard | |
|------|------|------|:------------:|:--:|:------------:|:--:|:-------------:|:--:|
| | | | WER/%⬇ | SIM/%⬆| CER/%⬇| SIM/%⬆ | CER/%⬇ | SIM/%⬆ |
| MegaTTS3 | 0.5B | ❌ | 2.79 | 77.1 | 1.52 | 79.0 | - | - |
| DiTAR | 0.6B | ❌ | 1.69 | 73.5 | 1.02 | 75.3 | - | - |
| CosyVoice3 | 0.5B | ❌ | 2.02 | 71.8 | 1.16 | 78.0 | 6.08 | 75.8 |
| CosyVoice3 | 1.5B | ❌ | 2.22 | 72.0 | 1.12 | 78.1 | 5.83 | 75.8 |
| Seed-TTS | - | ❌ | 2.25 | 76.2 | 1.12 | 79.6 | 7.59 | 77.6 |
| MiniMax-Speech | - | ❌ | 1.65 | 69.2 | 0.83 | 78.3 | - | - |
| F5-TTS | 0.3B | ✅ | 2.00 | 67.0 | 1.53 | 76.0 | 8.67 | 71.3 |
| MaskGCT | 1B | ✅ | 2.62 | 71.7 | 2.27 | 77.4 | - | - |
| CosyVoice | 0.3B | ✅ | 4.29 | 60.9 | 3.63 | 72.3 | 11.75 | 70.9 |
| CosyVoice2 | 0.5B | ✅ | 3.09 | 65.9 | 1.38 | 75.7 | 6.83 | 72.4 |
| SparkTTS | 0.5B | ✅ | 3.14 | 57.3 | 1.54 | 66.0 | - | - |
| FireRedTTS | 0.5B | ✅ | 3.82 | 46.0 | 1.51 | 63.5 | 17.45 | 62.1 |
| FireRedTTS-2 | 1.5B | ✅ | 1.95 | 66.5 | 1.14 | 73.6 | - | - |
| Qwen2.5-Omni | 7B | ✅ | 2.72 | 63.2 | 1.70 | 75.2 | 7.97 | 74.7 |
| Qwen3-Omni | 30B-A3B | ✅ | 1.39 | - | 1.07 | - | - | - |
| OpenAudio-s1-mini | 0.5B | ✅ | 1.94 | 55.0 | 1.18 | 68.5 | 23.37 | 64.3 |
| IndexTTS2 | 1.5B | ✅ | 2.23 | 70.6 | 1.03 | 76.5 | 7.12 | 75.5 |
| VibeVoice | 1.5B | ✅ | 3.04 | 68.9 | 1.16 | 74.4 | - | - |
| HiggsAudio-v2 | 3B | ✅ | 2.44 | 67.7 | 1.50 | 74.0 | 55.07 | 65.6 |
| VoxCPM-0.5B | 0.6B | ✅ | 1.85 | 72.9 | 0.93 | 77.2 | 8.87 | 73.0 |
| VoxCPM1.5 | 0.8B | ✅ | 2.12 | 71.4 | 1.18 | 77.0 | 7.74 | 73.1 |
| MOSS-TTS | | ✅ | 1.85 | 73.4 | 1.20 | 78.8 | - | - |
| Qwen3-TTS | 1.7B | ✅ | 1.23 | 71.7 | 1.22 | 77.0 | 6.76 | 74.8 |
| FishAudio S2 | 4B | ✅ | 0.99 | - | 0.54 | - | 5.99 | - |
| LongCat-Audio-DiT | 3.5B | ✅ | 1.50 | 78.6 | 1.09 | 81.8 | 6.04 | 79.7 |
| **VoxCPM2** | 2B | ✅ | 1.84 | 75.3 | 0.97| 79.5| 8.13 | 75.3 |
</details>
### CV3-eval
<details>
<summary><b>CV3-eval 多语言 WER/CER(⬇) 结果(点击展开)</b></summary>
| Model | zh | en | hard-zh | hard-en | ja | ko | de | es | fr | it | ru |
|-------|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
| CosyVoice2 | 4.08 | 6.32 | 12.58| 11.96| 9.13 | 19.7 |- | - | - | - | - |
| CosyVoice3-1.5B | 3.91 | 4.99 | 9.77 | 10.55 | 7.57 | 5.69 | 6.43 | 4.47 | 11.8 | 10.5 | 6.64 |
| Fish Audio S2 | 2.65 | 2.43 | 9.10 | 4.40 | 3.96 | 2.76 | 2.22 | 2.00 | 6.26 | 2.04 | 2.78 |
| **VoxCPM2** | 3.65 | 5.00 | 8.55 | 8.48 | 5.96 | 5.69 | 4.77 | 3.80 | 9.85 | 4.25 | 5.21 |
</details>
### MiniMax-Multilingual-Test
<details>
<summary><b>Minimax-MLS-test WER(⬇) 结果(点击展开)</b></summary>
| Language | Minimax | ElevenLabs | Qwen3-TTS | FishAudio S2 | **VoxCPM2** |
|----------|:-------:|:----------:|:--------------------:|:------------:|:-----------:|
| Arabic | **1.665** | 1.666 | | 3.500 | 13.046 |
| Cantonese | 34.111 | 51.513 | | **30.670** | 38.584 |
| Chinese | 2.252 | 16.026 | 0.928 | **0.730** | 1.136 |
| Czech | 3.875 | **2.108** | | 2.840 | 24.132 |
| Dutch | 1.143 | **0.803** | | 0.990 | 0.913 |
| English | 2.164 | 2.339 | **0.934** | 1.620 | 2.289 |
| Finnish | 4.666 | 2.964 | | 3.330 | **2.632** |
| French | 4.099 | 5.216 | **2.858** | 3.050 | 4.534 |
| German | 1.906 | 0.572 | 1.235 | **0.550** | 0.679 |
| Greek | 2.016 | **0.991** | | 5.740 | 2.844 |
| Hindi | 6.962 | **5.827** | | 14.640 | 19.699 |
| Indonesian | 1.237 | **1.059** | | 1.460 | 1.084 |
| Italian | 1.543 | 1.743 | **0.948** | 1.270 | 1.563 |
| Japanese | 3.519 | 10.646 | 3.823 | **2.760** | 4.628 |
| Korean | 1.747 | 1.865 | 1.755 | **1.180** | 1.962 |
| Polish | 1.415 | **0.766** | | 1.260 | 1.141 |
| Portuguese | 1.877 | 1.331 | 1.526 | **1.140** | 1.938 |
| Romanian | 2.878 | **1.347** | | 10.740 | 21.577 |
| Russian | 4.281 | 3.878 | 3.212 | **2.400** | 3.634 |
| Spanish | 1.029 | 1.084 | 1.126 | **0.910** | 1.438 |
| Thai | 2.701 | 73.936 | | 4.230 | 2.961 |
| Turkish | 1.52 | 0.699 | | 0.870 | 0.817 |
| Ukrainian | 1.082 | **0.997** | | 2.300 | 6.316 |
| Vietnamese | **0.88** | 73.415 | | 7.410 | 3.307 |
</details>
<details>
<summary><b>Minimax-MLS-test SIM(⬆) 结果(点击展开)</b></summary>
| Language | Minimax | ElevenLabs | Qwen3-TTS | FishAudio S2 | **VoxCPM2** |
|----------|:-------:|:----------:|:--------------------:|:------------:|:-----------:|
| Arabic | 73.6 | 70.6 | | 75.0 | **79.1** |
| Cantonese | 77.8 | 67.0 | | 80.5 | **83.5** |
| Chinese | 78.0 | 67.7 | 79.9 | 81.6 | **82.5** |
| Czech | 79.6 | 68.5 | | **79.8** | 78.3 |
| Dutch | 73.8 | 68.0 | | 73.0 | **80.8** |
| English | 75.6 | 61.3 | 77.5 | 79.7 | **85.4** |
| Finnish | 83.5 | 75.9 | | 81.9 | **89.0** |
| French | 62.8 | 53.5 | 62.8 | 69.8 | **73.5** |
| German | 73.3 | 61.4 | 77.5 | 76.7 | **80.3** |
| Greek | 82.6 | 73.3 | | 79.5 | **86.0** |
| Hindi | 81.8 | 73.0 | | 82.1 | **85.6** |
| Indonesian | 72.9 | 66.0 | | 76.3 | **80.0** |
| Italian | 69.9 | 57.9 | 81.7 | 74.7 | **78.0** |
| Japanese | 77.6 | 73.8 | 78.8 | 79.6 | **82.8** |
| Korean | 77.6 | 70.0 | 79.9 | 81.7 | **83.3** |
| Polish | 80.2 | 72.9 | | 81.9 | **88.4** |
| Portuguese | 80.5 | 71.1 | 81.7 | 78.1 | **83.7** |
| Romanian | **80.9** | 69.9 | | 73.3 | 79.7 |
| Russian | 76.1 | 67.6 | 79.2 | 79.0 | **81.1** |
| Spanish | 76.2 | 61.5 | 81.4 | 77.6 | **83.1** |
| Thai | 80.0 | 58.8 | | 78.6 | **84.0** |
| Turkish | 77.9 | 59.6 | | 83.5 | **87.1** |
| Ukrainian | 73.0 | 64.7 | | 74.7 | **79.8** |
| Vietnamese | 74.3 | 36.9 | | 74.0 | **80.6** |
</details>
### Internal 30-Language ASR Benchmark
我们额外进行了内部多语言可懂度评测:**30 语种 × 500 样本**,ASR 转写评估使用 **Gemini 3.1 Flash Lite API**
<details>
<summary><b>内部30语种评测集ASR结果(点击展开)</b></summary>
| 语言 | 指标 | VoxCPM2 | Fish S2-Pro |
|---|---:|---:|---:|
| ar (阿拉伯语) | CER | 1.23% | 0.30% |
| da (丹麦语) | WER | 2.70% | 3.52% |
| de (德语) | WER | 0.96% | 0.64% |
| el (希腊语) | WER | 3.17% | 4.61% |
| en (英语) | WER | 0.42% | 1.03% |
| es (西班牙语) | WER | 1.33% | 0.64% |
| fi (芬兰语) | WER | 2.24% | 2.80% |
| fr (法语) | WER | 2.16% | 2.34% |
| he (希伯来语) | CER | 2.98% | 15.27% |
| hi (印地语) | CER | 0.79% | 0.91% |
| id (印尼语) | WER | 1.36% | 1.68% |
| it (意大利语) | WER | 1.65% | 1.08% |
| ja (日语) | CER | 2.40% | 1.82% |
| km (高棉语) | CER | 2.05% | 75.15% |
| ko (韩语) | CER | 0.95% | 0.29% |
| lo (老挝语) | CER | 1.90% | 87.40% |
| ms (马来语) | WER | 1.75% | 1.41% |
| my (缅甸语) | CER | 1.42% | 85.27% |
| nl (荷兰语) | WER | 1.25% | 1.68% |
| no (挪威语) | WER | 2.49% | 3.76% |
| pl (波兰语) | WER | 1.90% | 1.65% |
| pt (葡萄牙语) | WER | 1.48% | 1.49% |
| ru (俄语) | WER | 0.90% | 0.86% |
| sv (瑞典语) | WER | 2.22% | 2.63% |
| sw (斯瓦希里语) | CER | 1.07% | 2.02% |
| th (泰语) | CER | 0.94% | 1.92% |
| tl (菲律宾语) | WER | 2.63% | 4.00% |
| tr (土耳其语) | WER | 1.65% | 1.65% |
| vi (越南语) | WER | 1.56% | 5.56% |
| zh (中文) | CER | 0.92% | 1.02% |
| 平均(30 语种) | | **1.68%** | - |
</details>
### InstructTTSEval
<details>
<summary><b>指令驱动音色设计结果 (点击展开)</b></summary>
| Model | InstructTTSEval-ZH | | | InstructTTSEval-EN | | |
|-------|:---:|:----:|:----:|:----:|:----:|:----:|
| | APS⬆| DSD⬆ | RP⬆| APS⬆ | DSD⬆ | RP⬆ |
| Hume | | | | 83.0 | 75.3 | 54.3 |
| VoxInstruct | 47.5 | 52.3 | 42.6 | 54.9 | 57.0 | 39.3 |
| Parler-tts-mini | | | | 63.4 | 48.7 | 28.6 |
| Parler-tts-large | | | | 60.0 | 45.9 | 31.2 |
| PromptTTS | | | | 64.3 | 47.2 | 31.4 |
| PromptStyle | | | | 57.4 | 46.4 | 30.9 |
| VoiceSculptor | 75.7 | 64.7 | 61.5 | | | |
| Mimo-Audio-7B-Instruct | 75.7 | 74.3 | 61.5 | 80.6 | 77.6 | 59.5 |
| Qwen3TTS-12Hz-1.7B-VD | **85.2** | **81.1** | **65.1** | 82.9 | 82.4 | 68.4 |
| **VoxCPM2** | **85.2** | 71.5 | 60.8 | **84.2** | **83.2** | **71.4** |
</details>
---
## ⚙️ 微调
VoxCPM 支持**全参数微调(SFT** 和 **LoRA 微调**。仅需 **5-10分钟** 的音频数据,即可适配特定说话人、语言或领域。
```bash
# LoRA 微调(参数高效,推荐)
python scripts/train_voxcpm_finetune.py \
--config_path conf/voxcpm_v2/voxcpm_finetune_lora.yaml
# 全参数微调
python scripts/train_voxcpm_finetune.py \
--config_path conf/voxcpm_v2/voxcpm_finetune_all.yaml
# WebUI 训练与推理
python lora_ft_webui.py # 然后打开 http://localhost:7860
```
> **完整指南 →** [微调文档](https://voxcpm.readthedocs.io/zh-cn/latest/finetuning/finetune.html)(数据准备、配置、训练、LoRA 热切换、常见问题)
---
## 📚 文档
完整文档:**[voxcpm.readthedocs.io](https://voxcpm.readthedocs.io/zh-cn/latest/)**
| 主题 | 链接 |
|---|---|
| 快速开始与安装 | [快速开始](https://voxcpm.readthedocs.io/zh-cn/latest/quickstart.html) |
| 使用指南与 Cookbook | [使用指南](https://voxcpm.readthedocs.io/zh-cn/latest/usage_guide.html) |
| VoxCPM 系列模型 | [模型列表](https://voxcpm.readthedocs.io/zh-cn/latest/models/version_history.html) |
| 微调(SFT & LoRA | [微调指南](https://voxcpm.readthedocs.io/zh-cn/latest/finetuning/finetune.html) |
| 常见问题 | [FAQ](https://voxcpm.readthedocs.io/zh-cn/latest/faq.html) |
---
## 🌟 生态与社区
| 项目 | 说明 |
|---|---|
| [**Nano-vLLM**](https://github.com/a710128/nanovllm-voxcpm) | 高吞吐快速 GPU 推理引擎 |
| [**vLLM-Omni**](https://github.com/vllm-project/vllm-omni) | 官方 vLLM 全模态服务(原生支持 VoxCPM2)— PagedAttention、OpenAI 兼容 API |
| [**VoxCPM.cpp**](https://github.com/bluryar/VoxCPM.cpp) | GGML/GGUFCPU、CUDA、Vulkan 推理 |
| [**VoxCPM-ONNX**](https://github.com/bluryar/VoxCPM-ONNX) | ONNX 导出,支持 CPU 推理 |
| [**VoxCPMANE**](https://github.com/0seba/VoxCPMANE) | Apple Neural Engine 后端 |
| [**voxcpm_rs**](https://github.com/madushan1000/voxcpm_rs) | Rust 重新实现 |
| [**ComfyUI-VoxCPM**](https://github.com/wildminder/ComfyUI-VoxCPM) | ComfyUI 节点工作流 |
| [**ComfyUI_RH_VoxCPM**](https://github.com/HM-RunningHub/ComfyUI_RH_VoxCPM) | 面向 VoxCPM 2 的功能更完整的 ComfyUI 工作流,支持多说话人、LoRA 和自动 ASR |
| [**ComfyUI-VoxCPMTTS**](https://github.com/1038lab/ComfyUI-VoxCPMTTS) | ComfyUI TTS 扩展 |
| [**TTS WebUI**](https://github.com/rsxdalv/tts_webui_extension.vox_cpm) | 浏览器端 TTS 扩展 |
> 完整生态见[文档](https://voxcpm.readthedocs.io/zh-cn/latest/)。社区项目非 OpenBMB 官方维护。做了什么有趣的东西?[提 Issue 或 PR](https://github.com/OpenBMB/VoxCPM/issues) 把它加进来!
---
## ⚠️ 风险与局限性
- **滥用风险:** VoxCPM 的声音克隆能力可生成高度逼真的合成语音。**严禁**将 VoxCPM 用于冒充他人、欺诈或虚假信息传播。我们强烈建议对所有 AI 生成的内容进行明确标注。
- **可控生成稳定性:** 音色设计和可控声音克隆的结果可能因生成次数而异 — 建议尝试生成 1~3 次以获得理想的音色或风格。我们正在积极提升可控性的一致性。
- **语言覆盖:** VoxCPM2 官方支持 30 种语言。对于未列入的语言,欢迎直接测试或使用自有数据进行微调。我们计划在未来版本中扩展语言覆盖。
- **使用说明:** 本模型基于 Apache-2.0 协议发布。用于生产部署时,我们建议针对具体场景进行充分的测试和安全评估。
---
## 📖 引用
如果 VoxCPM 对您有帮助,请考虑引用我们的工作并为仓库加星 ⭐!
```bib
@article{voxcpm2_2026,
title = {VoxCPM2: Tokenizer-Free TTS for Multilingual Speech Generation, Creative Voice Design, and True-to-Life Cloning},
author = {VoxCPM Team},
journal = {GitHub},
year = {2026},
}
@article{voxcpm2025,
title = {VoxCPM: Tokenizer-Free TTS for Context-Aware Speech Generation
and True-to-Life Voice Cloning},
author = {Zhou, Yixuan and Zeng, Guoyang and Liu, Xin and Li, Xiang and
Yu, Renjie and Wang, Ziyang and Ye, Runchuan and Sun, Weiyue and
Gui, Jiancheng and Li, Kehan and Wu, Zhiyong and Liu, Zhiyuan},
journal = {arXiv preprint arXiv:2509.24650},
year = {2025},
}
```
## 📄 许可证
VoxCPM 模型权重和代码基于 [Apache-2.0](LICENSE) 协议开源。
## 🙏 致谢
- [DiTAR](https://arxiv.org/abs/2502.03930) 扩散自回归骨干架构
- [MiniCPM-4](https://github.com/OpenBMB/MiniCPM) 语言模型基座
- [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) 基于 Flow Matching 的 LocDiT 实现
- [DAC](https://github.com/descriptinc/descript-audio-codec) Audio VAE 骨干
- 感谢所有社区用户试用 VoxCPM、反馈问题、分享想法和贡献——你们的支持让项目持续进步
## 机构
<p>
<a href="https://modelbest.cn/"><img src="assets/modelbest_logo.png" width="28px"> 面壁智能</a>
&nbsp;&nbsp;&nbsp;
<a href="https://github.com/thuhcsi"><img src="assets/thuhcsi_logo.png" width="28px"> 清华大学人机交互实验室</a>
</p>
## ⭐ Star 历史
[![Star History Chart](https://api.star-history.com/svg?repos=OpenBMB/VoxCPM&type=Date)](https://star-history.com/#OpenBMB/VoxCPM&Date)
+21 -37
View File
@@ -1,4 +1,5 @@
import os
import re
import sys
import logging
import numpy as np
@@ -9,8 +10,6 @@ from funasr import AutoModel
from pathlib import Path
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if os.environ.get("HF_REPO_ID", "").strip() == "":
os.environ["HF_REPO_ID"] = "openbmb/VoxCPM2"
import voxcpm
@@ -55,7 +54,7 @@ _EXAMPLES_FOOTER_EN = (
)
_USAGE_INSTRUCTIONS_ZH = (
"**VoxCPM2 — 三种语音生成方式:**\n\n"
"**三种语音生成方式:**\n\n"
"🎨 **声音设计(Voice Design** \n"
"无需参考音频。在 **Control Instruction** 中描述目标音色特征"
"(性别、年龄、语气、情绪、语速等),VoxCPM2 即可为你从零创造独一无二的声音。\n\n"
@@ -66,6 +65,8 @@ _USAGE_INSTRUCTIONS_ZH = (
"开启 **极致克隆模式** 并提供参考音频的文字内容(可自动识别)。"
"模型会将参考音频视为已说出的前文,以**音频续写**的方式完整还原参考音频中的所有声音细节。"
"注意:该模式与可控克隆模式互斥,将禁用Control Instruction。\n\n"
"目前支持的方言包括:\n"
"「四川话、粤语、吴语、东北话、河南话、陕西话、山东话、天津话、闽南话」"
)
_EXAMPLES_FOOTER_ZH = (
@@ -221,11 +222,11 @@ _APP_THEME = gr.themes.Soft(
# ---------- Model ----------
class VoxCPMDemo:
def __init__(self, model_dir: Optional[str] = None) -> None:
def __init__(self, model_id: str = "openbmb/VoxCPM2") -> None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Running on device: {self.device}")
logger.info(f"运行在设备上: {self.device}")
self.asr_model_id = "iic/SenseVoiceSmall"
self.asr_model_id = "./models/iic/SenseVoiceSmall"
self.asr_model: Optional[AutoModel] = AutoModel(
model=self.asr_model_id,
disable_update=True,
@@ -234,36 +235,13 @@ class VoxCPMDemo:
)
self.voxcpm_model: Optional[voxcpm.VoxCPM] = None
self.explicit_model_dir = model_dir
def _resolve_model_dir(self) -> str:
if self.explicit_model_dir and os.path.isdir(self.explicit_model_dir):
return self.explicit_model_dir
env_model_dir = os.environ.get("VOXCPM_MODEL_DIR", "").strip()
if env_model_dir and os.path.isdir(env_model_dir):
return env_model_dir
repo_id = os.environ.get("HF_REPO_ID", "").strip()
if len(repo_id) > 0:
target_dir = os.path.join("models", repo_id.replace("/", "__"))
if not os.path.isdir(target_dir):
try:
from huggingface_hub import snapshot_download
os.makedirs(target_dir, exist_ok=True)
logger.info(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...")
snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False)
except Exception as e:
logger.warning(f"HF download failed: {e}. Falling back to 'models'.")
return "models"
return target_dir
return "models"
self._model_id = model_id
def get_or_load_voxcpm(self) -> voxcpm.VoxCPM:
if self.voxcpm_model is not None:
return self.voxcpm_model
logger.info("Model not loaded, initializing...")
model_dir = self._resolve_model_dir()
logger.info(f"Using model dir: {model_dir}")
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir, optimize=True)
logger.info(f"Loading model: {self._model_id}")
self.voxcpm_model = voxcpm.VoxCPM.from_pretrained(self._model_id, optimize=True,zipenhancer_model_id="./models/iic/speech_zipenhancer_ans_multiloss_16k_base")
logger.info("Model loaded successfully.")
return self.voxcpm_model
@@ -315,6 +293,9 @@ class VoxCPMDemo:
raise ValueError("Please input text to synthesize.")
control = (control_instruction or "").strip()
# Strip any parentheses (half-width/full-width) from control text to avoid
# breaking the "(control)text" prompt format expected by the model.
control = re.sub(r"[()()]", "", control).strip()
final_text = f"({control}){text}" if control else text
audio_path = reference_wav_path_input if reference_wav_path_input else None
@@ -507,9 +488,9 @@ def run_demo(
server_name: str = "0.0.0.0",
server_port: int = 8808,
show_error: bool = True,
model_dir: Optional[str] = None,
model_id: str = "./models/OpenBMB/VoxCPM2",
):
demo = VoxCPMDemo(model_dir=model_dir)
demo = VoxCPMDemo(model_id=model_id)
interface = create_demo_interface(demo)
interface.queue(max_size=10, default_concurrency_limit=1).launch(
server_name=server_name,
@@ -524,7 +505,10 @@ def run_demo(
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--model-dir", type=str, default=None, help="Path to VoxCPM2 checkpoint directory")
parser.add_argument("--port", type=int, default=8808, help="Server port")
parser.add_argument(
"--model-id", type=str, default="./models/OpenBMB/VoxCPM2",
help="本地路径或HuggingFace仓库ID(默认:./models/openbmb/VoxCPM2",
)
parser.add_argument("--port", type=int, default=8808, help="服务端口")
args = parser.parse_args()
run_demo(model_dir=args.model_dir, server_port=args.port)
run_demo(model_id=args.model_id, server_port=args.port)
-280
View File
@@ -1,280 +0,0 @@
import os
import sys
import numpy as np
import torch
import gradio as gr
from typing import Optional, Tuple
from funasr import AutoModel
from pathlib import Path
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if os.environ.get("HF_REPO_ID", "").strip() == "":
os.environ["HF_REPO_ID"] = "openbmb/VoxCPM1.5"
import voxcpm
class VoxCPMDemo:
def __init__(self) -> None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🚀 Running on device: {self.device}", file=sys.stderr)
# ASR model for prompt text recognition
self.asr_model_id = "iic/SenseVoiceSmall"
self.asr_model: Optional[AutoModel] = AutoModel(
model=self.asr_model_id,
disable_update=True,
log_level='DEBUG',
device="cuda:0" if self.device == "cuda" else "cpu",
)
# TTS model (lazy init)
self.voxcpm_model: Optional[voxcpm.VoxCPM] = None
self.default_local_model_dir = "./models/VoxCPM1.5"
# ---------- Model helpers ----------
def _resolve_model_dir(self) -> str:
"""
Resolve model directory:
1) Use local checkpoint directory if exists
2) If HF_REPO_ID env is set, download into models/{repo}
3) Fallback to 'models'
"""
if os.path.isdir(self.default_local_model_dir):
return self.default_local_model_dir
repo_id = os.environ.get("HF_REPO_ID", "").strip()
if len(repo_id) > 0:
target_dir = os.path.join("models", repo_id.replace("/", "__"))
if not os.path.isdir(target_dir):
try:
from huggingface_hub import snapshot_download # type: ignore
os.makedirs(target_dir, exist_ok=True)
print(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...", file=sys.stderr)
snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False)
except Exception as e:
print(f"Warning: HF download failed: {e}. Falling back to 'data'.", file=sys.stderr)
return "models"
return target_dir
return "models"
def get_or_load_voxcpm(self) -> voxcpm.VoxCPM:
if self.voxcpm_model is not None:
return self.voxcpm_model
print("Model not loaded, initializing...", file=sys.stderr)
model_dir = self._resolve_model_dir()
print(f"Using model dir: {model_dir}", file=sys.stderr)
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir)
print("Model loaded successfully.", file=sys.stderr)
return self.voxcpm_model
# ---------- Functional endpoints ----------
def prompt_wav_recognition(self, prompt_wav: Optional[str]) -> str:
if prompt_wav is None:
return ""
res = self.asr_model.generate(input=prompt_wav, language="auto", use_itn=True)
text = res[0]["text"].split('|>')[-1]
return text
def generate_tts_audio(
self,
text_input: str,
prompt_wav_path_input: Optional[str] = None,
prompt_text_input: Optional[str] = None,
cfg_value_input: float = 2.0,
inference_timesteps_input: int = 10,
do_normalize: bool = True,
denoise: bool = True,
) -> Tuple[int, np.ndarray]:
"""
Generate speech from text using VoxCPM; optional reference audio for voice style guidance.
Returns (sample_rate, waveform_numpy)
"""
current_model = self.get_or_load_voxcpm()
text = (text_input or "").strip()
if len(text) == 0:
raise ValueError("Please input text to synthesize.")
prompt_wav_path = prompt_wav_path_input if prompt_wav_path_input else None
prompt_text = prompt_text_input if prompt_text_input else None
print(f"Generating audio for text: '{text[:60]}...'", file=sys.stderr)
wav = current_model.generate(
text=text,
prompt_text=prompt_text,
prompt_wav_path=prompt_wav_path,
cfg_value=float(cfg_value_input),
inference_timesteps=int(inference_timesteps_input),
normalize=do_normalize,
denoise=denoise,
)
return (current_model.tts_model.sample_rate, wav)
# ---------- UI Builders ----------
_APP_THEME = gr.themes.Soft(
primary_hue="blue",
secondary_hue="gray",
neutral_hue="slate",
font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"],
)
_CUSTOM_CSS = """
.logo-container {
text-align: center;
margin: 0.5rem 0 1rem 0;
}
.logo-container img {
height: 80px;
width: auto;
max-width: 200px;
display: inline-block;
}
/* Bold accordion labels */
#acc_quick details > summary,
#acc_tips details > summary {
font-weight: 600 !important;
font-size: 1.1em !important;
}
/* Bold labels for specific checkboxes */
#chk_denoise label,
#chk_denoise span,
#chk_normalize label,
#chk_normalize span {
font-weight: 600;
}
"""
def create_demo_interface(demo: VoxCPMDemo):
"""Build the Gradio UI for VoxCPM demo."""
gr.set_static_paths(paths=[Path.cwd().absolute()/"assets"])
with gr.Blocks() as interface:
# Header logo
gr.HTML('<div class="logo-container"><img src="/gradio_api/file=assets/voxcpm_logo.png" alt="VoxCPM Logo"></div>')
# Quick Start
with gr.Accordion("📋 Quick Start Guide |快速入门", open=False, elem_id="acc_quick"):
gr.Markdown("""
### How to Use |使用说明
1. **(Optional) Provide a Voice Prompt** - Upload or record an audio clip to provide the desired voice characteristics for synthesis.
**(可选)提供参考声音** - 上传或录制一段音频,为声音合成提供音色、语调和情感等个性化特征
2. **(Optional) Enter prompt text** - If you provided a voice prompt, enter the corresponding transcript here (auto-recognition available).
**(可选项)输入参考文本** - 如果提供了参考语音,请输入其对应的文本内容(支持自动识别)。
3. **Enter target text** - Type the text you want the model to speak.
**输入目标文本** - 输入您希望模型朗读的文字内容。
4. **Generate Speech** - Click the "Generate" button to create your audio.
**生成语音** - 点击"生成"按钮,即可为您创造出音频。
""")
# Pro Tips
with gr.Accordion("💡 Pro Tips |使用建议", open=False, elem_id="acc_tips"):
gr.Markdown("""
### Prompt Speech Enhancement|参考语音降噪
- **Enable** to remove background noise for a clean voice, with an external ZipEnhancer component. However, this will limit the audio sampling rate to 16kHz, restricting the cloning quality ceiling.
**启用**:通过 ZipEnhancer 组件消除背景噪音,但会将音频采样率限制在16kHz,限制克隆上限。
- **Disable** to preserve the original audio's all information, including background atmosphere, and support audio cloning up to 44.1kHz sampling rate.
**禁用**:保留原始音频的全部信息,包括背景环境声,最高支持44.1kHz的音频复刻。
### Text Normalization|文本正则化
- **Enable** to process general text with an external WeTextProcessing component.
**启用**:使用 WeTextProcessing 组件,可支持常见文本的正则化处理。
- **Disable** to use VoxCPM's native text understanding ability. For example, it supports phonemes input (For Chinese, phonemes are converted using pinyin, {ni3}{hao3}; For English, phonemes are converted using CMUDict, {HH AH0 L OW1}), try it!
**禁用**:将使用 VoxCPM 内置的文本理解能力。如,支持音素输入(如中文转拼音:{ni3}{hao3};英文转CMUDict{HH AH0 L OW1})和公式符号合成,尝试一下!
### CFG ValueCFG 值
- **Lower CFG** if the voice prompt sounds strained or expressive, or instability occurs with long text input.
**调低**:如果提示语音听起来不自然或过于夸张,或者长文本输入出现稳定性问题。
- **Higher CFG** for better adherence to the prompt speech style or input text, or instability occurs with too short text input.
**调高**:为更好地贴合提示音频的风格或输入文本, 或者极短文本输入出现稳定性问题。
### Inference Timesteps|推理时间步
- **Lower** for faster synthesis speed.
**调低**:合成速度更快。
- **Higher** for better synthesis quality.
**调高**:合成质量更佳。
""")
# Main controls
with gr.Row():
with gr.Column():
prompt_wav = gr.Audio(
sources=["upload", 'microphone'],
type="filepath",
label="Prompt Speech (Optional, or let VoxCPM improvise)",
value="./examples/example.wav",
)
DoDenoisePromptAudio = gr.Checkbox(
value=False,
label="Prompt Speech Enhancement",
elem_id="chk_denoise",
info="We use ZipEnhancer model to denoise the prompt audio."
)
with gr.Row():
prompt_text = gr.Textbox(
value="Just by listening a few minutes a day, you'll be able to eliminate negative thoughts by conditioning your mind to be more positive.",
label="Prompt Text",
placeholder="Please enter the prompt text. Automatic recognition is supported, and you can correct the results yourself..."
)
run_btn = gr.Button("Generate Speech", variant="primary")
with gr.Column():
cfg_value = gr.Slider(
minimum=1.0,
maximum=3.0,
value=2.0,
step=0.1,
label="CFG Value (Guidance Scale)",
info="Higher values increase adherence to prompt, lower values allow more creativity"
)
inference_timesteps = gr.Slider(
minimum=4,
maximum=30,
value=10,
step=1,
label="Inference Timesteps",
info="Number of inference timesteps for generation (higher values may improve quality but slower)"
)
with gr.Row():
text = gr.Textbox(
value="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly realistic speech.",
label="Target Text",
)
with gr.Row():
DoNormalizeText = gr.Checkbox(
value=False,
label="Text Normalization",
elem_id="chk_normalize",
info="We use wetext library to normalize the input text."
)
audio_output = gr.Audio(label="Output Audio")
# Wiring
run_btn.click(
fn=demo.generate_tts_audio,
inputs=[text, prompt_wav, prompt_text, cfg_value, inference_timesteps, DoNormalizeText, DoDenoisePromptAudio],
outputs=[audio_output],
show_progress=True,
api_name="generate",
)
prompt_wav.change(fn=demo.prompt_wav_recognition, inputs=[prompt_wav], outputs=[prompt_text])
return interface
def run_demo(server_name: str = "localhost", server_port: int = 7860, show_error: bool = True):
demo = VoxCPMDemo()
interface = create_demo_interface(demo)
interface.queue(max_size=10, default_concurrency_limit=1).launch(
server_name=server_name,
server_port=server_port,
show_error=show_error,
theme=_APP_THEME,
css=_CUSTOM_CSS,
)
if __name__ == "__main__":
run_demo()
BIN
View File
Binary file not shown.

Before

Width:  |  Height:  |  Size: 9.5 KiB

+3 -1
View File
@@ -1,7 +1,8 @@
pretrained_path: /path/to/VoxCPM2/
train_manifest: /path/to/train.jsonl
val_manifest: null
sample_rate: 48000
sample_rate: 16000 # AudioVAE encoder input rate; must match audio_vae_config.sample_rate
out_sample_rate: 48000 # AudioVAE decoder output rate; used for TensorBoard audio logging
batch_size: 2
grad_accum_steps: 8 # effective batch size = batch_size × grad_accum_steps = 16
num_workers: 8
@@ -14,6 +15,7 @@ weight_decay: 0.01
warmup_steps: 100
max_steps: 1000
max_batch_tokens: 8192
max_grad_norm: 1.0 # gradient clipping max norm; 0 = disabled
save_path: /path/to/checkpoints/finetune_all
tensorboard: /path/to/logs/finetune_all
lambdas:
+3 -1
View File
@@ -1,7 +1,8 @@
pretrained_path: /path/to/VoxCPM2/
train_manifest: /path/to/train.jsonl
val_manifest: null
sample_rate: 48000
sample_rate: 16000 # AudioVAE encoder input rate; must match audio_vae_config.sample_rate
out_sample_rate: 48000 # AudioVAE decoder output rate; used for TensorBoard audio logging
batch_size: 2
grad_accum_steps: 8 # effective batch size = batch_size × grad_accum_steps = 16
num_workers: 8
@@ -14,6 +15,7 @@ weight_decay: 0.01
warmup_steps: 100
max_steps: 1000
max_batch_tokens: 8192
max_grad_norm: 1.0 # gradient clipping max norm; 0 = disabled
save_path: /path/to/checkpoints/finetune_lora
tensorboard: /path/to/logs/finetune_lora
lambdas:
+99 -12
View File
@@ -14,8 +14,10 @@ from typing import Optional
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root / "src"))
# Default pretrained model path relative to this repo
default_pretrained_path = str(project_root / "models" / "openbmb__VoxCPM1.5")
# Default pretrained model path: prefer VoxCPM2 if it exists, fallback to VoxCPM1.5
_v2_path = project_root / "models" / "openbmb__VoxCPM2"
_v15_path = project_root / "models" / "openbmb__VoxCPM1.5"
default_pretrained_path = str(_v2_path if _v2_path.exists() else _v15_path)
from voxcpm.core import VoxCPM
from voxcpm.model.voxcpm import LoRAConfig
@@ -99,6 +101,24 @@ def get_timestamp_str():
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
def detect_sample_rate(pretrained_path: str) -> Optional[int]:
"""Read audio_vae_config.sample_rate from the model's config.json.
This is the AudioVAE *encoder* input rate, which is the correct rate for
resampling training data. Returns None when detection fails.
"""
config_file = os.path.join(pretrained_path, "config.json")
if not os.path.isfile(config_file):
return None
try:
with open(config_file, "r", encoding="utf-8") as f:
cfg = json.load(f)
return int(cfg["audio_vae_config"]["sample_rate"])
except (KeyError, ValueError, json.JSONDecodeError) as e:
print(f"Warning: failed to detect sample_rate from {config_file}: {e}", file=sys.stderr)
return None
def get_or_load_asr_model():
global asr_model
if asr_model is None:
@@ -261,27 +281,48 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
print(f"Warning: Failed to read base_model from LoRA config: {e}", file=sys.stderr)
# 加载模型
lora_to_load = lora_selection if lora_selection and lora_selection != "None" else None
try:
print(f"Loading base model: {base_model_path}", file=sys.stderr)
load_model(base_model_path)
if lora_selection and lora_selection != "None":
print(f"Model loaded for LoRA: {lora_selection}", file=sys.stderr)
load_model(base_model_path, lora_to_load)
if lora_to_load:
print(f"Model loaded with LoRA: {lora_selection}", file=sys.stderr)
except Exception as e:
error_msg = f"Failed to load model from {base_model_path}: {str(e)}"
print(error_msg, file=sys.stderr)
return None, error_msg
lora_just_loaded = lora_to_load
else:
lora_just_loaded = None
# Handle LoRA hot-swapping
assert current_model is not None, "Model must be loaded before inference"
if lora_selection and lora_selection != "None":
full_lora_path = os.path.join("lora", lora_selection)
if lora_just_loaded != lora_selection:
new_lora_config, new_base_model = load_lora_config_from_checkpoint(full_lora_path)
current_r = current_model.tts_model.lora_config.r if current_model.tts_model.lora_config else None
new_r = new_lora_config.r if new_lora_config else None
if new_r is not None and current_r is not None and new_r != current_r:
print(f"LoRA rank mismatch (model r={current_r}, checkpoint r={new_r}), reloading...", file=sys.stderr)
reload_base = (
new_base_model if new_base_model and os.path.exists(new_base_model)
else (pretrained_path if pretrained_path and pretrained_path.strip() else default_pretrained_path)
)
try:
load_model(reload_base, lora_selection)
except Exception as e:
return None, f"Failed to reload model for LoRA rank change: {e}"
else:
print(f"Hot-loading LoRA: {full_lora_path}", file=sys.stderr)
try:
current_model.load_lora(full_lora_path)
current_model.set_lora_enabled(True)
except Exception as e:
print(f"Error loading LoRA: {e}", file=sys.stderr)
return None, f"Error loading LoRA: {e}"
current_model.set_lora_enabled(True)
else:
print("Disabling LoRA", file=sys.stderr)
current_model.set_lora_enabled(False)
@@ -350,6 +391,7 @@ def start_training(
warmup_steps=100,
max_steps=None,
sample_rate=44100,
max_grad_norm=1.0,
# LoRA advanced
enable_lm=True,
enable_dit=True,
@@ -377,15 +419,39 @@ def start_training(
os.makedirs(checkpoints_dir, exist_ok=True)
os.makedirs(logs_dir, exist_ok=True)
# Auto-detect sample_rate from model config.json to prevent mismatch
detected_sr = detect_sample_rate(pretrained_path)
if detected_sr is not None:
if int(sample_rate) != detected_sr:
training_log += (
f"[Auto-fix] sample_rate changed from {int(sample_rate)} to {detected_sr} "
f"(read from {pretrained_path}/config.json audio_vae_config.sample_rate)\n"
)
sample_rate = detected_sr
# Create config dictionary
# Resolve max_steps default
resolved_max_steps = int(max_steps) if max_steps not in (None, "", 0) else int(num_iters)
# Auto-detect out_sample_rate from model config
out_sample_rate = 0
config_file = os.path.join(pretrained_path, "config.json")
if os.path.isfile(config_file):
try:
with open(config_file, "r", encoding="utf-8") as f:
cfg = json.load(f)
out_sr = cfg.get("audio_vae_config", {}).get("out_sample_rate")
if out_sr:
out_sample_rate = int(out_sr)
except Exception:
pass
config = {
"pretrained_path": pretrained_path,
"train_manifest": train_manifest,
"val_manifest": val_manifest,
"sample_rate": int(sample_rate),
"out_sample_rate": out_sample_rate,
"batch_size": int(batch_size),
"grad_accum_steps": int(grad_accum_steps),
"num_workers": int(num_workers),
@@ -397,6 +463,7 @@ def start_training(
"weight_decay": float(weight_decay),
"warmup_steps": int(warmup_steps),
"max_steps": resolved_max_steps,
"max_grad_norm": float(max_grad_norm),
"save_path": checkpoints_dir,
"tensorboard": tensorboard_path if tensorboard_path else logs_dir,
"lambdas": {"loss/diff": 1.0, "loss/stop": 1.0},
@@ -904,17 +971,19 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
with gr.Row():
max_steps = gr.Number(label="最大步数 (max_steps, 0→默认num_iters)", value=0, precision=0)
sample_rate = gr.Number(label="采样率 (sample_rate)", value=44100, precision=0)
tensorboard_path = gr.Textbox(label="Tensorboard 路径 (可选)", value="")
max_grad_norm = gr.Number(label="梯度裁剪 (max_grad_norm, 0=关闭)", value=1.0)
with gr.Row():
tensorboard_path = gr.Textbox(label="Tensorboard 路径 (可选)", value="")
enable_lm = gr.Checkbox(label="启用 LoRA LM (enable_lm)", value=True)
enable_dit = gr.Checkbox(label="启用 LoRA DIT (enable_dit)", value=True)
with gr.Row():
enable_proj = gr.Checkbox(label="启用投影 (enable_proj)", value=False)
dropout = gr.Number(label="LoRA Dropout", value=0.0)
gr.Markdown("#### 分发选项 (Distribution)")
with gr.Row():
hf_model_id = gr.Textbox(
label="HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)", value="openbmb/VoxCPM1.5"
label="HuggingFace Model ID (e.g., openbmb/VoxCPM2)", value=""
)
distribute = gr.Checkbox(label="分发模式 (distribute)", value=False)
@@ -929,6 +998,19 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
show_label=False,
)
def on_pretrained_path_change(path):
"""Auto-detect sample_rate when pretrained model path changes."""
sr = detect_sample_rate(path)
if sr is not None:
return gr.update(value=sr)
return gr.update()
train_pretrained_path.change(
on_pretrained_path_change,
inputs=[train_pretrained_path],
outputs=[sample_rate],
)
start_btn.click(
start_training,
inputs=[
@@ -951,6 +1033,7 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
warmup_steps,
max_steps,
sample_rate,
max_grad_norm,
enable_lm,
enable_dit,
enable_proj,
@@ -1109,12 +1192,13 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
"warmup_steps": "warmup_steps",
"max_steps": "最大步数 (max_steps)",
"sample_rate": "采样率 (sample_rate)",
"max_grad_norm": "梯度裁剪 (max_grad_norm, 0=关闭)",
"enable_lm": "启用 LoRA LM (enable_lm)",
"enable_dit": "启用 LoRA DIT (enable_dit)",
"enable_proj": "启用投影 (enable_proj)",
"dropout": "LoRA Dropout",
"tensorboard_path": "Tensorboard 路径 (可选)",
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM2)",
"distribute": "分发模式 (distribute)",
}
else:
@@ -1127,12 +1211,13 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
"warmup_steps": "Warmup Steps",
"max_steps": "Max Steps",
"sample_rate": "Sample Rate",
"max_grad_norm": "Max Grad Norm (0=disabled)",
"enable_lm": "Enable LoRA LM",
"enable_dit": "Enable LoRA DIT",
"enable_proj": "Enable Projection",
"dropout": "LoRA Dropout",
"tensorboard_path": "Tensorboard Path (Optional)",
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM2)",
"distribute": "Distribute Mode",
}
@@ -1162,11 +1247,12 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
gr.update(label=adv["warmup_steps"]),
gr.update(label=adv["max_steps"]),
gr.update(label=adv["sample_rate"]),
gr.update(label=adv["max_grad_norm"]),
gr.update(label=adv["tensorboard_path"]),
gr.update(label=adv["enable_lm"]),
gr.update(label=adv["enable_dit"]),
gr.update(label=adv["enable_proj"]),
gr.update(label=adv["dropout"]),
gr.update(label=adv["tensorboard_path"]),
# Distribution options
gr.update(label=adv["hf_model_id"]),
gr.update(label=adv["distribute"]),
@@ -1213,11 +1299,12 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
warmup_steps,
max_steps,
sample_rate,
max_grad_norm,
tensorboard_path,
enable_lm,
enable_dit,
enable_proj,
dropout,
tensorboard_path,
# distribution outputs
hf_model_id,
distribute,
+34
View File
@@ -0,0 +1,34 @@
"""
模型下载脚本
"""
from modelscope import snapshot_download
def download(repo_id:str, local_dir:str):
"""
下载模型仓库或单个文件
Args:
repo_id (str): 用户名/仓库名,例如 'stabilityai/sdxl-turbo'
local_dir (str or Path): 下载文件放置的本地目录路径
Returns:
str: 下载文件的本地路径
Raises:
ValueError: 当 repo_id 格式不正确时
"""
model_dir = snapshot_download(
repo_id,
repo_type='model',
local_dir=f"{local_dir}/{repo_id}",
)
if __name__ == "__main__":
download("OpenBMB/VoxCPM2", "./models")
download("iic/SenseVoiceSmall", "./models")
download('iic/speech_zipenhancer_ans_multiloss_16k_base',"./models")
+1 -2
View File
@@ -47,8 +47,7 @@ dependencies = [
"funasr",
"spaces",
"argbind",
"safetensors"
"safetensors",
]
[project.optional-dependencies]
+108
View File
@@ -0,0 +1,108 @@
"""Unit checks for pick_runtime_dtype / get_dtype consistency.
Loads src/voxcpm/model/utils.py directly to avoid the heavy voxcpm package
init. Run with: `python scripts/test_pick_runtime_dtype.py`.
"""
import importlib.util
import os
import pathlib
import sys
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent
UTILS = str(REPO_ROOT / "src" / "voxcpm" / "model" / "utils.py")
spec = importlib.util.spec_from_file_location("voxcpm_utils", UTILS)
utils = importlib.util.module_from_spec(spec)
spec.loader.exec_module(utils)
_LOW_PRECISION_DTYPES = utils._LOW_PRECISION_DTYPES
_VALID_DTYPE_OVERRIDES = utils._VALID_DTYPE_OVERRIDES
get_dtype = utils.get_dtype
pick_runtime_dtype = utils.pick_runtime_dtype
def expect(actual, expected, label):
ok = actual == expected
mark = "OK " if ok else "FAIL"
print(f"[{mark}] {label}: got={actual!r} expected={expected!r}")
return ok
def expect_raises(fn, exc_type, label):
try:
fn()
except exc_type as e:
print(f"[OK ] {label}: raised {exc_type.__name__}: {e}")
return True
except Exception as e:
print(f"[FAIL] {label}: raised {type(e).__name__} not {exc_type.__name__}: {e}")
return False
print(f"[FAIL] {label}: no exception raised")
return False
results = []
print("=== override set sanity ===")
results.append(expect("half" not in _VALID_DTYPE_OVERRIDES, True, "half removed from _VALID_DTYPE_OVERRIDES"))
results.append(expect("half" not in _LOW_PRECISION_DTYPES, True, "half removed from _LOW_PRECISION_DTYPES"))
print("\n=== every accepted override parses through get_dtype ===")
for dt in sorted(_VALID_DTYPE_OVERRIDES):
try:
torch_dtype = get_dtype(dt)
print(f"[OK ] get_dtype({dt!r}) -> {torch_dtype}")
results.append(True)
except Exception as e:
print(f"[FAIL] get_dtype({dt!r}) raised: {e}")
results.append(False)
print("\n=== pick_runtime_dtype: non-mps is a no-op ===")
results.append(expect(pick_runtime_dtype("cuda", "bfloat16"), "bfloat16", "cuda/bf16 untouched"))
results.append(expect(pick_runtime_dtype("cpu", "float16"), "float16", "cpu/fp16 untouched"))
results.append(expect(pick_runtime_dtype("cuda", "float32"), "float32", "cuda/fp32 untouched"))
print("\n=== pick_runtime_dtype: mps forces fp32 for low-precision ===")
os.environ.pop("VOXCPM_MPS_DTYPE", None)
results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "float32", "mps/bf16 -> fp32"))
results.append(expect(pick_runtime_dtype("mps", "bf16"), "float32", "mps/bf16-alias -> fp32"))
results.append(expect(pick_runtime_dtype("mps", "float16"), "float32", "mps/fp16 -> fp32"))
results.append(expect(pick_runtime_dtype("mps", "fp16"), "float32", "mps/fp16-alias -> fp32"))
results.append(expect(pick_runtime_dtype("mps", "float32"), "float32", "mps/fp32 stays"))
results.append(expect(pick_runtime_dtype("mps", "fp32"), "fp32", "mps/fp32-alias stays"))
print("\n=== pick_runtime_dtype: VOXCPM_MPS_DTYPE override ===")
os.environ["VOXCPM_MPS_DTYPE"] = "bfloat16"
results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "bfloat16", "override bf16 honored"))
os.environ["VOXCPM_MPS_DTYPE"] = "FP16"
results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "fp16", "override is case-insensitive"))
os.environ["VOXCPM_MPS_DTYPE"] = " float32 "
results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "float32", "override is whitespace-trimmed"))
print("\n=== pick_runtime_dtype: 'half' is no longer a valid override ===")
os.environ["VOXCPM_MPS_DTYPE"] = "half"
results.append(
expect_raises(
lambda: pick_runtime_dtype("mps", "bfloat16"),
ValueError,
"override=half now rejected (was the bug)",
)
)
os.environ["VOXCPM_MPS_DTYPE"] = "garbage"
results.append(
expect_raises(
lambda: pick_runtime_dtype("mps", "bfloat16"),
ValueError,
"override=garbage still rejected",
)
)
os.environ.pop("VOXCPM_MPS_DTYPE", None)
print("\n=== summary ===")
passed = sum(results)
total = len(results)
print(f"{passed}/{total} passed")
sys.exit(0 if passed == total else 1)
+31 -9
View File
@@ -30,7 +30,8 @@ except ImportError:
import json
from voxcpm.model import VoxCPMModel, VoxCPM2Model
from voxcpm.model.voxcpm import LoRAConfig
from voxcpm.model.voxcpm import LoRAConfig as LoRAConfigV1
from voxcpm.model.voxcpm2 import LoRAConfig as LoRAConfigV2
from voxcpm.training import (
Accelerator,
BatchProcessor,
@@ -46,6 +47,7 @@ def train(
train_manifest: str,
val_manifest: str = "",
sample_rate: int = 16_000,
out_sample_rate: int = 0, # AudioVAE decoder output rate; used for TensorBoard audio logging
batch_size: int = 1,
grad_accum_steps: int = 1,
num_workers: int = 2,
@@ -63,6 +65,7 @@ def train(
lambdas: Dict[str, float] = {"loss/diff": 1.0, "loss/stop": 1.0},
lora: dict = None,
config_path: str = "",
max_grad_norm: float = 0.0, # gradient clipping; 0 = disabled (backward compat)
# Distribution options (for LoRA checkpoints)
hf_model_id: str = "", # HuggingFace model ID (e.g., "openbmb/VoxCPM1.5")
distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path
@@ -91,6 +94,7 @@ def train(
with open(os.path.join(pretrained_path, "config.json"), "r", encoding="utf-8") as _f:
_arch = json.load(_f).get("architecture", "voxcpm").lower()
_model_cls = VoxCPM2Model if _arch == "voxcpm2" else VoxCPMModel
LoRAConfig = LoRAConfigV2 if _arch == "voxcpm2" else LoRAConfigV1
if accelerator.rank == 0:
print(f"Detected architecture: {_arch} -> {_model_cls.__name__}", file=sys.stderr)
base_model = _model_cls.from_local(
@@ -98,6 +102,12 @@ def train(
)
tokenizer = base_model.text_tokenizer
expected_sr = base_model.audio_vae.sample_rate
assert sample_rate == expected_sr, (
f"sample_rate mismatch: config says {sample_rate}, but the AudioVAE encoder expects {expected_sr}. "
f"Please set sample_rate: {expected_sr} in your training config. "
)
train_ds, val_ds = load_audio_text_datasets(
train_manifest=train_manifest,
val_manifest=val_manifest,
@@ -170,8 +180,12 @@ def train(
dataset_cnt=dataset_cnt,
device=accelerator.device,
)
# Save audio_vae for audio generation
# Save audio_vae and output sample rate for audio generation.
# Prefer model's actual output rate; fall back to YAML out_sample_rate or encode rate.
audio_vae_for_gen = base_model.audio_vae
out_sr = base_model.sample_rate # decoder output rate (e.g. 48000 for V2)
if out_sr == 0 and out_sample_rate > 0:
out_sr = out_sample_rate
del base_model.audio_vae
model = accelerator.prepare_model(base_model)
unwrapped_model = accelerator.unwrap(model)
@@ -304,8 +318,8 @@ def train(
scaler = getattr(accelerator, "scaler", None)
if scaler is not None:
scaler.unscale_(optimizer)
# Use large max_norm to only compute grad_norm without actual clipping
grad_norm = torch.nn.utils.clip_grad_norm_(unwrapped_model.parameters(), max_norm=1e9)
effective_max_norm = max_grad_norm if max_grad_norm > 0 else 1e9
grad_norm = torch.nn.utils.clip_grad_norm_(unwrapped_model.parameters(), max_norm=effective_max_norm)
accelerator.step(optimizer)
accelerator.update()
@@ -333,6 +347,7 @@ def train(
val_ds=val_ds,
audio_vae=audio_vae_for_gen,
sample_rate=sample_rate,
out_sample_rate=out_sr,
val_texts=val_texts,
tokenizer=tokenizer,
valid_interval=valid_interval,
@@ -359,6 +374,7 @@ def validate(
val_ds=None,
audio_vae=None,
sample_rate=22050,
out_sample_rate=0,
val_texts=None,
tokenizer=None,
valid_interval=1000,
@@ -424,6 +440,7 @@ def validate(
step,
accelerator,
sample_rate,
out_sample_rate=out_sample_rate,
val_texts=val_texts,
tokenizer=tokenizer,
valid_interval=valid_interval,
@@ -526,6 +543,7 @@ def generate_sample_audio(
step,
accelerator,
sample_rate=22050,
out_sample_rate=0,
val_texts=None,
tokenizer=None,
pretrained_path=None,
@@ -540,6 +558,10 @@ def generate_sample_audio(
log(f"[Audio] Starting audio generation for {num_samples} samples at step {step}")
unwrapped_model = accelerator.unwrap(model)
# Determine the correct output sample rate for generated audio.
# out_sample_rate is the decoder output rate (e.g. 48kHz for V2);
# sample_rate is the encoder input rate (e.g. 16kHz for V2).
gen_sr = out_sample_rate if out_sample_rate > 0 else sample_rate
for i in range(num_samples):
sample = val_ds[i]
@@ -596,10 +618,10 @@ def generate_sample_audio(
gen_audio_np = normalize_audio(gen_audio_np)
tag = f"val_sample_{i}"
writer.add_audio(f"{tag}/generated_audio", gen_audio_np, global_step=step, sample_rate=sample_rate)
log(f"[Audio] Generated audio for sample {i}: duration={len(gen_audio_np)/sample_rate:.2f}s")
writer.add_audio(f"{tag}/generated_audio", gen_audio_np, global_step=step, sample_rate=gen_sr)
log(f"[Audio] Generated audio for sample {i}: duration={len(gen_audio_np)/gen_sr:.2f}s")
# Log reference audio
# Log reference audio (at encoder input rate, which is what val_ds provides)
if ref_audio_np is not None:
writer.add_audio(
f"{tag}/reference_audio", normalize_audio(ref_audio_np), global_step=step, sample_rate=sample_rate
@@ -607,9 +629,9 @@ def generate_sample_audio(
# Generate mel spectrogram figure
try:
mel_gen = compute_mel_spectrogram(gen_audio_np, sample_rate)
mel_gen = compute_mel_spectrogram(gen_audio_np, gen_sr)
mel_ref = compute_mel_spectrogram(ref_audio_np, sample_rate) if ref_audio_np is not None else None
fig = create_mel_figure(gen_audio_np, mel_gen, sample_rate, step, ref_audio_np, mel_ref)
fig = create_mel_figure(gen_audio_np, mel_gen, gen_sr, step, ref_audio_np, mel_ref)
writer.add_figure(f"{tag}/mel_spectrogram", fig, global_step=step)
log(f"[Audio] Created mel spectrogram figure for sample {i}")
except Exception as e:
Binary file not shown.
Binary file not shown.
Binary file not shown.
+60 -6
View File
@@ -11,11 +11,6 @@ import os
import sys
from pathlib import Path
import soundfile as sf
from voxcpm.core import VoxCPM
DEFAULT_HF_MODEL_ID = "openbmb/VoxCPM2"
# -----------------------------
@@ -173,7 +168,9 @@ def validate_batch_args(args, parser):
# -----------------------------
def load_model(args) -> VoxCPM:
def load_model(args):
from voxcpm.core import VoxCPM
print("Loading VoxCPM model...", file=sys.stderr)
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get(
@@ -209,6 +206,7 @@ def load_model(args) -> VoxCPM:
zipenhancer_model_path=zipenhancer_path,
enable_denoiser=not args.no_denoiser,
optimize=not args.no_optimize,
device=args.device,
lora_config=lora_config,
lora_weights_path=lora_weights_path,
)
@@ -227,6 +225,7 @@ def load_model(args) -> VoxCPM:
cache_dir=args.cache_dir,
local_files_only=args.local_files_only,
optimize=not args.no_optimize,
device=args.device,
lora_config=lora_config,
lora_weights_path=lora_weights_path,
)
@@ -264,6 +263,8 @@ def _run_single(args, parser, *, text: str, output: str, prompt_text: str | None
and (args.prompt_audio is not None or args.reference_audio is not None),
)
import soundfile as sf
sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
duration = len(audio_array) / model.tts_model.sample_rate
@@ -286,7 +287,27 @@ def cmd_clone(args, parser):
)
def cmd_validate(args, parser):
from voxcpm.training.validate import (
print_validation_report,
validate_manifest,
)
manifest = str(require_file_exists(args.manifest, parser, "manifest file"))
result = validate_manifest(
manifest_path=manifest,
sample_rate=args.sample_rate,
max_samples=args.max_samples,
verbose=args.verbose,
)
print_validation_report(result, manifest)
if not result.is_valid:
sys.exit(1)
def cmd_batch(args, parser):
import soundfile as sf
input_file = require_file_exists(args.input, parser, "input file")
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
@@ -403,6 +424,12 @@ def _add_model_args(parser):
default=DEFAULT_HF_MODEL_ID,
help=f"Hugging Face repo id (default: {DEFAULT_HF_MODEL_ID})",
)
parser.add_argument(
"--device",
type=str,
default="auto",
help="Runtime device: auto, cpu, mps, cuda, or cuda:N (default: auto)",
)
parser.add_argument(
"--cache-dir", type=str, help="Cache directory for Hub downloads"
)
@@ -524,6 +551,30 @@ Examples:
_add_model_args(batch_parser)
_add_lora_args(batch_parser)
# Validate subcommand
validate_parser = subparsers.add_parser(
"validate",
help="Validate a training data manifest (JSONL) before fine-tuning",
)
validate_parser.add_argument(
"--manifest", "-m", required=True, help="Path to JSONL training manifest"
)
validate_parser.add_argument(
"--sample-rate",
type=int,
default=16_000,
help="Expected audio sample rate in Hz (default: 16000)",
)
validate_parser.add_argument(
"--max-samples",
type=int,
default=0,
help="Maximum number of samples to validate (0 = all, default: 0)",
)
validate_parser.add_argument(
"--verbose", "-v", action="store_true", help="Print per-sample progress"
)
# Legacy root arguments
parser.add_argument("--input", "-i", help="Input text file (batch mode only)")
parser.add_argument(
@@ -576,6 +627,9 @@ def main():
parser = _build_parser()
args = parser.parse_args()
if args.command == "validate":
return cmd_validate(args, parser)
validate_ranges(args, parser)
if args.command == "design":
+32 -5
View File
@@ -8,6 +8,7 @@ from typing import Generator, Optional
from huggingface_hub import snapshot_download
from .model.voxcpm import VoxCPMModel, LoRAConfig
from .model.voxcpm2 import VoxCPM2Model
from .model.utils import next_and_close
class VoxCPM:
@@ -17,6 +18,7 @@ class VoxCPM:
zipenhancer_model_path: str | None = "iic/speech_zipenhancer_ans_multiloss_16k_base",
enable_denoiser: bool = True,
optimize: bool = True,
device: str | None = None,
lora_config: Optional[LoRAConfig] = None,
lora_weights_path: Optional[str] = None,
):
@@ -30,6 +32,9 @@ class VoxCPM:
id or local path. If None, denoiser will not be initialized.
enable_denoiser: Whether to initialize the denoiser pipeline.
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
device: Runtime device. If set to ``None`` or ``"auto"``, VoxCPM
will choose automatically (preferring CUDA, then MPS, then CPU).
If set explicitly, that device is used or a clear error is raised.
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
provided without lora_config, a default config will be created.
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
@@ -56,10 +61,20 @@ class VoxCPM:
arch = config.get("architecture", "voxcpm").lower()
if arch == "voxcpm2":
self.tts_model = VoxCPM2Model.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
self.tts_model = VoxCPM2Model.from_local(
voxcpm_model_path,
optimize=optimize,
device=device,
lora_config=lora_config,
)
print("Loaded VoxCPM2Model", file=sys.stderr)
elif arch == "voxcpm":
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
self.tts_model = VoxCPMModel.from_local(
voxcpm_model_path,
optimize=optimize,
device=device,
lora_config=lora_config,
)
print("Loaded VoxCPMModel", file=sys.stderr)
else:
raise ValueError(f"Unsupported architecture: {arch}")
@@ -94,6 +109,7 @@ class VoxCPM:
cache_dir: str = None,
local_files_only: bool = False,
optimize: bool = True,
device: str | None = None,
lora_config: Optional[LoRAConfig] = None,
lora_weights_path: Optional[str] = None,
**kwargs,
@@ -109,6 +125,9 @@ class VoxCPM:
cache_dir: Custom cache directory for the snapshot.
local_files_only: If True, only use local files and do not attempt
to download.
device: Runtime device. Use ``None``/``"auto"`` for automatic
fallback, or an explicit value such as ``"cpu"``, ``"mps"``,
``"cuda"``, or ``"cuda:0"``.
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
provided without lora_config, a default config will be created with
enable_lm=True and enable_dit=True.
@@ -130,7 +149,7 @@ class VoxCPM:
if not repo_id:
raise ValueError("You must provide hf_model_id")
# Load from local path if provided
# 从本地路径加载(如果提供)
if os.path.isdir(repo_id):
local_path = repo_id
else:
@@ -146,13 +165,14 @@ class VoxCPM:
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
enable_denoiser=load_denoiser,
optimize=optimize,
device=device,
lora_config=lora_config,
lora_weights_path=lora_weights_path,
**kwargs,
)
def generate(self, *args, **kwargs) -> np.ndarray:
return next(self._generate(*args, streaming=False, **kwargs))
return next_and_close(self._generate(*args, streaming=False, **kwargs))
def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]:
return self._generate(*args, streaming=True, **kwargs)
@@ -200,7 +220,7 @@ class VoxCPM:
Yields audio chunks for each generation step if ``streaming=True``,
otherwise yields a single array containing the final audio.
"""
if not text.strip() or not isinstance(text, str):
if not isinstance(text, str) or not text.strip():
raise ValueError("target text must be a non-empty string")
if prompt_wav_path is not None:
@@ -273,8 +293,15 @@ class VoxCPM:
streaming=streaming,
)
if streaming:
try:
for wav, _, _ in generate_result:
yield wav.squeeze(0).cpu().numpy()
finally:
generate_result.close()
else:
wav, _, _ = next_and_close(generate_result)
yield wav.squeeze(0).cpu().numpy()
finally:
for tmp_path in temp_files:
+111 -1
View File
@@ -1,7 +1,25 @@
from typing import List
import os
from typing import List, Optional
import torch
from transformers import PreTrainedTokenizer
_LOW_PRECISION_DTYPES = {"bfloat16", "bf16", "float16", "fp16"}
_VALID_DTYPE_OVERRIDES = {
"bfloat16", "bf16",
"float16", "fp16",
"float32", "fp32",
}
# Ref: https://github.com/OpenBMB/VoxCPM/issues/256#issuecomment-4235252732
# Explicitly close partially-consumed generators so inference_mode cleanup
# does not get deferred to Python's GC/finalizer path.
def next_and_close(gen):
try:
return next(gen)
finally:
gen.close()
def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
"""Create a tokenizer wrapper that converts multi-character Chinese tokens to single characters.
@@ -119,3 +137,95 @@ def get_dtype(dtype: str):
return torch.float32
else:
raise ValueError(f"Unsupported dtype: {dtype}")
def _has_mps() -> bool:
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
def pick_runtime_dtype(device: str, configured_dtype: str) -> str:
"""Pick a safe runtime dtype for the resolved device.
On Apple Silicon (MPS), bfloat16/float16 produce enough numerical drift
in the diffusion AR loop that the output is glitched and the model's
badcase detector triggers infinite retries. float32 is the only stable
option today. CUDA and CPU keep whatever the checkpoint was trained with.
Users can override with ``VOXCPM_MPS_DTYPE`` (e.g. ``bfloat16``) when
they want to test future MPS improvements.
"""
if device != "mps":
return configured_dtype
override = os.environ.get("VOXCPM_MPS_DTYPE", "").strip().lower()
if override:
if override not in _VALID_DTYPE_OVERRIDES:
raise ValueError(
f"VOXCPM_MPS_DTYPE='{override}' is not one of "
f"{sorted(_VALID_DTYPE_OVERRIDES)}"
)
return override
if (configured_dtype or "").lower() in _LOW_PRECISION_DTYPES:
return "float32"
return configured_dtype
def auto_select_device(preferred_device: Optional[str] = "cuda") -> str:
"""
Choose a runtime device automatically.
Preference order:
- if the preferred device is available, use it
- otherwise fall back to CUDA -> MPS -> CPU
"""
preferred = (preferred_device or "cuda").strip().lower()
if preferred.startswith("cuda") and torch.cuda.is_available():
return preferred
if preferred == "mps" and _has_mps():
return "mps"
if preferred == "cpu":
return "cpu"
if torch.cuda.is_available():
return "cuda"
if _has_mps():
return "mps"
return "cpu"
def resolve_runtime_device(device: Optional[str], configured_device: str = "cuda") -> str:
"""
Resolve the actual runtime device.
Semantics:
- ``device`` is ``None`` or ``"auto"``: use automatic fallback selection
- otherwise: treat it as an explicit user choice and validate availability
"""
explicit = None if device is None else device.strip().lower()
if explicit is None or explicit == "auto":
return auto_select_device(configured_device)
if explicit.startswith("cuda"):
if not torch.cuda.is_available():
raise ValueError(
f"Requested device '{device}', but CUDA is not available. "
"Use device='auto' for automatic fallback."
)
return explicit
if explicit == "mps":
if not _has_mps():
raise ValueError(
"Requested device 'mps', but MPS is not available. "
"Use device='auto' for automatic fallback."
)
return "mps"
if explicit == "cpu":
return "cpu"
raise ValueError(
f"Unsupported device '{device}'. Supported values are 'auto', 'cpu', 'mps', "
"'cuda', or indexed CUDA devices like 'cuda:0'."
)
+37 -17
View File
@@ -44,7 +44,13 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
from ..modules.locenc import VoxCPMLocEnc
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
from .utils import get_dtype, mask_multichar_chinese_tokens
from .utils import (
get_dtype,
mask_multichar_chinese_tokens,
next_and_close,
pick_runtime_dtype,
resolve_runtime_device,
)
class VoxCPMEncoderConfig(BaseModel):
@@ -109,18 +115,22 @@ class VoxCPMModel(nn.Module):
tokenizer: LlamaTokenizerFast,
audio_vae: AudioVAE,
lora_config: LoRAConfig = None,
device: str | None = None,
):
super().__init__()
self.config = config
self.lora_config = lora_config
self.feat_dim = config.feat_dim
self.patch_size = config.patch_size
self.device = config.device
if not torch.cuda.is_available():
if torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"
self.device = resolve_runtime_device(device, config.device)
self.config.device = self.device
resolved_dtype = pick_runtime_dtype(self.device, self.config.dtype)
if resolved_dtype != self.config.dtype:
print(
f"[voxcpm] adjusted dtype {self.config.dtype} -> {resolved_dtype} for device {self.device}",
file=sys.stderr,
)
self.config.dtype = resolved_dtype
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
# Text-Semantic LM
@@ -227,6 +237,7 @@ class VoxCPMModel(nn.Module):
self.residual_lm.forward_step = torch.compile(
self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True
)
self._feat_encoder_raw = self.feat_encoder
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
self.feat_decoder.estimator = torch.compile(
self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True
@@ -337,7 +348,7 @@ class VoxCPMModel(nn.Module):
return get_dtype(self.config.dtype)
def generate(self, *args, **kwargs) -> torch.Tensor:
return next(self._generate(*args, streaming=False, **kwargs))
return next_and_close(self._generate(*args, streaming=False, **kwargs))
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
return self._generate(*args, streaming=True, **kwargs)
@@ -463,7 +474,7 @@ class VoxCPMModel(nn.Module):
yield decode_audio
break
else:
latent_pred, pred_audio_feat = next(inference_result)
latent_pred, pred_audio_feat = next_and_close(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(
@@ -571,7 +582,7 @@ class VoxCPMModel(nn.Module):
return merged_cache
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
return next_and_close(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
def generate_with_prompt_cache_streaming(
self, *args, **kwargs
@@ -690,7 +701,7 @@ class VoxCPMModel(nn.Module):
yield (decode_audio, target_text_token, pred_audio_feat)
break
else:
latent_pred, pred_audio_feat = next(inference_result)
latent_pred, pred_audio_feat = next_and_close(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(
@@ -713,7 +724,7 @@ class VoxCPMModel(nn.Module):
yield (decode_audio, target_text_token, pred_audio_feat)
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
return next(self._inference(*args, streaming=False, **kwargs))
return next_and_close(self._inference(*args, streaming=False, **kwargs))
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
return self._inference(*args, streaming=True, **kwargs)
@@ -755,7 +766,8 @@ class VoxCPMModel(nn.Module):
"""
B, T, P, D = feat.shape
feat_embed = self.feat_encoder(feat) # [b, t, h_feat]
prefill_encoder = getattr(self, "_feat_encoder_raw", self.feat_encoder)
feat_embed = prefill_encoder(feat) # [b, t, h_feat]
feat_embed = self.enc_to_lm_proj(feat_embed)
if self.config.lm_config.use_mup:
@@ -845,8 +857,16 @@ class VoxCPMModel(nn.Module):
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
@classmethod
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
def from_local(
cls,
path: str,
optimize: bool = True,
training: bool = False,
device: str | None = None,
lora_config: LoRAConfig = None,
):
with open(os.path.join(path, "config.json"), "r", encoding="utf-8") as _cfg_f:
config = VoxCPMConfig.model_validate_json(_cfg_f.read())
tokenizer = LlamaTokenizerFast.from_pretrained(path)
audio_vae_config = getattr(config, "audio_vae_config", None)
audio_vae = AudioVAE(config=audio_vae_config) if audio_vae_config else AudioVAE()
@@ -868,7 +888,7 @@ class VoxCPMModel(nn.Module):
raise FileNotFoundError(
f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
)
model = cls(config, tokenizer, audio_vae, lora_config)
model = cls(config, tokenizer, audio_vae, lora_config, device=device)
if not training:
lm_dtype = get_dtype(model.config.dtype)
model = model.to(lm_dtype)
@@ -950,7 +970,7 @@ class VoxCPMModel(nn.Module):
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
state_dict = load_file(str(safetensors_file), device=device)
elif ckpt_file and ckpt_file.exists():
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
ckpt = torch.load(ckpt_file, map_location=device, weights_only=True)
state_dict = ckpt.get("state_dict", ckpt)
else:
raise FileNotFoundError(f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}")
+103 -71
View File
@@ -45,28 +45,17 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiTV2
from ..modules.locenc import VoxCPMLocEnc
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
from .utils import get_dtype, mask_multichar_chinese_tokens
from .utils import (
get_dtype,
mask_multichar_chinese_tokens,
next_and_close,
pick_runtime_dtype,
resolve_runtime_device,
)
def _trim_audio_silence_vad(
audio: torch.Tensor,
sample_rate: int,
max_silence_ms: float = 200.0,
top_db: float = 35.0,
) -> torch.Tensor:
"""使用能量阈值(VAD 方式)截取首尾静音及尾部长段伪静音,首尾各最多保留 max_silence_ms 毫秒静音。
会同时截掉末尾的长段伪静音(低能量但非完全静音的段落,如长时间底噪)。
Args:
audio: (1, T) 的音频 tensor
sample_rate: 采样率
max_silence_ms: 首尾允许保留的最大静音长度(毫秒)
top_db: 低于参考电平多少 dB 视为静音
Returns:
截取后的 (1, T') tensor
"""
# A simple function to trim audio silence using VAD, not used default
def _trim_audio_silence_vad(audio: torch.Tensor, sample_rate: int, max_silence_ms: float = 200.0, top_db: float = 35.0) -> torch.Tensor:
if audio.numel() == 0:
return audio
y = audio.squeeze(0).numpy()
@@ -85,7 +74,7 @@ def _trim_audio_silence_vad(
except Exception:
start, end = 0, n
# 用逐帧 RMS 找「最后一段有持续能量的位置」,截掉末尾长伪静音(低能量底噪等)
# Find the last frame with continuous energy, trim the long pseudo-silence at the end (low energy background noise, etc.)
n_frames = max(0, (n - frame_length) // hop_length + 1)
last_voice_frame = -1
for j in range(n_frames):
@@ -168,18 +157,22 @@ class VoxCPM2Model(nn.Module):
tokenizer: LlamaTokenizerFast,
audio_vae: AudioVAEV2,
lora_config: LoRAConfig = None,
device: str | None = None,
):
super().__init__()
self.config = config
self.lora_config = lora_config
self.feat_dim = config.feat_dim
self.patch_size = config.patch_size
self.device = config.device
if not torch.cuda.is_available():
if torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"
self.device = resolve_runtime_device(device, config.device)
self.config.device = self.device
resolved_dtype = pick_runtime_dtype(self.device, self.config.dtype)
if resolved_dtype != self.config.dtype:
print(
f"[voxcpm2] adjusted dtype {self.config.dtype} -> {resolved_dtype} for device {self.device}",
file=sys.stderr,
)
self.config.dtype = resolved_dtype
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
# Text-Semantic LM
@@ -246,6 +239,7 @@ class VoxCPM2Model(nn.Module):
# Audio VAE
self.audio_vae = audio_vae
self.chunk_size = audio_vae.chunk_size
self._decode_chunk_size = getattr(audio_vae, "decode_chunk_size", audio_vae.chunk_size)
self._encode_sample_rate = audio_vae.sample_rate
self.sample_rate = getattr(audio_vae, "out_sample_rate", audio_vae.sample_rate)
@@ -291,6 +285,7 @@ class VoxCPM2Model(nn.Module):
self.residual_lm.forward_step = torch.compile(
self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True
)
self._feat_encoder_raw = self.feat_encoder
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
self.feat_decoder.estimator = torch.compile(
self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True
@@ -382,11 +377,7 @@ class VoxCPM2Model(nn.Module):
mu=dit_hidden,
patch_size=self.patch_size,
cond=feat_cond_for_sample,
n_timesteps=(
self.config.dit_config.cfm_config.inference_cfg_rate
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
else 10
),
n_timesteps=10,
)
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
@@ -402,18 +393,25 @@ class VoxCPM2Model(nn.Module):
def _dtype(self):
return get_dtype(self.config.dtype)
def _encode_wav(self, wav_path: str, padding_mode: str = "right") -> torch.Tensor:
def _encode_wav(
self,
wav_path: str,
padding_mode: str = "right",
trim_silence_vad: bool = False,
) -> torch.Tensor:
"""Load, trim, pad and VAE-encode an audio file.
Args:
wav_path: path to the audio file.
padding_mode: "right" (default) or "left" padding for alignment.
trim_silence_vad: whether to apply VAD-based silence trimming.
Returns:
audio_feat: (T, P, D) tensor of latent patches.
"""
audio, _ = librosa.load(wav_path, sr=self._encode_sample_rate, mono=True)
audio = torch.from_numpy(audio).unsqueeze(0)
if trim_silence_vad:
audio = _trim_audio_silence_vad(audio, self._encode_sample_rate, max_silence_ms=200.0)
patch_len = self.patch_size * self.chunk_size
if audio.size(1) % patch_len != 0:
@@ -456,7 +454,7 @@ class VoxCPM2Model(nn.Module):
return tokens, feats, t_mask, a_mask
def generate(self, *args, **kwargs) -> torch.Tensor:
return next(self._generate(*args, streaming=False, **kwargs))
return next_and_close(self._generate(*args, streaming=False, **kwargs))
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
return self._generate(*args, streaming=True, **kwargs)
@@ -475,6 +473,7 @@ class VoxCPM2Model(nn.Module):
retry_badcase: bool = False,
retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0,
trim_silence_vad: bool = False,
streaming: bool = False,
streaming_prefix_len: int = 4,
) -> Generator[torch.Tensor, None, None]:
@@ -495,8 +494,12 @@ class VoxCPM2Model(nn.Module):
)
text_length = text_token.shape[0]
ref_feat = self._encode_wav(reference_wav_path, padding_mode="right")
prompt_feat = self._encode_wav(prompt_wav_path, padding_mode="left")
ref_feat = self._encode_wav(
reference_wav_path,
padding_mode="right",
trim_silence_vad=trim_silence_vad,
)
prompt_feat = self._encode_wav(prompt_wav_path, padding_mode="left", trim_silence_vad=trim_silence_vad)
prompt_audio_length = prompt_feat.size(0)
ref_tokens, ref_feats, ref_t_mask, ref_a_mask = self._make_ref_prefix(ref_feat, text_token.device)
@@ -538,7 +541,11 @@ class VoxCPM2Model(nn.Module):
)
text_length = text_token.shape[0]
ref_feat = self._encode_wav(reference_wav_path, padding_mode="right")
ref_feat = self._encode_wav(
reference_wav_path,
padding_mode="right",
trim_silence_vad=trim_silence_vad,
)
ref_tokens, ref_feats, ref_t_mask, ref_a_mask = self._make_ref_prefix(ref_feat, text_token.device)
text_pad_feat = torch.zeros(
@@ -595,7 +602,7 @@ class VoxCPM2Model(nn.Module):
)
text_length = text_token.shape[0]
prompt_feat = self._encode_wav(prompt_wav_path, padding_mode="left")
prompt_feat = self._encode_wav(prompt_wav_path, padding_mode="left", trim_silence_vad=trim_silence_vad)
prompt_audio_length = prompt_feat.size(0)
prompt_pad_token = torch.zeros(prompt_audio_length, dtype=torch.int32, device=text_token.device)
text_pad_feat = torch.zeros(
@@ -640,14 +647,14 @@ class VoxCPM2Model(nn.Module):
streaming_prefix_len=streaming_prefix_len,
)
if streaming:
patch_len = self.patch_size * self.chunk_size
for latent_pred, _ in inference_result:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
with self.audio_vae.streaming_decode() as vae_dec:
for latent_pred, _, _ctx in inference_result:
decode_audio = vae_dec.decode_chunk(latent_pred.to(torch.float32))
decode_audio = decode_audio.squeeze(1).cpu()
yield decode_audio
break
else:
latent_pred, pred_audio_feat = next(inference_result)
latent_pred, pred_audio_feat, context_len = next_and_close(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(
@@ -663,10 +670,9 @@ class VoxCPM2Model(nn.Module):
if not streaming:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
patch_len = self.patch_size * self.chunk_size
has_continuation = bool(prompt_wav_path)
if has_continuation:
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1):].squeeze(1).cpu()
decode_patch_len = self.patch_size * self._decode_chunk_size
if context_len > 0:
decode_audio = decode_audio[..., decode_patch_len * context_len:].squeeze(1).cpu()
else:
decode_audio = decode_audio.squeeze(1).cpu()
yield decode_audio
@@ -677,6 +683,7 @@ class VoxCPM2Model(nn.Module):
prompt_text: str = None,
prompt_wav_path: str = None,
reference_wav_path: str = None,
trim_silence_vad: bool = False,
):
"""
Build prompt cache for subsequent generation.
@@ -693,6 +700,8 @@ class VoxCPM2Model(nn.Module):
Must be paired with ``prompt_text``.
reference_wav_path: reference audio path for voice cloning
(structurally isolated via ref_audio tokens).
trim_silence_vad: whether to apply VAD-based silence trimming
before encoding prompt/reference audio.
Returns:
prompt_cache: dict used by ``_generate_with_prompt_cache``.
@@ -705,11 +714,19 @@ class VoxCPM2Model(nn.Module):
cache = {}
if reference_wav_path:
cache["ref_audio_feat"] = self._encode_wav(reference_wav_path, padding_mode="right")
cache["ref_audio_feat"] = self._encode_wav(
reference_wav_path,
padding_mode="right",
trim_silence_vad=trim_silence_vad,
)
if prompt_wav_path and prompt_text is not None:
cache["prompt_text"] = prompt_text
cache["audio_feat"] = self._encode_wav(prompt_wav_path, padding_mode="left")
cache["audio_feat"] = self._encode_wav(
prompt_wav_path,
padding_mode="left",
trim_silence_vad=trim_silence_vad,
)
has_ref = "ref_audio_feat" in cache
has_prompt = "audio_feat" in cache
@@ -755,7 +772,7 @@ class VoxCPM2Model(nn.Module):
return merged
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
return next_and_close(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
def generate_with_prompt_cache_streaming(
self, *args, **kwargs
@@ -917,14 +934,14 @@ class VoxCPM2Model(nn.Module):
streaming_prefix_len=streaming_prefix_len,
)
if streaming:
patch_len = self.patch_size * self.chunk_size
for latent_pred, pred_audio_feat in inference_result:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
with self.audio_vae.streaming_decode() as vae_dec:
for latent_pred, pred_audio_feat, _ctx in inference_result:
decode_audio = vae_dec.decode_chunk(latent_pred.to(torch.float32))
decode_audio = decode_audio.squeeze(1).cpu()
yield (decode_audio, target_text_token, pred_audio_feat)
break
else:
latent_pred, pred_audio_feat = next(inference_result)
latent_pred, pred_audio_feat, context_len = next_and_close(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(
@@ -939,18 +956,20 @@ class VoxCPM2Model(nn.Module):
break
if not streaming:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
patch_len = self.patch_size * self.chunk_size
if mode in ("continuation", "ref_continuation"):
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1) :].squeeze(1).cpu()
decode_patch_len = self.patch_size * self._decode_chunk_size
if context_len > 0:
decode_audio = decode_audio[..., decode_patch_len * context_len:].squeeze(1).cpu()
else:
decode_audio = decode_audio[..., :].squeeze(1).cpu()
decode_audio = decode_audio.squeeze(1).cpu()
yield (decode_audio, target_text_token, pred_audio_feat)
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
return next(self._inference(*args, streaming=False, **kwargs))
feat_pred, generated_feat, _ = next_and_close(self._inference(*args, streaming=False, **kwargs))
return feat_pred, generated_feat
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
return self._inference(*args, streaming=True, **kwargs)
for feat_pred, pred_feat_seq, _ in self._inference(*args, streaming=True, **kwargs):
yield feat_pred, pred_feat_seq
@torch.inference_mode()
def _inference(
@@ -989,7 +1008,8 @@ class VoxCPM2Model(nn.Module):
"""
B, T, P, D = feat.shape
feat_embed = self.feat_encoder(feat) # [b, t, h_feat]
prefill_encoder = getattr(self, "_feat_encoder_raw", self.feat_encoder)
feat_embed = prefill_encoder(feat) # [b, t, h_feat]
feat_embed = self.enc_to_lm_proj(feat_embed)
if self.config.lm_config.use_mup:
@@ -1009,6 +1029,7 @@ class VoxCPM2Model(nn.Module):
# trailing audio patches as initial context so the VAE can decode smoothly.
# - Reference-only / zero-shot (feat_mask ends with 0): start from scratch.
has_continuation_audio = feat_mask[0, -1].item() == 1
context_len = 0
if has_continuation_audio:
audio_indices = feat_mask.squeeze(0).nonzero(as_tuple=True)[0]
context_len = min(streaming_prefix_len - 1, len(audio_indices))
@@ -1058,11 +1079,13 @@ class VoxCPM2Model(nn.Module):
prefix_feat_cond = pred_feat
if streaming:
# return the last three predicted latent features to provide enough context for smooth decoding
pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1)
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
# Yield only the newest patch latent for stateful VAE decode
feat_pred = rearrange(pred_feat.unsqueeze(1), "b t p d -> b d (t p)", b=B, p=self.patch_size)
yield feat_pred, pred_feat_seq
yield feat_pred, pred_feat_seq, context_len
if len(pred_feat_seq) > streaming_prefix_len:
pred_feat_seq = pred_feat_seq[-streaming_prefix_len:]
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
if i > min_len and stop_flag == 1:
@@ -1081,11 +1104,20 @@ class VoxCPM2Model(nn.Module):
if not streaming:
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
generated_feat = pred_feat_seq[:, context_len:, :, :].squeeze(0).cpu()
yield feat_pred, generated_feat, context_len
@classmethod
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
def from_local(
cls,
path: str,
optimize: bool = True,
training: bool = False,
device: str | None = None,
lora_config: LoRAConfig = None,
):
with open(os.path.join(path, "config.json"), "r", encoding="utf-8") as _cfg_f:
config = VoxCPMConfig.model_validate_json(_cfg_f.read())
tokenizer = LlamaTokenizerFast.from_pretrained(path)
audio_vae_config = getattr(config, "audio_vae_config", None)
audio_vae = AudioVAEV2(config=audio_vae_config) if audio_vae_config else AudioVAEV2()
@@ -1107,7 +1139,7 @@ class VoxCPM2Model(nn.Module):
raise FileNotFoundError(
f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
)
model = cls(config, tokenizer, audio_vae, lora_config)
model = cls(config, tokenizer, audio_vae, lora_config, device=device)
if not training:
lm_dtype = get_dtype(model.config.dtype)
model = model.to(lm_dtype)
@@ -1189,7 +1221,7 @@ class VoxCPM2Model(nn.Module):
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
state_dict = load_file(str(safetensors_file), device=device)
elif ckpt_file and ckpt_file.exists():
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
ckpt = torch.load(ckpt_file, map_location=device, weights_only=True)
state_dict = ckpt.get("state_dict", ckpt)
else:
raise FileNotFoundError(f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}")
@@ -436,6 +436,7 @@ class AudioVAE(nn.Module):
self.out_sample_rate = out_sample_rate
self.sr_bin_boundaries = sr_bin_boundaries
self.chunk_size = math.prod(encoder_rates)
self.decode_chunk_size = math.prod(decoder_rates)
def preprocess(self, audio_data, sample_rate):
if sample_rate is None:
@@ -471,6 +472,20 @@ class AudioVAE(nn.Module):
sr_cond = torch.tensor([self.out_sample_rate], device=z.device, dtype=torch.int32)
return self.decoder(z, sr_cond)
def streaming_decode(self):
"""Return a ``StreamingVAEDecoder`` context manager for stateful
chunk-by-chunk decoding. Each call to ``decode_chunk`` processes only
the new latent patch and carries causal-conv state internally, avoiding
the redundant overlap decode used previously.
Usage::
with vae.streaming_decode() as dec:
for patch in patches:
audio_chunk = dec.decode_chunk(patch)
"""
return StreamingVAEDecoder(self)
def encode(self, audio_data: torch.Tensor, sample_rate: int):
"""
Args:
@@ -484,3 +499,82 @@ class AudioVAE(nn.Module):
audio_data = self.preprocess(audio_data, sample_rate)
return self.encoder(audio_data)["mu"]
class StreamingVAEDecoder:
"""Stateful streaming wrapper for :class:`AudioVAE`.
Carries causal-convolution padding buffers between calls so that each
``decode_chunk`` processes only the new latent patch — no overlap needed.
"""
def __init__(self, vae: AudioVAE):
self._vae = vae
self._states: dict = {}
self._originals: list = []
# -- context manager --------------------------------------------------
def __enter__(self):
self._states.clear()
self._install()
return self
def __exit__(self, *exc):
self._restore()
self._states.clear()
# -- public API --------------------------------------------------------
def decode_chunk(self, z_chunk: torch.Tensor) -> torch.Tensor:
"""Decode a single latent chunk and return the audio waveform."""
return self._vae.decode(z_chunk)
# -- internals ---------------------------------------------------------
def _install(self):
for name, mod in self._vae.decoder.named_modules():
if isinstance(mod, CausalConv1d):
pad = mod._CausalConv1d__padding * 2 - mod._CausalConv1d__output_padding
if pad > 0:
self._patch_causal_conv(mod, pad)
elif isinstance(mod, CausalTransposeConv1d):
trim = mod._CausalTransposeConv1d__padding * 2 - mod._CausalTransposeConv1d__output_padding
ctx = (mod.kernel_size[0] - 1) // mod.stride[0]
if ctx > 0:
self._patch_transpose_conv(mod, ctx, trim)
def _patch_causal_conv(self, mod, pad_size):
states = self._states
key = id(mod)
orig = mod.forward
def fwd(x, _k=key, _p=pad_size, _m=mod):
x_pad = torch.cat([states[_k], x], dim=-1) if _k in states else F.pad(x, (_p, 0))
if x.shape[-1] >= _p:
states[_k] = x[:, :, -_p:].detach()
else:
prev = states.get(_k, torch.zeros(x.shape[0], x.shape[1], _p,
device=x.device, dtype=x.dtype))
states[_k] = torch.cat([prev, x], dim=-1)[:, :, -_p:].detach()
return nn.Conv1d.forward(_m, x_pad)
mod.forward = fwd
self._originals.append((mod, orig))
def _patch_transpose_conv(self, mod, ctx, trim):
states = self._states
key = id(mod)
orig = mod.forward
def fwd(x, _k=key, _c=ctx, _t=trim, _m=mod):
x_full = torch.cat([states[_k], x], dim=-1) if _k in states else F.pad(x, (_c, 0))
states[_k] = x[:, :, -_c:].detach()
out = nn.ConvTranspose1d.forward(_m, x_full)
left = _c * _m.stride[0]
return out[..., left:-_t] if _t > 0 else out[..., left:]
mod.forward = fwd
self._originals.append((mod, orig))
def _restore(self):
for mod, orig in self._originals:
mod.forward = orig
self._originals.clear()
+1 -1
View File
@@ -225,7 +225,7 @@ class UnifiedCFM(torch.nn.Module):
losses = F.mse_loss(u_pred, u_tgt.detach(), reduction="none").mean(dim=1)
if tgt_mask is not None:
weights = self.adaptive_loss_weighting(losses, tgt_mask.squeeze(1))
loss = (weights * losses).sum() / torch.sum(tgt_mask)
loss = (weights * losses).sum() / torch.clamp(torch.sum(tgt_mask), min=1.0)
else:
loss = losses.mean()
+3 -1
View File
@@ -196,7 +196,9 @@ class MiniCPMAttention(nn.Module):
key_cache[:, :, position_id, :] = key_states
value_cache[:, :, position_id, :] = value_states
attn_mask = torch.arange(key_cache.size(2), device=key_cache.device) <= position_id
# Use an explicit broadcastable mask shape for SDPA. A 1D mask can
# trigger a CPU-side dimension bug in some PyTorch versions.
attn_mask = (torch.arange(key_cache.size(2), device=key_cache.device) <= position_id).view(1, 1, 1, -1)
# ref: https://github.com/pytorch/pytorch/issues/163597
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
+3
View File
@@ -15,6 +15,7 @@ from .data import (
BatchProcessor,
)
from .state import TrainingState
from .validate import validate_manifest, ValidationResult
__all__ = [
"Accelerator",
@@ -24,4 +25,6 @@ __all__ = [
"TrainingState",
"load_audio_text_datasets",
"build_dataloader",
"validate_manifest",
"ValidationResult",
]
+53 -11
View File
@@ -12,6 +12,7 @@ from .packers import AudioFeatureProcessingPacker
DEFAULT_TEXT_COLUMN = "text"
DEFAULT_AUDIO_COLUMN = "audio"
DEFAULT_REF_AUDIO_COLUMN = "ref_audio"
DEFAULT_ID_COLUMN = "dataset_id"
@@ -21,6 +22,7 @@ def load_audio_text_datasets(
val_manifest: str = "",
text_column: str = DEFAULT_TEXT_COLUMN,
audio_column: str = DEFAULT_AUDIO_COLUMN,
ref_audio_column: str = DEFAULT_REF_AUDIO_COLUMN,
dataset_id_column: str = DEFAULT_ID_COLUMN,
sample_rate: int = 16_000,
num_proc: int = 1,
@@ -34,14 +36,19 @@ def load_audio_text_datasets(
def prepare(ds: Dataset) -> Dataset:
if audio_column not in ds.column_names:
raise ValueError(f"Expected '{audio_column}' column in manifest.")
# We cast to Audio to ensure proper handling during training,
# but for length calculation we might need raw path or duration if available.
# HF datasets usually don't compute duration automatically for 'Audio' column.
ds = ds.cast_column(audio_column, Audio(sampling_rate=sample_rate))
if audio_column != DEFAULT_AUDIO_COLUMN:
ds = ds.rename_column(audio_column, DEFAULT_AUDIO_COLUMN)
if text_column != DEFAULT_TEXT_COLUMN:
ds = ds.rename_column(text_column, DEFAULT_TEXT_COLUMN)
# ref_audio is optional — cast to Audio if the column exists
ref_col = ref_audio_column if ref_audio_column in ds.column_names else DEFAULT_REF_AUDIO_COLUMN
if ref_col in ds.column_names:
ds = ds.cast_column(ref_col, Audio(sampling_rate=sample_rate))
if ref_col != DEFAULT_REF_AUDIO_COLUMN:
ds = ds.rename_column(ref_col, DEFAULT_REF_AUDIO_COLUMN)
if dataset_id_column and dataset_id_column in ds.column_names:
if dataset_id_column != DEFAULT_ID_COLUMN:
ds = ds.rename_column(dataset_id_column, DEFAULT_ID_COLUMN)
@@ -67,11 +74,11 @@ def compute_sample_lengths(
- 音频长度:
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
t_seq = ceil(t_vae / patch_size)
- 序列总长约为: text_len + t_seq + 2
- 无 ref_audio: text_len + t_seq + 2
- 有 ref_audio: text_len + t_seq + ref_seq + 4
Optimized: Use batch column access instead of iterating item by item.
"""
# Batch access columns - much faster than per-item access
text_ids_list = ds["text_ids"]
text_lens = [len(t) for t in text_ids_list]
@@ -79,18 +86,35 @@ def compute_sample_lengths(
if has_duration:
durations = ds["duration"]
else:
# Fallback: need to compute from audio (slow, but unavoidable without duration column)
durations = []
for i in range(len(ds)):
audio = ds[i][DEFAULT_AUDIO_COLUMN]
durations.append(len(audio["array"]) / float(audio["sampling_rate"]))
# Vectorized length computation
has_ref_audio = DEFAULT_REF_AUDIO_COLUMN in ds.column_names
if has_ref_audio:
ref_duration_col = "ref_duration" if "ref_duration" in ds.column_names else None
lengths = []
for text_len, duration in zip(text_lens, durations):
for i, (text_len, duration) in enumerate(zip(text_lens, durations)):
t_vae = math.ceil(float(duration) * audio_vae_fps)
t_seq = math.ceil(t_vae / patch_size)
total_len = text_len + t_seq + 2
ref_seq = 0
if has_ref_audio:
# Estimate ref_audio length; ref_audio is None for samples without it
if ref_duration_col:
ref_dur = ds[i].get(ref_duration_col)
else:
ref_item = ds[i].get(DEFAULT_REF_AUDIO_COLUMN)
ref_dur = len(ref_item["array"]) / float(ref_item["sampling_rate"]) if ref_item else None
if ref_dur is not None and float(ref_dur) > 0:
ref_vae = math.ceil(float(ref_dur) * audio_vae_fps)
ref_seq = math.ceil(ref_vae / patch_size)
# +2 for 101/102; +2 more for 103/104 when ref_audio present
overhead = 4 if ref_seq > 0 else 2
total_len = text_len + t_seq + ref_seq + overhead
lengths.append(total_len)
return lengths
@@ -102,8 +126,11 @@ class HFVoxCPMDataset(TorchDataset):
PyTorch-friendly samples.
"""
_SENTINEL = [-100.0]
def __init__(self, dataset: Dataset):
self.dataset = dataset
self.has_ref_audio = DEFAULT_REF_AUDIO_COLUMN in dataset.column_names
def __len__(self):
return len(self.dataset)
@@ -111,13 +138,17 @@ class HFVoxCPMDataset(TorchDataset):
def __getitem__(self, idx: int):
item = self.dataset[idx]
audio = item[DEFAULT_AUDIO_COLUMN]
return {
sample = {
"text_ids": item["text_ids"],
"audio_array": audio["array"],
"audio_sampling_rate": audio["sampling_rate"],
"dataset_id": item.get(DEFAULT_ID_COLUMN, 0),
"is_prompt": item.get("is_prompt", False),
}
if self.has_ref_audio:
ref = item.get(DEFAULT_REF_AUDIO_COLUMN)
sample["ref_audio_array"] = ref["array"] if ref else self._SENTINEL
return sample
@staticmethod
def pad_sequences(seqs: List[torch.Tensor], pad_value: float):
@@ -143,7 +174,7 @@ class HFVoxCPMDataset(TorchDataset):
audio_padded = cls.pad_sequences(audio_tensors, pad_value=-100.0)
task_ids = torch.ones(text_padded.size(0), dtype=torch.int32)
return {
result = {
"text_tokens": text_padded,
"audio_tokens": audio_padded,
"task_ids": task_ids,
@@ -151,6 +182,12 @@ class HFVoxCPMDataset(TorchDataset):
"is_prompts": is_prompts,
}
if "ref_audio_array" in batch[0]:
ref_tensors = [torch.tensor(s["ref_audio_array"], dtype=torch.float32) for s in batch]
result["ref_audio_tokens"] = cls.pad_sequences(ref_tensors, pad_value=-100.0)
return result
class BatchProcessor:
"""
@@ -184,12 +221,17 @@ class BatchProcessor:
task_ids = batch["task_ids"].to(self.device)
dataset_ids = batch["dataset_ids"].to(self.device)
ref_audio_tokens = None
if "ref_audio_tokens" in batch:
ref_audio_tokens = batch["ref_audio_tokens"].to(self.device)
packed = self.packer(
audio_tokens=audio_tokens,
text_tokens=text_tokens,
task_ids=task_ids,
dataset_ids=dataset_ids,
is_prompts=batch["is_prompts"],
ref_audio_tokens=ref_audio_tokens,
)
return packed
+114 -4
View File
@@ -1,4 +1,4 @@
from typing import Dict, List
from typing import Dict, List, Optional
import torch
import torch.nn as nn
@@ -14,7 +14,6 @@ class AudioFeatureProcessingPacker:
def __init__(self, dataset_cnt: int, max_len: int, patch_size: int, feat_dim: int, audio_vae: nn.Module):
self.audio_start_id = 101
self.audio_end_id = 102
# unused now
self.audio_prompt_start_id = 103
self.audio_prompt_end_id = 104
self.text_eos_token_id = 2
@@ -78,11 +77,16 @@ class AudioFeatureProcessingPacker:
task_ids: torch.Tensor,
dataset_ids: torch.Tensor,
is_prompts: List[bool],
ref_audio_tokens: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""
Padding-based batching: each sample in the input batch is processed
independently and then padded to a common length (capped by ``max_len``).
The result tensors all have shape [B, T, ...].
If ``ref_audio_tokens`` is provided (same batch dim as ``audio_tokens``),
samples whose unpadded ref_audio length > 0 will be processed with the
reference-audio path (tokens 103/104 prepended, loss only on target audio).
"""
device = audio_tokens.device
max_dataset_id = int(dataset_ids.max().item()) if dataset_ids.numel() > 0 else -1
@@ -101,13 +105,33 @@ class AudioFeatureProcessingPacker:
audio_duration_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
text_token_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
for audio_token, text_token, task_id, dataset_idx, is_prompt in zip(
audio_tokens, text_tokens, task_ids.tolist(), dataset_ids.tolist(), is_prompts
ref_iter = ref_audio_tokens if ref_audio_tokens is not None else [None] * audio_tokens.size(0)
for audio_token, text_token, task_id, dataset_idx, is_prompt, ref_token in zip(
audio_tokens, text_tokens, task_ids.tolist(), dataset_ids.tolist(), is_prompts, ref_iter
):
unpad_audio_token = self.unpad_audio_tokens(audio_token).to(torch.float32)
unpad_text_token = self.unpad_text_tokens(text_token)
usage = self.id_to_task[task_id]
has_ref = False
if ref_token is not None:
unpad_ref_token = self.unpad_audio_tokens(ref_token).to(torch.float32)
if unpad_ref_token.numel() > 0:
has_ref = True
if has_ref:
(
packed_text,
audio_feat,
text_mask,
audio_mask,
loss_mask,
labels,
audio_duration,
text_token_count,
) = self.process_tts_data_with_ref(unpad_ref_token, unpad_audio_token, unpad_text_token)
else:
(
packed_text,
audio_feat,
@@ -294,3 +318,89 @@ class AudioFeatureProcessingPacker:
audio_duration,
text_token_count,
)
def process_tts_data_with_ref(
self,
ref_audio_token: torch.Tensor,
target_audio_token: torch.Tensor,
text_token: torch.Tensor,
):
"""
Build a training sequence with reference audio prepended:
[103, ref_feats, 104, text, 101, target_feats, 102]
Loss is computed only on the target audio segment.
"""
device = text_token.device
txt_len = len(text_token)
ref_feats, ref_duration = self.extract_audio_feats(ref_audio_token)
ref_feats = ref_feats.squeeze(0) # [R, P, D]
ref_len = ref_feats.shape[0]
tgt_feats, tgt_duration = self.extract_audio_feats(target_audio_token)
tgt_feats = tgt_feats.squeeze(0) # [A, P, D]
tgt_len = tgt_feats.shape[0]
feat_shape = (self.patch_size, ref_feats.size(-1))
def _tok(ids):
return torch.tensor(ids, dtype=torch.int32, device=device)
# -- text token track --
# [103, 0×R, 104, text_ids, 101, 0×A, 102]
text_token_info = torch.cat([
_tok([self.audio_prompt_start_id]),
torch.zeros(ref_len, dtype=torch.int32, device=device),
_tok([self.audio_prompt_end_id]),
text_token,
_tok([self.audio_start_id]),
torch.zeros(tgt_len, dtype=torch.int32, device=device),
_tok([self.audio_end_id]),
])
# -- audio feature track --
zero_1 = torch.zeros((1,) + feat_shape, dtype=torch.float32, device=device)
zero_txt = torch.zeros((txt_len,) + feat_shape, dtype=torch.float32, device=device)
audio_feat_info = torch.cat([
zero_1, ref_feats, zero_1, # 103, ref, 104
zero_txt, # text
zero_1, tgt_feats, zero_1, # 101, target, 102
], dim=0)
# -- masks --
text_mask = torch.cat([
torch.ones(1), torch.zeros(ref_len), torch.ones(1),
torch.ones(txt_len),
torch.ones(1), torch.zeros(tgt_len), torch.ones(1),
]).to(torch.int32).to(device)
audio_mask = torch.cat([
torch.zeros(1), torch.ones(ref_len), torch.zeros(1),
torch.zeros(txt_len),
torch.zeros(1), torch.ones(tgt_len), torch.zeros(1),
]).to(torch.int32).to(device)
loss_mask = torch.cat([
torch.zeros(1 + ref_len + 1), # ref part: no loss
torch.zeros(txt_len), # text: no loss
torch.zeros(1), # 101: no loss
torch.ones(tgt_len), # target audio: LOSS
torch.zeros(1), # 102: no loss
]).to(torch.int32).to(device)
total_len = 1 + ref_len + 1 + txt_len + 1 + tgt_len + 1
labels = torch.zeros(total_len, dtype=torch.int32, device=device)
labels[-2] = 1 # stop label at last target audio position
return (
text_token_info,
audio_feat_info,
text_mask,
audio_mask,
loss_mask,
labels,
ref_duration + tgt_duration,
txt_len,
)
+309
View File
@@ -0,0 +1,309 @@
"""
Pre-flight validation for VoxCPM training data manifests.
Validates JSONL manifest files before starting expensive fine-tuning jobs,
catching format issues, missing files, and data quality problems early.
"""
import json
import os
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional
@dataclass
class ValidationResult:
"""Structured result of a manifest validation run."""
total_samples: int = 0
valid_samples: int = 0
errors: List[str] = field(default_factory=list)
warnings: List[str] = field(default_factory=list)
audio_durations: List[float] = field(default_factory=list)
text_lengths: List[int] = field(default_factory=list)
has_ref_audio: int = 0
@property
def is_valid(self) -> bool:
return len(self.errors) == 0 and self.valid_samples > 0
def _check_audio_file(audio_path: str, sample_rate: int) -> Optional[str]:
"""Check if an audio file exists, is readable, and matches expected sample rate.
Returns an error message, or None if the file is valid.
"""
if not os.path.isfile(audio_path):
return f"Audio file not found: {audio_path}"
try:
import soundfile as sf
info = sf.info(audio_path)
if info.frames == 0:
return f"Audio file is empty: {audio_path}"
if info.samplerate != sample_rate:
return (
f"Sample rate mismatch in {audio_path}: "
f"expected {sample_rate} Hz, got {info.samplerate} Hz"
)
return None
except ImportError:
# soundfile not available; just check existence
return None
except Exception as e:
return f"Cannot read audio file {audio_path}: {e}"
def _get_audio_duration(audio_path: str) -> Optional[float]:
"""Get audio duration in seconds. Returns None if unavailable."""
try:
import soundfile as sf
info = sf.info(audio_path)
return info.duration
except Exception:
return None
def validate_manifest(
manifest_path: str,
sample_rate: int = 16_000,
max_samples: int = 0,
verbose: bool = False,
) -> ValidationResult:
"""Validate a JSONL training manifest file.
Checks:
1. File exists and is readable
2. Each line is valid JSON
3. Required columns present (text, audio)
4. Audio files exist and are readable
5. Text content is non-empty
6. Collects duration and text length statistics
7. Validates optional ref_audio column
Args:
manifest_path: Path to the JSONL manifest file.
sample_rate: Expected audio sample rate (for informational purposes).
max_samples: Maximum number of samples to validate (0 = all).
verbose: Print per-sample progress.
Returns:
ValidationResult with errors, warnings, and statistics.
"""
result = ValidationResult()
path = Path(manifest_path)
if not path.exists():
result.errors.append(f"Manifest file not found: {manifest_path}")
return result
if not path.is_file():
result.errors.append(f"Manifest path is not a file: {manifest_path}")
return result
manifest_dir = path.parent
try:
with open(path, "r", encoding="utf-8") as f:
lines = f.readlines()
except Exception as e:
result.errors.append(f"Cannot read manifest file: {e}")
return result
if not lines:
result.errors.append("Manifest file is empty")
return result
samples_to_check = len(lines)
if max_samples > 0:
samples_to_check = min(samples_to_check, max_samples)
missing_audio_count = 0
empty_text_count = 0
for i, line in enumerate(lines[:samples_to_check]):
line = line.strip()
if not line:
continue
result.total_samples += 1
# Check JSON validity
try:
entry = json.loads(line)
except json.JSONDecodeError as e:
result.errors.append(f"Line {i + 1}: Invalid JSON — {e}")
continue
if not isinstance(entry, dict):
result.errors.append(f"Line {i + 1}: Expected JSON object, got {type(entry).__name__}")
continue
# Check required columns
has_error = False
if "text" not in entry:
result.errors.append(f"Line {i + 1}: Missing required column 'text'")
has_error = True
if "audio" not in entry:
result.errors.append(f"Line {i + 1}: Missing required column 'audio'")
has_error = True
if has_error:
continue
# Validate text
text = entry["text"]
if not isinstance(text, str) or not text.strip():
empty_text_count += 1
if empty_text_count <= 5:
result.warnings.append(f"Line {i + 1}: Empty or non-string text")
else:
result.text_lengths.append(len(text))
# Validate audio path
audio_path = entry["audio"]
if isinstance(audio_path, dict):
# HuggingFace Audio format with {"path": ..., "array": ...}
audio_path = audio_path.get("path", "")
if isinstance(audio_path, str) and audio_path:
# Resolve relative paths against manifest directory
if not os.path.isabs(audio_path):
audio_path = str(manifest_dir / audio_path)
audio_error = _check_audio_file(audio_path, sample_rate)
if audio_error:
missing_audio_count += 1
if missing_audio_count <= 5:
result.errors.append(f"Line {i + 1}: {audio_error}")
has_error = True
else:
duration = _get_audio_duration(audio_path)
if duration is not None:
result.audio_durations.append(duration)
if duration < 0.3:
result.warnings.append(
f"Line {i + 1}: Very short audio ({duration:.2f}s)"
)
elif duration > 30.0:
result.warnings.append(
f"Line {i + 1}: Very long audio ({duration:.1f}s), may cause OOM"
)
else:
result.errors.append(f"Line {i + 1}: Invalid audio path")
has_error = True
# Validate optional ref_audio
if "ref_audio" in entry:
ref_path = entry["ref_audio"]
if isinstance(ref_path, dict):
ref_path = ref_path.get("path", "")
if isinstance(ref_path, str) and ref_path:
if not os.path.isabs(ref_path):
ref_path = str(manifest_dir / ref_path)
if os.path.isfile(ref_path):
result.has_ref_audio += 1
else:
result.warnings.append(
f"Line {i + 1}: ref_audio file not found: {ref_path}"
)
if not has_error:
result.valid_samples += 1
if verbose and (i + 1) % 100 == 0:
print(f" Validated {i + 1}/{samples_to_check} samples...", file=sys.stderr)
# Summarize truncated errors
if missing_audio_count > 5:
result.errors.append(
f"... and {missing_audio_count - 5} more missing audio files "
f"({missing_audio_count} total)"
)
if empty_text_count > 5:
result.warnings.append(
f"... and {empty_text_count - 5} more empty text entries "
f"({empty_text_count} total)"
)
return result
def print_validation_report(result: ValidationResult, manifest_path: str) -> None:
"""Print a human-readable validation report to stderr."""
print(f"\n{'=' * 60}", file=sys.stderr)
print(f" VoxCPM Training Data Validation Report", file=sys.stderr)
print(f"{'=' * 60}", file=sys.stderr)
print(f" Manifest : {manifest_path}", file=sys.stderr)
print(f" Samples : {result.valid_samples}/{result.total_samples} valid", file=sys.stderr)
if result.has_ref_audio > 0:
print(
f" Ref Audio: {result.has_ref_audio} samples with reference audio",
file=sys.stderr,
)
# Audio duration statistics
if result.audio_durations:
durations = sorted(result.audio_durations)
total_hrs = sum(durations) / 3600
print(f"\n Audio Duration Statistics:", file=sys.stderr)
print(f" Total : {total_hrs:.2f} hours", file=sys.stderr)
print(
f" Range : {durations[0]:.2f}s — {durations[-1]:.1f}s",
file=sys.stderr,
)
print(
f" Mean : {sum(durations) / len(durations):.2f}s",
file=sys.stderr,
)
median_idx = len(durations) // 2
print(f" Median : {durations[median_idx]:.2f}s", file=sys.stderr)
# Text length statistics
if result.text_lengths:
lengths = sorted(result.text_lengths)
print(f"\n Text Length Statistics (characters):", file=sys.stderr)
print(
f" Range : {lengths[0]}{lengths[-1]}",
file=sys.stderr,
)
print(
f" Mean : {sum(lengths) / len(lengths):.0f}",
file=sys.stderr,
)
# Errors
if result.errors:
print(f"\n ERRORS ({len(result.errors)}):", file=sys.stderr)
for err in result.errors[:20]:
print(f" x {err}", file=sys.stderr)
if len(result.errors) > 20:
print(
f" ... ({len(result.errors) - 20} more errors omitted)",
file=sys.stderr,
)
# Warnings
if result.warnings:
print(f"\n WARNINGS ({len(result.warnings)}):", file=sys.stderr)
for warn in result.warnings[:10]:
print(f" ! {warn}", file=sys.stderr)
if len(result.warnings) > 10:
print(
f" ... ({len(result.warnings) - 10} more warnings omitted)",
file=sys.stderr,
)
# Summary
print(f"\n{'=' * 60}", file=sys.stderr)
if result.is_valid:
print(" PASSED: Manifest is valid for training.", file=sys.stderr)
else:
print(" FAILED: Fix errors above before starting training.", file=sys.stderr)
print(f"{'=' * 60}\n", file=sys.stderr)
+31
View File
@@ -58,6 +58,7 @@ def test_parser_defaults_to_voxcpm2():
parser = cli._build_parser()
args = parser.parse_args(["design", "--text", "hello", "--output", "out.wav"])
assert args.hf_model_id == "openbmb/VoxCPM2"
assert args.device == "auto"
assert args.no_optimize is False
@@ -85,6 +86,7 @@ def test_load_model_respects_no_optimize_for_local_model(monkeypatch):
cli.load_model(args)
assert calls["kwargs"]["device"] == "auto"
assert calls["kwargs"]["optimize"] is False
@@ -110,6 +112,7 @@ def test_load_model_defaults_optimize_for_hf(monkeypatch):
cli.load_model(args)
assert calls["kwargs"]["device"] == "auto"
assert calls["kwargs"]["optimize"] is True
@@ -136,9 +139,37 @@ def test_load_model_respects_no_optimize_for_hf(monkeypatch):
cli.load_model(args)
assert calls["kwargs"]["device"] == "auto"
assert calls["kwargs"]["optimize"] is False
def test_load_model_passes_explicit_device_to_hf(monkeypatch):
calls = {}
class FakeVoxCPM:
@classmethod
def from_pretrained(cls, **kwargs):
calls["kwargs"] = kwargs
return DummyModel()
monkeypatch.setattr(cli, "VoxCPM", FakeVoxCPM)
args = cli._build_parser().parse_args(
[
"design",
"--text",
"hello",
"--output",
"out.wav",
"--device",
"mps",
]
)
cli.load_model(args)
assert calls["kwargs"]["device"] == "mps"
def test_design_subcommand_applies_control(monkeypatch, tmp_path):
dummy_model = DummyModel()
monkeypatch.setattr(cli, "load_model", lambda args: dummy_model)
+150
View File
@@ -0,0 +1,150 @@
from __future__ import annotations
import importlib.util
import sys
import types
from pathlib import Path
import pytest
import torch
ROOT = Path(__file__).resolve().parents[1]
SRC = ROOT / "src"
def _load_module(name: str, path: Path):
spec = importlib.util.spec_from_file_location(name, path)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
sys.modules[name] = module
spec.loader.exec_module(module)
return module
def bootstrap_repo_modules(monkeypatch):
for name, path in [
("voxcpm", SRC / "voxcpm"),
("voxcpm.model", SRC / "voxcpm" / "model"),
("voxcpm.modules", SRC / "voxcpm" / "modules"),
]:
pkg = types.ModuleType(name)
pkg.__path__ = [str(path)]
monkeypatch.setitem(sys.modules, name, pkg)
hh = types.ModuleType("huggingface_hub")
hh.snapshot_download = lambda *a, **k: "/tmp/fake"
monkeypatch.setitem(sys.modules, "huggingface_hub", hh)
pydantic = types.ModuleType("pydantic")
class BaseModel:
@classmethod
def model_rebuild(cls):
return None
@classmethod
def model_validate_json(cls, s):
return cls()
def model_dump(self):
return {}
pydantic.BaseModel = BaseModel
monkeypatch.setitem(sys.modules, "pydantic", pydantic)
torchaudio = types.ModuleType("torchaudio")
monkeypatch.setitem(sys.modules, "torchaudio", torchaudio)
librosa = types.ModuleType("librosa")
librosa.effects = types.SimpleNamespace(trim=lambda *a, **k: (None, (0, 0)))
monkeypatch.setitem(sys.modules, "librosa", librosa)
einops = types.ModuleType("einops")
einops.rearrange = lambda x, *a, **k: x
monkeypatch.setitem(sys.modules, "einops", einops)
tqdm_pkg = types.ModuleType("tqdm")
tqdm_pkg.__path__ = ["/nonexistent"]
tqdm_pkg.tqdm = lambda x, *a, **k: x
monkeypatch.setitem(sys.modules, "tqdm", tqdm_pkg)
tqdm_auto = types.ModuleType("tqdm.auto")
tqdm_auto.tqdm = lambda x, *a, **k: x
monkeypatch.setitem(sys.modules, "tqdm.auto", tqdm_auto)
transformers = types.ModuleType("transformers")
class LlamaTokenizerFast:
pass
class PreTrainedTokenizer:
pass
transformers.LlamaTokenizerFast = LlamaTokenizerFast
transformers.PreTrainedTokenizer = PreTrainedTokenizer
monkeypatch.setitem(sys.modules, "transformers", transformers)
internal_mods = {
"voxcpm.modules.audiovae": ["AudioVAE", "AudioVAEConfig", "AudioVAEV2", "AudioVAEConfigV2"],
"voxcpm.modules.layers": ["ScalarQuantizationLayer"],
"voxcpm.modules.locdit": ["CfmConfig", "UnifiedCFM", "VoxCPMLocDiT", "VoxCPMLocDiTV2"],
"voxcpm.modules.locenc": ["VoxCPMLocEnc"],
"voxcpm.modules.minicpm4": ["MiniCPM4Config", "MiniCPMModel"],
"voxcpm.modules.layers.lora": ["apply_lora_to_named_linear_modules", "LoRALinear"],
}
for modname, names in internal_mods.items():
module = types.ModuleType(modname)
for name in names:
if name == "apply_lora_to_named_linear_modules":
setattr(module, name, lambda *a, **k: None)
else:
setattr(module, name, type(name, (), {}))
monkeypatch.setitem(sys.modules, modname, module)
_load_module("voxcpm.model.utils", SRC / "voxcpm" / "model" / "utils.py")
voxcpm = _load_module("voxcpm.model.voxcpm", SRC / "voxcpm" / "model" / "voxcpm.py")
voxcpm2 = _load_module("voxcpm.model.voxcpm2", SRC / "voxcpm" / "model" / "voxcpm2.py")
return voxcpm.VoxCPMModel, voxcpm2.VoxCPM2Model
class DummyModel:
device = "cpu"
def named_parameters(self):
return []
@pytest.mark.parametrize("module_name", ["v1", "v2"])
def test_load_lora_weights_accepts_tensor_only_legacy_checkpoints(monkeypatch, tmp_path, module_name):
VoxCPMModel, VoxCPM2Model = bootstrap_repo_modules(monkeypatch)
cls = VoxCPMModel if module_name == "v1" else VoxCPM2Model
ckpt_path = tmp_path / "lora_weights.ckpt"
torch.save({"state_dict": {"fake": torch.zeros(1)}}, ckpt_path)
loaded, skipped = cls.load_lora_weights(DummyModel(), str(ckpt_path), device="cpu")
assert loaded == []
assert skipped == ["fake"]
@pytest.mark.parametrize("module_name", ["v1", "v2"])
def test_load_lora_weights_rejects_malicious_pickle_payloads(monkeypatch, tmp_path, module_name):
VoxCPMModel, VoxCPM2Model = bootstrap_repo_modules(monkeypatch)
cls = VoxCPMModel if module_name == "v1" else VoxCPM2Model
ckpt_path = tmp_path / "lora_weights.ckpt"
marker_path = tmp_path / f"{module_name}-marker.txt"
class Exploit:
def __reduce__(self):
import pathlib
return (pathlib.Path.write_text, (marker_path, f"{module_name} executed\n"))
torch.save({"state_dict": {"fake": torch.zeros(1)}, "boom": Exploit()}, ckpt_path)
with pytest.raises(Exception, match="Weights only load failed"):
cls.load_lora_weights(DummyModel(), str(ckpt_path), device="cpu")
assert not marker_path.exists()
+49
View File
@@ -0,0 +1,49 @@
from __future__ import annotations
import importlib.util
import sys
import types
from pathlib import Path
import pytest
ROOT = Path(__file__).resolve().parents[1]
UTILS_PATH = ROOT / "src" / "voxcpm" / "model" / "utils.py"
transformers_stub = types.ModuleType("transformers")
transformers_stub.PreTrainedTokenizer = object
sys.modules.setdefault("transformers", transformers_stub)
spec = importlib.util.spec_from_file_location("voxcpm.model.utils", UTILS_PATH)
utils = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(utils)
def test_resolve_runtime_device_auto_falls_back_to_cpu(monkeypatch):
monkeypatch.setattr(utils.torch.cuda, "is_available", lambda: False)
monkeypatch.setattr(utils, "_has_mps", lambda: False)
assert utils.resolve_runtime_device(None, "cuda") == "cpu"
def test_resolve_runtime_device_auto_uses_mps_when_available(monkeypatch):
monkeypatch.setattr(utils.torch.cuda, "is_available", lambda: False)
monkeypatch.setattr(utils, "_has_mps", lambda: True)
assert utils.resolve_runtime_device("auto", "cuda") == "mps"
def test_resolve_runtime_device_respects_explicit_cpu(monkeypatch):
monkeypatch.setattr(utils.torch.cuda, "is_available", lambda: True)
monkeypatch.setattr(utils, "_has_mps", lambda: True)
assert utils.resolve_runtime_device("cpu", "cuda") == "cpu"
def test_resolve_runtime_device_rejects_unavailable_explicit_cuda(monkeypatch):
monkeypatch.setattr(utils.torch.cuda, "is_available", lambda: False)
monkeypatch.setattr(utils, "_has_mps", lambda: True)
with pytest.raises(ValueError, match="CUDA is not available"):
utils.resolve_runtime_device("cuda:0", "cuda")
+252
View File
@@ -0,0 +1,252 @@
"""Tests for the training data validation module."""
from __future__ import annotations
import json
import os
import sys
import tempfile
import types
from pathlib import Path
import pytest
ROOT = Path(__file__).resolve().parents[1]
# Stub voxcpm package so imports work without full dependencies
pkg = types.ModuleType("voxcpm")
pkg.__path__ = [str(ROOT / "src" / "voxcpm")]
sys.modules.setdefault("voxcpm", pkg)
training_pkg = types.ModuleType("voxcpm.training")
training_pkg.__path__ = [str(ROOT / "src" / "voxcpm" / "training")]
sys.modules.setdefault("voxcpm.training", training_pkg)
from voxcpm.training.validate import ValidationResult, validate_manifest
@pytest.fixture
def tmp_dir():
with tempfile.TemporaryDirectory() as d:
yield Path(d)
def _create_wav(path: Path, duration_s: float = 1.0, sr: int = 16000):
"""Create a minimal valid WAV file."""
try:
import soundfile as sf
import numpy as np
samples = int(duration_s * sr)
data = np.zeros(samples, dtype=np.float32)
sf.write(str(path), data, sr)
except ImportError:
# If soundfile is not available, create a minimal WAV header
import struct
samples = int(duration_s * sr)
data_size = samples * 2 # 16-bit PCM
with open(path, "wb") as f:
f.write(b"RIFF")
f.write(struct.pack("<I", 36 + data_size))
f.write(b"WAVEfmt ")
f.write(struct.pack("<IHHIIHH", 16, 1, 1, sr, sr * 2, 2, 16))
f.write(b"data")
f.write(struct.pack("<I", data_size))
f.write(b"\x00" * data_size)
def _write_manifest(path: Path, entries: list[dict]):
with open(path, "w", encoding="utf-8") as f:
for entry in entries:
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
class TestValidateManifest:
def test_valid_manifest(self, tmp_dir):
audio1 = tmp_dir / "audio1.wav"
audio2 = tmp_dir / "audio2.wav"
_create_wav(audio1, 2.0)
_create_wav(audio2, 3.0)
manifest = tmp_dir / "train.jsonl"
_write_manifest(
manifest,
[
{"text": "Hello world", "audio": str(audio1)},
{"text": "Goodbye world", "audio": str(audio2)},
],
)
result = validate_manifest(str(manifest))
assert result.total_samples == 2
assert result.valid_samples == 2
assert result.is_valid
assert len(result.errors) == 0
def test_missing_manifest(self):
result = validate_manifest("/nonexistent/path.jsonl")
assert not result.is_valid
assert any("not found" in e for e in result.errors)
def test_empty_manifest(self, tmp_dir):
manifest = tmp_dir / "empty.jsonl"
manifest.write_text("")
result = validate_manifest(str(manifest))
assert not result.is_valid
def test_invalid_json(self, tmp_dir):
manifest = tmp_dir / "bad.jsonl"
manifest.write_text("not json\n{bad json}\n")
result = validate_manifest(str(manifest))
assert len(result.errors) >= 2
assert any("Invalid JSON" in e for e in result.errors)
def test_missing_columns(self, tmp_dir):
manifest = tmp_dir / "missing.jsonl"
_write_manifest(
manifest,
[
{"text": "hello"}, # missing audio
{"audio": "test.wav"}, # missing text
],
)
result = validate_manifest(str(manifest))
assert len(result.errors) >= 2
assert any("'audio'" in e for e in result.errors)
assert any("'text'" in e for e in result.errors)
def test_missing_audio_file(self, tmp_dir):
manifest = tmp_dir / "missing_audio.jsonl"
_write_manifest(
manifest,
[{"text": "hello", "audio": "/nonexistent/audio.wav"}],
)
result = validate_manifest(str(manifest))
assert not result.is_valid
assert any("not found" in e for e in result.errors)
def test_empty_text_warning(self, tmp_dir):
audio = tmp_dir / "audio.wav"
_create_wav(audio)
manifest = tmp_dir / "empty_text.jsonl"
_write_manifest(
manifest,
[{"text": "", "audio": str(audio)}],
)
result = validate_manifest(str(manifest))
assert len(result.warnings) > 0
assert any("Empty" in w for w in result.warnings)
def test_relative_audio_path(self, tmp_dir):
audio = tmp_dir / "audio.wav"
_create_wav(audio)
manifest = tmp_dir / "rel.jsonl"
_write_manifest(
manifest,
[{"text": "hello", "audio": "audio.wav"}],
)
result = validate_manifest(str(manifest))
assert result.valid_samples == 1
assert result.is_valid
def test_max_samples_limit(self, tmp_dir):
audio = tmp_dir / "audio.wav"
_create_wav(audio)
manifest = tmp_dir / "many.jsonl"
_write_manifest(
manifest,
[{"text": f"sample {i}", "audio": str(audio)} for i in range(100)],
)
result = validate_manifest(str(manifest), max_samples=10)
assert result.total_samples == 10
def test_ref_audio_counted(self, tmp_dir):
audio = tmp_dir / "audio.wav"
ref = tmp_dir / "ref.wav"
_create_wav(audio)
_create_wav(ref)
manifest = tmp_dir / "ref.jsonl"
_write_manifest(
manifest,
[{"text": "hello", "audio": str(audio), "ref_audio": str(ref)}],
)
result = validate_manifest(str(manifest))
assert result.has_ref_audio == 1
def test_validation_result_properties(self):
r = ValidationResult(total_samples=5, valid_samples=5)
assert r.is_valid
r2 = ValidationResult(total_samples=5, valid_samples=5, errors=["err"])
assert not r2.is_valid
r3 = ValidationResult(total_samples=0, valid_samples=0)
assert not r3.is_valid
def test_invalid_audio_not_counted_as_valid(self, tmp_dir):
"""A row with a bad audio path must not increment valid_samples."""
manifest = tmp_dir / "bad_audio.jsonl"
_write_manifest(
manifest,
[{"text": "hello", "audio": "/nonexistent/audio.wav"}],
)
result = validate_manifest(str(manifest))
assert result.total_samples == 1
assert result.valid_samples == 0
assert not result.is_valid
assert any("not found" in e for e in result.errors)
def test_sample_rate_mismatch(self, tmp_dir):
"""A file with a different sample rate should be reported as an error."""
try:
import soundfile as sf
import numpy as np
except ImportError:
pytest.skip("soundfile not available")
audio = tmp_dir / "audio_8k.wav"
import numpy as np
samples = np.zeros(8000, dtype=np.float32)
sf.write(str(audio), samples, 8000)
manifest = tmp_dir / "sr_mismatch.jsonl"
_write_manifest(manifest, [{"text": "hello", "audio": str(audio)}])
result = validate_manifest(str(manifest), sample_rate=16000)
assert result.valid_samples == 0
assert not result.is_valid
assert any("Sample rate mismatch" in e or "sample rate" in e.lower() for e in result.errors)
def test_mixed_ref_audio_warns_for_each_missing(self, tmp_dir):
"""Missing ref_audio entries should each generate a warning independently."""
audio = tmp_dir / "audio.wav"
ref_good = tmp_dir / "ref_good.wav"
_create_wav(audio)
_create_wav(ref_good)
manifest = tmp_dir / "mixed_ref.jsonl"
_write_manifest(
manifest,
[
{"text": "row1", "audio": str(audio), "ref_audio": str(ref_good)},
{"text": "row2", "audio": str(audio), "ref_audio": "/nonexistent/ref.wav"},
],
)
result = validate_manifest(str(manifest))
assert result.has_ref_audio == 1
assert any("ref_audio file not found" in w for w in result.warnings)
def test_cli_validate_exit_code(self, tmp_dir):
"""validate subcommand must exit 1 on validation error (missing audio)."""
import subprocess
manifest = tmp_dir / "bad.jsonl"
_write_manifest(manifest, [{"text": "hi", "audio": "/nonexistent/x.wav"}])
proc = subprocess.run(
[sys.executable, "-m", "voxcpm.cli", "validate", "--manifest", str(manifest)],
capture_output=True,
text=True,
)
assert proc.returncode == 1, f"Expected exit 1, got {proc.returncode}"
assert "FAILED" in proc.stderr or "Audio file not found" in proc.stderr