update voxcpm2

This commit is contained in:
刘鑫
2026-03-31 11:50:37 +08:00
parent 23ed7ffeee
commit d9cf376e16
36 changed files with 8163 additions and 834 deletions
+46 -27
View File
@@ -13,11 +13,11 @@ import soundfile as sf
from voxcpm.core import VoxCPM
# -----------------------------
# Validators
# -----------------------------
def validate_file_exists(file_path: str, file_type: str = "file") -> Path:
path = Path(file_path)
if not path.exists():
@@ -53,12 +53,11 @@ def validate_ranges(args, parser):
# Model loading
# -----------------------------
def load_model(args) -> VoxCPM:
print("Loading VoxCPM model...", file=sys.stderr)
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get(
"ZIPENHANCER_MODEL_PATH", None
)
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get("ZIPENHANCER_MODEL_PATH", None)
# Build LoRA config if provided
lora_config = None
@@ -119,22 +118,29 @@ def load_model(args) -> VoxCPM:
# Commands
# -----------------------------
def cmd_clone(args):
if not args.text:
sys.exit("Error: Please provide --text for synthesis")
if not args.prompt_audio or not args.prompt_text:
sys.exit("Error: Voice cloning requires both --prompt-audio and --prompt-text")
has_prompt = args.prompt_audio and args.prompt_text
has_ref = args.reference_audio is not None
if not has_prompt and not has_ref:
sys.exit("Error: Voice cloning requires --prompt-audio + --prompt-text, or --reference-audio, or both")
prompt_audio_path = validate_file_exists(args.prompt_audio, "reference audio file")
if args.prompt_audio:
validate_file_exists(args.prompt_audio, "prompt audio file")
if args.reference_audio:
validate_file_exists(args.reference_audio, "reference audio file")
output_path = validate_output_path(args.output)
model = load_model(args)
audio_array = model.generate(
text=args.text,
prompt_wav_path=str(prompt_audio_path),
prompt_text=args.prompt_text,
prompt_wav_path=args.prompt_audio if has_prompt else None,
prompt_text=args.prompt_text if has_prompt else None,
reference_wav_path=args.reference_audio,
cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
normalize=args.normalize,
@@ -185,7 +191,11 @@ def cmd_batch(args):
prompt_audio_path = None
if args.prompt_audio:
prompt_audio_path = str(validate_file_exists(args.prompt_audio, "reference audio file"))
prompt_audio_path = str(validate_file_exists(args.prompt_audio, "prompt audio file"))
reference_audio_path = None
if args.reference_audio:
reference_audio_path = str(validate_file_exists(args.reference_audio, "reference audio file"))
success_count = 0
@@ -195,10 +205,11 @@ def cmd_batch(args):
text=text,
prompt_wav_path=prompt_audio_path,
prompt_text=args.prompt_text,
reference_wav_path=reference_audio_path,
cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
normalize=args.normalize,
denoise=args.denoise and prompt_audio_path is not None,
denoise=args.denoise and (prompt_audio_path is not None or reference_audio_path is not None),
)
output_file = output_dir / f"output_{i:03d}.wav"
@@ -218,6 +229,7 @@ def cmd_batch(args):
# Parser
# -----------------------------
def _build_unified_parser():
parser = argparse.ArgumentParser(
description="VoxCPM CLI - voice cloning, direct TTS, and batch processing",
@@ -236,34 +248,40 @@ Examples:
parser.add_argument("--text", "-t", help="Text to synthesize (single or clone mode)")
parser.add_argument("--output", "-o", help="Output audio file path (single or clone mode)")
# Prompt
parser.add_argument("--prompt-audio", "-pa", help="Reference audio file path (clone mode)")
parser.add_argument("--prompt-text", "-pt", help="Reference text corresponding to the audio")
parser.add_argument("--denoise", action="store_true", help="Enable prompt speech enhancement")
# Prompt / Reference
parser.add_argument(
"--prompt-audio", "-pa", help="Prompt audio file path (continuation mode, requires --prompt-text)"
)
parser.add_argument("--prompt-text", "-pt", help="Text corresponding to the prompt audio")
parser.add_argument(
"--reference-audio", "-ra", help="Reference audio for voice cloning (isolated mode, VoxCPM2 only)"
)
parser.add_argument("--denoise", action="store_true", help="Enable prompt/reference speech enhancement")
# Generation parameters
parser.add_argument("--cfg-value", type=float, default=2.0,
help="CFG guidance scale (float, recommended 0.55.0, default: 2.0)")
parser.add_argument("--inference-timesteps", type=int, default=10,
help="Inference steps (int, 1100, default: 10)")
parser.add_argument(
"--cfg-value", type=float, default=2.0, help="CFG guidance scale (float, recommended 0.55.0, default: 2.0)"
)
parser.add_argument("--inference-timesteps", type=int, default=10, help="Inference steps (int, 1100, default: 10)")
parser.add_argument("--normalize", action="store_true", help="Enable text normalization")
# Model loading
parser.add_argument("--model-path", type=str, help="Local VoxCPM model path")
parser.add_argument("--hf-model-id", type=str, default="openbmb/VoxCPM1.5",
help="Hugging Face repo id (default: openbmb/VoxCPM1.5)")
parser.add_argument(
"--hf-model-id", type=str, default="openbmb/VoxCPM1.5", help="Hugging Face repo id (default: openbmb/VoxCPM1.5)"
)
parser.add_argument("--cache-dir", type=str, help="Cache directory for Hub downloads")
parser.add_argument("--local-files-only", action="store_true", help="Disable network access")
parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading")
parser.add_argument("--zipenhancer-path", type=str,
help="ZipEnhancer model id or local path (or env ZIPENHANCER_MODEL_PATH)")
parser.add_argument(
"--zipenhancer-path", type=str, help="ZipEnhancer model id or local path (or env ZIPENHANCER_MODEL_PATH)"
)
# LoRA
parser.add_argument("--lora-path", type=str, help="Path to LoRA weights")
parser.add_argument("--lora-r", type=int, default=32, help="LoRA rank (positive int, default: 32)")
parser.add_argument("--lora-alpha", type=int, default=16, help="LoRA alpha (positive int, default: 16)")
parser.add_argument("--lora-dropout", type=float, default=0.0,
help="LoRA dropout rate (0.01.0, default: 0.0)")
parser.add_argument("--lora-dropout", type=float, default=0.0, help="LoRA dropout rate (0.01.0, default: 0.0)")
parser.add_argument("--lora-disable-lm", action="store_true", help="Disable LoRA on LM layers")
parser.add_argument("--lora-disable-dit", action="store_true", help="Disable LoRA on DiT layers")
parser.add_argument("--lora-enable-proj", action="store_true", help="Enable LoRA on projection layers")
@@ -275,6 +293,7 @@ Examples:
# Entrypoint
# -----------------------------
def main():
parser = _build_unified_parser()
args = parser.parse_args()
@@ -296,8 +315,8 @@ def main():
if not args.text or not args.output:
parser.error("Single-sample mode requires --text and --output")
# Clone mode
if args.prompt_audio or args.prompt_text:
# Clone mode (prompt continuation, reference isolation, or both)
if args.prompt_audio or args.prompt_text or args.reference_audio:
return cmd_clone(args)
# Direct synthesis
+151 -100
View File
@@ -1,21 +1,25 @@
import os
import sys
import re
import json
import tempfile
import numpy as np
from typing import Generator, Optional
from huggingface_hub import snapshot_download
from .model.voxcpm import VoxCPMModel, LoRAConfig
from .model.voxcpm2 import VoxCPM2Model
class VoxCPM:
def __init__(self,
voxcpm_model_path : str,
zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
enable_denoiser : bool = True,
optimize: bool = True,
lora_config: Optional[LoRAConfig] = None,
lora_weights_path: Optional[str] = None,
):
def __init__(
self,
voxcpm_model_path: str,
zipenhancer_model_path: str | None = "iic/speech_zipenhancer_ans_multiloss_16k_base",
enable_denoiser: bool = True,
optimize: bool = True,
lora_config: Optional[LoRAConfig] = None,
lora_weights_path: Optional[str] = None,
):
"""Initialize VoxCPM TTS pipeline.
Args:
@@ -26,13 +30,16 @@ class VoxCPM:
id or local path. If None, denoiser will not be initialized.
enable_denoiser: Whether to initialize the denoiser pipeline.
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
provided without lora_config, a default config will be created.
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
containing lora_weights.ckpt). If provided, LoRA weights will be loaded.
"""
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}", file=sys.stderr)
print(
f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}",
file=sys.stderr,
)
# If lora_weights_path is provided but no lora_config, create a default one
if lora_weights_path is not None and lora_config is None:
lora_config = LoRAConfig(
@@ -41,18 +48,33 @@ class VoxCPM:
enable_proj=False,
)
print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}", file=sys.stderr)
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
# Determine model type from config.json architecture field
config_path = os.path.join(voxcpm_model_path, "config.json")
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
arch = config.get("architecture", "voxcpm").lower()
if arch == "voxcpm2":
self.tts_model = VoxCPM2Model.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
print("Loaded VoxCPM2Model", file=sys.stderr)
elif arch == "voxcpm":
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
print("Loaded VoxCPMModel", file=sys.stderr)
else:
raise ValueError(f"Unsupported architecture: {arch}")
# Load LoRA weights if path is provided
if lora_weights_path is not None:
print(f"Loading LoRA weights from: {lora_weights_path}", file=sys.stderr)
loaded_keys, skipped_keys = self.tts_model.load_lora_weights(lora_weights_path)
print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}", file=sys.stderr)
self.text_normalizer = None
self.denoiser = None
if enable_denoiser and zipenhancer_model_path is not None:
from .zipenhancer import ZipEnhancer
self.denoiser = ZipEnhancer(zipenhancer_model_path)
else:
self.denoiser = None
@@ -64,17 +86,18 @@ class VoxCPM:
)
@classmethod
def from_pretrained(cls,
hf_model_id: str = "openbmb/VoxCPM1.5",
load_denoiser: bool = True,
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
cache_dir: str = None,
local_files_only: bool = False,
optimize: bool = True,
lora_config: Optional[LoRAConfig] = None,
lora_weights_path: Optional[str] = None,
**kwargs,
):
def from_pretrained(
cls,
hf_model_id: str = "openbmb/VoxCPM2",
load_denoiser: bool = True,
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
cache_dir: str = None,
local_files_only: bool = False,
optimize: bool = True,
lora_config: Optional[LoRAConfig] = None,
lora_weights_path: Optional[str] = None,
**kwargs,
):
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
Args:
@@ -86,7 +109,7 @@ class VoxCPM:
cache_dir: Custom cache directory for the snapshot.
local_files_only: If True, only use local files and do not attempt
to download.
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
provided without lora_config, a default config will be created with
enable_lm=True and enable_dit=True.
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
@@ -106,7 +129,7 @@ class VoxCPM:
repo_id = hf_model_id
if not repo_id:
raise ValueError("You must provide hf_model_id")
# Load from local path if provided
if os.path.isdir(repo_id):
local_path = repo_id
@@ -134,118 +157,146 @@ class VoxCPM:
def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]:
return self._generate(*args, streaming=True, **kwargs)
def _generate(self,
text : str,
prompt_wav_path : str = None,
prompt_text : str = None,
cfg_value : float = 2.0,
inference_timesteps : int = 10,
min_len : int = 2,
max_len : int = 4096,
normalize : bool = False,
denoise : bool = False,
retry_badcase : bool = True,
retry_badcase_max_times : int = 3,
retry_badcase_ratio_threshold : float = 6.0,
streaming: bool = False,
) -> Generator[np.ndarray, None, None]:
def _generate(
self,
text: str,
prompt_wav_path: str = None,
prompt_text: str = None,
reference_wav_path: str = None,
cfg_value: float = 2.0,
inference_timesteps: int = 10,
min_len: int = 2,
max_len: int = 4096,
normalize: bool = False,
denoise: bool = False,
retry_badcase: bool = True,
retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0,
streaming: bool = False,
) -> Generator[np.ndarray, None, None]:
"""Synthesize speech for the given text and return a single waveform.
This method optionally builds and reuses a prompt cache. If an external
prompt (``prompt_wav_path`` + ``prompt_text``) is provided, it will be
used for all sub-sentences. Otherwise, the prompt cache is built from
the first generated result and reused for the remaining text chunks.
Args:
text: Input text. Can include newlines; each non-empty line is
treated as a sub-sentence.
prompt_wav_path: Path to a reference audio file for prompting.
text: Input text to synthesize.
prompt_wav_path: Path to prompt audio for continuation mode.
Must be paired with ``prompt_text``.
prompt_text: Text content corresponding to the prompt audio.
reference_wav_path: Path to reference audio for voice cloning
(structurally isolated via ref_audio tokens). Can be used
alone or combined with ``prompt_wav_path`` + ``prompt_text``.
cfg_value: Guidance scale for the generation model.
inference_timesteps: Number of inference steps.
min_len: Minimum audio length.
max_len: Maximum token length during generation.
normalize: Whether to run text normalization before generation.
denoise: Whether to denoise the prompt audio if a denoiser is
available.
denoise: Whether to denoise the prompt/reference audio if a
denoiser is available.
retry_badcase: Whether to retry badcase.
retry_badcase_max_times: Maximum number of times to retry badcase.
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio.
streaming: Whether to return a generator of audio chunks.
Returns:
Generator of numpy.ndarray: 1D waveform array (float32) on CPU.
Yields audio chunks for each generations step if ``streaming=True``,
Generator of numpy.ndarray: 1D waveform array (float32) on CPU.
Yields audio chunks for each generation step if ``streaming=True``,
otherwise yields a single array containing the final audio.
"""
if not text.strip() or not isinstance(text, str):
raise ValueError("target text must be a non-empty string")
if prompt_wav_path is not None:
if not os.path.exists(prompt_wav_path):
raise FileNotFoundError(f"prompt_wav_path does not exist: {prompt_wav_path}")
if reference_wav_path is not None:
if not os.path.exists(reference_wav_path):
raise FileNotFoundError(f"reference_wav_path does not exist: {reference_wav_path}")
if (prompt_wav_path is None) != (prompt_text is None):
raise ValueError("prompt_wav_path and prompt_text must both be provided or both be None")
is_v2 = isinstance(self.tts_model, VoxCPM2Model)
if reference_wav_path is not None and not is_v2:
raise ValueError("reference_wav_path is only supported with VoxCPM2 models")
text = text.replace("\n", " ")
text = re.sub(r'\s+', ' ', text)
temp_prompt_wav_path = None
text = re.sub(r"\s+", " ", text)
temp_files = []
try:
if prompt_wav_path is not None and prompt_text is not None:
if denoise and self.denoiser is not None:
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
temp_prompt_wav_path = tmp_file.name
self.denoiser.enhance(prompt_wav_path, output_path=temp_prompt_wav_path)
prompt_wav_path = temp_prompt_wav_path
fixed_prompt_cache = self.tts_model.build_prompt_cache(
prompt_wav_path=prompt_wav_path,
prompt_text=prompt_text
)
actual_prompt_path = prompt_wav_path
actual_ref_path = reference_wav_path
if denoise and self.denoiser is not None:
if prompt_wav_path is not None:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
temp_files.append(tmp.name)
self.denoiser.enhance(prompt_wav_path, output_path=temp_files[-1])
actual_prompt_path = temp_files[-1]
if reference_wav_path is not None:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
temp_files.append(tmp.name)
self.denoiser.enhance(reference_wav_path, output_path=temp_files[-1])
actual_ref_path = temp_files[-1]
if actual_prompt_path is not None or actual_ref_path is not None:
if is_v2:
fixed_prompt_cache = self.tts_model.build_prompt_cache(
prompt_text=prompt_text,
prompt_wav_path=actual_prompt_path,
reference_wav_path=actual_ref_path,
)
else:
fixed_prompt_cache = self.tts_model.build_prompt_cache(
prompt_text=prompt_text,
prompt_wav_path=actual_prompt_path,
)
else:
fixed_prompt_cache = None # will be built from the first inference
fixed_prompt_cache = None
if normalize:
if self.text_normalizer is None:
from .utils.text_normalize import TextNormalizer
self.text_normalizer = TextNormalizer()
text = self.text_normalizer.normalize(text)
generate_result = self.tts_model._generate_with_prompt_cache(
target_text=text,
prompt_cache=fixed_prompt_cache,
min_len=min_len,
max_len=max_len,
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
retry_badcase=retry_badcase,
retry_badcase_max_times=retry_badcase_max_times,
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
streaming=streaming,
)
target_text=text,
prompt_cache=fixed_prompt_cache,
min_len=min_len,
max_len=max_len,
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
retry_badcase=retry_badcase,
retry_badcase_max_times=retry_badcase_max_times,
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
streaming=streaming,
)
for wav, _, _ in generate_result:
yield wav.squeeze(0).cpu().numpy()
finally:
if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path):
try:
os.unlink(temp_prompt_wav_path)
except OSError:
pass
for tmp_path in temp_files:
if tmp_path and os.path.exists(tmp_path):
try:
os.unlink(tmp_path)
except OSError:
pass
# ------------------------------------------------------------------ #
# LoRA Interface (delegated to VoxCPMModel)
# ------------------------------------------------------------------ #
def load_lora(self, lora_weights_path: str) -> tuple:
"""Load LoRA weights from a checkpoint file.
Args:
lora_weights_path: Path to LoRA weights (.pth file or directory
containing lora_weights.ckpt).
Returns:
tuple: (loaded_keys, skipped_keys) - lists of loaded and skipped parameter names.
Raises:
RuntimeError: If model was not initialized with LoRA config.
"""
@@ -259,23 +310,23 @@ class VoxCPM:
def unload_lora(self):
"""Unload LoRA by resetting all LoRA weights to initial state (effectively disabling LoRA)."""
self.tts_model.reset_lora_weights()
def set_lora_enabled(self, enabled: bool):
"""Enable or disable LoRA layers without unloading weights.
Args:
enabled: If True, LoRA layers are active; if False, only base model is used.
"""
self.tts_model.set_lora_enabled(enabled)
def get_lora_state_dict(self) -> dict:
"""Get current LoRA parameters state dict.
Returns:
dict: State dict containing all LoRA parameters (lora_A, lora_B).
"""
return self.tts_model.get_lora_state_dict()
@property
def lora_enabled(self) -> bool:
"""Check if LoRA is currently configured."""
+2 -1
View File
@@ -1,3 +1,4 @@
from .voxcpm import VoxCPMModel
from .voxcpm2 import VoxCPM2Model
__all__ = ["VoxCPMModel"]
__all__ = ["VoxCPMModel", "VoxCPM2Model"]
+18 -19
View File
@@ -5,17 +5,17 @@ from transformers import PreTrainedTokenizer
def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
"""Create a tokenizer wrapper that converts multi-character Chinese tokens to single characters.
This function creates a wrapper around the provided tokenizer that automatically
splits multi-character Chinese tokens into individual characters. This is useful
for ensuring consistent tokenization of Chinese text.
Args:
tokenizer: The base tokenizer to wrap
Returns:
A CharTokenizerWrapper instance that handles multi-character Chinese tokens
Example:
>>> from transformers import LlamaTokenizerFast
>>> tokenizer = LlamaTokenizerFast.from_pretrained("path/to/tokenizer")
@@ -24,20 +24,19 @@ def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
"""
# Pre-compute multi-character tokens (length >= 2, pure Chinese characters)
multichar_tokens = {
token for token in tokenizer.vocab.keys()
if len(token) >= 2 and all("\u4e00" <= c <= "\u9fff" for c in token)
token for token in tokenizer.vocab.keys() if len(token) >= 2 and all("\u4e00" <= c <= "\u9fff" for c in token)
}
class CharTokenizerWrapper:
"""Wrapper class for tokenizers that handles multi-character Chinese tokens.
This wrapper automatically splits multi-character Chinese tokens into
individual characters while preserving the original tokenizer's interface.
"""
def __init__(self, base_tokenizer: PreTrainedTokenizer) -> None:
"""Initialize the wrapper with a base tokenizer.
Args:
base_tokenizer: The tokenizer to wrap
"""
@@ -46,14 +45,14 @@ def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
def tokenize(self, text: str, **kwargs) -> List[str]:
"""Tokenize text and split multi-character Chinese tokens into single characters.
Args:
text: Input text to tokenize
**kwargs: Additional arguments passed to the base tokenizer
Returns:
List of processed tokens with multi-character Chinese tokens split
Example:
>>> wrapper = CharTokenizerWrapper(tokenizer)
>>> tokens = wrapper.tokenize("你好世界")
@@ -61,10 +60,10 @@ def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
"""
if not isinstance(text, str):
raise TypeError(f"Expected string input, got {type(text)}")
tokens = self.tokenizer.tokenize(text, **kwargs)
processed = []
for token in tokens:
# Remove possible subword prefix
clean_token = token.replace("", "")
@@ -75,22 +74,22 @@ def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
processed.extend(chars)
else:
processed.append(token)
return processed
def __call__(self, text: str, **kwargs) -> List[int]:
"""Call the tokenizer and return token IDs.
This method provides the same interface as the original tokenizer
but with multi-character Chinese token handling.
Args:
text: Input text to tokenize
**kwargs: Additional arguments passed to the base tokenizer
Returns:
List of token IDs
Raises:
TypeError: If input is not a string
ValueError: If tokenization fails
+128 -115
View File
@@ -24,7 +24,6 @@ from typing import Tuple, Union, Generator, List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import warnings
from einops import rearrange
@@ -32,6 +31,7 @@ from pydantic import BaseModel
try:
from safetensors.torch import load_file
SAFETENSORS_AVAILABLE = True
except ImportError:
SAFETENSORS_AVAILABLE = False
@@ -84,9 +84,9 @@ class VoxCPMConfig(BaseModel):
class LoRAConfig(BaseModel):
enable_lm: bool = False # Apply LoRA to base_lm + residual_lm
enable_dit: bool = False # Apply LoRA to VoxCPMLocDiT
enable_proj: bool = False # Apply LoRA to projection Linear layers
enable_lm: bool = False # Apply LoRA to base_lm + residual_lm
enable_dit: bool = False # Apply LoRA to VoxCPMLocDiT
enable_proj: bool = False # Apply LoRA to projection Linear layers
r: int = 8
alpha: int = 16
@@ -165,10 +165,10 @@ class VoxCPMModel(nn.Module):
# Projection layers
self.fsq_layer = ScalarQuantizationLayer(
config.lm_config.hidden_size,
config.lm_config.hidden_size,
config.scalar_quantization_latent_dim,
config.scalar_quantization_scale
config.lm_config.hidden_size,
config.lm_config.hidden_size,
config.scalar_quantization_latent_dim,
config.scalar_quantization_scale,
)
self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size)
self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
@@ -196,9 +196,7 @@ class VoxCPMModel(nn.Module):
# LM: base_lm + residual_lm
if cfg.enable_lm:
for lm in [self.base_lm, self.residual_lm]:
apply_lora_to_named_linear_modules(
lm, target_submodule_names=cfg.target_modules_lm, **lora_kwargs
)
apply_lora_to_named_linear_modules(lm, target_submodule_names=cfg.target_modules_lm, **lora_kwargs)
# DiT: feat_decoder.estimator
if cfg.enable_dit:
@@ -209,6 +207,7 @@ class VoxCPMModel(nn.Module):
# 投影层
if cfg.enable_proj:
from ..modules.layers.lora import LoRALinear
for attr_name in cfg.target_proj_modules:
module = getattr(self, attr_name, None)
if isinstance(module, nn.Linear):
@@ -221,13 +220,17 @@ class VoxCPMModel(nn.Module):
if self.device != "cuda":
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
try:
import triton
import triton # noqa: F401
except ImportError:
raise ValueError("triton is not installed")
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
self.residual_lm.forward_step = torch.compile(
self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True
)
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
self.feat_decoder.estimator = torch.compile(
self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True
)
except Exception as e:
print(f"Warning: torch.compile disabled - {e}", file=sys.stderr)
return self
@@ -313,9 +316,11 @@ class VoxCPMModel(nn.Module):
mu=dit_hidden,
patch_size=self.patch_size,
cond=feat_cond_for_sample,
n_timesteps=self.config.dit_config.cfm_config.inference_cfg_rate
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
else 10,
n_timesteps=(
self.config.dit_config.cfm_config.inference_cfg_rate
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
else 10
),
)
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
@@ -331,7 +336,6 @@ class VoxCPMModel(nn.Module):
def _dtype(self):
return get_dtype(self.config.dtype)
def generate(self, *args, **kwargs) -> torch.Tensor:
return next(self._generate(*args, streaming=False, **kwargs))
@@ -350,7 +354,7 @@ class VoxCPMModel(nn.Module):
cfg_value: float = 2.0,
retry_badcase: bool = False,
retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
streaming: bool = False,
) -> Generator[torch.Tensor, None, None]:
if retry_badcase and streaming:
@@ -394,7 +398,7 @@ class VoxCPMModel(nn.Module):
audio, sr = torchaudio.load(prompt_wav_path)
if audio.size(0) > 1:
audio = audio.mean(dim=0, keepdim=True)
audio = audio.mean(dim=0, keepdim=True)
if sr != self.sample_rate:
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
@@ -435,7 +439,7 @@ class VoxCPMModel(nn.Module):
audio_mask = audio_mask.unsqueeze(0).to(self.device)
target_text_length = len(self.text_tokenizer(target_text))
retry_badcase_times = 0
while retry_badcase_times < retry_badcase_max_times:
inference_result = self._inference(
@@ -444,7 +448,9 @@ class VoxCPMModel(nn.Module):
audio_feat,
audio_mask,
min_len=min_len,
max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len), # avoid too long audio
max_len=min(
int(target_text_length * retry_badcase_ratio_threshold + 10), max_len
), # avoid too long audio
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
streaming=streaming,
@@ -460,18 +466,21 @@ class VoxCPMModel(nn.Module):
latent_pred, pred_audio_feat = next(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...", file=sys.stderr)
print(
f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...",
file=sys.stderr,
)
retry_badcase_times += 1
continue
else:
break
else:
break
break
if not streaming:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
yield decode_audio
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
yield decode_audio
@torch.inference_mode()
def build_prompt_cache(
self,
@@ -480,11 +489,11 @@ class VoxCPMModel(nn.Module):
):
"""
Build prompt cache for subsequent fast generation.
Args:
prompt_text: prompt text (required)
prompt_wav_path: prompt audio path (required)
Returns:
prompt_cache: dict with prompt_text (raw text) and audio features.
Text tokenization will be done during generation for consistency.
@@ -496,7 +505,7 @@ class VoxCPMModel(nn.Module):
audio, sr = torchaudio.load(prompt_wav_path)
if audio.size(0) > 1:
audio = audio.mean(dim=0, keepdim=True)
if sr != self.sample_rate:
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
@@ -514,16 +523,17 @@ class VoxCPMModel(nn.Module):
self.audio_vae.latent_dim,
-1,
self.patch_size,
).permute(1, 2, 0) # (D, T, P)
).permute(
1, 2, 0
) # (D, T, P)
# build prompt cache - only save raw text and audio features
prompt_cache = {
"prompt_text": prompt_text,
"audio_feat": audio_feat,
}
return prompt_cache
def merge_prompt_cache(
self,
original_cache: dict,
@@ -532,12 +542,12 @@ class VoxCPMModel(nn.Module):
):
"""
Merge original prompt cache with newly generated content to stabilize voice.
Args:
original_cache: original prompt cache
new_text: newly generated text
new_text: newly generated text
new_audio_feat: newly generated audio features
Returns:
merged_cache: merged cache with prompt_text and audio_feat
"""
@@ -557,20 +567,17 @@ class VoxCPMModel(nn.Module):
"prompt_text": merged_prompt_text,
"audio_feat": merged_audio_feat,
}
return merged_cache
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
def generate_with_prompt_cache_streaming(
self, *args, **kwargs
) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]:
return self._generate_with_prompt_cache(*args, streaming=True, **kwargs)
@torch.inference_mode()
def _generate_with_prompt_cache(
self,
@@ -588,7 +595,7 @@ class VoxCPMModel(nn.Module):
) -> Generator[Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
"""
Generate audio using pre-built prompt cache.
Args:
target_text: Text to convert to speech
prompt_cache: Cache built by build_prompt_cache (can be None)
@@ -601,7 +608,7 @@ class VoxCPMModel(nn.Module):
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
streaming: Whether to return a generator of audio chunks
streaming_prefix_len: Number of prefix audio patches to use for streaming mode
Returns:
Generator of Tuple containing:
- Decoded audio tensor for the current step if ``streaming=True``, else final decoded audio tensor
@@ -619,7 +626,7 @@ class VoxCPMModel(nn.Module):
prompt_audio_feat = prompt_cache["audio_feat"]
prompt_text = prompt_cache["prompt_text"]
text = prompt_text + target_text
text_token = torch.LongTensor(self.text_tokenizer(text))
text_token = torch.cat(
[
@@ -632,7 +639,7 @@ class VoxCPMModel(nn.Module):
],
dim=-1,
)
target_text_token = torch.LongTensor(self.text_tokenizer(target_text))
audio_length = prompt_audio_feat.size(0)
@@ -645,14 +652,18 @@ class VoxCPMModel(nn.Module):
)
text_token = torch.cat([text_token, text_pad_token])
audio_feat = torch.cat([audio_pad_feat, prompt_audio_feat], dim=0)
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
text_mask = (
torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
)
audio_mask = (
torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
)
text_token = text_token.unsqueeze(0).to(self.device)
text_mask = text_mask.unsqueeze(0).to(self.device)
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
audio_mask = audio_mask.unsqueeze(0).to(self.device)
# run inference
target_text_length = len(self.text_tokenizer(target_text))
retry_badcase_times = 0
@@ -663,7 +674,9 @@ class VoxCPMModel(nn.Module):
audio_feat,
audio_mask,
min_len=min_len,
max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len), # avoid too long audio
max_len=min(
int(target_text_length * retry_badcase_ratio_threshold + 10), max_len
), # avoid too long audio
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
streaming=streaming,
@@ -674,17 +687,16 @@ class VoxCPMModel(nn.Module):
for latent_pred, pred_audio_feat in inference_result:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
yield (
decode_audio,
target_text_token,
pred_audio_feat
)
yield (decode_audio, target_text_token, pred_audio_feat)
break
else:
latent_pred, pred_audio_feat = next(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...", file=sys.stderr)
print(
f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...",
file=sys.stderr,
)
retry_badcase_times += 1
continue
else:
@@ -695,18 +707,14 @@ class VoxCPMModel(nn.Module):
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
patch_len = self.patch_size * self.chunk_size
if audio_mask.sum().item() > 0:
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1):].squeeze(1).cpu()
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1) :].squeeze(1).cpu()
else:
decode_audio = decode_audio[..., :].squeeze(1).cpu()
yield (
decode_audio,
target_text_token,
pred_audio_feat
)
yield (decode_audio, target_text_token, pred_audio_feat)
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
return next(self._inference(*args, streaming=False, **kwargs))
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
return self._inference(*args, streaming=True, **kwargs)
@@ -725,10 +733,10 @@ class VoxCPMModel(nn.Module):
streaming_prefix_len: int = 3,
) -> Generator[Tuple[torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
"""Core inference method for audio generation.
This is the main inference loop that generates audio features
using the language model and diffusion transformer.
Args:
text: Input text tokens
text_mask: Mask for text tokens
@@ -739,7 +747,7 @@ class VoxCPMModel(nn.Module):
inference_timesteps: Number of diffusion steps
cfg_value: Classifier-free guidance value
streaming: Whether to yield each step latent feature or just the final result
Returns:
Generator of Tuple containing:
- Predicted latent feature at the current step if ``streaming=True``, else final latent features
@@ -749,12 +757,12 @@ class VoxCPMModel(nn.Module):
feat_embed = self.feat_encoder(feat) # [b, t, h_feat]
feat_embed = self.enc_to_lm_proj(feat_embed)
if self.config.lm_config.use_mup:
scale_emb = self.config.lm_config.scale_emb
else:
scale_emb = 1.0
text_embed = self.base_lm.embed_tokens(text) * scale_emb
combined_embed = text_mask.unsqueeze(-1) * text_embed + feat_mask.unsqueeze(-1) * feat_embed
@@ -778,11 +786,10 @@ class VoxCPMModel(nn.Module):
is_causal=True,
)
self.base_lm.kv_cache.fill_caches(kv_cache_tuple)
enc_outputs = self.fsq_layer(enc_outputs) * feat_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
lm_hidden = enc_outputs[:, -1, :]
residual_enc_outputs, residual_kv_cache_tuple = self.residual_lm(
inputs_embeds=enc_outputs + feat_mask.unsqueeze(-1) * feat_embed,
is_causal=True,
@@ -790,7 +797,6 @@ class VoxCPMModel(nn.Module):
self.residual_lm.kv_cache.fill_caches(residual_kv_cache_tuple)
residual_hidden = residual_enc_outputs[:, -1, :]
for i in tqdm(range(max_len)):
dit_hidden_1 = self.lm_to_dit_proj(lm_hidden) # [b, h_dit]
dit_hidden_2 = self.res_to_dit_proj(residual_hidden) # [b, h_dit]
@@ -805,10 +811,10 @@ class VoxCPMModel(nn.Module):
).transpose(
1, 2
) # [b, p, d]
curr_embed = self.feat_encoder(pred_feat.unsqueeze(1)) # b, 1, c
curr_embed = self.enc_to_lm_proj(curr_embed)
pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
prefix_feat_cond = pred_feat
@@ -816,58 +822,70 @@ class VoxCPMModel(nn.Module):
# return the last three predicted latent features to provide enough context for smooth decoding
pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1)
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
yield feat_pred, pred_feat_seq
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
if i > min_len and stop_flag == 1:
break
lm_hidden = self.base_lm.forward_step(
curr_embed[:, 0, :], torch.tensor([self.base_lm.kv_cache.step()], device=curr_embed.device)
).clone()
lm_hidden = self.fsq_layer(lm_hidden)
residual_hidden = self.residual_lm.forward_step(
lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device)
lm_hidden + curr_embed[:, 0, :],
torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device),
).clone()
if not streaming:
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
@classmethod
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
tokenizer = LlamaTokenizerFast.from_pretrained(path)
audio_vae_config = getattr(config, 'audio_vae_config', None)
audio_vae_config = getattr(config, "audio_vae_config", None)
audio_vae = AudioVAE(config=audio_vae_config) if audio_vae_config else AudioVAE()
vae_state_dict = torch.load(
os.path.join(path, "audiovae.pth"),
map_location="cpu",
weights_only=True,
)["state_dict"]
# Try to load AudioVAE from safetensors first, fallback to pytorch
audiovae_safetensors_path = os.path.join(path, "audiovae.safetensors")
audiovae_pth_path = os.path.join(path, "audiovae.pth")
if os.path.exists(audiovae_safetensors_path) and SAFETENSORS_AVAILABLE:
print(f"Loading AudioVAE from safetensors: {audiovae_safetensors_path}", file=sys.stderr)
vae_state_dict = load_file(audiovae_safetensors_path, device="cpu")
elif os.path.exists(audiovae_pth_path):
print(f"Loading AudioVAE from pytorch: {audiovae_pth_path}", file=sys.stderr)
checkpoint = torch.load(
audiovae_pth_path,
map_location="cpu",
weights_only=True,
)
vae_state_dict = checkpoint.get("state_dict", checkpoint)
else:
raise FileNotFoundError(
f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
)
model = cls(config, tokenizer, audio_vae, lora_config)
if not training:
lm_dtype = get_dtype(model.config.dtype)
model = model.to(lm_dtype)
else: # training mode
else: # training mode
for name, param in model.named_parameters():
if "audio_vae" in name: # freeze VAE weights
if "audio_vae" in name: # freeze VAE weights
param.requires_grad = False
continue
if lora_config is not None:
if "lora" not in name: # freeze non-LoRA weights
if "lora" not in name: # freeze non-LoRA weights
param.requires_grad = False
model.audio_vae = model.audio_vae.to(torch.float32)
# Try to load from safetensors first, fallback to pytorch_model.bin
safetensors_path = os.path.join(path, "model.safetensors")
pytorch_model_path = os.path.join(path, "pytorch_model.bin")
if os.path.exists(safetensors_path) and SAFETENSORS_AVAILABLE:
print(f"Loading model from safetensors: {safetensors_path}", file=sys.stderr)
model_state_dict = load_file(safetensors_path)
@@ -880,13 +898,11 @@ class VoxCPMModel(nn.Module):
)
model_state_dict = checkpoint.get("state_dict", checkpoint)
else:
raise FileNotFoundError(
f"Model file not found. Expected either {safetensors_path} or {pytorch_model_path}"
)
raise FileNotFoundError(f"Model file not found. Expected either {safetensors_path} or {pytorch_model_path}")
for kw, val in vae_state_dict.items():
model_state_dict[f"audio_vae.{kw}"] = val
# LoRALinear holds weight/bias directly, compatible with nn.Linear state_dict keys.
# Using strict=False since pretrained weights don't contain lora_A/lora_B.
model.load_state_dict(model_state_dict, strict=False)
@@ -900,6 +916,7 @@ class VoxCPMModel(nn.Module):
def _iter_lora_modules(self):
"""Iterate over all LoRA modules."""
from ..modules.layers.lora import LoRALinear
for module in self.modules():
if isinstance(module, LoRALinear):
yield module
@@ -909,7 +926,7 @@ class VoxCPMModel(nn.Module):
Load LoRA weights from file, supports calling after torch.compile.
Uses named_parameters() to handle compile's _orig_mod wrapper.
Supports both safetensors and pytorch formats.
Args:
lora_path: Checkpoint path (directory or .safetensors/.ckpt file)
device: Target device, defaults to model's current device
@@ -917,18 +934,18 @@ class VoxCPMModel(nn.Module):
tuple: (loaded_keys, skipped_keys)
"""
from pathlib import Path
device = device or self.device
lora_path = Path(lora_path)
lora_p = Path(lora_path)
# Try safetensors first, then fallback to .ckpt
if lora_path.is_dir():
safetensors_file = lora_path / "lora_weights.safetensors"
ckpt_file = lora_path / "lora_weights.ckpt"
if lora_p.is_dir():
safetensors_file = lora_p / "lora_weights.safetensors"
ckpt_file = lora_p / "lora_weights.ckpt"
else:
safetensors_file = lora_path if lora_path.suffix == ".safetensors" else None
ckpt_file = lora_path if lora_path.suffix in [".ckpt", ".pth"] else None
safetensors_file = lora_p if lora_p.suffix == ".safetensors" else None
ckpt_file = lora_p if lora_p.suffix in [".ckpt", ".pth"] else None
# Load from safetensors if available
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
state_dict = load_file(str(safetensors_file), device=device)
@@ -936,14 +953,12 @@ class VoxCPMModel(nn.Module):
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
state_dict = ckpt.get("state_dict", ckpt)
else:
raise FileNotFoundError(
f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}"
)
raise FileNotFoundError(f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}")
# Build param mapping (handle torch.compile's _orig_mod prefix)
model_params = dict(self.named_parameters())
key_mapping = {k.replace("._orig_mod.", "."): k for k in model_params if "._orig_mod." in k}
loaded_keys, skipped_keys = [], []
for key, value in state_dict.items():
target_key = key if key in model_params else key_mapping.get(key)
@@ -952,7 +967,7 @@ class VoxCPMModel(nn.Module):
loaded_keys.append(key)
else:
skipped_keys.append(key)
return loaded_keys, skipped_keys
def set_lora_enabled(self, enabled: bool):
@@ -967,6 +982,4 @@ class VoxCPMModel(nn.Module):
def get_lora_state_dict(self) -> dict:
"""Get all LoRA parameters (lora_A/lora_B)."""
return {name: param.data.clone()
for name, param in self.named_parameters()
if "lora_" in name}
return {name: param.data.clone() for name, param in self.named_parameters() if "lora_" in name}
File diff suppressed because it is too large Load Diff
+1
View File
@@ -1 +1,2 @@
from .audio_vae import AudioVAE, AudioVAEConfig
from .audio_vae_v2 import AudioVAE as AudioVAEV2, AudioVAEConfig as AudioVAEConfigV2
+4 -4
View File
@@ -1,5 +1,5 @@
import math
from typing import List, Union, Optional
from typing import List
import numpy as np
import torch
@@ -285,12 +285,12 @@ class AudioVAE(nn.Module):
def __init__(
self,
config: Optional[AudioVAEConfig] = None,
config: AudioVAEConfig = None,
):
# 如果没有传入config,使用默认配置
if config is None:
config = AudioVAEConfig()
super().__init__()
encoder_dim = config.encoder_dim
@@ -301,7 +301,7 @@ class AudioVAE(nn.Module):
depthwise = config.depthwise
sample_rate = config.sample_rate
use_noise_block = config.use_noise_block
self.encoder_dim = encoder_dim
self.encoder_rates = encoder_rates
self.decoder_dim = decoder_dim
+486
View File
@@ -0,0 +1,486 @@
import math
from typing import List, Optional
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
from pydantic import BaseModel
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
class CausalConv1d(nn.Conv1d):
def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs):
super().__init__(*args, **kwargs)
self.__padding = padding
self.__output_padding = output_padding
def forward(self, x):
x_pad = F.pad(x, (self.__padding * 2 - self.__output_padding, 0))
return super().forward(x_pad)
class CausalTransposeConv1d(nn.ConvTranspose1d):
def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs):
super().__init__(*args, **kwargs)
self.__padding = padding
self.__output_padding = output_padding
def forward(self, x):
return super().forward(x)[..., : -(self.__padding * 2 - self.__output_padding)]
def WNCausalConv1d(*args, **kwargs):
return weight_norm(CausalConv1d(*args, **kwargs))
def WNCausalTransposeConv1d(*args, **kwargs):
return weight_norm(CausalTransposeConv1d(*args, **kwargs))
# Scripting this brings model speed up 1.4x
@torch.jit.script
def snake(x, alpha):
shape = x.shape
x = x.reshape(shape[0], shape[1], -1)
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
x = x.reshape(shape)
return x
class Snake1d(nn.Module):
def __init__(self, channels):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
def forward(self, x):
return snake(x, self.alpha)
def init_weights(m):
if isinstance(m, nn.Conv1d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
class CausalResidualUnit(nn.Module):
def __init__(self, dim: int = 16, dilation: int = 1, kernel: int = 7, groups: int = 1):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.block = nn.Sequential(
Snake1d(dim),
WNCausalConv1d(
dim,
dim,
kernel_size=kernel,
dilation=dilation,
padding=pad,
groups=groups,
),
Snake1d(dim),
WNCausalConv1d(dim, dim, kernel_size=1),
)
def forward(self, x):
y = self.block(x)
pad = (x.shape[-1] - y.shape[-1]) // 2
assert pad == 0
if pad > 0:
x = x[..., pad:-pad]
return x + y
class CausalEncoderBlock(nn.Module):
def __init__(self, output_dim: int = 16, input_dim=None, stride: int = 1, groups=1):
super().__init__()
input_dim = input_dim or output_dim // 2
self.block = nn.Sequential(
CausalResidualUnit(input_dim, dilation=1, groups=groups),
CausalResidualUnit(input_dim, dilation=3, groups=groups),
CausalResidualUnit(input_dim, dilation=9, groups=groups),
Snake1d(input_dim),
WNCausalConv1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
output_padding=stride % 2,
),
)
def forward(self, x):
return self.block(x)
class CausalEncoder(nn.Module):
def __init__(
self,
d_model: int = 64,
latent_dim: int = 32,
strides: list = [2, 4, 8, 8],
depthwise: bool = False,
):
super().__init__()
# Create first convolution
self.block = [WNCausalConv1d(1, d_model, kernel_size=7, padding=3)]
# Create EncoderBlocks that double channels as they downsample by `stride`
for stride in strides:
d_model *= 2
groups = d_model // 2 if depthwise else 1
self.block += [CausalEncoderBlock(output_dim=d_model, stride=stride, groups=groups)]
groups = d_model if depthwise else 1
# Create two convolution, for mu and logvar
self.fc_mu = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
self.fc_logvar = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
# Wrap black into nn.Sequential
self.block = nn.Sequential(*self.block)
self.enc_dim = d_model
def forward(self, x):
hidden_state = self.block(x)
return {
"hidden_state": hidden_state,
"mu": self.fc_mu(hidden_state),
"logvar": self.fc_logvar(hidden_state),
}
class NoiseBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.linear = WNCausalConv1d(dim, dim, kernel_size=1, bias=False)
def forward(self, x):
B, C, T = x.shape
noise = torch.randn((B, 1, T), device=x.device, dtype=x.dtype)
h = self.linear(x)
n = noise * h
x = x + n
return x
class CausalDecoderBlock(nn.Module):
def __init__(
self,
input_dim: int = 16,
output_dim: int = 8,
stride: int = 1,
groups=1,
use_noise_block: bool = False,
):
super().__init__()
layers = [
Snake1d(input_dim),
WNCausalTransposeConv1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
output_padding=stride % 2,
),
]
if use_noise_block:
layers.append(NoiseBlock(output_dim))
layers.extend(
[
CausalResidualUnit(output_dim, dilation=1, groups=groups),
CausalResidualUnit(output_dim, dilation=3, groups=groups),
CausalResidualUnit(output_dim, dilation=9, groups=groups),
]
)
self.block = nn.Sequential(*layers)
self.input_channels = input_dim
def forward(self, x):
return self.block(x)
class TransposeLastTwoDim(torch.nn.Module):
def forward(self, x):
return torch.transpose(x, -1, -2)
class SampleRateConditionLayer(nn.Module):
def __init__(
self,
input_dim: int,
sr_bin_buckets: int = None,
cond_type: str = "scale_bias",
cond_dim: int = 128,
out_layer: bool = False,
):
super().__init__()
self.cond_type, out_layer_in_dim = cond_type, input_dim
if cond_type == "scale_bias":
self.scale_embed = nn.Embedding(sr_bin_buckets, input_dim)
self.bias_embed = nn.Embedding(sr_bin_buckets, input_dim)
nn.init.ones_(self.scale_embed.weight)
nn.init.zeros_(self.bias_embed.weight)
elif cond_type == "scale_bias_init":
self.scale_embed = nn.Embedding(sr_bin_buckets, input_dim)
self.bias_embed = nn.Embedding(sr_bin_buckets, input_dim)
nn.init.normal_(self.scale_embed.weight, mean=1)
nn.init.normal_(self.bias_embed.weight)
elif cond_type == "add":
self.cond_embed = nn.Embedding(sr_bin_buckets, input_dim)
nn.init.normal_(self.cond_embed.weight)
elif cond_type == "concat":
self.cond_embed = nn.Embedding(sr_bin_buckets, cond_dim)
assert out_layer, "out_layer must be True for concat cond_type"
out_layer_in_dim = input_dim + cond_dim
else:
raise ValueError(f"Invalid cond_type: {cond_type}")
if out_layer:
self.out_layer = nn.Sequential(
Snake1d(out_layer_in_dim),
WNCausalConv1d(out_layer_in_dim, input_dim, kernel_size=1),
)
else:
self.out_layer = nn.Identity()
def forward(self, x, sr_cond):
if self.cond_type == "scale_bias" or self.cond_type == "scale_bias_init":
x = x * self.scale_embed(sr_cond).unsqueeze(-1) + self.bias_embed(sr_cond).unsqueeze(-1)
elif self.cond_type == "add":
x = x + self.cond_embed(sr_cond).unsqueeze(-1)
elif self.cond_type == "concat":
x = torch.cat([x, self.cond_embed(sr_cond).unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
return self.out_layer(x)
class CausalDecoder(nn.Module):
def __init__(
self,
input_channel,
channels,
rates,
depthwise: bool = False,
d_out: int = 1,
use_noise_block: bool = False,
sr_bin_boundaries: List[int] = None,
cond_type: str = "scale_bias",
cond_dim: int = 128,
cond_out_layer: bool = False,
):
super().__init__()
# Add first conv layer
if depthwise:
layers = [
WNCausalConv1d(input_channel, input_channel, kernel_size=7, padding=3, groups=input_channel),
WNCausalConv1d(input_channel, channels, kernel_size=1),
]
else:
layers = [WNCausalConv1d(input_channel, channels, kernel_size=7, padding=3)]
# Add upsampling + MRF blocks
for i, stride in enumerate(rates):
input_dim = channels // 2**i
output_dim = channels // 2 ** (i + 1)
groups = output_dim if depthwise else 1
layers += [
CausalDecoderBlock(
input_dim,
output_dim,
stride,
groups=groups,
use_noise_block=use_noise_block,
)
]
# Add final conv layer
layers += [
Snake1d(output_dim),
WNCausalConv1d(output_dim, d_out, kernel_size=7, padding=3),
nn.Tanh(),
]
if sr_bin_boundaries is None:
self.model = nn.Sequential(*layers)
self.sr_bin_boundaries = None
else:
self.model = nn.ModuleList(layers)
self.register_buffer("sr_bin_boundaries", torch.tensor(sr_bin_boundaries, dtype=torch.int32))
self.sr_bin_buckets = len(sr_bin_boundaries) + 1
cond_layers = []
for layer in self.model:
if layer.__class__.__name__ == "CausalDecoderBlock":
cond_layers.append(
SampleRateConditionLayer(
input_dim=layer.input_channels,
sr_bin_buckets=self.sr_bin_buckets,
cond_type=cond_type,
cond_dim=cond_dim,
out_layer=cond_out_layer,
)
)
else:
cond_layers.append(None)
self.sr_cond_model = nn.ModuleList(cond_layers)
def get_sr_idx(self, sr):
return torch.bucketize(sr, self.sr_bin_boundaries)
def forward(self, x, sr_cond=None):
if self.sr_bin_boundaries is not None:
# assert sr_cond is not None
sr_cond = self.get_sr_idx(sr_cond)
for layer, sr_cond_layer in zip(self.model, self.sr_cond_model):
if sr_cond_layer is not None:
x = sr_cond_layer(x, sr_cond)
x = layer(x)
return x
else:
return self.model(x)
class AudioVAEConfig(BaseModel):
encoder_dim: int = 128
encoder_rates: List[int] = [2, 5, 8, 8]
latent_dim: int = 64
decoder_dim: int = 2048
decoder_rates: List[int] = [8, 6, 5, 2, 2, 2]
depthwise: bool = True
sample_rate: int = 16000
out_sample_rate: int = 48000
use_noise_block: bool = False
sr_bin_boundaries: Optional[List[int]] = [20000, 30000, 40000]
cond_type: str = "scale_bias"
cond_dim: int = 128
cond_out_layer: bool = False
class AudioVAE(nn.Module):
"""
Args:
"""
def __init__(
self,
config: AudioVAEConfig = None,
):
# 如果没有传入config,使用默认配置
if config is None:
config = AudioVAEConfig()
super().__init__()
encoder_dim = config.encoder_dim
encoder_rates = config.encoder_rates
latent_dim = config.latent_dim
decoder_dim = config.decoder_dim
decoder_rates = config.decoder_rates
depthwise = config.depthwise
sample_rate = config.sample_rate
out_sample_rate = config.out_sample_rate
use_noise_block = config.use_noise_block
sr_bin_boundaries = config.sr_bin_boundaries
cond_type = config.cond_type
cond_dim = config.cond_dim
cond_out_layer = config.cond_out_layer
self.encoder_dim = encoder_dim
self.encoder_rates = encoder_rates
self.decoder_dim = decoder_dim
self.decoder_rates = decoder_rates
self.depthwise = depthwise
self.use_noise_block = use_noise_block
if latent_dim is None:
latent_dim = encoder_dim * (2 ** len(encoder_rates))
self.latent_dim = latent_dim
self.hop_length = np.prod(encoder_rates)
self.encoder = CausalEncoder(
encoder_dim,
latent_dim,
encoder_rates,
depthwise=depthwise,
)
self.decoder = CausalDecoder(
latent_dim,
decoder_dim,
decoder_rates,
depthwise=depthwise,
use_noise_block=use_noise_block,
sr_bin_boundaries=sr_bin_boundaries,
cond_type=cond_type,
cond_dim=cond_dim,
cond_out_layer=cond_out_layer,
)
self.sample_rate = sample_rate
self.out_sample_rate = out_sample_rate
self.sr_bin_boundaries = sr_bin_boundaries
self.chunk_size = math.prod(encoder_rates)
def preprocess(self, audio_data, sample_rate):
if sample_rate is None:
sample_rate = self.sample_rate
assert sample_rate == self.sample_rate
pad_to = self.hop_length
length = audio_data.shape[-1]
right_pad = math.ceil(length / pad_to) * pad_to - length
audio_data = nn.functional.pad(audio_data, (0, right_pad))
return audio_data
def decode(self, z: torch.Tensor, sr_cond: torch.Tensor = None):
"""Decode given latent codes and return audio data
Parameters
----------
z : Tensor[B x D x T]
Quantized continuous representation of input
length : int, optional
Number of samples in output audio, by default None
Returns
-------
dict
A dictionary with the following keys:
"audio" : Tensor[B x 1 x length]
Decoded audio data.
"""
if self.sr_bin_boundaries is not None:
# use default output sample rate
if sr_cond is None:
sr_cond = torch.tensor([self.out_sample_rate], device=z.device, dtype=torch.int32)
return self.decoder(z, sr_cond)
def encode(self, audio_data: torch.Tensor, sample_rate: int):
"""
Args:
audio_data: Tensor[B x 1 x T]
sample_rate: int
Returns:
z: Tensor[B x D x T]
"""
if audio_data.ndim == 2:
audio_data = audio_data.unsqueeze(1)
audio_data = self.preprocess(audio_data, sample_rate)
return self.encoder(audio_data)["mu"]
+1 -1
View File
@@ -1 +1 @@
from .scalar_quantization_layer import ScalarQuantizationLayer
from .scalar_quantization_layer import ScalarQuantizationLayer
+1 -4
View File
@@ -34,7 +34,7 @@ class LoRALinear(nn.Module):
self.r = r
self.alpha = alpha
self._base_scaling = alpha / r if r > 0 else 0.0
# 使用 buffer 存储 scaling,这样修改值不会触发 torch.compile 重编译
# persistent=False 表示不保存到 state_dict,避免加载时 missing key
self.register_buffer("scaling", torch.tensor(self._base_scaling), persistent=False)
@@ -128,6 +128,3 @@ def apply_lora_to_named_linear_modules(
dropout=dropout,
)
setattr(parent, short_name, lora_layer)
@@ -12,7 +12,7 @@ class ScalarQuantizationLayer(nn.Module):
self.in_proj = nn.Linear(in_dim, latent_dim)
self.out_proj = nn.Linear(latent_dim, out_dim)
def forward(self, hidden):
hidden = self.in_proj(hidden)
hidden = torch.tanh(hidden)
@@ -23,4 +23,4 @@ class ScalarQuantizationLayer(nn.Module):
else:
hidden = torch.round(hidden * self.scale) / self.scale
return self.out_proj(hidden)
return self.out_proj(hidden)
+1
View File
@@ -1,2 +1,3 @@
from .unified_cfm import UnifiedCFM, CfmConfig
from .local_dit import VoxCPMLocDiT
from .local_dit_v2 import VoxCPMLocDiT as VoxCPMLocDiTV2
+116
View File
@@ -0,0 +1,116 @@
import torch
from ..minicpm4 import MiniCPMModel, MiniCPM4Config
import torch.nn as nn
import math
class SinusoidalPosEmb(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
def forward(self, x, scale=1000):
if x.ndim < 1:
x = x.unsqueeze(0)
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=x.dtype, device=device) * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
out_dim: int = None,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True)
self.act = nn.SiLU()
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, bias=True)
def forward(self, sample):
sample = self.linear_1(sample)
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
class VoxCPMLocDiT(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
config: MiniCPM4Config,
in_channels: int = 64,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels
self.config = config
self.in_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
self.cond_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
self.out_proj = nn.Linear(config.hidden_size, self.out_channels, bias=True)
self.time_embeddings = SinusoidalPosEmb(config.hidden_size)
self.time_mlp = TimestepEmbedding(
in_channels=config.hidden_size,
time_embed_dim=config.hidden_size,
)
self.delta_time_mlp = TimestepEmbedding(
in_channels=config.hidden_size,
time_embed_dim=config.hidden_size,
)
assert config.vocab_size == 0, "vocab_size must be 0 for local DiT"
self.decoder = MiniCPMModel(config)
def forward(
self,
x: torch.Tensor,
mu: torch.Tensor,
t: torch.Tensor,
cond: torch.Tensor,
dt: torch.Tensor,
):
"""
Forward pass of DiT.
x: (N, C, T) tensor of inputs
mu: (N, C) tensor of hidden embedding
t: (N,) tensor of diffusion timesteps
cond: (N, C, T') tensor of prefix conditions
dt: (N,) used for mean velocity (may be supported in the future...)
"""
x = self.in_proj(x.transpose(1, 2).contiguous())
cond = self.cond_proj(cond.transpose(1, 2).contiguous())
prefix = cond.size(1)
t = self.time_embeddings(t).to(x.dtype)
t = self.time_mlp(t)
dt = self.time_embeddings(dt).to(x.dtype)
dt = self.delta_time_mlp(dt)
t = t + dt
mu = mu.view(x.size(0), -1, x.size(-1))
x = torch.cat([mu, (t).unsqueeze(1), cond, x], dim=1)
hidden, _ = self.decoder(x, is_causal=False)
hidden = hidden[:, prefix + mu.size(1) + 1 :, :]
hidden = self.out_proj(hidden)
return hidden.transpose(1, 2).contiguous()
+8 -7
View File
@@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import Tuple
import torch
import torch.nn.functional as F
@@ -56,7 +56,7 @@ class UnifiedCFM(torch.nn.Module):
cond: torch.Tensor,
temperature: float = 1.0,
cfg_value: float = 1.0,
sway_sampling_coef: float = 1.0,
sway_sampling_coef: float = 1.0,
use_cfg_zero_star: bool = True,
):
b, _ = mu.shape
@@ -116,7 +116,7 @@ class UnifiedCFM(torch.nn.Module):
dphi_dt = self.estimator(x_in, mu_in, t_in, cond_in, dt_in)
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
if use_cfg_zero_star:
positive_flat = dphi_dt.view(b, -1)
negative_flat = cfg_dphi_dt.view(b, -1)
@@ -124,7 +124,7 @@ class UnifiedCFM(torch.nn.Module):
st_star = st_star.view(b, *([1] * (len(dphi_dt.shape) - 1)))
else:
st_star = 1.0
dphi_dt = cfg_dphi_dt * st_star + cfg_value * (dphi_dt - cfg_dphi_dt * st_star)
x = x - dt * dphi_dt
@@ -138,7 +138,9 @@ class UnifiedCFM(torch.nn.Module):
# ------------------------------------------------------------------ #
# Training loss
# ------------------------------------------------------------------ #
def adaptive_loss_weighting(self, losses: torch.Tensor, mask: torch.Tensor | None = None, p: float = 0.0, epsilon: float = 1e-3):
def adaptive_loss_weighting(
self, losses: torch.Tensor, mask: torch.Tensor | None = None, p: float = 0.0, epsilon: float = 1e-3
):
weights = 1.0 / ((losses + epsilon).pow(p))
if mask is not None:
weights = weights * mask
@@ -193,8 +195,7 @@ class UnifiedCFM(torch.nn.Module):
cond = cond + noisy_mask.view(-1, 1, 1) * torch.randn_like(cond) * self.noise_cond_scale
ratio_r_neq_t = (
self.ratio_r_neq_t_range[0]
+ progress * (self.ratio_r_neq_t_range[1] - self.ratio_r_neq_t_range[0])
self.ratio_r_neq_t_range[0] + progress * (self.ratio_r_neq_t_range[1] - self.ratio_r_neq_t_range[0])
if self.mean_mode
else 0.0
)
+1 -1
View File
@@ -26,4 +26,4 @@ class MiniCPM4Config(BaseModel):
dim_model_base: int
scale_depth: float
rope_theta: float
kv_channels: int = None
kv_channels: int = None
+13 -14
View File
@@ -64,10 +64,8 @@ class MiniCPMLongRoPE(nn.Module):
self.long_factor = config.rope_scaling.long_factor
self.original_max_position_embeddings = config.rope_scaling.original_max_position_embeddings
scale = (self.max_position_embeddings / self.original_max_position_embeddings)
self.scaling_factor = math.sqrt(
1 + math.log(scale) / math.log(self.original_max_position_embeddings)
)
scale = self.max_position_embeddings / self.original_max_position_embeddings
self.scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
@@ -76,11 +74,7 @@ class MiniCPMLongRoPE(nn.Module):
self.register_buffer("cos_cached", torch.empty(0), persistent=False)
self.register_buffer("sin_cached", torch.empty(0), persistent=False)
self._set_cos_sin_cache(
seq_len=self.max_position_embeddings,
device=self.inv_freq.device,
dtype=torch.float32
)
self._set_cos_sin_cache(seq_len=self.max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32)
def _set_cos_sin_cache(self, seq_len, device, dtype):
"""设置cos和sin缓存"""
@@ -93,8 +87,7 @@ class MiniCPMLongRoPE(nn.Module):
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=device)
freqs = torch.mul(
torch.outer(t, 1.0 / ext_factors).to(device=device),
self.inv_freq.to(device=device).to(dtype)
torch.outer(t, 1.0 / ext_factors).to(device=device), self.inv_freq.to(device=device).to(dtype)
)
# 创建embeddings
@@ -123,7 +116,9 @@ class MiniCPMAttention(nn.Module):
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
self.head_dim = (
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
)
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
@@ -153,7 +148,7 @@ class MiniCPMAttention(nn.Module):
cos, sin = position_emb
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# ref: https://github.com/pytorch/pytorch/issues/163597
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
query_states = query_states.contiguous()
@@ -413,7 +408,11 @@ class MiniCPMModel(nn.Module):
self.kv_cache = StaticKVCache(
num_layers=self.config.num_hidden_layers,
num_kv_heads=self.config.num_key_value_heads,
dim_kv_head=self.config.hidden_size // self.config.num_attention_heads if self.config.kv_channels is None else self.config.kv_channels,
dim_kv_head=(
self.config.hidden_size // self.config.num_attention_heads
if self.config.kv_channels is None
else self.config.kv_channels
),
batch_size=batch_size,
device=device,
dtype=dtype,
-1
View File
@@ -25,4 +25,3 @@ __all__ = [
"load_audio_text_datasets",
"build_dataloader",
]
+2 -5
View File
@@ -47,9 +47,7 @@ class Accelerator:
pass
self.scaler = torch.amp.GradScaler("cuda") if (amp and torch.cuda.is_available()) else DummyScaler()
self.device_ctx = (
torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
)
self.device_ctx = torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
self._ddp_model = None # For no_sync support
def _set_seed(self, seed: int):
@@ -84,7 +82,7 @@ class Accelerator:
# Model helpers
# ------------------------------------------------------------------ #
def prepare_model(self, model: torch.nn.Module, **kwargs):
if hasattr(model, 'device'): # make sure the matrix will be moved to the correct device
if hasattr(model, "device"): # make sure the matrix will be moved to the correct device
model.device = self.device
model = model.to(self.device)
if self.world_size > 1:
@@ -163,4 +161,3 @@ class Accelerator:
@staticmethod
def unwrap(model: torch.nn.Module) -> torch.nn.Module:
return model.module if hasattr(model, "module") else model
-2
View File
@@ -36,5 +36,3 @@ def parse_args_with_config(config_path: str | Path | None = None):
yaml_args = argbind.parse_args(yaml_args=yaml_args, argv=[])
cli_args.update(yaml_args)
return cli_args
+4 -6
View File
@@ -1,5 +1,4 @@
import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import argbind
@@ -11,7 +10,6 @@ from ..model.voxcpm import VoxCPMConfig
from ..modules.audiovae import AudioVAE
from .packers import AudioFeatureProcessingPacker
DEFAULT_TEXT_COLUMN = "text"
DEFAULT_AUDIO_COLUMN = "audio"
DEFAULT_ID_COLUMN = "dataset_id"
@@ -36,7 +34,7 @@ def load_audio_text_datasets(
def prepare(ds: Dataset) -> Dataset:
if audio_column not in ds.column_names:
raise ValueError(f"Expected '{audio_column}' column in manifest.")
# We cast to Audio to ensure proper handling during training,
# We cast to Audio to ensure proper handling during training,
# but for length calculation we might need raw path or duration if available.
# HF datasets usually don't compute duration automatically for 'Audio' column.
ds = ds.cast_column(audio_column, Audio(sampling_rate=sample_rate))
@@ -70,13 +68,13 @@ def compute_sample_lengths(
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
t_seq = ceil(t_vae / patch_size)
- 序列总长约为: text_len + t_seq + 2
Optimized: Use batch column access instead of iterating item by item.
"""
# Batch access columns - much faster than per-item access
text_ids_list = ds["text_ids"]
text_lens = [len(t) for t in text_ids_list]
has_duration = "duration" in ds.column_names
if has_duration:
durations = ds["duration"]
@@ -86,7 +84,7 @@ def compute_sample_lengths(
for i in range(len(ds)):
audio = ds[i][DEFAULT_AUDIO_COLUMN]
durations.append(len(audio["array"]) / float(audio["sampling_rate"]))
# Vectorized length computation
lengths = []
for text_len, duration in zip(text_lens, durations):
+29 -22
View File
@@ -1,5 +1,4 @@
from typing import Dict, List, Tuple
from typing import Dict, List
import torch
import torch.nn as nn
@@ -15,7 +14,7 @@ class AudioFeatureProcessingPacker:
def __init__(self, dataset_cnt: int, max_len: int, patch_size: int, feat_dim: int, audio_vae: nn.Module):
self.audio_start_id = 101
self.audio_end_id = 102
# unused now
# unused now
self.audio_prompt_start_id = 103
self.audio_prompt_end_id = 104
self.text_eos_token_id = 2
@@ -147,31 +146,26 @@ class AudioFeatureProcessingPacker:
def pad_1d(x: torch.Tensor, pad_value: int = 0) -> torch.Tensor:
if x.size(0) >= max_len:
return x[: max_len]
return x[:max_len]
pad = torch.full((max_len - x.size(0),), pad_value, dtype=x.dtype, device=x.device)
return torch.cat([x, pad], dim=0)
def pad_3d(x: torch.Tensor) -> torch.Tensor:
# x: [T, P, D]
if x.size(0) >= max_len:
return x[: max_len]
pad = torch.zeros(
(max_len - x.size(0),) + x.shape[1:], dtype=x.dtype, device=x.device
)
return x[:max_len]
pad = torch.zeros((max_len - x.size(0),) + x.shape[1:], dtype=x.dtype, device=x.device)
return torch.cat([x, pad], dim=0)
if lengths:
text_tokens_batch = torch.stack([pad_1d(t, pad_value=0) for t in text_tokens_list], dim=0)
text_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in text_mask_list], dim=0)
audio_feats_batch = torch.stack([pad_3d(f) for f in audio_feats_list], dim=0)
audio_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in audio_mask_list], dim=0)
loss_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in loss_mask_list], dim=0)
labels_batch = torch.stack([pad_1d(l, pad_value=0) for l in labels_list], dim=0)
audio_task_ids_batch = torch.stack(
[pad_1d(t, pad_value=0) for t in audio_task_ids_list], dim=0
)
audio_dataset_ids_batch = torch.stack(
[pad_1d(d, pad_value=0) for d in audio_dataset_ids_list], dim=0
)
labels_batch = torch.stack([pad_1d(lbl, pad_value=0) for lbl in labels_list], dim=0)
audio_task_ids_batch = torch.stack([pad_1d(t, pad_value=0) for t in audio_task_ids_list], dim=0)
audio_dataset_ids_batch = torch.stack([pad_1d(d, pad_value=0) for d in audio_dataset_ids_list], dim=0)
# Position ids: [B, T], simple 0..L_i-1 then padded with 0
position_ids_list = []
@@ -265,13 +259,27 @@ class AudioFeatureProcessingPacker:
)
audio_feat_info = torch.cat([audio_pad_feat, audio_feat_info, audio_pad_feat[0:1, ...]], dim=0)
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length), torch.ones(1)]).type(torch.int32).to(
text_token.device
text_mask = (
torch.cat([torch.ones(text_length), torch.zeros(audio_length), torch.ones(1)])
.type(torch.int32)
.to(text_token.device)
)
audio_mask = (
torch.cat([torch.zeros(text_length), torch.ones(audio_length), torch.zeros(1)])
.type(torch.int32)
.to(text_token.device)
)
loss_mask = (
torch.cat(
[
torch.zeros(text_length),
torch.zeros(audio_length) if is_prompt else torch.ones(audio_length),
torch.zeros(1),
]
)
.type(torch.int32)
.to(text_token.device)
)
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length), torch.zeros(1)]).type(
torch.int32
).to(text_token.device)
loss_mask = torch.cat([torch.zeros(text_length), torch.zeros(audio_length) if is_prompt else torch.ones(audio_length), torch.zeros(1)]).type(torch.int32).to(text_token.device)
labels = torch.zeros(text_length + audio_length + 1).type(torch.int32).to(text_token.device)
labels[-2] = 1
@@ -286,4 +294,3 @@ class AudioFeatureProcessingPacker:
audio_duration,
text_token_count,
)
-1
View File
@@ -18,4 +18,3 @@ class TrainingState:
val_loader: object
tracker: object
batch_processor: object
-1
View File
@@ -76,4 +76,3 @@ class TrainingTracker:
@contextlib.contextmanager
def live(self):
yield
+34 -31
View File
@@ -2,10 +2,10 @@
import re
import regex
import inflect
from functools import partial
from wetext import Normalizer
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]+")
# whether contain chinese character
def contains_chinese(text):
@@ -14,19 +14,19 @@ def contains_chinese(text):
# replace special symbol
def replace_corner_mark(text):
text = text.replace('²', '平方')
text = text.replace('³', '立方')
text = text.replace('', '根号')
text = text.replace('', '约等于')
text = text.replace('<', '小于')
text = text.replace("²", "平方")
text = text.replace("³", "立方")
text = text.replace("", "根号")
text = text.replace("", "约等于")
text = text.replace("<", "小于")
return text
# remove meaningless symbol
def remove_bracket(text):
text = text.replace('', ' ').replace('', ' ')
text = text.replace('', ' ').replace('', ' ')
text = text.replace('`', '').replace('`', '')
text = text.replace("", " ").replace("", " ")
text = text.replace("", " ").replace("", " ")
text = text.replace("`", "").replace("`", "")
text = text.replace("——", " ")
return text
@@ -38,7 +38,7 @@ def spell_out_number(text: str, inflect_parser):
for i, c in enumerate(text):
if not c.isdigit():
if st is not None:
num_str = inflect_parser.number_to_words(text[st: i])
num_str = inflect_parser.number_to_words(text[st:i])
new_text.append(num_str)
st = None
new_text.append(c)
@@ -48,7 +48,7 @@ def spell_out_number(text: str, inflect_parser):
if st is not None and st < len(text):
num_str = inflect_parser.number_to_words(text[st:])
new_text.append(num_str)
return ''.join(new_text)
return "".join(new_text)
# split paragrah logic
@@ -69,18 +69,18 @@ def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=
return len(tokenize(_text)) < merge_len
if lang == "zh":
pounc = ['', '', '', '', '', '', '.', '?', '!', ';']
pounc = ["", "", "", "", "", "", ".", "?", "!", ";"]
else:
pounc = ['.', '?', '!', ';', ':']
pounc = [".", "?", "!", ";", ":"]
if comma_split:
pounc.extend(['', ','])
pounc.extend(["", ","])
st = 0
utts = []
for i, c in enumerate(text):
if c in pounc:
if len(text[st: i]) > 0:
utts.append(text[st: i] + c)
if i + 1 < len(text) and text[i + 1] in ['"', '']:
if len(text[st:i]) > 0:
utts.append(text[st:i] + c)
if i + 1 < len(text) and text[i + 1] in ['"', ""]:
tmp = utts.pop(-1)
utts.append(tmp + text[i + 1])
st = i + 2
@@ -88,9 +88,9 @@ def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=
st = i + 1
if len(utts) == 0:
if lang == "zh":
utts.append(text + '')
utts.append(text + "")
else:
utts.append(text + '.')
utts.append(text + ".")
final_utts = []
cur_utt = ""
for utt in utts:
@@ -112,13 +112,13 @@ def replace_blank(text: str):
out_str = []
for i, c in enumerate(text):
if c == " ":
if ((text[i + 1].isascii() and text[i + 1] != " ") and
(text[i - 1].isascii() and text[i - 1] != " ")):
if (text[i + 1].isascii() and text[i + 1] != " ") and (text[i - 1].isascii() and text[i - 1] != " "):
out_str.append(c)
else:
out_str.append(c)
return "".join(out_str)
def clean_markdown(md_text: str) -> str:
# 去除代码块 ``` ```(包括多行)
md_text = re.sub(r"```.*?```", "", md_text, flags=re.DOTALL)
@@ -131,9 +131,9 @@ def clean_markdown(md_text: str) -> str:
# 去除链接但保留文本 [text](url) -> text
md_text = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", md_text)
# 替换无序列表符号
md_text = re.sub(r'^(\s*)-\s+', r'\1', md_text, flags=re.MULTILINE)
md_text = re.sub(r"^(\s*)-\s+", r"\1", md_text, flags=re.MULTILINE)
# 去除HTML标签
md_text = re.sub(r"<[^>]+>", "", md_text)
@@ -152,28 +152,31 @@ def clean_text(text):
# 去除 Markdown 语法
text = clean_markdown(text)
# 匹配并移除表情符号
text = regex.compile(r'\p{Emoji_Presentation}|\p{Emoji}\uFE0F', flags=regex.UNICODE).sub("",text)
text = regex.compile(r"\p{Emoji_Presentation}|\p{Emoji}\uFE0F", flags=regex.UNICODE).sub("", text)
# 去除换行符
text = text.replace("\n", " ")
text = text.replace("\t", " ")
text = text.replace('"', "\")
text = text.replace("", '"').replace("", '"')
return text
class TextNormalizer:
def __init__(self, tokenizer=None):
self.tokenizer = tokenizer
self.zh_tn_model = Normalizer(lang="zh", operator="tn", remove_erhua=True)
self.en_tn_model = Normalizer(lang="en", operator="tn")
self.inflect_parser = inflect.engine()
def normalize(self, text, split=False):
# 去除 Markdown 语法,去除表情符号,去除换行符
lang = "zh" if contains_chinese(text) else "en"
text = clean_text(text)
if lang == "zh":
text = text.replace("=", "等于") # 修复 ”550 + 320 等于 870 千卡。“ 被错误正则为 ”五百五十加三百二十等于八七十千卡.“
if re.search(r'([\d$%^*_+≥≤≠×÷?=])', text): # 避免 英文连字符被错误正则为减
text = re.sub(r'(?<=[a-zA-Z0-9])-(?=\d)', ' - ', text) # 修复 x-2 被正则为 x负2
text = text.replace(
"=", "等于"
) # 修复 ”550 + 320 等于 870 千卡。“ 被错误正则为 ”五百五十加三百二十等于八七十千卡.“
if re.search(r"([\d$%^*_+≥≤≠×÷?=])", text): # 避免 英文连字符被错误正则为减
text = re.sub(r"(?<=[a-zA-Z0-9])-(?=\d)", " - ", text) # 修复 x-2 被正则为 x负2
text = self.zh_tn_model.normalize(text)
text = replace_blank(text)
text = replace_corner_mark(text)
@@ -182,4 +185,4 @@ class TextNormalizer:
text = self.en_tn_model.normalize(text)
text = spell_out_number(text, self.inflect_parser)
if split is False:
return text
return text
+10 -14
View File
@@ -7,15 +7,15 @@ Related dependencies are imported only when denoising functionality is needed.
import os
import tempfile
from typing import Optional, Union
from typing import Optional
import torchaudio
import torch
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
class ZipEnhancer:
"""ZipEnhancer Audio Denoising Enhancer"""
def __init__(self, model_path: str = "iic/speech_zipenhancer_ans_multiloss_16k_base"):
"""
Initialize ZipEnhancer
@@ -23,25 +23,21 @@ class ZipEnhancer:
model_path: ModelScope model path or local path
"""
self.model_path = model_path
self._pipeline = pipeline(
Tasks.acoustic_noise_suppression,
model=self.model_path
)
self._pipeline = pipeline(Tasks.acoustic_noise_suppression, model=self.model_path)
def _normalize_loudness(self, wav_path: str):
"""
Audio loudness normalization
Args:
wav_path: Audio file path
"""
audio, sr = torchaudio.load(wav_path)
loudness = torchaudio.functional.loudness(audio, sr)
normalized_audio = torchaudio.functional.gain(audio, -20-loudness)
normalized_audio = torchaudio.functional.gain(audio, -20 - loudness)
torchaudio.save(wav_path, normalized_audio, sr)
def enhance(self, input_path: str, output_path: Optional[str] = None,
normalize_loudness: bool = True) -> str:
def enhance(self, input_path: str, output_path: Optional[str] = None, normalize_loudness: bool = True) -> str:
"""
Audio denoising enhancement
Args:
@@ -57,7 +53,7 @@ class ZipEnhancer:
raise FileNotFoundError(f"Input audio file does not exist: {input_path}")
# Create temporary file if no output path is specified
if output_path is None:
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
output_path = tmp_file.name
try:
# Perform denoising processing
@@ -73,4 +69,4 @@ class ZipEnhancer:
os.unlink(output_path)
except OSError:
pass
raise RuntimeError(f"Audio denoising processing failed: {e}")
raise RuntimeError(f"Audio denoising processing failed: {e}")