diff --git a/src/voxcpm/model/utils.py b/src/voxcpm/model/utils.py index 2ca6068..a26e1c9 100644 --- a/src/voxcpm/model/utils.py +++ b/src/voxcpm/model/utils.py @@ -1,7 +1,15 @@ +import os from typing import List, Optional import torch from transformers import PreTrainedTokenizer +_LOW_PRECISION_DTYPES = {"bfloat16", "bf16", "float16", "fp16", "half"} +_VALID_DTYPE_OVERRIDES = { + "bfloat16", "bf16", + "float16", "fp16", "half", + "float32", "fp32", +} + # Ref: https://github.com/OpenBMB/VoxCPM/issues/256#issuecomment-4235252732 # Explicitly close partially-consumed generators so inference_mode cleanup @@ -135,6 +143,34 @@ def _has_mps() -> bool: return hasattr(torch.backends, "mps") and torch.backends.mps.is_available() +def pick_runtime_dtype(device: str, configured_dtype: str) -> str: + """Pick a safe runtime dtype for the resolved device. + + On Apple Silicon (MPS), bfloat16/float16 produce enough numerical drift + in the diffusion AR loop that the output is glitched and the model's + badcase detector triggers infinite retries. float32 is the only stable + option today. CUDA and CPU keep whatever the checkpoint was trained with. + + Users can override with ``VOXCPM_MPS_DTYPE`` (e.g. ``bfloat16``) when + they want to test future MPS improvements. + """ + if device != "mps": + return configured_dtype + + override = os.environ.get("VOXCPM_MPS_DTYPE", "").strip().lower() + if override: + if override not in _VALID_DTYPE_OVERRIDES: + raise ValueError( + f"VOXCPM_MPS_DTYPE='{override}' is not one of " + f"{sorted(_VALID_DTYPE_OVERRIDES)}" + ) + return override + + if (configured_dtype or "").lower() in _LOW_PRECISION_DTYPES: + return "float32" + return configured_dtype + + def auto_select_device(preferred_device: Optional[str] = "cuda") -> str: """ Choose a runtime device automatically. diff --git a/src/voxcpm/model/voxcpm.py b/src/voxcpm/model/voxcpm.py index e1c50e9..2289713 100644 --- a/src/voxcpm/model/voxcpm.py +++ b/src/voxcpm/model/voxcpm.py @@ -44,7 +44,13 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT from ..modules.locenc import VoxCPMLocEnc from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel -from .utils import get_dtype, mask_multichar_chinese_tokens, next_and_close, resolve_runtime_device +from .utils import ( + get_dtype, + mask_multichar_chinese_tokens, + next_and_close, + pick_runtime_dtype, + resolve_runtime_device, +) class VoxCPMEncoderConfig(BaseModel): @@ -118,6 +124,13 @@ class VoxCPMModel(nn.Module): self.patch_size = config.patch_size self.device = resolve_runtime_device(device, config.device) self.config.device = self.device + resolved_dtype = pick_runtime_dtype(self.device, self.config.dtype) + if resolved_dtype != self.config.dtype: + print( + f"[voxcpm] adjusted dtype {self.config.dtype} -> {resolved_dtype} for device {self.device}", + file=sys.stderr, + ) + self.config.dtype = resolved_dtype print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr) # Text-Semantic LM diff --git a/src/voxcpm/model/voxcpm2.py b/src/voxcpm/model/voxcpm2.py index b1bc6a3..5d3b410 100644 --- a/src/voxcpm/model/voxcpm2.py +++ b/src/voxcpm/model/voxcpm2.py @@ -45,7 +45,13 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiTV2 from ..modules.locenc import VoxCPMLocEnc from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel -from .utils import get_dtype, mask_multichar_chinese_tokens, next_and_close, resolve_runtime_device +from .utils import ( + get_dtype, + mask_multichar_chinese_tokens, + next_and_close, + pick_runtime_dtype, + resolve_runtime_device, +) # A simple function to trim audio silence using VAD, not used default @@ -160,6 +166,13 @@ class VoxCPM2Model(nn.Module): self.patch_size = config.patch_size self.device = resolve_runtime_device(device, config.device) self.config.device = self.device + resolved_dtype = pick_runtime_dtype(self.device, self.config.dtype) + if resolved_dtype != self.config.dtype: + print( + f"[voxcpm2] adjusted dtype {self.config.dtype} -> {resolved_dtype} for device {self.device}", + file=sys.stderr, + ) + self.config.dtype = resolved_dtype print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr) # Text-Semantic LM