|
|
@@ -43,6 +43,7 @@ import torchaudio
|
|
|
|
|
|
from matcha.utils.audio import mel_spectrogram
|
|
|
|
|
|
+ORIGINAL_VOCAB_SIZE = 151663
|
|
|
torch.set_num_threads(1)
|
|
|
|
|
|
|
|
|
@@ -81,6 +82,12 @@ class TritonPythonModel:
|
|
|
self.flow_pre_lookahead_len = 3
|
|
|
self.token_hop_len = 15
|
|
|
|
|
|
+ spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt")
|
|
|
+ if not os.path.exists(spk_info_path):
|
|
|
+ raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
|
|
|
+ spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
|
|
|
+ self.default_spk_info = spk_info["001"]
|
|
|
+
|
|
|
def forward_llm(self, input_ids):
|
|
|
"""
|
|
|
Prepares the response from the language model based on the provided
|
|
|
@@ -220,11 +227,11 @@ class TritonPythonModel:
|
|
|
|
|
|
def forward_token2wav(
|
|
|
self,
|
|
|
- prompt_speech_tokens: torch.Tensor,
|
|
|
- prompt_speech_feat: torch.Tensor,
|
|
|
- prompt_spk_embedding: torch.Tensor,
|
|
|
target_speech_tokens: torch.Tensor,
|
|
|
request_id: str,
|
|
|
+ prompt_speech_tokens: torch.Tensor = None,
|
|
|
+ prompt_speech_feat: torch.Tensor = None,
|
|
|
+ prompt_spk_embedding: torch.Tensor = None,
|
|
|
token_offset: int = None,
|
|
|
finalize: bool = None) -> torch.Tensor:
|
|
|
"""Forward pass through the vocoder component.
|
|
|
@@ -238,12 +245,9 @@ class TritonPythonModel:
|
|
|
Returns:
|
|
|
Generated waveform tensor
|
|
|
"""
|
|
|
- prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
|
|
|
- prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
|
|
|
- prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
|
|
|
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
|
|
|
|
|
|
- inputs_tensor = [prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor]
|
|
|
+ inputs_tensor = [target_speech_tokens_tensor]
|
|
|
|
|
|
if token_offset is not None:
|
|
|
assert finalize is not None
|
|
|
@@ -252,6 +256,13 @@ class TritonPythonModel:
|
|
|
inputs_tensor.append(token_offset_tensor)
|
|
|
inputs_tensor.append(finalize_tensor)
|
|
|
|
|
|
+ if prompt_spk_embedding is not None:
|
|
|
+ assert prompt_speech_feat is not None
|
|
|
+ prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
|
|
|
+ prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
|
|
|
+ prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
|
|
|
+ inputs_tensor.extend([prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor])
|
|
|
+
|
|
|
# Create and execute inference request
|
|
|
inference_request = pb_utils.InferenceRequest(
|
|
|
model_name='token2wav',
|
|
|
@@ -318,25 +329,30 @@ class TritonPythonModel:
|
|
|
request_id = request.request_id()
|
|
|
# Extract input tensors
|
|
|
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
|
|
- wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
|
|
|
|
|
# Process reference audio through audio tokenizer
|
|
|
-
|
|
|
- prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
|
|
|
- prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
|
|
|
-
|
|
|
- wav_tensor = wav.as_numpy()
|
|
|
- wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
|
|
|
- prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
|
|
|
- speech_feat = self._extract_speech_feat(prompt_speech_resample)
|
|
|
- token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
|
|
|
- prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
|
|
|
- prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
|
|
|
-
|
|
|
- flow_prompt_speech_token_len = prompt_speech_tokens.shape[-1]
|
|
|
-
|
|
|
- reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
|
|
- reference_text = reference_text[0][0].decode('utf-8')
|
|
|
+ if wav is not None:
|
|
|
+ wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
|
|
+ prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
|
|
|
+ prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
|
|
|
+
|
|
|
+ wav_tensor = wav.as_numpy()
|
|
|
+ wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
|
|
|
+ prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
|
|
|
+ speech_feat = self._extract_speech_feat(prompt_speech_resample)
|
|
|
+ token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
|
|
|
+ prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
|
|
|
+ prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
|
|
|
+
|
|
|
+ reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
|
|
+ reference_text = reference_text[0][0].decode('utf-8')
|
|
|
+ prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
|
|
|
+ else:
|
|
|
+ # using pre-cached reference text
|
|
|
+ reference_text = self.default_spk_info["prompt_text"]
|
|
|
+ prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
|
|
|
+ prompt_speech_feat = None
|
|
|
+ prompt_spk_embedding = None
|
|
|
|
|
|
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
|
|
|
target_text = target_text[0][0].decode('utf-8')
|
|
|
@@ -350,7 +366,6 @@ class TritonPythonModel:
|
|
|
|
|
|
# Generate semantic tokens with LLM
|
|
|
generated_ids_iter = self.forward_llm(input_ids)
|
|
|
- prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
|
|
|
|
|
|
if self.decoupled:
|
|
|
response_sender = request.get_response_sender()
|
|
|
@@ -380,8 +395,9 @@ class TritonPythonModel:
|
|
|
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
|
|
|
|
|
sub_tts_speech = self.forward_token2wav(
|
|
|
- prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding,
|
|
|
- this_tts_speech_token, request_id, token_offset, False)
|
|
|
+ this_tts_speech_token, request_id, prompt_speech_tokens,
|
|
|
+ prompt_speech_feat, prompt_spk_embedding, token_offset, False
|
|
|
+ )
|
|
|
|
|
|
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
|
|
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
|
|
@@ -414,7 +430,7 @@ class TritonPythonModel:
|
|
|
time.sleep(0.02)
|
|
|
|
|
|
this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
|
|
- sub_tts_speech = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, this_tts_speech_token, request_id, token_offset, True)
|
|
|
+ sub_tts_speech = self.forward_token2wav(this_tts_speech_token, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, token_offset, True)
|
|
|
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
|
|
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
|
|
response_sender.send(inference_response)
|
|
|
@@ -428,7 +444,7 @@ class TritonPythonModel:
|
|
|
if generated_ids is None or len(generated_ids) == 0:
|
|
|
raise pb_utils.TritonModelException("Generated IDs is None or empty")
|
|
|
|
|
|
- audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids, request_id)
|
|
|
+ audio = self.forward_token2wav(generated_ids, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding)
|
|
|
|
|
|
# Prepare response
|
|
|
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|