length_regulator.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Tuple
  15. import torch.nn as nn
  16. import torch
  17. from torch.nn import functional as F
  18. from cosyvoice.utils.mask import make_pad_mask
  19. class InterpolateRegulator(nn.Module):
  20. def __init__(
  21. self,
  22. channels: int,
  23. sampling_ratios: Tuple,
  24. out_channels: int = None,
  25. groups: int = 1,
  26. ):
  27. super().__init__()
  28. self.sampling_ratios = sampling_ratios
  29. out_channels = out_channels or channels
  30. model = nn.ModuleList([])
  31. if len(sampling_ratios) > 0:
  32. for _ in sampling_ratios:
  33. module = nn.Conv1d(channels, channels, 3, 1, 1)
  34. norm = nn.GroupNorm(groups, channels)
  35. act = nn.Mish()
  36. model.extend([module, norm, act])
  37. model.append(
  38. nn.Conv1d(channels, out_channels, 1, 1)
  39. )
  40. self.model = nn.Sequential(*model)
  41. def forward(self, x, ylens=None):
  42. # x in (B, T, D)
  43. mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
  44. x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
  45. out = self.model(x).transpose(1, 2).contiguous()
  46. olens = ylens
  47. return out * mask, olens
  48. def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
  49. # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
  50. # NOTE 20 corresponds to token_overlap_len in cosyvoice/cli/model.py
  51. # x in (B, T, D)
  52. if x2.shape[1] > 40:
  53. x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
  54. x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
  55. mode='linear')
  56. x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
  57. x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
  58. else:
  59. x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
  60. if x1.shape[1] != 0:
  61. x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
  62. x = torch.concat([x1, x2], dim=2)
  63. else:
  64. x = x2
  65. out = self.model(x).transpose(1, 2).contiguous()
  66. return out, mel_len1 + mel_len2