Update test-weights of NAS-Bench-201

This commit is contained in:
D-X-Y 2020-03-16 11:11:01 +11:00
parent fb76814369
commit e76451c791
3 changed files with 58 additions and 14 deletions

View File

@ -3,13 +3,15 @@
###############################################################################################
# Before run these commands, the files must be properly put.
# python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_0-e61699
# python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897
# python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10-valid --use_12 1 --use_valid 1
# bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1 1
###############################################################################################
import os, sys, time, glob, random, argparse
import os, gc, sys, time, glob, random, argparse
import numpy as np
import torch
import torch.nn as nn
from pathlib import Path
from collections import OrderedDict
from tqdm import tqdm
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
@ -24,30 +26,48 @@ def get_cor(A, B):
return float(np.corrcoef(A, B)[0,1])
def tostr(accdict, norms):
xstr = []
for key, accs in accdict.items():
cor = get_cor(accs, norms)
xstr.append('{:}: {:.3f}'.format(key, cor))
return ' '.join(xstr)
def evaluate(api, weight_dir, data: str, use_12epochs_result: bool, valid_or_test: bool):
print('\nEvaluate dataset={:}'.format(data))
norms, accs = [], []
for idx in tqdm(range(len(api))):
final_accs = OrderedDict({'cifar10-valid': [], 'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
for idx in range(len(api)):
info = api.get_more_info(idx, data, use_12epochs_result=use_12epochs_result, is_random=False)
if valid_or_test:
accs.append(info['valid-accuracy'])
else:
accs.append(info['test-accuracy'])
for key in final_accs.keys():
info = api.get_more_info(idx, key, use_12epochs_result=False, is_random=False)
final_accs[key].append(info['test-accuracy'])
config = api.get_net_config(idx, data)
net = get_cell_based_tiny_net(config)
api.reload(weight_dir, idx)
params = api.get_net_param(idx, data, None)
cur_norms = []
for seed, param in params.items():
net.load_state_dict(param)
_, summary = weight_watcher.analyze(net, alphas=False)
cur_norms.append( summary['lognorm'] )
with torch.no_grad():
net.load_state_dict(param)
_, summary = weight_watcher.analyze(net, alphas=False)
cur_norms.append( summary['lognorm'] )
norms.append( float(np.mean(cur_norms)) )
api.clear_params(idx, use_12epochs_result)
correlation = get_cor(norms, accs)
print('For {:} with {:} epochs on {:} : the correlation is {:}'.format(data, 12 if use_12epochs_result else 200, 'valid' if valid_or_test else 'test', correlation))
if idx % 200 == 199 or idx + 1 == len(api):
correlation = get_cor(norms, accs)
head = '{:05d}/{:05d}'.format(idx, len(api))
stem = tostr(final_accs, norms)
print('{:} {:} {:} with {:} epochs on {:} : the correlation is {:.3f}. {:}'.format(time_string(), head, data, 12 if use_12epochs_result else 200, 'valid' if valid_or_test else 'test', correlation, stem))
torch.cuda.empty_cache() ; gc.collect()
def main(meta_file: str, weight_dir, save_dir):
def main(meta_file: str, weight_dir, save_dir, xdata, use_12epochs_result, valid_or_test):
api = API(meta_file)
datasets = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
print(time_string() + ' ' + '='*50)
@ -62,7 +82,8 @@ def main(meta_file: str, weight_dir, save_dir):
print('Using 200 epochs, trained on {:20s} : {:} trials in total ({:}).'.format(data, total, nums))
print(time_string() + ' ' + '='*50)
evaluate(api, weight_dir, 'cifar10-valid', False, True)
#evaluate(api, weight_dir, 'cifar10-valid', False, True)
evaluate(api, weight_dir, xdata, use_12epochs_result, valid_or_test)
print('{:} finish this test.'.format(time_string()))
@ -71,6 +92,9 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser("Analysis of NAS-Bench-201")
parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-201/visuals', help='The base-name of folder to save checkpoints and log.')
parser.add_argument('--base_path', type=str, default=None, help='The path to the NAS-Bench-201 benchmark file and weight dir.')
parser.add_argument('--dataset' , type=str, default=None, help='.')
parser.add_argument('--use_12' , type=int, default=None, help='.')
parser.add_argument('--use_valid', type=int, default=None, help='.')
args = parser.parse_args()
save_dir = Path(args.save_dir)
@ -80,5 +104,5 @@ if __name__ == '__main__':
assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file)
assert weight_dir.exists() and weight_dir.is_dir(), 'invalid path for weight dir : {:}'.format(weight_dir)
main(str(meta_file), weight_dir, save_dir)
main(str(meta_file), weight_dir, save_dir, args.dataset, bool(args.use_12), bool(args.use_valid))

View File

@ -53,7 +53,7 @@ class NASBench201API(object):
if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict))
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
self.filename = Path(file_path_or_dict).name
file_path_or_dict = torch.load(file_path_or_dict)
file_path_or_dict = torch.load(file_path_or_dict, map_location='cpu')
elif isinstance(file_path_or_dict, dict):
file_path_or_dict = copy.deepcopy( file_path_or_dict )
else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict)))
@ -112,7 +112,7 @@ class NASBench201API(object):
xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(index))
assert 0 <= index < len(self.meta_archs), 'invalid index of {:}'.format(index)
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
xdata = torch.load(xfile_path)
xdata = torch.load(xfile_path, map_location='cpu')
assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path)
self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] )
self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] )
@ -723,7 +723,7 @@ class ArchResults(object):
def create_from_state_dict(state_dict_or_file):
x = ArchResults(-1, -1)
if isinstance(state_dict_or_file, str): # a file path
state_dict = torch.load(state_dict_or_file)
state_dict = torch.load(state_dict_or_file, map_location='cpu')
elif isinstance(state_dict_or_file, dict):
state_dict = state_dict_or_file
else:

View File

@ -0,0 +1,20 @@
#!/bin/bash
# bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1 1
echo script name: $0
echo $# arguments
if [ "$#" -ne 3 ] ;then
echo "Input illegal number of parameters " $#
echo "Need 3 parameters for dataset, use_12_epoch, and use_validation_set"
exit 1
fi
if [ "$TORCH_HOME" = "" ]; then
echo "Must set TORCH_HOME envoriment variable for data dir saving"
exit 1
else
echo "TORCH_HOME : $TORCH_HOME"
fi
OMP_NUM_THREADS=4 python exps/NAS-Bench-201/test-weights.py \
--base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 \
--dataset $1 \
--use_12 $2 --use_valid $3