1
0
Selaa lähdekoodia

add onnx export

lyuxiang.lx 1 vuosi sitten
vanhempi
commit
2ce724045b

+ 8 - 1
cosyvoice/bin/export_jit.py

@@ -44,7 +44,7 @@ def main():
     torch._C._jit_set_profiling_mode(False)
     torch._C._jit_set_profiling_executor(False)
 
-    cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_trt=False)
+    cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
 
     # 1. export llm text_encoder
     llm_text_encoder = cosyvoice.model.llm.text_encoder.half()
@@ -60,5 +60,12 @@ def main():
     script = torch.jit.optimize_for_inference(script)
     script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
 
+    # 3. export flow encoder
+    flow_encoder = cosyvoice.model.flow.encoder
+    script = torch.jit.script(flow_encoder)
+    script = torch.jit.freeze(script)
+    script = torch.jit.optimize_for_inference(script)
+    script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
+
 if __name__ == '__main__':
     main()

+ 71 - 190
cosyvoice/bin/export_onnx.py

@@ -1,4 +1,5 @@
 # Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com)
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,217 +13,97 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from __future__ import print_function
+
 import argparse
 import logging
+logging.getLogger('matplotlib').setLevel(logging.WARNING)
 import os
 import sys
-
-logging.getLogger('matplotlib').setLevel(logging.WARNING)
-import onnxruntime as ort
-import numpy as np
-
-# try:
-#     import tensorrt
-#     import tensorrt as trt
-# except ImportError:
-#     error_msg_zh = [
-#         "step.1 下载 tensorrt .tar.gz 压缩包并解压,下载地址: https://developer.nvidia.com/tensorrt/download/10x",
-#         "step.2 使用 tensorrt whl 包进行安装根据 python 版本对应进行安装,如 pip install ${TensorRT-Path}/python/tensorrt-10.2.0-cp38-none-linux_x86_64.whl",
-#         "step.3 将 tensorrt 的 lib 路径添加进环境变量中,export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${TensorRT-Path}/lib/"
-#     ]
-#     print("\n".join(error_msg_zh))
-#     sys.exit(1)
-
+ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
+sys.path.append('{}/../..'.format(ROOT_DIR))
+sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
+import onnxruntime
+import random
 import torch
+from tqdm import tqdm
 from cosyvoice.cli.cosyvoice import CosyVoice
 
 
