update voxcpm2
This commit is contained in:
@@ -25,4 +25,3 @@ __all__ = [
|
||||
"load_audio_text_datasets",
|
||||
"build_dataloader",
|
||||
]
|
||||
|
||||
|
||||
@@ -47,9 +47,7 @@ class Accelerator:
|
||||
pass
|
||||
|
||||
self.scaler = torch.amp.GradScaler("cuda") if (amp and torch.cuda.is_available()) else DummyScaler()
|
||||
self.device_ctx = (
|
||||
torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
|
||||
)
|
||||
self.device_ctx = torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
|
||||
self._ddp_model = None # For no_sync support
|
||||
|
||||
def _set_seed(self, seed: int):
|
||||
@@ -84,7 +82,7 @@ class Accelerator:
|
||||
# Model helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
def prepare_model(self, model: torch.nn.Module, **kwargs):
|
||||
if hasattr(model, 'device'): # make sure the matrix will be moved to the correct device
|
||||
if hasattr(model, "device"): # make sure the matrix will be moved to the correct device
|
||||
model.device = self.device
|
||||
model = model.to(self.device)
|
||||
if self.world_size > 1:
|
||||
@@ -163,4 +161,3 @@ class Accelerator:
|
||||
@staticmethod
|
||||
def unwrap(model: torch.nn.Module) -> torch.nn.Module:
|
||||
return model.module if hasattr(model, "module") else model
|
||||
|
||||
|
||||
@@ -36,5 +36,3 @@ def parse_args_with_config(config_path: str | Path | None = None):
|
||||
yaml_args = argbind.parse_args(yaml_args=yaml_args, argv=[])
|
||||
cli_args.update(yaml_args)
|
||||
return cli_args
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import argbind
|
||||
@@ -11,7 +10,6 @@ from ..model.voxcpm import VoxCPMConfig
|
||||
from ..modules.audiovae import AudioVAE
|
||||
from .packers import AudioFeatureProcessingPacker
|
||||
|
||||
|
||||
DEFAULT_TEXT_COLUMN = "text"
|
||||
DEFAULT_AUDIO_COLUMN = "audio"
|
||||
DEFAULT_ID_COLUMN = "dataset_id"
|
||||
@@ -36,7 +34,7 @@ def load_audio_text_datasets(
|
||||
def prepare(ds: Dataset) -> Dataset:
|
||||
if audio_column not in ds.column_names:
|
||||
raise ValueError(f"Expected '{audio_column}' column in manifest.")
|
||||
# We cast to Audio to ensure proper handling during training,
|
||||
# We cast to Audio to ensure proper handling during training,
|
||||
# but for length calculation we might need raw path or duration if available.
|
||||
# HF datasets usually don't compute duration automatically for 'Audio' column.
|
||||
ds = ds.cast_column(audio_column, Audio(sampling_rate=sample_rate))
|
||||
@@ -70,13 +68,13 @@ def compute_sample_lengths(
|
||||
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
|
||||
t_seq = ceil(t_vae / patch_size)
|
||||
- 序列总长约为: text_len + t_seq + 2
|
||||
|
||||
|
||||
Optimized: Use batch column access instead of iterating item by item.
|
||||
"""
|
||||
# Batch access columns - much faster than per-item access
|
||||
text_ids_list = ds["text_ids"]
|
||||
text_lens = [len(t) for t in text_ids_list]
|
||||
|
||||
|
||||
has_duration = "duration" in ds.column_names
|
||||
if has_duration:
|
||||
durations = ds["duration"]
|
||||
@@ -86,7 +84,7 @@ def compute_sample_lengths(
|
||||
for i in range(len(ds)):
|
||||
audio = ds[i][DEFAULT_AUDIO_COLUMN]
|
||||
durations.append(len(audio["array"]) / float(audio["sampling_rate"]))
|
||||
|
||||
|
||||
# Vectorized length computation
|
||||
lengths = []
|
||||
for text_len, duration in zip(text_lens, durations):
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -15,7 +14,7 @@ class AudioFeatureProcessingPacker:
|
||||
def __init__(self, dataset_cnt: int, max_len: int, patch_size: int, feat_dim: int, audio_vae: nn.Module):
|
||||
self.audio_start_id = 101
|
||||
self.audio_end_id = 102
|
||||
# unused now
|
||||
# unused now
|
||||
self.audio_prompt_start_id = 103
|
||||
self.audio_prompt_end_id = 104
|
||||
self.text_eos_token_id = 2
|
||||
@@ -147,31 +146,26 @@ class AudioFeatureProcessingPacker:
|
||||
|
||||
def pad_1d(x: torch.Tensor, pad_value: int = 0) -> torch.Tensor:
|
||||
if x.size(0) >= max_len:
|
||||
return x[: max_len]
|
||||
return x[:max_len]
|
||||
pad = torch.full((max_len - x.size(0),), pad_value, dtype=x.dtype, device=x.device)
|
||||
return torch.cat([x, pad], dim=0)
|
||||
|
||||
def pad_3d(x: torch.Tensor) -> torch.Tensor:
|
||||
# x: [T, P, D]
|
||||
if x.size(0) >= max_len:
|
||||
return x[: max_len]
|
||||
pad = torch.zeros(
|
||||
(max_len - x.size(0),) + x.shape[1:], dtype=x.dtype, device=x.device
|
||||
)
|
||||
return x[:max_len]
|
||||
pad = torch.zeros((max_len - x.size(0),) + x.shape[1:], dtype=x.dtype, device=x.device)
|
||||
return torch.cat([x, pad], dim=0)
|
||||
|
||||
if lengths:
|
||||
text_tokens_batch = torch.stack([pad_1d(t, pad_value=0) for t in text_tokens_list], dim=0)
|
||||
text_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in text_mask_list], dim=0)
|
||||
audio_feats_batch = torch.stack([pad_3d(f) for f in audio_feats_list], dim=0)
|
||||
audio_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in audio_mask_list], dim=0)
|
||||
loss_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in loss_mask_list], dim=0)
|
||||
labels_batch = torch.stack([pad_1d(l, pad_value=0) for l in labels_list], dim=0)
|
||||
audio_task_ids_batch = torch.stack(
|
||||
[pad_1d(t, pad_value=0) for t in audio_task_ids_list], dim=0
|
||||
)
|
||||
audio_dataset_ids_batch = torch.stack(
|
||||
[pad_1d(d, pad_value=0) for d in audio_dataset_ids_list], dim=0
|
||||
)
|
||||
labels_batch = torch.stack([pad_1d(lbl, pad_value=0) for lbl in labels_list], dim=0)
|
||||
audio_task_ids_batch = torch.stack([pad_1d(t, pad_value=0) for t in audio_task_ids_list], dim=0)
|
||||
audio_dataset_ids_batch = torch.stack([pad_1d(d, pad_value=0) for d in audio_dataset_ids_list], dim=0)
|
||||
|
||||
# Position ids: [B, T], simple 0..L_i-1 then padded with 0
|
||||
position_ids_list = []
|
||||
@@ -265,13 +259,27 @@ class AudioFeatureProcessingPacker:
|
||||
)
|
||||
audio_feat_info = torch.cat([audio_pad_feat, audio_feat_info, audio_pad_feat[0:1, ...]], dim=0)
|
||||
|
||||
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length), torch.ones(1)]).type(torch.int32).to(
|
||||
text_token.device
|
||||
text_mask = (
|
||||
torch.cat([torch.ones(text_length), torch.zeros(audio_length), torch.ones(1)])
|
||||
.type(torch.int32)
|
||||
.to(text_token.device)
|
||||
)
|
||||
audio_mask = (
|
||||
torch.cat([torch.zeros(text_length), torch.ones(audio_length), torch.zeros(1)])
|
||||
.type(torch.int32)
|
||||
.to(text_token.device)
|
||||
)
|
||||
loss_mask = (
|
||||
torch.cat(
|
||||
[
|
||||
torch.zeros(text_length),
|
||||
torch.zeros(audio_length) if is_prompt else torch.ones(audio_length),
|
||||
torch.zeros(1),
|
||||
]
|
||||
)
|
||||
.type(torch.int32)
|
||||
.to(text_token.device)
|
||||
)
|
||||
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length), torch.zeros(1)]).type(
|
||||
torch.int32
|
||||
).to(text_token.device)
|
||||
loss_mask = torch.cat([torch.zeros(text_length), torch.zeros(audio_length) if is_prompt else torch.ones(audio_length), torch.zeros(1)]).type(torch.int32).to(text_token.device)
|
||||
|
||||
labels = torch.zeros(text_length + audio_length + 1).type(torch.int32).to(text_token.device)
|
||||
labels[-2] = 1
|
||||
@@ -286,4 +294,3 @@ class AudioFeatureProcessingPacker:
|
||||
audio_duration,
|
||||
text_token_count,
|
||||
)
|
||||
|
||||
|
||||
@@ -18,4 +18,3 @@ class TrainingState:
|
||||
val_loader: object
|
||||
tracker: object
|
||||
batch_processor: object
|
||||
|
||||
|
||||
@@ -76,4 +76,3 @@ class TrainingTracker:
|
||||
@contextlib.contextmanager
|
||||
def live(self):
|
||||
yield
|
||||
|
||||
|
||||
Reference in New Issue
Block a user