import os import random import sys sys.path.insert(0, '../../') import glob import numpy as np import torch import nasbench201.utils as ig_utils import logging import argparse import shutil import torch.nn as nn import torch.utils import torchvision.datasets as dset import torch.backends.cudnn as cudnn import json from sota.cnn.model import Network from torch.utils.tensorboard import SummaryWriter from collections import namedtuple parser = argparse.ArgumentParser("cifar") parser.add_argument('--data', type=str, default='../../data', help='location of the data corpus') parser.add_argument('--dataset', type=str, default='cifar10', help='choose dataset') parser.add_argument('--batch_size', type=int, default=96, help='batch size') parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate') parser.add_argument('--momentum', type=float, default=0.9, help='momentum') parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay') parser.add_argument('--report_freq', type=float, default=50, help='report frequency') parser.add_argument('--gpu', type=str, default='auto', help='gpu device id') parser.add_argument('--epochs', type=int, default=600, help='num of training epochs') parser.add_argument('--init_channels', type=int, default=36, help='num of init channels') parser.add_argument('--layers', type=int, default=20, help='total number of layers') parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model') parser.add_argument('--auxiliary', action='store_true', default=True, help='use auxiliary tower') parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss') parser.add_argument('--cutout', action='store_true', default=True, help='use cutout') parser.add_argument('--cutout_length', type=int, default=16, help='cutout length') parser.add_argument('--cutout_prob', type=float, default=1.0, help='cutout probability') parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path probability') parser.add_argument('--save', type=str, default='exp', help='experiment name') parser.add_argument('--seed', type=int, default=0, help='random seed') parser.add_argument('--arch', type=str, default='c100_s4_pgd', help='which architecture to use') parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping') #### common parser.add_argument('--resume_epoch', type=int, default=0, help="load ckpt, start training at resume_epoch") parser.add_argument('--ckpt_interval', type=int, default=50, help="interval (epoch) for saving checkpoints") parser.add_argument('--resume_expid', type=str, default='', help="full expid to resume from, name == ckpt folder name") parser.add_argument('--fast', action='store_true', default=False, help="fast mode for debugging") parser.add_argument('--queue', action='store_true', default=False, help="queueing for gpu") parser.add_argument('--from_dir', action='store_true', default=True, help="arch load form dir") args = parser.parse_args() def load_network_pool(ckpt_path): with open(os.path.join(ckpt_path, 'best_networks.json'), 'r') as save_file: networks_pool = json.load(save_file) return networks_pool['networks'] Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') #### args augment expid = args.save print(args.from_dir) if args.from_dir: id_name = os.path.split(args.arch)[1] # print('aaaaaaa', args.arch) args.arch = load_network_pool(args.arch) args.save = '../../experiments/sota/{}/eval/{}-{}-{}'.format( args.dataset, args.save, id_name, args.seed) else: args.save = '../../experiments/sota/{}/eval/{}-{}-{}'.format( args.dataset, args.save, args.arch, args.seed) if args.cutout: args.save += '-cutout-' + str(args.cutout_length) + '-' + str(args.cutout_prob) if args.auxiliary: args.save += '-auxiliary-' + str(args.auxiliary_weight) #### logging if args.resume_epoch > 0: # do not delete dir if resume: args.save = '../../experiments/sota/{}/{}'.format(args.dataset, args.resume_expid) assert (os.path.exists(args.save), 'resume but {} does not exist!'.format(args.save)) else: scripts_to_save = glob.glob('*.py') if os.path.exists(args.save): if input("WARNING: {} exists, override?[y/n]".format(args.save)) == 'y': print('proceed to override saving directory') shutil.rmtree(args.save) else: exit(0) ig_utils.create_exp_dir(args.save, scripts_to_save=scripts_to_save) log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') log_file = 'log_resume_{}.txt'.format(args.resume_epoch) if args.resume_epoch > 0 else 'log.txt' fh = logging.FileHandler(os.path.join(args.save, log_file), mode='w') fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) writer = SummaryWriter(args.save + '/runs') if args.dataset == 'cifar100': n_classes = 100 else: n_classes = 10 def seed_torch(seed=0): random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) cudnn.deterministic = True cudnn.benchmark = False def main(): torch.set_num_threads(3) if not torch.cuda.is_available(): logging.info('no gpu device available') sys.exit(1) #### gpu queueing if args.queue: ig_utils.queue_gpu() gpu = ig_utils.pick_gpu_lowest_memory() if args.gpu == 'auto' else int(args.gpu) torch.cuda.set_device(gpu) cudnn.enabled = True seed_torch(args.seed) logging.info('gpu device = %d' % gpu) logging.info("args = %s", args) if args.from_dir: genotype_config = json.loads(args.arch) genotype = Genotype(normal=genotype_config['normal'], normal_concat=genotype_config['normal_concat'], reduce=genotype_config['reduce'], reduce_concat=genotype_config['reduce_concat']) else: genotype = eval("genotypes.%s" % args.arch) model = Network(args.init_channels, n_classes, args.layers, args.auxiliary, genotype) model = model.cuda() logging.info("param size = %fMB", ig_utils.count_parameters_in_MB(model)) criterion = nn.CrossEntropyLoss() criterion = criterion.cuda() optimizer = torch.optim.SGD( model.parameters(), args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay ) if args.dataset == 'cifar10': train_transform, valid_transform = ig_utils._data_transforms_cifar10(args) train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform) valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform) elif args.dataset == 'cifar100': train_transform, valid_transform = ig_utils._data_transforms_cifar100(args) train_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=train_transform) valid_data = dset.CIFAR100(root=args.data, train=False, download=True, transform=valid_transform) elif args.dataset == 'svhn': train_transform, valid_transform = ig_utils._data_transforms_svhn(args) train_data = dset.SVHN(root=args.data, split='train', download=True, transform=train_transform) valid_data = dset.SVHN(root=args.data, split='test', download=True, transform=valid_transform) train_queue = torch.utils.data.DataLoader( train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=0) valid_queue = torch.utils.data.DataLoader( valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, float(args.epochs), # eta_min=1e-4 ) #### resume start_epoch = 0 if args.resume_epoch > 0: logging.info('loading checkpoint from {}'.format(expid)) filename = os.path.join(args.save, 'checkpoint_{}.pth.tar'.format(args.resume_epoch)) if os.path.isfile(filename): print("=> loading checkpoint '{}'".format(filename)) checkpoint = torch.load(filename, map_location='cpu') resume_epoch = checkpoint['epoch'] # epoch model.load_state_dict(checkpoint['state_dict']) # model scheduler.load_state_dict(checkpoint['scheduler']) optimizer.load_state_dict(checkpoint['optimizer']) # optimizer start_epoch = args.resume_epoch print("=> loaded checkpoint '{}' (epoch {})".format(filename, resume_epoch)) else: print("=> no checkpoint found at '{}'".format(filename)) #### main training best_valid_acc = 0 for epoch in range(start_epoch, args.epochs): lr = scheduler.get_lr()[0] if args.cutout: train_transform.transforms[-1].cutout_prob = args.cutout_prob logging.info('epoch %d lr %e cutout_prob %e', epoch, lr, train_transform.transforms[-1].cutout_prob) else: logging.info('epoch %d lr %e', epoch, lr) model.drop_path_prob = args.drop_path_prob * epoch / args.epochs train_acc, train_obj = train(train_queue, model, criterion, optimizer) logging.info('train_acc %f', train_acc) writer.add_scalar('Acc/train', train_acc, epoch) writer.add_scalar('Obj/train', train_obj, epoch) ## scheduler scheduler.step() valid_acc, valid_obj = infer(valid_queue, model, criterion) logging.info('valid_acc %f', valid_acc) writer.add_scalar('Acc/valid', valid_acc, epoch) writer.add_scalar('Obj/valid', valid_obj, epoch) ## checkpoint if (epoch + 1) % args.ckpt_interval == 0: save_state_dict = { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), } ig_utils.save_checkpoint(save_state_dict, False, args.save, per_epoch=True) best_valid_acc = max(best_valid_acc, valid_acc) logging.info('best valid_acc %f', best_valid_acc) writer.close() def train(train_queue, model, criterion, optimizer): objs = ig_utils.AvgrageMeter() top1 = ig_utils.AvgrageMeter() top5 = ig_utils.AvgrageMeter() model.train() for step, (input, target) in enumerate(train_queue): input = input.cuda() target = target.cuda(non_blocking=True) optimizer.zero_grad() logits, logits_aux = model(input) loss = criterion(logits, target) if args.auxiliary: loss_aux = criterion(logits_aux, target) loss += args.auxiliary_weight * loss_aux loss.backward() nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step() prec1, prec5 = ig_utils.accuracy(logits, target, topk=(1, 5)) n = input.size(0) objs.update(loss.data, n) top1.update(prec1.data, n) top5.update(prec5.data, n) if step % args.report_freq == 0: logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) if args.fast: logging.info('//// WARNING: FAST MODE') break return top1.avg, objs.avg def infer(valid_queue, model, criterion): objs = ig_utils.AvgrageMeter() top1 = ig_utils.AvgrageMeter() top5 = ig_utils.AvgrageMeter() model.eval() with torch.no_grad(): for step, (input, target) in enumerate(valid_queue): input = input.cuda() target = target.cuda(non_blocking=True) logits, _ = model(input) loss = criterion(logits, target) prec1, prec5 = ig_utils.accuracy(logits, target, topk=(1, 5)) n = input.size(0) objs.update(loss.data, n) top1.update(prec1.data, n) top5.update(prec5.data, n) if step % args.report_freq == 0: logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) if args.fast: logging.info('//// WARNING: FAST MODE') break return top1.avg, objs.avg if __name__ == '__main__': main()