update configs
This commit is contained in:
		
							
								
								
									
										15
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								README.md
									
									
									
									
									
								
							| @@ -1,11 +1,13 @@ | ||||
| # Nueral Architecture Search (NAS) | ||||
|  | ||||
| This project contains the following neural architecture search algorithms, implemented in [PyTorch](http://pytorch.org). More NAS resources can be found in [Awesome-NAS](https://github.com/D-X-Y/Awesome-NAS). | ||||
| This project contains the following neural architecture search algorithms, implemented in [PyTorch](http://pytorch.org). | ||||
| More NAS resources can be found in [Awesome-NAS](https://github.com/D-X-Y/Awesome-NAS). | ||||
|  | ||||
| - Network Pruning via Transformable Architecture Search, NeurIPS 2019 | ||||
| - One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 | ||||
| - Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 | ||||
| - several typical classification models, e.g., ResNet and DenseNet (see [BASELINE.md](https://github.com/D-X-Y/NAS-Projects/blob/master/BASELINE.md)) | ||||
| - 10 NAS algorithms for the neural topology in `exps/algos` | ||||
| - Several typical classification models, e.g., ResNet and DenseNet (see [BASELINE.md](https://github.com/D-X-Y/NAS-Projects/blob/master/BASELINE.md)) | ||||
|  | ||||
|  | ||||
| ## Requirements and Preparation | ||||
| @@ -15,7 +17,7 @@ Please install `PyTorch>=1.1.0`, `Python>=3.6`, and `opencv`. | ||||
| The CIFAR and ImageNet should be downloaded and extracted into `$TORCH_HOME`. | ||||
| Some methods use knowledge distillation (KD), which require pre-trained models. Please download these models from [Google Driver](https://drive.google.com/open?id=1ANmiYEGX-IQZTfH8w0aSpj-Wypg-0DR-) (or train by yourself) and save into `.latent-data`. | ||||
|  | ||||
| ### usefull tools | ||||
| ### Usefull tools | ||||
| 1. Compute the number of parameters and FLOPs of a model: | ||||
| ``` | ||||
| from utils import get_model_infos | ||||
| @@ -52,13 +54,18 @@ CUDA_VISIBLE_DEVICES=0,1 bash ./scripts-search/search-width-gumbel.sh cifar10 Re | ||||
|  | ||||
| Search for both depth and width configuration of ResNet: | ||||
| ``` | ||||
| CUDA_VISIBLE_DEVICES=0,1 bash ./scripts-search/search-cifar.sh cifar10 ResNet56  CIFARX 0.47 -1 | ||||
| CUDA_VISIBLE_DEVICES=0,1 bash ./scripts-search/search-shape-cifar.sh cifar10 ResNet56  CIFARX 0.47 -1 | ||||
| ``` | ||||
|  | ||||
| args: `cifar10` indicates the dataset name, `ResNet56` indicates the basemodel name, `CIFARX` indicates the searching hyper-parameters, `0.47/0.57` indicates the expected FLOP ratio, `-1` indicates the random seed. | ||||
|  | ||||
| ### Model Configuration | ||||
| The searched shapes for ResNet-20/32/56/110/164 in Table 3 in the original paper are listed in [`configs/NeurIPS-2019`](https://github.com/D-X-Y/NAS-Projects/tree/master/configs/NeurIPS-2019). | ||||
| If you want to directly train a model with searched configuration of TAS, try these: | ||||
| ``` | ||||
| CUDA_VISIBLE_DEVICES=0,1 bash ./scripts/tas-infer-train.sh cifar10  C010-ResNet32 -1 | ||||
| CUDA_VISIBLE_DEVICES=0,1 bash ./scripts/tas-infer-train.sh cifar100 C100-ResNet32 -1 | ||||
| ``` | ||||
|  | ||||
|  | ||||
| ## [One-Shot Neural Architecture Search via Self-Evaluated Template Network](https://arxiv.org/abs/1910.05733) | ||||
|   | ||||
| @@ -3,9 +3,10 @@ | ||||
|   "arch"               : ["str"   , "resnet"], | ||||
|   "depth"              : ["int"   , "32"], | ||||
|   "module"             : ["str"   , "ResNetBasicblock"], | ||||
|   "super_type"         : ["str"   , "infer"], | ||||
|   "super_type"         : ["str"   , "infer-shape"], | ||||
|   "zero_init_residual" : ["bool"  , "0"], | ||||
|   "class_num"          : ["int"   , "100"], | ||||
|   "xchannels"          : ["int"   , ["3", "16", "4", "4", "4", "14", "6", "4", "8", "4", "4", "4", "32", "32", "9", "28", "28", "28", "28", "28", "32", "32", "64", "64", "64", "64", "64", "64", "64", "64", "64", "64"]], | ||||
|   "xchannels"          : ["int"   , ["3", "16", "4", "4", "6", "11", "6", "4", "8", "4", "4", "4", "32", "32", "9", "28", "28", "28", "28", "28", "32", "32", "64", "64", "64", "64", "64", "64", "64", "64", "64", "64"]], | ||||
|   "xblocks"            : ["int"   , ["5", "5", "5"]], | ||||
|   "estimated_FLOP"     : ["float" , "42.493184"] | ||||
| } | ||||
| @@ -1,287 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import os, sys, time, argparse, collections | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from pathlib import Path | ||||
| from collections import defaultdict | ||||
| lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
| from config_utils import load_config, dict2config | ||||
| from datasets     import get_datasets | ||||
| # AA-NAS-Bench related module or function | ||||
| from models       import CellStructure, get_cell_based_tiny_net | ||||
| from aa_nas_api   import ArchResults, ResultsCount | ||||
| from AA_functions import pure_evaluate | ||||
|  | ||||
|  | ||||
|  | ||||
| def account_one_arch(arch_index, arch_str, checkpoints, datasets, dataloader_dict): | ||||
|   information = ArchResults(arch_index, arch_str) | ||||
|  | ||||
|   for checkpoint_path in checkpoints: | ||||
|     checkpoint = torch.load(checkpoint_path, map_location='cpu') | ||||
|     used_seed  = checkpoint_path.name.split('-')[-1].split('.')[0] | ||||
|     for dataset in datasets: | ||||
|       assert dataset in checkpoint, 'Can not find {:} in arch-{:} from {:}'.format(dataset, arch_index, checkpoint_path) | ||||
|       results     = checkpoint[dataset] | ||||
|       assert results['finish-train'], 'This {:} arch seed={:} does not finish train on {:} ::: {:}'.format(arch_index, used_seed, dataset, checkpoint_path) | ||||
|       arch_config = {'channel': results['channel'], 'num_cells': results['num_cells'], 'arch_str': arch_str, 'class_num': results['config']['class_num']} | ||||
|       xresult     = ResultsCount(dataset, results['net_state_dict'], results['train_acc1es'], results['train_losses'], \ | ||||
|                                   results['param'], results['flop'], arch_config, used_seed, results['total_epoch'], None) | ||||
|       if dataset == 'cifar10-valid': | ||||
|         xresult.update_eval('x-valid' , results['valid_acc1es'], results['valid_losses']) | ||||
|       elif dataset == 'cifar10': | ||||
|         xresult.update_eval('ori-test', results['valid_acc1es'], results['valid_losses']) | ||||
|       elif dataset == 'cifar100' or dataset == 'ImageNet16-120': | ||||
|         xresult.update_eval('ori-test', results['valid_acc1es'], results['valid_losses']) | ||||
|         net_config = dict2config({'name': 'infer.tiny', 'C': arch_config['channel'], 'N': arch_config['num_cells'], | ||||
|                                   'genotype': CellStructure.str2structure(arch_config['arch_str']), 'num_classes':arch_config['class_num']}, None) | ||||
|         network = get_cell_based_tiny_net(net_config) | ||||
|         network.load_state_dict(xresult.get_net_param()) | ||||
|         network = network.cuda() | ||||
|         loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format(dataset, 'valid')], network) | ||||
|         xresult.update_eval('x-valid', {results['total_epoch']-1: top1}, {results['total_epoch']-1: loss}) | ||||
|         loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format(dataset,  'test')], network) | ||||
|         xresult.update_eval('x-test' , {results['total_epoch']-1: top1}, {results['total_epoch']-1: loss}) | ||||
|         xresult.update_latency(latencies) | ||||
|       else: | ||||
|         raise ValueError('invalid dataset name : {:}'.format(dataset)) | ||||
|       information.update(dataset, int(used_seed), xresult) | ||||
|   return information | ||||
|  | ||||
|  | ||||
|  | ||||
| def GET_DataLoaders(workers): | ||||
|  | ||||
|   torch.set_num_threads(workers) | ||||
|  | ||||
|   root_dir  = (Path(__file__).parent / '..').resolve() | ||||
|   torch_dir = Path(os.environ['TORCH_HOME']) | ||||
|   # cifar | ||||
|   cifar_config_path = root_dir / 'configs' / 'nas-benchmark' / 'CIFAR.config' | ||||
|   cifar_config = load_config(cifar_config_path, None, None) | ||||
|   print ('{:} Create data-loader for all datasets'.format(time_string())) | ||||
|   print ('-'*200) | ||||
|   TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets('cifar10', str(torch_dir/'cifar.python'), -1) | ||||
|   print ('original CIFAR-10 : {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_CIFAR10), len(VALID_CIFAR10), xshape, class_num)) | ||||
|   cifar10_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'cifar-split.txt', None, None) | ||||
|   assert cifar10_splits.train[:10] == [0, 5, 7, 11, 13, 15, 16, 17, 20, 24] and cifar10_splits.valid[:10] == [1, 2, 3, 4, 6, 8, 9, 10, 12, 14] | ||||
|   temp_dataset = deepcopy(TRAIN_CIFAR10) | ||||
|   temp_dataset.transform = VALID_CIFAR10.transform | ||||
|   # data loader | ||||
|   trainval_cifar10_loader = torch.utils.data.DataLoader(TRAIN_CIFAR10, batch_size=cifar_config.batch_size, shuffle=True , num_workers=workers, pin_memory=True) | ||||
|   train_cifar10_loader    = torch.utils.data.DataLoader(TRAIN_CIFAR10, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.train), num_workers=workers, pin_memory=True) | ||||
|   valid_cifar10_loader    = torch.utils.data.DataLoader(temp_dataset , batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.valid), num_workers=workers, pin_memory=True) | ||||
|   test__cifar10_loader    = torch.utils.data.DataLoader(VALID_CIFAR10, batch_size=cifar_config.batch_size, shuffle=False, num_workers=workers, pin_memory=True) | ||||
|   print ('CIFAR-10  : trval-loader has {:3d} batch with {:} per batch'.format(len(trainval_cifar10_loader), cifar_config.batch_size)) | ||||
|   print ('CIFAR-10  : train-loader has {:3d} batch with {:} per batch'.format(len(train_cifar10_loader), cifar_config.batch_size)) | ||||
|   print ('CIFAR-10  : valid-loader has {:3d} batch with {:} per batch'.format(len(valid_cifar10_loader), cifar_config.batch_size)) | ||||
|   print ('CIFAR-10  : test--loader has {:3d} batch with {:} per batch'.format(len(test__cifar10_loader), cifar_config.batch_size)) | ||||
|   print ('-'*200) | ||||
|   # CIFAR-100 | ||||
|   TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets('cifar100', str(torch_dir/'cifar.python'), -1) | ||||
|   print ('original CIFAR-100: {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_CIFAR100), len(VALID_CIFAR100), xshape, class_num)) | ||||
|   cifar100_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'cifar100-test-split.txt', None, None) | ||||
|   assert cifar100_splits.xvalid[:10] == [1, 3, 4, 5, 8, 10, 13, 14, 15, 16] and cifar100_splits.xtest[:10] == [0, 2, 6, 7, 9, 11, 12, 17, 20, 24] | ||||
|   train_cifar100_loader = torch.utils.data.DataLoader(TRAIN_CIFAR100, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True) | ||||
|   valid_cifar100_loader = torch.utils.data.DataLoader(VALID_CIFAR100, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), num_workers=workers, pin_memory=True) | ||||
|   test__cifar100_loader = torch.utils.data.DataLoader(VALID_CIFAR100, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest) , num_workers=workers, pin_memory=True) | ||||
|   print ('CIFAR-100  : train-loader has {:3d} batch'.format(len(train_cifar100_loader))) | ||||
|   print ('CIFAR-100  : valid-loader has {:3d} batch'.format(len(valid_cifar100_loader))) | ||||
|   print ('CIFAR-100  : test--loader has {:3d} batch'.format(len(test__cifar100_loader))) | ||||
|   print ('-'*200) | ||||
|  | ||||
|   imagenet16_config_path = 'configs/nas-benchmark/ImageNet-16.config' | ||||
|   imagenet16_config = load_config(imagenet16_config_path, None, None) | ||||
|   TRAIN_ImageNet16_120, VALID_ImageNet16_120, xshape, class_num = get_datasets('ImageNet16-120', str(torch_dir/'cifar.python'/'ImageNet16'), -1) | ||||
|   print ('original TRAIN_ImageNet16_120: {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_ImageNet16_120), len(VALID_ImageNet16_120), xshape, class_num)) | ||||
|   imagenet_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'imagenet-16-120-test-split.txt', None, None) | ||||
|   assert imagenet_splits.xvalid[:10] == [1, 2, 3, 6, 7, 8, 9, 12, 16, 18] and imagenet_splits.xtest[:10] == [0, 4, 5, 10, 11, 13, 14, 15, 17, 20] | ||||
|   train_imagenet_loader = torch.utils.data.DataLoader(TRAIN_ImageNet16_120, batch_size=imagenet16_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True) | ||||
|   valid_imagenet_loader = torch.utils.data.DataLoader(VALID_ImageNet16_120, batch_size=imagenet16_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xvalid), num_workers=workers, pin_memory=True) | ||||
|   test__imagenet_loader = torch.utils.data.DataLoader(VALID_ImageNet16_120, batch_size=imagenet16_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xtest) , num_workers=workers, pin_memory=True) | ||||
|   print ('ImageNet-16-120  : train-loader has {:3d} batch with {:} per batch'.format(len(train_imagenet_loader), imagenet16_config.batch_size)) | ||||
|   print ('ImageNet-16-120  : valid-loader has {:3d} batch with {:} per batch'.format(len(valid_imagenet_loader), imagenet16_config.batch_size)) | ||||
|   print ('ImageNet-16-120  : test--loader has {:3d} batch with {:} per batch'.format(len(test__imagenet_loader), imagenet16_config.batch_size)) | ||||
|  | ||||
|   # 'cifar10', 'cifar100', 'ImageNet16-120' | ||||
|   loaders = {'cifar10@trainval': trainval_cifar10_loader, | ||||
|              'cifar10@train'   : train_cifar10_loader, | ||||
|              'cifar10@valid'   : valid_cifar10_loader, | ||||
|              'cifar10@test'    : test__cifar10_loader, | ||||
|              'cifar100@train'  : train_cifar100_loader, | ||||
|              'cifar100@valid'  : valid_cifar100_loader, | ||||
|              'cifar100@test'   : test__cifar100_loader, | ||||
|              'ImageNet16-120@train': train_imagenet_loader, | ||||
|              'ImageNet16-120@valid': valid_imagenet_loader, | ||||
|              'ImageNet16-120@test' : test__imagenet_loader} | ||||
|   return loaders | ||||
|  | ||||
|  | ||||
|  | ||||
| def simplify(save_dir, meta_file, basestr, target_dir): | ||||
|   meta_infos     = torch.load(meta_file, map_location='cpu') | ||||
|   meta_archs     = meta_infos['archs'] # a list of architecture strings | ||||
|   meta_num_archs = meta_infos['total'] | ||||
|   meta_max_node  = meta_infos['max_node'] | ||||
|   assert meta_num_archs == len(meta_archs), 'invalid number of archs : {:} vs {:}'.format(meta_num_archs, len(meta_archs)) | ||||
|  | ||||
|   sub_model_dirs = sorted(list(save_dir.glob('*-*-{:}'.format(basestr)))) | ||||
|   print ('{:} find {:} directories used to save checkpoints'.format(time_string(), len(sub_model_dirs))) | ||||
|    | ||||
|   subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0 | ||||
|   num_seeds = defaultdict(lambda: 0) | ||||
|   for index, sub_dir in enumerate(sub_model_dirs): | ||||
|     xcheckpoints = list(sub_dir.glob('arch-*-seed-*.pth')) | ||||
|     arch_indexes = set() | ||||
|     for checkpoint in xcheckpoints: | ||||
|       temp_names = checkpoint.name.split('-') | ||||
|       assert len(temp_names) == 4 and temp_names[0] == 'arch' and temp_names[2] == 'seed', 'invalid checkpoint name : {:}'.format(checkpoint.name) | ||||
|       arch_indexes.add( temp_names[1] ) | ||||
|     subdir2archs[sub_dir] = sorted(list(arch_indexes)) | ||||
|     num_evaluated_arch   += len(arch_indexes) | ||||
|     # count number of seeds for each architecture | ||||
|     for arch_index in arch_indexes: | ||||
|       num_seeds[ len(list(sub_dir.glob('arch-{:}-seed-*.pth'.format(arch_index)))) ] += 1 | ||||
|   print('{:} There are {:5d} architectures that have been evaluated ({:} in total).'.format(time_string(), num_evaluated_arch, meta_num_archs)) | ||||
|   for key in sorted( list( num_seeds.keys() ) ): print ('{:} There are {:5d} architectures that are evaluated {:} times.'.format(time_string(), num_seeds[key], key)) | ||||
|  | ||||
|   dataloader_dict = GET_DataLoaders( 6 ) | ||||
|  | ||||
|   to_save_simply = save_dir / 'simplifies' | ||||
|   to_save_allarc = save_dir / 'simplifies' / 'architectures' | ||||
|   if not to_save_simply.exists(): to_save_simply.mkdir(parents=True, exist_ok=True) | ||||
|   if not to_save_allarc.exists(): to_save_allarc.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|   assert (save_dir / target_dir) in subdir2archs, 'can not find {:}'.format(target_dir) | ||||
|   arch2infos, datasets = {}, ('cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120') | ||||
|   evaluated_indexes    = set() | ||||
|   target_directory     = save_dir / target_dir | ||||
|   arch_indexes         = subdir2archs[ target_directory ] | ||||
|   num_seeds            = defaultdict(lambda: 0) | ||||
|   end_time             = time.time() | ||||
|   arch_time            = AverageMeter() | ||||
|   for idx, arch_index in enumerate(arch_indexes): | ||||
|     checkpoints = list(target_directory.glob('arch-{:}-seed-*.pth'.format(arch_index))) | ||||
|     try: | ||||
|       arch_info = account_one_arch(arch_index, meta_archs[int(arch_index)], checkpoints, datasets, dataloader_dict) | ||||
|       num_seeds[ len(checkpoints) ] += 1 | ||||
|     except: | ||||
|       print('Loading {:} failed, : {:}'.format(arch_index, checkpoints)) | ||||
|       continue | ||||
|     assert int(arch_index) not in evaluated_indexes, 'conflict arch-index : {:}'.format(arch_index) | ||||
|     assert 0 <= int(arch_index) < len(meta_archs), 'invalid arch-index {:} (not found in meta_archs)'.format(arch_index) | ||||
|     evaluated_indexes.add( int(arch_index) ) | ||||
|     arch2infos[int(arch_index)] = arch_info | ||||
|     torch.save(arch_info.state_dict(), to_save_allarc / '{:}-FULL.pth'.format(arch_index)) | ||||
|     #torch.save(arch_info, to_save_allarc / '{:}-FULL.pth'.format(arch_index)) | ||||
|     arch_info.clear_params() | ||||
|     torch.save(arch_info.state_dict(), to_save_allarc / '{:}-SIMPLE.pth'.format(arch_index)) | ||||
|     # measure elapsed time | ||||
|     arch_time.update(time.time() - end_time) | ||||
|     end_time  = time.time() | ||||
|     need_time = '{:}'.format( convert_secs2time(arch_time.avg * (len(arch_indexes)-idx-1), True) ) | ||||
|     print('{:} {:} [{:03d}/{:03d}] : {:} still need {:}'.format(time_string(), target_dir, idx, len(arch_indexes), arch_index, need_time)) | ||||
|   # measure time | ||||
|   xstrs = ['{:}:{:03d}'.format(key, num_seeds[key]) for key in sorted( list( num_seeds.keys() ) ) ] | ||||
|   print('{:} {:} done : {:}'.format(time_string(), target_dir, xstrs)) | ||||
|   final_infos = {'meta_archs' : meta_archs, | ||||
|                  'total_archs': meta_num_archs, | ||||
|                  'basestr'    : basestr, | ||||
|                  'arch2infos' : arch2infos, | ||||
|                  'evaluated_indexes': evaluated_indexes} | ||||
|   save_file_name = to_save_simply / '{:}.pth'.format(target_dir) | ||||
|   torch.save(final_infos, save_file_name) | ||||
|   print ('Save {:} / {:} architecture results into {:}.'.format(len(evaluated_indexes), meta_num_archs, save_file_name)) | ||||
|  | ||||
|  | ||||
|  | ||||
| def merge_all(save_dir, meta_file, basestr): | ||||
|   meta_infos     = torch.load(meta_file, map_location='cpu') | ||||
|   meta_archs     = meta_infos['archs'] | ||||
|   meta_num_archs = meta_infos['total'] | ||||
|   meta_max_node  = meta_infos['max_node'] | ||||
|   assert meta_num_archs == len(meta_archs), 'invalid number of archs : {:} vs {:}'.format(meta_num_archs, len(meta_archs)) | ||||
|  | ||||
|   sub_model_dirs = sorted(list(save_dir.glob('*-*-{:}'.format(basestr)))) | ||||
|   print ('{:} find {:} directories used to save checkpoints'.format(time_string(), len(sub_model_dirs))) | ||||
|   for index, sub_dir in enumerate(sub_model_dirs): | ||||
|     arch_info_files = sorted( list(sub_dir.glob('arch-*-seed-*.pth') ) ) | ||||
|     print ('The {:02d}/{:02d}-th directory : {:} : {:} runs.'.format(index, len(sub_model_dirs), sub_dir, len(arch_info_files))) | ||||
|    | ||||
|   subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0 | ||||
|   num_seeds = defaultdict(lambda: 0) | ||||
|   for index, sub_dir in enumerate(sub_model_dirs): | ||||
|     xcheckpoints = list(sub_dir.glob('arch-*-seed-*.pth')) | ||||
|     arch_indexes = set() | ||||
|     for checkpoint in xcheckpoints: | ||||
|       temp_names = checkpoint.name.split('-') | ||||
|       assert len(temp_names) == 4 and temp_names[0] == 'arch' and temp_names[2] == 'seed', 'invalid checkpoint name : {:}'.format(checkpoint.name) | ||||
|       arch_indexes.add( temp_names[1] ) | ||||
|     subdir2archs[sub_dir] = sorted(list(arch_indexes)) | ||||
|     num_evaluated_arch   += len(arch_indexes) | ||||
|     # count number of seeds for each architecture | ||||
|     for arch_index in arch_indexes: | ||||
|       num_seeds[ len(list(sub_dir.glob('arch-{:}-seed-*.pth'.format(arch_index)))) ] += 1 | ||||
|   print('There are {:5d} architectures that have been evaluated ({:} in total).'.format(num_evaluated_arch, meta_num_archs)) | ||||
|   for key in sorted( list( num_seeds.keys() ) ): print ('There are {:5d} architectures that are evaluated {:} times.'.format(num_seeds[key], key)) | ||||
|  | ||||
|   arch2infos, evaluated_indexes = dict(), set() | ||||
|   for IDX, (sub_dir, arch_indexes) in enumerate(subdir2archs.items()): | ||||
|     ckp_path = sub_dir.parent / 'simplifies' / '{:}.pth'.format(sub_dir.name) | ||||
|     if ckp_path.exists(): | ||||
|       sub_ckps = torch.load(ckp_path, map_location='cpu') | ||||
|       assert sub_ckps['total_archs'] == meta_num_archs and sub_ckps['basestr'] == basestr | ||||
|       xarch2infos = sub_ckps['arch2infos'] | ||||
|       xevalindexs = sub_ckps['evaluated_indexes'] | ||||
|       for eval_index in xevalindexs: | ||||
|         assert eval_index not in evaluated_indexes and eval_index not in arch2infos | ||||
|         arch2infos[eval_index] = xarch2infos[eval_index].state_dict() | ||||
|         evaluated_indexes.add( eval_index ) | ||||
|       print ('{:} [{:03d}/{:03d}] merge data from {:} with {:} models.'.format(time_string(), IDX, len(subdir2archs), ckp_path, len(xevalindexs))) | ||||
|     else: | ||||
|       print ('{:} [{:03d}/{:03d}] can not find {:}, skip.'.format(time_string(), IDX, len(subdir2archs), ckp_path)) | ||||
|  | ||||
|   evaluated_indexes = sorted( list( evaluated_indexes ) ) | ||||
|   print ('Finally, there are {:} models.'.format(len(evaluated_indexes))) | ||||
|  | ||||
|   to_save_simply = save_dir / 'simplifies' | ||||
|   if not to_save_simply.exists(): to_save_simply.mkdir(parents=True, exist_ok=True) | ||||
|   final_infos = {'meta_archs' : meta_archs, | ||||
|                  'total_archs': meta_num_archs, | ||||
|                  'arch2infos' : arch2infos, | ||||
|                  'evaluated_indexes': evaluated_indexes} | ||||
|   save_file_name = to_save_simply / '{:}-final-infos.pth'.format(basestr) | ||||
|   torch.save(final_infos, save_file_name) | ||||
|   print ('Save {:} / {:} architecture results into {:}.'.format(len(evaluated_indexes), meta_num_archs, save_file_name)) | ||||
|  | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|  | ||||
|   parser = argparse.ArgumentParser(description='An Algorithm-Agnostic (AA) NAS Benchmark', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser.add_argument('--mode'         ,  type=str, choices=['cal', 'merge'],                  help='The running mode for this script.') | ||||
|   parser.add_argument('--base_save_dir',  type=str, default='./output/AA-NAS-BENCH-4',     help='The base-name of folder to save checkpoints and log.') | ||||
|   parser.add_argument('--target_dir'   ,  type=str,                                            help='The target directory.') | ||||
|   parser.add_argument('--max_node'     ,  type=int, default=4,                                 help='The maximum node in a cell.') | ||||
|   parser.add_argument('--channel'      ,  type=int, default=16,                                help='The number of channels.') | ||||
|   parser.add_argument('--num_cells'    ,  type=int, default=5,                                 help='The number of cells in one stage.') | ||||
|   args = parser.parse_args() | ||||
|    | ||||
|   save_dir  = Path( args.base_save_dir ) | ||||
|   meta_path = save_dir / 'meta-node-{:}.pth'.format(args.max_node) | ||||
|   assert save_dir.exists(),  'invalid save dir path : {:}'.format(save_dir) | ||||
|   assert meta_path.exists(), 'invalid saved meta path : {:}'.format(meta_path) | ||||
|   print ('start the statistics of our nas-benchmark from {:} using {:}.'.format(save_dir, args.target_dir)) | ||||
|   basestr   = 'C{:}-N{:}'.format(args.channel, args.num_cells) | ||||
|    | ||||
|   if args.mode == 'cal': | ||||
|     simplify(save_dir, meta_path, basestr, args.target_dir) | ||||
|   elif args.mode == 'merge': | ||||
|     merge_all(save_dir, meta_path, basestr) | ||||
|   else: | ||||
|     raise ValueError('invalid mode : {:}'.format(args.mode)) | ||||
| @@ -9,35 +9,8 @@ from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
| from models       import get_cell_based_tiny_net | ||||
| 
 | ||||
| 
 | ||||
| __all__ = ['evaluate_for_seed', 'pure_evaluate'] | ||||
| 
 | ||||
| 
 | ||||
| def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()): | ||||
|   data_time, batch_time, batch = AverageMeter(), AverageMeter(), None | ||||
|   losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|   latencies = [] | ||||
|   network.eval() | ||||
|   with torch.no_grad(): | ||||
|     end = time.time() | ||||
|     for i, (inputs, targets) in enumerate(xloader): | ||||
|       targets = targets.cuda(non_blocking=True) | ||||
|       inputs  = inputs.cuda(non_blocking=True) | ||||
|       data_time.update(time.time() - end) | ||||
|       # forward | ||||
|       features, logits = network(inputs) | ||||
|       loss             = criterion(logits, targets) | ||||
|       batch_time.update(time.time() - end) | ||||
|       if batch is None or batch == inputs.size(0): | ||||
|         batch = inputs.size(0) | ||||
|         latencies.append( batch_time.val - data_time.val ) | ||||
|       # record loss and accuracy | ||||
|       prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|       losses.update(loss.item(),  inputs.size(0)) | ||||
|       top1.update  (prec1.item(), inputs.size(0)) | ||||
|       top5.update  (prec5.item(), inputs.size(0)) | ||||
|       end = time.time() | ||||
|   if len(latencies) > 2: latencies = latencies[1:] | ||||
|   return losses.avg, top1.avg, top5.avg, latencies | ||||
| __all__ = ['evaluate_for_seed'] | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| @@ -47,7 +20,7 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode): | ||||
|   elif mode == 'valid': network.eval() | ||||
|   else: raise ValueError("The mode is not right : {:}".format(mode)) | ||||
| 
 | ||||
|   batch_time, end = AverageMeter(), time.time() | ||||
|   data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time() | ||||
|   for i, (inputs, targets) in enumerate(xloader): | ||||
|     if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader)) | ||||
| 
 | ||||
| @@ -72,7 +45,7 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode): | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loader, seed, logger): | ||||
| def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loaders, seed, logger): | ||||
| 
 | ||||
|   prepare_seed(seed) # random seed | ||||
|   net = get_cell_based_tiny_net(dict2config({'name': 'infer.tiny', | ||||
| @@ -83,7 +56,7 @@ def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loader, see | ||||
|   #net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num) | ||||
|   flop, param  = get_model_infos(net, config.xshape) | ||||
|   logger.log('Network : {:}'.format(net.get_message()), False) | ||||
|   logger.log('Seed-------------------------- {:} --------------------------'.format(seed)) | ||||
|   logger.log('{:} Seed-------------------------- {:} --------------------------'.format(time_string(), seed)) | ||||
|   logger.log('FLOP = {:} MB, Param = {:} MB'.format(flop, param)) | ||||
|   # train and valid | ||||
|   optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), config) | ||||
| @@ -96,16 +69,17 @@ def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loader, see | ||||
|     scheduler.update(epoch, 0.0) | ||||
| 
 | ||||
|     train_loss, train_acc1, train_acc5, train_tm = procedure(train_loader, network, criterion, scheduler, optimizer, 'train') | ||||
|     with torch.no_grad(): | ||||
|       valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(valid_loader, network, criterion,      None,      None, 'valid') | ||||
|     train_losses[epoch] = train_loss | ||||
|     train_acc1es[epoch] = train_acc1  | ||||
|     train_acc5es[epoch] = train_acc5 | ||||
|     valid_losses[epoch] = valid_loss | ||||
|     valid_acc1es[epoch] = valid_acc1  | ||||
|     valid_acc5es[epoch] = valid_acc5 | ||||
|     train_times [epoch] = train_tm | ||||
|     valid_times [epoch] = valid_tm | ||||
|     with torch.no_grad(): | ||||
|       for key, xloder in valid_loaders.items(): | ||||
|         valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(xloder  , network, criterion,      None,      None, 'valid') | ||||
|         valid_losses['{:}@{:}'.format(key,epoch)] = valid_loss | ||||
|         valid_acc1es['{:}@{:}'.format(key,epoch)] = valid_acc1  | ||||
|         valid_acc5es['{:}@{:}'.format(key,epoch)] = valid_acc5 | ||||
|         valid_times ['{:}@{:}'.format(key,epoch)] = valid_tm | ||||
| 
 | ||||
|     # measure elapsed time | ||||
|     epoch_time.update(time.time() - start_time) | ||||
| @@ -7,7 +7,7 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True | ||||
| from copy    import deepcopy | ||||
| from pathlib import Path | ||||
| 
 | ||||
| lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() | ||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| from config_utils import load_config | ||||
| from procedures   import save_checkpoint, copy_checkpoint | ||||
| @@ -15,7 +15,7 @@ from procedures   import get_machine_info | ||||
| from datasets     import get_datasets | ||||
| from log_utils    import Logger, AverageMeter, time_string, convert_secs2time | ||||
| from models       import CellStructure, CellArchitectures, get_search_spaces | ||||
| from AA_functions_v2 import evaluate_for_seed | ||||
| from functions    import evaluate_for_seed | ||||
| 
 | ||||
| 
 | ||||
| def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger): | ||||
| @@ -156,14 +156,14 @@ def main(save_dir, workers, datasets, xpaths, splits, use_less, srange, arch_ind | ||||
|   logger.close() | ||||
| 
 | ||||
| 
 | ||||
| def train_single_model(save_dir, workers, datasets, xpaths, use_less, splits, seeds, model_str, arch_config): | ||||
| def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, seeds, model_str, arch_config): | ||||
|   assert torch.cuda.is_available(), 'CUDA is not available.' | ||||
|   torch.backends.cudnn.enabled   = True | ||||
|   torch.backends.cudnn.deterministic = True | ||||
|   #torch.backends.cudnn.benchmark = True | ||||
|   torch.set_num_threads( workers ) | ||||
|    | ||||
|   save_dir = Path(save_dir) / 'specifics' / '{:}-{:}-{:}'.format(model_str, arch_config['channel'], arch_config['num_cells']) | ||||
|   save_dir = Path(save_dir) / 'specifics' / '{:}-{:}-{:}-{:}'.format('LESS' if use_less else 'FULL', model_str, arch_config['channel'], arch_config['num_cells']) | ||||
|   logger   = Logger(str(save_dir), 0, False) | ||||
|   if model_str in CellArchitectures: | ||||
|     arch   = CellArchitectures[model_str] | ||||
| @@ -247,18 +247,22 @@ def generate_meta_info(save_dir, max_node, divide=40): | ||||
|   torch.save(info, save_name) | ||||
|   print ('save the meta file into {:}'.format(save_name)) | ||||
| 
 | ||||
|   script_name = save_dir / 'meta-node-{:}.opt-script.txt'.format(max_node) | ||||
|   with open(str(script_name), 'w') as cfile: | ||||
|     gaps = total_arch // divide | ||||
|     for start in range(0, total_arch, gaps): | ||||
|       xend = min(start+gaps, total_arch) | ||||
|       cfile.write('bash ./scripts-search/AA-NAS-train-archs.sh {:5d} {:5d} -1 \'777 888 999\'\n'.format(start, xend-1)) | ||||
|   print ('save the training script into {:}'.format(script_name)) | ||||
|   script_name_full = save_dir / 'BENCH-102-N{:}.opt-full.script'.format(max_node) | ||||
|   script_name_less = save_dir / 'BENCH-102-N{:}.opt-less.script'.format(max_node) | ||||
|   full_file = open(str(script_name_full), 'w') | ||||
|   less_file = open(str(script_name_less), 'w') | ||||
|   gaps = total_arch // divide | ||||
|   for start in range(0, total_arch, gaps): | ||||
|     xend = min(start+gaps, total_arch) | ||||
|     full_file.write('bash ./scripts-search/NAS-Bench-102/train-models.sh 0 {:5d} {:5d} -1 \'777 888 999\'\n'.format(start, xend-1)) | ||||
|     less_file.write('bash ./scripts-search/NAS-Bench-102/train-models.sh 1 {:5d} {:5d} -1 \'777 888 999\'\n'.format(start, xend-1)) | ||||
|   print ('save the training script into {:} and {:}'.format(script_name_full, script_name_less)) | ||||
|   full_file.close() | ||||
|   less_file.close() | ||||
| 
 | ||||
|   script_name = save_dir / 'meta-node-{:}.cal-script.txt'.format(max_node) | ||||
|   macro = 'OMP_NUM_THREADS=6 CUDA_VISIBLE_DEVICES=0' | ||||
|   with open(str(script_name), 'w') as cfile: | ||||
|     gaps = total_arch // divide | ||||
|     for start in range(0, total_arch, gaps): | ||||
|       xend = min(start+gaps, total_arch) | ||||
|       cfile.write('{:} python exps/AA-NAS-statistics.py --mode cal --target_dir {:06d}-{:06d}-C16-N5\n'.format(macro, start, xend-1)) | ||||
| @@ -278,7 +282,7 @@ if __name__ == '__main__': | ||||
|   parser.add_argument('--datasets',    type=str,   nargs='+',      help='The applied datasets.') | ||||
|   parser.add_argument('--xpaths',      type=str,   nargs='+',      help='The root path for this dataset.') | ||||
|   parser.add_argument('--splits',      type=int,   nargs='+',      help='The root path for this dataset.') | ||||
|   parser.add_argument('--use_less',    type=int,   default=0,      help='Using the less-training-epoch config.') | ||||
|   parser.add_argument('--use_less',    type=int,   default=0, choices=[0,1], help='Using the less-training-epoch config.') | ||||
|   parser.add_argument('--seeds'  ,     type=int,   nargs='+',      help='The range of models to be evaluated') | ||||
|   parser.add_argument('--channel',     type=int,                   help='The number of channels.') | ||||
|   parser.add_argument('--num_cells',   type=int,                   help='The number of cells in one stage.') | ||||
| @@ -1,5 +1,5 @@ | ||||
| #!/bin/bash | ||||
| # bash ./scripts-search/AA-NAS-meta-gen.sh AA-NAS-BENCHMARK 4 | ||||
| # bash scripts-search/NAS-Bench-102/meta-gen.sh NAS-BENCH-102 4 | ||||
| echo script name: $0 | ||||
| echo $# arguments | ||||
| if [ "$#" -ne 2 ] ;then | ||||
| @@ -13,4 +13,4 @@ node=$2 | ||||
| 
 | ||||
| save_dir=./output/${name}-${node} | ||||
| 
 | ||||
| python ./exps/AA-NAS-Bench-main.py --mode meta --save_dir ${save_dir} --max_node ${node} | ||||
| python ./exps/NAS-Bench-102/main.py --mode meta --save_dir ${save_dir} --max_node ${node} | ||||
							
								
								
									
										34
									
								
								scripts-search/NAS-Bench-102/train-a-net.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								scripts-search/NAS-Bench-102/train-a-net.sh
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,34 @@ | ||||
| #!/bin/bash | ||||
| # bash ./scripts-search/NAS-Bench-102/train-a-net.sh resnet 16 5 | ||||
| echo script name: $0 | ||||
| echo $# arguments | ||||
| if [ "$#" -ne 3 ] ;then | ||||
|   echo "Input illegal number of parameters " $# | ||||
|   echo "Need 3 parameters for network, channel, num-of-cells" | ||||
|   exit 1 | ||||
| fi | ||||
| if [ "$TORCH_HOME" = "" ]; then | ||||
|   echo "Must set TORCH_HOME envoriment variable for data dir saving" | ||||
|   exit 1 | ||||
| else | ||||
|   echo "TORCH_HOME : $TORCH_HOME" | ||||
| fi | ||||
|  | ||||
| model=$1 | ||||
| channel=$2 | ||||
| num_cells=$3 | ||||
|  | ||||
| save_dir=./output/NAS-BENCH-102-4/ | ||||
|  | ||||
| OMP_NUM_THREADS=4 python ./exps/NAS-Bench-102/main.py \ | ||||
| 	--mode specific-${model} --save_dir ${save_dir} --max_node 4 \ | ||||
| 	--datasets cifar10 cifar10 cifar100 ImageNet16-120 \ | ||||
| 	--use_less 0 \ | ||||
| 	--splits         1       0        0              0 \ | ||||
| 	--xpaths $TORCH_HOME/cifar.python \ | ||||
| 		 $TORCH_HOME/cifar.python \ | ||||
| 		 $TORCH_HOME/cifar.python \ | ||||
| 		 $TORCH_HOME/cifar.python/ImageNet16 \ | ||||
| 	--channel ${channel} --num_cells ${num_cells} \ | ||||
| 	--workers 4 \ | ||||
| 	--seeds 777 888 999 | ||||
							
								
								
									
										43
									
								
								scripts-search/NAS-Bench-102/train-models.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								scripts-search/NAS-Bench-102/train-models.sh
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,43 @@ | ||||
| #!/bin/bash | ||||
| # bash ./scripts-search/train-models.sh 0/1 0 100 -1 '777 888 999' | ||||
| echo script name: $0 | ||||
| echo $# arguments | ||||
| if [ "$#" -ne 5 ] ;then | ||||
|   echo "Input illegal number of parameters " $# | ||||
|   echo "Need 5 parameters for use-less-or-not, start-and-end, arch-index, and seeds" | ||||
|   exit 1 | ||||
| fi | ||||
| if [ "$TORCH_HOME" = "" ]; then | ||||
|   echo "Must set TORCH_HOME envoriment variable for data dir saving" | ||||
|   exit 1 | ||||
| else | ||||
|   echo "TORCH_HOME : $TORCH_HOME" | ||||
| fi | ||||
|  | ||||
| use_less=$1 | ||||
| xstart=$2 | ||||
| xend=$3 | ||||
| arch_index=$4 | ||||
| all_seeds=$5 | ||||
|  | ||||
| save_dir=./output/NAS-BENCH-102-4/ | ||||
|  | ||||
| if [ ${arch_index} == "-1" ]; then | ||||
|   mode=new | ||||
| else | ||||
|   mode=cover | ||||
| fi | ||||
|  | ||||
| OMP_NUM_THREADS=4 python ./exps/AA-NAS-Bench-main.py \ | ||||
| 	--mode ${mode} --save_dir ${save_dir} --max_node 4 \ | ||||
| 	--use_less ${use_less} \ | ||||
| 	--datasets cifar10 cifar10 cifar100 ImageNet16-120 \ | ||||
| 	--splits   1       0       0        0 \ | ||||
| 	--xpaths $TORCH_HOME/cifar.python \ | ||||
| 		 $TORCH_HOME/cifar.python \ | ||||
| 		 $TORCH_HOME/cifar.python \ | ||||
| 		 $TORCH_HOME/cifar.python/ImageNet16 \ | ||||
| 	--channel 16 --num_cells 5 \ | ||||
| 	--workers 4 \ | ||||
| 	--srange ${xstart} ${xend} --arch_index ${arch_index} \ | ||||
| 	--seeds ${all_seeds} | ||||
| @@ -1,5 +1,5 @@ | ||||
| #!/bin/bash | ||||
| # bash ./scripts-search/search-cifar.sh cifar10 ResNet110 CIFAR 0.57 777 | ||||
| # bash ./scripts-search/search-shape-cifar.sh cifar10 ResNet110 CIFAR 0.57 777 | ||||
| set -e | ||||
| echo script name: $0 | ||||
| echo $# arguments | ||||
							
								
								
									
										51
									
								
								scripts/tas-infer-train.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								scripts/tas-infer-train.sh
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,51 @@ | ||||
| #!/bin/bash | ||||
| # bash ./scripts/tas-infer-train.sh cifar10 C100-ResNet32 -1 | ||||
| set -e | ||||
| echo script name: $0 | ||||
| echo $# arguments | ||||
| if [ "$#" -ne 3 ] ;then | ||||
|   echo "Input illegal number of parameters " $# | ||||
|   echo "Need 3 parameters for the dataset and the-config-name and the-random-seed" | ||||
|   exit 1 | ||||
| fi | ||||
| if [ "$TORCH_HOME" = "" ]; then | ||||
|   echo "Must set TORCH_HOME envoriment variable for data dir saving" | ||||
|   exit 1 | ||||
| else | ||||
|   echo "TORCH_HOME : $TORCH_HOME" | ||||
| fi | ||||
|  | ||||
| dataset=$1 | ||||
| model=$2 | ||||
| rseed=$3 | ||||
| batch=256 | ||||
|  | ||||
| save_dir=./output/search-shape/TAS-INFER-${dataset}-${model} | ||||
|  | ||||
| python --version | ||||
|  | ||||
| # normal training | ||||
| xsave_dir=${save_dir}-NMT | ||||
| OMP_NUM_THREADS=4 python ./exps/basic-main.py --dataset ${dataset} \ | ||||
| 	--data_path $TORCH_HOME/cifar.python \ | ||||
| 	--model_config ./configs/NeurIPS-2019/${model}.config \ | ||||
| 	--optim_config ./configs/opts/CIFAR-E300-W5-L1-COS.config \ | ||||
| 	--procedure    basic \ | ||||
| 	--save_dir     ${xsave_dir} \ | ||||
| 	--cutout_length -1 \ | ||||
| 	--batch_size 256 --rand_seed ${rseed} --workers 6 \ | ||||
| 	--eval_frequency 1 --print_freq 100 --print_freq_eval 200 | ||||
|  | ||||
| # KD training | ||||
| xsave_dir=${save_dir}-KDT | ||||
| OMP_NUM_THREADS=4 python ./exps/KD-main.py --dataset ${dataset} \ | ||||
| 	--data_path $TORCH_HOME/cifar.python \ | ||||
| 	--model_config ./configs/NeurIPS-2019/${model}.config \ | ||||
| 	--optim_config  ./configs/opts/CIFAR-E300-W5-L1-COS.config \ | ||||
| 	--KD_checkpoint ./.latent-data/basemodels/${dataset}/${model}.pth \ | ||||
| 	--procedure    Simple-KD \ | ||||
| 	--save_dir     ${xsave_dir} \ | ||||
| 	--KD_alpha 0.9 --KD_temperature 4 \ | ||||
| 	--cutout_length -1 \ | ||||
| 	--batch_size 256 --rand_seed ${rseed} --workers 6 \ | ||||
| 	--eval_frequency 1 --print_freq 100 --print_freq_eval 200 | ||||
		Reference in New Issue
	
	Block a user