executor.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
  2. # 2024 Alibaba Inc (authors: Xiang Lyu)
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. from contextlib import nullcontext
  17. import os
  18. import torch
  19. import torch.distributed as dist
  20. from cosyvoice.utils.train_utils import update_parameter_and_lr, log_per_step, log_per_save, batch_forward, batch_backward, save_model, cosyvoice_join
  21. class Executor:
  22. def __init__(self, gan: bool = False, ref_model: torch.nn.Module = None, dpo_loss: torch.nn.Module = None):
  23. self.gan = gan
  24. self.ref_model = ref_model
  25. self.dpo_loss = dpo_loss
  26. self.step = 0
  27. self.epoch = 0
  28. self.rank = int(os.environ.get('RANK', 0))
  29. self.device = torch.device('cuda:{}'.format(self.rank))
  30. def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=None):
  31. ''' Train one epoch
  32. '''
  33. lr = optimizer.param_groups[0]['lr']
  34. logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
  35. logging.info('using accumulate grad, new batch size is {} times'
  36. ' larger than before'.format(info_dict['accum_grad']))
  37. # A context manager to be used in conjunction with an instance of
  38. # torch.nn.parallel.DistributedDataParallel to be able to train
  39. # with uneven inputs across participating processes.
  40. model.train()
  41. if self.ref_model is not None:
  42. self.ref_model.eval()
  43. model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
  44. with model_context():
  45. for batch_idx, batch_dict in enumerate(train_data_loader):
  46. info_dict["tag"] = "TRAIN"
  47. info_dict["step"] = self.step
  48. info_dict["epoch"] = self.epoch
  49. info_dict["batch_idx"] = batch_idx
  50. if cosyvoice_join(group_join, info_dict):
  51. break
  52. # Disable gradient synchronizations across DDP processes.
  53. # Within this context, gradients will be accumulated on module
  54. # variables, which will later be synchronized.
  55. if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
  56. context = model.no_sync
  57. # Used for single gpu training and DDP gradient synchronization
  58. # processes.
  59. else:
  60. context = nullcontext
  61. with context():
  62. info_dict = batch_forward(model, batch_dict, scaler, info_dict, ref_model=self.ref_model, dpo_loss=self.dpo_loss)
  63. info_dict = batch_backward(model, scaler, info_dict)
  64. info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
  65. log_per_step(writer, info_dict)
  66. # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
  67. if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
  68. (batch_idx + 1) % info_dict["accum_grad"] == 0:
  69. dist.barrier()
  70. self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
  71. model.train()
  72. if (batch_idx + 1) % info_dict["accum_grad"] == 0:
  73. self.step += 1
  74. dist.barrier()
  75. self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
  76. def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
  77. writer, info_dict, scaler, group_join):
  78. ''' Train one epoch
  79. '''
  80. lr = optimizer.param_groups[0]['lr']
  81. logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
  82. logging.info('using accumulate grad, new batch size is {} times'
  83. ' larger than before'.format(info_dict['accum_grad']))
  84. # A context manager to be used in conjunction with an instance of
  85. # torch.nn.parallel.DistributedDataParallel to be able to train
  86. # with uneven inputs across participating processes.
  87. model.train()
  88. model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
  89. with model_context():
  90. for batch_idx, batch_dict in enumerate(train_data_loader):
  91. info_dict["tag"] = "TRAIN"
  92. info_dict["step"] = self.step
  93. info_dict["epoch"] = self.epoch
  94. info_dict["batch_idx"] = batch_idx
  95. if cosyvoice_join(group_join, info_dict):
  96. break
  97. # Disable gradient synchronizations across DDP processes.
  98. # Within this context, gradients will be accumulated on module
  99. # variables, which will later be synchronized.
  100. if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
  101. context = model.no_sync
  102. # Used for single gpu training and DDP gradient synchronization
  103. # processes.
  104. else:
  105. context = nullcontext
  106. with context():
  107. batch_dict['turn'] = 'discriminator'
  108. info_dict = batch_forward(model, batch_dict, scaler, info_dict)
  109. info_dict = batch_backward(model, scaler, info_dict)
  110. info_dict = update_parameter_and_lr(model, optimizer_d, scheduler_d, scaler, info_dict)
  111. optimizer.zero_grad()
  112. log_per_step(writer, info_dict)
  113. with context():
  114. batch_dict['turn'] = 'generator'
  115. info_dict = batch_forward(model, batch_dict, scaler, info_dict)
  116. info_dict = batch_backward(model, scaler, info_dict)
  117. info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
  118. optimizer_d.zero_grad()
  119. log_per_step(writer, info_dict)
  120. # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
  121. if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
  122. (batch_idx + 1) % info_dict["accum_grad"] == 0:
  123. dist.barrier()
  124. self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
  125. model.train()
  126. if (batch_idx + 1) % info_dict["accum_grad"] == 0:
  127. self.step += 1
  128. dist.barrier()
  129. self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
  130. @torch.inference_mode()
  131. def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True):
  132. ''' Cross validation on
  133. '''
  134. logging.info('Epoch {} Step {} on_batch_end {} CV rank {}'.format(self.epoch, self.step + 1, on_batch_end, self.rank))
  135. model.eval()
  136. total_num_utts, total_loss_dict = 0, {} # avoid division by 0
  137. for batch_idx, batch_dict in enumerate(cv_data_loader):
  138. info_dict["tag"] = "CV"
  139. info_dict["step"] = self.step
  140. info_dict["epoch"] = self.epoch
  141. info_dict["batch_idx"] = batch_idx
  142. num_utts = len(batch_dict["utts"])
  143. total_num_utts += num_utts
  144. if self.gan is True:
  145. batch_dict['turn'] = 'generator'
  146. info_dict = batch_forward(model, batch_dict, None, info_dict)
  147. for k, v in info_dict['loss_dict'].items():
  148. if k not in total_loss_dict:
  149. total_loss_dict[k] = []
  150. total_loss_dict[k].append(v.item() * num_utts)
  151. log_per_step(None, info_dict)
  152. for k, v in total_loss_dict.items():
  153. total_loss_dict[k] = sum(v) / total_num_utts
  154. info_dict['loss_dict'] = total_loss_dict
  155. log_per_save(writer, info_dict)
  156. model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step + 1)
  157. save_model(model, model_name, info_dict)