|
|
@@ -34,7 +34,7 @@ from torch.nn.utils import clip_grad_norm_
|
|
|
from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live
|
|
|
|
|
|
from cosyvoice.dataset.dataset import Dataset
|
|
|
-from cosyvoice.utils.scheduler import WarmupLR, NoamHoldAnnealing
|
|
|
+from cosyvoice.utils.scheduler import WarmupLR, NoamHoldAnnealing, ConstantLR
|
|
|
|
|
|
|
|
|
def init_distributed(args):
|
|
|
@@ -122,6 +122,9 @@ def init_optimizer_and_scheduler(args, configs, model):
|
|
|
elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
|
|
|
scheduler_type = NoamHoldAnnealing
|
|
|
scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
|
|
|
+ elif configs['train_conf']['scheduler'] == 'constantlr':
|
|
|
+ scheduler_type = ConstantLR
|
|
|
+ scheduler = ConstantLR(optimizer)
|
|
|
else:
|
|
|
raise ValueError("unknown scheduler: " + configs['train_conf'])
|
|
|
|