processor.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. import random
  16. import pyarrow.parquet as pq
  17. from io import BytesIO
  18. import torch
  19. import torchaudio
  20. from torch.nn.utils.rnn import pad_sequence
  21. import torch.nn.functional as F
  22. torchaudio.set_audio_backend('soundfile')
  23. AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
  24. def parquet_opener(data, mode='train', tts_data={}):
  25. """ Give url or local file, return file descriptor
  26. Inplace operation.
  27. Args:
  28. data(Iterable[str]): url or local file list
  29. Returns:
  30. Iterable[{src, stream}]
  31. """
  32. for sample in data:
  33. assert 'src' in sample
  34. url = sample['src']
  35. try:
  36. df = pq.read_table(url).to_pandas()
  37. for i in range(len(df)):
  38. if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
  39. continue
  40. sample.update(dict(df.loc[i]))
  41. if mode == 'train':
  42. # NOTE do not return sample directly, must initialize a new dict
  43. yield {**sample}
  44. else:
  45. for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
  46. yield {**sample, 'tts_index': index, 'tts_text': text}
  47. except Exception as ex:
  48. logging.warning('Failed to open {}, ex info {}'.format(url, ex))
  49. def filter(data,
  50. max_length=10240,
  51. min_length=10,
  52. token_max_length=200,
  53. token_min_length=1,
  54. min_output_input_ratio=0.0005,
  55. max_output_input_ratio=1,
  56. mode='train'):
  57. """ Filter sample according to feature and label length
  58. Inplace operation.
  59. Args::
  60. data: Iterable[{key, wav, label, sample_rate}]
  61. max_length: drop utterance which is greater than max_length(10ms)
  62. min_length: drop utterance which is less than min_length(10ms)
  63. token_max_length: drop utterance which is greater than
  64. token_max_length, especially when use char unit for
  65. english modeling
  66. token_min_length: drop utterance which is
  67. less than token_max_length
  68. min_output_input_ratio: minimal ration of
  69. token_length / feats_length(10ms)
  70. max_output_input_ratio: maximum ration of
  71. token_length / feats_length(10ms)
  72. Returns:
  73. Iterable[{key, wav, label, sample_rate}]
  74. """
  75. for sample in data:
  76. sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
  77. del sample['audio_data']
  78. # sample['wav'] is torch.Tensor, we have 100 frames every second
  79. num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
  80. if num_frames < min_length:
  81. continue
  82. if num_frames > max_length:
  83. continue
  84. if len(sample['text_token']) < token_min_length:
  85. continue
  86. if len(sample['text_token']) > token_max_length:
  87. continue
  88. if len(sample['speech_token']) == 0:
  89. continue
  90. if num_frames != 0:
  91. if len(sample['text_token']) / num_frames < min_output_input_ratio:
  92. continue
  93. if len(sample['text_token']) / num_frames > max_output_input_ratio:
  94. continue
  95. yield sample
  96. def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
  97. """ Resample data.
  98. Inplace operation.
  99. Args:
  100. data: Iterable[{key, wav, label, sample_rate}]
  101. resample_rate: target resample rate
  102. Returns:
  103. Iterable[{key, wav, label, sample_rate}]
  104. """
  105. for sample in data:
  106. assert 'sample_rate' in sample
  107. assert 'speech' in sample
  108. sample_rate = sample['sample_rate']
  109. waveform = sample['speech']
  110. if sample_rate != resample_rate:
  111. if sample_rate < min_sample_rate:
  112. continue
  113. sample['sample_rate'] = resample_rate
  114. sample['speech'] = torchaudio.transforms.Resample(
  115. orig_freq=sample_rate, new_freq=resample_rate)(waveform)
  116. max_val = sample['speech'].abs().max()
  117. if max_val > 1:
  118. sample['speech'] /= max_val
  119. yield sample
  120. def compute_fbank(data,
  121. feat_extractor,
  122. mode='train'):
  123. """ Extract fbank
  124. Args:
  125. data: Iterable[{key, wav, label, sample_rate}]
  126. Returns:
  127. Iterable[{key, feat, label}]
  128. """
  129. for sample in data:
  130. assert 'sample_rate' in sample
  131. assert 'speech' in sample
  132. assert 'utt' in sample
  133. assert 'text_token' in sample
  134. waveform = sample['speech']
  135. mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
  136. sample['speech_feat'] = mat
  137. del sample['speech']
  138. yield sample
  139. def parse_embedding(data, normalize, mode='train'):
  140. """ Parse utt_embedding/spk_embedding
  141. Args:
  142. data: Iterable[{key, wav, label, sample_rate}]
  143. Returns:
  144. Iterable[{key, feat, label}]
  145. """
  146. for sample in data:
  147. sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
  148. sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
  149. if normalize:
  150. sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
  151. sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
  152. yield sample
  153. def tokenize(data, get_tokenizer, allowed_special, mode='train'):
  154. """ Decode text to chars or BPE
  155. Inplace operation
  156. Args:
  157. data: Iterable[{key, wav, txt, sample_rate}]
  158. Returns:
  159. Iterable[{key, wav, txt, tokens, label, sample_rate}]
  160. """
  161. tokenizer = get_tokenizer()
  162. for sample in data:
  163. assert 'text' in sample
  164. sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
  165. if mode == 'inference':
  166. sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
  167. yield sample
  168. def shuffle(data, shuffle_size=10000, mode='train'):
  169. """ Local shuffle the data
  170. Args:
  171. data: Iterable[{key, feat, label}]
  172. shuffle_size: buffer size for shuffle
  173. Returns:
  174. Iterable[{key, feat, label}]
  175. """
  176. buf = []
  177. for sample in data:
  178. buf.append(sample)
  179. if len(buf) >= shuffle_size:
  180. random.shuffle(buf)
  181. for x in buf:
  182. yield x
  183. buf = []
  184. # The sample left over
  185. random.shuffle(buf)
  186. for x in buf:
  187. yield x
  188. def sort(data, sort_size=500, mode='train'):
  189. """ Sort the data by feature length.
  190. Sort is used after shuffle and before batch, so we can group
  191. utts with similar lengths into a batch, and `sort_size` should
  192. be less than `shuffle_size`
  193. Args:
  194. data: Iterable[{key, feat, label}]
  195. sort_size: buffer size for sort
  196. Returns:
  197. Iterable[{key, feat, label}]
  198. """
  199. buf = []
  200. for sample in data:
  201. buf.append(sample)
  202. if len(buf) >= sort_size:
  203. buf.sort(key=lambda x: x['speech_feat'].size(0))
  204. for x in buf:
  205. yield x
  206. buf = []
  207. # The sample left over
  208. buf.sort(key=lambda x: x['speech_feat'].size(0))
  209. for x in buf:
  210. yield x
  211. def static_batch(data, batch_size=16):
  212. """ Static batch the data by `batch_size`
  213. Args:
  214. data: Iterable[{key, feat, label}]
  215. batch_size: batch size
  216. Returns:
  217. Iterable[List[{key, feat, label}]]
  218. """
  219. buf = []
  220. for sample in data:
  221. buf.append(sample)
  222. if len(buf) >= batch_size:
  223. yield buf
  224. buf = []
  225. if len(buf) > 0:
  226. yield buf
  227. def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
  228. """ Dynamic batch the data until the total frames in batch
  229. reach `max_frames_in_batch`
  230. Args:
  231. data: Iterable[{key, feat, label}]
  232. max_frames_in_batch: max_frames in one batch
  233. Returns:
  234. Iterable[List[{key, feat, label}]]
  235. """
  236. buf = []
  237. longest_frames = 0
  238. for sample in data:
  239. assert 'speech_feat' in sample
  240. assert isinstance(sample['speech_feat'], torch.Tensor)
  241. new_sample_frames = sample['speech_feat'].size(0)
  242. longest_frames = max(longest_frames, new_sample_frames)
  243. frames_after_padding = longest_frames * (len(buf) + 1)
  244. if frames_after_padding > max_frames_in_batch:
  245. yield buf
  246. buf = [sample]
  247. longest_frames = new_sample_frames
  248. else:
  249. buf.append(sample)
  250. if len(buf) > 0:
  251. yield buf
  252. def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
  253. """ Wrapper for static/dynamic batch
  254. """
  255. if mode == 'inference':
  256. return static_batch(data, 1)
  257. else:
  258. if batch_type == 'static':
  259. return static_batch(data, batch_size)
  260. elif batch_type == 'dynamic':
  261. return dynamic_batch(data, max_frames_in_batch)
  262. else:
  263. logging.fatal('Unsupported batch type {}'.format(batch_type))
  264. def padding(data, use_spk_embedding, mode='train'):
  265. """ Padding the data into training data
  266. Args:
  267. data: Iterable[List[{key, feat, label}]]
  268. Returns:
  269. Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
  270. """
  271. for sample in data:
  272. assert isinstance(sample, list)
  273. speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
  274. dtype=torch.int32)
  275. order = torch.argsort(speech_feat_len, descending=True)
  276. utts = [sample[i]['utt'] for i in order]
  277. speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
  278. speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
  279. speech_token = pad_sequence(speech_token,
  280. batch_first=True,
  281. padding_value=0)
  282. speech_feat = [sample[i]['speech_feat'] for i in order]
  283. speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
  284. speech_feat = pad_sequence(speech_feat,
  285. batch_first=True,
  286. padding_value=0)
  287. text = [sample[i]['text'] for i in order]
  288. text_token = [torch.tensor(sample[i]['text_token']) for i in order]
  289. text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
  290. text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
  291. utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
  292. spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
  293. batch = {
  294. "utts": utts,
  295. "speech_token": speech_token,
  296. "speech_token_len": speech_token_len,
  297. "speech_feat": speech_feat,
  298. "speech_feat_len": speech_feat_len,
  299. "text": text,
  300. "text_token": text_token,
  301. "text_token_len": text_token_len,
  302. "utt_embedding": utt_embedding,
  303. "spk_embedding": spk_embedding,
  304. }
  305. if mode == 'inference':
  306. tts_text = [sample[i]['tts_text'] for i in order]
  307. tts_index = [sample[i]['tts_index'] for i in order]
  308. tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
  309. tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
  310. tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
  311. batch.update({'tts_text': tts_text,
  312. 'tts_index': tts_index,
  313. 'tts_text_token': tts_text_token,
  314. 'tts_text_token_len': tts_text_token_len})
  315. if use_spk_embedding is True:
  316. batch["embedding"] = batch["spk_embedding"]
  317. else:
  318. batch["embedding"] = batch["utt_embedding"]
  319. yield batch