root 2 месяцев назад
Родитель
Сommit
31a0adc73d

+ 102 - 57
runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py

@@ -43,6 +43,7 @@ import torchaudio
 
 
 from matcha.utils.audio import mel_spectrogram
+from datetime import datetime
 
 ORIGINAL_VOCAB_SIZE = 151663
 torch.set_num_threads(1)
@@ -86,6 +87,7 @@ class TritonPythonModel:
         model_params = {k: v["string_value"] for k, v in parameters.items()}
         self.logger.log_info(f"model_params:{model_params}")
         self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential")  # "exponential" or "time_based"
+        # self.dynamic_chunk_strategy = "equal"
         self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
 
         # Initialize tokenizer
@@ -105,7 +107,9 @@ class TritonPythonModel:
         if not os.path.exists(spk_info_path):
             raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
         spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
-        # self.default_spk_info = spk_info["001"]
+        self.default_spk_info = spk_info["001"]
+        self.http_client = httpx.AsyncClient()
+        self.runtime_cache = {}
 
     def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
         """Converts a tensor or list of speech token IDs to a string representation."""
@@ -131,7 +135,6 @@ class TritonPythonModel:
             {"role": "user", "content": full_text},
             {"role": "assistant", "content": prompt_speech_tokens_str}
         ]
