fix: streaming decode
This commit is contained in:
@@ -238,8 +238,8 @@ voxcpm --help
|
||||
|
||||
### Web Demo
|
||||
|
||||
```bash
|
||||
python app.py # then open http://localhost:7860
|
||||
```bash
|
||||
python app.py --model-dir /path/to/VoxCPM2 --port 8808 # use a local model directory, open http://localhost:8808
|
||||
```
|
||||
|
||||
### 🚢 Production Deployment (Nano-vLLM)
|
||||
|
||||
+1
-1
@@ -238,7 +238,7 @@ voxcpm --help
|
||||
### Web Demo
|
||||
|
||||
```bash
|
||||
python app.py # 然后打开 http://localhost:7860
|
||||
python app.py --model-dir /path/to/VoxCPM2 --port 8808 # 指定本地模型路径,然后打开 http://localhost:8808
|
||||
```
|
||||
|
||||
### 🚢 生产部署(Nano-vLLM)
|
||||
|
||||
+24
-40
@@ -48,25 +48,8 @@ from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
||||
from .utils import get_dtype, mask_multichar_chinese_tokens
|
||||
|
||||
|
||||
def _trim_audio_silence_vad(
|
||||
audio: torch.Tensor,
|
||||
sample_rate: int,
|
||||
max_silence_ms: float = 200.0,
|
||||
top_db: float = 35.0,
|
||||
) -> torch.Tensor:
|
||||
"""使用能量阈值(VAD 方式)截取首尾静音及尾部长段伪静音,首尾各最多保留 max_silence_ms 毫秒静音。
|
||||
|
||||
会同时截掉末尾的长段伪静音(低能量但非完全静音的段落,如长时间底噪)。
|
||||
|
||||
Args:
|
||||
audio: (1, T) 的音频 tensor
|
||||
sample_rate: 采样率
|
||||
max_silence_ms: 首尾允许保留的最大静音长度(毫秒)
|
||||
top_db: 低于参考电平多少 dB 视为静音
|
||||
|
||||
Returns:
|
||||
截取后的 (1, T') tensor
|
||||
"""
|
||||
# A simple function to trim audio silence using VAD, not used default
|
||||
def _trim_audio_silence_vad(audio: torch.Tensor, sample_rate: int, max_silence_ms: float = 200.0, top_db: float = 35.0) -> torch.Tensor:
|
||||
if audio.numel() == 0:
|
||||
return audio
|
||||
y = audio.squeeze(0).numpy()
|
||||
@@ -85,7 +68,7 @@ def _trim_audio_silence_vad(
|
||||
except Exception:
|
||||
start, end = 0, n
|
||||
|
||||
# 用逐帧 RMS 找「最后一段有持续能量的位置」,截掉末尾长伪静音(低能量底噪等)
|
||||
# Find the last frame with continuous energy, trim the long pseudo-silence at the end (low energy background noise, etc.)
|
||||
n_frames = max(0, (n - frame_length) // hop_length + 1)
|
||||
last_voice_frame = -1
|
||||
for j in range(n_frames):
|
||||
@@ -383,11 +366,7 @@ class VoxCPM2Model(nn.Module):
|
||||
mu=dit_hidden,
|
||||
patch_size=self.patch_size,
|
||||
cond=feat_cond_for_sample,
|
||||
n_timesteps=(
|
||||
self.config.dit_config.cfm_config.inference_cfg_rate
|
||||
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
|
||||
else 10
|
||||
),
|
||||
n_timesteps=10,
|
||||
)
|
||||
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
|
||||
|
||||
@@ -658,13 +637,13 @@ class VoxCPM2Model(nn.Module):
|
||||
)
|
||||
if streaming:
|
||||
decode_patch_len = self.patch_size * self._decode_chunk_size
|
||||
for latent_pred, _ in inference_result:
|
||||
for latent_pred, _, _ctx in inference_result:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||
decode_audio = decode_audio[..., -decode_patch_len:].squeeze(1).cpu()
|
||||
yield decode_audio
|
||||
break
|
||||
else:
|
||||
latent_pred, pred_audio_feat = next(inference_result)
|
||||
latent_pred, pred_audio_feat, context_len = next(inference_result)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(
|
||||
@@ -681,9 +660,8 @@ class VoxCPM2Model(nn.Module):
|
||||
if not streaming:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||
decode_patch_len = self.patch_size * self._decode_chunk_size
|
||||
has_continuation = bool(prompt_wav_path)
|
||||
if has_continuation:
|
||||
decode_audio = decode_audio[..., decode_patch_len * (streaming_prefix_len - 1):].squeeze(1).cpu()
|
||||
if context_len > 0:
|
||||
decode_audio = decode_audio[..., decode_patch_len * context_len:].squeeze(1).cpu()
|
||||
else:
|
||||
decode_audio = decode_audio.squeeze(1).cpu()
|
||||
yield decode_audio
|
||||
@@ -946,13 +924,13 @@ class VoxCPM2Model(nn.Module):
|
||||
)
|
||||
if streaming:
|
||||
decode_patch_len = self.patch_size * self._decode_chunk_size
|
||||
for latent_pred, pred_audio_feat in inference_result:
|
||||
for latent_pred, pred_audio_feat, _ctx in inference_result:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||
decode_audio = decode_audio[..., -decode_patch_len:].squeeze(1).cpu()
|
||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||
break
|
||||
else:
|
||||
latent_pred, pred_audio_feat = next(inference_result)
|
||||
latent_pred, pred_audio_feat, context_len = next(inference_result)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(
|
||||
@@ -968,17 +946,19 @@ class VoxCPM2Model(nn.Module):
|
||||
if not streaming:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||
decode_patch_len = self.patch_size * self._decode_chunk_size
|
||||
if mode in ("continuation", "ref_continuation"):
|
||||
decode_audio = decode_audio[..., decode_patch_len * (streaming_prefix_len - 1) :].squeeze(1).cpu()
|
||||
if context_len > 0:
|
||||
decode_audio = decode_audio[..., decode_patch_len * context_len:].squeeze(1).cpu()
|
||||
else:
|
||||
decode_audio = decode_audio[..., :].squeeze(1).cpu()
|
||||
decode_audio = decode_audio.squeeze(1).cpu()
|
||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||
|
||||
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return next(self._inference(*args, streaming=False, **kwargs))
|
||||
feat_pred, generated_feat, _ = next(self._inference(*args, streaming=False, **kwargs))
|
||||
return feat_pred, generated_feat
|
||||
|
||||
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
|
||||
return self._inference(*args, streaming=True, **kwargs)
|
||||
for feat_pred, pred_feat_seq, _ in self._inference(*args, streaming=True, **kwargs):
|
||||
yield feat_pred, pred_feat_seq
|
||||
|
||||
@torch.inference_mode()
|
||||
def _inference(
|
||||
@@ -1037,6 +1017,7 @@ class VoxCPM2Model(nn.Module):
|
||||
# trailing audio patches as initial context so the VAE can decode smoothly.
|
||||
# - Reference-only / zero-shot (feat_mask ends with 0): start from scratch.
|
||||
has_continuation_audio = feat_mask[0, -1].item() == 1
|
||||
context_len = 0
|
||||
if has_continuation_audio:
|
||||
audio_indices = feat_mask.squeeze(0).nonzero(as_tuple=True)[0]
|
||||
context_len = min(streaming_prefix_len - 1, len(audio_indices))
|
||||
@@ -1086,11 +1067,13 @@ class VoxCPM2Model(nn.Module):
|
||||
prefix_feat_cond = pred_feat
|
||||
|
||||
if streaming:
|
||||
# return the last three predicted latent features to provide enough context for smooth decoding
|
||||
pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1)
|
||||
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
|
||||
yield feat_pred, pred_feat_seq
|
||||
yield feat_pred, pred_feat_seq, context_len
|
||||
|
||||
if len(pred_feat_seq) > streaming_prefix_len:
|
||||
pred_feat_seq = pred_feat_seq[-streaming_prefix_len:]
|
||||
|
||||
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
|
||||
if i > min_len and stop_flag == 1:
|
||||
@@ -1109,7 +1092,8 @@ class VoxCPM2Model(nn.Module):
|
||||
if not streaming:
|
||||
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
|
||||
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
|
||||
generated_feat = pred_feat_seq[:, context_len:, :, :].squeeze(0).cpu()
|
||||
yield feat_pred, generated_feat, context_len
|
||||
|
||||
@classmethod
|
||||
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
|
||||
|
||||
Reference in New Issue
Block a user