make_parquet_list.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. #!/usr/bin/env python3
  2. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import argparse
  16. import logging
  17. import os
  18. import json
  19. from tqdm import tqdm
  20. import pandas as pd
  21. import multiprocessing
  22. import time
  23. import torch
  24. def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file):
  25. start_time = time.time()
  26. data_list = []
  27. for utt in tqdm(utt_list):
  28. data = open(utt2wav[utt], 'rb').read()
  29. data_list.append(data)
  30. # 保存到parquet,utt2parquet_file,spk2parquet_file
  31. df = pd.DataFrame()
  32. df['utt'] = utt_list
  33. df['audio_data'] = data_list
  34. df['wav'] = [utt2wav[utt] for utt in utt_list]
  35. df['text'] = [utt2text[utt] for utt in utt_list]
  36. df['spk'] = [utt2spk[utt] for utt in utt_list]
  37. if utt2embedding is not None:
  38. df['utt_embedding'] = [utt2embedding[utt] for utt in utt_list]
  39. if spk2embedding is not None:
  40. df['spk_embedding'] = [spk2embedding[utt2spk[utt]] for utt in utt_list]
  41. if utt2speech_token is not None:
  42. df['speech_token'] = [utt2speech_token[utt] for utt in utt_list]
  43. if utt2instruct is not None:
  44. df['instruct'] = [utt2instruct[utt] for utt in utt_list]
  45. if args.dpo:
  46. df['reject_speech_token'] = [utt2reject_speech_token.get(utt, None) for utt in utt_list]
  47. df.to_parquet(parquet_file)
  48. with open(utt2parquet_file, 'w') as f:
  49. json.dump({k: parquet_file for k in utt_list}, f, ensure_ascii=False, indent=2)
  50. with open(spk2parquet_file, 'w') as f:
  51. json.dump({k: parquet_file for k in list(set(spk_list))}, f, ensure_ascii=False, indent=2)
  52. logging.info('spend time {}'.format(time.time() - start_time))
  53. if __name__ == "__main__":
  54. parser = argparse.ArgumentParser()
  55. parser.add_argument('--num_utts_per_parquet',
  56. type=int,
  57. default=1000,
  58. help='num utts per parquet')
  59. parser.add_argument('--num_processes',
  60. type=int,
  61. default=1,
  62. help='num processes for make parquets')
  63. parser.add_argument('--src_dir',
  64. type=str)
  65. parser.add_argument('--des_dir',
  66. type=str)
  67. parser.add_argument('--dpo',
  68. action='store_true',
  69. default=False,
  70. help='Use Direct Preference Optimization')
  71. args = parser.parse_args()
  72. utt2wav, utt2text, utt2spk = {}, {}, {}
  73. with open('{}/wav.scp'.format(args.src_dir)) as f:
  74. for l in f:
  75. l = l.replace('\n', '').split()
  76. utt2wav[l[0]] = l[1]
  77. with open('{}/text'.format(args.src_dir)) as f:
  78. for l in f:
  79. l = l.replace('\n', '').split()
  80. utt2text[l[0]] = ' '.join(l[1:])
  81. with open('{}/utt2spk'.format(args.src_dir)) as f:
  82. for l in f:
  83. l = l.replace('\n', '').split()
  84. utt2spk[l[0]] = l[1]
  85. if os.path.exists('{}/instruct'.format(args.src_dir)):
  86. utt2instruct = {}
  87. with open('{}/instruct'.format(args.src_dir)) as f:
  88. for l in f:
  89. l = l.replace('\n', '').split()
  90. utt2instruct[l[0]] = ' '.join(l[1:])
  91. else:
  92. utt2instruct = None
  93. utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir)) if os.path.exists('{}/utt2embedding.pt'.format(args.src_dir)) else None
  94. spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir)) if os.path.exists('{}/spk2embedding.pt'.format(args.src_dir)) else None
  95. utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir)) if os.path.exists('{}/utt2speech_token.pt'.format(args.src_dir)) else None
  96. if args.dpo:
  97. utt2reject_speech_token = torch.load('{}_reject/utt2speech_token.pt'.format(args.src_dir)) if os.path.exists('{}_reject/utt2speech_token.pt'.format(args.src_dir)) else {}
  98. utts = list(utt2wav.keys())
  99. # Using process pool to speedup
  100. pool = multiprocessing.Pool(processes=args.num_processes)
  101. parquet_list, utt2parquet_list, spk2parquet_list = [], [], []
  102. for i, j in enumerate(range(0, len(utts), args.num_utts_per_parquet)):
  103. parquet_file = os.path.join(args.des_dir, 'parquet_{:09d}.tar'.format(i))
  104. utt2parquet_file = os.path.join(args.des_dir, 'utt2parquet_{:09d}.json'.format(i))
  105. spk2parquet_file = os.path.join(args.des_dir, 'spk2parquet_{:09d}.json'.format(i))
  106. parquet_list.append(parquet_file)
  107. utt2parquet_list.append(utt2parquet_file)
  108. spk2parquet_list.append(spk2parquet_file)
  109. pool.apply_async(job, (utts[j: j + args.num_utts_per_parquet], parquet_file, utt2parquet_file, spk2parquet_file))
  110. pool.close()
  111. pool.join()
  112. with open('{}/data.list'.format(args.des_dir), 'w', encoding='utf8') as f1, \
  113. open('{}/utt2data.list'.format(args.des_dir), 'w', encoding='utf8') as f2, \
  114. open('{}/spk2data.list'.format(args.des_dir), 'w', encoding='utf8') as f3:
  115. for name in parquet_list:
  116. f1.write(name + '\n')
  117. for name in utt2parquet_list:
  118. f2.write(name + '\n')
  119. for name in spk2parquet_list:
  120. f3.write(name + '\n')