1
0

export_onnx.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. # Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com)
  2. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from __future__ import print_function
  16. import argparse
  17. import logging
  18. logging.getLogger('matplotlib').setLevel(logging.WARNING)
  19. import os
  20. import sys
  21. import onnxruntime
  22. import random
  23. import torch
  24. from tqdm import tqdm
  25. ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
  26. sys.path.append('{}/../..'.format(ROOT_DIR))
  27. sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
  28. from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
  29. from cosyvoice.utils.file_utils import logging
  30. def get_dummy_input(batch_size, seq_len, out_channels, device):
  31. x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
  32. mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
  33. mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
  34. t = torch.rand((batch_size), dtype=torch.float32, device=device)
  35. spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
  36. cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
  37. return x, mask, mu, t, spks, cond
  38. def get_args():
  39. parser = argparse.ArgumentParser(description='export your model for deployment')
  40. parser.add_argument('--model_dir',
  41. type=str,
  42. default='pretrained_models/CosyVoice-300M',
  43. help='local path')
  44. args = parser.parse_args()
  45. print(args)
  46. return args
  47. @torch.no_grad()
  48. def main():
  49. args = get_args()
  50. logging.basicConfig(level=logging.DEBUG,
  51. format='%(asctime)s %(levelname)s %(message)s')
  52. try:
  53. model = CosyVoice(args.model_dir)
  54. except Exception:
  55. try:
  56. model = CosyVoice2(args.model_dir)
  57. except Exception:
  58. raise TypeError('no valid model_type!')
  59. if not isinstance(model, CosyVoice2):
  60. # 1. export flow decoder estimator
  61. estimator = model.model.flow.decoder.estimator
  62. estimator.eval()
  63. device = model.model.device
  64. batch_size, seq_len = 2, 256
  65. out_channels = model.model.flow.decoder.estimator.out_channels
  66. x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
  67. torch.onnx.export(
  68. estimator,
  69. (x, mask, mu, t, spks, cond),
  70. '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
  71. export_params=True,
  72. opset_version=18,
  73. do_constant_folding=True,
  74. input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
  75. output_names=['estimator_out'],
  76. dynamic_axes={
  77. 'x': {2: 'seq_len'},
  78. 'mask': {2: 'seq_len'},
  79. 'mu': {2: 'seq_len'},
  80. 'cond': {2: 'seq_len'},
  81. 'estimator_out': {2: 'seq_len'},
  82. }
  83. )
  84. # 2. test computation consistency
  85. option = onnxruntime.SessionOptions()
  86. option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
  87. option.intra_op_num_threads = 1
  88. providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
  89. estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
  90. sess_options=option, providers=providers)
  91. for _ in tqdm(range(10)):
  92. x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
  93. output_pytorch = estimator(x, mask, mu, t, spks, cond)
  94. ort_inputs = {
  95. 'x': x.cpu().numpy(),
  96. 'mask': mask.cpu().numpy(),
  97. 'mu': mu.cpu().numpy(),
  98. 't': t.cpu().numpy(),
  99. 'spks': spks.cpu().numpy(),
  100. 'cond': cond.cpu().numpy()
  101. }
  102. output_onnx = estimator_onnx.run(None, ort_inputs)[0]
  103. torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
  104. logging.info('successfully export estimator')
  105. else:
  106. # 1. export flow decoder estimator
  107. estimator = model.model.flow.decoder.estimator
  108. estimator.forward = estimator.forward_chunk
  109. estimator.eval()
  110. device = model.model.device
  111. batch_size, seq_len = 2, 256
  112. out_channels = model.model.flow.decoder.estimator.out_channels
  113. x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
  114. cache = model.model.init_flow_cache()['decoder_cache']
  115. cache.pop('offset')
  116. cache = {k: v[0] for k, v in cache.items()}
  117. torch.onnx.export(
  118. estimator,
  119. (x, mask, mu, t, spks, cond,
  120. cache['down_blocks_conv_cache'],
  121. cache['down_blocks_kv_cache'],
  122. cache['mid_blocks_conv_cache'],
  123. cache['mid_blocks_kv_cache'],
  124. cache['up_blocks_conv_cache'],
  125. cache['up_blocks_kv_cache'],
  126. cache['final_blocks_conv_cache']),
  127. '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
  128. export_params=True,
  129. opset_version=18,
  130. do_constant_folding=True,
  131. input_names=['x', 'mask', 'mu', 't', 'spks', 'cond', 'down_blocks_conv_cache', 'down_blocks_kv_cache', 'mid_blocks_conv_cache', 'mid_blocks_kv_cache',
  132. 'up_blocks_conv_cache', 'up_blocks_kv_cache', 'final_blocks_conv_cache'],
  133. output_names=['estimator_out', 'down_blocks_conv_cache_out', 'down_blocks_kv_cache_out', 'mid_blocks_conv_cache_out', 'mid_blocks_kv_cache_out',
  134. 'up_blocks_conv_cache_out', 'up_blocks_kv_cache_out', 'final_blocks_conv_cache_out'],
  135. dynamic_axes={
  136. 'x': {2: 'seq_len'},
  137. 'mask': {2: 'seq_len'},
  138. 'mu': {2: 'seq_len'},
  139. 'cond': {2: 'seq_len'},
  140. 'down_blocks_kv_cache': {3: 'cache_in_len'},
  141. 'mid_blocks_kv_cache': {3: 'cache_in_len'},
  142. 'up_blocks_kv_cache': {3: 'cache_in_len'},
  143. 'estimator_out': {2: 'seq_len'},
  144. 'down_blocks_kv_cache_out': {3: 'cache_out_len'},
  145. 'mid_blocks_kv_cache_out': {3: 'cache_out_len'},
  146. 'up_blocks_kv_cache_out': {3: 'cache_out_len'},
  147. }
  148. )
  149. # 2. test computation consistency
  150. option = onnxruntime.SessionOptions()
  151. option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
  152. option.intra_op_num_threads = 1
  153. providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
  154. estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
  155. sess_options=option, providers=providers)
  156. for _ in tqdm(range(10)):
  157. x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
  158. cache = model.model.init_flow_cache()['decoder_cache']
  159. cache.pop('offset')
  160. cache = {k: v[0] for k, v in cache.items()}
  161. output_pytorch = estimator(x, mask, mu, t, spks, cond, **{k: v.clone() for k, v in cache.items()})
  162. ort_inputs = {
  163. 'x': x.cpu().numpy(),
  164. 'mask': mask.cpu().numpy(),
  165. 'mu': mu.cpu().numpy(),
  166. 't': t.cpu().numpy(),
  167. 'spks': spks.cpu().numpy(),
  168. 'cond': cond.cpu().numpy(),
  169. }
  170. output_onnx = estimator_onnx.run(None, {**ort_inputs, **{k: v.clone().cpu().numpy() for k, v in cache.items()}})
  171. for i, j in zip(output_pytorch, output_onnx):
  172. torch.testing.assert_allclose(i, torch.from_numpy(j).to(device), rtol=1e-2, atol=1e-4)
  173. logging.info('successfully export estimator')
  174. if __name__ == "__main__":
  175. main()