diff --git a/conf/voxcpm_v2/voxcpm_finetune_all.yaml b/conf/voxcpm_v2/voxcpm_finetune_all.yaml index 9717290..0b32beb 100644 --- a/conf/voxcpm_v2/voxcpm_finetune_all.yaml +++ b/conf/voxcpm_v2/voxcpm_finetune_all.yaml @@ -1,7 +1,8 @@ pretrained_path: /path/to/VoxCPM2/ train_manifest: /path/to/train.jsonl val_manifest: null -sample_rate: 48000 +sample_rate: 16000 # AudioVAE encoder input rate; must match audio_vae_config.sample_rate +out_sample_rate: 48000 # AudioVAE decoder output rate; only used at inference, not during training batch_size: 2 grad_accum_steps: 8 # effective batch size = batch_size × grad_accum_steps = 16 num_workers: 8 diff --git a/conf/voxcpm_v2/voxcpm_finetune_lora.yaml b/conf/voxcpm_v2/voxcpm_finetune_lora.yaml index f5d2d8a..32c9a40 100644 --- a/conf/voxcpm_v2/voxcpm_finetune_lora.yaml +++ b/conf/voxcpm_v2/voxcpm_finetune_lora.yaml @@ -1,7 +1,8 @@ pretrained_path: /path/to/VoxCPM2/ train_manifest: /path/to/train.jsonl val_manifest: null -sample_rate: 48000 +sample_rate: 16000 # AudioVAE encoder input rate; must match audio_vae_config.sample_rate +out_sample_rate: 48000 # AudioVAE decoder output rate; only used at inference, not during training batch_size: 2 grad_accum_steps: 8 # effective batch size = batch_size × grad_accum_steps = 16 num_workers: 8 diff --git a/lora_ft_webui.py b/lora_ft_webui.py index d909e39..e9982ea 100644 --- a/lora_ft_webui.py +++ b/lora_ft_webui.py @@ -99,6 +99,24 @@ def get_timestamp_str(): return datetime.datetime.now().strftime("%Y%m%d_%H%M%S") +def detect_sample_rate(pretrained_path: str) -> Optional[int]: + """Read audio_vae_config.sample_rate from the model's config.json. + + This is the AudioVAE *encoder* input rate, which is the correct rate for + resampling training data. Returns None when detection fails. + """ + config_file = os.path.join(pretrained_path, "config.json") + if not os.path.isfile(config_file): + return None + try: + with open(config_file, "r", encoding="utf-8") as f: + cfg = json.load(f) + return int(cfg["audio_vae_config"]["sample_rate"]) + except (KeyError, ValueError, json.JSONDecodeError) as e: + print(f"Warning: failed to detect sample_rate from {config_file}: {e}", file=sys.stderr) + return None + + def get_or_load_asr_model(): global asr_model if asr_model is None: @@ -377,6 +395,16 @@ def start_training( os.makedirs(checkpoints_dir, exist_ok=True) os.makedirs(logs_dir, exist_ok=True) + # Auto-detect sample_rate from model config.json to prevent mismatch + detected_sr = detect_sample_rate(pretrained_path) + if detected_sr is not None: + if int(sample_rate) != detected_sr: + training_log += ( + f"[Auto-fix] sample_rate changed from {int(sample_rate)} to {detected_sr} " + f"(read from {pretrained_path}/config.json audio_vae_config.sample_rate)\n" + ) + sample_rate = detected_sr + # Create config dictionary # Resolve max_steps default resolved_max_steps = int(max_steps) if max_steps not in (None, "", 0) else int(num_iters) @@ -929,6 +957,19 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css show_label=False, ) + def on_pretrained_path_change(path): + """Auto-detect sample_rate when pretrained model path changes.""" + sr = detect_sample_rate(path) + if sr is not None: + return gr.update(value=sr) + return gr.update() + + train_pretrained_path.change( + on_pretrained_path_change, + inputs=[train_pretrained_path], + outputs=[sample_rate], + ) + start_btn.click( start_training, inputs=[ diff --git a/scripts/train_voxcpm_finetune.py b/scripts/train_voxcpm_finetune.py index e034b69..2b05e6b 100644 --- a/scripts/train_voxcpm_finetune.py +++ b/scripts/train_voxcpm_finetune.py @@ -46,6 +46,7 @@ def train( train_manifest: str, val_manifest: str = "", sample_rate: int = 16_000, + out_sample_rate: int = 0, # accepted from YAML for documentation; not used in training batch_size: int = 1, grad_accum_steps: int = 1, num_workers: int = 2, @@ -68,6 +69,7 @@ def train( distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path ): _ = config_path + _ = out_sample_rate # Validate distribution options if lora is not None and distribute and not hf_model_id: @@ -98,6 +100,12 @@ def train( ) tokenizer = base_model.text_tokenizer + expected_sr = base_model.audio_vae.sample_rate + assert sample_rate == expected_sr, ( + f"sample_rate mismatch: config says {sample_rate}, but the AudioVAE encoder expects {expected_sr}. " + f"Please set sample_rate: {expected_sr} in your training config. " + ) + train_ds, val_ds = load_audio_text_datasets( train_manifest=train_manifest, val_manifest=val_manifest,