Files
VoxCPM/src/voxcpm/model/voxcpm2.py
T
sharziki fb46aad9a5 fix: close file handles in from_local() config loading
Use context managers when reading config.json in VoxCPMModel.from_local()
and VoxCPM2Model.from_local() to prevent file descriptor leaks. Also add
explicit encoding="utf-8" to avoid locale-dependent decode errors.

Closes #235

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-11 00:01:14 -04:00

1244 lines
52 KiB
Python

"""
VoxCPM: A Tokenizer-free speech generation model
This module contains the main VoxCPM model implementation, including configuration classes
and the core VoxCPMModel for text-to-speech generation.
Copyright 2026 OpenBMB
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import sys
from typing import Tuple, Union, Generator, List, Optional
import torch
import torch.nn as nn
import warnings
import librosa
import numpy as np
from einops import rearrange
from pydantic import BaseModel
try:
from safetensors.torch import load_file
SAFETENSORS_AVAILABLE = True
except ImportError:
SAFETENSORS_AVAILABLE = False
from tqdm import tqdm
from transformers import LlamaTokenizerFast
from ..modules.audiovae import AudioVAEV2, AudioVAEConfigV2
from ..modules.layers import ScalarQuantizationLayer
from ..modules.layers.lora import apply_lora_to_named_linear_modules
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiTV2
from ..modules.locenc import VoxCPMLocEnc
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
from .utils import get_dtype, mask_multichar_chinese_tokens, resolve_runtime_device
# A simple function to trim audio silence using VAD, not used default
def _trim_audio_silence_vad(audio: torch.Tensor, sample_rate: int, max_silence_ms: float = 200.0, top_db: float = 35.0) -> torch.Tensor:
if audio.numel() == 0:
return audio
y = audio.squeeze(0).numpy()
n = len(y)
frame_length = 2048
hop_length = 512
ref = np.max(np.abs(y))
if ref <= 0:
return audio
threshold = ref * (10.0 ** (-top_db / 20.0))
try:
_, (start, end) = librosa.effects.trim(
y, top_db=top_db, ref=np.max, frame_length=frame_length, hop_length=hop_length
)
except Exception:
start, end = 0, n
# Find the last frame with continuous energy, trim the long pseudo-silence at the end (low energy background noise, etc.)
n_frames = max(0, (n - frame_length) // hop_length + 1)
last_voice_frame = -1
for j in range(n_frames):
idx = j * hop_length
if idx + frame_length > n:
break
rms = np.sqrt(np.mean(y[idx : idx + frame_length] ** 2))
if rms >= threshold:
last_voice_frame = j
if last_voice_frame >= 0:
end_by_vad = min(n, (last_voice_frame + 1) * hop_length + (frame_length - hop_length))
end = min(end, end_by_vad)
max_silence_samples = int(max_silence_ms * sample_rate / 1000.0)
new_start = max(0, start - max_silence_samples)
new_end = min(n, end + max_silence_samples)
return audio[:, new_start:new_end]
class VoxCPMEncoderConfig(BaseModel):
hidden_dim: int = 1024
ffn_dim: int = 4096
num_heads: int = 16
num_layers: int = 4
kv_channels: int = None
class VoxCPMDitConfig(BaseModel):
hidden_dim: int = 1024
ffn_dim: int = 4096
num_heads: int = 16
num_layers: int = 4
kv_channels: int = None
dit_mean_mode: bool = False
cfm_config: CfmConfig
class VoxCPMConfig(BaseModel):
lm_config: MiniCPM4Config
patch_size: int = 4
feat_dim: int = 64
residual_lm_num_layers: int = 8
residual_lm_no_rope: bool = False
scalar_quantization_latent_dim: int = 512
scalar_quantization_scale: int = 9
encoder_config: VoxCPMEncoderConfig
dit_config: VoxCPMDitConfig
audio_vae_config: Optional[AudioVAEConfigV2] = None
max_length: int = 8192
device: str = "cuda"
dtype: str = "bfloat16"
class LoRAConfig(BaseModel):
enable_lm: bool = False # Apply LoRA to base_lm + residual_lm
enable_dit: bool = False # Apply LoRA to VoxCPMLocDiT
enable_proj: bool = False # Apply LoRA to projection Linear layers
r: int = 8
alpha: int = 16
dropout: float = 0.0
# Target linear layer names for LM & DiT (matched by attribute name)
target_modules_lm: list[str] = ["q_proj", "v_proj", "k_proj", "o_proj"]
target_modules_dit: list[str] = ["q_proj", "v_proj", "k_proj", "o_proj"]
# Projection layer attribute names to find on VoxCPM2Model
target_proj_modules: list[str] = ["enc_to_lm_proj", "lm_to_dit_proj", "res_to_dit_proj", "fusion_concat_proj"]
VoxCPMConfig.model_rebuild()
class VoxCPM2Model(nn.Module):
def __init__(
self,
config: VoxCPMConfig,
tokenizer: LlamaTokenizerFast,
audio_vae: AudioVAEV2,
lora_config: LoRAConfig = None,
device: str | None = None,
):
super().__init__()
self.config = config
self.lora_config = lora_config
self.feat_dim = config.feat_dim
self.patch_size = config.patch_size
self.device = resolve_runtime_device(device, config.device)
self.config.device = self.device
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
# Text-Semantic LM
self.base_lm = MiniCPMModel(config.lm_config)
self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
self.text_tokenizer = mask_multichar_chinese_tokens(tokenizer)
self.audio_start_token = 101
self.audio_end_token = 102
self.ref_audio_start_token = 103
self.ref_audio_end_token = 104
# Residual Acoustic LM
residual_lm_config = config.lm_config.model_copy(deep=True)
residual_lm_config.num_hidden_layers = config.residual_lm_num_layers
residual_lm_config.vocab_size = 0
residual_lm_config.no_rope = config.residual_lm_no_rope
self.residual_lm = MiniCPMModel(residual_lm_config)
self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
# Local Encoder
encoder_config = config.lm_config.model_copy(deep=True)
encoder_config.hidden_size = config.encoder_config.hidden_dim
encoder_config.intermediate_size = config.encoder_config.ffn_dim
encoder_config.num_attention_heads = config.encoder_config.num_heads
encoder_config.num_hidden_layers = config.encoder_config.num_layers
encoder_config.kv_channels = config.encoder_config.kv_channels
encoder_config.vocab_size = 0
self.feat_encoder = VoxCPMLocEnc(encoder_config, input_dim=config.feat_dim)
# Local DiT
decoder_config = config.lm_config.model_copy(deep=True)
decoder_config.hidden_size = config.dit_config.hidden_dim
decoder_config.intermediate_size = config.dit_config.ffn_dim
decoder_config.num_attention_heads = config.dit_config.num_heads
decoder_config.num_hidden_layers = config.dit_config.num_layers
decoder_config.kv_channels = config.dit_config.kv_channels
decoder_config.vocab_size = 0
self.feat_decoder = UnifiedCFM(
in_channels=config.feat_dim,
cfm_params=config.dit_config.cfm_config,
estimator=VoxCPMLocDiTV2(decoder_config, in_channels=config.feat_dim),
mean_mode=config.dit_config.dit_mean_mode,
)
# Projection layers
self.fsq_layer = ScalarQuantizationLayer(
config.lm_config.hidden_size,
config.lm_config.hidden_size,
config.scalar_quantization_latent_dim,
config.scalar_quantization_scale,
)
self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size)
self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
self.res_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
self.fusion_concat_proj = nn.Linear(config.lm_config.hidden_size * 2, config.lm_config.hidden_size)
# Stop Predictor
self.stop_proj = nn.Linear(config.lm_config.hidden_size, config.lm_config.hidden_size)
self.stop_actn = nn.SiLU()
self.stop_head = nn.Linear(config.lm_config.hidden_size, 2, bias=False)
self.stop_loss = nn.CrossEntropyLoss(reduction="none")
# Audio VAE
self.audio_vae = audio_vae
self.chunk_size = audio_vae.chunk_size
self._decode_chunk_size = getattr(audio_vae, "decode_chunk_size", audio_vae.chunk_size)
self._encode_sample_rate = audio_vae.sample_rate
self.sample_rate = getattr(audio_vae, "out_sample_rate", audio_vae.sample_rate)
if self.lora_config is not None:
self._apply_lora()
def _apply_lora(self):
"""注入 LoRA 到 LM / DiT / 投影层"""
cfg = self.lora_config
lora_kwargs = dict(r=cfg.r, alpha=cfg.alpha, dropout=cfg.dropout)
# LM: base_lm + residual_lm
if cfg.enable_lm:
for lm in [self.base_lm, self.residual_lm]:
apply_lora_to_named_linear_modules(lm, target_submodule_names=cfg.target_modules_lm, **lora_kwargs)
# DiT: feat_decoder.estimator
if cfg.enable_dit:
apply_lora_to_named_linear_modules(
self.feat_decoder.estimator, target_submodule_names=cfg.target_modules_dit, **lora_kwargs
)
# 投影层
if cfg.enable_proj:
from ..modules.layers.lora import LoRALinear
for attr_name in cfg.target_proj_modules:
module = getattr(self, attr_name, None)
if isinstance(module, nn.Linear):
setattr(self, attr_name, LoRALinear(base=module, **lora_kwargs))
def optimize(self, disable: bool = False):
if disable:
return self
try:
if self.device != "cuda":
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
try:
import triton # noqa: F401
except ImportError:
raise ValueError("triton is not installed")
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
self.residual_lm.forward_step = torch.compile(
self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True
)
self._feat_encoder_raw = self.feat_encoder
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
self.feat_decoder.estimator = torch.compile(
self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True
)
except Exception as e:
print(f"Warning: torch.compile disabled - {e}", file=sys.stderr)
return self
def forward(
self,
text_tokens: torch.Tensor,
text_mask: torch.Tensor,
audio_feats: torch.Tensor,
audio_mask: torch.Tensor,
loss_mask: torch.Tensor,
position_ids: torch.Tensor,
labels: torch.Tensor,
*,
progress: float = 0.0,
sample_generate: bool = False,
):
del position_ids # not used yet
text_tokens = text_tokens.to(self.device, dtype=torch.long)
text_mask = text_mask.to(self.device, dtype=self._dtype())
audio_feats = audio_feats.to(self.device, dtype=self._dtype())
audio_mask = audio_mask.to(self.device, dtype=self._dtype())
loss_mask = loss_mask.to(self.device, dtype=self._dtype())
labels = labels.to(self.device, dtype=torch.long)
B, T, P, D = audio_feats.shape
feat_embed = self.feat_encoder(audio_feats)
feat_embed = self.enc_to_lm_proj(feat_embed)
scale_emb = getattr(self.config.lm_config, "scale_emb", 1.0)
if not getattr(self.config.lm_config, "use_mup", False):
scale_emb = 1.0
text_embed = self.base_lm.embed_tokens(text_tokens) * scale_emb
combined_embed = text_mask.unsqueeze(-1) * text_embed + audio_mask.unsqueeze(-1) * feat_embed
enc_outputs, _ = self.base_lm(inputs_embeds=combined_embed, is_causal=True)
enc_outputs = enc_outputs.to(self._dtype())
enc_outputs = self.fsq_layer(enc_outputs) * audio_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
lm_hidden = torch.cat((torch.zeros_like(enc_outputs[:, 0:1, :]), enc_outputs[:, :-1, :]), dim=1)
residual_inputs = self.fusion_concat_proj(
torch.cat((enc_outputs, audio_mask.unsqueeze(-1) * feat_embed), dim=-1)
)
residual_outputs, _ = self.residual_lm(inputs_embeds=residual_inputs, is_causal=True)
residual_outputs = residual_outputs.to(self._dtype())
residual_hidden = torch.cat(
(torch.zeros_like(residual_outputs[:, 0:1, :]), residual_outputs[:, :-1, :]),
dim=1,
)
dit_hidden = torch.cat((self.lm_to_dit_proj(lm_hidden), self.res_to_dit_proj(residual_hidden)), dim=-1)
dit_hidden = rearrange(dit_hidden, "b t c -> (b t) c")
# Keep diffusion inputs in the same dtype as the model (e.g., bfloat16)
target_dtype = self._dtype()
feat_gt = rearrange(audio_feats.to(target_dtype), "b t p d -> (b t) p d")
feat_cond = torch.cat(
(torch.zeros_like(audio_feats[:, 0:1, ...]), audio_feats[:, :-1, ...]),
dim=1,
)
feat_cond = rearrange(feat_cond.to(target_dtype), "b t p d -> (b t) p d")
loss_seq_mask = loss_mask.unsqueeze(-1).repeat(1, 1, self.patch_size)
loss_seq_mask = rearrange(loss_seq_mask, "b t p -> (b t) p 1").to(target_dtype)
diff_loss = self.feat_decoder.compute_loss(
feat_gt.transpose(1, 2).contiguous(),
dit_hidden,
cond=feat_cond.transpose(1, 2).contiguous(),
tgt_mask=loss_seq_mask.transpose(1, 2).contiguous(),
progress=progress,
)
stop_logits = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden)))
stop_losses = self.stop_loss(stop_logits.transpose(1, 2), labels)
denom = torch.clamp(loss_mask.sum(), min=1.0)
stop_loss = (stop_losses * loss_mask).sum() / denom
feat_pred = None
if sample_generate:
feat_cond_for_sample = feat_cond.transpose(1, 2).contiguous()
feat_pred_seq = self.feat_decoder(
mu=dit_hidden,
patch_size=self.patch_size,
cond=feat_cond_for_sample,
n_timesteps=10,
)
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
feat_gt_tensor = rearrange(feat_gt, "(b t) p d -> b d (t p)", b=B, p=self.patch_size)
return {
"loss/diff": diff_loss,
"loss/stop": stop_loss,
"feat_gt": feat_gt_tensor,
"feat_pred": feat_pred,
}
def _dtype(self):
return get_dtype(self.config.dtype)
def _encode_wav(
self,
wav_path: str,
padding_mode: str = "right",
trim_silence_vad: bool = False,
) -> torch.Tensor:
"""Load, trim, pad and VAE-encode an audio file.
Args:
wav_path: path to the audio file.
padding_mode: "right" (default) or "left" padding for alignment.
trim_silence_vad: whether to apply VAD-based silence trimming.
Returns:
audio_feat: (T, P, D) tensor of latent patches.
"""
audio, _ = librosa.load(wav_path, sr=self._encode_sample_rate, mono=True)
audio = torch.from_numpy(audio).unsqueeze(0)
if trim_silence_vad:
audio = _trim_audio_silence_vad(audio, self._encode_sample_rate, max_silence_ms=200.0)
patch_len = self.patch_size * self.chunk_size
if audio.size(1) % patch_len != 0:
padding_size = patch_len - audio.size(1) % patch_len
pad = (padding_size, 0) if padding_mode == "left" else (0, padding_size)
audio = torch.nn.functional.pad(audio, pad)
feat = self.audio_vae.encode(audio.to(self.device), self._encode_sample_rate).cpu()
return feat.view(self.audio_vae.latent_dim, -1, self.patch_size).permute(1, 2, 0)
def _make_ref_prefix(self, ref_feat: torch.Tensor, device: torch.device):
"""Build the [ref_start ref_audio ref_end] prefix segments.
Returns:
tokens, feats, text_mask, audio_mask (all 1-D / 2-D tensors)
"""
ref_len = ref_feat.size(0)
z1 = torch.zeros((1, self.patch_size, self.audio_vae.latent_dim), dtype=torch.float32, device=device)
tokens = torch.cat(
[
torch.tensor([self.ref_audio_start_token], dtype=torch.int32, device=device),
torch.zeros(ref_len, dtype=torch.int32, device=device),
torch.tensor([self.ref_audio_end_token], dtype=torch.int32, device=device),
]
)
feats = torch.cat([z1, ref_feat, z1], dim=0)
t_mask = torch.cat(
[
torch.tensor([1], dtype=torch.int32),
torch.zeros(ref_len, dtype=torch.int32),
torch.tensor([1], dtype=torch.int32),
]
).to(device)
a_mask = torch.cat(
[
torch.tensor([0], dtype=torch.int32),
torch.ones(ref_len, dtype=torch.int32),
torch.tensor([0], dtype=torch.int32),
]
).to(device)
return tokens, feats, t_mask, a_mask
def generate(self, *args, **kwargs) -> torch.Tensor:
return next(self._generate(*args, streaming=False, **kwargs))
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
return self._generate(*args, streaming=True, **kwargs)
@torch.inference_mode()
def _generate(
self,
target_text: str,
prompt_text: str = "",
prompt_wav_path: str = "",
reference_wav_path: str = "",
min_len: int = 2,
max_len: int = 2000,
inference_timesteps: int = 10,
cfg_value: float = 2.0,
retry_badcase: bool = False,
retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0,
trim_silence_vad: bool = False,
streaming: bool = False,
streaming_prefix_len: int = 4,
) -> Generator[torch.Tensor, None, None]:
if retry_badcase and streaming:
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
retry_badcase = False
if reference_wav_path and prompt_wav_path:
# Combined mode: reference isolation prefix + continuation suffix
text = prompt_text + target_text
text_token = torch.LongTensor(self.text_tokenizer(text))
text_token = torch.cat(
[
text_token,
torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device),
],
dim=-1,
)
text_length = text_token.shape[0]
ref_feat = self._encode_wav(
reference_wav_path,
padding_mode="right",
trim_silence_vad=trim_silence_vad,
)
prompt_feat = self._encode_wav(prompt_wav_path, padding_mode="left", trim_silence_vad=trim_silence_vad)
prompt_audio_length = prompt_feat.size(0)
ref_tokens, ref_feats, ref_t_mask, ref_a_mask = self._make_ref_prefix(ref_feat, text_token.device)
prompt_pad_token = torch.zeros(prompt_audio_length, dtype=torch.int32, device=text_token.device)
text_pad_feat = torch.zeros(
(text_length, self.patch_size, self.audio_vae.latent_dim),
dtype=torch.float32,
device=text_token.device,
)
text_token = torch.cat([ref_tokens, text_token, prompt_pad_token])
audio_feat = torch.cat([ref_feats, text_pad_feat, prompt_feat], dim=0)
text_mask = torch.cat(
[
ref_t_mask,
torch.ones(text_length, dtype=torch.int32).to(text_token.device),
torch.zeros(prompt_audio_length, dtype=torch.int32).to(text_token.device),
]
)
audio_mask = torch.cat(
[
ref_a_mask,
torch.zeros(text_length, dtype=torch.int32).to(text_token.device),
torch.ones(prompt_audio_length, dtype=torch.int32).to(text_token.device),
]
)
elif reference_wav_path:
# Reference-only mode (prompt isolation)
text = target_text
text_token = torch.LongTensor(self.text_tokenizer(text))
text_token = torch.cat(
[
text_token,
torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device),
],
dim=-1,
)
text_length = text_token.shape[0]
ref_feat = self._encode_wav(
reference_wav_path,
padding_mode="right",
trim_silence_vad=trim_silence_vad,
)
ref_tokens, ref_feats, ref_t_mask, ref_a_mask = self._make_ref_prefix(ref_feat, text_token.device)
text_pad_feat = torch.zeros(
(text_length, self.patch_size, self.audio_vae.latent_dim),
dtype=torch.float32,
device=text_token.device,
)
text_token = torch.cat([ref_tokens, text_token])
audio_feat = torch.cat([ref_feats, text_pad_feat], dim=0)
text_mask = torch.cat(
[
ref_t_mask,
torch.ones(text_length, dtype=torch.int32).to(text_token.device),
]
)
audio_mask = torch.cat(
[
ref_a_mask,
torch.zeros(text_length, dtype=torch.int32).to(text_token.device),
]
)
elif len(prompt_wav_path) == 0:
# Zero-shot mode
text = target_text
text_token = torch.LongTensor(self.text_tokenizer(text))
text_token = torch.cat(
[
text_token,
torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device),
],
dim=-1,
)
text_length = text_token.shape[0]
audio_feat = torch.zeros(
(text_length, self.patch_size, self.audio_vae.latent_dim),
dtype=torch.float32,
device=text_token.device,
)
text_mask = torch.ones(text_length, dtype=torch.int32).to(text_token.device)
audio_mask = torch.zeros(text_length, dtype=torch.int32).to(text_token.device)
else:
# Continuation-only mode
text = prompt_text + target_text
text_token = torch.LongTensor(self.text_tokenizer(text))
text_token = torch.cat(
[
text_token,
torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device),
],
dim=-1,
)
text_length = text_token.shape[0]
prompt_feat = self._encode_wav(prompt_wav_path, padding_mode="left", trim_silence_vad=trim_silence_vad)
prompt_audio_length = prompt_feat.size(0)
prompt_pad_token = torch.zeros(prompt_audio_length, dtype=torch.int32, device=text_token.device)
text_pad_feat = torch.zeros(
(text_length, self.patch_size, self.audio_vae.latent_dim),
dtype=torch.float32,
device=text_token.device,
)
text_token = torch.cat([text_token, prompt_pad_token])
audio_feat = torch.cat([text_pad_feat, prompt_feat], dim=0)
text_mask = torch.cat(
[
torch.ones(text_length, dtype=torch.int32),
torch.zeros(prompt_audio_length, dtype=torch.int32),
]
).to(text_token.device)
audio_mask = torch.cat(
[
torch.zeros(text_length, dtype=torch.int32),
torch.ones(prompt_audio_length, dtype=torch.int32),
]
).to(text_token.device)
text_token = text_token.unsqueeze(0).to(self.device)
text_mask = text_mask.unsqueeze(0).to(self.device)
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
audio_mask = audio_mask.unsqueeze(0).to(self.device)
target_text_length = len(self.text_tokenizer(target_text))
retry_badcase_times = 0
while retry_badcase_times < retry_badcase_max_times:
inference_result = self._inference(
text_token,
text_mask,
audio_feat,
audio_mask,
min_len=min_len,
max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len),
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
streaming=streaming,
streaming_prefix_len=streaming_prefix_len,
)
if streaming:
decode_patch_len = self.patch_size * self._decode_chunk_size
for latent_pred, _, _ctx in inference_result:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
decode_audio = decode_audio[..., -decode_patch_len:].squeeze(1).cpu()
yield decode_audio
break
else:
latent_pred, pred_audio_feat, context_len = next(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(
f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...",
file=sys.stderr,
)
retry_badcase_times += 1
continue
else:
break
else:
break
if not streaming:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
decode_patch_len = self.patch_size * self._decode_chunk_size
if context_len > 0:
decode_audio = decode_audio[..., decode_patch_len * context_len:].squeeze(1).cpu()
else:
decode_audio = decode_audio.squeeze(1).cpu()
yield decode_audio
@torch.inference_mode()
def build_prompt_cache(
self,
prompt_text: str = None,
prompt_wav_path: str = None,
reference_wav_path: str = None,
trim_silence_vad: bool = False,
):
"""
Build prompt cache for subsequent generation.
Supports the same parameter combinations as ``generate()``:
- ``reference_wav_path`` only -> reference mode (voice cloning, isolated)
- ``prompt_text`` + ``prompt_wav_path`` -> continuation mode
- all three -> combined ref + continuation mode
Args:
prompt_text: prompt text for continuation mode.
Must be paired with ``prompt_wav_path``.
prompt_wav_path: prompt audio path for continuation mode.
Must be paired with ``prompt_text``.
reference_wav_path: reference audio path for voice cloning
(structurally isolated via ref_audio tokens).
trim_silence_vad: whether to apply VAD-based silence trimming
before encoding prompt/reference audio.
Returns:
prompt_cache: dict used by ``_generate_with_prompt_cache``.
"""
if (prompt_wav_path is None) != (prompt_text is None):
raise ValueError("prompt_wav_path and prompt_text must both be provided or both be None")
if prompt_wav_path is None and reference_wav_path is None:
raise ValueError("At least one of prompt_wav_path or reference_wav_path must be provided")
cache = {}
if reference_wav_path:
cache["ref_audio_feat"] = self._encode_wav(
reference_wav_path,
padding_mode="right",
trim_silence_vad=trim_silence_vad,
)
if prompt_wav_path and prompt_text is not None:
cache["prompt_text"] = prompt_text
cache["audio_feat"] = self._encode_wav(
prompt_wav_path,
padding_mode="left",
trim_silence_vad=trim_silence_vad,
)
has_ref = "ref_audio_feat" in cache
has_prompt = "audio_feat" in cache
if has_ref and has_prompt:
cache["mode"] = "ref_continuation"
elif has_ref:
cache["mode"] = "reference"
else:
cache["mode"] = "continuation"
return cache
def merge_prompt_cache(
self,
original_cache: dict,
new_text: str,
new_audio_feat: torch.Tensor,
):
"""
Merge original prompt cache with newly generated content to stabilize voice.
Args:
original_cache: original prompt cache (any mode)
new_text: newly generated text
new_audio_feat: newly generated audio features
Returns:
merged_cache: merged cache with prompt_text and audio_feat
"""
if original_cache is None:
return {
"prompt_text": new_text,
"audio_feat": new_audio_feat,
"mode": "continuation",
}
merged = {}
if "ref_audio_feat" in original_cache:
merged["ref_audio_feat"] = original_cache["ref_audio_feat"]
merged["prompt_text"] = original_cache.get("prompt_text", "") + new_text
old_feat = original_cache.get("audio_feat", new_audio_feat.new_empty(0, *new_audio_feat.shape[1:]))
merged["audio_feat"] = torch.cat([old_feat, new_audio_feat], dim=0)
merged["mode"] = "ref_continuation" if "ref_audio_feat" in merged else "continuation"
return merged
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
def generate_with_prompt_cache_streaming(
self, *args, **kwargs
) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]:
return self._generate_with_prompt_cache(*args, streaming=True, **kwargs)
@torch.inference_mode()
def _generate_with_prompt_cache(
self,
target_text: str,
prompt_cache: dict,
min_len: int = 2,
max_len: int = 2000,
inference_timesteps: int = 10,
cfg_value: float = 2.0,
retry_badcase: bool = False,
retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0,
streaming: bool = False,
streaming_prefix_len: int = 4,
) -> Generator[Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
"""
Generate audio using pre-built prompt cache.
Args:
target_text: Text to convert to speech
prompt_cache: Cache built by ``build_prompt_cache()``. Can be None
for zero-shot generation.
min_len: Minimum audio length to avoid very short audio
max_len: Maximum audio length
inference_timesteps: Number of diffusion sampling steps
cfg_value: Classifier-free guidance value
retry_badcase: Whether to retry on bad cases
retry_badcase_max_times: Maximum retry attempts
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
streaming: Whether to return a generator of audio chunks
streaming_prefix_len: Number of prefix audio patches to use for streaming mode
Returns:
Generator of Tuple containing:
- Decoded audio tensor for the current step if ``streaming=True``, else final decoded audio tensor
- Tensor of new text tokens
- New audio features up to the current step as a List if ``streaming=True``, else as a concatenated Tensor
"""
if retry_badcase and streaming:
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
retry_badcase = False
# Determine mode from cache
if prompt_cache is None:
mode = "zero_shot"
text = target_text
else:
mode = prompt_cache.get("mode", "continuation")
if mode in ("continuation", "ref_continuation"):
prompt_text = prompt_cache.get("prompt_text", "")
text = prompt_text + target_text
else:
text = target_text
text_token = torch.LongTensor(self.text_tokenizer(text))
text_token = torch.cat(
[
text_token,
torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device),
],
dim=-1,
)
target_text_token = torch.LongTensor(self.text_tokenizer(target_text))
text_length = text_token.shape[0]
if mode in ("zero_shot", "continuation"):
prompt_audio_feat = (
prompt_cache["audio_feat"]
if prompt_cache
else torch.empty((0, self.patch_size, self.audio_vae.latent_dim), dtype=torch.float32)
)
audio_length = prompt_audio_feat.size(0)
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
text_pad_feat = torch.zeros(
(text_length, self.patch_size, self.audio_vae.latent_dim),
dtype=torch.float32,
device=text_token.device,
)
text_token = torch.cat([text_token, text_pad_token])
audio_feat = torch.cat([text_pad_feat, prompt_audio_feat], dim=0)
text_mask = torch.cat(
[torch.ones(text_length, dtype=torch.int32), torch.zeros(audio_length, dtype=torch.int32)]
).to(text_token.device)
audio_mask = torch.cat(
[torch.zeros(text_length, dtype=torch.int32), torch.ones(audio_length, dtype=torch.int32)]
).to(text_token.device)
elif mode == "reference":
ref_audio_feat = prompt_cache["ref_audio_feat"]
ref_tokens, ref_feats, ref_t_mask, ref_a_mask = self._make_ref_prefix(ref_audio_feat, text_token.device)
text_pad_feat = torch.zeros(
(text_length, self.patch_size, self.audio_vae.latent_dim),
dtype=torch.float32,
device=text_token.device,
)
text_token = torch.cat([ref_tokens, text_token])
audio_feat = torch.cat([ref_feats, text_pad_feat], dim=0)
text_mask = torch.cat([ref_t_mask, torch.ones(text_length, dtype=torch.int32).to(text_token.device)])
audio_mask = torch.cat([ref_a_mask, torch.zeros(text_length, dtype=torch.int32).to(text_token.device)])
else:
# ref_continuation mode
ref_audio_feat = prompt_cache["ref_audio_feat"]
prompt_audio_feat = prompt_cache["audio_feat"]
prompt_audio_length = prompt_audio_feat.size(0)
ref_tokens, ref_feats, ref_t_mask, ref_a_mask = self._make_ref_prefix(ref_audio_feat, text_token.device)
prompt_pad_token = torch.zeros(prompt_audio_length, dtype=torch.int32, device=text_token.device)
text_pad_feat = torch.zeros(
(text_length, self.patch_size, self.audio_vae.latent_dim),
dtype=torch.float32,
device=text_token.device,
)
text_token = torch.cat([ref_tokens, text_token, prompt_pad_token])
audio_feat = torch.cat([ref_feats, text_pad_feat, prompt_audio_feat], dim=0)
text_mask = torch.cat(
[
ref_t_mask,
torch.ones(text_length, dtype=torch.int32).to(text_token.device),
torch.zeros(prompt_audio_length, dtype=torch.int32).to(text_token.device),
]
)
audio_mask = torch.cat(
[
ref_a_mask,
torch.zeros(text_length, dtype=torch.int32).to(text_token.device),
torch.ones(prompt_audio_length, dtype=torch.int32).to(text_token.device),
]
)
text_token = text_token.unsqueeze(0).to(self.device)
text_mask = text_mask.unsqueeze(0).to(self.device)
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
audio_mask = audio_mask.unsqueeze(0).to(self.device)
# run inference
target_text_length = len(self.text_tokenizer(target_text))
retry_badcase_times = 0
while retry_badcase_times < retry_badcase_max_times:
inference_result = self._inference(
text_token,
text_mask,
audio_feat,
audio_mask,
min_len=min_len,
max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len),
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
streaming=streaming,
streaming_prefix_len=streaming_prefix_len,
)
if streaming:
decode_patch_len = self.patch_size * self._decode_chunk_size
for latent_pred, pred_audio_feat, _ctx in inference_result:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
decode_audio = decode_audio[..., -decode_patch_len:].squeeze(1).cpu()
yield (decode_audio, target_text_token, pred_audio_feat)
break
else:
latent_pred, pred_audio_feat, context_len = next(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(
f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...",
file=sys.stderr,
)
retry_badcase_times += 1
continue
else:
break
else:
break
if not streaming:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
decode_patch_len = self.patch_size * self._decode_chunk_size
if context_len > 0:
decode_audio = decode_audio[..., decode_patch_len * context_len:].squeeze(1).cpu()
else:
decode_audio = decode_audio.squeeze(1).cpu()
yield (decode_audio, target_text_token, pred_audio_feat)
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
feat_pred, generated_feat, _ = next(self._inference(*args, streaming=False, **kwargs))
return feat_pred, generated_feat
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
for feat_pred, pred_feat_seq, _ in self._inference(*args, streaming=True, **kwargs):
yield feat_pred, pred_feat_seq
@torch.inference_mode()
def _inference(
self,
text: torch.Tensor,
text_mask: torch.Tensor,
feat: torch.Tensor,
feat_mask: torch.Tensor,
min_len: int = 2,
max_len: int = 2000,
inference_timesteps: int = 10,
cfg_value: float = 2.0,
streaming: bool = False,
streaming_prefix_len: int = 4,
) -> Generator[Tuple[torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
"""Core inference method for audio generation.
This is the main inference loop that generates audio features
using the language model and diffusion transformer.
Args:
text: Input text tokens
text_mask: Mask for text tokens
feat: Input audio features
feat_mask: Mask for audio features
min_len: Minimum generation length
max_len: Maximum generation length
inference_timesteps: Number of diffusion steps
cfg_value: Classifier-free guidance value
streaming: Whether to yield each step latent feature or just the final result
Returns:
Generator of Tuple containing:
- Predicted latent feature at the current step if ``streaming=True``, else final latent features
- Predicted audio feature sequence so far as a List if ``streaming=True``, else as a concatenated Tensor
"""
B, T, P, D = feat.shape
prefill_encoder = getattr(self, "_feat_encoder_raw", self.feat_encoder)
feat_embed = prefill_encoder(feat) # [b, t, h_feat]
feat_embed = self.enc_to_lm_proj(feat_embed)
if self.config.lm_config.use_mup:
scale_emb = self.config.lm_config.scale_emb
else:
scale_emb = 1.0
text_embed = self.base_lm.embed_tokens(text) * scale_emb
combined_embed = text_mask.unsqueeze(-1) * text_embed + feat_mask.unsqueeze(-1) * feat_embed
prefix_feat_cond = feat[:, -1, ...] # b, p, d
pred_feat_seq = [] # b, t, p, d
curr_embed = None
# Prepare prompt context patches for streaming mode
# - Continuation modes (feat_mask ends with 1): use the last (streaming_prefix_len - 1)
# trailing audio patches as initial context so the VAE can decode smoothly.
# - Reference-only / zero-shot (feat_mask ends with 0): start from scratch.
has_continuation_audio = feat_mask[0, -1].item() == 1
context_len = 0
if has_continuation_audio:
audio_indices = feat_mask.squeeze(0).nonzero(as_tuple=True)[0]
context_len = min(streaming_prefix_len - 1, len(audio_indices))
last_audio_indices = audio_indices[-context_len:]
pred_feat_seq = list(feat[:, last_audio_indices, :, :].split(1, dim=1))
else:
pred_feat_seq = []
enc_outputs, kv_cache_tuple = self.base_lm(
inputs_embeds=combined_embed,
is_causal=True,
)
self.base_lm.kv_cache.fill_caches(kv_cache_tuple)
enc_outputs = self.fsq_layer(enc_outputs) * feat_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
lm_hidden = enc_outputs[:, -1, :]
residual_enc_inputs = self.fusion_concat_proj(
torch.cat((enc_outputs, feat_mask.unsqueeze(-1) * feat_embed), dim=-1)
)
residual_enc_outputs, residual_kv_cache_tuple = self.residual_lm(
inputs_embeds=residual_enc_inputs,
is_causal=True,
)
self.residual_lm.kv_cache.fill_caches(residual_kv_cache_tuple)
residual_hidden = residual_enc_outputs[:, -1, :]
for i in tqdm(range(max_len)):
dit_hidden_1 = self.lm_to_dit_proj(lm_hidden) # [b, h_dit]
dit_hidden_2 = self.res_to_dit_proj(residual_hidden) # [b, h_dit]
dit_hidden = torch.cat((dit_hidden_1, dit_hidden_2), dim=-1)
pred_feat = self.feat_decoder(
mu=dit_hidden,
patch_size=self.patch_size,
cond=prefix_feat_cond.transpose(1, 2).contiguous(),
n_timesteps=inference_timesteps,
cfg_value=cfg_value,
).transpose(
1, 2
) # [b, p, d]
curr_embed = self.feat_encoder(pred_feat.unsqueeze(1)) # b, 1, c
curr_embed = self.enc_to_lm_proj(curr_embed)
pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
prefix_feat_cond = pred_feat
if streaming:
pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1)
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
yield feat_pred, pred_feat_seq, context_len
if len(pred_feat_seq) > streaming_prefix_len:
pred_feat_seq = pred_feat_seq[-streaming_prefix_len:]
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
if i > min_len and stop_flag == 1:
break
lm_hidden = self.base_lm.forward_step(
curr_embed[:, 0, :], torch.tensor([self.base_lm.kv_cache.step()], device=curr_embed.device)
).clone()
lm_hidden = self.fsq_layer(lm_hidden)
curr_residual_input = self.fusion_concat_proj(torch.cat((lm_hidden, curr_embed[:, 0, :]), dim=-1))
residual_hidden = self.residual_lm.forward_step(
curr_residual_input, torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device)
).clone()
if not streaming:
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
generated_feat = pred_feat_seq[:, context_len:, :, :].squeeze(0).cpu()
yield feat_pred, generated_feat, context_len
@classmethod
def from_local(
cls,
path: str,
optimize: bool = True,
training: bool = False,
device: str | None = None,
lora_config: LoRAConfig = None,
):
with open(os.path.join(path, "config.json"), "r", encoding="utf-8") as _cfg_f:
config = VoxCPMConfig.model_validate_json(_cfg_f.read())
tokenizer = LlamaTokenizerFast.from_pretrained(path)
audio_vae_config = getattr(config, "audio_vae_config", None)
audio_vae = AudioVAEV2(config=audio_vae_config) if audio_vae_config else AudioVAEV2()
# Try to load AudioVAE from safetensors first, fallback to pytorch
audiovae_safetensors_path = os.path.join(path, "audiovae.safetensors")
audiovae_pth_path = os.path.join(path, "audiovae.pth")
if os.path.exists(audiovae_safetensors_path) and SAFETENSORS_AVAILABLE:
print(f"Loading AudioVAE from safetensors: {audiovae_safetensors_path}", file=sys.stderr)
vae_state_dict = load_file(audiovae_safetensors_path, device="cpu")
elif os.path.exists(audiovae_pth_path):
print(f"Loading AudioVAE from pytorch: {audiovae_pth_path}", file=sys.stderr)
checkpoint = torch.load(
audiovae_pth_path,
map_location="cpu",
weights_only=True,
)
vae_state_dict = checkpoint.get("state_dict", checkpoint)
else:
raise FileNotFoundError(
f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
)
model = cls(config, tokenizer, audio_vae, lora_config, device=device)
if not training:
lm_dtype = get_dtype(model.config.dtype)
model = model.to(lm_dtype)
else: # training mode
for name, param in model.named_parameters():
if "audio_vae" in name: # freeze VAE weights
param.requires_grad = False
continue
if lora_config is not None:
if "lora" not in name: # freeze non-LoRA weights
param.requires_grad = False
model.audio_vae = model.audio_vae.to(torch.float32)
# Try to load from safetensors first, fallback to pytorch_model.bin
safetensors_path = os.path.join(path, "model.safetensors")
pytorch_model_path = os.path.join(path, "pytorch_model.bin")
if os.path.exists(safetensors_path) and SAFETENSORS_AVAILABLE:
print(f"Loading model from safetensors: {safetensors_path}", file=sys.stderr)
model_state_dict = load_file(safetensors_path)
elif os.path.exists(pytorch_model_path):
print(f"Loading model from pytorch_model.bin: {pytorch_model_path}", file=sys.stderr)
checkpoint = torch.load(
pytorch_model_path,
map_location="cpu",
weights_only=True,
)
model_state_dict = checkpoint.get("state_dict", checkpoint)
else:
raise FileNotFoundError(f"Model file not found. Expected either {safetensors_path} or {pytorch_model_path}")
for kw, val in vae_state_dict.items():
model_state_dict[f"audio_vae.{kw}"] = val
# LoRALinear keeps weight/bias compatible with nn.Linear but adds
# lora_A/lora_B, which are absent from base pretrained checkpoints.
model.load_state_dict(model_state_dict, strict=False)
if training:
return model
return model.to(model.device).eval().optimize(disable=not optimize)
# ------------------------------------------------------------------ #
# LoRA Weight Management
# ------------------------------------------------------------------ #
def _iter_lora_modules(self):
"""Iterate over all LoRA modules."""
from ..modules.layers.lora import LoRALinear
for module in self.modules():
if isinstance(module, LoRALinear):
yield module
def load_lora_weights(self, lora_path: str, device: str = None):
"""
Load LoRA weights from file, supports calling after torch.compile.
Uses named_parameters() to handle compile's _orig_mod wrapper.
Supports both safetensors and pytorch formats.
Args:
lora_path: Checkpoint path (directory or .safetensors/.ckpt file)
device: Target device, defaults to model's current device
Returns:
tuple: (loaded_keys, skipped_keys)
"""
from pathlib import Path
device = device or self.device
lora_p = Path(lora_path)
# Try safetensors first, then fallback to .ckpt
if lora_p.is_dir():
safetensors_file = lora_p / "lora_weights.safetensors"
ckpt_file = lora_p / "lora_weights.ckpt"
else:
safetensors_file = lora_p if lora_p.suffix == ".safetensors" else None
ckpt_file = lora_p if lora_p.suffix in [".ckpt", ".pth"] else None
# Load from safetensors if available
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
state_dict = load_file(str(safetensors_file), device=device)
elif ckpt_file and ckpt_file.exists():
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
state_dict = ckpt.get("state_dict", ckpt)
else:
raise FileNotFoundError(f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}")
# Build param mapping (handle torch.compile's _orig_mod prefix)
model_params = dict(self.named_parameters())
key_mapping = {k.replace("._orig_mod.", "."): k for k in model_params if "._orig_mod." in k}
loaded_keys, skipped_keys = [], []
for key, value in state_dict.items():
target_key = key if key in model_params else key_mapping.get(key)
if target_key:
model_params[target_key].data.copy_(value.to(device))
loaded_keys.append(key)
else:
skipped_keys.append(key)
return loaded_keys, skipped_keys
def set_lora_enabled(self, enabled: bool):
"""Enable/disable all LoRA layers."""
for module in self._iter_lora_modules():
module.set_enabled(enabled)
def reset_lora_weights(self):
"""Reset all LoRA weights (A: kaiming, B: zeros), effectively unloading LoRA."""
for module in self._iter_lora_modules():
module.reset_lora_parameters()
def get_lora_state_dict(self) -> dict:
"""Get all LoRA parameters (lora_A/lora_B)."""
return {name: param.data.clone() for name, param in self.named_parameters() if "lora_" in name}