processor.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. import random
  16. import pyarrow.parquet as pq
  17. from io import BytesIO
  18. import torch
  19. import torchaudio
  20. from torch.nn.utils.rnn import pad_sequence
  21. import torch.nn.functional as F
  22. torchaudio.set_audio_backend('soundfile')
  23. torchaudio.utils.sox_utils.set_buffer_size(16500)
  24. AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
  25. def parquet_opener(data, mode='train', tts_data={}):
  26. """ Give url or local file, return file descriptor
  27. Inplace operation.
  28. Args:
  29. data(Iterable[str]): url or local file list
  30. Returns:
  31. Iterable[{src, stream}]
  32. """
  33. for sample in data:
  34. assert 'src' in sample
  35. url = sample['src']
  36. try:
  37. df = pq.read_table(url).to_pandas()
  38. for i in range(len(df)):
  39. if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
  40. continue
  41. sample.update(dict(df.loc[i]))
  42. if mode == 'train':
  43. # NOTE do not return sample directly, must initialize a new dict
  44. yield {**sample}
  45. else:
  46. for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
  47. yield {**sample, 'tts_index': index, 'tts_text': text}
  48. except Exception as ex:
  49. logging.warning('Failed to open {}, ex info {}'.format(url, ex))
  50. def filter(data,
  51. max_length=10240,
  52. min_length=10,
  53. token_max_length=200,
  54. token_min_length=1,
  55. min_output_input_ratio=0.0005,
  56. max_output_input_ratio=1,
  57. mode='train'):
  58. """ Filter sample according to feature and label length
  59. Inplace operation.
  60. Args::
  61. data: Iterable[{key, wav, label, sample_rate}]
  62. max_length: drop utterance which is greater than max_length(10ms)
  63. min_length: drop utterance which is less than min_length(10ms)
  64. token_max_length: drop utterance which is greater than
  65. token_max_length, especially when use char unit for
  66. english modeling
  67. token_min_length: drop utterance which is
  68. less than token_max_length
  69. min_output_input_ratio: minimal ration of
  70. token_length / feats_length(10ms)
  71. max_output_input_ratio: maximum ration of
  72. token_length / feats_length(10ms)
  73. Returns:
  74. Iterable[{key, wav, label, sample_rate}]
  75. """
  76. for sample in data:
  77. sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
  78. del sample['audio_data']
  79. # sample['wav'] is torch.Tensor, we have 100 frames every second
  80. num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
  81. if num_frames < min_length:
  82. continue
  83. if num_frames > max_length:
  84. continue
  85. if len(sample['text_token']) < token_min_length:
  86. continue
  87. if len(sample['text_token']) > token_max_length:
  88. continue
  89. if len(sample['speech_token']) == 0:
  90. continue
  91. if num_frames != 0:
  92. if len(sample['text_token']) / num_frames < min_output_input_ratio:
  93. continue
  94. if len(sample['text_token']) / num_frames > max_output_input_ratio:
  95. continue
  96. yield sample
  97. def resample(data, resample_rate=22050, mode='train'):
  98. """ Resample data.
  99. Inplace operation.
  100. Args:
  101. data: Iterable[{key, wav, label, sample_rate}]
  102. resample_rate: target resample rate
  103. Returns:
  104. Iterable[{key, wav, label, sample_rate}]
  105. """
  106. for sample in data:
  107. assert 'sample_rate' in sample
  108. assert 'speech' in sample
  109. sample_rate = sample['sample_rate']
  110. waveform = sample['speech']
  111. if sample_rate != resample_rate:
  112. if sample_rate < resample_rate:
  113. continue
  114. sample['sample_rate'] = resample_rate
  115. sample['speech'] = torchaudio.transforms.Resample(
  116. orig_freq=sample_rate, new_freq=resample_rate)(waveform)
  117. max_val = sample['speech'].abs().max()
  118. if max_val > 1:
  119. sample['speech'] /= max_val
  120. yield sample
  121. def compute_fbank(data,
  122. feat_extractor,
  123. mode='train'):
  124. """ Extract fbank
  125. Args:
  126. data: Iterable[{key, wav, label, sample_rate}]
  127. Returns:
  128. Iterable[{key, feat, label}]
  129. """
  130. for sample in data:
  131. assert 'sample_rate' in sample
  132. assert 'speech' in sample
  133. assert 'utt' in sample
  134. assert 'text_token' in sample
  135. waveform = sample['speech']
  136. mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
  137. sample['speech_feat'] = mat
  138. del sample['speech']
  139. yield sample
  140. def parse_embedding(data, normalize, mode='train'):
  141. """ Parse utt_embedding/spk_embedding
  142. Args:
  143. data: Iterable[{key, wav, label, sample_rate}]
  144. Returns:
  145. Iterable[{key, feat, label}]
  146. """
  147. for sample in data:
  148. sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
  149. sample['spk_embedding'] = torch.stack([torch.tensor(i, dtype=torch.float32) for i in sample['spk_embedding']], dim=0).mean(dim=0)
  150. if normalize:
  151. sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
  152. sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
  153. yield sample
  154. def tokenize(data, get_tokenizer, allowed_special, mode='train'):
  155. """ Decode text to chars or BPE
  156. Inplace operation
  157. Args:
  158. data: Iterable[{key, wav, txt, sample_rate}]
  159. Returns:
  160. Iterable[{key, wav, txt, tokens, label, sample_rate}]
  161. """
  162. tokenizer = get_tokenizer()
  163. for sample in data:
  164. assert 'text' in sample
  165. sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
  166. if mode == 'inference':
  167. sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
  168. yield sample
  169. def shuffle(data, shuffle_size=10000, mode='train'):
  170. """ Local shuffle the data
  171. Args:
  172. data: Iterable[{key, feat, label}]
  173. shuffle_size: buffer size for shuffle
  174. Returns:
  175. Iterable[{key, feat, label}]
  176. """
  177. buf = []
  178. for sample in data:
  179. buf.append(sample)
  180. if len(buf) >= shuffle_size:
  181. random.shuffle(buf)
  182. for x in buf:
  183. yield x
  184. buf = []
  185. # The sample left over
  186. random.shuffle(buf)
  187. for x in buf:
  188. yield x
  189. def sort(data, sort_size=500, mode='train'):
  190. """ Sort the data by feature length.
  191. Sort is used after shuffle and before batch, so we can group
  192. utts with similar lengths into a batch, and `sort_size` should
  193. be less than `shuffle_size`
  194. Args:
  195. data: Iterable[{key, feat, label}]
  196. sort_size: buffer size for sort
  197. Returns:
  198. Iterable[{key, feat, label}]
  199. """
  200. buf = []
  201. for sample in data:
  202. buf.append(sample)
  203. if len(buf) >= sort_size:
  204. buf.sort(key=lambda x: x['speech_feat'].size(0))
  205. for x in buf:
  206. yield x
  207. buf = []
  208. # The sample left over
  209. buf.sort(key=lambda x: x['speech_feat'].size(0))
  210. for x in buf:
  211. yield x
  212. def static_batch(data, batch_size=16):
  213. """ Static batch the data by `batch_size`
  214. Args:
  215. data: Iterable[{key, feat, label}]
  216. batch_size: batch size
  217. Returns:
  218. Iterable[List[{key, feat, label}]]
  219. """
  220. buf = []
  221. for sample in data:
  222. buf.append(sample)
  223. if len(buf) >= batch_size:
  224. yield buf
  225. buf = []
  226. if len(buf) > 0:
  227. yield buf
  228. def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
  229. """ Dynamic batch the data until the total frames in batch
  230. reach `max_frames_in_batch`
  231. Args:
  232. data: Iterable[{key, feat, label}]
  233. max_frames_in_batch: max_frames in one batch
  234. Returns:
  235. Iterable[List[{key, feat, label}]]
  236. """
  237. buf = []
  238. longest_frames = 0
  239. for sample in data:
  240. assert 'speech_feat' in sample
  241. assert isinstance(sample['speech_feat'], torch.Tensor)
  242. new_sample_frames = sample['speech_feat'].size(0)
  243. longest_frames = max(longest_frames, new_sample_frames)
  244. frames_after_padding = longest_frames * (len(buf) + 1)
  245. if frames_after_padding > max_frames_in_batch:
  246. yield buf
  247. buf = [sample]
  248. longest_frames = new_sample_frames
  249. else:
  250. buf.append(sample)
  251. if len(buf) > 0:
  252. yield buf
  253. def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
  254. """ Wrapper for static/dynamic batch
  255. """
  256. if mode == 'inference':
  257. return static_batch(data, 1)
  258. else:
  259. if batch_type == 'static':
  260. return static_batch(data, batch_size)
  261. elif batch_type == 'dynamic':
  262. return dynamic_batch(data, max_frames_in_batch)
  263. else:
  264. logging.fatal('Unsupported batch type {}'.format(batch_type))
  265. def padding(data, mode='train'):
  266. """ Padding the data into training data
  267. Args:
  268. data: Iterable[List[{key, feat, label}]]
  269. Returns:
  270. Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
  271. """
  272. for sample in data:
  273. assert isinstance(sample, list)
  274. speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
  275. dtype=torch.int32)
  276. order = torch.argsort(speech_feat_len, descending=True)
  277. utts = [sample[i]['utt'] for i in order]
  278. speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
  279. speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
  280. speech_token = pad_sequence(speech_token,
  281. batch_first=True,
  282. padding_value=0)
  283. speech_feat = [sample[i]['speech_feat'] for i in order]
  284. speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
  285. speech_feat = pad_sequence(speech_feat,
  286. batch_first=True,
  287. padding_value=0)
  288. text = [sample[i]['text'] for i in order]
  289. text_token = [torch.tensor(sample[i]['text_token']) for i in order]
  290. text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
  291. text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
  292. utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
  293. spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
  294. batch = {
  295. "utts": utts,
  296. "speech_token": speech_token,
  297. "speech_token_len": speech_token_len,
  298. "speech_feat": speech_feat,
  299. "speech_feat_len": speech_feat_len,
  300. "text": text,
  301. "text_token": text_token,
  302. "text_token_len": text_token_len,
  303. "utt_embedding": utt_embedding,
  304. "spk_embedding": spk_embedding,
  305. }
  306. if mode == 'inference':
  307. tts_text = [sample[i]['tts_text'] for i in order]
  308. tts_index = [sample[i]['tts_index'] for i in order]
  309. tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
  310. tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
  311. tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
  312. batch.update({'tts_text': tts_text,
  313. 'tts_index': tts_index,
  314. 'tts_text_token': tts_text_token,
  315. 'tts_text_token_len': tts_text_token_len})
  316. yield batch