offline-decode-files.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753
  1. #!/usr/bin/env python3
  2. #
  3. # Copyright (c) 2023 by manyeyes
  4. # Copyright (c) 2023 Xiaomi Corporation
  5. """
  6. This file demonstrates how to use sherpa-onnx Python API to transcribe
  7. file(s) with a non-streaming model.
  8. (1) For paraformer
  9. ./python-api-examples/offline-decode-files.py \
  10. --tokens=/path/to/tokens.txt \
  11. --paraformer=/path/to/paraformer.onnx \
  12. --num-threads=2 \
  13. --decoding-method=greedy_search \
  14. --debug=false \
  15. --sample-rate=16000 \
  16. --feature-dim=80 \
  17. /path/to/0.wav \
  18. /path/to/1.wav
  19. (2) For transducer models from icefall
  20. ./python-api-examples/offline-decode-files.py \
  21. --tokens=/path/to/tokens.txt \
  22. --encoder=/path/to/encoder.onnx \
  23. --decoder=/path/to/decoder.onnx \
  24. --joiner=/path/to/joiner.onnx \
  25. --num-threads=2 \
  26. --decoding-method=greedy_search \
  27. --debug=false \
  28. --sample-rate=16000 \
  29. --feature-dim=80 \
  30. /path/to/0.wav \
  31. /path/to/1.wav
  32. (3) For CTC models from NeMo
  33. python3 ./python-api-examples/offline-decode-files.py \
  34. --tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \
  35. --nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \
  36. --num-threads=2 \
  37. --decoding-method=greedy_search \
  38. --debug=false \
  39. ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \
  40. ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \
  41. ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav
  42. (4) For Whisper models
  43. python3 ./python-api-examples/offline-decode-files.py \
  44. --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
  45. --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
  46. --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
  47. --whisper-task=transcribe \
  48. --num-threads=1 \
  49. ./sherpa-onnx-whisper-base.en/test_wavs/0.wav \
  50. ./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
  51. ./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
  52. (5) For CTC models from WeNet
  53. python3 ./python-api-examples/offline-decode-files.py \
  54. --wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
  55. --tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
  56. ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \
  57. ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \
  58. ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav
  59. (6) For tdnn models of the yesno recipe from icefall
  60. python3 ./python-api-examples/offline-decode-files.py \
  61. --sample-rate=8000 \
  62. --feature-dim=23 \
  63. --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
  64. --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \
  65. ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \
  66. ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \
  67. ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav
  68. Please refer to
  69. https://k2-fsa.github.io/sherpa/onnx/index.html
  70. to install sherpa-onnx and to download non-streaming pre-trained models
  71. used in this file.
  72. """
  73. import argparse
  74. import time
  75. import wave
  76. from pathlib import Path
  77. from typing import List, Tuple, Dict, Iterable, TextIO, Union
  78. import numpy as np
  79. import sherpa_onnx
  80. import soundfile as sf
  81. from datasets import load_dataset
  82. import logging
  83. from collections import defaultdict
  84. import kaldialign
  85. from zhon.hanzi import punctuation
  86. import string
  87. punctuation_all = punctuation + string.punctuation
  88. Pathlike = Union[str, Path]
  89. def remove_punctuation(text: str) -> str:
  90. for x in punctuation_all:
  91. if x == '\'':
  92. continue
  93. text = text.replace(x, '')
  94. return text
  95. def store_transcripts(
  96. filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
  97. ) -> None:
  98. """Save predicted results and reference transcripts to a file.
  99. Args:
  100. filename:
  101. File to save the results to.
  102. texts:
  103. An iterable of tuples. The first element is the cur_id, the second is
  104. the reference transcript and the third element is the predicted result.
  105. If it is a multi-talker ASR system, the ref and hyp may also be lists of
  106. strings.
  107. Returns:
  108. Return None.
  109. """
  110. with open(filename, "w", encoding="utf8") as f:
  111. for cut_id, ref, hyp in texts:
  112. if char_level:
  113. ref = list("".join(ref))
  114. hyp = list("".join(hyp))
  115. print(f"{cut_id}:\tref={ref}", file=f)
  116. print(f"{cut_id}:\thyp={hyp}", file=f)
  117. def write_error_stats(
  118. f: TextIO,
  119. test_set_name: str,
  120. results: List[Tuple[str, str]],
  121. enable_log: bool = True,
  122. compute_CER: bool = False,
  123. sclite_mode: bool = False,
  124. ) -> float:
  125. """Write statistics based on predicted results and reference transcripts.
  126. It will write the following to the given file:
  127. - WER
  128. - number of insertions, deletions, substitutions, corrects and total
  129. reference words. For example::
  130. Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
  131. reference words (2337 correct)
  132. - The difference between the reference transcript and predicted result.
  133. An instance is given below::
  134. THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
  135. The above example shows that the reference word is `EDISON`,
  136. but it is predicted to `ADDISON` (a substitution error).
  137. Another example is::
  138. FOR THE FIRST DAY (SIR->*) I THINK
  139. The reference word `SIR` is missing in the predicted
  140. results (a deletion error).
  141. results:
  142. An iterable of tuples. The first element is the cut_id, the second is
  143. the reference transcript and the third element is the predicted result.
  144. enable_log:
  145. If True, also print detailed WER to the console.
  146. Otherwise, it is written only to the given file.
  147. Returns:
  148. Return None.
  149. """
  150. subs: Dict[Tuple[str, str], int] = defaultdict(int)
  151. ins: Dict[str, int] = defaultdict(int)
  152. dels: Dict[str, int] = defaultdict(int)
  153. # `words` stores counts per word, as follows:
  154. # corr, ref_sub, hyp_sub, ins, dels
  155. words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
  156. num_corr = 0
  157. ERR = "*"
  158. if compute_CER:
  159. for i, res in enumerate(results):
  160. cut_id, ref, hyp = res
  161. ref = list("".join(ref))
  162. hyp = list("".join(hyp))
  163. results[i] = (cut_id, ref, hyp)
  164. for cut_id, ref, hyp in results:
  165. ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
  166. for ref_word, hyp_word in ali:
  167. if ref_word == ERR:
  168. ins[hyp_word] += 1
  169. words[hyp_word][3] += 1
  170. elif hyp_word == ERR:
  171. dels[ref_word] += 1
  172. words[ref_word][4] += 1
  173. elif hyp_word != ref_word:
  174. subs[(ref_word, hyp_word)] += 1
  175. words[ref_word][1] += 1
  176. words[hyp_word][2] += 1
  177. else:
  178. words[ref_word][0] += 1
  179. num_corr += 1
  180. ref_len = sum([len(r) for _, r, _ in results])
  181. sub_errs = sum(subs.values())
  182. ins_errs = sum(ins.values())
  183. del_errs = sum(dels.values())
  184. tot_errs = sub_errs + ins_errs + del_errs
  185. tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
  186. if enable_log:
  187. logging.info(
  188. f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
  189. f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
  190. f"{del_errs} del, {sub_errs} sub ]"
  191. )
  192. print(f"%WER = {tot_err_rate}", file=f)
  193. print(
  194. f"Errors: {ins_errs} insertions, {del_errs} deletions, "
  195. f"{sub_errs} substitutions, over {ref_len} reference "
  196. f"words ({num_corr} correct)",
  197. file=f,
  198. )
  199. print(
  200. "Search below for sections starting with PER-UTT DETAILS:, "
  201. "SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
  202. file=f,
  203. )
  204. print("", file=f)
  205. print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
  206. for cut_id, ref, hyp in results:
  207. ali = kaldialign.align(ref, hyp, ERR)
  208. combine_successive_errors = True
  209. if combine_successive_errors:
  210. ali = [[[x], [y]] for x, y in ali]
  211. for i in range(len(ali) - 1):
  212. if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
  213. ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
  214. ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
  215. ali[i] = [[], []]
  216. ali = [
  217. [
  218. list(filter(lambda a: a != ERR, x)),
  219. list(filter(lambda a: a != ERR, y)),
  220. ]
  221. for x, y in ali
  222. ]
  223. ali = list(filter(lambda x: x != [[], []], ali))
  224. ali = [
  225. [
  226. ERR if x == [] else " ".join(x),
  227. ERR if y == [] else " ".join(y),
  228. ]
  229. for x, y in ali
  230. ]
  231. print(
  232. f"{cut_id}:\t"
  233. + " ".join(
  234. (
  235. ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
  236. for ref_word, hyp_word in ali
  237. )
  238. ),
  239. file=f,
  240. )
  241. print("", file=f)
  242. print("SUBSTITUTIONS: count ref -> hyp", file=f)
  243. for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
  244. print(f"{count} {ref} -> {hyp}", file=f)
  245. print("", file=f)
  246. print("DELETIONS: count ref", file=f)
  247. for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
  248. print(f"{count} {ref}", file=f)
  249. print("", file=f)
  250. print("INSERTIONS: count hyp", file=f)
  251. for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
  252. print(f"{count} {hyp}", file=f)
  253. print("", file=f)
  254. print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
  255. for _, word, counts in sorted(
  256. [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
  257. ):
  258. (corr, ref_sub, hyp_sub, ins, dels) = counts
  259. tot_errs = ref_sub + hyp_sub + ins + dels
  260. ref_count = corr + ref_sub + dels
  261. hyp_count = corr + hyp_sub + ins
  262. print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
  263. return float(tot_err_rate)
  264. def get_args():
  265. parser = argparse.ArgumentParser(
  266. formatter_class=argparse.ArgumentDefaultsHelpFormatter
  267. )
  268. parser.add_argument(
  269. "--tokens",
  270. type=str,
  271. help="Path to tokens.txt",
  272. )
  273. parser.add_argument(
  274. "--hotwords-file",
  275. type=str,
  276. default="",
  277. help="""
  278. The file containing hotwords, one words/phrases per line, like
  279. HELLO WORLD
  280. 你好世界
  281. """,
  282. )
  283. parser.add_argument(
  284. "--hotwords-score",
  285. type=float,
  286. default=1.5,
  287. help="""
  288. The hotword score of each token for biasing word/phrase. Used only if
  289. --hotwords-file is given.
  290. """,
  291. )
  292. parser.add_argument(
  293. "--modeling-unit",
  294. type=str,
  295. default="",
  296. help="""
  297. The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe.
  298. Used only when hotwords-file is given.
  299. """,
  300. )
  301. parser.add_argument(
  302. "--bpe-vocab",
  303. type=str,
  304. default="",
  305. help="""
  306. The path to the bpe vocabulary, the bpe vocabulary is generated by
  307. sentencepiece, you can also export the bpe vocabulary through a bpe model
  308. by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given
  309. and modeling-unit is bpe or cjkchar+bpe.
  310. """,
  311. )
  312. parser.add_argument(
  313. "--encoder",
  314. default="",
  315. type=str,
  316. help="Path to the encoder model",
  317. )
  318. parser.add_argument(
  319. "--decoder",
  320. default="",
  321. type=str,
  322. help="Path to the decoder model",
  323. )
  324. parser.add_argument(
  325. "--joiner",
  326. default="",
  327. type=str,
  328. help="Path to the joiner model",
  329. )
  330. parser.add_argument(
  331. "--paraformer",
  332. default="",
  333. type=str,
  334. help="Path to the model.onnx from Paraformer",
  335. )
  336. parser.add_argument(
  337. "--nemo-ctc",
  338. default="",
  339. type=str,
  340. help="Path to the model.onnx from NeMo CTC",
  341. )
  342. parser.add_argument(
  343. "--wenet-ctc",
  344. default="",
  345. type=str,
  346. help="Path to the model.onnx from WeNet CTC",
  347. )
  348. parser.add_argument(
  349. "--tdnn-model",
  350. default="",
  351. type=str,
  352. help="Path to the model.onnx for the tdnn model of the yesno recipe",
  353. )
  354. parser.add_argument(
  355. "--num-threads",
  356. type=int,
  357. default=1,
  358. help="Number of threads for neural network computation",
  359. )
  360. parser.add_argument(
  361. "--whisper-encoder",
  362. default="",
  363. type=str,
  364. help="Path to whisper encoder model",
  365. )
  366. parser.add_argument(
  367. "--whisper-decoder",
  368. default="",
  369. type=str,
  370. help="Path to whisper decoder model",
  371. )
  372. parser.add_argument(
  373. "--whisper-language",
  374. default="",
  375. type=str,
  376. help="""It specifies the spoken language in the input audio file.
  377. Example values: en, fr, de, zh, jp.
  378. Available languages for multilingual models can be found at
  379. https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
  380. If not specified, we infer the language from the input audio file.
  381. """,
  382. )
  383. parser.add_argument(
  384. "--whisper-task",
  385. default="transcribe",
  386. choices=["transcribe", "translate"],
  387. type=str,
  388. help="""For multilingual models, if you specify translate, the output
  389. will be in English.
  390. """,
  391. )
  392. parser.add_argument(
  393. "--whisper-tail-paddings",
  394. default=-1,
  395. type=int,
  396. help="""Number of tail padding frames.
  397. We have removed the 30-second constraint from whisper, so you need to
  398. choose the amount of tail padding frames by yourself.
  399. Use -1 to use a default value for tail padding.
  400. """,
  401. )
  402. parser.add_argument(
  403. "--blank-penalty",
  404. type=float,
  405. default=0.0,
  406. help="""
  407. The penalty applied on blank symbol during decoding.
  408. Note: It is a positive value that would be applied to logits like
  409. this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
  410. [batch_size, vocab] and blank id is 0).
  411. """,
  412. )
  413. parser.add_argument(
  414. "--decoding-method",
  415. type=str,
  416. default="greedy_search",
  417. help="Valid values are greedy_search and modified_beam_search",
  418. )
  419. parser.add_argument(
  420. "--debug",
  421. type=bool,
  422. default=False,
  423. help="True to show debug messages",
  424. )
  425. parser.add_argument(
  426. "--sample-rate",
  427. type=int,
  428. default=16000,
  429. help="""Sample rate of the feature extractor. Must match the one
  430. expected by the model. Note: The input sound files can have a
  431. different sample rate from this argument.""",
  432. )
  433. parser.add_argument(
  434. "--feature-dim",
  435. type=int,
  436. default=80,
  437. help="Feature dimension. Must match the one expected by the model",
  438. )
  439. parser.add_argument(
  440. "sound_files",
  441. type=str,
  442. nargs="+",
  443. help="The input sound file(s) to decode. Each file must be of WAVE"
  444. "format with a single channel, and each sample has 16-bit, "
  445. "i.e., int16_t. "
  446. "The sample rate of the file can be arbitrary and does not need to "
  447. "be 16 kHz",
  448. )
  449. parser.add_argument(
  450. "--name",
  451. type=str,
  452. default="",
  453. help="The directory containing the input sound files to decode",
  454. )
  455. parser.add_argument(
  456. "--log-dir",
  457. type=str,
  458. default="",
  459. help="The directory containing the input sound files to decode",
  460. )
  461. parser.add_argument(
  462. "--label",
  463. type=str,
  464. default=None,
  465. help="wav_base_name label",
  466. )
  467. # Dataset related arguments for loading labels when label file is not provided
  468. parser.add_argument(
  469. "--dataset-name",
  470. type=str,
  471. default="yuekai/seed_tts_cosy2",
  472. help="Huggingface dataset name for loading labels",
  473. )
  474. parser.add_argument(
  475. "--split-name",
  476. type=str,
  477. default="wenetspeech4tts",
  478. help="Dataset split name for loading labels",
  479. )
  480. return parser.parse_args()
  481. def assert_file_exists(filename: str):
  482. assert Path(filename).is_file(), (
  483. f"{filename} does not exist!\n"
  484. "Please refer to "
  485. "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
  486. )
  487. def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
  488. """
  489. Args:
  490. wave_filename:
  491. Path to a wave file. It should be single channel and can be of type
  492. 32-bit floating point PCM. Its sample rate does not need to be 24kHz.
  493. Returns:
  494. Return a tuple containing:
  495. - A 1-D array of dtype np.float32 containing the samples,
  496. which are normalized to the range [-1, 1].
  497. - Sample rate of the wave file.
  498. """
  499. samples, sample_rate = sf.read(wave_filename, dtype="float32")
  500. assert (
  501. samples.ndim == 1
  502. ), f"Expected single channel, but got {samples.ndim} channels."
  503. samples_float32 = samples.astype(np.float32)
  504. return samples_float32, sample_rate
  505. def normalize_text_alimeeting(text: str) -> str:
  506. """
  507. Text normalization similar to M2MeT challenge baseline.
  508. See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
  509. """
  510. import re
  511. text = text.replace('\u00A0', '') # test_hard
  512. text = text.replace(" ", "")
  513. text = text.replace("<sil>", "")
  514. text = text.replace("<%>", "")
  515. text = text.replace("<->", "")
  516. text = text.replace("<$>", "")
  517. text = text.replace("<#>", "")
  518. text = text.replace("<_>", "")
  519. text = text.replace("<space>", "")
  520. text = text.replace("`", "")
  521. text = text.replace("&", "")
  522. text = text.replace(",", "")
  523. if re.search("[a-zA-Z]", text):
  524. text = text.upper()
  525. text = text.replace("A", "A")
  526. text = text.replace("a", "A")
  527. text = text.replace("b", "B")
  528. text = text.replace("c", "C")
  529. text = text.replace("k", "K")
  530. text = text.replace("t", "T")
  531. text = text.replace(",", "")
  532. text = text.replace("丶", "")
  533. text = text.replace("。", "")
  534. text = text.replace("、", "")
  535. text = text.replace("?", "")
  536. text = remove_punctuation(text)
  537. return text
  538. def main():
  539. args = get_args()
  540. assert_file_exists(args.tokens)
  541. assert args.num_threads > 0, args.num_threads
  542. assert len(args.nemo_ctc) == 0, args.nemo_ctc
  543. assert len(args.wenet_ctc) == 0, args.wenet_ctc
  544. assert len(args.whisper_encoder) == 0, args.whisper_encoder
  545. assert len(args.whisper_decoder) == 0, args.whisper_decoder
  546. assert len(args.tdnn_model) == 0, args.tdnn_model
  547. assert_file_exists(args.paraformer)
  548. recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
  549. paraformer=args.paraformer,
  550. tokens=args.tokens,
  551. num_threads=args.num_threads,
  552. sample_rate=args.sample_rate,
  553. feature_dim=args.feature_dim,
  554. decoding_method=args.decoding_method,
  555. debug=args.debug,
  556. )
  557. print("Started!")
  558. start_time = time.time()
  559. streams, results = [], []
  560. total_duration = 0
  561. for i, wave_filename in enumerate(args.sound_files):
  562. assert_file_exists(wave_filename)
  563. samples, sample_rate = read_wave(wave_filename)
  564. duration = len(samples) / sample_rate
  565. total_duration += duration
  566. s = recognizer.create_stream()
  567. s.accept_waveform(sample_rate, samples)
  568. streams.append(s)
  569. if i % 10 == 0:
  570. recognizer.decode_streams(streams)
  571. results += [s.result.text for s in streams]
  572. streams = []
  573. print(f"Processed {i} files")
  574. # process the last batch
  575. if streams:
  576. recognizer.decode_streams(streams)
  577. results += [s.result.text for s in streams]
  578. end_time = time.time()
  579. print("Done!")
  580. results_dict = {}
  581. for wave_filename, result in zip(args.sound_files, results):
  582. print(f"{wave_filename}\n{result}")
  583. print("-" * 10)
  584. wave_basename = Path(wave_filename).stem
  585. results_dict[wave_basename] = result
  586. elapsed_seconds = end_time - start_time
  587. rtf = elapsed_seconds / total_duration
  588. print(f"num_threads: {args.num_threads}")
  589. print(f"decoding_method: {args.decoding_method}")
  590. print(f"Wave duration: {total_duration:.3f} s")
  591. print(f"Elapsed time: {elapsed_seconds:.3f} s")
  592. print(
  593. f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
  594. )
  595. # Load labels either from file or from dataset
  596. labels_dict = {}
  597. if args.label:
  598. # Load labels from file (original functionality)
  599. print(f"Loading labels from file: {args.label}")
  600. with open(args.label, "r") as f:
  601. for line in f:
  602. # fields = line.strip().split(" ")
  603. # fields = [item for item in fields if item]
  604. # assert len(fields) == 4
  605. # prompt_text, prompt_audio, text, audio_path = fields
  606. fields = line.strip().split("|")
  607. fields = [item for item in fields if item]
  608. assert len(fields) == 4
  609. audio_path, prompt_text, prompt_audio, text = fields
  610. labels_dict[Path(audio_path).stem] = normalize_text_alimeeting(text)
  611. else:
  612. # Load labels from dataset (new functionality)
  613. print(f"Loading labels from dataset: {args.dataset_name}, split: {args.split_name}")
  614. if 'zero' in args.split_name:
  615. dataset_name = "yuekai/CV3-Eval"
  616. else:
  617. dataset_name = "yuekai/seed_tts_cosy2"
  618. dataset = load_dataset(
  619. dataset_name,
  620. split=args.split_name,
  621. trust_remote_code=True,
  622. )
  623. for item in dataset:
  624. audio_id = item["id"]
  625. labels_dict[audio_id] = normalize_text_alimeeting(item["target_text"])
  626. print(f"Loaded {len(labels_dict)} labels from dataset")
  627. # Perform evaluation if labels are available
  628. if labels_dict:
  629. final_results = []
  630. for key, value in results_dict.items():
  631. if key in labels_dict:
  632. final_results.append((key, labels_dict[key], value))
  633. else:
  634. print(f"Warning: No label found for {key}, skipping...")
  635. if final_results:
  636. store_transcripts(
  637. filename=f"{args.log_dir}/recogs-{args.name}.txt", texts=final_results
  638. )
  639. with open(f"{args.log_dir}/errs-{args.name}.txt", "w") as f:
  640. write_error_stats(f, "test-set", final_results, enable_log=True)
  641. with open(f"{args.log_dir}/errs-{args.name}.txt", "r") as f:
  642. print(f.readline()) # WER
  643. print(f.readline()) # Detailed errors
  644. else:
  645. print("No matching labels found for evaluation")
  646. else:
  647. print("No labels available for evaluation")
  648. if __name__ == "__main__":
  649. main()