onnx.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import onnxruntime
  2. import torch, random
  3. import os
  4. import torchaudio.compliance.kaldi as kaldi
  5. class SpeechTokenExtractor():
  6. def __init__(self, model_path):
  7. self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
  8. option = onnxruntime.SessionOptions()
  9. option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
  10. option.intra_op_num_threads = 1
  11. self.speech_tokenizer_session = onnxruntime.InferenceSession(model_path,
  12. sess_options=option,
  13. providers=[("CUDAExecutionProvider", {'device_id': self.local_rank})])
  14. def inference(self, feat, feat_lengths, device):
  15. speech_token = self.speech_tokenizer_session.run(None,
  16. {self.speech_tokenizer_session.get_inputs()[0].name:
  17. feat.transpose(1, 2).detach().cpu().numpy(),
  18. self.speech_tokenizer_session.get_inputs()[1].name:
  19. feat_lengths.detach().cpu().numpy()})[0]
  20. return torch.tensor(speech_token).to(torch.int32).to(device), (feat_lengths / 4).to(torch.int32).to(device)
  21. class EmbeddingExtractor():
  22. def __init__(self, model_path):
  23. option = onnxruntime.SessionOptions()
  24. option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
  25. option.intra_op_num_threads = 1
  26. self.max_len = 10 * 16000
  27. self.campplus_session = onnxruntime.InferenceSession(model_path,
  28. sess_options=option,
  29. providers=["CPUExecutionProvider"])
  30. def inference(self, speech):
  31. if speech.shape[1] > self.max_len:
  32. start_index = random.randint(0, speech.shape[1] - self.max_len)
  33. speech = speech[:, start_index: start_index + self.max_len]
  34. feat = kaldi.fbank(speech,
  35. num_mel_bins=80,
  36. dither=0,
  37. sample_frequency=16000)
  38. feat = feat - feat.mean(dim=0, keepdim=True)
  39. embedding = self.campplus_session.run(None,
  40. {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
  41. return torch.tensor(embedding).to(speech.device)
  42. # singleton mode, only initialized once
  43. onnx_path = os.environ.get('onnx_path')
  44. if onnx_path is not None:
  45. embedding_extractor, online_feature = EmbeddingExtractor(model_path=os.path.join(onnx_path, 'campplus.onnx')), True
  46. else:
  47. embedding_extractor, online_feature = None, False