root преди 2 месеца
родител
ревизия
444b7ff5df
променени са 2 файла, в които са добавени 22 реда и са изтрити 8 реда
  1. 11 0
      runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh
  2. 11 8
      runtime/triton_trtllm/token2wav_dit.py

+ 11 - 0
runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh

@@ -2,6 +2,9 @@
 # Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang)
 export CUDA_VISIBLE_DEVICES=0
 cosyvoice_path=/workspace/CosyVoice
+cosyvoice_path=/workspace_yuekai/tts/CosyVoice
+stepaudio2_path=/workspace_yuekai/tts/Step-Audio2
+export PYTHONPATH=${stepaudio2_path}:$PYTHONPATH
 export PYTHONPATH=${cosyvoice_path}:$PYTHONPATH
 export PYTHONPATH=${cosyvoice_path}/third_party/Matcha-TTS:$PYTHONPATH
 stage=$1
@@ -140,3 +143,11 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
     done
   done
 fi
+
+
+if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
+
+   python3 benchmark_streaming_token2wav.py --enable-trt
+
+
+fi

+ 11 - 8
runtime/triton_trtllm/token2wav_dit.py

@@ -362,8 +362,9 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
             spk_emb_for_flow.to(self.device),
             n_timesteps=10
         )
-
-        # cache dict's tensor batch dim is 1 for now
+        # Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache']
+        cache['estimator_att_cache'] = cache['estimator_att_cache'].clone()
+        cache['estimator_cnn_cache'] = cache['estimator_cnn_cache'].clone()
         return cache
 
 
@@ -371,7 +372,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
     def forward_streaming(
         self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000
     ):
-
         if speaker_id not in self.speaker_cache:
             assert prompt_audio is not None, "prompt_audio is required for new speaker"
             assert prompt_audio_sample_rate == 16000
@@ -388,7 +388,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
             self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict}
 
         if request_id not in self.streaming_flow_cache:
-            self.streaming_flow_cache[request_id] = self.speaker_cache[speaker_id]['cache_dict'].copy()
+            self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()}
             self.hift_cache_dict[request_id] = dict(
             mel = torch.zeros(1, 80, 0, device='cuda'), 
             source = torch.zeros(1, 1, 0, device='cuda'),
@@ -396,12 +396,14 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
             )
 
         current_request_cache = self.streaming_flow_cache[request_id]
-        prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict']
+
+        current_prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict']
         generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda')
 
+
         chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk(
             token=generated_speech_tokens,
-            spk=prompt_audio_dict['spk_emb_for_flow'].to(self.device),
+            spk=current_prompt_audio_dict['spk_emb_for_flow'].to(self.device),
             cache=current_request_cache,
             last_chunk=last_chunk,
             n_timesteps=10,
@@ -409,9 +411,10 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
 
         self.streaming_flow_cache[request_id] = new_streaming_flow_cache
 
-        if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100):
+
+        if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (current_prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100):
             self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([
-                self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :prompt_audio_dict['prompt_mels_for_flow'].shape[1]],
+                self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :current_prompt_audio_dict['prompt_mels_for_flow'].shape[1]],
                 self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:],
             ], dim=4)