perf: stateful streaming VAE decode — eliminate redundant overlap

Streaming decode previously re-decoded 4 overlapping patches through
the VAE each step, discarding 75% of the output. Replace with stateful
decode that carries causal conv padding buffers between calls — one
patch in, one patch out, no overlap.

Changes:
- Add StreamingVAEDecoder to audiovae/audio_vae_v2.py — caches
  CausalConv1d and CausalTransposeConv1d left-pad state between calls
- AudioVAE.streaming_decode() context manager for clean lifecycle
- _inference yields single-patch latents in streaming mode
- _generate and _generate_with_prompt_cache use StreamingVAEDecoder

Streaming VAE decode time (isolated): 289ms → 148ms (2x faster)
Stateful vs full decode: cosine 1.0000, max diff 0.0005
(more accurate than previous overlap approach at max diff 0.001)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Kevin Knoedler
2026-04-08 09:06:13 -07:00
parent 364eff6840
commit 66205135fc
2 changed files with 105 additions and 12 deletions
+8 -8
View File
@@ -636,10 +636,10 @@ class VoxCPM2Model(nn.Module):
streaming_prefix_len=streaming_prefix_len, streaming_prefix_len=streaming_prefix_len,
) )
if streaming: if streaming:
decode_patch_len = self.patch_size * self._decode_chunk_size with self.audio_vae.streaming_decode() as vae_dec:
for latent_pred, _, _ctx in inference_result: for latent_pred, _, _ctx in inference_result:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)) decode_audio = vae_dec.decode_chunk(latent_pred.to(torch.float32))
decode_audio = decode_audio[..., -decode_patch_len:].squeeze(1).cpu() decode_audio = decode_audio.squeeze(1).cpu()
yield decode_audio yield decode_audio
break break
else: else:
@@ -923,10 +923,10 @@ class VoxCPM2Model(nn.Module):
streaming_prefix_len=streaming_prefix_len, streaming_prefix_len=streaming_prefix_len,
) )
if streaming: if streaming:
decode_patch_len = self.patch_size * self._decode_chunk_size with self.audio_vae.streaming_decode() as vae_dec:
for latent_pred, pred_audio_feat, _ctx in inference_result: for latent_pred, pred_audio_feat, _ctx in inference_result:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)) decode_audio = vae_dec.decode_chunk(latent_pred.to(torch.float32))
decode_audio = decode_audio[..., -decode_patch_len:].squeeze(1).cpu() decode_audio = decode_audio.squeeze(1).cpu()
yield (decode_audio, target_text_token, pred_audio_feat) yield (decode_audio, target_text_token, pred_audio_feat)
break break
else: else:
@@ -1067,8 +1067,8 @@ class VoxCPM2Model(nn.Module):
prefix_feat_cond = pred_feat prefix_feat_cond = pred_feat
if streaming: if streaming:
pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1) # Yield only the newest patch latent for stateful VAE decode
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size) feat_pred = rearrange(pred_feat.unsqueeze(1), "b t p d -> b d (t p)", b=B, p=self.patch_size)
yield feat_pred, pred_feat_seq, context_len yield feat_pred, pred_feat_seq, context_len
@@ -472,6 +472,20 @@ class AudioVAE(nn.Module):
sr_cond = torch.tensor([self.out_sample_rate], device=z.device, dtype=torch.int32) sr_cond = torch.tensor([self.out_sample_rate], device=z.device, dtype=torch.int32)
return self.decoder(z, sr_cond) return self.decoder(z, sr_cond)
def streaming_decode(self):
"""Return a ``StreamingVAEDecoder`` context manager for stateful
chunk-by-chunk decoding. Each call to ``decode_chunk`` processes only
the new latent patch and carries causal-conv state internally, avoiding
the redundant overlap decode used previously.
Usage::
with vae.streaming_decode() as dec:
for patch in patches:
audio_chunk = dec.decode_chunk(patch)
"""
return StreamingVAEDecoder(self)
def encode(self, audio_data: torch.Tensor, sample_rate: int): def encode(self, audio_data: torch.Tensor, sample_rate: int):
""" """
Args: Args:
@@ -485,3 +499,82 @@ class AudioVAE(nn.Module):
audio_data = self.preprocess(audio_data, sample_rate) audio_data = self.preprocess(audio_data, sample_rate)
return self.encoder(audio_data)["mu"] return self.encoder(audio_data)["mu"]
class StreamingVAEDecoder:
"""Stateful streaming wrapper for :class:`AudioVAE`.
Carries causal-convolution padding buffers between calls so that each
``decode_chunk`` processes only the new latent patch — no overlap needed.
"""
def __init__(self, vae: AudioVAE):
self._vae = vae
self._states: dict = {}
self._originals: list = []
# -- context manager --------------------------------------------------
def __enter__(self):
self._states.clear()
self._install()
return self
def __exit__(self, *exc):
self._restore()
self._states.clear()
# -- public API --------------------------------------------------------
def decode_chunk(self, z_chunk: torch.Tensor) -> torch.Tensor:
"""Decode a single latent chunk and return the audio waveform."""
return self._vae.decode(z_chunk)
# -- internals ---------------------------------------------------------
def _install(self):
for name, mod in self._vae.decoder.named_modules():
if isinstance(mod, CausalConv1d):
pad = mod._CausalConv1d__padding * 2 - mod._CausalConv1d__output_padding
if pad > 0:
self._patch_causal_conv(mod, pad)
elif isinstance(mod, CausalTransposeConv1d):
trim = mod._CausalTransposeConv1d__padding * 2 - mod._CausalTransposeConv1d__output_padding
ctx = mod.kernel_size[0] // mod.stride[0] - 1
if ctx > 0:
self._patch_transpose_conv(mod, ctx, trim)
def _patch_causal_conv(self, mod, pad_size):
states = self._states
key = id(mod)
orig = mod.forward
def fwd(x, _k=key, _p=pad_size, _m=mod):
x_pad = torch.cat([states[_k], x], dim=-1) if _k in states else F.pad(x, (_p, 0))
if x.shape[-1] >= _p:
states[_k] = x[:, :, -_p:].detach()
else:
prev = states.get(_k, torch.zeros(x.shape[0], x.shape[1], _p,
device=x.device, dtype=x.dtype))
states[_k] = torch.cat([prev, x], dim=-1)[:, :, -_p:].detach()
return nn.Conv1d.forward(_m, x_pad)
mod.forward = fwd
self._originals.append((mod, orig))
def _patch_transpose_conv(self, mod, ctx, trim):
states = self._states
key = id(mod)
orig = mod.forward
def fwd(x, _k=key, _c=ctx, _t=trim, _m=mod):
x_full = torch.cat([states[_k], x], dim=-1) if _k in states else F.pad(x, (_c, 0))
states[_k] = x[:, :, -_c:].detach()
out = nn.ConvTranspose1d.forward(_m, x_full)
left = _c * _m.stride[0]
return out[..., left:-_t] if _t > 0 else out[..., left:]
mod.forward = fwd
self._originals.append((mod, orig))
def _restore(self):
for mod, orig in self._originals:
mod.forward = orig
self._originals.clear()