-def calculate_onnx(onnx_file, x, mask, mu, t, spks, cond):
-    providers = ['CUDAExecutionProvider']
-    sess_options = ort.SessionOptions()
-
-    providers = [
-        'CUDAExecutionProvider'
-    ]
-
-    # Load the ONNX model
-    session = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers)
-    
-    x_np = x.cpu().numpy()
-    mask_np = mask.cpu().numpy()
-    mu_np = mu.cpu().numpy()
-    t_np = np.array(t.cpu()) 
-    spks_np = spks.cpu().numpy()
-    cond_np = cond.cpu().numpy()
-
-    ort_inputs = {
-        'x': x_np,
-        'mask': mask_np,
-        'mu': mu_np,
-        't': t_np,
-        'spks': spks_np,
-        'cond': cond_np
-    }
-
-    output = session.run(None, ort_inputs)
-
-    return output[0]
-
-# def calculate_tensorrt(trt_file, x, mask, mu, t, spks, cond):
-#     trt.init_libnvinfer_plugins(None, "")
-#     logger = trt.Logger(trt.Logger.WARNING)
-#     runtime = trt.Runtime(logger)
-#     with open(trt_file, 'rb') as f:
-#         serialized_engine = f.read()
-#     engine = runtime.deserialize_cuda_engine(serialized_engine)
-#     context = engine.create_execution_context()
-
-#     bs = x.shape[0]
-#     hs = x.shape[1]
-#     seq_len = x.shape[2]
-
-#     ret = torch.zeros_like(x)
-
-#     # Set input shapes for dynamic dimensions
-#     context.set_input_shape("x", x.shape)
-#     context.set_input_shape("mask", mask.shape)
-#     context.set_input_shape("mu", mu.shape)
-#     context.set_input_shape("t", t.shape)
-#     context.set_input_shape("spks", spks.shape)
-#     context.set_input_shape("cond", cond.shape)
-
-#     # bindings = [x.data_ptr(), mask.data_ptr(), mu.data_ptr(), t.data_ptr(), spks.data_ptr(), cond.data_ptr(), ret.data_ptr()]
-#     # names = ['x', 'mask', 'mu', 't', 'spks', 'cond', 'estimator_out']
-#     #
-#     # for i in range(len(bindings)):
-#     #     context.set_tensor_address(names[i], bindings[i])
-#     #
-#     # handle = torch.cuda.current_stream().cuda_stream
-#     # context.execute_async_v3(stream_handle=handle)
-
-#     # Create a list of bindings
-#     bindings = [int(x.data_ptr()), int(mask.data_ptr()), int(mu.data_ptr()), int(t.data_ptr()), int(spks.data_ptr()), int(cond.data_ptr()), int(ret.data_ptr())]
-
-#     # Execute the inference
-#     context.execute_v2(bindings=bindings)
-
-#     torch.cuda.synchronize()
-
-#     return ret
-
-
-# def test_calculate_value(estimator, onnx_file, trt_file, dummy_input, args):
-#     torch_output = estimator.forward(**dummy_input).cpu().detach().numpy()
-#     onnx_output = calculate_onnx(onnx_file, **dummy_input)
-#     tensorrt_output = calculate_tensorrt(trt_file, **dummy_input).cpu().detach().numpy()
-#     atol = 2e-3  # Absolute tolerance
-#     rtol = 1e-4  # Relative tolerance
-
-#     print(f"args.export_half: {args.export_half}, args.model_dir: {args.model_dir}")
-#     print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")
-
-#     print("torch_output diff with onnx_output: ", )
-#     print(f"compare with atol: {atol}, rtol: {rtol} ", np.allclose(torch_output, onnx_output, atol, rtol))
-#     print(f"max diff value: ", np.max(np.fabs(torch_output - onnx_output)))
-#     print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")
-
-#     print("torch_output diff with tensorrt_output: ")
-#     print(f"compare with atol: {atol}, rtol: {rtol} ", np.allclose(torch_output, tensorrt_output, atol, rtol))
-#     print(f"max diff value: ", np.max(np.fabs(torch_output - tensorrt_output)))
-#     print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")
-
-#     print("onnx_output diff with tensorrt_output: ")
-#     print(f"compare with atol: {atol}, rtol: {rtol} ", np.allclose(onnx_output, tensorrt_output, atol, rtol))
-#     print(f"max diff value: ", np.max(np.fabs(onnx_output - tensorrt_output)))
-#     print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")
+def get_dummy_input(batch_size, seq_len, out_channels, device):
+    x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
+    mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
+    mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
+    t = torch.rand((batch_size), dtype=torch.float32, device=device)
+    spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
+    cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
+    return x, mask, mu, t, spks, cond
 
 
 def get_args():
-    parser = argparse.ArgumentParser(description='Export your model for deployment')
-    parser.add_argument('--model_dir', type=str, default='pretrained_models/CosyVoice-300M', help='Local path to the model directory')
-    parser.add_argument('--export_half', type=str, choices=['True', 'False'], default='False', help='Export with half precision (FP16)')
-    # parser.add_argument('--trt_max_len', type=int, default=8192, help='Export max len')
-    parser.add_argument('--exec_export', type=str, choices=['True', 'False'], default='True', help='Exec export')
-    
+    parser = argparse.ArgumentParser(description='export your model for deployment')
+    parser.add_argument('--model_dir',
+                        type=str,
+                        default='pretrained_models/CosyVoice-300M',
+                        help='local path')
     args = parser.parse_args()
-    args.export_half = args.export_half == 'True'
-    args.exec_export = args.exec_export == 'True'
-    print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")
     print(args)
     return args
 
 def main():
     args = get_args()
+    logging.basicConfig(level=logging.DEBUG,
+                        format='%(asctime)s %(levelname)s %(message)s')
 
-    cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_trt=False)
+    cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
+
+    # 1. export flow decoder estimator
     estimator = cosyvoice.model.flow.decoder.estimator
 
-    dtype = torch.float32 if not args.export_half else torch.float16
-    device = torch.device("cuda")
-    batch_size = 1
-    seq_len = 256
+    device = cosyvoice.model.device
+    batch_size, seq_len = 1, 256
     out_channels = cosyvoice.model.flow.decoder.estimator.out_channels
