length_regulator.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Tuple
  15. import torch.nn as nn
  16. from torch.nn import functional as F
  17. from cosyvoice.utils.mask import make_pad_mask
  18. class InterpolateRegulator(nn.Module):
  19. def __init__(
  20. self,
  21. channels: int,
  22. sampling_ratios: Tuple,
  23. out_channels: int = None,
  24. groups: int = 1,
  25. ):
  26. super().__init__()
  27. self.sampling_ratios = sampling_ratios
  28. out_channels = out_channels or channels
  29. model = nn.ModuleList([])
  30. if len(sampling_ratios) > 0:
  31. for _ in sampling_ratios:
  32. module = nn.Conv1d(channels, channels, 3, 1, 1)
  33. norm = nn.GroupNorm(groups, channels)
  34. act = nn.Mish()
  35. model.extend([module, norm, act])
  36. model.append(
  37. nn.Conv1d(channels, out_channels, 1, 1)
  38. )
  39. self.model = nn.Sequential(*model)
  40. def forward(self, x, ylens=None):
  41. # x in (B, T, D)
  42. mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
  43. x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
  44. out = self.model(x).transpose(1, 2).contiguous()
  45. olens = ylens
  46. return out * mask, olens