| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443 |
- # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import logging
- import random
- import pyarrow.parquet as pq
- from io import BytesIO
- import torch
- import torchaudio
- from torch.nn.utils.rnn import pad_sequence
- import torch.nn.functional as F
- import pyworld as pw
- AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
- def parquet_opener(data, mode='train', tts_data={}):
- """ Give url or local file, return file descriptor
- Inplace operation.
- Args:
- data(Iterable[str]): url or local file list
- Returns:
- Iterable[{src, stream}]
- """
- for sample in data:
- assert 'src' in sample
- url = sample['src']
- try:
- for df in pq.ParquetFile(url).iter_batches(batch_size=64):
- df = df.to_pandas()
- for i in range(len(df)):
- if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
- continue
- sample.update(dict(df.loc[i]))
- if mode == 'train':
- # NOTE do not return sample directly, must initialize a new dict
- yield {**sample}
- else:
- for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
- yield {**sample, 'tts_index': index, 'tts_text': text}
- except Exception as ex:
- logging.warning('Failed to open {}, ex info {}'.format(url, ex))
- def filter(data,
- max_length=10240,
- min_length=10,
- token_max_length=200,
- token_min_length=1,
- min_output_input_ratio=0.0005,
- max_output_input_ratio=1,
- mode='train'):
- """ Filter sample according to feature and label length
- Inplace operation.
- Args::
- data: Iterable[{key, wav, label, sample_rate}]
- max_length: drop utterance which is greater than max_length(10ms)
- min_length: drop utterance which is less than min_length(10ms)
- token_max_length: drop utterance which is greater than
- token_max_length, especially when use char unit for
- english modeling
- token_min_length: drop utterance which is
- less than token_max_length
- min_output_input_ratio: minimal ration of
- token_length / feats_length(10ms)
- max_output_input_ratio: maximum ration of
- token_length / feats_length(10ms)
- Returns:
- Iterable[{key, wav, label, sample_rate}]
- """
- for sample in data:
- sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
- sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
- del sample['audio_data']
- # sample['wav'] is torch.Tensor, we have 100 frames every second
- num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
- if num_frames < min_length:
- continue
- if num_frames > max_length:
- continue
- if len(sample['text_token']) < token_min_length:
- continue
- if len(sample['text_token']) > token_max_length:
- continue
- if len(sample['speech_token']) == 0:
- continue
- if num_frames != 0:
- if len(sample['text_token']) / num_frames < min_output_input_ratio:
- continue
- if len(sample['text_token']) / num_frames > max_output_input_ratio:
- continue
- yield sample
- def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
- """ Resample data.
- Inplace operation.
- Args:
- data: Iterable[{key, wav, label, sample_rate}]
- resample_rate: target resample rate
- Returns:
- Iterable[{key, wav, label, sample_rate}]
- """
- for sample in data:
- assert 'sample_rate' in sample
- assert 'speech' in sample
- sample_rate = sample['sample_rate']
- waveform = sample['speech']
- if sample_rate != resample_rate:
- if sample_rate < min_sample_rate:
- continue
- sample['sample_rate'] = resample_rate
- sample['speech'] = torchaudio.transforms.Resample(
- orig_freq=sample_rate, new_freq=resample_rate)(waveform)
- max_val = sample['speech'].abs().max()
- if max_val > 1:
- sample['speech'] /= max_val
- yield sample
- def truncate(data, truncate_length=24576, mode='train'):
- """ Truncate data.
- Args:
- data: Iterable[{key, wav, label, sample_rate}]
- truncate_length: truncate length
- Returns:
- Iterable[{key, wav, label, sample_rate}]
- """
- for sample in data:
- waveform = sample['speech']
- if waveform.shape[1] > truncate_length:
- start = random.randint(0, waveform.shape[1] - truncate_length)
- waveform = waveform[:, start: start + truncate_length]
- else:
- waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
- sample['speech'] = waveform
- yield sample
- def compute_fbank(data,
- feat_extractor,
- mode='train'):
- """ Extract fbank
- Args:
- data: Iterable[{key, wav, label, sample_rate}]
- Returns:
- Iterable[{key, feat, label}]
- """
- for sample in data:
- assert 'sample_rate' in sample
- assert 'speech' in sample
- assert 'utt' in sample
- assert 'text_token' in sample
- waveform = sample['speech']
- mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
- sample['speech_feat'] = mat
- yield sample
- def compute_f0(data, sample_rate, hop_size, mode='train'):
- """ Extract f0
- Args:
- data: Iterable[{key, wav, label, sample_rate}]
- Returns:
- Iterable[{key, feat, label}]
- """
- frame_period = hop_size * 1000 / sample_rate
- for sample in data:
- assert 'sample_rate' in sample
- assert 'speech' in sample
- assert 'utt' in sample
- assert 'text_token' in sample
- waveform = sample['speech']
- _f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)
- if sum(_f0 != 0) < 5: # this happens when the algorithm fails
- _f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio
- f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate)
- f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1)
- sample['pitch_feat'] = f0
- yield sample
- def parse_embedding(data, normalize, mode='train'):
- """ Parse utt_embedding/spk_embedding
- Args:
- data: Iterable[{key, wav, label, sample_rate}]
- Returns:
- Iterable[{key, feat, label}]
- """
- for sample in data:
- sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
- sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
- if normalize:
- sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
- sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
- yield sample
- def tokenize(data, get_tokenizer, allowed_special, mode='train'):
- """ Decode text to chars or BPE
- Inplace operation
- Args:
- data: Iterable[{key, wav, txt, sample_rate}]
- Returns:
- Iterable[{key, wav, txt, tokens, label, sample_rate}]
- """
- tokenizer = get_tokenizer()
- for sample in data:
- assert 'text' in sample
- sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
- if mode == 'inference':
- sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
- yield sample
- def shuffle(data, shuffle_size=10000, mode='train'):
- """ Local shuffle the data
- Args:
- data: Iterable[{key, feat, label}]
- shuffle_size: buffer size for shuffle
- Returns:
- Iterable[{key, feat, label}]
- """
- buf = []
- for sample in data:
- buf.append(sample)
- if len(buf) >= shuffle_size:
- random.shuffle(buf)
- for x in buf:
- yield x
- buf = []
- # The sample left over
- random.shuffle(buf)
- for x in buf:
- yield x
- def sort(data, sort_size=500, mode='train'):
- """ Sort the data by feature length.
- Sort is used after shuffle and before batch, so we can group
- utts with similar lengths into a batch, and `sort_size` should
- be less than `shuffle_size`
- Args:
- data: Iterable[{key, feat, label}]
- sort_size: buffer size for sort
- Returns:
- Iterable[{key, feat, label}]
- """
- buf = []
- for sample in data:
- buf.append(sample)
- if len(buf) >= sort_size:
- buf.sort(key=lambda x: x['speech_feat'].size(0))
- for x in buf:
- yield x
- buf = []
- # The sample left over
- buf.sort(key=lambda x: x['speech_feat'].size(0))
- for x in buf:
- yield x
- def static_batch(data, batch_size=16):
- """ Static batch the data by `batch_size`
- Args:
- data: Iterable[{key, feat, label}]
- batch_size: batch size
- Returns:
- Iterable[List[{key, feat, label}]]
- """
- buf = []
- for sample in data:
- buf.append(sample)
- if len(buf) >= batch_size:
- yield buf
- buf = []
- if len(buf) > 0:
- yield buf
- def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
- """ Dynamic batch the data until the total frames in batch
- reach `max_frames_in_batch`
- Args:
- data: Iterable[{key, feat, label}]
- max_frames_in_batch: max_frames in one batch
- Returns:
- Iterable[List[{key, feat, label}]]
- """
- buf = []
- longest_frames = 0
- for sample in data:
- assert 'speech_feat' in sample
- assert isinstance(sample['speech_feat'], torch.Tensor)
- new_sample_frames = sample['speech_feat'].size(0)
- longest_frames = max(longest_frames, new_sample_frames)
- frames_after_padding = longest_frames * (len(buf) + 1)
- if frames_after_padding > max_frames_in_batch:
- yield buf
- buf = [sample]
- longest_frames = new_sample_frames
- else:
- buf.append(sample)
- if len(buf) > 0:
- yield buf
- def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
- """ Wrapper for static/dynamic batch
- """
- if mode == 'inference':
- return static_batch(data, 1)
- else:
- if batch_type == 'static':
- return static_batch(data, batch_size)
- elif batch_type == 'dynamic':
- return dynamic_batch(data, max_frames_in_batch)
- else:
- logging.fatal('Unsupported batch type {}'.format(batch_type))
- def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
- """ Padding the data into training data
- Args:
- data: Iterable[List[{key, feat, label}]]
- Returns:
- Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
- """
- for sample in data:
- assert isinstance(sample, list)
- speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
- dtype=torch.int32)
- order = torch.argsort(speech_feat_len, descending=True)
- utts = [sample[i]['utt'] for i in order]
- speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
- speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
- speech = pad_sequence(speech, batch_first=True, padding_value=0)
- speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
- speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
- speech_token = pad_sequence(speech_token,
- batch_first=True,
- padding_value=0)
- speech_feat = [sample[i]['speech_feat'] for i in order]
- speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
- speech_feat = pad_sequence(speech_feat,
- batch_first=True,
- padding_value=0)
- text = [sample[i]['text'] for i in order]
- text_token = [torch.tensor(sample[i]['text_token']) for i in order]
- text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
- text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
- utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
- spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
- batch = {
- "utts": utts,
- "speech": speech,
- "speech_len": speech_len,
- "speech_token": speech_token,
- "speech_token_len": speech_token_len,
- "speech_feat": speech_feat,
- "speech_feat_len": speech_feat_len,
- "text": text,
- "text_token": text_token,
- "text_token_len": text_token_len,
- "utt_embedding": utt_embedding,
- "spk_embedding": spk_embedding,
- }
- if dpo:
- reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order]
- reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
- reject_speech_token = pad_sequence(reject_speech_token,
- batch_first=True,
- padding_value=0)
- batch['reject_speech_token'] = reject_speech_token
- batch['reject_speech_token_len'] = reject_speech_token_len
- if gan is True:
- # in gan train, we need pitch_feat
- pitch_feat = [sample[i]['pitch_feat'] for i in order]
- pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
- pitch_feat = pad_sequence(pitch_feat,
- batch_first=True,
- padding_value=0)
- batch["pitch_feat"] = pitch_feat
- batch["pitch_feat_len"] = pitch_feat_len
- else:
- # only gan train needs speech, delete it to save memory
- del batch["speech"]
- del batch["speech_len"]
- if mode == 'inference':
- tts_text = [sample[i]['tts_text'] for i in order]
- tts_index = [sample[i]['tts_index'] for i in order]
- tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
- tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
- tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
- batch.update({'tts_text': tts_text,
- 'tts_index': tts_index,
- 'tts_text_token': tts_text_token,
- 'tts_text_token_len': tts_text_token_len})
- if use_spk_embedding is True:
- batch["embedding"] = batch["spk_embedding"]
- else:
- batch["embedding"] = batch["utt_embedding"]
- yield batch
|