fix: ft log and setting
This commit is contained in:
+33
-8
@@ -14,8 +14,10 @@ from typing import Optional
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root / "src"))
|
||||
|
||||
# Default pretrained model path relative to this repo
|
||||
default_pretrained_path = str(project_root / "models" / "openbmb__VoxCPM1.5")
|
||||
# Default pretrained model path: prefer VoxCPM2 if it exists, fallback to VoxCPM1.5
|
||||
_v2_path = project_root / "models" / "openbmb__VoxCPM2"
|
||||
_v15_path = project_root / "models" / "openbmb__VoxCPM1.5"
|
||||
default_pretrained_path = str(_v2_path if _v2_path.exists() else _v15_path)
|
||||
|
||||
from voxcpm.core import VoxCPM
|
||||
from voxcpm.model.voxcpm import LoRAConfig
|
||||
@@ -368,6 +370,7 @@ def start_training(
|
||||
warmup_steps=100,
|
||||
max_steps=None,
|
||||
sample_rate=44100,
|
||||
max_grad_norm=1.0,
|
||||
# LoRA advanced
|
||||
enable_lm=True,
|
||||
enable_dit=True,
|
||||
@@ -409,11 +412,25 @@ def start_training(
|
||||
# Resolve max_steps default
|
||||
resolved_max_steps = int(max_steps) if max_steps not in (None, "", 0) else int(num_iters)
|
||||
|
||||
# Auto-detect out_sample_rate from model config
|
||||
out_sample_rate = 0
|
||||
config_file = os.path.join(pretrained_path, "config.json")
|
||||
if os.path.isfile(config_file):
|
||||
try:
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
out_sr = cfg.get("audio_vae_config", {}).get("out_sample_rate")
|
||||
if out_sr:
|
||||
out_sample_rate = int(out_sr)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
config = {
|
||||
"pretrained_path": pretrained_path,
|
||||
"train_manifest": train_manifest,
|
||||
"val_manifest": val_manifest,
|
||||
"sample_rate": int(sample_rate),
|
||||
"out_sample_rate": out_sample_rate,
|
||||
"batch_size": int(batch_size),
|
||||
"grad_accum_steps": int(grad_accum_steps),
|
||||
"num_workers": int(num_workers),
|
||||
@@ -425,6 +442,7 @@ def start_training(
|
||||
"weight_decay": float(weight_decay),
|
||||
"warmup_steps": int(warmup_steps),
|
||||
"max_steps": resolved_max_steps,
|
||||
"max_grad_norm": float(max_grad_norm),
|
||||
"save_path": checkpoints_dir,
|
||||
"tensorboard": tensorboard_path if tensorboard_path else logs_dir,
|
||||
"lambdas": {"loss/diff": 1.0, "loss/stop": 1.0},
|
||||
@@ -932,17 +950,19 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
|
||||
with gr.Row():
|
||||
max_steps = gr.Number(label="最大步数 (max_steps, 0→默认num_iters)", value=0, precision=0)
|
||||
sample_rate = gr.Number(label="采样率 (sample_rate)", value=44100, precision=0)
|
||||
tensorboard_path = gr.Textbox(label="Tensorboard 路径 (可选)", value="")
|
||||
max_grad_norm = gr.Number(label="梯度裁剪 (max_grad_norm, 0=关闭)", value=1.0)
|
||||
with gr.Row():
|
||||
tensorboard_path = gr.Textbox(label="Tensorboard 路径 (可选)", value="")
|
||||
enable_lm = gr.Checkbox(label="启用 LoRA LM (enable_lm)", value=True)
|
||||
enable_dit = gr.Checkbox(label="启用 LoRA DIT (enable_dit)", value=True)
|
||||
with gr.Row():
|
||||
enable_proj = gr.Checkbox(label="启用投影 (enable_proj)", value=False)
|
||||
dropout = gr.Number(label="LoRA Dropout", value=0.0)
|
||||
|
||||
gr.Markdown("#### 分发选项 (Distribution)")
|
||||
with gr.Row():
|
||||
hf_model_id = gr.Textbox(
|
||||
label="HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)", value="openbmb/VoxCPM1.5"
|
||||
label="HuggingFace Model ID (e.g., openbmb/VoxCPM2)", value=""
|
||||
)
|
||||
distribute = gr.Checkbox(label="分发模式 (distribute)", value=False)
|
||||
|
||||
@@ -992,6 +1012,7 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
|
||||
warmup_steps,
|
||||
max_steps,
|
||||
sample_rate,
|
||||
max_grad_norm,
|
||||
enable_lm,
|
||||
enable_dit,
|
||||
enable_proj,
|
||||
@@ -1150,12 +1171,13 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
|
||||
"warmup_steps": "warmup_steps",
|
||||
"max_steps": "最大步数 (max_steps)",
|
||||
"sample_rate": "采样率 (sample_rate)",
|
||||
"max_grad_norm": "梯度裁剪 (max_grad_norm, 0=关闭)",
|
||||
"enable_lm": "启用 LoRA LM (enable_lm)",
|
||||
"enable_dit": "启用 LoRA DIT (enable_dit)",
|
||||
"enable_proj": "启用投影 (enable_proj)",
|
||||
"dropout": "LoRA Dropout",
|
||||
"tensorboard_path": "Tensorboard 路径 (可选)",
|
||||
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
|
||||
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM2)",
|
||||
"distribute": "分发模式 (distribute)",
|
||||
}
|
||||
else:
|
||||
@@ -1168,12 +1190,13 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
|
||||
"warmup_steps": "Warmup Steps",
|
||||
"max_steps": "Max Steps",
|
||||
"sample_rate": "Sample Rate",
|
||||
"max_grad_norm": "Max Grad Norm (0=disabled)",
|
||||
"enable_lm": "Enable LoRA LM",
|
||||
"enable_dit": "Enable LoRA DIT",
|
||||
"enable_proj": "Enable Projection",
|
||||
"dropout": "LoRA Dropout",
|
||||
"tensorboard_path": "Tensorboard Path (Optional)",
|
||||
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
|
||||
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM2)",
|
||||
"distribute": "Distribute Mode",
|
||||
}
|
||||
|
||||
@@ -1203,11 +1226,12 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
|
||||
gr.update(label=adv["warmup_steps"]),
|
||||
gr.update(label=adv["max_steps"]),
|
||||
gr.update(label=adv["sample_rate"]),
|
||||
gr.update(label=adv["max_grad_norm"]),
|
||||
gr.update(label=adv["tensorboard_path"]),
|
||||
gr.update(label=adv["enable_lm"]),
|
||||
gr.update(label=adv["enable_dit"]),
|
||||
gr.update(label=adv["enable_proj"]),
|
||||
gr.update(label=adv["dropout"]),
|
||||
gr.update(label=adv["tensorboard_path"]),
|
||||
# Distribution options
|
||||
gr.update(label=adv["hf_model_id"]),
|
||||
gr.update(label=adv["distribute"]),
|
||||
@@ -1254,11 +1278,12 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
|
||||
warmup_steps,
|
||||
max_steps,
|
||||
sample_rate,
|
||||
max_grad_norm,
|
||||
tensorboard_path,
|
||||
enable_lm,
|
||||
enable_dit,
|
||||
enable_proj,
|
||||
dropout,
|
||||
tensorboard_path,
|
||||
# distribution outputs
|
||||
hf_model_id,
|
||||
distribute,
|
||||
|
||||
Reference in New Issue
Block a user