update voxcpm2
This commit is contained in:
@@ -3,13 +3,14 @@ import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
import gradio as gr
|
||||
import spaces
|
||||
import spaces # noqa: F401
|
||||
from typing import Optional, Tuple
|
||||
from funasr import AutoModel
|
||||
from pathlib import Path
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
if os.environ.get("HF_REPO_ID", "").strip() == "":
|
||||
os.environ["HF_REPO_ID"] = "openbmb/VoxCPM1.5"
|
||||
os.environ["HF_REPO_ID"] = "openbmb/VoxCPM2"
|
||||
|
||||
import voxcpm
|
||||
|
||||
@@ -24,13 +25,13 @@ class VoxCPMDemo:
|
||||
self.asr_model: Optional[AutoModel] = AutoModel(
|
||||
model=self.asr_model_id,
|
||||
disable_update=True,
|
||||
log_level='DEBUG',
|
||||
log_level="DEBUG",
|
||||
device="cuda:0" if self.device == "cuda" else "cpu",
|
||||
)
|
||||
|
||||
# TTS model (lazy init)
|
||||
self.voxcpm_model: Optional[voxcpm.VoxCPM] = None
|
||||
self.default_local_model_dir = "./models/VoxCPM1.5"
|
||||
self.default_local_model_dir = "/Users/xinliu/Downloads/VoxCPM2-0.5B-newaudiovae-6hz-0316"
|
||||
|
||||
# ---------- Model helpers ----------
|
||||
def _resolve_model_dir(self) -> str:
|
||||
@@ -49,6 +50,7 @@ class VoxCPMDemo:
|
||||
if not os.path.isdir(target_dir):
|
||||
try:
|
||||
from huggingface_hub import snapshot_download # type: ignore
|
||||
|
||||
os.makedirs(target_dir, exist_ok=True)
|
||||
print(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...", file=sys.stderr)
|
||||
snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False)
|
||||
@@ -64,7 +66,7 @@ class VoxCPMDemo:
|
||||
print("Model not loaded, initializing...", file=sys.stderr)
|
||||
model_dir = self._resolve_model_dir()
|
||||
print(f"Using model dir: {model_dir}", file=sys.stderr)
|
||||
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir)
|
||||
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir, optimize=False)
|
||||
print("Model loaded successfully.", file=sys.stderr)
|
||||
return self.voxcpm_model
|
||||
|
||||
@@ -73,21 +75,24 @@ class VoxCPMDemo:
|
||||
if prompt_wav is None:
|
||||
return ""
|
||||
res = self.asr_model.generate(input=prompt_wav, language="auto", use_itn=True)
|
||||
text = res[0]["text"].split('|>')[-1]
|
||||
text = res[0]["text"].split("|>")[-1]
|
||||
return text
|
||||
|
||||
def generate_tts_audio(
|
||||
self,
|
||||
text_input: str,
|
||||
prompt_wav_path_input: Optional[str] = None,
|
||||
prompt_text_input: Optional[str] = None,
|
||||
control_instruction: str = "",
|
||||
reference_wav_path_input: Optional[str] = None,
|
||||
cfg_value_input: float = 2.0,
|
||||
inference_timesteps_input: int = 10,
|
||||
do_normalize: bool = True,
|
||||
denoise: bool = True,
|
||||
) -> Tuple[int, np.ndarray]:
|
||||
"""
|
||||
Generate speech from text using VoxCPM; optional reference audio for voice style guidance.
|
||||
Generate speech from text using VoxCPM.
|
||||
- If reference_wav provided: Prompt isolation mode (voice cloning)
|
||||
- If no reference_wav: Voice design mode (use control_instruction to describe voice)
|
||||
|
||||
Returns (sample_rate, waveform_numpy)
|
||||
"""
|
||||
current_model = self.get_or_load_voxcpm()
|
||||
@@ -96,14 +101,25 @@ class VoxCPMDemo:
|
||||
if len(text) == 0:
|
||||
raise ValueError("Please input text to synthesize.")
|
||||
|
||||
prompt_wav_path = prompt_wav_path_input if prompt_wav_path_input else None
|
||||
prompt_text = prompt_text_input if prompt_text_input else None
|
||||
# 处理 control instruction
|
||||
control = (control_instruction or "").strip()
|
||||
if control:
|
||||
final_text = f"({control}){text}"
|
||||
else:
|
||||
final_text = text
|
||||
|
||||
print(f"Generating audio for text: '{text[:60]}...'", file=sys.stderr)
|
||||
reference_wav_path = reference_wav_path_input if reference_wav_path_input else None
|
||||
|
||||
# 判断模式
|
||||
if reference_wav_path:
|
||||
print(f"[Prompt Isolation Mode] reference_wav: {reference_wav_path}", file=sys.stderr)
|
||||
else:
|
||||
print(f"[Voice Design Mode] control: {control[:50] if control else 'None'}...", file=sys.stderr)
|
||||
|
||||
print(f"Generating audio for text: '{final_text[:80]}...'", file=sys.stderr)
|
||||
wav = current_model.generate(
|
||||
text=text,
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
text=final_text,
|
||||
reference_wav_path=reference_wav_path,
|
||||
cfg_value=float(cfg_value_input),
|
||||
inference_timesteps=int(inference_timesteps_input),
|
||||
normalize=do_normalize,
|
||||
@@ -114,19 +130,14 @@ class VoxCPMDemo:
|
||||
|
||||
# ---------- UI Builders ----------
|
||||
|
||||
def create_demo_interface(demo: VoxCPMDemo):
|
||||
"""Build the Gradio UI for VoxCPM demo."""
|
||||
# static assets (logo path)
|
||||
gr.set_static_paths(paths=[Path.cwd().absolute()/"assets"])
|
||||
|
||||
with gr.Blocks(
|
||||
theme=gr.themes.Soft(
|
||||
THEME = gr.themes.Soft(
|
||||
primary_hue="blue",
|
||||
secondary_hue="gray",
|
||||
neutral_hue="slate",
|
||||
font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"]
|
||||
),
|
||||
css="""
|
||||
font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"],
|
||||
)
|
||||
|
||||
CSS = """
|
||||
.logo-container {
|
||||
text-align: center;
|
||||
margin: 0.5rem 0 1rem 0;
|
||||
@@ -138,8 +149,12 @@ def create_demo_interface(demo: VoxCPMDemo):
|
||||
display: inline-block;
|
||||
}
|
||||
/* Bold accordion labels */
|
||||
#acc_quick details > summary,
|
||||
#acc_tips details > summary {
|
||||
#acc_quick > .label-wrap,
|
||||
#acc_tips > .label-wrap,
|
||||
#acc_quick > .label-wrap > span,
|
||||
#acc_tips > .label-wrap > span,
|
||||
#acc_quick summary,
|
||||
#acc_tips summary {
|
||||
font-weight: 600 !important;
|
||||
font-size: 1.1em !important;
|
||||
}
|
||||
@@ -151,9 +166,17 @@ def create_demo_interface(demo: VoxCPMDemo):
|
||||
font-weight: 600;
|
||||
}
|
||||
"""
|
||||
) as interface:
|
||||
# Header logo
|
||||
gr.HTML('<div class="logo-container"><img src="/gradio_api/file=assets/voxcpm_logo.png" alt="VoxCPM Logo"></div>')
|
||||
|
||||
|
||||
def create_demo_interface(demo: VoxCPMDemo):
|
||||
"""Build the Gradio UI for VoxCPM demo."""
|
||||
gr.set_static_paths(paths=[Path.cwd().absolute() / "assets"])
|
||||
|
||||
with gr.Blocks() as interface:
|
||||
gr.HTML(
|
||||
'<div class="logo-container"><img src="/gradio_api/file=assets/voxcpm_logo.png" alt="VoxCPM Logo"></div>',
|
||||
padding=True,
|
||||
)
|
||||
|
||||
# Quick Start
|
||||
with gr.Accordion("📋 Quick Start Guide |快速入门", open=False, elem_id="acc_quick"):
|
||||
@@ -200,34 +223,56 @@ def create_demo_interface(demo: VoxCPMDemo):
|
||||
# Main controls
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
prompt_wav = gr.Audio(
|
||||
sources=["upload", 'microphone'],
|
||||
# 1. Reference Audio
|
||||
# gr.Markdown("### 🎤 Reference Audio (Optional)")
|
||||
# gr.Markdown("*提供参考音频进行音色克隆;不提供则使用 Voice Design 模式*")
|
||||
reference_wav = gr.Audio(
|
||||
sources=["upload", "microphone"],
|
||||
type="filepath",
|
||||
label="Prompt Speech (Optional, or let VoxCPM improvise)",
|
||||
value="./examples/example.wav",
|
||||
label="Reference Audio (Optional)",
|
||||
)
|
||||
DoDenoisePromptAudio = gr.Checkbox(
|
||||
value=False,
|
||||
label="Prompt Speech Enhancement",
|
||||
label="Reference Audio Enhancement",
|
||||
elem_id="chk_denoise",
|
||||
info="We use ZipEnhancer model to denoise the prompt audio."
|
||||
info="Use ZipEnhancer to denoise the reference audio",
|
||||
)
|
||||
with gr.Row():
|
||||
prompt_text = gr.Textbox(
|
||||
value="Just by listening a few minutes a day, you'll be able to eliminate negative thoughts by conditioning your mind to be more positive.",
|
||||
label="Prompt Text",
|
||||
placeholder="Please enter the prompt text. Automatic recognition is supported, and you can correct the results yourself..."
|
||||
|
||||
# 2. Control Instruction
|
||||
# gr.Markdown("### 🎛️ Control Instruction (Optional)")
|
||||
# gr.Markdown("*描述声音风格、情感等,格式:`(instruction) text`*")
|
||||
control_instruction = gr.Textbox(
|
||||
value="",
|
||||
label="Control Instruction",
|
||||
placeholder="*描述声音风格、情感等,格式:`(instruction) text`,例如:年轻女性,温柔甜美 / 悲伤地说 / an excited young man*",
|
||||
lines=2,
|
||||
)
|
||||
run_btn = gr.Button("Generate Speech", variant="primary")
|
||||
|
||||
# 3. Target Text
|
||||
# gr.Markdown("### 📝 Target Text")
|
||||
text = gr.Textbox(
|
||||
value="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly realistic speech.",
|
||||
label="Target Text",
|
||||
lines=3,
|
||||
)
|
||||
DoNormalizeText = gr.Checkbox(
|
||||
value=False,
|
||||
label="Text Normalization",
|
||||
elem_id="chk_normalize",
|
||||
info="Use wetext library to normalize the input text",
|
||||
)
|
||||
|
||||
run_btn = gr.Button("🔊 Generate Speech", variant="primary", size="lg")
|
||||
|
||||
with gr.Column():
|
||||
gr.Markdown("### ⚙️ Generation Settings")
|
||||
cfg_value = gr.Slider(
|
||||
minimum=1.0,
|
||||
maximum=3.0,
|
||||
value=2.0,
|
||||
step=0.1,
|
||||
label="CFG Value (Guidance Scale)",
|
||||
info="Higher values increase adherence to prompt, lower values allow more creativity"
|
||||
info="Higher = more adherence to prompt; Lower = more creativity",
|
||||
)
|
||||
inference_timesteps = gr.Slider(
|
||||
minimum=4,
|
||||
@@ -235,40 +280,54 @@ def create_demo_interface(demo: VoxCPMDemo):
|
||||
value=10,
|
||||
step=1,
|
||||
label="Inference Timesteps",
|
||||
info="Number of inference timesteps for generation (higher values may improve quality but slower)"
|
||||
info="Higher = better quality but slower",
|
||||
)
|
||||
with gr.Row():
|
||||
text = gr.Textbox(
|
||||
value="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly realistic speech.",
|
||||
label="Target Text",
|
||||
)
|
||||
with gr.Row():
|
||||
DoNormalizeText = gr.Checkbox(
|
||||
value=False,
|
||||
label="Text Normalization",
|
||||
elem_id="chk_normalize",
|
||||
info="We use wetext library to normalize the input text."
|
||||
)
|
||||
audio_output = gr.Audio(label="Output Audio")
|
||||
|
||||
gr.Markdown("### 🔈 Output")
|
||||
audio_output = gr.Audio(label="Generated Audio")
|
||||
|
||||
gr.Markdown("""
|
||||
---
|
||||
**模式说明 / Mode Info:**
|
||||
- **有 Reference Audio** → Prompt 隔离模式(音色克隆)
|
||||
- **无 Reference Audio** → Voice Design 模式(用 Control Instruction 描述声音)
|
||||
|
||||
**Control Instruction 示例:**
|
||||
- `年轻女性,温柔甜美`
|
||||
- `悲伤地说`
|
||||
- `an excited young man`
|
||||
""")
|
||||
|
||||
# Wiring
|
||||
run_btn.click(
|
||||
fn=demo.generate_tts_audio,
|
||||
inputs=[text, prompt_wav, prompt_text, cfg_value, inference_timesteps, DoNormalizeText, DoDenoisePromptAudio],
|
||||
inputs=[
|
||||
text,
|
||||
control_instruction,
|
||||
reference_wav,
|
||||
cfg_value,
|
||||
inference_timesteps,
|
||||
DoNormalizeText,
|
||||
DoDenoisePromptAudio,
|
||||
],
|
||||
outputs=[audio_output],
|
||||
show_progress=True,
|
||||
api_name="generate",
|
||||
)
|
||||
prompt_wav.change(fn=demo.prompt_wav_recognition, inputs=[prompt_wav], outputs=[prompt_text])
|
||||
|
||||
return interface
|
||||
|
||||
|
||||
def run_demo(server_name: str = "localhost", server_port: int = 7860, show_error: bool = True):
|
||||
def run_demo(server_name: str = "0.0.0.0", server_port: int = 7869, show_error: bool = True):
|
||||
demo = VoxCPMDemo()
|
||||
interface = create_demo_interface(demo)
|
||||
# Recommended to enable queue on Spaces for better throughput
|
||||
interface.queue(max_size=10, default_concurrency_limit=1).launch(server_name=server_name, server_port=server_port, show_error=show_error)
|
||||
interface.queue(max_size=10, default_concurrency_limit=1).launch(
|
||||
server_name=server_name,
|
||||
server_port=server_port,
|
||||
show_error=show_error,
|
||||
theme=THEME,
|
||||
css=CSS,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -25,7 +25,7 @@ lora:
|
||||
enable_lm: true
|
||||
enable_dit: true
|
||||
enable_proj: false
|
||||
r: 32
|
||||
r: 8
|
||||
alpha: 16
|
||||
dropout: 0.0
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ lora:
|
||||
enable_lm: true
|
||||
enable_dit: true
|
||||
enable_proj: false
|
||||
r: 32
|
||||
r: 8
|
||||
alpha: 16
|
||||
dropout: 0.0
|
||||
|
||||
|
||||
Binary file not shown.
+207
-220
@@ -1,18 +1,14 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import glob
|
||||
import json
|
||||
import yaml
|
||||
import shutil
|
||||
import datetime
|
||||
import subprocess
|
||||
import threading
|
||||
import gradio as gr
|
||||
import torch
|
||||
import soundfile as sf
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
from typing import Optional
|
||||
|
||||
# Add src to sys.path
|
||||
project_root = Path(__file__).parent
|
||||
@@ -89,7 +85,7 @@ LANG_DICT = {
|
||||
"lang_select": "Language / 语言",
|
||||
"refresh": "刷新",
|
||||
"output_name": "输出目录名称 (可选,若存在则继续训练)",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Global variables
|
||||
@@ -98,9 +94,11 @@ asr_model: Optional[AutoModel] = None
|
||||
training_process: Optional[subprocess.Popen] = None
|
||||
training_log = ""
|
||||
|
||||
|
||||
def get_timestamp_str():
|
||||
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
|
||||
def get_or_load_asr_model():
|
||||
global asr_model
|
||||
if asr_model is None:
|
||||
@@ -109,23 +107,25 @@ def get_or_load_asr_model():
|
||||
asr_model = AutoModel(
|
||||
model="iic/SenseVoiceSmall",
|
||||
disable_update=True,
|
||||
log_level='ERROR',
|
||||
log_level="ERROR",
|
||||
device=device,
|
||||
)
|
||||
return asr_model
|
||||
|
||||
|
||||
def recognize_audio(audio_path):
|
||||
if not audio_path:
|
||||
return ""
|
||||
try:
|
||||
model = get_or_load_asr_model()
|
||||
res = model.generate(input=audio_path, language="auto", use_itn=True)
|
||||
text = res[0]["text"].split('|>')[-1]
|
||||
text = res[0]["text"].split("|>")[-1]
|
||||
return text
|
||||
except Exception as e:
|
||||
print(f"ASR Error: {e}", file=sys.stderr)
|
||||
return ""
|
||||
|
||||
|
||||
def scan_lora_checkpoints(root_dir="lora", with_info=False):
|
||||
"""
|
||||
Scans for LoRA checkpoints in the lora directory.
|
||||
@@ -170,6 +170,7 @@ def scan_lora_checkpoints(root_dir="lora", with_info=False):
|
||||
|
||||
return sorted(checkpoints, reverse=True)
|
||||
|
||||
|
||||
def load_lora_config_from_checkpoint(lora_path):
|
||||
"""Load LoRA config from lora_config.json if available."""
|
||||
lora_config_file = os.path.join(lora_path, "lora_config.json")
|
||||
@@ -184,6 +185,7 @@ def load_lora_config_from_checkpoint(lora_path):
|
||||
print(f"Warning: Failed to load lora_config.json: {e}", file=sys.stderr)
|
||||
return None, None
|
||||
|
||||
|
||||
def get_default_lora_config():
|
||||
"""Return default LoRA config for hot-swapping support."""
|
||||
return LoRAConfig(
|
||||
@@ -192,9 +194,10 @@ def get_default_lora_config():
|
||||
r=32,
|
||||
alpha=16,
|
||||
target_modules_lm=["q_proj", "v_proj", "k_proj", "o_proj"],
|
||||
target_modules_dit=["q_proj", "v_proj", "k_proj", "o_proj"]
|
||||
target_modules_dit=["q_proj", "v_proj", "k_proj", "o_proj"],
|
||||
)
|
||||
|
||||
|
||||
def load_model(pretrained_path, lora_path=None):
|
||||
global current_model
|
||||
print(f"Loading model from {pretrained_path}...", file=sys.stderr)
|
||||
@@ -228,9 +231,8 @@ def load_model(pretrained_path, lora_path=None):
|
||||
)
|
||||
return "Model loaded successfully!"
|
||||
|
||||
def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, steps, seed, pretrained_path=None):
|
||||
global current_model
|
||||
|
||||
def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, steps, seed, pretrained_path=None):
|
||||
# 如果选择了 LoRA 模型且当前模型未加载,尝试从 LoRA config 读取 base_model
|
||||
if current_model is None:
|
||||
# 优先使用用户指定的预训练模型路径
|
||||
@@ -261,7 +263,7 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
|
||||
# 加载模型
|
||||
try:
|
||||
print(f"Loading base model: {base_model_path}", file=sys.stderr)
|
||||
status_msg = load_model(base_model_path)
|
||||
load_model(base_model_path)
|
||||
if lora_selection and lora_selection != "None":
|
||||
print(f"Model loaded for LoRA: {lora_selection}", file=sys.stderr)
|
||||
except Exception as e:
|
||||
@@ -270,6 +272,7 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
|
||||
return None, error_msg
|
||||
|
||||
# Handle LoRA hot-swapping
|
||||
assert current_model is not None, "Model must be loaded before inference"
|
||||
if lora_selection and lora_selection != "None":
|
||||
full_lora_path = os.path.join("lora", lora_selection)
|
||||
print(f"Hot-loading LoRA: {full_lora_path}", file=sys.stderr)
|
||||
@@ -317,14 +320,16 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
|
||||
prompt_text=final_prompt_text,
|
||||
cfg_value=cfg_scale,
|
||||
inference_timesteps=steps,
|
||||
denoise=False
|
||||
denoise=False,
|
||||
)
|
||||
return (current_model.tts_model.sample_rate, audio_np), "Generation Success"
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return None, f"Error: {str(e)}"
|
||||
|
||||
|
||||
def start_training(
|
||||
pretrained_path,
|
||||
train_manifest,
|
||||
@@ -355,7 +360,7 @@ def start_training(
|
||||
hf_model_id="",
|
||||
distribute=False,
|
||||
):
|
||||
global training_process, training_log
|
||||
global training_log
|
||||
|
||||
if training_process is not None and training_process.poll() is None:
|
||||
return "Training is already running!"
|
||||
@@ -394,10 +399,7 @@ def start_training(
|
||||
"max_steps": resolved_max_steps,
|
||||
"save_path": checkpoints_dir,
|
||||
"tensorboard": tensorboard_path if tensorboard_path else logs_dir,
|
||||
"lambdas": {
|
||||
"loss/diff": 1.0,
|
||||
"loss/stop": 1.0
|
||||
},
|
||||
"lambdas": {"loss/diff": 1.0, "loss/stop": 1.0},
|
||||
"lora": {
|
||||
"enable_lm": bool(enable_lm),
|
||||
"enable_dit": bool(enable_dit),
|
||||
@@ -406,7 +408,7 @@ def start_training(
|
||||
"alpha": int(lora_alpha),
|
||||
"dropout": float(dropout),
|
||||
"target_modules_lm": ["q_proj", "v_proj", "k_proj", "o_proj"],
|
||||
"target_modules_dit": ["q_proj", "v_proj", "k_proj", "o_proj"]
|
||||
"target_modules_dit": ["q_proj", "v_proj", "k_proj", "o_proj"],
|
||||
},
|
||||
}
|
||||
|
||||
@@ -420,25 +422,15 @@ def start_training(
|
||||
with open(config_path, "w") as f:
|
||||
yaml.dump(config, f)
|
||||
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"scripts/train_voxcpm_finetune.py",
|
||||
"--config_path",
|
||||
config_path
|
||||
]
|
||||
cmd = [sys.executable, "scripts/train_voxcpm_finetune.py", "--config_path", config_path]
|
||||
|
||||
training_log = f"Starting training...\nConfig saved to {config_path}\nOutput dir: {save_dir}\n"
|
||||
|
||||
def run_process():
|
||||
global training_process, training_log
|
||||
training_process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1
|
||||
)
|
||||
training_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
|
||||
|
||||
assert training_process.stdout is not None
|
||||
for line in training_process.stdout:
|
||||
training_log += line
|
||||
# Keep log size manageable
|
||||
@@ -452,17 +444,20 @@ def start_training(
|
||||
|
||||
return f"Training started! Check 'lora/{timestamp}'"
|
||||
|
||||
|
||||
def get_training_log():
|
||||
return training_log
|
||||
|
||||
|
||||
def stop_training():
|
||||
global training_process, training_log
|
||||
global training_log
|
||||
if training_process is not None and training_process.poll() is None:
|
||||
training_process.terminate()
|
||||
training_log += "\nTraining terminated by user."
|
||||
return "Training stopped."
|
||||
return "No training running."
|
||||
|
||||
|
||||
# --- GUI Layout ---
|
||||
|
||||
# 自定义CSS样式
|
||||
@@ -830,11 +825,7 @@ label {
|
||||
}
|
||||
"""
|
||||
|
||||
with gr.Blocks(
|
||||
title="VoxCPM LoRA WebUI",
|
||||
theme=gr.themes.Soft(),
|
||||
css=custom_css
|
||||
) as app:
|
||||
with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css) as app:
|
||||
|
||||
# State for language
|
||||
lang_state = gr.State("zh") # Default to Chinese
|
||||
@@ -850,10 +841,7 @@ with gr.Blocks(
|
||||
""")
|
||||
with gr.Column(scale=1):
|
||||
lang_btn = gr.Radio(
|
||||
choices=["en", "zh"],
|
||||
value="zh",
|
||||
label="🌐 Language / 语言",
|
||||
elem_classes="lang-selector"
|
||||
choices=["en", "zh"], value="zh", label="🌐 Language / 语言", elem_classes="lang-selector"
|
||||
)
|
||||
|
||||
with gr.Tabs(elem_classes="tabs") as tabs:
|
||||
@@ -869,79 +857,40 @@ with gr.Blocks(
|
||||
gr.Markdown("#### 📁 基础配置")
|
||||
|
||||
train_pretrained_path = gr.Textbox(
|
||||
label="📂 预训练模型路径",
|
||||
value=default_pretrained_path,
|
||||
elem_classes="input-field"
|
||||
label="📂 预训练模型路径", value=default_pretrained_path, elem_classes="input-field"
|
||||
)
|
||||
train_manifest = gr.Textbox(
|
||||
label="📋 训练数据清单 (jsonl)",
|
||||
value="examples/train_data_example.jsonl",
|
||||
elem_classes="input-field"
|
||||
)
|
||||
val_manifest = gr.Textbox(
|
||||
label="📊 验证数据清单 (可选)",
|
||||
value="",
|
||||
elem_classes="input-field"
|
||||
elem_classes="input-field",
|
||||
)
|
||||
val_manifest = gr.Textbox(label="📊 验证数据清单 (可选)", value="", elem_classes="input-field")
|
||||
|
||||
gr.Markdown("#### ⚙️ 训练参数")
|
||||
|
||||
with gr.Row():
|
||||
lr = gr.Number(
|
||||
label="📈 学习率 (Learning Rate)",
|
||||
value=1e-4,
|
||||
elem_classes="input-field"
|
||||
)
|
||||
lr = gr.Number(label="📈 学习率 (Learning Rate)", value=1e-4, elem_classes="input-field")
|
||||
num_iters = gr.Number(
|
||||
label="🔄 最大迭代次数",
|
||||
value=2000,
|
||||
precision=0,
|
||||
elem_classes="input-field"
|
||||
label="🔄 最大迭代次数", value=2000, precision=0, elem_classes="input-field"
|
||||
)
|
||||
batch_size = gr.Number(
|
||||
label="📦 批次大小 (Batch Size)",
|
||||
value=1,
|
||||
precision=0,
|
||||
elem_classes="input-field"
|
||||
label="📦 批次大小 (Batch Size)", value=1, precision=0, elem_classes="input-field"
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
lora_rank = gr.Number(
|
||||
label="🎯 LoRA Rank",
|
||||
value=32,
|
||||
precision=0,
|
||||
elem_classes="input-field"
|
||||
)
|
||||
lora_alpha = gr.Number(
|
||||
label="⚖️ LoRA Alpha",
|
||||
value=16,
|
||||
precision=0,
|
||||
elem_classes="input-field"
|
||||
)
|
||||
lora_rank = gr.Number(label="🎯 LoRA Rank", value=32, precision=0, elem_classes="input-field")
|
||||
lora_alpha = gr.Number(label="⚖️ LoRA Alpha", value=16, precision=0, elem_classes="input-field")
|
||||
save_interval = gr.Number(
|
||||
label="💾 保存间隔 (Steps)",
|
||||
value=1000,
|
||||
precision=0,
|
||||
elem_classes="input-field"
|
||||
label="💾 保存间隔 (Steps)", value=1000, precision=0, elem_classes="input-field"
|
||||
)
|
||||
|
||||
output_name = gr.Textbox(
|
||||
label="📁 输出目录名称 (可选,若存在则继续训练)",
|
||||
value="",
|
||||
elem_classes="input-field"
|
||||
label="📁 输出目录名称 (可选,若存在则继续训练)", value="", elem_classes="input-field"
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
start_btn = gr.Button(
|
||||
"▶️ 开始训练",
|
||||
variant="primary",
|
||||
elem_classes="button-primary"
|
||||
)
|
||||
stop_btn = gr.Button(
|
||||
"⏹️ 停止训练",
|
||||
variant="stop",
|
||||
elem_classes="button-stop"
|
||||
)
|
||||
start_btn = gr.Button("▶️ 开始训练", variant="primary", elem_classes="button-primary")
|
||||
stop_btn = gr.Button("⏹️ 停止训练", variant="stop", elem_classes="button-stop")
|
||||
|
||||
with gr.Accordion("🔧 高级选项 (Advanced)", open=False, elem_classes="accordion"):
|
||||
with gr.Row():
|
||||
@@ -964,7 +913,9 @@ with gr.Blocks(
|
||||
|
||||
gr.Markdown("#### 分发选项 (Distribution)")
|
||||
with gr.Row():
|
||||
hf_model_id = gr.Textbox(label="HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)", value="openbmb/VoxCPM1.5")
|
||||
hf_model_id = gr.Textbox(
|
||||
label="HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)", value="openbmb/VoxCPM1.5"
|
||||
)
|
||||
distribute = gr.Checkbox(label="分发模式 (distribute)", value=False)
|
||||
|
||||
with gr.Column(scale=2, elem_classes="form-section"):
|
||||
@@ -975,23 +926,41 @@ with gr.Blocks(
|
||||
max_lines=30,
|
||||
interactive=False,
|
||||
elem_classes="input-field",
|
||||
show_label=False
|
||||
show_label=False,
|
||||
)
|
||||
|
||||
start_btn.click(
|
||||
start_training,
|
||||
inputs=[
|
||||
train_pretrained_path, train_manifest, val_manifest,
|
||||
lr, num_iters, batch_size, lora_rank, lora_alpha, save_interval,
|
||||
train_pretrained_path,
|
||||
train_manifest,
|
||||
val_manifest,
|
||||
lr,
|
||||
num_iters,
|
||||
batch_size,
|
||||
lora_rank,
|
||||
lora_alpha,
|
||||
save_interval,
|
||||
output_name,
|
||||
# advanced
|
||||
grad_accum_steps, num_workers, log_interval, valid_interval,
|
||||
weight_decay, warmup_steps, max_steps, sample_rate,
|
||||
enable_lm, enable_dit, enable_proj, dropout, tensorboard_path,
|
||||
grad_accum_steps,
|
||||
num_workers,
|
||||
log_interval,
|
||||
valid_interval,
|
||||
weight_decay,
|
||||
warmup_steps,
|
||||
max_steps,
|
||||
sample_rate,
|
||||
enable_lm,
|
||||
enable_dit,
|
||||
enable_proj,
|
||||
dropout,
|
||||
tensorboard_path,
|
||||
# distribution
|
||||
hf_model_id, distribute
|
||||
hf_model_id,
|
||||
distribute,
|
||||
],
|
||||
outputs=[logs_out] # Initial message
|
||||
outputs=[logs_out], # Initial message
|
||||
)
|
||||
stop_btn.click(stop_training, outputs=[logs_out])
|
||||
|
||||
@@ -1016,21 +985,17 @@ with gr.Blocks(
|
||||
value="Hello, this is a test of the VoxCPM LoRA model.",
|
||||
elem_classes="input-field",
|
||||
lines=4,
|
||||
placeholder="输入要合成的文本内容..."
|
||||
placeholder="输入要合成的文本内容...",
|
||||
)
|
||||
|
||||
gr.Markdown("**🎭 声音克隆(可选)**")
|
||||
|
||||
prompt_wav = gr.Audio(
|
||||
label="🎵 参考音频",
|
||||
type="filepath",
|
||||
elem_classes="input-field"
|
||||
)
|
||||
prompt_wav = gr.Audio(label="🎵 参考音频", type="filepath", elem_classes="input-field")
|
||||
|
||||
prompt_text = gr.Textbox(
|
||||
label="📝 参考文本(可选)",
|
||||
elem_classes="input-field",
|
||||
placeholder="如不填写,将自动识别参考音频内容"
|
||||
placeholder="如不填写,将自动识别参考音频内容",
|
||||
)
|
||||
|
||||
# 中栏:模型选择和参数配置 (35%)
|
||||
@@ -1043,14 +1008,10 @@ with gr.Blocks(
|
||||
value="None",
|
||||
interactive=True,
|
||||
elem_classes="input-field",
|
||||
info="选择训练好的 LoRA 模型,或选择 None 使用基础模型"
|
||||
info="选择训练好的 LoRA 模型,或选择 None 使用基础模型",
|
||||
)
|
||||
|
||||
refresh_lora_btn = gr.Button(
|
||||
"🔄 刷新模型列表",
|
||||
elem_classes="button-refresh",
|
||||
size="sm"
|
||||
)
|
||||
refresh_lora_btn = gr.Button("🔄 刷新模型列表", elem_classes="button-refresh", size="sm")
|
||||
|
||||
gr.Markdown("#### ⚙️ 生成参数")
|
||||
|
||||
@@ -1060,7 +1021,7 @@ with gr.Blocks(
|
||||
maximum=5.0,
|
||||
value=2.0,
|
||||
step=0.1,
|
||||
info="引导系数,值越大越贴近提示"
|
||||
info="引导系数,值越大越贴近提示",
|
||||
)
|
||||
|
||||
steps = gr.Slider(
|
||||
@@ -1069,7 +1030,7 @@ with gr.Blocks(
|
||||
maximum=50,
|
||||
value=10,
|
||||
step=1,
|
||||
info="生成质量与步数成正比,但耗时更长"
|
||||
info="生成质量与步数成正比,但耗时更长",
|
||||
)
|
||||
|
||||
seed = gr.Number(
|
||||
@@ -1077,25 +1038,16 @@ with gr.Blocks(
|
||||
value=-1,
|
||||
precision=0,
|
||||
elem_classes="input-field",
|
||||
info="-1 为随机,固定值可复现结果"
|
||||
info="-1 为随机,固定值可复现结果",
|
||||
)
|
||||
|
||||
generate_btn = gr.Button(
|
||||
"🎵 生成音频",
|
||||
variant="primary",
|
||||
elem_classes="button-primary",
|
||||
size="lg"
|
||||
)
|
||||
generate_btn = gr.Button("🎵 生成音频", variant="primary", elem_classes="button-primary", size="lg")
|
||||
|
||||
# 右栏:生成结果 (30%)
|
||||
with gr.Column(scale=30, elem_classes="form-section"):
|
||||
gr.Markdown("#### 🎧 生成结果")
|
||||
|
||||
audio_out = gr.Audio(
|
||||
label="",
|
||||
elem_classes="input-field",
|
||||
show_label=False
|
||||
)
|
||||
audio_out = gr.Audio(label="", elem_classes="input-field", show_label=False)
|
||||
|
||||
gr.Markdown("#### 📋 状态信息")
|
||||
|
||||
@@ -1105,7 +1057,7 @@ with gr.Blocks(
|
||||
elem_classes="input-field",
|
||||
show_label=False,
|
||||
lines=3,
|
||||
placeholder="等待生成..."
|
||||
placeholder="等待生成...",
|
||||
)
|
||||
|
||||
def refresh_loras():
|
||||
@@ -1126,16 +1078,21 @@ with gr.Blocks(
|
||||
refresh_lora_btn.click(refresh_loras, outputs=[lora_select])
|
||||
|
||||
# Auto-recognize audio when uploaded
|
||||
prompt_wav.change(
|
||||
fn=recognize_audio,
|
||||
inputs=[prompt_wav],
|
||||
outputs=[prompt_text]
|
||||
)
|
||||
prompt_wav.change(fn=recognize_audio, inputs=[prompt_wav], outputs=[prompt_text])
|
||||
|
||||
generate_btn.click(
|
||||
run_inference,
|
||||
inputs=[infer_text, prompt_wav, prompt_text, lora_select, cfg_scale, steps, seed, train_pretrained_path],
|
||||
outputs=[audio_out, status_out]
|
||||
inputs=[
|
||||
infer_text,
|
||||
prompt_wav,
|
||||
prompt_text,
|
||||
lora_select,
|
||||
cfg_scale,
|
||||
steps,
|
||||
seed,
|
||||
train_pretrained_path,
|
||||
],
|
||||
outputs=[audio_out, status_out],
|
||||
)
|
||||
|
||||
# --- Language Switching Logic ---
|
||||
@@ -1144,108 +1101,138 @@ with gr.Blocks(
|
||||
# Labels for advanced options
|
||||
if lang == "zh":
|
||||
adv = {
|
||||
'grad_accum_steps': "梯度累积 (grad_accum_steps)",
|
||||
'num_workers': "数据加载线程 (num_workers)",
|
||||
'log_interval': "日志间隔 (log_interval)",
|
||||
'valid_interval': "验证间隔 (valid_interval)",
|
||||
'weight_decay': "权重衰减 (weight_decay)",
|
||||
'warmup_steps': "warmup_steps",
|
||||
'max_steps': "最大步数 (max_steps)",
|
||||
'sample_rate': "采样率 (sample_rate)",
|
||||
'enable_lm': "启用 LoRA LM (enable_lm)",
|
||||
'enable_dit': "启用 LoRA DIT (enable_dit)",
|
||||
'enable_proj': "启用投影 (enable_proj)",
|
||||
'dropout': "LoRA Dropout",
|
||||
'tensorboard_path': "Tensorboard 路径 (可选)",
|
||||
'hf_model_id': "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
|
||||
'distribute': "分发模式 (distribute)",
|
||||
"grad_accum_steps": "梯度累积 (grad_accum_steps)",
|
||||
"num_workers": "数据加载线程 (num_workers)",
|
||||
"log_interval": "日志间隔 (log_interval)",
|
||||
"valid_interval": "验证间隔 (valid_interval)",
|
||||
"weight_decay": "权重衰减 (weight_decay)",
|
||||
"warmup_steps": "warmup_steps",
|
||||
"max_steps": "最大步数 (max_steps)",
|
||||
"sample_rate": "采样率 (sample_rate)",
|
||||
"enable_lm": "启用 LoRA LM (enable_lm)",
|
||||
"enable_dit": "启用 LoRA DIT (enable_dit)",
|
||||
"enable_proj": "启用投影 (enable_proj)",
|
||||
"dropout": "LoRA Dropout",
|
||||
"tensorboard_path": "Tensorboard 路径 (可选)",
|
||||
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
|
||||
"distribute": "分发模式 (distribute)",
|
||||
}
|
||||
else:
|
||||
adv = {
|
||||
'grad_accum_steps': "Grad Accum Steps",
|
||||
'num_workers': "Num Workers",
|
||||
'log_interval': "Log Interval",
|
||||
'valid_interval': "Valid Interval",
|
||||
'weight_decay': "Weight Decay",
|
||||
'warmup_steps': "Warmup Steps",
|
||||
'max_steps': "Max Steps",
|
||||
'sample_rate': "Sample Rate",
|
||||
'enable_lm': "Enable LoRA LM",
|
||||
'enable_dit': "Enable LoRA DIT",
|
||||
'enable_proj': "Enable Projection",
|
||||
'dropout': "LoRA Dropout",
|
||||
'tensorboard_path': "Tensorboard Path (Optional)",
|
||||
'hf_model_id': "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
|
||||
'distribute': "Distribute Mode",
|
||||
"grad_accum_steps": "Grad Accum Steps",
|
||||
"num_workers": "Num Workers",
|
||||
"log_interval": "Log Interval",
|
||||
"valid_interval": "Valid Interval",
|
||||
"weight_decay": "Weight Decay",
|
||||
"warmup_steps": "Warmup Steps",
|
||||
"max_steps": "Max Steps",
|
||||
"sample_rate": "Sample Rate",
|
||||
"enable_lm": "Enable LoRA LM",
|
||||
"enable_dit": "Enable LoRA DIT",
|
||||
"enable_proj": "Enable Projection",
|
||||
"dropout": "LoRA Dropout",
|
||||
"tensorboard_path": "Tensorboard Path (Optional)",
|
||||
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
|
||||
"distribute": "Distribute Mode",
|
||||
}
|
||||
|
||||
return (
|
||||
gr.update(value=f"# {d['title']}"),
|
||||
gr.update(label=d['tab_train']),
|
||||
gr.update(label=d['tab_infer']),
|
||||
gr.update(label=d['pretrained_path']),
|
||||
gr.update(label=d['train_manifest']),
|
||||
gr.update(label=d['val_manifest']),
|
||||
gr.update(label=d['lr']),
|
||||
gr.update(label=d['max_iters']),
|
||||
gr.update(label=d['batch_size']),
|
||||
gr.update(label=d['lora_rank']),
|
||||
gr.update(label=d['lora_alpha']),
|
||||
gr.update(label=d['save_interval']),
|
||||
gr.update(label=d['output_name']),
|
||||
gr.update(value=d['start_train']),
|
||||
gr.update(value=d['stop_train']),
|
||||
gr.update(label=d['train_logs']),
|
||||
gr.update(label=d["tab_train"]),
|
||||
gr.update(label=d["tab_infer"]),
|
||||
gr.update(label=d["pretrained_path"]),
|
||||
gr.update(label=d["train_manifest"]),
|
||||
gr.update(label=d["val_manifest"]),
|
||||
gr.update(label=d["lr"]),
|
||||
gr.update(label=d["max_iters"]),
|
||||
gr.update(label=d["batch_size"]),
|
||||
gr.update(label=d["lora_rank"]),
|
||||
gr.update(label=d["lora_alpha"]),
|
||||
gr.update(label=d["save_interval"]),
|
||||
gr.update(label=d["output_name"]),
|
||||
gr.update(value=d["start_train"]),
|
||||
gr.update(value=d["stop_train"]),
|
||||
gr.update(label=d["train_logs"]),
|
||||
# Advanced options (must match outputs order)
|
||||
gr.update(label=adv['grad_accum_steps']),
|
||||
gr.update(label=adv['num_workers']),
|
||||
gr.update(label=adv['log_interval']),
|
||||
gr.update(label=adv['valid_interval']),
|
||||
gr.update(label=adv['weight_decay']),
|
||||
gr.update(label=adv['warmup_steps']),
|
||||
gr.update(label=adv['max_steps']),
|
||||
gr.update(label=adv['sample_rate']),
|
||||
gr.update(label=adv['enable_lm']),
|
||||
gr.update(label=adv['enable_dit']),
|
||||
gr.update(label=adv['enable_proj']),
|
||||
gr.update(label=adv['dropout']),
|
||||
gr.update(label=adv['tensorboard_path']),
|
||||
gr.update(label=adv["grad_accum_steps"]),
|
||||
gr.update(label=adv["num_workers"]),
|
||||
gr.update(label=adv["log_interval"]),
|
||||
gr.update(label=adv["valid_interval"]),
|
||||
gr.update(label=adv["weight_decay"]),
|
||||
gr.update(label=adv["warmup_steps"]),
|
||||
gr.update(label=adv["max_steps"]),
|
||||
gr.update(label=adv["sample_rate"]),
|
||||
gr.update(label=adv["enable_lm"]),
|
||||
gr.update(label=adv["enable_dit"]),
|
||||
gr.update(label=adv["enable_proj"]),
|
||||
gr.update(label=adv["dropout"]),
|
||||
gr.update(label=adv["tensorboard_path"]),
|
||||
# Distribution options
|
||||
gr.update(label=adv['hf_model_id']),
|
||||
gr.update(label=adv['distribute']),
|
||||
gr.update(label=adv["hf_model_id"]),
|
||||
gr.update(label=adv["distribute"]),
|
||||
# Inference section
|
||||
gr.update(label=d['text_to_synth']),
|
||||
gr.update(label=d['ref_audio']),
|
||||
gr.update(label=d['ref_text']),
|
||||
gr.update(label=d['select_lora']),
|
||||
gr.update(value=d['refresh']),
|
||||
gr.update(label=d['cfg_scale']),
|
||||
gr.update(label=d['infer_steps']),
|
||||
gr.update(label=d['seed']),
|
||||
gr.update(value=d['gen_audio']),
|
||||
gr.update(label=d['gen_output']),
|
||||
gr.update(label=d['status']),
|
||||
gr.update(label=d["text_to_synth"]),
|
||||
gr.update(label=d["ref_audio"]),
|
||||
gr.update(label=d["ref_text"]),
|
||||
gr.update(label=d["select_lora"]),
|
||||
gr.update(value=d["refresh"]),
|
||||
gr.update(label=d["cfg_scale"]),
|
||||
gr.update(label=d["infer_steps"]),
|
||||
gr.update(label=d["seed"]),
|
||||
gr.update(value=d["gen_audio"]),
|
||||
gr.update(label=d["gen_output"]),
|
||||
gr.update(label=d["status"]),
|
||||
)
|
||||
|
||||
lang_btn.change(
|
||||
change_language,
|
||||
inputs=[lang_btn],
|
||||
outputs=[
|
||||
title_md, tab_train, tab_infer,
|
||||
train_pretrained_path, train_manifest, val_manifest,
|
||||
lr, num_iters, batch_size, lora_rank, lora_alpha, save_interval,
|
||||
title_md,
|
||||
tab_train,
|
||||
tab_infer,
|
||||
train_pretrained_path,
|
||||
train_manifest,
|
||||
val_manifest,
|
||||
lr,
|
||||
num_iters,
|
||||
batch_size,
|
||||
lora_rank,
|
||||
lora_alpha,
|
||||
save_interval,
|
||||
output_name,
|
||||
start_btn, stop_btn, logs_out,
|
||||
start_btn,
|
||||
stop_btn,
|
||||
logs_out,
|
||||
# advanced outputs
|
||||
grad_accum_steps, num_workers, log_interval, valid_interval,
|
||||
weight_decay, warmup_steps, max_steps, sample_rate,
|
||||
enable_lm, enable_dit, enable_proj, dropout, tensorboard_path,
|
||||
grad_accum_steps,
|
||||
num_workers,
|
||||
log_interval,
|
||||
valid_interval,
|
||||
weight_decay,
|
||||
warmup_steps,
|
||||
max_steps,
|
||||
sample_rate,
|
||||
enable_lm,
|
||||
enable_dit,
|
||||
enable_proj,
|
||||
dropout,
|
||||
tensorboard_path,
|
||||
# distribution outputs
|
||||
hf_model_id, distribute,
|
||||
infer_text, prompt_wav, prompt_text,
|
||||
lora_select, refresh_lora_btn, cfg_scale, steps, seed,
|
||||
generate_btn, audio_out, status_out
|
||||
]
|
||||
hf_model_id,
|
||||
distribute,
|
||||
infer_text,
|
||||
prompt_wav,
|
||||
prompt_text,
|
||||
lora_select,
|
||||
refresh_lora_btn,
|
||||
cfg_scale,
|
||||
steps,
|
||||
seed,
|
||||
generate_btn,
|
||||
audio_out,
|
||||
status_out,
|
||||
],
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
+1
-3
@@ -30,7 +30,7 @@ dependencies = [
|
||||
"torchcodec",
|
||||
"transformers>=4.36.2",
|
||||
"einops",
|
||||
"gradio<6",
|
||||
"gradio>=6,<7",
|
||||
"inflect",
|
||||
"addict",
|
||||
"wetext",
|
||||
@@ -57,7 +57,6 @@ dev = [
|
||||
"pytest-cov>=2.0",
|
||||
"black>=21.0",
|
||||
"flake8>=3.8",
|
||||
"mypy>=0.800",
|
||||
"pre-commit>=2.0",
|
||||
]
|
||||
|
||||
@@ -90,7 +89,6 @@ extend-exclude = '''
|
||||
\.eggs
|
||||
| \.git
|
||||
| \.hg
|
||||
| \.mypy_cache
|
||||
| \.tox
|
||||
| \.venv
|
||||
| build
|
||||
|
||||
@@ -125,7 +125,10 @@ def main():
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
sf.write(str(out_path), audio_np, model.tts_model.sample_rate)
|
||||
|
||||
print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||
print(
|
||||
f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -127,7 +127,9 @@ def main():
|
||||
|
||||
print(f"Loaded config from: {lora_config_path}", file=sys.stderr)
|
||||
print(f" Base model: {pretrained_path}", file=sys.stderr)
|
||||
print(f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None", file=sys.stderr)
|
||||
print(
|
||||
f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None", file=sys.stderr
|
||||
)
|
||||
|
||||
# 3. Load model with LoRA (no denoiser)
|
||||
print(f"\n[1/2] Loading model with LoRA: {pretrained_path}", file=sys.stderr)
|
||||
@@ -146,10 +148,10 @@ def main():
|
||||
out_path = Path(args.output)
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"\n[2/2] Starting synthesis tests...", file=sys.stderr)
|
||||
print("\n[2/2] Starting synthesis tests...", file=sys.stderr)
|
||||
|
||||
# === Test 1: With LoRA ===
|
||||
print(f"\n [Test 1] Synthesize with LoRA...", file=sys.stderr)
|
||||
print("\n [Test 1] Synthesize with LoRA...", file=sys.stderr)
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
@@ -162,10 +164,13 @@ def main():
|
||||
)
|
||||
lora_output = out_path.with_stem(out_path.stem + "_with_lora")
|
||||
sf.write(str(lora_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||
print(
|
||||
f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# === Test 2: Disable LoRA (via set_lora_enabled) ===
|
||||
print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...", file=sys.stderr)
|
||||
print("\n [Test 2] Disable LoRA (set_lora_enabled=False)...", file=sys.stderr)
|
||||
model.set_lora_enabled(False)
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
@@ -179,10 +184,13 @@ def main():
|
||||
)
|
||||
disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled")
|
||||
sf.write(str(disabled_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||
print(
|
||||
f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# === Test 3: Re-enable LoRA ===
|
||||
print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...", file=sys.stderr)
|
||||
print("\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...", file=sys.stderr)
|
||||
model.set_lora_enabled(True)
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
@@ -196,10 +204,13 @@ def main():
|
||||
)
|
||||
reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled")
|
||||
sf.write(str(reenabled_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||
print(
|
||||
f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# === Test 4: Unload LoRA (reset_lora_weights) ===
|
||||
print(f"\n [Test 4] Unload LoRA (unload_lora)...", file=sys.stderr)
|
||||
print("\n [Test 4] Unload LoRA (unload_lora)...", file=sys.stderr)
|
||||
model.unload_lora()
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
@@ -213,10 +224,13 @@ def main():
|
||||
)
|
||||
reset_output = out_path.with_stem(out_path.stem + "_lora_reset")
|
||||
sf.write(str(reset_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||
print(
|
||||
f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# === Test 5: Hot-reload LoRA (load_lora) ===
|
||||
print(f"\n [Test 5] Hot-reload LoRA (load_lora)...", file=sys.stderr)
|
||||
print("\n [Test 5] Hot-reload LoRA (load_lora)...", file=sys.stderr)
|
||||
loaded, skipped = model.load_lora(ckpt_dir)
|
||||
print(f" Reloaded {len(loaded)} parameters", file=sys.stderr)
|
||||
audio_np = model.generate(
|
||||
@@ -231,9 +245,12 @@ def main():
|
||||
)
|
||||
reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded")
|
||||
sf.write(str(reload_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||
print(
|
||||
f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
print(f"\n[Done] All tests completed!", file=sys.stderr)
|
||||
print("\n[Done] All tests completed!", file=sys.stderr)
|
||||
print(f" - with_lora: {lora_output}", file=sys.stderr)
|
||||
print(f" - lora_disabled: {disabled_output}", file=sys.stderr)
|
||||
print(f" - lora_reenabled: {reenabled_output}", file=sys.stderr)
|
||||
|
||||
@@ -7,7 +7,7 @@ project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root / "src"))
|
||||
|
||||
import contextlib
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict
|
||||
|
||||
import argbind
|
||||
import torch
|
||||
@@ -17,16 +17,19 @@ from transformers import get_cosine_schedule_with_warmup
|
||||
import signal
|
||||
import os
|
||||
|
||||
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
try:
|
||||
from safetensors.torch import save_file
|
||||
|
||||
SAFETENSORS_AVAILABLE = True
|
||||
except ImportError:
|
||||
SAFETENSORS_AVAILABLE = False
|
||||
print("Warning: safetensors not available, will use pytorch format", file=sys.stderr)
|
||||
|
||||
from voxcpm.model import VoxCPMModel
|
||||
import json
|
||||
|
||||
from voxcpm.model import VoxCPMModel, VoxCPM2Model
|
||||
from voxcpm.model.voxcpm import LoRAConfig
|
||||
from voxcpm.training import (
|
||||
Accelerator,
|
||||
@@ -84,7 +87,15 @@ def train(
|
||||
writer = SummaryWriter(log_dir=str(tb_dir)) if accelerator.rank == 0 else None
|
||||
tracker = TrainingTracker(writer=writer, log_file=str(save_dir / "train.log"), rank=accelerator.rank)
|
||||
|
||||
base_model = VoxCPMModel.from_local(pretrained_path, optimize=False, training=True, lora_config=LoRAConfig(**lora) if lora else None)
|
||||
# Auto-detect model architecture from config.json
|
||||
with open(os.path.join(pretrained_path, "config.json"), "r", encoding="utf-8") as _f:
|
||||
_arch = json.load(_f).get("architecture", "voxcpm").lower()
|
||||
_model_cls = VoxCPM2Model if _arch == "voxcpm2" else VoxCPMModel
|
||||
if accelerator.rank == 0:
|
||||
print(f"Detected architecture: {_arch} -> {_model_cls.__name__}", file=sys.stderr)
|
||||
base_model = _model_cls.from_local(
|
||||
pretrained_path, optimize=False, training=True, lora_config=LoRAConfig(**lora) if lora else None
|
||||
)
|
||||
tokenizer = base_model.text_tokenizer
|
||||
|
||||
train_ds, val_ds = load_audio_text_datasets(
|
||||
@@ -166,7 +177,6 @@ def train(
|
||||
unwrapped_model = accelerator.unwrap(model)
|
||||
unwrapped_model.train()
|
||||
|
||||
|
||||
# Only print param info on rank 0 to avoid cluttered output
|
||||
if accelerator.rank == 0:
|
||||
for name, param in model.named_parameters():
|
||||
@@ -199,7 +209,19 @@ def train(
|
||||
resume = {"step": start_step}
|
||||
|
||||
# Register signal handler to save checkpoint on termination (SIGTERM/SIGINT)
|
||||
def _signal_handler(signum, frame, _model=model, _optim=optimizer, _sched=scheduler, _save_dir=save_dir, _pretrained=pretrained_path, _hf_id=hf_model_id, _dist=distribute, _resume=resume, _rank=accelerator.rank):
|
||||
def _signal_handler(
|
||||
signum,
|
||||
frame,
|
||||
_model=model,
|
||||
_optim=optimizer,
|
||||
_sched=scheduler,
|
||||
_save_dir=save_dir,
|
||||
_pretrained=pretrained_path,
|
||||
_hf_id=hf_model_id,
|
||||
_dist=distribute,
|
||||
_resume=resume,
|
||||
_rank=accelerator.rank,
|
||||
):
|
||||
try:
|
||||
cur_step = int(_resume.get("step", start_step))
|
||||
except Exception:
|
||||
@@ -229,8 +251,8 @@ def train(
|
||||
except StopIteration:
|
||||
data_epoch += 1
|
||||
# Key: set DistributedSampler epoch to ensure different data order each epoch
|
||||
sampler = getattr(train_loader, 'sampler', None)
|
||||
if hasattr(sampler, 'set_epoch'):
|
||||
sampler = getattr(train_loader, "sampler", None)
|
||||
if hasattr(sampler, "set_epoch"):
|
||||
sampler.set_epoch(data_epoch)
|
||||
train_iter = iter(train_loader)
|
||||
return next(train_iter)
|
||||
@@ -250,7 +272,7 @@ def train(
|
||||
|
||||
# Only sync gradients on the last micro-batch
|
||||
# Use no_sync() for intermediate steps to reduce communication overhead
|
||||
is_last_micro_step = (micro_step == grad_accum_steps - 1)
|
||||
is_last_micro_step = micro_step == grad_accum_steps - 1
|
||||
sync_context = contextlib.nullcontext() if is_last_micro_step else accelerator.no_sync()
|
||||
|
||||
with sync_context:
|
||||
@@ -299,10 +321,22 @@ def train(
|
||||
tracker.log_metrics(loss_values, split="train")
|
||||
|
||||
if val_loader is not None and (step % valid_interval == 0 or step == num_iters - 1):
|
||||
validate(model, val_loader, batch_processor, accelerator, tracker, lambdas,
|
||||
writer=writer, step=step, val_ds=val_ds, audio_vae=audio_vae_for_gen,
|
||||
sample_rate=sample_rate, val_texts=val_texts, tokenizer=tokenizer,
|
||||
valid_interval=valid_interval)
|
||||
validate(
|
||||
model,
|
||||
val_loader,
|
||||
batch_processor,
|
||||
accelerator,
|
||||
tracker,
|
||||
lambdas,
|
||||
writer=writer,
|
||||
step=step,
|
||||
val_ds=val_ds,
|
||||
audio_vae=audio_vae_for_gen,
|
||||
sample_rate=sample_rate,
|
||||
val_texts=val_texts,
|
||||
tokenizer=tokenizer,
|
||||
valid_interval=valid_interval,
|
||||
)
|
||||
|
||||
if (step % save_interval == 0 or step == num_iters - 1) and accelerator.rank == 0:
|
||||
save_checkpoint(model, optimizer, scheduler, save_dir, step, pretrained_path, hf_model_id, distribute)
|
||||
@@ -313,11 +347,24 @@ def train(
|
||||
writer.close()
|
||||
|
||||
|
||||
def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas,
|
||||
writer=None, step=0, val_ds=None, audio_vae=None, sample_rate=22050,
|
||||
val_texts=None, tokenizer=None, valid_interval=1000):
|
||||
def validate(
|
||||
model,
|
||||
val_loader,
|
||||
batch_processor,
|
||||
accelerator,
|
||||
tracker,
|
||||
lambdas,
|
||||
writer=None,
|
||||
step=0,
|
||||
val_ds=None,
|
||||
audio_vae=None,
|
||||
sample_rate=22050,
|
||||
val_texts=None,
|
||||
tokenizer=None,
|
||||
valid_interval=1000,
|
||||
):
|
||||
"""Validate and generate sample audio"""
|
||||
import numpy as np
|
||||
import numpy as np # noqa: F401
|
||||
from collections import defaultdict
|
||||
|
||||
model.eval()
|
||||
@@ -369,13 +416,24 @@ def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas,
|
||||
# Generate sample audio for TensorBoard display
|
||||
if writer is not None and val_ds is not None and audio_vae is not None and accelerator.rank == 0:
|
||||
try:
|
||||
generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, sample_rate,
|
||||
val_texts=val_texts, tokenizer=tokenizer, valid_interval=valid_interval,
|
||||
tracker=tracker)
|
||||
generate_sample_audio(
|
||||
model,
|
||||
val_ds,
|
||||
audio_vae,
|
||||
writer,
|
||||
step,
|
||||
accelerator,
|
||||
sample_rate,
|
||||
val_texts=val_texts,
|
||||
tokenizer=tokenizer,
|
||||
valid_interval=valid_interval,
|
||||
tracker=tracker,
|
||||
)
|
||||
except Exception as e:
|
||||
tracker.print(f"[Warning] Failed to generate sample audio: {e}")
|
||||
import traceback
|
||||
import io
|
||||
|
||||
buf = io.StringIO()
|
||||
traceback.print_exc(file=buf)
|
||||
tracker.print(buf.getvalue())
|
||||
@@ -398,6 +456,7 @@ def compute_mel_spectrogram(audio_np, sample_rate, n_mels=128):
|
||||
"""Compute Mel Spectrogram (dB) using librosa"""
|
||||
import numpy as np
|
||||
import librosa
|
||||
|
||||
audio_np = audio_np.flatten().astype(np.float32)
|
||||
mel = librosa.feature.melspectrogram(y=audio_np, sr=sample_rate, n_mels=n_mels, fmax=sample_rate // 2)
|
||||
return librosa.power_to_db(mel, ref=np.max)
|
||||
@@ -408,7 +467,8 @@ def create_mel_figure(gen_audio_np, gen_mel, sample_rate, step=None, ref_audio_n
|
||||
Create mel spectrogram figure: show comparison if reference audio exists, otherwise show generated only
|
||||
"""
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
import librosa.display
|
||||
|
||||
@@ -419,19 +479,32 @@ def create_mel_figure(gen_audio_np, gen_mel, sample_rate, step=None, ref_audio_n
|
||||
# Comparison mode: reference vs generated
|
||||
fig, (ax_ref, ax_gen) = plt.subplots(2, 1, figsize=(12, 8))
|
||||
|
||||
img_ref = librosa.display.specshow(ref_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax_ref)
|
||||
ax_ref.set_title(f'Reference (GT) - {len(ref_audio_np)/sample_rate:.2f}s{step_str}', fontsize=10, fontweight='bold', color='#28A745')
|
||||
plt.colorbar(img_ref, ax=ax_ref, format='%+2.0f dB', pad=0.02)
|
||||
img_ref = librosa.display.specshow(
|
||||
ref_mel, sr=sample_rate, x_axis="time", y_axis="mel", fmax=fmax, cmap="viridis", ax=ax_ref
|
||||
)
|
||||
ax_ref.set_title(
|
||||
f"Reference (GT) - {len(ref_audio_np)/sample_rate:.2f}s{step_str}",
|
||||
fontsize=10,
|
||||
fontweight="bold",
|
||||
color="#28A745",
|
||||
)
|
||||
plt.colorbar(img_ref, ax=ax_ref, format="%+2.0f dB", pad=0.02)
|
||||
|
||||
img_gen = librosa.display.specshow(gen_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax_gen)
|
||||
ax_gen.set_title(f'Generated - {len(gen_audio_np)/sample_rate:.2f}s', fontsize=10, fontweight='bold', color='#DC3545')
|
||||
plt.colorbar(img_gen, ax=ax_gen, format='%+2.0f dB', pad=0.02)
|
||||
img_gen = librosa.display.specshow(
|
||||
gen_mel, sr=sample_rate, x_axis="time", y_axis="mel", fmax=fmax, cmap="viridis", ax=ax_gen
|
||||
)
|
||||
ax_gen.set_title(
|
||||
f"Generated - {len(gen_audio_np)/sample_rate:.2f}s", fontsize=10, fontweight="bold", color="#DC3545"
|
||||
)
|
||||
plt.colorbar(img_gen, ax=ax_gen, format="%+2.0f dB", pad=0.02)
|
||||
else:
|
||||
# Single figure mode: show generated only
|
||||
fig, ax = plt.subplots(figsize=(12, 4))
|
||||
img = librosa.display.specshow(gen_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax)
|
||||
ax.set_title(f'Generated - {len(gen_audio_np)/sample_rate:.2f}s{step_str}', fontsize=11, fontweight='bold')
|
||||
plt.colorbar(img, ax=ax, format='%+2.0f dB', pad=0.02)
|
||||
img = librosa.display.specshow(
|
||||
gen_mel, sr=sample_rate, x_axis="time", y_axis="mel", fmax=fmax, cmap="viridis", ax=ax
|
||||
)
|
||||
ax.set_title(f"Generated - {len(gen_audio_np)/sample_rate:.2f}s{step_str}", fontsize=11, fontweight="bold")
|
||||
plt.colorbar(img, ax=ax, format="%+2.0f dB", pad=0.02)
|
||||
|
||||
plt.tight_layout()
|
||||
return fig
|
||||
@@ -440,13 +513,25 @@ def create_mel_figure(gen_audio_np, gen_mel, sample_rate, step=None, ref_audio_n
|
||||
def normalize_audio(audio_np):
|
||||
"""Normalize audio to [-0.9, 0.9]"""
|
||||
import numpy as np
|
||||
|
||||
max_val = np.abs(audio_np).max()
|
||||
return audio_np / max_val * 0.9 if max_val > 0 else audio_np
|
||||
|
||||
|
||||
def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, sample_rate=22050,
|
||||
val_texts=None, tokenizer=None, pretrained_path=None, valid_interval=1000,
|
||||
tracker=None):
|
||||
def generate_sample_audio(
|
||||
model,
|
||||
val_ds,
|
||||
audio_vae,
|
||||
writer,
|
||||
step,
|
||||
accelerator,
|
||||
sample_rate=22050,
|
||||
val_texts=None,
|
||||
tokenizer=None,
|
||||
pretrained_path=None,
|
||||
valid_interval=1000,
|
||||
tracker=None,
|
||||
):
|
||||
"""Select 2 fixed validation samples, generate audio and log to TensorBoard"""
|
||||
import numpy as np
|
||||
|
||||
@@ -468,7 +553,10 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
|
||||
ref_sr = sample["audio"].get("sampling_rate", sample_rate)
|
||||
if ref_sr != sample_rate:
|
||||
import torchaudio.functional as F
|
||||
ref_audio_np = F.resample(torch.from_numpy(ref_audio_np).unsqueeze(0), ref_sr, sample_rate).squeeze(0).numpy()
|
||||
|
||||
ref_audio_np = (
|
||||
F.resample(torch.from_numpy(ref_audio_np).unsqueeze(0), ref_sr, sample_rate).squeeze(0).numpy()
|
||||
)
|
||||
log(f"[Audio] Loaded reference audio for sample {i}: duration={len(ref_audio_np)/sample_rate:.2f}s")
|
||||
except Exception as e:
|
||||
log(f"[Warning] Failed to load reference audio: {e}")
|
||||
@@ -500,7 +588,11 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
|
||||
continue
|
||||
|
||||
# Process generated audio
|
||||
gen_audio_np = generated.cpu().float().numpy().flatten() if isinstance(generated, torch.Tensor) else np.array(generated, dtype=np.float32).flatten()
|
||||
gen_audio_np = (
|
||||
generated.cpu().float().numpy().flatten()
|
||||
if isinstance(generated, torch.Tensor)
|
||||
else np.array(generated, dtype=np.float32).flatten()
|
||||
)
|
||||
gen_audio_np = normalize_audio(gen_audio_np)
|
||||
|
||||
tag = f"val_sample_{i}"
|
||||
@@ -509,7 +601,9 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
|
||||
|
||||
# Log reference audio
|
||||
if ref_audio_np is not None:
|
||||
writer.add_audio(f"{tag}/reference_audio", normalize_audio(ref_audio_np), global_step=step, sample_rate=sample_rate)
|
||||
writer.add_audio(
|
||||
f"{tag}/reference_audio", normalize_audio(ref_audio_np), global_step=step, sample_rate=sample_rate
|
||||
)
|
||||
|
||||
# Generate mel spectrogram figure
|
||||
try:
|
||||
@@ -524,6 +618,7 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
|
||||
except Exception as e:
|
||||
log(f"[Warning] Failed to generate audio for sample {i}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
finally:
|
||||
@@ -545,8 +640,6 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0):
|
||||
Called by all ranks so that distributed state stays aligned.
|
||||
Returns the step number to resume from, or 0 if no checkpoint found.
|
||||
"""
|
||||
import json
|
||||
|
||||
latest_folder = save_dir / "latest"
|
||||
if not latest_folder.exists():
|
||||
return 0
|
||||
@@ -564,6 +657,7 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0):
|
||||
if lora_weights_path.exists():
|
||||
if lora_weights_path.suffix == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
state_dict = load_file(str(lora_weights_path))
|
||||
else:
|
||||
ckpt = torch.load(lora_weights_path, map_location="cpu")
|
||||
@@ -581,6 +675,7 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0):
|
||||
if model_path.exists():
|
||||
if model_path.suffix == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
state_dict = load_file(str(model_path))
|
||||
else:
|
||||
ckpt = torch.load(model_path, map_location="cpu")
|
||||
@@ -625,13 +720,21 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0):
|
||||
return 0
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pretrained_path: str = None, hf_model_id: str = "", distribute: bool = False):
|
||||
def save_checkpoint(
|
||||
model,
|
||||
optimizer,
|
||||
scheduler,
|
||||
save_dir: Path,
|
||||
step: int,
|
||||
pretrained_path: str = None,
|
||||
hf_model_id: str = "",
|
||||
distribute: bool = False,
|
||||
):
|
||||
"""
|
||||
Save checkpoint with different strategies for full finetune vs LoRA:
|
||||
- Full finetune: save non-vae weights to model.safetensors (or pytorch_model.bin if safetensors unavailable)
|
||||
- LoRA: save only lora weights to lora_weights.safetensors (or lora_weights.ckpt if safetensors unavailable)
|
||||
"""
|
||||
import json
|
||||
import shutil
|
||||
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -671,7 +774,14 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
|
||||
# Copy config files from pretrained path
|
||||
if pretrained_path:
|
||||
pretrained_dir = Path(pretrained_path)
|
||||
files_to_copy = ["config.json", "audiovae.pth", "tokenizer.json", "special_tokens_map.json", "tokenizer_config.json"]
|
||||
files_to_copy = [
|
||||
"config.json",
|
||||
"audiovae.pth",
|
||||
"audiovae.safetensors",
|
||||
"tokenizer.json",
|
||||
"special_tokens_map.json",
|
||||
"tokenizer_config.json",
|
||||
]
|
||||
for fname in files_to_copy:
|
||||
src = pretrained_dir / fname
|
||||
if src.exists():
|
||||
|
||||
+46
-27
@@ -13,11 +13,11 @@ import soundfile as sf
|
||||
|
||||
from voxcpm.core import VoxCPM
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Validators
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def validate_file_exists(file_path: str, file_type: str = "file") -> Path:
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
@@ -53,12 +53,11 @@ def validate_ranges(args, parser):
|
||||
# Model loading
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def load_model(args) -> VoxCPM:
|
||||
print("Loading VoxCPM model...", file=sys.stderr)
|
||||
|
||||
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get(
|
||||
"ZIPENHANCER_MODEL_PATH", None
|
||||
)
|
||||
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get("ZIPENHANCER_MODEL_PATH", None)
|
||||
|
||||
# Build LoRA config if provided
|
||||
lora_config = None
|
||||
@@ -119,22 +118,29 @@ def load_model(args) -> VoxCPM:
|
||||
# Commands
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def cmd_clone(args):
|
||||
if not args.text:
|
||||
sys.exit("Error: Please provide --text for synthesis")
|
||||
|
||||
if not args.prompt_audio or not args.prompt_text:
|
||||
sys.exit("Error: Voice cloning requires both --prompt-audio and --prompt-text")
|
||||
has_prompt = args.prompt_audio and args.prompt_text
|
||||
has_ref = args.reference_audio is not None
|
||||
if not has_prompt and not has_ref:
|
||||
sys.exit("Error: Voice cloning requires --prompt-audio + --prompt-text, or --reference-audio, or both")
|
||||
|
||||
prompt_audio_path = validate_file_exists(args.prompt_audio, "reference audio file")
|
||||
if args.prompt_audio:
|
||||
validate_file_exists(args.prompt_audio, "prompt audio file")
|
||||
if args.reference_audio:
|
||||
validate_file_exists(args.reference_audio, "reference audio file")
|
||||
output_path = validate_output_path(args.output)
|
||||
|
||||
model = load_model(args)
|
||||
|
||||
audio_array = model.generate(
|
||||
text=args.text,
|
||||
prompt_wav_path=str(prompt_audio_path),
|
||||
prompt_text=args.prompt_text,
|
||||
prompt_wav_path=args.prompt_audio if has_prompt else None,
|
||||
prompt_text=args.prompt_text if has_prompt else None,
|
||||
reference_wav_path=args.reference_audio,
|
||||
cfg_value=args.cfg_value,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
normalize=args.normalize,
|
||||
@@ -185,7 +191,11 @@ def cmd_batch(args):
|
||||
|
||||
prompt_audio_path = None
|
||||
if args.prompt_audio:
|
||||
prompt_audio_path = str(validate_file_exists(args.prompt_audio, "reference audio file"))
|
||||
prompt_audio_path = str(validate_file_exists(args.prompt_audio, "prompt audio file"))
|
||||
|
||||
reference_audio_path = None
|
||||
if args.reference_audio:
|
||||
reference_audio_path = str(validate_file_exists(args.reference_audio, "reference audio file"))
|
||||
|
||||
success_count = 0
|
||||
|
||||
@@ -195,10 +205,11 @@ def cmd_batch(args):
|
||||
text=text,
|
||||
prompt_wav_path=prompt_audio_path,
|
||||
prompt_text=args.prompt_text,
|
||||
reference_wav_path=reference_audio_path,
|
||||
cfg_value=args.cfg_value,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
normalize=args.normalize,
|
||||
denoise=args.denoise and prompt_audio_path is not None,
|
||||
denoise=args.denoise and (prompt_audio_path is not None or reference_audio_path is not None),
|
||||
)
|
||||
|
||||
output_file = output_dir / f"output_{i:03d}.wav"
|
||||
@@ -218,6 +229,7 @@ def cmd_batch(args):
|
||||
# Parser
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def _build_unified_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="VoxCPM CLI - voice cloning, direct TTS, and batch processing",
|
||||
@@ -236,34 +248,40 @@ Examples:
|
||||
parser.add_argument("--text", "-t", help="Text to synthesize (single or clone mode)")
|
||||
parser.add_argument("--output", "-o", help="Output audio file path (single or clone mode)")
|
||||
|
||||
# Prompt
|
||||
parser.add_argument("--prompt-audio", "-pa", help="Reference audio file path (clone mode)")
|
||||
parser.add_argument("--prompt-text", "-pt", help="Reference text corresponding to the audio")
|
||||
parser.add_argument("--denoise", action="store_true", help="Enable prompt speech enhancement")
|
||||
# Prompt / Reference
|
||||
parser.add_argument(
|
||||
"--prompt-audio", "-pa", help="Prompt audio file path (continuation mode, requires --prompt-text)"
|
||||
)
|
||||
parser.add_argument("--prompt-text", "-pt", help="Text corresponding to the prompt audio")
|
||||
parser.add_argument(
|
||||
"--reference-audio", "-ra", help="Reference audio for voice cloning (isolated mode, VoxCPM2 only)"
|
||||
)
|
||||
parser.add_argument("--denoise", action="store_true", help="Enable prompt/reference speech enhancement")
|
||||
|
||||
# Generation parameters
|
||||
parser.add_argument("--cfg-value", type=float, default=2.0,
|
||||
help="CFG guidance scale (float, recommended 0.5–5.0, default: 2.0)")
|
||||
parser.add_argument("--inference-timesteps", type=int, default=10,
|
||||
help="Inference steps (int, 1–100, default: 10)")
|
||||
parser.add_argument(
|
||||
"--cfg-value", type=float, default=2.0, help="CFG guidance scale (float, recommended 0.5–5.0, default: 2.0)"
|
||||
)
|
||||
parser.add_argument("--inference-timesteps", type=int, default=10, help="Inference steps (int, 1–100, default: 10)")
|
||||
parser.add_argument("--normalize", action="store_true", help="Enable text normalization")
|
||||
|
||||
# Model loading
|
||||
parser.add_argument("--model-path", type=str, help="Local VoxCPM model path")
|
||||
parser.add_argument("--hf-model-id", type=str, default="openbmb/VoxCPM1.5",
|
||||
help="Hugging Face repo id (default: openbmb/VoxCPM1.5)")
|
||||
parser.add_argument(
|
||||
"--hf-model-id", type=str, default="openbmb/VoxCPM1.5", help="Hugging Face repo id (default: openbmb/VoxCPM1.5)"
|
||||
)
|
||||
parser.add_argument("--cache-dir", type=str, help="Cache directory for Hub downloads")
|
||||
parser.add_argument("--local-files-only", action="store_true", help="Disable network access")
|
||||
parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading")
|
||||
parser.add_argument("--zipenhancer-path", type=str,
|
||||
help="ZipEnhancer model id or local path (or env ZIPENHANCER_MODEL_PATH)")
|
||||
parser.add_argument(
|
||||
"--zipenhancer-path", type=str, help="ZipEnhancer model id or local path (or env ZIPENHANCER_MODEL_PATH)"
|
||||
)
|
||||
|
||||
# LoRA
|
||||
parser.add_argument("--lora-path", type=str, help="Path to LoRA weights")
|
||||
parser.add_argument("--lora-r", type=int, default=32, help="LoRA rank (positive int, default: 32)")
|
||||
parser.add_argument("--lora-alpha", type=int, default=16, help="LoRA alpha (positive int, default: 16)")
|
||||
parser.add_argument("--lora-dropout", type=float, default=0.0,
|
||||
help="LoRA dropout rate (0.0–1.0, default: 0.0)")
|
||||
parser.add_argument("--lora-dropout", type=float, default=0.0, help="LoRA dropout rate (0.0–1.0, default: 0.0)")
|
||||
parser.add_argument("--lora-disable-lm", action="store_true", help="Disable LoRA on LM layers")
|
||||
parser.add_argument("--lora-disable-dit", action="store_true", help="Disable LoRA on DiT layers")
|
||||
parser.add_argument("--lora-enable-proj", action="store_true", help="Enable LoRA on projection layers")
|
||||
@@ -275,6 +293,7 @@ Examples:
|
||||
# Entrypoint
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def main():
|
||||
parser = _build_unified_parser()
|
||||
args = parser.parse_args()
|
||||
@@ -296,8 +315,8 @@ def main():
|
||||
if not args.text or not args.output:
|
||||
parser.error("Single-sample mode requires --text and --output")
|
||||
|
||||
# Clone mode
|
||||
if args.prompt_audio or args.prompt_text:
|
||||
# Clone mode (prompt continuation, reference isolation, or both)
|
||||
if args.prompt_audio or args.prompt_text or args.reference_audio:
|
||||
return cmd_clone(args)
|
||||
|
||||
# Direct synthesis
|
||||
|
||||
+80
-29
@@ -1,16 +1,20 @@
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import json
|
||||
import tempfile
|
||||
import numpy as np
|
||||
from typing import Generator, Optional
|
||||
from huggingface_hub import snapshot_download
|
||||
from .model.voxcpm import VoxCPMModel, LoRAConfig
|
||||
from .model.voxcpm2 import VoxCPM2Model
|
||||
|
||||
|
||||
class VoxCPM:
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
voxcpm_model_path: str,
|
||||
zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||
zipenhancer_model_path: str | None = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||
enable_denoiser: bool = True,
|
||||
optimize: bool = True,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
@@ -31,7 +35,10 @@ class VoxCPM:
|
||||
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
|
||||
containing lora_weights.ckpt). If provided, LoRA weights will be loaded.
|
||||
"""
|
||||
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}", file=sys.stderr)
|
||||
print(
|
||||
f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# If lora_weights_path is provided but no lora_config, create a default one
|
||||
if lora_weights_path is not None and lora_config is None:
|
||||
@@ -42,7 +49,20 @@ class VoxCPM:
|
||||
)
|
||||
print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}", file=sys.stderr)
|
||||
|
||||
# Determine model type from config.json architecture field
|
||||
config_path = os.path.join(voxcpm_model_path, "config.json")
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
arch = config.get("architecture", "voxcpm").lower()
|
||||
|
||||
if arch == "voxcpm2":
|
||||
self.tts_model = VoxCPM2Model.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
||||
print("Loaded VoxCPM2Model", file=sys.stderr)
|
||||
elif arch == "voxcpm":
|
||||
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
||||
print("Loaded VoxCPMModel", file=sys.stderr)
|
||||
else:
|
||||
raise ValueError(f"Unsupported architecture: {arch}")
|
||||
|
||||
# Load LoRA weights if path is provided
|
||||
if lora_weights_path is not None:
|
||||
@@ -51,8 +71,10 @@ class VoxCPM:
|
||||
print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}", file=sys.stderr)
|
||||
|
||||
self.text_normalizer = None
|
||||
self.denoiser = None
|
||||
if enable_denoiser and zipenhancer_model_path is not None:
|
||||
from .zipenhancer import ZipEnhancer
|
||||
|
||||
self.denoiser = ZipEnhancer(zipenhancer_model_path)
|
||||
else:
|
||||
self.denoiser = None
|
||||
@@ -64,8 +86,9 @@ class VoxCPM:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls,
|
||||
hf_model_id: str = "openbmb/VoxCPM1.5",
|
||||
def from_pretrained(
|
||||
cls,
|
||||
hf_model_id: str = "openbmb/VoxCPM2",
|
||||
load_denoiser: bool = True,
|
||||
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||
cache_dir: str = None,
|
||||
@@ -134,10 +157,12 @@ class VoxCPM:
|
||||
def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]:
|
||||
return self._generate(*args, streaming=True, **kwargs)
|
||||
|
||||
def _generate(self,
|
||||
def _generate(
|
||||
self,
|
||||
text: str,
|
||||
prompt_wav_path: str = None,
|
||||
prompt_text: str = None,
|
||||
reference_wav_path: str = None,
|
||||
cfg_value: float = 2.0,
|
||||
inference_timesteps: int = 10,
|
||||
min_len: int = 2,
|
||||
@@ -151,29 +176,28 @@ class VoxCPM:
|
||||
) -> Generator[np.ndarray, None, None]:
|
||||
"""Synthesize speech for the given text and return a single waveform.
|
||||
|
||||
This method optionally builds and reuses a prompt cache. If an external
|
||||
prompt (``prompt_wav_path`` + ``prompt_text``) is provided, it will be
|
||||
used for all sub-sentences. Otherwise, the prompt cache is built from
|
||||
the first generated result and reused for the remaining text chunks.
|
||||
|
||||
Args:
|
||||
text: Input text. Can include newlines; each non-empty line is
|
||||
treated as a sub-sentence.
|
||||
prompt_wav_path: Path to a reference audio file for prompting.
|
||||
text: Input text to synthesize.
|
||||
prompt_wav_path: Path to prompt audio for continuation mode.
|
||||
Must be paired with ``prompt_text``.
|
||||
prompt_text: Text content corresponding to the prompt audio.
|
||||
reference_wav_path: Path to reference audio for voice cloning
|
||||
(structurally isolated via ref_audio tokens). Can be used
|
||||
alone or combined with ``prompt_wav_path`` + ``prompt_text``.
|
||||
cfg_value: Guidance scale for the generation model.
|
||||
inference_timesteps: Number of inference steps.
|
||||
min_len: Minimum audio length.
|
||||
max_len: Maximum token length during generation.
|
||||
normalize: Whether to run text normalization before generation.
|
||||
denoise: Whether to denoise the prompt audio if a denoiser is
|
||||
available.
|
||||
denoise: Whether to denoise the prompt/reference audio if a
|
||||
denoiser is available.
|
||||
retry_badcase: Whether to retry badcase.
|
||||
retry_badcase_max_times: Maximum number of times to retry badcase.
|
||||
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio.
|
||||
streaming: Whether to return a generator of audio chunks.
|
||||
Returns:
|
||||
Generator of numpy.ndarray: 1D waveform array (float32) on CPU.
|
||||
Yields audio chunks for each generations step if ``streaming=True``,
|
||||
Yields audio chunks for each generation step if ``streaming=True``,
|
||||
otherwise yields a single array containing the final audio.
|
||||
"""
|
||||
if not text.strip() or not isinstance(text, str):
|
||||
@@ -183,30 +207,56 @@ class VoxCPM:
|
||||
if not os.path.exists(prompt_wav_path):
|
||||
raise FileNotFoundError(f"prompt_wav_path does not exist: {prompt_wav_path}")
|
||||
|
||||
if reference_wav_path is not None:
|
||||
if not os.path.exists(reference_wav_path):
|
||||
raise FileNotFoundError(f"reference_wav_path does not exist: {reference_wav_path}")
|
||||
|
||||
if (prompt_wav_path is None) != (prompt_text is None):
|
||||
raise ValueError("prompt_wav_path and prompt_text must both be provided or both be None")
|
||||
|
||||
is_v2 = isinstance(self.tts_model, VoxCPM2Model)
|
||||
if reference_wav_path is not None and not is_v2:
|
||||
raise ValueError("reference_wav_path is only supported with VoxCPM2 models")
|
||||
|
||||
text = text.replace("\n", " ")
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
temp_prompt_wav_path = None
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
temp_files = []
|
||||
|
||||
try:
|
||||
if prompt_wav_path is not None and prompt_text is not None:
|
||||
actual_prompt_path = prompt_wav_path
|
||||
actual_ref_path = reference_wav_path
|
||||
|
||||
if denoise and self.denoiser is not None:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
|
||||
temp_prompt_wav_path = tmp_file.name
|
||||
self.denoiser.enhance(prompt_wav_path, output_path=temp_prompt_wav_path)
|
||||
prompt_wav_path = temp_prompt_wav_path
|
||||
if prompt_wav_path is not None:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
||||
temp_files.append(tmp.name)
|
||||
self.denoiser.enhance(prompt_wav_path, output_path=temp_files[-1])
|
||||
actual_prompt_path = temp_files[-1]
|
||||
if reference_wav_path is not None:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
||||
temp_files.append(tmp.name)
|
||||
self.denoiser.enhance(reference_wav_path, output_path=temp_files[-1])
|
||||
actual_ref_path = temp_files[-1]
|
||||
|
||||
if actual_prompt_path is not None or actual_ref_path is not None:
|
||||
if is_v2:
|
||||
fixed_prompt_cache = self.tts_model.build_prompt_cache(
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
prompt_text=prompt_text
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav_path=actual_prompt_path,
|
||||
reference_wav_path=actual_ref_path,
|
||||
)
|
||||
else:
|
||||
fixed_prompt_cache = None # will be built from the first inference
|
||||
fixed_prompt_cache = self.tts_model.build_prompt_cache(
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav_path=actual_prompt_path,
|
||||
)
|
||||
else:
|
||||
fixed_prompt_cache = None
|
||||
|
||||
if normalize:
|
||||
if self.text_normalizer is None:
|
||||
from .utils.text_normalize import TextNormalizer
|
||||
|
||||
self.text_normalizer = TextNormalizer()
|
||||
text = self.text_normalizer.normalize(text)
|
||||
|
||||
@@ -227,9 +277,10 @@ class VoxCPM:
|
||||
yield wav.squeeze(0).cpu().numpy()
|
||||
|
||||
finally:
|
||||
if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path):
|
||||
for tmp_path in temp_files:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
try:
|
||||
os.unlink(temp_prompt_wav_path)
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from .voxcpm import VoxCPMModel
|
||||
from .voxcpm2 import VoxCPM2Model
|
||||
|
||||
__all__ = ["VoxCPMModel"]
|
||||
__all__ = ["VoxCPMModel", "VoxCPM2Model"]
|
||||
|
||||
@@ -24,8 +24,7 @@ def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
|
||||
"""
|
||||
# Pre-compute multi-character tokens (length >= 2, pure Chinese characters)
|
||||
multichar_tokens = {
|
||||
token for token in tokenizer.vocab.keys()
|
||||
if len(token) >= 2 and all("\u4e00" <= c <= "\u9fff" for c in token)
|
||||
token for token in tokenizer.vocab.keys() if len(token) >= 2 and all("\u4e00" <= c <= "\u9fff" for c in token)
|
||||
}
|
||||
|
||||
class CharTokenizerWrapper:
|
||||
|
||||
+69
-56
@@ -24,7 +24,6 @@ from typing import Tuple, Union, Generator, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
import warnings
|
||||
from einops import rearrange
|
||||
@@ -32,6 +31,7 @@ from pydantic import BaseModel
|
||||
|
||||
try:
|
||||
from safetensors.torch import load_file
|
||||
|
||||
SAFETENSORS_AVAILABLE = True
|
||||
except ImportError:
|
||||
SAFETENSORS_AVAILABLE = False
|
||||
@@ -168,7 +168,7 @@ class VoxCPMModel(nn.Module):
|
||||
config.lm_config.hidden_size,
|
||||
config.lm_config.hidden_size,
|
||||
config.scalar_quantization_latent_dim,
|
||||
config.scalar_quantization_scale
|
||||
config.scalar_quantization_scale,
|
||||
)
|
||||
self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size)
|
||||
self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
|
||||
@@ -196,9 +196,7 @@ class VoxCPMModel(nn.Module):
|
||||
# LM: base_lm + residual_lm
|
||||
if cfg.enable_lm:
|
||||
for lm in [self.base_lm, self.residual_lm]:
|
||||
apply_lora_to_named_linear_modules(
|
||||
lm, target_submodule_names=cfg.target_modules_lm, **lora_kwargs
|
||||
)
|
||||
apply_lora_to_named_linear_modules(lm, target_submodule_names=cfg.target_modules_lm, **lora_kwargs)
|
||||
|
||||
# DiT: feat_decoder.estimator
|
||||
if cfg.enable_dit:
|
||||
@@ -209,6 +207,7 @@ class VoxCPMModel(nn.Module):
|
||||
# 投影层
|
||||
if cfg.enable_proj:
|
||||
from ..modules.layers.lora import LoRALinear
|
||||
|
||||
for attr_name in cfg.target_proj_modules:
|
||||
module = getattr(self, attr_name, None)
|
||||
if isinstance(module, nn.Linear):
|
||||
@@ -221,13 +220,17 @@ class VoxCPMModel(nn.Module):
|
||||
if self.device != "cuda":
|
||||
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
|
||||
try:
|
||||
import triton
|
||||
import triton # noqa: F401
|
||||
except ImportError:
|
||||
raise ValueError("triton is not installed")
|
||||
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
||||
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
||||
self.residual_lm.forward_step = torch.compile(
|
||||
self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True
|
||||
)
|
||||
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
||||
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
|
||||
self.feat_decoder.estimator = torch.compile(
|
||||
self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: torch.compile disabled - {e}", file=sys.stderr)
|
||||
return self
|
||||
@@ -313,9 +316,11 @@ class VoxCPMModel(nn.Module):
|
||||
mu=dit_hidden,
|
||||
patch_size=self.patch_size,
|
||||
cond=feat_cond_for_sample,
|
||||
n_timesteps=self.config.dit_config.cfm_config.inference_cfg_rate
|
||||
n_timesteps=(
|
||||
self.config.dit_config.cfm_config.inference_cfg_rate
|
||||
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
|
||||
else 10,
|
||||
else 10
|
||||
),
|
||||
)
|
||||
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
|
||||
|
||||
@@ -331,7 +336,6 @@ class VoxCPMModel(nn.Module):
|
||||
def _dtype(self):
|
||||
return get_dtype(self.config.dtype)
|
||||
|
||||
|
||||
def generate(self, *args, **kwargs) -> torch.Tensor:
|
||||
return next(self._generate(*args, streaming=False, **kwargs))
|
||||
|
||||
@@ -444,7 +448,9 @@ class VoxCPMModel(nn.Module):
|
||||
audio_feat,
|
||||
audio_mask,
|
||||
min_len=min_len,
|
||||
max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len), # avoid too long audio
|
||||
max_len=min(
|
||||
int(target_text_length * retry_badcase_ratio_threshold + 10), max_len
|
||||
), # avoid too long audio
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
streaming=streaming,
|
||||
@@ -460,7 +466,10 @@ class VoxCPMModel(nn.Module):
|
||||
latent_pred, pred_audio_feat = next(inference_result)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...", file=sys.stderr)
|
||||
print(
|
||||
f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...",
|
||||
file=sys.stderr,
|
||||
)
|
||||
retry_badcase_times += 1
|
||||
continue
|
||||
else:
|
||||
@@ -514,7 +523,9 @@ class VoxCPMModel(nn.Module):
|
||||
self.audio_vae.latent_dim,
|
||||
-1,
|
||||
self.patch_size,
|
||||
).permute(1, 2, 0) # (D, T, P)
|
||||
).permute(
|
||||
1, 2, 0
|
||||
) # (D, T, P)
|
||||
# build prompt cache - only save raw text and audio features
|
||||
prompt_cache = {
|
||||
"prompt_text": prompt_text,
|
||||
@@ -523,7 +534,6 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
return prompt_cache
|
||||
|
||||
|
||||
def merge_prompt_cache(
|
||||
self,
|
||||
original_cache: dict,
|
||||
@@ -560,17 +570,14 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
return merged_cache
|
||||
|
||||
|
||||
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
||||
|
||||
|
||||
def generate_with_prompt_cache_streaming(
|
||||
self, *args, **kwargs
|
||||
) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]:
|
||||
return self._generate_with_prompt_cache(*args, streaming=True, **kwargs)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _generate_with_prompt_cache(
|
||||
self,
|
||||
@@ -645,8 +652,12 @@ class VoxCPMModel(nn.Module):
|
||||
)
|
||||
text_token = torch.cat([text_token, text_pad_token])
|
||||
audio_feat = torch.cat([audio_pad_feat, prompt_audio_feat], dim=0)
|
||||
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
|
||||
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
|
||||
text_mask = (
|
||||
torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
|
||||
)
|
||||
audio_mask = (
|
||||
torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
|
||||
)
|
||||
|
||||
text_token = text_token.unsqueeze(0).to(self.device)
|
||||
text_mask = text_mask.unsqueeze(0).to(self.device)
|
||||
@@ -663,7 +674,9 @@ class VoxCPMModel(nn.Module):
|
||||
audio_feat,
|
||||
audio_mask,
|
||||
min_len=min_len,
|
||||
max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len), # avoid too long audio
|
||||
max_len=min(
|
||||
int(target_text_length * retry_badcase_ratio_threshold + 10), max_len
|
||||
), # avoid too long audio
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
streaming=streaming,
|
||||
@@ -674,17 +687,16 @@ class VoxCPMModel(nn.Module):
|
||||
for latent_pred, pred_audio_feat in inference_result:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
|
||||
yield (
|
||||
decode_audio,
|
||||
target_text_token,
|
||||
pred_audio_feat
|
||||
)
|
||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||
break
|
||||
else:
|
||||
latent_pred, pred_audio_feat = next(inference_result)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...", file=sys.stderr)
|
||||
print(
|
||||
f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...",
|
||||
file=sys.stderr,
|
||||
)
|
||||
retry_badcase_times += 1
|
||||
continue
|
||||
else:
|
||||
@@ -698,11 +710,7 @@ class VoxCPMModel(nn.Module):
|
||||
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1) :].squeeze(1).cpu()
|
||||
else:
|
||||
decode_audio = decode_audio[..., :].squeeze(1).cpu()
|
||||
yield (
|
||||
decode_audio,
|
||||
target_text_token,
|
||||
pred_audio_feat
|
||||
)
|
||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||
|
||||
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return next(self._inference(*args, streaming=False, **kwargs))
|
||||
@@ -782,7 +790,6 @@ class VoxCPMModel(nn.Module):
|
||||
enc_outputs = self.fsq_layer(enc_outputs) * feat_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
|
||||
lm_hidden = enc_outputs[:, -1, :]
|
||||
|
||||
|
||||
residual_enc_outputs, residual_kv_cache_tuple = self.residual_lm(
|
||||
inputs_embeds=enc_outputs + feat_mask.unsqueeze(-1) * feat_embed,
|
||||
is_causal=True,
|
||||
@@ -790,7 +797,6 @@ class VoxCPMModel(nn.Module):
|
||||
self.residual_lm.kv_cache.fill_caches(residual_kv_cache_tuple)
|
||||
residual_hidden = residual_enc_outputs[:, -1, :]
|
||||
|
||||
|
||||
for i in tqdm(range(max_len)):
|
||||
dit_hidden_1 = self.lm_to_dit_proj(lm_hidden) # [b, h_dit]
|
||||
dit_hidden_2 = self.res_to_dit_proj(residual_hidden) # [b, h_dit]
|
||||
@@ -827,10 +833,10 @@ class VoxCPMModel(nn.Module):
|
||||
curr_embed[:, 0, :], torch.tensor([self.base_lm.kv_cache.step()], device=curr_embed.device)
|
||||
).clone()
|
||||
|
||||
|
||||
lm_hidden = self.fsq_layer(lm_hidden)
|
||||
residual_hidden = self.residual_lm.forward_step(
|
||||
lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device)
|
||||
lm_hidden + curr_embed[:, 0, :],
|
||||
torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device),
|
||||
).clone()
|
||||
|
||||
if not streaming:
|
||||
@@ -838,18 +844,30 @@ class VoxCPMModel(nn.Module):
|
||||
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
|
||||
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
||||
audio_vae_config = getattr(config, 'audio_vae_config', None)
|
||||
audio_vae_config = getattr(config, "audio_vae_config", None)
|
||||
audio_vae = AudioVAE(config=audio_vae_config) if audio_vae_config else AudioVAE()
|
||||
vae_state_dict = torch.load(
|
||||
os.path.join(path, "audiovae.pth"),
|
||||
# Try to load AudioVAE from safetensors first, fallback to pytorch
|
||||
audiovae_safetensors_path = os.path.join(path, "audiovae.safetensors")
|
||||
audiovae_pth_path = os.path.join(path, "audiovae.pth")
|
||||
if os.path.exists(audiovae_safetensors_path) and SAFETENSORS_AVAILABLE:
|
||||
print(f"Loading AudioVAE from safetensors: {audiovae_safetensors_path}", file=sys.stderr)
|
||||
vae_state_dict = load_file(audiovae_safetensors_path, device="cpu")
|
||||
elif os.path.exists(audiovae_pth_path):
|
||||
print(f"Loading AudioVAE from pytorch: {audiovae_pth_path}", file=sys.stderr)
|
||||
checkpoint = torch.load(
|
||||
audiovae_pth_path,
|
||||
map_location="cpu",
|
||||
weights_only=True,
|
||||
)["state_dict"]
|
||||
)
|
||||
vae_state_dict = checkpoint.get("state_dict", checkpoint)
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
|
||||
)
|
||||
model = cls(config, tokenizer, audio_vae, lora_config)
|
||||
if not training:
|
||||
lm_dtype = get_dtype(model.config.dtype)
|
||||
@@ -880,9 +898,7 @@ class VoxCPMModel(nn.Module):
|
||||
)
|
||||
model_state_dict = checkpoint.get("state_dict", checkpoint)
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"Model file not found. Expected either {safetensors_path} or {pytorch_model_path}"
|
||||
)
|
||||
raise FileNotFoundError(f"Model file not found. Expected either {safetensors_path} or {pytorch_model_path}")
|
||||
|
||||
for kw, val in vae_state_dict.items():
|
||||
model_state_dict[f"audio_vae.{kw}"] = val
|
||||
@@ -900,6 +916,7 @@ class VoxCPMModel(nn.Module):
|
||||
def _iter_lora_modules(self):
|
||||
"""Iterate over all LoRA modules."""
|
||||
from ..modules.layers.lora import LoRALinear
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, LoRALinear):
|
||||
yield module
|
||||
@@ -919,15 +936,15 @@ class VoxCPMModel(nn.Module):
|
||||
from pathlib import Path
|
||||
|
||||
device = device or self.device
|
||||
lora_path = Path(lora_path)
|
||||
lora_p = Path(lora_path)
|
||||
|
||||
# Try safetensors first, then fallback to .ckpt
|
||||
if lora_path.is_dir():
|
||||
safetensors_file = lora_path / "lora_weights.safetensors"
|
||||
ckpt_file = lora_path / "lora_weights.ckpt"
|
||||
if lora_p.is_dir():
|
||||
safetensors_file = lora_p / "lora_weights.safetensors"
|
||||
ckpt_file = lora_p / "lora_weights.ckpt"
|
||||
else:
|
||||
safetensors_file = lora_path if lora_path.suffix == ".safetensors" else None
|
||||
ckpt_file = lora_path if lora_path.suffix in [".ckpt", ".pth"] else None
|
||||
safetensors_file = lora_p if lora_p.suffix == ".safetensors" else None
|
||||
ckpt_file = lora_p if lora_p.suffix in [".ckpt", ".pth"] else None
|
||||
|
||||
# Load from safetensors if available
|
||||
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
|
||||
@@ -936,9 +953,7 @@ class VoxCPMModel(nn.Module):
|
||||
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
|
||||
state_dict = ckpt.get("state_dict", ckpt)
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}"
|
||||
)
|
||||
raise FileNotFoundError(f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}")
|
||||
|
||||
# Build param mapping (handle torch.compile's _orig_mod prefix)
|
||||
model_params = dict(self.named_parameters())
|
||||
@@ -967,6 +982,4 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
def get_lora_state_dict(self) -> dict:
|
||||
"""Get all LoRA parameters (lora_A/lora_B)."""
|
||||
return {name: param.data.clone()
|
||||
for name, param in self.named_parameters()
|
||||
if "lora_" in name}
|
||||
return {name: param.data.clone() for name, param in self.named_parameters() if "lora_" in name}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1 +1,2 @@
|
||||
from .audio_vae import AudioVAE, AudioVAEConfig
|
||||
from .audio_vae_v2 import AudioVAE as AudioVAEV2, AudioVAEConfig as AudioVAEConfigV2
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import math
|
||||
from typing import List, Union, Optional
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -285,7 +285,7 @@ class AudioVAE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[AudioVAEConfig] = None,
|
||||
config: AudioVAEConfig = None,
|
||||
):
|
||||
# 如果没有传入config,使用默认配置
|
||||
if config is None:
|
||||
|
||||
@@ -0,0 +1,486 @@
|
||||
import math
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def WNConv1d(*args, **kwargs):
|
||||
return weight_norm(nn.Conv1d(*args, **kwargs))
|
||||
|
||||
|
||||
def WNConvTranspose1d(*args, **kwargs):
|
||||
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
||||
|
||||
|
||||
class CausalConv1d(nn.Conv1d):
|
||||
def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.__padding = padding
|
||||
self.__output_padding = output_padding
|
||||
|
||||
def forward(self, x):
|
||||
x_pad = F.pad(x, (self.__padding * 2 - self.__output_padding, 0))
|
||||
return super().forward(x_pad)
|
||||
|
||||
|
||||
class CausalTransposeConv1d(nn.ConvTranspose1d):
|
||||
def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.__padding = padding
|
||||
self.__output_padding = output_padding
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x)[..., : -(self.__padding * 2 - self.__output_padding)]
|
||||
|
||||
|
||||
def WNCausalConv1d(*args, **kwargs):
|
||||
return weight_norm(CausalConv1d(*args, **kwargs))
|
||||
|
||||
|
||||
def WNCausalTransposeConv1d(*args, **kwargs):
|
||||
return weight_norm(CausalTransposeConv1d(*args, **kwargs))
|
||||
|
||||
|
||||
# Scripting this brings model speed up 1.4x
|
||||
@torch.jit.script
|
||||
def snake(x, alpha):
|
||||
shape = x.shape
|
||||
x = x.reshape(shape[0], shape[1], -1)
|
||||
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
||||
x = x.reshape(shape)
|
||||
return x
|
||||
|
||||
|
||||
class Snake1d(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
||||
|
||||
def forward(self, x):
|
||||
return snake(x, self.alpha)
|
||||
|
||||
|
||||
def init_weights(m):
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
class CausalResidualUnit(nn.Module):
|
||||
def __init__(self, dim: int = 16, dilation: int = 1, kernel: int = 7, groups: int = 1):
|
||||
super().__init__()
|
||||
pad = ((7 - 1) * dilation) // 2
|
||||
self.block = nn.Sequential(
|
||||
Snake1d(dim),
|
||||
WNCausalConv1d(
|
||||
dim,
|
||||
dim,
|
||||
kernel_size=kernel,
|
||||
dilation=dilation,
|
||||
padding=pad,
|
||||
groups=groups,
|
||||
),
|
||||
Snake1d(dim),
|
||||
WNCausalConv1d(dim, dim, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.block(x)
|
||||
pad = (x.shape[-1] - y.shape[-1]) // 2
|
||||
assert pad == 0
|
||||
if pad > 0:
|
||||
x = x[..., pad:-pad]
|
||||
return x + y
|
||||
|
||||
|
||||
class CausalEncoderBlock(nn.Module):
|
||||
def __init__(self, output_dim: int = 16, input_dim=None, stride: int = 1, groups=1):
|
||||
super().__init__()
|
||||
input_dim = input_dim or output_dim // 2
|
||||
self.block = nn.Sequential(
|
||||
CausalResidualUnit(input_dim, dilation=1, groups=groups),
|
||||
CausalResidualUnit(input_dim, dilation=3, groups=groups),
|
||||
CausalResidualUnit(input_dim, dilation=9, groups=groups),
|
||||
Snake1d(input_dim),
|
||||
WNCausalConv1d(
|
||||
input_dim,
|
||||
output_dim,
|
||||
kernel_size=2 * stride,
|
||||
stride=stride,
|
||||
padding=math.ceil(stride / 2),
|
||||
output_padding=stride % 2,
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class CausalEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int = 64,
|
||||
latent_dim: int = 32,
|
||||
strides: list = [2, 4, 8, 8],
|
||||
depthwise: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
# Create first convolution
|
||||
self.block = [WNCausalConv1d(1, d_model, kernel_size=7, padding=3)]
|
||||
|
||||
# Create EncoderBlocks that double channels as they downsample by `stride`
|
||||
for stride in strides:
|
||||
d_model *= 2
|
||||
groups = d_model // 2 if depthwise else 1
|
||||
self.block += [CausalEncoderBlock(output_dim=d_model, stride=stride, groups=groups)]
|
||||
|
||||
groups = d_model if depthwise else 1
|
||||
|
||||
# Create two convolution, for mu and logvar
|
||||
self.fc_mu = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
|
||||
self.fc_logvar = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
|
||||
|
||||
# Wrap black into nn.Sequential
|
||||
self.block = nn.Sequential(*self.block)
|
||||
self.enc_dim = d_model
|
||||
|
||||
def forward(self, x):
|
||||
hidden_state = self.block(x)
|
||||
return {
|
||||
"hidden_state": hidden_state,
|
||||
"mu": self.fc_mu(hidden_state),
|
||||
"logvar": self.fc_logvar(hidden_state),
|
||||
}
|
||||
|
||||
|
||||
class NoiseBlock(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.linear = WNCausalConv1d(dim, dim, kernel_size=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, T = x.shape
|
||||
noise = torch.randn((B, 1, T), device=x.device, dtype=x.dtype)
|
||||
h = self.linear(x)
|
||||
n = noise * h
|
||||
x = x + n
|
||||
return x
|
||||
|
||||
|
||||
class CausalDecoderBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int = 16,
|
||||
output_dim: int = 8,
|
||||
stride: int = 1,
|
||||
groups=1,
|
||||
use_noise_block: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
layers = [
|
||||
Snake1d(input_dim),
|
||||
WNCausalTransposeConv1d(
|
||||
input_dim,
|
||||
output_dim,
|
||||
kernel_size=2 * stride,
|
||||
stride=stride,
|
||||
padding=math.ceil(stride / 2),
|
||||
output_padding=stride % 2,
|
||||
),
|
||||
]
|
||||
if use_noise_block:
|
||||
layers.append(NoiseBlock(output_dim))
|
||||
layers.extend(
|
||||
[
|
||||
CausalResidualUnit(output_dim, dilation=1, groups=groups),
|
||||
CausalResidualUnit(output_dim, dilation=3, groups=groups),
|
||||
CausalResidualUnit(output_dim, dilation=9, groups=groups),
|
||||
]
|
||||
)
|
||||
self.block = nn.Sequential(*layers)
|
||||
self.input_channels = input_dim
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class TransposeLastTwoDim(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.transpose(x, -1, -2)
|
||||
|
||||
|
||||
class SampleRateConditionLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
sr_bin_buckets: int = None,
|
||||
cond_type: str = "scale_bias",
|
||||
cond_dim: int = 128,
|
||||
out_layer: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.cond_type, out_layer_in_dim = cond_type, input_dim
|
||||
|
||||
if cond_type == "scale_bias":
|
||||
self.scale_embed = nn.Embedding(sr_bin_buckets, input_dim)
|
||||
self.bias_embed = nn.Embedding(sr_bin_buckets, input_dim)
|
||||
nn.init.ones_(self.scale_embed.weight)
|
||||
nn.init.zeros_(self.bias_embed.weight)
|
||||
elif cond_type == "scale_bias_init":
|
||||
self.scale_embed = nn.Embedding(sr_bin_buckets, input_dim)
|
||||
self.bias_embed = nn.Embedding(sr_bin_buckets, input_dim)
|
||||
nn.init.normal_(self.scale_embed.weight, mean=1)
|
||||
nn.init.normal_(self.bias_embed.weight)
|
||||
elif cond_type == "add":
|
||||
self.cond_embed = nn.Embedding(sr_bin_buckets, input_dim)
|
||||
nn.init.normal_(self.cond_embed.weight)
|
||||
elif cond_type == "concat":
|
||||
self.cond_embed = nn.Embedding(sr_bin_buckets, cond_dim)
|
||||
assert out_layer, "out_layer must be True for concat cond_type"
|
||||
out_layer_in_dim = input_dim + cond_dim
|
||||
else:
|
||||
raise ValueError(f"Invalid cond_type: {cond_type}")
|
||||
|
||||
if out_layer:
|
||||
self.out_layer = nn.Sequential(
|
||||
Snake1d(out_layer_in_dim),
|
||||
WNCausalConv1d(out_layer_in_dim, input_dim, kernel_size=1),
|
||||
)
|
||||
else:
|
||||
self.out_layer = nn.Identity()
|
||||
|
||||
def forward(self, x, sr_cond):
|
||||
if self.cond_type == "scale_bias" or self.cond_type == "scale_bias_init":
|
||||
x = x * self.scale_embed(sr_cond).unsqueeze(-1) + self.bias_embed(sr_cond).unsqueeze(-1)
|
||||
elif self.cond_type == "add":
|
||||
x = x + self.cond_embed(sr_cond).unsqueeze(-1)
|
||||
elif self.cond_type == "concat":
|
||||
x = torch.cat([x, self.cond_embed(sr_cond).unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
|
||||
|
||||
return self.out_layer(x)
|
||||
|
||||
|
||||
class CausalDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_channel,
|
||||
channels,
|
||||
rates,
|
||||
depthwise: bool = False,
|
||||
d_out: int = 1,
|
||||
use_noise_block: bool = False,
|
||||
sr_bin_boundaries: List[int] = None,
|
||||
cond_type: str = "scale_bias",
|
||||
cond_dim: int = 128,
|
||||
cond_out_layer: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Add first conv layer
|
||||
if depthwise:
|
||||
layers = [
|
||||
WNCausalConv1d(input_channel, input_channel, kernel_size=7, padding=3, groups=input_channel),
|
||||
WNCausalConv1d(input_channel, channels, kernel_size=1),
|
||||
]
|
||||
else:
|
||||
layers = [WNCausalConv1d(input_channel, channels, kernel_size=7, padding=3)]
|
||||
|
||||
# Add upsampling + MRF blocks
|
||||
for i, stride in enumerate(rates):
|
||||
input_dim = channels // 2**i
|
||||
output_dim = channels // 2 ** (i + 1)
|
||||
groups = output_dim if depthwise else 1
|
||||
layers += [
|
||||
CausalDecoderBlock(
|
||||
input_dim,
|
||||
output_dim,
|
||||
stride,
|
||||
groups=groups,
|
||||
use_noise_block=use_noise_block,
|
||||
)
|
||||
]
|
||||
|
||||
# Add final conv layer
|
||||
layers += [
|
||||
Snake1d(output_dim),
|
||||
WNCausalConv1d(output_dim, d_out, kernel_size=7, padding=3),
|
||||
nn.Tanh(),
|
||||
]
|
||||
|
||||
if sr_bin_boundaries is None:
|
||||
self.model = nn.Sequential(*layers)
|
||||
self.sr_bin_boundaries = None
|
||||
else:
|
||||
self.model = nn.ModuleList(layers)
|
||||
|
||||
self.register_buffer("sr_bin_boundaries", torch.tensor(sr_bin_boundaries, dtype=torch.int32))
|
||||
self.sr_bin_buckets = len(sr_bin_boundaries) + 1
|
||||
|
||||
cond_layers = []
|
||||
for layer in self.model:
|
||||
if layer.__class__.__name__ == "CausalDecoderBlock":
|
||||
cond_layers.append(
|
||||
SampleRateConditionLayer(
|
||||
input_dim=layer.input_channels,
|
||||
sr_bin_buckets=self.sr_bin_buckets,
|
||||
cond_type=cond_type,
|
||||
cond_dim=cond_dim,
|
||||
out_layer=cond_out_layer,
|
||||
)
|
||||
)
|
||||
else:
|
||||
cond_layers.append(None)
|
||||
self.sr_cond_model = nn.ModuleList(cond_layers)
|
||||
|
||||
def get_sr_idx(self, sr):
|
||||
return torch.bucketize(sr, self.sr_bin_boundaries)
|
||||
|
||||
def forward(self, x, sr_cond=None):
|
||||
if self.sr_bin_boundaries is not None:
|
||||
# assert sr_cond is not None
|
||||
sr_cond = self.get_sr_idx(sr_cond)
|
||||
|
||||
for layer, sr_cond_layer in zip(self.model, self.sr_cond_model):
|
||||
if sr_cond_layer is not None:
|
||||
x = sr_cond_layer(x, sr_cond)
|
||||
x = layer(x)
|
||||
return x
|
||||
else:
|
||||
return self.model(x)
|
||||
|
||||
|
||||
class AudioVAEConfig(BaseModel):
|
||||
encoder_dim: int = 128
|
||||
encoder_rates: List[int] = [2, 5, 8, 8]
|
||||
latent_dim: int = 64
|
||||
decoder_dim: int = 2048
|
||||
decoder_rates: List[int] = [8, 6, 5, 2, 2, 2]
|
||||
depthwise: bool = True
|
||||
sample_rate: int = 16000
|
||||
out_sample_rate: int = 48000
|
||||
use_noise_block: bool = False
|
||||
sr_bin_boundaries: Optional[List[int]] = [20000, 30000, 40000]
|
||||
cond_type: str = "scale_bias"
|
||||
cond_dim: int = 128
|
||||
cond_out_layer: bool = False
|
||||
|
||||
|
||||
class AudioVAE(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AudioVAEConfig = None,
|
||||
):
|
||||
# 如果没有传入config,使用默认配置
|
||||
if config is None:
|
||||
config = AudioVAEConfig()
|
||||
|
||||
super().__init__()
|
||||
|
||||
encoder_dim = config.encoder_dim
|
||||
encoder_rates = config.encoder_rates
|
||||
latent_dim = config.latent_dim
|
||||
decoder_dim = config.decoder_dim
|
||||
decoder_rates = config.decoder_rates
|
||||
depthwise = config.depthwise
|
||||
sample_rate = config.sample_rate
|
||||
out_sample_rate = config.out_sample_rate
|
||||
use_noise_block = config.use_noise_block
|
||||
sr_bin_boundaries = config.sr_bin_boundaries
|
||||
cond_type = config.cond_type
|
||||
cond_dim = config.cond_dim
|
||||
cond_out_layer = config.cond_out_layer
|
||||
|
||||
self.encoder_dim = encoder_dim
|
||||
self.encoder_rates = encoder_rates
|
||||
self.decoder_dim = decoder_dim
|
||||
self.decoder_rates = decoder_rates
|
||||
self.depthwise = depthwise
|
||||
|
||||
self.use_noise_block = use_noise_block
|
||||
|
||||
if latent_dim is None:
|
||||
latent_dim = encoder_dim * (2 ** len(encoder_rates))
|
||||
|
||||
self.latent_dim = latent_dim
|
||||
self.hop_length = np.prod(encoder_rates)
|
||||
self.encoder = CausalEncoder(
|
||||
encoder_dim,
|
||||
latent_dim,
|
||||
encoder_rates,
|
||||
depthwise=depthwise,
|
||||
)
|
||||
|
||||
self.decoder = CausalDecoder(
|
||||
latent_dim,
|
||||
decoder_dim,
|
||||
decoder_rates,
|
||||
depthwise=depthwise,
|
||||
use_noise_block=use_noise_block,
|
||||
sr_bin_boundaries=sr_bin_boundaries,
|
||||
cond_type=cond_type,
|
||||
cond_dim=cond_dim,
|
||||
cond_out_layer=cond_out_layer,
|
||||
)
|
||||
self.sample_rate = sample_rate
|
||||
self.out_sample_rate = out_sample_rate
|
||||
self.sr_bin_boundaries = sr_bin_boundaries
|
||||
self.chunk_size = math.prod(encoder_rates)
|
||||
|
||||
def preprocess(self, audio_data, sample_rate):
|
||||
if sample_rate is None:
|
||||
sample_rate = self.sample_rate
|
||||
assert sample_rate == self.sample_rate
|
||||
pad_to = self.hop_length
|
||||
length = audio_data.shape[-1]
|
||||
right_pad = math.ceil(length / pad_to) * pad_to - length
|
||||
audio_data = nn.functional.pad(audio_data, (0, right_pad))
|
||||
|
||||
return audio_data
|
||||
|
||||
def decode(self, z: torch.Tensor, sr_cond: torch.Tensor = None):
|
||||
"""Decode given latent codes and return audio data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
z : Tensor[B x D x T]
|
||||
Quantized continuous representation of input
|
||||
length : int, optional
|
||||
Number of samples in output audio, by default None
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A dictionary with the following keys:
|
||||
"audio" : Tensor[B x 1 x length]
|
||||
Decoded audio data.
|
||||
"""
|
||||
if self.sr_bin_boundaries is not None:
|
||||
# use default output sample rate
|
||||
if sr_cond is None:
|
||||
sr_cond = torch.tensor([self.out_sample_rate], device=z.device, dtype=torch.int32)
|
||||
return self.decoder(z, sr_cond)
|
||||
|
||||
def encode(self, audio_data: torch.Tensor, sample_rate: int):
|
||||
"""
|
||||
Args:
|
||||
audio_data: Tensor[B x 1 x T]
|
||||
sample_rate: int
|
||||
Returns:
|
||||
z: Tensor[B x D x T]
|
||||
"""
|
||||
if audio_data.ndim == 2:
|
||||
audio_data = audio_data.unsqueeze(1)
|
||||
|
||||
audio_data = self.preprocess(audio_data, sample_rate)
|
||||
return self.encoder(audio_data)["mu"]
|
||||
@@ -128,6 +128,3 @@ def apply_lora_to_named_linear_modules(
|
||||
dropout=dropout,
|
||||
)
|
||||
setattr(parent, short_name, lora_layer)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
from .unified_cfm import UnifiedCFM, CfmConfig
|
||||
from .local_dit import VoxCPMLocDiT
|
||||
from .local_dit_v2 import VoxCPMLocDiT as VoxCPMLocDiTV2
|
||||
|
||||
@@ -0,0 +1,116 @@
|
||||
import torch
|
||||
from ..minicpm4 import MiniCPMModel, MiniCPM4Config
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
|
||||
class SinusoidalPosEmb(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
|
||||
|
||||
def forward(self, x, scale=1000):
|
||||
if x.ndim < 1:
|
||||
x = x.unsqueeze(0)
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=x.dtype, device=device) * -emb)
|
||||
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
time_embed_dim: int,
|
||||
out_dim: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True)
|
||||
self.act = nn.SiLU()
|
||||
if out_dim is not None:
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, bias=True)
|
||||
|
||||
def forward(self, sample):
|
||||
sample = self.linear_1(sample)
|
||||
sample = self.act(sample)
|
||||
sample = self.linear_2(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class VoxCPMLocDiT(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MiniCPM4Config,
|
||||
in_channels: int = 64,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.config = config
|
||||
|
||||
self.in_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
|
||||
self.cond_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
|
||||
self.out_proj = nn.Linear(config.hidden_size, self.out_channels, bias=True)
|
||||
|
||||
self.time_embeddings = SinusoidalPosEmb(config.hidden_size)
|
||||
self.time_mlp = TimestepEmbedding(
|
||||
in_channels=config.hidden_size,
|
||||
time_embed_dim=config.hidden_size,
|
||||
)
|
||||
self.delta_time_mlp = TimestepEmbedding(
|
||||
in_channels=config.hidden_size,
|
||||
time_embed_dim=config.hidden_size,
|
||||
)
|
||||
|
||||
assert config.vocab_size == 0, "vocab_size must be 0 for local DiT"
|
||||
self.decoder = MiniCPMModel(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mu: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
cond: torch.Tensor,
|
||||
dt: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Forward pass of DiT.
|
||||
x: (N, C, T) tensor of inputs
|
||||
mu: (N, C) tensor of hidden embedding
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
cond: (N, C, T') tensor of prefix conditions
|
||||
dt: (N,) used for mean velocity (may be supported in the future...)
|
||||
"""
|
||||
x = self.in_proj(x.transpose(1, 2).contiguous())
|
||||
|
||||
cond = self.cond_proj(cond.transpose(1, 2).contiguous())
|
||||
prefix = cond.size(1)
|
||||
|
||||
t = self.time_embeddings(t).to(x.dtype)
|
||||
t = self.time_mlp(t)
|
||||
dt = self.time_embeddings(dt).to(x.dtype)
|
||||
dt = self.delta_time_mlp(dt)
|
||||
t = t + dt
|
||||
|
||||
mu = mu.view(x.size(0), -1, x.size(-1))
|
||||
x = torch.cat([mu, (t).unsqueeze(1), cond, x], dim=1)
|
||||
|
||||
hidden, _ = self.decoder(x, is_causal=False)
|
||||
hidden = hidden[:, prefix + mu.size(1) + 1 :, :]
|
||||
hidden = self.out_proj(hidden)
|
||||
|
||||
return hidden.transpose(1, 2).contiguous()
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -138,7 +138,9 @@ class UnifiedCFM(torch.nn.Module):
|
||||
# ------------------------------------------------------------------ #
|
||||
# Training loss
|
||||
# ------------------------------------------------------------------ #
|
||||
def adaptive_loss_weighting(self, losses: torch.Tensor, mask: torch.Tensor | None = None, p: float = 0.0, epsilon: float = 1e-3):
|
||||
def adaptive_loss_weighting(
|
||||
self, losses: torch.Tensor, mask: torch.Tensor | None = None, p: float = 0.0, epsilon: float = 1e-3
|
||||
):
|
||||
weights = 1.0 / ((losses + epsilon).pow(p))
|
||||
if mask is not None:
|
||||
weights = weights * mask
|
||||
@@ -193,8 +195,7 @@ class UnifiedCFM(torch.nn.Module):
|
||||
cond = cond + noisy_mask.view(-1, 1, 1) * torch.randn_like(cond) * self.noise_cond_scale
|
||||
|
||||
ratio_r_neq_t = (
|
||||
self.ratio_r_neq_t_range[0]
|
||||
+ progress * (self.ratio_r_neq_t_range[1] - self.ratio_r_neq_t_range[0])
|
||||
self.ratio_r_neq_t_range[0] + progress * (self.ratio_r_neq_t_range[1] - self.ratio_r_neq_t_range[0])
|
||||
if self.mean_mode
|
||||
else 0.0
|
||||
)
|
||||
|
||||
@@ -64,10 +64,8 @@ class MiniCPMLongRoPE(nn.Module):
|
||||
self.long_factor = config.rope_scaling.long_factor
|
||||
self.original_max_position_embeddings = config.rope_scaling.original_max_position_embeddings
|
||||
|
||||
scale = (self.max_position_embeddings / self.original_max_position_embeddings)
|
||||
self.scaling_factor = math.sqrt(
|
||||
1 + math.log(scale) / math.log(self.original_max_position_embeddings)
|
||||
)
|
||||
scale = self.max_position_embeddings / self.original_max_position_embeddings
|
||||
self.scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
@@ -76,11 +74,7 @@ class MiniCPMLongRoPE(nn.Module):
|
||||
self.register_buffer("cos_cached", torch.empty(0), persistent=False)
|
||||
self.register_buffer("sin_cached", torch.empty(0), persistent=False)
|
||||
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=self.max_position_embeddings,
|
||||
device=self.inv_freq.device,
|
||||
dtype=torch.float32
|
||||
)
|
||||
self._set_cos_sin_cache(seq_len=self.max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
"""设置cos和sin缓存"""
|
||||
@@ -93,8 +87,7 @@ class MiniCPMLongRoPE(nn.Module):
|
||||
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=device)
|
||||
|
||||
freqs = torch.mul(
|
||||
torch.outer(t, 1.0 / ext_factors).to(device=device),
|
||||
self.inv_freq.to(device=device).to(dtype)
|
||||
torch.outer(t, 1.0 / ext_factors).to(device=device), self.inv_freq.to(device=device).to(dtype)
|
||||
)
|
||||
|
||||
# 创建embeddings
|
||||
@@ -123,7 +116,9 @@ class MiniCPMAttention(nn.Module):
|
||||
self.layer_idx = layer_idx
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
||||
self.head_dim = (
|
||||
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
||||
)
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
@@ -413,7 +408,11 @@ class MiniCPMModel(nn.Module):
|
||||
self.kv_cache = StaticKVCache(
|
||||
num_layers=self.config.num_hidden_layers,
|
||||
num_kv_heads=self.config.num_key_value_heads,
|
||||
dim_kv_head=self.config.hidden_size // self.config.num_attention_heads if self.config.kv_channels is None else self.config.kv_channels,
|
||||
dim_kv_head=(
|
||||
self.config.hidden_size // self.config.num_attention_heads
|
||||
if self.config.kv_channels is None
|
||||
else self.config.kv_channels
|
||||
),
|
||||
batch_size=batch_size,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
|
||||
@@ -25,4 +25,3 @@ __all__ = [
|
||||
"load_audio_text_datasets",
|
||||
"build_dataloader",
|
||||
]
|
||||
|
||||
|
||||
@@ -47,9 +47,7 @@ class Accelerator:
|
||||
pass
|
||||
|
||||
self.scaler = torch.amp.GradScaler("cuda") if (amp and torch.cuda.is_available()) else DummyScaler()
|
||||
self.device_ctx = (
|
||||
torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
|
||||
)
|
||||
self.device_ctx = torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
|
||||
self._ddp_model = None # For no_sync support
|
||||
|
||||
def _set_seed(self, seed: int):
|
||||
@@ -84,7 +82,7 @@ class Accelerator:
|
||||
# Model helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
def prepare_model(self, model: torch.nn.Module, **kwargs):
|
||||
if hasattr(model, 'device'): # make sure the matrix will be moved to the correct device
|
||||
if hasattr(model, "device"): # make sure the matrix will be moved to the correct device
|
||||
model.device = self.device
|
||||
model = model.to(self.device)
|
||||
if self.world_size > 1:
|
||||
@@ -163,4 +161,3 @@ class Accelerator:
|
||||
@staticmethod
|
||||
def unwrap(model: torch.nn.Module) -> torch.nn.Module:
|
||||
return model.module if hasattr(model, "module") else model
|
||||
|
||||
|
||||
@@ -36,5 +36,3 @@ def parse_args_with_config(config_path: str | Path | None = None):
|
||||
yaml_args = argbind.parse_args(yaml_args=yaml_args, argv=[])
|
||||
cli_args.update(yaml_args)
|
||||
return cli_args
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import argbind
|
||||
@@ -11,7 +10,6 @@ from ..model.voxcpm import VoxCPMConfig
|
||||
from ..modules.audiovae import AudioVAE
|
||||
from .packers import AudioFeatureProcessingPacker
|
||||
|
||||
|
||||
DEFAULT_TEXT_COLUMN = "text"
|
||||
DEFAULT_AUDIO_COLUMN = "audio"
|
||||
DEFAULT_ID_COLUMN = "dataset_id"
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -155,23 +154,18 @@ class AudioFeatureProcessingPacker:
|
||||
# x: [T, P, D]
|
||||
if x.size(0) >= max_len:
|
||||
return x[:max_len]
|
||||
pad = torch.zeros(
|
||||
(max_len - x.size(0),) + x.shape[1:], dtype=x.dtype, device=x.device
|
||||
)
|
||||
pad = torch.zeros((max_len - x.size(0),) + x.shape[1:], dtype=x.dtype, device=x.device)
|
||||
return torch.cat([x, pad], dim=0)
|
||||
|
||||
if lengths:
|
||||
text_tokens_batch = torch.stack([pad_1d(t, pad_value=0) for t in text_tokens_list], dim=0)
|
||||
text_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in text_mask_list], dim=0)
|
||||
audio_feats_batch = torch.stack([pad_3d(f) for f in audio_feats_list], dim=0)
|
||||
audio_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in audio_mask_list], dim=0)
|
||||
loss_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in loss_mask_list], dim=0)
|
||||
labels_batch = torch.stack([pad_1d(l, pad_value=0) for l in labels_list], dim=0)
|
||||
audio_task_ids_batch = torch.stack(
|
||||
[pad_1d(t, pad_value=0) for t in audio_task_ids_list], dim=0
|
||||
)
|
||||
audio_dataset_ids_batch = torch.stack(
|
||||
[pad_1d(d, pad_value=0) for d in audio_dataset_ids_list], dim=0
|
||||
)
|
||||
labels_batch = torch.stack([pad_1d(lbl, pad_value=0) for lbl in labels_list], dim=0)
|
||||
audio_task_ids_batch = torch.stack([pad_1d(t, pad_value=0) for t in audio_task_ids_list], dim=0)
|
||||
audio_dataset_ids_batch = torch.stack([pad_1d(d, pad_value=0) for d in audio_dataset_ids_list], dim=0)
|
||||
|
||||
# Position ids: [B, T], simple 0..L_i-1 then padded with 0
|
||||
position_ids_list = []
|
||||
@@ -265,13 +259,27 @@ class AudioFeatureProcessingPacker:
|
||||
)
|
||||
audio_feat_info = torch.cat([audio_pad_feat, audio_feat_info, audio_pad_feat[0:1, ...]], dim=0)
|
||||
|
||||
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length), torch.ones(1)]).type(torch.int32).to(
|
||||
text_token.device
|
||||
text_mask = (
|
||||
torch.cat([torch.ones(text_length), torch.zeros(audio_length), torch.ones(1)])
|
||||
.type(torch.int32)
|
||||
.to(text_token.device)
|
||||
)
|
||||
audio_mask = (
|
||||
torch.cat([torch.zeros(text_length), torch.ones(audio_length), torch.zeros(1)])
|
||||
.type(torch.int32)
|
||||
.to(text_token.device)
|
||||
)
|
||||
loss_mask = (
|
||||
torch.cat(
|
||||
[
|
||||
torch.zeros(text_length),
|
||||
torch.zeros(audio_length) if is_prompt else torch.ones(audio_length),
|
||||
torch.zeros(1),
|
||||
]
|
||||
)
|
||||
.type(torch.int32)
|
||||
.to(text_token.device)
|
||||
)
|
||||
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length), torch.zeros(1)]).type(
|
||||
torch.int32
|
||||
).to(text_token.device)
|
||||
loss_mask = torch.cat([torch.zeros(text_length), torch.zeros(audio_length) if is_prompt else torch.ones(audio_length), torch.zeros(1)]).type(torch.int32).to(text_token.device)
|
||||
|
||||
labels = torch.zeros(text_length + audio_length + 1).type(torch.int32).to(text_token.device)
|
||||
labels[-2] = 1
|
||||
@@ -286,4 +294,3 @@ class AudioFeatureProcessingPacker:
|
||||
audio_duration,
|
||||
text_token_count,
|
||||
)
|
||||
|
||||
|
||||
@@ -18,4 +18,3 @@ class TrainingState:
|
||||
val_loader: object
|
||||
tracker: object
|
||||
batch_processor: object
|
||||
|
||||
|
||||
@@ -76,4 +76,3 @@ class TrainingTracker:
|
||||
@contextlib.contextmanager
|
||||
def live(self):
|
||||
yield
|
||||
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
import re
|
||||
import regex
|
||||
import inflect
|
||||
from functools import partial
|
||||
from wetext import Normalizer
|
||||
|
||||
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
|
||||
chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]+")
|
||||
|
||||
|
||||
# whether contain chinese character
|
||||
def contains_chinese(text):
|
||||
@@ -14,19 +14,19 @@ def contains_chinese(text):
|
||||
|
||||
# replace special symbol
|
||||
def replace_corner_mark(text):
|
||||
text = text.replace('²', '平方')
|
||||
text = text.replace('³', '立方')
|
||||
text = text.replace('√', '根号')
|
||||
text = text.replace('≈', '约等于')
|
||||
text = text.replace('<', '小于')
|
||||
text = text.replace("²", "平方")
|
||||
text = text.replace("³", "立方")
|
||||
text = text.replace("√", "根号")
|
||||
text = text.replace("≈", "约等于")
|
||||
text = text.replace("<", "小于")
|
||||
return text
|
||||
|
||||
|
||||
# remove meaningless symbol
|
||||
def remove_bracket(text):
|
||||
text = text.replace('(', ' ').replace(')', ' ')
|
||||
text = text.replace('【', ' ').replace('】', ' ')
|
||||
text = text.replace('`', '').replace('`', '')
|
||||
text = text.replace("(", " ").replace(")", " ")
|
||||
text = text.replace("【", " ").replace("】", " ")
|
||||
text = text.replace("`", "").replace("`", "")
|
||||
text = text.replace("——", " ")
|
||||
return text
|
||||
|
||||
@@ -48,7 +48,7 @@ def spell_out_number(text: str, inflect_parser):
|
||||
if st is not None and st < len(text):
|
||||
num_str = inflect_parser.number_to_words(text[st:])
|
||||
new_text.append(num_str)
|
||||
return ''.join(new_text)
|
||||
return "".join(new_text)
|
||||
|
||||
|
||||
# split paragrah logic:
|
||||
@@ -69,18 +69,18 @@ def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=
|
||||
return len(tokenize(_text)) < merge_len
|
||||
|
||||
if lang == "zh":
|
||||
pounc = ['。', '?', '!', ';', ':', '、', '.', '?', '!', ';']
|
||||
pounc = ["。", "?", "!", ";", ":", "、", ".", "?", "!", ";"]
|
||||
else:
|
||||
pounc = ['.', '?', '!', ';', ':']
|
||||
pounc = [".", "?", "!", ";", ":"]
|
||||
if comma_split:
|
||||
pounc.extend([',', ','])
|
||||
pounc.extend([",", ","])
|
||||
st = 0
|
||||
utts = []
|
||||
for i, c in enumerate(text):
|
||||
if c in pounc:
|
||||
if len(text[st:i]) > 0:
|
||||
utts.append(text[st:i] + c)
|
||||
if i + 1 < len(text) and text[i + 1] in ['"', '”']:
|
||||
if i + 1 < len(text) and text[i + 1] in ['"', "”"]:
|
||||
tmp = utts.pop(-1)
|
||||
utts.append(tmp + text[i + 1])
|
||||
st = i + 2
|
||||
@@ -88,9 +88,9 @@ def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=
|
||||
st = i + 1
|
||||
if len(utts) == 0:
|
||||
if lang == "zh":
|
||||
utts.append(text + '。')
|
||||
utts.append(text + "。")
|
||||
else:
|
||||
utts.append(text + '.')
|
||||
utts.append(text + ".")
|
||||
final_utts = []
|
||||
cur_utt = ""
|
||||
for utt in utts:
|
||||
@@ -112,13 +112,13 @@ def replace_blank(text: str):
|
||||
out_str = []
|
||||
for i, c in enumerate(text):
|
||||
if c == " ":
|
||||
if ((text[i + 1].isascii() and text[i + 1] != " ") and
|
||||
(text[i - 1].isascii() and text[i - 1] != " ")):
|
||||
if (text[i + 1].isascii() and text[i + 1] != " ") and (text[i - 1].isascii() and text[i - 1] != " "):
|
||||
out_str.append(c)
|
||||
else:
|
||||
out_str.append(c)
|
||||
return "".join(out_str)
|
||||
|
||||
|
||||
def clean_markdown(md_text: str) -> str:
|
||||
# 去除代码块 ``` ```(包括多行)
|
||||
md_text = re.sub(r"```.*?```", "", md_text, flags=re.DOTALL)
|
||||
@@ -133,7 +133,7 @@ def clean_markdown(md_text: str) -> str:
|
||||
md_text = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", md_text)
|
||||
|
||||
# 替换无序列表符号
|
||||
md_text = re.sub(r'^(\s*)-\s+', r'\1', md_text, flags=re.MULTILINE)
|
||||
md_text = re.sub(r"^(\s*)-\s+", r"\1", md_text, flags=re.MULTILINE)
|
||||
|
||||
# 去除HTML标签
|
||||
md_text = re.sub(r"<[^>]+>", "", md_text)
|
||||
@@ -152,13 +152,14 @@ def clean_text(text):
|
||||
# 去除 Markdown 语法
|
||||
text = clean_markdown(text)
|
||||
# 匹配并移除表情符号
|
||||
text = regex.compile(r'\p{Emoji_Presentation}|\p{Emoji}\uFE0F', flags=regex.UNICODE).sub("",text)
|
||||
text = regex.compile(r"\p{Emoji_Presentation}|\p{Emoji}\uFE0F", flags=regex.UNICODE).sub("", text)
|
||||
# 去除换行符
|
||||
text = text.replace("\n", " ")
|
||||
text = text.replace("\t", " ")
|
||||
text = text.replace('"', "\“")
|
||||
text = text.replace("“", '"').replace("”", '"')
|
||||
return text
|
||||
|
||||
|
||||
class TextNormalizer:
|
||||
def __init__(self, tokenizer=None):
|
||||
self.tokenizer = tokenizer
|
||||
@@ -171,9 +172,11 @@ class TextNormalizer:
|
||||
lang = "zh" if contains_chinese(text) else "en"
|
||||
text = clean_text(text)
|
||||
if lang == "zh":
|
||||
text = text.replace("=", "等于") # 修复 ”550 + 320 等于 870 千卡。“ 被错误正则为 ”五百五十加三百二十等于八七十千卡.“
|
||||
if re.search(r'([\d$%^*_+≥≤≠×÷?=])', text): # 避免 英文连字符被错误正则为减
|
||||
text = re.sub(r'(?<=[a-zA-Z0-9])-(?=\d)', ' - ', text) # 修复 x-2 被正则为 x负2
|
||||
text = text.replace(
|
||||
"=", "等于"
|
||||
) # 修复 ”550 + 320 等于 870 千卡。“ 被错误正则为 ”五百五十加三百二十等于八七十千卡.“
|
||||
if re.search(r"([\d$%^*_+≥≤≠×÷?=])", text): # 避免 英文连字符被错误正则为减
|
||||
text = re.sub(r"(?<=[a-zA-Z0-9])-(?=\d)", " - ", text) # 修复 x-2 被正则为 x负2
|
||||
text = self.zh_tn_model.normalize(text)
|
||||
text = replace_blank(text)
|
||||
text = replace_corner_mark(text)
|
||||
|
||||
@@ -7,15 +7,15 @@ Related dependencies are imported only when denoising functionality is needed.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
import torchaudio
|
||||
import torch
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
class ZipEnhancer:
|
||||
"""ZipEnhancer Audio Denoising Enhancer"""
|
||||
|
||||
def __init__(self, model_path: str = "iic/speech_zipenhancer_ans_multiloss_16k_base"):
|
||||
"""
|
||||
Initialize ZipEnhancer
|
||||
@@ -23,10 +23,7 @@ class ZipEnhancer:
|
||||
model_path: ModelScope model path or local path
|
||||
"""
|
||||
self.model_path = model_path
|
||||
self._pipeline = pipeline(
|
||||
Tasks.acoustic_noise_suppression,
|
||||
model=self.model_path
|
||||
)
|
||||
self._pipeline = pipeline(Tasks.acoustic_noise_suppression, model=self.model_path)
|
||||
|
||||
def _normalize_loudness(self, wav_path: str):
|
||||
"""
|
||||
@@ -40,8 +37,7 @@ class ZipEnhancer:
|
||||
normalized_audio = torchaudio.functional.gain(audio, -20 - loudness)
|
||||
torchaudio.save(wav_path, normalized_audio, sr)
|
||||
|
||||
def enhance(self, input_path: str, output_path: Optional[str] = None,
|
||||
normalize_loudness: bool = True) -> str:
|
||||
def enhance(self, input_path: str, output_path: Optional[str] = None, normalize_loudness: bool = True) -> str:
|
||||
"""
|
||||
Audio denoising enhancement
|
||||
Args:
|
||||
@@ -57,7 +53,7 @@ class ZipEnhancer:
|
||||
raise FileNotFoundError(f"Input audio file does not exist: {input_path}")
|
||||
# Create temporary file if no output path is specified
|
||||
if output_path is None:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
|
||||
output_path = tmp_file.name
|
||||
try:
|
||||
# Perform denoising processing
|
||||
|
||||
Reference in New Issue
Block a user