update voxcpm2
This commit is contained in:
+46
-27
@@ -13,11 +13,11 @@ import soundfile as sf
|
||||
|
||||
from voxcpm.core import VoxCPM
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Validators
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def validate_file_exists(file_path: str, file_type: str = "file") -> Path:
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
@@ -53,12 +53,11 @@ def validate_ranges(args, parser):
|
||||
# Model loading
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def load_model(args) -> VoxCPM:
|
||||
print("Loading VoxCPM model...", file=sys.stderr)
|
||||
|
||||
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get(
|
||||
"ZIPENHANCER_MODEL_PATH", None
|
||||
)
|
||||
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get("ZIPENHANCER_MODEL_PATH", None)
|
||||
|
||||
# Build LoRA config if provided
|
||||
lora_config = None
|
||||
@@ -119,22 +118,29 @@ def load_model(args) -> VoxCPM:
|
||||
# Commands
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def cmd_clone(args):
|
||||
if not args.text:
|
||||
sys.exit("Error: Please provide --text for synthesis")
|
||||
|
||||
if not args.prompt_audio or not args.prompt_text:
|
||||
sys.exit("Error: Voice cloning requires both --prompt-audio and --prompt-text")
|
||||
has_prompt = args.prompt_audio and args.prompt_text
|
||||
has_ref = args.reference_audio is not None
|
||||
if not has_prompt and not has_ref:
|
||||
sys.exit("Error: Voice cloning requires --prompt-audio + --prompt-text, or --reference-audio, or both")
|
||||
|
||||
prompt_audio_path = validate_file_exists(args.prompt_audio, "reference audio file")
|
||||
if args.prompt_audio:
|
||||
validate_file_exists(args.prompt_audio, "prompt audio file")
|
||||
if args.reference_audio:
|
||||
validate_file_exists(args.reference_audio, "reference audio file")
|
||||
output_path = validate_output_path(args.output)
|
||||
|
||||
model = load_model(args)
|
||||
|
||||
audio_array = model.generate(
|
||||
text=args.text,
|
||||
prompt_wav_path=str(prompt_audio_path),
|
||||
prompt_text=args.prompt_text,
|
||||
prompt_wav_path=args.prompt_audio if has_prompt else None,
|
||||
prompt_text=args.prompt_text if has_prompt else None,
|
||||
reference_wav_path=args.reference_audio,
|
||||
cfg_value=args.cfg_value,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
normalize=args.normalize,
|
||||
@@ -185,7 +191,11 @@ def cmd_batch(args):
|
||||
|
||||
prompt_audio_path = None
|
||||
if args.prompt_audio:
|
||||
prompt_audio_path = str(validate_file_exists(args.prompt_audio, "reference audio file"))
|
||||
prompt_audio_path = str(validate_file_exists(args.prompt_audio, "prompt audio file"))
|
||||
|
||||
reference_audio_path = None
|
||||
if args.reference_audio:
|
||||
reference_audio_path = str(validate_file_exists(args.reference_audio, "reference audio file"))
|
||||
|
||||
success_count = 0
|
||||
|
||||
@@ -195,10 +205,11 @@ def cmd_batch(args):
|
||||
text=text,
|
||||
prompt_wav_path=prompt_audio_path,
|
||||
prompt_text=args.prompt_text,
|
||||
reference_wav_path=reference_audio_path,
|
||||
cfg_value=args.cfg_value,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
normalize=args.normalize,
|
||||
denoise=args.denoise and prompt_audio_path is not None,
|
||||
denoise=args.denoise and (prompt_audio_path is not None or reference_audio_path is not None),
|
||||
)
|
||||
|
||||
output_file = output_dir / f"output_{i:03d}.wav"
|
||||
@@ -218,6 +229,7 @@ def cmd_batch(args):
|
||||
# Parser
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def _build_unified_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="VoxCPM CLI - voice cloning, direct TTS, and batch processing",
|
||||
@@ -236,34 +248,40 @@ Examples:
|
||||
parser.add_argument("--text", "-t", help="Text to synthesize (single or clone mode)")
|
||||
parser.add_argument("--output", "-o", help="Output audio file path (single or clone mode)")
|
||||
|
||||
# Prompt
|
||||
parser.add_argument("--prompt-audio", "-pa", help="Reference audio file path (clone mode)")
|
||||
parser.add_argument("--prompt-text", "-pt", help="Reference text corresponding to the audio")
|
||||
parser.add_argument("--denoise", action="store_true", help="Enable prompt speech enhancement")
|
||||
# Prompt / Reference
|
||||
parser.add_argument(
|
||||
"--prompt-audio", "-pa", help="Prompt audio file path (continuation mode, requires --prompt-text)"
|
||||
)
|
||||
parser.add_argument("--prompt-text", "-pt", help="Text corresponding to the prompt audio")
|
||||
parser.add_argument(
|
||||
"--reference-audio", "-ra", help="Reference audio for voice cloning (isolated mode, VoxCPM2 only)"
|
||||
)
|
||||
parser.add_argument("--denoise", action="store_true", help="Enable prompt/reference speech enhancement")
|
||||
|
||||
# Generation parameters
|
||||
parser.add_argument("--cfg-value", type=float, default=2.0,
|
||||
help="CFG guidance scale (float, recommended 0.5–5.0, default: 2.0)")
|
||||
parser.add_argument("--inference-timesteps", type=int, default=10,
|
||||
help="Inference steps (int, 1–100, default: 10)")
|
||||
parser.add_argument(
|
||||
"--cfg-value", type=float, default=2.0, help="CFG guidance scale (float, recommended 0.5–5.0, default: 2.0)"
|
||||
)
|
||||
parser.add_argument("--inference-timesteps", type=int, default=10, help="Inference steps (int, 1–100, default: 10)")
|
||||
parser.add_argument("--normalize", action="store_true", help="Enable text normalization")
|
||||
|
||||
# Model loading
|
||||
parser.add_argument("--model-path", type=str, help="Local VoxCPM model path")
|
||||
parser.add_argument("--hf-model-id", type=str, default="openbmb/VoxCPM1.5",
|
||||
help="Hugging Face repo id (default: openbmb/VoxCPM1.5)")
|
||||
parser.add_argument(
|
||||
"--hf-model-id", type=str, default="openbmb/VoxCPM1.5", help="Hugging Face repo id (default: openbmb/VoxCPM1.5)"
|
||||
)
|
||||
parser.add_argument("--cache-dir", type=str, help="Cache directory for Hub downloads")
|
||||
parser.add_argument("--local-files-only", action="store_true", help="Disable network access")
|
||||
parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading")
|
||||
parser.add_argument("--zipenhancer-path", type=str,
|
||||
help="ZipEnhancer model id or local path (or env ZIPENHANCER_MODEL_PATH)")
|
||||
parser.add_argument(
|
||||
"--zipenhancer-path", type=str, help="ZipEnhancer model id or local path (or env ZIPENHANCER_MODEL_PATH)"
|
||||
)
|
||||
|
||||
# LoRA
|
||||
parser.add_argument("--lora-path", type=str, help="Path to LoRA weights")
|
||||
parser.add_argument("--lora-r", type=int, default=32, help="LoRA rank (positive int, default: 32)")
|
||||
parser.add_argument("--lora-alpha", type=int, default=16, help="LoRA alpha (positive int, default: 16)")
|
||||
parser.add_argument("--lora-dropout", type=float, default=0.0,
|
||||
help="LoRA dropout rate (0.0–1.0, default: 0.0)")
|
||||
parser.add_argument("--lora-dropout", type=float, default=0.0, help="LoRA dropout rate (0.0–1.0, default: 0.0)")
|
||||
parser.add_argument("--lora-disable-lm", action="store_true", help="Disable LoRA on LM layers")
|
||||
parser.add_argument("--lora-disable-dit", action="store_true", help="Disable LoRA on DiT layers")
|
||||
parser.add_argument("--lora-enable-proj", action="store_true", help="Enable LoRA on projection layers")
|
||||
@@ -275,6 +293,7 @@ Examples:
|
||||
# Entrypoint
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def main():
|
||||
parser = _build_unified_parser()
|
||||
args = parser.parse_args()
|
||||
@@ -296,8 +315,8 @@ def main():
|
||||
if not args.text or not args.output:
|
||||
parser.error("Single-sample mode requires --text and --output")
|
||||
|
||||
# Clone mode
|
||||
if args.prompt_audio or args.prompt_text:
|
||||
# Clone mode (prompt continuation, reference isolation, or both)
|
||||
if args.prompt_audio or args.prompt_text or args.reference_audio:
|
||||
return cmd_clone(args)
|
||||
|
||||
# Direct synthesis
|
||||
|
||||
+151
-100
@@ -1,21 +1,25 @@
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import json
|
||||
import tempfile
|
||||
import numpy as np
|
||||
from typing import Generator, Optional
|
||||
from huggingface_hub import snapshot_download
|
||||
from .model.voxcpm import VoxCPMModel, LoRAConfig
|
||||
from .model.voxcpm2 import VoxCPM2Model
|
||||
|
||||
|
||||
class VoxCPM:
|
||||
def __init__(self,
|
||||
voxcpm_model_path : str,
|
||||
zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||
enable_denoiser : bool = True,
|
||||
optimize: bool = True,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
lora_weights_path: Optional[str] = None,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
voxcpm_model_path: str,
|
||||
zipenhancer_model_path: str | None = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||
enable_denoiser: bool = True,
|
||||
optimize: bool = True,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
lora_weights_path: Optional[str] = None,
|
||||
):
|
||||
"""Initialize VoxCPM TTS pipeline.
|
||||
|
||||
Args:
|
||||
@@ -26,13 +30,16 @@ class VoxCPM:
|
||||
id or local path. If None, denoiser will not be initialized.
|
||||
enable_denoiser: Whether to initialize the denoiser pipeline.
|
||||
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
|
||||
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
|
||||
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
|
||||
provided without lora_config, a default config will be created.
|
||||
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
|
||||
containing lora_weights.ckpt). If provided, LoRA weights will be loaded.
|
||||
"""
|
||||
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}", file=sys.stderr)
|
||||
|
||||
print(
|
||||
f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# If lora_weights_path is provided but no lora_config, create a default one
|
||||
if lora_weights_path is not None and lora_config is None:
|
||||
lora_config = LoRAConfig(
|
||||
@@ -41,18 +48,33 @@ class VoxCPM:
|
||||
enable_proj=False,
|
||||
)
|
||||
print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}", file=sys.stderr)
|
||||
|
||||
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
||||
|
||||
|
||||
# Determine model type from config.json architecture field
|
||||
config_path = os.path.join(voxcpm_model_path, "config.json")
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
arch = config.get("architecture", "voxcpm").lower()
|
||||
|
||||
if arch == "voxcpm2":
|
||||
self.tts_model = VoxCPM2Model.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
||||
print("Loaded VoxCPM2Model", file=sys.stderr)
|
||||
elif arch == "voxcpm":
|
||||
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
||||
print("Loaded VoxCPMModel", file=sys.stderr)
|
||||
else:
|
||||
raise ValueError(f"Unsupported architecture: {arch}")
|
||||
|
||||
# Load LoRA weights if path is provided
|
||||
if lora_weights_path is not None:
|
||||
print(f"Loading LoRA weights from: {lora_weights_path}", file=sys.stderr)
|
||||
loaded_keys, skipped_keys = self.tts_model.load_lora_weights(lora_weights_path)
|
||||
print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}", file=sys.stderr)
|
||||
|
||||
|
||||
self.text_normalizer = None
|
||||
self.denoiser = None
|
||||
if enable_denoiser and zipenhancer_model_path is not None:
|
||||
from .zipenhancer import ZipEnhancer
|
||||
|
||||
self.denoiser = ZipEnhancer(zipenhancer_model_path)
|
||||
else:
|
||||
self.denoiser = None
|
||||
@@ -64,17 +86,18 @@ class VoxCPM:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls,
|
||||
hf_model_id: str = "openbmb/VoxCPM1.5",
|
||||
load_denoiser: bool = True,
|
||||
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||
cache_dir: str = None,
|
||||
local_files_only: bool = False,
|
||||
optimize: bool = True,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
lora_weights_path: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
def from_pretrained(
|
||||
cls,
|
||||
hf_model_id: str = "openbmb/VoxCPM2",
|
||||
load_denoiser: bool = True,
|
||||
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||
cache_dir: str = None,
|
||||
local_files_only: bool = False,
|
||||
optimize: bool = True,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
lora_weights_path: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
|
||||
|
||||
Args:
|
||||
@@ -86,7 +109,7 @@ class VoxCPM:
|
||||
cache_dir: Custom cache directory for the snapshot.
|
||||
local_files_only: If True, only use local files and do not attempt
|
||||
to download.
|
||||
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
|
||||
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
|
||||
provided without lora_config, a default config will be created with
|
||||
enable_lm=True and enable_dit=True.
|
||||
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
|
||||
@@ -106,7 +129,7 @@ class VoxCPM:
|
||||
repo_id = hf_model_id
|
||||
if not repo_id:
|
||||
raise ValueError("You must provide hf_model_id")
|
||||
|
||||
|
||||
# Load from local path if provided
|
||||
if os.path.isdir(repo_id):
|
||||
local_path = repo_id
|
||||
@@ -134,118 +157,146 @@ class VoxCPM:
|
||||
def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]:
|
||||
return self._generate(*args, streaming=True, **kwargs)
|
||||
|
||||
def _generate(self,
|
||||
text : str,
|
||||
prompt_wav_path : str = None,
|
||||
prompt_text : str = None,
|
||||
cfg_value : float = 2.0,
|
||||
inference_timesteps : int = 10,
|
||||
min_len : int = 2,
|
||||
max_len : int = 4096,
|
||||
normalize : bool = False,
|
||||
denoise : bool = False,
|
||||
retry_badcase : bool = True,
|
||||
retry_badcase_max_times : int = 3,
|
||||
retry_badcase_ratio_threshold : float = 6.0,
|
||||
streaming: bool = False,
|
||||
) -> Generator[np.ndarray, None, None]:
|
||||
def _generate(
|
||||
self,
|
||||
text: str,
|
||||
prompt_wav_path: str = None,
|
||||
prompt_text: str = None,
|
||||
reference_wav_path: str = None,
|
||||
cfg_value: float = 2.0,
|
||||
inference_timesteps: int = 10,
|
||||
min_len: int = 2,
|
||||
max_len: int = 4096,
|
||||
normalize: bool = False,
|
||||
denoise: bool = False,
|
||||
retry_badcase: bool = True,
|
||||
retry_badcase_max_times: int = 3,
|
||||
retry_badcase_ratio_threshold: float = 6.0,
|
||||
streaming: bool = False,
|
||||
) -> Generator[np.ndarray, None, None]:
|
||||
"""Synthesize speech for the given text and return a single waveform.
|
||||
|
||||
This method optionally builds and reuses a prompt cache. If an external
|
||||
prompt (``prompt_wav_path`` + ``prompt_text``) is provided, it will be
|
||||
used for all sub-sentences. Otherwise, the prompt cache is built from
|
||||
the first generated result and reused for the remaining text chunks.
|
||||
|
||||
Args:
|
||||
text: Input text. Can include newlines; each non-empty line is
|
||||
treated as a sub-sentence.
|
||||
prompt_wav_path: Path to a reference audio file for prompting.
|
||||
text: Input text to synthesize.
|
||||
prompt_wav_path: Path to prompt audio for continuation mode.
|
||||
Must be paired with ``prompt_text``.
|
||||
prompt_text: Text content corresponding to the prompt audio.
|
||||
reference_wav_path: Path to reference audio for voice cloning
|
||||
(structurally isolated via ref_audio tokens). Can be used
|
||||
alone or combined with ``prompt_wav_path`` + ``prompt_text``.
|
||||
cfg_value: Guidance scale for the generation model.
|
||||
inference_timesteps: Number of inference steps.
|
||||
min_len: Minimum audio length.
|
||||
max_len: Maximum token length during generation.
|
||||
normalize: Whether to run text normalization before generation.
|
||||
denoise: Whether to denoise the prompt audio if a denoiser is
|
||||
available.
|
||||
denoise: Whether to denoise the prompt/reference audio if a
|
||||
denoiser is available.
|
||||
retry_badcase: Whether to retry badcase.
|
||||
retry_badcase_max_times: Maximum number of times to retry badcase.
|
||||
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio.
|
||||
streaming: Whether to return a generator of audio chunks.
|
||||
Returns:
|
||||
Generator of numpy.ndarray: 1D waveform array (float32) on CPU.
|
||||
Yields audio chunks for each generations step if ``streaming=True``,
|
||||
Generator of numpy.ndarray: 1D waveform array (float32) on CPU.
|
||||
Yields audio chunks for each generation step if ``streaming=True``,
|
||||
otherwise yields a single array containing the final audio.
|
||||
"""
|
||||
if not text.strip() or not isinstance(text, str):
|
||||
raise ValueError("target text must be a non-empty string")
|
||||
|
||||
|
||||
if prompt_wav_path is not None:
|
||||
if not os.path.exists(prompt_wav_path):
|
||||
raise FileNotFoundError(f"prompt_wav_path does not exist: {prompt_wav_path}")
|
||||
|
||||
|
||||
if reference_wav_path is not None:
|
||||
if not os.path.exists(reference_wav_path):
|
||||
raise FileNotFoundError(f"reference_wav_path does not exist: {reference_wav_path}")
|
||||
|
||||
if (prompt_wav_path is None) != (prompt_text is None):
|
||||
raise ValueError("prompt_wav_path and prompt_text must both be provided or both be None")
|
||||
|
||||
|
||||
is_v2 = isinstance(self.tts_model, VoxCPM2Model)
|
||||
if reference_wav_path is not None and not is_v2:
|
||||
raise ValueError("reference_wav_path is only supported with VoxCPM2 models")
|
||||
|
||||
text = text.replace("\n", " ")
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
temp_prompt_wav_path = None
|
||||
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
temp_files = []
|
||||
|
||||
try:
|
||||
if prompt_wav_path is not None and prompt_text is not None:
|
||||
if denoise and self.denoiser is not None:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
|
||||
temp_prompt_wav_path = tmp_file.name
|
||||
self.denoiser.enhance(prompt_wav_path, output_path=temp_prompt_wav_path)
|
||||
prompt_wav_path = temp_prompt_wav_path
|
||||
fixed_prompt_cache = self.tts_model.build_prompt_cache(
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
prompt_text=prompt_text
|
||||
)
|
||||
actual_prompt_path = prompt_wav_path
|
||||
actual_ref_path = reference_wav_path
|
||||
|
||||
if denoise and self.denoiser is not None:
|
||||
if prompt_wav_path is not None:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
||||
temp_files.append(tmp.name)
|
||||
self.denoiser.enhance(prompt_wav_path, output_path=temp_files[-1])
|
||||
actual_prompt_path = temp_files[-1]
|
||||
if reference_wav_path is not None:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
||||
temp_files.append(tmp.name)
|
||||
self.denoiser.enhance(reference_wav_path, output_path=temp_files[-1])
|
||||
actual_ref_path = temp_files[-1]
|
||||
|
||||
if actual_prompt_path is not None or actual_ref_path is not None:
|
||||
if is_v2:
|
||||
fixed_prompt_cache = self.tts_model.build_prompt_cache(
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav_path=actual_prompt_path,
|
||||
reference_wav_path=actual_ref_path,
|
||||
)
|
||||
else:
|
||||
fixed_prompt_cache = self.tts_model.build_prompt_cache(
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav_path=actual_prompt_path,
|
||||
)
|
||||
else:
|
||||
fixed_prompt_cache = None # will be built from the first inference
|
||||
|
||||
fixed_prompt_cache = None
|
||||
|
||||
if normalize:
|
||||
if self.text_normalizer is None:
|
||||
from .utils.text_normalize import TextNormalizer
|
||||
|
||||
self.text_normalizer = TextNormalizer()
|
||||
text = self.text_normalizer.normalize(text)
|
||||
|
||||
|
||||
generate_result = self.tts_model._generate_with_prompt_cache(
|
||||
target_text=text,
|
||||
prompt_cache=fixed_prompt_cache,
|
||||
min_len=min_len,
|
||||
max_len=max_len,
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
retry_badcase=retry_badcase,
|
||||
retry_badcase_max_times=retry_badcase_max_times,
|
||||
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
target_text=text,
|
||||
prompt_cache=fixed_prompt_cache,
|
||||
min_len=min_len,
|
||||
max_len=max_len,
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
retry_badcase=retry_badcase,
|
||||
retry_badcase_max_times=retry_badcase_max_times,
|
||||
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
for wav, _, _ in generate_result:
|
||||
yield wav.squeeze(0).cpu().numpy()
|
||||
|
||||
|
||||
finally:
|
||||
if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path):
|
||||
try:
|
||||
os.unlink(temp_prompt_wav_path)
|
||||
except OSError:
|
||||
pass
|
||||
for tmp_path in temp_files:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# LoRA Interface (delegated to VoxCPMModel)
|
||||
# ------------------------------------------------------------------ #
|
||||
def load_lora(self, lora_weights_path: str) -> tuple:
|
||||
"""Load LoRA weights from a checkpoint file.
|
||||
|
||||
|
||||
Args:
|
||||
lora_weights_path: Path to LoRA weights (.pth file or directory
|
||||
containing lora_weights.ckpt).
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (loaded_keys, skipped_keys) - lists of loaded and skipped parameter names.
|
||||
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model was not initialized with LoRA config.
|
||||
"""
|
||||
@@ -259,23 +310,23 @@ class VoxCPM:
|
||||
def unload_lora(self):
|
||||
"""Unload LoRA by resetting all LoRA weights to initial state (effectively disabling LoRA)."""
|
||||
self.tts_model.reset_lora_weights()
|
||||
|
||||
|
||||
def set_lora_enabled(self, enabled: bool):
|
||||
"""Enable or disable LoRA layers without unloading weights.
|
||||
|
||||
|
||||
Args:
|
||||
enabled: If True, LoRA layers are active; if False, only base model is used.
|
||||
"""
|
||||
self.tts_model.set_lora_enabled(enabled)
|
||||
|
||||
|
||||
def get_lora_state_dict(self) -> dict:
|
||||
"""Get current LoRA parameters state dict.
|
||||
|
||||
|
||||
Returns:
|
||||
dict: State dict containing all LoRA parameters (lora_A, lora_B).
|
||||
"""
|
||||
return self.tts_model.get_lora_state_dict()
|
||||
|
||||
|
||||
@property
|
||||
def lora_enabled(self) -> bool:
|
||||
"""Check if LoRA is currently configured."""
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from .voxcpm import VoxCPMModel
|
||||
from .voxcpm2 import VoxCPM2Model
|
||||
|
||||
__all__ = ["VoxCPMModel"]
|
||||
__all__ = ["VoxCPMModel", "VoxCPM2Model"]
|
||||
|
||||
+18
-19
@@ -5,17 +5,17 @@ from transformers import PreTrainedTokenizer
|
||||
|
||||
def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
|
||||
"""Create a tokenizer wrapper that converts multi-character Chinese tokens to single characters.
|
||||
|
||||
|
||||
This function creates a wrapper around the provided tokenizer that automatically
|
||||
splits multi-character Chinese tokens into individual characters. This is useful
|
||||
for ensuring consistent tokenization of Chinese text.
|
||||
|
||||
|
||||
Args:
|
||||
tokenizer: The base tokenizer to wrap
|
||||
|
||||
|
||||
Returns:
|
||||
A CharTokenizerWrapper instance that handles multi-character Chinese tokens
|
||||
|
||||
|
||||
Example:
|
||||
>>> from transformers import LlamaTokenizerFast
|
||||
>>> tokenizer = LlamaTokenizerFast.from_pretrained("path/to/tokenizer")
|
||||
@@ -24,20 +24,19 @@ def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
|
||||
"""
|
||||
# Pre-compute multi-character tokens (length >= 2, pure Chinese characters)
|
||||
multichar_tokens = {
|
||||
token for token in tokenizer.vocab.keys()
|
||||
if len(token) >= 2 and all("\u4e00" <= c <= "\u9fff" for c in token)
|
||||
token for token in tokenizer.vocab.keys() if len(token) >= 2 and all("\u4e00" <= c <= "\u9fff" for c in token)
|
||||
}
|
||||
|
||||
class CharTokenizerWrapper:
|
||||
"""Wrapper class for tokenizers that handles multi-character Chinese tokens.
|
||||
|
||||
|
||||
This wrapper automatically splits multi-character Chinese tokens into
|
||||
individual characters while preserving the original tokenizer's interface.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, base_tokenizer: PreTrainedTokenizer) -> None:
|
||||
"""Initialize the wrapper with a base tokenizer.
|
||||
|
||||
|
||||
Args:
|
||||
base_tokenizer: The tokenizer to wrap
|
||||
"""
|
||||
@@ -46,14 +45,14 @@ def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
|
||||
|
||||
def tokenize(self, text: str, **kwargs) -> List[str]:
|
||||
"""Tokenize text and split multi-character Chinese tokens into single characters.
|
||||
|
||||
|
||||
Args:
|
||||
text: Input text to tokenize
|
||||
**kwargs: Additional arguments passed to the base tokenizer
|
||||
|
||||
|
||||
Returns:
|
||||
List of processed tokens with multi-character Chinese tokens split
|
||||
|
||||
|
||||
Example:
|
||||
>>> wrapper = CharTokenizerWrapper(tokenizer)
|
||||
>>> tokens = wrapper.tokenize("你好世界")
|
||||
@@ -61,10 +60,10 @@ def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
|
||||
"""
|
||||
if not isinstance(text, str):
|
||||
raise TypeError(f"Expected string input, got {type(text)}")
|
||||
|
||||
|
||||
tokens = self.tokenizer.tokenize(text, **kwargs)
|
||||
processed = []
|
||||
|
||||
|
||||
for token in tokens:
|
||||
# Remove possible subword prefix
|
||||
clean_token = token.replace("▁", "")
|
||||
@@ -75,22 +74,22 @@ def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
|
||||
processed.extend(chars)
|
||||
else:
|
||||
processed.append(token)
|
||||
|
||||
|
||||
return processed
|
||||
|
||||
def __call__(self, text: str, **kwargs) -> List[int]:
|
||||
"""Call the tokenizer and return token IDs.
|
||||
|
||||
|
||||
This method provides the same interface as the original tokenizer
|
||||
but with multi-character Chinese token handling.
|
||||
|
||||
|
||||
Args:
|
||||
text: Input text to tokenize
|
||||
**kwargs: Additional arguments passed to the base tokenizer
|
||||
|
||||
|
||||
Returns:
|
||||
List of token IDs
|
||||
|
||||
|
||||
Raises:
|
||||
TypeError: If input is not a string
|
||||
ValueError: If tokenization fails
|
||||
|
||||
+128
-115
@@ -24,7 +24,6 @@ from typing import Tuple, Union, Generator, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
import warnings
|
||||
from einops import rearrange
|
||||
@@ -32,6 +31,7 @@ from pydantic import BaseModel
|
||||
|
||||
try:
|
||||
from safetensors.torch import load_file
|
||||
|
||||
SAFETENSORS_AVAILABLE = True
|
||||
except ImportError:
|
||||
SAFETENSORS_AVAILABLE = False
|
||||
@@ -84,9 +84,9 @@ class VoxCPMConfig(BaseModel):
|
||||
|
||||
|
||||
class LoRAConfig(BaseModel):
|
||||
enable_lm: bool = False # Apply LoRA to base_lm + residual_lm
|
||||
enable_dit: bool = False # Apply LoRA to VoxCPMLocDiT
|
||||
enable_proj: bool = False # Apply LoRA to projection Linear layers
|
||||
enable_lm: bool = False # Apply LoRA to base_lm + residual_lm
|
||||
enable_dit: bool = False # Apply LoRA to VoxCPMLocDiT
|
||||
enable_proj: bool = False # Apply LoRA to projection Linear layers
|
||||
|
||||
r: int = 8
|
||||
alpha: int = 16
|
||||
@@ -165,10 +165,10 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
# Projection layers
|
||||
self.fsq_layer = ScalarQuantizationLayer(
|
||||
config.lm_config.hidden_size,
|
||||
config.lm_config.hidden_size,
|
||||
config.scalar_quantization_latent_dim,
|
||||
config.scalar_quantization_scale
|
||||
config.lm_config.hidden_size,
|
||||
config.lm_config.hidden_size,
|
||||
config.scalar_quantization_latent_dim,
|
||||
config.scalar_quantization_scale,
|
||||
)
|
||||
self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size)
|
||||
self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
|
||||
@@ -196,9 +196,7 @@ class VoxCPMModel(nn.Module):
|
||||
# LM: base_lm + residual_lm
|
||||
if cfg.enable_lm:
|
||||
for lm in [self.base_lm, self.residual_lm]:
|
||||
apply_lora_to_named_linear_modules(
|
||||
lm, target_submodule_names=cfg.target_modules_lm, **lora_kwargs
|
||||
)
|
||||
apply_lora_to_named_linear_modules(lm, target_submodule_names=cfg.target_modules_lm, **lora_kwargs)
|
||||
|
||||
# DiT: feat_decoder.estimator
|
||||
if cfg.enable_dit:
|
||||
@@ -209,6 +207,7 @@ class VoxCPMModel(nn.Module):
|
||||
# 投影层
|
||||
if cfg.enable_proj:
|
||||
from ..modules.layers.lora import LoRALinear
|
||||
|
||||
for attr_name in cfg.target_proj_modules:
|
||||
module = getattr(self, attr_name, None)
|
||||
if isinstance(module, nn.Linear):
|
||||
@@ -221,13 +220,17 @@ class VoxCPMModel(nn.Module):
|
||||
if self.device != "cuda":
|
||||
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
|
||||
try:
|
||||
import triton
|
||||
import triton # noqa: F401
|
||||
except ImportError:
|
||||
raise ValueError("triton is not installed")
|
||||
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
||||
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
||||
self.residual_lm.forward_step = torch.compile(
|
||||
self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True
|
||||
)
|
||||
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
||||
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
|
||||
self.feat_decoder.estimator = torch.compile(
|
||||
self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: torch.compile disabled - {e}", file=sys.stderr)
|
||||
return self
|
||||
@@ -313,9 +316,11 @@ class VoxCPMModel(nn.Module):
|
||||
mu=dit_hidden,
|
||||
patch_size=self.patch_size,
|
||||
cond=feat_cond_for_sample,
|
||||
n_timesteps=self.config.dit_config.cfm_config.inference_cfg_rate
|
||||
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
|
||||
else 10,
|
||||
n_timesteps=(
|
||||
self.config.dit_config.cfm_config.inference_cfg_rate
|
||||
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
|
||||
else 10
|
||||
),
|
||||
)
|
||||
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
|
||||
|
||||
@@ -331,7 +336,6 @@ class VoxCPMModel(nn.Module):
|
||||
def _dtype(self):
|
||||
return get_dtype(self.config.dtype)
|
||||
|
||||
|
||||
def generate(self, *args, **kwargs) -> torch.Tensor:
|
||||
return next(self._generate(*args, streaming=False, **kwargs))
|
||||
|
||||
@@ -350,7 +354,7 @@ class VoxCPMModel(nn.Module):
|
||||
cfg_value: float = 2.0,
|
||||
retry_badcase: bool = False,
|
||||
retry_badcase_max_times: int = 3,
|
||||
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
|
||||
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
|
||||
streaming: bool = False,
|
||||
) -> Generator[torch.Tensor, None, None]:
|
||||
if retry_badcase and streaming:
|
||||
@@ -394,7 +398,7 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
audio, sr = torchaudio.load(prompt_wav_path)
|
||||
if audio.size(0) > 1:
|
||||
audio = audio.mean(dim=0, keepdim=True)
|
||||
audio = audio.mean(dim=0, keepdim=True)
|
||||
|
||||
if sr != self.sample_rate:
|
||||
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
|
||||
@@ -435,7 +439,7 @@ class VoxCPMModel(nn.Module):
|
||||
audio_mask = audio_mask.unsqueeze(0).to(self.device)
|
||||
|
||||
target_text_length = len(self.text_tokenizer(target_text))
|
||||
|
||||
|
||||
retry_badcase_times = 0
|
||||
while retry_badcase_times < retry_badcase_max_times:
|
||||
inference_result = self._inference(
|
||||
@@ -444,7 +448,9 @@ class VoxCPMModel(nn.Module):
|
||||
audio_feat,
|
||||
audio_mask,
|
||||
min_len=min_len,
|
||||
max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len), # avoid too long audio
|
||||
max_len=min(
|
||||
int(target_text_length * retry_badcase_ratio_threshold + 10), max_len
|
||||
), # avoid too long audio
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
streaming=streaming,
|
||||
@@ -460,18 +466,21 @@ class VoxCPMModel(nn.Module):
|
||||
latent_pred, pred_audio_feat = next(inference_result)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...", file=sys.stderr)
|
||||
print(
|
||||
f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...",
|
||||
file=sys.stderr,
|
||||
)
|
||||
retry_badcase_times += 1
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
break
|
||||
|
||||
if not streaming:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||
yield decode_audio
|
||||
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||
yield decode_audio
|
||||
|
||||
@torch.inference_mode()
|
||||
def build_prompt_cache(
|
||||
self,
|
||||
@@ -480,11 +489,11 @@ class VoxCPMModel(nn.Module):
|
||||
):
|
||||
"""
|
||||
Build prompt cache for subsequent fast generation.
|
||||
|
||||
|
||||
Args:
|
||||
prompt_text: prompt text (required)
|
||||
prompt_wav_path: prompt audio path (required)
|
||||
|
||||
|
||||
Returns:
|
||||
prompt_cache: dict with prompt_text (raw text) and audio features.
|
||||
Text tokenization will be done during generation for consistency.
|
||||
@@ -496,7 +505,7 @@ class VoxCPMModel(nn.Module):
|
||||
audio, sr = torchaudio.load(prompt_wav_path)
|
||||
if audio.size(0) > 1:
|
||||
audio = audio.mean(dim=0, keepdim=True)
|
||||
|
||||
|
||||
if sr != self.sample_rate:
|
||||
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
|
||||
|
||||
@@ -514,16 +523,17 @@ class VoxCPMModel(nn.Module):
|
||||
self.audio_vae.latent_dim,
|
||||
-1,
|
||||
self.patch_size,
|
||||
).permute(1, 2, 0) # (D, T, P)
|
||||
).permute(
|
||||
1, 2, 0
|
||||
) # (D, T, P)
|
||||
# build prompt cache - only save raw text and audio features
|
||||
prompt_cache = {
|
||||
"prompt_text": prompt_text,
|
||||
"audio_feat": audio_feat,
|
||||
}
|
||||
|
||||
|
||||
return prompt_cache
|
||||
|
||||
|
||||
def merge_prompt_cache(
|
||||
self,
|
||||
original_cache: dict,
|
||||
@@ -532,12 +542,12 @@ class VoxCPMModel(nn.Module):
|
||||
):
|
||||
"""
|
||||
Merge original prompt cache with newly generated content to stabilize voice.
|
||||
|
||||
|
||||
Args:
|
||||
original_cache: original prompt cache
|
||||
new_text: newly generated text
|
||||
new_text: newly generated text
|
||||
new_audio_feat: newly generated audio features
|
||||
|
||||
|
||||
Returns:
|
||||
merged_cache: merged cache with prompt_text and audio_feat
|
||||
"""
|
||||
@@ -557,20 +567,17 @@ class VoxCPMModel(nn.Module):
|
||||
"prompt_text": merged_prompt_text,
|
||||
"audio_feat": merged_audio_feat,
|
||||
}
|
||||
|
||||
|
||||
return merged_cache
|
||||
|
||||
|
||||
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
||||
|
||||
|
||||
def generate_with_prompt_cache_streaming(
|
||||
self, *args, **kwargs
|
||||
) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]:
|
||||
return self._generate_with_prompt_cache(*args, streaming=True, **kwargs)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _generate_with_prompt_cache(
|
||||
self,
|
||||
@@ -588,7 +595,7 @@ class VoxCPMModel(nn.Module):
|
||||
) -> Generator[Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
|
||||
"""
|
||||
Generate audio using pre-built prompt cache.
|
||||
|
||||
|
||||
Args:
|
||||
target_text: Text to convert to speech
|
||||
prompt_cache: Cache built by build_prompt_cache (can be None)
|
||||
@@ -601,7 +608,7 @@ class VoxCPMModel(nn.Module):
|
||||
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
|
||||
streaming: Whether to return a generator of audio chunks
|
||||
streaming_prefix_len: Number of prefix audio patches to use for streaming mode
|
||||
|
||||
|
||||
Returns:
|
||||
Generator of Tuple containing:
|
||||
- Decoded audio tensor for the current step if ``streaming=True``, else final decoded audio tensor
|
||||
@@ -619,7 +626,7 @@ class VoxCPMModel(nn.Module):
|
||||
prompt_audio_feat = prompt_cache["audio_feat"]
|
||||
prompt_text = prompt_cache["prompt_text"]
|
||||
text = prompt_text + target_text
|
||||
|
||||
|
||||
text_token = torch.LongTensor(self.text_tokenizer(text))
|
||||
text_token = torch.cat(
|
||||
[
|
||||
@@ -632,7 +639,7 @@ class VoxCPMModel(nn.Module):
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
|
||||
target_text_token = torch.LongTensor(self.text_tokenizer(target_text))
|
||||
|
||||
audio_length = prompt_audio_feat.size(0)
|
||||
@@ -645,14 +652,18 @@ class VoxCPMModel(nn.Module):
|
||||
)
|
||||
text_token = torch.cat([text_token, text_pad_token])
|
||||
audio_feat = torch.cat([audio_pad_feat, prompt_audio_feat], dim=0)
|
||||
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
|
||||
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
|
||||
text_mask = (
|
||||
torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
|
||||
)
|
||||
audio_mask = (
|
||||
torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
|
||||
)
|
||||
|
||||
text_token = text_token.unsqueeze(0).to(self.device)
|
||||
text_mask = text_mask.unsqueeze(0).to(self.device)
|
||||
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
|
||||
audio_mask = audio_mask.unsqueeze(0).to(self.device)
|
||||
|
||||
|
||||
# run inference
|
||||
target_text_length = len(self.text_tokenizer(target_text))
|
||||
retry_badcase_times = 0
|
||||
@@ -663,7 +674,9 @@ class VoxCPMModel(nn.Module):
|
||||
audio_feat,
|
||||
audio_mask,
|
||||
min_len=min_len,
|
||||
max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len), # avoid too long audio
|
||||
max_len=min(
|
||||
int(target_text_length * retry_badcase_ratio_threshold + 10), max_len
|
||||
), # avoid too long audio
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
streaming=streaming,
|
||||
@@ -674,17 +687,16 @@ class VoxCPMModel(nn.Module):
|
||||
for latent_pred, pred_audio_feat in inference_result:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
|
||||
yield (
|
||||
decode_audio,
|
||||
target_text_token,
|
||||
pred_audio_feat
|
||||
)
|
||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||
break
|
||||
else:
|
||||
latent_pred, pred_audio_feat = next(inference_result)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...", file=sys.stderr)
|
||||
print(
|
||||
f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...",
|
||||
file=sys.stderr,
|
||||
)
|
||||
retry_badcase_times += 1
|
||||
continue
|
||||
else:
|
||||
@@ -695,18 +707,14 @@ class VoxCPMModel(nn.Module):
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||
patch_len = self.patch_size * self.chunk_size
|
||||
if audio_mask.sum().item() > 0:
|
||||
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1):].squeeze(1).cpu()
|
||||
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1) :].squeeze(1).cpu()
|
||||
else:
|
||||
decode_audio = decode_audio[..., :].squeeze(1).cpu()
|
||||
yield (
|
||||
decode_audio,
|
||||
target_text_token,
|
||||
pred_audio_feat
|
||||
)
|
||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||
|
||||
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return next(self._inference(*args, streaming=False, **kwargs))
|
||||
|
||||
|
||||
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
|
||||
return self._inference(*args, streaming=True, **kwargs)
|
||||
|
||||
@@ -725,10 +733,10 @@ class VoxCPMModel(nn.Module):
|
||||
streaming_prefix_len: int = 3,
|
||||
) -> Generator[Tuple[torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
|
||||
"""Core inference method for audio generation.
|
||||
|
||||
|
||||
This is the main inference loop that generates audio features
|
||||
using the language model and diffusion transformer.
|
||||
|
||||
|
||||
Args:
|
||||
text: Input text tokens
|
||||
text_mask: Mask for text tokens
|
||||
@@ -739,7 +747,7 @@ class VoxCPMModel(nn.Module):
|
||||
inference_timesteps: Number of diffusion steps
|
||||
cfg_value: Classifier-free guidance value
|
||||
streaming: Whether to yield each step latent feature or just the final result
|
||||
|
||||
|
||||
Returns:
|
||||
Generator of Tuple containing:
|
||||
- Predicted latent feature at the current step if ``streaming=True``, else final latent features
|
||||
@@ -749,12 +757,12 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
feat_embed = self.feat_encoder(feat) # [b, t, h_feat]
|
||||
feat_embed = self.enc_to_lm_proj(feat_embed)
|
||||
|
||||
|
||||
if self.config.lm_config.use_mup:
|
||||
scale_emb = self.config.lm_config.scale_emb
|
||||
else:
|
||||
scale_emb = 1.0
|
||||
|
||||
|
||||
text_embed = self.base_lm.embed_tokens(text) * scale_emb
|
||||
combined_embed = text_mask.unsqueeze(-1) * text_embed + feat_mask.unsqueeze(-1) * feat_embed
|
||||
|
||||
@@ -778,11 +786,10 @@ class VoxCPMModel(nn.Module):
|
||||
is_causal=True,
|
||||
)
|
||||
self.base_lm.kv_cache.fill_caches(kv_cache_tuple)
|
||||
|
||||
|
||||
enc_outputs = self.fsq_layer(enc_outputs) * feat_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
|
||||
lm_hidden = enc_outputs[:, -1, :]
|
||||
|
||||
|
||||
residual_enc_outputs, residual_kv_cache_tuple = self.residual_lm(
|
||||
inputs_embeds=enc_outputs + feat_mask.unsqueeze(-1) * feat_embed,
|
||||
is_causal=True,
|
||||
@@ -790,7 +797,6 @@ class VoxCPMModel(nn.Module):
|
||||
self.residual_lm.kv_cache.fill_caches(residual_kv_cache_tuple)
|
||||
residual_hidden = residual_enc_outputs[:, -1, :]
|
||||
|
||||
|
||||
for i in tqdm(range(max_len)):
|
||||
dit_hidden_1 = self.lm_to_dit_proj(lm_hidden) # [b, h_dit]
|
||||
dit_hidden_2 = self.res_to_dit_proj(residual_hidden) # [b, h_dit]
|
||||
@@ -805,10 +811,10 @@ class VoxCPMModel(nn.Module):
|
||||
).transpose(
|
||||
1, 2
|
||||
) # [b, p, d]
|
||||
|
||||
|
||||
curr_embed = self.feat_encoder(pred_feat.unsqueeze(1)) # b, 1, c
|
||||
curr_embed = self.enc_to_lm_proj(curr_embed)
|
||||
|
||||
|
||||
pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
|
||||
prefix_feat_cond = pred_feat
|
||||
|
||||
@@ -816,58 +822,70 @@ class VoxCPMModel(nn.Module):
|
||||
# return the last three predicted latent features to provide enough context for smooth decoding
|
||||
pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1)
|
||||
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
|
||||
|
||||
yield feat_pred, pred_feat_seq
|
||||
|
||||
|
||||
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
|
||||
if i > min_len and stop_flag == 1:
|
||||
break
|
||||
|
||||
|
||||
lm_hidden = self.base_lm.forward_step(
|
||||
curr_embed[:, 0, :], torch.tensor([self.base_lm.kv_cache.step()], device=curr_embed.device)
|
||||
).clone()
|
||||
|
||||
|
||||
lm_hidden = self.fsq_layer(lm_hidden)
|
||||
residual_hidden = self.residual_lm.forward_step(
|
||||
lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device)
|
||||
lm_hidden + curr_embed[:, 0, :],
|
||||
torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device),
|
||||
).clone()
|
||||
|
||||
|
||||
if not streaming:
|
||||
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
|
||||
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
|
||||
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
||||
audio_vae_config = getattr(config, 'audio_vae_config', None)
|
||||
audio_vae_config = getattr(config, "audio_vae_config", None)
|
||||
audio_vae = AudioVAE(config=audio_vae_config) if audio_vae_config else AudioVAE()
|
||||
vae_state_dict = torch.load(
|
||||
os.path.join(path, "audiovae.pth"),
|
||||
map_location="cpu",
|
||||
weights_only=True,
|
||||
)["state_dict"]
|
||||
# Try to load AudioVAE from safetensors first, fallback to pytorch
|
||||
audiovae_safetensors_path = os.path.join(path, "audiovae.safetensors")
|
||||
audiovae_pth_path = os.path.join(path, "audiovae.pth")
|
||||
if os.path.exists(audiovae_safetensors_path) and SAFETENSORS_AVAILABLE:
|
||||
print(f"Loading AudioVAE from safetensors: {audiovae_safetensors_path}", file=sys.stderr)
|
||||
vae_state_dict = load_file(audiovae_safetensors_path, device="cpu")
|
||||
elif os.path.exists(audiovae_pth_path):
|
||||
print(f"Loading AudioVAE from pytorch: {audiovae_pth_path}", file=sys.stderr)
|
||||
checkpoint = torch.load(
|
||||
audiovae_pth_path,
|
||||
map_location="cpu",
|
||||
weights_only=True,
|
||||
)
|
||||
vae_state_dict = checkpoint.get("state_dict", checkpoint)
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
|
||||
)
|
||||
model = cls(config, tokenizer, audio_vae, lora_config)
|
||||
if not training:
|
||||
lm_dtype = get_dtype(model.config.dtype)
|
||||
model = model.to(lm_dtype)
|
||||
else: # training mode
|
||||
else: # training mode
|
||||
for name, param in model.named_parameters():
|
||||
if "audio_vae" in name: # freeze VAE weights
|
||||
if "audio_vae" in name: # freeze VAE weights
|
||||
param.requires_grad = False
|
||||
continue
|
||||
if lora_config is not None:
|
||||
if "lora" not in name: # freeze non-LoRA weights
|
||||
if "lora" not in name: # freeze non-LoRA weights
|
||||
param.requires_grad = False
|
||||
model.audio_vae = model.audio_vae.to(torch.float32)
|
||||
|
||||
|
||||
# Try to load from safetensors first, fallback to pytorch_model.bin
|
||||
safetensors_path = os.path.join(path, "model.safetensors")
|
||||
pytorch_model_path = os.path.join(path, "pytorch_model.bin")
|
||||
|
||||
|
||||
if os.path.exists(safetensors_path) and SAFETENSORS_AVAILABLE:
|
||||
print(f"Loading model from safetensors: {safetensors_path}", file=sys.stderr)
|
||||
model_state_dict = load_file(safetensors_path)
|
||||
@@ -880,13 +898,11 @@ class VoxCPMModel(nn.Module):
|
||||
)
|
||||
model_state_dict = checkpoint.get("state_dict", checkpoint)
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"Model file not found. Expected either {safetensors_path} or {pytorch_model_path}"
|
||||
)
|
||||
|
||||
raise FileNotFoundError(f"Model file not found. Expected either {safetensors_path} or {pytorch_model_path}")
|
||||
|
||||
for kw, val in vae_state_dict.items():
|
||||
model_state_dict[f"audio_vae.{kw}"] = val
|
||||
|
||||
|
||||
# LoRALinear holds weight/bias directly, compatible with nn.Linear state_dict keys.
|
||||
# Using strict=False since pretrained weights don't contain lora_A/lora_B.
|
||||
model.load_state_dict(model_state_dict, strict=False)
|
||||
@@ -900,6 +916,7 @@ class VoxCPMModel(nn.Module):
|
||||
def _iter_lora_modules(self):
|
||||
"""Iterate over all LoRA modules."""
|
||||
from ..modules.layers.lora import LoRALinear
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, LoRALinear):
|
||||
yield module
|
||||
@@ -909,7 +926,7 @@ class VoxCPMModel(nn.Module):
|
||||
Load LoRA weights from file, supports calling after torch.compile.
|
||||
Uses named_parameters() to handle compile's _orig_mod wrapper.
|
||||
Supports both safetensors and pytorch formats.
|
||||
|
||||
|
||||
Args:
|
||||
lora_path: Checkpoint path (directory or .safetensors/.ckpt file)
|
||||
device: Target device, defaults to model's current device
|
||||
@@ -917,18 +934,18 @@ class VoxCPMModel(nn.Module):
|
||||
tuple: (loaded_keys, skipped_keys)
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
device = device or self.device
|
||||
lora_path = Path(lora_path)
|
||||
|
||||
lora_p = Path(lora_path)
|
||||
|
||||
# Try safetensors first, then fallback to .ckpt
|
||||
if lora_path.is_dir():
|
||||
safetensors_file = lora_path / "lora_weights.safetensors"
|
||||
ckpt_file = lora_path / "lora_weights.ckpt"
|
||||
if lora_p.is_dir():
|
||||
safetensors_file = lora_p / "lora_weights.safetensors"
|
||||
ckpt_file = lora_p / "lora_weights.ckpt"
|
||||
else:
|
||||
safetensors_file = lora_path if lora_path.suffix == ".safetensors" else None
|
||||
ckpt_file = lora_path if lora_path.suffix in [".ckpt", ".pth"] else None
|
||||
|
||||
safetensors_file = lora_p if lora_p.suffix == ".safetensors" else None
|
||||
ckpt_file = lora_p if lora_p.suffix in [".ckpt", ".pth"] else None
|
||||
|
||||
# Load from safetensors if available
|
||||
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
|
||||
state_dict = load_file(str(safetensors_file), device=device)
|
||||
@@ -936,14 +953,12 @@ class VoxCPMModel(nn.Module):
|
||||
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
|
||||
state_dict = ckpt.get("state_dict", ckpt)
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}"
|
||||
)
|
||||
|
||||
raise FileNotFoundError(f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}")
|
||||
|
||||
# Build param mapping (handle torch.compile's _orig_mod prefix)
|
||||
model_params = dict(self.named_parameters())
|
||||
key_mapping = {k.replace("._orig_mod.", "."): k for k in model_params if "._orig_mod." in k}
|
||||
|
||||
|
||||
loaded_keys, skipped_keys = [], []
|
||||
for key, value in state_dict.items():
|
||||
target_key = key if key in model_params else key_mapping.get(key)
|
||||
@@ -952,7 +967,7 @@ class VoxCPMModel(nn.Module):
|
||||
loaded_keys.append(key)
|
||||
else:
|
||||
skipped_keys.append(key)
|
||||
|
||||
|
||||
return loaded_keys, skipped_keys
|
||||
|
||||
def set_lora_enabled(self, enabled: bool):
|
||||
@@ -967,6 +982,4 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
def get_lora_state_dict(self) -> dict:
|
||||
"""Get all LoRA parameters (lora_A/lora_B)."""
|
||||
return {name: param.data.clone()
|
||||
for name, param in self.named_parameters()
|
||||
if "lora_" in name}
|
||||
return {name: param.data.clone() for name, param in self.named_parameters() if "lora_" in name}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1 +1,2 @@
|
||||
from .audio_vae import AudioVAE, AudioVAEConfig
|
||||
from .audio_vae_v2 import AudioVAE as AudioVAEV2, AudioVAEConfig as AudioVAEConfigV2
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import math
|
||||
from typing import List, Union, Optional
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -285,12 +285,12 @@ class AudioVAE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[AudioVAEConfig] = None,
|
||||
config: AudioVAEConfig = None,
|
||||
):
|
||||
# 如果没有传入config,使用默认配置
|
||||
if config is None:
|
||||
config = AudioVAEConfig()
|
||||
|
||||
|
||||
super().__init__()
|
||||
|
||||
encoder_dim = config.encoder_dim
|
||||
@@ -301,7 +301,7 @@ class AudioVAE(nn.Module):
|
||||
depthwise = config.depthwise
|
||||
sample_rate = config.sample_rate
|
||||
use_noise_block = config.use_noise_block
|
||||
|
||||
|
||||
self.encoder_dim = encoder_dim
|
||||
self.encoder_rates = encoder_rates
|
||||
self.decoder_dim = decoder_dim
|
||||
|
||||
@@ -0,0 +1,486 @@
|
||||
import math
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def WNConv1d(*args, **kwargs):
|
||||
return weight_norm(nn.Conv1d(*args, **kwargs))
|
||||
|
||||
|
||||
def WNConvTranspose1d(*args, **kwargs):
|
||||
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
||||
|
||||
|
||||
class CausalConv1d(nn.Conv1d):
|
||||
def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.__padding = padding
|
||||
self.__output_padding = output_padding
|
||||
|
||||
def forward(self, x):
|
||||
x_pad = F.pad(x, (self.__padding * 2 - self.__output_padding, 0))
|
||||
return super().forward(x_pad)
|
||||
|
||||
|
||||
class CausalTransposeConv1d(nn.ConvTranspose1d):
|
||||
def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.__padding = padding
|
||||
self.__output_padding = output_padding
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x)[..., : -(self.__padding * 2 - self.__output_padding)]
|
||||
|
||||
|
||||
def WNCausalConv1d(*args, **kwargs):
|
||||
return weight_norm(CausalConv1d(*args, **kwargs))
|
||||
|
||||
|
||||
def WNCausalTransposeConv1d(*args, **kwargs):
|
||||
return weight_norm(CausalTransposeConv1d(*args, **kwargs))
|
||||
|
||||
|
||||
# Scripting this brings model speed up 1.4x
|
||||
@torch.jit.script
|
||||
def snake(x, alpha):
|
||||
shape = x.shape
|
||||
x = x.reshape(shape[0], shape[1], -1)
|
||||
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
||||
x = x.reshape(shape)
|
||||
return x
|
||||
|
||||
|
||||
class Snake1d(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
||||
|
||||
def forward(self, x):
|
||||
return snake(x, self.alpha)
|
||||
|
||||
|
||||
def init_weights(m):
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
class CausalResidualUnit(nn.Module):
|
||||
def __init__(self, dim: int = 16, dilation: int = 1, kernel: int = 7, groups: int = 1):
|
||||
super().__init__()
|
||||
pad = ((7 - 1) * dilation) // 2
|
||||
self.block = nn.Sequential(
|
||||
Snake1d(dim),
|
||||
WNCausalConv1d(
|
||||
dim,
|
||||
dim,
|
||||
kernel_size=kernel,
|
||||
dilation=dilation,
|
||||
padding=pad,
|
||||
groups=groups,
|
||||
),
|
||||
Snake1d(dim),
|
||||
WNCausalConv1d(dim, dim, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.block(x)
|
||||
pad = (x.shape[-1] - y.shape[-1]) // 2
|
||||
assert pad == 0
|
||||
if pad > 0:
|
||||
x = x[..., pad:-pad]
|
||||
return x + y
|
||||
|
||||
|
||||
class CausalEncoderBlock(nn.Module):
|
||||
def __init__(self, output_dim: int = 16, input_dim=None, stride: int = 1, groups=1):
|
||||
super().__init__()
|
||||
input_dim = input_dim or output_dim // 2
|
||||
self.block = nn.Sequential(
|
||||
CausalResidualUnit(input_dim, dilation=1, groups=groups),
|
||||
CausalResidualUnit(input_dim, dilation=3, groups=groups),
|
||||
CausalResidualUnit(input_dim, dilation=9, groups=groups),
|
||||
Snake1d(input_dim),
|
||||
WNCausalConv1d(
|
||||
input_dim,
|
||||
output_dim,
|
||||
kernel_size=2 * stride,
|
||||
stride=stride,
|
||||
padding=math.ceil(stride / 2),
|
||||
output_padding=stride % 2,
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class CausalEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int = 64,
|
||||
latent_dim: int = 32,
|
||||
strides: list = [2, 4, 8, 8],
|
||||
depthwise: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
# Create first convolution
|
||||
self.block = [WNCausalConv1d(1, d_model, kernel_size=7, padding=3)]
|
||||
|
||||
# Create EncoderBlocks that double channels as they downsample by `stride`
|
||||
for stride in strides:
|
||||
d_model *= 2
|
||||
groups = d_model // 2 if depthwise else 1
|
||||
self.block += [CausalEncoderBlock(output_dim=d_model, stride=stride, groups=groups)]
|
||||
|
||||
groups = d_model if depthwise else 1
|
||||
|
||||
# Create two convolution, for mu and logvar
|
||||
self.fc_mu = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
|
||||
self.fc_logvar = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
|
||||
|
||||
# Wrap black into nn.Sequential
|
||||
self.block = nn.Sequential(*self.block)
|
||||
self.enc_dim = d_model
|
||||
|
||||
def forward(self, x):
|
||||
hidden_state = self.block(x)
|
||||
return {
|
||||
"hidden_state": hidden_state,
|
||||
"mu": self.fc_mu(hidden_state),
|
||||
"logvar": self.fc_logvar(hidden_state),
|
||||
}
|
||||
|
||||
|
||||
class NoiseBlock(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.linear = WNCausalConv1d(dim, dim, kernel_size=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, T = x.shape
|
||||
noise = torch.randn((B, 1, T), device=x.device, dtype=x.dtype)
|
||||
h = self.linear(x)
|
||||
n = noise * h
|
||||
x = x + n
|
||||
return x
|
||||
|
||||
|
||||
class CausalDecoderBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int = 16,
|
||||
output_dim: int = 8,
|
||||
stride: int = 1,
|
||||
groups=1,
|
||||
use_noise_block: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
layers = [
|
||||
Snake1d(input_dim),
|
||||
WNCausalTransposeConv1d(
|
||||
input_dim,
|
||||
output_dim,
|
||||
kernel_size=2 * stride,
|
||||
stride=stride,
|
||||
padding=math.ceil(stride / 2),
|
||||
output_padding=stride % 2,
|
||||
),
|
||||
]
|
||||
if use_noise_block:
|
||||
layers.append(NoiseBlock(output_dim))
|
||||
layers.extend(
|
||||
[
|
||||
CausalResidualUnit(output_dim, dilation=1, groups=groups),
|
||||
CausalResidualUnit(output_dim, dilation=3, groups=groups),
|
||||
CausalResidualUnit(output_dim, dilation=9, groups=groups),
|
||||
]
|
||||
)
|
||||
self.block = nn.Sequential(*layers)
|
||||
self.input_channels = input_dim
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class TransposeLastTwoDim(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.transpose(x, -1, -2)
|
||||
|
||||
|
||||
class SampleRateConditionLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
sr_bin_buckets: int = None,
|
||||
cond_type: str = "scale_bias",
|
||||
cond_dim: int = 128,
|
||||
out_layer: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.cond_type, out_layer_in_dim = cond_type, input_dim
|
||||
|
||||
if cond_type == "scale_bias":
|
||||
self.scale_embed = nn.Embedding(sr_bin_buckets, input_dim)
|
||||
self.bias_embed = nn.Embedding(sr_bin_buckets, input_dim)
|
||||
nn.init.ones_(self.scale_embed.weight)
|
||||
nn.init.zeros_(self.bias_embed.weight)
|
||||
elif cond_type == "scale_bias_init":
|
||||
self.scale_embed = nn.Embedding(sr_bin_buckets, input_dim)
|
||||
self.bias_embed = nn.Embedding(sr_bin_buckets, input_dim)
|
||||
nn.init.normal_(self.scale_embed.weight, mean=1)
|
||||
nn.init.normal_(self.bias_embed.weight)
|
||||
elif cond_type == "add":
|
||||
self.cond_embed = nn.Embedding(sr_bin_buckets, input_dim)
|
||||
nn.init.normal_(self.cond_embed.weight)
|
||||
elif cond_type == "concat":
|
||||
self.cond_embed = nn.Embedding(sr_bin_buckets, cond_dim)
|
||||
assert out_layer, "out_layer must be True for concat cond_type"
|
||||
out_layer_in_dim = input_dim + cond_dim
|
||||
else:
|
||||
raise ValueError(f"Invalid cond_type: {cond_type}")
|
||||
|
||||
if out_layer:
|
||||
self.out_layer = nn.Sequential(
|
||||
Snake1d(out_layer_in_dim),
|
||||
WNCausalConv1d(out_layer_in_dim, input_dim, kernel_size=1),
|
||||
)
|
||||
else:
|
||||
self.out_layer = nn.Identity()
|
||||
|
||||
def forward(self, x, sr_cond):
|
||||
if self.cond_type == "scale_bias" or self.cond_type == "scale_bias_init":
|
||||
x = x * self.scale_embed(sr_cond).unsqueeze(-1) + self.bias_embed(sr_cond).unsqueeze(-1)
|
||||
elif self.cond_type == "add":
|
||||
x = x + self.cond_embed(sr_cond).unsqueeze(-1)
|
||||
elif self.cond_type == "concat":
|
||||
x = torch.cat([x, self.cond_embed(sr_cond).unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
|
||||
|
||||
return self.out_layer(x)
|
||||
|
||||
|
||||
class CausalDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_channel,
|
||||
channels,
|
||||
rates,
|
||||
depthwise: bool = False,
|
||||
d_out: int = 1,
|
||||
use_noise_block: bool = False,
|
||||
sr_bin_boundaries: List[int] = None,
|
||||
cond_type: str = "scale_bias",
|
||||
cond_dim: int = 128,
|
||||
cond_out_layer: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Add first conv layer
|
||||
if depthwise:
|
||||
layers = [
|
||||
WNCausalConv1d(input_channel, input_channel, kernel_size=7, padding=3, groups=input_channel),
|
||||
WNCausalConv1d(input_channel, channels, kernel_size=1),
|
||||
]
|
||||
else:
|
||||
layers = [WNCausalConv1d(input_channel, channels, kernel_size=7, padding=3)]
|
||||
|
||||
# Add upsampling + MRF blocks
|
||||
for i, stride in enumerate(rates):
|
||||
input_dim = channels // 2**i
|
||||
output_dim = channels // 2 ** (i + 1)
|
||||
groups = output_dim if depthwise else 1
|
||||
layers += [
|
||||
CausalDecoderBlock(
|
||||
input_dim,
|
||||
output_dim,
|
||||
stride,
|
||||
groups=groups,
|
||||
use_noise_block=use_noise_block,
|
||||
)
|
||||
]
|
||||
|
||||
# Add final conv layer
|
||||
layers += [
|
||||
Snake1d(output_dim),
|
||||
WNCausalConv1d(output_dim, d_out, kernel_size=7, padding=3),
|
||||
nn.Tanh(),
|
||||
]
|
||||
|
||||
if sr_bin_boundaries is None:
|
||||
self.model = nn.Sequential(*layers)
|
||||
self.sr_bin_boundaries = None
|
||||
else:
|
||||
self.model = nn.ModuleList(layers)
|
||||
|
||||
self.register_buffer("sr_bin_boundaries", torch.tensor(sr_bin_boundaries, dtype=torch.int32))
|
||||
self.sr_bin_buckets = len(sr_bin_boundaries) + 1
|
||||
|
||||
cond_layers = []
|
||||
for layer in self.model:
|
||||
if layer.__class__.__name__ == "CausalDecoderBlock":
|
||||
cond_layers.append(
|
||||
SampleRateConditionLayer(
|
||||
input_dim=layer.input_channels,
|
||||
sr_bin_buckets=self.sr_bin_buckets,
|
||||
cond_type=cond_type,
|
||||
cond_dim=cond_dim,
|
||||
out_layer=cond_out_layer,
|
||||
)
|
||||
)
|
||||
else:
|
||||
cond_layers.append(None)
|
||||
self.sr_cond_model = nn.ModuleList(cond_layers)
|
||||
|
||||
def get_sr_idx(self, sr):
|
||||
return torch.bucketize(sr, self.sr_bin_boundaries)
|
||||
|
||||
def forward(self, x, sr_cond=None):
|
||||
if self.sr_bin_boundaries is not None:
|
||||
# assert sr_cond is not None
|
||||
sr_cond = self.get_sr_idx(sr_cond)
|
||||
|
||||
for layer, sr_cond_layer in zip(self.model, self.sr_cond_model):
|
||||
if sr_cond_layer is not None:
|
||||
x = sr_cond_layer(x, sr_cond)
|
||||
x = layer(x)
|
||||
return x
|
||||
else:
|
||||
return self.model(x)
|
||||
|
||||
|
||||
class AudioVAEConfig(BaseModel):
|
||||
encoder_dim: int = 128
|
||||
encoder_rates: List[int] = [2, 5, 8, 8]
|
||||
latent_dim: int = 64
|
||||
decoder_dim: int = 2048
|
||||
decoder_rates: List[int] = [8, 6, 5, 2, 2, 2]
|
||||
depthwise: bool = True
|
||||
sample_rate: int = 16000
|
||||
out_sample_rate: int = 48000
|
||||
use_noise_block: bool = False
|
||||
sr_bin_boundaries: Optional[List[int]] = [20000, 30000, 40000]
|
||||
cond_type: str = "scale_bias"
|
||||
cond_dim: int = 128
|
||||
cond_out_layer: bool = False
|
||||
|
||||
|
||||
class AudioVAE(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AudioVAEConfig = None,
|
||||
):
|
||||
# 如果没有传入config,使用默认配置
|
||||
if config is None:
|
||||
config = AudioVAEConfig()
|
||||
|
||||
super().__init__()
|
||||
|
||||
encoder_dim = config.encoder_dim
|
||||
encoder_rates = config.encoder_rates
|
||||
latent_dim = config.latent_dim
|
||||
decoder_dim = config.decoder_dim
|
||||
decoder_rates = config.decoder_rates
|
||||
depthwise = config.depthwise
|
||||
sample_rate = config.sample_rate
|
||||
out_sample_rate = config.out_sample_rate
|
||||
use_noise_block = config.use_noise_block
|
||||
sr_bin_boundaries = config.sr_bin_boundaries
|
||||
cond_type = config.cond_type
|
||||
cond_dim = config.cond_dim
|
||||
cond_out_layer = config.cond_out_layer
|
||||
|
||||
self.encoder_dim = encoder_dim
|
||||
self.encoder_rates = encoder_rates
|
||||
self.decoder_dim = decoder_dim
|
||||
self.decoder_rates = decoder_rates
|
||||
self.depthwise = depthwise
|
||||
|
||||
self.use_noise_block = use_noise_block
|
||||
|
||||
if latent_dim is None:
|
||||
latent_dim = encoder_dim * (2 ** len(encoder_rates))
|
||||
|
||||
self.latent_dim = latent_dim
|
||||
self.hop_length = np.prod(encoder_rates)
|
||||
self.encoder = CausalEncoder(
|
||||
encoder_dim,
|
||||
latent_dim,
|
||||
encoder_rates,
|
||||
depthwise=depthwise,
|
||||
)
|
||||
|
||||
self.decoder = CausalDecoder(
|
||||
latent_dim,
|
||||
decoder_dim,
|
||||
decoder_rates,
|
||||
depthwise=depthwise,
|
||||
use_noise_block=use_noise_block,
|
||||
sr_bin_boundaries=sr_bin_boundaries,
|
||||
cond_type=cond_type,
|
||||
cond_dim=cond_dim,
|
||||
cond_out_layer=cond_out_layer,
|
||||
)
|
||||
self.sample_rate = sample_rate
|
||||
self.out_sample_rate = out_sample_rate
|
||||
self.sr_bin_boundaries = sr_bin_boundaries
|
||||
self.chunk_size = math.prod(encoder_rates)
|
||||
|
||||
def preprocess(self, audio_data, sample_rate):
|
||||
if sample_rate is None:
|
||||
sample_rate = self.sample_rate
|
||||
assert sample_rate == self.sample_rate
|
||||
pad_to = self.hop_length
|
||||
length = audio_data.shape[-1]
|
||||
right_pad = math.ceil(length / pad_to) * pad_to - length
|
||||
audio_data = nn.functional.pad(audio_data, (0, right_pad))
|
||||
|
||||
return audio_data
|
||||
|
||||
def decode(self, z: torch.Tensor, sr_cond: torch.Tensor = None):
|
||||
"""Decode given latent codes and return audio data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
z : Tensor[B x D x T]
|
||||
Quantized continuous representation of input
|
||||
length : int, optional
|
||||
Number of samples in output audio, by default None
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A dictionary with the following keys:
|
||||
"audio" : Tensor[B x 1 x length]
|
||||
Decoded audio data.
|
||||
"""
|
||||
if self.sr_bin_boundaries is not None:
|
||||
# use default output sample rate
|
||||
if sr_cond is None:
|
||||
sr_cond = torch.tensor([self.out_sample_rate], device=z.device, dtype=torch.int32)
|
||||
return self.decoder(z, sr_cond)
|
||||
|
||||
def encode(self, audio_data: torch.Tensor, sample_rate: int):
|
||||
"""
|
||||
Args:
|
||||
audio_data: Tensor[B x 1 x T]
|
||||
sample_rate: int
|
||||
Returns:
|
||||
z: Tensor[B x D x T]
|
||||
"""
|
||||
if audio_data.ndim == 2:
|
||||
audio_data = audio_data.unsqueeze(1)
|
||||
|
||||
audio_data = self.preprocess(audio_data, sample_rate)
|
||||
return self.encoder(audio_data)["mu"]
|
||||
@@ -1 +1 @@
|
||||
from .scalar_quantization_layer import ScalarQuantizationLayer
|
||||
from .scalar_quantization_layer import ScalarQuantizationLayer
|
||||
|
||||
@@ -34,7 +34,7 @@ class LoRALinear(nn.Module):
|
||||
self.r = r
|
||||
self.alpha = alpha
|
||||
self._base_scaling = alpha / r if r > 0 else 0.0
|
||||
|
||||
|
||||
# 使用 buffer 存储 scaling,这样修改值不会触发 torch.compile 重编译
|
||||
# persistent=False 表示不保存到 state_dict,避免加载时 missing key
|
||||
self.register_buffer("scaling", torch.tensor(self._base_scaling), persistent=False)
|
||||
@@ -128,6 +128,3 @@ def apply_lora_to_named_linear_modules(
|
||||
dropout=dropout,
|
||||
)
|
||||
setattr(parent, short_name, lora_layer)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ class ScalarQuantizationLayer(nn.Module):
|
||||
|
||||
self.in_proj = nn.Linear(in_dim, latent_dim)
|
||||
self.out_proj = nn.Linear(latent_dim, out_dim)
|
||||
|
||||
|
||||
def forward(self, hidden):
|
||||
hidden = self.in_proj(hidden)
|
||||
hidden = torch.tanh(hidden)
|
||||
@@ -23,4 +23,4 @@ class ScalarQuantizationLayer(nn.Module):
|
||||
else:
|
||||
hidden = torch.round(hidden * self.scale) / self.scale
|
||||
|
||||
return self.out_proj(hidden)
|
||||
return self.out_proj(hidden)
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
from .unified_cfm import UnifiedCFM, CfmConfig
|
||||
from .local_dit import VoxCPMLocDiT
|
||||
from .local_dit_v2 import VoxCPMLocDiT as VoxCPMLocDiTV2
|
||||
|
||||
@@ -0,0 +1,116 @@
|
||||
import torch
|
||||
from ..minicpm4 import MiniCPMModel, MiniCPM4Config
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
|
||||
class SinusoidalPosEmb(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
|
||||
|
||||
def forward(self, x, scale=1000):
|
||||
if x.ndim < 1:
|
||||
x = x.unsqueeze(0)
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=x.dtype, device=device) * -emb)
|
||||
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
time_embed_dim: int,
|
||||
out_dim: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True)
|
||||
self.act = nn.SiLU()
|
||||
if out_dim is not None:
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, bias=True)
|
||||
|
||||
def forward(self, sample):
|
||||
sample = self.linear_1(sample)
|
||||
sample = self.act(sample)
|
||||
sample = self.linear_2(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class VoxCPMLocDiT(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MiniCPM4Config,
|
||||
in_channels: int = 64,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.config = config
|
||||
|
||||
self.in_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
|
||||
self.cond_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
|
||||
self.out_proj = nn.Linear(config.hidden_size, self.out_channels, bias=True)
|
||||
|
||||
self.time_embeddings = SinusoidalPosEmb(config.hidden_size)
|
||||
self.time_mlp = TimestepEmbedding(
|
||||
in_channels=config.hidden_size,
|
||||
time_embed_dim=config.hidden_size,
|
||||
)
|
||||
self.delta_time_mlp = TimestepEmbedding(
|
||||
in_channels=config.hidden_size,
|
||||
time_embed_dim=config.hidden_size,
|
||||
)
|
||||
|
||||
assert config.vocab_size == 0, "vocab_size must be 0 for local DiT"
|
||||
self.decoder = MiniCPMModel(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mu: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
cond: torch.Tensor,
|
||||
dt: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Forward pass of DiT.
|
||||
x: (N, C, T) tensor of inputs
|
||||
mu: (N, C) tensor of hidden embedding
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
cond: (N, C, T') tensor of prefix conditions
|
||||
dt: (N,) used for mean velocity (may be supported in the future...)
|
||||
"""
|
||||
x = self.in_proj(x.transpose(1, 2).contiguous())
|
||||
|
||||
cond = self.cond_proj(cond.transpose(1, 2).contiguous())
|
||||
prefix = cond.size(1)
|
||||
|
||||
t = self.time_embeddings(t).to(x.dtype)
|
||||
t = self.time_mlp(t)
|
||||
dt = self.time_embeddings(dt).to(x.dtype)
|
||||
dt = self.delta_time_mlp(dt)
|
||||
t = t + dt
|
||||
|
||||
mu = mu.view(x.size(0), -1, x.size(-1))
|
||||
x = torch.cat([mu, (t).unsqueeze(1), cond, x], dim=1)
|
||||
|
||||
hidden, _ = self.decoder(x, is_causal=False)
|
||||
hidden = hidden[:, prefix + mu.size(1) + 1 :, :]
|
||||
hidden = self.out_proj(hidden)
|
||||
|
||||
return hidden.transpose(1, 2).contiguous()
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -56,7 +56,7 @@ class UnifiedCFM(torch.nn.Module):
|
||||
cond: torch.Tensor,
|
||||
temperature: float = 1.0,
|
||||
cfg_value: float = 1.0,
|
||||
sway_sampling_coef: float = 1.0,
|
||||
sway_sampling_coef: float = 1.0,
|
||||
use_cfg_zero_star: bool = True,
|
||||
):
|
||||
b, _ = mu.shape
|
||||
@@ -116,7 +116,7 @@ class UnifiedCFM(torch.nn.Module):
|
||||
|
||||
dphi_dt = self.estimator(x_in, mu_in, t_in, cond_in, dt_in)
|
||||
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
||||
|
||||
|
||||
if use_cfg_zero_star:
|
||||
positive_flat = dphi_dt.view(b, -1)
|
||||
negative_flat = cfg_dphi_dt.view(b, -1)
|
||||
@@ -124,7 +124,7 @@ class UnifiedCFM(torch.nn.Module):
|
||||
st_star = st_star.view(b, *([1] * (len(dphi_dt.shape) - 1)))
|
||||
else:
|
||||
st_star = 1.0
|
||||
|
||||
|
||||
dphi_dt = cfg_dphi_dt * st_star + cfg_value * (dphi_dt - cfg_dphi_dt * st_star)
|
||||
|
||||
x = x - dt * dphi_dt
|
||||
@@ -138,7 +138,9 @@ class UnifiedCFM(torch.nn.Module):
|
||||
# ------------------------------------------------------------------ #
|
||||
# Training loss
|
||||
# ------------------------------------------------------------------ #
|
||||
def adaptive_loss_weighting(self, losses: torch.Tensor, mask: torch.Tensor | None = None, p: float = 0.0, epsilon: float = 1e-3):
|
||||
def adaptive_loss_weighting(
|
||||
self, losses: torch.Tensor, mask: torch.Tensor | None = None, p: float = 0.0, epsilon: float = 1e-3
|
||||
):
|
||||
weights = 1.0 / ((losses + epsilon).pow(p))
|
||||
if mask is not None:
|
||||
weights = weights * mask
|
||||
@@ -193,8 +195,7 @@ class UnifiedCFM(torch.nn.Module):
|
||||
cond = cond + noisy_mask.view(-1, 1, 1) * torch.randn_like(cond) * self.noise_cond_scale
|
||||
|
||||
ratio_r_neq_t = (
|
||||
self.ratio_r_neq_t_range[0]
|
||||
+ progress * (self.ratio_r_neq_t_range[1] - self.ratio_r_neq_t_range[0])
|
||||
self.ratio_r_neq_t_range[0] + progress * (self.ratio_r_neq_t_range[1] - self.ratio_r_neq_t_range[0])
|
||||
if self.mean_mode
|
||||
else 0.0
|
||||
)
|
||||
|
||||
@@ -26,4 +26,4 @@ class MiniCPM4Config(BaseModel):
|
||||
dim_model_base: int
|
||||
scale_depth: float
|
||||
rope_theta: float
|
||||
kv_channels: int = None
|
||||
kv_channels: int = None
|
||||
|
||||
@@ -64,10 +64,8 @@ class MiniCPMLongRoPE(nn.Module):
|
||||
self.long_factor = config.rope_scaling.long_factor
|
||||
self.original_max_position_embeddings = config.rope_scaling.original_max_position_embeddings
|
||||
|
||||
scale = (self.max_position_embeddings / self.original_max_position_embeddings)
|
||||
self.scaling_factor = math.sqrt(
|
||||
1 + math.log(scale) / math.log(self.original_max_position_embeddings)
|
||||
)
|
||||
scale = self.max_position_embeddings / self.original_max_position_embeddings
|
||||
self.scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
@@ -76,11 +74,7 @@ class MiniCPMLongRoPE(nn.Module):
|
||||
self.register_buffer("cos_cached", torch.empty(0), persistent=False)
|
||||
self.register_buffer("sin_cached", torch.empty(0), persistent=False)
|
||||
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=self.max_position_embeddings,
|
||||
device=self.inv_freq.device,
|
||||
dtype=torch.float32
|
||||
)
|
||||
self._set_cos_sin_cache(seq_len=self.max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
"""设置cos和sin缓存"""
|
||||
@@ -93,8 +87,7 @@ class MiniCPMLongRoPE(nn.Module):
|
||||
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=device)
|
||||
|
||||
freqs = torch.mul(
|
||||
torch.outer(t, 1.0 / ext_factors).to(device=device),
|
||||
self.inv_freq.to(device=device).to(dtype)
|
||||
torch.outer(t, 1.0 / ext_factors).to(device=device), self.inv_freq.to(device=device).to(dtype)
|
||||
)
|
||||
|
||||
# 创建embeddings
|
||||
@@ -123,7 +116,9 @@ class MiniCPMAttention(nn.Module):
|
||||
self.layer_idx = layer_idx
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
||||
self.head_dim = (
|
||||
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
||||
)
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
@@ -153,7 +148,7 @@ class MiniCPMAttention(nn.Module):
|
||||
cos, sin = position_emb
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
|
||||
# ref: https://github.com/pytorch/pytorch/issues/163597
|
||||
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
|
||||
query_states = query_states.contiguous()
|
||||
@@ -413,7 +408,11 @@ class MiniCPMModel(nn.Module):
|
||||
self.kv_cache = StaticKVCache(
|
||||
num_layers=self.config.num_hidden_layers,
|
||||
num_kv_heads=self.config.num_key_value_heads,
|
||||
dim_kv_head=self.config.hidden_size // self.config.num_attention_heads if self.config.kv_channels is None else self.config.kv_channels,
|
||||
dim_kv_head=(
|
||||
self.config.hidden_size // self.config.num_attention_heads
|
||||
if self.config.kv_channels is None
|
||||
else self.config.kv_channels
|
||||
),
|
||||
batch_size=batch_size,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
|
||||
@@ -25,4 +25,3 @@ __all__ = [
|
||||
"load_audio_text_datasets",
|
||||
"build_dataloader",
|
||||
]
|
||||
|
||||
|
||||
@@ -47,9 +47,7 @@ class Accelerator:
|
||||
pass
|
||||
|
||||
self.scaler = torch.amp.GradScaler("cuda") if (amp and torch.cuda.is_available()) else DummyScaler()
|
||||
self.device_ctx = (
|
||||
torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
|
||||
)
|
||||
self.device_ctx = torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
|
||||
self._ddp_model = None # For no_sync support
|
||||
|
||||
def _set_seed(self, seed: int):
|
||||
@@ -84,7 +82,7 @@ class Accelerator:
|
||||
# Model helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
def prepare_model(self, model: torch.nn.Module, **kwargs):
|
||||
if hasattr(model, 'device'): # make sure the matrix will be moved to the correct device
|
||||
if hasattr(model, "device"): # make sure the matrix will be moved to the correct device
|
||||
model.device = self.device
|
||||
model = model.to(self.device)
|
||||
if self.world_size > 1:
|
||||
@@ -163,4 +161,3 @@ class Accelerator:
|
||||
@staticmethod
|
||||
def unwrap(model: torch.nn.Module) -> torch.nn.Module:
|
||||
return model.module if hasattr(model, "module") else model
|
||||
|
||||
|
||||
@@ -36,5 +36,3 @@ def parse_args_with_config(config_path: str | Path | None = None):
|
||||
yaml_args = argbind.parse_args(yaml_args=yaml_args, argv=[])
|
||||
cli_args.update(yaml_args)
|
||||
return cli_args
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import argbind
|
||||
@@ -11,7 +10,6 @@ from ..model.voxcpm import VoxCPMConfig
|
||||
from ..modules.audiovae import AudioVAE
|
||||
from .packers import AudioFeatureProcessingPacker
|
||||
|
||||
|
||||
DEFAULT_TEXT_COLUMN = "text"
|
||||
DEFAULT_AUDIO_COLUMN = "audio"
|
||||
DEFAULT_ID_COLUMN = "dataset_id"
|
||||
@@ -36,7 +34,7 @@ def load_audio_text_datasets(
|
||||
def prepare(ds: Dataset) -> Dataset:
|
||||
if audio_column not in ds.column_names:
|
||||
raise ValueError(f"Expected '{audio_column}' column in manifest.")
|
||||
# We cast to Audio to ensure proper handling during training,
|
||||
# We cast to Audio to ensure proper handling during training,
|
||||
# but for length calculation we might need raw path or duration if available.
|
||||
# HF datasets usually don't compute duration automatically for 'Audio' column.
|
||||
ds = ds.cast_column(audio_column, Audio(sampling_rate=sample_rate))
|
||||
@@ -70,13 +68,13 @@ def compute_sample_lengths(
|
||||
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
|
||||
t_seq = ceil(t_vae / patch_size)
|
||||
- 序列总长约为: text_len + t_seq + 2
|
||||
|
||||
|
||||
Optimized: Use batch column access instead of iterating item by item.
|
||||
"""
|
||||
# Batch access columns - much faster than per-item access
|
||||
text_ids_list = ds["text_ids"]
|
||||
text_lens = [len(t) for t in text_ids_list]
|
||||
|
||||
|
||||
has_duration = "duration" in ds.column_names
|
||||
if has_duration:
|
||||
durations = ds["duration"]
|
||||
@@ -86,7 +84,7 @@ def compute_sample_lengths(
|
||||
for i in range(len(ds)):
|
||||
audio = ds[i][DEFAULT_AUDIO_COLUMN]
|
||||
durations.append(len(audio["array"]) / float(audio["sampling_rate"]))
|
||||
|
||||
|
||||
# Vectorized length computation
|
||||
lengths = []
|
||||
for text_len, duration in zip(text_lens, durations):
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -15,7 +14,7 @@ class AudioFeatureProcessingPacker:
|
||||
def __init__(self, dataset_cnt: int, max_len: int, patch_size: int, feat_dim: int, audio_vae: nn.Module):
|
||||
self.audio_start_id = 101
|
||||
self.audio_end_id = 102
|
||||
# unused now
|
||||
# unused now
|
||||
self.audio_prompt_start_id = 103
|
||||
self.audio_prompt_end_id = 104
|
||||
self.text_eos_token_id = 2
|
||||
@@ -147,31 +146,26 @@ class AudioFeatureProcessingPacker:
|
||||
|
||||
def pad_1d(x: torch.Tensor, pad_value: int = 0) -> torch.Tensor:
|
||||
if x.size(0) >= max_len:
|
||||
return x[: max_len]
|
||||
return x[:max_len]
|
||||
pad = torch.full((max_len - x.size(0),), pad_value, dtype=x.dtype, device=x.device)
|
||||
return torch.cat([x, pad], dim=0)
|
||||
|
||||
def pad_3d(x: torch.Tensor) -> torch.Tensor:
|
||||
# x: [T, P, D]
|
||||
if x.size(0) >= max_len:
|
||||
return x[: max_len]
|
||||
pad = torch.zeros(
|
||||
(max_len - x.size(0),) + x.shape[1:], dtype=x.dtype, device=x.device
|
||||
)
|
||||
return x[:max_len]
|
||||
pad = torch.zeros((max_len - x.size(0),) + x.shape[1:], dtype=x.dtype, device=x.device)
|
||||
return torch.cat([x, pad], dim=0)
|
||||
|
||||
if lengths:
|
||||
text_tokens_batch = torch.stack([pad_1d(t, pad_value=0) for t in text_tokens_list], dim=0)
|
||||
text_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in text_mask_list], dim=0)
|
||||
audio_feats_batch = torch.stack([pad_3d(f) for f in audio_feats_list], dim=0)
|
||||
audio_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in audio_mask_list], dim=0)
|
||||
loss_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in loss_mask_list], dim=0)
|
||||
labels_batch = torch.stack([pad_1d(l, pad_value=0) for l in labels_list], dim=0)
|
||||
audio_task_ids_batch = torch.stack(
|
||||
[pad_1d(t, pad_value=0) for t in audio_task_ids_list], dim=0
|
||||
)
|
||||
audio_dataset_ids_batch = torch.stack(
|
||||
[pad_1d(d, pad_value=0) for d in audio_dataset_ids_list], dim=0
|
||||
)
|
||||
labels_batch = torch.stack([pad_1d(lbl, pad_value=0) for lbl in labels_list], dim=0)
|
||||
audio_task_ids_batch = torch.stack([pad_1d(t, pad_value=0) for t in audio_task_ids_list], dim=0)
|
||||
audio_dataset_ids_batch = torch.stack([pad_1d(d, pad_value=0) for d in audio_dataset_ids_list], dim=0)
|
||||
|
||||
# Position ids: [B, T], simple 0..L_i-1 then padded with 0
|
||||
position_ids_list = []
|
||||
@@ -265,13 +259,27 @@ class AudioFeatureProcessingPacker:
|
||||
)
|
||||
audio_feat_info = torch.cat([audio_pad_feat, audio_feat_info, audio_pad_feat[0:1, ...]], dim=0)
|
||||
|
||||
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length), torch.ones(1)]).type(torch.int32).to(
|
||||
text_token.device
|
||||
text_mask = (
|
||||
torch.cat([torch.ones(text_length), torch.zeros(audio_length), torch.ones(1)])
|
||||
.type(torch.int32)
|
||||
.to(text_token.device)
|
||||
)
|
||||
audio_mask = (
|
||||
torch.cat([torch.zeros(text_length), torch.ones(audio_length), torch.zeros(1)])
|
||||
.type(torch.int32)
|
||||
.to(text_token.device)
|
||||
)
|
||||
loss_mask = (
|
||||
torch.cat(
|
||||
[
|
||||
torch.zeros(text_length),
|
||||
torch.zeros(audio_length) if is_prompt else torch.ones(audio_length),
|
||||
torch.zeros(1),
|
||||
]
|
||||
)
|
||||
.type(torch.int32)
|
||||
.to(text_token.device)
|
||||
)
|
||||
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length), torch.zeros(1)]).type(
|
||||
torch.int32
|
||||
).to(text_token.device)
|
||||
loss_mask = torch.cat([torch.zeros(text_length), torch.zeros(audio_length) if is_prompt else torch.ones(audio_length), torch.zeros(1)]).type(torch.int32).to(text_token.device)
|
||||
|
||||
labels = torch.zeros(text_length + audio_length + 1).type(torch.int32).to(text_token.device)
|
||||
labels[-2] = 1
|
||||
@@ -286,4 +294,3 @@ class AudioFeatureProcessingPacker:
|
||||
audio_duration,
|
||||
text_token_count,
|
||||
)
|
||||
|
||||
|
||||
@@ -18,4 +18,3 @@ class TrainingState:
|
||||
val_loader: object
|
||||
tracker: object
|
||||
batch_processor: object
|
||||
|
||||
|
||||
@@ -76,4 +76,3 @@ class TrainingTracker:
|
||||
@contextlib.contextmanager
|
||||
def live(self):
|
||||
yield
|
||||
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
import re
|
||||
import regex
|
||||
import inflect
|
||||
from functools import partial
|
||||
from wetext import Normalizer
|
||||
|
||||
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
|
||||
chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]+")
|
||||
|
||||
|
||||
# whether contain chinese character
|
||||
def contains_chinese(text):
|
||||
@@ -14,19 +14,19 @@ def contains_chinese(text):
|
||||
|
||||
# replace special symbol
|
||||
def replace_corner_mark(text):
|
||||
text = text.replace('²', '平方')
|
||||
text = text.replace('³', '立方')
|
||||
text = text.replace('√', '根号')
|
||||
text = text.replace('≈', '约等于')
|
||||
text = text.replace('<', '小于')
|
||||
text = text.replace("²", "平方")
|
||||
text = text.replace("³", "立方")
|
||||
text = text.replace("√", "根号")
|
||||
text = text.replace("≈", "约等于")
|
||||
text = text.replace("<", "小于")
|
||||
return text
|
||||
|
||||
|
||||
# remove meaningless symbol
|
||||
def remove_bracket(text):
|
||||
text = text.replace('(', ' ').replace(')', ' ')
|
||||
text = text.replace('【', ' ').replace('】', ' ')
|
||||
text = text.replace('`', '').replace('`', '')
|
||||
text = text.replace("(", " ").replace(")", " ")
|
||||
text = text.replace("【", " ").replace("】", " ")
|
||||
text = text.replace("`", "").replace("`", "")
|
||||
text = text.replace("——", " ")
|
||||
return text
|
||||
|
||||
@@ -38,7 +38,7 @@ def spell_out_number(text: str, inflect_parser):
|
||||
for i, c in enumerate(text):
|
||||
if not c.isdigit():
|
||||
if st is not None:
|
||||
num_str = inflect_parser.number_to_words(text[st: i])
|
||||
num_str = inflect_parser.number_to_words(text[st:i])
|
||||
new_text.append(num_str)
|
||||
st = None
|
||||
new_text.append(c)
|
||||
@@ -48,7 +48,7 @@ def spell_out_number(text: str, inflect_parser):
|
||||
if st is not None and st < len(text):
|
||||
num_str = inflect_parser.number_to_words(text[st:])
|
||||
new_text.append(num_str)
|
||||
return ''.join(new_text)
|
||||
return "".join(new_text)
|
||||
|
||||
|
||||
# split paragrah logic:
|
||||
@@ -69,18 +69,18 @@ def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=
|
||||
return len(tokenize(_text)) < merge_len
|
||||
|
||||
if lang == "zh":
|
||||
pounc = ['。', '?', '!', ';', ':', '、', '.', '?', '!', ';']
|
||||
pounc = ["。", "?", "!", ";", ":", "、", ".", "?", "!", ";"]
|
||||
else:
|
||||
pounc = ['.', '?', '!', ';', ':']
|
||||
pounc = [".", "?", "!", ";", ":"]
|
||||
if comma_split:
|
||||
pounc.extend([',', ','])
|
||||
pounc.extend([",", ","])
|
||||
st = 0
|
||||
utts = []
|
||||
for i, c in enumerate(text):
|
||||
if c in pounc:
|
||||
if len(text[st: i]) > 0:
|
||||
utts.append(text[st: i] + c)
|
||||
if i + 1 < len(text) and text[i + 1] in ['"', '”']:
|
||||
if len(text[st:i]) > 0:
|
||||
utts.append(text[st:i] + c)
|
||||
if i + 1 < len(text) and text[i + 1] in ['"', "”"]:
|
||||
tmp = utts.pop(-1)
|
||||
utts.append(tmp + text[i + 1])
|
||||
st = i + 2
|
||||
@@ -88,9 +88,9 @@ def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=
|
||||
st = i + 1
|
||||
if len(utts) == 0:
|
||||
if lang == "zh":
|
||||
utts.append(text + '。')
|
||||
utts.append(text + "。")
|
||||
else:
|
||||
utts.append(text + '.')
|
||||
utts.append(text + ".")
|
||||
final_utts = []
|
||||
cur_utt = ""
|
||||
for utt in utts:
|
||||
@@ -112,13 +112,13 @@ def replace_blank(text: str):
|
||||
out_str = []
|
||||
for i, c in enumerate(text):
|
||||
if c == " ":
|
||||
if ((text[i + 1].isascii() and text[i + 1] != " ") and
|
||||
(text[i - 1].isascii() and text[i - 1] != " ")):
|
||||
if (text[i + 1].isascii() and text[i + 1] != " ") and (text[i - 1].isascii() and text[i - 1] != " "):
|
||||
out_str.append(c)
|
||||
else:
|
||||
out_str.append(c)
|
||||
return "".join(out_str)
|
||||
|
||||
|
||||
def clean_markdown(md_text: str) -> str:
|
||||
# 去除代码块 ``` ```(包括多行)
|
||||
md_text = re.sub(r"```.*?```", "", md_text, flags=re.DOTALL)
|
||||
@@ -131,9 +131,9 @@ def clean_markdown(md_text: str) -> str:
|
||||
|
||||
# 去除链接但保留文本 [text](url) -> text
|
||||
md_text = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", md_text)
|
||||
|
||||
|
||||
# 替换无序列表符号
|
||||
md_text = re.sub(r'^(\s*)-\s+', r'\1', md_text, flags=re.MULTILINE)
|
||||
md_text = re.sub(r"^(\s*)-\s+", r"\1", md_text, flags=re.MULTILINE)
|
||||
|
||||
# 去除HTML标签
|
||||
md_text = re.sub(r"<[^>]+>", "", md_text)
|
||||
@@ -152,28 +152,31 @@ def clean_text(text):
|
||||
# 去除 Markdown 语法
|
||||
text = clean_markdown(text)
|
||||
# 匹配并移除表情符号
|
||||
text = regex.compile(r'\p{Emoji_Presentation}|\p{Emoji}\uFE0F', flags=regex.UNICODE).sub("",text)
|
||||
text = regex.compile(r"\p{Emoji_Presentation}|\p{Emoji}\uFE0F", flags=regex.UNICODE).sub("", text)
|
||||
# 去除换行符
|
||||
text = text.replace("\n", " ")
|
||||
text = text.replace("\t", " ")
|
||||
text = text.replace('"', "\“")
|
||||
text = text.replace("“", '"').replace("”", '"')
|
||||
return text
|
||||
|
||||
|
||||
class TextNormalizer:
|
||||
def __init__(self, tokenizer=None):
|
||||
self.tokenizer = tokenizer
|
||||
self.zh_tn_model = Normalizer(lang="zh", operator="tn", remove_erhua=True)
|
||||
self.en_tn_model = Normalizer(lang="en", operator="tn")
|
||||
self.inflect_parser = inflect.engine()
|
||||
|
||||
|
||||
def normalize(self, text, split=False):
|
||||
# 去除 Markdown 语法,去除表情符号,去除换行符
|
||||
lang = "zh" if contains_chinese(text) else "en"
|
||||
text = clean_text(text)
|
||||
if lang == "zh":
|
||||
text = text.replace("=", "等于") # 修复 ”550 + 320 等于 870 千卡。“ 被错误正则为 ”五百五十加三百二十等于八七十千卡.“
|
||||
if re.search(r'([\d$%^*_+≥≤≠×÷?=])', text): # 避免 英文连字符被错误正则为减
|
||||
text = re.sub(r'(?<=[a-zA-Z0-9])-(?=\d)', ' - ', text) # 修复 x-2 被正则为 x负2
|
||||
text = text.replace(
|
||||
"=", "等于"
|
||||
) # 修复 ”550 + 320 等于 870 千卡。“ 被错误正则为 ”五百五十加三百二十等于八七十千卡.“
|
||||
if re.search(r"([\d$%^*_+≥≤≠×÷?=])", text): # 避免 英文连字符被错误正则为减
|
||||
text = re.sub(r"(?<=[a-zA-Z0-9])-(?=\d)", " - ", text) # 修复 x-2 被正则为 x负2
|
||||
text = self.zh_tn_model.normalize(text)
|
||||
text = replace_blank(text)
|
||||
text = replace_corner_mark(text)
|
||||
@@ -182,4 +185,4 @@ class TextNormalizer:
|
||||
text = self.en_tn_model.normalize(text)
|
||||
text = spell_out_number(text, self.inflect_parser)
|
||||
if split is False:
|
||||
return text
|
||||
return text
|
||||
|
||||
+10
-14
@@ -7,15 +7,15 @@ Related dependencies are imported only when denoising functionality is needed.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
import torchaudio
|
||||
import torch
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
class ZipEnhancer:
|
||||
"""ZipEnhancer Audio Denoising Enhancer"""
|
||||
|
||||
def __init__(self, model_path: str = "iic/speech_zipenhancer_ans_multiloss_16k_base"):
|
||||
"""
|
||||
Initialize ZipEnhancer
|
||||
@@ -23,25 +23,21 @@ class ZipEnhancer:
|
||||
model_path: ModelScope model path or local path
|
||||
"""
|
||||
self.model_path = model_path
|
||||
self._pipeline = pipeline(
|
||||
Tasks.acoustic_noise_suppression,
|
||||
model=self.model_path
|
||||
)
|
||||
|
||||
self._pipeline = pipeline(Tasks.acoustic_noise_suppression, model=self.model_path)
|
||||
|
||||
def _normalize_loudness(self, wav_path: str):
|
||||
"""
|
||||
Audio loudness normalization
|
||||
|
||||
|
||||
Args:
|
||||
wav_path: Audio file path
|
||||
"""
|
||||
audio, sr = torchaudio.load(wav_path)
|
||||
loudness = torchaudio.functional.loudness(audio, sr)
|
||||
normalized_audio = torchaudio.functional.gain(audio, -20-loudness)
|
||||
normalized_audio = torchaudio.functional.gain(audio, -20 - loudness)
|
||||
torchaudio.save(wav_path, normalized_audio, sr)
|
||||
|
||||
def enhance(self, input_path: str, output_path: Optional[str] = None,
|
||||
normalize_loudness: bool = True) -> str:
|
||||
|
||||
def enhance(self, input_path: str, output_path: Optional[str] = None, normalize_loudness: bool = True) -> str:
|
||||
"""
|
||||
Audio denoising enhancement
|
||||
Args:
|
||||
@@ -57,7 +53,7 @@ class ZipEnhancer:
|
||||
raise FileNotFoundError(f"Input audio file does not exist: {input_path}")
|
||||
# Create temporary file if no output path is specified
|
||||
if output_path is None:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
|
||||
output_path = tmp_file.name
|
||||
try:
|
||||
# Perform denoising processing
|
||||
@@ -73,4 +69,4 @@ class ZipEnhancer:
|
||||
os.unlink(output_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise RuntimeError(f"Audio denoising processing failed: {e}")
|
||||
raise RuntimeError(f"Audio denoising processing failed: {e}")
|
||||
|
||||
Reference in New Issue
Block a user