test_llm.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2. # SPDX-License-Identifier: Apache-2.0
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import argparse
  16. import numpy as np
  17. import torch
  18. import tensorrt_llm
  19. from tensorrt_llm.logger import logger
  20. from tensorrt_llm.runtime import ModelRunnerCpp
  21. from transformers import AutoTokenizer
  22. def parse_arguments(args=None):
  23. parser = argparse.ArgumentParser()
  24. parser.add_argument(
  25. '--input_text',
  26. type=str,
  27. nargs='+',
  28. default=["Born in north-east France, Soyer trained as a"])
  29. parser.add_argument('--tokenizer_dir', type=str, default="meta-llama/Meta-Llama-3-8B-Instruct")
  30. parser.add_argument('--engine_dir', type=str, default="meta-llama/Meta-Llama-3-8B-Instruct")
  31. parser.add_argument('--log_level', type=str, default="debug")
  32. parser.add_argument('--kv_cache_free_gpu_memory_fraction', type=float, default=0.6)
  33. parser.add_argument('--temperature', type=float, default=0.8)
  34. parser.add_argument('--top_k', type=int, default=50)
  35. parser.add_argument('--top_p', type=float, default=0.95)
  36. return parser.parse_args(args=args)
  37. def parse_input(tokenizer,
  38. input_text=None,
  39. prompt_template=None):
  40. batch_input_ids = []
  41. for curr_text in input_text:
  42. if prompt_template is not None:
  43. curr_text = prompt_template.format(input_text=curr_text)
  44. input_ids = tokenizer.encode(
  45. curr_text)
  46. batch_input_ids.append(input_ids)
  47. batch_input_ids = [
  48. torch.tensor(x, dtype=torch.int32) for x in batch_input_ids
  49. ]
  50. logger.debug(f"Input token ids (batch_size = {len(batch_input_ids)}):")
  51. for i, input_ids in enumerate(batch_input_ids):
  52. logger.debug(f"Request {i}: {input_ids.tolist()}")
  53. return batch_input_ids
  54. def main(args):
  55. runtime_rank = tensorrt_llm.mpi_rank()
  56. logger.set_level(args.log_level)
  57. tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir)
  58. prompt_template = "<|sos|>{input_text}<|task_id|>"
  59. end_id = tokenizer.convert_tokens_to_ids("<|eos1|>")
  60. batch_input_ids = parse_input(tokenizer=tokenizer,
  61. input_text=args.input_text,
  62. prompt_template=prompt_template)
  63. input_lengths = [x.size(0) for x in batch_input_ids]
  64. runner_kwargs = dict(
  65. engine_dir=args.engine_dir,
  66. rank=runtime_rank,
  67. max_output_len=1024,
  68. enable_context_fmha_fp32_acc=False,
  69. max_batch_size=len(batch_input_ids),
  70. max_input_len=max(input_lengths),
  71. kv_cache_free_gpu_memory_fraction=args.kv_cache_free_gpu_memory_fraction,
  72. cuda_graph_mode=False,
  73. gather_generation_logits=False,
  74. )
  75. runner = ModelRunnerCpp.from_dir(**runner_kwargs)
  76. with torch.no_grad():
  77. outputs = runner.generate(
  78. batch_input_ids=batch_input_ids,
  79. max_new_tokens=1024,
  80. end_id=end_id,
  81. pad_id=end_id,
  82. temperature=args.temperature,
  83. top_k=args.top_k,
  84. top_p=args.top_p,
  85. num_return_sequences=1,
  86. repetition_penalty=1.1,
  87. random_seed=42,
  88. streaming=False,
  89. output_sequence_lengths=True,
  90. output_generation_logits=False,
  91. return_dict=True,
  92. return_all_generated_tokens=False)
  93. torch.cuda.synchronize()
  94. output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"]
  95. num_output_sents, num_beams, _ = output_ids.size()
  96. assert num_beams == 1
  97. beam = 0
  98. batch_size = len(input_lengths)
  99. num_return_sequences = num_output_sents // batch_size
  100. assert num_return_sequences == 1
  101. for i in range(batch_size * num_return_sequences):
  102. batch_idx = i // num_return_sequences
  103. seq_idx = i % num_return_sequences
  104. inputs = output_ids[i][0][:input_lengths[batch_idx]].tolist()
  105. input_text = tokenizer.decode(inputs)
  106. print(f'Input [Text {batch_idx}]: \"{input_text}\"')
  107. output_begin = input_lengths[batch_idx]
  108. output_end = sequence_lengths[i][beam]
  109. outputs = output_ids[i][beam][output_begin:output_end].tolist()
  110. output_text = tokenizer.decode(outputs)
  111. print(f'Output [Text {batch_idx}]: \"{output_text}\"')
  112. logger.debug(str(outputs))
  113. if __name__ == '__main__':
  114. args = parse_arguments()
  115. main(args)