Răsfoiți Sursa

Merge pull request #580 from FunAudioLLM/dev/lyuxiang.lx

fix hifigan init bug
Xiang Lyu 1 an în urmă
părinte
comite
d8f00f4793

+ 9 - 2
cosyvoice/bin/train.py

@@ -68,6 +68,10 @@ def get_args():
                         action='store_true',
                         default=False,
                         help='Use pinned memory buffers used for reading')
+    parser.add_argument('--use_amp',
+                        action='store_true',
+                        default=False,
+                        help='Use automatic mixed precision training')
     parser.add_argument('--deepspeed.save_states',
                         dest='save_states',
                         default='model_only',
@@ -133,6 +137,9 @@ def main():
     # Get executor
     executor = Executor(gan=gan)
 
+    # Init scaler, used for pytorch amp mixed precision training
+    scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
+
     # Start training loop
     for epoch in range(info_dict['max_epoch']):
         executor.epoch = epoch
@@ -141,9 +148,9 @@ def main():
         group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
         if gan is True:
             executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
-                                        writer, info_dict, group_join)
+                                        writer, info_dict, scaler, group_join)
         else:
-            executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join)
+            executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join)
         dist.destroy_process_group(group_join)
 
 

+ 12 - 12
cosyvoice/utils/executor.py

@@ -32,7 +32,7 @@ class Executor:
         self.rank = int(os.environ.get('RANK', 0))
         self.device = torch.device('cuda:{}'.format(self.rank))
 
-    def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join):
+    def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join):
         ''' Train one epoch
         '''
 
@@ -65,10 +65,10 @@ class Executor:
                     context = nullcontext
 
                 with context():
-                    info_dict = batch_forward(model, batch_dict, info_dict)
-                    info_dict = batch_backward(model, info_dict)
+                    info_dict = batch_forward(model, batch_dict, scaler, info_dict)
+                    info_dict = batch_backward(model, scaler, info_dict)
 
-                info_dict = update_parameter_and_lr(model, optimizer, scheduler, info_dict)
+                info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
                 log_per_step(writer, info_dict)
                 # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
                 if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
@@ -82,7 +82,7 @@ class Executor:
         self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
 
     def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
-                           writer, info_dict, group_join):
+                           writer, info_dict, scaler, group_join):
         ''' Train one epoch
         '''
 
@@ -116,16 +116,16 @@ class Executor:
 
                 with context():
                     batch_dict['turn'] = 'discriminator'
-                    info_dict = batch_forward(model, batch_dict, info_dict)
-                    info_dict = batch_backward(model, info_dict)
-                info_dict = update_parameter_and_lr(model, optimizer_d, scheduler_d, info_dict)
+                    info_dict = batch_forward(model, batch_dict, scaler, info_dict)
+                    info_dict = batch_backward(model, scaler, info_dict)
+                info_dict = update_parameter_and_lr(model, optimizer_d, scheduler_d, scaler, info_dict)
                 optimizer.zero_grad()
                 log_per_step(writer, info_dict)
                 with context():
                     batch_dict['turn'] = 'generator'
-                    info_dict = batch_forward(model, batch_dict, info_dict)
-                    info_dict = batch_backward(model, info_dict)
-                info_dict = update_parameter_and_lr(model, optimizer, scheduler, info_dict)
+                    info_dict = batch_forward(model, batch_dict, scaler, info_dict)
+                    info_dict = batch_backward(model, scaler, info_dict)
+                info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
                 optimizer_d.zero_grad()
                 log_per_step(writer, info_dict)
                 # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
@@ -157,7 +157,7 @@ class Executor:
 
             if self.gan is True:
                 batch_dict['turn'] = 'generator'
-            info_dict = batch_forward(model, batch_dict, info_dict)
+            info_dict = batch_forward(model, batch_dict, None, info_dict)
 
             for k, v in info_dict['loss_dict'].items():
                 if k not in total_loss_dict:

+ 74 - 42
cosyvoice/utils/train_utils.py

@@ -14,7 +14,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from contextlib import nullcontext
 import logging
 import os
 import torch
@@ -110,38 +109,60 @@ def wrap_cuda_model(args, model):
 
 
 def init_optimizer_and_scheduler(args, configs, model, gan):
-    if configs['train_conf']['optim'] == 'adam':
-        optimizer = optim.Adam(model.parameters(), **configs['train_conf']['optim_conf'])
-    elif configs['train_conf']['optim'] == 'adamw':
-        optimizer = optim.AdamW(model.parameters(), **configs['train_conf']['optim_conf'])
-    else:
-        raise ValueError("unknown optimizer: " + configs['train_conf'])
-
-    if configs['train_conf']['scheduler'] == 'warmuplr':
-        scheduler_type = WarmupLR
-        scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
-    elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
-        scheduler_type = NoamHoldAnnealing
-        scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
-    elif configs['train_conf']['scheduler'] == 'constantlr':
-        scheduler_type = ConstantLR
-        scheduler = ConstantLR(optimizer)
+    if gan is False:
+        if configs['train_conf']['optim'] == 'adam':
+            optimizer = optim.Adam(model.parameters(), **configs['train_conf']['optim_conf'])
+        elif configs['train_conf']['optim'] == 'adamw':
+            optimizer = optim.AdamW(model.parameters(), **configs['train_conf']['optim_conf'])
+        else:
+            raise ValueError("unknown optimizer: " + configs['train_conf'])
+
+        if configs['train_conf']['scheduler'] == 'warmuplr':
+            scheduler_type = WarmupLR
+            scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
+        elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
+            scheduler_type = NoamHoldAnnealing
+            scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
+        elif configs['train_conf']['scheduler'] == 'constantlr':
+            scheduler_type = ConstantLR
+            scheduler = ConstantLR(optimizer)
+        else:
+            raise ValueError("unknown scheduler: " + configs['train_conf'])
+
+        # use deepspeed optimizer for speedup
+        if args.train_engine == "deepspeed":
+            def scheduler(opt):
+                return scheduler_type(opt, **configs['train_conf']['scheduler_conf'])
+            model, optimizer, _, scheduler = deepspeed.initialize(
+                args=args,
+                model=model,
+                optimizer=None,
+                lr_scheduler=scheduler,
+                model_parameters=model.parameters())
+
+        optimizer_d, scheduler_d = None, None
+
     else:
