flow.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. import random
  16. from typing import Dict, Optional
  17. import torch
  18. import torch.nn as nn
  19. from torch.nn import functional as F
  20. from omegaconf import DictConfig
  21. from cosyvoice.utils.mask import make_pad_mask
  22. class MaskedDiffWithXvec(torch.nn.Module):
  23. def __init__(self,
  24. input_size: int = 512,
  25. output_size: int = 80,
  26. spk_embed_dim: int = 192,
  27. output_type: str = "mel",
  28. vocab_size: int = 4096,
  29. input_frame_rate: int = 50,
  30. only_mask_loss: bool = True,
  31. encoder: torch.nn.Module = None,
  32. length_regulator: torch.nn.Module = None,
  33. decoder: torch.nn.Module = None,
  34. decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
  35. 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
  36. 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
  37. 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
  38. 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
  39. super().__init__()
  40. self.input_size = input_size
  41. self.output_size = output_size
  42. self.decoder_conf = decoder_conf
  43. self.vocab_size = vocab_size
  44. self.output_type = output_type
  45. self.input_frame_rate = input_frame_rate
  46. logging.info(f"input frame rate={self.input_frame_rate}")
  47. self.input_embedding = nn.Embedding(vocab_size, input_size)
  48. self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
  49. self.encoder = encoder
  50. self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
  51. self.decoder = decoder
  52. self.length_regulator = length_regulator
  53. self.only_mask_loss = only_mask_loss
  54. def forward(
  55. self,
  56. batch: dict,
  57. device: torch.device,
  58. ) -> Dict[str, Optional[torch.Tensor]]:
  59. token = batch['speech_token'].to(device)
  60. token_len = batch['speech_token_len'].to(device)
  61. feat = batch['speech_feat'].to(device)
  62. feat_len = batch['speech_feat_len'].to(device)
  63. embedding = batch['embedding'].to(device)
  64. # xvec projection
  65. embedding = F.normalize(embedding, dim=1)
  66. embedding = self.spk_embed_affine_layer(embedding)
  67. # concat text and prompt_text
  68. mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
  69. token = self.input_embedding(torch.clamp(token, min=0)) * mask
  70. # text encode
  71. h, h_lengths = self.encoder(token, token_len)
  72. h = self.encoder_proj(h)
  73. h, h_lengths = self.length_regulator(h, feat_len)
  74. # get conditions
  75. conds = torch.zeros(feat.shape, device=token.device)
  76. for i, j in enumerate(feat_len):
  77. if random.random() < 0.5:
  78. continue
  79. index = random.randint(0, int(0.3 * j))
  80. conds[i, :index] = feat[i, :index]
  81. conds = conds.transpose(1, 2)
  82. mask = (~make_pad_mask(feat_len)).to(h)
  83. # NOTE this is unnecessary, feat/h already same shape
  84. loss, _ = self.decoder.compute_loss(
  85. feat.transpose(1, 2).contiguous(),
  86. mask.unsqueeze(1),
  87. h.transpose(1, 2).contiguous(),
  88. embedding,
  89. cond=conds
  90. )
  91. return {'loss': loss}
  92. @torch.inference_mode()
  93. def inference(self,
  94. token,
  95. token_len,
  96. prompt_token,
  97. prompt_token_len,
  98. prompt_feat,
  99. prompt_feat_len,
  100. embedding,
  101. flow_cache):
  102. assert token.shape[0] == 1
  103. # xvec projection
  104. embedding = F.normalize(embedding, dim=1)
  105. embedding = self.spk_embed_affine_layer(embedding)
  106. # concat speech token and prompt speech token
  107. token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
  108. token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
  109. mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
  110. token = self.input_embedding(torch.clamp(token, min=0)) * mask
  111. # text encode
  112. h, h_lengths = self.encoder(token, token_len)
  113. h = self.encoder_proj(h)
  114. mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
  115. h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
  116. # get conditions
  117. conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
  118. conds[:, :mel_len1] = prompt_feat
  119. conds = conds.transpose(1, 2)
  120. mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
  121. feat, flow_cache = self.decoder(
  122. mu=h.transpose(1, 2).contiguous(),
  123. mask=mask.unsqueeze(1),
  124. spks=embedding,
  125. cond=conds,
  126. n_timesteps=10,
  127. prompt_len=mel_len1,
  128. cache=flow_cache
  129. )
  130. feat = feat[:, :, mel_len1:]
  131. assert feat.shape[2] == mel_len2
  132. return feat.float(), flow_cache
  133. class CausalMaskedDiffWithXvec(torch.nn.Module):
  134. def __init__(self,
  135. input_size: int = 512,
  136. output_size: int = 80,
  137. spk_embed_dim: int = 192,
  138. output_type: str = "mel",
  139. vocab_size: int = 4096,
  140. input_frame_rate: int = 50,
  141. only_mask_loss: bool = True,
  142. token_mel_ratio: int = 2,
  143. pre_lookahead_len: int = 3,
  144. encoder: torch.nn.Module = None,
  145. decoder: torch.nn.Module = None,
  146. decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
  147. 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
  148. 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
  149. 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
  150. 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
  151. super().__init__()
  152. self.input_size = input_size
  153. self.output_size = output_size
  154. self.decoder_conf = decoder_conf
  155. self.vocab_size = vocab_size
  156. self.output_type = output_type
  157. self.input_frame_rate = input_frame_rate
  158. logging.info(f"input frame rate={self.input_frame_rate}")
  159. self.input_embedding = nn.Embedding(vocab_size, input_size)
  160. self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
  161. self.encoder = encoder
  162. self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
  163. self.decoder = decoder
  164. self.only_mask_loss = only_mask_loss
  165. self.token_mel_ratio = token_mel_ratio
  166. self.pre_lookahead_len = pre_lookahead_len
  167. def forward(
  168. self,
  169. batch: dict,
  170. device: torch.device,
  171. ) -> Dict[str, Optional[torch.Tensor]]:
  172. token = batch['speech_token'].to(device)
  173. token_len = batch['speech_token_len'].to(device)
  174. feat = batch['speech_feat'].to(device)
  175. feat_len = batch['speech_feat_len'].to(device)
  176. embedding = batch['embedding'].to(device)
  177. # NOTE unified training, static_chunk_size > 0 or = 0
  178. streaming = True if random.random() < 0.5 else False
  179. # xvec projection
  180. embedding = F.normalize(embedding, dim=1)
  181. embedding = self.spk_embed_affine_layer(embedding)
  182. # concat text and prompt_text
  183. mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
  184. token = self.input_embedding(torch.clamp(token, min=0)) * mask
  185. # text encode
  186. h, h_lengths = self.encoder(token, token_len, streaming=streaming)
  187. h = self.encoder_proj(h)
  188. # get conditions
  189. conds = torch.zeros(feat.shape, device=token.device)
  190. for i, j in enumerate(feat_len):
  191. if random.random() < 0.5:
  192. continue
  193. index = random.randint(0, int(0.3 * j))
  194. conds[i, :index] = feat[i, :index]
  195. conds = conds.transpose(1, 2)
  196. mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
  197. loss, _ = self.decoder.compute_loss(
  198. feat.transpose(1, 2).contiguous(),
  199. mask.unsqueeze(1),
  200. h.transpose(1, 2).contiguous(),
  201. embedding,
  202. cond=conds,
  203. streaming=streaming,
  204. )
  205. return {'loss': loss}
  206. @torch.inference_mode()
  207. def inference(self,
  208. token,
  209. token_len,
  210. prompt_token,
  211. prompt_token_len,
  212. prompt_feat,
  213. prompt_feat_len,
  214. embedding,
  215. streaming,
  216. finalize):
  217. assert token.shape[0] == 1
  218. # xvec projection
  219. embedding = F.normalize(embedding, dim=1)
  220. embedding = self.spk_embed_affine_layer(embedding)
  221. # concat text and prompt_text
  222. token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
  223. mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
  224. token = self.input_embedding(torch.clamp(token, min=0)) * mask
  225. # text encode
  226. if finalize is True:
  227. h, h_lengths = self.encoder(token, token_len, streaming=streaming)
  228. else:
  229. token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
  230. h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming)
  231. mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
  232. h = self.encoder_proj(h)
  233. # get conditions
  234. conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
  235. conds[:, :mel_len1] = prompt_feat
  236. conds = conds.transpose(1, 2)
  237. mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
  238. feat, _ = self.decoder(
  239. mu=h.transpose(1, 2).contiguous(),
  240. mask=mask.unsqueeze(1),
  241. spks=embedding,
  242. cond=conds,
  243. n_timesteps=10,
  244. streaming=streaming
  245. )
  246. feat = feat[:, :, mel_len1:]
  247. assert feat.shape[2] == mel_len2
  248. return feat.float(), None
  249. class CausalMaskedDiffWithDiT(torch.nn.Module):
  250. def __init__(self,
  251. input_size: int = 512,
  252. output_size: int = 80,
  253. spk_embed_dim: int = 192,
  254. output_type: str = "mel",
  255. vocab_size: int = 4096,
  256. input_frame_rate: int = 50,
  257. only_mask_loss: bool = True,
  258. token_mel_ratio: int = 2,
  259. pre_lookahead_len: int = 3,
  260. pre_lookahead_layer: torch.nn.Module = None,
  261. decoder: torch.nn.Module = None,
  262. decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
  263. 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
  264. 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
  265. 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
  266. 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
  267. super().__init__()
  268. self.input_size = input_size
  269. self.output_size = output_size
  270. self.decoder_conf = decoder_conf
  271. self.vocab_size = vocab_size
  272. self.output_type = output_type
  273. self.input_frame_rate = input_frame_rate
  274. logging.info(f"input frame rate={self.input_frame_rate}")
  275. self.input_embedding = nn.Embedding(vocab_size, input_size)
  276. self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
  277. self.pre_lookahead_len = pre_lookahead_len
  278. self.pre_lookahead_layer = pre_lookahead_layer
  279. self.decoder = decoder
  280. self.only_mask_loss = only_mask_loss
  281. self.token_mel_ratio = token_mel_ratio
  282. def forward(
  283. self,
  284. batch: dict,
  285. device: torch.device,
  286. ) -> Dict[str, Optional[torch.Tensor]]:
  287. token = batch['speech_token'].to(device)
  288. token_len = batch['speech_token_len'].to(device)
  289. feat = batch['speech_feat'].to(device)
  290. feat_len = batch['speech_feat_len'].to(device)
  291. embedding = batch['embedding'].to(device)
  292. # NOTE unified training, static_chunk_size > 0 or = 0
  293. streaming = True if random.random() < 0.5 else False
  294. # xvec projection
  295. embedding = F.normalize(embedding, dim=1)
  296. embedding = self.spk_embed_affine_layer(embedding)
  297. # concat text and prompt_text
  298. mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
  299. token = self.input_embedding(torch.clamp(token, min=0)) * mask
  300. # text encode
  301. h = self.pre_lookahead_layer(token)
  302. h = h.repeat_interleave(self.token_mel_ratio, dim=1)
  303. mask = mask.repeat_interleave(self.token_mel_ratio, dim=1).squeeze(dim=-1)
  304. # get conditions
  305. conds = torch.zeros(feat.shape, device=token.device)
  306. for i, j in enumerate(feat_len):
  307. if random.random() < 0.5:
  308. continue
  309. index = random.randint(0, int(0.3 * j))
  310. conds[i, :index] = feat[i, :index]
  311. conds = conds.transpose(1, 2)
  312. loss, _ = self.decoder.compute_loss(
  313. feat.transpose(1, 2).contiguous(),
  314. mask.unsqueeze(1),
  315. h.transpose(1, 2).contiguous(),
  316. embedding,
  317. cond=conds,
  318. streaming=streaming,
  319. )
  320. return {'loss': loss}
  321. @torch.inference_mode()
  322. def inference(self,
  323. token,
  324. token_len,
  325. prompt_token,
  326. prompt_token_len,
  327. prompt_feat,
  328. prompt_feat_len,
  329. embedding,
  330. streaming,
  331. finalize):
  332. assert token.shape[0] == 1
  333. # xvec projection
  334. embedding = F.normalize(embedding, dim=1)
  335. embedding = self.spk_embed_affine_layer(embedding)
  336. # concat text and prompt_text
  337. token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
  338. mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
  339. token = self.input_embedding(torch.clamp(token, min=0)) * mask
  340. # text encode
  341. if finalize is True:
  342. h = self.pre_lookahead_layer(token)
  343. else:
  344. h = self.pre_lookahead_layer(token[:, :-self.pre_lookahead_len], context=token[:, -self.pre_lookahead_len:])
  345. h = h.repeat_interleave(self.token_mel_ratio, dim=1)
  346. mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
  347. # get conditions
  348. conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
  349. conds[:, :mel_len1] = prompt_feat
  350. conds = conds.transpose(1, 2)
  351. mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
  352. feat, _ = self.decoder(
  353. mu=h.transpose(1, 2).contiguous(),
  354. mask=mask.unsqueeze(1),
  355. spks=embedding,
  356. cond=conds,
  357. n_timesteps=10,
  358. streaming=streaming
  359. )
  360. feat = feat[:, :, mel_len1:]
  361. assert feat.shape[2] == mel_len2
  362. return feat.float(), None
  363. if __name__ == '__main__':
  364. torch.backends.cudnn.deterministic = True
  365. torch.backends.cudnn.benchmark = False
  366. from hyperpyyaml import load_hyperpyyaml
  367. with open('./pretrained_models/Fun-CosyVoice3-0.5B/cosyvoice3.yaml', 'r') as f:
  368. configs = load_hyperpyyaml(f, overrides={'llm': None, 'hift': None})
  369. model = configs['flow']
  370. device = 'cuda' if torch.cuda.is_available() else 'cpu'
  371. model.to(device)
  372. model.eval()
  373. max_len = 10 * model.decoder.estimator.static_chunk_size
  374. chunk_size = model.decoder.estimator.static_chunk_size
  375. context_size = model.pre_lookahead_layer.pre_lookahead_len
  376. token = torch.randint(0, 6561, size=(1, max_len)).to(device)
  377. token_len = torch.tensor([max_len]).to(device)
  378. prompt_token = torch.randint(0, 6561, size=(1, chunk_size)).to(device)
  379. prompt_token_len = torch.tensor([chunk_size]).to(device)
  380. prompt_feat = torch.rand(1, chunk_size * 2, 80).to(device)
  381. prompt_feat_len = torch.tensor([chunk_size * 2]).to(device)
  382. prompt_embedding = torch.rand(1, 192).to(device)
  383. pred_gt, _ = model.inference(token, token_len, prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, prompt_embedding, streaming=True, finalize=True)
  384. for i in range(0, max_len, chunk_size):
  385. finalize = True if i + chunk_size + context_size >= max_len else False
  386. pred_chunk, _ = model.inference(token[:, :i + chunk_size + context_size], torch.tensor([token[:, :i + chunk_size + context_size].shape[1]]).to(device),
  387. prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, prompt_embedding, streaming=True, finalize=finalize)
  388. pred_chunk = pred_chunk[:, :, i * model.token_mel_ratio:]
  389. print((pred_gt[:, :, i * model.token_mel_ratio: i * model.token_mel_ratio + pred_chunk.shape[2]] - pred_chunk).abs().max().item())