flow_matching.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import threading
  15. import torch
  16. import torch.nn.functional as F
  17. from matcha.models.components.flow_matching import BASECFM
  18. import queue
  19. class EstimatorWrapper:
  20. def __init__(self, estimator_engine, estimator_count=2,):
  21. self.estimators = queue.Queue()
  22. self.estimator_engine = estimator_engine
  23. for _ in range(estimator_count):
  24. estimator = estimator_engine.create_execution_context()
  25. if estimator is not None:
  26. self.estimators.put(estimator)
  27. if self.estimators.empty():
  28. raise Exception("No available estimator")
  29. def acquire_estimator(self):
  30. return self.estimators.get(), self.estimator_engine
  31. def release_estimator(self, estimator):
  32. self.estimators.put(estimator)
  33. return
  34. class ConditionalCFM(BASECFM):
  35. def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
  36. super().__init__(
  37. n_feats=in_channels,
  38. cfm_params=cfm_params,
  39. n_spks=n_spks,
  40. spk_emb_dim=spk_emb_dim,
  41. )
  42. self.t_scheduler = cfm_params.t_scheduler
  43. self.training_cfg_rate = cfm_params.training_cfg_rate
  44. self.inference_cfg_rate = cfm_params.inference_cfg_rate
  45. in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
  46. # Just change the architecture of the estimator here
  47. self.estimator = estimator
  48. self.lock = threading.Lock()
  49. @torch.inference_mode()
  50. def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
  51. """Forward diffusion
  52. Args:
  53. mu (torch.Tensor): output of encoder
  54. shape: (batch_size, n_feats, mel_timesteps)
  55. mask (torch.Tensor): output_mask
  56. shape: (batch_size, 1, mel_timesteps)
  57. n_timesteps (int): number of diffusion steps
  58. temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
  59. spks (torch.Tensor, optional): speaker ids. Defaults to None.
  60. shape: (batch_size, spk_emb_dim)
  61. cond: Not used but kept for future purposes
  62. Returns:
  63. sample: generated mel-spectrogram
  64. shape: (batch_size, n_feats, mel_timesteps)
  65. """
  66. z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
  67. cache_size = flow_cache.shape[2]
  68. # fix prompt and overlap part mu and z
  69. if cache_size != 0:
  70. z[:, :, :cache_size] = flow_cache[:, :, :, 0]
  71. mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
  72. z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
  73. mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
  74. flow_cache = torch.stack([z_cache, mu_cache], dim=-1)
  75. t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
  76. if self.t_scheduler == 'cosine':
  77. t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
  78. return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
  79. def solve_euler(self, x, t_span, mu, mask, spks, cond):
  80. """
  81. Fixed euler solver for ODEs.
  82. Args:
  83. x (torch.Tensor): random noise
  84. t_span (torch.Tensor): n_timesteps interpolated
  85. shape: (n_timesteps + 1,)
  86. mu (torch.Tensor): output of encoder
  87. shape: (batch_size, n_feats, mel_timesteps)
  88. mask (torch.Tensor): output_mask
  89. shape: (batch_size, 1, mel_timesteps)
  90. spks (torch.Tensor, optional): speaker ids. Defaults to None.
  91. shape: (batch_size, spk_emb_dim)
  92. cond: Not used but kept for future purposes
  93. """
  94. t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
  95. t = t.unsqueeze(dim=0)
  96. # I am storing this because I can later plot it by putting a debugger here and saving it to a file
  97. # Or in future might add like a return_all_steps flag
  98. sol = []
  99. # Do not use concat, it may cause memory format changed and trt infer with wrong results!
  100. x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
  101. mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
  102. mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
  103. t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
  104. spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
  105. cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
  106. for step in range(1, len(t_span)):
  107. # Classifier-Free Guidance inference introduced in VoiceBox
  108. x_in[:] = x
  109. mask_in[:] = mask
  110. mu_in[0] = mu
  111. t_in[:] = t.unsqueeze(0)
  112. spks_in[0] = spks
  113. cond_in[0] = cond
  114. dphi_dt = self.forward_estimator(
  115. x_in, mask_in,
  116. mu_in, t_in,
  117. spks_in,
  118. cond_in
  119. )
  120. dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
  121. dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
  122. x = x + dt * dphi_dt
  123. t = t + dt
  124. sol.append(x)
  125. if step < len(t_span) - 1:
  126. dt = t_span[step + 1] - t
  127. return sol[-1].float()
  128. def forward_estimator(self, x, mask, mu, t, spks, cond):
  129. if isinstance(self.estimator, torch.nn.Module):
  130. return self.estimator.forward(x, mask, mu, t, spks, cond)
  131. else:
  132. if isinstance(self.estimator, EstimatorWrapper):
  133. estimator, engine = self.estimator.acquire_estimator()
  134. estimator.set_input_shape('x', (2, 80, x.size(2)))
  135. estimator.set_input_shape('mask', (2, 1, x.size(2)))
  136. estimator.set_input_shape('mu', (2, 80, x.size(2)))
  137. estimator.set_input_shape('t', (2,))
  138. estimator.set_input_shape('spks', (2, 80))
  139. estimator.set_input_shape('cond', (2, 80, x.size(2)))
  140. data_ptrs = [x.contiguous().data_ptr(),
  141. mask.contiguous().data_ptr(),
  142. mu.contiguous().data_ptr(),
  143. t.contiguous().data_ptr(),
  144. spks.contiguous().data_ptr(),
  145. cond.contiguous().data_ptr(),
  146. x.data_ptr()]
  147. for idx, data_ptr in enumerate(data_ptrs):
  148. estimator.set_tensor_address(engine.get_tensor_name(idx), data_ptr)
  149. # run trt engine
  150. estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream)
  151. torch.cuda.current_stream().synchronize()
  152. self.estimator.release_estimator(estimator)
  153. return x
  154. else:
  155. with self.lock:
  156. self.estimator.set_input_shape('x', (2, 80, x.size(2)))
  157. self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
  158. self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
  159. self.estimator.set_input_shape('t', (2,))
  160. self.estimator.set_input_shape('spks', (2, 80))
  161. self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
  162. # run trt engine
  163. self.estimator.execute_v2([x.contiguous().data_ptr(),
  164. mask.contiguous().data_ptr(),
  165. mu.contiguous().data_ptr(),
  166. t.contiguous().data_ptr(),
  167. spks.contiguous().data_ptr(),
  168. cond.contiguous().data_ptr(),
  169. x.data_ptr()])
  170. return x
  171. def compute_loss(self, x1, mask, mu, spks=None, cond=None):
  172. """Computes diffusion loss
  173. Args:
  174. x1 (torch.Tensor): Target
  175. shape: (batch_size, n_feats, mel_timesteps)
  176. mask (torch.Tensor): target mask
  177. shape: (batch_size, 1, mel_timesteps)
  178. mu (torch.Tensor): output of encoder
  179. shape: (batch_size, n_feats, mel_timesteps)
  180. spks (torch.Tensor, optional): speaker embedding. Defaults to None.
  181. shape: (batch_size, spk_emb_dim)
  182. Returns:
  183. loss: conditional flow matching loss
  184. y: conditional flow
  185. shape: (batch_size, n_feats, mel_timesteps)
  186. """
  187. b, _, t = mu.shape
  188. # random timestep
  189. t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
  190. if self.t_scheduler == 'cosine':
  191. t = 1 - torch.cos(t * 0.5 * torch.pi)
  192. # sample noise p(x_0)
  193. z = torch.randn_like(x1)
  194. y = (1 - (1 - self.sigma_min) * t) * z + t * x1
  195. u = x1 - (1 - self.sigma_min) * z
  196. # during training, we randomly drop condition to trade off mode coverage and sample fidelity
  197. if self.training_cfg_rate > 0:
  198. cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
  199. mu = mu * cfg_mask.view(-1, 1, 1)
  200. spks = spks * cfg_mask.view(-1, 1)
  201. cond = cond * cfg_mask.view(-1, 1, 1)
  202. pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
  203. loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
  204. return loss, y
  205. class CausalConditionalCFM(ConditionalCFM):
  206. def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
  207. super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
  208. self.rand_noise = torch.randn([1, 80, 50 * 300])
  209. @torch.inference_mode()
  210. def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
  211. """Forward diffusion
  212. Args:
  213. mu (torch.Tensor): output of encoder
  214. shape: (batch_size, n_feats, mel_timesteps)
  215. mask (torch.Tensor): output_mask
  216. shape: (batch_size, 1, mel_timesteps)
  217. n_timesteps (int): number of diffusion steps
  218. temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
  219. spks (torch.Tensor, optional): speaker ids. Defaults to None.
  220. shape: (batch_size, spk_emb_dim)
  221. cond: Not used but kept for future purposes
  222. Returns:
  223. sample: generated mel-spectrogram
  224. shape: (batch_size, n_feats, mel_timesteps)
  225. """
  226. z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
  227. # fix prompt and overlap part mu and z
  228. t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
  229. if self.t_scheduler == 'cosine':
  230. t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
  231. return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None