|
|
@@ -21,11 +21,10 @@ import torch
|
|
|
import torchaudio
|
|
|
import torchaudio.compliance.kaldi as kaldi
|
|
|
from tqdm import tqdm
|
|
|
+from itertools import repeat
|
|
|
|
|
|
|
|
|
-def extract_embedding(input_list):
|
|
|
- utt, wav_file, ort_session = input_list
|
|
|
-
|
|
|
+def extract_embedding(utt: str, wav_file: str, ort_session: onnxruntime.InferenceSession):
|
|
|
audio, sample_rate = torchaudio.load(wav_file)
|
|
|
if sample_rate != 16000:
|
|
|
audio = torchaudio.transforms.Resample(
|
|
|
@@ -33,19 +32,7 @@ def extract_embedding(input_list):
|
|
|
)(audio)
|
|
|
feat = kaldi.fbank(audio, num_mel_bins=80, dither=0, sample_frequency=16000)
|
|
|
feat = feat - feat.mean(dim=0, keepdim=True)
|
|
|
- embedding = (
|
|
|
- ort_session.run(
|
|
|
- None,
|
|
|
- {
|
|
|
- ort_session.get_inputs()[0]
|
|
|
- .name: feat.unsqueeze(dim=0)
|
|
|
- .cpu()
|
|
|
- .numpy()
|
|
|
- },
|
|
|
- )[0]
|
|
|
- .flatten()
|
|
|
- .tolist()
|
|
|
- )
|
|
|
+ embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
|
|
return (utt, embedding)
|
|
|
|
|
|
|
|
|
@@ -72,16 +59,14 @@ def main(args):
|
|
|
args.onnx_path, sess_options=option, providers=providers
|
|
|
)
|
|
|
|
|
|
- inputs = [
|
|
|
- (utt, utt2wav[utt], ort_session)
|
|
|
- for utt in tqdm(utt2wav.keys(), desc="Load data")
|
|
|
- ]
|
|
|
+ all_utt = utt2wav.keys()
|
|
|
+
|
|
|
with ThreadPoolExecutor(max_workers=args.num_thread) as executor:
|
|
|
results = list(
|
|
|
tqdm(
|
|
|
- executor.map(extract_embedding, inputs),
|
|
|
- total=len(inputs),
|
|
|
- desc="Process data: ",
|
|
|
+ executor.map(extract_embedding, all_utt, [utt2wav[utt] for utt in all_utt], repeat(ort_session)),
|
|
|
+ total=len(utt2wav),
|
|
|
+ desc="Process data: "
|
|
|
)
|
|
|
)
|
|
|
|