root 1 mês atrás
pai
commit
a019a2504e

+ 29 - 113
runtime/triton_trtllm/client_grpc.py

@@ -43,9 +43,9 @@ python3 client_grpc.py \
 import argparse
 import asyncio
 import json
-import queue  # Added
-import uuid  # Added
-import functools  # Added
+import queue
+import uuid
+import functools
 
 import os
 import time
@@ -55,13 +55,11 @@ from pathlib import Path
 import numpy as np
 import soundfile as sf
 import tritonclient
-import tritonclient.grpc.aio as grpcclient_aio  # Renamed original import
-import tritonclient.grpc as grpcclient_sync  # Added sync client import
-from tritonclient.utils import np_to_triton_dtype, InferenceServerException  # Added InferenceServerException
+import tritonclient.grpc.aio as grpcclient_aio
+import tritonclient.grpc as grpcclient_sync
+from tritonclient.utils import np_to_triton_dtype, InferenceServerException
 
-from datetime import datetime
 
-# --- Added UserData and callback ---
 class UserData:
     def __init__(self):
         self._completed_requests = queue.Queue()
@@ -86,7 +84,7 @@ class UserData:
 def callback(user_data, result, error):
     if not error:
         if user_data._first_chunk_time is None:
-            user_data._first_chunk_time = time.time()  # Record time of first successful chunk
+            user_data._first_chunk_time = time.time()
         elif user_data._second_chunk_time is None:
             user_data._second_chunk_time = time.time()
 
@@ -99,10 +97,6 @@ def callback(user_data, result, error):
 def stream_callback(user_data_map, result, error):
     request_id = None
     if error:
-        # Note: InferenceServerException doesn't have a public request_id() method in all versions.
-        # This part might need adjustment depending on the tritonclient library version.
-        # A more robust way would be to wrap the error with the request_id if possible.
-        # For now, we assume we can't get request_id from error and it will timeout on the client side.
         print(f"An error occurred in the stream callback: {error}")
     else:
         request_id = result.get_response().id
@@ -115,31 +109,9 @@ def stream_callback(user_data_map, result, error):
             print(f"Warning: Could not find user_data for request_id {request_id}")
 
 
-# --- End Added UserData and callback ---
-
-
 def write_triton_stats(stats, summary_file):
     with open(summary_file, "w") as summary_f:
         model_stats = stats["model_stats"]
-        # write a note, the log is from triton_client.get_inference_statistics(), to better human readability
-        summary_f.write(
-            "The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
-        )
-        summary_f.write("To learn more about the log, please refer to: \n")
-        summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n")
-        summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n")
-        summary_f.write(
-            "To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
-        )
-        summary_f.write(
-            "However, there is a trade-off between the increased queue time and the increased batch size. \n"
-        )
-        summary_f.write(
-            "You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n"
-        )
-        summary_f.write(
-            "See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n"
-        )
         for model_state in model_stats:
             if "last_inference" not in model_state:
                 continue
@@ -150,7 +122,7 @@ def write_triton_stats(stats, summary_file):
             total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
             total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
             summary_f.write(
-                f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n"  # noqa
+                f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n"
             )
             model_batch_stats = model_state["batch_stats"]
             for batch in model_batch_stats:
@@ -164,19 +136,18 @@ def write_triton_stats(stats, summary_file):
                 compute_input_time_ms = int(compute_input["ns"]) / 1e6
                 compute_output_time_ms = int(compute_output["ns"]) / 1e6
                 summary_f.write(
-                    f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n"  # noqa
+                    f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n"
                 )
                 summary_f.write(
-                    f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, "  # noqa
+                    f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, "
                 )
                 summary_f.write(
-                    f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n"  # noqa
+                    f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n"
                 )
 
 
 def subtract_stats(stats_after, stats_before):
     """Subtracts two Triton inference statistics objects."""
