convert_checkpoint.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. import argparse
  2. import os
  3. import time
  4. import traceback
  5. from concurrent.futures import ThreadPoolExecutor, as_completed
  6. from transformers import AutoConfig
  7. import tensorrt_llm
  8. from tensorrt_llm._utils import release_gc
  9. from tensorrt_llm.logger import logger
  10. from tensorrt_llm.mapping import Mapping
  11. from tensorrt_llm.models import QWenForCausalLM
  12. from tensorrt_llm.models.modeling_utils import QuantConfig
  13. from tensorrt_llm.quantization import QuantAlgo
  14. def parse_arguments():
  15. parser = argparse.ArgumentParser()
  16. parser.add_argument('--model_dir', type=str, default=None, required=True)
  17. parser.add_argument('--tp_size',
  18. type=int,
  19. default=1,
  20. help='N-way tensor parallelism size')
  21. parser.add_argument('--pp_size',
  22. type=int,
  23. default=1,
  24. help='N-way pipeline parallelism size')
  25. parser.add_argument('--cp_size',
  26. type=int,
  27. default=1,
  28. help='N-way context parallelism size')
  29. parser.add_argument(
  30. '--dtype',
  31. type=str,
  32. default='auto',
  33. choices=['auto', 'float16', 'bfloat16', 'float32'],
  34. help=
  35. "The data type for the model weights and activations if not quantized. "
  36. "If 'auto', the data type is automatically inferred from the source model; "
  37. "however, if the source dtype is float32, it is converted to float16.")
  38. parser.add_argument(
  39. '--use_weight_only',
  40. default=False,
  41. action="store_true",
  42. help='Quantize weights for the various GEMMs to INT4/INT8.'
  43. 'See --weight_only_precision to set the precision')
  44. parser.add_argument(
  45. '--disable_weight_only_quant_plugin',
  46. default=False,
  47. action="store_true",
  48. help=
  49. 'By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.'
  50. 'You must also use --use_weight_only for that argument to have an impact.'
  51. )
  52. parser.add_argument(
  53. '--weight_only_precision',
  54. const='int8',
  55. type=str,
  56. nargs='?',
  57. default='int8',
  58. choices=['int8', 'int4', 'int4_gptq'],
  59. help=
  60. 'Define the precision for the weights when using weight-only quantization.'
  61. 'You must also use --use_weight_only for that argument to have an impact.'
  62. )
  63. parser.add_argument(
  64. '--calib_dataset',
  65. type=str,
  66. default='ccdv/cnn_dailymail',
  67. help=
  68. "The huggingface dataset name or the local directory of the dataset for calibration."
  69. )
  70. parser.add_argument(
  71. "--smoothquant",
  72. "-sq",
  73. type=float,
  74. default=None,
  75. help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)"
  76. " to Smoothquant the model, and output int8 weights."
  77. " A good first try is 0.5. Must be in [0, 1]")
  78. parser.add_argument(
  79. '--per_channel',
  80. action="store_true",
  81. default=False,
  82. help=
  83. 'By default, we use a single static scaling factor for the GEMM\'s result. '
  84. 'per_channel instead uses a different static scaling factor for each channel. '
  85. 'The latter is usually more accurate, but a little slower.')
  86. parser.add_argument(
  87. '--per_token',
  88. action="store_true",
  89. default=False,
  90. help=
  91. 'By default, we use a single static scaling factor to scale activations in the int8 range. '
  92. 'per_token chooses at run time, and for each token, a custom scaling factor. '
  93. 'The latter is usually more accurate, but a little slower.')
  94. parser.add_argument(
  95. '--int8_kv_cache',
  96. default=False,
  97. action="store_true",
  98. help=
  99. 'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
  100. )
  101. parser.add_argument(
  102. '--per_group',
  103. default=False,
  104. action="store_true",
  105. help=
  106. 'By default, we use a single static scaling factor to scale weights in the int4 range. '
  107. 'per_group chooses at run time, and for each group, a custom scaling factor. '
  108. 'The flag is built for GPTQ/AWQ quantization.')
  109. parser.add_argument('--group_size',
  110. type=int,
  111. default=128,
  112. help='Group size used in GPTQ quantization.')
  113. parser.add_argument("--load_model_on_cpu", action="store_true")
  114. parser.add_argument(
  115. '--use_parallel_embedding',
  116. action="store_true",
  117. default=False,
  118. help=
  119. 'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
  120. )
  121. parser.add_argument(
  122. '--embedding_sharding_dim',
  123. type=int,
  124. default=0,
  125. choices=[0, 1],
  126. help=
  127. 'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
  128. 'To shard it along hidden dimension, set embedding_sharding_dim=1'
  129. 'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
  130. )
  131. parser.add_argument('--output_dir',
  132. type=str,
  133. default='tllm_checkpoint',
  134. help='The path to save the TensorRT-LLM checkpoint')
  135. parser.add_argument(
  136. '--workers',
  137. type=int,
  138. default=1,
  139. help='The number of workers for converting checkpoint in parallel')
  140. parser.add_argument(
  141. '--moe_tp_size',
  142. type=int,
  143. default=-1,
  144. help=
  145. 'N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE'
  146. )
  147. parser.add_argument(
  148. '--moe_ep_size',
  149. type=int,
  150. default=-1,
  151. help=
  152. 'N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE'
  153. )
  154. args = parser.parse_args()
  155. return args
  156. def args_to_quant_config(args: argparse.Namespace) -> QuantConfig:
  157. '''return config dict with quantization info based on the command line args
  158. '''
  159. quant_config = QuantConfig()
  160. if args.use_weight_only:
  161. if args.weight_only_precision == 'int8':
  162. quant_config.quant_algo = QuantAlgo.W8A16
  163. elif args.weight_only_precision == 'int4':
  164. quant_config.quant_algo = QuantAlgo.W4A16
  165. elif args.smoothquant:
  166. quant_config.smoothquant_val = args.smoothquant
  167. if args.per_channel:
  168. if args.per_token:
  169. quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN
  170. else:
  171. quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN
  172. else:
  173. if args.per_token:
  174. quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN
  175. else:
  176. quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN
  177. if args.int8_kv_cache:
  178. quant_config.kv_cache_quant_algo = QuantAlgo.INT8
  179. if args.weight_only_precision == 'int4_gptq':
  180. quant_config.group_size = args.group_size
  181. quant_config.has_zero_point = True
  182. quant_config.pre_quant_scale = False
  183. quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
  184. return quant_config
  185. def update_quant_config_from_hf(quant_config, hf_config,
  186. override_fields) -> tuple[QuantConfig, dict]:
  187. hf_config_dict = hf_config.to_dict()
  188. if hf_config_dict.get('quantization_config'):
  189. # update the quant_algo, and clamp_val.
  190. if hf_config_dict['quantization_config'].get('quant_method') == 'awq':
  191. logger.info(
  192. "Load quantization configs from huggingface model_config.")
  193. quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
  194. quant_config.group_size = hf_config_dict['quantization_config'].get(
  195. 'group_size', 128)
  196. quant_config.has_zero_point = hf_config_dict[
  197. 'quantization_config'].get('zero_point', False)
  198. override_fields.update({"use_autoawq": True})
  199. elif hf_config_dict['quantization_config'].get(
  200. 'quant_method') == 'gptq':
  201. logger.info(
  202. "Load quantization configs from huggingface model_config.")
  203. desc_act = hf_config_dict['quantization_config'].get(
  204. 'desc_act', False)
  205. if desc_act:
  206. raise ValueError("GPTQ with desc_act=True is not implemented!")
  207. quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
  208. quant_config.group_size = hf_config_dict['quantization_config'].get(
  209. 'group_size', 128)
  210. quant_config.has_zero_point = hf_config_dict[
  211. 'quantization_config'].get('sym', False)
  212. return quant_config, override_fields
  213. def args_to_build_options(args):
  214. return {
  215. 'use_parallel_embedding': args.use_parallel_embedding,
  216. 'embedding_sharding_dim': args.embedding_sharding_dim,
  217. 'disable_weight_only_quant_plugin':
  218. args.disable_weight_only_quant_plugin
  219. }
  220. def convert_and_save_hf(args):
  221. model_dir = args.model_dir
  222. world_size = args.tp_size * args.pp_size
  223. # Need to convert the cli args to the kay-value pairs and override them in the generate config dict.
  224. # Ideally these fields will be moved out of the config and pass them into build API, keep them here for compatibility purpose for now,
  225. # before the refactor is done.
  226. override_fields = {}
  227. override_fields.update(args_to_build_options(args))
  228. quant_config = args_to_quant_config(args)
  229. try:
  230. hf_config = AutoConfig.from_pretrained(model_dir,
  231. trust_remote_code=True)
  232. quant_config, override_fields = update_quant_config_from_hf(
  233. quant_config, hf_config, override_fields)
  234. except:
  235. logger.warning("AutoConfig cannot load the huggingface config.")
  236. if args.smoothquant is not None or args.int8_kv_cache:
  237. mapping = Mapping(world_size=world_size,
  238. tp_size=args.tp_size,
  239. pp_size=args.pp_size,
  240. moe_tp_size=args.moe_tp_size,
  241. moe_ep_size=args.moe_ep_size,
  242. cp_size=args.cp_size)
  243. QWenForCausalLM.quantize(args.model_dir,
  244. args.output_dir,
  245. dtype=args.dtype,
  246. mapping=mapping,
  247. quant_config=quant_config,
  248. calib_dataset=args.calib_dataset,
  249. **override_fields)
  250. else:
  251. def convert_and_save_rank(args, rank):
  252. mapping = Mapping(world_size=world_size,
  253. rank=rank,
  254. tp_size=args.tp_size,
  255. pp_size=args.pp_size,
  256. moe_tp_size=args.moe_tp_size,
  257. moe_ep_size=args.moe_ep_size)
  258. qwen = QWenForCausalLM.from_hugging_face(model_dir,
  259. args.dtype,
  260. mapping=mapping,
  261. quant_config=quant_config,
  262. **override_fields)
  263. qwen.config.mapping.cp_size = args.cp_size
  264. qwen.config.mapping.attn_tp_size = -1
  265. qwen.config.mapping.attn_cp_size = -1
  266. qwen.config.mapping.world_size *= args.cp_size
  267. qwen.save_checkpoint(args.output_dir, save_config=(rank == 0))
  268. del qwen
  269. execute(args.workers, [convert_and_save_rank] * world_size, args)
  270. release_gc()
  271. def execute(workers, func, args):
  272. if workers == 1:
  273. for rank, f in enumerate(func):
  274. f(args, rank)
  275. else:
  276. with ThreadPoolExecutor(max_workers=workers) as p:
  277. futures = [p.submit(f, args, rank) for rank, f in enumerate(func)]
  278. exceptions = []
  279. for future in as_completed(futures):
  280. try:
  281. future.result()
  282. except Exception as e:
  283. traceback.print_exc()
  284. exceptions.append(e)
  285. assert len(
  286. exceptions
  287. ) == 0, "Checkpoint conversion failed, please check error log."
  288. def main():
  289. print(tensorrt_llm.__version__)
  290. args = parse_arguments()
  291. if (args.moe_tp_size == -1 and args.moe_ep_size == -1):
  292. # moe default to tp-only
  293. args.moe_tp_size = args.tp_size
  294. args.moe_ep_size = 1
  295. elif (args.moe_tp_size == -1):
  296. args.moe_tp_size = args.tp_size // args.moe_ep_size
  297. elif (args.moe_ep_size == -1):
  298. args.moe_ep_size = args.tp_size // args.moe_tp_size
  299. assert (args.moe_tp_size * args.moe_ep_size == args.tp_size
  300. ), "moe_tp_size * moe_ep_size must equal to tp_size"
  301. tik = time.time()
  302. if not os.path.exists(args.output_dir):
  303. os.makedirs(args.output_dir)
  304. assert args.model_dir is not None
  305. convert_and_save_hf(args)
  306. tok = time.time()
  307. t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
  308. print(f'Total time of converting checkpoints: {t}')
  309. if __name__ == '__main__':
  310. main()