|
|
@@ -28,9 +28,10 @@ import json
|
|
|
import math
|
|
|
import os
|
|
|
import re
|
|
|
-import threading
|
|
|
import time
|
|
|
from typing import Dict, List, Tuple, Optional, Union
|
|
|
+import asyncio
|
|
|
+import httpx
|
|
|
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
@@ -42,11 +43,30 @@ import torchaudio
|
|
|
|
|
|
|
|
|
from matcha.utils.audio import mel_spectrogram
|
|
|
+from datetime import datetime
|
|
|
|
|
|
ORIGINAL_VOCAB_SIZE = 151663
|
|
|
torch.set_num_threads(1)
|
|
|
|
|
|
|
|
|
+def parse_speech_token_string(response_text: str) -> List[int]:
|
|
|
+ """
|
|
|
+ Parses a string of speech tokens (e.g., "<|s_123|><|s_456|>") into a list of integer IDs.
|
|
|
+ """
|
|
|
+ speech_tokens = response_text.strip().split('><')
|
|
|
+ if len(speech_tokens) > 1:
|
|
|
+ # Add back the missing '<' and '>' for proper parsing
|
|
|
+ speech_tokens = ['<' + t if not t.startswith('<') else t for t in speech_tokens]
|
|
|
+ speech_tokens = [t + '>' if not t.endswith('>') else t for t in speech_tokens]
|
|
|
+
|
|
|
+ speech_ids = []
|
|
|
+ for token_str in speech_tokens:
|
|
|
+ match = re.match(r'<\|s_(\d+)\|>', token_str)
|
|
|
+ if match:
|
|
|
+ speech_ids.append(int(match.group(1)))
|
|
|
+ return speech_ids
|
|
|
+
|
|
|
+
|
|
|
class TritonPythonModel:
|
|
|
"""Triton Python model for Spark TTS.
|
|
|
|
|
|
@@ -67,6 +87,7 @@ class TritonPythonModel:
|
|
|
model_params = {k: v["string_value"] for k, v in parameters.items()}
|
|
|
self.logger.log_info(f"model_params:{model_params}")
|
|
|
self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
|
|
|
+ # self.dynamic_chunk_strategy = "equal"
|
|
|
self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
|
|
|
|
|
|
# Initialize tokenizer
|
|
|
@@ -87,92 +108,86 @@ class TritonPythonModel:
|
|
|
raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
|
|
|
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
|
|
|
self.default_spk_info = spk_info["001"]
|
|
|
-
|
|
|
- def forward_llm(self, input_ids):
|
|
|
+ self.http_client = httpx.AsyncClient()
|
|
|
+
|
|
|
+ def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
|
|
|
+ """Converts a tensor or list of speech token IDs to a string representation."""
|
|
|
+ if isinstance(speech_tokens, torch.Tensor):
|
|
|
+ # Ensure tensor is on CPU and flattened
|
|
|
+ speech_tokens = speech_tokens.cpu().numpy().flatten().tolist()
|
|
|
+
|
|
|
+ speech_id_str = ""
|
|
|
+ for token_id in speech_tokens:
|
|
|
+ # Convert token ID back to the speech number N
|
|
|
+ token_num = token_id - ORIGINAL_VOCAB_SIZE
|
|
|
+ speech_id_str += f"<|s_{token_num}|>"
|
|
|
+ return speech_id_str
|
|
|
+
|
|
|
+ async def forward_llm_async(self, target_text: str, reference_text: str, prompt_speech_tokens: Union[torch.Tensor, List]):
|
|
|
"""
|
|
|
- Prepares the response from the language model based on the provided
|
|
|
- inputs. Creates a `pb_utils.InferenceRequest` object with passed
|
|
|
- `llm_request_inputs` to send to a decoupled TensorRTLLM model.
|
|
|
- For each response from the language model:
|
|
|
- - Checks for errors and raise an exception if any are found.
|
|
|
- - Extracts the "output_ids" tensor from the response.
|
|
|
- - Determines the finish reason based on the presence of the
|
|
|
- end-of-sequence token or reaching the maximum length.
|
|
|
- - Appends the generated token IDs to `output_ids`.
|
|
|
- - If the finish reason is determined, decodes the output IDs to text
|
|
|
- and prepares the final response.
|
|
|
-
|
|
|
- The final response includes the generated text, finish reason,
|
|
|
- completion tokens, prompt tokens, and total tokens.
|
|
|
-
|
|
|
- Parameters
|
|
|
- ----------
|
|
|
- - llm_request_inputs (dict): A dictionary containing the inputs for the language model.
|
|
|
-
|
|
|
- Returns
|
|
|
- -------
|
|
|
- - pb_utils.InferenceResponse: The response object containing the generated text and additional metadata.
|
|
|
+ Asynchronously sends a request to the TRTLLM-serve endpoint and processes the streaming response.
|
|
|
"""
|
|
|
- # convert input_ids to numpy, with shape [1, sequence_length]
|
|
|
- input_ids = input_ids.cpu().numpy()
|
|
|
- max_tokens = 750
|
|
|
- input_dict = {
|
|
|
- "request_output_len": np.array([[max_tokens]], dtype=np.int32),
|
|
|
- "end_id": np.array([[self.eos_token_id]], dtype=np.int32),
|
|
|
- "pad_id": np.array([[self.eos_token_id]], dtype=np.int32),
|
|
|
- "streaming": np.array([[self.decoupled]], dtype=np.bool_),
|
|
|
- "runtime_top_p": np.array([[0.95]], dtype=np.float32),
|
|
|
- "runtime_top_k": np.array([[50]], dtype=np.int32),
|
|
|
- "temperature": np.array([[0.8]], dtype=np.float32),
|
|
|
- "repetition_penalty": np.array([[1.1]], dtype=np.float32),
|
|
|
- "random_seed": np.array([[42]], dtype=np.uint64),
|
|
|
- "input_ids": input_ids,
|
|
|
- "input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
|
|
|
- }
|
|
|
+ full_text = f"{reference_text}{target_text}"
|
|
|
+ prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens)
|
|
|
|
|
|
- # Convert inputs to Triton tensors
|
|
|
- input_tensor_list = [
|
|
|
- pb_utils.Tensor(k, v) for k, v in input_dict.items()
|
|
|
+ chat = [
|
|
|
+ {"role": "user", "content": full_text},
|
|
|
+ {"role": "assistant", "content": prompt_speech_tokens_str}
|
|
|
]
|
|
|
|
|
|
- # Create and execute inference request
|
|
|
- llm_request = pb_utils.InferenceRequest(
|
|
|
- model_name="tensorrt_llm",
|
|
|
- requested_output_names=["output_ids", "sequence_length"],
|
|
|
- inputs=input_tensor_list,
|
|
|
- )
|
|
|
-
|
|
|
- llm_responses = llm_request.exec(decoupled=self.decoupled)
|
|
|
- if self.decoupled:
|
|
|
- for llm_response in llm_responses:
|
|
|
- if llm_response.has_error():
|
|
|
- raise pb_utils.TritonModelException(llm_response.error().message())
|
|
|
-
|
|
|
- # Extract and process output
|
|
|
- output_ids = pb_utils.get_output_tensor_by_name(
|
|
|
- llm_response, "output_ids").as_numpy()
|
|
|
- seq_lens = pb_utils.get_output_tensor_by_name(
|
|
|
- llm_response, "sequence_length").as_numpy()
|
|
|
-
|
|
|
- # Get actual output IDs up to the sequence length
|
|
|
- actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
|
|
|
-
|
|
|
- yield actual_output_ids
|
|
|
- else:
|
|
|
- llm_response = llm_responses
|
|
|
- if llm_response.has_error():
|
|
|
- raise pb_utils.TritonModelException(llm_response.error().message())
|
|
|
-
|
|
|
- # Extract and process output
|
|
|
- output_ids = pb_utils.get_output_tensor_by_name(
|
|
|
- llm_response, "output_ids").as_numpy()
|
|
|
- seq_lens = pb_utils.get_output_tensor_by_name(
|
|
|
- llm_response, "sequence_length").as_numpy()
|
|
|
+ payload = {
|
|
|
+ "model": "trt_engines_bfloat16",
|
|
|
+ "messages": chat,
|
|
|
+ "max_tokens": 750,
|
|
|
+ "temperature": 0.8,
|
|
|
+ "top_p": 0.95,
|
|
|
+ "top_k": 50,
|
|
|
+ "repetition_penalty": 1.1,
|
|
|
+ "stop": ["<|eos1|>", "<|eos|>"],
|
|
|
+ "stream": True,
|
|
|
+ }
|
|
|
|
|
|
- # Get actual output IDs up to the sequence length
|
|
|
- actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
|
|
|
+ api_base = "http://localhost:8000/v1/chat/completions"
|
|
|
+
|
|
|
+ buffer = ""
|
|
|
+ async with self.http_client.stream("POST", api_base, json=payload, timeout=None) as response:
|
|
|
+ print(f"start httpx.AsyncClient, target_text: {target_text[:5]}, time: {datetime.now()}")
|
|
|
+ print(f"start response.aiter_lines, target_text: {target_text[:5]}, time: {datetime.now()}")
|
|
|
+ response.raise_for_status()
|
|
|
+ async for line in response.aiter_lines():
|
|
|
+ if line.startswith("data: "):
|
|
|
+ line_data = line[len("data: "):].strip()
|
|
|
+ if line_data == "[DONE]":
|
|
|
+ break
|
|
|
+ try:
|
|
|
+ json_data = json.loads(line_data)
|
|
|
+ content = json_data.get("choices", [{}])[0].get("delta", {}).get("content")
|
|
|
+ if content:
|
|
|
+ buffer += content
|
|
|
+ print(f"buffer: {buffer}, target_text: {target_text[:5]}, time: {datetime.now()}")
|
|
|
+ while True:
|
|
|
+ match = re.search(r"<\|s_(\d+)\|>", buffer)
|
|
|
+ if not match:
|
|
|
+ break
|
|
|
+
|
|
|
+ token_num = int(match.group(1))
|
|
|
+ final_id = token_num + ORIGINAL_VOCAB_SIZE
|
|
|
+ yield final_id
|
|
|
+ buffer = buffer[match.end():]
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ self.logger.log_info(f"Skipping non-JSON line: {line_data}")
|
|
|
+ continue
|
|
|
+
|
|
|
+ # Process any remaining complete tokens in the buffer after the stream ends
|
|
|
+ while True:
|
|
|
+ match = re.search(r"<\|s_(\d+)\|>", buffer)
|
|
|
+ if not match:
|
|
|
+ break
|
|
|
+ token_num = int(match.group(1))
|
|
|
+ final_id = token_num + ORIGINAL_VOCAB_SIZE
|
|
|
+ yield final_id
|
|
|
+ buffer = buffer[match.end():]
|
|
|
|
|
|
- yield actual_output_ids
|
|
|
|
|
|
def forward_audio_tokenizer(self, wav, wav_len):
|
|
|
"""Forward pass through the audio tokenizer component.
|
|
|
@@ -225,7 +240,7 @@ class TritonPythonModel:
|
|
|
|
|
|
return prompt_spk_embedding
|
|
|
|
|
|
- def forward_token2wav(
|
|
|
+ async def forward_token2wav(
|
|
|
self,
|
|
|
index: int,
|
|
|
target_speech_tokens: torch.Tensor,
|
|
|
@@ -247,17 +262,19 @@ class TritonPythonModel:
|
|
|
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
|
|
|
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
|
|
|
inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
|
|
|
-
|
|
|
+
|
|
|
# Create and execute inference request
|
|
|
inference_request = pb_utils.InferenceRequest(
|
|
|
model_name='token2wav_dit',
|
|
|
- requested_output_names=['waveform'],
|
|
|
+ requested_output_names=[
|
|
|
+ "waveform",
|
|
|
+ ],
|
|
|
inputs=inputs_tensor,
|
|
|
request_id=request_id,
|
|
|
parameters={"priority": index+1},
|
|
|
)
|
|
|
|
|
|
- inference_response = inference_request.exec()
|
|
|
+ inference_response = await inference_request.async_exec()
|
|
|
if inference_response.has_error():
|
|
|
raise pb_utils.TritonModelException(inference_response.error().message())
|
|
|
|
|
|
@@ -267,14 +284,6 @@ class TritonPythonModel:
|
|
|
|
|
|
return waveform
|
|
|
|
|
|
- def parse_input(self, text, prompt_text, prompt_speech_tokens):
|
|
|
- total_text = f"{prompt_text}{text}"
|
|
|
- prompt = self.prompt_template.format(input_text=total_text)
|
|
|
- input_ids = self.tokenizer.encode(prompt)
|
|
|
- input_ids = torch.tensor([input_ids], dtype=torch.int32)
|
|
|
- input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1)
|
|
|
- return input_ids
|
|
|
-
|
|
|
def _extract_speech_feat(self, speech):
|
|
|
speech_feat = mel_spectrogram(
|
|
|
speech,
|
|
|
@@ -292,106 +301,75 @@ class TritonPythonModel:
|
|
|
speech_feat = speech_feat.unsqueeze(dim=0)
|
|
|
return speech_feat
|
|
|
|
|
|
- def _llm_gen_thread(self, generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag):
|
|
|
- for generated_ids in generated_ids_iter:
|
|
|
- generated_ids = generated_ids.tolist()
|
|
|
- if len(generated_ids) == 0:
|
|
|
- break
|
|
|
- semantic_token_ids_arr.extend(generated_ids)
|
|
|
- llm_is_done_flag[0] = True
|
|
|
+ async def _process_request(self, request):
|
|
|
+ request_id = request.request_id()
|
|
|
+ # Extract input tensors
|
|
|
+ wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
|
|
+
|
|
|
+ # Process reference audio through audio tokenizer
|
|
|
+ if wav is not None:
|
|
|
+ wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
|
|
+ prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
|
|
|
+ prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
|
|
|
+
|
|
|
+ wav_tensor = wav.as_numpy()
|
|
|
+ wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
|
|
|
+ print(f"wav_tensor: {wav_tensor.shape}, time: {datetime.now()}")
|
|
|
+ prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
|
|
|
+ speech_feat = self._extract_speech_feat(prompt_speech_resample)
|
|
|
+ token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
|
|
|
+ prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
|
|
|
+ prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
|
|
|
+
|
|
|
+ reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
|
|
+ reference_text = reference_text[0][0].decode('utf-8')
|
|
|
+ # prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
|
|
|
+
|
|
|
+ # reference_text = self.default_spk_info["prompt_text"]
|
|
|
+ # prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
|
|
|
+ # prompt_speech_feat = None
|
|
|
+ # prompt_spk_embedding = None
|
|
|
|
|
|
- def execute(self, requests):
|
|
|
- """Execute inference on the batched requests.
|
|
|
+ else:
|
|
|
+ # using pre-cached reference text
|
|
|
+ assert False, "using pre-cached reference text is not supported"
|
|
|
+ reference_text = self.default_spk_info["prompt_text"]
|
|
|
+ prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
|
|
|
+ prompt_speech_feat = None
|
|
|
+ prompt_spk_embedding = None
|
|
|
|
|
|
- Args:
|
|
|
- requests: List of inference requests
|
|
|
+ target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
|
|
|
+ target_text = target_text[0][0].decode('utf-8')
|
|
|
+ print(f"target_text: {target_text}, time: {datetime.now()}")
|
|
|
|
|
|
- Returns:
|
|
|
- List of inference responses containing generated audio
|
|
|
- """
|
|
|
- responses = []
|
|
|
-
|
|
|
- for request in requests:
|
|
|
- request_id = request.request_id()
|
|
|
- # Extract input tensors
|
|
|
- wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
|
|
-
|
|
|
- # Process reference audio through audio tokenizer
|
|
|
- if wav is not None:
|
|
|
- wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
|
|
- prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
|
|
|
- prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
|
|
|
-
|
|
|
- wav_tensor = wav.as_numpy()
|
|
|
- wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
|
|
|
- prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
|
|
|
- speech_feat = self._extract_speech_feat(prompt_speech_resample)
|
|
|
- token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
|
|
|
- prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
|
|
|
- prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
|
|
|
-
|
|
|
- reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
|
|
- reference_text = reference_text[0][0].decode('utf-8')
|
|
|
- # prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
|
|
|
-
|
|
|
- # reference_text = self.default_spk_info["prompt_text"]
|
|
|
- # prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
|
|
|
- # prompt_speech_feat = None
|
|
|
- # prompt_spk_embedding = None
|
|
|
-
|
|
|
- else:
|
|
|
- assert False, "wav is None"
|
|
|
- # using pre-cached reference text
|
|
|
- reference_text = self.default_spk_info["prompt_text"]
|
|
|
- prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
|
|
|
- prompt_speech_feat = None
|
|
|
- prompt_spk_embedding = None
|
|
|
-
|
|
|
- target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
|
|
|
- target_text = target_text[0][0].decode('utf-8')
|
|
|
-
|
|
|
- # Prepare prompt for LLM
|
|
|
- input_ids = self.parse_input(
|
|
|
- text=target_text,
|
|
|
- prompt_text=reference_text,
|
|
|
+ if self.decoupled:
|
|
|
+ response_sender = request.get_response_sender()
|
|
|
+
|
|
|
+ semantic_token_ids_arr = []
|
|
|
+ token_offset, chunk_index = 0, 0
|
|
|
+ start_time = time.time()
|
|
|
+ this_token_hop_len = self.token_hop_len
|
|
|
+ print(f"start forward_llm_async, target_text: {target_text[:5]}, time: {datetime.now()}")
|
|
|
+ async for generated_ids in self.forward_llm_async(
|
|
|
+ target_text=target_text,
|
|
|
+ reference_text=reference_text,
|
|
|
prompt_speech_tokens=prompt_speech_tokens,
|
|
|
- )
|
|
|
-
|
|
|
- # Generate semantic tokens with LLM
|
|
|
- generated_ids_iter = self.forward_llm(input_ids)
|
|
|
-
|
|
|
- if self.decoupled:
|
|
|
- response_sender = request.get_response_sender()
|
|
|
-
|
|
|
- semantic_token_ids_arr = []
|
|
|
- llm_is_done_flag = [False]
|
|
|
-
|
|
|
- llm_thread = threading.Thread(
|
|
|
- target=self._llm_gen_thread,
|
|
|
- args=(generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag)
|
|
|
- )
|
|
|
-
|
|
|
- llm_thread.start()
|
|
|
-
|
|
|
- token_offset, chunk_index = 0, 0
|
|
|
- start_time = time.time()
|
|
|
- this_token_hop_len = self.token_hop_len
|
|
|
-
|
|
|
+ ):
|
|
|
+ if not generated_ids:
|
|
|
+ break
|
|
|
+ semantic_token_ids_arr.append(generated_ids)
|
|
|
+ print(f"generated_ids: {generated_ids}, target_text: {target_text[:5]}, time: {datetime.now()}")
|
|
|
while True:
|
|
|
pending_num = len(semantic_token_ids_arr) - token_offset
|
|
|
-
|
|
|
- if llm_is_done_flag[0]:
|
|
|
- break
|
|
|
-
|
|
|
if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
|
|
|
this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
|
|
|
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
|
|
-
|
|
|
- sub_tts_speech = self.forward_token2wav(
|
|
|
+ print(f"chunk_index: {chunk_index}, target_text: {target_text[:5]}, time: {datetime.now()}")
|
|
|
+ sub_tts_speech = await self.forward_token2wav(
|
|
|
chunk_index,
|
|
|
this_tts_speech_token, request_id, wav, wav_len, False
|
|
|
)
|
|
|
-
|
|
|
+ print(f"finish token2wav, target_text: {target_text[:5]}, time: {datetime.now()}")
|
|
|
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
|
|
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
|
|
response_sender.send(inference_response)
|
|
|
@@ -401,6 +379,8 @@ class TritonPythonModel:
|
|
|
|
|
|
if self.dynamic_chunk_strategy == "exponential":
|
|
|
this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
|
|
|
+ elif self.dynamic_chunk_strategy == "equal":
|
|
|
+ this_token_hop_len = self.token_hop_len
|
|
|
elif self.dynamic_chunk_strategy == "time_based":
|
|
|
# see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
|
|
|
cost_time = time.time() - start_time
|
|
|
@@ -420,19 +400,36 @@ class TritonPythonModel:
|
|
|
this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
|
|
|
chunk_index += 1
|
|
|
else:
|
|
|
- time.sleep(0.02)
|
|
|
-
|
|
|
- this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
|
|
- sub_tts_speech = self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
|
|
|
- audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
|
|
- inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
|
|
- response_sender.send(inference_response)
|
|
|
-
|
|
|
- llm_thread.join()
|
|
|
- response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
|
|
- self.logger.log_info("send tritonserver_response_complete_final to end")
|
|
|
- else:
|
|
|
- raise NotImplementedError("Decoupled mode is not supported")
|
|
|
-
|
|
|
- if not self.decoupled:
|
|
|
- return responses
|
|
|
+ break
|
|
|
+
|
|
|
+ this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
|
|
+ sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
|
|
|
+ audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
|
|
+ inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
|
|
+ response_sender.send(inference_response)
|
|
|
+
|
|
|
+ response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
|
|
+ self.logger.log_info("send tritonserver_response_complete_final to end")
|
|
|
+ else:
|
|
|
+ raise NotImplementedError("Decoupled mode is not supported")
|
|
|
+
|
|
|
+ async def execute(self, requests):
|
|
|
+ """Execute inference on the batched requests.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ requests: List of inference requests
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ List of inference responses containing generated audio
|
|
|
+ """
|
|
|
+ tasks = [
|
|
|
+ asyncio.create_task(self._process_request(request))
|
|
|
+ for request in requests
|
|
|
+ ]
|
|
|
+ await asyncio.gather(*tasks)
|
|
|
+ return None
|
|
|
+
|
|
|
+ def finalize(self):
|
|
|
+ self.logger.log_info("Finalizing CosyVoice DIT model")
|
|
|
+ if hasattr(self, "http_client"):
|
|
|
+ asyncio.run(self.http_client.aclose())
|