From 23ed7ffeeeeed6843b6f39b42ecc98357e455f7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E9=91=AB?= Date: Fri, 13 Mar 2026 18:43:07 +0800 Subject: [PATCH] fix: fix some bugs in resuming multi-GPU training --- scripts/train_voxcpm_finetune.py | 71 +++++++++++++++++++------------- 1 file changed, 42 insertions(+), 29 deletions(-) diff --git a/scripts/train_voxcpm_finetune.py b/scripts/train_voxcpm_finetune.py index 169c9e9..b4a2f5b 100644 --- a/scripts/train_voxcpm_finetune.py +++ b/scripts/train_voxcpm_finetune.py @@ -188,15 +188,9 @@ def train( num_training_steps=total_training_steps, ) - # Try to load checkpoint and resume training - start_step = 0 - if accelerator.rank == 0: - start_step = load_checkpoint(model, optimizer, scheduler, save_dir) - # Broadcast start_step to all processes - if hasattr(accelerator, 'all_reduce'): - start_step_tensor = torch.tensor(start_step, device=accelerator.device) - accelerator.all_reduce(start_step_tensor) - start_step = int(start_step_tensor.item()) + # All ranks load the same checkpoint to keep model and optimizer state in sync. + start_step = load_checkpoint(model, optimizer, scheduler, save_dir, rank=accelerator.rank) + accelerator.barrier() if start_step > 0 and accelerator.rank == 0: tracker.print(f"Resuming training from step {start_step}") @@ -205,17 +199,18 @@ def train( resume = {"step": start_step} # Register signal handler to save checkpoint on termination (SIGTERM/SIGINT) - def _signal_handler(signum, frame, _model=model, _optim=optimizer, _sched=scheduler, _save_dir=save_dir, _pretrained=pretrained_path, _hf_id=hf_model_id, _dist=distribute, _resume=resume): + def _signal_handler(signum, frame, _model=model, _optim=optimizer, _sched=scheduler, _save_dir=save_dir, _pretrained=pretrained_path, _hf_id=hf_model_id, _dist=distribute, _resume=resume, _rank=accelerator.rank): try: cur_step = int(_resume.get("step", start_step)) except Exception: cur_step = start_step - print(f"Signal {signum} received. Saving checkpoint at step {cur_step} ...", file=sys.stderr) - try: - save_checkpoint(_model, _optim, _sched, _save_dir, cur_step, _pretrained, _hf_id, _dist) - print("Checkpoint saved. Exiting.", file=sys.stderr) - except Exception as e: - print(f"Error saving checkpoint on signal: {e}", file=sys.stderr) + if _rank == 0: + print(f"Signal {signum} received. Saving checkpoint at step {cur_step} ...", file=sys.stderr) + try: + save_checkpoint(_model, _optim, _sched, _save_dir, cur_step, _pretrained, _hf_id, _dist) + print("Checkpoint saved. Exiting.", file=sys.stderr) + except Exception as e: + print(f"Error saving checkpoint on signal: {e}", file=sys.stderr) os._exit(0) signal.signal(signal.SIGTERM, _signal_handler) @@ -297,8 +292,8 @@ def train( if step % log_interval == 0 or step == num_iters - 1: loss_values = {k: v.item() if isinstance(v, torch.Tensor) else float(v) for k, v in loss_dict.items()} loss_values["lr"] = float(optimizer.param_groups[0]["lr"]) - # Approximate epoch: seen samples / total samples (considering grad_accum and batch_size) - epoch = (step * grad_accum_steps * batch_size) / max(1, num_train_samples) + # Account for all GPUs when converting steps to epochs. + epoch = (step * grad_accum_steps * batch_size * accelerator.world_size) / max(1, num_train_samples) loss_values["epoch"] = float(epoch) loss_values["grad_norm"] = float(grad_norm) tracker.log_metrics(loss_values, split="train") @@ -478,7 +473,7 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s except Exception as e: log(f"[Warning] Failed to load reference audio: {e}") - # 记录原模式,避免异常时状态不一致 + # Preserve the original mode so validation failures do not leak into training. prev_training = unwrapped_model.training try: # Inference setup @@ -532,7 +527,7 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s traceback.print_exc() finally: - # Restore training setup(无论成功失败都恢复) + # Always restore the training state, even if generation fails. try: # unwrapped_model.to(torch.float32) unwrapped_model.audio_vae = None @@ -544,11 +539,14 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s log(f"[Warning] Failed to restore model state: {e}") -def load_checkpoint(model, optimizer, scheduler, save_dir: Path): +def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0): """ Load the latest checkpoint if it exists. + Called by all ranks so that distributed state stays aligned. Returns the step number to resume from, or 0 if no checkpoint found. """ + import json + latest_folder = save_dir / "latest" if not latest_folder.exists(): return 0 @@ -571,9 +569,9 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path): ckpt = torch.load(lora_weights_path, map_location="cpu") state_dict = ckpt.get("state_dict", ckpt) - # Load only lora weights unwrapped.load_state_dict(state_dict, strict=False) - print(f"Loaded LoRA weights from {lora_weights_path}", file=sys.stderr) + if rank == 0: + print(f"Loaded LoRA weights from {lora_weights_path}", file=sys.stderr) else: # Full finetune: load model.safetensors or pytorch_model.bin model_path = latest_folder / "model.safetensors" @@ -589,26 +587,39 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path): state_dict = ckpt.get("state_dict", ckpt) unwrapped.load_state_dict(state_dict, strict=False) - print(f"Loaded model weights from {model_path}", file=sys.stderr) + if rank == 0: + print(f"Loaded model weights from {model_path}", file=sys.stderr) # Load optimizer state optimizer_path = latest_folder / "optimizer.pth" if optimizer_path.exists(): optimizer.load_state_dict(torch.load(optimizer_path, map_location="cpu")) - print(f"Loaded optimizer state from {optimizer_path}", file=sys.stderr) + if rank == 0: + print(f"Loaded optimizer state from {optimizer_path}", file=sys.stderr) # Load scheduler state scheduler_path = latest_folder / "scheduler.pth" if scheduler_path.exists(): scheduler.load_state_dict(torch.load(scheduler_path, map_location="cpu")) - print(f"Loaded scheduler state from {scheduler_path}", file=sys.stderr) + if rank == 0: + print(f"Loaded scheduler state from {scheduler_path}", file=sys.stderr) - # Try to infer step from checkpoint folders + state_path = latest_folder / "training_state.json" + if state_path.exists(): + with open(state_path, "r", encoding="utf-8") as f: + state = json.load(f) + resume_step = int(state.get("step", 0)) + if rank == 0: + print(f"Resuming from step {resume_step}", file=sys.stderr) + return resume_step + + # Fallback for older checkpoints without metadata. step_folders = [d for d in save_dir.iterdir() if d.is_dir() and d.name.startswith("step_")] if step_folders: steps = [int(d.name.split("_")[1]) for d in step_folders] resume_step = max(steps) - print(f"Resuming from step {resume_step}", file=sys.stderr) + if rank == 0: + print(f"Resuming from step {resume_step}", file=sys.stderr) return resume_step return 0 @@ -620,6 +631,7 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret - Full finetune: save non-vae weights to model.safetensors (or pytorch_model.bin if safetensors unavailable) - LoRA: save only lora weights to lora_weights.safetensors (or lora_weights.ckpt if safetensors unavailable) """ + import json import shutil save_dir.mkdir(parents=True, exist_ok=True) @@ -641,7 +653,6 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret # Save LoRA config and base model path to a separate JSON file # If distribute=True, save hf_model_id; otherwise save local pretrained_path - import json base_model_to_save = hf_model_id if distribute else (str(pretrained_path) if pretrained_path else None) lora_info = { "base_model": base_model_to_save, @@ -668,6 +679,8 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret torch.save(optimizer.state_dict(), folder / "optimizer.pth") torch.save(scheduler.state_dict(), folder / "scheduler.pth") + with open(folder / "training_state.json", "w", encoding="utf-8") as f: + json.dump({"step": int(step)}, f) # Update (or create) a `latest` folder by copying the most recent checkpoint latest_link = save_dir / "latest"