fix: fix some bugs in resuming multi-GPU training
This commit is contained in:
@@ -188,15 +188,9 @@ def train(
|
||||
num_training_steps=total_training_steps,
|
||||
)
|
||||
|
||||
# Try to load checkpoint and resume training
|
||||
start_step = 0
|
||||
if accelerator.rank == 0:
|
||||
start_step = load_checkpoint(model, optimizer, scheduler, save_dir)
|
||||
# Broadcast start_step to all processes
|
||||
if hasattr(accelerator, 'all_reduce'):
|
||||
start_step_tensor = torch.tensor(start_step, device=accelerator.device)
|
||||
accelerator.all_reduce(start_step_tensor)
|
||||
start_step = int(start_step_tensor.item())
|
||||
# All ranks load the same checkpoint to keep model and optimizer state in sync.
|
||||
start_step = load_checkpoint(model, optimizer, scheduler, save_dir, rank=accelerator.rank)
|
||||
accelerator.barrier()
|
||||
|
||||
if start_step > 0 and accelerator.rank == 0:
|
||||
tracker.print(f"Resuming training from step {start_step}")
|
||||
@@ -205,17 +199,18 @@ def train(
|
||||
resume = {"step": start_step}
|
||||
|
||||
# Register signal handler to save checkpoint on termination (SIGTERM/SIGINT)
|
||||
def _signal_handler(signum, frame, _model=model, _optim=optimizer, _sched=scheduler, _save_dir=save_dir, _pretrained=pretrained_path, _hf_id=hf_model_id, _dist=distribute, _resume=resume):
|
||||
def _signal_handler(signum, frame, _model=model, _optim=optimizer, _sched=scheduler, _save_dir=save_dir, _pretrained=pretrained_path, _hf_id=hf_model_id, _dist=distribute, _resume=resume, _rank=accelerator.rank):
|
||||
try:
|
||||
cur_step = int(_resume.get("step", start_step))
|
||||
except Exception:
|
||||
cur_step = start_step
|
||||
print(f"Signal {signum} received. Saving checkpoint at step {cur_step} ...", file=sys.stderr)
|
||||
try:
|
||||
save_checkpoint(_model, _optim, _sched, _save_dir, cur_step, _pretrained, _hf_id, _dist)
|
||||
print("Checkpoint saved. Exiting.", file=sys.stderr)
|
||||
except Exception as e:
|
||||
print(f"Error saving checkpoint on signal: {e}", file=sys.stderr)
|
||||
if _rank == 0:
|
||||
print(f"Signal {signum} received. Saving checkpoint at step {cur_step} ...", file=sys.stderr)
|
||||
try:
|
||||
save_checkpoint(_model, _optim, _sched, _save_dir, cur_step, _pretrained, _hf_id, _dist)
|
||||
print("Checkpoint saved. Exiting.", file=sys.stderr)
|
||||
except Exception as e:
|
||||
print(f"Error saving checkpoint on signal: {e}", file=sys.stderr)
|
||||
os._exit(0)
|
||||
|
||||
signal.signal(signal.SIGTERM, _signal_handler)
|
||||
@@ -297,8 +292,8 @@ def train(
|
||||
if step % log_interval == 0 or step == num_iters - 1:
|
||||
loss_values = {k: v.item() if isinstance(v, torch.Tensor) else float(v) for k, v in loss_dict.items()}
|
||||
loss_values["lr"] = float(optimizer.param_groups[0]["lr"])
|
||||
# Approximate epoch: seen samples / total samples (considering grad_accum and batch_size)
|
||||
epoch = (step * grad_accum_steps * batch_size) / max(1, num_train_samples)
|
||||
# Account for all GPUs when converting steps to epochs.
|
||||
epoch = (step * grad_accum_steps * batch_size * accelerator.world_size) / max(1, num_train_samples)
|
||||
loss_values["epoch"] = float(epoch)
|
||||
loss_values["grad_norm"] = float(grad_norm)
|
||||
tracker.log_metrics(loss_values, split="train")
|
||||
@@ -478,7 +473,7 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
|
||||
except Exception as e:
|
||||
log(f"[Warning] Failed to load reference audio: {e}")
|
||||
|
||||
# 记录原模式,避免异常时状态不一致
|
||||
# Preserve the original mode so validation failures do not leak into training.
|
||||
prev_training = unwrapped_model.training
|
||||
try:
|
||||
# Inference setup
|
||||
@@ -532,7 +527,7 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
|
||||
traceback.print_exc()
|
||||
|
||||
finally:
|
||||
# Restore training setup(无论成功失败都恢复)
|
||||
# Always restore the training state, even if generation fails.
|
||||
try:
|
||||
# unwrapped_model.to(torch.float32)
|
||||
unwrapped_model.audio_vae = None
|
||||
@@ -544,11 +539,14 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
|
||||
log(f"[Warning] Failed to restore model state: {e}")
|
||||
|
||||
|
||||
def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
|
||||
def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0):
|
||||
"""
|
||||
Load the latest checkpoint if it exists.
|
||||
Called by all ranks so that distributed state stays aligned.
|
||||
Returns the step number to resume from, or 0 if no checkpoint found.
|
||||
"""
|
||||
import json
|
||||
|
||||
latest_folder = save_dir / "latest"
|
||||
if not latest_folder.exists():
|
||||
return 0
|
||||
@@ -571,9 +569,9 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
|
||||
ckpt = torch.load(lora_weights_path, map_location="cpu")
|
||||
state_dict = ckpt.get("state_dict", ckpt)
|
||||
|
||||
# Load only lora weights
|
||||
unwrapped.load_state_dict(state_dict, strict=False)
|
||||
print(f"Loaded LoRA weights from {lora_weights_path}", file=sys.stderr)
|
||||
if rank == 0:
|
||||
print(f"Loaded LoRA weights from {lora_weights_path}", file=sys.stderr)
|
||||
else:
|
||||
# Full finetune: load model.safetensors or pytorch_model.bin
|
||||
model_path = latest_folder / "model.safetensors"
|
||||
@@ -589,26 +587,39 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
|
||||
state_dict = ckpt.get("state_dict", ckpt)
|
||||
|
||||
unwrapped.load_state_dict(state_dict, strict=False)
|
||||
print(f"Loaded model weights from {model_path}", file=sys.stderr)
|
||||
if rank == 0:
|
||||
print(f"Loaded model weights from {model_path}", file=sys.stderr)
|
||||
|
||||
# Load optimizer state
|
||||
optimizer_path = latest_folder / "optimizer.pth"
|
||||
if optimizer_path.exists():
|
||||
optimizer.load_state_dict(torch.load(optimizer_path, map_location="cpu"))
|
||||
print(f"Loaded optimizer state from {optimizer_path}", file=sys.stderr)
|
||||
if rank == 0:
|
||||
print(f"Loaded optimizer state from {optimizer_path}", file=sys.stderr)
|
||||
|
||||
# Load scheduler state
|
||||
scheduler_path = latest_folder / "scheduler.pth"
|
||||
if scheduler_path.exists():
|
||||
scheduler.load_state_dict(torch.load(scheduler_path, map_location="cpu"))
|
||||
print(f"Loaded scheduler state from {scheduler_path}", file=sys.stderr)
|
||||
if rank == 0:
|
||||
print(f"Loaded scheduler state from {scheduler_path}", file=sys.stderr)
|
||||
|
||||
# Try to infer step from checkpoint folders
|
||||
state_path = latest_folder / "training_state.json"
|
||||
if state_path.exists():
|
||||
with open(state_path, "r", encoding="utf-8") as f:
|
||||
state = json.load(f)
|
||||
resume_step = int(state.get("step", 0))
|
||||
if rank == 0:
|
||||
print(f"Resuming from step {resume_step}", file=sys.stderr)
|
||||
return resume_step
|
||||
|
||||
# Fallback for older checkpoints without metadata.
|
||||
step_folders = [d for d in save_dir.iterdir() if d.is_dir() and d.name.startswith("step_")]
|
||||
if step_folders:
|
||||
steps = [int(d.name.split("_")[1]) for d in step_folders]
|
||||
resume_step = max(steps)
|
||||
print(f"Resuming from step {resume_step}", file=sys.stderr)
|
||||
if rank == 0:
|
||||
print(f"Resuming from step {resume_step}", file=sys.stderr)
|
||||
return resume_step
|
||||
|
||||
return 0
|
||||
@@ -620,6 +631,7 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
|
||||
- Full finetune: save non-vae weights to model.safetensors (or pytorch_model.bin if safetensors unavailable)
|
||||
- LoRA: save only lora weights to lora_weights.safetensors (or lora_weights.ckpt if safetensors unavailable)
|
||||
"""
|
||||
import json
|
||||
import shutil
|
||||
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -641,7 +653,6 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
|
||||
|
||||
# Save LoRA config and base model path to a separate JSON file
|
||||
# If distribute=True, save hf_model_id; otherwise save local pretrained_path
|
||||
import json
|
||||
base_model_to_save = hf_model_id if distribute else (str(pretrained_path) if pretrained_path else None)
|
||||
lora_info = {
|
||||
"base_model": base_model_to_save,
|
||||
@@ -668,6 +679,8 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
|
||||
|
||||
torch.save(optimizer.state_dict(), folder / "optimizer.pth")
|
||||
torch.save(scheduler.state_dict(), folder / "scheduler.pth")
|
||||
with open(folder / "training_state.json", "w", encoding="utf-8") as f:
|
||||
json.dump({"step": int(step)}, f)
|
||||
|
||||
# Update (or create) a `latest` folder by copying the most recent checkpoint
|
||||
latest_link = save_dir / "latest"
|
||||
|
||||
Reference in New Issue
Block a user