|
|
@@ -70,6 +70,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
|
|
|
f.write(engine_bytes)
|
|
|
logging.info("Succesfully convert onnx to trt...")
|
|
|
|
|
|
+
|
|
|
class TrtContextWrapper:
|
|
|
def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
|
|
|
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
|
|
|
@@ -88,12 +89,13 @@ class TrtContextWrapper:
|
|
|
def release_estimator(self, context, stream):
|
|
|
self.trt_context_pool.put([context, stream])
|
|
|
|
|
|
+
|
|
|
class CosyVoice2_Token2Wav(torch.nn.Module):
|
|
|
def __init__(self, model_dir: str = "./CosyVoice2-0.5B", enable_trt: bool = False, device_id: int = 0):
|
|
|
super().__init__()
|
|
|
self.device_id = device_id
|
|
|
self.device = f"cuda:{device_id}"
|
|
|
-
|
|
|
+
|
|
|
self.flow = CausalMaskedDiffWithXvec()
|
|
|
self.flow.half()
|
|
|
self.flow.load_state_dict(torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True)
|
|
|
@@ -107,22 +109,20 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|
|
option = onnxruntime.SessionOptions()
|
|
|
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
|
option.intra_op_num_threads = 1
|
|
|
- self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option,
|
|
|
- providers=["CPUExecutionProvider"])
|
|
|
-
|
|
|
+ self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option, providers=["CPUExecutionProvider"])
|
|
|
+
|
|
|
self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2.onnx").to(self.device).eval()
|
|
|
|
|
|
- gpu="l20"
|
|
|
+ gpu = "l20"
|
|
|
if enable_trt:
|
|
|
self.load_trt(f'{model_dir}/flow.decoder.estimator.fp16.dynamic_batch.{gpu}.plan',
|
|
|
- f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx',
|
|
|
- 1,
|
|
|
- True)
|
|
|
+ f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx',
|
|
|
+ 1,
|
|
|
+ True)
|
|
|
self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt',
|
|
|
- f'{model_dir}/campplus.onnx',
|
|
|
- 1,
|
|
|
- False)
|
|
|
-
|
|
|
+ f'{model_dir}/campplus.onnx',
|
|
|
+ 1,
|
|
|
+ False)
|
|
|
|
|
|
def forward_spk_embedding(self, spk_feat):
|
|
|
if isinstance(self.spk_model, onnxruntime.InferenceSession):
|
|
|
@@ -173,7 +173,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|
|
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, fp16=True):
|
|
|
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
|
|
|
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
|
|
|
- trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_batch_size=2, max_batch_size=16)
|
|
|
+ trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_bs=2, max_batch_size=16)
|
|
|
convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, fp16)
|
|
|
del self.flow.decoder.estimator
|
|
|
import tensorrt as trt
|
|
|
@@ -182,10 +182,11 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|
|
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
|
|
|
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
|
|
|
|
|
|
- def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_size=64):
|
|
|
+ def get_trt_kwargs_dynamic_batch(self, opt_bs=2, max_batch_size=64):
|
|
|
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)]
|
|
|
- opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 1, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80)]
|
|
|
- max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 1, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80)]
|
|
|
+ opt_shape = [(opt_bs * 2, 80, 500), (opt_bs * 2, 1, 500), (opt_bs * 2, 80, 500), (opt_bs * 2, 80, 500), (opt_bs * 2,), (opt_bs * 2, 80)]
|
|
|
+ max_shape = [(max_batch_size * 2, 80, 3000), (max_batch_size * 2, 1, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2,),
|
|
|
+ (max_batch_size * 2, 80)]
|
|
|
input_names = ["x", "mask", "mu", "cond", "t", "spks"]
|
|
|
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
|
|
|
|
|
@@ -203,7 +204,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|
|
speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist()
|
|
|
prompt_speech_tokens_list.append(speech_tokens_i)
|
|
|
return prompt_speech_tokens_list
|
|
|
-
|
|
|
+
|
|
|
def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor:
|
|
|
spk_emb_for_flow = []
|
|
|
for audio in prompt_audios_list:
|
|
|
@@ -213,9 +214,9 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|
|
spk_emb = self.forward_spk_embedding(spk_feat)
|
|
|
|
|
|
spk_emb_for_flow.append(spk_emb)
|
|
|
- spk_emb_for_flow = torch.tensor(spk_emb_for_flow)
|
|
|
+ spk_emb_for_flow = torch.tensor(spk_emb_for_flow)
|
|
|
return spk_emb_for_flow
|
|
|
-
|
|
|
+
|
|
|
def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]):
|
|
|
prompt_mels_for_flow = []
|
|
|
prompt_mels_lens_for_flow = []
|
|
|
@@ -231,9 +232,9 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|
|
prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80]
|
|
|
prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
|
|
|
return prompt_mels_for_flow, prompt_mels_lens_for_flow
|
|
|
-
|
|
|
|
|
|
- def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor):
|
|
|
+ def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor,
|
|
|
+ prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor):
|
|
|
batch_size = prompt_mels_for_flow.shape[0]
|
|
|
flow_inputs = []
|
|
|
flow_inputs_lens = []
|
|
|
@@ -262,14 +263,12 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|
|
generated_wavs.append(wav)
|
|
|
return generated_wavs
|
|
|
|
|
|
-
|
|
|
@torch.inference_mode()
|
|
|
def forward(
|
|
|
self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
|
|
|
):
|
|
|
# assert all item in prompt_audios_sample_rate is 16000
|
|
|
assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
|
|
|
-
|
|
|
|
|
|
prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list)
|
|
|
|
|
|
@@ -277,10 +276,11 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|
|
|
|
|
spk_emb_for_flow = self.get_spk_emb(prompt_audios_list)
|
|
|
|
|
|
- generated_mels, generated_mels_lens = self.forward_flow(prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
|
|
|
+ generated_mels, generated_mels_lens = self.forward_flow(
|
|
|
+ prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
|
|
|
|
|
|
generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow)
|
|
|
-
|
|
|
+
|
|
|
return generated_wavs
|
|
|
|
|
|
|
|
|
@@ -288,13 +288,14 @@ def collate_fn(batch):
|
|
|
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
|
|
|
for i, item in enumerate(batch):
|
|
|
generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
|
|
|
- audio = torch.from_numpy(item['prompt_audio']['array']).float()
|
|
|
+ audio = torch.from_numpy(item['prompt_audio']['array']).float()
|
|
|
prompt_audios_list.append(audio)
|
|
|
prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
|
|
|
ids.append(item['id'])
|
|
|
|
|
|
return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate
|
|
|
|
|
|
+
|
|
|
def get_args():
|
|
|
parser = argparse.ArgumentParser()
|
|
|
parser.add_argument("--enable-trt", action="store_true")
|
|
|
@@ -305,6 +306,7 @@ def get_args():
|
|
|
parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch")
|
|
|
return parser.parse_args()
|
|
|
|
|
|
+
|
|
|
if __name__ == "__main__":
|
|
|
args = get_args()
|
|
|
model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt)
|
|
|
@@ -315,22 +317,19 @@ if __name__ == "__main__":
|
|
|
|
|
|
dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
|
|
|
|
|
|
-
|
|
|
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
|
|
|
-
|
|
|
-
|
|
|
+
|
|
|
for epoch in range(args.warmup):
|
|
|
start_time = time.time()
|
|
|
-
|
|
|
+
|
|
|
for batch in data_loader:
|
|
|
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch
|
|
|
|
|
|
generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate)
|
|
|
-
|
|
|
|
|
|
for id, wav in zip(ids, generated_wavs):
|
|
|
torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000)
|
|
|
-
|
|
|
+
|
|
|
end_time = time.time()
|
|
|
epoch_time = end_time - start_time
|
|
|
- print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")
|
|
|
+ print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")
|