onnx.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import onnxruntime
  2. import torch, random
  3. from torch import nn
  4. import os
  5. import whisper
  6. import numpy as np
  7. import torchaudio.compliance.kaldi as kaldi
  8. import torch.nn.functional as F
  9. class SpeechTokenExtractor():
  10. def __init__(self, model_path):
  11. self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
  12. option = onnxruntime.SessionOptions()
  13. option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
  14. option.intra_op_num_threads = 1
  15. self.speech_tokenizer_session = onnxruntime.InferenceSession(model_path,
  16. sess_options=option,
  17. providers=[("CUDAExecutionProvider", {'device_id': self.local_rank})])
  18. def inference(self, feat, feat_lengths):
  19. speech_token = self.speech_tokenizer_session.run(None,
  20. {self.speech_tokenizer_session.get_inputs()[0].name:
  21. feat.transpose(1, 2).detach().cpu().numpy(),
  22. self.speech_tokenizer_session.get_inputs()[1].name:
  23. feat_lengths.detach().cpu().numpy()})[0]
  24. return torch.tensor(speech_token).to(feat), (feat_lengths / 4).to(torch.int32).to(feat.device)
  25. class EmbeddingExtractor():
  26. def __init__(self, model_path):
  27. option = onnxruntime.SessionOptions()
  28. option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
  29. option.intra_op_num_threads = 1
  30. self.max_len = 10 * 16000
  31. self.campplus_session = onnxruntime.InferenceSession(model_path,
  32. sess_options=option,
  33. providers=["CPUExecutionProvider"])
  34. def inference(self, speech):
  35. if speech.shape[1] > self.max_len:
  36. start_index = random.randint(0, speech.shape[1] - self.max_len)
  37. speech = speech[:, start_index: start_index + self.max_len]
  38. feat = kaldi.fbank(speech,
  39. num_mel_bins=80,
  40. dither=0,
  41. sample_frequency=16000)
  42. feat = feat - feat.mean(dim=0, keepdim=True)
  43. embedding = self.campplus_session.run(None,
  44. {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
  45. return torch.tensor(embedding).to(speech.device)
  46. # singleton mode, only initialized once
  47. onnx_path = os.environ.get('onnx_path')
  48. if onnx_path is not None:
  49. embedding_extractor, online_feature = EmbeddingExtractor(model_path=os.path.join(onnx_path, 'campplus.onnx')), True
  50. else:
  51. embedding_extractor, online_feature = None, False