diff --git a/README.md b/README.md index d6e6467..a460c31 100644 --- a/README.md +++ b/README.md @@ -238,8 +238,8 @@ voxcpm --help ### Web Demo -```bash -python app.py # then open http://localhost:7860 +```bash +python app.py --model-dir /path/to/VoxCPM2 --port 8808 # use a local model directory, open http://localhost:8808 ``` ### 🚢 Production Deployment (Nano-vLLM) diff --git a/README_zh.md b/README_zh.md index 89e1f43..2c184e5 100644 --- a/README_zh.md +++ b/README_zh.md @@ -238,7 +238,7 @@ voxcpm --help ### Web Demo ```bash -python app.py # 然后打开 http://localhost:7860 +python app.py --model-dir /path/to/VoxCPM2 --port 8808 # 指定本地模型路径,然后打开 http://localhost:8808 ``` ### 🚢 生产部署(Nano-vLLM) diff --git a/src/voxcpm/model/voxcpm2.py b/src/voxcpm/model/voxcpm2.py index 29a7741..ad4de96 100644 --- a/src/voxcpm/model/voxcpm2.py +++ b/src/voxcpm/model/voxcpm2.py @@ -48,25 +48,8 @@ from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel from .utils import get_dtype, mask_multichar_chinese_tokens -def _trim_audio_silence_vad( - audio: torch.Tensor, - sample_rate: int, - max_silence_ms: float = 200.0, - top_db: float = 35.0, -) -> torch.Tensor: - """使用能量阈值(VAD 方式)截取首尾静音及尾部长段伪静音,首尾各最多保留 max_silence_ms 毫秒静音。 - - 会同时截掉末尾的长段伪静音(低能量但非完全静音的段落,如长时间底噪)。 - - Args: - audio: (1, T) 的音频 tensor - sample_rate: 采样率 - max_silence_ms: 首尾允许保留的最大静音长度(毫秒) - top_db: 低于参考电平多少 dB 视为静音 - - Returns: - 截取后的 (1, T') tensor - """ +# A simple function to trim audio silence using VAD, not used default +def _trim_audio_silence_vad(audio: torch.Tensor, sample_rate: int, max_silence_ms: float = 200.0, top_db: float = 35.0) -> torch.Tensor: if audio.numel() == 0: return audio y = audio.squeeze(0).numpy() @@ -85,7 +68,7 @@ def _trim_audio_silence_vad( except Exception: start, end = 0, n - # 用逐帧 RMS 找「最后一段有持续能量的位置」,截掉末尾长伪静音(低能量底噪等) + # Find the last frame with continuous energy, trim the long pseudo-silence at the end (low energy background noise, etc.) n_frames = max(0, (n - frame_length) // hop_length + 1) last_voice_frame = -1 for j in range(n_frames): @@ -383,11 +366,7 @@ class VoxCPM2Model(nn.Module): mu=dit_hidden, patch_size=self.patch_size, cond=feat_cond_for_sample, - n_timesteps=( - self.config.dit_config.cfm_config.inference_cfg_rate - if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate") - else 10 - ), + n_timesteps=10, ) feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size) @@ -658,13 +637,13 @@ class VoxCPM2Model(nn.Module): ) if streaming: decode_patch_len = self.patch_size * self._decode_chunk_size - for latent_pred, _ in inference_result: + for latent_pred, _, _ctx in inference_result: decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)) decode_audio = decode_audio[..., -decode_patch_len:].squeeze(1).cpu() yield decode_audio break else: - latent_pred, pred_audio_feat = next(inference_result) + latent_pred, pred_audio_feat, context_len = next(inference_result) if retry_badcase: if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold: print( @@ -681,9 +660,8 @@ class VoxCPM2Model(nn.Module): if not streaming: decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)) decode_patch_len = self.patch_size * self._decode_chunk_size - has_continuation = bool(prompt_wav_path) - if has_continuation: - decode_audio = decode_audio[..., decode_patch_len * (streaming_prefix_len - 1):].squeeze(1).cpu() + if context_len > 0: + decode_audio = decode_audio[..., decode_patch_len * context_len:].squeeze(1).cpu() else: decode_audio = decode_audio.squeeze(1).cpu() yield decode_audio @@ -946,13 +924,13 @@ class VoxCPM2Model(nn.Module): ) if streaming: decode_patch_len = self.patch_size * self._decode_chunk_size - for latent_pred, pred_audio_feat in inference_result: + for latent_pred, pred_audio_feat, _ctx in inference_result: decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)) decode_audio = decode_audio[..., -decode_patch_len:].squeeze(1).cpu() yield (decode_audio, target_text_token, pred_audio_feat) break else: - latent_pred, pred_audio_feat = next(inference_result) + latent_pred, pred_audio_feat, context_len = next(inference_result) if retry_badcase: if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold: print( @@ -968,17 +946,19 @@ class VoxCPM2Model(nn.Module): if not streaming: decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)) decode_patch_len = self.patch_size * self._decode_chunk_size - if mode in ("continuation", "ref_continuation"): - decode_audio = decode_audio[..., decode_patch_len * (streaming_prefix_len - 1) :].squeeze(1).cpu() + if context_len > 0: + decode_audio = decode_audio[..., decode_patch_len * context_len:].squeeze(1).cpu() else: - decode_audio = decode_audio[..., :].squeeze(1).cpu() + decode_audio = decode_audio.squeeze(1).cpu() yield (decode_audio, target_text_token, pred_audio_feat) def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: - return next(self._inference(*args, streaming=False, **kwargs)) + feat_pred, generated_feat, _ = next(self._inference(*args, streaming=False, **kwargs)) + return feat_pred, generated_feat def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]: - return self._inference(*args, streaming=True, **kwargs) + for feat_pred, pred_feat_seq, _ in self._inference(*args, streaming=True, **kwargs): + yield feat_pred, pred_feat_seq @torch.inference_mode() def _inference( @@ -1037,6 +1017,7 @@ class VoxCPM2Model(nn.Module): # trailing audio patches as initial context so the VAE can decode smoothly. # - Reference-only / zero-shot (feat_mask ends with 0): start from scratch. has_continuation_audio = feat_mask[0, -1].item() == 1 + context_len = 0 if has_continuation_audio: audio_indices = feat_mask.squeeze(0).nonzero(as_tuple=True)[0] context_len = min(streaming_prefix_len - 1, len(audio_indices)) @@ -1086,11 +1067,13 @@ class VoxCPM2Model(nn.Module): prefix_feat_cond = pred_feat if streaming: - # return the last three predicted latent features to provide enough context for smooth decoding pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1) feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size) - yield feat_pred, pred_feat_seq + yield feat_pred, pred_feat_seq, context_len + + if len(pred_feat_seq) > streaming_prefix_len: + pred_feat_seq = pred_feat_seq[-streaming_prefix_len:] stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item() if i > min_len and stop_flag == 1: @@ -1109,7 +1092,8 @@ class VoxCPM2Model(nn.Module): if not streaming: pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size) - yield feat_pred, pred_feat_seq.squeeze(0).cpu() + generated_feat = pred_feat_seq[:, context_len:, :, :].squeeze(0).cpu() + yield feat_pred, generated_feat, context_len @classmethod def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):