flow_matching.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import onnxruntime
  15. import torch
  16. import torch.nn.functional as F
  17. from matcha.models.components.flow_matching import BASECFM
  18. class ConditionalCFM(BASECFM):
  19. def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
  20. super().__init__(
  21. n_feats=in_channels,
  22. cfm_params=cfm_params,
  23. n_spks=n_spks,
  24. spk_emb_dim=spk_emb_dim,
  25. )
  26. self.t_scheduler = cfm_params.t_scheduler
  27. self.training_cfg_rate = cfm_params.training_cfg_rate
  28. self.inference_cfg_rate = cfm_params.inference_cfg_rate
  29. in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
  30. # Just change the architecture of the estimator here
  31. self.estimator = estimator
  32. @torch.inference_mode()
  33. def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
  34. """Forward diffusion
  35. Args:
  36. mu (torch.Tensor): output of encoder
  37. shape: (batch_size, n_feats, mel_timesteps)
  38. mask (torch.Tensor): output_mask
  39. shape: (batch_size, 1, mel_timesteps)
  40. n_timesteps (int): number of diffusion steps
  41. temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
  42. spks (torch.Tensor, optional): speaker ids. Defaults to None.
  43. shape: (batch_size, spk_emb_dim)
  44. cond: Not used but kept for future purposes
  45. Returns:
  46. sample: generated mel-spectrogram
  47. shape: (batch_size, n_feats, mel_timesteps)
  48. """
  49. z = torch.randn_like(mu) * temperature
  50. cache_size = flow_cache.shape[2]
  51. # fix prompt and overlap part mu and z
  52. if cache_size != 0:
  53. z[:, :, :cache_size] = flow_cache[:, :, :, 0]
  54. mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
  55. z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
  56. mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
  57. flow_cache = torch.stack([z_cache, mu_cache], dim=-1)
  58. t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
  59. if self.t_scheduler == 'cosine':
  60. t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
  61. return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
  62. def solve_euler(self, x, t_span, mu, mask, spks, cond):
  63. """
  64. Fixed euler solver for ODEs.
  65. Args:
  66. x (torch.Tensor): random noise
  67. t_span (torch.Tensor): n_timesteps interpolated
  68. shape: (n_timesteps + 1,)
  69. mu (torch.Tensor): output of encoder
  70. shape: (batch_size, n_feats, mel_timesteps)
  71. mask (torch.Tensor): output_mask
  72. shape: (batch_size, 1, mel_timesteps)
  73. spks (torch.Tensor, optional): speaker ids. Defaults to None.
  74. shape: (batch_size, spk_emb_dim)
  75. cond: Not used but kept for future purposes
  76. """
  77. t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
  78. t = t.unsqueeze(dim=0)
  79. # I am storing this because I can later plot it by putting a debugger here and saving it to a file
  80. # Or in future might add like a return_all_steps flag
  81. sol = []
  82. if self.inference_cfg_rate > 0:
  83. # Do not use concat, it may cause memory format changed and trt infer with wrong results!
  84. x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
  85. mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
  86. mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
  87. t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
  88. spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
  89. cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
  90. else:
  91. x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
  92. for step in range(1, len(t_span)):
  93. # Classifier-Free Guidance inference introduced in VoiceBox
  94. if self.inference_cfg_rate > 0:
  95. x_in[:] = x
  96. mask_in[:] = mask
  97. mu_in[0] = mu
  98. t_in[:] = t.unsqueeze(0)
  99. spks_in[0] = spks
  100. cond_in[0] = cond
  101. else:
  102. x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
  103. dphi_dt = self.forward_estimator(
  104. x_in, mask_in,
  105. mu_in, t_in,
  106. spks_in,
  107. cond_in
  108. )
  109. if self.inference_cfg_rate > 0:
  110. dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
  111. dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
  112. x = x + dt * dphi_dt
  113. t = t + dt
  114. sol.append(x)
  115. if step < len(t_span) - 1:
  116. dt = t_span[step + 1] - t
  117. return sol[-1].float()
  118. def forward_estimator(self, x, mask, mu, t, spks, cond):
  119. if isinstance(self.estimator, torch.nn.Module):
  120. return self.estimator.forward(x, mask, mu, t, spks, cond)
  121. elif isinstance(self.estimator, onnxruntime.InferenceSession):
  122. ort_inputs = {
  123. 'x': x.cpu().numpy(),
  124. 'mask': mask.cpu().numpy(),
  125. 'mu': mu.cpu().numpy(),
  126. 't': t.cpu().numpy(),
  127. 'spks': spks.cpu().numpy(),
  128. 'cond': cond.cpu().numpy()
  129. }
  130. output = self.estimator.run(None, ort_inputs)[0]
  131. return torch.tensor(output, dtype=x.dtype, device=x.device)
  132. else:
  133. self.estimator.set_input_shape('x', (2, 80, x.size(2)))
  134. self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
  135. self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
  136. self.estimator.set_input_shape('t', (2,))
  137. self.estimator.set_input_shape('spks', (2, 80))
  138. self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
  139. # run trt engine
  140. self.estimator.execute_v2([x.contiguous().data_ptr(),
  141. mask.contiguous().data_ptr(),
  142. mu.contiguous().data_ptr(),
  143. t.contiguous().data_ptr(),
  144. spks.contiguous().data_ptr(),
  145. cond.contiguous().data_ptr(),
  146. x.data_ptr()])
  147. return x
  148. def compute_loss(self, x1, mask, mu, spks=None, cond=None):
  149. """Computes diffusion loss
  150. Args:
  151. x1 (torch.Tensor): Target
  152. shape: (batch_size, n_feats, mel_timesteps)
  153. mask (torch.Tensor): target mask
  154. shape: (batch_size, 1, mel_timesteps)
  155. mu (torch.Tensor): output of encoder
  156. shape: (batch_size, n_feats, mel_timesteps)
  157. spks (torch.Tensor, optional): speaker embedding. Defaults to None.
  158. shape: (batch_size, spk_emb_dim)
  159. Returns:
  160. loss: conditional flow matching loss
  161. y: conditional flow
  162. shape: (batch_size, n_feats, mel_timesteps)
  163. """
  164. b, _, t = mu.shape
  165. # random timestep
  166. t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
  167. if self.t_scheduler == 'cosine':
  168. t = 1 - torch.cos(t * 0.5 * torch.pi)
  169. # sample noise p(x_0)
  170. z = torch.randn_like(x1)
  171. y = (1 - (1 - self.sigma_min) * t) * z + t * x1
  172. u = x1 - (1 - self.sigma_min) * z
  173. # during training, we randomly drop condition to trade off mode coverage and sample fidelity
  174. if self.training_cfg_rate > 0:
  175. cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
  176. mu = mu * cfg_mask.view(-1, 1, 1)
  177. spks = spks * cfg_mask.view(-1, 1)
  178. cond = cond * cfg_mask.view(-1, 1, 1)
  179. pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
  180. loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
  181. return loss, y
  182. class CausalConditionalCFM(ConditionalCFM):
  183. def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
  184. super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
  185. self.rand_noise = torch.randn([1, 80, 50 * 300])
  186. @torch.inference_mode()
  187. def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
  188. """Forward diffusion
  189. Args:
  190. mu (torch.Tensor): output of encoder
  191. shape: (batch_size, n_feats, mel_timesteps)
  192. mask (torch.Tensor): output_mask
  193. shape: (batch_size, 1, mel_timesteps)
  194. n_timesteps (int): number of diffusion steps
  195. temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
  196. spks (torch.Tensor, optional): speaker ids. Defaults to None.
  197. shape: (batch_size, spk_emb_dim)
  198. cond: Not used but kept for future purposes
  199. Returns:
  200. sample: generated mel-spectrogram
  201. shape: (batch_size, n_feats, mel_timesteps)
  202. """
  203. z = self.rand_noise[:, :, :mu.size(2)].to(mu.device) * temperature
  204. if self.fp16 is True:
  205. z = z.half()
  206. # fix prompt and overlap part mu and z
  207. t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
  208. if self.t_scheduler == 'cosine':
  209. t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
  210. return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None