ソースを参照

add prompt audio cache

yuekaiz 5 ヶ月 前
コミット
6971536358

+ 9 - 6
runtime/triton_trtllm/README.md

@@ -77,16 +77,19 @@ The following results were obtained by decoding on a single L20 GPU with 26 prom
 **Streaming TTS (First Chunk Latency)**
 | Mode | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
 |---|---|---|---|---|
-| Streaming, Decoupled=True | 1 | 220.43 | 218.07 | 0.1237 |
-| Streaming, Decoupled=True | 2 | 476.97 | 369.25 | 0.1022 |
-| Streaming, Decoupled=True | 4 | 1107.34 | 1243.75| 0.0922 |
+| Streaming, use_spk2info_cache=False | 1 | 220.43 | 218.07 | 0.1237 |
+| Streaming, use_spk2info_cache=False | 2 | 476.97 | 369.25 | 0.1022 |
+| Streaming, use_spk2info_cache=False | 4 | 1107.34 | 1243.75| 0.0922 |
+| Streaming, use_spk2info_cache=True | 1 | 189.88 | 184.81 | 0.1155 |
+| Streaming, use_spk2info_cache=True | 2 | 323.04 | 316.83 | 0.0905 |
+| Streaming, use_spk2info_cache=True | 4 | 977.68 | 903.68| 0.0733 |
 
 **Offline TTS (Full Sentence Latency)**
 | Mode | Note | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
 |---|---|---|---|---|---|
-| Offline, Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 758.04 | 615.79 | 0.0891 |
-| Offline, Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1025.93 | 901.68 | 0.0657 |
-| Offline, Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1914.13 | 1783.58 | 0.0610 |
+| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 758.04 | 615.79 | 0.0891 |
+| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1025.93 | 901.68 | 0.0657 |
+| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1914.13 | 1783.58 | 0.0610 |
 
 ### OpenAI-Compatible Server
 

+ 19 - 5
runtime/triton_trtllm/client_grpc.py

@@ -257,7 +257,13 @@ def get_args():
         default=0.1,
         help="Chunk overlap duration for streaming reconstruction (in seconds)."
     )
-    # --- End Added arguments ---
+
+    parser.add_argument(
+        "--use-spk2info-cache",
+        type=bool,
+        default=False,
+        help="Use spk2info cache for reference audio.",
+    )
 
     return parser.parse_args()
 
@@ -283,7 +289,8 @@ def prepare_request_input_output(
     reference_text,
     target_text,
     sample_rate=16000,
-    padding_duration: int = None  # Optional padding for offline mode
+    padding_duration: int = None,  # Optional padding for offline mode
+    use_spk2info_cache: bool = False
 ):
     """Prepares inputs for Triton inference (offline or streaming)."""
     assert len(waveform.shape) == 1, "waveform should be 1D"
@@ -330,7 +337,8 @@ def prepare_request_input_output(
     inputs[3].set_data_from_numpy(input_data_numpy)
 
     outputs = [protocol_client.InferRequestedOutput("waveform")]
-
+    if use_spk2info_cache:
+        inputs = inputs[-1:]
     return inputs, outputs
 
 
@@ -453,6 +461,7 @@ async def send_streaming(
     save_sample_rate: int = 16000,
     chunk_overlap_duration: float = 0.1,
     padding_duration: int = None,
+    use_spk2info_cache: bool = False,
 ):
     total_duration = 0.0
     latency_data = []
@@ -478,7 +487,8 @@ async def send_streaming(
                     reference_text,
                     target_text,
                     sample_rate,
-                    padding_duration=padding_duration
+                    padding_duration=padding_duration,
+                    use_spk2info_cache=use_spk2info_cache
                 )
                 request_id = str(uuid.uuid4())
                 user_data = UserData()
@@ -534,6 +544,7 @@ async def send(
     padding_duration: int = None,
     audio_save_dir: str = "./",
     save_sample_rate: int = 16000,
+    use_spk2info_cache: bool = False,
 ):
     total_duration = 0.0
     latency_data = []
@@ -552,7 +563,8 @@ async def send(
             reference_text,
             target_text,
             sample_rate,
-            padding_duration=padding_duration
+            padding_duration=padding_duration,
+            use_spk2info_cache=use_spk2info_cache
         )
         sequence_id = 100000000 + i + task_id * 10
         start = time.time()
@@ -691,6 +703,7 @@ async def main():
                     audio_save_dir=args.log_dir,
                     padding_duration=1,
                     save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
+                    use_spk2info_cache=args.use_spk2info_cache,
                 )
             )
         elif args.mode == "streaming":
@@ -706,6 +719,7 @@ async def main():
                     padding_duration=10,
                     save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
                     chunk_overlap_duration=args.chunk_overlap_duration,
+                    use_spk2info_cache=args.use_spk2info_cache,
                 )
             )
         # --- End Task Creation ---

+ 45 - 29
runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py

