prepare_reject_sample.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import argparse
  2. import logging
  3. import os
  4. from tqdm import tqdm
  5. import torch, torchaudio
  6. from cosyvoice.cli.cosyvoice import CosyVoice2
  7. from cosyvoice.utils.file_utils import load_wav
  8. logger = logging.getLogger()
  9. def main():
  10. cosyvoice = CosyVoice2(args.ref_model)
  11. utt2wav, utt2text = {}, {}
  12. with open('{}/wav.scp'.format(args.src_dir)) as f:
  13. for l in f:
  14. l = l.split('\n')[0].split()
  15. utt2wav[l[0]] = l[1]
  16. with open('{}/text'.format(args.src_dir)) as f:
  17. for l in f:
  18. l = l.split('\n')[0].split()
  19. utt2text[l[0]] = ' '.join(l[1:])
  20. os.makedirs('{}/wav'.format(args.des_dir), exist_ok=True)
  21. with open('{}/wav.scp'.format(args.des_dir), 'w') as f:
  22. for utt, wav in tqdm(utt2wav.items()):
  23. prompt_speech_16k = load_wav(wav, 16000)
  24. if prompt_speech_16k.shape[1] >= 30 * 16000:
  25. continue
  26. speech_list = []
  27. for i, j in enumerate(cosyvoice.inference_zero_shot(utt2text[utt], utt2text[utt], prompt_speech_16k, stream=False, text_frontend=False)):
  28. speech_list.append(j['tts_speech'])
  29. negative_wav = os.path.abspath('{}/wav/{}'.format(args.des_dir, os.path.basename(wav)))
  30. torchaudio.save(negative_wav, torch.concat(speech_list, dim=1), cosyvoice.sample_rate, backend='soundfile')
  31. f.write('{} {}\n'.format(utt, negative_wav))
  32. if __name__ == "__main__":
  33. parser = argparse.ArgumentParser()
  34. parser.add_argument('--src_dir',
  35. type=str)
  36. parser.add_argument('--des_dir',
  37. type=str)
  38. parser.add_argument('--ref_model',
  39. type=str)
  40. args = parser.parse_args()
  41. main()