offline-decode-files.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754
  1. # Copyright (c) 2023 by manyeyes
  2. # Copyright (c) 2023 Xiaomi Corporation
  3. """
  4. This file demonstrates how to use sherpa-onnx Python API to transcribe
  5. file(s) with a non-streaming model.
  6. (1) For paraformer
  7. ./python-api-examples/offline-decode-files.py \
  8. --tokens=/path/to/tokens.txt \
  9. --paraformer=/path/to/paraformer.onnx \
  10. --num-threads=2 \
  11. --decoding-method=greedy_search \
  12. --debug=false \
  13. --sample-rate=16000 \
  14. --feature-dim=80 \
  15. /path/to/0.wav \
  16. /path/to/1.wav
  17. (2) For transducer models from icefall
  18. ./python-api-examples/offline-decode-files.py \
  19. --tokens=/path/to/tokens.txt \
  20. --encoder=/path/to/encoder.onnx \
  21. --decoder=/path/to/decoder.onnx \
  22. --joiner=/path/to/joiner.onnx \
  23. --num-threads=2 \
  24. --decoding-method=greedy_search \
  25. --debug=false \
  26. --sample-rate=16000 \
  27. --feature-dim=80 \
  28. /path/to/0.wav \
  29. /path/to/1.wav
  30. (3) For CTC models from NeMo
  31. python3 ./python-api-examples/offline-decode-files.py \
  32. --tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \
  33. --nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \
  34. --num-threads=2 \
  35. --decoding-method=greedy_search \
  36. --debug=false \
  37. ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \
  38. ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \
  39. ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav
  40. (4) For Whisper models
  41. python3 ./python-api-examples/offline-decode-files.py \
  42. --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
  43. --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
  44. --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
  45. --whisper-task=transcribe \
  46. --num-threads=1 \
  47. ./sherpa-onnx-whisper-base.en/test_wavs/0.wav \
  48. ./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
  49. ./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
  50. (5) For CTC models from WeNet
  51. python3 ./python-api-examples/offline-decode-files.py \
  52. --wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
  53. --tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
  54. ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \
  55. ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \
  56. ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav
  57. (6) For tdnn models of the yesno recipe from icefall
  58. python3 ./python-api-examples/offline-decode-files.py \
  59. --sample-rate=8000 \
  60. --feature-dim=23 \
  61. --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
  62. --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \
  63. ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \
  64. ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \
  65. ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav
  66. Please refer to
  67. https://k2-fsa.github.io/sherpa/onnx/index.html
  68. to install sherpa-onnx and to download non-streaming pre-trained models
  69. used in this file.
  70. """
  71. import argparse
  72. import time
  73. import wave
  74. from pathlib import Path
  75. from typing import List, Tuple, Dict, Iterable, TextIO, Union
  76. import numpy as np
  77. import sherpa_onnx
  78. import soundfile as sf
  79. from datasets import load_dataset
  80. import logging
  81. from collections import defaultdict
  82. import kaldialign
  83. from zhon.hanzi import punctuation
  84. import string
  85. punctuation_all = punctuation + string.punctuation
  86. Pathlike = Union[str, Path]
  87. def remove_punctuation(text: str) -> str:
  88. for x in punctuation_all:
  89. if x == '\'':
  90. continue
  91. text = text.replace(x, '')
  92. return text
  93. def store_transcripts(
  94. filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
  95. ) -> None:
  96. """Save predicted results and reference transcripts to a file.
  97. Args:
  98. filename:
  99. File to save the results to.
  100. texts:
  101. An iterable of tuples. The first element is the cur_id, the second is
  102. the reference transcript and the third element is the predicted result.
  103. If it is a multi-talker ASR system, the ref and hyp may also be lists of
  104. strings.
  105. Returns:
  106. Return None.
  107. """
  108. with open(filename, "w", encoding="utf8") as f:
  109. for cut_id, ref, hyp in texts:
  110. if char_level:
  111. ref = list("".join(ref))
  112. hyp = list("".join(hyp))
  113. print(f"{cut_id}:\tref={ref}", file=f)
  114. print(f"{cut_id}:\thyp={hyp}", file=f)
  115. def write_error_stats(
  116. f: TextIO,
  117. test_set_name: str,
  118. results: List[Tuple[str, str]],
  119. enable_log: bool = True,
  120. compute_CER: bool = False,
  121. sclite_mode: bool = False,
  122. ) -> float:
  123. """Write statistics based on predicted results and reference transcripts.
  124. It will write the following to the given file:
  125. - WER
  126. - number of insertions, deletions, substitutions, corrects and total
  127. reference words. For example::
  128. Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
  129. reference words (2337 correct)
  130. - The difference between the reference transcript and predicted result.
  131. An instance is given below::
  132. THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
  133. The above example shows that the reference word is `EDISON`,
  134. but it is predicted to `ADDISON` (a substitution error).
  135. Another example is::
  136. FOR THE FIRST DAY (SIR->*) I THINK
  137. The reference word `SIR` is missing in the predicted
  138. results (a deletion error).
  139. results:
  140. An iterable of tuples. The first element is the cut_id, the second is
  141. the reference transcript and the third element is the predicted result.
  142. enable_log:
  143. If True, also print detailed WER to the console.
  144. Otherwise, it is written only to the given file.
  145. Returns:
  146. Return None.
  147. """
  148. subs: Dict[Tuple[str, str], int] = defaultdict(int)
  149. ins: Dict[str, int] = defaultdict(int)
  150. dels: Dict[str, int] = defaultdict(int)
  151. # `words` stores counts per word, as follows:
  152. # corr, ref_sub, hyp_sub, ins, dels
  153. words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
  154. num_corr = 0
  155. ERR = "*"
  156. if compute_CER:
  157. for i, res in enumerate(results):
  158. cut_id, ref, hyp = res
  159. ref = list("".join(ref))
  160. hyp = list("".join(hyp))
  161. results[i] = (cut_id, ref, hyp)
  162. for _cut_id, ref, hyp in results:
  163. ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
  164. for ref_word, hyp_word in ali:
  165. if ref_word == ERR:
  166. ins[hyp_word] += 1
  167. words[hyp_word][3] += 1
  168. elif hyp_word == ERR:
  169. dels[ref_word] += 1
  170. words[ref_word][4] += 1
  171. elif hyp_word != ref_word:
  172. subs[(ref_word, hyp_word)] += 1
  173. words[ref_word][1] += 1
  174. words[hyp_word][2] += 1
  175. else:
  176. words[ref_word][0] += 1
  177. num_corr += 1
  178. ref_len = sum([len(r) for _, r, _ in results])
  179. sub_errs = sum(subs.values())
  180. ins_errs = sum(ins.values())
  181. del_errs = sum(dels.values())
  182. tot_errs = sub_errs + ins_errs + del_errs
  183. tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
  184. if enable_log:
  185. logging.info(
  186. f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
  187. f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
  188. f"{del_errs} del, {sub_errs} sub ]"
  189. )
  190. print(f"%WER = {tot_err_rate}", file=f)
  191. print(
  192. f"Errors: {ins_errs} insertions, {del_errs} deletions, "
  193. f"{sub_errs} substitutions, over {ref_len} reference "
  194. f"words ({num_corr} correct)",
  195. file=f,
  196. )
  197. print(
  198. "Search below for sections starting with PER-UTT DETAILS:, "
  199. "SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
  200. file=f,
  201. )
  202. print("", file=f)
  203. print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
  204. for cut_id, ref, hyp in results:
  205. ali = kaldialign.align(ref, hyp, ERR)
  206. combine_successive_errors = True
  207. if combine_successive_errors:
  208. ali = [[[x], [y]] for x, y in ali]
  209. for i in range(len(ali) - 1):
  210. if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
  211. ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
  212. ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
  213. ali[i] = [[], []]
  214. ali = [
  215. [
  216. list(filter(lambda a: a != ERR, x)),
  217. list(filter(lambda a: a != ERR, y)),
  218. ]
  219. for x, y in ali
  220. ]
  221. ali = list(filter(lambda x: x != [[], []], ali))
  222. ali = [
  223. [
  224. ERR if x == [] else " ".join(x),
  225. ERR if y == [] else " ".join(y),
  226. ]
  227. for x, y in ali
  228. ]
  229. print(
  230. f"{cut_id}:\t"
  231. + " ".join(
  232. (
  233. ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
  234. for ref_word, hyp_word in ali
  235. )
  236. ),
  237. file=f,
  238. )
  239. print("", file=f)
  240. print("SUBSTITUTIONS: count ref -> hyp", file=f)
  241. for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
  242. print(f"{count} {ref} -> {hyp}", file=f)
  243. print("", file=f)
  244. print("DELETIONS: count ref", file=f)
  245. for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
  246. print(f"{count} {ref}", file=f)
  247. print("", file=f)
  248. print("INSERTIONS: count hyp", file=f)
  249. for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
  250. print(f"{count} {hyp}", file=f)
  251. print("", file=f)
  252. print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
  253. for _, word, counts in sorted(
  254. [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
  255. ):
  256. (corr, ref_sub, hyp_sub, ins, dels) = counts
  257. tot_errs = ref_sub + hyp_sub + ins + dels
  258. ref_count = corr + ref_sub + dels
  259. hyp_count = corr + hyp_sub + ins
  260. print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
  261. return float(tot_err_rate)
  262. def get_args():
  263. parser = argparse.ArgumentParser(
  264. formatter_class=argparse.ArgumentDefaultsHelpFormatter
  265. )
  266. parser.add_argument(
  267. "--tokens",
  268. type=str,
  269. help="Path to tokens.txt",
  270. )
  271. parser.add_argument(
  272. "--hotwords-file",
  273. type=str,
  274. default="",
  275. help="""
  276. The file containing hotwords, one words/phrases per line, like
  277. HELLO WORLD
  278. 你好世界
  279. """,
  280. )
  281. parser.add_argument(
  282. "--hotwords-score",
  283. type=float,
  284. default=1.5,
  285. help="""
  286. The hotword score of each token for biasing word/phrase. Used only if
  287. --hotwords-file is given.
  288. """,
  289. )
  290. parser.add_argument(
  291. "--modeling-unit",
  292. type=str,
  293. default="",
  294. help="""
  295. The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe.
  296. Used only when hotwords-file is given.
  297. """,
  298. )
  299. parser.add_argument(
  300. "--bpe-vocab",
  301. type=str,
  302. default="",
  303. help="""
  304. The path to the bpe vocabulary, the bpe vocabulary is generated by
  305. sentencepiece, you can also export the bpe vocabulary through a bpe model
  306. by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given
  307. and modeling-unit is bpe or cjkchar+bpe.
  308. """,
  309. )
  310. parser.add_argument(
  311. "--encoder",
  312. default="",
  313. type=str,
  314. help="Path to the encoder model",
  315. )
  316. parser.add_argument(
  317. "--decoder",
  318. default="",
  319. type=str,
  320. help="Path to the decoder model",
  321. )
  322. parser.add_argument(
  323. "--joiner",
  324. default="",
  325. type=str,
  326. help="Path to the joiner model",
  327. )
  328. parser.add_argument(
  329. "--paraformer",
  330. default="",
  331. type=str,
  332. help="Path to the model.onnx from Paraformer",
  333. )
  334. parser.add_argument(
  335. "--nemo-ctc",
  336. default="",
  337. type=str,
  338. help="Path to the model.onnx from NeMo CTC",
  339. )
  340. parser.add_argument(
  341. "--wenet-ctc",
  342. default="",
  343. type=str,
  344. help="Path to the model.onnx from WeNet CTC",
  345. )
  346. parser.add_argument(
  347. "--tdnn-model",
  348. default="",
  349. type=str,
  350. help="Path to the model.onnx for the tdnn model of the yesno recipe",
  351. )
  352. parser.add_argument(
  353. "--num-threads",
  354. type=int,
  355. default=1,
  356. help="Number of threads for neural network computation",
  357. )
  358. parser.add_argument(
  359. "--whisper-encoder",
  360. default="",
  361. type=str,
  362. help="Path to whisper encoder model",
  363. )
  364. parser.add_argument(
  365. "--whisper-decoder",
  366. default="",
  367. type=str,
  368. help="Path to whisper decoder model",
  369. )
  370. parser.add_argument(
  371. "--whisper-language",
  372. default="",
  373. type=str,
  374. help="""It specifies the spoken language in the input audio file.
  375. Example values: en, fr, de, zh, jp.
  376. Available languages for multilingual models can be found at
  377. https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
  378. If not specified, we infer the language from the input audio file.
  379. """,
  380. )
  381. parser.add_argument(
  382. "--whisper-task",
  383. default="transcribe",
  384. choices=["transcribe", "translate"],
  385. type=str,
  386. help="""For multilingual models, if you specify translate, the output
  387. will be in English.
  388. """,
  389. )
  390. parser.add_argument(
  391. "--whisper-tail-paddings",
  392. default=-1,
  393. type=int,
  394. help="""Number of tail padding frames.
  395. We have removed the 30-second constraint from whisper, so you need to
  396. choose the amount of tail padding frames by yourself.
  397. Use -1 to use a default value for tail padding.
  398. """,
  399. )
  400. parser.add_argument(
  401. "--blank-penalty",
  402. type=float,
  403. default=0.0,
  404. help="""
  405. The penalty applied on blank symbol during decoding.
  406. Note: It is a positive value that would be applied to logits like
  407. this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
  408. [batch_size, vocab] and blank id is 0).
  409. """,
  410. )
  411. parser.add_argument(
  412. "--decoding-method",
  413. type=str,
  414. default="greedy_search",
  415. help="Valid values are greedy_search and modified_beam_search",
  416. )
  417. parser.add_argument(
  418. "--debug",
  419. type=bool,
  420. default=False,
  421. help="True to show debug messages",
  422. )
  423. parser.add_argument(
  424. "--sample-rate",
  425. type=int,
  426. default=16000,
  427. help="""Sample rate of the feature extractor. Must match the one
  428. expected by the model. Note: The input sound files can have a
  429. different sample rate from this argument.""",
  430. )
  431. parser.add_argument(
  432. "--feature-dim",
  433. type=int,
  434. default=80,
  435. help="Feature dimension. Must match the one expected by the model",
  436. )
  437. parser.add_argument(
  438. "sound_files",
  439. type=str,
  440. nargs="+",
  441. help="The input sound file(s) to decode. Each file must be of WAVE"
  442. "format with a single channel, and each sample has 16-bit, "
  443. "i.e., int16_t. "
  444. "The sample rate of the file can be arbitrary and does not need to "
  445. "be 16 kHz",
  446. )
  447. parser.add_argument(
  448. "--name",
  449. type=str,
  450. default="",
  451. help="The directory containing the input sound files to decode",
  452. )
  453. parser.add_argument(
  454. "--log-dir",
  455. type=str,
  456. default="",
  457. help="The directory containing the input sound files to decode",
  458. )
  459. parser.add_argument(
  460. "--label",
  461. type=str,
  462. default=None,
  463. help="wav_base_name label",
  464. )
  465. # Dataset related arguments for loading labels when label file is not provided
  466. parser.add_argument(
  467. "--dataset-name",
  468. type=str,
  469. default="yuekai/seed_tts_cosy2",
  470. help="Huggingface dataset name for loading labels",
  471. )
  472. parser.add_argument(
  473. "--split-name",
  474. type=str,
  475. default="wenetspeech4tts",
  476. help="Dataset split name for loading labels",
  477. )
  478. return parser.parse_args()
  479. def assert_file_exists(filename: str):
  480. assert Path(filename).is_file(), (
  481. f"{filename} does not exist!\n"
  482. "Please refer to "
  483. "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
  484. )
  485. def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
  486. """
  487. Args:
  488. wave_filename:
  489. Path to a wave file. It should be single channel and can be of type
  490. 32-bit floating point PCM. Its sample rate does not need to be 24kHz.
  491. Returns:
  492. Return a tuple containing:
  493. - A 1-D array of dtype np.float32 containing the samples,
  494. which are normalized to the range [-1, 1].
  495. - Sample rate of the wave file.
  496. """
  497. samples, sample_rate = sf.read(wave_filename, dtype="float32")
  498. assert (
  499. samples.ndim == 1
  500. ), f"Expected single channel, but got {samples.ndim} channels."
  501. samples_float32 = samples.astype(np.float32)
  502. return samples_float32, sample_rate
  503. def normalize_text_alimeeting(text: str) -> str:
  504. """
  505. Text normalization similar to M2MeT challenge baseline.
  506. See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
  507. """
  508. import re
  509. text = text.replace('\u00A0', '') # test_hard
  510. text = text.replace(" ", "")
  511. text = text.replace("<sil>", "")
  512. text = text.replace("<%>", "")
  513. text = text.replace("<->", "")
  514. text = text.replace("<$>", "")
  515. text = text.replace("<#>", "")
  516. text = text.replace("<_>", "")
  517. text = text.replace("<space>", "")
  518. text = text.replace("`", "")
  519. text = text.replace("&", "")
  520. text = text.replace(",", "")
  521. if re.search("[a-zA-Z]", text):
  522. text = text.upper()
  523. text = text.replace("A", "A")
  524. text = text.replace("a", "A")
  525. text = text.replace("b", "B")
  526. text = text.replace("c", "C")
  527. text = text.replace("k", "K")
  528. text = text.replace("t", "T")
  529. text = text.replace(",", "")
  530. text = text.replace("丶", "")
  531. text = text.replace("。", "")
  532. text = text.replace("、", "")
  533. text = text.replace("?", "")
  534. text = remove_punctuation(text)
  535. return text
  536. def main():
  537. args = get_args()
  538. assert_file_exists(args.tokens)
  539. assert args.num_threads > 0, args.num_threads
  540. assert len(args.nemo_ctc) == 0, args.nemo_ctc
  541. assert len(args.wenet_ctc) == 0, args.wenet_ctc
  542. assert len(args.whisper_encoder) == 0, args.whisper_encoder
  543. assert len(args.whisper_decoder) == 0, args.whisper_decoder
  544. assert len(args.tdnn_model) == 0, args.tdnn_model
  545. assert_file_exists(args.paraformer)
  546. recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
  547. paraformer=args.paraformer,
  548. tokens=args.tokens,
  549. num_threads=args.num_threads,
  550. sample_rate=args.sample_rate,
  551. feature_dim=args.feature_dim,
  552. decoding_method=args.decoding_method,
  553. debug=args.debug,
  554. )
  555. print("Started!")
  556. start_time = time.time()
  557. streams, results = [], []
  558. total_duration = 0
  559. for i, wave_filename in enumerate(args.sound_files):
  560. assert_file_exists(wave_filename)
  561. samples, sample_rate = read_wave(wave_filename)
  562. duration = len(samples) / sample_rate
  563. total_duration += duration
  564. s = recognizer.create_stream()
  565. s.accept_waveform(sample_rate, samples)
  566. streams.append(s)
  567. if i % 10 == 0:
  568. recognizer.decode_streams(streams)
  569. results += [s.result.text for s in streams]
  570. streams = []
  571. print(f"Processed {i} files")
  572. # process the last batch
  573. if streams:
  574. recognizer.decode_streams(streams)
  575. results += [s.result.text for s in streams]
  576. end_time = time.time()
  577. print("Done!")
  578. results_dict = {}
  579. for wave_filename, result in zip(args.sound_files, results):
  580. print(f"{wave_filename}\n{result}")
  581. print("-" * 10)
  582. wave_basename = Path(wave_filename).stem
  583. results_dict[wave_basename] = result
  584. elapsed_seconds = end_time - start_time
  585. rtf = elapsed_seconds / total_duration
  586. print(f"num_threads: {args.num_threads}")
  587. print(f"decoding_method: {args.decoding_method}")
  588. print(f"Wave duration: {total_duration:.3f} s")
  589. print(f"Elapsed time: {elapsed_seconds:.3f} s")
  590. print(
  591. f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
  592. )
  593. # Load labels either from file or from dataset
  594. labels_dict = {}
  595. if args.label:
  596. # Load labels from file (original functionality)
  597. print(f"Loading labels from file: {args.label}")
  598. with open(args.label, "r") as f:
  599. for line in f:
  600. # fields = line.strip().split(" ")
  601. # fields = [item for item in fields if item]
  602. # assert len(fields) == 4
  603. # prompt_text, prompt_audio, text, audio_path = fields
  604. fields = line.strip().split("|")
  605. fields = [item for item in fields if item]
  606. assert len(fields) == 4
  607. audio_path, prompt_text, prompt_audio, text = fields
  608. labels_dict[Path(audio_path).stem] = normalize_text_alimeeting(text)
  609. else:
  610. # Load labels from dataset (new functionality)
  611. print(f"Loading labels from dataset: {args.dataset_name}, split: {args.split_name}")
  612. if 'zero' in args.split_name:
  613. dataset_name = "yuekai/CV3-Eval"
  614. else:
  615. dataset_name = "yuekai/seed_tts_cosy2"
  616. dataset = load_dataset(
  617. dataset_name,
  618. split=args.split_name,
  619. trust_remote_code=True,
  620. )
  621. for item in dataset:
  622. audio_id = item["id"]
  623. labels_dict[audio_id] = normalize_text_alimeeting(item["target_text"])
  624. print(f"Loaded {len(labels_dict)} labels from dataset")
  625. # Perform evaluation if labels are available
  626. if labels_dict:
  627. final_results = []
  628. for key, value in results_dict.items():
  629. if key in labels_dict:
  630. final_results.append((key, labels_dict[key], value))
  631. else:
  632. print(f"Warning: No label found for {key}, skipping...")
  633. if final_results:
  634. store_transcripts(
  635. filename=f"{args.log_dir}/recogs-{args.name}.txt", texts=final_results
  636. )
  637. with open(f"{args.log_dir}/errs-{args.name}.txt", "w") as f:
  638. write_error_stats(f, "test-set", final_results, enable_log=True)
  639. with open(f"{args.log_dir}/errs-{args.name}.txt", "r") as f:
  640. print(f.readline()) # WER
  641. print(f.readline()) # Detailed errors
  642. else:
  643. print("No matching labels found for evaluation")
  644. else:
  645. print("No labels available for evaluation")
  646. if __name__ == "__main__":
  647. main()