Update weight watcher codes

This commit is contained in:
D-X-Y 2020-07-05 22:29:26 +00:00
parent 9659f132be
commit 6facc39a42
9 changed files with 167 additions and 161 deletions

View File

@ -46,7 +46,8 @@ It is recommended to put these data into `$TORCH_HOME` (`~/.torch/` by default).
```
from nas_201_api import NASBench201API as API
api = API('$path_to_meta_nas_bench_file')
api = API('NAS-Bench-201-v1_1-096897.pth')
# Create an API without the verbose log
api = API('NAS-Bench-201-v1_1-096897.pth', verbose=False)
# The default path for benchmark file is '{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-Bench-201-v1_1-096897.pth')
api = API(None)
```

View File

@ -90,9 +90,9 @@ def visualize_sss_info(api, dataset, vis_save_dir):
print ('Do not find cache file : {:}'.format(cache_file_path))
params, flops, train_accs, valid_accs, test_accs = [], [], [], [], []
for index in range(len(api)):
info = api.get_cost_info(index, dataset)
params.append(info['params'])
flops.append(info['flops'])
cost_info = api.get_cost_info(index, dataset, hp='90')
params.append(cost_info['params'])
flops.append(cost_info['flops'])
# accuracy
info = api.get_more_info(index, dataset, hp='90', is_random=False)
train_accs.append(info['train-accuracy'])
@ -178,9 +178,9 @@ def visualize_tss_info(api, dataset, vis_save_dir):
print ('Do not find cache file : {:}'.format(cache_file_path))
params, flops, train_accs, valid_accs, test_accs = [], [], [], [], []
for index in range(len(api)):
info = api.get_cost_info(index, dataset)
params.append(info['params'])
flops.append(info['flops'])
cost_info = api.get_cost_info(index, dataset, hp='12')
params.append(cost_info['params'])
flops.append(cost_info['flops'])
# accuracy
info = api.get_more_info(index, dataset, hp='200', is_random=False)
train_accs.append(info['train-accuracy'])
@ -190,6 +190,7 @@ def visualize_tss_info(api, dataset, vis_save_dir):
valid_accs.append(info['valid-accuracy'])
else:
valid_accs.append(info['valid-accuracy'])
print('')
info = {'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs}
torch.save(info, cache_file_path)
else:

View File

@ -1,113 +0,0 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
###############################################################################################
# Before run these commands, the files must be properly put.
# python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_0-e61699
# python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10-valid --use_12 1 --use_valid 1
# bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1
###############################################################################################
import os, gc, sys, math, argparse, psutil
import numpy as np
import torch
from pathlib import Path
from collections import OrderedDict
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from nas_201_api import NASBench201API as API
from log_utils import time_string
from models import get_cell_based_tiny_net
from utils import weight_watcher
def get_cor(A, B):
return float(np.corrcoef(A, B)[0,1])
def tostr(accdict, norms):
xstr = []
for key, accs in accdict.items():
cor = get_cor(accs, norms)
xstr.append('{:}: {:.3f}'.format(key, cor))
return ' '.join(xstr)
def evaluate(api, weight_dir, data: str, use_12epochs_result: bool):
print('\nEvaluate dataset={:}'.format(data))
norms, process = [], psutil.Process(os.getpid())
final_val_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
final_test_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []})
for idx in range(len(api)):
# info = api.get_more_info(idx, data, use_12epochs_result=use_12epochs_result, is_random=False)
for key in ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']:
info = api.get_more_info(idx, key, use_12epochs_result=False, is_random=False)
if key == 'cifar10-valid':
final_val_accs['cifar10'].append(info['valid-accuracy'])
elif key == 'cifar10':
final_test_accs['cifar10'].append(info['test-accuracy'])
else:
final_test_accs[key].append(info['test-accuracy'])
final_val_accs[key].append(info['valid-accuracy'])
config = api.get_net_config(idx, data)
net = get_cell_based_tiny_net(config)
api.reload(weight_dir, idx)
params = api.get_net_param(idx, data, None, use_12epochs_result=use_12epochs_result)
cur_norms = []
for seed, param in params.items():
with torch.no_grad():
net.load_state_dict(param)
_, summary = weight_watcher.analyze(net, alphas=False)
cur_norms.append(-summary['lognorm'])
cur_norm = float(np.mean(cur_norms))
if math.isnan(cur_norm):
print (' IGNORE {:} due to nan.'.format(idx))
continue
norms.append(cur_norm)
api.clear_params(idx, None)
if idx % 200 == 199 or idx + 1 == len(api):
head = '{:05d}/{:05d}'.format(idx, len(api))
stem_val = tostr(final_val_accs, norms)
stem_test = tostr(final_test_accs, norms)
print('{:} {:} {:} with {:} epochs ({:.2f} MB memory)'.format(time_string(), head, data, 12 if use_12epochs_result else 200, process.memory_info().rss / 1e6))
print(' [Valid] -->> {:}'.format(stem_val))
print(' [Test.] -->> {:}'.format(stem_test))
gc.collect()
def main(meta_file: str, weight_dir, save_dir, xdata, use_12epochs_result):
api = API(meta_file)
datasets = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
print(time_string() + ' ' + '='*50)
for data in datasets:
nums = api.statistics(data, True)
total = sum([k*v for k, v in nums.items()])
print('Using 012 epochs, trained on {:20s} : {:} trials in total ({:}).'.format(data, total, nums))
print(time_string() + ' ' + '='*50)
for data in datasets:
nums = api.statistics(data, False)
total = sum([k*v for k, v in nums.items()])
print('Using 200 epochs, trained on {:20s} : {:} trials in total ({:}).'.format(data, total, nums))
print(time_string() + ' ' + '='*50)
#evaluate(api, weight_dir, 'cifar10-valid', False, True)
evaluate(api, weight_dir, xdata, use_12epochs_result)
print('{:} finish this test.'.format(time_string()))
if __name__ == '__main__':
parser = argparse.ArgumentParser("Analysis of NAS-Bench-201")
parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-201/visuals', help='The base-name of folder to save checkpoints and log.')
parser.add_argument('--base_path', type=str, default=None, help='The path to the NAS-Bench-201 benchmark file and weight dir.')
parser.add_argument('--dataset' , type=str, default=None, help='.')
parser.add_argument('--use_12' , type=int, default=None, help='.')
args = parser.parse_args()
save_dir = Path(args.save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
meta_file = Path(args.base_path + '.pth')
weight_dir = Path(args.base_path + '-archive')
assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file)
assert weight_dir.exists() and weight_dir.is_dir(), 'invalid path for weight dir : {:}'.format(weight_dir)
main(str(meta_file), weight_dir, save_dir, args.dataset, bool(args.use_12))

