|
@@ -35,7 +35,7 @@ class CosyVoiceModel:
|
|
|
self.token_max_hop_len = 200
|
|
|
self.token_overlap_len = 20
|
|
|
# mel fade in out
|
|
|
- self.mel_overlap_len = 34
|
|
|
+ self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
|
|
|
self.mel_window = np.hamming(2 * self.mel_overlap_len)
|
|
|
# hift cache
|
|
|
self.mel_cache_len = 20
|
|
@@ -54,9 +54,10 @@ class CosyVoiceModel:
|
|
|
self.hift_cache_dict = {}
|
|
|
|
|
|
def load(self, llm_model, flow_model, hift_model):
|
|
|
- self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
|
|
|
- self.llm.to(self.device).eval()
|
|
|
- self.llm.half()
|
|
|
+ if self.llm is not None:
|
|
|
+ self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
|
|
|
+ self.llm.to(self.device).eval()
|
|
|
+ self.llm.half()
|
|
|
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
|
|
|
self.flow.to(self.device).eval()
|
|
|
self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
|
|
@@ -131,11 +132,11 @@ class CosyVoiceModel:
|
|
|
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
|
|
return tts_speech
|
|
|
|
|
|
- def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
|
|
- prompt_text=torch.zeros(1, 0, dtype=torch.int32),
|
|
|
- llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
|
|
- flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
|
|
- prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
|
|
|
+ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
|
|
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32),
|
|
|
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
|
|
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
|
|
+ prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
|
|
|
# this_uuid is used to track variables related to this inference thread
|
|
|
this_uuid = str(uuid.uuid1())
|
|
|
with self.lock:
|
|
@@ -148,7 +149,8 @@ class CosyVoiceModel:
|
|
|
while True:
|
|
|
time.sleep(0.1)
|
|
|
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
|
|
|
- this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
|
|
|
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
|
|
|
+ .unsqueeze(dim=0)
|
|
|
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
|
|
prompt_token=flow_prompt_speech_token,
|
|
|
prompt_feat=prompt_speech_feat,
|
|
@@ -164,7 +166,7 @@ class CosyVoiceModel:
|
|
|
break
|
|
|
p.join()
|
|
|
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
|
|
- this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
|
|
|
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
|
|
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
|
|
prompt_token=flow_prompt_speech_token,
|
|
|
prompt_feat=prompt_speech_feat,
|
|
@@ -175,7 +177,58 @@ class CosyVoiceModel:
|
|
|
else:
|
|
|
# deal with all tokens
|
|
|
p.join()
|
|
|
- this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
|
|
|
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
|
|
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
|
|
+ prompt_token=flow_prompt_speech_token,
|
|
|
+ prompt_feat=prompt_speech_feat,
|
|
|
+ embedding=flow_embedding,
|
|
|
+ uuid=this_uuid,
|
|
|
+ finalize=True,
|
|
|
+ speed=speed)
|
|
|
+ yield {'tts_speech': this_tts_speech.cpu()}
|
|
|
+ with self.lock:
|
|
|
+ self.tts_speech_token_dict.pop(this_uuid)
|
|
|
+ self.llm_end_dict.pop(this_uuid)
|
|
|
+ self.mel_overlap_dict.pop(this_uuid)
|
|
|
+ self.hift_cache_dict.pop(this_uuid)
|
|
|
+
|
|
|
+ def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
|
|
|
+ # this_uuid is used to track variables related to this inference thread
|
|
|
+ this_uuid = str(uuid.uuid1())
|
|
|
+ with self.lock:
|
|
|
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
|
|
|
+ self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None
|
|
|
+ if stream is True:
|
|
|
+ token_hop_len = self.token_min_hop_len
|
|
|
+ while True:
|
|
|
+ if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
|
|
|
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
|
|
|
+ .unsqueeze(dim=0)
|
|
|
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
|
|
+ prompt_token=flow_prompt_speech_token,
|
|
|
+ prompt_feat=prompt_speech_feat,
|
|
|
+ embedding=flow_embedding,
|
|
|
+ uuid=this_uuid,
|
|
|
+ finalize=False)
|
|
|
+ yield {'tts_speech': this_tts_speech.cpu()}
|
|
|
+ with self.lock:
|
|
|
+ self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
|
|
|
+ # increase token_hop_len for better speech quality
|
|
|
+ token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
|
|
|
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
|
|
|
+ break
|
|
|
+ # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
|
|
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid], dim=1).unsqueeze(dim=0)
|
|
|
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
|
|
+ prompt_token=flow_prompt_speech_token,
|
|
|
+ prompt_feat=prompt_speech_feat,
|
|
|
+ embedding=flow_embedding,
|
|
|
+ uuid=this_uuid,
|
|
|
+ finalize=True)
|
|
|
+ yield {'tts_speech': this_tts_speech.cpu()}
|
|
|
+ else:
|
|
|
+ # deal with all tokens
|
|
|
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
|
|
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
|
|
prompt_token=flow_prompt_speech_token,
|
|
|
prompt_feat=prompt_speech_feat,
|