|
|
@@ -362,8 +362,9 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|
|
spk_emb_for_flow.to(self.device),
|
|
|
n_timesteps=10
|
|
|
)
|
|
|
-
|
|
|
- # cache dict's tensor batch dim is 1 for now
|
|
|
+ # Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache']
|
|
|
+ cache['estimator_att_cache'] = cache['estimator_att_cache'].clone()
|
|
|
+ cache['estimator_cnn_cache'] = cache['estimator_cnn_cache'].clone()
|
|
|
return cache
|
|
|
|
|
|
|
|
|
@@ -371,7 +372,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|
|
def forward_streaming(
|
|
|
self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000
|
|
|
):
|
|
|
-
|
|
|
if speaker_id not in self.speaker_cache:
|
|
|
assert prompt_audio is not None, "prompt_audio is required for new speaker"
|
|
|
assert prompt_audio_sample_rate == 16000
|
|
|
@@ -388,7 +388,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|
|
self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict}
|
|
|
|
|
|
if request_id not in self.streaming_flow_cache:
|
|
|
- self.streaming_flow_cache[request_id] = self.speaker_cache[speaker_id]['cache_dict'].copy()
|
|
|
+ self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()}
|
|
|
self.hift_cache_dict[request_id] = dict(
|
|
|
mel = torch.zeros(1, 80, 0, device='cuda'),
|
|
|
source = torch.zeros(1, 1, 0, device='cuda'),
|
|
|
@@ -396,12 +396,14 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|
|
)
|
|
|
|
|
|
current_request_cache = self.streaming_flow_cache[request_id]
|
|
|
- prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict']
|
|
|
+
|
|
|
+ current_prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict']
|
|
|
generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda')
|
|
|
|
|
|
+
|
|
|
chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk(
|
|
|
token=generated_speech_tokens,
|
|
|
- spk=prompt_audio_dict['spk_emb_for_flow'].to(self.device),
|
|
|
+ spk=current_prompt_audio_dict['spk_emb_for_flow'].to(self.device),
|
|
|
cache=current_request_cache,
|
|
|
last_chunk=last_chunk,
|
|
|
n_timesteps=10,
|
|
|
@@ -409,9 +411,10 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|
|
|
|
|
self.streaming_flow_cache[request_id] = new_streaming_flow_cache
|
|
|
|
|
|
- if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100):
|
|
|
+
|
|
|
+ if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (current_prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100):
|
|
|
self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([
|
|
|
- self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :prompt_audio_dict['prompt_mels_for_flow'].shape[1]],
|
|
|
+ self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :current_prompt_audio_dict['prompt_mels_for_flow'].shape[1]],
|
|
|
self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:],
|
|
|
], dim=4)
|
|
|
|