| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193 |
- # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
- #
- # Redistribution and use in source and binary forms, with or without
- # modification, are permitted provided that the following conditions
- # are met:
- # * Redistributions of source code must retain the above copyright
- # notice, this list of conditions and the following disclaimer.
- # * Redistributions in binary form must reproduce the above copyright
- # notice, this list of conditions and the following disclaimer in the
- # documentation and/or other materials provided with the distribution.
- # * Neither the name of NVIDIA CORPORATION nor the names of its
- # contributors may be used to endorse or promote products derived
- # from this software without specific prior written permission.
- #
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
- # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
- # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
- # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
- # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
- # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
- # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
- # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
- # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- import json
- import os
- import logging
- from typing import List, Dict
- import torch
- from torch.utils.dlpack import to_dlpack
- from torch.nn import functional as F
- import triton_python_backend_utils as pb_utils
- from hyperpyyaml import load_hyperpyyaml
- from cosyvoice.utils.common import fade_in_out
- from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
- from cosyvoice.utils.common import TrtContextWrapper
- from collections import defaultdict
- import numpy as np
- from .token2wav_dit import CosyVoice2_Token2Wav
- import hashlib
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- logger = logging.getLogger(__name__)
- ORIGINAL_VOCAB_SIZE = 151663
- torch.set_num_threads(1)
- def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str:
- """
- Generates a unique ID for a torch.Tensor.
- Tensors with the same elements and properties will have the same ID.
- """
- # Convert tensor to a byte string
- tensor_bytes = tensor.numpy().tobytes()
- # Create a SHA-256 hash of the byte string
- hasher = hashlib.sha256()
- hasher.update(tensor_bytes)
-
- return hasher.hexdigest()
- class TritonPythonModel:
- """Triton Python model for vocoder.
- This model takes global and semantic tokens as input and generates audio waveforms
- using the BiCodec vocoder.
- """
- def initialize(self, args):
- """Initialize the model.
- Args:
- args: Dictionary containing model configuration
- """
- # Parse model parameters
- parameters = json.loads(args['model_config'])['parameters']
- model_params = {key: value["string_value"] for key, value in parameters.items()}
- model_dir = model_params["model_dir"]
- # Initialize device and vocoder
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
- # FIXME: device id settings
- self.token2wav_model = CosyVoice2_Token2Wav(
- model_dir, enable_trt=True, streaming=True
- )
- logger.info("Token2Wav initialized successfully")
- def execute(self, requests):
- """Execute inference on the batched requests.
- Args:
- requests: List of inference requests
- Returns:
- List of inference responses containing generated waveforms
- """
- responses = []
- for request in requests:
- request_id = request.request_id()
- # Get inputs
- target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens")
- target_speech_tokens = torch.utils.dlpack.from_dlpack(target_speech_tokens_tensor.to_dlpack())
- target_speech_tokens = target_speech_tokens.squeeze().tolist()
- finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
- wav_array = pb_utils.get_input_tensor_by_name(request, "reference_wav").as_numpy()
- wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len").as_numpy().item()
- wav = torch.from_numpy(wav_array)[:, :wav_len].squeeze(0)
- spk_id = get_spk_id_from_prompt_audio(wav)
- # Handle cache
- conformer_cnn_cache = pb_utils.get_input_tensor_by_name(request, "conformer_cnn_cache")
- if conformer_cnn_cache is not None:
- self.token2wav_model.streaming_flow_cache[request_id]['conformer_cnn_cache'] = torch.utils.dlpack.from_dlpack(conformer_cnn_cache.to_dlpack())
-
- conformer_att_cache_np = pb_utils.get_input_tensor_by_name(request, "conformer_att_cache")
- self.token2wav_model.streaming_flow_cache[request_id]['conformer_att_cache'] = torch.utils.dlpack.from_dlpack(conformer_att_cache_np.to_dlpack()).transpose(0,1)
-
- estimator_cnn_cache_np = pb_utils.get_input_tensor_by_name(request, "estimator_cnn_cache")
- self.token2wav_model.streaming_flow_cache[request_id]['estimator_cnn_cache'] = torch.utils.dlpack.from_dlpack(estimator_cnn_cache_np.to_dlpack()).squeeze(0)
- estimator_att_cache_np = pb_utils.get_input_tensor_by_name(request, "estimator_att_cache")
- self.token2wav_model.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.utils.dlpack.from_dlpack(estimator_att_cache_np.to_dlpack()).squeeze(0)
- mel_np = pb_utils.get_input_tensor_by_name(request, "mel")
- self.token2wav_model.streaming_flow_cache[request_id]['mel'] = torch.utils.dlpack.from_dlpack(mel_np.to_dlpack())
-
- source_np = pb_utils.get_input_tensor_by_name(request, "source")
- self.token2wav_model.hift_cache_dict[request_id]['source'] = torch.utils.dlpack.from_dlpack(source_np.to_dlpack())
-
- speech_np = pb_utils.get_input_tensor_by_name(request, "speech")
- self.token2wav_model.hift_cache_dict[request_id]['speech'] = torch.utils.dlpack.from_dlpack(speech_np.to_dlpack())
- # Forward pass
- audio_hat = self.token2wav_model.forward_streaming(
- target_speech_tokens,
- finalize,
- request_id=request_id,
- speaker_id=f"{spk_id}",
- prompt_audio=wav,
- prompt_audio_sample_rate=16000
- )
-
- # Prepare outputs
- outputs = []
- wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
- outputs.append(wav_tensor)
-
- if request_id in self.token2wav_model.streaming_flow_cache:
- cache = self.token2wav_model.streaming_flow_cache[request_id]
- hifigan_cache = self.token2wav_model.hift_cache_dict[request_id]
- conformer_cnn_cache = cache['conformer_cnn_cache']
- conformer_att_cache = cache['conformer_att_cache'].transpose(0,1)
- estimator_cnn_cache = cache['estimator_cnn_cache'].unsqueeze(0)
- estimator_att_cache = cache['estimator_att_cache'].unsqueeze(0)
- mel = hifigan_cache['mel']
- source = hifigan_cache['source']
- speech = hifigan_cache['speech']
- outputs.extend([
- pb_utils.Tensor.from_dlpack("conformer_cnn_cache", to_dlpack(conformer_cnn_cache)),
- pb_utils.Tensor.from_dlpack("conformer_att_cache", to_dlpack(conformer_att_cache)),
- pb_utils.Tensor.from_dlpack("estimator_cnn_cache", to_dlpack(estimator_cnn_cache)),
- pb_utils.Tensor.from_dlpack("estimator_att_cache", to_dlpack(estimator_att_cache)),
- pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel)),
- pb_utils.Tensor.from_dlpack("source", to_dlpack(source)),
- pb_utils.Tensor.from_dlpack("speech", to_dlpack(speech)),
- ])
- else:
- outputs.extend([pb_utils.Tensor("conformer_cnn_cache", np.array([], dtype=np.float16)),
- pb_utils.Tensor("conformer_att_cache", np.array([], dtype=np.float16)),
- pb_utils.Tensor("estimator_cnn_cache", np.array([], dtype=np.float16)),
- pb_utils.Tensor("estimator_att_cache", np.array([], dtype=np.float16)),
- pb_utils.Tensor("mel", np.array([], dtype=np.float32)),
- pb_utils.Tensor("source", np.array([], dtype=np.float32)),
- pb_utils.Tensor("speech", np.array([], dtype=np.float32)),
- ])
- inference_response = pb_utils.InferenceResponse(output_tensors=outputs)
- responses.append(inference_response)
- return responses
- def finalize(self):
- self.logger.log_info("Finalizing Token2WavDiT model")
|