yuekaiz 6 月之前
父节点
当前提交
33aee03ed5

+ 1 - 1
examples/grpo/cosyvoice2/infer_dataset.py

@@ -53,7 +53,7 @@ except RuntimeError:
     pass
 
 
-TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}"
+TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}" # noqa: E501
 
 
 def audio_decode_cosyvoice2(

+ 0 - 2
examples/grpo/cosyvoice2/pretrained_to_huggingface.py

@@ -1,5 +1,3 @@
-#!/usr/bin/env python3
-
 # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
 # SPDX-License-Identifier: Apache-2.0
 #

+ 1 - 3
examples/grpo/cosyvoice2/scripts/offline-decode-files.py

@@ -1,5 +1,3 @@
-#!/usr/bin/env python3
-#
 # Copyright (c)  2023 by manyeyes
 # Copyright (c)  2023  Xiaomi Corporation
 
@@ -195,7 +193,7 @@ def write_error_stats(
             hyp = list("".join(hyp))
             results[i] = (cut_id, ref, hyp)
 
-    for cut_id, ref, hyp in results:
+    for _cut_id, ref, hyp in results:
         ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
         for ref_word, hyp_word in ali:
             if ref_word == ERR:

+ 1 - 1
examples/grpo/cosyvoice2/token2wav_asr_server.py

@@ -295,7 +295,7 @@ def main():
         metrics_port=8002,
     )
 
-    device_ids = [i for i in range(args.number_of_devices)]
+    device_ids = list(range(args.number_of_devices))
     device_ids = device_ids * args.number_of_instances_per_device
 
     with Triton(config=triton_config) as triton:

+ 10 - 2
runtime/triton_trtllm/client_grpc.py

@@ -122,7 +122,10 @@ def write_triton_stats(stats, summary_file):
             total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
             total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
             summary_f.write(
-                f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n"
+                f"queue time {total_queue_time_s:<5.2f} s, "
+                f"compute infer time {total_infer_time_s:<5.2f} s, "
+                f"compute input time {total_input_time_s:<5.2f} s, "
+                f"compute output time {total_output_time_s:<5.2f} s \n"
             )
             model_batch_stats = model_state["batch_stats"]
             for batch in model_batch_stats:
@@ -136,7 +139,12 @@ def write_triton_stats(stats, summary_file):
                 compute_input_time_ms = int(compute_input["ns"]) / 1e6
                 compute_output_time_ms = int(compute_output["ns"]) / 1e6
                 summary_f.write(
-                    f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n"
+                    f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, "
+                    f"total_infer_time {compute_infer_time_ms:<9.2f} ms, "
+                    f"avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}="
+                    f"{compute_infer_time_ms / batch_count:.2f} ms, "
+                    f"avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}="
+                    f"{compute_infer_time_ms / batch_count / batch_size:.2f} ms \n"
                 )
                 summary_f.write(
                     f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, "

+ 0 - 1
runtime/triton_trtllm/client_http.py

@@ -25,7 +25,6 @@
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 import requests
 import soundfile as sf
-import json
 import numpy as np
 import argparse
 

+ 0 - 3
runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py

@@ -25,12 +25,9 @@
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 import json
-import math
 import os
-import re
 import threading
 import time
-from typing import Dict, List, Tuple, Optional, Union
 
 import numpy as np
 import torch

+ 1 - 2
runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py

@@ -178,7 +178,6 @@ class TritonPythonModel:
             yield final_id
             buffer = buffer[match.end():]
 
-
     def forward_audio_tokenizer(self, wav, wav_len):
         """Forward pass through the audio tokenizer component.
 
@@ -263,7 +262,7 @@ class TritonPythonModel:
             ],
             inputs=inputs_tensor,
             request_id=request_id,
-            parameters={"priority": index+1},
+            parameters={"priority": index + 1},
         )
 
         inference_response = await inference_request.async_exec()

+ 0 - 1
runtime/triton_trtllm/model_repo/token2wav/1/model.py

@@ -28,7 +28,6 @@ import json
 import os
 
 import logging
-from typing import List, Dict
 
 import torch
 from torch.utils.dlpack import to_dlpack

+ 7 - 2
runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py

@@ -48,9 +48,11 @@ import hashlib
 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 logger = logging.getLogger(__name__)
 
+
 ORIGINAL_VOCAB_SIZE = 151663
 torch.set_num_threads(1)
 
+
 def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str:
     """
     Generates a unique ID for a torch.Tensor.
@@ -65,6 +67,7 @@ def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str:
 
     return hasher.hexdigest()
 
+
 class TritonPythonModel:
     """Triton Python model for vocoder.
 
@@ -114,7 +117,6 @@ class TritonPythonModel:
 
             request_id = request.request_id()
 
-
             wav_array = pb_utils.get_input_tensor_by_name(
                 request, "reference_wav").as_numpy()
             wav_len = pb_utils.get_input_tensor_by_name(
@@ -125,7 +127,10 @@ class TritonPythonModel:
 
             spk_id = get_spk_id_from_prompt_audio(wav)
 
-            audio_hat = self.token2wav_model.forward_streaming(target_speech_tokens, finalize, request_id=request_id, speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000)
+            audio_hat = self.token2wav_model.forward_streaming(
+                target_speech_tokens, finalize, request_id=request_id,
+                speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000
+            )
 
             outputs = []
 

+ 67 - 44
runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py

@@ -35,7 +35,7 @@ import numpy as np
 from hyperpyyaml import load_hyperpyyaml
 
 
-def fade_in_out(fade_in_mel:torch.Tensor, fade_out_mel:torch.Tensor, window:torch.Tensor):
+def fade_in_out(fade_in_mel: torch.Tensor, fade_out_mel: torch.Tensor, window: torch.Tensor):
     """perform fade_in_out in tensor style
     """
     mel_overlap_len = int(window.shape[0] / 2)
@@ -45,6 +45,7 @@ def fade_in_out(fade_in_mel:torch.Tensor, fade_out_mel:torch.Tensor, window:torc
         fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
     return fade_in_mel
 
+
 def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype):
     import tensorrt as trt
     logging.info("Converting onnx to trt...")
@@ -90,6 +91,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype):
         f.write(engine_bytes)
     logging.info("Succesfully convert onnx to trt...")
 
+
 class TrtContextWrapper:
     def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
         self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
@@ -108,6 +110,7 @@ class TrtContextWrapper:
     def release_estimator(self, context, stream):
         self.trt_context_pool.put([context, stream])
 
+
 class CosyVoice2_Token2Wav(torch.nn.Module):
     def __init__(self, model_dir: str, enable_trt: bool = False, device_id: int = 0, streaming: bool = False, dtype: torch.dtype = torch.float16):
         super().__init__()
@@ -131,27 +134,33 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
         option = onnxruntime.SessionOptions()
         option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
         option.intra_op_num_threads = 1
-        self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option,
-                                                    providers=["CPUExecutionProvider"])
+        self.spk_model = onnxruntime.InferenceSession(
+            f"{model_dir}/campplus.onnx", sess_options=option,
+            providers=["CPUExecutionProvider"])
         self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2_25hz.onnx").to(self.device).eval()
 
-        gpu="l20"
+        gpu = "l20"
         if enable_trt:
             if streaming:
-                self.load_trt(f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.{gpu}.plan',
-                                    f'{model_dir}/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx',
-                                    1,
-                                    self.dtype, streaming)
+                self.load_trt(
+                    f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.{gpu}.plan',
+                    f'{model_dir}/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx',
+                    1,
+                    self.dtype, streaming
+                )
             else:
-                self.load_trt(f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.{gpu}.plan',
-                                    f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx',
-                                    1,
-                                    self.dtype)
-            self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt',
-                                f'{model_dir}/campplus.onnx',
-                                1,
-                                False)
-
+                self.load_trt(
+                    f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.{gpu}.plan',
+                    f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx',
+                    1,
+                    self.dtype
+                )
+            self.load_spk_trt(
+                f'{model_dir}/campplus.{gpu}.fp32.trt',
+                f'{model_dir}/campplus.onnx',
+                1,
+                False
+            )
 
         self.streaming_flow_cache = {}
         self.speaker_cache = {}
@@ -215,7 +224,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
             opt_batch_size = 2
             max_batch_size = 16
             if streaming:
-                opt_batch_size, max_batch_size = 1, 1 # only support batch size 1 for streaming tts
+                opt_batch_size, max_batch_size = 1, 1  # only support batch size 1 for streaming tts
             trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_batch_size=opt_batch_size, max_batch_size=max_batch_size, streaming=streaming)
             convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, dtype)
         del self.flow.decoder.estimator
@@ -228,13 +237,27 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
     def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_size=64, streaming=False):
         if streaming:
             min_shape = [(2, 80, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80), (16, 2, 1024, 2), (16, 2, 8, 0, 128)]
-            opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80), (16, opt_batch_size*2, 1024, 2), (16, opt_batch_size*2, 8, 100, 128)]
-            max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80), (16, max_batch_size*2, 1024, 2), (16, max_batch_size*2, 8, 1000, 128)]
+            opt_shape = [
+                (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500),
+                (opt_batch_size * 2,), (opt_batch_size * 2, 80), (16, opt_batch_size * 2, 1024, 2),
+                (16, opt_batch_size * 2, 8, 100, 128)
+            ]
+            max_shape = [
+                (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000),
+                (max_batch_size * 2,), (max_batch_size * 2, 80), (16, max_batch_size * 2, 1024, 2),
+                (16, max_batch_size * 2, 8, 1000, 128)
+            ]
             input_names = ["x", "mu", "cond", "t", "spks", "cnn_cache", "att_cache"]
         else:
             min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)]
-            opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 1, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80)]
-            max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 1, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80)]
+            opt_shape = [
+                (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 1, 500), (opt_batch_size * 2, 80, 500),
+                (opt_batch_size * 2, 80, 500), (opt_batch_size * 2,), (opt_batch_size * 2, 80)
+            ]
+            max_shape = [
+                (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 1, 3000), (max_batch_size * 2, 80, 3000),
+                (max_batch_size * 2, 80, 3000), (max_batch_size * 2,), (max_batch_size * 2, 80)
+            ]
             input_names = ["x", "mask", "mu", "cond", "t", "spks"]
         return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
 
@@ -279,11 +302,17 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
             mel_len = mel.shape[0]
             prompt_mels_for_flow.append(mel)
             prompt_mels_lens_for_flow.append(mel_len)
-        prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0)  # [B, T', num_mels=80]
+        prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(
+            prompt_mels_for_flow, batch_first=True, padding_value=0
+        )  # [B, T', num_mels=80]
         prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
         return prompt_mels_for_flow, prompt_mels_lens_for_flow
 
-    def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor):
+    def forward_flow(self, prompt_speech_tokens_list: list[list[int]],
+                     generated_speech_tokens_list: list[list[int]],
+                     prompt_mels_for_flow: torch.Tensor,
+                     prompt_mels_lens_for_flow: torch.Tensor,
+                     spk_emb_for_flow: torch.Tensor):
         batch_size = prompt_mels_for_flow.shape[0]
         flow_inputs = []
         flow_inputs_lens = []
@@ -311,7 +340,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
             generated_wavs.append(wav)
         return generated_wavs
 
-
     @torch.inference_mode()
     def forward(
         self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
@@ -320,7 +348,10 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
 
         prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio(prompt_audios_list, prompt_audios_sample_rate)
 
-        generated_mels, generated_mels_lens = self.forward_flow(prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
+        generated_mels, generated_mels_lens = self.forward_flow(
+            prompt_speech_tokens_list, generated_speech_tokens_list,
+            prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
+        )
 
         generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow)
         return generated_wavs
@@ -337,7 +368,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
         spk_emb_for_flow = self.get_spk_emb(prompt_audios_list)
         return prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
 
-
     def get_prompt_audio_cache_for_streaming_tts(
         self, prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
     ):
@@ -356,7 +386,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
         # Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache']
         return new_cache
 
-
     @torch.inference_mode()
     def forward_streaming(
         self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000
@@ -379,9 +408,9 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
         if request_id not in self.streaming_flow_cache:
             self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()}
             self.hift_cache_dict[request_id] = dict(
-            mel = torch.zeros(1, 80, 0, device='cuda'),
-            source = torch.zeros(1, 1, 0, device='cuda'),
-            speech = torch.zeros(1, 0, device='cuda'),
+                mel=torch.zeros(1, 80, 0, device='cuda'),
+                source=torch.zeros(1, 1, 0, device='cuda'),
+                speech=torch.zeros(1, 0, device='cuda'),
             )
 
         current_request_cache = self.streaming_flow_cache[request_id]
@@ -389,7 +418,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
         current_prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict']
         generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda')
 
-
         chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk(
             token=generated_speech_tokens,
             spk=current_prompt_audio_dict['spk_emb_for_flow'].to(self.device),
@@ -400,15 +428,12 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
 
         self.streaming_flow_cache[request_id] = new_streaming_flow_cache
 
-
         if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (current_prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100):
             self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([
                 self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :current_prompt_audio_dict['prompt_mels_for_flow'].shape[1]],
                 self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:],
             ], dim=4)
 
-
-
         hift_cache_mel = self.hift_cache_dict[request_id]['mel'].clone()
         hift_cache_source = self.hift_cache_dict[request_id]['source'].clone()
         hift_cache_speech = self.hift_cache_dict[request_id]['speech'].clone()
@@ -422,9 +447,9 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
 
         # update vocoder cache
         self.hift_cache_dict[request_id] = dict(
-            mel = mel[..., -self.mel_cache_len:].clone().detach(),
-            source = source[:, :, -self.source_cache_len:].clone().detach(),
-            speech = speech[:, -self.source_cache_len:].clone().detach(),
+            mel=mel[..., -self.mel_cache_len:].clone().detach(),
+            source=source[:, :, -self.source_cache_len:].clone().detach(),
+            speech=speech[:, -self.source_cache_len:].clone().detach(),
         )
         if not last_chunk:
             speech = speech[:, :-self.source_cache_len]
@@ -436,6 +461,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
 
         return speech
 
+
 def collate_fn(batch):
     ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
     for i, item in enumerate(batch):
@@ -447,6 +473,7 @@ def collate_fn(batch):
 
     return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate
 
+
 def get_args():
     parser = argparse.ArgumentParser()
     parser.add_argument("--enable-trt", action="store_true")
@@ -457,6 +484,7 @@ def get_args():
     parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch")
     return parser.parse_args()
 
+
 if __name__ == "__main__":
     args = get_args()
     model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt)
@@ -466,22 +494,17 @@ if __name__ == "__main__":
 
     dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
 
-
     data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
 
-
     for epoch in range(args.warmup):
         start_time = time.time()
-
         for batch in data_loader:
             ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch
 
             generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate)
 
-
             for id, wav in zip(ids, generated_wavs):
                 torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000)
-
         end_time = time.time()
         epoch_time = end_time - start_time
-        print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")
+        print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")

+ 0 - 1
runtime/triton_trtllm/offline_inference.py

@@ -28,7 +28,6 @@ import argparse
 import json
 import os
 import sys
-from pathlib import Path
 
 import torch
 import torch.distributed as dist

+ 0 - 5
runtime/triton_trtllm/scripts/test_llm.py

@@ -15,11 +15,6 @@
 # limitations under the License.
 
 import argparse
-import ast
-import csv
-import os
-from pathlib import Path
-from typing import List, Optional
 
 import numpy as np
 import torch

+ 12 - 4
runtime/triton_trtllm/streaming_inference.py

@@ -9,6 +9,7 @@ import time
 from token2wav_dit import CosyVoice2_Token2Wav
 import soundfile as sf
 
+
 def collate_fn(batch):
     ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
     prompt_speech_tokens_list, prompt_text_list = [], []
@@ -23,6 +24,7 @@ def collate_fn(batch):
 
     return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list
 
+
 def get_args():
     parser = argparse.ArgumentParser()
     parser.add_argument("--enable-trt", action="store_true")
@@ -79,7 +81,11 @@ if __name__ == "__main__":
                     this_chunk_size = token_frame_rate * (2 ** chunk_index)
 
                 if len(buffer) >= this_chunk_size + token2wav_model.flow.pre_lookahead_len:
-                    wavs = token2wav_model.forward_streaming(buffer[:this_chunk_size + token2wav_model.flow.pre_lookahead_len], False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate)
+                    wavs = token2wav_model.forward_streaming(
+                        buffer[:this_chunk_size + token2wav_model.flow.pre_lookahead_len],
+                        False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio,
+                        prompt_audio_sample_rate=prompt_audio_sample_rate
+                    )
                     buffer = buffer[this_chunk_size - OVERLAP_SIZE:]
 
                     output_wavs.append(wavs)
@@ -87,7 +93,10 @@ if __name__ == "__main__":
                     chunk_index += 1
 
                 else:
-                    wavs = token2wav_model.forward_streaming(buffer, True, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate)
+                    wavs = token2wav_model.forward_streaming(
+                        buffer, True, request_id=id, speaker_id=f"{id}",
+                        prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate
+                    )
                     output_wavs.append(wavs)
                     total_forward_count += 1
                     # chunk_index += 1
@@ -96,7 +105,6 @@ if __name__ == "__main__":
             for i, wav in enumerate(output_wavs):
                 output_wavs[i] = wav.cpu().numpy().squeeze()
 
-
             audios = output_wavs
             reconstructed_audio = np.concatenate(audios)
             sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16")
@@ -111,4 +119,4 @@ if __name__ == "__main__":
             print(f"Cost time without speaker cache: {end_time - start_time} seconds")
         else:
             print(f"Cost time with speaker cache: {end_time - start_time} seconds")
-            print(f"Total flow matching forward calls: {total_forward_count}")
+            print(f"Total flow matching forward calls: {total_forward_count}")