Compare commits
35 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2e4c601d1d | |||
| 525ebc65a3 | |||
| 564b98b2d5 | |||
| 19b6bf7590 | |||
| 86bff0fc82 | |||
| dd7b78f2c0 | |||
| 29577d57f8 | |||
| 4509becfde | |||
| cd79a647fa | |||
| 96d605b9de | |||
| a9b03a768c | |||
| 77f847fcba | |||
| d3cc88722c | |||
| ec2acec8a1 | |||
| 13605c5a0e | |||
| afa63e6195 | |||
| eae0a29908 | |||
| 35895982d7 | |||
| f7f1b78c4d | |||
| 38d61cdf03 | |||
| 1565e83efe | |||
| 61b36d4e56 | |||
| b1584aec7c | |||
| 4457617953 | |||
| 5510503182 | |||
| fb46aad9a5 | |||
| e4e049624c | |||
| abf01b9bf3 | |||
| 4f4a5b9f6c | |||
| 79c0cf68dd | |||
| 75cfa3e9b8 | |||
| 5611bd08a0 | |||
| 66205135fc | |||
| 364eff6840 | |||
| 6d10932b09 |
@@ -2,3 +2,6 @@ launch.json
|
||||
__pycache__
|
||||
voxcpm.egg-info
|
||||
.DS_Store
|
||||
./pretrained_models/
|
||||
app_local.py
|
||||
models/
|
||||
@@ -46,7 +46,7 @@ VoxCPM is a **tokenizer-free** Text-to-Speech system that directly generates con
|
||||
- 🎙️ **Ultimate Cloning** — Reproduce every vocal nuance: provide both reference audio and its transcript, and the model continues seamlessly from the reference, faithfully preserving every vocal detail — timbre, rhythm, emotion, and style (same as VoxCPM1.5)
|
||||
- 🔊 **48kHz High-Quality Audio** — Accepts 16kHz reference audio and directly outputs 48kHz studio-quality audio via AudioVAE V2's asymmetric encode/decode design, with built-in super-resolution — no external upsampler needed
|
||||
- 🧠 **Context-Aware Synthesis** — Automatically infers appropriate prosody and expressiveness from text content
|
||||
- ⚡ **Real-Time Streaming** — RTF as low as ~0.3 on NVIDIA RTX 4090, and ~0.13 accelerated by [Nano-VLLM](https://github.com/a710128/nanovllm-voxcpm)
|
||||
- ⚡ **Real-Time Streaming** — RTF as low as ~0.3 on NVIDIA RTX 4090, and ~0.13 accelerated by [Nano-vLLM](https://github.com/a710128/nanovllm-voxcpm) or [vLLM-Omni](https://github.com/vllm-project/vllm-omni) — official vLLM omni-modal serving for VoxCPM2 with PagedAttention and an OpenAI-compatible API
|
||||
- 📜 **Fully Open-Source & Commercial-Ready** — Weights and code released under the [Apache-2.0](LICENSE) license, free for commercial use
|
||||
|
||||
|
||||
@@ -92,7 +92,7 @@ Chinese Dialect: 四川话, 粤语, 吴语, 东北话, 河南话, 陕西话, 山
|
||||
pip install voxcpm
|
||||
```
|
||||
|
||||
> **Requirements:** Python ≥ 3.10, PyTorch ≥ 2.5.0, CUDA ≥ 12.0. See [Quick Start Docs](https://voxcpm.readthedocs.io/en/latest/quickstart.html) for details.
|
||||
> **Requirements:** Python ≥ 3.10 (<3.13), PyTorch ≥ 2.5.0, CUDA ≥ 12.0. See [Quick Start Docs](https://voxcpm.readthedocs.io/en/latest/quickstart.html) for details.
|
||||
|
||||
### Python API
|
||||
|
||||
@@ -123,12 +123,12 @@ pip install modelscope
|
||||
```
|
||||
|
||||
```python
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope import snapshot_download
|
||||
snapshot_download("OpenBMB/VoxCPM2", local_dir='./pretrained_models/VoxCPM2') # specify the local directory to save the model
|
||||
|
||||
from voxcpm import VoxCPM
|
||||
import soundfile as sf
|
||||
|
||||
local_model_dir = snapshot_download("OpenBMB/VoxCPM2")
|
||||
model = VoxCPM.from_pretrained(local_model_dir, load_denoiser=False)
|
||||
model = VoxCPM.from_pretrained("./pretrained_models/VoxCPM2", load_denoiser=False)
|
||||
|
||||
wav = model.generate(
|
||||
text="VoxCPM2 is the current recommended release for realistic multilingual speech synthesis.",
|
||||
@@ -239,7 +239,7 @@ voxcpm --help
|
||||
### Web Demo
|
||||
|
||||
```bash
|
||||
python app.py --model-dir /path/to/VoxCPM2 --port 8808 # use a local model directory, open http://localhost:8808
|
||||
python app.py --port 8808 # then open in browser: http://localhost:8808
|
||||
```
|
||||
|
||||
### 🚢 Production Deployment (Nano-vLLM)
|
||||
@@ -262,6 +262,32 @@ server.stop()
|
||||
|
||||
> **RTF as low as ~0.13 on NVIDIA RTX 4090** (vs ~0.3 with the standard PyTorch implementation), with support for batched concurrent requests and a FastAPI HTTP server. See the [Nano-vLLM-VoxCPM repo](https://github.com/a710128/nanovllm-voxcpm) for deployment details.
|
||||
|
||||
### 🏭 Production Serving (vLLM-Omni)
|
||||
|
||||
For production multi-tenant deployments, use [**vLLM-Omni**](https://github.com/vllm-project/vllm-omni) — the official vLLM project's omni-modal extension with native **VoxCPM2** support. PagedAttention KV cache, continuous batching, and a drop-in **OpenAI-compatible** `/v1/audio/speech` endpoint.
|
||||
|
||||
```bash
|
||||
# Install from source (latest main — vllm-omni is rapidly evolving)
|
||||
uv pip install vllm==0.19.0 --torch-backend=auto
|
||||
git clone https://github.com/vllm-project/vllm-omni.git && cd vllm-omni
|
||||
uv pip install -e .
|
||||
```
|
||||
|
||||
See the [vLLM-Omni installation guide](https://vllm-omni.readthedocs.io/en/latest/getting_started/installation/) for other platforms (ROCm, XPU, MUSA, NPU) and Docker images.
|
||||
|
||||
```bash
|
||||
# Launch an OpenAI-compatible TTS server (--omni enables omni-modal serving)
|
||||
vllm serve openbmb/VoxCPM2 --omni --port 8000
|
||||
|
||||
# Call it from any OpenAI client
|
||||
curl http://localhost:8000/v1/audio/speech \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"model":"openbmb/VoxCPM2","input":"Hello from VoxCPM2 on vLLM-Omni!","voice":"default"}' \
|
||||
--output out.wav
|
||||
```
|
||||
|
||||
> Built on the upstream vLLM scheduler, with batched concurrent requests, streaming chunk delivery, and multi-GPU deployment out of the box. See the [VoxCPM2 example](https://github.com/vllm-project/vllm-omni/tree/main/examples/online_serving/voxcpm2) for full deployment recipes.
|
||||
|
||||
> **Full parameter reference, multi-scenario examples, and voice cloning tips →** [Quick Start Guide](https://voxcpm.readthedocs.io/en/latest/quickstart.html) | [Usage Guide](https://voxcpm.readthedocs.io/en/latest/usage_guide.html) | [Cookbook](https://voxcpm.readthedocs.io/en/latest/cookbook.html)
|
||||
|
||||
---
|
||||
@@ -528,11 +554,13 @@ Full documentation: **[voxcpm.readthedocs.io](https://voxcpm.readthedocs.io/en/l
|
||||
| Project | Description |
|
||||
|---|---|
|
||||
| [**Nano-vLLM**](https://github.com/a710128/nanovllm-voxcpm) | High-throughput and Fast GPU serving |
|
||||
| [**vLLM-Omni**](https://github.com/vllm-project/vllm-omni) | Official vLLM omni-modal serving for VoxCPM2 — PagedAttention, OpenAI-compatible API |
|
||||
| [**VoxCPM.cpp**](https://github.com/bluryar/VoxCPM.cpp) | GGML/GGUF: CPU, CUDA, Vulkan inference |
|
||||
| [**VoxCPM-ONNX**](https://github.com/bluryar/VoxCPM-ONNX) | ONNX export for CPU inference |
|
||||
| [**VoxCPMANE**](https://github.com/0seba/VoxCPMANE) | Apple Neural Engine backend |
|
||||
| [**voxcpm_rs**](https://github.com/madushan1000/voxcpm_rs) | Rust re-implementation |
|
||||
| [**ComfyUI-VoxCPM**](https://github.com/wildminder/ComfyUI-VoxCPM) | ComfyUI node-based workflows |
|
||||
| [**ComfyUI_RH_VoxCPM**](https://github.com/HM-RunningHub/ComfyUI_RH_VoxCPM) | Feature-complete ComfyUI workflow for VoxCPM 2 with multi-speaker generation, LoRA, and auto-ASR |
|
||||
| [**ComfyUI-VoxCPMTTS**](https://github.com/1038lab/ComfyUI-VoxCPMTTS) | ComfyUI TTS extension |
|
||||
| [**TTS WebUI**](https://github.com/rsxdalv/tts_webui_extension.vox_cpm) | Browser-based TTS extension |
|
||||
|
||||
|
||||
+35
-7
@@ -46,7 +46,7 @@ VoxCPM 是一个**无离散音频分词器**(Tokenizer-Free)的语音合成
|
||||
- 🎙️ **极致克隆** — 提供参考音频及其文本内容,模型接着参考音频进行无缝续写,从而精准还原声音细节特征(与 VoxCPM1.5 一致)
|
||||
- 🔊 **48kHz 高质量音频** — 输入 16kHz 参考音频,通过 AudioVAE V2 的非对称编解码设计直接输出 48kHz 高质量音频,内置超分能力
|
||||
- 🧠 **语境感知合成** — 根据文本内容自动推断合适的韵律和表现力
|
||||
- ⚡ **实时流式合成** — 在 NVIDIA RTX 4090 上 RTF 低至 ~0.3,通过 [Nano-VLLM](https://github.com/a710128/nanovllm-voxcpm) 加速后可达 ~0.13
|
||||
- ⚡ **实时流式合成** — 在 NVIDIA RTX 4090 上 RTF 低至 ~0.3,通过 [Nano-vLLM](https://github.com/a710128/nanovllm-voxcpm) 或 [vLLM-Omni](https://github.com/vllm-project/vllm-omni)(官方 vLLM 全模态服务,原生支持 VoxCPM2,提供 PagedAttention 与 OpenAI 兼容 API)加速后可达 ~0.13
|
||||
- 📜 **完全开源,商用就绪** — 权重和代码基于 [Apache-2.0](LICENSE) 协议发布,免费商用
|
||||
|
||||
<summary><b>🌍 支持的语言(30种)</b></summary>
|
||||
@@ -91,7 +91,7 @@ VoxCPM 是一个**无离散音频分词器**(Tokenizer-Free)的语音合成
|
||||
pip install voxcpm
|
||||
```
|
||||
|
||||
> **环境要求:** Python ≥ 3.10,PyTorch ≥ 2.5.0,CUDA ≥ 12.0。详见 [快速开始文档](https://voxcpm.readthedocs.io/zh-cn/latest/quickstart.html)。
|
||||
> **环境要求:** Python ≥ 3.10 (<3.13),PyTorch ≥ 2.5.0,CUDA ≥ 12.0。详见 [快速开始文档](https://voxcpm.readthedocs.io/zh-cn/latest/quickstart.html)。
|
||||
|
||||
### Python API
|
||||
|
||||
@@ -122,12 +122,12 @@ pip install modelscope
|
||||
```
|
||||
|
||||
```python
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope import snapshot_download
|
||||
snapshot_download("OpenBMB/VoxCPM2", local_dir='./pretrained_models/VoxCPM2') # 指定模型保存的本地路径
|
||||
|
||||
from voxcpm import VoxCPM
|
||||
import soundfile as sf
|
||||
|
||||
local_model_dir = snapshot_download("OpenBMB/VoxCPM2")
|
||||
model = VoxCPM.from_pretrained(local_model_dir, load_denoiser=False)
|
||||
model = VoxCPM.from_pretrained('./pretrained_models/VoxCPM2', load_denoiser=False)
|
||||
|
||||
wav = model.generate(
|
||||
text="VoxCPM2 是目前推荐使用的多语言语音合成版本。",
|
||||
@@ -238,7 +238,7 @@ voxcpm --help
|
||||
### Web Demo
|
||||
|
||||
```bash
|
||||
python app.py --model-dir /path/to/VoxCPM2 --port 8808 # 指定本地模型路径,然后打开 http://localhost:8808
|
||||
python app.py --port 8808 # 然后在浏览器打开 http://localhost:8808
|
||||
```
|
||||
|
||||
### 🚢 生产部署(Nano-vLLM)
|
||||
@@ -261,6 +261,32 @@ server.stop()
|
||||
|
||||
> **在 NVIDIA RTX 4090 上 RTF 低至 ~0.13**(标准 PyTorch 实现约 ~0.3),支持批量并发请求和 FastAPI HTTP 服务。详见 [Nano-vLLM-VoxCPM 仓库](https://github.com/a710128/nanovllm-voxcpm)。
|
||||
|
||||
### 🏭 生产环境部署(vLLM-Omni)
|
||||
|
||||
如需生产级多租户部署,使用 [**vLLM-Omni**](https://github.com/vllm-project/vllm-omni) — 官方 vLLM 项目的全模态扩展,原生支持 **VoxCPM2**。具备 PagedAttention KV 缓存、连续批处理,以及与 OpenAI 完全兼容的 `/v1/audio/speech` 接口。
|
||||
|
||||
```bash
|
||||
# 从源码安装(最新 main 分支 —— vllm-omni 正在快速迭代)
|
||||
uv pip install vllm==0.19.0 --torch-backend=auto
|
||||
git clone https://github.com/vllm-project/vllm-omni.git && cd vllm-omni
|
||||
uv pip install -e .
|
||||
```
|
||||
|
||||
其他平台(ROCm、XPU、MUSA、NPU)与 Docker 镜像请参考 [vLLM-Omni 安装文档](https://vllm-omni.readthedocs.io/en/latest/getting_started/installation/)。
|
||||
|
||||
```bash
|
||||
# 启动 OpenAI 兼容的 TTS 服务(--omni 启用全模态服务)
|
||||
vllm serve openbmb/VoxCPM2 --omni --port 8000
|
||||
|
||||
# 任意 OpenAI 客户端均可调用
|
||||
curl http://localhost:8000/v1/audio/speech \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"model":"openbmb/VoxCPM2","input":"你好,欢迎使用 VoxCPM2 on vLLM-Omni!","voice":"default"}' \
|
||||
--output out.wav
|
||||
```
|
||||
|
||||
> 基于上游 vLLM 调度器构建,开箱即用支持批量并发、流式分块输出和多 GPU 部署。完整示例见 [VoxCPM2 部署样例](https://github.com/vllm-project/vllm-omni/tree/main/examples/online_serving/voxcpm2)。
|
||||
|
||||
> **完整参数说明、多场景示例与声音克隆技巧 →** [快速开始指南](https://voxcpm.readthedocs.io/zh-cn/latest/quickstart.html) | [使用指南](https://voxcpm.readthedocs.io/zh-cn/latest/usage_guide.html) | [Cookbook](https://voxcpm.readthedocs.io/zh-cn/latest/cookbook.html)
|
||||
|
||||
---
|
||||
@@ -521,11 +547,13 @@ python lora_ft_webui.py # 然后打开 http://localhost:7860
|
||||
| 项目 | 说明 |
|
||||
|---|---|
|
||||
| [**Nano-vLLM**](https://github.com/a710128/nanovllm-voxcpm) | 高吞吐快速 GPU 推理引擎 |
|
||||
| [**vLLM-Omni**](https://github.com/vllm-project/vllm-omni) | 官方 vLLM 全模态服务(原生支持 VoxCPM2)— PagedAttention、OpenAI 兼容 API |
|
||||
| [**VoxCPM.cpp**](https://github.com/bluryar/VoxCPM.cpp) | GGML/GGUF:CPU、CUDA、Vulkan 推理 |
|
||||
| [**VoxCPM-ONNX**](https://github.com/bluryar/VoxCPM-ONNX) | ONNX 导出,支持 CPU 推理 |
|
||||
| [**VoxCPMANE**](https://github.com/0seba/VoxCPMANE) | Apple Neural Engine 后端 |
|
||||
| [**voxcpm_rs**](https://github.com/madushan1000/voxcpm_rs) | Rust 重新实现 |
|
||||
| [**ComfyUI-VoxCPM**](https://github.com/wildminder/ComfyUI-VoxCPM) | ComfyUI 节点工作流 |
|
||||
| [**ComfyUI_RH_VoxCPM**](https://github.com/HM-RunningHub/ComfyUI_RH_VoxCPM) | 面向 VoxCPM 2 的功能更完整的 ComfyUI 工作流,支持多说话人、LoRA 和自动 ASR |
|
||||
| [**ComfyUI-VoxCPMTTS**](https://github.com/1038lab/ComfyUI-VoxCPMTTS) | ComfyUI TTS 扩展 |
|
||||
| [**TTS WebUI**](https://github.com/rsxdalv/tts_webui_extension.vox_cpm) | 浏览器端 TTS 扩展 |
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import logging
|
||||
import numpy as np
|
||||
@@ -9,8 +10,6 @@ from funasr import AutoModel
|
||||
from pathlib import Path
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
if os.environ.get("HF_REPO_ID", "").strip() == "":
|
||||
os.environ["HF_REPO_ID"] = "openbmb/VoxCPM2"
|
||||
|
||||
import voxcpm
|
||||
|
||||
@@ -55,7 +54,7 @@ _EXAMPLES_FOOTER_EN = (
|
||||
)
|
||||
|
||||
_USAGE_INSTRUCTIONS_ZH = (
|
||||
"**VoxCPM2 — 三种语音生成方式:**\n\n"
|
||||
"**三种语音生成方式:**\n\n"
|
||||
"🎨 **声音设计(Voice Design)** \n"
|
||||
"无需参考音频。在 **Control Instruction** 中描述目标音色特征"
|
||||
"(性别、年龄、语气、情绪、语速等),VoxCPM2 即可为你从零创造独一无二的声音。\n\n"
|
||||
@@ -66,6 +65,8 @@ _USAGE_INSTRUCTIONS_ZH = (
|
||||
"开启 **极致克隆模式** 并提供参考音频的文字内容(可自动识别)。"
|
||||
"模型会将参考音频视为已说出的前文,以**音频续写**的方式完整还原参考音频中的所有声音细节。"
|
||||
"注意:该模式与可控克隆模式互斥,将禁用Control Instruction。\n\n"
|
||||
"目前支持的方言包括:\n"
|
||||
"「四川话、粤语、吴语、东北话、河南话、陕西话、山东话、天津话、闽南话」"
|
||||
)
|
||||
|
||||
_EXAMPLES_FOOTER_ZH = (
|
||||
@@ -221,11 +222,11 @@ _APP_THEME = gr.themes.Soft(
|
||||
# ---------- Model ----------
|
||||
|
||||
class VoxCPMDemo:
|
||||
def __init__(self, model_dir: Optional[str] = None) -> None:
|
||||
def __init__(self, model_id: str = "openbmb/VoxCPM2") -> None:
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"Running on device: {self.device}")
|
||||
logger.info(f"运行在设备上: {self.device}")
|
||||
|
||||
self.asr_model_id = "iic/SenseVoiceSmall"
|
||||
self.asr_model_id = "./models/iic/SenseVoiceSmall"
|
||||
self.asr_model: Optional[AutoModel] = AutoModel(
|
||||
model=self.asr_model_id,
|
||||
disable_update=True,
|
||||
@@ -234,36 +235,13 @@ class VoxCPMDemo:
|
||||
)
|
||||
|
||||
self.voxcpm_model: Optional[voxcpm.VoxCPM] = None
|
||||
self.explicit_model_dir = model_dir
|
||||
|
||||
def _resolve_model_dir(self) -> str:
|
||||
if self.explicit_model_dir and os.path.isdir(self.explicit_model_dir):
|
||||
return self.explicit_model_dir
|
||||
env_model_dir = os.environ.get("VOXCPM_MODEL_DIR", "").strip()
|
||||
if env_model_dir and os.path.isdir(env_model_dir):
|
||||
return env_model_dir
|
||||
repo_id = os.environ.get("HF_REPO_ID", "").strip()
|
||||
if len(repo_id) > 0:
|
||||
target_dir = os.path.join("models", repo_id.replace("/", "__"))
|
||||
if not os.path.isdir(target_dir):
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
os.makedirs(target_dir, exist_ok=True)
|
||||
logger.info(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...")
|
||||
snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False)
|
||||
except Exception as e:
|
||||
logger.warning(f"HF download failed: {e}. Falling back to 'models'.")
|
||||
return "models"
|
||||
return target_dir
|
||||
return "models"
|
||||
self._model_id = model_id
|
||||
|
||||
def get_or_load_voxcpm(self) -> voxcpm.VoxCPM:
|
||||
if self.voxcpm_model is not None:
|
||||
return self.voxcpm_model
|
||||
logger.info("Model not loaded, initializing...")
|
||||
model_dir = self._resolve_model_dir()
|
||||
logger.info(f"Using model dir: {model_dir}")
|
||||
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir, optimize=True)
|
||||
logger.info(f"Loading model: {self._model_id}")
|
||||
self.voxcpm_model = voxcpm.VoxCPM.from_pretrained(self._model_id, optimize=True,zipenhancer_model_id="./models/iic/speech_zipenhancer_ans_multiloss_16k_base")
|
||||
logger.info("Model loaded successfully.")
|
||||
return self.voxcpm_model
|
||||
|
||||
@@ -315,6 +293,9 @@ class VoxCPMDemo:
|
||||
raise ValueError("Please input text to synthesize.")
|
||||
|
||||
control = (control_instruction or "").strip()
|
||||
# Strip any parentheses (half-width/full-width) from control text to avoid
|
||||
# breaking the "(control)text" prompt format expected by the model.
|
||||
control = re.sub(r"[()()]", "", control).strip()
|
||||
final_text = f"({control}){text}" if control else text
|
||||
|
||||
audio_path = reference_wav_path_input if reference_wav_path_input else None
|
||||
@@ -507,9 +488,9 @@ def run_demo(
|
||||
server_name: str = "0.0.0.0",
|
||||
server_port: int = 8808,
|
||||
show_error: bool = True,
|
||||
model_dir: Optional[str] = None,
|
||||
model_id: str = "./models/OpenBMB/VoxCPM2",
|
||||
):
|
||||
demo = VoxCPMDemo(model_dir=model_dir)
|
||||
demo = VoxCPMDemo(model_id=model_id)
|
||||
interface = create_demo_interface(demo)
|
||||
interface.queue(max_size=10, default_concurrency_limit=1).launch(
|
||||
server_name=server_name,
|
||||
@@ -524,7 +505,10 @@ def run_demo(
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model-dir", type=str, default=None, help="Path to VoxCPM2 checkpoint directory")
|
||||
parser.add_argument("--port", type=int, default=8808, help="Server port")
|
||||
parser.add_argument(
|
||||
"--model-id", type=str, default="./models/OpenBMB/VoxCPM2",
|
||||
help="本地路径或HuggingFace仓库ID(默认:./models/openbmb/VoxCPM2)",
|
||||
)
|
||||
parser.add_argument("--port", type=int, default=8808, help="服务端口")
|
||||
args = parser.parse_args()
|
||||
run_demo(model_dir=args.model_dir, server_port=args.port)
|
||||
run_demo(model_id=args.model_id, server_port=args.port)
|
||||
-280
@@ -1,280 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
import gradio as gr
|
||||
from typing import Optional, Tuple
|
||||
from funasr import AutoModel
|
||||
from pathlib import Path
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
if os.environ.get("HF_REPO_ID", "").strip() == "":
|
||||
os.environ["HF_REPO_ID"] = "openbmb/VoxCPM1.5"
|
||||
|
||||
import voxcpm
|
||||
|
||||
|
||||
class VoxCPMDemo:
|
||||
def __init__(self) -> None:
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"🚀 Running on device: {self.device}", file=sys.stderr)
|
||||
|
||||
# ASR model for prompt text recognition
|
||||
self.asr_model_id = "iic/SenseVoiceSmall"
|
||||
self.asr_model: Optional[AutoModel] = AutoModel(
|
||||
model=self.asr_model_id,
|
||||
disable_update=True,
|
||||
log_level='DEBUG',
|
||||
device="cuda:0" if self.device == "cuda" else "cpu",
|
||||
)
|
||||
|
||||
# TTS model (lazy init)
|
||||
self.voxcpm_model: Optional[voxcpm.VoxCPM] = None
|
||||
self.default_local_model_dir = "./models/VoxCPM1.5"
|
||||
|
||||
# ---------- Model helpers ----------
|
||||
def _resolve_model_dir(self) -> str:
|
||||
"""
|
||||
Resolve model directory:
|
||||
1) Use local checkpoint directory if exists
|
||||
2) If HF_REPO_ID env is set, download into models/{repo}
|
||||
3) Fallback to 'models'
|
||||
"""
|
||||
if os.path.isdir(self.default_local_model_dir):
|
||||
return self.default_local_model_dir
|
||||
|
||||
repo_id = os.environ.get("HF_REPO_ID", "").strip()
|
||||
if len(repo_id) > 0:
|
||||
target_dir = os.path.join("models", repo_id.replace("/", "__"))
|
||||
if not os.path.isdir(target_dir):
|
||||
try:
|
||||
from huggingface_hub import snapshot_download # type: ignore
|
||||
os.makedirs(target_dir, exist_ok=True)
|
||||
print(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...", file=sys.stderr)
|
||||
snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False)
|
||||
except Exception as e:
|
||||
print(f"Warning: HF download failed: {e}. Falling back to 'data'.", file=sys.stderr)
|
||||
return "models"
|
||||
return target_dir
|
||||
return "models"
|
||||
|
||||
def get_or_load_voxcpm(self) -> voxcpm.VoxCPM:
|
||||
if self.voxcpm_model is not None:
|
||||
return self.voxcpm_model
|
||||
print("Model not loaded, initializing...", file=sys.stderr)
|
||||
model_dir = self._resolve_model_dir()
|
||||
print(f"Using model dir: {model_dir}", file=sys.stderr)
|
||||
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir)
|
||||
print("Model loaded successfully.", file=sys.stderr)
|
||||
return self.voxcpm_model
|
||||
|
||||
# ---------- Functional endpoints ----------
|
||||
def prompt_wav_recognition(self, prompt_wav: Optional[str]) -> str:
|
||||
if prompt_wav is None:
|
||||
return ""
|
||||
res = self.asr_model.generate(input=prompt_wav, language="auto", use_itn=True)
|
||||
text = res[0]["text"].split('|>')[-1]
|
||||
return text
|
||||
|
||||
def generate_tts_audio(
|
||||
self,
|
||||
text_input: str,
|
||||
prompt_wav_path_input: Optional[str] = None,
|
||||
prompt_text_input: Optional[str] = None,
|
||||
cfg_value_input: float = 2.0,
|
||||
inference_timesteps_input: int = 10,
|
||||
do_normalize: bool = True,
|
||||
denoise: bool = True,
|
||||
) -> Tuple[int, np.ndarray]:
|
||||
"""
|
||||
Generate speech from text using VoxCPM; optional reference audio for voice style guidance.
|
||||
Returns (sample_rate, waveform_numpy)
|
||||
"""
|
||||
current_model = self.get_or_load_voxcpm()
|
||||
|
||||
text = (text_input or "").strip()
|
||||
if len(text) == 0:
|
||||
raise ValueError("Please input text to synthesize.")
|
||||
|
||||
prompt_wav_path = prompt_wav_path_input if prompt_wav_path_input else None
|
||||
prompt_text = prompt_text_input if prompt_text_input else None
|
||||
|
||||
print(f"Generating audio for text: '{text[:60]}...'", file=sys.stderr)
|
||||
wav = current_model.generate(
|
||||
text=text,
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
cfg_value=float(cfg_value_input),
|
||||
inference_timesteps=int(inference_timesteps_input),
|
||||
normalize=do_normalize,
|
||||
denoise=denoise,
|
||||
)
|
||||
return (current_model.tts_model.sample_rate, wav)
|
||||
|
||||
|
||||
# ---------- UI Builders ----------
|
||||
|
||||
_APP_THEME = gr.themes.Soft(
|
||||
primary_hue="blue",
|
||||
secondary_hue="gray",
|
||||
neutral_hue="slate",
|
||||
font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"],
|
||||
)
|
||||
|
||||
_CUSTOM_CSS = """
|
||||
.logo-container {
|
||||
text-align: center;
|
||||
margin: 0.5rem 0 1rem 0;
|
||||
}
|
||||
.logo-container img {
|
||||
height: 80px;
|
||||
width: auto;
|
||||
max-width: 200px;
|
||||
display: inline-block;
|
||||
}
|
||||
/* Bold accordion labels */
|
||||
#acc_quick details > summary,
|
||||
#acc_tips details > summary {
|
||||
font-weight: 600 !important;
|
||||
font-size: 1.1em !important;
|
||||
}
|
||||
/* Bold labels for specific checkboxes */
|
||||
#chk_denoise label,
|
||||
#chk_denoise span,
|
||||
#chk_normalize label,
|
||||
#chk_normalize span {
|
||||
font-weight: 600;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def create_demo_interface(demo: VoxCPMDemo):
|
||||
"""Build the Gradio UI for VoxCPM demo."""
|
||||
gr.set_static_paths(paths=[Path.cwd().absolute()/"assets"])
|
||||
|
||||
with gr.Blocks() as interface:
|
||||
# Header logo
|
||||
gr.HTML('<div class="logo-container"><img src="/gradio_api/file=assets/voxcpm_logo.png" alt="VoxCPM Logo"></div>')
|
||||
|
||||
# Quick Start
|
||||
with gr.Accordion("📋 Quick Start Guide |快速入门", open=False, elem_id="acc_quick"):
|
||||
gr.Markdown("""
|
||||
### How to Use |使用说明
|
||||
1. **(Optional) Provide a Voice Prompt** - Upload or record an audio clip to provide the desired voice characteristics for synthesis.
|
||||
**(可选)提供参考声音** - 上传或录制一段音频,为声音合成提供音色、语调和情感等个性化特征
|
||||
2. **(Optional) Enter prompt text** - If you provided a voice prompt, enter the corresponding transcript here (auto-recognition available).
|
||||
**(可选项)输入参考文本** - 如果提供了参考语音,请输入其对应的文本内容(支持自动识别)。
|
||||
3. **Enter target text** - Type the text you want the model to speak.
|
||||
**输入目标文本** - 输入您希望模型朗读的文字内容。
|
||||
4. **Generate Speech** - Click the "Generate" button to create your audio.
|
||||
**生成语音** - 点击"生成"按钮,即可为您创造出音频。
|
||||
""")
|
||||
|
||||
# Pro Tips
|
||||
with gr.Accordion("💡 Pro Tips |使用建议", open=False, elem_id="acc_tips"):
|
||||
gr.Markdown("""
|
||||
### Prompt Speech Enhancement|参考语音降噪
|
||||
- **Enable** to remove background noise for a clean voice, with an external ZipEnhancer component. However, this will limit the audio sampling rate to 16kHz, restricting the cloning quality ceiling.
|
||||
**启用**:通过 ZipEnhancer 组件消除背景噪音,但会将音频采样率限制在16kHz,限制克隆上限。
|
||||
- **Disable** to preserve the original audio's all information, including background atmosphere, and support audio cloning up to 44.1kHz sampling rate.
|
||||
**禁用**:保留原始音频的全部信息,包括背景环境声,最高支持44.1kHz的音频复刻。
|
||||
|
||||
### Text Normalization|文本正则化
|
||||
- **Enable** to process general text with an external WeTextProcessing component.
|
||||
**启用**:使用 WeTextProcessing 组件,可支持常见文本的正则化处理。
|
||||
- **Disable** to use VoxCPM's native text understanding ability. For example, it supports phonemes input (For Chinese, phonemes are converted using pinyin, {ni3}{hao3}; For English, phonemes are converted using CMUDict, {HH AH0 L OW1}), try it!
|
||||
**禁用**:将使用 VoxCPM 内置的文本理解能力。如,支持音素输入(如中文转拼音:{ni3}{hao3};英文转CMUDict:{HH AH0 L OW1})和公式符号合成,尝试一下!
|
||||
|
||||
### CFG Value|CFG 值
|
||||
- **Lower CFG** if the voice prompt sounds strained or expressive, or instability occurs with long text input.
|
||||
**调低**:如果提示语音听起来不自然或过于夸张,或者长文本输入出现稳定性问题。
|
||||
- **Higher CFG** for better adherence to the prompt speech style or input text, or instability occurs with too short text input.
|
||||
**调高**:为更好地贴合提示音频的风格或输入文本, 或者极短文本输入出现稳定性问题。
|
||||
|
||||
### Inference Timesteps|推理时间步
|
||||
- **Lower** for faster synthesis speed.
|
||||
**调低**:合成速度更快。
|
||||
- **Higher** for better synthesis quality.
|
||||
**调高**:合成质量更佳。
|
||||
""")
|
||||
|
||||
# Main controls
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
prompt_wav = gr.Audio(
|
||||
sources=["upload", 'microphone'],
|
||||
type="filepath",
|
||||
label="Prompt Speech (Optional, or let VoxCPM improvise)",
|
||||
value="./examples/example.wav",
|
||||
)
|
||||
DoDenoisePromptAudio = gr.Checkbox(
|
||||
value=False,
|
||||
label="Prompt Speech Enhancement",
|
||||
elem_id="chk_denoise",
|
||||
info="We use ZipEnhancer model to denoise the prompt audio."
|
||||
)
|
||||
with gr.Row():
|
||||
prompt_text = gr.Textbox(
|
||||
value="Just by listening a few minutes a day, you'll be able to eliminate negative thoughts by conditioning your mind to be more positive.",
|
||||
label="Prompt Text",
|
||||
placeholder="Please enter the prompt text. Automatic recognition is supported, and you can correct the results yourself..."
|
||||
)
|
||||
run_btn = gr.Button("Generate Speech", variant="primary")
|
||||
|
||||
with gr.Column():
|
||||
cfg_value = gr.Slider(
|
||||
minimum=1.0,
|
||||
maximum=3.0,
|
||||
value=2.0,
|
||||
step=0.1,
|
||||
label="CFG Value (Guidance Scale)",
|
||||
info="Higher values increase adherence to prompt, lower values allow more creativity"
|
||||
)
|
||||
inference_timesteps = gr.Slider(
|
||||
minimum=4,
|
||||
maximum=30,
|
||||
value=10,
|
||||
step=1,
|
||||
label="Inference Timesteps",
|
||||
info="Number of inference timesteps for generation (higher values may improve quality but slower)"
|
||||
)
|
||||
with gr.Row():
|
||||
text = gr.Textbox(
|
||||
value="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly realistic speech.",
|
||||
label="Target Text",
|
||||
)
|
||||
with gr.Row():
|
||||
DoNormalizeText = gr.Checkbox(
|
||||
value=False,
|
||||
label="Text Normalization",
|
||||
elem_id="chk_normalize",
|
||||
info="We use wetext library to normalize the input text."
|
||||
)
|
||||
audio_output = gr.Audio(label="Output Audio")
|
||||
|
||||
# Wiring
|
||||
run_btn.click(
|
||||
fn=demo.generate_tts_audio,
|
||||
inputs=[text, prompt_wav, prompt_text, cfg_value, inference_timesteps, DoNormalizeText, DoDenoisePromptAudio],
|
||||
outputs=[audio_output],
|
||||
show_progress=True,
|
||||
api_name="generate",
|
||||
)
|
||||
prompt_wav.change(fn=demo.prompt_wav_recognition, inputs=[prompt_wav], outputs=[prompt_text])
|
||||
|
||||
return interface
|
||||
|
||||
|
||||
def run_demo(server_name: str = "localhost", server_port: int = 7860, show_error: bool = True):
|
||||
demo = VoxCPMDemo()
|
||||
interface = create_demo_interface(demo)
|
||||
interface.queue(max_size=10, default_concurrency_limit=1).launch(
|
||||
server_name=server_name,
|
||||
server_port=server_port,
|
||||
show_error=show_error,
|
||||
theme=_APP_THEME,
|
||||
css=_CUSTOM_CSS,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_demo()
|
||||
+25
-4
@@ -281,27 +281,48 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
|
||||
print(f"Warning: Failed to read base_model from LoRA config: {e}", file=sys.stderr)
|
||||
|
||||
# 加载模型
|
||||
lora_to_load = lora_selection if lora_selection and lora_selection != "None" else None
|
||||
try:
|
||||
print(f"Loading base model: {base_model_path}", file=sys.stderr)
|
||||
load_model(base_model_path)
|
||||
if lora_selection and lora_selection != "None":
|
||||
print(f"Model loaded for LoRA: {lora_selection}", file=sys.stderr)
|
||||
load_model(base_model_path, lora_to_load)
|
||||
if lora_to_load:
|
||||
print(f"Model loaded with LoRA: {lora_selection}", file=sys.stderr)
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to load model from {base_model_path}: {str(e)}"
|
||||
print(error_msg, file=sys.stderr)
|
||||
return None, error_msg
|
||||
lora_just_loaded = lora_to_load
|
||||
else:
|
||||
lora_just_loaded = None
|
||||
|
||||
# Handle LoRA hot-swapping
|
||||
assert current_model is not None, "Model must be loaded before inference"
|
||||
if lora_selection and lora_selection != "None":
|
||||
full_lora_path = os.path.join("lora", lora_selection)
|
||||
|
||||
if lora_just_loaded != lora_selection:
|
||||
new_lora_config, new_base_model = load_lora_config_from_checkpoint(full_lora_path)
|
||||
current_r = current_model.tts_model.lora_config.r if current_model.tts_model.lora_config else None
|
||||
new_r = new_lora_config.r if new_lora_config else None
|
||||
|
||||
if new_r is not None and current_r is not None and new_r != current_r:
|
||||
print(f"LoRA rank mismatch (model r={current_r}, checkpoint r={new_r}), reloading...", file=sys.stderr)
|
||||
reload_base = (
|
||||
new_base_model if new_base_model and os.path.exists(new_base_model)
|
||||
else (pretrained_path if pretrained_path and pretrained_path.strip() else default_pretrained_path)
|
||||
)
|
||||
try:
|
||||
load_model(reload_base, lora_selection)
|
||||
except Exception as e:
|
||||
return None, f"Failed to reload model for LoRA rank change: {e}"
|
||||
else:
|
||||
print(f"Hot-loading LoRA: {full_lora_path}", file=sys.stderr)
|
||||
try:
|
||||
current_model.load_lora(full_lora_path)
|
||||
current_model.set_lora_enabled(True)
|
||||
except Exception as e:
|
||||
print(f"Error loading LoRA: {e}", file=sys.stderr)
|
||||
return None, f"Error loading LoRA: {e}"
|
||||
current_model.set_lora_enabled(True)
|
||||
else:
|
||||
print("Disabling LoRA", file=sys.stderr)
|
||||
current_model.set_lora_enabled(False)
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
模型下载脚本
|
||||
|
||||
"""
|
||||
from modelscope import snapshot_download
|
||||
|
||||
|
||||
def download(repo_id:str, local_dir:str):
|
||||
"""
|
||||
下载模型仓库或单个文件
|
||||
|
||||
|
||||
Args:
|
||||
repo_id (str): 用户名/仓库名,例如 'stabilityai/sdxl-turbo'
|
||||
local_dir (str or Path): 下载文件放置的本地目录路径
|
||||
|
||||
Returns:
|
||||
str: 下载文件的本地路径
|
||||
|
||||
Raises:
|
||||
ValueError: 当 repo_id 格式不正确时
|
||||
"""
|
||||
model_dir = snapshot_download(
|
||||
repo_id,
|
||||
repo_type='model',
|
||||
local_dir=f"{local_dir}/{repo_id}",
|
||||
)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
download("OpenBMB/VoxCPM2", "./models")
|
||||
download("iic/SenseVoiceSmall", "./models")
|
||||
download('iic/speech_zipenhancer_ans_multiloss_16k_base',"./models")
|
||||
+1
-2
@@ -47,8 +47,7 @@ dependencies = [
|
||||
"funasr",
|
||||
"spaces",
|
||||
"argbind",
|
||||
"safetensors"
|
||||
|
||||
"safetensors",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
"""Unit checks for pick_runtime_dtype / get_dtype consistency.
|
||||
|
||||
Loads src/voxcpm/model/utils.py directly to avoid the heavy voxcpm package
|
||||
init. Run with: `python scripts/test_pick_runtime_dtype.py`.
|
||||
"""
|
||||
import importlib.util
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent
|
||||
UTILS = str(REPO_ROOT / "src" / "voxcpm" / "model" / "utils.py")
|
||||
spec = importlib.util.spec_from_file_location("voxcpm_utils", UTILS)
|
||||
utils = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(utils)
|
||||
|
||||
_LOW_PRECISION_DTYPES = utils._LOW_PRECISION_DTYPES
|
||||
_VALID_DTYPE_OVERRIDES = utils._VALID_DTYPE_OVERRIDES
|
||||
get_dtype = utils.get_dtype
|
||||
pick_runtime_dtype = utils.pick_runtime_dtype
|
||||
|
||||
|
||||
def expect(actual, expected, label):
|
||||
ok = actual == expected
|
||||
mark = "OK " if ok else "FAIL"
|
||||
print(f"[{mark}] {label}: got={actual!r} expected={expected!r}")
|
||||
return ok
|
||||
|
||||
|
||||
def expect_raises(fn, exc_type, label):
|
||||
try:
|
||||
fn()
|
||||
except exc_type as e:
|
||||
print(f"[OK ] {label}: raised {exc_type.__name__}: {e}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"[FAIL] {label}: raised {type(e).__name__} not {exc_type.__name__}: {e}")
|
||||
return False
|
||||
print(f"[FAIL] {label}: no exception raised")
|
||||
return False
|
||||
|
||||
|
||||
results = []
|
||||
|
||||
print("=== override set sanity ===")
|
||||
results.append(expect("half" not in _VALID_DTYPE_OVERRIDES, True, "half removed from _VALID_DTYPE_OVERRIDES"))
|
||||
results.append(expect("half" not in _LOW_PRECISION_DTYPES, True, "half removed from _LOW_PRECISION_DTYPES"))
|
||||
|
||||
print("\n=== every accepted override parses through get_dtype ===")
|
||||
for dt in sorted(_VALID_DTYPE_OVERRIDES):
|
||||
try:
|
||||
torch_dtype = get_dtype(dt)
|
||||
print(f"[OK ] get_dtype({dt!r}) -> {torch_dtype}")
|
||||
results.append(True)
|
||||
except Exception as e:
|
||||
print(f"[FAIL] get_dtype({dt!r}) raised: {e}")
|
||||
results.append(False)
|
||||
|
||||
print("\n=== pick_runtime_dtype: non-mps is a no-op ===")
|
||||
results.append(expect(pick_runtime_dtype("cuda", "bfloat16"), "bfloat16", "cuda/bf16 untouched"))
|
||||
results.append(expect(pick_runtime_dtype("cpu", "float16"), "float16", "cpu/fp16 untouched"))
|
||||
results.append(expect(pick_runtime_dtype("cuda", "float32"), "float32", "cuda/fp32 untouched"))
|
||||
|
||||
print("\n=== pick_runtime_dtype: mps forces fp32 for low-precision ===")
|
||||
os.environ.pop("VOXCPM_MPS_DTYPE", None)
|
||||
results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "float32", "mps/bf16 -> fp32"))
|
||||
results.append(expect(pick_runtime_dtype("mps", "bf16"), "float32", "mps/bf16-alias -> fp32"))
|
||||
results.append(expect(pick_runtime_dtype("mps", "float16"), "float32", "mps/fp16 -> fp32"))
|
||||
results.append(expect(pick_runtime_dtype("mps", "fp16"), "float32", "mps/fp16-alias -> fp32"))
|
||||
results.append(expect(pick_runtime_dtype("mps", "float32"), "float32", "mps/fp32 stays"))
|
||||
results.append(expect(pick_runtime_dtype("mps", "fp32"), "fp32", "mps/fp32-alias stays"))
|
||||
|
||||
print("\n=== pick_runtime_dtype: VOXCPM_MPS_DTYPE override ===")
|
||||
os.environ["VOXCPM_MPS_DTYPE"] = "bfloat16"
|
||||
results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "bfloat16", "override bf16 honored"))
|
||||
|
||||
os.environ["VOXCPM_MPS_DTYPE"] = "FP16"
|
||||
results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "fp16", "override is case-insensitive"))
|
||||
|
||||
os.environ["VOXCPM_MPS_DTYPE"] = " float32 "
|
||||
results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "float32", "override is whitespace-trimmed"))
|
||||
|
||||
print("\n=== pick_runtime_dtype: 'half' is no longer a valid override ===")
|
||||
os.environ["VOXCPM_MPS_DTYPE"] = "half"
|
||||
results.append(
|
||||
expect_raises(
|
||||
lambda: pick_runtime_dtype("mps", "bfloat16"),
|
||||
ValueError,
|
||||
"override=half now rejected (was the bug)",
|
||||
)
|
||||
)
|
||||
|
||||
os.environ["VOXCPM_MPS_DTYPE"] = "garbage"
|
||||
results.append(
|
||||
expect_raises(
|
||||
lambda: pick_runtime_dtype("mps", "bfloat16"),
|
||||
ValueError,
|
||||
"override=garbage still rejected",
|
||||
)
|
||||
)
|
||||
|
||||
os.environ.pop("VOXCPM_MPS_DTYPE", None)
|
||||
|
||||
print("\n=== summary ===")
|
||||
passed = sum(results)
|
||||
total = len(results)
|
||||
print(f"{passed}/{total} passed")
|
||||
sys.exit(0 if passed == total else 1)
|
||||
+60
-6
@@ -11,11 +11,6 @@ import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import soundfile as sf
|
||||
|
||||
from voxcpm.core import VoxCPM
|
||||
|
||||
|
||||
DEFAULT_HF_MODEL_ID = "openbmb/VoxCPM2"
|
||||
|
||||
# -----------------------------
|
||||
@@ -173,7 +168,9 @@ def validate_batch_args(args, parser):
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def load_model(args) -> VoxCPM:
|
||||
def load_model(args):
|
||||
from voxcpm.core import VoxCPM
|
||||
|
||||
print("Loading VoxCPM model...", file=sys.stderr)
|
||||
|
||||
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get(
|
||||
@@ -209,6 +206,7 @@ def load_model(args) -> VoxCPM:
|
||||
zipenhancer_model_path=zipenhancer_path,
|
||||
enable_denoiser=not args.no_denoiser,
|
||||
optimize=not args.no_optimize,
|
||||
device=args.device,
|
||||
lora_config=lora_config,
|
||||
lora_weights_path=lora_weights_path,
|
||||
)
|
||||
@@ -227,6 +225,7 @@ def load_model(args) -> VoxCPM:
|
||||
cache_dir=args.cache_dir,
|
||||
local_files_only=args.local_files_only,
|
||||
optimize=not args.no_optimize,
|
||||
device=args.device,
|
||||
lora_config=lora_config,
|
||||
lora_weights_path=lora_weights_path,
|
||||
)
|
||||
@@ -264,6 +263,8 @@ def _run_single(args, parser, *, text: str, output: str, prompt_text: str | None
|
||||
and (args.prompt_audio is not None or args.reference_audio is not None),
|
||||
)
|
||||
|
||||
import soundfile as sf
|
||||
|
||||
sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
|
||||
|
||||
duration = len(audio_array) / model.tts_model.sample_rate
|
||||
@@ -286,7 +287,27 @@ def cmd_clone(args, parser):
|
||||
)
|
||||
|
||||
|
||||
def cmd_validate(args, parser):
|
||||
from voxcpm.training.validate import (
|
||||
print_validation_report,
|
||||
validate_manifest,
|
||||
)
|
||||
|
||||
manifest = str(require_file_exists(args.manifest, parser, "manifest file"))
|
||||
result = validate_manifest(
|
||||
manifest_path=manifest,
|
||||
sample_rate=args.sample_rate,
|
||||
max_samples=args.max_samples,
|
||||
verbose=args.verbose,
|
||||
)
|
||||
print_validation_report(result, manifest)
|
||||
if not result.is_valid:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def cmd_batch(args, parser):
|
||||
import soundfile as sf
|
||||
|
||||
input_file = require_file_exists(args.input, parser, "input file")
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -403,6 +424,12 @@ def _add_model_args(parser):
|
||||
default=DEFAULT_HF_MODEL_ID,
|
||||
help=f"Hugging Face repo id (default: {DEFAULT_HF_MODEL_ID})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="auto",
|
||||
help="Runtime device: auto, cpu, mps, cuda, or cuda:N (default: auto)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache-dir", type=str, help="Cache directory for Hub downloads"
|
||||
)
|
||||
@@ -524,6 +551,30 @@ Examples:
|
||||
_add_model_args(batch_parser)
|
||||
_add_lora_args(batch_parser)
|
||||
|
||||
# Validate subcommand
|
||||
validate_parser = subparsers.add_parser(
|
||||
"validate",
|
||||
help="Validate a training data manifest (JSONL) before fine-tuning",
|
||||
)
|
||||
validate_parser.add_argument(
|
||||
"--manifest", "-m", required=True, help="Path to JSONL training manifest"
|
||||
)
|
||||
validate_parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16_000,
|
||||
help="Expected audio sample rate in Hz (default: 16000)",
|
||||
)
|
||||
validate_parser.add_argument(
|
||||
"--max-samples",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Maximum number of samples to validate (0 = all, default: 0)",
|
||||
)
|
||||
validate_parser.add_argument(
|
||||
"--verbose", "-v", action="store_true", help="Print per-sample progress"
|
||||
)
|
||||
|
||||
# Legacy root arguments
|
||||
parser.add_argument("--input", "-i", help="Input text file (batch mode only)")
|
||||
parser.add_argument(
|
||||
@@ -576,6 +627,9 @@ def main():
|
||||
parser = _build_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "validate":
|
||||
return cmd_validate(args, parser)
|
||||
|
||||
validate_ranges(args, parser)
|
||||
|
||||
if args.command == "design":
|
||||
|
||||
+32
-5
@@ -8,6 +8,7 @@ from typing import Generator, Optional
|
||||
from huggingface_hub import snapshot_download
|
||||
from .model.voxcpm import VoxCPMModel, LoRAConfig
|
||||
from .model.voxcpm2 import VoxCPM2Model
|
||||
from .model.utils import next_and_close
|
||||
|
||||
|
||||
class VoxCPM:
|
||||
@@ -17,6 +18,7 @@ class VoxCPM:
|
||||
zipenhancer_model_path: str | None = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||
enable_denoiser: bool = True,
|
||||
optimize: bool = True,
|
||||
device: str | None = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
lora_weights_path: Optional[str] = None,
|
||||
):
|
||||
@@ -30,6 +32,9 @@ class VoxCPM:
|
||||
id or local path. If None, denoiser will not be initialized.
|
||||
enable_denoiser: Whether to initialize the denoiser pipeline.
|
||||
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
|
||||
device: Runtime device. If set to ``None`` or ``"auto"``, VoxCPM
|
||||
will choose automatically (preferring CUDA, then MPS, then CPU).
|
||||
If set explicitly, that device is used or a clear error is raised.
|
||||
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
|
||||
provided without lora_config, a default config will be created.
|
||||
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
|
||||
@@ -56,10 +61,20 @@ class VoxCPM:
|
||||
arch = config.get("architecture", "voxcpm").lower()
|
||||
|
||||
if arch == "voxcpm2":
|
||||
self.tts_model = VoxCPM2Model.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
||||
self.tts_model = VoxCPM2Model.from_local(
|
||||
voxcpm_model_path,
|
||||
optimize=optimize,
|
||||
device=device,
|
||||
lora_config=lora_config,
|
||||
)
|
||||
print("Loaded VoxCPM2Model", file=sys.stderr)
|
||||
elif arch == "voxcpm":
|
||||
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
||||
self.tts_model = VoxCPMModel.from_local(
|
||||
voxcpm_model_path,
|
||||
optimize=optimize,
|
||||
device=device,
|
||||
lora_config=lora_config,
|
||||
)
|
||||
print("Loaded VoxCPMModel", file=sys.stderr)
|
||||
else:
|
||||
raise ValueError(f"Unsupported architecture: {arch}")
|
||||
@@ -94,6 +109,7 @@ class VoxCPM:
|
||||
cache_dir: str = None,
|
||||
local_files_only: bool = False,
|
||||
optimize: bool = True,
|
||||
device: str | None = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
lora_weights_path: Optional[str] = None,
|
||||
**kwargs,
|
||||
@@ -109,6 +125,9 @@ class VoxCPM:
|
||||
cache_dir: Custom cache directory for the snapshot.
|
||||
local_files_only: If True, only use local files and do not attempt
|
||||
to download.
|
||||
device: Runtime device. Use ``None``/``"auto"`` for automatic
|
||||
fallback, or an explicit value such as ``"cpu"``, ``"mps"``,
|
||||
``"cuda"``, or ``"cuda:0"``.
|
||||
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
|
||||
provided without lora_config, a default config will be created with
|
||||
enable_lm=True and enable_dit=True.
|
||||
@@ -130,7 +149,7 @@ class VoxCPM:
|
||||
if not repo_id:
|
||||
raise ValueError("You must provide hf_model_id")
|
||||
|
||||
# Load from local path if provided
|
||||
# 从本地路径加载(如果提供)
|
||||
if os.path.isdir(repo_id):
|
||||
local_path = repo_id
|
||||
else:
|
||||
@@ -146,13 +165,14 @@ class VoxCPM:
|
||||
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
|
||||
enable_denoiser=load_denoiser,
|
||||
optimize=optimize,
|
||||
device=device,
|
||||
lora_config=lora_config,
|
||||
lora_weights_path=lora_weights_path,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def generate(self, *args, **kwargs) -> np.ndarray:
|
||||
return next(self._generate(*args, streaming=False, **kwargs))
|
||||
return next_and_close(self._generate(*args, streaming=False, **kwargs))
|
||||
|
||||
def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]:
|
||||
return self._generate(*args, streaming=True, **kwargs)
|
||||
@@ -200,7 +220,7 @@ class VoxCPM:
|
||||
Yields audio chunks for each generation step if ``streaming=True``,
|
||||
otherwise yields a single array containing the final audio.
|
||||
"""
|
||||
if not text.strip() or not isinstance(text, str):
|
||||
if not isinstance(text, str) or not text.strip():
|
||||
raise ValueError("target text must be a non-empty string")
|
||||
|
||||
if prompt_wav_path is not None:
|
||||
@@ -273,8 +293,15 @@ class VoxCPM:
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
if streaming:
|
||||
try:
|
||||
for wav, _, _ in generate_result:
|
||||
yield wav.squeeze(0).cpu().numpy()
|
||||
finally:
|
||||
generate_result.close()
|
||||
else:
|
||||
wav, _, _ = next_and_close(generate_result)
|
||||
yield wav.squeeze(0).cpu().numpy()
|
||||
|
||||
finally:
|
||||
for tmp_path in temp_files:
|
||||
|
||||
+111
-1
@@ -1,7 +1,25 @@
|
||||
from typing import List
|
||||
import os
|
||||
from typing import List, Optional
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
_LOW_PRECISION_DTYPES = {"bfloat16", "bf16", "float16", "fp16"}
|
||||
_VALID_DTYPE_OVERRIDES = {
|
||||
"bfloat16", "bf16",
|
||||
"float16", "fp16",
|
||||
"float32", "fp32",
|
||||
}
|
||||
|
||||
|
||||
# Ref: https://github.com/OpenBMB/VoxCPM/issues/256#issuecomment-4235252732
|
||||
# Explicitly close partially-consumed generators so inference_mode cleanup
|
||||
# does not get deferred to Python's GC/finalizer path.
|
||||
def next_and_close(gen):
|
||||
try:
|
||||
return next(gen)
|
||||
finally:
|
||||
gen.close()
|
||||
|
||||
|
||||
def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
|
||||
"""Create a tokenizer wrapper that converts multi-character Chinese tokens to single characters.
|
||||
@@ -119,3 +137,95 @@ def get_dtype(dtype: str):
|
||||
return torch.float32
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
|
||||
def _has_mps() -> bool:
|
||||
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||
|
||||
|
||||
def pick_runtime_dtype(device: str, configured_dtype: str) -> str:
|
||||
"""Pick a safe runtime dtype for the resolved device.
|
||||
|
||||
On Apple Silicon (MPS), bfloat16/float16 produce enough numerical drift
|
||||
in the diffusion AR loop that the output is glitched and the model's
|
||||
badcase detector triggers infinite retries. float32 is the only stable
|
||||
option today. CUDA and CPU keep whatever the checkpoint was trained with.
|
||||
|
||||
Users can override with ``VOXCPM_MPS_DTYPE`` (e.g. ``bfloat16``) when
|
||||
they want to test future MPS improvements.
|
||||
"""
|
||||
if device != "mps":
|
||||
return configured_dtype
|
||||
|
||||
override = os.environ.get("VOXCPM_MPS_DTYPE", "").strip().lower()
|
||||
if override:
|
||||
if override not in _VALID_DTYPE_OVERRIDES:
|
||||
raise ValueError(
|
||||
f"VOXCPM_MPS_DTYPE='{override}' is not one of "
|
||||
f"{sorted(_VALID_DTYPE_OVERRIDES)}"
|
||||
)
|
||||
return override
|
||||
|
||||
if (configured_dtype or "").lower() in _LOW_PRECISION_DTYPES:
|
||||
return "float32"
|
||||
return configured_dtype
|
||||
|
||||
|
||||
def auto_select_device(preferred_device: Optional[str] = "cuda") -> str:
|
||||
"""
|
||||
Choose a runtime device automatically.
|
||||
|
||||
Preference order:
|
||||
- if the preferred device is available, use it
|
||||
- otherwise fall back to CUDA -> MPS -> CPU
|
||||
"""
|
||||
preferred = (preferred_device or "cuda").strip().lower()
|
||||
|
||||
if preferred.startswith("cuda") and torch.cuda.is_available():
|
||||
return preferred
|
||||
if preferred == "mps" and _has_mps():
|
||||
return "mps"
|
||||
if preferred == "cpu":
|
||||
return "cpu"
|
||||
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
if _has_mps():
|
||||
return "mps"
|
||||
return "cpu"
|
||||
|
||||
|
||||
def resolve_runtime_device(device: Optional[str], configured_device: str = "cuda") -> str:
|
||||
"""
|
||||
Resolve the actual runtime device.
|
||||
|
||||
Semantics:
|
||||
- ``device`` is ``None`` or ``"auto"``: use automatic fallback selection
|
||||
- otherwise: treat it as an explicit user choice and validate availability
|
||||
"""
|
||||
explicit = None if device is None else device.strip().lower()
|
||||
|
||||
if explicit is None or explicit == "auto":
|
||||
return auto_select_device(configured_device)
|
||||
|
||||
if explicit.startswith("cuda"):
|
||||
if not torch.cuda.is_available():
|
||||
raise ValueError(
|
||||
f"Requested device '{device}', but CUDA is not available. "
|
||||
"Use device='auto' for automatic fallback."
|
||||
)
|
||||
return explicit
|
||||
if explicit == "mps":
|
||||
if not _has_mps():
|
||||
raise ValueError(
|
||||
"Requested device 'mps', but MPS is not available. "
|
||||
"Use device='auto' for automatic fallback."
|
||||
)
|
||||
return "mps"
|
||||
if explicit == "cpu":
|
||||
return "cpu"
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported device '{device}'. Supported values are 'auto', 'cpu', 'mps', "
|
||||
"'cuda', or indexed CUDA devices like 'cuda:0'."
|
||||
)
|
||||
|
||||
+37
-17
@@ -44,7 +44,13 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules
|
||||
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
|
||||
from ..modules.locenc import VoxCPMLocEnc
|
||||
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
||||
from .utils import get_dtype, mask_multichar_chinese_tokens
|
||||
from .utils import (
|
||||
get_dtype,
|
||||
mask_multichar_chinese_tokens,
|
||||
next_and_close,
|
||||
pick_runtime_dtype,
|
||||
resolve_runtime_device,
|
||||
)
|
||||
|
||||
|
||||
class VoxCPMEncoderConfig(BaseModel):
|
||||
@@ -109,18 +115,22 @@ class VoxCPMModel(nn.Module):
|
||||
tokenizer: LlamaTokenizerFast,
|
||||
audio_vae: AudioVAE,
|
||||
lora_config: LoRAConfig = None,
|
||||
device: str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
self.feat_dim = config.feat_dim
|
||||
self.patch_size = config.patch_size
|
||||
self.device = config.device
|
||||
if not torch.cuda.is_available():
|
||||
if torch.backends.mps.is_available():
|
||||
self.device = "mps"
|
||||
else:
|
||||
self.device = "cpu"
|
||||
self.device = resolve_runtime_device(device, config.device)
|
||||
self.config.device = self.device
|
||||
resolved_dtype = pick_runtime_dtype(self.device, self.config.dtype)
|
||||
if resolved_dtype != self.config.dtype:
|
||||
print(
|
||||
f"[voxcpm] adjusted dtype {self.config.dtype} -> {resolved_dtype} for device {self.device}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
self.config.dtype = resolved_dtype
|
||||
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
|
||||
|
||||
# Text-Semantic LM
|
||||
@@ -227,6 +237,7 @@ class VoxCPMModel(nn.Module):
|
||||
self.residual_lm.forward_step = torch.compile(
|
||||
self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True
|
||||
)
|
||||
self._feat_encoder_raw = self.feat_encoder
|
||||
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
||||
self.feat_decoder.estimator = torch.compile(
|
||||
self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True
|
||||
@@ -337,7 +348,7 @@ class VoxCPMModel(nn.Module):
|
||||
return get_dtype(self.config.dtype)
|
||||
|
||||
def generate(self, *args, **kwargs) -> torch.Tensor:
|
||||
return next(self._generate(*args, streaming=False, **kwargs))
|
||||
return next_and_close(self._generate(*args, streaming=False, **kwargs))
|
||||
|
||||
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
|
||||
return self._generate(*args, streaming=True, **kwargs)
|
||||
@@ -463,7 +474,7 @@ class VoxCPMModel(nn.Module):
|
||||
yield decode_audio
|
||||
break
|
||||
else:
|
||||
latent_pred, pred_audio_feat = next(inference_result)
|
||||
latent_pred, pred_audio_feat = next_and_close(inference_result)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(
|
||||
@@ -571,7 +582,7 @@ class VoxCPMModel(nn.Module):
|
||||
return merged_cache
|
||||
|
||||
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
||||
return next_and_close(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
||||
|
||||
def generate_with_prompt_cache_streaming(
|
||||
self, *args, **kwargs
|
||||
@@ -690,7 +701,7 @@ class VoxCPMModel(nn.Module):
|
||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||
break
|
||||
else:
|
||||
latent_pred, pred_audio_feat = next(inference_result)
|
||||
latent_pred, pred_audio_feat = next_and_close(inference_result)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(
|
||||
@@ -713,7 +724,7 @@ class VoxCPMModel(nn.Module):
|
||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||
|
||||
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return next(self._inference(*args, streaming=False, **kwargs))
|
||||
return next_and_close(self._inference(*args, streaming=False, **kwargs))
|
||||
|
||||
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
|
||||
return self._inference(*args, streaming=True, **kwargs)
|
||||
@@ -755,7 +766,8 @@ class VoxCPMModel(nn.Module):
|
||||
"""
|
||||
B, T, P, D = feat.shape
|
||||
|
||||
feat_embed = self.feat_encoder(feat) # [b, t, h_feat]
|
||||
prefill_encoder = getattr(self, "_feat_encoder_raw", self.feat_encoder)
|
||||
feat_embed = prefill_encoder(feat) # [b, t, h_feat]
|
||||
feat_embed = self.enc_to_lm_proj(feat_embed)
|
||||
|
||||
if self.config.lm_config.use_mup:
|
||||
@@ -845,8 +857,16 @@ class VoxCPMModel(nn.Module):
|
||||
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
|
||||
|
||||
@classmethod
|
||||
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
|
||||
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
||||
def from_local(
|
||||
cls,
|
||||
path: str,
|
||||
optimize: bool = True,
|
||||
training: bool = False,
|
||||
device: str | None = None,
|
||||
lora_config: LoRAConfig = None,
|
||||
):
|
||||
with open(os.path.join(path, "config.json"), "r", encoding="utf-8") as _cfg_f:
|
||||
config = VoxCPMConfig.model_validate_json(_cfg_f.read())
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
||||
audio_vae_config = getattr(config, "audio_vae_config", None)
|
||||
audio_vae = AudioVAE(config=audio_vae_config) if audio_vae_config else AudioVAE()
|
||||
@@ -868,7 +888,7 @@ class VoxCPMModel(nn.Module):
|
||||
raise FileNotFoundError(
|
||||
f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
|
||||
)
|
||||
model = cls(config, tokenizer, audio_vae, lora_config)
|
||||
model = cls(config, tokenizer, audio_vae, lora_config, device=device)
|
||||
if not training:
|
||||
lm_dtype = get_dtype(model.config.dtype)
|
||||
model = model.to(lm_dtype)
|
||||
@@ -950,7 +970,7 @@ class VoxCPMModel(nn.Module):
|
||||
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
|
||||
state_dict = load_file(str(safetensors_file), device=device)
|
||||
elif ckpt_file and ckpt_file.exists():
|
||||
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
|
||||
ckpt = torch.load(ckpt_file, map_location=device, weights_only=True)
|
||||
state_dict = ckpt.get("state_dict", ckpt)
|
||||
else:
|
||||
raise FileNotFoundError(f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}")
|
||||
|
||||
+45
-25
@@ -45,7 +45,13 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules
|
||||
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiTV2
|
||||
from ..modules.locenc import VoxCPMLocEnc
|
||||
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
||||
from .utils import get_dtype, mask_multichar_chinese_tokens
|
||||
from .utils import (
|
||||
get_dtype,
|
||||
mask_multichar_chinese_tokens,
|
||||
next_and_close,
|
||||
pick_runtime_dtype,
|
||||
resolve_runtime_device,
|
||||
)
|
||||
|
||||
|
||||
# A simple function to trim audio silence using VAD, not used default
|
||||
@@ -151,18 +157,22 @@ class VoxCPM2Model(nn.Module):
|
||||
tokenizer: LlamaTokenizerFast,
|
||||
audio_vae: AudioVAEV2,
|
||||
lora_config: LoRAConfig = None,
|
||||
device: str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
self.feat_dim = config.feat_dim
|
||||
self.patch_size = config.patch_size
|
||||
self.device = config.device
|
||||
if not torch.cuda.is_available():
|
||||
if torch.backends.mps.is_available():
|
||||
self.device = "mps"
|
||||
else:
|
||||
self.device = "cpu"
|
||||
self.device = resolve_runtime_device(device, config.device)
|
||||
self.config.device = self.device
|
||||
resolved_dtype = pick_runtime_dtype(self.device, self.config.dtype)
|
||||
if resolved_dtype != self.config.dtype:
|
||||
print(
|
||||
f"[voxcpm2] adjusted dtype {self.config.dtype} -> {resolved_dtype} for device {self.device}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
self.config.dtype = resolved_dtype
|
||||
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
|
||||
|
||||
# Text-Semantic LM
|
||||
@@ -275,6 +285,7 @@ class VoxCPM2Model(nn.Module):
|
||||
self.residual_lm.forward_step = torch.compile(
|
||||
self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True
|
||||
)
|
||||
self._feat_encoder_raw = self.feat_encoder
|
||||
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
||||
self.feat_decoder.estimator = torch.compile(
|
||||
self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True
|
||||
@@ -443,7 +454,7 @@ class VoxCPM2Model(nn.Module):
|
||||
return tokens, feats, t_mask, a_mask
|
||||
|
||||
def generate(self, *args, **kwargs) -> torch.Tensor:
|
||||
return next(self._generate(*args, streaming=False, **kwargs))
|
||||
return next_and_close(self._generate(*args, streaming=False, **kwargs))
|
||||
|
||||
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
|
||||
return self._generate(*args, streaming=True, **kwargs)
|
||||
@@ -636,14 +647,14 @@ class VoxCPM2Model(nn.Module):
|
||||
streaming_prefix_len=streaming_prefix_len,
|
||||
)
|
||||
if streaming:
|
||||
decode_patch_len = self.patch_size * self._decode_chunk_size
|
||||
with self.audio_vae.streaming_decode() as vae_dec:
|
||||
for latent_pred, _, _ctx in inference_result:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||
decode_audio = decode_audio[..., -decode_patch_len:].squeeze(1).cpu()
|
||||
decode_audio = vae_dec.decode_chunk(latent_pred.to(torch.float32))
|
||||
decode_audio = decode_audio.squeeze(1).cpu()
|
||||
yield decode_audio
|
||||
break
|
||||
else:
|
||||
latent_pred, pred_audio_feat, context_len = next(inference_result)
|
||||
latent_pred, pred_audio_feat, context_len = next_and_close(inference_result)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(
|
||||
@@ -761,7 +772,7 @@ class VoxCPM2Model(nn.Module):
|
||||
return merged
|
||||
|
||||
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
||||
return next_and_close(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
||||
|
||||
def generate_with_prompt_cache_streaming(
|
||||
self, *args, **kwargs
|
||||
@@ -923,14 +934,14 @@ class VoxCPM2Model(nn.Module):
|
||||
streaming_prefix_len=streaming_prefix_len,
|
||||
)
|
||||
if streaming:
|
||||
decode_patch_len = self.patch_size * self._decode_chunk_size
|
||||
with self.audio_vae.streaming_decode() as vae_dec:
|
||||
for latent_pred, pred_audio_feat, _ctx in inference_result:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||
decode_audio = decode_audio[..., -decode_patch_len:].squeeze(1).cpu()
|
||||
decode_audio = vae_dec.decode_chunk(latent_pred.to(torch.float32))
|
||||
decode_audio = decode_audio.squeeze(1).cpu()
|
||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||
break
|
||||
else:
|
||||
latent_pred, pred_audio_feat, context_len = next(inference_result)
|
||||
latent_pred, pred_audio_feat, context_len = next_and_close(inference_result)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(
|
||||
@@ -953,7 +964,7 @@ class VoxCPM2Model(nn.Module):
|
||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||
|
||||
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
feat_pred, generated_feat, _ = next(self._inference(*args, streaming=False, **kwargs))
|
||||
feat_pred, generated_feat, _ = next_and_close(self._inference(*args, streaming=False, **kwargs))
|
||||
return feat_pred, generated_feat
|
||||
|
||||
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
|
||||
@@ -997,7 +1008,8 @@ class VoxCPM2Model(nn.Module):
|
||||
"""
|
||||
B, T, P, D = feat.shape
|
||||
|
||||
feat_embed = self.feat_encoder(feat) # [b, t, h_feat]
|
||||
prefill_encoder = getattr(self, "_feat_encoder_raw", self.feat_encoder)
|
||||
feat_embed = prefill_encoder(feat) # [b, t, h_feat]
|
||||
feat_embed = self.enc_to_lm_proj(feat_embed)
|
||||
|
||||
if self.config.lm_config.use_mup:
|
||||
@@ -1067,8 +1079,8 @@ class VoxCPM2Model(nn.Module):
|
||||
prefix_feat_cond = pred_feat
|
||||
|
||||
if streaming:
|
||||
pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1)
|
||||
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
# Yield only the newest patch latent for stateful VAE decode
|
||||
feat_pred = rearrange(pred_feat.unsqueeze(1), "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
|
||||
yield feat_pred, pred_feat_seq, context_len
|
||||
|
||||
@@ -1096,8 +1108,16 @@ class VoxCPM2Model(nn.Module):
|
||||
yield feat_pred, generated_feat, context_len
|
||||
|
||||
@classmethod
|
||||
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
|
||||
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
||||
def from_local(
|
||||
cls,
|
||||
path: str,
|
||||
optimize: bool = True,
|
||||
training: bool = False,
|
||||
device: str | None = None,
|
||||
lora_config: LoRAConfig = None,
|
||||
):
|
||||
with open(os.path.join(path, "config.json"), "r", encoding="utf-8") as _cfg_f:
|
||||
config = VoxCPMConfig.model_validate_json(_cfg_f.read())
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
||||
audio_vae_config = getattr(config, "audio_vae_config", None)
|
||||
audio_vae = AudioVAEV2(config=audio_vae_config) if audio_vae_config else AudioVAEV2()
|
||||
@@ -1119,7 +1139,7 @@ class VoxCPM2Model(nn.Module):
|
||||
raise FileNotFoundError(
|
||||
f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
|
||||
)
|
||||
model = cls(config, tokenizer, audio_vae, lora_config)
|
||||
model = cls(config, tokenizer, audio_vae, lora_config, device=device)
|
||||
if not training:
|
||||
lm_dtype = get_dtype(model.config.dtype)
|
||||
model = model.to(lm_dtype)
|
||||
@@ -1201,7 +1221,7 @@ class VoxCPM2Model(nn.Module):
|
||||
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
|
||||
state_dict = load_file(str(safetensors_file), device=device)
|
||||
elif ckpt_file and ckpt_file.exists():
|
||||
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
|
||||
ckpt = torch.load(ckpt_file, map_location=device, weights_only=True)
|
||||
state_dict = ckpt.get("state_dict", ckpt)
|
||||
else:
|
||||
raise FileNotFoundError(f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}")
|
||||
|
||||
@@ -472,6 +472,20 @@ class AudioVAE(nn.Module):
|
||||
sr_cond = torch.tensor([self.out_sample_rate], device=z.device, dtype=torch.int32)
|
||||
return self.decoder(z, sr_cond)
|
||||
|
||||
def streaming_decode(self):
|
||||
"""Return a ``StreamingVAEDecoder`` context manager for stateful
|
||||
chunk-by-chunk decoding. Each call to ``decode_chunk`` processes only
|
||||
the new latent patch and carries causal-conv state internally, avoiding
|
||||
the redundant overlap decode used previously.
|
||||
|
||||
Usage::
|
||||
|
||||
with vae.streaming_decode() as dec:
|
||||
for patch in patches:
|
||||
audio_chunk = dec.decode_chunk(patch)
|
||||
"""
|
||||
return StreamingVAEDecoder(self)
|
||||
|
||||
def encode(self, audio_data: torch.Tensor, sample_rate: int):
|
||||
"""
|
||||
Args:
|
||||
@@ -485,3 +499,82 @@ class AudioVAE(nn.Module):
|
||||
|
||||
audio_data = self.preprocess(audio_data, sample_rate)
|
||||
return self.encoder(audio_data)["mu"]
|
||||
|
||||
|
||||
class StreamingVAEDecoder:
|
||||
"""Stateful streaming wrapper for :class:`AudioVAE`.
|
||||
|
||||
Carries causal-convolution padding buffers between calls so that each
|
||||
``decode_chunk`` processes only the new latent patch — no overlap needed.
|
||||
"""
|
||||
|
||||
def __init__(self, vae: AudioVAE):
|
||||
self._vae = vae
|
||||
self._states: dict = {}
|
||||
self._originals: list = []
|
||||
|
||||
# -- context manager --------------------------------------------------
|
||||
def __enter__(self):
|
||||
self._states.clear()
|
||||
self._install()
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
self._restore()
|
||||
self._states.clear()
|
||||
|
||||
# -- public API --------------------------------------------------------
|
||||
def decode_chunk(self, z_chunk: torch.Tensor) -> torch.Tensor:
|
||||
"""Decode a single latent chunk and return the audio waveform."""
|
||||
return self._vae.decode(z_chunk)
|
||||
|
||||
# -- internals ---------------------------------------------------------
|
||||
def _install(self):
|
||||
for name, mod in self._vae.decoder.named_modules():
|
||||
if isinstance(mod, CausalConv1d):
|
||||
pad = mod._CausalConv1d__padding * 2 - mod._CausalConv1d__output_padding
|
||||
if pad > 0:
|
||||
self._patch_causal_conv(mod, pad)
|
||||
elif isinstance(mod, CausalTransposeConv1d):
|
||||
trim = mod._CausalTransposeConv1d__padding * 2 - mod._CausalTransposeConv1d__output_padding
|
||||
ctx = (mod.kernel_size[0] - 1) // mod.stride[0]
|
||||
if ctx > 0:
|
||||
self._patch_transpose_conv(mod, ctx, trim)
|
||||
|
||||
def _patch_causal_conv(self, mod, pad_size):
|
||||
states = self._states
|
||||
key = id(mod)
|
||||
orig = mod.forward
|
||||
|
||||
def fwd(x, _k=key, _p=pad_size, _m=mod):
|
||||
x_pad = torch.cat([states[_k], x], dim=-1) if _k in states else F.pad(x, (_p, 0))
|
||||
if x.shape[-1] >= _p:
|
||||
states[_k] = x[:, :, -_p:].detach()
|
||||
else:
|
||||
prev = states.get(_k, torch.zeros(x.shape[0], x.shape[1], _p,
|
||||
device=x.device, dtype=x.dtype))
|
||||
states[_k] = torch.cat([prev, x], dim=-1)[:, :, -_p:].detach()
|
||||
return nn.Conv1d.forward(_m, x_pad)
|
||||
|
||||
mod.forward = fwd
|
||||
self._originals.append((mod, orig))
|
||||
|
||||
def _patch_transpose_conv(self, mod, ctx, trim):
|
||||
states = self._states
|
||||
key = id(mod)
|
||||
orig = mod.forward
|
||||
|
||||
def fwd(x, _k=key, _c=ctx, _t=trim, _m=mod):
|
||||
x_full = torch.cat([states[_k], x], dim=-1) if _k in states else F.pad(x, (_c, 0))
|
||||
states[_k] = x[:, :, -_c:].detach()
|
||||
out = nn.ConvTranspose1d.forward(_m, x_full)
|
||||
left = _c * _m.stride[0]
|
||||
return out[..., left:-_t] if _t > 0 else out[..., left:]
|
||||
|
||||
mod.forward = fwd
|
||||
self._originals.append((mod, orig))
|
||||
|
||||
def _restore(self):
|
||||
for mod, orig in self._originals:
|
||||
mod.forward = orig
|
||||
self._originals.clear()
|
||||
|
||||
@@ -196,7 +196,9 @@ class MiniCPMAttention(nn.Module):
|
||||
key_cache[:, :, position_id, :] = key_states
|
||||
value_cache[:, :, position_id, :] = value_states
|
||||
|
||||
attn_mask = torch.arange(key_cache.size(2), device=key_cache.device) <= position_id
|
||||
# Use an explicit broadcastable mask shape for SDPA. A 1D mask can
|
||||
# trigger a CPU-side dimension bug in some PyTorch versions.
|
||||
attn_mask = (torch.arange(key_cache.size(2), device=key_cache.device) <= position_id).view(1, 1, 1, -1)
|
||||
|
||||
# ref: https://github.com/pytorch/pytorch/issues/163597
|
||||
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
|
||||
|
||||
@@ -15,6 +15,7 @@ from .data import (
|
||||
BatchProcessor,
|
||||
)
|
||||
from .state import TrainingState
|
||||
from .validate import validate_manifest, ValidationResult
|
||||
|
||||
__all__ = [
|
||||
"Accelerator",
|
||||
@@ -24,4 +25,6 @@ __all__ = [
|
||||
"TrainingState",
|
||||
"load_audio_text_datasets",
|
||||
"build_dataloader",
|
||||
"validate_manifest",
|
||||
"ValidationResult",
|
||||
]
|
||||
|
||||
+53
-11
@@ -12,6 +12,7 @@ from .packers import AudioFeatureProcessingPacker
|
||||
|
||||
DEFAULT_TEXT_COLUMN = "text"
|
||||
DEFAULT_AUDIO_COLUMN = "audio"
|
||||
DEFAULT_REF_AUDIO_COLUMN = "ref_audio"
|
||||
DEFAULT_ID_COLUMN = "dataset_id"
|
||||
|
||||
|
||||
@@ -21,6 +22,7 @@ def load_audio_text_datasets(
|
||||
val_manifest: str = "",
|
||||
text_column: str = DEFAULT_TEXT_COLUMN,
|
||||
audio_column: str = DEFAULT_AUDIO_COLUMN,
|
||||
ref_audio_column: str = DEFAULT_REF_AUDIO_COLUMN,
|
||||
dataset_id_column: str = DEFAULT_ID_COLUMN,
|
||||
sample_rate: int = 16_000,
|
||||
num_proc: int = 1,
|
||||
@@ -34,14 +36,19 @@ def load_audio_text_datasets(
|
||||
def prepare(ds: Dataset) -> Dataset:
|
||||
if audio_column not in ds.column_names:
|
||||
raise ValueError(f"Expected '{audio_column}' column in manifest.")
|
||||
# We cast to Audio to ensure proper handling during training,
|
||||
# but for length calculation we might need raw path or duration if available.
|
||||
# HF datasets usually don't compute duration automatically for 'Audio' column.
|
||||
ds = ds.cast_column(audio_column, Audio(sampling_rate=sample_rate))
|
||||
if audio_column != DEFAULT_AUDIO_COLUMN:
|
||||
ds = ds.rename_column(audio_column, DEFAULT_AUDIO_COLUMN)
|
||||
if text_column != DEFAULT_TEXT_COLUMN:
|
||||
ds = ds.rename_column(text_column, DEFAULT_TEXT_COLUMN)
|
||||
|
||||
# ref_audio is optional — cast to Audio if the column exists
|
||||
ref_col = ref_audio_column if ref_audio_column in ds.column_names else DEFAULT_REF_AUDIO_COLUMN
|
||||
if ref_col in ds.column_names:
|
||||
ds = ds.cast_column(ref_col, Audio(sampling_rate=sample_rate))
|
||||
if ref_col != DEFAULT_REF_AUDIO_COLUMN:
|
||||
ds = ds.rename_column(ref_col, DEFAULT_REF_AUDIO_COLUMN)
|
||||
|
||||
if dataset_id_column and dataset_id_column in ds.column_names:
|
||||
if dataset_id_column != DEFAULT_ID_COLUMN:
|
||||
ds = ds.rename_column(dataset_id_column, DEFAULT_ID_COLUMN)
|
||||
@@ -67,11 +74,11 @@ def compute_sample_lengths(
|
||||
- 音频长度:
|
||||
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
|
||||
t_seq = ceil(t_vae / patch_size)
|
||||
- 序列总长约为: text_len + t_seq + 2
|
||||
- 无 ref_audio: text_len + t_seq + 2
|
||||
- 有 ref_audio: text_len + t_seq + ref_seq + 4
|
||||
|
||||
Optimized: Use batch column access instead of iterating item by item.
|
||||
"""
|
||||
# Batch access columns - much faster than per-item access
|
||||
text_ids_list = ds["text_ids"]
|
||||
text_lens = [len(t) for t in text_ids_list]
|
||||
|
||||
@@ -79,18 +86,35 @@ def compute_sample_lengths(
|
||||
if has_duration:
|
||||
durations = ds["duration"]
|
||||
else:
|
||||
# Fallback: need to compute from audio (slow, but unavoidable without duration column)
|
||||
durations = []
|
||||
for i in range(len(ds)):
|
||||
audio = ds[i][DEFAULT_AUDIO_COLUMN]
|
||||
durations.append(len(audio["array"]) / float(audio["sampling_rate"]))
|
||||
|
||||
# Vectorized length computation
|
||||
has_ref_audio = DEFAULT_REF_AUDIO_COLUMN in ds.column_names
|
||||
if has_ref_audio:
|
||||
ref_duration_col = "ref_duration" if "ref_duration" in ds.column_names else None
|
||||
|
||||
lengths = []
|
||||
for text_len, duration in zip(text_lens, durations):
|
||||
for i, (text_len, duration) in enumerate(zip(text_lens, durations)):
|
||||
t_vae = math.ceil(float(duration) * audio_vae_fps)
|
||||
t_seq = math.ceil(t_vae / patch_size)
|
||||
total_len = text_len + t_seq + 2
|
||||
|
||||
ref_seq = 0
|
||||
if has_ref_audio:
|
||||
# Estimate ref_audio length; ref_audio is None for samples without it
|
||||
if ref_duration_col:
|
||||
ref_dur = ds[i].get(ref_duration_col)
|
||||
else:
|
||||
ref_item = ds[i].get(DEFAULT_REF_AUDIO_COLUMN)
|
||||
ref_dur = len(ref_item["array"]) / float(ref_item["sampling_rate"]) if ref_item else None
|
||||
if ref_dur is not None and float(ref_dur) > 0:
|
||||
ref_vae = math.ceil(float(ref_dur) * audio_vae_fps)
|
||||
ref_seq = math.ceil(ref_vae / patch_size)
|
||||
|
||||
# +2 for 101/102; +2 more for 103/104 when ref_audio present
|
||||
overhead = 4 if ref_seq > 0 else 2
|
||||
total_len = text_len + t_seq + ref_seq + overhead
|
||||
lengths.append(total_len)
|
||||
|
||||
return lengths
|
||||
@@ -102,8 +126,11 @@ class HFVoxCPMDataset(TorchDataset):
|
||||
PyTorch-friendly samples.
|
||||
"""
|
||||
|
||||
_SENTINEL = [-100.0]
|
||||
|
||||
def __init__(self, dataset: Dataset):
|
||||
self.dataset = dataset
|
||||
self.has_ref_audio = DEFAULT_REF_AUDIO_COLUMN in dataset.column_names
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
@@ -111,13 +138,17 @@ class HFVoxCPMDataset(TorchDataset):
|
||||
def __getitem__(self, idx: int):
|
||||
item = self.dataset[idx]
|
||||
audio = item[DEFAULT_AUDIO_COLUMN]
|
||||
return {
|
||||
sample = {
|
||||
"text_ids": item["text_ids"],
|
||||
"audio_array": audio["array"],
|
||||
"audio_sampling_rate": audio["sampling_rate"],
|
||||
"dataset_id": item.get(DEFAULT_ID_COLUMN, 0),
|
||||
"is_prompt": item.get("is_prompt", False),
|
||||
}
|
||||
if self.has_ref_audio:
|
||||
ref = item.get(DEFAULT_REF_AUDIO_COLUMN)
|
||||
sample["ref_audio_array"] = ref["array"] if ref else self._SENTINEL
|
||||
return sample
|
||||
|
||||
@staticmethod
|
||||
def pad_sequences(seqs: List[torch.Tensor], pad_value: float):
|
||||
@@ -143,7 +174,7 @@ class HFVoxCPMDataset(TorchDataset):
|
||||
audio_padded = cls.pad_sequences(audio_tensors, pad_value=-100.0)
|
||||
task_ids = torch.ones(text_padded.size(0), dtype=torch.int32)
|
||||
|
||||
return {
|
||||
result = {
|
||||
"text_tokens": text_padded,
|
||||
"audio_tokens": audio_padded,
|
||||
"task_ids": task_ids,
|
||||
@@ -151,6 +182,12 @@ class HFVoxCPMDataset(TorchDataset):
|
||||
"is_prompts": is_prompts,
|
||||
}
|
||||
|
||||
if "ref_audio_array" in batch[0]:
|
||||
ref_tensors = [torch.tensor(s["ref_audio_array"], dtype=torch.float32) for s in batch]
|
||||
result["ref_audio_tokens"] = cls.pad_sequences(ref_tensors, pad_value=-100.0)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class BatchProcessor:
|
||||
"""
|
||||
@@ -184,12 +221,17 @@ class BatchProcessor:
|
||||
task_ids = batch["task_ids"].to(self.device)
|
||||
dataset_ids = batch["dataset_ids"].to(self.device)
|
||||
|
||||
ref_audio_tokens = None
|
||||
if "ref_audio_tokens" in batch:
|
||||
ref_audio_tokens = batch["ref_audio_tokens"].to(self.device)
|
||||
|
||||
packed = self.packer(
|
||||
audio_tokens=audio_tokens,
|
||||
text_tokens=text_tokens,
|
||||
task_ids=task_ids,
|
||||
dataset_ids=dataset_ids,
|
||||
is_prompts=batch["is_prompts"],
|
||||
ref_audio_tokens=ref_audio_tokens,
|
||||
)
|
||||
return packed
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -14,7 +14,6 @@ class AudioFeatureProcessingPacker:
|
||||
def __init__(self, dataset_cnt: int, max_len: int, patch_size: int, feat_dim: int, audio_vae: nn.Module):
|
||||
self.audio_start_id = 101
|
||||
self.audio_end_id = 102
|
||||
# unused now
|
||||
self.audio_prompt_start_id = 103
|
||||
self.audio_prompt_end_id = 104
|
||||
self.text_eos_token_id = 2
|
||||
@@ -78,11 +77,16 @@ class AudioFeatureProcessingPacker:
|
||||
task_ids: torch.Tensor,
|
||||
dataset_ids: torch.Tensor,
|
||||
is_prompts: List[bool],
|
||||
ref_audio_tokens: Optional[torch.Tensor] = None,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Padding-based batching: each sample in the input batch is processed
|
||||
independently and then padded to a common length (capped by ``max_len``).
|
||||
The result tensors all have shape [B, T, ...].
|
||||
|
||||
If ``ref_audio_tokens`` is provided (same batch dim as ``audio_tokens``),
|
||||
samples whose unpadded ref_audio length > 0 will be processed with the
|
||||
reference-audio path (tokens 103/104 prepended, loss only on target audio).
|
||||
"""
|
||||
device = audio_tokens.device
|
||||
max_dataset_id = int(dataset_ids.max().item()) if dataset_ids.numel() > 0 else -1
|
||||
@@ -101,13 +105,33 @@ class AudioFeatureProcessingPacker:
|
||||
audio_duration_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
|
||||
text_token_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
|
||||
|
||||
for audio_token, text_token, task_id, dataset_idx, is_prompt in zip(
|
||||
audio_tokens, text_tokens, task_ids.tolist(), dataset_ids.tolist(), is_prompts
|
||||
ref_iter = ref_audio_tokens if ref_audio_tokens is not None else [None] * audio_tokens.size(0)
|
||||
|
||||
for audio_token, text_token, task_id, dataset_idx, is_prompt, ref_token in zip(
|
||||
audio_tokens, text_tokens, task_ids.tolist(), dataset_ids.tolist(), is_prompts, ref_iter
|
||||
):
|
||||
unpad_audio_token = self.unpad_audio_tokens(audio_token).to(torch.float32)
|
||||
unpad_text_token = self.unpad_text_tokens(text_token)
|
||||
usage = self.id_to_task[task_id]
|
||||
|
||||
has_ref = False
|
||||
if ref_token is not None:
|
||||
unpad_ref_token = self.unpad_audio_tokens(ref_token).to(torch.float32)
|
||||
if unpad_ref_token.numel() > 0:
|
||||
has_ref = True
|
||||
|
||||
if has_ref:
|
||||
(
|
||||
packed_text,
|
||||
audio_feat,
|
||||
text_mask,
|
||||
audio_mask,
|
||||
loss_mask,
|
||||
labels,
|
||||
audio_duration,
|
||||
text_token_count,
|
||||
) = self.process_tts_data_with_ref(unpad_ref_token, unpad_audio_token, unpad_text_token)
|
||||
else:
|
||||
(
|
||||
packed_text,
|
||||
audio_feat,
|
||||
@@ -294,3 +318,89 @@ class AudioFeatureProcessingPacker:
|
||||
audio_duration,
|
||||
text_token_count,
|
||||
)
|
||||
|
||||
def process_tts_data_with_ref(
|
||||
self,
|
||||
ref_audio_token: torch.Tensor,
|
||||
target_audio_token: torch.Tensor,
|
||||
text_token: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Build a training sequence with reference audio prepended:
|
||||
|
||||
[103, ref_feats, 104, text, 101, target_feats, 102]
|
||||
|
||||
Loss is computed only on the target audio segment.
|
||||
"""
|
||||
device = text_token.device
|
||||
txt_len = len(text_token)
|
||||
|
||||
ref_feats, ref_duration = self.extract_audio_feats(ref_audio_token)
|
||||
ref_feats = ref_feats.squeeze(0) # [R, P, D]
|
||||
ref_len = ref_feats.shape[0]
|
||||
|
||||
tgt_feats, tgt_duration = self.extract_audio_feats(target_audio_token)
|
||||
tgt_feats = tgt_feats.squeeze(0) # [A, P, D]
|
||||
tgt_len = tgt_feats.shape[0]
|
||||
|
||||
feat_shape = (self.patch_size, ref_feats.size(-1))
|
||||
|
||||
def _tok(ids):
|
||||
return torch.tensor(ids, dtype=torch.int32, device=device)
|
||||
|
||||
# -- text token track --
|
||||
# [103, 0×R, 104, text_ids, 101, 0×A, 102]
|
||||
text_token_info = torch.cat([
|
||||
_tok([self.audio_prompt_start_id]),
|
||||
torch.zeros(ref_len, dtype=torch.int32, device=device),
|
||||
_tok([self.audio_prompt_end_id]),
|
||||
text_token,
|
||||
_tok([self.audio_start_id]),
|
||||
torch.zeros(tgt_len, dtype=torch.int32, device=device),
|
||||
_tok([self.audio_end_id]),
|
||||
])
|
||||
|
||||
# -- audio feature track --
|
||||
zero_1 = torch.zeros((1,) + feat_shape, dtype=torch.float32, device=device)
|
||||
zero_txt = torch.zeros((txt_len,) + feat_shape, dtype=torch.float32, device=device)
|
||||
audio_feat_info = torch.cat([
|
||||
zero_1, ref_feats, zero_1, # 103, ref, 104
|
||||
zero_txt, # text
|
||||
zero_1, tgt_feats, zero_1, # 101, target, 102
|
||||
], dim=0)
|
||||
|
||||
# -- masks --
|
||||
text_mask = torch.cat([
|
||||
torch.ones(1), torch.zeros(ref_len), torch.ones(1),
|
||||
torch.ones(txt_len),
|
||||
torch.ones(1), torch.zeros(tgt_len), torch.ones(1),
|
||||
]).to(torch.int32).to(device)
|
||||
|
||||
audio_mask = torch.cat([
|
||||
torch.zeros(1), torch.ones(ref_len), torch.zeros(1),
|
||||
torch.zeros(txt_len),
|
||||
torch.zeros(1), torch.ones(tgt_len), torch.zeros(1),
|
||||
]).to(torch.int32).to(device)
|
||||
|
||||
loss_mask = torch.cat([
|
||||
torch.zeros(1 + ref_len + 1), # ref part: no loss
|
||||
torch.zeros(txt_len), # text: no loss
|
||||
torch.zeros(1), # 101: no loss
|
||||
torch.ones(tgt_len), # target audio: LOSS
|
||||
torch.zeros(1), # 102: no loss
|
||||
]).to(torch.int32).to(device)
|
||||
|
||||
total_len = 1 + ref_len + 1 + txt_len + 1 + tgt_len + 1
|
||||
labels = torch.zeros(total_len, dtype=torch.int32, device=device)
|
||||
labels[-2] = 1 # stop label at last target audio position
|
||||
|
||||
return (
|
||||
text_token_info,
|
||||
audio_feat_info,
|
||||
text_mask,
|
||||
audio_mask,
|
||||
loss_mask,
|
||||
labels,
|
||||
ref_duration + tgt_duration,
|
||||
txt_len,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,309 @@
|
||||
"""
|
||||
Pre-flight validation for VoxCPM training data manifests.
|
||||
|
||||
Validates JSONL manifest files before starting expensive fine-tuning jobs,
|
||||
catching format issues, missing files, and data quality problems early.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""Structured result of a manifest validation run."""
|
||||
|
||||
total_samples: int = 0
|
||||
valid_samples: int = 0
|
||||
errors: List[str] = field(default_factory=list)
|
||||
warnings: List[str] = field(default_factory=list)
|
||||
audio_durations: List[float] = field(default_factory=list)
|
||||
text_lengths: List[int] = field(default_factory=list)
|
||||
has_ref_audio: int = 0
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
return len(self.errors) == 0 and self.valid_samples > 0
|
||||
|
||||
|
||||
def _check_audio_file(audio_path: str, sample_rate: int) -> Optional[str]:
|
||||
"""Check if an audio file exists, is readable, and matches expected sample rate.
|
||||
|
||||
Returns an error message, or None if the file is valid.
|
||||
"""
|
||||
if not os.path.isfile(audio_path):
|
||||
return f"Audio file not found: {audio_path}"
|
||||
try:
|
||||
import soundfile as sf
|
||||
|
||||
info = sf.info(audio_path)
|
||||
if info.frames == 0:
|
||||
return f"Audio file is empty: {audio_path}"
|
||||
if info.samplerate != sample_rate:
|
||||
return (
|
||||
f"Sample rate mismatch in {audio_path}: "
|
||||
f"expected {sample_rate} Hz, got {info.samplerate} Hz"
|
||||
)
|
||||
return None
|
||||
except ImportError:
|
||||
# soundfile not available; just check existence
|
||||
return None
|
||||
except Exception as e:
|
||||
return f"Cannot read audio file {audio_path}: {e}"
|
||||
|
||||
|
||||
def _get_audio_duration(audio_path: str) -> Optional[float]:
|
||||
"""Get audio duration in seconds. Returns None if unavailable."""
|
||||
try:
|
||||
import soundfile as sf
|
||||
|
||||
info = sf.info(audio_path)
|
||||
return info.duration
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def validate_manifest(
|
||||
manifest_path: str,
|
||||
sample_rate: int = 16_000,
|
||||
max_samples: int = 0,
|
||||
verbose: bool = False,
|
||||
) -> ValidationResult:
|
||||
"""Validate a JSONL training manifest file.
|
||||
|
||||
Checks:
|
||||
1. File exists and is readable
|
||||
2. Each line is valid JSON
|
||||
3. Required columns present (text, audio)
|
||||
4. Audio files exist and are readable
|
||||
5. Text content is non-empty
|
||||
6. Collects duration and text length statistics
|
||||
7. Validates optional ref_audio column
|
||||
|
||||
Args:
|
||||
manifest_path: Path to the JSONL manifest file.
|
||||
sample_rate: Expected audio sample rate (for informational purposes).
|
||||
max_samples: Maximum number of samples to validate (0 = all).
|
||||
verbose: Print per-sample progress.
|
||||
|
||||
Returns:
|
||||
ValidationResult with errors, warnings, and statistics.
|
||||
"""
|
||||
result = ValidationResult()
|
||||
path = Path(manifest_path)
|
||||
|
||||
if not path.exists():
|
||||
result.errors.append(f"Manifest file not found: {manifest_path}")
|
||||
return result
|
||||
|
||||
if not path.is_file():
|
||||
result.errors.append(f"Manifest path is not a file: {manifest_path}")
|
||||
return result
|
||||
|
||||
manifest_dir = path.parent
|
||||
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
except Exception as e:
|
||||
result.errors.append(f"Cannot read manifest file: {e}")
|
||||
return result
|
||||
|
||||
if not lines:
|
||||
result.errors.append("Manifest file is empty")
|
||||
return result
|
||||
|
||||
samples_to_check = len(lines)
|
||||
if max_samples > 0:
|
||||
samples_to_check = min(samples_to_check, max_samples)
|
||||
|
||||
missing_audio_count = 0
|
||||
empty_text_count = 0
|
||||
|
||||
for i, line in enumerate(lines[:samples_to_check]):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
result.total_samples += 1
|
||||
|
||||
# Check JSON validity
|
||||
try:
|
||||
entry = json.loads(line)
|
||||
except json.JSONDecodeError as e:
|
||||
result.errors.append(f"Line {i + 1}: Invalid JSON — {e}")
|
||||
continue
|
||||
|
||||
if not isinstance(entry, dict):
|
||||
result.errors.append(f"Line {i + 1}: Expected JSON object, got {type(entry).__name__}")
|
||||
continue
|
||||
|
||||
# Check required columns
|
||||
has_error = False
|
||||
|
||||
if "text" not in entry:
|
||||
result.errors.append(f"Line {i + 1}: Missing required column 'text'")
|
||||
has_error = True
|
||||
|
||||
if "audio" not in entry:
|
||||
result.errors.append(f"Line {i + 1}: Missing required column 'audio'")
|
||||
has_error = True
|
||||
|
||||
if has_error:
|
||||
continue
|
||||
|
||||
# Validate text
|
||||
text = entry["text"]
|
||||
if not isinstance(text, str) or not text.strip():
|
||||
empty_text_count += 1
|
||||
if empty_text_count <= 5:
|
||||
result.warnings.append(f"Line {i + 1}: Empty or non-string text")
|
||||
else:
|
||||
result.text_lengths.append(len(text))
|
||||
|
||||
# Validate audio path
|
||||
audio_path = entry["audio"]
|
||||
if isinstance(audio_path, dict):
|
||||
# HuggingFace Audio format with {"path": ..., "array": ...}
|
||||
audio_path = audio_path.get("path", "")
|
||||
|
||||
if isinstance(audio_path, str) and audio_path:
|
||||
# Resolve relative paths against manifest directory
|
||||
if not os.path.isabs(audio_path):
|
||||
audio_path = str(manifest_dir / audio_path)
|
||||
|
||||
audio_error = _check_audio_file(audio_path, sample_rate)
|
||||
if audio_error:
|
||||
missing_audio_count += 1
|
||||
if missing_audio_count <= 5:
|
||||
result.errors.append(f"Line {i + 1}: {audio_error}")
|
||||
has_error = True
|
||||
else:
|
||||
duration = _get_audio_duration(audio_path)
|
||||
if duration is not None:
|
||||
result.audio_durations.append(duration)
|
||||
if duration < 0.3:
|
||||
result.warnings.append(
|
||||
f"Line {i + 1}: Very short audio ({duration:.2f}s)"
|
||||
)
|
||||
elif duration > 30.0:
|
||||
result.warnings.append(
|
||||
f"Line {i + 1}: Very long audio ({duration:.1f}s), may cause OOM"
|
||||
)
|
||||
else:
|
||||
result.errors.append(f"Line {i + 1}: Invalid audio path")
|
||||
has_error = True
|
||||
|
||||
# Validate optional ref_audio
|
||||
if "ref_audio" in entry:
|
||||
ref_path = entry["ref_audio"]
|
||||
if isinstance(ref_path, dict):
|
||||
ref_path = ref_path.get("path", "")
|
||||
if isinstance(ref_path, str) and ref_path:
|
||||
if not os.path.isabs(ref_path):
|
||||
ref_path = str(manifest_dir / ref_path)
|
||||
if os.path.isfile(ref_path):
|
||||
result.has_ref_audio += 1
|
||||
else:
|
||||
result.warnings.append(
|
||||
f"Line {i + 1}: ref_audio file not found: {ref_path}"
|
||||
)
|
||||
|
||||
if not has_error:
|
||||
result.valid_samples += 1
|
||||
|
||||
if verbose and (i + 1) % 100 == 0:
|
||||
print(f" Validated {i + 1}/{samples_to_check} samples...", file=sys.stderr)
|
||||
|
||||
# Summarize truncated errors
|
||||
if missing_audio_count > 5:
|
||||
result.errors.append(
|
||||
f"... and {missing_audio_count - 5} more missing audio files "
|
||||
f"({missing_audio_count} total)"
|
||||
)
|
||||
if empty_text_count > 5:
|
||||
result.warnings.append(
|
||||
f"... and {empty_text_count - 5} more empty text entries "
|
||||
f"({empty_text_count} total)"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def print_validation_report(result: ValidationResult, manifest_path: str) -> None:
|
||||
"""Print a human-readable validation report to stderr."""
|
||||
print(f"\n{'=' * 60}", file=sys.stderr)
|
||||
print(f" VoxCPM Training Data Validation Report", file=sys.stderr)
|
||||
print(f"{'=' * 60}", file=sys.stderr)
|
||||
print(f" Manifest : {manifest_path}", file=sys.stderr)
|
||||
print(f" Samples : {result.valid_samples}/{result.total_samples} valid", file=sys.stderr)
|
||||
|
||||
if result.has_ref_audio > 0:
|
||||
print(
|
||||
f" Ref Audio: {result.has_ref_audio} samples with reference audio",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# Audio duration statistics
|
||||
if result.audio_durations:
|
||||
durations = sorted(result.audio_durations)
|
||||
total_hrs = sum(durations) / 3600
|
||||
print(f"\n Audio Duration Statistics:", file=sys.stderr)
|
||||
print(f" Total : {total_hrs:.2f} hours", file=sys.stderr)
|
||||
print(
|
||||
f" Range : {durations[0]:.2f}s — {durations[-1]:.1f}s",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(
|
||||
f" Mean : {sum(durations) / len(durations):.2f}s",
|
||||
file=sys.stderr,
|
||||
)
|
||||
median_idx = len(durations) // 2
|
||||
print(f" Median : {durations[median_idx]:.2f}s", file=sys.stderr)
|
||||
|
||||
# Text length statistics
|
||||
if result.text_lengths:
|
||||
lengths = sorted(result.text_lengths)
|
||||
print(f"\n Text Length Statistics (characters):", file=sys.stderr)
|
||||
print(
|
||||
f" Range : {lengths[0]} — {lengths[-1]}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(
|
||||
f" Mean : {sum(lengths) / len(lengths):.0f}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# Errors
|
||||
if result.errors:
|
||||
print(f"\n ERRORS ({len(result.errors)}):", file=sys.stderr)
|
||||
for err in result.errors[:20]:
|
||||
print(f" x {err}", file=sys.stderr)
|
||||
if len(result.errors) > 20:
|
||||
print(
|
||||
f" ... ({len(result.errors) - 20} more errors omitted)",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# Warnings
|
||||
if result.warnings:
|
||||
print(f"\n WARNINGS ({len(result.warnings)}):", file=sys.stderr)
|
||||
for warn in result.warnings[:10]:
|
||||
print(f" ! {warn}", file=sys.stderr)
|
||||
if len(result.warnings) > 10:
|
||||
print(
|
||||
f" ... ({len(result.warnings) - 10} more warnings omitted)",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# Summary
|
||||
print(f"\n{'=' * 60}", file=sys.stderr)
|
||||
if result.is_valid:
|
||||
print(" PASSED: Manifest is valid for training.", file=sys.stderr)
|
||||
else:
|
||||
print(" FAILED: Fix errors above before starting training.", file=sys.stderr)
|
||||
print(f"{'=' * 60}\n", file=sys.stderr)
|
||||
@@ -58,6 +58,7 @@ def test_parser_defaults_to_voxcpm2():
|
||||
parser = cli._build_parser()
|
||||
args = parser.parse_args(["design", "--text", "hello", "--output", "out.wav"])
|
||||
assert args.hf_model_id == "openbmb/VoxCPM2"
|
||||
assert args.device == "auto"
|
||||
assert args.no_optimize is False
|
||||
|
||||
|
||||
@@ -85,6 +86,7 @@ def test_load_model_respects_no_optimize_for_local_model(monkeypatch):
|
||||
|
||||
cli.load_model(args)
|
||||
|
||||
assert calls["kwargs"]["device"] == "auto"
|
||||
assert calls["kwargs"]["optimize"] is False
|
||||
|
||||
|
||||
@@ -110,6 +112,7 @@ def test_load_model_defaults_optimize_for_hf(monkeypatch):
|
||||
|
||||
cli.load_model(args)
|
||||
|
||||
assert calls["kwargs"]["device"] == "auto"
|
||||
assert calls["kwargs"]["optimize"] is True
|
||||
|
||||
|
||||
@@ -136,9 +139,37 @@ def test_load_model_respects_no_optimize_for_hf(monkeypatch):
|
||||
|
||||
cli.load_model(args)
|
||||
|
||||
assert calls["kwargs"]["device"] == "auto"
|
||||
assert calls["kwargs"]["optimize"] is False
|
||||
|
||||
|
||||
def test_load_model_passes_explicit_device_to_hf(monkeypatch):
|
||||
calls = {}
|
||||
|
||||
class FakeVoxCPM:
|
||||
@classmethod
|
||||
def from_pretrained(cls, **kwargs):
|
||||
calls["kwargs"] = kwargs
|
||||
return DummyModel()
|
||||
|
||||
monkeypatch.setattr(cli, "VoxCPM", FakeVoxCPM)
|
||||
args = cli._build_parser().parse_args(
|
||||
[
|
||||
"design",
|
||||
"--text",
|
||||
"hello",
|
||||
"--output",
|
||||
"out.wav",
|
||||
"--device",
|
||||
"mps",
|
||||
]
|
||||
)
|
||||
|
||||
cli.load_model(args)
|
||||
|
||||
assert calls["kwargs"]["device"] == "mps"
|
||||
|
||||
|
||||
def test_design_subcommand_applies_control(monkeypatch, tmp_path):
|
||||
dummy_model = DummyModel()
|
||||
monkeypatch.setattr(cli, "load_model", lambda args: dummy_model)
|
||||
|
||||
@@ -0,0 +1,150 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
SRC = ROOT / "src"
|
||||
|
||||
|
||||
def _load_module(name: str, path: Path):
|
||||
spec = importlib.util.spec_from_file_location(name, path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec.loader is not None
|
||||
sys.modules[name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def bootstrap_repo_modules(monkeypatch):
|
||||
for name, path in [
|
||||
("voxcpm", SRC / "voxcpm"),
|
||||
("voxcpm.model", SRC / "voxcpm" / "model"),
|
||||
("voxcpm.modules", SRC / "voxcpm" / "modules"),
|
||||
]:
|
||||
pkg = types.ModuleType(name)
|
||||
pkg.__path__ = [str(path)]
|
||||
monkeypatch.setitem(sys.modules, name, pkg)
|
||||
|
||||
hh = types.ModuleType("huggingface_hub")
|
||||
hh.snapshot_download = lambda *a, **k: "/tmp/fake"
|
||||
monkeypatch.setitem(sys.modules, "huggingface_hub", hh)
|
||||
|
||||
pydantic = types.ModuleType("pydantic")
|
||||
|
||||
class BaseModel:
|
||||
@classmethod
|
||||
def model_rebuild(cls):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def model_validate_json(cls, s):
|
||||
return cls()
|
||||
|
||||
def model_dump(self):
|
||||
return {}
|
||||
|
||||
pydantic.BaseModel = BaseModel
|
||||
monkeypatch.setitem(sys.modules, "pydantic", pydantic)
|
||||
|
||||
torchaudio = types.ModuleType("torchaudio")
|
||||
monkeypatch.setitem(sys.modules, "torchaudio", torchaudio)
|
||||
|
||||
librosa = types.ModuleType("librosa")
|
||||
librosa.effects = types.SimpleNamespace(trim=lambda *a, **k: (None, (0, 0)))
|
||||
monkeypatch.setitem(sys.modules, "librosa", librosa)
|
||||
|
||||
einops = types.ModuleType("einops")
|
||||
einops.rearrange = lambda x, *a, **k: x
|
||||
monkeypatch.setitem(sys.modules, "einops", einops)
|
||||
|
||||
tqdm_pkg = types.ModuleType("tqdm")
|
||||
tqdm_pkg.__path__ = ["/nonexistent"]
|
||||
tqdm_pkg.tqdm = lambda x, *a, **k: x
|
||||
monkeypatch.setitem(sys.modules, "tqdm", tqdm_pkg)
|
||||
|
||||
tqdm_auto = types.ModuleType("tqdm.auto")
|
||||
tqdm_auto.tqdm = lambda x, *a, **k: x
|
||||
monkeypatch.setitem(sys.modules, "tqdm.auto", tqdm_auto)
|
||||
|
||||
transformers = types.ModuleType("transformers")
|
||||
|
||||
class LlamaTokenizerFast:
|
||||
pass
|
||||
|
||||
class PreTrainedTokenizer:
|
||||
pass
|
||||
|
||||
transformers.LlamaTokenizerFast = LlamaTokenizerFast
|
||||
transformers.PreTrainedTokenizer = PreTrainedTokenizer
|
||||
monkeypatch.setitem(sys.modules, "transformers", transformers)
|
||||
|
||||
internal_mods = {
|
||||
"voxcpm.modules.audiovae": ["AudioVAE", "AudioVAEConfig", "AudioVAEV2", "AudioVAEConfigV2"],
|
||||
"voxcpm.modules.layers": ["ScalarQuantizationLayer"],
|
||||
"voxcpm.modules.locdit": ["CfmConfig", "UnifiedCFM", "VoxCPMLocDiT", "VoxCPMLocDiTV2"],
|
||||
"voxcpm.modules.locenc": ["VoxCPMLocEnc"],
|
||||
"voxcpm.modules.minicpm4": ["MiniCPM4Config", "MiniCPMModel"],
|
||||
"voxcpm.modules.layers.lora": ["apply_lora_to_named_linear_modules", "LoRALinear"],
|
||||
}
|
||||
for modname, names in internal_mods.items():
|
||||
module = types.ModuleType(modname)
|
||||
for name in names:
|
||||
if name == "apply_lora_to_named_linear_modules":
|
||||
setattr(module, name, lambda *a, **k: None)
|
||||
else:
|
||||
setattr(module, name, type(name, (), {}))
|
||||
monkeypatch.setitem(sys.modules, modname, module)
|
||||
|
||||
_load_module("voxcpm.model.utils", SRC / "voxcpm" / "model" / "utils.py")
|
||||
voxcpm = _load_module("voxcpm.model.voxcpm", SRC / "voxcpm" / "model" / "voxcpm.py")
|
||||
voxcpm2 = _load_module("voxcpm.model.voxcpm2", SRC / "voxcpm" / "model" / "voxcpm2.py")
|
||||
return voxcpm.VoxCPMModel, voxcpm2.VoxCPM2Model
|
||||
|
||||
|
||||
class DummyModel:
|
||||
device = "cpu"
|
||||
|
||||
def named_parameters(self):
|
||||
return []
|
||||
|
||||
|
||||
@pytest.mark.parametrize("module_name", ["v1", "v2"])
|
||||
def test_load_lora_weights_accepts_tensor_only_legacy_checkpoints(monkeypatch, tmp_path, module_name):
|
||||
VoxCPMModel, VoxCPM2Model = bootstrap_repo_modules(monkeypatch)
|
||||
cls = VoxCPMModel if module_name == "v1" else VoxCPM2Model
|
||||
|
||||
ckpt_path = tmp_path / "lora_weights.ckpt"
|
||||
torch.save({"state_dict": {"fake": torch.zeros(1)}}, ckpt_path)
|
||||
|
||||
loaded, skipped = cls.load_lora_weights(DummyModel(), str(ckpt_path), device="cpu")
|
||||
|
||||
assert loaded == []
|
||||
assert skipped == ["fake"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("module_name", ["v1", "v2"])
|
||||
def test_load_lora_weights_rejects_malicious_pickle_payloads(monkeypatch, tmp_path, module_name):
|
||||
VoxCPMModel, VoxCPM2Model = bootstrap_repo_modules(monkeypatch)
|
||||
cls = VoxCPMModel if module_name == "v1" else VoxCPM2Model
|
||||
|
||||
ckpt_path = tmp_path / "lora_weights.ckpt"
|
||||
marker_path = tmp_path / f"{module_name}-marker.txt"
|
||||
|
||||
class Exploit:
|
||||
def __reduce__(self):
|
||||
import pathlib
|
||||
|
||||
return (pathlib.Path.write_text, (marker_path, f"{module_name} executed\n"))
|
||||
|
||||
torch.save({"state_dict": {"fake": torch.zeros(1)}, "boom": Exploit()}, ckpt_path)
|
||||
|
||||
with pytest.raises(Exception, match="Weights only load failed"):
|
||||
cls.load_lora_weights(DummyModel(), str(ckpt_path), device="cpu")
|
||||
|
||||
assert not marker_path.exists()
|
||||
@@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
UTILS_PATH = ROOT / "src" / "voxcpm" / "model" / "utils.py"
|
||||
|
||||
transformers_stub = types.ModuleType("transformers")
|
||||
transformers_stub.PreTrainedTokenizer = object
|
||||
sys.modules.setdefault("transformers", transformers_stub)
|
||||
|
||||
spec = importlib.util.spec_from_file_location("voxcpm.model.utils", UTILS_PATH)
|
||||
utils = importlib.util.module_from_spec(spec)
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(utils)
|
||||
|
||||
|
||||
def test_resolve_runtime_device_auto_falls_back_to_cpu(monkeypatch):
|
||||
monkeypatch.setattr(utils.torch.cuda, "is_available", lambda: False)
|
||||
monkeypatch.setattr(utils, "_has_mps", lambda: False)
|
||||
|
||||
assert utils.resolve_runtime_device(None, "cuda") == "cpu"
|
||||
|
||||
|
||||
def test_resolve_runtime_device_auto_uses_mps_when_available(monkeypatch):
|
||||
monkeypatch.setattr(utils.torch.cuda, "is_available", lambda: False)
|
||||
monkeypatch.setattr(utils, "_has_mps", lambda: True)
|
||||
|
||||
assert utils.resolve_runtime_device("auto", "cuda") == "mps"
|
||||
|
||||
|
||||
def test_resolve_runtime_device_respects_explicit_cpu(monkeypatch):
|
||||
monkeypatch.setattr(utils.torch.cuda, "is_available", lambda: True)
|
||||
monkeypatch.setattr(utils, "_has_mps", lambda: True)
|
||||
|
||||
assert utils.resolve_runtime_device("cpu", "cuda") == "cpu"
|
||||
|
||||
|
||||
def test_resolve_runtime_device_rejects_unavailable_explicit_cuda(monkeypatch):
|
||||
monkeypatch.setattr(utils.torch.cuda, "is_available", lambda: False)
|
||||
monkeypatch.setattr(utils, "_has_mps", lambda: True)
|
||||
|
||||
with pytest.raises(ValueError, match="CUDA is not available"):
|
||||
utils.resolve_runtime_device("cuda:0", "cuda")
|
||||
@@ -0,0 +1,252 @@
|
||||
"""Tests for the training data validation module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
|
||||
# Stub voxcpm package so imports work without full dependencies
|
||||
pkg = types.ModuleType("voxcpm")
|
||||
pkg.__path__ = [str(ROOT / "src" / "voxcpm")]
|
||||
sys.modules.setdefault("voxcpm", pkg)
|
||||
|
||||
training_pkg = types.ModuleType("voxcpm.training")
|
||||
training_pkg.__path__ = [str(ROOT / "src" / "voxcpm" / "training")]
|
||||
sys.modules.setdefault("voxcpm.training", training_pkg)
|
||||
|
||||
from voxcpm.training.validate import ValidationResult, validate_manifest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_dir():
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
yield Path(d)
|
||||
|
||||
|
||||
def _create_wav(path: Path, duration_s: float = 1.0, sr: int = 16000):
|
||||
"""Create a minimal valid WAV file."""
|
||||
try:
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
|
||||
samples = int(duration_s * sr)
|
||||
data = np.zeros(samples, dtype=np.float32)
|
||||
sf.write(str(path), data, sr)
|
||||
except ImportError:
|
||||
# If soundfile is not available, create a minimal WAV header
|
||||
import struct
|
||||
|
||||
samples = int(duration_s * sr)
|
||||
data_size = samples * 2 # 16-bit PCM
|
||||
with open(path, "wb") as f:
|
||||
f.write(b"RIFF")
|
||||
f.write(struct.pack("<I", 36 + data_size))
|
||||
f.write(b"WAVEfmt ")
|
||||
f.write(struct.pack("<IHHIIHH", 16, 1, 1, sr, sr * 2, 2, 16))
|
||||
f.write(b"data")
|
||||
f.write(struct.pack("<I", data_size))
|
||||
f.write(b"\x00" * data_size)
|
||||
|
||||
|
||||
def _write_manifest(path: Path, entries: list[dict]):
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
for entry in entries:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
||||
|
||||
|
||||
class TestValidateManifest:
|
||||
def test_valid_manifest(self, tmp_dir):
|
||||
audio1 = tmp_dir / "audio1.wav"
|
||||
audio2 = tmp_dir / "audio2.wav"
|
||||
_create_wav(audio1, 2.0)
|
||||
_create_wav(audio2, 3.0)
|
||||
|
||||
manifest = tmp_dir / "train.jsonl"
|
||||
_write_manifest(
|
||||
manifest,
|
||||
[
|
||||
{"text": "Hello world", "audio": str(audio1)},
|
||||
{"text": "Goodbye world", "audio": str(audio2)},
|
||||
],
|
||||
)
|
||||
|
||||
result = validate_manifest(str(manifest))
|
||||
assert result.total_samples == 2
|
||||
assert result.valid_samples == 2
|
||||
assert result.is_valid
|
||||
assert len(result.errors) == 0
|
||||
|
||||
def test_missing_manifest(self):
|
||||
result = validate_manifest("/nonexistent/path.jsonl")
|
||||
assert not result.is_valid
|
||||
assert any("not found" in e for e in result.errors)
|
||||
|
||||
def test_empty_manifest(self, tmp_dir):
|
||||
manifest = tmp_dir / "empty.jsonl"
|
||||
manifest.write_text("")
|
||||
result = validate_manifest(str(manifest))
|
||||
assert not result.is_valid
|
||||
|
||||
def test_invalid_json(self, tmp_dir):
|
||||
manifest = tmp_dir / "bad.jsonl"
|
||||
manifest.write_text("not json\n{bad json}\n")
|
||||
result = validate_manifest(str(manifest))
|
||||
assert len(result.errors) >= 2
|
||||
assert any("Invalid JSON" in e for e in result.errors)
|
||||
|
||||
def test_missing_columns(self, tmp_dir):
|
||||
manifest = tmp_dir / "missing.jsonl"
|
||||
_write_manifest(
|
||||
manifest,
|
||||
[
|
||||
{"text": "hello"}, # missing audio
|
||||
{"audio": "test.wav"}, # missing text
|
||||
],
|
||||
)
|
||||
result = validate_manifest(str(manifest))
|
||||
assert len(result.errors) >= 2
|
||||
assert any("'audio'" in e for e in result.errors)
|
||||
assert any("'text'" in e for e in result.errors)
|
||||
|
||||
def test_missing_audio_file(self, tmp_dir):
|
||||
manifest = tmp_dir / "missing_audio.jsonl"
|
||||
_write_manifest(
|
||||
manifest,
|
||||
[{"text": "hello", "audio": "/nonexistent/audio.wav"}],
|
||||
)
|
||||
result = validate_manifest(str(manifest))
|
||||
assert not result.is_valid
|
||||
assert any("not found" in e for e in result.errors)
|
||||
|
||||
def test_empty_text_warning(self, tmp_dir):
|
||||
audio = tmp_dir / "audio.wav"
|
||||
_create_wav(audio)
|
||||
manifest = tmp_dir / "empty_text.jsonl"
|
||||
_write_manifest(
|
||||
manifest,
|
||||
[{"text": "", "audio": str(audio)}],
|
||||
)
|
||||
result = validate_manifest(str(manifest))
|
||||
assert len(result.warnings) > 0
|
||||
assert any("Empty" in w for w in result.warnings)
|
||||
|
||||
def test_relative_audio_path(self, tmp_dir):
|
||||
audio = tmp_dir / "audio.wav"
|
||||
_create_wav(audio)
|
||||
manifest = tmp_dir / "rel.jsonl"
|
||||
_write_manifest(
|
||||
manifest,
|
||||
[{"text": "hello", "audio": "audio.wav"}],
|
||||
)
|
||||
result = validate_manifest(str(manifest))
|
||||
assert result.valid_samples == 1
|
||||
assert result.is_valid
|
||||
|
||||
def test_max_samples_limit(self, tmp_dir):
|
||||
audio = tmp_dir / "audio.wav"
|
||||
_create_wav(audio)
|
||||
manifest = tmp_dir / "many.jsonl"
|
||||
_write_manifest(
|
||||
manifest,
|
||||
[{"text": f"sample {i}", "audio": str(audio)} for i in range(100)],
|
||||
)
|
||||
result = validate_manifest(str(manifest), max_samples=10)
|
||||
assert result.total_samples == 10
|
||||
|
||||
def test_ref_audio_counted(self, tmp_dir):
|
||||
audio = tmp_dir / "audio.wav"
|
||||
ref = tmp_dir / "ref.wav"
|
||||
_create_wav(audio)
|
||||
_create_wav(ref)
|
||||
manifest = tmp_dir / "ref.jsonl"
|
||||
_write_manifest(
|
||||
manifest,
|
||||
[{"text": "hello", "audio": str(audio), "ref_audio": str(ref)}],
|
||||
)
|
||||
result = validate_manifest(str(manifest))
|
||||
assert result.has_ref_audio == 1
|
||||
|
||||
def test_validation_result_properties(self):
|
||||
r = ValidationResult(total_samples=5, valid_samples=5)
|
||||
assert r.is_valid
|
||||
|
||||
r2 = ValidationResult(total_samples=5, valid_samples=5, errors=["err"])
|
||||
assert not r2.is_valid
|
||||
|
||||
r3 = ValidationResult(total_samples=0, valid_samples=0)
|
||||
assert not r3.is_valid
|
||||
|
||||
def test_invalid_audio_not_counted_as_valid(self, tmp_dir):
|
||||
"""A row with a bad audio path must not increment valid_samples."""
|
||||
manifest = tmp_dir / "bad_audio.jsonl"
|
||||
_write_manifest(
|
||||
manifest,
|
||||
[{"text": "hello", "audio": "/nonexistent/audio.wav"}],
|
||||
)
|
||||
result = validate_manifest(str(manifest))
|
||||
assert result.total_samples == 1
|
||||
assert result.valid_samples == 0
|
||||
assert not result.is_valid
|
||||
assert any("not found" in e for e in result.errors)
|
||||
|
||||
def test_sample_rate_mismatch(self, tmp_dir):
|
||||
"""A file with a different sample rate should be reported as an error."""
|
||||
try:
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
pytest.skip("soundfile not available")
|
||||
|
||||
audio = tmp_dir / "audio_8k.wav"
|
||||
import numpy as np
|
||||
samples = np.zeros(8000, dtype=np.float32)
|
||||
sf.write(str(audio), samples, 8000)
|
||||
|
||||
manifest = tmp_dir / "sr_mismatch.jsonl"
|
||||
_write_manifest(manifest, [{"text": "hello", "audio": str(audio)}])
|
||||
|
||||
result = validate_manifest(str(manifest), sample_rate=16000)
|
||||
assert result.valid_samples == 0
|
||||
assert not result.is_valid
|
||||
assert any("Sample rate mismatch" in e or "sample rate" in e.lower() for e in result.errors)
|
||||
|
||||
def test_mixed_ref_audio_warns_for_each_missing(self, tmp_dir):
|
||||
"""Missing ref_audio entries should each generate a warning independently."""
|
||||
audio = tmp_dir / "audio.wav"
|
||||
ref_good = tmp_dir / "ref_good.wav"
|
||||
_create_wav(audio)
|
||||
_create_wav(ref_good)
|
||||
|
||||
manifest = tmp_dir / "mixed_ref.jsonl"
|
||||
_write_manifest(
|
||||
manifest,
|
||||
[
|
||||
{"text": "row1", "audio": str(audio), "ref_audio": str(ref_good)},
|
||||
{"text": "row2", "audio": str(audio), "ref_audio": "/nonexistent/ref.wav"},
|
||||
],
|
||||
)
|
||||
result = validate_manifest(str(manifest))
|
||||
assert result.has_ref_audio == 1
|
||||
assert any("ref_audio file not found" in w for w in result.warnings)
|
||||
|
||||
def test_cli_validate_exit_code(self, tmp_dir):
|
||||
"""validate subcommand must exit 1 on validation error (missing audio)."""
|
||||
import subprocess
|
||||
manifest = tmp_dir / "bad.jsonl"
|
||||
_write_manifest(manifest, [{"text": "hi", "audio": "/nonexistent/x.wav"}])
|
||||
|
||||
proc = subprocess.run(
|
||||
[sys.executable, "-m", "voxcpm.cli", "validate", "--manifest", str(manifest)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
assert proc.returncode == 1, f"Expected exit 1, got {proc.returncode}"
|
||||
assert "FAILED" in proc.stderr or "Audio file not found" in proc.stderr
|
||||
Reference in New Issue
Block a user