|
|
@@ -52,8 +52,8 @@ class CosyVoiceModel:
|
|
|
# dict used to store session related variable
|
|
|
self.tts_speech_token_dict = {}
|
|
|
self.llm_end_dict = {}
|
|
|
- self.flow_cache_dict = {}
|
|
|
self.mel_overlap_dict = {}
|
|
|
+ self.flow_cache_dict = {}
|
|
|
self.hift_cache_dict = {}
|
|
|
|
|
|
def load(self, llm_model, flow_model, hift_model):
|
|
|
@@ -102,18 +102,17 @@ class CosyVoiceModel:
|
|
|
|
|
|
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
|
|
|
tts_mel, flow_cache = self.flow.inference(token=token.to(self.device),
|
|
|
- token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
|
|
- prompt_token=prompt_token.to(self.device),
|
|
|
- prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
|
|
- prompt_feat=prompt_feat.to(self.device),
|
|
|
- prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
|
|
- embedding=embedding.to(self.device),
|
|
|
- required_cache_size=self.mel_overlap_len,
|
|
|
- flow_cache=self.flow_cache_dict[uuid])
|
|
|
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
|
|
+ prompt_token=prompt_token.to(self.device),
|
|
|
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
|
|
+ prompt_feat=prompt_feat.to(self.device),
|
|
|
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
|
|
+ embedding=embedding.to(self.device),
|
|
|
+ flow_cache=self.flow_cache_dict[uuid])
|
|
|
self.flow_cache_dict[uuid] = flow_cache
|
|
|
|
|
|
# mel overlap fade in out
|
|
|
- if self.mel_overlap_dict[uuid] is not None:
|
|
|
+ if self.mel_overlap_dict[uuid].shape[2] != 0:
|
|
|
tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
|
|
|
# append hift cache
|
|
|
if self.hift_cache_dict[uuid] is not None:
|
|
|
@@ -150,8 +149,9 @@ class CosyVoiceModel:
|
|
|
this_uuid = str(uuid.uuid1())
|
|
|
with self.lock:
|
|
|
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
|
|
- self.flow_cache_dict[this_uuid] = None
|
|
|
- self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None
|
|
|
+ self.hift_cache_dict[this_uuid] = None
|
|
|
+ self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
|
|
|
+ self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
|
|
|
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
|
|
p.start()
|
|
|
if stream is True:
|
|
|
@@ -207,7 +207,9 @@ class CosyVoiceModel:
|
|
|
this_uuid = str(uuid.uuid1())
|
|
|
with self.lock:
|
|
|
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
|
|
|
- self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None
|
|
|
+ self.hift_cache_dict[this_uuid] = None
|
|
|
+ self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
|
|
|
+ self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
|
|
|
if stream is True:
|
|
|
token_hop_len = self.token_min_hop_len
|
|
|
while True:
|