fix VoxCPM2 training sample_rate: 48000 -> 16000 (match AudioVAE encoder)

Made-with: Cursor
This commit is contained in:
刘鑫
2026-04-07 22:57:42 +08:00
parent da700f264e
commit 46cfce0c97
4 changed files with 53 additions and 2 deletions
+41
View File
@@ -99,6 +99,24 @@ def get_timestamp_str():
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
def detect_sample_rate(pretrained_path: str) -> Optional[int]:
"""Read audio_vae_config.sample_rate from the model's config.json.
This is the AudioVAE *encoder* input rate, which is the correct rate for
resampling training data. Returns None when detection fails.
"""
config_file = os.path.join(pretrained_path, "config.json")
if not os.path.isfile(config_file):
return None
try:
with open(config_file, "r", encoding="utf-8") as f:
cfg = json.load(f)
return int(cfg["audio_vae_config"]["sample_rate"])
except (KeyError, ValueError, json.JSONDecodeError) as e:
print(f"Warning: failed to detect sample_rate from {config_file}: {e}", file=sys.stderr)
return None
def get_or_load_asr_model():
global asr_model
if asr_model is None:
@@ -377,6 +395,16 @@ def start_training(
os.makedirs(checkpoints_dir, exist_ok=True)
os.makedirs(logs_dir, exist_ok=True)
# Auto-detect sample_rate from model config.json to prevent mismatch
detected_sr = detect_sample_rate(pretrained_path)
if detected_sr is not None:
if int(sample_rate) != detected_sr:
training_log += (
f"[Auto-fix] sample_rate changed from {int(sample_rate)} to {detected_sr} "
f"(read from {pretrained_path}/config.json audio_vae_config.sample_rate)\n"
)
sample_rate = detected_sr
# Create config dictionary
# Resolve max_steps default
resolved_max_steps = int(max_steps) if max_steps not in (None, "", 0) else int(num_iters)
@@ -929,6 +957,19 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
show_label=False,
)
def on_pretrained_path_change(path):
"""Auto-detect sample_rate when pretrained model path changes."""
sr = detect_sample_rate(path)
if sr is not None:
return gr.update(value=sr)
return gr.update()
train_pretrained_path.change(
on_pretrained_path_change,
inputs=[train_pretrained_path],
outputs=[sample_rate],
)
start_btn.click(
start_training,
inputs=[