Browse Source

Merge branch 'main' of github.com:FunAudioLLM/CosyVoice into main

lyuxiang.lx 1 month ago
parent
commit
dd5cdb6ebf
1 changed files with 5 additions and 3 deletions
  1. 5 3
      runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py

+ 5 - 3
runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py

@@ -28,6 +28,7 @@ import json
 import os
 import threading
 import time
+from uuid import uuid4
 
 import numpy as np
 import torch
@@ -364,6 +365,7 @@ class TritonPythonModel:
             # Generate semantic tokens with LLM
             generated_ids_iter = self.forward_llm(input_ids)
 
+            token2wav_request_id = request_id or str(uuid4())
             if self.decoupled:
                 response_sender = request.get_response_sender()
 
@@ -392,7 +394,7 @@ class TritonPythonModel:
                         this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
 
                         sub_tts_speech = self.forward_token2wav(
-                            this_tts_speech_token, request_id, prompt_speech_tokens,
+                            this_tts_speech_token, token2wav_request_id, prompt_speech_tokens,
                             prompt_speech_feat, prompt_spk_embedding, token_offset, False
                         )
 
@@ -427,7 +429,7 @@ class TritonPythonModel:
                         time.sleep(0.02)
 
                 this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device)
-                sub_tts_speech = self.forward_token2wav(this_tts_speech_token, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, token_offset, True)
+                sub_tts_speech = self.forward_token2wav(this_tts_speech_token, token2wav_request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, token_offset, True)
                 audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
                 inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
                 response_sender.send(inference_response)
@@ -441,7 +443,7 @@ class TritonPythonModel:
                 if generated_ids is None or len(generated_ids) == 0:
                     raise pb_utils.TritonModelException("Generated IDs is None or empty")
 
-                audio = self.forward_token2wav(generated_ids, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding)
+                audio = self.forward_token2wav(generated_ids, token2wav_request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding)
 
                 # Prepare response
                 audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))