update code styles
This commit is contained in:
parent
5ac5060a33
commit
ad34af9913
66
BASELINE.md
66
BASELINE.md
@ -40,39 +40,39 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k GDAS_
|
||||
|
||||
## Performance on ImageNet
|
||||
|
||||
| Model | FLOPs (GB) | Params (M) | Top-1 Error | Top-5 Error | Optimizer |
|
||||
|:--------------:|:----------:|:----------:|:-----------:|:-----------:|:----------:|
|
||||
| ResNet-18 | 1.814 | 11.69 | 30.24 | 10.92 | Official |
|
||||
| ResNet-18 | 1.814 | 11.69 | 29.97 | 10.43 | Step-120 |
|
||||
| ResNet-18 | 1.814 | 11.69 | 29.35 | 10.13 | Cosine-120 |
|
||||
| ResNet-18 | 1.814 | 11.69 | 29.45 | 10.25 | Cosine-120 B1024 |
|
||||
| ResNet-18 | 1.814 | 11.69 | 29.44 | 10.12 |Cosine-S-120|
|
||||
| ResNet-18 (DS) | 2.053 | 11.71 | 28.53 | 9.69 |Cosine-S-120|
|
||||
| ResNet-34 | 3.663 | 21.80 | 25.65 | 8.06 |Cosine-120 |
|
||||
| ResNet-34 (DS) | 3.903 | 21.82 | 25.05 | 7.67 |Cosine-S-120|
|
||||
| ResNet-50 | 4.089 | 25.56 | 23.85 | 7.13 | Official |
|
||||
| ResNet-50 | 4.089 | 25.56 | 22.54 | 6.45 |Cosine-120 |
|
||||
| ResNet-50 | 4.089 | 25.56 | 22.71 | 6.38 |Cosine-120 B1024 |
|
||||
| ResNet-50 | 4.089 | 25.56 | 22.34 | 6.22 |Cosine-S-120|
|
||||
| ResNet-50 (DS) | 4.328 | 25.58 | 22.67 | 6.39 | Step-120 |
|
||||
| ResNet-50 (DS) | 4.328 | 25.58 | 21.94 | 6.23 | Cosine-120 |
|
||||
| ResNet-50 (DS) | 4.328 | 25.58 | 21.71 | 5.99 |Cosine-S-120|
|
||||
| ResNet-101 | 7.801 | 44.55 | 20.93 | 5.57 |Cosine-120 |
|
||||
| ResNet-101 | 7.801 | 44.55 | 20.92 | 5.58 |Cosine-120 B1024 |
|
||||
| ResNet-101 (DS)| 8.041 | 44.57 | 20.36 | 5.22 |Cosine-S-120|
|
||||
| ResNet-152 | 11.514 | 60.19 | 20.10 | 5.17 |Cosine-120 B1024 |
|
||||
| ResNet-152 (DS)| 11.753 | 60.21 | 19.83 | 5.02 |Cosine-S-120|
|
||||
| ResNet-200 | 15.007 | 64.67 | 20.06 | 4.98 |Cosine-S-120|
|
||||
| Next50-32x4d (DS)| 4.2 | 25.0 | 22.2 | - | Official |
|
||||
| Next50-32x4d (DS)| 4.470 | 25.05 | 21.16 | 5.65 |Cosine-S-120|
|
||||
| MobileNet-V2 | 0.300 | 3.40 | 28.0 | - | Official |
|
||||
| MobileNet-V2 | 0.300 | 3.50 | 27.92 | 9.50 | MobileFast |
|
||||
| MobileNet-V2 | 0.300 | 3.50 | 27.56 | 9.26 | MobileFast-Smooth |
|
||||
| ShuffleNet-V2 1.0| 0.146 | 2.28 | 30.6 | 11.1 | Official |
|
||||
| ShuffleNet-V2 1.0| 0.145 | 2.28 | | |Cosine-S-120|
|
||||
| ShuffleNet-V2 1.5| 0.299 | | 27.4 | - | Official |
|
||||
| ShuffleNet-V2 1.5| | | | |Cosine-S-120|
|
||||
| ShuffleNet-V2 2.0| | | | |Cosine-S-120|
|
||||
| Model | FLOPs (GB) | Params (M) | Top-1 Error | Top-5 Error | Optimizer |
|
||||
|:-----------------:|:----------:|:----------:|:-----------:|:-----------:|:----------:|
|
||||
| ResNet-18 | 1.814 | 11.69 | 30.24 | 10.92 | Official |
|
||||
| ResNet-18 | 1.814 | 11.69 | 29.97 | 10.43 | Step-120 |
|
||||
| ResNet-18 | 1.814 | 11.69 | 29.35 | 10.13 | Cosine-120 |
|
||||
| ResNet-18 | 1.814 | 11.69 | 29.45 | 10.25 | Cosine-120 B1024 |
|
||||
| ResNet-18 | 1.814 | 11.69 | 29.44 | 10.12 | Cosine-S-120 |
|
||||
| ResNet-18 (DS) | 2.053 | 11.71 | 28.53 | 9.69 | Cosine-S-120 |
|
||||
| ResNet-34 | 3.663 | 21.80 | 25.65 | 8.06 | Cosine-120 |
|
||||
| ResNet-34 (DS) | 3.903 | 21.82 | 25.05 | 7.67 | Cosine-S-120 |
|
||||
| ResNet-50 | 4.089 | 25.56 | 23.85 | 7.13 | Official |
|
||||
| ResNet-50 | 4.089 | 25.56 | 22.54 | 6.45 | Cosine-120 |
|
||||
| ResNet-50 | 4.089 | 25.56 | 22.71 | 6.38 | Cosine-120 B1024 |
|
||||
| ResNet-50 | 4.089 | 25.56 | 22.34 | 6.22 | Cosine-S-120 |
|
||||
| ResNet-50 (DS) | 4.328 | 25.58 | 22.67 | 6.39 | Step-120 |
|
||||
| ResNet-50 (DS) | 4.328 | 25.58 | 21.94 | 6.23 | Cosine-120 |
|
||||
| ResNet-50 (DS) | 4.328 | 25.58 | 21.71 | 5.99 | Cosine-S-120 |
|
||||
| ResNet-101 | 7.801 | 44.55 | 20.93 | 5.57 | Cosine-120 |
|
||||
| ResNet-101 | 7.801 | 44.55 | 20.92 | 5.58 | Cosine-120 B1024 |
|
||||
| ResNet-101 (DS) | 8.041 | 44.57 | 20.36 | 5.22 | Cosine-S-120 |
|
||||
| ResNet-152 | 11.514 | 60.19 | 20.10 | 5.17 | Cosine-120 B1024 |
|
||||
| ResNet-152 (DS) | 11.753 | 60.21 | 19.83 | 5.02 | Cosine-S-120 |
|
||||
| ResNet-200 | 15.007 | 64.67 | 20.06 | 4.98 | Cosine-S-120 |
|
||||
| Next50-32x4d (DS) | 4.2 | 25.0 | 22.2 | - | Official |
|
||||
| Next50-32x4d (DS) | 4.470 | 25.05 | 21.16 | 5.65 | Cosine-S-120 |
|
||||
| MobileNet-V2 | 0.300 | 3.40 | 28.0 | - | Official |
|
||||
| MobileNet-V2 | 0.300 | 3.50 | 27.92 | 9.50 | MobileFast |
|
||||
| MobileNet-V2 | 0.300 | 3.50 | 27.56 | 9.26 | MobileFast-Smooth |
|
||||
| ShuffleNet-V2 1.0 | 0.146 | 2.28 | 30.6 | 11.1 | Official |
|
||||
| ShuffleNet-V2 1.0 | 0.145 | 2.28 | | | Cosine-S-120 |
|
||||
| ShuffleNet-V2 1.5 | 0.299 | | 27.4 | - | Official |
|
||||
| ShuffleNet-V2 1.5 | | | | | Cosine-S-120 |
|
||||
| ShuffleNet-V2 2.0 | | | | | Cosine-S-120 |
|
||||
|
||||
`DS` indicates deep-stem for the first convolutional layer.
|
||||
```
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
The following is a set of guidelines for contributing to NAS-Projects.
|
||||
|
||||
#### Table Of Contents
|
||||
## Table Of Contents
|
||||
|
||||
[How Can I Contribute?](#how-can-i-contribute)
|
||||
* [Reporting Bugs](#reporting-bugs)
|
||||
|
@ -6,9 +6,9 @@ Each edge here is associated with an operation selected from a predefined operat
|
||||
For it to be applicable for all NAS algorithms, the search space defined in NAS-Bench-102 includes 4 nodes and 5 associated operation options, which generates 15,625 neural cell candidates in total.
|
||||
|
||||
In this Markdown file, we provide:
|
||||
- [How to Use NAS-Bench-102](#how-to-use-nas-bench-102)
|
||||
- [Instruction to re-generate NAS-Bench-102](#instruction-to-re-generate-nas-bench-102)
|
||||
- [10 NAS algorithms evaluated in our paper](#to-reproduce-10-baseline-nas-algorithms-in-nas-bench-102)
|
||||
- [How to Use NAS-Bench-102](#how-to-use-nas-bench-102)
|
||||
- [Instruction to re-generate NAS-Bench-102](#instruction-to-re-generate-nas-bench-102)
|
||||
- [10 NAS algorithms evaluated in our paper](#to-reproduce-10-baseline-nas-algorithms-in-nas-bench-102)
|
||||
|
||||
Note: please use `PyTorch >= 1.2.0` and `Python >= 3.6.0`.
|
||||
|
||||
@ -140,6 +140,8 @@ This command will train 390 architectures (id from 0 to 389) using the following
|
||||
| CIFAR-100 | train | valid / test |
|
||||
| ImageNet-16-120 | train | valid / test |
|
||||
|
||||
Note that the above `train`, `valid`, and `test` indicate the proposed splits in our NAS-Bench-102, and they might be different with the original splits.
|
||||
|
||||
3. calculate the latency, merge the results of all architectures, and simplify the results.
|
||||
(see commands in `output/NAS-BENCH-102-4/meta-node-4.cal-script.txt` which is automatically generated by step-1).
|
||||
```
|
||||
@ -167,7 +169,7 @@ If researchers can provide better results with different hyper-parameters, we ar
|
||||
|
||||
**Note that** you need to prepare the training and test data as described in [Preparation and Download](#preparation-and-download)
|
||||
|
||||
- [1] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V1.sh cifar10 -1`
|
||||
- [1] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V1.sh cifar10 -1`, where `cifar10` can be replaced with `cifar100` or `ImageNet16-120`.
|
||||
- [2] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V2.sh cifar10 -1`
|
||||
- [3] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh cifar10 -1`
|
||||
- [4] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/SETN.sh cifar10 -1`
|
||||
|
@ -8,7 +8,6 @@ from tqdm import tqdm
|
||||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
import matplotlib
|
||||
@ -498,6 +497,8 @@ def show_nas_sharing_w(api, dataset, subset, vis_save_dir, file_name, y_lims, x_
|
||||
|
||||
def get_accs(xdata):
|
||||
epochs, xresults = xdata['epoch'], []
|
||||
metrics = api.arch2infos_full[ api.random() ].get_metrics(dataset, subset, None, False)
|
||||
xresults.append( metrics['accuracy'] )
|
||||
for iepoch in range(epochs):
|
||||
genotype = xdata['genotypes'][iepoch]
|
||||
index = api.query_index_by_arch(genotype)
|
||||
@ -547,7 +548,6 @@ if __name__ == '__main__':
|
||||
#visualize_relative_ranking(vis_save_dir)
|
||||
|
||||
api = API(args.api_path)
|
||||
"""
|
||||
for x_maxs in [50, 250]:
|
||||
show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||
show_nas_sharing_w(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||
@ -555,11 +555,12 @@ if __name__ == '__main__':
|
||||
show_nas_sharing_w(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||
show_nas_sharing_w(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||
show_nas_sharing_w(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||
just_show(api)
|
||||
"""
|
||||
just_show(api)
|
||||
plot_results_nas(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (85,95, 1))
|
||||
plot_results_nas(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-com.pdf', (85,95, 1))
|
||||
plot_results_nas(api, 'cifar100' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (55,75, 3))
|
||||
plot_results_nas(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-com.pdf', (55,75, 3))
|
||||
plot_results_nas(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-com.pdf', (35,50, 3))
|
||||
plot_results_nas(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-com.pdf', (35,50, 3))
|
||||
"""
|
||||
|
@ -10,7 +10,6 @@ from copy import deepcopy
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributions import Categorical
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import load_config, dict2config, configure2str
|
||||
|
@ -121,9 +121,19 @@ def main(xargs):
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||
elif xargs.dataset == 'cifar100':
|
||||
raise ValueError('not support yet : {:}'.format(xargs.dataset))
|
||||
elif xargs.dataset.startswith('ImageNet16'):
|
||||
raise ValueError('not support yet : {:}'.format(xargs.dataset))
|
||||
cifar100_test_split = load_config('configs/nas-benchmark/cifar100-test-split.txt', None, None)
|
||||
search_train_data = train_data
|
||||
search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform
|
||||
search_data = SearchDataset(xargs.dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), cifar100_test_split.xvalid)
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_test_split.xvalid), num_workers=xargs.workers, pin_memory=True)
|
||||
elif xargs.dataset == 'ImageNet16-120':
|
||||
imagenet_test_split = load_config('configs/nas-benchmark/imagenet-16-120-test-split.txt', None, None)
|
||||
search_train_data = train_data
|
||||
search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform
|
||||
search_data = SearchDataset(xargs.dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), imagenet_test_split.xvalid)
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_test_split.xvalid), num_workers=xargs.workers, pin_memory=True)
|
||||
else:
|
||||
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
|
||||
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
|
||||
@ -168,7 +178,7 @@ def main(xargs):
|
||||
logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
|
||||
else:
|
||||
logger.log("=> do not find the last-info file : {:}".format(last_info))
|
||||
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}
|
||||
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: search_model.genotype()}
|
||||
|
||||
# start training
|
||||
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
|
||||
@ -230,7 +240,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--data_path', type=str, help='Path to dataset')
|
||||
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
|
||||
# channels and number-of-cells
|
||||
parser.add_argument('--config_path', type=str, help='The config paths.')
|
||||
parser.add_argument('--config_path', type=str, help='The config path.')
|
||||
parser.add_argument('--search_space_name', type=str, help='The search space name.')
|
||||
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
|
||||
parser.add_argument('--channel', type=int, help='The number of channels.')
|
||||
|
@ -181,8 +181,8 @@ def main(xargs):
|
||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||
else:
|
||||
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
|
||||
config_path = 'configs/nas-benchmark/algos/DARTS.config'
|
||||
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||
#config_path = 'configs/nas-benchmark/algos/DARTS.config'
|
||||
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||
# To split data
|
||||
train_data_v2 = deepcopy(train_data)
|
||||
train_data_v2.transform = valid_data.transform
|
||||
@ -233,7 +233,7 @@ def main(xargs):
|
||||
logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
|
||||
else:
|
||||
logger.log("=> do not find the last-info file : {:}".format(last_info))
|
||||
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}
|
||||
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: search_model.genotype()}
|
||||
|
||||
# start training
|
||||
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
|
||||
@ -297,6 +297,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--data_path', type=str, help='Path to dataset')
|
||||
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
|
||||
# channels and number-of-cells
|
||||
parser.add_argument('--config_path', type=str, help='The config path.')
|
||||
parser.add_argument('--search_space_name', type=str, help='The search space name.')
|
||||
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
|
||||
parser.add_argument('--channel', type=int, help='The number of channels.')
|
||||
|
@ -3,7 +3,7 @@
|
||||
###########################################################################
|
||||
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
|
||||
###########################################################################
|
||||
import os, sys, time, glob, random, argparse
|
||||
import os, sys, time, random, argparse
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
@ -11,7 +11,7 @@ import torch.nn as nn
|
||||
from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import load_config, dict2config, configure2str
|
||||
from config_utils import load_config, dict2config
|
||||
from datasets import get_datasets, SearchDataset
|
||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
|
@ -1,12 +1,14 @@
|
||||
# python ./exps/vis/test.py
|
||||
import os, sys, random
|
||||
from pathlib import Path
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
|
||||
from nas_102_api import NASBench102API as API
|
||||
|
||||
def test_nas_api():
|
||||
from nas_102_api import ArchResults
|
||||
@ -72,7 +74,40 @@ def test_auto_grad():
|
||||
s_grads = torch.autograd.grad(grads, net.parameters())
|
||||
second_order_grads.append( s_grads )
|
||||
|
||||
|
||||
def test_one_shot_model(ckpath, use_train):
|
||||
from models import get_cell_based_tiny_net, get_search_spaces
|
||||
from datasets import get_datasets, SearchDataset
|
||||
from config_utils import load_config, dict2config
|
||||
from utils.nas_utils import evaluate_one_shot
|
||||
use_train = int(use_train) > 0
|
||||
#ckpath = 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/seed-11416-basic.pth'
|
||||
#ckpath = 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/seed-28640-basic.pth'
|
||||
print ('ckpath : {:}'.format(ckpath))
|
||||
ckp = torch.load(ckpath)
|
||||
xargs = ckp['args']
|
||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None)
|
||||
if xargs.dataset == 'cifar10':
|
||||
cifar_split = load_config('configs/nas-benchmark/cifar-split.txt', None, None)
|
||||
xvalid_data = deepcopy(train_data)
|
||||
xvalid_data.transform = valid_data.transform
|
||||
valid_loader= torch.utils.data.DataLoader(xvalid_data, batch_size=2048, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar_split.valid), num_workers=12, pin_memory=True)
|
||||
else: raise ValueError('invalid dataset : {:}'.format(xargs.dataseet))
|
||||
search_space = get_search_spaces('cell', xargs.search_space_name)
|
||||
model_config = dict2config({'name': 'SETN', 'C': xargs.channel, 'N': xargs.num_cells,
|
||||
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
|
||||
'space' : search_space,
|
||||
'affine' : False, 'track_running_stats': True}, None)
|
||||
search_model = get_cell_based_tiny_net(model_config)
|
||||
search_model.load_state_dict( ckp['search_model'] )
|
||||
search_model = search_model.cuda()
|
||||
api = API('/home/dxy/.torch/NAS-Bench-102-v1_0-e61699.pth')
|
||||
archs, probs, accuracies = evaluate_one_shot(search_model, valid_loader, api, use_train)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
#test_nas_api()
|
||||
#for i in range(200): plot('{:04d}'.format(i))
|
||||
test_auto_grad()
|
||||
#test_auto_grad()
|
||||
test_one_shot_model(sys.argv[1], sys.argv[2])
|
||||
|
@ -9,16 +9,25 @@ class SearchDataset(data.Dataset):
|
||||
|
||||
def __init__(self, name, data, train_split, valid_split, check=True):
|
||||
self.datasetname = name
|
||||
self.data = data
|
||||
self.train_split = train_split.copy()
|
||||
self.valid_split = valid_split.copy()
|
||||
if check:
|
||||
intersection = set(train_split).intersection(set(valid_split))
|
||||
assert len(intersection) == 0, 'the splitted train and validation sets should have no intersection'
|
||||
if isinstance(data, (list, tuple)): # new type of SearchDataset
|
||||
assert len(data) == 2, 'invalid length: {:}'.format( len(data) )
|
||||
self.train_data = data[0]
|
||||
self.valid_data = data[1]
|
||||
self.train_split = train_split.copy()
|
||||
self.valid_split = valid_split.copy()
|
||||
self.mode_str = 'V2' # new mode
|
||||
else:
|
||||
self.mode_str = 'V1' # old mode
|
||||
self.data = data
|
||||
self.train_split = train_split.copy()
|
||||
self.valid_split = valid_split.copy()
|
||||
if check:
|
||||
intersection = set(train_split).intersection(set(valid_split))
|
||||
assert len(intersection) == 0, 'the splitted train and validation sets should have no intersection'
|
||||
self.length = len(self.train_split)
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}(name={datasetname}, train={tr_L}, valid={val_L})'.format(name=self.__class__.__name__, datasetname=self.datasetname, tr_L=len(self.train_split), val_L=len(self.valid_split)))
|
||||
return ('{name}(name={datasetname}, train={tr_L}, valid={val_L}, version={ver})'.format(name=self.__class__.__name__, datasetname=self.datasetname, tr_L=len(self.train_split), val_L=len(self.valid_split), ver=self.mode_str))
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
@ -27,6 +36,11 @@ class SearchDataset(data.Dataset):
|
||||
assert index >= 0 and index < self.length, 'invalid index = {:}'.format(index)
|
||||
train_index = self.train_split[index]
|
||||
valid_index = random.choice( self.valid_split )
|
||||
train_image, train_label = self.data[train_index]
|
||||
valid_image, valid_label = self.data[valid_index]
|
||||
if self.mode_str == 'V1':
|
||||
train_image, train_label = self.data[train_index]
|
||||
valid_image, valid_label = self.data[valid_index]
|
||||
elif self.mode_str == 'V2':
|
||||
train_image, train_label = self.train_data[train_index]
|
||||
valid_image, valid_label = self.valid_data[valid_index]
|
||||
else: raise ValueError('invalid mode : {:}'.format(self.mode_str))
|
||||
return train_image, train_label, valid_image, valid_label
|
||||
|
@ -34,7 +34,7 @@ class PointMeta():
|
||||
|
||||
def get_box(self, return_diagonal=False):
|
||||
if self.box is None: return None
|
||||
if return_diagonal == False:
|
||||
if not return_diagonal:
|
||||
return self.box.clone()
|
||||
else:
|
||||
W = (self.box[2]-self.box[0]).item()
|
||||
|
@ -1,4 +1,3 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import OPS
|
||||
|
@ -68,7 +68,7 @@ class Structure:
|
||||
for i, node_info in enumerate(self.nodes):
|
||||
sums = []
|
||||
for op, xin in node_info:
|
||||
if op == 'none' or nodes[xin] == False: x = False
|
||||
if op == 'none' or nodes[xin] is False: x = False
|
||||
else: x = True
|
||||
sums.append( x )
|
||||
nodes[i+1] = sum(sums) > 0
|
||||
|
@ -85,7 +85,7 @@ class SearchCell(nn.Module):
|
||||
candidates = self.edges[node_str]
|
||||
select_op = random.choice(candidates)
|
||||
sops.append( select_op )
|
||||
if not hasattr(select_op, 'is_zero') or select_op.is_zero == False: has_non_zero=True
|
||||
if not hasattr(select_op, 'is_zero') or select_op.is_zero is False: has_non_zero=True
|
||||
if has_non_zero: break
|
||||
inter_nodes = []
|
||||
for j, select_op in enumerate(sops):
|
||||
|
@ -1,4 +1,4 @@
|
||||
import math, torch
|
||||
import math
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ..initialization import initialize_resnet
|
||||
|
@ -70,6 +70,9 @@ class NASBench102API(object):
|
||||
def __repr__(self):
|
||||
return ('{name}({num}/{total} architectures)'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs)))
|
||||
|
||||
def random(self):
|
||||
return random.randint(0, len(self.meta_archs)-1)
|
||||
|
||||
def query_index_by_arch(self, arch):
|
||||
if isinstance(arch, str):
|
||||
if arch in self.archstr2index: arch_index = self.archstr2index[ arch ]
|
||||
|
@ -1,7 +1,7 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, time, torch, random, PIL, copy, numpy as np
|
||||
import os, sys, torch, random, PIL, copy, numpy as np
|
||||
from os import path as osp
|
||||
from shutil import copyfile
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
from .evaluation_utils import obtain_accuracy
|
||||
from .gpu_manager import GPUManager
|
||||
from .flop_benchmark import get_model_infos
|
||||
from .affine_utils import normalize_points, denormalize_points
|
||||
from .affine_utils import identity2affine, solve2theta, affine2image
|
||||
|
@ -1,10 +1,3 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
#
|
||||
# functions for affine transformation
|
||||
import math, torch
|
||||
import numpy as np
|
@ -1,4 +1,4 @@
|
||||
import copy, torch
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
@ -27,7 +27,7 @@ class GPUManager():
|
||||
find = False
|
||||
for gpu in all_gpus:
|
||||
if gpu['index'] == CUDA_VISIBLE_DEVICE:
|
||||
assert find==False, 'Duplicate cuda device index : {}'.format(CUDA_VISIBLE_DEVICE)
|
||||
assert not find, 'Duplicate cuda device index : {}'.format(CUDA_VISIBLE_DEVICE)
|
||||
find = True
|
||||
selected_gpus.append( gpu.copy() )
|
||||
selected_gpus[-1]['index'] = '{}'.format(idx)
|
||||
|
52
lib/utils/nas_utils.py
Normal file
52
lib/utils/nas_utils.py
Normal file
@ -0,0 +1,52 @@
|
||||
# This file is for experimental usage
|
||||
import os, sys, torch, random
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from tqdm import tqdm
|
||||
import torch.nn as nn
|
||||
|
||||
from utils import obtain_accuracy
|
||||
from models import CellStructure
|
||||
from log_utils import time_string
|
||||
|
||||
def evaluate_one_shot(model, xloader, api, cal_mode, seed=111):
|
||||
weights = deepcopy(model.state_dict())
|
||||
model.train(cal_mode)
|
||||
with torch.no_grad():
|
||||
logits = nn.functional.log_softmax(model.arch_parameters, dim=-1)
|
||||
archs = CellStructure.gen_all(model.op_names, model.max_nodes, False)
|
||||
probs, accuracies, gt_accs = [], [], []
|
||||
loader_iter = iter(xloader)
|
||||
random.seed(seed)
|
||||
random.shuffle(archs)
|
||||
for idx, arch in enumerate(archs):
|
||||
arch_index = api.query_index_by_arch( arch )
|
||||
metrics = api.get_more_info(arch_index, 'cifar10-valid', None, False, False)
|
||||
gt_accs.append( metrics['valid-accuracy'] )
|
||||
select_logits = []
|
||||
for i, node_info in enumerate(arch.nodes):
|
||||
for op, xin in node_info:
|
||||
node_str = '{:}<-{:}'.format(i+1, xin)
|
||||
op_index = model.op_names.index(op)
|
||||
select_logits.append( logits[model.edge2index[node_str], op_index] )
|
||||
cur_prob = sum(select_logits).item()
|
||||
probs.append( cur_prob )
|
||||
cor_prob = np.corrcoef(probs, gt_accs)[0,1]
|
||||
print ('correlation for probabilities : {:}'.format(cor_prob))
|
||||
|
||||
for idx, arch in enumerate(archs):
|
||||
model.set_cal_mode('dynamic', arch)
|
||||
try:
|
||||
inputs, targets = next(loader_iter)
|
||||
except:
|
||||
loader_iter = iter(xloader)
|
||||
inputs, targets = next(loader_iter)
|
||||
_, logits = model(inputs.cuda())
|
||||
_, preds = torch.max(logits, dim=-1)
|
||||
correct = (preds == targets.cuda() ).float()
|
||||
accuracies.append( correct.mean().item() )
|
||||
if idx != 0 and (idx % 300 == 0 or idx + 1 == len(archs) or idx == 10):
|
||||
cor_accs = np.corrcoef(accuracies, gt_accs[:idx+1])[0,1]
|
||||
print ('{:} {:03d}/{:03d} mode={:5s}, correlation : accs={:.4f}, arch={:}'.format(time_string(), idx, len(archs), 'Train' if cal_mode else 'Eval', cor_accs, arch))
|
||||
model.load_state_dict(weights)
|
||||
return archs, probs, accuracies
|
@ -1 +0,0 @@
|
||||
from .affine_utils import normalize_points, denormalize_points
|
@ -33,6 +33,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V2.py \
|
||||
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
||||
--dataset ${dataset} --data_path ${data_path} \
|
||||
--search_space_name ${space} \
|
||||
--config_path configs/nas-benchmark/algos/DARTS.config \
|
||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
|
||||
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
|
||||
--workers 4 --print_freq 200 --rand_seed ${seed}
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
# bash ./scripts/prepare.sh
|
||||
datasets='cifar10 cifar100 imagenet-1k'
|
||||
#datasets='cifar10 cifar100 imagenet-1k'
|
||||
#ratios='0.5 0.8 0.9'
|
||||
ratios='0.5'
|
||||
save_dir=./.latent-data/splits
|
||||
|
@ -33,7 +33,7 @@ OMP_NUM_THREADS=4 python ./exps/basic-main.py --dataset ${dataset} \
|
||||
--procedure basic \
|
||||
--save_dir ${xsave_dir} \
|
||||
--cutout_length -1 \
|
||||
--batch_size 256 --rand_seed ${rseed} --workers 6 \
|
||||
--batch_size ${batch} --rand_seed ${rseed} --workers 6 \
|
||||
--eval_frequency 1 --print_freq 100 --print_freq_eval 200
|
||||
|
||||
# KD training
|
||||
@ -47,5 +47,5 @@ OMP_NUM_THREADS=4 python ./exps/KD-main.py --dataset ${dataset} \
|
||||
--save_dir ${xsave_dir} \
|
||||
--KD_alpha 0.9 --KD_temperature 4 \
|
||||
--cutout_length -1 \
|
||||
--batch_size 256 --rand_seed ${rseed} --workers 6 \
|
||||
--batch_size ${batch} --rand_seed ${rseed} --workers 6 \
|
||||
--eval_frequency 1 --print_freq 100 --print_freq_eval 200
|
||||
|
Loading…
Reference in New Issue
Block a user