naswot/nasspace.py
Jack Turner b74255e1f3 v2
2021-02-26 16:12:51 +00:00

361 lines
16 KiB
Python

from models import get_cell_based_tiny_net, get_search_spaces
from nas_201_api import NASBench201API as API
from nasbench import api as nasbench101api
from nas_101_api.model import Network
from nas_101_api.model_spec import ModelSpec
import itertools
import random
import numpy as np
from models.cell_searchs.genotypes import Structure
from copy import deepcopy
from pycls.models.nas.nas import NetworkImageNet, NetworkCIFAR
from pycls.models.anynet import AnyNet
from pycls.models.nas.genotypes import GENOTYPES, Genotype
import json
import torch
class Nasbench201:
def __init__(self, dataset, apiloc):
self.dataset = dataset
self.api = API(apiloc, verbose=False)
self.epochs = '12'
def get_network(self, uid):
#config = self.api.get_net_config(uid, self.dataset)
config = self.api.get_net_config(uid, 'cifar10-valid')
config['num_classes'] = 1
network = get_cell_based_tiny_net(config)
return network
def __iter__(self):
for uid in range(len(self)):
network = self.get_network(uid)
yield uid, network
def __getitem__(self, index):
return index
def __len__(self):
return 15625
def num_activations(self):
network = self.get_network(0)
return network.classifier.in_features
#def get_12epoch_accuracy(self, uid, acc_type, trainval, traincifar10=False):
# archinfo = self.api.query_meta_info_by_index(uid)
# if (self.dataset == 'cifar10' or traincifar10) and trainval:
# #return archinfo.get_metrics('cifar10-valid', acc_type, iepoch=12)['accuracy']
# return archinfo.get_metrics('cifar10-valid', 'x-valid', iepoch=12)['accuracy']
# elif traincifar10:
# return archinfo.get_metrics('cifar10', acc_type, iepoch=12)['accuracy']
# else:
# return archinfo.get_metrics(self.dataset, 'ori-test', iepoch=12)['accuracy']
def get_12epoch_accuracy(self, uid, acc_type, trainval, traincifar10=False):
#archinfo = self.api.query_meta_info_by_index(uid)
#if (self.dataset == 'cifar10' and trainval) or traincifar10:
info = self.api.get_more_info(uid, 'cifar10-valid', iepoch=None, hp=self.epochs, is_random=True)
#else:
# info = self.api.get_more_info(uid, self.dataset, iepoch=None, hp=self.epochs, is_random=True)
return info['valid-accuracy']
def get_final_accuracy(self, uid, acc_type, trainval):
#archinfo = self.api.query_meta_info_by_index(uid)
if self.dataset == 'cifar10' and trainval:
info = self.api.query_meta_info_by_index(uid, hp='200').get_metrics('cifar10-valid', 'x-valid')
#info = self.api.query_by_index(uid, 'cifar10-valid', hp='200')
#info = self.api.get_more_info(uid, 'cifar10-valid', iepoch=None, hp='200', is_random=True)
else:
info = self.api.query_meta_info_by_index(uid, hp='200').get_metrics(self.dataset, acc_type)
#info = self.api.query_by_index(uid, self.dataset, hp='200')
#info = self.api.get_more_info(uid, self.dataset, iepoch=None, hp='200', is_random=True)
return info['accuracy']
#return info['valid-accuracy']
#if self.dataset == 'cifar10' and trainval:
# return archinfo.get_metrics('cifar10-valid', acc_type, iepoch=11)['accuracy']
#else:
# #return archinfo.get_metrics(self.dataset, 'ori-test', iepoch=12)['accuracy']
# return archinfo.get_metrics(self.dataset, 'x-test', iepoch=11)['accuracy']
##dataset = self.dataset
##if self.dataset == 'cifar10' and trainval:
## dataset = 'cifar10-valid'
##archinfo = self.api.get_more_info(uid, dataset, iepoch=None, use_12epochs_result=True, is_random=True)
##return archinfo['valid-accuracy']
def get_accuracy(self, uid, acc_type, trainval=True):
archinfo = self.api.query_meta_info_by_index(uid)
if self.dataset == 'cifar10' and trainval:
return archinfo.get_metrics('cifar10-valid', acc_type)['accuracy']
else:
return archinfo.get_metrics(self.dataset, acc_type)['accuracy']
def get_accuracy_for_all_datasets(self, uid):
archinfo = self.api.query_meta_info_by_index(uid,hp='200')
c10 = archinfo.get_metrics('cifar10', 'ori-test')['accuracy']
c10_val = archinfo.get_metrics('cifar10-valid', 'x-valid')['accuracy']
c100 = archinfo.get_metrics('cifar100', 'x-test')['accuracy']
c100_val = archinfo.get_metrics('cifar100', 'x-valid')['accuracy']
imagenet = archinfo.get_metrics('ImageNet16-120', 'x-test')['accuracy']
imagenet_val = archinfo.get_metrics('ImageNet16-120', 'x-valid')['accuracy']
return c10, c10_val, c100, c100_val, imagenet, imagenet_val
#def train_and_eval(self, arch, dataname, acc_type, trainval=True):
# unique_hash = self.__getitem__(arch)
# time = self.get_training_time(unique_hash)
# acc12 = self.get_12epoch_accuracy(unique_hash, acc_type, trainval)
# acc = self.get_final_accuracy(unique_hash, acc_type, trainval)
# return acc12, acc, time
def train_and_eval(self, arch, dataname, acc_type, trainval=True, traincifar10=False):
unique_hash = self.__getitem__(arch)
time = self.get_training_time(unique_hash)
acc12 = self.get_12epoch_accuracy(unique_hash, acc_type, trainval, traincifar10)
acc = self.get_final_accuracy(unique_hash, acc_type, trainval)
return acc12, acc, time
def random_arch(self):
return random.randint(0, len(self)-1)
def get_training_time(self, unique_hash):
#info = self.api.get_more_info(unique_hash, 'cifar10-valid' if self.dataset == 'cifar10' else self.dataset, iepoch=None, use_12epochs_result=True, is_random=True)
#info = self.api.get_more_info(unique_hash, 'cifar10-valid', iepoch=None, use_12epochs_result=True, is_random=True)
info = self.api.get_more_info(unique_hash, 'cifar10-valid', iepoch=None, hp='12', is_random=True)
return info['train-all-time'] + info['valid-per-time']
#if self.dataset == 'cifar10' and trainval:
# info = self.api.get_more_info(unique_hash, 'cifar10-valid', iepoch=None, hp=self.epochs, is_random=True)
#else:
# info = self.api.get_more_info(unique_hash, self.dataset, iepoch=None, hp=self.epochs, is_random=True)
##info = self.api.get_more_info(unique_hash, 'cifar10-valid', iepoch=None, use_12epochs_result=True, is_random=True)
#return info['train-all-time'] + info['valid-per-time']
def mutate_arch(self, arch):
op_names = get_search_spaces('cell', 'nas-bench-201')
#config = self.api.get_net_config(arch, self.dataset)
config = self.api.get_net_config(arch, 'cifar10-valid')
parent_arch = Structure(self.api.str2lists(config['arch_str']))
child_arch = deepcopy( parent_arch )
node_id = random.randint(0, len(child_arch.nodes)-1)
node_info = list( child_arch.nodes[node_id] )
snode_id = random.randint(0, len(node_info)-1)
xop = random.choice( op_names )
while xop == node_info[snode_id][0]:
xop = random.choice( op_names )
node_info[snode_id] = (xop, node_info[snode_id][1])
child_arch.nodes[node_id] = tuple( node_info )
arch_index = self.api.query_index_by_arch( child_arch )
return arch_index
class Nasbench101:
def __init__(self, dataset, apiloc, args):
self.dataset = dataset
self.api = nasbench101api.NASBench(apiloc)
self.args = args
def get_accuracy(self, unique_hash, acc_type, trainval=True):
spec = self.get_spec(unique_hash)
_, stats = self.api.get_metrics_from_spec(spec)
maxacc = 0.
for ep in stats:
for statmap in stats[ep]:
newacc = statmap['final_test_accuracy']
if newacc > maxacc:
maxacc = newacc
return maxacc
def get_final_accuracy(self, uid, acc_type, trainval):
return self.get_accuracy(uid, acc_type, trainval)
def get_training_time(self, unique_hash):
spec = self.get_spec(unique_hash)
_, stats = self.api.get_metrics_from_spec(spec)
maxacc = -1.
maxtime = 0.
for ep in stats:
for statmap in stats[ep]:
newacc = statmap['final_test_accuracy']
if newacc > maxacc:
maxacc = newacc
maxtime = statmap['final_training_time']
return maxtime
def get_network(self, unique_hash):
spec = self.get_spec(unique_hash)
network = Network(spec, self.args)
return network
def get_spec(self, unique_hash):
matrix = self.api.fixed_statistics[unique_hash]['module_adjacency']
operations = self.api.fixed_statistics[unique_hash]['module_operations']
spec = ModelSpec(matrix, operations)
return spec
def __iter__(self):
for unique_hash in self.api.hash_iterator():
network = self.get_network(unique_hash)
yield unique_hash, network
def __getitem__(self, index):
return next(itertools.islice(self.api.hash_iterator(), index, None))
def __len__(self):
return len(self.api.hash_iterator())
def num_activations(self):
for unique_hash in self.api.hash_iterator():
network = self.get_network(unique_hash)
return network.classifier.in_features
def train_and_eval(self, arch, dataname, acc_type, trainval=True, traincifar10=False):
unique_hash = self.__getitem__(arch)
time =12.* self.get_training_time(unique_hash)/108.
acc = self.get_accuracy(unique_hash, acc_type, trainval)
return acc, acc, time
def random_arch(self):
return random.randint(0, len(self)-1)
def mutate_arch(self, arch):
unique_hash = self.__getitem__(arch)
matrix = self.api.fixed_statistics[unique_hash]['module_adjacency']
operations = self.api.fixed_statistics[unique_hash]['module_operations']
coords = [ (i, j) for i in range(matrix.shape[0]) for j in range(i+1, matrix.shape[1])]
random.shuffle(coords)
# loop through changes until we find change thats allowed
for i, j in coords:
# try the ops in a particular order
for k in [m for m in np.unique(matrix) if m != matrix[i, j]]:
newmatrix = matrix.copy()
newmatrix[i, j] = k
spec = ModelSpec(newmatrix, operations)
try:
newhash = self.api._hash_spec(spec)
if newhash in self.api.fixed_statistics:
return [n for n, m in enumerate(self.api.fixed_statistics.keys()) if m == newhash][0]
except:
pass
class ReturnFeatureLayer(torch.nn.Module):
def __init__(self, mod):
super(ReturnFeatureLayer, self).__init__()
self.mod = mod
def forward(self, x):
return self.mod(x), x
def return_feature_layer(network, prefix=''):
#for attr_str in dir(network):
# target_attr = getattr(network, attr_str)
# if isinstance(target_attr, torch.nn.Linear):
# setattr(network, attr_str, ReturnFeatureLayer(target_attr))
for n, ch in list(network.named_children()):
if isinstance(ch, torch.nn.Linear):
setattr(network, n, ReturnFeatureLayer(ch))
else:
return_feature_layer(ch, prefix + '\t')
class NDS:
def __init__(self, searchspace):
self.searchspace = searchspace
data = json.load(open(f'nds_data/{searchspace}.json', 'r'))
try:
data = data['top'] + data['mid']
except Exception as e:
pass
self.data = data
def __iter__(self):
for unique_hash in range(len(self)):
network = self.get_network(unique_hash)
yield unique_hash, network
def get_network_config(self, uid):
return self.data[uid]['net']
def get_network_optim_config(self, uid):
return self.data[uid]['optim']
def get_network(self, uid):
netinfo = self.data[uid]
config = netinfo['net']
#print(config)
if 'genotype' in config:
#print('geno')
gen = config['genotype']
genotype = Genotype(normal=gen['normal'], normal_concat=gen['normal_concat'], reduce=gen['reduce'], reduce_concat=gen['reduce_concat'])
if '_in' in self.searchspace:
network = NetworkImageNet(config['width'], 1, config['depth'], config['aux'], genotype)
else:
network = NetworkCIFAR(config['width'], 1, config['depth'], config['aux'], genotype)
network.drop_path_prob = 0.
#print(config)
#print('genotype')
L = config['depth']
else:
if 'bot_muls' in config and 'bms' not in config:
config['bms'] = config['bot_muls']
del config['bot_muls']
if 'num_gs' in config and 'gws' not in config:
config['gws'] = config['num_gs']
del config['num_gs']
config['nc'] = 1
config['se_r'] = None
config['stem_w'] = 12
L = sum(config['ds'])
if 'ResN' in self.searchspace:
config['stem_type'] = 'res_stem_in'
else:
config['stem_type'] = 'simple_stem_in'
#"res_stem_cifar": ResStemCifar,
#"res_stem_in": ResStemIN,
#"simple_stem_in": SimpleStemIN,
if config['block_type'] == 'double_plain_block':
config['block_type'] = 'vanilla_block'
network = AnyNet(**config)
return_feature_layer(network)
return network
def __getitem__(self, index):
return index
def __len__(self):
return len(self.data)
def random_arch(self):
return random.randint(0, len(self.data)-1)
def get_final_accuracy(self, uid, acc_type, trainval):
return 100.-self.data[uid]['test_ep_top1'][-1]
def get_search_space(args):
if args.nasspace == 'nasbench201':
return Nasbench201(args.dataset, args.api_loc)
elif args.nasspace == 'nasbench101':
return Nasbench101(args.dataset, args.api_loc, args)
elif args.nasspace == 'nds_resnet':
return NDS('ResNet')
elif args.nasspace == 'nds_amoeba':
return NDS('Amoeba')
elif args.nasspace == 'nds_amoeba_in':
return NDS('Amoeba_in')
elif args.nasspace == 'nds_darts_in':
return NDS('DARTS_in')
elif args.nasspace == 'nds_darts':
return NDS('DARTS')
elif args.nasspace == 'nds_darts_fix-w-d':
return NDS('DARTS_fix-w-d')
elif args.nasspace == 'nds_darts_lr-wd':
return NDS('DARTS_lr-wd')
elif args.nasspace == 'nds_enas':
return NDS('ENAS')
elif args.nasspace == 'nds_enas_in':
return NDS('ENAS_in')
elif args.nasspace == 'nds_enas_fix-w-d':
return NDS('ENAS_fix-w-d')
elif args.nasspace == 'nds_pnas':
return NDS('PNAS')
elif args.nasspace == 'nds_pnas_fix-w-d':
return NDS('PNAS_fix-w-d')
elif args.nasspace == 'nds_pnas_in':
return NDS('PNAS_in')
elif args.nasspace == 'nds_nasnet':
return NDS('NASNet')
elif args.nasspace == 'nds_nasnet_in':
return NDS('NASNet_in')
elif args.nasspace == 'nds_resnext-a':
return NDS('ResNeXt-A')
elif args.nasspace == 'nds_resnext-a_in':
return NDS('ResNeXt-A_in')
elif args.nasspace == 'nds_resnext-b':
return NDS('ResNeXt-B')
elif args.nasspace == 'nds_resnext-b_in':
return NDS('ResNeXt-B_in')
elif args.nasspace == 'nds_vanilla':
return NDS('Vanilla')
elif args.nasspace == 'nds_vanilla_lr-wd':
return NDS('Vanilla_lr-wd')
elif args.nasspace == 'nds_vanilla_lr-wd_in':
return NDS('Vanilla_lr-wd_in')