1
0
Просмотр исходного кода

Implementing concurrent.futures

MiXaiLL76 1 год назад
Родитель
Сommit
73271d46f9
1 измененных файлов с 58 добавлено и 83 удалено
  1. 58 83
      tools/extract_embedding.py

+ 58 - 83
tools/extract_embedding.py

@@ -13,71 +13,40 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import argparse
+import os
+from concurrent.futures import ThreadPoolExecutor
+
+import onnxruntime
 import torch
 import torchaudio
-from tqdm import tqdm
-import onnxruntime
 import torchaudio.compliance.kaldi as kaldi
-from queue import Queue, Empty
-from threading import Thread
-
-
-class ExtractEmbedding:
-    def __init__(self, model_path: str, queue: Queue, out_queue: Queue):
-        self.model_path = model_path
-        self.queue = queue
-        self.out_queue = out_queue
-        self.is_run = True
-
-    def run(self):
-        self.consumer_thread = Thread(target=self.consumer)
-        self.consumer_thread.start()
-
-    def stop(self):
-        self.is_run = False
-        self.consumer_thread.join()
-
-    def consumer(self):
-        option = onnxruntime.SessionOptions()
-        option.graph_optimization_level = (
-            onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
-        )
-        option.intra_op_num_threads = 1
-        providers = ["CPUExecutionProvider"]
-        ort_session = onnxruntime.InferenceSession(
-            self.model_path, sess_options=option, providers=providers
-        )
+from tqdm import tqdm
 
-        while self.is_run:
-            try:
-                utt, wav_file = self.queue.get(timeout=1)
 
-                audio, sample_rate = torchaudio.load(wav_file)
-                if sample_rate != 16000:
-                    audio = torchaudio.transforms.Resample(
-                        orig_freq=sample_rate, new_freq=16000
-                    )(audio)
-                feat = kaldi.fbank(
-                    audio, num_mel_bins=80, dither=0, sample_frequency=16000
-                )
-                feat = feat - feat.mean(dim=0, keepdim=True)
-                embedding = (
-                    ort_session.run(
-                        None,
-                        {
-                            ort_session.get_inputs()[0]
-                            .name: feat.unsqueeze(dim=0)
-                            .cpu()
-                            .numpy()
-                        },
-                    )[0]
-                    .flatten()
-                    .tolist()
-                )
-                self.out_queue.put((utt, embedding))
-            except Empty:
-                self.is_run = False
-                break
+def extract_embedding(input_list):
+    utt, wav_file, ort_session = input_list
+
+    audio, sample_rate = torchaudio.load(wav_file)
+    if sample_rate != 16000:
+        audio = torchaudio.transforms.Resample(
+            orig_freq=sample_rate, new_freq=16000
+        )(audio)
+    feat = kaldi.fbank(audio, num_mel_bins=80, dither=0, sample_frequency=16000)
+    feat = feat - feat.mean(dim=0, keepdim=True)
+    embedding = (
+        ort_session.run(
+            None,
+            {
+                ort_session.get_inputs()[0]
+                .name: feat.unsqueeze(dim=0)
+                .cpu()
+                .numpy()
+            },
+        )[0]
+        .flatten()
+        .tolist()
+    )
+    return (utt, embedding)
 
 
 def main(args):
@@ -91,32 +60,38 @@ def main(args):
             l = l.replace("\n", "").split()
             utt2spk[l[0]] = l[1]
 
-    input_queue = Queue()
-    output_queue = Queue()
-    consumers = [
-        ExtractEmbedding(args.onnx_path, input_queue, output_queue)
-        for _ in range(args.num_thread)
+    assert os.path.exists(args.onnx_path), "onnx_path not exists"
+
+    option = onnxruntime.SessionOptions()
+    option.graph_optimization_level = (
+        onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
+    )
+    option.intra_op_num_threads = 1
+    providers = ["CPUExecutionProvider"]
+    ort_session = onnxruntime.InferenceSession(
+        args.onnx_path, sess_options=option, providers=providers
+    )
+
+    inputs = [
+        (utt, utt2wav[utt], ort_session)
+        for utt in tqdm(utt2wav.keys(), desc="Load data")
     ]
+    with ThreadPoolExecutor(max_workers=args.num_thread) as executor:
+        results = list(
+            tqdm(
+                executor.map(extract_embedding, inputs),
+                total=len(inputs),
+                desc="Process data: ",
+            )
+        )
 
     utt2embedding, spk2embedding = {}, {}
-    for utt in tqdm(utt2wav.keys(), desc="Load data"):
-        input_queue.put((utt, utt2wav[utt]))
-
-    for c in consumers:
-        c.run()
-
-    with tqdm(desc="Process data: ", total=len(utt2wav)) as pbar:
-        while any([c.is_run for c in consumers]):
-            try:
-                utt, embedding = output_queue.get(timeout=1)
-                utt2embedding[utt] = embedding
-                spk = utt2spk[utt]
-                if spk not in spk2embedding:
-                    spk2embedding[spk] = []
-                spk2embedding[spk].append(embedding)
-                pbar.update(1)
-            except Empty:
-                continue
+    for utt, embedding in results:
+        utt2embedding[utt] = embedding
+        spk = utt2spk[utt]
+        if spk not in spk2embedding:
+            spk2embedding[spk] = []
+        spk2embedding[spk].append(embedding)
 
     for k, v in spk2embedding.items():
         spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()