fix: use uncompiled feat_encoder for prefill to prevent CUDA Graph dynamic shape accumulation (#209)

This commit is contained in:
刘鑫
2026-04-09 16:00:17 +08:00
parent 5611bd08a0
commit 75cfa3e9b8
3 changed files with 536 additions and 2 deletions
+3 -1
View File
@@ -227,6 +227,7 @@ class VoxCPMModel(nn.Module):
self.residual_lm.forward_step = torch.compile(
self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True
)
self._feat_encoder_raw = self.feat_encoder
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
self.feat_decoder.estimator = torch.compile(
self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True
@@ -755,7 +756,8 @@ class VoxCPMModel(nn.Module):
"""
B, T, P, D = feat.shape
feat_embed = self.feat_encoder(feat) # [b, t, h_feat]
prefill_encoder = getattr(self, "_feat_encoder_raw", self.feat_encoder)
feat_embed = prefill_encoder(feat) # [b, t, h_feat]
feat_embed = self.enc_to_lm_proj(feat_embed)
if self.config.lm_config.use_mup: