Merge pull request #263 from Oumnya/fix/mps-bf16-dtype
fix(mps): force float32 on Apple Silicon to avoid bf16 quality loss
This commit is contained in:
@@ -44,7 +44,13 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules
|
||||
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
|
||||
from ..modules.locenc import VoxCPMLocEnc
|
||||
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
||||
from .utils import get_dtype, mask_multichar_chinese_tokens, next_and_close, resolve_runtime_device
|
||||
from .utils import (
|
||||
get_dtype,
|
||||
mask_multichar_chinese_tokens,
|
||||
next_and_close,
|
||||
pick_runtime_dtype,
|
||||
resolve_runtime_device,
|
||||
)
|
||||
|
||||
|
||||
class VoxCPMEncoderConfig(BaseModel):
|
||||
@@ -118,6 +124,13 @@ class VoxCPMModel(nn.Module):
|
||||
self.patch_size = config.patch_size
|
||||
self.device = resolve_runtime_device(device, config.device)
|
||||
self.config.device = self.device
|
||||
resolved_dtype = pick_runtime_dtype(self.device, self.config.dtype)
|
||||
if resolved_dtype != self.config.dtype:
|
||||
print(
|
||||
f"[voxcpm] adjusted dtype {self.config.dtype} -> {resolved_dtype} for device {self.device}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
self.config.dtype = resolved_dtype
|
||||
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
|
||||
|
||||
# Text-Semantic LM
|
||||
|
||||
Reference in New Issue
Block a user