Merge pull request #263 from Oumnya/fix/mps-bf16-dtype

fix(mps): force float32 on Apple Silicon to avoid bf16 quality loss
This commit is contained in:
ZGY
2026-04-21 18:49:48 +08:00
committed by GitHub
4 changed files with 172 additions and 2 deletions
+108
View File
@@ -0,0 +1,108 @@
"""Unit checks for pick_runtime_dtype / get_dtype consistency.
Loads src/voxcpm/model/utils.py directly to avoid the heavy voxcpm package
init. Run with: `python scripts/test_pick_runtime_dtype.py`.
"""
import importlib.util
import os
import pathlib
import sys
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent
UTILS = str(REPO_ROOT / "src" / "voxcpm" / "model" / "utils.py")
spec = importlib.util.spec_from_file_location("voxcpm_utils", UTILS)
utils = importlib.util.module_from_spec(spec)
spec.loader.exec_module(utils)
_LOW_PRECISION_DTYPES = utils._LOW_PRECISION_DTYPES
_VALID_DTYPE_OVERRIDES = utils._VALID_DTYPE_OVERRIDES
get_dtype = utils.get_dtype
pick_runtime_dtype = utils.pick_runtime_dtype
def expect(actual, expected, label):
ok = actual == expected
mark = "OK " if ok else "FAIL"
print(f"[{mark}] {label}: got={actual!r} expected={expected!r}")
return ok
def expect_raises(fn, exc_type, label):
try:
fn()
except exc_type as e:
print(f"[OK ] {label}: raised {exc_type.__name__}: {e}")
return True
except Exception as e:
print(f"[FAIL] {label}: raised {type(e).__name__} not {exc_type.__name__}: {e}")
return False
print(f"[FAIL] {label}: no exception raised")
return False
results = []
print("=== override set sanity ===")
results.append(expect("half" not in _VALID_DTYPE_OVERRIDES, True, "half removed from _VALID_DTYPE_OVERRIDES"))
results.append(expect("half" not in _LOW_PRECISION_DTYPES, True, "half removed from _LOW_PRECISION_DTYPES"))
print("\n=== every accepted override parses through get_dtype ===")
for dt in sorted(_VALID_DTYPE_OVERRIDES):
try:
torch_dtype = get_dtype(dt)
print(f"[OK ] get_dtype({dt!r}) -> {torch_dtype}")
results.append(True)
except Exception as e:
print(f"[FAIL] get_dtype({dt!r}) raised: {e}")
results.append(False)
print("\n=== pick_runtime_dtype: non-mps is a no-op ===")
results.append(expect(pick_runtime_dtype("cuda", "bfloat16"), "bfloat16", "cuda/bf16 untouched"))
results.append(expect(pick_runtime_dtype("cpu", "float16"), "float16", "cpu/fp16 untouched"))
results.append(expect(pick_runtime_dtype("cuda", "float32"), "float32", "cuda/fp32 untouched"))
print("\n=== pick_runtime_dtype: mps forces fp32 for low-precision ===")
os.environ.pop("VOXCPM_MPS_DTYPE", None)
results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "float32", "mps/bf16 -> fp32"))
results.append(expect(pick_runtime_dtype("mps", "bf16"), "float32", "mps/bf16-alias -> fp32"))
results.append(expect(pick_runtime_dtype("mps", "float16"), "float32", "mps/fp16 -> fp32"))
results.append(expect(pick_runtime_dtype("mps", "fp16"), "float32", "mps/fp16-alias -> fp32"))
results.append(expect(pick_runtime_dtype("mps", "float32"), "float32", "mps/fp32 stays"))
results.append(expect(pick_runtime_dtype("mps", "fp32"), "fp32", "mps/fp32-alias stays"))
print("\n=== pick_runtime_dtype: VOXCPM_MPS_DTYPE override ===")
os.environ["VOXCPM_MPS_DTYPE"] = "bfloat16"
results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "bfloat16", "override bf16 honored"))
os.environ["VOXCPM_MPS_DTYPE"] = "FP16"
results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "fp16", "override is case-insensitive"))
os.environ["VOXCPM_MPS_DTYPE"] = " float32 "
results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "float32", "override is whitespace-trimmed"))
print("\n=== pick_runtime_dtype: 'half' is no longer a valid override ===")
os.environ["VOXCPM_MPS_DTYPE"] = "half"
results.append(
expect_raises(
lambda: pick_runtime_dtype("mps", "bfloat16"),
ValueError,
"override=half now rejected (was the bug)",
)
)
os.environ["VOXCPM_MPS_DTYPE"] = "garbage"
results.append(
expect_raises(
lambda: pick_runtime_dtype("mps", "bfloat16"),
ValueError,
"override=garbage still rejected",
)
)
os.environ.pop("VOXCPM_MPS_DTYPE", None)
print("\n=== summary ===")
passed = sum(results)
total = len(results)
print(f"{passed}/{total} passed")
sys.exit(0 if passed == total else 1)
+36
View File
@@ -1,7 +1,15 @@
import os
from typing import List, Optional
import torch
from transformers import PreTrainedTokenizer
_LOW_PRECISION_DTYPES = {"bfloat16", "bf16", "float16", "fp16"}
_VALID_DTYPE_OVERRIDES = {
"bfloat16", "bf16",
"float16", "fp16",
"float32", "fp32",
}
# Ref: https://github.com/OpenBMB/VoxCPM/issues/256#issuecomment-4235252732
# Explicitly close partially-consumed generators so inference_mode cleanup
@@ -135,6 +143,34 @@ def _has_mps() -> bool:
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
def pick_runtime_dtype(device: str, configured_dtype: str) -> str:
"""Pick a safe runtime dtype for the resolved device.
On Apple Silicon (MPS), bfloat16/float16 produce enough numerical drift
in the diffusion AR loop that the output is glitched and the model's
badcase detector triggers infinite retries. float32 is the only stable
option today. CUDA and CPU keep whatever the checkpoint was trained with.
Users can override with ``VOXCPM_MPS_DTYPE`` (e.g. ``bfloat16``) when
they want to test future MPS improvements.
"""
if device != "mps":
return configured_dtype
override = os.environ.get("VOXCPM_MPS_DTYPE", "").strip().lower()
if override:
if override not in _VALID_DTYPE_OVERRIDES:
raise ValueError(
f"VOXCPM_MPS_DTYPE='{override}' is not one of "
f"{sorted(_VALID_DTYPE_OVERRIDES)}"
)
return override
if (configured_dtype or "").lower() in _LOW_PRECISION_DTYPES:
return "float32"
return configured_dtype
def auto_select_device(preferred_device: Optional[str] = "cuda") -> str:
"""
Choose a runtime device automatically.
+14 -1
View File
@@ -44,7 +44,13 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
from ..modules.locenc import VoxCPMLocEnc
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
from .utils import get_dtype, mask_multichar_chinese_tokens, next_and_close, resolve_runtime_device
from .utils import (
get_dtype,
mask_multichar_chinese_tokens,
next_and_close,
pick_runtime_dtype,
resolve_runtime_device,
)
class VoxCPMEncoderConfig(BaseModel):
@@ -118,6 +124,13 @@ class VoxCPMModel(nn.Module):
self.patch_size = config.patch_size
self.device = resolve_runtime_device(device, config.device)
self.config.device = self.device
resolved_dtype = pick_runtime_dtype(self.device, self.config.dtype)
if resolved_dtype != self.config.dtype:
print(
f"[voxcpm] adjusted dtype {self.config.dtype} -> {resolved_dtype} for device {self.device}",
file=sys.stderr,
)
self.config.dtype = resolved_dtype
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
# Text-Semantic LM
+14 -1
View File
@@ -45,7 +45,13 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiTV2
from ..modules.locenc import VoxCPMLocEnc
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
from .utils import get_dtype, mask_multichar_chinese_tokens, next_and_close, resolve_runtime_device
from .utils import (
get_dtype,
mask_multichar_chinese_tokens,
next_and_close,
pick_runtime_dtype,
resolve_runtime_device,
)
# A simple function to trim audio silence using VAD, not used default
@@ -160,6 +166,13 @@ class VoxCPM2Model(nn.Module):
self.patch_size = config.patch_size
self.device = resolve_runtime_device(device, config.device)
self.config.device = self.device
resolved_dtype = pick_runtime_dtype(self.device, self.config.dtype)
if resolved_dtype != self.config.dtype:
print(
f"[voxcpm2] adjusted dtype {self.config.dtype} -> {resolved_dtype} for device {self.device}",
file=sys.stderr,
)
self.config.dtype = resolved_dtype
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
# Text-Semantic LM