|
@@ -5,6 +5,7 @@ import torch.nn.functional as F
|
|
|
from matcha.hifigan.models import feature_loss, generator_loss, discriminator_loss
|
|
|
from cosyvoice.utils.losses import tpr_loss, mel_loss
|
|
|
|
|
|
+
|
|
|
class HiFiGan(nn.Module):
|
|
|
def __init__(self, generator, discriminator, mel_spec_transform,
|
|
|
multi_mel_spectral_recon_loss_weight=45, feat_match_loss_weight=2.0,
|
|
@@ -44,7 +45,9 @@ class HiFiGan(nn.Module):
|
|
|
else:
|
|
|
loss_tpr = torch.zeros(1).to(device)
|
|
|
loss_f0 = F.l1_loss(generated_f0, pitch_feat)
|
|
|
- loss = loss_gen + self.feat_match_loss_weight * loss_fm + self.multi_mel_spectral_recon_loss_weight * loss_mel + self.tpr_loss_weight * loss_tpr + loss_f0
|
|
|
+ loss = loss_gen + self.feat_match_loss_weight * loss_fm + \
|
|
|
+ self.multi_mel_spectral_recon_loss_weight * loss_mel + \
|
|
|
+ self.tpr_loss_weight * loss_tpr + loss_f0
|
|
|
return {'loss': loss, 'loss_gen': loss_gen, 'loss_fm': loss_fm, 'loss_mel': loss_mel, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0}
|
|
|
|
|
|
def forward_discriminator(self, batch, device):
|
|
@@ -63,4 +66,4 @@ class HiFiGan(nn.Module):
|
|
|
loss_tpr = torch.zeros(1).to(device)
|
|
|
loss_f0 = F.l1_loss(generated_f0, pitch_feat)
|
|
|
loss = loss_disc + self.tpr_loss_weight * loss_tpr + loss_f0
|
|
|
- return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0}
|
|
|
+ return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0}
|