|
|
@@ -51,9 +51,10 @@ def init_distributed(args):
|
|
|
return world_size, local_rank, rank
|
|
|
|
|
|
|
|
|
-def init_dataset_and_dataloader(args, configs):
|
|
|
- train_dataset = Dataset(args.train_data, data_pipeline=configs['data_pipeline'], mode='train', shuffle=True, partition=True)
|
|
|
- cv_dataset = Dataset(args.cv_data, data_pipeline=configs['data_pipeline'], mode='train', shuffle=False, partition=False)
|
|
|
+def init_dataset_and_dataloader(args, configs, gan):
|
|
|
+ data_pipeline = configs['data_pipeline_gan'] if gan is True else configs['data_pipeline']
|
|
|
+ train_dataset = Dataset(args.train_data, data_pipeline=data_pipeline, mode='train', gan=gan, shuffle=True, partition=True)
|
|
|
+ cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='train', gan=gan, shuffle=False, partition=False)
|
|
|
|
|
|
# do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
|
|
|
train_data_loader = DataLoader(train_dataset,
|
|
|
@@ -108,30 +109,31 @@ def wrap_cuda_model(args, model):
|
|
|
return model
|
|
|
|
|
|
|
|
|
-def init_optimizer_and_scheduler(args, configs, model):
|
|
|
- if configs['train_conf']['optim'] == 'adam':
|
|
|
- optimizer = optim.Adam(model.parameters(), **configs['train_conf']['optim_conf'])
|
|
|
- elif configs['train_conf']['optim'] == 'adamw':
|
|
|
- optimizer = optim.AdamW(model.parameters(), **configs['train_conf']['optim_conf'])
|
|
|
+def init_optimizer_and_scheduler(args, configs, model, gan):
|
|
|
+ key = 'train_conf_gan' if gan is True else 'train_conf'
|
|
|
+ if configs[key]['optim'] == 'adam':
|
|
|
+ optimizer = optim.Adam(model.parameters(), **configs[key]['optim_conf'])
|
|
|
+ elif configs[key]['optim'] == 'adamw':
|
|
|
+ optimizer = optim.AdamW(model.parameters(), **configs[key]['optim_conf'])
|
|
|
else:
|
|
|
- raise ValueError("unknown optimizer: " + configs['train_conf'])
|
|
|
+ raise ValueError("unknown optimizer: " + configs[key])
|
|
|
|
|
|
- if configs['train_conf']['scheduler'] == 'warmuplr':
|
|
|
+ if configs[key]['scheduler'] == 'warmuplr':
|
|
|
scheduler_type = WarmupLR
|
|
|
- scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
|
|
|
- elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
|
|
|
+ scheduler = WarmupLR(optimizer, **configs[key]['scheduler_conf'])
|
|
|
+ elif configs[key]['scheduler'] == 'NoamHoldAnnealing':
|
|
|
scheduler_type = NoamHoldAnnealing
|
|
|
- scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
|
|
|
- elif configs['train_conf']['scheduler'] == 'constantlr':
|
|
|
+ scheduler = NoamHoldAnnealing(optimizer, **configs[key]['scheduler_conf'])
|
|
|
+ elif configs[key]['scheduler'] == 'constantlr':
|
|
|
scheduler_type = ConstantLR
|
|
|
scheduler = ConstantLR(optimizer)
|
|
|
else:
|
|
|
- raise ValueError("unknown scheduler: " + configs['train_conf'])
|
|
|
+ raise ValueError("unknown scheduler: " + configs[key])
|
|
|
|
|
|
# use deepspeed optimizer for speedup
|
|
|
if args.train_engine == "deepspeed":
|
|
|
def scheduler(opt):
|
|
|
- return scheduler_type(opt, **configs['train_conf']['scheduler_conf'])
|
|
|
+ return scheduler_type(opt, **configs[key]['scheduler_conf'])
|
|
|
model, optimizer, _, scheduler = deepspeed.initialize(
|
|
|
args=args,
|
|
|
model=model,
|
|
|
@@ -139,49 +141,28 @@ def init_optimizer_and_scheduler(args, configs, model):
|
|
|
lr_scheduler=scheduler,
|
|
|
model_parameters=model.parameters())
|
|
|
|
|
|
- return model, optimizer, scheduler
|
|
|
-
|
|
|
-
|
|
|
-def init_optimizer_and_scheduler_gan(args, configs, model):
|
|
|
- if configs['train_conf']['optim'] == 'adam':
|
|
|
- optimizer = optim.Adam(model.module.generator.parameters(), **configs['train_conf']['optim_conf'])
|
|
|
- elif configs['train_conf']['optim'] == 'adamw':
|
|
|
- optimizer = optim.AdamW(model.module.generator.parameters(), **configs['train_conf']['optim_conf'])
|
|
|
- else:
|
|
|
- raise ValueError("unknown optimizer: " + configs['train_conf'])
|
|
|
-
|
|
|
- if configs['train_conf']['scheduler'] == 'warmuplr':
|
|
|
- scheduler_type = WarmupLR
|
|
|
- scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
|
|
|
- elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
|
|
|
- scheduler_type = NoamHoldAnnealing
|
|
|
- scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
|
|
|
- elif configs['train_conf']['scheduler'] == 'constantlr':
|
|
|
- scheduler_type = ConstantLR
|
|
|
- scheduler = ConstantLR(optimizer)
|
|
|
- else:
|
|
|
- raise ValueError("unknown scheduler: " + configs['train_conf'])
|
|
|
-
|
|
|
- if configs['train_conf']['optim_d'] == 'adam':
|
|
|
- optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
|
|
|
- elif configs['train_conf']['optim_d'] == 'adamw':
|
|
|
- optimizer_d = optim.AdamW(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
|
|
|
- else:
|
|
|
- raise ValueError("unknown optimizer: " + configs['train_conf'])
|
|
|
-
|
|
|
- if configs['train_conf']['scheduler_d'] == 'warmuplr':
|
|
|
- scheduler_type = WarmupLR
|
|
|
- scheduler_d = WarmupLR(optimizer_d, **configs['train_conf']['scheduler_conf'])
|
|
|
- elif configs['train_conf']['scheduler_d'] == 'NoamHoldAnnealing':
|
|
|
- scheduler_type = NoamHoldAnnealing
|
|
|
- scheduler_d = NoamHoldAnnealing(optimizer_d, **configs['train_conf']['scheduler_conf'])
|
|
|
- elif configs['train_conf']['scheduler'] == 'constantlr':
|
|
|
- scheduler_type = ConstantLR
|
|
|
- scheduler_d = ConstantLR(optimizer_d)
|
|
|
- else:
|
|
|
- raise ValueError("unknown scheduler: " + configs['train_conf'])
|
|
|
-
|
|
|
# currently we wrap generator and discriminator in one model, so we cannot use deepspeed
|
|
|
+ if gan is True:
|
|
|
+ if configs[key]['optim_d'] == 'adam':
|
|
|
+ optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs[key]['optim_conf'])
|
|
|
+ elif configs[key]['optim_d'] == 'adamw':
|
|
|
+ optimizer_d = optim.AdamW(model.module.discriminator.parameters(), **configs[key]['optim_conf'])
|
|
|
+ else:
|
|
|
+ raise ValueError("unknown optimizer: " + configs[key])
|
|
|
+
|
|
|
+ if configs[key]['scheduler_d'] == 'warmuplr':
|
|
|
+ scheduler_type = WarmupLR
|
|
|
+ scheduler_d = WarmupLR(optimizer_d, **configs[key]['scheduler_conf'])
|
|
|
+ elif configs[key]['scheduler_d'] == 'NoamHoldAnnealing':
|
|
|
+ scheduler_type = NoamHoldAnnealing
|
|
|
+ scheduler_d = NoamHoldAnnealing(optimizer_d, **configs[key]['scheduler_conf'])
|
|
|
+ elif configs[key]['scheduler'] == 'constantlr':
|
|
|
+ scheduler_type = ConstantLR
|
|
|
+ scheduler_d = ConstantLR(optimizer_d)
|
|
|
+ else:
|
|
|
+ raise ValueError("unknown scheduler: " + configs[key])
|
|
|
+ else:
|
|
|
+ optimizer_d, scheduler_d = None, None
|
|
|
return model, optimizer, scheduler, optimizer_d, scheduler_d
|
|
|
|
|
|
|