update finetuning pipeline and runtime device handling

Support optional ref_audio samples in finetuning and make runtime device selection explicit while keeping auto fallback behavior consistent. Also ignore the local app override file to avoid accidental commits.

Made-with: Cursor
This commit is contained in:
刘鑫
2026-04-11 11:08:50 +08:00
parent abf01b9bf3
commit e4e049624c
10 changed files with 379 additions and 47 deletions
+1
View File
@@ -3,3 +3,4 @@ __pycache__
voxcpm.egg-info voxcpm.egg-info
.DS_Store .DS_Store
./pretrained_models/ ./pretrained_models/
app_local.py
+8
View File
@@ -209,6 +209,7 @@ def load_model(args) -> VoxCPM:
zipenhancer_model_path=zipenhancer_path, zipenhancer_model_path=zipenhancer_path,
enable_denoiser=not args.no_denoiser, enable_denoiser=not args.no_denoiser,
optimize=not args.no_optimize, optimize=not args.no_optimize,
device=args.device,
lora_config=lora_config, lora_config=lora_config,
lora_weights_path=lora_weights_path, lora_weights_path=lora_weights_path,
) )
@@ -227,6 +228,7 @@ def load_model(args) -> VoxCPM:
cache_dir=args.cache_dir, cache_dir=args.cache_dir,
local_files_only=args.local_files_only, local_files_only=args.local_files_only,
optimize=not args.no_optimize, optimize=not args.no_optimize,
device=args.device,
lora_config=lora_config, lora_config=lora_config,
lora_weights_path=lora_weights_path, lora_weights_path=lora_weights_path,
) )
@@ -403,6 +405,12 @@ def _add_model_args(parser):
default=DEFAULT_HF_MODEL_ID, default=DEFAULT_HF_MODEL_ID,
help=f"Hugging Face repo id (default: {DEFAULT_HF_MODEL_ID})", help=f"Hugging Face repo id (default: {DEFAULT_HF_MODEL_ID})",
) )
parser.add_argument(
"--device",
type=str,
default="auto",
help="Runtime device: auto, cpu, mps, cuda, or cuda:N (default: auto)",
)
parser.add_argument( parser.add_argument(
"--cache-dir", type=str, help="Cache directory for Hub downloads" "--cache-dir", type=str, help="Cache directory for Hub downloads"
) )
+21 -2
View File
@@ -17,6 +17,7 @@ class VoxCPM:
zipenhancer_model_path: str | None = "iic/speech_zipenhancer_ans_multiloss_16k_base", zipenhancer_model_path: str | None = "iic/speech_zipenhancer_ans_multiloss_16k_base",
enable_denoiser: bool = True, enable_denoiser: bool = True,
optimize: bool = True, optimize: bool = True,
device: str | None = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
lora_weights_path: Optional[str] = None, lora_weights_path: Optional[str] = None,
): ):
@@ -30,6 +31,9 @@ class VoxCPM:
id or local path. If None, denoiser will not be initialized. id or local path. If None, denoiser will not be initialized.
enable_denoiser: Whether to initialize the denoiser pipeline. enable_denoiser: Whether to initialize the denoiser pipeline.
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging. optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
device: Runtime device. If set to ``None`` or ``"auto"``, VoxCPM
will choose automatically (preferring CUDA, then MPS, then CPU).
If set explicitly, that device is used or a clear error is raised.
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
provided without lora_config, a default config will be created. provided without lora_config, a default config will be created.
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
@@ -56,10 +60,20 @@ class VoxCPM:
arch = config.get("architecture", "voxcpm").lower() arch = config.get("architecture", "voxcpm").lower()
if arch == "voxcpm2": if arch == "voxcpm2":
self.tts_model = VoxCPM2Model.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config) self.tts_model = VoxCPM2Model.from_local(
voxcpm_model_path,
optimize=optimize,
device=device,
lora_config=lora_config,
)
print("Loaded VoxCPM2Model", file=sys.stderr) print("Loaded VoxCPM2Model", file=sys.stderr)
elif arch == "voxcpm": elif arch == "voxcpm":
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config) self.tts_model = VoxCPMModel.from_local(
voxcpm_model_path,
optimize=optimize,
device=device,
lora_config=lora_config,
)
print("Loaded VoxCPMModel", file=sys.stderr) print("Loaded VoxCPMModel", file=sys.stderr)
else: else:
raise ValueError(f"Unsupported architecture: {arch}") raise ValueError(f"Unsupported architecture: {arch}")
@@ -94,6 +108,7 @@ class VoxCPM:
cache_dir: str = None, cache_dir: str = None,
local_files_only: bool = False, local_files_only: bool = False,
optimize: bool = True, optimize: bool = True,
device: str | None = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
lora_weights_path: Optional[str] = None, lora_weights_path: Optional[str] = None,
**kwargs, **kwargs,
@@ -109,6 +124,9 @@ class VoxCPM:
cache_dir: Custom cache directory for the snapshot. cache_dir: Custom cache directory for the snapshot.
local_files_only: If True, only use local files and do not attempt local_files_only: If True, only use local files and do not attempt
to download. to download.
device: Runtime device. Use ``None``/``"auto"`` for automatic
fallback, or an explicit value such as ``"cpu"``, ``"mps"``,
``"cuda"``, or ``"cuda:0"``.
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
provided without lora_config, a default config will be created with provided without lora_config, a default config will be created with
enable_lm=True and enable_dit=True. enable_lm=True and enable_dit=True.
@@ -146,6 +164,7 @@ class VoxCPM:
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None, zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
enable_denoiser=load_denoiser, enable_denoiser=load_denoiser,
optimize=optimize, optimize=optimize,
device=device,
lora_config=lora_config, lora_config=lora_config,
lora_weights_path=lora_weights_path, lora_weights_path=lora_weights_path,
**kwargs, **kwargs,
+65 -1
View File
@@ -1,4 +1,4 @@
from typing import List from typing import List, Optional
import torch import torch
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
@@ -119,3 +119,67 @@ def get_dtype(dtype: str):
return torch.float32 return torch.float32
else: else:
raise ValueError(f"Unsupported dtype: {dtype}") raise ValueError(f"Unsupported dtype: {dtype}")
def _has_mps() -> bool:
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
def auto_select_device(preferred_device: Optional[str] = "cuda") -> str:
"""
Choose a runtime device automatically.
Preference order:
- if the preferred device is available, use it
- otherwise fall back to CUDA -> MPS -> CPU
"""
preferred = (preferred_device or "cuda").strip().lower()
if preferred.startswith("cuda") and torch.cuda.is_available():
return preferred
if preferred == "mps" and _has_mps():
return "mps"
if preferred == "cpu":
return "cpu"
if torch.cuda.is_available():
return "cuda"
if _has_mps():
return "mps"
return "cpu"
def resolve_runtime_device(device: Optional[str], configured_device: str = "cuda") -> str:
"""
Resolve the actual runtime device.
Semantics:
- ``device`` is ``None`` or ``"auto"``: use automatic fallback selection
- otherwise: treat it as an explicit user choice and validate availability
"""
explicit = None if device is None else device.strip().lower()
if explicit is None or explicit == "auto":
return auto_select_device(configured_device)
if explicit.startswith("cuda"):
if not torch.cuda.is_available():
raise ValueError(
f"Requested device '{device}', but CUDA is not available. "
"Use device='auto' for automatic fallback."
)
return explicit
if explicit == "mps":
if not _has_mps():
raise ValueError(
"Requested device 'mps', but MPS is not available. "
"Use device='auto' for automatic fallback."
)
return "mps"
if explicit == "cpu":
return "cpu"
raise ValueError(
f"Unsupported device '{device}'. Supported values are 'auto', 'cpu', 'mps', "
"'cuda', or indexed CUDA devices like 'cuda:0'."
)
+13 -9
View File
@@ -44,7 +44,7 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
from ..modules.locenc import VoxCPMLocEnc from ..modules.locenc import VoxCPMLocEnc
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
from .utils import get_dtype, mask_multichar_chinese_tokens from .utils import get_dtype, mask_multichar_chinese_tokens, resolve_runtime_device
class VoxCPMEncoderConfig(BaseModel): class VoxCPMEncoderConfig(BaseModel):
@@ -109,18 +109,15 @@ class VoxCPMModel(nn.Module):
tokenizer: LlamaTokenizerFast, tokenizer: LlamaTokenizerFast,
audio_vae: AudioVAE, audio_vae: AudioVAE,
lora_config: LoRAConfig = None, lora_config: LoRAConfig = None,
device: str | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
self.feat_dim = config.feat_dim self.feat_dim = config.feat_dim
self.patch_size = config.patch_size self.patch_size = config.patch_size
self.device = config.device self.device = resolve_runtime_device(device, config.device)
if not torch.cuda.is_available(): self.config.device = self.device
if torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr) print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
# Text-Semantic LM # Text-Semantic LM
@@ -847,7 +844,14 @@ class VoxCPMModel(nn.Module):
yield feat_pred, pred_feat_seq.squeeze(0).cpu() yield feat_pred, pred_feat_seq.squeeze(0).cpu()
@classmethod @classmethod
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None): def from_local(
cls,
path: str,
optimize: bool = True,
training: bool = False,
device: str | None = None,
lora_config: LoRAConfig = None,
):
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read()) config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
tokenizer = LlamaTokenizerFast.from_pretrained(path) tokenizer = LlamaTokenizerFast.from_pretrained(path)
audio_vae_config = getattr(config, "audio_vae_config", None) audio_vae_config = getattr(config, "audio_vae_config", None)
@@ -870,7 +874,7 @@ class VoxCPMModel(nn.Module):
raise FileNotFoundError( raise FileNotFoundError(
f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}" f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
) )
model = cls(config, tokenizer, audio_vae, lora_config) model = cls(config, tokenizer, audio_vae, lora_config, device=device)
if not training: if not training:
lm_dtype = get_dtype(model.config.dtype) lm_dtype = get_dtype(model.config.dtype)
model = model.to(lm_dtype) model = model.to(lm_dtype)
+13 -9
View File
@@ -45,7 +45,7 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiTV2 from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiTV2
from ..modules.locenc import VoxCPMLocEnc from ..modules.locenc import VoxCPMLocEnc
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
from .utils import get_dtype, mask_multichar_chinese_tokens from .utils import get_dtype, mask_multichar_chinese_tokens, resolve_runtime_device
# A simple function to trim audio silence using VAD, not used default # A simple function to trim audio silence using VAD, not used default
@@ -151,18 +151,15 @@ class VoxCPM2Model(nn.Module):
tokenizer: LlamaTokenizerFast, tokenizer: LlamaTokenizerFast,
audio_vae: AudioVAEV2, audio_vae: AudioVAEV2,
lora_config: LoRAConfig = None, lora_config: LoRAConfig = None,
device: str | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
self.feat_dim = config.feat_dim self.feat_dim = config.feat_dim
self.patch_size = config.patch_size self.patch_size = config.patch_size
self.device = config.device self.device = resolve_runtime_device(device, config.device)
if not torch.cuda.is_available(): self.config.device = self.device
if torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr) print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
# Text-Semantic LM # Text-Semantic LM
@@ -1098,7 +1095,14 @@ class VoxCPM2Model(nn.Module):
yield feat_pred, generated_feat, context_len yield feat_pred, generated_feat, context_len
@classmethod @classmethod
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None): def from_local(
cls,
path: str,
optimize: bool = True,
training: bool = False,
device: str | None = None,
lora_config: LoRAConfig = None,
):
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read()) config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
tokenizer = LlamaTokenizerFast.from_pretrained(path) tokenizer = LlamaTokenizerFast.from_pretrained(path)
audio_vae_config = getattr(config, "audio_vae_config", None) audio_vae_config = getattr(config, "audio_vae_config", None)
@@ -1121,7 +1125,7 @@ class VoxCPM2Model(nn.Module):
raise FileNotFoundError( raise FileNotFoundError(
f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}" f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
) )
model = cls(config, tokenizer, audio_vae, lora_config) model = cls(config, tokenizer, audio_vae, lora_config, device=device)
if not training: if not training:
lm_dtype = get_dtype(model.config.dtype) lm_dtype = get_dtype(model.config.dtype)
model = model.to(lm_dtype) model = model.to(lm_dtype)
+53 -11
View File
@@ -12,6 +12,7 @@ from .packers import AudioFeatureProcessingPacker
DEFAULT_TEXT_COLUMN = "text" DEFAULT_TEXT_COLUMN = "text"
DEFAULT_AUDIO_COLUMN = "audio" DEFAULT_AUDIO_COLUMN = "audio"
DEFAULT_REF_AUDIO_COLUMN = "ref_audio"
DEFAULT_ID_COLUMN = "dataset_id" DEFAULT_ID_COLUMN = "dataset_id"
@@ -21,6 +22,7 @@ def load_audio_text_datasets(
val_manifest: str = "", val_manifest: str = "",
text_column: str = DEFAULT_TEXT_COLUMN, text_column: str = DEFAULT_TEXT_COLUMN,
audio_column: str = DEFAULT_AUDIO_COLUMN, audio_column: str = DEFAULT_AUDIO_COLUMN,
ref_audio_column: str = DEFAULT_REF_AUDIO_COLUMN,
dataset_id_column: str = DEFAULT_ID_COLUMN, dataset_id_column: str = DEFAULT_ID_COLUMN,
sample_rate: int = 16_000, sample_rate: int = 16_000,
num_proc: int = 1, num_proc: int = 1,
@@ -34,14 +36,19 @@ def load_audio_text_datasets(
def prepare(ds: Dataset) -> Dataset: def prepare(ds: Dataset) -> Dataset:
if audio_column not in ds.column_names: if audio_column not in ds.column_names:
raise ValueError(f"Expected '{audio_column}' column in manifest.") raise ValueError(f"Expected '{audio_column}' column in manifest.")
# We cast to Audio to ensure proper handling during training,
# but for length calculation we might need raw path or duration if available.
# HF datasets usually don't compute duration automatically for 'Audio' column.
ds = ds.cast_column(audio_column, Audio(sampling_rate=sample_rate)) ds = ds.cast_column(audio_column, Audio(sampling_rate=sample_rate))
if audio_column != DEFAULT_AUDIO_COLUMN: if audio_column != DEFAULT_AUDIO_COLUMN:
ds = ds.rename_column(audio_column, DEFAULT_AUDIO_COLUMN) ds = ds.rename_column(audio_column, DEFAULT_AUDIO_COLUMN)
if text_column != DEFAULT_TEXT_COLUMN: if text_column != DEFAULT_TEXT_COLUMN:
ds = ds.rename_column(text_column, DEFAULT_TEXT_COLUMN) ds = ds.rename_column(text_column, DEFAULT_TEXT_COLUMN)
# ref_audio is optional — cast to Audio if the column exists
ref_col = ref_audio_column if ref_audio_column in ds.column_names else DEFAULT_REF_AUDIO_COLUMN
if ref_col in ds.column_names:
ds = ds.cast_column(ref_col, Audio(sampling_rate=sample_rate))
if ref_col != DEFAULT_REF_AUDIO_COLUMN:
ds = ds.rename_column(ref_col, DEFAULT_REF_AUDIO_COLUMN)
if dataset_id_column and dataset_id_column in ds.column_names: if dataset_id_column and dataset_id_column in ds.column_names:
if dataset_id_column != DEFAULT_ID_COLUMN: if dataset_id_column != DEFAULT_ID_COLUMN:
ds = ds.rename_column(dataset_id_column, DEFAULT_ID_COLUMN) ds = ds.rename_column(dataset_id_column, DEFAULT_ID_COLUMN)
@@ -67,11 +74,11 @@ def compute_sample_lengths(
- 音频长度: - 音频长度:
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
t_seq = ceil(t_vae / patch_size) t_seq = ceil(t_vae / patch_size)
- 序列总长约为: text_len + t_seq + 2 - 无 ref_audio: text_len + t_seq + 2
- 有 ref_audio: text_len + t_seq + ref_seq + 4
Optimized: Use batch column access instead of iterating item by item. Optimized: Use batch column access instead of iterating item by item.
""" """
# Batch access columns - much faster than per-item access
text_ids_list = ds["text_ids"] text_ids_list = ds["text_ids"]
text_lens = [len(t) for t in text_ids_list] text_lens = [len(t) for t in text_ids_list]
@@ -79,18 +86,35 @@ def compute_sample_lengths(
if has_duration: if has_duration:
durations = ds["duration"] durations = ds["duration"]
else: else:
# Fallback: need to compute from audio (slow, but unavoidable without duration column)
durations = [] durations = []
for i in range(len(ds)): for i in range(len(ds)):
audio = ds[i][DEFAULT_AUDIO_COLUMN] audio = ds[i][DEFAULT_AUDIO_COLUMN]
durations.append(len(audio["array"]) / float(audio["sampling_rate"])) durations.append(len(audio["array"]) / float(audio["sampling_rate"]))
# Vectorized length computation has_ref_audio = DEFAULT_REF_AUDIO_COLUMN in ds.column_names
if has_ref_audio:
ref_duration_col = "ref_duration" if "ref_duration" in ds.column_names else None
lengths = [] lengths = []
for text_len, duration in zip(text_lens, durations): for i, (text_len, duration) in enumerate(zip(text_lens, durations)):
t_vae = math.ceil(float(duration) * audio_vae_fps) t_vae = math.ceil(float(duration) * audio_vae_fps)
t_seq = math.ceil(t_vae / patch_size) t_seq = math.ceil(t_vae / patch_size)
total_len = text_len + t_seq + 2
ref_seq = 0
if has_ref_audio:
# Estimate ref_audio length; ref_audio is None for samples without it
if ref_duration_col:
ref_dur = ds[i].get(ref_duration_col)
else:
ref_item = ds[i].get(DEFAULT_REF_AUDIO_COLUMN)
ref_dur = len(ref_item["array"]) / float(ref_item["sampling_rate"]) if ref_item else None
if ref_dur is not None and float(ref_dur) > 0:
ref_vae = math.ceil(float(ref_dur) * audio_vae_fps)
ref_seq = math.ceil(ref_vae / patch_size)
# +2 for 101/102; +2 more for 103/104 when ref_audio present
overhead = 4 if ref_seq > 0 else 2
total_len = text_len + t_seq + ref_seq + overhead
lengths.append(total_len) lengths.append(total_len)
return lengths return lengths
@@ -102,8 +126,11 @@ class HFVoxCPMDataset(TorchDataset):
PyTorch-friendly samples. PyTorch-friendly samples.
""" """
_SENTINEL = [-100.0]
def __init__(self, dataset: Dataset): def __init__(self, dataset: Dataset):
self.dataset = dataset self.dataset = dataset
self.has_ref_audio = DEFAULT_REF_AUDIO_COLUMN in dataset.column_names
def __len__(self): def __len__(self):
return len(self.dataset) return len(self.dataset)
@@ -111,13 +138,17 @@ class HFVoxCPMDataset(TorchDataset):
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
item = self.dataset[idx] item = self.dataset[idx]
audio = item[DEFAULT_AUDIO_COLUMN] audio = item[DEFAULT_AUDIO_COLUMN]
return { sample = {
"text_ids": item["text_ids"], "text_ids": item["text_ids"],
"audio_array": audio["array"], "audio_array": audio["array"],
"audio_sampling_rate": audio["sampling_rate"], "audio_sampling_rate": audio["sampling_rate"],
"dataset_id": item.get(DEFAULT_ID_COLUMN, 0), "dataset_id": item.get(DEFAULT_ID_COLUMN, 0),
"is_prompt": item.get("is_prompt", False), "is_prompt": item.get("is_prompt", False),
} }
if self.has_ref_audio:
ref = item.get(DEFAULT_REF_AUDIO_COLUMN)
sample["ref_audio_array"] = ref["array"] if ref else self._SENTINEL
return sample
@staticmethod @staticmethod
def pad_sequences(seqs: List[torch.Tensor], pad_value: float): def pad_sequences(seqs: List[torch.Tensor], pad_value: float):
@@ -143,7 +174,7 @@ class HFVoxCPMDataset(TorchDataset):
audio_padded = cls.pad_sequences(audio_tensors, pad_value=-100.0) audio_padded = cls.pad_sequences(audio_tensors, pad_value=-100.0)
task_ids = torch.ones(text_padded.size(0), dtype=torch.int32) task_ids = torch.ones(text_padded.size(0), dtype=torch.int32)
return { result = {
"text_tokens": text_padded, "text_tokens": text_padded,
"audio_tokens": audio_padded, "audio_tokens": audio_padded,
"task_ids": task_ids, "task_ids": task_ids,
@@ -151,6 +182,12 @@ class HFVoxCPMDataset(TorchDataset):
"is_prompts": is_prompts, "is_prompts": is_prompts,
} }
if "ref_audio_array" in batch[0]:
ref_tensors = [torch.tensor(s["ref_audio_array"], dtype=torch.float32) for s in batch]
result["ref_audio_tokens"] = cls.pad_sequences(ref_tensors, pad_value=-100.0)
return result
class BatchProcessor: class BatchProcessor:
""" """
@@ -184,12 +221,17 @@ class BatchProcessor:
task_ids = batch["task_ids"].to(self.device) task_ids = batch["task_ids"].to(self.device)
dataset_ids = batch["dataset_ids"].to(self.device) dataset_ids = batch["dataset_ids"].to(self.device)
ref_audio_tokens = None
if "ref_audio_tokens" in batch:
ref_audio_tokens = batch["ref_audio_tokens"].to(self.device)
packed = self.packer( packed = self.packer(
audio_tokens=audio_tokens, audio_tokens=audio_tokens,
text_tokens=text_tokens, text_tokens=text_tokens,
task_ids=task_ids, task_ids=task_ids,
dataset_ids=dataset_ids, dataset_ids=dataset_ids,
is_prompts=batch["is_prompts"], is_prompts=batch["is_prompts"],
ref_audio_tokens=ref_audio_tokens,
) )
return packed return packed
+114 -4
View File
@@ -1,4 +1,4 @@
from typing import Dict, List from typing import Dict, List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -14,7 +14,6 @@ class AudioFeatureProcessingPacker:
def __init__(self, dataset_cnt: int, max_len: int, patch_size: int, feat_dim: int, audio_vae: nn.Module): def __init__(self, dataset_cnt: int, max_len: int, patch_size: int, feat_dim: int, audio_vae: nn.Module):
self.audio_start_id = 101 self.audio_start_id = 101
self.audio_end_id = 102 self.audio_end_id = 102
# unused now
self.audio_prompt_start_id = 103 self.audio_prompt_start_id = 103
self.audio_prompt_end_id = 104 self.audio_prompt_end_id = 104
self.text_eos_token_id = 2 self.text_eos_token_id = 2
@@ -78,11 +77,16 @@ class AudioFeatureProcessingPacker:
task_ids: torch.Tensor, task_ids: torch.Tensor,
dataset_ids: torch.Tensor, dataset_ids: torch.Tensor,
is_prompts: List[bool], is_prompts: List[bool],
ref_audio_tokens: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
""" """
Padding-based batching: each sample in the input batch is processed Padding-based batching: each sample in the input batch is processed
independently and then padded to a common length (capped by ``max_len``). independently and then padded to a common length (capped by ``max_len``).
The result tensors all have shape [B, T, ...]. The result tensors all have shape [B, T, ...].
If ``ref_audio_tokens`` is provided (same batch dim as ``audio_tokens``),
samples whose unpadded ref_audio length > 0 will be processed with the
reference-audio path (tokens 103/104 prepended, loss only on target audio).
""" """
device = audio_tokens.device device = audio_tokens.device
max_dataset_id = int(dataset_ids.max().item()) if dataset_ids.numel() > 0 else -1 max_dataset_id = int(dataset_ids.max().item()) if dataset_ids.numel() > 0 else -1
@@ -101,13 +105,33 @@ class AudioFeatureProcessingPacker:
audio_duration_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device) audio_duration_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
text_token_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device) text_token_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
for audio_token, text_token, task_id, dataset_idx, is_prompt in zip( ref_iter = ref_audio_tokens if ref_audio_tokens is not None else [None] * audio_tokens.size(0)
audio_tokens, text_tokens, task_ids.tolist(), dataset_ids.tolist(), is_prompts
for audio_token, text_token, task_id, dataset_idx, is_prompt, ref_token in zip(
audio_tokens, text_tokens, task_ids.tolist(), dataset_ids.tolist(), is_prompts, ref_iter
): ):
unpad_audio_token = self.unpad_audio_tokens(audio_token).to(torch.float32) unpad_audio_token = self.unpad_audio_tokens(audio_token).to(torch.float32)
unpad_text_token = self.unpad_text_tokens(text_token) unpad_text_token = self.unpad_text_tokens(text_token)
usage = self.id_to_task[task_id] usage = self.id_to_task[task_id]
has_ref = False
if ref_token is not None:
unpad_ref_token = self.unpad_audio_tokens(ref_token).to(torch.float32)
if unpad_ref_token.numel() > 0:
has_ref = True
if has_ref:
(
packed_text,
audio_feat,
text_mask,
audio_mask,
loss_mask,
labels,
audio_duration,
text_token_count,
) = self.process_tts_data_with_ref(unpad_ref_token, unpad_audio_token, unpad_text_token)
else:
( (
packed_text, packed_text,
audio_feat, audio_feat,
@@ -294,3 +318,89 @@ class AudioFeatureProcessingPacker:
audio_duration, audio_duration,
text_token_count, text_token_count,
) )
def process_tts_data_with_ref(
self,
ref_audio_token: torch.Tensor,
target_audio_token: torch.Tensor,
text_token: torch.Tensor,
):
"""
Build a training sequence with reference audio prepended:
[103, ref_feats, 104, text, 101, target_feats, 102]
Loss is computed only on the target audio segment.
"""
device = text_token.device
txt_len = len(text_token)
ref_feats, ref_duration = self.extract_audio_feats(ref_audio_token)
ref_feats = ref_feats.squeeze(0) # [R, P, D]
ref_len = ref_feats.shape[0]
tgt_feats, tgt_duration = self.extract_audio_feats(target_audio_token)
tgt_feats = tgt_feats.squeeze(0) # [A, P, D]
tgt_len = tgt_feats.shape[0]
feat_shape = (self.patch_size, ref_feats.size(-1))
def _tok(ids):
return torch.tensor(ids, dtype=torch.int32, device=device)
# -- text token track --
# [103, 0×R, 104, text_ids, 101, 0×A, 102]
text_token_info = torch.cat([
_tok([self.audio_prompt_start_id]),
torch.zeros(ref_len, dtype=torch.int32, device=device),
_tok([self.audio_prompt_end_id]),
text_token,
_tok([self.audio_start_id]),
torch.zeros(tgt_len, dtype=torch.int32, device=device),
_tok([self.audio_end_id]),
])
# -- audio feature track --
zero_1 = torch.zeros((1,) + feat_shape, dtype=torch.float32, device=device)
zero_txt = torch.zeros((txt_len,) + feat_shape, dtype=torch.float32, device=device)
audio_feat_info = torch.cat([
zero_1, ref_feats, zero_1, # 103, ref, 104
zero_txt, # text
zero_1, tgt_feats, zero_1, # 101, target, 102
], dim=0)
# -- masks --
text_mask = torch.cat([
torch.ones(1), torch.zeros(ref_len), torch.ones(1),
torch.ones(txt_len),
torch.ones(1), torch.zeros(tgt_len), torch.ones(1),
]).to(torch.int32).to(device)
audio_mask = torch.cat([
torch.zeros(1), torch.ones(ref_len), torch.zeros(1),
torch.zeros(txt_len),
torch.zeros(1), torch.ones(tgt_len), torch.zeros(1),
]).to(torch.int32).to(device)
loss_mask = torch.cat([
torch.zeros(1 + ref_len + 1), # ref part: no loss
torch.zeros(txt_len), # text: no loss
torch.zeros(1), # 101: no loss
torch.ones(tgt_len), # target audio: LOSS
torch.zeros(1), # 102: no loss
]).to(torch.int32).to(device)
total_len = 1 + ref_len + 1 + txt_len + 1 + tgt_len + 1
labels = torch.zeros(total_len, dtype=torch.int32, device=device)
labels[-2] = 1 # stop label at last target audio position
return (
text_token_info,
audio_feat_info,
text_mask,
audio_mask,
loss_mask,
labels,
ref_duration + tgt_duration,
txt_len,
)
+31
View File
@@ -58,6 +58,7 @@ def test_parser_defaults_to_voxcpm2():
parser = cli._build_parser() parser = cli._build_parser()
args = parser.parse_args(["design", "--text", "hello", "--output", "out.wav"]) args = parser.parse_args(["design", "--text", "hello", "--output", "out.wav"])
assert args.hf_model_id == "openbmb/VoxCPM2" assert args.hf_model_id == "openbmb/VoxCPM2"
assert args.device == "auto"
assert args.no_optimize is False assert args.no_optimize is False
@@ -85,6 +86,7 @@ def test_load_model_respects_no_optimize_for_local_model(monkeypatch):
cli.load_model(args) cli.load_model(args)
assert calls["kwargs"]["device"] == "auto"
assert calls["kwargs"]["optimize"] is False assert calls["kwargs"]["optimize"] is False
@@ -110,6 +112,7 @@ def test_load_model_defaults_optimize_for_hf(monkeypatch):
cli.load_model(args) cli.load_model(args)
assert calls["kwargs"]["device"] == "auto"
assert calls["kwargs"]["optimize"] is True assert calls["kwargs"]["optimize"] is True
@@ -136,9 +139,37 @@ def test_load_model_respects_no_optimize_for_hf(monkeypatch):
cli.load_model(args) cli.load_model(args)
assert calls["kwargs"]["device"] == "auto"
assert calls["kwargs"]["optimize"] is False assert calls["kwargs"]["optimize"] is False
def test_load_model_passes_explicit_device_to_hf(monkeypatch):
calls = {}
class FakeVoxCPM:
@classmethod
def from_pretrained(cls, **kwargs):
calls["kwargs"] = kwargs
return DummyModel()
monkeypatch.setattr(cli, "VoxCPM", FakeVoxCPM)
args = cli._build_parser().parse_args(
[
"design",
"--text",
"hello",
"--output",
"out.wav",
"--device",
"mps",
]
)
cli.load_model(args)
assert calls["kwargs"]["device"] == "mps"
def test_design_subcommand_applies_control(monkeypatch, tmp_path): def test_design_subcommand_applies_control(monkeypatch, tmp_path):
dummy_model = DummyModel() dummy_model = DummyModel()
monkeypatch.setattr(cli, "load_model", lambda args: dummy_model) monkeypatch.setattr(cli, "load_model", lambda args: dummy_model)
+49
View File
@@ -0,0 +1,49 @@
from __future__ import annotations
import importlib.util
import sys
import types
from pathlib import Path
import pytest
ROOT = Path(__file__).resolve().parents[1]
UTILS_PATH = ROOT / "src" / "voxcpm" / "model" / "utils.py"
transformers_stub = types.ModuleType("transformers")
transformers_stub.PreTrainedTokenizer = object
sys.modules.setdefault("transformers", transformers_stub)
spec = importlib.util.spec_from_file_location("voxcpm.model.utils", UTILS_PATH)
utils = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(utils)
def test_resolve_runtime_device_auto_falls_back_to_cpu(monkeypatch):
monkeypatch.setattr(utils.torch.cuda, "is_available", lambda: False)
monkeypatch.setattr(utils, "_has_mps", lambda: False)
assert utils.resolve_runtime_device(None, "cuda") == "cpu"
def test_resolve_runtime_device_auto_uses_mps_when_available(monkeypatch):
monkeypatch.setattr(utils.torch.cuda, "is_available", lambda: False)
monkeypatch.setattr(utils, "_has_mps", lambda: True)
assert utils.resolve_runtime_device("auto", "cuda") == "mps"
def test_resolve_runtime_device_respects_explicit_cpu(monkeypatch):
monkeypatch.setattr(utils.torch.cuda, "is_available", lambda: True)
monkeypatch.setattr(utils, "_has_mps", lambda: True)
assert utils.resolve_runtime_device("cpu", "cuda") == "cpu"
def test_resolve_runtime_device_rejects_unavailable_explicit_cuda(monkeypatch):
monkeypatch.setattr(utils.torch.cuda, "is_available", lambda: False)
monkeypatch.setattr(utils, "_has_mps", lambda: True)
with pytest.raises(ValueError, match="CUDA is not available"):
utils.resolve_runtime_device("cuda:0", "cuda")