update voxcpm2

This commit is contained in:
刘鑫
2026-03-31 11:50:37 +08:00
parent 23ed7ffeee
commit d9cf376e16
36 changed files with 8163 additions and 834 deletions
-1
View File
@@ -25,4 +25,3 @@ __all__ = [
"load_audio_text_datasets",
"build_dataloader",
]
+2 -5
View File
@@ -47,9 +47,7 @@ class Accelerator:
pass
self.scaler = torch.amp.GradScaler("cuda") if (amp and torch.cuda.is_available()) else DummyScaler()
self.device_ctx = (
torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
)
self.device_ctx = torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
self._ddp_model = None # For no_sync support
def _set_seed(self, seed: int):
@@ -84,7 +82,7 @@ class Accelerator:
# Model helpers
# ------------------------------------------------------------------ #
def prepare_model(self, model: torch.nn.Module, **kwargs):
if hasattr(model, 'device'): # make sure the matrix will be moved to the correct device
if hasattr(model, "device"): # make sure the matrix will be moved to the correct device
model.device = self.device
model = model.to(self.device)
if self.world_size > 1:
@@ -163,4 +161,3 @@ class Accelerator:
@staticmethod
def unwrap(model: torch.nn.Module) -> torch.nn.Module:
return model.module if hasattr(model, "module") else model
-2
View File
@@ -36,5 +36,3 @@ def parse_args_with_config(config_path: str | Path | None = None):
yaml_args = argbind.parse_args(yaml_args=yaml_args, argv=[])
cli_args.update(yaml_args)
return cli_args
+4 -6
View File
@@ -1,5 +1,4 @@
import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import argbind
@@ -11,7 +10,6 @@ from ..model.voxcpm import VoxCPMConfig
from ..modules.audiovae import AudioVAE
from .packers import AudioFeatureProcessingPacker
DEFAULT_TEXT_COLUMN = "text"
DEFAULT_AUDIO_COLUMN = "audio"
DEFAULT_ID_COLUMN = "dataset_id"
@@ -36,7 +34,7 @@ def load_audio_text_datasets(
def prepare(ds: Dataset) -> Dataset:
if audio_column not in ds.column_names:
raise ValueError(f"Expected '{audio_column}' column in manifest.")
# We cast to Audio to ensure proper handling during training,
# We cast to Audio to ensure proper handling during training,
# but for length calculation we might need raw path or duration if available.
# HF datasets usually don't compute duration automatically for 'Audio' column.
ds = ds.cast_column(audio_column, Audio(sampling_rate=sample_rate))
@@ -70,13 +68,13 @@ def compute_sample_lengths(
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
t_seq = ceil(t_vae / patch_size)
- 序列总长约为: text_len + t_seq + 2
Optimized: Use batch column access instead of iterating item by item.
"""
# Batch access columns - much faster than per-item access
text_ids_list = ds["text_ids"]
text_lens = [len(t) for t in text_ids_list]
has_duration = "duration" in ds.column_names
if has_duration:
durations = ds["duration"]
@@ -86,7 +84,7 @@ def compute_sample_lengths(
for i in range(len(ds)):
audio = ds[i][DEFAULT_AUDIO_COLUMN]
durations.append(len(audio["array"]) / float(audio["sampling_rate"]))
# Vectorized length computation
lengths = []
for text_len, duration in zip(text_lens, durations):
+29 -22
View File
@@ -1,5 +1,4 @@
from typing import Dict, List, Tuple
from typing import Dict, List
import torch
import torch.nn as nn
@@ -15,7 +14,7 @@ class AudioFeatureProcessingPacker:
def __init__(self, dataset_cnt: int, max_len: int, patch_size: int, feat_dim: int, audio_vae: nn.Module):
self.audio_start_id = 101
self.audio_end_id = 102
# unused now
# unused now
self.audio_prompt_start_id = 103
self.audio_prompt_end_id = 104
self.text_eos_token_id = 2
@@ -147,31 +146,26 @@ class AudioFeatureProcessingPacker:
def pad_1d(x: torch.Tensor, pad_value: int = 0) -> torch.Tensor:
if x.size(0) >= max_len:
return x[: max_len]
return x[:max_len]
pad = torch.full((max_len - x.size(0),), pad_value, dtype=x.dtype, device=x.device)
return torch.cat([x, pad], dim=0)
def pad_3d(x: torch.Tensor) -> torch.Tensor:
# x: [T, P, D]
if x.size(0) >= max_len:
return x[: max_len]
pad = torch.zeros(
(max_len - x.size(0),) + x.shape[1:], dtype=x.dtype, device=x.device
)
return x[:max_len]
pad = torch.zeros((max_len - x.size(0),) + x.shape[1:], dtype=x.dtype, device=x.device)
return torch.cat([x, pad], dim=0)
if lengths:
text_tokens_batch = torch.stack([pad_1d(t, pad_value=0) for t in text_tokens_list], dim=0)
text_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in text_mask_list], dim=0)
audio_feats_batch = torch.stack([pad_3d(f) for f in audio_feats_list], dim=0)
audio_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in audio_mask_list], dim=0)
loss_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in loss_mask_list], dim=0)
labels_batch = torch.stack([pad_1d(l, pad_value=0) for l in labels_list], dim=0)
audio_task_ids_batch = torch.stack(
[pad_1d(t, pad_value=0) for t in audio_task_ids_list], dim=0
)
audio_dataset_ids_batch = torch.stack(
[pad_1d(d, pad_value=0) for d in audio_dataset_ids_list], dim=0
)
labels_batch = torch.stack([pad_1d(lbl, pad_value=0) for lbl in labels_list], dim=0)
audio_task_ids_batch = torch.stack([pad_1d(t, pad_value=0) for t in audio_task_ids_list], dim=0)
audio_dataset_ids_batch = torch.stack([pad_1d(d, pad_value=0) for d in audio_dataset_ids_list], dim=0)
# Position ids: [B, T], simple 0..L_i-1 then padded with 0
position_ids_list = []
@@ -265,13 +259,27 @@ class AudioFeatureProcessingPacker:
)
audio_feat_info = torch.cat([audio_pad_feat, audio_feat_info, audio_pad_feat[0:1, ...]], dim=0)
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length), torch.ones(1)]).type(torch.int32).to(
text_token.device
text_mask = (
torch.cat([torch.ones(text_length), torch.zeros(audio_length), torch.ones(1)])
.type(torch.int32)
.to(text_token.device)
)
audio_mask = (
torch.cat([torch.zeros(text_length), torch.ones(audio_length), torch.zeros(1)])
.type(torch.int32)
.to(text_token.device)
)
loss_mask = (
torch.cat(
[
torch.zeros(text_length),
torch.zeros(audio_length) if is_prompt else torch.ones(audio_length),
torch.zeros(1),
]
)
.type(torch.int32)
.to(text_token.device)
)
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length), torch.zeros(1)]).type(
torch.int32
).to(text_token.device)
loss_mask = torch.cat([torch.zeros(text_length), torch.zeros(audio_length) if is_prompt else torch.ones(audio_length), torch.zeros(1)]).type(torch.int32).to(text_token.device)
labels = torch.zeros(text_length + audio_length + 1).type(torch.int32).to(text_token.device)
labels[-2] = 1
@@ -286,4 +294,3 @@ class AudioFeatureProcessingPacker:
audio_duration,
text_token_count,
)
-1
View File
@@ -18,4 +18,3 @@ class TrainingState:
val_loader: object
tracker: object
batch_processor: object
-1
View File
@@ -76,4 +76,3 @@ class TrainingTracker:
@contextlib.contextmanager
def live(self):
yield