from __future__ import print_function
import numpy as np
import os
import os.path
import sys
import shutil
import torch
import torchvision.transforms as transforms
from PIL import Image
from torch.autograd import Variable
from torchvision.datasets import VisionDataset
from torchvision.datasets import utils
if sys.version_info[0] == 2:
import cPickle as pickle
import pickle
class AvgrageMeter(object):
def __init__(self):
def reset(self):
self.avg = 0
self.sum = 0
self.cnt = 0
def update(self, val, n=1):
self.sum += val * n
self.cnt += n
self.avg = self.sum / self.cnt
def accuracy(output, target, topk=(1,)):
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].contiguous().view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class Cutout(object):
def __init__(self, length, prob=1.0):
self.length = length
self.prob = prob
def __call__(self, img):
if np.random.binomial(1, self.prob):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img
def _data_transforms_svhn(args):
SVHN_MEAN = [0.4377, 0.4438, 0.4728]
SVHN_STD = [0.1980, 0.2010, 0.1970]
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.Normalize(SVHN_MEAN, SVHN_STD),
if args.cutout:
valid_transform = transforms.Compose([
transforms.Normalize(SVHN_MEAN, SVHN_STD),
return train_transform, valid_transform
def _data_transforms_cifar100(args):
CIFAR_MEAN = [0.5071, 0.4865, 0.4409]
CIFAR_STD = [0.2673, 0.2564, 0.2762]
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
if args.cutout:
valid_transform = transforms.Compose([
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
return train_transform, valid_transform
def _data_transforms_cifar10(args):
CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
if args.cutout:
valid_transform = transforms.Compose([
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
return train_transform, valid_transform
def count_parameters_in_MB(model):
return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name) / 1e6
def count_parameters_in_Compact(model):
from sota.cnn.model import Network as CompactModel
genotype = model.genotype()
compact_model = CompactModel(36, model._num_classes, 20, True, genotype)
num_params = count_parameters_in_MB(compact_model)
return num_params
def save_checkpoint(state, is_best, save, per_epoch=False, prefix=''):
filename = prefix
if per_epoch:
epoch = state['epoch']
filename += 'checkpoint_{}.pth.tar'.format(epoch)
filename += 'checkpoint.pth.tar'
filename = os.path.join(save, filename)
torch.save(state, filename)
if is_best:
best_filename = os.path.join(save, 'model_best.pth.tar')
shutil.copyfile(filename, best_filename)
def load_checkpoint(model, optimizer, save, epoch=None):
if epoch is None:
filename = 'checkpoint.pth.tar'
filename = 'checkpoint_{}.pth.tar'.format(epoch)
filename = os.path.join(save, filename)
start_epoch = 0
if os.path.isfile(filename):
print("=> loading checkpoint '{}'".format(filename))
checkpoint = torch.load(filename)
start_epoch = checkpoint['epoch']
best_acc_top1 = checkpoint['best_acc_top1']
print("=> loaded checkpoint '{}' (epoch {})"
.format(filename, checkpoint['epoch']))
print("=> no checkpoint found at '{}'".format(filename))
return model, optimizer, start_epoch, best_acc_top1
def save(model, model_path):
torch.save(model.state_dict(), model_path)
def load(model, model_path):
def drop_path(x, drop_prob):
if drop_prob > 0.:
keep_prob = 1. - drop_prob
mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob))
return x
def create_exp_dir(path, scripts_to_save=None):
if not os.path.exists(path):
print('Experiment dir : {}'.format(path))
if scripts_to_save is not None:
os.mkdir(os.path.join(path, 'scripts'))
for script in scripts_to_save:
dst_file = os.path.join(path, 'scripts', os.path.basename(script))
shutil.copyfile(script, dst_file)
class CIFAR10(VisionDataset):
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
root (string): Root directory of dataset where directory
``cifar-10-batches-py`` exists or will be saved to if download is set to True.
train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
base_folder = 'cifar-10-batches-py'
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
filename = "cifar-10-python.tar.gz"
tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
train_list = [
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
['data_batch_4', '634d18415352ddfa80567beed471001a'],
#['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
test_list = [
['test_batch', '40351d587109b95175f43aff81a1287e'],
meta = {
'filename': 'batches.meta',
'key': 'label_names',
'md5': '5ff9c542aee3614f3951f8cda6e48888',
def __init__(self, root, train=True, transform=None, target_transform=None,
super(CIFAR10, self).__init__(root, transform=transform,
self.train = train # training set or test set
if download:
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
if self.train:
downloaded_list = self.train_list
downloaded_list = self.test_list
self.data = []
self.targets = []
# now load the picked numpy arrays
for file_name, checksum in downloaded_list:
file_path = os.path.join(self.root, self.base_folder, file_name)
with open(file_path, 'rb') as f:
if sys.version_info[0] == 2:
entry = pickle.load(f)
entry = pickle.load(f, encoding='latin1')
if 'labels' in entry:
self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
def _load_meta(self):
path = os.path.join(self.root, self.base_folder, self.meta['filename'])
if not utils.check_integrity(path, self.meta['md5']):
raise RuntimeError('Dataset metadata file not found or corrupted.' +
' You can use download=True to download it')
with open(path, 'rb') as infile:
if sys.version_info[0] == 2:
data = pickle.load(infile)
data = pickle.load(infile, encoding='latin1')
self.classes = data[self.meta['key']]
self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
def __getitem__(self, index):
index (int): Index
tuple: (image, target) where target is index of the target class.
img, target = self.data[index], self.targets[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.data)
def _check_integrity(self):
root = self.root
for fentry in (self.train_list + self.test_list):
filename, md5 = fentry[0], fentry[1]
fpath = os.path.join(root, self.base_folder, filename)
if not utils.check_integrity(fpath, md5):
return False
return True
def download(self):
if self._check_integrity():
print('Files already downloaded and verified')
utils.download_and_extract_archive(self.url, self.root,
def extra_repr(self):
return "Split: {}".format("Train" if self.train is True else "Test")
def pick_gpu_lowest_memory():
import gpustat
stats = gpustat.GPUStatCollection.new_query()
ids = map(lambda gpu: int(gpu.entry['index']), stats)
ratios = map(lambda gpu: float(gpu.memory_used)/float(gpu.memory_total), stats)
bestGPU = min(zip(ids, ratios), key=lambda x: x[1])[0]
return bestGPU
#### early stopping (from RobustNAS)
class EVLocalAvg(object):
def __init__(self, window=5, ev_freq=2, total_epochs=50):
""" Keep track of the eigenvalues local average.
window (int): number of elements used to compute local average.
Default: 5
ev_freq (int): frequency used to compute eigenvalues. Default:
every 2 epochs
total_epochs (int): total number of epochs that DARTS runs.
Default: 50
self.window = window
self.ev_freq = ev_freq
self.epochs = total_epochs
self.stop_search = False
self.stop_epoch = total_epochs - 1
self.stop_genotype = None
self.stop_numparam = 0
self.ev = []
self.ev_local_avg = []
self.genotypes = {}
self.numparams = {}
self.la_epochs = {}
# start and end index of the local average window
self.la_start_idx = 0
self.la_end_idx = self.window
def reset(self):
self.ev = []
self.ev_local_avg = []
self.genotypes = {}
self.numparams = {}
self.la_epochs = {}
def update(self, epoch, ev, genotype, numparam=0):
""" Method to update the local average list.
epoch (int): current epoch
ev (float): current dominant eigenvalue
genotype (namedtuple): current genotype
self.genotypes.update({epoch: genotype})
self.numparams.update({epoch: numparam})
# set the stop_genotype to the current genotype in case the early stop
# procedure decides not to early stop
self.stop_genotype = genotype
# since the local average computation starts after the dominant
# eigenvalue in the first epoch is already computed we have to wait
# at least until we have 3 eigenvalues in the list.
if (len(self.ev) >= int(np.ceil(self.window/2))) and (epoch <
self.epochs - 1):
# start sliding the window as soon as the number of eigenvalues in
# the list becomes equal to the window size
if len(self.ev) < self.window:
assert len(self.ev[self.la_start_idx: self.la_end_idx]) == self.window
self.la_start_idx += 1
self.la_end_idx += 1
# keep track of the offset between the current epoch and the epoch
# corresponding to the local average. NOTE: in the end the size of
# self.ev and self.ev_local_avg should be equal
self.la_epochs.update({epoch: int(epoch -
elif len(self.ev) < int(np.ceil(self.window/2)):
self.la_epochs.update({epoch: -1})
# since there is an offset between the current epoch and the local
# average epoch, loop in the last epoch to compute the local average of
# these number of elements: window, window - 1, window - 2, ..., ceil(window/2)
elif epoch == self.epochs - 1:
for i in range(int(np.ceil(self.window/2))):
assert len(self.ev[self.la_start_idx: self.la_end_idx]) == self.window - i
self.la_end_idx + 1]))
self.la_start_idx += 1
def early_stop(self, epoch, factor=1.3, es_start_epoch=10, delta=4, criteria='local_avg'):
""" Early stopping criterion
epoch (int): current epoch
factor (float): threshold factor for the ration between the current
and prefious eigenvalue. Default: 1.3
es_start_epoch (int): until this epoch do not consider early
stopping. Default: 20
delta (int): factor influencing which previous local average we
consider for early stopping. Default: 2
if criteria == 'local_avg':
if int(self.la_epochs[epoch] - self.ev_freq*delta) >= es_start_epoch:
if criteria == 'local_avg':
current_la = self.ev_local_avg[-1]
previous_la = self.ev_local_avg[-1 - delta]
self.stop_search = current_la / previous_la > factor
if self.stop_search:
self.stop_epoch = int(self.la_epochs[epoch] - self.ev_freq*delta)
self.stop_genotype = self.genotypes[self.stop_epoch]
self.stop_numparam = self.numparams[self.stop_epoch]
elif criteria == 'exact':
if epoch > es_start_epoch:
current_la = self.ev[-1]
previous_la = self.ev[-1 - delta]
self.stop_search = current_la / previous_la > factor
if self.stop_search:
self.stop_epoch = epoch - delta
self.stop_genotype = self.genotypes[self.stop_epoch]
self.stop_numparam = self.numparams[self.stop_epoch]
print('ERROR IN EARLY STOP: WRONG CRITERIA:', criteria); exit(0)
def gen_comb(eids):
comb = []
for r in range(len(eids)):
for c in range(r + 1, len(eids)):
comb.append((eids[r], eids[c]))
return comb