diff --git a/scripts/test_pick_runtime_dtype.py b/scripts/test_pick_runtime_dtype.py new file mode 100644 index 0000000..5bff4eb --- /dev/null +++ b/scripts/test_pick_runtime_dtype.py @@ -0,0 +1,108 @@ +"""Unit checks for pick_runtime_dtype / get_dtype consistency. + +Loads src/voxcpm/model/utils.py directly to avoid the heavy voxcpm package +init. Run with: `python scripts/test_pick_runtime_dtype.py`. +""" +import importlib.util +import os +import pathlib +import sys + +REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent +UTILS = str(REPO_ROOT / "src" / "voxcpm" / "model" / "utils.py") +spec = importlib.util.spec_from_file_location("voxcpm_utils", UTILS) +utils = importlib.util.module_from_spec(spec) +spec.loader.exec_module(utils) + +_LOW_PRECISION_DTYPES = utils._LOW_PRECISION_DTYPES +_VALID_DTYPE_OVERRIDES = utils._VALID_DTYPE_OVERRIDES +get_dtype = utils.get_dtype +pick_runtime_dtype = utils.pick_runtime_dtype + + +def expect(actual, expected, label): + ok = actual == expected + mark = "OK " if ok else "FAIL" + print(f"[{mark}] {label}: got={actual!r} expected={expected!r}") + return ok + + +def expect_raises(fn, exc_type, label): + try: + fn() + except exc_type as e: + print(f"[OK ] {label}: raised {exc_type.__name__}: {e}") + return True + except Exception as e: + print(f"[FAIL] {label}: raised {type(e).__name__} not {exc_type.__name__}: {e}") + return False + print(f"[FAIL] {label}: no exception raised") + return False + + +results = [] + +print("=== override set sanity ===") +results.append(expect("half" not in _VALID_DTYPE_OVERRIDES, True, "half removed from _VALID_DTYPE_OVERRIDES")) +results.append(expect("half" not in _LOW_PRECISION_DTYPES, True, "half removed from _LOW_PRECISION_DTYPES")) + +print("\n=== every accepted override parses through get_dtype ===") +for dt in sorted(_VALID_DTYPE_OVERRIDES): + try: + torch_dtype = get_dtype(dt) + print(f"[OK ] get_dtype({dt!r}) -> {torch_dtype}") + results.append(True) + except Exception as e: + print(f"[FAIL] get_dtype({dt!r}) raised: {e}") + results.append(False) + +print("\n=== pick_runtime_dtype: non-mps is a no-op ===") +results.append(expect(pick_runtime_dtype("cuda", "bfloat16"), "bfloat16", "cuda/bf16 untouched")) +results.append(expect(pick_runtime_dtype("cpu", "float16"), "float16", "cpu/fp16 untouched")) +results.append(expect(pick_runtime_dtype("cuda", "float32"), "float32", "cuda/fp32 untouched")) + +print("\n=== pick_runtime_dtype: mps forces fp32 for low-precision ===") +os.environ.pop("VOXCPM_MPS_DTYPE", None) +results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "float32", "mps/bf16 -> fp32")) +results.append(expect(pick_runtime_dtype("mps", "bf16"), "float32", "mps/bf16-alias -> fp32")) +results.append(expect(pick_runtime_dtype("mps", "float16"), "float32", "mps/fp16 -> fp32")) +results.append(expect(pick_runtime_dtype("mps", "fp16"), "float32", "mps/fp16-alias -> fp32")) +results.append(expect(pick_runtime_dtype("mps", "float32"), "float32", "mps/fp32 stays")) +results.append(expect(pick_runtime_dtype("mps", "fp32"), "fp32", "mps/fp32-alias stays")) + +print("\n=== pick_runtime_dtype: VOXCPM_MPS_DTYPE override ===") +os.environ["VOXCPM_MPS_DTYPE"] = "bfloat16" +results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "bfloat16", "override bf16 honored")) + +os.environ["VOXCPM_MPS_DTYPE"] = "FP16" +results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "fp16", "override is case-insensitive")) + +os.environ["VOXCPM_MPS_DTYPE"] = " float32 " +results.append(expect(pick_runtime_dtype("mps", "bfloat16"), "float32", "override is whitespace-trimmed")) + +print("\n=== pick_runtime_dtype: 'half' is no longer a valid override ===") +os.environ["VOXCPM_MPS_DTYPE"] = "half" +results.append( + expect_raises( + lambda: pick_runtime_dtype("mps", "bfloat16"), + ValueError, + "override=half now rejected (was the bug)", + ) +) + +os.environ["VOXCPM_MPS_DTYPE"] = "garbage" +results.append( + expect_raises( + lambda: pick_runtime_dtype("mps", "bfloat16"), + ValueError, + "override=garbage still rejected", + ) +) + +os.environ.pop("VOXCPM_MPS_DTYPE", None) + +print("\n=== summary ===") +passed = sum(results) +total = len(results) +print(f"{passed}/{total} passed") +sys.exit(0 if passed == total else 1) diff --git a/src/voxcpm/model/utils.py b/src/voxcpm/model/utils.py index 2ca6068..940fc74 100644 --- a/src/voxcpm/model/utils.py +++ b/src/voxcpm/model/utils.py @@ -1,7 +1,15 @@ +import os from typing import List, Optional import torch from transformers import PreTrainedTokenizer +_LOW_PRECISION_DTYPES = {"bfloat16", "bf16", "float16", "fp16"} +_VALID_DTYPE_OVERRIDES = { + "bfloat16", "bf16", + "float16", "fp16", + "float32", "fp32", +} + # Ref: https://github.com/OpenBMB/VoxCPM/issues/256#issuecomment-4235252732 # Explicitly close partially-consumed generators so inference_mode cleanup @@ -135,6 +143,34 @@ def _has_mps() -> bool: return hasattr(torch.backends, "mps") and torch.backends.mps.is_available() +def pick_runtime_dtype(device: str, configured_dtype: str) -> str: + """Pick a safe runtime dtype for the resolved device. + + On Apple Silicon (MPS), bfloat16/float16 produce enough numerical drift + in the diffusion AR loop that the output is glitched and the model's + badcase detector triggers infinite retries. float32 is the only stable + option today. CUDA and CPU keep whatever the checkpoint was trained with. + + Users can override with ``VOXCPM_MPS_DTYPE`` (e.g. ``bfloat16``) when + they want to test future MPS improvements. + """ + if device != "mps": + return configured_dtype + + override = os.environ.get("VOXCPM_MPS_DTYPE", "").strip().lower() + if override: + if override not in _VALID_DTYPE_OVERRIDES: + raise ValueError( + f"VOXCPM_MPS_DTYPE='{override}' is not one of " + f"{sorted(_VALID_DTYPE_OVERRIDES)}" + ) + return override + + if (configured_dtype or "").lower() in _LOW_PRECISION_DTYPES: + return "float32" + return configured_dtype + + def auto_select_device(preferred_device: Optional[str] = "cuda") -> str: """ Choose a runtime device automatically. diff --git a/src/voxcpm/model/voxcpm.py b/src/voxcpm/model/voxcpm.py index c08e1e8..445618b 100644 --- a/src/voxcpm/model/voxcpm.py +++ b/src/voxcpm/model/voxcpm.py @@ -44,7 +44,13 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT from ..modules.locenc import VoxCPMLocEnc from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel -from .utils import get_dtype, mask_multichar_chinese_tokens, next_and_close, resolve_runtime_device +from .utils import ( + get_dtype, + mask_multichar_chinese_tokens, + next_and_close, + pick_runtime_dtype, + resolve_runtime_device, +) class VoxCPMEncoderConfig(BaseModel): @@ -118,6 +124,13 @@ class VoxCPMModel(nn.Module): self.patch_size = config.patch_size self.device = resolve_runtime_device(device, config.device) self.config.device = self.device + resolved_dtype = pick_runtime_dtype(self.device, self.config.dtype) + if resolved_dtype != self.config.dtype: + print( + f"[voxcpm] adjusted dtype {self.config.dtype} -> {resolved_dtype} for device {self.device}", + file=sys.stderr, + ) + self.config.dtype = resolved_dtype print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr) # Text-Semantic LM diff --git a/src/voxcpm/model/voxcpm2.py b/src/voxcpm/model/voxcpm2.py index 8692ba0..90495f1 100644 --- a/src/voxcpm/model/voxcpm2.py +++ b/src/voxcpm/model/voxcpm2.py @@ -45,7 +45,13 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiTV2 from ..modules.locenc import VoxCPMLocEnc from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel -from .utils import get_dtype, mask_multichar_chinese_tokens, next_and_close, resolve_runtime_device +from .utils import ( + get_dtype, + mask_multichar_chinese_tokens, + next_and_close, + pick_runtime_dtype, + resolve_runtime_device, +) # A simple function to trim audio silence using VAD, not used default @@ -160,6 +166,13 @@ class VoxCPM2Model(nn.Module): self.patch_size = config.patch_size self.device = resolve_runtime_device(device, config.device) self.config.device = self.device + resolved_dtype = pick_runtime_dtype(self.device, self.config.dtype) + if resolved_dtype != self.config.dtype: + print( + f"[voxcpm2] adjusted dtype {self.config.dtype} -> {resolved_dtype} for device {self.device}", + file=sys.stderr, + ) + self.config.dtype = resolved_dtype print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr) # Text-Semantic LM