Update: VoxCPM1.5 and fine-tuning supprt
This commit is contained in:
@@ -1 +1 @@
|
||||
from .audio_vae import AudioVAE
|
||||
from .audio_vae import AudioVAE, AudioVAEConfig
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import math
|
||||
from typing import List, Union
|
||||
from typing import List, Union, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def WNConv1d(*args, **kwargs):
|
||||
@@ -266,6 +267,17 @@ class CausalDecoder(nn.Module):
|
||||
return self.model(x)
|
||||
|
||||
|
||||
class AudioVAEConfig(BaseModel):
|
||||
encoder_dim: int = 128
|
||||
encoder_rates: List[int] = [2, 5, 8, 8]
|
||||
latent_dim: int = 64
|
||||
decoder_dim: int = 1536
|
||||
decoder_rates: List[int] = [8, 8, 5, 2]
|
||||
depthwise: bool = True
|
||||
sample_rate: int = 16000
|
||||
use_noise_block: bool = False
|
||||
|
||||
|
||||
class AudioVAE(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
@@ -273,17 +285,23 @@ class AudioVAE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_dim: int = 128,
|
||||
encoder_rates: List[int] = [2, 5, 8, 8],
|
||||
latent_dim: int = 64,
|
||||
decoder_dim: int = 1536,
|
||||
decoder_rates: List[int] = [8, 8, 5, 2],
|
||||
depthwise: bool = True,
|
||||
sample_rate: int = 16000,
|
||||
use_noise_block: bool = False,
|
||||
config: Optional[AudioVAEConfig] = None,
|
||||
):
|
||||
# 如果没有传入config,使用默认配置
|
||||
if config is None:
|
||||
config = AudioVAEConfig()
|
||||
|
||||
super().__init__()
|
||||
|
||||
encoder_dim = config.encoder_dim
|
||||
encoder_rates = config.encoder_rates
|
||||
latent_dim = config.latent_dim
|
||||
decoder_dim = config.decoder_dim
|
||||
decoder_rates = config.decoder_rates
|
||||
depthwise = config.depthwise
|
||||
sample_rate = config.sample_rate
|
||||
use_noise_block = config.use_noise_block
|
||||
|
||||
self.encoder_dim = encoder_dim
|
||||
self.encoder_rates = encoder_rates
|
||||
self.decoder_dim = decoder_dim
|
||||
|
||||
@@ -0,0 +1,133 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class LoRALinear(nn.Module):
|
||||
"""
|
||||
LoRA 线性层:直接持有 weight/bias,保持与 nn.Linear 相同的 state_dict key 结构。
|
||||
|
||||
state_dict 结构:
|
||||
- weight: 原始权重(与 nn.Linear 一致)
|
||||
- bias: 原始偏置(与 nn.Linear 一致)
|
||||
- lora_A: LoRA 低秩矩阵 A
|
||||
- lora_B: LoRA 低秩矩阵 B
|
||||
|
||||
这样设计的好处:加载预训练权重时无需做 key 转换。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base: nn.Linear,
|
||||
r: int,
|
||||
alpha: float = 1.0,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(base, nn.Linear), "LoRALinear only supports wrapping nn.Linear."
|
||||
|
||||
self.in_features = base.in_features
|
||||
self.out_features = base.out_features
|
||||
self.r = r
|
||||
self.alpha = alpha
|
||||
self._base_scaling = alpha / r if r > 0 else 0.0
|
||||
|
||||
# 使用 buffer 存储 scaling,这样修改值不会触发 torch.compile 重编译
|
||||
# persistent=False 表示不保存到 state_dict,避免加载时 missing key
|
||||
self.register_buffer("scaling", torch.tensor(self._base_scaling), persistent=False)
|
||||
|
||||
# 直接持有 weight 和 bias(从原始 Linear 转移过来)
|
||||
self.weight = base.weight
|
||||
self.bias = base.bias # 可能是 None
|
||||
|
||||
# LoRA 参数
|
||||
if r > 0:
|
||||
self.lora_A = nn.Parameter(torch.zeros(r, self.in_features))
|
||||
self.lora_B = nn.Parameter(torch.zeros(self.out_features, r))
|
||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||
nn.init.zeros_(self.lora_B)
|
||||
else:
|
||||
self.register_parameter("lora_A", None)
|
||||
self.register_parameter("lora_B", None)
|
||||
|
||||
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# 基础 Linear 计算
|
||||
result = F.linear(x, self.weight, self.bias)
|
||||
if self.r <= 0 or self.lora_A is None:
|
||||
return result
|
||||
# LoRA: result + dropout(x @ A^T @ B^T) * scaling
|
||||
lora_out = F.linear(F.linear(x, self.lora_A), self.lora_B)
|
||||
return result + self.dropout(lora_out) * self.scaling
|
||||
|
||||
def reset_lora_parameters(self):
|
||||
"""重置 LoRA 参数到初始状态"""
|
||||
if self.r > 0 and self.lora_A is not None:
|
||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||
nn.init.zeros_(self.lora_B)
|
||||
|
||||
def set_enabled(self, enabled: bool):
|
||||
"""启用/禁用 LoRA(通过 scaling 控制,兼容 torch.compile)"""
|
||||
# 使用 fill_ 原地修改 buffer 值,不会触发重编译
|
||||
self.scaling.fill_(self._base_scaling if enabled else 0.0)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.scaling.item() != 0.0
|
||||
|
||||
|
||||
def _get_parent_module(root: nn.Module, name: str) -> Optional[nn.Module]:
|
||||
"""
|
||||
根据类似 'layers.0.self_attn.q_proj' 的全名,返回 parent module(即 q_proj 的上一级)。
|
||||
"""
|
||||
parts = name.split(".")
|
||||
if len(parts) == 1:
|
||||
return root
|
||||
parent = root
|
||||
for p in parts[:-1]:
|
||||
if not hasattr(parent, p):
|
||||
return None
|
||||
parent = getattr(parent, p)
|
||||
return parent
|
||||
|
||||
|
||||
def apply_lora_to_named_linear_modules(
|
||||
root: nn.Module,
|
||||
*,
|
||||
target_submodule_names: list[str],
|
||||
r: int,
|
||||
alpha: float,
|
||||
dropout: float,
|
||||
) -> None:
|
||||
"""
|
||||
在给定模块及其子模块中,对名字以 target_submodule_names 结尾的 Linear 层注入 LoRA。
|
||||
|
||||
例如 target_submodule_names=["q_proj", "v_proj"] 时,
|
||||
会在所有名为 *.q_proj / *.v_proj 的 nn.Linear 上替换为 LoRALinear。
|
||||
"""
|
||||
for full_name, module in list(root.named_modules()):
|
||||
if not isinstance(module, nn.Linear):
|
||||
continue
|
||||
short_name = full_name.split(".")[-1]
|
||||
if short_name not in target_submodule_names:
|
||||
continue
|
||||
|
||||
parent = _get_parent_module(root, full_name)
|
||||
if parent is None:
|
||||
continue
|
||||
|
||||
# 用 LoRALinear 替换原始 Linear
|
||||
lora_layer = LoRALinear(
|
||||
base=module,
|
||||
r=r,
|
||||
alpha=alpha,
|
||||
dropout=dropout,
|
||||
)
|
||||
setattr(parent, short_name, lora_layer)
|
||||
|
||||
|
||||
|
||||
@@ -1,20 +1,29 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from typing import List
|
||||
from .local_dit import VoxCPMLocDiT
|
||||
import math
|
||||
import torch.nn.functional as F
|
||||
from torch.func import jvp
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .local_dit import VoxCPMLocDiT
|
||||
|
||||
|
||||
class CfmConfig(BaseModel):
|
||||
sigma_min: float = 1e-06
|
||||
sigma_min: float = 1e-6
|
||||
solver: str = "euler"
|
||||
t_scheduler: str = "log-norm"
|
||||
training_cfg_rate: float = 0.1
|
||||
inference_cfg_rate: float = 1.0
|
||||
reg_loss_type: str = "l1"
|
||||
ratio_r_neq_t_range: Tuple[float, float] = (0.25, 0.75)
|
||||
noise_cond_prob_range: Tuple[float, float] = (0.0, 0.0)
|
||||
noise_cond_scale: float = 0.0
|
||||
|
||||
|
||||
class UnifiedCFM(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
in_channels: int,
|
||||
cfm_params: CfmConfig,
|
||||
estimator: VoxCPMLocDiT,
|
||||
mean_mode: bool = False,
|
||||
@@ -23,12 +32,21 @@ class UnifiedCFM(torch.nn.Module):
|
||||
self.solver = cfm_params.solver
|
||||
self.sigma_min = cfm_params.sigma_min
|
||||
self.t_scheduler = cfm_params.t_scheduler
|
||||
self.training_cfg_rate = cfm_params.training_cfg_rate
|
||||
self.inference_cfg_rate = cfm_params.inference_cfg_rate
|
||||
self.reg_loss_type = cfm_params.reg_loss_type
|
||||
self.ratio_r_neq_t_range = cfm_params.ratio_r_neq_t_range
|
||||
self.noise_cond_prob_range = cfm_params.noise_cond_prob_range
|
||||
self.noise_cond_scale = cfm_params.noise_cond_scale
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.mean_mode = mean_mode
|
||||
|
||||
# Just change the architecture of the estimator here
|
||||
self.estimator = estimator
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Inference
|
||||
# ------------------------------------------------------------------ #
|
||||
@torch.inference_mode()
|
||||
def forward(
|
||||
self,
|
||||
@@ -41,33 +59,25 @@ class UnifiedCFM(torch.nn.Module):
|
||||
sway_sampling_coef: float = 1.0,
|
||||
use_cfg_zero_star: bool = True,
|
||||
):
|
||||
"""Forward diffusion
|
||||
|
||||
Args:
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats)
|
||||
n_timesteps (int): number of diffusion steps
|
||||
cond: Not used but kept for future purposes
|
||||
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
sample: generated mel-spectrogram
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
b, c = mu.shape
|
||||
b, _ = mu.shape
|
||||
t = patch_size
|
||||
z = torch.randn((b, self.in_channels, t), device=mu.device, dtype=mu.dtype) * temperature
|
||||
|
||||
t_span = torch.linspace(1, 0, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
||||
# Sway sampling strategy
|
||||
t_span = t_span + sway_sampling_coef * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span)
|
||||
|
||||
return self.solve_euler(z, t_span=t_span, mu=mu, cond=cond, cfg_value=cfg_value, use_cfg_zero_star=use_cfg_zero_star)
|
||||
return self.solve_euler(
|
||||
x=z,
|
||||
t_span=t_span,
|
||||
mu=mu,
|
||||
cond=cond,
|
||||
cfg_value=cfg_value,
|
||||
use_cfg_zero_star=use_cfg_zero_star,
|
||||
)
|
||||
|
||||
def optimized_scale(self, positive_flat, negative_flat):
|
||||
def optimized_scale(self, positive_flat: torch.Tensor, negative_flat: torch.Tensor):
|
||||
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
||||
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
||||
|
||||
squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8
|
||||
st_star = dot_product / squared_norm
|
||||
return st_star
|
||||
|
||||
@@ -80,24 +90,13 @@ class UnifiedCFM(torch.nn.Module):
|
||||
cfg_value: float = 1.0,
|
||||
use_cfg_zero_star: bool = True,
|
||||
):
|
||||
"""
|
||||
Fixed euler solver for ODEs.
|
||||
Args:
|
||||
x (torch.Tensor): random noise
|
||||
t_span (torch.Tensor): n_timesteps interpolated
|
||||
shape: (n_timesteps + 1,)
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats)
|
||||
cond: condition -- prefix prompt
|
||||
cfg_value (float, optional): cfg value for guidance. Defaults to 1.0.
|
||||
"""
|
||||
t, _, dt = t_span[0], t_span[-1], t_span[0] - t_span[1]
|
||||
|
||||
sol = []
|
||||
zero_init_steps = max(1, int(len(t_span) * 0.04))
|
||||
for step in range(1, len(t_span)):
|
||||
if use_cfg_zero_star and step <= zero_init_steps:
|
||||
dphi_dt = 0.
|
||||
dphi_dt = torch.zeros_like(x)
|
||||
else:
|
||||
# Classifier-Free Guidance inference introduced in VoiceBox
|
||||
b = x.size(0)
|
||||
@@ -105,7 +104,7 @@ class UnifiedCFM(torch.nn.Module):
|
||||
mu_in = torch.zeros([2 * b, mu.size(1)], device=x.device, dtype=x.dtype)
|
||||
t_in = torch.zeros([2 * b], device=x.device, dtype=x.dtype)
|
||||
dt_in = torch.zeros([2 * b], device=x.device, dtype=x.dtype)
|
||||
cond_in = torch.zeros([2 * b, self.in_channels, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
cond_in = torch.zeros([2 * b, self.in_channels, cond.size(2)], device=x.device, dtype=x.dtype)
|
||||
x_in[:b], x_in[b:] = x, x
|
||||
mu_in[:b] = mu
|
||||
t_in[:b], t_in[b:] = t.unsqueeze(0), t.unsqueeze(0)
|
||||
@@ -135,3 +134,98 @@ class UnifiedCFM(torch.nn.Module):
|
||||
dt = t - t_span[step + 1]
|
||||
|
||||
return sol[-1]
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Training loss
|
||||
# ------------------------------------------------------------------ #
|
||||
def adaptive_loss_weighting(self, losses: torch.Tensor, mask: torch.Tensor | None = None, p: float = 0.0, epsilon: float = 1e-3):
|
||||
weights = 1.0 / ((losses + epsilon).pow(p))
|
||||
if mask is not None:
|
||||
weights = weights * mask
|
||||
return weights.detach()
|
||||
|
||||
def sample_r_t(self, x: torch.Tensor, mu: float = -0.4, sigma: float = 1.0, ratio_r_neq_t: float = 0.0):
|
||||
batch_size = x.shape[0]
|
||||
if self.t_scheduler == "log-norm":
|
||||
s_r = torch.randn(batch_size, device=x.device, dtype=x.dtype) * sigma + mu
|
||||
s_t = torch.randn(batch_size, device=x.device, dtype=x.dtype) * sigma + mu
|
||||
r = torch.sigmoid(s_r)
|
||||
t = torch.sigmoid(s_t)
|
||||
elif self.t_scheduler == "uniform":
|
||||
r = torch.rand(batch_size, device=x.device, dtype=x.dtype)
|
||||
t = torch.rand(batch_size, device=x.device, dtype=x.dtype)
|
||||
else:
|
||||
raise ValueError(f"Unsupported t_scheduler: {self.t_scheduler}")
|
||||
|
||||
mask = torch.rand(batch_size, device=x.device, dtype=x.dtype) < ratio_r_neq_t
|
||||
r, t = torch.where(
|
||||
mask,
|
||||
torch.stack([torch.min(r, t), torch.max(r, t)], dim=0),
|
||||
torch.stack([t, t], dim=0),
|
||||
)
|
||||
|
||||
return r.squeeze(), t.squeeze()
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
x1: torch.Tensor,
|
||||
mu: torch.Tensor,
|
||||
cond: torch.Tensor | None = None,
|
||||
tgt_mask: torch.Tensor | None = None,
|
||||
progress: float = 0.0,
|
||||
):
|
||||
b, _, _ = x1.shape
|
||||
|
||||
if self.training_cfg_rate > 0:
|
||||
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
|
||||
mu = mu * cfg_mask.view(-1, 1)
|
||||
|
||||
if cond is None:
|
||||
cond = torch.zeros_like(x1)
|
||||
|
||||
noisy_mask = torch.rand(b, device=x1.device) > (
|
||||
1.0
|
||||
- (
|
||||
self.noise_cond_prob_range[0]
|
||||
+ progress * (self.noise_cond_prob_range[1] - self.noise_cond_prob_range[0])
|
||||
)
|
||||
)
|
||||
cond = cond + noisy_mask.view(-1, 1, 1) * torch.randn_like(cond) * self.noise_cond_scale
|
||||
|
||||
ratio_r_neq_t = (
|
||||
self.ratio_r_neq_t_range[0]
|
||||
+ progress * (self.ratio_r_neq_t_range[1] - self.ratio_r_neq_t_range[0])
|
||||
if self.mean_mode
|
||||
else 0.0
|
||||
)
|
||||
|
||||
r, t = self.sample_r_t(x1, ratio_r_neq_t=ratio_r_neq_t)
|
||||
r_ = r.detach().clone()
|
||||
t_ = t.detach().clone()
|
||||
z = torch.randn_like(x1)
|
||||
y = (1 - t_.view(-1, 1, 1)) * x1 + t_.view(-1, 1, 1) * z
|
||||
v = z - x1
|
||||
|
||||
def model_fn(z_sample, r_sample, t_sample):
|
||||
return self.estimator(z_sample, mu, t_sample, cond, dt=t_sample - r_sample)
|
||||
|
||||
if self.mean_mode:
|
||||
v_r = torch.zeros_like(r)
|
||||
v_t = torch.ones_like(t)
|
||||
from torch.backends.cuda import sdp_kernel
|
||||
|
||||
with sdp_kernel(enable_flash=False, enable_mem_efficient=False):
|
||||
u_pred, dudt = jvp(model_fn, (y, r, t), (v, v_r, v_t))
|
||||
u_tgt = v - (t_ - r_).view(-1, 1, 1) * dudt
|
||||
else:
|
||||
u_pred = model_fn(y, r, t)
|
||||
u_tgt = v
|
||||
|
||||
losses = F.mse_loss(u_pred, u_tgt.detach(), reduction="none").mean(dim=1)
|
||||
if tgt_mask is not None:
|
||||
weights = self.adaptive_loss_weighting(losses, tgt_mask.squeeze(1))
|
||||
loss = (weights * losses).sum() / torch.sum(tgt_mask)
|
||||
else:
|
||||
loss = losses.mean()
|
||||
|
||||
return loss
|
||||
|
||||
Reference in New Issue
Block a user