lyuxiang.lx 5 days ago
parent
commit
f26cde56df

+ 9 - 0
cosyvoice/cli/model.py

@@ -256,6 +256,10 @@ class CosyVoice2Model(CosyVoiceModel):
         self.fp16 = fp16
         # NOTE must matching training static_chunk_size
         self.token_hop_len = 25
+        # NOTE increase token_hop_len incrementally to avoid duplicate inference 
+        self.token_max_hop_len = 4 * self.token_hop_len
+        self.stream_scale_factor = 2
+        assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
         # hift cache
         self.mel_cache_len = 8
         self.source_cache_len = int(self.mel_cache_len * 480)
@@ -353,6 +357,7 @@ class CosyVoice2Model(CosyVoiceModel):
                                                      stream=stream,
                                                      finalize=False)
                     token_offset += this_token_hop_len
+                    self.token_hop_len = min(self.token_max_hop_len, self.token_hop_len * self.stream_scale_factor)
                     yield {'tts_speech': this_tts_speech.cpu()}
                 if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len:
                     break
@@ -403,6 +408,10 @@ class CosyVoice3Model(CosyVoice2Model):
         self.fp16 = fp16
         # NOTE must matching training static_chunk_size
         self.token_hop_len = 25
+        # NOTE increase token_hop_len incrementally to avoid duplicate inference 
+        self.token_max_hop_len = 4 * self.token_hop_len
+        self.stream_scale_factor = 2
+        assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
         # rtf and decoding related
         self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
         self.lock = threading.Lock()

+ 54 - 61
cosyvoice/dataset/processor.py

@@ -17,6 +17,7 @@ import random
 import pyarrow.parquet as pq
 from io import BytesIO
 import numpy as np
+import whisper
 import torch
 import torchaudio
 from torch.nn.utils.rnn import pad_sequence
