update voxcpm2
This commit is contained in:
+247
-260
@@ -1,18 +1,14 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import glob
|
||||
import json
|
||||
import yaml
|
||||
import shutil
|
||||
import datetime
|
||||
import subprocess
|
||||
import threading
|
||||
import gradio as gr
|
||||
import torch
|
||||
import soundfile as sf
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
from typing import Optional
|
||||
|
||||
# Add src to sys.path
|
||||
project_root = Path(__file__).parent
|
||||
@@ -89,7 +85,7 @@ LANG_DICT = {
|
||||
"lang_select": "Language / 语言",
|
||||
"refresh": "刷新",
|
||||
"output_name": "输出目录名称 (可选,若存在则继续训练)",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Global variables
|
||||
@@ -98,9 +94,11 @@ asr_model: Optional[AutoModel] = None
|
||||
training_process: Optional[subprocess.Popen] = None
|
||||
training_log = ""
|
||||
|
||||
|
||||
def get_timestamp_str():
|
||||
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
|
||||
def get_or_load_asr_model():
|
||||
global asr_model
|
||||
if asr_model is None:
|
||||
@@ -109,44 +107,46 @@ def get_or_load_asr_model():
|
||||
asr_model = AutoModel(
|
||||
model="iic/SenseVoiceSmall",
|
||||
disable_update=True,
|
||||
log_level='ERROR',
|
||||
log_level="ERROR",
|
||||
device=device,
|
||||
)
|
||||
return asr_model
|
||||
|
||||
|
||||
def recognize_audio(audio_path):
|
||||
if not audio_path:
|
||||
return ""
|
||||
try:
|
||||
model = get_or_load_asr_model()
|
||||
res = model.generate(input=audio_path, language="auto", use_itn=True)
|
||||
text = res[0]["text"].split('|>')[-1]
|
||||
text = res[0]["text"].split("|>")[-1]
|
||||
return text
|
||||
except Exception as e:
|
||||
print(f"ASR Error: {e}", file=sys.stderr)
|
||||
return ""
|
||||
|
||||
|
||||
def scan_lora_checkpoints(root_dir="lora", with_info=False):
|
||||
"""
|
||||
Scans for LoRA checkpoints in the lora directory.
|
||||
|
||||
|
||||
Args:
|
||||
root_dir: Directory to scan for LoRA checkpoints
|
||||
with_info: If True, returns list of (path, base_model) tuples
|
||||
|
||||
|
||||
Returns:
|
||||
List of checkpoint paths, or list of (path, base_model) tuples if with_info=True
|
||||
"""
|
||||
checkpoints = []
|
||||
if not os.path.exists(root_dir):
|
||||
os.makedirs(root_dir, exist_ok=True)
|
||||
|
||||
|
||||
# Look for lora_weights.safetensors recursively
|
||||
for root, dirs, files in os.walk(root_dir):
|
||||
if "lora_weights.safetensors" in files:
|
||||
# Use the relative path from root_dir as the ID
|
||||
rel_path = os.path.relpath(root, root_dir)
|
||||
|
||||
|
||||
if with_info:
|
||||
# Try to read base_model from lora_config.json
|
||||
base_model = None
|
||||
@@ -161,15 +161,16 @@ def scan_lora_checkpoints(root_dir="lora", with_info=False):
|
||||
checkpoints.append((rel_path, base_model))
|
||||
else:
|
||||
checkpoints.append(rel_path)
|
||||
|
||||
|
||||
# Also check for checkpoints in the default location if they exist
|
||||
default_ckpt = "checkpoints/finetune_lora"
|
||||
if os.path.exists(os.path.join(root_dir, default_ckpt)):
|
||||
# This might be covered by the walk, but good to be sure
|
||||
pass
|
||||
# This might be covered by the walk, but good to be sure
|
||||
pass
|
||||
|
||||
return sorted(checkpoints, reverse=True)
|
||||
|
||||
|
||||
def load_lora_config_from_checkpoint(lora_path):
|
||||
"""Load LoRA config from lora_config.json if available."""
|
||||
lora_config_file = os.path.join(lora_path, "lora_config.json")
|
||||
@@ -184,6 +185,7 @@ def load_lora_config_from_checkpoint(lora_path):
|
||||
print(f"Warning: Failed to load lora_config.json: {e}", file=sys.stderr)
|
||||
return None, None
|
||||
|
||||
|
||||
def get_default_lora_config():
|
||||
"""Return default LoRA config for hot-swapping support."""
|
||||
return LoRAConfig(
|
||||
@@ -192,16 +194,17 @@ def get_default_lora_config():
|
||||
r=32,
|
||||
alpha=16,
|
||||
target_modules_lm=["q_proj", "v_proj", "k_proj", "o_proj"],
|
||||
target_modules_dit=["q_proj", "v_proj", "k_proj", "o_proj"]
|
||||
target_modules_dit=["q_proj", "v_proj", "k_proj", "o_proj"],
|
||||
)
|
||||
|
||||
|
||||
def load_model(pretrained_path, lora_path=None):
|
||||
global current_model
|
||||
print(f"Loading model from {pretrained_path}...", file=sys.stderr)
|
||||
|
||||
|
||||
lora_config = None
|
||||
lora_weights_path = None
|
||||
|
||||
|
||||
if lora_path:
|
||||
full_lora_path = os.path.join("lora", lora_path)
|
||||
if os.path.exists(full_lora_path):
|
||||
@@ -214,7 +217,7 @@ def load_model(pretrained_path, lora_path=None):
|
||||
# Fallback to default config for old checkpoints
|
||||
lora_config = get_default_lora_config()
|
||||
print("Using default LoRA config (lora_config.json not found)", file=sys.stderr)
|
||||
|
||||
|
||||
# Always init with a default LoRA config to allow hot-swapping later
|
||||
if lora_config is None:
|
||||
lora_config = get_default_lora_config()
|
||||
@@ -228,25 +231,24 @@ def load_model(pretrained_path, lora_path=None):
|
||||
)
|
||||
return "Model loaded successfully!"
|
||||
|
||||
|
||||
def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, steps, seed, pretrained_path=None):
|
||||
global current_model
|
||||
|
||||
# 如果选择了 LoRA 模型且当前模型未加载,尝试从 LoRA config 读取 base_model
|
||||
if current_model is None:
|
||||
# 优先使用用户指定的预训练模型路径
|
||||
base_model_path = pretrained_path if pretrained_path and pretrained_path.strip() else default_pretrained_path
|
||||
|
||||
|
||||
# 如果选择了 LoRA,尝试从其 config 读取 base_model
|
||||
if lora_selection and lora_selection != "None":
|
||||
full_lora_path = os.path.join("lora", lora_selection)
|
||||
lora_config_file = os.path.join(full_lora_path, "lora_config.json")
|
||||
|
||||
|
||||
if os.path.exists(lora_config_file):
|
||||
try:
|
||||
with open(lora_config_file, "r", encoding="utf-8") as f:
|
||||
lora_info = json.load(f)
|
||||
saved_base_model = lora_info.get("base_model")
|
||||
|
||||
|
||||
if saved_base_model:
|
||||
# 优先使用保存的 base_model 路径
|
||||
if os.path.exists(saved_base_model):
|
||||
@@ -257,11 +259,11 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
|
||||
print(f"Falling back to default: {base_model_path}", file=sys.stderr)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to read base_model from LoRA config: {e}", file=sys.stderr)
|
||||
|
||||
|
||||
# 加载模型
|
||||
try:
|
||||
print(f"Loading base model: {base_model_path}", file=sys.stderr)
|
||||
status_msg = load_model(base_model_path)
|
||||
load_model(base_model_path)
|
||||
if lora_selection and lora_selection != "None":
|
||||
print(f"Model loaded for LoRA: {lora_selection}", file=sys.stderr)
|
||||
except Exception as e:
|
||||
@@ -270,6 +272,7 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
|
||||
return None, error_msg
|
||||
|
||||
# Handle LoRA hot-swapping
|
||||
assert current_model is not None, "Model must be loaded before inference"
|
||||
if lora_selection and lora_selection != "None":
|
||||
full_lora_path = os.path.join("lora", lora_selection)
|
||||
print(f"Hot-loading LoRA: {full_lora_path}", file=sys.stderr)
|
||||
@@ -290,11 +293,11 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
|
||||
# 处理 prompt 参数:必须同时为 None 或同时有值
|
||||
final_prompt_wav = None
|
||||
final_prompt_text = None
|
||||
|
||||
|
||||
if prompt_wav and prompt_wav.strip():
|
||||
# 有参考音频
|
||||
final_prompt_wav = prompt_wav
|
||||
|
||||
|
||||
# 如果没有提供参考文本,尝试自动识别
|
||||
if not prompt_text or not prompt_text.strip():
|
||||
print("参考音频已提供但缺少文本,自动识别中...", file=sys.stderr)
|
||||
@@ -317,14 +320,16 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
|
||||
prompt_text=final_prompt_text,
|
||||
cfg_value=cfg_scale,
|
||||
inference_timesteps=steps,
|
||||
denoise=False
|
||||
denoise=False,
|
||||
)
|
||||
return (current_model.tts_model.sample_rate, audio_np), "Generation Success"
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return None, f"Error: {str(e)}"
|
||||
|
||||
|
||||
def start_training(
|
||||
pretrained_path,
|
||||
train_manifest,
|
||||
@@ -355,8 +360,8 @@ def start_training(
|
||||
hf_model_id="",
|
||||
distribute=False,
|
||||
):
|
||||
global training_process, training_log
|
||||
|
||||
global training_log
|
||||
|
||||
if training_process is not None and training_process.poll() is None:
|
||||
return "Training is already running!"
|
||||
|
||||
@@ -368,7 +373,7 @@ def start_training(
|
||||
save_dir = os.path.join("lora", timestamp)
|
||||
checkpoints_dir = os.path.join(save_dir, "checkpoints")
|
||||
logs_dir = os.path.join(save_dir, "logs")
|
||||
|
||||
|
||||
os.makedirs(checkpoints_dir, exist_ok=True)
|
||||
os.makedirs(logs_dir, exist_ok=True)
|
||||
|
||||
@@ -394,10 +399,7 @@ def start_training(
|
||||
"max_steps": resolved_max_steps,
|
||||
"save_path": checkpoints_dir,
|
||||
"tensorboard": tensorboard_path if tensorboard_path else logs_dir,
|
||||
"lambdas": {
|
||||
"loss/diff": 1.0,
|
||||
"loss/stop": 1.0
|
||||
},
|
||||
"lambdas": {"loss/diff": 1.0, "loss/stop": 1.0},
|
||||
"lora": {
|
||||
"enable_lm": bool(enable_lm),
|
||||
"enable_dit": bool(enable_dit),
|
||||
@@ -406,10 +408,10 @@ def start_training(
|
||||
"alpha": int(lora_alpha),
|
||||
"dropout": float(dropout),
|
||||
"target_modules_lm": ["q_proj", "v_proj", "k_proj", "o_proj"],
|
||||
"target_modules_dit": ["q_proj", "v_proj", "k_proj", "o_proj"]
|
||||
"target_modules_dit": ["q_proj", "v_proj", "k_proj", "o_proj"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Add distribution options if provided
|
||||
if hf_model_id and hf_model_id.strip():
|
||||
config["hf_model_id"] = hf_model_id.strip()
|
||||
@@ -420,49 +422,42 @@ def start_training(
|
||||
with open(config_path, "w") as f:
|
||||
yaml.dump(config, f)
|
||||
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"scripts/train_voxcpm_finetune.py",
|
||||
"--config_path",
|
||||
config_path
|
||||
]
|
||||
cmd = [sys.executable, "scripts/train_voxcpm_finetune.py", "--config_path", config_path]
|
||||
|
||||
training_log = f"Starting training...\nConfig saved to {config_path}\nOutput dir: {save_dir}\n"
|
||||
|
||||
|
||||
def run_process():
|
||||
global training_process, training_log
|
||||
training_process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1
|
||||
)
|
||||
|
||||
training_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
|
||||
|
||||
assert training_process.stdout is not None
|
||||
for line in training_process.stdout:
|
||||
training_log += line
|
||||
# Keep log size manageable
|
||||
if len(training_log) > 100000:
|
||||
training_log = training_log[-100000:]
|
||||
|
||||
|
||||
training_process.wait()
|
||||
training_log += f"\nTraining finished with code {training_process.returncode}"
|
||||
|
||||
threading.Thread(target=run_process, daemon=True).start()
|
||||
|
||||
|
||||
return f"Training started! Check 'lora/{timestamp}'"
|
||||
|
||||
|
||||
def get_training_log():
|
||||
return training_log
|
||||
|
||||
|
||||
def stop_training():
|
||||
global training_process, training_log
|
||||
global training_log
|
||||
if training_process is not None and training_process.poll() is None:
|
||||
training_process.terminate()
|
||||
training_log += "\nTraining terminated by user."
|
||||
return "Training stopped."
|
||||
return "No training running."
|
||||
|
||||
|
||||
# --- GUI Layout ---
|
||||
|
||||
# 自定义CSS样式
|
||||
@@ -830,14 +825,10 @@ label {
|
||||
}
|
||||
"""
|
||||
|
||||
with gr.Blocks(
|
||||
title="VoxCPM LoRA WebUI",
|
||||
theme=gr.themes.Soft(),
|
||||
css=custom_css
|
||||
) as app:
|
||||
|
||||
with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css) as app:
|
||||
|
||||
# State for language
|
||||
lang_state = gr.State("zh") # Default to Chinese
|
||||
lang_state = gr.State("zh") # Default to Chinese
|
||||
|
||||
# 标题区域
|
||||
with gr.Row(elem_classes="title-section"):
|
||||
@@ -850,10 +841,7 @@ with gr.Blocks(
|
||||
""")
|
||||
with gr.Column(scale=1):
|
||||
lang_btn = gr.Radio(
|
||||
choices=["en", "zh"],
|
||||
value="zh",
|
||||
label="🌐 Language / 语言",
|
||||
elem_classes="lang-selector"
|
||||
choices=["en", "zh"], value="zh", label="🌐 Language / 语言", elem_classes="lang-selector"
|
||||
)
|
||||
|
||||
with gr.Tabs(elem_classes="tabs") as tabs:
|
||||
@@ -869,79 +857,40 @@ with gr.Blocks(
|
||||
gr.Markdown("#### 📁 基础配置")
|
||||
|
||||
train_pretrained_path = gr.Textbox(
|
||||
label="📂 预训练模型路径",
|
||||
value=default_pretrained_path,
|
||||
elem_classes="input-field"
|
||||
label="📂 预训练模型路径", value=default_pretrained_path, elem_classes="input-field"
|
||||
)
|
||||
train_manifest = gr.Textbox(
|
||||
label="📋 训练数据清单 (jsonl)",
|
||||
value="examples/train_data_example.jsonl",
|
||||
elem_classes="input-field"
|
||||
)
|
||||
val_manifest = gr.Textbox(
|
||||
label="📊 验证数据清单 (可选)",
|
||||
value="",
|
||||
elem_classes="input-field"
|
||||
elem_classes="input-field",
|
||||
)
|
||||
val_manifest = gr.Textbox(label="📊 验证数据清单 (可选)", value="", elem_classes="input-field")
|
||||
|
||||
gr.Markdown("#### ⚙️ 训练参数")
|
||||
|
||||
with gr.Row():
|
||||
lr = gr.Number(
|
||||
label="📈 学习率 (Learning Rate)",
|
||||
value=1e-4,
|
||||
elem_classes="input-field"
|
||||
)
|
||||
lr = gr.Number(label="📈 学习率 (Learning Rate)", value=1e-4, elem_classes="input-field")
|
||||
num_iters = gr.Number(
|
||||
label="🔄 最大迭代次数",
|
||||
value=2000,
|
||||
precision=0,
|
||||
elem_classes="input-field"
|
||||
label="🔄 最大迭代次数", value=2000, precision=0, elem_classes="input-field"
|
||||
)
|
||||
batch_size = gr.Number(
|
||||
label="📦 批次大小 (Batch Size)",
|
||||
value=1,
|
||||
precision=0,
|
||||
elem_classes="input-field"
|
||||
label="📦 批次大小 (Batch Size)", value=1, precision=0, elem_classes="input-field"
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
lora_rank = gr.Number(
|
||||
label="🎯 LoRA Rank",
|
||||
value=32,
|
||||
precision=0,
|
||||
elem_classes="input-field"
|
||||
)
|
||||
lora_alpha = gr.Number(
|
||||
label="⚖️ LoRA Alpha",
|
||||
value=16,
|
||||
precision=0,
|
||||
elem_classes="input-field"
|
||||
)
|
||||
lora_rank = gr.Number(label="🎯 LoRA Rank", value=32, precision=0, elem_classes="input-field")
|
||||
lora_alpha = gr.Number(label="⚖️ LoRA Alpha", value=16, precision=0, elem_classes="input-field")
|
||||
save_interval = gr.Number(
|
||||
label="💾 保存间隔 (Steps)",
|
||||
value=1000,
|
||||
precision=0,
|
||||
elem_classes="input-field"
|
||||
label="💾 保存间隔 (Steps)", value=1000, precision=0, elem_classes="input-field"
|
||||
)
|
||||
|
||||
output_name = gr.Textbox(
|
||||
label="📁 输出目录名称 (可选,若存在则继续训练)",
|
||||
value="",
|
||||
elem_classes="input-field"
|
||||
label="📁 输出目录名称 (可选,若存在则继续训练)", value="", elem_classes="input-field"
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
start_btn = gr.Button(
|
||||
"▶️ 开始训练",
|
||||
variant="primary",
|
||||
elem_classes="button-primary"
|
||||
)
|
||||
stop_btn = gr.Button(
|
||||
"⏹️ 停止训练",
|
||||
variant="stop",
|
||||
elem_classes="button-stop"
|
||||
)
|
||||
start_btn = gr.Button("▶️ 开始训练", variant="primary", elem_classes="button-primary")
|
||||
stop_btn = gr.Button("⏹️ 停止训练", variant="stop", elem_classes="button-stop")
|
||||
|
||||
with gr.Accordion("🔧 高级选项 (Advanced)", open=False, elem_classes="accordion"):
|
||||
with gr.Row():
|
||||
@@ -961,10 +910,12 @@ with gr.Blocks(
|
||||
enable_dit = gr.Checkbox(label="启用 LoRA DIT (enable_dit)", value=True)
|
||||
enable_proj = gr.Checkbox(label="启用投影 (enable_proj)", value=False)
|
||||
dropout = gr.Number(label="LoRA Dropout", value=0.0)
|
||||
|
||||
|
||||
gr.Markdown("#### 分发选项 (Distribution)")
|
||||
with gr.Row():
|
||||
hf_model_id = gr.Textbox(label="HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)", value="openbmb/VoxCPM1.5")
|
||||
hf_model_id = gr.Textbox(
|
||||
label="HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)", value="openbmb/VoxCPM1.5"
|
||||
)
|
||||
distribute = gr.Checkbox(label="分发模式 (distribute)", value=False)
|
||||
|
||||
with gr.Column(scale=2, elem_classes="form-section"):
|
||||
@@ -975,26 +926,44 @@ with gr.Blocks(
|
||||
max_lines=30,
|
||||
interactive=False,
|
||||
elem_classes="input-field",
|
||||
show_label=False
|
||||
show_label=False,
|
||||
)
|
||||
|
||||
|
||||
start_btn.click(
|
||||
start_training,
|
||||
inputs=[
|
||||
train_pretrained_path, train_manifest, val_manifest,
|
||||
lr, num_iters, batch_size, lora_rank, lora_alpha, save_interval,
|
||||
train_pretrained_path,
|
||||
train_manifest,
|
||||
val_manifest,
|
||||
lr,
|
||||
num_iters,
|
||||
batch_size,
|
||||
lora_rank,
|
||||
lora_alpha,
|
||||
save_interval,
|
||||
output_name,
|
||||
# advanced
|
||||
grad_accum_steps, num_workers, log_interval, valid_interval,
|
||||
weight_decay, warmup_steps, max_steps, sample_rate,
|
||||
enable_lm, enable_dit, enable_proj, dropout, tensorboard_path,
|
||||
grad_accum_steps,
|
||||
num_workers,
|
||||
log_interval,
|
||||
valid_interval,
|
||||
weight_decay,
|
||||
warmup_steps,
|
||||
max_steps,
|
||||
sample_rate,
|
||||
enable_lm,
|
||||
enable_dit,
|
||||
enable_proj,
|
||||
dropout,
|
||||
tensorboard_path,
|
||||
# distribution
|
||||
hf_model_id, distribute
|
||||
hf_model_id,
|
||||
distribute,
|
||||
],
|
||||
outputs=[logs_out] # Initial message
|
||||
outputs=[logs_out], # Initial message
|
||||
)
|
||||
stop_btn.click(stop_training, outputs=[logs_out])
|
||||
|
||||
|
||||
# Log refresher
|
||||
timer = gr.Timer(1)
|
||||
timer.tick(get_training_log, outputs=logs_out)
|
||||
@@ -1016,21 +985,17 @@ with gr.Blocks(
|
||||
value="Hello, this is a test of the VoxCPM LoRA model.",
|
||||
elem_classes="input-field",
|
||||
lines=4,
|
||||
placeholder="输入要合成的文本内容..."
|
||||
placeholder="输入要合成的文本内容...",
|
||||
)
|
||||
|
||||
gr.Markdown("**🎭 声音克隆(可选)**")
|
||||
|
||||
prompt_wav = gr.Audio(
|
||||
label="🎵 参考音频",
|
||||
type="filepath",
|
||||
elem_classes="input-field"
|
||||
)
|
||||
|
||||
|
||||
prompt_wav = gr.Audio(label="🎵 参考音频", type="filepath", elem_classes="input-field")
|
||||
|
||||
prompt_text = gr.Textbox(
|
||||
label="📝 参考文本(可选)",
|
||||
elem_classes="input-field",
|
||||
placeholder="如不填写,将自动识别参考音频内容"
|
||||
placeholder="如不填写,将自动识别参考音频内容",
|
||||
)
|
||||
|
||||
# 中栏:模型选择和参数配置 (35%)
|
||||
@@ -1043,15 +1008,11 @@ with gr.Blocks(
|
||||
value="None",
|
||||
interactive=True,
|
||||
elem_classes="input-field",
|
||||
info="选择训练好的 LoRA 模型,或选择 None 使用基础模型"
|
||||
)
|
||||
|
||||
refresh_lora_btn = gr.Button(
|
||||
"🔄 刷新模型列表",
|
||||
elem_classes="button-refresh",
|
||||
size="sm"
|
||||
info="选择训练好的 LoRA 模型,或选择 None 使用基础模型",
|
||||
)
|
||||
|
||||
refresh_lora_btn = gr.Button("🔄 刷新模型列表", elem_classes="button-refresh", size="sm")
|
||||
|
||||
gr.Markdown("#### ⚙️ 生成参数")
|
||||
|
||||
cfg_scale = gr.Slider(
|
||||
@@ -1060,59 +1021,50 @@ with gr.Blocks(
|
||||
maximum=5.0,
|
||||
value=2.0,
|
||||
step=0.1,
|
||||
info="引导系数,值越大越贴近提示"
|
||||
info="引导系数,值越大越贴近提示",
|
||||
)
|
||||
|
||||
|
||||
steps = gr.Slider(
|
||||
label="🔢 推理步数",
|
||||
minimum=1,
|
||||
maximum=50,
|
||||
value=10,
|
||||
step=1,
|
||||
info="生成质量与步数成正比,但耗时更长"
|
||||
info="生成质量与步数成正比,但耗时更长",
|
||||
)
|
||||
|
||||
|
||||
seed = gr.Number(
|
||||
label="🎲 随机种子",
|
||||
value=-1,
|
||||
precision=0,
|
||||
elem_classes="input-field",
|
||||
info="-1 为随机,固定值可复现结果"
|
||||
info="-1 为随机,固定值可复现结果",
|
||||
)
|
||||
|
||||
generate_btn = gr.Button(
|
||||
"🎵 生成音频",
|
||||
variant="primary",
|
||||
elem_classes="button-primary",
|
||||
size="lg"
|
||||
)
|
||||
generate_btn = gr.Button("🎵 生成音频", variant="primary", elem_classes="button-primary", size="lg")
|
||||
|
||||
# 右栏:生成结果 (30%)
|
||||
with gr.Column(scale=30, elem_classes="form-section"):
|
||||
gr.Markdown("#### 🎧 生成结果")
|
||||
|
||||
audio_out = gr.Audio(
|
||||
label="",
|
||||
elem_classes="input-field",
|
||||
show_label=False
|
||||
)
|
||||
|
||||
audio_out = gr.Audio(label="", elem_classes="input-field", show_label=False)
|
||||
|
||||
gr.Markdown("#### 📋 状态信息")
|
||||
|
||||
|
||||
status_out = gr.Textbox(
|
||||
label="",
|
||||
interactive=False,
|
||||
elem_classes="input-field",
|
||||
show_label=False,
|
||||
lines=3,
|
||||
placeholder="等待生成..."
|
||||
placeholder="等待生成...",
|
||||
)
|
||||
|
||||
def refresh_loras():
|
||||
# 获取 LoRA checkpoints 及其 base model 信息
|
||||
checkpoints_with_info = scan_lora_checkpoints(with_info=True)
|
||||
choices = ["None"] + [ckpt[0] for ckpt in checkpoints_with_info]
|
||||
|
||||
|
||||
# 输出调试信息
|
||||
print(f"刷新 LoRA 列表: 找到 {len(checkpoints_with_info)} 个检查点", file=sys.stderr)
|
||||
for ckpt_path, base_model in checkpoints_with_info:
|
||||
@@ -1120,22 +1072,27 @@ with gr.Blocks(
|
||||
print(f" - {ckpt_path} (Base Model: {base_model})", file=sys.stderr)
|
||||
else:
|
||||
print(f" - {ckpt_path}", file=sys.stderr)
|
||||
|
||||
|
||||
return gr.update(choices=choices, value="None")
|
||||
|
||||
refresh_lora_btn.click(refresh_loras, outputs=[lora_select])
|
||||
|
||||
|
||||
# Auto-recognize audio when uploaded
|
||||
prompt_wav.change(
|
||||
fn=recognize_audio,
|
||||
inputs=[prompt_wav],
|
||||
outputs=[prompt_text]
|
||||
)
|
||||
|
||||
prompt_wav.change(fn=recognize_audio, inputs=[prompt_wav], outputs=[prompt_text])
|
||||
|
||||
generate_btn.click(
|
||||
run_inference,
|
||||
inputs=[infer_text, prompt_wav, prompt_text, lora_select, cfg_scale, steps, seed, train_pretrained_path],
|
||||
outputs=[audio_out, status_out]
|
||||
inputs=[
|
||||
infer_text,
|
||||
prompt_wav,
|
||||
prompt_text,
|
||||
lora_select,
|
||||
cfg_scale,
|
||||
steps,
|
||||
seed,
|
||||
train_pretrained_path,
|
||||
],
|
||||
outputs=[audio_out, status_out],
|
||||
)
|
||||
|
||||
# --- Language Switching Logic ---
|
||||
@@ -1144,111 +1101,141 @@ with gr.Blocks(
|
||||
# Labels for advanced options
|
||||
if lang == "zh":
|
||||
adv = {
|
||||
'grad_accum_steps': "梯度累积 (grad_accum_steps)",
|
||||
'num_workers': "数据加载线程 (num_workers)",
|
||||
'log_interval': "日志间隔 (log_interval)",
|
||||
'valid_interval': "验证间隔 (valid_interval)",
|
||||
'weight_decay': "权重衰减 (weight_decay)",
|
||||
'warmup_steps': "warmup_steps",
|
||||
'max_steps': "最大步数 (max_steps)",
|
||||
'sample_rate': "采样率 (sample_rate)",
|
||||
'enable_lm': "启用 LoRA LM (enable_lm)",
|
||||
'enable_dit': "启用 LoRA DIT (enable_dit)",
|
||||
'enable_proj': "启用投影 (enable_proj)",
|
||||
'dropout': "LoRA Dropout",
|
||||
'tensorboard_path': "Tensorboard 路径 (可选)",
|
||||
'hf_model_id': "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
|
||||
'distribute': "分发模式 (distribute)",
|
||||
"grad_accum_steps": "梯度累积 (grad_accum_steps)",
|
||||
"num_workers": "数据加载线程 (num_workers)",
|
||||
"log_interval": "日志间隔 (log_interval)",
|
||||
"valid_interval": "验证间隔 (valid_interval)",
|
||||
"weight_decay": "权重衰减 (weight_decay)",
|
||||
"warmup_steps": "warmup_steps",
|
||||
"max_steps": "最大步数 (max_steps)",
|
||||
"sample_rate": "采样率 (sample_rate)",
|
||||
"enable_lm": "启用 LoRA LM (enable_lm)",
|
||||
"enable_dit": "启用 LoRA DIT (enable_dit)",
|
||||
"enable_proj": "启用投影 (enable_proj)",
|
||||
"dropout": "LoRA Dropout",
|
||||
"tensorboard_path": "Tensorboard 路径 (可选)",
|
||||
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
|
||||
"distribute": "分发模式 (distribute)",
|
||||
}
|
||||
else:
|
||||
adv = {
|
||||
'grad_accum_steps': "Grad Accum Steps",
|
||||
'num_workers': "Num Workers",
|
||||
'log_interval': "Log Interval",
|
||||
'valid_interval': "Valid Interval",
|
||||
'weight_decay': "Weight Decay",
|
||||
'warmup_steps': "Warmup Steps",
|
||||
'max_steps': "Max Steps",
|
||||
'sample_rate': "Sample Rate",
|
||||
'enable_lm': "Enable LoRA LM",
|
||||
'enable_dit': "Enable LoRA DIT",
|
||||
'enable_proj': "Enable Projection",
|
||||
'dropout': "LoRA Dropout",
|
||||
'tensorboard_path': "Tensorboard Path (Optional)",
|
||||
'hf_model_id': "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
|
||||
'distribute': "Distribute Mode",
|
||||
"grad_accum_steps": "Grad Accum Steps",
|
||||
"num_workers": "Num Workers",
|
||||
"log_interval": "Log Interval",
|
||||
"valid_interval": "Valid Interval",
|
||||
"weight_decay": "Weight Decay",
|
||||
"warmup_steps": "Warmup Steps",
|
||||
"max_steps": "Max Steps",
|
||||
"sample_rate": "Sample Rate",
|
||||
"enable_lm": "Enable LoRA LM",
|
||||
"enable_dit": "Enable LoRA DIT",
|
||||
"enable_proj": "Enable Projection",
|
||||
"dropout": "LoRA Dropout",
|
||||
"tensorboard_path": "Tensorboard Path (Optional)",
|
||||
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
|
||||
"distribute": "Distribute Mode",
|
||||
}
|
||||
|
||||
return (
|
||||
gr.update(value=f"# {d['title']}"),
|
||||
gr.update(label=d['tab_train']),
|
||||
gr.update(label=d['tab_infer']),
|
||||
gr.update(label=d['pretrained_path']),
|
||||
gr.update(label=d['train_manifest']),
|
||||
gr.update(label=d['val_manifest']),
|
||||
gr.update(label=d['lr']),
|
||||
gr.update(label=d['max_iters']),
|
||||
gr.update(label=d['batch_size']),
|
||||
gr.update(label=d['lora_rank']),
|
||||
gr.update(label=d['lora_alpha']),
|
||||
gr.update(label=d['save_interval']),
|
||||
gr.update(label=d['output_name']),
|
||||
gr.update(value=d['start_train']),
|
||||
gr.update(value=d['stop_train']),
|
||||
gr.update(label=d['train_logs']),
|
||||
gr.update(label=d["tab_train"]),
|
||||
gr.update(label=d["tab_infer"]),
|
||||
gr.update(label=d["pretrained_path"]),
|
||||
gr.update(label=d["train_manifest"]),
|
||||
gr.update(label=d["val_manifest"]),
|
||||
gr.update(label=d["lr"]),
|
||||
gr.update(label=d["max_iters"]),
|
||||
gr.update(label=d["batch_size"]),
|
||||
gr.update(label=d["lora_rank"]),
|
||||
gr.update(label=d["lora_alpha"]),
|
||||
gr.update(label=d["save_interval"]),
|
||||
gr.update(label=d["output_name"]),
|
||||
gr.update(value=d["start_train"]),
|
||||
gr.update(value=d["stop_train"]),
|
||||
gr.update(label=d["train_logs"]),
|
||||
# Advanced options (must match outputs order)
|
||||
gr.update(label=adv['grad_accum_steps']),
|
||||
gr.update(label=adv['num_workers']),
|
||||
gr.update(label=adv['log_interval']),
|
||||
gr.update(label=adv['valid_interval']),
|
||||
gr.update(label=adv['weight_decay']),
|
||||
gr.update(label=adv['warmup_steps']),
|
||||
gr.update(label=adv['max_steps']),
|
||||
gr.update(label=adv['sample_rate']),
|
||||
gr.update(label=adv['enable_lm']),
|
||||
gr.update(label=adv['enable_dit']),
|
||||
gr.update(label=adv['enable_proj']),
|
||||
gr.update(label=adv['dropout']),
|
||||
gr.update(label=adv['tensorboard_path']),
|
||||
gr.update(label=adv["grad_accum_steps"]),
|
||||
gr.update(label=adv["num_workers"]),
|
||||
gr.update(label=adv["log_interval"]),
|
||||
gr.update(label=adv["valid_interval"]),
|
||||
gr.update(label=adv["weight_decay"]),
|
||||
gr.update(label=adv["warmup_steps"]),
|
||||
gr.update(label=adv["max_steps"]),
|
||||
gr.update(label=adv["sample_rate"]),
|
||||
gr.update(label=adv["enable_lm"]),
|
||||
gr.update(label=adv["enable_dit"]),
|
||||
gr.update(label=adv["enable_proj"]),
|
||||
gr.update(label=adv["dropout"]),
|
||||
gr.update(label=adv["tensorboard_path"]),
|
||||
# Distribution options
|
||||
gr.update(label=adv['hf_model_id']),
|
||||
gr.update(label=adv['distribute']),
|
||||
gr.update(label=adv["hf_model_id"]),
|
||||
gr.update(label=adv["distribute"]),
|
||||
# Inference section
|
||||
gr.update(label=d['text_to_synth']),
|
||||
gr.update(label=d['ref_audio']),
|
||||
gr.update(label=d['ref_text']),
|
||||
gr.update(label=d['select_lora']),
|
||||
gr.update(value=d['refresh']),
|
||||
gr.update(label=d['cfg_scale']),
|
||||
gr.update(label=d['infer_steps']),
|
||||
gr.update(label=d['seed']),
|
||||
gr.update(value=d['gen_audio']),
|
||||
gr.update(label=d['gen_output']),
|
||||
gr.update(label=d['status']),
|
||||
gr.update(label=d["text_to_synth"]),
|
||||
gr.update(label=d["ref_audio"]),
|
||||
gr.update(label=d["ref_text"]),
|
||||
gr.update(label=d["select_lora"]),
|
||||
gr.update(value=d["refresh"]),
|
||||
gr.update(label=d["cfg_scale"]),
|
||||
gr.update(label=d["infer_steps"]),
|
||||
gr.update(label=d["seed"]),
|
||||
gr.update(value=d["gen_audio"]),
|
||||
gr.update(label=d["gen_output"]),
|
||||
gr.update(label=d["status"]),
|
||||
)
|
||||
|
||||
lang_btn.change(
|
||||
change_language,
|
||||
inputs=[lang_btn],
|
||||
outputs=[
|
||||
title_md, tab_train, tab_infer,
|
||||
train_pretrained_path, train_manifest, val_manifest,
|
||||
lr, num_iters, batch_size, lora_rank, lora_alpha, save_interval,
|
||||
title_md,
|
||||
tab_train,
|
||||
tab_infer,
|
||||
train_pretrained_path,
|
||||
train_manifest,
|
||||
val_manifest,
|
||||
lr,
|
||||
num_iters,
|
||||
batch_size,
|
||||
lora_rank,
|
||||
lora_alpha,
|
||||
save_interval,
|
||||
output_name,
|
||||
start_btn, stop_btn, logs_out,
|
||||
start_btn,
|
||||
stop_btn,
|
||||
logs_out,
|
||||
# advanced outputs
|
||||
grad_accum_steps, num_workers, log_interval, valid_interval,
|
||||
weight_decay, warmup_steps, max_steps, sample_rate,
|
||||
enable_lm, enable_dit, enable_proj, dropout, tensorboard_path,
|
||||
grad_accum_steps,
|
||||
num_workers,
|
||||
log_interval,
|
||||
valid_interval,
|
||||
weight_decay,
|
||||
warmup_steps,
|
||||
max_steps,
|
||||
sample_rate,
|
||||
enable_lm,
|
||||
enable_dit,
|
||||
enable_proj,
|
||||
dropout,
|
||||
tensorboard_path,
|
||||
# distribution outputs
|
||||
hf_model_id, distribute,
|
||||
infer_text, prompt_wav, prompt_text,
|
||||
lora_select, refresh_lora_btn, cfg_scale, steps, seed,
|
||||
generate_btn, audio_out, status_out
|
||||
]
|
||||
hf_model_id,
|
||||
distribute,
|
||||
infer_text,
|
||||
prompt_wav,
|
||||
prompt_text,
|
||||
lora_select,
|
||||
refresh_lora_btn,
|
||||
cfg_scale,
|
||||
steps,
|
||||
seed,
|
||||
generate_btn,
|
||||
audio_out,
|
||||
status_out,
|
||||
],
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Ensure lora directory exists
|
||||
os.makedirs("lora", exist_ok=True)
|
||||
app.queue().launch(server_name="0.0.0.0", server_port=7860)
|
||||
app.queue().launch(server_name="0.0.0.0", server_port=7860)
|
||||
|
||||
Reference in New Issue
Block a user