MiXaiLL76 1 year ago
parent
commit
1d05ae5fd3
1 changed files with 8 additions and 23 deletions
  1. 8 23
      tools/extract_embedding.py

+ 8 - 23
tools/extract_embedding.py

@@ -21,11 +21,10 @@ import torch
 import torchaudio
 import torchaudio.compliance.kaldi as kaldi
 from tqdm import tqdm
+from itertools import repeat
 
 
-def extract_embedding(input_list):
-    utt, wav_file, ort_session = input_list
-
+def extract_embedding(utt: str, wav_file: str, ort_session: onnxruntime.InferenceSession):
     audio, sample_rate = torchaudio.load(wav_file)
     if sample_rate != 16000:
         audio = torchaudio.transforms.Resample(
@@ -33,19 +32,7 @@ def extract_embedding(input_list):
         )(audio)
     feat = kaldi.fbank(audio, num_mel_bins=80, dither=0, sample_frequency=16000)
     feat = feat - feat.mean(dim=0, keepdim=True)
-    embedding = (
-        ort_session.run(
-            None,
-            {
-                ort_session.get_inputs()[0]
-                .name: feat.unsqueeze(dim=0)
-                .cpu()
-                .numpy()
-            },
-        )[0]
-        .flatten()
-        .tolist()
-    )
+    embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
     return (utt, embedding)
 
 
@@ -72,16 +59,14 @@ def main(args):
         args.onnx_path, sess_options=option, providers=providers
     )
 
-    inputs = [
-        (utt, utt2wav[utt], ort_session)
-        for utt in tqdm(utt2wav.keys(), desc="Load data")
-    ]
+    all_utt = utt2wav.keys()
+
     with ThreadPoolExecutor(max_workers=args.num_thread) as executor:
         results = list(
             tqdm(
-                executor.map(extract_embedding, inputs),
-                total=len(inputs),
-                desc="Process data: ",
+                executor.map(extract_embedding, all_utt, [utt2wav[utt] for utt in all_utt], repeat(ort_session)),
+                total=len(utt2wav),
+                desc="Process data: "
             )
         )