flow_matching.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import onnxruntime
  15. import torch
  16. import torch.nn.functional as F
  17. from matcha.models.components.flow_matching import BASECFM
  18. class ConditionalCFM(BASECFM):
  19. def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
  20. super().__init__(
  21. n_feats=in_channels,
  22. cfm_params=cfm_params,
  23. n_spks=n_spks,
  24. spk_emb_dim=spk_emb_dim,
  25. )
  26. self.t_scheduler = cfm_params.t_scheduler
  27. self.training_cfg_rate = cfm_params.training_cfg_rate
  28. self.inference_cfg_rate = cfm_params.inference_cfg_rate
  29. in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
  30. # Just change the architecture of the estimator here
  31. self.estimator = estimator
  32. @torch.inference_mode()
  33. def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
  34. """Forward diffusion
  35. Args:
  36. mu (torch.Tensor): output of encoder
  37. shape: (batch_size, n_feats, mel_timesteps)
  38. mask (torch.Tensor): output_mask
  39. shape: (batch_size, 1, mel_timesteps)
  40. n_timesteps (int): number of diffusion steps
  41. temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
  42. spks (torch.Tensor, optional): speaker ids. Defaults to None.
  43. shape: (batch_size, spk_emb_dim)
  44. cond: Not used but kept for future purposes
  45. Returns:
  46. sample: generated mel-spectrogram
  47. shape: (batch_size, n_feats, mel_timesteps)
  48. """
  49. z = torch.randn_like(mu) * temperature
  50. cache_size = flow_cache.shape[2]
  51. # fix prompt and overlap part mu and z
  52. if cache_size != 0:
  53. z[:, :, :cache_size] = flow_cache[:, :, :, 0]
  54. mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
  55. z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
  56. mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
  57. flow_cache = torch.stack([z_cache, mu_cache], dim=-1)
  58. t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
  59. if self.t_scheduler == 'cosine':
  60. t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
  61. return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
  62. def solve_euler(self, x, t_span, mu, mask, spks, cond):
  63. """
  64. Fixed euler solver for ODEs.
  65. Args:
  66. x (torch.Tensor): random noise
  67. t_span (torch.Tensor): n_timesteps interpolated
  68. shape: (n_timesteps + 1,)
  69. mu (torch.Tensor): output of encoder
  70. shape: (batch_size, n_feats, mel_timesteps)
  71. mask (torch.Tensor): output_mask
  72. shape: (batch_size, 1, mel_timesteps)
  73. spks (torch.Tensor, optional): speaker ids. Defaults to None.
  74. shape: (batch_size, spk_emb_dim)
  75. cond: Not used but kept for future purposes
  76. """
  77. t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
  78. t = t.unsqueeze(dim=0)
  79. # I am storing this because I can later plot it by putting a debugger here and saving it to a file
  80. # Or in future might add like a return_all_steps flag
  81. sol = []
  82. if self.inference_cfg_rate > 0:
  83. # Do not use concat, it may cause memory format changed and trt infer with wrong results!
  84. x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
  85. mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
  86. mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
  87. t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
  88. spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
  89. cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
  90. else:
  91. x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
  92. for step in range(1, len(t_span)):
  93. # Classifier-Free Guidance inference introduced in VoiceBox
  94. if self.inference_cfg_rate > 0:
  95. x_in[:] = x
  96. mask_in[:] = mask
  97. mu_in[0] = mu
  98. t_in[:] = t.unsqueeze(0)
  99. spks_in[0] = spks
  100. cond_in[0] = cond
  101. else:
  102. x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
  103. dphi_dt = self.forward_estimator(
  104. x_in, mask_in,
  105. mu_in, t_in,
  106. spks_in,
  107. cond_in
  108. )
  109. if self.inference_cfg_rate > 0:
  110. dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
  111. dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
  112. x = x + dt * dphi_dt
  113. t = t + dt
  114. sol.append(x)
  115. if step < len(t_span) - 1:
  116. dt = t_span[step + 1] - t
  117. return sol[-1].float()
  118. def forward_estimator(self, x, mask, mu, t, spks, cond):
  119. if isinstance(self.estimator, torch.nn.Module):
  120. return self.estimator.forward(x, mask, mu, t, spks, cond)
  121. elif isinstance(self.estimator, onnxruntime.InferenceSession):
  122. ort_inputs = {
  123. 'x': x.cpu().numpy(),
  124. 'mask': mask.cpu().numpy(),
  125. 'mu': mu.cpu().numpy(),
  126. 't': t.cpu().numpy(),
  127. 'spk': spks.cpu().numpy(),
  128. 'cond': cond.cpu().numpy(),
  129. 'mask_rand': torch.randn(1, 1, 1).numpy()
  130. }
  131. output = self.estimator.run(None, ort_inputs)[0]
  132. return torch.tensor(output, dtype=x.dtype, device=x.device)
  133. else:
  134. if not x.is_contiguous():
  135. x = x.contiguous()
  136. if not mask.is_contiguous():
  137. mask = mask.contiguous()
  138. if not mu.is_contiguous():
  139. mu = mu.contiguous()
  140. if not t.is_contiguous():
  141. t = t.contiguous()
  142. if not spks.is_contiguous():
  143. spks = spks.contiguous()
  144. if not cond.is_contiguous():
  145. cond = cond.contiguous()
  146. self.estimator.set_input_shape('x', (2, 80, x.size(2)))
  147. self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
  148. self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
  149. self.estimator.set_input_shape('t', (2,))
  150. self.estimator.set_input_shape('spk', (2, 80))
  151. self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
  152. self.estimator.set_input_shape('mask_rand', (1, 1, 1))
  153. # run trt engine
  154. self.estimator.execute_v2([x.data_ptr(),
  155. mask.data_ptr(),
  156. mu.data_ptr(),
  157. t.data_ptr(),
  158. spks.data_ptr(),
  159. cond.data_ptr(),
  160. torch.randn(1, 1, 1).to(x.device).data_ptr(),
  161. x.data_ptr()])
  162. return x
  163. def compute_loss(self, x1, mask, mu, spks=None, cond=None):
  164. """Computes diffusion loss
  165. Args:
  166. x1 (torch.Tensor): Target
  167. shape: (batch_size, n_feats, mel_timesteps)
  168. mask (torch.Tensor): target mask
  169. shape: (batch_size, 1, mel_timesteps)
  170. mu (torch.Tensor): output of encoder
  171. shape: (batch_size, n_feats, mel_timesteps)
  172. spks (torch.Tensor, optional): speaker embedding. Defaults to None.
  173. shape: (batch_size, spk_emb_dim)
  174. Returns:
  175. loss: conditional flow matching loss
  176. y: conditional flow
  177. shape: (batch_size, n_feats, mel_timesteps)
  178. """
  179. b, _, t = mu.shape
  180. # random timestep
  181. t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
  182. if self.t_scheduler == 'cosine':
  183. t = 1 - torch.cos(t * 0.5 * torch.pi)
  184. # sample noise p(x_0)
  185. z = torch.randn_like(x1)
  186. y = (1 - (1 - self.sigma_min) * t) * z + t * x1
  187. u = x1 - (1 - self.sigma_min) * z
  188. # during training, we randomly drop condition to trade off mode coverage and sample fidelity
  189. if self.training_cfg_rate > 0:
  190. cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
  191. mu = mu * cfg_mask.view(-1, 1, 1)
  192. spks = spks * cfg_mask.view(-1, 1)
  193. cond = cond * cfg_mask.view(-1, 1, 1)
  194. pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
  195. loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
  196. return loss, y
  197. class CausalConditionalCFM(ConditionalCFM):
  198. def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
  199. super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
  200. self.rand_noise = torch.randn([1, 80, 50 * 300])
  201. @torch.inference_mode()
  202. def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
  203. """Forward diffusion
  204. Args:
  205. mu (torch.Tensor): output of encoder
  206. shape: (batch_size, n_feats, mel_timesteps)
  207. mask (torch.Tensor): output_mask
  208. shape: (batch_size, 1, mel_timesteps)
  209. n_timesteps (int): number of diffusion steps
  210. temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
  211. spks (torch.Tensor, optional): speaker ids. Defaults to None.
  212. shape: (batch_size, spk_emb_dim)
  213. cond: Not used but kept for future purposes
  214. Returns:
  215. sample: generated mel-spectrogram
  216. shape: (batch_size, n_feats, mel_timesteps)
  217. """
  218. z = self.rand_noise[:, :, :mu.size(2)].to(mu.device) * temperature
  219. if self.sp16 is True:
  220. z = z.half()
  221. # fix prompt and overlap part mu and z
  222. t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
  223. if self.t_scheduler == 'cosine':
  224. t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
  225. return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None