make_parquet_list.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. #!/usr/bin/env python3
  2. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import argparse
  16. import logging
  17. import os
  18. import json
  19. from tqdm import tqdm
  20. import pandas as pd
  21. import multiprocessing
  22. import time
  23. import torch
  24. def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file):
  25. start_time = time.time()
  26. data_list = []
  27. for utt in tqdm(utt_list):
  28. data = open(utt2wav[utt], 'rb').read()
  29. data_list.append(data)
  30. wav_list = [utt2wav[utt] for utt in utt_list]
  31. text_list = [utt2text[utt] for utt in utt_list]
  32. spk_list = [utt2spk[utt] for utt in utt_list]
  33. uttembedding_list = [utt2embedding[utt] for utt in utt_list]
  34. spkembedding_list = [spk2embedding[utt2spk[utt]] for utt in utt_list]
  35. speech_token_list = [utt2speech_token.get(utt, []) for utt in utt_list]
  36. if args.dpo:
  37. reject_speech_token_list = [utt2reject_speech_token[utt] for utt in utt_list]
  38. # 保存到parquet,utt2parquet_file,spk2parquet_file
  39. df = pd.DataFrame()
  40. df['utt'] = utt_list
  41. df['wav'] = wav_list
  42. df['audio_data'] = data_list
  43. df['text'] = text_list
  44. df['spk'] = spk_list
  45. df['utt_embedding'] = uttembedding_list
  46. df['spk_embedding'] = spkembedding_list
  47. df['speech_token'] = speech_token_list
  48. if args.dpo:
  49. df['reject_speech_token'] = reject_speech_token_list
  50. df.to_parquet(parquet_file)
  51. with open(utt2parquet_file, 'w') as f:
  52. json.dump({k: parquet_file for k in utt_list}, f, ensure_ascii=False, indent=2)
  53. with open(spk2parquet_file, 'w') as f:
  54. json.dump({k: parquet_file for k in list(set(spk_list))}, f, ensure_ascii=False, indent=2)
  55. logging.info('spend time {}'.format(time.time() - start_time))
  56. if __name__ == "__main__":
  57. parser = argparse.ArgumentParser()
  58. parser.add_argument('--num_utts_per_parquet',
  59. type=int,
  60. default=1000,
  61. help='num utts per parquet')
  62. parser.add_argument('--num_processes',
  63. type=int,
  64. default=1,
  65. help='num processes for make parquets')
  66. parser.add_argument('--src_dir',
  67. type=str)
  68. parser.add_argument('--des_dir',
  69. type=str)
  70. parser.add_argument('--dpo',
  71. action='store_true',
  72. default=False,
  73. help='Use Direct Preference Optimization')
  74. args = parser.parse_args()
  75. utt2wav, utt2text, utt2spk = {}, {}, {}
  76. with open('{}/wav.scp'.format(args.src_dir)) as f:
  77. for l in f:
  78. l = l.replace('\n', '').split()
  79. utt2wav[l[0]] = l[1]
  80. with open('{}/text'.format(args.src_dir)) as f:
  81. for l in f:
  82. l = l.replace('\n', '').split()
  83. utt2text[l[0]] = ' '.join(l[1:])
  84. with open('{}/utt2spk'.format(args.src_dir)) as f:
  85. for l in f:
  86. l = l.replace('\n', '').split()
  87. utt2spk[l[0]] = l[1]
  88. utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir))
  89. spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir))
  90. utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir))
  91. if args.dpo:
  92. utt2reject_speech_token = torch.load('{}_reject/utt2speech_token.pt'.format(args.src_dir))
  93. utts = list(utt2wav.keys())
  94. # Using process pool to speedup
  95. pool = multiprocessing.Pool(processes=args.num_processes)
  96. parquet_list, utt2parquet_list, spk2parquet_list = [], [], []
  97. for i, j in enumerate(range(0, len(utts), args.num_utts_per_parquet)):
  98. parquet_file = os.path.join(args.des_dir, 'parquet_{:09d}.tar'.format(i))
  99. utt2parquet_file = os.path.join(args.des_dir, 'utt2parquet_{:09d}.json'.format(i))
  100. spk2parquet_file = os.path.join(args.des_dir, 'spk2parquet_{:09d}.json'.format(i))
  101. parquet_list.append(parquet_file)
  102. utt2parquet_list.append(utt2parquet_file)
  103. spk2parquet_list.append(spk2parquet_file)
  104. pool.apply_async(job, (utts[j: j + args.num_utts_per_parquet], parquet_file, utt2parquet_file, spk2parquet_file))
  105. pool.close()
  106. pool.join()
  107. with open('{}/data.list'.format(args.des_dir), 'w', encoding='utf8') as f1, \
  108. open('{}/utt2data.list'.format(args.des_dir), 'w', encoding='utf8') as f2, \
  109. open('{}/spk2data.list'.format(args.des_dir), 'w', encoding='utf8') as f3:
  110. for name in parquet_list:
  111. f1.write(name + '\n')
  112. for name in utt2parquet_list:
  113. f2.write(name + '\n')
  114. for name in spk2parquet_list:
  115. f3.write(name + '\n')