4 Commits

Author SHA1 Message Date
Labmem-Zhouyx 68af4fe502 fix: ft log and setting 2026-04-08 18:15:17 +08:00
Labmem-Zhouyx ee3649c1b3 fix: streaming decode 2026-04-08 17:25:54 +08:00
Labmem-Zhouyx 82d77d445c fix: decode chunksize for audiovae_v2 2026-04-08 16:31:36 +08:00
Labmem-Zhouyx 8f95d13073 update readme: 30-language asr result on internal benchmark 2026-04-08 15:36:56 +08:00
9 changed files with 188 additions and 73 deletions
+47 -3
View File
@@ -238,8 +238,8 @@ voxcpm --help
### Web Demo ### Web Demo
```bash ```bash
python app.py # then open http://localhost:7860 python app.py --model-dir /path/to/VoxCPM2 --port 8808 # use a local model directory, open http://localhost:8808
``` ```
### 🚢 Production Deployment (Nano-vLLM) ### 🚢 Production Deployment (Nano-vLLM)
@@ -415,10 +415,54 @@ VoxCPM2 achieves state-of-the-art or comparable results on public zero-shot and
</details> </details>
### Internal 30-Language ASR Benchmark
We additionally run an internal multilingual intelligibility benchmark with **30 languages × 500 samples**. ASR transcription is evaluated via **Gemini 3.1 Flash Lite API**.
<details>
<summary><b>Internal 30-Language ASR Benchmark (click to expand)</b></summary>
| Language | Metric | VoxCPM2 | Fish S2-Pro |
|---|---:|---:|---:|
| ar (Arabic) | CER | 1.23% | 0.30% |
| da (Danish) | WER | 2.70% | 3.52% |
| de (German) | WER | 0.96% | 0.64% |
| el (Greek) | WER | 3.17% | 4.61% |
| en (English) | WER | 0.42% | 1.03% |
| es (Spanish) | WER | 1.33% | 0.64% |
| fi (Finnish) | WER | 2.24% | 2.80% |
| fr (French) | WER | 2.16% | 2.34% |
| he (Hebrew) | CER | 2.98% | 15.27% |
| hi (Hindi) | CER | 0.79% | 0.91% |
| id (Indonesian) | WER | 1.36% | 1.68% |
| it (Italian) | WER | 1.65% | 1.08% |
| ja (Japanese) | CER | 2.40% | 1.82% |
| km (Khmer) | CER | 2.05% | 75.15% |
| ko (Korean) | CER | 0.95% | 0.29% |
| lo (Lao) | CER | 1.90% | 87.40% |
| ms (Malay) | WER | 1.75% | 1.41% |
| my (Burmese) | CER | 1.42% | 85.27% |
| nl (Dutch) | WER | 1.25% | 1.68% |
| no (Norwegian) | WER | 2.49% | 3.76% |
| pl (Polish) | WER | 1.90% | 1.65% |
| pt (Portuguese) | WER | 1.48% | 1.49% |
| ru (Russian) | WER | 0.90% | 0.86% |
| sv (Swedish) | WER | 2.22% | 2.63% |
| sw (Swahili) | CER | 1.07% | 2.02% |
| th (Thai) | CER | 0.94% | 1.92% |
| tl (Tagalog) | WER | 2.63% | 4.00% |
| tr (Turkish) | WER | 1.65% | 1.65% |
| vi (Vietnamese) | WER | 1.56% | 5.56% |
| zh (Chinese) | CER | 0.92% | 1.02% |
| Average (30 languages) | | **1.68%** | - |
</details>
### InstructTTSEval ### InstructTTSEval
<details> <details>
<summary><b>Instruction-Guided Voice Design Results</b></summary> <summary><b>Instruction-Guided Voice Design Results (click to expand)</b></summary>
| Model | InstructTTSEval-ZH | | | InstructTTSEval-EN | | | | Model | InstructTTSEval-ZH | | | InstructTTSEval-EN | | |
|-------|:---:|:----:|:----:|:----:|:----:|:----:| |-------|:---:|:----:|:----:|:----:|:----:|:----:|
+46 -2
View File
@@ -238,7 +238,7 @@ voxcpm --help
### Web Demo ### Web Demo
```bash ```bash
python app.py # 然后打开 http://localhost:7860 python app.py --model-dir /path/to/VoxCPM2 --port 8808 # 指定本地模型路径,然后打开 http://localhost:8808
``` ```
### 🚢 生产部署(Nano-vLLM ### 🚢 生产部署(Nano-vLLM
@@ -414,10 +414,54 @@ VoxCPM2 在公开的零样本和可控 TTS 基准测试中取得了 SOTA 或可
</details> </details>
### Internal 30-Language ASR Benchmark
我们额外进行了内部多语言可懂度评测:**30 语种 × 500 样本**,ASR 转写评估使用 **Gemini 3.1 Flash Lite API**
<details>
<summary><b>内部30语种评测集ASR结果(点击展开)</b></summary>
| 语言 | 指标 | VoxCPM2 | Fish S2-Pro |
|---|---:|---:|---:|
| ar (阿拉伯语) | CER | 1.23% | 0.30% |
| da (丹麦语) | WER | 2.70% | 3.52% |
| de (德语) | WER | 0.96% | 0.64% |
| el (希腊语) | WER | 3.17% | 4.61% |
| en (英语) | WER | 0.42% | 1.03% |
| es (西班牙语) | WER | 1.33% | 0.64% |
| fi (芬兰语) | WER | 2.24% | 2.80% |
| fr (法语) | WER | 2.16% | 2.34% |
| he (希伯来语) | CER | 2.98% | 15.27% |
| hi (印地语) | CER | 0.79% | 0.91% |
| id (印尼语) | WER | 1.36% | 1.68% |
| it (意大利语) | WER | 1.65% | 1.08% |
| ja (日语) | CER | 2.40% | 1.82% |
| km (高棉语) | CER | 2.05% | 75.15% |
| ko (韩语) | CER | 0.95% | 0.29% |
| lo (老挝语) | CER | 1.90% | 87.40% |
| ms (马来语) | WER | 1.75% | 1.41% |
| my (缅甸语) | CER | 1.42% | 85.27% |
| nl (荷兰语) | WER | 1.25% | 1.68% |
| no (挪威语) | WER | 2.49% | 3.76% |
| pl (波兰语) | WER | 1.90% | 1.65% |
| pt (葡萄牙语) | WER | 1.48% | 1.49% |
| ru (俄语) | WER | 0.90% | 0.86% |
| sv (瑞典语) | WER | 2.22% | 2.63% |
| sw (斯瓦希里语) | CER | 1.07% | 2.02% |
| th (泰语) | CER | 0.94% | 1.92% |
| tl (菲律宾语) | WER | 2.63% | 4.00% |
| tr (土耳其语) | WER | 1.65% | 1.65% |
| vi (越南语) | WER | 1.56% | 5.56% |
| zh (中文) | CER | 0.92% | 1.02% |
| 平均(30 语种) | | **1.68%** | - |
</details>
### InstructTTSEval ### InstructTTSEval
<details> <details>
<summary><b>指令驱动音色设计结果</b></summary> <summary><b>指令驱动音色设计结果 (点击展开)</b></summary>
| Model | InstructTTSEval-ZH | | | InstructTTSEval-EN | | | | Model | InstructTTSEval-ZH | | | InstructTTSEval-EN | | |
|-------|:---:|:----:|:----:|:----:|:----:|:----:| |-------|:---:|:----:|:----:|:----:|:----:|:----:|
+2 -1
View File
@@ -2,7 +2,7 @@ pretrained_path: /path/to/VoxCPM2/
train_manifest: /path/to/train.jsonl train_manifest: /path/to/train.jsonl
val_manifest: null val_manifest: null
sample_rate: 16000 # AudioVAE encoder input rate; must match audio_vae_config.sample_rate sample_rate: 16000 # AudioVAE encoder input rate; must match audio_vae_config.sample_rate
out_sample_rate: 48000 # AudioVAE decoder output rate; only used at inference, not during training out_sample_rate: 48000 # AudioVAE decoder output rate; used for TensorBoard audio logging
batch_size: 2 batch_size: 2
grad_accum_steps: 8 # effective batch size = batch_size × grad_accum_steps = 16 grad_accum_steps: 8 # effective batch size = batch_size × grad_accum_steps = 16
num_workers: 8 num_workers: 8
@@ -15,6 +15,7 @@ weight_decay: 0.01
warmup_steps: 100 warmup_steps: 100
max_steps: 1000 max_steps: 1000
max_batch_tokens: 8192 max_batch_tokens: 8192
max_grad_norm: 1.0 # gradient clipping max norm; 0 = disabled
save_path: /path/to/checkpoints/finetune_all save_path: /path/to/checkpoints/finetune_all
tensorboard: /path/to/logs/finetune_all tensorboard: /path/to/logs/finetune_all
lambdas: lambdas:
+2 -1
View File
@@ -2,7 +2,7 @@ pretrained_path: /path/to/VoxCPM2/
train_manifest: /path/to/train.jsonl train_manifest: /path/to/train.jsonl
val_manifest: null val_manifest: null
sample_rate: 16000 # AudioVAE encoder input rate; must match audio_vae_config.sample_rate sample_rate: 16000 # AudioVAE encoder input rate; must match audio_vae_config.sample_rate
out_sample_rate: 48000 # AudioVAE decoder output rate; only used at inference, not during training out_sample_rate: 48000 # AudioVAE decoder output rate; used for TensorBoard audio logging
batch_size: 2 batch_size: 2
grad_accum_steps: 8 # effective batch size = batch_size × grad_accum_steps = 16 grad_accum_steps: 8 # effective batch size = batch_size × grad_accum_steps = 16
num_workers: 8 num_workers: 8
@@ -15,6 +15,7 @@ weight_decay: 0.01
warmup_steps: 100 warmup_steps: 100
max_steps: 1000 max_steps: 1000
max_batch_tokens: 8192 max_batch_tokens: 8192
max_grad_norm: 1.0 # gradient clipping max norm; 0 = disabled
save_path: /path/to/checkpoints/finetune_lora save_path: /path/to/checkpoints/finetune_lora
tensorboard: /path/to/logs/finetune_lora tensorboard: /path/to/logs/finetune_lora
lambdas: lambdas:
+33 -8
View File
@@ -14,8 +14,10 @@ from typing import Optional
project_root = Path(__file__).parent project_root = Path(__file__).parent
sys.path.insert(0, str(project_root / "src")) sys.path.insert(0, str(project_root / "src"))
# Default pretrained model path relative to this repo # Default pretrained model path: prefer VoxCPM2 if it exists, fallback to VoxCPM1.5
default_pretrained_path = str(project_root / "models" / "openbmb__VoxCPM1.5") _v2_path = project_root / "models" / "openbmb__VoxCPM2"
_v15_path = project_root / "models" / "openbmb__VoxCPM1.5"
default_pretrained_path = str(_v2_path if _v2_path.exists() else _v15_path)
from voxcpm.core import VoxCPM from voxcpm.core import VoxCPM
from voxcpm.model.voxcpm import LoRAConfig from voxcpm.model.voxcpm import LoRAConfig
@@ -368,6 +370,7 @@ def start_training(
warmup_steps=100, warmup_steps=100,
max_steps=None, max_steps=None,
sample_rate=44100, sample_rate=44100,
max_grad_norm=1.0,
# LoRA advanced # LoRA advanced
enable_lm=True, enable_lm=True,
enable_dit=True, enable_dit=True,
@@ -409,11 +412,25 @@ def start_training(
# Resolve max_steps default # Resolve max_steps default
resolved_max_steps = int(max_steps) if max_steps not in (None, "", 0) else int(num_iters) resolved_max_steps = int(max_steps) if max_steps not in (None, "", 0) else int(num_iters)
# Auto-detect out_sample_rate from model config
out_sample_rate = 0
config_file = os.path.join(pretrained_path, "config.json")
if os.path.isfile(config_file):
try:
with open(config_file, "r", encoding="utf-8") as f:
cfg = json.load(f)
out_sr = cfg.get("audio_vae_config", {}).get("out_sample_rate")
if out_sr:
out_sample_rate = int(out_sr)
except Exception:
pass
config = { config = {
"pretrained_path": pretrained_path, "pretrained_path": pretrained_path,
"train_manifest": train_manifest, "train_manifest": train_manifest,
"val_manifest": val_manifest, "val_manifest": val_manifest,
"sample_rate": int(sample_rate), "sample_rate": int(sample_rate),
"out_sample_rate": out_sample_rate,
"batch_size": int(batch_size), "batch_size": int(batch_size),
"grad_accum_steps": int(grad_accum_steps), "grad_accum_steps": int(grad_accum_steps),
"num_workers": int(num_workers), "num_workers": int(num_workers),
@@ -425,6 +442,7 @@ def start_training(
"weight_decay": float(weight_decay), "weight_decay": float(weight_decay),
"warmup_steps": int(warmup_steps), "warmup_steps": int(warmup_steps),
"max_steps": resolved_max_steps, "max_steps": resolved_max_steps,
"max_grad_norm": float(max_grad_norm),
"save_path": checkpoints_dir, "save_path": checkpoints_dir,
"tensorboard": tensorboard_path if tensorboard_path else logs_dir, "tensorboard": tensorboard_path if tensorboard_path else logs_dir,
"lambdas": {"loss/diff": 1.0, "loss/stop": 1.0}, "lambdas": {"loss/diff": 1.0, "loss/stop": 1.0},
@@ -932,17 +950,19 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
with gr.Row(): with gr.Row():
max_steps = gr.Number(label="最大步数 (max_steps, 0→默认num_iters)", value=0, precision=0) max_steps = gr.Number(label="最大步数 (max_steps, 0→默认num_iters)", value=0, precision=0)
sample_rate = gr.Number(label="采样率 (sample_rate)", value=44100, precision=0) sample_rate = gr.Number(label="采样率 (sample_rate)", value=44100, precision=0)
tensorboard_path = gr.Textbox(label="Tensorboard 路径 (可选)", value="") max_grad_norm = gr.Number(label="梯度裁剪 (max_grad_norm, 0=关闭)", value=1.0)
with gr.Row(): with gr.Row():
tensorboard_path = gr.Textbox(label="Tensorboard 路径 (可选)", value="")
enable_lm = gr.Checkbox(label="启用 LoRA LM (enable_lm)", value=True) enable_lm = gr.Checkbox(label="启用 LoRA LM (enable_lm)", value=True)
enable_dit = gr.Checkbox(label="启用 LoRA DIT (enable_dit)", value=True) enable_dit = gr.Checkbox(label="启用 LoRA DIT (enable_dit)", value=True)
with gr.Row():
enable_proj = gr.Checkbox(label="启用投影 (enable_proj)", value=False) enable_proj = gr.Checkbox(label="启用投影 (enable_proj)", value=False)
dropout = gr.Number(label="LoRA Dropout", value=0.0) dropout = gr.Number(label="LoRA Dropout", value=0.0)
gr.Markdown("#### 分发选项 (Distribution)") gr.Markdown("#### 分发选项 (Distribution)")
with gr.Row(): with gr.Row():
hf_model_id = gr.Textbox( hf_model_id = gr.Textbox(
label="HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)", value="openbmb/VoxCPM1.5" label="HuggingFace Model ID (e.g., openbmb/VoxCPM2)", value=""
) )
distribute = gr.Checkbox(label="分发模式 (distribute)", value=False) distribute = gr.Checkbox(label="分发模式 (distribute)", value=False)
@@ -992,6 +1012,7 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
warmup_steps, warmup_steps,
max_steps, max_steps,
sample_rate, sample_rate,
max_grad_norm,
enable_lm, enable_lm,
enable_dit, enable_dit,
enable_proj, enable_proj,
@@ -1150,12 +1171,13 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
"warmup_steps": "warmup_steps", "warmup_steps": "warmup_steps",
"max_steps": "最大步数 (max_steps)", "max_steps": "最大步数 (max_steps)",
"sample_rate": "采样率 (sample_rate)", "sample_rate": "采样率 (sample_rate)",
"max_grad_norm": "梯度裁剪 (max_grad_norm, 0=关闭)",
"enable_lm": "启用 LoRA LM (enable_lm)", "enable_lm": "启用 LoRA LM (enable_lm)",
"enable_dit": "启用 LoRA DIT (enable_dit)", "enable_dit": "启用 LoRA DIT (enable_dit)",
"enable_proj": "启用投影 (enable_proj)", "enable_proj": "启用投影 (enable_proj)",
"dropout": "LoRA Dropout", "dropout": "LoRA Dropout",
"tensorboard_path": "Tensorboard 路径 (可选)", "tensorboard_path": "Tensorboard 路径 (可选)",
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)", "hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM2)",
"distribute": "分发模式 (distribute)", "distribute": "分发模式 (distribute)",
} }
else: else:
@@ -1168,12 +1190,13 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
"warmup_steps": "Warmup Steps", "warmup_steps": "Warmup Steps",
"max_steps": "Max Steps", "max_steps": "Max Steps",
"sample_rate": "Sample Rate", "sample_rate": "Sample Rate",
"max_grad_norm": "Max Grad Norm (0=disabled)",
"enable_lm": "Enable LoRA LM", "enable_lm": "Enable LoRA LM",
"enable_dit": "Enable LoRA DIT", "enable_dit": "Enable LoRA DIT",
"enable_proj": "Enable Projection", "enable_proj": "Enable Projection",
"dropout": "LoRA Dropout", "dropout": "LoRA Dropout",
"tensorboard_path": "Tensorboard Path (Optional)", "tensorboard_path": "Tensorboard Path (Optional)",
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)", "hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM2)",
"distribute": "Distribute Mode", "distribute": "Distribute Mode",
} }
@@ -1203,11 +1226,12 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
gr.update(label=adv["warmup_steps"]), gr.update(label=adv["warmup_steps"]),
gr.update(label=adv["max_steps"]), gr.update(label=adv["max_steps"]),
gr.update(label=adv["sample_rate"]), gr.update(label=adv["sample_rate"]),
gr.update(label=adv["max_grad_norm"]),
gr.update(label=adv["tensorboard_path"]),
gr.update(label=adv["enable_lm"]), gr.update(label=adv["enable_lm"]),
gr.update(label=adv["enable_dit"]), gr.update(label=adv["enable_dit"]),
gr.update(label=adv["enable_proj"]), gr.update(label=adv["enable_proj"]),
gr.update(label=adv["dropout"]), gr.update(label=adv["dropout"]),
gr.update(label=adv["tensorboard_path"]),
# Distribution options # Distribution options
gr.update(label=adv["hf_model_id"]), gr.update(label=adv["hf_model_id"]),
gr.update(label=adv["distribute"]), gr.update(label=adv["distribute"]),
@@ -1254,11 +1278,12 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
warmup_steps, warmup_steps,
max_steps, max_steps,
sample_rate, sample_rate,
max_grad_norm,
tensorboard_path,
enable_lm, enable_lm,
enable_dit, enable_dit,
enable_proj, enable_proj,
dropout, dropout,
tensorboard_path,
# distribution outputs # distribution outputs
hf_model_id, hf_model_id,
distribute, distribute,
+25 -11
View File
@@ -30,7 +30,8 @@ except ImportError:
import json import json
from voxcpm.model import VoxCPMModel, VoxCPM2Model from voxcpm.model import VoxCPMModel, VoxCPM2Model
from voxcpm.model.voxcpm import LoRAConfig from voxcpm.model.voxcpm import LoRAConfig as LoRAConfigV1
from voxcpm.model.voxcpm2 import LoRAConfig as LoRAConfigV2
from voxcpm.training import ( from voxcpm.training import (
Accelerator, Accelerator,
BatchProcessor, BatchProcessor,
@@ -46,7 +47,7 @@ def train(
train_manifest: str, train_manifest: str,
val_manifest: str = "", val_manifest: str = "",
sample_rate: int = 16_000, sample_rate: int = 16_000,
out_sample_rate: int = 0, # accepted from YAML for documentation; not used in training out_sample_rate: int = 0, # AudioVAE decoder output rate; used for TensorBoard audio logging
batch_size: int = 1, batch_size: int = 1,
grad_accum_steps: int = 1, grad_accum_steps: int = 1,
num_workers: int = 2, num_workers: int = 2,
@@ -64,12 +65,12 @@ def train(
lambdas: Dict[str, float] = {"loss/diff": 1.0, "loss/stop": 1.0}, lambdas: Dict[str, float] = {"loss/diff": 1.0, "loss/stop": 1.0},
lora: dict = None, lora: dict = None,
config_path: str = "", config_path: str = "",
max_grad_norm: float = 0.0, # gradient clipping; 0 = disabled (backward compat)
# Distribution options (for LoRA checkpoints) # Distribution options (for LoRA checkpoints)
hf_model_id: str = "", # HuggingFace model ID (e.g., "openbmb/VoxCPM1.5") hf_model_id: str = "", # HuggingFace model ID (e.g., "openbmb/VoxCPM1.5")
distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path
): ):
_ = config_path _ = config_path
_ = out_sample_rate
# Validate distribution options # Validate distribution options
if lora is not None and distribute and not hf_model_id: if lora is not None and distribute and not hf_model_id:
@@ -93,6 +94,7 @@ def train(
with open(os.path.join(pretrained_path, "config.json"), "r", encoding="utf-8") as _f: with open(os.path.join(pretrained_path, "config.json"), "r", encoding="utf-8") as _f:
_arch = json.load(_f).get("architecture", "voxcpm").lower() _arch = json.load(_f).get("architecture", "voxcpm").lower()
_model_cls = VoxCPM2Model if _arch == "voxcpm2" else VoxCPMModel _model_cls = VoxCPM2Model if _arch == "voxcpm2" else VoxCPMModel
LoRAConfig = LoRAConfigV2 if _arch == "voxcpm2" else LoRAConfigV1
if accelerator.rank == 0: if accelerator.rank == 0:
print(f"Detected architecture: {_arch} -> {_model_cls.__name__}", file=sys.stderr) print(f"Detected architecture: {_arch} -> {_model_cls.__name__}", file=sys.stderr)
base_model = _model_cls.from_local( base_model = _model_cls.from_local(
@@ -178,8 +180,12 @@ def train(
dataset_cnt=dataset_cnt, dataset_cnt=dataset_cnt,
device=accelerator.device, device=accelerator.device,
) )
# Save audio_vae for audio generation # Save audio_vae and output sample rate for audio generation.
# Prefer model's actual output rate; fall back to YAML out_sample_rate or encode rate.
audio_vae_for_gen = base_model.audio_vae audio_vae_for_gen = base_model.audio_vae
out_sr = base_model.sample_rate # decoder output rate (e.g. 48000 for V2)
if out_sr == 0 and out_sample_rate > 0:
out_sr = out_sample_rate
del base_model.audio_vae del base_model.audio_vae
model = accelerator.prepare_model(base_model) model = accelerator.prepare_model(base_model)
unwrapped_model = accelerator.unwrap(model) unwrapped_model = accelerator.unwrap(model)
@@ -312,8 +318,8 @@ def train(
scaler = getattr(accelerator, "scaler", None) scaler = getattr(accelerator, "scaler", None)
if scaler is not None: if scaler is not None:
scaler.unscale_(optimizer) scaler.unscale_(optimizer)
# Use large max_norm to only compute grad_norm without actual clipping effective_max_norm = max_grad_norm if max_grad_norm > 0 else 1e9
grad_norm = torch.nn.utils.clip_grad_norm_(unwrapped_model.parameters(), max_norm=1e9) grad_norm = torch.nn.utils.clip_grad_norm_(unwrapped_model.parameters(), max_norm=effective_max_norm)
accelerator.step(optimizer) accelerator.step(optimizer)
accelerator.update() accelerator.update()
@@ -341,6 +347,7 @@ def train(
val_ds=val_ds, val_ds=val_ds,
audio_vae=audio_vae_for_gen, audio_vae=audio_vae_for_gen,
sample_rate=sample_rate, sample_rate=sample_rate,
out_sample_rate=out_sr,
val_texts=val_texts, val_texts=val_texts,
tokenizer=tokenizer, tokenizer=tokenizer,
valid_interval=valid_interval, valid_interval=valid_interval,
@@ -367,6 +374,7 @@ def validate(
val_ds=None, val_ds=None,
audio_vae=None, audio_vae=None,
sample_rate=22050, sample_rate=22050,
out_sample_rate=0,
val_texts=None, val_texts=None,
tokenizer=None, tokenizer=None,
valid_interval=1000, valid_interval=1000,
@@ -432,6 +440,7 @@ def validate(
step, step,
accelerator, accelerator,
sample_rate, sample_rate,
out_sample_rate=out_sample_rate,
val_texts=val_texts, val_texts=val_texts,
tokenizer=tokenizer, tokenizer=tokenizer,
valid_interval=valid_interval, valid_interval=valid_interval,
@@ -534,6 +543,7 @@ def generate_sample_audio(
step, step,
accelerator, accelerator,
sample_rate=22050, sample_rate=22050,
out_sample_rate=0,
val_texts=None, val_texts=None,
tokenizer=None, tokenizer=None,
pretrained_path=None, pretrained_path=None,
@@ -548,6 +558,10 @@ def generate_sample_audio(
log(f"[Audio] Starting audio generation for {num_samples} samples at step {step}") log(f"[Audio] Starting audio generation for {num_samples} samples at step {step}")
unwrapped_model = accelerator.unwrap(model) unwrapped_model = accelerator.unwrap(model)
# Determine the correct output sample rate for generated audio.
# out_sample_rate is the decoder output rate (e.g. 48kHz for V2);
# sample_rate is the encoder input rate (e.g. 16kHz for V2).
gen_sr = out_sample_rate if out_sample_rate > 0 else sample_rate
for i in range(num_samples): for i in range(num_samples):
sample = val_ds[i] sample = val_ds[i]
@@ -604,10 +618,10 @@ def generate_sample_audio(
gen_audio_np = normalize_audio(gen_audio_np) gen_audio_np = normalize_audio(gen_audio_np)
tag = f"val_sample_{i}" tag = f"val_sample_{i}"
writer.add_audio(f"{tag}/generated_audio", gen_audio_np, global_step=step, sample_rate=sample_rate) writer.add_audio(f"{tag}/generated_audio", gen_audio_np, global_step=step, sample_rate=gen_sr)
log(f"[Audio] Generated audio for sample {i}: duration={len(gen_audio_np)/sample_rate:.2f}s") log(f"[Audio] Generated audio for sample {i}: duration={len(gen_audio_np)/gen_sr:.2f}s")
# Log reference audio # Log reference audio (at encoder input rate, which is what val_ds provides)
if ref_audio_np is not None: if ref_audio_np is not None:
writer.add_audio( writer.add_audio(
f"{tag}/reference_audio", normalize_audio(ref_audio_np), global_step=step, sample_rate=sample_rate f"{tag}/reference_audio", normalize_audio(ref_audio_np), global_step=step, sample_rate=sample_rate
@@ -615,9 +629,9 @@ def generate_sample_audio(
# Generate mel spectrogram figure # Generate mel spectrogram figure
try: try:
mel_gen = compute_mel_spectrogram(gen_audio_np, sample_rate) mel_gen = compute_mel_spectrogram(gen_audio_np, gen_sr)
mel_ref = compute_mel_spectrogram(ref_audio_np, sample_rate) if ref_audio_np is not None else None mel_ref = compute_mel_spectrogram(ref_audio_np, sample_rate) if ref_audio_np is not None else None
fig = create_mel_figure(gen_audio_np, mel_gen, sample_rate, step, ref_audio_np, mel_ref) fig = create_mel_figure(gen_audio_np, mel_gen, gen_sr, step, ref_audio_np, mel_ref)
writer.add_figure(f"{tag}/mel_spectrogram", fig, global_step=step) writer.add_figure(f"{tag}/mel_spectrogram", fig, global_step=step)
log(f"[Audio] Created mel spectrogram figure for sample {i}") log(f"[Audio] Created mel spectrogram figure for sample {i}")
except Exception as e: except Exception as e:
+31 -46
View File
@@ -48,25 +48,8 @@ from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
from .utils import get_dtype, mask_multichar_chinese_tokens from .utils import get_dtype, mask_multichar_chinese_tokens
def _trim_audio_silence_vad( # A simple function to trim audio silence using VAD, not used default
audio: torch.Tensor, def _trim_audio_silence_vad(audio: torch.Tensor, sample_rate: int, max_silence_ms: float = 200.0, top_db: float = 35.0) -> torch.Tensor:
sample_rate: int,
max_silence_ms: float = 200.0,
top_db: float = 35.0,
) -> torch.Tensor:
"""使用能量阈值(VAD 方式)截取首尾静音及尾部长段伪静音,首尾各最多保留 max_silence_ms 毫秒静音。
会同时截掉末尾的长段伪静音(低能量但非完全静音的段落,如长时间底噪)。
Args:
audio: (1, T) 的音频 tensor
sample_rate: 采样率
max_silence_ms: 首尾允许保留的最大静音长度(毫秒)
top_db: 低于参考电平多少 dB 视为静音
Returns:
截取后的 (1, T') tensor
"""
if audio.numel() == 0: if audio.numel() == 0:
return audio return audio
y = audio.squeeze(0).numpy() y = audio.squeeze(0).numpy()
@@ -85,7 +68,7 @@ def _trim_audio_silence_vad(
except Exception: except Exception:
start, end = 0, n start, end = 0, n
# 用逐帧 RMS 找「最后一段有持续能量的位置」,截掉末尾长伪静音(低能量底噪等) # Find the last frame with continuous energy, trim the long pseudo-silence at the end (low energy background noise, etc.)
n_frames = max(0, (n - frame_length) // hop_length + 1) n_frames = max(0, (n - frame_length) // hop_length + 1)
last_voice_frame = -1 last_voice_frame = -1
for j in range(n_frames): for j in range(n_frames):
@@ -246,6 +229,7 @@ class VoxCPM2Model(nn.Module):
# Audio VAE # Audio VAE
self.audio_vae = audio_vae self.audio_vae = audio_vae
self.chunk_size = audio_vae.chunk_size self.chunk_size = audio_vae.chunk_size
self._decode_chunk_size = getattr(audio_vae, "decode_chunk_size", audio_vae.chunk_size)
self._encode_sample_rate = audio_vae.sample_rate self._encode_sample_rate = audio_vae.sample_rate
self.sample_rate = getattr(audio_vae, "out_sample_rate", audio_vae.sample_rate) self.sample_rate = getattr(audio_vae, "out_sample_rate", audio_vae.sample_rate)
@@ -382,11 +366,7 @@ class VoxCPM2Model(nn.Module):
mu=dit_hidden, mu=dit_hidden,
patch_size=self.patch_size, patch_size=self.patch_size,
cond=feat_cond_for_sample, cond=feat_cond_for_sample,
n_timesteps=( n_timesteps=10,
self.config.dit_config.cfm_config.inference_cfg_rate
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
else 10
),
) )
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size) feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
@@ -656,14 +636,14 @@ class VoxCPM2Model(nn.Module):
streaming_prefix_len=streaming_prefix_len, streaming_prefix_len=streaming_prefix_len,
) )
if streaming: if streaming:
patch_len = self.patch_size * self.chunk_size decode_patch_len = self.patch_size * self._decode_chunk_size
for latent_pred, _ in inference_result: for latent_pred, _, _ctx in inference_result:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)) decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu() decode_audio = decode_audio[..., -decode_patch_len:].squeeze(1).cpu()
yield decode_audio yield decode_audio
break break
else: else:
latent_pred, pred_audio_feat = next(inference_result) latent_pred, pred_audio_feat, context_len = next(inference_result)
if retry_badcase: if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold: if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print( print(
@@ -679,10 +659,9 @@ class VoxCPM2Model(nn.Module):
if not streaming: if not streaming:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)) decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
patch_len = self.patch_size * self.chunk_size decode_patch_len = self.patch_size * self._decode_chunk_size
has_continuation = bool(prompt_wav_path) if context_len > 0:
if has_continuation: decode_audio = decode_audio[..., decode_patch_len * context_len:].squeeze(1).cpu()
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1):].squeeze(1).cpu()
else: else:
decode_audio = decode_audio.squeeze(1).cpu() decode_audio = decode_audio.squeeze(1).cpu()
yield decode_audio yield decode_audio
@@ -944,14 +923,14 @@ class VoxCPM2Model(nn.Module):
streaming_prefix_len=streaming_prefix_len, streaming_prefix_len=streaming_prefix_len,
) )
if streaming: if streaming:
patch_len = self.patch_size * self.chunk_size decode_patch_len = self.patch_size * self._decode_chunk_size
for latent_pred, pred_audio_feat in inference_result: for latent_pred, pred_audio_feat, _ctx in inference_result:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)) decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu() decode_audio = decode_audio[..., -decode_patch_len:].squeeze(1).cpu()
yield (decode_audio, target_text_token, pred_audio_feat) yield (decode_audio, target_text_token, pred_audio_feat)
break break
else: else:
latent_pred, pred_audio_feat = next(inference_result) latent_pred, pred_audio_feat, context_len = next(inference_result)
if retry_badcase: if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold: if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print( print(
@@ -966,18 +945,20 @@ class VoxCPM2Model(nn.Module):
break break
if not streaming: if not streaming:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)) decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
patch_len = self.patch_size * self.chunk_size decode_patch_len = self.patch_size * self._decode_chunk_size
if mode in ("continuation", "ref_continuation"): if context_len > 0:
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1) :].squeeze(1).cpu() decode_audio = decode_audio[..., decode_patch_len * context_len:].squeeze(1).cpu()
else: else:
decode_audio = decode_audio[..., :].squeeze(1).cpu() decode_audio = decode_audio.squeeze(1).cpu()
yield (decode_audio, target_text_token, pred_audio_feat) yield (decode_audio, target_text_token, pred_audio_feat)
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
return next(self._inference(*args, streaming=False, **kwargs)) feat_pred, generated_feat, _ = next(self._inference(*args, streaming=False, **kwargs))
return feat_pred, generated_feat
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]: def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
return self._inference(*args, streaming=True, **kwargs) for feat_pred, pred_feat_seq, _ in self._inference(*args, streaming=True, **kwargs):
yield feat_pred, pred_feat_seq
@torch.inference_mode() @torch.inference_mode()
def _inference( def _inference(
@@ -1036,6 +1017,7 @@ class VoxCPM2Model(nn.Module):
# trailing audio patches as initial context so the VAE can decode smoothly. # trailing audio patches as initial context so the VAE can decode smoothly.
# - Reference-only / zero-shot (feat_mask ends with 0): start from scratch. # - Reference-only / zero-shot (feat_mask ends with 0): start from scratch.
has_continuation_audio = feat_mask[0, -1].item() == 1 has_continuation_audio = feat_mask[0, -1].item() == 1
context_len = 0
if has_continuation_audio: if has_continuation_audio:
audio_indices = feat_mask.squeeze(0).nonzero(as_tuple=True)[0] audio_indices = feat_mask.squeeze(0).nonzero(as_tuple=True)[0]
context_len = min(streaming_prefix_len - 1, len(audio_indices)) context_len = min(streaming_prefix_len - 1, len(audio_indices))
@@ -1085,11 +1067,13 @@ class VoxCPM2Model(nn.Module):
prefix_feat_cond = pred_feat prefix_feat_cond = pred_feat
if streaming: if streaming:
# return the last three predicted latent features to provide enough context for smooth decoding
pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1) pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1)
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size) feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
yield feat_pred, pred_feat_seq yield feat_pred, pred_feat_seq, context_len
if len(pred_feat_seq) > streaming_prefix_len:
pred_feat_seq = pred_feat_seq[-streaming_prefix_len:]
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item() stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
if i > min_len and stop_flag == 1: if i > min_len and stop_flag == 1:
@@ -1108,7 +1092,8 @@ class VoxCPM2Model(nn.Module):
if not streaming: if not streaming:
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size) feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
yield feat_pred, pred_feat_seq.squeeze(0).cpu() generated_feat = pred_feat_seq[:, context_len:, :, :].squeeze(0).cpu()
yield feat_pred, generated_feat, context_len
@classmethod @classmethod
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None): def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
@@ -436,6 +436,7 @@ class AudioVAE(nn.Module):
self.out_sample_rate = out_sample_rate self.out_sample_rate = out_sample_rate
self.sr_bin_boundaries = sr_bin_boundaries self.sr_bin_boundaries = sr_bin_boundaries
self.chunk_size = math.prod(encoder_rates) self.chunk_size = math.prod(encoder_rates)
self.decode_chunk_size = math.prod(decoder_rates)
def preprocess(self, audio_data, sample_rate): def preprocess(self, audio_data, sample_rate):
if sample_rate is None: if sample_rate is None:
+1 -1
View File
@@ -225,7 +225,7 @@ class UnifiedCFM(torch.nn.Module):
losses = F.mse_loss(u_pred, u_tgt.detach(), reduction="none").mean(dim=1) losses = F.mse_loss(u_pred, u_tgt.detach(), reduction="none").mean(dim=1)
if tgt_mask is not None: if tgt_mask is not None:
weights = self.adaptive_loss_weighting(losses, tgt_mask.squeeze(1)) weights = self.adaptive_loss_weighting(losses, tgt_mask.squeeze(1))
loss = (weights * losses).sum() / torch.sum(tgt_mask) loss = (weights * losses).sum() / torch.clamp(torch.sum(tgt_mask), min=1.0)
else: else:
loss = losses.mean() loss = losses.mean()