Explorar o código

optimize vc code

lyuxiang.lx hai 7 meses
pai
achega
37e48dd318
Modificáronse 2 ficheiros con 17 adicións e 62 borrados
  1. 1 1
      cosyvoice/cli/cosyvoice.py
  2. 16 61
      cosyvoice/cli/model.py

+ 1 - 1
cosyvoice/cli/cosyvoice.py

@@ -128,7 +128,7 @@ class CosyVoice:
     def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
         model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
         start_time = time.time()
-        for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
+        for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
             speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
             logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
             yield model_output

+ 16 - 61
cosyvoice/cli/model.py

@@ -122,6 +122,10 @@ class CosyVoiceModel:
                     self.tts_speech_token_dict[uuid].append(i)
         self.llm_end_dict[uuid] = True
 
+    def vc_job(self, source_speech_token, uuid):
+        self.tts_speech_token_dict[uuid] = source_speech_token.flatten().tolist()
+        self.llm_end_dict[uuid] = True
+
     def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
         with torch.cuda.amp.autocast(self.fp16):
             tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device),
@@ -162,11 +166,11 @@ class CosyVoiceModel:
                 tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
         return tts_speech
 
-    def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
+    def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.zeros(0, 192), llm_embedding=torch.zeros(0, 192),
             prompt_text=torch.zeros(1, 0, dtype=torch.int32),
             llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
             flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
-            prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
+            prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
         # this_uuid is used to track variables related to this inference thread
         this_uuid = str(uuid.uuid1())
         with self.lock:
@@ -174,7 +178,10 @@ class CosyVoiceModel:
             self.hift_cache_dict[this_uuid] = None
             self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
             self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
-        p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
+        if source_speech_token.shape[1] == 0:
+            p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
+        else:
+            p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
         p.start()
         if stream is True:
             token_hop_len = self.token_min_hop_len
@@ -226,61 +233,6 @@ class CosyVoiceModel:
             self.flow_cache_dict.pop(this_uuid)
         torch.cuda.empty_cache()
 
-    def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
-        # this_uuid is used to track variables related to this inference thread
-        this_uuid = str(uuid.uuid1())
-        with self.lock:
-            self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
-            self.hift_cache_dict[this_uuid] = None
-            self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
-            self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
-        if stream is True:
-            token_hop_len = self.token_min_hop_len
-            while True:
-                if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
-                    this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
-                        .unsqueeze(dim=0)
-                    this_tts_speech = self.token2wav(token=this_tts_speech_token,
-                                                     prompt_token=flow_prompt_speech_token,
-                                                     prompt_feat=prompt_speech_feat,
-                                                     embedding=flow_embedding,
-                                                     uuid=this_uuid,
-                                                     finalize=False)
-                    yield {'tts_speech': this_tts_speech.cpu()}
-                    with self.lock:
-                        self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
-                    # increase token_hop_len for better speech quality
-                    token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
-                if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
-                    break
-            # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
-            this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
-            this_tts_speech = self.token2wav(token=this_tts_speech_token,
-                                             prompt_token=flow_prompt_speech_token,
-                                             prompt_feat=prompt_speech_feat,
-                                             embedding=flow_embedding,
-                                             uuid=this_uuid,
-                                             finalize=True)
-            yield {'tts_speech': this_tts_speech.cpu()}
-        else:
-            # deal with all tokens
-            this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
-            this_tts_speech = self.token2wav(token=this_tts_speech_token,
-                                             prompt_token=flow_prompt_speech_token,
-                                             prompt_feat=prompt_speech_feat,
-                                             embedding=flow_embedding,
-                                             uuid=this_uuid,
-                                             finalize=True,
-                                             speed=speed)
-            yield {'tts_speech': this_tts_speech.cpu()}
-        with self.lock:
-            self.tts_speech_token_dict.pop(this_uuid)
-            self.llm_end_dict.pop(this_uuid)
-            self.mel_overlap_dict.pop(this_uuid)
-            self.hift_cache_dict.pop(this_uuid)
-            self.flow_cache_dict.pop(this_uuid)
-        torch.cuda.empty_cache()
-
 
 class CosyVoice2Model(CosyVoiceModel):
 
@@ -386,18 +338,21 @@ class CosyVoice2Model(CosyVoiceModel):
                 tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
         return tts_speech
 
-    def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
+    def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.zeros(0, 192), llm_embedding=torch.zeros(0, 192),
             prompt_text=torch.zeros(1, 0, dtype=torch.int32),
             llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
             flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
-            prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
+            prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
         # this_uuid is used to track variables related to this inference thread
         this_uuid = str(uuid.uuid1())
         with self.lock:
             self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
             self.hift_cache_dict[this_uuid] = None
             self.flow_cache_dict[this_uuid] = self.init_flow_cache()
-        p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
+        if source_speech_token.shape[1] == 0:
+            p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
+        else:
+            p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
         p.start()
         if stream is True:
             assert self.use_flow_cache is True, "set use_flow_cache=True if you want to use stream inference to avoid OOM"