length_regulator.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Tuple
  15. import torch.nn as nn
  16. import torch
  17. from torch.nn import functional as F
  18. from cosyvoice.utils.mask import make_pad_mask
  19. class InterpolateRegulator(nn.Module):
  20. def __init__(
  21. self,
  22. channels: int,
  23. sampling_ratios: Tuple,
  24. out_channels: int = None,
  25. groups: int = 1,
  26. ):
  27. super().__init__()
  28. self.sampling_ratios = sampling_ratios
  29. out_channels = out_channels or channels
  30. model = nn.ModuleList([])
  31. if len(sampling_ratios) > 0:
  32. for _ in sampling_ratios:
  33. module = nn.Conv1d(channels, channels, 3, 1, 1)
  34. norm = nn.GroupNorm(groups, channels)
  35. act = nn.Mish()
  36. model.extend([module, norm, act])
  37. model.append(
  38. nn.Conv1d(channels, out_channels, 1, 1)
  39. )
  40. self.model = nn.Sequential(*model)
  41. def forward(self, x, ylens=None):
  42. # x in (B, T, D)
  43. mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
  44. x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
  45. out = self.model(x).transpose(1, 2).contiguous()
  46. olens = ylens
  47. return out * mask, olens
  48. def inference(self, x1, x2, mel_len1, mel_len2):
  49. # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
  50. # x in (B, T, D)
  51. if x2.shape[1] > 40:
  52. x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=34, mode='linear')
  53. x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - 34 * 2, mode='linear')
  54. x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=34, mode='linear')
  55. x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
  56. else:
  57. x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
  58. if x1.shape[1] != 0:
  59. x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
  60. x = torch.concat([x1, x2], dim=2)
  61. else:
  62. x = x2
  63. out = self.model(x).transpose(1, 2).contiguous()
  64. return out, mel_len1 + mel_len2