164 lines
6.7 KiB
Python
164 lines
6.7 KiB
Python
import argparse
|
|
import pickle
|
|
import random
|
|
|
|
from foresight.dataset import *
|
|
from foresight.models import nasbench2
|
|
from foresight.weight_initializers import init_net
|
|
from models import get_cell_based_tiny_net
|
|
|
|
|
|
def get_score(net, x, device, measure='meco'):
|
|
result_list = []
|
|
|
|
def forward_hook(module, data_input, data_output):
|
|
fea = data_output[0].clone().detach()
|
|
n = torch.tensor(fea.shape[0])
|
|
fea = fea.reshape(n, -1)
|
|
if measure == 'meco':
|
|
corr = torch.corrcoef(fea)
|
|
corr[torch.isnan(corr)] = 0
|
|
corr[torch.isinf(corr)] = 0
|
|
values = torch.linalg.eig(corr)[0]
|
|
result = torch.min(torch.real(values))
|
|
elif measure == 'meco_opt':
|
|
idxs = random.sample(range(n), 8)
|
|
fea = fea[idxs, :]
|
|
corr = torch.corrcoef(fea)
|
|
corr[torch.isnan(corr)] = 0
|
|
corr[torch.isinf(corr)] = 0
|
|
values = torch.linalg.eig(corr)[0]
|
|
result = torch.min(torch.real(values)) * n / 8
|
|
result_list.append(result)
|
|
for name, modules in net.named_modules():
|
|
modules.register_forward_hook(forward_hook)
|
|
x = x.to(device)
|
|
net(x)
|
|
results = torch.tensor(result_list)
|
|
results = results[torch.logical_not(torch.isnan(results))]
|
|
results = results[torch.logical_not(torch.isinf(results))]
|
|
res = torch.sum(results)
|
|
result_list.clear()
|
|
|
|
return res.item()
|
|
|
|
def get_num_classes(args):
|
|
return 100 if args.dataset == 'cifar100' else 10 if args.dataset == 'cifar10' else 120
|
|
|
|
|
|
def parse_arguments():
|
|
parser = argparse.ArgumentParser(description='Zero-cost Metrics for NAS-Bench-201')
|
|
# parser.add_argument('--api_loc', default='../data/NAS-Bench-201-v1_0-e61699.pth',
|
|
# type=str, help='path to API')
|
|
parser.add_argument('--outdir', default='./',
|
|
type=str, help='output directory')
|
|
parser.add_argument('--search_space', type=str, default='tss', choices=['tss', 'sss'])
|
|
parser.add_argument('--init_w_type', type=str, default='none',
|
|
help='weight initialization (before pruning) type [none, xavier, kaiming, zero, one]')
|
|
parser.add_argument('--init_b_type', type=str, default='none',
|
|
help='bias initialization (before pruning) type [none, xavier, kaiming, zero, one]')
|
|
parser.add_argument('--measure', type=str, default='meco', choices=['meco', 'meco_opt'])
|
|
parser.add_argument('--batch_size', default=1, type=int)
|
|
parser.add_argument('--dataset', type=str, default='cifar10',
|
|
help='dataset to use [cifar10, cifar100, ImageNet16-120]')
|
|
parser.add_argument('--gpu', type=int, default=0, help='GPU index to work on')
|
|
parser.add_argument('--data_size', type=int, default=32, help='data_size')
|
|
parser.add_argument('--num_data_workers', type=int, default=2, help='number of workers for dataloaders')
|
|
parser.add_argument('--dataload', type=str, default='appoint', help='random, grasp, appoint supported')
|
|
parser.add_argument('--dataload_info', type=int, default=1,
|
|
help='number of batches to use for random dataload or number of samples per class for grasp dataload')
|
|
parser.add_argument('--seed', type=int, default=42, help='pytorch manual seed')
|
|
parser.add_argument('--write_freq', type=int, default=1, help='frequency of write to file')
|
|
parser.add_argument('--start', type=int, default=0, help='start index')
|
|
parser.add_argument('--end', type=int, default=0, help='end index')
|
|
parser.add_argument('--noacc', default=True, action='store_true',
|
|
help='avoid loading NASBench2 api an instead load a pickle file with tuple (index, arch_str)')
|
|
args = parser.parse_args()
|
|
args.device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu")
|
|
return args
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_arguments()
|
|
print(args.device)
|
|
|
|
from nats_bench import create
|
|
|
|
api = create(None, args.search_space, fast_mode=True, verbose=False)
|
|
|
|
torch.manual_seed(args.seed)
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
train_loader, val_loader = get_cifar_dataloaders(args.batch_size, args.batch_size, args.dataset, args.num_data_workers, resize=args.data_size)
|
|
x, y = next(iter(train_loader))
|
|
|
|
cached_res = []
|
|
if 'cifar' in args.dataset :
|
|
pre = 'cf'
|
|
elif 'Image' in args.dataset:
|
|
pre = 'im'
|
|
elif 'oxford' in args.dataset:
|
|
pre = 'ox'
|
|
elif 'air' in args.dataset:
|
|
pre = 'ai'
|
|
pfn = f'nb2_{args.search_space}_{pre}{get_num_classes(args)}_seed{args.seed}_dl{args.dataload}_dlinfo{args.dataload_info}_initw{args.init_w_type}_initb{args.init_b_type}_{args.batch_size}.p'
|
|
op = os.path.join(args.outdir, pfn)
|
|
|
|
end = len(api) if args.end == 0 else args.end
|
|
|
|
# loop over nasbench2 archs
|
|
for i, arch_str in enumerate(api):
|
|
|
|
if i < args.start:
|
|
continue
|
|
if i >= end:
|
|
break
|
|
|
|
res = {'i': i, 'arch': arch_str}
|
|
if args.search_space == 'tss':
|
|
net = nasbench2.get_model_from_arch_str(arch_str, get_num_classes(args))
|
|
arch_str2 = nasbench2.get_arch_str_from_model(net)
|
|
if arch_str != arch_str2:
|
|
print(arch_str)
|
|
print(arch_str2)
|
|
raise ValueError
|
|
elif args.search_space == 'sss':
|
|
config = api.get_net_config(i, args.dataset)
|
|
net = get_cell_based_tiny_net(config)
|
|
net.to(args.device)
|
|
|
|
init_net(net, args.init_w_type, args.init_b_type)
|
|
|
|
measures = get_score(net, x, args.device, measure=args.measure)
|
|
|
|
res[f'{args.measure}'] = measures
|
|
|
|
if not args.noacc:
|
|
if args.search_space == 'tss':
|
|
info = api.get_more_info(i, 'cifar10-valid' if args.dataset == 'cifar10' else args.dataset, iepoch=None,
|
|
hp='200', is_random=False)
|
|
else:
|
|
info = api.get_more_info(i, 'cifar10-valid' if args.dataset == 'cifar10' else args.dataset, iepoch=None,
|
|
hp='90', is_random=False)
|
|
|
|
trainacc = info['train-accuracy']
|
|
valacc = info['valid-accuracy']
|
|
testacc = info['test-accuracy']
|
|
|
|
res['trainacc'] = trainacc
|
|
res['valacc'] = valacc
|
|
res['testacc'] = testacc
|
|
|
|
print(res)
|
|
cached_res.append(res)
|
|
|
|
# write to file
|
|
if i % args.write_freq == 0 or i == len(api) - 1 or i == 10:
|
|
print(f'writing {len(cached_res)} results to {op}')
|
|
pf = open(op, 'ab')
|
|
for cr in cached_res:
|
|
pickle.dump(cr, pf)
|
|
pf.close()
|
|
cached_res = []
|