495 lines
		
	
	
		
			17 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			495 lines
		
	
	
		
			17 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from __future__ import print_function
 | |
| 
 | |
| import numpy as np
 | |
| import os
 | |
| import os.path
 | |
| import sys
 | |
| import shutil
 | |
| import torch
 | |
| import torchvision.transforms as transforms
 | |
| 
 | |
| from PIL import Image
 | |
| from torch.autograd import Variable
 | |
| from torchvision.datasets import VisionDataset
 | |
| from torchvision.datasets import utils
 | |
| 
 | |
| if sys.version_info[0] == 2:
 | |
|     import cPickle as pickle
 | |
| else:
 | |
|     import pickle
 | |
| 
 | |
| 
 | |
| class AvgrageMeter(object):
 | |
| 
 | |
|     def __init__(self):
 | |
|         self.reset()
 | |
| 
 | |
|     def reset(self):
 | |
|         self.avg = 0
 | |
|         self.sum = 0
 | |
|         self.cnt = 0
 | |
| 
 | |
|     def update(self, val, n=1):
 | |
|         self.sum += val * n
 | |
|         self.cnt += n
 | |
|         self.avg = self.sum / self.cnt
 | |
| 
 | |
| 
 | |
| def accuracy(output, target, topk=(1,)):
 | |
|     maxk = max(topk)
 | |
|     batch_size = target.size(0)
 | |
| 
 | |
|     _, pred = output.topk(maxk, 1, True, True)
 | |
|     pred = pred.t()
 | |
|     correct = pred.eq(target.view(1, -1).expand_as(pred))
 | |
| 
 | |
|     res = []
 | |
|     for k in topk:        
 | |
|         correct_k = correct[:k].contiguous().view(-1).float().sum(0)
 | |
|         res.append(correct_k.mul_(100.0 / batch_size))
 | |
|     return res
 | |
| 
 | |
| 
 | |
| class Cutout(object):
 | |
|     def __init__(self, length, prob=1.0):
 | |
|         self.length = length
 | |
|         self.prob = prob
 | |
| 
 | |
|     def __call__(self, img):
 | |
|         if np.random.binomial(1, self.prob):
 | |
|             h, w = img.size(1), img.size(2)
 | |
|             mask = np.ones((h, w), np.float32)
 | |
|             y = np.random.randint(h)
 | |
|             x = np.random.randint(w)
 | |
| 
 | |
