Merge pull request #263 from Oumnya/fix/mps-bf16-dtype
fix(mps): force float32 on Apple Silicon to avoid bf16 quality loss
This commit is contained in:
@@ -0,0 +1,108 @@
|
|||||||
|
"""Unit checks for pick_runtime_dtype / get_dtype consistency.
|
||||||
|
|
||||||
|
Loads src/voxcpm/model/utils.py directly to avoid the heavy voxcpm package
|
||||||
|
init. Run with: `python scripts/test_pick_runtime_dtype.py`.
|
||||||
|
"""
|
||||||
|
import importlib.util
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import sys
|
||||||
|
|
||||||
|
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent
|
||||||
|
UTILS = str(REPO_ROOT / "src" / "voxcpm" / "model" / "utils.py")
|
||||||
|
spec = importlib.util.spec_from_file_location("voxcpm_utils", UTILS)
|
||||||
|
utils = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(utils)
|
||||||
|
|
||||||
|
_LOW_PRECISION_DTYPES = utils._LOW_PRECISION_DTYPES
|
||||||
|
_VALID_DTYPE_OVERRIDES = utils._VALID_DTYPE_OVERRIDES
|
||||||
|
get_dtype = utils.get_dtype
|
||||||
|
pick_runtime_dtype = utils.pick_runtime_dtype
|
||||||
|
|
||||||
|
|
||||||
|
def expect(actual, expected, label):
|
||||||
|
ok = actual == expected
|
||||||
|
mark = "OK " if ok else "FAIL"
|
||||||
|
print(f"[{mark}] {label}: got={actual!r} expected={expected!r}")
|
||||||
|
return ok
|
||||||
|
|
||||||
|
|
||||||
|
def expect_raises(fn, exc_type, label):
|
||||||
|
try:
|
||||||
|
fn()
|
||||||
|
except exc_type as e:
|
||||||
|
print(f"[OK ] {label}: raised {exc_type.__name__}: {e}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[FAIL] {label}: raised {type(e).__name__} not {exc_type.__name__}: {e}")
|
||||||
|
return False
|
||||||
|
print(f"[FAIL] {label}: no exception raised")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
print("=== override set sanity ===")
|
||||||
|
results.append(expect("half" not in _VALID_DTYPE_OVERRIDES, True, "half removed from _VALID_DTYPE_OVERRIDES"))
|
||||||
|
results.append(expect("half" not in _LOW_PRECISION_DTYPES, True, "half removed from _LOW_PRECISION_DTYPES"))
|
||||||
|
|
||||||
|
print("\n=== every accepted override parses through get_dtype ===")
|
||||||
|
for dt in sorted(_VALID_DTYPE_OVERRIDES):
|
||||||
|
try:
|
||||||
|
torch_dtype = get_dtype(dt)
|
||||||
|
print(f"[OK ] get_dtype({dt!r}) -> {torch_dtype}")
|
||||||
|
results.append(True)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[FAIL] get_dtype({dt!r}) raised: {e}")
|
||||||
|
results.append(False)
|
||||||
|
|
||||||
|
print("\n=== pick_runtime_dtype: non-mps is a no-op ===")
|
||||||
|
results.append(expect(pick_runtime_dtype("cuda", "bfloat16"), "bfloat16", "cuda/bf16 untouched"))
|
||||||
|
results.append(expect(pick_runtime_dtype("cpu", "float16"), "float16", "cpu/fp16 untouched"))
|
||||||
|
results.append(expect(pick_runtime_dtype("cuda", "float32"), "float32", "cuda/fp32 untouched"))
|
||||||
|
|
||||||
|
print("\n=== pick_runtime_dtype: mps forces fp32 for low-precision ===")
|
||||||
|
os.environ.pop("VOXCPM_MPS_DTYPE", None)
|
||||||
|
results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "float32", "mps/bf16 -> fp32"))
|
||||||
|
results.append(expect(pick_runtime_dtype("mps", "bf16"), "float32", "mps/bf16-alias -> fp32"))
|
||||||
|
results.append(expect(pick_runtime_dtype("mps", "float16"), "float32", "mps/fp16 -> fp32"))
|
||||||
|
results.append(expect(pick_runtime_dtype("mps", "fp16"), "float32", "mps/fp16-alias -> fp32"))
|
||||||
|
results.append(expect(pick_runtime_dtype("mps", "float32"), "float32", "mps/fp32 stays"))
|
||||||
|
results.append(expect(pick_runtime_dtype("mps", "fp32"), "fp32", "mps/fp32-alias stays"))
|
||||||
|
|
||||||
|
print("\n=== pick_runtime_dtype: VOXCPM_MPS_DTYPE override ===")
|
||||||
|
os.environ["VOXCPM_MPS_DTYPE"] = "bfloat16"
|
||||||
|
results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "bfloat16", "override bf16 honored"))
|
||||||
|
|
||||||
|
os.environ["VOXCPM_MPS_DTYPE"] = "FP16"
|
||||||
|
results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "fp16", "override is case-insensitive"))
|
||||||
|
|
||||||
|
os.environ["VOXCPM_MPS_DTYPE"] = " float32 "
|
||||||
|
results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "float32", "override is whitespace-trimmed"))
|
||||||
|
|
||||||
|
print("\n=== pick_runtime_dtype: 'half' is no longer a valid override ===")
|
||||||
|
os.environ["VOXCPM_MPS_DTYPE"] = "half"
|
||||||
|
results.append(
|
||||||
|
expect_raises(
|
||||||
|
lambda: pick_runtime_dtype("mps", "bfloat16"),
|
||||||
|
ValueError,
|
||||||
|
"override=half now rejected (was the bug)",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
os.environ["VOXCPM_MPS_DTYPE"] = "garbage"
|
||||||
|
results.append(
|
||||||
|
expect_raises(
|
||||||
|
lambda: pick_runtime_dtype("mps", "bfloat16"),
|
||||||
|
ValueError,
|
||||||
|
"override=garbage still rejected",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
os.environ.pop("VOXCPM_MPS_DTYPE", None)
|
||||||
|
|
||||||
|
print("\n=== summary ===")
|
||||||
|
passed = sum(results)
|
||||||
|
total = len(results)
|
||||||
|
print(f"{passed}/{total} passed")
|
||||||
|
sys.exit(0 if passed == total else 1)
|
||||||
@@ -1,7 +1,15 @@
|
|||||||
|
import os
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
import torch
|
import torch
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
_LOW_PRECISION_DTYPES = {"bfloat16", "bf16", "float16", "fp16"}
|
||||||
|
_VALID_DTYPE_OVERRIDES = {
|
||||||
|
"bfloat16", "bf16",
|
||||||
|
"float16", "fp16",
|
||||||
|
"float32", "fp32",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# Ref: https://github.com/OpenBMB/VoxCPM/issues/256#issuecomment-4235252732
|
# Ref: https://github.com/OpenBMB/VoxCPM/issues/256#issuecomment-4235252732
|
||||||
# Explicitly close partially-consumed generators so inference_mode cleanup
|
# Explicitly close partially-consumed generators so inference_mode cleanup
|
||||||
@@ -135,6 +143,34 @@ def _has_mps() -> bool:
|
|||||||
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||||
|
|
||||||
|
|
||||||
|
def pick_runtime_dtype(device: str, configured_dtype: str) -> str:
|
||||||
|
"""Pick a safe runtime dtype for the resolved device.
|
||||||
|
|
||||||
|
On Apple Silicon (MPS), bfloat16/float16 produce enough numerical drift
|
||||||
|
in the diffusion AR loop that the output is glitched and the model's
|
||||||
|
badcase detector triggers infinite retries. float32 is the only stable
|
||||||
|
option today. CUDA and CPU keep whatever the checkpoint was trained with.
|
||||||
|
|
||||||
|
Users can override with ``VOXCPM_MPS_DTYPE`` (e.g. ``bfloat16``) when
|
||||||
|
they want to test future MPS improvements.
|
||||||
|
"""
|
||||||
|
if device != "mps":
|
||||||
|
return configured_dtype
|
||||||
|
|
||||||
|
override = os.environ.get("VOXCPM_MPS_DTYPE", "").strip().lower()
|
||||||
|
if override:
|
||||||
|
if override not in _VALID_DTYPE_OVERRIDES:
|
||||||
|
raise ValueError(
|
||||||
|
f"VOXCPM_MPS_DTYPE='{override}' is not one of "
|
||||||
|
f"{sorted(_VALID_DTYPE_OVERRIDES)}"
|
||||||
|
)
|
||||||
|
return override
|
||||||
|
|
||||||
|
if (configured_dtype or "").lower() in _LOW_PRECISION_DTYPES:
|
||||||
|
return "float32"
|
||||||
|
return configured_dtype
|
||||||
|
|
||||||
|
|
||||||
def auto_select_device(preferred_device: Optional[str] = "cuda") -> str:
|
def auto_select_device(preferred_device: Optional[str] = "cuda") -> str:
|
||||||
"""
|
"""
|
||||||
Choose a runtime device automatically.
|
Choose a runtime device automatically.
|
||||||
|
|||||||
@@ -44,7 +44,13 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules
|
|||||||
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
|
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
|
||||||
from ..modules.locenc import VoxCPMLocEnc
|
from ..modules.locenc import VoxCPMLocEnc
|
||||||
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
||||||
from .utils import get_dtype, mask_multichar_chinese_tokens, next_and_close, resolve_runtime_device
|
from .utils import (
|
||||||
|
get_dtype,
|
||||||
|
mask_multichar_chinese_tokens,
|
||||||
|
next_and_close,
|
||||||
|
pick_runtime_dtype,
|
||||||
|
resolve_runtime_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class VoxCPMEncoderConfig(BaseModel):
|
class VoxCPMEncoderConfig(BaseModel):
|
||||||
@@ -118,6 +124,13 @@ class VoxCPMModel(nn.Module):
|
|||||||
self.patch_size = config.patch_size
|
self.patch_size = config.patch_size
|
||||||
self.device = resolve_runtime_device(device, config.device)
|
self.device = resolve_runtime_device(device, config.device)
|
||||||
self.config.device = self.device
|
self.config.device = self.device
|
||||||
|
resolved_dtype = pick_runtime_dtype(self.device, self.config.dtype)
|
||||||
|
if resolved_dtype != self.config.dtype:
|
||||||
|
print(
|
||||||
|
f"[voxcpm] adjusted dtype {self.config.dtype} -> {resolved_dtype} for device {self.device}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
self.config.dtype = resolved_dtype
|
||||||
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
|
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
|
||||||
|
|
||||||
# Text-Semantic LM
|
# Text-Semantic LM
|
||||||
|
|||||||
@@ -45,7 +45,13 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules
|
|||||||
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiTV2
|
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiTV2
|
||||||
from ..modules.locenc import VoxCPMLocEnc
|
from ..modules.locenc import VoxCPMLocEnc
|
||||||
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
||||||
from .utils import get_dtype, mask_multichar_chinese_tokens, next_and_close, resolve_runtime_device
|
from .utils import (
|
||||||
|
get_dtype,
|
||||||
|
mask_multichar_chinese_tokens,
|
||||||
|
next_and_close,
|
||||||
|
pick_runtime_dtype,
|
||||||
|
resolve_runtime_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# A simple function to trim audio silence using VAD, not used default
|
# A simple function to trim audio silence using VAD, not used default
|
||||||
@@ -160,6 +166,13 @@ class VoxCPM2Model(nn.Module):
|
|||||||
self.patch_size = config.patch_size
|
self.patch_size = config.patch_size
|
||||||
self.device = resolve_runtime_device(device, config.device)
|
self.device = resolve_runtime_device(device, config.device)
|
||||||
self.config.device = self.device
|
self.config.device = self.device
|
||||||
|
resolved_dtype = pick_runtime_dtype(self.device, self.config.dtype)
|
||||||
|
if resolved_dtype != self.config.dtype:
|
||||||
|
print(
|
||||||
|
f"[voxcpm2] adjusted dtype {self.config.dtype} -> {resolved_dtype} for device {self.device}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
self.config.dtype = resolved_dtype
|
||||||
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
|
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
|
||||||
|
|
||||||
# Text-Semantic LM
|
# Text-Semantic LM
|
||||||
|
|||||||
Reference in New Issue
Block a user