Update train_voxcpm_finetune.py
修改了issue#185中提到的问题,在训练时进行validate会对原模型执行to(torch.bfloat16)然后to(torch.float32)的操作,这样可能导致模型数值浮动,因此这个修改让validate步骤保留原模型数值
This commit is contained in:
@@ -477,20 +477,28 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
|
|||||||
log(f"[Audio] Loaded reference audio for sample {i}: duration={len(ref_audio_np)/sample_rate:.2f}s")
|
log(f"[Audio] Loaded reference audio for sample {i}: duration={len(ref_audio_np)/sample_rate:.2f}s")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log(f"[Warning] Failed to load reference audio: {e}")
|
log(f"[Warning] Failed to load reference audio: {e}")
|
||||||
|
|
||||||
|
# 记录原模式,避免异常时状态不一致
|
||||||
|
prev_training = unwrapped_model.training
|
||||||
try:
|
try:
|
||||||
# Inference setup
|
# Inference setup
|
||||||
unwrapped_model.eval()
|
unwrapped_model.eval()
|
||||||
unwrapped_model.to(torch.bfloat16)
|
# unwrapped_model.to(torch.bfloat16)
|
||||||
unwrapped_model.audio_vae = audio_vae.to(torch.float32)
|
unwrapped_model.audio_vae = audio_vae.to(torch.float32)
|
||||||
|
|
||||||
log(f"[Audio] Generating sample {i} with text: '{text[:50]}...'")
|
log(f"[Audio] Generating sample {i} with text: '{text[:50]}...'")
|
||||||
|
autocast_ctx = (
|
||||||
|
torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else contextlib.nullcontext()
|
||||||
|
)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
generated = unwrapped_model.generate(target_text=text, inference_timesteps=10, cfg_value=2.0)
|
with autocast_ctx:
|
||||||
|
generated = unwrapped_model.generate(target_text=text, inference_timesteps=10, cfg_value=2.0)
|
||||||
|
|
||||||
# Restore training setup
|
# Restore training setup
|
||||||
unwrapped_model.to(torch.float32)
|
# unwrapped_model.to(torch.float32)
|
||||||
unwrapped_model.audio_vae = None
|
# unwrapped_model.audio_vae = None
|
||||||
|
|
||||||
if generated is None or len(generated) == 0:
|
if generated is None or len(generated) == 0:
|
||||||
log(f"[Warning] Generated audio is empty for sample {i}")
|
log(f"[Warning] Generated audio is empty for sample {i}")
|
||||||
@@ -523,6 +531,18 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
|
|||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore training setup(无论成功失败都恢复)
|
||||||
|
try:
|
||||||
|
# unwrapped_model.to(torch.float32)
|
||||||
|
unwrapped_model.audio_vae = None
|
||||||
|
if prev_training:
|
||||||
|
unwrapped_model.train()
|
||||||
|
else:
|
||||||
|
unwrapped_model.eval()
|
||||||
|
except Exception as e:
|
||||||
|
log(f"[Warning] Failed to restore model state: {e}")
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
|
def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
|
||||||
"""
|
"""
|
||||||
@@ -671,4 +691,4 @@ if __name__ == "__main__":
|
|||||||
else:
|
else:
|
||||||
# Otherwise use command line args (parsed by argbind)
|
# Otherwise use command line args (parsed by argbind)
|
||||||
with argbind.scope(args):
|
with argbind.scope(args):
|
||||||
train()
|
train()
|
||||||
|
|||||||
Reference in New Issue
Block a user