@@ -43,6 +43,7 @@ import torchaudio
 
 from matcha.utils.audio import mel_spectrogram
 
+ORIGINAL_VOCAB_SIZE = 151663
 torch.set_num_threads(1)
 
 
@@ -81,6 +82,12 @@ class TritonPythonModel:
         self.flow_pre_lookahead_len = 3
         self.token_hop_len = 15
 
+        spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt")
+        if not os.path.exists(spk_info_path):
+            raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
+        spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
+        self.default_spk_info = spk_info["001"]
+
     def forward_llm(self, input_ids):
         """
         Prepares the response from the language model based on the provided
@@ -220,11 +227,11 @@ class TritonPythonModel:
 
     def forward_token2wav(
             self,
-            prompt_speech_tokens: torch.Tensor,
-            prompt_speech_feat: torch.Tensor,
-            prompt_spk_embedding: torch.Tensor,
             target_speech_tokens: torch.Tensor,
             request_id: str,
+            prompt_speech_tokens: torch.Tensor = None,
+            prompt_speech_feat: torch.Tensor = None,
+            prompt_spk_embedding: torch.Tensor = None,
             token_offset: int = None,
             finalize: bool = None) -> torch.Tensor:
         """Forward pass through the vocoder component.
@@ -238,12 +245,9 @@ class TritonPythonModel:
         Returns:
             Generated waveform tensor
         """
-        prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
-        prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
-        prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
         target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
 
-        inputs_tensor = [prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor]
+        inputs_tensor = [target_speech_tokens_tensor]
 
         if token_offset is not None:
             assert finalize is not None
@@ -252,6 +256,13 @@ class TritonPythonModel:
             inputs_tensor.append(token_offset_tensor)
             inputs_tensor.append(finalize_tensor)
 
+        if prompt_spk_embedding is not None:
+            assert prompt_speech_feat is not None
+            prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
+            prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
+            prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
+            inputs_tensor.extend([prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor])
+
         # Create and execute inference request
         inference_request = pb_utils.InferenceRequest(
             model_name='token2wav',
@@ -318,25 +329,30 @@ class TritonPythonModel:
             request_id = request.request_id()
             # Extract input tensors
             wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
-            wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
 
             # Process reference audio through audio tokenizer
-
-            prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
-            prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
-
-            wav_tensor = wav.as_numpy()
-            wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
-            prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
-            speech_feat = self._extract_speech_feat(prompt_speech_resample)
-            token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
-            prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
-            prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
-
-            flow_prompt_speech_token_len = prompt_speech_tokens.shape[-1]
-
-            reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
-            reference_text = reference_text[0][0].decode('utf-8')
+            if wav is not None:
+                wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
+                prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
+                prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
+
+                wav_tensor = wav.as_numpy()
+                wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
+                prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
+                speech_feat = self._extract_speech_feat(prompt_speech_resample)
+                token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
+                prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
+                prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
+
+                reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
+                reference_text = reference_text[0][0].decode('utf-8')
+                prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
+            else:
+                # using pre-cached reference text
+                reference_text = self.default_spk_info["prompt_text"]
+                prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
+                prompt_speech_feat = None
+                prompt_spk_embedding = None
 
             target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
             target_text = target_text[0][0].decode('utf-8')
@@ -350,7 +366,6 @@ class TritonPythonModel:
 
             # Generate semantic tokens with LLM
             generated_ids_iter = self.forward_llm(input_ids)
-            prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
 
             if self.decoupled:
                 response_sender = request.get_response_sender()
@@ -380,8 +395,9 @@ class TritonPythonModel:
                         this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
 
                         sub_tts_speech = self.forward_token2wav(
-                            prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding,
-                            this_tts_speech_token, request_id, token_offset, False)
+                            this_tts_speech_token, request_id, prompt_speech_tokens,
+                            prompt_speech_feat, prompt_spk_embedding, token_offset, False
+                        )
 
                         audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
                         inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
@@ -414,7 +430,7 @@ class TritonPythonModel:
                         time.sleep(0.02)
 
                 this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device)
-                sub_tts_speech = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, this_tts_speech_token, request_id, token_offset, True)
+                sub_tts_speech = self.forward_token2wav(this_tts_speech_token, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, token_offset, True)
                 audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
                 inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
                 response_sender.send(inference_response)
@@ -428,7 +444,7 @@ class TritonPythonModel:
                 if generated_ids is None or len(generated_ids) == 0:
                     raise pb_utils.TritonModelException("Generated IDs is None or empty")
 
-                audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids, request_id)
+                audio = self.forward_token2wav(generated_ids, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding)
 
                 # Prepare response
                 audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))

+ 3 - 0
runtime/triton_trtllm/model_repo/cosyvoice2/config.pbtxt

@@ -37,16 +37,19 @@ input [
     name: "reference_wav"
     data_type: TYPE_FP32
     dims: [-1]
+    optional: true
   },
   {
     name: "reference_wav_len"
     data_type: TYPE_INT32
     dims: [1]
+    optional: true
   },
   {
     name: "reference_text"
     data_type: TYPE_STRING
     dims: [1]
+    optional: true
   },
   {
     name: "target_text"

+ 20 - 8
runtime/triton_trtllm/model_repo/token2wav/1/model.py

@@ -187,6 +187,12 @@ class TritonPythonModel:
             model_dir, load_jit=False, load_trt=True, fp16=True, device=self.device
         )
 
+        spk_info_path = os.path.join(model_dir, "spk2info.pt")
+        if not os.path.exists(spk_info_path):
+            raise ValueError(f"spk2info.pt not found in {model_dir}")
+        spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
+        self.default_spk_info = spk_info["001"]
+
         logger.info("Token2Wav initialized successfully")
 
     def execute(self, requests):
@@ -202,17 +208,23 @@ class TritonPythonModel:
         # Process each request in batch
         for request in requests:
             target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
-            prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens").as_numpy()
-            prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy()
-            prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy()
-
             target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor).to(self.device)
-            prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device)
-            prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
-            prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
+
+            prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens")
+            if prompt_speech_tokens_tensor is not None:
+                prompt_speech_tokens_tensor = prompt_speech_tokens_tensor.as_numpy()
+                prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy()
+                prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy()
+                prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device)
+                prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
+                prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
+                prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
+            else:
+                prompt_speech_tokens = self.default_spk_info["speech_token"].to(self.device)
+                prompt_speech_feat = self.default_spk_info["speech_feat"].to(torch.float16).to(self.device)
+                prompt_spk_embedding = self.default_spk_info["embedding"].to(torch.float16).to(self.device)
 
             # shift the speech tokens according to the original vocab size
-            prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
             target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
 
             # We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.

+ 3 - 0
runtime/triton_trtllm/model_repo/token2wav/config.pbtxt

@@ -35,16 +35,19 @@ input [
     name: "prompt_speech_tokens"
     data_type: TYPE_INT32
     dims: [-1]
+    optional: true
   },
   {
     name: "prompt_speech_feat"
     data_type: TYPE_FP16
     dims: [-1, 80]
+    optional: true
   },
   {
     name: "prompt_spk_embedding"
     data_type: TYPE_FP16
     dims: [-1]
+    optional: true
   },
   {
     name: "token_offset"

+ 15 - 7
runtime/triton_trtllm/run.sh

@@ -15,6 +15,8 @@ trt_engines_dir=./trt_engines_${trt_dtype}
 
 model_repo=./model_repo_cosyvoice2
 
+use_spk2info_cache=True
+
 if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
     echo "Cloning CosyVoice"
     git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path
@@ -27,6 +29,8 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
     echo "Downloading CosyVoice2-0.5B"
     huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm
     modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir
+    # download spk2info.pt to directly use cached speech tokens, speech feats, and embeddings
+    wget https://raw.githubusercontent.com/qi-hua/async_cosyvoice/main/CosyVoice2-0.5B/spk2info.pt -O $model_scope_model_local_dir/spk2info.pt
 fi
 
 
@@ -57,10 +61,12 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
     cosyvoice2_dir="cosyvoice2"
 
     cp -r ./model_repo/${cosyvoice2_dir} $model_repo
-    cp -r ./model_repo/audio_tokenizer $model_repo
     cp -r ./model_repo/tensorrt_llm $model_repo
     cp -r ./model_repo/token2wav $model_repo
-    cp -r ./model_repo/speaker_embedding $model_repo
+    if [ $use_spk2info_cache == "False" ]; then
+        cp -r ./model_repo/audio_tokenizer $model_repo
+        cp -r ./model_repo/speaker_embedding $model_repo
+    fi
 
     ENGINE_PATH=$trt_engines_dir
     MAX_QUEUE_DELAY_MICROSECONDS=0
@@ -71,11 +77,12 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
     DECOUPLED_MODE=True # True for streaming, False for offline
 
     python3 scripts/fill_template.py -i ${model_repo}/token2wav/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
-    python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
     python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
-    python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
     python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
-
+    if [ $use_spk2info_cache == "False" ]; then
+        python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
+        python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
+    fi
 fi
 
 if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
@@ -94,7 +101,7 @@ fi
 
 if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
     echo "Running benchmark client grpc"
-    num_task=1
+    num_task=4
 
     mode=streaming
     BLS_INSTANCE_NUM=4
@@ -104,6 +111,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
         --model-name cosyvoice2 \
         --num-tasks $num_task \
         --mode $mode \
+        --use-spk2info-cache $use_spk2info_cache \
         --huggingface-dataset yuekai/seed_tts_cosy2 \
-        --log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}
+        --log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_spk_cache_${use_spk2info_cache}
 fi