fix: use uncompiled feat_encoder for prefill to prevent CUDA Graph dynamic shape accumulation (#209)
This commit is contained in:
@@ -227,6 +227,7 @@ class VoxCPMModel(nn.Module):
|
||||
self.residual_lm.forward_step = torch.compile(
|
||||
self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True
|
||||
)
|
||||
self._feat_encoder_raw = self.feat_encoder
|
||||
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
||||
self.feat_decoder.estimator = torch.compile(
|
||||
self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True
|
||||
@@ -755,7 +756,8 @@ class VoxCPMModel(nn.Module):
|
||||
"""
|
||||
B, T, P, D = feat.shape
|
||||
|
||||
feat_embed = self.feat_encoder(feat) # [b, t, h_feat]
|
||||
prefill_encoder = getattr(self, "_feat_encoder_raw", self.feat_encoder)
|
||||
feat_embed = prefill_encoder(feat) # [b, t, h_feat]
|
||||
feat_embed = self.enc_to_lm_proj(feat_embed)
|
||||
|
||||
if self.config.lm_config.use_mup:
|
||||
|
||||
@@ -275,6 +275,7 @@ class VoxCPM2Model(nn.Module):
|
||||
self.residual_lm.forward_step = torch.compile(
|
||||
self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True
|
||||
)
|
||||
self._feat_encoder_raw = self.feat_encoder
|
||||
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
||||
self.feat_decoder.estimator = torch.compile(
|
||||
self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True
|
||||
@@ -997,7 +998,8 @@ class VoxCPM2Model(nn.Module):
|
||||
"""
|
||||
B, T, P, D = feat.shape
|
||||
|
||||
feat_embed = self.feat_encoder(feat) # [b, t, h_feat]
|
||||
prefill_encoder = getattr(self, "_feat_encoder_raw", self.feat_encoder)
|
||||
feat_embed = prefill_encoder(feat) # [b, t, h_feat]
|
||||
feat_embed = self.enc_to_lm_proj(feat_embed)
|
||||
|
||||
if self.config.lm_config.use_mup:
|
||||
|
||||
Reference in New Issue
Block a user