1
0

losses.py 607 B

1234567891011121314151617181920
  1. import torch
  2. import torch.nn.functional as F
  3. def tpr_loss(disc_real_outputs, disc_generated_outputs, tau):
  4. loss = 0
  5. for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
  6. m_DG = torch.median((dr - dg))
  7. L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG])
  8. loss += tau - F.relu(tau - L_rel)
  9. return loss
  10. def mel_loss(real_speech, generated_speech, mel_transforms):
  11. loss = 0
  12. for transform in mel_transforms:
  13. mel_r = transform(real_speech)
  14. mel_g = transform(generated_speech)
  15. loss += F.l1_loss(mel_g, mel_r)
  16. return loss