update code styles
This commit is contained in:
		
							
								
								
									
										66
									
								
								BASELINE.md
									
									
									
									
									
								
							
							
						
						
									
										66
									
								
								BASELINE.md
									
									
									
									
									
								
							| @@ -40,39 +40,39 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k GDAS_ | |||||||
|  |  | ||||||
| ## Performance on ImageNet | ## Performance on ImageNet | ||||||
|  |  | ||||||
| |      Model     | FLOPs (GB) | Params (M) | Top-1 Error | Top-5 Error |  Optimizer | | |        Model      | FLOPs (GB) | Params (M) | Top-1 Error | Top-5 Error |  Optimizer | | ||||||
| |:--------------:|:----------:|:----------:|:-----------:|:-----------:|:----------:| | |:-----------------:|:----------:|:----------:|:-----------:|:-----------:|:----------:| | ||||||
| | ResNet-18      | 1.814      |  11.69     |   30.24     |   10.92     | Official   | | | ResNet-18         | 1.814      |  11.69     |   30.24     |   10.92     | Official   | | ||||||
| | ResNet-18      | 1.814      |  11.69     |   29.97     |   10.43     | Step-120   | | | ResNet-18         | 1.814      |  11.69     |   29.97     |   10.43     | Step-120   | | ||||||
| | ResNet-18      | 1.814      |  11.69     |   29.35     |   10.13     | Cosine-120 | | | ResNet-18         | 1.814      |  11.69     |   29.35     |   10.13     | Cosine-120 | | ||||||
| | ResNet-18      | 1.814      |  11.69     |   29.45     |   10.25     | Cosine-120 B1024 | | | ResNet-18         | 1.814      |  11.69     |   29.45     |   10.25     | Cosine-120 B1024 | | ||||||
| | ResNet-18      | 1.814      |  11.69     |   29.44     |   10.12     |Cosine-S-120| | | ResNet-18         | 1.814      |  11.69     |   29.44     |   10.12     | Cosine-S-120 | | ||||||
| | ResNet-18 (DS) | 2.053      |  11.71     |   28.53     |   9.69      |Cosine-S-120| | | ResNet-18 (DS)    | 2.053      |  11.71     |   28.53     |   9.69      | Cosine-S-120 | | ||||||
| | ResNet-34      | 3.663      |  21.80     |   25.65     |   8.06      |Cosine-120  | | | ResNet-34         | 3.663      |  21.80     |   25.65     |   8.06      | Cosine-120   | | ||||||
| | ResNet-34 (DS) | 3.903      |  21.82     |   25.05     |   7.67      |Cosine-S-120| | | ResNet-34 (DS)    | 3.903      |  21.82     |   25.05     |   7.67      | Cosine-S-120 | | ||||||
| | ResNet-50      | 4.089      |  25.56     |   23.85     |   7.13      | Official   | | | ResNet-50         | 4.089      |  25.56     |   23.85     |   7.13      | Official     | | ||||||
| | ResNet-50      | 4.089      |  25.56     |   22.54     |   6.45      |Cosine-120  | | | ResNet-50         | 4.089      |  25.56     |   22.54     |   6.45      | Cosine-120   | | ||||||
| | ResNet-50      | 4.089      |  25.56     |   22.71     |   6.38      |Cosine-120 B1024 | | | ResNet-50         | 4.089      |  25.56     |   22.71     |   6.38      | Cosine-120 B1024 | | ||||||
| | ResNet-50      | 4.089      |  25.56     |   22.34     |   6.22      |Cosine-S-120| | | ResNet-50         | 4.089      |  25.56     |   22.34     |   6.22      | Cosine-S-120 | | ||||||
| | ResNet-50 (DS) | 4.328      |  25.58     |   22.67     |   6.39      | Step-120   | | | ResNet-50 (DS)    | 4.328      |  25.58     |   22.67     |   6.39      | Step-120     | | ||||||
| | ResNet-50 (DS) | 4.328      |  25.58     |   21.94     |   6.23      | Cosine-120 | | | ResNet-50 (DS)    | 4.328      |  25.58     |   21.94     |   6.23      | Cosine-120   | | ||||||
| | ResNet-50 (DS) | 4.328      |  25.58     |   21.71     |   5.99      |Cosine-S-120| | | ResNet-50 (DS)    | 4.328      |  25.58     |   21.71     |   5.99      | Cosine-S-120 | | ||||||
| | ResNet-101     | 7.801      |  44.55     |   20.93     |   5.57      |Cosine-120  | | | ResNet-101        | 7.801      |  44.55     |   20.93     |   5.57      | Cosine-120   | | ||||||
| | ResNet-101     | 7.801      |  44.55     |   20.92     |   5.58      |Cosine-120 B1024 | | | ResNet-101        | 7.801      |  44.55     |   20.92     |   5.58      | Cosine-120 B1024 | | ||||||
| | ResNet-101 (DS)| 8.041      |  44.57     |   20.36     |   5.22      |Cosine-S-120| | | ResNet-101 (DS)   | 8.041      |  44.57     |   20.36     |   5.22      | Cosine-S-120 | | ||||||
| | ResNet-152     | 11.514     |  60.19     |   20.10     |   5.17      |Cosine-120 B1024 | | | ResNet-152        | 11.514     |  60.19     |   20.10     |   5.17      | Cosine-120 B1024 | | ||||||
| | ResNet-152 (DS)| 11.753     |  60.21     |   19.83     |   5.02      |Cosine-S-120| | | ResNet-152 (DS)   | 11.753     |  60.21     |   19.83     |   5.02      | Cosine-S-120 | | ||||||
| | ResNet-200     | 15.007     |  64.67     |   20.06     |   4.98      |Cosine-S-120| | | ResNet-200        | 15.007     |  64.67     |   20.06     |   4.98      | Cosine-S-120 | | ||||||
| | Next50-32x4d (DS)| 4.2      |  25.0      |   22.2      |     -       | Official   | | | Next50-32x4d (DS) | 4.2        |  25.0      |   22.2      |     -       | Official     | | ||||||
| | Next50-32x4d (DS)| 4.470    |  25.05     |   21.16     |   5.65      |Cosine-S-120| | | Next50-32x4d (DS) | 4.470      |  25.05     |   21.16     |   5.65      | Cosine-S-120 | | ||||||
| | MobileNet-V2   | 0.300      |  3.40      |   28.0      |     -       | Official   | | | MobileNet-V2      | 0.300      |  3.40      |   28.0      |     -       | Official     | | ||||||
| | MobileNet-V2   | 0.300      |  3.50      |   27.92     |   9.50      | MobileFast | | | MobileNet-V2      | 0.300      |  3.50      |   27.92     |   9.50      | MobileFast   | | ||||||
| | MobileNet-V2   | 0.300      |  3.50      |   27.56     |   9.26      | MobileFast-Smooth | | | MobileNet-V2      | 0.300      |  3.50      |   27.56     |   9.26      | MobileFast-Smooth | | ||||||
| | ShuffleNet-V2 1.0| 0.146    |  2.28      |   30.6      |   11.1      | Official   | | | ShuffleNet-V2 1.0 | 0.146      |  2.28      |   30.6      |   11.1      | Official     | | ||||||
| | ShuffleNet-V2 1.0| 0.145    |  2.28      |             |             |Cosine-S-120| | | ShuffleNet-V2 1.0 | 0.145      |  2.28      |             |             | Cosine-S-120 | | ||||||
| | ShuffleNet-V2 1.5| 0.299    |            |   27.4      |     -       | Official   | | | ShuffleNet-V2 1.5 | 0.299      |            |   27.4      |     -       | Official     | | ||||||
| | ShuffleNet-V2 1.5|          |            |             |             |Cosine-S-120| | | ShuffleNet-V2 1.5 |            |            |             |             | Cosine-S-120 | | ||||||
| | ShuffleNet-V2 2.0|          |            |             |             |Cosine-S-120| | | ShuffleNet-V2 2.0 |            |            |             |             | Cosine-S-120 | | ||||||
|  |  | ||||||
| `DS` indicates deep-stem for the first convolutional layer. | `DS` indicates deep-stem for the first convolutional layer. | ||||||
| ``` | ``` | ||||||
|   | |||||||
| @@ -4,7 +4,7 @@ | |||||||
|  |  | ||||||
| The following is a set of guidelines for contributing to NAS-Projects. | The following is a set of guidelines for contributing to NAS-Projects. | ||||||
|  |  | ||||||
| #### Table Of Contents | ## Table Of Contents | ||||||
|  |  | ||||||
| [How Can I Contribute?](#how-can-i-contribute) | [How Can I Contribute?](#how-can-i-contribute) | ||||||
|   * [Reporting Bugs](#reporting-bugs) |   * [Reporting Bugs](#reporting-bugs) | ||||||
|   | |||||||
| @@ -6,9 +6,9 @@ Each edge here is associated with an operation selected from a predefined operat | |||||||
| For it to be applicable for all NAS algorithms, the search space defined in NAS-Bench-102 includes 4 nodes and 5 associated operation options, which generates 15,625 neural cell candidates in total. | For it to be applicable for all NAS algorithms, the search space defined in NAS-Bench-102 includes 4 nodes and 5 associated operation options, which generates 15,625 neural cell candidates in total. | ||||||
|  |  | ||||||
| In this Markdown file, we provide: | In this Markdown file, we provide: | ||||||
| - [How to Use NAS-Bench-102](#how-to-use-nas-bench-102) | -	[How to Use NAS-Bench-102](#how-to-use-nas-bench-102) | ||||||
| - [Instruction to re-generate NAS-Bench-102](#instruction-to-re-generate-nas-bench-102) | -	[Instruction to re-generate NAS-Bench-102](#instruction-to-re-generate-nas-bench-102) | ||||||
| - [10 NAS algorithms evaluated in our paper](#to-reproduce-10-baseline-nas-algorithms-in-nas-bench-102) | -	[10 NAS algorithms evaluated in our paper](#to-reproduce-10-baseline-nas-algorithms-in-nas-bench-102) | ||||||
|  |  | ||||||
| Note: please use `PyTorch >= 1.2.0` and `Python >= 3.6.0`. | Note: please use `PyTorch >= 1.2.0` and `Python >= 3.6.0`. | ||||||
|  |  | ||||||
| @@ -140,6 +140,8 @@ This command will train 390 architectures (id from 0 to 389) using the following | |||||||
| | CIFAR-100       | train         | valid / test | | | CIFAR-100       | train         | valid / test | | ||||||
| | ImageNet-16-120 | train         | valid / test | | | ImageNet-16-120 | train         | valid / test | | ||||||
|  |  | ||||||
|  | Note that the above `train`, `valid`, and `test` indicate the proposed splits in our NAS-Bench-102, and they might be different with the original splits. | ||||||
|  |  | ||||||
| 3. calculate the latency, merge the results of all architectures, and simplify the results. | 3. calculate the latency, merge the results of all architectures, and simplify the results. | ||||||
| (see commands in `output/NAS-BENCH-102-4/meta-node-4.cal-script.txt` which is automatically generated by step-1). | (see commands in `output/NAS-BENCH-102-4/meta-node-4.cal-script.txt` which is automatically generated by step-1). | ||||||
| ``` | ``` | ||||||
| @@ -167,7 +169,7 @@ If researchers can provide better results with different hyper-parameters, we ar | |||||||
|  |  | ||||||
| **Note that** you need to prepare the training and test data as described in [Preparation and Download](#preparation-and-download) | **Note that** you need to prepare the training and test data as described in [Preparation and Download](#preparation-and-download) | ||||||
|  |  | ||||||
| - [1] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V1.sh cifar10 -1` | - [1] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V1.sh cifar10 -1`, where `cifar10` can be replaced with `cifar100` or `ImageNet16-120`. | ||||||
| - [2] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V2.sh cifar10 -1` | - [2] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V2.sh cifar10 -1` | ||||||
| - [3] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh     cifar10 -1` | - [3] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh     cifar10 -1` | ||||||
| - [4] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/SETN.sh     cifar10 -1` | - [4] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/SETN.sh     cifar10 -1` | ||||||
|   | |||||||
| @@ -8,7 +8,6 @@ from tqdm import tqdm | |||||||
| from collections import OrderedDict | from collections import OrderedDict | ||||||
| import numpy as np | import numpy as np | ||||||
| import torch | import torch | ||||||
| import torch.nn as nn |  | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from collections import defaultdict | from collections import defaultdict | ||||||
| import matplotlib | import matplotlib | ||||||
| @@ -498,6 +497,8 @@ def show_nas_sharing_w(api, dataset, subset, vis_save_dir, file_name, y_lims, x_ | |||||||
|  |  | ||||||
|   def get_accs(xdata): |   def get_accs(xdata): | ||||||
|     epochs, xresults = xdata['epoch'], [] |     epochs, xresults = xdata['epoch'], [] | ||||||
|  |     metrics = api.arch2infos_full[ api.random() ].get_metrics(dataset, subset, None, False) | ||||||
|  |     xresults.append( metrics['accuracy'] ) | ||||||
|     for iepoch in range(epochs): |     for iepoch in range(epochs): | ||||||
|       genotype = xdata['genotypes'][iepoch] |       genotype = xdata['genotypes'][iepoch] | ||||||
|       index = api.query_index_by_arch(genotype) |       index = api.query_index_by_arch(genotype) | ||||||
| @@ -547,7 +548,6 @@ if __name__ == '__main__': | |||||||
|   #visualize_relative_ranking(vis_save_dir) |   #visualize_relative_ranking(vis_save_dir) | ||||||
|  |  | ||||||
|   api = API(args.api_path) |   api = API(args.api_path) | ||||||
|   """ |  | ||||||
|   for x_maxs in [50, 250]: |   for x_maxs in [50, 250]: | ||||||
|     show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) |     show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||||
|     show_nas_sharing_w(api, 'cifar10'       , 'ori-test', vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) |     show_nas_sharing_w(api, 'cifar10'       , 'ori-test', vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||||
| @@ -555,11 +555,12 @@ if __name__ == '__main__': | |||||||
|     show_nas_sharing_w(api, 'cifar100'      , 'x-test'  , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) |     show_nas_sharing_w(api, 'cifar100'      , 'x-test'  , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||||
|     show_nas_sharing_w(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) |     show_nas_sharing_w(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||||
|     show_nas_sharing_w(api, 'ImageNet16-120', 'x-test'  , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) |     show_nas_sharing_w(api, 'ImageNet16-120', 'x-test'  , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||||
|   just_show(api) |  | ||||||
|   """ |   """ | ||||||
|  |   just_show(api) | ||||||
|   plot_results_nas(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (85,95, 1)) |   plot_results_nas(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (85,95, 1)) | ||||||
|   plot_results_nas(api, 'cifar10'       , 'ori-test', vis_save_dir, 'nas-com.pdf', (85,95, 1)) |   plot_results_nas(api, 'cifar10'       , 'ori-test', vis_save_dir, 'nas-com.pdf', (85,95, 1)) | ||||||
|   plot_results_nas(api, 'cifar100'      , 'x-valid' , vis_save_dir, 'nas-com.pdf', (55,75, 3)) |   plot_results_nas(api, 'cifar100'      , 'x-valid' , vis_save_dir, 'nas-com.pdf', (55,75, 3)) | ||||||
|   plot_results_nas(api, 'cifar100'      , 'x-test'  , vis_save_dir, 'nas-com.pdf', (55,75, 3)) |   plot_results_nas(api, 'cifar100'      , 'x-test'  , vis_save_dir, 'nas-com.pdf', (55,75, 3)) | ||||||
|   plot_results_nas(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-com.pdf', (35,50, 3)) |   plot_results_nas(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-com.pdf', (35,50, 3)) | ||||||
|   plot_results_nas(api, 'ImageNet16-120', 'x-test'  , vis_save_dir, 'nas-com.pdf', (35,50, 3)) |   plot_results_nas(api, 'ImageNet16-120', 'x-test'  , vis_save_dir, 'nas-com.pdf', (35,50, 3)) | ||||||
|  |   """ | ||||||
|   | |||||||
| @@ -10,7 +10,6 @@ from copy import deepcopy | |||||||
| from pathlib import Path | from pathlib import Path | ||||||
| import torch | import torch | ||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| from torch.distributions import Categorical |  | ||||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||||
| from config_utils import load_config, dict2config, configure2str | from config_utils import load_config, dict2config, configure2str | ||||||
|   | |||||||
| @@ -121,9 +121,19 @@ def main(xargs): | |||||||
|     search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) |     search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) | ||||||
|     valid_loader  = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) |     valid_loader  = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) | ||||||
|   elif xargs.dataset == 'cifar100': |   elif xargs.dataset == 'cifar100': | ||||||
|     raise ValueError('not support yet : {:}'.format(xargs.dataset)) |     cifar100_test_split = load_config('configs/nas-benchmark/cifar100-test-split.txt', None, None) | ||||||
|   elif xargs.dataset.startswith('ImageNet16'): |     search_train_data = train_data | ||||||
|     raise ValueError('not support yet : {:}'.format(xargs.dataset)) |     search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform | ||||||
|  |     search_data   = SearchDataset(xargs.dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), cifar100_test_split.xvalid) | ||||||
|  |     search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) | ||||||
|  |     valid_loader  = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_test_split.xvalid), num_workers=xargs.workers, pin_memory=True) | ||||||
|  |   elif xargs.dataset == 'ImageNet16-120': | ||||||
|  |     imagenet_test_split = load_config('configs/nas-benchmark/imagenet-16-120-test-split.txt', None, None) | ||||||
|  |     search_train_data = train_data | ||||||
|  |     search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform | ||||||
|  |     search_data   = SearchDataset(xargs.dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), imagenet_test_split.xvalid) | ||||||
|  |     search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) | ||||||
|  |     valid_loader  = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_test_split.xvalid), num_workers=xargs.workers, pin_memory=True) | ||||||
|   else: |   else: | ||||||
|     raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) |     raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) | ||||||
|   logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) |   logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) | ||||||
| @@ -168,7 +178,7 @@ def main(xargs): | |||||||
|     logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) |     logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) | ||||||
|   else: |   else: | ||||||
|     logger.log("=> do not find the last-info file : {:}".format(last_info)) |     logger.log("=> do not find the last-info file : {:}".format(last_info)) | ||||||
|     start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {} |     start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: search_model.genotype()} | ||||||
|  |  | ||||||
|   # start training |   # start training | ||||||
|   start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup |   start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup | ||||||
| @@ -230,7 +240,7 @@ if __name__ == '__main__': | |||||||
|   parser.add_argument('--data_path',          type=str,   help='Path to dataset') |   parser.add_argument('--data_path',          type=str,   help='Path to dataset') | ||||||
|   parser.add_argument('--dataset',            type=str,   choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') |   parser.add_argument('--dataset',            type=str,   choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') | ||||||
|   # channels and number-of-cells |   # channels and number-of-cells | ||||||
|   parser.add_argument('--config_path',        type=str,   help='The config paths.') |   parser.add_argument('--config_path',        type=str,   help='The config path.') | ||||||
|   parser.add_argument('--search_space_name',  type=str,   help='The search space name.') |   parser.add_argument('--search_space_name',  type=str,   help='The search space name.') | ||||||
|   parser.add_argument('--max_nodes',          type=int,   help='The maximum number of nodes.') |   parser.add_argument('--max_nodes',          type=int,   help='The maximum number of nodes.') | ||||||
|   parser.add_argument('--channel',            type=int,   help='The number of channels.') |   parser.add_argument('--channel',            type=int,   help='The number of channels.') | ||||||
|   | |||||||
| @@ -181,8 +181,8 @@ def main(xargs): | |||||||
|     logger.log('Load split file from {:}'.format(split_Fpath)) |     logger.log('Load split file from {:}'.format(split_Fpath)) | ||||||
|   else: |   else: | ||||||
|     raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) |     raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) | ||||||
|   config_path = 'configs/nas-benchmark/algos/DARTS.config' |   #config_path = 'configs/nas-benchmark/algos/DARTS.config' | ||||||
|   config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) |   config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||||
|   # To split data |   # To split data | ||||||
|   train_data_v2 = deepcopy(train_data) |   train_data_v2 = deepcopy(train_data) | ||||||
|   train_data_v2.transform = valid_data.transform |   train_data_v2.transform = valid_data.transform | ||||||
| @@ -233,7 +233,7 @@ def main(xargs): | |||||||
|     logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) |     logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) | ||||||
|   else: |   else: | ||||||
|     logger.log("=> do not find the last-info file : {:}".format(last_info)) |     logger.log("=> do not find the last-info file : {:}".format(last_info)) | ||||||
|     start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {} |     start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: search_model.genotype()} | ||||||
|  |  | ||||||
|   # start training |   # start training | ||||||
|   start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup |   start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup | ||||||
| @@ -297,6 +297,7 @@ if __name__ == '__main__': | |||||||
|   parser.add_argument('--data_path',          type=str,   help='Path to dataset') |   parser.add_argument('--data_path',          type=str,   help='Path to dataset') | ||||||
|   parser.add_argument('--dataset',            type=str,   choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') |   parser.add_argument('--dataset',            type=str,   choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') | ||||||
|   # channels and number-of-cells |   # channels and number-of-cells | ||||||
|  |   parser.add_argument('--config_path',        type=str,   help='The config path.') | ||||||
|   parser.add_argument('--search_space_name',  type=str,   help='The search space name.') |   parser.add_argument('--search_space_name',  type=str,   help='The search space name.') | ||||||
|   parser.add_argument('--max_nodes',          type=int,   help='The maximum number of nodes.') |   parser.add_argument('--max_nodes',          type=int,   help='The maximum number of nodes.') | ||||||
|   parser.add_argument('--channel',            type=int,   help='The number of channels.') |   parser.add_argument('--channel',            type=int,   help='The number of channels.') | ||||||
|   | |||||||
| @@ -3,7 +3,7 @@ | |||||||
| ########################################################################### | ########################################################################### | ||||||
| # Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 # | # Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 # | ||||||
| ########################################################################### | ########################################################################### | ||||||
| import os, sys, time, glob, random, argparse | import os, sys, time, random, argparse | ||||||
| import numpy as np | import numpy as np | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| import torch | import torch | ||||||
| @@ -11,7 +11,7 @@ import torch.nn as nn | |||||||
| from pathlib import Path | from pathlib import Path | ||||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||||
| from config_utils import load_config, dict2config, configure2str | from config_utils import load_config, dict2config | ||||||
| from datasets     import get_datasets, SearchDataset | from datasets     import get_datasets, SearchDataset | ||||||
| from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler | from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler | ||||||
| from utils        import get_model_infos, obtain_accuracy | from utils        import get_model_infos, obtain_accuracy | ||||||
|   | |||||||
| @@ -1,12 +1,14 @@ | |||||||
| # python ./exps/vis/test.py | # python ./exps/vis/test.py | ||||||
| import os, sys, random | import os, sys, random | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
|  | from copy import deepcopy | ||||||
| import torch | import torch | ||||||
| import numpy as np | import numpy as np | ||||||
| from collections import OrderedDict | from collections import OrderedDict | ||||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||||
|  |  | ||||||
|  | from nas_102_api import NASBench102API as API | ||||||
|  |  | ||||||
| def test_nas_api(): | def test_nas_api(): | ||||||
|   from nas_102_api import ArchResults |   from nas_102_api import ArchResults | ||||||
| @@ -72,7 +74,40 @@ def test_auto_grad(): | |||||||
|     s_grads = torch.autograd.grad(grads, net.parameters()) |     s_grads = torch.autograd.grad(grads, net.parameters()) | ||||||
|     second_order_grads.append( s_grads ) |     second_order_grads.append( s_grads ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_one_shot_model(ckpath, use_train): | ||||||
|  |   from models import get_cell_based_tiny_net, get_search_spaces | ||||||
|  |   from datasets import get_datasets, SearchDataset | ||||||
|  |   from config_utils import load_config, dict2config | ||||||
|  |   from utils.nas_utils import evaluate_one_shot | ||||||
|  |   use_train = int(use_train) > 0 | ||||||
|  |   #ckpath = 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/seed-11416-basic.pth' | ||||||
|  |   #ckpath = 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/seed-28640-basic.pth' | ||||||
|  |   print ('ckpath : {:}'.format(ckpath)) | ||||||
|  |   ckp = torch.load(ckpath) | ||||||
|  |   xargs = ckp['args'] | ||||||
|  |   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||||
|  |   config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None) | ||||||
|  |   if xargs.dataset == 'cifar10': | ||||||
|  |     cifar_split = load_config('configs/nas-benchmark/cifar-split.txt', None, None) | ||||||
|  |     xvalid_data = deepcopy(train_data) | ||||||
|  |     xvalid_data.transform = valid_data.transform | ||||||
|  |     valid_loader= torch.utils.data.DataLoader(xvalid_data, batch_size=2048, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar_split.valid), num_workers=12, pin_memory=True) | ||||||
|  |   else: raise ValueError('invalid dataset : {:}'.format(xargs.dataseet)) | ||||||
|  |   search_space = get_search_spaces('cell', xargs.search_space_name) | ||||||
|  |   model_config = dict2config({'name': 'SETN', 'C': xargs.channel, 'N': xargs.num_cells, | ||||||
|  |                               'max_nodes': xargs.max_nodes, 'num_classes': class_num, | ||||||
|  |                               'space'    : search_space, | ||||||
|  |                               'affine'   : False, 'track_running_stats': True}, None) | ||||||
|  |   search_model = get_cell_based_tiny_net(model_config) | ||||||
|  |   search_model.load_state_dict( ckp['search_model'] ) | ||||||
|  |   search_model = search_model.cuda() | ||||||
|  |   api = API('/home/dxy/.torch/NAS-Bench-102-v1_0-e61699.pth') | ||||||
|  |   archs, probs, accuracies = evaluate_one_shot(search_model, valid_loader, api, use_train) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|   #test_nas_api() |   #test_nas_api() | ||||||
|   #for i in range(200): plot('{:04d}'.format(i)) |   #for i in range(200): plot('{:04d}'.format(i)) | ||||||
|   test_auto_grad() |   #test_auto_grad() | ||||||
|  |   test_one_shot_model(sys.argv[1], sys.argv[2]) | ||||||
|   | |||||||
| @@ -9,16 +9,25 @@ class SearchDataset(data.Dataset): | |||||||
|  |  | ||||||
|   def __init__(self, name, data, train_split, valid_split, check=True): |   def __init__(self, name, data, train_split, valid_split, check=True): | ||||||
|     self.datasetname = name |     self.datasetname = name | ||||||
|     self.data        = data |     if isinstance(data, (list, tuple)): # new type of SearchDataset | ||||||
|     self.train_split = train_split.copy() |       assert len(data) == 2, 'invalid length: {:}'.format( len(data) ) | ||||||
|     self.valid_split = valid_split.copy() |       self.train_data  = data[0] | ||||||
|     if check: |       self.valid_data  = data[1] | ||||||
|       intersection = set(train_split).intersection(set(valid_split)) |       self.train_split = train_split.copy() | ||||||
|       assert len(intersection) == 0, 'the splitted train and validation sets should have no intersection' |       self.valid_split = valid_split.copy() | ||||||
|  |       self.mode_str    = 'V2' # new mode  | ||||||
|  |     else: | ||||||
|  |       self.mode_str    = 'V1' # old mode  | ||||||
|  |       self.data        = data | ||||||
|  |       self.train_split = train_split.copy() | ||||||
|  |       self.valid_split = valid_split.copy() | ||||||
|  |       if check: | ||||||
|  |         intersection = set(train_split).intersection(set(valid_split)) | ||||||
|  |         assert len(intersection) == 0, 'the splitted train and validation sets should have no intersection' | ||||||
|     self.length      = len(self.train_split) |     self.length      = len(self.train_split) | ||||||
|  |  | ||||||
|   def __repr__(self): |   def __repr__(self): | ||||||
|     return ('{name}(name={datasetname}, train={tr_L}, valid={val_L})'.format(name=self.__class__.__name__, datasetname=self.datasetname, tr_L=len(self.train_split), val_L=len(self.valid_split))) |     return ('{name}(name={datasetname}, train={tr_L}, valid={val_L}, version={ver})'.format(name=self.__class__.__name__, datasetname=self.datasetname, tr_L=len(self.train_split), val_L=len(self.valid_split), ver=self.mode_str)) | ||||||
|  |  | ||||||
|   def __len__(self): |   def __len__(self): | ||||||
|     return self.length |     return self.length | ||||||
| @@ -27,6 +36,11 @@ class SearchDataset(data.Dataset): | |||||||
|     assert index >= 0 and index < self.length, 'invalid index = {:}'.format(index) |     assert index >= 0 and index < self.length, 'invalid index = {:}'.format(index) | ||||||
|     train_index = self.train_split[index] |     train_index = self.train_split[index] | ||||||
|     valid_index = random.choice( self.valid_split ) |     valid_index = random.choice( self.valid_split ) | ||||||
|     train_image, train_label = self.data[train_index] |     if self.mode_str == 'V1': | ||||||
|     valid_image, valid_label = self.data[valid_index] |       train_image, train_label = self.data[train_index] | ||||||
|  |       valid_image, valid_label = self.data[valid_index] | ||||||
|  |     elif self.mode_str == 'V2': | ||||||
|  |       train_image, train_label = self.train_data[train_index] | ||||||
|  |       valid_image, valid_label = self.valid_data[valid_index] | ||||||
|  |     else: raise ValueError('invalid mode : {:}'.format(self.mode_str)) | ||||||
|     return train_image, train_label, valid_image, valid_label |     return train_image, train_label, valid_image, valid_label | ||||||
|   | |||||||
| @@ -34,7 +34,7 @@ class PointMeta(): | |||||||
|  |  | ||||||
|   def get_box(self, return_diagonal=False): |   def get_box(self, return_diagonal=False): | ||||||
|     if self.box is None: return None |     if self.box is None: return None | ||||||
|     if return_diagonal == False: |     if not return_diagonal: | ||||||
|       return self.box.clone() |       return self.box.clone() | ||||||
|     else: |     else: | ||||||
|       W = (self.box[2]-self.box[0]).item() |       W = (self.box[2]-self.box[0]).item() | ||||||
|   | |||||||
| @@ -1,4 +1,3 @@ | |||||||
| import torch |  | ||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from ..cell_operations import OPS | from ..cell_operations import OPS | ||||||
|   | |||||||
| @@ -68,7 +68,7 @@ class Structure: | |||||||
|     for i, node_info in enumerate(self.nodes): |     for i, node_info in enumerate(self.nodes): | ||||||
|       sums = [] |       sums = [] | ||||||
|       for op, xin in node_info: |       for op, xin in node_info: | ||||||
|         if op == 'none' or nodes[xin] == False: x = False |         if op == 'none' or nodes[xin] is False: x = False | ||||||
|         else: x = True |         else: x = True | ||||||
|         sums.append( x ) |         sums.append( x ) | ||||||
|       nodes[i+1] = sum(sums) > 0 |       nodes[i+1] = sum(sums) > 0 | ||||||
|   | |||||||
| @@ -85,7 +85,7 @@ class SearchCell(nn.Module): | |||||||
|           candidates = self.edges[node_str] |           candidates = self.edges[node_str] | ||||||
|           select_op  = random.choice(candidates) |           select_op  = random.choice(candidates) | ||||||
|           sops.append( select_op ) |           sops.append( select_op ) | ||||||
|           if not hasattr(select_op, 'is_zero') or select_op.is_zero == False: has_non_zero=True |           if not hasattr(select_op, 'is_zero') or select_op.is_zero is False: has_non_zero=True | ||||||
|         if has_non_zero: break |         if has_non_zero: break | ||||||
|       inter_nodes = [] |       inter_nodes = [] | ||||||
|       for j, select_op in enumerate(sops): |       for j, select_op in enumerate(sops): | ||||||
|   | |||||||
| @@ -1,4 +1,4 @@ | |||||||
| import math, torch | import math | ||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||||
| from ..initialization import initialize_resnet | from ..initialization import initialize_resnet | ||||||
|   | |||||||
| @@ -70,6 +70,9 @@ class NASBench102API(object): | |||||||
|   def __repr__(self): |   def __repr__(self): | ||||||
|     return ('{name}({num}/{total} architectures)'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs))) |     return ('{name}({num}/{total} architectures)'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs))) | ||||||
|  |  | ||||||
|  |   def random(self): | ||||||
|  |     return random.randint(0, len(self.meta_archs)-1) | ||||||
|  |  | ||||||
|   def query_index_by_arch(self, arch): |   def query_index_by_arch(self, arch): | ||||||
|     if isinstance(arch, str): |     if isinstance(arch, str): | ||||||
|       if arch in self.archstr2index: arch_index = self.archstr2index[ arch ] |       if arch in self.archstr2index: arch_index = self.archstr2index[ arch ] | ||||||
|   | |||||||
| @@ -1,7 +1,7 @@ | |||||||
| ################################################## | ################################################## | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||||
| ################################################## | ################################################## | ||||||
| import os, sys, time, torch, random, PIL, copy, numpy as np | import os, sys, torch, random, PIL, copy, numpy as np | ||||||
| from os import path as osp | from os import path as osp | ||||||
| from shutil  import copyfile | from shutil  import copyfile | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,3 +1,5 @@ | |||||||
| from .evaluation_utils import obtain_accuracy | from .evaluation_utils import obtain_accuracy | ||||||
| from .gpu_manager      import GPUManager | from .gpu_manager      import GPUManager | ||||||
| from .flop_benchmark   import get_model_infos | from .flop_benchmark   import get_model_infos | ||||||
|  | from .affine_utils     import normalize_points, denormalize_points | ||||||
|  | from .affine_utils     import identity2affine, solve2theta, affine2image | ||||||
|   | |||||||
| @@ -1,10 +1,3 @@ | |||||||
| # Copyright (c) Facebook, Inc. and its affiliates. |  | ||||||
| # All rights reserved. |  | ||||||
| # |  | ||||||
| # This source code is licensed under the license found in the |  | ||||||
| # LICENSE file in the root directory of this source tree. |  | ||||||
| # |  | ||||||
| # |  | ||||||
| # functions for affine transformation | # functions for affine transformation | ||||||
| import math, torch | import math, torch | ||||||
| import numpy as np | import numpy as np | ||||||
| @@ -1,4 +1,4 @@ | |||||||
| import copy, torch | import torch | ||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| import numpy as np | import numpy as np | ||||||
|  |  | ||||||
|   | |||||||
| @@ -27,7 +27,7 @@ class GPUManager(): | |||||||
|         find = False |         find = False | ||||||
|         for gpu in all_gpus: |         for gpu in all_gpus: | ||||||
|           if gpu['index'] == CUDA_VISIBLE_DEVICE: |           if gpu['index'] == CUDA_VISIBLE_DEVICE: | ||||||
|             assert find==False, 'Duplicate cuda device index : {}'.format(CUDA_VISIBLE_DEVICE) |             assert not find, 'Duplicate cuda device index : {}'.format(CUDA_VISIBLE_DEVICE) | ||||||
|             find = True |             find = True | ||||||
|             selected_gpus.append( gpu.copy() ) |             selected_gpus.append( gpu.copy() ) | ||||||
|             selected_gpus[-1]['index'] = '{}'.format(idx) |             selected_gpus[-1]['index'] = '{}'.format(idx) | ||||||
|   | |||||||
							
								
								
									
										52
									
								
								lib/utils/nas_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								lib/utils/nas_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,52 @@ | |||||||