-    # Deep copy to avoid modifying the original stats_after
     stats_diff = json.loads(json.dumps(stats_after))
 
     model_stats_before_map = {
@@ -196,7 +167,6 @@ def subtract_stats(stats_after, stats_before):
         if model_name in model_stats_before_map:
             model_stat_before = model_stats_before_map[model_name]
 
-            # Subtract counts
             model_stat_after["inference_count"] = str(
                 int(model_stat_after.get("inference_count", 0)) - int(model_stat_before.get("inference_count", 0))
             )
@@ -204,7 +174,6 @@ def subtract_stats(stats_after, stats_before):
                 int(model_stat_after.get("execution_count", 0)) - int(model_stat_before.get("execution_count", 0))
             )
 
-            # Subtract aggregate stats (like queue, compute times)
             if "inference_stats" in model_stat_after and "inference_stats" in model_stat_before:
                 for key in ["success", "fail", "queue", "compute_input", "compute_infer", "compute_output", "cache_hit", "cache_miss"]:
                     if key in model_stat_after["inference_stats"] and key in model_stat_before["inference_stats"]:
@@ -217,7 +186,6 @@ def subtract_stats(stats_after, stats_before):
                             count_before = int(model_stat_before["inference_stats"][key]["count"])
                             model_stat_after["inference_stats"][key]["count"] = str(count_after - count_before)
 
-            # Subtract batch execution stats
             if "batch_stats" in model_stat_after and "batch_stats" in model_stat_before:
                 batch_stats_before_map = {b["batch_size"]: b for b in model_stat_before["batch_stats"]}
                 for batch_stat_after in model_stat_after["batch_stats"]:
@@ -338,7 +306,6 @@ def get_args():
         help="log directory",
     )
 
