From e4e049624c4695ed682da791bbc2155be7a8f008 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E9=91=AB?= Date: Sat, 11 Apr 2026 11:08:50 +0800 Subject: [PATCH] update finetuning pipeline and runtime device handling Support optional ref_audio samples in finetuning and make runtime device selection explicit while keeping auto fallback behavior consistent. Also ignore the local app override file to avoid accidental commits. Made-with: Cursor --- .gitignore | 3 +- src/voxcpm/cli.py | 8 ++ src/voxcpm/core.py | 23 +++++- src/voxcpm/model/utils.py | 66 +++++++++++++++- src/voxcpm/model/voxcpm.py | 22 +++--- src/voxcpm/model/voxcpm2.py | 22 +++--- src/voxcpm/training/data.py | 64 ++++++++++++--- src/voxcpm/training/packers.py | 138 +++++++++++++++++++++++++++++---- tests/test_cli.py | 31 ++++++++ tests/test_model_utils.py | 49 ++++++++++++ 10 files changed, 379 insertions(+), 47 deletions(-) create mode 100644 tests/test_model_utils.py diff --git a/.gitignore b/.gitignore index 2d6ef51..d397292 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ launch.json __pycache__ voxcpm.egg-info .DS_Store -./pretrained_models/ \ No newline at end of file +./pretrained_models/ +app_local.py \ No newline at end of file diff --git a/src/voxcpm/cli.py b/src/voxcpm/cli.py index 60f8264..8f2e41a 100644 --- a/src/voxcpm/cli.py +++ b/src/voxcpm/cli.py @@ -209,6 +209,7 @@ def load_model(args) -> VoxCPM: zipenhancer_model_path=zipenhancer_path, enable_denoiser=not args.no_denoiser, optimize=not args.no_optimize, + device=args.device, lora_config=lora_config, lora_weights_path=lora_weights_path, ) @@ -227,6 +228,7 @@ def load_model(args) -> VoxCPM: cache_dir=args.cache_dir, local_files_only=args.local_files_only, optimize=not args.no_optimize, + device=args.device, lora_config=lora_config, lora_weights_path=lora_weights_path, ) @@ -403,6 +405,12 @@ def _add_model_args(parser): default=DEFAULT_HF_MODEL_ID, help=f"Hugging Face repo id (default: {DEFAULT_HF_MODEL_ID})", ) + parser.add_argument( + "--device", + type=str, + default="auto", + help="Runtime device: auto, cpu, mps, cuda, or cuda:N (default: auto)", + ) parser.add_argument( "--cache-dir", type=str, help="Cache directory for Hub downloads" ) diff --git a/src/voxcpm/core.py b/src/voxcpm/core.py index 2b95928..9fd7015 100644 --- a/src/voxcpm/core.py +++ b/src/voxcpm/core.py @@ -17,6 +17,7 @@ class VoxCPM: zipenhancer_model_path: str | None = "iic/speech_zipenhancer_ans_multiloss_16k_base", enable_denoiser: bool = True, optimize: bool = True, + device: str | None = None, lora_config: Optional[LoRAConfig] = None, lora_weights_path: Optional[str] = None, ): @@ -30,6 +31,9 @@ class VoxCPM: id or local path. If None, denoiser will not be initialized. enable_denoiser: Whether to initialize the denoiser pipeline. optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging. + device: Runtime device. If set to ``None`` or ``"auto"``, VoxCPM + will choose automatically (preferring CUDA, then MPS, then CPU). + If set explicitly, that device is used or a clear error is raised. lora_config: LoRA configuration for fine-tuning. If lora_weights_path is provided without lora_config, a default config will be created. lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory @@ -56,10 +60,20 @@ class VoxCPM: arch = config.get("architecture", "voxcpm").lower() if arch == "voxcpm2": - self.tts_model = VoxCPM2Model.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config) + self.tts_model = VoxCPM2Model.from_local( + voxcpm_model_path, + optimize=optimize, + device=device, + lora_config=lora_config, + ) print("Loaded VoxCPM2Model", file=sys.stderr) elif arch == "voxcpm": - self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config) + self.tts_model = VoxCPMModel.from_local( + voxcpm_model_path, + optimize=optimize, + device=device, + lora_config=lora_config, + ) print("Loaded VoxCPMModel", file=sys.stderr) else: raise ValueError(f"Unsupported architecture: {arch}") @@ -94,6 +108,7 @@ class VoxCPM: cache_dir: str = None, local_files_only: bool = False, optimize: bool = True, + device: str | None = None, lora_config: Optional[LoRAConfig] = None, lora_weights_path: Optional[str] = None, **kwargs, @@ -109,6 +124,9 @@ class VoxCPM: cache_dir: Custom cache directory for the snapshot. local_files_only: If True, only use local files and do not attempt to download. + device: Runtime device. Use ``None``/``"auto"`` for automatic + fallback, or an explicit value such as ``"cpu"``, ``"mps"``, + ``"cuda"``, or ``"cuda:0"``. lora_config: LoRA configuration for fine-tuning. If lora_weights_path is provided without lora_config, a default config will be created with enable_lm=True and enable_dit=True. @@ -146,6 +164,7 @@ class VoxCPM: zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None, enable_denoiser=load_denoiser, optimize=optimize, + device=device, lora_config=lora_config, lora_weights_path=lora_weights_path, **kwargs, diff --git a/src/voxcpm/model/utils.py b/src/voxcpm/model/utils.py index 98ebe86..904c54d 100644 --- a/src/voxcpm/model/utils.py +++ b/src/voxcpm/model/utils.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional import torch from transformers import PreTrainedTokenizer @@ -119,3 +119,67 @@ def get_dtype(dtype: str): return torch.float32 else: raise ValueError(f"Unsupported dtype: {dtype}") + + +def _has_mps() -> bool: + return hasattr(torch.backends, "mps") and torch.backends.mps.is_available() + + +def auto_select_device(preferred_device: Optional[str] = "cuda") -> str: + """ + Choose a runtime device automatically. + + Preference order: + - if the preferred device is available, use it + - otherwise fall back to CUDA -> MPS -> CPU + """ + preferred = (preferred_device or "cuda").strip().lower() + + if preferred.startswith("cuda") and torch.cuda.is_available(): + return preferred + if preferred == "mps" and _has_mps(): + return "mps" + if preferred == "cpu": + return "cpu" + + if torch.cuda.is_available(): + return "cuda" + if _has_mps(): + return "mps" + return "cpu" + + +def resolve_runtime_device(device: Optional[str], configured_device: str = "cuda") -> str: + """ + Resolve the actual runtime device. + + Semantics: + - ``device`` is ``None`` or ``"auto"``: use automatic fallback selection + - otherwise: treat it as an explicit user choice and validate availability + """ + explicit = None if device is None else device.strip().lower() + + if explicit is None or explicit == "auto": + return auto_select_device(configured_device) + + if explicit.startswith("cuda"): + if not torch.cuda.is_available(): + raise ValueError( + f"Requested device '{device}', but CUDA is not available. " + "Use device='auto' for automatic fallback." + ) + return explicit + if explicit == "mps": + if not _has_mps(): + raise ValueError( + "Requested device 'mps', but MPS is not available. " + "Use device='auto' for automatic fallback." + ) + return "mps" + if explicit == "cpu": + return "cpu" + + raise ValueError( + f"Unsupported device '{device}'. Supported values are 'auto', 'cpu', 'mps', " + "'cuda', or indexed CUDA devices like 'cuda:0'." + ) diff --git a/src/voxcpm/model/voxcpm.py b/src/voxcpm/model/voxcpm.py index b75a847..662a118 100644 --- a/src/voxcpm/model/voxcpm.py +++ b/src/voxcpm/model/voxcpm.py @@ -44,7 +44,7 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT from ..modules.locenc import VoxCPMLocEnc from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel -from .utils import get_dtype, mask_multichar_chinese_tokens +from .utils import get_dtype, mask_multichar_chinese_tokens, resolve_runtime_device class VoxCPMEncoderConfig(BaseModel): @@ -109,18 +109,15 @@ class VoxCPMModel(nn.Module): tokenizer: LlamaTokenizerFast, audio_vae: AudioVAE, lora_config: LoRAConfig = None, + device: str | None = None, ): super().__init__() self.config = config self.lora_config = lora_config self.feat_dim = config.feat_dim self.patch_size = config.patch_size - self.device = config.device - if not torch.cuda.is_available(): - if torch.backends.mps.is_available(): - self.device = "mps" - else: - self.device = "cpu" + self.device = resolve_runtime_device(device, config.device) + self.config.device = self.device print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr) # Text-Semantic LM @@ -847,7 +844,14 @@ class VoxCPMModel(nn.Module): yield feat_pred, pred_feat_seq.squeeze(0).cpu() @classmethod - def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None): + def from_local( + cls, + path: str, + optimize: bool = True, + training: bool = False, + device: str | None = None, + lora_config: LoRAConfig = None, + ): config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read()) tokenizer = LlamaTokenizerFast.from_pretrained(path) audio_vae_config = getattr(config, "audio_vae_config", None) @@ -870,7 +874,7 @@ class VoxCPMModel(nn.Module): raise FileNotFoundError( f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}" ) - model = cls(config, tokenizer, audio_vae, lora_config) + model = cls(config, tokenizer, audio_vae, lora_config, device=device) if not training: lm_dtype = get_dtype(model.config.dtype) model = model.to(lm_dtype) diff --git a/src/voxcpm/model/voxcpm2.py b/src/voxcpm/model/voxcpm2.py index 376567a..268583b 100644 --- a/src/voxcpm/model/voxcpm2.py +++ b/src/voxcpm/model/voxcpm2.py @@ -45,7 +45,7 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiTV2 from ..modules.locenc import VoxCPMLocEnc from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel -from .utils import get_dtype, mask_multichar_chinese_tokens +from .utils import get_dtype, mask_multichar_chinese_tokens, resolve_runtime_device # A simple function to trim audio silence using VAD, not used default @@ -151,18 +151,15 @@ class VoxCPM2Model(nn.Module): tokenizer: LlamaTokenizerFast, audio_vae: AudioVAEV2, lora_config: LoRAConfig = None, + device: str | None = None, ): super().__init__() self.config = config self.lora_config = lora_config self.feat_dim = config.feat_dim self.patch_size = config.patch_size - self.device = config.device - if not torch.cuda.is_available(): - if torch.backends.mps.is_available(): - self.device = "mps" - else: - self.device = "cpu" + self.device = resolve_runtime_device(device, config.device) + self.config.device = self.device print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr) # Text-Semantic LM @@ -1098,7 +1095,14 @@ class VoxCPM2Model(nn.Module): yield feat_pred, generated_feat, context_len @classmethod - def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None): + def from_local( + cls, + path: str, + optimize: bool = True, + training: bool = False, + device: str | None = None, + lora_config: LoRAConfig = None, + ): config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read()) tokenizer = LlamaTokenizerFast.from_pretrained(path) audio_vae_config = getattr(config, "audio_vae_config", None) @@ -1121,7 +1125,7 @@ class VoxCPM2Model(nn.Module): raise FileNotFoundError( f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}" ) - model = cls(config, tokenizer, audio_vae, lora_config) + model = cls(config, tokenizer, audio_vae, lora_config, device=device) if not training: lm_dtype = get_dtype(model.config.dtype) model = model.to(lm_dtype) diff --git a/src/voxcpm/training/data.py b/src/voxcpm/training/data.py index ec8a2cf..7664650 100644 --- a/src/voxcpm/training/data.py +++ b/src/voxcpm/training/data.py @@ -12,6 +12,7 @@ from .packers import AudioFeatureProcessingPacker DEFAULT_TEXT_COLUMN = "text" DEFAULT_AUDIO_COLUMN = "audio" +DEFAULT_REF_AUDIO_COLUMN = "ref_audio" DEFAULT_ID_COLUMN = "dataset_id" @@ -21,6 +22,7 @@ def load_audio_text_datasets( val_manifest: str = "", text_column: str = DEFAULT_TEXT_COLUMN, audio_column: str = DEFAULT_AUDIO_COLUMN, + ref_audio_column: str = DEFAULT_REF_AUDIO_COLUMN, dataset_id_column: str = DEFAULT_ID_COLUMN, sample_rate: int = 16_000, num_proc: int = 1, @@ -34,14 +36,19 @@ def load_audio_text_datasets( def prepare(ds: Dataset) -> Dataset: if audio_column not in ds.column_names: raise ValueError(f"Expected '{audio_column}' column in manifest.") - # We cast to Audio to ensure proper handling during training, - # but for length calculation we might need raw path or duration if available. - # HF datasets usually don't compute duration automatically for 'Audio' column. ds = ds.cast_column(audio_column, Audio(sampling_rate=sample_rate)) if audio_column != DEFAULT_AUDIO_COLUMN: ds = ds.rename_column(audio_column, DEFAULT_AUDIO_COLUMN) if text_column != DEFAULT_TEXT_COLUMN: ds = ds.rename_column(text_column, DEFAULT_TEXT_COLUMN) + + # ref_audio is optional — cast to Audio if the column exists + ref_col = ref_audio_column if ref_audio_column in ds.column_names else DEFAULT_REF_AUDIO_COLUMN + if ref_col in ds.column_names: + ds = ds.cast_column(ref_col, Audio(sampling_rate=sample_rate)) + if ref_col != DEFAULT_REF_AUDIO_COLUMN: + ds = ds.rename_column(ref_col, DEFAULT_REF_AUDIO_COLUMN) + if dataset_id_column and dataset_id_column in ds.column_names: if dataset_id_column != DEFAULT_ID_COLUMN: ds = ds.rename_column(dataset_id_column, DEFAULT_ID_COLUMN) @@ -67,11 +74,11 @@ def compute_sample_lengths( - 音频长度: duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae t_seq = ceil(t_vae / patch_size) - - 序列总长约为: text_len + t_seq + 2 + - 无 ref_audio: text_len + t_seq + 2 + - 有 ref_audio: text_len + t_seq + ref_seq + 4 Optimized: Use batch column access instead of iterating item by item. """ - # Batch access columns - much faster than per-item access text_ids_list = ds["text_ids"] text_lens = [len(t) for t in text_ids_list] @@ -79,18 +86,35 @@ def compute_sample_lengths( if has_duration: durations = ds["duration"] else: - # Fallback: need to compute from audio (slow, but unavoidable without duration column) durations = [] for i in range(len(ds)): audio = ds[i][DEFAULT_AUDIO_COLUMN] durations.append(len(audio["array"]) / float(audio["sampling_rate"])) - # Vectorized length computation + has_ref_audio = DEFAULT_REF_AUDIO_COLUMN in ds.column_names + if has_ref_audio: + ref_duration_col = "ref_duration" if "ref_duration" in ds.column_names else None + lengths = [] - for text_len, duration in zip(text_lens, durations): + for i, (text_len, duration) in enumerate(zip(text_lens, durations)): t_vae = math.ceil(float(duration) * audio_vae_fps) t_seq = math.ceil(t_vae / patch_size) - total_len = text_len + t_seq + 2 + + ref_seq = 0 + if has_ref_audio: + # Estimate ref_audio length; ref_audio is None for samples without it + if ref_duration_col: + ref_dur = ds[i].get(ref_duration_col) + else: + ref_item = ds[i].get(DEFAULT_REF_AUDIO_COLUMN) + ref_dur = len(ref_item["array"]) / float(ref_item["sampling_rate"]) if ref_item else None + if ref_dur is not None and float(ref_dur) > 0: + ref_vae = math.ceil(float(ref_dur) * audio_vae_fps) + ref_seq = math.ceil(ref_vae / patch_size) + + # +2 for 101/102; +2 more for 103/104 when ref_audio present + overhead = 4 if ref_seq > 0 else 2 + total_len = text_len + t_seq + ref_seq + overhead lengths.append(total_len) return lengths @@ -102,8 +126,11 @@ class HFVoxCPMDataset(TorchDataset): PyTorch-friendly samples. """ + _SENTINEL = [-100.0] + def __init__(self, dataset: Dataset): self.dataset = dataset + self.has_ref_audio = DEFAULT_REF_AUDIO_COLUMN in dataset.column_names def __len__(self): return len(self.dataset) @@ -111,13 +138,17 @@ class HFVoxCPMDataset(TorchDataset): def __getitem__(self, idx: int): item = self.dataset[idx] audio = item[DEFAULT_AUDIO_COLUMN] - return { + sample = { "text_ids": item["text_ids"], "audio_array": audio["array"], "audio_sampling_rate": audio["sampling_rate"], "dataset_id": item.get(DEFAULT_ID_COLUMN, 0), "is_prompt": item.get("is_prompt", False), } + if self.has_ref_audio: + ref = item.get(DEFAULT_REF_AUDIO_COLUMN) + sample["ref_audio_array"] = ref["array"] if ref else self._SENTINEL + return sample @staticmethod def pad_sequences(seqs: List[torch.Tensor], pad_value: float): @@ -143,7 +174,7 @@ class HFVoxCPMDataset(TorchDataset): audio_padded = cls.pad_sequences(audio_tensors, pad_value=-100.0) task_ids = torch.ones(text_padded.size(0), dtype=torch.int32) - return { + result = { "text_tokens": text_padded, "audio_tokens": audio_padded, "task_ids": task_ids, @@ -151,6 +182,12 @@ class HFVoxCPMDataset(TorchDataset): "is_prompts": is_prompts, } + if "ref_audio_array" in batch[0]: + ref_tensors = [torch.tensor(s["ref_audio_array"], dtype=torch.float32) for s in batch] + result["ref_audio_tokens"] = cls.pad_sequences(ref_tensors, pad_value=-100.0) + + return result + class BatchProcessor: """ @@ -184,12 +221,17 @@ class BatchProcessor: task_ids = batch["task_ids"].to(self.device) dataset_ids = batch["dataset_ids"].to(self.device) + ref_audio_tokens = None + if "ref_audio_tokens" in batch: + ref_audio_tokens = batch["ref_audio_tokens"].to(self.device) + packed = self.packer( audio_tokens=audio_tokens, text_tokens=text_tokens, task_ids=task_ids, dataset_ids=dataset_ids, is_prompts=batch["is_prompts"], + ref_audio_tokens=ref_audio_tokens, ) return packed diff --git a/src/voxcpm/training/packers.py b/src/voxcpm/training/packers.py index a6ee46f..cfcc193 100644 --- a/src/voxcpm/training/packers.py +++ b/src/voxcpm/training/packers.py @@ -1,4 +1,4 @@ -from typing import Dict, List +from typing import Dict, List, Optional import torch import torch.nn as nn @@ -14,7 +14,6 @@ class AudioFeatureProcessingPacker: def __init__(self, dataset_cnt: int, max_len: int, patch_size: int, feat_dim: int, audio_vae: nn.Module): self.audio_start_id = 101 self.audio_end_id = 102 - # unused now self.audio_prompt_start_id = 103 self.audio_prompt_end_id = 104 self.text_eos_token_id = 2 @@ -78,11 +77,16 @@ class AudioFeatureProcessingPacker: task_ids: torch.Tensor, dataset_ids: torch.Tensor, is_prompts: List[bool], + ref_audio_tokens: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: """ Padding-based batching: each sample in the input batch is processed independently and then padded to a common length (capped by ``max_len``). The result tensors all have shape [B, T, ...]. + + If ``ref_audio_tokens`` is provided (same batch dim as ``audio_tokens``), + samples whose unpadded ref_audio length > 0 will be processed with the + reference-audio path (tokens 103/104 prepended, loss only on target audio). """ device = audio_tokens.device max_dataset_id = int(dataset_ids.max().item()) if dataset_ids.numel() > 0 else -1 @@ -101,23 +105,43 @@ class AudioFeatureProcessingPacker: audio_duration_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device) text_token_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device) - for audio_token, text_token, task_id, dataset_idx, is_prompt in zip( - audio_tokens, text_tokens, task_ids.tolist(), dataset_ids.tolist(), is_prompts + ref_iter = ref_audio_tokens if ref_audio_tokens is not None else [None] * audio_tokens.size(0) + + for audio_token, text_token, task_id, dataset_idx, is_prompt, ref_token in zip( + audio_tokens, text_tokens, task_ids.tolist(), dataset_ids.tolist(), is_prompts, ref_iter ): unpad_audio_token = self.unpad_audio_tokens(audio_token).to(torch.float32) unpad_text_token = self.unpad_text_tokens(text_token) usage = self.id_to_task[task_id] - ( - packed_text, - audio_feat, - text_mask, - audio_mask, - loss_mask, - labels, - audio_duration, - text_token_count, - ) = self.process_functions[usage](unpad_audio_token, unpad_text_token, is_prompt) + has_ref = False + if ref_token is not None: + unpad_ref_token = self.unpad_audio_tokens(ref_token).to(torch.float32) + if unpad_ref_token.numel() > 0: + has_ref = True + + if has_ref: + ( + packed_text, + audio_feat, + text_mask, + audio_mask, + loss_mask, + labels, + audio_duration, + text_token_count, + ) = self.process_tts_data_with_ref(unpad_ref_token, unpad_audio_token, unpad_text_token) + else: + ( + packed_text, + audio_feat, + text_mask, + audio_mask, + loss_mask, + labels, + audio_duration, + text_token_count, + ) = self.process_functions[usage](unpad_audio_token, unpad_text_token, is_prompt) audio_duration_consumed[dataset_idx] += audio_duration text_token_consumed[dataset_idx] += text_token_count @@ -294,3 +318,89 @@ class AudioFeatureProcessingPacker: audio_duration, text_token_count, ) + + def process_tts_data_with_ref( + self, + ref_audio_token: torch.Tensor, + target_audio_token: torch.Tensor, + text_token: torch.Tensor, + ): + """ + Build a training sequence with reference audio prepended: + + [103, ref_feats, 104, text, 101, target_feats, 102] + + Loss is computed only on the target audio segment. + """ + device = text_token.device + txt_len = len(text_token) + + ref_feats, ref_duration = self.extract_audio_feats(ref_audio_token) + ref_feats = ref_feats.squeeze(0) # [R, P, D] + ref_len = ref_feats.shape[0] + + tgt_feats, tgt_duration = self.extract_audio_feats(target_audio_token) + tgt_feats = tgt_feats.squeeze(0) # [A, P, D] + tgt_len = tgt_feats.shape[0] + + feat_shape = (self.patch_size, ref_feats.size(-1)) + + def _tok(ids): + return torch.tensor(ids, dtype=torch.int32, device=device) + + # -- text token track -- + # [103, 0×R, 104, text_ids, 101, 0×A, 102] + text_token_info = torch.cat([ + _tok([self.audio_prompt_start_id]), + torch.zeros(ref_len, dtype=torch.int32, device=device), + _tok([self.audio_prompt_end_id]), + text_token, + _tok([self.audio_start_id]), + torch.zeros(tgt_len, dtype=torch.int32, device=device), + _tok([self.audio_end_id]), + ]) + + # -- audio feature track -- + zero_1 = torch.zeros((1,) + feat_shape, dtype=torch.float32, device=device) + zero_txt = torch.zeros((txt_len,) + feat_shape, dtype=torch.float32, device=device) + audio_feat_info = torch.cat([ + zero_1, ref_feats, zero_1, # 103, ref, 104 + zero_txt, # text + zero_1, tgt_feats, zero_1, # 101, target, 102 + ], dim=0) + + # -- masks -- + text_mask = torch.cat([ + torch.ones(1), torch.zeros(ref_len), torch.ones(1), + torch.ones(txt_len), + torch.ones(1), torch.zeros(tgt_len), torch.ones(1), + ]).to(torch.int32).to(device) + + audio_mask = torch.cat([ + torch.zeros(1), torch.ones(ref_len), torch.zeros(1), + torch.zeros(txt_len), + torch.zeros(1), torch.ones(tgt_len), torch.zeros(1), + ]).to(torch.int32).to(device) + + loss_mask = torch.cat([ + torch.zeros(1 + ref_len + 1), # ref part: no loss + torch.zeros(txt_len), # text: no loss + torch.zeros(1), # 101: no loss + torch.ones(tgt_len), # target audio: LOSS + torch.zeros(1), # 102: no loss + ]).to(torch.int32).to(device) + + total_len = 1 + ref_len + 1 + txt_len + 1 + tgt_len + 1 + labels = torch.zeros(total_len, dtype=torch.int32, device=device) + labels[-2] = 1 # stop label at last target audio position + + return ( + text_token_info, + audio_feat_info, + text_mask, + audio_mask, + loss_mask, + labels, + ref_duration + tgt_duration, + txt_len, + ) diff --git a/tests/test_cli.py b/tests/test_cli.py index 509ef6d..cae8ec8 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -58,6 +58,7 @@ def test_parser_defaults_to_voxcpm2(): parser = cli._build_parser() args = parser.parse_args(["design", "--text", "hello", "--output", "out.wav"]) assert args.hf_model_id == "openbmb/VoxCPM2" + assert args.device == "auto" assert args.no_optimize is False @@ -85,6 +86,7 @@ def test_load_model_respects_no_optimize_for_local_model(monkeypatch): cli.load_model(args) + assert calls["kwargs"]["device"] == "auto" assert calls["kwargs"]["optimize"] is False @@ -110,6 +112,7 @@ def test_load_model_defaults_optimize_for_hf(monkeypatch): cli.load_model(args) + assert calls["kwargs"]["device"] == "auto" assert calls["kwargs"]["optimize"] is True @@ -136,9 +139,37 @@ def test_load_model_respects_no_optimize_for_hf(monkeypatch): cli.load_model(args) + assert calls["kwargs"]["device"] == "auto" assert calls["kwargs"]["optimize"] is False +def test_load_model_passes_explicit_device_to_hf(monkeypatch): + calls = {} + + class FakeVoxCPM: + @classmethod + def from_pretrained(cls, **kwargs): + calls["kwargs"] = kwargs + return DummyModel() + + monkeypatch.setattr(cli, "VoxCPM", FakeVoxCPM) + args = cli._build_parser().parse_args( + [ + "design", + "--text", + "hello", + "--output", + "out.wav", + "--device", + "mps", + ] + ) + + cli.load_model(args) + + assert calls["kwargs"]["device"] == "mps" + + def test_design_subcommand_applies_control(monkeypatch, tmp_path): dummy_model = DummyModel() monkeypatch.setattr(cli, "load_model", lambda args: dummy_model) diff --git a/tests/test_model_utils.py b/tests/test_model_utils.py new file mode 100644 index 0000000..bb69ffc --- /dev/null +++ b/tests/test_model_utils.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import importlib.util +import sys +import types +from pathlib import Path + +import pytest + +ROOT = Path(__file__).resolve().parents[1] +UTILS_PATH = ROOT / "src" / "voxcpm" / "model" / "utils.py" + +transformers_stub = types.ModuleType("transformers") +transformers_stub.PreTrainedTokenizer = object +sys.modules.setdefault("transformers", transformers_stub) + +spec = importlib.util.spec_from_file_location("voxcpm.model.utils", UTILS_PATH) +utils = importlib.util.module_from_spec(spec) +assert spec.loader is not None +spec.loader.exec_module(utils) + + +def test_resolve_runtime_device_auto_falls_back_to_cpu(monkeypatch): + monkeypatch.setattr(utils.torch.cuda, "is_available", lambda: False) + monkeypatch.setattr(utils, "_has_mps", lambda: False) + + assert utils.resolve_runtime_device(None, "cuda") == "cpu" + + +def test_resolve_runtime_device_auto_uses_mps_when_available(monkeypatch): + monkeypatch.setattr(utils.torch.cuda, "is_available", lambda: False) + monkeypatch.setattr(utils, "_has_mps", lambda: True) + + assert utils.resolve_runtime_device("auto", "cuda") == "mps" + + +def test_resolve_runtime_device_respects_explicit_cpu(monkeypatch): + monkeypatch.setattr(utils.torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(utils, "_has_mps", lambda: True) + + assert utils.resolve_runtime_device("cpu", "cuda") == "cpu" + + +def test_resolve_runtime_device_rejects_unavailable_explicit_cuda(monkeypatch): + monkeypatch.setattr(utils.torch.cuda, "is_available", lambda: False) + monkeypatch.setattr(utils, "_has_mps", lambda: True) + + with pytest.raises(ValueError, match="CUDA is not available"): + utils.resolve_runtime_device("cuda:0", "cuda")