update finetuning pipeline and runtime device handling

Support optional ref_audio samples in finetuning and make runtime device selection explicit while keeping auto fallback behavior consistent. Also ignore the local app override file to avoid accidental commits.

Made-with: Cursor
This commit is contained in:
刘鑫
2026-04-11 11:08:50 +08:00
parent abf01b9bf3
commit e4e049624c
10 changed files with 379 additions and 47 deletions
+53 -11
View File
@@ -12,6 +12,7 @@ from .packers import AudioFeatureProcessingPacker
DEFAULT_TEXT_COLUMN = "text"
DEFAULT_AUDIO_COLUMN = "audio"
DEFAULT_REF_AUDIO_COLUMN = "ref_audio"
DEFAULT_ID_COLUMN = "dataset_id"
@@ -21,6 +22,7 @@ def load_audio_text_datasets(
val_manifest: str = "",
text_column: str = DEFAULT_TEXT_COLUMN,
audio_column: str = DEFAULT_AUDIO_COLUMN,
ref_audio_column: str = DEFAULT_REF_AUDIO_COLUMN,
dataset_id_column: str = DEFAULT_ID_COLUMN,
sample_rate: int = 16_000,
num_proc: int = 1,
@@ -34,14 +36,19 @@ def load_audio_text_datasets(
def prepare(ds: Dataset) -> Dataset:
if audio_column not in ds.column_names:
raise ValueError(f"Expected '{audio_column}' column in manifest.")
# We cast to Audio to ensure proper handling during training,
# but for length calculation we might need raw path or duration if available.
# HF datasets usually don't compute duration automatically for 'Audio' column.
ds = ds.cast_column(audio_column, Audio(sampling_rate=sample_rate))
if audio_column != DEFAULT_AUDIO_COLUMN:
ds = ds.rename_column(audio_column, DEFAULT_AUDIO_COLUMN)
if text_column != DEFAULT_TEXT_COLUMN:
ds = ds.rename_column(text_column, DEFAULT_TEXT_COLUMN)
# ref_audio is optional — cast to Audio if the column exists
ref_col = ref_audio_column if ref_audio_column in ds.column_names else DEFAULT_REF_AUDIO_COLUMN
if ref_col in ds.column_names:
ds = ds.cast_column(ref_col, Audio(sampling_rate=sample_rate))
if ref_col != DEFAULT_REF_AUDIO_COLUMN:
ds = ds.rename_column(ref_col, DEFAULT_REF_AUDIO_COLUMN)
if dataset_id_column and dataset_id_column in ds.column_names:
if dataset_id_column != DEFAULT_ID_COLUMN:
ds = ds.rename_column(dataset_id_column, DEFAULT_ID_COLUMN)
@@ -67,11 +74,11 @@ def compute_sample_lengths(
- 音频长度:
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
t_seq = ceil(t_vae / patch_size)
- 序列总长约为: text_len + t_seq + 2
- 无 ref_audio: text_len + t_seq + 2
- 有 ref_audio: text_len + t_seq + ref_seq + 4
Optimized: Use batch column access instead of iterating item by item.
"""
# Batch access columns - much faster than per-item access
text_ids_list = ds["text_ids"]
text_lens = [len(t) for t in text_ids_list]
@@ -79,18 +86,35 @@ def compute_sample_lengths(
if has_duration:
durations = ds["duration"]
else:
# Fallback: need to compute from audio (slow, but unavoidable without duration column)
durations = []
for i in range(len(ds)):
audio = ds[i][DEFAULT_AUDIO_COLUMN]
durations.append(len(audio["array"]) / float(audio["sampling_rate"]))
# Vectorized length computation
has_ref_audio = DEFAULT_REF_AUDIO_COLUMN in ds.column_names
if has_ref_audio:
ref_duration_col = "ref_duration" if "ref_duration" in ds.column_names else None
lengths = []
for text_len, duration in zip(text_lens, durations):
for i, (text_len, duration) in enumerate(zip(text_lens, durations)):
t_vae = math.ceil(float(duration) * audio_vae_fps)
t_seq = math.ceil(t_vae / patch_size)
total_len = text_len + t_seq + 2
ref_seq = 0
if has_ref_audio:
# Estimate ref_audio length; ref_audio is None for samples without it
if ref_duration_col:
ref_dur = ds[i].get(ref_duration_col)
else:
ref_item = ds[i].get(DEFAULT_REF_AUDIO_COLUMN)
ref_dur = len(ref_item["array"]) / float(ref_item["sampling_rate"]) if ref_item else None
if ref_dur is not None and float(ref_dur) > 0:
ref_vae = math.ceil(float(ref_dur) * audio_vae_fps)
ref_seq = math.ceil(ref_vae / patch_size)
# +2 for 101/102; +2 more for 103/104 when ref_audio present
overhead = 4 if ref_seq > 0 else 2
total_len = text_len + t_seq + ref_seq + overhead
lengths.append(total_len)
return lengths
@@ -102,8 +126,11 @@ class HFVoxCPMDataset(TorchDataset):
PyTorch-friendly samples.
"""
_SENTINEL = [-100.0]
def __init__(self, dataset: Dataset):
self.dataset = dataset
self.has_ref_audio = DEFAULT_REF_AUDIO_COLUMN in dataset.column_names
def __len__(self):
return len(self.dataset)
@@ -111,13 +138,17 @@ class HFVoxCPMDataset(TorchDataset):
def __getitem__(self, idx: int):
item = self.dataset[idx]
audio = item[DEFAULT_AUDIO_COLUMN]
return {
sample = {
"text_ids": item["text_ids"],
"audio_array": audio["array"],
"audio_sampling_rate": audio["sampling_rate"],
"dataset_id": item.get(DEFAULT_ID_COLUMN, 0),
"is_prompt": item.get("is_prompt", False),
}
if self.has_ref_audio:
ref = item.get(DEFAULT_REF_AUDIO_COLUMN)
sample["ref_audio_array"] = ref["array"] if ref else self._SENTINEL
return sample
@staticmethod
def pad_sequences(seqs: List[torch.Tensor], pad_value: float):
@@ -143,7 +174,7 @@ class HFVoxCPMDataset(TorchDataset):
audio_padded = cls.pad_sequences(audio_tensors, pad_value=-100.0)
task_ids = torch.ones(text_padded.size(0), dtype=torch.int32)
return {
result = {
"text_tokens": text_padded,
"audio_tokens": audio_padded,
"task_ids": task_ids,
@@ -151,6 +182,12 @@ class HFVoxCPMDataset(TorchDataset):
"is_prompts": is_prompts,
}
if "ref_audio_array" in batch[0]:
ref_tensors = [torch.tensor(s["ref_audio_array"], dtype=torch.float32) for s in batch]
result["ref_audio_tokens"] = cls.pad_sequences(ref_tensors, pad_value=-100.0)
return result
class BatchProcessor:
"""
@@ -184,12 +221,17 @@ class BatchProcessor:
task_ids = batch["task_ids"].to(self.device)
dataset_ids = batch["dataset_ids"].to(self.device)
ref_audio_tokens = None
if "ref_audio_tokens" in batch:
ref_audio_tokens = batch["ref_audio_tokens"].to(self.device)
packed = self.packer(
audio_tokens=audio_tokens,
text_tokens=text_tokens,
task_ids=task_ids,
dataset_ids=dataset_ids,
is_prompts=batch["is_prompts"],
ref_audio_tokens=ref_audio_tokens,
)
return packed