Quellcode durchsuchen

support streaming tts

root vor 3 Monaten
Ursprung
Commit
73d261dd48

+ 36 - 28
runtime/triton_trtllm/client_grpc.py

@@ -395,38 +395,45 @@ def run_sync_streaming_inference(
     # Reconstruct audio using cross-fade (from client_grpc_streaming.py)
     actual_duration = 0
     if audios:
-        cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
-        fade_out = np.linspace(1, 0, cross_fade_samples)
-        fade_in = np.linspace(0, 1, cross_fade_samples)
-        reconstructed_audio = None
-
-        # Simplified reconstruction based on client_grpc_streaming.py
-        if not audios:
-            print("Warning: No audio chunks received.")
-            reconstructed_audio = np.array([], dtype=np.float32)  # Empty array
-        elif len(audios) == 1:
-            reconstructed_audio = audios[0]
+        # Only spark_tts model uses cross-fade
+        if model_name == "spark_tts":
+            cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
+            fade_out = np.linspace(1, 0, cross_fade_samples)
+            fade_in = np.linspace(0, 1, cross_fade_samples)
+            reconstructed_audio = None
+
+            # Simplified reconstruction based on client_grpc_streaming.py
+            if not audios:
+                print("Warning: No audio chunks received.")
+                reconstructed_audio = np.array([], dtype=np.float32)  # Empty array
+            elif len(audios) == 1:
+                reconstructed_audio = audios[0]
+            else:
+                reconstructed_audio = audios[0][:-cross_fade_samples]  # Start with first chunk minus overlap
+                for i in range(1, len(audios)):
+                    # Cross-fade section
+                    cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
+                                        audios[i - 1][-cross_fade_samples:] * fade_out)
+                    # Middle section of the current chunk
+                    middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
+                    # Concatenate
+                    reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
+                # Add the last part of the final chunk
+                reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
+
+            if reconstructed_audio is not None and reconstructed_audio.size > 0:
+                actual_duration = len(reconstructed_audio) / save_sample_rate
+                # Save reconstructed audio
+                sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
+            else:
+                print("Warning: No audio chunks received or reconstructed.")
+                actual_duration = 0  # Set duration to 0 if no audio
         else:
-            reconstructed_audio = audios[0][:-cross_fade_samples]  # Start with first chunk minus overlap
-            for i in range(1, len(audios)):
-                # Cross-fade section
-                cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
-                                       audios[i - 1][-cross_fade_samples:] * fade_out)
-                # Middle section of the current chunk
-                middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
-                # Concatenate
-                reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
-            # Add the last part of the final chunk
-            reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
-
-        if reconstructed_audio is not None and reconstructed_audio.size > 0:
+            reconstructed_audio = np.concatenate(audios)
+            print(f"reconstructed_audio: {reconstructed_audio.shape}")
             actual_duration = len(reconstructed_audio) / save_sample_rate
             # Save reconstructed audio
-            os.makedirs(os.path.dirname(audio_save_path), exist_ok=True)
             sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
-        else:
-            print("Warning: No audio chunks received or reconstructed.")
-            actual_duration = 0  # Set duration to 0 if no audio
 
     else:
         print("Warning: No audio chunks received.")
@@ -667,6 +674,7 @@ async def main():
     manifest_item_list = split_data(manifest_item_list, num_tasks)
 
     os.makedirs(args.log_dir, exist_ok=True)
+
     tasks = []
     start_time = time.time()
     for i in range(num_tasks):

+ 60 - 17
runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py

@@ -114,6 +114,7 @@ class TritonPythonModel:
             "runtime_top_p": np.array([[0.95]], dtype=np.float32),
             "runtime_top_k": np.array([[50]], dtype=np.int32),
             "temperature": np.array([[0.8]], dtype=np.float32),
+            "repetition_penalty": np.array([[1.1]], dtype=np.float32),
             "input_ids": input_ids,
             "input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
         }
@@ -144,6 +145,7 @@ class TritonPythonModel:
 
                 # Get actual output IDs up to the sequence length
                 actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
+                print(f"actual_output_ids: {actual_output_ids}")
 
                 yield actual_output_ids
         else:
@@ -193,7 +195,10 @@ class TritonPythonModel:
             prompt_speech_tokens: torch.Tensor,
             prompt_speech_feat: torch.Tensor,
             prompt_spk_embedding: torch.Tensor,
