updates for beta
This commit is contained in:
		| @@ -62,7 +62,7 @@ def train_controller(xloader, shared_cnn, controller, criterion, optimizer, conf | ||||
|   # config. (containing some necessary arg) | ||||
|   #   baseline: The baseline score (i.e. average val_acc) from the previous epoch | ||||
|   data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|   GradnormMeter, LossMeter, ValAccMeter, BaselineMeter, RewardMeter, xend = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), time.time() | ||||
|   GradnormMeter, LossMeter, ValAccMeter, EntropyMeter, BaselineMeter, RewardMeter, xend = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), time.time() | ||||
|    | ||||
|   shared_cnn.eval() | ||||
|   controller.train() | ||||
| @@ -96,8 +96,9 @@ def train_controller(xloader, shared_cnn, controller, criterion, optimizer, conf | ||||
|     # account | ||||
|     RewardMeter.update(reward.item()) | ||||
|     BaselineMeter.update(baseline.item()) | ||||
|     ValAccMeter.update(val_top1.item()) | ||||
|     ValAccMeter.update(val_top1.item()*100) | ||||
|     LossMeter.update(loss.item()) | ||||
|     EntropyMeter.update(entropy.item()) | ||||
|    | ||||
|     # Average gradient over controller_num_aggregate samples | ||||
|     loss = loss / config.ctl_num_aggre | ||||
| @@ -116,7 +117,8 @@ def train_controller(xloader, shared_cnn, controller, criterion, optimizer, conf | ||||
|       Sstr = '*Train-Controller* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, config.ctl_train_steps * config.ctl_num_aggre) | ||||
|       Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) | ||||
|       Wstr = '[Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Reward {reward.val:.2f} ({reward.avg:.2f})] Baseline {basel.val:.2f} ({basel.avg:.2f})'.format(loss=LossMeter, top1=ValAccMeter, reward=RewardMeter, basel=BaselineMeter) | ||||
|       logger.log(Sstr + ' ' + Tstr + ' ' + Wstr) | ||||
|       Estr = 'Entropy={:.4f} ({:.4f})'.format(EntropyMeter.val, EntropyMeter.avg) | ||||
|       logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Estr) | ||||
|  | ||||
|   return LossMeter.avg, ValAccMeter.avg, BaselineMeter.avg, RewardMeter.avg, baseline.item() | ||||
|  | ||||
| @@ -250,7 +252,7 @@ def main(xargs): | ||||
|     w_scheduler.update(epoch, 0.0) | ||||
|     need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) ) | ||||
|     epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) | ||||
|     logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()))) | ||||
|     logger.log('\n[Search the {:}-th epoch] {:}, LR={:}, baseline={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()), baseline)) | ||||
|  | ||||
|     cnn_loss, cnn_top1, cnn_top5 = train_shared_cnn(train_loader, shared_cnn, controller, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger) | ||||
|     logger.log('[{:}] shared-cnn : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, cnn_loss, cnn_top1, cnn_top5)) | ||||
| @@ -264,7 +266,7 @@ def main(xargs): | ||||
|     logger.log('[{:}] controller : loss={:.2f}, accuracy={:.2f}%, baseline={:.2f}, reward={:.2f}, current-baseline={:.4f}'.format(epoch_str, ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline)) | ||||
|     best_arch, _ = get_best_arch(controller, shared_cnn, valid_loader) | ||||
|     shared_cnn.module.update_arch(best_arch) | ||||
|     best_valid_acc = valid_func(valid_loader, shared_cnn, criterion) | ||||
|     _, best_valid_acc, _ = valid_func(valid_loader, shared_cnn, criterion) | ||||
|  | ||||
|     genotypes[epoch] = best_arch | ||||
|     # check the best accuracy | ||||
| @@ -301,6 +303,14 @@ def main(xargs): | ||||
|     start_time = time.time() | ||||
|  | ||||
|   logger.log('\n' + '-'*100) | ||||
|   logger.log('During searching, the best architecture is {:}'.format(genotypes['best'])) | ||||
|   logger.log('Its accuracy is {:.2f}%'.format(valid_accuracies['best'])) | ||||
|   logger.log('Randomly select {:} architectures and select the best.'.format(xargs.controller_num_samples)) | ||||
|   final_arch, _ = get_best_arch(controller, shared_cnn, valid_loader, xargs.controller_num_samples) | ||||
|   shared_cnn.module.update_arch(final_arch) | ||||
|   final_loss, final_top1, final_top5 = valid_func(valid_loader, shared_cnn, criterion) | ||||
|   logger.log('The Selected Final Architecture : {:}'.format(final_arch)) | ||||
|   logger.log('Loss={:.3f}, Accuracy@1={:.2f}%, Accuracy@5={:.2f}%'.format(final_loss, final_top1, final_top5)) | ||||
|   # check the performance from the architecture dataset | ||||
|   #if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset): | ||||
|   #  logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset)) | ||||
|   | ||||
| @@ -23,7 +23,6 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer | ||||
|   data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|   base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|   arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|   network.train() | ||||
|   end = time.time() | ||||
|   for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): | ||||
|     scheduler.update(None, 1.0 * step / len(xloader)) | ||||
| @@ -33,9 +32,13 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer | ||||
|     data_time.update(time.time() - end) | ||||
|      | ||||
|     # update the weights | ||||
|     network.module.set_cal_mode( 'urs' ) | ||||
|     w_optimizer.zero_grad() | ||||
|     _, logits = network(base_inputs) | ||||
|     network.train() | ||||
|     sampled_arch = network.module.dync_genotype(True) | ||||
|     network.module.set_cal_mode('dynamic', sampled_arch) | ||||
|     #network.module.set_cal_mode( 'urs' ) | ||||
|     network.zero_grad() | ||||
|     _, logits = network( torch.cat((base_inputs, arch_inputs), dim=0) ) | ||||
|     logits    = logits[:base_inputs.size(0)] | ||||
|     base_loss = criterion(logits, base_targets) | ||||
|     base_loss.backward() | ||||
|     w_optimizer.step() | ||||
| @@ -46,8 +49,9 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer | ||||
|     base_top5.update  (base_prec5.item(), base_inputs.size(0)) | ||||
|  | ||||
|     # update the architecture-weight | ||||
|     network.eval() | ||||
|     network.module.set_cal_mode( 'joint' ) | ||||
|     a_optimizer.zero_grad() | ||||
|     network.zero_grad() | ||||
|     _, logits = network(arch_inputs) | ||||
|     arch_loss = criterion(logits, arch_targets) | ||||
|     arch_loss.backward() | ||||
| @@ -68,15 +72,42 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer | ||||
|       Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5) | ||||
|       Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5) | ||||
|       logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr) | ||||
|   return base_losses.avg, base_top1.avg, base_top5.avg | ||||
|       #print (nn.functional.softmax(network.module.arch_parameters, dim=-1)) | ||||
|       #print (network.module.arch_parameters) | ||||
|   return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg | ||||
|  | ||||
|  | ||||
| def get_best_arch(xloader, network, n_samples): | ||||
|   with torch.no_grad(): | ||||
|     network.eval() | ||||
|     archs, valid_accs = [], [] | ||||
|     loader_iter = iter(xloader) | ||||
|     for i in range(n_samples): | ||||
|       try: | ||||
|         inputs, targets = next(loader_iter) | ||||
|       except: | ||||
|         loader_iter = iter(xloader) | ||||
|         inputs, targets = next(loader_iter) | ||||
|  | ||||
|       sampled_arch = network.module.dync_genotype(False) | ||||
|       network.module.set_cal_mode('dynamic', sampled_arch) | ||||
|       _, logits = network(inputs) | ||||
|       val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5)) | ||||
|  | ||||
|       archs.append( sampled_arch ) | ||||
|       valid_accs.append( val_top1.item() ) | ||||
|  | ||||
|     best_idx = np.argmax(valid_accs) | ||||
|     best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx] | ||||
|     return best_arch, best_valid_acc | ||||
|  | ||||
|  | ||||
| def valid_func(xloader, network, criterion): | ||||
|   data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|   arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|   network.train() | ||||
|   end = time.time() | ||||
|   with torch.no_grad(): | ||||
|     network.eval() | ||||
|     for step, (arch_inputs, arch_targets) in enumerate(xloader): | ||||
|       arch_targets = arch_targets.cuda(non_blocking=True) | ||||
|       # measure data loading time | ||||
| @@ -117,8 +148,8 @@ def main(xargs): | ||||
|     logger.log('Load split file from {:}'.format(split_Fpath)) | ||||
|   else: | ||||
|     raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) | ||||
|   config_path = 'configs/nas-benchmark/algos/SETN.config' | ||||
|   config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||
|   #config_path = 'configs/nas-benchmark/algos/SETN.config' | ||||
|   config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||
|   # To split data | ||||
|   train_data_v2 = deepcopy(train_data) | ||||
|   train_data_v2.transform = valid_data.transform | ||||
| @@ -126,7 +157,7 @@ def main(xargs): | ||||
|   search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) | ||||
|   # data loader | ||||
|   search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) | ||||
|   valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) | ||||
|   valid_loader  = torch.utils.data.DataLoader(valid_data,  batch_size=config.test_batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) | ||||
|   logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) | ||||
|   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
|  | ||||
| @@ -134,6 +165,7 @@ def main(xargs): | ||||
|   model_config = dict2config({'name': 'SETN', 'C': xargs.channel, 'N': xargs.num_cells, | ||||
|                               'max_nodes': xargs.max_nodes, 'num_classes': class_num, | ||||
|                               'space'    : search_space}, None) | ||||
|   logger.log('search space : {:}'.format(search_space)) | ||||
|   search_model = get_cell_based_tiny_net(model_config) | ||||
|    | ||||
|   w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config) | ||||
| @@ -173,17 +205,24 @@ def main(xargs): | ||||
|     epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) | ||||
|     logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()))) | ||||
|  | ||||
|     search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger) | ||||
|     logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5)) | ||||
|     search_model.set_cal_mode('urs') | ||||
|     search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \ | ||||
|                 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger) | ||||
|     logger.log('[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5)) | ||||
|     logger.log('[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_a_loss, search_a_top1, search_a_top5)) | ||||
|  | ||||
|     genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num) | ||||
|     network.module.set_cal_mode('dynamic', genotype) | ||||
|     valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion) | ||||
|     logger.log('[{:}] URS---evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) | ||||
|     search_model.set_cal_mode('joint') | ||||
|     valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion) | ||||
|     logger.log('[{:}] JOINT-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) | ||||
|     search_model.set_cal_mode('select') | ||||
|     valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion) | ||||
|     logger.log('[{:}] Selec-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) | ||||
|     logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype)) | ||||
|     #search_model.set_cal_mode('urs') | ||||
|     #valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion) | ||||
|     #logger.log('[{:}] URS---evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) | ||||
|     #search_model.set_cal_mode('joint') | ||||
|     #valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion) | ||||
|     #logger.log('[{:}] JOINT-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) | ||||
|     #search_model.set_cal_mode('select') | ||||
|     #valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion) | ||||
|     #logger.log('[{:}] Selec-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) | ||||
|     # check the best accuracy | ||||
|     valid_accuracies[epoch] = valid_a_top1 | ||||
|     if valid_a_top1 > valid_accuracies['best']: | ||||
| @@ -192,7 +231,7 @@ def main(xargs): | ||||
|       find_best = True | ||||
|     else: find_best = False | ||||
|  | ||||
|     genotypes[epoch] = search_model.genotype() | ||||
|     genotypes[epoch] = genotype | ||||
|     logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch])) | ||||
|     # save checkpoint | ||||
|     save_path = save_checkpoint({'epoch' : epoch + 1, | ||||
| @@ -219,6 +258,7 @@ def main(xargs): | ||||
|     start_time = time.time() | ||||
|  | ||||
|   # sampling | ||||
|   """ | ||||
|   with torch.no_grad(): | ||||
|     logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) | ||||
|   selected_archs = set() | ||||
| @@ -238,6 +278,7 @@ def main(xargs): | ||||
|     if best_arch is None or best_acc < valid_a_top1: | ||||
|       best_arch, best_acc = arch, valid_a_top1 | ||||
|   logger.log('Find the best one : {:} with accuracy={:.2f}%'.format(best_arch, best_acc)) | ||||
|   """ | ||||
|  | ||||
|   logger.log('\n' + '-'*100) | ||||
|   # check the performance from the architecture dataset | ||||
| @@ -267,6 +308,7 @@ if __name__ == '__main__': | ||||
|   parser.add_argument('--channel',            type=int,   help='The number of channels.') | ||||
|   parser.add_argument('--num_cells',          type=int,   help='The number of cells in one stage.') | ||||
|   parser.add_argument('--select_num',         type=int,   help='The number of selected architectures to evaluate.') | ||||
|   parser.add_argument('--config_path',        type=str,   help='.') | ||||
|   # architecture leraning rate | ||||
|   parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') | ||||
|   parser.add_argument('--arch_weight_decay',  type=float, default=1e-3, help='weight decay for arch encoding') | ||||
|   | ||||
		Reference in New Issue
	
	Block a user