|
|
@@ -59,6 +59,9 @@ class CosyVoiceModel:
|
|
|
self.stream_scale_factor = 1
|
|
|
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
|
|
|
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
|
|
+ self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
|
|
|
+ for _ in range(trt_concurrent):
|
|
|
+ self.trt_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
|
|
|
self.lock = threading.Lock()
|
|
|
# dict used to store session related variable
|
|
|
self.tts_speech_token_dict = {}
|
|
|
@@ -66,6 +69,7 @@ class CosyVoiceModel:
|
|
|
self.mel_overlap_dict = {}
|
|
|
self.flow_cache_dict = {}
|
|
|
self.hift_cache_dict = {}
|
|
|
+ self.trt_context_dict = {}
|
|
|
|
|
|
def load(self, llm_model, flow_model, hift_model):
|
|
|
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
|
|
@@ -176,11 +180,13 @@ class CosyVoiceModel:
|
|
|
prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
|
|
|
# this_uuid is used to track variables related to this inference thread
|
|
|
this_uuid = str(uuid.uuid1())
|
|
|
+ this_trt_context = self.trt_context_pool.get()
|
|
|
with self.lock:
|
|
|
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
|
|
self.hift_cache_dict[this_uuid] = None
|
|
|
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
|
|
|
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
|
|
|
+ self.trt_context_dict[this_uuid] = this_trt_context
|
|
|
if source_speech_token.shape[1] == 0:
|
|
|
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
|
|
else:
|
|
|
@@ -234,6 +240,8 @@ class CosyVoiceModel:
|
|
|
self.mel_overlap_dict.pop(this_uuid)
|
|
|
self.hift_cache_dict.pop(this_uuid)
|
|
|
self.flow_cache_dict.pop(this_uuid)
|
|
|
+ self.trt_context_pool.put(self.trt_context_dict[this_uuid])
|
|
|
+ self.trt_context_dict.pop(this_uuid)
|
|
|
if torch.cuda.is_available():
|
|
|
torch.cuda.empty_cache()
|
|
|
torch.cuda.current_stream().synchronize()
|
|
|
@@ -324,10 +332,11 @@ class CosyVoice2Model(CosyVoiceModel):
|
|
|
prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
|
|
|
# this_uuid is used to track variables related to this inference thread
|
|
|
this_uuid = str(uuid.uuid1())
|
|
|
+ this_trt_context = self.trt_context_pool.get()
|
|
|
with self.lock:
|
|
|
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
|
|
self.hift_cache_dict[this_uuid] = None
|
|
|
- self.trt_context_dict[this_uuid] = self.trt_context_pool.get()
|
|
|
+ self.trt_context_dict[this_uuid] = this_trt_context
|
|
|
if source_speech_token.shape[1] == 0:
|
|
|
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
|
|
else:
|