-    # --- Added arguments ---
     parser.add_argument(
         "--mode",
         type=str,
@@ -379,39 +346,33 @@ def load_audio(wav_path, target_sample_rate=16000):
 
 
 def prepare_request_input_output(
-    protocol_client,  # Can be grpcclient_aio or grpcclient_sync
+    protocol_client,
     waveform,
     reference_text,
     target_text,
     sample_rate=16000,
-    padding_duration: int = None,  # Optional padding for offline mode
+    padding_duration: int = None,
     use_spk2info_cache: bool = False
 ):
     """Prepares inputs for Triton inference (offline or streaming)."""
     assert len(waveform.shape) == 1, "waveform should be 1D"
     lengths = np.array([[len(waveform)]], dtype=np.int32)
 
-    # Apply padding only if padding_duration is provided (for offline)
     if padding_duration:
         duration = len(waveform) / sample_rate
-        # Estimate target duration based on text length ratio (crude estimation)
-        # Avoid division by zero if reference_text is empty
         if reference_text:
             estimated_target_duration = duration / len(reference_text) * len(target_text)
         else:
-            estimated_target_duration = duration  # Assume target duration similar to reference if no text
+            estimated_target_duration = duration
 
-        # Calculate required samples based on estimated total duration
         required_total_samples = padding_duration * sample_rate * (
             (int(estimated_target_duration + duration) // padding_duration) + 1
         )
         samples = np.zeros((1, required_total_samples), dtype=np.float32)
         samples[0, : len(waveform)] = waveform
     else:
-        # No padding for streaming or if padding_duration is None
         samples = waveform.reshape(1, -1).astype(np.float32)
 
-    # Common input creation logic
     inputs = [
         protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
         protocol_client.InferInput(
@@ -450,12 +411,8 @@ def run_sync_streaming_inference(
 ):
     """Helper function to run the blocking sync streaming call."""
     start_time_total = time.time()
-    user_data.record_start_time()  # Record start time for first chunk latency calculation
-    # e.g. 08:47:34.827758
+    user_data.record_start_time()
 
-    print(f"Record start time in human readable: {datetime.now()}")
-    # input()
-    # Send request
     sync_triton_client.async_stream_infer(
         model_name,
         inputs,
@@ -464,30 +421,26 @@ def run_sync_streaming_inference(
         enable_empty_final_response=True,
     )
 
-    # Process results
     audios = []
     while True:
         try:
-            result = user_data._completed_requests.get(timeout=20)  # Add timeout
+            result = user_data._completed_requests.get(timeout=20)
             if isinstance(result, InferenceServerException):
                 print(f"Received InferenceServerException: {result}")
-                # Don't stop the stream here, just return error
                 return None, None, None, None
-            # Get response metadata
             response = result.get_response()
             final = response.parameters["triton_final_response"].bool_param
             if final is True:
                 break
 
             audio_chunk = result.as_numpy("waveform").reshape(-1)
-            if audio_chunk.size > 0:  # Only append non-empty chunks
+            if audio_chunk.size > 0:
                 audios.append(audio_chunk)
             else:
                 print("Warning: received empty audio chunk.")
 
         except queue.Empty:
             print(f"Timeout waiting for response for request id {request_id}")
-            # Don't stop stream here, just return error
             return None, None, None, None
 
     end_time_total = time.time()
@@ -495,47 +448,36 @@ def run_sync_streaming_inference(
     first_chunk_latency = user_data.get_first_chunk_latency()
     second_chunk_latency = user_data.get_second_chunk_latency()
 
-    # Reconstruct audio using cross-fade (from client_grpc_streaming.py)
-    actual_duration = 0
     if audios:
-        # Only spark_tts model uses cross-fade
         if model_name == "spark_tts":
             cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
             fade_out = np.linspace(1, 0, cross_fade_samples)
             fade_in = np.linspace(0, 1, cross_fade_samples)
             reconstructed_audio = None
 
-            # Simplified reconstruction based on client_grpc_streaming.py
             if not audios:
                 print("Warning: No audio chunks received.")
-                reconstructed_audio = np.array([], dtype=np.float32)  # Empty array
+                reconstructed_audio = np.array([], dtype=np.float32)
             elif len(audios) == 1:
                 reconstructed_audio = audios[0]
             else:
-                reconstructed_audio = audios[0][:-cross_fade_samples]  # Start with first chunk minus overlap
+                reconstructed_audio = audios[0][:-cross_fade_samples]
                 for i in range(1, len(audios)):
-                    # Cross-fade section
                     cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
                                            audios[i - 1][-cross_fade_samples:] * fade_out)
-                    # Middle section of the current chunk
                     middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
-                    # Concatenate
                     reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
-                # Add the last part of the final chunk
                 reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
 
             if reconstructed_audio is not None and reconstructed_audio.size > 0:
                 actual_duration = len(reconstructed_audio) / save_sample_rate
-                # Save reconstructed audio
                 sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
             else:
                 print("Warning: No audio chunks received or reconstructed.")
-                actual_duration = 0  # Set duration to 0 if no audio
+                actual_duration = 0
         else:
             reconstructed_audio = np.concatenate(audios)
-            print(f"reconstructed_audio: {reconstructed_audio.shape}")
             actual_duration = len(reconstructed_audio) / save_sample_rate
-            # Save reconstructed audio
             sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
 
     else:
@@ -548,7 +490,7 @@ def run_sync_streaming_inference(
 async def send_streaming(
     manifest_item_list: list,
     name: str,
-    server_url: str,  # Changed from sync_triton_client
+    server_url: str,
     protocol_client: types.ModuleType,
     log_interval: int,
     model_name: str,
@@ -561,12 +503,12 @@ async def send_streaming(
     total_duration = 0.0
     latency_data = []
     task_id = int(name[5:])
-    sync_triton_client = None  # Initialize client variable
+    sync_triton_client = None
     user_data_map = {}
 
-    try:  # Wrap in try...finally to ensure client closing
+    try:
         print(f"{name}: Initializing sync client for streaming...")
-        sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False)  # Create client here
+        sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False)
         sync_triton_client.start_stream(callback=functools.partial(stream_callback, user_data_map))
 
         print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
@@ -593,7 +535,6 @@ async def send_streaming(
                 user_data_map[request_id] = user_data
 
                 audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
-                print("target_text: ", target_text, "time: ", datetime.now())
                 total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration = await asyncio.to_thread(
                     run_sync_streaming_inference,
                     sync_triton_client,
@@ -627,7 +568,7 @@ async def send_streaming(
                 import traceback
                 traceback.print_exc()
 
-    finally:  # Ensure client is closed
+    finally:
         if sync_triton_client:
             try:
                 print(f"{name}: Closing stream and sync client...")
@@ -656,7 +597,6 @@ async def send(
     latency_data = []
     task_id = int(name[5:])
 
-    print(f"manifest_item_list: {manifest_item_list}")
     for i, item in enumerate(manifest_item_list):
         if i % log_interval == 0:
             print(f"{name}: {i}/{len(manifest_item_list)}")
@@ -697,7 +637,6 @@ def load_manifests(manifest_path):
             assert len(line.strip().split("|")) == 4
             utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
             utt = Path(utt).stem
-            # gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav")
             if not os.path.isabs(prompt_wav):
                 prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
             manifest_list.append(
@@ -738,23 +677,17 @@ async def main():
     args = get_args()
     url = f"{args.server_addr}:{args.server_port}"
 
-    # --- Client Initialization based on mode ---
     triton_client = None
     protocol_client = None
     if args.mode == "offline":
         print("Initializing gRPC client for offline mode...")
-        # Use the async client for offline tasks
         triton_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
         protocol_client = grpcclient_aio
     elif args.mode == "streaming":
         print("Initializing gRPC client for streaming mode...")
-        # Use the sync client for streaming tasks, handled via asyncio.to_thread
-        # We will create one sync client instance PER TASK inside send_streaming.
-        # triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now
-        protocol_client = grpcclient_sync  # protocol client for input prep
+        protocol_client = grpcclient_sync
     else:
         raise ValueError(f"Invalid mode: {args.mode}")
-    # --- End Client Initialization ---
 
     if args.reference_audio:
         args.num_tasks = 1
@@ -776,24 +709,18 @@ async def main():
             trust_remote_code=True,
         )
         manifest_item_list = []
-        tmp_audio_path="./asset_zero_shot_prompt.wav"
-        tmp_audio_text="希望你以后能够做的比我还好呦。"
         for i in range(len(dataset)):
             manifest_item_list.append(
                 {
                     "audio_filepath": dataset[i]["prompt_audio"],
                     "reference_text": dataset[i]["prompt_text"],
-                    # "audio_filepath": tmp_audio_path,
-                    # "reference_text": tmp_audio_text,
                     "target_audio_path": dataset[i]["id"],
                     "target_text": dataset[i]["target_text"],
                 }
             )
-        # manifest_item_list = manifest_item_list[:4]
     else:
         manifest_item_list = load_manifests(args.manifest_path)
 
-    # --- Statistics Fetching (Before) ---
     stats_client = None
     stats_before = None
     try:
@@ -803,7 +730,6 @@ async def main():
         stats_before = await stats_client.get_inference_statistics(model_name="", as_json=True)
     except Exception as e:
         print(f"Could not retrieve statistics before running tasks: {e}")
-    # --- End Statistics Fetching (Before) ---
 
     num_tasks = min(args.num_tasks, len(manifest_item_list))
     manifest_item_list = split_data(manifest_item_list, num_tasks)
@@ -813,7 +739,6 @@ async def main():
     tasks = []
     start_time = time.time()
     for i in range(num_tasks):
-        # --- Task Creation based on mode ---
         if args.mode == "offline":
             task = asyncio.create_task(
                 send(
@@ -834,7 +759,7 @@ async def main():
                 send_streaming(
                     manifest_item_list[i],
                     name=f"task-{i}",
-                    server_url=url,  # Pass URL instead of client
+                    server_url=url,
                     protocol_client=protocol_client,
                     log_interval=args.log_interval,
                     model_name=args.model_name,
@@ -845,7 +770,6 @@ async def main():
                     use_spk2info_cache=args.use_spk2info_cache,
                 )
             )
-        # --- End Task Creation ---
         tasks.append(task)
 
     ans_list = await asyncio.gather(*tasks)
@@ -858,7 +782,7 @@ async def main():
     for ans in ans_list:
         if ans:
             total_duration += ans[0]
-            latency_data.extend(ans[1])  # Use extend for list of lists
+            latency_data.extend(ans[1])
         else:
             print("Warning: A task returned None, possibly due to an error.")
 
@@ -874,10 +798,8 @@ async def main():
     s += f"({total_duration / 3600:.2f} hours)\n"
     s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"
 
-    # --- Statistics Reporting based on mode ---
     if latency_data:
         if args.mode == "offline":
-            # Original offline latency calculation
             latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
             if latency_list:
                 latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
@@ -892,7 +814,6 @@ async def main():
                 s += "No latency data collected for offline mode.\n"
 
         elif args.mode == "streaming":
-            # Calculate stats for total request latency and first chunk latency
             total_latency_list = [total for (total, first, second, duration) in latency_data if total is not None]
             first_chunk_latency_list = [first for (total, first, second, duration) in latency_data if first is not None]
             second_chunk_latency_list = [second for (total, first, second, duration) in latency_data if second is not None]
@@ -937,7 +858,6 @@ async def main():
                 s += "No second chunk latency data collected (check for errors or if all requests failed before second chunk).\n"
     else:
         s += "No latency data collected.\n"
-    # --- End Statistics Reporting ---
 
     print(s)
     if args.manifest_path:
@@ -947,12 +867,10 @@ async def main():
     elif args.reference_audio:
         name = Path(args.reference_audio).stem
     else:
-        name = "results"  # Default name if no manifest/split/audio provided
+        name = "results"
     with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
         f.write(s)
 
-    # --- Statistics Fetching using temporary Async Client ---
-    # Use a separate async client for fetching stats regardless of mode
     try:
         if stats_client and stats_before:
             print("Fetching inference statistics after running tasks...")
@@ -980,11 +898,9 @@ async def main():
                 await stats_client.close()
             except Exception as e:
                 print(f"Error closing async stats client: {e}")
-    # --- End Statistics Fetching ---
 
 
 if __name__ == "__main__":
-    # asyncio.run(main()) # Use TaskGroup for better exception handling if needed
     async def run_main():
         try:
             await main()

+ 18 - 52
runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py

@@ -43,7 +43,7 @@ import torchaudio
 
 
 from matcha.utils.audio import mel_spectrogram
-from datetime import datetime
+
 
 ORIGINAL_VOCAB_SIZE = 151663
 torch.set_num_threads(1)
@@ -85,9 +85,7 @@ class TritonPythonModel:
         self.model_config = json.loads(args['model_config'])
         parameters = self.model_config['parameters']
         model_params = {k: v["string_value"] for k, v in parameters.items()}
-        self.logger.log_info(f"model_params:{model_params}")
         self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential")  # "exponential" or "time_based"
-        # self.dynamic_chunk_strategy = "equal"
         self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
 
         # Initialize tokenizer
@@ -103,12 +101,8 @@ class TritonPythonModel:
         self.flow_pre_lookahead_len = 3
         self.token_hop_len = 15
 
-        spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt")
-        if not os.path.exists(spk_info_path):
-            raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
-        spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
-        self.default_spk_info = spk_info["001"]
         self.http_client = httpx.AsyncClient()
+        self.api_base = "http://localhost:8000/v1/chat/completions"
 
     def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
         """Converts a tensor or list of speech token IDs to a string representation."""
@@ -147,12 +141,8 @@ class TritonPythonModel:
             "stream": True,
         }
 
-        api_base = "http://localhost:8000/v1/chat/completions"
-
         buffer = ""
-        async with self.http_client.stream("POST", api_base, json=payload, timeout=None) as response:
-            print(f"start httpx.AsyncClient, target_text: {target_text[:5]}, time: {datetime.now()}")
-            print(f"start response.aiter_lines, target_text: {target_text[:5]}, time: {datetime.now()}")
+        async with self.http_client.stream("POST", self.api_base, json=payload, timeout=None) as response:
             response.raise_for_status()
             async for line in response.aiter_lines():
                 if line.startswith("data: "):
@@ -164,7 +154,6 @@ class TritonPythonModel:
                         content = json_data.get("choices", [{}])[0].get("delta", {}).get("content")
                         if content:
                             buffer += content
-                            print(f"buffer: {buffer}, target_text: {target_text[:5]}, time: {datetime.now()}")
                             while True:
                                 match = re.search(r"<\|s_(\d+)\|>", buffer)
                                 if not match:
@@ -307,40 +296,24 @@ class TritonPythonModel:
         wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
 
         # Process reference audio through audio tokenizer
-        if wav is not None:
-            wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
-            prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
-            prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
-
-            wav_tensor = wav.as_numpy()
-            wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
-            print(f"wav_tensor: {wav_tensor.shape}, time: {datetime.now()}")
-            prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
-            speech_feat = self._extract_speech_feat(prompt_speech_resample)
-            token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
-            prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
-            prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
-
-            reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
-            reference_text = reference_text[0][0].decode('utf-8')
-            # prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
-
-            # reference_text = self.default_spk_info["prompt_text"]
-            # prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
-            # prompt_speech_feat = None
-            # prompt_spk_embedding = None
 
-        else:
-            # using pre-cached reference text
-            assert False, "using pre-cached reference text is not supported"
-            reference_text = self.default_spk_info["prompt_text"]
-            prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
-            prompt_speech_feat = None
-            prompt_spk_embedding = None
+        wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
+        prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
+        prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
+
+        wav_tensor = wav.as_numpy()
+        wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
+        prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
+        speech_feat = self._extract_speech_feat(prompt_speech_resample)
+        token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
+        prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
+        prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
+
+        reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
+        reference_text = reference_text[0][0].decode('utf-8')
 
         target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
         target_text = target_text[0][0].decode('utf-8')
-        print(f"target_text: {target_text}, time: {datetime.now()}")
 
         if self.decoupled:
             response_sender = request.get_response_sender()
@@ -349,7 +322,6 @@ class TritonPythonModel:
             token_offset, chunk_index = 0, 0
             start_time = time.time()
             this_token_hop_len = self.token_hop_len
-            print(f"start forward_llm_async, target_text: {target_text[:5]}, time: {datetime.now()}")
             async for generated_ids in self.forward_llm_async(
                 target_text=target_text,
                 reference_text=reference_text,
@@ -358,24 +330,20 @@ class TritonPythonModel:
                 if not generated_ids:
                     break
                 semantic_token_ids_arr.append(generated_ids)
-                print(f"generated_ids: {generated_ids}, target_text: {target_text[:5]}, time: {datetime.now()}")
                 while True:
                     pending_num = len(semantic_token_ids_arr) - token_offset
                     if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
                         this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
                         this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
-                        print(f"chunk_index: {chunk_index}, target_text: {target_text[:5]}, time: {datetime.now()}")
                         sub_tts_speech = await self.forward_token2wav(
                             chunk_index,
                             this_tts_speech_token, request_id, wav, wav_len, False
                         )
-                        print(f"finish token2wav, target_text: {target_text[:5]}, time: {datetime.now()}")
                         audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
                         inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
                         response_sender.send(inference_response)
 
                         token_offset += this_token_hop_len
-                        self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}")
 
                         if self.dynamic_chunk_strategy == "exponential":
                             this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
@@ -389,7 +357,6 @@ class TritonPythonModel:
                                 avg_chunk_processing_time = cost_time / (chunk_index + 1)
                                 if avg_chunk_processing_time > 0:
                                     multiples = (duration - cost_time) / avg_chunk_processing_time
-                                    self.logger.log_info(f"multiples: {multiples}")
                                     next_pending_num = len(semantic_token_ids_arr) - token_offset
                                     if multiples > 4:
                                         this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len
@@ -409,9 +376,8 @@ class TritonPythonModel:
             response_sender.send(inference_response)
 
             response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
-            self.logger.log_info("send tritonserver_response_complete_final to end")
         else:
-            raise NotImplementedError("Decoupled mode is not supported")
+            raise NotImplementedError("Offline TTS mode is not supported")
 
     async def execute(self, requests):
         """Execute inference on the batched requests.

+ 1 - 13
runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py

@@ -106,13 +106,10 @@ class TritonPythonModel:
         # Process each request in batch
         for request in requests:
             target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
-            target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor)#.to(self.device)
-            # shift the speech tokens according to the original vocab size
+            target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor)
             target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
             target_speech_tokens = target_speech_tokens.squeeze().tolist()
 
-            # We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.
-           
             finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
                 
             request_id = request.request_id()
@@ -124,23 +121,14 @@ class TritonPythonModel:
                 request, "reference_wav_len").as_numpy().item()
 
             wav_array = torch.from_numpy(wav_array)
-            # Prepare inputs
             wav = wav_array[:, :wav_len].squeeze(0)
 
             spk_id = get_spk_id_from_prompt_audio(wav)
-            # wav = wav.to(self.device)
-
-            # update cache before forward
-            # self.token2wav_model.streaming_flow_cache[request_id]
-            # self.token2wav_model.hift_cache_dict[request_id]
 
             audio_hat = self.token2wav_model.forward_streaming(target_speech_tokens, finalize, request_id=request_id, speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000)
 
-            # get the cache after forward
             outputs = []
 
-            generated_wave = audio_hat.squeeze(0).cpu().numpy()
-
             wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
             outputs.append(wav_tensor)
             inference_response = pb_utils.InferenceResponse(output_tensors=outputs)

+ 0 - 10
runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py

@@ -320,7 +320,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
     def forward(
         self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
     ):
-        # assert all item in prompt_audios_sample_rate is 16000
         assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
         
 
@@ -335,7 +334,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
     def prepare_prompt_audio(
         self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
     ):
-        # assert all item in prompt_audios_sample_rate is 16000
         assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
         
 
@@ -385,7 +383,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
 
             cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
             self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict}
-            print(f"speaker_id {speaker_id} added to cache")
 
         if request_id not in self.streaming_flow_cache:
             self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()}
@@ -394,12 +391,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
             source = torch.zeros(1, 1, 0, device='cuda'),
             speech = torch.zeros(1, 0, device='cuda'),
             )
-        # else:
-        #     for k, v in self.streaming_flow_cache[request_id].items():
-        #         print(f"k: {k}, v: {v.shape}, dtype: {v.dtype}")
-        #     for k, v in self.hift_cache_dict[request_id].items():
-        #         print(f"k: {k}, v: {v.shape}, dtype: {v.dtype}")
-        #     breakpoint()
 
         current_request_cache = self.streaming_flow_cache[request_id]
 
@@ -477,7 +468,6 @@ def get_args():
 if __name__ == "__main__":
     args = get_args()
     model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt)
-    # mkdir output_dir if not exists
     if not os.path.exists(args.output_dir):
         os.makedirs(args.output_dir)
     dataset_name = "yuekai/seed_tts_cosy2"

+ 0 - 7
runtime/triton_trtllm/streaming_inference.py

@@ -35,12 +35,6 @@ def get_args():
     return parser.parse_args()
 
 
-def fake_generated_id_iter(generated_speech_tokens_list):
-    for i in range(len(generated_speech_tokens_list)):
-        yield generated_speech_tokens_list[i]
-
-
-
 if __name__ == "__main__":
     args = get_args()
     
@@ -53,7 +47,6 @@ if __name__ == "__main__":
 
     token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True)
     
-    flow_pre_lookahead_len = 3
     CHUNK_SIZE = 25
     token_frame_rate = 25
     OVERLAP_SIZE = 0