fix: fix some bugs in resuming multi-GPU training

This commit is contained in:
刘鑫
2026-03-13 18:43:07 +08:00
parent 7823e14b82
commit 23ed7ffeee
+31 -18
View File
@@ -188,15 +188,9 @@ def train(
num_training_steps=total_training_steps, num_training_steps=total_training_steps,
) )
# Try to load checkpoint and resume training # All ranks load the same checkpoint to keep model and optimizer state in sync.
start_step = 0 start_step = load_checkpoint(model, optimizer, scheduler, save_dir, rank=accelerator.rank)
if accelerator.rank == 0: accelerator.barrier()
start_step = load_checkpoint(model, optimizer, scheduler, save_dir)
# Broadcast start_step to all processes
if hasattr(accelerator, 'all_reduce'):
start_step_tensor = torch.tensor(start_step, device=accelerator.device)
accelerator.all_reduce(start_step_tensor)
start_step = int(start_step_tensor.item())
if start_step > 0 and accelerator.rank == 0: if start_step > 0 and accelerator.rank == 0:
tracker.print(f"Resuming training from step {start_step}") tracker.print(f"Resuming training from step {start_step}")
@@ -205,11 +199,12 @@ def train(
resume = {"step": start_step} resume = {"step": start_step}
# Register signal handler to save checkpoint on termination (SIGTERM/SIGINT) # Register signal handler to save checkpoint on termination (SIGTERM/SIGINT)
def _signal_handler(signum, frame, _model=model, _optim=optimizer, _sched=scheduler, _save_dir=save_dir, _pretrained=pretrained_path, _hf_id=hf_model_id, _dist=distribute, _resume=resume): def _signal_handler(signum, frame, _model=model, _optim=optimizer, _sched=scheduler, _save_dir=save_dir, _pretrained=pretrained_path, _hf_id=hf_model_id, _dist=distribute, _resume=resume, _rank=accelerator.rank):
try: try:
cur_step = int(_resume.get("step", start_step)) cur_step = int(_resume.get("step", start_step))
except Exception: except Exception:
cur_step = start_step cur_step = start_step
if _rank == 0:
print(f"Signal {signum} received. Saving checkpoint at step {cur_step} ...", file=sys.stderr) print(f"Signal {signum} received. Saving checkpoint at step {cur_step} ...", file=sys.stderr)
try: try:
save_checkpoint(_model, _optim, _sched, _save_dir, cur_step, _pretrained, _hf_id, _dist) save_checkpoint(_model, _optim, _sched, _save_dir, cur_step, _pretrained, _hf_id, _dist)
@@ -297,8 +292,8 @@ def train(
if step % log_interval == 0 or step == num_iters - 1: if step % log_interval == 0 or step == num_iters - 1:
loss_values = {k: v.item() if isinstance(v, torch.Tensor) else float(v) for k, v in loss_dict.items()} loss_values = {k: v.item() if isinstance(v, torch.Tensor) else float(v) for k, v in loss_dict.items()}
loss_values["lr"] = float(optimizer.param_groups[0]["lr"]) loss_values["lr"] = float(optimizer.param_groups[0]["lr"])
# Approximate epoch: seen samples / total samples (considering grad_accum and batch_size) # Account for all GPUs when converting steps to epochs.
epoch = (step * grad_accum_steps * batch_size) / max(1, num_train_samples) epoch = (step * grad_accum_steps * batch_size * accelerator.world_size) / max(1, num_train_samples)
loss_values["epoch"] = float(epoch) loss_values["epoch"] = float(epoch)
loss_values["grad_norm"] = float(grad_norm) loss_values["grad_norm"] = float(grad_norm)
tracker.log_metrics(loss_values, split="train") tracker.log_metrics(loss_values, split="train")
@@ -478,7 +473,7 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
except Exception as e: except Exception as e:
log(f"[Warning] Failed to load reference audio: {e}") log(f"[Warning] Failed to load reference audio: {e}")
# 记录原模式,避免异常时状态不一致 # Preserve the original mode so validation failures do not leak into training.
prev_training = unwrapped_model.training prev_training = unwrapped_model.training
try: try:
# Inference setup # Inference setup
@@ -532,7 +527,7 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
traceback.print_exc() traceback.print_exc()
finally: finally:
# Restore training setup(无论成功失败都恢复) # Always restore the training state, even if generation fails.
try: try:
# unwrapped_model.to(torch.float32) # unwrapped_model.to(torch.float32)
unwrapped_model.audio_vae = None unwrapped_model.audio_vae = None
@@ -544,11 +539,14 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
log(f"[Warning] Failed to restore model state: {e}") log(f"[Warning] Failed to restore model state: {e}")
def load_checkpoint(model, optimizer, scheduler, save_dir: Path): def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0):
""" """
Load the latest checkpoint if it exists. Load the latest checkpoint if it exists.
Called by all ranks so that distributed state stays aligned.
Returns the step number to resume from, or 0 if no checkpoint found. Returns the step number to resume from, or 0 if no checkpoint found.
""" """
import json
latest_folder = save_dir / "latest" latest_folder = save_dir / "latest"
if not latest_folder.exists(): if not latest_folder.exists():
return 0 return 0
@@ -571,8 +569,8 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
ckpt = torch.load(lora_weights_path, map_location="cpu") ckpt = torch.load(lora_weights_path, map_location="cpu")
state_dict = ckpt.get("state_dict", ckpt) state_dict = ckpt.get("state_dict", ckpt)
# Load only lora weights
unwrapped.load_state_dict(state_dict, strict=False) unwrapped.load_state_dict(state_dict, strict=False)
if rank == 0:
print(f"Loaded LoRA weights from {lora_weights_path}", file=sys.stderr) print(f"Loaded LoRA weights from {lora_weights_path}", file=sys.stderr)
else: else:
# Full finetune: load model.safetensors or pytorch_model.bin # Full finetune: load model.safetensors or pytorch_model.bin
@@ -589,25 +587,38 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
state_dict = ckpt.get("state_dict", ckpt) state_dict = ckpt.get("state_dict", ckpt)
unwrapped.load_state_dict(state_dict, strict=False) unwrapped.load_state_dict(state_dict, strict=False)
if rank == 0:
print(f"Loaded model weights from {model_path}", file=sys.stderr) print(f"Loaded model weights from {model_path}", file=sys.stderr)
# Load optimizer state # Load optimizer state
optimizer_path = latest_folder / "optimizer.pth" optimizer_path = latest_folder / "optimizer.pth"
if optimizer_path.exists(): if optimizer_path.exists():
optimizer.load_state_dict(torch.load(optimizer_path, map_location="cpu")) optimizer.load_state_dict(torch.load(optimizer_path, map_location="cpu"))
if rank == 0:
print(f"Loaded optimizer state from {optimizer_path}", file=sys.stderr) print(f"Loaded optimizer state from {optimizer_path}", file=sys.stderr)
# Load scheduler state # Load scheduler state
scheduler_path = latest_folder / "scheduler.pth" scheduler_path = latest_folder / "scheduler.pth"
if scheduler_path.exists(): if scheduler_path.exists():
scheduler.load_state_dict(torch.load(scheduler_path, map_location="cpu")) scheduler.load_state_dict(torch.load(scheduler_path, map_location="cpu"))
if rank == 0:
print(f"Loaded scheduler state from {scheduler_path}", file=sys.stderr) print(f"Loaded scheduler state from {scheduler_path}", file=sys.stderr)
# Try to infer step from checkpoint folders state_path = latest_folder / "training_state.json"
if state_path.exists():
with open(state_path, "r", encoding="utf-8") as f:
state = json.load(f)
resume_step = int(state.get("step", 0))
if rank == 0:
print(f"Resuming from step {resume_step}", file=sys.stderr)
return resume_step
# Fallback for older checkpoints without metadata.
step_folders = [d for d in save_dir.iterdir() if d.is_dir() and d.name.startswith("step_")] step_folders = [d for d in save_dir.iterdir() if d.is_dir() and d.name.startswith("step_")]
if step_folders: if step_folders:
steps = [int(d.name.split("_")[1]) for d in step_folders] steps = [int(d.name.split("_")[1]) for d in step_folders]
resume_step = max(steps) resume_step = max(steps)
if rank == 0:
print(f"Resuming from step {resume_step}", file=sys.stderr) print(f"Resuming from step {resume_step}", file=sys.stderr)
return resume_step return resume_step
@@ -620,6 +631,7 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
- Full finetune: save non-vae weights to model.safetensors (or pytorch_model.bin if safetensors unavailable) - Full finetune: save non-vae weights to model.safetensors (or pytorch_model.bin if safetensors unavailable)
- LoRA: save only lora weights to lora_weights.safetensors (or lora_weights.ckpt if safetensors unavailable) - LoRA: save only lora weights to lora_weights.safetensors (or lora_weights.ckpt if safetensors unavailable)
""" """
import json
import shutil import shutil
save_dir.mkdir(parents=True, exist_ok=True) save_dir.mkdir(parents=True, exist_ok=True)
@@ -641,7 +653,6 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
# Save LoRA config and base model path to a separate JSON file # Save LoRA config and base model path to a separate JSON file
# If distribute=True, save hf_model_id; otherwise save local pretrained_path # If distribute=True, save hf_model_id; otherwise save local pretrained_path
import json
base_model_to_save = hf_model_id if distribute else (str(pretrained_path) if pretrained_path else None) base_model_to_save = hf_model_id if distribute else (str(pretrained_path) if pretrained_path else None)
lora_info = { lora_info = {
"base_model": base_model_to_save, "base_model": base_model_to_save,
@@ -668,6 +679,8 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
torch.save(optimizer.state_dict(), folder / "optimizer.pth") torch.save(optimizer.state_dict(), folder / "optimizer.pth")
torch.save(scheduler.state_dict(), folder / "scheduler.pth") torch.save(scheduler.state_dict(), folder / "scheduler.pth")
with open(folder / "training_state.json", "w", encoding="utf-8") as f:
json.dump({"step": int(step)}, f)
# Update (or create) a `latest` folder by copying the most recent checkpoint # Update (or create) a `latest` folder by copying the most recent checkpoint
latest_link = save_dir / "latest" latest_link = save_dir / "latest"