| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869 |
- import json
- import os
- import logging
- import torch
- from torch.utils.dlpack import to_dlpack
- import triton_python_backend_utils as pb_utils
- from hyperpyyaml import load_hyperpyyaml
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- logger = logging.getLogger(__name__)
- torch.set_num_threads(1)
- class TritonPythonModel:
- """Triton Python model for CosyVoice3 vocoder (CausalHiFTGenerator).
- Stateless: converts mel spectrogram to waveform.
- CausalHiFTGenerator manages its own internal cache.
- """
- def initialize(self, args):
- parameters = json.loads(args['model_config'])['parameters']
- model_params = {k: v["string_value"] for k, v in parameters.items()}
- model_dir = model_params["model_dir"]
- self.device = torch.device("cuda")
- # Load CausalHiFTGenerator from cosyvoice3.yaml
- with open(os.path.join(model_dir, 'cosyvoice3.yaml'), 'r') as f:
- configs = load_hyperpyyaml(f, overrides={
- 'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')
- })
- self.hift = configs['hift']
- hift_state_dict = {
- k.replace('generator.', ''): v
- for k, v in torch.load(
- os.path.join(model_dir, 'hift.pt'),
- map_location='cpu', weights_only=True
- ).items()
- }
- self.hift.load_state_dict(hift_state_dict, strict=True)
- self.hift.to(self.device).eval()
- logger.info("CausalHiFTGenerator initialized successfully")
- def execute(self, requests):
- responses = []
- for req_idx, request in enumerate(requests):
- mel = pb_utils.get_input_tensor_by_name(request, "mel")
- mel = torch.utils.dlpack.from_dlpack(mel.to_dlpack()).to(self.device)
- if mel.dim() == 2:
- mel = mel.unsqueeze(0) # [80, T] -> [1, 80, T]
- finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
- with torch.no_grad():
- speech, _ = self.hift.inference(speech_feat=mel, finalize=finalize)
- # speech shape: [1, 1, S] or [1, S] depending on hift version
- speech = speech.squeeze() # flatten to [S]
- speech_tensor = pb_utils.Tensor.from_dlpack(
- "tts_speech", to_dlpack(speech.unsqueeze(0))) # [1, S] for batch dim
- inference_response = pb_utils.InferenceResponse(
- output_tensors=[speech_tensor])
- responses.append(inference_response)
- return responses
|