Bläddra i källkod

add constant lr scheduler

lyuxiang.lx 1 år sedan
förälder
incheckning
793a24862c

+ 22 - 0
cosyvoice/utils/scheduler.py

@@ -715,3 +715,25 @@ class NoamHoldAnnealing(WarmupHoldPolicy):
 
     def set_step(self, step: int):
         self.last_epoch = step
+
+
+class ConstantLR(_LRScheduler):
+    """The ConstantLR scheduler
+
+    This scheduler keeps a constant lr
+
+    """
+
+    def __init__(
+        self,
+        optimizer: torch.optim.Optimizer,
+    ):
+        # __init__() must be invoked before setting field
+        # because step() is also invoked in __init__()
+        super().__init__(optimizer)
+
+    def get_lr(self):
+        return self.base_lrs
+
+    def set_step(self, step: int):
+        self.last_epoch = step

+ 4 - 1
cosyvoice/utils/train_utils.py

@@ -34,7 +34,7 @@ from torch.nn.utils import clip_grad_norm_
 from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live
 
 from cosyvoice.dataset.dataset import Dataset
-from cosyvoice.utils.scheduler import WarmupLR, NoamHoldAnnealing
+from cosyvoice.utils.scheduler import WarmupLR, NoamHoldAnnealing, ConstantLR
 
 
 def init_distributed(args):
@@ -122,6 +122,9 @@ def init_optimizer_and_scheduler(args, configs, model):
     elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
         scheduler_type = NoamHoldAnnealing
         scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
+    elif configs['train_conf']['scheduler'] == 'constantlr':
+        scheduler_type = ConstantLR
+        scheduler = ConstantLR(optimizer)
     else:
         raise ValueError("unknown scheduler: " + configs['train_conf'])
 

+ 2 - 2
examples/libritts/cosyvoice/conf/cosyvoice.yaml

@@ -186,8 +186,8 @@ data_pipeline: [
 train_conf:
     optim: adam
     optim_conf:
-        lr: 0.001
-    scheduler: warmuplr
+        lr: 0.001 # change to 1e-5 during sft
+    scheduler: warmuplr # change to constantlr during sft
     scheduler_conf:
         warmup_steps: 2500
     max_epoch: 200

+ 1 - 1
tools/extract_embedding.py

@@ -54,7 +54,7 @@ def main(args):
             spk2embedding[spk] = []
         spk2embedding[spk].append(embedding)
     for k, v in spk2embedding.items():
-        spk2embedding[k] = torch.tensor(v).mean(dim=0, keepdim=True).tolist()
+        spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()
 
     torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir))
     torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir))