Merge pull request #263 from Oumnya/fix/mps-bf16-dtype

fix(mps): force float32 on Apple Silicon to avoid bf16 quality loss
This commit is contained in:
ZGY
2026-04-21 18:49:48 +08:00
committed by GitHub
4 changed files with 172 additions and 2 deletions
+14 -1
View File
@@ -44,7 +44,13 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
from ..modules.locenc import VoxCPMLocEnc
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
from .utils import get_dtype, mask_multichar_chinese_tokens, next_and_close, resolve_runtime_device
from .utils import (
get_dtype,
mask_multichar_chinese_tokens,
next_and_close,
pick_runtime_dtype,
resolve_runtime_device,
)
class VoxCPMEncoderConfig(BaseModel):
@@ -118,6 +124,13 @@ class VoxCPMModel(nn.Module):
self.patch_size = config.patch_size
self.device = resolve_runtime_device(device, config.device)
self.config.device = self.device
resolved_dtype = pick_runtime_dtype(self.device, self.config.dtype)
if resolved_dtype != self.config.dtype:
print(
f"[voxcpm] adjusted dtype {self.config.dtype} -> {resolved_dtype} for device {self.device}",
file=sys.stderr,
)
self.config.dtype = resolved_dtype
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
# Text-Semantic LM