fix VoxCPM2 training sample_rate: 48000 -> 16000 (match AudioVAE encoder)
Made-with: Cursor
This commit is contained in:
@@ -99,6 +99,24 @@ def get_timestamp_str():
|
||||
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
|
||||
def detect_sample_rate(pretrained_path: str) -> Optional[int]:
|
||||
"""Read audio_vae_config.sample_rate from the model's config.json.
|
||||
|
||||
This is the AudioVAE *encoder* input rate, which is the correct rate for
|
||||
resampling training data. Returns None when detection fails.
|
||||
"""
|
||||
config_file = os.path.join(pretrained_path, "config.json")
|
||||
if not os.path.isfile(config_file):
|
||||
return None
|
||||
try:
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
return int(cfg["audio_vae_config"]["sample_rate"])
|
||||
except (KeyError, ValueError, json.JSONDecodeError) as e:
|
||||
print(f"Warning: failed to detect sample_rate from {config_file}: {e}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
|
||||
def get_or_load_asr_model():
|
||||
global asr_model
|
||||
if asr_model is None:
|
||||
@@ -377,6 +395,16 @@ def start_training(
|
||||
os.makedirs(checkpoints_dir, exist_ok=True)
|
||||
os.makedirs(logs_dir, exist_ok=True)
|
||||
|
||||
# Auto-detect sample_rate from model config.json to prevent mismatch
|
||||
detected_sr = detect_sample_rate(pretrained_path)
|
||||
if detected_sr is not None:
|
||||
if int(sample_rate) != detected_sr:
|
||||
training_log += (
|
||||
f"[Auto-fix] sample_rate changed from {int(sample_rate)} to {detected_sr} "
|
||||
f"(read from {pretrained_path}/config.json audio_vae_config.sample_rate)\n"
|
||||
)
|
||||
sample_rate = detected_sr
|
||||
|
||||
# Create config dictionary
|
||||
# Resolve max_steps default
|
||||
resolved_max_steps = int(max_steps) if max_steps not in (None, "", 0) else int(num_iters)
|
||||
@@ -929,6 +957,19 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
|
||||
show_label=False,
|
||||
)
|
||||
|
||||
def on_pretrained_path_change(path):
|
||||
"""Auto-detect sample_rate when pretrained model path changes."""
|
||||
sr = detect_sample_rate(path)
|
||||
if sr is not None:
|
||||
return gr.update(value=sr)
|
||||
return gr.update()
|
||||
|
||||
train_pretrained_path.change(
|
||||
on_pretrained_path_change,
|
||||
inputs=[train_pretrained_path],
|
||||
outputs=[sample_rate],
|
||||
)
|
||||
|
||||
start_btn.click(
|
||||
start_training,
|
||||
inputs=[
|
||||
|
||||
Reference in New Issue
Block a user