fix(mps): force float32 on Apple Silicon to avoid bf16 quality loss

VoxCPM checkpoints default to bfloat16. Following commit e4e0496 which
added MPS device routing, running with `device=mps` selects bf16 on
Apple Silicon. On Metal, bf16 introduces enough numerical drift in the
diffusion AR loop that the synthesized audio is glitched and trips the
model's badcase detector, which retries until the per-call retry budget
is exhausted. Effectively MPS support is unusable in the default config.

This patch adds a single helper, `pick_runtime_dtype(device, dtype)`,
that promotes any low-precision dtype to float32 when the resolved
device is `mps`. CUDA and CPU paths are untouched. An opt-out env var
`VOXCPM_MPS_DTYPE` lets users force a specific dtype on MPS once future
PyTorch / macOS releases improve bf16 stability.

Both VoxCPMModel and VoxCPM2Model adopt the helper in their __init__,
replacing what would otherwise be duplicated inline checks.

Verified locally on Apple M5 Max, PyTorch 2.11, macOS 15:
- VoxCPM2 (2B): clean output, RTF ~0.78 steady state
- VoxCPM 0.5B: clean output, RTF ~0.92
- No badcase retries fired in any test
- VOXCPM_MPS_DTYPE=bfloat16 round-trips and reproduces the original
  glitched output, confirming the override path.
This commit is contained in:
oumnya
2026-04-15 12:22:56 +08:00
parent 1565e83efe
commit 38d61cdf03
3 changed files with 64 additions and 2 deletions
+36
View File
@@ -1,7 +1,15 @@
import os
from typing import List, Optional
import torch
from transformers import PreTrainedTokenizer
_LOW_PRECISION_DTYPES = {"bfloat16", "bf16", "float16", "fp16", "half"}
_VALID_DTYPE_OVERRIDES = {
"bfloat16", "bf16",
"float16", "fp16", "half",
"float32", "fp32",
}
# Ref: https://github.com/OpenBMB/VoxCPM/issues/256#issuecomment-4235252732
# Explicitly close partially-consumed generators so inference_mode cleanup
@@ -135,6 +143,34 @@ def _has_mps() -> bool:
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
def pick_runtime_dtype(device: str, configured_dtype: str) -> str:
"""Pick a safe runtime dtype for the resolved device.
On Apple Silicon (MPS), bfloat16/float16 produce enough numerical drift
in the diffusion AR loop that the output is glitched and the model's
badcase detector triggers infinite retries. float32 is the only stable
option today. CUDA and CPU keep whatever the checkpoint was trained with.
Users can override with ``VOXCPM_MPS_DTYPE`` (e.g. ``bfloat16``) when
they want to test future MPS improvements.
"""
if device != "mps":
return configured_dtype
override = os.environ.get("VOXCPM_MPS_DTYPE", "").strip().lower()
if override:
if override not in _VALID_DTYPE_OVERRIDES:
raise ValueError(
f"VOXCPM_MPS_DTYPE='{override}' is not one of "
f"{sorted(_VALID_DTYPE_OVERRIDES)}"
)
return override
if (configured_dtype or "").lower() in _LOW_PRECISION_DTYPES:
return "float32"
return configured_dtype
def auto_select_device(preferred_device: Optional[str] = "cuda") -> str:
"""
Choose a runtime device automatically.