flow.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os, logging
  15. import random
  16. from typing import Dict, Optional
  17. import torch
  18. import torch.nn as nn
  19. from torch.nn import functional as F
  20. from omegaconf import DictConfig
  21. from cosyvoice.utils.mask import make_pad_mask
  22. from cosyvoice.utils.onnx import SpeechTokenExtractor, online_feature, onnx_path
  23. class MaskedDiffWithXvec(torch.nn.Module):
  24. def __init__(self,
  25. input_size: int = 512,
  26. output_size: int = 80,
  27. spk_embed_dim: int = 192,
  28. output_type: str = "mel",
  29. vocab_size: int = 4096,
  30. input_frame_rate: int = 50,
  31. only_mask_loss: bool = True,
  32. encoder: torch.nn.Module = None,
  33. length_regulator: torch.nn.Module = None,
  34. decoder: torch.nn.Module = None,
  35. decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
  36. 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
  37. 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
  38. 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
  39. 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
  40. super().__init__()
  41. self.input_size = input_size
  42. self.output_size = output_size
  43. self.decoder_conf = decoder_conf
  44. self.vocab_size = vocab_size
  45. self.output_type = output_type
  46. self.input_frame_rate = input_frame_rate
  47. logging.info(f"input frame rate={self.input_frame_rate}")
  48. self.input_embedding = nn.Embedding(vocab_size, input_size)
  49. self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
  50. self.encoder = encoder
  51. self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
  52. self.decoder = decoder
  53. self.length_regulator = length_regulator
  54. self.only_mask_loss = only_mask_loss
  55. def forward(
  56. self,
  57. batch: dict,
  58. device: torch.device,
  59. ) -> Dict[str, Optional[torch.Tensor]]:
  60. token = batch['speech_token'].to(device)
  61. token_len = batch['speech_token_len'].to(device)
  62. feat = batch['speech_feat'].to(device)
  63. feat_len = batch['speech_feat_len'].to(device)
  64. embedding = batch['embedding'].to(device)
  65. # xvec projection
  66. embedding = F.normalize(embedding, dim=1)
  67. embedding = self.spk_embed_affine_layer(embedding)
  68. # concat text and prompt_text
  69. mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
  70. token = self.input_embedding(torch.clamp(token, min=0)) * mask
  71. # text encode
  72. h, h_lengths = self.encoder(token, token_len)
  73. h = self.encoder_proj(h)
  74. h, h_lengths = self.length_regulator(h, feat_len)
  75. # get conditions
  76. conds = torch.zeros(feat.shape, device=token.device)
  77. for i, j in enumerate(feat_len):
  78. if random.random() < 0.5:
  79. continue
  80. index = random.randint(0, int(0.3 * j))
  81. conds[i, :index] = feat[i, :index]
  82. conds = conds.transpose(1, 2)
  83. mask = (~make_pad_mask(feat_len)).to(h)
  84. # NOTE this is unnecessary, feat/h already same shape
  85. loss, _ = self.decoder.compute_loss(
  86. feat.transpose(1, 2).contiguous(),
  87. mask.unsqueeze(1),
  88. h.transpose(1, 2).contiguous(),
  89. embedding,
  90. cond=conds
  91. )
  92. return {'loss': loss}
  93. @torch.inference_mode()
  94. def inference(self,
  95. token,
  96. token_len,
  97. prompt_token,
  98. prompt_token_len,
  99. prompt_feat,
  100. prompt_feat_len,
  101. embedding,
  102. flow_cache):
  103. assert token.shape[0] == 1
  104. # xvec projection
  105. embedding = F.normalize(embedding, dim=1)
  106. embedding = self.spk_embed_affine_layer(embedding)
  107. # concat speech token and prompt speech token
  108. token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
  109. token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
  110. mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
  111. token = self.input_embedding(torch.clamp(token, min=0)) * mask
  112. # text encode
  113. h, h_lengths = self.encoder(token, token_len)
  114. h = self.encoder_proj(h)
  115. mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
  116. h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
  117. # get conditions
  118. conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
  119. conds[:, :mel_len1] = prompt_feat
  120. conds = conds.transpose(1, 2)
  121. mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
  122. feat, flow_cache = self.decoder(
  123. mu=h.transpose(1, 2).contiguous(),
  124. mask=mask.unsqueeze(1),
  125. spks=embedding,
  126. cond=conds,
  127. n_timesteps=10,
  128. prompt_len=mel_len1,
  129. cache=flow_cache
  130. )
  131. feat = feat[:, :, mel_len1:]
  132. assert feat.shape[2] == mel_len2
  133. return feat.float(), flow_cache
  134. class CausalMaskedDiffWithXvec(torch.nn.Module):
  135. def __init__(self,
  136. input_size: int = 512,
  137. output_size: int = 80,
  138. spk_embed_dim: int = 192,
  139. output_type: str = "mel",
  140. vocab_size: int = 4096,
  141. input_frame_rate: int = 50,
  142. only_mask_loss: bool = True,
  143. token_mel_ratio: int = 2,
  144. pre_lookahead_len: int = 3,
  145. encoder: torch.nn.Module = None,
  146. decoder: torch.nn.Module = None,
  147. decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
  148. 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
  149. 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
  150. 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
  151. 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
  152. super().__init__()
  153. self.input_size = input_size
  154. self.output_size = output_size
  155. self.decoder_conf = decoder_conf
  156. self.vocab_size = vocab_size
  157. self.output_type = output_type
  158. self.input_frame_rate = input_frame_rate
  159. logging.info(f"input frame rate={self.input_frame_rate}")
  160. self.input_embedding = nn.Embedding(vocab_size, input_size)
  161. self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
  162. self.encoder = encoder
  163. self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
  164. self.decoder = decoder
  165. self.only_mask_loss = only_mask_loss
  166. self.token_mel_ratio = token_mel_ratio
  167. self.pre_lookahead_len = pre_lookahead_len
  168. if online_feature is True:
  169. self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v2.batch.onnx'))
  170. def forward(
  171. self,
  172. batch: dict,
  173. device: torch.device,
  174. ) -> Dict[str, Optional[torch.Tensor]]:
  175. if 'speech_token' not in batch:
  176. token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
  177. else:
  178. token = batch['speech_token'].to(device)
  179. token_len = batch['speech_token_len'].to(device)
  180. feat = batch['speech_feat'].to(device)
  181. feat_len = batch['speech_feat_len'].to(device)
  182. embedding = batch['embedding'].to(device)
  183. # NOTE unified training, static_chunk_size > 0 or = 0
  184. streaming = True if random.random() < 0.5 else False
  185. # xvec projection
  186. embedding = F.normalize(embedding, dim=1)
  187. embedding = self.spk_embed_affine_layer(embedding)
  188. # concat text and prompt_text
  189. mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
  190. token = self.input_embedding(torch.clamp(token, min=0)) * mask
  191. # text encode
  192. h, h_lengths = self.encoder(token, token_len, streaming=streaming)
  193. h = self.encoder_proj(h)
  194. # get conditions
  195. conds = torch.zeros(feat.shape, device=token.device)
  196. for i, j in enumerate(feat_len):
  197. if random.random() < 0.5:
  198. continue
  199. index = random.randint(0, int(0.3 * j))
  200. conds[i, :index] = feat[i, :index]
  201. conds = conds.transpose(1, 2)
  202. mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
  203. loss, _ = self.decoder.compute_loss(
  204. feat.transpose(1, 2).contiguous(),
  205. mask.unsqueeze(1),
  206. h.transpose(1, 2).contiguous(),
  207. embedding,
  208. cond=conds,
  209. streaming=streaming,
  210. )
  211. return {'loss': loss}
  212. @torch.inference_mode()
  213. def inference(self,
  214. token,
  215. token_len,
  216. prompt_token,
  217. prompt_token_len,
  218. prompt_feat,
  219. prompt_feat_len,
  220. embedding,
  221. streaming,
  222. finalize):
  223. assert token.shape[0] == 1
  224. # xvec projection
  225. embedding = F.normalize(embedding, dim=1)
  226. embedding = self.spk_embed_affine_layer(embedding)
  227. # concat text and prompt_text
  228. token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
  229. mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
  230. token = self.input_embedding(torch.clamp(token, min=0)) * mask
  231. # text encode
  232. if finalize is True:
  233. h, h_lengths = self.encoder(token, token_len, streaming=streaming)
  234. else:
  235. token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
  236. h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming)
  237. mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
  238. h = self.encoder_proj(h)
  239. # get conditions
  240. conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
  241. conds[:, :mel_len1] = prompt_feat
  242. conds = conds.transpose(1, 2)
  243. mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
  244. feat, _ = self.decoder(
  245. mu=h.transpose(1, 2).contiguous(),
  246. mask=mask.unsqueeze(1),
  247. spks=embedding,
  248. cond=conds,
  249. n_timesteps=10,
  250. streaming=streaming
  251. )
  252. feat = feat[:, :, mel_len1:]
  253. assert feat.shape[2] == mel_len2
  254. return feat.float(), None
  255. class CausalMaskedDiffWithDiT(torch.nn.Module):
  256. def __init__(self,
  257. input_size: int = 512,
  258. output_size: int = 80,
  259. spk_embed_dim: int = 192,
  260. output_type: str = "mel",
  261. vocab_size: int = 4096,
  262. input_frame_rate: int = 50,
  263. only_mask_loss: bool = True,
  264. token_mel_ratio: int = 2,
  265. pre_lookahead_len: int = 3,
  266. pre_lookahead_layer: torch.nn.Module = None,
  267. decoder: torch.nn.Module = None,
  268. decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
  269. 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
  270. 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
  271. 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
  272. 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
  273. super().__init__()
  274. self.input_size = input_size
  275. self.output_size = output_size
  276. self.decoder_conf = decoder_conf
  277. self.vocab_size = vocab_size
  278. self.output_type = output_type
  279. self.input_frame_rate = input_frame_rate
  280. logging.info(f"input frame rate={self.input_frame_rate}")
  281. self.input_embedding = nn.Embedding(vocab_size, input_size)
  282. self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
  283. self.pre_lookahead_len = pre_lookahead_len
  284. self.pre_lookahead_layer = pre_lookahead_layer
  285. self.decoder = decoder
  286. self.only_mask_loss = only_mask_loss
  287. self.token_mel_ratio = token_mel_ratio
  288. if online_feature is True:
  289. self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v3.batch.onnx'))
  290. def forward(
  291. self,
  292. batch: dict,
  293. device: torch.device,
  294. ) -> Dict[str, Optional[torch.Tensor]]:
  295. if 'speech_token' not in batch:
  296. token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
  297. else:
  298. token = batch['speech_token'].to(device)
  299. token_len = batch['speech_token_len'].to(device)
  300. feat = batch['speech_feat'].to(device)
  301. feat_len = batch['speech_feat_len'].to(device)
  302. embedding = batch['embedding'].to(device)
  303. # NOTE unified training, static_chunk_size > 0 or = 0
  304. streaming = True if random.random() < 0.5 else False
  305. # xvec projection
  306. embedding = F.normalize(embedding, dim=1)
  307. embedding = self.spk_embed_affine_layer(embedding)
  308. # concat text and prompt_text
  309. mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
  310. token = self.input_embedding(torch.clamp(token, min=0)) * mask
  311. # text encode
  312. h = self.pre_lookahead_layer(token)
  313. h = h.repeat_interleave(self.token_mel_ratio, dim=1)
  314. mask = mask.repeat_interleave(self.token_mel_ratio, dim=1).squeeze(dim=-1)
  315. # get conditions
  316. conds = torch.zeros(feat.shape, device=token.device)
  317. for i, j in enumerate(feat_len):
  318. if random.random() < 0.5:
  319. continue
  320. index = random.randint(0, int(0.3 * j))
  321. conds[i, :index] = feat[i, :index]
  322. conds = conds.transpose(1, 2)
  323. loss, _ = self.decoder.compute_loss(
  324. feat.transpose(1, 2).contiguous(),
  325. mask.unsqueeze(1),
  326. h.transpose(1, 2).contiguous(),
  327. embedding,
  328. cond=conds,
  329. streaming=streaming,
  330. )
  331. return {'loss': loss}
  332. @torch.inference_mode()
  333. def inference(self,
  334. token,
  335. token_len,
  336. prompt_token,
  337. prompt_token_len,
  338. prompt_feat,
  339. prompt_feat_len,
  340. embedding,
  341. streaming,
  342. finalize):
  343. assert token.shape[0] == 1
  344. # xvec projection
  345. embedding = F.normalize(embedding, dim=1)
  346. embedding = self.spk_embed_affine_layer(embedding)
  347. # concat text and prompt_text
  348. token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
  349. mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
  350. token = self.input_embedding(torch.clamp(token, min=0)) * mask
  351. # text encode
  352. if finalize is True:
  353. h = self.pre_lookahead_layer(token)
  354. else:
  355. h = self.pre_lookahead_layer(token[:, :-self.pre_lookahead_len], context=token[:, -self.pre_lookahead_len:])
  356. h = h.repeat_interleave(self.token_mel_ratio, dim=1)
  357. mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
  358. # get conditions
  359. conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
  360. conds[:, :mel_len1] = prompt_feat
  361. conds = conds.transpose(1, 2)
  362. mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
  363. feat, _ = self.decoder(
  364. mu=h.transpose(1, 2).contiguous(),
  365. mask=mask.unsqueeze(1),
  366. spks=embedding,
  367. cond=conds,
  368. n_timesteps=10,
  369. streaming=streaming
  370. )
  371. feat = feat[:, :, mel_len1:]
  372. assert feat.shape[2] == mel_len2
  373. return feat.float(), None
  374. if __name__ == '__main__':
  375. torch.backends.cudnn.deterministic = True
  376. torch.backends.cudnn.benchmark = False
  377. from hyperpyyaml import load_hyperpyyaml
  378. with open('./pretrained_models/Fun-CosyVoice3-0.5B/cosyvoice3.yaml', 'r') as f:
  379. configs = load_hyperpyyaml(f, overrides={'llm': None, 'hift': None})
  380. model = configs['flow']
  381. device = 'cuda' if torch.cuda.is_available() else 'cpu'
  382. model.to(device)
  383. model.eval()
  384. max_len = 10 * model.decoder.estimator.static_chunk_size
  385. chunk_size = model.decoder.estimator.static_chunk_size
  386. context_size = model.pre_lookahead_layer.pre_lookahead_len
  387. token = torch.randint(0, 6561, size=(1, max_len)).to(device)
  388. token_len = torch.tensor([max_len]).to(device)
  389. prompt_token = torch.randint(0, 6561, size=(1, chunk_size)).to(device)
  390. prompt_token_len = torch.tensor([chunk_size]).to(device)
  391. prompt_feat = torch.rand(1, chunk_size * 2, 80).to(device)
  392. prompt_feat_len = torch.tensor([chunk_size * 2]).to(device)
  393. prompt_embedding = torch.rand(1, 192).to(device)
  394. pred_gt, _ = model.inference(token, token_len, prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, prompt_embedding, streaming=True, finalize=True)
  395. for i in range(0, max_len, chunk_size):
  396. finalize = True if i + chunk_size + context_size >= max_len else False
  397. pred_chunk, _ = model.inference(token[:, :i + chunk_size + context_size], torch.tensor([token[:, :i + chunk_size + context_size].shape[1]]).to(device),
  398. prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, prompt_embedding, streaming=True, finalize=finalize)
  399. pred_chunk = pred_chunk[:, :, i * model.token_mel_ratio:]
  400. print((pred_gt[:, :, i * model.token_mel_ratio: i * model.token_mel_ratio + pred_chunk.shape[2]] - pred_chunk).abs().max().item())