Merge pull request #263 from Oumnya/fix/mps-bf16-dtype
fix(mps): force float32 on Apple Silicon to avoid bf16 quality loss
This commit is contained in:
@@ -0,0 +1,108 @@
|
||||
"""Unit checks for pick_runtime_dtype / get_dtype consistency.
|
||||
|
||||
Loads src/voxcpm/model/utils.py directly to avoid the heavy voxcpm package
|
||||
init. Run with: `python scripts/test_pick_runtime_dtype.py`.
|
||||
"""
|
||||
import importlib.util
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent
|
||||
UTILS = str(REPO_ROOT / "src" / "voxcpm" / "model" / "utils.py")
|
||||
spec = importlib.util.spec_from_file_location("voxcpm_utils", UTILS)
|
||||
utils = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(utils)
|
||||
|
||||
_LOW_PRECISION_DTYPES = utils._LOW_PRECISION_DTYPES
|
||||
_VALID_DTYPE_OVERRIDES = utils._VALID_DTYPE_OVERRIDES
|
||||
get_dtype = utils.get_dtype
|
||||
pick_runtime_dtype = utils.pick_runtime_dtype
|
||||
|
||||
|
||||
def expect(actual, expected, label):
|
||||
ok = actual == expected
|
||||
mark = "OK " if ok else "FAIL"
|
||||
print(f"[{mark}] {label}: got={actual!r} expected={expected!r}")
|
||||
return ok
|
||||
|
||||
|
||||
def expect_raises(fn, exc_type, label):
|
||||
try:
|
||||
fn()
|
||||
except exc_type as e:
|
||||
print(f"[OK ] {label}: raised {exc_type.__name__}: {e}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"[FAIL] {label}: raised {type(e).__name__} not {exc_type.__name__}: {e}")
|
||||
return False
|
||||
print(f"[FAIL] {label}: no exception raised")
|
||||
return False
|
||||
|
||||
|
||||
results = []
|
||||
|
||||
print("=== override set sanity ===")
|
||||
results.append(expect("half" not in _VALID_DTYPE_OVERRIDES, True, "half removed from _VALID_DTYPE_OVERRIDES"))
|
||||
results.append(expect("half" not in _LOW_PRECISION_DTYPES, True, "half removed from _LOW_PRECISION_DTYPES"))
|
||||
|
||||
print("\n=== every accepted override parses through get_dtype ===")
|
||||
for dt in sorted(_VALID_DTYPE_OVERRIDES):
|
||||
try:
|
||||
torch_dtype = get_dtype(dt)
|
||||
print(f"[OK ] get_dtype({dt!r}) -> {torch_dtype}")
|
||||
results.append(True)
|
||||
except Exception as e:
|
||||
print(f"[FAIL] get_dtype({dt!r}) raised: {e}")
|
||||
results.append(False)
|
||||
|
||||
print("\n=== pick_runtime_dtype: non-mps is a no-op ===")
|
||||
results.append(expect(pick_runtime_dtype("cuda", "bfloat16"), "bfloat16", "cuda/bf16 untouched"))
|
||||
results.append(expect(pick_runtime_dtype("cpu", "float16"), "float16", "cpu/fp16 untouched"))
|
||||
results.append(expect(pick_runtime_dtype("cuda", "float32"), "float32", "cuda/fp32 untouched"))
|
||||
|
||||
print("\n=== pick_runtime_dtype: mps forces fp32 for low-precision ===")
|
||||
os.environ.pop("VOXCPM_MPS_DTYPE", None)
|
||||
results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "float32", "mps/bf16 -> fp32"))
|
||||
results.append(expect(pick_runtime_dtype("mps", "bf16"), "float32", "mps/bf16-alias -> fp32"))
|
||||
results.append(expect(pick_runtime_dtype("mps", "float16"), "float32", "mps/fp16 -> fp32"))
|
||||
results.append(expect(pick_runtime_dtype("mps", "fp16"), "float32", "mps/fp16-alias -> fp32"))
|
||||
results.append(expect(pick_runtime_dtype("mps", "float32"), "float32", "mps/fp32 stays"))
|
||||
results.append(expect(pick_runtime_dtype("mps", "fp32"), "fp32", "mps/fp32-alias stays"))
|
||||
|
||||
print("\n=== pick_runtime_dtype: VOXCPM_MPS_DTYPE override ===")
|
||||
os.environ["VOXCPM_MPS_DTYPE"] = "bfloat16"
|
||||
results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "bfloat16", "override bf16 honored"))
|
||||
|
||||
os.environ["VOXCPM_MPS_DTYPE"] = "FP16"
|
||||
results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "fp16", "override is case-insensitive"))
|
||||
|
||||
os.environ["VOXCPM_MPS_DTYPE"] = " float32 "
|
||||
results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "float32", "override is whitespace-trimmed"))
|
||||
|
||||
print("\n=== pick_runtime_dtype: 'half' is no longer a valid override ===")
|
||||
os.environ["VOXCPM_MPS_DTYPE"] = "half"
|
||||
results.append(
|
||||
expect_raises(
|
||||
lambda: pick_runtime_dtype("mps", "bfloat16"),
|
||||
ValueError,
|
||||
"override=half now rejected (was the bug)",
|
||||
)
|
||||
)
|
||||
|
||||
os.environ["VOXCPM_MPS_DTYPE"] = "garbage"
|
||||
results.append(
|
||||
expect_raises(
|
||||
lambda: pick_runtime_dtype("mps", "bfloat16"),
|
||||
ValueError,
|
||||
"override=garbage still rejected",
|
||||
)
|
||||
)
|
||||
|
||||
os.environ.pop("VOXCPM_MPS_DTYPE", None)
|
||||
|
||||
print("\n=== summary ===")
|
||||
passed = sum(results)
|
||||
total = len(results)
|
||||
print(f"{passed}/{total} passed")
|
||||
sys.exit(0 if passed == total else 1)
|
||||
@@ -1,7 +1,15 @@
|
||||
import os
|
||||
from typing import List, Optional
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
_LOW_PRECISION_DTYPES = {"bfloat16", "bf16", "float16", "fp16"}
|
||||
_VALID_DTYPE_OVERRIDES = {
|
||||
"bfloat16", "bf16",
|
||||
"float16", "fp16",
|
||||
"float32", "fp32",
|
||||
}
|
||||
|
||||
|
||||
# Ref: https://github.com/OpenBMB/VoxCPM/issues/256#issuecomment-4235252732
|
||||
# Explicitly close partially-consumed generators so inference_mode cleanup
|
||||
@@ -135,6 +143,34 @@ def _has_mps() -> bool:
|
||||
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||
|
||||
|
||||
def pick_runtime_dtype(device: str, configured_dtype: str) -> str:
|
||||
"""Pick a safe runtime dtype for the resolved device.
|
||||
|
||||
On Apple Silicon (MPS), bfloat16/float16 produce enough numerical drift
|
||||
in the diffusion AR loop that the output is glitched and the model's
|
||||
badcase detector triggers infinite retries. float32 is the only stable
|
||||
option today. CUDA and CPU keep whatever the checkpoint was trained with.
|
||||
|
||||
Users can override with ``VOXCPM_MPS_DTYPE`` (e.g. ``bfloat16``) when
|
||||
they want to test future MPS improvements.
|
||||
"""
|
||||
if device != "mps":
|
||||
return configured_dtype
|
||||
|
||||
override = os.environ.get("VOXCPM_MPS_DTYPE", "").strip().lower()
|
||||
if override:
|
||||
if override not in _VALID_DTYPE_OVERRIDES:
|
||||
raise ValueError(
|
||||
f"VOXCPM_MPS_DTYPE='{override}' is not one of "
|
||||
f"{sorted(_VALID_DTYPE_OVERRIDES)}"
|
||||
)
|
||||
return override
|
||||
|
||||
if (configured_dtype or "").lower() in _LOW_PRECISION_DTYPES:
|
||||
return "float32"
|
||||
return configured_dtype
|
||||
|
||||
|
||||
def auto_select_device(preferred_device: Optional[str] = "cuda") -> str:
|
||||
"""
|
||||
Choose a runtime device automatically.
|
||||
|
||||
@@ -44,7 +44,13 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules
|
||||
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
|
||||
from ..modules.locenc import VoxCPMLocEnc
|
||||
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
||||
from .utils import get_dtype, mask_multichar_chinese_tokens, next_and_close, resolve_runtime_device
|
||||
from .utils import (
|
||||
get_dtype,
|
||||
mask_multichar_chinese_tokens,
|
||||
next_and_close,
|
||||
pick_runtime_dtype,
|
||||
resolve_runtime_device,
|
||||
)
|
||||
|
||||
|
||||
class VoxCPMEncoderConfig(BaseModel):
|
||||
@@ -118,6 +124,13 @@ class VoxCPMModel(nn.Module):
|
||||
self.patch_size = config.patch_size
|
||||
self.device = resolve_runtime_device(device, config.device)
|
||||
self.config.device = self.device
|
||||
resolved_dtype = pick_runtime_dtype(self.device, self.config.dtype)
|
||||
if resolved_dtype != self.config.dtype:
|
||||
print(
|
||||
f"[voxcpm] adjusted dtype {self.config.dtype} -> {resolved_dtype} for device {self.device}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
self.config.dtype = resolved_dtype
|
||||
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
|
||||
|
||||
# Text-Semantic LM
|
||||
|
||||
@@ -45,7 +45,13 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules
|
||||
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiTV2
|
||||
from ..modules.locenc import VoxCPMLocEnc
|
||||
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
||||
from .utils import get_dtype, mask_multichar_chinese_tokens, next_and_close, resolve_runtime_device
|
||||
from .utils import (
|
||||
get_dtype,
|
||||
mask_multichar_chinese_tokens,
|
||||
next_and_close,
|
||||
pick_runtime_dtype,
|
||||
resolve_runtime_device,
|
||||
)
|
||||
|
||||
|
||||
# A simple function to trim audio silence using VAD, not used default
|
||||
@@ -160,6 +166,13 @@ class VoxCPM2Model(nn.Module):
|
||||
self.patch_size = config.patch_size
|
||||
self.device = resolve_runtime_device(device, config.device)
|
||||
self.config.device = self.device
|
||||
resolved_dtype = pick_runtime_dtype(self.device, self.config.dtype)
|
||||
if resolved_dtype != self.config.dtype:
|
||||
print(
|
||||
f"[voxcpm2] adjusted dtype {self.config.dtype} -> {resolved_dtype} for device {self.device}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
self.config.dtype = resolved_dtype
|
||||
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
|
||||
|
||||
# Text-Semantic LM
|
||||
|
||||
Reference in New Issue
Block a user