Kaynağa Gözat

use spk_embedding when sft

lyuxiang.lx 1 yıl önce
ebeveyn
işleme
0fd15bb12b

+ 1 - 1
cosyvoice/flow/flow.py

@@ -60,7 +60,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
         token_len = batch['speech_token_len'].to(device)
         feat = batch['speech_feat'].to(device)
         feat_len = batch['speech_feat_len'].to(device)
-        embedding = batch['utt_embedding'].to(device)
+        embedding = batch['embedding'].to(device)
 
         # xvec projection
         embedding = F.normalize(embedding, dim=1)

+ 1 - 1
cosyvoice/llm/llm.py

@@ -97,7 +97,7 @@ class TransformerLM(torch.nn.Module):
         text_token_len = batch['text_token_len'].to(device)
         speech_token = batch['speech_token'].to(device)
         speech_token_len = batch['speech_token_len'].to(device)
-        embedding = batch['utt_embedding'].to(device)
+        embedding = batch['embedding'].to(device)
 
         # 1. prepare llm_target
         lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() + [self.speech_token_size]) for i in range(text_token.size(0))]

+ 4 - 0
cosyvoice/utils/executor.py

@@ -52,6 +52,10 @@ class Executor:
                 info_dict["batch_idx"] = batch_idx
                 if cosyvoice_join(group_join, info_dict):
                     break
+                if info_dict["use_spk_embedding"] is True:
+                    batch_dict["embedding"] = batch_dict["spk_embedding"]
+                else:
+                    batch_dict["embedding"] = batch_dict["utt_embedding"]
 
                 # Disable gradient synchronizations across DDP processes.
                 # Within this context, gradients will be accumulated on module

+ 1 - 0
examples/libritts/cosyvoice/conf/cosyvoice.fromscratch.yaml

@@ -190,6 +190,7 @@ train_conf:
     scheduler: warmuplr
     scheduler_conf:
         warmup_steps: 25000
+    use_spk_embedding: False # change to True during sft
     max_epoch: 200
     grad_clip: 5
     accum_grad: 2

+ 1 - 0
examples/libritts/cosyvoice/conf/cosyvoice.yaml

@@ -190,6 +190,7 @@ train_conf:
     scheduler: warmuplr # change to constantlr during sft
     scheduler_conf:
         warmup_steps: 2500
+    use_spk_embedding: False # change to True during sft
     max_epoch: 200
     grad_clip: 5
     accum_grad: 2