|
|
@@ -0,0 +1,492 @@
|
|
|
+import json
|
|
|
+import re
|
|
|
+import time
|
|
|
+import asyncio
|
|
|
+
|
|
|
+import numpy as np
|
|
|
+import torch
|
|
|
+from torch.utils.dlpack import to_dlpack
|
|
|
+import triton_python_backend_utils as pb_utils
|
|
|
+
|
|
|
+import httpx
|
|
|
+import torchaudio
|
|
|
+from functools import partial
|
|
|
+from matcha.utils.audio import mel_spectrogram as matcha_mel_spectrogram
|
|
|
+
|
|
|
+
|
|
|
+torch.set_num_threads(1)
|
|
|
+
|
|
|
+# CosyVoice3 mel params: fmax=None (Nyquist), center=False
|
|
|
+mel_spectrogram = partial(matcha_mel_spectrogram,
|
|
|
+ n_fft=1920, num_mels=80, sampling_rate=24000,
|
|
|
+ hop_size=480, win_size=1920, fmin=0, fmax=None, center=False)
|
|
|
+
|
|
|
+
|
|
|
+def parse_speech_token_string(response_text):
|
|
|
+ """Parse speech tokens from string like '<|s_123|><|s_456|>' into list of int IDs."""
|
|
|
+ speech_tokens = response_text.strip().split('><')
|
|
|
+ if len(speech_tokens) > 1:
|
|
|
+ speech_tokens = ['<' + t if not t.startswith('<') else t for t in speech_tokens]
|
|
|
+ speech_tokens = [t + '>' if not t.endswith('>') else t for t in speech_tokens]
|
|
|
+ speech_ids = []
|
|
|
+ for token_str in speech_tokens:
|
|
|
+ match = re.match(r'<\|s_(\d+)\|>', token_str)
|
|
|
+ if match:
|
|
|
+ speech_ids.append(int(match.group(1)))
|
|
|
+ return speech_ids
|
|
|
+
|
|
|
+
|
|
|
+class TritonPythonModel:
|
|
|
+ """CosyVoice3 BLS orchestrator for Triton Inference Server.
|
|
|
+
|
|
|
+ Orchestrates: audio_tokenizer, speaker_embedding, remote LLM (httpx),
|
|
|
+ token2wav (flow-only), and vocoder (CausalHiFTGenerator).
|
|
|
+ Supports both streaming (decoupled) and offline (non-decoupled) modes.
|
|
|
+ """
|
|
|
+
|
|
|
+ def initialize(self, args):
|
|
|
+ self.logger = pb_utils.Logger
|
|
|
+ self.model_config = json.loads(args['model_config'])
|
|
|
+ parameters = self.model_config['parameters']
|
|
|
+ model_params = {k: v["string_value"] for k, v in parameters.items()}
|
|
|
+
|
|
|
+ self.device = torch.device("cuda")
|
|
|
+ self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
|
|
|
+
|
|
|
+ # Streaming config
|
|
|
+ self.token_frame_rate = 25
|
|
|
+ self.flow_pre_lookahead_len = 3
|
|
|
+ self.token_hop_len = 15
|
|
|
+ self.token_mel_ratio = 2
|
|
|
+ self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential")
|
|
|
+ self.logger.log_info(f"CosyVoice3 BLS initialized, decoupled={self.decoupled}, "
|
|
|
+ f"chunk_strategy={self.dynamic_chunk_strategy}")
|
|
|
+
|
|
|
+ # HTTP client for remote LLM (trtllm-serve default port: 8000)
|
|
|
+ self.http_client = httpx.AsyncClient()
|
|
|
+ self.api_base = model_params.get("llm_api_base", "http://localhost:8000/v1/chat/completions")
|
|
|
+
|
|
|
+ # Speaker cache to avoid redundant audio_tokenizer/speaker_embedding calls
|
|
|
+ self.speaker_cache = {}
|
|
|
+
|
|
|
+ def _convert_speech_tokens_to_str(self, speech_tokens):
|
|
|
+ """Convert speech token IDs tensor/list to string like '<|s_N|>'."""
|
|
|
+ if isinstance(speech_tokens, torch.Tensor):
|
|
|
+ speech_tokens = speech_tokens.cpu().numpy().flatten().tolist()
|
|
|
+ return "".join(f"<|s_{int(tid)}|>" for tid in speech_tokens)
|
|
|
+
|
|
|
+ def _extract_speech_feat(self, speech):
|
|
|
+ """Extract mel spectrogram from 24kHz speech for flow prompt."""
|
|
|
+ speech_feat = mel_spectrogram(speech).squeeze(dim=0).transpose(0, 1)
|
|
|
+ speech_feat = speech_feat.unsqueeze(dim=0).to(self.device)
|
|
|
+ return speech_feat
|
|
|
+
|
|
|
+ async def forward_llm_streaming(self, target_text, reference_text, prompt_speech_tokens):
|
|
|
+ """Async generator: stream LLM tokens via httpx SSE."""
|
|
|
+ full_text = f"{reference_text}{target_text}"
|
|
|
+ prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens)
|
|
|
+
|
|
|
+ chat = [
|
|
|
+ {"role": "user", "content": full_text},
|
|
|
+ {"role": "assistant", "content": prompt_speech_tokens_str}
|
|
|
+ ]
|
|
|
+ payload = {
|
|
|
+ "model": "trt_engines_bfloat16",
|
|
|
+ "messages": chat,
|
|
|
+ "max_tokens": 750,
|
|
|
+ "temperature": 0.8,
|
|
|
+ "top_p": 0.95,
|
|
|
+ "top_k": 50,
|
|
|
+ "repetition_penalty": 1.1,
|
|
|
+ "stop": ["<|eos1|>", "<|eos|>"],
|
|
|
+ "stream": True,
|
|
|
+ }
|
|
|
+
|
|
|
+ buffer = ""
|
|
|
+ async with self.http_client.stream("POST", self.api_base, json=payload, timeout=None) as response:
|
|
|
+ response.raise_for_status()
|
|
|
+ async for line in response.aiter_lines():
|
|
|
+ if line.startswith("data: "):
|
|
|
+ line_data = line[len("data: "):].strip()
|
|
|
+ if line_data == "[DONE]":
|
|
|
+ break
|
|
|
+ try:
|
|
|
+ json_data = json.loads(line_data)
|
|
|
+ content = json_data.get("choices", [{}])[0].get("delta", {}).get("content")
|
|
|
+ if content:
|
|
|
+ buffer += content
|
|
|
+ while True:
|
|
|
+ match = re.search(r"<\|s_(\d+)\|>", buffer)
|
|
|
+ if not match:
|
|
|
+ break
|
|
|
+ token_num = int(match.group(1))
|
|
|
+ # final_id = token_num + ORIGINAL_VOCAB_SIZE
|
|
|
+ yield token_num
|
|
|
+ buffer = buffer[match.end():]
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # Flush remaining tokens
|
|
|
+ while True:
|
|
|
+ match = re.search(r"<\|s_(\d+)\|>", buffer)
|
|
|
+ if not match:
|
|
|
+ break
|
|
|
+ token_num = int(match.group(1))
|
|
|
+ #final_id = token_num + ORIGINAL_VOCAB_SIZE
|
|
|
+ yield token_num
|
|
|
+ buffer = buffer[match.end():]
|
|
|
+
|
|
|
+ async def forward_llm_offline(self, target_text, reference_text, prompt_speech_tokens):
|
|
|
+ """Non-streaming LLM call, returns all speech token IDs at once."""
|
|
|
+ full_text = f"{reference_text}{target_text}"
|
|
|
+ prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens)
|
|
|
+
|
|
|
+ chat = [
|
|
|
+ {"role": "user", "content": full_text},
|
|
|
+ {"role": "assistant", "content": prompt_speech_tokens_str}
|
|
|
+ ]
|
|
|
+ payload = {
|
|
|
+ "model": "trt_engines_bfloat16",
|
|
|
+ "messages": chat,
|
|
|
+ "max_tokens": 750,
|
|
|
+ "temperature": 0.8,
|
|
|
+ "top_p": 0.95,
|
|
|
+ "top_k": 50,
|
|
|
+ "repetition_penalty": 1.1,
|
|
|
+ "stop": ["<|eos1|>", "<|eos|>"],
|
|
|
+ "stream": False,
|
|
|
+ }
|
|
|
+ response = await self.http_client.post(self.api_base, json=payload, timeout=None)
|
|
|
+ response.raise_for_status()
|
|
|
+ response_json = response.json()
|
|
|
+ generated_content = response_json['choices'][0]['message']['content']
|
|
|
+ speech_ids = parse_speech_token_string(generated_content)
|
|
|
+ # return [sid + ORIGINAL_VOCAB_SIZE for sid in speech_ids]
|
|
|
+ return speech_ids
|
|
|
+
|
|
|
+ def forward_audio_tokenizer(self, wav, wav_len):
|
|
|
+ """BLS call to audio_tokenizer."""
|
|
|
+ inference_request = pb_utils.InferenceRequest(
|
|
|
+ model_name='audio_tokenizer',
|
|
|
+ requested_output_names=['prompt_speech_tokens'],
|
|
|
+ inputs=[wav, wav_len]
|
|
|
+ )
|
|
|
+ inference_response = inference_request.exec()
|
|
|
+ if inference_response.has_error():
|
|
|
+ raise pb_utils.TritonModelException(inference_response.error().message())
|
|
|
+ prompt_speech_tokens = pb_utils.get_output_tensor_by_name(
|
|
|
+ inference_response, 'prompt_speech_tokens')
|
|
|
+ return torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
|
|
|
+
|
|
|
+ def forward_speaker_embedding(self, wav):
|
|
|
+ """BLS call to speaker_embedding."""
|
|
|
+ inference_request = pb_utils.InferenceRequest(
|
|
|
+ model_name='speaker_embedding',
|
|
|
+ requested_output_names=['prompt_spk_embedding'],
|
|
|
+ inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
|
|
|
+ )
|
|
|
+ inference_response = inference_request.exec()
|
|
|
+ if inference_response.has_error():
|
|
|
+ raise pb_utils.TritonModelException(inference_response.error().message())
|
|
|
+ prompt_spk_embedding = pb_utils.get_output_tensor_by_name(
|
|
|
+ inference_response, 'prompt_spk_embedding')
|
|
|
+ return torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
|
|
|
+
|
|
|
+ async def forward_token2wav(self, target_speech_tokens, prompt_speech_tokens,
|
|
|
+ prompt_speech_feat, prompt_spk_embedding,
|
|
|
+ request_id, token_offset=None, finalize=True,
|
|
|
+ priority=100):
|
|
|
+ """Async BLS call to token2wav (flow-only). Returns mel tensor."""
|
|
|
+ target_tokens_pb = pb_utils.Tensor.from_dlpack(
|
|
|
+ "target_speech_tokens", to_dlpack(target_speech_tokens))
|
|
|
+ prompt_tokens_pb = pb_utils.Tensor.from_dlpack(
|
|
|
+ "prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
|
|
|
+ prompt_feat_pb = pb_utils.Tensor.from_dlpack(
|
|
|
+ "prompt_speech_feat", to_dlpack(prompt_speech_feat))
|
|
|
+ prompt_emb_pb = pb_utils.Tensor.from_dlpack(
|
|
|
+ "prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
|
|
|
+
|
|
|
+ inputs = [target_tokens_pb, prompt_tokens_pb, prompt_feat_pb, prompt_emb_pb]
|
|
|
+
|
|
|
+ if token_offset is not None:
|
|
|
+ inputs.append(pb_utils.Tensor("token_offset",
|
|
|
+ np.array([[token_offset]], dtype=np.int32)))
|
|
|
+ inputs.append(pb_utils.Tensor("finalize",
|
|
|
+ np.array([[finalize]], dtype=np.bool_)))
|
|
|
+
|
|
|
+ inference_request = pb_utils.InferenceRequest(
|
|
|
+ model_name='token2wav',
|
|
|
+ requested_output_names=['mel'],
|
|
|
+ inputs=inputs,
|
|
|
+ request_id=request_id,
|
|
|
+ parameters={"priority": priority},
|
|
|
+ )
|
|
|
+
|
|
|
+ inference_response = await inference_request.async_exec()
|
|
|
+ if inference_response.has_error():
|
|
|
+ raise pb_utils.TritonModelException(inference_response.error().message())
|
|
|
+
|
|
|
+ mel = pb_utils.get_output_tensor_by_name(inference_response, 'mel')
|
|
|
+ return torch.utils.dlpack.from_dlpack(mel.to_dlpack())
|
|
|
+
|
|
|
+ async def forward_vocoder(self, mel, finalize):
|
|
|
+ """Async BLS call to vocoder. Returns speech tensor."""
|
|
|
+ if mel.dim() == 2:
|
|
|
+ mel = mel.unsqueeze(0) # [80, T] -> [1, 80, T]
|
|
|
+ mel_pb = pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel.float()))
|
|
|
+ finalize_pb = pb_utils.Tensor("finalize",
|
|
|
+ np.array([[finalize]], dtype=np.bool_))
|
|
|
+
|
|
|
+ inference_request = pb_utils.InferenceRequest(
|
|
|
+ model_name='vocoder',
|
|
|
+ requested_output_names=['tts_speech'],
|
|
|
+ inputs=[mel_pb, finalize_pb],
|
|
|
+ )
|
|
|
+
|
|
|
+ inference_response = await inference_request.async_exec()
|
|
|
+ if inference_response.has_error():
|
|
|
+ raise pb_utils.TritonModelException(inference_response.error().message())
|
|
|
+
|
|
|
+ speech = pb_utils.get_output_tensor_by_name(inference_response, 'tts_speech')
|
|
|
+ return torch.utils.dlpack.from_dlpack(speech.to_dlpack()).cpu()
|
|
|
+
|
|
|
+ def _prepare_prompt(self, request):
|
|
|
+ """Extract reference audio, tokenize, compute speaker embedding and mel feat."""
|
|
|
+ wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
|
|
+ wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
|
|
+
|
|
|
+ reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text")
|
|
|
+ reference_text = reference_text.as_numpy()[0][0].decode('utf-8') if reference_text is not None else ""
|
|
|
+ if '<|endofprompt|>' not in reference_text:
|
|
|
+ reference_text = 'You are a helpful assistant.<|endofprompt|>' + reference_text
|
|
|
+
|
|
|
+ # Check speaker cache
|
|
|
+ if reference_text in self.speaker_cache:
|
|
|
+ cached = self.speaker_cache[reference_text]
|
|
|
+ return (cached['prompt_speech_tokens_for_llm'], cached['prompt_speech_tokens'],
|
|
|
+ cached['prompt_speech_feat'], cached['prompt_spk_embedding'], reference_text)
|
|
|
+
|
|
|
+ # Audio tokenizer
|
|
|
+ wav_np = wav.as_numpy()
|
|
|
+ wav_len_val = wav_len.as_numpy()[0][0]
|
|
|
+ prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
|
|
|
+ prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0) # [1, T]
|
|
|
+
|
|
|
+ # Speaker embedding
|
|
|
+ wav_tensor = torch.from_numpy(wav_np)
|
|
|
+ wav_tensor = wav_tensor[:, :wav_len_val]
|
|
|
+ prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
|
|
|
+
|
|
|
+ # Mel extraction at 24kHz with CosyVoice3 params
|
|
|
+ prompt_speech_resample = torchaudio.transforms.Resample(
|
|
|
+ orig_freq=16000, new_freq=24000)(wav_tensor)
|
|
|
+ speech_feat = self._extract_speech_feat(prompt_speech_resample)
|
|
|
+
|
|
|
+ # Keep full tokens for LLM prefill (untruncated)
|
|
|
+ prompt_speech_tokens_for_llm = prompt_speech_tokens.clone()
|
|
|
+
|
|
|
+ # Align prompt speech feat and tokens to 2:1 ratio (for flow model only)
|
|
|
+ orig_feat_len = speech_feat.shape[1]
|
|
|
+ orig_token_len = prompt_speech_tokens.shape[-1]
|
|
|
+ token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
|
|
|
+ prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
|
|
|
+ prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
|
|
|
+
|
|
|
+ # Cache
|
|
|
+ self.speaker_cache[reference_text] = {
|
|
|
+ 'prompt_speech_tokens_for_llm': prompt_speech_tokens_for_llm,
|
|
|
+ 'prompt_speech_tokens': prompt_speech_tokens,
|
|
|
+ 'prompt_speech_feat': prompt_speech_feat,
|
|
|
+ 'prompt_spk_embedding': prompt_spk_embedding,
|
|
|
+ }
|
|
|
+
|
|
|
+ return prompt_speech_tokens_for_llm, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, reference_text
|
|
|
+
|
|
|
+ async def _process_request_streaming(self, request):
|
|
|
+ """Process a single request in streaming (decoupled) mode."""
|
|
|
+ request_id = request.request_id()
|
|
|
+ response_sender = request.get_response_sender()
|
|
|
+
|
|
|
+ try:
|
|
|
+ prompt_speech_tokens_for_llm, prompt_speech_tokens, prompt_speech_feat, \
|
|
|
+ prompt_spk_embedding, reference_text = self._prepare_prompt(request)
|
|
|
+
|
|
|
+ target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
|
|
|
+ target_text = target_text[0][0].decode('utf-8')
|
|
|
+
|
|
|
+ semantic_token_ids_arr = []
|
|
|
+ token_offset = 0
|
|
|
+ chunk_index = 0
|
|
|
+ this_token_hop_len = self.token_hop_len
|
|
|
+ accumulated_mel = None
|
|
|
+ speech_offset = 0
|
|
|
+ start_time = time.time()
|
|
|
+
|
|
|
+ async for generated_id in self.forward_llm_streaming(
|
|
|
+ target_text=target_text,
|
|
|
+ reference_text=reference_text,
|
|
|
+ prompt_speech_tokens=prompt_speech_tokens_for_llm,
|
|
|
+ ):
|
|
|
+ semantic_token_ids_arr.append(generated_id)
|
|
|
+
|
|
|
+ while True:
|
|
|
+ pending_num = len(semantic_token_ids_arr) - token_offset
|
|
|
+ if pending_num < this_token_hop_len + self.flow_pre_lookahead_len:
|
|
|
+ break
|
|
|
+
|
|
|
+ # Prepare tokens for this chunk
|
|
|
+ end_idx = token_offset + this_token_hop_len + self.flow_pre_lookahead_len
|
|
|
+ this_tokens = torch.tensor(
|
|
|
+ semantic_token_ids_arr[:end_idx]
|
|
|
+ ).unsqueeze(0).to(torch.int32).to(self.device)
|
|
|
+
|
|
|
+ # Call token2wav (flow-only) -> mel_chunk
|
|
|
+ mel_chunk = await self.forward_token2wav(
|
|
|
+ this_tokens, prompt_speech_tokens,
|
|
|
+ prompt_speech_feat, prompt_spk_embedding,
|
|
|
+ request_id, token_offset=token_offset, finalize=False,
|
|
|
+ priority=chunk_index + 1,
|
|
|
+ )
|
|
|
+
|
|
|
+ # Accumulate mel
|
|
|
+ if mel_chunk.dim() == 2:
|
|
|
+ mel_chunk = mel_chunk.unsqueeze(0)
|
|
|
+ if accumulated_mel is None:
|
|
|
+ accumulated_mel = mel_chunk
|
|
|
+ else:
|
|
|
+ accumulated_mel = torch.cat([accumulated_mel, mel_chunk], dim=2)
|
|
|
+
|
|
|
+ # Call vocoder
|
|
|
+ speech = await self.forward_vocoder(accumulated_mel, finalize=False)
|
|
|
+
|
|
|
+ # Extract new speech
|
|
|
+ new_speech = speech[:, speech_offset:]
|
|
|
+ speech_offset += new_speech.shape[1]
|
|
|
+
|
|
|
+ if new_speech.shape[1] > 0:
|
|
|
+ audio_tensor = pb_utils.Tensor.from_dlpack(
|
|
|
+ "waveform", to_dlpack(new_speech))
|
|
|
+ inference_response = pb_utils.InferenceResponse(
|
|
|
+ output_tensors=[audio_tensor])
|
|
|
+ response_sender.send(inference_response)
|
|
|
+
|
|
|
+ token_offset += this_token_hop_len
|
|
|
+
|
|
|
+ # Dynamic chunk strategy
|
|
|
+ if self.dynamic_chunk_strategy == "exponential":
|
|
|
+ this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
|
|
|
+ elif self.dynamic_chunk_strategy == "time_based":
|
|
|
+ cost_time = time.time() - start_time
|
|
|
+ duration = token_offset / self.token_frame_rate
|
|
|
+ if chunk_index > 0 and cost_time > 0:
|
|
|
+ avg_chunk_time = cost_time / (chunk_index + 1)
|
|
|
+ if avg_chunk_time > 0:
|
|
|
+ multiples = (duration - cost_time) / avg_chunk_time
|
|
|
+ next_pending = len(semantic_token_ids_arr) - token_offset
|
|
|
+ if multiples > 4:
|
|
|
+ this_token_hop_len = (next_pending // self.token_hop_len + 1) * self.token_hop_len
|
|
|
+ elif multiples > 2:
|
|
|
+ this_token_hop_len = (next_pending // self.token_hop_len) * self.token_hop_len
|
|
|
+ else:
|
|
|
+ this_token_hop_len = self.token_hop_len
|
|
|
+ this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
|
|
|
+
|
|
|
+ chunk_index += 1
|
|
|
+
|
|
|
+ # Final chunk with remaining tokens
|
|
|
+ if len(semantic_token_ids_arr) > 0:
|
|
|
+ remaining_tokens = torch.tensor(
|
|
|
+ semantic_token_ids_arr
|
|
|
+ ).unsqueeze(0).to(torch.int32).to(self.device)
|
|
|
+
|
|
|
+ mel_chunk = await self.forward_token2wav(
|
|
|
+ remaining_tokens, prompt_speech_tokens,
|
|
|
+ prompt_speech_feat, prompt_spk_embedding,
|
|
|
+ request_id, token_offset=token_offset, finalize=True,
|
|
|
+ priority=chunk_index + 1,
|
|
|
+ )
|
|
|
+
|
|
|
+ if mel_chunk.dim() == 2:
|
|
|
+ mel_chunk = mel_chunk.unsqueeze(0)
|
|
|
+ if accumulated_mel is None:
|
|
|
+ accumulated_mel = mel_chunk
|
|
|
+ else:
|
|
|
+ accumulated_mel = torch.cat([accumulated_mel, mel_chunk], dim=2)
|
|
|
+
|
|
|
+ speech = await self.forward_vocoder(accumulated_mel, finalize=True)
|
|
|
+
|
|
|
+ new_speech = speech[:, speech_offset:]
|
|
|
+ if new_speech.shape[1] > 0:
|
|
|
+ audio_tensor = pb_utils.Tensor.from_dlpack(
|
|
|
+ "waveform", to_dlpack(new_speech))
|
|
|
+ inference_response = pb_utils.InferenceResponse(
|
|
|
+ output_tensors=[audio_tensor])
|
|
|
+ response_sender.send(inference_response)
|
|
|
+
|
|
|
+ response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
|
|
+ except Exception as e:
|
|
|
+ self.logger.log_error(f"Error in streaming request: {e}")
|
|
|
+ error_response = pb_utils.InferenceResponse(
|
|
|
+ error=pb_utils.TritonError(str(e)))
|
|
|
+ response_sender.send(error_response)
|
|
|
+ response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
|
|
+
|
|
|
+ async def _process_request_offline(self, request):
|
|
|
+ """Process a single request in offline (non-decoupled) mode."""
|
|
|
+ request_id = request.request_id()
|
|
|
+
|
|
|
+ prompt_speech_tokens_for_llm, prompt_speech_tokens, prompt_speech_feat, \
|
|
|
+ prompt_spk_embedding, reference_text = self._prepare_prompt(request)
|
|
|
+
|
|
|
+ target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
|
|
|
+ target_text = target_text[0][0].decode('utf-8')
|
|
|
+
|
|
|
+ # Get all speech tokens at once (use full untruncated prompt tokens for LLM)
|
|
|
+ all_token_ids = await self.forward_llm_offline(
|
|
|
+ target_text=target_text,
|
|
|
+ reference_text=reference_text,
|
|
|
+ prompt_speech_tokens=prompt_speech_tokens_for_llm,
|
|
|
+ )
|
|
|
+
|
|
|
+ if len(all_token_ids) == 0:
|
|
|
+ raise pb_utils.TritonModelException("LLM generated no speech tokens")
|
|
|
+
|
|
|
+ all_tokens = torch.tensor(all_token_ids).unsqueeze(0).to(torch.int32).to(self.device)
|
|
|
+
|
|
|
+ # token2wav (no token_offset, finalize=True) -> full mel
|
|
|
+ mel = await self.forward_token2wav(
|
|
|
+ all_tokens, prompt_speech_tokens,
|
|
|
+ prompt_speech_feat, prompt_spk_embedding,
|
|
|
+ request_id,
|
|
|
+ )
|
|
|
+
|
|
|
+ # vocoder -> full speech
|
|
|
+ speech = await self.forward_vocoder(mel, finalize=True)
|
|
|
+
|
|
|
+ audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(speech))
|
|
|
+ return pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
|
|
+
|
|
|
+ async def execute(self, requests):
|
|
|
+ if self.decoupled:
|
|
|
+ tasks = [
|
|
|
+ asyncio.create_task(self._process_request_streaming(request))
|
|
|
+ for request in requests
|
|
|
+ ]
|
|
|
+ await asyncio.gather(*tasks)
|
|
|
+ return None
|
|
|
+ else:
|
|
|
+ responses = []
|
|
|
+ for request in requests:
|
|
|
+ try:
|
|
|
+ response = await self._process_request_offline(request)
|
|
|
+ responses.append(response)
|
|
|
+ except Exception as e:
|
|
|
+ self.logger.log_error(f"Error in offline request: {e}")
|
|
|
+ responses.append(pb_utils.InferenceResponse(
|
|
|
+ error=pb_utils.TritonError(str(e))))
|
|
|
+ return responses
|
|
|
+
|
|
|
+ def finalize(self):
|
|
|
+ self.logger.log_info("Finalizing CosyVoice3 BLS model")
|
|
|
+ if hasattr(self, "http_client"):
|
|
|
+ asyncio.run(self.http_client.aclose())
|