label_smoothing_loss.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. # Copyright (c) 2019 Shigeki Karita
  2. # 2020 Mobvoi Inc (Binbin Zhang)
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Label smoothing module."""
  16. import torch
  17. from torch import nn
  18. class LabelSmoothingLoss(nn.Module):
  19. """Label-smoothing loss.
  20. In a standard CE loss, the label's data distribution is:
  21. [0,1,2] ->
  22. [
  23. [1.0, 0.0, 0.0],
  24. [0.0, 1.0, 0.0],
  25. [0.0, 0.0, 1.0],
  26. ]
  27. In the smoothing version CE Loss,some probabilities
  28. are taken from the true label prob (1.0) and are divided
  29. among other labels.
  30. e.g.
  31. smoothing=0.1
  32. [0,1,2] ->
  33. [
  34. [0.9, 0.05, 0.05],
  35. [0.05, 0.9, 0.05],
  36. [0.05, 0.05, 0.9],
  37. ]
  38. Args:
  39. size (int): the number of class
  40. padding_idx (int): padding class id which will be ignored for loss
  41. smoothing (float): smoothing rate (0.0 means the conventional CE)
  42. normalize_length (bool):
  43. normalize loss by sequence length if True
  44. normalize loss by batch size if False
  45. """
  46. def __init__(self,
  47. size: int,
  48. padding_idx: int,
  49. smoothing: float,
  50. normalize_length: bool = False):
  51. """Construct an LabelSmoothingLoss object."""
  52. super(LabelSmoothingLoss, self).__init__()
  53. self.criterion = nn.KLDivLoss(reduction="none")
  54. self.padding_idx = padding_idx
  55. self.confidence = 1.0 - smoothing
  56. self.smoothing = smoothing
  57. self.size = size
  58. self.normalize_length = normalize_length
  59. def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
  60. """Compute loss between x and target.
  61. The model outputs and data labels tensors are flatten to
  62. (batch*seqlen, class) shape and a mask is applied to the
  63. padding part which should not be calculated for loss.
  64. Args:
  65. x (torch.Tensor): prediction (batch, seqlen, class)
  66. target (torch.Tensor):
  67. target signal masked with self.padding_id (batch, seqlen)
  68. Returns:
  69. loss (torch.Tensor) : The KL loss, scalar float value
  70. """
  71. assert x.size(2) == self.size
  72. batch_size = x.size(0)
  73. x = x.view(-1, self.size)
  74. target = target.view(-1)
  75. # use zeros_like instead of torch.no_grad() for true_dist,
  76. # since no_grad() can not be exported by JIT
  77. true_dist = torch.zeros_like(x)
  78. true_dist.fill_(self.smoothing / (self.size - 1))
  79. ignore = target == self.padding_idx # (B,)
  80. total = len(target) - ignore.sum().item()
  81. target = target.masked_fill(ignore, 0) # avoid -1 index
  82. true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
  83. kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
  84. denom = total if self.normalize_length else batch_size
  85. return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom