update voxcpm2

This commit is contained in:
刘鑫
2026-03-31 11:50:37 +08:00
parent 23ed7ffeee
commit d9cf376e16
36 changed files with 8163 additions and 834 deletions
+128 -115
View File
@@ -24,7 +24,6 @@ from typing import Tuple, Union, Generator, List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import warnings
from einops import rearrange
@@ -32,6 +31,7 @@ from pydantic import BaseModel
try:
from safetensors.torch import load_file
SAFETENSORS_AVAILABLE = True
except ImportError:
SAFETENSORS_AVAILABLE = False
@@ -84,9 +84,9 @@ class VoxCPMConfig(BaseModel):
class LoRAConfig(BaseModel):
enable_lm: bool = False # Apply LoRA to base_lm + residual_lm
enable_dit: bool = False # Apply LoRA to VoxCPMLocDiT
enable_proj: bool = False # Apply LoRA to projection Linear layers
enable_lm: bool = False # Apply LoRA to base_lm + residual_lm
enable_dit: bool = False # Apply LoRA to VoxCPMLocDiT
enable_proj: bool = False # Apply LoRA to projection Linear layers
r: int = 8
alpha: int = 16
@@ -165,10 +165,10 @@ class VoxCPMModel(nn.Module):
# Projection layers
self.fsq_layer = ScalarQuantizationLayer(
config.lm_config.hidden_size,
config.lm_config.hidden_size,
config.scalar_quantization_latent_dim,
config.scalar_quantization_scale
config.lm_config.hidden_size,
config.lm_config.hidden_size,
config.scalar_quantization_latent_dim,
config.scalar_quantization_scale,
)
self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size)
self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
@@ -196,9 +196,7 @@ class VoxCPMModel(nn.Module):
# LM: base_lm + residual_lm
if cfg.enable_lm:
for lm in [self.base_lm, self.residual_lm]:
apply_lora_to_named_linear_modules(
lm, target_submodule_names=cfg.target_modules_lm, **lora_kwargs
)
apply_lora_to_named_linear_modules(lm, target_submodule_names=cfg.target_modules_lm, **lora_kwargs)
# DiT: feat_decoder.estimator
if cfg.enable_dit:
@@ -209,6 +207,7 @@ class VoxCPMModel(nn.Module):
# 投影层
if cfg.enable_proj:
from ..modules.layers.lora import LoRALinear
for attr_name in cfg.target_proj_modules:
module = getattr(self, attr_name, None)
if isinstance(module, nn.Linear):
@@ -221,13 +220,17 @@ class VoxCPMModel(nn.Module):
if self.device != "cuda":
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
try:
import triton
import triton # noqa: F401
except ImportError:
raise ValueError("triton is not installed")
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
self.residual_lm.forward_step = torch.compile(
self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True
)
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
self.feat_decoder.estimator = torch.compile(
self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True
)
except Exception as e:
print(f"Warning: torch.compile disabled - {e}", file=sys.stderr)
return self
@@ -313,9 +316,11 @@ class VoxCPMModel(nn.Module):
mu=dit_hidden,
patch_size=self.patch_size,
cond=feat_cond_for_sample,
n_timesteps=self.config.dit_config.cfm_config.inference_cfg_rate
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
else 10,
n_timesteps=(
self.config.dit_config.cfm_config.inference_cfg_rate
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
else 10
),
)
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
@@ -331,7 +336,6 @@ class VoxCPMModel(nn.Module):
def _dtype(self):
return get_dtype(self.config.dtype)
def generate(self, *args, **kwargs) -> torch.Tensor:
return next(self._generate(*args, streaming=False, **kwargs))
@@ -350,7 +354,7 @@ class VoxCPMModel(nn.Module):
cfg_value: float = 2.0,
retry_badcase: bool = False,
retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
streaming: bool = False,
) -> Generator[torch.Tensor, None, None]:
if retry_badcase and streaming:
@@ -394,7 +398,7 @@ class VoxCPMModel(nn.Module):
audio, sr = torchaudio.load(prompt_wav_path)
if audio.size(0) > 1:
audio = audio.mean(dim=0, keepdim=True)
audio = audio.mean(dim=0, keepdim=True)
if sr != self.sample_rate:
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
@@ -435,7 +439,7 @@ class VoxCPMModel(nn.Module):
audio_mask = audio_mask.unsqueeze(0).to(self.device)
target_text_length = len(self.text_tokenizer(target_text))
retry_badcase_times = 0
while retry_badcase_times < retry_badcase_max_times:
inference_result = self._inference(
@@ -444,7 +448,9 @@ class VoxCPMModel(nn.Module):
audio_feat,
audio_mask,
min_len=min_len,
max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len), # avoid too long audio
max_len=min(
int(target_text_length * retry_badcase_ratio_threshold + 10), max_len
), # avoid too long audio
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
streaming=streaming,
@@ -460,18 +466,21 @@ class VoxCPMModel(nn.Module):
latent_pred, pred_audio_feat = next(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...", file=sys.stderr)
print(
f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...",
file=sys.stderr,
)
retry_badcase_times += 1
continue
else:
break
else:
break
break
if not streaming:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
yield decode_audio
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
yield decode_audio
@torch.inference_mode()
def build_prompt_cache(
self,
@@ -480,11 +489,11 @@ class VoxCPMModel(nn.Module):
):
"""
Build prompt cache for subsequent fast generation.
Args:
prompt_text: prompt text (required)
prompt_wav_path: prompt audio path (required)
Returns:
prompt_cache: dict with prompt_text (raw text) and audio features.
Text tokenization will be done during generation for consistency.
@@ -496,7 +505,7 @@ class VoxCPMModel(nn.Module):
audio, sr = torchaudio.load(prompt_wav_path)
if audio.size(0) > 1:
audio = audio.mean(dim=0, keepdim=True)
if sr != self.sample_rate:
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
@@ -514,16 +523,17 @@ class VoxCPMModel(nn.Module):
self.audio_vae.latent_dim,
-1,
self.patch_size,
).permute(1, 2, 0) # (D, T, P)
).permute(
1, 2, 0
) # (D, T, P)
# build prompt cache - only save raw text and audio features
prompt_cache = {
"prompt_text": prompt_text,
"audio_feat": audio_feat,
}
return prompt_cache
def merge_prompt_cache(
self,
original_cache: dict,
@@ -532,12 +542,12 @@ class VoxCPMModel(nn.Module):
):
"""
Merge original prompt cache with newly generated content to stabilize voice.
Args:
original_cache: original prompt cache
new_text: newly generated text
new_text: newly generated text
new_audio_feat: newly generated audio features
Returns:
merged_cache: merged cache with prompt_text and audio_feat
"""
@@ -557,20 +567,17 @@ class VoxCPMModel(nn.Module):
"prompt_text": merged_prompt_text,
"audio_feat": merged_audio_feat,
}
return merged_cache
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
def generate_with_prompt_cache_streaming(
self, *args, **kwargs
) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]:
return self._generate_with_prompt_cache(*args, streaming=True, **kwargs)
@torch.inference_mode()
def _generate_with_prompt_cache(
self,
@@ -588,7 +595,7 @@ class VoxCPMModel(nn.Module):
) -> Generator[Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
"""
Generate audio using pre-built prompt cache.
Args:
target_text: Text to convert to speech
prompt_cache: Cache built by build_prompt_cache (can be None)
@@ -601,7 +608,7 @@ class VoxCPMModel(nn.Module):
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
streaming: Whether to return a generator of audio chunks
streaming_prefix_len: Number of prefix audio patches to use for streaming mode
Returns:
Generator of Tuple containing:
- Decoded audio tensor for the current step if ``streaming=True``, else final decoded audio tensor
@@ -619,7 +626,7 @@ class VoxCPMModel(nn.Module):
prompt_audio_feat = prompt_cache["audio_feat"]
prompt_text = prompt_cache["prompt_text"]
text = prompt_text + target_text
text_token = torch.LongTensor(self.text_tokenizer(text))
text_token = torch.cat(
[
@@ -632,7 +639,7 @@ class VoxCPMModel(nn.Module):
],
dim=-1,
)
target_text_token = torch.LongTensor(self.text_tokenizer(target_text))
audio_length = prompt_audio_feat.size(0)
@@ -645,14 +652,18 @@ class VoxCPMModel(nn.Module):
)
text_token = torch.cat([text_token, text_pad_token])
audio_feat = torch.cat([audio_pad_feat, prompt_audio_feat], dim=0)
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
text_mask = (
torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
)
audio_mask = (
torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
)
text_token = text_token.unsqueeze(0).to(self.device)
text_mask = text_mask.unsqueeze(0).to(self.device)
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
audio_mask = audio_mask.unsqueeze(0).to(self.device)
# run inference
target_text_length = len(self.text_tokenizer(target_text))
retry_badcase_times = 0
@@ -663,7 +674,9 @@ class VoxCPMModel(nn.Module):
audio_feat,
audio_mask,
min_len=min_len,
max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len), # avoid too long audio
max_len=min(
int(target_text_length * retry_badcase_ratio_threshold + 10), max_len
), # avoid too long audio
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
streaming=streaming,
@@ -674,17 +687,16 @@ class VoxCPMModel(nn.Module):
for latent_pred, pred_audio_feat in inference_result:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
yield (
decode_audio,
target_text_token,
pred_audio_feat
)
yield (decode_audio, target_text_token, pred_audio_feat)
break
else:
latent_pred, pred_audio_feat = next(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...", file=sys.stderr)
print(
f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...",
file=sys.stderr,
)
retry_badcase_times += 1
continue
else:
@@ -695,18 +707,14 @@ class VoxCPMModel(nn.Module):
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
patch_len = self.patch_size * self.chunk_size
if audio_mask.sum().item() > 0:
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1):].squeeze(1).cpu()
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1) :].squeeze(1).cpu()
else:
decode_audio = decode_audio[..., :].squeeze(1).cpu()
yield (
decode_audio,
target_text_token,
pred_audio_feat
)
yield (decode_audio, target_text_token, pred_audio_feat)
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
return next(self._inference(*args, streaming=False, **kwargs))
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
return self._inference(*args, streaming=True, **kwargs)
@@ -725,10 +733,10 @@ class VoxCPMModel(nn.Module):
streaming_prefix_len: int = 3,
) -> Generator[Tuple[torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
"""Core inference method for audio generation.
This is the main inference loop that generates audio features
using the language model and diffusion transformer.
Args:
text: Input text tokens
text_mask: Mask for text tokens
@@ -739,7 +747,7 @@ class VoxCPMModel(nn.Module):
inference_timesteps: Number of diffusion steps
cfg_value: Classifier-free guidance value
streaming: Whether to yield each step latent feature or just the final result
Returns:
Generator of Tuple containing:
- Predicted latent feature at the current step if ``streaming=True``, else final latent features
@@ -749,12 +757,12 @@ class VoxCPMModel(nn.Module):
feat_embed = self.feat_encoder(feat) # [b, t, h_feat]
feat_embed = self.enc_to_lm_proj(feat_embed)
if self.config.lm_config.use_mup:
scale_emb = self.config.lm_config.scale_emb
else:
scale_emb = 1.0
text_embed = self.base_lm.embed_tokens(text) * scale_emb
combined_embed = text_mask.unsqueeze(-1) * text_embed + feat_mask.unsqueeze(-1) * feat_embed
@@ -778,11 +786,10 @@ class VoxCPMModel(nn.Module):
is_causal=True,
)
self.base_lm.kv_cache.fill_caches(kv_cache_tuple)
enc_outputs = self.fsq_layer(enc_outputs) * feat_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
lm_hidden = enc_outputs[:, -1, :]
residual_enc_outputs, residual_kv_cache_tuple = self.residual_lm(
inputs_embeds=enc_outputs + feat_mask.unsqueeze(-1) * feat_embed,
is_causal=True,
@@ -790,7 +797,6 @@ class VoxCPMModel(nn.Module):
self.residual_lm.kv_cache.fill_caches(residual_kv_cache_tuple)
residual_hidden = residual_enc_outputs[:, -1, :]
for i in tqdm(range(max_len)):
dit_hidden_1 = self.lm_to_dit_proj(lm_hidden) # [b, h_dit]
dit_hidden_2 = self.res_to_dit_proj(residual_hidden) # [b, h_dit]
@@ -805,10 +811,10 @@ class VoxCPMModel(nn.Module):
).transpose(
1, 2
) # [b, p, d]
curr_embed = self.feat_encoder(pred_feat.unsqueeze(1)) # b, 1, c
curr_embed = self.enc_to_lm_proj(curr_embed)
pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
prefix_feat_cond = pred_feat
@@ -816,58 +822,70 @@ class VoxCPMModel(nn.Module):
# return the last three predicted latent features to provide enough context for smooth decoding
pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1)
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
yield feat_pred, pred_feat_seq
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
if i > min_len and stop_flag == 1:
break
lm_hidden = self.base_lm.forward_step(
curr_embed[:, 0, :], torch.tensor([self.base_lm.kv_cache.step()], device=curr_embed.device)
).clone()
lm_hidden = self.fsq_layer(lm_hidden)
residual_hidden = self.residual_lm.forward_step(
lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device)
lm_hidden + curr_embed[:, 0, :],
torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device),
).clone()
if not streaming:
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
@classmethod
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
tokenizer = LlamaTokenizerFast.from_pretrained(path)
audio_vae_config = getattr(config, 'audio_vae_config', None)
audio_vae_config = getattr(config, "audio_vae_config", None)
audio_vae = AudioVAE(config=audio_vae_config) if audio_vae_config else AudioVAE()
vae_state_dict = torch.load(
os.path.join(path, "audiovae.pth"),
map_location="cpu",
weights_only=True,
)["state_dict"]
# Try to load AudioVAE from safetensors first, fallback to pytorch
audiovae_safetensors_path = os.path.join(path, "audiovae.safetensors")
audiovae_pth_path = os.path.join(path, "audiovae.pth")
if os.path.exists(audiovae_safetensors_path) and SAFETENSORS_AVAILABLE:
print(f"Loading AudioVAE from safetensors: {audiovae_safetensors_path}", file=sys.stderr)
vae_state_dict = load_file(audiovae_safetensors_path, device="cpu")
elif os.path.exists(audiovae_pth_path):
print(f"Loading AudioVAE from pytorch: {audiovae_pth_path}", file=sys.stderr)
checkpoint = torch.load(
audiovae_pth_path,
map_location="cpu",
weights_only=True,
)
vae_state_dict = checkpoint.get("state_dict", checkpoint)
else:
raise FileNotFoundError(
f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
)
model = cls(config, tokenizer, audio_vae, lora_config)
if not training:
lm_dtype = get_dtype(model.config.dtype)
model = model.to(lm_dtype)
else: # training mode
else: # training mode
for name, param in model.named_parameters():
if "audio_vae" in name: # freeze VAE weights
if "audio_vae" in name: # freeze VAE weights
param.requires_grad = False
continue
if lora_config is not None:
if "lora" not in name: # freeze non-LoRA weights
if "lora" not in name: # freeze non-LoRA weights
param.requires_grad = False
model.audio_vae = model.audio_vae.to(torch.float32)
# Try to load from safetensors first, fallback to pytorch_model.bin
safetensors_path = os.path.join(path, "model.safetensors")
pytorch_model_path = os.path.join(path, "pytorch_model.bin")
if os.path.exists(safetensors_path) and SAFETENSORS_AVAILABLE:
print(f"Loading model from safetensors: {safetensors_path}", file=sys.stderr)
model_state_dict = load_file(safetensors_path)
@@ -880,13 +898,11 @@ class VoxCPMModel(nn.Module):
)
model_state_dict = checkpoint.get("state_dict", checkpoint)
else:
raise FileNotFoundError(
f"Model file not found. Expected either {safetensors_path} or {pytorch_model_path}"
)
raise FileNotFoundError(f"Model file not found. Expected either {safetensors_path} or {pytorch_model_path}")
for kw, val in vae_state_dict.items():
model_state_dict[f"audio_vae.{kw}"] = val
# LoRALinear holds weight/bias directly, compatible with nn.Linear state_dict keys.
# Using strict=False since pretrained weights don't contain lora_A/lora_B.
model.load_state_dict(model_state_dict, strict=False)
@@ -900,6 +916,7 @@ class VoxCPMModel(nn.Module):
def _iter_lora_modules(self):
"""Iterate over all LoRA modules."""
from ..modules.layers.lora import LoRALinear
for module in self.modules():
if isinstance(module, LoRALinear):
yield module
@@ -909,7 +926,7 @@ class VoxCPMModel(nn.Module):
Load LoRA weights from file, supports calling after torch.compile.
Uses named_parameters() to handle compile's _orig_mod wrapper.
Supports both safetensors and pytorch formats.
Args:
lora_path: Checkpoint path (directory or .safetensors/.ckpt file)
device: Target device, defaults to model's current device
@@ -917,18 +934,18 @@ class VoxCPMModel(nn.Module):
tuple: (loaded_keys, skipped_keys)
"""
from pathlib import Path
device = device or self.device
lora_path = Path(lora_path)
lora_p = Path(lora_path)
# Try safetensors first, then fallback to .ckpt
if lora_path.is_dir():
safetensors_file = lora_path / "lora_weights.safetensors"
ckpt_file = lora_path / "lora_weights.ckpt"
if lora_p.is_dir():
safetensors_file = lora_p / "lora_weights.safetensors"
ckpt_file = lora_p / "lora_weights.ckpt"
else:
safetensors_file = lora_path if lora_path.suffix == ".safetensors" else None
ckpt_file = lora_path if lora_path.suffix in [".ckpt", ".pth"] else None
safetensors_file = lora_p if lora_p.suffix == ".safetensors" else None
ckpt_file = lora_p if lora_p.suffix in [".ckpt", ".pth"] else None
# Load from safetensors if available
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
state_dict = load_file(str(safetensors_file), device=device)
@@ -936,14 +953,12 @@ class VoxCPMModel(nn.Module):
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
state_dict = ckpt.get("state_dict", ckpt)
else:
raise FileNotFoundError(
f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}"
)
raise FileNotFoundError(f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}")
# Build param mapping (handle torch.compile's _orig_mod prefix)
model_params = dict(self.named_parameters())
key_mapping = {k.replace("._orig_mod.", "."): k for k in model_params if "._orig_mod." in k}
loaded_keys, skipped_keys = [], []
for key, value in state_dict.items():
target_key = key if key in model_params else key_mapping.get(key)
@@ -952,7 +967,7 @@ class VoxCPMModel(nn.Module):
loaded_keys.append(key)
else:
skipped_keys.append(key)
return loaded_keys, skipped_keys
def set_lora_enabled(self, enabled: bool):
@@ -967,6 +982,4 @@ class VoxCPMModel(nn.Module):
def get_lora_state_dict(self) -> dict:
"""Get all LoRA parameters (lora_A/lora_B)."""
return {name: param.data.clone()
for name, param in self.named_parameters()
if "lora_" in name}
return {name: param.data.clone() for name, param in self.named_parameters() if "lora_" in name}