35 Commits

Author SHA1 Message Date
Pine 2e4c601d1d - 修复路径错误 2026-05-06 14:58:46 +08:00
Pine 525ebc65a3 - 修正一个本地模型使用 2026-05-06 14:28:36 +08:00
Pine 564b98b2d5 - 自部署单独用支持:给添哥 2026-05-06 13:41:26 +08:00
liuxin 19b6bf7590 fix: handle LoRA rank mismatch during inference in lora_ft_webui
Pass the selected LoRA checkpoint to load_model() on first load so the
model initializes with the correct rank from lora_config.json instead of
always defaulting to r=32.

On subsequent LoRA hot-swaps, detect rank incompatibility and
automatically reload the model with the new checkpoint's config,
preventing tensor shape mismatch errors (fixes #283).

Made-with: Cursor
2026-04-28 10:52:57 +08:00
ZGY 86bff0fc82 Merge pull request #253 from SuperMarioYL/feat/validate-training-data
feat: add voxcpm validate CLI for pre-flight training data checks
2026-04-27 21:09:41 +08:00
supermario_leo dd7b78f2c0 refactor(cli): defer soundfile and voxcpm.core imports to inference commands
Move `import soundfile as sf` and `from voxcpm.core import VoxCPM` from
module-level into the functions that require model inference (load_model,
_run_single, cmd_batch), so `voxcpm validate` can run without loading
the model/inference stack.
2026-04-25 05:09:23 +08:00
supermario_leo 29577d57f8 test: fix test_cli_validate_exit_code to use --manifest flag and assert specific exit code
Pass manifest path via --manifest flag (required) instead of as a
positional argument, so the test exercises cmd_validate rather than
argparse error handling.  Also assert returncode==1 and check stderr
for the FAILED/error message to prevent false positives.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-24 10:15:57 +08:00
supermario_leo 4509becfde fix: address four validation correctness issues from review
- Invalid audio rows (bad path or sample-rate mismatch) no longer
  increment valid_samples; has_error is now set on any audio failure
- _check_audio_file now enforces the expected sample rate when soundfile
  is available, making --sample-rate actually useful
- ref_audio missing-file warning is emitted for every invalid entry
  independently, not only before the first valid one is seen
- New tests cover each of the four corrected behaviours: invalid audio
  count, sample-rate mismatch, mixed ref_audio, and CLI exit code
2026-04-22 05:06:35 +08:00
ZGY cd79a647fa Merge pull request #263 from Oumnya/fix/mps-bf16-dtype
fix(mps): force float32 on Apple Silicon to avoid bf16 quality loss
2026-04-21 18:49:48 +08:00
Oumnya 96d605b9de fix(mps): align VOXCPM_MPS_DTYPE override set with get_dtype parser
Drop "half" from _VALID_DTYPE_OVERRIDES / _LOW_PRECISION_DTYPES.
get_dtype() has never accepted "half", so VOXCPM_MPS_DTYPE=half would
pass override validation and then crash downstream with
"Unsupported dtype: half". The remaining aliases (bfloat16/bf16,
float16/fp16, float32/fp32) already cover the intended dtype space.

Adds a standalone unit check under scripts/ to guard the invariant
that every accepted override parses through get_dtype().

Addresses review feedback on #263.
2026-04-21 18:24:53 +08:00
ZGY a9b03a768c Merge pull request #277 from gluttony-10/main
feat: enhance control text processing in VoxCPMDemo
2026-04-21 17:11:42 +08:00
ZGY 77f847fcba Merge pull request #268 from shaun0927/fix/lora-weights-only
fix: load legacy LoRA checkpoints with weights_only=True
2026-04-21 16:55:42 +08:00
gluttony-10 d3cc88722c feat: enhance control text processing in VoxCPMDemo
Added regex to strip parentheses from control instructions in the text synthesis method to ensure compatibility with the expected prompt format. This change improves the robustness of the input handling.
2026-04-21 07:07:24 +00:00
JunghwanNA ec2acec8a1 Harden LoRA checkpoint loading against untrusted pickle payloads
LoRA is a first-class workflow in VoxCPM, and the project already prefers
safetensors plus weights-only fallback loading for base model artifacts. The
legacy LoRA .ckpt/.pth path was the remaining place that still deserialized
arbitrary pickle objects, so this switches it to weights_only=True and adds
focused regression coverage for both model loaders.

Constraint: Must preserve compatibility with tensor-only legacy LoRA checkpoints
Rejected: Remove .ckpt/.pth support entirely | too disruptive for existing users
Confidence: high
Scope-risk: narrow
Reversibility: clean
Directive: Keep LoRA artifact handling aligned with the existing safetensors-first, weights-only loading pattern
Tested: python3 -m pytest -q tests/test_lora_checkpoint_loading.py tests/test_model_utils.py -q
Not-tested: Full end-to-end LoRA hot-load with heavyweight model assets
2026-04-18 00:31:28 +09:00
xliucs 13605c5a0e Merge pull request #266 from linyueqian/docs/add-vllm-omni-references
docs: add vLLM-Omni serving references
2026-04-17 10:46:21 +08:00
Yueqian Lin afa63e6195 docs: add vLLM-Omni serving references
Document vLLM-Omni as a production serving option for VoxCPM2
alongside the existing Nano-vLLM reference. Mirrors the addition in
README_zh.md, and adds an ecosystem table entry.

Install snippet follows the upstream vLLM-Omni installation guide
(from source, since vllm-omni is rapidly evolving).

Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
2026-04-16 21:19:27 -05:00
liuxin eae0a29908 docs: add ComfyUI RH link
Made-with: Cursor
2026-04-16 11:46:40 +08:00
Labmem-Zhouyx 35895982d7 Merge PR #212: perf: stateful streaming VAE decode — eliminate redundant overlap
- StreamingVAEDecoder caches CausalConv1d/CausalTransposeConv1d left-pad
  state between calls — one patch in, one patch out, no overlap
- _inference yields single-patch latents in streaming mode
- 2x faster streaming VAE decode, more accurate (max diff 0.0005 vs 0.0011)
2026-04-15 16:01:38 +08:00
Labmem-Zhouyx f7f1b78c4d fix: correct transpose conv context 2026-04-15 16:01:02 +08:00
oumnya 38d61cdf03 fix(mps): force float32 on Apple Silicon to avoid bf16 quality loss
VoxCPM checkpoints default to bfloat16. Following commit e4e0496 which
added MPS device routing, running with `device=mps` selects bf16 on
Apple Silicon. On Metal, bf16 introduces enough numerical drift in the
diffusion AR loop that the synthesized audio is glitched and trips the
model's badcase detector, which retries until the per-call retry budget
is exhausted. Effectively MPS support is unusable in the default config.

This patch adds a single helper, `pick_runtime_dtype(device, dtype)`,
that promotes any low-precision dtype to float32 when the resolved
device is `mps`. CUDA and CPU paths are untouched. An opt-out env var
`VOXCPM_MPS_DTYPE` lets users force a specific dtype on MPS once future
PyTorch / macOS releases improve bf16 stability.

Both VoxCPMModel and VoxCPM2Model adopt the helper in their __init__,
replacing what would otherwise be duplicated inline checks.

Verified locally on Apple M5 Max, PyTorch 2.11, macOS 15:
- VoxCPM2 (2B): clean output, RTF ~0.78 steady state
- VoxCPM 0.5B: clean output, RTF ~0.92
- No badcase retries fired in any test
- VOXCPM_MPS_DTYPE=bfloat16 round-trips and reproduces the original
  glitched output, confirming the override path.
2026-04-15 12:22:56 +08:00
刘鑫 1565e83efe fix: complete shared generator cleanup coverage
Move generator close handling into a shared utility and wire the core generation pipeline through it so partially-consumed prompt cache generators are cleaned up consistently across both model variants and the public VoxCPM wrapper.

Made-with: Cursor
2026-04-13 17:39:05 +08:00
刘鑫 61b36d4e56 refactor: centralize generator cleanup in model helpers
Factor repeated next-and-close patterns into a shared helper in both VoxCPM model variants so non-streaming inference cleans up generators consistently while keeping the issue reference close to the workaround.

Made-with: Cursor
2026-04-13 16:57:08 +08:00
刘鑫 b1584aec7c fix: stabilize CPU SDPA mask broadcasting
Use an explicit broadcastable attention mask shape during MiniCPM incremental decoding so CPU runtimes avoid a PyTorch SDPA dimension error without changing attention semantics.

Made-with: Cursor
2026-04-13 15:38:53 +08:00
supermario_leo 4457617953 feat: add voxcpm validate CLI for pre-flight training data checks
Add a new `validate` subcommand that checks JSONL training manifests
before starting expensive fine-tuning jobs. This catches format issues,
missing audio files, and data quality problems early.

The validator performs:
- JSONL format validation (each line must be valid JSON)
- Required column checks (text, audio)
- Audio file existence and readability verification
- Duration and text length statistics (min, max, mean, median)
- Optional ref_audio column validation
- Warnings for very short (<0.3s) or very long (>30s) audio samples

Usage:
  voxcpm validate --manifest train.jsonl
  voxcpm validate --manifest train.jsonl --sample-rate 16000 --verbose

The module uses lazy imports for soundfile, so it works even in
minimal environments. Includes 11 unit tests covering all validation
paths.
2026-04-13 03:15:50 +08:00
xliucs 5510503182 Merge pull request #246 from sharziki/fix/unclosed-file-handles
fix: close file handles in from_local() config loading
2026-04-11 13:10:04 +08:00
sharziki fb46aad9a5 fix: close file handles in from_local() config loading
Use context managers when reading config.json in VoxCPMModel.from_local()
and VoxCPM2Model.from_local() to prevent file descriptor leaks. Also add
explicit encoding="utf-8" to avoid locale-dependent decode errors.

Closes #235

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-11 00:01:14 -04:00
刘鑫 e4e049624c update finetuning pipeline and runtime device handling
Support optional ref_audio samples in finetuning and make runtime device selection explicit while keeping auto fallback behavior consistent. Also ignore the local app override file to avoid accidental commits.

Made-with: Cursor
2026-04-11 11:08:50 +08:00
xliucs abf01b9bf3 Merge pull request #229 from kuishou68/fix/issue-228-validate-text-type-order
fix: correct isinstance/strip order in _generate() to prevent AttributeError on non-string input
2026-04-10 10:30:15 +08:00
cocoon 4f4a5b9f6c fix: correct type-check order in _generate() to prevent AttributeError on non-string input
The previous guard `not text.strip() or not isinstance(text, str)` called
.strip() before verifying that text is actually a string, causing an
AttributeError (e.g. for int input) instead of the intended ValueError.

Swap operand order so isinstance check short-circuits first.

Closes #228
2026-04-09 16:13:40 +00:00
刘鑫 79c0cf68dd chore: remove accidentally committed app_local.py
Made-with: Cursor
2026-04-09 16:05:18 +08:00
刘鑫 75cfa3e9b8 fix: use uncompiled feat_encoder for prefill to prevent CUDA Graph dynamic shape accumulation (#209) 2026-04-09 16:00:17 +08:00
Labmem-Zhouyx 5611bd08a0 optim app.py 2026-04-09 00:30:19 +08:00
Kevin Knoedler 66205135fc perf: stateful streaming VAE decode — eliminate redundant overlap
Streaming decode previously re-decoded 4 overlapping patches through
the VAE each step, discarding 75% of the output. Replace with stateful
decode that carries causal conv padding buffers between calls — one
patch in, one patch out, no overlap.

Changes:
- Add StreamingVAEDecoder to audiovae/audio_vae_v2.py — caches
  CausalConv1d and CausalTransposeConv1d left-pad state between calls
- AudioVAE.streaming_decode() context manager for clean lifecycle
- _inference yields single-patch latents in streaming mode
- _generate and _generate_with_prompt_cache use StreamingVAEDecoder

Streaming VAE decode time (isolated): 289ms → 148ms (2x faster)
Stateful vs full decode: cosine 1.0000, max diff 0.0005
(more accurate than previous overlap approach at max diff 0.001)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-08 09:09:22 -07:00
Labmem-Zhouyx 364eff6840 update readme: python version 2026-04-08 23:07:38 +08:00
Labmem-Zhouyx 6d10932b09 update readme 2026-04-08 18:48:58 +08:00
24 changed files with 1627 additions and 430 deletions
+4 -1
View File
@@ -1,4 +1,7 @@
launch.json
__pycache__
voxcpm.egg-info
.DS_Store
.DS_Store
./pretrained_models/
app_local.py
models/
+36 -8
View File
@@ -46,7 +46,7 @@ VoxCPM is a **tokenizer-free** Text-to-Speech system that directly generates con
- 🎙️ **Ultimate Cloning** — Reproduce every vocal nuance: provide both reference audio and its transcript, and the model continues seamlessly from the reference, faithfully preserving every vocal detail — timbre, rhythm, emotion, and style (same as VoxCPM1.5)
- 🔊 **48kHz High-Quality Audio** — Accepts 16kHz reference audio and directly outputs 48kHz studio-quality audio via AudioVAE V2's asymmetric encode/decode design, with built-in super-resolution — no external upsampler needed
- 🧠 **Context-Aware Synthesis** — Automatically infers appropriate prosody and expressiveness from text content
-**Real-Time Streaming** — RTF as low as ~0.3 on NVIDIA RTX 4090, and ~0.13 accelerated by [Nano-VLLM](https://github.com/a710128/nanovllm-voxcpm)
-**Real-Time Streaming** — RTF as low as ~0.3 on NVIDIA RTX 4090, and ~0.13 accelerated by [Nano-vLLM](https://github.com/a710128/nanovllm-voxcpm) or [vLLM-Omni](https://github.com/vllm-project/vllm-omni) — official vLLM omni-modal serving for VoxCPM2 with PagedAttention and an OpenAI-compatible API
- 📜 **Fully Open-Source & Commercial-Ready** — Weights and code released under the [Apache-2.0](LICENSE) license, free for commercial use
@@ -92,7 +92,7 @@ Chinese Dialect: 四川话, 粤语, 吴语, 东北话, 河南话, 陕西话, 山
pip install voxcpm
```
> **Requirements:** Python ≥ 3.10, PyTorch ≥ 2.5.0, CUDA ≥ 12.0. See [Quick Start Docs](https://voxcpm.readthedocs.io/en/latest/quickstart.html) for details.
> **Requirements:** Python ≥ 3.10 (<3.13), PyTorch ≥ 2.5.0, CUDA ≥ 12.0. See [Quick Start Docs](https://voxcpm.readthedocs.io/en/latest/quickstart.html) for details.
### Python API
@@ -123,12 +123,12 @@ pip install modelscope
```
```python
from modelscope.hub.snapshot_download import snapshot_download
from modelscope import snapshot_download
snapshot_download("OpenBMB/VoxCPM2", local_dir='./pretrained_models/VoxCPM2') # specify the local directory to save the model
from voxcpm import VoxCPM
import soundfile as sf
local_model_dir = snapshot_download("OpenBMB/VoxCPM2")
model = VoxCPM.from_pretrained(local_model_dir, load_denoiser=False)
model = VoxCPM.from_pretrained("./pretrained_models/VoxCPM2", load_denoiser=False)
wav = model.generate(
text="VoxCPM2 is the current recommended release for realistic multilingual speech synthesis.",
@@ -238,8 +238,8 @@ voxcpm --help
### Web Demo
```bash
python app.py --model-dir /path/to/VoxCPM2 --port 8808 # use a local model directory, open http://localhost:8808
```bash
python app.py --port 8808 # then open in browser: http://localhost:8808
```
### 🚢 Production Deployment (Nano-vLLM)
@@ -262,6 +262,32 @@ server.stop()
> **RTF as low as ~0.13 on NVIDIA RTX 4090** (vs ~0.3 with the standard PyTorch implementation), with support for batched concurrent requests and a FastAPI HTTP server. See the [Nano-vLLM-VoxCPM repo](https://github.com/a710128/nanovllm-voxcpm) for deployment details.
### 🏭 Production Serving (vLLM-Omni)
For production multi-tenant deployments, use [**vLLM-Omni**](https://github.com/vllm-project/vllm-omni) — the official vLLM project's omni-modal extension with native **VoxCPM2** support. PagedAttention KV cache, continuous batching, and a drop-in **OpenAI-compatible** `/v1/audio/speech` endpoint.
```bash
# Install from source (latest main — vllm-omni is rapidly evolving)
uv pip install vllm==0.19.0 --torch-backend=auto
git clone https://github.com/vllm-project/vllm-omni.git && cd vllm-omni
uv pip install -e .
```
See the [vLLM-Omni installation guide](https://vllm-omni.readthedocs.io/en/latest/getting_started/installation/) for other platforms (ROCm, XPU, MUSA, NPU) and Docker images.
```bash
# Launch an OpenAI-compatible TTS server (--omni enables omni-modal serving)
vllm serve openbmb/VoxCPM2 --omni --port 8000
# Call it from any OpenAI client
curl http://localhost:8000/v1/audio/speech \
-H "Content-Type: application/json" \
-d '{"model":"openbmb/VoxCPM2","input":"Hello from VoxCPM2 on vLLM-Omni!","voice":"default"}' \
--output out.wav
```
> Built on the upstream vLLM scheduler, with batched concurrent requests, streaming chunk delivery, and multi-GPU deployment out of the box. See the [VoxCPM2 example](https://github.com/vllm-project/vllm-omni/tree/main/examples/online_serving/voxcpm2) for full deployment recipes.
> **Full parameter reference, multi-scenario examples, and voice cloning tips →** [Quick Start Guide](https://voxcpm.readthedocs.io/en/latest/quickstart.html) | [Usage Guide](https://voxcpm.readthedocs.io/en/latest/usage_guide.html) | [Cookbook](https://voxcpm.readthedocs.io/en/latest/cookbook.html)
---
@@ -528,11 +554,13 @@ Full documentation: **[voxcpm.readthedocs.io](https://voxcpm.readthedocs.io/en/l
| Project | Description |
|---|---|
| [**Nano-vLLM**](https://github.com/a710128/nanovllm-voxcpm) | High-throughput and Fast GPU serving |
| [**vLLM-Omni**](https://github.com/vllm-project/vllm-omni) | Official vLLM omni-modal serving for VoxCPM2 — PagedAttention, OpenAI-compatible API |
| [**VoxCPM.cpp**](https://github.com/bluryar/VoxCPM.cpp) | GGML/GGUF: CPU, CUDA, Vulkan inference |
| [**VoxCPM-ONNX**](https://github.com/bluryar/VoxCPM-ONNX) | ONNX export for CPU inference |
| [**VoxCPMANE**](https://github.com/0seba/VoxCPMANE) | Apple Neural Engine backend |
| [**voxcpm_rs**](https://github.com/madushan1000/voxcpm_rs) | Rust re-implementation |
| [**ComfyUI-VoxCPM**](https://github.com/wildminder/ComfyUI-VoxCPM) | ComfyUI node-based workflows |
| [**ComfyUI_RH_VoxCPM**](https://github.com/HM-RunningHub/ComfyUI_RH_VoxCPM) | Feature-complete ComfyUI workflow for VoxCPM 2 with multi-speaker generation, LoRA, and auto-ASR |
| [**ComfyUI-VoxCPMTTS**](https://github.com/1038lab/ComfyUI-VoxCPMTTS) | ComfyUI TTS extension |
| [**TTS WebUI**](https://github.com/rsxdalv/tts_webui_extension.vox_cpm) | Browser-based TTS extension |
+35 -7
View File
@@ -46,7 +46,7 @@ VoxCPM 是一个**无离散音频分词器**Tokenizer-Free)的语音合成
- 🎙️ **极致克隆** — 提供参考音频及其文本内容,模型接着参考音频进行无缝续写,从而精准还原声音细节特征(与 VoxCPM1.5 一致)
- 🔊 **48kHz 高质量音频** — 输入 16kHz 参考音频,通过 AudioVAE V2 的非对称编解码设计直接输出 48kHz 高质量音频,内置超分能力
- 🧠 **语境感知合成** — 根据文本内容自动推断合适的韵律和表现力
-**实时流式合成** — 在 NVIDIA RTX 4090 上 RTF 低至 ~0.3,通过 [Nano-VLLM](https://github.com/a710128/nanovllm-voxcpm) 加速后可达 ~0.13
-**实时流式合成** — 在 NVIDIA RTX 4090 上 RTF 低至 ~0.3,通过 [Nano-vLLM](https://github.com/a710128/nanovllm-voxcpm) 或 [vLLM-Omni](https://github.com/vllm-project/vllm-omni)(官方 vLLM 全模态服务,原生支持 VoxCPM2,提供 PagedAttention 与 OpenAI 兼容 API加速后可达 ~0.13
- 📜 **完全开源,商用就绪** — 权重和代码基于 [Apache-2.0](LICENSE) 协议发布,免费商用
<summary><b>🌍 支持的语言(30种)</b></summary>
@@ -91,7 +91,7 @@ VoxCPM 是一个**无离散音频分词器**Tokenizer-Free)的语音合成
pip install voxcpm
```
> **环境要求:** Python ≥ 3.10PyTorch ≥ 2.5.0CUDA ≥ 12.0。详见 [快速开始文档](https://voxcpm.readthedocs.io/zh-cn/latest/quickstart.html)。
> **环境要求:** Python ≥ 3.10 (<3.13)PyTorch ≥ 2.5.0CUDA ≥ 12.0。详见 [快速开始文档](https://voxcpm.readthedocs.io/zh-cn/latest/quickstart.html)。
### Python API
@@ -122,12 +122,12 @@ pip install modelscope
```
```python
from modelscope.hub.snapshot_download import snapshot_download
from modelscope import snapshot_download
snapshot_download("OpenBMB/VoxCPM2", local_dir='./pretrained_models/VoxCPM2') # 指定模型保存的本地路径
from voxcpm import VoxCPM
import soundfile as sf
local_model_dir = snapshot_download("OpenBMB/VoxCPM2")
model = VoxCPM.from_pretrained(local_model_dir, load_denoiser=False)
model = VoxCPM.from_pretrained('./pretrained_models/VoxCPM2', load_denoiser=False)
wav = model.generate(
text="VoxCPM2 是目前推荐使用的多语言语音合成版本。",
@@ -238,7 +238,7 @@ voxcpm --help
### Web Demo
```bash
python app.py --model-dir /path/to/VoxCPM2 --port 8808 # 指定本地模型路径,然后打开 http://localhost:8808
python app.py --port 8808 # 然后在浏览器打开 http://localhost:8808
```
### 🚢 生产部署(Nano-vLLM
@@ -261,6 +261,32 @@ server.stop()
> **在 NVIDIA RTX 4090 上 RTF 低至 ~0.13**(标准 PyTorch 实现约 ~0.3),支持批量并发请求和 FastAPI HTTP 服务。详见 [Nano-vLLM-VoxCPM 仓库](https://github.com/a710128/nanovllm-voxcpm)。
### 🏭 生产环境部署(vLLM-Omni
如需生产级多租户部署,使用 [**vLLM-Omni**](https://github.com/vllm-project/vllm-omni) — 官方 vLLM 项目的全模态扩展,原生支持 **VoxCPM2**。具备 PagedAttention KV 缓存、连续批处理,以及与 OpenAI 完全兼容的 `/v1/audio/speech` 接口。
```bash
# 从源码安装(最新 main 分支 —— vllm-omni 正在快速迭代)
uv pip install vllm==0.19.0 --torch-backend=auto
git clone https://github.com/vllm-project/vllm-omni.git && cd vllm-omni
uv pip install -e .
```
其他平台(ROCm、XPU、MUSA、NPU)与 Docker 镜像请参考 [vLLM-Omni 安装文档](https://vllm-omni.readthedocs.io/en/latest/getting_started/installation/)。
```bash
# 启动 OpenAI 兼容的 TTS 服务(--omni 启用全模态服务)
vllm serve openbmb/VoxCPM2 --omni --port 8000
# 任意 OpenAI 客户端均可调用
curl http://localhost:8000/v1/audio/speech \
-H "Content-Type: application/json" \
-d '{"model":"openbmb/VoxCPM2","input":"你好,欢迎使用 VoxCPM2 on vLLM-Omni","voice":"default"}' \
--output out.wav
```
> 基于上游 vLLM 调度器构建,开箱即用支持批量并发、流式分块输出和多 GPU 部署。完整示例见 [VoxCPM2 部署样例](https://github.com/vllm-project/vllm-omni/tree/main/examples/online_serving/voxcpm2)。
> **完整参数说明、多场景示例与声音克隆技巧 →** [快速开始指南](https://voxcpm.readthedocs.io/zh-cn/latest/quickstart.html) | [使用指南](https://voxcpm.readthedocs.io/zh-cn/latest/usage_guide.html) | [Cookbook](https://voxcpm.readthedocs.io/zh-cn/latest/cookbook.html)
---
@@ -521,11 +547,13 @@ python lora_ft_webui.py # 然后打开 http://localhost:7860
| 项目 | 说明 |
|---|---|
| [**Nano-vLLM**](https://github.com/a710128/nanovllm-voxcpm) | 高吞吐快速 GPU 推理引擎 |
| [**vLLM-Omni**](https://github.com/vllm-project/vllm-omni) | 官方 vLLM 全模态服务(原生支持 VoxCPM2)— PagedAttention、OpenAI 兼容 API |
| [**VoxCPM.cpp**](https://github.com/bluryar/VoxCPM.cpp) | GGML/GGUFCPU、CUDA、Vulkan 推理 |
| [**VoxCPM-ONNX**](https://github.com/bluryar/VoxCPM-ONNX) | ONNX 导出,支持 CPU 推理 |
| [**VoxCPMANE**](https://github.com/0seba/VoxCPMANE) | Apple Neural Engine 后端 |
| [**voxcpm_rs**](https://github.com/madushan1000/voxcpm_rs) | Rust 重新实现 |
| [**ComfyUI-VoxCPM**](https://github.com/wildminder/ComfyUI-VoxCPM) | ComfyUI 节点工作流 |
| [**ComfyUI_RH_VoxCPM**](https://github.com/HM-RunningHub/ComfyUI_RH_VoxCPM) | 面向 VoxCPM 2 的功能更完整的 ComfyUI 工作流,支持多说话人、LoRA 和自动 ASR |
| [**ComfyUI-VoxCPMTTS**](https://github.com/1038lab/ComfyUI-VoxCPMTTS) | ComfyUI TTS 扩展 |
| [**TTS WebUI**](https://github.com/rsxdalv/tts_webui_extension.vox_cpm) | 浏览器端 TTS 扩展 |
+21 -37
View File
@@ -1,4 +1,5 @@
import os
import re
import sys
import logging
import numpy as np
@@ -9,8 +10,6 @@ from funasr import AutoModel
from pathlib import Path
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if os.environ.get("HF_REPO_ID", "").strip() == "":
os.environ["HF_REPO_ID"] = "openbmb/VoxCPM2"
import voxcpm
@@ -55,7 +54,7 @@ _EXAMPLES_FOOTER_EN = (
)
_USAGE_INSTRUCTIONS_ZH = (
"**VoxCPM2 — 三种语音生成方式:**\n\n"
"**三种语音生成方式:**\n\n"
"🎨 **声音设计(Voice Design** \n"
"无需参考音频。在 **Control Instruction** 中描述目标音色特征"
"(性别、年龄、语气、情绪、语速等),VoxCPM2 即可为你从零创造独一无二的声音。\n\n"
@@ -66,6 +65,8 @@ _USAGE_INSTRUCTIONS_ZH = (
"开启 **极致克隆模式** 并提供参考音频的文字内容(可自动识别)。"
"模型会将参考音频视为已说出的前文,以**音频续写**的方式完整还原参考音频中的所有声音细节。"
"注意:该模式与可控克隆模式互斥,将禁用Control Instruction。\n\n"
"目前支持的方言包括:\n"
"「四川话、粤语、吴语、东北话、河南话、陕西话、山东话、天津话、闽南话」"
)
_EXAMPLES_FOOTER_ZH = (
@@ -221,11 +222,11 @@ _APP_THEME = gr.themes.Soft(
# ---------- Model ----------
class VoxCPMDemo:
def __init__(self, model_dir: Optional[str] = None) -> None:
def __init__(self, model_id: str = "openbmb/VoxCPM2") -> None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Running on device: {self.device}")
logger.info(f"运行在设备上: {self.device}")
self.asr_model_id = "iic/SenseVoiceSmall"
self.asr_model_id = "./models/iic/SenseVoiceSmall"
self.asr_model: Optional[AutoModel] = AutoModel(
model=self.asr_model_id,
disable_update=True,
@@ -234,36 +235,13 @@ class VoxCPMDemo:
)
self.voxcpm_model: Optional[voxcpm.VoxCPM] = None
self.explicit_model_dir = model_dir
def _resolve_model_dir(self) -> str:
if self.explicit_model_dir and os.path.isdir(self.explicit_model_dir):
return self.explicit_model_dir
env_model_dir = os.environ.get("VOXCPM_MODEL_DIR", "").strip()
if env_model_dir and os.path.isdir(env_model_dir):
return env_model_dir
repo_id = os.environ.get("HF_REPO_ID", "").strip()
if len(repo_id) > 0:
target_dir = os.path.join("models", repo_id.replace("/", "__"))
if not os.path.isdir(target_dir):
try:
from huggingface_hub import snapshot_download
os.makedirs(target_dir, exist_ok=True)
logger.info(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...")
snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False)
except Exception as e:
logger.warning(f"HF download failed: {e}. Falling back to 'models'.")
return "models"
return target_dir
return "models"
self._model_id = model_id
def get_or_load_voxcpm(self) -> voxcpm.VoxCPM:
if self.voxcpm_model is not None:
return self.voxcpm_model
logger.info("Model not loaded, initializing...")
model_dir = self._resolve_model_dir()
logger.info(f"Using model dir: {model_dir}")
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir, optimize=True)
logger.info(f"Loading model: {self._model_id}")
self.voxcpm_model = voxcpm.VoxCPM.from_pretrained(self._model_id, optimize=True,zipenhancer_model_id="./models/iic/speech_zipenhancer_ans_multiloss_16k_base")
logger.info("Model loaded successfully.")
return self.voxcpm_model
@@ -315,6 +293,9 @@ class VoxCPMDemo:
raise ValueError("Please input text to synthesize.")
control = (control_instruction or "").strip()
# Strip any parentheses (half-width/full-width) from control text to avoid
# breaking the "(control)text" prompt format expected by the model.
control = re.sub(r"[()()]", "", control).strip()
final_text = f"({control}){text}" if control else text
audio_path = reference_wav_path_input if reference_wav_path_input else None
@@ -507,9 +488,9 @@ def run_demo(
server_name: str = "0.0.0.0",
server_port: int = 8808,
show_error: bool = True,
model_dir: Optional[str] = None,
model_id: str = "./models/OpenBMB/VoxCPM2",
):
demo = VoxCPMDemo(model_dir=model_dir)
demo = VoxCPMDemo(model_id=model_id)
interface = create_demo_interface(demo)
interface.queue(max_size=10, default_concurrency_limit=1).launch(
server_name=server_name,
@@ -524,7 +505,10 @@ def run_demo(
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--model-dir", type=str, default=None, help="Path to VoxCPM2 checkpoint directory")
parser.add_argument("--port", type=int, default=8808, help="Server port")
parser.add_argument(
"--model-id", type=str, default="./models/OpenBMB/VoxCPM2",
help="本地路径或HuggingFace仓库ID(默认:./models/openbmb/VoxCPM2",
)
parser.add_argument("--port", type=int, default=8808, help="服务端口")
args = parser.parse_args()
run_demo(model_dir=args.model_dir, server_port=args.port)
run_demo(model_id=args.model_id, server_port=args.port)
-280
View File
@@ -1,280 +0,0 @@
import os
import sys
import numpy as np
import torch
import gradio as gr
from typing import Optional, Tuple
from funasr import AutoModel
from pathlib import Path
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if os.environ.get("HF_REPO_ID", "").strip() == "":
os.environ["HF_REPO_ID"] = "openbmb/VoxCPM1.5"
import voxcpm
class VoxCPMDemo:
def __init__(self) -> None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🚀 Running on device: {self.device}", file=sys.stderr)
# ASR model for prompt text recognition
self.asr_model_id = "iic/SenseVoiceSmall"
self.asr_model: Optional[AutoModel] = AutoModel(
model=self.asr_model_id,
disable_update=True,
log_level='DEBUG',
device="cuda:0" if self.device == "cuda" else "cpu",
)
# TTS model (lazy init)
self.voxcpm_model: Optional[voxcpm.VoxCPM] = None
self.default_local_model_dir = "./models/VoxCPM1.5"
# ---------- Model helpers ----------
def _resolve_model_dir(self) -> str:
"""
Resolve model directory:
1) Use local checkpoint directory if exists
2) If HF_REPO_ID env is set, download into models/{repo}
3) Fallback to 'models'
"""
if os.path.isdir(self.default_local_model_dir):
return self.default_local_model_dir
repo_id = os.environ.get("HF_REPO_ID", "").strip()
if len(repo_id) > 0:
target_dir = os.path.join("models", repo_id.replace("/", "__"))
if not os.path.isdir(target_dir):
try:
from huggingface_hub import snapshot_download # type: ignore
os.makedirs(target_dir, exist_ok=True)
print(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...", file=sys.stderr)
snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False)
except Exception as e:
print(f"Warning: HF download failed: {e}. Falling back to 'data'.", file=sys.stderr)
return "models"
return target_dir
return "models"
def get_or_load_voxcpm(self) -> voxcpm.VoxCPM:
if self.voxcpm_model is not None:
return self.voxcpm_model
print("Model not loaded, initializing...", file=sys.stderr)
model_dir = self._resolve_model_dir()
print(f"Using model dir: {model_dir}", file=sys.stderr)
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir)
print("Model loaded successfully.", file=sys.stderr)
return self.voxcpm_model
# ---------- Functional endpoints ----------
def prompt_wav_recognition(self, prompt_wav: Optional[str]) -> str:
if prompt_wav is None:
return ""
res = self.asr_model.generate(input=prompt_wav, language="auto", use_itn=True)
text = res[0]["text"].split('|>')[-1]
return text
def generate_tts_audio(
self,
text_input: str,
prompt_wav_path_input: Optional[str] = None,
prompt_text_input: Optional[str] = None,
cfg_value_input: float = 2.0,
inference_timesteps_input: int = 10,
do_normalize: bool = True,
denoise: bool = True,
) -> Tuple[int, np.ndarray]:
"""
Generate speech from text using VoxCPM; optional reference audio for voice style guidance.
Returns (sample_rate, waveform_numpy)
"""
current_model = self.get_or_load_voxcpm()
text = (text_input or "").strip()
if len(text) == 0:
raise ValueError("Please input text to synthesize.")
prompt_wav_path = prompt_wav_path_input if prompt_wav_path_input else None
prompt_text = prompt_text_input if prompt_text_input else None
print(f"Generating audio for text: '{text[:60]}...'", file=sys.stderr)
wav = current_model.generate(
text=text,
prompt_text=prompt_text,
prompt_wav_path=prompt_wav_path,
cfg_value=float(cfg_value_input),
inference_timesteps=int(inference_timesteps_input),
normalize=do_normalize,
denoise=denoise,
)
return (current_model.tts_model.sample_rate, wav)
# ---------- UI Builders ----------
_APP_THEME = gr.themes.Soft(
primary_hue="blue",
secondary_hue="gray",
neutral_hue="slate",
font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"],
)
_CUSTOM_CSS = """
.logo-container {
text-align: center;
margin: 0.5rem 0 1rem 0;
}
.logo-container img {
height: 80px;
width: auto;
max-width: 200px;
display: inline-block;
}
/* Bold accordion labels */
#acc_quick details > summary,
#acc_tips details > summary {
font-weight: 600 !important;
font-size: 1.1em !important;
}
/* Bold labels for specific checkboxes */
#chk_denoise label,
#chk_denoise span,
#chk_normalize label,
#chk_normalize span {
font-weight: 600;
}
"""
def create_demo_interface(demo: VoxCPMDemo):
"""Build the Gradio UI for VoxCPM demo."""
gr.set_static_paths(paths=[Path.cwd().absolute()/"assets"])
with gr.Blocks() as interface:
# Header logo
gr.HTML('<div class="logo-container"><img src="/gradio_api/file=assets/voxcpm_logo.png" alt="VoxCPM Logo"></div>')
# Quick Start
with gr.Accordion("📋 Quick Start Guide |快速入门", open=False, elem_id="acc_quick"):
gr.Markdown("""
### How to Use |使用说明
1. **(Optional) Provide a Voice Prompt** - Upload or record an audio clip to provide the desired voice characteristics for synthesis.
**(可选)提供参考声音** - 上传或录制一段音频,为声音合成提供音色、语调和情感等个性化特征
2. **(Optional) Enter prompt text** - If you provided a voice prompt, enter the corresponding transcript here (auto-recognition available).
**(可选项)输入参考文本** - 如果提供了参考语音,请输入其对应的文本内容(支持自动识别)。
3. **Enter target text** - Type the text you want the model to speak.
**输入目标文本** - 输入您希望模型朗读的文字内容。
4. **Generate Speech** - Click the "Generate" button to create your audio.
**生成语音** - 点击"生成"按钮,即可为您创造出音频。
""")
# Pro Tips
with gr.Accordion("💡 Pro Tips |使用建议", open=False, elem_id="acc_tips"):
gr.Markdown("""
### Prompt Speech Enhancement|参考语音降噪
- **Enable** to remove background noise for a clean voice, with an external ZipEnhancer component. However, this will limit the audio sampling rate to 16kHz, restricting the cloning quality ceiling.
**启用**:通过 ZipEnhancer 组件消除背景噪音,但会将音频采样率限制在16kHz,限制克隆上限。
- **Disable** to preserve the original audio's all information, including background atmosphere, and support audio cloning up to 44.1kHz sampling rate.
**禁用**:保留原始音频的全部信息,包括背景环境声,最高支持44.1kHz的音频复刻。
### Text Normalization|文本正则化
- **Enable** to process general text with an external WeTextProcessing component.
**启用**:使用 WeTextProcessing 组件,可支持常见文本的正则化处理。
- **Disable** to use VoxCPM's native text understanding ability. For example, it supports phonemes input (For Chinese, phonemes are converted using pinyin, {ni3}{hao3}; For English, phonemes are converted using CMUDict, {HH AH0 L OW1}), try it!
**禁用**:将使用 VoxCPM 内置的文本理解能力。如,支持音素输入(如中文转拼音:{ni3}{hao3};英文转CMUDict{HH AH0 L OW1})和公式符号合成,尝试一下!
### CFG ValueCFG 值
- **Lower CFG** if the voice prompt sounds strained or expressive, or instability occurs with long text input.
**调低**:如果提示语音听起来不自然或过于夸张,或者长文本输入出现稳定性问题。
- **Higher CFG** for better adherence to the prompt speech style or input text, or instability occurs with too short text input.
**调高**:为更好地贴合提示音频的风格或输入文本, 或者极短文本输入出现稳定性问题。
### Inference Timesteps|推理时间步
- **Lower** for faster synthesis speed.
**调低**:合成速度更快。
- **Higher** for better synthesis quality.
**调高**:合成质量更佳。
""")
# Main controls
with gr.Row():
with gr.Column():
prompt_wav = gr.Audio(
sources=["upload", 'microphone'],
type="filepath",
label="Prompt Speech (Optional, or let VoxCPM improvise)",
value="./examples/example.wav",
)
DoDenoisePromptAudio = gr.Checkbox(
value=False,
label="Prompt Speech Enhancement",
elem_id="chk_denoise",
info="We use ZipEnhancer model to denoise the prompt audio."
)
with gr.Row():
prompt_text = gr.Textbox(
value="Just by listening a few minutes a day, you'll be able to eliminate negative thoughts by conditioning your mind to be more positive.",
label="Prompt Text",
placeholder="Please enter the prompt text. Automatic recognition is supported, and you can correct the results yourself..."
)
run_btn = gr.Button("Generate Speech", variant="primary")
with gr.Column():
cfg_value = gr.Slider(
minimum=1.0,
maximum=3.0,
value=2.0,
step=0.1,
label="CFG Value (Guidance Scale)",
info="Higher values increase adherence to prompt, lower values allow more creativity"
)
inference_timesteps = gr.Slider(
minimum=4,
maximum=30,
value=10,
step=1,
label="Inference Timesteps",
info="Number of inference timesteps for generation (higher values may improve quality but slower)"
)
with gr.Row():
text = gr.Textbox(
value="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly realistic speech.",
label="Target Text",
)
with gr.Row():
DoNormalizeText = gr.Checkbox(
value=False,
label="Text Normalization",
elem_id="chk_normalize",
info="We use wetext library to normalize the input text."
)
audio_output = gr.Audio(label="Output Audio")
# Wiring
run_btn.click(
fn=demo.generate_tts_audio,
inputs=[text, prompt_wav, prompt_text, cfg_value, inference_timesteps, DoNormalizeText, DoDenoisePromptAudio],
outputs=[audio_output],
show_progress=True,
api_name="generate",
)
prompt_wav.change(fn=demo.prompt_wav_recognition, inputs=[prompt_wav], outputs=[prompt_text])
return interface
def run_demo(server_name: str = "localhost", server_port: int = 7860, show_error: bool = True):
demo = VoxCPMDemo()
interface = create_demo_interface(demo)
interface.queue(max_size=10, default_concurrency_limit=1).launch(
server_name=server_name,
server_port=server_port,
show_error=show_error,
theme=_APP_THEME,
css=_CUSTOM_CSS,
)
if __name__ == "__main__":
run_demo()
+31 -10
View File
@@ -281,27 +281,48 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
print(f"Warning: Failed to read base_model from LoRA config: {e}", file=sys.stderr)
# 加载模型
lora_to_load = lora_selection if lora_selection and lora_selection != "None" else None
try:
print(f"Loading base model: {base_model_path}", file=sys.stderr)
load_model(base_model_path)
if lora_selection and lora_selection != "None":
print(f"Model loaded for LoRA: {lora_selection}", file=sys.stderr)
load_model(base_model_path, lora_to_load)
if lora_to_load:
print(f"Model loaded with LoRA: {lora_selection}", file=sys.stderr)
except Exception as e:
error_msg = f"Failed to load model from {base_model_path}: {str(e)}"
print(error_msg, file=sys.stderr)
return None, error_msg
lora_just_loaded = lora_to_load
else:
lora_just_loaded = None
# Handle LoRA hot-swapping
assert current_model is not None, "Model must be loaded before inference"
if lora_selection and lora_selection != "None":
full_lora_path = os.path.join("lora", lora_selection)
print(f"Hot-loading LoRA: {full_lora_path}", file=sys.stderr)
try:
current_model.load_lora(full_lora_path)
current_model.set_lora_enabled(True)
except Exception as e:
print(f"Error loading LoRA: {e}", file=sys.stderr)
return None, f"Error loading LoRA: {e}"
if lora_just_loaded != lora_selection:
new_lora_config, new_base_model = load_lora_config_from_checkpoint(full_lora_path)
current_r = current_model.tts_model.lora_config.r if current_model.tts_model.lora_config else None
new_r = new_lora_config.r if new_lora_config else None
if new_r is not None and current_r is not None and new_r != current_r:
print(f"LoRA rank mismatch (model r={current_r}, checkpoint r={new_r}), reloading...", file=sys.stderr)
reload_base = (
new_base_model if new_base_model and os.path.exists(new_base_model)
else (pretrained_path if pretrained_path and pretrained_path.strip() else default_pretrained_path)
)
try:
load_model(reload_base, lora_selection)
except Exception as e:
return None, f"Failed to reload model for LoRA rank change: {e}"
else:
print(f"Hot-loading LoRA: {full_lora_path}", file=sys.stderr)
try:
current_model.load_lora(full_lora_path)
except Exception as e:
print(f"Error loading LoRA: {e}", file=sys.stderr)
return None, f"Error loading LoRA: {e}"
current_model.set_lora_enabled(True)
else:
print("Disabling LoRA", file=sys.stderr)
current_model.set_lora_enabled(False)
+34
View File
@@ -0,0 +1,34 @@
"""
模型下载脚本
"""
from modelscope import snapshot_download
def download(repo_id:str, local_dir:str):
"""
下载模型仓库或单个文件
Args:
repo_id (str): 用户名/仓库名,例如 'stabilityai/sdxl-turbo'
local_dir (str or Path): 下载文件放置的本地目录路径
Returns:
str: 下载文件的本地路径
Raises:
ValueError: 当 repo_id 格式不正确时
"""
model_dir = snapshot_download(
repo_id,
repo_type='model',
local_dir=f"{local_dir}/{repo_id}",
)
if __name__ == "__main__":
download("OpenBMB/VoxCPM2", "./models")
download("iic/SenseVoiceSmall", "./models")
download('iic/speech_zipenhancer_ans_multiloss_16k_base',"./models")
+1 -2
View File
@@ -47,8 +47,7 @@ dependencies = [
"funasr",
"spaces",
"argbind",
"safetensors"
"safetensors",
]
[project.optional-dependencies]
+108
View File
@@ -0,0 +1,108 @@
"""Unit checks for pick_runtime_dtype / get_dtype consistency.
Loads src/voxcpm/model/utils.py directly to avoid the heavy voxcpm package
init. Run with: `python scripts/test_pick_runtime_dtype.py`.
"""
import importlib.util
import os
import pathlib
import sys
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent
UTILS = str(REPO_ROOT / "src" / "voxcpm" / "model" / "utils.py")
spec = importlib.util.spec_from_file_location("voxcpm_utils", UTILS)
utils = importlib.util.module_from_spec(spec)
spec.loader.exec_module(utils)
_LOW_PRECISION_DTYPES = utils._LOW_PRECISION_DTYPES
_VALID_DTYPE_OVERRIDES = utils._VALID_DTYPE_OVERRIDES
get_dtype = utils.get_dtype
pick_runtime_dtype = utils.pick_runtime_dtype
def expect(actual, expected, label):
ok = actual == expected
mark = "OK " if ok else "FAIL"
print(f"[{mark}] {label}: got={actual!r} expected={expected!r}")
return ok
def expect_raises(fn, exc_type, label):
try:
fn()
except exc_type as e:
print(f"[OK ] {label}: raised {exc_type.__name__}: {e}")
return True
except Exception as e:
print(f"[FAIL] {label}: raised {type(e).__name__} not {exc_type.__name__}: {e}")
return False
print(f"[FAIL] {label}: no exception raised")
return False
results = []
print("=== override set sanity ===")
results.append(expect("half" not in _VALID_DTYPE_OVERRIDES, True, "half removed from _VALID_DTYPE_OVERRIDES"))
results.append(expect("half" not in _LOW_PRECISION_DTYPES, True, "half removed from _LOW_PRECISION_DTYPES"))
print("\n=== every accepted override parses through get_dtype ===")
for dt in sorted(_VALID_DTYPE_OVERRIDES):
try:
torch_dtype = get_dtype(dt)
print(f"[OK ] get_dtype({dt!r}) -> {torch_dtype}")
results.append(True)
except Exception as e:
print(f"[FAIL] get_dtype({dt!r}) raised: {e}")
results.append(False)
print("\n=== pick_runtime_dtype: non-mps is a no-op ===")
results.append(expect(pick_runtime_dtype("cuda", "bfloat16"), "bfloat16", "cuda/bf16 untouched"))
results.append(expect(pick_runtime_dtype("cpu", "float16"), "float16", "cpu/fp16 untouched"))
results.append(expect(pick_runtime_dtype("cuda", "float32"), "float32", "cuda/fp32 untouched"))
print("\n=== pick_runtime_dtype: mps forces fp32 for low-precision ===")
os.environ.pop("VOXCPM_MPS_DTYPE", None)
results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "float32", "mps/bf16 -> fp32"))
results.append(expect(pick_runtime_dtype("mps", "bf16"), "float32", "mps/bf16-alias -> fp32"))
results.append(expect(pick_runtime_dtype("mps", "float16"), "float32", "mps/fp16 -> fp32"))
results.append(expect(pick_runtime_dtype("mps", "fp16"), "float32", "mps/fp16-alias -> fp32"))
results.append(expect(pick_runtime_dtype("mps", "float32"), "float32", "mps/fp32 stays"))
results.append(expect(pick_runtime_dtype("mps", "fp32"), "fp32", "mps/fp32-alias stays"))
print("\n=== pick_runtime_dtype: VOXCPM_MPS_DTYPE override ===")
os.environ["VOXCPM_MPS_DTYPE"] = "bfloat16"
results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "bfloat16", "override bf16 honored"))
os.environ["VOXCPM_MPS_DTYPE"] = "FP16"
results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "fp16", "override is case-insensitive"))
os.environ["VOXCPM_MPS_DTYPE"] = " float32 "
results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "float32", "override is whitespace-trimmed"))
print("\n=== pick_runtime_dtype: 'half' is no longer a valid override ===")
os.environ["VOXCPM_MPS_DTYPE"] = "half"
results.append(
expect_raises(
lambda: pick_runtime_dtype("mps", "bfloat16"),
ValueError,
"override=half now rejected (was the bug)",
)
)
os.environ["VOXCPM_MPS_DTYPE"] = "garbage"
results.append(
expect_raises(
lambda: pick_runtime_dtype("mps", "bfloat16"),
ValueError,
"override=garbage still rejected",
)
)
os.environ.pop("VOXCPM_MPS_DTYPE", None)
print("\n=== summary ===")
passed = sum(results)
total = len(results)
print(f"{passed}/{total} passed")
sys.exit(0 if passed == total else 1)
+60 -6
View File
@@ -11,11 +11,6 @@ import os
import sys
from pathlib import Path
import soundfile as sf
from voxcpm.core import VoxCPM
DEFAULT_HF_MODEL_ID = "openbmb/VoxCPM2"
# -----------------------------
@@ -173,7 +168,9 @@ def validate_batch_args(args, parser):
# -----------------------------
def load_model(args) -> VoxCPM:
def load_model(args):
from voxcpm.core import VoxCPM
print("Loading VoxCPM model...", file=sys.stderr)
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get(
@@ -209,6 +206,7 @@ def load_model(args) -> VoxCPM:
zipenhancer_model_path=zipenhancer_path,
enable_denoiser=not args.no_denoiser,
optimize=not args.no_optimize,
device=args.device,
lora_config=lora_config,
lora_weights_path=lora_weights_path,
)
@@ -227,6 +225,7 @@ def load_model(args) -> VoxCPM:
cache_dir=args.cache_dir,
local_files_only=args.local_files_only,
optimize=not args.no_optimize,
device=args.device,
lora_config=lora_config,
lora_weights_path=lora_weights_path,
)
@@ -264,6 +263,8 @@ def _run_single(args, parser, *, text: str, output: str, prompt_text: str | None
and (args.prompt_audio is not None or args.reference_audio is not None),
)
import soundfile as sf
sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
duration = len(audio_array) / model.tts_model.sample_rate
@@ -286,7 +287,27 @@ def cmd_clone(args, parser):
)
def cmd_validate(args, parser):
from voxcpm.training.validate import (
print_validation_report,
validate_manifest,
)
manifest = str(require_file_exists(args.manifest, parser, "manifest file"))
result = validate_manifest(
manifest_path=manifest,
sample_rate=args.sample_rate,
max_samples=args.max_samples,
verbose=args.verbose,
)
print_validation_report(result, manifest)
if not result.is_valid:
sys.exit(1)
def cmd_batch(args, parser):
import soundfile as sf
input_file = require_file_exists(args.input, parser, "input file")
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
@@ -403,6 +424,12 @@ def _add_model_args(parser):
default=DEFAULT_HF_MODEL_ID,
help=f"Hugging Face repo id (default: {DEFAULT_HF_MODEL_ID})",
)
parser.add_argument(
"--device",
type=str,
default="auto",
help="Runtime device: auto, cpu, mps, cuda, or cuda:N (default: auto)",
)
parser.add_argument(
"--cache-dir", type=str, help="Cache directory for Hub downloads"
)
@@ -524,6 +551,30 @@ Examples:
_add_model_args(batch_parser)
_add_lora_args(batch_parser)
# Validate subcommand
validate_parser = subparsers.add_parser(
"validate",
help="Validate a training data manifest (JSONL) before fine-tuning",
)
validate_parser.add_argument(
"--manifest", "-m", required=True, help="Path to JSONL training manifest"
)
validate_parser.add_argument(
"--sample-rate",
type=int,
default=16_000,
help="Expected audio sample rate in Hz (default: 16000)",
)
validate_parser.add_argument(
"--max-samples",
type=int,
default=0,
help="Maximum number of samples to validate (0 = all, default: 0)",
)
validate_parser.add_argument(
"--verbose", "-v", action="store_true", help="Print per-sample progress"
)
# Legacy root arguments
parser.add_argument("--input", "-i", help="Input text file (batch mode only)")
parser.add_argument(
@@ -576,6 +627,9 @@ def main():
parser = _build_parser()
args = parser.parse_args()
if args.command == "validate":
return cmd_validate(args, parser)
validate_ranges(args, parser)
if args.command == "design":
+33 -6
View File
@@ -8,6 +8,7 @@ from typing import Generator, Optional
from huggingface_hub import snapshot_download
from .model.voxcpm import VoxCPMModel, LoRAConfig
from .model.voxcpm2 import VoxCPM2Model
from .model.utils import next_and_close
class VoxCPM:
@@ -17,6 +18,7 @@ class VoxCPM:
zipenhancer_model_path: str | None = "iic/speech_zipenhancer_ans_multiloss_16k_base",
enable_denoiser: bool = True,
optimize: bool = True,
device: str | None = None,
lora_config: Optional[LoRAConfig] = None,
lora_weights_path: Optional[str] = None,
):
@@ -30,6 +32,9 @@ class VoxCPM:
id or local path. If None, denoiser will not be initialized.
enable_denoiser: Whether to initialize the denoiser pipeline.
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
device: Runtime device. If set to ``None`` or ``"auto"``, VoxCPM
will choose automatically (preferring CUDA, then MPS, then CPU).
If set explicitly, that device is used or a clear error is raised.
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
provided without lora_config, a default config will be created.
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
@@ -56,10 +61,20 @@ class VoxCPM:
arch = config.get("architecture", "voxcpm").lower()
if arch == "voxcpm2":
self.tts_model = VoxCPM2Model.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
self.tts_model = VoxCPM2Model.from_local(
voxcpm_model_path,
optimize=optimize,
device=device,
lora_config=lora_config,
)
print("Loaded VoxCPM2Model", file=sys.stderr)
elif arch == "voxcpm":
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
self.tts_model = VoxCPMModel.from_local(
voxcpm_model_path,
optimize=optimize,
device=device,
lora_config=lora_config,
)
print("Loaded VoxCPMModel", file=sys.stderr)
else:
raise ValueError(f"Unsupported architecture: {arch}")
@@ -94,6 +109,7 @@ class VoxCPM:
cache_dir: str = None,
local_files_only: bool = False,
optimize: bool = True,
device: str | None = None,
lora_config: Optional[LoRAConfig] = None,
lora_weights_path: Optional[str] = None,
**kwargs,
@@ -109,6 +125,9 @@ class VoxCPM:
cache_dir: Custom cache directory for the snapshot.
local_files_only: If True, only use local files and do not attempt
to download.
device: Runtime device. Use ``None``/``"auto"`` for automatic
fallback, or an explicit value such as ``"cpu"``, ``"mps"``,
``"cuda"``, or ``"cuda:0"``.
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
provided without lora_config, a default config will be created with
enable_lm=True and enable_dit=True.
@@ -130,7 +149,7 @@ class VoxCPM:
if not repo_id:
raise ValueError("You must provide hf_model_id")
# Load from local path if provided
# 从本地路径加载(如果提供)
if os.path.isdir(repo_id):
local_path = repo_id
else:
@@ -146,13 +165,14 @@ class VoxCPM:
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
enable_denoiser=load_denoiser,
optimize=optimize,
device=device,
lora_config=lora_config,
lora_weights_path=lora_weights_path,
**kwargs,
)
def generate(self, *args, **kwargs) -> np.ndarray:
return next(self._generate(*args, streaming=False, **kwargs))
return next_and_close(self._generate(*args, streaming=False, **kwargs))
def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]:
return self._generate(*args, streaming=True, **kwargs)
@@ -200,7 +220,7 @@ class VoxCPM:
Yields audio chunks for each generation step if ``streaming=True``,
otherwise yields a single array containing the final audio.
"""
if not text.strip() or not isinstance(text, str):
if not isinstance(text, str) or not text.strip():
raise ValueError("target text must be a non-empty string")
if prompt_wav_path is not None:
@@ -273,7 +293,14 @@ class VoxCPM:
streaming=streaming,
)
for wav, _, _ in generate_result:
if streaming:
try:
for wav, _, _ in generate_result:
yield wav.squeeze(0).cpu().numpy()
finally:
generate_result.close()
else:
wav, _, _ = next_and_close(generate_result)
yield wav.squeeze(0).cpu().numpy()
finally:
+111 -1
View File
@@ -1,7 +1,25 @@
from typing import List
import os
from typing import List, Optional
import torch
from transformers import PreTrainedTokenizer
_LOW_PRECISION_DTYPES = {"bfloat16", "bf16", "float16", "fp16"}
_VALID_DTYPE_OVERRIDES = {
"bfloat16", "bf16",
"float16", "fp16",
"float32", "fp32",
}
# Ref: https://github.com/OpenBMB/VoxCPM/issues/256#issuecomment-4235252732
# Explicitly close partially-consumed generators so inference_mode cleanup
# does not get deferred to Python's GC/finalizer path.
def next_and_close(gen):
try:
return next(gen)
finally:
gen.close()
def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
"""Create a tokenizer wrapper that converts multi-character Chinese tokens to single characters.
@@ -119,3 +137,95 @@ def get_dtype(dtype: str):
return torch.float32
else:
raise ValueError(f"Unsupported dtype: {dtype}")
def _has_mps() -> bool:
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
def pick_runtime_dtype(device: str, configured_dtype: str) -> str:
"""Pick a safe runtime dtype for the resolved device.
On Apple Silicon (MPS), bfloat16/float16 produce enough numerical drift
in the diffusion AR loop that the output is glitched and the model's
badcase detector triggers infinite retries. float32 is the only stable
option today. CUDA and CPU keep whatever the checkpoint was trained with.
Users can override with ``VOXCPM_MPS_DTYPE`` (e.g. ``bfloat16``) when
they want to test future MPS improvements.
"""
if device != "mps":
return configured_dtype
override = os.environ.get("VOXCPM_MPS_DTYPE", "").strip().lower()
if override:
if override not in _VALID_DTYPE_OVERRIDES:
raise ValueError(
f"VOXCPM_MPS_DTYPE='{override}' is not one of "
f"{sorted(_VALID_DTYPE_OVERRIDES)}"
)
return override
if (configured_dtype or "").lower() in _LOW_PRECISION_DTYPES:
return "float32"
return configured_dtype
def auto_select_device(preferred_device: Optional[str] = "cuda") -> str:
"""
Choose a runtime device automatically.
Preference order:
- if the preferred device is available, use it
- otherwise fall back to CUDA -> MPS -> CPU
"""
preferred = (preferred_device or "cuda").strip().lower()
if preferred.startswith("cuda") and torch.cuda.is_available():
return preferred
if preferred == "mps" and _has_mps():
return "mps"
if preferred == "cpu":
return "cpu"
if torch.cuda.is_available():
return "cuda"
if _has_mps():
return "mps"
return "cpu"
def resolve_runtime_device(device: Optional[str], configured_device: str = "cuda") -> str:
"""
Resolve the actual runtime device.
Semantics:
- ``device`` is ``None`` or ``"auto"``: use automatic fallback selection
- otherwise: treat it as an explicit user choice and validate availability
"""
explicit = None if device is None else device.strip().lower()
if explicit is None or explicit == "auto":
return auto_select_device(configured_device)
if explicit.startswith("cuda"):
if not torch.cuda.is_available():
raise ValueError(
f"Requested device '{device}', but CUDA is not available. "
"Use device='auto' for automatic fallback."
)
return explicit
if explicit == "mps":
if not _has_mps():
raise ValueError(
"Requested device 'mps', but MPS is not available. "
"Use device='auto' for automatic fallback."
)
return "mps"
if explicit == "cpu":
return "cpu"
raise ValueError(
f"Unsupported device '{device}'. Supported values are 'auto', 'cpu', 'mps', "
"'cuda', or indexed CUDA devices like 'cuda:0'."
)
+37 -17
View File
@@ -44,7 +44,13 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
from ..modules.locenc import VoxCPMLocEnc
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
from .utils import get_dtype, mask_multichar_chinese_tokens
from .utils import (
get_dtype,
mask_multichar_chinese_tokens,
next_and_close,
pick_runtime_dtype,
resolve_runtime_device,
)
class VoxCPMEncoderConfig(BaseModel):
@@ -109,18 +115,22 @@ class VoxCPMModel(nn.Module):
tokenizer: LlamaTokenizerFast,
audio_vae: AudioVAE,
lora_config: LoRAConfig = None,
device: str | None = None,
):
super().__init__()
self.config = config
self.lora_config = lora_config
self.feat_dim = config.feat_dim
self.patch_size = config.patch_size
self.device = config.device
if not torch.cuda.is_available():
if torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"
self.device = resolve_runtime_device(device, config.device)
self.config.device = self.device
resolved_dtype = pick_runtime_dtype(self.device, self.config.dtype)
if resolved_dtype != self.config.dtype:
print(
f"[voxcpm] adjusted dtype {self.config.dtype} -> {resolved_dtype} for device {self.device}",
file=sys.stderr,
)
self.config.dtype = resolved_dtype
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
# Text-Semantic LM
@@ -227,6 +237,7 @@ class VoxCPMModel(nn.Module):
self.residual_lm.forward_step = torch.compile(
self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True
)
self._feat_encoder_raw = self.feat_encoder
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
self.feat_decoder.estimator = torch.compile(
self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True
@@ -337,7 +348,7 @@ class VoxCPMModel(nn.Module):
return get_dtype(self.config.dtype)
def generate(self, *args, **kwargs) -> torch.Tensor:
return next(self._generate(*args, streaming=False, **kwargs))
return next_and_close(self._generate(*args, streaming=False, **kwargs))
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
return self._generate(*args, streaming=True, **kwargs)
@@ -463,7 +474,7 @@ class VoxCPMModel(nn.Module):
yield decode_audio
break
else:
latent_pred, pred_audio_feat = next(inference_result)
latent_pred, pred_audio_feat = next_and_close(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(
@@ -571,7 +582,7 @@ class VoxCPMModel(nn.Module):
return merged_cache
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
return next_and_close(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
def generate_with_prompt_cache_streaming(
self, *args, **kwargs
@@ -690,7 +701,7 @@ class VoxCPMModel(nn.Module):
yield (decode_audio, target_text_token, pred_audio_feat)
break
else:
latent_pred, pred_audio_feat = next(inference_result)
latent_pred, pred_audio_feat = next_and_close(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(
@@ -713,7 +724,7 @@ class VoxCPMModel(nn.Module):
yield (decode_audio, target_text_token, pred_audio_feat)
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
return next(self._inference(*args, streaming=False, **kwargs))
return next_and_close(self._inference(*args, streaming=False, **kwargs))
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
return self._inference(*args, streaming=True, **kwargs)
@@ -755,7 +766,8 @@ class VoxCPMModel(nn.Module):
"""
B, T, P, D = feat.shape
feat_embed = self.feat_encoder(feat) # [b, t, h_feat]
prefill_encoder = getattr(self, "_feat_encoder_raw", self.feat_encoder)
feat_embed = prefill_encoder(feat) # [b, t, h_feat]
feat_embed = self.enc_to_lm_proj(feat_embed)
if self.config.lm_config.use_mup:
@@ -845,8 +857,16 @@ class VoxCPMModel(nn.Module):
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
@classmethod
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
def from_local(
cls,
path: str,
optimize: bool = True,
training: bool = False,
device: str | None = None,
lora_config: LoRAConfig = None,
):
with open(os.path.join(path, "config.json"), "r", encoding="utf-8") as _cfg_f:
config = VoxCPMConfig.model_validate_json(_cfg_f.read())
tokenizer = LlamaTokenizerFast.from_pretrained(path)
audio_vae_config = getattr(config, "audio_vae_config", None)
audio_vae = AudioVAE(config=audio_vae_config) if audio_vae_config else AudioVAE()
@@ -868,7 +888,7 @@ class VoxCPMModel(nn.Module):
raise FileNotFoundError(
f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
)
model = cls(config, tokenizer, audio_vae, lora_config)
model = cls(config, tokenizer, audio_vae, lora_config, device=device)
if not training:
lm_dtype = get_dtype(model.config.dtype)
model = model.to(lm_dtype)
@@ -950,7 +970,7 @@ class VoxCPMModel(nn.Module):
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
state_dict = load_file(str(safetensors_file), device=device)
elif ckpt_file and ckpt_file.exists():
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
ckpt = torch.load(ckpt_file, map_location=device, weights_only=True)
state_dict = ckpt.get("state_dict", ckpt)
else:
raise FileNotFoundError(f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}")
+49 -29
View File
@@ -45,7 +45,13 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiTV2
from ..modules.locenc import VoxCPMLocEnc
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
from .utils import get_dtype, mask_multichar_chinese_tokens
from .utils import (
get_dtype,
mask_multichar_chinese_tokens,
next_and_close,
pick_runtime_dtype,
resolve_runtime_device,
)
# A simple function to trim audio silence using VAD, not used default
@@ -151,18 +157,22 @@ class VoxCPM2Model(nn.Module):
tokenizer: LlamaTokenizerFast,
audio_vae: AudioVAEV2,
lora_config: LoRAConfig = None,
device: str | None = None,
):
super().__init__()
self.config = config
self.lora_config = lora_config
self.feat_dim = config.feat_dim
self.patch_size = config.patch_size
self.device = config.device
if not torch.cuda.is_available():
if torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"
self.device = resolve_runtime_device(device, config.device)
self.config.device = self.device
resolved_dtype = pick_runtime_dtype(self.device, self.config.dtype)
if resolved_dtype != self.config.dtype:
print(
f"[voxcpm2] adjusted dtype {self.config.dtype} -> {resolved_dtype} for device {self.device}",
file=sys.stderr,
)
self.config.dtype = resolved_dtype
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
# Text-Semantic LM
@@ -275,6 +285,7 @@ class VoxCPM2Model(nn.Module):
self.residual_lm.forward_step = torch.compile(
self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True
)
self._feat_encoder_raw = self.feat_encoder
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
self.feat_decoder.estimator = torch.compile(
self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True
@@ -443,7 +454,7 @@ class VoxCPM2Model(nn.Module):
return tokens, feats, t_mask, a_mask
def generate(self, *args, **kwargs) -> torch.Tensor:
return next(self._generate(*args, streaming=False, **kwargs))
return next_and_close(self._generate(*args, streaming=False, **kwargs))
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
return self._generate(*args, streaming=True, **kwargs)
@@ -636,14 +647,14 @@ class VoxCPM2Model(nn.Module):
streaming_prefix_len=streaming_prefix_len,
)
if streaming:
decode_patch_len = self.patch_size * self._decode_chunk_size
for latent_pred, _, _ctx in inference_result:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
decode_audio = decode_audio[..., -decode_patch_len:].squeeze(1).cpu()
yield decode_audio
with self.audio_vae.streaming_decode() as vae_dec:
for latent_pred, _, _ctx in inference_result:
decode_audio = vae_dec.decode_chunk(latent_pred.to(torch.float32))
decode_audio = decode_audio.squeeze(1).cpu()
yield decode_audio
break
else:
latent_pred, pred_audio_feat, context_len = next(inference_result)
latent_pred, pred_audio_feat, context_len = next_and_close(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(
@@ -761,7 +772,7 @@ class VoxCPM2Model(nn.Module):
return merged
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
return next_and_close(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
def generate_with_prompt_cache_streaming(
self, *args, **kwargs
@@ -923,14 +934,14 @@ class VoxCPM2Model(nn.Module):
streaming_prefix_len=streaming_prefix_len,
)
if streaming:
decode_patch_len = self.patch_size * self._decode_chunk_size
for latent_pred, pred_audio_feat, _ctx in inference_result:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
decode_audio = decode_audio[..., -decode_patch_len:].squeeze(1).cpu()
yield (decode_audio, target_text_token, pred_audio_feat)
with self.audio_vae.streaming_decode() as vae_dec:
for latent_pred, pred_audio_feat, _ctx in inference_result:
decode_audio = vae_dec.decode_chunk(latent_pred.to(torch.float32))
decode_audio = decode_audio.squeeze(1).cpu()
yield (decode_audio, target_text_token, pred_audio_feat)
break
else:
latent_pred, pred_audio_feat, context_len = next(inference_result)
latent_pred, pred_audio_feat, context_len = next_and_close(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(
@@ -953,7 +964,7 @@ class VoxCPM2Model(nn.Module):
yield (decode_audio, target_text_token, pred_audio_feat)
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
feat_pred, generated_feat, _ = next(self._inference(*args, streaming=False, **kwargs))
feat_pred, generated_feat, _ = next_and_close(self._inference(*args, streaming=False, **kwargs))
return feat_pred, generated_feat
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
@@ -997,7 +1008,8 @@ class VoxCPM2Model(nn.Module):
"""
B, T, P, D = feat.shape
feat_embed = self.feat_encoder(feat) # [b, t, h_feat]
prefill_encoder = getattr(self, "_feat_encoder_raw", self.feat_encoder)
feat_embed = prefill_encoder(feat) # [b, t, h_feat]
feat_embed = self.enc_to_lm_proj(feat_embed)
if self.config.lm_config.use_mup:
@@ -1067,8 +1079,8 @@ class VoxCPM2Model(nn.Module):
prefix_feat_cond = pred_feat
if streaming:
pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1)
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
# Yield only the newest patch latent for stateful VAE decode
feat_pred = rearrange(pred_feat.unsqueeze(1), "b t p d -> b d (t p)", b=B, p=self.patch_size)
yield feat_pred, pred_feat_seq, context_len
@@ -1096,8 +1108,16 @@ class VoxCPM2Model(nn.Module):
yield feat_pred, generated_feat, context_len
@classmethod
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
def from_local(
cls,
path: str,
optimize: bool = True,
training: bool = False,
device: str | None = None,
lora_config: LoRAConfig = None,
):
with open(os.path.join(path, "config.json"), "r", encoding="utf-8") as _cfg_f:
config = VoxCPMConfig.model_validate_json(_cfg_f.read())
tokenizer = LlamaTokenizerFast.from_pretrained(path)
audio_vae_config = getattr(config, "audio_vae_config", None)
audio_vae = AudioVAEV2(config=audio_vae_config) if audio_vae_config else AudioVAEV2()
@@ -1119,7 +1139,7 @@ class VoxCPM2Model(nn.Module):
raise FileNotFoundError(
f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
)
model = cls(config, tokenizer, audio_vae, lora_config)
model = cls(config, tokenizer, audio_vae, lora_config, device=device)
if not training:
lm_dtype = get_dtype(model.config.dtype)
model = model.to(lm_dtype)
@@ -1201,7 +1221,7 @@ class VoxCPM2Model(nn.Module):
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
state_dict = load_file(str(safetensors_file), device=device)
elif ckpt_file and ckpt_file.exists():
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
ckpt = torch.load(ckpt_file, map_location=device, weights_only=True)
state_dict = ckpt.get("state_dict", ckpt)
else:
raise FileNotFoundError(f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}")
@@ -472,6 +472,20 @@ class AudioVAE(nn.Module):
sr_cond = torch.tensor([self.out_sample_rate], device=z.device, dtype=torch.int32)
return self.decoder(z, sr_cond)
def streaming_decode(self):
"""Return a ``StreamingVAEDecoder`` context manager for stateful
chunk-by-chunk decoding. Each call to ``decode_chunk`` processes only
the new latent patch and carries causal-conv state internally, avoiding
the redundant overlap decode used previously.
Usage::
with vae.streaming_decode() as dec:
for patch in patches:
audio_chunk = dec.decode_chunk(patch)
"""
return StreamingVAEDecoder(self)
def encode(self, audio_data: torch.Tensor, sample_rate: int):
"""
Args:
@@ -485,3 +499,82 @@ class AudioVAE(nn.Module):
audio_data = self.preprocess(audio_data, sample_rate)
return self.encoder(audio_data)["mu"]
class StreamingVAEDecoder:
"""Stateful streaming wrapper for :class:`AudioVAE`.
Carries causal-convolution padding buffers between calls so that each
``decode_chunk`` processes only the new latent patch — no overlap needed.
"""
def __init__(self, vae: AudioVAE):
self._vae = vae
self._states: dict = {}
self._originals: list = []
# -- context manager --------------------------------------------------
def __enter__(self):
self._states.clear()
self._install()
return self
def __exit__(self, *exc):
self._restore()
self._states.clear()
# -- public API --------------------------------------------------------
def decode_chunk(self, z_chunk: torch.Tensor) -> torch.Tensor:
"""Decode a single latent chunk and return the audio waveform."""
return self._vae.decode(z_chunk)
# -- internals ---------------------------------------------------------
def _install(self):
for name, mod in self._vae.decoder.named_modules():
if isinstance(mod, CausalConv1d):
pad = mod._CausalConv1d__padding * 2 - mod._CausalConv1d__output_padding
if pad > 0:
self._patch_causal_conv(mod, pad)
elif isinstance(mod, CausalTransposeConv1d):
trim = mod._CausalTransposeConv1d__padding * 2 - mod._CausalTransposeConv1d__output_padding
ctx = (mod.kernel_size[0] - 1) // mod.stride[0]
if ctx > 0:
self._patch_transpose_conv(mod, ctx, trim)
def _patch_causal_conv(self, mod, pad_size):
states = self._states
key = id(mod)
orig = mod.forward
def fwd(x, _k=key, _p=pad_size, _m=mod):
x_pad = torch.cat([states[_k], x], dim=-1) if _k in states else F.pad(x, (_p, 0))
if x.shape[-1] >= _p:
states[_k] = x[:, :, -_p:].detach()
else:
prev = states.get(_k, torch.zeros(x.shape[0], x.shape[1], _p,
device=x.device, dtype=x.dtype))
states[_k] = torch.cat([prev, x], dim=-1)[:, :, -_p:].detach()
return nn.Conv1d.forward(_m, x_pad)
mod.forward = fwd
self._originals.append((mod, orig))
def _patch_transpose_conv(self, mod, ctx, trim):
states = self._states
key = id(mod)
orig = mod.forward
def fwd(x, _k=key, _c=ctx, _t=trim, _m=mod):
x_full = torch.cat([states[_k], x], dim=-1) if _k in states else F.pad(x, (_c, 0))
states[_k] = x[:, :, -_c:].detach()
out = nn.ConvTranspose1d.forward(_m, x_full)
left = _c * _m.stride[0]
return out[..., left:-_t] if _t > 0 else out[..., left:]
mod.forward = fwd
self._originals.append((mod, orig))
def _restore(self):
for mod, orig in self._originals:
mod.forward = orig
self._originals.clear()
+3 -1
View File
@@ -196,7 +196,9 @@ class MiniCPMAttention(nn.Module):
key_cache[:, :, position_id, :] = key_states
value_cache[:, :, position_id, :] = value_states
attn_mask = torch.arange(key_cache.size(2), device=key_cache.device) <= position_id
# Use an explicit broadcastable mask shape for SDPA. A 1D mask can
# trigger a CPU-side dimension bug in some PyTorch versions.
attn_mask = (torch.arange(key_cache.size(2), device=key_cache.device) <= position_id).view(1, 1, 1, -1)
# ref: https://github.com/pytorch/pytorch/issues/163597
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
+3
View File
@@ -15,6 +15,7 @@ from .data import (
BatchProcessor,
)
from .state import TrainingState
from .validate import validate_manifest, ValidationResult
__all__ = [
"Accelerator",
@@ -24,4 +25,6 @@ __all__ = [
"TrainingState",
"load_audio_text_datasets",
"build_dataloader",
"validate_manifest",
"ValidationResult",
]
+53 -11
View File
@@ -12,6 +12,7 @@ from .packers import AudioFeatureProcessingPacker
DEFAULT_TEXT_COLUMN = "text"
DEFAULT_AUDIO_COLUMN = "audio"
DEFAULT_REF_AUDIO_COLUMN = "ref_audio"
DEFAULT_ID_COLUMN = "dataset_id"
@@ -21,6 +22,7 @@ def load_audio_text_datasets(
val_manifest: str = "",
text_column: str = DEFAULT_TEXT_COLUMN,
audio_column: str = DEFAULT_AUDIO_COLUMN,
ref_audio_column: str = DEFAULT_REF_AUDIO_COLUMN,
dataset_id_column: str = DEFAULT_ID_COLUMN,
sample_rate: int = 16_000,
num_proc: int = 1,
@@ -34,14 +36,19 @@ def load_audio_text_datasets(
def prepare(ds: Dataset) -> Dataset:
if audio_column not in ds.column_names:
raise ValueError(f"Expected '{audio_column}' column in manifest.")
# We cast to Audio to ensure proper handling during training,
# but for length calculation we might need raw path or duration if available.
# HF datasets usually don't compute duration automatically for 'Audio' column.
ds = ds.cast_column(audio_column, Audio(sampling_rate=sample_rate))
if audio_column != DEFAULT_AUDIO_COLUMN:
ds = ds.rename_column(audio_column, DEFAULT_AUDIO_COLUMN)
if text_column != DEFAULT_TEXT_COLUMN:
ds = ds.rename_column(text_column, DEFAULT_TEXT_COLUMN)
# ref_audio is optional — cast to Audio if the column exists
ref_col = ref_audio_column if ref_audio_column in ds.column_names else DEFAULT_REF_AUDIO_COLUMN
if ref_col in ds.column_names:
ds = ds.cast_column(ref_col, Audio(sampling_rate=sample_rate))
if ref_col != DEFAULT_REF_AUDIO_COLUMN:
ds = ds.rename_column(ref_col, DEFAULT_REF_AUDIO_COLUMN)
if dataset_id_column and dataset_id_column in ds.column_names:
if dataset_id_column != DEFAULT_ID_COLUMN:
ds = ds.rename_column(dataset_id_column, DEFAULT_ID_COLUMN)
@@ -67,11 +74,11 @@ def compute_sample_lengths(
- 音频长度:
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
t_seq = ceil(t_vae / patch_size)
- 序列总长约为: text_len + t_seq + 2
- 无 ref_audio: text_len + t_seq + 2
- 有 ref_audio: text_len + t_seq + ref_seq + 4
Optimized: Use batch column access instead of iterating item by item.
"""
# Batch access columns - much faster than per-item access
text_ids_list = ds["text_ids"]
text_lens = [len(t) for t in text_ids_list]
@@ -79,18 +86,35 @@ def compute_sample_lengths(
if has_duration:
durations = ds["duration"]
else:
# Fallback: need to compute from audio (slow, but unavoidable without duration column)
durations = []
for i in range(len(ds)):
audio = ds[i][DEFAULT_AUDIO_COLUMN]
durations.append(len(audio["array"]) / float(audio["sampling_rate"]))
# Vectorized length computation
has_ref_audio = DEFAULT_REF_AUDIO_COLUMN in ds.column_names
if has_ref_audio:
ref_duration_col = "ref_duration" if "ref_duration" in ds.column_names else None
lengths = []
for text_len, duration in zip(text_lens, durations):
for i, (text_len, duration) in enumerate(zip(text_lens, durations)):
t_vae = math.ceil(float(duration) * audio_vae_fps)
t_seq = math.ceil(t_vae / patch_size)
total_len = text_len + t_seq + 2
ref_seq = 0
if has_ref_audio:
# Estimate ref_audio length; ref_audio is None for samples without it
if ref_duration_col:
ref_dur = ds[i].get(ref_duration_col)
else:
ref_item = ds[i].get(DEFAULT_REF_AUDIO_COLUMN)
ref_dur = len(ref_item["array"]) / float(ref_item["sampling_rate"]) if ref_item else None
if ref_dur is not None and float(ref_dur) > 0:
ref_vae = math.ceil(float(ref_dur) * audio_vae_fps)
ref_seq = math.ceil(ref_vae / patch_size)
# +2 for 101/102; +2 more for 103/104 when ref_audio present
overhead = 4 if ref_seq > 0 else 2
total_len = text_len + t_seq + ref_seq + overhead
lengths.append(total_len)
return lengths
@@ -102,8 +126,11 @@ class HFVoxCPMDataset(TorchDataset):
PyTorch-friendly samples.
"""
_SENTINEL = [-100.0]
def __init__(self, dataset: Dataset):
self.dataset = dataset
self.has_ref_audio = DEFAULT_REF_AUDIO_COLUMN in dataset.column_names
def __len__(self):
return len(self.dataset)
@@ -111,13 +138,17 @@ class HFVoxCPMDataset(TorchDataset):
def __getitem__(self, idx: int):
item = self.dataset[idx]
audio = item[DEFAULT_AUDIO_COLUMN]
return {
sample = {
"text_ids": item["text_ids"],
"audio_array": audio["array"],
"audio_sampling_rate": audio["sampling_rate"],
"dataset_id": item.get(DEFAULT_ID_COLUMN, 0),
"is_prompt": item.get("is_prompt", False),
}
if self.has_ref_audio:
ref = item.get(DEFAULT_REF_AUDIO_COLUMN)
sample["ref_audio_array"] = ref["array"] if ref else self._SENTINEL
return sample
@staticmethod
def pad_sequences(seqs: List[torch.Tensor], pad_value: float):
@@ -143,7 +174,7 @@ class HFVoxCPMDataset(TorchDataset):
audio_padded = cls.pad_sequences(audio_tensors, pad_value=-100.0)
task_ids = torch.ones(text_padded.size(0), dtype=torch.int32)
return {
result = {
"text_tokens": text_padded,
"audio_tokens": audio_padded,
"task_ids": task_ids,
@@ -151,6 +182,12 @@ class HFVoxCPMDataset(TorchDataset):
"is_prompts": is_prompts,
}
if "ref_audio_array" in batch[0]:
ref_tensors = [torch.tensor(s["ref_audio_array"], dtype=torch.float32) for s in batch]
result["ref_audio_tokens"] = cls.pad_sequences(ref_tensors, pad_value=-100.0)
return result
class BatchProcessor:
"""
@@ -184,12 +221,17 @@ class BatchProcessor:
task_ids = batch["task_ids"].to(self.device)
dataset_ids = batch["dataset_ids"].to(self.device)
ref_audio_tokens = None
if "ref_audio_tokens" in batch:
ref_audio_tokens = batch["ref_audio_tokens"].to(self.device)
packed = self.packer(
audio_tokens=audio_tokens,
text_tokens=text_tokens,
task_ids=task_ids,
dataset_ids=dataset_ids,
is_prompts=batch["is_prompts"],
ref_audio_tokens=ref_audio_tokens,
)
return packed
+124 -14
View File
@@ -1,4 +1,4 @@
from typing import Dict, List
from typing import Dict, List, Optional
import torch
import torch.nn as nn
@@ -14,7 +14,6 @@ class AudioFeatureProcessingPacker:
def __init__(self, dataset_cnt: int, max_len: int, patch_size: int, feat_dim: int, audio_vae: nn.Module):
self.audio_start_id = 101
self.audio_end_id = 102
# unused now
self.audio_prompt_start_id = 103
self.audio_prompt_end_id = 104
self.text_eos_token_id = 2
@@ -78,11 +77,16 @@ class AudioFeatureProcessingPacker:
task_ids: torch.Tensor,
dataset_ids: torch.Tensor,
is_prompts: List[bool],
ref_audio_tokens: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""
Padding-based batching: each sample in the input batch is processed
independently and then padded to a common length (capped by ``max_len``).
The result tensors all have shape [B, T, ...].
If ``ref_audio_tokens`` is provided (same batch dim as ``audio_tokens``),
samples whose unpadded ref_audio length > 0 will be processed with the
reference-audio path (tokens 103/104 prepended, loss only on target audio).
"""
device = audio_tokens.device
max_dataset_id = int(dataset_ids.max().item()) if dataset_ids.numel() > 0 else -1
@@ -101,23 +105,43 @@ class AudioFeatureProcessingPacker:
audio_duration_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
text_token_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
for audio_token, text_token, task_id, dataset_idx, is_prompt in zip(
audio_tokens, text_tokens, task_ids.tolist(), dataset_ids.tolist(), is_prompts
ref_iter = ref_audio_tokens if ref_audio_tokens is not None else [None] * audio_tokens.size(0)
for audio_token, text_token, task_id, dataset_idx, is_prompt, ref_token in zip(
audio_tokens, text_tokens, task_ids.tolist(), dataset_ids.tolist(), is_prompts, ref_iter
):
unpad_audio_token = self.unpad_audio_tokens(audio_token).to(torch.float32)
unpad_text_token = self.unpad_text_tokens(text_token)
usage = self.id_to_task[task_id]
(
packed_text,
audio_feat,
text_mask,
audio_mask,
loss_mask,
labels,
audio_duration,
text_token_count,
) = self.process_functions[usage](unpad_audio_token, unpad_text_token, is_prompt)
has_ref = False
if ref_token is not None:
unpad_ref_token = self.unpad_audio_tokens(ref_token).to(torch.float32)
if unpad_ref_token.numel() > 0:
has_ref = True
if has_ref:
(
packed_text,
audio_feat,
text_mask,
audio_mask,
loss_mask,
labels,
audio_duration,
text_token_count,
) = self.process_tts_data_with_ref(unpad_ref_token, unpad_audio_token, unpad_text_token)
else:
(
packed_text,
audio_feat,
text_mask,
audio_mask,
loss_mask,
labels,
audio_duration,
text_token_count,
) = self.process_functions[usage](unpad_audio_token, unpad_text_token, is_prompt)
audio_duration_consumed[dataset_idx] += audio_duration
text_token_consumed[dataset_idx] += text_token_count
@@ -294,3 +318,89 @@ class AudioFeatureProcessingPacker:
audio_duration,
text_token_count,
)
def process_tts_data_with_ref(
self,
ref_audio_token: torch.Tensor,
target_audio_token: torch.Tensor,
text_token: torch.Tensor,
):
"""
Build a training sequence with reference audio prepended:
[103, ref_feats, 104, text, 101, target_feats, 102]
Loss is computed only on the target audio segment.
"""
device = text_token.device
txt_len = len(text_token)
ref_feats, ref_duration = self.extract_audio_feats(ref_audio_token)
ref_feats = ref_feats.squeeze(0) # [R, P, D]
ref_len = ref_feats.shape[0]
tgt_feats, tgt_duration = self.extract_audio_feats(target_audio_token)
tgt_feats = tgt_feats.squeeze(0) # [A, P, D]
tgt_len = tgt_feats.shape[0]
feat_shape = (self.patch_size, ref_feats.size(-1))
def _tok(ids):
return torch.tensor(ids, dtype=torch.int32, device=device)
# -- text token track --
# [103, 0×R, 104, text_ids, 101, 0×A, 102]
text_token_info = torch.cat([
_tok([self.audio_prompt_start_id]),
torch.zeros(ref_len, dtype=torch.int32, device=device),
_tok([self.audio_prompt_end_id]),
text_token,
_tok([self.audio_start_id]),
torch.zeros(tgt_len, dtype=torch.int32, device=device),
_tok([self.audio_end_id]),
])
# -- audio feature track --
zero_1 = torch.zeros((1,) + feat_shape, dtype=torch.float32, device=device)
zero_txt = torch.zeros((txt_len,) + feat_shape, dtype=torch.float32, device=device)
audio_feat_info = torch.cat([
zero_1, ref_feats, zero_1, # 103, ref, 104
zero_txt, # text
zero_1, tgt_feats, zero_1, # 101, target, 102
], dim=0)
# -- masks --
text_mask = torch.cat([
torch.ones(1), torch.zeros(ref_len), torch.ones(1),
torch.ones(txt_len),
torch.ones(1), torch.zeros(tgt_len), torch.ones(1),
]).to(torch.int32).to(device)
audio_mask = torch.cat([
torch.zeros(1), torch.ones(ref_len), torch.zeros(1),
torch.zeros(txt_len),
torch.zeros(1), torch.ones(tgt_len), torch.zeros(1),
]).to(torch.int32).to(device)
loss_mask = torch.cat([
torch.zeros(1 + ref_len + 1), # ref part: no loss
torch.zeros(txt_len), # text: no loss
torch.zeros(1), # 101: no loss
torch.ones(tgt_len), # target audio: LOSS
torch.zeros(1), # 102: no loss
]).to(torch.int32).to(device)
total_len = 1 + ref_len + 1 + txt_len + 1 + tgt_len + 1
labels = torch.zeros(total_len, dtype=torch.int32, device=device)
labels[-2] = 1 # stop label at last target audio position
return (
text_token_info,
audio_feat_info,
text_mask,
audio_mask,
loss_mask,
labels,
ref_duration + tgt_duration,
txt_len,
)
+309
View File
@@ -0,0 +1,309 @@
"""
Pre-flight validation for VoxCPM training data manifests.
Validates JSONL manifest files before starting expensive fine-tuning jobs,
catching format issues, missing files, and data quality problems early.
"""
import json
import os
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional
@dataclass
class ValidationResult:
"""Structured result of a manifest validation run."""
total_samples: int = 0
valid_samples: int = 0
errors: List[str] = field(default_factory=list)
warnings: List[str] = field(default_factory=list)
audio_durations: List[float] = field(default_factory=list)
text_lengths: List[int] = field(default_factory=list)
has_ref_audio: int = 0
@property
def is_valid(self) -> bool:
return len(self.errors) == 0 and self.valid_samples > 0
def _check_audio_file(audio_path: str, sample_rate: int) -> Optional[str]:
"""Check if an audio file exists, is readable, and matches expected sample rate.
Returns an error message, or None if the file is valid.
"""
if not os.path.isfile(audio_path):
return f"Audio file not found: {audio_path}"
try:
import soundfile as sf
info = sf.info(audio_path)
if info.frames == 0:
return f"Audio file is empty: {audio_path}"
if info.samplerate != sample_rate:
return (
f"Sample rate mismatch in {audio_path}: "
f"expected {sample_rate} Hz, got {info.samplerate} Hz"
)
return None
except ImportError:
# soundfile not available; just check existence
return None
except Exception as e:
return f"Cannot read audio file {audio_path}: {e}"
def _get_audio_duration(audio_path: str) -> Optional[float]:
"""Get audio duration in seconds. Returns None if unavailable."""
try:
import soundfile as sf
info = sf.info(audio_path)
return info.duration
except Exception:
return None
def validate_manifest(
manifest_path: str,
sample_rate: int = 16_000,
max_samples: int = 0,
verbose: bool = False,
) -> ValidationResult:
"""Validate a JSONL training manifest file.
Checks:
1. File exists and is readable
2. Each line is valid JSON
3. Required columns present (text, audio)
4. Audio files exist and are readable
5. Text content is non-empty
6. Collects duration and text length statistics
7. Validates optional ref_audio column
Args:
manifest_path: Path to the JSONL manifest file.
sample_rate: Expected audio sample rate (for informational purposes).
max_samples: Maximum number of samples to validate (0 = all).
verbose: Print per-sample progress.
Returns:
ValidationResult with errors, warnings, and statistics.
"""
result = ValidationResult()
path = Path(manifest_path)
if not path.exists():
result.errors.append(f"Manifest file not found: {manifest_path}")
return result
if not path.is_file():
result.errors.append(f"Manifest path is not a file: {manifest_path}")
return result
manifest_dir = path.parent
try:
with open(path, "r", encoding="utf-8") as f:
lines = f.readlines()
except Exception as e:
result.errors.append(f"Cannot read manifest file: {e}")
return result
if not lines:
result.errors.append("Manifest file is empty")
return result
samples_to_check = len(lines)
if max_samples > 0:
samples_to_check = min(samples_to_check, max_samples)
missing_audio_count = 0
empty_text_count = 0
for i, line in enumerate(lines[:samples_to_check]):
line = line.strip()
if not line:
continue
result.total_samples += 1
# Check JSON validity
try:
entry = json.loads(line)
except json.JSONDecodeError as e:
result.errors.append(f"Line {i + 1}: Invalid JSON — {e}")
continue
if not isinstance(entry, dict):
result.errors.append(f"Line {i + 1}: Expected JSON object, got {type(entry).__name__}")
continue
# Check required columns
has_error = False
if "text" not in entry:
result.errors.append(f"Line {i + 1}: Missing required column 'text'")
has_error = True
if "audio" not in entry:
result.errors.append(f"Line {i + 1}: Missing required column 'audio'")
has_error = True
if has_error:
continue
# Validate text
text = entry["text"]
if not isinstance(text, str) or not text.strip():
empty_text_count += 1
if empty_text_count <= 5:
result.warnings.append(f"Line {i + 1}: Empty or non-string text")
else:
result.text_lengths.append(len(text))
# Validate audio path
audio_path = entry["audio"]
if isinstance(audio_path, dict):
# HuggingFace Audio format with {"path": ..., "array": ...}
audio_path = audio_path.get("path", "")
if isinstance(audio_path, str) and audio_path:
# Resolve relative paths against manifest directory
if not os.path.isabs(audio_path):
audio_path = str(manifest_dir / audio_path)
audio_error = _check_audio_file(audio_path, sample_rate)
if audio_error:
missing_audio_count += 1
if missing_audio_count <= 5:
result.errors.append(f"Line {i + 1}: {audio_error}")
has_error = True
else:
duration = _get_audio_duration(audio_path)
if duration is not None:
result.audio_durations.append(duration)
if duration < 0.3:
result.warnings.append(
f"Line {i + 1}: Very short audio ({duration:.2f}s)"
)
elif duration > 30.0:
result.warnings.append(
f"Line {i + 1}: Very long audio ({duration:.1f}s), may cause OOM"
)
else:
result.errors.append(f"Line {i + 1}: Invalid audio path")
has_error = True
# Validate optional ref_audio
if "ref_audio" in entry:
ref_path = entry["ref_audio"]
if isinstance(ref_path, dict):
ref_path = ref_path.get("path", "")
if isinstance(ref_path, str) and ref_path:
if not os.path.isabs(ref_path):
ref_path = str(manifest_dir / ref_path)
if os.path.isfile(ref_path):
result.has_ref_audio += 1
else:
result.warnings.append(
f"Line {i + 1}: ref_audio file not found: {ref_path}"
)
if not has_error:
result.valid_samples += 1
if verbose and (i + 1) % 100 == 0:
print(f" Validated {i + 1}/{samples_to_check} samples...", file=sys.stderr)
# Summarize truncated errors
if missing_audio_count > 5:
result.errors.append(
f"... and {missing_audio_count - 5} more missing audio files "
f"({missing_audio_count} total)"
)
if empty_text_count > 5:
result.warnings.append(
f"... and {empty_text_count - 5} more empty text entries "
f"({empty_text_count} total)"
)
return result
def print_validation_report(result: ValidationResult, manifest_path: str) -> None:
"""Print a human-readable validation report to stderr."""
print(f"\n{'=' * 60}", file=sys.stderr)
print(f" VoxCPM Training Data Validation Report", file=sys.stderr)
print(f"{'=' * 60}", file=sys.stderr)
print(f" Manifest : {manifest_path}", file=sys.stderr)
print(f" Samples : {result.valid_samples}/{result.total_samples} valid", file=sys.stderr)
if result.has_ref_audio > 0:
print(
f" Ref Audio: {result.has_ref_audio} samples with reference audio",
file=sys.stderr,
)
# Audio duration statistics
if result.audio_durations:
durations = sorted(result.audio_durations)
total_hrs = sum(durations) / 3600
print(f"\n Audio Duration Statistics:", file=sys.stderr)
print(f" Total : {total_hrs:.2f} hours", file=sys.stderr)
print(
f" Range : {durations[0]:.2f}s — {durations[-1]:.1f}s",
file=sys.stderr,
)
print(
f" Mean : {sum(durations) / len(durations):.2f}s",
file=sys.stderr,
)
median_idx = len(durations) // 2
print(f" Median : {durations[median_idx]:.2f}s", file=sys.stderr)
# Text length statistics
if result.text_lengths:
lengths = sorted(result.text_lengths)
print(f"\n Text Length Statistics (characters):", file=sys.stderr)
print(
f" Range : {lengths[0]}{lengths[-1]}",
file=sys.stderr,
)
print(
f" Mean : {sum(lengths) / len(lengths):.0f}",
file=sys.stderr,
)
# Errors
if result.errors:
print(f"\n ERRORS ({len(result.errors)}):", file=sys.stderr)
for err in result.errors[:20]:
print(f" x {err}", file=sys.stderr)
if len(result.errors) > 20:
print(
f" ... ({len(result.errors) - 20} more errors omitted)",
file=sys.stderr,
)
# Warnings
if result.warnings:
print(f"\n WARNINGS ({len(result.warnings)}):", file=sys.stderr)
for warn in result.warnings[:10]:
print(f" ! {warn}", file=sys.stderr)
if len(result.warnings) > 10:
print(
f" ... ({len(result.warnings) - 10} more warnings omitted)",
file=sys.stderr,
)
# Summary
print(f"\n{'=' * 60}", file=sys.stderr)
if result.is_valid:
print(" PASSED: Manifest is valid for training.", file=sys.stderr)
else:
print(" FAILED: Fix errors above before starting training.", file=sys.stderr)
print(f"{'=' * 60}\n", file=sys.stderr)
+31
View File
@@ -58,6 +58,7 @@ def test_parser_defaults_to_voxcpm2():
parser = cli._build_parser()
args = parser.parse_args(["design", "--text", "hello", "--output", "out.wav"])
assert args.hf_model_id == "openbmb/VoxCPM2"
assert args.device == "auto"
assert args.no_optimize is False
@@ -85,6 +86,7 @@ def test_load_model_respects_no_optimize_for_local_model(monkeypatch):
cli.load_model(args)
assert calls["kwargs"]["device"] == "auto"
assert calls["kwargs"]["optimize"] is False
@@ -110,6 +112,7 @@ def test_load_model_defaults_optimize_for_hf(monkeypatch):
cli.load_model(args)
assert calls["kwargs"]["device"] == "auto"
assert calls["kwargs"]["optimize"] is True
@@ -136,9 +139,37 @@ def test_load_model_respects_no_optimize_for_hf(monkeypatch):
cli.load_model(args)
assert calls["kwargs"]["device"] == "auto"
assert calls["kwargs"]["optimize"] is False
def test_load_model_passes_explicit_device_to_hf(monkeypatch):
calls = {}
class FakeVoxCPM:
@classmethod
def from_pretrained(cls, **kwargs):
calls["kwargs"] = kwargs
return DummyModel()
monkeypatch.setattr(cli, "VoxCPM", FakeVoxCPM)
args = cli._build_parser().parse_args(
[
"design",
"--text",
"hello",
"--output",
"out.wav",
"--device",
"mps",
]
)
cli.load_model(args)
assert calls["kwargs"]["device"] == "mps"
def test_design_subcommand_applies_control(monkeypatch, tmp_path):
dummy_model = DummyModel()
monkeypatch.setattr(cli, "load_model", lambda args: dummy_model)
+150
View File
@@ -0,0 +1,150 @@
from __future__ import annotations
import importlib.util
import sys
import types
from pathlib import Path
import pytest
import torch
ROOT = Path(__file__).resolve().parents[1]
SRC = ROOT / "src"
def _load_module(name: str, path: Path):
spec = importlib.util.spec_from_file_location(name, path)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
sys.modules[name] = module
spec.loader.exec_module(module)
return module
def bootstrap_repo_modules(monkeypatch):
for name, path in [
("voxcpm", SRC / "voxcpm"),
("voxcpm.model", SRC / "voxcpm" / "model"),
("voxcpm.modules", SRC / "voxcpm" / "modules"),
]:
pkg = types.ModuleType(name)
pkg.__path__ = [str(path)]
monkeypatch.setitem(sys.modules, name, pkg)
hh = types.ModuleType("huggingface_hub")
hh.snapshot_download = lambda *a, **k: "/tmp/fake"
monkeypatch.setitem(sys.modules, "huggingface_hub", hh)
pydantic = types.ModuleType("pydantic")
class BaseModel:
@classmethod
def model_rebuild(cls):
return None
@classmethod
def model_validate_json(cls, s):
return cls()
def model_dump(self):
return {}
pydantic.BaseModel = BaseModel
monkeypatch.setitem(sys.modules, "pydantic", pydantic)
torchaudio = types.ModuleType("torchaudio")
monkeypatch.setitem(sys.modules, "torchaudio", torchaudio)
librosa = types.ModuleType("librosa")
librosa.effects = types.SimpleNamespace(trim=lambda *a, **k: (None, (0, 0)))
monkeypatch.setitem(sys.modules, "librosa", librosa)
einops = types.ModuleType("einops")
einops.rearrange = lambda x, *a, **k: x
monkeypatch.setitem(sys.modules, "einops", einops)
tqdm_pkg = types.ModuleType("tqdm")
tqdm_pkg.__path__ = ["/nonexistent"]
tqdm_pkg.tqdm = lambda x, *a, **k: x
monkeypatch.setitem(sys.modules, "tqdm", tqdm_pkg)
tqdm_auto = types.ModuleType("tqdm.auto")
tqdm_auto.tqdm = lambda x, *a, **k: x
monkeypatch.setitem(sys.modules, "tqdm.auto", tqdm_auto)
transformers = types.ModuleType("transformers")
class LlamaTokenizerFast:
pass
class PreTrainedTokenizer:
pass
transformers.LlamaTokenizerFast = LlamaTokenizerFast
transformers.PreTrainedTokenizer = PreTrainedTokenizer
monkeypatch.setitem(sys.modules, "transformers", transformers)
internal_mods = {
"voxcpm.modules.audiovae": ["AudioVAE", "AudioVAEConfig", "AudioVAEV2", "AudioVAEConfigV2"],
"voxcpm.modules.layers": ["ScalarQuantizationLayer"],
"voxcpm.modules.locdit": ["CfmConfig", "UnifiedCFM", "VoxCPMLocDiT", "VoxCPMLocDiTV2"],
"voxcpm.modules.locenc": ["VoxCPMLocEnc"],
"voxcpm.modules.minicpm4": ["MiniCPM4Config", "MiniCPMModel"],
"voxcpm.modules.layers.lora": ["apply_lora_to_named_linear_modules", "LoRALinear"],
}
for modname, names in internal_mods.items():
module = types.ModuleType(modname)
for name in names:
if name == "apply_lora_to_named_linear_modules":
setattr(module, name, lambda *a, **k: None)
else:
setattr(module, name, type(name, (), {}))
monkeypatch.setitem(sys.modules, modname, module)
_load_module("voxcpm.model.utils", SRC / "voxcpm" / "model" / "utils.py")
voxcpm = _load_module("voxcpm.model.voxcpm", SRC / "voxcpm" / "model" / "voxcpm.py")
voxcpm2 = _load_module("voxcpm.model.voxcpm2", SRC / "voxcpm" / "model" / "voxcpm2.py")
return voxcpm.VoxCPMModel, voxcpm2.VoxCPM2Model
class DummyModel:
device = "cpu"
def named_parameters(self):
return []
@pytest.mark.parametrize("module_name", ["v1", "v2"])
def test_load_lora_weights_accepts_tensor_only_legacy_checkpoints(monkeypatch, tmp_path, module_name):
VoxCPMModel, VoxCPM2Model = bootstrap_repo_modules(monkeypatch)
cls = VoxCPMModel if module_name == "v1" else VoxCPM2Model
ckpt_path = tmp_path / "lora_weights.ckpt"
torch.save({"state_dict": {"fake": torch.zeros(1)}}, ckpt_path)
loaded, skipped = cls.load_lora_weights(DummyModel(), str(ckpt_path), device="cpu")
assert loaded == []
assert skipped == ["fake"]
@pytest.mark.parametrize("module_name", ["v1", "v2"])
def test_load_lora_weights_rejects_malicious_pickle_payloads(monkeypatch, tmp_path, module_name):
VoxCPMModel, VoxCPM2Model = bootstrap_repo_modules(monkeypatch)
cls = VoxCPMModel if module_name == "v1" else VoxCPM2Model
ckpt_path = tmp_path / "lora_weights.ckpt"
marker_path = tmp_path / f"{module_name}-marker.txt"
class Exploit:
def __reduce__(self):
import pathlib
return (pathlib.Path.write_text, (marker_path, f"{module_name} executed\n"))
torch.save({"state_dict": {"fake": torch.zeros(1)}, "boom": Exploit()}, ckpt_path)
with pytest.raises(Exception, match="Weights only load failed"):
cls.load_lora_weights(DummyModel(), str(ckpt_path), device="cpu")
assert not marker_path.exists()
+49
View File
@@ -0,0 +1,49 @@
from __future__ import annotations
import importlib.util
import sys
import types
from pathlib import Path
import pytest
ROOT = Path(__file__).resolve().parents[1]
UTILS_PATH = ROOT / "src" / "voxcpm" / "model" / "utils.py"
transformers_stub = types.ModuleType("transformers")
transformers_stub.PreTrainedTokenizer = object
sys.modules.setdefault("transformers", transformers_stub)
spec = importlib.util.spec_from_file_location("voxcpm.model.utils", UTILS_PATH)
utils = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(utils)
def test_resolve_runtime_device_auto_falls_back_to_cpu(monkeypatch):
monkeypatch.setattr(utils.torch.cuda, "is_available", lambda: False)
monkeypatch.setattr(utils, "_has_mps", lambda: False)
assert utils.resolve_runtime_device(None, "cuda") == "cpu"
def test_resolve_runtime_device_auto_uses_mps_when_available(monkeypatch):
monkeypatch.setattr(utils.torch.cuda, "is_available", lambda: False)
monkeypatch.setattr(utils, "_has_mps", lambda: True)
assert utils.resolve_runtime_device("auto", "cuda") == "mps"
def test_resolve_runtime_device_respects_explicit_cpu(monkeypatch):
monkeypatch.setattr(utils.torch.cuda, "is_available", lambda: True)
monkeypatch.setattr(utils, "_has_mps", lambda: True)
assert utils.resolve_runtime_device("cpu", "cuda") == "cpu"
def test_resolve_runtime_device_rejects_unavailable_explicit_cuda(monkeypatch):
monkeypatch.setattr(utils.torch.cuda, "is_available", lambda: False)
monkeypatch.setattr(utils, "_has_mps", lambda: True)
with pytest.raises(ValueError, match="CUDA is not available"):
utils.resolve_runtime_device("cuda:0", "cuda")
+252
View File
@@ -0,0 +1,252 @@
"""Tests for the training data validation module."""
from __future__ import annotations
import json
import os
import sys
import tempfile
import types
from pathlib import Path
import pytest
ROOT = Path(__file__).resolve().parents[1]
# Stub voxcpm package so imports work without full dependencies
pkg = types.ModuleType("voxcpm")
pkg.__path__ = [str(ROOT / "src" / "voxcpm")]
sys.modules.setdefault("voxcpm", pkg)
training_pkg = types.ModuleType("voxcpm.training")
training_pkg.__path__ = [str(ROOT / "src" / "voxcpm" / "training")]
sys.modules.setdefault("voxcpm.training", training_pkg)
from voxcpm.training.validate import ValidationResult, validate_manifest
@pytest.fixture
def tmp_dir():
with tempfile.TemporaryDirectory() as d:
yield Path(d)
def _create_wav(path: Path, duration_s: float = 1.0, sr: int = 16000):
"""Create a minimal valid WAV file."""
try:
import soundfile as sf
import numpy as np
samples = int(duration_s * sr)
data = np.zeros(samples, dtype=np.float32)
sf.write(str(path), data, sr)
except ImportError:
# If soundfile is not available, create a minimal WAV header
import struct
samples = int(duration_s * sr)
data_size = samples * 2 # 16-bit PCM
with open(path, "wb") as f:
f.write(b"RIFF")
f.write(struct.pack("<I", 36 + data_size))
f.write(b"WAVEfmt ")
f.write(struct.pack("<IHHIIHH", 16, 1, 1, sr, sr * 2, 2, 16))
f.write(b"data")
f.write(struct.pack("<I", data_size))
f.write(b"\x00" * data_size)
def _write_manifest(path: Path, entries: list[dict]):
with open(path, "w", encoding="utf-8") as f:
for entry in entries:
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
class TestValidateManifest:
def test_valid_manifest(self, tmp_dir):
audio1 = tmp_dir / "audio1.wav"
audio2 = tmp_dir / "audio2.wav"
_create_wav(audio1, 2.0)
_create_wav(audio2, 3.0)
manifest = tmp_dir / "train.jsonl"
_write_manifest(
manifest,
[
{"text": "Hello world", "audio": str(audio1)},
{"text": "Goodbye world", "audio": str(audio2)},
],
)
result = validate_manifest(str(manifest))
assert result.total_samples == 2
assert result.valid_samples == 2
assert result.is_valid
assert len(result.errors) == 0
def test_missing_manifest(self):
result = validate_manifest("/nonexistent/path.jsonl")
assert not result.is_valid
assert any("not found" in e for e in result.errors)
def test_empty_manifest(self, tmp_dir):
manifest = tmp_dir / "empty.jsonl"
manifest.write_text("")
result = validate_manifest(str(manifest))
assert not result.is_valid
def test_invalid_json(self, tmp_dir):
manifest = tmp_dir / "bad.jsonl"
manifest.write_text("not json\n{bad json}\n")
result = validate_manifest(str(manifest))
assert len(result.errors) >= 2
assert any("Invalid JSON" in e for e in result.errors)
def test_missing_columns(self, tmp_dir):
manifest = tmp_dir / "missing.jsonl"
_write_manifest(
manifest,
[
{"text": "hello"}, # missing audio
{"audio": "test.wav"}, # missing text
],
)
result = validate_manifest(str(manifest))
assert len(result.errors) >= 2
assert any("'audio'" in e for e in result.errors)
assert any("'text'" in e for e in result.errors)
def test_missing_audio_file(self, tmp_dir):
manifest = tmp_dir / "missing_audio.jsonl"
_write_manifest(
manifest,
[{"text": "hello", "audio": "/nonexistent/audio.wav"}],
)
result = validate_manifest(str(manifest))
assert not result.is_valid
assert any("not found" in e for e in result.errors)
def test_empty_text_warning(self, tmp_dir):
audio = tmp_dir / "audio.wav"
_create_wav(audio)
manifest = tmp_dir / "empty_text.jsonl"
_write_manifest(
manifest,
[{"text": "", "audio": str(audio)}],
)
result = validate_manifest(str(manifest))
assert len(result.warnings) > 0
assert any("Empty" in w for w in result.warnings)
def test_relative_audio_path(self, tmp_dir):
audio = tmp_dir / "audio.wav"
_create_wav(audio)
manifest = tmp_dir / "rel.jsonl"
_write_manifest(
manifest,
[{"text": "hello", "audio": "audio.wav"}],
)
result = validate_manifest(str(manifest))
assert result.valid_samples == 1
assert result.is_valid
def test_max_samples_limit(self, tmp_dir):
audio = tmp_dir / "audio.wav"
_create_wav(audio)
manifest = tmp_dir / "many.jsonl"
_write_manifest(
manifest,
[{"text": f"sample {i}", "audio": str(audio)} for i in range(100)],
)
result = validate_manifest(str(manifest), max_samples=10)
assert result.total_samples == 10
def test_ref_audio_counted(self, tmp_dir):
audio = tmp_dir / "audio.wav"
ref = tmp_dir / "ref.wav"
_create_wav(audio)
_create_wav(ref)
manifest = tmp_dir / "ref.jsonl"
_write_manifest(
manifest,
[{"text": "hello", "audio": str(audio), "ref_audio": str(ref)}],
)
result = validate_manifest(str(manifest))
assert result.has_ref_audio == 1
def test_validation_result_properties(self):
r = ValidationResult(total_samples=5, valid_samples=5)
assert r.is_valid
r2 = ValidationResult(total_samples=5, valid_samples=5, errors=["err"])
assert not r2.is_valid
r3 = ValidationResult(total_samples=0, valid_samples=0)
assert not r3.is_valid
def test_invalid_audio_not_counted_as_valid(self, tmp_dir):
"""A row with a bad audio path must not increment valid_samples."""
manifest = tmp_dir / "bad_audio.jsonl"
_write_manifest(
manifest,
[{"text": "hello", "audio": "/nonexistent/audio.wav"}],
)
result = validate_manifest(str(manifest))
assert result.total_samples == 1
assert result.valid_samples == 0
assert not result.is_valid
assert any("not found" in e for e in result.errors)
def test_sample_rate_mismatch(self, tmp_dir):
"""A file with a different sample rate should be reported as an error."""
try:
import soundfile as sf
import numpy as np
except ImportError:
pytest.skip("soundfile not available")
audio = tmp_dir / "audio_8k.wav"
import numpy as np
samples = np.zeros(8000, dtype=np.float32)
sf.write(str(audio), samples, 8000)
manifest = tmp_dir / "sr_mismatch.jsonl"
_write_manifest(manifest, [{"text": "hello", "audio": str(audio)}])
result = validate_manifest(str(manifest), sample_rate=16000)
assert result.valid_samples == 0
assert not result.is_valid
assert any("Sample rate mismatch" in e or "sample rate" in e.lower() for e in result.errors)
def test_mixed_ref_audio_warns_for_each_missing(self, tmp_dir):
"""Missing ref_audio entries should each generate a warning independently."""
audio = tmp_dir / "audio.wav"
ref_good = tmp_dir / "ref_good.wav"
_create_wav(audio)
_create_wav(ref_good)
manifest = tmp_dir / "mixed_ref.jsonl"
_write_manifest(
manifest,
[
{"text": "row1", "audio": str(audio), "ref_audio": str(ref_good)},
{"text": "row2", "audio": str(audio), "ref_audio": "/nonexistent/ref.wav"},
],
)
result = validate_manifest(str(manifest))
assert result.has_ref_audio == 1
assert any("ref_audio file not found" in w for w in result.warnings)
def test_cli_validate_exit_code(self, tmp_dir):
"""validate subcommand must exit 1 on validation error (missing audio)."""
import subprocess
manifest = tmp_dir / "bad.jsonl"
_write_manifest(manifest, [{"text": "hi", "audio": "/nonexistent/x.wav"}])
proc = subprocess.run(
[sys.executable, "-m", "voxcpm.cli", "validate", "--manifest", str(manifest)],
capture_output=True,
text=True,
)
assert proc.returncode == 1, f"Expected exit 1, got {proc.returncode}"
assert "FAILED" in proc.stderr or "Audio file not found" in proc.stderr