Prototype generic nas model (cont.).
This commit is contained in:
		| @@ -2,7 +2,7 @@ | |||||||
|   "scheduler": ["str",   "cos"], |   "scheduler": ["str",   "cos"], | ||||||
|   "LR"       : ["float", "0.025"], |   "LR"       : ["float", "0.025"], | ||||||
|   "eta_min"  : ["float", "0.001"], |   "eta_min"  : ["float", "0.001"], | ||||||
|   "epochs"   : ["int",   "250"], |   "epochs"   : ["int",   "150"], | ||||||
|   "warmup"   : ["int",   "0"], |   "warmup"   : ["int",   "0"], | ||||||
|   "optim"    : ["str",   "SGD"], |   "optim"    : ["str",   "SGD"], | ||||||
|   "decay"    : ["float", "0.0005"], |   "decay"    : ["float", "0.0005"], | ||||||
|   | |||||||
| @@ -2,7 +2,7 @@ | |||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||||
| ###################################################################################### | ###################################################################################### | ||||||
| # python ./exps/algos-v2/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo darts-v1 --rand_seed 1 | # python ./exps/algos-v2/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo darts-v1 --rand_seed 1 | ||||||
| # python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v1 | # python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v1 --drop_path_rate 0.3 | ||||||
| # python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo darts-v1 | # python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo darts-v1 | ||||||
| #### | #### | ||||||
| # python ./exps/algos-v2/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo darts-v2 --rand_seed 1 | # python ./exps/algos-v2/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo darts-v2 --rand_seed 1 | ||||||
| @@ -30,6 +30,73 @@ from models       import get_cell_based_tiny_net, get_search_spaces | |||||||
| from nas_201_api  import NASBench201API as API | from nas_201_api  import NASBench201API as API | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # The following three functions are used for DARTS-V2 | ||||||
|  | def _concat(xs): | ||||||
|  |   return torch.cat([x.view(-1) for x in xs]) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _hessian_vector_product(vector, network, criterion, base_inputs, base_targets, r=1e-2): | ||||||
|  |   R = r / _concat(vector).norm() | ||||||
|  |   for p, v in zip(network.weights, vector): | ||||||
|  |     p.data.add_(R, v) | ||||||
|  |   _, logits = network(base_inputs) | ||||||
|  |   loss = criterion(logits, base_targets) | ||||||
|  |   grads_p = torch.autograd.grad(loss, network.alphas) | ||||||
|  |  | ||||||
|  |   for p, v in zip(network.weights, vector): | ||||||
|  |     p.data.sub_(2*R, v) | ||||||
|  |   _, logits = network(base_inputs) | ||||||
|  |   loss = criterion(logits, base_targets) | ||||||
|  |   grads_n = torch.autograd.grad(loss, network.alphas) | ||||||
|  |  | ||||||
|  |   for p, v in zip(network.weights, vector): | ||||||
|  |     p.data.add_(R, v) | ||||||
|  |   return [(x-y).div_(2*R) for x, y in zip(grads_p, grads_n)] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def backward_step_unrolled(network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets): | ||||||
|  |   # _compute_unrolled_model | ||||||
|  |   _, logits = network(base_inputs) | ||||||
|  |   loss = criterion(logits, base_targets) | ||||||
|  |   LR, WD, momentum = w_optimizer.param_groups[0]['lr'], w_optimizer.param_groups[0]['weight_decay'], w_optimizer.param_groups[0]['momentum'] | ||||||
|  |   with torch.no_grad(): | ||||||
|  |     theta = _concat(network.weights) | ||||||
|  |     try: | ||||||
|  |       moment = _concat(w_optimizer.state[v]['momentum_buffer'] for v in network.weights) | ||||||
|  |       moment = moment.mul_(momentum) | ||||||
|  |     except: | ||||||
|  |       moment = torch.zeros_like(theta) | ||||||
|  |     dtheta = _concat(torch.autograd.grad(loss, network.weights)) + WD*theta | ||||||
|  |     params = theta.sub(LR, moment+dtheta) | ||||||
|  |   unrolled_model = deepcopy(network) | ||||||
|  |   model_dict  = unrolled_model.state_dict() | ||||||
|  |   new_params, offset = {}, 0 | ||||||
|  |   for k, v in network.named_parameters(): | ||||||
|  |     if 'arch_parameters' in k: continue | ||||||
|  |     v_length = np.prod(v.size()) | ||||||
|  |     new_params[k] = params[offset: offset+v_length].view(v.size()) | ||||||
|  |     offset += v_length | ||||||
|  |   model_dict.update(new_params) | ||||||
|  |   unrolled_model.load_state_dict(model_dict) | ||||||
|  |  | ||||||
|  |   unrolled_model.zero_grad() | ||||||
|  |   _, unrolled_logits = unrolled_model(arch_inputs) | ||||||
|  |   unrolled_loss = criterion(unrolled_logits, arch_targets) | ||||||
|  |   unrolled_loss.backward() | ||||||
|  |  | ||||||
|  |   dalpha = unrolled_model.arch_parameters.grad | ||||||
|  |   vector = [v.grad.data for v in unrolled_model.weights] | ||||||
|  |   [implicit_grads] = _hessian_vector_product(vector, network, criterion, base_inputs, base_targets) | ||||||
|  |    | ||||||
|  |   dalpha.data.sub_(LR, implicit_grads.data) | ||||||
|  |  | ||||||
|  |   if network.arch_parameters.grad is None: | ||||||
|  |     network.arch_parameters.grad = deepcopy( dalpha ) | ||||||
|  |   else: | ||||||
|  |     network.arch_parameters.grad.data.copy_( dalpha.data ) | ||||||
|  |   return unrolled_loss.detach(), unrolled_logits.detach() | ||||||
|  |  | ||||||
|  |  | ||||||
| def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, algo, logger): | def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, algo, logger): | ||||||
|   data_time, batch_time = AverageMeter(), AverageMeter() |   data_time, batch_time = AverageMeter(), AverageMeter() | ||||||
|   base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() |   base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||||
| @@ -81,6 +148,9 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer | |||||||
|     else: |     else: | ||||||
|       raise ValueError('Invalid algo name : {:}'.format(algo)) |       raise ValueError('Invalid algo name : {:}'.format(algo)) | ||||||
|     network.zero_grad() |     network.zero_grad() | ||||||
|  |     if algo == 'darts-v2': | ||||||
|  |       arch_loss, logits = backward_step_unrolled(network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets) | ||||||
|  |     else: | ||||||
|       _, logits = network(arch_inputs) |       _, logits = network(arch_inputs) | ||||||
|       arch_loss = criterion(logits, arch_targets) |       arch_loss = criterion(logits, arch_targets) | ||||||
|       arch_loss.backward() |       arch_loss.backward() | ||||||
| @@ -192,7 +262,10 @@ def main(xargs): | |||||||
|   params = count_parameters_in_MB(search_model) |   params = count_parameters_in_MB(search_model) | ||||||
|   logger.log('The parameters of the search model = {:.2f} MB'.format(params)) |   logger.log('The parameters of the search model = {:.2f} MB'.format(params)) | ||||||
|   logger.log('search-space : {:}'.format(search_space)) |   logger.log('search-space : {:}'.format(search_space)) | ||||||
|  |   try: | ||||||
|     api = API() |     api = API() | ||||||
|  |   except: | ||||||
|  |     api = None | ||||||
|   logger.log('{:} create API = {:} done'.format(time_string(), api)) |   logger.log('{:} create API = {:} done'.format(time_string(), api)) | ||||||
|  |  | ||||||
|   last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') |   last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') | ||||||
| @@ -224,6 +297,7 @@ def main(xargs): | |||||||
|     epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) |     epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) | ||||||
|     logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()))) |     logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()))) | ||||||
|  |  | ||||||
|  |     network.set_drop_path(float(epoch+1) / total_epoch, xargs.drop_path_rate) | ||||||
|     search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \ |     search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \ | ||||||
|                 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, xargs.algo, logger) |                 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, xargs.algo, logger) | ||||||
|     search_time.update(time.time() - start_time) |     search_time.update(time.time() - start_time) | ||||||
| @@ -314,6 +388,7 @@ if __name__ == '__main__': | |||||||
|   # architecture leraning rate |   # architecture leraning rate | ||||||
|   parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') |   parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') | ||||||
|   parser.add_argument('--arch_weight_decay',  type=float, default=1e-3, help='weight decay for arch encoding') |   parser.add_argument('--arch_weight_decay',  type=float, default=1e-3, help='weight decay for arch encoding') | ||||||
|  |   parser.add_argument('--drop_path_rate'  ,  type=float, help='The drop path rate.') | ||||||
|   # log |   # log | ||||||
|   parser.add_argument('--workers',            type=int,   default=2,    help='number of data loading workers (default: 2)') |   parser.add_argument('--workers',            type=int,   default=2,    help='number of data loading workers (default: 2)') | ||||||
|   parser.add_argument('--save_dir',           type=str,   default='./output/search', help='Folder to save checkpoints and log.') |   parser.add_argument('--save_dir',           type=str,   default='./output/search', help='Folder to save checkpoints and log.') | ||||||
|   | |||||||
| @@ -67,6 +67,14 @@ class GenericNAS201Model(nn.Module): | |||||||
|     if mode == 'dynamic': self.dynamic_cell = deepcopy(dynamic_cell) |     if mode == 'dynamic': self.dynamic_cell = deepcopy(dynamic_cell) | ||||||
|     else                : self.dynamic_cell = None |     else                : self.dynamic_cell = None | ||||||
|  |  | ||||||
|  |   def set_drop_path(self, progress, drop_path_rate): | ||||||
|  |     if drop_path_rate is None: | ||||||
|  |       self._drop_path = None | ||||||
|  |     elif progress is None: | ||||||
|  |       self._drop_path = drop_path_rate | ||||||
|  |     else: | ||||||
|  |       self._drop_path = progress * drop_path_rate | ||||||
|  |  | ||||||
|   @property |   @property | ||||||
|   def mode(self): |   def mode(self): | ||||||
|     return self._mode |     return self._mode | ||||||
| @@ -210,6 +218,8 @@ class GenericNAS201Model(nn.Module): | |||||||
|           feature = cell.forward_gdas(feature, alphas, index) |           feature = cell.forward_gdas(feature, alphas, index) | ||||||
|         else: raise ValueError('invalid mode={:}'.format(self.mode)) |         else: raise ValueError('invalid mode={:}'.format(self.mode)) | ||||||
|       else: feature = cell(feature) |       else: feature = cell(feature) | ||||||
|  |       if self.drop_path is not None: | ||||||
|  |         feature = drop_path(feature, self.drop_path) | ||||||
|     out = self.lastact(feature) |     out = self.lastact(feature) | ||||||
|     out = self.global_pooling(out) |     out = self.global_pooling(out) | ||||||
|     out = out.view(out.size(0), -1) |     out = out.view(out.size(0), -1) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user