fix: ft log and setting

This commit is contained in:
Labmem-Zhouyx
2026-04-08 18:15:17 +08:00
parent ee3649c1b3
commit 68af4fe502
5 changed files with 63 additions and 22 deletions
+1 -1
View File
@@ -225,7 +225,7 @@ class UnifiedCFM(torch.nn.Module):
losses = F.mse_loss(u_pred, u_tgt.detach(), reduction="none").mean(dim=1)
if tgt_mask is not None:
weights = self.adaptive_loss_weighting(losses, tgt_mask.squeeze(1))
loss = (weights * losses).sum() / torch.sum(tgt_mask)
loss = (weights * losses).sum() / torch.clamp(torch.sum(tgt_mask), min=1.0)
else:
loss = losses.mean()