model.py 3.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2. #
  3. # Redistribution and use in source and binary forms, with or without
  4. # modification, are permitted provided that the following conditions
  5. # are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of NVIDIA CORPORATION nor the names of its
  12. # contributors may be used to endorse or promote products derived
  13. # from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
  16. # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  17. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
  18. # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
  19. # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
  20. # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
  21. # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
  22. # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
  23. # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  24. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  25. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  26. import json
  27. import torch
  28. from torch.utils.dlpack import to_dlpack
  29. import triton_python_backend_utils as pb_utils
  30. import os
  31. import numpy as np
  32. import s3tokenizer
  33. torch.set_num_threads(1)
  34. ORIGINAL_VOCAB_SIZE = 151663
  35. class TritonPythonModel:
  36. """Triton Python model for audio tokenization.
  37. This model takes reference audio input and extracts semantic tokens
  38. using s3tokenizer.
  39. """
  40. def initialize(self, args):
  41. """Initialize the model.
  42. Args:
  43. args: Dictionary containing model configuration
  44. """
  45. # Parse model parameters
  46. parameters = json.loads(args['model_config'])['parameters']
  47. model_params = {k: v["string_value"] for k, v in parameters.items()}
  48. self.device = torch.device("cuda")
  49. model_path = os.path.join(model_params["model_dir"], "speech_tokenizer_v2.onnx")
  50. self.audio_tokenizer = s3tokenizer.load_model(model_path).to(self.device)
  51. def execute(self, requests):
  52. """Execute inference on the batched requests.
  53. Args:
  54. requests: List of inference requests
  55. Returns:
  56. List of inference responses containing tokenized outputs
  57. """
  58. mels = []
  59. # Process each request in batch
  60. for request in requests:
  61. # Extract input tensors
  62. wav_array = pb_utils.get_input_tensor_by_name(
  63. request, "reference_wav").as_numpy()
  64. wav_len = pb_utils.get_input_tensor_by_name(
  65. request, "reference_wav_len").as_numpy().item()
  66. wav_array = torch.from_numpy(wav_array).to(self.device)
  67. # Prepare inputs
  68. wav = wav_array[:, :wav_len].squeeze(0)
  69. mels.append(s3tokenizer.log_mel_spectrogram(wav))
  70. mels, mels_lens = s3tokenizer.padding(mels)
  71. codes, codes_lens = self.audio_tokenizer.quantize(mels.to(self.device), mels_lens.to(self.device))
  72. codes = codes.clone() + ORIGINAL_VOCAB_SIZE
  73. responses = []
  74. for i in range(len(requests)):
  75. prompt_speech_tokens = codes[i, :codes_lens[i].item()]
  76. prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack(
  77. "prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
  78. inference_response = pb_utils.InferenceResponse(
  79. output_tensors=[prompt_speech_tokens_tensor])
  80. responses.append(inference_response)
  81. return responses