Переглянути джерело

Merge pull request #653 from FunAudioLLM/dev/lyuxiang.lx

resume training
Xiang Lyu 1 рік тому
батько
коміт
d6dbdfbf31
2 змінених файлів з 18 додано та 5 видалено
  1. 15 3
      cosyvoice/bin/train.py
  2. 3 2
      cosyvoice/utils/train_utils.py

+ 15 - 3
cosyvoice/bin/train.py

@@ -118,9 +118,15 @@ def main():
 
     # load checkpoint
     model = configs[args.model]
+    start_step, start_epoch = 0, -1
     if args.checkpoint is not None:
         if os.path.exists(args.checkpoint):
-            model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'), strict=False)
+            state_dict = torch.load(args.checkpoint, map_location='cpu')
+            model.load_state_dict(state_dict, strict=False)
+            if 'step' in state_dict:
+                start_step = state_dict['step']
+            if 'epoch' in state_dict:
+                start_epoch = state_dict['epoch']
         else:
             logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
 
@@ -129,19 +135,25 @@ def main():
 
     # Get optimizer & scheduler
     model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
+    scheduler.set_step(start_step)
+    if scheduler_d is not None:
+        scheduler_d.set_step(start_step)
 
     # Save init checkpoints
     info_dict = deepcopy(configs['train_conf'])
+    info_dict['step'] = start_step
+    info_dict['epoch'] = start_epoch
     save_model(model, 'init', info_dict)
 
     # Get executor
     executor = Executor(gan=gan)
+    executor.step = start_step
 
     # Init scaler, used for pytorch amp mixed precision training
     scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
-
+    print('start step {} start epoch {}'.format(start_step, start_epoch))
     # Start training loop
-    for epoch in range(info_dict['max_epoch']):
+    for epoch in range(start_epoch + 1, info_dict['max_epoch']):
         executor.epoch = epoch
         train_dataset.set_epoch(epoch)
         dist.barrier()

+ 3 - 2
cosyvoice/utils/train_utils.py

@@ -199,7 +199,7 @@ def save_model(model, model_name, info_dict):
 
     if info_dict["train_engine"] == "torch_ddp":
         if rank == 0:
-            torch.save(model.module.state_dict(), save_model_path)
+            torch.save({**model.module.state_dict(), 'epoch': info_dict['epoch'], 'step': info_dict['step']}, save_model_path)
     else:
         with torch.no_grad():
             model.save_checkpoint(save_dir=model_dir,
@@ -284,7 +284,8 @@ def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict):
             # We don't check grad here since that if the gradient
             # has inf/nan values, scaler.step will skip
             # optimizer.step().
-            scaler.step(optimizer)
+            if torch.isfinite(grad_norm):
+                scaler.step(optimizer)
             scaler.update()
         else:
             grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])