From 42c428164c4b944e852f65f037457a3fbf9de0c4 Mon Sep 17 00:00:00 2001 From: Labmem-Zhouyx <913703649@qq.com> Date: Tue, 31 Mar 2026 17:07:33 +0800 Subject: [PATCH] feat: add no_rope support for residual LM and fix streaming continuation decoding - Add `residual_lm_no_rope` config option in VoxCPMConfig and propagate to MiniCPMModel - Add `no_rope` field to MiniCPM4Config; make RoPE embedding optional in MiniCPMModel and MiniCPMAttention - Add `streaming_prefix_len` parameter to generation interface - Fix non-streaming audio decode in continuation mode to trim leading prefix patches consistently - Refactor streaming prefix context preparation: distinguish continuation vs. zero-shot via feat_mask trailing bit instead of audio_mask sum Made-with: Cursor --- src/voxcpm/model/voxcpm2.py | 30 +++++++++++++++++++-------- src/voxcpm/modules/minicpm4/config.py | 1 + src/voxcpm/modules/minicpm4/model.py | 29 +++++++++++++++++--------- 3 files changed, 41 insertions(+), 19 deletions(-) diff --git a/src/voxcpm/model/voxcpm2.py b/src/voxcpm/model/voxcpm2.py index 90f116c..8819424 100644 --- a/src/voxcpm/model/voxcpm2.py +++ b/src/voxcpm/model/voxcpm2.py @@ -129,6 +129,7 @@ class VoxCPMConfig(BaseModel): patch_size: int = 4 feat_dim: int = 64 residual_lm_num_layers: int = 8 + residual_lm_no_rope: bool = False scalar_quantization_latent_dim: int = 512 scalar_quantization_scale: int = 9 @@ -195,6 +196,7 @@ class VoxCPM2Model(nn.Module): residual_lm_config = config.lm_config.model_copy(deep=True) residual_lm_config.num_hidden_layers = config.residual_lm_num_layers residual_lm_config.vocab_size = 0 + residual_lm_config.no_rope = config.residual_lm_no_rope self.residual_lm = MiniCPMModel(residual_lm_config) self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype)) @@ -474,6 +476,7 @@ class VoxCPM2Model(nn.Module): retry_badcase_max_times: int = 3, retry_badcase_ratio_threshold: float = 6.0, streaming: bool = False, + streaming_prefix_len: int = 3, ) -> Generator[torch.Tensor, None, None]: if retry_badcase and streaming: warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.") @@ -634,6 +637,7 @@ class VoxCPM2Model(nn.Module): inference_timesteps=inference_timesteps, cfg_value=cfg_value, streaming=streaming, + streaming_prefix_len=streaming_prefix_len, ) if streaming: patch_len = self.patch_size * self.chunk_size @@ -658,7 +662,13 @@ class VoxCPM2Model(nn.Module): break if not streaming: - decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu() + decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)) + patch_len = self.patch_size * self.chunk_size + has_continuation = bool(prompt_wav_path) + if has_continuation: + decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1):].squeeze(1).cpu() + else: + decode_audio = decode_audio.squeeze(1).cpu() yield decode_audio @torch.inference_mode() @@ -930,7 +940,7 @@ class VoxCPM2Model(nn.Module): if not streaming: decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)) patch_len = self.patch_size * self.chunk_size - if audio_mask.sum().item() > 0: + if mode in ("continuation", "ref_continuation"): decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1) :].squeeze(1).cpu() else: decode_audio = decode_audio[..., :].squeeze(1).cpu() @@ -995,15 +1005,17 @@ class VoxCPM2Model(nn.Module): curr_embed = None # Prepare prompt context patches for streaming mode - # When there's a prompt audio, use its last (streaming_prefix_len - 1) patches as initial context - prompt_context_patches = [] - audio_patch_count = int(feat_mask.sum().item()) - if audio_patch_count > 0: - context_len = min(streaming_prefix_len - 1, audio_patch_count) + # - Continuation modes (feat_mask ends with 1): use the last (streaming_prefix_len - 1) + # trailing audio patches as initial context so the VAE can decode smoothly. + # - Reference-only / zero-shot (feat_mask ends with 0): start from scratch. + has_continuation_audio = feat_mask[0, -1].item() == 1 + if has_continuation_audio: audio_indices = feat_mask.squeeze(0).nonzero(as_tuple=True)[0] + context_len = min(streaming_prefix_len - 1, len(audio_indices)) last_audio_indices = audio_indices[-context_len:] - prompt_context_patches = list(feat[:, last_audio_indices, :, :].split(1, dim=1)) - pred_feat_seq = prompt_context_patches + pred_feat_seq + pred_feat_seq = list(feat[:, last_audio_indices, :, :].split(1, dim=1)) + else: + pred_feat_seq = [] enc_outputs, kv_cache_tuple = self.base_lm( inputs_embeds=combined_embed, diff --git a/src/voxcpm/modules/minicpm4/config.py b/src/voxcpm/modules/minicpm4/config.py index 8932821..332dd11 100644 --- a/src/voxcpm/modules/minicpm4/config.py +++ b/src/voxcpm/modules/minicpm4/config.py @@ -27,3 +27,4 @@ class MiniCPM4Config(BaseModel): scale_depth: float rope_theta: float kv_channels: int = None + no_rope: bool = False diff --git a/src/voxcpm/modules/minicpm4/model.py b/src/voxcpm/modules/minicpm4/model.py index 61c2122..99d6f0b 100644 --- a/src/voxcpm/modules/minicpm4/model.py +++ b/src/voxcpm/modules/minicpm4/model.py @@ -145,9 +145,9 @@ class MiniCPMAttention(nn.Module): key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = position_emb - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + if position_emb is not None: + cos, sin = position_emb + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # ref: https://github.com/pytorch/pytorch/issues/163597 # there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous @@ -187,9 +187,9 @@ class MiniCPMAttention(nn.Module): key_states = key_states.view(bsz, 1, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, 1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = position_emb - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + if position_emb is not None: + cos, sin = position_emb + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) key_cache, value_cache = kv_cache @@ -343,7 +343,10 @@ class MiniCPMModel(nn.Module): ) self.norm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rope_emb = MiniCPMLongRoPE(config) + if config.no_rope: + self.rope_emb = None + else: + self.rope_emb = MiniCPMLongRoPE(config) self.kv_cache = None @@ -360,8 +363,11 @@ class MiniCPMModel(nn.Module): hidden_states: Tensor(batch_size, seq_length, hidden_size) next_decoder_cache: List[(batch_size, num_heads, seq_length, head_dim), (batch_size, num_heads, seq_length, head_dim)] """ - position_ids = torch.arange(0, inputs_embeds.size(1), dtype=torch.long, device=inputs_embeds.device) - position_emb = self.rope_emb(position_ids) + if self.rope_emb is not None: + position_ids = torch.arange(0, inputs_embeds.size(1), dtype=torch.long, device=inputs_embeds.device) + position_emb = self.rope_emb(position_ids) + else: + position_emb = None hidden_states = inputs_embeds next_decoder_cache = [] @@ -390,7 +396,10 @@ class MiniCPMModel(nn.Module): """ assert self.kv_cache is not None, "KV cache is not setup" - position_emb = self.rope_emb(position_id) + if self.rope_emb is not None: + position_emb = self.rope_emb(position_id) + else: + position_emb = None hidden_states = inputs_embeds for i, decoder_layer in enumerate(self.layers):