-    x = torch.rand((batch_size, out_channels, seq_len), dtype=dtype, device=device)
-    mask = torch.ones((batch_size, 1, seq_len), dtype=dtype, device=device)
-    mu = torch.rand((batch_size, out_channels, seq_len), dtype=dtype, device=device)
-    t = torch.rand((batch_size, ), dtype=dtype, device=device)
-    spks = torch.rand((batch_size, out_channels), dtype=dtype, device=device)
-    cond = torch.rand((batch_size, out_channels, seq_len), dtype=dtype, device=device)
-
-    onnx_file_name = 'estimator_fp32.onnx' if not args.export_half else 'estimator_fp16.onnx'
-    onnx_file_path = os.path.join(args.model_dir, onnx_file_name)
-    dummy_input = (x, mask, mu, t, spks, cond)
-
-    estimator = estimator.to(dtype)
-
-    if args.exec_export:
-        torch.onnx.export(
-            estimator,
-            dummy_input,
-            onnx_file_path,
-            export_params=True,
-            opset_version=18,
-            do_constant_folding=True,
-            input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
-            output_names=['estimator_out'],
-            dynamic_axes={
-                'x': {2: 'seq_len'},
-                'mask': {2: 'seq_len'},
-                'mu': {2: 'seq_len'},
-                'cond': {2: 'seq_len'},
-                'estimator_out': {2: 'seq_len'},
-            }
-        )
-
-    # tensorrt_path = os.environ.get('tensorrt_root_dir')
-    # if not tensorrt_path:
-    #     raise EnvironmentError("Please set the 'tensorrt_root_dir' environment variable.")
-
-    # if not os.path.isdir(tensorrt_path):
-    #     raise FileNotFoundError(f"The directory {tensorrt_path} does not exist.")
-
-    # trt_lib_path = os.path.join(tensorrt_path, "lib")
-    # if trt_lib_path not in os.environ.get('LD_LIBRARY_PATH', ''):
-    #     print(f"Adding TensorRT lib path {trt_lib_path} to LD_LIBRARY_PATH.")
-    #     os.environ['LD_LIBRARY_PATH'] = f"{os.environ.get('LD_LIBRARY_PATH', '')}:{trt_lib_path}"
-
-    # trt_file_name = 'estimator_fp32.plan' if not args.export_half else 'estimator_fp16.plan'
-    # trt_file_path = os.path.join(args.model_dir, trt_file_name)
-
-    # trtexec_bin = os.path.join(tensorrt_path, 'bin/trtexec')
-    # trt_max_len = args.trt_max_len
-    # trtexec_cmd = f"{trtexec_bin} --onnx={onnx_file_path} --saveEngine={trt_file_path} " \
-    #               f"--minShapes=x:1x{out_channels}x1,mask:1x1x1,mu:1x{out_channels}x1,t:1,spks:1x{out_channels},cond:1x{out_channels}x1 " \
-    #               f"--maxShapes=x:1x{out_channels}x{trt_max_len},mask:1x1x{trt_max_len},mu:1x{out_channels}x{trt_max_len},t:1,spks:1x{out_channels},cond:1x{out_channels}x{trt_max_len} " + \
-    #               ("--fp16" if args.export_half else "")
-    
-    # print("execute ", trtexec_cmd)
-
-    # if args.exec_export:
-    #     os.system(trtexec_cmd)
-
-    # dummy_input = {'x': x, 'mask': mask, 'mu': mu, 't': t, 'spks': spks, 'cond': cond}
-    # test_calculate_value(estimator, onnx_file_path, trt_file_path, dummy_input, args)
+    x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
+    torch.onnx.export(
+        estimator,
+        (x, mask, mu, t, spks, cond),
+        '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
+        export_params=True,
+        opset_version=18,
+        do_constant_folding=True,
+        input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
+        output_names=['estimator_out'],
+        dynamic_axes={
+            'x': {0: 'batch_size', 2: 'seq_len'},
+            'mask': {0: 'batch_size', 2: 'seq_len'},
+            'mu': {0: 'batch_size', 2: 'seq_len'},
+            'cond': {0: 'batch_size', 2: 'seq_len'},
+            't': {0: 'batch_size'},
+            'spks': {0: 'batch_size'},
+            'estimator_out': {0: 'batch_size', 2: 'seq_len'},
+        }
+    )
+
+    # 2. test computation consistency
+    option = onnxruntime.SessionOptions()
+    option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
+    option.intra_op_num_threads = 1
+    providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
+    estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), sess_options=option, providers=providers)
+
+    for _ in tqdm(range(10)):
+        x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device)
+        output_pytorch = estimator(x, mask, mu, t, spks, cond)
+        ort_inputs = {
+            'x': x.cpu().numpy(),
+            'mask': mask.cpu().numpy(),
+            'mu': mu.cpu().numpy(),
+            't': t.cpu().numpy(),
+            'spks': spks.cpu().numpy(),
+            'cond': cond.cpu().numpy()
+        }
+        output_onnx = estimator_onnx.run(None, ort_inputs)[0]
+        torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
 
 if __name__ == "__main__":
     main()

