|
@@ -71,7 +71,7 @@ def init_dataset_and_dataloader(args, configs, gan, dpo):
|
|
|
|
|
|
def check_modify_and_save_config(args, configs):
|
|
|
if args.train_engine == "torch_ddp":
|
|
|
- configs['train_conf']["dtype"] = 'fp32'
|
|
|
+ configs['train_conf']["dtype"] = 'bf16' if args.use_amp is True else 'fp32'
|
|
|
else:
|
|
|
with open(args.deepspeed_config, 'r') as fin:
|
|
|
ds_configs = json.load(fin)
|
|
@@ -247,7 +247,7 @@ def batch_forward(model, batch, scaler, info_dict, ref_model=None, dpo_loss=None
|
|
|
dtype = torch.float32
|
|
|
|
|
|
if info_dict['train_engine'] == 'torch_ddp':
|
|
|
- autocast = torch.cuda.amp.autocast(enabled=scaler is not None)
|
|
|
+ autocast = torch.cuda.amp.autocast(enabled=scaler is not None, dtype=dtype)
|
|
|
else:
|
|
|
autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False)
|
|
|
|