processor.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. import random
  16. import pyarrow.parquet as pq
  17. from io import BytesIO
  18. import torch
  19. import torchaudio
  20. from torch.nn.utils.rnn import pad_sequence
  21. import torch.nn.functional as F
  22. import pyworld as pw
  23. AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
  24. def parquet_opener(data, mode='train'):
  25. """ Give url or local file, return file descriptor
  26. Inplace operation.
  27. Args:
  28. data(Iterable[str]): url or local file list
  29. Returns:
  30. Iterable[{src, stream}]
  31. """
  32. for sample in data:
  33. assert 'src' in sample
  34. url = sample['src']
  35. try:
  36. for df in pq.ParquetFile(url).iter_batches(batch_size=64):
  37. df = df.to_pandas()
  38. for i in range(len(df)):
  39. sample.update(dict(df.loc[i]))
  40. # NOTE do not return sample directly, must initialize a new dict
  41. yield {**sample}
  42. except Exception as ex:
  43. logging.warning('Failed to open {}, ex info {}'.format(url, ex))
  44. def filter(data,
  45. max_length=10240,
  46. min_length=10,
  47. token_max_length=200,
  48. token_min_length=1,
  49. min_output_input_ratio=0.0005,
  50. max_output_input_ratio=1,
  51. mode='train'):
  52. """ Filter sample according to feature and label length
  53. Inplace operation.
  54. Args::
  55. data: Iterable[{key, wav, label, sample_rate}]
  56. max_length: drop utterance which is greater than max_length(10ms)
  57. min_length: drop utterance which is less than min_length(10ms)
  58. token_max_length: drop utterance which is greater than
  59. token_max_length, especially when use char unit for
  60. english modeling
  61. token_min_length: drop utterance which is
  62. less than token_max_length
  63. min_output_input_ratio: minimal ration of
  64. token_length / feats_length(10ms)
  65. max_output_input_ratio: maximum ration of
  66. token_length / feats_length(10ms)
  67. Returns:
  68. Iterable[{key, wav, label, sample_rate}]
  69. """
  70. for sample in data:
  71. sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
  72. sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
  73. del sample['audio_data']
  74. # sample['wav'] is torch.Tensor, we have 100 frames every second
  75. num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
  76. if num_frames < min_length:
  77. continue
  78. if num_frames > max_length:
  79. continue
  80. if len(sample['text_token']) < token_min_length:
  81. continue
  82. if len(sample['text_token']) > token_max_length:
  83. continue
  84. if len(sample['speech_token']) == 0:
  85. continue
  86. if 'reject_speech_token' in sample and len(sample['reject_speech_token']) == 0:
  87. continue
  88. if num_frames != 0:
  89. if len(sample['text_token']) / num_frames < min_output_input_ratio:
  90. continue
  91. if len(sample['text_token']) / num_frames > max_output_input_ratio:
  92. continue
  93. yield sample
  94. def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
  95. """ Resample data.
  96. Inplace operation.
  97. Args:
  98. data: Iterable[{key, wav, label, sample_rate}]
  99. resample_rate: target resample rate
  100. Returns:
  101. Iterable[{key, wav, label, sample_rate}]
  102. """
  103. for sample in data:
  104. assert 'sample_rate' in sample
  105. assert 'speech' in sample
  106. sample_rate = sample['sample_rate']
  107. waveform = sample['speech']
  108. if sample_rate != resample_rate:
  109. if sample_rate < min_sample_rate:
  110. continue
  111. sample['sample_rate'] = resample_rate
  112. sample['speech'] = torchaudio.transforms.Resample(
  113. orig_freq=sample_rate, new_freq=resample_rate)(waveform)
  114. max_val = sample['speech'].abs().max()
  115. if max_val > 1:
  116. sample['speech'] /= max_val
  117. yield sample
  118. def truncate(data, truncate_length=24576, mode='train'):
  119. """ Truncate data.
  120. Args:
  121. data: Iterable[{key, wav, label, sample_rate}]
  122. truncate_length: truncate length
  123. Returns:
  124. Iterable[{key, wav, label, sample_rate}]
  125. """
  126. for sample in data:
  127. waveform = sample['speech']
  128. if waveform.shape[1] > truncate_length:
  129. start = random.randint(0, waveform.shape[1] - truncate_length)
  130. waveform = waveform[:, start: start + truncate_length]
  131. else:
  132. waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
  133. sample['speech'] = waveform
  134. yield sample
  135. def compute_fbank(data,
  136. feat_extractor,
  137. token_mel_ratio=0,
  138. mode='train'):
  139. """ Extract fbank
  140. Args:
  141. data: Iterable[{key, wav, label, sample_rate}]
  142. Returns:
  143. Iterable[{key, feat, label}]
  144. """
  145. for sample in data:
  146. assert 'sample_rate' in sample
  147. assert 'speech' in sample
  148. assert 'utt' in sample
  149. assert 'text_token' in sample
  150. waveform = sample['speech']
  151. feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
  152. if token_mel_ratio != 0:
  153. # trim to align speech_token and speech_feat
  154. token_len = int(min(feat.shape[0] / token_mel_ratio, sample["speech_token"].shape[0]))
  155. feat = feat[:token_mel_ratio * token_len]
  156. sample["speech_token"] = sample["speech_token"][:token_len]
  157. sample['speech_feat'] = feat
  158. yield sample
  159. def compute_f0(data, sample_rate, hop_size, mode='train'):
  160. """ Extract f0
  161. Args:
  162. data: Iterable[{key, wav, label, sample_rate}]
  163. Returns:
  164. Iterable[{key, feat, label}]
  165. """
  166. frame_period = hop_size * 1000 / sample_rate
  167. for sample in data:
  168. assert 'sample_rate' in sample
  169. assert 'speech' in sample
  170. assert 'utt' in sample
  171. assert 'text_token' in sample
  172. waveform = sample['speech']
  173. _f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)
  174. if sum(_f0 != 0) < 5: # this happens when the algorithm fails
  175. _f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio
  176. f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate)
  177. f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1)
  178. sample['pitch_feat'] = f0
  179. yield sample
  180. def parse_embedding(data, normalize, mode='train'):
  181. """ Parse utt_embedding/spk_embedding
  182. Args:
  183. data: Iterable[{key, wav, label, sample_rate}]
  184. Returns:
  185. Iterable[{key, feat, label}]
  186. """
  187. for sample in data:
  188. sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
  189. sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
  190. if normalize:
  191. sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
  192. sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
  193. yield sample
  194. def tokenize(data, get_tokenizer, allowed_special, mode='train'):
  195. """ Decode text to chars or BPE
  196. Inplace operation
  197. Args:
  198. data: Iterable[{key, wav, txt, sample_rate}]
  199. Returns:
  200. Iterable[{key, wav, txt, tokens, label, sample_rate}]
  201. """
  202. tokenizer = get_tokenizer()
  203. for sample in data:
  204. assert 'text' in sample
  205. sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
  206. if 'instruct' in sample:
  207. sample['instruct_token'] = tokenizer.encode(sample['instruct'], allowed_special=allowed_special)
  208. else:
  209. sample['instruct_token'] = tokenizer.encode('', allowed_special=allowed_special)
  210. yield sample
  211. def shuffle(data, shuffle_size=10000, mode='train'):
  212. """ Local shuffle the data
  213. Args:
  214. data: Iterable[{key, feat, label}]
  215. shuffle_size: buffer size for shuffle
  216. Returns:
  217. Iterable[{key, feat, label}]
  218. """
  219. buf = []
  220. for sample in data:
  221. buf.append(sample)
  222. if len(buf) >= shuffle_size:
  223. random.shuffle(buf)
  224. for x in buf:
  225. yield x
  226. buf = []
  227. # The sample left over
  228. random.shuffle(buf)
  229. for x in buf:
  230. yield x
  231. def sort(data, sort_size=500, mode='train'):
  232. """ Sort the data by feature length.
  233. Sort is used after shuffle and before batch, so we can group
  234. utts with similar lengths into a batch, and `sort_size` should
  235. be less than `shuffle_size`
  236. Args:
  237. data: Iterable[{key, feat, label}]
  238. sort_size: buffer size for sort
  239. Returns:
  240. Iterable[{key, feat, label}]
  241. """
  242. buf = []
  243. for sample in data:
  244. buf.append(sample)
  245. if len(buf) >= sort_size:
  246. buf.sort(key=lambda x: x['speech_feat'].size(0))
  247. for x in buf:
  248. yield x
  249. buf = []
  250. # The sample left over
  251. buf.sort(key=lambda x: x['speech_feat'].size(0))
  252. for x in buf:
  253. yield x
  254. def static_batch(data, batch_size=16):
  255. """ Static batch the data by `batch_size`
  256. Args:
  257. data: Iterable[{key, feat, label}]
  258. batch_size: batch size
  259. Returns:
  260. Iterable[List[{key, feat, label}]]
  261. """
  262. buf = []
  263. for sample in data:
  264. buf.append(sample)
  265. if len(buf) >= batch_size:
  266. yield buf
  267. buf = []
  268. if len(buf) > 0:
  269. yield buf
  270. def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
  271. """ Dynamic batch the data until the total frames in batch
  272. reach `max_frames_in_batch`
  273. Args:
  274. data: Iterable[{key, feat, label}]
  275. max_frames_in_batch: max_frames in one batch
  276. Returns:
  277. Iterable[List[{key, feat, label}]]
  278. """
  279. buf = []
  280. longest_frames = 0
  281. for sample in data:
  282. assert 'speech_feat' in sample
  283. assert isinstance(sample['speech_feat'], torch.Tensor)
  284. new_sample_frames = sample['speech_feat'].size(0)
  285. longest_frames = max(longest_frames, new_sample_frames)
  286. frames_after_padding = longest_frames * (len(buf) + 1)
  287. if frames_after_padding > max_frames_in_batch:
  288. yield buf
  289. buf = [sample]
  290. longest_frames = new_sample_frames
  291. else:
  292. buf.append(sample)
  293. if len(buf) > 0:
  294. yield buf
  295. def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
  296. """ Wrapper for static/dynamic batch
  297. """
  298. if batch_type == 'static':
  299. return static_batch(data, batch_size)
  300. elif batch_type == 'dynamic':
  301. return dynamic_batch(data, max_frames_in_batch)
  302. else:
  303. logging.fatal('Unsupported batch type {}'.format(batch_type))
  304. def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
  305. """ Padding the data into training data
  306. Args:
  307. data: Iterable[List[{key, feat, label}]]
  308. Returns:
  309. Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
  310. """
  311. for sample in data:
  312. assert isinstance(sample, list)
  313. speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
  314. dtype=torch.int32)
  315. order = torch.argsort(speech_feat_len, descending=True)
  316. utts = [sample[i]['utt'] for i in order]
  317. speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
  318. speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
  319. speech = pad_sequence(speech, batch_first=True, padding_value=0)
  320. speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
  321. speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
  322. speech_token = pad_sequence(speech_token,
  323. batch_first=True,
  324. padding_value=0)
  325. speech_feat = [sample[i]['speech_feat'] for i in order]
  326. speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
  327. speech_feat = pad_sequence(speech_feat,
  328. batch_first=True,
  329. padding_value=0)
  330. text = [sample[i]['text'] for i in order]
  331. text_token = [torch.tensor(sample[i]['text_token']) for i in order]
  332. text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
  333. text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
  334. instruct_token = [torch.tensor(sample[i]['instruct_token']) for i in order]
  335. instruct_token_len = torch.tensor([i.size(0) for i in instruct_token], dtype=torch.int32)
  336. instruct_token = pad_sequence(instruct_token, batch_first=True, padding_value=0)
  337. utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
  338. spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
  339. batch = {
  340. "utts": utts,
  341. "speech": speech,
  342. "speech_len": speech_len,
  343. "speech_token": speech_token,
  344. "speech_token_len": speech_token_len,
  345. "speech_feat": speech_feat,
  346. "speech_feat_len": speech_feat_len,
  347. "text": text,
  348. "text_token": text_token,
  349. "text_token_len": text_token_len,
  350. "instruct_token": instruct_token,
  351. "instruct_token_len": instruct_token_len,
  352. "utt_embedding": utt_embedding,
  353. "spk_embedding": spk_embedding,
  354. }
  355. if gan is True:
  356. # in gan train, we need pitch_feat
  357. pitch_feat = [sample[i]['pitch_feat'] for i in order]
  358. pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
  359. pitch_feat = pad_sequence(pitch_feat,
  360. batch_first=True,
  361. padding_value=0)
  362. batch["pitch_feat"] = pitch_feat
  363. batch["pitch_feat_len"] = pitch_feat_len
  364. else:
  365. # only gan train needs speech, delete it to save memory
  366. del batch["speech"]
  367. del batch["speech_len"]
  368. if dpo is True:
  369. reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order]
  370. reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
  371. reject_speech_token = pad_sequence(reject_speech_token,
  372. batch_first=True,
  373. padding_value=0)
  374. batch['reject_speech_token'] = reject_speech_token
  375. batch['reject_speech_token_len'] = reject_speech_token_len
  376. if use_spk_embedding is True:
  377. batch["embedding"] = batch["spk_embedding"]
  378. else:
  379. batch["embedding"] = batch["utt_embedding"]
  380. yield batch