|
|
@@ -40,17 +40,18 @@ def parquet_opener(data, mode='train', tts_data={}):
|
|
|
assert 'src' in sample
|
|
|
url = sample['src']
|
|
|
try:
|
|
|
- df = pq.read_table(url).to_pandas()
|
|
|
- for i in range(len(df)):
|
|
|
- if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
|
|
|
- continue
|
|
|
- sample.update(dict(df.loc[i]))
|
|
|
- if mode == 'train':
|
|
|
- # NOTE do not return sample directly, must initialize a new dict
|
|
|
- yield {**sample}
|
|
|
- else:
|
|
|
- for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
|
|
|
- yield {**sample, 'tts_index': index, 'tts_text': text}
|
|
|
+ for df in pq.ParquetFile(url).iter_batches(batch_size=64):
|
|
|
+ df = df.to_pandas()
|
|
|
+ for i in range(len(df)):
|
|
|
+ if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
|
|
|
+ continue
|
|
|
+ sample.update(dict(df.loc[i]))
|
|
|
+ if mode == 'train':
|
|
|
+ # NOTE do not return sample directly, must initialize a new dict
|
|
|
+ yield {**sample}
|
|
|
+ else:
|
|
|
+ for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
|
|
|
+ yield {**sample, 'tts_index': index, 'tts_text': text}
|
|
|
except Exception as ex:
|
|
|
logging.warning('Failed to open {}, ex info {}'.format(url, ex))
|
|
|
|