update voxcpm2

This commit is contained in:
刘鑫
2026-03-31 11:50:37 +08:00
parent 23ed7ffeee
commit d9cf376e16
36 changed files with 8163 additions and 834 deletions
+146 -87
View File
@@ -2,14 +2,15 @@ import os
import sys
import numpy as np
import torch
import gradio as gr
import spaces
import gradio as gr
import spaces # noqa: F401
from typing import Optional, Tuple
from funasr import AutoModel
from pathlib import Path
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if os.environ.get("HF_REPO_ID", "").strip() == "":
os.environ["HF_REPO_ID"] = "openbmb/VoxCPM1.5"
os.environ["HF_REPO_ID"] = "openbmb/VoxCPM2"
import voxcpm
@@ -24,13 +25,13 @@ class VoxCPMDemo:
self.asr_model: Optional[AutoModel] = AutoModel(
model=self.asr_model_id,
disable_update=True,
log_level='DEBUG',
log_level="DEBUG",
device="cuda:0" if self.device == "cuda" else "cpu",
)
# TTS model (lazy init)
self.voxcpm_model: Optional[voxcpm.VoxCPM] = None
self.default_local_model_dir = "./models/VoxCPM1.5"
self.default_local_model_dir = "/Users/xinliu/Downloads/VoxCPM2-0.5B-newaudiovae-6hz-0316"
# ---------- Model helpers ----------
def _resolve_model_dir(self) -> str:
@@ -49,6 +50,7 @@ class VoxCPMDemo:
if not os.path.isdir(target_dir):
try:
from huggingface_hub import snapshot_download # type: ignore
os.makedirs(target_dir, exist_ok=True)
print(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...", file=sys.stderr)
snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False)
@@ -64,7 +66,7 @@ class VoxCPMDemo:
print("Model not loaded, initializing...", file=sys.stderr)
model_dir = self._resolve_model_dir()
print(f"Using model dir: {model_dir}", file=sys.stderr)
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir)
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir, optimize=False)
print("Model loaded successfully.", file=sys.stderr)
return self.voxcpm_model
@@ -73,21 +75,24 @@ class VoxCPMDemo:
if prompt_wav is None:
return ""
res = self.asr_model.generate(input=prompt_wav, language="auto", use_itn=True)
text = res[0]["text"].split('|>')[-1]
text = res[0]["text"].split("|>")[-1]
return text
def generate_tts_audio(
self,
text_input: str,
prompt_wav_path_input: Optional[str] = None,
prompt_text_input: Optional[str] = None,
control_instruction: str = "",
reference_wav_path_input: Optional[str] = None,
cfg_value_input: float = 2.0,
inference_timesteps_input: int = 10,
do_normalize: bool = True,
denoise: bool = True,
) -> Tuple[int, np.ndarray]:
"""
Generate speech from text using VoxCPM; optional reference audio for voice style guidance.
Generate speech from text using VoxCPM.
- If reference_wav provided: Prompt isolation mode (voice cloning)
- If no reference_wav: Voice design mode (use control_instruction to describe voice)
Returns (sample_rate, waveform_numpy)
"""
current_model = self.get_or_load_voxcpm()
@@ -96,14 +101,25 @@ class VoxCPMDemo:
if len(text) == 0:
raise ValueError("Please input text to synthesize.")
prompt_wav_path = prompt_wav_path_input if prompt_wav_path_input else None
prompt_text = prompt_text_input if prompt_text_input else None
# 处理 control instruction
control = (control_instruction or "").strip()
if control:
final_text = f"({control}){text}"
else:
final_text = text
print(f"Generating audio for text: '{text[:60]}...'", file=sys.stderr)
reference_wav_path = reference_wav_path_input if reference_wav_path_input else None
# 判断模式
if reference_wav_path:
print(f"[Prompt Isolation Mode] reference_wav: {reference_wav_path}", file=sys.stderr)
else:
print(f"[Voice Design Mode] control: {control[:50] if control else 'None'}...", file=sys.stderr)
print(f"Generating audio for text: '{final_text[:80]}...'", file=sys.stderr)
wav = current_model.generate(
text=text,
prompt_text=prompt_text,
prompt_wav_path=prompt_wav_path,
text=final_text,
reference_wav_path=reference_wav_path,
cfg_value=float(cfg_value_input),
inference_timesteps=int(inference_timesteps_input),
normalize=do_normalize,
@@ -114,46 +130,53 @@ class VoxCPMDemo:
# ---------- UI Builders ----------
THEME = gr.themes.Soft(
primary_hue="blue",
secondary_hue="gray",
neutral_hue="slate",
font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"],
)
CSS = """
.logo-container {
text-align: center;
margin: 0.5rem 0 1rem 0;
}
.logo-container img {
height: 80px;
width: auto;
max-width: 200px;
display: inline-block;
}
/* Bold accordion labels */
#acc_quick > .label-wrap,
#acc_tips > .label-wrap,
#acc_quick > .label-wrap > span,
#acc_tips > .label-wrap > span,
#acc_quick summary,
#acc_tips summary {
font-weight: 600 !important;
font-size: 1.1em !important;
}
/* Bold labels for specific checkboxes */
#chk_denoise label,
#chk_denoise span,
#chk_normalize label,
#chk_normalize span {
font-weight: 600;
}
"""
def create_demo_interface(demo: VoxCPMDemo):
"""Build the Gradio UI for VoxCPM demo."""
# static assets (logo path)
gr.set_static_paths(paths=[Path.cwd().absolute()/"assets"])
gr.set_static_paths(paths=[Path.cwd().absolute() / "assets"])
with gr.Blocks(
theme=gr.themes.Soft(
primary_hue="blue",
secondary_hue="gray",
neutral_hue="slate",
font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"]
),
css="""
.logo-container {
text-align: center;
margin: 0.5rem 0 1rem 0;
}
.logo-container img {
height: 80px;
width: auto;
max-width: 200px;
display: inline-block;
}
/* Bold accordion labels */
#acc_quick details > summary,
#acc_tips details > summary {
font-weight: 600 !important;
font-size: 1.1em !important;
}
/* Bold labels for specific checkboxes */
#chk_denoise label,
#chk_denoise span,
#chk_normalize label,
#chk_normalize span {
font-weight: 600;
}
"""
) as interface:
# Header logo
gr.HTML('<div class="logo-container"><img src="/gradio_api/file=assets/voxcpm_logo.png" alt="VoxCPM Logo"></div>')
with gr.Blocks() as interface:
gr.HTML(
'<div class="logo-container"><img src="/gradio_api/file=assets/voxcpm_logo.png" alt="VoxCPM Logo"></div>',
padding=True,
)
# Quick Start
with gr.Accordion("📋 Quick Start Guide |快速入门", open=False, elem_id="acc_quick"):
@@ -200,34 +223,56 @@ def create_demo_interface(demo: VoxCPMDemo):
# Main controls
with gr.Row():
with gr.Column():
prompt_wav = gr.Audio(
sources=["upload", 'microphone'],
# 1. Reference Audio
# gr.Markdown("### 🎤 Reference Audio (Optional)")
# gr.Markdown("*提供参考音频进行音色克隆;不提供则使用 Voice Design 模式*")
reference_wav = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Prompt Speech (Optional, or let VoxCPM improvise)",
value="./examples/example.wav",
label="Reference Audio (Optional)",
)
DoDenoisePromptAudio = gr.Checkbox(
value=False,
label="Prompt Speech Enhancement",
label="Reference Audio Enhancement",
elem_id="chk_denoise",
info="We use ZipEnhancer model to denoise the prompt audio."
info="Use ZipEnhancer to denoise the reference audio",
)
with gr.Row():
prompt_text = gr.Textbox(
value="Just by listening a few minutes a day, you'll be able to eliminate negative thoughts by conditioning your mind to be more positive.",
label="Prompt Text",
placeholder="Please enter the prompt text. Automatic recognition is supported, and you can correct the results yourself..."
)
run_btn = gr.Button("Generate Speech", variant="primary")
# 2. Control Instruction
# gr.Markdown("### 🎛️ Control Instruction (Optional)")
# gr.Markdown("*描述声音风格、情感等,格式:`(instruction) text`*")
control_instruction = gr.Textbox(
value="",
label="Control Instruction",
placeholder="*描述声音风格、情感等,格式:`(instruction) text`,例如:年轻女性,温柔甜美 / 悲伤地说 / an excited young man*",
lines=2,
)
# 3. Target Text
# gr.Markdown("### 📝 Target Text")
text = gr.Textbox(
value="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly realistic speech.",
label="Target Text",
lines=3,
)
DoNormalizeText = gr.Checkbox(
value=False,
label="Text Normalization",
elem_id="chk_normalize",
info="Use wetext library to normalize the input text",
)
run_btn = gr.Button("🔊 Generate Speech", variant="primary", size="lg")
with gr.Column():
gr.Markdown("### ⚙️ Generation Settings")
cfg_value = gr.Slider(
minimum=1.0,
maximum=3.0,
value=2.0,
step=0.1,
label="CFG Value (Guidance Scale)",
info="Higher values increase adherence to prompt, lower values allow more creativity"
info="Higher = more adherence to prompt; Lower = more creativity",
)
inference_timesteps = gr.Slider(
minimum=4,
@@ -235,41 +280,55 @@ def create_demo_interface(demo: VoxCPMDemo):
value=10,
step=1,
label="Inference Timesteps",
info="Number of inference timesteps for generation (higher values may improve quality but slower)"
info="Higher = better quality but slower",
)
with gr.Row():
text = gr.Textbox(
value="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly realistic speech.",
label="Target Text",
)
with gr.Row():
DoNormalizeText = gr.Checkbox(
value=False,
label="Text Normalization",
elem_id="chk_normalize",
info="We use wetext library to normalize the input text."
)
audio_output = gr.Audio(label="Output Audio")
gr.Markdown("### 🔈 Output")
audio_output = gr.Audio(label="Generated Audio")
gr.Markdown("""
---
**模式说明 / Mode Info:**
- **有 Reference Audio** → Prompt 隔离模式(音色克隆)
- **无 Reference Audio** → Voice Design 模式(用 Control Instruction 描述声音)
**Control Instruction 示例:**
- `年轻女性,温柔甜美`
- `悲伤地说`
- `an excited young man`
""")
# Wiring
run_btn.click(
fn=demo.generate_tts_audio,
inputs=[text, prompt_wav, prompt_text, cfg_value, inference_timesteps, DoNormalizeText, DoDenoisePromptAudio],
inputs=[
text,
control_instruction,
reference_wav,
cfg_value,
inference_timesteps,
DoNormalizeText,
DoDenoisePromptAudio,
],
outputs=[audio_output],
show_progress=True,
api_name="generate",
)
prompt_wav.change(fn=demo.prompt_wav_recognition, inputs=[prompt_wav], outputs=[prompt_text])
return interface
def run_demo(server_name: str = "localhost", server_port: int = 7860, show_error: bool = True):
def run_demo(server_name: str = "0.0.0.0", server_port: int = 7869, show_error: bool = True):
demo = VoxCPMDemo()
interface = create_demo_interface(demo)
# Recommended to enable queue on Spaces for better throughput
interface.queue(max_size=10, default_concurrency_limit=1).launch(server_name=server_name, server_port=server_port, show_error=show_error)
interface.queue(max_size=10, default_concurrency_limit=1).launch(
server_name=server_name,
server_port=server_port,
show_error=show_error,
theme=THEME,
css=CSS,
)
if __name__ == "__main__":
run_demo()
run_demo()
+2 -2
View File
@@ -25,8 +25,8 @@ lora:
enable_lm: true
enable_dit: true
enable_proj: false
r: 32
alpha: 16
r: 8
alpha: 16
dropout: 0.0
# Distribution options (optional)
+1 -1
View File
@@ -25,7 +25,7 @@ lora:
enable_lm: true
enable_dit: true
enable_proj: false
r: 32
r: 8
alpha: 16
dropout: 0.0
Binary file not shown.
+247 -260
View File
@@ -1,18 +1,14 @@
import os
import sys
import time
import glob
import json
import yaml
import shutil
import datetime
import subprocess
import threading
import gradio as gr
import torch
import soundfile as sf
from pathlib import Path
from typing import Optional, List
from typing import Optional
# Add src to sys.path
project_root = Path(__file__).parent
@@ -89,7 +85,7 @@ LANG_DICT = {
"lang_select": "Language / 语言",
"refresh": "刷新",
"output_name": "输出目录名称 (可选,若存在则继续训练)",
}
},
}
# Global variables
@@ -98,9 +94,11 @@ asr_model: Optional[AutoModel] = None
training_process: Optional[subprocess.Popen] = None
training_log = ""
def get_timestamp_str():
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
def get_or_load_asr_model():
global asr_model
if asr_model is None:
@@ -109,44 +107,46 @@ def get_or_load_asr_model():
asr_model = AutoModel(
model="iic/SenseVoiceSmall",
disable_update=True,
log_level='ERROR',
log_level="ERROR",
device=device,
)
return asr_model
def recognize_audio(audio_path):
if not audio_path:
return ""
try:
model = get_or_load_asr_model()
res = model.generate(input=audio_path, language="auto", use_itn=True)
text = res[0]["text"].split('|>')[-1]
text = res[0]["text"].split("|>")[-1]
return text
except Exception as e:
print(f"ASR Error: {e}", file=sys.stderr)
return ""
def scan_lora_checkpoints(root_dir="lora", with_info=False):
"""
Scans for LoRA checkpoints in the lora directory.
Args:
root_dir: Directory to scan for LoRA checkpoints
with_info: If True, returns list of (path, base_model) tuples
Returns:
List of checkpoint paths, or list of (path, base_model) tuples if with_info=True
"""
checkpoints = []
if not os.path.exists(root_dir):
os.makedirs(root_dir, exist_ok=True)
# Look for lora_weights.safetensors recursively
for root, dirs, files in os.walk(root_dir):
if "lora_weights.safetensors" in files:
# Use the relative path from root_dir as the ID
rel_path = os.path.relpath(root, root_dir)
if with_info:
# Try to read base_model from lora_config.json
base_model = None
@@ -161,15 +161,16 @@ def scan_lora_checkpoints(root_dir="lora", with_info=False):
checkpoints.append((rel_path, base_model))
else:
checkpoints.append(rel_path)
# Also check for checkpoints in the default location if they exist
default_ckpt = "checkpoints/finetune_lora"
if os.path.exists(os.path.join(root_dir, default_ckpt)):
# This might be covered by the walk, but good to be sure
pass
# This might be covered by the walk, but good to be sure
pass
return sorted(checkpoints, reverse=True)
def load_lora_config_from_checkpoint(lora_path):
"""Load LoRA config from lora_config.json if available."""
lora_config_file = os.path.join(lora_path, "lora_config.json")
@@ -184,6 +185,7 @@ def load_lora_config_from_checkpoint(lora_path):
print(f"Warning: Failed to load lora_config.json: {e}", file=sys.stderr)
return None, None
def get_default_lora_config():
"""Return default LoRA config for hot-swapping support."""
return LoRAConfig(
@@ -192,16 +194,17 @@ def get_default_lora_config():
r=32,
alpha=16,
target_modules_lm=["q_proj", "v_proj", "k_proj", "o_proj"],
target_modules_dit=["q_proj", "v_proj", "k_proj", "o_proj"]
target_modules_dit=["q_proj", "v_proj", "k_proj", "o_proj"],
)
def load_model(pretrained_path, lora_path=None):
global current_model
print(f"Loading model from {pretrained_path}...", file=sys.stderr)
lora_config = None
lora_weights_path = None
if lora_path:
full_lora_path = os.path.join("lora", lora_path)
if os.path.exists(full_lora_path):
@@ -214,7 +217,7 @@ def load_model(pretrained_path, lora_path=None):
# Fallback to default config for old checkpoints
lora_config = get_default_lora_config()
print("Using default LoRA config (lora_config.json not found)", file=sys.stderr)
# Always init with a default LoRA config to allow hot-swapping later
if lora_config is None:
lora_config = get_default_lora_config()
@@ -228,25 +231,24 @@ def load_model(pretrained_path, lora_path=None):
)
return "Model loaded successfully!"
def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, steps, seed, pretrained_path=None):
global current_model
# 如果选择了 LoRA 模型且当前模型未加载,尝试从 LoRA config 读取 base_model
if current_model is None:
# 优先使用用户指定的预训练模型路径
base_model_path = pretrained_path if pretrained_path and pretrained_path.strip() else default_pretrained_path
# 如果选择了 LoRA,尝试从其 config 读取 base_model
if lora_selection and lora_selection != "None":
full_lora_path = os.path.join("lora", lora_selection)
lora_config_file = os.path.join(full_lora_path, "lora_config.json")
if os.path.exists(lora_config_file):
try:
with open(lora_config_file, "r", encoding="utf-8") as f:
lora_info = json.load(f)
saved_base_model = lora_info.get("base_model")
if saved_base_model:
# 优先使用保存的 base_model 路径
if os.path.exists(saved_base_model):
@@ -257,11 +259,11 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
print(f"Falling back to default: {base_model_path}", file=sys.stderr)
except Exception as e:
print(f"Warning: Failed to read base_model from LoRA config: {e}", file=sys.stderr)
# 加载模型
try:
print(f"Loading base model: {base_model_path}", file=sys.stderr)
status_msg = load_model(base_model_path)
load_model(base_model_path)
if lora_selection and lora_selection != "None":
print(f"Model loaded for LoRA: {lora_selection}", file=sys.stderr)
except Exception as e:
@@ -270,6 +272,7 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
return None, error_msg
# Handle LoRA hot-swapping
assert current_model is not None, "Model must be loaded before inference"
if lora_selection and lora_selection != "None":
full_lora_path = os.path.join("lora", lora_selection)
print(f"Hot-loading LoRA: {full_lora_path}", file=sys.stderr)
@@ -290,11 +293,11 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
# 处理 prompt 参数:必须同时为 None 或同时有值
final_prompt_wav = None
final_prompt_text = None
if prompt_wav and prompt_wav.strip():
# 有参考音频
final_prompt_wav = prompt_wav
# 如果没有提供参考文本,尝试自动识别
if not prompt_text or not prompt_text.strip():
print("参考音频已提供但缺少文本,自动识别中...", file=sys.stderr)
@@ -317,14 +320,16 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
prompt_text=final_prompt_text,
cfg_value=cfg_scale,
inference_timesteps=steps,
denoise=False
denoise=False,
)
return (current_model.tts_model.sample_rate, audio_np), "Generation Success"
except Exception as e:
import traceback
traceback.print_exc()
return None, f"Error: {str(e)}"
def start_training(
pretrained_path,
train_manifest,
@@ -355,8 +360,8 @@ def start_training(
hf_model_id="",
distribute=False,
):
global training_process, training_log
global training_log
if training_process is not None and training_process.poll() is None:
return "Training is already running!"
@@ -368,7 +373,7 @@ def start_training(
save_dir = os.path.join("lora", timestamp)
checkpoints_dir = os.path.join(save_dir, "checkpoints")
logs_dir = os.path.join(save_dir, "logs")
os.makedirs(checkpoints_dir, exist_ok=True)
os.makedirs(logs_dir, exist_ok=True)
@@ -394,10 +399,7 @@ def start_training(
"max_steps": resolved_max_steps,
"save_path": checkpoints_dir,
"tensorboard": tensorboard_path if tensorboard_path else logs_dir,
"lambdas": {
"loss/diff": 1.0,
"loss/stop": 1.0
},
"lambdas": {"loss/diff": 1.0, "loss/stop": 1.0},
"lora": {
"enable_lm": bool(enable_lm),
"enable_dit": bool(enable_dit),
@@ -406,10 +408,10 @@ def start_training(
"alpha": int(lora_alpha),
"dropout": float(dropout),
"target_modules_lm": ["q_proj", "v_proj", "k_proj", "o_proj"],
"target_modules_dit": ["q_proj", "v_proj", "k_proj", "o_proj"]
"target_modules_dit": ["q_proj", "v_proj", "k_proj", "o_proj"],
},
}
# Add distribution options if provided
if hf_model_id and hf_model_id.strip():
config["hf_model_id"] = hf_model_id.strip()
@@ -420,49 +422,42 @@ def start_training(
with open(config_path, "w") as f:
yaml.dump(config, f)
cmd = [
sys.executable,
"scripts/train_voxcpm_finetune.py",
"--config_path",
config_path
]
cmd = [sys.executable, "scripts/train_voxcpm_finetune.py", "--config_path", config_path]
training_log = f"Starting training...\nConfig saved to {config_path}\nOutput dir: {save_dir}\n"
def run_process():
global training_process, training_log
training_process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1
)
training_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
assert training_process.stdout is not None
for line in training_process.stdout:
training_log += line
# Keep log size manageable
if len(training_log) > 100000:
training_log = training_log[-100000:]
training_process.wait()
training_log += f"\nTraining finished with code {training_process.returncode}"
threading.Thread(target=run_process, daemon=True).start()
return f"Training started! Check 'lora/{timestamp}'"
def get_training_log():
return training_log
def stop_training():
global training_process, training_log
global training_log
if training_process is not None and training_process.poll() is None:
training_process.terminate()
training_log += "\nTraining terminated by user."
return "Training stopped."
return "No training running."
# --- GUI Layout ---
# 自定义CSS样式
@@ -830,14 +825,10 @@ label {
}
"""
with gr.Blocks(
title="VoxCPM LoRA WebUI",
theme=gr.themes.Soft(),
css=custom_css
) as app:
with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css) as app:
# State for language
lang_state = gr.State("zh") # Default to Chinese
lang_state = gr.State("zh") # Default to Chinese
# 标题区域
with gr.Row(elem_classes="title-section"):
@@ -850,10 +841,7 @@ with gr.Blocks(
""")
with gr.Column(scale=1):
lang_btn = gr.Radio(
choices=["en", "zh"],
value="zh",
label="🌐 Language / 语言",
elem_classes="lang-selector"
choices=["en", "zh"], value="zh", label="🌐 Language / 语言", elem_classes="lang-selector"
)
with gr.Tabs(elem_classes="tabs") as tabs:
@@ -869,79 +857,40 @@ with gr.Blocks(
gr.Markdown("#### 📁 基础配置")
train_pretrained_path = gr.Textbox(
label="📂 预训练模型路径",
value=default_pretrained_path,
elem_classes="input-field"
label="📂 预训练模型路径", value=default_pretrained_path, elem_classes="input-field"
)
train_manifest = gr.Textbox(
label="📋 训练数据清单 (jsonl)",
value="examples/train_data_example.jsonl",
elem_classes="input-field"
)
val_manifest = gr.Textbox(
label="📊 验证数据清单 (可选)",
value="",
elem_classes="input-field"
elem_classes="input-field",
)
val_manifest = gr.Textbox(label="📊 验证数据清单 (可选)", value="", elem_classes="input-field")
gr.Markdown("#### ⚙️ 训练参数")
with gr.Row():
lr = gr.Number(
label="📈 学习率 (Learning Rate)",
value=1e-4,
elem_classes="input-field"
)
lr = gr.Number(label="📈 学习率 (Learning Rate)", value=1e-4, elem_classes="input-field")
num_iters = gr.Number(
label="🔄 最大迭代次数",
value=2000,
precision=0,
elem_classes="input-field"
label="🔄 最大迭代次数", value=2000, precision=0, elem_classes="input-field"
)
batch_size = gr.Number(
label="📦 批次大小 (Batch Size)",
value=1,
precision=0,
elem_classes="input-field"
label="📦 批次大小 (Batch Size)", value=1, precision=0, elem_classes="input-field"
)
with gr.Row():
lora_rank = gr.Number(
label="🎯 LoRA Rank",
value=32,
precision=0,
elem_classes="input-field"
)
lora_alpha = gr.Number(
label="⚖️ LoRA Alpha",
value=16,
precision=0,
elem_classes="input-field"
)
lora_rank = gr.Number(label="🎯 LoRA Rank", value=32, precision=0, elem_classes="input-field")
lora_alpha = gr.Number(label="⚖️ LoRA Alpha", value=16, precision=0, elem_classes="input-field")
save_interval = gr.Number(
label="💾 保存间隔 (Steps)",
value=1000,
precision=0,
elem_classes="input-field"
label="💾 保存间隔 (Steps)", value=1000, precision=0, elem_classes="input-field"
)
output_name = gr.Textbox(
label="📁 输出目录名称 (可选,若存在则继续训练)",
value="",
elem_classes="input-field"
label="📁 输出目录名称 (可选,若存在则继续训练)", value="", elem_classes="input-field"
)
with gr.Row():
start_btn = gr.Button(
" 开始训练",
variant="primary",
elem_classes="button-primary"
)
stop_btn = gr.Button(
"⏹️ 停止训练",
variant="stop",
elem_classes="button-stop"
)
start_btn = gr.Button("▶️ 开始训练", variant="primary", elem_classes="button-primary")
stop_btn = gr.Button(" 停止训练", variant="stop", elem_classes="button-stop")
with gr.Accordion("🔧 高级选项 (Advanced)", open=False, elem_classes="accordion"):
with gr.Row():
@@ -961,10 +910,12 @@ with gr.Blocks(
enable_dit = gr.Checkbox(label="启用 LoRA DIT (enable_dit)", value=True)
enable_proj = gr.Checkbox(label="启用投影 (enable_proj)", value=False)
dropout = gr.Number(label="LoRA Dropout", value=0.0)
gr.Markdown("#### 分发选项 (Distribution)")
with gr.Row():
hf_model_id = gr.Textbox(label="HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)", value="openbmb/VoxCPM1.5")
hf_model_id = gr.Textbox(
label="HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)", value="openbmb/VoxCPM1.5"
)
distribute = gr.Checkbox(label="分发模式 (distribute)", value=False)
with gr.Column(scale=2, elem_classes="form-section"):
@@ -975,26 +926,44 @@ with gr.Blocks(
max_lines=30,
interactive=False,
elem_classes="input-field",
show_label=False
show_label=False,
)
start_btn.click(
start_training,
inputs=[
train_pretrained_path, train_manifest, val_manifest,
lr, num_iters, batch_size, lora_rank, lora_alpha, save_interval,
train_pretrained_path,
train_manifest,
val_manifest,
lr,
num_iters,
batch_size,
lora_rank,
lora_alpha,
save_interval,
output_name,
# advanced
grad_accum_steps, num_workers, log_interval, valid_interval,
weight_decay, warmup_steps, max_steps, sample_rate,
enable_lm, enable_dit, enable_proj, dropout, tensorboard_path,
grad_accum_steps,
num_workers,
log_interval,
valid_interval,
weight_decay,
warmup_steps,
max_steps,
sample_rate,
enable_lm,
enable_dit,
enable_proj,
dropout,
tensorboard_path,
# distribution
hf_model_id, distribute
hf_model_id,
distribute,
],
outputs=[logs_out] # Initial message
outputs=[logs_out], # Initial message
)
stop_btn.click(stop_training, outputs=[logs_out])
# Log refresher
timer = gr.Timer(1)
timer.tick(get_training_log, outputs=logs_out)
@@ -1016,21 +985,17 @@ with gr.Blocks(
value="Hello, this is a test of the VoxCPM LoRA model.",
elem_classes="input-field",
lines=4,
placeholder="输入要合成的文本内容..."
placeholder="输入要合成的文本内容...",
)
gr.Markdown("**🎭 声音克隆(可选)**")
prompt_wav = gr.Audio(
label="🎵 参考音频",
type="filepath",
elem_classes="input-field"
)
prompt_wav = gr.Audio(label="🎵 参考音频", type="filepath", elem_classes="input-field")
prompt_text = gr.Textbox(
label="📝 参考文本(可选)",
elem_classes="input-field",
placeholder="如不填写,将自动识别参考音频内容"
placeholder="如不填写,将自动识别参考音频内容",
)
# 中栏:模型选择和参数配置 (35%)
@@ -1043,15 +1008,11 @@ with gr.Blocks(
value="None",
interactive=True,
elem_classes="input-field",
info="选择训练好的 LoRA 模型,或选择 None 使用基础模型"
)
refresh_lora_btn = gr.Button(
"🔄 刷新模型列表",
elem_classes="button-refresh",
size="sm"
info="选择训练好的 LoRA 模型,或选择 None 使用基础模型",
)
refresh_lora_btn = gr.Button("🔄 刷新模型列表", elem_classes="button-refresh", size="sm")
gr.Markdown("#### ⚙️ 生成参数")
cfg_scale = gr.Slider(
@@ -1060,59 +1021,50 @@ with gr.Blocks(
maximum=5.0,
value=2.0,
step=0.1,
info="引导系数,值越大越贴近提示"
info="引导系数,值越大越贴近提示",
)
steps = gr.Slider(
label="🔢 推理步数",
minimum=1,
maximum=50,
value=10,
step=1,
info="生成质量与步数成正比,但耗时更长"
info="生成质量与步数成正比,但耗时更长",
)
seed = gr.Number(
label="🎲 随机种子",
value=-1,
precision=0,
elem_classes="input-field",
info="-1 为随机,固定值可复现结果"
info="-1 为随机,固定值可复现结果",
)
generate_btn = gr.Button(
"🎵 生成音频",
variant="primary",
elem_classes="button-primary",
size="lg"
)
generate_btn = gr.Button("🎵 生成音频", variant="primary", elem_classes="button-primary", size="lg")
# 右栏:生成结果 (30%)
with gr.Column(scale=30, elem_classes="form-section"):
gr.Markdown("#### 🎧 生成结果")
audio_out = gr.Audio(
label="",
elem_classes="input-field",
show_label=False
)
audio_out = gr.Audio(label="", elem_classes="input-field", show_label=False)
gr.Markdown("#### 📋 状态信息")
status_out = gr.Textbox(
label="",
interactive=False,
elem_classes="input-field",
show_label=False,
lines=3,
placeholder="等待生成..."
placeholder="等待生成...",
)
def refresh_loras():
# 获取 LoRA checkpoints 及其 base model 信息
checkpoints_with_info = scan_lora_checkpoints(with_info=True)
choices = ["None"] + [ckpt[0] for ckpt in checkpoints_with_info]
# 输出调试信息
print(f"刷新 LoRA 列表: 找到 {len(checkpoints_with_info)} 个检查点", file=sys.stderr)
for ckpt_path, base_model in checkpoints_with_info:
@@ -1120,22 +1072,27 @@ with gr.Blocks(
print(f" - {ckpt_path} (Base Model: {base_model})", file=sys.stderr)
else:
print(f" - {ckpt_path}", file=sys.stderr)
return gr.update(choices=choices, value="None")
refresh_lora_btn.click(refresh_loras, outputs=[lora_select])
# Auto-recognize audio when uploaded
prompt_wav.change(
fn=recognize_audio,
inputs=[prompt_wav],
outputs=[prompt_text]
)
prompt_wav.change(fn=recognize_audio, inputs=[prompt_wav], outputs=[prompt_text])
generate_btn.click(
run_inference,
inputs=[infer_text, prompt_wav, prompt_text, lora_select, cfg_scale, steps, seed, train_pretrained_path],
outputs=[audio_out, status_out]
inputs=[
infer_text,
prompt_wav,
prompt_text,
lora_select,
cfg_scale,
steps,
seed,
train_pretrained_path,
],
outputs=[audio_out, status_out],
)
# --- Language Switching Logic ---
@@ -1144,111 +1101,141 @@ with gr.Blocks(
# Labels for advanced options
if lang == "zh":
adv = {
'grad_accum_steps': "梯度累积 (grad_accum_steps)",
'num_workers': "数据加载线程 (num_workers)",
'log_interval': "日志间隔 (log_interval)",
'valid_interval': "验证间隔 (valid_interval)",
'weight_decay': "权重衰减 (weight_decay)",
'warmup_steps': "warmup_steps",
'max_steps': "最大步数 (max_steps)",
'sample_rate': "采样率 (sample_rate)",
'enable_lm': "启用 LoRA LM (enable_lm)",
'enable_dit': "启用 LoRA DIT (enable_dit)",
'enable_proj': "启用投影 (enable_proj)",
'dropout': "LoRA Dropout",
'tensorboard_path': "Tensorboard 路径 (可选)",
'hf_model_id': "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
'distribute': "分发模式 (distribute)",
"grad_accum_steps": "梯度累积 (grad_accum_steps)",
"num_workers": "数据加载线程 (num_workers)",
"log_interval": "日志间隔 (log_interval)",
"valid_interval": "验证间隔 (valid_interval)",
"weight_decay": "权重衰减 (weight_decay)",
"warmup_steps": "warmup_steps",
"max_steps": "最大步数 (max_steps)",
"sample_rate": "采样率 (sample_rate)",
"enable_lm": "启用 LoRA LM (enable_lm)",
"enable_dit": "启用 LoRA DIT (enable_dit)",
"enable_proj": "启用投影 (enable_proj)",
"dropout": "LoRA Dropout",
"tensorboard_path": "Tensorboard 路径 (可选)",
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
"distribute": "分发模式 (distribute)",
}
else:
adv = {
'grad_accum_steps': "Grad Accum Steps",
'num_workers': "Num Workers",
'log_interval': "Log Interval",
'valid_interval': "Valid Interval",
'weight_decay': "Weight Decay",
'warmup_steps': "Warmup Steps",
'max_steps': "Max Steps",
'sample_rate': "Sample Rate",
'enable_lm': "Enable LoRA LM",
'enable_dit': "Enable LoRA DIT",
'enable_proj': "Enable Projection",
'dropout': "LoRA Dropout",
'tensorboard_path': "Tensorboard Path (Optional)",
'hf_model_id': "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
'distribute': "Distribute Mode",
"grad_accum_steps": "Grad Accum Steps",
"num_workers": "Num Workers",
"log_interval": "Log Interval",
"valid_interval": "Valid Interval",
"weight_decay": "Weight Decay",
"warmup_steps": "Warmup Steps",
"max_steps": "Max Steps",
"sample_rate": "Sample Rate",
"enable_lm": "Enable LoRA LM",
"enable_dit": "Enable LoRA DIT",
"enable_proj": "Enable Projection",
"dropout": "LoRA Dropout",
"tensorboard_path": "Tensorboard Path (Optional)",
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
"distribute": "Distribute Mode",
}
return (
gr.update(value=f"# {d['title']}"),
gr.update(label=d['tab_train']),
gr.update(label=d['tab_infer']),
gr.update(label=d['pretrained_path']),
gr.update(label=d['train_manifest']),
gr.update(label=d['val_manifest']),
gr.update(label=d['lr']),
gr.update(label=d['max_iters']),
gr.update(label=d['batch_size']),
gr.update(label=d['lora_rank']),
gr.update(label=d['lora_alpha']),
gr.update(label=d['save_interval']),
gr.update(label=d['output_name']),
gr.update(value=d['start_train']),
gr.update(value=d['stop_train']),
gr.update(label=d['train_logs']),
gr.update(label=d["tab_train"]),
gr.update(label=d["tab_infer"]),
gr.update(label=d["pretrained_path"]),
gr.update(label=d["train_manifest"]),
gr.update(label=d["val_manifest"]),
gr.update(label=d["lr"]),
gr.update(label=d["max_iters"]),
gr.update(label=d["batch_size"]),
gr.update(label=d["lora_rank"]),
gr.update(label=d["lora_alpha"]),
gr.update(label=d["save_interval"]),
gr.update(label=d["output_name"]),
gr.update(value=d["start_train"]),
gr.update(value=d["stop_train"]),
gr.update(label=d["train_logs"]),
# Advanced options (must match outputs order)
gr.update(label=adv['grad_accum_steps']),
gr.update(label=adv['num_workers']),
gr.update(label=adv['log_interval']),
gr.update(label=adv['valid_interval']),
gr.update(label=adv['weight_decay']),
gr.update(label=adv['warmup_steps']),
gr.update(label=adv['max_steps']),
gr.update(label=adv['sample_rate']),
gr.update(label=adv['enable_lm']),
gr.update(label=adv['enable_dit']),
gr.update(label=adv['enable_proj']),
gr.update(label=adv['dropout']),
gr.update(label=adv['tensorboard_path']),
gr.update(label=adv["grad_accum_steps"]),
gr.update(label=adv["num_workers"]),
gr.update(label=adv["log_interval"]),
gr.update(label=adv["valid_interval"]),
gr.update(label=adv["weight_decay"]),
gr.update(label=adv["warmup_steps"]),
gr.update(label=adv["max_steps"]),
gr.update(label=adv["sample_rate"]),
gr.update(label=adv["enable_lm"]),
gr.update(label=adv["enable_dit"]),
gr.update(label=adv["enable_proj"]),
gr.update(label=adv["dropout"]),
gr.update(label=adv["tensorboard_path"]),
# Distribution options
gr.update(label=adv['hf_model_id']),
gr.update(label=adv['distribute']),
gr.update(label=adv["hf_model_id"]),
gr.update(label=adv["distribute"]),
# Inference section
gr.update(label=d['text_to_synth']),
gr.update(label=d['ref_audio']),
gr.update(label=d['ref_text']),
gr.update(label=d['select_lora']),
gr.update(value=d['refresh']),
gr.update(label=d['cfg_scale']),
gr.update(label=d['infer_steps']),
gr.update(label=d['seed']),
gr.update(value=d['gen_audio']),
gr.update(label=d['gen_output']),
gr.update(label=d['status']),
gr.update(label=d["text_to_synth"]),
gr.update(label=d["ref_audio"]),
gr.update(label=d["ref_text"]),
gr.update(label=d["select_lora"]),
gr.update(value=d["refresh"]),
gr.update(label=d["cfg_scale"]),
gr.update(label=d["infer_steps"]),
gr.update(label=d["seed"]),
gr.update(value=d["gen_audio"]),
gr.update(label=d["gen_output"]),
gr.update(label=d["status"]),
)
lang_btn.change(
change_language,
inputs=[lang_btn],
outputs=[
title_md, tab_train, tab_infer,
train_pretrained_path, train_manifest, val_manifest,
lr, num_iters, batch_size, lora_rank, lora_alpha, save_interval,
title_md,
tab_train,
tab_infer,
train_pretrained_path,
train_manifest,
val_manifest,
lr,
num_iters,
batch_size,
lora_rank,
lora_alpha,
save_interval,
output_name,
start_btn, stop_btn, logs_out,
start_btn,
stop_btn,
logs_out,
# advanced outputs
grad_accum_steps, num_workers, log_interval, valid_interval,
weight_decay, warmup_steps, max_steps, sample_rate,
enable_lm, enable_dit, enable_proj, dropout, tensorboard_path,
grad_accum_steps,
num_workers,
log_interval,
valid_interval,
weight_decay,
warmup_steps,
max_steps,
sample_rate,
enable_lm,
enable_dit,
enable_proj,
dropout,
tensorboard_path,
# distribution outputs
hf_model_id, distribute,
infer_text, prompt_wav, prompt_text,
lora_select, refresh_lora_btn, cfg_scale, steps, seed,
generate_btn, audio_out, status_out
]
hf_model_id,
distribute,
infer_text,
prompt_wav,
prompt_text,
lora_select,
refresh_lora_btn,
cfg_scale,
steps,
seed,
generate_btn,
audio_out,
status_out,
],
)
if __name__ == "__main__":
# Ensure lora directory exists
os.makedirs("lora", exist_ok=True)
app.queue().launch(server_name="0.0.0.0", server_port=7860)
app.queue().launch(server_name="0.0.0.0", server_port=7860)
+1 -3
View File
@@ -30,7 +30,7 @@ dependencies = [
"torchcodec",
"transformers>=4.36.2",
"einops",
"gradio<6",
"gradio>=6,<7",
"inflect",
"addict",
"wetext",
@@ -57,7 +57,6 @@ dev = [
"pytest-cov>=2.0",
"black>=21.0",
"flake8>=3.8",
"mypy>=0.800",
"pre-commit>=2.0",
]
@@ -90,7 +89,6 @@ extend-exclude = '''
\.eggs
| \.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| build
+4 -1
View File
@@ -125,7 +125,10 @@ def main():
out_path.parent.mkdir(parents=True, exist_ok=True)
sf.write(str(out_path), audio_np, model.tts_model.sample_rate)
print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
print(
f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
file=sys.stderr,
)
if __name__ == "__main__":
+35 -18
View File
@@ -112,22 +112,24 @@ def main():
f"lora_config.json not found in {ckpt_dir}. "
"Make sure the checkpoint was saved with the updated training script."
)
with open(lora_config_path, "r", encoding="utf-8") as f:
lora_info = json.load(f)
# Get base model path (command line arg overrides config)
pretrained_path = args.base_model if args.base_model else lora_info.get("base_model")
if not pretrained_path:
raise ValueError("base_model not found in lora_config.json and --base_model not provided")
# Get LoRA config
lora_cfg_dict = lora_info.get("lora_config", {})
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None
print(f"Loaded config from: {lora_config_path}", file=sys.stderr)
print(f" Base model: {pretrained_path}", file=sys.stderr)
print(f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None", file=sys.stderr)
print(
f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None", file=sys.stderr
)
# 3. Load model with LoRA (no denoiser)
print(f"\n[1/2] Loading model with LoRA: {pretrained_path}", file=sys.stderr)
@@ -146,10 +148,10 @@ def main():
out_path = Path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True)
print(f"\n[2/2] Starting synthesis tests...", file=sys.stderr)
print("\n[2/2] Starting synthesis tests...", file=sys.stderr)
# === Test 1: With LoRA ===
print(f"\n [Test 1] Synthesize with LoRA...", file=sys.stderr)
print("\n [Test 1] Synthesize with LoRA...", file=sys.stderr)
audio_np = model.generate(
text=args.text,
prompt_wav_path=prompt_wav_path,
@@ -162,10 +164,13 @@ def main():
)
lora_output = out_path.with_stem(out_path.stem + "_with_lora")
sf.write(str(lora_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
print(
f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
file=sys.stderr,
)
# === Test 2: Disable LoRA (via set_lora_enabled) ===
print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...", file=sys.stderr)
print("\n [Test 2] Disable LoRA (set_lora_enabled=False)...", file=sys.stderr)
model.set_lora_enabled(False)
audio_np = model.generate(
text=args.text,
@@ -179,10 +184,13 @@ def main():
)
disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled")
sf.write(str(disabled_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
print(
f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
file=sys.stderr,
)
# === Test 3: Re-enable LoRA ===
print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...", file=sys.stderr)
print("\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...", file=sys.stderr)
model.set_lora_enabled(True)
audio_np = model.generate(
text=args.text,
@@ -196,10 +204,13 @@ def main():
)
reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled")
sf.write(str(reenabled_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
print(
f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
file=sys.stderr,
)
# === Test 4: Unload LoRA (reset_lora_weights) ===
print(f"\n [Test 4] Unload LoRA (unload_lora)...", file=sys.stderr)
print("\n [Test 4] Unload LoRA (unload_lora)...", file=sys.stderr)
model.unload_lora()
audio_np = model.generate(
text=args.text,
@@ -213,10 +224,13 @@ def main():
)
reset_output = out_path.with_stem(out_path.stem + "_lora_reset")
sf.write(str(reset_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
print(
f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
file=sys.stderr,
)
# === Test 5: Hot-reload LoRA (load_lora) ===
print(f"\n [Test 5] Hot-reload LoRA (load_lora)...", file=sys.stderr)
print("\n [Test 5] Hot-reload LoRA (load_lora)...", file=sys.stderr)
loaded, skipped = model.load_lora(ckpt_dir)
print(f" Reloaded {len(loaded)} parameters", file=sys.stderr)
audio_np = model.generate(
@@ -231,9 +245,12 @@ def main():
)
reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded")
sf.write(str(reload_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
print(
f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
file=sys.stderr,
)
print(f"\n[Done] All tests completed!", file=sys.stderr)
print("\n[Done] All tests completed!", file=sys.stderr)
print(f" - with_lora: {lora_output}", file=sys.stderr)
print(f" - lora_disabled: {disabled_output}", file=sys.stderr)
print(f" - lora_reenabled: {reenabled_output}", file=sys.stderr)
+194 -84
View File
@@ -7,7 +7,7 @@ project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root / "src"))
import contextlib
from typing import Dict, Optional
from typing import Dict
import argbind
import torch
@@ -17,16 +17,19 @@ from transformers import get_cosine_schedule_with_warmup
import signal
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
os.environ["TOKENIZERS_PARALLELISM"] = "false"
try:
from safetensors.torch import save_file
SAFETENSORS_AVAILABLE = True
except ImportError:
SAFETENSORS_AVAILABLE = False
print("Warning: safetensors not available, will use pytorch format", file=sys.stderr)
from voxcpm.model import VoxCPMModel
import json
from voxcpm.model import VoxCPMModel, VoxCPM2Model
from voxcpm.model.voxcpm import LoRAConfig
from voxcpm.training import (
Accelerator,
@@ -61,15 +64,15 @@ def train(
lora: dict = None,
config_path: str = "",
# Distribution options (for LoRA checkpoints)
hf_model_id: str = "", # HuggingFace model ID (e.g., "openbmb/VoxCPM1.5")
distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path
hf_model_id: str = "", # HuggingFace model ID (e.g., "openbmb/VoxCPM1.5")
distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path
):
_ = config_path
# Validate distribution options
if lora is not None and distribute and not hf_model_id:
raise ValueError("hf_model_id is required when distribute=True")
accelerator = Accelerator(amp=True)
save_dir = Path(save_path)
@@ -84,7 +87,15 @@ def train(
writer = SummaryWriter(log_dir=str(tb_dir)) if accelerator.rank == 0 else None
tracker = TrainingTracker(writer=writer, log_file=str(save_dir / "train.log"), rank=accelerator.rank)
base_model = VoxCPMModel.from_local(pretrained_path, optimize=False, training=True, lora_config=LoRAConfig(**lora) if lora else None)
# Auto-detect model architecture from config.json
with open(os.path.join(pretrained_path, "config.json"), "r", encoding="utf-8") as _f:
_arch = json.load(_f).get("architecture", "voxcpm").lower()
_model_cls = VoxCPM2Model if _arch == "voxcpm2" else VoxCPMModel
if accelerator.rank == 0:
print(f"Detected architecture: {_arch} -> {_model_cls.__name__}", file=sys.stderr)
base_model = _model_cls.from_local(
pretrained_path, optimize=False, training=True, lora_config=LoRAConfig(**lora) if lora else None
)
tokenizer = base_model.text_tokenizer
train_ds, val_ds = load_audio_text_datasets(
@@ -166,7 +177,6 @@ def train(
unwrapped_model = accelerator.unwrap(model)
unwrapped_model.train()
# Only print param info on rank 0 to avoid cluttered output
if accelerator.rank == 0:
for name, param in model.named_parameters():
@@ -191,7 +201,7 @@ def train(
# All ranks load the same checkpoint to keep model and optimizer state in sync.
start_step = load_checkpoint(model, optimizer, scheduler, save_dir, rank=accelerator.rank)
accelerator.barrier()
if start_step > 0 and accelerator.rank == 0:
tracker.print(f"Resuming training from step {start_step}")
@@ -199,7 +209,19 @@ def train(
resume = {"step": start_step}
# Register signal handler to save checkpoint on termination (SIGTERM/SIGINT)
def _signal_handler(signum, frame, _model=model, _optim=optimizer, _sched=scheduler, _save_dir=save_dir, _pretrained=pretrained_path, _hf_id=hf_model_id, _dist=distribute, _resume=resume, _rank=accelerator.rank):
def _signal_handler(
signum,
frame,
_model=model,
_optim=optimizer,
_sched=scheduler,
_save_dir=save_dir,
_pretrained=pretrained_path,
_hf_id=hf_model_id,
_dist=distribute,
_resume=resume,
_rank=accelerator.rank,
):
try:
cur_step = int(_resume.get("step", start_step))
except Exception:
@@ -229,8 +251,8 @@ def train(
except StopIteration:
data_epoch += 1
# Key: set DistributedSampler epoch to ensure different data order each epoch
sampler = getattr(train_loader, 'sampler', None)
if hasattr(sampler, 'set_epoch'):
sampler = getattr(train_loader, "sampler", None)
if hasattr(sampler, "set_epoch"):
sampler.set_epoch(data_epoch)
train_iter = iter(train_loader)
return next(train_iter)
@@ -250,7 +272,7 @@ def train(
# Only sync gradients on the last micro-batch
# Use no_sync() for intermediate steps to reduce communication overhead
is_last_micro_step = (micro_step == grad_accum_steps - 1)
is_last_micro_step = micro_step == grad_accum_steps - 1
sync_context = contextlib.nullcontext() if is_last_micro_step else accelerator.no_sync()
with sync_context:
@@ -299,10 +321,22 @@ def train(
tracker.log_metrics(loss_values, split="train")
if val_loader is not None and (step % valid_interval == 0 or step == num_iters - 1):
validate(model, val_loader, batch_processor, accelerator, tracker, lambdas,
writer=writer, step=step, val_ds=val_ds, audio_vae=audio_vae_for_gen,
sample_rate=sample_rate, val_texts=val_texts, tokenizer=tokenizer,
valid_interval=valid_interval)
validate(
model,
val_loader,
batch_processor,
accelerator,
tracker,
lambdas,
writer=writer,
step=step,
val_ds=val_ds,
audio_vae=audio_vae_for_gen,
sample_rate=sample_rate,
val_texts=val_texts,
tokenizer=tokenizer,
valid_interval=valid_interval,
)
if (step % save_interval == 0 or step == num_iters - 1) and accelerator.rank == 0:
save_checkpoint(model, optimizer, scheduler, save_dir, step, pretrained_path, hf_model_id, distribute)
@@ -313,13 +347,26 @@ def train(
writer.close()
def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas,
writer=None, step=0, val_ds=None, audio_vae=None, sample_rate=22050,
val_texts=None, tokenizer=None, valid_interval=1000):
def validate(
model,
val_loader,
batch_processor,
accelerator,
tracker,
lambdas,
writer=None,
step=0,
val_ds=None,
audio_vae=None,
sample_rate=22050,
val_texts=None,
tokenizer=None,
valid_interval=1000,
):
"""Validate and generate sample audio"""
import numpy as np
import numpy as np # noqa: F401
from collections import defaultdict
model.eval()
total_losses = []
sub_losses = defaultdict(list) # Track individual sub-losses
@@ -356,26 +403,37 @@ def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas,
# Compute mean total loss
mean_total_loss = torch.stack(total_losses).mean()
accelerator.all_reduce(mean_total_loss)
# Compute mean of each sub-loss
val_metrics = {"loss/total": mean_total_loss.item()}
for key, values in sub_losses.items():
mean_sub_loss = torch.stack(values).mean()
accelerator.all_reduce(mean_sub_loss)
val_metrics[key] = mean_sub_loss.item()
tracker.log_metrics(val_metrics, split="val")
# Generate sample audio for TensorBoard display
if writer is not None and val_ds is not None and audio_vae is not None and accelerator.rank == 0:
try:
generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, sample_rate,
val_texts=val_texts, tokenizer=tokenizer, valid_interval=valid_interval,
tracker=tracker)
generate_sample_audio(
model,
val_ds,
audio_vae,
writer,
step,
accelerator,
sample_rate,
val_texts=val_texts,
tokenizer=tokenizer,
valid_interval=valid_interval,
tracker=tracker,
)
except Exception as e:
tracker.print(f"[Warning] Failed to generate sample audio: {e}")
import traceback
import io
buf = io.StringIO()
traceback.print_exc(file=buf)
tracker.print(buf.getvalue())
@@ -390,7 +448,7 @@ def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas,
missing.append("audio_vae")
if missing and accelerator.rank == 0:
tracker.print(f"[Warning] Skip audio generation: missing {', '.join(missing)}")
model.train()
@@ -398,6 +456,7 @@ def compute_mel_spectrogram(audio_np, sample_rate, n_mels=128):
"""Compute Mel Spectrogram (dB) using librosa"""
import numpy as np
import librosa
audio_np = audio_np.flatten().astype(np.float32)
mel = librosa.feature.melspectrogram(y=audio_np, sr=sample_rate, n_mels=n_mels, fmax=sample_rate // 2)
return librosa.power_to_db(mel, ref=np.max)
@@ -408,31 +467,45 @@ def create_mel_figure(gen_audio_np, gen_mel, sample_rate, step=None, ref_audio_n
Create mel spectrogram figure: show comparison if reference audio exists, otherwise show generated only
"""
import matplotlib
matplotlib.use('Agg')
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import librosa.display
fmax = sample_rate // 2
step_str = f" @ Step {step}" if step is not None else ""
if ref_audio_np is not None and ref_mel is not None:
# Comparison mode: reference vs generated
fig, (ax_ref, ax_gen) = plt.subplots(2, 1, figsize=(12, 8))
img_ref = librosa.display.specshow(ref_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax_ref)
ax_ref.set_title(f'Reference (GT) - {len(ref_audio_np)/sample_rate:.2f}s{step_str}', fontsize=10, fontweight='bold', color='#28A745')
plt.colorbar(img_ref, ax=ax_ref, format='%+2.0f dB', pad=0.02)
img_gen = librosa.display.specshow(gen_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax_gen)
ax_gen.set_title(f'Generated - {len(gen_audio_np)/sample_rate:.2f}s', fontsize=10, fontweight='bold', color='#DC3545')
plt.colorbar(img_gen, ax=ax_gen, format='%+2.0f dB', pad=0.02)
img_ref = librosa.display.specshow(
ref_mel, sr=sample_rate, x_axis="time", y_axis="mel", fmax=fmax, cmap="viridis", ax=ax_ref
)
ax_ref.set_title(
f"Reference (GT) - {len(ref_audio_np)/sample_rate:.2f}s{step_str}",
fontsize=10,
fontweight="bold",
color="#28A745",
)
plt.colorbar(img_ref, ax=ax_ref, format="%+2.0f dB", pad=0.02)
img_gen = librosa.display.specshow(
gen_mel, sr=sample_rate, x_axis="time", y_axis="mel", fmax=fmax, cmap="viridis", ax=ax_gen
)
ax_gen.set_title(
f"Generated - {len(gen_audio_np)/sample_rate:.2f}s", fontsize=10, fontweight="bold", color="#DC3545"
)
plt.colorbar(img_gen, ax=ax_gen, format="%+2.0f dB", pad=0.02)
else:
# Single figure mode: show generated only
fig, ax = plt.subplots(figsize=(12, 4))
img = librosa.display.specshow(gen_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax)
ax.set_title(f'Generated - {len(gen_audio_np)/sample_rate:.2f}s{step_str}', fontsize=11, fontweight='bold')
plt.colorbar(img, ax=ax, format='%+2.0f dB', pad=0.02)
img = librosa.display.specshow(
gen_mel, sr=sample_rate, x_axis="time", y_axis="mel", fmax=fmax, cmap="viridis", ax=ax
)
ax.set_title(f"Generated - {len(gen_audio_np)/sample_rate:.2f}s{step_str}", fontsize=11, fontweight="bold")
plt.colorbar(img, ax=ax, format="%+2.0f dB", pad=0.02)
plt.tight_layout()
return fig
@@ -440,26 +513,38 @@ def create_mel_figure(gen_audio_np, gen_mel, sample_rate, step=None, ref_audio_n
def normalize_audio(audio_np):
"""Normalize audio to [-0.9, 0.9]"""
import numpy as np
max_val = np.abs(audio_np).max()
return audio_np / max_val * 0.9 if max_val > 0 else audio_np
def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, sample_rate=22050,
val_texts=None, tokenizer=None, pretrained_path=None, valid_interval=1000,
tracker=None):
def generate_sample_audio(
model,
val_ds,
audio_vae,
writer,
step,
accelerator,
sample_rate=22050,
val_texts=None,
tokenizer=None,
pretrained_path=None,
valid_interval=1000,
tracker=None,
):
"""Select 2 fixed validation samples, generate audio and log to TensorBoard"""
import numpy as np
log = tracker.print if tracker else print
num_samples = min(2, len(val_ds))
log(f"[Audio] Starting audio generation for {num_samples} samples at step {step}")
unwrapped_model = accelerator.unwrap(model)
for i in range(num_samples):
sample = val_ds[i]
text = val_texts[i] if val_texts and i < len(val_texts) else "Hello, this is a test."
# Load reference audio
ref_audio_np = None
try:
@@ -468,7 +553,10 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
ref_sr = sample["audio"].get("sampling_rate", sample_rate)
if ref_sr != sample_rate:
import torchaudio.functional as F
ref_audio_np = F.resample(torch.from_numpy(ref_audio_np).unsqueeze(0), ref_sr, sample_rate).squeeze(0).numpy()
ref_audio_np = (
F.resample(torch.from_numpy(ref_audio_np).unsqueeze(0), ref_sr, sample_rate).squeeze(0).numpy()
)
log(f"[Audio] Loaded reference audio for sample {i}: duration={len(ref_audio_np)/sample_rate:.2f}s")
except Exception as e:
log(f"[Warning] Failed to load reference audio: {e}")
@@ -480,7 +568,7 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
unwrapped_model.eval()
# unwrapped_model.to(torch.bfloat16)
unwrapped_model.audio_vae = audio_vae.to(torch.float32)
log(f"[Audio] Generating sample {i} with text: '{text[:50]}...'")
autocast_ctx = (
torch.autocast(device_type="cuda", dtype=torch.bfloat16)
@@ -490,27 +578,33 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
with torch.no_grad():
with autocast_ctx:
generated = unwrapped_model.generate(target_text=text, inference_timesteps=10, cfg_value=2.0)
# Restore training setup
# unwrapped_model.to(torch.float32)
# unwrapped_model.audio_vae = None
if generated is None or len(generated) == 0:
log(f"[Warning] Generated audio is empty for sample {i}")
continue
# Process generated audio
gen_audio_np = generated.cpu().float().numpy().flatten() if isinstance(generated, torch.Tensor) else np.array(generated, dtype=np.float32).flatten()
gen_audio_np = (
generated.cpu().float().numpy().flatten()
if isinstance(generated, torch.Tensor)
else np.array(generated, dtype=np.float32).flatten()
)
gen_audio_np = normalize_audio(gen_audio_np)
tag = f"val_sample_{i}"
writer.add_audio(f"{tag}/generated_audio", gen_audio_np, global_step=step, sample_rate=sample_rate)
log(f"[Audio] Generated audio for sample {i}: duration={len(gen_audio_np)/sample_rate:.2f}s")
# Log reference audio
if ref_audio_np is not None:
writer.add_audio(f"{tag}/reference_audio", normalize_audio(ref_audio_np), global_step=step, sample_rate=sample_rate)
writer.add_audio(
f"{tag}/reference_audio", normalize_audio(ref_audio_np), global_step=step, sample_rate=sample_rate
)
# Generate mel spectrogram figure
try:
mel_gen = compute_mel_spectrogram(gen_audio_np, sample_rate)
@@ -520,10 +614,11 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
log(f"[Audio] Created mel spectrogram figure for sample {i}")
except Exception as e:
log(f"[Warning] Failed to create mel spectrogram: {e}")
except Exception as e:
log(f"[Warning] Failed to generate audio for sample {i}: {e}")
import traceback
traceback.print_exc()
finally:
@@ -545,30 +640,29 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0):
Called by all ranks so that distributed state stays aligned.
Returns the step number to resume from, or 0 if no checkpoint found.
"""
import json
latest_folder = save_dir / "latest"
if not latest_folder.exists():
return 0
unwrapped = model.module if hasattr(model, "module") else model
lora_cfg = unwrapped.lora_config
# Load model weights
if lora_cfg is not None:
# LoRA: load lora_weights
lora_weights_path = latest_folder / "lora_weights.safetensors"
if not lora_weights_path.exists():
lora_weights_path = latest_folder / "lora_weights.ckpt"
if lora_weights_path.exists():
if lora_weights_path.suffix == ".safetensors":
from safetensors.torch import load_file
state_dict = load_file(str(lora_weights_path))
else:
ckpt = torch.load(lora_weights_path, map_location="cpu")
state_dict = ckpt.get("state_dict", ckpt)
unwrapped.load_state_dict(state_dict, strict=False)
if rank == 0:
print(f"Loaded LoRA weights from {lora_weights_path}", file=sys.stderr)
@@ -577,33 +671,34 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0):
model_path = latest_folder / "model.safetensors"
if not model_path.exists():
model_path = latest_folder / "pytorch_model.bin"
if model_path.exists():
if model_path.suffix == ".safetensors":
from safetensors.torch import load_file
state_dict = load_file(str(model_path))
else:
ckpt = torch.load(model_path, map_location="cpu")
state_dict = ckpt.get("state_dict", ckpt)
unwrapped.load_state_dict(state_dict, strict=False)
if rank == 0:
print(f"Loaded model weights from {model_path}", file=sys.stderr)
# Load optimizer state
optimizer_path = latest_folder / "optimizer.pth"
if optimizer_path.exists():
optimizer.load_state_dict(torch.load(optimizer_path, map_location="cpu"))
if rank == 0:
print(f"Loaded optimizer state from {optimizer_path}", file=sys.stderr)
# Load scheduler state
scheduler_path = latest_folder / "scheduler.pth"
if scheduler_path.exists():
scheduler.load_state_dict(torch.load(scheduler_path, map_location="cpu"))
if rank == 0:
print(f"Loaded scheduler state from {scheduler_path}", file=sys.stderr)
state_path = latest_folder / "training_state.json"
if state_path.exists():
with open(state_path, "r", encoding="utf-8") as f:
@@ -621,28 +716,36 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0):
if rank == 0:
print(f"Resuming from step {resume_step}", file=sys.stderr)
return resume_step
return 0
def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pretrained_path: str = None, hf_model_id: str = "", distribute: bool = False):
def save_checkpoint(
model,
optimizer,
scheduler,
save_dir: Path,
step: int,
pretrained_path: str = None,
hf_model_id: str = "",
distribute: bool = False,
):
"""
Save checkpoint with different strategies for full finetune vs LoRA:
- Full finetune: save non-vae weights to model.safetensors (or pytorch_model.bin if safetensors unavailable)
- LoRA: save only lora weights to lora_weights.safetensors (or lora_weights.ckpt if safetensors unavailable)
"""
import json
import shutil
save_dir.mkdir(parents=True, exist_ok=True)
tag = f"step_{step:07d}"
folder = save_dir / tag
folder.mkdir(parents=True, exist_ok=True)
unwrapped = model.module if hasattr(model, "module") else model
full_state = unwrapped.state_dict()
lora_cfg = unwrapped.lora_config
if lora_cfg is not None:
# LoRA finetune: save only lora_A/lora_B weights
state_dict = {k: v for k, v in full_state.items() if "lora_" in k}
@@ -650,7 +753,7 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
save_file(state_dict, folder / "lora_weights.safetensors")
else:
torch.save({"state_dict": state_dict}, folder / "lora_weights.ckpt")
# Save LoRA config and base model path to a separate JSON file
# If distribute=True, save hf_model_id; otherwise save local pretrained_path
base_model_to_save = hf_model_id if distribute else (str(pretrained_path) if pretrained_path else None)
@@ -667,16 +770,23 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
save_file(state_dict, folder / "model.safetensors")
else:
torch.save({"state_dict": state_dict}, folder / "pytorch_model.bin")
# Copy config files from pretrained path
if pretrained_path:
pretrained_dir = Path(pretrained_path)
files_to_copy = ["config.json", "audiovae.pth", "tokenizer.json", "special_tokens_map.json", "tokenizer_config.json"]
files_to_copy = [
"config.json",
"audiovae.pth",
"audiovae.safetensors",
"tokenizer.json",
"special_tokens_map.json",
"tokenizer_config.json",
]
for fname in files_to_copy:
src = pretrained_dir / fname
if src.exists():
shutil.copy2(src, folder / fname)
torch.save(optimizer.state_dict(), folder / "optimizer.pth")
torch.save(scheduler.state_dict(), folder / "scheduler.pth")
with open(folder / "training_state.json", "w", encoding="utf-8") as f:
+46 -27
View File
@@ -13,11 +13,11 @@ import soundfile as sf
from voxcpm.core import VoxCPM
# -----------------------------
# Validators
# -----------------------------
def validate_file_exists(file_path: str, file_type: str = "file") -> Path:
path = Path(file_path)
if not path.exists():
@@ -53,12 +53,11 @@ def validate_ranges(args, parser):
# Model loading
# -----------------------------
def load_model(args) -> VoxCPM:
print("Loading VoxCPM model...", file=sys.stderr)
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get(
"ZIPENHANCER_MODEL_PATH", None
)
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get("ZIPENHANCER_MODEL_PATH", None)
# Build LoRA config if provided
lora_config = None
@@ -119,22 +118,29 @@ def load_model(args) -> VoxCPM:
# Commands
# -----------------------------
def cmd_clone(args):
if not args.text:
sys.exit("Error: Please provide --text for synthesis")
if not args.prompt_audio or not args.prompt_text:
sys.exit("Error: Voice cloning requires both --prompt-audio and --prompt-text")
has_prompt = args.prompt_audio and args.prompt_text
has_ref = args.reference_audio is not None
if not has_prompt and not has_ref:
sys.exit("Error: Voice cloning requires --prompt-audio + --prompt-text, or --reference-audio, or both")
prompt_audio_path = validate_file_exists(args.prompt_audio, "reference audio file")
if args.prompt_audio:
validate_file_exists(args.prompt_audio, "prompt audio file")
if args.reference_audio:
validate_file_exists(args.reference_audio, "reference audio file")
output_path = validate_output_path(args.output)
model = load_model(args)
audio_array = model.generate(
text=args.text,
prompt_wav_path=str(prompt_audio_path),
prompt_text=args.prompt_text,
prompt_wav_path=args.prompt_audio if has_prompt else None,
prompt_text=args.prompt_text if has_prompt else None,
reference_wav_path=args.reference_audio,
cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
normalize=args.normalize,
@@ -185,7 +191,11 @@ def cmd_batch(args):
prompt_audio_path = None
if args.prompt_audio:
prompt_audio_path = str(validate_file_exists(args.prompt_audio, "reference audio file"))
prompt_audio_path = str(validate_file_exists(args.prompt_audio, "prompt audio file"))
reference_audio_path = None
if args.reference_audio:
reference_audio_path = str(validate_file_exists(args.reference_audio, "reference audio file"))
success_count = 0
@@ -195,10 +205,11 @@ def cmd_batch(args):
text=text,
prompt_wav_path=prompt_audio_path,
prompt_text=args.prompt_text,
reference_wav_path=reference_audio_path,
cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
normalize=args.normalize,
denoise=args.denoise and prompt_audio_path is not None,
denoise=args.denoise and (prompt_audio_path is not None or reference_audio_path is not None),
)
output_file = output_dir / f"output_{i:03d}.wav"
@@ -218,6 +229,7 @@ def cmd_batch(args):
# Parser
# -----------------------------
def _build_unified_parser():
parser = argparse.ArgumentParser(
description="VoxCPM CLI - voice cloning, direct TTS, and batch processing",
@@ -236,34 +248,40 @@ Examples:
parser.add_argument("--text", "-t", help="Text to synthesize (single or clone mode)")
parser.add_argument("--output", "-o", help="Output audio file path (single or clone mode)")
# Prompt
parser.add_argument("--prompt-audio", "-pa", help="Reference audio file path (clone mode)")
parser.add_argument("--prompt-text", "-pt", help="Reference text corresponding to the audio")
parser.add_argument("--denoise", action="store_true", help="Enable prompt speech enhancement")
# Prompt / Reference
parser.add_argument(
"--prompt-audio", "-pa", help="Prompt audio file path (continuation mode, requires --prompt-text)"
)
parser.add_argument("--prompt-text", "-pt", help="Text corresponding to the prompt audio")
parser.add_argument(
"--reference-audio", "-ra", help="Reference audio for voice cloning (isolated mode, VoxCPM2 only)"
)
parser.add_argument("--denoise", action="store_true", help="Enable prompt/reference speech enhancement")
# Generation parameters
parser.add_argument("--cfg-value", type=float, default=2.0,
help="CFG guidance scale (float, recommended 0.55.0, default: 2.0)")
parser.add_argument("--inference-timesteps", type=int, default=10,
help="Inference steps (int, 1100, default: 10)")
parser.add_argument(
"--cfg-value", type=float, default=2.0, help="CFG guidance scale (float, recommended 0.55.0, default: 2.0)"
)
parser.add_argument("--inference-timesteps", type=int, default=10, help="Inference steps (int, 1100, default: 10)")
parser.add_argument("--normalize", action="store_true", help="Enable text normalization")
# Model loading
parser.add_argument("--model-path", type=str, help="Local VoxCPM model path")
parser.add_argument("--hf-model-id", type=str, default="openbmb/VoxCPM1.5",
help="Hugging Face repo id (default: openbmb/VoxCPM1.5)")
parser.add_argument(
"--hf-model-id", type=str, default="openbmb/VoxCPM1.5", help="Hugging Face repo id (default: openbmb/VoxCPM1.5)"
)
parser.add_argument("--cache-dir", type=str, help="Cache directory for Hub downloads")
parser.add_argument("--local-files-only", action="store_true", help="Disable network access")
parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading")
parser.add_argument("--zipenhancer-path", type=str,
help="ZipEnhancer model id or local path (or env ZIPENHANCER_MODEL_PATH)")
parser.add_argument(
"--zipenhancer-path", type=str, help="ZipEnhancer model id or local path (or env ZIPENHANCER_MODEL_PATH)"
)
# LoRA
parser.add_argument("--lora-path", type=str, help="Path to LoRA weights")
parser.add_argument("--lora-r", type=int, default=32, help="LoRA rank (positive int, default: 32)")
parser.add_argument("--lora-alpha", type=int, default=16, help="LoRA alpha (positive int, default: 16)")
parser.add_argument("--lora-dropout", type=float, default=0.0,
help="LoRA dropout rate (0.01.0, default: 0.0)")
parser.add_argument("--lora-dropout", type=float, default=0.0, help="LoRA dropout rate (0.01.0, default: 0.0)")
parser.add_argument("--lora-disable-lm", action="store_true", help="Disable LoRA on LM layers")
parser.add_argument("--lora-disable-dit", action="store_true", help="Disable LoRA on DiT layers")
parser.add_argument("--lora-enable-proj", action="store_true", help="Enable LoRA on projection layers")
@@ -275,6 +293,7 @@ Examples:
# Entrypoint
# -----------------------------
def main():
parser = _build_unified_parser()
args = parser.parse_args()
@@ -296,8 +315,8 @@ def main():
if not args.text or not args.output:
parser.error("Single-sample mode requires --text and --output")
# Clone mode
if args.prompt_audio or args.prompt_text:
# Clone mode (prompt continuation, reference isolation, or both)
if args.prompt_audio or args.prompt_text or args.reference_audio:
return cmd_clone(args)
# Direct synthesis
+151 -100
View File
@@ -1,21 +1,25 @@
import os
import sys
import re
import json
import tempfile
import numpy as np
from typing import Generator, Optional
from huggingface_hub import snapshot_download
from .model.voxcpm import VoxCPMModel, LoRAConfig
from .model.voxcpm2 import VoxCPM2Model
class VoxCPM:
def __init__(self,
voxcpm_model_path : str,
zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
enable_denoiser : bool = True,
optimize: bool = True,
lora_config: Optional[LoRAConfig] = None,
lora_weights_path: Optional[str] = None,
):
def __init__(
self,
voxcpm_model_path: str,
zipenhancer_model_path: str | None = "iic/speech_zipenhancer_ans_multiloss_16k_base",
enable_denoiser: bool = True,
optimize: bool = True,
lora_config: Optional[LoRAConfig] = None,
lora_weights_path: Optional[str] = None,
):
"""Initialize VoxCPM TTS pipeline.
Args:
@@ -26,13 +30,16 @@ class VoxCPM:
id or local path. If None, denoiser will not be initialized.
enable_denoiser: Whether to initialize the denoiser pipeline.
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
provided without lora_config, a default config will be created.
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
containing lora_weights.ckpt). If provided, LoRA weights will be loaded.
"""
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}", file=sys.stderr)
print(
f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}",
file=sys.stderr,
)
# If lora_weights_path is provided but no lora_config, create a default one
if lora_weights_path is not None and lora_config is None:
lora_config = LoRAConfig(
@@ -41,18 +48,33 @@ class VoxCPM:
enable_proj=False,
)
print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}", file=sys.stderr)
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
# Determine model type from config.json architecture field
config_path = os.path.join(voxcpm_model_path, "config.json")
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
arch = config.get("architecture", "voxcpm").lower()
if arch == "voxcpm2":
self.tts_model = VoxCPM2Model.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
print("Loaded VoxCPM2Model", file=sys.stderr)
elif arch == "voxcpm":
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
print("Loaded VoxCPMModel", file=sys.stderr)
else:
raise ValueError(f"Unsupported architecture: {arch}")
# Load LoRA weights if path is provided
if lora_weights_path is not None:
print(f"Loading LoRA weights from: {lora_weights_path}", file=sys.stderr)
loaded_keys, skipped_keys = self.tts_model.load_lora_weights(lora_weights_path)
print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}", file=sys.stderr)
self.text_normalizer = None
self.denoiser = None
if enable_denoiser and zipenhancer_model_path is not None:
from .zipenhancer import ZipEnhancer
self.denoiser = ZipEnhancer(zipenhancer_model_path)
else:
self.denoiser = None
@@ -64,17 +86,18 @@ class VoxCPM:
)
@classmethod
def from_pretrained(cls,
hf_model_id: str = "openbmb/VoxCPM1.5",
load_denoiser: bool = True,
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
cache_dir: str = None,
local_files_only: bool = False,
optimize: bool = True,
lora_config: Optional[LoRAConfig] = None,
lora_weights_path: Optional[str] = None,
**kwargs,
):
def from_pretrained(
cls,
hf_model_id: str = "openbmb/VoxCPM2",
load_denoiser: bool = True,
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
cache_dir: str = None,
local_files_only: bool = False,
optimize: bool = True,
lora_config: Optional[LoRAConfig] = None,
lora_weights_path: Optional[str] = None,
**kwargs,
):
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
Args:
@@ -86,7 +109,7 @@ class VoxCPM:
cache_dir: Custom cache directory for the snapshot.
local_files_only: If True, only use local files and do not attempt
to download.
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
provided without lora_config, a default config will be created with
enable_lm=True and enable_dit=True.
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
@@ -106,7 +129,7 @@ class VoxCPM:
repo_id = hf_model_id
if not repo_id:
raise ValueError("You must provide hf_model_id")
# Load from local path if provided
if os.path.isdir(repo_id):
local_path = repo_id
@@ -134,118 +157,146 @@ class VoxCPM:
def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]:
return self._generate(*args, streaming=True, **kwargs)
def _generate(self,
text : str,
prompt_wav_path : str = None,
prompt_text : str = None,
cfg_value : float = 2.0,
inference_timesteps : int = 10,
min_len : int = 2,
max_len : int = 4096,
normalize : bool = False,
denoise : bool = False,
retry_badcase : bool = True,
retry_badcase_max_times : int = 3,
retry_badcase_ratio_threshold : float = 6.0,
streaming: bool = False,
) -> Generator[np.ndarray, None, None]:
def _generate(
self,
text: str,
prompt_wav_path: str = None,
prompt_text: str = None,
reference_wav_path: str = None,
cfg_value: float = 2.0,
inference_timesteps: int = 10,
min_len: int = 2,
max_len: int = 4096,
normalize: bool = False,
denoise: bool = False,
retry_badcase: bool = True,
retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0,
streaming: bool = False,
) -> Generator[np.ndarray, None, None]:
"""Synthesize speech for the given text and return a single waveform.
This method optionally builds and reuses a prompt cache. If an external
prompt (``prompt_wav_path`` + ``prompt_text``) is provided, it will be
used for all sub-sentences. Otherwise, the prompt cache is built from
the first generated result and reused for the remaining text chunks.
Args:
text: Input text. Can include newlines; each non-empty line is
treated as a sub-sentence.
prompt_wav_path: Path to a reference audio file for prompting.
text: Input text to synthesize.
prompt_wav_path: Path to prompt audio for continuation mode.
Must be paired with ``prompt_text``.
prompt_text: Text content corresponding to the prompt audio.
reference_wav_path: Path to reference audio for voice cloning
(structurally isolated via ref_audio tokens). Can be used
alone or combined with ``prompt_wav_path`` + ``prompt_text``.
cfg_value: Guidance scale for the generation model.
inference_timesteps: Number of inference steps.
min_len: Minimum audio length.
max_len: Maximum token length during generation.
normalize: Whether to run text normalization before generation.
denoise: Whether to denoise the prompt audio if a denoiser is
available.
denoise: Whether to denoise the prompt/reference audio if a
denoiser is available.
retry_badcase: Whether to retry badcase.
retry_badcase_max_times: Maximum number of times to retry badcase.
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio.
streaming: Whether to return a generator of audio chunks.
Returns:
Generator of numpy.ndarray: 1D waveform array (float32) on CPU.
Yields audio chunks for each generations step if ``streaming=True``,
Generator of numpy.ndarray: 1D waveform array (float32) on CPU.
Yields audio chunks for each generation step if ``streaming=True``,
otherwise yields a single array containing the final audio.
"""
if not text.strip() or not isinstance(text, str):
raise ValueError("target text must be a non-empty string")
if prompt_wav_path is not None:
if not os.path.exists(prompt_wav_path):
raise FileNotFoundError(f"prompt_wav_path does not exist: {prompt_wav_path}")
if reference_wav_path is not None:
if not os.path.exists(reference_wav_path):
raise FileNotFoundError(f"reference_wav_path does not exist: {reference_wav_path}")
if (prompt_wav_path is None) != (prompt_text is None):
raise ValueError("prompt_wav_path and prompt_text must both be provided or both be None")
is_v2 = isinstance(self.tts_model, VoxCPM2Model)
if reference_wav_path is not None and not is_v2:
raise ValueError("reference_wav_path is only supported with VoxCPM2 models")
text = text.replace("\n", " ")
text = re.sub(r'\s+', ' ', text)
temp_prompt_wav_path = None
text = re.sub(r"\s+", " ", text)
temp_files = []
try:
if prompt_wav_path is not None and prompt_text is not None:
if denoise and self.denoiser is not None:
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
temp_prompt_wav_path = tmp_file.name
self.denoiser.enhance(prompt_wav_path, output_path=temp_prompt_wav_path)
prompt_wav_path = temp_prompt_wav_path
fixed_prompt_cache = self.tts_model.build_prompt_cache(
prompt_wav_path=prompt_wav_path,
prompt_text=prompt_text
)
actual_prompt_path = prompt_wav_path
actual_ref_path = reference_wav_path
if denoise and self.denoiser is not None:
if prompt_wav_path is not None:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
temp_files.append(tmp.name)
self.denoiser.enhance(prompt_wav_path, output_path=temp_files[-1])
actual_prompt_path = temp_files[-1]
if reference_wav_path is not None:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
temp_files.append(tmp.name)
self.denoiser.enhance(reference_wav_path, output_path=temp_files[-1])
actual_ref_path = temp_files[-1]
if actual_prompt_path is not None or actual_ref_path is not None:
if is_v2:
fixed_prompt_cache = self.tts_model.build_prompt_cache(
prompt_text=prompt_text,
prompt_wav_path=actual_prompt_path,
reference_wav_path=actual_ref_path,
)
else:
fixed_prompt_cache = self.tts_model.build_prompt_cache(
prompt_text=prompt_text,
prompt_wav_path=actual_prompt_path,
)
else:
fixed_prompt_cache = None # will be built from the first inference
fixed_prompt_cache = None
if normalize:
if self.text_normalizer is None:
from .utils.text_normalize import TextNormalizer
self.text_normalizer = TextNormalizer()
text = self.text_normalizer.normalize(text)
generate_result = self.tts_model._generate_with_prompt_cache(
target_text=text,
prompt_cache=fixed_prompt_cache,
min_len=min_len,
max_len=max_len,
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
retry_badcase=retry_badcase,
retry_badcase_max_times=retry_badcase_max_times,
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
streaming=streaming,
)
target_text=text,
prompt_cache=fixed_prompt_cache,
min_len=min_len,
max_len=max_len,
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
retry_badcase=retry_badcase,
retry_badcase_max_times=retry_badcase_max_times,
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
streaming=streaming,
)
for wav, _, _ in generate_result:
yield wav.squeeze(0).cpu().numpy()
finally:
if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path):
try:
os.unlink(temp_prompt_wav_path)
except OSError:
pass
for tmp_path in temp_files:
if tmp_path and os.path.exists(tmp_path):
try:
os.unlink(tmp_path)
except OSError:
pass
# ------------------------------------------------------------------ #
# LoRA Interface (delegated to VoxCPMModel)
# ------------------------------------------------------------------ #
def load_lora(self, lora_weights_path: str) -> tuple:
"""Load LoRA weights from a checkpoint file.
Args:
lora_weights_path: Path to LoRA weights (.pth file or directory
containing lora_weights.ckpt).
Returns:
tuple: (loaded_keys, skipped_keys) - lists of loaded and skipped parameter names.
Raises:
RuntimeError: If model was not initialized with LoRA config.
"""
@@ -259,23 +310,23 @@ class VoxCPM:
def unload_lora(self):
"""Unload LoRA by resetting all LoRA weights to initial state (effectively disabling LoRA)."""
self.tts_model.reset_lora_weights()
def set_lora_enabled(self, enabled: bool):
"""Enable or disable LoRA layers without unloading weights.
Args:
enabled: If True, LoRA layers are active; if False, only base model is used.
"""
self.tts_model.set_lora_enabled(enabled)
def get_lora_state_dict(self) -> dict:
"""Get current LoRA parameters state dict.
Returns:
dict: State dict containing all LoRA parameters (lora_A, lora_B).
"""
return self.tts_model.get_lora_state_dict()
@property
def lora_enabled(self) -> bool:
"""Check if LoRA is currently configured."""
+2 -1
View File
@@ -1,3 +1,4 @@
from .voxcpm import VoxCPMModel
from .voxcpm2 import VoxCPM2Model
__all__ = ["VoxCPMModel"]
__all__ = ["VoxCPMModel", "VoxCPM2Model"]
+18 -19
View File
@@ -5,17 +5,17 @@ from transformers import PreTrainedTokenizer
def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
"""Create a tokenizer wrapper that converts multi-character Chinese tokens to single characters.
This function creates a wrapper around the provided tokenizer that automatically
splits multi-character Chinese tokens into individual characters. This is useful
for ensuring consistent tokenization of Chinese text.
Args:
tokenizer: The base tokenizer to wrap
Returns:
A CharTokenizerWrapper instance that handles multi-character Chinese tokens
Example:
>>> from transformers import LlamaTokenizerFast
>>> tokenizer = LlamaTokenizerFast.from_pretrained("path/to/tokenizer")
@@ -24,20 +24,19 @@ def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
"""
# Pre-compute multi-character tokens (length >= 2, pure Chinese characters)
multichar_tokens = {
token for token in tokenizer.vocab.keys()
if len(token) >= 2 and all("\u4e00" <= c <= "\u9fff" for c in token)
token for token in tokenizer.vocab.keys() if len(token) >= 2 and all("\u4e00" <= c <= "\u9fff" for c in token)
}
class CharTokenizerWrapper:
"""Wrapper class for tokenizers that handles multi-character Chinese tokens.
This wrapper automatically splits multi-character Chinese tokens into
individual characters while preserving the original tokenizer's interface.
"""
def __init__(self, base_tokenizer: PreTrainedTokenizer) -> None:
"""Initialize the wrapper with a base tokenizer.
Args:
base_tokenizer: The tokenizer to wrap
"""
@@ -46,14 +45,14 @@ def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
def tokenize(self, text: str, **kwargs) -> List[str]:
"""Tokenize text and split multi-character Chinese tokens into single characters.
Args:
text: Input text to tokenize
**kwargs: Additional arguments passed to the base tokenizer
Returns:
List of processed tokens with multi-character Chinese tokens split
Example:
>>> wrapper = CharTokenizerWrapper(tokenizer)
>>> tokens = wrapper.tokenize("你好世界")
@@ -61,10 +60,10 @@ def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
"""
if not isinstance(text, str):
raise TypeError(f"Expected string input, got {type(text)}")
tokens = self.tokenizer.tokenize(text, **kwargs)
processed = []
for token in tokens:
# Remove possible subword prefix
clean_token = token.replace("", "")
@@ -75,22 +74,22 @@ def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
processed.extend(chars)
else:
processed.append(token)
return processed
def __call__(self, text: str, **kwargs) -> List[int]:
"""Call the tokenizer and return token IDs.
This method provides the same interface as the original tokenizer
but with multi-character Chinese token handling.
Args:
text: Input text to tokenize
**kwargs: Additional arguments passed to the base tokenizer
Returns:
List of token IDs
Raises:
TypeError: If input is not a string
ValueError: If tokenization fails
+128 -115
View File
@@ -24,7 +24,6 @@ from typing import Tuple, Union, Generator, List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import warnings
from einops import rearrange
@@ -32,6 +31,7 @@ from pydantic import BaseModel
try:
from safetensors.torch import load_file
SAFETENSORS_AVAILABLE = True
except ImportError:
SAFETENSORS_AVAILABLE = False
@@ -84,9 +84,9 @@ class VoxCPMConfig(BaseModel):
class LoRAConfig(BaseModel):
enable_lm: bool = False # Apply LoRA to base_lm + residual_lm
enable_dit: bool = False # Apply LoRA to VoxCPMLocDiT
enable_proj: bool = False # Apply LoRA to projection Linear layers
enable_lm: bool = False # Apply LoRA to base_lm + residual_lm
enable_dit: bool = False # Apply LoRA to VoxCPMLocDiT
enable_proj: bool = False # Apply LoRA to projection Linear layers
r: int = 8
alpha: int = 16
@@ -165,10 +165,10 @@ class VoxCPMModel(nn.Module):
# Projection layers
self.fsq_layer = ScalarQuantizationLayer(
config.lm_config.hidden_size,
config.lm_config.hidden_size,
config.scalar_quantization_latent_dim,
config.scalar_quantization_scale
config.lm_config.hidden_size,
config.lm_config.hidden_size,
config.scalar_quantization_latent_dim,
config.scalar_quantization_scale,
)
self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size)
self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
@@ -196,9 +196,7 @@ class VoxCPMModel(nn.Module):
# LM: base_lm + residual_lm
if cfg.enable_lm:
for lm in [self.base_lm, self.residual_lm]:
apply_lora_to_named_linear_modules(
lm, target_submodule_names=cfg.target_modules_lm, **lora_kwargs
)
apply_lora_to_named_linear_modules(lm, target_submodule_names=cfg.target_modules_lm, **lora_kwargs)
# DiT: feat_decoder.estimator
if cfg.enable_dit:
@@ -209,6 +207,7 @@ class VoxCPMModel(nn.Module):
# 投影层
if cfg.enable_proj:
from ..modules.layers.lora import LoRALinear
for attr_name in cfg.target_proj_modules:
module = getattr(self, attr_name, None)
if isinstance(module, nn.Linear):
@@ -221,13 +220,17 @@ class VoxCPMModel(nn.Module):
if self.device != "cuda":
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
try:
import triton
import triton # noqa: F401
except ImportError:
raise ValueError("triton is not installed")
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
self.residual_lm.forward_step = torch.compile(
self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True
)
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
self.feat_decoder.estimator = torch.compile(
self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True
)
except Exception as e:
print(f"Warning: torch.compile disabled - {e}", file=sys.stderr)
return self
@@ -313,9 +316,11 @@ class VoxCPMModel(nn.Module):
mu=dit_hidden,
patch_size=self.patch_size,
cond=feat_cond_for_sample,
n_timesteps=self.config.dit_config.cfm_config.inference_cfg_rate
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
else 10,
n_timesteps=(
self.config.dit_config.cfm_config.inference_cfg_rate
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
else 10
),
)
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
@@ -331,7 +336,6 @@ class VoxCPMModel(nn.Module):
def _dtype(self):
return get_dtype(self.config.dtype)
def generate(self, *args, **kwargs) -> torch.Tensor:
return next(self._generate(*args, streaming=False, **kwargs))
@@ -350,7 +354,7 @@ class VoxCPMModel(nn.Module):
cfg_value: float = 2.0,
retry_badcase: bool = False,
retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
streaming: bool = False,
) -> Generator[torch.Tensor, None, None]:
if retry_badcase and streaming:
@@ -394,7 +398,7 @@ class VoxCPMModel(nn.Module):
audio, sr = torchaudio.load(prompt_wav_path)
if audio.size(0) > 1:
audio = audio.mean(dim=0, keepdim=True)
audio = audio.mean(dim=0, keepdim=True)
if sr != self.sample_rate:
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
@@ -435,7 +439,7 @@ class VoxCPMModel(nn.Module):
audio_mask = audio_mask.unsqueeze(0).to(self.device)
target_text_length = len(self.text_tokenizer(target_text))
retry_badcase_times = 0
while retry_badcase_times < retry_badcase_max_times:
inference_result = self._inference(
@@ -444,7 +448,9 @@ class VoxCPMModel(nn.Module):
audio_feat,
audio_mask,
min_len=min_len,
max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len), # avoid too long audio
max_len=min(
int(target_text_length * retry_badcase_ratio_threshold + 10), max_len
), # avoid too long audio
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
streaming=streaming,
@@ -460,18 +466,21 @@ class VoxCPMModel(nn.Module):
latent_pred, pred_audio_feat = next(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...", file=sys.stderr)
print(
f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...",
file=sys.stderr,
)
retry_badcase_times += 1
continue
else:
break
else:
break
break
if not streaming:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
yield decode_audio
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
yield decode_audio
@torch.inference_mode()
def build_prompt_cache(
self,
@@ -480,11 +489,11 @@ class VoxCPMModel(nn.Module):
):
"""
Build prompt cache for subsequent fast generation.
Args:
prompt_text: prompt text (required)
prompt_wav_path: prompt audio path (required)
Returns:
prompt_cache: dict with prompt_text (raw text) and audio features.
Text tokenization will be done during generation for consistency.
@@ -496,7 +505,7 @@ class VoxCPMModel(nn.Module):
audio, sr = torchaudio.load(prompt_wav_path)
if audio.size(0) > 1:
audio = audio.mean(dim=0, keepdim=True)
if sr != self.sample_rate:
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
@@ -514,16 +523,17 @@ class VoxCPMModel(nn.Module):
self.audio_vae.latent_dim,
-1,
self.patch_size,
).permute(1, 2, 0) # (D, T, P)
).permute(
1, 2, 0
) # (D, T, P)
# build prompt cache - only save raw text and audio features
prompt_cache = {
"prompt_text": prompt_text,
"audio_feat": audio_feat,
}
return prompt_cache
def merge_prompt_cache(
self,
original_cache: dict,
@@ -532,12 +542,12 @@ class VoxCPMModel(nn.Module):
):
"""
Merge original prompt cache with newly generated content to stabilize voice.
Args:
original_cache: original prompt cache
new_text: newly generated text
new_text: newly generated text
new_audio_feat: newly generated audio features
Returns:
merged_cache: merged cache with prompt_text and audio_feat
"""
@@ -557,20 +567,17 @@ class VoxCPMModel(nn.Module):
"prompt_text": merged_prompt_text,
"audio_feat": merged_audio_feat,
}
return merged_cache
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
def generate_with_prompt_cache_streaming(
self, *args, **kwargs
) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]:
return self._generate_with_prompt_cache(*args, streaming=True, **kwargs)
@torch.inference_mode()
def _generate_with_prompt_cache(
self,
@@ -588,7 +595,7 @@ class VoxCPMModel(nn.Module):
) -> Generator[Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
"""
Generate audio using pre-built prompt cache.
Args:
target_text: Text to convert to speech
prompt_cache: Cache built by build_prompt_cache (can be None)
@@ -601,7 +608,7 @@ class VoxCPMModel(nn.Module):
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
streaming: Whether to return a generator of audio chunks
streaming_prefix_len: Number of prefix audio patches to use for streaming mode
Returns:
Generator of Tuple containing:
- Decoded audio tensor for the current step if ``streaming=True``, else final decoded audio tensor
@@ -619,7 +626,7 @@ class VoxCPMModel(nn.Module):
prompt_audio_feat = prompt_cache["audio_feat"]
prompt_text = prompt_cache["prompt_text"]
text = prompt_text + target_text
text_token = torch.LongTensor(self.text_tokenizer(text))
text_token = torch.cat(
[
@@ -632,7 +639,7 @@ class VoxCPMModel(nn.Module):
],
dim=-1,
)
target_text_token = torch.LongTensor(self.text_tokenizer(target_text))
audio_length = prompt_audio_feat.size(0)
@@ -645,14 +652,18 @@ class VoxCPMModel(nn.Module):
)
text_token = torch.cat([text_token, text_pad_token])
audio_feat = torch.cat([audio_pad_feat, prompt_audio_feat], dim=0)
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
text_mask = (
torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
)
audio_mask = (
torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
)
text_token = text_token.unsqueeze(0).to(self.device)
text_mask = text_mask.unsqueeze(0).to(self.device)
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
audio_mask = audio_mask.unsqueeze(0).to(self.device)
# run inference
target_text_length = len(self.text_tokenizer(target_text))
retry_badcase_times = 0
@@ -663,7 +674,9 @@ class VoxCPMModel(nn.Module):
audio_feat,
audio_mask,
min_len=min_len,
max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len), # avoid too long audio
max_len=min(
int(target_text_length * retry_badcase_ratio_threshold + 10), max_len
), # avoid too long audio
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
streaming=streaming,
@@ -674,17 +687,16 @@ class VoxCPMModel(nn.Module):
for latent_pred, pred_audio_feat in inference_result:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
yield (
decode_audio,
target_text_token,
pred_audio_feat
)
yield (decode_audio, target_text_token, pred_audio_feat)
break
else:
latent_pred, pred_audio_feat = next(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...", file=sys.stderr)
print(
f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...",
file=sys.stderr,
)
retry_badcase_times += 1
continue
else:
@@ -695,18 +707,14 @@ class VoxCPMModel(nn.Module):
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
patch_len = self.patch_size * self.chunk_size
if audio_mask.sum().item() > 0:
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1):].squeeze(1).cpu()
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1) :].squeeze(1).cpu()
else:
decode_audio = decode_audio[..., :].squeeze(1).cpu()
yield (
decode_audio,
target_text_token,
pred_audio_feat
)
yield (decode_audio, target_text_token, pred_audio_feat)
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
return next(self._inference(*args, streaming=False, **kwargs))
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
return self._inference(*args, streaming=True, **kwargs)
@@ -725,10 +733,10 @@ class VoxCPMModel(nn.Module):
streaming_prefix_len: int = 3,
) -> Generator[Tuple[torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
"""Core inference method for audio generation.
This is the main inference loop that generates audio features
using the language model and diffusion transformer.
Args:
text: Input text tokens
text_mask: Mask for text tokens
@@ -739,7 +747,7 @@ class VoxCPMModel(nn.Module):
inference_timesteps: Number of diffusion steps
cfg_value: Classifier-free guidance value
streaming: Whether to yield each step latent feature or just the final result
Returns:
Generator of Tuple containing:
- Predicted latent feature at the current step if ``streaming=True``, else final latent features
@@ -749,12 +757,12 @@ class VoxCPMModel(nn.Module):
feat_embed = self.feat_encoder(feat) # [b, t, h_feat]
feat_embed = self.enc_to_lm_proj(feat_embed)
if self.config.lm_config.use_mup:
scale_emb = self.config.lm_config.scale_emb
else:
scale_emb = 1.0
text_embed = self.base_lm.embed_tokens(text) * scale_emb
combined_embed = text_mask.unsqueeze(-1) * text_embed + feat_mask.unsqueeze(-1) * feat_embed
@@ -778,11 +786,10 @@ class VoxCPMModel(nn.Module):
is_causal=True,
)
self.base_lm.kv_cache.fill_caches(kv_cache_tuple)
enc_outputs = self.fsq_layer(enc_outputs) * feat_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
lm_hidden = enc_outputs[:, -1, :]
residual_enc_outputs, residual_kv_cache_tuple = self.residual_lm(
inputs_embeds=enc_outputs + feat_mask.unsqueeze(-1) * feat_embed,
is_causal=True,
@@ -790,7 +797,6 @@ class VoxCPMModel(nn.Module):
self.residual_lm.kv_cache.fill_caches(residual_kv_cache_tuple)
residual_hidden = residual_enc_outputs[:, -1, :]
for i in tqdm(range(max_len)):
dit_hidden_1 = self.lm_to_dit_proj(lm_hidden) # [b, h_dit]
dit_hidden_2 = self.res_to_dit_proj(residual_hidden) # [b, h_dit]
@@ -805,10 +811,10 @@ class VoxCPMModel(nn.Module):
).transpose(
1, 2
) # [b, p, d]
curr_embed = self.feat_encoder(pred_feat.unsqueeze(1)) # b, 1, c
curr_embed = self.enc_to_lm_proj(curr_embed)
pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
prefix_feat_cond = pred_feat
@@ -816,58 +822,70 @@ class VoxCPMModel(nn.Module):
# return the last three predicted latent features to provide enough context for smooth decoding
pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1)
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
yield feat_pred, pred_feat_seq
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
if i > min_len and stop_flag == 1:
break
lm_hidden = self.base_lm.forward_step(
curr_embed[:, 0, :], torch.tensor([self.base_lm.kv_cache.step()], device=curr_embed.device)
).clone()
lm_hidden = self.fsq_layer(lm_hidden)
residual_hidden = self.residual_lm.forward_step(
lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device)
lm_hidden + curr_embed[:, 0, :],
torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device),
).clone()
if not streaming:
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
@classmethod
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
tokenizer = LlamaTokenizerFast.from_pretrained(path)
audio_vae_config = getattr(config, 'audio_vae_config', None)
audio_vae_config = getattr(config, "audio_vae_config", None)
audio_vae = AudioVAE(config=audio_vae_config) if audio_vae_config else AudioVAE()
vae_state_dict = torch.load(
os.path.join(path, "audiovae.pth"),
map_location="cpu",
weights_only=True,
)["state_dict"]
# Try to load AudioVAE from safetensors first, fallback to pytorch
audiovae_safetensors_path = os.path.join(path, "audiovae.safetensors")
audiovae_pth_path = os.path.join(path, "audiovae.pth")
if os.path.exists(audiovae_safetensors_path) and SAFETENSORS_AVAILABLE:
print(f"Loading AudioVAE from safetensors: {audiovae_safetensors_path}", file=sys.stderr)
vae_state_dict = load_file(audiovae_safetensors_path, device="cpu")
elif os.path.exists(audiovae_pth_path):
print(f"Loading AudioVAE from pytorch: {audiovae_pth_path}", file=sys.stderr)
checkpoint = torch.load(
audiovae_pth_path,
map_location="cpu",
weights_only=True,
)
vae_state_dict = checkpoint.get("state_dict", checkpoint)
else:
raise FileNotFoundError(
f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
)
model = cls(config, tokenizer, audio_vae, lora_config)
if not training:
lm_dtype = get_dtype(model.config.dtype)
model = model.to(lm_dtype)
else: # training mode
else: # training mode
for name, param in model.named_parameters():
if "audio_vae" in name: # freeze VAE weights
if "audio_vae" in name: # freeze VAE weights
param.requires_grad = False
continue
if lora_config is not None:
if "lora" not in name: # freeze non-LoRA weights
if "lora" not in name: # freeze non-LoRA weights
param.requires_grad = False
model.audio_vae = model.audio_vae.to(torch.float32)
# Try to load from safetensors first, fallback to pytorch_model.bin
safetensors_path = os.path.join(path, "model.safetensors")
pytorch_model_path = os.path.join(path, "pytorch_model.bin")
if os.path.exists(safetensors_path) and SAFETENSORS_AVAILABLE:
print(f"Loading model from safetensors: {safetensors_path}", file=sys.stderr)
model_state_dict = load_file(safetensors_path)
@@ -880,13 +898,11 @@ class VoxCPMModel(nn.Module):
)
model_state_dict = checkpoint.get("state_dict", checkpoint)
else:
raise FileNotFoundError(
f"Model file not found. Expected either {safetensors_path} or {pytorch_model_path}"
)
raise FileNotFoundError(f"Model file not found. Expected either {safetensors_path} or {pytorch_model_path}")
for kw, val in vae_state_dict.items():
model_state_dict[f"audio_vae.{kw}"] = val
# LoRALinear holds weight/bias directly, compatible with nn.Linear state_dict keys.
# Using strict=False since pretrained weights don't contain lora_A/lora_B.
model.load_state_dict(model_state_dict, strict=False)
@@ -900,6 +916,7 @@ class VoxCPMModel(nn.Module):
def _iter_lora_modules(self):
"""Iterate over all LoRA modules."""
from ..modules.layers.lora import LoRALinear
for module in self.modules():
if isinstance(module, LoRALinear):
yield module
@@ -909,7 +926,7 @@ class VoxCPMModel(nn.Module):
Load LoRA weights from file, supports calling after torch.compile.
Uses named_parameters() to handle compile's _orig_mod wrapper.
Supports both safetensors and pytorch formats.
Args:
lora_path: Checkpoint path (directory or .safetensors/.ckpt file)
device: Target device, defaults to model's current device
@@ -917,18 +934,18 @@ class VoxCPMModel(nn.Module):
tuple: (loaded_keys, skipped_keys)
"""
from pathlib import Path
device = device or self.device
lora_path = Path(lora_path)
lora_p = Path(lora_path)
# Try safetensors first, then fallback to .ckpt
if lora_path.is_dir():
safetensors_file = lora_path / "lora_weights.safetensors"
ckpt_file = lora_path / "lora_weights.ckpt"
if lora_p.is_dir():
safetensors_file = lora_p / "lora_weights.safetensors"
ckpt_file = lora_p / "lora_weights.ckpt"
else:
safetensors_file = lora_path if lora_path.suffix == ".safetensors" else None
ckpt_file = lora_path if lora_path.suffix in [".ckpt", ".pth"] else None
safetensors_file = lora_p if lora_p.suffix == ".safetensors" else None
ckpt_file = lora_p if lora_p.suffix in [".ckpt", ".pth"] else None
# Load from safetensors if available
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
state_dict = load_file(str(safetensors_file), device=device)
@@ -936,14 +953,12 @@ class VoxCPMModel(nn.Module):
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
state_dict = ckpt.get("state_dict", ckpt)
else:
raise FileNotFoundError(
f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}"
)
raise FileNotFoundError(f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}")
# Build param mapping (handle torch.compile's _orig_mod prefix)
model_params = dict(self.named_parameters())
key_mapping = {k.replace("._orig_mod.", "."): k for k in model_params if "._orig_mod." in k}
loaded_keys, skipped_keys = [], []
for key, value in state_dict.items():
target_key = key if key in model_params else key_mapping.get(key)
@@ -952,7 +967,7 @@ class VoxCPMModel(nn.Module):
loaded_keys.append(key)
else:
skipped_keys.append(key)
return loaded_keys, skipped_keys
def set_lora_enabled(self, enabled: bool):
@@ -967,6 +982,4 @@ class VoxCPMModel(nn.Module):
def get_lora_state_dict(self) -> dict:
"""Get all LoRA parameters (lora_A/lora_B)."""
return {name: param.data.clone()
for name, param in self.named_parameters()
if "lora_" in name}
return {name: param.data.clone() for name, param in self.named_parameters() if "lora_" in name}
File diff suppressed because it is too large Load Diff
+1
View File
@@ -1 +1,2 @@
from .audio_vae import AudioVAE, AudioVAEConfig
from .audio_vae_v2 import AudioVAE as AudioVAEV2, AudioVAEConfig as AudioVAEConfigV2
+4 -4
View File
@@ -1,5 +1,5 @@
import math
from typing import List, Union, Optional
from typing import List
import numpy as np
import torch
@@ -285,12 +285,12 @@ class AudioVAE(nn.Module):
def __init__(
self,
config: Optional[AudioVAEConfig] = None,
config: AudioVAEConfig = None,
):
# 如果没有传入config,使用默认配置
if config is None:
config = AudioVAEConfig()
super().__init__()
encoder_dim = config.encoder_dim
@@ -301,7 +301,7 @@ class AudioVAE(nn.Module):
depthwise = config.depthwise
sample_rate = config.sample_rate
use_noise_block = config.use_noise_block
self.encoder_dim = encoder_dim
self.encoder_rates = encoder_rates
self.decoder_dim = decoder_dim
+486
View File
@@ -0,0 +1,486 @@
import math
from typing import List, Optional
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
from pydantic import BaseModel
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
class CausalConv1d(nn.Conv1d):
def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs):
super().__init__(*args, **kwargs)
self.__padding = padding
self.__output_padding = output_padding
def forward(self, x):
x_pad = F.pad(x, (self.__padding * 2 - self.__output_padding, 0))
return super().forward(x_pad)
class CausalTransposeConv1d(nn.ConvTranspose1d):
def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs):
super().__init__(*args, **kwargs)
self.__padding = padding
self.__output_padding = output_padding
def forward(self, x):
return super().forward(x)[..., : -(self.__padding * 2 - self.__output_padding)]
def WNCausalConv1d(*args, **kwargs):
return weight_norm(CausalConv1d(*args, **kwargs))
def WNCausalTransposeConv1d(*args, **kwargs):
return weight_norm(CausalTransposeConv1d(*args, **kwargs))
# Scripting this brings model speed up 1.4x
@torch.jit.script
def snake(x, alpha):
shape = x.shape
x = x.reshape(shape[0], shape[1], -1)
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
x = x.reshape(shape)
return x
class Snake1d(nn.Module):
def __init__(self, channels):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
def forward(self, x):
return snake(x, self.alpha)
def init_weights(m):
if isinstance(m, nn.Conv1d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
class CausalResidualUnit(nn.Module):
def __init__(self, dim: int = 16, dilation: int = 1, kernel: int = 7, groups: int = 1):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.block = nn.Sequential(
Snake1d(dim),
WNCausalConv1d(
dim,
dim,
kernel_size=kernel,
dilation=dilation,
padding=pad,
groups=groups,
),
Snake1d(dim),
WNCausalConv1d(dim, dim, kernel_size=1),
)
def forward(self, x):
y = self.block(x)
pad = (x.shape[-1] - y.shape[-1]) // 2
assert pad == 0
if pad > 0:
x = x[..., pad:-pad]
return x + y
class CausalEncoderBlock(nn.Module):
def __init__(self, output_dim: int = 16, input_dim=None, stride: int = 1, groups=1):
super().__init__()
input_dim = input_dim or output_dim // 2
self.block = nn.Sequential(
CausalResidualUnit(input_dim, dilation=1, groups=groups),
CausalResidualUnit(input_dim, dilation=3, groups=groups),
CausalResidualUnit(input_dim, dilation=9, groups=groups),
Snake1d(input_dim),
WNCausalConv1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
output_padding=stride % 2,
),
)
def forward(self, x):
return self.block(x)
class CausalEncoder(nn.Module):
def __init__(
self,
d_model: int = 64,
latent_dim: int = 32,
strides: list = [2, 4, 8, 8],
depthwise: bool = False,
):
super().__init__()
# Create first convolution
self.block = [WNCausalConv1d(1, d_model, kernel_size=7, padding=3)]
# Create EncoderBlocks that double channels as they downsample by `stride`
for stride in strides:
d_model *= 2
groups = d_model // 2 if depthwise else 1
self.block += [CausalEncoderBlock(output_dim=d_model, stride=stride, groups=groups)]
groups = d_model if depthwise else 1
# Create two convolution, for mu and logvar
self.fc_mu = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
self.fc_logvar = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
# Wrap black into nn.Sequential
self.block = nn.Sequential(*self.block)
self.enc_dim = d_model
def forward(self, x):
hidden_state = self.block(x)
return {
"hidden_state": hidden_state,
"mu": self.fc_mu(hidden_state),
"logvar": self.fc_logvar(hidden_state),
}
class NoiseBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.linear = WNCausalConv1d(dim, dim, kernel_size=1, bias=False)
def forward(self, x):
B, C, T = x.shape
noise = torch.randn((B, 1, T), device=x.device, dtype=x.dtype)
h = self.linear(x)
n = noise * h
x = x + n
return x
class CausalDecoderBlock(nn.Module):
def __init__(
self,
input_dim: int = 16,
output_dim: int = 8,
stride: int = 1,
groups=1,
use_noise_block: bool = False,
):
super().__init__()
layers = [
Snake1d(input_dim),
WNCausalTransposeConv1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
output_padding=stride % 2,
),
]
if use_noise_block:
layers.append(NoiseBlock(output_dim))
layers.extend(
[
CausalResidualUnit(output_dim, dilation=1, groups=groups),
CausalResidualUnit(output_dim, dilation=3, groups=groups),
CausalResidualUnit(output_dim, dilation=9, groups=groups),
]
)
self.block = nn.Sequential(*layers)
self.input_channels = input_dim
def forward(self, x):
return self.block(x)
class TransposeLastTwoDim(torch.nn.Module):
def forward(self, x):
return torch.transpose(x, -1, -2)
class SampleRateConditionLayer(nn.Module):
def __init__(
self,
input_dim: int,
sr_bin_buckets: int = None,
cond_type: str = "scale_bias",
cond_dim: int = 128,
out_layer: bool = False,
):
super().__init__()
self.cond_type, out_layer_in_dim = cond_type, input_dim
if cond_type == "scale_bias":
self.scale_embed = nn.Embedding(sr_bin_buckets, input_dim)
self.bias_embed = nn.Embedding(sr_bin_buckets, input_dim)
nn.init.ones_(self.scale_embed.weight)
nn.init.zeros_(self.bias_embed.weight)
elif cond_type == "scale_bias_init":
self.scale_embed = nn.Embedding(sr_bin_buckets, input_dim)
self.bias_embed = nn.Embedding(sr_bin_buckets, input_dim)
nn.init.normal_(self.scale_embed.weight, mean=1)
nn.init.normal_(self.bias_embed.weight)
elif cond_type == "add":
self.cond_embed = nn.Embedding(sr_bin_buckets, input_dim)
nn.init.normal_(self.cond_embed.weight)
elif cond_type == "concat":
self.cond_embed = nn.Embedding(sr_bin_buckets, cond_dim)
assert out_layer, "out_layer must be True for concat cond_type"
out_layer_in_dim = input_dim + cond_dim
else:
raise ValueError(f"Invalid cond_type: {cond_type}")
if out_layer:
self.out_layer = nn.Sequential(
Snake1d(out_layer_in_dim),
WNCausalConv1d(out_layer_in_dim, input_dim, kernel_size=1),
)
else:
self.out_layer = nn.Identity()
def forward(self, x, sr_cond):
if self.cond_type == "scale_bias" or self.cond_type == "scale_bias_init":
x = x * self.scale_embed(sr_cond).unsqueeze(-1) + self.bias_embed(sr_cond).unsqueeze(-1)
elif self.cond_type == "add":
x = x + self.cond_embed(sr_cond).unsqueeze(-1)
elif self.cond_type == "concat":
x = torch.cat([x, self.cond_embed(sr_cond).unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
return self.out_layer(x)
class CausalDecoder(nn.Module):
def __init__(
self,
input_channel,
channels,
rates,
depthwise: bool = False,
d_out: int = 1,
use_noise_block: bool = False,
sr_bin_boundaries: List[int] = None,
cond_type: str = "scale_bias",
cond_dim: int = 128,
cond_out_layer: bool = False,
):
super().__init__()
# Add first conv layer
if depthwise:
layers = [
WNCausalConv1d(input_channel, input_channel, kernel_size=7, padding=3, groups=input_channel),
WNCausalConv1d(input_channel, channels, kernel_size=1),
]
else:
layers = [WNCausalConv1d(input_channel, channels, kernel_size=7, padding=3)]
# Add upsampling + MRF blocks
for i, stride in enumerate(rates):
input_dim = channels // 2**i
output_dim = channels // 2 ** (i + 1)
groups = output_dim if depthwise else 1
layers += [
CausalDecoderBlock(
input_dim,
output_dim,
stride,
groups=groups,
use_noise_block=use_noise_block,
)
]
# Add final conv layer
layers += [
Snake1d(output_dim),
WNCausalConv1d(output_dim, d_out, kernel_size=7, padding=3),
nn.Tanh(),
]
if sr_bin_boundaries is None:
self.model = nn.Sequential(*layers)
self.sr_bin_boundaries = None
else:
self.model = nn.ModuleList(layers)
self.register_buffer("sr_bin_boundaries", torch.tensor(sr_bin_boundaries, dtype=torch.int32))
self.sr_bin_buckets = len(sr_bin_boundaries) + 1
cond_layers = []
for layer in self.model:
if layer.__class__.__name__ == "CausalDecoderBlock":
cond_layers.append(
SampleRateConditionLayer(
input_dim=layer.input_channels,
sr_bin_buckets=self.sr_bin_buckets,
cond_type=cond_type,
cond_dim=cond_dim,
out_layer=cond_out_layer,
)
)
else:
cond_layers.append(None)
self.sr_cond_model = nn.ModuleList(cond_layers)
def get_sr_idx(self, sr):
return torch.bucketize(sr, self.sr_bin_boundaries)
def forward(self, x, sr_cond=None):
if self.sr_bin_boundaries is not None:
# assert sr_cond is not None
sr_cond = self.get_sr_idx(sr_cond)
for layer, sr_cond_layer in zip(self.model, self.sr_cond_model):
if sr_cond_layer is not None:
x = sr_cond_layer(x, sr_cond)
x = layer(x)
return x
else:
return self.model(x)
class AudioVAEConfig(BaseModel):
encoder_dim: int = 128
encoder_rates: List[int] = [2, 5, 8, 8]
latent_dim: int = 64
decoder_dim: int = 2048
decoder_rates: List[int] = [8, 6, 5, 2, 2, 2]
depthwise: bool = True
sample_rate: int = 16000
out_sample_rate: int = 48000
use_noise_block: bool = False
sr_bin_boundaries: Optional[List[int]] = [20000, 30000, 40000]
cond_type: str = "scale_bias"
cond_dim: int = 128
cond_out_layer: bool = False
class AudioVAE(nn.Module):
"""
Args:
"""
def __init__(
self,
config: AudioVAEConfig = None,
):
# 如果没有传入config,使用默认配置
if config is None:
config = AudioVAEConfig()
super().__init__()
encoder_dim = config.encoder_dim
encoder_rates = config.encoder_rates
latent_dim = config.latent_dim
decoder_dim = config.decoder_dim
decoder_rates = config.decoder_rates
depthwise = config.depthwise
sample_rate = config.sample_rate
out_sample_rate = config.out_sample_rate
use_noise_block = config.use_noise_block
sr_bin_boundaries = config.sr_bin_boundaries
cond_type = config.cond_type
cond_dim = config.cond_dim
cond_out_layer = config.cond_out_layer
self.encoder_dim = encoder_dim
self.encoder_rates = encoder_rates
self.decoder_dim = decoder_dim
self.decoder_rates = decoder_rates
self.depthwise = depthwise
self.use_noise_block = use_noise_block
if latent_dim is None:
latent_dim = encoder_dim * (2 ** len(encoder_rates))
self.latent_dim = latent_dim
self.hop_length = np.prod(encoder_rates)
self.encoder = CausalEncoder(
encoder_dim,
latent_dim,
encoder_rates,
depthwise=depthwise,
)
self.decoder = CausalDecoder(
latent_dim,
decoder_dim,
decoder_rates,
depthwise=depthwise,
use_noise_block=use_noise_block,
sr_bin_boundaries=sr_bin_boundaries,
cond_type=cond_type,
cond_dim=cond_dim,
cond_out_layer=cond_out_layer,
)
self.sample_rate = sample_rate
self.out_sample_rate = out_sample_rate
self.sr_bin_boundaries = sr_bin_boundaries
self.chunk_size = math.prod(encoder_rates)
def preprocess(self, audio_data, sample_rate):
if sample_rate is None:
sample_rate = self.sample_rate
assert sample_rate == self.sample_rate
pad_to = self.hop_length
length = audio_data.shape[-1]
right_pad = math.ceil(length / pad_to) * pad_to - length
audio_data = nn.functional.pad(audio_data, (0, right_pad))
return audio_data
def decode(self, z: torch.Tensor, sr_cond: torch.Tensor = None):
"""Decode given latent codes and return audio data
Parameters
----------
z : Tensor[B x D x T]
Quantized continuous representation of input
length : int, optional
Number of samples in output audio, by default None
Returns
-------
dict
A dictionary with the following keys:
"audio" : Tensor[B x 1 x length]
Decoded audio data.
"""
if self.sr_bin_boundaries is not None:
# use default output sample rate
if sr_cond is None:
sr_cond = torch.tensor([self.out_sample_rate], device=z.device, dtype=torch.int32)
return self.decoder(z, sr_cond)
def encode(self, audio_data: torch.Tensor, sample_rate: int):
"""
Args:
audio_data: Tensor[B x 1 x T]
sample_rate: int
Returns:
z: Tensor[B x D x T]
"""
if audio_data.ndim == 2:
audio_data = audio_data.unsqueeze(1)
audio_data = self.preprocess(audio_data, sample_rate)
return self.encoder(audio_data)["mu"]
+1 -1
View File
@@ -1 +1 @@
from .scalar_quantization_layer import ScalarQuantizationLayer
from .scalar_quantization_layer import ScalarQuantizationLayer
+1 -4
View File
@@ -34,7 +34,7 @@ class LoRALinear(nn.Module):
self.r = r
self.alpha = alpha
self._base_scaling = alpha / r if r > 0 else 0.0
# 使用 buffer 存储 scaling,这样修改值不会触发 torch.compile 重编译
# persistent=False 表示不保存到 state_dict,避免加载时 missing key
self.register_buffer("scaling", torch.tensor(self._base_scaling), persistent=False)
@@ -128,6 +128,3 @@ def apply_lora_to_named_linear_modules(
dropout=dropout,
)
setattr(parent, short_name, lora_layer)
@@ -12,7 +12,7 @@ class ScalarQuantizationLayer(nn.Module):
self.in_proj = nn.Linear(in_dim, latent_dim)
self.out_proj = nn.Linear(latent_dim, out_dim)
def forward(self, hidden):
hidden = self.in_proj(hidden)
hidden = torch.tanh(hidden)
@@ -23,4 +23,4 @@ class ScalarQuantizationLayer(nn.Module):
else:
hidden = torch.round(hidden * self.scale) / self.scale
return self.out_proj(hidden)
return self.out_proj(hidden)
+1
View File
@@ -1,2 +1,3 @@
from .unified_cfm import UnifiedCFM, CfmConfig
from .local_dit import VoxCPMLocDiT
from .local_dit_v2 import VoxCPMLocDiT as VoxCPMLocDiTV2
+116
View File
@@ -0,0 +1,116 @@
import torch
from ..minicpm4 import MiniCPMModel, MiniCPM4Config
import torch.nn as nn
import math
class SinusoidalPosEmb(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
def forward(self, x, scale=1000):
if x.ndim < 1:
x = x.unsqueeze(0)
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=x.dtype, device=device) * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
out_dim: int = None,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True)
self.act = nn.SiLU()
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, bias=True)
def forward(self, sample):
sample = self.linear_1(sample)
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
class VoxCPMLocDiT(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
config: MiniCPM4Config,
in_channels: int = 64,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels
self.config = config
self.in_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
self.cond_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
self.out_proj = nn.Linear(config.hidden_size, self.out_channels, bias=True)
self.time_embeddings = SinusoidalPosEmb(config.hidden_size)
self.time_mlp = TimestepEmbedding(
in_channels=config.hidden_size,
time_embed_dim=config.hidden_size,
)
self.delta_time_mlp = TimestepEmbedding(
in_channels=config.hidden_size,
time_embed_dim=config.hidden_size,
)
assert config.vocab_size == 0, "vocab_size must be 0 for local DiT"
self.decoder = MiniCPMModel(config)
def forward(
self,
x: torch.Tensor,
mu: torch.Tensor,
t: torch.Tensor,
cond: torch.Tensor,
dt: torch.Tensor,
):
"""
Forward pass of DiT.
x: (N, C, T) tensor of inputs
mu: (N, C) tensor of hidden embedding
t: (N,) tensor of diffusion timesteps
cond: (N, C, T') tensor of prefix conditions
dt: (N,) used for mean velocity (may be supported in the future...)
"""
x = self.in_proj(x.transpose(1, 2).contiguous())
cond = self.cond_proj(cond.transpose(1, 2).contiguous())
prefix = cond.size(1)
t = self.time_embeddings(t).to(x.dtype)
t = self.time_mlp(t)
dt = self.time_embeddings(dt).to(x.dtype)
dt = self.delta_time_mlp(dt)
t = t + dt
mu = mu.view(x.size(0), -1, x.size(-1))
x = torch.cat([mu, (t).unsqueeze(1), cond, x], dim=1)
hidden, _ = self.decoder(x, is_causal=False)
hidden = hidden[:, prefix + mu.size(1) + 1 :, :]
hidden = self.out_proj(hidden)
return hidden.transpose(1, 2).contiguous()
+8 -7
View File
@@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import Tuple
import torch
import torch.nn.functional as F
@@ -56,7 +56,7 @@ class UnifiedCFM(torch.nn.Module):
cond: torch.Tensor,
temperature: float = 1.0,
cfg_value: float = 1.0,
sway_sampling_coef: float = 1.0,
sway_sampling_coef: float = 1.0,
use_cfg_zero_star: bool = True,
):
b, _ = mu.shape
@@ -116,7 +116,7 @@ class UnifiedCFM(torch.nn.Module):
dphi_dt = self.estimator(x_in, mu_in, t_in, cond_in, dt_in)
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
if use_cfg_zero_star:
positive_flat = dphi_dt.view(b, -1)
negative_flat = cfg_dphi_dt.view(b, -1)
@@ -124,7 +124,7 @@ class UnifiedCFM(torch.nn.Module):
st_star = st_star.view(b, *([1] * (len(dphi_dt.shape) - 1)))
else:
st_star = 1.0
dphi_dt = cfg_dphi_dt * st_star + cfg_value * (dphi_dt - cfg_dphi_dt * st_star)
x = x - dt * dphi_dt
@@ -138,7 +138,9 @@ class UnifiedCFM(torch.nn.Module):
# ------------------------------------------------------------------ #
# Training loss
# ------------------------------------------------------------------ #
def adaptive_loss_weighting(self, losses: torch.Tensor, mask: torch.Tensor | None = None, p: float = 0.0, epsilon: float = 1e-3):
def adaptive_loss_weighting(
self, losses: torch.Tensor, mask: torch.Tensor | None = None, p: float = 0.0, epsilon: float = 1e-3
):
weights = 1.0 / ((losses + epsilon).pow(p))
if mask is not None:
weights = weights * mask
@@ -193,8 +195,7 @@ class UnifiedCFM(torch.nn.Module):
cond = cond + noisy_mask.view(-1, 1, 1) * torch.randn_like(cond) * self.noise_cond_scale
ratio_r_neq_t = (
self.ratio_r_neq_t_range[0]
+ progress * (self.ratio_r_neq_t_range[1] - self.ratio_r_neq_t_range[0])
self.ratio_r_neq_t_range[0] + progress * (self.ratio_r_neq_t_range[1] - self.ratio_r_neq_t_range[0])
if self.mean_mode
else 0.0
)
+1 -1
View File
@@ -26,4 +26,4 @@ class MiniCPM4Config(BaseModel):
dim_model_base: int
scale_depth: float
rope_theta: float
kv_channels: int = None
kv_channels: int = None
+13 -14
View File
@@ -64,10 +64,8 @@ class MiniCPMLongRoPE(nn.Module):
self.long_factor = config.rope_scaling.long_factor
self.original_max_position_embeddings = config.rope_scaling.original_max_position_embeddings
scale = (self.max_position_embeddings / self.original_max_position_embeddings)
self.scaling_factor = math.sqrt(
1 + math.log(scale) / math.log(self.original_max_position_embeddings)
)
scale = self.max_position_embeddings / self.original_max_position_embeddings
self.scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
@@ -76,11 +74,7 @@ class MiniCPMLongRoPE(nn.Module):
self.register_buffer("cos_cached", torch.empty(0), persistent=False)
self.register_buffer("sin_cached", torch.empty(0), persistent=False)
self._set_cos_sin_cache(
seq_len=self.max_position_embeddings,
device=self.inv_freq.device,
dtype=torch.float32
)
self._set_cos_sin_cache(seq_len=self.max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32)
def _set_cos_sin_cache(self, seq_len, device, dtype):
"""设置cos和sin缓存"""
@@ -93,8 +87,7 @@ class MiniCPMLongRoPE(nn.Module):
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=device)
freqs = torch.mul(
torch.outer(t, 1.0 / ext_factors).to(device=device),
self.inv_freq.to(device=device).to(dtype)
torch.outer(t, 1.0 / ext_factors).to(device=device), self.inv_freq.to(device=device).to(dtype)
)
# 创建embeddings
@@ -123,7 +116,9 @@ class MiniCPMAttention(nn.Module):
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
self.head_dim = (
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
)
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
@@ -153,7 +148,7 @@ class MiniCPMAttention(nn.Module):
cos, sin = position_emb
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# ref: https://github.com/pytorch/pytorch/issues/163597
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
query_states = query_states.contiguous()
@@ -413,7 +408,11 @@ class MiniCPMModel(nn.Module):
self.kv_cache = StaticKVCache(
num_layers=self.config.num_hidden_layers,
num_kv_heads=self.config.num_key_value_heads,
dim_kv_head=self.config.hidden_size // self.config.num_attention_heads if self.config.kv_channels is None else self.config.kv_channels,
dim_kv_head=(
self.config.hidden_size // self.config.num_attention_heads
if self.config.kv_channels is None
else self.config.kv_channels
),
batch_size=batch_size,
device=device,
dtype=dtype,
-1
View File
@@ -25,4 +25,3 @@ __all__ = [
"load_audio_text_datasets",
"build_dataloader",
]
+2 -5
View File
@@ -47,9 +47,7 @@ class Accelerator:
pass
self.scaler = torch.amp.GradScaler("cuda") if (amp and torch.cuda.is_available()) else DummyScaler()
self.device_ctx = (
torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
)
self.device_ctx = torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
self._ddp_model = None # For no_sync support
def _set_seed(self, seed: int):
@@ -84,7 +82,7 @@ class Accelerator:
# Model helpers
# ------------------------------------------------------------------ #
def prepare_model(self, model: torch.nn.Module, **kwargs):
if hasattr(model, 'device'): # make sure the matrix will be moved to the correct device
if hasattr(model, "device"): # make sure the matrix will be moved to the correct device
model.device = self.device
model = model.to(self.device)
if self.world_size > 1:
@@ -163,4 +161,3 @@ class Accelerator:
@staticmethod
def unwrap(model: torch.nn.Module) -> torch.nn.Module:
return model.module if hasattr(model, "module") else model
-2
View File
@@ -36,5 +36,3 @@ def parse_args_with_config(config_path: str | Path | None = None):
yaml_args = argbind.parse_args(yaml_args=yaml_args, argv=[])
cli_args.update(yaml_args)
return cli_args
+4 -6
View File
@@ -1,5 +1,4 @@
import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import argbind
@@ -11,7 +10,6 @@ from ..model.voxcpm import VoxCPMConfig
from ..modules.audiovae import AudioVAE
from .packers import AudioFeatureProcessingPacker
DEFAULT_TEXT_COLUMN = "text"
DEFAULT_AUDIO_COLUMN = "audio"
DEFAULT_ID_COLUMN = "dataset_id"
@@ -36,7 +34,7 @@ def load_audio_text_datasets(
def prepare(ds: Dataset) -> Dataset:
if audio_column not in ds.column_names:
raise ValueError(f"Expected '{audio_column}' column in manifest.")
# We cast to Audio to ensure proper handling during training,
# We cast to Audio to ensure proper handling during training,
# but for length calculation we might need raw path or duration if available.
# HF datasets usually don't compute duration automatically for 'Audio' column.
ds = ds.cast_column(audio_column, Audio(sampling_rate=sample_rate))
@@ -70,13 +68,13 @@ def compute_sample_lengths(
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
t_seq = ceil(t_vae / patch_size)
- 序列总长约为: text_len + t_seq + 2
Optimized: Use batch column access instead of iterating item by item.
"""
# Batch access columns - much faster than per-item access
text_ids_list = ds["text_ids"]
text_lens = [len(t) for t in text_ids_list]
has_duration = "duration" in ds.column_names
if has_duration:
durations = ds["duration"]
@@ -86,7 +84,7 @@ def compute_sample_lengths(
for i in range(len(ds)):
audio = ds[i][DEFAULT_AUDIO_COLUMN]
durations.append(len(audio["array"]) / float(audio["sampling_rate"]))
# Vectorized length computation
lengths = []
for text_len, duration in zip(text_lens, durations):
+29 -22
View File
@@ -1,5 +1,4 @@
from typing import Dict, List, Tuple
from typing import Dict, List
import torch
import torch.nn as nn
@@ -15,7 +14,7 @@ class AudioFeatureProcessingPacker:
def __init__(self, dataset_cnt: int, max_len: int, patch_size: int, feat_dim: int, audio_vae: nn.Module):
self.audio_start_id = 101
self.audio_end_id = 102
# unused now
# unused now
self.audio_prompt_start_id = 103
self.audio_prompt_end_id = 104
self.text_eos_token_id = 2
@@ -147,31 +146,26 @@ class AudioFeatureProcessingPacker:
def pad_1d(x: torch.Tensor, pad_value: int = 0) -> torch.Tensor:
if x.size(0) >= max_len:
return x[: max_len]
return x[:max_len]
pad = torch.full((max_len - x.size(0),), pad_value, dtype=x.dtype, device=x.device)
return torch.cat([x, pad], dim=0)
def pad_3d(x: torch.Tensor) -> torch.Tensor:
# x: [T, P, D]
if x.size(0) >= max_len:
return x[: max_len]
pad = torch.zeros(
(max_len - x.size(0),) + x.shape[1:], dtype=x.dtype, device=x.device
)
return x[:max_len]
pad = torch.zeros((max_len - x.size(0),) + x.shape[1:], dtype=x.dtype, device=x.device)
return torch.cat([x, pad], dim=0)
if lengths:
text_tokens_batch = torch.stack([pad_1d(t, pad_value=0) for t in text_tokens_list], dim=0)
text_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in text_mask_list], dim=0)
audio_feats_batch = torch.stack([pad_3d(f) for f in audio_feats_list], dim=0)
audio_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in audio_mask_list], dim=0)
loss_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in loss_mask_list], dim=0)
labels_batch = torch.stack([pad_1d(l, pad_value=0) for l in labels_list], dim=0)
audio_task_ids_batch = torch.stack(
[pad_1d(t, pad_value=0) for t in audio_task_ids_list], dim=0
)
audio_dataset_ids_batch = torch.stack(
[pad_1d(d, pad_value=0) for d in audio_dataset_ids_list], dim=0
)
labels_batch = torch.stack([pad_1d(lbl, pad_value=0) for lbl in labels_list], dim=0)
audio_task_ids_batch = torch.stack([pad_1d(t, pad_value=0) for t in audio_task_ids_list], dim=0)
audio_dataset_ids_batch = torch.stack([pad_1d(d, pad_value=0) for d in audio_dataset_ids_list], dim=0)
# Position ids: [B, T], simple 0..L_i-1 then padded with 0
position_ids_list = []
@@ -265,13 +259,27 @@ class AudioFeatureProcessingPacker:
)
audio_feat_info = torch.cat([audio_pad_feat, audio_feat_info, audio_pad_feat[0:1, ...]], dim=0)
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length), torch.ones(1)]).type(torch.int32).to(
text_token.device
text_mask = (
torch.cat([torch.ones(text_length), torch.zeros(audio_length), torch.ones(1)])
.type(torch.int32)
.to(text_token.device)
)
audio_mask = (
torch.cat([torch.zeros(text_length), torch.ones(audio_length), torch.zeros(1)])
.type(torch.int32)
.to(text_token.device)
)
loss_mask = (
torch.cat(
[
torch.zeros(text_length),
torch.zeros(audio_length) if is_prompt else torch.ones(audio_length),
torch.zeros(1),
]
)
.type(torch.int32)
.to(text_token.device)
)
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length), torch.zeros(1)]).type(
torch.int32
).to(text_token.device)
loss_mask = torch.cat([torch.zeros(text_length), torch.zeros(audio_length) if is_prompt else torch.ones(audio_length), torch.zeros(1)]).type(torch.int32).to(text_token.device)
labels = torch.zeros(text_length + audio_length + 1).type(torch.int32).to(text_token.device)
labels[-2] = 1
@@ -286,4 +294,3 @@ class AudioFeatureProcessingPacker:
audio_duration,
text_token_count,
)
-1
View File
@@ -18,4 +18,3 @@ class TrainingState:
val_loader: object
tracker: object
batch_processor: object
-1
View File
@@ -76,4 +76,3 @@ class TrainingTracker:
@contextlib.contextmanager
def live(self):
yield
+34 -31
View File
@@ -2,10 +2,10 @@
import re
import regex
import inflect
from functools import partial
from wetext import Normalizer
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]+")
# whether contain chinese character
def contains_chinese(text):
@@ -14,19 +14,19 @@ def contains_chinese(text):
# replace special symbol
def replace_corner_mark(text):
text = text.replace('²', '平方')
text = text.replace('³', '立方')
text = text.replace('', '根号')
text = text.replace('', '约等于')
text = text.replace('<', '小于')
text = text.replace("²", "平方")
text = text.replace("³", "立方")
text = text.replace("", "根号")
text = text.replace("", "约等于")
text = text.replace("<", "小于")
return text
# remove meaningless symbol
def remove_bracket(text):
text = text.replace('', ' ').replace('', ' ')
text = text.replace('', ' ').replace('', ' ')
text = text.replace('`', '').replace('`', '')
text = text.replace("", " ").replace("", " ")
text = text.replace("", " ").replace("", " ")
text = text.replace("`", "").replace("`", "")
text = text.replace("——", " ")
return text
@@ -38,7 +38,7 @@ def spell_out_number(text: str, inflect_parser):
for i, c in enumerate(text):
if not c.isdigit():
if st is not None:
num_str = inflect_parser.number_to_words(text[st: i])
num_str = inflect_parser.number_to_words(text[st:i])
new_text.append(num_str)
st = None
new_text.append(c)
@@ -48,7 +48,7 @@ def spell_out_number(text: str, inflect_parser):
if st is not None and st < len(text):
num_str = inflect_parser.number_to_words(text[st:])
new_text.append(num_str)
return ''.join(new_text)
return "".join(new_text)
# split paragrah logic
@@ -69,18 +69,18 @@ def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=
return len(tokenize(_text)) < merge_len
if lang == "zh":
pounc = ['', '', '', '', '', '', '.', '?', '!', ';']
pounc = ["", "", "", "", "", "", ".", "?", "!", ";"]
else:
pounc = ['.', '?', '!', ';', ':']
pounc = [".", "?", "!", ";", ":"]
if comma_split:
pounc.extend(['', ','])
pounc.extend(["", ","])
st = 0
utts = []
for i, c in enumerate(text):
if c in pounc:
if len(text[st: i]) > 0:
utts.append(text[st: i] + c)
if i + 1 < len(text) and text[i + 1] in ['"', '']:
if len(text[st:i]) > 0:
utts.append(text[st:i] + c)
if i + 1 < len(text) and text[i + 1] in ['"', ""]:
tmp = utts.pop(-1)
utts.append(tmp + text[i + 1])
st = i + 2
@@ -88,9 +88,9 @@ def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=
st = i + 1
if len(utts) == 0:
if lang == "zh":
utts.append(text + '')
utts.append(text + "")
else:
utts.append(text + '.')
utts.append(text + ".")
final_utts = []
cur_utt = ""
for utt in utts:
@@ -112,13 +112,13 @@ def replace_blank(text: str):
out_str = []
for i, c in enumerate(text):
if c == " ":
if ((text[i + 1].isascii() and text[i + 1] != " ") and
(text[i - 1].isascii() and text[i - 1] != " ")):
if (text[i + 1].isascii() and text[i + 1] != " ") and (text[i - 1].isascii() and text[i - 1] != " "):
out_str.append(c)
else:
out_str.append(c)
return "".join(out_str)
def clean_markdown(md_text: str) -> str:
# 去除代码块 ``` ```(包括多行)
md_text = re.sub(r"```.*?```", "", md_text, flags=re.DOTALL)
@@ -131,9 +131,9 @@ def clean_markdown(md_text: str) -> str:
# 去除链接但保留文本 [text](url) -> text
md_text = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", md_text)
# 替换无序列表符号
md_text = re.sub(r'^(\s*)-\s+', r'\1', md_text, flags=re.MULTILINE)
md_text = re.sub(r"^(\s*)-\s+", r"\1", md_text, flags=re.MULTILINE)
# 去除HTML标签
md_text = re.sub(r"<[^>]+>", "", md_text)
@@ -152,28 +152,31 @@ def clean_text(text):
# 去除 Markdown 语法
text = clean_markdown(text)
# 匹配并移除表情符号
text = regex.compile(r'\p{Emoji_Presentation}|\p{Emoji}\uFE0F', flags=regex.UNICODE).sub("",text)
text = regex.compile(r"\p{Emoji_Presentation}|\p{Emoji}\uFE0F", flags=regex.UNICODE).sub("", text)
# 去除换行符
text = text.replace("\n", " ")
text = text.replace("\t", " ")
text = text.replace('"', "\")
text = text.replace("", '"').replace("", '"')
return text
class TextNormalizer:
def __init__(self, tokenizer=None):
self.tokenizer = tokenizer
self.zh_tn_model = Normalizer(lang="zh", operator="tn", remove_erhua=True)
self.en_tn_model = Normalizer(lang="en", operator="tn")
self.inflect_parser = inflect.engine()
def normalize(self, text, split=False):
# 去除 Markdown 语法,去除表情符号,去除换行符
lang = "zh" if contains_chinese(text) else "en"
text = clean_text(text)
if lang == "zh":
text = text.replace("=", "等于") # 修复 ”550 + 320 等于 870 千卡。“ 被错误正则为 ”五百五十加三百二十等于八七十千卡.“
if re.search(r'([\d$%^*_+≥≤≠×÷?=])', text): # 避免 英文连字符被错误正则为减
text = re.sub(r'(?<=[a-zA-Z0-9])-(?=\d)', ' - ', text) # 修复 x-2 被正则为 x负2
text = text.replace(
"=", "等于"
) # 修复 ”550 + 320 等于 870 千卡。“ 被错误正则为 ”五百五十加三百二十等于八七十千卡.“
if re.search(r"([\d$%^*_+≥≤≠×÷?=])", text): # 避免 英文连字符被错误正则为减
text = re.sub(r"(?<=[a-zA-Z0-9])-(?=\d)", " - ", text) # 修复 x-2 被正则为 x负2
text = self.zh_tn_model.normalize(text)
text = replace_blank(text)
text = replace_corner_mark(text)
@@ -182,4 +185,4 @@ class TextNormalizer:
text = self.en_tn_model.normalize(text)
text = spell_out_number(text, self.inflect_parser)
if split is False:
return text
return text
+10 -14
View File
@@ -7,15 +7,15 @@ Related dependencies are imported only when denoising functionality is needed.
import os
import tempfile
from typing import Optional, Union
from typing import Optional
import torchaudio
import torch
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
class ZipEnhancer:
"""ZipEnhancer Audio Denoising Enhancer"""
def __init__(self, model_path: str = "iic/speech_zipenhancer_ans_multiloss_16k_base"):
"""
Initialize ZipEnhancer
@@ -23,25 +23,21 @@ class ZipEnhancer:
model_path: ModelScope model path or local path
"""
self.model_path = model_path
self._pipeline = pipeline(
Tasks.acoustic_noise_suppression,
model=self.model_path
)
self._pipeline = pipeline(Tasks.acoustic_noise_suppression, model=self.model_path)
def _normalize_loudness(self, wav_path: str):
"""
Audio loudness normalization
Args:
wav_path: Audio file path
"""
audio, sr = torchaudio.load(wav_path)
loudness = torchaudio.functional.loudness(audio, sr)
normalized_audio = torchaudio.functional.gain(audio, -20-loudness)
normalized_audio = torchaudio.functional.gain(audio, -20 - loudness)
torchaudio.save(wav_path, normalized_audio, sr)
def enhance(self, input_path: str, output_path: Optional[str] = None,
normalize_loudness: bool = True) -> str:
def enhance(self, input_path: str, output_path: Optional[str] = None, normalize_loudness: bool = True) -> str:
"""
Audio denoising enhancement
Args:
@@ -57,7 +53,7 @@ class ZipEnhancer:
raise FileNotFoundError(f"Input audio file does not exist: {input_path}")
# Create temporary file if no output path is specified
if output_path is None:
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
output_path = tmp_file.name
try:
# Perform denoising processing
@@ -73,4 +69,4 @@ class ZipEnhancer:
os.unlink(output_path)
except OSError:
pass
raise RuntimeError(f"Audio denoising processing failed: {e}")
raise RuntimeError(f"Audio denoising processing failed: {e}")
Generated
+5263
View File
File diff suppressed because it is too large Load Diff