Ver código fonte

Merge pull request #330 from hexisyztem/flow_tensorrt

Flow tensorrt
Xiang Lyu 1 ano atrás
pai
commit
d2dea3d928

+ 123 - 5
cosyvoice/bin/export_trt.py

@@ -1,8 +1,126 @@
-# TODO 跟export_jit一样的逻辑,完成flow部分的estimator的onnx导出。
-# tensorrt的安装方式,再这里写一下步骤提示如下,如果没有安装,那么不要执行这个脚本,提示用户先安装,不给选择
+# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+import os
+import sys
+
+logging.getLogger('matplotlib').setLevel(logging.WARNING)
+
 try:
     import tensorrt
 except ImportError:
-    print('step1, 下载\n step2. 解压,安装whl,')
-# 安装命令里tensosrt的根目录用环境变量导入,比如os.environ['tensorrt_root_dir']/bin/exetrace,然后python里subprocess里执行导出命令
-# 后面我会在run.sh里写好执行命令 tensorrt_root_dir=xxxx python cosyvoice/bin/export_trt.py --model_dir xxx
+    error_msg_zh = [
+        "step.1 下载 tensorrt .tar.gz 压缩包并解压,下载地址: https://developer.nvidia.com/tensorrt/download/10x",
+        "step.2 使用 tensorrt whl 包进行安装根据 python 版本对应进行安装,如 pip install ${TensorRT-Path}/python/tensorrt-10.2.0-cp38-none-linux_x86_64.whl",
+        "step.3 将 tensorrt 的 lib 路径添加进环境变量中,export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${TensorRT-Path}/lib/"
+    ]
+    print("\n".join(error_msg_zh))
+    sys.exit(1)
+
+import torch
+from cosyvoice.cli.cosyvoice import CosyVoice
+
+def get_args():
+    parser = argparse.ArgumentParser(description='Export your model for deployment')
+    parser.add_argument('--model_dir',
+                        type=str,
+                        default='pretrained_models/CosyVoice-300M-SFT',
+                        help='Local path to the model directory')
+
+    parser.add_argument('--export_half',
+                        action='store_true',
+                        help='Export with half precision (FP16)')
+    
+    args = parser.parse_args()
+    print(args)
+    return args
+
+def main():
+    args = get_args()
+
+    cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_trt=False)
+    estimator = cosyvoice.model.flow.decoder.estimator
+
+    dtype = torch.float32 if not args.export_half else torch.float16
+    device = torch.device("cuda")
+    batch_size = 1
+    seq_len = 256
+    hidden_size = cosyvoice.model.flow.output_size
+    x = torch.rand((batch_size, hidden_size, seq_len), dtype=dtype, device=device)
+    mask = torch.ones((batch_size, 1, seq_len), dtype=dtype, device=device)
+    mu = torch.rand((batch_size, hidden_size, seq_len), dtype=dtype, device=device)
+    t = torch.rand((batch_size, ), dtype=dtype, device=device)
+    spks = torch.rand((batch_size, hidden_size), dtype=dtype, device=device)
+    cond = torch.rand((batch_size, hidden_size, seq_len), dtype=dtype, device=device)
+
+    onnx_file_name = 'estimator_fp32.onnx' if not args.export_half else 'estimator_fp16.onnx'
+    onnx_file_path = os.path.join(args.model_dir, onnx_file_name)
+    dummy_input = (x, mask, mu, t, spks, cond)
+
+    estimator = estimator.to(dtype)
+
+    torch.onnx.export(
+        estimator,
+        dummy_input,
+        onnx_file_path,
+        export_params=True,
+        opset_version=18,
+        do_constant_folding=True,
+        input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
+        output_names=['estimator_out'],
+        dynamic_axes={
+            'x': {2: 'seq_len'},
+            'mask': {2: 'seq_len'},
+            'mu': {2: 'seq_len'},
+            'cond': {2: 'seq_len'},
+            'estimator_out': {2: 'seq_len'},
+        }
+    )
+
+    tensorrt_path = os.environ.get('tensorrt_root_dir')
+    if not tensorrt_path:
+        raise EnvironmentError("Please set the 'tensorrt_root_dir' environment variable.")
+
+    if not os.path.isdir(tensorrt_path):
+        raise FileNotFoundError(f"The directory {tensorrt_path} does not exist.")
+
+    trt_lib_path = os.path.join(tensorrt_path, "lib")
+    if trt_lib_path not in os.environ.get('LD_LIBRARY_PATH', ''):
+        print(f"Adding TensorRT lib path {trt_lib_path} to LD_LIBRARY_PATH.")
+        os.environ['LD_LIBRARY_PATH'] = f"{os.environ.get('LD_LIBRARY_PATH', '')}:{trt_lib_path}"
+
+    trt_file_name = 'estimator_fp32.plan' if not args.export_half else 'estimator_fp16.plan'
+    trt_file_path = os.path.join(args.model_dir, trt_file_name)
+
+    trtexec_bin = os.path.join(tensorrt_path, 'bin/trtexec')
+    trtexec_cmd = f"{trtexec_bin} --onnx={onnx_file_path} --saveEngine={trt_file_path} " \
+                  "--minShapes=x:1x80x1,mask:1x1x1,mu:1x80x1,t:1,spks:1x80,cond:1x80x1 " \
+                  "--maxShapes=x:1x80x4096,mask:1x1x4096,mu:1x80x4096,t:1,spks:1x80,cond:1x80x4096 --verbose " + \
+                  ("--fp16" if args.export_half else "")
+    
+    print("execute ", trtexec_cmd)
+
+    os.system(trtexec_cmd)
+
+    # print("x.shape", x.shape)
+    # print("mask.shape", mask.shape)
+    # print("mu.shape", mu.shape)
+    # print("t.shape", t.shape)
+    # print("spks.shape", spks.shape)
+    # print("cond.shape", cond.shape)
+
+if __name__ == "__main__":
+    main()