View File

@ -1,20 +0,0 @@
#
# exps/experimental/test-api.py
#
import sys, time, random, argparse
from copy import deepcopy
import torchvision.models as models
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from nas_201_api import NASBench201API as API
def main():
api = API(None)
info = api.get_more_info(100, 'cifar100', 199, False, True)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,151 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
###########################################################################################################################################################
# Before run these commands, the files must be properly put.
#
# python exps/experimental/test-ww-bench.py --base_path $HOME/.torch/NAS-Bench-201-v1_0-e61699
# python exps/experimental/test-ww-bench.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10-valid --use_12 1 --use_valid 1
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NAS-Bench-301-v1_0 --dataset cifar10
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NAS-Bench-301-v1_0 --dataset cifar100
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NAS-Bench-301-v1_0 --dataset ImageNet16-120
###########################################################################################################################################################
import os, gc, sys, math, argparse, psutil
import numpy as np
import torch
from pathlib import Path
from collections import OrderedDict
import matplotlib
import seaborn as sns
matplotlib.use('agg')
import matplotlib.pyplot as plt
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from nas_201_api import NASBench201API, NASBench301API
from log_utils import time_string
from models import get_cell_based_tiny_net
from utils import weight_watcher
"""
def get_cor(A, B):
return float(np.corrcoef(A, B)[0,1])
def tostr(accdict, norms):
xstr = []
for key, accs in accdict.items():
cor = get_cor(accs, norms)
xstr.append('{:}: {:.3f}'.format(key, cor))
return ' '.join(xstr)
"""
def evaluate(api, weight_dir, data: str):
print('\nEvaluate dataset={:}'.format(data))
process = psutil.Process(os.getpid())
norms, accuracies = [], []
ok, total = 0, 5000
for idx in range(total):
arch_index = api.random()
api.reload(weight_dir, arch_index)
# compute the weight watcher results
config = api.get_net_config(arch_index, data)
net = get_cell_based_tiny_net(config)
meta_info = api.query_meta_info_by_index(arch_index, hp='200' if isinstance(api, NASBench201API) else '90')
params = meta_info.get_net_param(data, 777)
with torch.no_grad():
net.load_state_dict(params)
_, summary = weight_watcher.analyze(net, alphas=False)
if 'lognorm' not in summary:
api.clear_params(arch_index, None)
del net ; continue
continue
cur_norm = -summary['lognorm']
api.clear_params(arch_index, None)
if math.isnan(cur_norm):
del net, meta_info
continue
else:
ok += 1
norms.append(cur_norm)
# query the accuracy
info = meta_info.get_metrics(data, 'ori-test', iepoch=None, is_random=777)
accuracies.append(info['accuracy'])
del net, meta_info
# print the information
if idx % 20 == 0:
gc.collect()
print('{:} {:04d}_{:04d}/{:04d} ({:.2f} MB memory)'.format(time_string(), ok, idx, total, process.memory_info().rss / 1e6))
return norms, accuracies
def main(search_space, meta_file: str, weight_dir, save_dir, xdata):
API = NASBench201API if search_space == 'tss' else NASBench301API
save_dir.mkdir(parents=True, exist_ok=True)
api = API(meta_file, verbose=False)
datasets = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
print(time_string() + ' ' + '='*50)
for data in datasets:
hps = api.avaliable_hps
for hp in hps:
nums = api.statistics(data, hp=hp)
total = sum([k*v for k, v in nums.items()])
print('Using {:3s} epochs, trained on {:20s} : {:} trials in total ({:}).'.format(hp, data, total, nums))
print(time_string() + ' ' + '='*50)
norms, accuracies = evaluate(api, weight_dir, xdata)
indexes = list(range(len(norms)))
norm_indexes = sorted(indexes, key=lambda i: norms[i])
accy_indexes = sorted(indexes, key=lambda i: accuracies[i])
labels = []
for index in norm_indexes:
labels.append(accy_indexes.index(index))
dpi, width, height = 200, 1400, 800
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 18, 12
resnet_scale, resnet_alpha = 120, 0.5
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
plt.xlim(min(indexes), max(indexes))
plt.ylim(min(indexes), max(indexes))
# plt.ylabel('y').set_rotation(30)
plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//3), fontsize=LegendFontsize, rotation='vertical')
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//5), fontsize=LegendFontsize)
ax.scatter(indexes, labels , marker='*', s=0.5, c='tab:red' , alpha=0.8)
ax.scatter(indexes, indexes, marker='o', s=0.5, c='tab:blue' , alpha=0.8)
ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='Test accuracy')
ax.scatter([-1], [-1], marker='*', s=100, c='tab:red' , label='Weight watcher')
plt.grid(zorder=0)
ax.set_axisbelow(True)
plt.legend(loc=0, fontsize=LegendFontsize)
ax.set_xlabel('architecture ranking sorted by the test accuracy ', fontsize=LabelSize)
ax.set_ylabel('architecture ranking computed by weight watcher', fontsize=LabelSize)
save_path = (save_dir / '{:}-{:}-test-ww.pdf'.format(search_space, xdata)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
save_path = (save_dir / '{:}-{:}-test-ww.png'.format(search_space, xdata)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path))
print('{:} finish this test.'.format(time_string()))
if __name__ == '__main__':
parser = argparse.ArgumentParser("Analysis of NAS-Bench-201")
parser.add_argument('--save_dir', type=str, default='./output/vis-nas-bench/', help='The base-name of folder to save checkpoints and log.')
parser.add_argument('--search_space', type=str, default=None, choices=['tss', 'sss'], help='The search space.')
parser.add_argument('--base_path', type=str, default=None, help='The path to the NAS-Bench-201 benchmark file and weight dir.')
parser.add_argument('--dataset' , type=str, default=None, help='.')
args = parser.parse_args()
save_dir = Path(args.save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
meta_file = Path(args.base_path + '.pth')
weight_dir = Path(args.base_path + '-archive')
assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file)
assert weight_dir.exists() and weight_dir.is_dir(), 'invalid path for weight dir : {:}'.format(weight_dir)
main(args.search_space, str(meta_file), weight_dir, save_dir, args.dataset)

