Update: VoxCPM1.5 and fine-tuning supprt
This commit is contained in:
+324
-61
@@ -19,19 +19,27 @@ limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Tuple, Union, Generator, List
|
||||
from typing import Tuple, Union, Generator, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
import warnings
|
||||
from einops import rearrange
|
||||
from pydantic import BaseModel
|
||||
|
||||
try:
|
||||
from safetensors.torch import load_file
|
||||
SAFETENSORS_AVAILABLE = True
|
||||
except ImportError:
|
||||
SAFETENSORS_AVAILABLE = False
|
||||
from tqdm import tqdm
|
||||
from transformers import LlamaTokenizerFast
|
||||
|
||||
from ..modules.audiovae import AudioVAE
|
||||
from ..modules.audiovae import AudioVAE, AudioVAEConfig
|
||||
from ..modules.layers import ScalarQuantizationLayer
|
||||
from ..modules.layers.lora import apply_lora_to_named_linear_modules
|
||||
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
|
||||
from ..modules.locenc import VoxCPMLocEnc
|
||||
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
||||
@@ -66,10 +74,31 @@ class VoxCPMConfig(BaseModel):
|
||||
|
||||
encoder_config: VoxCPMEncoderConfig
|
||||
dit_config: VoxCPMDitConfig
|
||||
audio_vae_config: Optional[AudioVAEConfig] = None
|
||||
|
||||
max_length: int = 4096
|
||||
device: str = "cuda"
|
||||
dtype: str = "bfloat16"
|
||||
dit_mean_mode: bool = False
|
||||
|
||||
|
||||
class LoRAConfig(BaseModel):
|
||||
enable_lm: bool = False # Apply LoRA to base_lm + residual_lm
|
||||
enable_dit: bool = False # Apply LoRA to VoxCPMLocDiT
|
||||
enable_proj: bool = False # Apply LoRA to projection Linear layers
|
||||
|
||||
r: int = 8
|
||||
alpha: int = 16
|
||||
dropout: float = 0.0
|
||||
|
||||
# Target linear layer names for LM & DiT (matched by attribute name)
|
||||
target_modules_lm: list[str] = ["q_proj", "v_proj", "k_proj", "o_proj"]
|
||||
target_modules_dit: list[str] = ["q_proj", "v_proj", "k_proj", "o_proj"]
|
||||
# Projection layer attribute names to find on VoxCPMModel
|
||||
target_proj_modules: list[str] = ["enc_to_lm_proj", "lm_to_dit_proj", "res_to_dit_proj"]
|
||||
|
||||
|
||||
VoxCPMConfig.model_rebuild()
|
||||
|
||||
|
||||
class VoxCPMModel(nn.Module):
|
||||
@@ -78,9 +107,11 @@ class VoxCPMModel(nn.Module):
|
||||
config: VoxCPMConfig,
|
||||
tokenizer: LlamaTokenizerFast,
|
||||
audio_vae: AudioVAE,
|
||||
lora_config: LoRAConfig = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
self.feat_dim = config.feat_dim
|
||||
self.patch_size = config.patch_size
|
||||
self.device = config.device
|
||||
@@ -128,6 +159,7 @@ class VoxCPMModel(nn.Module):
|
||||
in_channels=config.feat_dim,
|
||||
cfm_params=config.dit_config.cfm_config,
|
||||
estimator=VoxCPMLocDiT(decoder_config, in_channels=config.feat_dim),
|
||||
mean_mode=config.dit_mean_mode,
|
||||
)
|
||||
|
||||
# Projection layers
|
||||
@@ -145,17 +177,46 @@ class VoxCPMModel(nn.Module):
|
||||
self.stop_proj = nn.Linear(config.lm_config.hidden_size, config.lm_config.hidden_size)
|
||||
self.stop_actn = nn.SiLU()
|
||||
self.stop_head = nn.Linear(config.lm_config.hidden_size, 2, bias=False)
|
||||
self.stop_loss = nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
# Audio VAE
|
||||
self.audio_vae = audio_vae
|
||||
self.chunk_size = audio_vae.chunk_size
|
||||
self.sample_rate = audio_vae.sample_rate
|
||||
|
||||
|
||||
if self.lora_config is not None:
|
||||
self._apply_lora()
|
||||
|
||||
def _apply_lora(self):
|
||||
"""注入 LoRA 到 LM / DiT / 投影层"""
|
||||
cfg = self.lora_config
|
||||
lora_kwargs = dict(r=cfg.r, alpha=cfg.alpha, dropout=cfg.dropout)
|
||||
|
||||
# LM: base_lm + residual_lm
|
||||
if cfg.enable_lm:
|
||||
for lm in [self.base_lm, self.residual_lm]:
|
||||
apply_lora_to_named_linear_modules(
|
||||
lm, target_submodule_names=cfg.target_modules_lm, **lora_kwargs
|
||||
)
|
||||
|
||||
# DiT: feat_decoder.estimator
|
||||
if cfg.enable_dit:
|
||||
apply_lora_to_named_linear_modules(
|
||||
self.feat_decoder.estimator, target_submodule_names=cfg.target_modules_dit, **lora_kwargs
|
||||
)
|
||||
|
||||
# 投影层
|
||||
if cfg.enable_proj:
|
||||
from ..modules.layers.lora import LoRALinear
|
||||
for attr_name in cfg.target_proj_modules:
|
||||
module = getattr(self, attr_name, None)
|
||||
if isinstance(module, nn.Linear):
|
||||
setattr(self, attr_name, LoRALinear(base=module, **lora_kwargs))
|
||||
|
||||
def optimize(self, disable: bool = False):
|
||||
if disable:
|
||||
return self
|
||||
try:
|
||||
if disable:
|
||||
raise ValueError("Optimization disabled by user")
|
||||
if self.device != "cuda":
|
||||
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
|
||||
try:
|
||||
@@ -164,17 +225,111 @@ class VoxCPMModel(nn.Module):
|
||||
raise ValueError("triton is not installed")
|
||||
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
||||
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
||||
self.feat_encoder_step = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
||||
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
||||
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
print("Warning: VoxCPMModel can not be optimized by torch.compile, using original forward_step functions")
|
||||
self.base_lm.forward_step = self.base_lm.forward_step
|
||||
self.residual_lm.forward_step = self.residual_lm.forward_step
|
||||
self.feat_encoder_step = self.feat_encoder
|
||||
self.feat_decoder.estimator = self.feat_decoder.estimator
|
||||
print(f"Warning: torch.compile disabled - {e}")
|
||||
return self
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text_tokens: torch.Tensor,
|
||||
text_mask: torch.Tensor,
|
||||
audio_feats: torch.Tensor,
|
||||
audio_mask: torch.Tensor,
|
||||
loss_mask: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
*,
|
||||
progress: float = 0.0,
|
||||
sample_generate: bool = False,
|
||||
):
|
||||
del position_ids # not used yet
|
||||
|
||||
text_tokens = text_tokens.to(self.device, dtype=torch.long)
|
||||
text_mask = text_mask.to(self.device, dtype=self._dtype())
|
||||
audio_feats = audio_feats.to(self.device, dtype=self._dtype())
|
||||
audio_mask = audio_mask.to(self.device, dtype=self._dtype())
|
||||
loss_mask = loss_mask.to(self.device, dtype=self._dtype())
|
||||
labels = labels.to(self.device, dtype=torch.long)
|
||||
|
||||
B, T, P, D = audio_feats.shape
|
||||
feat_embed = self.feat_encoder(audio_feats)
|
||||
feat_embed = self.enc_to_lm_proj(feat_embed)
|
||||
|
||||
scale_emb = getattr(self.config.lm_config, "scale_emb", 1.0)
|
||||
if not getattr(self.config.lm_config, "use_mup", False):
|
||||
scale_emb = 1.0
|
||||
text_embed = self.base_lm.embed_tokens(text_tokens) * scale_emb
|
||||
combined_embed = text_mask.unsqueeze(-1) * text_embed + audio_mask.unsqueeze(-1) * feat_embed
|
||||
|
||||
enc_outputs, _ = self.base_lm(inputs_embeds=combined_embed, is_causal=True)
|
||||
enc_outputs = enc_outputs.to(self._dtype())
|
||||
enc_outputs = self.fsq_layer(enc_outputs) * audio_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
|
||||
lm_hidden = torch.cat((torch.zeros_like(enc_outputs[:, 0:1, :]), enc_outputs[:, :-1, :]), dim=1)
|
||||
|
||||
residual_inputs = enc_outputs + audio_mask.unsqueeze(-1) * feat_embed
|
||||
residual_outputs, _ = self.residual_lm(inputs_embeds=residual_inputs, is_causal=True)
|
||||
residual_outputs = residual_outputs.to(self._dtype())
|
||||
residual_hidden = torch.cat(
|
||||
(torch.zeros_like(residual_outputs[:, 0:1, :]), residual_outputs[:, :-1, :]),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
dit_hidden = self.lm_to_dit_proj(lm_hidden) + self.res_to_dit_proj(residual_hidden)
|
||||
dit_hidden = rearrange(dit_hidden, "b t c -> (b t) c")
|
||||
|
||||
# Keep diffusion inputs in the same dtype as the model (e.g., bfloat16)
|
||||
target_dtype = self._dtype()
|
||||
|
||||
feat_gt = rearrange(audio_feats.to(target_dtype), "b t p d -> (b t) p d")
|
||||
feat_cond = torch.cat(
|
||||
(torch.zeros_like(audio_feats[:, 0:1, ...]), audio_feats[:, :-1, ...]),
|
||||
dim=1,
|
||||
)
|
||||
feat_cond = rearrange(feat_cond.to(target_dtype), "b t p d -> (b t) p d")
|
||||
|
||||
loss_seq_mask = loss_mask.unsqueeze(-1).repeat(1, 1, self.patch_size)
|
||||
loss_seq_mask = rearrange(loss_seq_mask, "b t p -> (b t) p 1").to(target_dtype)
|
||||
|
||||
diff_loss = self.feat_decoder.compute_loss(
|
||||
feat_gt.transpose(1, 2).contiguous(),
|
||||
dit_hidden,
|
||||
cond=feat_cond.transpose(1, 2).contiguous(),
|
||||
tgt_mask=loss_seq_mask.transpose(1, 2).contiguous(),
|
||||
progress=progress,
|
||||
)
|
||||
|
||||
stop_logits = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden)))
|
||||
stop_losses = self.stop_loss(stop_logits.transpose(1, 2), labels)
|
||||
denom = torch.clamp(loss_mask.sum(), min=1.0)
|
||||
stop_loss = (stop_losses * loss_mask).sum() / denom
|
||||
|
||||
feat_pred = None
|
||||
if sample_generate:
|
||||
feat_cond_for_sample = feat_cond.transpose(1, 2).contiguous()
|
||||
feat_pred_seq = self.feat_decoder(
|
||||
mu=dit_hidden,
|
||||
patch_size=self.patch_size,
|
||||
cond=feat_cond_for_sample,
|
||||
n_timesteps=self.config.dit_config.cfm_config.inference_cfg_rate
|
||||
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
|
||||
else 10,
|
||||
)
|
||||
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
|
||||
|
||||
feat_gt_tensor = rearrange(feat_gt, "(b t) p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
|
||||
return {
|
||||
"loss/diff": diff_loss,
|
||||
"loss/stop": stop_loss,
|
||||
"feat_gt": feat_gt_tensor,
|
||||
"feat_pred": feat_pred,
|
||||
}
|
||||
|
||||
def _dtype(self):
|
||||
return get_dtype(self.config.dtype)
|
||||
|
||||
|
||||
def generate(self, *args, **kwargs) -> torch.Tensor:
|
||||
return next(self._generate(*args, streaming=False, **kwargs))
|
||||
@@ -238,25 +393,25 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
audio, sr = torchaudio.load(prompt_wav_path)
|
||||
if audio.size(0) > 1:
|
||||
audio = audio.mean(dim=0, keepdim=True)
|
||||
|
||||
audio = audio.mean(dim=0, keepdim=True)
|
||||
|
||||
if sr != self.sample_rate:
|
||||
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
|
||||
|
||||
patch_len = self.patch_size * self.chunk_size
|
||||
|
||||
if audio.size(1) % patch_len != 0:
|
||||
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
|
||||
# 左填充:在音频开头填充,保持有效音频数据在序列末尾
|
||||
padding_size = patch_len - audio.size(1) % patch_len
|
||||
audio = torch.nn.functional.pad(audio, (padding_size, 0))
|
||||
|
||||
# (B, D, T)
|
||||
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
|
||||
|
||||
audio_feat = audio_feat.view(
|
||||
self.audio_vae.latent_dim,
|
||||
-1,
|
||||
self.patch_size,
|
||||
).permute(1, 2, 0)
|
||||
audio_feat = audio_feat[:-1, ...] # trick: remove the last padding token
|
||||
audio_length = audio_feat.size(0)
|
||||
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
|
||||
text_token = torch.cat([text_token, text_pad_token])
|
||||
@@ -288,7 +443,7 @@ class VoxCPMModel(nn.Module):
|
||||
audio_feat,
|
||||
audio_mask,
|
||||
min_len=min_len,
|
||||
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
|
||||
max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len), # avoid too long audio
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
streaming=streaming,
|
||||
@@ -314,7 +469,6 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
if not streaming:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
|
||||
yield decode_audio
|
||||
|
||||
@torch.inference_mode()
|
||||
@@ -331,13 +485,11 @@ class VoxCPMModel(nn.Module):
|
||||
prompt_wav_path: prompt audio path (required)
|
||||
|
||||
Returns:
|
||||
prompt_cache: dict with text tokens and audio features
|
||||
prompt_cache: dict with prompt_text (raw text) and audio features.
|
||||
Text tokenization will be done during generation for consistency.
|
||||
"""
|
||||
if not prompt_text or not prompt_wav_path:
|
||||
raise ValueError("prompt_text and prompt_wav_path are required")
|
||||
|
||||
# build text tokens
|
||||
text_token = torch.LongTensor(self.text_tokenizer(prompt_text))
|
||||
|
||||
# load audio
|
||||
audio, sr = torchaudio.load(prompt_wav_path)
|
||||
@@ -350,7 +502,9 @@ class VoxCPMModel(nn.Module):
|
||||
patch_len = self.patch_size * self.chunk_size
|
||||
|
||||
if audio.size(1) % patch_len != 0:
|
||||
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
|
||||
# Left padding: pad at the beginning of the audio to keep valid audio data at the end of the sequence
|
||||
padding_size = patch_len - audio.size(1) % patch_len
|
||||
audio = torch.nn.functional.pad(audio, (padding_size, 0))
|
||||
|
||||
# extract audio features
|
||||
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
|
||||
@@ -360,10 +514,9 @@ class VoxCPMModel(nn.Module):
|
||||
-1,
|
||||
self.patch_size,
|
||||
).permute(1, 2, 0) # (D, T, P)
|
||||
audio_feat = audio_feat[:-1, ...] # trick: remove the last padding token
|
||||
# build prompt cache
|
||||
# build prompt cache - only save raw text and audio features
|
||||
prompt_cache = {
|
||||
"text_token": text_token,
|
||||
"prompt_text": prompt_text,
|
||||
"audio_feat": audio_feat,
|
||||
}
|
||||
|
||||
@@ -373,7 +526,7 @@ class VoxCPMModel(nn.Module):
|
||||
def merge_prompt_cache(
|
||||
self,
|
||||
original_cache: dict,
|
||||
new_text_token: torch.Tensor,
|
||||
new_text: str,
|
||||
new_audio_feat: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
@@ -381,38 +534,42 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
Args:
|
||||
original_cache: original prompt cache
|
||||
new_text_token: newly generated text tokens
|
||||
new_text: newly generated text
|
||||
new_audio_feat: newly generated audio features
|
||||
|
||||
Returns:
|
||||
merged_cache: merged cache
|
||||
merged_cache: merged cache with prompt_text and audio_feat
|
||||
"""
|
||||
if original_cache is None:
|
||||
return {
|
||||
"text_token": new_text_token,
|
||||
"prompt_text": new_text,
|
||||
"audio_feat": new_audio_feat,
|
||||
}
|
||||
original_text_token = original_cache["text_token"]
|
||||
original_prompt_text = original_cache["prompt_text"]
|
||||
original_audio_feat = original_cache["audio_feat"]
|
||||
merged_text_token = torch.cat([original_text_token, new_text_token], dim=0)
|
||||
# Merge text by concatenation
|
||||
merged_prompt_text = original_prompt_text + new_text
|
||||
merged_audio_feat = torch.cat([original_audio_feat, new_audio_feat], dim=0)
|
||||
|
||||
# build new cache
|
||||
merged_cache = {
|
||||
"text_token": merged_text_token,
|
||||
"prompt_text": merged_prompt_text,
|
||||
"audio_feat": merged_audio_feat,
|
||||
}
|
||||
|
||||
return merged_cache
|
||||
|
||||
|
||||
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
||||
|
||||
|
||||
def generate_with_prompt_cache_streaming(
|
||||
self, *args, **kwargs
|
||||
) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]:
|
||||
return self._generate_with_prompt_cache(*args, streaming=True, **kwargs)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _generate_with_prompt_cache(
|
||||
self,
|
||||
@@ -453,14 +610,14 @@ class VoxCPMModel(nn.Module):
|
||||
retry_badcase = False
|
||||
# get prompt from cache
|
||||
if prompt_cache is None:
|
||||
prompt_text_token = torch.empty(0, dtype=torch.int32)
|
||||
prompt_audio_feat = torch.empty((0, self.patch_size, self.audio_vae.latent_dim), dtype=torch.float32)
|
||||
text = target_text
|
||||
else:
|
||||
prompt_text_token = prompt_cache["text_token"]
|
||||
prompt_audio_feat = prompt_cache["audio_feat"]
|
||||
# build target text tokens
|
||||
target_text_token = torch.LongTensor(self.text_tokenizer(target_text))
|
||||
text_token = torch.cat([prompt_text_token, target_text_token], dim=0)
|
||||
prompt_text = prompt_cache["prompt_text"]
|
||||
text = prompt_text + target_text
|
||||
|
||||
text_token = torch.LongTensor(self.text_tokenizer(text))
|
||||
text_token = torch.cat(
|
||||
[
|
||||
text_token,
|
||||
@@ -472,6 +629,8 @@ class VoxCPMModel(nn.Module):
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
target_text_token = torch.LongTensor(self.text_tokenizer(target_text))
|
||||
|
||||
audio_length = prompt_audio_feat.size(0)
|
||||
text_length = text_token.shape[0]
|
||||
@@ -501,7 +660,7 @@ class VoxCPMModel(nn.Module):
|
||||
audio_feat,
|
||||
audio_mask,
|
||||
min_len=min_len,
|
||||
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
|
||||
max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len), # avoid too long audio
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
streaming=streaming,
|
||||
@@ -530,7 +689,6 @@ class VoxCPMModel(nn.Module):
|
||||
break
|
||||
if not streaming:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
|
||||
|
||||
yield (
|
||||
decode_audio,
|
||||
@@ -556,6 +714,7 @@ class VoxCPMModel(nn.Module):
|
||||
inference_timesteps: int = 10,
|
||||
cfg_value: float = 2.0,
|
||||
streaming: bool = False,
|
||||
streaming_prefix_len: int = 3,
|
||||
) -> Generator[Tuple[torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
|
||||
"""Core inference method for audio generation.
|
||||
|
||||
@@ -628,7 +787,7 @@ class VoxCPMModel(nn.Module):
|
||||
1, 2
|
||||
) # [b, p, d]
|
||||
|
||||
curr_embed = self.feat_encoder_step(pred_feat.unsqueeze(1)) # b, 1, c
|
||||
curr_embed = self.feat_encoder(pred_feat.unsqueeze(1)) # b, 1, c
|
||||
curr_embed = self.enc_to_lm_proj(curr_embed)
|
||||
|
||||
pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
|
||||
@@ -636,8 +795,9 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
if streaming:
|
||||
# return the last three predicted latent features to provide enough context for smooth decoding
|
||||
pred_feat_chunk = torch.cat(pred_feat_seq[-3:], dim=1)
|
||||
pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1)
|
||||
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
|
||||
yield feat_pred, pred_feat_seq
|
||||
|
||||
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
|
||||
@@ -656,35 +816,138 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
if not streaming:
|
||||
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
|
||||
|
||||
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_local(cls, path: str, optimize: bool = True):
|
||||
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
|
||||
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
||||
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
||||
|
||||
audio_vae = AudioVAE()
|
||||
audio_vae_config = getattr(config, 'audio_vae_config', None)
|
||||
audio_vae = AudioVAE(config=audio_vae_config) if audio_vae_config else AudioVAE()
|
||||
vae_state_dict = torch.load(
|
||||
os.path.join(path, "audiovae.pth"),
|
||||
map_location="cpu",
|
||||
weights_only=True,
|
||||
)["state_dict"]
|
||||
|
||||
model = cls(config, tokenizer, audio_vae)
|
||||
lm_dtype = get_dtype(model.config.dtype)
|
||||
model = model.to(lm_dtype)
|
||||
model = cls(config, tokenizer, audio_vae, lora_config)
|
||||
if not training:
|
||||
lm_dtype = get_dtype(model.config.dtype)
|
||||
model = model.to(lm_dtype)
|
||||
else: # training mode
|
||||
for name, param in model.named_parameters():
|
||||
if "audio_vae" in name: # freeze VAE weights
|
||||
param.requires_grad = False
|
||||
continue
|
||||
if lora_config is not None:
|
||||
if "lora" not in name: # freeze non-LoRA weights
|
||||
param.requires_grad = False
|
||||
model.audio_vae = model.audio_vae.to(torch.float32)
|
||||
|
||||
model_state_dict = torch.load(
|
||||
os.path.join(path, "pytorch_model.bin"),
|
||||
map_location="cpu",
|
||||
weights_only=True,
|
||||
)["state_dict"]
|
||||
|
||||
|
||||
# Try to load from safetensors first, fallback to pytorch_model.bin
|
||||
safetensors_path = os.path.join(path, "model.safetensors")
|
||||
pytorch_model_path = os.path.join(path, "pytorch_model.bin")
|
||||
|
||||
if os.path.exists(safetensors_path) and SAFETENSORS_AVAILABLE:
|
||||
print(f"Loading model from safetensors: {safetensors_path}")
|
||||
model_state_dict = load_file(safetensors_path)
|
||||
elif os.path.exists(pytorch_model_path):
|
||||
print(f"Loading model from pytorch_model.bin: {pytorch_model_path}")
|
||||
checkpoint = torch.load(
|
||||
pytorch_model_path,
|
||||
map_location="cpu",
|
||||
weights_only=True,
|
||||
)
|
||||
model_state_dict = checkpoint.get("state_dict", checkpoint)
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"Model file not found. Expected either {safetensors_path} or {pytorch_model_path}"
|
||||
)
|
||||
|
||||
for kw, val in vae_state_dict.items():
|
||||
model_state_dict[f"audio_vae.{kw}"] = val
|
||||
model.load_state_dict(model_state_dict, strict=True)
|
||||
|
||||
# LoRALinear holds weight/bias directly, compatible with nn.Linear state_dict keys.
|
||||
# Using strict=False since pretrained weights don't contain lora_A/lora_B.
|
||||
model.load_state_dict(model_state_dict, strict=False)
|
||||
if training:
|
||||
return model
|
||||
return model.to(model.device).eval().optimize(disable=not optimize)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# LoRA Weight Management
|
||||
# ------------------------------------------------------------------ #
|
||||
def _iter_lora_modules(self):
|
||||
"""Iterate over all LoRA modules."""
|
||||
from ..modules.layers.lora import LoRALinear
|
||||
for module in self.modules():
|
||||
if isinstance(module, LoRALinear):
|
||||
yield module
|
||||
|
||||
def load_lora_weights(self, lora_path: str, device: str = None):
|
||||
"""
|
||||
Load LoRA weights from file, supports calling after torch.compile.
|
||||
Uses named_parameters() to handle compile's _orig_mod wrapper.
|
||||
Supports both safetensors and pytorch formats.
|
||||
|
||||
Args:
|
||||
lora_path: Checkpoint path (directory or .safetensors/.ckpt file)
|
||||
device: Target device, defaults to model's current device
|
||||
Returns:
|
||||
tuple: (loaded_keys, skipped_keys)
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
device = device or self.device
|
||||
lora_path = Path(lora_path)
|
||||
|
||||
# Try safetensors first, then fallback to .ckpt
|
||||
if lora_path.is_dir():
|
||||
safetensors_file = lora_path / "lora_weights.safetensors"
|
||||
ckpt_file = lora_path / "lora_weights.ckpt"
|
||||
else:
|
||||
safetensors_file = lora_path if lora_path.suffix == ".safetensors" else None
|
||||
ckpt_file = lora_path if lora_path.suffix in [".ckpt", ".pth"] else None
|
||||
|
||||
# Load from safetensors if available
|
||||
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
|
||||
state_dict = load_file(str(safetensors_file), device=device)
|
||||
elif ckpt_file and ckpt_file.exists():
|
||||
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
|
||||
state_dict = ckpt.get("state_dict", ckpt)
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}"
|
||||
)
|
||||
|
||||
# Build param mapping (handle torch.compile's _orig_mod prefix)
|
||||
model_params = dict(self.named_parameters())
|
||||
key_mapping = {k.replace("._orig_mod.", "."): k for k in model_params if "._orig_mod." in k}
|
||||
|
||||
loaded_keys, skipped_keys = [], []
|
||||
for key, value in state_dict.items():
|
||||
target_key = key if key in model_params else key_mapping.get(key)
|
||||
if target_key:
|
||||
model_params[target_key].data.copy_(value.to(device))
|
||||
loaded_keys.append(key)
|
||||
else:
|
||||
skipped_keys.append(key)
|
||||
|
||||
return loaded_keys, skipped_keys
|
||||
|
||||
def set_lora_enabled(self, enabled: bool):
|
||||
"""Enable/disable all LoRA layers."""
|
||||
for module in self._iter_lora_modules():
|
||||
module.set_enabled(enabled)
|
||||
|
||||
def reset_lora_weights(self):
|
||||
"""Reset all LoRA weights (A: kaiming, B: zeros), effectively unloading LoRA."""
|
||||
for module in self._iter_lora_modules():
|
||||
module.reset_lora_parameters()
|
||||
|
||||
def get_lora_state_dict(self) -> dict:
|
||||
"""Get all LoRA parameters (lora_A/lora_B)."""
|
||||
return {name: param.data.clone()
|
||||
for name, param in self.named_parameters()
|
||||
if "lora_" in name}
|
||||
|
||||
Reference in New Issue
Block a user