-            target_speech_tokens: torch.Tensor) -> torch.Tensor:
+            target_speech_tokens: torch.Tensor,
+            request_id: str,
+            token_offset: int = None,
+            finalize: bool = None) -> torch.Tensor:
         """Forward pass through the vocoder component.
 
         Args:
@@ -210,11 +215,22 @@ class TritonPythonModel:
         prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
         target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
 
+        inputs_tensor = [prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor]
+
+        if token_offset is not None:
+            assert finalize is not None
+            token_offset_tensor = pb_utils.Tensor("token_offset", np.array([[token_offset]], dtype=np.int32))
+            finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
+            inputs_tensor.append(token_offset_tensor)
+            inputs_tensor.append(finalize_tensor)
+
+
         # Create and execute inference request
         inference_request = pb_utils.InferenceRequest(
             model_name='token2wav',
             requested_output_names=['waveform'],
-            inputs=[prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor]
+            inputs=inputs_tensor,
+            request_id=request_id,
         )
 
         inference_response = inference_request.exec()
@@ -275,6 +291,7 @@ class TritonPythonModel:
         responses = []
 
         for request in requests:
+            request_id = request.request_id()
             # Extract input tensors
             wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
             wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
@@ -292,6 +309,11 @@ class TritonPythonModel:
             prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
             prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
 
+
+            flow_prompt_speech_token_len = prompt_speech_tokens.shape[-1]
+            token_hop_len = 25
+            flow_pre_lookahead_len = 3
+
             reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
             reference_text = reference_text[0][0].decode('utf-8')
 
@@ -308,24 +330,46 @@ class TritonPythonModel:
             # Generate semantic tokens with LLM
             generated_ids_iter = self.forward_llm(input_ids)
 
+            prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
+            print(f"here2")
             if self.decoupled:
                 response_sender = request.get_response_sender()
-                request_id = request.request_id()
-                generated_ids = []
-                for generated_id in generated_ids_iter:
-                    # convert the numpy array into a int32 tensor
-                    generated_id = generated_id.tolist()
-                    if len(generated_id) > 0:
-                        assert len(generated_id) == 1, "Generated ID is not a single integer"
-                        generated_ids.append(generated_id[0])
-                generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(torch.int32).to(self.device)
-                prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
-                audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
 
-                # Prepare response
-                audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
+
+
+                semantic_token_ids_arr, token_offset = [], 0
+                for generated_ids in generated_ids_iter:
+
+                    generated_ids = generated_ids.tolist()
+                    print(f"generated_id: {generated_ids}")
+                    semantic_token_ids_arr.extend(generated_ids)
+
+                    prompt_token_pad = int(np.ceil(flow_prompt_speech_token_len / token_hop_len) * token_hop_len - flow_prompt_speech_token_len)
+                    this_token_hop_len = token_hop_len + prompt_token_pad if token_offset == 0 else token_hop_len
+                    print(f"this_token_hop_len: {this_token_hop_len}")
+                    if len(semantic_token_ids_arr) - token_offset >= this_token_hop_len + flow_pre_lookahead_len:
+                        this_tts_speech_token = semantic_token_ids_arr[:token_offset + this_token_hop_len + flow_pre_lookahead_len]
+                        print(f"this_tts_speech_token: {this_tts_speech_token}")
+                        this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
+                        print(f"here3")
+                        
+                        sub_tts_speech = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, this_tts_speech_token, request_id, token_offset, False)
+                        print(f"here4")
+                        # Prepare response to send
+                        audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
+                        inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
+                        response_sender.send(inference_response)
+
+                        self.logger.log_info(f"[{request_id}]")
+                        token_offset += this_token_hop_len
+                print(f"here")
+
+                this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device)
+                sub_tts_speech = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, this_tts_speech_token, request_id, token_offset, True)
+                audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
                 inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
                 response_sender.send(inference_response)
+
                 response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
                 self.logger.log_info("send tritonserver_response_complete_final to end")
             else:
@@ -334,8 +378,7 @@ class TritonPythonModel:
                 if generated_ids is None or len(generated_ids) == 0:
                     raise pb_utils.TritonModelException("Generated IDs is None or empty")
 
-                prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
-                audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
+                audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids, request_id)
 
                 # Prepare response
                 audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))

+ 91 - 19
runtime/triton_trtllm/model_repo/token2wav/1/model.py

@@ -32,12 +32,16 @@ from typing import List, Dict
 
 import torch
 from torch.utils.dlpack import to_dlpack
+from torch.nn import functional as F
 
 import triton_python_backend_utils as pb_utils
 
 from hyperpyyaml import load_hyperpyyaml
+from cosyvoice.utils.common import fade_in_out
 from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
 from cosyvoice.utils.common import TrtContextWrapper
+from collections import defaultdict
+import numpy as np
 
 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 logger = logging.getLogger(__name__)
@@ -81,6 +85,13 @@ class CosyVoice2Model:
         if self.fp16 is True:
             self.flow.half()
 
+        # streaming tts config
+        self.token_hop_len = 25
+        self.mel_cache_len = 8
+        self.source_cache_len = int(self.mel_cache_len * 480)
+        self.speech_window = np.hamming(2 * self.source_cache_len)
+        self.hift_cache_dict = defaultdict(lambda: None)
+
     def load_jit(self, flow_encoder_model):
         flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
         self.flow.encoder = flow_encoder
@@ -112,6 +123,43 @@ class CosyVoice2Model:
         return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
 
 
+    def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
+        with torch.cuda.amp.autocast(self.fp16):
+            tts_mel, _ = self.flow.inference(token=token.to(self.device),
+                                             token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
+                                             prompt_token=prompt_token.to(self.device),
+                                             prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
+                                             prompt_feat=prompt_feat.to(self.device),
+                                             prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
+                                             embedding=embedding.to(self.device),
+                                             streaming=stream,
+                                             finalize=finalize)
+        tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
+        # append hift cache
+        if self.hift_cache_dict[uuid] is not None:
+            hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
+            tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
+        else:
+            hift_cache_source = torch.zeros(1, 1, 0)
+        # keep overlap mel and hift cache
+        if finalize is False:
+            tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
+            if self.hift_cache_dict[uuid] is not None:
+                tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
+            self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
+                                          'source': tts_source[:, :, -self.source_cache_len:],
+                                          'speech': tts_speech[:, -self.source_cache_len:]}
+            tts_speech = tts_speech[:, :-self.source_cache_len]
+        else:
+            if speed != 1.0:
+                assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
+                tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
+            tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
+            if self.hift_cache_dict[uuid] is not None:
+                tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
+        return tts_speech
+
+
 class TritonPythonModel:
     """Triton Python model for vocoder.
 
