From 38d61cdf035765e75bf204b47bc2f7464a20ecea Mon Sep 17 00:00:00 2001 From: oumnya Date: Wed, 15 Apr 2026 12:22:56 +0800 Subject: [PATCH] fix(mps): force float32 on Apple Silicon to avoid bf16 quality loss VoxCPM checkpoints default to bfloat16. Following commit e4e0496 which added MPS device routing, running with `device=mps` selects bf16 on Apple Silicon. On Metal, bf16 introduces enough numerical drift in the diffusion AR loop that the synthesized audio is glitched and trips the model's badcase detector, which retries until the per-call retry budget is exhausted. Effectively MPS support is unusable in the default config. This patch adds a single helper, `pick_runtime_dtype(device, dtype)`, that promotes any low-precision dtype to float32 when the resolved device is `mps`. CUDA and CPU paths are untouched. An opt-out env var `VOXCPM_MPS_DTYPE` lets users force a specific dtype on MPS once future PyTorch / macOS releases improve bf16 stability. Both VoxCPMModel and VoxCPM2Model adopt the helper in their __init__, replacing what would otherwise be duplicated inline checks. Verified locally on Apple M5 Max, PyTorch 2.11, macOS 15: - VoxCPM2 (2B): clean output, RTF ~0.78 steady state - VoxCPM 0.5B: clean output, RTF ~0.92 - No badcase retries fired in any test - VOXCPM_MPS_DTYPE=bfloat16 round-trips and reproduces the original glitched output, confirming the override path. --- src/voxcpm/model/utils.py | 36 ++++++++++++++++++++++++++++++++++++ src/voxcpm/model/voxcpm.py | 15 ++++++++++++++- src/voxcpm/model/voxcpm2.py | 15 ++++++++++++++- 3 files changed, 64 insertions(+), 2 deletions(-) diff --git a/src/voxcpm/model/utils.py b/src/voxcpm/model/utils.py index 2ca6068..a26e1c9 100644 --- a/src/voxcpm/model/utils.py +++ b/src/voxcpm/model/utils.py @@ -1,7 +1,15 @@ +import os from typing import List, Optional import torch from transformers import PreTrainedTokenizer +_LOW_PRECISION_DTYPES = {"bfloat16", "bf16", "float16", "fp16", "half"} +_VALID_DTYPE_OVERRIDES = { + "bfloat16", "bf16", + "float16", "fp16", "half", + "float32", "fp32", +} + # Ref: https://github.com/OpenBMB/VoxCPM/issues/256#issuecomment-4235252732 # Explicitly close partially-consumed generators so inference_mode cleanup @@ -135,6 +143,34 @@ def _has_mps() -> bool: return hasattr(torch.backends, "mps") and torch.backends.mps.is_available() +def pick_runtime_dtype(device: str, configured_dtype: str) -> str: + """Pick a safe runtime dtype for the resolved device. + + On Apple Silicon (MPS), bfloat16/float16 produce enough numerical drift + in the diffusion AR loop that the output is glitched and the model's + badcase detector triggers infinite retries. float32 is the only stable + option today. CUDA and CPU keep whatever the checkpoint was trained with. + + Users can override with ``VOXCPM_MPS_DTYPE`` (e.g. ``bfloat16``) when + they want to test future MPS improvements. + """ + if device != "mps": + return configured_dtype + + override = os.environ.get("VOXCPM_MPS_DTYPE", "").strip().lower() + if override: + if override not in _VALID_DTYPE_OVERRIDES: + raise ValueError( + f"VOXCPM_MPS_DTYPE='{override}' is not one of " + f"{sorted(_VALID_DTYPE_OVERRIDES)}" + ) + return override + + if (configured_dtype or "").lower() in _LOW_PRECISION_DTYPES: + return "float32" + return configured_dtype + + def auto_select_device(preferred_device: Optional[str] = "cuda") -> str: """ Choose a runtime device automatically. diff --git a/src/voxcpm/model/voxcpm.py b/src/voxcpm/model/voxcpm.py index e1c50e9..2289713 100644 --- a/src/voxcpm/model/voxcpm.py +++ b/src/voxcpm/model/voxcpm.py @@ -44,7 +44,13 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT from ..modules.locenc import VoxCPMLocEnc from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel -from .utils import get_dtype, mask_multichar_chinese_tokens, next_and_close, resolve_runtime_device +from .utils import ( + get_dtype, + mask_multichar_chinese_tokens, + next_and_close, + pick_runtime_dtype, + resolve_runtime_device, +) class VoxCPMEncoderConfig(BaseModel): @@ -118,6 +124,13 @@ class VoxCPMModel(nn.Module): self.patch_size = config.patch_size self.device = resolve_runtime_device(device, config.device) self.config.device = self.device + resolved_dtype = pick_runtime_dtype(self.device, self.config.dtype) + if resolved_dtype != self.config.dtype: + print( + f"[voxcpm] adjusted dtype {self.config.dtype} -> {resolved_dtype} for device {self.device}", + file=sys.stderr, + ) + self.config.dtype = resolved_dtype print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr) # Text-Semantic LM diff --git a/src/voxcpm/model/voxcpm2.py b/src/voxcpm/model/voxcpm2.py index b1bc6a3..5d3b410 100644 --- a/src/voxcpm/model/voxcpm2.py +++ b/src/voxcpm/model/voxcpm2.py @@ -45,7 +45,13 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiTV2 from ..modules.locenc import VoxCPMLocEnc from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel -from .utils import get_dtype, mask_multichar_chinese_tokens, next_and_close, resolve_runtime_device +from .utils import ( + get_dtype, + mask_multichar_chinese_tokens, + next_and_close, + pick_runtime_dtype, + resolve_runtime_device, +) # A simple function to trim audio silence using VAD, not used default @@ -160,6 +166,13 @@ class VoxCPM2Model(nn.Module): self.patch_size = config.patch_size self.device = resolve_runtime_device(device, config.device) self.config.device = self.device + resolved_dtype = pick_runtime_dtype(self.device, self.config.dtype) + if resolved_dtype != self.config.dtype: + print( + f"[voxcpm2] adjusted dtype {self.config.dtype} -> {resolved_dtype} for device {self.device}", + file=sys.stderr, + ) + self.config.dtype = resolved_dtype print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr) # Text-Semantic LM