1
0

processor.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. import random
  16. import pyarrow.parquet as pq
  17. from io import BytesIO
  18. import torch
  19. import torchaudio
  20. from torch.nn.utils.rnn import pad_sequence
  21. import torch.nn.functional as F
  22. torchaudio.set_audio_backend('soundfile')
  23. AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
  24. def parquet_opener(data, mode='train', tts_data={}):
  25. """ Give url or local file, return file descriptor
  26. Inplace operation.
  27. Args:
  28. data(Iterable[str]): url or local file list
  29. Returns:
  30. Iterable[{src, stream}]
  31. """
  32. for sample in data:
  33. assert 'src' in sample
  34. url = sample['src']
  35. try:
  36. df = pq.read_table(url).to_pandas()
  37. for i in range(len(df)):
  38. if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
  39. continue
  40. sample.update(dict(df.loc[i]))
  41. if mode == 'train':
  42. # NOTE do not return sample directly, must initialize a new dict
  43. yield {**sample}
  44. else:
  45. for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
  46. yield {**sample, 'tts_index': index, 'tts_text': text}
  47. except Exception as ex:
  48. logging.warning('Failed to open {}, ex info {}'.format(url, ex))
  49. def filter(data,
  50. max_length=10240,
  51. min_length=10,
  52. token_max_length=200,
  53. token_min_length=1,
  54. min_output_input_ratio=0.0005,
  55. max_output_input_ratio=1,
  56. mode='train'):
  57. """ Filter sample according to feature and label length
  58. Inplace operation.
  59. Args::
  60. data: Iterable[{key, wav, label, sample_rate}]
  61. max_length: drop utterance which is greater than max_length(10ms)
  62. min_length: drop utterance which is less than min_length(10ms)
  63. token_max_length: drop utterance which is greater than
  64. token_max_length, especially when use char unit for
  65. english modeling
  66. token_min_length: drop utterance which is
  67. less than token_max_length
  68. min_output_input_ratio: minimal ration of
  69. token_length / feats_length(10ms)
  70. max_output_input_ratio: maximum ration of
  71. token_length / feats_length(10ms)
  72. Returns:
  73. Iterable[{key, wav, label, sample_rate}]
  74. """
  75. for sample in data:
  76. sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
  77. sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
  78. del sample['audio_data']
  79. # sample['wav'] is torch.Tensor, we have 100 frames every second
  80. num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
  81. if num_frames < min_length:
  82. continue
  83. if num_frames > max_length:
  84. continue
  85. if len(sample['text_token']) < token_min_length:
  86. continue
  87. if len(sample['text_token']) > token_max_length:
  88. continue
  89. if len(sample['speech_token']) == 0:
  90. continue
  91. if num_frames != 0:
  92. if len(sample['text_token']) / num_frames < min_output_input_ratio:
  93. continue
  94. if len(sample['text_token']) / num_frames > max_output_input_ratio:
  95. continue
  96. yield sample
  97. def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
  98. """ Resample data.
  99. Inplace operation.
  100. Args:
  101. data: Iterable[{key, wav, label, sample_rate}]
  102. resample_rate: target resample rate
  103. Returns:
  104. Iterable[{key, wav, label, sample_rate}]
  105. """
  106. for sample in data:
  107. assert 'sample_rate' in sample
  108. assert 'speech' in sample
  109. sample_rate = sample['sample_rate']
  110. waveform = sample['speech']
  111. if sample_rate != resample_rate:
  112. if sample_rate < min_sample_rate:
  113. continue
  114. sample['sample_rate'] = resample_rate
  115. sample['speech'] = torchaudio.transforms.Resample(
  116. orig_freq=sample_rate, new_freq=resample_rate)(waveform)
  117. max_val = sample['speech'].abs().max()
  118. if max_val > 1:
  119. sample['speech'] /= max_val
  120. yield sample
  121. def truncate(data, truncate_length=24576, mode='train'):
  122. """ Truncate data.
  123. Args:
  124. data: Iterable[{key, wav, label, sample_rate}]
  125. truncate_length: truncate length
  126. Returns:
  127. Iterable[{key, wav, label, sample_rate}]
  128. """
  129. for sample in data:
  130. waveform = sample['speech']
  131. if waveform.shape[1] > truncate_length:
  132. start = random.randint(0, waveform.shape[1] - truncate_length)
  133. waveform = waveform[:, start: start + truncate_length]
  134. else:
  135. waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
  136. sample['speech'] = waveform
  137. yield sample
  138. def compute_fbank(data,
  139. feat_extractor,
  140. mode='train'):
  141. """ Extract fbank
  142. Args:
  143. data: Iterable[{key, wav, label, sample_rate}]
  144. Returns:
  145. Iterable[{key, feat, label}]
  146. """
  147. for sample in data:
  148. assert 'sample_rate' in sample
  149. assert 'speech' in sample
  150. assert 'utt' in sample
  151. assert 'text_token' in sample
  152. waveform = sample['speech']
  153. mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
  154. sample['speech_feat'] = mat
  155. yield sample
  156. def compute_f0(data, pitch_extractor, mode='train'):
  157. """ Extract f0
  158. Args:
  159. data: Iterable[{key, wav, label, sample_rate}]
  160. Returns:
  161. Iterable[{key, feat, label}]
  162. """
  163. for sample in data:
  164. assert 'sample_rate' in sample
  165. assert 'speech' in sample
  166. assert 'utt' in sample
  167. assert 'text_token' in sample
  168. waveform = sample['speech']
  169. mat = pitch_extractor(waveform).transpose(1, 2)
  170. mat = F.interpolate(mat, size=sample['speech_feat'].shape[0], mode='linear')
  171. sample['pitch_feat'] = mat[0, 0]
  172. yield sample
  173. def parse_embedding(data, normalize, mode='train'):
  174. """ Parse utt_embedding/spk_embedding
  175. Args:
  176. data: Iterable[{key, wav, label, sample_rate}]
  177. Returns:
  178. Iterable[{key, feat, label}]
  179. """
  180. for sample in data:
  181. sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
  182. sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
  183. if normalize:
  184. sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
  185. sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
  186. yield sample
  187. def tokenize(data, get_tokenizer, allowed_special, mode='train'):
  188. """ Decode text to chars or BPE
  189. Inplace operation
  190. Args:
  191. data: Iterable[{key, wav, txt, sample_rate}]
  192. Returns:
  193. Iterable[{key, wav, txt, tokens, label, sample_rate}]
  194. """
  195. tokenizer = get_tokenizer()
  196. for sample in data:
  197. assert 'text' in sample
  198. sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
  199. if mode == 'inference':
  200. sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
  201. yield sample
  202. def shuffle(data, shuffle_size=10000, mode='train'):
  203. """ Local shuffle the data
  204. Args:
  205. data: Iterable[{key, feat, label}]
  206. shuffle_size: buffer size for shuffle
  207. Returns:
  208. Iterable[{key, feat, label}]
  209. """
  210. buf = []
  211. for sample in data:
  212. buf.append(sample)
  213. if len(buf) >= shuffle_size:
  214. random.shuffle(buf)
  215. for x in buf:
  216. yield x
  217. buf = []
  218. # The sample left over
  219. random.shuffle(buf)
  220. for x in buf:
  221. yield x
  222. def sort(data, sort_size=500, mode='train'):
  223. """ Sort the data by feature length.
  224. Sort is used after shuffle and before batch, so we can group
  225. utts with similar lengths into a batch, and `sort_size` should
  226. be less than `shuffle_size`
  227. Args:
  228. data: Iterable[{key, feat, label}]
  229. sort_size: buffer size for sort
  230. Returns:
  231. Iterable[{key, feat, label}]
  232. """
  233. buf = []
  234. for sample in data:
  235. buf.append(sample)
  236. if len(buf) >= sort_size:
  237. buf.sort(key=lambda x: x['speech_feat'].size(0))
  238. for x in buf:
  239. yield x
  240. buf = []
  241. # The sample left over
  242. buf.sort(key=lambda x: x['speech_feat'].size(0))
  243. for x in buf:
  244. yield x
  245. def static_batch(data, batch_size=16):
  246. """ Static batch the data by `batch_size`
  247. Args:
  248. data: Iterable[{key, feat, label}]
  249. batch_size: batch size
  250. Returns:
  251. Iterable[List[{key, feat, label}]]
  252. """
  253. buf = []
  254. for sample in data:
  255. buf.append(sample)
  256. if len(buf) >= batch_size:
  257. yield buf
  258. buf = []
  259. if len(buf) > 0:
  260. yield buf
  261. def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
  262. """ Dynamic batch the data until the total frames in batch
  263. reach `max_frames_in_batch`
  264. Args:
  265. data: Iterable[{key, feat, label}]
  266. max_frames_in_batch: max_frames in one batch
  267. Returns:
  268. Iterable[List[{key, feat, label}]]
  269. """
  270. buf = []
  271. longest_frames = 0
  272. for sample in data:
  273. assert 'speech_feat' in sample
  274. assert isinstance(sample['speech_feat'], torch.Tensor)
  275. new_sample_frames = sample['speech_feat'].size(0)
  276. longest_frames = max(longest_frames, new_sample_frames)
  277. frames_after_padding = longest_frames * (len(buf) + 1)
  278. if frames_after_padding > max_frames_in_batch:
  279. yield buf
  280. buf = [sample]
  281. longest_frames = new_sample_frames
  282. else:
  283. buf.append(sample)
  284. if len(buf) > 0:
  285. yield buf
  286. def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
  287. """ Wrapper for static/dynamic batch
  288. """
  289. if mode == 'inference':
  290. return static_batch(data, 1)
  291. else:
  292. if batch_type == 'static':
  293. return static_batch(data, batch_size)
  294. elif batch_type == 'dynamic':
  295. return dynamic_batch(data, max_frames_in_batch)
  296. else:
  297. logging.fatal('Unsupported batch type {}'.format(batch_type))
  298. def padding(data, use_spk_embedding, mode='train'):
  299. """ Padding the data into training data
  300. Args:
  301. data: Iterable[List[{key, feat, label}]]
  302. Returns:
  303. Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
  304. """
  305. for sample in data:
  306. assert isinstance(sample, list)
  307. speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
  308. dtype=torch.int32)
  309. order = torch.argsort(speech_feat_len, descending=True)
  310. utts = [sample[i]['utt'] for i in order]
  311. speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
  312. speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
  313. speech = pad_sequence(speech, batch_first=True, padding_value=0)
  314. speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
  315. speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
  316. speech_token = pad_sequence(speech_token,
  317. batch_first=True,
  318. padding_value=0)
  319. speech_feat = [sample[i]['speech_feat'] for i in order]
  320. speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
  321. speech_feat = pad_sequence(speech_feat,
  322. batch_first=True,
  323. padding_value=0)
  324. pitch_feat = [sample[i]['pitch_feat'] for i in order]
  325. pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
  326. pitch_feat = pad_sequence(pitch_feat,
  327. batch_first=True,
  328. padding_value=0)
  329. text = [sample[i]['text'] for i in order]
  330. text_token = [torch.tensor(sample[i]['text_token']) for i in order]
  331. text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
  332. text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
  333. utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
  334. spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
  335. batch = {
  336. "utts": utts,
  337. "speech": speech,
  338. "speech_len": speech_len,
  339. "speech_token": speech_token,
  340. "speech_token_len": speech_token_len,
  341. "speech_feat": speech_feat,
  342. "speech_feat_len": speech_feat_len,
  343. "pitch_feat": pitch_feat,
  344. "pitch_feat_len": pitch_feat_len,
  345. "text": text,
  346. "text_token": text_token,
  347. "text_token_len": text_token_len,
  348. "utt_embedding": utt_embedding,
  349. "spk_embedding": spk_embedding,
  350. }
  351. if mode == 'inference':
  352. tts_text = [sample[i]['tts_text'] for i in order]
  353. tts_index = [sample[i]['tts_index'] for i in order]
  354. tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
  355. tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
  356. tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
  357. batch.update({'tts_text': tts_text,
  358. 'tts_index': tts_index,
  359. 'tts_text_token': tts_text_token,
  360. 'tts_text_token_len': tts_text_token_len})
  361. if use_spk_embedding is True:
  362. batch["embedding"] = batch["spk_embedding"]
  363. else:
  364. batch["embedding"] = batch["utt_embedding"]
  365. yield batch