@@ -179,6 +180,23 @@ def compute_fbank(data,
         yield sample
 
 
+def compute_whisper_fbank(data, num_frames=-1, mode='train'):
+    """ Extract whisper fbank 
+
+        Args:
+            data: Iterable[{key, wav, label, sample_rate}]
+
+        Returns:
+            Iterable[{key, feat, label}]
+    """
+    for sample in data:
+        if num_frames != -1:
+            assert sample['speech'].shape[1] % num_frames == 0, 'speech length is not aligned with speech_token'
+        sample['speech_16k'] = torchaudio.transforms.Resample(orig_freq=sample['sample_rate'], new_freq=16000)(sample['speech'])
+        sample['whisper_feat'] = whisper.log_mel_spectrogram(sample['speech_16k'], n_mels=128).squeeze(dim=0).transpose(0, 1)
+        yield sample
+
+
 def compute_f0(data, sample_rate, hop_size, mode='train'):
     """ Extract f0
 
@@ -215,11 +233,12 @@ def parse_embedding(data, normalize, mode='train'):
     """
     for sample in data:
         if 'utt_embedding' not in sample and 'spk_embedding' not in sample:
-            speech_16k = torchaudio.transforms.Resample(orig_freq=sample['sample_rate'], new_freq=16000)(sample['speech'])
-            embedding = embedding_extractor.inference(speech_16k)
+            sample['speech_16k'] = torchaudio.transforms.Resample(orig_freq=sample['sample_rate'], new_freq=16000)(sample['speech'])
+            embedding = embedding_extractor.inference(sample['speech_16k'])
             sample['spk_embedding'] = sample['utt_embedding'] = embedding
-        sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
-        sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
+        else:
+            sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
+            sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
         if normalize:
             sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
             sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
@@ -242,8 +261,6 @@ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
         sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
         if 'instruct' in sample:
             sample['instruct_token'] = tokenizer.encode(sample['instruct'], allowed_special=allowed_special)
-        else:
-            sample['instruct_token'] = tokenizer.encode('', allowed_special=allowed_special)
         yield sample
 
 
@@ -371,66 +388,42 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
     """
     for sample in data:
         assert isinstance(sample, list)
-        speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
-                                       dtype=torch.int32)
-        order = torch.argsort(speech_feat_len, descending=True)
-
-        utts = [sample[i]['utt'] for i in order]
-        speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
-        speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
-        speech = pad_sequence(speech, batch_first=True, padding_value=0)
-        speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
-        speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
-        speech_token = pad_sequence(speech_token,
-                                    batch_first=True,
-                                    padding_value=0)
-        speech_feat = [sample[i]['speech_feat'] for i in order]
-        speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
-        speech_feat = pad_sequence(speech_feat,
-                                   batch_first=True,
-                                   padding_value=0)
-        text = [sample[i]['text'] for i in order]
+        order = torch.argsort(torch.tensor([x['speech'].size(1) for x in sample], dtype=torch.int32), descending=True)
+        batch = {}
+        batch['utts'] = [sample[i]['utt'] for i in order]
+        batch['text'] = [sample[i]['text'] for i in order]
         text_token = [torch.tensor(sample[i]['text_token']) for i in order]
-        text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
-        text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
-        instruct_token = [torch.tensor(sample[i]['instruct_token']) for i in order]
-        instruct_token_len = torch.tensor([i.size(0) for i in instruct_token], dtype=torch.int32)
-        instruct_token = pad_sequence(instruct_token, batch_first=True, padding_value=0)
-        utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
-        spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
-        batch = {
-            "utts": utts,
-            "speech": speech,
-            "speech_len": speech_len,
-            "speech_token": speech_token,
-            "speech_token_len": speech_token_len,
-            "speech_feat": speech_feat,
-            "speech_feat_len": speech_feat_len,
-            "text": text,
-            "text_token": text_token,
-            "text_token_len": text_token_len,
-            "instruct_token": instruct_token,
-            "instruct_token_len": instruct_token_len,
-            "utt_embedding": utt_embedding,
-            "spk_embedding": spk_embedding,
-        }
+        batch['text_token_len'] = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
+        batch['text_token'] = pad_sequence(text_token, batch_first=True, padding_value=0)
+        speech_feat = [sample[i]['speech_feat'] for i in order]
+        batch['speech_feat_len'] = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
+        batch['speech_feat'] = pad_sequence(speech_feat, batch_first=True, padding_value=0)
+        batch['utt_embedding'] = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
+        batch['spk_embedding'] = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
+        if torch.tensor(['instruct_token' in sample[i] for i in order]).all():
+            instruct_token = [torch.tensor(sample[i]['instruct_token']) for i in order]
+            batch['instruct_token_len'] = torch.tensor([i.size(0) for i in instruct_token], dtype=torch.int32)
+            batch['instruct_token'] = pad_sequence(instruct_token, batch_first=True, padding_value=0)
+        if torch.tensor(['whisper_feat' in sample[i] for i in order]).all():
+            whisper_feat = [torch.tensor(sample[i]['whisper_feat']) for i in order]
+            batch['whisper_feat_len'] = torch.tensor([i.size(0) for i in whisper_feat], dtype=torch.int32)
+            batch['whisper_feat'] = pad_sequence(whisper_feat, batch_first=True, padding_value=0)
+        if torch.tensor(['speech_token' in sample[i] for i in order]).all():
+            speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
+            batch['speech_token_len'] = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
+            batch['speech_token'] = pad_sequence(speech_token, batch_first=True, padding_value=0)
         if gan is True:
-            # in gan train, we need pitch_feat
+            # in gan train, we need speech/pitch_feat
+            speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
+            batch['speech_len'] = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
+            batch['speech'] = pad_sequence(speech, batch_first=True, padding_value=0)
             pitch_feat = [sample[i]['pitch_feat'] for i in order]
-            pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
-            pitch_feat = pad_sequence(pitch_feat,
-                                      batch_first=True,
-                                      padding_value=0)
-            batch["pitch_feat"] = pitch_feat
-            batch["pitch_feat_len"] = pitch_feat_len
+            batch['pitch_feat_len'] = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
+            batch['pitch_feat'] = pad_sequence(pitch_feat, batch_first=True, padding_value=0)
         if dpo is True:
             reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order]
-            reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
-            reject_speech_token = pad_sequence(reject_speech_token,
-                                               batch_first=True,
-                                               padding_value=0)
-            batch['reject_speech_token'] = reject_speech_token
-            batch['reject_speech_token_len'] = reject_speech_token_len
+            batch['reject_speech_token_len'] = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
+            batch['reject_speech_token'] = pad_sequence(reject_speech_token, batch_first=True, padding_value=0)
         if use_spk_embedding is True:
             batch["embedding"] = batch["spk_embedding"]
         else:

+ 11 - 4
cosyvoice/flow/flow.py

@@ -11,7 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import logging
+import os, logging
 import random
 from typing import Dict, Optional
 import torch
@@ -19,7 +19,7 @@ import torch.nn as nn
 from torch.nn import functional as F
 from omegaconf import DictConfig
 from cosyvoice.utils.mask import make_pad_mask
-from cosyvoice.utils.onnx import SpeechTokenExtractor
+from cosyvoice.utils.onnx import SpeechTokenExtractor, online_feature, onnx_path
 
 
 class MaskedDiffWithXvec(torch.nn.Module):
@@ -180,14 +180,19 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
         self.only_mask_loss = only_mask_loss
         self.token_mel_ratio = token_mel_ratio
         self.pre_lookahead_len = pre_lookahead_len
+        if online_feature is True:
+            self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v2.batch.onnx'))
 
     def forward(
             self,
             batch: dict,
             device: torch.device,
     ) -> Dict[str, Optional[torch.Tensor]]:
-        token = batch['speech_token'].to(device)
-        token_len = batch['speech_token_len'].to(device)
+        if 'speech_token' not in batch:
+            token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'])
+        else:
+            token = batch['speech_token'].to(device)
+            token_len = batch['speech_token_len'].to(device)
         feat = batch['speech_feat'].to(device)
         feat_len = batch['speech_feat_len'].to(device)
         embedding = batch['embedding'].to(device)
@@ -309,6 +314,8 @@ class CausalMaskedDiffWithDiT(torch.nn.Module):
         self.decoder = decoder
         self.only_mask_loss = only_mask_loss
         self.token_mel_ratio = token_mel_ratio
+        if online_feature is True:
+            self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v3.batch.onnx'))
 
     def forward(
             self,

+ 6 - 2
cosyvoice/llm/llm.py

@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import queue
+import os, queue
 import random
 import time
 import threading
@@ -28,7 +28,7 @@ from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
 from cosyvoice.utils.common import th_accuracy
 from cosyvoice.utils.file_utils import logging
 from cosyvoice.utils.mask import make_pad_mask
-from cosyvoice.utils.onnx import SpeechTokenExtractor
+from cosyvoice.utils.onnx import SpeechTokenExtractor, online_feature, onnx_path
 
 
 class TransformerLM(torch.nn.Module):
@@ -301,6 +301,8 @@ class Qwen2LM(TransformerLM):
         # 5. vllm related
         self.stop_token_ids = [speech_token_size + i for i in range(3)]
         self.vllm_output_queue = {}
+        if online_feature is True:
+            self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v2.batch.onnx'))
 
     def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len, instruct_token=None, instruct_token_emb=None, instruct_token_len=None):
         lm_target, lm_input = [], []
@@ -667,6 +669,8 @@ class CosyVoice3LM(Qwen2LM):
         # 5. vllm related
         self.stop_token_ids = [speech_token_size + i for i in range(200)]
         self.vllm_output_queue = {}
+        if online_feature is True:
+            self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v3.batch.onnx'))
 
     def forward(
             self,

+ 5 - 6
cosyvoice/utils/onnx.py

@@ -18,14 +18,13 @@ class SpeechTokenExtractor():
                                                                      sess_options=option,
                                                                      providers=[("CUDAExecutionProvider", {'device_id': self.local_rank})])
 
-    def inference(self, feat, feat_lengths, device):
-        ort_out = self.speech_tokenizer_session.run(None,
+    def inference(self, feat, feat_lengths):
+        speech_token = self.speech_tokenizer_session.run(None,
                                                     {self.speech_tokenizer_session.get_inputs()[0].name:
-                                                    feat.detach().cpu().numpy(),
+                                                    feat.transpose(1, 2).detach().cpu().numpy(),
                                                     self.speech_tokenizer_session.get_inputs()[1].name:
-                                                    feat_lengths.detach().cpu().numpy()})
-        speech_token, speech_token_embedding = ort_out[0], ort_out[1]
-        return torch.tensor(speech_token).to(device), (feat_lengths / 2).to(torch.int32).to(device)
+                                                    feat_lengths.detach().cpu().numpy()})[0]
+        return torch.tensor(speech_token).to(feat), (feat_lengths / 4).to(torch.int32).to(feat.device)
 
 
 class EmbeddingExtractor():

+ 3 - 0
examples/libritts/cosyvoice2/conf/cosyvoice2.yaml

@@ -159,6 +159,8 @@ feat_extractor: !name:matcha.utils.audio.mel_spectrogram
 compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
     feat_extractor: !ref <feat_extractor>
     num_frames: 960
+compute_whisper_fbank: !name:cosyvoice.dataset.processor.compute_whisper_fbank
+    num_frames: 960
 compute_f0: !name:cosyvoice.dataset.processor.compute_f0
     sample_rate: !ref <sample_rate>
     hop_size: 480
@@ -183,6 +185,7 @@ data_pipeline: [
     !ref <resample>,
     !ref <compute_fbank>,
     !ref <parse_embedding>,
+    !ref <compute_whisper_fbank>,
     !ref <shuffle>,
     !ref <sort>,
     !ref <batch>,

+ 2 - 0
examples/libritts/cosyvoice3/conf/cosyvoice3.yaml

@@ -149,6 +149,7 @@ feat_extractor: !name:matcha.utils.audio.mel_spectrogram
 compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
     feat_extractor: !ref <feat_extractor>
     num_frames: 960
+compute_whisper_fbank: !name:cosyvoice.dataset.processor.compute_whisper_fbank
 compute_f0: !name:cosyvoice.dataset.processor.compute_f0
     sample_rate: !ref <sample_rate>
     hop_size: 480
@@ -173,6 +174,7 @@ data_pipeline: [
     !ref <resample>,
     !ref <compute_fbank>,
     !ref <parse_embedding>,
+    !ref <compute_whisper_fbank>,
     !ref <shuffle>,
     !ref <sort>,
     !ref <batch>,