update finetuning pipeline and runtime device handling
Support optional ref_audio samples in finetuning and make runtime device selection explicit while keeping auto fallback behavior consistent. Also ignore the local app override file to avoid accidental commits. Made-with: Cursor
This commit is contained in:
+53
-11
@@ -12,6 +12,7 @@ from .packers import AudioFeatureProcessingPacker
|
||||
|
||||
DEFAULT_TEXT_COLUMN = "text"
|
||||
DEFAULT_AUDIO_COLUMN = "audio"
|
||||
DEFAULT_REF_AUDIO_COLUMN = "ref_audio"
|
||||
DEFAULT_ID_COLUMN = "dataset_id"
|
||||
|
||||
|
||||
@@ -21,6 +22,7 @@ def load_audio_text_datasets(
|
||||
val_manifest: str = "",
|
||||
text_column: str = DEFAULT_TEXT_COLUMN,
|
||||
audio_column: str = DEFAULT_AUDIO_COLUMN,
|
||||
ref_audio_column: str = DEFAULT_REF_AUDIO_COLUMN,
|
||||
dataset_id_column: str = DEFAULT_ID_COLUMN,
|
||||
sample_rate: int = 16_000,
|
||||
num_proc: int = 1,
|
||||
@@ -34,14 +36,19 @@ def load_audio_text_datasets(
|
||||
def prepare(ds: Dataset) -> Dataset:
|
||||
if audio_column not in ds.column_names:
|
||||
raise ValueError(f"Expected '{audio_column}' column in manifest.")
|
||||
# We cast to Audio to ensure proper handling during training,
|
||||
# but for length calculation we might need raw path or duration if available.
|
||||
# HF datasets usually don't compute duration automatically for 'Audio' column.
|
||||
ds = ds.cast_column(audio_column, Audio(sampling_rate=sample_rate))
|
||||
if audio_column != DEFAULT_AUDIO_COLUMN:
|
||||
ds = ds.rename_column(audio_column, DEFAULT_AUDIO_COLUMN)
|
||||
if text_column != DEFAULT_TEXT_COLUMN:
|
||||
ds = ds.rename_column(text_column, DEFAULT_TEXT_COLUMN)
|
||||
|
||||
# ref_audio is optional — cast to Audio if the column exists
|
||||
ref_col = ref_audio_column if ref_audio_column in ds.column_names else DEFAULT_REF_AUDIO_COLUMN
|
||||
if ref_col in ds.column_names:
|
||||
ds = ds.cast_column(ref_col, Audio(sampling_rate=sample_rate))
|
||||
if ref_col != DEFAULT_REF_AUDIO_COLUMN:
|
||||
ds = ds.rename_column(ref_col, DEFAULT_REF_AUDIO_COLUMN)
|
||||
|
||||
if dataset_id_column and dataset_id_column in ds.column_names:
|
||||
if dataset_id_column != DEFAULT_ID_COLUMN:
|
||||
ds = ds.rename_column(dataset_id_column, DEFAULT_ID_COLUMN)
|
||||
@@ -67,11 +74,11 @@ def compute_sample_lengths(
|
||||
- 音频长度:
|
||||
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
|
||||
t_seq = ceil(t_vae / patch_size)
|
||||
- 序列总长约为: text_len + t_seq + 2
|
||||
- 无 ref_audio: text_len + t_seq + 2
|
||||
- 有 ref_audio: text_len + t_seq + ref_seq + 4
|
||||
|
||||
Optimized: Use batch column access instead of iterating item by item.
|
||||
"""
|
||||
# Batch access columns - much faster than per-item access
|
||||
text_ids_list = ds["text_ids"]
|
||||
text_lens = [len(t) for t in text_ids_list]
|
||||
|
||||
@@ -79,18 +86,35 @@ def compute_sample_lengths(
|
||||
if has_duration:
|
||||
durations = ds["duration"]
|
||||
else:
|
||||
# Fallback: need to compute from audio (slow, but unavoidable without duration column)
|
||||
durations = []
|
||||
for i in range(len(ds)):
|
||||
audio = ds[i][DEFAULT_AUDIO_COLUMN]
|
||||
durations.append(len(audio["array"]) / float(audio["sampling_rate"]))
|
||||
|
||||
# Vectorized length computation
|
||||
has_ref_audio = DEFAULT_REF_AUDIO_COLUMN in ds.column_names
|
||||
if has_ref_audio:
|
||||
ref_duration_col = "ref_duration" if "ref_duration" in ds.column_names else None
|
||||
|
||||
lengths = []
|
||||
for text_len, duration in zip(text_lens, durations):
|
||||
for i, (text_len, duration) in enumerate(zip(text_lens, durations)):
|
||||
t_vae = math.ceil(float(duration) * audio_vae_fps)
|
||||
t_seq = math.ceil(t_vae / patch_size)
|
||||
total_len = text_len + t_seq + 2
|
||||
|
||||
ref_seq = 0
|
||||
if has_ref_audio:
|
||||
# Estimate ref_audio length; ref_audio is None for samples without it
|
||||
if ref_duration_col:
|
||||
ref_dur = ds[i].get(ref_duration_col)
|
||||
else:
|
||||
ref_item = ds[i].get(DEFAULT_REF_AUDIO_COLUMN)
|
||||
ref_dur = len(ref_item["array"]) / float(ref_item["sampling_rate"]) if ref_item else None
|
||||
if ref_dur is not None and float(ref_dur) > 0:
|
||||
ref_vae = math.ceil(float(ref_dur) * audio_vae_fps)
|
||||
ref_seq = math.ceil(ref_vae / patch_size)
|
||||
|
||||
# +2 for 101/102; +2 more for 103/104 when ref_audio present
|
||||
overhead = 4 if ref_seq > 0 else 2
|
||||
total_len = text_len + t_seq + ref_seq + overhead
|
||||
lengths.append(total_len)
|
||||
|
||||
return lengths
|
||||
@@ -102,8 +126,11 @@ class HFVoxCPMDataset(TorchDataset):
|
||||
PyTorch-friendly samples.
|
||||
"""
|
||||
|
||||
_SENTINEL = [-100.0]
|
||||
|
||||
def __init__(self, dataset: Dataset):
|
||||
self.dataset = dataset
|
||||
self.has_ref_audio = DEFAULT_REF_AUDIO_COLUMN in dataset.column_names
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
@@ -111,13 +138,17 @@ class HFVoxCPMDataset(TorchDataset):
|
||||
def __getitem__(self, idx: int):
|
||||
item = self.dataset[idx]
|
||||
audio = item[DEFAULT_AUDIO_COLUMN]
|
||||
return {
|
||||
sample = {
|
||||
"text_ids": item["text_ids"],
|
||||
"audio_array": audio["array"],
|
||||
"audio_sampling_rate": audio["sampling_rate"],
|
||||
"dataset_id": item.get(DEFAULT_ID_COLUMN, 0),
|
||||
"is_prompt": item.get("is_prompt", False),
|
||||
}
|
||||
if self.has_ref_audio:
|
||||
ref = item.get(DEFAULT_REF_AUDIO_COLUMN)
|
||||
sample["ref_audio_array"] = ref["array"] if ref else self._SENTINEL
|
||||
return sample
|
||||
|
||||
@staticmethod
|
||||
def pad_sequences(seqs: List[torch.Tensor], pad_value: float):
|
||||
@@ -143,7 +174,7 @@ class HFVoxCPMDataset(TorchDataset):
|
||||
audio_padded = cls.pad_sequences(audio_tensors, pad_value=-100.0)
|
||||
task_ids = torch.ones(text_padded.size(0), dtype=torch.int32)
|
||||
|
||||
return {
|
||||
result = {
|
||||
"text_tokens": text_padded,
|
||||
"audio_tokens": audio_padded,
|
||||
"task_ids": task_ids,
|
||||
@@ -151,6 +182,12 @@ class HFVoxCPMDataset(TorchDataset):
|
||||
"is_prompts": is_prompts,
|
||||
}
|
||||
|
||||
if "ref_audio_array" in batch[0]:
|
||||
ref_tensors = [torch.tensor(s["ref_audio_array"], dtype=torch.float32) for s in batch]
|
||||
result["ref_audio_tokens"] = cls.pad_sequences(ref_tensors, pad_value=-100.0)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class BatchProcessor:
|
||||
"""
|
||||
@@ -184,12 +221,17 @@ class BatchProcessor:
|
||||
task_ids = batch["task_ids"].to(self.device)
|
||||
dataset_ids = batch["dataset_ids"].to(self.device)
|
||||
|
||||
ref_audio_tokens = None
|
||||
if "ref_audio_tokens" in batch:
|
||||
ref_audio_tokens = batch["ref_audio_tokens"].to(self.device)
|
||||
|
||||
packed = self.packer(
|
||||
audio_tokens=audio_tokens,
|
||||
text_tokens=text_tokens,
|
||||
task_ids=task_ids,
|
||||
dataset_ids=dataset_ids,
|
||||
is_prompts=batch["is_prompts"],
|
||||
ref_audio_tokens=ref_audio_tokens,
|
||||
)
|
||||
return packed
|
||||
|
||||
|
||||
+124
-14
@@ -1,4 +1,4 @@
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -14,7 +14,6 @@ class AudioFeatureProcessingPacker:
|
||||
def __init__(self, dataset_cnt: int, max_len: int, patch_size: int, feat_dim: int, audio_vae: nn.Module):
|
||||
self.audio_start_id = 101
|
||||
self.audio_end_id = 102
|
||||
# unused now
|
||||
self.audio_prompt_start_id = 103
|
||||
self.audio_prompt_end_id = 104
|
||||
self.text_eos_token_id = 2
|
||||
@@ -78,11 +77,16 @@ class AudioFeatureProcessingPacker:
|
||||
task_ids: torch.Tensor,
|
||||
dataset_ids: torch.Tensor,
|
||||
is_prompts: List[bool],
|
||||
ref_audio_tokens: Optional[torch.Tensor] = None,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Padding-based batching: each sample in the input batch is processed
|
||||
independently and then padded to a common length (capped by ``max_len``).
|
||||
The result tensors all have shape [B, T, ...].
|
||||
|
||||
If ``ref_audio_tokens`` is provided (same batch dim as ``audio_tokens``),
|
||||
samples whose unpadded ref_audio length > 0 will be processed with the
|
||||
reference-audio path (tokens 103/104 prepended, loss only on target audio).
|
||||
"""
|
||||
device = audio_tokens.device
|
||||
max_dataset_id = int(dataset_ids.max().item()) if dataset_ids.numel() > 0 else -1
|
||||
@@ -101,23 +105,43 @@ class AudioFeatureProcessingPacker:
|
||||
audio_duration_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
|
||||
text_token_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
|
||||
|
||||
for audio_token, text_token, task_id, dataset_idx, is_prompt in zip(
|
||||
audio_tokens, text_tokens, task_ids.tolist(), dataset_ids.tolist(), is_prompts
|
||||
ref_iter = ref_audio_tokens if ref_audio_tokens is not None else [None] * audio_tokens.size(0)
|
||||
|
||||
for audio_token, text_token, task_id, dataset_idx, is_prompt, ref_token in zip(
|
||||
audio_tokens, text_tokens, task_ids.tolist(), dataset_ids.tolist(), is_prompts, ref_iter
|
||||
):
|
||||
unpad_audio_token = self.unpad_audio_tokens(audio_token).to(torch.float32)
|
||||
unpad_text_token = self.unpad_text_tokens(text_token)
|
||||
usage = self.id_to_task[task_id]
|
||||
|
||||
(
|
||||
packed_text,
|
||||
audio_feat,
|
||||
text_mask,
|
||||
audio_mask,
|
||||
loss_mask,
|
||||
labels,
|
||||
audio_duration,
|
||||
text_token_count,
|
||||
) = self.process_functions[usage](unpad_audio_token, unpad_text_token, is_prompt)
|
||||
has_ref = False
|
||||
if ref_token is not None:
|
||||
unpad_ref_token = self.unpad_audio_tokens(ref_token).to(torch.float32)
|
||||
if unpad_ref_token.numel() > 0:
|
||||
has_ref = True
|
||||
|
||||
if has_ref:
|
||||
(
|
||||
packed_text,
|
||||
audio_feat,
|
||||
text_mask,
|
||||
audio_mask,
|
||||
loss_mask,
|
||||
labels,
|
||||
audio_duration,
|
||||
text_token_count,
|
||||
) = self.process_tts_data_with_ref(unpad_ref_token, unpad_audio_token, unpad_text_token)
|
||||
else:
|
||||
(
|
||||
packed_text,
|
||||
audio_feat,
|
||||
text_mask,
|
||||
audio_mask,
|
||||
loss_mask,
|
||||
labels,
|
||||
audio_duration,
|
||||
text_token_count,
|
||||
) = self.process_functions[usage](unpad_audio_token, unpad_text_token, is_prompt)
|
||||
|
||||
audio_duration_consumed[dataset_idx] += audio_duration
|
||||
text_token_consumed[dataset_idx] += text_token_count
|
||||
@@ -294,3 +318,89 @@ class AudioFeatureProcessingPacker:
|
||||
audio_duration,
|
||||
text_token_count,
|
||||
)
|
||||
|
||||
def process_tts_data_with_ref(
|
||||
self,
|
||||
ref_audio_token: torch.Tensor,
|
||||
target_audio_token: torch.Tensor,
|
||||
text_token: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Build a training sequence with reference audio prepended:
|
||||
|
||||
[103, ref_feats, 104, text, 101, target_feats, 102]
|
||||
|
||||
Loss is computed only on the target audio segment.
|
||||
"""
|
||||
device = text_token.device
|
||||
txt_len = len(text_token)
|
||||
|
||||
ref_feats, ref_duration = self.extract_audio_feats(ref_audio_token)
|
||||
ref_feats = ref_feats.squeeze(0) # [R, P, D]
|
||||
ref_len = ref_feats.shape[0]
|
||||
|
||||
tgt_feats, tgt_duration = self.extract_audio_feats(target_audio_token)
|
||||
tgt_feats = tgt_feats.squeeze(0) # [A, P, D]
|
||||
tgt_len = tgt_feats.shape[0]
|
||||
|
||||
feat_shape = (self.patch_size, ref_feats.size(-1))
|
||||
|
||||
def _tok(ids):
|
||||
return torch.tensor(ids, dtype=torch.int32, device=device)
|
||||
|
||||
# -- text token track --
|
||||
# [103, 0×R, 104, text_ids, 101, 0×A, 102]
|
||||
text_token_info = torch.cat([
|
||||
_tok([self.audio_prompt_start_id]),
|
||||
torch.zeros(ref_len, dtype=torch.int32, device=device),
|
||||
_tok([self.audio_prompt_end_id]),
|
||||
text_token,
|
||||
_tok([self.audio_start_id]),
|
||||
torch.zeros(tgt_len, dtype=torch.int32, device=device),
|
||||
_tok([self.audio_end_id]),
|
||||
])
|
||||
|
||||
# -- audio feature track --
|
||||
zero_1 = torch.zeros((1,) + feat_shape, dtype=torch.float32, device=device)
|
||||
zero_txt = torch.zeros((txt_len,) + feat_shape, dtype=torch.float32, device=device)
|
||||
audio_feat_info = torch.cat([
|
||||
zero_1, ref_feats, zero_1, # 103, ref, 104
|
||||
zero_txt, # text
|
||||
zero_1, tgt_feats, zero_1, # 101, target, 102
|
||||
], dim=0)
|
||||
|
||||
# -- masks --
|
||||
text_mask = torch.cat([
|
||||
torch.ones(1), torch.zeros(ref_len), torch.ones(1),
|
||||
torch.ones(txt_len),
|
||||
torch.ones(1), torch.zeros(tgt_len), torch.ones(1),
|
||||
]).to(torch.int32).to(device)
|
||||
|
||||
audio_mask = torch.cat([
|
||||
torch.zeros(1), torch.ones(ref_len), torch.zeros(1),
|
||||
torch.zeros(txt_len),
|
||||
torch.zeros(1), torch.ones(tgt_len), torch.zeros(1),
|
||||
]).to(torch.int32).to(device)
|
||||
|
||||
loss_mask = torch.cat([
|
||||
torch.zeros(1 + ref_len + 1), # ref part: no loss
|
||||
torch.zeros(txt_len), # text: no loss
|
||||
torch.zeros(1), # 101: no loss
|
||||
torch.ones(tgt_len), # target audio: LOSS
|
||||
torch.zeros(1), # 102: no loss
|
||||
]).to(torch.int32).to(device)
|
||||
|
||||
total_len = 1 + ref_len + 1 + txt_len + 1 + tgt_len + 1
|
||||
labels = torch.zeros(total_len, dtype=torch.int32, device=device)
|
||||
labels[-2] = 1 # stop label at last target audio position
|
||||
|
||||
return (
|
||||
text_token_info,
|
||||
audio_feat_info,
|
||||
text_mask,
|
||||
audio_mask,
|
||||
loss_mask,
|
||||
labels,
|
||||
ref_duration + tgt_duration,
|
||||
txt_len,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user