prepare_reject_sample.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import argparse
  2. import logging
  3. import os
  4. from tqdm import tqdm
  5. import torch
  6. import torchaudio
  7. from cosyvoice.cli.cosyvoice import CosyVoice2
  8. from cosyvoice.utils.file_utils import load_wav
  9. logger = logging.getLogger()
  10. def main():
  11. cosyvoice = CosyVoice2(args.ref_model)
  12. utt2wav, utt2text = {}, {}
  13. with open('{}/wav.scp'.format(args.src_dir)) as f:
  14. for l in f:
  15. l = l.split('\n')[0].split()
  16. utt2wav[l[0]] = l[1]
  17. with open('{}/text'.format(args.src_dir)) as f:
  18. for l in f:
  19. l = l.split('\n')[0].split()
  20. utt2text[l[0]] = ' '.join(l[1:])
  21. os.makedirs('{}/wav'.format(args.des_dir), exist_ok=True)
  22. with open('{}/wav.scp'.format(args.des_dir), 'w') as f:
  23. for utt, wav in tqdm(utt2wav.items()):
  24. prompt_speech_16k = load_wav(wav, 16000)
  25. if prompt_speech_16k.shape[1] >= 30 * 16000:
  26. continue
  27. speech_list = []
  28. for _, j in enumerate(cosyvoice.inference_zero_shot(utt2text[utt], utt2text[utt], prompt_speech_16k, stream=False, text_frontend=False)):
  29. speech_list.append(j['tts_speech'])
  30. negative_wav = os.path.abspath('{}/wav/{}'.format(args.des_dir, os.path.basename(wav)))
  31. torchaudio.save(negative_wav, torch.concat(speech_list, dim=1), cosyvoice.sample_rate, backend='soundfile')
  32. f.write('{} {}\n'.format(utt, negative_wav))
  33. if __name__ == "__main__":
  34. parser = argparse.ArgumentParser()
  35. parser.add_argument('--src_dir',
  36. type=str)
  37. parser.add_argument('--des_dir',
  38. type=str)
  39. parser.add_argument('--ref_model',
  40. type=str)
  41. args = parser.parse_args()
  42. main()