fix: streaming decode
This commit is contained in:
@@ -239,7 +239,7 @@ voxcpm --help
|
|||||||
### Web Demo
|
### Web Demo
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python app.py # then open http://localhost:7860
|
python app.py --model-dir /path/to/VoxCPM2 --port 8808 # use a local model directory, open http://localhost:8808
|
||||||
```
|
```
|
||||||
|
|
||||||
### 🚢 Production Deployment (Nano-vLLM)
|
### 🚢 Production Deployment (Nano-vLLM)
|
||||||
|
|||||||
+1
-1
@@ -238,7 +238,7 @@ voxcpm --help
|
|||||||
### Web Demo
|
### Web Demo
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python app.py # 然后打开 http://localhost:7860
|
python app.py --model-dir /path/to/VoxCPM2 --port 8808 # 指定本地模型路径,然后打开 http://localhost:8808
|
||||||
```
|
```
|
||||||
|
|
||||||
### 🚢 生产部署(Nano-vLLM)
|
### 🚢 生产部署(Nano-vLLM)
|
||||||
|
|||||||
+24
-40
@@ -48,25 +48,8 @@ from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
|||||||
from .utils import get_dtype, mask_multichar_chinese_tokens
|
from .utils import get_dtype, mask_multichar_chinese_tokens
|
||||||
|
|
||||||
|
|
||||||
def _trim_audio_silence_vad(
|
# A simple function to trim audio silence using VAD, not used default
|
||||||
audio: torch.Tensor,
|
def _trim_audio_silence_vad(audio: torch.Tensor, sample_rate: int, max_silence_ms: float = 200.0, top_db: float = 35.0) -> torch.Tensor:
|
||||||
sample_rate: int,
|
|
||||||
max_silence_ms: float = 200.0,
|
|
||||||
top_db: float = 35.0,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""使用能量阈值(VAD 方式)截取首尾静音及尾部长段伪静音,首尾各最多保留 max_silence_ms 毫秒静音。
|
|
||||||
|
|
||||||
会同时截掉末尾的长段伪静音(低能量但非完全静音的段落,如长时间底噪)。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
audio: (1, T) 的音频 tensor
|
|
||||||
sample_rate: 采样率
|
|
||||||
max_silence_ms: 首尾允许保留的最大静音长度(毫秒)
|
|
||||||
top_db: 低于参考电平多少 dB 视为静音
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
截取后的 (1, T') tensor
|
|
||||||
"""
|
|
||||||
if audio.numel() == 0:
|
if audio.numel() == 0:
|
||||||
return audio
|
return audio
|
||||||
y = audio.squeeze(0).numpy()
|
y = audio.squeeze(0).numpy()
|
||||||
@@ -85,7 +68,7 @@ def _trim_audio_silence_vad(
|
|||||||
except Exception:
|
except Exception:
|
||||||
start, end = 0, n
|
start, end = 0, n
|
||||||
|
|
||||||
# 用逐帧 RMS 找「最后一段有持续能量的位置」,截掉末尾长伪静音(低能量底噪等)
|
# Find the last frame with continuous energy, trim the long pseudo-silence at the end (low energy background noise, etc.)
|
||||||
n_frames = max(0, (n - frame_length) // hop_length + 1)
|
n_frames = max(0, (n - frame_length) // hop_length + 1)
|
||||||
last_voice_frame = -1
|
last_voice_frame = -1
|
||||||
for j in range(n_frames):
|
for j in range(n_frames):
|
||||||
@@ -383,11 +366,7 @@ class VoxCPM2Model(nn.Module):
|
|||||||
mu=dit_hidden,
|
mu=dit_hidden,
|
||||||
patch_size=self.patch_size,
|
patch_size=self.patch_size,
|
||||||
cond=feat_cond_for_sample,
|
cond=feat_cond_for_sample,
|
||||||
n_timesteps=(
|
n_timesteps=10,
|
||||||
self.config.dit_config.cfm_config.inference_cfg_rate
|
|
||||||
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
|
|
||||||
else 10
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
|
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
|
||||||
|
|
||||||
@@ -658,13 +637,13 @@ class VoxCPM2Model(nn.Module):
|
|||||||
)
|
)
|
||||||
if streaming:
|
if streaming:
|
||||||
decode_patch_len = self.patch_size * self._decode_chunk_size
|
decode_patch_len = self.patch_size * self._decode_chunk_size
|
||||||
for latent_pred, _ in inference_result:
|
for latent_pred, _, _ctx in inference_result:
|
||||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||||
decode_audio = decode_audio[..., -decode_patch_len:].squeeze(1).cpu()
|
decode_audio = decode_audio[..., -decode_patch_len:].squeeze(1).cpu()
|
||||||
yield decode_audio
|
yield decode_audio
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
latent_pred, pred_audio_feat = next(inference_result)
|
latent_pred, pred_audio_feat, context_len = next(inference_result)
|
||||||
if retry_badcase:
|
if retry_badcase:
|
||||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||||
print(
|
print(
|
||||||
@@ -681,9 +660,8 @@ class VoxCPM2Model(nn.Module):
|
|||||||
if not streaming:
|
if not streaming:
|
||||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||||
decode_patch_len = self.patch_size * self._decode_chunk_size
|
decode_patch_len = self.patch_size * self._decode_chunk_size
|
||||||
has_continuation = bool(prompt_wav_path)
|
if context_len > 0:
|
||||||
if has_continuation:
|
decode_audio = decode_audio[..., decode_patch_len * context_len:].squeeze(1).cpu()
|
||||||
decode_audio = decode_audio[..., decode_patch_len * (streaming_prefix_len - 1):].squeeze(1).cpu()
|
|
||||||
else:
|
else:
|
||||||
decode_audio = decode_audio.squeeze(1).cpu()
|
decode_audio = decode_audio.squeeze(1).cpu()
|
||||||
yield decode_audio
|
yield decode_audio
|
||||||
@@ -946,13 +924,13 @@ class VoxCPM2Model(nn.Module):
|
|||||||
)
|
)
|
||||||
if streaming:
|
if streaming:
|
||||||
decode_patch_len = self.patch_size * self._decode_chunk_size
|
decode_patch_len = self.patch_size * self._decode_chunk_size
|
||||||
for latent_pred, pred_audio_feat in inference_result:
|
for latent_pred, pred_audio_feat, _ctx in inference_result:
|
||||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||||
decode_audio = decode_audio[..., -decode_patch_len:].squeeze(1).cpu()
|
decode_audio = decode_audio[..., -decode_patch_len:].squeeze(1).cpu()
|
||||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
latent_pred, pred_audio_feat = next(inference_result)
|
latent_pred, pred_audio_feat, context_len = next(inference_result)
|
||||||
if retry_badcase:
|
if retry_badcase:
|
||||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||||
print(
|
print(
|
||||||
@@ -968,17 +946,19 @@ class VoxCPM2Model(nn.Module):
|
|||||||
if not streaming:
|
if not streaming:
|
||||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||||
decode_patch_len = self.patch_size * self._decode_chunk_size
|
decode_patch_len = self.patch_size * self._decode_chunk_size
|
||||||
if mode in ("continuation", "ref_continuation"):
|
if context_len > 0:
|
||||||
decode_audio = decode_audio[..., decode_patch_len * (streaming_prefix_len - 1) :].squeeze(1).cpu()
|
decode_audio = decode_audio[..., decode_patch_len * context_len:].squeeze(1).cpu()
|
||||||
else:
|
else:
|
||||||
decode_audio = decode_audio[..., :].squeeze(1).cpu()
|
decode_audio = decode_audio.squeeze(1).cpu()
|
||||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||||
|
|
||||||
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
return next(self._inference(*args, streaming=False, **kwargs))
|
feat_pred, generated_feat, _ = next(self._inference(*args, streaming=False, **kwargs))
|
||||||
|
return feat_pred, generated_feat
|
||||||
|
|
||||||
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
|
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
|
||||||
return self._inference(*args, streaming=True, **kwargs)
|
for feat_pred, pred_feat_seq, _ in self._inference(*args, streaming=True, **kwargs):
|
||||||
|
yield feat_pred, pred_feat_seq
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def _inference(
|
def _inference(
|
||||||
@@ -1037,6 +1017,7 @@ class VoxCPM2Model(nn.Module):
|
|||||||
# trailing audio patches as initial context so the VAE can decode smoothly.
|
# trailing audio patches as initial context so the VAE can decode smoothly.
|
||||||
# - Reference-only / zero-shot (feat_mask ends with 0): start from scratch.
|
# - Reference-only / zero-shot (feat_mask ends with 0): start from scratch.
|
||||||
has_continuation_audio = feat_mask[0, -1].item() == 1
|
has_continuation_audio = feat_mask[0, -1].item() == 1
|
||||||
|
context_len = 0
|
||||||
if has_continuation_audio:
|
if has_continuation_audio:
|
||||||
audio_indices = feat_mask.squeeze(0).nonzero(as_tuple=True)[0]
|
audio_indices = feat_mask.squeeze(0).nonzero(as_tuple=True)[0]
|
||||||
context_len = min(streaming_prefix_len - 1, len(audio_indices))
|
context_len = min(streaming_prefix_len - 1, len(audio_indices))
|
||||||
@@ -1086,11 +1067,13 @@ class VoxCPM2Model(nn.Module):
|
|||||||
prefix_feat_cond = pred_feat
|
prefix_feat_cond = pred_feat
|
||||||
|
|
||||||
if streaming:
|
if streaming:
|
||||||
# return the last three predicted latent features to provide enough context for smooth decoding
|
|
||||||
pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1)
|
pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1)
|
||||||
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||||
|
|
||||||
yield feat_pred, pred_feat_seq
|
yield feat_pred, pred_feat_seq, context_len
|
||||||
|
|
||||||
|
if len(pred_feat_seq) > streaming_prefix_len:
|
||||||
|
pred_feat_seq = pred_feat_seq[-streaming_prefix_len:]
|
||||||
|
|
||||||
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
|
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
|
||||||
if i > min_len and stop_flag == 1:
|
if i > min_len and stop_flag == 1:
|
||||||
@@ -1109,7 +1092,8 @@ class VoxCPM2Model(nn.Module):
|
|||||||
if not streaming:
|
if not streaming:
|
||||||
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
|
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
|
||||||
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||||
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
|
generated_feat = pred_feat_seq[:, context_len:, :, :].squeeze(0).cpu()
|
||||||
|
yield feat_pred, generated_feat, context_len
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
|
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
|
||||||
|
|||||||
Reference in New Issue
Block a user