-        raise ValueError("unknown scheduler: " + configs['train_conf'])
-
-    # use deepspeed optimizer for speedup
-    if args.train_engine == "deepspeed":
-        def scheduler(opt):
-            return scheduler_type(opt, **configs['train_conf']['scheduler_conf'])
-        model, optimizer, _, scheduler = deepspeed.initialize(
-            args=args,
-            model=model,
-            optimizer=None,
-            lr_scheduler=scheduler,
-            model_parameters=model.parameters())
-
-    # currently we wrap generator and discriminator in one model, so we cannot use deepspeed
-    if gan is True:
+        # currently we wrap generator and discriminator in one model, so we cannot use deepspeed
+        if configs['train_conf']['optim'] == 'adam':
+            optimizer = optim.Adam(model.module.generator.parameters(), **configs['train_conf']['optim_conf'])
+        elif configs['train_conf']['optim'] == 'adamw':
+            optimizer = optim.AdamW(model.module.generator.parameters(), **configs['train_conf']['optim_conf'])
+        else:
+            raise ValueError("unknown optimizer: " + configs['train_conf'])
+
+        if configs['train_conf']['scheduler'] == 'warmuplr':
+            scheduler_type = WarmupLR
+            scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
+        elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
+            scheduler_type = NoamHoldAnnealing
+            scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
+        elif configs['train_conf']['scheduler'] == 'constantlr':
+            scheduler_type = ConstantLR
+            scheduler = ConstantLR(optimizer)
+        else:
+            raise ValueError("unknown scheduler: " + configs['train_conf'])
+
         if configs['train_conf']['optim_d'] == 'adam':
             optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
         elif configs['train_conf']['optim_d'] == 'adamw':
@@ -160,8 +181,6 @@ def init_optimizer_and_scheduler(args, configs, model, gan):
             scheduler_d = ConstantLR(optimizer_d)
         else:
             raise ValueError("unknown scheduler: " + configs['train_conf'])
-    else:
-        optimizer_d, scheduler_d = None, None
     return model, optimizer, scheduler, optimizer_d, scheduler_d
 
 
@@ -216,7 +235,7 @@ def cosyvoice_join(group_join, info_dict):
         return False
 
 
-def batch_forward(model, batch, info_dict):
+def batch_forward(model, batch, scaler, info_dict):
     device = int(os.environ.get('LOCAL_RANK', 0))
 
     dtype = info_dict["dtype"]
@@ -228,7 +247,7 @@ def batch_forward(model, batch, info_dict):
         dtype = torch.float32
 
     if info_dict['train_engine'] == 'torch_ddp':
-        autocast = nullcontext()
+        autocast = torch.cuda.amp.autocast(enabled=scaler is not None)
     else:
         autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False)
 
@@ -237,27 +256,40 @@ def batch_forward(model, batch, info_dict):
     return info_dict
 
 
-def batch_backward(model, info_dict):
+def batch_backward(model, scaler, info_dict):
     if info_dict["train_engine"] == "deepspeed":
         scaled_loss = model.backward(info_dict['loss_dict']['loss'])
     else:
         scaled_loss = info_dict['loss_dict']['loss'] / info_dict['accum_grad']
-        scaled_loss.backward()
+        if scaler is not None:
+            scaler.scale(scaled_loss).backward()
+        else:
+            scaled_loss.backward()
 
     info_dict['loss_dict']['loss'] = scaled_loss
     return info_dict
 
 
-def update_parameter_and_lr(model, optimizer, scheduler, info_dict):
+def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict):
     grad_norm = 0.0
     if info_dict['train_engine'] == "deepspeed":
         info_dict["is_gradient_accumulation_boundary"] = model.is_gradient_accumulation_boundary()
         model.step()
         grad_norm = model.get_global_grad_norm()
     elif (info_dict['batch_idx'] + 1) % info_dict["accum_grad"] == 0:
-        grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
-        if torch.isfinite(grad_norm):
-            optimizer.step()
+        # Use mixed precision training
+        if scaler is not None:
+            scaler.unscale_(optimizer)
+            grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
+            # We don't check grad here since that if the gradient
+            # has inf/nan values, scaler.step will skip
+            # optimizer.step().
+            scaler.step(optimizer)
+            scaler.update()
+        else:
+            grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
+            if torch.isfinite(grad_norm):
+                optimizer.step()
         optimizer.zero_grad()
         scheduler.step()
     info_dict["lr"] = optimizer.param_groups[0]['lr']

+ 1 - 0
examples/libritts/cosyvoice/run.sh

@@ -99,6 +99,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
       --num_workers ${num_workers} \
       --prefetch ${prefetch} \
       --pin_memory \
+      --use_amp \
       --deepspeed_config ./conf/ds_stage2.json \
       --deepspeed.save_states model+optimizer
   done