View File

@ -77,6 +77,7 @@ class NASBench201API(NASBenchMetaAPI):
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )
# This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults
self.arch2infos_dict = OrderedDict()
self._avaliable_hps = set(['12', '200'])
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
all_info = file_path_or_dict['arch2infos'][xkey]
hp2archres = OrderedDict()

View File

@ -75,11 +75,13 @@ class NASBench301API(NASBenchMetaAPI):
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )
# This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults
self.arch2infos_dict = OrderedDict()
self._avaliable_hps = set()
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
all_infos = file_path_or_dict['arch2infos'][xkey]
hp2archres = OrderedDict()
for hp_key, results in all_infos.items():
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
self._avaliable_hps.add(hp_key) # save the avaliable hyper-parameter
self.arch2infos_dict[xkey] = hp2archres
self.evaluated_indexes = sorted(list(file_path_or_dict['evaluated_indexes']))
self.archstr2index = {}

View File

@ -57,6 +57,10 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
def __repr__(self):
return ('{name}({num}/{total} architectures, file={filename})'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs), filename=self.filename))
@property
def avaliable_hps(self):
return list(copy.deepcopy(self._avaliable_hps))
def random(self):
"""Return a random index of all architectures."""
return random.randint(0, len(self.meta_archs)-1)

View File

@ -1,21 +0,0 @@
#!/bin/bash
# bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1
echo script name: $0
echo $# arguments
if [ "$#" -ne 2 ] ;then
echo "Input illegal number of parameters " $#
echo "Need 2 parameters for dataset and use_12_epoch"
exit 1
fi
if [ "$TORCH_HOME" = "" ]; then
echo "Must set TORCH_HOME envoriment variable for data dir saving"
exit 1
else
echo "TORCH_HOME : $TORCH_HOME"
fi
CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/NAS-Bench-201/test-weights.py \
--base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 \
--dataset $1 \
--use_12 $2