convert_checkpoint.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. import argparse
  2. import os
  3. import time
  4. import traceback
  5. from concurrent.futures import ThreadPoolExecutor, as_completed
  6. from transformers import AutoConfig
  7. import tensorrt_llm
  8. from tensorrt_llm._utils import release_gc
  9. from tensorrt_llm.logger import logger
  10. from tensorrt_llm.mapping import Mapping
  11. from tensorrt_llm.models import QWenForCausalLM
  12. from tensorrt_llm.models.modeling_utils import QuantConfig
  13. from tensorrt_llm.quantization import QuantAlgo
  14. def parse_arguments():
  15. parser = argparse.ArgumentParser()
  16. parser.add_argument('--model_dir', type=str, default=None, required=True)
  17. parser.add_argument('--tp_size',
  18. type=int,
  19. default=1,
  20. help='N-way tensor parallelism size')
  21. parser.add_argument('--pp_size',
  22. type=int,
  23. default=1,
  24. help='N-way pipeline parallelism size')
  25. parser.add_argument('--cp_size',
  26. type=int,
  27. default=1,
  28. help='N-way context parallelism size')
  29. parser.add_argument(
  30. '--dtype',
  31. type=str,
  32. default='auto',
  33. choices=['auto', 'float16', 'bfloat16', 'float32'],
  34. help="The data type for the model weights and activations if not quantized. "
  35. "If 'auto', the data type is automatically inferred from the source model; "
  36. "however, if the source dtype is float32, it is converted to float16.")
  37. parser.add_argument(
  38. '--use_weight_only',
  39. default=False,
  40. action="store_true",
  41. help='Quantize weights for the various GEMMs to INT4/INT8.'
  42. 'See --weight_only_precision to set the precision')
  43. parser.add_argument(
  44. '--disable_weight_only_quant_plugin',
  45. default=False,
  46. action="store_true",
  47. help='By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.'
  48. 'You must also use --use_weight_only for that argument to have an impact.'
  49. )
  50. parser.add_argument(
  51. '--weight_only_precision',
  52. const='int8',
  53. type=str,
  54. nargs='?',
  55. default='int8',
  56. choices=['int8', 'int4', 'int4_gptq'],
  57. help='Define the precision for the weights when using weight-only quantization.'
  58. 'You must also use --use_weight_only for that argument to have an impact.'
  59. )
  60. parser.add_argument(
  61. '--calib_dataset',
  62. type=str,
  63. default='ccdv/cnn_dailymail',
  64. help="The huggingface dataset name or the local directory of the dataset for calibration."
  65. )
  66. parser.add_argument(
  67. "--smoothquant",
  68. "-sq",
  69. type=float,
  70. default=None,
  71. help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)"
  72. " to Smoothquant the model, and output int8 weights."
  73. " A good first try is 0.5. Must be in [0, 1]")
  74. parser.add_argument(
  75. '--per_channel',
  76. action="store_true",
  77. default=False,
  78. help='By default, we use a single static scaling factor for the GEMM\'s result. '
  79. 'per_channel instead uses a different static scaling factor for each channel. '
  80. 'The latter is usually more accurate, but a little slower.')
  81. parser.add_argument(
  82. '--per_token',
  83. action="store_true",
  84. default=False,
  85. help='By default, we use a single static scaling factor to scale activations in the int8 range. '
  86. 'per_token chooses at run time, and for each token, a custom scaling factor. '
  87. 'The latter is usually more accurate, but a little slower.')
  88. parser.add_argument(
  89. '--int8_kv_cache',
  90. default=False,
  91. action="store_true",
  92. help='By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
  93. )
  94. parser.add_argument(
  95. '--per_group',
  96. default=False,
  97. action="store_true",
  98. help='By default, we use a single static scaling factor to scale weights in the int4 range. '
  99. 'per_group chooses at run time, and for each group, a custom scaling factor. '
  100. 'The flag is built for GPTQ/AWQ quantization.')
  101. parser.add_argument('--group_size',
  102. type=int,
  103. default=128,
  104. help='Group size used in GPTQ quantization.')
  105. parser.add_argument("--load_model_on_cpu", action="store_true")
  106. parser.add_argument(
  107. '--use_parallel_embedding',
  108. action="store_true",
  109. default=False,
  110. help='By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
  111. )
  112. parser.add_argument(
  113. '--embedding_sharding_dim',
  114. type=int,
  115. default=0,
  116. choices=[0, 1],
  117. help='By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
  118. 'To shard it along hidden dimension, set embedding_sharding_dim=1'
  119. 'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
  120. )
  121. parser.add_argument('--output_dir',
  122. type=str,
  123. default='tllm_checkpoint',
  124. help='The path to save the TensorRT-LLM checkpoint')
  125. parser.add_argument(
  126. '--workers',
  127. type=int,
  128. default=1,
  129. help='The number of workers for converting checkpoint in parallel')
  130. parser.add_argument(
  131. '--moe_tp_size',
  132. type=int,
  133. default=-1,
  134. help='N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE'
  135. )
  136. parser.add_argument(
  137. '--moe_ep_size',
  138. type=int,
  139. default=-1,
  140. help='N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE'
  141. )
  142. args = parser.parse_args()
  143. return args
  144. def args_to_quant_config(args: argparse.Namespace) -> QuantConfig:
  145. '''return config dict with quantization info based on the command line args
  146. '''
  147. quant_config = QuantConfig()
  148. if args.use_weight_only:
  149. if args.weight_only_precision == 'int8':
  150. quant_config.quant_algo = QuantAlgo.W8A16
  151. elif args.weight_only_precision == 'int4':
  152. quant_config.quant_algo = QuantAlgo.W4A16
  153. elif args.smoothquant:
  154. quant_config.smoothquant_val = args.smoothquant
  155. if args.per_channel:
  156. if args.per_token:
  157. quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN
  158. else:
  159. quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN
  160. else:
  161. if args.per_token:
  162. quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN
  163. else:
  164. quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN
  165. if args.int8_kv_cache:
  166. quant_config.kv_cache_quant_algo = QuantAlgo.INT8
  167. if args.weight_only_precision == 'int4_gptq':
  168. quant_config.group_size = args.group_size
  169. quant_config.has_zero_point = True
  170. quant_config.pre_quant_scale = False
  171. quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
  172. return quant_config
  173. def update_quant_config_from_hf(quant_config, hf_config,
  174. override_fields) -> tuple[QuantConfig, dict]:
  175. hf_config_dict = hf_config.to_dict()
  176. if hf_config_dict.get('quantization_config'):
  177. # update the quant_algo, and clamp_val.
  178. if hf_config_dict['quantization_config'].get('quant_method') == 'awq':
  179. logger.info(
  180. "Load quantization configs from huggingface model_config.")
  181. quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
  182. quant_config.group_size = hf_config_dict['quantization_config'].get(
  183. 'group_size', 128)
  184. quant_config.has_zero_point = hf_config_dict[
  185. 'quantization_config'].get('zero_point', False)
  186. override_fields.update({"use_autoawq": True})
  187. elif hf_config_dict['quantization_config'].get(
  188. 'quant_method') == 'gptq':
  189. logger.info(
  190. "Load quantization configs from huggingface model_config.")
  191. desc_act = hf_config_dict['quantization_config'].get(
  192. 'desc_act', False)
  193. if desc_act:
  194. raise ValueError("GPTQ with desc_act=True is not implemented!")
  195. quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
  196. quant_config.group_size = hf_config_dict['quantization_config'].get(
  197. 'group_size', 128)
  198. quant_config.has_zero_point = hf_config_dict[
  199. 'quantization_config'].get('sym', False)
  200. return quant_config, override_fields
  201. def args_to_build_options(args):
  202. return {
  203. 'use_parallel_embedding': args.use_parallel_embedding,
  204. 'embedding_sharding_dim': args.embedding_sharding_dim,
  205. 'disable_weight_only_quant_plugin':
  206. args.disable_weight_only_quant_plugin
  207. }
  208. def convert_and_save_hf(args):
  209. model_dir = args.model_dir
  210. world_size = args.tp_size * args.pp_size
  211. # Need to convert the cli args to the kay-value pairs and override them in the generate config dict.
  212. # Ideally these fields will be moved out of the config and pass them into build API, keep them here for compatibility purpose for now,
  213. # before the refactor is done.
  214. override_fields = {}
  215. override_fields.update(args_to_build_options(args))
  216. quant_config = args_to_quant_config(args)
  217. try:
  218. hf_config = AutoConfig.from_pretrained(model_dir,
  219. trust_remote_code=True)
  220. quant_config, override_fields = update_quant_config_from_hf(
  221. quant_config, hf_config, override_fields)
  222. except BaseException:
  223. logger.warning("AutoConfig cannot load the huggingface config.")
  224. if args.smoothquant is not None or args.int8_kv_cache:
  225. mapping = Mapping(world_size=world_size,
  226. tp_size=args.tp_size,
  227. pp_size=args.pp_size,
  228. moe_tp_size=args.moe_tp_size,
  229. moe_ep_size=args.moe_ep_size,
  230. cp_size=args.cp_size)
  231. QWenForCausalLM.quantize(args.model_dir,
  232. args.output_dir,
  233. dtype=args.dtype,
  234. mapping=mapping,
  235. quant_config=quant_config,
  236. calib_dataset=args.calib_dataset,
  237. **override_fields)
  238. else:
  239. def convert_and_save_rank(args, rank):
  240. mapping = Mapping(world_size=world_size,
  241. rank=rank,
  242. tp_size=args.tp_size,
  243. pp_size=args.pp_size,
  244. moe_tp_size=args.moe_tp_size,
  245. moe_ep_size=args.moe_ep_size)
  246. qwen = QWenForCausalLM.from_hugging_face(model_dir,
  247. args.dtype,
  248. mapping=mapping,
  249. quant_config=quant_config,
  250. **override_fields)
  251. qwen.config.mapping.cp_size = args.cp_size
  252. qwen.config.mapping.attn_tp_size = -1
  253. qwen.config.mapping.attn_cp_size = -1
  254. qwen.config.mapping.world_size *= args.cp_size
  255. qwen.save_checkpoint(args.output_dir, save_config=(rank == 0))
  256. del qwen
  257. execute(args.workers, [convert_and_save_rank] * world_size, args)
  258. release_gc()
  259. def execute(workers, func, args):
  260. if workers == 1:
  261. for rank, f in enumerate(func):
  262. f(args, rank)
  263. else:
  264. with ThreadPoolExecutor(max_workers=workers) as p:
  265. futures = [p.submit(f, args, rank) for rank, f in enumerate(func)]
  266. exceptions = []
  267. for future in as_completed(futures):
  268. try:
  269. future.result()
  270. except Exception as e:
  271. traceback.print_exc()
  272. exceptions.append(e)
  273. assert len(
  274. exceptions
  275. ) == 0, "Checkpoint conversion failed, please check error log."
  276. def main():
  277. print(tensorrt_llm.__version__)
  278. args = parse_arguments()
  279. if (args.moe_tp_size == -1 and args.moe_ep_size == -1):
  280. # moe default to tp-only
  281. args.moe_tp_size = args.tp_size
  282. args.moe_ep_size = 1
  283. elif (args.moe_tp_size == -1):
  284. args.moe_tp_size = args.tp_size // args.moe_ep_size
  285. elif (args.moe_ep_size == -1):
  286. args.moe_ep_size = args.tp_size // args.moe_tp_size
  287. assert (args.moe_tp_size * args.moe_ep_size == args.tp_size
  288. ), "moe_tp_size * moe_ep_size must equal to tp_size"
  289. tik = time.time()
  290. if not os.path.exists(args.output_dir):
  291. os.makedirs(args.output_dir)
  292. assert args.model_dir is not None
  293. convert_and_save_hf(args)
  294. tok = time.time()
  295. t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
  296. print(f'Total time of converting checkpoints: {t}')
  297. if __name__ == '__main__':
  298. main()