losses.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import torch
  2. import torch.nn.functional as F
  3. from typing import Tuple
  4. def tpr_loss(disc_real_outputs, disc_generated_outputs, tau):
  5. loss = 0
  6. for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
  7. m_DG = torch.median((dr - dg))
  8. L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG])
  9. loss += tau - F.relu(tau - L_rel)
  10. return loss
  11. def mel_loss(real_speech, generated_speech, mel_transforms):
  12. loss = 0
  13. for transform in mel_transforms:
  14. mel_r = transform(real_speech)
  15. mel_g = transform(generated_speech)
  16. loss += F.l1_loss(mel_g, mel_r)
  17. return loss
  18. class DPOLoss(torch.nn.Module):
  19. """
  20. DPO Loss
  21. """
  22. def __init__(self, beta: float, label_smoothing: float = 0.0, ipo: bool = False) -> None:
  23. super().__init__()
  24. self.beta = beta
  25. self.label_smoothing = label_smoothing
  26. self.ipo = ipo
  27. def forward(
  28. self,
  29. policy_chosen_logps: torch.Tensor,
  30. policy_rejected_logps: torch.Tensor,
  31. reference_chosen_logps: torch.Tensor,
  32. reference_rejected_logps: torch.Tensor,
  33. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  34. pi_logratios = policy_chosen_logps - policy_rejected_logps
  35. ref_logratios = reference_chosen_logps - reference_rejected_logps
  36. logits = pi_logratios - ref_logratios
  37. if self.ipo:
  38. losses = (logits - 1 / (2 * self.beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf
  39. else:
  40. # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf)
  41. losses = (
  42. -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
  43. - F.logsigmoid(-self.beta * logits) * self.label_smoothing
  44. )
  45. loss = losses.mean()
  46. chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
  47. rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
  48. return loss, chosen_rewards, rejected_rewards