|
|
@@ -28,6 +28,8 @@ import json
|
|
|
import math
|
|
|
import os
|
|
|
import re
|
|
|
+import threading
|
|
|
+import time
|
|
|
from typing import Dict, List, Tuple, Optional, Union
|
|
|
|
|
|
import numpy as np
|
|
|
@@ -35,13 +37,14 @@ import torch
|
|
|
from torch.utils.dlpack import from_dlpack, to_dlpack
|
|
|
import triton_python_backend_utils as pb_utils
|
|
|
from transformers import AutoTokenizer
|
|
|
-import torchaudio.compliance.kaldi as kaldi
|
|
|
+
|
|
|
import torchaudio
|
|
|
-import onnxruntime
|
|
|
|
|
|
|
|
|
from matcha.utils.audio import mel_spectrogram
|
|
|
|
|
|
+torch.set_num_threads(1)
|
|
|
+
|
|
|
|
|
|
class TritonPythonModel:
|
|
|
"""Triton Python model for Spark TTS.
|
|
|
@@ -62,6 +65,8 @@ class TritonPythonModel:
|
|
|
parameters = self.model_config['parameters']
|
|
|
model_params = {k: v["string_value"] for k, v in parameters.items()}
|
|
|
self.logger.log_info(f"model_params:{model_params}")
|
|
|
+ self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
|
|
|
+ self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
|
|
|
|
|
|
# Initialize tokenizer
|
|
|
llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
|
|
|
@@ -72,11 +77,9 @@ class TritonPythonModel:
|
|
|
self.device = torch.device("cuda")
|
|
|
self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
|
|
|
|
|
|
- campplus_model = f'{model_params["model_dir"]}/campplus.onnx'
|
|
|
- option = onnxruntime.SessionOptions()
|
|
|
- option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
|
- option.intra_op_num_threads = 1
|
|
|
- self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
|
|
|
+ self.token_frame_rate = 25
|
|
|
+ self.flow_pre_lookahead_len = 3
|
|
|
+ self.token_hop_len = 15
|
|
|
|
|
|
def forward_llm(self, input_ids):
|
|
|
"""
|
|
|
@@ -105,7 +108,7 @@ class TritonPythonModel:
|
|
|
"""
|
|
|
# convert input_ids to numpy, with shape [1, sequence_length]
|
|
|
input_ids = input_ids.cpu().numpy()
|
|
|
- max_tokens = 1024
|
|
|
+ max_tokens = 750
|
|
|
input_dict = {
|
|
|
"request_output_len": np.array([[max_tokens]], dtype=np.int32),
|
|
|
"end_id": np.array([[self.eos_token_id]], dtype=np.int32),
|
|
|
@@ -114,6 +117,8 @@ class TritonPythonModel:
|
|
|
"runtime_top_p": np.array([[0.95]], dtype=np.float32),
|
|
|
"runtime_top_k": np.array([[50]], dtype=np.int32),
|
|
|
"temperature": np.array([[0.8]], dtype=np.float32),
|
|
|
+ "repetition_penalty": np.array([[1.1]], dtype=np.float32),
|
|
|
+ "random_seed": np.array([[42]], dtype=np.uint64),
|
|
|
"input_ids": input_ids,
|
|
|
"input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
|
|
|
}
|
|
|
@@ -188,12 +193,40 @@ class TritonPythonModel:
|
|
|
|
|
|
return prompt_speech_tokens
|
|
|
|
|
|
+ def forward_speaker_embedding(self, wav):
|
|
|
+ """Forward pass through the speaker embedding component.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ wav: Input waveform tensor
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Prompt speaker embedding tensor
|
|
|
+ """
|
|
|
+ inference_request = pb_utils.InferenceRequest(
|
|
|
+ model_name='speaker_embedding',
|
|
|
+ requested_output_names=['prompt_spk_embedding'],
|
|
|
+ inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
|
|
|
+ )
|
|
|
+
|
|
|
+ inference_response = inference_request.exec()
|
|
|
+ if inference_response.has_error():
|
|
|
+ raise pb_utils.TritonModelException(inference_response.error().message())
|
|
|
+
|
|
|
+ # Extract and convert output tensors
|
|
|
+ prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding')
|
|
|
+ prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
|
|
|
+
|
|
|
+ return prompt_spk_embedding
|
|
|
+
|
|
|
def forward_token2wav(
|
|
|
self,
|
|
|
prompt_speech_tokens: torch.Tensor,
|
|
|
prompt_speech_feat: torch.Tensor,
|
|
|
prompt_spk_embedding: torch.Tensor,
|
|
|
- target_speech_tokens: torch.Tensor) -> torch.Tensor:
|
|
|
+ target_speech_tokens: torch.Tensor,
|
|
|
+ request_id: str,
|
|
|
+ token_offset: int = None,
|
|
|
+ finalize: bool = None) -> torch.Tensor:
|
|
|
"""Forward pass through the vocoder component.
|
|
|
|
|
|
Args:
|
|
|
@@ -210,11 +243,21 @@ class TritonPythonModel:
|
|
|
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
|
|
|
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
|
|
|
|
|
|
+ inputs_tensor = [prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor]
|
|
|
+
|
|
|
+ if token_offset is not None:
|
|
|
+ assert finalize is not None
|
|
|
+ token_offset_tensor = pb_utils.Tensor("token_offset", np.array([[token_offset]], dtype=np.int32))
|
|
|
+ finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
|
|
|
+ inputs_tensor.append(token_offset_tensor)
|
|
|
+ inputs_tensor.append(finalize_tensor)
|
|
|
+
|
|
|
# Create and execute inference request
|
|
|
inference_request = pb_utils.InferenceRequest(
|
|
|
model_name='token2wav',
|
|
|
requested_output_names=['waveform'],
|
|
|
- inputs=[prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor]
|
|
|
+ inputs=inputs_tensor,
|
|
|
+ request_id=request_id,
|
|
|
)
|
|
|
|
|
|
inference_response = inference_request.exec()
|
|
|
@@ -235,17 +278,6 @@ class TritonPythonModel:
|
|
|
input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1)
|
|
|
return input_ids
|
|
|
|
|
|
- def _extract_spk_embedding(self, speech):
|
|
|
- feat = kaldi.fbank(speech,
|
|
|
- num_mel_bins=80,
|
|
|
- dither=0,
|
|
|
- sample_frequency=16000)
|
|
|
- feat = feat - feat.mean(dim=0, keepdim=True)
|
|
|
- embedding = self.campplus_session.run(None,
|
|
|
- {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
|
|
- embedding = torch.tensor([embedding]).to(self.device).half()
|
|
|
- return embedding
|
|
|
-
|
|
|
def _extract_speech_feat(self, speech):
|
|
|
speech_feat = mel_spectrogram(
|
|
|
speech,
|
|
|
@@ -263,6 +295,14 @@ class TritonPythonModel:
|
|
|
speech_feat = speech_feat.unsqueeze(dim=0)
|
|
|
return speech_feat
|
|
|
|
|
|
+ def _llm_gen_thread(self, generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag):
|
|
|
+ for generated_ids in generated_ids_iter:
|
|
|
+ generated_ids = generated_ids.tolist()
|
|
|
+ if len(generated_ids) == 0:
|
|
|
+ break
|
|
|
+ semantic_token_ids_arr.extend(generated_ids)
|
|
|
+ llm_is_done_flag[0] = True
|
|
|
+
|
|
|
def execute(self, requests):
|
|
|
"""Execute inference on the batched requests.
|
|
|
|
|
|
@@ -275,6 +315,7 @@ class TritonPythonModel:
|
|
|
responses = []
|
|
|
|
|
|
for request in requests:
|
|
|
+ request_id = request.request_id()
|
|
|
# Extract input tensors
|
|
|
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
|
|
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
|
|
@@ -292,6 +333,8 @@ class TritonPythonModel:
|
|
|
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
|
|
|
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
|
|
|
|
|
|
+ flow_prompt_speech_token_len = prompt_speech_tokens.shape[-1]
|
|
|
+
|
|
|
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
|
|
reference_text = reference_text[0][0].decode('utf-8')
|
|
|
|
|
|
@@ -307,25 +350,76 @@ class TritonPythonModel:
|
|
|
|
|
|
# Generate semantic tokens with LLM
|
|
|
generated_ids_iter = self.forward_llm(input_ids)
|
|
|
+ prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
|
|
|
|
|
|
if self.decoupled:
|
|
|
response_sender = request.get_response_sender()
|
|
|
- request_id = request.request_id()
|
|
|
- generated_ids = []
|
|
|
- for generated_id in generated_ids_iter:
|
|
|
- # convert the numpy array into a int32 tensor
|
|
|
- generated_id = generated_id.tolist()
|
|
|
- if len(generated_id) > 0:
|
|
|
- assert len(generated_id) == 1, "Generated ID is not a single integer"
|
|
|
- generated_ids.append(generated_id[0])
|
|
|
- generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(torch.int32).to(self.device)
|
|
|
- prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
|
|
|
- audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
|
|
|
|
|
|
- # Prepare response
|
|
|
- audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|
|
|
+ semantic_token_ids_arr = []
|
|
|
+ llm_is_done_flag = [False]
|
|
|
+
|
|
|
+ llm_thread = threading.Thread(
|
|
|
+ target=self._llm_gen_thread,
|
|
|
+ args=(generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag)
|
|
|
+ )
|
|
|
+
|
|
|
+ llm_thread.start()
|
|
|
+
|
|
|
+ token_offset, chunk_index = 0, 0
|
|
|
+ start_time = time.time()
|
|
|
+ this_token_hop_len = self.token_hop_len
|
|
|
+
|
|
|
+ while True:
|
|
|
+ pending_num = len(semantic_token_ids_arr) - token_offset
|
|
|
+
|
|
|
+ if llm_is_done_flag[0]:
|
|
|
+ break
|
|
|
+
|
|
|
+ if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
|
|
|
+ this_tts_speech_token = semantic_token_ids_arr[:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
|
|
|
+ this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
|
|
+
|
|
|
+ sub_tts_speech = self.forward_token2wav(
|
|
|
+ prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding,
|
|
|
+ this_tts_speech_token, request_id, token_offset, False)
|
|
|
+
|
|
|
+ audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
|
|
+ inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
|
|
+ response_sender.send(inference_response)
|
|
|
+
|
|
|
+ token_offset += this_token_hop_len
|
|
|
+ self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}")
|
|
|
+
|
|
|
+ if self.dynamic_chunk_strategy == "exponential":
|
|
|
+ this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
|
|
|
+ elif self.dynamic_chunk_strategy == "time_based":
|
|
|
+ # see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
|
|
|
+ cost_time = time.time() - start_time
|
|
|
+ duration = token_offset / self.token_frame_rate
|
|
|
+ if chunk_index > 0 and cost_time > 0:
|
|
|
+ avg_chunk_processing_time = cost_time / (chunk_index + 1)
|
|
|
+ if avg_chunk_processing_time > 0:
|
|
|
+ multiples = (duration - cost_time) / avg_chunk_processing_time
|
|
|
+ self.logger.log_info(f"multiples: {multiples}")
|
|
|
+ next_pending_num = len(semantic_token_ids_arr) - token_offset
|
|
|
+ if multiples > 4:
|
|
|
+ this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len
|
|
|
+ elif multiples > 2:
|
|
|
+ this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len
|
|
|
+ else:
|
|
|
+ this_token_hop_len = self.token_hop_len
|
|
|
+ this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
|
|
|
+ chunk_index += 1
|
|
|
+ else:
|
|
|
+ time.sleep(0.02)
|
|
|
+
|
|
|
+ this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
|
|
+ sub_tts_speech = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, this_tts_speech_token, request_id, token_offset, True)
|
|
|
+ audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
|
|
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
|
|
response_sender.send(inference_response)
|
|
|
+
|
|
|
+ llm_thread.join()
|
|
|
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
|
|
self.logger.log_info("send tritonserver_response_complete_final to end")
|
|
|
else:
|
|
|
@@ -334,8 +428,7 @@ class TritonPythonModel:
|
|
|
if generated_ids is None or len(generated_ids) == 0:
|
|
|
raise pb_utils.TritonModelException("Generated IDs is None or empty")
|
|
|
|
|
|
- prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
|
|
|
- audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
|
|
|
+ audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids, request_id)
|
|
|
|
|
|
# Prepare response
|
|
|
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|