update voxcpm2
This commit is contained in:
@@ -2,14 +2,15 @@ import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
import gradio as gr
|
||||
import spaces
|
||||
import gradio as gr
|
||||
import spaces # noqa: F401
|
||||
from typing import Optional, Tuple
|
||||
from funasr import AutoModel
|
||||
from pathlib import Path
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
if os.environ.get("HF_REPO_ID", "").strip() == "":
|
||||
os.environ["HF_REPO_ID"] = "openbmb/VoxCPM1.5"
|
||||
os.environ["HF_REPO_ID"] = "openbmb/VoxCPM2"
|
||||
|
||||
import voxcpm
|
||||
|
||||
@@ -24,13 +25,13 @@ class VoxCPMDemo:
|
||||
self.asr_model: Optional[AutoModel] = AutoModel(
|
||||
model=self.asr_model_id,
|
||||
disable_update=True,
|
||||
log_level='DEBUG',
|
||||
log_level="DEBUG",
|
||||
device="cuda:0" if self.device == "cuda" else "cpu",
|
||||
)
|
||||
|
||||
# TTS model (lazy init)
|
||||
self.voxcpm_model: Optional[voxcpm.VoxCPM] = None
|
||||
self.default_local_model_dir = "./models/VoxCPM1.5"
|
||||
self.default_local_model_dir = "/Users/xinliu/Downloads/VoxCPM2-0.5B-newaudiovae-6hz-0316"
|
||||
|
||||
# ---------- Model helpers ----------
|
||||
def _resolve_model_dir(self) -> str:
|
||||
@@ -49,6 +50,7 @@ class VoxCPMDemo:
|
||||
if not os.path.isdir(target_dir):
|
||||
try:
|
||||
from huggingface_hub import snapshot_download # type: ignore
|
||||
|
||||
os.makedirs(target_dir, exist_ok=True)
|
||||
print(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...", file=sys.stderr)
|
||||
snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False)
|
||||
@@ -64,7 +66,7 @@ class VoxCPMDemo:
|
||||
print("Model not loaded, initializing...", file=sys.stderr)
|
||||
model_dir = self._resolve_model_dir()
|
||||
print(f"Using model dir: {model_dir}", file=sys.stderr)
|
||||
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir)
|
||||
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir, optimize=False)
|
||||
print("Model loaded successfully.", file=sys.stderr)
|
||||
return self.voxcpm_model
|
||||
|
||||
@@ -73,21 +75,24 @@ class VoxCPMDemo:
|
||||
if prompt_wav is None:
|
||||
return ""
|
||||
res = self.asr_model.generate(input=prompt_wav, language="auto", use_itn=True)
|
||||
text = res[0]["text"].split('|>')[-1]
|
||||
text = res[0]["text"].split("|>")[-1]
|
||||
return text
|
||||
|
||||
def generate_tts_audio(
|
||||
self,
|
||||
text_input: str,
|
||||
prompt_wav_path_input: Optional[str] = None,
|
||||
prompt_text_input: Optional[str] = None,
|
||||
control_instruction: str = "",
|
||||
reference_wav_path_input: Optional[str] = None,
|
||||
cfg_value_input: float = 2.0,
|
||||
inference_timesteps_input: int = 10,
|
||||
do_normalize: bool = True,
|
||||
denoise: bool = True,
|
||||
) -> Tuple[int, np.ndarray]:
|
||||
"""
|
||||
Generate speech from text using VoxCPM; optional reference audio for voice style guidance.
|
||||
Generate speech from text using VoxCPM.
|
||||
- If reference_wav provided: Prompt isolation mode (voice cloning)
|
||||
- If no reference_wav: Voice design mode (use control_instruction to describe voice)
|
||||
|
||||
Returns (sample_rate, waveform_numpy)
|
||||
"""
|
||||
current_model = self.get_or_load_voxcpm()
|
||||
@@ -96,14 +101,25 @@ class VoxCPMDemo:
|
||||
if len(text) == 0:
|
||||
raise ValueError("Please input text to synthesize.")
|
||||
|
||||
prompt_wav_path = prompt_wav_path_input if prompt_wav_path_input else None
|
||||
prompt_text = prompt_text_input if prompt_text_input else None
|
||||
# 处理 control instruction
|
||||
control = (control_instruction or "").strip()
|
||||
if control:
|
||||
final_text = f"({control}){text}"
|
||||
else:
|
||||
final_text = text
|
||||
|
||||
print(f"Generating audio for text: '{text[:60]}...'", file=sys.stderr)
|
||||
reference_wav_path = reference_wav_path_input if reference_wav_path_input else None
|
||||
|
||||
# 判断模式
|
||||
if reference_wav_path:
|
||||
print(f"[Prompt Isolation Mode] reference_wav: {reference_wav_path}", file=sys.stderr)
|
||||
else:
|
||||
print(f"[Voice Design Mode] control: {control[:50] if control else 'None'}...", file=sys.stderr)
|
||||
|
||||
print(f"Generating audio for text: '{final_text[:80]}...'", file=sys.stderr)
|
||||
wav = current_model.generate(
|
||||
text=text,
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
text=final_text,
|
||||
reference_wav_path=reference_wav_path,
|
||||
cfg_value=float(cfg_value_input),
|
||||
inference_timesteps=int(inference_timesteps_input),
|
||||
normalize=do_normalize,
|
||||
@@ -114,46 +130,53 @@ class VoxCPMDemo:
|
||||
|
||||
# ---------- UI Builders ----------
|
||||
|
||||
THEME = gr.themes.Soft(
|
||||
primary_hue="blue",
|
||||
secondary_hue="gray",
|
||||
neutral_hue="slate",
|
||||
font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"],
|
||||
)
|
||||
|
||||
CSS = """
|
||||
.logo-container {
|
||||
text-align: center;
|
||||
margin: 0.5rem 0 1rem 0;
|
||||
}
|
||||
.logo-container img {
|
||||
height: 80px;
|
||||
width: auto;
|
||||
max-width: 200px;
|
||||
display: inline-block;
|
||||
}
|
||||
/* Bold accordion labels */
|
||||
#acc_quick > .label-wrap,
|
||||
#acc_tips > .label-wrap,
|
||||
#acc_quick > .label-wrap > span,
|
||||
#acc_tips > .label-wrap > span,
|
||||
#acc_quick summary,
|
||||
#acc_tips summary {
|
||||
font-weight: 600 !important;
|
||||
font-size: 1.1em !important;
|
||||
}
|
||||
/* Bold labels for specific checkboxes */
|
||||
#chk_denoise label,
|
||||
#chk_denoise span,
|
||||
#chk_normalize label,
|
||||
#chk_normalize span {
|
||||
font-weight: 600;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def create_demo_interface(demo: VoxCPMDemo):
|
||||
"""Build the Gradio UI for VoxCPM demo."""
|
||||
# static assets (logo path)
|
||||
gr.set_static_paths(paths=[Path.cwd().absolute()/"assets"])
|
||||
gr.set_static_paths(paths=[Path.cwd().absolute() / "assets"])
|
||||
|
||||
with gr.Blocks(
|
||||
theme=gr.themes.Soft(
|
||||
primary_hue="blue",
|
||||
secondary_hue="gray",
|
||||
neutral_hue="slate",
|
||||
font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"]
|
||||
),
|
||||
css="""
|
||||
.logo-container {
|
||||
text-align: center;
|
||||
margin: 0.5rem 0 1rem 0;
|
||||
}
|
||||
.logo-container img {
|
||||
height: 80px;
|
||||
width: auto;
|
||||
max-width: 200px;
|
||||
display: inline-block;
|
||||
}
|
||||
/* Bold accordion labels */
|
||||
#acc_quick details > summary,
|
||||
#acc_tips details > summary {
|
||||
font-weight: 600 !important;
|
||||
font-size: 1.1em !important;
|
||||
}
|
||||
/* Bold labels for specific checkboxes */
|
||||
#chk_denoise label,
|
||||
#chk_denoise span,
|
||||
#chk_normalize label,
|
||||
#chk_normalize span {
|
||||
font-weight: 600;
|
||||
}
|
||||
"""
|
||||
) as interface:
|
||||
# Header logo
|
||||
gr.HTML('<div class="logo-container"><img src="/gradio_api/file=assets/voxcpm_logo.png" alt="VoxCPM Logo"></div>')
|
||||
with gr.Blocks() as interface:
|
||||
gr.HTML(
|
||||
'<div class="logo-container"><img src="/gradio_api/file=assets/voxcpm_logo.png" alt="VoxCPM Logo"></div>',
|
||||
padding=True,
|
||||
)
|
||||
|
||||
# Quick Start
|
||||
with gr.Accordion("📋 Quick Start Guide |快速入门", open=False, elem_id="acc_quick"):
|
||||
@@ -200,34 +223,56 @@ def create_demo_interface(demo: VoxCPMDemo):
|
||||
# Main controls
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
prompt_wav = gr.Audio(
|
||||
sources=["upload", 'microphone'],
|
||||
# 1. Reference Audio
|
||||
# gr.Markdown("### 🎤 Reference Audio (Optional)")
|
||||
# gr.Markdown("*提供参考音频进行音色克隆;不提供则使用 Voice Design 模式*")
|
||||
reference_wav = gr.Audio(
|
||||
sources=["upload", "microphone"],
|
||||
type="filepath",
|
||||
label="Prompt Speech (Optional, or let VoxCPM improvise)",
|
||||
value="./examples/example.wav",
|
||||
label="Reference Audio (Optional)",
|
||||
)
|
||||
DoDenoisePromptAudio = gr.Checkbox(
|
||||
value=False,
|
||||
label="Prompt Speech Enhancement",
|
||||
label="Reference Audio Enhancement",
|
||||
elem_id="chk_denoise",
|
||||
info="We use ZipEnhancer model to denoise the prompt audio."
|
||||
info="Use ZipEnhancer to denoise the reference audio",
|
||||
)
|
||||
with gr.Row():
|
||||
prompt_text = gr.Textbox(
|
||||
value="Just by listening a few minutes a day, you'll be able to eliminate negative thoughts by conditioning your mind to be more positive.",
|
||||
label="Prompt Text",
|
||||
placeholder="Please enter the prompt text. Automatic recognition is supported, and you can correct the results yourself..."
|
||||
)
|
||||
run_btn = gr.Button("Generate Speech", variant="primary")
|
||||
|
||||
# 2. Control Instruction
|
||||
# gr.Markdown("### 🎛️ Control Instruction (Optional)")
|
||||
# gr.Markdown("*描述声音风格、情感等,格式:`(instruction) text`*")
|
||||
control_instruction = gr.Textbox(
|
||||
value="",
|
||||
label="Control Instruction",
|
||||
placeholder="*描述声音风格、情感等,格式:`(instruction) text`,例如:年轻女性,温柔甜美 / 悲伤地说 / an excited young man*",
|
||||
lines=2,
|
||||
)
|
||||
|
||||
# 3. Target Text
|
||||
# gr.Markdown("### 📝 Target Text")
|
||||
text = gr.Textbox(
|
||||
value="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly realistic speech.",
|
||||
label="Target Text",
|
||||
lines=3,
|
||||
)
|
||||
DoNormalizeText = gr.Checkbox(
|
||||
value=False,
|
||||
label="Text Normalization",
|
||||
elem_id="chk_normalize",
|
||||
info="Use wetext library to normalize the input text",
|
||||
)
|
||||
|
||||
run_btn = gr.Button("🔊 Generate Speech", variant="primary", size="lg")
|
||||
|
||||
with gr.Column():
|
||||
gr.Markdown("### ⚙️ Generation Settings")
|
||||
cfg_value = gr.Slider(
|
||||
minimum=1.0,
|
||||
maximum=3.0,
|
||||
value=2.0,
|
||||
step=0.1,
|
||||
label="CFG Value (Guidance Scale)",
|
||||
info="Higher values increase adherence to prompt, lower values allow more creativity"
|
||||
info="Higher = more adherence to prompt; Lower = more creativity",
|
||||
)
|
||||
inference_timesteps = gr.Slider(
|
||||
minimum=4,
|
||||
@@ -235,41 +280,55 @@ def create_demo_interface(demo: VoxCPMDemo):
|
||||
value=10,
|
||||
step=1,
|
||||
label="Inference Timesteps",
|
||||
info="Number of inference timesteps for generation (higher values may improve quality but slower)"
|
||||
info="Higher = better quality but slower",
|
||||
)
|
||||
with gr.Row():
|
||||
text = gr.Textbox(
|
||||
value="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly realistic speech.",
|
||||
label="Target Text",
|
||||
)
|
||||
with gr.Row():
|
||||
DoNormalizeText = gr.Checkbox(
|
||||
value=False,
|
||||
label="Text Normalization",
|
||||
elem_id="chk_normalize",
|
||||
info="We use wetext library to normalize the input text."
|
||||
)
|
||||
audio_output = gr.Audio(label="Output Audio")
|
||||
|
||||
gr.Markdown("### 🔈 Output")
|
||||
audio_output = gr.Audio(label="Generated Audio")
|
||||
|
||||
gr.Markdown("""
|
||||
---
|
||||
**模式说明 / Mode Info:**
|
||||
- **有 Reference Audio** → Prompt 隔离模式(音色克隆)
|
||||
- **无 Reference Audio** → Voice Design 模式(用 Control Instruction 描述声音)
|
||||
|
||||
**Control Instruction 示例:**
|
||||
- `年轻女性,温柔甜美`
|
||||
- `悲伤地说`
|
||||
- `an excited young man`
|
||||
""")
|
||||
|
||||
# Wiring
|
||||
run_btn.click(
|
||||
fn=demo.generate_tts_audio,
|
||||
inputs=[text, prompt_wav, prompt_text, cfg_value, inference_timesteps, DoNormalizeText, DoDenoisePromptAudio],
|
||||
inputs=[
|
||||
text,
|
||||
control_instruction,
|
||||
reference_wav,
|
||||
cfg_value,
|
||||
inference_timesteps,
|
||||
DoNormalizeText,
|
||||
DoDenoisePromptAudio,
|
||||
],
|
||||
outputs=[audio_output],
|
||||
show_progress=True,
|
||||
api_name="generate",
|
||||
)
|
||||
prompt_wav.change(fn=demo.prompt_wav_recognition, inputs=[prompt_wav], outputs=[prompt_text])
|
||||
|
||||
return interface
|
||||
|
||||
|
||||
def run_demo(server_name: str = "localhost", server_port: int = 7860, show_error: bool = True):
|
||||
def run_demo(server_name: str = "0.0.0.0", server_port: int = 7869, show_error: bool = True):
|
||||
demo = VoxCPMDemo()
|
||||
interface = create_demo_interface(demo)
|
||||
# Recommended to enable queue on Spaces for better throughput
|
||||
interface.queue(max_size=10, default_concurrency_limit=1).launch(server_name=server_name, server_port=server_port, show_error=show_error)
|
||||
interface.queue(max_size=10, default_concurrency_limit=1).launch(
|
||||
server_name=server_name,
|
||||
server_port=server_port,
|
||||
show_error=show_error,
|
||||
theme=THEME,
|
||||
css=CSS,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_demo()
|
||||
run_demo()
|
||||
|
||||
@@ -25,8 +25,8 @@ lora:
|
||||
enable_lm: true
|
||||
enable_dit: true
|
||||
enable_proj: false
|
||||
r: 32
|
||||
alpha: 16
|
||||
r: 8
|
||||
alpha: 16
|
||||
dropout: 0.0
|
||||
|
||||
# Distribution options (optional)
|
||||
|
||||
@@ -25,7 +25,7 @@ lora:
|
||||
enable_lm: true
|
||||
enable_dit: true
|
||||
enable_proj: false
|
||||
r: 32
|
||||
r: 8
|
||||
alpha: 16
|
||||
dropout: 0.0
|
||||
|
||||
|
||||
Binary file not shown.
+247
-260
@@ -1,18 +1,14 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import glob
|
||||
import json
|
||||
import yaml
|
||||
import shutil
|
||||
import datetime
|
||||
import subprocess
|
||||
import threading
|
||||
import gradio as gr
|
||||
import torch
|
||||
import soundfile as sf
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
from typing import Optional
|
||||
|
||||
# Add src to sys.path
|
||||
project_root = Path(__file__).parent
|
||||
@@ -89,7 +85,7 @@ LANG_DICT = {
|
||||
"lang_select": "Language / 语言",
|
||||
"refresh": "刷新",
|
||||
"output_name": "输出目录名称 (可选,若存在则继续训练)",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Global variables
|
||||
@@ -98,9 +94,11 @@ asr_model: Optional[AutoModel] = None
|
||||
training_process: Optional[subprocess.Popen] = None
|
||||
training_log = ""
|
||||
|
||||
|
||||
def get_timestamp_str():
|
||||
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
|
||||
def get_or_load_asr_model():
|
||||
global asr_model
|
||||
if asr_model is None:
|
||||
@@ -109,44 +107,46 @@ def get_or_load_asr_model():
|
||||
asr_model = AutoModel(
|
||||
model="iic/SenseVoiceSmall",
|
||||
disable_update=True,
|
||||
log_level='ERROR',
|
||||
log_level="ERROR",
|
||||
device=device,
|
||||
)
|
||||
return asr_model
|
||||
|
||||
|
||||
def recognize_audio(audio_path):
|
||||
if not audio_path:
|
||||
return ""
|
||||
try:
|
||||
model = get_or_load_asr_model()
|
||||
res = model.generate(input=audio_path, language="auto", use_itn=True)
|
||||
text = res[0]["text"].split('|>')[-1]
|
||||
text = res[0]["text"].split("|>")[-1]
|
||||
return text
|
||||
except Exception as e:
|
||||
print(f"ASR Error: {e}", file=sys.stderr)
|
||||
return ""
|
||||
|
||||
|
||||
def scan_lora_checkpoints(root_dir="lora", with_info=False):
|
||||
"""
|
||||
Scans for LoRA checkpoints in the lora directory.
|
||||
|
||||
|
||||
Args:
|
||||
root_dir: Directory to scan for LoRA checkpoints
|
||||
with_info: If True, returns list of (path, base_model) tuples
|
||||
|
||||
|
||||
Returns:
|
||||
List of checkpoint paths, or list of (path, base_model) tuples if with_info=True
|
||||
"""
|
||||
checkpoints = []
|
||||
if not os.path.exists(root_dir):
|
||||
os.makedirs(root_dir, exist_ok=True)
|
||||
|
||||
|
||||
# Look for lora_weights.safetensors recursively
|
||||
for root, dirs, files in os.walk(root_dir):
|
||||
if "lora_weights.safetensors" in files:
|
||||
# Use the relative path from root_dir as the ID
|
||||
rel_path = os.path.relpath(root, root_dir)
|
||||
|
||||
|
||||
if with_info:
|
||||
# Try to read base_model from lora_config.json
|
||||
base_model = None
|
||||
@@ -161,15 +161,16 @@ def scan_lora_checkpoints(root_dir="lora", with_info=False):
|
||||
checkpoints.append((rel_path, base_model))
|
||||
else:
|
||||
checkpoints.append(rel_path)
|
||||
|
||||
|
||||
# Also check for checkpoints in the default location if they exist
|
||||
default_ckpt = "checkpoints/finetune_lora"
|
||||
if os.path.exists(os.path.join(root_dir, default_ckpt)):
|
||||
# This might be covered by the walk, but good to be sure
|
||||
pass
|
||||
# This might be covered by the walk, but good to be sure
|
||||
pass
|
||||
|
||||
return sorted(checkpoints, reverse=True)
|
||||
|
||||
|
||||
def load_lora_config_from_checkpoint(lora_path):
|
||||
"""Load LoRA config from lora_config.json if available."""
|
||||
lora_config_file = os.path.join(lora_path, "lora_config.json")
|
||||
@@ -184,6 +185,7 @@ def load_lora_config_from_checkpoint(lora_path):
|
||||
print(f"Warning: Failed to load lora_config.json: {e}", file=sys.stderr)
|
||||
return None, None
|
||||
|
||||
|
||||
def get_default_lora_config():
|
||||
"""Return default LoRA config for hot-swapping support."""
|
||||
return LoRAConfig(
|
||||
@@ -192,16 +194,17 @@ def get_default_lora_config():
|
||||
r=32,
|
||||
alpha=16,
|
||||
target_modules_lm=["q_proj", "v_proj", "k_proj", "o_proj"],
|
||||
target_modules_dit=["q_proj", "v_proj", "k_proj", "o_proj"]
|
||||
target_modules_dit=["q_proj", "v_proj", "k_proj", "o_proj"],
|
||||
)
|
||||
|
||||
|
||||
def load_model(pretrained_path, lora_path=None):
|
||||
global current_model
|
||||
print(f"Loading model from {pretrained_path}...", file=sys.stderr)
|
||||
|
||||
|
||||
lora_config = None
|
||||
lora_weights_path = None
|
||||
|
||||
|
||||
if lora_path:
|
||||
full_lora_path = os.path.join("lora", lora_path)
|
||||
if os.path.exists(full_lora_path):
|
||||
@@ -214,7 +217,7 @@ def load_model(pretrained_path, lora_path=None):
|
||||
# Fallback to default config for old checkpoints
|
||||
lora_config = get_default_lora_config()
|
||||
print("Using default LoRA config (lora_config.json not found)", file=sys.stderr)
|
||||
|
||||
|
||||
# Always init with a default LoRA config to allow hot-swapping later
|
||||
if lora_config is None:
|
||||
lora_config = get_default_lora_config()
|
||||
@@ -228,25 +231,24 @@ def load_model(pretrained_path, lora_path=None):
|
||||
)
|
||||
return "Model loaded successfully!"
|
||||
|
||||
|
||||
def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, steps, seed, pretrained_path=None):
|
||||
global current_model
|
||||
|
||||
# 如果选择了 LoRA 模型且当前模型未加载,尝试从 LoRA config 读取 base_model
|
||||
if current_model is None:
|
||||
# 优先使用用户指定的预训练模型路径
|
||||
base_model_path = pretrained_path if pretrained_path and pretrained_path.strip() else default_pretrained_path
|
||||
|
||||
|
||||
# 如果选择了 LoRA,尝试从其 config 读取 base_model
|
||||
if lora_selection and lora_selection != "None":
|
||||
full_lora_path = os.path.join("lora", lora_selection)
|
||||
lora_config_file = os.path.join(full_lora_path, "lora_config.json")
|
||||
|
||||
|
||||
if os.path.exists(lora_config_file):
|
||||
try:
|
||||
with open(lora_config_file, "r", encoding="utf-8") as f:
|
||||
lora_info = json.load(f)
|
||||
saved_base_model = lora_info.get("base_model")
|
||||
|
||||
|
||||
if saved_base_model:
|
||||
# 优先使用保存的 base_model 路径
|
||||
if os.path.exists(saved_base_model):
|
||||
@@ -257,11 +259,11 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
|
||||
print(f"Falling back to default: {base_model_path}", file=sys.stderr)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to read base_model from LoRA config: {e}", file=sys.stderr)
|
||||
|
||||
|
||||
# 加载模型
|
||||
try:
|
||||
print(f"Loading base model: {base_model_path}", file=sys.stderr)
|
||||
status_msg = load_model(base_model_path)
|
||||
load_model(base_model_path)
|
||||
if lora_selection and lora_selection != "None":
|
||||
print(f"Model loaded for LoRA: {lora_selection}", file=sys.stderr)
|
||||
except Exception as e:
|
||||
@@ -270,6 +272,7 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
|
||||
return None, error_msg
|
||||
|
||||
# Handle LoRA hot-swapping
|
||||
assert current_model is not None, "Model must be loaded before inference"
|
||||
if lora_selection and lora_selection != "None":
|
||||
full_lora_path = os.path.join("lora", lora_selection)
|
||||
print(f"Hot-loading LoRA: {full_lora_path}", file=sys.stderr)
|
||||
@@ -290,11 +293,11 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
|
||||
# 处理 prompt 参数:必须同时为 None 或同时有值
|
||||
final_prompt_wav = None
|
||||
final_prompt_text = None
|
||||
|
||||
|
||||
if prompt_wav and prompt_wav.strip():
|
||||
# 有参考音频
|
||||
final_prompt_wav = prompt_wav
|
||||
|
||||
|
||||
# 如果没有提供参考文本,尝试自动识别
|
||||
if not prompt_text or not prompt_text.strip():
|
||||
print("参考音频已提供但缺少文本,自动识别中...", file=sys.stderr)
|
||||
@@ -317,14 +320,16 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
|
||||
prompt_text=final_prompt_text,
|
||||
cfg_value=cfg_scale,
|
||||
inference_timesteps=steps,
|
||||
denoise=False
|
||||
denoise=False,
|
||||
)
|
||||
return (current_model.tts_model.sample_rate, audio_np), "Generation Success"
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return None, f"Error: {str(e)}"
|
||||
|
||||
|
||||
def start_training(
|
||||
pretrained_path,
|
||||
train_manifest,
|
||||
@@ -355,8 +360,8 @@ def start_training(
|
||||
hf_model_id="",
|
||||
distribute=False,
|
||||
):
|
||||
global training_process, training_log
|
||||
|
||||
global training_log
|
||||
|
||||
if training_process is not None and training_process.poll() is None:
|
||||
return "Training is already running!"
|
||||
|
||||
@@ -368,7 +373,7 @@ def start_training(
|
||||
save_dir = os.path.join("lora", timestamp)
|
||||
checkpoints_dir = os.path.join(save_dir, "checkpoints")
|
||||
logs_dir = os.path.join(save_dir, "logs")
|
||||
|
||||
|
||||
os.makedirs(checkpoints_dir, exist_ok=True)
|
||||
os.makedirs(logs_dir, exist_ok=True)
|
||||
|
||||
@@ -394,10 +399,7 @@ def start_training(
|
||||
"max_steps": resolved_max_steps,
|
||||
"save_path": checkpoints_dir,
|
||||
"tensorboard": tensorboard_path if tensorboard_path else logs_dir,
|
||||
"lambdas": {
|
||||
"loss/diff": 1.0,
|
||||
"loss/stop": 1.0
|
||||
},
|
||||
"lambdas": {"loss/diff": 1.0, "loss/stop": 1.0},
|
||||
"lora": {
|
||||
"enable_lm": bool(enable_lm),
|
||||
"enable_dit": bool(enable_dit),
|
||||
@@ -406,10 +408,10 @@ def start_training(
|
||||
"alpha": int(lora_alpha),
|
||||
"dropout": float(dropout),
|
||||
"target_modules_lm": ["q_proj", "v_proj", "k_proj", "o_proj"],
|
||||
"target_modules_dit": ["q_proj", "v_proj", "k_proj", "o_proj"]
|
||||
"target_modules_dit": ["q_proj", "v_proj", "k_proj", "o_proj"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Add distribution options if provided
|
||||
if hf_model_id and hf_model_id.strip():
|
||||
config["hf_model_id"] = hf_model_id.strip()
|
||||
@@ -420,49 +422,42 @@ def start_training(
|
||||
with open(config_path, "w") as f:
|
||||
yaml.dump(config, f)
|
||||
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"scripts/train_voxcpm_finetune.py",
|
||||
"--config_path",
|
||||
config_path
|
||||
]
|
||||
cmd = [sys.executable, "scripts/train_voxcpm_finetune.py", "--config_path", config_path]
|
||||
|
||||
training_log = f"Starting training...\nConfig saved to {config_path}\nOutput dir: {save_dir}\n"
|
||||
|
||||
|
||||
def run_process():
|
||||
global training_process, training_log
|
||||
training_process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1
|
||||
)
|
||||
|
||||
training_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
|
||||
|
||||
assert training_process.stdout is not None
|
||||
for line in training_process.stdout:
|
||||
training_log += line
|
||||
# Keep log size manageable
|
||||
if len(training_log) > 100000:
|
||||
training_log = training_log[-100000:]
|
||||
|
||||
|
||||
training_process.wait()
|
||||
training_log += f"\nTraining finished with code {training_process.returncode}"
|
||||
|
||||
threading.Thread(target=run_process, daemon=True).start()
|
||||
|
||||
|
||||
return f"Training started! Check 'lora/{timestamp}'"
|
||||
|
||||
|
||||
def get_training_log():
|
||||
return training_log
|
||||
|
||||
|
||||
def stop_training():
|
||||
global training_process, training_log
|
||||
global training_log
|
||||
if training_process is not None and training_process.poll() is None:
|
||||
training_process.terminate()
|
||||
training_log += "\nTraining terminated by user."
|
||||
return "Training stopped."
|
||||
return "No training running."
|
||||
|
||||
|
||||
# --- GUI Layout ---
|
||||
|
||||
# 自定义CSS样式
|
||||
@@ -830,14 +825,10 @@ label {
|
||||
}
|
||||
"""
|
||||
|
||||
with gr.Blocks(
|
||||
title="VoxCPM LoRA WebUI",
|
||||
theme=gr.themes.Soft(),
|
||||
css=custom_css
|
||||
) as app:
|
||||
|
||||
with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css) as app:
|
||||
|
||||
# State for language
|
||||
lang_state = gr.State("zh") # Default to Chinese
|
||||
lang_state = gr.State("zh") # Default to Chinese
|
||||
|
||||
# 标题区域
|
||||
with gr.Row(elem_classes="title-section"):
|
||||
@@ -850,10 +841,7 @@ with gr.Blocks(
|
||||
""")
|
||||
with gr.Column(scale=1):
|
||||
lang_btn = gr.Radio(
|
||||
choices=["en", "zh"],
|
||||
value="zh",
|
||||
label="🌐 Language / 语言",
|
||||
elem_classes="lang-selector"
|
||||
choices=["en", "zh"], value="zh", label="🌐 Language / 语言", elem_classes="lang-selector"
|
||||
)
|
||||
|
||||
with gr.Tabs(elem_classes="tabs") as tabs:
|
||||
@@ -869,79 +857,40 @@ with gr.Blocks(
|
||||
gr.Markdown("#### 📁 基础配置")
|
||||
|
||||
train_pretrained_path = gr.Textbox(
|
||||
label="📂 预训练模型路径",
|
||||
value=default_pretrained_path,
|
||||
elem_classes="input-field"
|
||||
label="📂 预训练模型路径", value=default_pretrained_path, elem_classes="input-field"
|
||||
)
|
||||
train_manifest = gr.Textbox(
|
||||
label="📋 训练数据清单 (jsonl)",
|
||||
value="examples/train_data_example.jsonl",
|
||||
elem_classes="input-field"
|
||||
)
|
||||
val_manifest = gr.Textbox(
|
||||
label="📊 验证数据清单 (可选)",
|
||||
value="",
|
||||
elem_classes="input-field"
|
||||
elem_classes="input-field",
|
||||
)
|
||||
val_manifest = gr.Textbox(label="📊 验证数据清单 (可选)", value="", elem_classes="input-field")
|
||||
|
||||
gr.Markdown("#### ⚙️ 训练参数")
|
||||
|
||||
with gr.Row():
|
||||
lr = gr.Number(
|
||||
label="📈 学习率 (Learning Rate)",
|
||||
value=1e-4,
|
||||
elem_classes="input-field"
|
||||
)
|
||||
lr = gr.Number(label="📈 学习率 (Learning Rate)", value=1e-4, elem_classes="input-field")
|
||||
num_iters = gr.Number(
|
||||
label="🔄 最大迭代次数",
|
||||
value=2000,
|
||||
precision=0,
|
||||
elem_classes="input-field"
|
||||
label="🔄 最大迭代次数", value=2000, precision=0, elem_classes="input-field"
|
||||
)
|
||||
batch_size = gr.Number(
|
||||
label="📦 批次大小 (Batch Size)",
|
||||
value=1,
|
||||
precision=0,
|
||||
elem_classes="input-field"
|
||||
label="📦 批次大小 (Batch Size)", value=1, precision=0, elem_classes="input-field"
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
lora_rank = gr.Number(
|
||||
label="🎯 LoRA Rank",
|
||||
value=32,
|
||||
precision=0,
|
||||
elem_classes="input-field"
|
||||
)
|
||||
lora_alpha = gr.Number(
|
||||
label="⚖️ LoRA Alpha",
|
||||
value=16,
|
||||
precision=0,
|
||||
elem_classes="input-field"
|
||||
)
|
||||
lora_rank = gr.Number(label="🎯 LoRA Rank", value=32, precision=0, elem_classes="input-field")
|
||||
lora_alpha = gr.Number(label="⚖️ LoRA Alpha", value=16, precision=0, elem_classes="input-field")
|
||||
save_interval = gr.Number(
|
||||
label="💾 保存间隔 (Steps)",
|
||||
value=1000,
|
||||
precision=0,
|
||||
elem_classes="input-field"
|
||||
label="💾 保存间隔 (Steps)", value=1000, precision=0, elem_classes="input-field"
|
||||
)
|
||||
|
||||
output_name = gr.Textbox(
|
||||
label="📁 输出目录名称 (可选,若存在则继续训练)",
|
||||
value="",
|
||||
elem_classes="input-field"
|
||||
label="📁 输出目录名称 (可选,若存在则继续训练)", value="", elem_classes="input-field"
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
start_btn = gr.Button(
|
||||
"▶️ 开始训练",
|
||||
variant="primary",
|
||||
elem_classes="button-primary"
|
||||
)
|
||||
stop_btn = gr.Button(
|
||||
"⏹️ 停止训练",
|
||||
variant="stop",
|
||||
elem_classes="button-stop"
|
||||
)
|
||||
start_btn = gr.Button("▶️ 开始训练", variant="primary", elem_classes="button-primary")
|
||||
stop_btn = gr.Button("⏹️ 停止训练", variant="stop", elem_classes="button-stop")
|
||||
|
||||
with gr.Accordion("🔧 高级选项 (Advanced)", open=False, elem_classes="accordion"):
|
||||
with gr.Row():
|
||||
@@ -961,10 +910,12 @@ with gr.Blocks(
|
||||
enable_dit = gr.Checkbox(label="启用 LoRA DIT (enable_dit)", value=True)
|
||||
enable_proj = gr.Checkbox(label="启用投影 (enable_proj)", value=False)
|
||||
dropout = gr.Number(label="LoRA Dropout", value=0.0)
|
||||
|
||||
|
||||
gr.Markdown("#### 分发选项 (Distribution)")
|
||||
with gr.Row():
|
||||
hf_model_id = gr.Textbox(label="HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)", value="openbmb/VoxCPM1.5")
|
||||
hf_model_id = gr.Textbox(
|
||||
label="HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)", value="openbmb/VoxCPM1.5"
|
||||
)
|
||||
distribute = gr.Checkbox(label="分发模式 (distribute)", value=False)
|
||||
|
||||
with gr.Column(scale=2, elem_classes="form-section"):
|
||||
@@ -975,26 +926,44 @@ with gr.Blocks(
|
||||
max_lines=30,
|
||||
interactive=False,
|
||||
elem_classes="input-field",
|
||||
show_label=False
|
||||
show_label=False,
|
||||
)
|
||||
|
||||
|
||||
start_btn.click(
|
||||
start_training,
|
||||
inputs=[
|
||||
train_pretrained_path, train_manifest, val_manifest,
|
||||
lr, num_iters, batch_size, lora_rank, lora_alpha, save_interval,
|
||||
train_pretrained_path,
|
||||
train_manifest,
|
||||
val_manifest,
|
||||
lr,
|
||||
num_iters,
|
||||
batch_size,
|
||||
lora_rank,
|
||||
lora_alpha,
|
||||
save_interval,
|
||||
output_name,
|
||||
# advanced
|
||||
grad_accum_steps, num_workers, log_interval, valid_interval,
|
||||
weight_decay, warmup_steps, max_steps, sample_rate,
|
||||
enable_lm, enable_dit, enable_proj, dropout, tensorboard_path,
|
||||
grad_accum_steps,
|
||||
num_workers,
|
||||
log_interval,
|
||||
valid_interval,
|
||||
weight_decay,
|
||||
warmup_steps,
|
||||
max_steps,
|
||||
sample_rate,
|
||||
enable_lm,
|
||||
enable_dit,
|
||||
enable_proj,
|
||||
dropout,
|
||||
tensorboard_path,
|
||||
# distribution
|
||||
hf_model_id, distribute
|
||||
hf_model_id,
|
||||
distribute,
|
||||
],
|
||||
outputs=[logs_out] # Initial message
|
||||
outputs=[logs_out], # Initial message
|
||||
)
|
||||
stop_btn.click(stop_training, outputs=[logs_out])
|
||||
|
||||
|
||||
# Log refresher
|
||||
timer = gr.Timer(1)
|
||||
timer.tick(get_training_log, outputs=logs_out)
|
||||
@@ -1016,21 +985,17 @@ with gr.Blocks(
|
||||
value="Hello, this is a test of the VoxCPM LoRA model.",
|
||||
elem_classes="input-field",
|
||||
lines=4,
|
||||
placeholder="输入要合成的文本内容..."
|
||||
placeholder="输入要合成的文本内容...",
|
||||
)
|
||||
|
||||
gr.Markdown("**🎭 声音克隆(可选)**")
|
||||
|
||||
prompt_wav = gr.Audio(
|
||||
label="🎵 参考音频",
|
||||
type="filepath",
|
||||
elem_classes="input-field"
|
||||
)
|
||||
|
||||
|
||||
prompt_wav = gr.Audio(label="🎵 参考音频", type="filepath", elem_classes="input-field")
|
||||
|
||||
prompt_text = gr.Textbox(
|
||||
label="📝 参考文本(可选)",
|
||||
elem_classes="input-field",
|
||||
placeholder="如不填写,将自动识别参考音频内容"
|
||||
placeholder="如不填写,将自动识别参考音频内容",
|
||||
)
|
||||
|
||||
# 中栏:模型选择和参数配置 (35%)
|
||||
@@ -1043,15 +1008,11 @@ with gr.Blocks(
|
||||
value="None",
|
||||
interactive=True,
|
||||
elem_classes="input-field",
|
||||
info="选择训练好的 LoRA 模型,或选择 None 使用基础模型"
|
||||
)
|
||||
|
||||
refresh_lora_btn = gr.Button(
|
||||
"🔄 刷新模型列表",
|
||||
elem_classes="button-refresh",
|
||||
size="sm"
|
||||
info="选择训练好的 LoRA 模型,或选择 None 使用基础模型",
|
||||
)
|
||||
|
||||
refresh_lora_btn = gr.Button("🔄 刷新模型列表", elem_classes="button-refresh", size="sm")
|
||||
|
||||
gr.Markdown("#### ⚙️ 生成参数")
|
||||
|
||||
cfg_scale = gr.Slider(
|
||||
@@ -1060,59 +1021,50 @@ with gr.Blocks(
|
||||
maximum=5.0,
|
||||
value=2.0,
|
||||
step=0.1,
|
||||
info="引导系数,值越大越贴近提示"
|
||||
info="引导系数,值越大越贴近提示",
|
||||
)
|
||||
|
||||
|
||||
steps = gr.Slider(
|
||||
label="🔢 推理步数",
|
||||
minimum=1,
|
||||
maximum=50,
|
||||
value=10,
|
||||
step=1,
|
||||
info="生成质量与步数成正比,但耗时更长"
|
||||
info="生成质量与步数成正比,但耗时更长",
|
||||
)
|
||||
|
||||
|
||||
seed = gr.Number(
|
||||
label="🎲 随机种子",
|
||||
value=-1,
|
||||
precision=0,
|
||||
elem_classes="input-field",
|
||||
info="-1 为随机,固定值可复现结果"
|
||||
info="-1 为随机,固定值可复现结果",
|
||||
)
|
||||
|
||||
generate_btn = gr.Button(
|
||||
"🎵 生成音频",
|
||||
variant="primary",
|
||||
elem_classes="button-primary",
|
||||
size="lg"
|
||||
)
|
||||
generate_btn = gr.Button("🎵 生成音频", variant="primary", elem_classes="button-primary", size="lg")
|
||||
|
||||
# 右栏:生成结果 (30%)
|
||||
with gr.Column(scale=30, elem_classes="form-section"):
|
||||
gr.Markdown("#### 🎧 生成结果")
|
||||
|
||||
audio_out = gr.Audio(
|
||||
label="",
|
||||
elem_classes="input-field",
|
||||
show_label=False
|
||||
)
|
||||
|
||||
audio_out = gr.Audio(label="", elem_classes="input-field", show_label=False)
|
||||
|
||||
gr.Markdown("#### 📋 状态信息")
|
||||
|
||||
|
||||
status_out = gr.Textbox(
|
||||
label="",
|
||||
interactive=False,
|
||||
elem_classes="input-field",
|
||||
show_label=False,
|
||||
lines=3,
|
||||
placeholder="等待生成..."
|
||||
placeholder="等待生成...",
|
||||
)
|
||||
|
||||
def refresh_loras():
|
||||
# 获取 LoRA checkpoints 及其 base model 信息
|
||||
checkpoints_with_info = scan_lora_checkpoints(with_info=True)
|
||||
choices = ["None"] + [ckpt[0] for ckpt in checkpoints_with_info]
|
||||
|
||||
|
||||
# 输出调试信息
|
||||
print(f"刷新 LoRA 列表: 找到 {len(checkpoints_with_info)} 个检查点", file=sys.stderr)
|
||||
for ckpt_path, base_model in checkpoints_with_info:
|
||||
@@ -1120,22 +1072,27 @@ with gr.Blocks(
|
||||
print(f" - {ckpt_path} (Base Model: {base_model})", file=sys.stderr)
|
||||
else:
|
||||
print(f" - {ckpt_path}", file=sys.stderr)
|
||||
|
||||
|
||||
return gr.update(choices=choices, value="None")
|
||||
|
||||
refresh_lora_btn.click(refresh_loras, outputs=[lora_select])
|
||||
|
||||
|
||||
# Auto-recognize audio when uploaded
|
||||
prompt_wav.change(
|
||||
fn=recognize_audio,
|
||||
inputs=[prompt_wav],
|
||||
outputs=[prompt_text]
|
||||
)
|
||||
|
||||
prompt_wav.change(fn=recognize_audio, inputs=[prompt_wav], outputs=[prompt_text])
|
||||
|
||||
generate_btn.click(
|
||||
run_inference,
|
||||
inputs=[infer_text, prompt_wav, prompt_text, lora_select, cfg_scale, steps, seed, train_pretrained_path],
|
||||
outputs=[audio_out, status_out]
|
||||
inputs=[
|
||||
infer_text,
|
||||
prompt_wav,
|
||||
prompt_text,
|
||||
lora_select,
|
||||
cfg_scale,
|
||||
steps,
|
||||
seed,
|
||||
train_pretrained_path,
|
||||
],
|
||||
outputs=[audio_out, status_out],
|
||||
)
|
||||
|
||||
# --- Language Switching Logic ---
|
||||
@@ -1144,111 +1101,141 @@ with gr.Blocks(
|
||||
# Labels for advanced options
|
||||
if lang == "zh":
|
||||
adv = {
|
||||
'grad_accum_steps': "梯度累积 (grad_accum_steps)",
|
||||
'num_workers': "数据加载线程 (num_workers)",
|
||||
'log_interval': "日志间隔 (log_interval)",
|
||||
'valid_interval': "验证间隔 (valid_interval)",
|
||||
'weight_decay': "权重衰减 (weight_decay)",
|
||||
'warmup_steps': "warmup_steps",
|
||||
'max_steps': "最大步数 (max_steps)",
|
||||
'sample_rate': "采样率 (sample_rate)",
|
||||
'enable_lm': "启用 LoRA LM (enable_lm)",
|
||||
'enable_dit': "启用 LoRA DIT (enable_dit)",
|
||||
'enable_proj': "启用投影 (enable_proj)",
|
||||
'dropout': "LoRA Dropout",
|
||||
'tensorboard_path': "Tensorboard 路径 (可选)",
|
||||
'hf_model_id': "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
|
||||
'distribute': "分发模式 (distribute)",
|
||||
"grad_accum_steps": "梯度累积 (grad_accum_steps)",
|
||||
"num_workers": "数据加载线程 (num_workers)",
|
||||
"log_interval": "日志间隔 (log_interval)",
|
||||
"valid_interval": "验证间隔 (valid_interval)",
|
||||
"weight_decay": "权重衰减 (weight_decay)",
|
||||
"warmup_steps": "warmup_steps",
|
||||
"max_steps": "最大步数 (max_steps)",
|
||||
"sample_rate": "采样率 (sample_rate)",
|
||||
"enable_lm": "启用 LoRA LM (enable_lm)",
|
||||
"enable_dit": "启用 LoRA DIT (enable_dit)",
|
||||
"enable_proj": "启用投影 (enable_proj)",
|
||||
"dropout": "LoRA Dropout",
|
||||
"tensorboard_path": "Tensorboard 路径 (可选)",
|
||||
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
|
||||
"distribute": "分发模式 (distribute)",
|
||||
}
|
||||
else:
|
||||
adv = {
|
||||
'grad_accum_steps': "Grad Accum Steps",
|
||||
'num_workers': "Num Workers",
|
||||
'log_interval': "Log Interval",
|
||||
'valid_interval': "Valid Interval",
|
||||
'weight_decay': "Weight Decay",
|
||||
'warmup_steps': "Warmup Steps",
|
||||
'max_steps': "Max Steps",
|
||||
'sample_rate': "Sample Rate",
|
||||
'enable_lm': "Enable LoRA LM",
|
||||
'enable_dit': "Enable LoRA DIT",
|
||||
'enable_proj': "Enable Projection",
|
||||
'dropout': "LoRA Dropout",
|
||||
'tensorboard_path': "Tensorboard Path (Optional)",
|
||||
'hf_model_id': "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
|
||||
'distribute': "Distribute Mode",
|
||||
"grad_accum_steps": "Grad Accum Steps",
|
||||
"num_workers": "Num Workers",
|
||||
"log_interval": "Log Interval",
|
||||
"valid_interval": "Valid Interval",
|
||||
"weight_decay": "Weight Decay",
|
||||
"warmup_steps": "Warmup Steps",
|
||||
"max_steps": "Max Steps",
|
||||
"sample_rate": "Sample Rate",
|
||||
"enable_lm": "Enable LoRA LM",
|
||||
"enable_dit": "Enable LoRA DIT",
|
||||
"enable_proj": "Enable Projection",
|
||||
"dropout": "LoRA Dropout",
|
||||
"tensorboard_path": "Tensorboard Path (Optional)",
|
||||
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
|
||||
"distribute": "Distribute Mode",
|
||||
}
|
||||
|
||||
return (
|
||||
gr.update(value=f"# {d['title']}"),
|
||||
gr.update(label=d['tab_train']),
|
||||
gr.update(label=d['tab_infer']),
|
||||
gr.update(label=d['pretrained_path']),
|
||||
gr.update(label=d['train_manifest']),
|
||||
gr.update(label=d['val_manifest']),
|
||||
gr.update(label=d['lr']),
|
||||
gr.update(label=d['max_iters']),
|
||||
gr.update(label=d['batch_size']),
|
||||
gr.update(label=d['lora_rank']),
|
||||
gr.update(label=d['lora_alpha']),
|
||||
gr.update(label=d['save_interval']),
|
||||
gr.update(label=d['output_name']),
|
||||
gr.update(value=d['start_train']),
|
||||
gr.update(value=d['stop_train']),
|
||||
gr.update(label=d['train_logs']),
|
||||
gr.update(label=d["tab_train"]),
|
||||
gr.update(label=d["tab_infer"]),
|
||||
gr.update(label=d["pretrained_path"]),
|
||||
gr.update(label=d["train_manifest"]),
|
||||
gr.update(label=d["val_manifest"]),
|
||||
gr.update(label=d["lr"]),
|
||||
gr.update(label=d["max_iters"]),
|
||||
gr.update(label=d["batch_size"]),
|
||||
gr.update(label=d["lora_rank"]),
|
||||
gr.update(label=d["lora_alpha"]),
|
||||
gr.update(label=d["save_interval"]),
|
||||
gr.update(label=d["output_name"]),
|
||||
gr.update(value=d["start_train"]),
|
||||
gr.update(value=d["stop_train"]),
|
||||
gr.update(label=d["train_logs"]),
|
||||
# Advanced options (must match outputs order)
|
||||
gr.update(label=adv['grad_accum_steps']),
|
||||
gr.update(label=adv['num_workers']),
|
||||
gr.update(label=adv['log_interval']),
|
||||
gr.update(label=adv['valid_interval']),
|
||||
gr.update(label=adv['weight_decay']),
|
||||
gr.update(label=adv['warmup_steps']),
|
||||
gr.update(label=adv['max_steps']),
|
||||
gr.update(label=adv['sample_rate']),
|
||||
gr.update(label=adv['enable_lm']),
|
||||
gr.update(label=adv['enable_dit']),
|
||||
gr.update(label=adv['enable_proj']),
|
||||
gr.update(label=adv['dropout']),
|
||||
gr.update(label=adv['tensorboard_path']),
|
||||
gr.update(label=adv["grad_accum_steps"]),
|
||||
gr.update(label=adv["num_workers"]),
|
||||
gr.update(label=adv["log_interval"]),
|
||||
gr.update(label=adv["valid_interval"]),
|
||||
gr.update(label=adv["weight_decay"]),
|
||||
gr.update(label=adv["warmup_steps"]),
|
||||
gr.update(label=adv["max_steps"]),
|
||||
gr.update(label=adv["sample_rate"]),
|
||||
gr.update(label=adv["enable_lm"]),
|
||||
gr.update(label=adv["enable_dit"]),
|
||||
gr.update(label=adv["enable_proj"]),
|
||||
gr.update(label=adv["dropout"]),
|
||||
gr.update(label=adv["tensorboard_path"]),
|
||||
# Distribution options
|
||||
gr.update(label=adv['hf_model_id']),
|
||||
gr.update(label=adv['distribute']),
|
||||
gr.update(label=adv["hf_model_id"]),
|
||||
gr.update(label=adv["distribute"]),
|
||||
# Inference section
|
||||
gr.update(label=d['text_to_synth']),
|
||||
gr.update(label=d['ref_audio']),
|
||||
gr.update(label=d['ref_text']),
|
||||
gr.update(label=d['select_lora']),
|
||||
gr.update(value=d['refresh']),
|
||||
gr.update(label=d['cfg_scale']),
|
||||
gr.update(label=d['infer_steps']),
|
||||
gr.update(label=d['seed']),
|
||||
gr.update(value=d['gen_audio']),
|
||||
gr.update(label=d['gen_output']),
|
||||
gr.update(label=d['status']),
|
||||
gr.update(label=d["text_to_synth"]),
|
||||
gr.update(label=d["ref_audio"]),
|
||||
gr.update(label=d["ref_text"]),
|
||||
gr.update(label=d["select_lora"]),
|
||||
gr.update(value=d["refresh"]),
|
||||
gr.update(label=d["cfg_scale"]),
|
||||
gr.update(label=d["infer_steps"]),
|
||||
gr.update(label=d["seed"]),
|
||||
gr.update(value=d["gen_audio"]),
|
||||
gr.update(label=d["gen_output"]),
|
||||
gr.update(label=d["status"]),
|
||||
)
|
||||
|
||||
lang_btn.change(
|
||||
change_language,
|
||||
inputs=[lang_btn],
|
||||
outputs=[
|
||||
title_md, tab_train, tab_infer,
|
||||
train_pretrained_path, train_manifest, val_manifest,
|
||||
lr, num_iters, batch_size, lora_rank, lora_alpha, save_interval,
|
||||
title_md,
|
||||
tab_train,
|
||||
tab_infer,
|
||||
train_pretrained_path,
|
||||
train_manifest,
|
||||
val_manifest,
|
||||
lr,
|
||||
num_iters,
|
||||
batch_size,
|
||||
lora_rank,
|
||||
lora_alpha,
|
||||
save_interval,
|
||||
output_name,
|
||||
start_btn, stop_btn, logs_out,
|
||||
start_btn,
|
||||
stop_btn,
|
||||
logs_out,
|
||||
# advanced outputs
|
||||
grad_accum_steps, num_workers, log_interval, valid_interval,
|
||||
weight_decay, warmup_steps, max_steps, sample_rate,
|
||||
enable_lm, enable_dit, enable_proj, dropout, tensorboard_path,
|
||||
grad_accum_steps,
|
||||
num_workers,
|
||||
log_interval,
|
||||
valid_interval,
|
||||
weight_decay,
|
||||
warmup_steps,
|
||||
max_steps,
|
||||
sample_rate,
|
||||
enable_lm,
|
||||
enable_dit,
|
||||
enable_proj,
|
||||
dropout,
|
||||
tensorboard_path,
|
||||
# distribution outputs
|
||||
hf_model_id, distribute,
|
||||
infer_text, prompt_wav, prompt_text,
|
||||
lora_select, refresh_lora_btn, cfg_scale, steps, seed,
|
||||
generate_btn, audio_out, status_out
|
||||
]
|
||||
hf_model_id,
|
||||
distribute,
|
||||
infer_text,
|
||||
prompt_wav,
|
||||
prompt_text,
|
||||
lora_select,
|
||||
refresh_lora_btn,
|
||||
cfg_scale,
|
||||
steps,
|
||||
seed,
|
||||
generate_btn,
|
||||
audio_out,
|
||||
status_out,
|
||||
],
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Ensure lora directory exists
|
||||
os.makedirs("lora", exist_ok=True)
|
||||
app.queue().launch(server_name="0.0.0.0", server_port=7860)
|
||||
app.queue().launch(server_name="0.0.0.0", server_port=7860)
|
||||
|
||||
+1
-3
@@ -30,7 +30,7 @@ dependencies = [
|
||||
"torchcodec",
|
||||
"transformers>=4.36.2",
|
||||
"einops",
|
||||
"gradio<6",
|
||||
"gradio>=6,<7",
|
||||
"inflect",
|
||||
"addict",
|
||||
"wetext",
|
||||
@@ -57,7 +57,6 @@ dev = [
|
||||
"pytest-cov>=2.0",
|
||||
"black>=21.0",
|
||||
"flake8>=3.8",
|
||||
"mypy>=0.800",
|
||||
"pre-commit>=2.0",
|
||||
]
|
||||
|
||||
@@ -90,7 +89,6 @@ extend-exclude = '''
|
||||
\.eggs
|
||||
| \.git
|
||||
| \.hg
|
||||
| \.mypy_cache
|
||||
| \.tox
|
||||
| \.venv
|
||||
| build
|
||||
|
||||
@@ -125,7 +125,10 @@ def main():
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
sf.write(str(out_path), audio_np, model.tts_model.sample_rate)
|
||||
|
||||
print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||
print(
|
||||
f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -112,22 +112,24 @@ def main():
|
||||
f"lora_config.json not found in {ckpt_dir}. "
|
||||
"Make sure the checkpoint was saved with the updated training script."
|
||||
)
|
||||
|
||||
|
||||
with open(lora_config_path, "r", encoding="utf-8") as f:
|
||||
lora_info = json.load(f)
|
||||
|
||||
|
||||
# Get base model path (command line arg overrides config)
|
||||
pretrained_path = args.base_model if args.base_model else lora_info.get("base_model")
|
||||
if not pretrained_path:
|
||||
raise ValueError("base_model not found in lora_config.json and --base_model not provided")
|
||||
|
||||
|
||||
# Get LoRA config
|
||||
lora_cfg_dict = lora_info.get("lora_config", {})
|
||||
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None
|
||||
|
||||
|
||||
print(f"Loaded config from: {lora_config_path}", file=sys.stderr)
|
||||
print(f" Base model: {pretrained_path}", file=sys.stderr)
|
||||
print(f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None", file=sys.stderr)
|
||||
print(
|
||||
f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None", file=sys.stderr
|
||||
)
|
||||
|
||||
# 3. Load model with LoRA (no denoiser)
|
||||
print(f"\n[1/2] Loading model with LoRA: {pretrained_path}", file=sys.stderr)
|
||||
@@ -146,10 +148,10 @@ def main():
|
||||
out_path = Path(args.output)
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"\n[2/2] Starting synthesis tests...", file=sys.stderr)
|
||||
|
||||
print("\n[2/2] Starting synthesis tests...", file=sys.stderr)
|
||||
|
||||
# === Test 1: With LoRA ===
|
||||
print(f"\n [Test 1] Synthesize with LoRA...", file=sys.stderr)
|
||||
print("\n [Test 1] Synthesize with LoRA...", file=sys.stderr)
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
@@ -162,10 +164,13 @@ def main():
|
||||
)
|
||||
lora_output = out_path.with_stem(out_path.stem + "_with_lora")
|
||||
sf.write(str(lora_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||
print(
|
||||
f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# === Test 2: Disable LoRA (via set_lora_enabled) ===
|
||||
print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...", file=sys.stderr)
|
||||
print("\n [Test 2] Disable LoRA (set_lora_enabled=False)...", file=sys.stderr)
|
||||
model.set_lora_enabled(False)
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
@@ -179,10 +184,13 @@ def main():
|
||||
)
|
||||
disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled")
|
||||
sf.write(str(disabled_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||
print(
|
||||
f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# === Test 3: Re-enable LoRA ===
|
||||
print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...", file=sys.stderr)
|
||||
print("\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...", file=sys.stderr)
|
||||
model.set_lora_enabled(True)
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
@@ -196,10 +204,13 @@ def main():
|
||||
)
|
||||
reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled")
|
||||
sf.write(str(reenabled_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||
print(
|
||||
f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# === Test 4: Unload LoRA (reset_lora_weights) ===
|
||||
print(f"\n [Test 4] Unload LoRA (unload_lora)...", file=sys.stderr)
|
||||
print("\n [Test 4] Unload LoRA (unload_lora)...", file=sys.stderr)
|
||||
model.unload_lora()
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
@@ -213,10 +224,13 @@ def main():
|
||||
)
|
||||
reset_output = out_path.with_stem(out_path.stem + "_lora_reset")
|
||||
sf.write(str(reset_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||
print(
|
||||
f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# === Test 5: Hot-reload LoRA (load_lora) ===
|
||||
print(f"\n [Test 5] Hot-reload LoRA (load_lora)...", file=sys.stderr)
|
||||
print("\n [Test 5] Hot-reload LoRA (load_lora)...", file=sys.stderr)
|
||||
loaded, skipped = model.load_lora(ckpt_dir)
|
||||
print(f" Reloaded {len(loaded)} parameters", file=sys.stderr)
|
||||
audio_np = model.generate(
|
||||
@@ -231,9 +245,12 @@ def main():
|
||||
)
|
||||
reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded")
|
||||
sf.write(str(reload_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||
print(
|
||||
f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
print(f"\n[Done] All tests completed!", file=sys.stderr)
|
||||
print("\n[Done] All tests completed!", file=sys.stderr)
|
||||
print(f" - with_lora: {lora_output}", file=sys.stderr)
|
||||
print(f" - lora_disabled: {disabled_output}", file=sys.stderr)
|
||||
print(f" - lora_reenabled: {reenabled_output}", file=sys.stderr)
|
||||
|
||||
@@ -7,7 +7,7 @@ project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root / "src"))
|
||||
|
||||
import contextlib
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict
|
||||
|
||||
import argbind
|
||||
import torch
|
||||
@@ -17,16 +17,19 @@ from transformers import get_cosine_schedule_with_warmup
|
||||
import signal
|
||||
import os
|
||||
|
||||
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
try:
|
||||
from safetensors.torch import save_file
|
||||
|
||||
SAFETENSORS_AVAILABLE = True
|
||||
except ImportError:
|
||||
SAFETENSORS_AVAILABLE = False
|
||||
print("Warning: safetensors not available, will use pytorch format", file=sys.stderr)
|
||||
|
||||
from voxcpm.model import VoxCPMModel
|
||||
import json
|
||||
|
||||
from voxcpm.model import VoxCPMModel, VoxCPM2Model
|
||||
from voxcpm.model.voxcpm import LoRAConfig
|
||||
from voxcpm.training import (
|
||||
Accelerator,
|
||||
@@ -61,15 +64,15 @@ def train(
|
||||
lora: dict = None,
|
||||
config_path: str = "",
|
||||
# Distribution options (for LoRA checkpoints)
|
||||
hf_model_id: str = "", # HuggingFace model ID (e.g., "openbmb/VoxCPM1.5")
|
||||
distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path
|
||||
hf_model_id: str = "", # HuggingFace model ID (e.g., "openbmb/VoxCPM1.5")
|
||||
distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path
|
||||
):
|
||||
_ = config_path
|
||||
|
||||
|
||||
# Validate distribution options
|
||||
if lora is not None and distribute and not hf_model_id:
|
||||
raise ValueError("hf_model_id is required when distribute=True")
|
||||
|
||||
|
||||
accelerator = Accelerator(amp=True)
|
||||
|
||||
save_dir = Path(save_path)
|
||||
@@ -84,7 +87,15 @@ def train(
|
||||
writer = SummaryWriter(log_dir=str(tb_dir)) if accelerator.rank == 0 else None
|
||||
tracker = TrainingTracker(writer=writer, log_file=str(save_dir / "train.log"), rank=accelerator.rank)
|
||||
|
||||
base_model = VoxCPMModel.from_local(pretrained_path, optimize=False, training=True, lora_config=LoRAConfig(**lora) if lora else None)
|
||||
# Auto-detect model architecture from config.json
|
||||
with open(os.path.join(pretrained_path, "config.json"), "r", encoding="utf-8") as _f:
|
||||
_arch = json.load(_f).get("architecture", "voxcpm").lower()
|
||||
_model_cls = VoxCPM2Model if _arch == "voxcpm2" else VoxCPMModel
|
||||
if accelerator.rank == 0:
|
||||
print(f"Detected architecture: {_arch} -> {_model_cls.__name__}", file=sys.stderr)
|
||||
base_model = _model_cls.from_local(
|
||||
pretrained_path, optimize=False, training=True, lora_config=LoRAConfig(**lora) if lora else None
|
||||
)
|
||||
tokenizer = base_model.text_tokenizer
|
||||
|
||||
train_ds, val_ds = load_audio_text_datasets(
|
||||
@@ -166,7 +177,6 @@ def train(
|
||||
unwrapped_model = accelerator.unwrap(model)
|
||||
unwrapped_model.train()
|
||||
|
||||
|
||||
# Only print param info on rank 0 to avoid cluttered output
|
||||
if accelerator.rank == 0:
|
||||
for name, param in model.named_parameters():
|
||||
@@ -191,7 +201,7 @@ def train(
|
||||
# All ranks load the same checkpoint to keep model and optimizer state in sync.
|
||||
start_step = load_checkpoint(model, optimizer, scheduler, save_dir, rank=accelerator.rank)
|
||||
accelerator.barrier()
|
||||
|
||||
|
||||
if start_step > 0 and accelerator.rank == 0:
|
||||
tracker.print(f"Resuming training from step {start_step}")
|
||||
|
||||
@@ -199,7 +209,19 @@ def train(
|
||||
resume = {"step": start_step}
|
||||
|
||||
# Register signal handler to save checkpoint on termination (SIGTERM/SIGINT)
|
||||
def _signal_handler(signum, frame, _model=model, _optim=optimizer, _sched=scheduler, _save_dir=save_dir, _pretrained=pretrained_path, _hf_id=hf_model_id, _dist=distribute, _resume=resume, _rank=accelerator.rank):
|
||||
def _signal_handler(
|
||||
signum,
|
||||
frame,
|
||||
_model=model,
|
||||
_optim=optimizer,
|
||||
_sched=scheduler,
|
||||
_save_dir=save_dir,
|
||||
_pretrained=pretrained_path,
|
||||
_hf_id=hf_model_id,
|
||||
_dist=distribute,
|
||||
_resume=resume,
|
||||
_rank=accelerator.rank,
|
||||
):
|
||||
try:
|
||||
cur_step = int(_resume.get("step", start_step))
|
||||
except Exception:
|
||||
@@ -229,8 +251,8 @@ def train(
|
||||
except StopIteration:
|
||||
data_epoch += 1
|
||||
# Key: set DistributedSampler epoch to ensure different data order each epoch
|
||||
sampler = getattr(train_loader, 'sampler', None)
|
||||
if hasattr(sampler, 'set_epoch'):
|
||||
sampler = getattr(train_loader, "sampler", None)
|
||||
if hasattr(sampler, "set_epoch"):
|
||||
sampler.set_epoch(data_epoch)
|
||||
train_iter = iter(train_loader)
|
||||
return next(train_iter)
|
||||
@@ -250,7 +272,7 @@ def train(
|
||||
|
||||
# Only sync gradients on the last micro-batch
|
||||
# Use no_sync() for intermediate steps to reduce communication overhead
|
||||
is_last_micro_step = (micro_step == grad_accum_steps - 1)
|
||||
is_last_micro_step = micro_step == grad_accum_steps - 1
|
||||
sync_context = contextlib.nullcontext() if is_last_micro_step else accelerator.no_sync()
|
||||
|
||||
with sync_context:
|
||||
@@ -299,10 +321,22 @@ def train(
|
||||
tracker.log_metrics(loss_values, split="train")
|
||||
|
||||
if val_loader is not None and (step % valid_interval == 0 or step == num_iters - 1):
|
||||
validate(model, val_loader, batch_processor, accelerator, tracker, lambdas,
|
||||
writer=writer, step=step, val_ds=val_ds, audio_vae=audio_vae_for_gen,
|
||||
sample_rate=sample_rate, val_texts=val_texts, tokenizer=tokenizer,
|
||||
valid_interval=valid_interval)
|
||||
validate(
|
||||
model,
|
||||
val_loader,
|
||||
batch_processor,
|
||||
accelerator,
|
||||
tracker,
|
||||
lambdas,
|
||||
writer=writer,
|
||||
step=step,
|
||||
val_ds=val_ds,
|
||||
audio_vae=audio_vae_for_gen,
|
||||
sample_rate=sample_rate,
|
||||
val_texts=val_texts,
|
||||
tokenizer=tokenizer,
|
||||
valid_interval=valid_interval,
|
||||
)
|
||||
|
||||
if (step % save_interval == 0 or step == num_iters - 1) and accelerator.rank == 0:
|
||||
save_checkpoint(model, optimizer, scheduler, save_dir, step, pretrained_path, hf_model_id, distribute)
|
||||
@@ -313,13 +347,26 @@ def train(
|
||||
writer.close()
|
||||
|
||||
|
||||
def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas,
|
||||
writer=None, step=0, val_ds=None, audio_vae=None, sample_rate=22050,
|
||||
val_texts=None, tokenizer=None, valid_interval=1000):
|
||||
def validate(
|
||||
model,
|
||||
val_loader,
|
||||
batch_processor,
|
||||
accelerator,
|
||||
tracker,
|
||||
lambdas,
|
||||
writer=None,
|
||||
step=0,
|
||||
val_ds=None,
|
||||
audio_vae=None,
|
||||
sample_rate=22050,
|
||||
val_texts=None,
|
||||
tokenizer=None,
|
||||
valid_interval=1000,
|
||||
):
|
||||
"""Validate and generate sample audio"""
|
||||
import numpy as np
|
||||
import numpy as np # noqa: F401
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
model.eval()
|
||||
total_losses = []
|
||||
sub_losses = defaultdict(list) # Track individual sub-losses
|
||||
@@ -356,26 +403,37 @@ def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas,
|
||||
# Compute mean total loss
|
||||
mean_total_loss = torch.stack(total_losses).mean()
|
||||
accelerator.all_reduce(mean_total_loss)
|
||||
|
||||
|
||||
# Compute mean of each sub-loss
|
||||
val_metrics = {"loss/total": mean_total_loss.item()}
|
||||
for key, values in sub_losses.items():
|
||||
mean_sub_loss = torch.stack(values).mean()
|
||||
accelerator.all_reduce(mean_sub_loss)
|
||||
val_metrics[key] = mean_sub_loss.item()
|
||||
|
||||
|
||||
tracker.log_metrics(val_metrics, split="val")
|
||||
|
||||
|
||||
# Generate sample audio for TensorBoard display
|
||||
if writer is not None and val_ds is not None and audio_vae is not None and accelerator.rank == 0:
|
||||
try:
|
||||
generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, sample_rate,
|
||||
val_texts=val_texts, tokenizer=tokenizer, valid_interval=valid_interval,
|
||||
tracker=tracker)
|
||||
generate_sample_audio(
|
||||
model,
|
||||
val_ds,
|
||||
audio_vae,
|
||||
writer,
|
||||
step,
|
||||
accelerator,
|
||||
sample_rate,
|
||||
val_texts=val_texts,
|
||||
tokenizer=tokenizer,
|
||||
valid_interval=valid_interval,
|
||||
tracker=tracker,
|
||||
)
|
||||
except Exception as e:
|
||||
tracker.print(f"[Warning] Failed to generate sample audio: {e}")
|
||||
import traceback
|
||||
import io
|
||||
|
||||
buf = io.StringIO()
|
||||
traceback.print_exc(file=buf)
|
||||
tracker.print(buf.getvalue())
|
||||
@@ -390,7 +448,7 @@ def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas,
|
||||
missing.append("audio_vae")
|
||||
if missing and accelerator.rank == 0:
|
||||
tracker.print(f"[Warning] Skip audio generation: missing {', '.join(missing)}")
|
||||
|
||||
|
||||
model.train()
|
||||
|
||||
|
||||
@@ -398,6 +456,7 @@ def compute_mel_spectrogram(audio_np, sample_rate, n_mels=128):
|
||||
"""Compute Mel Spectrogram (dB) using librosa"""
|
||||
import numpy as np
|
||||
import librosa
|
||||
|
||||
audio_np = audio_np.flatten().astype(np.float32)
|
||||
mel = librosa.feature.melspectrogram(y=audio_np, sr=sample_rate, n_mels=n_mels, fmax=sample_rate // 2)
|
||||
return librosa.power_to_db(mel, ref=np.max)
|
||||
@@ -408,31 +467,45 @@ def create_mel_figure(gen_audio_np, gen_mel, sample_rate, step=None, ref_audio_n
|
||||
Create mel spectrogram figure: show comparison if reference audio exists, otherwise show generated only
|
||||
"""
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
import librosa.display
|
||||
|
||||
|
||||
fmax = sample_rate // 2
|
||||
step_str = f" @ Step {step}" if step is not None else ""
|
||||
|
||||
|
||||
if ref_audio_np is not None and ref_mel is not None:
|
||||
# Comparison mode: reference vs generated
|
||||
fig, (ax_ref, ax_gen) = plt.subplots(2, 1, figsize=(12, 8))
|
||||
|
||||
img_ref = librosa.display.specshow(ref_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax_ref)
|
||||
ax_ref.set_title(f'Reference (GT) - {len(ref_audio_np)/sample_rate:.2f}s{step_str}', fontsize=10, fontweight='bold', color='#28A745')
|
||||
plt.colorbar(img_ref, ax=ax_ref, format='%+2.0f dB', pad=0.02)
|
||||
|
||||
img_gen = librosa.display.specshow(gen_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax_gen)
|
||||
ax_gen.set_title(f'Generated - {len(gen_audio_np)/sample_rate:.2f}s', fontsize=10, fontweight='bold', color='#DC3545')
|
||||
plt.colorbar(img_gen, ax=ax_gen, format='%+2.0f dB', pad=0.02)
|
||||
|
||||
img_ref = librosa.display.specshow(
|
||||
ref_mel, sr=sample_rate, x_axis="time", y_axis="mel", fmax=fmax, cmap="viridis", ax=ax_ref
|
||||
)
|
||||
ax_ref.set_title(
|
||||
f"Reference (GT) - {len(ref_audio_np)/sample_rate:.2f}s{step_str}",
|
||||
fontsize=10,
|
||||
fontweight="bold",
|
||||
color="#28A745",
|
||||
)
|
||||
plt.colorbar(img_ref, ax=ax_ref, format="%+2.0f dB", pad=0.02)
|
||||
|
||||
img_gen = librosa.display.specshow(
|
||||
gen_mel, sr=sample_rate, x_axis="time", y_axis="mel", fmax=fmax, cmap="viridis", ax=ax_gen
|
||||
)
|
||||
ax_gen.set_title(
|
||||
f"Generated - {len(gen_audio_np)/sample_rate:.2f}s", fontsize=10, fontweight="bold", color="#DC3545"
|
||||
)
|
||||
plt.colorbar(img_gen, ax=ax_gen, format="%+2.0f dB", pad=0.02)
|
||||
else:
|
||||
# Single figure mode: show generated only
|
||||
fig, ax = plt.subplots(figsize=(12, 4))
|
||||
img = librosa.display.specshow(gen_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax)
|
||||
ax.set_title(f'Generated - {len(gen_audio_np)/sample_rate:.2f}s{step_str}', fontsize=11, fontweight='bold')
|
||||
plt.colorbar(img, ax=ax, format='%+2.0f dB', pad=0.02)
|
||||
|
||||
img = librosa.display.specshow(
|
||||
gen_mel, sr=sample_rate, x_axis="time", y_axis="mel", fmax=fmax, cmap="viridis", ax=ax
|
||||
)
|
||||
ax.set_title(f"Generated - {len(gen_audio_np)/sample_rate:.2f}s{step_str}", fontsize=11, fontweight="bold")
|
||||
plt.colorbar(img, ax=ax, format="%+2.0f dB", pad=0.02)
|
||||
|
||||
plt.tight_layout()
|
||||
return fig
|
||||
|
||||
@@ -440,26 +513,38 @@ def create_mel_figure(gen_audio_np, gen_mel, sample_rate, step=None, ref_audio_n
|
||||
def normalize_audio(audio_np):
|
||||
"""Normalize audio to [-0.9, 0.9]"""
|
||||
import numpy as np
|
||||
|
||||
max_val = np.abs(audio_np).max()
|
||||
return audio_np / max_val * 0.9 if max_val > 0 else audio_np
|
||||
|
||||
|
||||
def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, sample_rate=22050,
|
||||
val_texts=None, tokenizer=None, pretrained_path=None, valid_interval=1000,
|
||||
tracker=None):
|
||||
def generate_sample_audio(
|
||||
model,
|
||||
val_ds,
|
||||
audio_vae,
|
||||
writer,
|
||||
step,
|
||||
accelerator,
|
||||
sample_rate=22050,
|
||||
val_texts=None,
|
||||
tokenizer=None,
|
||||
pretrained_path=None,
|
||||
valid_interval=1000,
|
||||
tracker=None,
|
||||
):
|
||||
"""Select 2 fixed validation samples, generate audio and log to TensorBoard"""
|
||||
import numpy as np
|
||||
|
||||
|
||||
log = tracker.print if tracker else print
|
||||
num_samples = min(2, len(val_ds))
|
||||
log(f"[Audio] Starting audio generation for {num_samples} samples at step {step}")
|
||||
|
||||
|
||||
unwrapped_model = accelerator.unwrap(model)
|
||||
|
||||
|
||||
for i in range(num_samples):
|
||||
sample = val_ds[i]
|
||||
text = val_texts[i] if val_texts and i < len(val_texts) else "Hello, this is a test."
|
||||
|
||||
|
||||
# Load reference audio
|
||||
ref_audio_np = None
|
||||
try:
|
||||
@@ -468,7 +553,10 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
|
||||
ref_sr = sample["audio"].get("sampling_rate", sample_rate)
|
||||
if ref_sr != sample_rate:
|
||||
import torchaudio.functional as F
|
||||
ref_audio_np = F.resample(torch.from_numpy(ref_audio_np).unsqueeze(0), ref_sr, sample_rate).squeeze(0).numpy()
|
||||
|
||||
ref_audio_np = (
|
||||
F.resample(torch.from_numpy(ref_audio_np).unsqueeze(0), ref_sr, sample_rate).squeeze(0).numpy()
|
||||
)
|
||||
log(f"[Audio] Loaded reference audio for sample {i}: duration={len(ref_audio_np)/sample_rate:.2f}s")
|
||||
except Exception as e:
|
||||
log(f"[Warning] Failed to load reference audio: {e}")
|
||||
@@ -480,7 +568,7 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
|
||||
unwrapped_model.eval()
|
||||
# unwrapped_model.to(torch.bfloat16)
|
||||
unwrapped_model.audio_vae = audio_vae.to(torch.float32)
|
||||
|
||||
|
||||
log(f"[Audio] Generating sample {i} with text: '{text[:50]}...'")
|
||||
autocast_ctx = (
|
||||
torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
@@ -490,27 +578,33 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
|
||||
with torch.no_grad():
|
||||
with autocast_ctx:
|
||||
generated = unwrapped_model.generate(target_text=text, inference_timesteps=10, cfg_value=2.0)
|
||||
|
||||
|
||||
# Restore training setup
|
||||
# unwrapped_model.to(torch.float32)
|
||||
# unwrapped_model.audio_vae = None
|
||||
|
||||
|
||||
if generated is None or len(generated) == 0:
|
||||
log(f"[Warning] Generated audio is empty for sample {i}")
|
||||
continue
|
||||
|
||||
|
||||
# Process generated audio
|
||||
gen_audio_np = generated.cpu().float().numpy().flatten() if isinstance(generated, torch.Tensor) else np.array(generated, dtype=np.float32).flatten()
|
||||
gen_audio_np = (
|
||||
generated.cpu().float().numpy().flatten()
|
||||
if isinstance(generated, torch.Tensor)
|
||||
else np.array(generated, dtype=np.float32).flatten()
|
||||
)
|
||||
gen_audio_np = normalize_audio(gen_audio_np)
|
||||
|
||||
|
||||
tag = f"val_sample_{i}"
|
||||
writer.add_audio(f"{tag}/generated_audio", gen_audio_np, global_step=step, sample_rate=sample_rate)
|
||||
log(f"[Audio] Generated audio for sample {i}: duration={len(gen_audio_np)/sample_rate:.2f}s")
|
||||
|
||||
|
||||
# Log reference audio
|
||||
if ref_audio_np is not None:
|
||||
writer.add_audio(f"{tag}/reference_audio", normalize_audio(ref_audio_np), global_step=step, sample_rate=sample_rate)
|
||||
|
||||
writer.add_audio(
|
||||
f"{tag}/reference_audio", normalize_audio(ref_audio_np), global_step=step, sample_rate=sample_rate
|
||||
)
|
||||
|
||||
# Generate mel spectrogram figure
|
||||
try:
|
||||
mel_gen = compute_mel_spectrogram(gen_audio_np, sample_rate)
|
||||
@@ -520,10 +614,11 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
|
||||
log(f"[Audio] Created mel spectrogram figure for sample {i}")
|
||||
except Exception as e:
|
||||
log(f"[Warning] Failed to create mel spectrogram: {e}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
log(f"[Warning] Failed to generate audio for sample {i}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
finally:
|
||||
@@ -545,30 +640,29 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0):
|
||||
Called by all ranks so that distributed state stays aligned.
|
||||
Returns the step number to resume from, or 0 if no checkpoint found.
|
||||
"""
|
||||
import json
|
||||
|
||||
latest_folder = save_dir / "latest"
|
||||
if not latest_folder.exists():
|
||||
return 0
|
||||
|
||||
|
||||
unwrapped = model.module if hasattr(model, "module") else model
|
||||
lora_cfg = unwrapped.lora_config
|
||||
|
||||
|
||||
# Load model weights
|
||||
if lora_cfg is not None:
|
||||
# LoRA: load lora_weights
|
||||
lora_weights_path = latest_folder / "lora_weights.safetensors"
|
||||
if not lora_weights_path.exists():
|
||||
lora_weights_path = latest_folder / "lora_weights.ckpt"
|
||||
|
||||
|
||||
if lora_weights_path.exists():
|
||||
if lora_weights_path.suffix == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
state_dict = load_file(str(lora_weights_path))
|
||||
else:
|
||||
ckpt = torch.load(lora_weights_path, map_location="cpu")
|
||||
state_dict = ckpt.get("state_dict", ckpt)
|
||||
|
||||
|
||||
unwrapped.load_state_dict(state_dict, strict=False)
|
||||
if rank == 0:
|
||||
print(f"Loaded LoRA weights from {lora_weights_path}", file=sys.stderr)
|
||||
@@ -577,33 +671,34 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0):
|
||||
model_path = latest_folder / "model.safetensors"
|
||||
if not model_path.exists():
|
||||
model_path = latest_folder / "pytorch_model.bin"
|
||||
|
||||
|
||||
if model_path.exists():
|
||||
if model_path.suffix == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
state_dict = load_file(str(model_path))
|
||||
else:
|
||||
ckpt = torch.load(model_path, map_location="cpu")
|
||||
state_dict = ckpt.get("state_dict", ckpt)
|
||||
|
||||
|
||||
unwrapped.load_state_dict(state_dict, strict=False)
|
||||
if rank == 0:
|
||||
print(f"Loaded model weights from {model_path}", file=sys.stderr)
|
||||
|
||||
|
||||
# Load optimizer state
|
||||
optimizer_path = latest_folder / "optimizer.pth"
|
||||
if optimizer_path.exists():
|
||||
optimizer.load_state_dict(torch.load(optimizer_path, map_location="cpu"))
|
||||
if rank == 0:
|
||||
print(f"Loaded optimizer state from {optimizer_path}", file=sys.stderr)
|
||||
|
||||
|
||||
# Load scheduler state
|
||||
scheduler_path = latest_folder / "scheduler.pth"
|
||||
if scheduler_path.exists():
|
||||
scheduler.load_state_dict(torch.load(scheduler_path, map_location="cpu"))
|
||||
if rank == 0:
|
||||
print(f"Loaded scheduler state from {scheduler_path}", file=sys.stderr)
|
||||
|
||||
|
||||
state_path = latest_folder / "training_state.json"
|
||||
if state_path.exists():
|
||||
with open(state_path, "r", encoding="utf-8") as f:
|
||||
@@ -621,28 +716,36 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0):
|
||||
if rank == 0:
|
||||
print(f"Resuming from step {resume_step}", file=sys.stderr)
|
||||
return resume_step
|
||||
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pretrained_path: str = None, hf_model_id: str = "", distribute: bool = False):
|
||||
def save_checkpoint(
|
||||
model,
|
||||
optimizer,
|
||||
scheduler,
|
||||
save_dir: Path,
|
||||
step: int,
|
||||
pretrained_path: str = None,
|
||||
hf_model_id: str = "",
|
||||
distribute: bool = False,
|
||||
):
|
||||
"""
|
||||
Save checkpoint with different strategies for full finetune vs LoRA:
|
||||
- Full finetune: save non-vae weights to model.safetensors (or pytorch_model.bin if safetensors unavailable)
|
||||
- LoRA: save only lora weights to lora_weights.safetensors (or lora_weights.ckpt if safetensors unavailable)
|
||||
"""
|
||||
import json
|
||||
import shutil
|
||||
|
||||
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
tag = f"step_{step:07d}"
|
||||
folder = save_dir / tag
|
||||
folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
unwrapped = model.module if hasattr(model, "module") else model
|
||||
full_state = unwrapped.state_dict()
|
||||
lora_cfg = unwrapped.lora_config
|
||||
|
||||
|
||||
if lora_cfg is not None:
|
||||
# LoRA finetune: save only lora_A/lora_B weights
|
||||
state_dict = {k: v for k, v in full_state.items() if "lora_" in k}
|
||||
@@ -650,7 +753,7 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
|
||||
save_file(state_dict, folder / "lora_weights.safetensors")
|
||||
else:
|
||||
torch.save({"state_dict": state_dict}, folder / "lora_weights.ckpt")
|
||||
|
||||
|
||||
# Save LoRA config and base model path to a separate JSON file
|
||||
# If distribute=True, save hf_model_id; otherwise save local pretrained_path
|
||||
base_model_to_save = hf_model_id if distribute else (str(pretrained_path) if pretrained_path else None)
|
||||
@@ -667,16 +770,23 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
|
||||
save_file(state_dict, folder / "model.safetensors")
|
||||
else:
|
||||
torch.save({"state_dict": state_dict}, folder / "pytorch_model.bin")
|
||||
|
||||
|
||||
# Copy config files from pretrained path
|
||||
if pretrained_path:
|
||||
pretrained_dir = Path(pretrained_path)
|
||||
files_to_copy = ["config.json", "audiovae.pth", "tokenizer.json", "special_tokens_map.json", "tokenizer_config.json"]
|
||||
files_to_copy = [
|
||||
"config.json",
|
||||
"audiovae.pth",
|
||||
"audiovae.safetensors",
|
||||
"tokenizer.json",
|
||||
"special_tokens_map.json",
|
||||
"tokenizer_config.json",
|
||||
]
|
||||
for fname in files_to_copy:
|
||||
src = pretrained_dir / fname
|
||||
if src.exists():
|
||||
shutil.copy2(src, folder / fname)
|
||||
|
||||
|
||||
torch.save(optimizer.state_dict(), folder / "optimizer.pth")
|
||||
torch.save(scheduler.state_dict(), folder / "scheduler.pth")
|
||||
with open(folder / "training_state.json", "w", encoding="utf-8") as f:
|
||||
|
||||
+46
-27
@@ -13,11 +13,11 @@ import soundfile as sf
|
||||
|
||||
from voxcpm.core import VoxCPM
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Validators
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def validate_file_exists(file_path: str, file_type: str = "file") -> Path:
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
@@ -53,12 +53,11 @@ def validate_ranges(args, parser):
|
||||
# Model loading
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def load_model(args) -> VoxCPM:
|
||||
print("Loading VoxCPM model...", file=sys.stderr)
|
||||
|
||||
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get(
|
||||
"ZIPENHANCER_MODEL_PATH", None
|
||||
)
|
||||
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get("ZIPENHANCER_MODEL_PATH", None)
|
||||
|
||||
# Build LoRA config if provided
|
||||
lora_config = None
|
||||
@@ -119,22 +118,29 @@ def load_model(args) -> VoxCPM:
|
||||
# Commands
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def cmd_clone(args):
|
||||
if not args.text:
|
||||
sys.exit("Error: Please provide --text for synthesis")
|
||||
|
||||
if not args.prompt_audio or not args.prompt_text:
|
||||
sys.exit("Error: Voice cloning requires both --prompt-audio and --prompt-text")
|
||||
has_prompt = args.prompt_audio and args.prompt_text
|
||||
has_ref = args.reference_audio is not None
|
||||
if not has_prompt and not has_ref:
|
||||
sys.exit("Error: Voice cloning requires --prompt-audio + --prompt-text, or --reference-audio, or both")
|
||||
|
||||
prompt_audio_path = validate_file_exists(args.prompt_audio, "reference audio file")
|
||||
if args.prompt_audio:
|
||||
validate_file_exists(args.prompt_audio, "prompt audio file")
|
||||
if args.reference_audio:
|
||||
validate_file_exists(args.reference_audio, "reference audio file")
|
||||
output_path = validate_output_path(args.output)
|
||||
|
||||
model = load_model(args)
|
||||
|
||||
audio_array = model.generate(
|
||||
text=args.text,
|
||||
prompt_wav_path=str(prompt_audio_path),
|
||||
prompt_text=args.prompt_text,
|
||||
prompt_wav_path=args.prompt_audio if has_prompt else None,
|
||||
prompt_text=args.prompt_text if has_prompt else None,
|
||||
reference_wav_path=args.reference_audio,
|
||||
cfg_value=args.cfg_value,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
normalize=args.normalize,
|
||||
@@ -185,7 +191,11 @@ def cmd_batch(args):
|
||||
|
||||
prompt_audio_path = None
|
||||
if args.prompt_audio:
|
||||
prompt_audio_path = str(validate_file_exists(args.prompt_audio, "reference audio file"))
|
||||
prompt_audio_path = str(validate_file_exists(args.prompt_audio, "prompt audio file"))
|
||||
|
||||
reference_audio_path = None
|
||||
if args.reference_audio:
|
||||
reference_audio_path = str(validate_file_exists(args.reference_audio, "reference audio file"))
|
||||
|
||||
success_count = 0
|
||||
|
||||
@@ -195,10 +205,11 @@ def cmd_batch(args):
|
||||
text=text,
|
||||
prompt_wav_path=prompt_audio_path,
|
||||
prompt_text=args.prompt_text,
|
||||
reference_wav_path=reference_audio_path,
|
||||
cfg_value=args.cfg_value,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
normalize=args.normalize,
|
||||
denoise=args.denoise and prompt_audio_path is not None,
|
||||
denoise=args.denoise and (prompt_audio_path is not None or reference_audio_path is not None),
|
||||
)
|
||||
|
||||
output_file = output_dir / f"output_{i:03d}.wav"
|
||||
@@ -218,6 +229,7 @@ def cmd_batch(args):
|
||||
# Parser
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def _build_unified_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="VoxCPM CLI - voice cloning, direct TTS, and batch processing",
|
||||
@@ -236,34 +248,40 @@ Examples:
|
||||
parser.add_argument("--text", "-t", help="Text to synthesize (single or clone mode)")
|
||||
parser.add_argument("--output", "-o", help="Output audio file path (single or clone mode)")
|
||||
|
||||
# Prompt
|
||||
parser.add_argument("--prompt-audio", "-pa", help="Reference audio file path (clone mode)")
|
||||
parser.add_argument("--prompt-text", "-pt", help="Reference text corresponding to the audio")
|
||||
parser.add_argument("--denoise", action="store_true", help="Enable prompt speech enhancement")
|
||||
# Prompt / Reference
|
||||
parser.add_argument(
|
||||
"--prompt-audio", "-pa", help="Prompt audio file path (continuation mode, requires --prompt-text)"
|
||||
)
|
||||
parser.add_argument("--prompt-text", "-pt", help="Text corresponding to the prompt audio")
|
||||
parser.add_argument(
|
||||
"--reference-audio", "-ra", help="Reference audio for voice cloning (isolated mode, VoxCPM2 only)"
|
||||
)
|
||||
parser.add_argument("--denoise", action="store_true", help="Enable prompt/reference speech enhancement")
|
||||
|
||||
# Generation parameters
|
||||
parser.add_argument("--cfg-value", type=float, default=2.0,
|
||||
help="CFG guidance scale (float, recommended 0.5–5.0, default: 2.0)")
|
||||
parser.add_argument("--inference-timesteps", type=int, default=10,
|
||||
help="Inference steps (int, 1–100, default: 10)")
|
||||
parser.add_argument(
|
||||
"--cfg-value", type=float, default=2.0, help="CFG guidance scale (float, recommended 0.5–5.0, default: 2.0)"
|
||||
)
|
||||
parser.add_argument("--inference-timesteps", type=int, default=10, help="Inference steps (int, 1–100, default: 10)")
|
||||
parser.add_argument("--normalize", action="store_true", help="Enable text normalization")
|
||||
|
||||
# Model loading
|
||||
parser.add_argument("--model-path", type=str, help="Local VoxCPM model path")
|
||||
parser.add_argument("--hf-model-id", type=str, default="openbmb/VoxCPM1.5",
|
||||
help="Hugging Face repo id (default: openbmb/VoxCPM1.5)")
|
||||
parser.add_argument(
|
||||
"--hf-model-id", type=str, default="openbmb/VoxCPM1.5", help="Hugging Face repo id (default: openbmb/VoxCPM1.5)"
|
||||
)
|
||||
parser.add_argument("--cache-dir", type=str, help="Cache directory for Hub downloads")
|
||||
parser.add_argument("--local-files-only", action="store_true", help="Disable network access")
|
||||
parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading")
|
||||
parser.add_argument("--zipenhancer-path", type=str,
|
||||
help="ZipEnhancer model id or local path (or env ZIPENHANCER_MODEL_PATH)")
|
||||
parser.add_argument(
|
||||
"--zipenhancer-path", type=str, help="ZipEnhancer model id or local path (or env ZIPENHANCER_MODEL_PATH)"
|
||||
)
|
||||
|
||||
# LoRA
|
||||
parser.add_argument("--lora-path", type=str, help="Path to LoRA weights")
|
||||
parser.add_argument("--lora-r", type=int, default=32, help="LoRA rank (positive int, default: 32)")
|
||||
parser.add_argument("--lora-alpha", type=int, default=16, help="LoRA alpha (positive int, default: 16)")
|
||||
parser.add_argument("--lora-dropout", type=float, default=0.0,
|
||||
help="LoRA dropout rate (0.0–1.0, default: 0.0)")
|
||||
parser.add_argument("--lora-dropout", type=float, default=0.0, help="LoRA dropout rate (0.0–1.0, default: 0.0)")
|
||||
parser.add_argument("--lora-disable-lm", action="store_true", help="Disable LoRA on LM layers")
|
||||
parser.add_argument("--lora-disable-dit", action="store_true", help="Disable LoRA on DiT layers")
|
||||
parser.add_argument("--lora-enable-proj", action="store_true", help="Enable LoRA on projection layers")
|
||||
@@ -275,6 +293,7 @@ Examples:
|
||||
# Entrypoint
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def main():
|
||||
parser = _build_unified_parser()
|
||||
args = parser.parse_args()
|
||||
@@ -296,8 +315,8 @@ def main():
|
||||
if not args.text or not args.output:
|
||||
parser.error("Single-sample mode requires --text and --output")
|
||||
|
||||
# Clone mode
|
||||
if args.prompt_audio or args.prompt_text:
|
||||
# Clone mode (prompt continuation, reference isolation, or both)
|
||||
if args.prompt_audio or args.prompt_text or args.reference_audio:
|
||||
return cmd_clone(args)
|
||||
|
||||
# Direct synthesis
|
||||
|
||||
+151
-100
@@ -1,21 +1,25 @@
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import json
|
||||
import tempfile
|
||||
import numpy as np
|
||||
from typing import Generator, Optional
|
||||
from huggingface_hub import snapshot_download
|
||||
from .model.voxcpm import VoxCPMModel, LoRAConfig
|
||||
from .model.voxcpm2 import VoxCPM2Model
|
||||
|
||||
|
||||
class VoxCPM:
|
||||
def __init__(self,
|
||||
voxcpm_model_path : str,
|
||||
zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||
enable_denoiser : bool = True,
|
||||
optimize: bool = True,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
lora_weights_path: Optional[str] = None,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
voxcpm_model_path: str,
|
||||
zipenhancer_model_path: str | None = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||
enable_denoiser: bool = True,
|
||||
optimize: bool = True,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
lora_weights_path: Optional[str] = None,
|
||||
):
|
||||
"""Initialize VoxCPM TTS pipeline.
|
||||
|
||||
Args:
|
||||
@@ -26,13 +30,16 @@ class VoxCPM:
|
||||
id or local path. If None, denoiser will not be initialized.
|
||||
enable_denoiser: Whether to initialize the denoiser pipeline.
|
||||
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
|
||||
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
|
||||
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
|
||||
provided without lora_config, a default config will be created.
|
||||
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
|
||||
containing lora_weights.ckpt). If provided, LoRA weights will be loaded.
|
||||
"""
|
||||
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}", file=sys.stderr)
|
||||
|
||||
print(
|
||||
f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# If lora_weights_path is provided but no lora_config, create a default one
|
||||
if lora_weights_path is not None and lora_config is None:
|
||||
lora_config = LoRAConfig(
|
||||
@@ -41,18 +48,33 @@ class VoxCPM:
|
||||
enable_proj=False,
|
||||
)
|
||||
print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}", file=sys.stderr)
|
||||
|
||||
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
||||
|
||||
|
||||
# Determine model type from config.json architecture field
|
||||
config_path = os.path.join(voxcpm_model_path, "config.json")
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
arch = config.get("architecture", "voxcpm").lower()
|
||||
|
||||
if arch == "voxcpm2":
|
||||
self.tts_model = VoxCPM2Model.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
||||
print("Loaded VoxCPM2Model", file=sys.stderr)
|
||||
elif arch == "voxcpm":
|
||||
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
||||
print("Loaded VoxCPMModel", file=sys.stderr)
|
||||
else:
|
||||
raise ValueError(f"Unsupported architecture: {arch}")
|
||||
|
||||
# Load LoRA weights if path is provided
|
||||
if lora_weights_path is not None:
|
||||
print(f"Loading LoRA weights from: {lora_weights_path}", file=sys.stderr)
|
||||
loaded_keys, skipped_keys = self.tts_model.load_lora_weights(lora_weights_path)
|
||||
print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}", file=sys.stderr)
|
||||
|
||||
|
||||
self.text_normalizer = None
|
||||
self.denoiser = None
|
||||
if enable_denoiser and zipenhancer_model_path is not None:
|
||||
from .zipenhancer import ZipEnhancer
|
||||
|
||||
self.denoiser = ZipEnhancer(zipenhancer_model_path)
|
||||
else:
|
||||
self.denoiser = None
|
||||
@@ -64,17 +86,18 @@ class VoxCPM:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls,
|
||||
hf_model_id: str = "openbmb/VoxCPM1.5",
|
||||
load_denoiser: bool = True,
|
||||
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||
cache_dir: str = None,
|
||||
local_files_only: bool = False,
|
||||
optimize: bool = True,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
lora_weights_path: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
def from_pretrained(
|
||||
cls,
|
||||
hf_model_id: str = "openbmb/VoxCPM2",
|
||||
load_denoiser: bool = True,
|
||||
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||
cache_dir: str = None,
|
||||
local_files_only: bool = False,
|
||||
optimize: bool = True,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
lora_weights_path: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
|
||||
|
||||
Args:
|
||||
@@ -86,7 +109,7 @@ class VoxCPM:
|
||||
cache_dir: Custom cache directory for the snapshot.
|
||||
local_files_only: If True, only use local files and do not attempt
|
||||
to download.
|
||||
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
|
||||
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
|
||||
provided without lora_config, a default config will be created with
|
||||
enable_lm=True and enable_dit=True.
|
||||
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
|
||||
@@ -106,7 +129,7 @@ class VoxCPM:
|
||||
repo_id = hf_model_id
|
||||
if not repo_id:
|
||||
raise ValueError("You must provide hf_model_id")
|
||||
|
||||
|
||||
# Load from local path if provided
|
||||
if os.path.isdir(repo_id):
|
||||
local_path = repo_id
|
||||
@@ -134,118 +157,146 @@ class VoxCPM:
|
||||
def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]:
|
||||
return self._generate(*args, streaming=True, **kwargs)
|
||||
|
||||
def _generate(self,
|
||||
text : str,
|
||||
prompt_wav_path : str = None,
|
||||
prompt_text : str = None,
|
||||
cfg_value : float = 2.0,
|
||||
inference_timesteps : int = 10,
|
||||
min_len : int = 2,
|
||||
max_len : int = 4096,
|
||||
normalize : bool = False,
|
||||
denoise : bool = False,
|
||||
retry_badcase : bool = True,
|
||||
retry_badcase_max_times : int = 3,
|
||||
retry_badcase_ratio_threshold : float = 6.0,
|
||||
streaming: bool = False,
|
||||
) -> Generator[np.ndarray, None, None]:
|
||||
def _generate(
|
||||
self,
|
||||
text: str,
|
||||
prompt_wav_path: str = None,
|
||||
prompt_text: str = None,
|
||||
reference_wav_path: str = None,
|
||||
cfg_value: float = 2.0,
|
||||
inference_timesteps: int = 10,
|
||||
min_len: int = 2,
|
||||
max_len: int = 4096,
|
||||
normalize: bool = False,
|
||||
denoise: bool = False,
|
||||
retry_badcase: bool = True,
|
||||
retry_badcase_max_times: int = 3,
|
||||
retry_badcase_ratio_threshold: float = 6.0,
|
||||
streaming: bool = False,
|
||||
) -> Generator[np.ndarray, None, None]:
|
||||
"""Synthesize speech for the given text and return a single waveform.
|
||||
|
||||
This method optionally builds and reuses a prompt cache. If an external
|
||||
prompt (``prompt_wav_path`` + ``prompt_text``) is provided, it will be
|
||||
used for all sub-sentences. Otherwise, the prompt cache is built from
|
||||
the first generated result and reused for the remaining text chunks.
|
||||
|
||||
Args:
|
||||
text: Input text. Can include newlines; each non-empty line is
|
||||
treated as a sub-sentence.
|
||||
prompt_wav_path: Path to a reference audio file for prompting.
|
||||
text: Input text to synthesize.
|
||||
prompt_wav_path: Path to prompt audio for continuation mode.
|
||||
Must be paired with ``prompt_text``.
|
||||
prompt_text: Text content corresponding to the prompt audio.
|
||||
reference_wav_path: Path to reference audio for voice cloning
|
||||
(structurally isolated via ref_audio tokens). Can be used
|
||||
alone or combined with ``prompt_wav_path`` + ``prompt_text``.
|
||||
cfg_value: Guidance scale for the generation model.
|
||||
inference_timesteps: Number of inference steps.
|
||||
min_len: Minimum audio length.
|
||||
max_len: Maximum token length during generation.
|
||||
normalize: Whether to run text normalization before generation.
|
||||
denoise: Whether to denoise the prompt audio if a denoiser is
|
||||
available.
|
||||
denoise: Whether to denoise the prompt/reference audio if a
|
||||
denoiser is available.
|
||||
retry_badcase: Whether to retry badcase.
|
||||
retry_badcase_max_times: Maximum number of times to retry badcase.
|
||||
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio.
|
||||
streaming: Whether to return a generator of audio chunks.
|
||||
Returns:
|
||||
Generator of numpy.ndarray: 1D waveform array (float32) on CPU.
|
||||
Yields audio chunks for each generations step if ``streaming=True``,
|
||||
Generator of numpy.ndarray: 1D waveform array (float32) on CPU.
|
||||
Yields audio chunks for each generation step if ``streaming=True``,
|
||||
otherwise yields a single array containing the final audio.
|
||||
"""
|
||||
if not text.strip() or not isinstance(text, str):
|
||||
raise ValueError("target text must be a non-empty string")
|
||||
|
||||
|
||||
if prompt_wav_path is not None:
|
||||
if not os.path.exists(prompt_wav_path):
|
||||
raise FileNotFoundError(f"prompt_wav_path does not exist: {prompt_wav_path}")
|
||||
|
||||
|
||||
if reference_wav_path is not None:
|
||||
if not os.path.exists(reference_wav_path):
|
||||
raise FileNotFoundError(f"reference_wav_path does not exist: {reference_wav_path}")
|
||||
|
||||
if (prompt_wav_path is None) != (prompt_text is None):
|
||||
raise ValueError("prompt_wav_path and prompt_text must both be provided or both be None")
|
||||
|
||||
|
||||
is_v2 = isinstance(self.tts_model, VoxCPM2Model)
|
||||
if reference_wav_path is not None and not is_v2:
|
||||
raise ValueError("reference_wav_path is only supported with VoxCPM2 models")
|
||||
|
||||
text = text.replace("\n", " ")
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
temp_prompt_wav_path = None
|
||||
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
temp_files = []
|
||||
|
||||
try:
|
||||
if prompt_wav_path is not None and prompt_text is not None:
|
||||
if denoise and self.denoiser is not None:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
|
||||
temp_prompt_wav_path = tmp_file.name
|
||||
self.denoiser.enhance(prompt_wav_path, output_path=temp_prompt_wav_path)
|
||||
prompt_wav_path = temp_prompt_wav_path
|
||||
fixed_prompt_cache = self.tts_model.build_prompt_cache(
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
prompt_text=prompt_text
|
||||
)
|
||||
actual_prompt_path = prompt_wav_path
|
||||
actual_ref_path = reference_wav_path
|
||||
|
||||
if denoise and self.denoiser is not None:
|
||||
if prompt_wav_path is not None:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
||||
temp_files.append(tmp.name)
|
||||
self.denoiser.enhance(prompt_wav_path, output_path=temp_files[-1])
|
||||
actual_prompt_path = temp_files[-1]
|
||||
if reference_wav_path is not None:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
||||
temp_files.append(tmp.name)
|
||||
self.denoiser.enhance(reference_wav_path, output_path=temp_files[-1])
|
||||
actual_ref_path = temp_files[-1]
|
||||
|
||||
if actual_prompt_path is not None or actual_ref_path is not None:
|
||||
if is_v2:
|
||||
fixed_prompt_cache = self.tts_model.build_prompt_cache(
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav_path=actual_prompt_path,
|
||||
reference_wav_path=actual_ref_path,
|
||||
)
|
||||
else:
|
||||
fixed_prompt_cache = self.tts_model.build_prompt_cache(
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav_path=actual_prompt_path,
|
||||
)
|
||||
else:
|
||||
fixed_prompt_cache = None # will be built from the first inference
|
||||
|
||||
fixed_prompt_cache = None
|
||||
|
||||
if normalize:
|
||||
if self.text_normalizer is None:
|
||||
from .utils.text_normalize import TextNormalizer
|
||||
|
||||
self.text_normalizer = TextNormalizer()
|
||||
text = self.text_normalizer.normalize(text)
|
||||
|
||||
|
||||
generate_result = self.tts_model._generate_with_prompt_cache(
|
||||
target_text=text,
|
||||
prompt_cache=fixed_prompt_cache,
|
||||
min_len=min_len,
|
||||
max_len=max_len,
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
retry_badcase=retry_badcase,
|
||||
retry_badcase_max_times=retry_badcase_max_times,
|
||||
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
target_text=text,
|
||||
prompt_cache=fixed_prompt_cache,
|
||||
min_len=min_len,
|
||||
max_len=max_len,
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
retry_badcase=retry_badcase,
|
||||
retry_badcase_max_times=retry_badcase_max_times,
|
||||
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
for wav, _, _ in generate_result:
|
||||
yield wav.squeeze(0).cpu().numpy()
|
||||
|
||||
|
||||
finally:
|
||||
if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path):
|
||||
try:
|
||||
os.unlink(temp_prompt_wav_path)
|
||||
except OSError:
|
||||
pass
|
||||
for tmp_path in temp_files:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# LoRA Interface (delegated to VoxCPMModel)
|
||||
# ------------------------------------------------------------------ #
|
||||
def load_lora(self, lora_weights_path: str) -> tuple:
|
||||
"""Load LoRA weights from a checkpoint file.
|
||||
|
||||
|
||||
Args:
|
||||
lora_weights_path: Path to LoRA weights (.pth file or directory
|
||||
containing lora_weights.ckpt).
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (loaded_keys, skipped_keys) - lists of loaded and skipped parameter names.
|
||||
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model was not initialized with LoRA config.
|
||||
"""
|
||||
@@ -259,23 +310,23 @@ class VoxCPM:
|
||||
def unload_lora(self):
|
||||
"""Unload LoRA by resetting all LoRA weights to initial state (effectively disabling LoRA)."""
|
||||
self.tts_model.reset_lora_weights()
|
||||
|
||||
|
||||
def set_lora_enabled(self, enabled: bool):
|
||||
"""Enable or disable LoRA layers without unloading weights.
|
||||
|
||||
|
||||
Args:
|
||||
enabled: If True, LoRA layers are active; if False, only base model is used.
|
||||
"""
|
||||
self.tts_model.set_lora_enabled(enabled)
|
||||
|
||||
|
||||
def get_lora_state_dict(self) -> dict:
|
||||
"""Get current LoRA parameters state dict.
|
||||
|
||||
|
||||
Returns:
|
||||
dict: State dict containing all LoRA parameters (lora_A, lora_B).
|
||||
"""
|
||||
return self.tts_model.get_lora_state_dict()
|
||||
|
||||
|
||||
@property
|
||||
def lora_enabled(self) -> bool:
|
||||
"""Check if LoRA is currently configured."""
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from .voxcpm import VoxCPMModel
|
||||
from .voxcpm2 import VoxCPM2Model
|
||||
|
||||
__all__ = ["VoxCPMModel"]
|
||||
__all__ = ["VoxCPMModel", "VoxCPM2Model"]
|
||||
|
||||
+18
-19
@@ -5,17 +5,17 @@ from transformers import PreTrainedTokenizer
|
||||
|
||||
def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
|
||||
"""Create a tokenizer wrapper that converts multi-character Chinese tokens to single characters.
|
||||
|
||||
|
||||
This function creates a wrapper around the provided tokenizer that automatically
|
||||
splits multi-character Chinese tokens into individual characters. This is useful
|
||||
for ensuring consistent tokenization of Chinese text.
|
||||
|
||||
|
||||
Args:
|
||||
tokenizer: The base tokenizer to wrap
|
||||
|
||||
|
||||
Returns:
|
||||
A CharTokenizerWrapper instance that handles multi-character Chinese tokens
|
||||
|
||||
|
||||
Example:
|
||||
>>> from transformers import LlamaTokenizerFast
|
||||
>>> tokenizer = LlamaTokenizerFast.from_pretrained("path/to/tokenizer")
|
||||
@@ -24,20 +24,19 @@ def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
|
||||
"""
|
||||
# Pre-compute multi-character tokens (length >= 2, pure Chinese characters)
|
||||
multichar_tokens = {
|
||||
token for token in tokenizer.vocab.keys()
|
||||
if len(token) >= 2 and all("\u4e00" <= c <= "\u9fff" for c in token)
|
||||
token for token in tokenizer.vocab.keys() if len(token) >= 2 and all("\u4e00" <= c <= "\u9fff" for c in token)
|
||||
}
|
||||
|
||||
class CharTokenizerWrapper:
|
||||
"""Wrapper class for tokenizers that handles multi-character Chinese tokens.
|
||||
|
||||
|
||||
This wrapper automatically splits multi-character Chinese tokens into
|
||||
individual characters while preserving the original tokenizer's interface.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, base_tokenizer: PreTrainedTokenizer) -> None:
|
||||
"""Initialize the wrapper with a base tokenizer.
|
||||
|
||||
|
||||
Args:
|
||||
base_tokenizer: The tokenizer to wrap
|
||||
"""
|
||||
@@ -46,14 +45,14 @@ def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
|
||||
|
||||
def tokenize(self, text: str, **kwargs) -> List[str]:
|
||||
"""Tokenize text and split multi-character Chinese tokens into single characters.
|
||||
|
||||
|
||||
Args:
|
||||
text: Input text to tokenize
|
||||
**kwargs: Additional arguments passed to the base tokenizer
|
||||
|
||||
|
||||
Returns:
|
||||
List of processed tokens with multi-character Chinese tokens split
|
||||
|
||||
|
||||
Example:
|
||||
>>> wrapper = CharTokenizerWrapper(tokenizer)
|
||||
>>> tokens = wrapper.tokenize("你好世界")
|
||||
@@ -61,10 +60,10 @@ def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
|
||||
"""
|
||||
if not isinstance(text, str):
|
||||
raise TypeError(f"Expected string input, got {type(text)}")
|
||||
|
||||
|
||||
tokens = self.tokenizer.tokenize(text, **kwargs)
|
||||
processed = []
|
||||
|
||||
|
||||
for token in tokens:
|
||||
# Remove possible subword prefix
|
||||
clean_token = token.replace("▁", "")
|
||||
@@ -75,22 +74,22 @@ def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
|
||||
processed.extend(chars)
|
||||
else:
|
||||
processed.append(token)
|
||||
|
||||
|
||||
return processed
|
||||
|
||||
def __call__(self, text: str, **kwargs) -> List[int]:
|
||||
"""Call the tokenizer and return token IDs.
|
||||
|
||||
|
||||
This method provides the same interface as the original tokenizer
|
||||
but with multi-character Chinese token handling.
|
||||
|
||||
|
||||
Args:
|
||||
text: Input text to tokenize
|
||||
**kwargs: Additional arguments passed to the base tokenizer
|
||||
|
||||
|
||||
Returns:
|
||||
List of token IDs
|
||||
|
||||
|
||||
Raises:
|
||||
TypeError: If input is not a string
|
||||
ValueError: If tokenization fails
|
||||
|
||||
+128
-115
@@ -24,7 +24,6 @@ from typing import Tuple, Union, Generator, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
import warnings
|
||||
from einops import rearrange
|
||||
@@ -32,6 +31,7 @@ from pydantic import BaseModel
|
||||
|
||||
try:
|
||||
from safetensors.torch import load_file
|
||||
|
||||
SAFETENSORS_AVAILABLE = True
|
||||
except ImportError:
|
||||
SAFETENSORS_AVAILABLE = False
|
||||
@@ -84,9 +84,9 @@ class VoxCPMConfig(BaseModel):
|
||||
|
||||
|
||||
class LoRAConfig(BaseModel):
|
||||
enable_lm: bool = False # Apply LoRA to base_lm + residual_lm
|
||||
enable_dit: bool = False # Apply LoRA to VoxCPMLocDiT
|
||||
enable_proj: bool = False # Apply LoRA to projection Linear layers
|
||||
enable_lm: bool = False # Apply LoRA to base_lm + residual_lm
|
||||
enable_dit: bool = False # Apply LoRA to VoxCPMLocDiT
|
||||
enable_proj: bool = False # Apply LoRA to projection Linear layers
|
||||
|
||||
r: int = 8
|
||||
alpha: int = 16
|
||||
@@ -165,10 +165,10 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
# Projection layers
|
||||
self.fsq_layer = ScalarQuantizationLayer(
|
||||
config.lm_config.hidden_size,
|
||||
config.lm_config.hidden_size,
|
||||
config.scalar_quantization_latent_dim,
|
||||
config.scalar_quantization_scale
|
||||
config.lm_config.hidden_size,
|
||||
config.lm_config.hidden_size,
|
||||
config.scalar_quantization_latent_dim,
|
||||
config.scalar_quantization_scale,
|
||||
)
|
||||
self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size)
|
||||
self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
|
||||
@@ -196,9 +196,7 @@ class VoxCPMModel(nn.Module):
|
||||
# LM: base_lm + residual_lm
|
||||
if cfg.enable_lm:
|
||||
for lm in [self.base_lm, self.residual_lm]:
|
||||
apply_lora_to_named_linear_modules(
|
||||
lm, target_submodule_names=cfg.target_modules_lm, **lora_kwargs
|
||||
)
|
||||
apply_lora_to_named_linear_modules(lm, target_submodule_names=cfg.target_modules_lm, **lora_kwargs)
|
||||
|
||||
# DiT: feat_decoder.estimator
|
||||
if cfg.enable_dit:
|
||||
@@ -209,6 +207,7 @@ class VoxCPMModel(nn.Module):
|
||||
# 投影层
|
||||
if cfg.enable_proj:
|
||||
from ..modules.layers.lora import LoRALinear
|
||||
|
||||
for attr_name in cfg.target_proj_modules:
|
||||
module = getattr(self, attr_name, None)
|
||||
if isinstance(module, nn.Linear):
|
||||
@@ -221,13 +220,17 @@ class VoxCPMModel(nn.Module):
|
||||
if self.device != "cuda":
|
||||
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
|
||||
try:
|
||||
import triton
|
||||
import triton # noqa: F401
|
||||
except ImportError:
|
||||
raise ValueError("triton is not installed")
|
||||
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
||||
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
||||
self.residual_lm.forward_step = torch.compile(
|
||||
self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True
|
||||
)
|
||||
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
||||
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
|
||||
self.feat_decoder.estimator = torch.compile(
|
||||
self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: torch.compile disabled - {e}", file=sys.stderr)
|
||||
return self
|
||||
@@ -313,9 +316,11 @@ class VoxCPMModel(nn.Module):
|
||||
mu=dit_hidden,
|
||||
patch_size=self.patch_size,
|
||||
cond=feat_cond_for_sample,
|
||||
n_timesteps=self.config.dit_config.cfm_config.inference_cfg_rate
|
||||
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
|
||||
else 10,
|
||||
n_timesteps=(
|
||||
self.config.dit_config.cfm_config.inference_cfg_rate
|
||||
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
|
||||
else 10
|
||||
),
|
||||
)
|
||||
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
|
||||
|
||||
@@ -331,7 +336,6 @@ class VoxCPMModel(nn.Module):
|
||||
def _dtype(self):
|
||||
return get_dtype(self.config.dtype)
|
||||
|
||||
|
||||
def generate(self, *args, **kwargs) -> torch.Tensor:
|
||||
return next(self._generate(*args, streaming=False, **kwargs))
|
||||
|
||||
@@ -350,7 +354,7 @@ class VoxCPMModel(nn.Module):
|
||||
cfg_value: float = 2.0,
|
||||
retry_badcase: bool = False,
|
||||
retry_badcase_max_times: int = 3,
|
||||
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
|
||||
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
|
||||
streaming: bool = False,
|
||||
) -> Generator[torch.Tensor, None, None]:
|
||||
if retry_badcase and streaming:
|
||||
@@ -394,7 +398,7 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
audio, sr = torchaudio.load(prompt_wav_path)
|
||||
if audio.size(0) > 1:
|
||||
audio = audio.mean(dim=0, keepdim=True)
|
||||
audio = audio.mean(dim=0, keepdim=True)
|
||||
|
||||
if sr != self.sample_rate:
|
||||
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
|
||||
@@ -435,7 +439,7 @@ class VoxCPMModel(nn.Module):
|
||||
audio_mask = audio_mask.unsqueeze(0).to(self.device)
|
||||
|
||||
target_text_length = len(self.text_tokenizer(target_text))
|
||||
|
||||
|
||||
retry_badcase_times = 0
|
||||
while retry_badcase_times < retry_badcase_max_times:
|
||||
inference_result = self._inference(
|
||||
@@ -444,7 +448,9 @@ class VoxCPMModel(nn.Module):
|
||||
audio_feat,
|
||||
audio_mask,
|
||||
min_len=min_len,
|
||||
max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len), # avoid too long audio
|
||||
max_len=min(
|
||||
int(target_text_length * retry_badcase_ratio_threshold + 10), max_len
|
||||
), # avoid too long audio
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
streaming=streaming,
|
||||
@@ -460,18 +466,21 @@ class VoxCPMModel(nn.Module):
|
||||
latent_pred, pred_audio_feat = next(inference_result)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...", file=sys.stderr)
|
||||
print(
|
||||
f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...",
|
||||
file=sys.stderr,
|
||||
)
|
||||
retry_badcase_times += 1
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
break
|
||||
|
||||
if not streaming:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||
yield decode_audio
|
||||
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||
yield decode_audio
|
||||
|
||||
@torch.inference_mode()
|
||||
def build_prompt_cache(
|
||||
self,
|
||||
@@ -480,11 +489,11 @@ class VoxCPMModel(nn.Module):
|
||||
):
|
||||
"""
|
||||
Build prompt cache for subsequent fast generation.
|
||||
|
||||
|
||||
Args:
|
||||
prompt_text: prompt text (required)
|
||||
prompt_wav_path: prompt audio path (required)
|
||||
|
||||
|
||||
Returns:
|
||||
prompt_cache: dict with prompt_text (raw text) and audio features.
|
||||
Text tokenization will be done during generation for consistency.
|
||||
@@ -496,7 +505,7 @@ class VoxCPMModel(nn.Module):
|
||||
audio, sr = torchaudio.load(prompt_wav_path)
|
||||
if audio.size(0) > 1:
|
||||
audio = audio.mean(dim=0, keepdim=True)
|
||||
|
||||
|
||||
if sr != self.sample_rate:
|
||||
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
|
||||
|
||||
@@ -514,16 +523,17 @@ class VoxCPMModel(nn.Module):
|
||||
self.audio_vae.latent_dim,
|
||||
-1,
|
||||
self.patch_size,
|
||||
).permute(1, 2, 0) # (D, T, P)
|
||||
).permute(
|
||||
1, 2, 0
|
||||
) # (D, T, P)
|
||||
# build prompt cache - only save raw text and audio features
|
||||
prompt_cache = {
|
||||
"prompt_text": prompt_text,
|
||||
"audio_feat": audio_feat,
|
||||
}
|
||||
|
||||
|
||||
return prompt_cache
|
||||
|
||||
|
||||
def merge_prompt_cache(
|
||||
self,
|
||||
original_cache: dict,
|
||||
@@ -532,12 +542,12 @@ class VoxCPMModel(nn.Module):
|
||||
):
|
||||
"""
|
||||
Merge original prompt cache with newly generated content to stabilize voice.
|
||||
|
||||
|
||||
Args:
|
||||
original_cache: original prompt cache
|
||||
new_text: newly generated text
|
||||
new_text: newly generated text
|
||||
new_audio_feat: newly generated audio features
|
||||
|
||||
|
||||
Returns:
|
||||
merged_cache: merged cache with prompt_text and audio_feat
|
||||
"""
|
||||
@@ -557,20 +567,17 @@ class VoxCPMModel(nn.Module):
|
||||
"prompt_text": merged_prompt_text,
|
||||
"audio_feat": merged_audio_feat,
|
||||
}
|
||||
|
||||
|
||||
return merged_cache
|
||||
|
||||
|
||||
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
||||
|
||||
|
||||
def generate_with_prompt_cache_streaming(
|
||||
self, *args, **kwargs
|
||||
) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]:
|
||||
return self._generate_with_prompt_cache(*args, streaming=True, **kwargs)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _generate_with_prompt_cache(
|
||||
self,
|
||||
@@ -588,7 +595,7 @@ class VoxCPMModel(nn.Module):
|
||||
) -> Generator[Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
|
||||
"""
|
||||
Generate audio using pre-built prompt cache.
|
||||
|
||||
|
||||
Args:
|
||||
target_text: Text to convert to speech
|
||||
prompt_cache: Cache built by build_prompt_cache (can be None)
|
||||
@@ -601,7 +608,7 @@ class VoxCPMModel(nn.Module):
|
||||
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
|
||||
streaming: Whether to return a generator of audio chunks
|
||||
streaming_prefix_len: Number of prefix audio patches to use for streaming mode
|
||||
|
||||
|
||||
Returns:
|
||||
Generator of Tuple containing:
|
||||
- Decoded audio tensor for the current step if ``streaming=True``, else final decoded audio tensor
|
||||
@@ -619,7 +626,7 @@ class VoxCPMModel(nn.Module):
|
||||
prompt_audio_feat = prompt_cache["audio_feat"]
|
||||
prompt_text = prompt_cache["prompt_text"]
|
||||
text = prompt_text + target_text
|
||||
|
||||
|
||||
text_token = torch.LongTensor(self.text_tokenizer(text))
|
||||
text_token = torch.cat(
|
||||
[
|
||||
@@ -632,7 +639,7 @@ class VoxCPMModel(nn.Module):
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
|
||||
target_text_token = torch.LongTensor(self.text_tokenizer(target_text))
|
||||
|
||||
audio_length = prompt_audio_feat.size(0)
|
||||
@@ -645,14 +652,18 @@ class VoxCPMModel(nn.Module):
|
||||
)
|
||||
text_token = torch.cat([text_token, text_pad_token])
|
||||
audio_feat = torch.cat([audio_pad_feat, prompt_audio_feat], dim=0)
|
||||
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
|
||||
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
|
||||
text_mask = (
|
||||
torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
|
||||
)
|
||||
audio_mask = (
|
||||
torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
|
||||
)
|
||||
|
||||
text_token = text_token.unsqueeze(0).to(self.device)
|
||||
text_mask = text_mask.unsqueeze(0).to(self.device)
|
||||
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
|
||||
audio_mask = audio_mask.unsqueeze(0).to(self.device)
|
||||
|
||||
|
||||
# run inference
|
||||
target_text_length = len(self.text_tokenizer(target_text))
|
||||
retry_badcase_times = 0
|
||||
@@ -663,7 +674,9 @@ class VoxCPMModel(nn.Module):
|
||||
audio_feat,
|
||||
audio_mask,
|
||||
min_len=min_len,
|
||||
max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len), # avoid too long audio
|
||||
max_len=min(
|
||||
int(target_text_length * retry_badcase_ratio_threshold + 10), max_len
|
||||
), # avoid too long audio
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
streaming=streaming,
|
||||
@@ -674,17 +687,16 @@ class VoxCPMModel(nn.Module):
|
||||
for latent_pred, pred_audio_feat in inference_result:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
|
||||
yield (
|
||||
decode_audio,
|
||||
target_text_token,
|
||||
pred_audio_feat
|
||||
)
|
||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||
break
|
||||
else:
|
||||
latent_pred, pred_audio_feat = next(inference_result)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...", file=sys.stderr)
|
||||
print(
|
||||
f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...",
|
||||
file=sys.stderr,
|
||||
)
|
||||
retry_badcase_times += 1
|
||||
continue
|
||||
else:
|
||||
@@ -695,18 +707,14 @@ class VoxCPMModel(nn.Module):
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||
patch_len = self.patch_size * self.chunk_size
|
||||
if audio_mask.sum().item() > 0:
|
||||
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1):].squeeze(1).cpu()
|
||||
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1) :].squeeze(1).cpu()
|
||||
else:
|
||||
decode_audio = decode_audio[..., :].squeeze(1).cpu()
|
||||
yield (
|
||||
decode_audio,
|
||||
target_text_token,
|
||||
pred_audio_feat
|
||||
)
|
||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||
|
||||
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return next(self._inference(*args, streaming=False, **kwargs))
|
||||
|
||||
|
||||
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
|
||||
return self._inference(*args, streaming=True, **kwargs)
|
||||
|
||||
@@ -725,10 +733,10 @@ class VoxCPMModel(nn.Module):
|
||||
streaming_prefix_len: int = 3,
|
||||
) -> Generator[Tuple[torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
|
||||
"""Core inference method for audio generation.
|
||||
|
||||
|
||||
This is the main inference loop that generates audio features
|
||||
using the language model and diffusion transformer.
|
||||
|
||||
|
||||
Args:
|
||||
text: Input text tokens
|
||||
text_mask: Mask for text tokens
|
||||
@@ -739,7 +747,7 @@ class VoxCPMModel(nn.Module):
|
||||
inference_timesteps: Number of diffusion steps
|
||||
cfg_value: Classifier-free guidance value
|
||||
streaming: Whether to yield each step latent feature or just the final result
|
||||
|
||||
|
||||
Returns:
|
||||
Generator of Tuple containing:
|
||||
- Predicted latent feature at the current step if ``streaming=True``, else final latent features
|
||||
@@ -749,12 +757,12 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
feat_embed = self.feat_encoder(feat) # [b, t, h_feat]
|
||||
feat_embed = self.enc_to_lm_proj(feat_embed)
|
||||
|
||||
|
||||
if self.config.lm_config.use_mup:
|
||||
scale_emb = self.config.lm_config.scale_emb
|
||||
else:
|
||||
scale_emb = 1.0
|
||||
|
||||
|
||||
text_embed = self.base_lm.embed_tokens(text) * scale_emb
|
||||
combined_embed = text_mask.unsqueeze(-1) * text_embed + feat_mask.unsqueeze(-1) * feat_embed
|
||||
|
||||
@@ -778,11 +786,10 @@ class VoxCPMModel(nn.Module):
|
||||
is_causal=True,
|
||||
)
|
||||
self.base_lm.kv_cache.fill_caches(kv_cache_tuple)
|
||||
|
||||
|
||||
enc_outputs = self.fsq_layer(enc_outputs) * feat_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
|
||||
lm_hidden = enc_outputs[:, -1, :]
|
||||
|
||||
|
||||
residual_enc_outputs, residual_kv_cache_tuple = self.residual_lm(
|
||||
inputs_embeds=enc_outputs + feat_mask.unsqueeze(-1) * feat_embed,
|
||||
is_causal=True,
|
||||
@@ -790,7 +797,6 @@ class VoxCPMModel(nn.Module):
|
||||
self.residual_lm.kv_cache.fill_caches(residual_kv_cache_tuple)
|
||||
residual_hidden = residual_enc_outputs[:, -1, :]
|
||||
|
||||
|
||||
for i in tqdm(range(max_len)):
|
||||
dit_hidden_1 = self.lm_to_dit_proj(lm_hidden) # [b, h_dit]
|
||||
dit_hidden_2 = self.res_to_dit_proj(residual_hidden) # [b, h_dit]
|
||||
@@ -805,10 +811,10 @@ class VoxCPMModel(nn.Module):
|
||||
).transpose(
|
||||
1, 2
|
||||
) # [b, p, d]
|
||||
|
||||
|
||||
curr_embed = self.feat_encoder(pred_feat.unsqueeze(1)) # b, 1, c
|
||||
curr_embed = self.enc_to_lm_proj(curr_embed)
|
||||
|
||||
|
||||
pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
|
||||
prefix_feat_cond = pred_feat
|
||||
|
||||
@@ -816,58 +822,70 @@ class VoxCPMModel(nn.Module):
|
||||
# return the last three predicted latent features to provide enough context for smooth decoding
|
||||
pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1)
|
||||
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
|
||||
|
||||
yield feat_pred, pred_feat_seq
|
||||
|
||||
|
||||
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
|
||||
if i > min_len and stop_flag == 1:
|
||||
break
|
||||
|
||||
|
||||
lm_hidden = self.base_lm.forward_step(
|
||||
curr_embed[:, 0, :], torch.tensor([self.base_lm.kv_cache.step()], device=curr_embed.device)
|
||||
).clone()
|
||||
|
||||
|
||||
lm_hidden = self.fsq_layer(lm_hidden)
|
||||
residual_hidden = self.residual_lm.forward_step(
|
||||
lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device)
|
||||
lm_hidden + curr_embed[:, 0, :],
|
||||
torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device),
|
||||
).clone()
|
||||
|
||||
|
||||
if not streaming:
|
||||
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
|
||||
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
|
||||
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
||||
audio_vae_config = getattr(config, 'audio_vae_config', None)
|
||||
audio_vae_config = getattr(config, "audio_vae_config", None)
|
||||
audio_vae = AudioVAE(config=audio_vae_config) if audio_vae_config else AudioVAE()
|
||||
vae_state_dict = torch.load(
|
||||
os.path.join(path, "audiovae.pth"),
|
||||
map_location="cpu",
|
||||
weights_only=True,
|
||||
)["state_dict"]
|
||||
# Try to load AudioVAE from safetensors first, fallback to pytorch
|
||||
audiovae_safetensors_path = os.path.join(path, "audiovae.safetensors")
|
||||
audiovae_pth_path = os.path.join(path, "audiovae.pth")
|
||||
if os.path.exists(audiovae_safetensors_path) and SAFETENSORS_AVAILABLE:
|
||||
print(f"Loading AudioVAE from safetensors: {audiovae_safetensors_path}", file=sys.stderr)
|
||||
vae_state_dict = load_file(audiovae_safetensors_path, device="cpu")
|
||||
elif os.path.exists(audiovae_pth_path):
|
||||
print(f"Loading AudioVAE from pytorch: {audiovae_pth_path}", file=sys.stderr)
|
||||
checkpoint = torch.load(
|
||||
audiovae_pth_path,
|
||||
map_location="cpu",
|
||||
weights_only=True,
|
||||
)
|
||||
vae_state_dict = checkpoint.get("state_dict", checkpoint)
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
|
||||
)
|
||||
model = cls(config, tokenizer, audio_vae, lora_config)
|
||||
if not training:
|
||||
lm_dtype = get_dtype(model.config.dtype)
|
||||
model = model.to(lm_dtype)
|
||||
else: # training mode
|
||||
else: # training mode
|
||||
for name, param in model.named_parameters():
|
||||
if "audio_vae" in name: # freeze VAE weights
|
||||
if "audio_vae" in name: # freeze VAE weights
|
||||
param.requires_grad = False
|
||||
continue
|
||||
if lora_config is not None:
|
||||
if "lora" not in name: # freeze non-LoRA weights
|
||||
if "lora" not in name: # freeze non-LoRA weights
|
||||
param.requires_grad = False
|
||||
model.audio_vae = model.audio_vae.to(torch.float32)
|
||||
|
||||
|
||||
# Try to load from safetensors first, fallback to pytorch_model.bin
|
||||
safetensors_path = os.path.join(path, "model.safetensors")
|
||||
pytorch_model_path = os.path.join(path, "pytorch_model.bin")
|
||||
|
||||
|
||||
if os.path.exists(safetensors_path) and SAFETENSORS_AVAILABLE:
|
||||
print(f"Loading model from safetensors: {safetensors_path}", file=sys.stderr)
|
||||
model_state_dict = load_file(safetensors_path)
|
||||
@@ -880,13 +898,11 @@ class VoxCPMModel(nn.Module):
|
||||
)
|
||||
model_state_dict = checkpoint.get("state_dict", checkpoint)
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"Model file not found. Expected either {safetensors_path} or {pytorch_model_path}"
|
||||
)
|
||||
|
||||
raise FileNotFoundError(f"Model file not found. Expected either {safetensors_path} or {pytorch_model_path}")
|
||||
|
||||
for kw, val in vae_state_dict.items():
|
||||
model_state_dict[f"audio_vae.{kw}"] = val
|
||||
|
||||
|
||||
# LoRALinear holds weight/bias directly, compatible with nn.Linear state_dict keys.
|
||||
# Using strict=False since pretrained weights don't contain lora_A/lora_B.
|
||||
model.load_state_dict(model_state_dict, strict=False)
|
||||
@@ -900,6 +916,7 @@ class VoxCPMModel(nn.Module):
|
||||
def _iter_lora_modules(self):
|
||||
"""Iterate over all LoRA modules."""
|
||||
from ..modules.layers.lora import LoRALinear
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, LoRALinear):
|
||||
yield module
|
||||
@@ -909,7 +926,7 @@ class VoxCPMModel(nn.Module):
|
||||
Load LoRA weights from file, supports calling after torch.compile.
|
||||
Uses named_parameters() to handle compile's _orig_mod wrapper.
|
||||
Supports both safetensors and pytorch formats.
|
||||
|
||||
|
||||
Args:
|
||||
lora_path: Checkpoint path (directory or .safetensors/.ckpt file)
|
||||
device: Target device, defaults to model's current device
|
||||
@@ -917,18 +934,18 @@ class VoxCPMModel(nn.Module):
|
||||
tuple: (loaded_keys, skipped_keys)
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
device = device or self.device
|
||||
lora_path = Path(lora_path)
|
||||
|
||||
lora_p = Path(lora_path)
|
||||
|
||||
# Try safetensors first, then fallback to .ckpt
|
||||
if lora_path.is_dir():
|
||||
safetensors_file = lora_path / "lora_weights.safetensors"
|
||||
ckpt_file = lora_path / "lora_weights.ckpt"
|
||||
if lora_p.is_dir():
|
||||
safetensors_file = lora_p / "lora_weights.safetensors"
|
||||
ckpt_file = lora_p / "lora_weights.ckpt"
|
||||
else:
|
||||
safetensors_file = lora_path if lora_path.suffix == ".safetensors" else None
|
||||
ckpt_file = lora_path if lora_path.suffix in [".ckpt", ".pth"] else None
|
||||
|
||||
safetensors_file = lora_p if lora_p.suffix == ".safetensors" else None
|
||||
ckpt_file = lora_p if lora_p.suffix in [".ckpt", ".pth"] else None
|
||||
|
||||
# Load from safetensors if available
|
||||
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
|
||||
state_dict = load_file(str(safetensors_file), device=device)
|
||||
@@ -936,14 +953,12 @@ class VoxCPMModel(nn.Module):
|
||||
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
|
||||
state_dict = ckpt.get("state_dict", ckpt)
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}"
|
||||
)
|
||||
|
||||
raise FileNotFoundError(f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}")
|
||||
|
||||
# Build param mapping (handle torch.compile's _orig_mod prefix)
|
||||
model_params = dict(self.named_parameters())
|
||||
key_mapping = {k.replace("._orig_mod.", "."): k for k in model_params if "._orig_mod." in k}
|
||||
|
||||
|
||||
loaded_keys, skipped_keys = [], []
|
||||
for key, value in state_dict.items():
|
||||
target_key = key if key in model_params else key_mapping.get(key)
|
||||
@@ -952,7 +967,7 @@ class VoxCPMModel(nn.Module):
|
||||
loaded_keys.append(key)
|
||||
else:
|
||||
skipped_keys.append(key)
|
||||
|
||||
|
||||
return loaded_keys, skipped_keys
|
||||
|
||||
def set_lora_enabled(self, enabled: bool):
|
||||
@@ -967,6 +982,4 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
def get_lora_state_dict(self) -> dict:
|
||||
"""Get all LoRA parameters (lora_A/lora_B)."""
|
||||
return {name: param.data.clone()
|
||||
for name, param in self.named_parameters()
|
||||
if "lora_" in name}
|
||||
return {name: param.data.clone() for name, param in self.named_parameters() if "lora_" in name}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1 +1,2 @@
|
||||
from .audio_vae import AudioVAE, AudioVAEConfig
|
||||
from .audio_vae_v2 import AudioVAE as AudioVAEV2, AudioVAEConfig as AudioVAEConfigV2
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import math
|
||||
from typing import List, Union, Optional
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -285,12 +285,12 @@ class AudioVAE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[AudioVAEConfig] = None,
|
||||
config: AudioVAEConfig = None,
|
||||
):
|
||||
# 如果没有传入config,使用默认配置
|
||||
if config is None:
|
||||
config = AudioVAEConfig()
|
||||
|
||||
|
||||
super().__init__()
|
||||
|
||||
encoder_dim = config.encoder_dim
|
||||
@@ -301,7 +301,7 @@ class AudioVAE(nn.Module):
|
||||
depthwise = config.depthwise
|
||||
sample_rate = config.sample_rate
|
||||
use_noise_block = config.use_noise_block
|
||||
|
||||
|
||||
self.encoder_dim = encoder_dim
|
||||
self.encoder_rates = encoder_rates
|
||||
self.decoder_dim = decoder_dim
|
||||
|
||||
@@ -0,0 +1,486 @@
|
||||
import math
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def WNConv1d(*args, **kwargs):
|
||||
return weight_norm(nn.Conv1d(*args, **kwargs))
|
||||
|
||||
|
||||
def WNConvTranspose1d(*args, **kwargs):
|
||||
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
||||
|
||||
|
||||
class CausalConv1d(nn.Conv1d):
|
||||
def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.__padding = padding
|
||||
self.__output_padding = output_padding
|
||||
|
||||
def forward(self, x):
|
||||
x_pad = F.pad(x, (self.__padding * 2 - self.__output_padding, 0))
|
||||
return super().forward(x_pad)
|
||||
|
||||
|
||||
class CausalTransposeConv1d(nn.ConvTranspose1d):
|
||||
def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.__padding = padding
|
||||
self.__output_padding = output_padding
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x)[..., : -(self.__padding * 2 - self.__output_padding)]
|
||||
|
||||
|
||||
def WNCausalConv1d(*args, **kwargs):
|
||||
return weight_norm(CausalConv1d(*args, **kwargs))
|
||||
|
||||
|
||||
def WNCausalTransposeConv1d(*args, **kwargs):
|
||||
return weight_norm(CausalTransposeConv1d(*args, **kwargs))
|
||||
|
||||
|
||||
# Scripting this brings model speed up 1.4x
|
||||
@torch.jit.script
|
||||
def snake(x, alpha):
|
||||
shape = x.shape
|
||||
x = x.reshape(shape[0], shape[1], -1)
|
||||
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
||||
x = x.reshape(shape)
|
||||
return x
|
||||
|
||||
|
||||
class Snake1d(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
||||
|
||||
def forward(self, x):
|
||||
return snake(x, self.alpha)
|
||||
|
||||
|
||||
def init_weights(m):
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
class CausalResidualUnit(nn.Module):
|
||||
def __init__(self, dim: int = 16, dilation: int = 1, kernel: int = 7, groups: int = 1):
|
||||
super().__init__()
|
||||
pad = ((7 - 1) * dilation) // 2
|
||||
self.block = nn.Sequential(
|
||||
Snake1d(dim),
|
||||
WNCausalConv1d(
|
||||
dim,
|
||||
dim,
|
||||
kernel_size=kernel,
|
||||
dilation=dilation,
|
||||
padding=pad,
|
||||
groups=groups,
|
||||
),
|
||||
Snake1d(dim),
|
||||
WNCausalConv1d(dim, dim, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.block(x)
|
||||
pad = (x.shape[-1] - y.shape[-1]) // 2
|
||||
assert pad == 0
|
||||
if pad > 0:
|
||||
x = x[..., pad:-pad]
|
||||
return x + y
|
||||
|
||||
|
||||
class CausalEncoderBlock(nn.Module):
|
||||
def __init__(self, output_dim: int = 16, input_dim=None, stride: int = 1, groups=1):
|
||||
super().__init__()
|
||||
input_dim = input_dim or output_dim // 2
|
||||
self.block = nn.Sequential(
|
||||
CausalResidualUnit(input_dim, dilation=1, groups=groups),
|
||||
CausalResidualUnit(input_dim, dilation=3, groups=groups),
|
||||
CausalResidualUnit(input_dim, dilation=9, groups=groups),
|
||||
Snake1d(input_dim),
|
||||
WNCausalConv1d(
|
||||
input_dim,
|
||||
output_dim,
|
||||
kernel_size=2 * stride,
|
||||
stride=stride,
|
||||
padding=math.ceil(stride / 2),
|
||||
output_padding=stride % 2,
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class CausalEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int = 64,
|
||||
latent_dim: int = 32,
|
||||
strides: list = [2, 4, 8, 8],
|
||||
depthwise: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
# Create first convolution
|
||||
self.block = [WNCausalConv1d(1, d_model, kernel_size=7, padding=3)]
|
||||
|
||||
# Create EncoderBlocks that double channels as they downsample by `stride`
|
||||
for stride in strides:
|
||||
d_model *= 2
|
||||
groups = d_model // 2 if depthwise else 1
|
||||
self.block += [CausalEncoderBlock(output_dim=d_model, stride=stride, groups=groups)]
|
||||
|
||||
groups = d_model if depthwise else 1
|
||||
|
||||
# Create two convolution, for mu and logvar
|
||||
self.fc_mu = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
|
||||
self.fc_logvar = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
|
||||
|
||||
# Wrap black into nn.Sequential
|
||||
self.block = nn.Sequential(*self.block)
|
||||
self.enc_dim = d_model
|
||||
|
||||
def forward(self, x):
|
||||
hidden_state = self.block(x)
|
||||
return {
|
||||
"hidden_state": hidden_state,
|
||||
"mu": self.fc_mu(hidden_state),
|
||||
"logvar": self.fc_logvar(hidden_state),
|
||||
}
|
||||
|
||||
|
||||
class NoiseBlock(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.linear = WNCausalConv1d(dim, dim, kernel_size=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, T = x.shape
|
||||
noise = torch.randn((B, 1, T), device=x.device, dtype=x.dtype)
|
||||
h = self.linear(x)
|
||||
n = noise * h
|
||||
x = x + n
|
||||
return x
|
||||
|
||||
|
||||
class CausalDecoderBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int = 16,
|
||||
output_dim: int = 8,
|
||||
stride: int = 1,
|
||||
groups=1,
|
||||
use_noise_block: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
layers = [
|
||||
Snake1d(input_dim),
|
||||
WNCausalTransposeConv1d(
|
||||
input_dim,
|
||||
output_dim,
|
||||
kernel_size=2 * stride,
|
||||
stride=stride,
|
||||
padding=math.ceil(stride / 2),
|
||||
output_padding=stride % 2,
|
||||
),
|
||||
]
|
||||
if use_noise_block:
|
||||
layers.append(NoiseBlock(output_dim))
|
||||
layers.extend(
|
||||
[
|
||||
CausalResidualUnit(output_dim, dilation=1, groups=groups),
|
||||
CausalResidualUnit(output_dim, dilation=3, groups=groups),
|
||||
CausalResidualUnit(output_dim, dilation=9, groups=groups),
|
||||
]
|
||||
)
|
||||
self.block = nn.Sequential(*layers)
|
||||
self.input_channels = input_dim
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class TransposeLastTwoDim(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.transpose(x, -1, -2)
|
||||
|
||||
|
||||
class SampleRateConditionLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
sr_bin_buckets: int = None,
|
||||
cond_type: str = "scale_bias",
|
||||
cond_dim: int = 128,
|
||||
out_layer: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.cond_type, out_layer_in_dim = cond_type, input_dim
|
||||
|
||||
if cond_type == "scale_bias":
|
||||
self.scale_embed = nn.Embedding(sr_bin_buckets, input_dim)
|
||||
self.bias_embed = nn.Embedding(sr_bin_buckets, input_dim)
|
||||
nn.init.ones_(self.scale_embed.weight)
|
||||
nn.init.zeros_(self.bias_embed.weight)
|
||||
elif cond_type == "scale_bias_init":
|
||||
self.scale_embed = nn.Embedding(sr_bin_buckets, input_dim)
|
||||
self.bias_embed = nn.Embedding(sr_bin_buckets, input_dim)
|
||||
nn.init.normal_(self.scale_embed.weight, mean=1)
|
||||
nn.init.normal_(self.bias_embed.weight)
|
||||
elif cond_type == "add":
|
||||
self.cond_embed = nn.Embedding(sr_bin_buckets, input_dim)
|
||||
nn.init.normal_(self.cond_embed.weight)
|
||||
elif cond_type == "concat":
|
||||
self.cond_embed = nn.Embedding(sr_bin_buckets, cond_dim)
|
||||
assert out_layer, "out_layer must be True for concat cond_type"
|
||||
out_layer_in_dim = input_dim + cond_dim
|
||||
else:
|
||||
raise ValueError(f"Invalid cond_type: {cond_type}")
|
||||
|
||||
if out_layer:
|
||||
self.out_layer = nn.Sequential(
|
||||
Snake1d(out_layer_in_dim),
|
||||
WNCausalConv1d(out_layer_in_dim, input_dim, kernel_size=1),
|
||||
)
|
||||
else:
|
||||
self.out_layer = nn.Identity()
|
||||
|
||||
def forward(self, x, sr_cond):
|
||||
if self.cond_type == "scale_bias" or self.cond_type == "scale_bias_init":
|
||||
x = x * self.scale_embed(sr_cond).unsqueeze(-1) + self.bias_embed(sr_cond).unsqueeze(-1)
|
||||
elif self.cond_type == "add":
|
||||
x = x + self.cond_embed(sr_cond).unsqueeze(-1)
|
||||
elif self.cond_type == "concat":
|
||||
x = torch.cat([x, self.cond_embed(sr_cond).unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
|
||||
|
||||
return self.out_layer(x)
|
||||
|
||||
|
||||
class CausalDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_channel,
|
||||
channels,
|
||||
rates,
|
||||
depthwise: bool = False,
|
||||
d_out: int = 1,
|
||||
use_noise_block: bool = False,
|
||||
sr_bin_boundaries: List[int] = None,
|
||||
cond_type: str = "scale_bias",
|
||||
cond_dim: int = 128,
|
||||
cond_out_layer: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Add first conv layer
|
||||
if depthwise:
|
||||
layers = [
|
||||
WNCausalConv1d(input_channel, input_channel, kernel_size=7, padding=3, groups=input_channel),
|
||||
WNCausalConv1d(input_channel, channels, kernel_size=1),
|
||||
]
|
||||
else:
|
||||
layers = [WNCausalConv1d(input_channel, channels, kernel_size=7, padding=3)]
|
||||
|
||||
# Add upsampling + MRF blocks
|
||||
for i, stride in enumerate(rates):
|
||||
input_dim = channels // 2**i
|
||||
output_dim = channels // 2 ** (i + 1)
|
||||
groups = output_dim if depthwise else 1
|
||||
layers += [
|
||||
CausalDecoderBlock(
|
||||
input_dim,
|
||||
output_dim,
|
||||
stride,
|
||||
groups=groups,
|
||||
use_noise_block=use_noise_block,
|
||||
)
|
||||
]
|
||||
|
||||
# Add final conv layer
|
||||
layers += [
|
||||
Snake1d(output_dim),
|
||||
WNCausalConv1d(output_dim, d_out, kernel_size=7, padding=3),
|
||||
nn.Tanh(),
|
||||
]
|
||||
|
||||
if sr_bin_boundaries is None:
|
||||
self.model = nn.Sequential(*layers)
|
||||
self.sr_bin_boundaries = None
|
||||
else:
|
||||
self.model = nn.ModuleList(layers)
|
||||
|
||||
self.register_buffer("sr_bin_boundaries", torch.tensor(sr_bin_boundaries, dtype=torch.int32))
|
||||
self.sr_bin_buckets = len(sr_bin_boundaries) + 1
|
||||
|
||||
cond_layers = []
|
||||
for layer in self.model:
|
||||
if layer.__class__.__name__ == "CausalDecoderBlock":
|
||||
cond_layers.append(
|
||||
SampleRateConditionLayer(
|
||||
input_dim=layer.input_channels,
|
||||
sr_bin_buckets=self.sr_bin_buckets,
|
||||
cond_type=cond_type,
|
||||
cond_dim=cond_dim,
|
||||
out_layer=cond_out_layer,
|
||||
)
|
||||
)
|
||||
else:
|
||||
cond_layers.append(None)
|
||||
self.sr_cond_model = nn.ModuleList(cond_layers)
|
||||
|
||||
def get_sr_idx(self, sr):
|
||||
return torch.bucketize(sr, self.sr_bin_boundaries)
|
||||
|
||||
def forward(self, x, sr_cond=None):
|
||||
if self.sr_bin_boundaries is not None:
|
||||
# assert sr_cond is not None
|
||||
sr_cond = self.get_sr_idx(sr_cond)
|
||||
|
||||
for layer, sr_cond_layer in zip(self.model, self.sr_cond_model):
|
||||
if sr_cond_layer is not None:
|
||||
x = sr_cond_layer(x, sr_cond)
|
||||
x = layer(x)
|
||||
return x
|
||||
else:
|
||||
return self.model(x)
|
||||
|
||||
|
||||
class AudioVAEConfig(BaseModel):
|
||||
encoder_dim: int = 128
|
||||
encoder_rates: List[int] = [2, 5, 8, 8]
|
||||
latent_dim: int = 64
|
||||
decoder_dim: int = 2048
|
||||
decoder_rates: List[int] = [8, 6, 5, 2, 2, 2]
|
||||
depthwise: bool = True
|
||||
sample_rate: int = 16000
|
||||
out_sample_rate: int = 48000
|
||||
use_noise_block: bool = False
|
||||
sr_bin_boundaries: Optional[List[int]] = [20000, 30000, 40000]
|
||||
cond_type: str = "scale_bias"
|
||||
cond_dim: int = 128
|
||||
cond_out_layer: bool = False
|
||||
|
||||
|
||||
class AudioVAE(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AudioVAEConfig = None,
|
||||
):
|
||||
# 如果没有传入config,使用默认配置
|
||||
if config is None:
|
||||
config = AudioVAEConfig()
|
||||
|
||||
super().__init__()
|
||||
|
||||
encoder_dim = config.encoder_dim
|
||||
encoder_rates = config.encoder_rates
|
||||
latent_dim = config.latent_dim
|
||||
decoder_dim = config.decoder_dim
|
||||
decoder_rates = config.decoder_rates
|
||||
depthwise = config.depthwise
|
||||
sample_rate = config.sample_rate
|
||||
out_sample_rate = config.out_sample_rate
|
||||
use_noise_block = config.use_noise_block
|
||||
sr_bin_boundaries = config.sr_bin_boundaries
|
||||
cond_type = config.cond_type
|
||||
cond_dim = config.cond_dim
|
||||
cond_out_layer = config.cond_out_layer
|
||||
|
||||
self.encoder_dim = encoder_dim
|
||||
self.encoder_rates = encoder_rates
|
||||
self.decoder_dim = decoder_dim
|
||||
self.decoder_rates = decoder_rates
|
||||
self.depthwise = depthwise
|
||||
|
||||
self.use_noise_block = use_noise_block
|
||||
|
||||
if latent_dim is None:
|
||||
latent_dim = encoder_dim * (2 ** len(encoder_rates))
|
||||
|
||||
self.latent_dim = latent_dim
|
||||
self.hop_length = np.prod(encoder_rates)
|
||||
self.encoder = CausalEncoder(
|
||||
encoder_dim,
|
||||
latent_dim,
|
||||
encoder_rates,
|
||||
depthwise=depthwise,
|
||||
)
|
||||
|
||||
self.decoder = CausalDecoder(
|
||||
latent_dim,
|
||||
decoder_dim,
|
||||
decoder_rates,
|
||||
depthwise=depthwise,
|
||||
use_noise_block=use_noise_block,
|
||||
sr_bin_boundaries=sr_bin_boundaries,
|
||||
cond_type=cond_type,
|
||||
cond_dim=cond_dim,
|
||||
cond_out_layer=cond_out_layer,
|
||||
)
|
||||
self.sample_rate = sample_rate
|
||||
self.out_sample_rate = out_sample_rate
|
||||
self.sr_bin_boundaries = sr_bin_boundaries
|
||||
self.chunk_size = math.prod(encoder_rates)
|
||||
|
||||
def preprocess(self, audio_data, sample_rate):
|
||||
if sample_rate is None:
|
||||
sample_rate = self.sample_rate
|
||||
assert sample_rate == self.sample_rate
|
||||
pad_to = self.hop_length
|
||||
length = audio_data.shape[-1]
|
||||
right_pad = math.ceil(length / pad_to) * pad_to - length
|
||||
audio_data = nn.functional.pad(audio_data, (0, right_pad))
|
||||
|
||||
return audio_data
|
||||
|
||||
def decode(self, z: torch.Tensor, sr_cond: torch.Tensor = None):
|
||||
"""Decode given latent codes and return audio data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
z : Tensor[B x D x T]
|
||||
Quantized continuous representation of input
|
||||
length : int, optional
|
||||
Number of samples in output audio, by default None
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A dictionary with the following keys:
|
||||
"audio" : Tensor[B x 1 x length]
|
||||
Decoded audio data.
|
||||
"""
|
||||
if self.sr_bin_boundaries is not None:
|
||||
# use default output sample rate
|
||||
if sr_cond is None:
|
||||
sr_cond = torch.tensor([self.out_sample_rate], device=z.device, dtype=torch.int32)
|
||||
return self.decoder(z, sr_cond)
|
||||
|
||||
def encode(self, audio_data: torch.Tensor, sample_rate: int):
|
||||
"""
|
||||
Args:
|
||||
audio_data: Tensor[B x 1 x T]
|
||||
sample_rate: int
|
||||
Returns:
|
||||
z: Tensor[B x D x T]
|
||||
"""
|
||||
if audio_data.ndim == 2:
|
||||
audio_data = audio_data.unsqueeze(1)
|
||||
|
||||
audio_data = self.preprocess(audio_data, sample_rate)
|
||||
return self.encoder(audio_data)["mu"]
|
||||
@@ -1 +1 @@
|
||||
from .scalar_quantization_layer import ScalarQuantizationLayer
|
||||
from .scalar_quantization_layer import ScalarQuantizationLayer
|
||||
|
||||
@@ -34,7 +34,7 @@ class LoRALinear(nn.Module):
|
||||
self.r = r
|
||||
self.alpha = alpha
|
||||
self._base_scaling = alpha / r if r > 0 else 0.0
|
||||
|
||||
|
||||
# 使用 buffer 存储 scaling,这样修改值不会触发 torch.compile 重编译
|
||||
# persistent=False 表示不保存到 state_dict,避免加载时 missing key
|
||||
self.register_buffer("scaling", torch.tensor(self._base_scaling), persistent=False)
|
||||
@@ -128,6 +128,3 @@ def apply_lora_to_named_linear_modules(
|
||||
dropout=dropout,
|
||||
)
|
||||
setattr(parent, short_name, lora_layer)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ class ScalarQuantizationLayer(nn.Module):
|
||||
|
||||
self.in_proj = nn.Linear(in_dim, latent_dim)
|
||||
self.out_proj = nn.Linear(latent_dim, out_dim)
|
||||
|
||||
|
||||
def forward(self, hidden):
|
||||
hidden = self.in_proj(hidden)
|
||||
hidden = torch.tanh(hidden)
|
||||
@@ -23,4 +23,4 @@ class ScalarQuantizationLayer(nn.Module):
|
||||
else:
|
||||
hidden = torch.round(hidden * self.scale) / self.scale
|
||||
|
||||
return self.out_proj(hidden)
|
||||
return self.out_proj(hidden)
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
from .unified_cfm import UnifiedCFM, CfmConfig
|
||||
from .local_dit import VoxCPMLocDiT
|
||||
from .local_dit_v2 import VoxCPMLocDiT as VoxCPMLocDiTV2
|
||||
|
||||
@@ -0,0 +1,116 @@
|
||||
import torch
|
||||
from ..minicpm4 import MiniCPMModel, MiniCPM4Config
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
|
||||
class SinusoidalPosEmb(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
|
||||
|
||||
def forward(self, x, scale=1000):
|
||||
if x.ndim < 1:
|
||||
x = x.unsqueeze(0)
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=x.dtype, device=device) * -emb)
|
||||
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
time_embed_dim: int,
|
||||
out_dim: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True)
|
||||
self.act = nn.SiLU()
|
||||
if out_dim is not None:
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, bias=True)
|
||||
|
||||
def forward(self, sample):
|
||||
sample = self.linear_1(sample)
|
||||
sample = self.act(sample)
|
||||
sample = self.linear_2(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class VoxCPMLocDiT(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MiniCPM4Config,
|
||||
in_channels: int = 64,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.config = config
|
||||
|
||||
self.in_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
|
||||
self.cond_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
|
||||
self.out_proj = nn.Linear(config.hidden_size, self.out_channels, bias=True)
|
||||
|
||||
self.time_embeddings = SinusoidalPosEmb(config.hidden_size)
|
||||
self.time_mlp = TimestepEmbedding(
|
||||
in_channels=config.hidden_size,
|
||||
time_embed_dim=config.hidden_size,
|
||||
)
|
||||
self.delta_time_mlp = TimestepEmbedding(
|
||||
in_channels=config.hidden_size,
|
||||
time_embed_dim=config.hidden_size,
|
||||
)
|
||||
|
||||
assert config.vocab_size == 0, "vocab_size must be 0 for local DiT"
|
||||
self.decoder = MiniCPMModel(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mu: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
cond: torch.Tensor,
|
||||
dt: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Forward pass of DiT.
|
||||
x: (N, C, T) tensor of inputs
|
||||
mu: (N, C) tensor of hidden embedding
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
cond: (N, C, T') tensor of prefix conditions
|
||||
dt: (N,) used for mean velocity (may be supported in the future...)
|
||||
"""
|
||||
x = self.in_proj(x.transpose(1, 2).contiguous())
|
||||
|
||||
cond = self.cond_proj(cond.transpose(1, 2).contiguous())
|
||||
prefix = cond.size(1)
|
||||
|
||||
t = self.time_embeddings(t).to(x.dtype)
|
||||
t = self.time_mlp(t)
|
||||
dt = self.time_embeddings(dt).to(x.dtype)
|
||||
dt = self.delta_time_mlp(dt)
|
||||
t = t + dt
|
||||
|
||||
mu = mu.view(x.size(0), -1, x.size(-1))
|
||||
x = torch.cat([mu, (t).unsqueeze(1), cond, x], dim=1)
|
||||
|
||||
hidden, _ = self.decoder(x, is_causal=False)
|
||||
hidden = hidden[:, prefix + mu.size(1) + 1 :, :]
|
||||
hidden = self.out_proj(hidden)
|
||||
|
||||
return hidden.transpose(1, 2).contiguous()
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -56,7 +56,7 @@ class UnifiedCFM(torch.nn.Module):
|
||||
cond: torch.Tensor,
|
||||
temperature: float = 1.0,
|
||||
cfg_value: float = 1.0,
|
||||
sway_sampling_coef: float = 1.0,
|
||||
sway_sampling_coef: float = 1.0,
|
||||
use_cfg_zero_star: bool = True,
|
||||
):
|
||||
b, _ = mu.shape
|
||||
@@ -116,7 +116,7 @@ class UnifiedCFM(torch.nn.Module):
|
||||
|
||||
dphi_dt = self.estimator(x_in, mu_in, t_in, cond_in, dt_in)
|
||||
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
||||
|
||||
|
||||
if use_cfg_zero_star:
|
||||
positive_flat = dphi_dt.view(b, -1)
|
||||
negative_flat = cfg_dphi_dt.view(b, -1)
|
||||
@@ -124,7 +124,7 @@ class UnifiedCFM(torch.nn.Module):
|
||||
st_star = st_star.view(b, *([1] * (len(dphi_dt.shape) - 1)))
|
||||
else:
|
||||
st_star = 1.0
|
||||
|
||||
|
||||
dphi_dt = cfg_dphi_dt * st_star + cfg_value * (dphi_dt - cfg_dphi_dt * st_star)
|
||||
|
||||
x = x - dt * dphi_dt
|
||||
@@ -138,7 +138,9 @@ class UnifiedCFM(torch.nn.Module):
|
||||
# ------------------------------------------------------------------ #
|
||||
# Training loss
|
||||
# ------------------------------------------------------------------ #
|
||||
def adaptive_loss_weighting(self, losses: torch.Tensor, mask: torch.Tensor | None = None, p: float = 0.0, epsilon: float = 1e-3):
|
||||
def adaptive_loss_weighting(
|
||||
self, losses: torch.Tensor, mask: torch.Tensor | None = None, p: float = 0.0, epsilon: float = 1e-3
|
||||
):
|
||||
weights = 1.0 / ((losses + epsilon).pow(p))
|
||||
if mask is not None:
|
||||
weights = weights * mask
|
||||
@@ -193,8 +195,7 @@ class UnifiedCFM(torch.nn.Module):
|
||||
cond = cond + noisy_mask.view(-1, 1, 1) * torch.randn_like(cond) * self.noise_cond_scale
|
||||
|
||||
ratio_r_neq_t = (
|
||||
self.ratio_r_neq_t_range[0]
|
||||
+ progress * (self.ratio_r_neq_t_range[1] - self.ratio_r_neq_t_range[0])
|
||||
self.ratio_r_neq_t_range[0] + progress * (self.ratio_r_neq_t_range[1] - self.ratio_r_neq_t_range[0])
|
||||
if self.mean_mode
|
||||
else 0.0
|
||||
)
|
||||
|
||||
@@ -26,4 +26,4 @@ class MiniCPM4Config(BaseModel):
|
||||
dim_model_base: int
|
||||
scale_depth: float
|
||||
rope_theta: float
|
||||
kv_channels: int = None
|
||||
kv_channels: int = None
|
||||
|
||||
@@ -64,10 +64,8 @@ class MiniCPMLongRoPE(nn.Module):
|
||||
self.long_factor = config.rope_scaling.long_factor
|
||||
self.original_max_position_embeddings = config.rope_scaling.original_max_position_embeddings
|
||||
|
||||
scale = (self.max_position_embeddings / self.original_max_position_embeddings)
|
||||
self.scaling_factor = math.sqrt(
|
||||
1 + math.log(scale) / math.log(self.original_max_position_embeddings)
|
||||
)
|
||||
scale = self.max_position_embeddings / self.original_max_position_embeddings
|
||||
self.scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
@@ -76,11 +74,7 @@ class MiniCPMLongRoPE(nn.Module):
|
||||
self.register_buffer("cos_cached", torch.empty(0), persistent=False)
|
||||
self.register_buffer("sin_cached", torch.empty(0), persistent=False)
|
||||
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=self.max_position_embeddings,
|
||||
device=self.inv_freq.device,
|
||||
dtype=torch.float32
|
||||
)
|
||||
self._set_cos_sin_cache(seq_len=self.max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
"""设置cos和sin缓存"""
|
||||
@@ -93,8 +87,7 @@ class MiniCPMLongRoPE(nn.Module):
|
||||
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=device)
|
||||
|
||||
freqs = torch.mul(
|
||||
torch.outer(t, 1.0 / ext_factors).to(device=device),
|
||||
self.inv_freq.to(device=device).to(dtype)
|
||||
torch.outer(t, 1.0 / ext_factors).to(device=device), self.inv_freq.to(device=device).to(dtype)
|
||||
)
|
||||
|
||||
# 创建embeddings
|
||||
@@ -123,7 +116,9 @@ class MiniCPMAttention(nn.Module):
|
||||
self.layer_idx = layer_idx
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
||||
self.head_dim = (
|
||||
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
||||
)
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
@@ -153,7 +148,7 @@ class MiniCPMAttention(nn.Module):
|
||||
cos, sin = position_emb
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
|
||||
# ref: https://github.com/pytorch/pytorch/issues/163597
|
||||
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
|
||||
query_states = query_states.contiguous()
|
||||
@@ -413,7 +408,11 @@ class MiniCPMModel(nn.Module):
|
||||
self.kv_cache = StaticKVCache(
|
||||
num_layers=self.config.num_hidden_layers,
|
||||
num_kv_heads=self.config.num_key_value_heads,
|
||||
dim_kv_head=self.config.hidden_size // self.config.num_attention_heads if self.config.kv_channels is None else self.config.kv_channels,
|
||||
dim_kv_head=(
|
||||
self.config.hidden_size // self.config.num_attention_heads
|
||||
if self.config.kv_channels is None
|
||||
else self.config.kv_channels
|
||||
),
|
||||
batch_size=batch_size,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
|
||||
@@ -25,4 +25,3 @@ __all__ = [
|
||||
"load_audio_text_datasets",
|
||||
"build_dataloader",
|
||||
]
|
||||
|
||||
|
||||
@@ -47,9 +47,7 @@ class Accelerator:
|
||||
pass
|
||||
|
||||
self.scaler = torch.amp.GradScaler("cuda") if (amp and torch.cuda.is_available()) else DummyScaler()
|
||||
self.device_ctx = (
|
||||
torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
|
||||
)
|
||||
self.device_ctx = torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
|
||||
self._ddp_model = None # For no_sync support
|
||||
|
||||
def _set_seed(self, seed: int):
|
||||
@@ -84,7 +82,7 @@ class Accelerator:
|
||||
# Model helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
def prepare_model(self, model: torch.nn.Module, **kwargs):
|
||||
if hasattr(model, 'device'): # make sure the matrix will be moved to the correct device
|
||||
if hasattr(model, "device"): # make sure the matrix will be moved to the correct device
|
||||
model.device = self.device
|
||||
model = model.to(self.device)
|
||||
if self.world_size > 1:
|
||||
@@ -163,4 +161,3 @@ class Accelerator:
|
||||
@staticmethod
|
||||
def unwrap(model: torch.nn.Module) -> torch.nn.Module:
|
||||
return model.module if hasattr(model, "module") else model
|
||||
|
||||
|
||||
@@ -36,5 +36,3 @@ def parse_args_with_config(config_path: str | Path | None = None):
|
||||
yaml_args = argbind.parse_args(yaml_args=yaml_args, argv=[])
|
||||
cli_args.update(yaml_args)
|
||||
return cli_args
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import argbind
|
||||
@@ -11,7 +10,6 @@ from ..model.voxcpm import VoxCPMConfig
|
||||
from ..modules.audiovae import AudioVAE
|
||||
from .packers import AudioFeatureProcessingPacker
|
||||
|
||||
|
||||
DEFAULT_TEXT_COLUMN = "text"
|
||||
DEFAULT_AUDIO_COLUMN = "audio"
|
||||
DEFAULT_ID_COLUMN = "dataset_id"
|
||||
@@ -36,7 +34,7 @@ def load_audio_text_datasets(
|
||||
def prepare(ds: Dataset) -> Dataset:
|
||||
if audio_column not in ds.column_names:
|
||||
raise ValueError(f"Expected '{audio_column}' column in manifest.")
|
||||
# We cast to Audio to ensure proper handling during training,
|
||||
# We cast to Audio to ensure proper handling during training,
|
||||
# but for length calculation we might need raw path or duration if available.
|
||||
# HF datasets usually don't compute duration automatically for 'Audio' column.
|
||||
ds = ds.cast_column(audio_column, Audio(sampling_rate=sample_rate))
|
||||
@@ -70,13 +68,13 @@ def compute_sample_lengths(
|
||||
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
|
||||
t_seq = ceil(t_vae / patch_size)
|
||||
- 序列总长约为: text_len + t_seq + 2
|
||||
|
||||
|
||||
Optimized: Use batch column access instead of iterating item by item.
|
||||
"""
|
||||
# Batch access columns - much faster than per-item access
|
||||
text_ids_list = ds["text_ids"]
|
||||
text_lens = [len(t) for t in text_ids_list]
|
||||
|
||||
|
||||
has_duration = "duration" in ds.column_names
|
||||
if has_duration:
|
||||
durations = ds["duration"]
|
||||
@@ -86,7 +84,7 @@ def compute_sample_lengths(
|
||||
for i in range(len(ds)):
|
||||
audio = ds[i][DEFAULT_AUDIO_COLUMN]
|
||||
durations.append(len(audio["array"]) / float(audio["sampling_rate"]))
|
||||
|
||||
|
||||
# Vectorized length computation
|
||||
lengths = []
|
||||
for text_len, duration in zip(text_lens, durations):
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -15,7 +14,7 @@ class AudioFeatureProcessingPacker:
|
||||
def __init__(self, dataset_cnt: int, max_len: int, patch_size: int, feat_dim: int, audio_vae: nn.Module):
|
||||
self.audio_start_id = 101
|
||||
self.audio_end_id = 102
|
||||
# unused now
|
||||
# unused now
|
||||
self.audio_prompt_start_id = 103
|
||||
self.audio_prompt_end_id = 104
|
||||
self.text_eos_token_id = 2
|
||||
@@ -147,31 +146,26 @@ class AudioFeatureProcessingPacker:
|
||||
|
||||
def pad_1d(x: torch.Tensor, pad_value: int = 0) -> torch.Tensor:
|
||||
if x.size(0) >= max_len:
|
||||
return x[: max_len]
|
||||
return x[:max_len]
|
||||
pad = torch.full((max_len - x.size(0),), pad_value, dtype=x.dtype, device=x.device)
|
||||
return torch.cat([x, pad], dim=0)
|
||||
|
||||
def pad_3d(x: torch.Tensor) -> torch.Tensor:
|
||||
# x: [T, P, D]
|
||||
if x.size(0) >= max_len:
|
||||
return x[: max_len]
|
||||
pad = torch.zeros(
|
||||
(max_len - x.size(0),) + x.shape[1:], dtype=x.dtype, device=x.device
|
||||
)
|
||||
return x[:max_len]
|
||||
pad = torch.zeros((max_len - x.size(0),) + x.shape[1:], dtype=x.dtype, device=x.device)
|
||||
return torch.cat([x, pad], dim=0)
|
||||
|
||||
if lengths:
|
||||
text_tokens_batch = torch.stack([pad_1d(t, pad_value=0) for t in text_tokens_list], dim=0)
|
||||
text_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in text_mask_list], dim=0)
|
||||
audio_feats_batch = torch.stack([pad_3d(f) for f in audio_feats_list], dim=0)
|
||||
audio_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in audio_mask_list], dim=0)
|
||||
loss_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in loss_mask_list], dim=0)
|
||||
labels_batch = torch.stack([pad_1d(l, pad_value=0) for l in labels_list], dim=0)
|
||||
audio_task_ids_batch = torch.stack(
|
||||
[pad_1d(t, pad_value=0) for t in audio_task_ids_list], dim=0
|
||||
)
|
||||
audio_dataset_ids_batch = torch.stack(
|
||||
[pad_1d(d, pad_value=0) for d in audio_dataset_ids_list], dim=0
|
||||
)
|
||||
labels_batch = torch.stack([pad_1d(lbl, pad_value=0) for lbl in labels_list], dim=0)
|
||||
audio_task_ids_batch = torch.stack([pad_1d(t, pad_value=0) for t in audio_task_ids_list], dim=0)
|
||||
audio_dataset_ids_batch = torch.stack([pad_1d(d, pad_value=0) for d in audio_dataset_ids_list], dim=0)
|
||||
|
||||
# Position ids: [B, T], simple 0..L_i-1 then padded with 0
|
||||
position_ids_list = []
|
||||
@@ -265,13 +259,27 @@ class AudioFeatureProcessingPacker:
|
||||
)
|
||||
audio_feat_info = torch.cat([audio_pad_feat, audio_feat_info, audio_pad_feat[0:1, ...]], dim=0)
|
||||
|
||||
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length), torch.ones(1)]).type(torch.int32).to(
|
||||
text_token.device
|
||||
text_mask = (
|
||||
torch.cat([torch.ones(text_length), torch.zeros(audio_length), torch.ones(1)])
|
||||
.type(torch.int32)
|
||||
.to(text_token.device)
|
||||
)
|
||||
audio_mask = (
|
||||
torch.cat([torch.zeros(text_length), torch.ones(audio_length), torch.zeros(1)])
|
||||
.type(torch.int32)
|
||||
.to(text_token.device)
|
||||
)
|
||||
loss_mask = (
|
||||
torch.cat(
|
||||
[
|
||||
torch.zeros(text_length),
|
||||
torch.zeros(audio_length) if is_prompt else torch.ones(audio_length),
|
||||
torch.zeros(1),
|
||||
]
|
||||
)
|
||||
.type(torch.int32)
|
||||
.to(text_token.device)
|
||||
)
|
||||
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length), torch.zeros(1)]).type(
|
||||
torch.int32
|
||||
).to(text_token.device)
|
||||
loss_mask = torch.cat([torch.zeros(text_length), torch.zeros(audio_length) if is_prompt else torch.ones(audio_length), torch.zeros(1)]).type(torch.int32).to(text_token.device)
|
||||
|
||||
labels = torch.zeros(text_length + audio_length + 1).type(torch.int32).to(text_token.device)
|
||||
labels[-2] = 1
|
||||
@@ -286,4 +294,3 @@ class AudioFeatureProcessingPacker:
|
||||
audio_duration,
|
||||
text_token_count,
|
||||
)
|
||||
|
||||
|
||||
@@ -18,4 +18,3 @@ class TrainingState:
|
||||
val_loader: object
|
||||
tracker: object
|
||||
batch_processor: object
|
||||
|
||||
|
||||
@@ -76,4 +76,3 @@ class TrainingTracker:
|
||||
@contextlib.contextmanager
|
||||
def live(self):
|
||||
yield
|
||||
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
import re
|
||||
import regex
|
||||
import inflect
|
||||
from functools import partial
|
||||
from wetext import Normalizer
|
||||
|
||||
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
|
||||
chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]+")
|
||||
|
||||
|
||||
# whether contain chinese character
|
||||
def contains_chinese(text):
|
||||
@@ -14,19 +14,19 @@ def contains_chinese(text):
|
||||
|
||||
# replace special symbol
|
||||
def replace_corner_mark(text):
|
||||
text = text.replace('²', '平方')
|
||||
text = text.replace('³', '立方')
|
||||
text = text.replace('√', '根号')
|
||||
text = text.replace('≈', '约等于')
|
||||
text = text.replace('<', '小于')
|
||||
text = text.replace("²", "平方")
|
||||
text = text.replace("³", "立方")
|
||||
text = text.replace("√", "根号")
|
||||
text = text.replace("≈", "约等于")
|
||||
text = text.replace("<", "小于")
|
||||
return text
|
||||
|
||||
|
||||
# remove meaningless symbol
|
||||
def remove_bracket(text):
|
||||
text = text.replace('(', ' ').replace(')', ' ')
|
||||
text = text.replace('【', ' ').replace('】', ' ')
|
||||
text = text.replace('`', '').replace('`', '')
|
||||
text = text.replace("(", " ").replace(")", " ")
|
||||
text = text.replace("【", " ").replace("】", " ")
|
||||
text = text.replace("`", "").replace("`", "")
|
||||
text = text.replace("——", " ")
|
||||
return text
|
||||
|
||||
@@ -38,7 +38,7 @@ def spell_out_number(text: str, inflect_parser):
|
||||
for i, c in enumerate(text):
|
||||
if not c.isdigit():
|
||||
if st is not None:
|
||||
num_str = inflect_parser.number_to_words(text[st: i])
|
||||
num_str = inflect_parser.number_to_words(text[st:i])
|
||||
new_text.append(num_str)
|
||||
st = None
|
||||
new_text.append(c)
|
||||
@@ -48,7 +48,7 @@ def spell_out_number(text: str, inflect_parser):
|
||||
if st is not None and st < len(text):
|
||||
num_str = inflect_parser.number_to_words(text[st:])
|
||||
new_text.append(num_str)
|
||||
return ''.join(new_text)
|
||||
return "".join(new_text)
|
||||
|
||||
|
||||
# split paragrah logic:
|
||||
@@ -69,18 +69,18 @@ def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=
|
||||
return len(tokenize(_text)) < merge_len
|
||||
|
||||
if lang == "zh":
|
||||
pounc = ['。', '?', '!', ';', ':', '、', '.', '?', '!', ';']
|
||||
pounc = ["。", "?", "!", ";", ":", "、", ".", "?", "!", ";"]
|
||||
else:
|
||||
pounc = ['.', '?', '!', ';', ':']
|
||||
pounc = [".", "?", "!", ";", ":"]
|
||||
if comma_split:
|
||||
pounc.extend([',', ','])
|
||||
pounc.extend([",", ","])
|
||||
st = 0
|
||||
utts = []
|
||||
for i, c in enumerate(text):
|
||||
if c in pounc:
|
||||
if len(text[st: i]) > 0:
|
||||
utts.append(text[st: i] + c)
|
||||
if i + 1 < len(text) and text[i + 1] in ['"', '”']:
|
||||
if len(text[st:i]) > 0:
|
||||
utts.append(text[st:i] + c)
|
||||
if i + 1 < len(text) and text[i + 1] in ['"', "”"]:
|
||||
tmp = utts.pop(-1)
|
||||
utts.append(tmp + text[i + 1])
|
||||
st = i + 2
|
||||
@@ -88,9 +88,9 @@ def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=
|
||||
st = i + 1
|
||||
if len(utts) == 0:
|
||||
if lang == "zh":
|
||||
utts.append(text + '。')
|
||||
utts.append(text + "。")
|
||||
else:
|
||||
utts.append(text + '.')
|
||||
utts.append(text + ".")
|
||||
final_utts = []
|
||||
cur_utt = ""
|
||||
for utt in utts:
|
||||
@@ -112,13 +112,13 @@ def replace_blank(text: str):
|
||||
out_str = []
|
||||
for i, c in enumerate(text):
|
||||
if c == " ":
|
||||
if ((text[i + 1].isascii() and text[i + 1] != " ") and
|
||||
(text[i - 1].isascii() and text[i - 1] != " ")):
|
||||
if (text[i + 1].isascii() and text[i + 1] != " ") and (text[i - 1].isascii() and text[i - 1] != " "):
|
||||
out_str.append(c)
|
||||
else:
|
||||
out_str.append(c)
|
||||
return "".join(out_str)
|
||||
|
||||
|
||||
def clean_markdown(md_text: str) -> str:
|
||||
# 去除代码块 ``` ```(包括多行)
|
||||
md_text = re.sub(r"```.*?```", "", md_text, flags=re.DOTALL)
|
||||
@@ -131,9 +131,9 @@ def clean_markdown(md_text: str) -> str:
|
||||
|
||||
# 去除链接但保留文本 [text](url) -> text
|
||||
md_text = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", md_text)
|
||||
|
||||
|
||||
# 替换无序列表符号
|
||||
md_text = re.sub(r'^(\s*)-\s+', r'\1', md_text, flags=re.MULTILINE)
|
||||
md_text = re.sub(r"^(\s*)-\s+", r"\1", md_text, flags=re.MULTILINE)
|
||||
|
||||
# 去除HTML标签
|
||||
md_text = re.sub(r"<[^>]+>", "", md_text)
|
||||
@@ -152,28 +152,31 @@ def clean_text(text):
|
||||
# 去除 Markdown 语法
|
||||
text = clean_markdown(text)
|
||||
# 匹配并移除表情符号
|
||||
text = regex.compile(r'\p{Emoji_Presentation}|\p{Emoji}\uFE0F', flags=regex.UNICODE).sub("",text)
|
||||
text = regex.compile(r"\p{Emoji_Presentation}|\p{Emoji}\uFE0F", flags=regex.UNICODE).sub("", text)
|
||||
# 去除换行符
|
||||
text = text.replace("\n", " ")
|
||||
text = text.replace("\t", " ")
|
||||
text = text.replace('"', "\“")
|
||||
text = text.replace("“", '"').replace("”", '"')
|
||||
return text
|
||||
|
||||
|
||||
class TextNormalizer:
|
||||
def __init__(self, tokenizer=None):
|
||||
self.tokenizer = tokenizer
|
||||
self.zh_tn_model = Normalizer(lang="zh", operator="tn", remove_erhua=True)
|
||||
self.en_tn_model = Normalizer(lang="en", operator="tn")
|
||||
self.inflect_parser = inflect.engine()
|
||||
|
||||
|
||||
def normalize(self, text, split=False):
|
||||
# 去除 Markdown 语法,去除表情符号,去除换行符
|
||||
lang = "zh" if contains_chinese(text) else "en"
|
||||
text = clean_text(text)
|
||||
if lang == "zh":
|
||||
text = text.replace("=", "等于") # 修复 ”550 + 320 等于 870 千卡。“ 被错误正则为 ”五百五十加三百二十等于八七十千卡.“
|
||||
if re.search(r'([\d$%^*_+≥≤≠×÷?=])', text): # 避免 英文连字符被错误正则为减
|
||||
text = re.sub(r'(?<=[a-zA-Z0-9])-(?=\d)', ' - ', text) # 修复 x-2 被正则为 x负2
|
||||
text = text.replace(
|
||||
"=", "等于"
|
||||
) # 修复 ”550 + 320 等于 870 千卡。“ 被错误正则为 ”五百五十加三百二十等于八七十千卡.“
|
||||
if re.search(r"([\d$%^*_+≥≤≠×÷?=])", text): # 避免 英文连字符被错误正则为减
|
||||
text = re.sub(r"(?<=[a-zA-Z0-9])-(?=\d)", " - ", text) # 修复 x-2 被正则为 x负2
|
||||
text = self.zh_tn_model.normalize(text)
|
||||
text = replace_blank(text)
|
||||
text = replace_corner_mark(text)
|
||||
@@ -182,4 +185,4 @@ class TextNormalizer:
|
||||
text = self.en_tn_model.normalize(text)
|
||||
text = spell_out_number(text, self.inflect_parser)
|
||||
if split is False:
|
||||
return text
|
||||
return text
|
||||
|
||||
+10
-14
@@ -7,15 +7,15 @@ Related dependencies are imported only when denoising functionality is needed.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
import torchaudio
|
||||
import torch
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
class ZipEnhancer:
|
||||
"""ZipEnhancer Audio Denoising Enhancer"""
|
||||
|
||||
def __init__(self, model_path: str = "iic/speech_zipenhancer_ans_multiloss_16k_base"):
|
||||
"""
|
||||
Initialize ZipEnhancer
|
||||
@@ -23,25 +23,21 @@ class ZipEnhancer:
|
||||
model_path: ModelScope model path or local path
|
||||
"""
|
||||
self.model_path = model_path
|
||||
self._pipeline = pipeline(
|
||||
Tasks.acoustic_noise_suppression,
|
||||
model=self.model_path
|
||||
)
|
||||
|
||||
self._pipeline = pipeline(Tasks.acoustic_noise_suppression, model=self.model_path)
|
||||
|
||||
def _normalize_loudness(self, wav_path: str):
|
||||
"""
|
||||
Audio loudness normalization
|
||||
|
||||
|
||||
Args:
|
||||
wav_path: Audio file path
|
||||
"""
|
||||
audio, sr = torchaudio.load(wav_path)
|
||||
loudness = torchaudio.functional.loudness(audio, sr)
|
||||
normalized_audio = torchaudio.functional.gain(audio, -20-loudness)
|
||||
normalized_audio = torchaudio.functional.gain(audio, -20 - loudness)
|
||||
torchaudio.save(wav_path, normalized_audio, sr)
|
||||
|
||||
def enhance(self, input_path: str, output_path: Optional[str] = None,
|
||||
normalize_loudness: bool = True) -> str:
|
||||
|
||||
def enhance(self, input_path: str, output_path: Optional[str] = None, normalize_loudness: bool = True) -> str:
|
||||
"""
|
||||
Audio denoising enhancement
|
||||
Args:
|
||||
@@ -57,7 +53,7 @@ class ZipEnhancer:
|
||||
raise FileNotFoundError(f"Input audio file does not exist: {input_path}")
|
||||
# Create temporary file if no output path is specified
|
||||
if output_path is None:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
|
||||
output_path = tmp_file.name
|
||||
try:
|
||||
# Perform denoising processing
|
||||
@@ -73,4 +69,4 @@ class ZipEnhancer:
|
||||
os.unlink(output_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise RuntimeError(f"Audio denoising processing failed: {e}")
|
||||
raise RuntimeError(f"Audio denoising processing failed: {e}")
|
||||
|
||||
Reference in New Issue
Block a user