flow.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. import random
  16. from typing import Dict, Optional
  17. import torch
  18. import torch.nn as nn
  19. from torch.nn import functional as F
  20. from omegaconf import DictConfig
  21. from cosyvoice.utils.mask import make_pad_mask
  22. class MaskedDiffWithXvec(torch.nn.Module):
  23. def __init__(self,
  24. input_size: int = 512,
  25. output_size: int = 80,
  26. spk_embed_dim: int = 192,
  27. output_type: str = "mel",
  28. vocab_size: int = 4096,
  29. input_frame_rate: int = 50,
  30. only_mask_loss: bool = True,
  31. encoder: torch.nn.Module = None,
  32. length_regulator: torch.nn.Module = None,
  33. decoder: torch.nn.Module = None,
  34. decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
  35. 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
  36. 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
  37. 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
  38. 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
  39. mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
  40. 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
  41. super().__init__()
  42. self.input_size = input_size
  43. self.output_size = output_size
  44. self.decoder_conf = decoder_conf
  45. self.mel_feat_conf = mel_feat_conf
  46. self.vocab_size = vocab_size
  47. self.output_type = output_type
  48. self.input_frame_rate = input_frame_rate
  49. logging.info(f"input frame rate={self.input_frame_rate}")
  50. self.input_embedding = nn.Embedding(vocab_size, input_size)
  51. self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
  52. self.encoder = encoder
  53. self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
  54. self.decoder = decoder
  55. self.length_regulator = length_regulator
  56. self.only_mask_loss = only_mask_loss
  57. def forward(
  58. self,
  59. batch: dict,
  60. device: torch.device,
  61. ) -> Dict[str, Optional[torch.Tensor]]:
  62. token = batch['speech_token'].to(device)
  63. token_len = batch['speech_token_len'].to(device)
  64. feat = batch['speech_feat'].to(device)
  65. feat_len = batch['speech_feat_len'].to(device)
  66. embedding = batch['embedding'].to(device)
  67. # xvec projection
  68. embedding = F.normalize(embedding, dim=1)
  69. embedding = self.spk_embed_affine_layer(embedding)
  70. # concat text and prompt_text
  71. mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
  72. token = self.input_embedding(torch.clamp(token, min=0)) * mask
  73. # text encode
  74. h, h_lengths = self.encoder(token, token_len)
  75. h = self.encoder_proj(h)
  76. h, h_lengths = self.length_regulator(h, feat_len)
  77. # get conditions
  78. conds = torch.zeros(feat.shape, device=token.device)
  79. for i, j in enumerate(feat_len):
  80. if random.random() < 0.5:
  81. continue
  82. index = random.randint(0, int(0.3 * j))
  83. conds[i, :index] = feat[i, :index]
  84. conds = conds.transpose(1, 2)
  85. mask = (~make_pad_mask(feat_len)).to(h)
  86. # NOTE 这一句应该是不需要的,应该h已经过length_regulator跟feat一样的shape了
  87. feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
  88. loss, _ = self.decoder.compute_loss(
  89. feat.transpose(1, 2).contiguous(),
  90. mask.unsqueeze(1),
  91. h.transpose(1, 2).contiguous(),
  92. embedding,
  93. cond=conds
  94. )
  95. return {'loss': loss}
  96. @torch.inference_mode()
  97. def inference(self,
  98. token,
  99. token_len,
  100. prompt_token,
  101. prompt_token_len,
  102. prompt_feat,
  103. prompt_feat_len,
  104. embedding,
  105. flow_cache):
  106. assert token.shape[0] == 1
  107. # xvec projection
  108. embedding = F.normalize(embedding, dim=1)
  109. embedding = self.spk_embed_affine_layer(embedding)
  110. # concat text and prompt_text
  111. token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
  112. token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
  113. mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
  114. token = self.input_embedding(torch.clamp(token, min=0)) * mask
  115. # text encode
  116. h, h_lengths = self.encoder(token, token_len)
  117. h = self.encoder_proj(h)
  118. mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
  119. h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
  120. # get conditions
  121. conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
  122. conds[:, :mel_len1] = prompt_feat
  123. conds = conds.transpose(1, 2)
  124. mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
  125. feat, flow_cache = self.decoder(
  126. mu=h.transpose(1, 2).contiguous(),
  127. mask=mask.unsqueeze(1),
  128. spks=embedding,
  129. cond=conds,
  130. n_timesteps=10,
  131. prompt_len=mel_len1,
  132. cache=flow_cache
  133. )
  134. feat = feat[:, :, mel_len1:]
  135. assert feat.shape[2] == mel_len2
  136. return feat.float(), flow_cache
  137. class CausalMaskedDiffWithXvec(torch.nn.Module):
  138. def __init__(self,
  139. input_size: int = 512,
  140. output_size: int = 80,
  141. spk_embed_dim: int = 192,
  142. output_type: str = "mel",
  143. vocab_size: int = 4096,
  144. input_frame_rate: int = 50,
  145. only_mask_loss: bool = True,
  146. token_mel_ratio: int = 2,
  147. pre_lookahead_len: int = 3,
  148. encoder: torch.nn.Module = None,
  149. decoder: torch.nn.Module = None,
  150. decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
  151. 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
  152. 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
  153. 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
  154. 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
  155. mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
  156. 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
  157. super().__init__()
  158. self.input_size = input_size
  159. self.output_size = output_size
  160. self.decoder_conf = decoder_conf
  161. self.mel_feat_conf = mel_feat_conf
  162. self.vocab_size = vocab_size
  163. self.output_type = output_type
  164. self.input_frame_rate = input_frame_rate
  165. logging.info(f"input frame rate={self.input_frame_rate}")
  166. self.input_embedding = nn.Embedding(vocab_size, input_size)
  167. self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
  168. self.encoder = encoder
  169. self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
  170. self.decoder = decoder
  171. self.only_mask_loss = only_mask_loss
  172. self.token_mel_ratio = token_mel_ratio
  173. self.pre_lookahead_len = pre_lookahead_len
  174. def forward(
  175. self,
  176. batch: dict,
  177. device: torch.device,
  178. ) -> Dict[str, Optional[torch.Tensor]]:
  179. token = batch['speech_token'].to(device)
  180. token_len = batch['speech_token_len'].to(device)
  181. feat = batch['speech_feat'].to(device)
  182. feat_len = batch['speech_feat_len'].to(device)
  183. embedding = batch['embedding'].to(device)
  184. # NOTE unified training, static_chunk_size > 0 or = 0
  185. streaming = True if random.random() < 0.5 else False
  186. # xvec projection
  187. embedding = F.normalize(embedding, dim=1)
  188. embedding = self.spk_embed_affine_layer(embedding)
  189. # concat text and prompt_text
  190. mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
  191. token = self.input_embedding(torch.clamp(token, min=0)) * mask
  192. # text encode
  193. h, h_lengths = self.encoder(token, token_len, streaming=streaming)
  194. h = self.encoder_proj(h)
  195. # get conditions
  196. feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
  197. conds = torch.zeros(feat.shape, device=token.device)
  198. for i, j in enumerate(feat_len):
  199. if random.random() < 0.5:
  200. continue
  201. index = random.randint(0, int(0.3 * j))
  202. conds[i, :index] = feat[i, :index]
  203. conds = conds.transpose(1, 2)
  204. mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
  205. loss, _ = self.decoder.compute_loss(
  206. feat.transpose(1, 2).contiguous(),
  207. mask.unsqueeze(1),
  208. h.transpose(1, 2).contiguous(),
  209. embedding,
  210. cond=conds,
  211. streaming=streaming,
  212. )
  213. return {'loss': loss}
  214. @torch.inference_mode()
  215. def inference(self,
  216. token,
  217. token_len,
  218. prompt_token,
  219. prompt_token_len,
  220. prompt_feat,
  221. prompt_feat_len,
  222. embedding,
  223. cache,
  224. finalize):
  225. assert token.shape[0] == 1
  226. # xvec projection
  227. embedding = F.normalize(embedding, dim=1)
  228. embedding = self.spk_embed_affine_layer(embedding)
  229. # concat text and prompt_text
  230. token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
  231. mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
  232. token = self.input_embedding(torch.clamp(token, min=0)) * mask
  233. # text encode
  234. if finalize is True:
  235. h, h_lengths, encoder_cache = self.encoder.forward_chunk(token, token_len, **cache['encoder_cache'])
  236. else:
  237. token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
  238. h, h_lengths, encoder_cache = self.encoder.forward_chunk(token, token_len, context=context, **cache['encoder_cache'])
  239. cache['encoder_cache']['offset'] = encoder_cache[0]
  240. cache['encoder_cache']['pre_lookahead_layer_conv2_cache'] = encoder_cache[1]
  241. cache['encoder_cache']['encoders_kv_cache'] = encoder_cache[2]
  242. cache['encoder_cache']['upsample_offset'] = encoder_cache[3]
  243. cache['encoder_cache']['upsample_conv_cache'] = encoder_cache[4]
  244. cache['encoder_cache']['upsample_kv_cache'] = encoder_cache[5]
  245. mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
  246. h = self.encoder_proj(h)
  247. # get conditions
  248. conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
  249. conds[:, :mel_len1] = prompt_feat
  250. conds = conds.transpose(1, 2)
  251. mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
  252. feat, cache['decoder_cache'] = self.decoder(
  253. mu=h.transpose(1, 2).contiguous(),
  254. mask=mask.unsqueeze(1),
  255. spks=embedding,
  256. cond=conds,
  257. n_timesteps=10,
  258. cache=cache['decoder_cache']
  259. )
  260. feat = feat[:, :, mel_len1:]
  261. assert feat.shape[2] == mel_len2
  262. return feat.float(), cache