| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753 |
- #!/usr/bin/env python3
- #
- # Copyright (c) 2023 by manyeyes
- # Copyright (c) 2023 Xiaomi Corporation
- """
- This file demonstrates how to use sherpa-onnx Python API to transcribe
- file(s) with a non-streaming model.
- (1) For paraformer
- ./python-api-examples/offline-decode-files.py \
- --tokens=/path/to/tokens.txt \
- --paraformer=/path/to/paraformer.onnx \
- --num-threads=2 \
- --decoding-method=greedy_search \
- --debug=false \
- --sample-rate=16000 \
- --feature-dim=80 \
- /path/to/0.wav \
- /path/to/1.wav
- (2) For transducer models from icefall
- ./python-api-examples/offline-decode-files.py \
- --tokens=/path/to/tokens.txt \
- --encoder=/path/to/encoder.onnx \
- --decoder=/path/to/decoder.onnx \
- --joiner=/path/to/joiner.onnx \
- --num-threads=2 \
- --decoding-method=greedy_search \
- --debug=false \
- --sample-rate=16000 \
- --feature-dim=80 \
- /path/to/0.wav \
- /path/to/1.wav
- (3) For CTC models from NeMo
- python3 ./python-api-examples/offline-decode-files.py \
- --tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \
- --nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \
- --num-threads=2 \
- --decoding-method=greedy_search \
- --debug=false \
- ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \
- ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \
- ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav
- (4) For Whisper models
- python3 ./python-api-examples/offline-decode-files.py \
- --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
- --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
- --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
- --whisper-task=transcribe \
- --num-threads=1 \
- ./sherpa-onnx-whisper-base.en/test_wavs/0.wav \
- ./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
- ./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
- (5) For CTC models from WeNet
- python3 ./python-api-examples/offline-decode-files.py \
- --wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
- --tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
- ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \
- ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \
- ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav
- (6) For tdnn models of the yesno recipe from icefall
- python3 ./python-api-examples/offline-decode-files.py \
- --sample-rate=8000 \
- --feature-dim=23 \
- --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
- --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \
- ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \
- ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \
- ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav
- Please refer to
- https://k2-fsa.github.io/sherpa/onnx/index.html
- to install sherpa-onnx and to download non-streaming pre-trained models
- used in this file.
- """
- import argparse
- import time
- import wave
- from pathlib import Path
- from typing import List, Tuple, Dict, Iterable, TextIO, Union
- import numpy as np
- import sherpa_onnx
- import soundfile as sf
- from datasets import load_dataset
- import logging
- from collections import defaultdict
- import kaldialign
- from zhon.hanzi import punctuation
- import string
- punctuation_all = punctuation + string.punctuation
- Pathlike = Union[str, Path]
- def remove_punctuation(text: str) -> str:
- for x in punctuation_all:
- if x == '\'':
- continue
- text = text.replace(x, '')
- return text
- def store_transcripts(
- filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
- ) -> None:
- """Save predicted results and reference transcripts to a file.
- Args:
- filename:
- File to save the results to.
- texts:
- An iterable of tuples. The first element is the cur_id, the second is
- the reference transcript and the third element is the predicted result.
- If it is a multi-talker ASR system, the ref and hyp may also be lists of
- strings.
- Returns:
- Return None.
- """
- with open(filename, "w", encoding="utf8") as f:
- for cut_id, ref, hyp in texts:
- if char_level:
- ref = list("".join(ref))
- hyp = list("".join(hyp))
- print(f"{cut_id}:\tref={ref}", file=f)
- print(f"{cut_id}:\thyp={hyp}", file=f)
- def write_error_stats(
- f: TextIO,
- test_set_name: str,
- results: List[Tuple[str, str]],
- enable_log: bool = True,
- compute_CER: bool = False,
- sclite_mode: bool = False,
- ) -> float:
- """Write statistics based on predicted results and reference transcripts.
- It will write the following to the given file:
- - WER
- - number of insertions, deletions, substitutions, corrects and total
- reference words. For example::
- Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
- reference words (2337 correct)
- - The difference between the reference transcript and predicted result.
- An instance is given below::
- THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
- The above example shows that the reference word is `EDISON`,
- but it is predicted to `ADDISON` (a substitution error).
- Another example is::
- FOR THE FIRST DAY (SIR->*) I THINK
- The reference word `SIR` is missing in the predicted
- results (a deletion error).
- results:
- An iterable of tuples. The first element is the cut_id, the second is
- the reference transcript and the third element is the predicted result.
- enable_log:
- If True, also print detailed WER to the console.
- Otherwise, it is written only to the given file.
- Returns:
- Return None.
- """
- subs: Dict[Tuple[str, str], int] = defaultdict(int)
- ins: Dict[str, int] = defaultdict(int)
- dels: Dict[str, int] = defaultdict(int)
- # `words` stores counts per word, as follows:
- # corr, ref_sub, hyp_sub, ins, dels
- words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
- num_corr = 0
- ERR = "*"
- if compute_CER:
- for i, res in enumerate(results):
- cut_id, ref, hyp = res
- ref = list("".join(ref))
- hyp = list("".join(hyp))
- results[i] = (cut_id, ref, hyp)
- for cut_id, ref, hyp in results:
- ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
- for ref_word, hyp_word in ali:
- if ref_word == ERR:
- ins[hyp_word] += 1
- words[hyp_word][3] += 1
- elif hyp_word == ERR:
- dels[ref_word] += 1
- words[ref_word][4] += 1
- elif hyp_word != ref_word:
- subs[(ref_word, hyp_word)] += 1
- words[ref_word][1] += 1
- words[hyp_word][2] += 1
- else:
- words[ref_word][0] += 1
- num_corr += 1
- ref_len = sum([len(r) for _, r, _ in results])
- sub_errs = sum(subs.values())
- ins_errs = sum(ins.values())
- del_errs = sum(dels.values())
- tot_errs = sub_errs + ins_errs + del_errs
- tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
- if enable_log:
- logging.info(
- f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
- f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
- f"{del_errs} del, {sub_errs} sub ]"
- )
- print(f"%WER = {tot_err_rate}", file=f)
- print(
- f"Errors: {ins_errs} insertions, {del_errs} deletions, "
- f"{sub_errs} substitutions, over {ref_len} reference "
- f"words ({num_corr} correct)",
- file=f,
- )
- print(
- "Search below for sections starting with PER-UTT DETAILS:, "
- "SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
- file=f,
- )
- print("", file=f)
- print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
- for cut_id, ref, hyp in results:
- ali = kaldialign.align(ref, hyp, ERR)
- combine_successive_errors = True
- if combine_successive_errors:
- ali = [[[x], [y]] for x, y in ali]
- for i in range(len(ali) - 1):
- if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
- ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
- ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
- ali[i] = [[], []]
- ali = [
- [
- list(filter(lambda a: a != ERR, x)),
- list(filter(lambda a: a != ERR, y)),
- ]
- for x, y in ali
- ]
- ali = list(filter(lambda x: x != [[], []], ali))
- ali = [
- [
- ERR if x == [] else " ".join(x),
- ERR if y == [] else " ".join(y),
- ]
- for x, y in ali
- ]
- print(
- f"{cut_id}:\t"
- + " ".join(
- (
- ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
- for ref_word, hyp_word in ali
- )
- ),
- file=f,
- )
- print("", file=f)
- print("SUBSTITUTIONS: count ref -> hyp", file=f)
- for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
- print(f"{count} {ref} -> {hyp}", file=f)
- print("", file=f)
- print("DELETIONS: count ref", file=f)
- for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
- print(f"{count} {ref}", file=f)
- print("", file=f)
- print("INSERTIONS: count hyp", file=f)
- for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
- print(f"{count} {hyp}", file=f)
- print("", file=f)
- print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
- for _, word, counts in sorted(
- [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
- ):
- (corr, ref_sub, hyp_sub, ins, dels) = counts
- tot_errs = ref_sub + hyp_sub + ins + dels
- ref_count = corr + ref_sub + dels
- hyp_count = corr + hyp_sub + ins
- print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
- return float(tot_err_rate)
- def get_args():
- parser = argparse.ArgumentParser(
- formatter_class=argparse.ArgumentDefaultsHelpFormatter
- )
- parser.add_argument(
- "--tokens",
- type=str,
- help="Path to tokens.txt",
- )
- parser.add_argument(
- "--hotwords-file",
- type=str,
- default="",
- help="""
- The file containing hotwords, one words/phrases per line, like
- HELLO WORLD
- 你好世界
- """,
- )
- parser.add_argument(
- "--hotwords-score",
- type=float,
- default=1.5,
- help="""
- The hotword score of each token for biasing word/phrase. Used only if
- --hotwords-file is given.
- """,
- )
- parser.add_argument(
- "--modeling-unit",
- type=str,
- default="",
- help="""
- The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe.
- Used only when hotwords-file is given.
- """,
- )
- parser.add_argument(
- "--bpe-vocab",
- type=str,
- default="",
- help="""
- The path to the bpe vocabulary, the bpe vocabulary is generated by
- sentencepiece, you can also export the bpe vocabulary through a bpe model
- by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given
- and modeling-unit is bpe or cjkchar+bpe.
- """,
- )
- parser.add_argument(
- "--encoder",
- default="",
- type=str,
- help="Path to the encoder model",
- )
- parser.add_argument(
- "--decoder",
- default="",
- type=str,
- help="Path to the decoder model",
- )
- parser.add_argument(
- "--joiner",
- default="",
- type=str,
- help="Path to the joiner model",
- )
- parser.add_argument(
- "--paraformer",
- default="",
- type=str,
- help="Path to the model.onnx from Paraformer",
- )
- parser.add_argument(
- "--nemo-ctc",
- default="",
- type=str,
- help="Path to the model.onnx from NeMo CTC",
- )
- parser.add_argument(
- "--wenet-ctc",
- default="",
- type=str,
- help="Path to the model.onnx from WeNet CTC",
- )
- parser.add_argument(
- "--tdnn-model",
- default="",
- type=str,
- help="Path to the model.onnx for the tdnn model of the yesno recipe",
- )
- parser.add_argument(
- "--num-threads",
- type=int,
- default=1,
- help="Number of threads for neural network computation",
- )
- parser.add_argument(
- "--whisper-encoder",
- default="",
- type=str,
- help="Path to whisper encoder model",
- )
- parser.add_argument(
- "--whisper-decoder",
- default="",
- type=str,
- help="Path to whisper decoder model",
- )
- parser.add_argument(
- "--whisper-language",
- default="",
- type=str,
- help="""It specifies the spoken language in the input audio file.
- Example values: en, fr, de, zh, jp.
- Available languages for multilingual models can be found at
- https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
- If not specified, we infer the language from the input audio file.
- """,
- )
- parser.add_argument(
- "--whisper-task",
- default="transcribe",
- choices=["transcribe", "translate"],
- type=str,
- help="""For multilingual models, if you specify translate, the output
- will be in English.
- """,
- )
- parser.add_argument(
- "--whisper-tail-paddings",
- default=-1,
- type=int,
- help="""Number of tail padding frames.
- We have removed the 30-second constraint from whisper, so you need to
- choose the amount of tail padding frames by yourself.
- Use -1 to use a default value for tail padding.
- """,
- )
- parser.add_argument(
- "--blank-penalty",
- type=float,
- default=0.0,
- help="""
- The penalty applied on blank symbol during decoding.
- Note: It is a positive value that would be applied to logits like
- this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
- [batch_size, vocab] and blank id is 0).
- """,
- )
- parser.add_argument(
- "--decoding-method",
- type=str,
- default="greedy_search",
- help="Valid values are greedy_search and modified_beam_search",
- )
- parser.add_argument(
- "--debug",
- type=bool,
- default=False,
- help="True to show debug messages",
- )
- parser.add_argument(
- "--sample-rate",
- type=int,
- default=16000,
- help="""Sample rate of the feature extractor. Must match the one
- expected by the model. Note: The input sound files can have a
- different sample rate from this argument.""",
- )
- parser.add_argument(
- "--feature-dim",
- type=int,
- default=80,
- help="Feature dimension. Must match the one expected by the model",
- )
- parser.add_argument(
- "sound_files",
- type=str,
- nargs="+",
- help="The input sound file(s) to decode. Each file must be of WAVE"
- "format with a single channel, and each sample has 16-bit, "
- "i.e., int16_t. "
- "The sample rate of the file can be arbitrary and does not need to "
- "be 16 kHz",
- )
- parser.add_argument(
- "--name",
- type=str,
- default="",
- help="The directory containing the input sound files to decode",
- )
- parser.add_argument(
- "--log-dir",
- type=str,
- default="",
- help="The directory containing the input sound files to decode",
- )
- parser.add_argument(
- "--label",
- type=str,
- default=None,
- help="wav_base_name label",
- )
-
- # Dataset related arguments for loading labels when label file is not provided
- parser.add_argument(
- "--dataset-name",
- type=str,
- default="yuekai/seed_tts_cosy2",
- help="Huggingface dataset name for loading labels",
- )
-
- parser.add_argument(
- "--split-name",
- type=str,
- default="wenetspeech4tts",
- help="Dataset split name for loading labels",
- )
-
- return parser.parse_args()
- def assert_file_exists(filename: str):
- assert Path(filename).is_file(), (
- f"{filename} does not exist!\n"
- "Please refer to "
- "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
- )
- def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
- """
- Args:
- wave_filename:
- Path to a wave file. It should be single channel and can be of type
- 32-bit floating point PCM. Its sample rate does not need to be 24kHz.
- Returns:
- Return a tuple containing:
- - A 1-D array of dtype np.float32 containing the samples,
- which are normalized to the range [-1, 1].
- - Sample rate of the wave file.
- """
- samples, sample_rate = sf.read(wave_filename, dtype="float32")
- assert (
- samples.ndim == 1
- ), f"Expected single channel, but got {samples.ndim} channels."
- samples_float32 = samples.astype(np.float32)
- return samples_float32, sample_rate
- def normalize_text_alimeeting(text: str) -> str:
- """
- Text normalization similar to M2MeT challenge baseline.
- See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
- """
- import re
- text = text.replace('\u00A0', '') # test_hard
- text = text.replace(" ", "")
- text = text.replace("<sil>", "")
- text = text.replace("<%>", "")
- text = text.replace("<->", "")
- text = text.replace("<$>", "")
- text = text.replace("<#>", "")
- text = text.replace("<_>", "")
- text = text.replace("<space>", "")
- text = text.replace("`", "")
- text = text.replace("&", "")
- text = text.replace(",", "")
- if re.search("[a-zA-Z]", text):
- text = text.upper()
- text = text.replace("A", "A")
- text = text.replace("a", "A")
- text = text.replace("b", "B")
- text = text.replace("c", "C")
- text = text.replace("k", "K")
- text = text.replace("t", "T")
- text = text.replace(",", "")
- text = text.replace("丶", "")
- text = text.replace("。", "")
- text = text.replace("、", "")
- text = text.replace("?", "")
- text = remove_punctuation(text)
- return text
- def main():
- args = get_args()
- assert_file_exists(args.tokens)
- assert args.num_threads > 0, args.num_threads
- assert len(args.nemo_ctc) == 0, args.nemo_ctc
- assert len(args.wenet_ctc) == 0, args.wenet_ctc
- assert len(args.whisper_encoder) == 0, args.whisper_encoder
- assert len(args.whisper_decoder) == 0, args.whisper_decoder
- assert len(args.tdnn_model) == 0, args.tdnn_model
- assert_file_exists(args.paraformer)
- recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
- paraformer=args.paraformer,
- tokens=args.tokens,
- num_threads=args.num_threads,
- sample_rate=args.sample_rate,
- feature_dim=args.feature_dim,
- decoding_method=args.decoding_method,
- debug=args.debug,
- )
- print("Started!")
- start_time = time.time()
- streams, results = [], []
- total_duration = 0
- for i, wave_filename in enumerate(args.sound_files):
- assert_file_exists(wave_filename)
- samples, sample_rate = read_wave(wave_filename)
- duration = len(samples) / sample_rate
- total_duration += duration
- s = recognizer.create_stream()
- s.accept_waveform(sample_rate, samples)
- streams.append(s)
- if i % 10 == 0:
- recognizer.decode_streams(streams)
- results += [s.result.text for s in streams]
- streams = []
- print(f"Processed {i} files")
- # process the last batch
- if streams:
- recognizer.decode_streams(streams)
- results += [s.result.text for s in streams]
- end_time = time.time()
- print("Done!")
- results_dict = {}
- for wave_filename, result in zip(args.sound_files, results):
- print(f"{wave_filename}\n{result}")
- print("-" * 10)
- wave_basename = Path(wave_filename).stem
- results_dict[wave_basename] = result
- elapsed_seconds = end_time - start_time
- rtf = elapsed_seconds / total_duration
- print(f"num_threads: {args.num_threads}")
- print(f"decoding_method: {args.decoding_method}")
- print(f"Wave duration: {total_duration:.3f} s")
- print(f"Elapsed time: {elapsed_seconds:.3f} s")
- print(
- f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
- )
-
- # Load labels either from file or from dataset
- labels_dict = {}
-
- if args.label:
- # Load labels from file (original functionality)
- print(f"Loading labels from file: {args.label}")
- with open(args.label, "r") as f:
- for line in f:
- # fields = line.strip().split(" ")
- # fields = [item for item in fields if item]
- # assert len(fields) == 4
- # prompt_text, prompt_audio, text, audio_path = fields
- fields = line.strip().split("|")
- fields = [item for item in fields if item]
- assert len(fields) == 4
- audio_path, prompt_text, prompt_audio, text = fields
- labels_dict[Path(audio_path).stem] = normalize_text_alimeeting(text)
- else:
- # Load labels from dataset (new functionality)
- print(f"Loading labels from dataset: {args.dataset_name}, split: {args.split_name}")
- if 'zero' in args.split_name:
- dataset_name = "yuekai/CV3-Eval"
- else:
- dataset_name = "yuekai/seed_tts_cosy2"
- dataset = load_dataset(
- dataset_name,
- split=args.split_name,
- trust_remote_code=True,
- )
-
- for item in dataset:
- audio_id = item["id"]
- labels_dict[audio_id] = normalize_text_alimeeting(item["target_text"])
-
- print(f"Loaded {len(labels_dict)} labels from dataset")
- # Perform evaluation if labels are available
- if labels_dict:
- final_results = []
- for key, value in results_dict.items():
- if key in labels_dict:
- final_results.append((key, labels_dict[key], value))
- else:
- print(f"Warning: No label found for {key}, skipping...")
- if final_results:
- store_transcripts(
- filename=f"{args.log_dir}/recogs-{args.name}.txt", texts=final_results
- )
- with open(f"{args.log_dir}/errs-{args.name}.txt", "w") as f:
- write_error_stats(f, "test-set", final_results, enable_log=True)
- with open(f"{args.log_dir}/errs-{args.name}.txt", "r") as f:
- print(f.readline()) # WER
- print(f.readline()) # Detailed errors
- else:
- print("No matching labels found for evaluation")
- else:
- print("No labels available for evaluation")
- if __name__ == "__main__":
- main()
|