Browse Source

fix triton token2wav model cache thread unsafety

김의진 6 months ago
parent
commit
cd26dd1932
1 changed files with 4 additions and 3 deletions
  1. 4 3
      runtime/triton_trtllm/model_repo/token2wav/1/model.py

+ 4 - 3
runtime/triton_trtllm/model_repo/token2wav/1/model.py

@@ -28,6 +28,7 @@ import json
 import os
 
 import logging
+from uuid import uuid4
 
 import torch
 from torch.utils.dlpack import to_dlpack
@@ -235,17 +236,17 @@ class TritonPythonModel:
                     stream = True
                 else:
                     stream = False
-                request_id = request.request_id()
+                uuid = uuid4().hex
                 audio_hat = self.token2wav_model.model.token2wav(token=target_speech_tokens,
                                                                  prompt_token=prompt_speech_tokens,
                                                                  prompt_feat=prompt_speech_feat,
                                                                  embedding=prompt_spk_embedding,
                                                                  token_offset=token_offset,
-                                                                 uuid=request_id,
+                                                                 uuid=uuid,
                                                                  stream=stream,
                                                                  finalize=finalize)
                 if finalize:
-                    self.token2wav_model.model.hift_cache_dict.pop(request_id)
+                    self.token2wav_model.model.hift_cache_dict.pop(uuid)
 
             else:
                 tts_mel, _ = self.token2wav_model.model.flow.inference(