361 lines
16 KiB
Python
361 lines
16 KiB
Python
|
from models import get_cell_based_tiny_net, get_search_spaces
|
||
|
from nas_201_api import NASBench201API as API
|
||
|
from nasbench import api as nasbench101api
|
||
|
from nas_101_api.model import Network
|
||
|
from nas_101_api.model_spec import ModelSpec
|
||
|
import itertools
|
||
|
import random
|
||
|
import numpy as np
|
||
|
from models.cell_searchs.genotypes import Structure
|
||
|
from copy import deepcopy
|
||
|
from pycls.models.nas.nas import NetworkImageNet, NetworkCIFAR
|
||
|
from pycls.models.anynet import AnyNet
|
||
|
from pycls.models.nas.genotypes import GENOTYPES, Genotype
|
||
|
import json
|
||
|
import torch
|
||
|
|
||
|
|
||
|
class Nasbench201:
|
||
|
def __init__(self, dataset, apiloc):
|
||
|
self.dataset = dataset
|
||
|
self.api = API(apiloc, verbose=False)
|
||
|
self.epochs = '12'
|
||
|
def get_network(self, uid):
|
||
|
#config = self.api.get_net_config(uid, self.dataset)
|
||
|
config = self.api.get_net_config(uid, 'cifar10-valid')
|
||
|
config['num_classes'] = 1
|
||
|
network = get_cell_based_tiny_net(config)
|
||
|
return network
|
||
|
def __iter__(self):
|
||
|
for uid in range(len(self)):
|
||
|
network = self.get_network(uid)
|
||
|
yield uid, network
|
||
|
def __getitem__(self, index):
|
||
|
return index
|
||
|
def __len__(self):
|
||
|
return 15625
|
||
|
def num_activations(self):
|
||
|
network = self.get_network(0)
|
||
|
return network.classifier.in_features
|
||
|
#def get_12epoch_accuracy(self, uid, acc_type, trainval, traincifar10=False):
|
||
|
# archinfo = self.api.query_meta_info_by_index(uid)
|
||
|
# if (self.dataset == 'cifar10' or traincifar10) and trainval:
|
||
|
# #return archinfo.get_metrics('cifar10-valid', acc_type, iepoch=12)['accuracy']
|
||
|
# return archinfo.get_metrics('cifar10-valid', 'x-valid', iepoch=12)['accuracy']
|
||
|
# elif traincifar10:
|
||
|
# return archinfo.get_metrics('cifar10', acc_type, iepoch=12)['accuracy']
|
||
|
# else:
|
||
|
# return archinfo.get_metrics(self.dataset, 'ori-test', iepoch=12)['accuracy']
|
||
|
def get_12epoch_accuracy(self, uid, acc_type, trainval, traincifar10=False):
|
||
|
#archinfo = self.api.query_meta_info_by_index(uid)
|
||
|
#if (self.dataset == 'cifar10' and trainval) or traincifar10:
|
||
|
info = self.api.get_more_info(uid, 'cifar10-valid', iepoch=None, hp=self.epochs, is_random=True)
|
||
|
#else:
|
||
|
# info = self.api.get_more_info(uid, self.dataset, iepoch=None, hp=self.epochs, is_random=True)
|
||
|
return info['valid-accuracy']
|
||
|
def get_final_accuracy(self, uid, acc_type, trainval):
|
||
|
#archinfo = self.api.query_meta_info_by_index(uid)
|
||
|
if self.dataset == 'cifar10' and trainval:
|
||
|
info = self.api.query_meta_info_by_index(uid, hp='200').get_metrics('cifar10-valid', 'x-valid')
|
||
|
#info = self.api.query_by_index(uid, 'cifar10-valid', hp='200')
|
||
|
#info = self.api.get_more_info(uid, 'cifar10-valid', iepoch=None, hp='200', is_random=True)
|
||
|
else:
|
||
|
info = self.api.query_meta_info_by_index(uid, hp='200').get_metrics(self.dataset, acc_type)
|
||
|
#info = self.api.query_by_index(uid, self.dataset, hp='200')
|
||
|
#info = self.api.get_more_info(uid, self.dataset, iepoch=None, hp='200', is_random=True)
|
||
|
return info['accuracy']
|
||
|
#return info['valid-accuracy']
|
||
|
#if self.dataset == 'cifar10' and trainval:
|
||
|
# return archinfo.get_metrics('cifar10-valid', acc_type, iepoch=11)['accuracy']
|
||
|
#else:
|
||
|
# #return archinfo.get_metrics(self.dataset, 'ori-test', iepoch=12)['accuracy']
|
||
|
# return archinfo.get_metrics(self.dataset, 'x-test', iepoch=11)['accuracy']
|
||
|
##dataset = self.dataset
|
||
|
##if self.dataset == 'cifar10' and trainval:
|
||
|
## dataset = 'cifar10-valid'
|
||
|
##archinfo = self.api.get_more_info(uid, dataset, iepoch=None, use_12epochs_result=True, is_random=True)
|
||
|
##return archinfo['valid-accuracy']
|
||
|
|
||
|
def get_accuracy(self, uid, acc_type, trainval=True):
|
||
|
archinfo = self.api.query_meta_info_by_index(uid)
|
||
|
if self.dataset == 'cifar10' and trainval:
|
||
|
return archinfo.get_metrics('cifar10-valid', acc_type)['accuracy']
|
||
|
else:
|
||
|
return archinfo.get_metrics(self.dataset, acc_type)['accuracy']
|
||
|
|
||
|
def get_accuracy_for_all_datasets(self, uid):
|
||
|
archinfo = self.api.query_meta_info_by_index(uid,hp='200')
|
||
|
|
||
|
c10 = archinfo.get_metrics('cifar10', 'ori-test')['accuracy']
|
||
|
c10_val = archinfo.get_metrics('cifar10-valid', 'x-valid')['accuracy']
|
||
|
|
||
|
c100 = archinfo.get_metrics('cifar100', 'x-test')['accuracy']
|
||
|
c100_val = archinfo.get_metrics('cifar100', 'x-valid')['accuracy']
|
||
|
|
||
|
imagenet = archinfo.get_metrics('ImageNet16-120', 'x-test')['accuracy']
|
||
|
imagenet_val = archinfo.get_metrics('ImageNet16-120', 'x-valid')['accuracy']
|
||
|
|
||
|
return c10, c10_val, c100, c100_val, imagenet, imagenet_val
|
||
|
|
||
|
#def train_and_eval(self, arch, dataname, acc_type, trainval=True):
|
||
|
# unique_hash = self.__getitem__(arch)
|
||
|
# time = self.get_training_time(unique_hash)
|
||
|
# acc12 = self.get_12epoch_accuracy(unique_hash, acc_type, trainval)
|
||
|
# acc = self.get_final_accuracy(unique_hash, acc_type, trainval)
|
||
|
# return acc12, acc, time
|
||
|
def train_and_eval(self, arch, dataname, acc_type, trainval=True, traincifar10=False):
|
||
|
unique_hash = self.__getitem__(arch)
|
||
|
time = self.get_training_time(unique_hash)
|
||
|
acc12 = self.get_12epoch_accuracy(unique_hash, acc_type, trainval, traincifar10)
|
||
|
acc = self.get_final_accuracy(unique_hash, acc_type, trainval)
|
||
|
return acc12, acc, time
|
||
|
def random_arch(self):
|
||
|
return random.randint(0, len(self)-1)
|
||
|
def get_training_time(self, unique_hash):
|
||
|
#info = self.api.get_more_info(unique_hash, 'cifar10-valid' if self.dataset == 'cifar10' else self.dataset, iepoch=None, use_12epochs_result=True, is_random=True)
|
||
|
|
||
|
|
||
|
#info = self.api.get_more_info(unique_hash, 'cifar10-valid', iepoch=None, use_12epochs_result=True, is_random=True)
|
||
|
info = self.api.get_more_info(unique_hash, 'cifar10-valid', iepoch=None, hp='12', is_random=True)
|
||
|
return info['train-all-time'] + info['valid-per-time']
|
||
|
#if self.dataset == 'cifar10' and trainval:
|
||
|
# info = self.api.get_more_info(unique_hash, 'cifar10-valid', iepoch=None, hp=self.epochs, is_random=True)
|
||
|
#else:
|
||
|
# info = self.api.get_more_info(unique_hash, self.dataset, iepoch=None, hp=self.epochs, is_random=True)
|
||
|
|
||
|
##info = self.api.get_more_info(unique_hash, 'cifar10-valid', iepoch=None, use_12epochs_result=True, is_random=True)
|
||
|
#return info['train-all-time'] + info['valid-per-time']
|
||
|
def mutate_arch(self, arch):
|
||
|
op_names = get_search_spaces('cell', 'nas-bench-201')
|
||
|
#config = self.api.get_net_config(arch, self.dataset)
|
||
|
config = self.api.get_net_config(arch, 'cifar10-valid')
|
||
|
parent_arch = Structure(self.api.str2lists(config['arch_str']))
|
||
|
child_arch = deepcopy( parent_arch )
|
||
|
node_id = random.randint(0, len(child_arch.nodes)-1)
|
||
|
node_info = list( child_arch.nodes[node_id] )
|
||
|
snode_id = random.randint(0, len(node_info)-1)
|
||
|
xop = random.choice( op_names )
|
||
|
while xop == node_info[snode_id][0]:
|
||
|
xop = random.choice( op_names )
|
||
|
node_info[snode_id] = (xop, node_info[snode_id][1])
|
||
|
child_arch.nodes[node_id] = tuple( node_info )
|
||
|
arch_index = self.api.query_index_by_arch( child_arch )
|
||
|
return arch_index
|
||
|
|
||
|
class Nasbench101:
|
||
|
def __init__(self, dataset, apiloc, args):
|
||
|
self.dataset = dataset
|
||
|
self.api = nasbench101api.NASBench(apiloc)
|
||
|
self.args = args
|
||
|
def get_accuracy(self, unique_hash, acc_type, trainval=True):
|
||
|
spec = self.get_spec(unique_hash)
|
||
|
_, stats = self.api.get_metrics_from_spec(spec)
|
||
|
maxacc = 0.
|
||
|
for ep in stats:
|
||
|
for statmap in stats[ep]:
|
||
|
newacc = statmap['final_test_accuracy']
|
||
|
if newacc > maxacc:
|
||
|
maxacc = newacc
|
||
|
return maxacc
|
||
|
def get_final_accuracy(self, uid, acc_type, trainval):
|
||
|
return self.get_accuracy(uid, acc_type, trainval)
|
||
|
def get_training_time(self, unique_hash):
|
||
|
spec = self.get_spec(unique_hash)
|
||
|
_, stats = self.api.get_metrics_from_spec(spec)
|
||
|
maxacc = -1.
|
||
|
maxtime = 0.
|
||
|
for ep in stats:
|
||
|
for statmap in stats[ep]:
|
||
|
newacc = statmap['final_test_accuracy']
|
||
|
if newacc > maxacc:
|
||
|
maxacc = newacc
|
||
|
maxtime = statmap['final_training_time']
|
||
|
return maxtime
|
||
|
def get_network(self, unique_hash):
|
||
|
spec = self.get_spec(unique_hash)
|
||
|
network = Network(spec, self.args)
|
||
|
return network
|
||
|
def get_spec(self, unique_hash):
|
||
|
matrix = self.api.fixed_statistics[unique_hash]['module_adjacency']
|
||
|
operations = self.api.fixed_statistics[unique_hash]['module_operations']
|
||
|
spec = ModelSpec(matrix, operations)
|
||
|
return spec
|
||
|
def __iter__(self):
|
||
|
for unique_hash in self.api.hash_iterator():
|
||
|
network = self.get_network(unique_hash)
|
||
|
yield unique_hash, network
|
||
|
def __getitem__(self, index):
|
||
|
return next(itertools.islice(self.api.hash_iterator(), index, None))
|
||
|
def __len__(self):
|
||
|
return len(self.api.hash_iterator())
|
||
|
def num_activations(self):
|
||
|
for unique_hash in self.api.hash_iterator():
|
||
|
network = self.get_network(unique_hash)
|
||
|
return network.classifier.in_features
|
||
|
def train_and_eval(self, arch, dataname, acc_type, trainval=True, traincifar10=False):
|
||
|
unique_hash = self.__getitem__(arch)
|
||
|
time =12.* self.get_training_time(unique_hash)/108.
|
||
|
acc = self.get_accuracy(unique_hash, acc_type, trainval)
|
||
|
return acc, acc, time
|
||
|
def random_arch(self):
|
||
|
return random.randint(0, len(self)-1)
|
||
|
def mutate_arch(self, arch):
|
||
|
unique_hash = self.__getitem__(arch)
|
||
|
matrix = self.api.fixed_statistics[unique_hash]['module_adjacency']
|
||
|
operations = self.api.fixed_statistics[unique_hash]['module_operations']
|
||
|
coords = [ (i, j) for i in range(matrix.shape[0]) for j in range(i+1, matrix.shape[1])]
|
||
|
random.shuffle(coords)
|
||
|
# loop through changes until we find change thats allowed
|
||
|
for i, j in coords:
|
||
|
# try the ops in a particular order
|
||
|
for k in [m for m in np.unique(matrix) if m != matrix[i, j]]:
|
||
|
newmatrix = matrix.copy()
|
||
|
newmatrix[i, j] = k
|
||
|
spec = ModelSpec(newmatrix, operations)
|
||
|
try:
|
||
|
newhash = self.api._hash_spec(spec)
|
||
|
if newhash in self.api.fixed_statistics:
|
||
|
return [n for n, m in enumerate(self.api.fixed_statistics.keys()) if m == newhash][0]
|
||
|
except:
|
||
|
pass
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
class ReturnFeatureLayer(torch.nn.Module):
|
||
|
def __init__(self, mod):
|
||
|
super(ReturnFeatureLayer, self).__init__()
|
||
|
self.mod = mod
|
||
|
def forward(self, x):
|
||
|
return self.mod(x), x
|
||
|
|
||
|
|
||
|
def return_feature_layer(network, prefix=''):
|
||
|
#for attr_str in dir(network):
|
||
|
# target_attr = getattr(network, attr_str)
|
||
|
# if isinstance(target_attr, torch.nn.Linear):
|
||
|
# setattr(network, attr_str, ReturnFeatureLayer(target_attr))
|
||
|
for n, ch in list(network.named_children()):
|
||
|
if isinstance(ch, torch.nn.Linear):
|
||
|
setattr(network, n, ReturnFeatureLayer(ch))
|
||
|
else:
|
||
|
return_feature_layer(ch, prefix + '\t')
|
||
|
|
||
|
|
||
|
class NDS:
|
||
|
def __init__(self, searchspace):
|
||
|
self.searchspace = searchspace
|
||
|
data = json.load(open(f'nds_data/{searchspace}.json', 'r'))
|
||
|
try:
|
||
|
data = data['top'] + data['mid']
|
||
|
except Exception as e:
|
||
|
pass
|
||
|
self.data = data
|
||
|
def __iter__(self):
|
||
|
for unique_hash in range(len(self)):
|
||
|
network = self.get_network(unique_hash)
|
||
|
yield unique_hash, network
|
||
|
def get_network_config(self, uid):
|
||
|
return self.data[uid]['net']
|
||
|
def get_network_optim_config(self, uid):
|
||
|
return self.data[uid]['optim']
|
||
|
def get_network(self, uid):
|
||
|
netinfo = self.data[uid]
|
||
|
config = netinfo['net']
|
||
|
#print(config)
|
||
|
if 'genotype' in config:
|
||
|
#print('geno')
|
||
|
gen = config['genotype']
|
||
|
genotype = Genotype(normal=gen['normal'], normal_concat=gen['normal_concat'], reduce=gen['reduce'], reduce_concat=gen['reduce_concat'])
|
||
|
if '_in' in self.searchspace:
|
||
|
network = NetworkImageNet(config['width'], 1, config['depth'], config['aux'], genotype)
|
||
|
else:
|
||
|
network = NetworkCIFAR(config['width'], 1, config['depth'], config['aux'], genotype)
|
||
|
network.drop_path_prob = 0.
|
||
|
#print(config)
|
||
|
#print('genotype')
|
||
|
L = config['depth']
|
||
|
else:
|
||
|
if 'bot_muls' in config and 'bms' not in config:
|
||
|
config['bms'] = config['bot_muls']
|
||
|
del config['bot_muls']
|
||
|
if 'num_gs' in config and 'gws' not in config:
|
||
|
config['gws'] = config['num_gs']
|
||
|
del config['num_gs']
|
||
|
config['nc'] = 1
|
||
|
config['se_r'] = None
|
||
|
config['stem_w'] = 12
|
||
|
L = sum(config['ds'])
|
||
|
if 'ResN' in self.searchspace:
|
||
|
config['stem_type'] = 'res_stem_in'
|
||
|
else:
|
||
|
config['stem_type'] = 'simple_stem_in'
|
||
|
#"res_stem_cifar": ResStemCifar,
|
||
|
#"res_stem_in": ResStemIN,
|
||
|
#"simple_stem_in": SimpleStemIN,
|
||
|
if config['block_type'] == 'double_plain_block':
|
||
|
config['block_type'] = 'vanilla_block'
|
||
|
network = AnyNet(**config)
|
||
|
return_feature_layer(network)
|
||
|
return network
|
||
|
def __getitem__(self, index):
|
||
|
return index
|
||
|
def __len__(self):
|
||
|
return len(self.data)
|
||
|
def random_arch(self):
|
||
|
return random.randint(0, len(self.data)-1)
|
||
|
def get_final_accuracy(self, uid, acc_type, trainval):
|
||
|
return 100.-self.data[uid]['test_ep_top1'][-1]
|
||
|
|
||
|
|
||
|
def get_search_space(args):
|
||
|
if args.nasspace == 'nasbench201':
|
||
|
return Nasbench201(args.dataset, args.api_loc)
|
||
|
elif args.nasspace == 'nasbench101':
|
||
|
return Nasbench101(args.dataset, args.api_loc, args)
|
||
|
elif args.nasspace == 'nds_resnet':
|
||
|
return NDS('ResNet')
|
||
|
elif args.nasspace == 'nds_amoeba':
|
||
|
return NDS('Amoeba')
|
||
|
elif args.nasspace == 'nds_amoeba_in':
|
||
|
return NDS('Amoeba_in')
|
||
|
elif args.nasspace == 'nds_darts_in':
|
||
|
return NDS('DARTS_in')
|
||
|
elif args.nasspace == 'nds_darts':
|
||
|
return NDS('DARTS')
|
||
|
elif args.nasspace == 'nds_darts_fix-w-d':
|
||
|
return NDS('DARTS_fix-w-d')
|
||
|
elif args.nasspace == 'nds_darts_lr-wd':
|
||
|
return NDS('DARTS_lr-wd')
|
||
|
elif args.nasspace == 'nds_enas':
|
||
|
return NDS('ENAS')
|
||
|
elif args.nasspace == 'nds_enas_in':
|
||
|
return NDS('ENAS_in')
|
||
|
elif args.nasspace == 'nds_enas_fix-w-d':
|
||
|
return NDS('ENAS_fix-w-d')
|
||
|
elif args.nasspace == 'nds_pnas':
|
||
|
return NDS('PNAS')
|
||
|
elif args.nasspace == 'nds_pnas_fix-w-d':
|
||
|
return NDS('PNAS_fix-w-d')
|
||
|
elif args.nasspace == 'nds_pnas_in':
|
||
|
return NDS('PNAS_in')
|
||
|
elif args.nasspace == 'nds_nasnet':
|
||
|
return NDS('NASNet')
|
||
|
elif args.nasspace == 'nds_nasnet_in':
|
||
|
return NDS('NASNet_in')
|
||
|
elif args.nasspace == 'nds_resnext-a':
|
||
|
return NDS('ResNeXt-A')
|
||
|
elif args.nasspace == 'nds_resnext-a_in':
|
||
|
return NDS('ResNeXt-A_in')
|
||
|
elif args.nasspace == 'nds_resnext-b':
|
||
|
return NDS('ResNeXt-B')
|
||
|
elif args.nasspace == 'nds_resnext-b_in':
|
||
|
return NDS('ResNeXt-B_in')
|
||
|
elif args.nasspace == 'nds_vanilla':
|
||
|
return NDS('Vanilla')
|
||
|
elif args.nasspace == 'nds_vanilla_lr-wd':
|
||
|
return NDS('Vanilla_lr-wd')
|
||
|
elif args.nasspace == 'nds_vanilla_lr-wd_in':
|
||
|
return NDS('Vanilla_lr-wd_in')
|
||
|
|