|
|
@@ -199,8 +199,6 @@ class TritonPythonModel:
|
|
|
Returns:
|
|
|
Generated waveform tensor
|
|
|
"""
|
|
|
- print(prompt_speech_tokens.shape, prompt_speech_feat.shape, prompt_spk_embedding.shape, target_speech_tokens.shape)
|
|
|
- # Convert tensors to Triton format
|
|
|
prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
|
|
|
prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
|
|
|
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
|
|
|
@@ -228,9 +226,7 @@ class TritonPythonModel:
|
|
|
prompt = self.prompt_template.format(input_text=total_text)
|
|
|
input_ids = self.tokenizer.encode(prompt)
|
|
|
input_ids = torch.tensor([input_ids], dtype=torch.int32)
|
|
|
- print(input_ids.shape, "before cat")
|
|
|
input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1)
|
|
|
- print(input_ids.shape, "after cat", prompt_speech_tokens.shape)
|
|
|
return input_ids
|
|
|
|
|
|
def _extract_spk_embedding(self, speech):
|
|
|
@@ -271,23 +267,15 @@ class TritonPythonModel:
|
|
|
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
|
|
|
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
|
|
|
|
|
|
- # TODO: FIX ME
|
|
|
+
|
|
|
wav_tensor = wav.as_numpy()
|
|
|
- print(wav_tensor.shape, "wav_tensor")
|
|
|
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
|
|
|
- print(wav_tensor.shape, "wav_tensor after")
|
|
|
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
|
|
|
speech_feat = self._extract_speech_feat(prompt_speech_resample)
|
|
|
- print(speech_feat.shape, "speech_feat")
|
|
|
- print(prompt_speech_tokens.shape, "prompt_speech_tokens here")
|
|
|
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
|
|
|
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
|
|
|
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
|
|
|
- print(prompt_speech_tokens.shape, "prompt_speech_tokens after")
|
|
|
- print(speech_feat.shape, "speech_feat after")
|
|
|
- print(token_len, "token_len")
|
|
|
|
|
|
- # Extract text inputs
|
|
|
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
|
|
reference_text = reference_text[0][0].decode('utf-8')
|
|
|
|