321 lines
15 KiB
Python
321 lines
15 KiB
Python
import sys
|
|
import os
|
|
import json
|
|
import tqdm
|
|
import torch
|
|
import torch.utils
|
|
import torchvision.datasets as dset
|
|
import torch.backends.cudnn as cudnn
|
|
import random
|
|
import glob
|
|
import logging
|
|
import shutil
|
|
import numpy as np
|
|
sys.path.insert(0, '../')
|
|
from nasbench201.cell_infers.tiny_network import TinyNetwork
|
|
from nasbench201.genotypes import Structure
|
|
from nas_201_api import NASBench201API as API
|
|
from pycls.models.nas.nas import NetworkImageNet, NetworkCIFAR
|
|
from pycls.models.nas.genotypes import Genotype
|
|
import nasbench201.utils as ig_utils
|
|
from foresight.pruners import *
|
|
from Scorers.scorer import Jocab_Scorer
|
|
import torchvision.transforms as transforms
|
|
import argparse
|
|
from mobilenet_search_space.retrain_architecture.model import Network
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
from sota.cnn.hdf5 import H5Dataset
|
|
parser = argparse.ArgumentParser("sota")
|
|
parser.add_argument('--data', type=str, default='../data',
|
|
help='location of the data corpus')
|
|
parser.add_argument('--dataset', type=str, default='cifar10', help='choose dataset')
|
|
parser.add_argument('--gpu', type=str, default='auto', help='gpu device id')
|
|
parser.add_argument('--save', type=str, default='exp', help='experiment name')
|
|
parser.add_argument('--save_path', type=str, default='../experiments/sota', help='experiment name')
|
|
parser.add_argument('--seed', type=int, default=2, help='random seed')
|
|
parser.add_argument('--ckpt_path', type=str, help='path that saved networks pool')
|
|
parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data')
|
|
parser.add_argument('--maxiter', default=1, type=int, help='score is the max of this many evaluations of the network')
|
|
parser.add_argument('--batch_size', type=int, default=256, help='batch size for alpha')
|
|
parser.add_argument('--cutout', action='store_true', default=False, help='use cutout')
|
|
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
|
|
parser.add_argument('--cutout_prob', type=float, default=1.0, help='cutout probability')
|
|
parser.add_argument('--init_channels', type=int, default=16, help='num of init channels')
|
|
parser.add_argument('--layers', type=int, default=8, help='total number of layers')
|
|
parser.add_argument('--validate_rounds', type=int, default=10, help='score round for networks')
|
|
parser.add_argument('--proj_crit', type=str, default='jacob', choices=['loss', 'acc', 'var', 'cor', 'norm', 'jacob', 'snip', 'fisher', 'synflow', 'grad_norm', 'grasp', 'jacob_cov', 'comb', 'meco', 'zico'], help='criteria for projection')
|
|
parser.add_argument('--edge_decision', type=str, default='random', choices=['random','reverse', 'order', 'global_op_greedy', 'global_op_once', 'global_edge_greedy', 'global_edge_once'], help='which edge to be projected next')
|
|
args = parser.parse_args()
|
|
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
random.seed(args.seed)
|
|
np.random.seed(args.seed)
|
|
torch.manual_seed(args.seed)
|
|
torch.cuda.manual_seed(args.seed)
|
|
|
|
def load_network_pool(ckpt_path):
|
|
with open(os.path.join(ckpt_path,'networks_pool.json'), 'r') as save_file:
|
|
for line in save_file:
|
|
networks_pool = json.loads(line)
|
|
if 'pool_size' in networks_pool:
|
|
return networks_pool['search_space'], networks_pool['dataset'], networks_pool['networks'], networks_pool['pool_size']
|
|
else:
|
|
return networks_pool['search_space'], networks_pool['dataset'], networks_pool['networks'], len(networks_pool['networks'])
|
|
#### args augment
|
|
search_space, dataset, networks_pool, pool_size = load_network_pool(args.ckpt_path)
|
|
# print(search_space, dataset, networks_pool, pool_size)
|
|
search_space = search_space.strip()
|
|
dataset = dataset.strip()
|
|
expid = args.save
|
|
|
|
args.save = '{}/{}-valid-{}-{}-{}-{}'.format(args.save_path, search_space, args.save, args.seed, pool_size, args.validate_rounds)
|
|
if not dataset == 'cifar10':
|
|
args.save += '-' + dataset
|
|
if not args.edge_decision == 'random':
|
|
args.save += '-' + args.edge_decision
|
|
if not args.proj_crit == 'jacob':
|
|
args.save += '-' + args.proj_crit
|
|
scripts_to_save = glob.glob('*.py') + ['../exp_scripts/{}.sh'.format(expid)]
|
|
if os.path.exists(args.save):
|
|
if input("WARNING: {} exists, override?[y/n]".format(args.save)) == 'y':
|
|
print('proceed to override saving directory')
|
|
shutil.rmtree(args.save)
|
|
else:
|
|
exit(0)
|
|
ig_utils.create_exp_dir(args.save, scripts_to_save=None)
|
|
|
|
log_format = '%(asctime)s %(message)s'
|
|
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
|
|
format=log_format, datefmt='%m/%d %I:%M:%S %p')
|
|
|
|
log_file = 'log'
|
|
log_file += '.txt'
|
|
log_path = os.path.join(args.save, log_file)
|
|
logging.info('======> log filename: %s', log_file)
|
|
logging.info('load pool from space:%s and dataset:%s', search_space, dataset)
|
|
|
|
if os.path.exists(log_path):
|
|
if input("WARNING: {} exists, override?[y/n]".format(log_file)) == 'y':
|
|
print('proceed to override log file directory')
|
|
else:
|
|
exit(0)
|
|
|
|
fh = logging.FileHandler(log_path, mode='w')
|
|
fh.setFormatter(logging.Formatter(log_format))
|
|
logging.getLogger().addHandler(fh)
|
|
writer = SummaryWriter(args.save + '/runs')
|
|
|
|
#### macros
|
|
if dataset == 'cifar100':
|
|
n_classes = 100
|
|
elif dataset == 'imagenet16-120':
|
|
n_classes = 120
|
|
elif dataset == 'imagenet':
|
|
n_classes = 1000
|
|
else:
|
|
n_classes = 10
|
|
|
|
if search_space == 'nas-bench-201':
|
|
api = API('../data/NAS-Bench-201-v1_0-e61699.pth')
|
|
|
|
if search_space == 'nb_macro':
|
|
import pickle as pkl
|
|
f = open('../data/nbmacro-base-0.pickle','rb')
|
|
head = pkl.load(f)
|
|
value = pkl.load(f)
|
|
api ={}
|
|
for v in value:
|
|
h, val_t1, test_t1, t_time = v
|
|
api[h] = test_t1
|
|
def main():
|
|
#### data
|
|
if dataset == 'imagenet':
|
|
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
train_transform = transforms.Compose([
|
|
transforms.RandomResizedCrop(224),
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.ColorJitter(
|
|
brightness=0.4,
|
|
contrast=0.4,
|
|
saturation=0.4,
|
|
hue=0.2),
|
|
transforms.ToTensor(),
|
|
normalize,
|
|
])
|
|
|
|
test_transform = transforms.Compose([
|
|
transforms.Resize(256),
|
|
transforms.CenterCrop(224),
|
|
transforms.ToTensor(),
|
|
normalize,
|
|
])
|
|
train_data = H5Dataset(os.path.join(args.data, 'imagenet-train-256.h5'), transform=train_transform)
|
|
|
|
num_train = len(train_data)
|
|
indices = list(range(num_train))
|
|
split = int(np.floor(args.validate_rounds * args.batch_size))
|
|
|
|
train_queue = torch.utils.data.DataLoader(
|
|
train_data, batch_size=args.batch_size, num_workers=4, pin_memory=True, sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]))
|
|
|
|
else:
|
|
if dataset == 'cifar10':
|
|
train_transform, valid_transform = ig_utils._data_transforms_cifar10(args)
|
|
train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
|
|
valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)
|
|
elif dataset == 'cifar100':
|
|
train_transform, valid_transform = ig_utils._data_transforms_cifar100(args)
|
|
train_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=train_transform)
|
|
valid_data = dset.CIFAR100(root=args.data, train=False, download=True, transform=valid_transform)
|
|
elif dataset == 'svhn':
|
|
train_transform, valid_transform = ig_utils._data_transforms_svhn(args)
|
|
train_data = dset.SVHN(root=args.data, split='train', download=True, transform=train_transform)
|
|
valid_data = dset.SVHN(root=args.data, split='test', download=True, transform=valid_transform)
|
|
elif dataset == 'imagenet16-120':
|
|
from nasbench201.DownsampledImageNet import ImageNet16
|
|
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
|
|
std = [x / 255 for x in [63.22, 61.26, 65.09]]
|
|
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(16, padding=2), transforms.ToTensor(), transforms.Normalize(mean, std)]
|
|
train_transform = transforms.Compose(lists)
|
|
train_data = ImageNet16(root=os.path.join(data,'imagenet16'), train=True, transform=train_transform, use_num_of_class_only=120)
|
|
valid_data = ImageNet16(root=os.path.join(data,'imagenet16'), train=False, transform=train_transform, use_num_of_class_only=120)
|
|
assert len(train_data) == 151700
|
|
|
|
num_train = len(train_data)
|
|
indices = list(range(num_train))
|
|
split = int(np.floor(args.validate_rounds * args.batch_size))
|
|
|
|
train_queue = torch.utils.data.DataLoader(
|
|
train_data, batch_size=args.batch_size,
|
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
|
|
pin_memory=True, num_workers=4)
|
|
|
|
gpu = ig_utils.pick_gpu_lowest_memory() if args.gpu == 'auto' else int(args.gpu)
|
|
torch.cuda.set_device(gpu)
|
|
|
|
if args.proj_crit == 'jacob':
|
|
validate_scorer = Jocab_Scorer(gpu)
|
|
|
|
best_id = None
|
|
best_score = 0
|
|
best_networks = None
|
|
crit_list = []
|
|
print(len(train_queue))
|
|
net_history = []
|
|
for net_config in tqdm.tqdm(networks_pool, desc="networks", position=0):
|
|
net_id = net_config['id']
|
|
# print(net_id)
|
|
net_genotype = net_config['genotype']
|
|
# print(net_genotype)
|
|
if net_genotype not in net_history:
|
|
net_history.append(net_genotype)
|
|
# print(net_genotype)
|
|
network = get_networks_from_genotype(net_genotype, dataset, search_space)
|
|
# print(network)
|
|
if args.proj_crit == 'jacob':
|
|
validate_scorer.setup_hooks(network, args.batch_size)
|
|
for step, (input, target) in tqdm.tqdm(enumerate(train_queue), desc="validate_rounds", position=1, leave=False):
|
|
input.cuda()
|
|
target.cuda()
|
|
if args.proj_crit == 'jacob':
|
|
score = validate_scorer.score(network, input, target)
|
|
else:
|
|
#score = score_loop(network, None, train_queue, args.gpu, None, args.proj_crit)
|
|
network.requires_feature = False
|
|
measures = predictive.find_measures(network,
|
|
train_queue,
|
|
('random', 1, n_classes),
|
|
torch.device("cuda"),
|
|
measure_names=[args.proj_crit])
|
|
|
|
|
|
|
|
# measures = predictive.find_measures(network,
|
|
# train_queue,
|
|
# ('random', 1, n_classes), #TODO don't hard-code num_classes to 10
|
|
# torch.device("cuda"),
|
|
# measure_names=[args.proj_crit])
|
|
score = measures[args.proj_crit]
|
|
|
|
if step == 0:
|
|
crit_list.append(score)
|
|
else:
|
|
crit_list[-1] += score
|
|
if args.proj_crit != 'jacob':
|
|
break
|
|
#best_networks = networks_pool[np.nanargmax(crit_list)]['genotype']
|
|
best_networks = net_history[np.nanargmax(crit_list)]
|
|
|
|
if search_space == 'nas-bench-201':
|
|
cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \
|
|
cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test = query(api, best_networks, logging)
|
|
|
|
networks_info={}
|
|
networks_info['search_space'] = search_space
|
|
networks_info['dataset'] = dataset
|
|
networks_info['networks'] = best_networks
|
|
with open(os.path.join(args.save,'best_networks.json'), 'w') as save_file:
|
|
json.dump(networks_info, save_file)
|
|
|
|
|
|
#### util functions
|
|
def distill(result):
|
|
result = result.split('\n')
|
|
cifar10 = result[5].replace(' ', '').split(':')
|
|
cifar100 = result[7].replace(' ', '').split(':')
|
|
imagenet16 = result[9].replace(' ', '').split(':')
|
|
|
|
cifar10_train = float(cifar10[1].strip(',test')[-7:-2].strip('='))
|
|
cifar10_test = float(cifar10[2][-7:-2].strip('='))
|
|
cifar100_train = float(cifar100[1].strip(',valid')[-7:-2].strip('='))
|
|
cifar100_valid = float(cifar100[2].strip(',test')[-7:-2].strip('='))
|
|
cifar100_test = float(cifar100[3][-7:-2].strip('='))
|
|
imagenet16_train = float(imagenet16[1].strip(',valid')[-7:-2].strip('='))
|
|
imagenet16_valid = float(imagenet16[2].strip(',test')[-7:-2].strip('='))
|
|
imagenet16_test = float(imagenet16[3][-7:-2].strip('='))
|
|
|
|
return cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \
|
|
cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test
|
|
|
|
|
|
def query(api, genotype, logging):
|
|
result = api.query_by_arch(genotype, hp='200')
|
|
logging.info('{:}'.format(result))
|
|
cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \
|
|
cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test = distill(result)
|
|
logging.info('cifar10 train %f test %f', cifar10_train, cifar10_test)
|
|
logging.info('cifar100 train %f valid %f test %f', cifar100_train, cifar100_valid, cifar100_test)
|
|
logging.info('imagenet16 train %f valid %f test %f', imagenet16_train, imagenet16_valid, imagenet16_test)
|
|
return cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \
|
|
cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test
|
|
|
|
def get_networks_from_genotype(genotype_str, dataset, search_space):
|
|
if search_space == 'nas-bench-201':
|
|
net_index = api.query_index_by_arch(genotype_str)
|
|
##print(dataset)
|
|
net_config = api.get_net_config(net_index, 'cifar10-valid')
|
|
print(net_config)
|
|
genotype = Structure.str2structure(net_config['arch_str'])
|
|
network = TinyNetwork(net_config['C'], net_config['N'], genotype, n_classes)
|
|
return network
|
|
elif search_space == 'mobilenet':
|
|
rngs = [int(id) for id in genotype_str.split(' ')]
|
|
network = Network(rngs, n_class=n_classes)
|
|
return network
|
|
else:
|
|
# print(genotype_str)
|
|
genotype_config = json.loads(genotype_str)
|
|
genotype = Genotype(normal=genotype_config['normal'], normal_concat=genotype_config['normal_concat'], reduce=genotype_config['reduce'], reduce_concat=genotype_config['reduce_concat'])
|
|
|
|
if dataset == 'imagenet':
|
|
network = NetworkImageNet(args.init_channels, n_classes, args.layers, False, genotype)
|
|
else:
|
|
network = NetworkCIFAR(args.init_channels, n_classes, args.layers, False, genotype)
|
|
network.drop_path_prob = 0.
|
|
return network
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|