From c34620ab1b0b3a9d61ec08783ffa31ada0a735d0 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Sun, 19 Jul 2020 08:11:29 +0000 Subject: [PATCH] Prototype generic nas model (cont.). --- .../nas-benchmark/algos/weight-sharing.config | 2 +- exps/algos-v2/search-cell.py | 85 +++++++++++++++++-- lib/models/cell_searchs/generic_model.py | 10 +++ 3 files changed, 91 insertions(+), 6 deletions(-) diff --git a/configs/nas-benchmark/algos/weight-sharing.config b/configs/nas-benchmark/algos/weight-sharing.config index e2d956d..a58d727 100644 --- a/configs/nas-benchmark/algos/weight-sharing.config +++ b/configs/nas-benchmark/algos/weight-sharing.config @@ -2,7 +2,7 @@ "scheduler": ["str", "cos"], "LR" : ["float", "0.025"], "eta_min" : ["float", "0.001"], - "epochs" : ["int", "250"], + "epochs" : ["int", "150"], "warmup" : ["int", "0"], "optim" : ["str", "SGD"], "decay" : ["float", "0.0005"], diff --git a/exps/algos-v2/search-cell.py b/exps/algos-v2/search-cell.py index e1ae220..19c9f68 100644 --- a/exps/algos-v2/search-cell.py +++ b/exps/algos-v2/search-cell.py @@ -2,7 +2,7 @@ # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # ###################################################################################### # python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo darts-v1 --rand_seed 1 -# python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v1 +# python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v1 --drop_path_rate 0.3 # python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo darts-v1 #### # python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo darts-v2 --rand_seed 1 @@ -30,6 +30,73 @@ from models import get_cell_based_tiny_net, get_search_spaces from nas_201_api import NASBench201API as API +# The following three functions are used for DARTS-V2 +def _concat(xs): + return torch.cat([x.view(-1) for x in xs]) + + +def _hessian_vector_product(vector, network, criterion, base_inputs, base_targets, r=1e-2): + R = r / _concat(vector).norm() + for p, v in zip(network.weights, vector): + p.data.add_(R, v) + _, logits = network(base_inputs) + loss = criterion(logits, base_targets) + grads_p = torch.autograd.grad(loss, network.alphas) + + for p, v in zip(network.weights, vector): + p.data.sub_(2*R, v) + _, logits = network(base_inputs) + loss = criterion(logits, base_targets) + grads_n = torch.autograd.grad(loss, network.alphas) + + for p, v in zip(network.weights, vector): + p.data.add_(R, v) + return [(x-y).div_(2*R) for x, y in zip(grads_p, grads_n)] + + +def backward_step_unrolled(network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets): + # _compute_unrolled_model + _, logits = network(base_inputs) + loss = criterion(logits, base_targets) + LR, WD, momentum = w_optimizer.param_groups[0]['lr'], w_optimizer.param_groups[0]['weight_decay'], w_optimizer.param_groups[0]['momentum'] + with torch.no_grad(): + theta = _concat(network.weights) + try: + moment = _concat(w_optimizer.state[v]['momentum_buffer'] for v in network.weights) + moment = moment.mul_(momentum) + except: + moment = torch.zeros_like(theta) + dtheta = _concat(torch.autograd.grad(loss, network.weights)) + WD*theta + params = theta.sub(LR, moment+dtheta) + unrolled_model = deepcopy(network) + model_dict = unrolled_model.state_dict() + new_params, offset = {}, 0 + for k, v in network.named_parameters(): + if 'arch_parameters' in k: continue + v_length = np.prod(v.size()) + new_params[k] = params[offset: offset+v_length].view(v.size()) + offset += v_length + model_dict.update(new_params) + unrolled_model.load_state_dict(model_dict) + + unrolled_model.zero_grad() + _, unrolled_logits = unrolled_model(arch_inputs) + unrolled_loss = criterion(unrolled_logits, arch_targets) + unrolled_loss.backward() + + dalpha = unrolled_model.arch_parameters.grad + vector = [v.grad.data for v in unrolled_model.weights] + [implicit_grads] = _hessian_vector_product(vector, network, criterion, base_inputs, base_targets) + + dalpha.data.sub_(LR, implicit_grads.data) + + if network.arch_parameters.grad is None: + network.arch_parameters.grad = deepcopy( dalpha ) + else: + network.arch_parameters.grad.data.copy_( dalpha.data ) + return unrolled_loss.detach(), unrolled_logits.detach() + + def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, algo, logger): data_time, batch_time = AverageMeter(), AverageMeter() base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() @@ -81,9 +148,12 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer else: raise ValueError('Invalid algo name : {:}'.format(algo)) network.zero_grad() - _, logits = network(arch_inputs) - arch_loss = criterion(logits, arch_targets) - arch_loss.backward() + if algo == 'darts-v2': + arch_loss, logits = backward_step_unrolled(network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets) + else: + _, logits = network(arch_inputs) + arch_loss = criterion(logits, arch_targets) + arch_loss.backward() a_optimizer.step() # record arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) @@ -192,7 +262,10 @@ def main(xargs): params = count_parameters_in_MB(search_model) logger.log('The parameters of the search model = {:.2f} MB'.format(params)) logger.log('search-space : {:}'.format(search_space)) - api = API() + try: + api = API() + except: + api = None logger.log('{:} create API = {:} done'.format(time_string(), api)) last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') @@ -224,6 +297,7 @@ def main(xargs): epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()))) + network.set_drop_path(float(epoch+1) / total_epoch, xargs.drop_path_rate) search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \ = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, xargs.algo, logger) search_time.update(time.time() - start_time) @@ -314,6 +388,7 @@ if __name__ == '__main__': # architecture leraning rate parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding') + parser.add_argument('--drop_path_rate' , type=float, help='The drop path rate.') # log parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)') parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.') diff --git a/lib/models/cell_searchs/generic_model.py b/lib/models/cell_searchs/generic_model.py index 908f6fa..c90d150 100644 --- a/lib/models/cell_searchs/generic_model.py +++ b/lib/models/cell_searchs/generic_model.py @@ -67,6 +67,14 @@ class GenericNAS201Model(nn.Module): if mode == 'dynamic': self.dynamic_cell = deepcopy(dynamic_cell) else : self.dynamic_cell = None + def set_drop_path(self, progress, drop_path_rate): + if drop_path_rate is None: + self._drop_path = None + elif progress is None: + self._drop_path = drop_path_rate + else: + self._drop_path = progress * drop_path_rate + @property def mode(self): return self._mode @@ -210,6 +218,8 @@ class GenericNAS201Model(nn.Module): feature = cell.forward_gdas(feature, alphas, index) else: raise ValueError('invalid mode={:}'.format(self.mode)) else: feature = cell(feature) + if self.drop_path is not None: + feature = drop_path(feature, self.drop_path) out = self.lastact(feature) out = self.global_pooling(out) out = out.view(out.size(0), -1)