#!/usr/bin/env python3 # # Copyright (c) 2023 by manyeyes # Copyright (c) 2023 Xiaomi Corporation """ This file demonstrates how to use sherpa-onnx Python API to transcribe file(s) with a non-streaming model. (1) For paraformer ./python-api-examples/offline-decode-files.py \ --tokens=/path/to/tokens.txt \ --paraformer=/path/to/paraformer.onnx \ --num-threads=2 \ --decoding-method=greedy_search \ --debug=false \ --sample-rate=16000 \ --feature-dim=80 \ /path/to/0.wav \ /path/to/1.wav (2) For transducer models from icefall ./python-api-examples/offline-decode-files.py \ --tokens=/path/to/tokens.txt \ --encoder=/path/to/encoder.onnx \ --decoder=/path/to/decoder.onnx \ --joiner=/path/to/joiner.onnx \ --num-threads=2 \ --decoding-method=greedy_search \ --debug=false \ --sample-rate=16000 \ --feature-dim=80 \ /path/to/0.wav \ /path/to/1.wav (3) For CTC models from NeMo python3 ./python-api-examples/offline-decode-files.py \ --tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \ --nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \ --num-threads=2 \ --decoding-method=greedy_search \ --debug=false \ ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \ ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \ ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav (4) For Whisper models python3 ./python-api-examples/offline-decode-files.py \ --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \ --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \ --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \ --whisper-task=transcribe \ --num-threads=1 \ ./sherpa-onnx-whisper-base.en/test_wavs/0.wav \ ./sherpa-onnx-whisper-base.en/test_wavs/1.wav \ ./sherpa-onnx-whisper-base.en/test_wavs/8k.wav (5) For CTC models from WeNet python3 ./python-api-examples/offline-decode-files.py \ --wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \ --tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \ ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \ ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \ ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav (6) For tdnn models of the yesno recipe from icefall python3 ./python-api-examples/offline-decode-files.py \ --sample-rate=8000 \ --feature-dim=23 \ --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \ --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \ ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \ ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \ ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav Please refer to https://k2-fsa.github.io/sherpa/onnx/index.html to install sherpa-onnx and to download non-streaming pre-trained models used in this file. """ import argparse import time import wave from pathlib import Path from typing import List, Tuple, Dict, Iterable, TextIO, Union import numpy as np import sherpa_onnx import soundfile as sf from datasets import load_dataset import logging from collections import defaultdict import kaldialign from zhon.hanzi import punctuation import string punctuation_all = punctuation + string.punctuation Pathlike = Union[str, Path] def remove_punctuation(text: str) -> str: for x in punctuation_all: if x == '\'': continue text = text.replace(x, '') return text def store_transcripts( filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False ) -> None: """Save predicted results and reference transcripts to a file. Args: filename: File to save the results to. texts: An iterable of tuples. The first element is the cur_id, the second is the reference transcript and the third element is the predicted result. If it is a multi-talker ASR system, the ref and hyp may also be lists of strings. Returns: Return None. """ with open(filename, "w", encoding="utf8") as f: for cut_id, ref, hyp in texts: if char_level: ref = list("".join(ref)) hyp = list("".join(hyp)) print(f"{cut_id}:\tref={ref}", file=f) print(f"{cut_id}:\thyp={hyp}", file=f) def write_error_stats( f: TextIO, test_set_name: str, results: List[Tuple[str, str]], enable_log: bool = True, compute_CER: bool = False, sclite_mode: bool = False, ) -> float: """Write statistics based on predicted results and reference transcripts. It will write the following to the given file: - WER - number of insertions, deletions, substitutions, corrects and total reference words. For example:: Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606 reference words (2337 correct) - The difference between the reference transcript and predicted result. An instance is given below:: THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES The above example shows that the reference word is `EDISON`, but it is predicted to `ADDISON` (a substitution error). Another example is:: FOR THE FIRST DAY (SIR->*) I THINK The reference word `SIR` is missing in the predicted results (a deletion error). results: An iterable of tuples. The first element is the cut_id, the second is the reference transcript and the third element is the predicted result. enable_log: If True, also print detailed WER to the console. Otherwise, it is written only to the given file. Returns: Return None. """ subs: Dict[Tuple[str, str], int] = defaultdict(int) ins: Dict[str, int] = defaultdict(int) dels: Dict[str, int] = defaultdict(int) # `words` stores counts per word, as follows: # corr, ref_sub, hyp_sub, ins, dels words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0]) num_corr = 0 ERR = "*" if compute_CER: for i, res in enumerate(results): cut_id, ref, hyp = res ref = list("".join(ref)) hyp = list("".join(hyp)) results[i] = (cut_id, ref, hyp) for cut_id, ref, hyp in results: ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode) for ref_word, hyp_word in ali: if ref_word == ERR: ins[hyp_word] += 1 words[hyp_word][3] += 1 elif hyp_word == ERR: dels[ref_word] += 1 words[ref_word][4] += 1 elif hyp_word != ref_word: subs[(ref_word, hyp_word)] += 1 words[ref_word][1] += 1 words[hyp_word][2] += 1 else: words[ref_word][0] += 1 num_corr += 1 ref_len = sum([len(r) for _, r, _ in results]) sub_errs = sum(subs.values()) ins_errs = sum(ins.values()) del_errs = sum(dels.values()) tot_errs = sub_errs + ins_errs + del_errs tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len) if enable_log: logging.info( f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} " f"[{tot_errs} / {ref_len}, {ins_errs} ins, " f"{del_errs} del, {sub_errs} sub ]" ) print(f"%WER = {tot_err_rate}", file=f) print( f"Errors: {ins_errs} insertions, {del_errs} deletions, " f"{sub_errs} substitutions, over {ref_len} reference " f"words ({num_corr} correct)", file=f, ) print( "Search below for sections starting with PER-UTT DETAILS:, " "SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:", file=f, ) print("", file=f) print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f) for cut_id, ref, hyp in results: ali = kaldialign.align(ref, hyp, ERR) combine_successive_errors = True if combine_successive_errors: ali = [[[x], [y]] for x, y in ali] for i in range(len(ali) - 1): if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]: ali[i + 1][0] = ali[i][0] + ali[i + 1][0] ali[i + 1][1] = ali[i][1] + ali[i + 1][1] ali[i] = [[], []] ali = [ [ list(filter(lambda a: a != ERR, x)), list(filter(lambda a: a != ERR, y)), ] for x, y in ali ] ali = list(filter(lambda x: x != [[], []], ali)) ali = [ [ ERR if x == [] else " ".join(x), ERR if y == [] else " ".join(y), ] for x, y in ali ] print( f"{cut_id}:\t" + " ".join( ( ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})" for ref_word, hyp_word in ali ) ), file=f, ) print("", file=f) print("SUBSTITUTIONS: count ref -> hyp", file=f) for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True): print(f"{count} {ref} -> {hyp}", file=f) print("", file=f) print("DELETIONS: count ref", file=f) for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True): print(f"{count} {ref}", file=f) print("", file=f) print("INSERTIONS: count hyp", file=f) for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True): print(f"{count} {hyp}", file=f) print("", file=f) print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f) for _, word, counts in sorted( [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True ): (corr, ref_sub, hyp_sub, ins, dels) = counts tot_errs = ref_sub + hyp_sub + ins + dels ref_count = corr + ref_sub + dels hyp_count = corr + hyp_sub + ins print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) return float(tot_err_rate) def get_args(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument( "--tokens", type=str, help="Path to tokens.txt", ) parser.add_argument( "--hotwords-file", type=str, default="", help=""" The file containing hotwords, one words/phrases per line, like HELLO WORLD 你好世界 """, ) parser.add_argument( "--hotwords-score", type=float, default=1.5, help=""" The hotword score of each token for biasing word/phrase. Used only if --hotwords-file is given. """, ) parser.add_argument( "--modeling-unit", type=str, default="", help=""" The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe. Used only when hotwords-file is given. """, ) parser.add_argument( "--bpe-vocab", type=str, default="", help=""" The path to the bpe vocabulary, the bpe vocabulary is generated by sentencepiece, you can also export the bpe vocabulary through a bpe model by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given and modeling-unit is bpe or cjkchar+bpe. """, ) parser.add_argument( "--encoder", default="", type=str, help="Path to the encoder model", ) parser.add_argument( "--decoder", default="", type=str, help="Path to the decoder model", ) parser.add_argument( "--joiner", default="", type=str, help="Path to the joiner model", ) parser.add_argument( "--paraformer", default="", type=str, help="Path to the model.onnx from Paraformer", ) parser.add_argument( "--nemo-ctc", default="", type=str, help="Path to the model.onnx from NeMo CTC", ) parser.add_argument( "--wenet-ctc", default="", type=str, help="Path to the model.onnx from WeNet CTC", ) parser.add_argument( "--tdnn-model", default="", type=str, help="Path to the model.onnx for the tdnn model of the yesno recipe", ) parser.add_argument( "--num-threads", type=int, default=1, help="Number of threads for neural network computation", ) parser.add_argument( "--whisper-encoder", default="", type=str, help="Path to whisper encoder model", ) parser.add_argument( "--whisper-decoder", default="", type=str, help="Path to whisper decoder model", ) parser.add_argument( "--whisper-language", default="", type=str, help="""It specifies the spoken language in the input audio file. Example values: en, fr, de, zh, jp. Available languages for multilingual models can be found at https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10 If not specified, we infer the language from the input audio file. """, ) parser.add_argument( "--whisper-task", default="transcribe", choices=["transcribe", "translate"], type=str, help="""For multilingual models, if you specify translate, the output will be in English. """, ) parser.add_argument( "--whisper-tail-paddings", default=-1, type=int, help="""Number of tail padding frames. We have removed the 30-second constraint from whisper, so you need to choose the amount of tail padding frames by yourself. Use -1 to use a default value for tail padding. """, ) parser.add_argument( "--blank-penalty", type=float, default=0.0, help=""" The penalty applied on blank symbol during decoding. Note: It is a positive value that would be applied to logits like this `logits[:, 0] -= blank_penalty` (suppose logits.shape is [batch_size, vocab] and blank id is 0). """, ) parser.add_argument( "--decoding-method", type=str, default="greedy_search", help="Valid values are greedy_search and modified_beam_search", ) parser.add_argument( "--debug", type=bool, default=False, help="True to show debug messages", ) parser.add_argument( "--sample-rate", type=int, default=16000, help="""Sample rate of the feature extractor. Must match the one expected by the model. Note: The input sound files can have a different sample rate from this argument.""", ) parser.add_argument( "--feature-dim", type=int, default=80, help="Feature dimension. Must match the one expected by the model", ) parser.add_argument( "sound_files", type=str, nargs="+", help="The input sound file(s) to decode. Each file must be of WAVE" "format with a single channel, and each sample has 16-bit, " "i.e., int16_t. " "The sample rate of the file can be arbitrary and does not need to " "be 16 kHz", ) parser.add_argument( "--name", type=str, default="", help="The directory containing the input sound files to decode", ) parser.add_argument( "--log-dir", type=str, default="", help="The directory containing the input sound files to decode", ) parser.add_argument( "--label", type=str, default=None, help="wav_base_name label", ) # Dataset related arguments for loading labels when label file is not provided parser.add_argument( "--dataset-name", type=str, default="yuekai/seed_tts_cosy2", help="Huggingface dataset name for loading labels", ) parser.add_argument( "--split-name", type=str, default="wenetspeech4tts", help="Dataset split name for loading labels", ) return parser.parse_args() def assert_file_exists(filename: str): assert Path(filename).is_file(), ( f"{filename} does not exist!\n" "Please refer to " "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" ) def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: """ Args: wave_filename: Path to a wave file. It should be single channel and can be of type 32-bit floating point PCM. Its sample rate does not need to be 24kHz. Returns: Return a tuple containing: - A 1-D array of dtype np.float32 containing the samples, which are normalized to the range [-1, 1]. - Sample rate of the wave file. """ samples, sample_rate = sf.read(wave_filename, dtype="float32") assert ( samples.ndim == 1 ), f"Expected single channel, but got {samples.ndim} channels." samples_float32 = samples.astype(np.float32) return samples_float32, sample_rate def normalize_text_alimeeting(text: str) -> str: """ Text normalization similar to M2MeT challenge baseline. See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl """ import re text = text.replace('\u00A0', '') # test_hard text = text.replace(" ", "") text = text.replace("", "") text = text.replace("<%>", "") text = text.replace("<->", "") text = text.replace("<$>", "") text = text.replace("<#>", "") text = text.replace("<_>", "") text = text.replace("", "") text = text.replace("`", "") text = text.replace("&", "") text = text.replace(",", "") if re.search("[a-zA-Z]", text): text = text.upper() text = text.replace("A", "A") text = text.replace("a", "A") text = text.replace("b", "B") text = text.replace("c", "C") text = text.replace("k", "K") text = text.replace("t", "T") text = text.replace(",", "") text = text.replace("丶", "") text = text.replace("。", "") text = text.replace("、", "") text = text.replace("?", "") text = remove_punctuation(text) return text def main(): args = get_args() assert_file_exists(args.tokens) assert args.num_threads > 0, args.num_threads assert len(args.nemo_ctc) == 0, args.nemo_ctc assert len(args.wenet_ctc) == 0, args.wenet_ctc assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_decoder) == 0, args.whisper_decoder assert len(args.tdnn_model) == 0, args.tdnn_model assert_file_exists(args.paraformer) recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( paraformer=args.paraformer, tokens=args.tokens, num_threads=args.num_threads, sample_rate=args.sample_rate, feature_dim=args.feature_dim, decoding_method=args.decoding_method, debug=args.debug, ) print("Started!") start_time = time.time() streams, results = [], [] total_duration = 0 for i, wave_filename in enumerate(args.sound_files): assert_file_exists(wave_filename) samples, sample_rate = read_wave(wave_filename) duration = len(samples) / sample_rate total_duration += duration s = recognizer.create_stream() s.accept_waveform(sample_rate, samples) streams.append(s) if i % 10 == 0: recognizer.decode_streams(streams) results += [s.result.text for s in streams] streams = [] print(f"Processed {i} files") # process the last batch if streams: recognizer.decode_streams(streams) results += [s.result.text for s in streams] end_time = time.time() print("Done!") results_dict = {} for wave_filename, result in zip(args.sound_files, results): print(f"{wave_filename}\n{result}") print("-" * 10) wave_basename = Path(wave_filename).stem results_dict[wave_basename] = result elapsed_seconds = end_time - start_time rtf = elapsed_seconds / total_duration print(f"num_threads: {args.num_threads}") print(f"decoding_method: {args.decoding_method}") print(f"Wave duration: {total_duration:.3f} s") print(f"Elapsed time: {elapsed_seconds:.3f} s") print( f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" ) # Load labels either from file or from dataset labels_dict = {} if args.label: # Load labels from file (original functionality) print(f"Loading labels from file: {args.label}") with open(args.label, "r") as f: for line in f: # fields = line.strip().split(" ") # fields = [item for item in fields if item] # assert len(fields) == 4 # prompt_text, prompt_audio, text, audio_path = fields fields = line.strip().split("|") fields = [item for item in fields if item] assert len(fields) == 4 audio_path, prompt_text, prompt_audio, text = fields labels_dict[Path(audio_path).stem] = normalize_text_alimeeting(text) else: # Load labels from dataset (new functionality) print(f"Loading labels from dataset: {args.dataset_name}, split: {args.split_name}") if 'zero' in args.split_name: dataset_name = "yuekai/CV3-Eval" else: dataset_name = "yuekai/seed_tts_cosy2" dataset = load_dataset( dataset_name, split=args.split_name, trust_remote_code=True, ) for item in dataset: audio_id = item["id"] labels_dict[audio_id] = normalize_text_alimeeting(item["target_text"]) print(f"Loaded {len(labels_dict)} labels from dataset") # Perform evaluation if labels are available if labels_dict: final_results = [] for key, value in results_dict.items(): if key in labels_dict: final_results.append((key, labels_dict[key], value)) else: print(f"Warning: No label found for {key}, skipping...") if final_results: store_transcripts( filename=f"{args.log_dir}/recogs-{args.name}.txt", texts=final_results ) with open(f"{args.log_dir}/errs-{args.name}.txt", "w") as f: write_error_stats(f, "test-set", final_results, enable_log=True) with open(f"{args.log_dir}/errs-{args.name}.txt", "r") as f: print(f.readline()) # WER print(f.readline()) # Detailed errors else: print("No matching labels found for evaluation") else: print("No labels available for evaluation") if __name__ == "__main__": main()