perf: stateful streaming VAE decode — eliminate redundant overlap
Streaming decode previously re-decoded 4 overlapping patches through the VAE each step, discarding 75% of the output. Replace with stateful decode that carries causal conv padding buffers between calls — one patch in, one patch out, no overlap. Changes: - Add StreamingVAEDecoder to audiovae/audio_vae_v2.py — caches CausalConv1d and CausalTransposeConv1d left-pad state between calls - AudioVAE.streaming_decode() context manager for clean lifecycle - _inference yields single-patch latents in streaming mode - _generate and _generate_with_prompt_cache use StreamingVAEDecoder Streaming VAE decode time (isolated): 289ms → 148ms (2x faster) Stateful vs full decode: cosine 1.0000, max diff 0.0005 (more accurate than previous overlap approach at max diff 0.001) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
+12
-12
@@ -636,11 +636,11 @@ class VoxCPM2Model(nn.Module):
|
||||
streaming_prefix_len=streaming_prefix_len,
|
||||
)
|
||||
if streaming:
|
||||
decode_patch_len = self.patch_size * self._decode_chunk_size
|
||||
for latent_pred, _, _ctx in inference_result:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||
decode_audio = decode_audio[..., -decode_patch_len:].squeeze(1).cpu()
|
||||
yield decode_audio
|
||||
with self.audio_vae.streaming_decode() as vae_dec:
|
||||
for latent_pred, _, _ctx in inference_result:
|
||||
decode_audio = vae_dec.decode_chunk(latent_pred.to(torch.float32))
|
||||
decode_audio = decode_audio.squeeze(1).cpu()
|
||||
yield decode_audio
|
||||
break
|
||||
else:
|
||||
latent_pred, pred_audio_feat, context_len = next(inference_result)
|
||||
@@ -923,11 +923,11 @@ class VoxCPM2Model(nn.Module):
|
||||
streaming_prefix_len=streaming_prefix_len,
|
||||
)
|
||||
if streaming:
|
||||
decode_patch_len = self.patch_size * self._decode_chunk_size
|
||||
for latent_pred, pred_audio_feat, _ctx in inference_result:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||
decode_audio = decode_audio[..., -decode_patch_len:].squeeze(1).cpu()
|
||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||
with self.audio_vae.streaming_decode() as vae_dec:
|
||||
for latent_pred, pred_audio_feat, _ctx in inference_result:
|
||||
decode_audio = vae_dec.decode_chunk(latent_pred.to(torch.float32))
|
||||
decode_audio = decode_audio.squeeze(1).cpu()
|
||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||
break
|
||||
else:
|
||||
latent_pred, pred_audio_feat, context_len = next(inference_result)
|
||||
@@ -1067,8 +1067,8 @@ class VoxCPM2Model(nn.Module):
|
||||
prefix_feat_cond = pred_feat
|
||||
|
||||
if streaming:
|
||||
pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1)
|
||||
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
# Yield only the newest patch latent for stateful VAE decode
|
||||
feat_pred = rearrange(pred_feat.unsqueeze(1), "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
|
||||
yield feat_pred, pred_feat_seq, context_len
|
||||
|
||||
|
||||
Reference in New Issue
Block a user