Parcourir la source

fix cosyvoice3 training

lyuxiang.lx il y a 1 mois
Parent
commit
4d7295a9a7

+ 3 - 7
cosyvoice/dataset/processor.py

@@ -26,7 +26,7 @@ import pyworld as pw
 AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
 
 
-def parquet_opener(data, mode='train', tts_data={}):
+def parquet_opener(data, mode='train'):
     """ Give url or local file, return file descriptor
         Inplace operation.
 
@@ -44,12 +44,8 @@ def parquet_opener(data, mode='train', tts_data={}):
                 df = df.to_pandas()
                 for i in range(len(df)):
                     sample.update(dict(df.loc[i]))
-                    if mode == 'train':
-                        # NOTE do not return sample directly, must initialize a new dict
-                        yield {**sample}
-                    else:
-                        for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
-                            yield {**sample, 'tts_index': index, 'tts_text': text}
+                    # NOTE do not return sample directly, must initialize a new dict
+                    yield {**sample}
         except Exception as ex:
             logging.warning('Failed to open {}, ex info {}'.format(url, ex))
 

+ 3 - 3
cosyvoice/flow/flow.py

@@ -332,8 +332,9 @@ class CausalMaskedDiffWithDiT(torch.nn.Module):
         token = self.input_embedding(torch.clamp(token, min=0)) * mask
 
         # text encode
-        h, h_lengths = self.encoder(token, token_len, streaming=streaming)
-        h = self.encoder_proj(h)
+        h = self.pre_lookahead_layer(token)
+        h = h.repeat_interleave(self.token_mel_ratio, dim=1)
+        mask = mask.repeat_interleave(self.token_mel_ratio, dim=1).squeeze(dim=-1)
 
         # get conditions
         conds = torch.zeros(feat.shape, device=token.device)
@@ -344,7 +345,6 @@ class CausalMaskedDiffWithDiT(torch.nn.Module):
             conds[i, :index] = feat[i, :index]
         conds = conds.transpose(1, 2)
 
-        mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
         loss, _ = self.decoder.compute_loss(
             feat.transpose(1, 2).contiguous(),
             mask.unsqueeze(1),

+ 13 - 7
cosyvoice/llm/llm.py

@@ -301,18 +301,23 @@ class Qwen2LM(TransformerLM):
         self.stop_token_ids = [speech_token_size + i for i in range(3)]
         self.vllm_output_queue = {}
 
-    def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len):
+    def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len, instruct_token=None, instruct_token_emb=None, instruct_token_len=None):
         lm_target, lm_input = [], []
         text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
         speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
         text_token_emb = unpad_sequence(text_token_emb, text_token_len.cpu(), batch_first=True)
         speech_token_emb = unpad_sequence(speech_token_emb, speech_token_len.cpu(), batch_first=True)
+        # NOTE add instruct_token in CosyVoice3
+        if instruct_token is not None and instruct_token_emb is not None and instruct_token_len is not None:
+            instruct_token = unpad_sequence(instruct_token, instruct_token_len.cpu(), batch_first=True)
+            instruct_token_emb = unpad_sequence(instruct_token_emb, instruct_token_len.cpu(), batch_first=True)
         for i in range(len(text_token)):
             # bistream sequence
             if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
-                this_lm_target, this_lm_input = [], []
-                this_lm_target.append(IGNORE_ID)
-                this_lm_input.append(sos_emb.squeeze(dim=0))
+                this_lm_target, this_lm_input = [IGNORE_ID], [sos_emb.squeeze(dim=0)]
+                if instruct_token is not None and instruct_token_emb is not None and instruct_token_len is not None:
+                    this_lm_target += [IGNORE_ID] * instruct_token_len[i]
+                    this_lm_input.append(instruct_token_emb[i])
                 for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
                     this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
                     this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
@@ -333,8 +338,8 @@ class Qwen2LM(TransformerLM):
                 this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
             # unistream sequence
             else:
-                this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token])
-                this_lm_input = torch.concat([sos_emb.squeeze(dim=0), text_token_emb[i], task_id_emb.squeeze(dim=0), speech_token_emb[i]], dim=0)
+                this_lm_target = torch.tensor([IGNORE_ID] * (1 + instruct_token_len[i] + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token])
+                this_lm_input = torch.concat([sos_emb.squeeze(dim=0), instruct_token_emb[i], text_token_emb[i], task_id_emb.squeeze(dim=0), speech_token_emb[i]], dim=0)
             lm_target.append(this_lm_target)
             lm_input.append(this_lm_input)
         lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
@@ -681,6 +686,7 @@ class CosyVoice3LM(Qwen2LM):
 
         # 1. encode text_token
         text_token_emb = self.llm.model.model.embed_tokens(text_token)
+        instruct_token_emb = self.llm.model.model.embed_tokens(instruct_token)
 
         # 3. sos and task_id
         sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
@@ -691,7 +697,7 @@ class CosyVoice3LM(Qwen2LM):
 
         # 3. prepare llm_input/target
         lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
-                                                                         speech_token, speech_token_emb, speech_token_len)
+                                                                         speech_token, speech_token_emb, speech_token_len, instruct_token, instruct_token_emb, instruct_token_len)
         lm_target = lm_target.to(device)
 
         # 4. run lm forward

+ 5 - 5
cosyvoice/utils/train_utils.py

@@ -53,7 +53,7 @@ def init_distributed(args):
 def init_dataset_and_dataloader(args, configs, gan, dpo):
     data_pipeline = configs['data_pipeline_gan'] if gan is True else configs['data_pipeline']
     train_dataset = Dataset(args.train_data, data_pipeline=data_pipeline, mode='train', gan=gan, dpo=dpo, shuffle=True, partition=True)
-    cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='train', gan=gan, dpo=dpo, shuffle=False, partition=False)
+    cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='dev', gan=gan, dpo=dpo, shuffle=False, partition=False)
 
     # do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
     train_data_loader = DataLoader(train_dataset,
@@ -164,18 +164,18 @@ def init_optimizer_and_scheduler(args, configs, model, gan):
             raise ValueError("unknown scheduler: " + configs['train_conf'])
 
         if configs['train_conf']['optim_d'] == 'adam':
-            optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
+            optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf_d'])
         elif configs['train_conf']['optim_d'] == 'adamw':
-            optimizer_d = optim.AdamW(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
+            optimizer_d = optim.AdamW(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf_d'])
         else:
             raise ValueError("unknown optimizer: " + configs['train_conf'])
 
         if configs['train_conf']['scheduler_d'] == 'warmuplr':
             scheduler_type = WarmupLR
-            scheduler_d = WarmupLR(optimizer_d, **configs['train_conf']['scheduler_conf'])
+            scheduler_d = WarmupLR(optimizer_d, **configs['train_conf']['scheduler_d'])
         elif configs['train_conf']['scheduler_d'] == 'NoamHoldAnnealing':
             scheduler_type = NoamHoldAnnealing
-            scheduler_d = NoamHoldAnnealing(optimizer_d, **configs['train_conf']['scheduler_conf'])
+            scheduler_d = NoamHoldAnnealing(optimizer_d, **configs['train_conf']['scheduler_d'])
         elif configs['train_conf']['scheduler'] == 'constantlr':
             scheduler_type = ConstantLR
             scheduler_d = ConstantLR(optimizer_d)

+ 1 - 1
examples/libritts/cosyvoice3/conf/cosyvoice3.yaml

@@ -136,7 +136,7 @@ filter: !name:cosyvoice.dataset.processor.filter
 resample: !name:cosyvoice.dataset.processor.resample
     resample_rate: !ref <sample_rate>
 truncate: !name:cosyvoice.dataset.processor.truncate
-    truncate_length: 24480 # must be a multiplier of hop_size
+    truncate_length: 24960 # must be a multiplier of hop_size and token_mel_ratio
 feat_extractor: !name:matcha.utils.audio.mel_spectrogram
     n_fft: 1920
     num_mels: 80