|             y1 = np.clip(y - self.length // 2, 0, h)
 | |
|             y2 = np.clip(y + self.length // 2, 0, h)
 | |
|             x1 = np.clip(x - self.length // 2, 0, w)
 | |
|             x2 = np.clip(x + self.length // 2, 0, w)
 | |
| 
 | |
|             mask[y1: y2, x1: x2] = 0.
 | |
|             mask = torch.from_numpy(mask)
 | |
|             mask = mask.expand_as(img)
 | |
|             img *= mask
 | |
|         return img
 | |
| 
 | |
| def _data_transforms_svhn(args):
 | |
|     SVHN_MEAN = [0.4377, 0.4438, 0.4728]
 | |
|     SVHN_STD = [0.1980, 0.2010, 0.1970]
 | |
| 
 | |
|     train_transform = transforms.Compose([
 | |
|         transforms.RandomCrop(32, padding=4),
 | |
|         transforms.RandomHorizontalFlip(),
 | |
|         transforms.ToTensor(),
 | |
|         transforms.Normalize(SVHN_MEAN, SVHN_STD),
 | |
|     ])
 | |
|     if args.cutout:
 | |
|         train_transform.transforms.append(Cutout(args.cutout_length,
 | |
|                                           args.cutout_prob))
 | |
| 
 | |
|     valid_transform = transforms.Compose([
 | |
|         transforms.ToTensor(),
 | |
|         transforms.Normalize(SVHN_MEAN, SVHN_STD),
 | |
|         ])
 | |
|     return train_transform, valid_transform
 | |
| 
 | |
| 
 | |
| def _data_transforms_cifar100(args):
 | |
|     CIFAR_MEAN = [0.5071, 0.4865, 0.4409]
 | |
|     CIFAR_STD = [0.2673, 0.2564, 0.2762]
 | |
| 
 | |
|     train_transform = transforms.Compose([
 | |
|         transforms.RandomCrop(32, padding=4),
 | |
|         transforms.RandomHorizontalFlip(),
 | |
|         transforms.ToTensor(),
 | |
|         transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
 | |
|     ])
 | |
|     if args.cutout:
 | |
|         train_transform.transforms.append(Cutout(args.cutout_length,
 | |
|                                           args.cutout_prob))
 | |
| 
 | |
|     valid_transform = transforms.Compose([
 | |
|         transforms.ToTensor(),
 | |
|         transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
 | |
|         ])
 | |
|     return train_transform, valid_transform
 | |
| 
 | |
| 
 | |
| def _data_transforms_cifar10(args):
 | |
|     CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
 | |
|     CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]
 | |
| 
 | |
|     train_transform = transforms.Compose([
 | |
|         transforms.RandomCrop(32, padding=4),
 | |
|         transforms.RandomHorizontalFlip(),
 | |
|         transforms.ToTensor(),
 | |
|         transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
 | |
|     ])
 | |
|     if args.cutout:
 | |
|         train_transform.transforms.append(Cutout(args.cutout_length,
 | |
|                                                  args.cutout_prob))
 | |
| 
 | |
|     valid_transform = transforms.Compose([
 | |
|         transforms.ToTensor(),
 | |
|         transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
 | |
|     ])
 | |
|     return train_transform, valid_transform
 | |
| 
 | |
| 
 | |
| def count_parameters_in_MB(model):
 | |
|     return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name) / 1e6
 | |
| 
 | |
| 
 | |
| def count_parameters_in_Compact(model):
 | |
|     from sota.cnn.model import Network as CompactModel
 | |
|     genotype = model.genotype()
 | |
|     compact_model = CompactModel(36, model._num_classes, 20, True, genotype)
 | |
|     num_params = count_parameters_in_MB(compact_model)
 | |
|     return num_params
 | |
| 
 | |
| 
 | |
| def save_checkpoint(state, is_best, save, per_epoch=False, prefix=''):
 | |
|     filename = prefix
 | |
|     if per_epoch:
 | |
|         epoch = state['epoch']
 | |
|         filename += 'checkpoint_{}.pth.tar'.format(epoch)
 | |
|     else:
 | |
|         filename += 'checkpoint.pth.tar'
 | |
|     filename = os.path.join(save, filename)
 | |
|     torch.save(state, filename)
 | |
|     if is_best:
 | |
|         best_filename = os.path.join(save, 'model_best.pth.tar')
 | |
|         shutil.copyfile(filename, best_filename)
 | |
| 
 | |
| 
 | |
| def load_checkpoint(model, optimizer, save, epoch=None):
 | |
|     if epoch is None:
 | |
|         filename = 'checkpoint.pth.tar'
 | |
|     else:
 | |
|         filename = 'checkpoint_{}.pth.tar'.format(epoch)
 | |
|     filename = os.path.join(save, filename)
 | |
|     start_epoch = 0
 | |
|     if os.path.isfile(filename):
 | |
|         print("=> loading checkpoint '{}'".format(filename))
 | |
|         checkpoint = torch.load(filename)
 | |
|         start_epoch = checkpoint['epoch']
 | |
|         best_acc_top1 = checkpoint['best_acc_top1']
 | |
|         model.load_state_dict(checkpoint['state_dict'])
 | |
|         optimizer.load_state_dict(checkpoint['optimizer'])
 | |
|         print("=> loaded checkpoint '{}' (epoch {})"
 | |
|               .format(filename, checkpoint['epoch']))
 | |
|     else:
 | |
|         print("=> no checkpoint found at '{}'".format(filename))
 | |
|     
 | |
|     return model, optimizer, start_epoch, best_acc_top1
 | |
| 
 | |
| 
 | |
| def save(model, model_path):
 | |
|     torch.save(model.state_dict(), model_path)
 | |
| 
 | |
| 
 | |
| def load(model, model_path):
 | |
|     model.load_state_dict(torch.load(model_path))
 | |
| 
 | |
| 
 | |
| def drop_path(x, drop_prob):
 | |
|     if drop_prob > 0.:
 | |
|         keep_prob = 1. - drop_prob
 | |
|         mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob))
 | |
|         x.div_(keep_prob)
 | |
|         x.mul_(mask)
 | |
|     return x
 | |
| 
 | |
| 
 | |
| def create_exp_dir(path, scripts_to_save=None):
 | |
|     if not os.path.exists(path):
 | |
|         os.makedirs(path)
 | |
|     print('Experiment dir : {}'.format(path))
 | |
