update voxcpm2

This commit is contained in:
刘鑫
2026-03-31 11:50:37 +08:00
parent 23ed7ffeee
commit d9cf376e16
36 changed files with 8163 additions and 834 deletions
+194 -84
View File
@@ -7,7 +7,7 @@ project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root / "src"))
import contextlib
from typing import Dict, Optional
from typing import Dict
import argbind
import torch
@@ -17,16 +17,19 @@ from transformers import get_cosine_schedule_with_warmup
import signal
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
os.environ["TOKENIZERS_PARALLELISM"] = "false"
try:
from safetensors.torch import save_file
SAFETENSORS_AVAILABLE = True
except ImportError:
SAFETENSORS_AVAILABLE = False
print("Warning: safetensors not available, will use pytorch format", file=sys.stderr)
from voxcpm.model import VoxCPMModel
import json
from voxcpm.model import VoxCPMModel, VoxCPM2Model
from voxcpm.model.voxcpm import LoRAConfig
from voxcpm.training import (
Accelerator,
@@ -61,15 +64,15 @@ def train(
lora: dict = None,
config_path: str = "",
# Distribution options (for LoRA checkpoints)
hf_model_id: str = "", # HuggingFace model ID (e.g., "openbmb/VoxCPM1.5")
distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path
hf_model_id: str = "", # HuggingFace model ID (e.g., "openbmb/VoxCPM1.5")
distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path
):
_ = config_path
# Validate distribution options
if lora is not None and distribute and not hf_model_id:
raise ValueError("hf_model_id is required when distribute=True")
accelerator = Accelerator(amp=True)
save_dir = Path(save_path)
@@ -84,7 +87,15 @@ def train(
writer = SummaryWriter(log_dir=str(tb_dir)) if accelerator.rank == 0 else None
tracker = TrainingTracker(writer=writer, log_file=str(save_dir / "train.log"), rank=accelerator.rank)
base_model = VoxCPMModel.from_local(pretrained_path, optimize=False, training=True, lora_config=LoRAConfig(**lora) if lora else None)
# Auto-detect model architecture from config.json
with open(os.path.join(pretrained_path, "config.json"), "r", encoding="utf-8") as _f:
_arch = json.load(_f).get("architecture", "voxcpm").lower()
_model_cls = VoxCPM2Model if _arch == "voxcpm2" else VoxCPMModel
if accelerator.rank == 0:
print(f"Detected architecture: {_arch} -> {_model_cls.__name__}", file=sys.stderr)
base_model = _model_cls.from_local(
pretrained_path, optimize=False, training=True, lora_config=LoRAConfig(**lora) if lora else None
)
tokenizer = base_model.text_tokenizer
train_ds, val_ds = load_audio_text_datasets(
@@ -166,7 +177,6 @@ def train(
unwrapped_model = accelerator.unwrap(model)
unwrapped_model.train()
# Only print param info on rank 0 to avoid cluttered output
if accelerator.rank == 0:
for name, param in model.named_parameters():
@@ -191,7 +201,7 @@ def train(
# All ranks load the same checkpoint to keep model and optimizer state in sync.
start_step = load_checkpoint(model, optimizer, scheduler, save_dir, rank=accelerator.rank)
accelerator.barrier()
if start_step > 0 and accelerator.rank == 0:
tracker.print(f"Resuming training from step {start_step}")
@@ -199,7 +209,19 @@ def train(
resume = {"step": start_step}
# Register signal handler to save checkpoint on termination (SIGTERM/SIGINT)
def _signal_handler(signum, frame, _model=model, _optim=optimizer, _sched=scheduler, _save_dir=save_dir, _pretrained=pretrained_path, _hf_id=hf_model_id, _dist=distribute, _resume=resume, _rank=accelerator.rank):
def _signal_handler(
signum,
frame,
_model=model,
_optim=optimizer,
_sched=scheduler,
_save_dir=save_dir,
_pretrained=pretrained_path,
_hf_id=hf_model_id,
_dist=distribute,
_resume=resume,
_rank=accelerator.rank,
):
try:
cur_step = int(_resume.get("step", start_step))
except Exception:
@@ -229,8 +251,8 @@ def train(
except StopIteration:
data_epoch += 1
# Key: set DistributedSampler epoch to ensure different data order each epoch
sampler = getattr(train_loader, 'sampler', None)
if hasattr(sampler, 'set_epoch'):
sampler = getattr(train_loader, "sampler", None)
if hasattr(sampler, "set_epoch"):
sampler.set_epoch(data_epoch)
train_iter = iter(train_loader)
return next(train_iter)
@@ -250,7 +272,7 @@ def train(
# Only sync gradients on the last micro-batch
# Use no_sync() for intermediate steps to reduce communication overhead
is_last_micro_step = (micro_step == grad_accum_steps - 1)
is_last_micro_step = micro_step == grad_accum_steps - 1
sync_context = contextlib.nullcontext() if is_last_micro_step else accelerator.no_sync()
with sync_context:
@@ -299,10 +321,22 @@ def train(
tracker.log_metrics(loss_values, split="train")
if val_loader is not None and (step % valid_interval == 0 or step == num_iters - 1):
validate(model, val_loader, batch_processor, accelerator, tracker, lambdas,
writer=writer, step=step, val_ds=val_ds, audio_vae=audio_vae_for_gen,
sample_rate=sample_rate, val_texts=val_texts, tokenizer=tokenizer,
valid_interval=valid_interval)
validate(
model,
val_loader,
batch_processor,
accelerator,
tracker,
lambdas,
writer=writer,
step=step,
val_ds=val_ds,
audio_vae=audio_vae_for_gen,
sample_rate=sample_rate,
val_texts=val_texts,
tokenizer=tokenizer,
valid_interval=valid_interval,
)
if (step % save_interval == 0 or step == num_iters - 1) and accelerator.rank == 0:
save_checkpoint(model, optimizer, scheduler, save_dir, step, pretrained_path, hf_model_id, distribute)
@@ -313,13 +347,26 @@ def train(
writer.close()
def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas,
writer=None, step=0, val_ds=None, audio_vae=None, sample_rate=22050,
val_texts=None, tokenizer=None, valid_interval=1000):
def validate(
model,
val_loader,
batch_processor,
accelerator,
tracker,
lambdas,
writer=None,
step=0,
val_ds=None,
audio_vae=None,
sample_rate=22050,
val_texts=None,
tokenizer=None,
valid_interval=1000,
):
"""Validate and generate sample audio"""
import numpy as np
import numpy as np # noqa: F401
from collections import defaultdict
model.eval()
total_losses = []
sub_losses = defaultdict(list) # Track individual sub-losses
@@ -356,26 +403,37 @@ def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas,
# Compute mean total loss
mean_total_loss = torch.stack(total_losses).mean()
accelerator.all_reduce(mean_total_loss)
# Compute mean of each sub-loss
val_metrics = {"loss/total": mean_total_loss.item()}
for key, values in sub_losses.items():
mean_sub_loss = torch.stack(values).mean()
accelerator.all_reduce(mean_sub_loss)
val_metrics[key] = mean_sub_loss.item()
tracker.log_metrics(val_metrics, split="val")
# Generate sample audio for TensorBoard display
if writer is not None and val_ds is not None and audio_vae is not None and accelerator.rank == 0:
try:
generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, sample_rate,
val_texts=val_texts, tokenizer=tokenizer, valid_interval=valid_interval,
tracker=tracker)
generate_sample_audio(
model,
val_ds,
audio_vae,
writer,
step,
accelerator,
sample_rate,
val_texts=val_texts,
tokenizer=tokenizer,
valid_interval=valid_interval,
tracker=tracker,
)
except Exception as e:
tracker.print(f"[Warning] Failed to generate sample audio: {e}")
import traceback
import io
buf = io.StringIO()
traceback.print_exc(file=buf)
tracker.print(buf.getvalue())
@@ -390,7 +448,7 @@ def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas,
missing.append("audio_vae")
if missing and accelerator.rank == 0:
tracker.print(f"[Warning] Skip audio generation: missing {', '.join(missing)}")
model.train()
@@ -398,6 +456,7 @@ def compute_mel_spectrogram(audio_np, sample_rate, n_mels=128):
"""Compute Mel Spectrogram (dB) using librosa"""
import numpy as np
import librosa
audio_np = audio_np.flatten().astype(np.float32)
mel = librosa.feature.melspectrogram(y=audio_np, sr=sample_rate, n_mels=n_mels, fmax=sample_rate // 2)
return librosa.power_to_db(mel, ref=np.max)
@@ -408,31 +467,45 @@ def create_mel_figure(gen_audio_np, gen_mel, sample_rate, step=None, ref_audio_n
Create mel spectrogram figure: show comparison if reference audio exists, otherwise show generated only
"""
import matplotlib
matplotlib.use('Agg')
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import librosa.display
fmax = sample_rate // 2
step_str = f" @ Step {step}" if step is not None else ""
if ref_audio_np is not None and ref_mel is not None:
# Comparison mode: reference vs generated
fig, (ax_ref, ax_gen) = plt.subplots(2, 1, figsize=(12, 8))
img_ref = librosa.display.specshow(ref_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax_ref)
ax_ref.set_title(f'Reference (GT) - {len(ref_audio_np)/sample_rate:.2f}s{step_str}', fontsize=10, fontweight='bold', color='#28A745')
plt.colorbar(img_ref, ax=ax_ref, format='%+2.0f dB', pad=0.02)
img_gen = librosa.display.specshow(gen_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax_gen)
ax_gen.set_title(f'Generated - {len(gen_audio_np)/sample_rate:.2f}s', fontsize=10, fontweight='bold', color='#DC3545')
plt.colorbar(img_gen, ax=ax_gen, format='%+2.0f dB', pad=0.02)
img_ref = librosa.display.specshow(
ref_mel, sr=sample_rate, x_axis="time", y_axis="mel", fmax=fmax, cmap="viridis", ax=ax_ref
)
ax_ref.set_title(
f"Reference (GT) - {len(ref_audio_np)/sample_rate:.2f}s{step_str}",
fontsize=10,
fontweight="bold",
color="#28A745",
)
plt.colorbar(img_ref, ax=ax_ref, format="%+2.0f dB", pad=0.02)
img_gen = librosa.display.specshow(
gen_mel, sr=sample_rate, x_axis="time", y_axis="mel", fmax=fmax, cmap="viridis", ax=ax_gen
)
ax_gen.set_title(
f"Generated - {len(gen_audio_np)/sample_rate:.2f}s", fontsize=10, fontweight="bold", color="#DC3545"
)
plt.colorbar(img_gen, ax=ax_gen, format="%+2.0f dB", pad=0.02)
else:
# Single figure mode: show generated only
fig, ax = plt.subplots(figsize=(12, 4))
img = librosa.display.specshow(gen_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax)
ax.set_title(f'Generated - {len(gen_audio_np)/sample_rate:.2f}s{step_str}', fontsize=11, fontweight='bold')
plt.colorbar(img, ax=ax, format='%+2.0f dB', pad=0.02)
img = librosa.display.specshow(
gen_mel, sr=sample_rate, x_axis="time", y_axis="mel", fmax=fmax, cmap="viridis", ax=ax
)
ax.set_title(f"Generated - {len(gen_audio_np)/sample_rate:.2f}s{step_str}", fontsize=11, fontweight="bold")
plt.colorbar(img, ax=ax, format="%+2.0f dB", pad=0.02)
plt.tight_layout()
return fig
@@ -440,26 +513,38 @@ def create_mel_figure(gen_audio_np, gen_mel, sample_rate, step=None, ref_audio_n
def normalize_audio(audio_np):
"""Normalize audio to [-0.9, 0.9]"""
import numpy as np
max_val = np.abs(audio_np).max()
return audio_np / max_val * 0.9 if max_val > 0 else audio_np
def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, sample_rate=22050,
val_texts=None, tokenizer=None, pretrained_path=None, valid_interval=1000,
tracker=None):
def generate_sample_audio(
model,
val_ds,
audio_vae,
writer,
step,
accelerator,
sample_rate=22050,
val_texts=None,
tokenizer=None,
pretrained_path=None,
valid_interval=1000,
tracker=None,
):
"""Select 2 fixed validation samples, generate audio and log to TensorBoard"""
import numpy as np
log = tracker.print if tracker else print
num_samples = min(2, len(val_ds))
log(f"[Audio] Starting audio generation for {num_samples} samples at step {step}")
unwrapped_model = accelerator.unwrap(model)
for i in range(num_samples):
sample = val_ds[i]
text = val_texts[i] if val_texts and i < len(val_texts) else "Hello, this is a test."
# Load reference audio
ref_audio_np = None
try:
@@ -468,7 +553,10 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
ref_sr = sample["audio"].get("sampling_rate", sample_rate)
if ref_sr != sample_rate:
import torchaudio.functional as F
ref_audio_np = F.resample(torch.from_numpy(ref_audio_np).unsqueeze(0), ref_sr, sample_rate).squeeze(0).numpy()
ref_audio_np = (
F.resample(torch.from_numpy(ref_audio_np).unsqueeze(0), ref_sr, sample_rate).squeeze(0).numpy()
)
log(f"[Audio] Loaded reference audio for sample {i}: duration={len(ref_audio_np)/sample_rate:.2f}s")
except Exception as e:
log(f"[Warning] Failed to load reference audio: {e}")
@@ -480,7 +568,7 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
unwrapped_model.eval()
# unwrapped_model.to(torch.bfloat16)
unwrapped_model.audio_vae = audio_vae.to(torch.float32)
log(f"[Audio] Generating sample {i} with text: '{text[:50]}...'")
autocast_ctx = (
torch.autocast(device_type="cuda", dtype=torch.bfloat16)
@@ -490,27 +578,33 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
with torch.no_grad():
with autocast_ctx:
generated = unwrapped_model.generate(target_text=text, inference_timesteps=10, cfg_value=2.0)
# Restore training setup
# unwrapped_model.to(torch.float32)
# unwrapped_model.audio_vae = None
if generated is None or len(generated) == 0:
log(f"[Warning] Generated audio is empty for sample {i}")
continue
# Process generated audio
gen_audio_np = generated.cpu().float().numpy().flatten() if isinstance(generated, torch.Tensor) else np.array(generated, dtype=np.float32).flatten()
gen_audio_np = (
generated.cpu().float().numpy().flatten()
if isinstance(generated, torch.Tensor)
else np.array(generated, dtype=np.float32).flatten()
)
gen_audio_np = normalize_audio(gen_audio_np)
tag = f"val_sample_{i}"
writer.add_audio(f"{tag}/generated_audio", gen_audio_np, global_step=step, sample_rate=sample_rate)
log(f"[Audio] Generated audio for sample {i}: duration={len(gen_audio_np)/sample_rate:.2f}s")
# Log reference audio
if ref_audio_np is not None:
writer.add_audio(f"{tag}/reference_audio", normalize_audio(ref_audio_np), global_step=step, sample_rate=sample_rate)
writer.add_audio(
f"{tag}/reference_audio", normalize_audio(ref_audio_np), global_step=step, sample_rate=sample_rate
)
# Generate mel spectrogram figure
try:
mel_gen = compute_mel_spectrogram(gen_audio_np, sample_rate)
@@ -520,10 +614,11 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
log(f"[Audio] Created mel spectrogram figure for sample {i}")
except Exception as e:
log(f"[Warning] Failed to create mel spectrogram: {e}")
except Exception as e:
log(f"[Warning] Failed to generate audio for sample {i}: {e}")
import traceback
traceback.print_exc()
finally:
@@ -545,30 +640,29 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0):
Called by all ranks so that distributed state stays aligned.
Returns the step number to resume from, or 0 if no checkpoint found.
"""
import json
latest_folder = save_dir / "latest"
if not latest_folder.exists():
return 0
unwrapped = model.module if hasattr(model, "module") else model
lora_cfg = unwrapped.lora_config
# Load model weights
if lora_cfg is not None:
# LoRA: load lora_weights
lora_weights_path = latest_folder / "lora_weights.safetensors"
if not lora_weights_path.exists():
lora_weights_path = latest_folder / "lora_weights.ckpt"
if lora_weights_path.exists():
if lora_weights_path.suffix == ".safetensors":
from safetensors.torch import load_file
state_dict = load_file(str(lora_weights_path))
else:
ckpt = torch.load(lora_weights_path, map_location="cpu")
state_dict = ckpt.get("state_dict", ckpt)
unwrapped.load_state_dict(state_dict, strict=False)
if rank == 0:
print(f"Loaded LoRA weights from {lora_weights_path}", file=sys.stderr)
@@ -577,33 +671,34 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0):
model_path = latest_folder / "model.safetensors"
if not model_path.exists():
model_path = latest_folder / "pytorch_model.bin"
if model_path.exists():
if model_path.suffix == ".safetensors":
from safetensors.torch import load_file
state_dict = load_file(str(model_path))
else:
ckpt = torch.load(model_path, map_location="cpu")
state_dict = ckpt.get("state_dict", ckpt)
unwrapped.load_state_dict(state_dict, strict=False)
if rank == 0:
print(f"Loaded model weights from {model_path}", file=sys.stderr)
# Load optimizer state
optimizer_path = latest_folder / "optimizer.pth"
if optimizer_path.exists():
optimizer.load_state_dict(torch.load(optimizer_path, map_location="cpu"))
if rank == 0:
print(f"Loaded optimizer state from {optimizer_path}", file=sys.stderr)
# Load scheduler state
scheduler_path = latest_folder / "scheduler.pth"
if scheduler_path.exists():
scheduler.load_state_dict(torch.load(scheduler_path, map_location="cpu"))
if rank == 0:
print(f"Loaded scheduler state from {scheduler_path}", file=sys.stderr)
state_path = latest_folder / "training_state.json"
if state_path.exists():
with open(state_path, "r", encoding="utf-8") as f:
@@ -621,28 +716,36 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0):
if rank == 0:
print(f"Resuming from step {resume_step}", file=sys.stderr)
return resume_step
return 0
def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pretrained_path: str = None, hf_model_id: str = "", distribute: bool = False):
def save_checkpoint(
model,
optimizer,
scheduler,
save_dir: Path,
step: int,
pretrained_path: str = None,
hf_model_id: str = "",
distribute: bool = False,
):
"""
Save checkpoint with different strategies for full finetune vs LoRA:
- Full finetune: save non-vae weights to model.safetensors (or pytorch_model.bin if safetensors unavailable)
- LoRA: save only lora weights to lora_weights.safetensors (or lora_weights.ckpt if safetensors unavailable)
"""
import json
import shutil
save_dir.mkdir(parents=True, exist_ok=True)
tag = f"step_{step:07d}"
folder = save_dir / tag
folder.mkdir(parents=True, exist_ok=True)
unwrapped = model.module if hasattr(model, "module") else model
full_state = unwrapped.state_dict()
lora_cfg = unwrapped.lora_config
if lora_cfg is not None:
# LoRA finetune: save only lora_A/lora_B weights
state_dict = {k: v for k, v in full_state.items() if "lora_" in k}
@@ -650,7 +753,7 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
save_file(state_dict, folder / "lora_weights.safetensors")
else:
torch.save({"state_dict": state_dict}, folder / "lora_weights.ckpt")
# Save LoRA config and base model path to a separate JSON file
# If distribute=True, save hf_model_id; otherwise save local pretrained_path
base_model_to_save = hf_model_id if distribute else (str(pretrained_path) if pretrained_path else None)
@@ -667,16 +770,23 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
save_file(state_dict, folder / "model.safetensors")
else:
torch.save({"state_dict": state_dict}, folder / "pytorch_model.bin")
# Copy config files from pretrained path
if pretrained_path:
pretrained_dir = Path(pretrained_path)
files_to_copy = ["config.json", "audiovae.pth", "tokenizer.json", "special_tokens_map.json", "tokenizer_config.json"]
files_to_copy = [
"config.json",
"audiovae.pth",
"audiovae.safetensors",
"tokenizer.json",
"special_tokens_map.json",
"tokenizer_config.json",
]
for fname in files_to_copy:
src = pretrained_dir / fname
if src.exists():
shutil.copy2(src, folder / fname)
torch.save(optimizer.state_dict(), folder / "optimizer.pth")
torch.save(scheduler.state_dict(), folder / "scheduler.pth")
with open(folder / "training_state.json", "w", encoding="utf-8") as f: