flow_matching.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. import torch.nn.functional as F
  16. from matcha.models.components.flow_matching import BASECFM
  17. class ConditionalCFM(BASECFM):
  18. def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
  19. super().__init__(
  20. n_feats=in_channels,
  21. cfm_params=cfm_params,
  22. n_spks=n_spks,
  23. spk_emb_dim=spk_emb_dim,
  24. )
  25. self.t_scheduler = cfm_params.t_scheduler
  26. self.training_cfg_rate = cfm_params.training_cfg_rate
  27. self.inference_cfg_rate = cfm_params.inference_cfg_rate
  28. in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
  29. # Just change the architecture of the estimator here
  30. self.estimator = estimator
  31. self.estimator_context = None
  32. self.estimator_engine = None
  33. self.is_saved = None
  34. @torch.inference_mode()
  35. def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
  36. """Forward diffusion
  37. Args:
  38. mu (torch.Tensor): output of encoder
  39. shape: (batch_size, n_feats, mel_timesteps)
  40. mask (torch.Tensor): output_mask
  41. shape: (batch_size, 1, mel_timesteps)
  42. n_timesteps (int): number of diffusion steps
  43. temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
  44. spks (torch.Tensor, optional): speaker ids. Defaults to None.
  45. shape: (batch_size, spk_emb_dim)
  46. cond: Not used but kept for future purposes
  47. Returns:
  48. sample: generated mel-spectrogram
  49. shape: (batch_size, n_feats, mel_timesteps)
  50. """
  51. z = torch.randn_like(mu) * temperature
  52. t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
  53. if self.t_scheduler == 'cosine':
  54. t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
  55. return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
  56. def solve_euler(self, x, t_span, mu, mask, spks, cond):
  57. """
  58. Fixed euler solver for ODEs.
  59. Args:
  60. x (torch.Tensor): random noise
  61. t_span (torch.Tensor): n_timesteps interpolated
  62. shape: (n_timesteps + 1,)
  63. mu (torch.Tensor): output of encoder
  64. shape: (batch_size, n_feats, mel_timesteps)
  65. mask (torch.Tensor): output_mask
  66. shape: (batch_size, 1, mel_timesteps)
  67. spks (torch.Tensor, optional): speaker ids. Defaults to None.
  68. shape: (batch_size, spk_emb_dim)
  69. cond: Not used but kept for future purposes
  70. """
  71. t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
  72. t = t.unsqueeze(dim=0)
  73. # I am storing this because I can later plot it by putting a debugger here and saving it to a file
  74. # Or in future might add like a return_all_steps flag
  75. sol = []
  76. for step in range(1, len(t_span)):
  77. dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
  78. # Classifier-Free Guidance inference introduced in VoiceBox
  79. if self.inference_cfg_rate > 0:
  80. cfg_dphi_dt = self.forward_estimator(
  81. x, mask,
  82. torch.zeros_like(mu), t,
  83. torch.zeros_like(spks) if spks is not None else None,
  84. torch.zeros_like(cond)
  85. )
  86. dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
  87. self.inference_cfg_rate * cfg_dphi_dt)
  88. x = x + dt * dphi_dt
  89. t = t + dt
  90. sol.append(x)
  91. if step < len(t_span) - 1:
  92. dt = t_span[step + 1] - t
  93. return sol[-1]
  94. def forward_estimator(self, x, mask, mu, t, spks, cond):
  95. if self.estimator_context is not None:
  96. assert self.training is False, 'tensorrt cannot be used in training'
  97. bs = x.shape[0]
  98. hs = x.shape[1]
  99. seq_len = x.shape[2]
  100. # assert bs == 1 and hs == 80
  101. ret = torch.empty_like(x)
  102. self.estimator_context.set_input_shape("x", x.shape)
  103. self.estimator_context.set_input_shape("mask", mask.shape)
  104. self.estimator_context.set_input_shape("mu", mu.shape)
  105. self.estimator_context.set_input_shape("t", t.shape)
  106. self.estimator_context.set_input_shape("spks", spks.shape)
  107. self.estimator_context.set_input_shape("cond", cond.shape)
  108. bindings = [x.data_ptr(), mask.data_ptr(), mu.data_ptr(), t.data_ptr(), spks.data_ptr(), cond.data_ptr(), ret.data_ptr()]
  109. for i in range(len(bindings)):
  110. self.estimator_context.set_tensor_address(self.estimator_engine.get_tensor_name(i), bindings[i])
  111. handle = torch.cuda.current_stream().cuda_stream
  112. self.estimator_context.execute_async_v3(stream_handle=handle)
  113. return ret
  114. else:
  115. if self.is_saved == None:
  116. self.is_saved = True
  117. output = self.estimator.forward(x, mask, mu, t, spks, cond)
  118. torch.save(x, "x.pt")
  119. torch.save(mask, "mask.pt")
  120. torch.save(mu, "mu.pt")
  121. torch.save(t, "t.pt")
  122. torch.save(spks, "spks.pt")
  123. torch.save(cond, "cond.pt")
  124. torch.save(output, "output.pt")
  125. dummy_input = (x, mask, mu, t, spks, cond)
  126. torch.onnx.export(
  127. self.estimator,
  128. dummy_input,
  129. "estimator_fp32.onnx",
  130. export_params=True,
  131. opset_version=17,
  132. do_constant_folding=True,
  133. input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
  134. output_names=['output'],
  135. dynamic_axes={
  136. 'x': {2: 'seq_len'},
  137. 'mask': {2: 'seq_len'},
  138. 'mu': {2: 'seq_len'},
  139. 'cond': {2: 'seq_len'},
  140. 'output': {2: 'seq_len'},
  141. }
  142. )
  143. # print("x, x.shape", x, x.shape)
  144. # print("mask, mask.shape", mask, mask.shape)
  145. # print("mu, mu.shape", mu, mu.shape)
  146. # print("t, t.shape", t, t.shape)
  147. # print("spks, spks.shape", spks, spks.shape)
  148. # print("cond, cond.shape", cond, cond.shape)
  149. return self.estimator.forward(x, mask, mu, t, spks, cond)
  150. def compute_loss(self, x1, mask, mu, spks=None, cond=None):
  151. """Computes diffusion loss
  152. Args:
  153. x1 (torch.Tensor): Target
  154. shape: (batch_size, n_feats, mel_timesteps)
  155. mask (torch.Tensor): target mask
  156. shape: (batch_size, 1, mel_timesteps)
  157. mu (torch.Tensor): output of encoder
  158. shape: (batch_size, n_feats, mel_timesteps)
  159. spks (torch.Tensor, optional): speaker embedding. Defaults to None.
  160. shape: (batch_size, spk_emb_dim)
  161. Returns:
  162. loss: conditional flow matching loss
  163. y: conditional flow
  164. shape: (batch_size, n_feats, mel_timesteps)
  165. """
  166. b, _, t = mu.shape
  167. # random timestep
  168. t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
  169. if self.t_scheduler == 'cosine':
  170. t = 1 - torch.cos(t * 0.5 * torch.pi)
  171. # sample noise p(x_0)
  172. z = torch.randn_like(x1)
  173. y = (1 - (1 - self.sigma_min) * t) * z + t * x1
  174. u = x1 - (1 - self.sigma_min) * z
  175. # during training, we randomly drop condition to trade off mode coverage and sample fidelity
  176. if self.training_cfg_rate > 0:
  177. cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
  178. mu = mu * cfg_mask.view(-1, 1, 1)
  179. spks = spks * cfg_mask.view(-1, 1)
  180. cond = cond * cfg_mask.view(-1, 1, 1)
  181. pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
  182. loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
  183. return loss, y