1
0

flow_matching.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. import torch.nn.functional as F
  16. from matcha.models.components.flow_matching import BASECFM
  17. class ConditionalCFM(BASECFM):
  18. def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
  19. super().__init__(
  20. n_feats=in_channels,
  21. cfm_params=cfm_params,
  22. n_spks=n_spks,
  23. spk_emb_dim=spk_emb_dim,
  24. )
  25. self.t_scheduler = cfm_params.t_scheduler
  26. self.training_cfg_rate = cfm_params.training_cfg_rate
  27. self.inference_cfg_rate = cfm_params.inference_cfg_rate
  28. in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
  29. # Just change the architecture of the estimator here
  30. self.estimator = estimator
  31. self.estimator_context = None
  32. self.estimator_engine = None
  33. @torch.inference_mode()
  34. def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
  35. """Forward diffusion
  36. Args:
  37. mu (torch.Tensor): output of encoder
  38. shape: (batch_size, n_feats, mel_timesteps)
  39. mask (torch.Tensor): output_mask
  40. shape: (batch_size, 1, mel_timesteps)
  41. n_timesteps (int): number of diffusion steps
  42. temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
  43. spks (torch.Tensor, optional): speaker ids. Defaults to None.
  44. shape: (batch_size, spk_emb_dim)
  45. cond: Not used but kept for future purposes
  46. Returns:
  47. sample: generated mel-spectrogram
  48. shape: (batch_size, n_feats, mel_timesteps)
  49. """
  50. z = torch.randn_like(mu) * temperature
  51. t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
  52. if self.t_scheduler == 'cosine':
  53. t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
  54. return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
  55. def solve_euler(self, x, t_span, mu, mask, spks, cond):
  56. """
  57. Fixed euler solver for ODEs.
  58. Args:
  59. x (torch.Tensor): random noise
  60. t_span (torch.Tensor): n_timesteps interpolated
  61. shape: (n_timesteps + 1,)
  62. mu (torch.Tensor): output of encoder
  63. shape: (batch_size, n_feats, mel_timesteps)
  64. mask (torch.Tensor): output_mask
  65. shape: (batch_size, 1, mel_timesteps)
  66. spks (torch.Tensor, optional): speaker ids. Defaults to None.
  67. shape: (batch_size, spk_emb_dim)
  68. cond: Not used but kept for future purposes
  69. """
  70. t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
  71. t = t.unsqueeze(dim=0)
  72. # I am storing this because I can later plot it by putting a debugger here and saving it to a file
  73. # Or in future might add like a return_all_steps flag
  74. sol = []
  75. for step in range(1, len(t_span)):
  76. dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
  77. # Classifier-Free Guidance inference introduced in VoiceBox
  78. if self.inference_cfg_rate > 0:
  79. cfg_dphi_dt = self.forward_estimator(
  80. x, mask,
  81. torch.zeros_like(mu), t,
  82. torch.zeros_like(spks) if spks is not None else None,
  83. torch.zeros_like(cond)
  84. )
  85. dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
  86. self.inference_cfg_rate * cfg_dphi_dt)
  87. x = x + dt * dphi_dt
  88. t = t + dt
  89. sol.append(x)
  90. if step < len(t_span) - 1:
  91. dt = t_span[step + 1] - t
  92. return sol[-1]
  93. def forward_estimator(self, x, mask, mu, t, spks, cond):
  94. if self.estimator_context is not None:
  95. assert self.training is False, 'tensorrt cannot be used in training'
  96. bs = x.shape[0]
  97. hs = x.shape[1]
  98. seq_len = x.shape[2]
  99. # assert bs == 1 and hs == 80
  100. ret = torch.empty_like(x)
  101. self.estimator_context.set_input_shape("x", x.shape)
  102. self.estimator_context.set_input_shape("mask", mask.shape)
  103. self.estimator_context.set_input_shape("mu", mu.shape)
  104. self.estimator_context.set_input_shape("t", t.shape)
  105. self.estimator_context.set_input_shape("spks", spks.shape)
  106. self.estimator_context.set_input_shape("cond", cond.shape)
  107. bindings = [x.data_ptr(), mask.data_ptr(), mu.data_ptr(), t.data_ptr(), spks.data_ptr(), cond.data_ptr(), ret.data_ptr()]
  108. for i in range(len(bindings)):
  109. self.estimator_context.set_tensor_address(self.estimator_engine.get_tensor_name(i), bindings[i])
  110. handle = torch.cuda.current_stream().cuda_stream
  111. self.estimator_context.execute_async_v3(stream_handle=handle)
  112. return ret
  113. else:
  114. return self.estimator.forward(x, mask, mu, t, spks, cond)
  115. def compute_loss(self, x1, mask, mu, spks=None, cond=None):
  116. """Computes diffusion loss
  117. Args:
  118. x1 (torch.Tensor): Target
  119. shape: (batch_size, n_feats, mel_timesteps)
  120. mask (torch.Tensor): target mask
  121. shape: (batch_size, 1, mel_timesteps)
  122. mu (torch.Tensor): output of encoder
  123. shape: (batch_size, n_feats, mel_timesteps)
  124. spks (torch.Tensor, optional): speaker embedding. Defaults to None.
  125. shape: (batch_size, spk_emb_dim)
  126. Returns:
  127. loss: conditional flow matching loss
  128. y: conditional flow
  129. shape: (batch_size, n_feats, mel_timesteps)
  130. """
  131. b, _, t = mu.shape
  132. # random timestep
  133. t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
  134. if self.t_scheduler == 'cosine':
  135. t = 1 - torch.cos(t * 0.5 * torch.pi)
  136. # sample noise p(x_0)
  137. z = torch.randn_like(x1)
  138. y = (1 - (1 - self.sigma_min) * t) * z + t * x1
  139. u = x1 - (1 - self.sigma_min) * z
  140. # during training, we randomly drop condition to trade off mode coverage and sample fidelity
  141. if self.training_cfg_rate > 0:
  142. cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
  143. mu = mu * cfg_mask.view(-1, 1, 1)
  144. spks = spks * cfg_mask.view(-1, 1)
  145. cond = cond * cfg_mask.view(-1, 1, 1)
  146. pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
  147. loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
  148. return loss, y