MeCo/sota/cnn/train_imagenet.py
HamsterMimi 189df25fd3 upload
2023-05-04 13:09:03 +08:00

254 lines
9.5 KiB
Python

from torch.utils.tensorboard import SummaryWriter
import argparse
import glob
import logging
import sys
sys.path.insert(0, '../../')
import time
import random
import numpy as np
import os
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.utils
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable
import nasbench201.utils as utils
from sota.cnn.model_imagenet import NetworkImageNet as Network
import sota.cnn.genotypes as genotypes
from sota.cnn.hdf5 import H5Dataset
parser = argparse.ArgumentParser("imagenet")
parser.add_argument('--data', type=str, default='../../data', help='location of the data corpus')
parser.add_argument('--batch_size', type=int, default=128, help='batch size')
parser.add_argument('--learning_rate', type=float, default=0.1, help='init learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float, default=3e-5, help='weight decay')
parser.add_argument('--report_freq', type=float, default=100, help='report frequency')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--epochs', type=int, default=250, help='num of training epochs')
parser.add_argument('--init_channels', type=int, default=48, help='num of init channels')
parser.add_argument('--layers', type=int, default=14, help='total number of layers')
parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower')
parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss')
parser.add_argument('--drop_path_prob', type=float, default=0, help='drop path probability')
parser.add_argument('--save', type=str, default='EXP', help='experiment name')
parser.add_argument('--seed', type=int, default=0, help='random_ws seed')
parser.add_argument('--arch', type=str, default='c10_s3_pgd', help='which architecture to use')
parser.add_argument('--grad_clip', type=float, default=5., help='gradient clipping')
parser.add_argument('--label_smooth', type=float, default=0.1, help='label smoothing')
parser.add_argument('--gamma', type=float, default=0.97, help='learning rate decay')
parser.add_argument('--decay_period', type=int, default=1, help='epochs between two learning rate decays')
parser.add_argument('--parallel', action='store_true', default=False, help='darts parallelism')
parser.add_argument('--load', action='store_true', default=False, help='whether load checkpoint for continue training')
args = parser.parse_args()
args.save = '../../experiments/sota/imagenet/eval/{}-{}-{}-{}'.format(
args.save, time.strftime("%Y%m%d-%H%M%S"), args.arch, args.seed)
if args.auxiliary:
args.save += '-auxiliary-' + str(args.auxiliary_weight)
args.save += '-' + str(np.random.randint(10000))
utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)
writer = SummaryWriter(args.save + '/runs')
CLASSES = 1000
class CrossEntropyLabelSmooth(nn.Module):
def __init__(self, num_classes, epsilon):
super(CrossEntropyLabelSmooth, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.logsoftmax = nn.LogSoftmax(dim=1)
def forward(self, inputs, targets):
log_probs = self.logsoftmax(inputs)
targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
loss = (-targets * log_probs).mean(0).sum()
return loss
def seed_torch(seed=0):
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
cudnn.deterministic = True
cudnn.benchmark = False
def main():
if not torch.cuda.is_available():
logging.info('no gpu device available')
sys.exit(1)
torch.cuda.set_device(args.gpu)
cudnn.enabled = True
seed_torch(args.seed)
logging.info('gpu device = %d' % args.gpu)
logging.info("args = %s", args)
genotype = eval("genotypes.%s" % args.arch)
model = Network(args.init_channels, CLASSES, args.layers, args.auxiliary, genotype)
if args.parallel:
model = nn.DataParallel(model).cuda()
else:
model = model.cuda()
logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
criterion = nn.CrossEntropyLoss()
criterion = criterion.cuda()
criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth)
criterion_smooth = criterion_smooth.cuda()
optimizer = torch.optim.SGD(
model.parameters(),
args.learning_rate,
momentum=args.momentum,
weight_decay=args.weight_decay
)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.2),
transforms.ToTensor(),
normalize,
])
test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])
train_data = H5Dataset(os.path.join(args.data, 'imagenet-train-256.h5'), transform=train_transform)
valid_data = H5Dataset(os.path.join(args.data, 'imagenet-val-256.h5'), transform=test_transform)
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=4)
valid_queue = torch.utils.data.DataLoader(
valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.decay_period, gamma=args.gamma)
if args.load:
model, optimizer, start_epoch, best_acc_top1 = utils.load_checkpoint(
model, optimizer, '../../experiments/sota/imagenet/eval/EXP-20200210-143540-c10_s3_pgd-0-auxiliary-0.4-2753')
else:
best_acc_top1 = 0
start_epoch = 0
for epoch in range(start_epoch, args.epochs):
logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
train_acc, train_obj = train(train_queue, model, criterion_smooth, optimizer)
logging.info('train_acc %f', train_acc)
writer.add_scalar('Acc/train', train_acc, epoch)
writer.add_scalar('Obj/train', train_obj, epoch)
scheduler.step()
valid_acc_top1, valid_acc_top5, valid_obj = infer(valid_queue, model, criterion)
logging.info('valid_acc_top1 %f', valid_acc_top1)
logging.info('valid_acc_top5 %f', valid_acc_top5)
writer.add_scalar('Acc/valid_top1', valid_acc_top1, epoch)
writer.add_scalar('Acc/valid_top5', valid_acc_top5, epoch)
is_best = False
if valid_acc_top1 > best_acc_top1:
best_acc_top1 = valid_acc_top1
is_best = True
utils.save_checkpoint({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'best_acc_top1': best_acc_top1,
'optimizer': optimizer.state_dict(),
}, is_best, args.save)
def train(train_queue, model, criterion, optimizer):
objs = utils.AvgrageMeter()
top1 = utils.AvgrageMeter()
top5 = utils.AvgrageMeter()
model.train()
for step, (input, target) in enumerate(train_queue):
input = input.cuda()
target = target.cuda(non_blocking=True)
optimizer.zero_grad()
logits, logits_aux = model(input)
loss = criterion(logits, target)
if args.auxiliary:
loss_aux = criterion(logits_aux, target)
loss += args.auxiliary_weight * loss_aux
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step()
prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
n = input.size(0)
objs.update(loss.data, n)
top1.update(prec1.data, n)
top5.update(prec5.data, n)
if step % args.report_freq == 0:
logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
return top1.avg, objs.avg
def infer(valid_queue, model, criterion):
objs = utils.AvgrageMeter()
top1 = utils.AvgrageMeter()
top5 = utils.AvgrageMeter()
model.eval()
with torch.no_grad():
for step, (input, target) in enumerate(valid_queue):
input = input.cuda()
target = target.cuda(non_blocking=True)
logits, _ = model(input)
loss = criterion(logits, target)
prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
n = input.size(0)
objs.update(loss.data, n)
top1.update(prec1.data, n)
top5.update(prec5.data, n)
if step % args.report_freq == 0:
logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
return top1.avg, top5.avg, objs.avg
if __name__ == '__main__':
main()