MeCo/nasbench201/networks_proposal.py
HamsterMimi 2410fe9f5e update
2023-05-04 13:42:06 +08:00

246 lines
11 KiB
Python

import os
import sys
sys.path.insert(0, '../')
import time
import glob
import json
import shutil
import logging
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.utils
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
from torch.autograd import Variable
import nasbench201.utils as ig_utils
from nasbench201.search_model_darts_proj import TinyNetworkDartsProj
from nasbench201.cell_operations import SearchSpaceNames
from nasbench201.init_projection import pt_project, global_op_greedy_pt_project, global_op_once_pt_project, global_edge_greedy_pt_project, global_edge_once_pt_project, shrink_pt_project, tenas_project
from nas_201_api import NASBench201API as API
torch.set_printoptions(precision=4, sci_mode=False)
np.set_printoptions(precision=4, suppress=True)
parser = argparse.ArgumentParser("sota")
# data related
parser.add_argument('--data', type=str, default='../data', help='location of the data corpus')
parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100', 'imagenet16-120'], help='choose dataset')
parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data')
parser.add_argument('--batch_size', type=int, default=64, help='batch size for alpha')
parser.add_argument('--cutout', action='store_true', default=True, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--cutout_prob', type=float, default=1.0, help='cutout probability')
parser.add_argument('--seed', type=int, default=2, help='random seed')
#search space setting
parser.add_argument('--search_space', type=str, default='nas-bench-201')
parser.add_argument('--pool_size', type=int, default=100, help='number of model to proposed')
parser.add_argument('--init_channels', type=int, default=16, help='num of init channels')
parser.add_argument('--layers', type=int, default=8, help='total number of layers')
#system configurations
parser.add_argument('--gpu', type=str, default='auto', help='gpu device id')
parser.add_argument('--save', type=str, default='exp', help='experiment name')
#default opt setting for model
parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate')
parser.add_argument('--learning_rate_min', type=float, default=0.001, help='min learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--nesterov', action='store_true', default=True, help='using nestrov momentum for SGD')
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
#### common
parser.add_argument('--fast', action='store_true', default=True, help='skip loading api which is slow')
#### projection
parser.add_argument('--edge_decision', type=str, default='random', choices=['random','reverse', 'order', 'global_op_greedy', 'global_op_once', 'global_edge_greedy', 'global_edge_once', 'shrink_pt_project'], help='which edge to be projected next')
parser.add_argument('--proj_crit', type=str, default="comb", choices=['loss', 'acc', 'jacob', 'snip', 'fisher', 'synflow', 'grad_norm', 'grasp', 'jacob_cov','tenas', 'var', 'cor', 'norm', 'comb', 'meco'], help='criteria for projection')
args = parser.parse_args()
#### args augment
expid = args.save
args.save = '../experiments/nas-bench-201/prop-{}-{}-{}'.format(args.save, args.seed, args.pool_size)
if not args.dataset == 'cifar10':
args.save += '-' + args.dataset
if not args.edge_decision == 'random':
args.save += '-' + args.edge_decision
if not args.proj_crit == 'jacob':
args.save += '-' + args.proj_crit
#### logging
scripts_to_save = glob.glob('*.py') \
# + ['../exp_scripts/{}.sh'.format(expid)]
if os.path.exists(args.save):
if input("WARNING: {} exists, override?[y/n]".format(args.save)) == 'y':
print('proceed to override saving directory')
shutil.rmtree(args.save)
else:
exit(0)
ig_utils.create_exp_dir(args.save, scripts_to_save=scripts_to_save)
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
format=log_format, datefmt='%m/%d %I:%M:%S %p')
log_file = 'log.txt'
log_path = os.path.join(args.save, log_file)
logging.info('======> log filename: %s', log_file)
if os.path.exists(log_path):
if input("WARNING: {} exists, override?[y/n]".format(log_file)) == 'y':
print('proceed to override log file directory')
else:
exit(0)
fh = logging.FileHandler(log_path, mode='w')
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)
writer = SummaryWriter(args.save + '/runs')
#### macros
if args.dataset == 'cifar100':
n_classes = 100
elif args.dataset == 'imagenet16-120':
n_classes = 120
else:
n_classes = 10
def main():
torch.set_num_threads(3)
if not torch.cuda.is_available():
logging.info('no gpu device available')
sys.exit(1)
np.random.seed(args.seed)
gpu = ig_utils.pick_gpu_lowest_memory() if args.gpu == 'auto' else int(args.gpu)
torch.cuda.set_device(gpu)
cudnn.benchmark = True
torch.manual_seed(args.seed)
cudnn.enabled = True
torch.cuda.manual_seed(args.seed)
logging.info("args = %s", args)
logging.info('gpu device = %d' % gpu)
#### model
criterion = nn.CrossEntropyLoss()
search_space = SearchSpaceNames[args.search_space]
# 初始化超网络
model = TinyNetworkDartsProj(C=args.init_channels, N=5, max_nodes=4, num_classes=n_classes, criterion=criterion, search_space=search_space, args=args)
model_thin = TinyNetworkDartsProj(C=args.init_channels, N=5, max_nodes=4, num_classes=n_classes, criterion=criterion, search_space=search_space, args=args, stem_channels=1)
model = model.cuda()
model_thin = model_thin.cuda()
logging.info("param size = %fMB", ig_utils.count_parameters_in_MB(model))
#### data
if args.dataset == 'cifar10':
train_transform, valid_transform = ig_utils._data_transforms_cifar10(args)
train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)
elif args.dataset == 'cifar100':
train_transform, valid_transform = ig_utils._data_transforms_cifar100(args)
train_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=train_transform)
valid_data = dset.CIFAR100(root=args.data, train=False, download=True, transform=valid_transform)
elif args.dataset == 'imagenet16-120':
import torchvision.transforms as transforms
from nasbench201.DownsampledImageNet import ImageNet16
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
std = [x / 255 for x in [63.22, 61.26, 65.09]]
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(16, padding=2), transforms.ToTensor(), transforms.Normalize(mean, std)]
train_transform = transforms.Compose(lists)
train_data = ImageNet16(root=os.path.join(args.data,'imagenet16'), train=True, transform=train_transform, use_num_of_class_only=120)
valid_data = ImageNet16(root=os.path.join(args.data,'imagenet16'), train=False, transform=train_transform, use_num_of_class_only=120)
assert len(train_data) == 151700
num_train = len(train_data)
indices = list(range(num_train))
split = int(np.floor(args.train_portion * num_train))
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
pin_memory=True)
#format network pool diction
networks_pool={}
networks_pool['search_space'] = args.search_space
networks_pool['dataset'] = args.dataset
networks_pool['networks'] = []
networks_pool['pool_size'] = args.pool_size
#### architecture selection / projection
for i in range(args.pool_size):
network_info={}
logging.info('{} MODEL HAS SEARCHED'.format(i+1))
if args.edge_decision == 'global_op_greedy':
global_op_greedy_pt_project(train_queue, model, args)
elif args.edge_decision == 'global_op_once':
global_op_once_pt_project(train_queue, model, args)
elif args.edge_decision == 'global_edge_greedy':
global_edge_greedy_pt_project(train_queue, model, args)
elif args.edge_decision == 'global_edge_once':
global_edge_once_pt_project(train_queue, model, args)
elif args.edge_decision == 'shrink_pt_project':
shrink_pt_project(train_queue, model, args)
api = API('../data/NAS-Bench-201-v1_0-e61699.pth')
cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \
cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test = query(api, model.genotype().tostr(), logging)
else:
if args.proj_crit == 'jacob':
pt_project(train_queue, model, args)
else:
pt_project(train_queue, model, args)
# tenas_project(train_queue, model, model_thin, args)
network_info['id'] = str(i)
network_info['genotype'] = model.genotype().tostr()
networks_pool['networks'].append(network_info)
model.reset_arch_parameters()
with open(os.path.join(args.save,'networks_pool.json'), 'w') as save_file:
json.dump(networks_pool, save_file)
#### util functions
def distill(result):
result = result.split('\n')
cifar10 = result[5].replace(' ', '').split(':')
cifar100 = result[7].replace(' ', '').split(':')
imagenet16 = result[9].replace(' ', '').split(':')
cifar10_train = float(cifar10[1].strip(',test')[-7:-2].strip('='))
cifar10_test = float(cifar10[2][-7:-2].strip('='))
cifar100_train = float(cifar100[1].strip(',valid')[-7:-2].strip('='))
cifar100_valid = float(cifar100[2].strip(',test')[-7:-2].strip('='))
cifar100_test = float(cifar100[3][-7:-2].strip('='))
imagenet16_train = float(imagenet16[1].strip(',valid')[-7:-2].strip('='))
imagenet16_valid = float(imagenet16[2].strip(',test')[-7:-2].strip('='))
imagenet16_test = float(imagenet16[3][-7:-2].strip('='))
return cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \
cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test
def query(api, genotype, logging):
result = api.query_by_arch(genotype, hp='200')
logging.info('{:}'.format(result))
cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \
cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test = distill(result)
logging.info('cifar10 train %f test %f', cifar10_train, cifar10_test)
logging.info('cifar100 train %f valid %f test %f', cifar100_train, cifar100_valid, cifar100_test)
logging.info('imagenet16 train %f valid %f test %f', imagenet16_train, imagenet16_valid, imagenet16_test)
return cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \
cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test
if __name__ == '__main__':
main()