Selaa lähdekoodia

Merge pull request #1325 from FunAudioLLM/dev/lyuxiang.lx

fix trt wrapper bug
Xiang Lyu 6 kuukautta sitten
vanhempi
commit
4159a18469

+ 10 - 1
cosyvoice/cli/model.py

@@ -59,6 +59,9 @@ class CosyVoiceModel:
         self.stream_scale_factor = 1
         assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
         self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
+        self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
+        for _ in range(trt_concurrent):
+            self.trt_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
         self.lock = threading.Lock()
         # dict used to store session related variable
         self.tts_speech_token_dict = {}
@@ -66,6 +69,7 @@ class CosyVoiceModel:
         self.mel_overlap_dict = {}
         self.flow_cache_dict = {}
         self.hift_cache_dict = {}
+        self.trt_context_dict = {}
 
     def load(self, llm_model, flow_model, hift_model):
         self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
@@ -176,11 +180,13 @@ class CosyVoiceModel:
             prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
         # this_uuid is used to track variables related to this inference thread
         this_uuid = str(uuid.uuid1())
+        this_trt_context = self.trt_context_pool.get()
         with self.lock:
             self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
             self.hift_cache_dict[this_uuid] = None
             self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
             self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
+            self.trt_context_dict[this_uuid] = this_trt_context
         if source_speech_token.shape[1] == 0:
             p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
         else:
@@ -234,6 +240,8 @@ class CosyVoiceModel:
             self.mel_overlap_dict.pop(this_uuid)
             self.hift_cache_dict.pop(this_uuid)
             self.flow_cache_dict.pop(this_uuid)
+            self.trt_context_pool.put(self.trt_context_dict[this_uuid])
+            self.trt_context_dict.pop(this_uuid)
         if torch.cuda.is_available():
             torch.cuda.empty_cache()
             torch.cuda.current_stream().synchronize()
@@ -324,10 +332,11 @@ class CosyVoice2Model(CosyVoiceModel):
             prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
         # this_uuid is used to track variables related to this inference thread
         this_uuid = str(uuid.uuid1())
+        this_trt_context = self.trt_context_pool.get()
         with self.lock:
             self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
             self.hift_cache_dict[this_uuid] = None
-            self.trt_context_dict[this_uuid] = self.trt_context_pool.get()
+            self.trt_context_dict[this_uuid] = this_trt_context
         if source_speech_token.shape[1] == 0:
             p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
         else:

+ 3 - 0
cosyvoice/utils/mask.py

@@ -230,6 +230,9 @@ def add_optional_chunk_mask(xs: torch.Tensor,
     else:
         chunk_masks = masks
     assert chunk_masks.dtype == torch.bool
+    if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
+        print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
+        chunk_masks[chunk_masks.sum(dim=-1) == 0] = True
     return chunk_masks
 
 

+ 1 - 1
examples/libritts/cosyvoice2/conf/cosyvoice2.yaml

@@ -15,7 +15,7 @@ token_mel_ratio: 2
 
 # stream related params
 chunk_size: 25 # streaming inference chunk size, in token
-num_decoding_left_chunks: 1 # streaming inference flow decoder left chunk size, <0 means use all left chunks
+num_decoding_left_chunks: -1 # streaming inference flow decoder left chunk size, <0 means use all left chunks
 
 # model params
 # for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.