fix: ft log and setting
This commit is contained in:
@@ -225,7 +225,7 @@ class UnifiedCFM(torch.nn.Module):
|
||||
losses = F.mse_loss(u_pred, u_tgt.detach(), reduction="none").mean(dim=1)
|
||||
if tgt_mask is not None:
|
||||
weights = self.adaptive_loss_weighting(losses, tgt_mask.squeeze(1))
|
||||
loss = (weights * losses).sum() / torch.sum(tgt_mask)
|
||||
loss = (weights * losses).sum() / torch.clamp(torch.sum(tgt_mask), min=1.0)
|
||||
else:
|
||||
loss = losses.mean()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user