dataset.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
  2. # 2024 Alibaba Inc (authors: Xiang Lyu)
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import random
  16. import json
  17. import math
  18. from functools import partial
  19. import torch
  20. import torch.distributed as dist
  21. from torch.utils.data import IterableDataset
  22. from cosyvoice.utils.file_utils import read_lists, read_json_lists
  23. class Processor(IterableDataset):
  24. def __init__(self, source, f, *args, **kw):
  25. assert callable(f)
  26. self.source = source
  27. self.f = f
  28. self.args = args
  29. self.kw = kw
  30. def set_epoch(self, epoch):
  31. self.source.set_epoch(epoch)
  32. def __iter__(self):
  33. """ Return an iterator over the source dataset processed by the
  34. given processor.
  35. """
  36. assert self.source is not None
  37. assert callable(self.f)
  38. return self.f(iter(self.source), *self.args, **self.kw)
  39. def apply(self, f):
  40. assert callable(f)
  41. return Processor(self, f, *self.args, **self.kw)
  42. class DistributedSampler:
  43. def __init__(self, shuffle=True, partition=True):
  44. self.epoch = -1
  45. self.update()
  46. self.shuffle = shuffle
  47. self.partition = partition
  48. def update(self):
  49. assert dist.is_available()
  50. if dist.is_initialized():
  51. self.rank = dist.get_rank()
  52. self.world_size = dist.get_world_size()
  53. else:
  54. self.rank = 0
  55. self.world_size = 1
  56. worker_info = torch.utils.data.get_worker_info()
  57. if worker_info is None:
  58. self.worker_id = 0
  59. self.num_workers = 1
  60. else:
  61. self.worker_id = worker_info.id
  62. self.num_workers = worker_info.num_workers
  63. return dict(rank=self.rank,
  64. world_size=self.world_size,
  65. worker_id=self.worker_id,
  66. num_workers=self.num_workers)
  67. def set_epoch(self, epoch):
  68. self.epoch = epoch
  69. def sample(self, data):
  70. """ Sample data according to rank/world_size/num_workers
  71. Args:
  72. data(List): input data list
  73. Returns:
  74. List: data list after sample
  75. """
  76. data = list(range(len(data)))
  77. # force datalist even
  78. if self.partition:
  79. if self.shuffle:
  80. random.Random(self.epoch).shuffle(data)
  81. if len(data) < self.world_size:
  82. data = data * math.ceil(self.world_size / len(data))
  83. data = data[:self.world_size]
  84. data = data[self.rank::self.world_size]
  85. if len(data) < self.num_workers:
  86. data = data * math.ceil(self.num_workers / len(data))
  87. data = data[:self.num_workers]
  88. data = data[self.worker_id::self.num_workers]
  89. return data
  90. class DataList(IterableDataset):
  91. def __init__(self, lists, shuffle=True, partition=True):
  92. self.lists = lists
  93. self.sampler = DistributedSampler(shuffle, partition)
  94. def set_epoch(self, epoch):
  95. self.sampler.set_epoch(epoch)
  96. def __iter__(self):
  97. sampler_info = self.sampler.update()
  98. indexes = self.sampler.sample(self.lists)
  99. for index in indexes:
  100. data = dict(src=self.lists[index])
  101. data.update(sampler_info)
  102. yield data
  103. def Dataset(data_list_file,
  104. data_pipeline,
  105. mode='train',
  106. shuffle=True,
  107. partition=True,
  108. tts_file='',
  109. prompt_utt2data=''):
  110. """ Construct dataset from arguments
  111. We have two shuffle stage in the Dataset. The first is global
  112. shuffle at shards tar/raw file level. The second is global shuffle
  113. at training samples level.
  114. Args:
  115. data_type(str): raw/shard
  116. tokenizer (BaseTokenizer): tokenizer to tokenize
  117. partition(bool): whether to do data partition in terms of rank
  118. """
  119. assert mode in ['train', 'inference']
  120. lists = read_lists(data_list_file)
  121. if mode == 'inference':
  122. with open(tts_file) as f:
  123. tts_data = json.load(f)
  124. utt2lists = read_json_lists(prompt_utt2data)
  125. # filter unnecessary file in inference mode
  126. lists = list(set([utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists]))
  127. dataset = DataList(lists,
  128. shuffle=shuffle,
  129. partition=partition)
  130. if mode == 'inference':
  131. # map partial arg tts_data in inference mode
  132. data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
  133. for func in data_pipeline:
  134. dataset = Processor(dataset, func, mode=mode)
  135. return dataset