Update: VoxCPM1.5 and fine-tuning supprt
This commit is contained in:
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
Training utilities for VoxCPM fine-tuning.
|
||||
|
||||
This package mirrors the training mechanics used in the minicpm-audio
|
||||
tooling while relying solely on local audio-text datasets managed via
|
||||
the HuggingFace ``datasets`` library.
|
||||
"""
|
||||
|
||||
from .accelerator import Accelerator
|
||||
from .tracker import TrainingTracker
|
||||
from .data import (
|
||||
load_audio_text_datasets,
|
||||
HFVoxCPMDataset,
|
||||
build_dataloader,
|
||||
BatchProcessor,
|
||||
)
|
||||
from .state import TrainingState
|
||||
|
||||
__all__ = [
|
||||
"Accelerator",
|
||||
"TrainingTracker",
|
||||
"HFVoxCPMDataset",
|
||||
"BatchProcessor",
|
||||
"TrainingState",
|
||||
"load_audio_text_datasets",
|
||||
"build_dataloader",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,166 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import random
|
||||
import typing
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.utils.data
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
|
||||
class Accelerator:
|
||||
"""
|
||||
Simplified accelerator that mirrors the behaviour of the minicpm-audio
|
||||
training utilities. It initializes a distributed process group when
|
||||
``torchrun`` is used and exposes helpers for AMP, gradient scaling and
|
||||
preparing models/dataloaders for DDP.
|
||||
"""
|
||||
|
||||
def __init__(self, amp: bool = False, seed: int = 42):
|
||||
self.world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||
|
||||
if self.world_size > 1 and not dist.is_initialized():
|
||||
dist.init_process_group("nccl", init_method="env://")
|
||||
|
||||
self.rank = dist.get_rank() if dist.is_initialized() else 0
|
||||
self.local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
||||
self.amp = amp
|
||||
|
||||
# Set random seed to ensure model initialization consistency
|
||||
self._set_seed(seed)
|
||||
|
||||
class DummyScaler:
|
||||
def step(self, optimizer):
|
||||
optimizer.step()
|
||||
|
||||
def scale(self, loss):
|
||||
return loss
|
||||
|
||||
def unscale_(self, optimizer):
|
||||
return optimizer
|
||||
|
||||
def update(self):
|
||||
pass
|
||||
|
||||
self.scaler = torch.amp.GradScaler("cuda") if (amp and torch.cuda.is_available()) else DummyScaler()
|
||||
self.device_ctx = (
|
||||
torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
|
||||
)
|
||||
self._ddp_model = None # For no_sync support
|
||||
|
||||
def _set_seed(self, seed: int):
|
||||
"""Set random seed to ensure model initialization consistency across multiple GPUs"""
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
def __enter__(self):
|
||||
if self.device_ctx is not None:
|
||||
self.device_ctx.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
if self.device_ctx is not None:
|
||||
self.device_ctx.__exit__(exc_type, exc_value, traceback)
|
||||
|
||||
def barrier(self):
|
||||
"""Synchronize all processes"""
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
def all_reduce(self, tensor: torch.Tensor, op=dist.ReduceOp.AVG):
|
||||
"""All-reduce tensor across processes"""
|
||||
if dist.is_initialized():
|
||||
dist.all_reduce(tensor, op=op)
|
||||
return tensor
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Model helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
def prepare_model(self, model: torch.nn.Module, **kwargs):
|
||||
if hasattr(model, 'device'): # make sure the matrix will be moved to the correct device
|
||||
model.device = self.device
|
||||
model = model.to(self.device)
|
||||
if self.world_size > 1:
|
||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
model = DistributedDataParallel(model, device_ids=[self.local_rank], **kwargs)
|
||||
self._ddp_model = model # Save DDP model reference for no_sync support
|
||||
return model
|
||||
|
||||
@contextlib.contextmanager
|
||||
def no_sync(self):
|
||||
"""
|
||||
Context manager to skip gradient synchronization during gradient accumulation.
|
||||
Only used outside the last micro-batch.
|
||||
"""
|
||||
if self._ddp_model is not None:
|
||||
with self._ddp_model.no_sync():
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda", self.local_rank)
|
||||
if torch.backends.mps.is_available():
|
||||
return torch.device("mps")
|
||||
return torch.device("cpu")
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# AMP helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
def autocast(self, *args, **kwargs):
|
||||
return torch.amp.autocast("cuda", enabled=self.amp, *args, **kwargs)
|
||||
|
||||
def backward(self, loss: torch.Tensor):
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
def step(self, optimizer: torch.optim.Optimizer):
|
||||
self.scaler.step(optimizer)
|
||||
|
||||
def update(self):
|
||||
self.scaler.update()
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Data helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
def prepare_dataloader(
|
||||
self,
|
||||
dataset: typing.Iterable,
|
||||
*,
|
||||
batch_size: int,
|
||||
num_workers: int = 0,
|
||||
shuffle: bool = True,
|
||||
collate_fn=None,
|
||||
drop_last: bool = False,
|
||||
) -> torch.utils.data.DataLoader:
|
||||
if self.world_size > 1:
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle
|
||||
)
|
||||
shuffle = False
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle if sampler is None else False,
|
||||
sampler=sampler,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn,
|
||||
drop_last=drop_last,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def unwrap(model: torch.nn.Module) -> torch.nn.Module:
|
||||
return model.module if hasattr(model, "module") else model
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argbind
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
def load_yaml_config(path: str | Path) -> Dict[str, Any]:
|
||||
"""
|
||||
Load a YAML configuration file into a dictionary suitable for argbind.
|
||||
"""
|
||||
path = Path(path)
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"Configuration file {path} must contain a top-level mapping.")
|
||||
return data
|
||||
|
||||
|
||||
def parse_args_with_config(config_path: str | Path | None = None):
|
||||
"""
|
||||
Helper to unify CLI arguments and YAML configuration.
|
||||
|
||||
Usage mirrors minicpm-audio:
|
||||
args = parse_args_with_config("conf/voxcpm/finetune.yml")
|
||||
with argbind.scope(args):
|
||||
...
|
||||
"""
|
||||
cli_args = argbind.parse_args()
|
||||
if config_path is None:
|
||||
return cli_args
|
||||
|
||||
yaml_args = load_yaml_config(config_path)
|
||||
with argbind.scope(cli_args):
|
||||
yaml_args = argbind.parse_args(yaml_args=yaml_args, argv=[])
|
||||
cli_args.update(yaml_args)
|
||||
return cli_args
|
||||
|
||||
|
||||
@@ -0,0 +1,214 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import argbind
|
||||
import torch
|
||||
from datasets import Audio, Dataset, DatasetDict, load_dataset
|
||||
from torch.utils.data import Dataset as TorchDataset
|
||||
|
||||
from ..model.voxcpm import VoxCPMConfig
|
||||
from ..modules.audiovae import AudioVAE
|
||||
from .packers import AudioFeatureProcessingPacker
|
||||
|
||||
|
||||
DEFAULT_TEXT_COLUMN = "text"
|
||||
DEFAULT_AUDIO_COLUMN = "audio"
|
||||
DEFAULT_ID_COLUMN = "dataset_id"
|
||||
|
||||
|
||||
@argbind.bind()
|
||||
def load_audio_text_datasets(
|
||||
train_manifest: str,
|
||||
val_manifest: str = "",
|
||||
text_column: str = DEFAULT_TEXT_COLUMN,
|
||||
audio_column: str = DEFAULT_AUDIO_COLUMN,
|
||||
dataset_id_column: str = DEFAULT_ID_COLUMN,
|
||||
sample_rate: int = 16_000,
|
||||
num_proc: int = 1,
|
||||
) -> Tuple[Dataset, Optional[Dataset]]:
|
||||
data_files = {"train": train_manifest}
|
||||
if val_manifest:
|
||||
data_files["validation"] = val_manifest
|
||||
|
||||
dataset_dict: DatasetDict = load_dataset("json", data_files=data_files)
|
||||
|
||||
def prepare(ds: Dataset) -> Dataset:
|
||||
if audio_column not in ds.column_names:
|
||||
raise ValueError(f"Expected '{audio_column}' column in manifest.")
|
||||
# We cast to Audio to ensure proper handling during training,
|
||||
# but for length calculation we might need raw path or duration if available.
|
||||
# HF datasets usually don't compute duration automatically for 'Audio' column.
|
||||
ds = ds.cast_column(audio_column, Audio(sampling_rate=sample_rate))
|
||||
if audio_column != DEFAULT_AUDIO_COLUMN:
|
||||
ds = ds.rename_column(audio_column, DEFAULT_AUDIO_COLUMN)
|
||||
if text_column != DEFAULT_TEXT_COLUMN:
|
||||
ds = ds.rename_column(text_column, DEFAULT_TEXT_COLUMN)
|
||||
if dataset_id_column and dataset_id_column in ds.column_names:
|
||||
if dataset_id_column != DEFAULT_ID_COLUMN:
|
||||
ds = ds.rename_column(dataset_id_column, DEFAULT_ID_COLUMN)
|
||||
else:
|
||||
ds = ds.add_column(DEFAULT_ID_COLUMN, [0] * len(ds))
|
||||
return ds
|
||||
|
||||
train_ds = prepare(dataset_dict["train"])
|
||||
val_ds = prepare(dataset_dict["validation"]) if "validation" in dataset_dict else None
|
||||
return train_ds, val_ds
|
||||
|
||||
|
||||
def compute_sample_lengths(
|
||||
ds: Dataset,
|
||||
audio_vae_fps: int = 25,
|
||||
patch_size: int = 1,
|
||||
) -> List[int]:
|
||||
"""
|
||||
预估每个样本经过 packer 之后的大致序列长度(text+audio),用于过滤超长样本。
|
||||
|
||||
逻辑与 AudioFeatureProcessingPacker / AudioVAE 一致:
|
||||
- 文本长度: len(text_ids)
|
||||
- 音频长度:
|
||||
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
|
||||
t_seq = ceil(t_vae / patch_size)
|
||||
- 序列总长约为: text_len + t_seq + 2
|
||||
"""
|
||||
lengths: List[int] = []
|
||||
|
||||
has_duration = "duration" in ds.column_names
|
||||
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
text_len = len(item["text_ids"])
|
||||
|
||||
# 音频时长(尽量不解码;若 manifest 里已有 duration 列则优先使用)
|
||||
if has_duration:
|
||||
duration = float(item["duration"])
|
||||
else:
|
||||
audio = item[DEFAULT_AUDIO_COLUMN]
|
||||
duration = len(audio["array"]) / float(audio["sampling_rate"])
|
||||
|
||||
t_vae = math.ceil(duration * audio_vae_fps)
|
||||
t_seq = math.ceil(t_vae / patch_size)
|
||||
|
||||
total_len = text_len + t_seq + 2
|
||||
lengths.append(total_len)
|
||||
|
||||
return lengths
|
||||
|
||||
|
||||
class HFVoxCPMDataset(TorchDataset):
|
||||
"""
|
||||
Thin wrapper around a tokenized HuggingFace dataset that returns
|
||||
PyTorch-friendly samples.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset: Dataset):
|
||||
self.dataset = dataset
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
item = self.dataset[idx]
|
||||
audio = item[DEFAULT_AUDIO_COLUMN]
|
||||
return {
|
||||
"text_ids": item["text_ids"],
|
||||
"audio_array": audio["array"],
|
||||
"audio_sampling_rate": audio["sampling_rate"],
|
||||
"dataset_id": item.get(DEFAULT_ID_COLUMN, 0),
|
||||
"is_prompt": item.get("is_prompt", False),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def pad_sequences(seqs: List[torch.Tensor], pad_value: float):
|
||||
if not seqs:
|
||||
return torch.empty(0)
|
||||
max_len = max(seq.shape[0] for seq in seqs)
|
||||
padded = []
|
||||
for seq in seqs:
|
||||
if seq.shape[0] < max_len:
|
||||
pad_width = (0, max_len - seq.shape[0])
|
||||
seq = torch.nn.functional.pad(seq, pad_width, value=pad_value)
|
||||
padded.append(seq)
|
||||
return torch.stack(padded)
|
||||
|
||||
@classmethod
|
||||
def collate_fn(cls, batch: List[Dict]):
|
||||
text_tensors = [torch.tensor(sample["text_ids"], dtype=torch.int32) for sample in batch]
|
||||
audio_tensors = [torch.tensor(sample["audio_array"], dtype=torch.float32) for sample in batch]
|
||||
dataset_ids = torch.tensor([sample["dataset_id"] for sample in batch], dtype=torch.int32)
|
||||
is_prompts = [bool(sample.get("is_prompt", False)) for sample in batch]
|
||||
|
||||
text_padded = cls.pad_sequences(text_tensors, pad_value=-100)
|
||||
audio_padded = cls.pad_sequences(audio_tensors, pad_value=-100.0)
|
||||
task_ids = torch.ones(text_padded.size(0), dtype=torch.int32)
|
||||
|
||||
return {
|
||||
"text_tokens": text_padded,
|
||||
"audio_tokens": audio_padded,
|
||||
"task_ids": task_ids,
|
||||
"dataset_ids": dataset_ids,
|
||||
"is_prompts": is_prompts,
|
||||
}
|
||||
|
||||
|
||||
class BatchProcessor:
|
||||
"""
|
||||
Wraps ``AudioFeatureProcessingPacker`` so the training loop can mirror
|
||||
the minicpm-audio mechanics.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
config: VoxCPMConfig,
|
||||
audio_vae: AudioVAE,
|
||||
dataset_cnt: int,
|
||||
device: torch.device,
|
||||
):
|
||||
self.device = device
|
||||
self.dataset_cnt = dataset_cnt
|
||||
self.audio_vae = audio_vae
|
||||
self.audio_vae.to(device)
|
||||
self.packer = AudioFeatureProcessingPacker(
|
||||
dataset_cnt=dataset_cnt,
|
||||
max_len=config.max_length,
|
||||
patch_size=config.patch_size,
|
||||
feat_dim=config.feat_dim,
|
||||
audio_vae=self.audio_vae,
|
||||
)
|
||||
|
||||
def __call__(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
audio_tokens = batch["audio_tokens"].to(self.device)
|
||||
text_tokens = batch["text_tokens"].to(self.device)
|
||||
task_ids = batch["task_ids"].to(self.device)
|
||||
dataset_ids = batch["dataset_ids"].to(self.device)
|
||||
|
||||
packed = self.packer(
|
||||
audio_tokens=audio_tokens,
|
||||
text_tokens=text_tokens,
|
||||
task_ids=task_ids,
|
||||
dataset_ids=dataset_ids,
|
||||
is_prompts=batch["is_prompts"],
|
||||
)
|
||||
return packed
|
||||
|
||||
|
||||
def build_dataloader(
|
||||
hf_dataset: Dataset,
|
||||
*,
|
||||
accelerator,
|
||||
batch_size: int,
|
||||
num_workers: int,
|
||||
drop_last: bool = False,
|
||||
) -> torch.utils.data.DataLoader:
|
||||
torch_dataset = HFVoxCPMDataset(hf_dataset)
|
||||
# Standard padding-based batching; Accelerator will attach DistributedSampler if needed.
|
||||
return accelerator.prepare_dataloader(
|
||||
torch_dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
shuffle=True,
|
||||
collate_fn=HFVoxCPMDataset.collate_fn,
|
||||
drop_last=drop_last,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,289 @@
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class AudioFeatureProcessingPacker:
|
||||
"""
|
||||
Adapted from the minicpm-audio training utilities. It converts raw text and
|
||||
audio tokens into the packed multimodal representation required by VoxCPM.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset_cnt: int, max_len: int, patch_size: int, feat_dim: int, audio_vae: nn.Module):
|
||||
self.audio_start_id = 101
|
||||
self.audio_end_id = 102
|
||||
# unused now
|
||||
self.audio_prompt_start_id = 103
|
||||
self.audio_prompt_end_id = 104
|
||||
self.text_eos_token_id = 2
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.patch_len = audio_vae.hop_length * self.patch_size
|
||||
self.feat_dim = feat_dim
|
||||
self.dataset_cnt = max(dataset_cnt, 1)
|
||||
self.max_len = max_len
|
||||
|
||||
self.audio_vae = audio_vae
|
||||
|
||||
self.process_functions = {"tts": self.process_tts_data}
|
||||
self.task_id_map = {"tts": 1}
|
||||
self.id_to_task = {idx: usage for usage, idx in self.task_id_map.items()}
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
@staticmethod
|
||||
def _first_pad_position(tokens: torch.Tensor):
|
||||
positions = (tokens == -100).nonzero(as_tuple=True)
|
||||
if positions[0].numel() == 0:
|
||||
return None
|
||||
return int(positions[0][0])
|
||||
|
||||
def unpad_text_tokens(self, tokens: torch.Tensor):
|
||||
pad_pos = self._first_pad_position(tokens)
|
||||
return tokens if pad_pos is None else tokens[:pad_pos]
|
||||
|
||||
def unpad_audio_tokens(self, tokens: torch.Tensor):
|
||||
pad_pos = self._first_pad_position(tokens)
|
||||
return tokens if pad_pos is None else tokens[:pad_pos]
|
||||
|
||||
def encode_audio(self, wav: torch.Tensor):
|
||||
"""
|
||||
Encode raw waveform into latent features using AudioVAE.
|
||||
|
||||
AudioVAE.encode expects shape [B, 1, T'] and returns [B, D, T].
|
||||
We then transpose to [B, T, D] to match downstream expectations.
|
||||
"""
|
||||
wav = wav.unsqueeze(0) # [1, T]
|
||||
wav = wav.unsqueeze(1) # [1, 1, T]
|
||||
wav_len = wav.size(-1)
|
||||
if wav_len % self.patch_len != 0:
|
||||
padding_size = self.patch_len - wav_len % self.patch_len
|
||||
wav = torch.nn.functional.pad(wav, (0, padding_size))
|
||||
|
||||
with torch.no_grad():
|
||||
z = self.audio_vae.encode(wav, self.audio_vae.sample_rate) # [1, D, T']
|
||||
feat = z.transpose(1, 2) # [1, T', D]
|
||||
return feat
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Main entry point
|
||||
# ------------------------------------------------------------------ #
|
||||
def __call__(
|
||||
self,
|
||||
audio_tokens: torch.Tensor,
|
||||
text_tokens: torch.Tensor,
|
||||
task_ids: torch.Tensor,
|
||||
dataset_ids: torch.Tensor,
|
||||
is_prompts: List[bool],
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Padding-based batching: each sample in the input batch is processed
|
||||
independently and then padded to a common length (capped by ``max_len``).
|
||||
The result tensors all have shape [B, T, ...].
|
||||
"""
|
||||
device = audio_tokens.device
|
||||
max_dataset_id = int(dataset_ids.max().item()) if dataset_ids.numel() > 0 else -1
|
||||
dataset_cnt = max(self.dataset_cnt, max_dataset_id + 1)
|
||||
|
||||
text_tokens_list: List[torch.Tensor] = []
|
||||
audio_feats_list: List[torch.Tensor] = []
|
||||
text_mask_list: List[torch.Tensor] = []
|
||||
audio_mask_list: List[torch.Tensor] = []
|
||||
loss_mask_list: List[torch.Tensor] = []
|
||||
labels_list: List[torch.Tensor] = []
|
||||
audio_task_ids_list: List[torch.Tensor] = []
|
||||
audio_dataset_ids_list: List[torch.Tensor] = []
|
||||
lengths: List[int] = []
|
||||
|
||||
audio_duration_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
|
||||
text_token_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
|
||||
|
||||
for audio_token, text_token, task_id, dataset_idx, is_prompt in zip(
|
||||
audio_tokens, text_tokens, task_ids.tolist(), dataset_ids.tolist(), is_prompts
|
||||
):
|
||||
unpad_audio_token = self.unpad_audio_tokens(audio_token).to(torch.float32)
|
||||
unpad_text_token = self.unpad_text_tokens(text_token)
|
||||
usage = self.id_to_task[task_id]
|
||||
|
||||
(
|
||||
packed_text,
|
||||
audio_feat,
|
||||
text_mask,
|
||||
audio_mask,
|
||||
loss_mask,
|
||||
labels,
|
||||
audio_duration,
|
||||
text_token_count,
|
||||
) = self.process_functions[usage](unpad_audio_token, unpad_text_token, is_prompt)
|
||||
|
||||
audio_duration_consumed[dataset_idx] += audio_duration
|
||||
text_token_consumed[dataset_idx] += text_token_count
|
||||
|
||||
audio_task_id = torch.zeros_like(audio_mask)
|
||||
audio_task_id[audio_mask == 1] = self.task_id_map[usage]
|
||||
|
||||
audio_dataset_id = torch.zeros_like(audio_mask)
|
||||
audio_dataset_id[audio_mask == 1] = dataset_idx + 1
|
||||
|
||||
text_tokens_list.append(packed_text)
|
||||
text_mask_list.append(text_mask)
|
||||
audio_feats_list.append(audio_feat)
|
||||
audio_mask_list.append(audio_mask)
|
||||
loss_mask_list.append(loss_mask)
|
||||
labels_list.append(labels)
|
||||
audio_task_ids_list.append(audio_task_id)
|
||||
audio_dataset_ids_list.append(audio_dataset_id)
|
||||
lengths.append(packed_text.shape[0])
|
||||
|
||||
# Determine padded length per batch (cap by self.max_len)
|
||||
if lengths:
|
||||
max_len = min(self.max_len, max(lengths))
|
||||
else:
|
||||
max_len = self.max_len
|
||||
|
||||
def pad_1d(x: torch.Tensor, pad_value: int = 0) -> torch.Tensor:
|
||||
if x.size(0) >= max_len:
|
||||
return x[: max_len]
|
||||
pad = torch.full((max_len - x.size(0),), pad_value, dtype=x.dtype, device=x.device)
|
||||
return torch.cat([x, pad], dim=0)
|
||||
|
||||
def pad_3d(x: torch.Tensor) -> torch.Tensor:
|
||||
# x: [T, P, D]
|
||||
if x.size(0) >= max_len:
|
||||
return x[: max_len]
|
||||
pad = torch.zeros(
|
||||
(max_len - x.size(0),) + x.shape[1:], dtype=x.dtype, device=x.device
|
||||
)
|
||||
return torch.cat([x, pad], dim=0)
|
||||
if lengths:
|
||||
text_tokens_batch = torch.stack([pad_1d(t, pad_value=0) for t in text_tokens_list], dim=0)
|
||||
text_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in text_mask_list], dim=0)
|
||||
audio_feats_batch = torch.stack([pad_3d(f) for f in audio_feats_list], dim=0)
|
||||
audio_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in audio_mask_list], dim=0)
|
||||
loss_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in loss_mask_list], dim=0)
|
||||
labels_batch = torch.stack([pad_1d(l, pad_value=0) for l in labels_list], dim=0)
|
||||
audio_task_ids_batch = torch.stack(
|
||||
[pad_1d(t, pad_value=0) for t in audio_task_ids_list], dim=0
|
||||
)
|
||||
audio_dataset_ids_batch = torch.stack(
|
||||
[pad_1d(d, pad_value=0) for d in audio_dataset_ids_list], dim=0
|
||||
)
|
||||
|
||||
# Position ids: [B, T], simple 0..L_i-1 then padded with 0
|
||||
position_ids_list = []
|
||||
for L in lengths:
|
||||
L_clip = min(L, max_len)
|
||||
pos = torch.arange(0, L_clip, device=device)
|
||||
if L_clip < max_len:
|
||||
pad = torch.zeros(max_len - L_clip, dtype=pos.dtype, device=device)
|
||||
pos = torch.cat([pos, pad], dim=0)
|
||||
position_ids_list.append(pos)
|
||||
position_ids = torch.stack(position_ids_list, dim=0)
|
||||
else:
|
||||
# Empty batch fallback (shouldn't really happen)
|
||||
text_tokens_batch = torch.zeros((0, self.max_len), dtype=torch.int32, device=device)
|
||||
text_mask_batch = torch.zeros_like(text_tokens_batch)
|
||||
audio_feats_batch = torch.zeros(
|
||||
(0, self.max_len, self.patch_size, self.feat_dim), dtype=torch.float32, device=device
|
||||
)
|
||||
audio_mask_batch = torch.zeros_like(text_tokens_batch)
|
||||
loss_mask_batch = torch.zeros_like(text_tokens_batch)
|
||||
labels_batch = torch.zeros_like(text_tokens_batch)
|
||||
audio_task_ids_batch = torch.zeros_like(text_tokens_batch)
|
||||
audio_dataset_ids_batch = torch.zeros_like(text_tokens_batch)
|
||||
position_ids = torch.zeros_like(text_tokens_batch)
|
||||
|
||||
audio_duration_consumed = audio_duration_consumed.to(torch.long)
|
||||
text_token_consumed = text_token_consumed.to(torch.long)
|
||||
|
||||
return {
|
||||
"text_tokens": text_tokens_batch,
|
||||
"audio_feats": audio_feats_batch,
|
||||
"text_mask": text_mask_batch,
|
||||
"audio_mask": audio_mask_batch,
|
||||
"loss_mask": loss_mask_batch,
|
||||
"position_ids": position_ids,
|
||||
"labels": labels_batch,
|
||||
"audio_task_ids": audio_task_ids_batch,
|
||||
"audio_dataset_ids": audio_dataset_ids_batch,
|
||||
"audio_duration_consumed": audio_duration_consumed,
|
||||
"text_token_consumed": text_token_consumed,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Feature extraction helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
def extract_audio_feats(self, audio_data: torch.Tensor):
|
||||
audio_feats = self.encode_audio(audio_data)
|
||||
if audio_feats.size(1) % self.patch_size != 0:
|
||||
audio_feats_ = audio_feats.transpose(1, 2)
|
||||
padding = nn.functional.pad(audio_feats_, (0, self.patch_size - audio_feats.size(1) % self.patch_size))
|
||||
audio_feats = padding.transpose(1, 2)
|
||||
|
||||
audio_duration = audio_feats.size(1) / 25
|
||||
audio_feats = rearrange(audio_feats, "b (t p) c -> b t p c", p=self.patch_size)
|
||||
return audio_feats, audio_duration
|
||||
|
||||
def process_tts_data(self, audio_token: torch.Tensor, text_token: torch.Tensor, is_prompt: bool = False):
|
||||
text_token_info = torch.cat(
|
||||
[
|
||||
text_token,
|
||||
torch.tensor(
|
||||
[self.audio_prompt_start_id if is_prompt else self.audio_start_id],
|
||||
dtype=torch.int32,
|
||||
device=text_token.device,
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
text_token_count = len(text_token)
|
||||
text_length = text_token_info.shape[0]
|
||||
audio_feat_info, audio_duration = self.extract_audio_feats(audio_token)
|
||||
audio_feat_info = audio_feat_info.squeeze(0)
|
||||
audio_length = audio_feat_info.shape[0]
|
||||
|
||||
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
|
||||
text_token_info = torch.cat(
|
||||
[
|
||||
text_token_info,
|
||||
text_pad_token,
|
||||
torch.tensor(
|
||||
[self.audio_prompt_end_id if is_prompt else self.audio_end_id],
|
||||
dtype=torch.int32,
|
||||
device=text_token.device,
|
||||
),
|
||||
]
|
||||
)
|
||||
audio_pad_feat = torch.zeros(
|
||||
(text_length, self.patch_size, audio_feat_info.size(-1)),
|
||||
dtype=torch.float32,
|
||||
device=text_token.device,
|
||||
)
|
||||
audio_feat_info = torch.cat([audio_pad_feat, audio_feat_info, audio_pad_feat[0:1, ...]], dim=0)
|
||||
|
||||
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length), torch.ones(1)]).type(torch.int32).to(
|
||||
text_token.device
|
||||
)
|
||||
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length), torch.zeros(1)]).type(
|
||||
torch.int32
|
||||
).to(text_token.device)
|
||||
loss_mask = torch.cat([torch.zeros(text_length), torch.zeros(audio_length) if is_prompt else torch.ones(audio_length), torch.zeros(1)]).type(torch.int32).to(text_token.device)
|
||||
|
||||
labels = torch.zeros(text_length + audio_length + 1).type(torch.int32).to(text_token.device)
|
||||
labels[-2] = 1
|
||||
|
||||
return (
|
||||
text_token_info,
|
||||
audio_feat_info,
|
||||
text_mask,
|
||||
audio_mask,
|
||||
loss_mask,
|
||||
labels,
|
||||
audio_duration,
|
||||
text_token_count,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingState:
|
||||
"""
|
||||
Container that mirrors the object returned in the minicpm-audio training
|
||||
loop. It holds persistent references to the model, optimizer, scheduler,
|
||||
dataloaders and tracker.
|
||||
"""
|
||||
|
||||
generator: object
|
||||
optimizer: object
|
||||
scheduler: object
|
||||
train_loader: object
|
||||
val_loader: object
|
||||
tracker: object
|
||||
batch_processor: object
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
class TrainingTracker:
|
||||
"""
|
||||
Lightweight tracker inspired by the minimcpm-audio training workflow.
|
||||
|
||||
It keeps track of the current global step, prints rank-aware messages,
|
||||
optionally writes to TensorBoard via a provided writer, and stores progress
|
||||
in a logfile for later inspection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
writer=None,
|
||||
log_file: Optional[str] = None,
|
||||
rank: int = 0,
|
||||
):
|
||||
self.writer = writer
|
||||
self.log_file = Path(log_file) if log_file else None
|
||||
if self.log_file:
|
||||
self.log_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.rank = rank
|
||||
self.step = 0
|
||||
# Record the time of the last log to calculate the interval
|
||||
self._last_log_time: float | None = None
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Logging helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
def print(self, message: str):
|
||||
if self.rank == 0:
|
||||
print(message, flush=True)
|
||||
if self.log_file:
|
||||
with self.log_file.open("a", encoding="utf-8") as f:
|
||||
f.write(message + "\n")
|
||||
|
||||
def log_metrics(self, metrics: Dict[str, float], split: str):
|
||||
if self.rank == 0:
|
||||
now = time.time()
|
||||
dt_str = ""
|
||||
if self._last_log_time is not None:
|
||||
dt = now - self._last_log_time
|
||||
dt_str = f", log interval: {dt:.2f}s"
|
||||
self._last_log_time = now
|
||||
|
||||
formatted = ", ".join(f"{k}: {v:.6f}" for k, v in metrics.items())
|
||||
self.print(f"[{split}] step {self.step}: {formatted}{dt_str}")
|
||||
if self.writer is not None:
|
||||
for key, value in metrics.items():
|
||||
if isinstance(value, (int, float)):
|
||||
self.writer.add_scalar(f"{split}/{key}", value, self.step)
|
||||
|
||||
def done(self, split: str, message: str):
|
||||
self.print(f"[{split}] {message}")
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# State dict
|
||||
# ------------------------------------------------------------------ #
|
||||
def state_dict(self):
|
||||
return {"step": self.step}
|
||||
|
||||
def load_state_dict(self, state):
|
||||
self.step = int(state.get("step", 0))
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Context manager compatibility (for parity with minicpm-audio code)
|
||||
# ------------------------------------------------------------------ #
|
||||
@contextlib.contextmanager
|
||||
def live(self):
|
||||
yield
|
||||
|
||||
Reference in New Issue
Block a user