update voxcpm2
This commit is contained in:
@@ -7,7 +7,7 @@ project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root / "src"))
|
||||
|
||||
import contextlib
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict
|
||||
|
||||
import argbind
|
||||
import torch
|
||||
@@ -17,16 +17,19 @@ from transformers import get_cosine_schedule_with_warmup
|
||||
import signal
|
||||
import os
|
||||
|
||||
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
try:
|
||||
from safetensors.torch import save_file
|
||||
|
||||
SAFETENSORS_AVAILABLE = True
|
||||
except ImportError:
|
||||
SAFETENSORS_AVAILABLE = False
|
||||
print("Warning: safetensors not available, will use pytorch format", file=sys.stderr)
|
||||
|
||||
from voxcpm.model import VoxCPMModel
|
||||
import json
|
||||
|
||||
from voxcpm.model import VoxCPMModel, VoxCPM2Model
|
||||
from voxcpm.model.voxcpm import LoRAConfig
|
||||
from voxcpm.training import (
|
||||
Accelerator,
|
||||
@@ -61,15 +64,15 @@ def train(
|
||||
lora: dict = None,
|
||||
config_path: str = "",
|
||||
# Distribution options (for LoRA checkpoints)
|
||||
hf_model_id: str = "", # HuggingFace model ID (e.g., "openbmb/VoxCPM1.5")
|
||||
distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path
|
||||
hf_model_id: str = "", # HuggingFace model ID (e.g., "openbmb/VoxCPM1.5")
|
||||
distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path
|
||||
):
|
||||
_ = config_path
|
||||
|
||||
|
||||
# Validate distribution options
|
||||
if lora is not None and distribute and not hf_model_id:
|
||||
raise ValueError("hf_model_id is required when distribute=True")
|
||||
|
||||
|
||||
accelerator = Accelerator(amp=True)
|
||||
|
||||
save_dir = Path(save_path)
|
||||
@@ -84,7 +87,15 @@ def train(
|
||||
writer = SummaryWriter(log_dir=str(tb_dir)) if accelerator.rank == 0 else None
|
||||
tracker = TrainingTracker(writer=writer, log_file=str(save_dir / "train.log"), rank=accelerator.rank)
|
||||
|
||||
base_model = VoxCPMModel.from_local(pretrained_path, optimize=False, training=True, lora_config=LoRAConfig(**lora) if lora else None)
|
||||
# Auto-detect model architecture from config.json
|
||||
with open(os.path.join(pretrained_path, "config.json"), "r", encoding="utf-8") as _f:
|
||||
_arch = json.load(_f).get("architecture", "voxcpm").lower()
|
||||
_model_cls = VoxCPM2Model if _arch == "voxcpm2" else VoxCPMModel
|
||||
if accelerator.rank == 0:
|
||||
print(f"Detected architecture: {_arch} -> {_model_cls.__name__}", file=sys.stderr)
|
||||
base_model = _model_cls.from_local(
|
||||
pretrained_path, optimize=False, training=True, lora_config=LoRAConfig(**lora) if lora else None
|
||||
)
|
||||
tokenizer = base_model.text_tokenizer
|
||||
|
||||
train_ds, val_ds = load_audio_text_datasets(
|
||||
@@ -166,7 +177,6 @@ def train(
|
||||
unwrapped_model = accelerator.unwrap(model)
|
||||
unwrapped_model.train()
|
||||
|
||||
|
||||
# Only print param info on rank 0 to avoid cluttered output
|
||||
if accelerator.rank == 0:
|
||||
for name, param in model.named_parameters():
|
||||
@@ -191,7 +201,7 @@ def train(
|
||||
# All ranks load the same checkpoint to keep model and optimizer state in sync.
|
||||
start_step = load_checkpoint(model, optimizer, scheduler, save_dir, rank=accelerator.rank)
|
||||
accelerator.barrier()
|
||||
|
||||
|
||||
if start_step > 0 and accelerator.rank == 0:
|
||||
tracker.print(f"Resuming training from step {start_step}")
|
||||
|
||||
@@ -199,7 +209,19 @@ def train(
|
||||
resume = {"step": start_step}
|
||||
|
||||
# Register signal handler to save checkpoint on termination (SIGTERM/SIGINT)
|
||||
def _signal_handler(signum, frame, _model=model, _optim=optimizer, _sched=scheduler, _save_dir=save_dir, _pretrained=pretrained_path, _hf_id=hf_model_id, _dist=distribute, _resume=resume, _rank=accelerator.rank):
|
||||
def _signal_handler(
|
||||
signum,
|
||||
frame,
|
||||
_model=model,
|
||||
_optim=optimizer,
|
||||
_sched=scheduler,
|
||||
_save_dir=save_dir,
|
||||
_pretrained=pretrained_path,
|
||||
_hf_id=hf_model_id,
|
||||
_dist=distribute,
|
||||
_resume=resume,
|
||||
_rank=accelerator.rank,
|
||||
):
|
||||
try:
|
||||
cur_step = int(_resume.get("step", start_step))
|
||||
except Exception:
|
||||
@@ -229,8 +251,8 @@ def train(
|
||||
except StopIteration:
|
||||
data_epoch += 1
|
||||
# Key: set DistributedSampler epoch to ensure different data order each epoch
|
||||
sampler = getattr(train_loader, 'sampler', None)
|
||||
if hasattr(sampler, 'set_epoch'):
|
||||
sampler = getattr(train_loader, "sampler", None)
|
||||
if hasattr(sampler, "set_epoch"):
|
||||
sampler.set_epoch(data_epoch)
|
||||
train_iter = iter(train_loader)
|
||||
return next(train_iter)
|
||||
@@ -250,7 +272,7 @@ def train(
|
||||
|
||||
# Only sync gradients on the last micro-batch
|
||||
# Use no_sync() for intermediate steps to reduce communication overhead
|
||||
is_last_micro_step = (micro_step == grad_accum_steps - 1)
|
||||
is_last_micro_step = micro_step == grad_accum_steps - 1
|
||||
sync_context = contextlib.nullcontext() if is_last_micro_step else accelerator.no_sync()
|
||||
|
||||
with sync_context:
|
||||
@@ -299,10 +321,22 @@ def train(
|
||||
tracker.log_metrics(loss_values, split="train")
|
||||
|
||||
if val_loader is not None and (step % valid_interval == 0 or step == num_iters - 1):
|
||||
validate(model, val_loader, batch_processor, accelerator, tracker, lambdas,
|
||||
writer=writer, step=step, val_ds=val_ds, audio_vae=audio_vae_for_gen,
|
||||
sample_rate=sample_rate, val_texts=val_texts, tokenizer=tokenizer,
|
||||
valid_interval=valid_interval)
|
||||
validate(
|
||||
model,
|
||||
val_loader,
|
||||
batch_processor,
|
||||
accelerator,
|
||||
tracker,
|
||||
lambdas,
|
||||
writer=writer,
|
||||
step=step,
|
||||
val_ds=val_ds,
|
||||
audio_vae=audio_vae_for_gen,
|
||||
sample_rate=sample_rate,
|
||||
val_texts=val_texts,
|
||||
tokenizer=tokenizer,
|
||||
valid_interval=valid_interval,
|
||||
)
|
||||
|
||||
if (step % save_interval == 0 or step == num_iters - 1) and accelerator.rank == 0:
|
||||
save_checkpoint(model, optimizer, scheduler, save_dir, step, pretrained_path, hf_model_id, distribute)
|
||||
@@ -313,13 +347,26 @@ def train(
|
||||
writer.close()
|
||||
|
||||
|
||||
def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas,
|
||||
writer=None, step=0, val_ds=None, audio_vae=None, sample_rate=22050,
|
||||
val_texts=None, tokenizer=None, valid_interval=1000):
|
||||
def validate(
|
||||
model,
|
||||
val_loader,
|
||||
batch_processor,
|
||||
accelerator,
|
||||
tracker,
|
||||
lambdas,
|
||||
writer=None,
|
||||
step=0,
|
||||
val_ds=None,
|
||||
audio_vae=None,
|
||||
sample_rate=22050,
|
||||
val_texts=None,
|
||||
tokenizer=None,
|
||||
valid_interval=1000,
|
||||
):
|
||||
"""Validate and generate sample audio"""
|
||||
import numpy as np
|
||||
import numpy as np # noqa: F401
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
model.eval()
|
||||
total_losses = []
|
||||
sub_losses = defaultdict(list) # Track individual sub-losses
|
||||
@@ -356,26 +403,37 @@ def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas,
|
||||
# Compute mean total loss
|
||||
mean_total_loss = torch.stack(total_losses).mean()
|
||||
accelerator.all_reduce(mean_total_loss)
|
||||
|
||||
|
||||
# Compute mean of each sub-loss
|
||||
val_metrics = {"loss/total": mean_total_loss.item()}
|
||||
for key, values in sub_losses.items():
|
||||
mean_sub_loss = torch.stack(values).mean()
|
||||
accelerator.all_reduce(mean_sub_loss)
|
||||
val_metrics[key] = mean_sub_loss.item()
|
||||
|
||||
|
||||
tracker.log_metrics(val_metrics, split="val")
|
||||
|
||||
|
||||
# Generate sample audio for TensorBoard display
|
||||
if writer is not None and val_ds is not None and audio_vae is not None and accelerator.rank == 0:
|
||||
try:
|
||||
generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, sample_rate,
|
||||
val_texts=val_texts, tokenizer=tokenizer, valid_interval=valid_interval,
|
||||
tracker=tracker)
|
||||
generate_sample_audio(
|
||||
model,
|
||||
val_ds,
|
||||
audio_vae,
|
||||
writer,
|
||||
step,
|
||||
accelerator,
|
||||
sample_rate,
|
||||
val_texts=val_texts,
|
||||
tokenizer=tokenizer,
|
||||
valid_interval=valid_interval,
|
||||
tracker=tracker,
|
||||
)
|
||||
except Exception as e:
|
||||
tracker.print(f"[Warning] Failed to generate sample audio: {e}")
|
||||
import traceback
|
||||
import io
|
||||
|
||||
buf = io.StringIO()
|
||||
traceback.print_exc(file=buf)
|
||||
tracker.print(buf.getvalue())
|
||||
@@ -390,7 +448,7 @@ def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas,
|
||||
missing.append("audio_vae")
|
||||
if missing and accelerator.rank == 0:
|
||||
tracker.print(f"[Warning] Skip audio generation: missing {', '.join(missing)}")
|
||||
|
||||
|
||||
model.train()
|
||||
|
||||
|
||||
@@ -398,6 +456,7 @@ def compute_mel_spectrogram(audio_np, sample_rate, n_mels=128):
|
||||
"""Compute Mel Spectrogram (dB) using librosa"""
|
||||
import numpy as np
|
||||
import librosa
|
||||
|
||||
audio_np = audio_np.flatten().astype(np.float32)
|
||||
mel = librosa.feature.melspectrogram(y=audio_np, sr=sample_rate, n_mels=n_mels, fmax=sample_rate // 2)
|
||||
return librosa.power_to_db(mel, ref=np.max)
|
||||
@@ -408,31 +467,45 @@ def create_mel_figure(gen_audio_np, gen_mel, sample_rate, step=None, ref_audio_n
|
||||
Create mel spectrogram figure: show comparison if reference audio exists, otherwise show generated only
|
||||
"""
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
import librosa.display
|
||||
|
||||
|
||||
fmax = sample_rate // 2
|
||||
step_str = f" @ Step {step}" if step is not None else ""
|
||||
|
||||
|
||||
if ref_audio_np is not None and ref_mel is not None:
|
||||
# Comparison mode: reference vs generated
|
||||
fig, (ax_ref, ax_gen) = plt.subplots(2, 1, figsize=(12, 8))
|
||||
|
||||
img_ref = librosa.display.specshow(ref_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax_ref)
|
||||
ax_ref.set_title(f'Reference (GT) - {len(ref_audio_np)/sample_rate:.2f}s{step_str}', fontsize=10, fontweight='bold', color='#28A745')
|
||||
plt.colorbar(img_ref, ax=ax_ref, format='%+2.0f dB', pad=0.02)
|
||||
|
||||
img_gen = librosa.display.specshow(gen_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax_gen)
|
||||
ax_gen.set_title(f'Generated - {len(gen_audio_np)/sample_rate:.2f}s', fontsize=10, fontweight='bold', color='#DC3545')
|
||||
plt.colorbar(img_gen, ax=ax_gen, format='%+2.0f dB', pad=0.02)
|
||||
|
||||
img_ref = librosa.display.specshow(
|
||||
ref_mel, sr=sample_rate, x_axis="time", y_axis="mel", fmax=fmax, cmap="viridis", ax=ax_ref
|
||||
)
|
||||
ax_ref.set_title(
|
||||
f"Reference (GT) - {len(ref_audio_np)/sample_rate:.2f}s{step_str}",
|
||||
fontsize=10,
|
||||
fontweight="bold",
|
||||
color="#28A745",
|
||||
)
|
||||
plt.colorbar(img_ref, ax=ax_ref, format="%+2.0f dB", pad=0.02)
|
||||
|
||||
img_gen = librosa.display.specshow(
|
||||
gen_mel, sr=sample_rate, x_axis="time", y_axis="mel", fmax=fmax, cmap="viridis", ax=ax_gen
|
||||
)
|
||||
ax_gen.set_title(
|
||||
f"Generated - {len(gen_audio_np)/sample_rate:.2f}s", fontsize=10, fontweight="bold", color="#DC3545"
|
||||
)
|
||||
plt.colorbar(img_gen, ax=ax_gen, format="%+2.0f dB", pad=0.02)
|
||||
else:
|
||||
# Single figure mode: show generated only
|
||||
fig, ax = plt.subplots(figsize=(12, 4))
|
||||
img = librosa.display.specshow(gen_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax)
|
||||
ax.set_title(f'Generated - {len(gen_audio_np)/sample_rate:.2f}s{step_str}', fontsize=11, fontweight='bold')
|
||||
plt.colorbar(img, ax=ax, format='%+2.0f dB', pad=0.02)
|
||||
|
||||
img = librosa.display.specshow(
|
||||
gen_mel, sr=sample_rate, x_axis="time", y_axis="mel", fmax=fmax, cmap="viridis", ax=ax
|
||||
)
|
||||
ax.set_title(f"Generated - {len(gen_audio_np)/sample_rate:.2f}s{step_str}", fontsize=11, fontweight="bold")
|
||||
plt.colorbar(img, ax=ax, format="%+2.0f dB", pad=0.02)
|
||||
|
||||
plt.tight_layout()
|
||||
return fig
|
||||
|
||||
@@ -440,26 +513,38 @@ def create_mel_figure(gen_audio_np, gen_mel, sample_rate, step=None, ref_audio_n
|
||||
def normalize_audio(audio_np):
|
||||
"""Normalize audio to [-0.9, 0.9]"""
|
||||
import numpy as np
|
||||
|
||||
max_val = np.abs(audio_np).max()
|
||||
return audio_np / max_val * 0.9 if max_val > 0 else audio_np
|
||||
|
||||
|
||||
def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, sample_rate=22050,
|
||||
val_texts=None, tokenizer=None, pretrained_path=None, valid_interval=1000,
|
||||
tracker=None):
|
||||
def generate_sample_audio(
|
||||
model,
|
||||
val_ds,
|
||||
audio_vae,
|
||||
writer,
|
||||
step,
|
||||
accelerator,
|
||||
sample_rate=22050,
|
||||
val_texts=None,
|
||||
tokenizer=None,
|
||||
pretrained_path=None,
|
||||
valid_interval=1000,
|
||||
tracker=None,
|
||||
):
|
||||
"""Select 2 fixed validation samples, generate audio and log to TensorBoard"""
|
||||
import numpy as np
|
||||
|
||||
|
||||
log = tracker.print if tracker else print
|
||||
num_samples = min(2, len(val_ds))
|
||||
log(f"[Audio] Starting audio generation for {num_samples} samples at step {step}")
|
||||
|
||||
|
||||
unwrapped_model = accelerator.unwrap(model)
|
||||
|
||||
|
||||
for i in range(num_samples):
|
||||
sample = val_ds[i]
|
||||
text = val_texts[i] if val_texts and i < len(val_texts) else "Hello, this is a test."
|
||||
|
||||
|
||||
# Load reference audio
|
||||
ref_audio_np = None
|
||||
try:
|
||||
@@ -468,7 +553,10 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
|
||||
ref_sr = sample["audio"].get("sampling_rate", sample_rate)
|
||||
if ref_sr != sample_rate:
|
||||
import torchaudio.functional as F
|
||||
ref_audio_np = F.resample(torch.from_numpy(ref_audio_np).unsqueeze(0), ref_sr, sample_rate).squeeze(0).numpy()
|
||||
|
||||
ref_audio_np = (
|
||||
F.resample(torch.from_numpy(ref_audio_np).unsqueeze(0), ref_sr, sample_rate).squeeze(0).numpy()
|
||||
)
|
||||
log(f"[Audio] Loaded reference audio for sample {i}: duration={len(ref_audio_np)/sample_rate:.2f}s")
|
||||
except Exception as e:
|
||||
log(f"[Warning] Failed to load reference audio: {e}")
|
||||
@@ -480,7 +568,7 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
|
||||
unwrapped_model.eval()
|
||||
# unwrapped_model.to(torch.bfloat16)
|
||||
unwrapped_model.audio_vae = audio_vae.to(torch.float32)
|
||||
|
||||
|
||||
log(f"[Audio] Generating sample {i} with text: '{text[:50]}...'")
|
||||
autocast_ctx = (
|
||||
torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
@@ -490,27 +578,33 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
|
||||
with torch.no_grad():
|
||||
with autocast_ctx:
|
||||
generated = unwrapped_model.generate(target_text=text, inference_timesteps=10, cfg_value=2.0)
|
||||
|
||||
|
||||
# Restore training setup
|
||||
# unwrapped_model.to(torch.float32)
|
||||
# unwrapped_model.audio_vae = None
|
||||
|
||||
|
||||
if generated is None or len(generated) == 0:
|
||||
log(f"[Warning] Generated audio is empty for sample {i}")
|
||||
continue
|
||||
|
||||
|
||||
# Process generated audio
|
||||
gen_audio_np = generated.cpu().float().numpy().flatten() if isinstance(generated, torch.Tensor) else np.array(generated, dtype=np.float32).flatten()
|
||||
gen_audio_np = (
|
||||
generated.cpu().float().numpy().flatten()
|
||||
if isinstance(generated, torch.Tensor)
|
||||
else np.array(generated, dtype=np.float32).flatten()
|
||||
)
|
||||
gen_audio_np = normalize_audio(gen_audio_np)
|
||||
|
||||
|
||||
tag = f"val_sample_{i}"
|
||||
writer.add_audio(f"{tag}/generated_audio", gen_audio_np, global_step=step, sample_rate=sample_rate)
|
||||
log(f"[Audio] Generated audio for sample {i}: duration={len(gen_audio_np)/sample_rate:.2f}s")
|
||||
|
||||
|
||||
# Log reference audio
|
||||
if ref_audio_np is not None:
|
||||
writer.add_audio(f"{tag}/reference_audio", normalize_audio(ref_audio_np), global_step=step, sample_rate=sample_rate)
|
||||
|
||||
writer.add_audio(
|
||||
f"{tag}/reference_audio", normalize_audio(ref_audio_np), global_step=step, sample_rate=sample_rate
|
||||
)
|
||||
|
||||
# Generate mel spectrogram figure
|
||||
try:
|
||||
mel_gen = compute_mel_spectrogram(gen_audio_np, sample_rate)
|
||||
@@ -520,10 +614,11 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
|
||||
log(f"[Audio] Created mel spectrogram figure for sample {i}")
|
||||
except Exception as e:
|
||||
log(f"[Warning] Failed to create mel spectrogram: {e}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
log(f"[Warning] Failed to generate audio for sample {i}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
finally:
|
||||
@@ -545,30 +640,29 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0):
|
||||
Called by all ranks so that distributed state stays aligned.
|
||||
Returns the step number to resume from, or 0 if no checkpoint found.
|
||||
"""
|
||||
import json
|
||||
|
||||
latest_folder = save_dir / "latest"
|
||||
if not latest_folder.exists():
|
||||
return 0
|
||||
|
||||
|
||||
unwrapped = model.module if hasattr(model, "module") else model
|
||||
lora_cfg = unwrapped.lora_config
|
||||
|
||||
|
||||
# Load model weights
|
||||
if lora_cfg is not None:
|
||||
# LoRA: load lora_weights
|
||||
lora_weights_path = latest_folder / "lora_weights.safetensors"
|
||||
if not lora_weights_path.exists():
|
||||
lora_weights_path = latest_folder / "lora_weights.ckpt"
|
||||
|
||||
|
||||
if lora_weights_path.exists():
|
||||
if lora_weights_path.suffix == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
state_dict = load_file(str(lora_weights_path))
|
||||
else:
|
||||
ckpt = torch.load(lora_weights_path, map_location="cpu")
|
||||
state_dict = ckpt.get("state_dict", ckpt)
|
||||
|
||||
|
||||
unwrapped.load_state_dict(state_dict, strict=False)
|
||||
if rank == 0:
|
||||
print(f"Loaded LoRA weights from {lora_weights_path}", file=sys.stderr)
|
||||
@@ -577,33 +671,34 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0):
|
||||
model_path = latest_folder / "model.safetensors"
|
||||
if not model_path.exists():
|
||||
model_path = latest_folder / "pytorch_model.bin"
|
||||
|
||||
|
||||
if model_path.exists():
|
||||
if model_path.suffix == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
state_dict = load_file(str(model_path))
|
||||
else:
|
||||
ckpt = torch.load(model_path, map_location="cpu")
|
||||
state_dict = ckpt.get("state_dict", ckpt)
|
||||
|
||||
|
||||
unwrapped.load_state_dict(state_dict, strict=False)
|
||||
if rank == 0:
|
||||
print(f"Loaded model weights from {model_path}", file=sys.stderr)
|
||||
|
||||
|
||||
# Load optimizer state
|
||||
optimizer_path = latest_folder / "optimizer.pth"
|
||||
if optimizer_path.exists():
|
||||
optimizer.load_state_dict(torch.load(optimizer_path, map_location="cpu"))
|
||||
if rank == 0:
|
||||
print(f"Loaded optimizer state from {optimizer_path}", file=sys.stderr)
|
||||
|
||||
|
||||
# Load scheduler state
|
||||
scheduler_path = latest_folder / "scheduler.pth"
|
||||
if scheduler_path.exists():
|
||||
scheduler.load_state_dict(torch.load(scheduler_path, map_location="cpu"))
|
||||
if rank == 0:
|
||||
print(f"Loaded scheduler state from {scheduler_path}", file=sys.stderr)
|
||||
|
||||
|
||||
state_path = latest_folder / "training_state.json"
|
||||
if state_path.exists():
|
||||
with open(state_path, "r", encoding="utf-8") as f:
|
||||
@@ -621,28 +716,36 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0):
|
||||
if rank == 0:
|
||||
print(f"Resuming from step {resume_step}", file=sys.stderr)
|
||||
return resume_step
|
||||
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pretrained_path: str = None, hf_model_id: str = "", distribute: bool = False):
|
||||
def save_checkpoint(
|
||||
model,
|
||||
optimizer,
|
||||
scheduler,
|
||||
save_dir: Path,
|
||||
step: int,
|
||||
pretrained_path: str = None,
|
||||
hf_model_id: str = "",
|
||||
distribute: bool = False,
|
||||
):
|
||||
"""
|
||||
Save checkpoint with different strategies for full finetune vs LoRA:
|
||||
- Full finetune: save non-vae weights to model.safetensors (or pytorch_model.bin if safetensors unavailable)
|
||||
- LoRA: save only lora weights to lora_weights.safetensors (or lora_weights.ckpt if safetensors unavailable)
|
||||
"""
|
||||
import json
|
||||
import shutil
|
||||
|
||||
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
tag = f"step_{step:07d}"
|
||||
folder = save_dir / tag
|
||||
folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
unwrapped = model.module if hasattr(model, "module") else model
|
||||
full_state = unwrapped.state_dict()
|
||||
lora_cfg = unwrapped.lora_config
|
||||
|
||||
|
||||
if lora_cfg is not None:
|
||||
# LoRA finetune: save only lora_A/lora_B weights
|
||||
state_dict = {k: v for k, v in full_state.items() if "lora_" in k}
|
||||
@@ -650,7 +753,7 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
|
||||
save_file(state_dict, folder / "lora_weights.safetensors")
|
||||
else:
|
||||
torch.save({"state_dict": state_dict}, folder / "lora_weights.ckpt")
|
||||
|
||||
|
||||
# Save LoRA config and base model path to a separate JSON file
|
||||
# If distribute=True, save hf_model_id; otherwise save local pretrained_path
|
||||
base_model_to_save = hf_model_id if distribute else (str(pretrained_path) if pretrained_path else None)
|
||||
@@ -667,16 +770,23 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
|
||||
save_file(state_dict, folder / "model.safetensors")
|
||||
else:
|
||||
torch.save({"state_dict": state_dict}, folder / "pytorch_model.bin")
|
||||
|
||||
|
||||
# Copy config files from pretrained path
|
||||
if pretrained_path:
|
||||
pretrained_dir = Path(pretrained_path)
|
||||
files_to_copy = ["config.json", "audiovae.pth", "tokenizer.json", "special_tokens_map.json", "tokenizer_config.json"]
|
||||
files_to_copy = [
|
||||
"config.json",
|
||||
"audiovae.pth",
|
||||
"audiovae.safetensors",
|
||||
"tokenizer.json",
|
||||
"special_tokens_map.json",
|
||||
"tokenizer_config.json",
|
||||
]
|
||||
for fname in files_to_copy:
|
||||
src = pretrained_dir / fname
|
||||
if src.exists():
|
||||
shutil.copy2(src, folder / fname)
|
||||
|
||||
|
||||
torch.save(optimizer.state_dict(), folder / "optimizer.pth")
|
||||
torch.save(scheduler.state_dict(), folder / "scheduler.pth")
|
||||
with open(folder / "training_state.json", "w", encoding="utf-8") as f:
|
||||
|
||||
Reference in New Issue
Block a user