update finetuning pipeline and runtime device handling
Support optional ref_audio samples in finetuning and make runtime device selection explicit while keeping auto fallback behavior consistent. Also ignore the local app override file to avoid accidental commits. Made-with: Cursor
This commit is contained in:
@@ -3,3 +3,4 @@ __pycache__
|
|||||||
voxcpm.egg-info
|
voxcpm.egg-info
|
||||||
.DS_Store
|
.DS_Store
|
||||||
./pretrained_models/
|
./pretrained_models/
|
||||||
|
app_local.py
|
||||||
@@ -209,6 +209,7 @@ def load_model(args) -> VoxCPM:
|
|||||||
zipenhancer_model_path=zipenhancer_path,
|
zipenhancer_model_path=zipenhancer_path,
|
||||||
enable_denoiser=not args.no_denoiser,
|
enable_denoiser=not args.no_denoiser,
|
||||||
optimize=not args.no_optimize,
|
optimize=not args.no_optimize,
|
||||||
|
device=args.device,
|
||||||
lora_config=lora_config,
|
lora_config=lora_config,
|
||||||
lora_weights_path=lora_weights_path,
|
lora_weights_path=lora_weights_path,
|
||||||
)
|
)
|
||||||
@@ -227,6 +228,7 @@ def load_model(args) -> VoxCPM:
|
|||||||
cache_dir=args.cache_dir,
|
cache_dir=args.cache_dir,
|
||||||
local_files_only=args.local_files_only,
|
local_files_only=args.local_files_only,
|
||||||
optimize=not args.no_optimize,
|
optimize=not args.no_optimize,
|
||||||
|
device=args.device,
|
||||||
lora_config=lora_config,
|
lora_config=lora_config,
|
||||||
lora_weights_path=lora_weights_path,
|
lora_weights_path=lora_weights_path,
|
||||||
)
|
)
|
||||||
@@ -403,6 +405,12 @@ def _add_model_args(parser):
|
|||||||
default=DEFAULT_HF_MODEL_ID,
|
default=DEFAULT_HF_MODEL_ID,
|
||||||
help=f"Hugging Face repo id (default: {DEFAULT_HF_MODEL_ID})",
|
help=f"Hugging Face repo id (default: {DEFAULT_HF_MODEL_ID})",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default="auto",
|
||||||
|
help="Runtime device: auto, cpu, mps, cuda, or cuda:N (default: auto)",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cache-dir", type=str, help="Cache directory for Hub downloads"
|
"--cache-dir", type=str, help="Cache directory for Hub downloads"
|
||||||
)
|
)
|
||||||
|
|||||||
+21
-2
@@ -17,6 +17,7 @@ class VoxCPM:
|
|||||||
zipenhancer_model_path: str | None = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
zipenhancer_model_path: str | None = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||||
enable_denoiser: bool = True,
|
enable_denoiser: bool = True,
|
||||||
optimize: bool = True,
|
optimize: bool = True,
|
||||||
|
device: str | None = None,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
lora_weights_path: Optional[str] = None,
|
lora_weights_path: Optional[str] = None,
|
||||||
):
|
):
|
||||||
@@ -30,6 +31,9 @@ class VoxCPM:
|
|||||||
id or local path. If None, denoiser will not be initialized.
|
id or local path. If None, denoiser will not be initialized.
|
||||||
enable_denoiser: Whether to initialize the denoiser pipeline.
|
enable_denoiser: Whether to initialize the denoiser pipeline.
|
||||||
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
|
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
|
||||||
|
device: Runtime device. If set to ``None`` or ``"auto"``, VoxCPM
|
||||||
|
will choose automatically (preferring CUDA, then MPS, then CPU).
|
||||||
|
If set explicitly, that device is used or a clear error is raised.
|
||||||
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
|
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
|
||||||
provided without lora_config, a default config will be created.
|
provided without lora_config, a default config will be created.
|
||||||
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
|
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
|
||||||
@@ -56,10 +60,20 @@ class VoxCPM:
|
|||||||
arch = config.get("architecture", "voxcpm").lower()
|
arch = config.get("architecture", "voxcpm").lower()
|
||||||
|
|
||||||
if arch == "voxcpm2":
|
if arch == "voxcpm2":
|
||||||
self.tts_model = VoxCPM2Model.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
self.tts_model = VoxCPM2Model.from_local(
|
||||||
|
voxcpm_model_path,
|
||||||
|
optimize=optimize,
|
||||||
|
device=device,
|
||||||
|
lora_config=lora_config,
|
||||||
|
)
|
||||||
print("Loaded VoxCPM2Model", file=sys.stderr)
|
print("Loaded VoxCPM2Model", file=sys.stderr)
|
||||||
elif arch == "voxcpm":
|
elif arch == "voxcpm":
|
||||||
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
self.tts_model = VoxCPMModel.from_local(
|
||||||
|
voxcpm_model_path,
|
||||||
|
optimize=optimize,
|
||||||
|
device=device,
|
||||||
|
lora_config=lora_config,
|
||||||
|
)
|
||||||
print("Loaded VoxCPMModel", file=sys.stderr)
|
print("Loaded VoxCPMModel", file=sys.stderr)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported architecture: {arch}")
|
raise ValueError(f"Unsupported architecture: {arch}")
|
||||||
@@ -94,6 +108,7 @@ class VoxCPM:
|
|||||||
cache_dir: str = None,
|
cache_dir: str = None,
|
||||||
local_files_only: bool = False,
|
local_files_only: bool = False,
|
||||||
optimize: bool = True,
|
optimize: bool = True,
|
||||||
|
device: str | None = None,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
lora_weights_path: Optional[str] = None,
|
lora_weights_path: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -109,6 +124,9 @@ class VoxCPM:
|
|||||||
cache_dir: Custom cache directory for the snapshot.
|
cache_dir: Custom cache directory for the snapshot.
|
||||||
local_files_only: If True, only use local files and do not attempt
|
local_files_only: If True, only use local files and do not attempt
|
||||||
to download.
|
to download.
|
||||||
|
device: Runtime device. Use ``None``/``"auto"`` for automatic
|
||||||
|
fallback, or an explicit value such as ``"cpu"``, ``"mps"``,
|
||||||
|
``"cuda"``, or ``"cuda:0"``.
|
||||||
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
|
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
|
||||||
provided without lora_config, a default config will be created with
|
provided without lora_config, a default config will be created with
|
||||||
enable_lm=True and enable_dit=True.
|
enable_lm=True and enable_dit=True.
|
||||||
@@ -146,6 +164,7 @@ class VoxCPM:
|
|||||||
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
|
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
|
||||||
enable_denoiser=load_denoiser,
|
enable_denoiser=load_denoiser,
|
||||||
optimize=optimize,
|
optimize=optimize,
|
||||||
|
device=device,
|
||||||
lora_config=lora_config,
|
lora_config=lora_config,
|
||||||
lora_weights_path=lora_weights_path,
|
lora_weights_path=lora_weights_path,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import List
|
from typing import List, Optional
|
||||||
import torch
|
import torch
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
@@ -119,3 +119,67 @@ def get_dtype(dtype: str):
|
|||||||
return torch.float32
|
return torch.float32
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
def _has_mps() -> bool:
|
||||||
|
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||||
|
|
||||||
|
|
||||||
|
def auto_select_device(preferred_device: Optional[str] = "cuda") -> str:
|
||||||
|
"""
|
||||||
|
Choose a runtime device automatically.
|
||||||
|
|
||||||
|
Preference order:
|
||||||
|
- if the preferred device is available, use it
|
||||||
|
- otherwise fall back to CUDA -> MPS -> CPU
|
||||||
|
"""
|
||||||
|
preferred = (preferred_device or "cuda").strip().lower()
|
||||||
|
|
||||||
|
if preferred.startswith("cuda") and torch.cuda.is_available():
|
||||||
|
return preferred
|
||||||
|
if preferred == "mps" and _has_mps():
|
||||||
|
return "mps"
|
||||||
|
if preferred == "cpu":
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return "cuda"
|
||||||
|
if _has_mps():
|
||||||
|
return "mps"
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_runtime_device(device: Optional[str], configured_device: str = "cuda") -> str:
|
||||||
|
"""
|
||||||
|
Resolve the actual runtime device.
|
||||||
|
|
||||||
|
Semantics:
|
||||||
|
- ``device`` is ``None`` or ``"auto"``: use automatic fallback selection
|
||||||
|
- otherwise: treat it as an explicit user choice and validate availability
|
||||||
|
"""
|
||||||
|
explicit = None if device is None else device.strip().lower()
|
||||||
|
|
||||||
|
if explicit is None or explicit == "auto":
|
||||||
|
return auto_select_device(configured_device)
|
||||||
|
|
||||||
|
if explicit.startswith("cuda"):
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise ValueError(
|
||||||
|
f"Requested device '{device}', but CUDA is not available. "
|
||||||
|
"Use device='auto' for automatic fallback."
|
||||||
|
)
|
||||||
|
return explicit
|
||||||
|
if explicit == "mps":
|
||||||
|
if not _has_mps():
|
||||||
|
raise ValueError(
|
||||||
|
"Requested device 'mps', but MPS is not available. "
|
||||||
|
"Use device='auto' for automatic fallback."
|
||||||
|
)
|
||||||
|
return "mps"
|
||||||
|
if explicit == "cpu":
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported device '{device}'. Supported values are 'auto', 'cpu', 'mps', "
|
||||||
|
"'cuda', or indexed CUDA devices like 'cuda:0'."
|
||||||
|
)
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules
|
|||||||
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
|
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
|
||||||
from ..modules.locenc import VoxCPMLocEnc
|
from ..modules.locenc import VoxCPMLocEnc
|
||||||
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
||||||
from .utils import get_dtype, mask_multichar_chinese_tokens
|
from .utils import get_dtype, mask_multichar_chinese_tokens, resolve_runtime_device
|
||||||
|
|
||||||
|
|
||||||
class VoxCPMEncoderConfig(BaseModel):
|
class VoxCPMEncoderConfig(BaseModel):
|
||||||
@@ -109,18 +109,15 @@ class VoxCPMModel(nn.Module):
|
|||||||
tokenizer: LlamaTokenizerFast,
|
tokenizer: LlamaTokenizerFast,
|
||||||
audio_vae: AudioVAE,
|
audio_vae: AudioVAE,
|
||||||
lora_config: LoRAConfig = None,
|
lora_config: LoRAConfig = None,
|
||||||
|
device: str | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
self.feat_dim = config.feat_dim
|
self.feat_dim = config.feat_dim
|
||||||
self.patch_size = config.patch_size
|
self.patch_size = config.patch_size
|
||||||
self.device = config.device
|
self.device = resolve_runtime_device(device, config.device)
|
||||||
if not torch.cuda.is_available():
|
self.config.device = self.device
|
||||||
if torch.backends.mps.is_available():
|
|
||||||
self.device = "mps"
|
|
||||||
else:
|
|
||||||
self.device = "cpu"
|
|
||||||
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
|
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
|
||||||
|
|
||||||
# Text-Semantic LM
|
# Text-Semantic LM
|
||||||
@@ -847,7 +844,14 @@ class VoxCPMModel(nn.Module):
|
|||||||
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
|
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
|
def from_local(
|
||||||
|
cls,
|
||||||
|
path: str,
|
||||||
|
optimize: bool = True,
|
||||||
|
training: bool = False,
|
||||||
|
device: str | None = None,
|
||||||
|
lora_config: LoRAConfig = None,
|
||||||
|
):
|
||||||
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
||||||
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
||||||
audio_vae_config = getattr(config, "audio_vae_config", None)
|
audio_vae_config = getattr(config, "audio_vae_config", None)
|
||||||
@@ -870,7 +874,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
|
f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
|
||||||
)
|
)
|
||||||
model = cls(config, tokenizer, audio_vae, lora_config)
|
model = cls(config, tokenizer, audio_vae, lora_config, device=device)
|
||||||
if not training:
|
if not training:
|
||||||
lm_dtype = get_dtype(model.config.dtype)
|
lm_dtype = get_dtype(model.config.dtype)
|
||||||
model = model.to(lm_dtype)
|
model = model.to(lm_dtype)
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules
|
|||||||
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiTV2
|
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiTV2
|
||||||
from ..modules.locenc import VoxCPMLocEnc
|
from ..modules.locenc import VoxCPMLocEnc
|
||||||
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
||||||
from .utils import get_dtype, mask_multichar_chinese_tokens
|
from .utils import get_dtype, mask_multichar_chinese_tokens, resolve_runtime_device
|
||||||
|
|
||||||
|
|
||||||
# A simple function to trim audio silence using VAD, not used default
|
# A simple function to trim audio silence using VAD, not used default
|
||||||
@@ -151,18 +151,15 @@ class VoxCPM2Model(nn.Module):
|
|||||||
tokenizer: LlamaTokenizerFast,
|
tokenizer: LlamaTokenizerFast,
|
||||||
audio_vae: AudioVAEV2,
|
audio_vae: AudioVAEV2,
|
||||||
lora_config: LoRAConfig = None,
|
lora_config: LoRAConfig = None,
|
||||||
|
device: str | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
self.feat_dim = config.feat_dim
|
self.feat_dim = config.feat_dim
|
||||||
self.patch_size = config.patch_size
|
self.patch_size = config.patch_size
|
||||||
self.device = config.device
|
self.device = resolve_runtime_device(device, config.device)
|
||||||
if not torch.cuda.is_available():
|
self.config.device = self.device
|
||||||
if torch.backends.mps.is_available():
|
|
||||||
self.device = "mps"
|
|
||||||
else:
|
|
||||||
self.device = "cpu"
|
|
||||||
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
|
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
|
||||||
|
|
||||||
# Text-Semantic LM
|
# Text-Semantic LM
|
||||||
@@ -1098,7 +1095,14 @@ class VoxCPM2Model(nn.Module):
|
|||||||
yield feat_pred, generated_feat, context_len
|
yield feat_pred, generated_feat, context_len
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
|
def from_local(
|
||||||
|
cls,
|
||||||
|
path: str,
|
||||||
|
optimize: bool = True,
|
||||||
|
training: bool = False,
|
||||||
|
device: str | None = None,
|
||||||
|
lora_config: LoRAConfig = None,
|
||||||
|
):
|
||||||
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
||||||
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
||||||
audio_vae_config = getattr(config, "audio_vae_config", None)
|
audio_vae_config = getattr(config, "audio_vae_config", None)
|
||||||
@@ -1121,7 +1125,7 @@ class VoxCPM2Model(nn.Module):
|
|||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
|
f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
|
||||||
)
|
)
|
||||||
model = cls(config, tokenizer, audio_vae, lora_config)
|
model = cls(config, tokenizer, audio_vae, lora_config, device=device)
|
||||||
if not training:
|
if not training:
|
||||||
lm_dtype = get_dtype(model.config.dtype)
|
lm_dtype = get_dtype(model.config.dtype)
|
||||||
model = model.to(lm_dtype)
|
model = model.to(lm_dtype)
|
||||||
|
|||||||
+53
-11
@@ -12,6 +12,7 @@ from .packers import AudioFeatureProcessingPacker
|
|||||||
|
|
||||||
DEFAULT_TEXT_COLUMN = "text"
|
DEFAULT_TEXT_COLUMN = "text"
|
||||||
DEFAULT_AUDIO_COLUMN = "audio"
|
DEFAULT_AUDIO_COLUMN = "audio"
|
||||||
|
DEFAULT_REF_AUDIO_COLUMN = "ref_audio"
|
||||||
DEFAULT_ID_COLUMN = "dataset_id"
|
DEFAULT_ID_COLUMN = "dataset_id"
|
||||||
|
|
||||||
|
|
||||||
@@ -21,6 +22,7 @@ def load_audio_text_datasets(
|
|||||||
val_manifest: str = "",
|
val_manifest: str = "",
|
||||||
text_column: str = DEFAULT_TEXT_COLUMN,
|
text_column: str = DEFAULT_TEXT_COLUMN,
|
||||||
audio_column: str = DEFAULT_AUDIO_COLUMN,
|
audio_column: str = DEFAULT_AUDIO_COLUMN,
|
||||||
|
ref_audio_column: str = DEFAULT_REF_AUDIO_COLUMN,
|
||||||
dataset_id_column: str = DEFAULT_ID_COLUMN,
|
dataset_id_column: str = DEFAULT_ID_COLUMN,
|
||||||
sample_rate: int = 16_000,
|
sample_rate: int = 16_000,
|
||||||
num_proc: int = 1,
|
num_proc: int = 1,
|
||||||
@@ -34,14 +36,19 @@ def load_audio_text_datasets(
|
|||||||
def prepare(ds: Dataset) -> Dataset:
|
def prepare(ds: Dataset) -> Dataset:
|
||||||
if audio_column not in ds.column_names:
|
if audio_column not in ds.column_names:
|
||||||
raise ValueError(f"Expected '{audio_column}' column in manifest.")
|
raise ValueError(f"Expected '{audio_column}' column in manifest.")
|
||||||
# We cast to Audio to ensure proper handling during training,
|
|
||||||
# but for length calculation we might need raw path or duration if available.
|
|
||||||
# HF datasets usually don't compute duration automatically for 'Audio' column.
|
|
||||||
ds = ds.cast_column(audio_column, Audio(sampling_rate=sample_rate))
|
ds = ds.cast_column(audio_column, Audio(sampling_rate=sample_rate))
|
||||||
if audio_column != DEFAULT_AUDIO_COLUMN:
|
if audio_column != DEFAULT_AUDIO_COLUMN:
|
||||||
ds = ds.rename_column(audio_column, DEFAULT_AUDIO_COLUMN)
|
ds = ds.rename_column(audio_column, DEFAULT_AUDIO_COLUMN)
|
||||||
if text_column != DEFAULT_TEXT_COLUMN:
|
if text_column != DEFAULT_TEXT_COLUMN:
|
||||||
ds = ds.rename_column(text_column, DEFAULT_TEXT_COLUMN)
|
ds = ds.rename_column(text_column, DEFAULT_TEXT_COLUMN)
|
||||||
|
|
||||||
|
# ref_audio is optional — cast to Audio if the column exists
|
||||||
|
ref_col = ref_audio_column if ref_audio_column in ds.column_names else DEFAULT_REF_AUDIO_COLUMN
|
||||||
|
if ref_col in ds.column_names:
|
||||||
|
ds = ds.cast_column(ref_col, Audio(sampling_rate=sample_rate))
|
||||||
|
if ref_col != DEFAULT_REF_AUDIO_COLUMN:
|
||||||
|
ds = ds.rename_column(ref_col, DEFAULT_REF_AUDIO_COLUMN)
|
||||||
|
|
||||||
if dataset_id_column and dataset_id_column in ds.column_names:
|
if dataset_id_column and dataset_id_column in ds.column_names:
|
||||||
if dataset_id_column != DEFAULT_ID_COLUMN:
|
if dataset_id_column != DEFAULT_ID_COLUMN:
|
||||||
ds = ds.rename_column(dataset_id_column, DEFAULT_ID_COLUMN)
|
ds = ds.rename_column(dataset_id_column, DEFAULT_ID_COLUMN)
|
||||||
@@ -67,11 +74,11 @@ def compute_sample_lengths(
|
|||||||
- 音频长度:
|
- 音频长度:
|
||||||
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
|
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
|
||||||
t_seq = ceil(t_vae / patch_size)
|
t_seq = ceil(t_vae / patch_size)
|
||||||
- 序列总长约为: text_len + t_seq + 2
|
- 无 ref_audio: text_len + t_seq + 2
|
||||||
|
- 有 ref_audio: text_len + t_seq + ref_seq + 4
|
||||||
|
|
||||||
Optimized: Use batch column access instead of iterating item by item.
|
Optimized: Use batch column access instead of iterating item by item.
|
||||||
"""
|
"""
|
||||||
# Batch access columns - much faster than per-item access
|
|
||||||
text_ids_list = ds["text_ids"]
|
text_ids_list = ds["text_ids"]
|
||||||
text_lens = [len(t) for t in text_ids_list]
|
text_lens = [len(t) for t in text_ids_list]
|
||||||
|
|
||||||
@@ -79,18 +86,35 @@ def compute_sample_lengths(
|
|||||||
if has_duration:
|
if has_duration:
|
||||||
durations = ds["duration"]
|
durations = ds["duration"]
|
||||||
else:
|
else:
|
||||||
# Fallback: need to compute from audio (slow, but unavoidable without duration column)
|
|
||||||
durations = []
|
durations = []
|
||||||
for i in range(len(ds)):
|
for i in range(len(ds)):
|
||||||
audio = ds[i][DEFAULT_AUDIO_COLUMN]
|
audio = ds[i][DEFAULT_AUDIO_COLUMN]
|
||||||
durations.append(len(audio["array"]) / float(audio["sampling_rate"]))
|
durations.append(len(audio["array"]) / float(audio["sampling_rate"]))
|
||||||
|
|
||||||
# Vectorized length computation
|
has_ref_audio = DEFAULT_REF_AUDIO_COLUMN in ds.column_names
|
||||||
|
if has_ref_audio:
|
||||||
|
ref_duration_col = "ref_duration" if "ref_duration" in ds.column_names else None
|
||||||
|
|
||||||
lengths = []
|
lengths = []
|
||||||
for text_len, duration in zip(text_lens, durations):
|
for i, (text_len, duration) in enumerate(zip(text_lens, durations)):
|
||||||
t_vae = math.ceil(float(duration) * audio_vae_fps)
|
t_vae = math.ceil(float(duration) * audio_vae_fps)
|
||||||
t_seq = math.ceil(t_vae / patch_size)
|
t_seq = math.ceil(t_vae / patch_size)
|
||||||
total_len = text_len + t_seq + 2
|
|
||||||
|
ref_seq = 0
|
||||||
|
if has_ref_audio:
|
||||||
|
# Estimate ref_audio length; ref_audio is None for samples without it
|
||||||
|
if ref_duration_col:
|
||||||
|
ref_dur = ds[i].get(ref_duration_col)
|
||||||
|
else:
|
||||||
|
ref_item = ds[i].get(DEFAULT_REF_AUDIO_COLUMN)
|
||||||
|
ref_dur = len(ref_item["array"]) / float(ref_item["sampling_rate"]) if ref_item else None
|
||||||
|
if ref_dur is not None and float(ref_dur) > 0:
|
||||||
|
ref_vae = math.ceil(float(ref_dur) * audio_vae_fps)
|
||||||
|
ref_seq = math.ceil(ref_vae / patch_size)
|
||||||
|
|
||||||
|
# +2 for 101/102; +2 more for 103/104 when ref_audio present
|
||||||
|
overhead = 4 if ref_seq > 0 else 2
|
||||||
|
total_len = text_len + t_seq + ref_seq + overhead
|
||||||
lengths.append(total_len)
|
lengths.append(total_len)
|
||||||
|
|
||||||
return lengths
|
return lengths
|
||||||
@@ -102,8 +126,11 @@ class HFVoxCPMDataset(TorchDataset):
|
|||||||
PyTorch-friendly samples.
|
PyTorch-friendly samples.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_SENTINEL = [-100.0]
|
||||||
|
|
||||||
def __init__(self, dataset: Dataset):
|
def __init__(self, dataset: Dataset):
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
|
self.has_ref_audio = DEFAULT_REF_AUDIO_COLUMN in dataset.column_names
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.dataset)
|
return len(self.dataset)
|
||||||
@@ -111,13 +138,17 @@ class HFVoxCPMDataset(TorchDataset):
|
|||||||
def __getitem__(self, idx: int):
|
def __getitem__(self, idx: int):
|
||||||
item = self.dataset[idx]
|
item = self.dataset[idx]
|
||||||
audio = item[DEFAULT_AUDIO_COLUMN]
|
audio = item[DEFAULT_AUDIO_COLUMN]
|
||||||
return {
|
sample = {
|
||||||
"text_ids": item["text_ids"],
|
"text_ids": item["text_ids"],
|
||||||
"audio_array": audio["array"],
|
"audio_array": audio["array"],
|
||||||
"audio_sampling_rate": audio["sampling_rate"],
|
"audio_sampling_rate": audio["sampling_rate"],
|
||||||
"dataset_id": item.get(DEFAULT_ID_COLUMN, 0),
|
"dataset_id": item.get(DEFAULT_ID_COLUMN, 0),
|
||||||
"is_prompt": item.get("is_prompt", False),
|
"is_prompt": item.get("is_prompt", False),
|
||||||
}
|
}
|
||||||
|
if self.has_ref_audio:
|
||||||
|
ref = item.get(DEFAULT_REF_AUDIO_COLUMN)
|
||||||
|
sample["ref_audio_array"] = ref["array"] if ref else self._SENTINEL
|
||||||
|
return sample
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def pad_sequences(seqs: List[torch.Tensor], pad_value: float):
|
def pad_sequences(seqs: List[torch.Tensor], pad_value: float):
|
||||||
@@ -143,7 +174,7 @@ class HFVoxCPMDataset(TorchDataset):
|
|||||||
audio_padded = cls.pad_sequences(audio_tensors, pad_value=-100.0)
|
audio_padded = cls.pad_sequences(audio_tensors, pad_value=-100.0)
|
||||||
task_ids = torch.ones(text_padded.size(0), dtype=torch.int32)
|
task_ids = torch.ones(text_padded.size(0), dtype=torch.int32)
|
||||||
|
|
||||||
return {
|
result = {
|
||||||
"text_tokens": text_padded,
|
"text_tokens": text_padded,
|
||||||
"audio_tokens": audio_padded,
|
"audio_tokens": audio_padded,
|
||||||
"task_ids": task_ids,
|
"task_ids": task_ids,
|
||||||
@@ -151,6 +182,12 @@ class HFVoxCPMDataset(TorchDataset):
|
|||||||
"is_prompts": is_prompts,
|
"is_prompts": is_prompts,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if "ref_audio_array" in batch[0]:
|
||||||
|
ref_tensors = [torch.tensor(s["ref_audio_array"], dtype=torch.float32) for s in batch]
|
||||||
|
result["ref_audio_tokens"] = cls.pad_sequences(ref_tensors, pad_value=-100.0)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class BatchProcessor:
|
class BatchProcessor:
|
||||||
"""
|
"""
|
||||||
@@ -184,12 +221,17 @@ class BatchProcessor:
|
|||||||
task_ids = batch["task_ids"].to(self.device)
|
task_ids = batch["task_ids"].to(self.device)
|
||||||
dataset_ids = batch["dataset_ids"].to(self.device)
|
dataset_ids = batch["dataset_ids"].to(self.device)
|
||||||
|
|
||||||
|
ref_audio_tokens = None
|
||||||
|
if "ref_audio_tokens" in batch:
|
||||||
|
ref_audio_tokens = batch["ref_audio_tokens"].to(self.device)
|
||||||
|
|
||||||
packed = self.packer(
|
packed = self.packer(
|
||||||
audio_tokens=audio_tokens,
|
audio_tokens=audio_tokens,
|
||||||
text_tokens=text_tokens,
|
text_tokens=text_tokens,
|
||||||
task_ids=task_ids,
|
task_ids=task_ids,
|
||||||
dataset_ids=dataset_ids,
|
dataset_ids=dataset_ids,
|
||||||
is_prompts=batch["is_prompts"],
|
is_prompts=batch["is_prompts"],
|
||||||
|
ref_audio_tokens=ref_audio_tokens,
|
||||||
)
|
)
|
||||||
return packed
|
return packed
|
||||||
|
|
||||||
|
|||||||
+124
-14
@@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, List
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -14,7 +14,6 @@ class AudioFeatureProcessingPacker:
|
|||||||
def __init__(self, dataset_cnt: int, max_len: int, patch_size: int, feat_dim: int, audio_vae: nn.Module):
|
def __init__(self, dataset_cnt: int, max_len: int, patch_size: int, feat_dim: int, audio_vae: nn.Module):
|
||||||
self.audio_start_id = 101
|
self.audio_start_id = 101
|
||||||
self.audio_end_id = 102
|
self.audio_end_id = 102
|
||||||
# unused now
|
|
||||||
self.audio_prompt_start_id = 103
|
self.audio_prompt_start_id = 103
|
||||||
self.audio_prompt_end_id = 104
|
self.audio_prompt_end_id = 104
|
||||||
self.text_eos_token_id = 2
|
self.text_eos_token_id = 2
|
||||||
@@ -78,11 +77,16 @@ class AudioFeatureProcessingPacker:
|
|||||||
task_ids: torch.Tensor,
|
task_ids: torch.Tensor,
|
||||||
dataset_ids: torch.Tensor,
|
dataset_ids: torch.Tensor,
|
||||||
is_prompts: List[bool],
|
is_prompts: List[bool],
|
||||||
|
ref_audio_tokens: Optional[torch.Tensor] = None,
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Padding-based batching: each sample in the input batch is processed
|
Padding-based batching: each sample in the input batch is processed
|
||||||
independently and then padded to a common length (capped by ``max_len``).
|
independently and then padded to a common length (capped by ``max_len``).
|
||||||
The result tensors all have shape [B, T, ...].
|
The result tensors all have shape [B, T, ...].
|
||||||
|
|
||||||
|
If ``ref_audio_tokens`` is provided (same batch dim as ``audio_tokens``),
|
||||||
|
samples whose unpadded ref_audio length > 0 will be processed with the
|
||||||
|
reference-audio path (tokens 103/104 prepended, loss only on target audio).
|
||||||
"""
|
"""
|
||||||
device = audio_tokens.device
|
device = audio_tokens.device
|
||||||
max_dataset_id = int(dataset_ids.max().item()) if dataset_ids.numel() > 0 else -1
|
max_dataset_id = int(dataset_ids.max().item()) if dataset_ids.numel() > 0 else -1
|
||||||
@@ -101,23 +105,43 @@ class AudioFeatureProcessingPacker:
|
|||||||
audio_duration_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
|
audio_duration_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
|
||||||
text_token_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
|
text_token_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
|
||||||
|
|
||||||
for audio_token, text_token, task_id, dataset_idx, is_prompt in zip(
|
ref_iter = ref_audio_tokens if ref_audio_tokens is not None else [None] * audio_tokens.size(0)
|
||||||
audio_tokens, text_tokens, task_ids.tolist(), dataset_ids.tolist(), is_prompts
|
|
||||||
|
for audio_token, text_token, task_id, dataset_idx, is_prompt, ref_token in zip(
|
||||||
|
audio_tokens, text_tokens, task_ids.tolist(), dataset_ids.tolist(), is_prompts, ref_iter
|
||||||
):
|
):
|
||||||
unpad_audio_token = self.unpad_audio_tokens(audio_token).to(torch.float32)
|
unpad_audio_token = self.unpad_audio_tokens(audio_token).to(torch.float32)
|
||||||
unpad_text_token = self.unpad_text_tokens(text_token)
|
unpad_text_token = self.unpad_text_tokens(text_token)
|
||||||
usage = self.id_to_task[task_id]
|
usage = self.id_to_task[task_id]
|
||||||
|
|
||||||
(
|
has_ref = False
|
||||||
packed_text,
|
if ref_token is not None:
|
||||||
audio_feat,
|
unpad_ref_token = self.unpad_audio_tokens(ref_token).to(torch.float32)
|
||||||
text_mask,
|
if unpad_ref_token.numel() > 0:
|
||||||
audio_mask,
|
has_ref = True
|
||||||
loss_mask,
|
|
||||||
labels,
|
if has_ref:
|
||||||
audio_duration,
|
(
|
||||||
text_token_count,
|
packed_text,
|
||||||
) = self.process_functions[usage](unpad_audio_token, unpad_text_token, is_prompt)
|
audio_feat,
|
||||||
|
text_mask,
|
||||||
|
audio_mask,
|
||||||
|
loss_mask,
|
||||||
|
labels,
|
||||||
|
audio_duration,
|
||||||
|
text_token_count,
|
||||||
|
) = self.process_tts_data_with_ref(unpad_ref_token, unpad_audio_token, unpad_text_token)
|
||||||
|
else:
|
||||||
|
(
|
||||||
|
packed_text,
|
||||||
|
audio_feat,
|
||||||
|
text_mask,
|
||||||
|
audio_mask,
|
||||||
|
loss_mask,
|
||||||
|
labels,
|
||||||
|
audio_duration,
|
||||||
|
text_token_count,
|
||||||
|
) = self.process_functions[usage](unpad_audio_token, unpad_text_token, is_prompt)
|
||||||
|
|
||||||
audio_duration_consumed[dataset_idx] += audio_duration
|
audio_duration_consumed[dataset_idx] += audio_duration
|
||||||
text_token_consumed[dataset_idx] += text_token_count
|
text_token_consumed[dataset_idx] += text_token_count
|
||||||
@@ -294,3 +318,89 @@ class AudioFeatureProcessingPacker:
|
|||||||
audio_duration,
|
audio_duration,
|
||||||
text_token_count,
|
text_token_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def process_tts_data_with_ref(
|
||||||
|
self,
|
||||||
|
ref_audio_token: torch.Tensor,
|
||||||
|
target_audio_token: torch.Tensor,
|
||||||
|
text_token: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Build a training sequence with reference audio prepended:
|
||||||
|
|
||||||
|
[103, ref_feats, 104, text, 101, target_feats, 102]
|
||||||
|
|
||||||
|
Loss is computed only on the target audio segment.
|
||||||
|
"""
|
||||||
|
device = text_token.device
|
||||||
|
txt_len = len(text_token)
|
||||||
|
|
||||||
|
ref_feats, ref_duration = self.extract_audio_feats(ref_audio_token)
|
||||||
|
ref_feats = ref_feats.squeeze(0) # [R, P, D]
|
||||||
|
ref_len = ref_feats.shape[0]
|
||||||
|
|
||||||
|
tgt_feats, tgt_duration = self.extract_audio_feats(target_audio_token)
|
||||||
|
tgt_feats = tgt_feats.squeeze(0) # [A, P, D]
|
||||||
|
tgt_len = tgt_feats.shape[0]
|
||||||
|
|
||||||
|
feat_shape = (self.patch_size, ref_feats.size(-1))
|
||||||
|
|
||||||
|
def _tok(ids):
|
||||||
|
return torch.tensor(ids, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
# -- text token track --
|
||||||
|
# [103, 0×R, 104, text_ids, 101, 0×A, 102]
|
||||||
|
text_token_info = torch.cat([
|
||||||
|
_tok([self.audio_prompt_start_id]),
|
||||||
|
torch.zeros(ref_len, dtype=torch.int32, device=device),
|
||||||
|
_tok([self.audio_prompt_end_id]),
|
||||||
|
text_token,
|
||||||
|
_tok([self.audio_start_id]),
|
||||||
|
torch.zeros(tgt_len, dtype=torch.int32, device=device),
|
||||||
|
_tok([self.audio_end_id]),
|
||||||
|
])
|
||||||
|
|
||||||
|
# -- audio feature track --
|
||||||
|
zero_1 = torch.zeros((1,) + feat_shape, dtype=torch.float32, device=device)
|
||||||
|
zero_txt = torch.zeros((txt_len,) + feat_shape, dtype=torch.float32, device=device)
|
||||||
|
audio_feat_info = torch.cat([
|
||||||
|
zero_1, ref_feats, zero_1, # 103, ref, 104
|
||||||
|
zero_txt, # text
|
||||||
|
zero_1, tgt_feats, zero_1, # 101, target, 102
|
||||||
|
], dim=0)
|
||||||
|
|
||||||
|
# -- masks --
|
||||||
|
text_mask = torch.cat([
|
||||||
|
torch.ones(1), torch.zeros(ref_len), torch.ones(1),
|
||||||
|
torch.ones(txt_len),
|
||||||
|
torch.ones(1), torch.zeros(tgt_len), torch.ones(1),
|
||||||
|
]).to(torch.int32).to(device)
|
||||||
|
|
||||||
|
audio_mask = torch.cat([
|
||||||
|
torch.zeros(1), torch.ones(ref_len), torch.zeros(1),
|
||||||
|
torch.zeros(txt_len),
|
||||||
|
torch.zeros(1), torch.ones(tgt_len), torch.zeros(1),
|
||||||
|
]).to(torch.int32).to(device)
|
||||||
|
|
||||||
|
loss_mask = torch.cat([
|
||||||
|
torch.zeros(1 + ref_len + 1), # ref part: no loss
|
||||||
|
torch.zeros(txt_len), # text: no loss
|
||||||
|
torch.zeros(1), # 101: no loss
|
||||||
|
torch.ones(tgt_len), # target audio: LOSS
|
||||||
|
torch.zeros(1), # 102: no loss
|
||||||
|
]).to(torch.int32).to(device)
|
||||||
|
|
||||||
|
total_len = 1 + ref_len + 1 + txt_len + 1 + tgt_len + 1
|
||||||
|
labels = torch.zeros(total_len, dtype=torch.int32, device=device)
|
||||||
|
labels[-2] = 1 # stop label at last target audio position
|
||||||
|
|
||||||
|
return (
|
||||||
|
text_token_info,
|
||||||
|
audio_feat_info,
|
||||||
|
text_mask,
|
||||||
|
audio_mask,
|
||||||
|
loss_mask,
|
||||||
|
labels,
|
||||||
|
ref_duration + tgt_duration,
|
||||||
|
txt_len,
|
||||||
|
)
|
||||||
|
|||||||
@@ -58,6 +58,7 @@ def test_parser_defaults_to_voxcpm2():
|
|||||||
parser = cli._build_parser()
|
parser = cli._build_parser()
|
||||||
args = parser.parse_args(["design", "--text", "hello", "--output", "out.wav"])
|
args = parser.parse_args(["design", "--text", "hello", "--output", "out.wav"])
|
||||||
assert args.hf_model_id == "openbmb/VoxCPM2"
|
assert args.hf_model_id == "openbmb/VoxCPM2"
|
||||||
|
assert args.device == "auto"
|
||||||
assert args.no_optimize is False
|
assert args.no_optimize is False
|
||||||
|
|
||||||
|
|
||||||
@@ -85,6 +86,7 @@ def test_load_model_respects_no_optimize_for_local_model(monkeypatch):
|
|||||||
|
|
||||||
cli.load_model(args)
|
cli.load_model(args)
|
||||||
|
|
||||||
|
assert calls["kwargs"]["device"] == "auto"
|
||||||
assert calls["kwargs"]["optimize"] is False
|
assert calls["kwargs"]["optimize"] is False
|
||||||
|
|
||||||
|
|
||||||
@@ -110,6 +112,7 @@ def test_load_model_defaults_optimize_for_hf(monkeypatch):
|
|||||||
|
|
||||||
cli.load_model(args)
|
cli.load_model(args)
|
||||||
|
|
||||||
|
assert calls["kwargs"]["device"] == "auto"
|
||||||
assert calls["kwargs"]["optimize"] is True
|
assert calls["kwargs"]["optimize"] is True
|
||||||
|
|
||||||
|
|
||||||
@@ -136,9 +139,37 @@ def test_load_model_respects_no_optimize_for_hf(monkeypatch):
|
|||||||
|
|
||||||
cli.load_model(args)
|
cli.load_model(args)
|
||||||
|
|
||||||
|
assert calls["kwargs"]["device"] == "auto"
|
||||||
assert calls["kwargs"]["optimize"] is False
|
assert calls["kwargs"]["optimize"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_model_passes_explicit_device_to_hf(monkeypatch):
|
||||||
|
calls = {}
|
||||||
|
|
||||||
|
class FakeVoxCPM:
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, **kwargs):
|
||||||
|
calls["kwargs"] = kwargs
|
||||||
|
return DummyModel()
|
||||||
|
|
||||||
|
monkeypatch.setattr(cli, "VoxCPM", FakeVoxCPM)
|
||||||
|
args = cli._build_parser().parse_args(
|
||||||
|
[
|
||||||
|
"design",
|
||||||
|
"--text",
|
||||||
|
"hello",
|
||||||
|
"--output",
|
||||||
|
"out.wav",
|
||||||
|
"--device",
|
||||||
|
"mps",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
cli.load_model(args)
|
||||||
|
|
||||||
|
assert calls["kwargs"]["device"] == "mps"
|
||||||
|
|
||||||
|
|
||||||
def test_design_subcommand_applies_control(monkeypatch, tmp_path):
|
def test_design_subcommand_applies_control(monkeypatch, tmp_path):
|
||||||
dummy_model = DummyModel()
|
dummy_model = DummyModel()
|
||||||
monkeypatch.setattr(cli, "load_model", lambda args: dummy_model)
|
monkeypatch.setattr(cli, "load_model", lambda args: dummy_model)
|
||||||
|
|||||||
@@ -0,0 +1,49 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
UTILS_PATH = ROOT / "src" / "voxcpm" / "model" / "utils.py"
|
||||||
|
|
||||||
|
transformers_stub = types.ModuleType("transformers")
|
||||||
|
transformers_stub.PreTrainedTokenizer = object
|
||||||
|
sys.modules.setdefault("transformers", transformers_stub)
|
||||||
|
|
||||||
|
spec = importlib.util.spec_from_file_location("voxcpm.model.utils", UTILS_PATH)
|
||||||
|
utils = importlib.util.module_from_spec(spec)
|
||||||
|
assert spec.loader is not None
|
||||||
|
spec.loader.exec_module(utils)
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_runtime_device_auto_falls_back_to_cpu(monkeypatch):
|
||||||
|
monkeypatch.setattr(utils.torch.cuda, "is_available", lambda: False)
|
||||||
|
monkeypatch.setattr(utils, "_has_mps", lambda: False)
|
||||||
|
|
||||||
|
assert utils.resolve_runtime_device(None, "cuda") == "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_runtime_device_auto_uses_mps_when_available(monkeypatch):
|
||||||
|
monkeypatch.setattr(utils.torch.cuda, "is_available", lambda: False)
|
||||||
|
monkeypatch.setattr(utils, "_has_mps", lambda: True)
|
||||||
|
|
||||||
|
assert utils.resolve_runtime_device("auto", "cuda") == "mps"
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_runtime_device_respects_explicit_cpu(monkeypatch):
|
||||||
|
monkeypatch.setattr(utils.torch.cuda, "is_available", lambda: True)
|
||||||
|
monkeypatch.setattr(utils, "_has_mps", lambda: True)
|
||||||
|
|
||||||
|
assert utils.resolve_runtime_device("cpu", "cuda") == "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_runtime_device_rejects_unavailable_explicit_cuda(monkeypatch):
|
||||||
|
monkeypatch.setattr(utils.torch.cuda, "is_available", lambda: False)
|
||||||
|
monkeypatch.setattr(utils, "_has_mps", lambda: True)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="CUDA is not available"):
|
||||||
|
utils.resolve_runtime_device("cuda:0", "cuda")
|
||||||
Reference in New Issue
Block a user