fix: streaming decode

This commit is contained in:
Labmem-Zhouyx
2026-04-08 17:25:54 +08:00
parent 82d77d445c
commit ee3649c1b3
3 changed files with 27 additions and 43 deletions
+1 -1
View File
@@ -239,7 +239,7 @@ voxcpm --help
### Web Demo
```bash
python app.py # then open http://localhost:7860
python app.py --model-dir /path/to/VoxCPM2 --port 8808 # use a local model directory, open http://localhost:8808
```
### 🚢 Production Deployment (Nano-vLLM)
+1 -1
View File
@@ -238,7 +238,7 @@ voxcpm --help
### Web Demo
```bash
python app.py # 然后打开 http://localhost:7860
python app.py --model-dir /path/to/VoxCPM2 --port 8808 # 指定本地模型路径,然后打开 http://localhost:8808
```
### 🚢 生产部署(Nano-vLLM
+24 -40
View File
@@ -48,25 +48,8 @@ from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
from .utils import get_dtype, mask_multichar_chinese_tokens
def _trim_audio_silence_vad(
audio: torch.Tensor,
sample_rate: int,
max_silence_ms: float = 200.0,
top_db: float = 35.0,
) -> torch.Tensor:
"""使用能量阈值(VAD 方式)截取首尾静音及尾部长段伪静音,首尾各最多保留 max_silence_ms 毫秒静音。
会同时截掉末尾的长段伪静音(低能量但非完全静音的段落,如长时间底噪)。
Args:
audio: (1, T) 的音频 tensor
sample_rate: 采样率
max_silence_ms: 首尾允许保留的最大静音长度(毫秒)
top_db: 低于参考电平多少 dB 视为静音
Returns:
截取后的 (1, T') tensor
"""
# A simple function to trim audio silence using VAD, not used default
def _trim_audio_silence_vad(audio: torch.Tensor, sample_rate: int, max_silence_ms: float = 200.0, top_db: float = 35.0) -> torch.Tensor:
if audio.numel() == 0:
return audio
y = audio.squeeze(0).numpy()
@@ -85,7 +68,7 @@ def _trim_audio_silence_vad(
except Exception:
start, end = 0, n
# 用逐帧 RMS 找「最后一段有持续能量的位置」,截掉末尾长伪静音(低能量底噪等)
# Find the last frame with continuous energy, trim the long pseudo-silence at the end (low energy background noise, etc.)
n_frames = max(0, (n - frame_length) // hop_length + 1)
last_voice_frame = -1
for j in range(n_frames):
@@ -383,11 +366,7 @@ class VoxCPM2Model(nn.Module):
mu=dit_hidden,
patch_size=self.patch_size,
cond=feat_cond_for_sample,
n_timesteps=(
self.config.dit_config.cfm_config.inference_cfg_rate
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
else 10
),
n_timesteps=10,
)
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
@@ -658,13 +637,13 @@ class VoxCPM2Model(nn.Module):
)
if streaming:
decode_patch_len = self.patch_size * self._decode_chunk_size
for latent_pred, _ in inference_result:
for latent_pred, _, _ctx in inference_result:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
decode_audio = decode_audio[..., -decode_patch_len:].squeeze(1).cpu()
yield decode_audio
break
else:
latent_pred, pred_audio_feat = next(inference_result)
latent_pred, pred_audio_feat, context_len = next(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(
@@ -681,9 +660,8 @@ class VoxCPM2Model(nn.Module):
if not streaming:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
decode_patch_len = self.patch_size * self._decode_chunk_size
has_continuation = bool(prompt_wav_path)
if has_continuation:
decode_audio = decode_audio[..., decode_patch_len * (streaming_prefix_len - 1):].squeeze(1).cpu()
if context_len > 0:
decode_audio = decode_audio[..., decode_patch_len * context_len:].squeeze(1).cpu()
else:
decode_audio = decode_audio.squeeze(1).cpu()
yield decode_audio
@@ -946,13 +924,13 @@ class VoxCPM2Model(nn.Module):
)
if streaming:
decode_patch_len = self.patch_size * self._decode_chunk_size
for latent_pred, pred_audio_feat in inference_result:
for latent_pred, pred_audio_feat, _ctx in inference_result:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
decode_audio = decode_audio[..., -decode_patch_len:].squeeze(1).cpu()
yield (decode_audio, target_text_token, pred_audio_feat)
break
else:
latent_pred, pred_audio_feat = next(inference_result)
latent_pred, pred_audio_feat, context_len = next(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(
@@ -968,17 +946,19 @@ class VoxCPM2Model(nn.Module):
if not streaming:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
decode_patch_len = self.patch_size * self._decode_chunk_size
if mode in ("continuation", "ref_continuation"):
decode_audio = decode_audio[..., decode_patch_len * (streaming_prefix_len - 1) :].squeeze(1).cpu()
if context_len > 0:
decode_audio = decode_audio[..., decode_patch_len * context_len:].squeeze(1).cpu()
else:
decode_audio = decode_audio[..., :].squeeze(1).cpu()
decode_audio = decode_audio.squeeze(1).cpu()
yield (decode_audio, target_text_token, pred_audio_feat)
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
return next(self._inference(*args, streaming=False, **kwargs))
feat_pred, generated_feat, _ = next(self._inference(*args, streaming=False, **kwargs))
return feat_pred, generated_feat
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
return self._inference(*args, streaming=True, **kwargs)
for feat_pred, pred_feat_seq, _ in self._inference(*args, streaming=True, **kwargs):
yield feat_pred, pred_feat_seq
@torch.inference_mode()
def _inference(
@@ -1037,6 +1017,7 @@ class VoxCPM2Model(nn.Module):
# trailing audio patches as initial context so the VAE can decode smoothly.
# - Reference-only / zero-shot (feat_mask ends with 0): start from scratch.
has_continuation_audio = feat_mask[0, -1].item() == 1
context_len = 0
if has_continuation_audio:
audio_indices = feat_mask.squeeze(0).nonzero(as_tuple=True)[0]
context_len = min(streaming_prefix_len - 1, len(audio_indices))
@@ -1086,11 +1067,13 @@ class VoxCPM2Model(nn.Module):
prefix_feat_cond = pred_feat
if streaming:
# return the last three predicted latent features to provide enough context for smooth decoding
pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1)
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
yield feat_pred, pred_feat_seq
yield feat_pred, pred_feat_seq, context_len
if len(pred_feat_seq) > streaming_prefix_len:
pred_feat_seq = pred_feat_seq[-streaming_prefix_len:]
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
if i > min_len and stop_flag == 1:
@@ -1109,7 +1092,8 @@ class VoxCPM2Model(nn.Module):
if not streaming:
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
generated_feat = pred_feat_seq[:, context_len:, :, :].squeeze(0).cpu()
yield feat_pred, generated_feat, context_len
@classmethod
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):