lyuxiang.lx 2 месяцев назад
Родитель
Сommit
84e41729ea

+ 6 - 3
cosyvoice/flow/flow.py

@@ -189,7 +189,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
             device: torch.device,
     ) -> Dict[str, Optional[torch.Tensor]]:
         if 'speech_token' not in batch:
-            token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'])
+            token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
         else:
             token = batch['speech_token'].to(device)
             token_len = batch['speech_token_len'].to(device)
@@ -322,8 +322,11 @@ class CausalMaskedDiffWithDiT(torch.nn.Module):
             batch: dict,
             device: torch.device,
     ) -> Dict[str, Optional[torch.Tensor]]:
-        token = batch['speech_token'].to(device)
-        token_len = batch['speech_token_len'].to(device)
+        if 'speech_token' not in batch:
+            token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
+        else:
+            token = batch['speech_token'].to(device)
+            token_len = batch['speech_token_len'].to(device)
         feat = batch['speech_feat'].to(device)
         feat_len = batch['speech_feat_len'].to(device)
         embedding = batch['embedding'].to(device)

+ 11 - 4
cosyvoice/llm/llm.py

@@ -367,8 +367,11 @@ class Qwen2LM(TransformerLM):
         """
         text_token = batch['text_token'].to(device)
         text_token_len = batch['text_token_len'].to(device)
-        speech_token = batch['speech_token'].to(device)
-        speech_token_len = batch['speech_token_len'].to(device)
+        if 'speech_token' not in batch:
+            speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
+        else:
+            speech_token = batch['speech_token'].to(device)
+            speech_token_len = batch['speech_token_len'].to(device)
 
         # 1. encode text_token
         text_token_emb = self.llm.model.model.embed_tokens(text_token)
@@ -686,8 +689,12 @@ class CosyVoice3LM(Qwen2LM):
         """
         text_token = batch['text_token'].to(device)
         text_token_len = batch['text_token_len'].to(device)
-        speech_token = batch['speech_token'].to(device)
-        speech_token_len = batch['speech_token_len'].to(device)
+        if 'speech_token' not in batch:
+            speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
+        else:
+            speech_token = batch['speech_token'].to(device)
+            speech_token_len = batch['speech_token_len'].to(device)
+
         # NOTE should append instruct_token to sequence, not implemented yet
         instruct_token = batch['instruct_token'].to(device)
         instruct_token_len = batch['instruct_token_len'].to(device)

+ 2 - 6
cosyvoice/utils/onnx.py

@@ -1,11 +1,7 @@
 import onnxruntime
 import torch, random
-from torch import nn
 import os
-import whisper
-import numpy as np
 import torchaudio.compliance.kaldi as kaldi
-import torch.nn.functional as F
 
 
 class SpeechTokenExtractor():
@@ -18,13 +14,13 @@ class SpeechTokenExtractor():
                                                                      sess_options=option,
                                                                      providers=[("CUDAExecutionProvider", {'device_id': self.local_rank})])
 
-    def inference(self, feat, feat_lengths):
+    def inference(self, feat, feat_lengths, device):
         speech_token = self.speech_tokenizer_session.run(None,
                                                     {self.speech_tokenizer_session.get_inputs()[0].name:
                                                     feat.transpose(1, 2).detach().cpu().numpy(),
                                                     self.speech_tokenizer_session.get_inputs()[1].name:
                                                     feat_lengths.detach().cpu().numpy()})[0]
-        return torch.tensor(speech_token).to(feat), (feat_lengths / 4).to(torch.int32).to(feat.device)
+        return torch.tensor(speech_token).to(torch.int32).to(device), (feat_lengths / 4).to(torch.int32).to(device)
 
 
 class EmbeddingExtractor():

+ 1 - 0
examples/libritts/cosyvoice3/conf/cosyvoice3.yaml

@@ -150,6 +150,7 @@ compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
     feat_extractor: !ref <feat_extractor>
     num_frames: 960
 compute_whisper_fbank: !name:cosyvoice.dataset.processor.compute_whisper_fbank
+    num_frames: 960
 compute_f0: !name:cosyvoice.dataset.processor.compute_f0
     sample_rate: !ref <sample_rate>
     hop_size: 480