processor_dpo.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. import random
  16. import pyarrow.parquet as pq
  17. from io import BytesIO
  18. import torch
  19. import torchaudio
  20. from torch.nn.utils.rnn import pad_sequence
  21. import torch.nn.functional as F
  22. import pyworld as pw
  23. AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
  24. def parquet_opener(data, mode='train', tts_data={}):
  25. """ Give url or local file, return file descriptor
  26. Inplace operation.
  27. Args:
  28. data(Iterable[str]): url or local file list
  29. Returns:
  30. Iterable[{src, stream}]
  31. """
  32. for sample in data:
  33. assert 'src' in sample
  34. url = sample['src']
  35. try:
  36. for df in pq.ParquetFile(url).iter_batches(batch_size=64):
  37. df = df.to_pandas()
  38. for i in range(len(df)):
  39. if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
  40. continue
  41. sample.update(dict(df.loc[i]))
  42. if mode == 'train':
  43. # NOTE do not return sample directly, must initialize a new dict
  44. yield {**sample}
  45. else:
  46. for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
  47. yield {**sample, 'tts_index': index, 'tts_text': text}
  48. except Exception as ex:
  49. logging.warning('Failed to open {}, ex info {}'.format(url, ex))
  50. def filter(data,
  51. max_length=10240,
  52. min_length=10,
  53. token_max_length=200,
  54. token_min_length=1,
  55. min_output_input_ratio=0.0005,
  56. max_output_input_ratio=1,
  57. mode='train'):
  58. """ Filter sample according to feature and label length
  59. Inplace operation.
  60. Args::
  61. data: Iterable[{key, wav, label, sample_rate}]
  62. max_length: drop utterance which is greater than max_length(10ms)
  63. min_length: drop utterance which is less than min_length(10ms)
  64. token_max_length: drop utterance which is greater than
  65. token_max_length, especially when use char unit for
  66. english modeling
  67. token_min_length: drop utterance which is
  68. less than token_max_length
  69. min_output_input_ratio: minimal ration of
  70. token_length / feats_length(10ms)
  71. max_output_input_ratio: maximum ration of
  72. token_length / feats_length(10ms)
  73. Returns:
  74. Iterable[{key, wav, label, sample_rate}]
  75. """
  76. for sample in data:
  77. sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
  78. sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
  79. del sample['audio_data']
  80. # sample['wav'] is torch.Tensor, we have 100 frames every second
  81. num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
  82. if num_frames < min_length:
  83. continue
  84. if num_frames > max_length:
  85. continue
  86. if len(sample['text_token']) < token_min_length:
  87. continue
  88. if len(sample['text_token']) > token_max_length:
  89. continue
  90. if len(sample['speech_token']) == 0:
  91. continue
  92. if num_frames != 0:
  93. if len(sample['text_token']) / num_frames < min_output_input_ratio:
  94. continue
  95. if len(sample['text_token']) / num_frames > max_output_input_ratio:
  96. continue
  97. yield sample
  98. def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
  99. """ Resample data.
  100. Inplace operation.
  101. Args:
  102. data: Iterable[{key, wav, label, sample_rate}]
  103. resample_rate: target resample rate
  104. Returns:
  105. Iterable[{key, wav, label, sample_rate}]
  106. """
  107. for sample in data:
  108. assert 'sample_rate' in sample
  109. assert 'speech' in sample
  110. sample_rate = sample['sample_rate']
  111. waveform = sample['speech']
  112. if sample_rate != resample_rate:
  113. if sample_rate < min_sample_rate:
  114. continue
  115. sample['sample_rate'] = resample_rate
  116. sample['speech'] = torchaudio.transforms.Resample(
  117. orig_freq=sample_rate, new_freq=resample_rate)(waveform)
  118. max_val = sample['speech'].abs().max()
  119. if max_val > 1:
  120. sample['speech'] /= max_val
  121. yield sample
  122. def truncate(data, truncate_length=24576, mode='train'):
  123. """ Truncate data.
  124. Args:
  125. data: Iterable[{key, wav, label, sample_rate}]
  126. truncate_length: truncate length
  127. Returns:
  128. Iterable[{key, wav, label, sample_rate}]
  129. """
  130. for sample in data:
  131. waveform = sample['speech']
  132. if waveform.shape[1] > truncate_length:
  133. start = random.randint(0, waveform.shape[1] - truncate_length)
  134. waveform = waveform[:, start: start + truncate_length]
  135. else:
  136. waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
  137. sample['speech'] = waveform
  138. yield sample
  139. def compute_fbank(data,
  140. feat_extractor,
  141. mode='train'):
  142. """ Extract fbank
  143. Args:
  144. data: Iterable[{key, wav, label, sample_rate}]
  145. Returns:
  146. Iterable[{key, feat, label}]
  147. """
  148. for sample in data:
  149. assert 'sample_rate' in sample
  150. assert 'speech' in sample
  151. assert 'utt' in sample
  152. assert 'text_token' in sample
  153. waveform = sample['speech']
  154. mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
  155. sample['speech_feat'] = mat
  156. yield sample
  157. def compute_f0(data, sample_rate, hop_size, mode='train'):
  158. """ Extract f0
  159. Args:
  160. data: Iterable[{key, wav, label, sample_rate}]
  161. Returns:
  162. Iterable[{key, feat, label}]
  163. """
  164. frame_period = hop_size * 1000 / sample_rate
  165. for sample in data:
  166. assert 'sample_rate' in sample
  167. assert 'speech' in sample
  168. assert 'utt' in sample
  169. assert 'text_token' in sample
  170. waveform = sample['speech']
  171. _f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)
  172. if sum(_f0 != 0) < 5: # this happens when the algorithm fails
  173. _f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio
  174. f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate)
  175. f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1)
  176. sample['pitch_feat'] = f0
  177. yield sample
  178. def parse_embedding(data, normalize, mode='train'):
  179. """ Parse utt_embedding/spk_embedding
  180. Args:
  181. data: Iterable[{key, wav, label, sample_rate}]
  182. Returns:
  183. Iterable[{key, feat, label}]
  184. """
  185. for sample in data:
  186. sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
  187. sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
  188. if normalize:
  189. sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
  190. sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
  191. yield sample
  192. def tokenize(data, get_tokenizer, allowed_special, mode='train'):
  193. """ Decode text to chars or BPE
  194. Inplace operation
  195. Args:
  196. data: Iterable[{key, wav, txt, sample_rate}]
  197. Returns:
  198. Iterable[{key, wav, txt, tokens, label, sample_rate}]
  199. """
  200. tokenizer = get_tokenizer()
  201. for sample in data:
  202. assert 'text' in sample
  203. sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
  204. if mode == 'inference':
  205. sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
  206. yield sample
  207. def shuffle(data, shuffle_size=10000, mode='train'):
  208. """ Local shuffle the data
  209. Args:
  210. data: Iterable[{key, feat, label}]
  211. shuffle_size: buffer size for shuffle
  212. Returns:
  213. Iterable[{key, feat, label}]
  214. """
  215. buf = []
  216. for sample in data:
  217. buf.append(sample)
  218. if len(buf) >= shuffle_size:
  219. random.shuffle(buf)
  220. for x in buf:
  221. yield x
  222. buf = []
  223. # The sample left over
  224. random.shuffle(buf)
  225. for x in buf:
  226. yield x
  227. def sort(data, sort_size=500, mode='train'):
  228. """ Sort the data by feature length.
  229. Sort is used after shuffle and before batch, so we can group
  230. utts with similar lengths into a batch, and `sort_size` should
  231. be less than `shuffle_size`
  232. Args:
  233. data: Iterable[{key, feat, label}]
  234. sort_size: buffer size for sort
  235. Returns:
  236. Iterable[{key, feat, label}]
  237. """
  238. buf = []
  239. for sample in data:
  240. buf.append(sample)
  241. if len(buf) >= sort_size:
  242. buf.sort(key=lambda x: x['speech_feat'].size(0))
  243. for x in buf:
  244. yield x
  245. buf = []
  246. # The sample left over
  247. buf.sort(key=lambda x: x['speech_feat'].size(0))
  248. for x in buf:
  249. yield x
  250. def static_batch(data, batch_size=16):
  251. """ Static batch the data by `batch_size`
  252. Args:
  253. data: Iterable[{key, feat, label}]
  254. batch_size: batch size
  255. Returns:
  256. Iterable[List[{key, feat, label}]]
  257. """
  258. buf = []
  259. for sample in data:
  260. buf.append(sample)
  261. if len(buf) >= batch_size:
  262. yield buf
  263. buf = []
  264. if len(buf) > 0:
  265. yield buf
  266. def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
  267. """ Dynamic batch the data until the total frames in batch
  268. reach `max_frames_in_batch`
  269. Args:
  270. data: Iterable[{key, feat, label}]
  271. max_frames_in_batch: max_frames in one batch
  272. Returns:
  273. Iterable[List[{key, feat, label}]]
  274. """
  275. buf = []
  276. longest_frames = 0
  277. for sample in data:
  278. assert 'speech_feat' in sample
  279. assert isinstance(sample['speech_feat'], torch.Tensor)
  280. new_sample_frames = sample['speech_feat'].size(0)
  281. longest_frames = max(longest_frames, new_sample_frames)
  282. frames_after_padding = longest_frames * (len(buf) + 1)
  283. if frames_after_padding > max_frames_in_batch:
  284. yield buf
  285. buf = [sample]
  286. longest_frames = new_sample_frames
  287. else:
  288. buf.append(sample)
  289. if len(buf) > 0:
  290. yield buf
  291. def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
  292. """ Wrapper for static/dynamic batch
  293. """
  294. if mode == 'inference':
  295. return static_batch(data, 1)
  296. else:
  297. if batch_type == 'static':
  298. return static_batch(data, batch_size)
  299. elif batch_type == 'dynamic':
  300. return dynamic_batch(data, max_frames_in_batch)
  301. else:
  302. logging.fatal('Unsupported batch type {}'.format(batch_type))
  303. def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
  304. """ Padding the data into training data
  305. Args:
  306. data: Iterable[List[{key, feat, label}]]
  307. Returns:
  308. Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
  309. """
  310. for sample in data:
  311. assert isinstance(sample, list)
  312. speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
  313. dtype=torch.int32)
  314. order = torch.argsort(speech_feat_len, descending=True)
  315. utts = [sample[i]['utt'] for i in order]
  316. speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
  317. speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
  318. speech = pad_sequence(speech, batch_first=True, padding_value=0)
  319. speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
  320. speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
  321. speech_token = pad_sequence(speech_token,
  322. batch_first=True,
  323. padding_value=0)
  324. speech_feat = [sample[i]['speech_feat'] for i in order]
  325. speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
  326. speech_feat = pad_sequence(speech_feat,
  327. batch_first=True,
  328. padding_value=0)
  329. text = [sample[i]['text'] for i in order]
  330. text_token = [torch.tensor(sample[i]['text_token']) for i in order]
  331. text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
  332. text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
  333. utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
  334. spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
  335. batch = {
  336. "utts": utts,
  337. "speech": speech,
  338. "speech_len": speech_len,
  339. "speech_token": speech_token,
  340. "speech_token_len": speech_token_len,
  341. "speech_feat": speech_feat,
  342. "speech_feat_len": speech_feat_len,
  343. "text": text,
  344. "text_token": text_token,
  345. "text_token_len": text_token_len,
  346. "utt_embedding": utt_embedding,
  347. "spk_embedding": spk_embedding,
  348. }
  349. if dpo:
  350. reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order]
  351. reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
  352. reject_speech_token = pad_sequence(reject_speech_token,
  353. batch_first=True,
  354. padding_value=0)
  355. batch['reject_speech_token'] = reject_speech_token
  356. batch['reject_speech_token_len'] = reject_speech_token_len
  357. if gan is True:
  358. # in gan train, we need pitch_feat
  359. pitch_feat = [sample[i]['pitch_feat'] for i in order]
  360. pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
  361. pitch_feat = pad_sequence(pitch_feat,
  362. batch_first=True,
  363. padding_value=0)
  364. batch["pitch_feat"] = pitch_feat
  365. batch["pitch_feat_len"] = pitch_feat_len
  366. else:
  367. # only gan train needs speech, delete it to save memory
  368. del batch["speech"]
  369. del batch["speech_len"]
  370. if mode == 'inference':
  371. tts_text = [sample[i]['tts_text'] for i in order]
  372. tts_index = [sample[i]['tts_index'] for i in order]
  373. tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
  374. tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
  375. tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
  376. batch.update({'tts_text': tts_text,
  377. 'tts_index': tts_index,
  378. 'tts_text_token': tts_text_token,
  379. 'tts_text_token_len': tts_text_token_len})
  380. if use_spk_embedding is True:
  381. batch["embedding"] = batch["spk_embedding"]
  382. else:
  383. batch["embedding"] = batch["utt_embedding"]
  384. yield batch