average_model.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. # Copyright (c) 2020 Mobvoi Inc (Di Wu)
  2. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os
  16. import argparse
  17. import glob
  18. import yaml
  19. import torch
  20. def get_args():
  21. parser = argparse.ArgumentParser(description='average model')
  22. parser.add_argument('--dst_model', required=True, help='averaged model')
  23. parser.add_argument('--src_path',
  24. required=True,
  25. help='src model path for average')
  26. parser.add_argument('--val_best',
  27. action="store_true",
  28. help='averaged model')
  29. parser.add_argument('--num',
  30. default=5,
  31. type=int,
  32. help='nums for averaged model')
  33. args = parser.parse_args()
  34. print(args)
  35. return args
  36. def main():
  37. args = get_args()
  38. val_scores = []
  39. if args.val_best:
  40. yamls = glob.glob('{}/*.yaml'.format(args.src_path))
  41. yamls = [
  42. f for f in yamls
  43. if not (os.path.basename(f).startswith('train')
  44. or os.path.basename(f).startswith('init'))
  45. ]
  46. for y in yamls:
  47. with open(y, 'r') as f:
  48. dic_yaml = yaml.load(f, Loader=yaml.BaseLoader)
  49. loss = float(dic_yaml['loss_dict']['loss'])
  50. epoch = int(dic_yaml['epoch'])
  51. step = int(dic_yaml['step'])
  52. tag = dic_yaml['tag']
  53. val_scores += [[epoch, step, loss, tag]]
  54. sorted_val_scores = sorted(val_scores,
  55. key=lambda x: x[2],
  56. reverse=False)
  57. print("best val (epoch, step, loss, tag) = " +
  58. str(sorted_val_scores[:args.num]))
  59. path_list = [
  60. args.src_path + '/epoch_{}_whole.pt'.format(score[0])
  61. for score in sorted_val_scores[:args.num]
  62. ]
  63. print(path_list)
  64. avg = {}
  65. num = args.num
  66. assert num == len(path_list)
  67. for path in path_list:
  68. print('Processing {}'.format(path))
  69. states = torch.load(path, map_location=torch.device('cpu'))
  70. for k in states.keys():
  71. if k not in avg.keys():
  72. avg[k] = states[k].clone()
  73. else:
  74. avg[k] += states[k]
  75. # average
  76. for k in avg.keys():
  77. if avg[k] is not None:
  78. # pytorch 1.6 use true_divide instead of /=
  79. avg[k] = torch.true_divide(avg[k], num)
  80. print('Saving to {}'.format(args.dst_model))
  81. torch.save(avg, args.dst_model)
  82. if __name__ == '__main__':
  83. main()