2020-02-23 00:30:37 +01:00
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
#####################################################
2020-01-14 14:52:06 +01:00
# python exps/NAS-Bench-201/check.py --base_save_dir
2020-02-23 00:30:37 +01:00
#####################################################
2019-12-26 13:29:36 +01:00
import os , sys , time , argparse , collections
from shutil import copyfile
import torch
import torch . nn as nn
from pathlib import Path
from collections import defaultdict
lib_dir = ( Path ( __file__ ) . parent / ' .. ' / ' .. ' / ' lib ' ) . resolve ( )
if str ( lib_dir ) not in sys . path : sys . path . insert ( 0 , str ( lib_dir ) )
from log_utils import AverageMeter , time_string , convert_secs2time
def check_files ( save_dir , meta_file , basestr ) :
meta_infos = torch . load ( meta_file , map_location = ' cpu ' )
meta_archs = meta_infos [ ' archs ' ]
meta_num_archs = meta_infos [ ' total ' ]
meta_max_node = meta_infos [ ' max_node ' ]
assert meta_num_archs == len ( meta_archs ) , ' invalid number of archs : {:} vs {:} ' . format ( meta_num_archs , len ( meta_archs ) )
sub_model_dirs = sorted ( list ( save_dir . glob ( ' *-*- {:} ' . format ( basestr ) ) ) )
print ( ' {:} find {:} directories used to save checkpoints ' . format ( time_string ( ) , len ( sub_model_dirs ) ) )
subdir2archs , num_evaluated_arch = collections . OrderedDict ( ) , 0
num_seeds = defaultdict ( lambda : 0 )
for index , sub_dir in enumerate ( sub_model_dirs ) :
xcheckpoints = list ( sub_dir . glob ( ' arch-*-seed-*.pth ' ) )
#xcheckpoints = list(sub_dir.glob('arch-*-seed-0777.pth')) + list(sub_dir.glob('arch-*-seed-0888.pth')) + list(sub_dir.glob('arch-*-seed-0999.pth'))
arch_indexes = set ( )
for checkpoint in xcheckpoints :
temp_names = checkpoint . name . split ( ' - ' )
assert len ( temp_names ) == 4 and temp_names [ 0 ] == ' arch ' and temp_names [ 2 ] == ' seed ' , ' invalid checkpoint name : {:} ' . format ( checkpoint . name )
arch_indexes . add ( temp_names [ 1 ] )
subdir2archs [ sub_dir ] = sorted ( list ( arch_indexes ) )
num_evaluated_arch + = len ( arch_indexes )
# count number of seeds for each architecture
for arch_index in arch_indexes :
num_seeds [ len ( list ( sub_dir . glob ( ' arch- {:} -seed-*.pth ' . format ( arch_index ) ) ) ) ] + = 1
print ( ' There are {:5d} architectures that have been evaluated ( {:} in total, {:} ckps in total). ' . format ( num_evaluated_arch , meta_num_archs , sum ( k * v for k , v in num_seeds . items ( ) ) ) )
for key in sorted ( list ( num_seeds . keys ( ) ) ) : print ( ' There are {:5d} architectures that are evaluated {:} times. ' . format ( num_seeds [ key ] , key ) )
dir2ckps , dir2ckp_exists = dict ( ) , dict ( )
start_time , epoch_time = time . time ( ) , AverageMeter ( )
for IDX , ( sub_dir , arch_indexes ) in enumerate ( subdir2archs . items ( ) ) :
seeds = [ 777 , 888 , 999 ]
numrs = defaultdict ( lambda : 0 )
all_checkpoints , all_ckp_exists = [ ] , [ ]
for arch_index in arch_indexes :
checkpoints = [ ' arch- {:} -seed- {:04d} .pth ' . format ( arch_index , seed ) for seed in seeds ]
ckp_exists = [ ( sub_dir / x ) . exists ( ) for x in checkpoints ]
arch_index = int ( arch_index )
assert 0 < = arch_index < len ( meta_archs ) , ' invalid arch-index {:} (not found in meta_archs) ' . format ( arch_index )
all_checkpoints + = checkpoints
all_ckp_exists + = ckp_exists
numrs [ sum ( ckp_exists ) ] + = 1
dir2ckps [ str ( sub_dir ) ] = all_checkpoints
dir2ckp_exists [ str ( sub_dir ) ] = all_ckp_exists
# measure time
epoch_time . update ( time . time ( ) - start_time )
start_time = time . time ( )
numrstr = ' , ' . join ( [ ' {:} : {:03d} ' . format ( x , numrs [ x ] ) for x in sorted ( numrs . keys ( ) ) ] )
print ( ' {:} load [ {:2d} / {:2d} ] [ {:03d} archs] [ {:04d} -> {:04d} ckps] {:} done, need {:} . {:} ' . format ( time_string ( ) , IDX + 1 , len ( subdir2archs ) , len ( arch_indexes ) , len ( all_checkpoints ) , sum ( all_ckp_exists ) , sub_dir , convert_secs2time ( epoch_time . avg * ( len ( subdir2archs ) - IDX - 1 ) , True ) , numrstr ) )
if __name__ == ' __main__ ' :
2020-01-14 14:52:06 +01:00
parser = argparse . ArgumentParser ( description = ' NAS Benchmark 201 ' , formatter_class = argparse . ArgumentDefaultsHelpFormatter )
parser . add_argument ( ' --base_save_dir ' , type = str , default = ' ./output/NAS-BENCH-201-4 ' , help = ' The base-name of folder to save checkpoints and log. ' )
2019-12-26 13:29:36 +01:00
parser . add_argument ( ' --max_node ' , type = int , default = 4 , help = ' The maximum node in a cell. ' )
parser . add_argument ( ' --channel ' , type = int , default = 16 , help = ' The number of channels. ' )
parser . add_argument ( ' --num_cells ' , type = int , default = 5 , help = ' The number of cells in one stage. ' )
args = parser . parse_args ( )
save_dir = Path ( args . base_save_dir )
meta_path = save_dir / ' meta-node- {:} .pth ' . format ( args . max_node )
assert save_dir . exists ( ) , ' invalid save dir path : {:} ' . format ( save_dir )
assert meta_path . exists ( ) , ' invalid saved meta path : {:} ' . format ( meta_path )
2020-01-14 14:52:06 +01:00
print ( ' check NAS-Bench-201 in {:} ' . format ( save_dir ) )
2019-12-26 13:29:36 +01:00
basestr = ' C {:} -N {:} ' . format ( args . channel , args . num_cells )
check_files ( save_dir , meta_path , basestr )