flow_matching.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import threading
  15. import torch
  16. import torch.nn.functional as F
  17. from matcha.models.components.flow_matching import BASECFM
  18. class ConditionalCFM(BASECFM):
  19. def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
  20. super().__init__(
  21. n_feats=in_channels,
  22. cfm_params=cfm_params,
  23. n_spks=n_spks,
  24. spk_emb_dim=spk_emb_dim,
  25. )
  26. self.t_scheduler = cfm_params.t_scheduler
  27. self.training_cfg_rate = cfm_params.training_cfg_rate
  28. self.inference_cfg_rate = cfm_params.inference_cfg_rate
  29. in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
  30. # Just change the architecture of the estimator here
  31. self.estimator = estimator
  32. self.lock = threading.Lock()
  33. @torch.inference_mode()
  34. def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
  35. """Forward diffusion
  36. Args:
  37. mu (torch.Tensor): output of encoder
  38. shape: (batch_size, n_feats, mel_timesteps)
  39. mask (torch.Tensor): output_mask
  40. shape: (batch_size, 1, mel_timesteps)
  41. n_timesteps (int): number of diffusion steps
  42. temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
  43. spks (torch.Tensor, optional): speaker ids. Defaults to None.
  44. shape: (batch_size, spk_emb_dim)
  45. cond: Not used but kept for future purposes
  46. Returns:
  47. sample: generated mel-spectrogram
  48. shape: (batch_size, n_feats, mel_timesteps)
  49. """
  50. z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
  51. cache_size = cache.shape[2]
  52. # fix prompt and overlap part mu and z
  53. if cache_size != 0:
  54. z[:, :, :cache_size] = cache[:, :, :, 0]
  55. mu[:, :, :cache_size] = cache[:, :, :, 1]
  56. z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
  57. mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
  58. cache = torch.stack([z_cache, mu_cache], dim=-1)
  59. t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
  60. if self.t_scheduler == 'cosine':
  61. t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
  62. return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), cache
  63. def solve_euler(self, x, t_span, mu, mask, spks, cond):
  64. """
  65. Fixed euler solver for ODEs.
  66. Args:
  67. x (torch.Tensor): random noise
  68. t_span (torch.Tensor): n_timesteps interpolated
  69. shape: (n_timesteps + 1,)
  70. mu (torch.Tensor): output of encoder
  71. shape: (batch_size, n_feats, mel_timesteps)
  72. mask (torch.Tensor): output_mask
  73. shape: (batch_size, 1, mel_timesteps)
  74. spks (torch.Tensor, optional): speaker ids. Defaults to None.
  75. shape: (batch_size, spk_emb_dim)
  76. cond: Not used but kept for future purposes
  77. """
  78. t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
  79. t = t.unsqueeze(dim=0)
  80. # I am storing this because I can later plot it by putting a debugger here and saving it to a file
  81. # Or in future might add like a return_all_steps flag
  82. sol = []
  83. # Do not use concat, it may cause memory format changed and trt infer with wrong results!
  84. x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
  85. mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
  86. mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
  87. t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
  88. spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
  89. cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
  90. for step in range(1, len(t_span)):
  91. # Classifier-Free Guidance inference introduced in VoiceBox
  92. x_in[:] = x
  93. mask_in[:] = mask
  94. mu_in[0] = mu
  95. t_in[:] = t.unsqueeze(0)
  96. spks_in[0] = spks
  97. cond_in[0] = cond
  98. dphi_dt = self.forward_estimator(
  99. x_in, mask_in,
  100. mu_in, t_in,
  101. spks_in,
  102. cond_in
  103. )
  104. dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
  105. dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
  106. x = x + dt * dphi_dt
  107. t = t + dt
  108. sol.append(x)
  109. if step < len(t_span) - 1:
  110. dt = t_span[step + 1] - t
  111. return sol[-1].float()
  112. def forward_estimator(self, x, mask, mu, t, spks, cond):
  113. if isinstance(self.estimator, torch.nn.Module):
  114. return self.estimator(x, mask, mu, t, spks, cond)
  115. else:
  116. with self.lock:
  117. self.estimator.set_input_shape('x', (2, 80, x.size(2)))
  118. self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
  119. self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
  120. self.estimator.set_input_shape('t', (2,))
  121. self.estimator.set_input_shape('spks', (2, 80))
  122. self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
  123. # run trt engine
  124. assert self.estimator.execute_v2([x.contiguous().data_ptr(),
  125. mask.contiguous().data_ptr(),
  126. mu.contiguous().data_ptr(),
  127. t.contiguous().data_ptr(),
  128. spks.contiguous().data_ptr(),
  129. cond.contiguous().data_ptr(),
  130. x.data_ptr()]) is True
  131. return x
  132. def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
  133. """Computes diffusion loss
  134. Args:
  135. x1 (torch.Tensor): Target
  136. shape: (batch_size, n_feats, mel_timesteps)
  137. mask (torch.Tensor): target mask
  138. shape: (batch_size, 1, mel_timesteps)
  139. mu (torch.Tensor): output of encoder
  140. shape: (batch_size, n_feats, mel_timesteps)
  141. spks (torch.Tensor, optional): speaker embedding. Defaults to None.
  142. shape: (batch_size, spk_emb_dim)
  143. Returns:
  144. loss: conditional flow matching loss
  145. y: conditional flow
  146. shape: (batch_size, n_feats, mel_timesteps)
  147. """
  148. b, _, t = mu.shape
  149. # random timestep
  150. t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
  151. if self.t_scheduler == 'cosine':
  152. t = 1 - torch.cos(t * 0.5 * torch.pi)
  153. # sample noise p(x_0)
  154. z = torch.randn_like(x1)
  155. y = (1 - (1 - self.sigma_min) * t) * z + t * x1
  156. u = x1 - (1 - self.sigma_min) * z
  157. # during training, we randomly drop condition to trade off mode coverage and sample fidelity
  158. if self.training_cfg_rate > 0:
  159. cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
  160. mu = mu * cfg_mask.view(-1, 1, 1)
  161. spks = spks * cfg_mask.view(-1, 1)
  162. cond = cond * cfg_mask.view(-1, 1, 1)
  163. pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
  164. loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
  165. return loss, y
  166. class CausalConditionalCFM(ConditionalCFM):
  167. def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
  168. super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
  169. self.rand_noise = torch.randn([1, 80, 50 * 300])
  170. @torch.inference_mode()
  171. def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, cache={}):
  172. """Forward diffusion
  173. Args:
  174. mu (torch.Tensor): output of encoder
  175. shape: (batch_size, n_feats, mel_timesteps)
  176. mask (torch.Tensor): output_mask
  177. shape: (batch_size, 1, mel_timesteps)
  178. n_timesteps (int): number of diffusion steps
  179. temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
  180. spks (torch.Tensor, optional): speaker ids. Defaults to None.
  181. shape: (batch_size, spk_emb_dim)
  182. cond: Not used but kept for future purposes
  183. Returns:
  184. sample: generated mel-spectrogram
  185. shape: (batch_size, n_feats, mel_timesteps)
  186. """
  187. offset = cache.pop('offset')
  188. z = self.rand_noise[:, :, :mu.size(2) + offset].to(mu.device).to(mu.dtype) * temperature
  189. z = z[:, :, offset:]
  190. offset += mu.size(2)
  191. # fix prompt and overlap part mu and z
  192. t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
  193. if self.t_scheduler == 'cosine':
  194. t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
  195. mel, cache = self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, cache=cache)
  196. cache['offset'] = offset
  197. return mel, cache
  198. def solve_euler(self, x, t_span, mu, mask, spks, cond, cache):
  199. """
  200. Fixed euler solver for ODEs.
  201. Args:
  202. x (torch.Tensor): random noise
  203. t_span (torch.Tensor): n_timesteps interpolated
  204. shape: (n_timesteps + 1,)
  205. mu (torch.Tensor): output of encoder
  206. shape: (batch_size, n_feats, mel_timesteps)
  207. mask (torch.Tensor): output_mask
  208. shape: (batch_size, 1, mel_timesteps)
  209. spks (torch.Tensor, optional): speaker ids. Defaults to None.
  210. shape: (batch_size, spk_emb_dim)
  211. cond: Not used but kept for future purposes
  212. """
  213. t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
  214. t = t.unsqueeze(dim=0)
  215. # I am storing this because I can later plot it by putting a debugger here and saving it to a file
  216. # Or in future might add like a return_all_steps flag
  217. sol = []
  218. # estimator cache for each step
  219. down_blocks_kv_cache_new = torch.zeros(10, 1, 4, 2, x.size(2), 512, 2).to(x)
  220. mid_blocks_kv_cache_new = torch.zeros(10, 12, 4, 2, x.size(2), 512, 2).to(x)
  221. up_blocks_kv_cache_new = torch.zeros(10, 1, 4, 2, x.size(2), 512, 2).to(x)
  222. # Do not use concat, it may cause memory format changed and trt infer with wrong results!
  223. x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
  224. mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
  225. mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
  226. t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
  227. spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
  228. cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
  229. for step in range(1, len(t_span)):
  230. # Classifier-Free Guidance inference introduced in VoiceBox
  231. x_in[:] = x
  232. mask_in[:] = mask
  233. mu_in[0] = mu
  234. t_in[:] = t.unsqueeze(0)
  235. spks_in[0] = spks
  236. cond_in[0] = cond
  237. cache_step = {k: v[step - 1] for k, v in cache.items()}
  238. dphi_dt, cache_step = self.forward_estimator(
  239. x_in, mask_in,
  240. mu_in, t_in,
  241. spks_in,
  242. cond_in,
  243. cache_step
  244. )
  245. cache['down_blocks_conv_cache'][step - 1] = cache_step[0]
  246. down_blocks_kv_cache_new[step - 1] = cache_step[1]
  247. cache['mid_blocks_conv_cache'][step - 1] = cache_step[2]
  248. mid_blocks_kv_cache_new[step - 1] = cache_step[3]
  249. cache['up_blocks_conv_cache'][step - 1] = cache_step[4]
  250. up_blocks_kv_cache_new[step - 1] = cache_step[5]
  251. cache['final_blocks_conv_cache'][step - 1] = cache_step[6]
  252. dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
  253. dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
  254. x = x + dt * dphi_dt
  255. t = t + dt
  256. sol.append(x)
  257. if step < len(t_span) - 1:
  258. dt = t_span[step + 1] - t
  259. cache['down_blocks_kv_cache'] = torch.concat([cache['down_blocks_kv_cache'], down_blocks_kv_cache_new], dim=4)
  260. cache['mid_blocks_kv_cache'] = torch.concat([cache['mid_blocks_kv_cache'], mid_blocks_kv_cache_new], dim=4)
  261. cache['up_blocks_kv_cache'] = torch.concat([cache['up_blocks_kv_cache'], up_blocks_kv_cache_new], dim=4)
  262. return sol[-1].float(), cache
  263. def forward_estimator(self, x, mask, mu, t, spks, cond, cache):
  264. if isinstance(self.estimator, torch.nn.Module):
  265. x, cache1, cache2, cache3, cache4, cache5, cache6, cache7 = self.estimator.forward_chunk(x, mask, mu, t, spks, cond, **cache)
  266. cache = (cache1, cache2, cache3, cache4, cache5, cache6, cache7)
  267. else:
  268. with self.lock:
  269. self.estimator.set_input_shape('x', (2, 80, x.size(2)))
  270. self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
  271. self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
  272. self.estimator.set_input_shape('t', (2,))
  273. self.estimator.set_input_shape('spks', (2, 80))
  274. self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
  275. self.estimator.set_input_shape('down_blocks_conv_cache', cache['down_blocks_conv_cache'].shape)
  276. self.estimator.set_input_shape('down_blocks_kv_cache', cache['down_blocks_kv_cache'].shape)
  277. self.estimator.set_input_shape('mid_blocks_conv_cache', cache['mid_blocks_conv_cache'].shape)
  278. self.estimator.set_input_shape('mid_blocks_kv_cache', cache['mid_blocks_kv_cache'].shape)
  279. self.estimator.set_input_shape('up_blocks_conv_cache', cache['up_blocks_conv_cache'].shape)
  280. self.estimator.set_input_shape('up_blocks_kv_cache', cache['up_blocks_kv_cache'].shape)
  281. self.estimator.set_input_shape('final_blocks_conv_cache', cache['final_blocks_conv_cache'].shape)
  282. # run trt engine
  283. down_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
  284. mid_blocks_kv_cache_out = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x)
  285. up_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
  286. assert self.estimator.execute_v2([x.contiguous().data_ptr(),
  287. mask.contiguous().data_ptr(),
  288. mu.contiguous().data_ptr(),
  289. t.contiguous().data_ptr(),
  290. spks.contiguous().data_ptr(),
  291. cond.contiguous().data_ptr(),
  292. cache['down_blocks_conv_cache'].contiguous().data_ptr(),
  293. cache['down_blocks_kv_cache'].contiguous().data_ptr(),
  294. cache['mid_blocks_conv_cache'].contiguous().data_ptr(),
  295. cache['mid_blocks_kv_cache'].contiguous().data_ptr(),
  296. cache['up_blocks_conv_cache'].contiguous().data_ptr(),
  297. cache['up_blocks_kv_cache'].contiguous().data_ptr(),
  298. cache['final_blocks_conv_cache'].contiguous().data_ptr(),
  299. x.data_ptr(),
  300. cache['down_blocks_conv_cache'].data_ptr(),
  301. down_blocks_kv_cache_out.data_ptr(),
  302. cache['mid_blocks_conv_cache'].data_ptr(),
  303. mid_blocks_kv_cache_out.data_ptr(),
  304. cache['up_blocks_conv_cache'].data_ptr(),
  305. up_blocks_kv_cache_out.data_ptr(),
  306. cache['final_blocks_conv_cache'].data_ptr()]) is True
  307. cache = (cache['down_blocks_conv_cache'],
  308. down_blocks_kv_cache_out,
  309. cache['mid_blocks_conv_cache'],
  310. mid_blocks_kv_cache_out,
  311. cache['up_blocks_conv_cache'],
  312. up_blocks_kv_cache_out,
  313. cache['final_blocks_conv_cache'])
  314. return x, cache