feat: add no_rope support for residual LM and fix streaming continuation decoding
- Add `residual_lm_no_rope` config option in VoxCPMConfig and propagate to MiniCPMModel - Add `no_rope` field to MiniCPM4Config; make RoPE embedding optional in MiniCPMModel and MiniCPMAttention - Add `streaming_prefix_len` parameter to generation interface - Fix non-streaming audio decode in continuation mode to trim leading prefix patches consistently - Refactor streaming prefix context preparation: distinguish continuation vs. zero-shot via feat_mask trailing bit instead of audio_mask sum Made-with: Cursor
This commit is contained in:
@@ -129,6 +129,7 @@ class VoxCPMConfig(BaseModel):
|
||||
patch_size: int = 4
|
||||
feat_dim: int = 64
|
||||
residual_lm_num_layers: int = 8
|
||||
residual_lm_no_rope: bool = False
|
||||
scalar_quantization_latent_dim: int = 512
|
||||
scalar_quantization_scale: int = 9
|
||||
|
||||
@@ -195,6 +196,7 @@ class VoxCPM2Model(nn.Module):
|
||||
residual_lm_config = config.lm_config.model_copy(deep=True)
|
||||
residual_lm_config.num_hidden_layers = config.residual_lm_num_layers
|
||||
residual_lm_config.vocab_size = 0
|
||||
residual_lm_config.no_rope = config.residual_lm_no_rope
|
||||
self.residual_lm = MiniCPMModel(residual_lm_config)
|
||||
self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
|
||||
|
||||
@@ -474,6 +476,7 @@ class VoxCPM2Model(nn.Module):
|
||||
retry_badcase_max_times: int = 3,
|
||||
retry_badcase_ratio_threshold: float = 6.0,
|
||||
streaming: bool = False,
|
||||
streaming_prefix_len: int = 3,
|
||||
) -> Generator[torch.Tensor, None, None]:
|
||||
if retry_badcase and streaming:
|
||||
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
|
||||
@@ -634,6 +637,7 @@ class VoxCPM2Model(nn.Module):
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
streaming=streaming,
|
||||
streaming_prefix_len=streaming_prefix_len,
|
||||
)
|
||||
if streaming:
|
||||
patch_len = self.patch_size * self.chunk_size
|
||||
@@ -658,7 +662,13 @@ class VoxCPM2Model(nn.Module):
|
||||
break
|
||||
|
||||
if not streaming:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||
patch_len = self.patch_size * self.chunk_size
|
||||
has_continuation = bool(prompt_wav_path)
|
||||
if has_continuation:
|
||||
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1):].squeeze(1).cpu()
|
||||
else:
|
||||
decode_audio = decode_audio.squeeze(1).cpu()
|
||||
yield decode_audio
|
||||
|
||||
@torch.inference_mode()
|
||||
@@ -930,7 +940,7 @@ class VoxCPM2Model(nn.Module):
|
||||
if not streaming:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||
patch_len = self.patch_size * self.chunk_size
|
||||
if audio_mask.sum().item() > 0:
|
||||
if mode in ("continuation", "ref_continuation"):
|
||||
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1) :].squeeze(1).cpu()
|
||||
else:
|
||||
decode_audio = decode_audio[..., :].squeeze(1).cpu()
|
||||
@@ -995,15 +1005,17 @@ class VoxCPM2Model(nn.Module):
|
||||
curr_embed = None
|
||||
|
||||
# Prepare prompt context patches for streaming mode
|
||||
# When there's a prompt audio, use its last (streaming_prefix_len - 1) patches as initial context
|
||||
prompt_context_patches = []
|
||||
audio_patch_count = int(feat_mask.sum().item())
|
||||
if audio_patch_count > 0:
|
||||
context_len = min(streaming_prefix_len - 1, audio_patch_count)
|
||||
# - Continuation modes (feat_mask ends with 1): use the last (streaming_prefix_len - 1)
|
||||
# trailing audio patches as initial context so the VAE can decode smoothly.
|
||||
# - Reference-only / zero-shot (feat_mask ends with 0): start from scratch.
|
||||
has_continuation_audio = feat_mask[0, -1].item() == 1
|
||||
if has_continuation_audio:
|
||||
audio_indices = feat_mask.squeeze(0).nonzero(as_tuple=True)[0]
|
||||
context_len = min(streaming_prefix_len - 1, len(audio_indices))
|
||||
last_audio_indices = audio_indices[-context_len:]
|
||||
prompt_context_patches = list(feat[:, last_audio_indices, :, :].split(1, dim=1))
|
||||
pred_feat_seq = prompt_context_patches + pred_feat_seq
|
||||
pred_feat_seq = list(feat[:, last_audio_indices, :, :].split(1, dim=1))
|
||||
else:
|
||||
pred_feat_seq = []
|
||||
|
||||
enc_outputs, kv_cache_tuple = self.base_lm(
|
||||
inputs_embeds=combined_embed,
|
||||
|
||||
Reference in New Issue
Block a user