diff --git a/src/voxcpm/model/voxcpm2.py b/src/voxcpm/model/voxcpm2.py index 45d6fc1..29a7741 100644 --- a/src/voxcpm/model/voxcpm2.py +++ b/src/voxcpm/model/voxcpm2.py @@ -246,6 +246,7 @@ class VoxCPM2Model(nn.Module): # Audio VAE self.audio_vae = audio_vae self.chunk_size = audio_vae.chunk_size + self._decode_chunk_size = getattr(audio_vae, "decode_chunk_size", audio_vae.chunk_size) self._encode_sample_rate = audio_vae.sample_rate self.sample_rate = getattr(audio_vae, "out_sample_rate", audio_vae.sample_rate) @@ -656,10 +657,10 @@ class VoxCPM2Model(nn.Module): streaming_prefix_len=streaming_prefix_len, ) if streaming: - patch_len = self.patch_size * self.chunk_size + decode_patch_len = self.patch_size * self._decode_chunk_size for latent_pred, _ in inference_result: decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)) - decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu() + decode_audio = decode_audio[..., -decode_patch_len:].squeeze(1).cpu() yield decode_audio break else: @@ -679,10 +680,10 @@ class VoxCPM2Model(nn.Module): if not streaming: decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)) - patch_len = self.patch_size * self.chunk_size + decode_patch_len = self.patch_size * self._decode_chunk_size has_continuation = bool(prompt_wav_path) if has_continuation: - decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1):].squeeze(1).cpu() + decode_audio = decode_audio[..., decode_patch_len * (streaming_prefix_len - 1):].squeeze(1).cpu() else: decode_audio = decode_audio.squeeze(1).cpu() yield decode_audio @@ -944,10 +945,10 @@ class VoxCPM2Model(nn.Module): streaming_prefix_len=streaming_prefix_len, ) if streaming: - patch_len = self.patch_size * self.chunk_size + decode_patch_len = self.patch_size * self._decode_chunk_size for latent_pred, pred_audio_feat in inference_result: decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)) - decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu() + decode_audio = decode_audio[..., -decode_patch_len:].squeeze(1).cpu() yield (decode_audio, target_text_token, pred_audio_feat) break else: @@ -966,9 +967,9 @@ class VoxCPM2Model(nn.Module): break if not streaming: decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)) - patch_len = self.patch_size * self.chunk_size + decode_patch_len = self.patch_size * self._decode_chunk_size if mode in ("continuation", "ref_continuation"): - decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1) :].squeeze(1).cpu() + decode_audio = decode_audio[..., decode_patch_len * (streaming_prefix_len - 1) :].squeeze(1).cpu() else: decode_audio = decode_audio[..., :].squeeze(1).cpu() yield (decode_audio, target_text_token, pred_audio_feat) diff --git a/src/voxcpm/modules/audiovae/audio_vae_v2.py b/src/voxcpm/modules/audiovae/audio_vae_v2.py index 2c04f18..7ce5231 100644 --- a/src/voxcpm/modules/audiovae/audio_vae_v2.py +++ b/src/voxcpm/modules/audiovae/audio_vae_v2.py @@ -436,6 +436,7 @@ class AudioVAE(nn.Module): self.out_sample_rate = out_sample_rate self.sr_bin_boundaries = sr_bin_boundaries self.chunk_size = math.prod(encoder_rates) + self.decode_chunk_size = math.prod(decoder_rates) def preprocess(self, audio_data, sample_rate): if sample_rate is None: