فهرست منبع

Update estimator count retrieval and memory pool limit in CosyVoice

- Simplified estimator count retrieval in CosyVoice and CosyVoice2 classes to directly access the configs dictionary.
- Adjusted memory pool limit in the ONNX to TensorRT conversion function from 8GB to 1GB for optimized resource management.
禾息 1 سال پیش
والد
کامیت
369f3c2c18
2فایلهای تغییر یافته به همراه3 افزوده شده و 3 حذف شده
  1. 2 2
      cosyvoice/cli/cosyvoice.py
  2. 1 1
      cosyvoice/utils/file_utils.py

+ 2 - 2
cosyvoice/cli/cosyvoice.py

@@ -54,7 +54,7 @@ class CosyVoice:
                                 '{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
                                 '{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
         if load_trt:
-            self.estimator_count = configs['flow']['decoder']['estimator'].get('estimator_count', 1)
+            self.estimator_count = configs.get('estimator_count', 1)
             self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
                                 '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
                                 self.fp16, self.estimator_count)
@@ -180,7 +180,7 @@ class CosyVoice2(CosyVoice):
         if load_jit:
             self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
         if load_trt:
-            self.estimator_count = configs['flow']['decoder']['estimator'].get('estimator_count', 1)
+            self.estimator_count = configs.get('estimator_count', 1)
             self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
                                 '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
                                 self.fp16, self.estimator_count)

+ 1 - 1
cosyvoice/utils/file_utils.py

@@ -61,7 +61,7 @@ def convert_onnx_to_trt(trt_model, onnx_model, fp16):
     network = builder.create_network(network_flags)
     parser = trt.OnnxParser(network, logger)
     config = builder.create_builder_config()
-    config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 33)  # 8GB
+    config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30)  # 1GB
     if fp16:
         config.set_flag(trt.BuilderFlag.FP16)
     profile = builder.create_optimization_profile()