update voxcpm2

This commit is contained in:
刘鑫
2026-03-31 11:50:37 +08:00
parent 23ed7ffeee
commit d9cf376e16
36 changed files with 8163 additions and 834 deletions
+247 -260
View File
@@ -1,18 +1,14 @@
import os
import sys
import time
import glob
import json
import yaml
import shutil
import datetime
import subprocess
import threading
import gradio as gr
import torch
import soundfile as sf
from pathlib import Path
from typing import Optional, List
from typing import Optional
# Add src to sys.path
project_root = Path(__file__).parent
@@ -89,7 +85,7 @@ LANG_DICT = {
"lang_select": "Language / 语言",
"refresh": "刷新",
"output_name": "输出目录名称 (可选,若存在则继续训练)",
}
},
}
# Global variables
@@ -98,9 +94,11 @@ asr_model: Optional[AutoModel] = None
training_process: Optional[subprocess.Popen] = None
training_log = ""
def get_timestamp_str():
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
def get_or_load_asr_model():
global asr_model
if asr_model is None:
@@ -109,44 +107,46 @@ def get_or_load_asr_model():
asr_model = AutoModel(
model="iic/SenseVoiceSmall",
disable_update=True,
log_level='ERROR',
log_level="ERROR",
device=device,
)
return asr_model
def recognize_audio(audio_path):
if not audio_path:
return ""
try:
model = get_or_load_asr_model()
res = model.generate(input=audio_path, language="auto", use_itn=True)
text = res[0]["text"].split('|>')[-1]
text = res[0]["text"].split("|>")[-1]
return text
except Exception as e:
print(f"ASR Error: {e}", file=sys.stderr)
return ""
def scan_lora_checkpoints(root_dir="lora", with_info=False):
"""
Scans for LoRA checkpoints in the lora directory.
Args:
root_dir: Directory to scan for LoRA checkpoints
with_info: If True, returns list of (path, base_model) tuples
Returns:
List of checkpoint paths, or list of (path, base_model) tuples if with_info=True
"""
checkpoints = []
if not os.path.exists(root_dir):
os.makedirs(root_dir, exist_ok=True)
# Look for lora_weights.safetensors recursively
for root, dirs, files in os.walk(root_dir):
if "lora_weights.safetensors" in files:
# Use the relative path from root_dir as the ID
rel_path = os.path.relpath(root, root_dir)
if with_info:
# Try to read base_model from lora_config.json
base_model = None
@@ -161,15 +161,16 @@ def scan_lora_checkpoints(root_dir="lora", with_info=False):
checkpoints.append((rel_path, base_model))
else:
checkpoints.append(rel_path)
# Also check for checkpoints in the default location if they exist
default_ckpt = "checkpoints/finetune_lora"
if os.path.exists(os.path.join(root_dir, default_ckpt)):
# This might be covered by the walk, but good to be sure
pass
# This might be covered by the walk, but good to be sure
pass
return sorted(checkpoints, reverse=True)
def load_lora_config_from_checkpoint(lora_path):
"""Load LoRA config from lora_config.json if available."""
lora_config_file = os.path.join(lora_path, "lora_config.json")
@@ -184,6 +185,7 @@ def load_lora_config_from_checkpoint(lora_path):
print(f"Warning: Failed to load lora_config.json: {e}", file=sys.stderr)
return None, None
def get_default_lora_config():
"""Return default LoRA config for hot-swapping support."""
return LoRAConfig(
@@ -192,16 +194,17 @@ def get_default_lora_config():
r=32,
alpha=16,
target_modules_lm=["q_proj", "v_proj", "k_proj", "o_proj"],
target_modules_dit=["q_proj", "v_proj", "k_proj", "o_proj"]
target_modules_dit=["q_proj", "v_proj", "k_proj", "o_proj"],
)
def load_model(pretrained_path, lora_path=None):
global current_model
print(f"Loading model from {pretrained_path}...", file=sys.stderr)
lora_config = None
lora_weights_path = None
if lora_path:
full_lora_path = os.path.join("lora", lora_path)
if os.path.exists(full_lora_path):
@@ -214,7 +217,7 @@ def load_model(pretrained_path, lora_path=None):
# Fallback to default config for old checkpoints
lora_config = get_default_lora_config()
print("Using default LoRA config (lora_config.json not found)", file=sys.stderr)
# Always init with a default LoRA config to allow hot-swapping later
if lora_config is None:
lora_config = get_default_lora_config()
@@ -228,25 +231,24 @@ def load_model(pretrained_path, lora_path=None):
)
return "Model loaded successfully!"
def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, steps, seed, pretrained_path=None):
global current_model
# 如果选择了 LoRA 模型且当前模型未加载,尝试从 LoRA config 读取 base_model
if current_model is None:
# 优先使用用户指定的预训练模型路径
base_model_path = pretrained_path if pretrained_path and pretrained_path.strip() else default_pretrained_path
# 如果选择了 LoRA,尝试从其 config 读取 base_model
if lora_selection and lora_selection != "None":
full_lora_path = os.path.join("lora", lora_selection)
lora_config_file = os.path.join(full_lora_path, "lora_config.json")
if os.path.exists(lora_config_file):
try:
with open(lora_config_file, "r", encoding="utf-8") as f:
lora_info = json.load(f)
saved_base_model = lora_info.get("base_model")
if saved_base_model:
# 优先使用保存的 base_model 路径
if os.path.exists(saved_base_model):
@@ -257,11 +259,11 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
print(f"Falling back to default: {base_model_path}", file=sys.stderr)
except Exception as e:
print(f"Warning: Failed to read base_model from LoRA config: {e}", file=sys.stderr)
# 加载模型
try:
print(f"Loading base model: {base_model_path}", file=sys.stderr)
status_msg = load_model(base_model_path)
load_model(base_model_path)
if lora_selection and lora_selection != "None":
print(f"Model loaded for LoRA: {lora_selection}", file=sys.stderr)
except Exception as e:
@@ -270,6 +272,7 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
return None, error_msg
# Handle LoRA hot-swapping
assert current_model is not None, "Model must be loaded before inference"
if lora_selection and lora_selection != "None":
full_lora_path = os.path.join("lora", lora_selection)
print(f"Hot-loading LoRA: {full_lora_path}", file=sys.stderr)
@@ -290,11 +293,11 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
# 处理 prompt 参数:必须同时为 None 或同时有值
final_prompt_wav = None
final_prompt_text = None
if prompt_wav and prompt_wav.strip():
# 有参考音频
final_prompt_wav = prompt_wav
# 如果没有提供参考文本,尝试自动识别
if not prompt_text or not prompt_text.strip():
print("参考音频已提供但缺少文本,自动识别中...", file=sys.stderr)
@@ -317,14 +320,16 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
prompt_text=final_prompt_text,
cfg_value=cfg_scale,
inference_timesteps=steps,
denoise=False
denoise=False,
)
return (current_model.tts_model.sample_rate, audio_np), "Generation Success"
except Exception as e:
import traceback
traceback.print_exc()
return None, f"Error: {str(e)}"
def start_training(
pretrained_path,
train_manifest,
@@ -355,8 +360,8 @@ def start_training(
hf_model_id="",
distribute=False,
):
global training_process, training_log
global training_log
if training_process is not None and training_process.poll() is None:
return "Training is already running!"
@@ -368,7 +373,7 @@ def start_training(
save_dir = os.path.join("lora", timestamp)
checkpoints_dir = os.path.join(save_dir, "checkpoints")
logs_dir = os.path.join(save_dir, "logs")
os.makedirs(checkpoints_dir, exist_ok=True)
os.makedirs(logs_dir, exist_ok=True)
@@ -394,10 +399,7 @@ def start_training(
"max_steps": resolved_max_steps,
"save_path": checkpoints_dir,
"tensorboard": tensorboard_path if tensorboard_path else logs_dir,
"lambdas": {
"loss/diff": 1.0,
"loss/stop": 1.0
},
"lambdas": {"loss/diff": 1.0, "loss/stop": 1.0},
"lora": {
"enable_lm": bool(enable_lm),
"enable_dit": bool(enable_dit),
@@ -406,10 +408,10 @@ def start_training(
"alpha": int(lora_alpha),
"dropout": float(dropout),
"target_modules_lm": ["q_proj", "v_proj", "k_proj", "o_proj"],
"target_modules_dit": ["q_proj", "v_proj", "k_proj", "o_proj"]
"target_modules_dit": ["q_proj", "v_proj", "k_proj", "o_proj"],
},
}
# Add distribution options if provided
if hf_model_id and hf_model_id.strip():
config["hf_model_id"] = hf_model_id.strip()
@@ -420,49 +422,42 @@ def start_training(
with open(config_path, "w") as f:
yaml.dump(config, f)
cmd = [
sys.executable,
"scripts/train_voxcpm_finetune.py",
"--config_path",
config_path
]
cmd = [sys.executable, "scripts/train_voxcpm_finetune.py", "--config_path", config_path]
training_log = f"Starting training...\nConfig saved to {config_path}\nOutput dir: {save_dir}\n"
def run_process():
global training_process, training_log
training_process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1
)
training_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
assert training_process.stdout is not None
for line in training_process.stdout:
training_log += line
# Keep log size manageable
if len(training_log) > 100000:
training_log = training_log[-100000:]
training_process.wait()
training_log += f"\nTraining finished with code {training_process.returncode}"
threading.Thread(target=run_process, daemon=True).start()
return f"Training started! Check 'lora/{timestamp}'"
def get_training_log():
return training_log
def stop_training():
global training_process, training_log
global training_log
if training_process is not None and training_process.poll() is None:
training_process.terminate()
training_log += "\nTraining terminated by user."
return "Training stopped."
return "No training running."
# --- GUI Layout ---
# 自定义CSS样式
@@ -830,14 +825,10 @@ label {
}
"""
with gr.Blocks(
title="VoxCPM LoRA WebUI",
theme=gr.themes.Soft(),
css=custom_css
) as app:
with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css) as app:
# State for language
lang_state = gr.State("zh") # Default to Chinese
lang_state = gr.State("zh") # Default to Chinese
# 标题区域
with gr.Row(elem_classes="title-section"):
@@ -850,10 +841,7 @@ with gr.Blocks(
""")
with gr.Column(scale=1):
lang_btn = gr.Radio(
choices=["en", "zh"],
value="zh",
label="🌐 Language / 语言",
elem_classes="lang-selector"
choices=["en", "zh"], value="zh", label="🌐 Language / 语言", elem_classes="lang-selector"
)
with gr.Tabs(elem_classes="tabs") as tabs:
@@ -869,79 +857,40 @@ with gr.Blocks(
gr.Markdown("#### 📁 基础配置")
train_pretrained_path = gr.Textbox(
label="📂 预训练模型路径",
value=default_pretrained_path,
elem_classes="input-field"
label="📂 预训练模型路径", value=default_pretrained_path, elem_classes="input-field"
)
train_manifest = gr.Textbox(
label="📋 训练数据清单 (jsonl)",
value="examples/train_data_example.jsonl",
elem_classes="input-field"
)
val_manifest = gr.Textbox(
label="📊 验证数据清单 (可选)",
value="",
elem_classes="input-field"
elem_classes="input-field",
)
val_manifest = gr.Textbox(label="📊 验证数据清单 (可选)", value="", elem_classes="input-field")
gr.Markdown("#### ⚙️ 训练参数")
with gr.Row():
lr = gr.Number(
label="📈 学习率 (Learning Rate)",
value=1e-4,
elem_classes="input-field"
)
lr = gr.Number(label="📈 学习率 (Learning Rate)", value=1e-4, elem_classes="input-field")
num_iters = gr.Number(
label="🔄 最大迭代次数",
value=2000,
precision=0,
elem_classes="input-field"
label="🔄 最大迭代次数", value=2000, precision=0, elem_classes="input-field"
)
batch_size = gr.Number(
label="📦 批次大小 (Batch Size)",
value=1,
precision=0,
elem_classes="input-field"
label="📦 批次大小 (Batch Size)", value=1, precision=0, elem_classes="input-field"
)
with gr.Row():
lora_rank = gr.Number(
label="🎯 LoRA Rank",
value=32,
precision=0,
elem_classes="input-field"
)
lora_alpha = gr.Number(
label="⚖️ LoRA Alpha",
value=16,
precision=0,
elem_classes="input-field"
)
lora_rank = gr.Number(label="🎯 LoRA Rank", value=32, precision=0, elem_classes="input-field")
lora_alpha = gr.Number(label="⚖️ LoRA Alpha", value=16, precision=0, elem_classes="input-field")
save_interval = gr.Number(
label="💾 保存间隔 (Steps)",
value=1000,
precision=0,
elem_classes="input-field"
label="💾 保存间隔 (Steps)", value=1000, precision=0, elem_classes="input-field"
)
output_name = gr.Textbox(
label="📁 输出目录名称 (可选,若存在则继续训练)",
value="",
elem_classes="input-field"
label="📁 输出目录名称 (可选,若存在则继续训练)", value="", elem_classes="input-field"
)
with gr.Row():
start_btn = gr.Button(
" 开始训练",
variant="primary",
elem_classes="button-primary"
)
stop_btn = gr.Button(
"⏹️ 停止训练",
variant="stop",
elem_classes="button-stop"
)
start_btn = gr.Button("▶️ 开始训练", variant="primary", elem_classes="button-primary")
stop_btn = gr.Button(" 停止训练", variant="stop", elem_classes="button-stop")
with gr.Accordion("🔧 高级选项 (Advanced)", open=False, elem_classes="accordion"):
with gr.Row():
@@ -961,10 +910,12 @@ with gr.Blocks(
enable_dit = gr.Checkbox(label="启用 LoRA DIT (enable_dit)", value=True)
enable_proj = gr.Checkbox(label="启用投影 (enable_proj)", value=False)
dropout = gr.Number(label="LoRA Dropout", value=0.0)
gr.Markdown("#### 分发选项 (Distribution)")
with gr.Row():
hf_model_id = gr.Textbox(label="HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)", value="openbmb/VoxCPM1.5")
hf_model_id = gr.Textbox(
label="HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)", value="openbmb/VoxCPM1.5"
)
distribute = gr.Checkbox(label="分发模式 (distribute)", value=False)
with gr.Column(scale=2, elem_classes="form-section"):
@@ -975,26 +926,44 @@ with gr.Blocks(
max_lines=30,
interactive=False,
elem_classes="input-field",
show_label=False
show_label=False,
)
start_btn.click(
start_training,
inputs=[
train_pretrained_path, train_manifest, val_manifest,
lr, num_iters, batch_size, lora_rank, lora_alpha, save_interval,
train_pretrained_path,
train_manifest,
val_manifest,
lr,
num_iters,
batch_size,
lora_rank,
lora_alpha,
save_interval,
output_name,
# advanced
grad_accum_steps, num_workers, log_interval, valid_interval,
weight_decay, warmup_steps, max_steps, sample_rate,
enable_lm, enable_dit, enable_proj, dropout, tensorboard_path,
grad_accum_steps,
num_workers,
log_interval,
valid_interval,
weight_decay,
warmup_steps,
max_steps,
sample_rate,
enable_lm,
enable_dit,
enable_proj,
dropout,
tensorboard_path,
# distribution
hf_model_id, distribute
hf_model_id,
distribute,
],
outputs=[logs_out] # Initial message
outputs=[logs_out], # Initial message
)
stop_btn.click(stop_training, outputs=[logs_out])
# Log refresher
timer = gr.Timer(1)
timer.tick(get_training_log, outputs=logs_out)
@@ -1016,21 +985,17 @@ with gr.Blocks(
value="Hello, this is a test of the VoxCPM LoRA model.",
elem_classes="input-field",
lines=4,
placeholder="输入要合成的文本内容..."
placeholder="输入要合成的文本内容...",
)
gr.Markdown("**🎭 声音克隆(可选)**")
prompt_wav = gr.Audio(
label="🎵 参考音频",
type="filepath",
elem_classes="input-field"
)
prompt_wav = gr.Audio(label="🎵 参考音频", type="filepath", elem_classes="input-field")
prompt_text = gr.Textbox(
label="📝 参考文本(可选)",
elem_classes="input-field",
placeholder="如不填写,将自动识别参考音频内容"
placeholder="如不填写,将自动识别参考音频内容",
)
# 中栏:模型选择和参数配置 (35%)
@@ -1043,15 +1008,11 @@ with gr.Blocks(
value="None",
interactive=True,
elem_classes="input-field",
info="选择训练好的 LoRA 模型,或选择 None 使用基础模型"
)
refresh_lora_btn = gr.Button(
"🔄 刷新模型列表",
elem_classes="button-refresh",
size="sm"
info="选择训练好的 LoRA 模型,或选择 None 使用基础模型",
)
refresh_lora_btn = gr.Button("🔄 刷新模型列表", elem_classes="button-refresh", size="sm")
gr.Markdown("#### ⚙️ 生成参数")
cfg_scale = gr.Slider(
@@ -1060,59 +1021,50 @@ with gr.Blocks(
maximum=5.0,
value=2.0,
step=0.1,
info="引导系数,值越大越贴近提示"
info="引导系数,值越大越贴近提示",
)
steps = gr.Slider(
label="🔢 推理步数",
minimum=1,
maximum=50,
value=10,
step=1,
info="生成质量与步数成正比,但耗时更长"
info="生成质量与步数成正比,但耗时更长",
)
seed = gr.Number(
label="🎲 随机种子",
value=-1,
precision=0,
elem_classes="input-field",
info="-1 为随机,固定值可复现结果"
info="-1 为随机,固定值可复现结果",
)
generate_btn = gr.Button(
"🎵 生成音频",
variant="primary",
elem_classes="button-primary",
size="lg"
)
generate_btn = gr.Button("🎵 生成音频", variant="primary", elem_classes="button-primary", size="lg")
# 右栏:生成结果 (30%)
with gr.Column(scale=30, elem_classes="form-section"):
gr.Markdown("#### 🎧 生成结果")
audio_out = gr.Audio(
label="",
elem_classes="input-field",
show_label=False
)
audio_out = gr.Audio(label="", elem_classes="input-field", show_label=False)
gr.Markdown("#### 📋 状态信息")
status_out = gr.Textbox(
label="",
interactive=False,
elem_classes="input-field",
show_label=False,
lines=3,
placeholder="等待生成..."
placeholder="等待生成...",
)
def refresh_loras():
# 获取 LoRA checkpoints 及其 base model 信息
checkpoints_with_info = scan_lora_checkpoints(with_info=True)
choices = ["None"] + [ckpt[0] for ckpt in checkpoints_with_info]
# 输出调试信息
print(f"刷新 LoRA 列表: 找到 {len(checkpoints_with_info)} 个检查点", file=sys.stderr)
for ckpt_path, base_model in checkpoints_with_info:
@@ -1120,22 +1072,27 @@ with gr.Blocks(
print(f" - {ckpt_path} (Base Model: {base_model})", file=sys.stderr)
else:
print(f" - {ckpt_path}", file=sys.stderr)
return gr.update(choices=choices, value="None")
refresh_lora_btn.click(refresh_loras, outputs=[lora_select])
# Auto-recognize audio when uploaded
prompt_wav.change(
fn=recognize_audio,
inputs=[prompt_wav],
outputs=[prompt_text]
)
prompt_wav.change(fn=recognize_audio, inputs=[prompt_wav], outputs=[prompt_text])
generate_btn.click(
run_inference,
inputs=[infer_text, prompt_wav, prompt_text, lora_select, cfg_scale, steps, seed, train_pretrained_path],
outputs=[audio_out, status_out]
inputs=[
infer_text,
prompt_wav,
prompt_text,
lora_select,
cfg_scale,
steps,
seed,
train_pretrained_path,
],
outputs=[audio_out, status_out],
)
# --- Language Switching Logic ---
@@ -1144,111 +1101,141 @@ with gr.Blocks(
# Labels for advanced options
if lang == "zh":
adv = {
'grad_accum_steps': "梯度累积 (grad_accum_steps)",
'num_workers': "数据加载线程 (num_workers)",
'log_interval': "日志间隔 (log_interval)",
'valid_interval': "验证间隔 (valid_interval)",
'weight_decay': "权重衰减 (weight_decay)",
'warmup_steps': "warmup_steps",
'max_steps': "最大步数 (max_steps)",
'sample_rate': "采样率 (sample_rate)",
'enable_lm': "启用 LoRA LM (enable_lm)",
'enable_dit': "启用 LoRA DIT (enable_dit)",
'enable_proj': "启用投影 (enable_proj)",
'dropout': "LoRA Dropout",
'tensorboard_path': "Tensorboard 路径 (可选)",
'hf_model_id': "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
'distribute': "分发模式 (distribute)",
"grad_accum_steps": "梯度累积 (grad_accum_steps)",
"num_workers": "数据加载线程 (num_workers)",
"log_interval": "日志间隔 (log_interval)",
"valid_interval": "验证间隔 (valid_interval)",
"weight_decay": "权重衰减 (weight_decay)",
"warmup_steps": "warmup_steps",
"max_steps": "最大步数 (max_steps)",
"sample_rate": "采样率 (sample_rate)",
"enable_lm": "启用 LoRA LM (enable_lm)",
"enable_dit": "启用 LoRA DIT (enable_dit)",
"enable_proj": "启用投影 (enable_proj)",
"dropout": "LoRA Dropout",
"tensorboard_path": "Tensorboard 路径 (可选)",
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
"distribute": "分发模式 (distribute)",
}
else:
adv = {
'grad_accum_steps': "Grad Accum Steps",
'num_workers': "Num Workers",
'log_interval': "Log Interval",
'valid_interval': "Valid Interval",
'weight_decay': "Weight Decay",
'warmup_steps': "Warmup Steps",
'max_steps': "Max Steps",
'sample_rate': "Sample Rate",
'enable_lm': "Enable LoRA LM",
'enable_dit': "Enable LoRA DIT",
'enable_proj': "Enable Projection",
'dropout': "LoRA Dropout",
'tensorboard_path': "Tensorboard Path (Optional)",
'hf_model_id': "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
'distribute': "Distribute Mode",
"grad_accum_steps": "Grad Accum Steps",
"num_workers": "Num Workers",
"log_interval": "Log Interval",
"valid_interval": "Valid Interval",
"weight_decay": "Weight Decay",
"warmup_steps": "Warmup Steps",
"max_steps": "Max Steps",
"sample_rate": "Sample Rate",
"enable_lm": "Enable LoRA LM",
"enable_dit": "Enable LoRA DIT",
"enable_proj": "Enable Projection",
"dropout": "LoRA Dropout",
"tensorboard_path": "Tensorboard Path (Optional)",
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
"distribute": "Distribute Mode",
}
return (
gr.update(value=f"# {d['title']}"),
gr.update(label=d['tab_train']),
gr.update(label=d['tab_infer']),
gr.update(label=d['pretrained_path']),
gr.update(label=d['train_manifest']),
gr.update(label=d['val_manifest']),
gr.update(label=d['lr']),
gr.update(label=d['max_iters']),
gr.update(label=d['batch_size']),
gr.update(label=d['lora_rank']),
gr.update(label=d['lora_alpha']),
gr.update(label=d['save_interval']),
gr.update(label=d['output_name']),
gr.update(value=d['start_train']),
gr.update(value=d['stop_train']),
gr.update(label=d['train_logs']),
gr.update(label=d["tab_train"]),
gr.update(label=d["tab_infer"]),
gr.update(label=d["pretrained_path"]),
gr.update(label=d["train_manifest"]),
gr.update(label=d["val_manifest"]),
gr.update(label=d["lr"]),
gr.update(label=d["max_iters"]),
gr.update(label=d["batch_size"]),
gr.update(label=d["lora_rank"]),
gr.update(label=d["lora_alpha"]),
gr.update(label=d["save_interval"]),
gr.update(label=d["output_name"]),
gr.update(value=d["start_train"]),
gr.update(value=d["stop_train"]),
gr.update(label=d["train_logs"]),
# Advanced options (must match outputs order)
gr.update(label=adv['grad_accum_steps']),
gr.update(label=adv['num_workers']),
gr.update(label=adv['log_interval']),
gr.update(label=adv['valid_interval']),
gr.update(label=adv['weight_decay']),
gr.update(label=adv['warmup_steps']),
gr.update(label=adv['max_steps']),
gr.update(label=adv['sample_rate']),
gr.update(label=adv['enable_lm']),
gr.update(label=adv['enable_dit']),
gr.update(label=adv['enable_proj']),
gr.update(label=adv['dropout']),
gr.update(label=adv['tensorboard_path']),
gr.update(label=adv["grad_accum_steps"]),
gr.update(label=adv["num_workers"]),
gr.update(label=adv["log_interval"]),
gr.update(label=adv["valid_interval"]),
gr.update(label=adv["weight_decay"]),
gr.update(label=adv["warmup_steps"]),
gr.update(label=adv["max_steps"]),
gr.update(label=adv["sample_rate"]),
gr.update(label=adv["enable_lm"]),
gr.update(label=adv["enable_dit"]),
gr.update(label=adv["enable_proj"]),
gr.update(label=adv["dropout"]),
gr.update(label=adv["tensorboard_path"]),
# Distribution options
gr.update(label=adv['hf_model_id']),
gr.update(label=adv['distribute']),
gr.update(label=adv["hf_model_id"]),
gr.update(label=adv["distribute"]),
# Inference section
gr.update(label=d['text_to_synth']),
gr.update(label=d['ref_audio']),
gr.update(label=d['ref_text']),
gr.update(label=d['select_lora']),
gr.update(value=d['refresh']),
gr.update(label=d['cfg_scale']),
gr.update(label=d['infer_steps']),
gr.update(label=d['seed']),
gr.update(value=d['gen_audio']),
gr.update(label=d['gen_output']),
gr.update(label=d['status']),
gr.update(label=d["text_to_synth"]),
gr.update(label=d["ref_audio"]),
gr.update(label=d["ref_text"]),
gr.update(label=d["select_lora"]),
gr.update(value=d["refresh"]),
gr.update(label=d["cfg_scale"]),
gr.update(label=d["infer_steps"]),
gr.update(label=d["seed"]),
gr.update(value=d["gen_audio"]),
gr.update(label=d["gen_output"]),
gr.update(label=d["status"]),
)
lang_btn.change(
change_language,
inputs=[lang_btn],
outputs=[
title_md, tab_train, tab_infer,
train_pretrained_path, train_manifest, val_manifest,
lr, num_iters, batch_size, lora_rank, lora_alpha, save_interval,
title_md,
tab_train,
tab_infer,
train_pretrained_path,
train_manifest,
val_manifest,
lr,
num_iters,
batch_size,
lora_rank,
lora_alpha,
save_interval,
output_name,
start_btn, stop_btn, logs_out,
start_btn,
stop_btn,
logs_out,
# advanced outputs
grad_accum_steps, num_workers, log_interval, valid_interval,
weight_decay, warmup_steps, max_steps, sample_rate,
enable_lm, enable_dit, enable_proj, dropout, tensorboard_path,
grad_accum_steps,
num_workers,
log_interval,
valid_interval,
weight_decay,
warmup_steps,
max_steps,
sample_rate,
enable_lm,
enable_dit,
enable_proj,
dropout,
tensorboard_path,
# distribution outputs
hf_model_id, distribute,
infer_text, prompt_wav, prompt_text,
lora_select, refresh_lora_btn, cfg_scale, steps, seed,
generate_btn, audio_out, status_out
]
hf_model_id,
distribute,
infer_text,
prompt_wav,
prompt_text,
lora_select,
refresh_lora_btn,
cfg_scale,
steps,
seed,
generate_btn,
audio_out,
status_out,
],
)
if __name__ == "__main__":
# Ensure lora directory exists
os.makedirs("lora", exist_ok=True)
app.queue().launch(server_name="0.0.0.0", server_port=7860)
app.queue().launch(server_name="0.0.0.0", server_port=7860)