processor.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. import random
  16. import pyarrow.parquet as pq
  17. from io import BytesIO
  18. import numpy as np
  19. import whisper
  20. import torch
  21. import torchaudio
  22. from torch.nn.utils.rnn import pad_sequence
  23. import torch.nn.functional as F
  24. import pyworld as pw
  25. from cosyvoice.utils.onnx import embedding_extractor, online_feature
  26. AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
  27. def parquet_opener(data, mode='train'):
  28. """ Give url or local file, return file descriptor
  29. Inplace operation.
  30. Args:
  31. data(Iterable[str]): url or local file list
  32. Returns:
  33. Iterable[{src, stream}]
  34. """
  35. for sample in data:
  36. assert 'src' in sample
  37. url = sample['src']
  38. try:
  39. for df in pq.ParquetFile(url).iter_batches(batch_size=64):
  40. df = df.to_pandas()
  41. for i in range(len(df)):
  42. sample.update(dict(df.loc[i]))
  43. # NOTE do not return sample directly, must initialize a new dict
  44. yield {**sample}
  45. except Exception as ex:
  46. logging.warning('Failed to open {}, ex info {}'.format(url, ex))
  47. def filter(data,
  48. max_length=10240,
  49. min_length=10,
  50. token_max_length=200,
  51. token_min_length=1,
  52. min_output_input_ratio=0.0005,
  53. max_output_input_ratio=1,
  54. mode='train'):
  55. """ Filter sample according to feature and label length
  56. Inplace operation.
  57. Args::
  58. data: Iterable[{key, wav, label, sample_rate}]
  59. max_length: drop utterance which is greater than max_length(10ms)
  60. min_length: drop utterance which is less than min_length(10ms)
  61. token_max_length: drop utterance which is greater than
  62. token_max_length, especially when use char unit for
  63. english modeling
  64. token_min_length: drop utterance which is
  65. less than token_max_length
  66. min_output_input_ratio: minimal ration of
  67. token_length / feats_length(10ms)
  68. max_output_input_ratio: maximum ration of
  69. token_length / feats_length(10ms)
  70. Returns:
  71. Iterable[{key, wav, label, sample_rate}]
  72. """
  73. for sample in data:
  74. sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
  75. sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
  76. del sample['audio_data']
  77. # sample['wav'] is torch.Tensor, we have 100 frames every second
  78. num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
  79. if num_frames < min_length:
  80. continue
  81. if num_frames > max_length:
  82. continue
  83. if len(sample['text_token']) < token_min_length:
  84. continue
  85. if len(sample['text_token']) > token_max_length:
  86. continue
  87. if online_feature is False and len(sample['speech_token']) == 0:
  88. continue
  89. if online_feature is False and 'reject_speech_token' in sample and len(sample['reject_speech_token']) == 0:
  90. continue
  91. if num_frames != 0:
  92. if len(sample['text_token']) / num_frames < min_output_input_ratio:
  93. continue
  94. if len(sample['text_token']) / num_frames > max_output_input_ratio:
  95. continue
  96. yield sample
  97. def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
  98. """ Resample data.
  99. Inplace operation.
  100. Args:
  101. data: Iterable[{key, wav, label, sample_rate}]
  102. resample_rate: target resample rate
  103. Returns:
  104. Iterable[{key, wav, label, sample_rate}]
  105. """
  106. for sample in data:
  107. assert 'sample_rate' in sample
  108. assert 'speech' in sample
  109. sample_rate = sample['sample_rate']
  110. waveform = sample['speech']
  111. if sample_rate != resample_rate:
  112. if sample_rate < min_sample_rate:
  113. continue
  114. sample['sample_rate'] = resample_rate
  115. sample['speech'] = torchaudio.transforms.Resample(
  116. orig_freq=sample_rate, new_freq=resample_rate)(waveform)
  117. max_val = sample['speech'].abs().max()
  118. if max_val > 1:
  119. sample['speech'] /= max_val
  120. yield sample
  121. def truncate(data, truncate_length=24576, mode='train'):
  122. """ Truncate data.
  123. Args:
  124. data: Iterable[{key, wav, label, sample_rate}]
  125. truncate_length: truncate length
  126. Returns:
  127. Iterable[{key, wav, label, sample_rate}]
  128. """
  129. for sample in data:
  130. waveform = sample['speech']
  131. if waveform.shape[1] > truncate_length:
  132. start = random.randint(0, waveform.shape[1] - truncate_length)
  133. waveform = waveform[:, start: start + truncate_length]
  134. else:
  135. waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
  136. sample['speech'] = waveform
  137. yield sample
  138. def compute_fbank(data,
  139. feat_extractor,
  140. num_frames=-1,
  141. mode='train'):
  142. """ Extract fbank
  143. Args:
  144. data: Iterable[{key, wav, label, sample_rate}]
  145. Returns:
  146. Iterable[{key, feat, label}]
  147. """
  148. for sample in data:
  149. assert 'sample_rate' in sample
  150. assert 'speech' in sample
  151. assert 'utt' in sample
  152. assert 'text_token' in sample
  153. # NOTE in cosyvoice2/3, we support online token extraction, so we need to align speech to 25hz first
  154. if num_frames != -1:
  155. index = int(np.ceil(sample['speech'].shape[1] / num_frames))
  156. sample['speech'] = torch.concat([sample['speech'], torch.zeros(1, index * num_frames - sample['speech'].shape[1])], dim=1)
  157. sample['speech_feat'] = feat_extractor(sample['speech']).squeeze(dim=0).transpose(0, 1)
  158. yield sample
  159. def compute_whisper_fbank(data, num_frames=-1, mode='train'):
  160. """ Extract whisper fbank
  161. Args:
  162. data: Iterable[{key, wav, label, sample_rate}]
  163. Returns:
  164. Iterable[{key, feat, label}]
  165. """
  166. for sample in data:
  167. if num_frames != -1:
  168. assert sample['speech'].shape[1] % num_frames == 0, 'speech length is not aligned with speech_token'
  169. sample['speech_16k'] = torchaudio.transforms.Resample(orig_freq=sample['sample_rate'], new_freq=16000)(sample['speech'])
  170. sample['whisper_feat'] = whisper.log_mel_spectrogram(sample['speech_16k'], n_mels=128).squeeze(dim=0).transpose(0, 1)
  171. yield sample
  172. def compute_f0(data, sample_rate, hop_size, mode='train'):
  173. """ Extract f0
  174. Args:
  175. data: Iterable[{key, wav, label, sample_rate}]
  176. Returns:
  177. Iterable[{key, feat, label}]
  178. """
  179. frame_period = hop_size * 1000 / sample_rate
  180. for sample in data:
  181. assert 'sample_rate' in sample
  182. assert 'speech' in sample
  183. assert 'utt' in sample
  184. assert 'text_token' in sample
  185. waveform = sample['speech']
  186. _f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)
  187. if sum(_f0 != 0) < 5: # this happens when the algorithm fails
  188. _f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio
  189. f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate)
  190. f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1)
  191. sample['pitch_feat'] = f0
  192. yield sample
  193. def parse_embedding(data, normalize, mode='train'):
  194. """ Parse utt_embedding/spk_embedding
  195. Args:
  196. data: Iterable[{key, wav, label, sample_rate}]
  197. Returns:
  198. Iterable[{key, feat, label}]
  199. """
  200. for sample in data:
  201. if 'utt_embedding' not in sample and 'spk_embedding' not in sample:
  202. sample['speech_16k'] = torchaudio.transforms.Resample(orig_freq=sample['sample_rate'], new_freq=16000)(sample['speech'])
  203. embedding = embedding_extractor.inference(sample['speech_16k'])
  204. sample['spk_embedding'] = sample['utt_embedding'] = embedding
  205. else:
  206. sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
  207. sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
  208. if normalize:
  209. sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
  210. sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
  211. yield sample
  212. def tokenize(data, get_tokenizer, allowed_special, mode='train'):
  213. """ Decode text to chars or BPE
  214. Inplace operation
  215. Args:
  216. data: Iterable[{key, wav, txt, sample_rate}]
  217. Returns:
  218. Iterable[{key, wav, txt, tokens, label, sample_rate}]
  219. """
  220. tokenizer = get_tokenizer()
  221. for sample in data:
  222. assert 'text' in sample
  223. sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
  224. if 'instruct' in sample:
  225. sample['instruct_token'] = tokenizer.encode(sample['instruct'], allowed_special=allowed_special)
  226. yield sample
  227. def shuffle(data, shuffle_size=10000, mode='train'):
  228. """ Local shuffle the data
  229. Args:
  230. data: Iterable[{key, feat, label}]
  231. shuffle_size: buffer size for shuffle
  232. Returns:
  233. Iterable[{key, feat, label}]
  234. """
  235. buf = []
  236. yield_size = int(shuffle_size / 2)
  237. for sample in data:
  238. buf.append(sample)
  239. if len(buf) >= shuffle_size:
  240. random.shuffle(buf)
  241. for x in buf[:yield_size]:
  242. yield x
  243. buf = buf[yield_size:]
  244. # The sample left over
  245. random.shuffle(buf)
  246. for x in buf:
  247. yield x
  248. def sort(data, sort_size=500, mode='train'):
  249. """ Sort the data by feature length.
  250. Sort is used after shuffle and before batch, so we can group
  251. utts with similar lengths into a batch, and `sort_size` should
  252. be less than `shuffle_size`
  253. Args:
  254. data: Iterable[{key, feat, label}]
  255. sort_size: buffer size for sort
  256. Returns:
  257. Iterable[{key, feat, label}]
  258. """
  259. buf = []
  260. for sample in data:
  261. buf.append(sample)
  262. if len(buf) >= sort_size:
  263. buf.sort(key=lambda x: x['speech_feat'].size(0))
  264. for x in buf:
  265. yield x
  266. buf = []
  267. # The sample left over
  268. buf.sort(key=lambda x: x['speech_feat'].size(0))
  269. for x in buf:
  270. yield x
  271. def static_batch(data, batch_size=16):
  272. """ Static batch the data by `batch_size`
  273. Args:
  274. data: Iterable[{key, feat, label}]
  275. batch_size: batch size
  276. Returns:
  277. Iterable[List[{key, feat, label}]]
  278. """
  279. buf = []
  280. for sample in data:
  281. buf.append(sample)
  282. if len(buf) >= batch_size:
  283. yield buf
  284. buf = []
  285. if len(buf) > 0:
  286. yield buf
  287. def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
  288. """ Dynamic batch the data until the total frames in batch
  289. reach `max_frames_in_batch`
  290. Args:
  291. data: Iterable[{key, feat, label}]
  292. max_frames_in_batch: max_frames in one batch
  293. Returns:
  294. Iterable[List[{key, feat, label}]]
  295. """
  296. buf = []
  297. longest_frames = 0
  298. for sample in data:
  299. assert 'speech_feat' in sample
  300. assert isinstance(sample['speech_feat'], torch.Tensor)
  301. new_sample_frames = sample['speech_feat'].size(0)
  302. longest_frames = max(longest_frames, new_sample_frames)
  303. frames_after_padding = longest_frames * (len(buf) + 1)
  304. if frames_after_padding > max_frames_in_batch:
  305. yield buf
  306. buf = [sample]
  307. longest_frames = new_sample_frames
  308. else:
  309. buf.append(sample)
  310. if len(buf) > 0:
  311. yield buf
  312. def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
  313. """ Wrapper for static/dynamic batch
  314. """
  315. if batch_type == 'static':
  316. return static_batch(data, batch_size)
  317. elif batch_type == 'dynamic':
  318. return dynamic_batch(data, max_frames_in_batch)
  319. else:
  320. logging.fatal('Unsupported batch type {}'.format(batch_type))
  321. def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
  322. """ Padding the data into training data
  323. Args:
  324. data: Iterable[List[{key, feat, label}]]
  325. Returns:
  326. Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
  327. """
  328. for sample in data:
  329. assert isinstance(sample, list)
  330. order = torch.argsort(torch.tensor([x['speech'].size(1) for x in sample], dtype=torch.int32), descending=True)
  331. batch = {}
  332. batch['utts'] = [sample[i]['utt'] for i in order]
  333. batch['text'] = [sample[i]['text'] for i in order]
  334. text_token = [torch.tensor(sample[i]['text_token']) for i in order]
  335. batch['text_token_len'] = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
  336. batch['text_token'] = pad_sequence(text_token, batch_first=True, padding_value=0)
  337. speech_feat = [sample[i]['speech_feat'] for i in order]
  338. batch['speech_feat_len'] = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
  339. batch['speech_feat'] = pad_sequence(speech_feat, batch_first=True, padding_value=0)
  340. batch['utt_embedding'] = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
  341. batch['spk_embedding'] = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
  342. if torch.tensor(['instruct_token' in sample[i] for i in order]).all():
  343. instruct_token = [torch.tensor(sample[i]['instruct_token']) for i in order]
  344. batch['instruct_token_len'] = torch.tensor([i.size(0) for i in instruct_token], dtype=torch.int32)
  345. batch['instruct_token'] = pad_sequence(instruct_token, batch_first=True, padding_value=0)
  346. if torch.tensor(['whisper_feat' in sample[i] for i in order]).all():
  347. whisper_feat = [sample[i]['whisper_feat'] for i in order]
  348. batch['whisper_feat_len'] = torch.tensor([i.size(0) for i in whisper_feat], dtype=torch.int32)
  349. batch['whisper_feat'] = pad_sequence(whisper_feat, batch_first=True, padding_value=0)
  350. if torch.tensor(['speech_token' in sample[i] for i in order]).all():
  351. speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
  352. batch['speech_token_len'] = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
  353. batch['speech_token'] = pad_sequence(speech_token, batch_first=True, padding_value=0)
  354. if gan is True:
  355. # in gan train, we need speech/pitch_feat
  356. speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
  357. batch['speech_len'] = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
  358. batch['speech'] = pad_sequence(speech, batch_first=True, padding_value=0)
  359. pitch_feat = [sample[i]['pitch_feat'] for i in order]
  360. batch['pitch_feat_len'] = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
  361. batch['pitch_feat'] = pad_sequence(pitch_feat, batch_first=True, padding_value=0)
  362. if dpo is True:
  363. reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order]
  364. batch['reject_speech_token_len'] = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
  365. batch['reject_speech_token'] = pad_sequence(reject_speech_token, batch_first=True, padding_value=0)
  366. if use_spk_embedding is True:
  367. batch["embedding"] = batch["spk_embedding"]
  368. else:
  369. batch["embedding"] = batch["utt_embedding"]
  370. yield batch