feat: add no_rope support for residual LM and fix streaming continuation decoding

- Add `residual_lm_no_rope` config option in VoxCPMConfig and propagate to MiniCPMModel
- Add `no_rope` field to MiniCPM4Config; make RoPE embedding optional in MiniCPMModel and MiniCPMAttention
- Add `streaming_prefix_len` parameter to generation interface
- Fix non-streaming audio decode in continuation mode to trim leading prefix patches consistently
- Refactor streaming prefix context preparation: distinguish continuation vs. zero-shot via feat_mask trailing bit instead of audio_mask sum

Made-with: Cursor
This commit is contained in:
Labmem-Zhouyx
2026-03-31 17:07:33 +08:00
parent d9cf376e16
commit 42c428164c
3 changed files with 41 additions and 19 deletions
+21 -9
View File
@@ -129,6 +129,7 @@ class VoxCPMConfig(BaseModel):
patch_size: int = 4
feat_dim: int = 64
residual_lm_num_layers: int = 8
residual_lm_no_rope: bool = False
scalar_quantization_latent_dim: int = 512
scalar_quantization_scale: int = 9
@@ -195,6 +196,7 @@ class VoxCPM2Model(nn.Module):
residual_lm_config = config.lm_config.model_copy(deep=True)
residual_lm_config.num_hidden_layers = config.residual_lm_num_layers
residual_lm_config.vocab_size = 0
residual_lm_config.no_rope = config.residual_lm_no_rope
self.residual_lm = MiniCPMModel(residual_lm_config)
self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
@@ -474,6 +476,7 @@ class VoxCPM2Model(nn.Module):
retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0,
streaming: bool = False,
streaming_prefix_len: int = 3,
) -> Generator[torch.Tensor, None, None]:
if retry_badcase and streaming:
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
@@ -634,6 +637,7 @@ class VoxCPM2Model(nn.Module):
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
streaming=streaming,
streaming_prefix_len=streaming_prefix_len,
)
if streaming:
patch_len = self.patch_size * self.chunk_size
@@ -658,7 +662,13 @@ class VoxCPM2Model(nn.Module):
break
if not streaming:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
patch_len = self.patch_size * self.chunk_size
has_continuation = bool(prompt_wav_path)
if has_continuation:
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1):].squeeze(1).cpu()
else:
decode_audio = decode_audio.squeeze(1).cpu()
yield decode_audio
@torch.inference_mode()
@@ -930,7 +940,7 @@ class VoxCPM2Model(nn.Module):
if not streaming:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
patch_len = self.patch_size * self.chunk_size
if audio_mask.sum().item() > 0:
if mode in ("continuation", "ref_continuation"):
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1) :].squeeze(1).cpu()
else:
decode_audio = decode_audio[..., :].squeeze(1).cpu()
@@ -995,15 +1005,17 @@ class VoxCPM2Model(nn.Module):
curr_embed = None
# Prepare prompt context patches for streaming mode
# When there's a prompt audio, use its last (streaming_prefix_len - 1) patches as initial context
prompt_context_patches = []
audio_patch_count = int(feat_mask.sum().item())
if audio_patch_count > 0:
context_len = min(streaming_prefix_len - 1, audio_patch_count)
# - Continuation modes (feat_mask ends with 1): use the last (streaming_prefix_len - 1)
# trailing audio patches as initial context so the VAE can decode smoothly.
# - Reference-only / zero-shot (feat_mask ends with 0): start from scratch.
has_continuation_audio = feat_mask[0, -1].item() == 1
if has_continuation_audio:
audio_indices = feat_mask.squeeze(0).nonzero(as_tuple=True)[0]
context_len = min(streaming_prefix_len - 1, len(audio_indices))
last_audio_indices = audio_indices[-context_len:]
prompt_context_patches = list(feat[:, last_audio_indices, :, :].split(1, dim=1))
pred_feat_seq = prompt_context_patches + pred_feat_seq
pred_feat_seq = list(feat[:, last_audio_indices, :, :].split(1, dim=1))
else:
pred_feat_seq = []
enc_outputs, kv_cache_tuple = self.base_lm(
inputs_embeds=combined_embed,