|
|
@@ -26,7 +26,7 @@ from cosyvoice.utils.class_utils import get_model_type
|
|
|
|
|
|
class CosyVoice:
|
|
|
|
|
|
- def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
|
|
|
+ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
|
|
|
self.instruct = True if '-Instruct' in model_dir else False
|
|
|
self.model_dir = model_dir
|
|
|
self.fp16 = fp16
|
|
|
@@ -48,7 +48,7 @@ class CosyVoice:
|
|
|
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
|
|
load_jit, load_trt, fp16 = False, False, False
|
|
|
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
|
|
- self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
|
|
|
+ self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16, trt_concurrent)
|
|
|
self.model.load('{}/llm.pt'.format(model_dir),
|
|
|
'{}/flow.pt'.format(model_dir),
|
|
|
'{}/hift.pt'.format(model_dir))
|