MeCo/sota/cnn/train.py
HamsterMimi 189df25fd3 upload
2023-05-04 13:09:03 +08:00

310 lines
12 KiB
Python

import os
import random
import sys
sys.path.insert(0, '../../')
import glob
import numpy as np
import torch
import nasbench201.utils as ig_utils
import logging
import argparse
import shutil
import torch.nn as nn
import torch.utils
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
import json
from sota.cnn.model import Network
from torch.utils.tensorboard import SummaryWriter
from collections import namedtuple
parser = argparse.ArgumentParser("cifar")
parser.add_argument('--data', type=str, default='../../data',
help='location of the data corpus')
parser.add_argument('--dataset', type=str, default='cifar10', help='choose dataset')
parser.add_argument('--batch_size', type=int, default=96, help='batch size')
parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
parser.add_argument('--gpu', type=str, default='auto', help='gpu device id')
parser.add_argument('--epochs', type=int, default=600, help='num of training epochs')
parser.add_argument('--init_channels', type=int, default=36, help='num of init channels')
parser.add_argument('--layers', type=int, default=20, help='total number of layers')
parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model')
parser.add_argument('--auxiliary', action='store_true', default=True, help='use auxiliary tower')
parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss')
parser.add_argument('--cutout', action='store_true', default=True, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--cutout_prob', type=float, default=1.0, help='cutout probability')
parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path probability')
parser.add_argument('--save', type=str, default='exp', help='experiment name')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--arch', type=str, default='c100_s4_pgd', help='which architecture to use')
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
#### common
parser.add_argument('--resume_epoch', type=int, default=0, help="load ckpt, start training at resume_epoch")
parser.add_argument('--ckpt_interval', type=int, default=50, help="interval (epoch) for saving checkpoints")
parser.add_argument('--resume_expid', type=str, default='', help="full expid to resume from, name == ckpt folder name")
parser.add_argument('--fast', action='store_true', default=False, help="fast mode for debugging")
parser.add_argument('--queue', action='store_true', default=False, help="queueing for gpu")
parser.add_argument('--from_dir', action='store_true', default=True, help="arch load form dir")
args = parser.parse_args()
def load_network_pool(ckpt_path):
with open(os.path.join(ckpt_path, 'best_networks.json'), 'r') as save_file:
networks_pool = json.load(save_file)
return networks_pool['networks']
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
#### args augment
expid = args.save
print(args.from_dir)
if args.from_dir:
id_name = os.path.split(args.arch)[1]
# print('aaaaaaa', args.arch)
args.arch = load_network_pool(args.arch)
args.save = '../../experiments/sota/{}/eval/{}-{}-{}'.format(
args.dataset, args.save, id_name, args.seed)
else:
args.save = '../../experiments/sota/{}/eval/{}-{}-{}'.format(
args.dataset, args.save, args.arch, args.seed)
if args.cutout:
args.save += '-cutout-' + str(args.cutout_length) + '-' + str(args.cutout_prob)
if args.auxiliary:
args.save += '-auxiliary-' + str(args.auxiliary_weight)
#### logging
if args.resume_epoch > 0: # do not delete dir if resume:
args.save = '../../experiments/sota/{}/{}'.format(args.dataset, args.resume_expid)
assert (os.path.exists(args.save), 'resume but {} does not exist!'.format(args.save))
else:
scripts_to_save = glob.glob('*.py')
if os.path.exists(args.save):
if input("WARNING: {} exists, override?[y/n]".format(args.save)) == 'y':
print('proceed to override saving directory')
shutil.rmtree(args.save)
else:
exit(0)
ig_utils.create_exp_dir(args.save, scripts_to_save=scripts_to_save)
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
format=log_format, datefmt='%m/%d %I:%M:%S %p')
log_file = 'log_resume_{}.txt'.format(args.resume_epoch) if args.resume_epoch > 0 else 'log.txt'
fh = logging.FileHandler(os.path.join(args.save, log_file), mode='w')
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)
writer = SummaryWriter(args.save + '/runs')
if args.dataset == 'cifar100':
n_classes = 100
else:
n_classes = 10
def seed_torch(seed=0):
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
cudnn.deterministic = True
cudnn.benchmark = False
def main():
torch.set_num_threads(3)
if not torch.cuda.is_available():
logging.info('no gpu device available')
sys.exit(1)
#### gpu queueing
if args.queue:
ig_utils.queue_gpu()
gpu = ig_utils.pick_gpu_lowest_memory() if args.gpu == 'auto' else int(args.gpu)
torch.cuda.set_device(gpu)
cudnn.enabled = True
seed_torch(args.seed)
logging.info('gpu device = %d' % gpu)
logging.info("args = %s", args)
if args.from_dir:
genotype_config = json.loads(args.arch)
genotype = Genotype(normal=genotype_config['normal'], normal_concat=genotype_config['normal_concat'],
reduce=genotype_config['reduce'], reduce_concat=genotype_config['reduce_concat'])
else:
genotype = eval("genotypes.%s" % args.arch)
model = Network(args.init_channels, n_classes, args.layers, args.auxiliary, genotype)
model = model.cuda()
logging.info("param size = %fMB", ig_utils.count_parameters_in_MB(model))
criterion = nn.CrossEntropyLoss()
criterion = criterion.cuda()
optimizer = torch.optim.SGD(
model.parameters(),
args.learning_rate,
momentum=args.momentum,
weight_decay=args.weight_decay
)
if args.dataset == 'cifar10':
train_transform, valid_transform = ig_utils._data_transforms_cifar10(args)
train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)
elif args.dataset == 'cifar100':
train_transform, valid_transform = ig_utils._data_transforms_cifar100(args)
train_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=train_transform)
valid_data = dset.CIFAR100(root=args.data, train=False, download=True, transform=valid_transform)
elif args.dataset == 'svhn':
train_transform, valid_transform = ig_utils._data_transforms_svhn(args)
train_data = dset.SVHN(root=args.data, split='train', download=True, transform=train_transform)
valid_data = dset.SVHN(root=args.data, split='test', download=True, transform=valid_transform)
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=0)
valid_queue = torch.utils.data.DataLoader(
valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, float(args.epochs),
# eta_min=1e-4
)
#### resume
start_epoch = 0
if args.resume_epoch > 0:
logging.info('loading checkpoint from {}'.format(expid))
filename = os.path.join(args.save, 'checkpoint_{}.pth.tar'.format(args.resume_epoch))
if os.path.isfile(filename):
print("=> loading checkpoint '{}'".format(filename))
checkpoint = torch.load(filename, map_location='cpu')
resume_epoch = checkpoint['epoch'] # epoch
model.load_state_dict(checkpoint['state_dict']) # model
scheduler.load_state_dict(checkpoint['scheduler'])
optimizer.load_state_dict(checkpoint['optimizer']) # optimizer
start_epoch = args.resume_epoch
print("=> loaded checkpoint '{}' (epoch {})".format(filename, resume_epoch))
else:
print("=> no checkpoint found at '{}'".format(filename))
#### main training
best_valid_acc = 0
for epoch in range(start_epoch, args.epochs):
lr = scheduler.get_lr()[0]
if args.cutout:
train_transform.transforms[-1].cutout_prob = args.cutout_prob
logging.info('epoch %d lr %e cutout_prob %e', epoch, lr,
train_transform.transforms[-1].cutout_prob)
else:
logging.info('epoch %d lr %e', epoch, lr)
model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
train_acc, train_obj = train(train_queue, model, criterion, optimizer)
logging.info('train_acc %f', train_acc)
writer.add_scalar('Acc/train', train_acc, epoch)
writer.add_scalar('Obj/train', train_obj, epoch)
## scheduler
scheduler.step()
valid_acc, valid_obj = infer(valid_queue, model, criterion)
logging.info('valid_acc %f', valid_acc)
writer.add_scalar('Acc/valid', valid_acc, epoch)
writer.add_scalar('Obj/valid', valid_obj, epoch)
## checkpoint
if (epoch + 1) % args.ckpt_interval == 0:
save_state_dict = {
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
}
ig_utils.save_checkpoint(save_state_dict, False, args.save, per_epoch=True)
best_valid_acc = max(best_valid_acc, valid_acc)
logging.info('best valid_acc %f', best_valid_acc)
writer.close()
def train(train_queue, model, criterion, optimizer):
objs = ig_utils.AvgrageMeter()
top1 = ig_utils.AvgrageMeter()
top5 = ig_utils.AvgrageMeter()
model.train()
for step, (input, target) in enumerate(train_queue):
input = input.cuda()
target = target.cuda(non_blocking=True)
optimizer.zero_grad()
logits, logits_aux = model(input)
loss = criterion(logits, target)
if args.auxiliary:
loss_aux = criterion(logits_aux, target)
loss += args.auxiliary_weight * loss_aux
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step()
prec1, prec5 = ig_utils.accuracy(logits, target, topk=(1, 5))
n = input.size(0)
objs.update(loss.data, n)
top1.update(prec1.data, n)
top5.update(prec5.data, n)
if step % args.report_freq == 0:
logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
if args.fast:
logging.info('//// WARNING: FAST MODE')
break
return top1.avg, objs.avg
def infer(valid_queue, model, criterion):
objs = ig_utils.AvgrageMeter()
top1 = ig_utils.AvgrageMeter()
top5 = ig_utils.AvgrageMeter()
model.eval()
with torch.no_grad():
for step, (input, target) in enumerate(valid_queue):
input = input.cuda()
target = target.cuda(non_blocking=True)
logits, _ = model(input)
loss = criterion(logits, target)
prec1, prec5 = ig_utils.accuracy(logits, target, topk=(1, 5))
n = input.size(0)
objs.update(loss.data, n)
top1.update(prec1.data, n)
top5.update(prec5.data, n)
if step % args.report_freq == 0:
logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
if args.fast:
logging.info('//// WARNING: FAST MODE')
break
return top1.avg, objs.avg
if __name__ == '__main__':
main()