fix: stabilize CPU SDPA mask broadcasting
Use an explicit broadcastable attention mask shape during MiniCPM incremental decoding so CPU runtimes avoid a PyTorch SDPA dimension error without changing attention semantics. Made-with: Cursor
This commit is contained in:
@@ -196,7 +196,9 @@ class MiniCPMAttention(nn.Module):
|
|||||||
key_cache[:, :, position_id, :] = key_states
|
key_cache[:, :, position_id, :] = key_states
|
||||||
value_cache[:, :, position_id, :] = value_states
|
value_cache[:, :, position_id, :] = value_states
|
||||||
|
|
||||||
attn_mask = torch.arange(key_cache.size(2), device=key_cache.device) <= position_id
|
# Use an explicit broadcastable mask shape for SDPA. A 1D mask can
|
||||||
|
# trigger a CPU-side dimension bug in some PyTorch versions.
|
||||||
|
attn_mask = (torch.arange(key_cache.size(2), device=key_cache.device) <= position_id).view(1, 1, 1, -1)
|
||||||
|
|
||||||
# ref: https://github.com/pytorch/pytorch/issues/163597
|
# ref: https://github.com/pytorch/pytorch/issues/163597
|
||||||
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
|
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
|
||||||
|
|||||||
Reference in New Issue
Block a user