-        print(chat)
 
         payload = {
             "model": "trt_engines_bfloat16",
@@ -148,31 +151,33 @@ class TritonPythonModel:
         api_base = "http://localhost:8000/v1/chat/completions"
 
         buffer = ""
-        async with httpx.AsyncClient() as client:
-            async with client.stream("POST", api_base, json=payload, timeout=None) as response:
-                response.raise_for_status()
-                async for line in response.aiter_lines():
-                    if line.startswith("data: "):
-                        line_data = line[len("data: "):].strip()
-                        if line_data == "[DONE]":
-                            break
-                        try:
-                            json_data = json.loads(line_data)
-                            content = json_data.get("choices", [{}])[0].get("delta", {}).get("content")
-                            if content:
-                                buffer += content
-                                while True:
-                                    match = re.search(r"<\|s_(\d+)\|>", buffer)
-                                    if not match:
-                                        break
-
-                                    token_num = int(match.group(1))
-                                    final_id = token_num + ORIGINAL_VOCAB_SIZE
-                                    yield final_id
-                                    buffer = buffer[match.end():]
-                        except json.JSONDecodeError:
-                            self.logger.log_info(f"Skipping non-JSON line: {line_data}")
-                            continue
+        async with self.http_client.stream("POST", api_base, json=payload, timeout=None) as response:
+            print(f"start httpx.AsyncClient, target_text: {target_text[:5]}, time: {datetime.now()}")
+            print(f"start response.aiter_lines, target_text: {target_text[:5]}, time: {datetime.now()}")
+            response.raise_for_status()
+            async for line in response.aiter_lines():
+                if line.startswith("data: "):
+                    line_data = line[len("data: "):].strip()
+                    if line_data == "[DONE]":
+                        break
+                    try:
+                        json_data = json.loads(line_data)
+                        content = json_data.get("choices", [{}])[0].get("delta", {}).get("content")
+                        if content:
+                            buffer += content
+                            print(f"buffer: {buffer}, target_text: {target_text[:5]}, time: {datetime.now()}")
+                            while True:
+                                match = re.search(r"<\|s_(\d+)\|>", buffer)
+                                if not match:
+                                    break
+
+                                token_num = int(match.group(1))
+                                final_id = token_num + ORIGINAL_VOCAB_SIZE
+                                yield final_id
+                                buffer = buffer[match.end():]
+                    except json.JSONDecodeError:
+                        self.logger.log_info(f"Skipping non-JSON line: {line_data}")
+                        continue
 
         # Process any remaining complete tokens in the buffer after the stream ends
         while True:
@@ -236,7 +241,7 @@ class TritonPythonModel:
 
         return prompt_spk_embedding
 
-    def forward_token2wav(
+    async def forward_token2wav(
             self,
             index: int,
             target_speech_tokens: torch.Tensor,
@@ -258,20 +263,57 @@ class TritonPythonModel:
         target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
         finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
         inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
-
+        
+        # optional cache inputs
+        if self.runtime_cache[request_id]["conformer_cnn_cache"] is not None:
+            # inputs_tensor.extend([
+            #     pb_utils.Tensor("conformer_cnn_cache", self.runtime_cache[request_id]["conformer_cnn_cache"].as_numpy()),
+            #     pb_utils.Tensor("conformer_att_cache", self.runtime_cache[request_id]["conformer_att_cache"].as_numpy()),
+            #     pb_utils.Tensor("estimator_cnn_cache", self.runtime_cache[request_id]["estimator_cnn_cache"].as_numpy()),
+            #     pb_utils.Tensor("estimator_att_cache", self.runtime_cache[request_id]["estimator_att_cache"].as_numpy()),
+            #     pb_utils.Tensor("mel", self.runtime_cache[request_id]["mel"].as_numpy()),
+            #     pb_utils.Tensor("source", self.runtime_cache[request_id]["source"].as_numpy()),
+            #     pb_utils.Tensor("speech", self.runtime_cache[request_id]["speech"].as_numpy()),
+            # ])
+            inputs_tensor.extend([
+                self.runtime_cache[request_id]["conformer_cnn_cache"],
+                self.runtime_cache[request_id]["conformer_att_cache"],
+                self.runtime_cache[request_id]["estimator_cnn_cache"],
+                self.runtime_cache[request_id]["estimator_att_cache"],
+                self.runtime_cache[request_id]["mel"],
+                self.runtime_cache[request_id]["source"],
+                self.runtime_cache[request_id]["speech"],
+            ])
         # Create and execute inference request
         inference_request = pb_utils.InferenceRequest(
             model_name='token2wav_dit',
-            requested_output_names=['waveform'],
+            requested_output_names=[
+                "waveform",
+                "conformer_cnn_cache",
+                "conformer_att_cache",
+                "estimator_cnn_cache",
+                "estimator_att_cache",
+                "mel",
+                "source",
+                "speech",
+            ],
             inputs=inputs_tensor,
             request_id=request_id,
             parameters={"priority": index+1},
         )
 
-        inference_response = inference_request.exec()
+        inference_response = await inference_request.async_exec()
         if inference_response.has_error():
             raise pb_utils.TritonModelException(inference_response.error().message())
 
+        self.runtime_cache[request_id]["conformer_cnn_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "conformer_cnn_cache")
+        self.runtime_cache[request_id]["conformer_att_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "conformer_att_cache")
+        self.runtime_cache[request_id]["estimator_cnn_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "estimator_cnn_cache")
+        self.runtime_cache[request_id]["estimator_att_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "estimator_att_cache")
+        self.runtime_cache[request_id]["mel"] = pb_utils.get_output_tensor_by_name(inference_response, "mel")
+        self.runtime_cache[request_id]["source"] = pb_utils.get_output_tensor_by_name(inference_response, "source")
+        self.runtime_cache[request_id]["speech"] = pb_utils.get_output_tensor_by_name(inference_response, "speech")
+
         # Extract and convert output waveform
         waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
         waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
@@ -297,6 +339,16 @@ class TritonPythonModel:
 
     async def _process_request(self, request):
         request_id = request.request_id()
+        if request_id not in self.runtime_cache:
+            self.runtime_cache[request_id] = {
+                "conformer_cnn_cache": None,
+                "conformer_att_cache": None,
+                "estimator_cnn_cache": None,
+                "estimator_att_cache": None,
+                "mel": None,
+                "source": None,
+                "speech": None,
+            }
         # Extract input tensors
         wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
 
@@ -308,6 +360,7 @@ class TritonPythonModel:
 
             wav_tensor = wav.as_numpy()
             wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
+            print(f"wav_tensor: {wav_tensor.shape}, time: {datetime.now()}")
             prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
             speech_feat = self._extract_speech_feat(prompt_speech_resample)
             token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
@@ -316,7 +369,7 @@ class TritonPythonModel:
 
             reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
             reference_text = reference_text[0][0].decode('utf-8')
-            # prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
+            prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
 
             # reference_text = self.default_spk_info["prompt_text"]
             # prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
@@ -333,6 +386,7 @@ class TritonPythonModel:
 
         target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
         target_text = target_text[0][0].decode('utf-8')
+        print(f"target_text: {target_text}, time: {datetime.now()}")
 
         if self.decoupled:
             response_sender = request.get_response_sender()
@@ -341,7 +395,7 @@ class TritonPythonModel:
             token_offset, chunk_index = 0, 0
             start_time = time.time()
             this_token_hop_len = self.token_hop_len
-
+            print(f"start forward_llm_async, target_text: {target_text[:5]}, time: {datetime.now()}")
             async for generated_ids in self.forward_llm_async(
                 target_text=target_text,
                 reference_text=reference_text,
@@ -350,18 +404,18 @@ class TritonPythonModel:
                 if not generated_ids:
                     break
                 semantic_token_ids_arr.append(generated_ids)
-                
+                print(f"generated_ids: {generated_ids}, target_text: {target_text[:5]}, time: {datetime.now()}")
                 while True:
                     pending_num = len(semantic_token_ids_arr) - token_offset
                     if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
                         this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
                         this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
-
-                        sub_tts_speech = self.forward_token2wav(
+                        print(f"chunk_index: {chunk_index}, target_text: {target_text[:5]}, time: {datetime.now()}")
+                        sub_tts_speech = await self.forward_token2wav(
                             chunk_index,
                             this_tts_speech_token, request_id, wav, wav_len, False
                         )
-
+                        print(f"finish token2wav, target_text: {target_text[:5]}, time: {datetime.now()}")
                         audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
                         inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
                         response_sender.send(inference_response)
@@ -371,6 +425,8 @@ class TritonPythonModel:
 
                         if self.dynamic_chunk_strategy == "exponential":
                             this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
+                        elif self.dynamic_chunk_strategy == "equal":
+                            this_token_hop_len = self.token_hop_len
                         elif self.dynamic_chunk_strategy == "time_based":
                             # see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
                             cost_time = time.time() - start_time
@@ -393,29 +449,13 @@ class TritonPythonModel:
                         break
             
             this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
-            sub_tts_speech = self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
+            sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
             audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
             inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
             response_sender.send(inference_response)
-
-            ## debug
-            ## save semantic_token_ids_arr and reference_text, target_text to a single json file
-            # save into a torch .pt
-            # for i, item in enumerate(semantic_token_ids_arr):
-            #     semantic_token_ids_arr[i] = item - ORIGINAL_VOCAB_SIZE
-            # import json
-            # data = {
-            #     "semantic_token_ids_arr": semantic_token_ids_arr,
-            #     "reference_text": reference_text,
-            #     "target_text": target_text
-            # }
-            # with open(f"semantic_token_ids_arr_debug_{request_id}.pt", "wb") as f:
-            #     torch.save(data, f)
-            # with open(f"semantic_token_ids_arr_debug_{request_id}.json", "w") as f:
-            #     json.dump(data, f)
-            
-            # ##
-
+            if request_id in self.runtime_cache:
+                del self.runtime_cache[request_id]
+                self.logger.log_info(f"Deleted cache for request_id: {request_id}")
             response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
             self.logger.log_info("send tritonserver_response_complete_final to end")
         else:
@@ -436,3 +476,8 @@ class TritonPythonModel:
         ]
         await asyncio.gather(*tasks)
         return None
+
+    def finalize(self):
+        self.logger.log_info("Finalizing CosyVoice DIT model")
+        if hasattr(self, "http_client"):
+            asyncio.run(self.http_client.aclose())

+ 1 - 1
runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt

@@ -31,7 +31,7 @@ parameters [
    value: {string_value:"${model_dir}"}
   }
 ]
-
+parameters: { key: "FORCE_CPU_ONLY_INPUT_TENSORS" value: {string_value:"no"}}
 input [
   {
     name: "reference_wav"

+ 77 - 25
runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py

@@ -103,39 +103,91 @@ class TritonPythonModel:
             List of inference responses containing generated waveforms
         """
         responses = []
-        # Process each request in batch
         for request in requests:
-            target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
-            target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor)#.to(self.device)
-            # shift the speech tokens according to the original vocab size
-            target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
-            target_speech_tokens = target_speech_tokens.squeeze().tolist()
-
-            # We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.
-           
-            finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
-                
             request_id = request.request_id()
-               
 
-            wav_array = pb_utils.get_input_tensor_by_name(
-                request, "reference_wav").as_numpy()
-            wav_len = pb_utils.get_input_tensor_by_name(
-                request, "reference_wav_len").as_numpy().item()
-
-            wav_array = torch.from_numpy(wav_array)
-            # Prepare inputs
-            wav = wav_array[:, :wav_len].squeeze(0)
+            # Get inputs
+            target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens")
+            target_speech_tokens = torch.utils.dlpack.from_dlpack(target_speech_tokens_tensor.to_dlpack())
+            target_speech_tokens = target_speech_tokens.squeeze().tolist()
 
+            finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
+            wav_array = pb_utils.get_input_tensor_by_name(request, "reference_wav").as_numpy()
+            wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len").as_numpy().item()
+            wav = torch.from_numpy(wav_array)[:, :wav_len].squeeze(0)
             spk_id = get_spk_id_from_prompt_audio(wav)
-            # wav = wav.to(self.device)
 
-            audio_hat = self.token2wav_model.forward_streaming(target_speech_tokens, finalize, request_id=request_id, speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000)
+            # Handle cache
+            conformer_cnn_cache = pb_utils.get_input_tensor_by_name(request, "conformer_cnn_cache")
+            if conformer_cnn_cache is not None:
+                self.token2wav_model.streaming_flow_cache[request_id]['conformer_cnn_cache'] = torch.utils.dlpack.from_dlpack(conformer_cnn_cache.to_dlpack())
+                
+                conformer_att_cache_np = pb_utils.get_input_tensor_by_name(request, "conformer_att_cache")
+                self.token2wav_model.streaming_flow_cache[request_id]['conformer_att_cache'] = torch.utils.dlpack.from_dlpack(conformer_att_cache_np.to_dlpack()).transpose(0,1)
+                
+                estimator_cnn_cache_np = pb_utils.get_input_tensor_by_name(request, "estimator_cnn_cache")
+                self.token2wav_model.streaming_flow_cache[request_id]['estimator_cnn_cache'] = torch.utils.dlpack.from_dlpack(estimator_cnn_cache_np.to_dlpack()).squeeze(0)
 
-            generated_wave = audio_hat.squeeze(0).cpu().numpy()
+                estimator_att_cache_np = pb_utils.get_input_tensor_by_name(request, "estimator_att_cache")
+                self.token2wav_model.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.utils.dlpack.from_dlpack(estimator_att_cache_np.to_dlpack()).squeeze(0)
 
+                mel_np = pb_utils.get_input_tensor_by_name(request, "mel")
+                self.token2wav_model.streaming_flow_cache[request_id]['mel'] = torch.utils.dlpack.from_dlpack(mel_np.to_dlpack())
+                
+                source_np = pb_utils.get_input_tensor_by_name(request, "source")
+                self.token2wav_model.hift_cache_dict[request_id]['source'] = torch.utils.dlpack.from_dlpack(source_np.to_dlpack())
+                
+                speech_np = pb_utils.get_input_tensor_by_name(request, "speech")
+                self.token2wav_model.hift_cache_dict[request_id]['speech'] = torch.utils.dlpack.from_dlpack(speech_np.to_dlpack())
+
+            # Forward pass
+            audio_hat = self.token2wav_model.forward_streaming(
+                target_speech_tokens, 
+                finalize, 
+                request_id=request_id, 
+                speaker_id=f"{spk_id}", 
+                prompt_audio=wav, 
+                prompt_audio_sample_rate=16000
+            )
+            
+            # Prepare outputs
+            outputs = []
             wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
-            inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor])
+            outputs.append(wav_tensor)
+            
+            if request_id in self.token2wav_model.streaming_flow_cache:
+                cache = self.token2wav_model.streaming_flow_cache[request_id]
+                hifigan_cache = self.token2wav_model.hift_cache_dict[request_id]
+                conformer_cnn_cache = cache['conformer_cnn_cache']
+                conformer_att_cache = cache['conformer_att_cache'].transpose(0,1)
+                estimator_cnn_cache = cache['estimator_cnn_cache'].unsqueeze(0)
+                estimator_att_cache = cache['estimator_att_cache'].unsqueeze(0)
+                mel = hifigan_cache['mel']
+                source = hifigan_cache['source']
+                speech = hifigan_cache['speech']
+
+                outputs.extend([
+                    pb_utils.Tensor.from_dlpack("conformer_cnn_cache", to_dlpack(conformer_cnn_cache)),
+                    pb_utils.Tensor.from_dlpack("conformer_att_cache", to_dlpack(conformer_att_cache)),
+                    pb_utils.Tensor.from_dlpack("estimator_cnn_cache", to_dlpack(estimator_cnn_cache)),
+                    pb_utils.Tensor.from_dlpack("estimator_att_cache", to_dlpack(estimator_att_cache)),
+                    pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel)),
+                    pb_utils.Tensor.from_dlpack("source", to_dlpack(source)),
+                    pb_utils.Tensor.from_dlpack("speech", to_dlpack(speech)),
+                ])
+            else:
+                outputs.extend([pb_utils.Tensor("conformer_cnn_cache", np.array([], dtype=np.float16)),
+                pb_utils.Tensor("conformer_att_cache", np.array([], dtype=np.float16)),
+                pb_utils.Tensor("estimator_cnn_cache", np.array([], dtype=np.float16)),
+                pb_utils.Tensor("estimator_att_cache", np.array([], dtype=np.float16)),
+                pb_utils.Tensor("mel", np.array([], dtype=np.float32)),
+                pb_utils.Tensor("source", np.array([], dtype=np.float32)),
+                pb_utils.Tensor("speech", np.array([], dtype=np.float32)),
+                ])
+
+            inference_response = pb_utils.InferenceResponse(output_tensors=outputs)
             responses.append(inference_response)
-
         return responses
+
+    def finalize(self):
+        self.logger.log_info("Finalizing Token2WavDiT model")

+ 7 - 39
runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py

@@ -372,7 +372,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
         self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000
     ):  
         if speaker_id not in self.speaker_cache:
-        # if 1:
             assert prompt_audio is not None, "prompt_audio is required for new speaker"
             assert prompt_audio_sample_rate == 16000
 
@@ -384,20 +383,10 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
 
             prompt_audio_dict = {'spk_emb_for_flow': spk_emb_for_flow, 'prompt_mels_for_flow': prompt_mels_for_flow}
 
-        # if speaker_id not in self.speaker_cache:
-        # if 1:
-            
             cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
             self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict}
             print(f"speaker_id {speaker_id} added to cache")
 
-            # get a clone of cache dict ['estimator_att_cache'] and later check if it would be change 
-        att_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['estimator_att_cache'].clone()
-        cnn_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['estimator_cnn_cache'].clone()
-        conformer_cnn_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['conformer_cnn_cache'].clone()
-        conformer_att_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['conformer_att_cache'].clone()
-    
-
         if request_id not in self.streaming_flow_cache:
             self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()}
             self.hift_cache_dict[request_id] = dict(
@@ -405,6 +394,12 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
             source = torch.zeros(1, 1, 0, device='cuda'),
             speech = torch.zeros(1, 0, device='cuda'),
             )
+        # else:
+        #     for k, v in self.streaming_flow_cache[request_id].items():
+        #         print(f"k: {k}, v: {v.shape}, dtype: {v.dtype}")
+        #     for k, v in self.hift_cache_dict[request_id].items():
+        #         print(f"k: {k}, v: {v.shape}, dtype: {v.dtype}")
+        #     breakpoint()
 
         current_request_cache = self.streaming_flow_cache[request_id]
 
@@ -420,33 +415,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
             n_timesteps=10,
         )
 
-        # get the original att_cache
-        original_att_cache = self.speaker_cache[speaker_id]['cache_dict']['estimator_att_cache']
-        original_cnn_cache = self.speaker_cache[speaker_id]['cache_dict']['estimator_cnn_cache']
-        original_conformer_cnn_cache = self.speaker_cache[speaker_id]['cache_dict']['conformer_cnn_cache']
-        original_conformer_att_cache = self.speaker_cache[speaker_id]['cache_dict']['conformer_att_cache']
-        if not torch.allclose(original_att_cache, att_cache_clone):
-            print("att_cache changed")
-            # print the last 10 elements of original_att_cache and att_cache_clone
-            print(original_att_cache[:, :, :, -10:])
-            print(att_cache_clone[:, :, :, -10:])
-            breakpoint()
-        if not torch.allclose(original_cnn_cache, cnn_cache_clone):
-            print("cnn_cache changed")
-            print(original_cnn_cache[..., -10:])
-            print(cnn_cache_clone[..., -10:])
-            breakpoint()
-        if not torch.allclose(original_conformer_cnn_cache, conformer_cnn_cache_clone):
-            print("conformer_cnn_cache changed")
-            print(original_conformer_cnn_cache[..., -10:])
-            print(conformer_cnn_cache_clone[..., -10:])
-            breakpoint()
-        if not torch.allclose(original_conformer_att_cache, conformer_att_cache_clone):
-            print("conformer_att_cache changed")
-            print(original_conformer_att_cache[..., -10:])
-            print(conformer_att_cache_clone[..., -10:])
-            breakpoint()
-
         self.streaming_flow_cache[request_id] = new_streaming_flow_cache
 
 
@@ -482,7 +450,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
             assert request_id in self.streaming_flow_cache
             self.streaming_flow_cache.pop(request_id)
             self.hift_cache_dict.pop(request_id)
-        # breakpoint()
+
         return speech
 
 def collate_fn(batch):

+ 80 - 0
runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt

@@ -15,11 +15,14 @@
 name: "token2wav_dit"
 backend: "python"
 max_batch_size: ${triton_max_batch_size}
+
 dynamic_batching {
     max_queue_delay_microseconds: ${max_queue_delay_microseconds}
     priority_levels: 10
     default_priority_level: 10
 }
+
+parameters: { key: "FORCE_CPU_ONLY_INPUT_TENSORS" value: {string_value:"no"}}
 parameters [
   {
    key: "model_dir",
@@ -49,6 +52,48 @@ input [
     dims: [ 1 ]
     reshape: { shape: [ ] }
     optional: true
+  },
+  {
+    name: "conformer_cnn_cache"
+    data_type: TYPE_FP16
+    dims: [ 512, -1 ]
+    optional: true
+  },
+  {
+    name: "conformer_att_cache"
+    data_type: TYPE_FP16
+    dims: [ 10, 8, -1, 128 ]
+    optional: true
+  },
+  {
+    name: "estimator_cnn_cache"
+    data_type: TYPE_FP16
+    dims: [ 10, 16, -1, 1024, 2 ]
+    optional: true
+  },
+  {
+    name: "estimator_att_cache"
+    data_type: TYPE_FP16
+    dims: [ 10, 16, -1, 8, -1, 128 ]
+    optional: true
+  },
+  {
+    name: "mel"
+    data_type: TYPE_FP32
+    dims: [ 80, -1 ]
+    optional: true
+  },
+  {
+    name: "source"
+    data_type: TYPE_FP32
+    dims: [ 1, -1 ]
+    optional: true
+  },
+  {
+    name: "speech"
+    data_type: TYPE_FP32
+    dims: [ -1 ]
+    optional: true
   }
 ]
 output [
@@ -56,6 +101,41 @@ output [
     name: "waveform"
     data_type: TYPE_FP32
     dims: [ -1 ]
+  },
+  {
+    name: "conformer_cnn_cache"
+    data_type: TYPE_FP16
+    dims: [ 512, -1 ]
+  },
+  {
+    name: "conformer_att_cache"
+    data_type: TYPE_FP16
+    dims: [ 10, 8, -1, 128 ]
+  },
+  {
+    name: "estimator_cnn_cache"
+    data_type: TYPE_FP16
+    dims: [ 10, 16, -1, 1024, 2 ]
+  },
+  {
+    name: "estimator_att_cache"
+    data_type: TYPE_FP16
+    dims: [ 10, 16, -1, 8, -1, 128 ]
+  },
+  {
+    name: "mel"
+    data_type: TYPE_FP32
+    dims: [ 80, -1 ]
+  },
+  {
+    name: "source"
+    data_type: TYPE_FP32
+    dims: [ 1, -1 ]
+  },
+  {
+    name: "speech"
+    data_type: TYPE_FP32
+    dims: [ -1 ]
   }
 ]