|
|
@@ -41,7 +41,7 @@ class HiFiGan(nn.Module):
|
|
|
loss_fm = feature_loss(fmap_rs, fmap_gs)
|
|
|
loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform)
|
|
|
if self.tpr_loss_weight != 0:
|
|
|
- loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
|
|
|
+ loss_tpr = tpr_loss(y_d_gs, y_d_rs, self.tpr_loss_tau)
|
|
|
else:
|
|
|
loss_tpr = torch.zeros(1).to(device)
|
|
|
loss_f0 = F.l1_loss(generated_f0, pitch_feat)
|