+ 4 - 9
cosyvoice/cli/cosyvoice.py

@@ -21,7 +21,7 @@ from cosyvoice.utils.file_utils import logging
 
 class CosyVoice:
 
-    def __init__(self, model_dir, load_jit=True, load_trt=False, load_onnx=True, use_fp16=False):
+    def __init__(self, model_dir, load_jit=True, load_onnx=True):
         instruct = True if '-Instruct' in model_dir else False
         self.model_dir = model_dir
         if not os.path.exists(model_dir):
@@ -39,17 +39,12 @@ class CosyVoice:
         self.model.load('{}/llm.pt'.format(model_dir),
                         '{}/flow.pt'.format(model_dir),
                         '{}/hift.pt'.format(model_dir))
-
         if load_jit:
             self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
-                                    '{}/llm.llm.fp16.zip'.format(model_dir))
-
-        # if load_trt:
-        #     self.model.load_trt(model_dir, use_fp16)
-        
+                                    '{}/llm.llm.fp16.zip'.format(model_dir),
+                                    '{}/flow.encoder.fp32.zip'.format(model_dir))
         if load_onnx:
-            self.model.load_onnx(model_dir, use_fp16)
-            
+            self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
         del configs
 
     def list_avaliable_spks(self):

+ 13 - 39
cosyvoice/cli/model.py

@@ -11,7 +11,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import os
 import torch
 import numpy as np
 import threading
@@ -20,7 +19,6 @@ from contextlib import nullcontext
 import uuid
 from cosyvoice.utils.common import fade_in_out
 import numpy as np
-import onnxruntime as ort
 
 class CosyVoiceModel:
 
@@ -62,47 +60,22 @@ class CosyVoiceModel:
         self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
         self.hift.to(self.device).eval()
 
-    def load_jit(self, llm_text_encoder_model, llm_llm_model):
+    def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
         llm_text_encoder = torch.jit.load(llm_text_encoder_model)
         self.llm.text_encoder = llm_text_encoder
         llm_llm = torch.jit.load(llm_llm_model)
         self.llm.llm = llm_llm
+        flow_encoder = torch.jit.load(flow_encoder_model)
+        self.flow.encoder = flow_encoder
 
-    # def load_trt(self, model_dir, use_fp16):
-    #     import tensorrt as trt
-    #     trt_file_name = 'estimator_fp16.plan' if use_fp16 else 'estimator_fp32.plan'
-    #     trt_file_path = os.path.join(model_dir, trt_file_name)
-    #     if not os.path.isfile(trt_file_path):
-    #         raise f"{trt_file_path} does not exist. Please use bin/export_trt.py to generate .plan file"
-
-    #     trt.init_libnvinfer_plugins(None, "")
-    #     logger = trt.Logger(trt.Logger.WARNING)
-    #     runtime = trt.Runtime(logger)
-    #     with open(trt_file_path, 'rb') as f:
-    #         serialized_engine = f.read()
-    #     engine = runtime.deserialize_cuda_engine(serialized_engine)
-    #     self.flow.decoder.estimator_context = engine.create_execution_context()
-    #     self.flow.decoder.estimator = None
-    
-    def load_onnx(self, model_dir, use_fp16):
-        onnx_file_name = 'estimator_fp16.onnx' if use_fp16 else 'estimator_fp32.onnx'
-        onnx_file_path = os.path.join(model_dir, onnx_file_name)
-        if not os.path.isfile(onnx_file_path):
-            raise f"{onnx_file_path} does not exist. Please use bin/export_trt.py to generate .onnx file"
-        
-        providers = ['CUDAExecutionProvider']
-        sess_options = ort.SessionOptions()
-
-        # Add TensorRT Execution Provider
-        providers = [
-            'CUDAExecutionProvider'
-        ]
-
-        # Load the ONNX model
-        self.flow.decoder.session = ort.InferenceSession(onnx_file_path, sess_options=sess_options, providers=providers)
-        # self.flow.decoder.estimator_context = None
-        self.flow.decoder.estimator = None
-        
+    def load_onnx(self, flow_decoder_estimator_model):
+        import onnxruntime
+        option = onnxruntime.SessionOptions()
+        option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
+        option.intra_op_num_threads = 1
+        providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
+        del self.flow.decoder.estimator
+        self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
 
     def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
         with self.llm_context:
