file_utils.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
  2. # 2024 Alibaba Inc (authors: Xiang Lyu, Zetao Hu)
  3. # 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li)
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import os
  17. import json
  18. import torch
  19. import torchaudio
  20. import logging
  21. logging.getLogger('matplotlib').setLevel(logging.WARNING)
  22. logging.basicConfig(level=logging.DEBUG,
  23. format='%(asctime)s %(levelname)s %(message)s')
  24. def read_lists(list_file):
  25. lists = []
  26. with open(list_file, 'r', encoding='utf8') as fin:
  27. for line in fin:
  28. lists.append(line.strip())
  29. return lists
  30. def read_json_lists(list_file):
  31. lists = read_lists(list_file)
  32. results = {}
  33. for fn in lists:
  34. with open(fn, 'r', encoding='utf8') as fin:
  35. results.update(json.load(fin))
  36. return results
  37. def load_wav(wav, target_sr, min_sr=16000):
  38. speech, sample_rate = torchaudio.load(wav, backend='soundfile')
  39. speech = speech.mean(dim=0, keepdim=True)
  40. if sample_rate != target_sr:
  41. assert sample_rate >= min_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
  42. speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
  43. return speech
  44. def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
  45. import tensorrt as trt
  46. logging.info("Converting onnx to trt...")
  47. network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
  48. logger = trt.Logger(trt.Logger.INFO)
  49. builder = trt.Builder(logger)
  50. network = builder.create_network(network_flags)
  51. parser = trt.OnnxParser(network, logger)
  52. config = builder.create_builder_config()
  53. config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB
  54. if fp16:
  55. config.set_flag(trt.BuilderFlag.FP16)
  56. profile = builder.create_optimization_profile()
  57. # load onnx model
  58. with open(onnx_model, "rb") as f:
  59. if not parser.parse(f.read()):
  60. for error in range(parser.num_errors):
  61. print(parser.get_error(error))
  62. raise ValueError('failed to parse {}'.format(onnx_model))
  63. # set input shapes
  64. for i in range(len(trt_kwargs['input_names'])):
  65. profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
  66. tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
  67. # set input and output data type
  68. for i in range(network.num_inputs):
  69. input_tensor = network.get_input(i)
  70. input_tensor.dtype = tensor_dtype
  71. for i in range(network.num_outputs):
  72. output_tensor = network.get_output(i)
  73. output_tensor.dtype = tensor_dtype
  74. config.add_optimization_profile(profile)
  75. engine_bytes = builder.build_serialized_network(network, config)
  76. # save trt engine
  77. with open(trt_model, "wb") as f:
  78. f.write(engine_bytes)
  79. logging.info("Succesfully convert onnx to trt...")
  80. # NOTE do not support bistream inference as only speech token embedding/head is kept
  81. def export_cosyvoice2_vllm(model, model_path, device):
  82. if os.path.exists(model_path):
  83. return
  84. dtype = torch.bfloat16
  85. # lm_head
  86. use_bias = True if model.llm_decoder.bias is not None else False
  87. model.llm.model.lm_head = model.llm_decoder
  88. # embed_tokens
  89. embed_tokens = model.llm.model.model.embed_tokens
  90. model.llm.model.set_input_embeddings(model.speech_embedding)
  91. model.llm.model.to(device)
  92. model.llm.model.to(dtype)
  93. tmp_vocab_size = model.llm.model.config.vocab_size
  94. tmp_tie_embedding = model.llm.model.config.tie_word_embeddings
  95. del model.llm.model.generation_config.eos_token_id
  96. del model.llm.model.config.bos_token_id
  97. del model.llm.model.config.eos_token_id
  98. model.llm.model.config.vocab_size = model.speech_embedding.num_embeddings
  99. model.llm.model.config.tie_word_embeddings = False
  100. model.llm.model.config.use_bias = use_bias
  101. model.llm.model.save_pretrained(model_path)
  102. if use_bias is True:
  103. os.system('sed -i s@Qwen2ForCausalLM@CosyVoice2ForCausalLM@g {}/config.json'.format(os.path.abspath(model_path)))
  104. model.llm.model.config.vocab_size = tmp_vocab_size
  105. model.llm.model.config.tie_word_embeddings = tmp_tie_embedding
  106. model.llm.model.set_input_embeddings(embed_tokens)