Quellcode durchsuchen

use bf16 for amp

lyuxiang.lx vor 1 Monat
Ursprung
Commit
11515d0d5a
2 geänderte Dateien mit 3 neuen und 3 gelöschten Zeilen
  1. 1 1
      cosyvoice/utils/executor.py
  2. 2 2
      cosyvoice/utils/train_utils.py

+ 1 - 1
cosyvoice/utils/executor.py

@@ -166,7 +166,7 @@ class Executor:
             for k, v in info_dict['loss_dict'].items():
                 if k not in total_loss_dict:
                     total_loss_dict[k] = []
-                total_loss_dict[k].append(v.item() * num_utts)
+                total_loss_dict[k].append(v.mean().item() * num_utts)
             log_per_step(None, info_dict)
         for k, v in total_loss_dict.items():
             total_loss_dict[k] = sum(v) / total_num_utts

+ 2 - 2
cosyvoice/utils/train_utils.py

@@ -71,7 +71,7 @@ def init_dataset_and_dataloader(args, configs, gan, dpo):
 
 def check_modify_and_save_config(args, configs):
     if args.train_engine == "torch_ddp":
-        configs['train_conf']["dtype"] = 'fp32'
+        configs['train_conf']["dtype"] = 'bf16' if args.use_amp is True else 'fp32'
     else:
         with open(args.deepspeed_config, 'r') as fin:
             ds_configs = json.load(fin)
@@ -247,7 +247,7 @@ def batch_forward(model, batch, scaler, info_dict, ref_model=None, dpo_loss=None
         dtype = torch.float32
 
     if info_dict['train_engine'] == 'torch_ddp':
-        autocast = torch.cuda.amp.autocast(enabled=scaler is not None)
+        autocast = torch.cuda.amp.autocast(enabled=scaler is not None, dtype=dtype)
     else:
         autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False)