Jelajahi Sumber

update vc/tts code

lyuxiang.lx 1 tahun lalu
induk
melakukan
72b89a52fb

+ 65 - 12
cosyvoice/cli/model.py

@@ -35,7 +35,7 @@ class CosyVoiceModel:
         self.token_max_hop_len = 200
         self.token_overlap_len = 20
         # mel fade in out
-        self.mel_overlap_len = 34
+        self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
         self.mel_window = np.hamming(2 * self.mel_overlap_len)
         # hift cache
         self.mel_cache_len = 20
@@ -54,9 +54,10 @@ class CosyVoiceModel:
         self.hift_cache_dict = {}
 
     def load(self, llm_model, flow_model, hift_model):
-        self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
-        self.llm.to(self.device).eval()
-        self.llm.half()
+        if self.llm is not None:
+            self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
+            self.llm.to(self.device).eval()
+            self.llm.half()
         self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
         self.flow.to(self.device).eval()
         self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
@@ -131,11 +132,11 @@ class CosyVoiceModel:
                 tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
         return tts_speech
 
-    def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
-                  prompt_text=torch.zeros(1, 0, dtype=torch.int32),
-                  llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
-                  flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
-                  prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
+    def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
+            prompt_text=torch.zeros(1, 0, dtype=torch.int32),
+            llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
+            flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
+            prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
         # this_uuid is used to track variables related to this inference thread
         this_uuid = str(uuid.uuid1())
         with self.lock:
@@ -148,7 +149,8 @@ class CosyVoiceModel:
             while True:
                 time.sleep(0.1)
                 if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
-                    this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
+                    this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
+                        .unsqueeze(dim=0)
                     this_tts_speech = self.token2wav(token=this_tts_speech_token,
                                                      prompt_token=flow_prompt_speech_token,
                                                      prompt_feat=prompt_speech_feat,
@@ -164,7 +166,7 @@ class CosyVoiceModel:
                     break
             p.join()
             # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
-            this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
+            this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
             this_tts_speech = self.token2wav(token=this_tts_speech_token,
                                              prompt_token=flow_prompt_speech_token,
                                              prompt_feat=prompt_speech_feat,
@@ -175,7 +177,58 @@ class CosyVoiceModel:
         else:
             # deal with all tokens
             p.join()
-            this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
+            this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
+            this_tts_speech = self.token2wav(token=this_tts_speech_token,
+                                             prompt_token=flow_prompt_speech_token,
+                                             prompt_feat=prompt_speech_feat,
+                                             embedding=flow_embedding,
+                                             uuid=this_uuid,
+                                             finalize=True,
+                                             speed=speed)
+            yield {'tts_speech': this_tts_speech.cpu()}
+        with self.lock:
+            self.tts_speech_token_dict.pop(this_uuid)
+            self.llm_end_dict.pop(this_uuid)
+            self.mel_overlap_dict.pop(this_uuid)
+            self.hift_cache_dict.pop(this_uuid)
+
+    def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
+        # this_uuid is used to track variables related to this inference thread
+        this_uuid = str(uuid.uuid1())
+        with self.lock:
+            self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
+            self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None
+        if stream is True:
+            token_hop_len = self.token_min_hop_len
+            while True:
+                if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
+                    this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
+                        .unsqueeze(dim=0)
+                    this_tts_speech = self.token2wav(token=this_tts_speech_token,
+                                                     prompt_token=flow_prompt_speech_token,
+                                                     prompt_feat=prompt_speech_feat,
+                                                     embedding=flow_embedding,
+                                                     uuid=this_uuid,
+                                                     finalize=False)
+                    yield {'tts_speech': this_tts_speech.cpu()}
+                    with self.lock:
+                        self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
+                    # increase token_hop_len for better speech quality
+                    token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
+                if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
+                    break
+            # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
+            this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid], dim=1).unsqueeze(dim=0)
+            this_tts_speech = self.token2wav(token=this_tts_speech_token,
+                                             prompt_token=flow_prompt_speech_token,
+                                             prompt_feat=prompt_speech_feat,
+                                             embedding=flow_embedding,
+                                             uuid=this_uuid,
+                                             finalize=True)
+            yield {'tts_speech': this_tts_speech.cpu()}
+        else:
+            # deal with all tokens
+            this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
             this_tts_speech = self.token2wav(token=this_tts_speech_token,
                                              prompt_token=flow_prompt_speech_token,
                                              prompt_feat=prompt_speech_feat,

+ 1 - 1
cosyvoice/flow/flow.py

@@ -125,7 +125,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
         h, h_lengths = self.encoder(token, token_len)
         h = self.encoder_proj(h)
         mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
-        h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2)
+        h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
 
         # get conditions
         conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)

+ 5 - 4
cosyvoice/flow/length_regulator.py

@@ -49,13 +49,14 @@ class InterpolateRegulator(nn.Module):
         olens = ylens
         return out * mask, olens
 
-    def inference(self, x1, x2, mel_len1, mel_len2):
+    def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
         # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
         # x in (B, T, D)
         if x2.shape[1] > 40:
-            x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=34, mode='linear')
-            x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - 34 * 2, mode='linear')
-            x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=34, mode='linear')
+            x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
+            x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
+                                   mode='linear')
+            x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
             x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
         else:
             x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')

+ 2 - 4
cosyvoice/tokenizer/tokenizer.py

@@ -1,9 +1,7 @@
 import base64
 import os
-import string
-from dataclasses import dataclass, field
-from functools import cached_property, lru_cache
-from typing import Dict, List, Optional, Tuple
+from functools import lru_cache
+from typing import Optional
 from whisper.tokenizer import Tokenizer
 
 import tiktoken

+ 1 - 0
cosyvoice/utils/common.py

@@ -145,6 +145,7 @@ def fade_in_out(fade_in_mel, fade_out_mel, window):
         fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
     return fade_in_mel.to(device)
 
+
 def set_all_random_seed(seed):
     random.seed(seed)
     np.random.seed(seed)