diff --git a/src/voxcpm/model/voxcpm2.py b/src/voxcpm/model/voxcpm2.py index ad4de96..930f80e 100644 --- a/src/voxcpm/model/voxcpm2.py +++ b/src/voxcpm/model/voxcpm2.py @@ -636,11 +636,11 @@ class VoxCPM2Model(nn.Module): streaming_prefix_len=streaming_prefix_len, ) if streaming: - decode_patch_len = self.patch_size * self._decode_chunk_size - for latent_pred, _, _ctx in inference_result: - decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)) - decode_audio = decode_audio[..., -decode_patch_len:].squeeze(1).cpu() - yield decode_audio + with self.audio_vae.streaming_decode() as vae_dec: + for latent_pred, _, _ctx in inference_result: + decode_audio = vae_dec.decode_chunk(latent_pred.to(torch.float32)) + decode_audio = decode_audio.squeeze(1).cpu() + yield decode_audio break else: latent_pred, pred_audio_feat, context_len = next(inference_result) @@ -923,11 +923,11 @@ class VoxCPM2Model(nn.Module): streaming_prefix_len=streaming_prefix_len, ) if streaming: - decode_patch_len = self.patch_size * self._decode_chunk_size - for latent_pred, pred_audio_feat, _ctx in inference_result: - decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)) - decode_audio = decode_audio[..., -decode_patch_len:].squeeze(1).cpu() - yield (decode_audio, target_text_token, pred_audio_feat) + with self.audio_vae.streaming_decode() as vae_dec: + for latent_pred, pred_audio_feat, _ctx in inference_result: + decode_audio = vae_dec.decode_chunk(latent_pred.to(torch.float32)) + decode_audio = decode_audio.squeeze(1).cpu() + yield (decode_audio, target_text_token, pred_audio_feat) break else: latent_pred, pred_audio_feat, context_len = next(inference_result) @@ -1067,8 +1067,8 @@ class VoxCPM2Model(nn.Module): prefix_feat_cond = pred_feat if streaming: - pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1) - feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size) + # Yield only the newest patch latent for stateful VAE decode + feat_pred = rearrange(pred_feat.unsqueeze(1), "b t p d -> b d (t p)", b=B, p=self.patch_size) yield feat_pred, pred_feat_seq, context_len diff --git a/src/voxcpm/modules/audiovae/audio_vae_v2.py b/src/voxcpm/modules/audiovae/audio_vae_v2.py index 7ce5231..5232bef 100644 --- a/src/voxcpm/modules/audiovae/audio_vae_v2.py +++ b/src/voxcpm/modules/audiovae/audio_vae_v2.py @@ -472,6 +472,20 @@ class AudioVAE(nn.Module): sr_cond = torch.tensor([self.out_sample_rate], device=z.device, dtype=torch.int32) return self.decoder(z, sr_cond) + def streaming_decode(self): + """Return a ``StreamingVAEDecoder`` context manager for stateful + chunk-by-chunk decoding. Each call to ``decode_chunk`` processes only + the new latent patch and carries causal-conv state internally, avoiding + the redundant overlap decode used previously. + + Usage:: + + with vae.streaming_decode() as dec: + for patch in patches: + audio_chunk = dec.decode_chunk(patch) + """ + return StreamingVAEDecoder(self) + def encode(self, audio_data: torch.Tensor, sample_rate: int): """ Args: @@ -485,3 +499,82 @@ class AudioVAE(nn.Module): audio_data = self.preprocess(audio_data, sample_rate) return self.encoder(audio_data)["mu"] + + +class StreamingVAEDecoder: + """Stateful streaming wrapper for :class:`AudioVAE`. + + Carries causal-convolution padding buffers between calls so that each + ``decode_chunk`` processes only the new latent patch — no overlap needed. + """ + + def __init__(self, vae: AudioVAE): + self._vae = vae + self._states: dict = {} + self._originals: list = [] + + # -- context manager -------------------------------------------------- + def __enter__(self): + self._states.clear() + self._install() + return self + + def __exit__(self, *exc): + self._restore() + self._states.clear() + + # -- public API -------------------------------------------------------- + def decode_chunk(self, z_chunk: torch.Tensor) -> torch.Tensor: + """Decode a single latent chunk and return the audio waveform.""" + return self._vae.decode(z_chunk) + + # -- internals --------------------------------------------------------- + def _install(self): + for name, mod in self._vae.decoder.named_modules(): + if isinstance(mod, CausalConv1d): + pad = mod._CausalConv1d__padding * 2 - mod._CausalConv1d__output_padding + if pad > 0: + self._patch_causal_conv(mod, pad) + elif isinstance(mod, CausalTransposeConv1d): + trim = mod._CausalTransposeConv1d__padding * 2 - mod._CausalTransposeConv1d__output_padding + ctx = mod.kernel_size[0] // mod.stride[0] - 1 + if ctx > 0: + self._patch_transpose_conv(mod, ctx, trim) + + def _patch_causal_conv(self, mod, pad_size): + states = self._states + key = id(mod) + orig = mod.forward + + def fwd(x, _k=key, _p=pad_size, _m=mod): + x_pad = torch.cat([states[_k], x], dim=-1) if _k in states else F.pad(x, (_p, 0)) + if x.shape[-1] >= _p: + states[_k] = x[:, :, -_p:].detach() + else: + prev = states.get(_k, torch.zeros(x.shape[0], x.shape[1], _p, + device=x.device, dtype=x.dtype)) + states[_k] = torch.cat([prev, x], dim=-1)[:, :, -_p:].detach() + return nn.Conv1d.forward(_m, x_pad) + + mod.forward = fwd + self._originals.append((mod, orig)) + + def _patch_transpose_conv(self, mod, ctx, trim): + states = self._states + key = id(mod) + orig = mod.forward + + def fwd(x, _k=key, _c=ctx, _t=trim, _m=mod): + x_full = torch.cat([states[_k], x], dim=-1) if _k in states else F.pad(x, (_c, 0)) + states[_k] = x[:, :, -_c:].detach() + out = nn.ConvTranspose1d.forward(_m, x_full) + left = _c * _m.stride[0] + return out[..., left:-_t] if _t > 0 else out[..., left:] + + mod.forward = fwd + self._originals.append((mod, orig)) + + def _restore(self): + for mod, orig in self._originals: + mod.forward = orig + self._originals.clear()