model.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2. #
  3. # Redistribution and use in source and binary forms, with or without
  4. # modification, are permitted provided that the following conditions
  5. # are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of NVIDIA CORPORATION nor the names of its
  12. # contributors may be used to endorse or promote products derived
  13. # from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
  16. # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  17. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
  18. # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
  19. # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
  20. # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
  21. # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
  22. # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
  23. # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  24. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  25. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  26. import json
  27. import os
  28. import logging
  29. from typing import List, Dict
  30. import torch
  31. from torch.utils.dlpack import to_dlpack
  32. from torch.nn import functional as F
  33. import triton_python_backend_utils as pb_utils
  34. from hyperpyyaml import load_hyperpyyaml
  35. from cosyvoice.utils.common import fade_in_out
  36. from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
  37. from cosyvoice.utils.common import TrtContextWrapper
  38. from collections import defaultdict
  39. import numpy as np
  40. from .token2wav_dit import CosyVoice2_Token2Wav
  41. import hashlib
  42. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  43. logger = logging.getLogger(__name__)
  44. ORIGINAL_VOCAB_SIZE = 151663
  45. torch.set_num_threads(1)
  46. def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str:
  47. """
  48. Generates a unique ID for a torch.Tensor.
  49. Tensors with the same elements and properties will have the same ID.
  50. """
  51. # Convert tensor to a byte string
  52. tensor_bytes = tensor.numpy().tobytes()
  53. # Create a SHA-256 hash of the byte string
  54. hasher = hashlib.sha256()
  55. hasher.update(tensor_bytes)
  56. return hasher.hexdigest()
  57. class TritonPythonModel:
  58. """Triton Python model for vocoder.
  59. This model takes global and semantic tokens as input and generates audio waveforms
  60. using the BiCodec vocoder.
  61. """
  62. def initialize(self, args):
  63. """Initialize the model.
  64. Args:
  65. args: Dictionary containing model configuration
  66. """
  67. # Parse model parameters
  68. parameters = json.loads(args['model_config'])['parameters']
  69. model_params = {key: value["string_value"] for key, value in parameters.items()}
  70. model_dir = model_params["model_dir"]
  71. # Initialize device and vocoder
  72. self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  73. logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
  74. # FIXME: device id settings
  75. self.token2wav_model = CosyVoice2_Token2Wav(
  76. model_dir, enable_trt=True, streaming=True
  77. )
  78. logger.info("Token2Wav initialized successfully")
  79. def execute(self, requests):
  80. """Execute inference on the batched requests.
  81. Args:
  82. requests: List of inference requests
  83. Returns:
  84. List of inference responses containing generated waveforms
  85. """
  86. responses = []
  87. # Process each request in batch
  88. for request in requests:
  89. target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
  90. target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor)#.to(self.device)
  91. # shift the speech tokens according to the original vocab size
  92. target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
  93. target_speech_tokens = target_speech_tokens.squeeze().tolist()
  94. # We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.
  95. finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
  96. request_id = request.request_id()
  97. wav_array = pb_utils.get_input_tensor_by_name(
  98. request, "reference_wav").as_numpy()
  99. wav_len = pb_utils.get_input_tensor_by_name(
  100. request, "reference_wav_len").as_numpy().item()
  101. wav_array = torch.from_numpy(wav_array)
  102. # Prepare inputs
  103. wav = wav_array[:, :wav_len].squeeze(0)
  104. spk_id = get_spk_id_from_prompt_audio(wav)
  105. # wav = wav.to(self.device)
  106. # update cache before forward
  107. # self.token2wav_model.streaming_flow_cache[request_id]
  108. # self.token2wav_model.hift_cache_dict[request_id]
  109. audio_hat = self.token2wav_model.forward_streaming(target_speech_tokens, finalize, request_id=request_id, speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000)
  110. # get the cache after forward
  111. outputs = []
  112. generated_wave = audio_hat.squeeze(0).cpu().numpy()
  113. wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
  114. outputs.append(wav_tensor)
  115. inference_response = pb_utils.InferenceResponse(output_tensors=outputs)
  116. responses.append(inference_response)
  117. return responses