|
|
@@ -43,6 +43,7 @@ import torchaudio
|
|
|
|
|
|
|
|
|
from matcha.utils.audio import mel_spectrogram
|
|
|
+from datetime import datetime
|
|
|
|
|
|
ORIGINAL_VOCAB_SIZE = 151663
|
|
|
torch.set_num_threads(1)
|
|
|
@@ -86,6 +87,7 @@ class TritonPythonModel:
|
|
|
model_params = {k: v["string_value"] for k, v in parameters.items()}
|
|
|
self.logger.log_info(f"model_params:{model_params}")
|
|
|
self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
|
|
|
+ # self.dynamic_chunk_strategy = "equal"
|
|
|
self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
|
|
|
|
|
|
# Initialize tokenizer
|
|
|
@@ -105,7 +107,9 @@ class TritonPythonModel:
|
|
|
if not os.path.exists(spk_info_path):
|
|
|
raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
|
|
|
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
|
|
|
- # self.default_spk_info = spk_info["001"]
|
|
|
+ self.default_spk_info = spk_info["001"]
|
|
|
+ self.http_client = httpx.AsyncClient()
|
|
|
+ self.runtime_cache = {}
|
|
|
|
|
|
def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
|
|
|
"""Converts a tensor or list of speech token IDs to a string representation."""
|
|
|
@@ -131,7 +135,6 @@ class TritonPythonModel:
|
|
|
{"role": "user", "content": full_text},
|
|
|
{"role": "assistant", "content": prompt_speech_tokens_str}
|
|
|
]
|
|
|
- print(chat)
|
|
|
|
|
|
payload = {
|
|
|
"model": "trt_engines_bfloat16",
|
|
|
@@ -148,31 +151,33 @@ class TritonPythonModel:
|
|
|
api_base = "http://localhost:8000/v1/chat/completions"
|
|
|
|
|
|
buffer = ""
|
|
|
- async with httpx.AsyncClient() as client:
|
|
|
- async with client.stream("POST", api_base, json=payload, timeout=None) as response:
|
|
|
- response.raise_for_status()
|
|
|
- async for line in response.aiter_lines():
|
|
|
- if line.startswith("data: "):
|
|
|
- line_data = line[len("data: "):].strip()
|
|
|
- if line_data == "[DONE]":
|
|
|
- break
|
|
|
- try:
|
|
|
- json_data = json.loads(line_data)
|
|
|
- content = json_data.get("choices", [{}])[0].get("delta", {}).get("content")
|
|
|
- if content:
|
|
|
- buffer += content
|
|
|
- while True:
|
|
|
- match = re.search(r"<\|s_(\d+)\|>", buffer)
|
|
|
- if not match:
|
|
|
- break
|
|
|
-
|
|
|
- token_num = int(match.group(1))
|
|
|
- final_id = token_num + ORIGINAL_VOCAB_SIZE
|
|
|
- yield final_id
|
|
|
- buffer = buffer[match.end():]
|
|
|
- except json.JSONDecodeError:
|
|
|
- self.logger.log_info(f"Skipping non-JSON line: {line_data}")
|
|
|
- continue
|
|
|
+ async with self.http_client.stream("POST", api_base, json=payload, timeout=None) as response:
|
|
|
+ print(f"start httpx.AsyncClient, target_text: {target_text[:5]}, time: {datetime.now()}")
|
|
|
+ print(f"start response.aiter_lines, target_text: {target_text[:5]}, time: {datetime.now()}")
|
|
|
+ response.raise_for_status()
|
|
|
+ async for line in response.aiter_lines():
|
|
|
+ if line.startswith("data: "):
|
|
|
+ line_data = line[len("data: "):].strip()
|
|
|
+ if line_data == "[DONE]":
|
|
|
+ break
|
|
|
+ try:
|
|
|
+ json_data = json.loads(line_data)
|
|
|
+ content = json_data.get("choices", [{}])[0].get("delta", {}).get("content")
|
|
|
+ if content:
|
|
|
+ buffer += content
|
|
|
+ print(f"buffer: {buffer}, target_text: {target_text[:5]}, time: {datetime.now()}")
|
|
|
+ while True:
|
|
|
+ match = re.search(r"<\|s_(\d+)\|>", buffer)
|
|
|
+ if not match:
|
|
|
+ break
|
|
|
+
|
|
|
+ token_num = int(match.group(1))
|
|
|
+ final_id = token_num + ORIGINAL_VOCAB_SIZE
|
|
|
+ yield final_id
|
|
|
+ buffer = buffer[match.end():]
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ self.logger.log_info(f"Skipping non-JSON line: {line_data}")
|
|
|
+ continue
|
|
|
|
|
|
# Process any remaining complete tokens in the buffer after the stream ends
|
|
|
while True:
|
|
|
@@ -236,7 +241,7 @@ class TritonPythonModel:
|
|
|
|
|
|
return prompt_spk_embedding
|
|
|
|
|
|
- def forward_token2wav(
|
|
|
+ async def forward_token2wav(
|
|
|
self,
|
|
|
index: int,
|
|
|
target_speech_tokens: torch.Tensor,
|
|
|
@@ -258,20 +263,57 @@ class TritonPythonModel:
|
|
|
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
|
|
|
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
|
|
|
inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
|
|
|
-
|
|
|
+
|
|
|
+ # optional cache inputs
|
|
|
+ if self.runtime_cache[request_id]["conformer_cnn_cache"] is not None:
|
|
|
+ # inputs_tensor.extend([
|
|
|
+ # pb_utils.Tensor("conformer_cnn_cache", self.runtime_cache[request_id]["conformer_cnn_cache"].as_numpy()),
|
|
|
+ # pb_utils.Tensor("conformer_att_cache", self.runtime_cache[request_id]["conformer_att_cache"].as_numpy()),
|
|
|
+ # pb_utils.Tensor("estimator_cnn_cache", self.runtime_cache[request_id]["estimator_cnn_cache"].as_numpy()),
|
|
|
+ # pb_utils.Tensor("estimator_att_cache", self.runtime_cache[request_id]["estimator_att_cache"].as_numpy()),
|
|
|
+ # pb_utils.Tensor("mel", self.runtime_cache[request_id]["mel"].as_numpy()),
|
|
|
+ # pb_utils.Tensor("source", self.runtime_cache[request_id]["source"].as_numpy()),
|
|
|
+ # pb_utils.Tensor("speech", self.runtime_cache[request_id]["speech"].as_numpy()),
|
|
|
+ # ])
|
|
|
+ inputs_tensor.extend([
|
|
|
+ self.runtime_cache[request_id]["conformer_cnn_cache"],
|
|
|
+ self.runtime_cache[request_id]["conformer_att_cache"],
|
|
|
+ self.runtime_cache[request_id]["estimator_cnn_cache"],
|
|
|
+ self.runtime_cache[request_id]["estimator_att_cache"],
|
|
|
+ self.runtime_cache[request_id]["mel"],
|
|
|
+ self.runtime_cache[request_id]["source"],
|
|
|
+ self.runtime_cache[request_id]["speech"],
|
|
|
+ ])
|
|
|
# Create and execute inference request
|
|
|
inference_request = pb_utils.InferenceRequest(
|
|
|
model_name='token2wav_dit',
|
|
|
- requested_output_names=['waveform'],
|
|
|
+ requested_output_names=[
|
|
|
+ "waveform",
|
|
|
+ "conformer_cnn_cache",
|
|
|
+ "conformer_att_cache",
|
|
|
+ "estimator_cnn_cache",
|
|
|
+ "estimator_att_cache",
|
|
|
+ "mel",
|
|
|
+ "source",
|
|
|
+ "speech",
|
|
|
+ ],
|
|
|
inputs=inputs_tensor,
|
|
|
request_id=request_id,
|
|
|
parameters={"priority": index+1},
|
|
|
)
|
|
|
|
|
|
- inference_response = inference_request.exec()
|
|
|
+ inference_response = await inference_request.async_exec()
|
|
|
if inference_response.has_error():
|
|
|
raise pb_utils.TritonModelException(inference_response.error().message())
|
|
|
|
|
|
+ self.runtime_cache[request_id]["conformer_cnn_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "conformer_cnn_cache")
|
|
|
+ self.runtime_cache[request_id]["conformer_att_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "conformer_att_cache")
|
|
|
+ self.runtime_cache[request_id]["estimator_cnn_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "estimator_cnn_cache")
|
|
|
+ self.runtime_cache[request_id]["estimator_att_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "estimator_att_cache")
|
|
|
+ self.runtime_cache[request_id]["mel"] = pb_utils.get_output_tensor_by_name(inference_response, "mel")
|
|
|
+ self.runtime_cache[request_id]["source"] = pb_utils.get_output_tensor_by_name(inference_response, "source")
|
|
|
+ self.runtime_cache[request_id]["speech"] = pb_utils.get_output_tensor_by_name(inference_response, "speech")
|
|
|
+
|
|
|
# Extract and convert output waveform
|
|
|
waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
|
|
|
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
|
|
|
@@ -297,6 +339,16 @@ class TritonPythonModel:
|
|
|
|
|
|
async def _process_request(self, request):
|
|
|
request_id = request.request_id()
|
|
|
+ if request_id not in self.runtime_cache:
|
|
|
+ self.runtime_cache[request_id] = {
|
|
|
+ "conformer_cnn_cache": None,
|
|
|
+ "conformer_att_cache": None,
|
|
|
+ "estimator_cnn_cache": None,
|
|
|
+ "estimator_att_cache": None,
|
|
|
+ "mel": None,
|
|
|
+ "source": None,
|
|
|
+ "speech": None,
|
|
|
+ }
|
|
|
# Extract input tensors
|
|
|
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
|
|
|
|
|
@@ -308,6 +360,7 @@ class TritonPythonModel:
|
|
|
|
|
|
wav_tensor = wav.as_numpy()
|
|
|
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
|
|
|
+ print(f"wav_tensor: {wav_tensor.shape}, time: {datetime.now()}")
|
|
|
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
|
|
|
speech_feat = self._extract_speech_feat(prompt_speech_resample)
|
|
|
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
|
|
|
@@ -316,7 +369,7 @@ class TritonPythonModel:
|
|
|
|
|
|
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
|
|
reference_text = reference_text[0][0].decode('utf-8')
|
|
|
- # prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
|
|
|
+ prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
|
|
|
|
|
|
# reference_text = self.default_spk_info["prompt_text"]
|
|
|
# prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
|
|
|
@@ -333,6 +386,7 @@ class TritonPythonModel:
|
|
|
|
|
|
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
|
|
|
target_text = target_text[0][0].decode('utf-8')
|
|
|
+ print(f"target_text: {target_text}, time: {datetime.now()}")
|
|
|
|
|
|
if self.decoupled:
|
|
|
response_sender = request.get_response_sender()
|
|
|
@@ -341,7 +395,7 @@ class TritonPythonModel:
|
|
|
token_offset, chunk_index = 0, 0
|
|
|
start_time = time.time()
|
|
|
this_token_hop_len = self.token_hop_len
|
|
|
-
|
|
|
+ print(f"start forward_llm_async, target_text: {target_text[:5]}, time: {datetime.now()}")
|
|
|
async for generated_ids in self.forward_llm_async(
|
|
|
target_text=target_text,
|
|
|
reference_text=reference_text,
|
|
|
@@ -350,18 +404,18 @@ class TritonPythonModel:
|
|
|
if not generated_ids:
|
|
|
break
|
|
|
semantic_token_ids_arr.append(generated_ids)
|
|
|
-
|
|
|
+ print(f"generated_ids: {generated_ids}, target_text: {target_text[:5]}, time: {datetime.now()}")
|
|
|
while True:
|
|
|
pending_num = len(semantic_token_ids_arr) - token_offset
|
|
|
if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
|
|
|
this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
|
|
|
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
|
|
-
|
|
|
- sub_tts_speech = self.forward_token2wav(
|
|
|
+ print(f"chunk_index: {chunk_index}, target_text: {target_text[:5]}, time: {datetime.now()}")
|
|
|
+ sub_tts_speech = await self.forward_token2wav(
|
|
|
chunk_index,
|
|
|
this_tts_speech_token, request_id, wav, wav_len, False
|
|
|
)
|
|
|
-
|
|
|
+ print(f"finish token2wav, target_text: {target_text[:5]}, time: {datetime.now()}")
|
|
|
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
|
|
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
|
|
response_sender.send(inference_response)
|
|
|
@@ -371,6 +425,8 @@ class TritonPythonModel:
|
|
|
|
|
|
if self.dynamic_chunk_strategy == "exponential":
|
|
|
this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
|
|
|
+ elif self.dynamic_chunk_strategy == "equal":
|
|
|
+ this_token_hop_len = self.token_hop_len
|
|
|
elif self.dynamic_chunk_strategy == "time_based":
|
|
|
# see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
|
|
|
cost_time = time.time() - start_time
|
|
|
@@ -393,29 +449,13 @@ class TritonPythonModel:
|
|
|
break
|
|
|
|
|
|
this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
|
|
- sub_tts_speech = self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
|
|
|
+ sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
|
|
|
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
|
|
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
|
|
response_sender.send(inference_response)
|
|
|
-
|
|
|
- ## debug
|
|
|
- ## save semantic_token_ids_arr and reference_text, target_text to a single json file
|
|
|
- # save into a torch .pt
|
|
|
- # for i, item in enumerate(semantic_token_ids_arr):
|
|
|
- # semantic_token_ids_arr[i] = item - ORIGINAL_VOCAB_SIZE
|
|
|
- # import json
|
|
|
- # data = {
|
|
|
- # "semantic_token_ids_arr": semantic_token_ids_arr,
|
|
|
- # "reference_text": reference_text,
|
|
|
- # "target_text": target_text
|
|
|
- # }
|
|
|
- # with open(f"semantic_token_ids_arr_debug_{request_id}.pt", "wb") as f:
|
|
|
- # torch.save(data, f)
|
|
|
- # with open(f"semantic_token_ids_arr_debug_{request_id}.json", "w") as f:
|
|
|
- # json.dump(data, f)
|
|
|
-
|
|
|
- # ##
|
|
|
-
|
|
|
+ if request_id in self.runtime_cache:
|
|
|
+ del self.runtime_cache[request_id]
|
|
|
+ self.logger.log_info(f"Deleted cache for request_id: {request_id}")
|
|
|
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
|
|
self.logger.log_info("send tritonserver_response_complete_final to end")
|
|
|
else:
|
|
|
@@ -436,3 +476,8 @@ class TritonPythonModel:
|
|
|
]
|
|
|
await asyncio.gather(*tasks)
|
|
|
return None
|
|
|
+
|
|
|
+ def finalize(self):
|
|
|
+ self.logger.log_info("Finalizing CosyVoice DIT model")
|
|
|
+ if hasattr(self, "http_client"):
|
|
|
+ asyncio.run(self.http_client.aclose())
|