310 lines
12 KiB
Python
310 lines
12 KiB
Python
import os
|
|
import random
|
|
import sys
|
|
|
|
sys.path.insert(0, '../../')
|
|
import glob
|
|
import numpy as np
|
|
import torch
|
|
import nasbench201.utils as ig_utils
|
|
import logging
|
|
import argparse
|
|
import shutil
|
|
import torch.nn as nn
|
|
import torch.utils
|
|
import torchvision.datasets as dset
|
|
import torch.backends.cudnn as cudnn
|
|
import json
|
|
from sota.cnn.model import Network
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
from collections import namedtuple
|
|
|
|
parser = argparse.ArgumentParser("cifar")
|
|
parser.add_argument('--data', type=str, default='../../data',
|
|
help='location of the data corpus')
|
|
parser.add_argument('--dataset', type=str, default='cifar10', help='choose dataset')
|
|
parser.add_argument('--batch_size', type=int, default=96, help='batch size')
|
|
parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate')
|
|
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
|
|
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
|
|
parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
|
|
parser.add_argument('--gpu', type=str, default='auto', help='gpu device id')
|
|
parser.add_argument('--epochs', type=int, default=600, help='num of training epochs')
|
|
parser.add_argument('--init_channels', type=int, default=36, help='num of init channels')
|
|
parser.add_argument('--layers', type=int, default=20, help='total number of layers')
|
|
parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model')
|
|
parser.add_argument('--auxiliary', action='store_true', default=True, help='use auxiliary tower')
|
|
parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss')
|
|
parser.add_argument('--cutout', action='store_true', default=True, help='use cutout')
|
|
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
|
|
parser.add_argument('--cutout_prob', type=float, default=1.0, help='cutout probability')
|
|
parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path probability')
|
|
parser.add_argument('--save', type=str, default='exp', help='experiment name')
|
|
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
|
parser.add_argument('--arch', type=str, default='c100_s4_pgd', help='which architecture to use')
|
|
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
|
|
#### common
|
|
parser.add_argument('--resume_epoch', type=int, default=0, help="load ckpt, start training at resume_epoch")
|
|
parser.add_argument('--ckpt_interval', type=int, default=50, help="interval (epoch) for saving checkpoints")
|
|
parser.add_argument('--resume_expid', type=str, default='', help="full expid to resume from, name == ckpt folder name")
|
|
parser.add_argument('--fast', action='store_true', default=False, help="fast mode for debugging")
|
|
parser.add_argument('--queue', action='store_true', default=False, help="queueing for gpu")
|
|
|
|
parser.add_argument('--from_dir', action='store_true', default=True, help="arch load form dir")
|
|
args = parser.parse_args()
|
|
|
|
|
|
def load_network_pool(ckpt_path):
|
|
with open(os.path.join(ckpt_path, 'best_networks.json'), 'r') as save_file:
|
|
networks_pool = json.load(save_file)
|
|
return networks_pool['networks']
|
|
|
|
|
|
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
|
|
#### args augment
|
|
expid = args.save
|
|
|
|
print(args.from_dir)
|
|
if args.from_dir:
|
|
id_name = os.path.split(args.arch)[1]
|
|
# print('aaaaaaa', args.arch)
|
|
args.arch = load_network_pool(args.arch)
|
|
args.save = '../../experiments/sota/{}/eval/{}-{}-{}'.format(
|
|
args.dataset, args.save, id_name, args.seed)
|
|
else:
|
|
args.save = '../../experiments/sota/{}/eval/{}-{}-{}'.format(
|
|
args.dataset, args.save, args.arch, args.seed)
|
|
if args.cutout:
|
|
args.save += '-cutout-' + str(args.cutout_length) + '-' + str(args.cutout_prob)
|
|
if args.auxiliary:
|
|
args.save += '-auxiliary-' + str(args.auxiliary_weight)
|
|
|
|
#### logging
|
|
if args.resume_epoch > 0: # do not delete dir if resume:
|
|
args.save = '../../experiments/sota/{}/{}'.format(args.dataset, args.resume_expid)
|
|
assert (os.path.exists(args.save), 'resume but {} does not exist!'.format(args.save))
|
|
else:
|
|
scripts_to_save = glob.glob('*.py')
|
|
if os.path.exists(args.save):
|
|
if input("WARNING: {} exists, override?[y/n]".format(args.save)) == 'y':
|
|
print('proceed to override saving directory')
|
|
shutil.rmtree(args.save)
|
|
else:
|
|
exit(0)
|
|
ig_utils.create_exp_dir(args.save, scripts_to_save=scripts_to_save)
|
|
|
|
log_format = '%(asctime)s %(message)s'
|
|
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
|
|
format=log_format, datefmt='%m/%d %I:%M:%S %p')
|
|
log_file = 'log_resume_{}.txt'.format(args.resume_epoch) if args.resume_epoch > 0 else 'log.txt'
|
|
fh = logging.FileHandler(os.path.join(args.save, log_file), mode='w')
|
|
fh.setFormatter(logging.Formatter(log_format))
|
|
logging.getLogger().addHandler(fh)
|
|
writer = SummaryWriter(args.save + '/runs')
|
|
|
|
if args.dataset == 'cifar100':
|
|
n_classes = 100
|
|
else:
|
|
n_classes = 10
|
|
|
|
|
|
def seed_torch(seed=0):
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
cudnn.deterministic = True
|
|
cudnn.benchmark = False
|
|
|
|
|
|
def main():
|
|
torch.set_num_threads(3)
|
|
if not torch.cuda.is_available():
|
|
logging.info('no gpu device available')
|
|
sys.exit(1)
|
|
|
|
#### gpu queueing
|
|
if args.queue:
|
|
ig_utils.queue_gpu()
|
|
|
|
gpu = ig_utils.pick_gpu_lowest_memory() if args.gpu == 'auto' else int(args.gpu)
|
|
torch.cuda.set_device(gpu)
|
|
cudnn.enabled = True
|
|
seed_torch(args.seed)
|
|
|
|
logging.info('gpu device = %d' % gpu)
|
|
logging.info("args = %s", args)
|
|
|
|
if args.from_dir:
|
|
genotype_config = json.loads(args.arch)
|
|
genotype = Genotype(normal=genotype_config['normal'], normal_concat=genotype_config['normal_concat'],
|
|
reduce=genotype_config['reduce'], reduce_concat=genotype_config['reduce_concat'])
|
|
else:
|
|
genotype = eval("genotypes.%s" % args.arch)
|
|
|
|
model = Network(args.init_channels, n_classes, args.layers, args.auxiliary, genotype)
|
|
model = model.cuda()
|
|
|
|
logging.info("param size = %fMB", ig_utils.count_parameters_in_MB(model))
|
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
criterion = criterion.cuda()
|
|
optimizer = torch.optim.SGD(
|
|
model.parameters(),
|
|
args.learning_rate,
|
|
momentum=args.momentum,
|
|
weight_decay=args.weight_decay
|
|
)
|
|
|
|
if args.dataset == 'cifar10':
|
|
train_transform, valid_transform = ig_utils._data_transforms_cifar10(args)
|
|
train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
|
|
valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)
|
|
elif args.dataset == 'cifar100':
|
|
train_transform, valid_transform = ig_utils._data_transforms_cifar100(args)
|
|
train_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=train_transform)
|
|
valid_data = dset.CIFAR100(root=args.data, train=False, download=True, transform=valid_transform)
|
|
elif args.dataset == 'svhn':
|
|
train_transform, valid_transform = ig_utils._data_transforms_svhn(args)
|
|
train_data = dset.SVHN(root=args.data, split='train', download=True, transform=train_transform)
|
|
valid_data = dset.SVHN(root=args.data, split='test', download=True, transform=valid_transform)
|
|
|
|
train_queue = torch.utils.data.DataLoader(
|
|
train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=0)
|
|
|
|
valid_queue = torch.utils.data.DataLoader(
|
|
valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0)
|
|
|
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
optimizer, float(args.epochs),
|
|
# eta_min=1e-4
|
|
)
|
|
|
|
#### resume
|
|
start_epoch = 0
|
|
if args.resume_epoch > 0:
|
|
logging.info('loading checkpoint from {}'.format(expid))
|
|
filename = os.path.join(args.save, 'checkpoint_{}.pth.tar'.format(args.resume_epoch))
|
|
|
|
if os.path.isfile(filename):
|
|
print("=> loading checkpoint '{}'".format(filename))
|
|
checkpoint = torch.load(filename, map_location='cpu')
|
|
resume_epoch = checkpoint['epoch'] # epoch
|
|
model.load_state_dict(checkpoint['state_dict']) # model
|
|
scheduler.load_state_dict(checkpoint['scheduler'])
|
|
optimizer.load_state_dict(checkpoint['optimizer']) # optimizer
|
|
start_epoch = args.resume_epoch
|
|
print("=> loaded checkpoint '{}' (epoch {})".format(filename, resume_epoch))
|
|
else:
|
|
print("=> no checkpoint found at '{}'".format(filename))
|
|
|
|
#### main training
|
|
best_valid_acc = 0
|
|
for epoch in range(start_epoch, args.epochs):
|
|
lr = scheduler.get_lr()[0]
|
|
if args.cutout:
|
|
train_transform.transforms[-1].cutout_prob = args.cutout_prob
|
|
logging.info('epoch %d lr %e cutout_prob %e', epoch, lr,
|
|
train_transform.transforms[-1].cutout_prob)
|
|
else:
|
|
logging.info('epoch %d lr %e', epoch, lr)
|
|
model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
|
|
|
|
train_acc, train_obj = train(train_queue, model, criterion, optimizer)
|
|
logging.info('train_acc %f', train_acc)
|
|
writer.add_scalar('Acc/train', train_acc, epoch)
|
|
writer.add_scalar('Obj/train', train_obj, epoch)
|
|
|
|
## scheduler
|
|
scheduler.step()
|
|
|
|
valid_acc, valid_obj = infer(valid_queue, model, criterion)
|
|
logging.info('valid_acc %f', valid_acc)
|
|
writer.add_scalar('Acc/valid', valid_acc, epoch)
|
|
writer.add_scalar('Obj/valid', valid_obj, epoch)
|
|
|
|
## checkpoint
|
|
if (epoch + 1) % args.ckpt_interval == 0:
|
|
save_state_dict = {
|
|
'epoch': epoch + 1,
|
|
'state_dict': model.state_dict(),
|
|
'optimizer': optimizer.state_dict(),
|
|
'scheduler': scheduler.state_dict(),
|
|
}
|
|
ig_utils.save_checkpoint(save_state_dict, False, args.save, per_epoch=True)
|
|
|
|
best_valid_acc = max(best_valid_acc, valid_acc)
|
|
logging.info('best valid_acc %f', best_valid_acc)
|
|
writer.close()
|
|
|
|
|
|
def train(train_queue, model, criterion, optimizer):
|
|
objs = ig_utils.AvgrageMeter()
|
|
top1 = ig_utils.AvgrageMeter()
|
|
top5 = ig_utils.AvgrageMeter()
|
|
model.train()
|
|
|
|
for step, (input, target) in enumerate(train_queue):
|
|
input = input.cuda()
|
|
target = target.cuda(non_blocking=True)
|
|
|
|
optimizer.zero_grad()
|
|
logits, logits_aux = model(input)
|
|
loss = criterion(logits, target)
|
|
if args.auxiliary:
|
|
loss_aux = criterion(logits_aux, target)
|
|
loss += args.auxiliary_weight * loss_aux
|
|
loss.backward()
|
|
nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
|
optimizer.step()
|
|
|
|
prec1, prec5 = ig_utils.accuracy(logits, target, topk=(1, 5))
|
|
n = input.size(0)
|
|
objs.update(loss.data, n)
|
|
top1.update(prec1.data, n)
|
|
top5.update(prec5.data, n)
|
|
|
|
if step % args.report_freq == 0:
|
|
logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
|
|
|
|
if args.fast:
|
|
logging.info('//// WARNING: FAST MODE')
|
|
break
|
|
|
|
return top1.avg, objs.avg
|
|
|
|
|
|
def infer(valid_queue, model, criterion):
|
|
objs = ig_utils.AvgrageMeter()
|
|
top1 = ig_utils.AvgrageMeter()
|
|
top5 = ig_utils.AvgrageMeter()
|
|
model.eval()
|
|
|
|
with torch.no_grad():
|
|
for step, (input, target) in enumerate(valid_queue):
|
|
input = input.cuda()
|
|
target = target.cuda(non_blocking=True)
|
|
|
|
logits, _ = model(input)
|
|
loss = criterion(logits, target)
|
|
|
|
prec1, prec5 = ig_utils.accuracy(logits, target, topk=(1, 5))
|
|
n = input.size(0)
|
|
objs.update(loss.data, n)
|
|
top1.update(prec1.data, n)
|
|
top5.update(prec5.data, n)
|
|
|
|
if step % args.report_freq == 0:
|
|
logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
|
|
|
|
if args.fast:
|
|
logging.info('//// WARNING: FAST MODE')
|
|
break
|
|
|
|
return top1.avg, objs.avg
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|