flow_matching.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. # 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import threading
  16. import torch
  17. import torch.nn.functional as F
  18. from matcha.models.components.flow_matching import BASECFM
  19. class ConditionalCFM(BASECFM):
  20. def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
  21. super().__init__(
  22. n_feats=in_channels,
  23. cfm_params=cfm_params,
  24. n_spks=n_spks,
  25. spk_emb_dim=spk_emb_dim,
  26. )
  27. self.t_scheduler = cfm_params.t_scheduler
  28. self.training_cfg_rate = cfm_params.training_cfg_rate
  29. self.inference_cfg_rate = cfm_params.inference_cfg_rate
  30. in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
  31. # Just change the architecture of the estimator here
  32. self.estimator = estimator
  33. self.lock = threading.Lock()
  34. @torch.inference_mode()
  35. def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
  36. """Forward diffusion
  37. Args:
  38. mu (torch.Tensor): output of encoder
  39. shape: (batch_size, n_feats, mel_timesteps)
  40. mask (torch.Tensor): output_mask
  41. shape: (batch_size, 1, mel_timesteps)
  42. n_timesteps (int): number of diffusion steps
  43. temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
  44. spks (torch.Tensor, optional): speaker ids. Defaults to None.
  45. shape: (batch_size, spk_emb_dim)
  46. cond: Not used but kept for future purposes
  47. Returns:
  48. sample: generated mel-spectrogram
  49. shape: (batch_size, n_feats, mel_timesteps)
  50. """
  51. z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
  52. cache_size = cache.shape[2]
  53. # fix prompt and overlap part mu and z
  54. if cache_size != 0:
  55. z[:, :, :cache_size] = cache[:, :, :, 0]
  56. mu[:, :, :cache_size] = cache[:, :, :, 1]
  57. z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
  58. mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
  59. cache = torch.stack([z_cache, mu_cache], dim=-1)
  60. t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
  61. if self.t_scheduler == 'cosine':
  62. t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
  63. return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), cache
  64. def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False):
  65. """
  66. Fixed euler solver for ODEs.
  67. Args:
  68. x (torch.Tensor): random noise
  69. t_span (torch.Tensor): n_timesteps interpolated
  70. shape: (n_timesteps + 1,)
  71. mu (torch.Tensor): output of encoder
  72. shape: (batch_size, n_feats, mel_timesteps)
  73. mask (torch.Tensor): output_mask
  74. shape: (batch_size, 1, mel_timesteps)
  75. spks (torch.Tensor, optional): speaker ids. Defaults to None.
  76. shape: (batch_size, spk_emb_dim)
  77. cond: Not used but kept for future purposes
  78. """
  79. t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
  80. t = t.unsqueeze(dim=0)
  81. # I am storing this because I can later plot it by putting a debugger here and saving it to a file
  82. # Or in future might add like a return_all_steps flag
  83. sol = []
  84. # Do not use concat, it may cause memory format changed and trt infer with wrong results!
  85. x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
  86. mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
  87. mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
  88. t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
  89. spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
  90. cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
  91. for step in range(1, len(t_span)):
  92. # Classifier-Free Guidance inference introduced in VoiceBox
  93. x_in[:] = x
  94. mask_in[:] = mask
  95. mu_in[0] = mu
  96. t_in[:] = t.unsqueeze(0)
  97. spks_in[0] = spks
  98. cond_in[0] = cond
  99. dphi_dt = self.forward_estimator(
  100. x_in, mask_in,
  101. mu_in, t_in,
  102. spks_in,
  103. cond_in,
  104. streaming
  105. )
  106. dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
  107. dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
  108. x = x + dt * dphi_dt
  109. t = t + dt
  110. sol.append(x)
  111. if step < len(t_span) - 1:
  112. dt = t_span[step + 1] - t
  113. return sol[-1].float()
  114. def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False):
  115. if isinstance(self.estimator, torch.nn.Module):
  116. return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming)
  117. else:
  118. estimator, trt_engine = self.estimator.acquire_estimator()
  119. estimator.set_input_shape('x', (2, 80, x.size(2)))
  120. estimator.set_input_shape('mask', (2, 1, x.size(2)))
  121. estimator.set_input_shape('mu', (2, 80, x.size(2)))
  122. estimator.set_input_shape('t', (2,))
  123. estimator.set_input_shape('spks', (2, 80))
  124. estimator.set_input_shape('cond', (2, 80, x.size(2)))
  125. data_ptrs = [x.contiguous().data_ptr(),
  126. mask.contiguous().data_ptr(),
  127. mu.contiguous().data_ptr(),
  128. t.contiguous().data_ptr(),
  129. spks.contiguous().data_ptr(),
  130. cond.contiguous().data_ptr(),
  131. x.data_ptr()]
  132. for i, j in enumerate(data_ptrs):
  133. estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
  134. # run trt engine
  135. assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
  136. torch.cuda.current_stream().synchronize()
  137. self.estimator.release_estimator(estimator)
  138. return x
  139. def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
  140. """Computes diffusion loss
  141. Args:
  142. x1 (torch.Tensor): Target
  143. shape: (batch_size, n_feats, mel_timesteps)
  144. mask (torch.Tensor): target mask
  145. shape: (batch_size, 1, mel_timesteps)
  146. mu (torch.Tensor): output of encoder
  147. shape: (batch_size, n_feats, mel_timesteps)
  148. spks (torch.Tensor, optional): speaker embedding. Defaults to None.
  149. shape: (batch_size, spk_emb_dim)
  150. Returns:
  151. loss: conditional flow matching loss
  152. y: conditional flow
  153. shape: (batch_size, n_feats, mel_timesteps)
  154. """
  155. b, _, t = mu.shape
  156. # random timestep
  157. t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
  158. if self.t_scheduler == 'cosine':
  159. t = 1 - torch.cos(t * 0.5 * torch.pi)
  160. # sample noise p(x_0)
  161. z = torch.randn_like(x1)
  162. y = (1 - (1 - self.sigma_min) * t) * z + t * x1
  163. u = x1 - (1 - self.sigma_min) * z
  164. # during training, we randomly drop condition to trade off mode coverage and sample fidelity
  165. if self.training_cfg_rate > 0:
  166. cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
  167. mu = mu * cfg_mask.view(-1, 1, 1)
  168. spks = spks * cfg_mask.view(-1, 1)
  169. cond = cond * cfg_mask.view(-1, 1, 1)
  170. pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
  171. loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
  172. return loss, y
  173. class CausalConditionalCFM(ConditionalCFM):
  174. def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
  175. super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
  176. self.rand_noise = torch.randn([1, 80, 50 * 300])
  177. @torch.inference_mode()
  178. def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False):
  179. """Forward diffusion
  180. Args:
  181. mu (torch.Tensor): output of encoder
  182. shape: (batch_size, n_feats, mel_timesteps)
  183. mask (torch.Tensor): output_mask
  184. shape: (batch_size, 1, mel_timesteps)
  185. n_timesteps (int): number of diffusion steps
  186. temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
  187. spks (torch.Tensor, optional): speaker ids. Defaults to None.
  188. shape: (batch_size, spk_emb_dim)
  189. cond: Not used but kept for future purposes
  190. Returns:
  191. sample: generated mel-spectrogram
  192. shape: (batch_size, n_feats, mel_timesteps)
  193. """
  194. z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
  195. # fix prompt and overlap part mu and z
  196. t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
  197. if self.t_scheduler == 'cosine':
  198. t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
  199. return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None