update voxcpm2

This commit is contained in:
刘鑫
2026-03-31 11:50:37 +08:00
parent 23ed7ffeee
commit d9cf376e16
36 changed files with 8163 additions and 834 deletions
+144 -85
View File
@@ -3,13 +3,14 @@ import sys
import numpy as np
import torch
import gradio as gr
import spaces
import spaces # noqa: F401
from typing import Optional, Tuple
from funasr import AutoModel
from pathlib import Path
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if os.environ.get("HF_REPO_ID", "").strip() == "":
os.environ["HF_REPO_ID"] = "openbmb/VoxCPM1.5"
os.environ["HF_REPO_ID"] = "openbmb/VoxCPM2"
import voxcpm
@@ -24,13 +25,13 @@ class VoxCPMDemo:
self.asr_model: Optional[AutoModel] = AutoModel(
model=self.asr_model_id,
disable_update=True,
log_level='DEBUG',
log_level="DEBUG",
device="cuda:0" if self.device == "cuda" else "cpu",
)
# TTS model (lazy init)
self.voxcpm_model: Optional[voxcpm.VoxCPM] = None
self.default_local_model_dir = "./models/VoxCPM1.5"
self.default_local_model_dir = "/Users/xinliu/Downloads/VoxCPM2-0.5B-newaudiovae-6hz-0316"
# ---------- Model helpers ----------
def _resolve_model_dir(self) -> str:
@@ -49,6 +50,7 @@ class VoxCPMDemo:
if not os.path.isdir(target_dir):
try:
from huggingface_hub import snapshot_download # type: ignore
os.makedirs(target_dir, exist_ok=True)
print(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...", file=sys.stderr)
snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False)
@@ -64,7 +66,7 @@ class VoxCPMDemo:
print("Model not loaded, initializing...", file=sys.stderr)
model_dir = self._resolve_model_dir()
print(f"Using model dir: {model_dir}", file=sys.stderr)
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir)
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir, optimize=False)
print("Model loaded successfully.", file=sys.stderr)
return self.voxcpm_model
@@ -73,21 +75,24 @@ class VoxCPMDemo:
if prompt_wav is None:
return ""
res = self.asr_model.generate(input=prompt_wav, language="auto", use_itn=True)
text = res[0]["text"].split('|>')[-1]
text = res[0]["text"].split("|>")[-1]
return text
def generate_tts_audio(
self,
text_input: str,
prompt_wav_path_input: Optional[str] = None,
prompt_text_input: Optional[str] = None,
control_instruction: str = "",
reference_wav_path_input: Optional[str] = None,
cfg_value_input: float = 2.0,
inference_timesteps_input: int = 10,
do_normalize: bool = True,
denoise: bool = True,
) -> Tuple[int, np.ndarray]:
"""
Generate speech from text using VoxCPM; optional reference audio for voice style guidance.
Generate speech from text using VoxCPM.
- If reference_wav provided: Prompt isolation mode (voice cloning)
- If no reference_wav: Voice design mode (use control_instruction to describe voice)
Returns (sample_rate, waveform_numpy)
"""
current_model = self.get_or_load_voxcpm()
@@ -96,14 +101,25 @@ class VoxCPMDemo:
if len(text) == 0:
raise ValueError("Please input text to synthesize.")
prompt_wav_path = prompt_wav_path_input if prompt_wav_path_input else None
prompt_text = prompt_text_input if prompt_text_input else None
# 处理 control instruction
control = (control_instruction or "").strip()
if control:
final_text = f"({control}){text}"
else:
final_text = text
print(f"Generating audio for text: '{text[:60]}...'", file=sys.stderr)
reference_wav_path = reference_wav_path_input if reference_wav_path_input else None
# 判断模式
if reference_wav_path:
print(f"[Prompt Isolation Mode] reference_wav: {reference_wav_path}", file=sys.stderr)
else:
print(f"[Voice Design Mode] control: {control[:50] if control else 'None'}...", file=sys.stderr)
print(f"Generating audio for text: '{final_text[:80]}...'", file=sys.stderr)
wav = current_model.generate(
text=text,
prompt_text=prompt_text,
prompt_wav_path=prompt_wav_path,
text=final_text,
reference_wav_path=reference_wav_path,
cfg_value=float(cfg_value_input),
inference_timesteps=int(inference_timesteps_input),
normalize=do_normalize,
@@ -114,46 +130,53 @@ class VoxCPMDemo:
# ---------- UI Builders ----------
THEME = gr.themes.Soft(
primary_hue="blue",
secondary_hue="gray",
neutral_hue="slate",
font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"],
)
CSS = """
.logo-container {
text-align: center;
margin: 0.5rem 0 1rem 0;
}
.logo-container img {
height: 80px;
width: auto;
max-width: 200px;
display: inline-block;
}
/* Bold accordion labels */
#acc_quick > .label-wrap,
#acc_tips > .label-wrap,
#acc_quick > .label-wrap > span,
#acc_tips > .label-wrap > span,
#acc_quick summary,
#acc_tips summary {
font-weight: 600 !important;
font-size: 1.1em !important;
}
/* Bold labels for specific checkboxes */
#chk_denoise label,
#chk_denoise span,
#chk_normalize label,
#chk_normalize span {
font-weight: 600;
}
"""
def create_demo_interface(demo: VoxCPMDemo):
"""Build the Gradio UI for VoxCPM demo."""
# static assets (logo path)
gr.set_static_paths(paths=[Path.cwd().absolute()/"assets"])
gr.set_static_paths(paths=[Path.cwd().absolute() / "assets"])
with gr.Blocks(
theme=gr.themes.Soft(
primary_hue="blue",
secondary_hue="gray",
neutral_hue="slate",
font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"]
),
css="""
.logo-container {
text-align: center;
margin: 0.5rem 0 1rem 0;
}
.logo-container img {
height: 80px;
width: auto;
max-width: 200px;
display: inline-block;
}
/* Bold accordion labels */
#acc_quick details > summary,
#acc_tips details > summary {
font-weight: 600 !important;
font-size: 1.1em !important;
}
/* Bold labels for specific checkboxes */
#chk_denoise label,
#chk_denoise span,
#chk_normalize label,
#chk_normalize span {
font-weight: 600;
}
"""
) as interface:
# Header logo
gr.HTML('<div class="logo-container"><img src="/gradio_api/file=assets/voxcpm_logo.png" alt="VoxCPM Logo"></div>')
with gr.Blocks() as interface:
gr.HTML(
'<div class="logo-container"><img src="/gradio_api/file=assets/voxcpm_logo.png" alt="VoxCPM Logo"></div>',
padding=True,
)
# Quick Start
with gr.Accordion("📋 Quick Start Guide |快速入门", open=False, elem_id="acc_quick"):
@@ -200,34 +223,56 @@ def create_demo_interface(demo: VoxCPMDemo):
# Main controls
with gr.Row():
with gr.Column():
prompt_wav = gr.Audio(
sources=["upload", 'microphone'],
# 1. Reference Audio
# gr.Markdown("### 🎤 Reference Audio (Optional)")
# gr.Markdown("*提供参考音频进行音色克隆;不提供则使用 Voice Design 模式*")
reference_wav = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Prompt Speech (Optional, or let VoxCPM improvise)",
value="./examples/example.wav",
label="Reference Audio (Optional)",
)
DoDenoisePromptAudio = gr.Checkbox(
value=False,
label="Prompt Speech Enhancement",
label="Reference Audio Enhancement",
elem_id="chk_denoise",
info="We use ZipEnhancer model to denoise the prompt audio."
info="Use ZipEnhancer to denoise the reference audio",
)
with gr.Row():
prompt_text = gr.Textbox(
value="Just by listening a few minutes a day, you'll be able to eliminate negative thoughts by conditioning your mind to be more positive.",
label="Prompt Text",
placeholder="Please enter the prompt text. Automatic recognition is supported, and you can correct the results yourself..."
)
run_btn = gr.Button("Generate Speech", variant="primary")
# 2. Control Instruction
# gr.Markdown("### 🎛️ Control Instruction (Optional)")
# gr.Markdown("*描述声音风格、情感等,格式:`(instruction) text`*")
control_instruction = gr.Textbox(
value="",
label="Control Instruction",
placeholder="*描述声音风格、情感等,格式:`(instruction) text`,例如:年轻女性,温柔甜美 / 悲伤地说 / an excited young man*",
lines=2,
)
# 3. Target Text
# gr.Markdown("### 📝 Target Text")
text = gr.Textbox(
value="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly realistic speech.",
label="Target Text",
lines=3,
)
DoNormalizeText = gr.Checkbox(
value=False,
label="Text Normalization",
elem_id="chk_normalize",
info="Use wetext library to normalize the input text",
)
run_btn = gr.Button("🔊 Generate Speech", variant="primary", size="lg")
with gr.Column():
gr.Markdown("### ⚙️ Generation Settings")
cfg_value = gr.Slider(
minimum=1.0,
maximum=3.0,
value=2.0,
step=0.1,
label="CFG Value (Guidance Scale)",
info="Higher values increase adherence to prompt, lower values allow more creativity"
info="Higher = more adherence to prompt; Lower = more creativity",
)
inference_timesteps = gr.Slider(
minimum=4,
@@ -235,40 +280,54 @@ def create_demo_interface(demo: VoxCPMDemo):
value=10,
step=1,
label="Inference Timesteps",
info="Number of inference timesteps for generation (higher values may improve quality but slower)"
info="Higher = better quality but slower",
)
with gr.Row():
text = gr.Textbox(
value="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly realistic speech.",
label="Target Text",
)
with gr.Row():
DoNormalizeText = gr.Checkbox(
value=False,
label="Text Normalization",
elem_id="chk_normalize",
info="We use wetext library to normalize the input text."
)
audio_output = gr.Audio(label="Output Audio")
gr.Markdown("### 🔈 Output")
audio_output = gr.Audio(label="Generated Audio")
gr.Markdown("""
---
**模式说明 / Mode Info:**
- **有 Reference Audio** → Prompt 隔离模式(音色克隆)
- **无 Reference Audio** → Voice Design 模式(用 Control Instruction 描述声音)
**Control Instruction 示例:**
- `年轻女性,温柔甜美`
- `悲伤地说`
- `an excited young man`
""")
# Wiring
run_btn.click(
fn=demo.generate_tts_audio,
inputs=[text, prompt_wav, prompt_text, cfg_value, inference_timesteps, DoNormalizeText, DoDenoisePromptAudio],
inputs=[
text,
control_instruction,
reference_wav,
cfg_value,
inference_timesteps,
DoNormalizeText,
DoDenoisePromptAudio,
],
outputs=[audio_output],
show_progress=True,
api_name="generate",
)
prompt_wav.change(fn=demo.prompt_wav_recognition, inputs=[prompt_wav], outputs=[prompt_text])
return interface
def run_demo(server_name: str = "localhost", server_port: int = 7860, show_error: bool = True):
def run_demo(server_name: str = "0.0.0.0", server_port: int = 7869, show_error: bool = True):
demo = VoxCPMDemo()
interface = create_demo_interface(demo)
# Recommended to enable queue on Spaces for better throughput
interface.queue(max_size=10, default_concurrency_limit=1).launch(server_name=server_name, server_port=server_port, show_error=show_error)
interface.queue(max_size=10, default_concurrency_limit=1).launch(
server_name=server_name,
server_port=server_port,
show_error=show_error,
theme=THEME,
css=CSS,
)
if __name__ == "__main__":
+1 -1
View File
@@ -25,7 +25,7 @@ lora:
enable_lm: true
enable_dit: true
enable_proj: false
r: 32
r: 8
alpha: 16
dropout: 0.0
+1 -1
View File
@@ -25,7 +25,7 @@ lora:
enable_lm: true
enable_dit: true
enable_proj: false
r: 32
r: 8
alpha: 16
dropout: 0.0
Binary file not shown.
+210 -223
View File
@@ -1,18 +1,14 @@
import os
import sys
import time
import glob
import json
import yaml
import shutil
import datetime
import subprocess
import threading
import gradio as gr
import torch
import soundfile as sf
from pathlib import Path
from typing import Optional, List
from typing import Optional
# Add src to sys.path
project_root = Path(__file__).parent
@@ -89,7 +85,7 @@ LANG_DICT = {
"lang_select": "Language / 语言",
"refresh": "刷新",
"output_name": "输出目录名称 (可选,若存在则继续训练)",
}
},
}
# Global variables
@@ -98,9 +94,11 @@ asr_model: Optional[AutoModel] = None
training_process: Optional[subprocess.Popen] = None
training_log = ""
def get_timestamp_str():
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
def get_or_load_asr_model():
global asr_model
if asr_model is None:
@@ -109,23 +107,25 @@ def get_or_load_asr_model():
asr_model = AutoModel(
model="iic/SenseVoiceSmall",
disable_update=True,
log_level='ERROR',
log_level="ERROR",
device=device,
)
return asr_model
def recognize_audio(audio_path):
if not audio_path:
return ""
try:
model = get_or_load_asr_model()
res = model.generate(input=audio_path, language="auto", use_itn=True)
text = res[0]["text"].split('|>')[-1]
text = res[0]["text"].split("|>")[-1]
return text
except Exception as e:
print(f"ASR Error: {e}", file=sys.stderr)
return ""
def scan_lora_checkpoints(root_dir="lora", with_info=False):
"""
Scans for LoRA checkpoints in the lora directory.
@@ -165,11 +165,12 @@ def scan_lora_checkpoints(root_dir="lora", with_info=False):
# Also check for checkpoints in the default location if they exist
default_ckpt = "checkpoints/finetune_lora"
if os.path.exists(os.path.join(root_dir, default_ckpt)):
# This might be covered by the walk, but good to be sure
pass
# This might be covered by the walk, but good to be sure
pass
return sorted(checkpoints, reverse=True)
def load_lora_config_from_checkpoint(lora_path):
"""Load LoRA config from lora_config.json if available."""
lora_config_file = os.path.join(lora_path, "lora_config.json")
@@ -184,6 +185,7 @@ def load_lora_config_from_checkpoint(lora_path):
print(f"Warning: Failed to load lora_config.json: {e}", file=sys.stderr)
return None, None
def get_default_lora_config():
"""Return default LoRA config for hot-swapping support."""
return LoRAConfig(
@@ -192,9 +194,10 @@ def get_default_lora_config():
r=32,
alpha=16,
target_modules_lm=["q_proj", "v_proj", "k_proj", "o_proj"],
target_modules_dit=["q_proj", "v_proj", "k_proj", "o_proj"]
target_modules_dit=["q_proj", "v_proj", "k_proj", "o_proj"],
)
def load_model(pretrained_path, lora_path=None):
global current_model
print(f"Loading model from {pretrained_path}...", file=sys.stderr)
@@ -228,9 +231,8 @@ def load_model(pretrained_path, lora_path=None):
)
return "Model loaded successfully!"
def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, steps, seed, pretrained_path=None):
global current_model
def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, steps, seed, pretrained_path=None):
# 如果选择了 LoRA 模型且当前模型未加载,尝试从 LoRA config 读取 base_model
if current_model is None:
# 优先使用用户指定的预训练模型路径
@@ -261,7 +263,7 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
# 加载模型
try:
print(f"Loading base model: {base_model_path}", file=sys.stderr)
status_msg = load_model(base_model_path)
load_model(base_model_path)
if lora_selection and lora_selection != "None":
print(f"Model loaded for LoRA: {lora_selection}", file=sys.stderr)
except Exception as e:
@@ -270,6 +272,7 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
return None, error_msg
# Handle LoRA hot-swapping
assert current_model is not None, "Model must be loaded before inference"
if lora_selection and lora_selection != "None":
full_lora_path = os.path.join("lora", lora_selection)
print(f"Hot-loading LoRA: {full_lora_path}", file=sys.stderr)
@@ -317,14 +320,16 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
prompt_text=final_prompt_text,
cfg_value=cfg_scale,
inference_timesteps=steps,
denoise=False
denoise=False,
)
return (current_model.tts_model.sample_rate, audio_np), "Generation Success"
except Exception as e:
import traceback
traceback.print_exc()
return None, f"Error: {str(e)}"
def start_training(
pretrained_path,
train_manifest,
@@ -355,7 +360,7 @@ def start_training(
hf_model_id="",
distribute=False,
):
global training_process, training_log
global training_log
if training_process is not None and training_process.poll() is None:
return "Training is already running!"
@@ -394,10 +399,7 @@ def start_training(
"max_steps": resolved_max_steps,
"save_path": checkpoints_dir,
"tensorboard": tensorboard_path if tensorboard_path else logs_dir,
"lambdas": {
"loss/diff": 1.0,
"loss/stop": 1.0
},
"lambdas": {"loss/diff": 1.0, "loss/stop": 1.0},
"lora": {
"enable_lm": bool(enable_lm),
"enable_dit": bool(enable_dit),
@@ -406,7 +408,7 @@ def start_training(
"alpha": int(lora_alpha),
"dropout": float(dropout),
"target_modules_lm": ["q_proj", "v_proj", "k_proj", "o_proj"],
"target_modules_dit": ["q_proj", "v_proj", "k_proj", "o_proj"]
"target_modules_dit": ["q_proj", "v_proj", "k_proj", "o_proj"],
},
}
@@ -420,25 +422,15 @@ def start_training(
with open(config_path, "w") as f:
yaml.dump(config, f)
cmd = [
sys.executable,
"scripts/train_voxcpm_finetune.py",
"--config_path",
config_path
]
cmd = [sys.executable, "scripts/train_voxcpm_finetune.py", "--config_path", config_path]
training_log = f"Starting training...\nConfig saved to {config_path}\nOutput dir: {save_dir}\n"
def run_process():
global training_process, training_log
training_process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1
)
training_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
assert training_process.stdout is not None
for line in training_process.stdout:
training_log += line
# Keep log size manageable
@@ -452,17 +444,20 @@ def start_training(
return f"Training started! Check 'lora/{timestamp}'"
def get_training_log():
return training_log
def stop_training():
global training_process, training_log
global training_log
if training_process is not None and training_process.poll() is None:
training_process.terminate()
training_log += "\nTraining terminated by user."
return "Training stopped."
return "No training running."
# --- GUI Layout ---
# 自定义CSS样式
@@ -830,14 +825,10 @@ label {
}
"""
with gr.Blocks(
title="VoxCPM LoRA WebUI",
theme=gr.themes.Soft(),
css=custom_css
) as app:
with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css) as app:
# State for language
lang_state = gr.State("zh") # Default to Chinese
lang_state = gr.State("zh") # Default to Chinese
# 标题区域
with gr.Row(elem_classes="title-section"):
@@ -850,10 +841,7 @@ with gr.Blocks(
""")
with gr.Column(scale=1):
lang_btn = gr.Radio(
choices=["en", "zh"],
value="zh",
label="🌐 Language / 语言",
elem_classes="lang-selector"
choices=["en", "zh"], value="zh", label="🌐 Language / 语言", elem_classes="lang-selector"
)
with gr.Tabs(elem_classes="tabs") as tabs:
@@ -869,79 +857,40 @@ with gr.Blocks(
gr.Markdown("#### 📁 基础配置")
train_pretrained_path = gr.Textbox(
label="📂 预训练模型路径",
value=default_pretrained_path,
elem_classes="input-field"
label="📂 预训练模型路径", value=default_pretrained_path, elem_classes="input-field"
)
train_manifest = gr.Textbox(
label="📋 训练数据清单 (jsonl)",
value="examples/train_data_example.jsonl",
elem_classes="input-field"
)
val_manifest = gr.Textbox(
label="📊 验证数据清单 (可选)",
value="",
elem_classes="input-field"
elem_classes="input-field",
)
val_manifest = gr.Textbox(label="📊 验证数据清单 (可选)", value="", elem_classes="input-field")
gr.Markdown("#### ⚙️ 训练参数")
with gr.Row():
lr = gr.Number(
label="📈 学习率 (Learning Rate)",
value=1e-4,
elem_classes="input-field"
)
lr = gr.Number(label="📈 学习率 (Learning Rate)", value=1e-4, elem_classes="input-field")
num_iters = gr.Number(
label="🔄 最大迭代次数",
value=2000,
precision=0,
elem_classes="input-field"
label="🔄 最大迭代次数", value=2000, precision=0, elem_classes="input-field"
)
batch_size = gr.Number(
label="📦 批次大小 (Batch Size)",
value=1,
precision=0,
elem_classes="input-field"
label="📦 批次大小 (Batch Size)", value=1, precision=0, elem_classes="input-field"
)
with gr.Row():
lora_rank = gr.Number(
label="🎯 LoRA Rank",
value=32,
precision=0,
elem_classes="input-field"
)
lora_alpha = gr.Number(
label="⚖️ LoRA Alpha",
value=16,
precision=0,
elem_classes="input-field"
)
lora_rank = gr.Number(label="🎯 LoRA Rank", value=32, precision=0, elem_classes="input-field")
lora_alpha = gr.Number(label="⚖️ LoRA Alpha", value=16, precision=0, elem_classes="input-field")
save_interval = gr.Number(
label="💾 保存间隔 (Steps)",
value=1000,
precision=0,
elem_classes="input-field"
label="💾 保存间隔 (Steps)", value=1000, precision=0, elem_classes="input-field"
)
output_name = gr.Textbox(
label="📁 输出目录名称 (可选,若存在则继续训练)",
value="",
elem_classes="input-field"
label="📁 输出目录名称 (可选,若存在则继续训练)", value="", elem_classes="input-field"
)
with gr.Row():
start_btn = gr.Button(
" 开始训练",
variant="primary",
elem_classes="button-primary"
)
stop_btn = gr.Button(
"⏹️ 停止训练",
variant="stop",
elem_classes="button-stop"
)
start_btn = gr.Button("▶️ 开始训练", variant="primary", elem_classes="button-primary")
stop_btn = gr.Button(" 停止训练", variant="stop", elem_classes="button-stop")
with gr.Accordion("🔧 高级选项 (Advanced)", open=False, elem_classes="accordion"):
with gr.Row():
@@ -964,7 +913,9 @@ with gr.Blocks(
gr.Markdown("#### 分发选项 (Distribution)")
with gr.Row():
hf_model_id = gr.Textbox(label="HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)", value="openbmb/VoxCPM1.5")
hf_model_id = gr.Textbox(
label="HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)", value="openbmb/VoxCPM1.5"
)
distribute = gr.Checkbox(label="分发模式 (distribute)", value=False)
with gr.Column(scale=2, elem_classes="form-section"):
@@ -975,23 +926,41 @@ with gr.Blocks(
max_lines=30,
interactive=False,
elem_classes="input-field",
show_label=False
show_label=False,
)
start_btn.click(
start_training,
inputs=[
train_pretrained_path, train_manifest, val_manifest,
lr, num_iters, batch_size, lora_rank, lora_alpha, save_interval,
train_pretrained_path,
train_manifest,
val_manifest,
lr,
num_iters,
batch_size,
lora_rank,
lora_alpha,
save_interval,
output_name,
# advanced
grad_accum_steps, num_workers, log_interval, valid_interval,
weight_decay, warmup_steps, max_steps, sample_rate,
enable_lm, enable_dit, enable_proj, dropout, tensorboard_path,
grad_accum_steps,
num_workers,
log_interval,
valid_interval,
weight_decay,
warmup_steps,
max_steps,
sample_rate,
enable_lm,
enable_dit,
enable_proj,
dropout,
tensorboard_path,
# distribution
hf_model_id, distribute
hf_model_id,
distribute,
],
outputs=[logs_out] # Initial message
outputs=[logs_out], # Initial message
)
stop_btn.click(stop_training, outputs=[logs_out])
@@ -1016,21 +985,17 @@ with gr.Blocks(
value="Hello, this is a test of the VoxCPM LoRA model.",
elem_classes="input-field",
lines=4,
placeholder="输入要合成的文本内容..."
placeholder="输入要合成的文本内容...",
)
gr.Markdown("**🎭 声音克隆(可选)**")
prompt_wav = gr.Audio(
label="🎵 参考音频",
type="filepath",
elem_classes="input-field"
)
prompt_wav = gr.Audio(label="🎵 参考音频", type="filepath", elem_classes="input-field")
prompt_text = gr.Textbox(
label="📝 参考文本(可选)",
elem_classes="input-field",
placeholder="如不填写,将自动识别参考音频内容"
placeholder="如不填写,将自动识别参考音频内容",
)
# 中栏:模型选择和参数配置 (35%)
@@ -1043,14 +1008,10 @@ with gr.Blocks(
value="None",
interactive=True,
elem_classes="input-field",
info="选择训练好的 LoRA 模型,或选择 None 使用基础模型"
info="选择训练好的 LoRA 模型,或选择 None 使用基础模型",
)
refresh_lora_btn = gr.Button(
"🔄 刷新模型列表",
elem_classes="button-refresh",
size="sm"
)
refresh_lora_btn = gr.Button("🔄 刷新模型列表", elem_classes="button-refresh", size="sm")
gr.Markdown("#### ⚙️ 生成参数")
@@ -1060,7 +1021,7 @@ with gr.Blocks(
maximum=5.0,
value=2.0,
step=0.1,
info="引导系数,值越大越贴近提示"
info="引导系数,值越大越贴近提示",
)
steps = gr.Slider(
@@ -1069,7 +1030,7 @@ with gr.Blocks(
maximum=50,
value=10,
step=1,
info="生成质量与步数成正比,但耗时更长"
info="生成质量与步数成正比,但耗时更长",
)
seed = gr.Number(
@@ -1077,25 +1038,16 @@ with gr.Blocks(
value=-1,
precision=0,
elem_classes="input-field",
info="-1 为随机,固定值可复现结果"
info="-1 为随机,固定值可复现结果",
)
generate_btn = gr.Button(
"🎵 生成音频",
variant="primary",
elem_classes="button-primary",
size="lg"
)
generate_btn = gr.Button("🎵 生成音频", variant="primary", elem_classes="button-primary", size="lg")
# 右栏:生成结果 (30%)
with gr.Column(scale=30, elem_classes="form-section"):
gr.Markdown("#### 🎧 生成结果")
audio_out = gr.Audio(
label="",
elem_classes="input-field",
show_label=False
)
audio_out = gr.Audio(label="", elem_classes="input-field", show_label=False)
gr.Markdown("#### 📋 状态信息")
@@ -1105,7 +1057,7 @@ with gr.Blocks(
elem_classes="input-field",
show_label=False,
lines=3,
placeholder="等待生成..."
placeholder="等待生成...",
)
def refresh_loras():
@@ -1126,16 +1078,21 @@ with gr.Blocks(
refresh_lora_btn.click(refresh_loras, outputs=[lora_select])
# Auto-recognize audio when uploaded
prompt_wav.change(
fn=recognize_audio,
inputs=[prompt_wav],
outputs=[prompt_text]
)
prompt_wav.change(fn=recognize_audio, inputs=[prompt_wav], outputs=[prompt_text])
generate_btn.click(
run_inference,
inputs=[infer_text, prompt_wav, prompt_text, lora_select, cfg_scale, steps, seed, train_pretrained_path],
outputs=[audio_out, status_out]
inputs=[
infer_text,
prompt_wav,
prompt_text,
lora_select,
cfg_scale,
steps,
seed,
train_pretrained_path,
],
outputs=[audio_out, status_out],
)
# --- Language Switching Logic ---
@@ -1144,108 +1101,138 @@ with gr.Blocks(
# Labels for advanced options
if lang == "zh":
adv = {
'grad_accum_steps': "梯度累积 (grad_accum_steps)",
'num_workers': "数据加载线程 (num_workers)",
'log_interval': "日志间隔 (log_interval)",
'valid_interval': "验证间隔 (valid_interval)",
'weight_decay': "权重衰减 (weight_decay)",
'warmup_steps': "warmup_steps",
'max_steps': "最大步数 (max_steps)",
'sample_rate': "采样率 (sample_rate)",
'enable_lm': "启用 LoRA LM (enable_lm)",
'enable_dit': "启用 LoRA DIT (enable_dit)",
'enable_proj': "启用投影 (enable_proj)",
'dropout': "LoRA Dropout",
'tensorboard_path': "Tensorboard 路径 (可选)",
'hf_model_id': "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
'distribute': "分发模式 (distribute)",
"grad_accum_steps": "梯度累积 (grad_accum_steps)",
"num_workers": "数据加载线程 (num_workers)",
"log_interval": "日志间隔 (log_interval)",
"valid_interval": "验证间隔 (valid_interval)",
"weight_decay": "权重衰减 (weight_decay)",
"warmup_steps": "warmup_steps",
"max_steps": "最大步数 (max_steps)",
"sample_rate": "采样率 (sample_rate)",
"enable_lm": "启用 LoRA LM (enable_lm)",
"enable_dit": "启用 LoRA DIT (enable_dit)",
"enable_proj": "启用投影 (enable_proj)",
"dropout": "LoRA Dropout",
"tensorboard_path": "Tensorboard 路径 (可选)",
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
"distribute": "分发模式 (distribute)",
}
else:
adv = {
'grad_accum_steps': "Grad Accum Steps",
'num_workers': "Num Workers",
'log_interval': "Log Interval",
'valid_interval': "Valid Interval",
'weight_decay': "Weight Decay",
'warmup_steps': "Warmup Steps",
'max_steps': "Max Steps",
'sample_rate': "Sample Rate",
'enable_lm': "Enable LoRA LM",
'enable_dit': "Enable LoRA DIT",
'enable_proj': "Enable Projection",
'dropout': "LoRA Dropout",
'tensorboard_path': "Tensorboard Path (Optional)",
'hf_model_id': "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
'distribute': "Distribute Mode",
"grad_accum_steps": "Grad Accum Steps",
"num_workers": "Num Workers",
"log_interval": "Log Interval",
"valid_interval": "Valid Interval",
"weight_decay": "Weight Decay",
"warmup_steps": "Warmup Steps",
"max_steps": "Max Steps",
"sample_rate": "Sample Rate",
"enable_lm": "Enable LoRA LM",
"enable_dit": "Enable LoRA DIT",
"enable_proj": "Enable Projection",
"dropout": "LoRA Dropout",
"tensorboard_path": "Tensorboard Path (Optional)",
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
"distribute": "Distribute Mode",
}
return (
gr.update(value=f"# {d['title']}"),
gr.update(label=d['tab_train']),
gr.update(label=d['tab_infer']),
gr.update(label=d['pretrained_path']),
gr.update(label=d['train_manifest']),
gr.update(label=d['val_manifest']),
gr.update(label=d['lr']),
gr.update(label=d['max_iters']),
gr.update(label=d['batch_size']),
gr.update(label=d['lora_rank']),
gr.update(label=d['lora_alpha']),
gr.update(label=d['save_interval']),
gr.update(label=d['output_name']),
gr.update(value=d['start_train']),
gr.update(value=d['stop_train']),
gr.update(label=d['train_logs']),
gr.update(label=d["tab_train"]),
gr.update(label=d["tab_infer"]),
gr.update(label=d["pretrained_path"]),
gr.update(label=d["train_manifest"]),
gr.update(label=d["val_manifest"]),
gr.update(label=d["lr"]),
gr.update(label=d["max_iters"]),
gr.update(label=d["batch_size"]),
gr.update(label=d["lora_rank"]),
gr.update(label=d["lora_alpha"]),
gr.update(label=d["save_interval"]),
gr.update(label=d["output_name"]),
gr.update(value=d["start_train"]),
gr.update(value=d["stop_train"]),
gr.update(label=d["train_logs"]),
# Advanced options (must match outputs order)
gr.update(label=adv['grad_accum_steps']),
gr.update(label=adv['num_workers']),
gr.update(label=adv['log_interval']),
gr.update(label=adv['valid_interval']),
gr.update(label=adv['weight_decay']),
gr.update(label=adv['warmup_steps']),
gr.update(label=adv['max_steps']),
gr.update(label=adv['sample_rate']),
gr.update(label=adv['enable_lm']),
gr.update(label=adv['enable_dit']),
gr.update(label=adv['enable_proj']),
gr.update(label=adv['dropout']),
gr.update(label=adv['tensorboard_path']),
gr.update(label=adv["grad_accum_steps"]),
gr.update(label=adv["num_workers"]),
gr.update(label=adv["log_interval"]),
gr.update(label=adv["valid_interval"]),
gr.update(label=adv["weight_decay"]),
gr.update(label=adv["warmup_steps"]),
gr.update(label=adv["max_steps"]),
gr.update(label=adv["sample_rate"]),
gr.update(label=adv["enable_lm"]),
gr.update(label=adv["enable_dit"]),
gr.update(label=adv["enable_proj"]),
gr.update(label=adv["dropout"]),
gr.update(label=adv["tensorboard_path"]),
# Distribution options
gr.update(label=adv['hf_model_id']),
gr.update(label=adv['distribute']),
gr.update(label=adv["hf_model_id"]),
gr.update(label=adv["distribute"]),
# Inference section
gr.update(label=d['text_to_synth']),
gr.update(label=d['ref_audio']),
gr.update(label=d['ref_text']),
gr.update(label=d['select_lora']),
gr.update(value=d['refresh']),
gr.update(label=d['cfg_scale']),
gr.update(label=d['infer_steps']),
gr.update(label=d['seed']),
gr.update(value=d['gen_audio']),
gr.update(label=d['gen_output']),
gr.update(label=d['status']),
gr.update(label=d["text_to_synth"]),
gr.update(label=d["ref_audio"]),
gr.update(label=d["ref_text"]),
gr.update(label=d["select_lora"]),
gr.update(value=d["refresh"]),
gr.update(label=d["cfg_scale"]),
gr.update(label=d["infer_steps"]),
gr.update(label=d["seed"]),
gr.update(value=d["gen_audio"]),
gr.update(label=d["gen_output"]),
gr.update(label=d["status"]),
)
lang_btn.change(
change_language,
inputs=[lang_btn],
outputs=[
title_md, tab_train, tab_infer,
train_pretrained_path, train_manifest, val_manifest,
lr, num_iters, batch_size, lora_rank, lora_alpha, save_interval,
title_md,
tab_train,
tab_infer,
train_pretrained_path,
train_manifest,
val_manifest,
lr,
num_iters,
batch_size,
lora_rank,
lora_alpha,
save_interval,
output_name,
start_btn, stop_btn, logs_out,
start_btn,
stop_btn,
logs_out,
# advanced outputs
grad_accum_steps, num_workers, log_interval, valid_interval,
weight_decay, warmup_steps, max_steps, sample_rate,
enable_lm, enable_dit, enable_proj, dropout, tensorboard_path,
grad_accum_steps,
num_workers,
log_interval,
valid_interval,
weight_decay,
warmup_steps,
max_steps,
sample_rate,
enable_lm,
enable_dit,
enable_proj,
dropout,
tensorboard_path,
# distribution outputs
hf_model_id, distribute,
infer_text, prompt_wav, prompt_text,
lora_select, refresh_lora_btn, cfg_scale, steps, seed,
generate_btn, audio_out, status_out
]
hf_model_id,
distribute,
infer_text,
prompt_wav,
prompt_text,
lora_select,
refresh_lora_btn,
cfg_scale,
steps,
seed,
generate_btn,
audio_out,
status_out,
],
)
if __name__ == "__main__":
+1 -3
View File
@@ -30,7 +30,7 @@ dependencies = [
"torchcodec",
"transformers>=4.36.2",
"einops",
"gradio<6",
"gradio>=6,<7",
"inflect",
"addict",
"wetext",
@@ -57,7 +57,6 @@ dev = [
"pytest-cov>=2.0",
"black>=21.0",
"flake8>=3.8",
"mypy>=0.800",
"pre-commit>=2.0",
]
@@ -90,7 +89,6 @@ extend-exclude = '''
\.eggs
| \.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| build
+4 -1
View File
@@ -125,7 +125,10 @@ def main():
out_path.parent.mkdir(parents=True, exist_ok=True)
sf.write(str(out_path), audio_np, model.tts_model.sample_rate)
print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
print(
f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
file=sys.stderr,
)
if __name__ == "__main__":
+30 -13
View File
@@ -127,7 +127,9 @@ def main():
print(f"Loaded config from: {lora_config_path}", file=sys.stderr)
print(f" Base model: {pretrained_path}", file=sys.stderr)
print(f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None", file=sys.stderr)
print(
f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None", file=sys.stderr
)
# 3. Load model with LoRA (no denoiser)
print(f"\n[1/2] Loading model with LoRA: {pretrained_path}", file=sys.stderr)
@@ -146,10 +148,10 @@ def main():
out_path = Path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True)
print(f"\n[2/2] Starting synthesis tests...", file=sys.stderr)
print("\n[2/2] Starting synthesis tests...", file=sys.stderr)
# === Test 1: With LoRA ===
print(f"\n [Test 1] Synthesize with LoRA...", file=sys.stderr)
print("\n [Test 1] Synthesize with LoRA...", file=sys.stderr)
audio_np = model.generate(
text=args.text,
prompt_wav_path=prompt_wav_path,
@@ -162,10 +164,13 @@ def main():
)
lora_output = out_path.with_stem(out_path.stem + "_with_lora")
sf.write(str(lora_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
print(
f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
file=sys.stderr,
)
# === Test 2: Disable LoRA (via set_lora_enabled) ===
print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...", file=sys.stderr)
print("\n [Test 2] Disable LoRA (set_lora_enabled=False)...", file=sys.stderr)
model.set_lora_enabled(False)
audio_np = model.generate(
text=args.text,
@@ -179,10 +184,13 @@ def main():
)
disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled")
sf.write(str(disabled_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
print(
f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
file=sys.stderr,
)
# === Test 3: Re-enable LoRA ===
print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...", file=sys.stderr)
print("\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...", file=sys.stderr)
model.set_lora_enabled(True)
audio_np = model.generate(
text=args.text,
@@ -196,10 +204,13 @@ def main():
)
reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled")
sf.write(str(reenabled_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
print(
f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
file=sys.stderr,
)
# === Test 4: Unload LoRA (reset_lora_weights) ===
print(f"\n [Test 4] Unload LoRA (unload_lora)...", file=sys.stderr)
print("\n [Test 4] Unload LoRA (unload_lora)...", file=sys.stderr)
model.unload_lora()
audio_np = model.generate(
text=args.text,
@@ -213,10 +224,13 @@ def main():
)
reset_output = out_path.with_stem(out_path.stem + "_lora_reset")
sf.write(str(reset_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
print(
f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
file=sys.stderr,
)
# === Test 5: Hot-reload LoRA (load_lora) ===
print(f"\n [Test 5] Hot-reload LoRA (load_lora)...", file=sys.stderr)
print("\n [Test 5] Hot-reload LoRA (load_lora)...", file=sys.stderr)
loaded, skipped = model.load_lora(ckpt_dir)
print(f" Reloaded {len(loaded)} parameters", file=sys.stderr)
audio_np = model.generate(
@@ -231,9 +245,12 @@ def main():
)
reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded")
sf.write(str(reload_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
print(
f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
file=sys.stderr,
)
print(f"\n[Done] All tests completed!", file=sys.stderr)
print("\n[Done] All tests completed!", file=sys.stderr)
print(f" - with_lora: {lora_output}", file=sys.stderr)
print(f" - lora_disabled: {disabled_output}", file=sys.stderr)
print(f" - lora_reenabled: {reenabled_output}", file=sys.stderr)
+153 -43
View File
@@ -7,7 +7,7 @@ project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root / "src"))
import contextlib
from typing import Dict, Optional
from typing import Dict
import argbind
import torch
@@ -17,16 +17,19 @@ from transformers import get_cosine_schedule_with_warmup
import signal
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
os.environ["TOKENIZERS_PARALLELISM"] = "false"
try:
from safetensors.torch import save_file
SAFETENSORS_AVAILABLE = True
except ImportError:
SAFETENSORS_AVAILABLE = False
print("Warning: safetensors not available, will use pytorch format", file=sys.stderr)
from voxcpm.model import VoxCPMModel
import json
from voxcpm.model import VoxCPMModel, VoxCPM2Model
from voxcpm.model.voxcpm import LoRAConfig
from voxcpm.training import (
Accelerator,
@@ -61,8 +64,8 @@ def train(
lora: dict = None,
config_path: str = "",
# Distribution options (for LoRA checkpoints)
hf_model_id: str = "", # HuggingFace model ID (e.g., "openbmb/VoxCPM1.5")
distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path
hf_model_id: str = "", # HuggingFace model ID (e.g., "openbmb/VoxCPM1.5")
distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path
):
_ = config_path
@@ -84,7 +87,15 @@ def train(
writer = SummaryWriter(log_dir=str(tb_dir)) if accelerator.rank == 0 else None
tracker = TrainingTracker(writer=writer, log_file=str(save_dir / "train.log"), rank=accelerator.rank)
base_model = VoxCPMModel.from_local(pretrained_path, optimize=False, training=True, lora_config=LoRAConfig(**lora) if lora else None)
# Auto-detect model architecture from config.json
with open(os.path.join(pretrained_path, "config.json"), "r", encoding="utf-8") as _f:
_arch = json.load(_f).get("architecture", "voxcpm").lower()
_model_cls = VoxCPM2Model if _arch == "voxcpm2" else VoxCPMModel
if accelerator.rank == 0:
print(f"Detected architecture: {_arch} -> {_model_cls.__name__}", file=sys.stderr)
base_model = _model_cls.from_local(
pretrained_path, optimize=False, training=True, lora_config=LoRAConfig(**lora) if lora else None
)
tokenizer = base_model.text_tokenizer
train_ds, val_ds = load_audio_text_datasets(
@@ -166,7 +177,6 @@ def train(
unwrapped_model = accelerator.unwrap(model)
unwrapped_model.train()
# Only print param info on rank 0 to avoid cluttered output
if accelerator.rank == 0:
for name, param in model.named_parameters():
@@ -199,7 +209,19 @@ def train(
resume = {"step": start_step}
# Register signal handler to save checkpoint on termination (SIGTERM/SIGINT)
def _signal_handler(signum, frame, _model=model, _optim=optimizer, _sched=scheduler, _save_dir=save_dir, _pretrained=pretrained_path, _hf_id=hf_model_id, _dist=distribute, _resume=resume, _rank=accelerator.rank):
def _signal_handler(
signum,
frame,
_model=model,
_optim=optimizer,
_sched=scheduler,
_save_dir=save_dir,
_pretrained=pretrained_path,
_hf_id=hf_model_id,
_dist=distribute,
_resume=resume,
_rank=accelerator.rank,
):
try:
cur_step = int(_resume.get("step", start_step))
except Exception:
@@ -229,8 +251,8 @@ def train(
except StopIteration:
data_epoch += 1
# Key: set DistributedSampler epoch to ensure different data order each epoch
sampler = getattr(train_loader, 'sampler', None)
if hasattr(sampler, 'set_epoch'):
sampler = getattr(train_loader, "sampler", None)
if hasattr(sampler, "set_epoch"):
sampler.set_epoch(data_epoch)
train_iter = iter(train_loader)
return next(train_iter)
@@ -250,7 +272,7 @@ def train(
# Only sync gradients on the last micro-batch
# Use no_sync() for intermediate steps to reduce communication overhead
is_last_micro_step = (micro_step == grad_accum_steps - 1)
is_last_micro_step = micro_step == grad_accum_steps - 1
sync_context = contextlib.nullcontext() if is_last_micro_step else accelerator.no_sync()
with sync_context:
@@ -299,10 +321,22 @@ def train(
tracker.log_metrics(loss_values, split="train")
if val_loader is not None and (step % valid_interval == 0 or step == num_iters - 1):
validate(model, val_loader, batch_processor, accelerator, tracker, lambdas,
writer=writer, step=step, val_ds=val_ds, audio_vae=audio_vae_for_gen,
sample_rate=sample_rate, val_texts=val_texts, tokenizer=tokenizer,
valid_interval=valid_interval)
validate(
model,
val_loader,
batch_processor,
accelerator,
tracker,
lambdas,
writer=writer,
step=step,
val_ds=val_ds,
audio_vae=audio_vae_for_gen,
sample_rate=sample_rate,
val_texts=val_texts,
tokenizer=tokenizer,
valid_interval=valid_interval,
)
if (step % save_interval == 0 or step == num_iters - 1) and accelerator.rank == 0:
save_checkpoint(model, optimizer, scheduler, save_dir, step, pretrained_path, hf_model_id, distribute)
@@ -313,11 +347,24 @@ def train(
writer.close()
def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas,
writer=None, step=0, val_ds=None, audio_vae=None, sample_rate=22050,
val_texts=None, tokenizer=None, valid_interval=1000):
def validate(
model,
val_loader,
batch_processor,
accelerator,
tracker,
lambdas,
writer=None,
step=0,
val_ds=None,
audio_vae=None,
sample_rate=22050,
val_texts=None,
tokenizer=None,
valid_interval=1000,
):
"""Validate and generate sample audio"""
import numpy as np
import numpy as np # noqa: F401
from collections import defaultdict
model.eval()
@@ -369,13 +416,24 @@ def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas,
# Generate sample audio for TensorBoard display
if writer is not None and val_ds is not None and audio_vae is not None and accelerator.rank == 0:
try:
generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, sample_rate,
val_texts=val_texts, tokenizer=tokenizer, valid_interval=valid_interval,
tracker=tracker)
generate_sample_audio(
model,
val_ds,
audio_vae,
writer,
step,
accelerator,
sample_rate,
val_texts=val_texts,
tokenizer=tokenizer,
valid_interval=valid_interval,
tracker=tracker,
)
except Exception as e:
tracker.print(f"[Warning] Failed to generate sample audio: {e}")
import traceback
import io
buf = io.StringIO()
traceback.print_exc(file=buf)
tracker.print(buf.getvalue())
@@ -398,6 +456,7 @@ def compute_mel_spectrogram(audio_np, sample_rate, n_mels=128):
"""Compute Mel Spectrogram (dB) using librosa"""
import numpy as np
import librosa
audio_np = audio_np.flatten().astype(np.float32)
mel = librosa.feature.melspectrogram(y=audio_np, sr=sample_rate, n_mels=n_mels, fmax=sample_rate // 2)
return librosa.power_to_db(mel, ref=np.max)
@@ -408,7 +467,8 @@ def create_mel_figure(gen_audio_np, gen_mel, sample_rate, step=None, ref_audio_n
Create mel spectrogram figure: show comparison if reference audio exists, otherwise show generated only
"""
import matplotlib
matplotlib.use('Agg')
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import librosa.display
@@ -419,19 +479,32 @@ def create_mel_figure(gen_audio_np, gen_mel, sample_rate, step=None, ref_audio_n
# Comparison mode: reference vs generated
fig, (ax_ref, ax_gen) = plt.subplots(2, 1, figsize=(12, 8))
img_ref = librosa.display.specshow(ref_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax_ref)
ax_ref.set_title(f'Reference (GT) - {len(ref_audio_np)/sample_rate:.2f}s{step_str}', fontsize=10, fontweight='bold', color='#28A745')
plt.colorbar(img_ref, ax=ax_ref, format='%+2.0f dB', pad=0.02)
img_ref = librosa.display.specshow(
ref_mel, sr=sample_rate, x_axis="time", y_axis="mel", fmax=fmax, cmap="viridis", ax=ax_ref
)
ax_ref.set_title(
f"Reference (GT) - {len(ref_audio_np)/sample_rate:.2f}s{step_str}",
fontsize=10,
fontweight="bold",
color="#28A745",
)
plt.colorbar(img_ref, ax=ax_ref, format="%+2.0f dB", pad=0.02)
img_gen = librosa.display.specshow(gen_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax_gen)
ax_gen.set_title(f'Generated - {len(gen_audio_np)/sample_rate:.2f}s', fontsize=10, fontweight='bold', color='#DC3545')
plt.colorbar(img_gen, ax=ax_gen, format='%+2.0f dB', pad=0.02)
img_gen = librosa.display.specshow(
gen_mel, sr=sample_rate, x_axis="time", y_axis="mel", fmax=fmax, cmap="viridis", ax=ax_gen
)
ax_gen.set_title(
f"Generated - {len(gen_audio_np)/sample_rate:.2f}s", fontsize=10, fontweight="bold", color="#DC3545"
)
plt.colorbar(img_gen, ax=ax_gen, format="%+2.0f dB", pad=0.02)
else:
# Single figure mode: show generated only
fig, ax = plt.subplots(figsize=(12, 4))
img = librosa.display.specshow(gen_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax)
ax.set_title(f'Generated - {len(gen_audio_np)/sample_rate:.2f}s{step_str}', fontsize=11, fontweight='bold')
plt.colorbar(img, ax=ax, format='%+2.0f dB', pad=0.02)
img = librosa.display.specshow(
gen_mel, sr=sample_rate, x_axis="time", y_axis="mel", fmax=fmax, cmap="viridis", ax=ax
)
ax.set_title(f"Generated - {len(gen_audio_np)/sample_rate:.2f}s{step_str}", fontsize=11, fontweight="bold")
plt.colorbar(img, ax=ax, format="%+2.0f dB", pad=0.02)
plt.tight_layout()
return fig
@@ -440,13 +513,25 @@ def create_mel_figure(gen_audio_np, gen_mel, sample_rate, step=None, ref_audio_n
def normalize_audio(audio_np):
"""Normalize audio to [-0.9, 0.9]"""
import numpy as np
max_val = np.abs(audio_np).max()
return audio_np / max_val * 0.9 if max_val > 0 else audio_np
def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, sample_rate=22050,
val_texts=None, tokenizer=None, pretrained_path=None, valid_interval=1000,
tracker=None):
def generate_sample_audio(
model,
val_ds,
audio_vae,
writer,
step,
accelerator,
sample_rate=22050,
val_texts=None,
tokenizer=None,
pretrained_path=None,
valid_interval=1000,
tracker=None,
):
"""Select 2 fixed validation samples, generate audio and log to TensorBoard"""
import numpy as np
@@ -468,7 +553,10 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
ref_sr = sample["audio"].get("sampling_rate", sample_rate)
if ref_sr != sample_rate:
import torchaudio.functional as F
ref_audio_np = F.resample(torch.from_numpy(ref_audio_np).unsqueeze(0), ref_sr, sample_rate).squeeze(0).numpy()
ref_audio_np = (
F.resample(torch.from_numpy(ref_audio_np).unsqueeze(0), ref_sr, sample_rate).squeeze(0).numpy()
)
log(f"[Audio] Loaded reference audio for sample {i}: duration={len(ref_audio_np)/sample_rate:.2f}s")
except Exception as e:
log(f"[Warning] Failed to load reference audio: {e}")
@@ -500,7 +588,11 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
continue
# Process generated audio
gen_audio_np = generated.cpu().float().numpy().flatten() if isinstance(generated, torch.Tensor) else np.array(generated, dtype=np.float32).flatten()
gen_audio_np = (
generated.cpu().float().numpy().flatten()
if isinstance(generated, torch.Tensor)
else np.array(generated, dtype=np.float32).flatten()
)
gen_audio_np = normalize_audio(gen_audio_np)
tag = f"val_sample_{i}"
@@ -509,7 +601,9 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
# Log reference audio
if ref_audio_np is not None:
writer.add_audio(f"{tag}/reference_audio", normalize_audio(ref_audio_np), global_step=step, sample_rate=sample_rate)
writer.add_audio(
f"{tag}/reference_audio", normalize_audio(ref_audio_np), global_step=step, sample_rate=sample_rate
)
# Generate mel spectrogram figure
try:
@@ -524,6 +618,7 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
except Exception as e:
log(f"[Warning] Failed to generate audio for sample {i}: {e}")
import traceback
traceback.print_exc()
finally:
@@ -545,8 +640,6 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0):
Called by all ranks so that distributed state stays aligned.
Returns the step number to resume from, or 0 if no checkpoint found.
"""
import json
latest_folder = save_dir / "latest"
if not latest_folder.exists():
return 0
@@ -564,6 +657,7 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0):
if lora_weights_path.exists():
if lora_weights_path.suffix == ".safetensors":
from safetensors.torch import load_file
state_dict = load_file(str(lora_weights_path))
else:
ckpt = torch.load(lora_weights_path, map_location="cpu")
@@ -581,6 +675,7 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0):
if model_path.exists():
if model_path.suffix == ".safetensors":
from safetensors.torch import load_file
state_dict = load_file(str(model_path))
else:
ckpt = torch.load(model_path, map_location="cpu")
@@ -625,13 +720,21 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0):
return 0
def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pretrained_path: str = None, hf_model_id: str = "", distribute: bool = False):
def save_checkpoint(
model,
optimizer,
scheduler,
save_dir: Path,
step: int,
pretrained_path: str = None,
hf_model_id: str = "",
distribute: bool = False,
):
"""
Save checkpoint with different strategies for full finetune vs LoRA:
- Full finetune: save non-vae weights to model.safetensors (or pytorch_model.bin if safetensors unavailable)
- LoRA: save only lora weights to lora_weights.safetensors (or lora_weights.ckpt if safetensors unavailable)
"""
import json
import shutil
save_dir.mkdir(parents=True, exist_ok=True)
@@ -671,7 +774,14 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
# Copy config files from pretrained path
if pretrained_path:
pretrained_dir = Path(pretrained_path)
files_to_copy = ["config.json", "audiovae.pth", "tokenizer.json", "special_tokens_map.json", "tokenizer_config.json"]
files_to_copy = [
"config.json",
"audiovae.pth",
"audiovae.safetensors",
"tokenizer.json",
"special_tokens_map.json",
"tokenizer_config.json",
]
for fname in files_to_copy:
src = pretrained_dir / fname
if src.exists():
+46 -27
View File
@@ -13,11 +13,11 @@ import soundfile as sf
from voxcpm.core import VoxCPM
# -----------------------------
# Validators
# -----------------------------
def validate_file_exists(file_path: str, file_type: str = "file") -> Path:
path = Path(file_path)
if not path.exists():
@@ -53,12 +53,11 @@ def validate_ranges(args, parser):
# Model loading
# -----------------------------
def load_model(args) -> VoxCPM:
print("Loading VoxCPM model...", file=sys.stderr)
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get(
"ZIPENHANCER_MODEL_PATH", None
)
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get("ZIPENHANCER_MODEL_PATH", None)
# Build LoRA config if provided
lora_config = None
@@ -119,22 +118,29 @@ def load_model(args) -> VoxCPM:
# Commands
# -----------------------------
def cmd_clone(args):
if not args.text:
sys.exit("Error: Please provide --text for synthesis")
if not args.prompt_audio or not args.prompt_text:
sys.exit("Error: Voice cloning requires both --prompt-audio and --prompt-text")
has_prompt = args.prompt_audio and args.prompt_text
has_ref = args.reference_audio is not None
if not has_prompt and not has_ref:
sys.exit("Error: Voice cloning requires --prompt-audio + --prompt-text, or --reference-audio, or both")
prompt_audio_path = validate_file_exists(args.prompt_audio, "reference audio file")
if args.prompt_audio:
validate_file_exists(args.prompt_audio, "prompt audio file")
if args.reference_audio:
validate_file_exists(args.reference_audio, "reference audio file")
output_path = validate_output_path(args.output)
model = load_model(args)
audio_array = model.generate(
text=args.text,
prompt_wav_path=str(prompt_audio_path),
prompt_text=args.prompt_text,
prompt_wav_path=args.prompt_audio if has_prompt else None,
prompt_text=args.prompt_text if has_prompt else None,
reference_wav_path=args.reference_audio,
cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
normalize=args.normalize,
@@ -185,7 +191,11 @@ def cmd_batch(args):
prompt_audio_path = None
if args.prompt_audio:
prompt_audio_path = str(validate_file_exists(args.prompt_audio, "reference audio file"))
prompt_audio_path = str(validate_file_exists(args.prompt_audio, "prompt audio file"))
reference_audio_path = None
if args.reference_audio:
reference_audio_path = str(validate_file_exists(args.reference_audio, "reference audio file"))
success_count = 0
@@ -195,10 +205,11 @@ def cmd_batch(args):
text=text,
prompt_wav_path=prompt_audio_path,
prompt_text=args.prompt_text,
reference_wav_path=reference_audio_path,
cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
normalize=args.normalize,
denoise=args.denoise and prompt_audio_path is not None,
denoise=args.denoise and (prompt_audio_path is not None or reference_audio_path is not None),
)
output_file = output_dir / f"output_{i:03d}.wav"
@@ -218,6 +229,7 @@ def cmd_batch(args):
# Parser
# -----------------------------
def _build_unified_parser():
parser = argparse.ArgumentParser(
description="VoxCPM CLI - voice cloning, direct TTS, and batch processing",
@@ -236,34 +248,40 @@ Examples:
parser.add_argument("--text", "-t", help="Text to synthesize (single or clone mode)")
parser.add_argument("--output", "-o", help="Output audio file path (single or clone mode)")
# Prompt
parser.add_argument("--prompt-audio", "-pa", help="Reference audio file path (clone mode)")
parser.add_argument("--prompt-text", "-pt", help="Reference text corresponding to the audio")
parser.add_argument("--denoise", action="store_true", help="Enable prompt speech enhancement")
# Prompt / Reference
parser.add_argument(
"--prompt-audio", "-pa", help="Prompt audio file path (continuation mode, requires --prompt-text)"
)
parser.add_argument("--prompt-text", "-pt", help="Text corresponding to the prompt audio")
parser.add_argument(
"--reference-audio", "-ra", help="Reference audio for voice cloning (isolated mode, VoxCPM2 only)"
)
parser.add_argument("--denoise", action="store_true", help="Enable prompt/reference speech enhancement")
# Generation parameters
parser.add_argument("--cfg-value", type=float, default=2.0,
help="CFG guidance scale (float, recommended 0.55.0, default: 2.0)")
parser.add_argument("--inference-timesteps", type=int, default=10,
help="Inference steps (int, 1100, default: 10)")
parser.add_argument(
"--cfg-value", type=float, default=2.0, help="CFG guidance scale (float, recommended 0.55.0, default: 2.0)"
)
parser.add_argument("--inference-timesteps", type=int, default=10, help="Inference steps (int, 1100, default: 10)")
parser.add_argument("--normalize", action="store_true", help="Enable text normalization")
# Model loading
parser.add_argument("--model-path", type=str, help="Local VoxCPM model path")
parser.add_argument("--hf-model-id", type=str, default="openbmb/VoxCPM1.5",
help="Hugging Face repo id (default: openbmb/VoxCPM1.5)")
parser.add_argument(
"--hf-model-id", type=str, default="openbmb/VoxCPM1.5", help="Hugging Face repo id (default: openbmb/VoxCPM1.5)"
)
parser.add_argument("--cache-dir", type=str, help="Cache directory for Hub downloads")
parser.add_argument("--local-files-only", action="store_true", help="Disable network access")
parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading")
parser.add_argument("--zipenhancer-path", type=str,
help="ZipEnhancer model id or local path (or env ZIPENHANCER_MODEL_PATH)")
parser.add_argument(
"--zipenhancer-path", type=str, help="ZipEnhancer model id or local path (or env ZIPENHANCER_MODEL_PATH)"
)
# LoRA
parser.add_argument("--lora-path", type=str, help="Path to LoRA weights")
parser.add_argument("--lora-r", type=int, default=32, help="LoRA rank (positive int, default: 32)")
parser.add_argument("--lora-alpha", type=int, default=16, help="LoRA alpha (positive int, default: 16)")
parser.add_argument("--lora-dropout", type=float, default=0.0,
help="LoRA dropout rate (0.01.0, default: 0.0)")
parser.add_argument("--lora-dropout", type=float, default=0.0, help="LoRA dropout rate (0.01.0, default: 0.0)")
parser.add_argument("--lora-disable-lm", action="store_true", help="Disable LoRA on LM layers")
parser.add_argument("--lora-disable-dit", action="store_true", help="Disable LoRA on DiT layers")
parser.add_argument("--lora-enable-proj", action="store_true", help="Enable LoRA on projection layers")
@@ -275,6 +293,7 @@ Examples:
# Entrypoint
# -----------------------------
def main():
parser = _build_unified_parser()
args = parser.parse_args()
@@ -296,8 +315,8 @@ def main():
if not args.text or not args.output:
parser.error("Single-sample mode requires --text and --output")
# Clone mode
if args.prompt_audio or args.prompt_text:
# Clone mode (prompt continuation, reference isolation, or both)
if args.prompt_audio or args.prompt_text or args.reference_audio:
return cmd_clone(args)
# Direct synthesis
+127 -76
View File
@@ -1,21 +1,25 @@
import os
import sys
import re
import json
import tempfile
import numpy as np
from typing import Generator, Optional
from huggingface_hub import snapshot_download
from .model.voxcpm import VoxCPMModel, LoRAConfig
from .model.voxcpm2 import VoxCPM2Model
class VoxCPM:
def __init__(self,
voxcpm_model_path : str,
zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
enable_denoiser : bool = True,
optimize: bool = True,
lora_config: Optional[LoRAConfig] = None,
lora_weights_path: Optional[str] = None,
):
def __init__(
self,
voxcpm_model_path: str,
zipenhancer_model_path: str | None = "iic/speech_zipenhancer_ans_multiloss_16k_base",
enable_denoiser: bool = True,
optimize: bool = True,
lora_config: Optional[LoRAConfig] = None,
lora_weights_path: Optional[str] = None,
):
"""Initialize VoxCPM TTS pipeline.
Args:
@@ -31,7 +35,10 @@ class VoxCPM:
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
containing lora_weights.ckpt). If provided, LoRA weights will be loaded.
"""
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}", file=sys.stderr)
print(
f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}",
file=sys.stderr,
)
# If lora_weights_path is provided but no lora_config, create a default one
if lora_weights_path is not None and lora_config is None:
@@ -42,7 +49,20 @@ class VoxCPM:
)
print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}", file=sys.stderr)
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
# Determine model type from config.json architecture field
config_path = os.path.join(voxcpm_model_path, "config.json")
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
arch = config.get("architecture", "voxcpm").lower()
if arch == "voxcpm2":
self.tts_model = VoxCPM2Model.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
print("Loaded VoxCPM2Model", file=sys.stderr)
elif arch == "voxcpm":
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
print("Loaded VoxCPMModel", file=sys.stderr)
else:
raise ValueError(f"Unsupported architecture: {arch}")
# Load LoRA weights if path is provided
if lora_weights_path is not None:
@@ -51,8 +71,10 @@ class VoxCPM:
print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}", file=sys.stderr)
self.text_normalizer = None
self.denoiser = None
if enable_denoiser and zipenhancer_model_path is not None:
from .zipenhancer import ZipEnhancer
self.denoiser = ZipEnhancer(zipenhancer_model_path)
else:
self.denoiser = None
@@ -64,17 +86,18 @@ class VoxCPM:
)
@classmethod
def from_pretrained(cls,
hf_model_id: str = "openbmb/VoxCPM1.5",
load_denoiser: bool = True,
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
cache_dir: str = None,
local_files_only: bool = False,
optimize: bool = True,
lora_config: Optional[LoRAConfig] = None,
lora_weights_path: Optional[str] = None,
**kwargs,
):
def from_pretrained(
cls,
hf_model_id: str = "openbmb/VoxCPM2",
load_denoiser: bool = True,
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
cache_dir: str = None,
local_files_only: bool = False,
optimize: bool = True,
lora_config: Optional[LoRAConfig] = None,
lora_weights_path: Optional[str] = None,
**kwargs,
):
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
Args:
@@ -134,46 +157,47 @@ class VoxCPM:
def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]:
return self._generate(*args, streaming=True, **kwargs)
def _generate(self,
text : str,
prompt_wav_path : str = None,
prompt_text : str = None,
cfg_value : float = 2.0,
inference_timesteps : int = 10,
min_len : int = 2,
max_len : int = 4096,
normalize : bool = False,
denoise : bool = False,
retry_badcase : bool = True,
retry_badcase_max_times : int = 3,
retry_badcase_ratio_threshold : float = 6.0,
streaming: bool = False,
) -> Generator[np.ndarray, None, None]:
def _generate(
self,
text: str,
prompt_wav_path: str = None,
prompt_text: str = None,
reference_wav_path: str = None,
cfg_value: float = 2.0,
inference_timesteps: int = 10,
min_len: int = 2,
max_len: int = 4096,
normalize: bool = False,
denoise: bool = False,
retry_badcase: bool = True,
retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0,
streaming: bool = False,
) -> Generator[np.ndarray, None, None]:
"""Synthesize speech for the given text and return a single waveform.
This method optionally builds and reuses a prompt cache. If an external
prompt (``prompt_wav_path`` + ``prompt_text``) is provided, it will be
used for all sub-sentences. Otherwise, the prompt cache is built from
the first generated result and reused for the remaining text chunks.
Args:
text: Input text. Can include newlines; each non-empty line is
treated as a sub-sentence.
prompt_wav_path: Path to a reference audio file for prompting.
text: Input text to synthesize.
prompt_wav_path: Path to prompt audio for continuation mode.
Must be paired with ``prompt_text``.
prompt_text: Text content corresponding to the prompt audio.
reference_wav_path: Path to reference audio for voice cloning
(structurally isolated via ref_audio tokens). Can be used
alone or combined with ``prompt_wav_path`` + ``prompt_text``.
cfg_value: Guidance scale for the generation model.
inference_timesteps: Number of inference steps.
min_len: Minimum audio length.
max_len: Maximum token length during generation.
normalize: Whether to run text normalization before generation.
denoise: Whether to denoise the prompt audio if a denoiser is
available.
denoise: Whether to denoise the prompt/reference audio if a
denoiser is available.
retry_badcase: Whether to retry badcase.
retry_badcase_max_times: Maximum number of times to retry badcase.
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio.
streaming: Whether to return a generator of audio chunks.
Returns:
Generator of numpy.ndarray: 1D waveform array (float32) on CPU.
Yields audio chunks for each generations step if ``streaming=True``,
Yields audio chunks for each generation step if ``streaming=True``,
otherwise yields a single array containing the final audio.
"""
if not text.strip() or not isinstance(text, str):
@@ -183,55 +207,82 @@ class VoxCPM:
if not os.path.exists(prompt_wav_path):
raise FileNotFoundError(f"prompt_wav_path does not exist: {prompt_wav_path}")
if reference_wav_path is not None:
if not os.path.exists(reference_wav_path):
raise FileNotFoundError(f"reference_wav_path does not exist: {reference_wav_path}")
if (prompt_wav_path is None) != (prompt_text is None):
raise ValueError("prompt_wav_path and prompt_text must both be provided or both be None")
is_v2 = isinstance(self.tts_model, VoxCPM2Model)
if reference_wav_path is not None and not is_v2:
raise ValueError("reference_wav_path is only supported with VoxCPM2 models")
text = text.replace("\n", " ")
text = re.sub(r'\s+', ' ', text)
temp_prompt_wav_path = None
text = re.sub(r"\s+", " ", text)
temp_files = []
try:
if prompt_wav_path is not None and prompt_text is not None:
if denoise and self.denoiser is not None:
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
temp_prompt_wav_path = tmp_file.name
self.denoiser.enhance(prompt_wav_path, output_path=temp_prompt_wav_path)
prompt_wav_path = temp_prompt_wav_path
fixed_prompt_cache = self.tts_model.build_prompt_cache(
prompt_wav_path=prompt_wav_path,
prompt_text=prompt_text
)
actual_prompt_path = prompt_wav_path
actual_ref_path = reference_wav_path
if denoise and self.denoiser is not None:
if prompt_wav_path is not None:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
temp_files.append(tmp.name)
self.denoiser.enhance(prompt_wav_path, output_path=temp_files[-1])
actual_prompt_path = temp_files[-1]
if reference_wav_path is not None:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
temp_files.append(tmp.name)
self.denoiser.enhance(reference_wav_path, output_path=temp_files[-1])
actual_ref_path = temp_files[-1]
if actual_prompt_path is not None or actual_ref_path is not None:
if is_v2:
fixed_prompt_cache = self.tts_model.build_prompt_cache(
prompt_text=prompt_text,
prompt_wav_path=actual_prompt_path,
reference_wav_path=actual_ref_path,
)
else:
fixed_prompt_cache = self.tts_model.build_prompt_cache(
prompt_text=prompt_text,
prompt_wav_path=actual_prompt_path,
)
else:
fixed_prompt_cache = None # will be built from the first inference
fixed_prompt_cache = None
if normalize:
if self.text_normalizer is None:
from .utils.text_normalize import TextNormalizer
self.text_normalizer = TextNormalizer()
text = self.text_normalizer.normalize(text)
generate_result = self.tts_model._generate_with_prompt_cache(
target_text=text,
prompt_cache=fixed_prompt_cache,
min_len=min_len,
max_len=max_len,
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
retry_badcase=retry_badcase,
retry_badcase_max_times=retry_badcase_max_times,
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
streaming=streaming,
)
target_text=text,
prompt_cache=fixed_prompt_cache,
min_len=min_len,
max_len=max_len,
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
retry_badcase=retry_badcase,
retry_badcase_max_times=retry_badcase_max_times,
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
streaming=streaming,
)
for wav, _, _ in generate_result:
yield wav.squeeze(0).cpu().numpy()
finally:
if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path):
try:
os.unlink(temp_prompt_wav_path)
except OSError:
pass
for tmp_path in temp_files:
if tmp_path and os.path.exists(tmp_path):
try:
os.unlink(tmp_path)
except OSError:
pass
# ------------------------------------------------------------------ #
# LoRA Interface (delegated to VoxCPMModel)
+2 -1
View File
@@ -1,3 +1,4 @@
from .voxcpm import VoxCPMModel
from .voxcpm2 import VoxCPM2Model
__all__ = ["VoxCPMModel"]
__all__ = ["VoxCPMModel", "VoxCPM2Model"]
+1 -2
View File
@@ -24,8 +24,7 @@ def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
"""
# Pre-compute multi-character tokens (length >= 2, pure Chinese characters)
multichar_tokens = {
token for token in tokenizer.vocab.keys()
if len(token) >= 2 and all("\u4e00" <= c <= "\u9fff" for c in token)
token for token in tokenizer.vocab.keys() if len(token) >= 2 and all("\u4e00" <= c <= "\u9fff" for c in token)
}
class CharTokenizerWrapper:
+80 -67
View File
@@ -24,7 +24,6 @@ from typing import Tuple, Union, Generator, List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import warnings
from einops import rearrange
@@ -32,6 +31,7 @@ from pydantic import BaseModel
try:
from safetensors.torch import load_file
SAFETENSORS_AVAILABLE = True
except ImportError:
SAFETENSORS_AVAILABLE = False
@@ -84,9 +84,9 @@ class VoxCPMConfig(BaseModel):
class LoRAConfig(BaseModel):
enable_lm: bool = False # Apply LoRA to base_lm + residual_lm
enable_dit: bool = False # Apply LoRA to VoxCPMLocDiT
enable_proj: bool = False # Apply LoRA to projection Linear layers
enable_lm: bool = False # Apply LoRA to base_lm + residual_lm
enable_dit: bool = False # Apply LoRA to VoxCPMLocDiT
enable_proj: bool = False # Apply LoRA to projection Linear layers
r: int = 8
alpha: int = 16
@@ -168,7 +168,7 @@ class VoxCPMModel(nn.Module):
config.lm_config.hidden_size,
config.lm_config.hidden_size,
config.scalar_quantization_latent_dim,
config.scalar_quantization_scale
config.scalar_quantization_scale,
)
self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size)
self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
@@ -196,9 +196,7 @@ class VoxCPMModel(nn.Module):
# LM: base_lm + residual_lm
if cfg.enable_lm:
for lm in [self.base_lm, self.residual_lm]:
apply_lora_to_named_linear_modules(
lm, target_submodule_names=cfg.target_modules_lm, **lora_kwargs
)
apply_lora_to_named_linear_modules(lm, target_submodule_names=cfg.target_modules_lm, **lora_kwargs)
# DiT: feat_decoder.estimator
if cfg.enable_dit:
@@ -209,6 +207,7 @@ class VoxCPMModel(nn.Module):
# 投影层
if cfg.enable_proj:
from ..modules.layers.lora import LoRALinear
for attr_name in cfg.target_proj_modules:
module = getattr(self, attr_name, None)
if isinstance(module, nn.Linear):
@@ -221,13 +220,17 @@ class VoxCPMModel(nn.Module):
if self.device != "cuda":
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
try:
import triton
import triton # noqa: F401
except ImportError:
raise ValueError("triton is not installed")
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
self.residual_lm.forward_step = torch.compile(
self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True
)
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
self.feat_decoder.estimator = torch.compile(
self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True
)
except Exception as e:
print(f"Warning: torch.compile disabled - {e}", file=sys.stderr)
return self
@@ -313,9 +316,11 @@ class VoxCPMModel(nn.Module):
mu=dit_hidden,
patch_size=self.patch_size,
cond=feat_cond_for_sample,
n_timesteps=self.config.dit_config.cfm_config.inference_cfg_rate
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
else 10,
n_timesteps=(
self.config.dit_config.cfm_config.inference_cfg_rate
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
else 10
),
)
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
@@ -331,7 +336,6 @@ class VoxCPMModel(nn.Module):
def _dtype(self):
return get_dtype(self.config.dtype)
def generate(self, *args, **kwargs) -> torch.Tensor:
return next(self._generate(*args, streaming=False, **kwargs))
@@ -350,7 +354,7 @@ class VoxCPMModel(nn.Module):
cfg_value: float = 2.0,
retry_badcase: bool = False,
retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
streaming: bool = False,
) -> Generator[torch.Tensor, None, None]:
if retry_badcase and streaming:
@@ -444,7 +448,9 @@ class VoxCPMModel(nn.Module):
audio_feat,
audio_mask,
min_len=min_len,
max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len), # avoid too long audio
max_len=min(
int(target_text_length * retry_badcase_ratio_threshold + 10), max_len
), # avoid too long audio
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
streaming=streaming,
@@ -460,7 +466,10 @@ class VoxCPMModel(nn.Module):
latent_pred, pred_audio_feat = next(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...", file=sys.stderr)
print(
f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...",
file=sys.stderr,
)
retry_badcase_times += 1
continue
else:
@@ -514,7 +523,9 @@ class VoxCPMModel(nn.Module):
self.audio_vae.latent_dim,
-1,
self.patch_size,
).permute(1, 2, 0) # (D, T, P)
).permute(
1, 2, 0
) # (D, T, P)
# build prompt cache - only save raw text and audio features
prompt_cache = {
"prompt_text": prompt_text,
@@ -523,7 +534,6 @@ class VoxCPMModel(nn.Module):
return prompt_cache
def merge_prompt_cache(
self,
original_cache: dict,
@@ -560,17 +570,14 @@ class VoxCPMModel(nn.Module):
return merged_cache
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
def generate_with_prompt_cache_streaming(
self, *args, **kwargs
) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]:
return self._generate_with_prompt_cache(*args, streaming=True, **kwargs)
@torch.inference_mode()
def _generate_with_prompt_cache(
self,
@@ -645,8 +652,12 @@ class VoxCPMModel(nn.Module):
)
text_token = torch.cat([text_token, text_pad_token])
audio_feat = torch.cat([audio_pad_feat, prompt_audio_feat], dim=0)
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
text_mask = (
torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
)
audio_mask = (
torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
)
text_token = text_token.unsqueeze(0).to(self.device)
text_mask = text_mask.unsqueeze(0).to(self.device)
@@ -663,7 +674,9 @@ class VoxCPMModel(nn.Module):
audio_feat,
audio_mask,
min_len=min_len,
max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len), # avoid too long audio
max_len=min(
int(target_text_length * retry_badcase_ratio_threshold + 10), max_len
), # avoid too long audio
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
streaming=streaming,
@@ -674,17 +687,16 @@ class VoxCPMModel(nn.Module):
for latent_pred, pred_audio_feat in inference_result:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
yield (
decode_audio,
target_text_token,
pred_audio_feat
)
yield (decode_audio, target_text_token, pred_audio_feat)
break
else:
latent_pred, pred_audio_feat = next(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...", file=sys.stderr)
print(
f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...",
file=sys.stderr,
)
retry_badcase_times += 1
continue
else:
@@ -695,14 +707,10 @@ class VoxCPMModel(nn.Module):
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
patch_len = self.patch_size * self.chunk_size
if audio_mask.sum().item() > 0:
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1):].squeeze(1).cpu()
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1) :].squeeze(1).cpu()
else:
decode_audio = decode_audio[..., :].squeeze(1).cpu()
yield (
decode_audio,
target_text_token,
pred_audio_feat
)
yield (decode_audio, target_text_token, pred_audio_feat)
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
return next(self._inference(*args, streaming=False, **kwargs))
@@ -782,7 +790,6 @@ class VoxCPMModel(nn.Module):
enc_outputs = self.fsq_layer(enc_outputs) * feat_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
lm_hidden = enc_outputs[:, -1, :]
residual_enc_outputs, residual_kv_cache_tuple = self.residual_lm(
inputs_embeds=enc_outputs + feat_mask.unsqueeze(-1) * feat_embed,
is_causal=True,
@@ -790,7 +797,6 @@ class VoxCPMModel(nn.Module):
self.residual_lm.kv_cache.fill_caches(residual_kv_cache_tuple)
residual_hidden = residual_enc_outputs[:, -1, :]
for i in tqdm(range(max_len)):
dit_hidden_1 = self.lm_to_dit_proj(lm_hidden) # [b, h_dit]
dit_hidden_2 = self.res_to_dit_proj(residual_hidden) # [b, h_dit]
@@ -827,10 +833,10 @@ class VoxCPMModel(nn.Module):
curr_embed[:, 0, :], torch.tensor([self.base_lm.kv_cache.step()], device=curr_embed.device)
).clone()
lm_hidden = self.fsq_layer(lm_hidden)
residual_hidden = self.residual_lm.forward_step(
lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device)
lm_hidden + curr_embed[:, 0, :],
torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device),
).clone()
if not streaming:
@@ -838,29 +844,41 @@ class VoxCPMModel(nn.Module):
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
@classmethod
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
tokenizer = LlamaTokenizerFast.from_pretrained(path)
audio_vae_config = getattr(config, 'audio_vae_config', None)
audio_vae_config = getattr(config, "audio_vae_config", None)
audio_vae = AudioVAE(config=audio_vae_config) if audio_vae_config else AudioVAE()
vae_state_dict = torch.load(
os.path.join(path, "audiovae.pth"),
map_location="cpu",
weights_only=True,
)["state_dict"]
# Try to load AudioVAE from safetensors first, fallback to pytorch
audiovae_safetensors_path = os.path.join(path, "audiovae.safetensors")
audiovae_pth_path = os.path.join(path, "audiovae.pth")
if os.path.exists(audiovae_safetensors_path) and SAFETENSORS_AVAILABLE:
print(f"Loading AudioVAE from safetensors: {audiovae_safetensors_path}", file=sys.stderr)
vae_state_dict = load_file(audiovae_safetensors_path, device="cpu")
elif os.path.exists(audiovae_pth_path):
print(f"Loading AudioVAE from pytorch: {audiovae_pth_path}", file=sys.stderr)
checkpoint = torch.load(
audiovae_pth_path,
map_location="cpu",
weights_only=True,
)
vae_state_dict = checkpoint.get("state_dict", checkpoint)
else:
raise FileNotFoundError(
f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
)
model = cls(config, tokenizer, audio_vae, lora_config)
if not training:
lm_dtype = get_dtype(model.config.dtype)
model = model.to(lm_dtype)
else: # training mode
else: # training mode
for name, param in model.named_parameters():
if "audio_vae" in name: # freeze VAE weights
if "audio_vae" in name: # freeze VAE weights
param.requires_grad = False
continue
if lora_config is not None:
if "lora" not in name: # freeze non-LoRA weights
if "lora" not in name: # freeze non-LoRA weights
param.requires_grad = False
model.audio_vae = model.audio_vae.to(torch.float32)
@@ -880,9 +898,7 @@ class VoxCPMModel(nn.Module):
)
model_state_dict = checkpoint.get("state_dict", checkpoint)
else:
raise FileNotFoundError(
f"Model file not found. Expected either {safetensors_path} or {pytorch_model_path}"
)
raise FileNotFoundError(f"Model file not found. Expected either {safetensors_path} or {pytorch_model_path}")
for kw, val in vae_state_dict.items():
model_state_dict[f"audio_vae.{kw}"] = val
@@ -900,6 +916,7 @@ class VoxCPMModel(nn.Module):
def _iter_lora_modules(self):
"""Iterate over all LoRA modules."""
from ..modules.layers.lora import LoRALinear
for module in self.modules():
if isinstance(module, LoRALinear):
yield module
@@ -919,15 +936,15 @@ class VoxCPMModel(nn.Module):
from pathlib import Path
device = device or self.device
lora_path = Path(lora_path)
lora_p = Path(lora_path)
# Try safetensors first, then fallback to .ckpt
if lora_path.is_dir():
safetensors_file = lora_path / "lora_weights.safetensors"
ckpt_file = lora_path / "lora_weights.ckpt"
if lora_p.is_dir():
safetensors_file = lora_p / "lora_weights.safetensors"
ckpt_file = lora_p / "lora_weights.ckpt"
else:
safetensors_file = lora_path if lora_path.suffix == ".safetensors" else None
ckpt_file = lora_path if lora_path.suffix in [".ckpt", ".pth"] else None
safetensors_file = lora_p if lora_p.suffix == ".safetensors" else None
ckpt_file = lora_p if lora_p.suffix in [".ckpt", ".pth"] else None
# Load from safetensors if available
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
@@ -936,9 +953,7 @@ class VoxCPMModel(nn.Module):
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
state_dict = ckpt.get("state_dict", ckpt)
else:
raise FileNotFoundError(
f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}"
)
raise FileNotFoundError(f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}")
# Build param mapping (handle torch.compile's _orig_mod prefix)
model_params = dict(self.named_parameters())
@@ -967,6 +982,4 @@ class VoxCPMModel(nn.Module):
def get_lora_state_dict(self) -> dict:
"""Get all LoRA parameters (lora_A/lora_B)."""
return {name: param.data.clone()
for name, param in self.named_parameters()
if "lora_" in name}
return {name: param.data.clone() for name, param in self.named_parameters() if "lora_" in name}
File diff suppressed because it is too large Load Diff
+1
View File
@@ -1 +1,2 @@
from .audio_vae import AudioVAE, AudioVAEConfig
from .audio_vae_v2 import AudioVAE as AudioVAEV2, AudioVAEConfig as AudioVAEConfigV2
+2 -2
View File
@@ -1,5 +1,5 @@
import math
from typing import List, Union, Optional
from typing import List
import numpy as np
import torch
@@ -285,7 +285,7 @@ class AudioVAE(nn.Module):
def __init__(
self,
config: Optional[AudioVAEConfig] = None,
config: AudioVAEConfig = None,
):
# 如果没有传入config,使用默认配置
if config is None:
+486
View File
@@ -0,0 +1,486 @@
import math
from typing import List, Optional
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
from pydantic import BaseModel
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
class CausalConv1d(nn.Conv1d):
def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs):
super().__init__(*args, **kwargs)
self.__padding = padding
self.__output_padding = output_padding
def forward(self, x):
x_pad = F.pad(x, (self.__padding * 2 - self.__output_padding, 0))
return super().forward(x_pad)
class CausalTransposeConv1d(nn.ConvTranspose1d):
def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs):
super().__init__(*args, **kwargs)
self.__padding = padding
self.__output_padding = output_padding
def forward(self, x):
return super().forward(x)[..., : -(self.__padding * 2 - self.__output_padding)]
def WNCausalConv1d(*args, **kwargs):
return weight_norm(CausalConv1d(*args, **kwargs))
def WNCausalTransposeConv1d(*args, **kwargs):
return weight_norm(CausalTransposeConv1d(*args, **kwargs))
# Scripting this brings model speed up 1.4x
@torch.jit.script
def snake(x, alpha):
shape = x.shape
x = x.reshape(shape[0], shape[1], -1)
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
x = x.reshape(shape)
return x
class Snake1d(nn.Module):
def __init__(self, channels):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
def forward(self, x):
return snake(x, self.alpha)
def init_weights(m):
if isinstance(m, nn.Conv1d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
class CausalResidualUnit(nn.Module):
def __init__(self, dim: int = 16, dilation: int = 1, kernel: int = 7, groups: int = 1):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.block = nn.Sequential(
Snake1d(dim),
WNCausalConv1d(
dim,
dim,
kernel_size=kernel,
dilation=dilation,
padding=pad,
groups=groups,
),
Snake1d(dim),
WNCausalConv1d(dim, dim, kernel_size=1),
)
def forward(self, x):
y = self.block(x)
pad = (x.shape[-1] - y.shape[-1]) // 2
assert pad == 0
if pad > 0:
x = x[..., pad:-pad]
return x + y
class CausalEncoderBlock(nn.Module):
def __init__(self, output_dim: int = 16, input_dim=None, stride: int = 1, groups=1):
super().__init__()
input_dim = input_dim or output_dim // 2
self.block = nn.Sequential(
CausalResidualUnit(input_dim, dilation=1, groups=groups),
CausalResidualUnit(input_dim, dilation=3, groups=groups),
CausalResidualUnit(input_dim, dilation=9, groups=groups),
Snake1d(input_dim),
WNCausalConv1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
output_padding=stride % 2,
),
)
def forward(self, x):
return self.block(x)
class CausalEncoder(nn.Module):
def __init__(
self,
d_model: int = 64,
latent_dim: int = 32,
strides: list = [2, 4, 8, 8],
depthwise: bool = False,
):
super().__init__()
# Create first convolution
self.block = [WNCausalConv1d(1, d_model, kernel_size=7, padding=3)]
# Create EncoderBlocks that double channels as they downsample by `stride`
for stride in strides:
d_model *= 2
groups = d_model // 2 if depthwise else 1
self.block += [CausalEncoderBlock(output_dim=d_model, stride=stride, groups=groups)]
groups = d_model if depthwise else 1
# Create two convolution, for mu and logvar
self.fc_mu = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
self.fc_logvar = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
# Wrap black into nn.Sequential
self.block = nn.Sequential(*self.block)
self.enc_dim = d_model
def forward(self, x):
hidden_state = self.block(x)
return {
"hidden_state": hidden_state,
"mu": self.fc_mu(hidden_state),
"logvar": self.fc_logvar(hidden_state),
}
class NoiseBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.linear = WNCausalConv1d(dim, dim, kernel_size=1, bias=False)
def forward(self, x):
B, C, T = x.shape
noise = torch.randn((B, 1, T), device=x.device, dtype=x.dtype)
h = self.linear(x)
n = noise * h
x = x + n
return x
class CausalDecoderBlock(nn.Module):
def __init__(
self,
input_dim: int = 16,
output_dim: int = 8,
stride: int = 1,
groups=1,
use_noise_block: bool = False,
):
super().__init__()
layers = [
Snake1d(input_dim),
WNCausalTransposeConv1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
output_padding=stride % 2,
),
]
if use_noise_block:
layers.append(NoiseBlock(output_dim))
layers.extend(
[
CausalResidualUnit(output_dim, dilation=1, groups=groups),
CausalResidualUnit(output_dim, dilation=3, groups=groups),
CausalResidualUnit(output_dim, dilation=9, groups=groups),
]
)
self.block = nn.Sequential(*layers)
self.input_channels = input_dim
def forward(self, x):
return self.block(x)
class TransposeLastTwoDim(torch.nn.Module):
def forward(self, x):
return torch.transpose(x, -1, -2)
class SampleRateConditionLayer(nn.Module):
def __init__(
self,
input_dim: int,
sr_bin_buckets: int = None,
cond_type: str = "scale_bias",
cond_dim: int = 128,
out_layer: bool = False,
):
super().__init__()
self.cond_type, out_layer_in_dim = cond_type, input_dim
if cond_type == "scale_bias":
self.scale_embed = nn.Embedding(sr_bin_buckets, input_dim)
self.bias_embed = nn.Embedding(sr_bin_buckets, input_dim)
nn.init.ones_(self.scale_embed.weight)
nn.init.zeros_(self.bias_embed.weight)
elif cond_type == "scale_bias_init":
self.scale_embed = nn.Embedding(sr_bin_buckets, input_dim)
self.bias_embed = nn.Embedding(sr_bin_buckets, input_dim)
nn.init.normal_(self.scale_embed.weight, mean=1)
nn.init.normal_(self.bias_embed.weight)
elif cond_type == "add":
self.cond_embed = nn.Embedding(sr_bin_buckets, input_dim)
nn.init.normal_(self.cond_embed.weight)
elif cond_type == "concat":
self.cond_embed = nn.Embedding(sr_bin_buckets, cond_dim)
assert out_layer, "out_layer must be True for concat cond_type"
out_layer_in_dim = input_dim + cond_dim
else:
raise ValueError(f"Invalid cond_type: {cond_type}")
if out_layer:
self.out_layer = nn.Sequential(
Snake1d(out_layer_in_dim),
WNCausalConv1d(out_layer_in_dim, input_dim, kernel_size=1),
)
else:
self.out_layer = nn.Identity()
def forward(self, x, sr_cond):
if self.cond_type == "scale_bias" or self.cond_type == "scale_bias_init":
x = x * self.scale_embed(sr_cond).unsqueeze(-1) + self.bias_embed(sr_cond).unsqueeze(-1)
elif self.cond_type == "add":
x = x + self.cond_embed(sr_cond).unsqueeze(-1)
elif self.cond_type == "concat":
x = torch.cat([x, self.cond_embed(sr_cond).unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
return self.out_layer(x)
class CausalDecoder(nn.Module):
def __init__(
self,
input_channel,
channels,
rates,
depthwise: bool = False,
d_out: int = 1,
use_noise_block: bool = False,
sr_bin_boundaries: List[int] = None,
cond_type: str = "scale_bias",
cond_dim: int = 128,
cond_out_layer: bool = False,
):
super().__init__()
# Add first conv layer
if depthwise:
layers = [
WNCausalConv1d(input_channel, input_channel, kernel_size=7, padding=3, groups=input_channel),
WNCausalConv1d(input_channel, channels, kernel_size=1),
]
else:
layers = [WNCausalConv1d(input_channel, channels, kernel_size=7, padding=3)]
# Add upsampling + MRF blocks
for i, stride in enumerate(rates):
input_dim = channels // 2**i
output_dim = channels // 2 ** (i + 1)
groups = output_dim if depthwise else 1
layers += [
CausalDecoderBlock(
input_dim,
output_dim,
stride,
groups=groups,
use_noise_block=use_noise_block,
)
]
# Add final conv layer
layers += [
Snake1d(output_dim),
WNCausalConv1d(output_dim, d_out, kernel_size=7, padding=3),
nn.Tanh(),
]
if sr_bin_boundaries is None:
self.model = nn.Sequential(*layers)
self.sr_bin_boundaries = None
else:
self.model = nn.ModuleList(layers)
self.register_buffer("sr_bin_boundaries", torch.tensor(sr_bin_boundaries, dtype=torch.int32))
self.sr_bin_buckets = len(sr_bin_boundaries) + 1
cond_layers = []
for layer in self.model:
if layer.__class__.__name__ == "CausalDecoderBlock":
cond_layers.append(
SampleRateConditionLayer(
input_dim=layer.input_channels,
sr_bin_buckets=self.sr_bin_buckets,
cond_type=cond_type,
cond_dim=cond_dim,
out_layer=cond_out_layer,
)
)
else:
cond_layers.append(None)
self.sr_cond_model = nn.ModuleList(cond_layers)
def get_sr_idx(self, sr):
return torch.bucketize(sr, self.sr_bin_boundaries)
def forward(self, x, sr_cond=None):
if self.sr_bin_boundaries is not None:
# assert sr_cond is not None
sr_cond = self.get_sr_idx(sr_cond)
for layer, sr_cond_layer in zip(self.model, self.sr_cond_model):
if sr_cond_layer is not None:
x = sr_cond_layer(x, sr_cond)
x = layer(x)
return x
else:
return self.model(x)
class AudioVAEConfig(BaseModel):
encoder_dim: int = 128
encoder_rates: List[int] = [2, 5, 8, 8]
latent_dim: int = 64
decoder_dim: int = 2048
decoder_rates: List[int] = [8, 6, 5, 2, 2, 2]
depthwise: bool = True
sample_rate: int = 16000
out_sample_rate: int = 48000
use_noise_block: bool = False
sr_bin_boundaries: Optional[List[int]] = [20000, 30000, 40000]
cond_type: str = "scale_bias"
cond_dim: int = 128
cond_out_layer: bool = False
class AudioVAE(nn.Module):
"""
Args:
"""
def __init__(
self,
config: AudioVAEConfig = None,
):
# 如果没有传入config,使用默认配置
if config is None:
config = AudioVAEConfig()
super().__init__()
encoder_dim = config.encoder_dim
encoder_rates = config.encoder_rates
latent_dim = config.latent_dim
decoder_dim = config.decoder_dim
decoder_rates = config.decoder_rates
depthwise = config.depthwise
sample_rate = config.sample_rate
out_sample_rate = config.out_sample_rate
use_noise_block = config.use_noise_block
sr_bin_boundaries = config.sr_bin_boundaries
cond_type = config.cond_type
cond_dim = config.cond_dim
cond_out_layer = config.cond_out_layer
self.encoder_dim = encoder_dim
self.encoder_rates = encoder_rates
self.decoder_dim = decoder_dim
self.decoder_rates = decoder_rates
self.depthwise = depthwise
self.use_noise_block = use_noise_block
if latent_dim is None:
latent_dim = encoder_dim * (2 ** len(encoder_rates))
self.latent_dim = latent_dim
self.hop_length = np.prod(encoder_rates)
self.encoder = CausalEncoder(
encoder_dim,
latent_dim,
encoder_rates,
depthwise=depthwise,
)
self.decoder = CausalDecoder(
latent_dim,
decoder_dim,
decoder_rates,
depthwise=depthwise,
use_noise_block=use_noise_block,
sr_bin_boundaries=sr_bin_boundaries,
cond_type=cond_type,
cond_dim=cond_dim,
cond_out_layer=cond_out_layer,
)
self.sample_rate = sample_rate
self.out_sample_rate = out_sample_rate
self.sr_bin_boundaries = sr_bin_boundaries
self.chunk_size = math.prod(encoder_rates)
def preprocess(self, audio_data, sample_rate):
if sample_rate is None:
sample_rate = self.sample_rate
assert sample_rate == self.sample_rate
pad_to = self.hop_length
length = audio_data.shape[-1]
right_pad = math.ceil(length / pad_to) * pad_to - length
audio_data = nn.functional.pad(audio_data, (0, right_pad))
return audio_data
def decode(self, z: torch.Tensor, sr_cond: torch.Tensor = None):
"""Decode given latent codes and return audio data
Parameters
----------
z : Tensor[B x D x T]
Quantized continuous representation of input
length : int, optional
Number of samples in output audio, by default None
Returns
-------
dict
A dictionary with the following keys:
"audio" : Tensor[B x 1 x length]
Decoded audio data.
"""
if self.sr_bin_boundaries is not None:
# use default output sample rate
if sr_cond is None:
sr_cond = torch.tensor([self.out_sample_rate], device=z.device, dtype=torch.int32)
return self.decoder(z, sr_cond)
def encode(self, audio_data: torch.Tensor, sample_rate: int):
"""
Args:
audio_data: Tensor[B x 1 x T]
sample_rate: int
Returns:
z: Tensor[B x D x T]
"""
if audio_data.ndim == 2:
audio_data = audio_data.unsqueeze(1)
audio_data = self.preprocess(audio_data, sample_rate)
return self.encoder(audio_data)["mu"]
-3
View File
@@ -128,6 +128,3 @@ def apply_lora_to_named_linear_modules(
dropout=dropout,
)
setattr(parent, short_name, lora_layer)
+1
View File
@@ -1,2 +1,3 @@
from .unified_cfm import UnifiedCFM, CfmConfig
from .local_dit import VoxCPMLocDiT
from .local_dit_v2 import VoxCPMLocDiT as VoxCPMLocDiTV2
+116
View File
@@ -0,0 +1,116 @@
import torch
from ..minicpm4 import MiniCPMModel, MiniCPM4Config
import torch.nn as nn
import math
class SinusoidalPosEmb(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
def forward(self, x, scale=1000):
if x.ndim < 1:
x = x.unsqueeze(0)
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=x.dtype, device=device) * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
out_dim: int = None,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True)
self.act = nn.SiLU()
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, bias=True)
def forward(self, sample):
sample = self.linear_1(sample)
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
class VoxCPMLocDiT(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
config: MiniCPM4Config,
in_channels: int = 64,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels
self.config = config
self.in_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
self.cond_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
self.out_proj = nn.Linear(config.hidden_size, self.out_channels, bias=True)
self.time_embeddings = SinusoidalPosEmb(config.hidden_size)
self.time_mlp = TimestepEmbedding(
in_channels=config.hidden_size,
time_embed_dim=config.hidden_size,
)
self.delta_time_mlp = TimestepEmbedding(
in_channels=config.hidden_size,
time_embed_dim=config.hidden_size,
)
assert config.vocab_size == 0, "vocab_size must be 0 for local DiT"
self.decoder = MiniCPMModel(config)
def forward(
self,
x: torch.Tensor,
mu: torch.Tensor,
t: torch.Tensor,
cond: torch.Tensor,
dt: torch.Tensor,
):
"""
Forward pass of DiT.
x: (N, C, T) tensor of inputs
mu: (N, C) tensor of hidden embedding
t: (N,) tensor of diffusion timesteps
cond: (N, C, T') tensor of prefix conditions
dt: (N,) used for mean velocity (may be supported in the future...)
"""
x = self.in_proj(x.transpose(1, 2).contiguous())
cond = self.cond_proj(cond.transpose(1, 2).contiguous())
prefix = cond.size(1)
t = self.time_embeddings(t).to(x.dtype)
t = self.time_mlp(t)
dt = self.time_embeddings(dt).to(x.dtype)
dt = self.delta_time_mlp(dt)
t = t + dt
mu = mu.view(x.size(0), -1, x.size(-1))
x = torch.cat([mu, (t).unsqueeze(1), cond, x], dim=1)
hidden, _ = self.decoder(x, is_causal=False)
hidden = hidden[:, prefix + mu.size(1) + 1 :, :]
hidden = self.out_proj(hidden)
return hidden.transpose(1, 2).contiguous()
+5 -4
View File
@@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import Tuple
import torch
import torch.nn.functional as F
@@ -138,7 +138,9 @@ class UnifiedCFM(torch.nn.Module):
# ------------------------------------------------------------------ #
# Training loss
# ------------------------------------------------------------------ #
def adaptive_loss_weighting(self, losses: torch.Tensor, mask: torch.Tensor | None = None, p: float = 0.0, epsilon: float = 1e-3):
def adaptive_loss_weighting(
self, losses: torch.Tensor, mask: torch.Tensor | None = None, p: float = 0.0, epsilon: float = 1e-3
):
weights = 1.0 / ((losses + epsilon).pow(p))
if mask is not None:
weights = weights * mask
@@ -193,8 +195,7 @@ class UnifiedCFM(torch.nn.Module):
cond = cond + noisy_mask.view(-1, 1, 1) * torch.randn_like(cond) * self.noise_cond_scale
ratio_r_neq_t = (
self.ratio_r_neq_t_range[0]
+ progress * (self.ratio_r_neq_t_range[1] - self.ratio_r_neq_t_range[0])
self.ratio_r_neq_t_range[0] + progress * (self.ratio_r_neq_t_range[1] - self.ratio_r_neq_t_range[0])
if self.mean_mode
else 0.0
)
+12 -13
View File
@@ -64,10 +64,8 @@ class MiniCPMLongRoPE(nn.Module):
self.long_factor = config.rope_scaling.long_factor
self.original_max_position_embeddings = config.rope_scaling.original_max_position_embeddings
scale = (self.max_position_embeddings / self.original_max_position_embeddings)
self.scaling_factor = math.sqrt(
1 + math.log(scale) / math.log(self.original_max_position_embeddings)
)
scale = self.max_position_embeddings / self.original_max_position_embeddings
self.scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
@@ -76,11 +74,7 @@ class MiniCPMLongRoPE(nn.Module):
self.register_buffer("cos_cached", torch.empty(0), persistent=False)
self.register_buffer("sin_cached", torch.empty(0), persistent=False)
self._set_cos_sin_cache(
seq_len=self.max_position_embeddings,
device=self.inv_freq.device,
dtype=torch.float32
)
self._set_cos_sin_cache(seq_len=self.max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32)
def _set_cos_sin_cache(self, seq_len, device, dtype):
"""设置cos和sin缓存"""
@@ -93,8 +87,7 @@ class MiniCPMLongRoPE(nn.Module):
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=device)
freqs = torch.mul(
torch.outer(t, 1.0 / ext_factors).to(device=device),
self.inv_freq.to(device=device).to(dtype)
torch.outer(t, 1.0 / ext_factors).to(device=device), self.inv_freq.to(device=device).to(dtype)
)
# 创建embeddings
@@ -123,7 +116,9 @@ class MiniCPMAttention(nn.Module):
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
self.head_dim = (
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
)
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
@@ -413,7 +408,11 @@ class MiniCPMModel(nn.Module):
self.kv_cache = StaticKVCache(
num_layers=self.config.num_hidden_layers,
num_kv_heads=self.config.num_key_value_heads,
dim_kv_head=self.config.hidden_size // self.config.num_attention_heads if self.config.kv_channels is None else self.config.kv_channels,
dim_kv_head=(
self.config.hidden_size // self.config.num_attention_heads
if self.config.kv_channels is None
else self.config.kv_channels
),
batch_size=batch_size,
device=device,
dtype=dtype,
-1
View File
@@ -25,4 +25,3 @@ __all__ = [
"load_audio_text_datasets",
"build_dataloader",
]
+2 -5
View File
@@ -47,9 +47,7 @@ class Accelerator:
pass
self.scaler = torch.amp.GradScaler("cuda") if (amp and torch.cuda.is_available()) else DummyScaler()
self.device_ctx = (
torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
)
self.device_ctx = torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
self._ddp_model = None # For no_sync support
def _set_seed(self, seed: int):
@@ -84,7 +82,7 @@ class Accelerator:
# Model helpers
# ------------------------------------------------------------------ #
def prepare_model(self, model: torch.nn.Module, **kwargs):
if hasattr(model, 'device'): # make sure the matrix will be moved to the correct device
if hasattr(model, "device"): # make sure the matrix will be moved to the correct device
model.device = self.device
model = model.to(self.device)
if self.world_size > 1:
@@ -163,4 +161,3 @@ class Accelerator:
@staticmethod
def unwrap(model: torch.nn.Module) -> torch.nn.Module:
return model.module if hasattr(model, "module") else model
-2
View File
@@ -36,5 +36,3 @@ def parse_args_with_config(config_path: str | Path | None = None):
yaml_args = argbind.parse_args(yaml_args=yaml_args, argv=[])
cli_args.update(yaml_args)
return cli_args
-2
View File
@@ -1,5 +1,4 @@
import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import argbind
@@ -11,7 +10,6 @@ from ..model.voxcpm import VoxCPMConfig
from ..modules.audiovae import AudioVAE
from .packers import AudioFeatureProcessingPacker
DEFAULT_TEXT_COLUMN = "text"
DEFAULT_AUDIO_COLUMN = "audio"
DEFAULT_ID_COLUMN = "dataset_id"
+28 -21
View File
@@ -1,5 +1,4 @@
from typing import Dict, List, Tuple
from typing import Dict, List
import torch
import torch.nn as nn
@@ -147,31 +146,26 @@ class AudioFeatureProcessingPacker:
def pad_1d(x: torch.Tensor, pad_value: int = 0) -> torch.Tensor:
if x.size(0) >= max_len:
return x[: max_len]
return x[:max_len]
pad = torch.full((max_len - x.size(0),), pad_value, dtype=x.dtype, device=x.device)
return torch.cat([x, pad], dim=0)
def pad_3d(x: torch.Tensor) -> torch.Tensor:
# x: [T, P, D]
if x.size(0) >= max_len:
return x[: max_len]
pad = torch.zeros(
(max_len - x.size(0),) + x.shape[1:], dtype=x.dtype, device=x.device
)
return x[:max_len]
pad = torch.zeros((max_len - x.size(0),) + x.shape[1:], dtype=x.dtype, device=x.device)
return torch.cat([x, pad], dim=0)
if lengths:
text_tokens_batch = torch.stack([pad_1d(t, pad_value=0) for t in text_tokens_list], dim=0)
text_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in text_mask_list], dim=0)
audio_feats_batch = torch.stack([pad_3d(f) for f in audio_feats_list], dim=0)
audio_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in audio_mask_list], dim=0)
loss_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in loss_mask_list], dim=0)
labels_batch = torch.stack([pad_1d(l, pad_value=0) for l in labels_list], dim=0)
audio_task_ids_batch = torch.stack(
[pad_1d(t, pad_value=0) for t in audio_task_ids_list], dim=0
)
audio_dataset_ids_batch = torch.stack(
[pad_1d(d, pad_value=0) for d in audio_dataset_ids_list], dim=0
)
labels_batch = torch.stack([pad_1d(lbl, pad_value=0) for lbl in labels_list], dim=0)
audio_task_ids_batch = torch.stack([pad_1d(t, pad_value=0) for t in audio_task_ids_list], dim=0)
audio_dataset_ids_batch = torch.stack([pad_1d(d, pad_value=0) for d in audio_dataset_ids_list], dim=0)
# Position ids: [B, T], simple 0..L_i-1 then padded with 0
position_ids_list = []
@@ -265,13 +259,27 @@ class AudioFeatureProcessingPacker:
)
audio_feat_info = torch.cat([audio_pad_feat, audio_feat_info, audio_pad_feat[0:1, ...]], dim=0)
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length), torch.ones(1)]).type(torch.int32).to(
text_token.device
text_mask = (
torch.cat([torch.ones(text_length), torch.zeros(audio_length), torch.ones(1)])
.type(torch.int32)
.to(text_token.device)
)
audio_mask = (
torch.cat([torch.zeros(text_length), torch.ones(audio_length), torch.zeros(1)])
.type(torch.int32)
.to(text_token.device)
)
loss_mask = (
torch.cat(
[
torch.zeros(text_length),
torch.zeros(audio_length) if is_prompt else torch.ones(audio_length),
torch.zeros(1),
]
)
.type(torch.int32)
.to(text_token.device)
)
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length), torch.zeros(1)]).type(
torch.int32
).to(text_token.device)
loss_mask = torch.cat([torch.zeros(text_length), torch.zeros(audio_length) if is_prompt else torch.ones(audio_length), torch.zeros(1)]).type(torch.int32).to(text_token.device)
labels = torch.zeros(text_length + audio_length + 1).type(torch.int32).to(text_token.device)
labels[-2] = 1
@@ -286,4 +294,3 @@ class AudioFeatureProcessingPacker:
audio_duration,
text_token_count,
)
-1
View File
@@ -18,4 +18,3 @@ class TrainingState:
val_loader: object
tracker: object
batch_processor: object
-1
View File
@@ -76,4 +76,3 @@ class TrainingTracker:
@contextlib.contextmanager
def live(self):
yield
+31 -28
View File
@@ -2,10 +2,10 @@
import re
import regex
import inflect
from functools import partial
from wetext import Normalizer
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]+")
# whether contain chinese character
def contains_chinese(text):
@@ -14,19 +14,19 @@ def contains_chinese(text):
# replace special symbol
def replace_corner_mark(text):
text = text.replace('²', '平方')
text = text.replace('³', '立方')
text = text.replace('', '根号')
text = text.replace('', '约等于')
text = text.replace('<', '小于')
text = text.replace("²", "平方")
text = text.replace("³", "立方")
text = text.replace("", "根号")
text = text.replace("", "约等于")
text = text.replace("<", "小于")
return text
# remove meaningless symbol
def remove_bracket(text):
text = text.replace('', ' ').replace('', ' ')
text = text.replace('', ' ').replace('', ' ')
text = text.replace('`', '').replace('`', '')
text = text.replace("", " ").replace("", " ")
text = text.replace("", " ").replace("", " ")
text = text.replace("`", "").replace("`", "")
text = text.replace("——", " ")
return text
@@ -38,7 +38,7 @@ def spell_out_number(text: str, inflect_parser):
for i, c in enumerate(text):
if not c.isdigit():
if st is not None:
num_str = inflect_parser.number_to_words(text[st: i])
num_str = inflect_parser.number_to_words(text[st:i])
new_text.append(num_str)
st = None
new_text.append(c)
@@ -48,7 +48,7 @@ def spell_out_number(text: str, inflect_parser):
if st is not None and st < len(text):
num_str = inflect_parser.number_to_words(text[st:])
new_text.append(num_str)
return ''.join(new_text)
return "".join(new_text)
# split paragrah logic
@@ -69,18 +69,18 @@ def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=
return len(tokenize(_text)) < merge_len
if lang == "zh":
pounc = ['', '', '', '', '', '', '.', '?', '!', ';']
pounc = ["", "", "", "", "", "", ".", "?", "!", ";"]
else:
pounc = ['.', '?', '!', ';', ':']
pounc = [".", "?", "!", ";", ":"]
if comma_split:
pounc.extend(['', ','])
pounc.extend(["", ","])
st = 0
utts = []
for i, c in enumerate(text):
if c in pounc:
if len(text[st: i]) > 0:
utts.append(text[st: i] + c)
if i + 1 < len(text) and text[i + 1] in ['"', '']:
if len(text[st:i]) > 0:
utts.append(text[st:i] + c)
if i + 1 < len(text) and text[i + 1] in ['"', ""]:
tmp = utts.pop(-1)
utts.append(tmp + text[i + 1])
st = i + 2
@@ -88,9 +88,9 @@ def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=
st = i + 1
if len(utts) == 0:
if lang == "zh":
utts.append(text + '')
utts.append(text + "")
else:
utts.append(text + '.')
utts.append(text + ".")
final_utts = []
cur_utt = ""
for utt in utts:
@@ -112,13 +112,13 @@ def replace_blank(text: str):
out_str = []
for i, c in enumerate(text):
if c == " ":
if ((text[i + 1].isascii() and text[i + 1] != " ") and
(text[i - 1].isascii() and text[i - 1] != " ")):
if (text[i + 1].isascii() and text[i + 1] != " ") and (text[i - 1].isascii() and text[i - 1] != " "):
out_str.append(c)
else:
out_str.append(c)
return "".join(out_str)
def clean_markdown(md_text: str) -> str:
# 去除代码块 ``` ```(包括多行)
md_text = re.sub(r"```.*?```", "", md_text, flags=re.DOTALL)
@@ -133,7 +133,7 @@ def clean_markdown(md_text: str) -> str:
md_text = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", md_text)
# 替换无序列表符号
md_text = re.sub(r'^(\s*)-\s+', r'\1', md_text, flags=re.MULTILINE)
md_text = re.sub(r"^(\s*)-\s+", r"\1", md_text, flags=re.MULTILINE)
# 去除HTML标签
md_text = re.sub(r"<[^>]+>", "", md_text)
@@ -152,13 +152,14 @@ def clean_text(text):
# 去除 Markdown 语法
text = clean_markdown(text)
# 匹配并移除表情符号
text = regex.compile(r'\p{Emoji_Presentation}|\p{Emoji}\uFE0F', flags=regex.UNICODE).sub("",text)
text = regex.compile(r"\p{Emoji_Presentation}|\p{Emoji}\uFE0F", flags=regex.UNICODE).sub("", text)
# 去除换行符
text = text.replace("\n", " ")
text = text.replace("\t", " ")
text = text.replace('"', "\")
text = text.replace("", '"').replace("", '"')
return text
class TextNormalizer:
def __init__(self, tokenizer=None):
self.tokenizer = tokenizer
@@ -171,9 +172,11 @@ class TextNormalizer:
lang = "zh" if contains_chinese(text) else "en"
text = clean_text(text)
if lang == "zh":
text = text.replace("=", "等于") # 修复 ”550 + 320 等于 870 千卡。“ 被错误正则为 ”五百五十加三百二十等于八七十千卡.“
if re.search(r'([\d$%^*_+≥≤≠×÷?=])', text): # 避免 英文连字符被错误正则为减
text = re.sub(r'(?<=[a-zA-Z0-9])-(?=\d)', ' - ', text) # 修复 x-2 被正则为 x负2
text = text.replace(
"=", "等于"
) # 修复 ”550 + 320 等于 870 千卡。“ 被错误正则为 ”五百五十加三百二十等于八七十千卡.“
if re.search(r"([\d$%^*_+≥≤≠×÷?=])", text): # 避免 英文连字符被错误正则为减
text = re.sub(r"(?<=[a-zA-Z0-9])-(?=\d)", " - ", text) # 修复 x-2 被正则为 x负2
text = self.zh_tn_model.normalize(text)
text = replace_blank(text)
text = replace_corner_mark(text)
+6 -10
View File
@@ -7,15 +7,15 @@ Related dependencies are imported only when denoising functionality is needed.
import os
import tempfile
from typing import Optional, Union
from typing import Optional
import torchaudio
import torch
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
class ZipEnhancer:
"""ZipEnhancer Audio Denoising Enhancer"""
def __init__(self, model_path: str = "iic/speech_zipenhancer_ans_multiloss_16k_base"):
"""
Initialize ZipEnhancer
@@ -23,10 +23,7 @@ class ZipEnhancer:
model_path: ModelScope model path or local path
"""
self.model_path = model_path
self._pipeline = pipeline(
Tasks.acoustic_noise_suppression,
model=self.model_path
)
self._pipeline = pipeline(Tasks.acoustic_noise_suppression, model=self.model_path)
def _normalize_loudness(self, wav_path: str):
"""
@@ -37,11 +34,10 @@ class ZipEnhancer:
"""
audio, sr = torchaudio.load(wav_path)
loudness = torchaudio.functional.loudness(audio, sr)
normalized_audio = torchaudio.functional.gain(audio, -20-loudness)
normalized_audio = torchaudio.functional.gain(audio, -20 - loudness)
torchaudio.save(wav_path, normalized_audio, sr)
def enhance(self, input_path: str, output_path: Optional[str] = None,
normalize_loudness: bool = True) -> str:
def enhance(self, input_path: str, output_path: Optional[str] = None, normalize_loudness: bool = True) -> str:
"""
Audio denoising enhancement
Args:
@@ -57,7 +53,7 @@ class ZipEnhancer:
raise FileNotFoundError(f"Input audio file does not exist: {input_path}")
# Create temporary file if no output path is specified
if output_path is None:
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
output_path = tmp_file.name
try:
# Perform denoising processing
Generated
+5263
View File
File diff suppressed because it is too large Load Diff