Răsfoiți Sursa

add threading

MiXaiLL76 1 an în urmă
părinte
comite
7b3e285bca
1 a modificat fișierele cu 94 adăugiri și 30 ștergeri
  1. 94 30
      tools/extract_embedding.py

+ 94 - 30
tools/extract_embedding.py

@@ -18,53 +18,117 @@ import torchaudio
 from tqdm import tqdm
 import onnxruntime
 import torchaudio.compliance.kaldi as kaldi
+from queue import Queue, Empty
+from threading import Thread
+
+
+class ExtractEmbedding:
+    def __init__(self, model_path: str, queue: Queue, out_queue: Queue):
+        self.model_path = model_path
+        self.queue = queue
+        self.out_queue = out_queue
+        self.is_run = True
+
+    def run(self):
+        self.consumer_thread = Thread(target=self.consumer)
+        self.consumer_thread.start()
+
+    def stop(self):
+        self.is_run = False
+        self.consumer_thread.join()
+
+    def consumer(self):
+        option = onnxruntime.SessionOptions()
+        option.graph_optimization_level = (
+            onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
+        )
+        option.intra_op_num_threads = 1
+        providers = ["CPUExecutionProvider"]
+        ort_session = onnxruntime.InferenceSession(
+            self.model_path, sess_options=option, providers=providers
+        )
+
+        while self.is_run:
+            try:
+                utt, wav_file = self.queue.get(timeout=1)
+
+                audio, sample_rate = torchaudio.load(wav_file)
+                if sample_rate != 16000:
+                    audio = torchaudio.transforms.Resample(
+                        orig_freq=sample_rate, new_freq=16000
+                    )(audio)
+                feat = kaldi.fbank(
+                    audio, num_mel_bins=80, dither=0, sample_frequency=16000
+                )
+                feat = feat - feat.mean(dim=0, keepdim=True)
+                embedding = (
+                    ort_session.run(
+                        None,
+                        {
+                            ort_session.get_inputs()[0]
+                            .name: feat.unsqueeze(dim=0)
+                            .cpu()
+                            .numpy()
+                        },
+                    )[0]
+                    .flatten()
+                    .tolist()
+                )
+                self.out_queue.put((utt, embedding))
+            except Empty:
+                self.is_run = False
+                break
 
 
 def main(args):
     utt2wav, utt2spk = {}, {}
-    with open('{}/wav.scp'.format(args.dir)) as f:
+    with open("{}/wav.scp".format(args.dir)) as f:
         for l in f:
-            l = l.replace('\n', '').split()
+            l = l.replace("\n", "").split()
             utt2wav[l[0]] = l[1]
-    with open('{}/utt2spk'.format(args.dir)) as f:
+    with open("{}/utt2spk".format(args.dir)) as f:
         for l in f:
-            l = l.replace('\n', '').split()
+            l = l.replace("\n", "").split()
             utt2spk[l[0]] = l[1]
 
-    option = onnxruntime.SessionOptions()
-    option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
-    option.intra_op_num_threads = 1
-    providers = ["CPUExecutionProvider"]
-    ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
+    input_queue = Queue()
+    output_queue = Queue()
+    consumers = [
+        ExtractEmbedding(args.onnx_path, input_queue, output_queue)
+        for _ in range(args.num_thread)
+    ]
 
     utt2embedding, spk2embedding = {}, {}
-    for utt in tqdm(utt2wav.keys()):
-        audio, sample_rate = torchaudio.load(utt2wav[utt])
-        if sample_rate != 16000:
-            audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
-        feat = kaldi.fbank(audio,
-                           num_mel_bins=80,
-                           dither=0,
-                           sample_frequency=16000)
-        feat = feat - feat.mean(dim=0, keepdim=True)
-        embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
-        utt2embedding[utt] = embedding
-        spk = utt2spk[utt]
-        if spk not in spk2embedding:
-            spk2embedding[spk] = []
-        spk2embedding[spk].append(embedding)
+    for utt in tqdm(utt2wav.keys(), desc="Load data"):
+        input_queue.put((utt, utt2wav[utt]))
+
+    for c in consumers:
+        c.run()
+
+    with tqdm(desc="Process data: ", total=len(utt2wav)) as pbar:
+        while any([c.is_run for c in consumers]):
+            try:
+                utt, embedding = output_queue.get(timeout=1)
+                utt2embedding[utt] = embedding
+                spk = utt2spk[utt]
+                if spk not in spk2embedding:
+                    spk2embedding[spk] = []
+                spk2embedding[spk].append(embedding)
+                pbar.update(1)
+            except Empty:
+                continue
+
     for k, v in spk2embedding.items():
         spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()
 
-    torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir))
-    torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir))
+    torch.save(utt2embedding, "{}/utt2embedding.pt".format(args.dir))
+    torch.save(spk2embedding, "{}/spk2embedding.pt".format(args.dir))
 
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
-    parser.add_argument('--dir',
-                        type=str)
-    parser.add_argument('--onnx_path',
-                        type=str)
+    parser.add_argument("--dir", type=str)
+    parser.add_argument("--onnx_path", type=str)
+    parser.add_argument("--num_thread", type=int, default=8)
     args = parser.parse_args()
     main(args)