diff --git a/scripts/train_voxcpm_finetune.py b/scripts/train_voxcpm_finetune.py index b6a9f43..169c9e9 100644 --- a/scripts/train_voxcpm_finetune.py +++ b/scripts/train_voxcpm_finetune.py @@ -477,20 +477,28 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s log(f"[Audio] Loaded reference audio for sample {i}: duration={len(ref_audio_np)/sample_rate:.2f}s") except Exception as e: log(f"[Warning] Failed to load reference audio: {e}") - + + # 记录原模式,避免异常时状态不一致 + prev_training = unwrapped_model.training try: # Inference setup unwrapped_model.eval() - unwrapped_model.to(torch.bfloat16) + # unwrapped_model.to(torch.bfloat16) unwrapped_model.audio_vae = audio_vae.to(torch.float32) log(f"[Audio] Generating sample {i} with text: '{text[:50]}...'") + autocast_ctx = ( + torch.autocast(device_type="cuda", dtype=torch.bfloat16) + if torch.cuda.is_available() + else contextlib.nullcontext() + ) with torch.no_grad(): - generated = unwrapped_model.generate(target_text=text, inference_timesteps=10, cfg_value=2.0) + with autocast_ctx: + generated = unwrapped_model.generate(target_text=text, inference_timesteps=10, cfg_value=2.0) # Restore training setup - unwrapped_model.to(torch.float32) - unwrapped_model.audio_vae = None + # unwrapped_model.to(torch.float32) + # unwrapped_model.audio_vae = None if generated is None or len(generated) == 0: log(f"[Warning] Generated audio is empty for sample {i}") @@ -523,6 +531,18 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s import traceback traceback.print_exc() + finally: + # Restore training setup(无论成功失败都恢复) + try: + # unwrapped_model.to(torch.float32) + unwrapped_model.audio_vae = None + if prev_training: + unwrapped_model.train() + else: + unwrapped_model.eval() + except Exception as e: + log(f"[Warning] Failed to restore model state: {e}") + def load_checkpoint(model, optimizer, scheduler, save_dir: Path): """ @@ -671,4 +691,4 @@ if __name__ == "__main__": else: # Otherwise use command line args (parsed by argbind) with argbind.scope(args): - train() \ No newline at end of file + train()