Browse Source

Merge branch 'FunAudioLLM:main' into fastapi

New Bing 1 year ago
parent
commit
3513376c0f

+ 4 - 1
cosyvoice/cli/frontend.py

@@ -114,7 +114,10 @@ class CosyVoiceFrontEnd:
                                                 token_min_n=60, merge_len=20,
                                                 comma_split=False)]
         else:
-            text = self.en_tn_model.normalize(text)
+            if self.use_ttsfrd:
+                text = self.frd.get_frd_extra_info(text, 'input')
+            else:
+                text = self.en_tn_model.normalize(text)
             text = spell_out_number(text, self.inflect_parser)
             texts = [i for i in split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
                                                 token_min_n=60, merge_len=20,

+ 1 - 0
cosyvoice/cli/model.py

@@ -56,4 +56,5 @@ class CosyVoiceModel:
                                       prompt_feat_len=prompt_speech_feat_len.to(self.device),
                                       embedding=flow_embedding.to(self.device))
         tts_speech = self.hift.inference(mel=tts_mel).cpu()
+        torch.cuda.empty_cache()
         return {'tts_speech': tts_speech}

+ 1 - 1
cosyvoice/dataset/processor.py

@@ -167,7 +167,7 @@ def parse_embedding(data, normalize, mode='train'):
     """
     for sample in data:
         sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
-        sample['spk_embedding'] = torch.stack([torch.tensor(i, dtype=torch.float32) for i in sample['spk_embedding']], dim=0).mean(dim=0)
+        sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
         if normalize:
             sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
             sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)

+ 1 - 1
cosyvoice/flow/flow.py

@@ -60,7 +60,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
         token_len = batch['speech_token_len'].to(device)
         feat = batch['speech_feat'].to(device)
         feat_len = batch['speech_feat_len'].to(device)
-        embedding = batch['utt_embedding'].to(device)
+        embedding = batch['embedding'].to(device)
 
         # xvec projection
         embedding = F.normalize(embedding, dim=1)

+ 1 - 1
cosyvoice/llm/llm.py

@@ -97,7 +97,7 @@ class TransformerLM(torch.nn.Module):
         text_token_len = batch['text_token_len'].to(device)
         speech_token = batch['speech_token'].to(device)
         speech_token_len = batch['speech_token_len'].to(device)
-        embedding = batch['utt_embedding'].to(device)
+        embedding = batch['embedding'].to(device)
 
         # 1. prepare llm_target
         lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() + [self.speech_token_size]) for i in range(text_token.size(0))]

+ 4 - 0
cosyvoice/utils/executor.py

@@ -52,6 +52,10 @@ class Executor:
                 info_dict["batch_idx"] = batch_idx
                 if cosyvoice_join(group_join, info_dict):
                     break
+                if info_dict["use_spk_embedding"] is True:
+                    batch_dict["embedding"] = batch_dict["spk_embedding"]
+                else:
+                    batch_dict["embedding"] = batch_dict["utt_embedding"]
 
                 # Disable gradient synchronizations across DDP processes.
                 # Within this context, gradients will be accumulated on module

+ 22 - 0
cosyvoice/utils/scheduler.py

@@ -715,3 +715,25 @@ class NoamHoldAnnealing(WarmupHoldPolicy):
 
     def set_step(self, step: int):
         self.last_epoch = step
+
+
+class ConstantLR(_LRScheduler):
+    """The ConstantLR scheduler
+
+    This scheduler keeps a constant lr
+
+    """
+
+    def __init__(
+        self,
+        optimizer: torch.optim.Optimizer,
+    ):
+        # __init__() must be invoked before setting field
+        # because step() is also invoked in __init__()
+        super().__init__(optimizer)
+
+    def get_lr(self):
+        return self.base_lrs
+
+    def set_step(self, step: int):
+        self.last_epoch = step

+ 4 - 1
cosyvoice/utils/train_utils.py

@@ -34,7 +34,7 @@ from torch.nn.utils import clip_grad_norm_
 from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live
 
 from cosyvoice.dataset.dataset import Dataset
-from cosyvoice.utils.scheduler import WarmupLR, NoamHoldAnnealing
+from cosyvoice.utils.scheduler import WarmupLR, NoamHoldAnnealing, ConstantLR
 
 
 def init_distributed(args):
@@ -122,6 +122,9 @@ def init_optimizer_and_scheduler(args, configs, model):
     elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
         scheduler_type = NoamHoldAnnealing
         scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
+    elif configs['train_conf']['scheduler'] == 'constantlr':
+        scheduler_type = ConstantLR
+        scheduler = ConstantLR(optimizer)
     else:
         raise ValueError("unknown scheduler: " + configs['train_conf'])
 

+ 1 - 0
examples/libritts/cosyvoice/conf/cosyvoice.fromscratch.yaml

@@ -190,6 +190,7 @@ train_conf:
     scheduler: warmuplr
     scheduler_conf:
         warmup_steps: 25000
+    use_spk_embedding: False # change to True during sft
     max_epoch: 200
     grad_clip: 5
     accum_grad: 2

+ 3 - 2
examples/libritts/cosyvoice/conf/cosyvoice.yaml

@@ -186,10 +186,11 @@ data_pipeline: [
 train_conf:
     optim: adam
     optim_conf:
-        lr: 0.001
-    scheduler: warmuplr
+        lr: 0.001 # change to 1e-5 during sft
+    scheduler: warmuplr # change to constantlr during sft
     scheduler_conf:
         warmup_steps: 2500
+    use_spk_embedding: False # change to True during sft
     max_epoch: 200
     grad_clip: 5
     accum_grad: 2

+ 2 - 0
tools/extract_embedding.py

@@ -53,6 +53,8 @@ def main(args):
         if spk not in spk2embedding:
             spk2embedding[spk] = []
         spk2embedding[spk].append(embedding)
+    for k, v in spk2embedding.items():
+        spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()
 
     torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir))
     torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir))