dataset.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
  2. # 2024 Alibaba Inc (authors: Xiang Lyu)
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import random
  16. import math
  17. from functools import partial
  18. import torch
  19. import torch.distributed as dist
  20. from torch.utils.data import IterableDataset
  21. from cosyvoice.utils.file_utils import read_lists
  22. class Processor(IterableDataset):
  23. def __init__(self, source, f, *args, **kw):
  24. assert callable(f)
  25. self.source = source
  26. self.f = f
  27. self.args = args
  28. self.kw = kw
  29. def set_epoch(self, epoch):
  30. self.source.set_epoch(epoch)
  31. def __iter__(self):
  32. """ Return an iterator over the source dataset processed by the
  33. given processor.
  34. """
  35. assert self.source is not None
  36. assert callable(self.f)
  37. return self.f(iter(self.source), *self.args, **self.kw)
  38. def apply(self, f):
  39. assert callable(f)
  40. return Processor(self, f, *self.args, **self.kw)
  41. class DistributedSampler:
  42. def __init__(self, shuffle=True, partition=True):
  43. self.epoch = -1
  44. self.update()
  45. self.shuffle = shuffle
  46. self.partition = partition
  47. def update(self):
  48. assert dist.is_available()
  49. if dist.is_initialized():
  50. self.rank = dist.get_rank()
  51. self.world_size = dist.get_world_size()
  52. else:
  53. self.rank = 0
  54. self.world_size = 1
  55. worker_info = torch.utils.data.get_worker_info()
  56. if worker_info is None:
  57. self.worker_id = 0
  58. self.num_workers = 1
  59. else:
  60. self.worker_id = worker_info.id
  61. self.num_workers = worker_info.num_workers
  62. return dict(rank=self.rank,
  63. world_size=self.world_size,
  64. worker_id=self.worker_id,
  65. num_workers=self.num_workers)
  66. def set_epoch(self, epoch):
  67. self.epoch = epoch
  68. def sample(self, data):
  69. """ Sample data according to rank/world_size/num_workers
  70. Args:
  71. data(List): input data list
  72. Returns:
  73. List: data list after sample
  74. """
  75. data = list(range(len(data)))
  76. # force datalist even
  77. if self.partition:
  78. if self.shuffle:
  79. random.Random(self.epoch).shuffle(data)
  80. if len(data) < self.world_size:
  81. data = data * math.ceil(self.world_size / len(data))
  82. data = data[:self.world_size]
  83. data = data[self.rank::self.world_size]
  84. if len(data) < self.num_workers:
  85. data = data * math.ceil(self.num_workers / len(data))
  86. data = data[:self.num_workers]
  87. data = data[self.worker_id::self.num_workers]
  88. return data
  89. class DataList(IterableDataset):
  90. def __init__(self, lists, shuffle=True, partition=True):
  91. self.lists = lists
  92. self.sampler = DistributedSampler(shuffle, partition)
  93. def set_epoch(self, epoch):
  94. self.sampler.set_epoch(epoch)
  95. def __iter__(self):
  96. sampler_info = self.sampler.update()
  97. indexes = self.sampler.sample(self.lists)
  98. for index in indexes:
  99. data = dict(src=self.lists[index])
  100. data.update(sampler_info)
  101. yield data
  102. def Dataset(data_list_file,
  103. data_pipeline,
  104. mode='train',
  105. gan=False,
  106. dpo=False,
  107. shuffle=True,
  108. partition=True):
  109. """ Construct dataset from arguments
  110. We have two shuffle stage in the Dataset. The first is global
  111. shuffle at shards tar/raw file level. The second is global shuffle
  112. at training samples level.
  113. Args:
  114. data_type(str): raw/shard
  115. tokenizer (BaseTokenizer): tokenizer to tokenize
  116. partition(bool): whether to do data partition in terms of rank
  117. """
  118. lists = read_lists(data_list_file)
  119. dataset = DataList(lists,
  120. shuffle=shuffle,
  121. partition=partition)
  122. # map partial arg to padding func
  123. data_pipeline[-1] = partial(data_pipeline[-1], gan=gan, dpo=dpo)
  124. for func in data_pipeline:
  125. dataset = Processor(dataset, func, mode=mode)
  126. return dataset