fix: fix some bugs in resuming multi-GPU training
This commit is contained in:
@@ -188,15 +188,9 @@ def train(
|
|||||||
num_training_steps=total_training_steps,
|
num_training_steps=total_training_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Try to load checkpoint and resume training
|
# All ranks load the same checkpoint to keep model and optimizer state in sync.
|
||||||
start_step = 0
|
start_step = load_checkpoint(model, optimizer, scheduler, save_dir, rank=accelerator.rank)
|
||||||
if accelerator.rank == 0:
|
accelerator.barrier()
|
||||||
start_step = load_checkpoint(model, optimizer, scheduler, save_dir)
|
|
||||||
# Broadcast start_step to all processes
|
|
||||||
if hasattr(accelerator, 'all_reduce'):
|
|
||||||
start_step_tensor = torch.tensor(start_step, device=accelerator.device)
|
|
||||||
accelerator.all_reduce(start_step_tensor)
|
|
||||||
start_step = int(start_step_tensor.item())
|
|
||||||
|
|
||||||
if start_step > 0 and accelerator.rank == 0:
|
if start_step > 0 and accelerator.rank == 0:
|
||||||
tracker.print(f"Resuming training from step {start_step}")
|
tracker.print(f"Resuming training from step {start_step}")
|
||||||
@@ -205,11 +199,12 @@ def train(
|
|||||||
resume = {"step": start_step}
|
resume = {"step": start_step}
|
||||||
|
|
||||||
# Register signal handler to save checkpoint on termination (SIGTERM/SIGINT)
|
# Register signal handler to save checkpoint on termination (SIGTERM/SIGINT)
|
||||||
def _signal_handler(signum, frame, _model=model, _optim=optimizer, _sched=scheduler, _save_dir=save_dir, _pretrained=pretrained_path, _hf_id=hf_model_id, _dist=distribute, _resume=resume):
|
def _signal_handler(signum, frame, _model=model, _optim=optimizer, _sched=scheduler, _save_dir=save_dir, _pretrained=pretrained_path, _hf_id=hf_model_id, _dist=distribute, _resume=resume, _rank=accelerator.rank):
|
||||||
try:
|
try:
|
||||||
cur_step = int(_resume.get("step", start_step))
|
cur_step = int(_resume.get("step", start_step))
|
||||||
except Exception:
|
except Exception:
|
||||||
cur_step = start_step
|
cur_step = start_step
|
||||||
|
if _rank == 0:
|
||||||
print(f"Signal {signum} received. Saving checkpoint at step {cur_step} ...", file=sys.stderr)
|
print(f"Signal {signum} received. Saving checkpoint at step {cur_step} ...", file=sys.stderr)
|
||||||
try:
|
try:
|
||||||
save_checkpoint(_model, _optim, _sched, _save_dir, cur_step, _pretrained, _hf_id, _dist)
|
save_checkpoint(_model, _optim, _sched, _save_dir, cur_step, _pretrained, _hf_id, _dist)
|
||||||
@@ -297,8 +292,8 @@ def train(
|
|||||||
if step % log_interval == 0 or step == num_iters - 1:
|
if step % log_interval == 0 or step == num_iters - 1:
|
||||||
loss_values = {k: v.item() if isinstance(v, torch.Tensor) else float(v) for k, v in loss_dict.items()}
|
loss_values = {k: v.item() if isinstance(v, torch.Tensor) else float(v) for k, v in loss_dict.items()}
|
||||||
loss_values["lr"] = float(optimizer.param_groups[0]["lr"])
|
loss_values["lr"] = float(optimizer.param_groups[0]["lr"])
|
||||||
# Approximate epoch: seen samples / total samples (considering grad_accum and batch_size)
|
# Account for all GPUs when converting steps to epochs.
|
||||||
epoch = (step * grad_accum_steps * batch_size) / max(1, num_train_samples)
|
epoch = (step * grad_accum_steps * batch_size * accelerator.world_size) / max(1, num_train_samples)
|
||||||
loss_values["epoch"] = float(epoch)
|
loss_values["epoch"] = float(epoch)
|
||||||
loss_values["grad_norm"] = float(grad_norm)
|
loss_values["grad_norm"] = float(grad_norm)
|
||||||
tracker.log_metrics(loss_values, split="train")
|
tracker.log_metrics(loss_values, split="train")
|
||||||
@@ -478,7 +473,7 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
log(f"[Warning] Failed to load reference audio: {e}")
|
log(f"[Warning] Failed to load reference audio: {e}")
|
||||||
|
|
||||||
# 记录原模式,避免异常时状态不一致
|
# Preserve the original mode so validation failures do not leak into training.
|
||||||
prev_training = unwrapped_model.training
|
prev_training = unwrapped_model.training
|
||||||
try:
|
try:
|
||||||
# Inference setup
|
# Inference setup
|
||||||
@@ -532,7 +527,7 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Restore training setup(无论成功失败都恢复)
|
# Always restore the training state, even if generation fails.
|
||||||
try:
|
try:
|
||||||
# unwrapped_model.to(torch.float32)
|
# unwrapped_model.to(torch.float32)
|
||||||
unwrapped_model.audio_vae = None
|
unwrapped_model.audio_vae = None
|
||||||
@@ -544,11 +539,14 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
|
|||||||
log(f"[Warning] Failed to restore model state: {e}")
|
log(f"[Warning] Failed to restore model state: {e}")
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
|
def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0):
|
||||||
"""
|
"""
|
||||||
Load the latest checkpoint if it exists.
|
Load the latest checkpoint if it exists.
|
||||||
|
Called by all ranks so that distributed state stays aligned.
|
||||||
Returns the step number to resume from, or 0 if no checkpoint found.
|
Returns the step number to resume from, or 0 if no checkpoint found.
|
||||||
"""
|
"""
|
||||||
|
import json
|
||||||
|
|
||||||
latest_folder = save_dir / "latest"
|
latest_folder = save_dir / "latest"
|
||||||
if not latest_folder.exists():
|
if not latest_folder.exists():
|
||||||
return 0
|
return 0
|
||||||
@@ -571,8 +569,8 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
|
|||||||
ckpt = torch.load(lora_weights_path, map_location="cpu")
|
ckpt = torch.load(lora_weights_path, map_location="cpu")
|
||||||
state_dict = ckpt.get("state_dict", ckpt)
|
state_dict = ckpt.get("state_dict", ckpt)
|
||||||
|
|
||||||
# Load only lora weights
|
|
||||||
unwrapped.load_state_dict(state_dict, strict=False)
|
unwrapped.load_state_dict(state_dict, strict=False)
|
||||||
|
if rank == 0:
|
||||||
print(f"Loaded LoRA weights from {lora_weights_path}", file=sys.stderr)
|
print(f"Loaded LoRA weights from {lora_weights_path}", file=sys.stderr)
|
||||||
else:
|
else:
|
||||||
# Full finetune: load model.safetensors or pytorch_model.bin
|
# Full finetune: load model.safetensors or pytorch_model.bin
|
||||||
@@ -589,25 +587,38 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
|
|||||||
state_dict = ckpt.get("state_dict", ckpt)
|
state_dict = ckpt.get("state_dict", ckpt)
|
||||||
|
|
||||||
unwrapped.load_state_dict(state_dict, strict=False)
|
unwrapped.load_state_dict(state_dict, strict=False)
|
||||||
|
if rank == 0:
|
||||||
print(f"Loaded model weights from {model_path}", file=sys.stderr)
|
print(f"Loaded model weights from {model_path}", file=sys.stderr)
|
||||||
|
|
||||||
# Load optimizer state
|
# Load optimizer state
|
||||||
optimizer_path = latest_folder / "optimizer.pth"
|
optimizer_path = latest_folder / "optimizer.pth"
|
||||||
if optimizer_path.exists():
|
if optimizer_path.exists():
|
||||||
optimizer.load_state_dict(torch.load(optimizer_path, map_location="cpu"))
|
optimizer.load_state_dict(torch.load(optimizer_path, map_location="cpu"))
|
||||||
|
if rank == 0:
|
||||||
print(f"Loaded optimizer state from {optimizer_path}", file=sys.stderr)
|
print(f"Loaded optimizer state from {optimizer_path}", file=sys.stderr)
|
||||||
|
|
||||||
# Load scheduler state
|
# Load scheduler state
|
||||||
scheduler_path = latest_folder / "scheduler.pth"
|
scheduler_path = latest_folder / "scheduler.pth"
|
||||||
if scheduler_path.exists():
|
if scheduler_path.exists():
|
||||||
scheduler.load_state_dict(torch.load(scheduler_path, map_location="cpu"))
|
scheduler.load_state_dict(torch.load(scheduler_path, map_location="cpu"))
|
||||||
|
if rank == 0:
|
||||||
print(f"Loaded scheduler state from {scheduler_path}", file=sys.stderr)
|
print(f"Loaded scheduler state from {scheduler_path}", file=sys.stderr)
|
||||||
|
|
||||||
# Try to infer step from checkpoint folders
|
state_path = latest_folder / "training_state.json"
|
||||||
|
if state_path.exists():
|
||||||
|
with open(state_path, "r", encoding="utf-8") as f:
|
||||||
|
state = json.load(f)
|
||||||
|
resume_step = int(state.get("step", 0))
|
||||||
|
if rank == 0:
|
||||||
|
print(f"Resuming from step {resume_step}", file=sys.stderr)
|
||||||
|
return resume_step
|
||||||
|
|
||||||
|
# Fallback for older checkpoints without metadata.
|
||||||
step_folders = [d for d in save_dir.iterdir() if d.is_dir() and d.name.startswith("step_")]
|
step_folders = [d for d in save_dir.iterdir() if d.is_dir() and d.name.startswith("step_")]
|
||||||
if step_folders:
|
if step_folders:
|
||||||
steps = [int(d.name.split("_")[1]) for d in step_folders]
|
steps = [int(d.name.split("_")[1]) for d in step_folders]
|
||||||
resume_step = max(steps)
|
resume_step = max(steps)
|
||||||
|
if rank == 0:
|
||||||
print(f"Resuming from step {resume_step}", file=sys.stderr)
|
print(f"Resuming from step {resume_step}", file=sys.stderr)
|
||||||
return resume_step
|
return resume_step
|
||||||
|
|
||||||
@@ -620,6 +631,7 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
|
|||||||
- Full finetune: save non-vae weights to model.safetensors (or pytorch_model.bin if safetensors unavailable)
|
- Full finetune: save non-vae weights to model.safetensors (or pytorch_model.bin if safetensors unavailable)
|
||||||
- LoRA: save only lora weights to lora_weights.safetensors (or lora_weights.ckpt if safetensors unavailable)
|
- LoRA: save only lora weights to lora_weights.safetensors (or lora_weights.ckpt if safetensors unavailable)
|
||||||
"""
|
"""
|
||||||
|
import json
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -641,7 +653,6 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
|
|||||||
|
|
||||||
# Save LoRA config and base model path to a separate JSON file
|
# Save LoRA config and base model path to a separate JSON file
|
||||||
# If distribute=True, save hf_model_id; otherwise save local pretrained_path
|
# If distribute=True, save hf_model_id; otherwise save local pretrained_path
|
||||||
import json
|
|
||||||
base_model_to_save = hf_model_id if distribute else (str(pretrained_path) if pretrained_path else None)
|
base_model_to_save = hf_model_id if distribute else (str(pretrained_path) if pretrained_path else None)
|
||||||
lora_info = {
|
lora_info = {
|
||||||
"base_model": base_model_to_save,
|
"base_model": base_model_to_save,
|
||||||
@@ -668,6 +679,8 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
|
|||||||
|
|
||||||
torch.save(optimizer.state_dict(), folder / "optimizer.pth")
|
torch.save(optimizer.state_dict(), folder / "optimizer.pth")
|
||||||
torch.save(scheduler.state_dict(), folder / "scheduler.pth")
|
torch.save(scheduler.state_dict(), folder / "scheduler.pth")
|
||||||
|
with open(folder / "training_state.json", "w", encoding="utf-8") as f:
|
||||||
|
json.dump({"step": int(step)}, f)
|
||||||
|
|
||||||
# Update (or create) a `latest` folder by copying the most recent checkpoint
|
# Update (or create) a `latest` folder by copying the most recent checkpoint
|
||||||
latest_link = save_dir / "latest"
|
latest_link = save_dir / "latest"
|
||||||
|
|||||||
Reference in New Issue
Block a user