|
|
@@ -0,0 +1,753 @@
|
|
|
+#!/usr/bin/env python3
|
|
|
+#
|
|
|
+# Copyright (c) 2023 by manyeyes
|
|
|
+# Copyright (c) 2023 Xiaomi Corporation
|
|
|
+
|
|
|
+"""
|
|
|
+This file demonstrates how to use sherpa-onnx Python API to transcribe
|
|
|
+file(s) with a non-streaming model.
|
|
|
+
|
|
|
+(1) For paraformer
|
|
|
+
|
|
|
+ ./python-api-examples/offline-decode-files.py \
|
|
|
+ --tokens=/path/to/tokens.txt \
|
|
|
+ --paraformer=/path/to/paraformer.onnx \
|
|
|
+ --num-threads=2 \
|
|
|
+ --decoding-method=greedy_search \
|
|
|
+ --debug=false \
|
|
|
+ --sample-rate=16000 \
|
|
|
+ --feature-dim=80 \
|
|
|
+ /path/to/0.wav \
|
|
|
+ /path/to/1.wav
|
|
|
+
|
|
|
+(2) For transducer models from icefall
|
|
|
+
|
|
|
+ ./python-api-examples/offline-decode-files.py \
|
|
|
+ --tokens=/path/to/tokens.txt \
|
|
|
+ --encoder=/path/to/encoder.onnx \
|
|
|
+ --decoder=/path/to/decoder.onnx \
|
|
|
+ --joiner=/path/to/joiner.onnx \
|
|
|
+ --num-threads=2 \
|
|
|
+ --decoding-method=greedy_search \
|
|
|
+ --debug=false \
|
|
|
+ --sample-rate=16000 \
|
|
|
+ --feature-dim=80 \
|
|
|
+ /path/to/0.wav \
|
|
|
+ /path/to/1.wav
|
|
|
+
|
|
|
+(3) For CTC models from NeMo
|
|
|
+
|
|
|
+python3 ./python-api-examples/offline-decode-files.py \
|
|
|
+ --tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \
|
|
|
+ --nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \
|
|
|
+ --num-threads=2 \
|
|
|
+ --decoding-method=greedy_search \
|
|
|
+ --debug=false \
|
|
|
+ ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \
|
|
|
+ ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \
|
|
|
+ ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav
|
|
|
+
|
|
|
+(4) For Whisper models
|
|
|
+
|
|
|
+python3 ./python-api-examples/offline-decode-files.py \
|
|
|
+ --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
|
|
|
+ --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
|
|
|
+ --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
|
|
|
+ --whisper-task=transcribe \
|
|
|
+ --num-threads=1 \
|
|
|
+ ./sherpa-onnx-whisper-base.en/test_wavs/0.wav \
|
|
|
+ ./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
|
|
|
+ ./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
|
|
|
+
|
|
|
+(5) For CTC models from WeNet
|
|
|
+
|
|
|
+python3 ./python-api-examples/offline-decode-files.py \
|
|
|
+ --wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
|
|
|
+ --tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
|
|
|
+ ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \
|
|
|
+ ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \
|
|
|
+ ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav
|
|
|
+
|
|
|
+(6) For tdnn models of the yesno recipe from icefall
|
|
|
+
|
|
|
+python3 ./python-api-examples/offline-decode-files.py \
|
|
|
+ --sample-rate=8000 \
|
|
|
+ --feature-dim=23 \
|
|
|
+ --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
|
|
|
+ --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \
|
|
|
+ ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \
|
|
|
+ ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \
|
|
|
+ ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav
|
|
|
+
|
|
|
+Please refer to
|
|
|
+https://k2-fsa.github.io/sherpa/onnx/index.html
|
|
|
+to install sherpa-onnx and to download non-streaming pre-trained models
|
|
|
+used in this file.
|
|
|
+"""
|
|
|
+import argparse
|
|
|
+import time
|
|
|
+import wave
|
|
|
+from pathlib import Path
|
|
|
+from typing import List, Tuple, Dict, Iterable, TextIO, Union
|
|
|
+
|
|
|
+import numpy as np
|
|
|
+import sherpa_onnx
|
|
|
+import soundfile as sf
|
|
|
+from datasets import load_dataset
|
|
|
+import logging
|
|
|
+from collections import defaultdict
|
|
|
+import kaldialign
|
|
|
+from zhon.hanzi import punctuation
|
|
|
+import string
|
|
|
+punctuation_all = punctuation + string.punctuation
|
|
|
+Pathlike = Union[str, Path]
|
|
|
+
|
|
|
+def remove_punctuation(text: str) -> str:
|
|
|
+ for x in punctuation_all:
|
|
|
+ if x == '\'':
|
|
|
+ continue
|
|
|
+ text = text.replace(x, '')
|
|
|
+ return text
|
|
|
+
|
|
|
+def store_transcripts(
|
|
|
+ filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
|
|
|
+) -> None:
|
|
|
+ """Save predicted results and reference transcripts to a file.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ filename:
|
|
|
+ File to save the results to.
|
|
|
+ texts:
|
|
|
+ An iterable of tuples. The first element is the cur_id, the second is
|
|
|
+ the reference transcript and the third element is the predicted result.
|
|
|
+ If it is a multi-talker ASR system, the ref and hyp may also be lists of
|
|
|
+ strings.
|
|
|
+ Returns:
|
|
|
+ Return None.
|
|
|
+ """
|
|
|
+ with open(filename, "w", encoding="utf8") as f:
|
|
|
+ for cut_id, ref, hyp in texts:
|
|
|
+ if char_level:
|
|
|
+ ref = list("".join(ref))
|
|
|
+ hyp = list("".join(hyp))
|
|
|
+ print(f"{cut_id}:\tref={ref}", file=f)
|
|
|
+ print(f"{cut_id}:\thyp={hyp}", file=f)
|
|
|
+
|
|
|
+
|
|
|
+def write_error_stats(
|
|
|
+ f: TextIO,
|
|
|
+ test_set_name: str,
|
|
|
+ results: List[Tuple[str, str]],
|
|
|
+ enable_log: bool = True,
|
|
|
+ compute_CER: bool = False,
|
|
|
+ sclite_mode: bool = False,
|
|
|
+) -> float:
|
|
|
+ """Write statistics based on predicted results and reference transcripts.
|
|
|
+
|
|
|
+ It will write the following to the given file:
|
|
|
+
|
|
|
+ - WER
|
|
|
+ - number of insertions, deletions, substitutions, corrects and total
|
|
|
+ reference words. For example::
|
|
|
+
|
|
|
+ Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
|
|
|
+ reference words (2337 correct)
|
|
|
+
|
|
|
+ - The difference between the reference transcript and predicted result.
|
|
|
+ An instance is given below::
|
|
|
+
|
|
|
+ THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
|
|
|
+
|
|
|
+ The above example shows that the reference word is `EDISON`,
|
|
|
+ but it is predicted to `ADDISON` (a substitution error).
|
|
|
+
|
|
|
+ Another example is::
|
|
|
+
|
|
|
+ FOR THE FIRST DAY (SIR->*) I THINK
|
|
|
+
|
|
|
+ The reference word `SIR` is missing in the predicted
|
|
|
+ results (a deletion error).
|
|
|
+ results:
|
|
|
+ An iterable of tuples. The first element is the cut_id, the second is
|
|
|
+ the reference transcript and the third element is the predicted result.
|
|
|
+ enable_log:
|
|
|
+ If True, also print detailed WER to the console.
|
|
|
+ Otherwise, it is written only to the given file.
|
|
|
+ Returns:
|
|
|
+ Return None.
|
|
|
+ """
|
|
|
+ subs: Dict[Tuple[str, str], int] = defaultdict(int)
|
|
|
+ ins: Dict[str, int] = defaultdict(int)
|
|
|
+ dels: Dict[str, int] = defaultdict(int)
|
|
|
+
|
|
|
+ # `words` stores counts per word, as follows:
|
|
|
+ # corr, ref_sub, hyp_sub, ins, dels
|
|
|
+ words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
|
|
+ num_corr = 0
|
|
|
+ ERR = "*"
|
|
|
+
|
|
|
+ if compute_CER:
|
|
|
+ for i, res in enumerate(results):
|
|
|
+ cut_id, ref, hyp = res
|
|
|
+ ref = list("".join(ref))
|
|
|
+ hyp = list("".join(hyp))
|
|
|
+ results[i] = (cut_id, ref, hyp)
|
|
|
+
|
|
|
+ for cut_id, ref, hyp in results:
|
|
|
+ ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
|
|
|
+ for ref_word, hyp_word in ali:
|
|
|
+ if ref_word == ERR:
|
|
|
+ ins[hyp_word] += 1
|
|
|
+ words[hyp_word][3] += 1
|
|
|
+ elif hyp_word == ERR:
|
|
|
+ dels[ref_word] += 1
|
|
|
+ words[ref_word][4] += 1
|
|
|
+ elif hyp_word != ref_word:
|
|
|
+ subs[(ref_word, hyp_word)] += 1
|
|
|
+ words[ref_word][1] += 1
|
|
|
+ words[hyp_word][2] += 1
|
|
|
+ else:
|
|
|
+ words[ref_word][0] += 1
|
|
|
+ num_corr += 1
|
|
|
+ ref_len = sum([len(r) for _, r, _ in results])
|
|
|
+ sub_errs = sum(subs.values())
|
|
|
+ ins_errs = sum(ins.values())
|
|
|
+ del_errs = sum(dels.values())
|
|
|
+ tot_errs = sub_errs + ins_errs + del_errs
|
|
|
+ tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
|
|
|
+
|
|
|
+ if enable_log:
|
|
|
+ logging.info(
|
|
|
+ f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
|
|
|
+ f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
|
|
|
+ f"{del_errs} del, {sub_errs} sub ]"
|
|
|
+ )
|
|
|
+
|
|
|
+ print(f"%WER = {tot_err_rate}", file=f)
|
|
|
+ print(
|
|
|
+ f"Errors: {ins_errs} insertions, {del_errs} deletions, "
|
|
|
+ f"{sub_errs} substitutions, over {ref_len} reference "
|
|
|
+ f"words ({num_corr} correct)",
|
|
|
+ file=f,
|
|
|
+ )
|
|
|
+ print(
|
|
|
+ "Search below for sections starting with PER-UTT DETAILS:, "
|
|
|
+ "SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
|
|
|
+ file=f,
|
|
|
+ )
|
|
|
+
|
|
|
+ print("", file=f)
|
|
|
+ print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
|
|
|
+ for cut_id, ref, hyp in results:
|
|
|
+ ali = kaldialign.align(ref, hyp, ERR)
|
|
|
+ combine_successive_errors = True
|
|
|
+ if combine_successive_errors:
|
|
|
+ ali = [[[x], [y]] for x, y in ali]
|
|
|
+ for i in range(len(ali) - 1):
|
|
|
+ if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
|
|
|
+ ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
|
|
|
+ ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
|
|
|
+ ali[i] = [[], []]
|
|
|
+ ali = [
|
|
|
+ [
|
|
|
+ list(filter(lambda a: a != ERR, x)),
|
|
|
+ list(filter(lambda a: a != ERR, y)),
|
|
|
+ ]
|
|
|
+ for x, y in ali
|
|
|
+ ]
|
|
|
+ ali = list(filter(lambda x: x != [[], []], ali))
|
|
|
+ ali = [
|
|
|
+ [
|
|
|
+ ERR if x == [] else " ".join(x),
|
|
|
+ ERR if y == [] else " ".join(y),
|
|
|
+ ]
|
|
|
+ for x, y in ali
|
|
|
+ ]
|
|
|
+
|
|
|
+ print(
|
|
|
+ f"{cut_id}:\t"
|
|
|
+ + " ".join(
|
|
|
+ (
|
|
|
+ ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
|
|
|
+ for ref_word, hyp_word in ali
|
|
|
+ )
|
|
|
+ ),
|
|
|
+ file=f,
|
|
|
+ )
|
|
|
+
|
|
|
+ print("", file=f)
|
|
|
+ print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
|
|
+
|
|
|
+ for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
|
|
|
+ print(f"{count} {ref} -> {hyp}", file=f)
|
|
|
+
|
|
|
+ print("", file=f)
|
|
|
+ print("DELETIONS: count ref", file=f)
|
|
|
+ for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
|
|
|
+ print(f"{count} {ref}", file=f)
|
|
|
+
|
|
|
+ print("", file=f)
|
|
|
+ print("INSERTIONS: count hyp", file=f)
|
|
|
+ for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
|
|
|
+ print(f"{count} {hyp}", file=f)
|
|
|
+
|
|
|
+ print("", file=f)
|
|
|
+ print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
|
|
|
+ for _, word, counts in sorted(
|
|
|
+ [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
|
|
+ ):
|
|
|
+ (corr, ref_sub, hyp_sub, ins, dels) = counts
|
|
|
+ tot_errs = ref_sub + hyp_sub + ins + dels
|
|
|
+ ref_count = corr + ref_sub + dels
|
|
|
+ hyp_count = corr + hyp_sub + ins
|
|
|
+
|
|
|
+ print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
|
|
+ return float(tot_err_rate)
|
|
|
+
|
|
|
+def get_args():
|
|
|
+ parser = argparse.ArgumentParser(
|
|
|
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--tokens",
|
|
|
+ type=str,
|
|
|
+ help="Path to tokens.txt",
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--hotwords-file",
|
|
|
+ type=str,
|
|
|
+ default="",
|
|
|
+ help="""
|
|
|
+ The file containing hotwords, one words/phrases per line, like
|
|
|
+ HELLO WORLD
|
|
|
+ 你好世界
|
|
|
+ """,
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--hotwords-score",
|
|
|
+ type=float,
|
|
|
+ default=1.5,
|
|
|
+ help="""
|
|
|
+ The hotword score of each token for biasing word/phrase. Used only if
|
|
|
+ --hotwords-file is given.
|
|
|
+ """,
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--modeling-unit",
|
|
|
+ type=str,
|
|
|
+ default="",
|
|
|
+ help="""
|
|
|
+ The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe.
|
|
|
+ Used only when hotwords-file is given.
|
|
|
+ """,
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--bpe-vocab",
|
|
|
+ type=str,
|
|
|
+ default="",
|
|
|
+ help="""
|
|
|
+ The path to the bpe vocabulary, the bpe vocabulary is generated by
|
|
|
+ sentencepiece, you can also export the bpe vocabulary through a bpe model
|
|
|
+ by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given
|
|
|
+ and modeling-unit is bpe or cjkchar+bpe.
|
|
|
+ """,
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--encoder",
|
|
|
+ default="",
|
|
|
+ type=str,
|
|
|
+ help="Path to the encoder model",
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--decoder",
|
|
|
+ default="",
|
|
|
+ type=str,
|
|
|
+ help="Path to the decoder model",
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--joiner",
|
|
|
+ default="",
|
|
|
+ type=str,
|
|
|
+ help="Path to the joiner model",
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--paraformer",
|
|
|
+ default="",
|
|
|
+ type=str,
|
|
|
+ help="Path to the model.onnx from Paraformer",
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--nemo-ctc",
|
|
|
+ default="",
|
|
|
+ type=str,
|
|
|
+ help="Path to the model.onnx from NeMo CTC",
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--wenet-ctc",
|
|
|
+ default="",
|
|
|
+ type=str,
|
|
|
+ help="Path to the model.onnx from WeNet CTC",
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--tdnn-model",
|
|
|
+ default="",
|
|
|
+ type=str,
|
|
|
+ help="Path to the model.onnx for the tdnn model of the yesno recipe",
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--num-threads",
|
|
|
+ type=int,
|
|
|
+ default=1,
|
|
|
+ help="Number of threads for neural network computation",
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--whisper-encoder",
|
|
|
+ default="",
|
|
|
+ type=str,
|
|
|
+ help="Path to whisper encoder model",
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--whisper-decoder",
|
|
|
+ default="",
|
|
|
+ type=str,
|
|
|
+ help="Path to whisper decoder model",
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--whisper-language",
|
|
|
+ default="",
|
|
|
+ type=str,
|
|
|
+ help="""It specifies the spoken language in the input audio file.
|
|
|
+ Example values: en, fr, de, zh, jp.
|
|
|
+ Available languages for multilingual models can be found at
|
|
|
+ https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
|
|
|
+ If not specified, we infer the language from the input audio file.
|
|
|
+ """,
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--whisper-task",
|
|
|
+ default="transcribe",
|
|
|
+ choices=["transcribe", "translate"],
|
|
|
+ type=str,
|
|
|
+ help="""For multilingual models, if you specify translate, the output
|
|
|
+ will be in English.
|
|
|
+ """,
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--whisper-tail-paddings",
|
|
|
+ default=-1,
|
|
|
+ type=int,
|
|
|
+ help="""Number of tail padding frames.
|
|
|
+ We have removed the 30-second constraint from whisper, so you need to
|
|
|
+ choose the amount of tail padding frames by yourself.
|
|
|
+ Use -1 to use a default value for tail padding.
|
|
|
+ """,
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--blank-penalty",
|
|
|
+ type=float,
|
|
|
+ default=0.0,
|
|
|
+ help="""
|
|
|
+ The penalty applied on blank symbol during decoding.
|
|
|
+ Note: It is a positive value that would be applied to logits like
|
|
|
+ this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
|
|
+ [batch_size, vocab] and blank id is 0).
|
|
|
+ """,
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--decoding-method",
|
|
|
+ type=str,
|
|
|
+ default="greedy_search",
|
|
|
+ help="Valid values are greedy_search and modified_beam_search",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--debug",
|
|
|
+ type=bool,
|
|
|
+ default=False,
|
|
|
+ help="True to show debug messages",
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--sample-rate",
|
|
|
+ type=int,
|
|
|
+ default=16000,
|
|
|
+ help="""Sample rate of the feature extractor. Must match the one
|
|
|
+ expected by the model. Note: The input sound files can have a
|
|
|
+ different sample rate from this argument.""",
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--feature-dim",
|
|
|
+ type=int,
|
|
|
+ default=80,
|
|
|
+ help="Feature dimension. Must match the one expected by the model",
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "sound_files",
|
|
|
+ type=str,
|
|
|
+ nargs="+",
|
|
|
+ help="The input sound file(s) to decode. Each file must be of WAVE"
|
|
|
+ "format with a single channel, and each sample has 16-bit, "
|
|
|
+ "i.e., int16_t. "
|
|
|
+ "The sample rate of the file can be arbitrary and does not need to "
|
|
|
+ "be 16 kHz",
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--name",
|
|
|
+ type=str,
|
|
|
+ default="",
|
|
|
+ help="The directory containing the input sound files to decode",
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--log-dir",
|
|
|
+ type=str,
|
|
|
+ default="",
|
|
|
+ help="The directory containing the input sound files to decode",
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--label",
|
|
|
+ type=str,
|
|
|
+ default=None,
|
|
|
+ help="wav_base_name label",
|
|
|
+ )
|
|
|
+
|
|
|
+ # Dataset related arguments for loading labels when label file is not provided
|
|
|
+ parser.add_argument(
|
|
|
+ "--dataset-name",
|
|
|
+ type=str,
|
|
|
+ default="yuekai/seed_tts_cosy2",
|
|
|
+ help="Huggingface dataset name for loading labels",
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--split-name",
|
|
|
+ type=str,
|
|
|
+ default="wenetspeech4tts",
|
|
|
+ help="Dataset split name for loading labels",
|
|
|
+ )
|
|
|
+
|
|
|
+ return parser.parse_args()
|
|
|
+
|
|
|
+
|
|
|
+def assert_file_exists(filename: str):
|
|
|
+ assert Path(filename).is_file(), (
|
|
|
+ f"{filename} does not exist!\n"
|
|
|
+ "Please refer to "
|
|
|
+ "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
|
|
|
+ """
|
|
|
+ Args:
|
|
|
+ wave_filename:
|
|
|
+ Path to a wave file. It should be single channel and can be of type
|
|
|
+ 32-bit floating point PCM. Its sample rate does not need to be 24kHz.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Return a tuple containing:
|
|
|
+ - A 1-D array of dtype np.float32 containing the samples,
|
|
|
+ which are normalized to the range [-1, 1].
|
|
|
+ - Sample rate of the wave file.
|
|
|
+ """
|
|
|
+
|
|
|
+ samples, sample_rate = sf.read(wave_filename, dtype="float32")
|
|
|
+ assert (
|
|
|
+ samples.ndim == 1
|
|
|
+ ), f"Expected single channel, but got {samples.ndim} channels."
|
|
|
+
|
|
|
+ samples_float32 = samples.astype(np.float32)
|
|
|
+
|
|
|
+ return samples_float32, sample_rate
|
|
|
+
|
|
|
+
|
|
|
+def normalize_text_alimeeting(text: str) -> str:
|
|
|
+ """
|
|
|
+ Text normalization similar to M2MeT challenge baseline.
|
|
|
+ See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
|
|
|
+ """
|
|
|
+ import re
|
|
|
+ text = text.replace('\u00A0', '') # test_hard
|
|
|
+ text = text.replace(" ", "")
|
|
|
+ text = text.replace("<sil>", "")
|
|
|
+ text = text.replace("<%>", "")
|
|
|
+ text = text.replace("<->", "")
|
|
|
+ text = text.replace("<$>", "")
|
|
|
+ text = text.replace("<#>", "")
|
|
|
+ text = text.replace("<_>", "")
|
|
|
+ text = text.replace("<space>", "")
|
|
|
+ text = text.replace("`", "")
|
|
|
+ text = text.replace("&", "")
|
|
|
+ text = text.replace(",", "")
|
|
|
+ if re.search("[a-zA-Z]", text):
|
|
|
+ text = text.upper()
|
|
|
+ text = text.replace("A", "A")
|
|
|
+ text = text.replace("a", "A")
|
|
|
+ text = text.replace("b", "B")
|
|
|
+ text = text.replace("c", "C")
|
|
|
+ text = text.replace("k", "K")
|
|
|
+ text = text.replace("t", "T")
|
|
|
+ text = text.replace(",", "")
|
|
|
+ text = text.replace("丶", "")
|
|
|
+ text = text.replace("。", "")
|
|
|
+ text = text.replace("、", "")
|
|
|
+ text = text.replace("?", "")
|
|
|
+ text = remove_punctuation(text)
|
|
|
+ return text
|
|
|
+
|
|
|
+
|
|
|
+def main():
|
|
|
+ args = get_args()
|
|
|
+ assert_file_exists(args.tokens)
|
|
|
+ assert args.num_threads > 0, args.num_threads
|
|
|
+
|
|
|
+ assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
|
|
+ assert len(args.wenet_ctc) == 0, args.wenet_ctc
|
|
|
+ assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
|
|
+ assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
|
|
+ assert len(args.tdnn_model) == 0, args.tdnn_model
|
|
|
+
|
|
|
+ assert_file_exists(args.paraformer)
|
|
|
+
|
|
|
+ recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
|
|
|
+ paraformer=args.paraformer,
|
|
|
+ tokens=args.tokens,
|
|
|
+ num_threads=args.num_threads,
|
|
|
+ sample_rate=args.sample_rate,
|
|
|
+ feature_dim=args.feature_dim,
|
|
|
+ decoding_method=args.decoding_method,
|
|
|
+ debug=args.debug,
|
|
|
+ )
|
|
|
+
|
|
|
+ print("Started!")
|
|
|
+ start_time = time.time()
|
|
|
+
|
|
|
+ streams, results = [], []
|
|
|
+ total_duration = 0
|
|
|
+
|
|
|
+ for i, wave_filename in enumerate(args.sound_files):
|
|
|
+ assert_file_exists(wave_filename)
|
|
|
+ samples, sample_rate = read_wave(wave_filename)
|
|
|
+ duration = len(samples) / sample_rate
|
|
|
+ total_duration += duration
|
|
|
+ s = recognizer.create_stream()
|
|
|
+ s.accept_waveform(sample_rate, samples)
|
|
|
+
|
|
|
+ streams.append(s)
|
|
|
+ if i % 10 == 0:
|
|
|
+ recognizer.decode_streams(streams)
|
|
|
+ results += [s.result.text for s in streams]
|
|
|
+ streams = []
|
|
|
+ print(f"Processed {i} files")
|
|
|
+ # process the last batch
|
|
|
+ if streams:
|
|
|
+ recognizer.decode_streams(streams)
|
|
|
+ results += [s.result.text for s in streams]
|
|
|
+ end_time = time.time()
|
|
|
+ print("Done!")
|
|
|
+
|
|
|
+ results_dict = {}
|
|
|
+ for wave_filename, result in zip(args.sound_files, results):
|
|
|
+ print(f"{wave_filename}\n{result}")
|
|
|
+ print("-" * 10)
|
|
|
+ wave_basename = Path(wave_filename).stem
|
|
|
+ results_dict[wave_basename] = result
|
|
|
+
|
|
|
+ elapsed_seconds = end_time - start_time
|
|
|
+ rtf = elapsed_seconds / total_duration
|
|
|
+ print(f"num_threads: {args.num_threads}")
|
|
|
+ print(f"decoding_method: {args.decoding_method}")
|
|
|
+ print(f"Wave duration: {total_duration:.3f} s")
|
|
|
+ print(f"Elapsed time: {elapsed_seconds:.3f} s")
|
|
|
+ print(
|
|
|
+ f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
|
|
|
+ )
|
|
|
+
|
|
|
+ # Load labels either from file or from dataset
|
|
|
+ labels_dict = {}
|
|
|
+
|
|
|
+ if args.label:
|
|
|
+ # Load labels from file (original functionality)
|
|
|
+ print(f"Loading labels from file: {args.label}")
|
|
|
+ with open(args.label, "r") as f:
|
|
|
+ for line in f:
|
|
|
+ # fields = line.strip().split(" ")
|
|
|
+ # fields = [item for item in fields if item]
|
|
|
+ # assert len(fields) == 4
|
|
|
+ # prompt_text, prompt_audio, text, audio_path = fields
|
|
|
+
|
|
|
+ fields = line.strip().split("|")
|
|
|
+ fields = [item for item in fields if item]
|
|
|
+ assert len(fields) == 4
|
|
|
+ audio_path, prompt_text, prompt_audio, text = fields
|
|
|
+ labels_dict[Path(audio_path).stem] = normalize_text_alimeeting(text)
|
|
|
+ else:
|
|
|
+ # Load labels from dataset (new functionality)
|
|
|
+ print(f"Loading labels from dataset: {args.dataset_name}, split: {args.split_name}")
|
|
|
+ if 'zero' in args.split_name:
|
|
|
+ dataset_name = "yuekai/CV3-Eval"
|
|
|
+ else:
|
|
|
+ dataset_name = "yuekai/seed_tts_cosy2"
|
|
|
+ dataset = load_dataset(
|
|
|
+ dataset_name,
|
|
|
+ split=args.split_name,
|
|
|
+ trust_remote_code=True,
|
|
|
+ )
|
|
|
+
|
|
|
+ for item in dataset:
|
|
|
+ audio_id = item["id"]
|
|
|
+ labels_dict[audio_id] = normalize_text_alimeeting(item["target_text"])
|
|
|
+
|
|
|
+ print(f"Loaded {len(labels_dict)} labels from dataset")
|
|
|
+
|
|
|
+ # Perform evaluation if labels are available
|
|
|
+ if labels_dict:
|
|
|
+
|
|
|
+ final_results = []
|
|
|
+ for key, value in results_dict.items():
|
|
|
+ if key in labels_dict:
|
|
|
+ final_results.append((key, labels_dict[key], value))
|
|
|
+ else:
|
|
|
+ print(f"Warning: No label found for {key}, skipping...")
|
|
|
+
|
|
|
+ if final_results:
|
|
|
+ store_transcripts(
|
|
|
+ filename=f"{args.log_dir}/recogs-{args.name}.txt", texts=final_results
|
|
|
+ )
|
|
|
+ with open(f"{args.log_dir}/errs-{args.name}.txt", "w") as f:
|
|
|
+ write_error_stats(f, "test-set", final_results, enable_log=True)
|
|
|
+
|
|
|
+ with open(f"{args.log_dir}/errs-{args.name}.txt", "r") as f:
|
|
|
+ print(f.readline()) # WER
|
|
|
+ print(f.readline()) # Detailed errors
|
|
|
+ else:
|
|
|
+ print("No matching labels found for evaluation")
|
|
|
+ else:
|
|
|
+ print("No labels available for evaluation")
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|