flow.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. from typing import Dict, Optional
  16. import torch
  17. import torch.nn as nn
  18. from torch.nn import functional as F
  19. from omegaconf import DictConfig
  20. from cosyvoice.utils.mask import make_pad_mask
  21. class MaskedDiffWithXvec(torch.nn.Module):
  22. def __init__(self,
  23. input_size: int = 512,
  24. output_size: int = 80,
  25. spk_embed_dim: int = 192,
  26. output_type: str = "mel",
  27. vocab_size: int = 4096,
  28. input_frame_rate: int = 50,
  29. only_mask_loss: bool = True,
  30. encoder: torch.nn.Module = None,
  31. length_regulator: torch.nn.Module = None,
  32. decoder: torch.nn.Module = None,
  33. decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1, 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
  34. mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050, 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
  35. super().__init__()
  36. self.input_size = input_size
  37. self.output_size = output_size
  38. self.decoder_conf = decoder_conf
  39. self.mel_feat_conf = mel_feat_conf
  40. self.vocab_size = vocab_size
  41. self.output_type = output_type
  42. self.input_frame_rate = input_frame_rate
  43. logging.info(f"input frame rate={self.input_frame_rate}")
  44. self.input_embedding = nn.Embedding(vocab_size, input_size)
  45. self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
  46. self.encoder = encoder
  47. self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
  48. self.decoder = decoder
  49. self.length_regulator = length_regulator
  50. self.only_mask_loss = only_mask_loss
  51. def forward(
  52. self,
  53. batch: dict,
  54. device: torch.device,
  55. ) -> Dict[str, Optional[torch.Tensor]]:
  56. token = batch['speech_token'].to(device)
  57. token_len = batch['speech_token_len'].to(device)
  58. feat = batch['speech_feat'].to(device)
  59. feat_len = batch['speech_feat_len'].to(device)
  60. embedding = batch['embedding'].to(device)
  61. # xvec projection
  62. embedding = F.normalize(embedding, dim=1)
  63. embedding = self.spk_embed_affine_layer(embedding)
  64. # concat text and prompt_text
  65. mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
  66. token = self.input_embedding(torch.clamp(token, min=0)) * mask
  67. # text encode
  68. h, h_lengths = self.encoder(token, token_len)
  69. h = self.encoder_proj(h)
  70. h, h_lengths = self.length_regulator(h, feat_len)
  71. # get conditions
  72. conds = torch.zeros(feat.shape, device=token.device)
  73. conds = conds.transpose(1, 2)
  74. mask = (~make_pad_mask(feat_len)).to(h)
  75. feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
  76. loss, _ = self.decoder.compute_loss(
  77. feat.transpose(1, 2).contiguous(),
  78. mask.unsqueeze(1),
  79. h.transpose(1, 2).contiguous(),
  80. embedding,
  81. cond=conds
  82. )
  83. return {'loss': loss}
  84. @torch.inference_mode()
  85. def inference(self,
  86. token,
  87. token_len,
  88. prompt_token,
  89. prompt_token_len,
  90. prompt_feat,
  91. prompt_feat_len,
  92. embedding):
  93. assert token.shape[0] == 1
  94. # xvec projection
  95. embedding = F.normalize(embedding, dim=1)
  96. embedding = self.spk_embed_affine_layer(embedding)
  97. # concat text and prompt_text
  98. token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
  99. mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding)
  100. token = self.input_embedding(torch.clamp(token, min=0)) * mask
  101. # text encode
  102. h, h_lengths = self.encoder(token, token_len)
  103. h = self.encoder_proj(h)
  104. feat_len = (token_len / 50 * 22050 / 256).int()
  105. h, h_lengths = self.length_regulator(h, feat_len)
  106. # get conditions
  107. conds = torch.zeros([1, feat_len.max().item(), self.output_size], device=token.device)
  108. if prompt_feat.shape[1] != 0:
  109. for i, j in enumerate(prompt_feat_len):
  110. conds[i, :j] = prompt_feat[i]
  111. conds = conds.transpose(1, 2)
  112. mask = (~make_pad_mask(feat_len)).to(h)
  113. feat = self.decoder(
  114. mu=h.transpose(1, 2).contiguous(),
  115. mask=mask.unsqueeze(1),
  116. spks=embedding,
  117. cond=conds,
  118. n_timesteps=10
  119. )
  120. if prompt_feat.shape[1] != 0:
  121. feat = feat[:, :, prompt_feat.shape[1]:]
  122. return feat