##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # ##################################################### # python exps/NAS-Bench-201/check.py --base_save_dir ##################################################### import os, sys, time, argparse, collections from shutil import copyfile import torch import torch.nn as nn from pathlib import Path from collections import defaultdict lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) from log_utils import AverageMeter, time_string, convert_secs2time def check_files(save_dir, meta_file, basestr): meta_infos = torch.load(meta_file, map_location='cpu') meta_archs = meta_infos['archs'] meta_num_archs = meta_infos['total'] meta_max_node = meta_infos['max_node'] assert meta_num_archs == len(meta_archs), 'invalid number of archs : {:} vs {:}'.format(meta_num_archs, len(meta_archs)) sub_model_dirs = sorted(list(save_dir.glob('*-*-{:}'.format(basestr)))) print ('{:} find {:} directories used to save checkpoints'.format(time_string(), len(sub_model_dirs))) subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0 num_seeds = defaultdict(lambda: 0) for index, sub_dir in enumerate(sub_model_dirs): xcheckpoints = list(sub_dir.glob('arch-*-seed-*.pth')) #xcheckpoints = list(sub_dir.glob('arch-*-seed-0777.pth')) + list(sub_dir.glob('arch-*-seed-0888.pth')) + list(sub_dir.glob('arch-*-seed-0999.pth')) arch_indexes = set() for checkpoint in xcheckpoints: temp_names = checkpoint.name.split('-') assert len(temp_names) == 4 and temp_names[0] == 'arch' and temp_names[2] == 'seed', 'invalid checkpoint name : {:}'.format(checkpoint.name) arch_indexes.add( temp_names[1] ) subdir2archs[sub_dir] = sorted(list(arch_indexes)) num_evaluated_arch += len(arch_indexes) # count number of seeds for each architecture for arch_index in arch_indexes: num_seeds[ len(list(sub_dir.glob('arch-{:}-seed-*.pth'.format(arch_index)))) ] += 1 print('There are {:5d} architectures that have been evaluated ({:} in total, {:} ckps in total).'.format(num_evaluated_arch, meta_num_archs, sum(k*v for k, v in num_seeds.items()))) for key in sorted( list( num_seeds.keys() ) ): print ('There are {:5d} architectures that are evaluated {:} times.'.format(num_seeds[key], key)) dir2ckps, dir2ckp_exists = dict(), dict() start_time, epoch_time = time.time(), AverageMeter() for IDX, (sub_dir, arch_indexes) in enumerate(subdir2archs.items()): seeds = [777, 888, 999] numrs = defaultdict(lambda: 0) all_checkpoints, all_ckp_exists = [], [] for arch_index in arch_indexes: checkpoints = ['arch-{:}-seed-{:04d}.pth'.format(arch_index, seed) for seed in seeds] ckp_exists = [(sub_dir/x).exists() for x in checkpoints] arch_index = int(arch_index) assert 0 <= arch_index < len(meta_archs), 'invalid arch-index {:} (not found in meta_archs)'.format(arch_index) all_checkpoints += checkpoints all_ckp_exists += ckp_exists numrs[sum(ckp_exists)] += 1 dir2ckps[ str(sub_dir) ] = all_checkpoints dir2ckp_exists[ str(sub_dir) ] = all_ckp_exists # measure time epoch_time.update(time.time() - start_time) start_time = time.time() numrstr = ', '.join( ['{:}: {:03d}'.format(x, numrs[x]) for x in sorted(numrs.keys())] ) print('{:} load [{:2d}/{:2d}] [{:03d} archs] [{:04d}->{:04d} ckps] {:} done, need {:}. {:}'.format(time_string(), IDX+1, len(subdir2archs), len(arch_indexes), len(all_checkpoints), sum(all_ckp_exists), sub_dir, convert_secs2time(epoch_time.avg * (len(subdir2archs)-IDX-1), True), numrstr)) if __name__ == '__main__': parser = argparse.ArgumentParser(description='NAS Benchmark 201', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--base_save_dir', type=str, default='./output/NAS-BENCH-201-4', help='The base-name of folder to save checkpoints and log.') parser.add_argument('--max_node', type=int, default=4, help='The maximum node in a cell.') parser.add_argument('--channel', type=int, default=16, help='The number of channels.') parser.add_argument('--num_cells', type=int, default=5, help='The number of cells in one stage.') args = parser.parse_args() save_dir = Path( args.base_save_dir ) meta_path = save_dir / 'meta-node-{:}.pth'.format(args.max_node) assert save_dir.exists(), 'invalid save dir path : {:}'.format(save_dir) assert meta_path.exists(), 'invalid saved meta path : {:}'.format(meta_path) print ('check NAS-Bench-201 in {:}'.format(save_dir)) basestr = 'C{:}-N{:}'.format(args.channel, args.num_cells) check_files(save_dir, meta_path, basestr)