Update: VoxCPM1.5 and fine-tuning supprt

This commit is contained in:
Labmem-Zhouyx
2025-12-05 21:00:01 +08:00
parent d1bb6aaf41
commit 461ad7e506
29 changed files with 2928 additions and 228 deletions
+28
View File
@@ -0,0 +1,28 @@
"""
Training utilities for VoxCPM fine-tuning.
This package mirrors the training mechanics used in the minicpm-audio
tooling while relying solely on local audio-text datasets managed via
the HuggingFace ``datasets`` library.
"""
from .accelerator import Accelerator
from .tracker import TrainingTracker
from .data import (
load_audio_text_datasets,
HFVoxCPMDataset,
build_dataloader,
BatchProcessor,
)
from .state import TrainingState
__all__ = [
"Accelerator",
"TrainingTracker",
"HFVoxCPMDataset",
"BatchProcessor",
"TrainingState",
"load_audio_text_datasets",
"build_dataloader",
]
+166
View File
@@ -0,0 +1,166 @@
from __future__ import annotations
import contextlib
import os
import random
import typing
import numpy as np
import torch
import torch.distributed as dist
import torch.utils.data
from torch.nn.parallel import DistributedDataParallel
class Accelerator:
"""
Simplified accelerator that mirrors the behaviour of the minicpm-audio
training utilities. It initializes a distributed process group when
``torchrun`` is used and exposes helpers for AMP, gradient scaling and
preparing models/dataloaders for DDP.
"""
def __init__(self, amp: bool = False, seed: int = 42):
self.world_size = int(os.getenv("WORLD_SIZE", "1"))
if self.world_size > 1 and not dist.is_initialized():
dist.init_process_group("nccl", init_method="env://")
self.rank = dist.get_rank() if dist.is_initialized() else 0
self.local_rank = int(os.environ.get("LOCAL_RANK", "0"))
self.amp = amp
# Set random seed to ensure model initialization consistency
self._set_seed(seed)
class DummyScaler:
def step(self, optimizer):
optimizer.step()
def scale(self, loss):
return loss
def unscale_(self, optimizer):
return optimizer
def update(self):
pass
self.scaler = torch.amp.GradScaler("cuda") if (amp and torch.cuda.is_available()) else DummyScaler()
self.device_ctx = (
torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
)
self._ddp_model = None # For no_sync support
def _set_seed(self, seed: int):
"""Set random seed to ensure model initialization consistency across multiple GPUs"""
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def __enter__(self):
if self.device_ctx is not None:
self.device_ctx.__enter__()
return self
def __exit__(self, exc_type, exc_value, traceback):
if self.device_ctx is not None:
self.device_ctx.__exit__(exc_type, exc_value, traceback)
def barrier(self):
"""Synchronize all processes"""
if dist.is_initialized():
dist.barrier()
def all_reduce(self, tensor: torch.Tensor, op=dist.ReduceOp.AVG):
"""All-reduce tensor across processes"""
if dist.is_initialized():
dist.all_reduce(tensor, op=op)
return tensor
# ------------------------------------------------------------------ #
# Model helpers
# ------------------------------------------------------------------ #
def prepare_model(self, model: torch.nn.Module, **kwargs):
if hasattr(model, 'device'): # make sure the matrix will be moved to the correct device
model.device = self.device
model = model.to(self.device)
if self.world_size > 1:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DistributedDataParallel(model, device_ids=[self.local_rank], **kwargs)
self._ddp_model = model # Save DDP model reference for no_sync support
return model
@contextlib.contextmanager
def no_sync(self):
"""
Context manager to skip gradient synchronization during gradient accumulation.
Only used outside the last micro-batch.
"""
if self._ddp_model is not None:
with self._ddp_model.no_sync():
yield
else:
yield
@property
def device(self):
if torch.cuda.is_available():
return torch.device("cuda", self.local_rank)
if torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
# ------------------------------------------------------------------ #
# AMP helpers
# ------------------------------------------------------------------ #
def autocast(self, *args, **kwargs):
return torch.amp.autocast("cuda", enabled=self.amp, *args, **kwargs)
def backward(self, loss: torch.Tensor):
self.scaler.scale(loss).backward()
def step(self, optimizer: torch.optim.Optimizer):
self.scaler.step(optimizer)
def update(self):
self.scaler.update()
# ------------------------------------------------------------------ #
# Data helpers
# ------------------------------------------------------------------ #
def prepare_dataloader(
self,
dataset: typing.Iterable,
*,
batch_size: int,
num_workers: int = 0,
shuffle: bool = True,
collate_fn=None,
drop_last: bool = False,
) -> torch.utils.data.DataLoader:
if self.world_size > 1:
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle
)
shuffle = False
else:
sampler = None
return torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle if sampler is None else False,
sampler=sampler,
num_workers=num_workers,
collate_fn=collate_fn,
drop_last=drop_last,
pin_memory=True,
)
@staticmethod
def unwrap(model: torch.nn.Module) -> torch.nn.Module:
return model.module if hasattr(model, "module") else model
+40
View File
@@ -0,0 +1,40 @@
from __future__ import annotations
import argbind
import yaml
from pathlib import Path
from typing import Dict, Any
def load_yaml_config(path: str | Path) -> Dict[str, Any]:
"""
Load a YAML configuration file into a dictionary suitable for argbind.
"""
path = Path(path)
with path.open("r", encoding="utf-8") as f:
data = yaml.safe_load(f)
if not isinstance(data, dict):
raise ValueError(f"Configuration file {path} must contain a top-level mapping.")
return data
def parse_args_with_config(config_path: str | Path | None = None):
"""
Helper to unify CLI arguments and YAML configuration.
Usage mirrors minicpm-audio:
args = parse_args_with_config("conf/voxcpm/finetune.yml")
with argbind.scope(args):
...
"""
cli_args = argbind.parse_args()
if config_path is None:
return cli_args
yaml_args = load_yaml_config(config_path)
with argbind.scope(cli_args):
yaml_args = argbind.parse_args(yaml_args=yaml_args, argv=[])
cli_args.update(yaml_args)
return cli_args
+214
View File
@@ -0,0 +1,214 @@
import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import argbind
import torch
from datasets import Audio, Dataset, DatasetDict, load_dataset
from torch.utils.data import Dataset as TorchDataset
from ..model.voxcpm import VoxCPMConfig
from ..modules.audiovae import AudioVAE
from .packers import AudioFeatureProcessingPacker
DEFAULT_TEXT_COLUMN = "text"
DEFAULT_AUDIO_COLUMN = "audio"
DEFAULT_ID_COLUMN = "dataset_id"
@argbind.bind()
def load_audio_text_datasets(
train_manifest: str,
val_manifest: str = "",
text_column: str = DEFAULT_TEXT_COLUMN,
audio_column: str = DEFAULT_AUDIO_COLUMN,
dataset_id_column: str = DEFAULT_ID_COLUMN,
sample_rate: int = 16_000,
num_proc: int = 1,
) -> Tuple[Dataset, Optional[Dataset]]:
data_files = {"train": train_manifest}
if val_manifest:
data_files["validation"] = val_manifest
dataset_dict: DatasetDict = load_dataset("json", data_files=data_files)
def prepare(ds: Dataset) -> Dataset:
if audio_column not in ds.column_names:
raise ValueError(f"Expected '{audio_column}' column in manifest.")
# We cast to Audio to ensure proper handling during training,
# but for length calculation we might need raw path or duration if available.
# HF datasets usually don't compute duration automatically for 'Audio' column.
ds = ds.cast_column(audio_column, Audio(sampling_rate=sample_rate))
if audio_column != DEFAULT_AUDIO_COLUMN:
ds = ds.rename_column(audio_column, DEFAULT_AUDIO_COLUMN)
if text_column != DEFAULT_TEXT_COLUMN:
ds = ds.rename_column(text_column, DEFAULT_TEXT_COLUMN)
if dataset_id_column and dataset_id_column in ds.column_names:
if dataset_id_column != DEFAULT_ID_COLUMN:
ds = ds.rename_column(dataset_id_column, DEFAULT_ID_COLUMN)
else:
ds = ds.add_column(DEFAULT_ID_COLUMN, [0] * len(ds))
return ds
train_ds = prepare(dataset_dict["train"])
val_ds = prepare(dataset_dict["validation"]) if "validation" in dataset_dict else None
return train_ds, val_ds
def compute_sample_lengths(
ds: Dataset,
audio_vae_fps: int = 25,
patch_size: int = 1,
) -> List[int]:
"""
预估每个样本经过 packer 之后的大致序列长度(text+audio),用于过滤超长样本。
逻辑与 AudioFeatureProcessingPacker / AudioVAE 一致:
- 文本长度: len(text_ids)
- 音频长度:
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
t_seq = ceil(t_vae / patch_size)
- 序列总长约为: text_len + t_seq + 2
"""
lengths: List[int] = []
has_duration = "duration" in ds.column_names
for i in range(len(ds)):
item = ds[i]
text_len = len(item["text_ids"])
# 音频时长(尽量不解码;若 manifest 里已有 duration 列则优先使用)
if has_duration:
duration = float(item["duration"])
else:
audio = item[DEFAULT_AUDIO_COLUMN]
duration = len(audio["array"]) / float(audio["sampling_rate"])
t_vae = math.ceil(duration * audio_vae_fps)
t_seq = math.ceil(t_vae / patch_size)
total_len = text_len + t_seq + 2
lengths.append(total_len)
return lengths
class HFVoxCPMDataset(TorchDataset):
"""
Thin wrapper around a tokenized HuggingFace dataset that returns
PyTorch-friendly samples.
"""
def __init__(self, dataset: Dataset):
self.dataset = dataset
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx: int):
item = self.dataset[idx]
audio = item[DEFAULT_AUDIO_COLUMN]
return {
"text_ids": item["text_ids"],
"audio_array": audio["array"],
"audio_sampling_rate": audio["sampling_rate"],
"dataset_id": item.get(DEFAULT_ID_COLUMN, 0),
"is_prompt": item.get("is_prompt", False),
}
@staticmethod
def pad_sequences(seqs: List[torch.Tensor], pad_value: float):
if not seqs:
return torch.empty(0)
max_len = max(seq.shape[0] for seq in seqs)
padded = []
for seq in seqs:
if seq.shape[0] < max_len:
pad_width = (0, max_len - seq.shape[0])
seq = torch.nn.functional.pad(seq, pad_width, value=pad_value)
padded.append(seq)
return torch.stack(padded)
@classmethod
def collate_fn(cls, batch: List[Dict]):
text_tensors = [torch.tensor(sample["text_ids"], dtype=torch.int32) for sample in batch]
audio_tensors = [torch.tensor(sample["audio_array"], dtype=torch.float32) for sample in batch]
dataset_ids = torch.tensor([sample["dataset_id"] for sample in batch], dtype=torch.int32)
is_prompts = [bool(sample.get("is_prompt", False)) for sample in batch]
text_padded = cls.pad_sequences(text_tensors, pad_value=-100)
audio_padded = cls.pad_sequences(audio_tensors, pad_value=-100.0)
task_ids = torch.ones(text_padded.size(0), dtype=torch.int32)
return {
"text_tokens": text_padded,
"audio_tokens": audio_padded,
"task_ids": task_ids,
"dataset_ids": dataset_ids,
"is_prompts": is_prompts,
}
class BatchProcessor:
"""
Wraps ``AudioFeatureProcessingPacker`` so the training loop can mirror
the minicpm-audio mechanics.
"""
def __init__(
self,
*,
config: VoxCPMConfig,
audio_vae: AudioVAE,
dataset_cnt: int,
device: torch.device,
):
self.device = device
self.dataset_cnt = dataset_cnt
self.audio_vae = audio_vae
self.audio_vae.to(device)
self.packer = AudioFeatureProcessingPacker(
dataset_cnt=dataset_cnt,
max_len=config.max_length,
patch_size=config.patch_size,
feat_dim=config.feat_dim,
audio_vae=self.audio_vae,
)
def __call__(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
audio_tokens = batch["audio_tokens"].to(self.device)
text_tokens = batch["text_tokens"].to(self.device)
task_ids = batch["task_ids"].to(self.device)
dataset_ids = batch["dataset_ids"].to(self.device)
packed = self.packer(
audio_tokens=audio_tokens,
text_tokens=text_tokens,
task_ids=task_ids,
dataset_ids=dataset_ids,
is_prompts=batch["is_prompts"],
)
return packed
def build_dataloader(
hf_dataset: Dataset,
*,
accelerator,
batch_size: int,
num_workers: int,
drop_last: bool = False,
) -> torch.utils.data.DataLoader:
torch_dataset = HFVoxCPMDataset(hf_dataset)
# Standard padding-based batching; Accelerator will attach DistributedSampler if needed.
return accelerator.prepare_dataloader(
torch_dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=True,
collate_fn=HFVoxCPMDataset.collate_fn,
drop_last=drop_last,
)
+289
View File
@@ -0,0 +1,289 @@
from typing import Dict, List, Tuple
import torch
import torch.nn as nn
from einops import rearrange
class AudioFeatureProcessingPacker:
"""
Adapted from the minicpm-audio training utilities. It converts raw text and
audio tokens into the packed multimodal representation required by VoxCPM.
"""
def __init__(self, dataset_cnt: int, max_len: int, patch_size: int, feat_dim: int, audio_vae: nn.Module):
self.audio_start_id = 101
self.audio_end_id = 102
# unused now
self.audio_prompt_start_id = 103
self.audio_prompt_end_id = 104
self.text_eos_token_id = 2
self.patch_size = patch_size
self.patch_len = audio_vae.hop_length * self.patch_size
self.feat_dim = feat_dim
self.dataset_cnt = max(dataset_cnt, 1)
self.max_len = max_len
self.audio_vae = audio_vae
self.process_functions = {"tts": self.process_tts_data}
self.task_id_map = {"tts": 1}
self.id_to_task = {idx: usage for usage, idx in self.task_id_map.items()}
# ------------------------------------------------------------------ #
# Helpers
# ------------------------------------------------------------------ #
@staticmethod
def _first_pad_position(tokens: torch.Tensor):
positions = (tokens == -100).nonzero(as_tuple=True)
if positions[0].numel() == 0:
return None
return int(positions[0][0])
def unpad_text_tokens(self, tokens: torch.Tensor):
pad_pos = self._first_pad_position(tokens)
return tokens if pad_pos is None else tokens[:pad_pos]
def unpad_audio_tokens(self, tokens: torch.Tensor):
pad_pos = self._first_pad_position(tokens)
return tokens if pad_pos is None else tokens[:pad_pos]
def encode_audio(self, wav: torch.Tensor):
"""
Encode raw waveform into latent features using AudioVAE.
AudioVAE.encode expects shape [B, 1, T'] and returns [B, D, T].
We then transpose to [B, T, D] to match downstream expectations.
"""
wav = wav.unsqueeze(0) # [1, T]
wav = wav.unsqueeze(1) # [1, 1, T]
wav_len = wav.size(-1)
if wav_len % self.patch_len != 0:
padding_size = self.patch_len - wav_len % self.patch_len
wav = torch.nn.functional.pad(wav, (0, padding_size))
with torch.no_grad():
z = self.audio_vae.encode(wav, self.audio_vae.sample_rate) # [1, D, T']
feat = z.transpose(1, 2) # [1, T', D]
return feat
# ------------------------------------------------------------------ #
# Main entry point
# ------------------------------------------------------------------ #
def __call__(
self,
audio_tokens: torch.Tensor,
text_tokens: torch.Tensor,
task_ids: torch.Tensor,
dataset_ids: torch.Tensor,
is_prompts: List[bool],
) -> Dict[str, torch.Tensor]:
"""
Padding-based batching: each sample in the input batch is processed
independently and then padded to a common length (capped by ``max_len``).
The result tensors all have shape [B, T, ...].
"""
device = audio_tokens.device
max_dataset_id = int(dataset_ids.max().item()) if dataset_ids.numel() > 0 else -1
dataset_cnt = max(self.dataset_cnt, max_dataset_id + 1)
text_tokens_list: List[torch.Tensor] = []
audio_feats_list: List[torch.Tensor] = []
text_mask_list: List[torch.Tensor] = []
audio_mask_list: List[torch.Tensor] = []
loss_mask_list: List[torch.Tensor] = []
labels_list: List[torch.Tensor] = []
audio_task_ids_list: List[torch.Tensor] = []
audio_dataset_ids_list: List[torch.Tensor] = []
lengths: List[int] = []
audio_duration_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
text_token_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
for audio_token, text_token, task_id, dataset_idx, is_prompt in zip(
audio_tokens, text_tokens, task_ids.tolist(), dataset_ids.tolist(), is_prompts
):
unpad_audio_token = self.unpad_audio_tokens(audio_token).to(torch.float32)
unpad_text_token = self.unpad_text_tokens(text_token)
usage = self.id_to_task[task_id]
(
packed_text,
audio_feat,
text_mask,
audio_mask,
loss_mask,
labels,
audio_duration,
text_token_count,
) = self.process_functions[usage](unpad_audio_token, unpad_text_token, is_prompt)
audio_duration_consumed[dataset_idx] += audio_duration
text_token_consumed[dataset_idx] += text_token_count
audio_task_id = torch.zeros_like(audio_mask)
audio_task_id[audio_mask == 1] = self.task_id_map[usage]
audio_dataset_id = torch.zeros_like(audio_mask)
audio_dataset_id[audio_mask == 1] = dataset_idx + 1
text_tokens_list.append(packed_text)
text_mask_list.append(text_mask)
audio_feats_list.append(audio_feat)
audio_mask_list.append(audio_mask)
loss_mask_list.append(loss_mask)
labels_list.append(labels)
audio_task_ids_list.append(audio_task_id)
audio_dataset_ids_list.append(audio_dataset_id)
lengths.append(packed_text.shape[0])
# Determine padded length per batch (cap by self.max_len)
if lengths:
max_len = min(self.max_len, max(lengths))
else:
max_len = self.max_len
def pad_1d(x: torch.Tensor, pad_value: int = 0) -> torch.Tensor:
if x.size(0) >= max_len:
return x[: max_len]
pad = torch.full((max_len - x.size(0),), pad_value, dtype=x.dtype, device=x.device)
return torch.cat([x, pad], dim=0)
def pad_3d(x: torch.Tensor) -> torch.Tensor:
# x: [T, P, D]
if x.size(0) >= max_len:
return x[: max_len]
pad = torch.zeros(
(max_len - x.size(0),) + x.shape[1:], dtype=x.dtype, device=x.device
)
return torch.cat([x, pad], dim=0)
if lengths:
text_tokens_batch = torch.stack([pad_1d(t, pad_value=0) for t in text_tokens_list], dim=0)
text_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in text_mask_list], dim=0)
audio_feats_batch = torch.stack([pad_3d(f) for f in audio_feats_list], dim=0)
audio_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in audio_mask_list], dim=0)
loss_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in loss_mask_list], dim=0)
labels_batch = torch.stack([pad_1d(l, pad_value=0) for l in labels_list], dim=0)
audio_task_ids_batch = torch.stack(
[pad_1d(t, pad_value=0) for t in audio_task_ids_list], dim=0
)
audio_dataset_ids_batch = torch.stack(
[pad_1d(d, pad_value=0) for d in audio_dataset_ids_list], dim=0
)
# Position ids: [B, T], simple 0..L_i-1 then padded with 0
position_ids_list = []
for L in lengths:
L_clip = min(L, max_len)
pos = torch.arange(0, L_clip, device=device)
if L_clip < max_len:
pad = torch.zeros(max_len - L_clip, dtype=pos.dtype, device=device)
pos = torch.cat([pos, pad], dim=0)
position_ids_list.append(pos)
position_ids = torch.stack(position_ids_list, dim=0)
else:
# Empty batch fallback (shouldn't really happen)
text_tokens_batch = torch.zeros((0, self.max_len), dtype=torch.int32, device=device)
text_mask_batch = torch.zeros_like(text_tokens_batch)
audio_feats_batch = torch.zeros(
(0, self.max_len, self.patch_size, self.feat_dim), dtype=torch.float32, device=device
)
audio_mask_batch = torch.zeros_like(text_tokens_batch)
loss_mask_batch = torch.zeros_like(text_tokens_batch)
labels_batch = torch.zeros_like(text_tokens_batch)
audio_task_ids_batch = torch.zeros_like(text_tokens_batch)
audio_dataset_ids_batch = torch.zeros_like(text_tokens_batch)
position_ids = torch.zeros_like(text_tokens_batch)
audio_duration_consumed = audio_duration_consumed.to(torch.long)
text_token_consumed = text_token_consumed.to(torch.long)
return {
"text_tokens": text_tokens_batch,
"audio_feats": audio_feats_batch,
"text_mask": text_mask_batch,
"audio_mask": audio_mask_batch,
"loss_mask": loss_mask_batch,
"position_ids": position_ids,
"labels": labels_batch,
"audio_task_ids": audio_task_ids_batch,
"audio_dataset_ids": audio_dataset_ids_batch,
"audio_duration_consumed": audio_duration_consumed,
"text_token_consumed": text_token_consumed,
}
# ------------------------------------------------------------------ #
# Feature extraction helpers
# ------------------------------------------------------------------ #
def extract_audio_feats(self, audio_data: torch.Tensor):
audio_feats = self.encode_audio(audio_data)
if audio_feats.size(1) % self.patch_size != 0:
audio_feats_ = audio_feats.transpose(1, 2)
padding = nn.functional.pad(audio_feats_, (0, self.patch_size - audio_feats.size(1) % self.patch_size))
audio_feats = padding.transpose(1, 2)
audio_duration = audio_feats.size(1) / 25
audio_feats = rearrange(audio_feats, "b (t p) c -> b t p c", p=self.patch_size)
return audio_feats, audio_duration
def process_tts_data(self, audio_token: torch.Tensor, text_token: torch.Tensor, is_prompt: bool = False):
text_token_info = torch.cat(
[
text_token,
torch.tensor(
[self.audio_prompt_start_id if is_prompt else self.audio_start_id],
dtype=torch.int32,
device=text_token.device,
),
],
dim=-1,
)
text_token_count = len(text_token)
text_length = text_token_info.shape[0]
audio_feat_info, audio_duration = self.extract_audio_feats(audio_token)
audio_feat_info = audio_feat_info.squeeze(0)
audio_length = audio_feat_info.shape[0]
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
text_token_info = torch.cat(
[
text_token_info,
text_pad_token,
torch.tensor(
[self.audio_prompt_end_id if is_prompt else self.audio_end_id],
dtype=torch.int32,
device=text_token.device,
),
]
)
audio_pad_feat = torch.zeros(
(text_length, self.patch_size, audio_feat_info.size(-1)),
dtype=torch.float32,
device=text_token.device,
)
audio_feat_info = torch.cat([audio_pad_feat, audio_feat_info, audio_pad_feat[0:1, ...]], dim=0)
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length), torch.ones(1)]).type(torch.int32).to(
text_token.device
)
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length), torch.zeros(1)]).type(
torch.int32
).to(text_token.device)
loss_mask = torch.cat([torch.zeros(text_length), torch.zeros(audio_length) if is_prompt else torch.ones(audio_length), torch.zeros(1)]).type(torch.int32).to(text_token.device)
labels = torch.zeros(text_length + audio_length + 1).type(torch.int32).to(text_token.device)
labels[-2] = 1
return (
text_token_info,
audio_feat_info,
text_mask,
audio_mask,
loss_mask,
labels,
audio_duration,
text_token_count,
)
+21
View File
@@ -0,0 +1,21 @@
from __future__ import annotations
from dataclasses import dataclass
@dataclass
class TrainingState:
"""
Container that mirrors the object returned in the minicpm-audio training
loop. It holds persistent references to the model, optimizer, scheduler,
dataloaders and tracker.
"""
generator: object
optimizer: object
scheduler: object
train_loader: object
val_loader: object
tracker: object
batch_processor: object
+78
View File
@@ -0,0 +1,78 @@
from __future__ import annotations
import contextlib
import time
from pathlib import Path
from typing import Dict, Optional
class TrainingTracker:
"""
Lightweight tracker inspired by the minimcpm-audio training workflow.
It keeps track of the current global step, prints rank-aware messages,
optionally writes to TensorBoard via a provided writer, and stores progress
in a logfile for later inspection.
"""
def __init__(
self,
*,
writer=None,
log_file: Optional[str] = None,
rank: int = 0,
):
self.writer = writer
self.log_file = Path(log_file) if log_file else None
if self.log_file:
self.log_file.parent.mkdir(parents=True, exist_ok=True)
self.rank = rank
self.step = 0
# Record the time of the last log to calculate the interval
self._last_log_time: float | None = None
# ------------------------------------------------------------------ #
# Logging helpers
# ------------------------------------------------------------------ #
def print(self, message: str):
if self.rank == 0:
print(message, flush=True)
if self.log_file:
with self.log_file.open("a", encoding="utf-8") as f:
f.write(message + "\n")
def log_metrics(self, metrics: Dict[str, float], split: str):
if self.rank == 0:
now = time.time()
dt_str = ""
if self._last_log_time is not None:
dt = now - self._last_log_time
dt_str = f", log interval: {dt:.2f}s"
self._last_log_time = now
formatted = ", ".join(f"{k}: {v:.6f}" for k, v in metrics.items())
self.print(f"[{split}] step {self.step}: {formatted}{dt_str}")
if self.writer is not None:
for key, value in metrics.items():
if isinstance(value, (int, float)):
self.writer.add_scalar(f"{split}/{key}", value, self.step)
def done(self, split: str, message: str):
self.print(f"[{split}] {message}")
# ------------------------------------------------------------------ #
# State dict
# ------------------------------------------------------------------ #
def state_dict(self):
return {"step": self.step}
def load_state_dict(self, state):
self.step = int(state.get("step", 0))
# ------------------------------------------------------------------ #
# Context manager compatibility (for parity with minicpm-audio code)
# ------------------------------------------------------------------ #
@contextlib.contextmanager
def live(self):
yield