Prototype generic nas model (cont.) for GDAS.
This commit is contained in:
		| @@ -377,8 +377,7 @@ def main(xargs): | ||||
|     start_epoch = last_info['epoch'] | ||||
|     checkpoint  = torch.load(last_info['last_checkpoint']) | ||||
|     genotypes   = checkpoint['genotypes'] | ||||
|     if xargs.algo == 'enas': | ||||
|       baseline  = checkpoint['baseline'] | ||||
|     baseline  = checkpoint['baseline'] | ||||
|     valid_accuracies = checkpoint['valid_accuracies'] | ||||
|     search_model.load_state_dict( checkpoint['search_model'] ) | ||||
|     w_scheduler.load_state_dict ( checkpoint['w_scheduler'] ) | ||||
| @@ -401,7 +400,7 @@ def main(xargs): | ||||
|     network.set_drop_path(float(epoch+1) / total_epoch, xargs.drop_path_rate) | ||||
|     if xargs.algo == 'gdas': | ||||
|       network.set_tau( xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1) ) | ||||
|       logger.log('[Reset tau as : {:}'.format(network.tau)) | ||||
|       logger.log('[RESET tau as : {:} and drop_path as {:}]'.format(network.tau, network.drop_path)) | ||||
|     search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \ | ||||
|                 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, xargs.algo, logger) | ||||
|     search_time.update(time.time() - start_time) | ||||
| @@ -423,6 +422,7 @@ def main(xargs): | ||||
|       network.set_cal_mode('urs', None) | ||||
|     else: | ||||
|       raise ValueError('Invalid algorithm name : {:}'.format(xargs.algo)) | ||||
|     logger.log('[{:}] - [get_best_arch] : {:} -> {:}'.format(epoch_str, genotype, temp_accuracy)) | ||||
|     valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion, xargs.algo, logger) | ||||
|     logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype)) | ||||
|     valid_accuracies[epoch] = valid_a_top1 | ||||
| @@ -494,7 +494,7 @@ if __name__ == '__main__': | ||||
|   parser.add_argument('--eval_candidate_num', type=int,   default=100, help='The number of selected architectures to evaluate.') | ||||
|   # | ||||
|   parser.add_argument('--track_running_stats',type=int,   default=0, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.') | ||||
|   parser.add_argument('--affine'      ,       type=int,   default=1, choices=[0,1],help='Whether use affine=True or False in the BN layer.') | ||||
|   parser.add_argument('--affine'      ,       type=int,   default=0, choices=[0,1],help='Whether use affine=True or False in the BN layer.') | ||||
|   parser.add_argument('--config_path' ,       type=str,   default='./configs/nas-benchmark/algos/weight-sharing.config', help='The path of configuration.') | ||||
|   # architecture leraning rate | ||||
|   parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') | ||||
|   | ||||
		Reference in New Issue
	
	Block a user