Sfoglia il codice sorgente

Merge pull request #639 from FunAudioLLM/dev/lyuxiang.lx

use stream read to save memory
Xiang Lyu 1 anno fa
parent
commit
7701325969
1 ha cambiato i file con 12 aggiunte e 11 eliminazioni
  1. 12 11
      cosyvoice/dataset/processor.py

+ 12 - 11
cosyvoice/dataset/processor.py

@@ -40,17 +40,18 @@ def parquet_opener(data, mode='train', tts_data={}):
         assert 'src' in sample
         url = sample['src']
         try:
-            df = pq.read_table(url).to_pandas()
-            for i in range(len(df)):
-                if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
-                    continue
-                sample.update(dict(df.loc[i]))
-                if mode == 'train':
-                    # NOTE do not return sample directly, must initialize a new dict
-                    yield {**sample}
-                else:
-                    for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
-                        yield {**sample, 'tts_index': index, 'tts_text': text}
+            for df in pq.ParquetFile(url).iter_batches(batch_size=64):
+                df = df.to_pandas()
+                for i in range(len(df)):
+                    if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
+                        continue
+                    sample.update(dict(df.loc[i]))
+                    if mode == 'train':
+                        # NOTE do not return sample directly, must initialize a new dict
+                        yield {**sample}
+                    else:
+                        for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
+                            yield {**sample, 'tts_index': index, 'tts_text': text}
         except Exception as ex:
             logging.warning('Failed to open {}, ex info {}'.format(url, ex))