@@ -207,4 +180,5 @@ class CosyVoiceModel:
             self.llm_end_dict.pop(this_uuid)
             self.mel_overlap_dict.pop(this_uuid)
             self.hift_cache_dict.pop(this_uuid)
-        torch.cuda.synchronize()
+        if torch.cuda.is_available():
+            torch.cuda.synchronize()

+ 10 - 43
cosyvoice/flow/flow_matching.py

@@ -31,8 +31,6 @@ class ConditionalCFM(BASECFM):
         in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
         # Just change the architecture of the estimator here
         self.estimator = estimator
-        self.estimator_context = None # for tensorrt
-        self.session = None # for onnx
 
     @torch.inference_mode()
     def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
@@ -82,10 +80,10 @@ class ConditionalCFM(BASECFM):
         sol = []
 
         for step in range(1, len(t_span)):
-            dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
+            dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
             # Classifier-Free Guidance inference introduced in VoiceBox
             if self.inference_cfg_rate > 0:
-                cfg_dphi_dt = self.estimator(
+                cfg_dphi_dt = self.forward_estimator(
                     x, mask,
                     torch.zeros_like(mu), t,
                     torch.zeros_like(spks) if spks is not None else None,
@@ -102,51 +100,20 @@ class ConditionalCFM(BASECFM):
         return sol[-1]
 
     def forward_estimator(self, x, mask, mu, t, spks, cond):
-
-        if self.estimator is not None:
+        if isinstance(self.estimator, torch.nn.Module):
             return self.estimator.forward(x, mask, mu, t, spks, cond)
-        # elif self.estimator_context is not None:
-        #     assert self.training is False, 'tensorrt cannot be used in training'
-        #     bs = x.shape[0]
-        #     hs = x.shape[1]
-        #     seq_len = x.shape[2]
-        #     # assert bs == 1 and hs == 80
-        #     ret = torch.empty_like(x)
-        #     self.estimator_context.set_input_shape("x", x.shape)
-        #     self.estimator_context.set_input_shape("mask", mask.shape)
-        #     self.estimator_context.set_input_shape("mu", mu.shape)
-        #     self.estimator_context.set_input_shape("t", t.shape)
-        #     self.estimator_context.set_input_shape("spks", spks.shape)
-        #     self.estimator_context.set_input_shape("cond", cond.shape)
-
-        #     # Create a list of bindings
-        #     bindings = [int(x.data_ptr()), int(mask.data_ptr()), int(mu.data_ptr()), int(t.data_ptr()), int(spks.data_ptr()), int(cond.data_ptr()), int(ret.data_ptr())]
-
-        #     # Execute the inference
-        #     self.estimator_context.execute_v2(bindings=bindings)
-        #     return ret
         else:
-            x_np = x.cpu().numpy()
-            mask_np = mask.cpu().numpy()
-            mu_np = mu.cpu().numpy()
-            t_np = t.cpu().numpy()
-            spks_np = spks.cpu().numpy()
-            cond_np = cond.cpu().numpy()
-
             ort_inputs = {
-                'x': x_np,
-                'mask': mask_np,
-                'mu': mu_np,
-                't': t_np,
-                'spks': spks_np,
-                'cond': cond_np
+                'x': x.cpu().numpy(),
+                'mask': mask.cpu().numpy(),
+                'mu': mu.cpu().numpy(),
+                't': t.cpu().numpy(),
+                'spks': spks.cpu().numpy(),
+                'cond': cond.cpu().numpy()
             }
-
-            output = self.session.run(None, ort_inputs)[0]
-
+            output = self.estimator.run(None, ort_inputs)[0]
             return torch.tensor(output, dtype=x.dtype, device=x.device)
 
-
     def compute_loss(self, x1, mask, mu, spks=None, cond=None):
         """Computes diffusion loss
 

+ 1 - 0
requirements.txt

@@ -15,6 +15,7 @@ matplotlib==3.7.5
 modelscope==1.15.0
 networkx==3.1
 omegaconf==2.3.0
+onnx==1.16.0
 onnxruntime-gpu==1.16.0; sys_platform == 'linux'
 onnxruntime==1.16.0; sys_platform == 'darwin' or sys_platform == 'windows'
 openai-whisper==20231117