fix: fix some bugs in resuming multi-GPU training

This commit is contained in:
刘鑫
2026-03-13 18:43:07 +08:00
parent 7823e14b82
commit 23ed7ffeee
+42 -29
View File
@@ -188,15 +188,9 @@ def train(
num_training_steps=total_training_steps,
)
# Try to load checkpoint and resume training
start_step = 0
if accelerator.rank == 0:
start_step = load_checkpoint(model, optimizer, scheduler, save_dir)
# Broadcast start_step to all processes
if hasattr(accelerator, 'all_reduce'):
start_step_tensor = torch.tensor(start_step, device=accelerator.device)
accelerator.all_reduce(start_step_tensor)
start_step = int(start_step_tensor.item())
# All ranks load the same checkpoint to keep model and optimizer state in sync.
start_step = load_checkpoint(model, optimizer, scheduler, save_dir, rank=accelerator.rank)
accelerator.barrier()
if start_step > 0 and accelerator.rank == 0:
tracker.print(f"Resuming training from step {start_step}")
@@ -205,17 +199,18 @@ def train(
resume = {"step": start_step}
# Register signal handler to save checkpoint on termination (SIGTERM/SIGINT)
def _signal_handler(signum, frame, _model=model, _optim=optimizer, _sched=scheduler, _save_dir=save_dir, _pretrained=pretrained_path, _hf_id=hf_model_id, _dist=distribute, _resume=resume):
def _signal_handler(signum, frame, _model=model, _optim=optimizer, _sched=scheduler, _save_dir=save_dir, _pretrained=pretrained_path, _hf_id=hf_model_id, _dist=distribute, _resume=resume, _rank=accelerator.rank):
try:
cur_step = int(_resume.get("step", start_step))
except Exception:
cur_step = start_step
print(f"Signal {signum} received. Saving checkpoint at step {cur_step} ...", file=sys.stderr)
try:
save_checkpoint(_model, _optim, _sched, _save_dir, cur_step, _pretrained, _hf_id, _dist)
print("Checkpoint saved. Exiting.", file=sys.stderr)
except Exception as e:
print(f"Error saving checkpoint on signal: {e}", file=sys.stderr)
if _rank == 0:
print(f"Signal {signum} received. Saving checkpoint at step {cur_step} ...", file=sys.stderr)
try:
save_checkpoint(_model, _optim, _sched, _save_dir, cur_step, _pretrained, _hf_id, _dist)
print("Checkpoint saved. Exiting.", file=sys.stderr)
except Exception as e:
print(f"Error saving checkpoint on signal: {e}", file=sys.stderr)
os._exit(0)
signal.signal(signal.SIGTERM, _signal_handler)
@@ -297,8 +292,8 @@ def train(
if step % log_interval == 0 or step == num_iters - 1:
loss_values = {k: v.item() if isinstance(v, torch.Tensor) else float(v) for k, v in loss_dict.items()}
loss_values["lr"] = float(optimizer.param_groups[0]["lr"])
# Approximate epoch: seen samples / total samples (considering grad_accum and batch_size)
epoch = (step * grad_accum_steps * batch_size) / max(1, num_train_samples)
# Account for all GPUs when converting steps to epochs.
epoch = (step * grad_accum_steps * batch_size * accelerator.world_size) / max(1, num_train_samples)
loss_values["epoch"] = float(epoch)
loss_values["grad_norm"] = float(grad_norm)
tracker.log_metrics(loss_values, split="train")
@@ -478,7 +473,7 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
except Exception as e:
log(f"[Warning] Failed to load reference audio: {e}")
# 记录原模式,避免异常时状态不一致
# Preserve the original mode so validation failures do not leak into training.
prev_training = unwrapped_model.training
try:
# Inference setup
@@ -532,7 +527,7 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
traceback.print_exc()
finally:
# Restore training setup(无论成功失败都恢复)
# Always restore the training state, even if generation fails.
try:
# unwrapped_model.to(torch.float32)
unwrapped_model.audio_vae = None
@@ -544,11 +539,14 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
log(f"[Warning] Failed to restore model state: {e}")
def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0):
"""
Load the latest checkpoint if it exists.
Called by all ranks so that distributed state stays aligned.
Returns the step number to resume from, or 0 if no checkpoint found.
"""
import json
latest_folder = save_dir / "latest"
if not latest_folder.exists():
return 0
@@ -571,9 +569,9 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
ckpt = torch.load(lora_weights_path, map_location="cpu")
state_dict = ckpt.get("state_dict", ckpt)
# Load only lora weights
unwrapped.load_state_dict(state_dict, strict=False)
print(f"Loaded LoRA weights from {lora_weights_path}", file=sys.stderr)
if rank == 0:
print(f"Loaded LoRA weights from {lora_weights_path}", file=sys.stderr)
else:
# Full finetune: load model.safetensors or pytorch_model.bin
model_path = latest_folder / "model.safetensors"
@@ -589,26 +587,39 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
state_dict = ckpt.get("state_dict", ckpt)
unwrapped.load_state_dict(state_dict, strict=False)
print(f"Loaded model weights from {model_path}", file=sys.stderr)
if rank == 0:
print(f"Loaded model weights from {model_path}", file=sys.stderr)
# Load optimizer state
optimizer_path = latest_folder / "optimizer.pth"
if optimizer_path.exists():
optimizer.load_state_dict(torch.load(optimizer_path, map_location="cpu"))
print(f"Loaded optimizer state from {optimizer_path}", file=sys.stderr)
if rank == 0:
print(f"Loaded optimizer state from {optimizer_path}", file=sys.stderr)
# Load scheduler state
scheduler_path = latest_folder / "scheduler.pth"
if scheduler_path.exists():
scheduler.load_state_dict(torch.load(scheduler_path, map_location="cpu"))
print(f"Loaded scheduler state from {scheduler_path}", file=sys.stderr)
if rank == 0:
print(f"Loaded scheduler state from {scheduler_path}", file=sys.stderr)
# Try to infer step from checkpoint folders
state_path = latest_folder / "training_state.json"
if state_path.exists():
with open(state_path, "r", encoding="utf-8") as f:
state = json.load(f)
resume_step = int(state.get("step", 0))
if rank == 0:
print(f"Resuming from step {resume_step}", file=sys.stderr)
return resume_step
# Fallback for older checkpoints without metadata.
step_folders = [d for d in save_dir.iterdir() if d.is_dir() and d.name.startswith("step_")]
if step_folders:
steps = [int(d.name.split("_")[1]) for d in step_folders]
resume_step = max(steps)
print(f"Resuming from step {resume_step}", file=sys.stderr)
if rank == 0:
print(f"Resuming from step {resume_step}", file=sys.stderr)
return resume_step
return 0
@@ -620,6 +631,7 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
- Full finetune: save non-vae weights to model.safetensors (or pytorch_model.bin if safetensors unavailable)
- LoRA: save only lora weights to lora_weights.safetensors (or lora_weights.ckpt if safetensors unavailable)
"""
import json
import shutil
save_dir.mkdir(parents=True, exist_ok=True)
@@ -641,7 +653,6 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
# Save LoRA config and base model path to a separate JSON file
# If distribute=True, save hf_model_id; otherwise save local pretrained_path
import json
base_model_to_save = hf_model_id if distribute else (str(pretrained_path) if pretrained_path else None)
lora_info = {
"base_model": base_model_to_save,
@@ -668,6 +679,8 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
torch.save(optimizer.state_dict(), folder / "optimizer.pth")
torch.save(scheduler.state_dict(), folder / "scheduler.pth")
with open(folder / "training_state.json", "w", encoding="utf-8") as f:
json.dump({"step": int(step)}, f)
# Update (or create) a `latest` folder by copying the most recent checkpoint
latest_link = save_dir / "latest"