flow.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. import random
  16. from typing import Dict, Optional
  17. import torch
  18. import torch.nn as nn
  19. from torch.nn import functional as F
  20. from omegaconf import DictConfig
  21. from cosyvoice.utils.mask import make_pad_mask
  22. from cosyvoice.utils.onnx import SpeechTokenExtractor
  23. class MaskedDiffWithXvec(torch.nn.Module):
  24. def __init__(self,
  25. input_size: int = 512,
  26. output_size: int = 80,
  27. spk_embed_dim: int = 192,
  28. output_type: str = "mel",
  29. vocab_size: int = 4096,
  30. input_frame_rate: int = 50,
  31. only_mask_loss: bool = True,
  32. encoder: torch.nn.Module = None,
  33. length_regulator: torch.nn.Module = None,
  34. decoder: torch.nn.Module = None,
  35. decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
  36. 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
  37. 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
  38. 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
  39. 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
  40. super().__init__()
  41. self.input_size = input_size
  42. self.output_size = output_size
  43. self.decoder_conf = decoder_conf
  44. self.vocab_size = vocab_size
  45. self.output_type = output_type
  46. self.input_frame_rate = input_frame_rate
  47. logging.info(f"input frame rate={self.input_frame_rate}")
  48. self.input_embedding = nn.Embedding(vocab_size, input_size)
  49. self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
  50. self.encoder = encoder
  51. self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
  52. self.decoder = decoder
  53. self.length_regulator = length_regulator
  54. self.only_mask_loss = only_mask_loss
  55. def forward(
  56. self,
  57. batch: dict,
  58. device: torch.device,
  59. ) -> Dict[str, Optional[torch.Tensor]]:
  60. token = batch['speech_token'].to(device)
  61. token_len = batch['speech_token_len'].to(device)
  62. feat = batch['speech_feat'].to(device)
  63. feat_len = batch['speech_feat_len'].to(device)
  64. embedding = batch['embedding'].to(device)
  65. # xvec projection
  66. embedding = F.normalize(embedding, dim=1)
  67. embedding = self.spk_embed_affine_layer(embedding)
  68. # concat text and prompt_text
  69. mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
  70. token = self.input_embedding(torch.clamp(token, min=0)) * mask
  71. # text encode
  72. h, h_lengths = self.encoder(token, token_len)
  73. h = self.encoder_proj(h)
  74. h, h_lengths = self.length_regulator(h, feat_len)
  75. # get conditions
  76. conds = torch.zeros(feat.shape, device=token.device)
  77. for i, j in enumerate(feat_len):
  78. if random.random() < 0.5:
  79. continue
  80. index = random.randint(0, int(0.3 * j))
  81. conds[i, :index] = feat[i, :index]
  82. conds = conds.transpose(1, 2)
  83. mask = (~make_pad_mask(feat_len)).to(h)
  84. # NOTE this is unnecessary, feat/h already same shape
  85. loss, _ = self.decoder.compute_loss(
  86. feat.transpose(1, 2).contiguous(),
  87. mask.unsqueeze(1),
  88. h.transpose(1, 2).contiguous(),
  89. embedding,
  90. cond=conds
  91. )
  92. return {'loss': loss}
  93. @torch.inference_mode()
  94. def inference(self,
  95. token,
  96. token_len,
  97. prompt_token,
  98. prompt_token_len,
  99. prompt_feat,
  100. prompt_feat_len,
  101. embedding,
  102. flow_cache):
  103. assert token.shape[0] == 1
  104. # xvec projection
  105. embedding = F.normalize(embedding, dim=1)
  106. embedding = self.spk_embed_affine_layer(embedding)
  107. # concat speech token and prompt speech token
  108. token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
  109. token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
  110. mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
  111. token = self.input_embedding(torch.clamp(token, min=0)) * mask
  112. # text encode
  113. h, h_lengths = self.encoder(token, token_len)
  114. h = self.encoder_proj(h)
  115. mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
  116. h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
  117. # get conditions
  118. conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
  119. conds[:, :mel_len1] = prompt_feat
  120. conds = conds.transpose(1, 2)
  121. mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
  122. feat, flow_cache = self.decoder(
  123. mu=h.transpose(1, 2).contiguous(),
  124. mask=mask.unsqueeze(1),
  125. spks=embedding,
  126. cond=conds,
  127. n_timesteps=10,
  128. prompt_len=mel_len1,
  129. cache=flow_cache
  130. )
  131. feat = feat[:, :, mel_len1:]
  132. assert feat.shape[2] == mel_len2
  133. return feat.float(), flow_cache
  134. class CausalMaskedDiffWithXvec(torch.nn.Module):
  135. def __init__(self,
  136. input_size: int = 512,
  137. output_size: int = 80,
  138. spk_embed_dim: int = 192,
  139. output_type: str = "mel",
  140. vocab_size: int = 4096,
  141. input_frame_rate: int = 50,
  142. only_mask_loss: bool = True,
  143. token_mel_ratio: int = 2,
  144. pre_lookahead_len: int = 3,
  145. encoder: torch.nn.Module = None,
  146. decoder: torch.nn.Module = None,
  147. decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
  148. 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
  149. 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
  150. 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
  151. 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
  152. super().__init__()
  153. self.input_size = input_size
  154. self.output_size = output_size
  155. self.decoder_conf = decoder_conf
  156. self.vocab_size = vocab_size
  157. self.output_type = output_type
  158. self.input_frame_rate = input_frame_rate
  159. logging.info(f"input frame rate={self.input_frame_rate}")
  160. self.input_embedding = nn.Embedding(vocab_size, input_size)
  161. self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
  162. self.encoder = encoder
  163. self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
  164. self.decoder = decoder
  165. self.only_mask_loss = only_mask_loss
  166. self.token_mel_ratio = token_mel_ratio
  167. self.pre_lookahead_len = pre_lookahead_len
  168. def forward(
  169. self,
  170. batch: dict,
  171. device: torch.device,
  172. ) -> Dict[str, Optional[torch.Tensor]]:
  173. token = batch['speech_token'].to(device)
  174. token_len = batch['speech_token_len'].to(device)
  175. feat = batch['speech_feat'].to(device)
  176. feat_len = batch['speech_feat_len'].to(device)
  177. embedding = batch['embedding'].to(device)
  178. # NOTE unified training, static_chunk_size > 0 or = 0
  179. streaming = True if random.random() < 0.5 else False
  180. # xvec projection
  181. embedding = F.normalize(embedding, dim=1)
  182. embedding = self.spk_embed_affine_layer(embedding)
  183. # concat text and prompt_text
  184. mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
  185. token = self.input_embedding(torch.clamp(token, min=0)) * mask
  186. # text encode
  187. h, h_lengths = self.encoder(token, token_len, streaming=streaming)
  188. h = self.encoder_proj(h)
  189. # get conditions
  190. conds = torch.zeros(feat.shape, device=token.device)
  191. for i, j in enumerate(feat_len):
  192. if random.random() < 0.5:
  193. continue
  194. index = random.randint(0, int(0.3 * j))
  195. conds[i, :index] = feat[i, :index]
  196. conds = conds.transpose(1, 2)
  197. mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
  198. loss, _ = self.decoder.compute_loss(
  199. feat.transpose(1, 2).contiguous(),
  200. mask.unsqueeze(1),
  201. h.transpose(1, 2).contiguous(),
  202. embedding,
  203. cond=conds,
  204. streaming=streaming,
  205. )
  206. return {'loss': loss}
  207. @torch.inference_mode()
  208. def inference(self,
  209. token,
  210. token_len,
  211. prompt_token,
  212. prompt_token_len,
  213. prompt_feat,
  214. prompt_feat_len,
  215. embedding,
  216. streaming,
  217. finalize):
  218. assert token.shape[0] == 1
  219. # xvec projection
  220. embedding = F.normalize(embedding, dim=1)
  221. embedding = self.spk_embed_affine_layer(embedding)
  222. # concat text and prompt_text
  223. token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
  224. mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
  225. token = self.input_embedding(torch.clamp(token, min=0)) * mask
  226. # text encode
  227. if finalize is True:
  228. h, h_lengths = self.encoder(token, token_len, streaming=streaming)
  229. else:
  230. token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
  231. h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming)
  232. mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
  233. h = self.encoder_proj(h)
  234. # get conditions
  235. conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
  236. conds[:, :mel_len1] = prompt_feat
  237. conds = conds.transpose(1, 2)
  238. mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
  239. feat, _ = self.decoder(
  240. mu=h.transpose(1, 2).contiguous(),
  241. mask=mask.unsqueeze(1),
  242. spks=embedding,
  243. cond=conds,
  244. n_timesteps=10,
  245. streaming=streaming
  246. )
  247. feat = feat[:, :, mel_len1:]
  248. assert feat.shape[2] == mel_len2
  249. return feat.float(), None
  250. class CausalMaskedDiffWithDiT(torch.nn.Module):
  251. def __init__(self,
  252. input_size: int = 512,
  253. output_size: int = 80,
  254. spk_embed_dim: int = 192,
  255. output_type: str = "mel",
  256. vocab_size: int = 4096,
  257. input_frame_rate: int = 50,
  258. only_mask_loss: bool = True,
  259. token_mel_ratio: int = 2,
  260. pre_lookahead_len: int = 3,
  261. pre_lookahead_layer: torch.nn.Module = None,
  262. decoder: torch.nn.Module = None,
  263. decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
  264. 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
  265. 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
  266. 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
  267. 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
  268. super().__init__()
  269. self.input_size = input_size
  270. self.output_size = output_size
  271. self.decoder_conf = decoder_conf
  272. self.vocab_size = vocab_size
  273. self.output_type = output_type
  274. self.input_frame_rate = input_frame_rate
  275. logging.info(f"input frame rate={self.input_frame_rate}")
  276. self.input_embedding = nn.Embedding(vocab_size, input_size)
  277. self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
  278. self.pre_lookahead_len = pre_lookahead_len
  279. self.pre_lookahead_layer = pre_lookahead_layer
  280. self.decoder = decoder
  281. self.only_mask_loss = only_mask_loss
  282. self.token_mel_ratio = token_mel_ratio
  283. def forward(
  284. self,
  285. batch: dict,
  286. device: torch.device,
  287. ) -> Dict[str, Optional[torch.Tensor]]:
  288. token = batch['speech_token'].to(device)
  289. token_len = batch['speech_token_len'].to(device)
  290. feat = batch['speech_feat'].to(device)
  291. feat_len = batch['speech_feat_len'].to(device)
  292. embedding = batch['embedding'].to(device)
  293. # NOTE unified training, static_chunk_size > 0 or = 0
  294. streaming = True if random.random() < 0.5 else False
  295. # xvec projection
  296. embedding = F.normalize(embedding, dim=1)
  297. embedding = self.spk_embed_affine_layer(embedding)
  298. # concat text and prompt_text
  299. mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
  300. token = self.input_embedding(torch.clamp(token, min=0)) * mask
  301. # text encode
  302. h = self.pre_lookahead_layer(token)
  303. h = h.repeat_interleave(self.token_mel_ratio, dim=1)
  304. mask = mask.repeat_interleave(self.token_mel_ratio, dim=1).squeeze(dim=-1)
  305. # get conditions
  306. conds = torch.zeros(feat.shape, device=token.device)
  307. for i, j in enumerate(feat_len):
  308. if random.random() < 0.5:
  309. continue
  310. index = random.randint(0, int(0.3 * j))
  311. conds[i, :index] = feat[i, :index]
  312. conds = conds.transpose(1, 2)
  313. loss, _ = self.decoder.compute_loss(
  314. feat.transpose(1, 2).contiguous(),
  315. mask.unsqueeze(1),
  316. h.transpose(1, 2).contiguous(),
  317. embedding,
  318. cond=conds,
  319. streaming=streaming,
  320. )
  321. return {'loss': loss}
  322. @torch.inference_mode()
  323. def inference(self,
  324. token,
  325. token_len,
  326. prompt_token,
  327. prompt_token_len,
  328. prompt_feat,
  329. prompt_feat_len,
  330. embedding,
  331. streaming,
  332. finalize):
  333. assert token.shape[0] == 1
  334. # xvec projection
  335. embedding = F.normalize(embedding, dim=1)
  336. embedding = self.spk_embed_affine_layer(embedding)
  337. # concat text and prompt_text
  338. token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
  339. mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
  340. token = self.input_embedding(torch.clamp(token, min=0)) * mask
  341. # text encode
  342. if finalize is True:
  343. h = self.pre_lookahead_layer(token)
  344. else:
  345. h = self.pre_lookahead_layer(token[:, :-self.pre_lookahead_len], context=token[:, -self.pre_lookahead_len:])
  346. h = h.repeat_interleave(self.token_mel_ratio, dim=1)
  347. mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
  348. # get conditions
  349. conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
  350. conds[:, :mel_len1] = prompt_feat
  351. conds = conds.transpose(1, 2)
  352. mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
  353. feat, _ = self.decoder(
  354. mu=h.transpose(1, 2).contiguous(),
  355. mask=mask.unsqueeze(1),
  356. spks=embedding,
  357. cond=conds,
  358. n_timesteps=10,
  359. streaming=streaming
  360. )
  361. feat = feat[:, :, mel_len1:]
  362. assert feat.shape[2] == mel_len2
  363. return feat.float(), None
  364. if __name__ == '__main__':
  365. torch.backends.cudnn.deterministic = True
  366. torch.backends.cudnn.benchmark = False
  367. from hyperpyyaml import load_hyperpyyaml
  368. with open('./pretrained_models/Fun-CosyVoice3-0.5B/cosyvoice3.yaml', 'r') as f:
  369. configs = load_hyperpyyaml(f, overrides={'llm': None, 'hift': None})
  370. model = configs['flow']
  371. device = 'cuda' if torch.cuda.is_available() else 'cpu'
  372. model.to(device)
  373. model.eval()
  374. max_len = 10 * model.decoder.estimator.static_chunk_size
  375. chunk_size = model.decoder.estimator.static_chunk_size
  376. context_size = model.pre_lookahead_layer.pre_lookahead_len
  377. token = torch.randint(0, 6561, size=(1, max_len)).to(device)
  378. token_len = torch.tensor([max_len]).to(device)
  379. prompt_token = torch.randint(0, 6561, size=(1, chunk_size)).to(device)
  380. prompt_token_len = torch.tensor([chunk_size]).to(device)
  381. prompt_feat = torch.rand(1, chunk_size * 2, 80).to(device)
  382. prompt_feat_len = torch.tensor([chunk_size * 2]).to(device)
  383. prompt_embedding = torch.rand(1, 192).to(device)
  384. pred_gt, _ = model.inference(token, token_len, prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, prompt_embedding, streaming=True, finalize=True)
  385. for i in range(0, max_len, chunk_size):
  386. finalize = True if i + chunk_size + context_size >= max_len else False
  387. pred_chunk, _ = model.inference(token[:, :i + chunk_size + context_size], torch.tensor([token[:, :i + chunk_size + context_size].shape[1]]).to(device),
  388. prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, prompt_embedding, streaming=True, finalize=finalize)
  389. pred_chunk = pred_chunk[:, :, i * model.token_mel_ratio:]
  390. print((pred_gt[:, :, i * model.token_mel_ratio: i * model.token_mel_ratio + pred_chunk.shape[2]] - pred_chunk).abs().max().item())