| 
 | |
|     if scripts_to_save is not None:
 | |
|         os.mkdir(os.path.join(path, 'scripts'))
 | |
|         for script in scripts_to_save:
 | |
|             dst_file = os.path.join(path, 'scripts', os.path.basename(script))
 | |
|             shutil.copyfile(script, dst_file)
 | |
| 
 | |
| 
 | |
| class CIFAR10(VisionDataset):
 | |
|     """`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
 | |
| 
 | |
|     Args:
 | |
|         root (string): Root directory of dataset where directory
 | |
|             ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
 | |
|         train (bool, optional): If True, creates dataset from training set, otherwise
 | |
|             creates from test set.
 | |
|         transform (callable, optional): A function/transform that takes in an PIL image
 | |
|             and returns a transformed version. E.g, ``transforms.RandomCrop``
 | |
|         target_transform (callable, optional): A function/transform that takes in the
 | |
|             target and transforms it.
 | |
|         download (bool, optional): If true, downloads the dataset from the internet and
 | |
|             puts it in root directory. If dataset is already downloaded, it is not
 | |
|             downloaded again.
 | |
| 
 | |
|     """
 | |
|     base_folder = 'cifar-10-batches-py'
 | |
|     url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
 | |
|     filename = "cifar-10-python.tar.gz"
 | |
|     tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
 | |
|     train_list = [
 | |
|         ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
 | |
|         ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
 | |
|         ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
 | |
|         ['data_batch_4', '634d18415352ddfa80567beed471001a'],
 | |
|         #['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
 | |
|     ]
 | |
| 
 | |
|     test_list = [
 | |
|         ['test_batch', '40351d587109b95175f43aff81a1287e'],
 | |
|     ]
 | |
|     meta = {
 | |
|         'filename': 'batches.meta',
 | |
|         'key': 'label_names',
 | |
|         'md5': '5ff9c542aee3614f3951f8cda6e48888',
 | |
|     }
 | |
| 
 | |
|     def __init__(self, root, train=True, transform=None, target_transform=None,
 | |
|                  download=False):
 | |
| 
 | |
|         super(CIFAR10, self).__init__(root, transform=transform,
 | |
|                                       target_transform=target_transform)
 | |
| 
 | |
|         self.train = train  # training set or test set
 | |
| 
 | |
|         if download:
 | |
|             self.download()
 | |
| 
 | |
|         if not self._check_integrity():
 | |
|             raise RuntimeError('Dataset not found or corrupted.' +
 | |
|                                ' You can use download=True to download it')
 | |
| 
 | |
|         if self.train:
 | |
|             downloaded_list = self.train_list
 | |
|         else:
 | |
|             downloaded_list = self.test_list
 | |
| 
 | |
|         self.data = []
 | |
|         self.targets = []
 | |
| 
 | |
|         # now load the picked numpy arrays
 | |
|         for file_name, checksum in downloaded_list:
 | |
|             file_path = os.path.join(self.root, self.base_folder, file_name)
 | |
|             with open(file_path, 'rb') as f:
 | |
|                 if sys.version_info[0] == 2:
 | |
|                     entry = pickle.load(f)
 | |
|                 else:
 | |
|                     entry = pickle.load(f, encoding='latin1')
 | |
|                 self.data.append(entry['data'])
 | |
|                 if 'labels' in entry:
 | |
|                     self.targets.extend(entry['labels'])
 | |
|                 else:
 | |
|                     self.targets.extend(entry['fine_labels'])
 | |
| 
 | |
|         self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
 | |
|         self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC
 | |
| 
 | |
|         self._load_meta()
 | |
| 
 | |
|     def _load_meta(self):
 | |
|         path = os.path.join(self.root, self.base_folder, self.meta['filename'])
 | |
|         if not utils.check_integrity(path, self.meta['md5']):
 | |
|             raise RuntimeError('Dataset metadata file not found or corrupted.' +
 | |
|                                ' You can use download=True to download it')
 | |
|         with open(path, 'rb') as infile:
 | |
|             if sys.version_info[0] == 2:
 | |
|                 data = pickle.load(infile)
 | |
|             else:
 | |
|                 data = pickle.load(infile, encoding='latin1')
 | |
|             self.classes = data[self.meta['key']]
 | |
|         self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
 | |
| 
 | |
|     def __getitem__(self, index):
 | |
