from __future__ import print_function import numpy as np import os import os.path import sys import shutil import torch import torchvision.transforms as transforms from PIL import Image from torch.autograd import Variable from torchvision.datasets import VisionDataset from torchvision.datasets import utils if sys.version_info[0] == 2: import cPickle as pickle else: import pickle class AvgrageMeter(object): def __init__(self): self.reset() def reset(self): self.avg = 0 self.sum = 0 self.cnt = 0 def update(self, val, n=1): self.sum += val * n self.cnt += n self.avg = self.sum / self.cnt def accuracy(output, target, topk=(1,)): maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].contiguous().view(-1).float().sum(0) res.append(correct_k.mul_(100.0 / batch_size)) return res class Cutout(object): def __init__(self, length, prob=1.0): self.length = length self.prob = prob def __call__(self, img): if np.random.binomial(1, self.prob): h, w = img.size(1), img.size(2) mask = np.ones((h, w), np.float32) y = np.random.randint(h) x = np.random.randint(w) y1 = np.clip(y - self.length // 2, 0, h) y2 = np.clip(y + self.length // 2, 0, h) x1 = np.clip(x - self.length // 2, 0, w) x2 = np.clip(x + self.length // 2, 0, w) mask[y1: y2, x1: x2] = 0. mask = torch.from_numpy(mask) mask = mask.expand_as(img) img *= mask return img def _data_transforms_svhn(args): SVHN_MEAN = [0.4377, 0.4438, 0.4728] SVHN_STD = [0.1980, 0.2010, 0.1970] train_transform = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(SVHN_MEAN, SVHN_STD), ]) if args.cutout: train_transform.transforms.append(Cutout(args.cutout_length, args.cutout_prob)) valid_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(SVHN_MEAN, SVHN_STD), ]) return train_transform, valid_transform def _data_transforms_cifar100(args): CIFAR_MEAN = [0.5071, 0.4865, 0.4409] CIFAR_STD = [0.2673, 0.2564, 0.2762] train_transform = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(CIFAR_MEAN, CIFAR_STD), ]) if args.cutout: train_transform.transforms.append(Cutout(args.cutout_length, args.cutout_prob)) valid_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(CIFAR_MEAN, CIFAR_STD), ]) return train_transform, valid_transform def _data_transforms_cifar10(args): CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124] CIFAR_STD = [0.24703233, 0.24348505, 0.26158768] train_transform = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(CIFAR_MEAN, CIFAR_STD), ]) if args.cutout: train_transform.transforms.append(Cutout(args.cutout_length, args.cutout_prob)) valid_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(CIFAR_MEAN, CIFAR_STD), ]) return train_transform, valid_transform def count_parameters_in_MB(model): return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name) / 1e6 def count_parameters_in_Compact(model): from sota.cnn.model import Network as CompactModel genotype = model.genotype() compact_model = CompactModel(36, model._num_classes, 20, True, genotype) num_params = count_parameters_in_MB(compact_model) return num_params def save_checkpoint(state, is_best, save, per_epoch=False, prefix=''): filename = prefix if per_epoch: epoch = state['epoch'] filename += 'checkpoint_{}.pth.tar'.format(epoch) else: filename += 'checkpoint.pth.tar' filename = os.path.join(save, filename) torch.save(state, filename) if is_best: best_filename = os.path.join(save, 'model_best.pth.tar') shutil.copyfile(filename, best_filename) def load_checkpoint(model, optimizer, save, epoch=None): if epoch is None: filename = 'checkpoint.pth.tar' else: filename = 'checkpoint_{}.pth.tar'.format(epoch) filename = os.path.join(save, filename) start_epoch = 0 if os.path.isfile(filename): print("=> loading checkpoint '{}'".format(filename)) checkpoint = torch.load(filename) start_epoch = checkpoint['epoch'] best_acc_top1 = checkpoint['best_acc_top1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})" .format(filename, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(filename)) return model, optimizer, start_epoch, best_acc_top1 def save(model, model_path): torch.save(model.state_dict(), model_path) def load(model, model_path): model.load_state_dict(torch.load(model_path)) def drop_path(x, drop_prob): if drop_prob > 0.: keep_prob = 1. - drop_prob mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)) x.div_(keep_prob) x.mul_(mask) return x def create_exp_dir(path, scripts_to_save=None): if not os.path.exists(path): os.makedirs(path) print('Experiment dir : {}'.format(path)) if scripts_to_save is not None: os.mkdir(os.path.join(path, 'scripts')) for script in scripts_to_save: dst_file = os.path.join(path, 'scripts', os.path.basename(script)) shutil.copyfile(script, dst_file) class CIFAR10(VisionDataset): """`CIFAR10 `_ Dataset. Args: root (string): Root directory of dataset where directory ``cifar-10-batches-py`` exists or will be saved to if download is set to True. train (bool, optional): If True, creates dataset from training set, otherwise creates from test set. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. """ base_folder = 'cifar-10-batches-py' url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" filename = "cifar-10-python.tar.gz" tgz_md5 = 'c58f30108f718f92721af3b95e74349a' train_list = [ ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], ['data_batch_4', '634d18415352ddfa80567beed471001a'], #['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], ] test_list = [ ['test_batch', '40351d587109b95175f43aff81a1287e'], ] meta = { 'filename': 'batches.meta', 'key': 'label_names', 'md5': '5ff9c542aee3614f3951f8cda6e48888', } def __init__(self, root, train=True, transform=None, target_transform=None, download=False): super(CIFAR10, self).__init__(root, transform=transform, target_transform=target_transform) self.train = train # training set or test set if download: self.download() if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.' + ' You can use download=True to download it') if self.train: downloaded_list = self.train_list else: downloaded_list = self.test_list self.data = [] self.targets = [] # now load the picked numpy arrays for file_name, checksum in downloaded_list: file_path = os.path.join(self.root, self.base_folder, file_name) with open(file_path, 'rb') as f: if sys.version_info[0] == 2: entry = pickle.load(f) else: entry = pickle.load(f, encoding='latin1') self.data.append(entry['data']) if 'labels' in entry: self.targets.extend(entry['labels']) else: self.targets.extend(entry['fine_labels']) self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC self._load_meta() def _load_meta(self): path = os.path.join(self.root, self.base_folder, self.meta['filename']) if not utils.check_integrity(path, self.meta['md5']): raise RuntimeError('Dataset metadata file not found or corrupted.' + ' You can use download=True to download it') with open(path, 'rb') as infile: if sys.version_info[0] == 2: data = pickle.load(infile) else: data = pickle.load(infile, encoding='latin1') self.classes = data[self.meta['key']] self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is index of the target class. """ img, target = self.data[index], self.targets[index] # doing this so that it is consistent with all other datasets # to return a PIL Image img = Image.fromarray(img) if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self): return len(self.data) def _check_integrity(self): root = self.root for fentry in (self.train_list + self.test_list): filename, md5 = fentry[0], fentry[1] fpath = os.path.join(root, self.base_folder, filename) if not utils.check_integrity(fpath, md5): return False return True def download(self): if self._check_integrity(): print('Files already downloaded and verified') return utils.download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) def extra_repr(self): return "Split: {}".format("Train" if self.train is True else "Test") def pick_gpu_lowest_memory(): import gpustat stats = gpustat.GPUStatCollection.new_query() ids = map(lambda gpu: int(gpu.entry['index']), stats) ratios = map(lambda gpu: float(gpu.memory_used)/float(gpu.memory_total), stats) bestGPU = min(zip(ids, ratios), key=lambda x: x[1])[0] return bestGPU #### early stopping (from RobustNAS) class EVLocalAvg(object): def __init__(self, window=5, ev_freq=2, total_epochs=50): """ Keep track of the eigenvalues local average. Args: window (int): number of elements used to compute local average. Default: 5 ev_freq (int): frequency used to compute eigenvalues. Default: every 2 epochs total_epochs (int): total number of epochs that DARTS runs. Default: 50 """ self.window = window self.ev_freq = ev_freq self.epochs = total_epochs self.stop_search = False self.stop_epoch = total_epochs - 1 self.stop_genotype = None self.stop_numparam = 0 self.ev = [] self.ev_local_avg = [] self.genotypes = {} self.numparams = {} self.la_epochs = {} # start and end index of the local average window self.la_start_idx = 0 self.la_end_idx = self.window def reset(self): self.ev = [] self.ev_local_avg = [] self.genotypes = {} self.numparams = {} self.la_epochs = {} def update(self, epoch, ev, genotype, numparam=0): """ Method to update the local average list. Args: epoch (int): current epoch ev (float): current dominant eigenvalue genotype (namedtuple): current genotype """ self.ev.append(ev) self.genotypes.update({epoch: genotype}) self.numparams.update({epoch: numparam}) # set the stop_genotype to the current genotype in case the early stop # procedure decides not to early stop self.stop_genotype = genotype # since the local average computation starts after the dominant # eigenvalue in the first epoch is already computed we have to wait # at least until we have 3 eigenvalues in the list. if (len(self.ev) >= int(np.ceil(self.window/2))) and (epoch < self.epochs - 1): # start sliding the window as soon as the number of eigenvalues in # the list becomes equal to the window size if len(self.ev) < self.window: self.ev_local_avg.append(np.mean(self.ev)) else: assert len(self.ev[self.la_start_idx: self.la_end_idx]) == self.window self.ev_local_avg.append(np.mean(self.ev[self.la_start_idx: self.la_end_idx])) self.la_start_idx += 1 self.la_end_idx += 1 # keep track of the offset between the current epoch and the epoch # corresponding to the local average. NOTE: in the end the size of # self.ev and self.ev_local_avg should be equal self.la_epochs.update({epoch: int(epoch - int(self.ev_freq*np.floor(self.window/2)))}) elif len(self.ev) < int(np.ceil(self.window/2)): self.la_epochs.update({epoch: -1}) # since there is an offset between the current epoch and the local # average epoch, loop in the last epoch to compute the local average of # these number of elements: window, window - 1, window - 2, ..., ceil(window/2) elif epoch == self.epochs - 1: for i in range(int(np.ceil(self.window/2))): assert len(self.ev[self.la_start_idx: self.la_end_idx]) == self.window - i self.ev_local_avg.append(np.mean(self.ev[self.la_start_idx: self.la_end_idx + 1])) self.la_start_idx += 1 def early_stop(self, epoch, factor=1.3, es_start_epoch=10, delta=4, criteria='local_avg'): """ Early stopping criterion Args: epoch (int): current epoch factor (float): threshold factor for the ration between the current and prefious eigenvalue. Default: 1.3 es_start_epoch (int): until this epoch do not consider early stopping. Default: 20 delta (int): factor influencing which previous local average we consider for early stopping. Default: 2 """ if criteria == 'local_avg': if int(self.la_epochs[epoch] - self.ev_freq*delta) >= es_start_epoch: if criteria == 'local_avg': current_la = self.ev_local_avg[-1] previous_la = self.ev_local_avg[-1 - delta] self.stop_search = current_la / previous_la > factor if self.stop_search: self.stop_epoch = int(self.la_epochs[epoch] - self.ev_freq*delta) self.stop_genotype = self.genotypes[self.stop_epoch] self.stop_numparam = self.numparams[self.stop_epoch] elif criteria == 'exact': if epoch > es_start_epoch: current_la = self.ev[-1] previous_la = self.ev[-1 - delta] self.stop_search = current_la / previous_la > factor if self.stop_search: self.stop_epoch = epoch - delta self.stop_genotype = self.genotypes[self.stop_epoch] self.stop_numparam = self.numparams[self.stop_epoch] else: print('ERROR IN EARLY STOP: WRONG CRITERIA:', criteria); exit(0) def gen_comb(eids): comb = [] for r in range(len(eids)): for c in range(r + 1, len(eids)): comb.append((eids[r], eids[c])) return comb