Merge pull request #186 from symhsym/patch-1

Update train_voxcpm_finetune.py
This commit is contained in:
xliucs
2026-02-11 18:05:39 +08:00
committed by GitHub
+26 -6
View File
@@ -477,20 +477,28 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
log(f"[Audio] Loaded reference audio for sample {i}: duration={len(ref_audio_np)/sample_rate:.2f}s") log(f"[Audio] Loaded reference audio for sample {i}: duration={len(ref_audio_np)/sample_rate:.2f}s")
except Exception as e: except Exception as e:
log(f"[Warning] Failed to load reference audio: {e}") log(f"[Warning] Failed to load reference audio: {e}")
# 记录原模式,避免异常时状态不一致
prev_training = unwrapped_model.training
try: try:
# Inference setup # Inference setup
unwrapped_model.eval() unwrapped_model.eval()
unwrapped_model.to(torch.bfloat16) # unwrapped_model.to(torch.bfloat16)
unwrapped_model.audio_vae = audio_vae.to(torch.float32) unwrapped_model.audio_vae = audio_vae.to(torch.float32)
log(f"[Audio] Generating sample {i} with text: '{text[:50]}...'") log(f"[Audio] Generating sample {i} with text: '{text[:50]}...'")
autocast_ctx = (
torch.autocast(device_type="cuda", dtype=torch.bfloat16)
if torch.cuda.is_available()
else contextlib.nullcontext()
)
with torch.no_grad(): with torch.no_grad():
generated = unwrapped_model.generate(target_text=text, inference_timesteps=10, cfg_value=2.0) with autocast_ctx:
generated = unwrapped_model.generate(target_text=text, inference_timesteps=10, cfg_value=2.0)
# Restore training setup # Restore training setup
unwrapped_model.to(torch.float32) # unwrapped_model.to(torch.float32)
unwrapped_model.audio_vae = None # unwrapped_model.audio_vae = None
if generated is None or len(generated) == 0: if generated is None or len(generated) == 0:
log(f"[Warning] Generated audio is empty for sample {i}") log(f"[Warning] Generated audio is empty for sample {i}")
@@ -523,6 +531,18 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
import traceback import traceback
traceback.print_exc() traceback.print_exc()
finally:
# Restore training setup(无论成功失败都恢复)
try:
# unwrapped_model.to(torch.float32)
unwrapped_model.audio_vae = None
if prev_training:
unwrapped_model.train()
else:
unwrapped_model.eval()
except Exception as e:
log(f"[Warning] Failed to restore model state: {e}")
def load_checkpoint(model, optimizer, scheduler, save_dir: Path): def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
""" """
@@ -671,4 +691,4 @@ if __name__ == "__main__":
else: else:
# Otherwise use command line args (parsed by argbind) # Otherwise use command line args (parsed by argbind)
with argbind.scope(args): with argbind.scope(args):
train() train()