|
|
@@ -103,6 +103,7 @@ class TritonPythonModel:
|
|
|
|
|
|
self.http_client = httpx.AsyncClient()
|
|
|
self.api_base = "http://localhost:8000/v1/chat/completions"
|
|
|
+ self.speaker_cache = {}
|
|
|
|
|
|
def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
|
|
|
"""Converts a tensor or list of speech token IDs to a string representation."""
|
|
|
@@ -240,10 +241,12 @@ class TritonPythonModel:
|
|
|
"""Forward pass through the vocoder component.
|
|
|
|
|
|
Args:
|
|
|
- prompt_speech_tokens: Prompt speech tokens tensor
|
|
|
- prompt_speech_feat: Prompt speech feat tensor
|
|
|
- prompt_spk_embedding: Prompt spk embedding tensor
|
|
|
+ index: Index of the request
|
|
|
target_speech_tokens: Target speech tokens tensor
|
|
|
+ request_id: Request ID
|
|
|
+ reference_wav: Reference waveform tensor
|
|
|
+ reference_wav_len: Reference waveform length tensor
|
|
|
+ finalize: Whether to finalize the request
|
|
|
|
|
|
Returns:
|
|
|
Generated waveform tensor
|
|
|
@@ -292,25 +295,16 @@ class TritonPythonModel:
|
|
|
|
|
|
async def _process_request(self, request):
|
|
|
request_id = request.request_id()
|
|
|
- # Extract input tensors
|
|
|
- wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
|
|
|
|
|
- # Process reference audio through audio tokenizer
|
|
|
+ reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
|
|
+ reference_text = reference_text[0][0].decode('utf-8')
|
|
|
|
|
|
+ wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
|
|
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
|
|
- prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
|
|
|
- prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
|
|
|
|
|
|
- wav_tensor = wav.as_numpy()
|
|
|
- wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
|
|
|
- prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
|
|
|
- speech_feat = self._extract_speech_feat(prompt_speech_resample)
|
|
|
- token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
|
|
|
- prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
|
|
|
- prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
|
|
|
-
|
|
|
- reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
|
|
- reference_text = reference_text[0][0].decode('utf-8')
|
|
|
+ if reference_text not in self.speaker_cache:
|
|
|
+ self.speaker_cache[reference_text] = self.forward_audio_tokenizer(wav, wav_len).unsqueeze(0)
|
|
|
+ prompt_speech_tokens = self.speaker_cache[reference_text]
|
|
|
|
|
|
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
|
|
|
target_text = target_text[0][0].decode('utf-8')
|