diff --git a/src/voxcpm/modules/minicpm4/model.py b/src/voxcpm/modules/minicpm4/model.py index 99d6f0b..6075807 100644 --- a/src/voxcpm/modules/minicpm4/model.py +++ b/src/voxcpm/modules/minicpm4/model.py @@ -196,7 +196,9 @@ class MiniCPMAttention(nn.Module): key_cache[:, :, position_id, :] = key_states value_cache[:, :, position_id, :] = value_states - attn_mask = torch.arange(key_cache.size(2), device=key_cache.device) <= position_id + # Use an explicit broadcastable mask shape for SDPA. A 1D mask can + # trigger a CPU-side dimension bug in some PyTorch versions. + attn_mask = (torch.arange(key_cache.size(2), device=key_cache.device) <= position_id).view(1, 1, 1, -1) # ref: https://github.com/pytorch/pytorch/issues/163597 # there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous