update voxcpm2

This commit is contained in:
刘鑫
2026-03-31 11:50:37 +08:00
parent 23ed7ffeee
commit d9cf376e16
36 changed files with 8163 additions and 834 deletions
+144 -85
View File
@@ -3,13 +3,14 @@ import sys
import numpy as np import numpy as np
import torch import torch
import gradio as gr import gradio as gr
import spaces import spaces # noqa: F401
from typing import Optional, Tuple from typing import Optional, Tuple
from funasr import AutoModel from funasr import AutoModel
from pathlib import Path from pathlib import Path
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
if os.environ.get("HF_REPO_ID", "").strip() == "": if os.environ.get("HF_REPO_ID", "").strip() == "":
os.environ["HF_REPO_ID"] = "openbmb/VoxCPM1.5" os.environ["HF_REPO_ID"] = "openbmb/VoxCPM2"
import voxcpm import voxcpm
@@ -24,13 +25,13 @@ class VoxCPMDemo:
self.asr_model: Optional[AutoModel] = AutoModel( self.asr_model: Optional[AutoModel] = AutoModel(
model=self.asr_model_id, model=self.asr_model_id,
disable_update=True, disable_update=True,
log_level='DEBUG', log_level="DEBUG",
device="cuda:0" if self.device == "cuda" else "cpu", device="cuda:0" if self.device == "cuda" else "cpu",
) )
# TTS model (lazy init) # TTS model (lazy init)
self.voxcpm_model: Optional[voxcpm.VoxCPM] = None self.voxcpm_model: Optional[voxcpm.VoxCPM] = None
self.default_local_model_dir = "./models/VoxCPM1.5" self.default_local_model_dir = "/Users/xinliu/Downloads/VoxCPM2-0.5B-newaudiovae-6hz-0316"
# ---------- Model helpers ---------- # ---------- Model helpers ----------
def _resolve_model_dir(self) -> str: def _resolve_model_dir(self) -> str:
@@ -49,6 +50,7 @@ class VoxCPMDemo:
if not os.path.isdir(target_dir): if not os.path.isdir(target_dir):
try: try:
from huggingface_hub import snapshot_download # type: ignore from huggingface_hub import snapshot_download # type: ignore
os.makedirs(target_dir, exist_ok=True) os.makedirs(target_dir, exist_ok=True)
print(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...", file=sys.stderr) print(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...", file=sys.stderr)
snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False) snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False)
@@ -64,7 +66,7 @@ class VoxCPMDemo:
print("Model not loaded, initializing...", file=sys.stderr) print("Model not loaded, initializing...", file=sys.stderr)
model_dir = self._resolve_model_dir() model_dir = self._resolve_model_dir()
print(f"Using model dir: {model_dir}", file=sys.stderr) print(f"Using model dir: {model_dir}", file=sys.stderr)
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir) self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir, optimize=False)
print("Model loaded successfully.", file=sys.stderr) print("Model loaded successfully.", file=sys.stderr)
return self.voxcpm_model return self.voxcpm_model
@@ -73,21 +75,24 @@ class VoxCPMDemo:
if prompt_wav is None: if prompt_wav is None:
return "" return ""
res = self.asr_model.generate(input=prompt_wav, language="auto", use_itn=True) res = self.asr_model.generate(input=prompt_wav, language="auto", use_itn=True)
text = res[0]["text"].split('|>')[-1] text = res[0]["text"].split("|>")[-1]
return text return text
def generate_tts_audio( def generate_tts_audio(
self, self,
text_input: str, text_input: str,
prompt_wav_path_input: Optional[str] = None, control_instruction: str = "",
prompt_text_input: Optional[str] = None, reference_wav_path_input: Optional[str] = None,
cfg_value_input: float = 2.0, cfg_value_input: float = 2.0,
inference_timesteps_input: int = 10, inference_timesteps_input: int = 10,
do_normalize: bool = True, do_normalize: bool = True,
denoise: bool = True, denoise: bool = True,
) -> Tuple[int, np.ndarray]: ) -> Tuple[int, np.ndarray]:
""" """
Generate speech from text using VoxCPM; optional reference audio for voice style guidance. Generate speech from text using VoxCPM.
- If reference_wav provided: Prompt isolation mode (voice cloning)
- If no reference_wav: Voice design mode (use control_instruction to describe voice)
Returns (sample_rate, waveform_numpy) Returns (sample_rate, waveform_numpy)
""" """
current_model = self.get_or_load_voxcpm() current_model = self.get_or_load_voxcpm()
@@ -96,14 +101,25 @@ class VoxCPMDemo:
if len(text) == 0: if len(text) == 0:
raise ValueError("Please input text to synthesize.") raise ValueError("Please input text to synthesize.")
prompt_wav_path = prompt_wav_path_input if prompt_wav_path_input else None # 处理 control instruction
prompt_text = prompt_text_input if prompt_text_input else None control = (control_instruction or "").strip()
if control:
final_text = f"({control}){text}"
else:
final_text = text
print(f"Generating audio for text: '{text[:60]}...'", file=sys.stderr) reference_wav_path = reference_wav_path_input if reference_wav_path_input else None
# 判断模式
if reference_wav_path:
print(f"[Prompt Isolation Mode] reference_wav: {reference_wav_path}", file=sys.stderr)
else:
print(f"[Voice Design Mode] control: {control[:50] if control else 'None'}...", file=sys.stderr)
print(f"Generating audio for text: '{final_text[:80]}...'", file=sys.stderr)
wav = current_model.generate( wav = current_model.generate(
text=text, text=final_text,
prompt_text=prompt_text, reference_wav_path=reference_wav_path,
prompt_wav_path=prompt_wav_path,
cfg_value=float(cfg_value_input), cfg_value=float(cfg_value_input),
inference_timesteps=int(inference_timesteps_input), inference_timesteps=int(inference_timesteps_input),
normalize=do_normalize, normalize=do_normalize,
@@ -114,46 +130,53 @@ class VoxCPMDemo:
# ---------- UI Builders ---------- # ---------- UI Builders ----------
THEME = gr.themes.Soft(
primary_hue="blue",
secondary_hue="gray",
neutral_hue="slate",
font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"],
)
CSS = """
.logo-container {
text-align: center;
margin: 0.5rem 0 1rem 0;
}
.logo-container img {
height: 80px;
width: auto;
max-width: 200px;
display: inline-block;
}
/* Bold accordion labels */
#acc_quick > .label-wrap,
#acc_tips > .label-wrap,
#acc_quick > .label-wrap > span,
#acc_tips > .label-wrap > span,
#acc_quick summary,
#acc_tips summary {
font-weight: 600 !important;
font-size: 1.1em !important;
}
/* Bold labels for specific checkboxes */
#chk_denoise label,
#chk_denoise span,
#chk_normalize label,
#chk_normalize span {
font-weight: 600;
}
"""
def create_demo_interface(demo: VoxCPMDemo): def create_demo_interface(demo: VoxCPMDemo):
"""Build the Gradio UI for VoxCPM demo.""" """Build the Gradio UI for VoxCPM demo."""
# static assets (logo path) gr.set_static_paths(paths=[Path.cwd().absolute() / "assets"])
gr.set_static_paths(paths=[Path.cwd().absolute()/"assets"])
with gr.Blocks( with gr.Blocks() as interface:
theme=gr.themes.Soft( gr.HTML(
primary_hue="blue", '<div class="logo-container"><img src="/gradio_api/file=assets/voxcpm_logo.png" alt="VoxCPM Logo"></div>',
secondary_hue="gray", padding=True,
neutral_hue="slate", )
font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"]
),
css="""
.logo-container {
text-align: center;
margin: 0.5rem 0 1rem 0;
}
.logo-container img {
height: 80px;
width: auto;
max-width: 200px;
display: inline-block;
}
/* Bold accordion labels */
#acc_quick details > summary,
#acc_tips details > summary {
font-weight: 600 !important;
font-size: 1.1em !important;
}
/* Bold labels for specific checkboxes */
#chk_denoise label,
#chk_denoise span,
#chk_normalize label,
#chk_normalize span {
font-weight: 600;
}
"""
) as interface:
# Header logo
gr.HTML('<div class="logo-container"><img src="/gradio_api/file=assets/voxcpm_logo.png" alt="VoxCPM Logo"></div>')
# Quick Start # Quick Start
with gr.Accordion("📋 Quick Start Guide |快速入门", open=False, elem_id="acc_quick"): with gr.Accordion("📋 Quick Start Guide |快速入门", open=False, elem_id="acc_quick"):
@@ -200,34 +223,56 @@ def create_demo_interface(demo: VoxCPMDemo):
# Main controls # Main controls
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
prompt_wav = gr.Audio( # 1. Reference Audio
sources=["upload", 'microphone'], # gr.Markdown("### 🎤 Reference Audio (Optional)")
# gr.Markdown("*提供参考音频进行音色克隆;不提供则使用 Voice Design 模式*")
reference_wav = gr.Audio(
sources=["upload", "microphone"],
type="filepath", type="filepath",
label="Prompt Speech (Optional, or let VoxCPM improvise)", label="Reference Audio (Optional)",
value="./examples/example.wav",
) )
DoDenoisePromptAudio = gr.Checkbox( DoDenoisePromptAudio = gr.Checkbox(
value=False, value=False,
label="Prompt Speech Enhancement", label="Reference Audio Enhancement",
elem_id="chk_denoise", elem_id="chk_denoise",
info="We use ZipEnhancer model to denoise the prompt audio." info="Use ZipEnhancer to denoise the reference audio",
) )
with gr.Row():
prompt_text = gr.Textbox( # 2. Control Instruction
value="Just by listening a few minutes a day, you'll be able to eliminate negative thoughts by conditioning your mind to be more positive.", # gr.Markdown("### 🎛️ Control Instruction (Optional)")
label="Prompt Text", # gr.Markdown("*描述声音风格、情感等,格式:`(instruction) text`*")
placeholder="Please enter the prompt text. Automatic recognition is supported, and you can correct the results yourself..." control_instruction = gr.Textbox(
) value="",
run_btn = gr.Button("Generate Speech", variant="primary") label="Control Instruction",
placeholder="*描述声音风格、情感等,格式:`(instruction) text`,例如:年轻女性,温柔甜美 / 悲伤地说 / an excited young man*",
lines=2,
)
# 3. Target Text
# gr.Markdown("### 📝 Target Text")
text = gr.Textbox(
value="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly realistic speech.",
label="Target Text",
lines=3,
)
DoNormalizeText = gr.Checkbox(
value=False,
label="Text Normalization",
elem_id="chk_normalize",
info="Use wetext library to normalize the input text",
)
run_btn = gr.Button("🔊 Generate Speech", variant="primary", size="lg")
with gr.Column(): with gr.Column():
gr.Markdown("### ⚙️ Generation Settings")
cfg_value = gr.Slider( cfg_value = gr.Slider(
minimum=1.0, minimum=1.0,
maximum=3.0, maximum=3.0,
value=2.0, value=2.0,
step=0.1, step=0.1,
label="CFG Value (Guidance Scale)", label="CFG Value (Guidance Scale)",
info="Higher values increase adherence to prompt, lower values allow more creativity" info="Higher = more adherence to prompt; Lower = more creativity",
) )
inference_timesteps = gr.Slider( inference_timesteps = gr.Slider(
minimum=4, minimum=4,
@@ -235,40 +280,54 @@ def create_demo_interface(demo: VoxCPMDemo):
value=10, value=10,
step=1, step=1,
label="Inference Timesteps", label="Inference Timesteps",
info="Number of inference timesteps for generation (higher values may improve quality but slower)" info="Higher = better quality but slower",
) )
with gr.Row():
text = gr.Textbox( gr.Markdown("### 🔈 Output")
value="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly realistic speech.", audio_output = gr.Audio(label="Generated Audio")
label="Target Text",
) gr.Markdown("""
with gr.Row(): ---
DoNormalizeText = gr.Checkbox( **模式说明 / Mode Info:**
value=False, - **有 Reference Audio** → Prompt 隔离模式(音色克隆)
label="Text Normalization", - **无 Reference Audio** → Voice Design 模式(用 Control Instruction 描述声音)
elem_id="chk_normalize",
info="We use wetext library to normalize the input text." **Control Instruction 示例:**
) - `年轻女性,温柔甜美`
audio_output = gr.Audio(label="Output Audio") - `悲伤地说`
- `an excited young man`
""")
# Wiring # Wiring
run_btn.click( run_btn.click(
fn=demo.generate_tts_audio, fn=demo.generate_tts_audio,
inputs=[text, prompt_wav, prompt_text, cfg_value, inference_timesteps, DoNormalizeText, DoDenoisePromptAudio], inputs=[
text,
control_instruction,
reference_wav,
cfg_value,
inference_timesteps,
DoNormalizeText,
DoDenoisePromptAudio,
],
outputs=[audio_output], outputs=[audio_output],
show_progress=True, show_progress=True,
api_name="generate", api_name="generate",
) )
prompt_wav.change(fn=demo.prompt_wav_recognition, inputs=[prompt_wav], outputs=[prompt_text])
return interface return interface
def run_demo(server_name: str = "localhost", server_port: int = 7860, show_error: bool = True): def run_demo(server_name: str = "0.0.0.0", server_port: int = 7869, show_error: bool = True):
demo = VoxCPMDemo() demo = VoxCPMDemo()
interface = create_demo_interface(demo) interface = create_demo_interface(demo)
# Recommended to enable queue on Spaces for better throughput interface.queue(max_size=10, default_concurrency_limit=1).launch(
interface.queue(max_size=10, default_concurrency_limit=1).launch(server_name=server_name, server_port=server_port, show_error=show_error) server_name=server_name,
server_port=server_port,
show_error=show_error,
theme=THEME,
css=CSS,
)
if __name__ == "__main__": if __name__ == "__main__":
+1 -1
View File
@@ -25,7 +25,7 @@ lora:
enable_lm: true enable_lm: true
enable_dit: true enable_dit: true
enable_proj: false enable_proj: false
r: 32 r: 8
alpha: 16 alpha: 16
dropout: 0.0 dropout: 0.0
+1 -1
View File
@@ -25,7 +25,7 @@ lora:
enable_lm: true enable_lm: true
enable_dit: true enable_dit: true
enable_proj: false enable_proj: false
r: 32 r: 8
alpha: 16 alpha: 16
dropout: 0.0 dropout: 0.0
Binary file not shown.
+210 -223
View File
@@ -1,18 +1,14 @@
import os import os
import sys import sys
import time
import glob
import json import json
import yaml import yaml
import shutil
import datetime import datetime
import subprocess import subprocess
import threading import threading
import gradio as gr import gradio as gr
import torch import torch
import soundfile as sf
from pathlib import Path from pathlib import Path
from typing import Optional, List from typing import Optional
# Add src to sys.path # Add src to sys.path
project_root = Path(__file__).parent project_root = Path(__file__).parent
@@ -89,7 +85,7 @@ LANG_DICT = {
"lang_select": "Language / 语言", "lang_select": "Language / 语言",
"refresh": "刷新", "refresh": "刷新",
"output_name": "输出目录名称 (可选,若存在则继续训练)", "output_name": "输出目录名称 (可选,若存在则继续训练)",
} },
} }
# Global variables # Global variables
@@ -98,9 +94,11 @@ asr_model: Optional[AutoModel] = None
training_process: Optional[subprocess.Popen] = None training_process: Optional[subprocess.Popen] = None
training_log = "" training_log = ""
def get_timestamp_str(): def get_timestamp_str():
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S") return datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
def get_or_load_asr_model(): def get_or_load_asr_model():
global asr_model global asr_model
if asr_model is None: if asr_model is None:
@@ -109,23 +107,25 @@ def get_or_load_asr_model():
asr_model = AutoModel( asr_model = AutoModel(
model="iic/SenseVoiceSmall", model="iic/SenseVoiceSmall",
disable_update=True, disable_update=True,
log_level='ERROR', log_level="ERROR",
device=device, device=device,
) )
return asr_model return asr_model
def recognize_audio(audio_path): def recognize_audio(audio_path):
if not audio_path: if not audio_path:
return "" return ""
try: try:
model = get_or_load_asr_model() model = get_or_load_asr_model()
res = model.generate(input=audio_path, language="auto", use_itn=True) res = model.generate(input=audio_path, language="auto", use_itn=True)
text = res[0]["text"].split('|>')[-1] text = res[0]["text"].split("|>")[-1]
return text return text
except Exception as e: except Exception as e:
print(f"ASR Error: {e}", file=sys.stderr) print(f"ASR Error: {e}", file=sys.stderr)
return "" return ""
def scan_lora_checkpoints(root_dir="lora", with_info=False): def scan_lora_checkpoints(root_dir="lora", with_info=False):
""" """
Scans for LoRA checkpoints in the lora directory. Scans for LoRA checkpoints in the lora directory.
@@ -165,11 +165,12 @@ def scan_lora_checkpoints(root_dir="lora", with_info=False):
# Also check for checkpoints in the default location if they exist # Also check for checkpoints in the default location if they exist
default_ckpt = "checkpoints/finetune_lora" default_ckpt = "checkpoints/finetune_lora"
if os.path.exists(os.path.join(root_dir, default_ckpt)): if os.path.exists(os.path.join(root_dir, default_ckpt)):
# This might be covered by the walk, but good to be sure # This might be covered by the walk, but good to be sure
pass pass
return sorted(checkpoints, reverse=True) return sorted(checkpoints, reverse=True)
def load_lora_config_from_checkpoint(lora_path): def load_lora_config_from_checkpoint(lora_path):
"""Load LoRA config from lora_config.json if available.""" """Load LoRA config from lora_config.json if available."""
lora_config_file = os.path.join(lora_path, "lora_config.json") lora_config_file = os.path.join(lora_path, "lora_config.json")
@@ -184,6 +185,7 @@ def load_lora_config_from_checkpoint(lora_path):
print(f"Warning: Failed to load lora_config.json: {e}", file=sys.stderr) print(f"Warning: Failed to load lora_config.json: {e}", file=sys.stderr)
return None, None return None, None
def get_default_lora_config(): def get_default_lora_config():
"""Return default LoRA config for hot-swapping support.""" """Return default LoRA config for hot-swapping support."""
return LoRAConfig( return LoRAConfig(
@@ -192,9 +194,10 @@ def get_default_lora_config():
r=32, r=32,
alpha=16, alpha=16,
target_modules_lm=["q_proj", "v_proj", "k_proj", "o_proj"], target_modules_lm=["q_proj", "v_proj", "k_proj", "o_proj"],
target_modules_dit=["q_proj", "v_proj", "k_proj", "o_proj"] target_modules_dit=["q_proj", "v_proj", "k_proj", "o_proj"],
) )
def load_model(pretrained_path, lora_path=None): def load_model(pretrained_path, lora_path=None):
global current_model global current_model
print(f"Loading model from {pretrained_path}...", file=sys.stderr) print(f"Loading model from {pretrained_path}...", file=sys.stderr)
@@ -228,9 +231,8 @@ def load_model(pretrained_path, lora_path=None):
) )
return "Model loaded successfully!" return "Model loaded successfully!"
def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, steps, seed, pretrained_path=None):
global current_model
def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, steps, seed, pretrained_path=None):
# 如果选择了 LoRA 模型且当前模型未加载,尝试从 LoRA config 读取 base_model # 如果选择了 LoRA 模型且当前模型未加载,尝试从 LoRA config 读取 base_model
if current_model is None: if current_model is None:
# 优先使用用户指定的预训练模型路径 # 优先使用用户指定的预训练模型路径
@@ -261,7 +263,7 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
# 加载模型 # 加载模型
try: try:
print(f"Loading base model: {base_model_path}", file=sys.stderr) print(f"Loading base model: {base_model_path}", file=sys.stderr)
status_msg = load_model(base_model_path) load_model(base_model_path)
if lora_selection and lora_selection != "None": if lora_selection and lora_selection != "None":
print(f"Model loaded for LoRA: {lora_selection}", file=sys.stderr) print(f"Model loaded for LoRA: {lora_selection}", file=sys.stderr)
except Exception as e: except Exception as e:
@@ -270,6 +272,7 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
return None, error_msg return None, error_msg
# Handle LoRA hot-swapping # Handle LoRA hot-swapping
assert current_model is not None, "Model must be loaded before inference"
if lora_selection and lora_selection != "None": if lora_selection and lora_selection != "None":
full_lora_path = os.path.join("lora", lora_selection) full_lora_path = os.path.join("lora", lora_selection)
print(f"Hot-loading LoRA: {full_lora_path}", file=sys.stderr) print(f"Hot-loading LoRA: {full_lora_path}", file=sys.stderr)
@@ -317,14 +320,16 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
prompt_text=final_prompt_text, prompt_text=final_prompt_text,
cfg_value=cfg_scale, cfg_value=cfg_scale,
inference_timesteps=steps, inference_timesteps=steps,
denoise=False denoise=False,
) )
return (current_model.tts_model.sample_rate, audio_np), "Generation Success" return (current_model.tts_model.sample_rate, audio_np), "Generation Success"
except Exception as e: except Exception as e:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
return None, f"Error: {str(e)}" return None, f"Error: {str(e)}"
def start_training( def start_training(
pretrained_path, pretrained_path,
train_manifest, train_manifest,
@@ -355,7 +360,7 @@ def start_training(
hf_model_id="", hf_model_id="",
distribute=False, distribute=False,
): ):
global training_process, training_log global training_log
if training_process is not None and training_process.poll() is None: if training_process is not None and training_process.poll() is None:
return "Training is already running!" return "Training is already running!"
@@ -394,10 +399,7 @@ def start_training(
"max_steps": resolved_max_steps, "max_steps": resolved_max_steps,
"save_path": checkpoints_dir, "save_path": checkpoints_dir,
"tensorboard": tensorboard_path if tensorboard_path else logs_dir, "tensorboard": tensorboard_path if tensorboard_path else logs_dir,
"lambdas": { "lambdas": {"loss/diff": 1.0, "loss/stop": 1.0},
"loss/diff": 1.0,
"loss/stop": 1.0
},
"lora": { "lora": {
"enable_lm": bool(enable_lm), "enable_lm": bool(enable_lm),
"enable_dit": bool(enable_dit), "enable_dit": bool(enable_dit),
@@ -406,7 +408,7 @@ def start_training(
"alpha": int(lora_alpha), "alpha": int(lora_alpha),
"dropout": float(dropout), "dropout": float(dropout),
"target_modules_lm": ["q_proj", "v_proj", "k_proj", "o_proj"], "target_modules_lm": ["q_proj", "v_proj", "k_proj", "o_proj"],
"target_modules_dit": ["q_proj", "v_proj", "k_proj", "o_proj"] "target_modules_dit": ["q_proj", "v_proj", "k_proj", "o_proj"],
}, },
} }
@@ -420,25 +422,15 @@ def start_training(
with open(config_path, "w") as f: with open(config_path, "w") as f:
yaml.dump(config, f) yaml.dump(config, f)
cmd = [ cmd = [sys.executable, "scripts/train_voxcpm_finetune.py", "--config_path", config_path]
sys.executable,
"scripts/train_voxcpm_finetune.py",
"--config_path",
config_path
]
training_log = f"Starting training...\nConfig saved to {config_path}\nOutput dir: {save_dir}\n" training_log = f"Starting training...\nConfig saved to {config_path}\nOutput dir: {save_dir}\n"
def run_process(): def run_process():
global training_process, training_log global training_process, training_log
training_process = subprocess.Popen( training_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1
)
assert training_process.stdout is not None
for line in training_process.stdout: for line in training_process.stdout:
training_log += line training_log += line
# Keep log size manageable # Keep log size manageable
@@ -452,17 +444,20 @@ def start_training(
return f"Training started! Check 'lora/{timestamp}'" return f"Training started! Check 'lora/{timestamp}'"
def get_training_log(): def get_training_log():
return training_log return training_log
def stop_training(): def stop_training():
global training_process, training_log global training_log
if training_process is not None and training_process.poll() is None: if training_process is not None and training_process.poll() is None:
training_process.terminate() training_process.terminate()
training_log += "\nTraining terminated by user." training_log += "\nTraining terminated by user."
return "Training stopped." return "Training stopped."
return "No training running." return "No training running."
# --- GUI Layout --- # --- GUI Layout ---
# 自定义CSS样式 # 自定义CSS样式
@@ -830,14 +825,10 @@ label {
} }
""" """
with gr.Blocks( with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css) as app:
title="VoxCPM LoRA WebUI",
theme=gr.themes.Soft(),
css=custom_css
) as app:
# State for language # State for language
lang_state = gr.State("zh") # Default to Chinese lang_state = gr.State("zh") # Default to Chinese
# 标题区域 # 标题区域
with gr.Row(elem_classes="title-section"): with gr.Row(elem_classes="title-section"):
@@ -850,10 +841,7 @@ with gr.Blocks(
""") """)
with gr.Column(scale=1): with gr.Column(scale=1):
lang_btn = gr.Radio( lang_btn = gr.Radio(
choices=["en", "zh"], choices=["en", "zh"], value="zh", label="🌐 Language / 语言", elem_classes="lang-selector"
value="zh",
label="🌐 Language / 语言",
elem_classes="lang-selector"
) )
with gr.Tabs(elem_classes="tabs") as tabs: with gr.Tabs(elem_classes="tabs") as tabs:
@@ -869,79 +857,40 @@ with gr.Blocks(
gr.Markdown("#### 📁 基础配置") gr.Markdown("#### 📁 基础配置")
train_pretrained_path = gr.Textbox( train_pretrained_path = gr.Textbox(
label="📂 预训练模型路径", label="📂 预训练模型路径", value=default_pretrained_path, elem_classes="input-field"
value=default_pretrained_path,
elem_classes="input-field"
) )
train_manifest = gr.Textbox( train_manifest = gr.Textbox(
label="📋 训练数据清单 (jsonl)", label="📋 训练数据清单 (jsonl)",
value="examples/train_data_example.jsonl", value="examples/train_data_example.jsonl",
elem_classes="input-field" elem_classes="input-field",
)
val_manifest = gr.Textbox(
label="📊 验证数据清单 (可选)",
value="",
elem_classes="input-field"
) )
val_manifest = gr.Textbox(label="📊 验证数据清单 (可选)", value="", elem_classes="input-field")
gr.Markdown("#### ⚙️ 训练参数") gr.Markdown("#### ⚙️ 训练参数")
with gr.Row(): with gr.Row():
lr = gr.Number( lr = gr.Number(label="📈 学习率 (Learning Rate)", value=1e-4, elem_classes="input-field")
label="📈 学习率 (Learning Rate)",
value=1e-4,
elem_classes="input-field"
)
num_iters = gr.Number( num_iters = gr.Number(
label="🔄 最大迭代次数", label="🔄 最大迭代次数", value=2000, precision=0, elem_classes="input-field"
value=2000,
precision=0,
elem_classes="input-field"
) )
batch_size = gr.Number( batch_size = gr.Number(
label="📦 批次大小 (Batch Size)", label="📦 批次大小 (Batch Size)", value=1, precision=0, elem_classes="input-field"
value=1,
precision=0,
elem_classes="input-field"
) )
with gr.Row(): with gr.Row():
lora_rank = gr.Number( lora_rank = gr.Number(label="🎯 LoRA Rank", value=32, precision=0, elem_classes="input-field")
label="🎯 LoRA Rank", lora_alpha = gr.Number(label="⚖️ LoRA Alpha", value=16, precision=0, elem_classes="input-field")
value=32,
precision=0,
elem_classes="input-field"
)
lora_alpha = gr.Number(
label="⚖️ LoRA Alpha",
value=16,
precision=0,
elem_classes="input-field"
)
save_interval = gr.Number( save_interval = gr.Number(
label="💾 保存间隔 (Steps)", label="💾 保存间隔 (Steps)", value=1000, precision=0, elem_classes="input-field"
value=1000,
precision=0,
elem_classes="input-field"
) )
output_name = gr.Textbox( output_name = gr.Textbox(
label="📁 输出目录名称 (可选,若存在则继续训练)", label="📁 输出目录名称 (可选,若存在则继续训练)", value="", elem_classes="input-field"
value="",
elem_classes="input-field"
) )
with gr.Row(): with gr.Row():
start_btn = gr.Button( start_btn = gr.Button("▶️ 开始训练", variant="primary", elem_classes="button-primary")
" 开始训练", stop_btn = gr.Button(" 停止训练", variant="stop", elem_classes="button-stop")
variant="primary",
elem_classes="button-primary"
)
stop_btn = gr.Button(
"⏹️ 停止训练",
variant="stop",
elem_classes="button-stop"
)
with gr.Accordion("🔧 高级选项 (Advanced)", open=False, elem_classes="accordion"): with gr.Accordion("🔧 高级选项 (Advanced)", open=False, elem_classes="accordion"):
with gr.Row(): with gr.Row():
@@ -964,7 +913,9 @@ with gr.Blocks(
gr.Markdown("#### 分发选项 (Distribution)") gr.Markdown("#### 分发选项 (Distribution)")
with gr.Row(): with gr.Row():
hf_model_id = gr.Textbox(label="HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)", value="openbmb/VoxCPM1.5") hf_model_id = gr.Textbox(
label="HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)", value="openbmb/VoxCPM1.5"
)
distribute = gr.Checkbox(label="分发模式 (distribute)", value=False) distribute = gr.Checkbox(label="分发模式 (distribute)", value=False)
with gr.Column(scale=2, elem_classes="form-section"): with gr.Column(scale=2, elem_classes="form-section"):
@@ -975,23 +926,41 @@ with gr.Blocks(
max_lines=30, max_lines=30,
interactive=False, interactive=False,
elem_classes="input-field", elem_classes="input-field",
show_label=False show_label=False,
) )
start_btn.click( start_btn.click(
start_training, start_training,
inputs=[ inputs=[
train_pretrained_path, train_manifest, val_manifest, train_pretrained_path,
lr, num_iters, batch_size, lora_rank, lora_alpha, save_interval, train_manifest,
val_manifest,
lr,
num_iters,
batch_size,
lora_rank,
lora_alpha,
save_interval,
output_name, output_name,
# advanced # advanced
grad_accum_steps, num_workers, log_interval, valid_interval, grad_accum_steps,
weight_decay, warmup_steps, max_steps, sample_rate, num_workers,
enable_lm, enable_dit, enable_proj, dropout, tensorboard_path, log_interval,
valid_interval,
weight_decay,
warmup_steps,
max_steps,
sample_rate,
enable_lm,
enable_dit,
enable_proj,
dropout,
tensorboard_path,
# distribution # distribution
hf_model_id, distribute hf_model_id,
distribute,
], ],
outputs=[logs_out] # Initial message outputs=[logs_out], # Initial message
) )
stop_btn.click(stop_training, outputs=[logs_out]) stop_btn.click(stop_training, outputs=[logs_out])
@@ -1016,21 +985,17 @@ with gr.Blocks(
value="Hello, this is a test of the VoxCPM LoRA model.", value="Hello, this is a test of the VoxCPM LoRA model.",
elem_classes="input-field", elem_classes="input-field",
lines=4, lines=4,
placeholder="输入要合成的文本内容..." placeholder="输入要合成的文本内容...",
) )
gr.Markdown("**🎭 声音克隆(可选)**") gr.Markdown("**🎭 声音克隆(可选)**")
prompt_wav = gr.Audio( prompt_wav = gr.Audio(label="🎵 参考音频", type="filepath", elem_classes="input-field")
label="🎵 参考音频",
type="filepath",
elem_classes="input-field"
)
prompt_text = gr.Textbox( prompt_text = gr.Textbox(
label="📝 参考文本(可选)", label="📝 参考文本(可选)",
elem_classes="input-field", elem_classes="input-field",
placeholder="如不填写,将自动识别参考音频内容" placeholder="如不填写,将自动识别参考音频内容",
) )
# 中栏:模型选择和参数配置 (35%) # 中栏:模型选择和参数配置 (35%)
@@ -1043,14 +1008,10 @@ with gr.Blocks(
value="None", value="None",
interactive=True, interactive=True,
elem_classes="input-field", elem_classes="input-field",
info="选择训练好的 LoRA 模型,或选择 None 使用基础模型" info="选择训练好的 LoRA 模型,或选择 None 使用基础模型",
) )
refresh_lora_btn = gr.Button( refresh_lora_btn = gr.Button("🔄 刷新模型列表", elem_classes="button-refresh", size="sm")
"🔄 刷新模型列表",
elem_classes="button-refresh",
size="sm"
)
gr.Markdown("#### ⚙️ 生成参数") gr.Markdown("#### ⚙️ 生成参数")
@@ -1060,7 +1021,7 @@ with gr.Blocks(
maximum=5.0, maximum=5.0,
value=2.0, value=2.0,
step=0.1, step=0.1,
info="引导系数,值越大越贴近提示" info="引导系数,值越大越贴近提示",
) )
steps = gr.Slider( steps = gr.Slider(
@@ -1069,7 +1030,7 @@ with gr.Blocks(
maximum=50, maximum=50,
value=10, value=10,
step=1, step=1,
info="生成质量与步数成正比,但耗时更长" info="生成质量与步数成正比,但耗时更长",
) )
seed = gr.Number( seed = gr.Number(
@@ -1077,25 +1038,16 @@ with gr.Blocks(
value=-1, value=-1,
precision=0, precision=0,
elem_classes="input-field", elem_classes="input-field",
info="-1 为随机,固定值可复现结果" info="-1 为随机,固定值可复现结果",
) )
generate_btn = gr.Button( generate_btn = gr.Button("🎵 生成音频", variant="primary", elem_classes="button-primary", size="lg")
"🎵 生成音频",
variant="primary",
elem_classes="button-primary",
size="lg"
)
# 右栏:生成结果 (30%) # 右栏:生成结果 (30%)
with gr.Column(scale=30, elem_classes="form-section"): with gr.Column(scale=30, elem_classes="form-section"):
gr.Markdown("#### 🎧 生成结果") gr.Markdown("#### 🎧 生成结果")
audio_out = gr.Audio( audio_out = gr.Audio(label="", elem_classes="input-field", show_label=False)
label="",
elem_classes="input-field",
show_label=False
)
gr.Markdown("#### 📋 状态信息") gr.Markdown("#### 📋 状态信息")
@@ -1105,7 +1057,7 @@ with gr.Blocks(
elem_classes="input-field", elem_classes="input-field",
show_label=False, show_label=False,
lines=3, lines=3,
placeholder="等待生成..." placeholder="等待生成...",
) )
def refresh_loras(): def refresh_loras():
@@ -1126,16 +1078,21 @@ with gr.Blocks(
refresh_lora_btn.click(refresh_loras, outputs=[lora_select]) refresh_lora_btn.click(refresh_loras, outputs=[lora_select])
# Auto-recognize audio when uploaded # Auto-recognize audio when uploaded
prompt_wav.change( prompt_wav.change(fn=recognize_audio, inputs=[prompt_wav], outputs=[prompt_text])
fn=recognize_audio,
inputs=[prompt_wav],
outputs=[prompt_text]
)
generate_btn.click( generate_btn.click(
run_inference, run_inference,
inputs=[infer_text, prompt_wav, prompt_text, lora_select, cfg_scale, steps, seed, train_pretrained_path], inputs=[
outputs=[audio_out, status_out] infer_text,
prompt_wav,
prompt_text,
lora_select,
cfg_scale,
steps,
seed,
train_pretrained_path,
],
outputs=[audio_out, status_out],
) )
# --- Language Switching Logic --- # --- Language Switching Logic ---
@@ -1144,108 +1101,138 @@ with gr.Blocks(
# Labels for advanced options # Labels for advanced options
if lang == "zh": if lang == "zh":
adv = { adv = {
'grad_accum_steps': "梯度累积 (grad_accum_steps)", "grad_accum_steps": "梯度累积 (grad_accum_steps)",
'num_workers': "数据加载线程 (num_workers)", "num_workers": "数据加载线程 (num_workers)",
'log_interval': "日志间隔 (log_interval)", "log_interval": "日志间隔 (log_interval)",
'valid_interval': "验证间隔 (valid_interval)", "valid_interval": "验证间隔 (valid_interval)",
'weight_decay': "权重衰减 (weight_decay)", "weight_decay": "权重衰减 (weight_decay)",
'warmup_steps': "warmup_steps", "warmup_steps": "warmup_steps",
'max_steps': "最大步数 (max_steps)", "max_steps": "最大步数 (max_steps)",
'sample_rate': "采样率 (sample_rate)", "sample_rate": "采样率 (sample_rate)",
'enable_lm': "启用 LoRA LM (enable_lm)", "enable_lm": "启用 LoRA LM (enable_lm)",
'enable_dit': "启用 LoRA DIT (enable_dit)", "enable_dit": "启用 LoRA DIT (enable_dit)",
'enable_proj': "启用投影 (enable_proj)", "enable_proj": "启用投影 (enable_proj)",
'dropout': "LoRA Dropout", "dropout": "LoRA Dropout",
'tensorboard_path': "Tensorboard 路径 (可选)", "tensorboard_path": "Tensorboard 路径 (可选)",
'hf_model_id': "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)", "hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
'distribute': "分发模式 (distribute)", "distribute": "分发模式 (distribute)",
} }
else: else:
adv = { adv = {
'grad_accum_steps': "Grad Accum Steps", "grad_accum_steps": "Grad Accum Steps",
'num_workers': "Num Workers", "num_workers": "Num Workers",
'log_interval': "Log Interval", "log_interval": "Log Interval",
'valid_interval': "Valid Interval", "valid_interval": "Valid Interval",
'weight_decay': "Weight Decay", "weight_decay": "Weight Decay",
'warmup_steps': "Warmup Steps", "warmup_steps": "Warmup Steps",
'max_steps': "Max Steps", "max_steps": "Max Steps",
'sample_rate': "Sample Rate", "sample_rate": "Sample Rate",
'enable_lm': "Enable LoRA LM", "enable_lm": "Enable LoRA LM",
'enable_dit': "Enable LoRA DIT", "enable_dit": "Enable LoRA DIT",
'enable_proj': "Enable Projection", "enable_proj": "Enable Projection",
'dropout': "LoRA Dropout", "dropout": "LoRA Dropout",
'tensorboard_path': "Tensorboard Path (Optional)", "tensorboard_path": "Tensorboard Path (Optional)",
'hf_model_id': "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)", "hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
'distribute': "Distribute Mode", "distribute": "Distribute Mode",
} }
return ( return (
gr.update(value=f"# {d['title']}"), gr.update(value=f"# {d['title']}"),
gr.update(label=d['tab_train']), gr.update(label=d["tab_train"]),
gr.update(label=d['tab_infer']), gr.update(label=d["tab_infer"]),
gr.update(label=d['pretrained_path']), gr.update(label=d["pretrained_path"]),
gr.update(label=d['train_manifest']), gr.update(label=d["train_manifest"]),
gr.update(label=d['val_manifest']), gr.update(label=d["val_manifest"]),
gr.update(label=d['lr']), gr.update(label=d["lr"]),
gr.update(label=d['max_iters']), gr.update(label=d["max_iters"]),
gr.update(label=d['batch_size']), gr.update(label=d["batch_size"]),
gr.update(label=d['lora_rank']), gr.update(label=d["lora_rank"]),
gr.update(label=d['lora_alpha']), gr.update(label=d["lora_alpha"]),
gr.update(label=d['save_interval']), gr.update(label=d["save_interval"]),
gr.update(label=d['output_name']), gr.update(label=d["output_name"]),
gr.update(value=d['start_train']), gr.update(value=d["start_train"]),
gr.update(value=d['stop_train']), gr.update(value=d["stop_train"]),
gr.update(label=d['train_logs']), gr.update(label=d["train_logs"]),
# Advanced options (must match outputs order) # Advanced options (must match outputs order)
gr.update(label=adv['grad_accum_steps']), gr.update(label=adv["grad_accum_steps"]),
gr.update(label=adv['num_workers']), gr.update(label=adv["num_workers"]),
gr.update(label=adv['log_interval']), gr.update(label=adv["log_interval"]),
gr.update(label=adv['valid_interval']), gr.update(label=adv["valid_interval"]),
gr.update(label=adv['weight_decay']), gr.update(label=adv["weight_decay"]),
gr.update(label=adv['warmup_steps']), gr.update(label=adv["warmup_steps"]),
gr.update(label=adv['max_steps']), gr.update(label=adv["max_steps"]),
gr.update(label=adv['sample_rate']), gr.update(label=adv["sample_rate"]),
gr.update(label=adv['enable_lm']), gr.update(label=adv["enable_lm"]),
gr.update(label=adv['enable_dit']), gr.update(label=adv["enable_dit"]),
gr.update(label=adv['enable_proj']), gr.update(label=adv["enable_proj"]),
gr.update(label=adv['dropout']), gr.update(label=adv["dropout"]),
gr.update(label=adv['tensorboard_path']), gr.update(label=adv["tensorboard_path"]),
# Distribution options # Distribution options
gr.update(label=adv['hf_model_id']), gr.update(label=adv["hf_model_id"]),
gr.update(label=adv['distribute']), gr.update(label=adv["distribute"]),
# Inference section # Inference section
gr.update(label=d['text_to_synth']), gr.update(label=d["text_to_synth"]),
gr.update(label=d['ref_audio']), gr.update(label=d["ref_audio"]),
gr.update(label=d['ref_text']), gr.update(label=d["ref_text"]),
gr.update(label=d['select_lora']), gr.update(label=d["select_lora"]),
gr.update(value=d['refresh']), gr.update(value=d["refresh"]),
gr.update(label=d['cfg_scale']), gr.update(label=d["cfg_scale"]),
gr.update(label=d['infer_steps']), gr.update(label=d["infer_steps"]),
gr.update(label=d['seed']), gr.update(label=d["seed"]),
gr.update(value=d['gen_audio']), gr.update(value=d["gen_audio"]),
gr.update(label=d['gen_output']), gr.update(label=d["gen_output"]),
gr.update(label=d['status']), gr.update(label=d["status"]),
) )
lang_btn.change( lang_btn.change(
change_language, change_language,
inputs=[lang_btn], inputs=[lang_btn],
outputs=[ outputs=[
title_md, tab_train, tab_infer, title_md,
train_pretrained_path, train_manifest, val_manifest, tab_train,
lr, num_iters, batch_size, lora_rank, lora_alpha, save_interval, tab_infer,
train_pretrained_path,
train_manifest,
val_manifest,
lr,
num_iters,
batch_size,
lora_rank,
lora_alpha,
save_interval,
output_name, output_name,
start_btn, stop_btn, logs_out, start_btn,
stop_btn,
logs_out,
# advanced outputs # advanced outputs
grad_accum_steps, num_workers, log_interval, valid_interval, grad_accum_steps,
weight_decay, warmup_steps, max_steps, sample_rate, num_workers,
enable_lm, enable_dit, enable_proj, dropout, tensorboard_path, log_interval,
valid_interval,
weight_decay,
warmup_steps,
max_steps,
sample_rate,
enable_lm,
enable_dit,
enable_proj,
dropout,
tensorboard_path,
# distribution outputs # distribution outputs
hf_model_id, distribute, hf_model_id,
infer_text, prompt_wav, prompt_text, distribute,
lora_select, refresh_lora_btn, cfg_scale, steps, seed, infer_text,
generate_btn, audio_out, status_out prompt_wav,
] prompt_text,
lora_select,
refresh_lora_btn,
cfg_scale,
steps,
seed,
generate_btn,
audio_out,
status_out,
],
) )
if __name__ == "__main__": if __name__ == "__main__":
+1 -3
View File
@@ -30,7 +30,7 @@ dependencies = [
"torchcodec", "torchcodec",
"transformers>=4.36.2", "transformers>=4.36.2",
"einops", "einops",
"gradio<6", "gradio>=6,<7",
"inflect", "inflect",
"addict", "addict",
"wetext", "wetext",
@@ -57,7 +57,6 @@ dev = [
"pytest-cov>=2.0", "pytest-cov>=2.0",
"black>=21.0", "black>=21.0",
"flake8>=3.8", "flake8>=3.8",
"mypy>=0.800",
"pre-commit>=2.0", "pre-commit>=2.0",
] ]
@@ -90,7 +89,6 @@ extend-exclude = '''
\.eggs \.eggs
| \.git | \.git
| \.hg | \.hg
| \.mypy_cache
| \.tox | \.tox
| \.venv | \.venv
| build | build
+4 -1
View File
@@ -125,7 +125,10 @@ def main():
out_path.parent.mkdir(parents=True, exist_ok=True) out_path.parent.mkdir(parents=True, exist_ok=True)
sf.write(str(out_path), audio_np, model.tts_model.sample_rate) sf.write(str(out_path), audio_np, model.tts_model.sample_rate)
print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr) print(
f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
file=sys.stderr,
)
if __name__ == "__main__": if __name__ == "__main__":
+30 -13
View File
@@ -127,7 +127,9 @@ def main():
print(f"Loaded config from: {lora_config_path}", file=sys.stderr) print(f"Loaded config from: {lora_config_path}", file=sys.stderr)
print(f" Base model: {pretrained_path}", file=sys.stderr) print(f" Base model: {pretrained_path}", file=sys.stderr)
print(f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None", file=sys.stderr) print(
f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None", file=sys.stderr
)
# 3. Load model with LoRA (no denoiser) # 3. Load model with LoRA (no denoiser)
print(f"\n[1/2] Loading model with LoRA: {pretrained_path}", file=sys.stderr) print(f"\n[1/2] Loading model with LoRA: {pretrained_path}", file=sys.stderr)
@@ -146,10 +148,10 @@ def main():
out_path = Path(args.output) out_path = Path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True) out_path.parent.mkdir(parents=True, exist_ok=True)
print(f"\n[2/2] Starting synthesis tests...", file=sys.stderr) print("\n[2/2] Starting synthesis tests...", file=sys.stderr)
# === Test 1: With LoRA === # === Test 1: With LoRA ===
print(f"\n [Test 1] Synthesize with LoRA...", file=sys.stderr) print("\n [Test 1] Synthesize with LoRA...", file=sys.stderr)
audio_np = model.generate( audio_np = model.generate(
text=args.text, text=args.text,
prompt_wav_path=prompt_wav_path, prompt_wav_path=prompt_wav_path,
@@ -162,10 +164,13 @@ def main():
) )
lora_output = out_path.with_stem(out_path.stem + "_with_lora") lora_output = out_path.with_stem(out_path.stem + "_with_lora")
sf.write(str(lora_output), audio_np, model.tts_model.sample_rate) sf.write(str(lora_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr) print(
f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
file=sys.stderr,
)
# === Test 2: Disable LoRA (via set_lora_enabled) === # === Test 2: Disable LoRA (via set_lora_enabled) ===
print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...", file=sys.stderr) print("\n [Test 2] Disable LoRA (set_lora_enabled=False)...", file=sys.stderr)
model.set_lora_enabled(False) model.set_lora_enabled(False)
audio_np = model.generate( audio_np = model.generate(
text=args.text, text=args.text,
@@ -179,10 +184,13 @@ def main():
) )
disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled") disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled")
sf.write(str(disabled_output), audio_np, model.tts_model.sample_rate) sf.write(str(disabled_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr) print(
f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
file=sys.stderr,
)
# === Test 3: Re-enable LoRA === # === Test 3: Re-enable LoRA ===
print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...", file=sys.stderr) print("\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...", file=sys.stderr)
model.set_lora_enabled(True) model.set_lora_enabled(True)
audio_np = model.generate( audio_np = model.generate(
text=args.text, text=args.text,
@@ -196,10 +204,13 @@ def main():
) )
reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled") reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled")
sf.write(str(reenabled_output), audio_np, model.tts_model.sample_rate) sf.write(str(reenabled_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr) print(
f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
file=sys.stderr,
)
# === Test 4: Unload LoRA (reset_lora_weights) === # === Test 4: Unload LoRA (reset_lora_weights) ===
print(f"\n [Test 4] Unload LoRA (unload_lora)...", file=sys.stderr) print("\n [Test 4] Unload LoRA (unload_lora)...", file=sys.stderr)
model.unload_lora() model.unload_lora()
audio_np = model.generate( audio_np = model.generate(
text=args.text, text=args.text,
@@ -213,10 +224,13 @@ def main():
) )
reset_output = out_path.with_stem(out_path.stem + "_lora_reset") reset_output = out_path.with_stem(out_path.stem + "_lora_reset")
sf.write(str(reset_output), audio_np, model.tts_model.sample_rate) sf.write(str(reset_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr) print(
f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
file=sys.stderr,
)
# === Test 5: Hot-reload LoRA (load_lora) === # === Test 5: Hot-reload LoRA (load_lora) ===
print(f"\n [Test 5] Hot-reload LoRA (load_lora)...", file=sys.stderr) print("\n [Test 5] Hot-reload LoRA (load_lora)...", file=sys.stderr)
loaded, skipped = model.load_lora(ckpt_dir) loaded, skipped = model.load_lora(ckpt_dir)
print(f" Reloaded {len(loaded)} parameters", file=sys.stderr) print(f" Reloaded {len(loaded)} parameters", file=sys.stderr)
audio_np = model.generate( audio_np = model.generate(
@@ -231,9 +245,12 @@ def main():
) )
reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded") reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded")
sf.write(str(reload_output), audio_np, model.tts_model.sample_rate) sf.write(str(reload_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr) print(
f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
file=sys.stderr,
)
print(f"\n[Done] All tests completed!", file=sys.stderr) print("\n[Done] All tests completed!", file=sys.stderr)
print(f" - with_lora: {lora_output}", file=sys.stderr) print(f" - with_lora: {lora_output}", file=sys.stderr)
print(f" - lora_disabled: {disabled_output}", file=sys.stderr) print(f" - lora_disabled: {disabled_output}", file=sys.stderr)
print(f" - lora_reenabled: {reenabled_output}", file=sys.stderr) print(f" - lora_reenabled: {reenabled_output}", file=sys.stderr)
+153 -43
View File
@@ -7,7 +7,7 @@ project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root / "src")) sys.path.insert(0, str(project_root / "src"))
import contextlib import contextlib
from typing import Dict, Optional from typing import Dict
import argbind import argbind
import torch import torch
@@ -17,16 +17,19 @@ from transformers import get_cosine_schedule_with_warmup
import signal import signal
import os import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false' os.environ["TOKENIZERS_PARALLELISM"] = "false"
try: try:
from safetensors.torch import save_file from safetensors.torch import save_file
SAFETENSORS_AVAILABLE = True SAFETENSORS_AVAILABLE = True
except ImportError: except ImportError:
SAFETENSORS_AVAILABLE = False SAFETENSORS_AVAILABLE = False
print("Warning: safetensors not available, will use pytorch format", file=sys.stderr) print("Warning: safetensors not available, will use pytorch format", file=sys.stderr)
from voxcpm.model import VoxCPMModel import json
from voxcpm.model import VoxCPMModel, VoxCPM2Model
from voxcpm.model.voxcpm import LoRAConfig from voxcpm.model.voxcpm import LoRAConfig
from voxcpm.training import ( from voxcpm.training import (
Accelerator, Accelerator,
@@ -61,8 +64,8 @@ def train(
lora: dict = None, lora: dict = None,
config_path: str = "", config_path: str = "",
# Distribution options (for LoRA checkpoints) # Distribution options (for LoRA checkpoints)
hf_model_id: str = "", # HuggingFace model ID (e.g., "openbmb/VoxCPM1.5") hf_model_id: str = "", # HuggingFace model ID (e.g., "openbmb/VoxCPM1.5")
distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path
): ):
_ = config_path _ = config_path
@@ -84,7 +87,15 @@ def train(
writer = SummaryWriter(log_dir=str(tb_dir)) if accelerator.rank == 0 else None writer = SummaryWriter(log_dir=str(tb_dir)) if accelerator.rank == 0 else None
tracker = TrainingTracker(writer=writer, log_file=str(save_dir / "train.log"), rank=accelerator.rank) tracker = TrainingTracker(writer=writer, log_file=str(save_dir / "train.log"), rank=accelerator.rank)
base_model = VoxCPMModel.from_local(pretrained_path, optimize=False, training=True, lora_config=LoRAConfig(**lora) if lora else None) # Auto-detect model architecture from config.json
with open(os.path.join(pretrained_path, "config.json"), "r", encoding="utf-8") as _f:
_arch = json.load(_f).get("architecture", "voxcpm").lower()
_model_cls = VoxCPM2Model if _arch == "voxcpm2" else VoxCPMModel
if accelerator.rank == 0:
print(f"Detected architecture: {_arch} -> {_model_cls.__name__}", file=sys.stderr)
base_model = _model_cls.from_local(
pretrained_path, optimize=False, training=True, lora_config=LoRAConfig(**lora) if lora else None
)
tokenizer = base_model.text_tokenizer tokenizer = base_model.text_tokenizer
train_ds, val_ds = load_audio_text_datasets( train_ds, val_ds = load_audio_text_datasets(
@@ -166,7 +177,6 @@ def train(
unwrapped_model = accelerator.unwrap(model) unwrapped_model = accelerator.unwrap(model)
unwrapped_model.train() unwrapped_model.train()
# Only print param info on rank 0 to avoid cluttered output # Only print param info on rank 0 to avoid cluttered output
if accelerator.rank == 0: if accelerator.rank == 0:
for name, param in model.named_parameters(): for name, param in model.named_parameters():
@@ -199,7 +209,19 @@ def train(
resume = {"step": start_step} resume = {"step": start_step}
# Register signal handler to save checkpoint on termination (SIGTERM/SIGINT) # Register signal handler to save checkpoint on termination (SIGTERM/SIGINT)
def _signal_handler(signum, frame, _model=model, _optim=optimizer, _sched=scheduler, _save_dir=save_dir, _pretrained=pretrained_path, _hf_id=hf_model_id, _dist=distribute, _resume=resume, _rank=accelerator.rank): def _signal_handler(
signum,
frame,
_model=model,
_optim=optimizer,
_sched=scheduler,
_save_dir=save_dir,
_pretrained=pretrained_path,
_hf_id=hf_model_id,
_dist=distribute,
_resume=resume,
_rank=accelerator.rank,
):
try: try:
cur_step = int(_resume.get("step", start_step)) cur_step = int(_resume.get("step", start_step))
except Exception: except Exception:
@@ -229,8 +251,8 @@ def train(
except StopIteration: except StopIteration:
data_epoch += 1 data_epoch += 1
# Key: set DistributedSampler epoch to ensure different data order each epoch # Key: set DistributedSampler epoch to ensure different data order each epoch
sampler = getattr(train_loader, 'sampler', None) sampler = getattr(train_loader, "sampler", None)
if hasattr(sampler, 'set_epoch'): if hasattr(sampler, "set_epoch"):
sampler.set_epoch(data_epoch) sampler.set_epoch(data_epoch)
train_iter = iter(train_loader) train_iter = iter(train_loader)
return next(train_iter) return next(train_iter)
@@ -250,7 +272,7 @@ def train(
# Only sync gradients on the last micro-batch # Only sync gradients on the last micro-batch
# Use no_sync() for intermediate steps to reduce communication overhead # Use no_sync() for intermediate steps to reduce communication overhead
is_last_micro_step = (micro_step == grad_accum_steps - 1) is_last_micro_step = micro_step == grad_accum_steps - 1
sync_context = contextlib.nullcontext() if is_last_micro_step else accelerator.no_sync() sync_context = contextlib.nullcontext() if is_last_micro_step else accelerator.no_sync()
with sync_context: with sync_context:
@@ -299,10 +321,22 @@ def train(
tracker.log_metrics(loss_values, split="train") tracker.log_metrics(loss_values, split="train")
if val_loader is not None and (step % valid_interval == 0 or step == num_iters - 1): if val_loader is not None and (step % valid_interval == 0 or step == num_iters - 1):
validate(model, val_loader, batch_processor, accelerator, tracker, lambdas, validate(
writer=writer, step=step, val_ds=val_ds, audio_vae=audio_vae_for_gen, model,
sample_rate=sample_rate, val_texts=val_texts, tokenizer=tokenizer, val_loader,
valid_interval=valid_interval) batch_processor,
accelerator,
tracker,
lambdas,
writer=writer,
step=step,
val_ds=val_ds,
audio_vae=audio_vae_for_gen,
sample_rate=sample_rate,
val_texts=val_texts,
tokenizer=tokenizer,
valid_interval=valid_interval,
)
if (step % save_interval == 0 or step == num_iters - 1) and accelerator.rank == 0: if (step % save_interval == 0 or step == num_iters - 1) and accelerator.rank == 0:
save_checkpoint(model, optimizer, scheduler, save_dir, step, pretrained_path, hf_model_id, distribute) save_checkpoint(model, optimizer, scheduler, save_dir, step, pretrained_path, hf_model_id, distribute)
@@ -313,11 +347,24 @@ def train(
writer.close() writer.close()
def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas, def validate(
writer=None, step=0, val_ds=None, audio_vae=None, sample_rate=22050, model,
val_texts=None, tokenizer=None, valid_interval=1000): val_loader,
batch_processor,
accelerator,
tracker,
lambdas,
writer=None,
step=0,
val_ds=None,
audio_vae=None,
sample_rate=22050,
val_texts=None,
tokenizer=None,
valid_interval=1000,
):
"""Validate and generate sample audio""" """Validate and generate sample audio"""
import numpy as np import numpy as np # noqa: F401
from collections import defaultdict from collections import defaultdict
model.eval() model.eval()
@@ -369,13 +416,24 @@ def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas,
# Generate sample audio for TensorBoard display # Generate sample audio for TensorBoard display
if writer is not None and val_ds is not None and audio_vae is not None and accelerator.rank == 0: if writer is not None and val_ds is not None and audio_vae is not None and accelerator.rank == 0:
try: try:
generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, sample_rate, generate_sample_audio(
val_texts=val_texts, tokenizer=tokenizer, valid_interval=valid_interval, model,
tracker=tracker) val_ds,
audio_vae,
writer,
step,
accelerator,
sample_rate,
val_texts=val_texts,
tokenizer=tokenizer,
valid_interval=valid_interval,
tracker=tracker,
)
except Exception as e: except Exception as e:
tracker.print(f"[Warning] Failed to generate sample audio: {e}") tracker.print(f"[Warning] Failed to generate sample audio: {e}")
import traceback import traceback
import io import io
buf = io.StringIO() buf = io.StringIO()
traceback.print_exc(file=buf) traceback.print_exc(file=buf)
tracker.print(buf.getvalue()) tracker.print(buf.getvalue())
@@ -398,6 +456,7 @@ def compute_mel_spectrogram(audio_np, sample_rate, n_mels=128):
"""Compute Mel Spectrogram (dB) using librosa""" """Compute Mel Spectrogram (dB) using librosa"""
import numpy as np import numpy as np
import librosa import librosa
audio_np = audio_np.flatten().astype(np.float32) audio_np = audio_np.flatten().astype(np.float32)
mel = librosa.feature.melspectrogram(y=audio_np, sr=sample_rate, n_mels=n_mels, fmax=sample_rate // 2) mel = librosa.feature.melspectrogram(y=audio_np, sr=sample_rate, n_mels=n_mels, fmax=sample_rate // 2)
return librosa.power_to_db(mel, ref=np.max) return librosa.power_to_db(mel, ref=np.max)
@@ -408,7 +467,8 @@ def create_mel_figure(gen_audio_np, gen_mel, sample_rate, step=None, ref_audio_n
Create mel spectrogram figure: show comparison if reference audio exists, otherwise show generated only Create mel spectrogram figure: show comparison if reference audio exists, otherwise show generated only
""" """
import matplotlib import matplotlib
matplotlib.use('Agg')
matplotlib.use("Agg")
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import librosa.display import librosa.display
@@ -419,19 +479,32 @@ def create_mel_figure(gen_audio_np, gen_mel, sample_rate, step=None, ref_audio_n
# Comparison mode: reference vs generated # Comparison mode: reference vs generated
fig, (ax_ref, ax_gen) = plt.subplots(2, 1, figsize=(12, 8)) fig, (ax_ref, ax_gen) = plt.subplots(2, 1, figsize=(12, 8))
img_ref = librosa.display.specshow(ref_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax_ref) img_ref = librosa.display.specshow(
ax_ref.set_title(f'Reference (GT) - {len(ref_audio_np)/sample_rate:.2f}s{step_str}', fontsize=10, fontweight='bold', color='#28A745') ref_mel, sr=sample_rate, x_axis="time", y_axis="mel", fmax=fmax, cmap="viridis", ax=ax_ref
plt.colorbar(img_ref, ax=ax_ref, format='%+2.0f dB', pad=0.02) )
ax_ref.set_title(
f"Reference (GT) - {len(ref_audio_np)/sample_rate:.2f}s{step_str}",
fontsize=10,
fontweight="bold",
color="#28A745",
)
plt.colorbar(img_ref, ax=ax_ref, format="%+2.0f dB", pad=0.02)
img_gen = librosa.display.specshow(gen_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax_gen) img_gen = librosa.display.specshow(
ax_gen.set_title(f'Generated - {len(gen_audio_np)/sample_rate:.2f}s', fontsize=10, fontweight='bold', color='#DC3545') gen_mel, sr=sample_rate, x_axis="time", y_axis="mel", fmax=fmax, cmap="viridis", ax=ax_gen
plt.colorbar(img_gen, ax=ax_gen, format='%+2.0f dB', pad=0.02) )
ax_gen.set_title(
f"Generated - {len(gen_audio_np)/sample_rate:.2f}s", fontsize=10, fontweight="bold", color="#DC3545"
)
plt.colorbar(img_gen, ax=ax_gen, format="%+2.0f dB", pad=0.02)
else: else:
# Single figure mode: show generated only # Single figure mode: show generated only
fig, ax = plt.subplots(figsize=(12, 4)) fig, ax = plt.subplots(figsize=(12, 4))
img = librosa.display.specshow(gen_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax) img = librosa.display.specshow(
ax.set_title(f'Generated - {len(gen_audio_np)/sample_rate:.2f}s{step_str}', fontsize=11, fontweight='bold') gen_mel, sr=sample_rate, x_axis="time", y_axis="mel", fmax=fmax, cmap="viridis", ax=ax
plt.colorbar(img, ax=ax, format='%+2.0f dB', pad=0.02) )
ax.set_title(f"Generated - {len(gen_audio_np)/sample_rate:.2f}s{step_str}", fontsize=11, fontweight="bold")
plt.colorbar(img, ax=ax, format="%+2.0f dB", pad=0.02)
plt.tight_layout() plt.tight_layout()
return fig return fig
@@ -440,13 +513,25 @@ def create_mel_figure(gen_audio_np, gen_mel, sample_rate, step=None, ref_audio_n
def normalize_audio(audio_np): def normalize_audio(audio_np):
"""Normalize audio to [-0.9, 0.9]""" """Normalize audio to [-0.9, 0.9]"""
import numpy as np import numpy as np
max_val = np.abs(audio_np).max() max_val = np.abs(audio_np).max()
return audio_np / max_val * 0.9 if max_val > 0 else audio_np return audio_np / max_val * 0.9 if max_val > 0 else audio_np
def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, sample_rate=22050, def generate_sample_audio(
val_texts=None, tokenizer=None, pretrained_path=None, valid_interval=1000, model,
tracker=None): val_ds,
audio_vae,
writer,
step,
accelerator,
sample_rate=22050,
val_texts=None,
tokenizer=None,
pretrained_path=None,
valid_interval=1000,
tracker=None,
):
"""Select 2 fixed validation samples, generate audio and log to TensorBoard""" """Select 2 fixed validation samples, generate audio and log to TensorBoard"""
import numpy as np import numpy as np
@@ -468,7 +553,10 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
ref_sr = sample["audio"].get("sampling_rate", sample_rate) ref_sr = sample["audio"].get("sampling_rate", sample_rate)
if ref_sr != sample_rate: if ref_sr != sample_rate:
import torchaudio.functional as F import torchaudio.functional as F
ref_audio_np = F.resample(torch.from_numpy(ref_audio_np).unsqueeze(0), ref_sr, sample_rate).squeeze(0).numpy()
ref_audio_np = (
F.resample(torch.from_numpy(ref_audio_np).unsqueeze(0), ref_sr, sample_rate).squeeze(0).numpy()
)
log(f"[Audio] Loaded reference audio for sample {i}: duration={len(ref_audio_np)/sample_rate:.2f}s") log(f"[Audio] Loaded reference audio for sample {i}: duration={len(ref_audio_np)/sample_rate:.2f}s")
except Exception as e: except Exception as e:
log(f"[Warning] Failed to load reference audio: {e}") log(f"[Warning] Failed to load reference audio: {e}")
@@ -500,7 +588,11 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
continue continue
# Process generated audio # Process generated audio
gen_audio_np = generated.cpu().float().numpy().flatten() if isinstance(generated, torch.Tensor) else np.array(generated, dtype=np.float32).flatten() gen_audio_np = (
generated.cpu().float().numpy().flatten()
if isinstance(generated, torch.Tensor)
else np.array(generated, dtype=np.float32).flatten()
)
gen_audio_np = normalize_audio(gen_audio_np) gen_audio_np = normalize_audio(gen_audio_np)
tag = f"val_sample_{i}" tag = f"val_sample_{i}"
@@ -509,7 +601,9 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
# Log reference audio # Log reference audio
if ref_audio_np is not None: if ref_audio_np is not None:
writer.add_audio(f"{tag}/reference_audio", normalize_audio(ref_audio_np), global_step=step, sample_rate=sample_rate) writer.add_audio(
f"{tag}/reference_audio", normalize_audio(ref_audio_np), global_step=step, sample_rate=sample_rate
)
# Generate mel spectrogram figure # Generate mel spectrogram figure
try: try:
@@ -524,6 +618,7 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
except Exception as e: except Exception as e:
log(f"[Warning] Failed to generate audio for sample {i}: {e}") log(f"[Warning] Failed to generate audio for sample {i}: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
finally: finally:
@@ -545,8 +640,6 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0):
Called by all ranks so that distributed state stays aligned. Called by all ranks so that distributed state stays aligned.
Returns the step number to resume from, or 0 if no checkpoint found. Returns the step number to resume from, or 0 if no checkpoint found.
""" """
import json
latest_folder = save_dir / "latest" latest_folder = save_dir / "latest"
if not latest_folder.exists(): if not latest_folder.exists():
return 0 return 0
@@ -564,6 +657,7 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0):
if lora_weights_path.exists(): if lora_weights_path.exists():
if lora_weights_path.suffix == ".safetensors": if lora_weights_path.suffix == ".safetensors":
from safetensors.torch import load_file from safetensors.torch import load_file
state_dict = load_file(str(lora_weights_path)) state_dict = load_file(str(lora_weights_path))
else: else:
ckpt = torch.load(lora_weights_path, map_location="cpu") ckpt = torch.load(lora_weights_path, map_location="cpu")
@@ -581,6 +675,7 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0):
if model_path.exists(): if model_path.exists():
if model_path.suffix == ".safetensors": if model_path.suffix == ".safetensors":
from safetensors.torch import load_file from safetensors.torch import load_file
state_dict = load_file(str(model_path)) state_dict = load_file(str(model_path))
else: else:
ckpt = torch.load(model_path, map_location="cpu") ckpt = torch.load(model_path, map_location="cpu")
@@ -625,13 +720,21 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0):
return 0 return 0
def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pretrained_path: str = None, hf_model_id: str = "", distribute: bool = False): def save_checkpoint(
model,
optimizer,
scheduler,
save_dir: Path,
step: int,
pretrained_path: str = None,
hf_model_id: str = "",
distribute: bool = False,
):
""" """
Save checkpoint with different strategies for full finetune vs LoRA: Save checkpoint with different strategies for full finetune vs LoRA:
- Full finetune: save non-vae weights to model.safetensors (or pytorch_model.bin if safetensors unavailable) - Full finetune: save non-vae weights to model.safetensors (or pytorch_model.bin if safetensors unavailable)
- LoRA: save only lora weights to lora_weights.safetensors (or lora_weights.ckpt if safetensors unavailable) - LoRA: save only lora weights to lora_weights.safetensors (or lora_weights.ckpt if safetensors unavailable)
""" """
import json
import shutil import shutil
save_dir.mkdir(parents=True, exist_ok=True) save_dir.mkdir(parents=True, exist_ok=True)
@@ -671,7 +774,14 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
# Copy config files from pretrained path # Copy config files from pretrained path
if pretrained_path: if pretrained_path:
pretrained_dir = Path(pretrained_path) pretrained_dir = Path(pretrained_path)
files_to_copy = ["config.json", "audiovae.pth", "tokenizer.json", "special_tokens_map.json", "tokenizer_config.json"] files_to_copy = [
"config.json",
"audiovae.pth",
"audiovae.safetensors",
"tokenizer.json",
"special_tokens_map.json",
"tokenizer_config.json",
]
for fname in files_to_copy: for fname in files_to_copy:
src = pretrained_dir / fname src = pretrained_dir / fname
if src.exists(): if src.exists():
+46 -27
View File
@@ -13,11 +13,11 @@ import soundfile as sf
from voxcpm.core import VoxCPM from voxcpm.core import VoxCPM
# ----------------------------- # -----------------------------
# Validators # Validators
# ----------------------------- # -----------------------------
def validate_file_exists(file_path: str, file_type: str = "file") -> Path: def validate_file_exists(file_path: str, file_type: str = "file") -> Path:
path = Path(file_path) path = Path(file_path)
if not path.exists(): if not path.exists():
@@ -53,12 +53,11 @@ def validate_ranges(args, parser):
# Model loading # Model loading
# ----------------------------- # -----------------------------
def load_model(args) -> VoxCPM: def load_model(args) -> VoxCPM:
print("Loading VoxCPM model...", file=sys.stderr) print("Loading VoxCPM model...", file=sys.stderr)
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get( zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get("ZIPENHANCER_MODEL_PATH", None)
"ZIPENHANCER_MODEL_PATH", None
)
# Build LoRA config if provided # Build LoRA config if provided
lora_config = None lora_config = None
@@ -119,22 +118,29 @@ def load_model(args) -> VoxCPM:
# Commands # Commands
# ----------------------------- # -----------------------------
def cmd_clone(args): def cmd_clone(args):
if not args.text: if not args.text:
sys.exit("Error: Please provide --text for synthesis") sys.exit("Error: Please provide --text for synthesis")
if not args.prompt_audio or not args.prompt_text: has_prompt = args.prompt_audio and args.prompt_text
sys.exit("Error: Voice cloning requires both --prompt-audio and --prompt-text") has_ref = args.reference_audio is not None
if not has_prompt and not has_ref:
sys.exit("Error: Voice cloning requires --prompt-audio + --prompt-text, or --reference-audio, or both")
prompt_audio_path = validate_file_exists(args.prompt_audio, "reference audio file") if args.prompt_audio:
validate_file_exists(args.prompt_audio, "prompt audio file")
if args.reference_audio:
validate_file_exists(args.reference_audio, "reference audio file")
output_path = validate_output_path(args.output) output_path = validate_output_path(args.output)
model = load_model(args) model = load_model(args)
audio_array = model.generate( audio_array = model.generate(
text=args.text, text=args.text,
prompt_wav_path=str(prompt_audio_path), prompt_wav_path=args.prompt_audio if has_prompt else None,
prompt_text=args.prompt_text, prompt_text=args.prompt_text if has_prompt else None,
reference_wav_path=args.reference_audio,
cfg_value=args.cfg_value, cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps, inference_timesteps=args.inference_timesteps,
normalize=args.normalize, normalize=args.normalize,
@@ -185,7 +191,11 @@ def cmd_batch(args):
prompt_audio_path = None prompt_audio_path = None
if args.prompt_audio: if args.prompt_audio:
prompt_audio_path = str(validate_file_exists(args.prompt_audio, "reference audio file")) prompt_audio_path = str(validate_file_exists(args.prompt_audio, "prompt audio file"))
reference_audio_path = None
if args.reference_audio:
reference_audio_path = str(validate_file_exists(args.reference_audio, "reference audio file"))
success_count = 0 success_count = 0
@@ -195,10 +205,11 @@ def cmd_batch(args):
text=text, text=text,
prompt_wav_path=prompt_audio_path, prompt_wav_path=prompt_audio_path,
prompt_text=args.prompt_text, prompt_text=args.prompt_text,
reference_wav_path=reference_audio_path,
cfg_value=args.cfg_value, cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps, inference_timesteps=args.inference_timesteps,
normalize=args.normalize, normalize=args.normalize,
denoise=args.denoise and prompt_audio_path is not None, denoise=args.denoise and (prompt_audio_path is not None or reference_audio_path is not None),
) )
output_file = output_dir / f"output_{i:03d}.wav" output_file = output_dir / f"output_{i:03d}.wav"
@@ -218,6 +229,7 @@ def cmd_batch(args):
# Parser # Parser
# ----------------------------- # -----------------------------
def _build_unified_parser(): def _build_unified_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="VoxCPM CLI - voice cloning, direct TTS, and batch processing", description="VoxCPM CLI - voice cloning, direct TTS, and batch processing",
@@ -236,34 +248,40 @@ Examples:
parser.add_argument("--text", "-t", help="Text to synthesize (single or clone mode)") parser.add_argument("--text", "-t", help="Text to synthesize (single or clone mode)")
parser.add_argument("--output", "-o", help="Output audio file path (single or clone mode)") parser.add_argument("--output", "-o", help="Output audio file path (single or clone mode)")
# Prompt # Prompt / Reference
parser.add_argument("--prompt-audio", "-pa", help="Reference audio file path (clone mode)") parser.add_argument(
parser.add_argument("--prompt-text", "-pt", help="Reference text corresponding to the audio") "--prompt-audio", "-pa", help="Prompt audio file path (continuation mode, requires --prompt-text)"
parser.add_argument("--denoise", action="store_true", help="Enable prompt speech enhancement") )
parser.add_argument("--prompt-text", "-pt", help="Text corresponding to the prompt audio")
parser.add_argument(
"--reference-audio", "-ra", help="Reference audio for voice cloning (isolated mode, VoxCPM2 only)"
)
parser.add_argument("--denoise", action="store_true", help="Enable prompt/reference speech enhancement")
# Generation parameters # Generation parameters
parser.add_argument("--cfg-value", type=float, default=2.0, parser.add_argument(
help="CFG guidance scale (float, recommended 0.55.0, default: 2.0)") "--cfg-value", type=float, default=2.0, help="CFG guidance scale (float, recommended 0.55.0, default: 2.0)"
parser.add_argument("--inference-timesteps", type=int, default=10, )
help="Inference steps (int, 1100, default: 10)") parser.add_argument("--inference-timesteps", type=int, default=10, help="Inference steps (int, 1100, default: 10)")
parser.add_argument("--normalize", action="store_true", help="Enable text normalization") parser.add_argument("--normalize", action="store_true", help="Enable text normalization")
# Model loading # Model loading
parser.add_argument("--model-path", type=str, help="Local VoxCPM model path") parser.add_argument("--model-path", type=str, help="Local VoxCPM model path")
parser.add_argument("--hf-model-id", type=str, default="openbmb/VoxCPM1.5", parser.add_argument(
help="Hugging Face repo id (default: openbmb/VoxCPM1.5)") "--hf-model-id", type=str, default="openbmb/VoxCPM1.5", help="Hugging Face repo id (default: openbmb/VoxCPM1.5)"
)
parser.add_argument("--cache-dir", type=str, help="Cache directory for Hub downloads") parser.add_argument("--cache-dir", type=str, help="Cache directory for Hub downloads")
parser.add_argument("--local-files-only", action="store_true", help="Disable network access") parser.add_argument("--local-files-only", action="store_true", help="Disable network access")
parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading") parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading")
parser.add_argument("--zipenhancer-path", type=str, parser.add_argument(
help="ZipEnhancer model id or local path (or env ZIPENHANCER_MODEL_PATH)") "--zipenhancer-path", type=str, help="ZipEnhancer model id or local path (or env ZIPENHANCER_MODEL_PATH)"
)
# LoRA # LoRA
parser.add_argument("--lora-path", type=str, help="Path to LoRA weights") parser.add_argument("--lora-path", type=str, help="Path to LoRA weights")
parser.add_argument("--lora-r", type=int, default=32, help="LoRA rank (positive int, default: 32)") parser.add_argument("--lora-r", type=int, default=32, help="LoRA rank (positive int, default: 32)")
parser.add_argument("--lora-alpha", type=int, default=16, help="LoRA alpha (positive int, default: 16)") parser.add_argument("--lora-alpha", type=int, default=16, help="LoRA alpha (positive int, default: 16)")
parser.add_argument("--lora-dropout", type=float, default=0.0, parser.add_argument("--lora-dropout", type=float, default=0.0, help="LoRA dropout rate (0.01.0, default: 0.0)")
help="LoRA dropout rate (0.01.0, default: 0.0)")
parser.add_argument("--lora-disable-lm", action="store_true", help="Disable LoRA on LM layers") parser.add_argument("--lora-disable-lm", action="store_true", help="Disable LoRA on LM layers")
parser.add_argument("--lora-disable-dit", action="store_true", help="Disable LoRA on DiT layers") parser.add_argument("--lora-disable-dit", action="store_true", help="Disable LoRA on DiT layers")
parser.add_argument("--lora-enable-proj", action="store_true", help="Enable LoRA on projection layers") parser.add_argument("--lora-enable-proj", action="store_true", help="Enable LoRA on projection layers")
@@ -275,6 +293,7 @@ Examples:
# Entrypoint # Entrypoint
# ----------------------------- # -----------------------------
def main(): def main():
parser = _build_unified_parser() parser = _build_unified_parser()
args = parser.parse_args() args = parser.parse_args()
@@ -296,8 +315,8 @@ def main():
if not args.text or not args.output: if not args.text or not args.output:
parser.error("Single-sample mode requires --text and --output") parser.error("Single-sample mode requires --text and --output")
# Clone mode # Clone mode (prompt continuation, reference isolation, or both)
if args.prompt_audio or args.prompt_text: if args.prompt_audio or args.prompt_text or args.reference_audio:
return cmd_clone(args) return cmd_clone(args)
# Direct synthesis # Direct synthesis
+127 -76
View File
@@ -1,21 +1,25 @@
import os import os
import sys import sys
import re import re
import json
import tempfile import tempfile
import numpy as np import numpy as np
from typing import Generator, Optional from typing import Generator, Optional
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from .model.voxcpm import VoxCPMModel, LoRAConfig from .model.voxcpm import VoxCPMModel, LoRAConfig
from .model.voxcpm2 import VoxCPM2Model
class VoxCPM: class VoxCPM:
def __init__(self, def __init__(
voxcpm_model_path : str, self,
zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base", voxcpm_model_path: str,
enable_denoiser : bool = True, zipenhancer_model_path: str | None = "iic/speech_zipenhancer_ans_multiloss_16k_base",
optimize: bool = True, enable_denoiser: bool = True,
lora_config: Optional[LoRAConfig] = None, optimize: bool = True,
lora_weights_path: Optional[str] = None, lora_config: Optional[LoRAConfig] = None,
): lora_weights_path: Optional[str] = None,
):
"""Initialize VoxCPM TTS pipeline. """Initialize VoxCPM TTS pipeline.
Args: Args:
@@ -31,7 +35,10 @@ class VoxCPM:
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
containing lora_weights.ckpt). If provided, LoRA weights will be loaded. containing lora_weights.ckpt). If provided, LoRA weights will be loaded.
""" """
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}", file=sys.stderr) print(
f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}",
file=sys.stderr,
)
# If lora_weights_path is provided but no lora_config, create a default one # If lora_weights_path is provided but no lora_config, create a default one
if lora_weights_path is not None and lora_config is None: if lora_weights_path is not None and lora_config is None:
@@ -42,7 +49,20 @@ class VoxCPM:
) )
print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}", file=sys.stderr) print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}", file=sys.stderr)
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config) # Determine model type from config.json architecture field
config_path = os.path.join(voxcpm_model_path, "config.json")
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
arch = config.get("architecture", "voxcpm").lower()
if arch == "voxcpm2":
self.tts_model = VoxCPM2Model.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
print("Loaded VoxCPM2Model", file=sys.stderr)
elif arch == "voxcpm":
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
print("Loaded VoxCPMModel", file=sys.stderr)
else:
raise ValueError(f"Unsupported architecture: {arch}")
# Load LoRA weights if path is provided # Load LoRA weights if path is provided
if lora_weights_path is not None: if lora_weights_path is not None:
@@ -51,8 +71,10 @@ class VoxCPM:
print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}", file=sys.stderr) print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}", file=sys.stderr)
self.text_normalizer = None self.text_normalizer = None
self.denoiser = None
if enable_denoiser and zipenhancer_model_path is not None: if enable_denoiser and zipenhancer_model_path is not None:
from .zipenhancer import ZipEnhancer from .zipenhancer import ZipEnhancer
self.denoiser = ZipEnhancer(zipenhancer_model_path) self.denoiser = ZipEnhancer(zipenhancer_model_path)
else: else:
self.denoiser = None self.denoiser = None
@@ -64,17 +86,18 @@ class VoxCPM:
) )
@classmethod @classmethod
def from_pretrained(cls, def from_pretrained(
hf_model_id: str = "openbmb/VoxCPM1.5", cls,
load_denoiser: bool = True, hf_model_id: str = "openbmb/VoxCPM2",
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base", load_denoiser: bool = True,
cache_dir: str = None, zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
local_files_only: bool = False, cache_dir: str = None,
optimize: bool = True, local_files_only: bool = False,
lora_config: Optional[LoRAConfig] = None, optimize: bool = True,
lora_weights_path: Optional[str] = None, lora_config: Optional[LoRAConfig] = None,
**kwargs, lora_weights_path: Optional[str] = None,
): **kwargs,
):
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot. """Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
Args: Args:
@@ -134,46 +157,47 @@ class VoxCPM:
def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]: def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]:
return self._generate(*args, streaming=True, **kwargs) return self._generate(*args, streaming=True, **kwargs)
def _generate(self, def _generate(
text : str, self,
prompt_wav_path : str = None, text: str,
prompt_text : str = None, prompt_wav_path: str = None,
cfg_value : float = 2.0, prompt_text: str = None,
inference_timesteps : int = 10, reference_wav_path: str = None,
min_len : int = 2, cfg_value: float = 2.0,
max_len : int = 4096, inference_timesteps: int = 10,
normalize : bool = False, min_len: int = 2,
denoise : bool = False, max_len: int = 4096,
retry_badcase : bool = True, normalize: bool = False,
retry_badcase_max_times : int = 3, denoise: bool = False,
retry_badcase_ratio_threshold : float = 6.0, retry_badcase: bool = True,
streaming: bool = False, retry_badcase_max_times: int = 3,
) -> Generator[np.ndarray, None, None]: retry_badcase_ratio_threshold: float = 6.0,
streaming: bool = False,
) -> Generator[np.ndarray, None, None]:
"""Synthesize speech for the given text and return a single waveform. """Synthesize speech for the given text and return a single waveform.
This method optionally builds and reuses a prompt cache. If an external
prompt (``prompt_wav_path`` + ``prompt_text``) is provided, it will be
used for all sub-sentences. Otherwise, the prompt cache is built from
the first generated result and reused for the remaining text chunks.
Args: Args:
text: Input text. Can include newlines; each non-empty line is text: Input text to synthesize.
treated as a sub-sentence. prompt_wav_path: Path to prompt audio for continuation mode.
prompt_wav_path: Path to a reference audio file for prompting. Must be paired with ``prompt_text``.
prompt_text: Text content corresponding to the prompt audio. prompt_text: Text content corresponding to the prompt audio.
reference_wav_path: Path to reference audio for voice cloning
(structurally isolated via ref_audio tokens). Can be used
alone or combined with ``prompt_wav_path`` + ``prompt_text``.
cfg_value: Guidance scale for the generation model. cfg_value: Guidance scale for the generation model.
inference_timesteps: Number of inference steps. inference_timesteps: Number of inference steps.
min_len: Minimum audio length.
max_len: Maximum token length during generation. max_len: Maximum token length during generation.
normalize: Whether to run text normalization before generation. normalize: Whether to run text normalization before generation.
denoise: Whether to denoise the prompt audio if a denoiser is denoise: Whether to denoise the prompt/reference audio if a
available. denoiser is available.
retry_badcase: Whether to retry badcase. retry_badcase: Whether to retry badcase.
retry_badcase_max_times: Maximum number of times to retry badcase. retry_badcase_max_times: Maximum number of times to retry badcase.
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio. retry_badcase_ratio_threshold: Threshold for audio-to-text ratio.
streaming: Whether to return a generator of audio chunks. streaming: Whether to return a generator of audio chunks.
Returns: Returns:
Generator of numpy.ndarray: 1D waveform array (float32) on CPU. Generator of numpy.ndarray: 1D waveform array (float32) on CPU.
Yields audio chunks for each generations step if ``streaming=True``, Yields audio chunks for each generation step if ``streaming=True``,
otherwise yields a single array containing the final audio. otherwise yields a single array containing the final audio.
""" """
if not text.strip() or not isinstance(text, str): if not text.strip() or not isinstance(text, str):
@@ -183,55 +207,82 @@ class VoxCPM:
if not os.path.exists(prompt_wav_path): if not os.path.exists(prompt_wav_path):
raise FileNotFoundError(f"prompt_wav_path does not exist: {prompt_wav_path}") raise FileNotFoundError(f"prompt_wav_path does not exist: {prompt_wav_path}")
if reference_wav_path is not None:
if not os.path.exists(reference_wav_path):
raise FileNotFoundError(f"reference_wav_path does not exist: {reference_wav_path}")
if (prompt_wav_path is None) != (prompt_text is None): if (prompt_wav_path is None) != (prompt_text is None):
raise ValueError("prompt_wav_path and prompt_text must both be provided or both be None") raise ValueError("prompt_wav_path and prompt_text must both be provided or both be None")
is_v2 = isinstance(self.tts_model, VoxCPM2Model)
if reference_wav_path is not None and not is_v2:
raise ValueError("reference_wav_path is only supported with VoxCPM2 models")
text = text.replace("\n", " ") text = text.replace("\n", " ")
text = re.sub(r'\s+', ' ', text) text = re.sub(r"\s+", " ", text)
temp_prompt_wav_path = None temp_files = []
try: try:
if prompt_wav_path is not None and prompt_text is not None: actual_prompt_path = prompt_wav_path
if denoise and self.denoiser is not None: actual_ref_path = reference_wav_path
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
temp_prompt_wav_path = tmp_file.name if denoise and self.denoiser is not None:
self.denoiser.enhance(prompt_wav_path, output_path=temp_prompt_wav_path) if prompt_wav_path is not None:
prompt_wav_path = temp_prompt_wav_path with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
fixed_prompt_cache = self.tts_model.build_prompt_cache( temp_files.append(tmp.name)
prompt_wav_path=prompt_wav_path, self.denoiser.enhance(prompt_wav_path, output_path=temp_files[-1])
prompt_text=prompt_text actual_prompt_path = temp_files[-1]
) if reference_wav_path is not None:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
temp_files.append(tmp.name)
self.denoiser.enhance(reference_wav_path, output_path=temp_files[-1])
actual_ref_path = temp_files[-1]
if actual_prompt_path is not None or actual_ref_path is not None:
if is_v2:
fixed_prompt_cache = self.tts_model.build_prompt_cache(
prompt_text=prompt_text,
prompt_wav_path=actual_prompt_path,
reference_wav_path=actual_ref_path,
)
else:
fixed_prompt_cache = self.tts_model.build_prompt_cache(
prompt_text=prompt_text,
prompt_wav_path=actual_prompt_path,
)
else: else:
fixed_prompt_cache = None # will be built from the first inference fixed_prompt_cache = None
if normalize: if normalize:
if self.text_normalizer is None: if self.text_normalizer is None:
from .utils.text_normalize import TextNormalizer from .utils.text_normalize import TextNormalizer
self.text_normalizer = TextNormalizer() self.text_normalizer = TextNormalizer()
text = self.text_normalizer.normalize(text) text = self.text_normalizer.normalize(text)
generate_result = self.tts_model._generate_with_prompt_cache( generate_result = self.tts_model._generate_with_prompt_cache(
target_text=text, target_text=text,
prompt_cache=fixed_prompt_cache, prompt_cache=fixed_prompt_cache,
min_len=min_len, min_len=min_len,
max_len=max_len, max_len=max_len,
inference_timesteps=inference_timesteps, inference_timesteps=inference_timesteps,
cfg_value=cfg_value, cfg_value=cfg_value,
retry_badcase=retry_badcase, retry_badcase=retry_badcase,
retry_badcase_max_times=retry_badcase_max_times, retry_badcase_max_times=retry_badcase_max_times,
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold, retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
streaming=streaming, streaming=streaming,
) )
for wav, _, _ in generate_result: for wav, _, _ in generate_result:
yield wav.squeeze(0).cpu().numpy() yield wav.squeeze(0).cpu().numpy()
finally: finally:
if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path): for tmp_path in temp_files:
try: if tmp_path and os.path.exists(tmp_path):
os.unlink(temp_prompt_wav_path) try:
except OSError: os.unlink(tmp_path)
pass except OSError:
pass
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
# LoRA Interface (delegated to VoxCPMModel) # LoRA Interface (delegated to VoxCPMModel)
+2 -1
View File
@@ -1,3 +1,4 @@
from .voxcpm import VoxCPMModel from .voxcpm import VoxCPMModel
from .voxcpm2 import VoxCPM2Model
__all__ = ["VoxCPMModel"] __all__ = ["VoxCPMModel", "VoxCPM2Model"]
+1 -2
View File
@@ -24,8 +24,7 @@ def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
""" """
# Pre-compute multi-character tokens (length >= 2, pure Chinese characters) # Pre-compute multi-character tokens (length >= 2, pure Chinese characters)
multichar_tokens = { multichar_tokens = {
token for token in tokenizer.vocab.keys() token for token in tokenizer.vocab.keys() if len(token) >= 2 and all("\u4e00" <= c <= "\u9fff" for c in token)
if len(token) >= 2 and all("\u4e00" <= c <= "\u9fff" for c in token)
} }
class CharTokenizerWrapper: class CharTokenizerWrapper:
+80 -67
View File
@@ -24,7 +24,6 @@ from typing import Tuple, Union, Generator, List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import torchaudio import torchaudio
import warnings import warnings
from einops import rearrange from einops import rearrange
@@ -32,6 +31,7 @@ from pydantic import BaseModel
try: try:
from safetensors.torch import load_file from safetensors.torch import load_file
SAFETENSORS_AVAILABLE = True SAFETENSORS_AVAILABLE = True
except ImportError: except ImportError:
SAFETENSORS_AVAILABLE = False SAFETENSORS_AVAILABLE = False
@@ -84,9 +84,9 @@ class VoxCPMConfig(BaseModel):
class LoRAConfig(BaseModel): class LoRAConfig(BaseModel):
enable_lm: bool = False # Apply LoRA to base_lm + residual_lm enable_lm: bool = False # Apply LoRA to base_lm + residual_lm
enable_dit: bool = False # Apply LoRA to VoxCPMLocDiT enable_dit: bool = False # Apply LoRA to VoxCPMLocDiT
enable_proj: bool = False # Apply LoRA to projection Linear layers enable_proj: bool = False # Apply LoRA to projection Linear layers
r: int = 8 r: int = 8
alpha: int = 16 alpha: int = 16
@@ -168,7 +168,7 @@ class VoxCPMModel(nn.Module):
config.lm_config.hidden_size, config.lm_config.hidden_size,
config.lm_config.hidden_size, config.lm_config.hidden_size,
config.scalar_quantization_latent_dim, config.scalar_quantization_latent_dim,
config.scalar_quantization_scale config.scalar_quantization_scale,
) )
self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size) self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size)
self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim) self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
@@ -196,9 +196,7 @@ class VoxCPMModel(nn.Module):
# LM: base_lm + residual_lm # LM: base_lm + residual_lm
if cfg.enable_lm: if cfg.enable_lm:
for lm in [self.base_lm, self.residual_lm]: for lm in [self.base_lm, self.residual_lm]:
apply_lora_to_named_linear_modules( apply_lora_to_named_linear_modules(lm, target_submodule_names=cfg.target_modules_lm, **lora_kwargs)
lm, target_submodule_names=cfg.target_modules_lm, **lora_kwargs
)
# DiT: feat_decoder.estimator # DiT: feat_decoder.estimator
if cfg.enable_dit: if cfg.enable_dit:
@@ -209,6 +207,7 @@ class VoxCPMModel(nn.Module):
# 投影层 # 投影层
if cfg.enable_proj: if cfg.enable_proj:
from ..modules.layers.lora import LoRALinear from ..modules.layers.lora import LoRALinear
for attr_name in cfg.target_proj_modules: for attr_name in cfg.target_proj_modules:
module = getattr(self, attr_name, None) module = getattr(self, attr_name, None)
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
@@ -221,13 +220,17 @@ class VoxCPMModel(nn.Module):
if self.device != "cuda": if self.device != "cuda":
raise ValueError("VoxCPMModel can only be optimized on CUDA device") raise ValueError("VoxCPMModel can only be optimized on CUDA device")
try: try:
import triton import triton # noqa: F401
except ImportError: except ImportError:
raise ValueError("triton is not installed") raise ValueError("triton is not installed")
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True) self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True) self.residual_lm.forward_step = torch.compile(
self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True
)
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True) self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True) self.feat_decoder.estimator = torch.compile(
self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True
)
except Exception as e: except Exception as e:
print(f"Warning: torch.compile disabled - {e}", file=sys.stderr) print(f"Warning: torch.compile disabled - {e}", file=sys.stderr)
return self return self
@@ -313,9 +316,11 @@ class VoxCPMModel(nn.Module):
mu=dit_hidden, mu=dit_hidden,
patch_size=self.patch_size, patch_size=self.patch_size,
cond=feat_cond_for_sample, cond=feat_cond_for_sample,
n_timesteps=self.config.dit_config.cfm_config.inference_cfg_rate n_timesteps=(
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate") self.config.dit_config.cfm_config.inference_cfg_rate
else 10, if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
else 10
),
) )
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size) feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
@@ -331,7 +336,6 @@ class VoxCPMModel(nn.Module):
def _dtype(self): def _dtype(self):
return get_dtype(self.config.dtype) return get_dtype(self.config.dtype)
def generate(self, *args, **kwargs) -> torch.Tensor: def generate(self, *args, **kwargs) -> torch.Tensor:
return next(self._generate(*args, streaming=False, **kwargs)) return next(self._generate(*args, streaming=False, **kwargs))
@@ -350,7 +354,7 @@ class VoxCPMModel(nn.Module):
cfg_value: float = 2.0, cfg_value: float = 2.0,
retry_badcase: bool = False, retry_badcase: bool = False,
retry_badcase_max_times: int = 3, retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection) retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
streaming: bool = False, streaming: bool = False,
) -> Generator[torch.Tensor, None, None]: ) -> Generator[torch.Tensor, None, None]:
if retry_badcase and streaming: if retry_badcase and streaming:
@@ -444,7 +448,9 @@ class VoxCPMModel(nn.Module):
audio_feat, audio_feat,
audio_mask, audio_mask,
min_len=min_len, min_len=min_len,
max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len), # avoid too long audio max_len=min(
int(target_text_length * retry_badcase_ratio_threshold + 10), max_len
), # avoid too long audio
inference_timesteps=inference_timesteps, inference_timesteps=inference_timesteps,
cfg_value=cfg_value, cfg_value=cfg_value,
streaming=streaming, streaming=streaming,
@@ -460,7 +466,10 @@ class VoxCPMModel(nn.Module):
latent_pred, pred_audio_feat = next(inference_result) latent_pred, pred_audio_feat = next(inference_result)
if retry_badcase: if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold: if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...", file=sys.stderr) print(
f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...",
file=sys.stderr,
)
retry_badcase_times += 1 retry_badcase_times += 1
continue continue
else: else:
@@ -514,7 +523,9 @@ class VoxCPMModel(nn.Module):
self.audio_vae.latent_dim, self.audio_vae.latent_dim,
-1, -1,
self.patch_size, self.patch_size,
).permute(1, 2, 0) # (D, T, P) ).permute(
1, 2, 0
) # (D, T, P)
# build prompt cache - only save raw text and audio features # build prompt cache - only save raw text and audio features
prompt_cache = { prompt_cache = {
"prompt_text": prompt_text, "prompt_text": prompt_text,
@@ -523,7 +534,6 @@ class VoxCPMModel(nn.Module):
return prompt_cache return prompt_cache
def merge_prompt_cache( def merge_prompt_cache(
self, self,
original_cache: dict, original_cache: dict,
@@ -560,17 +570,14 @@ class VoxCPMModel(nn.Module):
return merged_cache return merged_cache
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs)) return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
def generate_with_prompt_cache_streaming( def generate_with_prompt_cache_streaming(
self, *args, **kwargs self, *args, **kwargs
) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]: ) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]:
return self._generate_with_prompt_cache(*args, streaming=True, **kwargs) return self._generate_with_prompt_cache(*args, streaming=True, **kwargs)
@torch.inference_mode() @torch.inference_mode()
def _generate_with_prompt_cache( def _generate_with_prompt_cache(
self, self,
@@ -645,8 +652,12 @@ class VoxCPMModel(nn.Module):
) )
text_token = torch.cat([text_token, text_pad_token]) text_token = torch.cat([text_token, text_pad_token])
audio_feat = torch.cat([audio_pad_feat, prompt_audio_feat], dim=0) audio_feat = torch.cat([audio_pad_feat, prompt_audio_feat], dim=0)
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device) text_mask = (
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device) torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
)
audio_mask = (
torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
)
text_token = text_token.unsqueeze(0).to(self.device) text_token = text_token.unsqueeze(0).to(self.device)
text_mask = text_mask.unsqueeze(0).to(self.device) text_mask = text_mask.unsqueeze(0).to(self.device)
@@ -663,7 +674,9 @@ class VoxCPMModel(nn.Module):
audio_feat, audio_feat,
audio_mask, audio_mask,
min_len=min_len, min_len=min_len,
max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len), # avoid too long audio max_len=min(
int(target_text_length * retry_badcase_ratio_threshold + 10), max_len
), # avoid too long audio
inference_timesteps=inference_timesteps, inference_timesteps=inference_timesteps,
cfg_value=cfg_value, cfg_value=cfg_value,
streaming=streaming, streaming=streaming,
@@ -674,17 +687,16 @@ class VoxCPMModel(nn.Module):
for latent_pred, pred_audio_feat in inference_result: for latent_pred, pred_audio_feat in inference_result:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)) decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu() decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
yield ( yield (decode_audio, target_text_token, pred_audio_feat)
decode_audio,
target_text_token,
pred_audio_feat
)
break break
else: else:
latent_pred, pred_audio_feat = next(inference_result) latent_pred, pred_audio_feat = next(inference_result)
if retry_badcase: if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold: if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...", file=sys.stderr) print(
f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...",
file=sys.stderr,
)
retry_badcase_times += 1 retry_badcase_times += 1
continue continue
else: else:
@@ -695,14 +707,10 @@ class VoxCPMModel(nn.Module):
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)) decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
patch_len = self.patch_size * self.chunk_size patch_len = self.patch_size * self.chunk_size
if audio_mask.sum().item() > 0: if audio_mask.sum().item() > 0:
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1):].squeeze(1).cpu() decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1) :].squeeze(1).cpu()
else: else:
decode_audio = decode_audio[..., :].squeeze(1).cpu() decode_audio = decode_audio[..., :].squeeze(1).cpu()
yield ( yield (decode_audio, target_text_token, pred_audio_feat)
decode_audio,
target_text_token,
pred_audio_feat
)
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
return next(self._inference(*args, streaming=False, **kwargs)) return next(self._inference(*args, streaming=False, **kwargs))
@@ -782,7 +790,6 @@ class VoxCPMModel(nn.Module):
enc_outputs = self.fsq_layer(enc_outputs) * feat_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1) enc_outputs = self.fsq_layer(enc_outputs) * feat_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
lm_hidden = enc_outputs[:, -1, :] lm_hidden = enc_outputs[:, -1, :]
residual_enc_outputs, residual_kv_cache_tuple = self.residual_lm( residual_enc_outputs, residual_kv_cache_tuple = self.residual_lm(
inputs_embeds=enc_outputs + feat_mask.unsqueeze(-1) * feat_embed, inputs_embeds=enc_outputs + feat_mask.unsqueeze(-1) * feat_embed,
is_causal=True, is_causal=True,
@@ -790,7 +797,6 @@ class VoxCPMModel(nn.Module):
self.residual_lm.kv_cache.fill_caches(residual_kv_cache_tuple) self.residual_lm.kv_cache.fill_caches(residual_kv_cache_tuple)
residual_hidden = residual_enc_outputs[:, -1, :] residual_hidden = residual_enc_outputs[:, -1, :]
for i in tqdm(range(max_len)): for i in tqdm(range(max_len)):
dit_hidden_1 = self.lm_to_dit_proj(lm_hidden) # [b, h_dit] dit_hidden_1 = self.lm_to_dit_proj(lm_hidden) # [b, h_dit]
dit_hidden_2 = self.res_to_dit_proj(residual_hidden) # [b, h_dit] dit_hidden_2 = self.res_to_dit_proj(residual_hidden) # [b, h_dit]
@@ -827,10 +833,10 @@ class VoxCPMModel(nn.Module):
curr_embed[:, 0, :], torch.tensor([self.base_lm.kv_cache.step()], device=curr_embed.device) curr_embed[:, 0, :], torch.tensor([self.base_lm.kv_cache.step()], device=curr_embed.device)
).clone() ).clone()
lm_hidden = self.fsq_layer(lm_hidden) lm_hidden = self.fsq_layer(lm_hidden)
residual_hidden = self.residual_lm.forward_step( residual_hidden = self.residual_lm.forward_step(
lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device) lm_hidden + curr_embed[:, 0, :],
torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device),
).clone() ).clone()
if not streaming: if not streaming:
@@ -838,29 +844,41 @@ class VoxCPMModel(nn.Module):
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size) feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
yield feat_pred, pred_feat_seq.squeeze(0).cpu() yield feat_pred, pred_feat_seq.squeeze(0).cpu()
@classmethod @classmethod
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None): def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read()) config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
tokenizer = LlamaTokenizerFast.from_pretrained(path) tokenizer = LlamaTokenizerFast.from_pretrained(path)
audio_vae_config = getattr(config, 'audio_vae_config', None) audio_vae_config = getattr(config, "audio_vae_config", None)
audio_vae = AudioVAE(config=audio_vae_config) if audio_vae_config else AudioVAE() audio_vae = AudioVAE(config=audio_vae_config) if audio_vae_config else AudioVAE()
vae_state_dict = torch.load( # Try to load AudioVAE from safetensors first, fallback to pytorch
os.path.join(path, "audiovae.pth"), audiovae_safetensors_path = os.path.join(path, "audiovae.safetensors")
map_location="cpu", audiovae_pth_path = os.path.join(path, "audiovae.pth")
weights_only=True, if os.path.exists(audiovae_safetensors_path) and SAFETENSORS_AVAILABLE:
)["state_dict"] print(f"Loading AudioVAE from safetensors: {audiovae_safetensors_path}", file=sys.stderr)
vae_state_dict = load_file(audiovae_safetensors_path, device="cpu")
elif os.path.exists(audiovae_pth_path):
print(f"Loading AudioVAE from pytorch: {audiovae_pth_path}", file=sys.stderr)
checkpoint = torch.load(
audiovae_pth_path,
map_location="cpu",
weights_only=True,
)
vae_state_dict = checkpoint.get("state_dict", checkpoint)
else:
raise FileNotFoundError(
f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
)
model = cls(config, tokenizer, audio_vae, lora_config) model = cls(config, tokenizer, audio_vae, lora_config)
if not training: if not training:
lm_dtype = get_dtype(model.config.dtype) lm_dtype = get_dtype(model.config.dtype)
model = model.to(lm_dtype) model = model.to(lm_dtype)
else: # training mode else: # training mode
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if "audio_vae" in name: # freeze VAE weights if "audio_vae" in name: # freeze VAE weights
param.requires_grad = False param.requires_grad = False
continue continue
if lora_config is not None: if lora_config is not None:
if "lora" not in name: # freeze non-LoRA weights if "lora" not in name: # freeze non-LoRA weights
param.requires_grad = False param.requires_grad = False
model.audio_vae = model.audio_vae.to(torch.float32) model.audio_vae = model.audio_vae.to(torch.float32)
@@ -880,9 +898,7 @@ class VoxCPMModel(nn.Module):
) )
model_state_dict = checkpoint.get("state_dict", checkpoint) model_state_dict = checkpoint.get("state_dict", checkpoint)
else: else:
raise FileNotFoundError( raise FileNotFoundError(f"Model file not found. Expected either {safetensors_path} or {pytorch_model_path}")
f"Model file not found. Expected either {safetensors_path} or {pytorch_model_path}"
)
for kw, val in vae_state_dict.items(): for kw, val in vae_state_dict.items():
model_state_dict[f"audio_vae.{kw}"] = val model_state_dict[f"audio_vae.{kw}"] = val
@@ -900,6 +916,7 @@ class VoxCPMModel(nn.Module):
def _iter_lora_modules(self): def _iter_lora_modules(self):
"""Iterate over all LoRA modules.""" """Iterate over all LoRA modules."""
from ..modules.layers.lora import LoRALinear from ..modules.layers.lora import LoRALinear
for module in self.modules(): for module in self.modules():
if isinstance(module, LoRALinear): if isinstance(module, LoRALinear):
yield module yield module
@@ -919,15 +936,15 @@ class VoxCPMModel(nn.Module):
from pathlib import Path from pathlib import Path
device = device or self.device device = device or self.device
lora_path = Path(lora_path) lora_p = Path(lora_path)
# Try safetensors first, then fallback to .ckpt # Try safetensors first, then fallback to .ckpt
if lora_path.is_dir(): if lora_p.is_dir():
safetensors_file = lora_path / "lora_weights.safetensors" safetensors_file = lora_p / "lora_weights.safetensors"
ckpt_file = lora_path / "lora_weights.ckpt" ckpt_file = lora_p / "lora_weights.ckpt"
else: else:
safetensors_file = lora_path if lora_path.suffix == ".safetensors" else None safetensors_file = lora_p if lora_p.suffix == ".safetensors" else None
ckpt_file = lora_path if lora_path.suffix in [".ckpt", ".pth"] else None ckpt_file = lora_p if lora_p.suffix in [".ckpt", ".pth"] else None
# Load from safetensors if available # Load from safetensors if available
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE: if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
@@ -936,9 +953,7 @@ class VoxCPMModel(nn.Module):
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False) ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
state_dict = ckpt.get("state_dict", ckpt) state_dict = ckpt.get("state_dict", ckpt)
else: else:
raise FileNotFoundError( raise FileNotFoundError(f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}")
f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}"
)
# Build param mapping (handle torch.compile's _orig_mod prefix) # Build param mapping (handle torch.compile's _orig_mod prefix)
model_params = dict(self.named_parameters()) model_params = dict(self.named_parameters())
@@ -967,6 +982,4 @@ class VoxCPMModel(nn.Module):
def get_lora_state_dict(self) -> dict: def get_lora_state_dict(self) -> dict:
"""Get all LoRA parameters (lora_A/lora_B).""" """Get all LoRA parameters (lora_A/lora_B)."""
return {name: param.data.clone() return {name: param.data.clone() for name, param in self.named_parameters() if "lora_" in name}
for name, param in self.named_parameters()
if "lora_" in name}
File diff suppressed because it is too large Load Diff
+1
View File
@@ -1 +1,2 @@
from .audio_vae import AudioVAE, AudioVAEConfig from .audio_vae import AudioVAE, AudioVAEConfig
from .audio_vae_v2 import AudioVAE as AudioVAEV2, AudioVAEConfig as AudioVAEConfigV2
+2 -2
View File
@@ -1,5 +1,5 @@
import math import math
from typing import List, Union, Optional from typing import List
import numpy as np import numpy as np
import torch import torch
@@ -285,7 +285,7 @@ class AudioVAE(nn.Module):
def __init__( def __init__(
self, self,
config: Optional[AudioVAEConfig] = None, config: AudioVAEConfig = None,
): ):
# 如果没有传入config,使用默认配置 # 如果没有传入config,使用默认配置
if config is None: if config is None:
+486
View File
@@ -0,0 +1,486 @@
import math
from typing import List, Optional
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
from pydantic import BaseModel
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
class CausalConv1d(nn.Conv1d):
def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs):
super().__init__(*args, **kwargs)
self.__padding = padding
self.__output_padding = output_padding
def forward(self, x):
x_pad = F.pad(x, (self.__padding * 2 - self.__output_padding, 0))
return super().forward(x_pad)
class CausalTransposeConv1d(nn.ConvTranspose1d):
def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs):
super().__init__(*args, **kwargs)
self.__padding = padding
self.__output_padding = output_padding
def forward(self, x):
return super().forward(x)[..., : -(self.__padding * 2 - self.__output_padding)]
def WNCausalConv1d(*args, **kwargs):
return weight_norm(CausalConv1d(*args, **kwargs))
def WNCausalTransposeConv1d(*args, **kwargs):
return weight_norm(CausalTransposeConv1d(*args, **kwargs))
# Scripting this brings model speed up 1.4x
@torch.jit.script
def snake(x, alpha):
shape = x.shape
x = x.reshape(shape[0], shape[1], -1)
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
x = x.reshape(shape)
return x
class Snake1d(nn.Module):
def __init__(self, channels):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
def forward(self, x):
return snake(x, self.alpha)
def init_weights(m):
if isinstance(m, nn.Conv1d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
class CausalResidualUnit(nn.Module):
def __init__(self, dim: int = 16, dilation: int = 1, kernel: int = 7, groups: int = 1):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.block = nn.Sequential(
Snake1d(dim),
WNCausalConv1d(
dim,
dim,
kernel_size=kernel,
dilation=dilation,
padding=pad,
groups=groups,
),
Snake1d(dim),
WNCausalConv1d(dim, dim, kernel_size=1),
)
def forward(self, x):
y = self.block(x)
pad = (x.shape[-1] - y.shape[-1]) // 2
assert pad == 0
if pad > 0:
x = x[..., pad:-pad]
return x + y
class CausalEncoderBlock(nn.Module):
def __init__(self, output_dim: int = 16, input_dim=None, stride: int = 1, groups=1):
super().__init__()
input_dim = input_dim or output_dim // 2
self.block = nn.Sequential(
CausalResidualUnit(input_dim, dilation=1, groups=groups),
CausalResidualUnit(input_dim, dilation=3, groups=groups),
CausalResidualUnit(input_dim, dilation=9, groups=groups),
Snake1d(input_dim),
WNCausalConv1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
output_padding=stride % 2,
),
)
def forward(self, x):
return self.block(x)
class CausalEncoder(nn.Module):
def __init__(
self,
d_model: int = 64,
latent_dim: int = 32,
strides: list = [2, 4, 8, 8],
depthwise: bool = False,
):
super().__init__()
# Create first convolution
self.block = [WNCausalConv1d(1, d_model, kernel_size=7, padding=3)]
# Create EncoderBlocks that double channels as they downsample by `stride`
for stride in strides:
d_model *= 2
groups = d_model // 2 if depthwise else 1
self.block += [CausalEncoderBlock(output_dim=d_model, stride=stride, groups=groups)]
groups = d_model if depthwise else 1
# Create two convolution, for mu and logvar
self.fc_mu = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
self.fc_logvar = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
# Wrap black into nn.Sequential
self.block = nn.Sequential(*self.block)
self.enc_dim = d_model
def forward(self, x):
hidden_state = self.block(x)
return {
"hidden_state": hidden_state,
"mu": self.fc_mu(hidden_state),
"logvar": self.fc_logvar(hidden_state),
}
class NoiseBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.linear = WNCausalConv1d(dim, dim, kernel_size=1, bias=False)
def forward(self, x):
B, C, T = x.shape
noise = torch.randn((B, 1, T), device=x.device, dtype=x.dtype)
h = self.linear(x)
n = noise * h
x = x + n
return x
class CausalDecoderBlock(nn.Module):
def __init__(
self,
input_dim: int = 16,
output_dim: int = 8,
stride: int = 1,
groups=1,
use_noise_block: bool = False,
):
super().__init__()
layers = [
Snake1d(input_dim),
WNCausalTransposeConv1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
output_padding=stride % 2,
),
]
if use_noise_block:
layers.append(NoiseBlock(output_dim))
layers.extend(
[
CausalResidualUnit(output_dim, dilation=1, groups=groups),
CausalResidualUnit(output_dim, dilation=3, groups=groups),
CausalResidualUnit(output_dim, dilation=9, groups=groups),
]
)
self.block = nn.Sequential(*layers)
self.input_channels = input_dim
def forward(self, x):
return self.block(x)
class TransposeLastTwoDim(torch.nn.Module):
def forward(self, x):
return torch.transpose(x, -1, -2)
class SampleRateConditionLayer(nn.Module):
def __init__(
self,
input_dim: int,
sr_bin_buckets: int = None,
cond_type: str = "scale_bias",
cond_dim: int = 128,
out_layer: bool = False,
):
super().__init__()
self.cond_type, out_layer_in_dim = cond_type, input_dim
if cond_type == "scale_bias":
self.scale_embed = nn.Embedding(sr_bin_buckets, input_dim)
self.bias_embed = nn.Embedding(sr_bin_buckets, input_dim)
nn.init.ones_(self.scale_embed.weight)
nn.init.zeros_(self.bias_embed.weight)
elif cond_type == "scale_bias_init":
self.scale_embed = nn.Embedding(sr_bin_buckets, input_dim)
self.bias_embed = nn.Embedding(sr_bin_buckets, input_dim)
nn.init.normal_(self.scale_embed.weight, mean=1)
nn.init.normal_(self.bias_embed.weight)
elif cond_type == "add":
self.cond_embed = nn.Embedding(sr_bin_buckets, input_dim)
nn.init.normal_(self.cond_embed.weight)
elif cond_type == "concat":
self.cond_embed = nn.Embedding(sr_bin_buckets, cond_dim)
assert out_layer, "out_layer must be True for concat cond_type"
out_layer_in_dim = input_dim + cond_dim
else:
raise ValueError(f"Invalid cond_type: {cond_type}")
if out_layer:
self.out_layer = nn.Sequential(
Snake1d(out_layer_in_dim),
WNCausalConv1d(out_layer_in_dim, input_dim, kernel_size=1),
)
else:
self.out_layer = nn.Identity()
def forward(self, x, sr_cond):
if self.cond_type == "scale_bias" or self.cond_type == "scale_bias_init":
x = x * self.scale_embed(sr_cond).unsqueeze(-1) + self.bias_embed(sr_cond).unsqueeze(-1)
elif self.cond_type == "add":
x = x + self.cond_embed(sr_cond).unsqueeze(-1)
elif self.cond_type == "concat":
x = torch.cat([x, self.cond_embed(sr_cond).unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
return self.out_layer(x)
class CausalDecoder(nn.Module):
def __init__(
self,
input_channel,
channels,
rates,
depthwise: bool = False,
d_out: int = 1,
use_noise_block: bool = False,
sr_bin_boundaries: List[int] = None,
cond_type: str = "scale_bias",
cond_dim: int = 128,
cond_out_layer: bool = False,
):
super().__init__()
# Add first conv layer
if depthwise:
layers = [
WNCausalConv1d(input_channel, input_channel, kernel_size=7, padding=3, groups=input_channel),
WNCausalConv1d(input_channel, channels, kernel_size=1),
]
else:
layers = [WNCausalConv1d(input_channel, channels, kernel_size=7, padding=3)]
# Add upsampling + MRF blocks
for i, stride in enumerate(rates):
input_dim = channels // 2**i
output_dim = channels // 2 ** (i + 1)
groups = output_dim if depthwise else 1
layers += [
CausalDecoderBlock(
input_dim,
output_dim,
stride,
groups=groups,
use_noise_block=use_noise_block,
)
]
# Add final conv layer
layers += [
Snake1d(output_dim),
WNCausalConv1d(output_dim, d_out, kernel_size=7, padding=3),
nn.Tanh(),
]
if sr_bin_boundaries is None:
self.model = nn.Sequential(*layers)
self.sr_bin_boundaries = None
else:
self.model = nn.ModuleList(layers)
self.register_buffer("sr_bin_boundaries", torch.tensor(sr_bin_boundaries, dtype=torch.int32))
self.sr_bin_buckets = len(sr_bin_boundaries) + 1
cond_layers = []
for layer in self.model:
if layer.__class__.__name__ == "CausalDecoderBlock":
cond_layers.append(
SampleRateConditionLayer(
input_dim=layer.input_channels,
sr_bin_buckets=self.sr_bin_buckets,
cond_type=cond_type,
cond_dim=cond_dim,
out_layer=cond_out_layer,
)
)
else:
cond_layers.append(None)
self.sr_cond_model = nn.ModuleList(cond_layers)
def get_sr_idx(self, sr):
return torch.bucketize(sr, self.sr_bin_boundaries)
def forward(self, x, sr_cond=None):
if self.sr_bin_boundaries is not None:
# assert sr_cond is not None
sr_cond = self.get_sr_idx(sr_cond)
for layer, sr_cond_layer in zip(self.model, self.sr_cond_model):
if sr_cond_layer is not None:
x = sr_cond_layer(x, sr_cond)
x = layer(x)
return x
else:
return self.model(x)
class AudioVAEConfig(BaseModel):
encoder_dim: int = 128
encoder_rates: List[int] = [2, 5, 8, 8]
latent_dim: int = 64
decoder_dim: int = 2048
decoder_rates: List[int] = [8, 6, 5, 2, 2, 2]
depthwise: bool = True
sample_rate: int = 16000
out_sample_rate: int = 48000
use_noise_block: bool = False
sr_bin_boundaries: Optional[List[int]] = [20000, 30000, 40000]
cond_type: str = "scale_bias"
cond_dim: int = 128
cond_out_layer: bool = False
class AudioVAE(nn.Module):
"""
Args:
"""
def __init__(
self,
config: AudioVAEConfig = None,
):
# 如果没有传入config,使用默认配置
if config is None:
config = AudioVAEConfig()
super().__init__()
encoder_dim = config.encoder_dim
encoder_rates = config.encoder_rates
latent_dim = config.latent_dim
decoder_dim = config.decoder_dim
decoder_rates = config.decoder_rates
depthwise = config.depthwise
sample_rate = config.sample_rate
out_sample_rate = config.out_sample_rate
use_noise_block = config.use_noise_block
sr_bin_boundaries = config.sr_bin_boundaries
cond_type = config.cond_type
cond_dim = config.cond_dim
cond_out_layer = config.cond_out_layer
self.encoder_dim = encoder_dim
self.encoder_rates = encoder_rates
self.decoder_dim = decoder_dim
self.decoder_rates = decoder_rates
self.depthwise = depthwise
self.use_noise_block = use_noise_block
if latent_dim is None:
latent_dim = encoder_dim * (2 ** len(encoder_rates))
self.latent_dim = latent_dim
self.hop_length = np.prod(encoder_rates)
self.encoder = CausalEncoder(
encoder_dim,
latent_dim,
encoder_rates,
depthwise=depthwise,
)
self.decoder = CausalDecoder(
latent_dim,
decoder_dim,
decoder_rates,
depthwise=depthwise,
use_noise_block=use_noise_block,
sr_bin_boundaries=sr_bin_boundaries,
cond_type=cond_type,
cond_dim=cond_dim,
cond_out_layer=cond_out_layer,
)
self.sample_rate = sample_rate
self.out_sample_rate = out_sample_rate
self.sr_bin_boundaries = sr_bin_boundaries
self.chunk_size = math.prod(encoder_rates)
def preprocess(self, audio_data, sample_rate):
if sample_rate is None:
sample_rate = self.sample_rate
assert sample_rate == self.sample_rate
pad_to = self.hop_length
length = audio_data.shape[-1]
right_pad = math.ceil(length / pad_to) * pad_to - length
audio_data = nn.functional.pad(audio_data, (0, right_pad))
return audio_data
def decode(self, z: torch.Tensor, sr_cond: torch.Tensor = None):
"""Decode given latent codes and return audio data
Parameters
----------
z : Tensor[B x D x T]
Quantized continuous representation of input
length : int, optional
Number of samples in output audio, by default None
Returns
-------
dict
A dictionary with the following keys:
"audio" : Tensor[B x 1 x length]
Decoded audio data.
"""
if self.sr_bin_boundaries is not None:
# use default output sample rate
if sr_cond is None:
sr_cond = torch.tensor([self.out_sample_rate], device=z.device, dtype=torch.int32)
return self.decoder(z, sr_cond)
def encode(self, audio_data: torch.Tensor, sample_rate: int):
"""
Args:
audio_data: Tensor[B x 1 x T]
sample_rate: int
Returns:
z: Tensor[B x D x T]
"""
if audio_data.ndim == 2:
audio_data = audio_data.unsqueeze(1)
audio_data = self.preprocess(audio_data, sample_rate)
return self.encoder(audio_data)["mu"]
-3
View File
@@ -128,6 +128,3 @@ def apply_lora_to_named_linear_modules(
dropout=dropout, dropout=dropout,
) )
setattr(parent, short_name, lora_layer) setattr(parent, short_name, lora_layer)
+1
View File
@@ -1,2 +1,3 @@
from .unified_cfm import UnifiedCFM, CfmConfig from .unified_cfm import UnifiedCFM, CfmConfig
from .local_dit import VoxCPMLocDiT from .local_dit import VoxCPMLocDiT
from .local_dit_v2 import VoxCPMLocDiT as VoxCPMLocDiTV2
+116
View File
@@ -0,0 +1,116 @@
import torch
from ..minicpm4 import MiniCPMModel, MiniCPM4Config
import torch.nn as nn
import math
class SinusoidalPosEmb(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
def forward(self, x, scale=1000):
if x.ndim < 1:
x = x.unsqueeze(0)
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=x.dtype, device=device) * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
out_dim: int = None,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True)
self.act = nn.SiLU()
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, bias=True)
def forward(self, sample):
sample = self.linear_1(sample)
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
class VoxCPMLocDiT(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
config: MiniCPM4Config,
in_channels: int = 64,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels
self.config = config
self.in_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
self.cond_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
self.out_proj = nn.Linear(config.hidden_size, self.out_channels, bias=True)
self.time_embeddings = SinusoidalPosEmb(config.hidden_size)
self.time_mlp = TimestepEmbedding(
in_channels=config.hidden_size,
time_embed_dim=config.hidden_size,
)
self.delta_time_mlp = TimestepEmbedding(
in_channels=config.hidden_size,
time_embed_dim=config.hidden_size,
)
assert config.vocab_size == 0, "vocab_size must be 0 for local DiT"
self.decoder = MiniCPMModel(config)
def forward(
self,
x: torch.Tensor,
mu: torch.Tensor,
t: torch.Tensor,
cond: torch.Tensor,
dt: torch.Tensor,
):
"""
Forward pass of DiT.
x: (N, C, T) tensor of inputs
mu: (N, C) tensor of hidden embedding
t: (N,) tensor of diffusion timesteps
cond: (N, C, T') tensor of prefix conditions
dt: (N,) used for mean velocity (may be supported in the future...)
"""
x = self.in_proj(x.transpose(1, 2).contiguous())
cond = self.cond_proj(cond.transpose(1, 2).contiguous())
prefix = cond.size(1)
t = self.time_embeddings(t).to(x.dtype)
t = self.time_mlp(t)
dt = self.time_embeddings(dt).to(x.dtype)
dt = self.delta_time_mlp(dt)
t = t + dt
mu = mu.view(x.size(0), -1, x.size(-1))
x = torch.cat([mu, (t).unsqueeze(1), cond, x], dim=1)
hidden, _ = self.decoder(x, is_causal=False)
hidden = hidden[:, prefix + mu.size(1) + 1 :, :]
hidden = self.out_proj(hidden)
return hidden.transpose(1, 2).contiguous()
+5 -4
View File
@@ -1,4 +1,4 @@
from typing import List, Tuple from typing import Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -138,7 +138,9 @@ class UnifiedCFM(torch.nn.Module):
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
# Training loss # Training loss
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
def adaptive_loss_weighting(self, losses: torch.Tensor, mask: torch.Tensor | None = None, p: float = 0.0, epsilon: float = 1e-3): def adaptive_loss_weighting(
self, losses: torch.Tensor, mask: torch.Tensor | None = None, p: float = 0.0, epsilon: float = 1e-3
):
weights = 1.0 / ((losses + epsilon).pow(p)) weights = 1.0 / ((losses + epsilon).pow(p))
if mask is not None: if mask is not None:
weights = weights * mask weights = weights * mask
@@ -193,8 +195,7 @@ class UnifiedCFM(torch.nn.Module):
cond = cond + noisy_mask.view(-1, 1, 1) * torch.randn_like(cond) * self.noise_cond_scale cond = cond + noisy_mask.view(-1, 1, 1) * torch.randn_like(cond) * self.noise_cond_scale
ratio_r_neq_t = ( ratio_r_neq_t = (
self.ratio_r_neq_t_range[0] self.ratio_r_neq_t_range[0] + progress * (self.ratio_r_neq_t_range[1] - self.ratio_r_neq_t_range[0])
+ progress * (self.ratio_r_neq_t_range[1] - self.ratio_r_neq_t_range[0])
if self.mean_mode if self.mean_mode
else 0.0 else 0.0
) )
+12 -13
View File
@@ -64,10 +64,8 @@ class MiniCPMLongRoPE(nn.Module):
self.long_factor = config.rope_scaling.long_factor self.long_factor = config.rope_scaling.long_factor
self.original_max_position_embeddings = config.rope_scaling.original_max_position_embeddings self.original_max_position_embeddings = config.rope_scaling.original_max_position_embeddings
scale = (self.max_position_embeddings / self.original_max_position_embeddings) scale = self.max_position_embeddings / self.original_max_position_embeddings
self.scaling_factor = math.sqrt( self.scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
1 + math.log(scale) / math.log(self.original_max_position_embeddings)
)
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
@@ -76,11 +74,7 @@ class MiniCPMLongRoPE(nn.Module):
self.register_buffer("cos_cached", torch.empty(0), persistent=False) self.register_buffer("cos_cached", torch.empty(0), persistent=False)
self.register_buffer("sin_cached", torch.empty(0), persistent=False) self.register_buffer("sin_cached", torch.empty(0), persistent=False)
self._set_cos_sin_cache( self._set_cos_sin_cache(seq_len=self.max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32)
seq_len=self.max_position_embeddings,
device=self.inv_freq.device,
dtype=torch.float32
)
def _set_cos_sin_cache(self, seq_len, device, dtype): def _set_cos_sin_cache(self, seq_len, device, dtype):
"""设置cos和sin缓存""" """设置cos和sin缓存"""
@@ -93,8 +87,7 @@ class MiniCPMLongRoPE(nn.Module):
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=device) ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=device)
freqs = torch.mul( freqs = torch.mul(
torch.outer(t, 1.0 / ext_factors).to(device=device), torch.outer(t, 1.0 / ext_factors).to(device=device), self.inv_freq.to(device=device).to(dtype)
self.inv_freq.to(device=device).to(dtype)
) )
# 创建embeddings # 创建embeddings
@@ -123,7 +116,9 @@ class MiniCPMAttention(nn.Module):
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.head_dim = config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels self.head_dim = (
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
)
self.num_key_value_heads = config.num_key_value_heads self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings self.max_position_embeddings = config.max_position_embeddings
@@ -413,7 +408,11 @@ class MiniCPMModel(nn.Module):
self.kv_cache = StaticKVCache( self.kv_cache = StaticKVCache(
num_layers=self.config.num_hidden_layers, num_layers=self.config.num_hidden_layers,
num_kv_heads=self.config.num_key_value_heads, num_kv_heads=self.config.num_key_value_heads,
dim_kv_head=self.config.hidden_size // self.config.num_attention_heads if self.config.kv_channels is None else self.config.kv_channels, dim_kv_head=(
self.config.hidden_size // self.config.num_attention_heads
if self.config.kv_channels is None
else self.config.kv_channels
),
batch_size=batch_size, batch_size=batch_size,
device=device, device=device,
dtype=dtype, dtype=dtype,
-1
View File
@@ -25,4 +25,3 @@ __all__ = [
"load_audio_text_datasets", "load_audio_text_datasets",
"build_dataloader", "build_dataloader",
] ]
+2 -5
View File
@@ -47,9 +47,7 @@ class Accelerator:
pass pass
self.scaler = torch.amp.GradScaler("cuda") if (amp and torch.cuda.is_available()) else DummyScaler() self.scaler = torch.amp.GradScaler("cuda") if (amp and torch.cuda.is_available()) else DummyScaler()
self.device_ctx = ( self.device_ctx = torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
)
self._ddp_model = None # For no_sync support self._ddp_model = None # For no_sync support
def _set_seed(self, seed: int): def _set_seed(self, seed: int):
@@ -84,7 +82,7 @@ class Accelerator:
# Model helpers # Model helpers
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
def prepare_model(self, model: torch.nn.Module, **kwargs): def prepare_model(self, model: torch.nn.Module, **kwargs):
if hasattr(model, 'device'): # make sure the matrix will be moved to the correct device if hasattr(model, "device"): # make sure the matrix will be moved to the correct device
model.device = self.device model.device = self.device
model = model.to(self.device) model = model.to(self.device)
if self.world_size > 1: if self.world_size > 1:
@@ -163,4 +161,3 @@ class Accelerator:
@staticmethod @staticmethod
def unwrap(model: torch.nn.Module) -> torch.nn.Module: def unwrap(model: torch.nn.Module) -> torch.nn.Module:
return model.module if hasattr(model, "module") else model return model.module if hasattr(model, "module") else model
-2
View File
@@ -36,5 +36,3 @@ def parse_args_with_config(config_path: str | Path | None = None):
yaml_args = argbind.parse_args(yaml_args=yaml_args, argv=[]) yaml_args = argbind.parse_args(yaml_args=yaml_args, argv=[])
cli_args.update(yaml_args) cli_args.update(yaml_args)
return cli_args return cli_args
-2
View File
@@ -1,5 +1,4 @@
import math import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import argbind import argbind
@@ -11,7 +10,6 @@ from ..model.voxcpm import VoxCPMConfig
from ..modules.audiovae import AudioVAE from ..modules.audiovae import AudioVAE
from .packers import AudioFeatureProcessingPacker from .packers import AudioFeatureProcessingPacker
DEFAULT_TEXT_COLUMN = "text" DEFAULT_TEXT_COLUMN = "text"
DEFAULT_AUDIO_COLUMN = "audio" DEFAULT_AUDIO_COLUMN = "audio"
DEFAULT_ID_COLUMN = "dataset_id" DEFAULT_ID_COLUMN = "dataset_id"
+28 -21
View File
@@ -1,5 +1,4 @@
from typing import Dict, List
from typing import Dict, List, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -147,31 +146,26 @@ class AudioFeatureProcessingPacker:
def pad_1d(x: torch.Tensor, pad_value: int = 0) -> torch.Tensor: def pad_1d(x: torch.Tensor, pad_value: int = 0) -> torch.Tensor:
if x.size(0) >= max_len: if x.size(0) >= max_len:
return x[: max_len] return x[:max_len]
pad = torch.full((max_len - x.size(0),), pad_value, dtype=x.dtype, device=x.device) pad = torch.full((max_len - x.size(0),), pad_value, dtype=x.dtype, device=x.device)
return torch.cat([x, pad], dim=0) return torch.cat([x, pad], dim=0)
def pad_3d(x: torch.Tensor) -> torch.Tensor: def pad_3d(x: torch.Tensor) -> torch.Tensor:
# x: [T, P, D] # x: [T, P, D]
if x.size(0) >= max_len: if x.size(0) >= max_len:
return x[: max_len] return x[:max_len]
pad = torch.zeros( pad = torch.zeros((max_len - x.size(0),) + x.shape[1:], dtype=x.dtype, device=x.device)
(max_len - x.size(0),) + x.shape[1:], dtype=x.dtype, device=x.device
)
return torch.cat([x, pad], dim=0) return torch.cat([x, pad], dim=0)
if lengths: if lengths:
text_tokens_batch = torch.stack([pad_1d(t, pad_value=0) for t in text_tokens_list], dim=0) text_tokens_batch = torch.stack([pad_1d(t, pad_value=0) for t in text_tokens_list], dim=0)
text_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in text_mask_list], dim=0) text_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in text_mask_list], dim=0)
audio_feats_batch = torch.stack([pad_3d(f) for f in audio_feats_list], dim=0) audio_feats_batch = torch.stack([pad_3d(f) for f in audio_feats_list], dim=0)
audio_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in audio_mask_list], dim=0) audio_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in audio_mask_list], dim=0)
loss_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in loss_mask_list], dim=0) loss_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in loss_mask_list], dim=0)
labels_batch = torch.stack([pad_1d(l, pad_value=0) for l in labels_list], dim=0) labels_batch = torch.stack([pad_1d(lbl, pad_value=0) for lbl in labels_list], dim=0)
audio_task_ids_batch = torch.stack( audio_task_ids_batch = torch.stack([pad_1d(t, pad_value=0) for t in audio_task_ids_list], dim=0)
[pad_1d(t, pad_value=0) for t in audio_task_ids_list], dim=0 audio_dataset_ids_batch = torch.stack([pad_1d(d, pad_value=0) for d in audio_dataset_ids_list], dim=0)
)
audio_dataset_ids_batch = torch.stack(
[pad_1d(d, pad_value=0) for d in audio_dataset_ids_list], dim=0
)
# Position ids: [B, T], simple 0..L_i-1 then padded with 0 # Position ids: [B, T], simple 0..L_i-1 then padded with 0
position_ids_list = [] position_ids_list = []
@@ -265,13 +259,27 @@ class AudioFeatureProcessingPacker:
) )
audio_feat_info = torch.cat([audio_pad_feat, audio_feat_info, audio_pad_feat[0:1, ...]], dim=0) audio_feat_info = torch.cat([audio_pad_feat, audio_feat_info, audio_pad_feat[0:1, ...]], dim=0)
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length), torch.ones(1)]).type(torch.int32).to( text_mask = (
text_token.device torch.cat([torch.ones(text_length), torch.zeros(audio_length), torch.ones(1)])
.type(torch.int32)
.to(text_token.device)
)
audio_mask = (
torch.cat([torch.zeros(text_length), torch.ones(audio_length), torch.zeros(1)])
.type(torch.int32)
.to(text_token.device)
)
loss_mask = (
torch.cat(
[
torch.zeros(text_length),
torch.zeros(audio_length) if is_prompt else torch.ones(audio_length),
torch.zeros(1),
]
)
.type(torch.int32)
.to(text_token.device)
) )
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length), torch.zeros(1)]).type(
torch.int32
).to(text_token.device)
loss_mask = torch.cat([torch.zeros(text_length), torch.zeros(audio_length) if is_prompt else torch.ones(audio_length), torch.zeros(1)]).type(torch.int32).to(text_token.device)
labels = torch.zeros(text_length + audio_length + 1).type(torch.int32).to(text_token.device) labels = torch.zeros(text_length + audio_length + 1).type(torch.int32).to(text_token.device)
labels[-2] = 1 labels[-2] = 1
@@ -286,4 +294,3 @@ class AudioFeatureProcessingPacker:
audio_duration, audio_duration,
text_token_count, text_token_count,
) )
-1
View File
@@ -18,4 +18,3 @@ class TrainingState:
val_loader: object val_loader: object
tracker: object tracker: object
batch_processor: object batch_processor: object
-1
View File
@@ -76,4 +76,3 @@ class TrainingTracker:
@contextlib.contextmanager @contextlib.contextmanager
def live(self): def live(self):
yield yield
+31 -28
View File
@@ -2,10 +2,10 @@
import re import re
import regex import regex
import inflect import inflect
from functools import partial
from wetext import Normalizer from wetext import Normalizer
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+') chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]+")
# whether contain chinese character # whether contain chinese character
def contains_chinese(text): def contains_chinese(text):
@@ -14,19 +14,19 @@ def contains_chinese(text):
# replace special symbol # replace special symbol
def replace_corner_mark(text): def replace_corner_mark(text):
text = text.replace('²', '平方') text = text.replace("²", "平方")
text = text.replace('³', '立方') text = text.replace("³", "立方")
text = text.replace('', '根号') text = text.replace("", "根号")
text = text.replace('', '约等于') text = text.replace("", "约等于")
text = text.replace('<', '小于') text = text.replace("<", "小于")
return text return text
# remove meaningless symbol # remove meaningless symbol
def remove_bracket(text): def remove_bracket(text):
text = text.replace('', ' ').replace('', ' ') text = text.replace("", " ").replace("", " ")
text = text.replace('', ' ').replace('', ' ') text = text.replace("", " ").replace("", " ")
text = text.replace('`', '').replace('`', '') text = text.replace("`", "").replace("`", "")
text = text.replace("——", " ") text = text.replace("——", " ")
return text return text
@@ -38,7 +38,7 @@ def spell_out_number(text: str, inflect_parser):
for i, c in enumerate(text): for i, c in enumerate(text):
if not c.isdigit(): if not c.isdigit():
if st is not None: if st is not None:
num_str = inflect_parser.number_to_words(text[st: i]) num_str = inflect_parser.number_to_words(text[st:i])
new_text.append(num_str) new_text.append(num_str)
st = None st = None
new_text.append(c) new_text.append(c)
@@ -48,7 +48,7 @@ def spell_out_number(text: str, inflect_parser):
if st is not None and st < len(text): if st is not None and st < len(text):
num_str = inflect_parser.number_to_words(text[st:]) num_str = inflect_parser.number_to_words(text[st:])
new_text.append(num_str) new_text.append(num_str)
return ''.join(new_text) return "".join(new_text)
# split paragrah logic # split paragrah logic
@@ -69,18 +69,18 @@ def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=
return len(tokenize(_text)) < merge_len return len(tokenize(_text)) < merge_len
if lang == "zh": if lang == "zh":
pounc = ['', '', '', '', '', '', '.', '?', '!', ';'] pounc = ["", "", "", "", "", "", ".", "?", "!", ";"]
else: else:
pounc = ['.', '?', '!', ';', ':'] pounc = [".", "?", "!", ";", ":"]
if comma_split: if comma_split:
pounc.extend(['', ',']) pounc.extend(["", ","])
st = 0 st = 0
utts = [] utts = []
for i, c in enumerate(text): for i, c in enumerate(text):
if c in pounc: if c in pounc:
if len(text[st: i]) > 0: if len(text[st:i]) > 0:
utts.append(text[st: i] + c) utts.append(text[st:i] + c)
if i + 1 < len(text) and text[i + 1] in ['"', '']: if i + 1 < len(text) and text[i + 1] in ['"', ""]:
tmp = utts.pop(-1) tmp = utts.pop(-1)
utts.append(tmp + text[i + 1]) utts.append(tmp + text[i + 1])
st = i + 2 st = i + 2
@@ -88,9 +88,9 @@ def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=
st = i + 1 st = i + 1
if len(utts) == 0: if len(utts) == 0:
if lang == "zh": if lang == "zh":
utts.append(text + '') utts.append(text + "")
else: else:
utts.append(text + '.') utts.append(text + ".")
final_utts = [] final_utts = []
cur_utt = "" cur_utt = ""
for utt in utts: for utt in utts:
@@ -112,13 +112,13 @@ def replace_blank(text: str):
out_str = [] out_str = []
for i, c in enumerate(text): for i, c in enumerate(text):
if c == " ": if c == " ":
if ((text[i + 1].isascii() and text[i + 1] != " ") and if (text[i + 1].isascii() and text[i + 1] != " ") and (text[i - 1].isascii() and text[i - 1] != " "):
(text[i - 1].isascii() and text[i - 1] != " ")):
out_str.append(c) out_str.append(c)
else: else:
out_str.append(c) out_str.append(c)
return "".join(out_str) return "".join(out_str)
def clean_markdown(md_text: str) -> str: def clean_markdown(md_text: str) -> str:
# 去除代码块 ``` ```(包括多行) # 去除代码块 ``` ```(包括多行)
md_text = re.sub(r"```.*?```", "", md_text, flags=re.DOTALL) md_text = re.sub(r"```.*?```", "", md_text, flags=re.DOTALL)
@@ -133,7 +133,7 @@ def clean_markdown(md_text: str) -> str:
md_text = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", md_text) md_text = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", md_text)
# 替换无序列表符号 # 替换无序列表符号
md_text = re.sub(r'^(\s*)-\s+', r'\1', md_text, flags=re.MULTILINE) md_text = re.sub(r"^(\s*)-\s+", r"\1", md_text, flags=re.MULTILINE)
# 去除HTML标签 # 去除HTML标签
md_text = re.sub(r"<[^>]+>", "", md_text) md_text = re.sub(r"<[^>]+>", "", md_text)
@@ -152,13 +152,14 @@ def clean_text(text):
# 去除 Markdown 语法 # 去除 Markdown 语法
text = clean_markdown(text) text = clean_markdown(text)
# 匹配并移除表情符号 # 匹配并移除表情符号
text = regex.compile(r'\p{Emoji_Presentation}|\p{Emoji}\uFE0F', flags=regex.UNICODE).sub("",text) text = regex.compile(r"\p{Emoji_Presentation}|\p{Emoji}\uFE0F", flags=regex.UNICODE).sub("", text)
# 去除换行符 # 去除换行符
text = text.replace("\n", " ") text = text.replace("\n", " ")
text = text.replace("\t", " ") text = text.replace("\t", " ")
text = text.replace('"', "\") text = text.replace("", '"').replace("", '"')
return text return text
class TextNormalizer: class TextNormalizer:
def __init__(self, tokenizer=None): def __init__(self, tokenizer=None):
self.tokenizer = tokenizer self.tokenizer = tokenizer
@@ -171,9 +172,11 @@ class TextNormalizer:
lang = "zh" if contains_chinese(text) else "en" lang = "zh" if contains_chinese(text) else "en"
text = clean_text(text) text = clean_text(text)
if lang == "zh": if lang == "zh":
text = text.replace("=", "等于") # 修复 ”550 + 320 等于 870 千卡。“ 被错误正则为 ”五百五十加三百二十等于八七十千卡.“ text = text.replace(
if re.search(r'([\d$%^*_+≥≤≠×÷?=])', text): # 避免 英文连字符被错误正则为减 "=", "等于"
text = re.sub(r'(?<=[a-zA-Z0-9])-(?=\d)', ' - ', text) # 修复 x-2 被正则为 x负2 ) # 修复 ”550 + 320 等于 870 千卡。“ 被错误正则为 ”五百五十加三百二十等于八七十千卡.“
if re.search(r"([\d$%^*_+≥≤≠×÷?=])", text): # 避免 英文连字符被错误正则为减
text = re.sub(r"(?<=[a-zA-Z0-9])-(?=\d)", " - ", text) # 修复 x-2 被正则为 x负2
text = self.zh_tn_model.normalize(text) text = self.zh_tn_model.normalize(text)
text = replace_blank(text) text = replace_blank(text)
text = replace_corner_mark(text) text = replace_corner_mark(text)
+6 -10
View File
@@ -7,15 +7,15 @@ Related dependencies are imported only when denoising functionality is needed.
import os import os
import tempfile import tempfile
from typing import Optional, Union from typing import Optional
import torchaudio import torchaudio
import torch
from modelscope.pipelines import pipeline from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
class ZipEnhancer: class ZipEnhancer:
"""ZipEnhancer Audio Denoising Enhancer""" """ZipEnhancer Audio Denoising Enhancer"""
def __init__(self, model_path: str = "iic/speech_zipenhancer_ans_multiloss_16k_base"): def __init__(self, model_path: str = "iic/speech_zipenhancer_ans_multiloss_16k_base"):
""" """
Initialize ZipEnhancer Initialize ZipEnhancer
@@ -23,10 +23,7 @@ class ZipEnhancer:
model_path: ModelScope model path or local path model_path: ModelScope model path or local path
""" """
self.model_path = model_path self.model_path = model_path
self._pipeline = pipeline( self._pipeline = pipeline(Tasks.acoustic_noise_suppression, model=self.model_path)
Tasks.acoustic_noise_suppression,
model=self.model_path
)
def _normalize_loudness(self, wav_path: str): def _normalize_loudness(self, wav_path: str):
""" """
@@ -37,11 +34,10 @@ class ZipEnhancer:
""" """
audio, sr = torchaudio.load(wav_path) audio, sr = torchaudio.load(wav_path)
loudness = torchaudio.functional.loudness(audio, sr) loudness = torchaudio.functional.loudness(audio, sr)
normalized_audio = torchaudio.functional.gain(audio, -20-loudness) normalized_audio = torchaudio.functional.gain(audio, -20 - loudness)
torchaudio.save(wav_path, normalized_audio, sr) torchaudio.save(wav_path, normalized_audio, sr)
def enhance(self, input_path: str, output_path: Optional[str] = None, def enhance(self, input_path: str, output_path: Optional[str] = None, normalize_loudness: bool = True) -> str:
normalize_loudness: bool = True) -> str:
""" """
Audio denoising enhancement Audio denoising enhancement
Args: Args:
@@ -57,7 +53,7 @@ class ZipEnhancer:
raise FileNotFoundError(f"Input audio file does not exist: {input_path}") raise FileNotFoundError(f"Input audio file does not exist: {input_path}")
# Create temporary file if no output path is specified # Create temporary file if no output path is specified
if output_path is None: if output_path is None:
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file: with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
output_path = tmp_file.name output_path = tmp_file.name
try: try:
# Perform denoising processing # Perform denoising processing
Generated
+5263
View File
File diff suppressed because it is too large Load Diff