lyuxiang.lx 11 months ago
parent
commit
82219cdd27

+ 2 - 2
cosyvoice/cli/cosyvoice.py

@@ -26,7 +26,7 @@ from cosyvoice.utils.class_utils import get_model_type
 
 class CosyVoice:
 
-    def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
+    def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
         self.instruct = True if '-Instruct' in model_dir else False
         self.model_dir = model_dir
         self.fp16 = fp16
@@ -48,7 +48,7 @@ class CosyVoice:
         if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
             load_jit, load_trt, fp16 = False, False, False
             logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
-        self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
+        self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16, trt_concurrent)
         self.model.load('{}/llm.pt'.format(model_dir),
                         '{}/flow.pt'.format(model_dir),
                         '{}/hift.pt'.format(model_dir))

+ 1 - 0
examples/libritts/cosyvoice2/conf/cosyvoice2.yaml

@@ -158,6 +158,7 @@ feat_extractor: !name:matcha.utils.audio.mel_spectrogram
     center: False
 compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
     feat_extractor: !ref <feat_extractor>
+    token_mel_ratio: 2
 compute_f0: !name:cosyvoice.dataset.processor.compute_f0
     sample_rate: !ref <sample_rate>
     hop_size: 480

+ 2 - 2
runtime/python/grpc/server.py

@@ -34,10 +34,10 @@ logging.basicConfig(level=logging.DEBUG,
 class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
     def __init__(self, args):
         try:
-            self.cosyvoice = CosyVoice(args.model_dir)
+            self.cosyvoice = CosyVoice(args.model_dir, trt_concurrent=args.max_conc)
         except Exception:
             try:
-                self.cosyvoice = CosyVoice2(args.model_dir)
+                self.cosyvoice = CosyVoice2(args.model_dir, trt_concurrent=args.max_conc)
             except Exception:
                 raise TypeError('no valid model_type!')
         logging.info('grpc service initialized')