|         """
 | |
|         Args:
 | |
|             index (int): Index
 | |
| 
 | |
|         Returns:
 | |
|             tuple: (image, target) where target is index of the target class.
 | |
|         """
 | |
|         img, target = self.data[index], self.targets[index]
 | |
| 
 | |
|         # doing this so that it is consistent with all other datasets
 | |
|         # to return a PIL Image
 | |
|         img = Image.fromarray(img)
 | |
| 
 | |
|         if self.transform is not None:
 | |
|             img = self.transform(img)
 | |
| 
 | |
|         if self.target_transform is not None:
 | |
|             target = self.target_transform(target)
 | |
| 
 | |
|         return img, target
 | |
| 
 | |
|     def __len__(self):
 | |
|         return len(self.data)
 | |
| 
 | |
|     def _check_integrity(self):
 | |
|         root = self.root
 | |
|         for fentry in (self.train_list + self.test_list):
 | |
|             filename, md5 = fentry[0], fentry[1]
 | |
|             fpath = os.path.join(root, self.base_folder, filename)
 | |
|             if not utils.check_integrity(fpath, md5):
 | |
|                 return False
 | |
|         return True
 | |
| 
 | |
|     def download(self):
 | |
|         if self._check_integrity():
 | |
|             print('Files already downloaded and verified')
 | |
|             return
 | |
|         utils.download_and_extract_archive(self.url, self.root,
 | |
|                                            filename=self.filename,
 | |
|                                            md5=self.tgz_md5)
 | |
| 
 | |
|     def extra_repr(self):
 | |
|         return "Split: {}".format("Train" if self.train is True else "Test")
 | |
| 
 | |
| 
 | |
| def pick_gpu_lowest_memory():
 | |
|     import gpustat
 | |
|     stats = gpustat.GPUStatCollection.new_query()
 | |
|     ids = map(lambda gpu: int(gpu.entry['index']), stats)
 | |
|     ratios = map(lambda gpu: float(gpu.memory_used)/float(gpu.memory_total), stats)
 | |
|     bestGPU = min(zip(ids, ratios), key=lambda x: x[1])[0]
 | |
|     return bestGPU
 | |
| 
 | |
| 
 | |
| #### early stopping (from RobustNAS)
 | |
| class EVLocalAvg(object):
 | |
|     def __init__(self, window=5, ev_freq=2, total_epochs=50):
 | |
|         """ Keep track of the eigenvalues local average.
 | |
|         Args:
 | |
|             window (int): number of elements used to compute local average.
 | |
|                 Default: 5
 | |
|             ev_freq (int): frequency used to compute eigenvalues. Default:
 | |
|                 every 2 epochs
 | |
|             total_epochs (int): total number of epochs that DARTS runs.
 | |
|                 Default: 50
 | |
|         """
 | |
|         self.window = window
 | |
|         self.ev_freq = ev_freq
 | |
|         self.epochs = total_epochs
 | |
| 
 | |
|         self.stop_search = False
 | |
|         self.stop_epoch = total_epochs - 1
 | |
|         self.stop_genotype = None
 | |
|         self.stop_numparam = 0
 | |
| 
 | |
|         self.ev = []
 | |
|         self.ev_local_avg = []
 | |
|         self.genotypes = {}
 | |
|         self.numparams = {}
 | |
|         self.la_epochs = {}
 | |
| 
 | |
|         # start and end index of the local average window
 | |
|         self.la_start_idx = 0
 | |
|         self.la_end_idx = self.window
 | |
| 
 | |
|     def reset(self):
 | |
|         self.ev = []
 | |
|         self.ev_local_avg = []
 | |
|         self.genotypes = {}
 | |
|         self.numparams = {}
 | |
|         self.la_epochs = {}
 | |
| 
 | |
|     def update(self, epoch, ev, genotype, numparam=0):
 | |
|         """ Method to update the local average list.
 | |
| 
 | |
|         Args:
 | |
|             epoch (int): current epoch
 | |
|             ev (float): current dominant eigenvalue
 | |
|             genotype (namedtuple): current genotype
 | |
| 
 | |
|         """
 | |
|         self.ev.append(ev)
 | |
|         self.genotypes.update({epoch: genotype})
 | |
|         self.numparams.update({epoch: numparam})
 | |
|         # set the stop_genotype to the current genotype in case the early stop
 | |
|         # procedure decides not to early stop
 | |
|         self.stop_genotype = genotype
 | |
| 
 | |
|         # since the local average computation starts after the dominant
 | |