+ 5 - 1
cosyvoice/cli/cosyvoice.py

@@ -21,7 +21,7 @@ from cosyvoice.utils.file_utils import logging
 
 class CosyVoice:
 
-    def __init__(self, model_dir, load_jit=True):
+    def __init__(self, model_dir, load_jit=True, load_trt=True, use_fp16=False):
         instruct = True if '-Instruct' in model_dir else False
         self.model_dir = model_dir
         if not os.path.exists(model_dir):
@@ -39,9 +39,13 @@ class CosyVoice:
         self.model.load('{}/llm.pt'.format(model_dir),
                         '{}/flow.pt'.format(model_dir),
                         '{}/hift.pt'.format(model_dir))
+                        
         if load_jit:
             self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
                                     '{}/llm.llm.fp16.zip'.format(model_dir))
+        if load_trt:
+            self.model.load_trt(model_dir, use_fp16)
+            
         del configs
 
     def list_avaliable_spks(self):

+ 17 - 1
cosyvoice/cli/model.py

@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import os
 import torch
 import numpy as np
 import threading
@@ -19,7 +20,6 @@ from contextlib import nullcontext
 import uuid
 from cosyvoice.utils.common import fade_in_out
 
-
 class CosyVoiceModel:
 
     def __init__(self,
@@ -66,6 +66,22 @@ class CosyVoiceModel:
         llm_llm = torch.jit.load(llm_llm_model)
         self.llm.llm = llm_llm
 
+    def load_trt(self, model_dir, use_fp16):
+        import tensorrt as trt
+        trt_file_name = 'estimator_fp16.plan' if use_fp16 else 'estimator_fp32.plan'
+        trt_file_path = os.path.join(model_dir, trt_file_name)
+        if not os.path.isfile(trt_file_path):
+            raise f"{trt_file_path} does not exist. Please use bin/export_trt.py to generate .plan file"
+
+        trt.init_libnvinfer_plugins(None, "")
+        logger = trt.Logger(trt.Logger.WARNING)
+        runtime = trt.Runtime(logger)
+        with open(trt_file_path, 'rb') as f:
+            serialized_engine = f.read()
+        engine = runtime.deserialize_cuda_engine(serialized_engine)
+        self.flow.decoder.estimator_context = engine.create_execution_context()
+        self.flow.decoder.estimator = None
+
     def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
         with self.llm_context:
             for i in self.llm.inference(text=text.to(self.device),

+ 1 - 1
cosyvoice/flow/decoder.py

@@ -159,7 +159,7 @@ class ConditionalDecoder(nn.Module):
             _type_: _description_
         """
 
-        t = self.time_embeddings(t)
+        t = self.time_embeddings(t).to(t.dtype)
         t = self.time_mlp(t)
 
         x = pack([x, mu], "b * t")[0]

+ 1 - 1
cosyvoice/flow/flow.py

@@ -113,7 +113,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
         # concat text and prompt_text
         token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
         token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
-        mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding)
+        mask = (~make_pad_mask(token_len)).to(embedding.dtype).unsqueeze(-1).to(embedding)
         token = self.input_embedding(torch.clamp(token, min=0)) * mask
 
         # text encode

+ 29 - 1
cosyvoice/flow/flow_matching.py

@@ -50,7 +50,7 @@ class ConditionalCFM(BASECFM):
                 shape: (batch_size, n_feats, mel_timesteps)
         """
         z = torch.randn_like(mu) * temperature
-        t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
+        t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
         if self.t_scheduler == 'cosine':
             t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
         return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
@@ -71,6 +71,7 @@ class ConditionalCFM(BASECFM):
             cond: Not used but kept for future purposes
         """
         t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
+        t = t.unsqueeze(dim=0)
 
         # I am storing this because I can later plot it by putting a debugger here and saving it to a file
         # Or in future might add like a return_all_steps flag
@@ -96,6 +97,33 @@ class ConditionalCFM(BASECFM):
 
         return sol[-1]
 
+    def forward_estimator(self, x, mask, mu, t, spks, cond):
+
+        if self.estimator is not None:
+            return self.estimator.forward(x, mask, mu, t, spks, cond)
+        else:
+            assert self.training is False, 'tensorrt cannot be used in training'
+            bs = x.shape[0]
+            hs = x.shape[1]
+            seq_len = x.shape[2]
+            # assert bs == 1 and hs == 80
+            ret = torch.empty_like(x)
+            self.estimator_context.set_input_shape("x", x.shape)
+            self.estimator_context.set_input_shape("mask", mask.shape)
+            self.estimator_context.set_input_shape("mu", mu.shape)
+            self.estimator_context.set_input_shape("t", t.shape)
+            self.estimator_context.set_input_shape("spks", spks.shape)
+            self.estimator_context.set_input_shape("cond", cond.shape)
+            bindings = [x.data_ptr(), mask.data_ptr(), mu.data_ptr(), t.data_ptr(), spks.data_ptr(), cond.data_ptr(), ret.data_ptr()]
+            names = ['x', 'mask', 'mu', 't', 'spks', 'cond', 'estimator_out']
+            
+            for i in range(len(bindings)):
+                self.estimator_context.set_tensor_address(names[i], bindings[i])
+
+            handle = torch.cuda.current_stream().cuda_stream
+            self.estimator_context.execute_async_v3(stream_handle=handle)
+            return ret
+
     def compute_loss(self, x1, mask, mu, spks=None, cond=None):
         """Computes diffusion loss