fix: stabilize CPU SDPA mask broadcasting
Use an explicit broadcastable attention mask shape during MiniCPM incremental decoding so CPU runtimes avoid a PyTorch SDPA dimension error without changing attention semantics. Made-with: Cursor
This commit is contained in:
@@ -196,7 +196,9 @@ class MiniCPMAttention(nn.Module):
|
||||
key_cache[:, :, position_id, :] = key_states
|
||||
value_cache[:, :, position_id, :] = value_states
|
||||
|
||||
attn_mask = torch.arange(key_cache.size(2), device=key_cache.device) <= position_id
|
||||
# Use an explicit broadcastable mask shape for SDPA. A 1D mask can
|
||||
# trigger a CPU-side dimension bug in some PyTorch versions.
|
||||
attn_mask = (torch.arange(key_cache.size(2), device=key_cache.device) <= position_id).view(1, 1, 1, -1)
|
||||
|
||||
# ref: https://github.com/pytorch/pytorch/issues/163597
|
||||
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
|
||||
|
||||
Reference in New Issue
Block a user