streaming_inference.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import torch
  2. import os
  3. import argparse
  4. from datasets import load_dataset
  5. from torch.utils.data import DataLoader
  6. import numpy as np
  7. import torchaudio
  8. import time
  9. from token2wav_dit import CosyVoice2_Token2Wav
  10. import soundfile as sf
  11. def collate_fn(batch):
  12. ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
  13. prompt_speech_tokens_list, prompt_text_list = [], []
  14. for i, item in enumerate(batch):
  15. generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
  16. audio = torch.from_numpy(item['prompt_audio']['array']).float()
  17. prompt_audios_list.append(audio)
  18. prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
  19. ids.append(item['id'])
  20. prompt_speech_tokens_list.append(item['prompt_audio_cosy2_tokens'])
  21. prompt_text_list.append(item['prompt_text'])
  22. return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list
  23. def get_args():
  24. parser = argparse.ArgumentParser()
  25. parser.add_argument("--enable-trt", action="store_true")
  26. parser.add_argument("--model-dir", type=str, default="./Step-Audio-2-mini/token2wav")
  27. parser.add_argument("--batch-size", type=int, default=1)
  28. parser.add_argument("--output-dir", type=str, default="generated_wavs")
  29. parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts")
  30. parser.add_argument("--dataset-name", type=str, default="yuekai/seed_tts_cosy2")
  31. parser.add_argument("--strategy", type=str, default="equal", choices=["equal", "exponential"])
  32. return parser.parse_args()
  33. def fake_generated_id_iter(generated_speech_tokens_list):
  34. for i in range(len(generated_speech_tokens_list)):
  35. yield generated_speech_tokens_list[i]
  36. if __name__ == "__main__":
  37. args = get_args()
  38. if not os.path.exists(args.output_dir):
  39. os.makedirs(args.output_dir)
  40. dataset_name = args.dataset_name
  41. dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
  42. data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
  43. token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True)
  44. flow_pre_lookahead_len = 3
  45. CHUNK_SIZE = 25
  46. token_frame_rate = 25
  47. OVERLAP_SIZE = 0
  48. warmup_times = 3
  49. for _ in range(warmup_times):
  50. start_time = time.time()
  51. total_forward_count = 0
  52. for batch in data_loader:
  53. tts_speech_list = []
  54. ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list = batch
  55. id, generated_speech_tokens, prompt_audio, prompt_audio_sample_rate = ids[0], generated_speech_tokens_list[0], prompt_audios_list[0], prompt_audios_sample_rate[0]
  56. assert prompt_audio_sample_rate == 16000
  57. prompt_text = prompt_text_list[0]
  58. prompt_speech_tokens = prompt_speech_tokens_list[0]
  59. semantic_token_ids_arr, token_offset = [], 0
  60. flow_prompt_speech_token_len = len(prompt_speech_tokens)
  61. buffer = generated_speech_tokens
  62. output_wavs = []
  63. chunk_index = 0
  64. while True:
  65. if args.strategy == "equal":
  66. this_chunk_size = CHUNK_SIZE
  67. elif args.strategy == "exponential":
  68. this_chunk_size = token_frame_rate * (2 ** chunk_index)
  69. if len(buffer) >= this_chunk_size + token2wav_model.flow.pre_lookahead_len:
  70. wavs = token2wav_model.forward_streaming(buffer[:this_chunk_size + token2wav_model.flow.pre_lookahead_len], False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate)
  71. buffer = buffer[this_chunk_size - OVERLAP_SIZE:]
  72. output_wavs.append(wavs)
  73. total_forward_count += 1
  74. chunk_index += 1
  75. else:
  76. wavs = token2wav_model.forward_streaming(buffer, True, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate)
  77. output_wavs.append(wavs)
  78. total_forward_count += 1
  79. # chunk_index += 1
  80. break
  81. for i, wav in enumerate(output_wavs):
  82. output_wavs[i] = wav.cpu().numpy().squeeze()
  83. audios = output_wavs
  84. reconstructed_audio = np.concatenate(audios)
  85. sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16")
  86. end_time = time.time()
  87. if _ == 0:
  88. token2wav_model.speaker_cache = {}
  89. print(f"Warmup time: {end_time - start_time} seconds")
  90. print("clear speaker cache")
  91. elif _ == 1:
  92. print(f"Cost time without speaker cache: {end_time - start_time} seconds")
  93. else:
  94. print(f"Cost time with speaker cache: {end_time - start_time} seconds")
  95. print(f"Total flow matching forward calls: {total_forward_count}")