|  | # This file is for experimental usage | ||||||
|  | import os, sys, torch, random | ||||||
|  | import numpy as np | ||||||
|  | from copy import deepcopy | ||||||
|  | from tqdm import tqdm | ||||||
|  | import torch.nn as nn | ||||||
|  |  | ||||||
|  | from utils  import obtain_accuracy | ||||||
|  | from models import CellStructure | ||||||
|  | from log_utils import time_string | ||||||
|  |  | ||||||
|  | def evaluate_one_shot(model, xloader, api, cal_mode, seed=111): | ||||||
|  |   weights = deepcopy(model.state_dict()) | ||||||
|  |   model.train(cal_mode) | ||||||
|  |   with torch.no_grad(): | ||||||
|  |     logits = nn.functional.log_softmax(model.arch_parameters, dim=-1) | ||||||
|  |     archs = CellStructure.gen_all(model.op_names, model.max_nodes, False) | ||||||
|  |     probs, accuracies, gt_accs = [], [], [] | ||||||
|  |     loader_iter = iter(xloader) | ||||||
|  |     random.seed(seed) | ||||||
|  |     random.shuffle(archs) | ||||||
|  |     for idx, arch in enumerate(archs): | ||||||
|  |       arch_index = api.query_index_by_arch( arch ) | ||||||
|  |       metrics = api.get_more_info(arch_index, 'cifar10-valid', None, False, False) | ||||||
|  |       gt_accs.append( metrics['valid-accuracy'] ) | ||||||
|  |       select_logits = [] | ||||||
|  |       for i, node_info in enumerate(arch.nodes): | ||||||
|  |         for op, xin in node_info: | ||||||
|  |           node_str = '{:}<-{:}'.format(i+1, xin) | ||||||
|  |           op_index = model.op_names.index(op) | ||||||
|  |           select_logits.append( logits[model.edge2index[node_str], op_index] ) | ||||||
|  |       cur_prob = sum(select_logits).item() | ||||||
|  |       probs.append( cur_prob ) | ||||||
|  |     cor_prob = np.corrcoef(probs, gt_accs)[0,1] | ||||||
|  |     print ('correlation for probabilities : {:}'.format(cor_prob)) | ||||||
|  |        | ||||||
|  |     for idx, arch in enumerate(archs): | ||||||
|  |       model.set_cal_mode('dynamic', arch) | ||||||
|  |       try: | ||||||
|  |         inputs, targets = next(loader_iter) | ||||||
|  |       except: | ||||||
|  |         loader_iter = iter(xloader) | ||||||
|  |         inputs, targets = next(loader_iter) | ||||||
|  |       _, logits = model(inputs.cuda()) | ||||||
|  |       _, preds  = torch.max(logits, dim=-1) | ||||||
|  |       correct = (preds == targets.cuda() ).float() | ||||||
|  |       accuracies.append( correct.mean().item() ) | ||||||
|  |       if idx != 0 and (idx % 300 == 0 or idx + 1 == len(archs) or idx == 10): | ||||||
|  |         cor_accs = np.corrcoef(accuracies, gt_accs[:idx+1])[0,1] | ||||||
|  |         print ('{:} {:03d}/{:03d} mode={:5s}, correlation : accs={:.4f}, arch={:}'.format(time_string(), idx, len(archs), 'Train' if cal_mode else 'Eval', cor_accs, arch)) | ||||||
|  |   model.load_state_dict(weights) | ||||||
|  |   return archs, probs, accuracies | ||||||
| @@ -1 +0,0 @@ | |||||||
| from .affine_utils import normalize_points, denormalize_points |  | ||||||
| @@ -33,6 +33,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V2.py \ | |||||||
| 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | ||||||
| 	--dataset ${dataset} --data_path ${data_path} \ | 	--dataset ${dataset} --data_path ${data_path} \ | ||||||
| 	--search_space_name ${space} \ | 	--search_space_name ${space} \ | ||||||
|  | 	--config_path configs/nas-benchmark/algos/DARTS.config \ | ||||||
| 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | ||||||
| 	--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ | 	--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ | ||||||
| 	--workers 4 --print_freq 200 --rand_seed ${seed} | 	--workers 4 --print_freq 200 --rand_seed ${seed} | ||||||
|   | |||||||
| @@ -1,6 +1,6 @@ | |||||||
| #!/bin/bash | #!/bin/bash | ||||||
| # bash ./scripts/prepare.sh | # bash ./scripts/prepare.sh | ||||||
| datasets='cifar10 cifar100 imagenet-1k' | #datasets='cifar10 cifar100 imagenet-1k' | ||||||
| #ratios='0.5 0.8 0.9' | #ratios='0.5 0.8 0.9' | ||||||
| ratios='0.5' | ratios='0.5' | ||||||
| save_dir=./.latent-data/splits | save_dir=./.latent-data/splits | ||||||
|   | |||||||
| @@ -33,7 +33,7 @@ OMP_NUM_THREADS=4 python ./exps/basic-main.py --dataset ${dataset} \ | |||||||
| 	--procedure    basic \ | 	--procedure    basic \ | ||||||
| 	--save_dir     ${xsave_dir} \ | 	--save_dir     ${xsave_dir} \ | ||||||
| 	--cutout_length -1 \ | 	--cutout_length -1 \ | ||||||
| 	--batch_size 256 --rand_seed ${rseed} --workers 6 \ | 	--batch_size ${batch} --rand_seed ${rseed} --workers 6 \ | ||||||
| 	--eval_frequency 1 --print_freq 100 --print_freq_eval 200 | 	--eval_frequency 1 --print_freq 100 --print_freq_eval 200 | ||||||
|  |  | ||||||
| # KD training | # KD training | ||||||
| @@ -47,5 +47,5 @@ OMP_NUM_THREADS=4 python ./exps/KD-main.py --dataset ${dataset} \ | |||||||
| 	--save_dir     ${xsave_dir} \ | 	--save_dir     ${xsave_dir} \ | ||||||
| 	--KD_alpha 0.9 --KD_temperature 4 \ | 	--KD_alpha 0.9 --KD_temperature 4 \ | ||||||
| 	--cutout_length -1 \ | 	--cutout_length -1 \ | ||||||
| 	--batch_size 256 --rand_seed ${rseed} --workers 6 \ | 	--batch_size ${batch} --rand_seed ${rseed} --workers 6 \ | ||||||
| 	--eval_frequency 1 --print_freq 100 --print_freq_eval 200 | 	--eval_frequency 1 --print_freq 100 --print_freq_eval 200 | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user