|         # eigenvalue in the first epoch is already computed we have to wait
 | |
|         # at least until we have 3 eigenvalues in the list.
 | |
|         if (len(self.ev) >= int(np.ceil(self.window/2))) and (epoch <
 | |
|                                                               self.epochs - 1):
 | |
|             # start sliding the window as soon as the number of eigenvalues in
 | |
|             # the list becomes equal to the window size
 | |
|             if len(self.ev) < self.window:
 | |
|                 self.ev_local_avg.append(np.mean(self.ev))
 | |
|             else:
 | |
|                 assert len(self.ev[self.la_start_idx: self.la_end_idx]) == self.window
 | |
|                 self.ev_local_avg.append(np.mean(self.ev[self.la_start_idx:
 | |
|                                                          self.la_end_idx]))
 | |
|                 self.la_start_idx += 1
 | |
|                 self.la_end_idx += 1
 | |
| 
 | |
|             # keep track of the offset between the current epoch and the epoch
 | |
|             # corresponding to the local average. NOTE: in the end the size of
 | |
|             # self.ev and self.ev_local_avg should be equal
 | |
|             self.la_epochs.update({epoch: int(epoch -
 | |
|                                               int(self.ev_freq*np.floor(self.window/2)))})
 | |
| 
 | |
|         elif len(self.ev) < int(np.ceil(self.window/2)):
 | |
|           self.la_epochs.update({epoch: -1})
 | |
| 
 | |
|         # since there is an offset between the current epoch and the local
 | |
|         # average epoch, loop in the last epoch to compute the local average of
 | |
|         # these number of elements: window, window - 1, window - 2, ..., ceil(window/2)
 | |
|         elif epoch == self.epochs - 1:
 | |
|             for i in range(int(np.ceil(self.window/2))):
 | |
|                 assert len(self.ev[self.la_start_idx: self.la_end_idx]) == self.window - i
 | |
|                 self.ev_local_avg.append(np.mean(self.ev[self.la_start_idx:
 | |
|                                                          self.la_end_idx + 1]))
 | |
|                 self.la_start_idx += 1
 | |
| 
 | |
|     def early_stop(self, epoch, factor=1.3, es_start_epoch=10, delta=4, criteria='local_avg'):
 | |
|         """ Early stopping criterion
 | |
| 
 | |
|         Args:
 | |
|             epoch (int): current epoch
 | |
|             factor (float): threshold factor for the ration between the current
 | |
|                 and prefious eigenvalue. Default: 1.3
 | |
|             es_start_epoch (int): until this epoch do not consider early
 | |
|                 stopping. Default: 20
 | |
|             delta (int): factor influencing which previous local average we
 | |
|                 consider for early stopping. Default: 2
 | |
|         """
 | |
|         if criteria == 'local_avg':
 | |
|             if int(self.la_epochs[epoch] - self.ev_freq*delta) >= es_start_epoch:
 | |
|                 if criteria == 'local_avg':
 | |
|                     current_la = self.ev_local_avg[-1]
 | |
|                     previous_la = self.ev_local_avg[-1 - delta]
 | |
|                     self.stop_search = current_la / previous_la > factor
 | |
|                     if self.stop_search:
 | |
|                         self.stop_epoch = int(self.la_epochs[epoch] - self.ev_freq*delta)
 | |
|                         self.stop_genotype = self.genotypes[self.stop_epoch]
 | |
|                         self.stop_numparam = self.numparams[self.stop_epoch]
 | |
|         elif criteria == 'exact':
 | |
|             if epoch > es_start_epoch:
 | |
|                 current_la = self.ev[-1]
 | |
|                 previous_la = self.ev[-1 - delta]
 | |
|                 self.stop_search = current_la / previous_la > factor
 | |
|                 if self.stop_search:
 | |
|                     self.stop_epoch = epoch - delta
 | |
|                     self.stop_genotype = self.genotypes[self.stop_epoch]
 | |
|                     self.stop_numparam = self.numparams[self.stop_epoch]
 | |
|         else:
 | |
|             print('ERROR IN EARLY STOP: WRONG CRITERIA:', criteria); exit(0)
 | |
| 
 | |
| 
 | |
| def gen_comb(eids):
 | |
|     comb = []
 | |
|     for r in range(len(eids)):
 | |
|         for c in range(r + 1, len(eids)):
 | |
|             comb.append((eids[r], eids[c]))
 | |
| 
 | |
|     return comb
 |