@@ -166,25 +214,49 @@ class TritonPythonModel:
             prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
             target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
 
-            tts_mel, _ = self.token2wav_model.model.flow.inference(
-                token=target_speech_tokens,
-                token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
-                    self.device
-                ),
-                prompt_token=prompt_speech_tokens,
-                prompt_token_len=torch.tensor(
-                    [prompt_speech_tokens.shape[1]], dtype=torch.int32
-                ).to(self.device),
-                prompt_feat=prompt_speech_feat,
-                prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
-                embedding=prompt_spk_embedding,
-                streaming=False,
-                finalize=True,
-            )
-
-            audio_hat, _ = self.token2wav_model.model.hift.inference(
-                speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
-            )
+            # We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.
+            token_offset = pb_utils.get_input_tensor_by_name(request, "token_offset")
+            if token_offset is not None:
+                token_offset = token_offset.as_numpy().item()
+                finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
+                if not finalize:
+                    stream = True
+                else:
+                    stream = False
+                request_id = request.request_id()
+                print(f"token_offset: {token_offset}, finalize: {finalize}, request_id: {request_id}")
+                audio_hat = self.token2wav_model.model.token2wav(token=target_speech_tokens,
+                                                                 prompt_token=prompt_speech_tokens,
+                                                                 prompt_feat=prompt_speech_feat,
+                                                                 embedding=prompt_spk_embedding,
+                                                                 token_offset=token_offset,
+                                                                 uuid=request_id,
+                                                                 stream=stream,
+                                                                 finalize=finalize)
+                if finalize:
+                    print(f"dict keys: {self.token2wav_model.model.hift_cache_dict.keys()}")
+                    self.token2wav_model.model.hift_cache_dict.pop(request_id)
+
+            else:
+                tts_mel, _ = self.token2wav_model.model.flow.inference(
+                    token=target_speech_tokens,
+                    token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
+                        self.device
+                    ),
+                    prompt_token=prompt_speech_tokens,
+                    prompt_token_len=torch.tensor(
+                        [prompt_speech_tokens.shape[1]], dtype=torch.int32
+                    ).to(self.device),
+                    prompt_feat=prompt_speech_feat,
+                    prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
+                    embedding=prompt_spk_embedding,
+                    streaming=False,
+                    finalize=True,
+                )
+
+                audio_hat, _ = self.token2wav_model.model.hift.inference(
+                    speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
+                )
 
             generated_wave = audio_hat.squeeze(0).cpu().numpy()
 

+ 14 - 0
runtime/triton_trtllm/model_repo/token2wav/config.pbtxt

@@ -45,6 +45,20 @@ input [
     name: "prompt_spk_embedding"
     data_type: TYPE_FP16
     dims: [-1]
+  },
+  {
+    name: "token_offset"
+    data_type: TYPE_INT32
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  {
+    name: "finalize"
+    data_type: TYPE_BOOL
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
   }
 ]
 output [