Bladeren bron

use thread pool in tools

lyuxiang.lx 1 jaar geleden
bovenliggende
commit
ff8e63567a
2 gewijzigde bestanden met toevoegingen van 61 en 70 verwijderingen
  1. 30 47
      tools/extract_embedding.py
  2. 31 23
      tools/extract_speech_token.py

+ 30 - 47
tools/extract_embedding.py

@@ -13,74 +13,39 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import argparse
-import os
-from concurrent.futures import ThreadPoolExecutor
-
+from concurrent.futures import ThreadPoolExecutor, as_completed
 import onnxruntime
 import torch
 import torchaudio
 import torchaudio.compliance.kaldi as kaldi
 from tqdm import tqdm
-from itertools import repeat
 
 
-def extract_embedding(utt: str, wav_file: str, ort_session: onnxruntime.InferenceSession):
-    audio, sample_rate = torchaudio.load(wav_file)
+def single_job(utt):
+    audio, sample_rate = torchaudio.load(utt2wav[utt])
     if sample_rate != 16000:
-        audio = torchaudio.transforms.Resample(
-            orig_freq=sample_rate, new_freq=16000
-        )(audio)
-    feat = kaldi.fbank(audio, num_mel_bins=80, dither=0, sample_frequency=16000)
+        audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
+    feat = kaldi.fbank(audio,
+                        num_mel_bins=80,
+                        dither=0,
+                        sample_frequency=16000)
     feat = feat - feat.mean(dim=0, keepdim=True)
     embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
-    return (utt, embedding)
+    return utt, embedding
 
 
 def main(args):
-    utt2wav, utt2spk = {}, {}
-    with open("{}/wav.scp".format(args.dir)) as f:
-        for l in f:
-            l = l.replace("\n", "").split()
-            utt2wav[l[0]] = l[1]
-    with open("{}/utt2spk".format(args.dir)) as f:
-        for l in f:
-            l = l.replace("\n", "").split()
-            utt2spk[l[0]] = l[1]
-
-    assert os.path.exists(args.onnx_path), "onnx_path not exists"
-
-    option = onnxruntime.SessionOptions()
-    option.graph_optimization_level = (
-        onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
-    )
-    option.intra_op_num_threads = 1
-    providers = ["CPUExecutionProvider"]
-    ort_session = onnxruntime.InferenceSession(
-        args.onnx_path, sess_options=option, providers=providers
-    )
-
-    all_utt = utt2wav.keys()
-
-    with ThreadPoolExecutor(max_workers=args.num_thread) as executor:
-        results = list(
-            tqdm(
-                executor.map(extract_embedding, all_utt, [utt2wav[utt] for utt in all_utt], repeat(ort_session)),
-                total=len(utt2wav),
-                desc="Process data: "
-            )
-        )
-
+    all_task = [executor.submit(single_job, utt) for utt in utt2wav.keys()]
     utt2embedding, spk2embedding = {}, {}
-    for utt, embedding in results:
+    for future in tqdm(as_completed(all_task)):
+        utt, embedding = future.result()
         utt2embedding[utt] = embedding
         spk = utt2spk[utt]
         if spk not in spk2embedding:
             spk2embedding[spk] = []
         spk2embedding[spk].append(embedding)
-
     for k, v in spk2embedding.items():
         spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()
-
     torch.save(utt2embedding, "{}/utt2embedding.pt".format(args.dir))
     torch.save(spk2embedding, "{}/spk2embedding.pt".format(args.dir))
 
@@ -91,4 +56,22 @@ if __name__ == "__main__":
     parser.add_argument("--onnx_path", type=str)
     parser.add_argument("--num_thread", type=int, default=8)
     args = parser.parse_args()
+
+    utt2wav, utt2spk = {}, {}
+    with open('{}/wav.scp'.format(args.dir)) as f:
+        for l in f:
+            l = l.replace('\n', '').split()
+            utt2wav[l[0]] = l[1]
+    with open('{}/utt2spk'.format(args.dir)) as f:
+        for l in f:
+            l = l.replace('\n', '').split()
+            utt2spk[l[0]] = l[1]
+
+    option = onnxruntime.SessionOptions()
+    option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
+    option.intra_op_num_threads = 1
+    providers = ["CPUExecutionProvider"]
+    ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
+    executor = ThreadPoolExecutor(max_workers=args.num_thread)
+
     main(args)

+ 31 - 23
tools/extract_speech_token.py

@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import argparse
+from concurrent.futures import ThreadPoolExecutor, as_completed
 import logging
 import torch
 from tqdm import tqdm
@@ -22,7 +23,36 @@ import torchaudio
 import whisper
 
 
+def single_job(utt):
+    audio, sample_rate = torchaudio.load(utt2wav[utt])
+    if sample_rate != 16000:
+        audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
+    if audio.shape[1] / 16000 > 30:
+        logging.warning('do not support extract speech token for audio longer than 30s')
+        speech_token = []
+    else:
+        feat = whisper.log_mel_spectrogram(audio, n_mels=128)
+        speech_token = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
+                                                ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
+    return utt, speech_token
+
+
 def main(args):
+    all_task = [executor.submit(single_job, utt) for utt in utt2wav.keys()]
+    utt2speech_token = {}
+    for future in tqdm(as_completed(all_task)):
+        utt, speech_token = future.result()
+        utt2speech_token[utt] = speech_token
+    torch.save(utt2speech_token, '{}/utt2speech_token.pt'.format(args.dir))
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--dir", type=str)
+    parser.add_argument("--onnx_path", type=str)
+    parser.add_argument("--num_thread", type=int, default=8)
+    args = parser.parse_args()
+
     utt2wav = {}
     with open('{}/wav.scp'.format(args.dir)) as f:
         for l in f:
@@ -34,28 +64,6 @@ def main(args):
     option.intra_op_num_threads = 1
     providers = ["CUDAExecutionProvider"]
     ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
+    executor = ThreadPoolExecutor(max_workers=args.num_thread)
 
-    utt2speech_token = {}
-    for utt in tqdm(utt2wav.keys()):
-        audio, sample_rate = torchaudio.load(utt2wav[utt])
-        if sample_rate != 16000:
-            audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
-        if audio.shape[1] / 16000 > 30:
-            logging.warning('do not support extract speech token for audio longer than 30s')
-            speech_token = []
-        else:
-            feat = whisper.log_mel_spectrogram(audio, n_mels=128)
-            speech_token = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
-                                                  ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
-        utt2speech_token[utt] = speech_token
-    torch.save(utt2speech_token, '{}/utt2speech_token.pt'.format(args.dir))
-
-
-if __name__ == "__main__":
-    parser = argparse.ArgumentParser()
-    parser.add_argument('--dir',
-                        type=str)
-    parser.add_argument('--onnx_path',
-                        type=str)
-    args = parser.parse_args()
     main(args)