average_model.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. # Copyright (c) 2020 Mobvoi Inc (Di Wu)
  2. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os
  16. import argparse
  17. import glob
  18. import sys
  19. import yaml
  20. import torch
  21. def get_args():
  22. parser = argparse.ArgumentParser(description='average model')
  23. parser.add_argument('--dst_model', required=True, help='averaged model')
  24. parser.add_argument('--src_path',
  25. required=True,
  26. help='src model path for average')
  27. parser.add_argument('--val_best',
  28. action="store_true",
  29. help='averaged model')
  30. parser.add_argument('--num',
  31. default=5,
  32. type=int,
  33. help='nums for averaged model')
  34. args = parser.parse_args()
  35. print(args)
  36. return args
  37. def main():
  38. args = get_args()
  39. val_scores = []
  40. if args.val_best:
  41. yamls = glob.glob('{}/*.yaml'.format(args.src_path))
  42. yamls = [
  43. f for f in yamls
  44. if not (os.path.basename(f).startswith('train')
  45. or os.path.basename(f).startswith('init'))
  46. ]
  47. for y in yamls:
  48. with open(y, 'r') as f:
  49. dic_yaml = yaml.load(f, Loader=yaml.BaseLoader)
  50. loss = float(dic_yaml['loss_dict']['loss'])
  51. epoch = int(dic_yaml['epoch'])
  52. step = int(dic_yaml['step'])
  53. tag = dic_yaml['tag']
  54. val_scores += [[epoch, step, loss, tag]]
  55. sorted_val_scores = sorted(val_scores,
  56. key=lambda x: x[2],
  57. reverse=False)
  58. print("best val (epoch, step, loss, tag) = " +
  59. str(sorted_val_scores[:args.num]))
  60. path_list = [
  61. args.src_path + '/epoch_{}_whole.pt'.format(score[0])
  62. for score in sorted_val_scores[:args.num]
  63. ]
  64. print(path_list)
  65. avg = {}
  66. num = args.num
  67. assert num == len(path_list)
  68. for path in path_list:
  69. print('Processing {}'.format(path))
  70. states = torch.load(path, map_location=torch.device('cpu'))
  71. for k in states.keys():
  72. if k not in avg.keys():
  73. avg[k] = states[k].clone()
  74. else:
  75. avg[k] += states[k]
  76. # average
  77. for k in avg.keys():
  78. if avg[k] is not None:
  79. # pytorch 1.6 use true_divide instead of /=
  80. avg[k] = torch.true_divide(avg[k], num)
  81. print('Saving to {}'.format(args.dst_model))
  82. torch.save(avg, args.dst_model)
  83. if __name__ == '__main__':
  84. main()