feat: add no_rope support for residual LM and fix streaming continuation decoding

- Add `residual_lm_no_rope` config option in VoxCPMConfig and propagate to MiniCPMModel
- Add `no_rope` field to MiniCPM4Config; make RoPE embedding optional in MiniCPMModel and MiniCPMAttention
- Add `streaming_prefix_len` parameter to generation interface
- Fix non-streaming audio decode in continuation mode to trim leading prefix patches consistently
- Refactor streaming prefix context preparation: distinguish continuation vs. zero-shot via feat_mask trailing bit instead of audio_mask sum

Made-with: Cursor
This commit is contained in:
Labmem-Zhouyx
2026-03-31 17:07:33 +08:00
parent d9cf376e16
commit 42c428164c
3 changed files with 41 additions and 19 deletions
+21 -9
View File
@@ -129,6 +129,7 @@ class VoxCPMConfig(BaseModel):
patch_size: int = 4 patch_size: int = 4
feat_dim: int = 64 feat_dim: int = 64
residual_lm_num_layers: int = 8 residual_lm_num_layers: int = 8
residual_lm_no_rope: bool = False
scalar_quantization_latent_dim: int = 512 scalar_quantization_latent_dim: int = 512
scalar_quantization_scale: int = 9 scalar_quantization_scale: int = 9
@@ -195,6 +196,7 @@ class VoxCPM2Model(nn.Module):
residual_lm_config = config.lm_config.model_copy(deep=True) residual_lm_config = config.lm_config.model_copy(deep=True)
residual_lm_config.num_hidden_layers = config.residual_lm_num_layers residual_lm_config.num_hidden_layers = config.residual_lm_num_layers
residual_lm_config.vocab_size = 0 residual_lm_config.vocab_size = 0
residual_lm_config.no_rope = config.residual_lm_no_rope
self.residual_lm = MiniCPMModel(residual_lm_config) self.residual_lm = MiniCPMModel(residual_lm_config)
self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype)) self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
@@ -474,6 +476,7 @@ class VoxCPM2Model(nn.Module):
retry_badcase_max_times: int = 3, retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0, retry_badcase_ratio_threshold: float = 6.0,
streaming: bool = False, streaming: bool = False,
streaming_prefix_len: int = 3,
) -> Generator[torch.Tensor, None, None]: ) -> Generator[torch.Tensor, None, None]:
if retry_badcase and streaming: if retry_badcase and streaming:
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.") warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
@@ -634,6 +637,7 @@ class VoxCPM2Model(nn.Module):
inference_timesteps=inference_timesteps, inference_timesteps=inference_timesteps,
cfg_value=cfg_value, cfg_value=cfg_value,
streaming=streaming, streaming=streaming,
streaming_prefix_len=streaming_prefix_len,
) )
if streaming: if streaming:
patch_len = self.patch_size * self.chunk_size patch_len = self.patch_size * self.chunk_size
@@ -658,7 +662,13 @@ class VoxCPM2Model(nn.Module):
break break
if not streaming: if not streaming:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu() decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
patch_len = self.patch_size * self.chunk_size
has_continuation = bool(prompt_wav_path)
if has_continuation:
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1):].squeeze(1).cpu()
else:
decode_audio = decode_audio.squeeze(1).cpu()
yield decode_audio yield decode_audio
@torch.inference_mode() @torch.inference_mode()
@@ -930,7 +940,7 @@ class VoxCPM2Model(nn.Module):
if not streaming: if not streaming:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)) decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
patch_len = self.patch_size * self.chunk_size patch_len = self.patch_size * self.chunk_size
if audio_mask.sum().item() > 0: if mode in ("continuation", "ref_continuation"):
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1) :].squeeze(1).cpu() decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1) :].squeeze(1).cpu()
else: else:
decode_audio = decode_audio[..., :].squeeze(1).cpu() decode_audio = decode_audio[..., :].squeeze(1).cpu()
@@ -995,15 +1005,17 @@ class VoxCPM2Model(nn.Module):
curr_embed = None curr_embed = None
# Prepare prompt context patches for streaming mode # Prepare prompt context patches for streaming mode
# When there's a prompt audio, use its last (streaming_prefix_len - 1) patches as initial context # - Continuation modes (feat_mask ends with 1): use the last (streaming_prefix_len - 1)
prompt_context_patches = [] # trailing audio patches as initial context so the VAE can decode smoothly.
audio_patch_count = int(feat_mask.sum().item()) # - Reference-only / zero-shot (feat_mask ends with 0): start from scratch.
if audio_patch_count > 0: has_continuation_audio = feat_mask[0, -1].item() == 1
context_len = min(streaming_prefix_len - 1, audio_patch_count) if has_continuation_audio:
audio_indices = feat_mask.squeeze(0).nonzero(as_tuple=True)[0] audio_indices = feat_mask.squeeze(0).nonzero(as_tuple=True)[0]
context_len = min(streaming_prefix_len - 1, len(audio_indices))
last_audio_indices = audio_indices[-context_len:] last_audio_indices = audio_indices[-context_len:]
prompt_context_patches = list(feat[:, last_audio_indices, :, :].split(1, dim=1)) pred_feat_seq = list(feat[:, last_audio_indices, :, :].split(1, dim=1))
pred_feat_seq = prompt_context_patches + pred_feat_seq else:
pred_feat_seq = []
enc_outputs, kv_cache_tuple = self.base_lm( enc_outputs, kv_cache_tuple = self.base_lm(
inputs_embeds=combined_embed, inputs_embeds=combined_embed,
+1
View File
@@ -27,3 +27,4 @@ class MiniCPM4Config(BaseModel):
scale_depth: float scale_depth: float
rope_theta: float rope_theta: float
kv_channels: int = None kv_channels: int = None
no_rope: bool = False
+11 -2
View File
@@ -145,8 +145,8 @@ class MiniCPMAttention(nn.Module):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if position_emb is not None:
cos, sin = position_emb cos, sin = position_emb
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# ref: https://github.com/pytorch/pytorch/issues/163597 # ref: https://github.com/pytorch/pytorch/issues/163597
@@ -187,8 +187,8 @@ class MiniCPMAttention(nn.Module):
key_states = key_states.view(bsz, 1, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, 1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, 1, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, 1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if position_emb is not None:
cos, sin = position_emb cos, sin = position_emb
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
key_cache, value_cache = kv_cache key_cache, value_cache = kv_cache
@@ -343,6 +343,9 @@ class MiniCPMModel(nn.Module):
) )
self.norm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if config.no_rope:
self.rope_emb = None
else:
self.rope_emb = MiniCPMLongRoPE(config) self.rope_emb = MiniCPMLongRoPE(config)
self.kv_cache = None self.kv_cache = None
@@ -360,8 +363,11 @@ class MiniCPMModel(nn.Module):
hidden_states: Tensor(batch_size, seq_length, hidden_size) hidden_states: Tensor(batch_size, seq_length, hidden_size)
next_decoder_cache: List[(batch_size, num_heads, seq_length, head_dim), (batch_size, num_heads, seq_length, head_dim)] next_decoder_cache: List[(batch_size, num_heads, seq_length, head_dim), (batch_size, num_heads, seq_length, head_dim)]
""" """
if self.rope_emb is not None:
position_ids = torch.arange(0, inputs_embeds.size(1), dtype=torch.long, device=inputs_embeds.device) position_ids = torch.arange(0, inputs_embeds.size(1), dtype=torch.long, device=inputs_embeds.device)
position_emb = self.rope_emb(position_ids) position_emb = self.rope_emb(position_ids)
else:
position_emb = None
hidden_states = inputs_embeds hidden_states = inputs_embeds
next_decoder_cache = [] next_decoder_cache = []
@@ -390,7 +396,10 @@ class MiniCPMModel(nn.Module):
""" """
assert self.kv_cache is not None, "KV cache is not setup" assert self.kv_cache is not None, "KV cache is not setup"
if self.rope_emb is not None:
position_emb = self.rope_emb(position_id) position_emb = self.rope_emb(position_id)
else:
position_emb = None
hidden_states = inputs_embeds hidden_states = inputs_embeds
for i, decoder_layer in enumerate(self.layers): for i, decoder_layer in enumerate(self.layers):