train_utils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
  2. # 2023 Horizon Inc. (authors: Xingchen Song)
  3. # 2024 Alibaba Inc (authors: Xiang Lyu)
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. from contextlib import nullcontext
  17. import logging
  18. import os
  19. import torch
  20. import json
  21. import re
  22. import datetime
  23. import yaml
  24. import deepspeed
  25. import torch.optim as optim
  26. import torch.distributed as dist
  27. from torch.utils.tensorboard import SummaryWriter
  28. from torch.utils.data import DataLoader
  29. from torch.nn.utils import clip_grad_norm_
  30. from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live
  31. from cosyvoice.dataset.dataset import Dataset
  32. from cosyvoice.utils.scheduler import WarmupLR, NoamHoldAnnealing, ConstantLR
  33. def init_distributed(args):
  34. world_size = int(os.environ.get('WORLD_SIZE', 1))
  35. local_rank = int(os.environ.get('LOCAL_RANK', 0))
  36. rank = int(os.environ.get('RANK', 0))
  37. logging.info('training on multiple gpus, this gpu {}'.format(local_rank) +
  38. ', rank {}, world_size {}'.format(rank, world_size))
  39. if args.train_engine == 'torch_ddp':
  40. torch.cuda.set_device(local_rank)
  41. dist.init_process_group(args.dist_backend)
  42. else:
  43. deepspeed.init_distributed(dist_backend=args.dist_backend)
  44. return world_size, local_rank, rank
  45. def init_dataset_and_dataloader(args, configs):
  46. train_dataset = Dataset(args.train_data, data_pipeline=configs['data_pipeline'], mode='train', shuffle=True, partition=True)
  47. cv_dataset = Dataset(args.cv_data, data_pipeline=configs['data_pipeline'], mode='train', shuffle=False, partition=False)
  48. # do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
  49. train_data_loader = DataLoader(train_dataset,
  50. batch_size=None,
  51. pin_memory=args.pin_memory,
  52. num_workers=args.num_workers,
  53. prefetch_factor=args.prefetch)
  54. cv_data_loader = DataLoader(cv_dataset,
  55. batch_size=None,
  56. pin_memory=args.pin_memory,
  57. num_workers=args.num_workers,
  58. prefetch_factor=args.prefetch)
  59. return train_dataset, cv_dataset, train_data_loader, cv_data_loader
  60. def check_modify_and_save_config(args, configs):
  61. if args.train_engine == "torch_ddp":
  62. configs['train_conf']["dtype"] = 'fp32'
  63. else:
  64. with open(args.deepspeed_config, 'r') as fin:
  65. ds_configs = json.load(fin)
  66. if "fp16" in ds_configs and ds_configs["fp16"]["enabled"]:
  67. configs['train_conf']["dtype"] = "fp16"
  68. elif "bf16" in ds_configs and ds_configs["bf16"]["enabled"]:
  69. configs['train_conf']["dtype"] = "bf16"
  70. else:
  71. configs['train_conf']["dtype"] = "fp32"
  72. assert ds_configs["train_micro_batch_size_per_gpu"] == 1
  73. # if use deepspeed, override ddp config
  74. configs['train_conf']['save_per_step'] = int(configs['train_conf']['save_per_step'] *
  75. configs['train_conf']['accum_grad'] / ds_configs["gradient_accumulation_steps"])
  76. configs['train_conf']['accum_grad'] = ds_configs["gradient_accumulation_steps"]
  77. configs['train_conf']['grad_clip'] = ds_configs["gradient_clipping"]
  78. configs['train_conf']['log_interval'] = ds_configs["steps_per_print"]
  79. return configs
  80. def wrap_cuda_model(args, model):
  81. local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
  82. world_size = int(os.environ.get('WORLD_SIZE', 1))
  83. if args.train_engine == "torch_ddp": # native pytorch ddp
  84. assert (torch.cuda.is_available())
  85. model.cuda()
  86. model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
  87. else:
  88. if int(os.environ.get('RANK', 0)) == 0:
  89. logging.info("Estimating model states memory needs (zero2)...")
  90. estimate_zero2_model_states_mem_needs_all_live(
  91. model,
  92. num_gpus_per_node=local_world_size,
  93. num_nodes=world_size // local_world_size)
  94. return model
  95. def init_optimizer_and_scheduler(args, configs, model):
  96. if configs['train_conf']['optim'] == 'adam':
  97. optimizer = optim.Adam(model.parameters(), **configs['train_conf']['optim_conf'])
  98. elif configs['train_conf']['optim'] == 'adamw':
  99. optimizer = optim.AdamW(model.parameters(), **configs['train_conf']['optim_conf'])
  100. else:
  101. raise ValueError("unknown optimizer: " + configs['train_conf'])
  102. if configs['train_conf']['scheduler'] == 'warmuplr':
  103. scheduler_type = WarmupLR
  104. scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
  105. elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
  106. scheduler_type = NoamHoldAnnealing
  107. scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
  108. elif configs['train_conf']['scheduler'] == 'constantlr':
  109. scheduler_type = ConstantLR
  110. scheduler = ConstantLR(optimizer)
  111. else:
  112. raise ValueError("unknown scheduler: " + configs['train_conf'])
  113. # use deepspeed optimizer for speedup
  114. if args.train_engine == "deepspeed":
  115. def scheduler(opt):
  116. return scheduler_type(opt, **configs['train_conf']['scheduler_conf'])
  117. model, optimizer, _, scheduler = deepspeed.initialize(
  118. args=args,
  119. model=model,
  120. optimizer=None,
  121. lr_scheduler=scheduler,
  122. model_parameters=model.parameters())
  123. return model, optimizer, scheduler
  124. def init_optimizer_and_scheduler_gan(args, configs, model):
  125. if configs['train_conf']['optim'] == 'adam':
  126. optimizer = optim.Adam(model.module.generator.parameters(), **configs['train_conf']['optim_conf'])
  127. elif configs['train_conf']['optim'] == 'adamw':
  128. optimizer = optim.AdamW(model.module.generator.parameters(), **configs['train_conf']['optim_conf'])
  129. else:
  130. raise ValueError("unknown optimizer: " + configs['train_conf'])
  131. if configs['train_conf']['scheduler'] == 'warmuplr':
  132. scheduler_type = WarmupLR
  133. scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
  134. elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
  135. scheduler_type = NoamHoldAnnealing
  136. scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
  137. elif configs['train_conf']['scheduler'] == 'constantlr':
  138. scheduler_type = ConstantLR
  139. scheduler = ConstantLR(optimizer)
  140. else:
  141. raise ValueError("unknown scheduler: " + configs['train_conf'])
  142. if configs['train_conf']['optim_d'] == 'adam':
  143. optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
  144. elif configs['train_conf']['optim_d'] == 'adamw':
  145. optimizer_d = optim.AdamW(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
  146. else:
  147. raise ValueError("unknown optimizer: " + configs['train_conf'])
  148. if configs['train_conf']['scheduler_d'] == 'warmuplr':
  149. scheduler_type = WarmupLR
  150. scheduler_d = WarmupLR(optimizer_d, **configs['train_conf']['scheduler_conf'])
  151. elif configs['train_conf']['scheduler_d'] == 'NoamHoldAnnealing':
  152. scheduler_type = NoamHoldAnnealing
  153. scheduler_d = NoamHoldAnnealing(optimizer_d, **configs['train_conf']['scheduler_conf'])
  154. elif configs['train_conf']['scheduler'] == 'constantlr':
  155. scheduler_type = ConstantLR
  156. scheduler_d = ConstantLR(optimizer_d)
  157. else:
  158. raise ValueError("unknown scheduler: " + configs['train_conf'])
  159. # currently we wrap generator and discriminator in one model, so we cannot use deepspeed
  160. return model, optimizer, scheduler, optimizer_d, scheduler_d
  161. def init_summarywriter(args):
  162. writer = None
  163. if int(os.environ.get('RANK', 0)) == 0:
  164. os.makedirs(args.model_dir, exist_ok=True)
  165. writer = SummaryWriter(args.tensorboard_dir)
  166. return writer
  167. def save_model(model, model_name, info_dict):
  168. rank = int(os.environ.get('RANK', 0))
  169. model_dir = info_dict["model_dir"]
  170. save_model_path = os.path.join(model_dir, '{}.pt'.format(model_name))
  171. if info_dict["train_engine"] == "torch_ddp":
  172. if rank == 0:
  173. torch.save(model.module.state_dict(), save_model_path)
  174. else:
  175. with torch.no_grad():
  176. model.save_checkpoint(save_dir=model_dir,
  177. tag=model_name,
  178. client_state=info_dict)
  179. if rank == 0:
  180. info_path = re.sub('.pt$', '.yaml', save_model_path)
  181. info_dict['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S')
  182. with open(info_path, 'w') as fout:
  183. data = yaml.dump(info_dict)
  184. fout.write(data)
  185. logging.info('[Rank {}] Checkpoint: save to checkpoint {}'.format(rank, save_model_path))
  186. def cosyvoice_join(group_join, info_dict):
  187. world_size = int(os.environ.get('WORLD_SIZE', 1))
  188. local_rank = int(os.environ.get('LOCAL_RANK', 0))
  189. rank = int(os.environ.get('RANK', 0))
  190. if info_dict["batch_idx"] != 0:
  191. # we try to join all rank in both ddp and deepspeed mode, in case different rank has different lr
  192. try:
  193. dist.monitored_barrier(group=group_join,
  194. timeout=group_join.options._timeout)
  195. return False
  196. except RuntimeError as e:
  197. logging.info("Detected uneven workload distribution: {}\n".format(e) +
  198. "Break current worker to manually join all workers, " +
  199. "world_size {}, current rank {}, current local_rank {}\n".
  200. format(world_size, rank, local_rank))
  201. return True
  202. else:
  203. return False
  204. def batch_forward(model, batch, info_dict):
  205. device = int(os.environ.get('LOCAL_RANK', 0))
  206. dtype = info_dict["dtype"]
  207. if dtype == "fp16":
  208. dtype = torch.float16
  209. elif dtype == "bf16":
  210. dtype = torch.bfloat16
  211. else: # fp32
  212. dtype = torch.float32
  213. if info_dict['train_engine'] == 'torch_ddp':
  214. autocast = nullcontext()
  215. else:
  216. autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False)
  217. with autocast:
  218. info_dict['loss_dict'] = model(batch, device)
  219. return info_dict
  220. def batch_backward(model, info_dict):
  221. if info_dict["train_engine"] == "deepspeed":
  222. scaled_loss = model.backward(info_dict['loss_dict']['loss'])
  223. else:
  224. scaled_loss = info_dict['loss_dict']['loss'] / info_dict['accum_grad']
  225. scaled_loss.backward()
  226. info_dict['loss_dict']['loss'] = scaled_loss
  227. return info_dict
  228. def update_parameter_and_lr(model, optimizer, scheduler, info_dict):
  229. grad_norm = 0.0
  230. if info_dict['train_engine'] == "deepspeed":
  231. info_dict["is_gradient_accumulation_boundary"] = model.is_gradient_accumulation_boundary()
  232. model.step()
  233. grad_norm = model.get_global_grad_norm()
  234. elif (info_dict['batch_idx'] + 1) % info_dict["accum_grad"] == 0:
  235. grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
  236. if torch.isfinite(grad_norm):
  237. optimizer.step()
  238. optimizer.zero_grad()
  239. scheduler.step()
  240. info_dict["lr"] = optimizer.param_groups[0]['lr']
  241. info_dict["grad_norm"] = grad_norm
  242. return info_dict
  243. def log_per_step(writer, info_dict):
  244. tag = info_dict["tag"]
  245. epoch = info_dict.get('epoch', 0)
  246. step = info_dict["step"]
  247. batch_idx = info_dict["batch_idx"]
  248. loss_dict = info_dict['loss_dict']
  249. rank = int(os.environ.get('RANK', 0))
  250. # only rank 0 write to tensorboard to avoid multi-process write
  251. if writer is not None:
  252. if (info_dict['train_engine'] == 'deepspeed' and info_dict['is_gradient_accumulation_boundary'] is True) or \
  253. (info_dict['train_engine'] == 'torch_ddp' and (info_dict['batch_idx'] + 1) % info_dict['accum_grad'] == 0):
  254. for k in ['epoch', 'lr', 'grad_norm']:
  255. writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1)
  256. for k, v in loss_dict.items():
  257. writer.add_scalar('{}/{}'.format(tag, k), v, step + 1)
  258. # TRAIN & CV, Shell log (stdout)
  259. if (info_dict['batch_idx'] + 1) % info_dict['log_interval'] == 0:
  260. log_str = '{} Batch {}/{} '.format(tag, epoch, batch_idx + 1)
  261. for name, value in loss_dict.items():
  262. log_str += '{} {:.6f} '.format(name, value)
  263. if tag == "TRAIN":
  264. log_str += 'lr {:.8f} grad_norm {:.6f}'.format(
  265. info_dict["lr"], info_dict['grad_norm'])
  266. log_str += ' rank {}'.format(rank)
  267. logging.debug(log_str)
  268. def log_per_save(writer, info_dict):
  269. tag = info_dict["tag"]
  270. epoch = info_dict["epoch"]
  271. step = info_dict["step"]
  272. loss_dict = info_dict["loss_dict"]
  273. lr = info_dict['lr']
  274. rank = int(os.environ.get('RANK', 0))
  275. logging.info(
  276. 'Epoch {} Step {} CV info lr {} {} rank {}'.format(
  277. epoch, step + 1, lr, rank, ' '.join(['{}_{}'.format(k, v) for k, v in loss_dict.items()])))
  278. if writer is not None:
  279. for k in ['epoch', 'lr']:
  280. writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1)
  281. for k, v in loss_dict.items():
  282. writer.add_scalar('{}/{}'.format(tag, k), v, step + 1)