flow.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os, logging
  15. import random
  16. from typing import Dict, Optional
  17. import torch
  18. import torch.nn as nn
  19. from torch.nn import functional as F
  20. from omegaconf import DictConfig
  21. from cosyvoice.utils.mask import make_pad_mask
  22. from cosyvoice.utils.onnx import SpeechTokenExtractor, online_feature, onnx_path
  23. class MaskedDiffWithXvec(torch.nn.Module):
  24. def __init__(self,
  25. input_size: int = 512,
  26. output_size: int = 80,
  27. spk_embed_dim: int = 192,
  28. output_type: str = "mel",
  29. vocab_size: int = 4096,
  30. input_frame_rate: int = 50,
  31. only_mask_loss: bool = True,
  32. encoder: torch.nn.Module = None,
  33. length_regulator: torch.nn.Module = None,
  34. decoder: torch.nn.Module = None,
  35. decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
  36. 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
  37. 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
  38. 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
  39. 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
  40. super().__init__()
  41. self.input_size = input_size
  42. self.output_size = output_size
  43. self.decoder_conf = decoder_conf
  44. self.vocab_size = vocab_size
  45. self.output_type = output_type
  46. self.input_frame_rate = input_frame_rate
  47. logging.info(f"input frame rate={self.input_frame_rate}")
  48. self.input_embedding = nn.Embedding(vocab_size, input_size)
  49. self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
  50. self.encoder = encoder
  51. self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
  52. self.decoder = decoder
  53. self.length_regulator = length_regulator
  54. self.only_mask_loss = only_mask_loss
  55. def forward(
  56. self,
  57. batch: dict,
  58. device: torch.device,
  59. ) -> Dict[str, Optional[torch.Tensor]]:
  60. token = batch['speech_token'].to(device)
  61. token_len = batch['speech_token_len'].to(device)
  62. feat = batch['speech_feat'].to(device)
  63. feat_len = batch['speech_feat_len'].to(device)
  64. embedding = batch['embedding'].to(device)
  65. # xvec projection
  66. embedding = F.normalize(embedding, dim=1)
  67. embedding = self.spk_embed_affine_layer(embedding)
  68. # concat text and prompt_text
  69. mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
  70. token = self.input_embedding(torch.clamp(token, min=0)) * mask
  71. # text encode
  72. h, h_lengths = self.encoder(token, token_len)
  73. h = self.encoder_proj(h)
  74. h, h_lengths = self.length_regulator(h, feat_len)
  75. # get conditions
  76. conds = torch.zeros(feat.shape, device=token.device)
  77. for i, j in enumerate(feat_len):
  78. if random.random() < 0.5:
  79. continue
  80. index = random.randint(0, int(0.3 * j))
  81. conds[i, :index] = feat[i, :index]
  82. conds = conds.transpose(1, 2)
  83. mask = (~make_pad_mask(feat_len)).to(h)
  84. # NOTE this is unnecessary, feat/h already same shape
  85. loss, _ = self.decoder.compute_loss(
  86. feat.transpose(1, 2).contiguous(),
  87. mask.unsqueeze(1),
  88. h.transpose(1, 2).contiguous(),
  89. embedding,
  90. cond=conds
  91. )
  92. return {'loss': loss}
  93. @torch.inference_mode()
  94. def inference(self,
  95. token,
  96. token_len,
  97. prompt_token,
  98. prompt_token_len,
  99. prompt_feat,
  100. prompt_feat_len,
  101. embedding,
  102. flow_cache):
  103. assert token.shape[0] == 1
  104. # xvec projection
  105. embedding = F.normalize(embedding, dim=1)
  106. embedding = self.spk_embed_affine_layer(embedding)
  107. # concat speech token and prompt speech token
  108. token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
  109. token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
  110. mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
  111. token = self.input_embedding(torch.clamp(token, min=0)) * mask
  112. # text encode
  113. h, h_lengths = self.encoder(token, token_len)
  114. h = self.encoder_proj(h)
  115. mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
  116. h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
  117. # get conditions
  118. conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
  119. conds[:, :mel_len1] = prompt_feat
  120. conds = conds.transpose(1, 2)
  121. mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
  122. feat, flow_cache = self.decoder(
  123. mu=h.transpose(1, 2).contiguous(),
  124. mask=mask.unsqueeze(1),
  125. spks=embedding,
  126. cond=conds,
  127. n_timesteps=10,
  128. prompt_len=mel_len1,
  129. cache=flow_cache
  130. )
  131. feat = feat[:, :, mel_len1:]
  132. assert feat.shape[2] == mel_len2
  133. return feat.float(), flow_cache
  134. class CausalMaskedDiffWithXvec(torch.nn.Module):
  135. def __init__(self,
  136. input_size: int = 512,
  137. output_size: int = 80,
  138. spk_embed_dim: int = 192,
  139. output_type: str = "mel",
  140. vocab_size: int = 4096,
  141. input_frame_rate: int = 50,
  142. only_mask_loss: bool = True,
  143. token_mel_ratio: int = 2,
  144. pre_lookahead_len: int = 3,
  145. encoder: torch.nn.Module = None,
  146. decoder: torch.nn.Module = None,
  147. decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
  148. 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
  149. 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
  150. 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
  151. 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
  152. super().__init__()
  153. self.input_size = input_size
  154. self.output_size = output_size
  155. self.decoder_conf = decoder_conf
  156. self.vocab_size = vocab_size
  157. self.output_type = output_type
  158. self.input_frame_rate = input_frame_rate
  159. logging.info(f"input frame rate={self.input_frame_rate}")
  160. self.input_embedding = nn.Embedding(vocab_size, input_size)
  161. self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
  162. self.encoder = encoder
  163. self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
  164. self.decoder = decoder
  165. self.only_mask_loss = only_mask_loss
  166. self.token_mel_ratio = token_mel_ratio
  167. self.pre_lookahead_len = pre_lookahead_len
  168. if online_feature is True:
  169. self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v2.batch.onnx'))
  170. def forward(
  171. self,
  172. batch: dict,
  173. device: torch.device,
  174. ) -> Dict[str, Optional[torch.Tensor]]:
  175. if 'speech_token' not in batch:
  176. token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'])
  177. else:
  178. token = batch['speech_token'].to(device)
  179. token_len = batch['speech_token_len'].to(device)
  180. feat = batch['speech_feat'].to(device)
  181. feat_len = batch['speech_feat_len'].to(device)
  182. embedding = batch['embedding'].to(device)
  183. # NOTE unified training, static_chunk_size > 0 or = 0
  184. streaming = True if random.random() < 0.5 else False
  185. # xvec projection
  186. embedding = F.normalize(embedding, dim=1)
  187. embedding = self.spk_embed_affine_layer(embedding)
  188. # concat text and prompt_text
  189. mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
  190. token = self.input_embedding(torch.clamp(token, min=0)) * mask
  191. # text encode
  192. h, h_lengths = self.encoder(token, token_len, streaming=streaming)
  193. h = self.encoder_proj(h)
  194. # get conditions
  195. conds = torch.zeros(feat.shape, device=token.device)
  196. for i, j in enumerate(feat_len):
  197. if random.random() < 0.5:
  198. continue
  199. index = random.randint(0, int(0.3 * j))
  200. conds[i, :index] = feat[i, :index]
  201. conds = conds.transpose(1, 2)
  202. mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
  203. loss, _ = self.decoder.compute_loss(
  204. feat.transpose(1, 2).contiguous(),
  205. mask.unsqueeze(1),
  206. h.transpose(1, 2).contiguous(),
  207. embedding,
  208. cond=conds,
  209. streaming=streaming,
  210. )
  211. return {'loss': loss}
  212. @torch.inference_mode()
  213. def inference(self,
  214. token,
  215. token_len,
  216. prompt_token,
  217. prompt_token_len,
  218. prompt_feat,
  219. prompt_feat_len,
  220. embedding,
  221. streaming,
  222. finalize):
  223. assert token.shape[0] == 1
  224. # xvec projection
  225. embedding = F.normalize(embedding, dim=1)
  226. embedding = self.spk_embed_affine_layer(embedding)
  227. # concat text and prompt_text
  228. token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
  229. mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
  230. token = self.input_embedding(torch.clamp(token, min=0)) * mask
  231. # text encode
  232. if finalize is True:
  233. h, h_lengths = self.encoder(token, token_len, streaming=streaming)
  234. else:
  235. token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
  236. h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming)
  237. mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
  238. h = self.encoder_proj(h)
  239. # get conditions
  240. conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
  241. conds[:, :mel_len1] = prompt_feat
  242. conds = conds.transpose(1, 2)
  243. mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
  244. feat, _ = self.decoder(
  245. mu=h.transpose(1, 2).contiguous(),
  246. mask=mask.unsqueeze(1),
  247. spks=embedding,
  248. cond=conds,
  249. n_timesteps=10,
  250. streaming=streaming
  251. )
  252. feat = feat[:, :, mel_len1:]
  253. assert feat.shape[2] == mel_len2
  254. return feat.float(), None
  255. class CausalMaskedDiffWithDiT(torch.nn.Module):
  256. def __init__(self,
  257. input_size: int = 512,
  258. output_size: int = 80,
  259. spk_embed_dim: int = 192,
  260. output_type: str = "mel",
  261. vocab_size: int = 4096,
  262. input_frame_rate: int = 50,
  263. only_mask_loss: bool = True,
  264. token_mel_ratio: int = 2,
  265. pre_lookahead_len: int = 3,
  266. pre_lookahead_layer: torch.nn.Module = None,
  267. decoder: torch.nn.Module = None,
  268. decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
  269. 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
  270. 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
  271. 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
  272. 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
  273. super().__init__()
  274. self.input_size = input_size
  275. self.output_size = output_size
  276. self.decoder_conf = decoder_conf
  277. self.vocab_size = vocab_size
  278. self.output_type = output_type
  279. self.input_frame_rate = input_frame_rate
  280. logging.info(f"input frame rate={self.input_frame_rate}")
  281. self.input_embedding = nn.Embedding(vocab_size, input_size)
  282. self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
  283. self.pre_lookahead_len = pre_lookahead_len
  284. self.pre_lookahead_layer = pre_lookahead_layer
  285. self.decoder = decoder
  286. self.only_mask_loss = only_mask_loss
  287. self.token_mel_ratio = token_mel_ratio
  288. if online_feature is True:
  289. self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v3.batch.onnx'))
  290. def forward(
  291. self,
  292. batch: dict,
  293. device: torch.device,
  294. ) -> Dict[str, Optional[torch.Tensor]]:
  295. token = batch['speech_token'].to(device)
  296. token_len = batch['speech_token_len'].to(device)
  297. feat = batch['speech_feat'].to(device)
  298. feat_len = batch['speech_feat_len'].to(device)
  299. embedding = batch['embedding'].to(device)
  300. # NOTE unified training, static_chunk_size > 0 or = 0
  301. streaming = True if random.random() < 0.5 else False
  302. # xvec projection
  303. embedding = F.normalize(embedding, dim=1)
  304. embedding = self.spk_embed_affine_layer(embedding)
  305. # concat text and prompt_text
  306. mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
  307. token = self.input_embedding(torch.clamp(token, min=0)) * mask
  308. # text encode
  309. h = self.pre_lookahead_layer(token)
  310. h = h.repeat_interleave(self.token_mel_ratio, dim=1)
  311. mask = mask.repeat_interleave(self.token_mel_ratio, dim=1).squeeze(dim=-1)
  312. # get conditions
  313. conds = torch.zeros(feat.shape, device=token.device)
  314. for i, j in enumerate(feat_len):
  315. if random.random() < 0.5:
  316. continue
  317. index = random.randint(0, int(0.3 * j))
  318. conds[i, :index] = feat[i, :index]
  319. conds = conds.transpose(1, 2)
  320. loss, _ = self.decoder.compute_loss(
  321. feat.transpose(1, 2).contiguous(),
  322. mask.unsqueeze(1),
  323. h.transpose(1, 2).contiguous(),
  324. embedding,
  325. cond=conds,
  326. streaming=streaming,
  327. )
  328. return {'loss': loss}
  329. @torch.inference_mode()
  330. def inference(self,
  331. token,
  332. token_len,
  333. prompt_token,
  334. prompt_token_len,
  335. prompt_feat,
  336. prompt_feat_len,
  337. embedding,
  338. streaming,
  339. finalize):
  340. assert token.shape[0] == 1
  341. # xvec projection
  342. embedding = F.normalize(embedding, dim=1)
  343. embedding = self.spk_embed_affine_layer(embedding)
  344. # concat text and prompt_text
  345. token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
  346. mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
  347. token = self.input_embedding(torch.clamp(token, min=0)) * mask
  348. # text encode
  349. if finalize is True:
  350. h = self.pre_lookahead_layer(token)
  351. else:
  352. h = self.pre_lookahead_layer(token[:, :-self.pre_lookahead_len], context=token[:, -self.pre_lookahead_len:])
  353. h = h.repeat_interleave(self.token_mel_ratio, dim=1)
  354. mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
  355. # get conditions
  356. conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
  357. conds[:, :mel_len1] = prompt_feat
  358. conds = conds.transpose(1, 2)
  359. mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
  360. feat, _ = self.decoder(
  361. mu=h.transpose(1, 2).contiguous(),
  362. mask=mask.unsqueeze(1),
  363. spks=embedding,
  364. cond=conds,
  365. n_timesteps=10,
  366. streaming=streaming
  367. )
  368. feat = feat[:, :, mel_len1:]
  369. assert feat.shape[2] == mel_len2
  370. return feat.float(), None
  371. if __name__ == '__main__':
  372. torch.backends.cudnn.deterministic = True
  373. torch.backends.cudnn.benchmark = False
  374. from hyperpyyaml import load_hyperpyyaml
  375. with open('./pretrained_models/Fun-CosyVoice3-0.5B/cosyvoice3.yaml', 'r') as f:
  376. configs = load_hyperpyyaml(f, overrides={'llm': None, 'hift': None})
  377. model = configs['flow']
  378. device = 'cuda' if torch.cuda.is_available() else 'cpu'
  379. model.to(device)
  380. model.eval()
  381. max_len = 10 * model.decoder.estimator.static_chunk_size
  382. chunk_size = model.decoder.estimator.static_chunk_size
  383. context_size = model.pre_lookahead_layer.pre_lookahead_len
  384. token = torch.randint(0, 6561, size=(1, max_len)).to(device)
  385. token_len = torch.tensor([max_len]).to(device)
  386. prompt_token = torch.randint(0, 6561, size=(1, chunk_size)).to(device)
  387. prompt_token_len = torch.tensor([chunk_size]).to(device)
  388. prompt_feat = torch.rand(1, chunk_size * 2, 80).to(device)
  389. prompt_feat_len = torch.tensor([chunk_size * 2]).to(device)
  390. prompt_embedding = torch.rand(1, 192).to(device)
  391. pred_gt, _ = model.inference(token, token_len, prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, prompt_embedding, streaming=True, finalize=True)
  392. for i in range(0, max_len, chunk_size):
  393. finalize = True if i + chunk_size + context_size >= max_len else False
  394. pred_chunk, _ = model.inference(token[:, :i + chunk_size + context_size], torch.tensor([token[:, :i + chunk_size + context_size].shape[1]]).to(device),
  395. prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, prompt_embedding, streaming=True, finalize=finalize)
  396. pred_chunk = pred_chunk[:, :, i * model.token_mel_ratio:]
  397. print((pred_gt[:, :, i * model.token_mel_ratio: i * model.token_mel_ratio + pred_chunk.shape[2]] - pred_chunk).abs().max().item())