model.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import json
  2. import os
  3. import logging
  4. import torch
  5. from torch.utils.dlpack import to_dlpack
  6. import triton_python_backend_utils as pb_utils
  7. from hyperpyyaml import load_hyperpyyaml
  8. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  9. logger = logging.getLogger(__name__)
  10. torch.set_num_threads(1)
  11. class TritonPythonModel:
  12. """Triton Python model for CosyVoice3 vocoder (CausalHiFTGenerator).
  13. Stateless: converts mel spectrogram to waveform.
  14. CausalHiFTGenerator manages its own internal cache.
  15. """
  16. def initialize(self, args):
  17. parameters = json.loads(args['model_config'])['parameters']
  18. model_params = {k: v["string_value"] for k, v in parameters.items()}
  19. model_dir = model_params["model_dir"]
  20. self.device = torch.device("cuda")
  21. # Load CausalHiFTGenerator from cosyvoice3.yaml
  22. with open(os.path.join(model_dir, 'cosyvoice3.yaml'), 'r') as f:
  23. configs = load_hyperpyyaml(f, overrides={
  24. 'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')
  25. })
  26. self.hift = configs['hift']
  27. hift_state_dict = {
  28. k.replace('generator.', ''): v
  29. for k, v in torch.load(
  30. os.path.join(model_dir, 'hift.pt'),
  31. map_location='cpu', weights_only=True
  32. ).items()
  33. }
  34. self.hift.load_state_dict(hift_state_dict, strict=True)
  35. self.hift.to(self.device).eval()
  36. logger.info("CausalHiFTGenerator initialized successfully")
  37. def execute(self, requests):
  38. responses = []
  39. for req_idx, request in enumerate(requests):
  40. mel = pb_utils.get_input_tensor_by_name(request, "mel")
  41. mel = torch.utils.dlpack.from_dlpack(mel.to_dlpack()).to(self.device)
  42. if mel.dim() == 2:
  43. mel = mel.unsqueeze(0) # [80, T] -> [1, 80, T]
  44. finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
  45. with torch.no_grad():
  46. speech, _ = self.hift.inference(speech_feat=mel, finalize=finalize)
  47. # speech shape: [1, 1, S] or [1, S] depending on hift version
  48. speech = speech.squeeze() # flatten to [S]
  49. speech_tensor = pb_utils.Tensor.from_dlpack(
  50. "tts_speech", to_dlpack(speech.unsqueeze(0))) # [1, S] for batch dim
  51. inference_response = pb_utils.InferenceResponse(
  52. output_tensors=[speech_tensor])
  53. responses.append(inference_response)
  54. return responses