hifigan.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. from typing import Dict, Optional
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from matcha.hifigan.models import feature_loss, generator_loss, discriminator_loss
  6. from cosyvoice.utils.losses import tpr_loss, mel_loss
  7. class HiFiGan(nn.Module):
  8. def __init__(self, generator, discriminator, mel_spec_transform,
  9. multi_mel_spectral_recon_loss_weight=45, feat_match_loss_weight=2.0,
  10. tpr_loss_weight=1.0, tpr_loss_tau=0.04):
  11. super(HiFiGan, self).__init__()
  12. self.generator = generator
  13. self.discriminator = discriminator
  14. self.mel_spec_transform = mel_spec_transform
  15. self.multi_mel_spectral_recon_loss_weight = multi_mel_spectral_recon_loss_weight
  16. self.feat_match_loss_weight = feat_match_loss_weight
  17. self.tpr_loss_weight = tpr_loss_weight
  18. self.tpr_loss_tau = tpr_loss_tau
  19. def forward(
  20. self,
  21. batch: dict,
  22. device: torch.device,
  23. ) -> Dict[str, Optional[torch.Tensor]]:
  24. if batch['turn'] == 'generator':
  25. return self.forward_generator(batch, device)
  26. else:
  27. return self.forward_discriminator(batch, device)
  28. def forward_generator(self, batch, device):
  29. real_speech = batch['speech'].to(device)
  30. pitch_feat = batch['pitch_feat'].to(device)
  31. # 1. calculate generator outputs
  32. generated_speech, generated_f0 = self.generator(batch, device)
  33. # 2. calculate discriminator outputs
  34. y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
  35. # 3. calculate generator losses, feature loss, mel loss, tpr losses [Optional]
  36. loss_gen, _ = generator_loss(y_d_gs)
  37. loss_fm = feature_loss(fmap_rs, fmap_gs)
  38. loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform)
  39. if self.tpr_loss_weight != 0:
  40. loss_tpr = tpr_loss(y_d_gs, y_d_rs, self.tpr_loss_tau)
  41. else:
  42. loss_tpr = torch.zeros(1).to(device)
  43. loss_f0 = F.l1_loss(generated_f0, pitch_feat)
  44. loss = loss_gen + self.feat_match_loss_weight * loss_fm + \
  45. self.multi_mel_spectral_recon_loss_weight * loss_mel + \
  46. self.tpr_loss_weight * loss_tpr + loss_f0
  47. return {'loss': loss, 'loss_gen': loss_gen, 'loss_fm': loss_fm, 'loss_mel': loss_mel, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0}
  48. def forward_discriminator(self, batch, device):
  49. real_speech = batch['speech'].to(device)
  50. # 1. calculate generator outputs
  51. with torch.no_grad():
  52. generated_speech, generated_f0 = self.generator(batch, device)
  53. # 2. calculate discriminator outputs
  54. y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech.detach())
  55. # 3. calculate discriminator losses, tpr losses [Optional]
  56. loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs)
  57. if self.tpr_loss_weight != 0:
  58. loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
  59. else:
  60. loss_tpr = torch.zeros(1).to(device)
  61. loss = loss_disc + self.tpr_loss_weight * loss_tpr
  62. return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr}