updates
This commit is contained in:
parent
fac556c176
commit
7b354d4c74
138
.latent-data/BASELINE.md
Normal file
138
.latent-data/BASELINE.md
Normal file
@ -0,0 +1,138 @@
|
||||
# Basic Classification Models
|
||||
|
||||
## Performance on CIFAR
|
||||
|
||||
| Model | FLOPs | Params (M) | Error on CIFAR-10 | Error on CIFAR-100 | Batch-GPU |
|
||||
|:------------------:|:-----------:|:----------:|:-----------------:|:------------------:|:---------:|
|
||||
| ResNet-08 | 12.50 M | 0.08 | 12.14 | 40.20 | 256-2 |
|
||||
| ResNet-20 | 40.81 M | 0.27 | 7.26 | 31.38 | 256-2 |
|
||||
| ResNet-32 | 69.12 M | 0.47 | 6.19 | 29.56 | 256-2 |
|
||||
| ResNet-56 | 125.75 M | 0.86 | 5.74 | 26.82 | 256-2 |
|
||||
| ResNet-110 | 253.15 M | 1.73 | 5.14 | 25.18 | 256-2 |
|
||||
| ResNet-110 | 253.15 M | 1.73 | 5.06 | 25.49 | 256-1 |
|
||||
| ResNet-164 | 247.65 M | 1.70 | 4.36 | 21.48 | 256-2 |
|
||||
| ResNet-1001 | 1491.00 M | 10.33 | 5.34 | 22.50 | 256-2 |
|
||||
| DenseNet-BC100-12 | 287.93 M | 0.77 | 4.68 | 22.76 | 256-2 |
|
||||
| DenseNet-BC100-12 | 287.93 M | 0.77 | 4.25 | 21.54 | 128-2 |
|
||||
| DenseNet-BC100-12 | 287.93 M | 0.77 | 5.51 | 24.67 | 64-1 |
|
||||
| WRN-28-10 | 5243.33 M | 36.48 | 3.61 | 19.65 | 256-2 |
|
||||
|
||||
```
|
||||
bash ./scripts-cluster/local.sh 0,1 "bash ./scripts/base-train.sh cifar10 ResNet20 E300 L1 256 -1"
|
||||
bash ./scripts-cluster/local.sh 0,1 "bash ./scripts/base-train.sh cifar10 ResNet56 E300 L1 256 -1"
|
||||
bash ./scripts-cluster/local.sh 0,1 "bash ./scripts/base-train.sh cifar10 ResNet110 E300 L1 256 -1"
|
||||
bash ./scripts-cluster/local.sh 0,1 "bash ./scripts/base-train.sh cifar10 ResNet164 E300 L1 256 -1"
|
||||
bash ./scripts-cluster/local.sh 0,1 "bash ./scripts/base-train.sh cifar10 DenseBC100-12 E300 L1 256 -1"
|
||||
bash ./scripts-cluster/local.sh 0,1 "bash ./scripts/base-train.sh cifar10 WRN28-10 E300 L1 256 -1"
|
||||
CUDA_VISIBLE_DEVICES=0,1 python ./exps/basic-eval.py --data_path ${TORCH_HOME}/ILSVRC2012 --checkpoint
|
||||
CUDA_VISIBLE_DEVICES=0,1 python ./exps/test-official-CNN.py --data_path ${TORCH_HOME}/ILSVRC2012
|
||||
python ./scripts-cluster/submit.py yq01-v100-box-2-8 TEST-CIFAR10-1001 2 "bash ./scripts/base-train.sh cifar10 ResNet1001 E300 L1 256 1021"
|
||||
```
|
||||
|
||||
Train some NAS models:
|
||||
```
|
||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts/nas-infer-train.sh cifar10 SETN 96 -1
|
||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts/nas-infer-train.sh cifar100 SETN 96 -1
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k SETN 256 -1
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k SETN1 256 -1
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k DARTS 256 -1
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k GDAS_V1 256 -1
|
||||
```
|
||||
|
||||
## Performance on ImageNet
|
||||
|
||||
| Model | FLOPs (GB) | Params (M) | Top-1 Error | Top-5 Error | Optimizer |
|
||||
|:--------------:|:----------:|:----------:|:-----------:|:-----------:|:----------:|
|
||||
| ResNet-18 | 1.814 | 11.69 | 30.24 | 10.92 | Official |
|
||||
| ResNet-18 | 1.814 | 11.69 | 29.97 | 10.43 | Step-120 |
|
||||
| ResNet-18 | 1.814 | 11.69 | 29.35 | 10.13 | Cosine-120 |
|
||||
| ResNet-18 | 1.814 | 11.69 | 29.45 | 10.25 | Cosine-120 B1024 |
|
||||
| ResNet-18 | 1.814 | 11.69 | 29.44 | 10.12 |Cosine-S-120|
|
||||
| ResNet-18 (DS) | 2.053 | 11.71 | 28.53 | 9.69 |Cosine-S-120|
|
||||
| ResNet-34 | 3.663 | 21.80 | 25.65 | 8.06 |Cosine-120 |
|
||||
| ResNet-34 (DS) | 3.903 | 21.82 | 25.05 | 7.67 |Cosine-S-120|
|
||||
| ResNet-50 | 4.089 | 25.56 | 23.85 | 7.13 | Official |
|
||||
| ResNet-50 | 4.089 | 25.56 | 22.54 | 6.45 |Cosine-120 |
|
||||
| ResNet-50 | 4.089 | 25.56 | 22.71 | 6.38 |Cosine-120 B1024 |
|
||||
| ResNet-50 | 4.089 | 25.56 | 22.34 | 6.22 |Cosine-S-120|
|
||||
| ResNet-50 (DS) | 4.328 | 25.58 | 22.67 | 6.39 | Step-120 |
|
||||
| ResNet-50 (DS) | 4.328 | 25.58 | 21.94 | 6.23 | Cosine-120 |
|
||||
| ResNet-50 (DS) | 4.328 | 25.58 | 21.71 | 5.99 |Cosine-S-120|
|
||||
| ResNet-101 | 7.801 | 44.55 | 20.93 | 5.57 |Cosine-120 |
|
||||
| ResNet-101 | 7.801 | 44.55 | 20.92 | 5.58 |Cosine-120 B1024 |
|
||||
| ResNet-101 (DS)| 8.041 | 44.57 | 20.36 | 5.22 |Cosine-S-120|
|
||||
| ResNet-152 | 11.514 | 60.19 | 20.10 | 5.17 |Cosine-120 B1024 |
|
||||
| ResNet-152 (DS)| 11.753 | 60.21 | 19.83 | 5.02 |Cosine-S-120|
|
||||
| ResNet-200 | 15.007 | 64.67 | 20.06 | 4.98 |Cosine-S-120|
|
||||
| Next50-32x4d (DS)| 4.2 | 25.0 | 22.2 | - | Official |
|
||||
| Next50-32x4d (DS)| 4.470 | 25.05 | 21.16 | 5.65 |Cosine-S-120|
|
||||
| MobileNet-V2 | 0.300 | 3.40 | 28.0 | - | Official |
|
||||
| MobileNet-V2 | 0.300 | 3.50 | 27.92 | 9.50 | MobileFast |
|
||||
| MobileNet-V2 | 0.300 | 3.50 | 27.56 | 9.26 | MobileFast-Smooth |
|
||||
| ShuffleNet-V2 1.0| 0.146 | 2.28 | 30.6 | 11.1 | Official |
|
||||
| ShuffleNet-V2 1.0| 0.145 | 2.28 | | |Cosine-S-120|
|
||||
| ShuffleNet-V2 1.5| 0.299 | | 27.4 | - | Official |
|
||||
| ShuffleNet-V2 1.5| | | | |Cosine-S-120|
|
||||
| ShuffleNet-V2 2.0| | | | |Cosine-S-120|
|
||||
|
||||
`DS` indicates deep-stem for the first convolutional layer.
|
||||
```
|
||||
bash ./scripts-cluster/local.sh 0,1,2,3 "bash ./scripts/base-imagenet.sh ResNet18V1 Step-Soft 256 -1"
|
||||
bash ./scripts-cluster/local.sh 0,1,2,3 "bash ./scripts/base-imagenet.sh ResNet18V1 Cos-Soft 256 -1"
|
||||
bash ./scripts-cluster/local.sh 0,1,2,3 "bash ./scripts/base-imagenet.sh ResNet18V1 Cos-Soft 1024 -1"
|
||||
bash ./scripts-cluster/local.sh 0,1,2,3 "bash ./scripts/base-imagenet.sh ResNet18V1 Cos-Smooth 256 -1"
|
||||
bash ./scripts-cluster/local.sh 0,1,2,3 "bash ./scripts/base-imagenet.sh ResNet18V2 Cos-Smooth 256 -1"
|
||||
bash ./scripts-cluster/local.sh 0,1,2,3 "bash ./scripts/base-imagenet.sh ResNet34V2 Cos-Smooth 256 -1"
|
||||
bash ./scripts-cluster/local.sh 0,1,2,3 "bash ./scripts/base-imagenet.sh ResNet50V1 Cos-Soft 256 -1"
|
||||
bash ./scripts-cluster/local.sh 0,1,2,3 "bash ./scripts/base-imagenet.sh ResNet50V2 Step-Soft 256 -1"
|
||||
bash ./scripts-cluster/local.sh 0,1,2,3 "bash ./scripts/base-imagenet.sh ResNet50V2 Cos-Soft 256 -1"
|
||||
bash ./scripts-cluster/local.sh 0,1,2,3 "bash ./scripts/base-imagenet.sh ResNet101V2 Cos-Smooth 256 -1"
|
||||
bash ./scripts-cluster/local.sh 0,1,2,3 "bash ./scripts/base-imagenet.sh ResNext50-32x4dV2 Cos-Smooth 256 -1"
|
||||
```
|
||||
|
||||
Train efficient models may require different hyper-parameters.
|
||||
```
|
||||
bash ./scripts-cluster/local.sh 0,1,2,3 "bash ./scripts/base-imagenet.sh MobileNetV2-X MobileFast 256 -1"
|
||||
bash ./scripts-cluster/local.sh 0,1,2,3 "bash ./scripts/base-imagenet.sh MobileNetV2-X MobileFastS 256 -1"
|
||||
bash ./scripts-cluster/local.sh 0,1,2,3 "bash ./scripts/base-imagenet.sh MobileNetV2 Mobile 256 -1" (70.96 top-1, 90.05 top-5)
|
||||
bash ./scripts-cluster/local.sh 0,1,2,3 "bash ./scripts/base-imagenet.sh ShuffleNetV2-X Shuffle 1024 -1"
|
||||
```
|
||||
|
||||
# Train with Knowledge Distillation
|
||||
|
||||
ResNet110 -> ResNet20
|
||||
```
|
||||
bash ./scripts-cluster/local.sh 0,1 "bash ./scripts/KD-train.sh cifar10 ResNet20 ResNet110 0.9 4 -1"
|
||||
```
|
||||
|
||||
ResNet110 -> ResNet110
|
||||
```
|
||||
bash ./scripts-cluster/local.sh 0,1 "bash ./scripts/KD-train.sh cifar10 ResNet110 ResNet110 0.9 4 -1"
|
||||
```
|
||||
|
||||
Set alpha=0.9 and temperature=4 following `Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer, ICLR 2017`.
|
||||
|
||||
# Linux
|
||||
The following command will redirect the output of top command to `top.txt`.
|
||||
```
|
||||
top -b -n 1 > top.txt
|
||||
```
|
||||
|
||||
## Download the ImageNet dataset
|
||||
The ImageNet Large Scale Visual Recognition Challenge (ILSVRC) dataset has 1000 categories and 1.2 million images. The images do not need to be preprocessed or packaged in any database, but the validation images need to be moved into appropriate subfolders.
|
||||
|
||||
1. Download the images from http://image-net.org/download-images
|
||||
|
||||
2. Extract the training data:
|
||||
```bash
|
||||
mkdir train && mv ILSVRC2012_img_train.tar train/ && cd train
|
||||
tar -xvf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar
|
||||
find . -name "*.tar" | while read NAME ; do mkdir -p "${NAME%.tar}"; tar -xvf "${NAME}" -C "${NAME%.tar}"; rm -f "${NAME}"; done
|
||||
cd ..
|
||||
```
|
||||
|
||||
3. Extract the validation data and move images to subfolders:
|
||||
```bash
|
||||
mkdir val && mv ILSVRC2012_img_val.tar val/ && cd val && tar -xvf ILSVRC2012_img_val.tar
|
||||
wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash
|
||||
```
|
64
AA-NAS-Bench.md
Normal file
64
AA-NAS-Bench.md
Normal file
@ -0,0 +1,64 @@
|
||||
# An Algorithm-Agnostic NAS Benchmark (AA-NAS-Bench)
|
||||
|
||||
We propose an Algorithm-Agnostic NAS Benchmark (AA-NAS-Bench) with a fixed search space, which provides a unified benchmark for almost any up-to-date NAS algorithms.
|
||||
The design of our search space is inspired from that used in the most popular cell-based searching algorithms, where a cell is represented as a directed acyclic graph. Each edge here is associated with an operation selected from a predefined operation set. For it to be applicable for all NAS algorithms, the search space defined in AA-NAS-Bench includes 4 nodes and 5 associated operation options, which generates 15,625 neural cell candidates in total.
|
||||
|
||||
In this Markdown file, we provide:
|
||||
- Detailed instruction to reproduce AA-NAS-Bench.
|
||||
- 10 NAS algorithms evaluated in our paper.
|
||||
|
||||
Note: please use `PyTorch >= 1.1.0` and `Python >= 3.6.0`.
|
||||
|
||||
## Instruction to Generate AA-NAS-Bench
|
||||
|
||||
1. generate the meta file for AA-NAS-Bench using the following script, where `AA-NAS-BENCH` indicates the name and `4` indicates the maximum number of nodes in a cell.
|
||||
```
|
||||
bash scripts-search/AA-NAS-meta-gen.sh AA-NAS-BENCH 4
|
||||
```
|
||||
|
||||
2. train earch architecture on a single GPU (see commands in `output/AA-NAS-BENCH-4/meta-node-4.opt-script.txt` which is automatically generated by step-1).
|
||||
```
|
||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/AA-NAS-train-archs.sh 0 389 -1 '777 888 999'
|
||||
```
|
||||
This command will train 390 architectures (id from 0 to 389) using the following four kinds of splits with three random seeds (777, 888, 999).
|
||||
|
||||
| Dataset | Train | Eval |
|
||||
|:---------------:|:-------------:|:-----:|
|
||||
| CIFAR-10 | train | valid |
|
||||
| CIFAR-10 | train + valid | test |
|
||||
| CIFAR-100 | train | valid+test |
|
||||
| ImageNet-16-120 | train | valid+test |
|
||||
|
||||
3. calculate the latency, merge the results of all architectures, and simplify the results.
|
||||
(see commands in `output/AA-NAS-BENCH-4/meta-node-4.cal-script.txt` which is automatically generated by step-1).
|
||||
```
|
||||
OMP_NUM_THREADS=6 CUDA_VISIBLE_DEVICES=0 python exps/AA-NAS-statistics.py --mode cal --target_dir 000000-000389-C16-N5
|
||||
```
|
||||
|
||||
4. merge all results into a single file for AA-NAS-Bench-API.
|
||||
```
|
||||
OMP_NUM_THREADS=4 python exps/AA-NAS-statistics.py --mode merge
|
||||
```
|
||||
|
||||
[option] train a single architecture on a single GPU.
|
||||
```
|
||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/AA-NAS-train-net.sh resnet 16 5
|
||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/AA-NAS-train-net.sh '|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|skip_connect~1|skip_connect~2|' 16 5
|
||||
```
|
||||
|
||||
[option] load the parameters of a trained network.
|
||||
```
|
||||
|
||||
```
|
||||
|
||||
## To reproduce 10 baseline NAS algorithms in AA-NAS-Bench
|
||||
|
||||
We have tried our best to implement each method. However, still, some algorithms might obtain non-optimal results since their hyper-parameters might not fit our AA-NAS-Bench.
|
||||
If researchers can provide better results with different hyper-parameters, we are happy to update results according to the new experimental results. We also welcome more NAS algorithms to test on our dataset and would include them accordingly.
|
||||
|
||||
- `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V1.sh cifar10 -1`
|
||||
- `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V2.sh cifar10 -1`
|
||||
- `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh cifar10 -1`
|
||||
- `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/SETN.sh cifar10 -1`
|
||||
- `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/ENAS.sh cifar10 -1`
|
||||
- `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/RANDOM-NAS.sh cifar10 -1`
|
279
exps/AA-NAS-Bench-main.py
Normal file
279
exps/AA-NAS-Bench-main.py
Normal file
@ -0,0 +1,279 @@
|
||||
import os, sys, time, torch, random, argparse
|
||||
from PIL import ImageFile
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
lib_dir = (Path(__file__).parent / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import load_config
|
||||
from procedures import save_checkpoint, copy_checkpoint
|
||||
from procedures import get_machine_info
|
||||
from datasets import get_datasets
|
||||
from log_utils import Logger, AverageMeter, time_string, convert_secs2time
|
||||
from models import CellStructure, CellArchitectures, get_search_spaces
|
||||
from AA_functions import evaluate_for_seed
|
||||
|
||||
|
||||
def evaluate_all_datasets(arch, datasets, xpaths, splits, seed, arch_config, workers, logger):
|
||||
machine_info, arch_config = get_machine_info(), deepcopy(arch_config)
|
||||
all_infos = {'info': machine_info}
|
||||
all_dataset_keys = []
|
||||
# look all the datasets
|
||||
for dataset, xpath, split in zip(datasets, xpaths, splits):
|
||||
# train valid data
|
||||
train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)
|
||||
# load the configurature
|
||||
if dataset == 'cifar10' or dataset == 'cifar100':
|
||||
config_path = 'configs/nas-benchmark/CIFAR.config'
|
||||
split_info = load_config('configs/nas-benchmark/cifar-split.txt', None, None)
|
||||
elif dataset.startswith('ImageNet16'):
|
||||
config_path = 'configs/nas-benchmark/ImageNet-16.config'
|
||||
split_info = load_config('configs/nas-benchmark/{:}-split.txt'.format(dataset), None, None)
|
||||
else:
|
||||
raise ValueError('invalid dataset : {:}'.format(dataset))
|
||||
config = load_config(config_path, \
|
||||
{'class_num': class_num,
|
||||
'xshape' : xshape}, \
|
||||
logger)
|
||||
# check whether use splited validation set
|
||||
if bool(split):
|
||||
assert len(train_data) == len(split_info.train) + len(split_info.valid), 'invalid length : {:} vs {:} + {:}'.format(len(train_data), len(split_info.train), len(split_info.valid))
|
||||
train_data_v2 = deepcopy(train_data)
|
||||
train_data_v2.transform = valid_data.transform
|
||||
valid_data = train_data_v2
|
||||
# data loader
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), num_workers=workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), num_workers=workers, pin_memory=True)
|
||||
else:
|
||||
# data loader
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, shuffle=True , num_workers=workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)
|
||||
|
||||
dataset_key = '{:}'.format(dataset)
|
||||
if bool(split): dataset_key = dataset_key + '-valid'
|
||||
logger.log('Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size))
|
||||
logger.log('Evaluate ||||||| {:10s} ||||||| Config={:}'.format(dataset_key, config))
|
||||
results = evaluate_for_seed(arch_config, config, arch, train_loader, valid_loader, seed, logger)
|
||||
all_infos[dataset_key] = results
|
||||
all_dataset_keys.append( dataset_key )
|
||||
all_infos['all_dataset_keys'] = all_dataset_keys
|
||||
return all_infos
|
||||
|
||||
|
||||
def main(save_dir, workers, datasets, xpaths, splits, srange, arch_index, seeds, cover_mode, meta_info, arch_config):
|
||||
assert torch.cuda.is_available(), 'CUDA is not available.'
|
||||
torch.backends.cudnn.enabled = True
|
||||
#torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.set_num_threads( workers )
|
||||
|
||||
assert len(srange) == 2 and 0 <= srange[0] <= srange[1], 'invalid srange : {:}'.format(srange)
|
||||
|
||||
sub_dir = Path(save_dir) / '{:06d}-{:06d}-C{:}-N{:}'.format(srange[0], srange[1], arch_config['channel'], arch_config['num_cells'])
|
||||
logger = Logger(str(sub_dir), 0, False)
|
||||
|
||||
all_archs = meta_info['archs']
|
||||
assert srange[1] < meta_info['total'], 'invalid range : {:}-{:} vs. {:}'.format(srange[0], srange[1], meta_info['total'])
|
||||
assert arch_index == -1 or srange[0] <= arch_index <= srange[1], 'invalid range : {:} vs. {:} vs. {:}'.format(srange[0], arch_index, srange[1])
|
||||
if arch_index == -1:
|
||||
to_evaluate_indexes = list(range(srange[0], srange[1]+1))
|
||||
else:
|
||||
to_evaluate_indexes = [arch_index]
|
||||
logger.log('xargs : seeds = {:}'.format(seeds))
|
||||
logger.log('xargs : arch_index = {:}'.format(arch_index))
|
||||
logger.log('xargs : cover_mode = {:}'.format(cover_mode))
|
||||
logger.log('-'*100)
|
||||
|
||||
logger.log('Start evaluating range =: {:06d} vs. {:06d} vs. {:06d} / {:06d} with cover-mode={:}'.format(srange[0], arch_index, srange[1], meta_info['total'], cover_mode))
|
||||
for i, (dataset, xpath, split) in enumerate(zip(datasets, xpaths, splits)):
|
||||
logger.log('--->>> Evaluate {:}/{:} : dataset={:9s}, path={:}, split={:}'.format(i, len(datasets), dataset, xpath, split))
|
||||
logger.log('--->>> architecture config : {:}'.format(arch_config))
|
||||
|
||||
|
||||
start_time, epoch_time = time.time(), AverageMeter()
|
||||
for i, index in enumerate(to_evaluate_indexes):
|
||||
arch = all_archs[index]
|
||||
logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th architecture [seeds={:}] {:}'.format('-'*15, i, len(to_evaluate_indexes), index, meta_info['total'], seeds, '-'*15))
|
||||
#logger.log('{:} {:} {:}'.format('-'*15, arch.tostr(), '-'*15))
|
||||
logger.log('{:} {:} {:}'.format('-'*15, arch, '-'*15))
|
||||
|
||||
# test this arch on different datasets with different seeds
|
||||
has_continue = False
|
||||
for seed in seeds:
|
||||
to_save_name = sub_dir / 'arch-{:06d}-seed-{:04d}.pth'.format(index, seed)
|
||||
if to_save_name.exists():
|
||||
if cover_mode:
|
||||
logger.log('Find existing file : {:}, remove it before evaluation'.format(to_save_name))
|
||||
os.remove(str(to_save_name))
|
||||
else :
|
||||
logger.log('Find existing file : {:}, skip this evaluation'.format(to_save_name))
|
||||
has_continue = True
|
||||
continue
|
||||
results = evaluate_all_datasets(CellStructure.str2structure(arch), \
|
||||
datasets, xpaths, splits, seed, \
|
||||
arch_config, workers, logger)
|
||||
torch.save(results, to_save_name)
|
||||
logger.log('{:} --evaluate-- {:06d}/{:06d} ({:06d}/{:06d})-th seed={:} done, save into {:}'.format('-'*15, i, len(to_evaluate_indexes), index, meta_info['total'], seed, to_save_name))
|
||||
# measure elapsed time
|
||||
if not has_continue: epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (len(to_evaluate_indexes)-i-1), True) )
|
||||
logger.log('This arch costs : {:}'.format( convert_secs2time(epoch_time.val, True) ))
|
||||
logger.log('{:}'.format('*'*100))
|
||||
logger.log('{:} {:74s} {:}'.format('*'*10, '{:06d}/{:06d} ({:06d}/{:06d})-th done, left {:}'.format(i, len(to_evaluate_indexes), index, meta_info['total'], need_time), '*'*10))
|
||||
logger.log('{:}'.format('*'*100))
|
||||
|
||||
logger.close()
|
||||
|
||||
|
||||
def train_single_model(save_dir, workers, datasets, xpaths, splits, seeds, model_str, arch_config):
|
||||
assert torch.cuda.is_available(), 'CUDA is not available.'
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.deterministic = True
|
||||
#torch.backends.cudnn.benchmark = True
|
||||
torch.set_num_threads( workers )
|
||||
|
||||
save_dir = Path(save_dir) / 'specifics' / '{:}-{:}-{:}'.format(model_str, arch_config['channel'], arch_config['num_cells'])
|
||||
logger = Logger(str(save_dir), 0, False)
|
||||
if model_str in CellArchitectures:
|
||||
arch = CellArchitectures[model_str]
|
||||
logger.log('The model string is found in pre-defined architecture dict : {:}'.format(model_str))
|
||||
else:
|
||||
try:
|
||||
arch = CellStructure.str2structure(model_str)
|
||||
except:
|
||||
raise ValueError('Invalid model string : {:}. It can not be found or parsed.'.format(model_str))
|
||||
assert arch.check_valid_op(get_search_spaces('cell', 'full')), '{:} has the invalid op.'.format(arch)
|
||||
logger.log('Start train-evaluate {:}'.format(arch.tostr()))
|
||||
logger.log('arch_config : {:}'.format(arch_config))
|
||||
|
||||
start_time, seed_time = time.time(), AverageMeter()
|
||||
for _is, seed in enumerate(seeds):
|
||||
logger.log('\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------'.format(_is, len(seeds), seed))
|
||||
to_save_name = save_dir / 'seed-{:04d}.pth'.format(seed)
|
||||
if to_save_name.exists():
|
||||
logger.log('Find the existing file {:}, directly load!'.format(to_save_name))
|
||||
checkpoint = torch.load(to_save_name)
|
||||
else:
|
||||
logger.log('Does not find the existing file {:}, train and evaluate!'.format(to_save_name))
|
||||
checkpoint = evaluate_all_datasets(arch, datasets, xpaths, splits, seed, arch_config, workers, logger)
|
||||
torch.save(checkpoint, to_save_name)
|
||||
# log information
|
||||
logger.log('{:}'.format(checkpoint['info']))
|
||||
all_dataset_keys = checkpoint['all_dataset_keys']
|
||||
for dataset_key in all_dataset_keys:
|
||||
logger.log('\n{:} dataset : {:} {:}'.format('-'*15, dataset_key, '-'*15))
|
||||
dataset_info = checkpoint[dataset_key]
|
||||
#logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] ))
|
||||
logger.log('Flops = {:} MB, Params = {:} MB'.format(dataset_info['flop'], dataset_info['param']))
|
||||
logger.log('config : {:}'.format(dataset_info['config']))
|
||||
logger.log('Training State (finish) = {:}'.format(dataset_info['finish-train']))
|
||||
last_epoch = dataset_info['total_epoch'] - 1
|
||||
train_acc1es, train_acc5es = dataset_info['train_acc1es'], dataset_info['train_acc5es']
|
||||
valid_acc1es, valid_acc5es = dataset_info['valid_acc1es'], dataset_info['valid_acc5es']
|
||||
logger.log('Last Info : Train = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%, Test = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%'.format(train_acc1es[last_epoch], train_acc5es[last_epoch], 100-train_acc1es[last_epoch], valid_acc1es[last_epoch], valid_acc5es[last_epoch], 100-valid_acc1es[last_epoch]))
|
||||
# measure elapsed time
|
||||
seed_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
need_time = 'Time Left: {:}'.format( convert_secs2time(seed_time.avg * (len(seeds)-_is-1), True) )
|
||||
logger.log('\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}'.format(_is, len(seeds), seed, need_time))
|
||||
logger.close()
|
||||
|
||||
|
||||
def generate_meta_info(save_dir, max_node, divide=40):
|
||||
aa_nas_bench_ss = get_search_spaces('cell', 'aa-nas')
|
||||
archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False)
|
||||
print ('There are {:} archs vs {:}.'.format(len(archs), len(aa_nas_bench_ss) ** ((max_node-1)*max_node/2)))
|
||||
|
||||
random.seed( 88 ) # please do not change this line for reproducibility
|
||||
random.shuffle( archs )
|
||||
# to test fixed-random shuffle
|
||||
#print ('arch [0] : {:}\n---->>>> {:}'.format( archs[0], archs[0].tostr() ))
|
||||
#print ('arch [9] : {:}\n---->>>> {:}'.format( archs[9], archs[9].tostr() ))
|
||||
assert archs[0 ].tostr() == '|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|', 'please check the 0-th architecture : {:}'.format(archs[0])
|
||||
assert archs[9 ].tostr() == '|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|', 'please check the 9-th architecture : {:}'.format(archs[9])
|
||||
assert archs[123].tostr() == '|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|', 'please check the 123-th architecture : {:}'.format(archs[123])
|
||||
total_arch = len(archs)
|
||||
|
||||
num = 50000
|
||||
indexes_5W = list(range(num))
|
||||
random.seed( 1021 )
|
||||
random.shuffle( indexes_5W )
|
||||
train_split = sorted( list(set(indexes_5W[:num//2])) )
|
||||
valid_split = sorted( list(set(indexes_5W[num//2:])) )
|
||||
assert len(train_split) + len(valid_split) == num
|
||||
assert train_split[0] == 0 and train_split[10] == 26 and train_split[111] == 203 and valid_split[0] == 1 and valid_split[10] == 18 and valid_split[111] == 242, '{:} {:} {:} - {:} {:} {:}'.format(train_split[0], train_split[10], train_split[111], valid_split[0], valid_split[10], valid_split[111])
|
||||
splits = {num: {'train': train_split, 'valid': valid_split} }
|
||||
|
||||
info = {'archs' : [x.tostr() for x in archs],
|
||||
'total' : total_arch,
|
||||
'max_node' : max_node,
|
||||
'splits': splits}
|
||||
|
||||
save_dir = Path(save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
save_name = save_dir / 'meta-node-{:}.pth'.format(max_node)
|
||||
assert not save_name.exists(), '{:} already exist'.format(save_name)
|
||||
torch.save(info, save_name)
|
||||
print ('save the meta file into {:}'.format(save_name))
|
||||
|
||||
script_name = save_dir / 'meta-node-{:}.opt-script.txt'.format(max_node)
|
||||
with open(str(script_name), 'w') as cfile:
|
||||
gaps = total_arch // divide
|
||||
for start in range(0, total_arch, gaps):
|
||||
xend = min(start+gaps, total_arch)
|
||||
cfile.write('bash ./scripts-search/AA-NAS-train-archs.sh {:5d} {:5d} -1 \'777 888 999\'\n'.format(start, xend-1))
|
||||
print ('save the training script into {:}'.format(script_name))
|
||||
|
||||
script_name = save_dir / 'meta-node-{:}.cal-script.txt'.format(max_node)
|
||||
macro = 'OMP_NUM_THREADS=6 CUDA_VISIBLE_DEVICES=0'
|
||||
with open(str(script_name), 'w') as cfile:
|
||||
gaps = total_arch // divide
|
||||
for start in range(0, total_arch, gaps):
|
||||
xend = min(start+gaps, total_arch)
|
||||
cfile.write('{:} python exps/AA-NAS-statistics.py --mode cal --target_dir {:06d}-{:06d}-C16-N5\n'.format(macro, start, xend-1))
|
||||
print ('save the post-processing script into {:}'.format(script_name))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
#mode_choices = ['meta', 'new', 'cover'] + ['specific-{:}'.format(_) for _ in CellArchitectures.keys()]
|
||||
parser = argparse.ArgumentParser(description='Algorithm-Agnostic NAS Benchmark', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--mode' , type=str, required=True, help='The script mode.')
|
||||
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
|
||||
parser.add_argument('--max_node', type=int, help='The maximum node in a cell.')
|
||||
# use for train the model
|
||||
parser.add_argument('--workers', type=int, default=8, help='number of data loading workers (default: 2)')
|
||||
parser.add_argument('--srange' , type=int, nargs='+', help='The range of models to be evaluated')
|
||||
parser.add_argument('--arch_index', type=int, default=-1, help='The architecture index to be evaluated (cover mode).')
|
||||
parser.add_argument('--datasets', type=str, nargs='+', help='The applied datasets.')
|
||||
parser.add_argument('--xpaths', type=str, nargs='+', help='The root path for this dataset.')
|
||||
parser.add_argument('--splits', type=int, nargs='+', help='The root path for this dataset.')
|
||||
parser.add_argument('--seeds' , type=int, nargs='+', help='The range of models to be evaluated')
|
||||
parser.add_argument('--channel', type=int, help='The number of channels.')
|
||||
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.mode in ['meta', 'new', 'cover'] or args.mode.startswith('specific-'), 'invalid mode : {:}'.format(args.mode)
|
||||
|
||||
if args.mode == 'meta':
|
||||
generate_meta_info(args.save_dir, args.max_node)
|
||||
elif args.mode.startswith('specific'):
|
||||
assert len(args.mode.split('-')) == 2, 'invalid mode : {:}'.format(args.mode)
|
||||
model_str = args.mode.split('-')[1]
|
||||
train_single_model(args.save_dir, args.workers, args.datasets, args.xpaths, args.splits, \
|
||||
tuple(args.seeds), model_str, {'channel': args.channel, 'num_cells': args.num_cells})
|
||||
else:
|
||||
meta_path = Path(args.save_dir) / 'meta-node-{:}.pth'.format(args.max_node)
|
||||
assert meta_path.exists(), '{:} does not exist.'.format(meta_path)
|
||||
meta_info = torch.load( meta_path )
|
||||
# check whether args is ok
|
||||
assert len(args.srange) == 2 and args.srange[0] <= args.srange[1], 'invalid length of srange args: {:}'.format(args.srange)
|
||||
assert len(args.seeds) > 0, 'invalid length of seeds args: {:}'.format(args.seeds)
|
||||
assert len(args.datasets) == len(args.xpaths) == len(args.splits), 'invalid infos : {:} vs {:} vs {:}'.format(len(args.datasets), len(args.xpaths), len(args.splits))
|
||||
assert args.workers > 0, 'invalid number of workers : {:}'.format(args.workers)
|
||||
|
||||
main(args.save_dir, args.workers, args.datasets, args.xpaths, args.splits, \
|
||||
tuple(args.srange), args.arch_index, tuple(args.seeds), \
|
||||
args.mode == 'cover', meta_info, \
|
||||
{'channel': args.channel, 'num_cells': args.num_cells})
|
288
exps/AA-NAS-statistics.py
Normal file
288
exps/AA-NAS-statistics.py
Normal file
@ -0,0 +1,288 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, time, argparse, collections
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
lib_dir = (Path(__file__).parent / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from config_utils import load_config, dict2config
|
||||
from datasets import get_datasets
|
||||
# AA-NAS-Bench related module or function
|
||||
from models import CellStructure, get_cell_based_tiny_net
|
||||
from aa_nas_api import ArchResults, ResultsCount
|
||||
from AA_functions import pure_evaluate
|
||||
|
||||
|
||||
|
||||
def account_one_arch(arch_index, arch_str, checkpoints, datasets, dataloader_dict):
|
||||
information = ArchResults(arch_index, arch_str)
|
||||
|
||||
for checkpoint_path in checkpoints:
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
used_seed = checkpoint_path.name.split('-')[-1].split('.')[0]
|
||||
for dataset in datasets:
|
||||
assert dataset in checkpoint, 'Can not find {:} in arch-{:} from {:}'.format(dataset, arch_index, checkpoint_path)
|
||||
results = checkpoint[dataset]
|
||||
assert results['finish-train'], 'This {:} arch seed={:} does not finish train on {:} ::: {:}'.format(arch_index, used_seed, dataset, checkpoint_path)
|
||||
arch_config = {'channel': results['channel'], 'num_cells': results['num_cells'], 'arch_str': arch_str, 'class_num': results['config']['class_num']}
|
||||
xresult = ResultsCount(dataset, results['net_state_dict'], results['train_acc1es'], results['train_losses'], \
|
||||
results['param'], results['flop'], arch_config, used_seed, results['total_epoch'], None)
|
||||
if dataset == 'cifar10-valid':
|
||||
xresult.update_eval('x-valid' , results['valid_acc1es'], results['valid_losses'])
|
||||
elif dataset == 'cifar10':
|
||||
xresult.update_eval('ori-test', results['valid_acc1es'], results['valid_losses'])
|
||||
elif dataset == 'cifar100' or dataset == 'ImageNet16-120':
|
||||
xresult.update_eval('ori-test', results['valid_acc1es'], results['valid_losses'])
|
||||
net_config = dict2config({'name': 'infer.tiny', 'C': arch_config['channel'], 'N': arch_config['num_cells'],
|
||||
'genotype': CellStructure.str2structure(arch_config['arch_str']), 'num_classes':arch_config['class_num']}, None)
|
||||
network = get_cell_based_tiny_net(net_config)
|
||||
network.load_state_dict(xresult.get_net_param())
|
||||
network = network.cuda()
|
||||
loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format(dataset, 'valid')], network)
|
||||
xresult.update_eval('x-valid', {results['total_epoch']-1: top1}, {results['total_epoch']-1: loss})
|
||||
loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format(dataset, 'test')], network)
|
||||
xresult.update_eval('x-test' , {results['total_epoch']-1: top1}, {results['total_epoch']-1: loss})
|
||||
xresult.update_latency(latencies)
|
||||
else:
|
||||
raise ValueError('invalid dataset name : {:}'.format(dataset))
|
||||
information.update(dataset, int(used_seed), xresult)
|
||||
return information
|
||||
|
||||
|
||||
|
||||
def GET_DataLoaders(workers):
|
||||
|
||||
torch.set_num_threads(workers)
|
||||
|
||||
root_dir = (Path(__file__).parent / '..').resolve()
|
||||
torch_dir = Path(os.environ['TORCH_HOME'])
|
||||
# cifar
|
||||
cifar_config_path = root_dir / 'configs' / 'nas-benchmark' / 'CIFAR.config'
|
||||
cifar_config = load_config(cifar_config_path, None, None)
|
||||
print ('{:} Create data-loader for all datasets'.format(time_string()))
|
||||
print ('-'*200)
|
||||
TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets('cifar10', str(torch_dir/'cifar.python'), -1)
|
||||
print ('original CIFAR-10 : {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_CIFAR10), len(VALID_CIFAR10), xshape, class_num))
|
||||
cifar10_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'cifar-split.txt', None, None)
|
||||
assert cifar10_splits.train[:10] == [0, 5, 7, 11, 13, 15, 16, 17, 20, 24] and cifar10_splits.valid[:10] == [1, 2, 3, 4, 6, 8, 9, 10, 12, 14]
|
||||
temp_dataset = deepcopy(TRAIN_CIFAR10)
|
||||
temp_dataset.transform = VALID_CIFAR10.transform
|
||||
# data loader
|
||||
trainval_cifar10_loader = torch.utils.data.DataLoader(TRAIN_CIFAR10, batch_size=cifar_config.batch_size, shuffle=True , num_workers=workers, pin_memory=True)
|
||||
train_cifar10_loader = torch.utils.data.DataLoader(TRAIN_CIFAR10, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.train), num_workers=workers, pin_memory=True)
|
||||
valid_cifar10_loader = torch.utils.data.DataLoader(temp_dataset , batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.valid), num_workers=workers, pin_memory=True)
|
||||
test__cifar10_loader = torch.utils.data.DataLoader(VALID_CIFAR10, batch_size=cifar_config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)
|
||||
print ('CIFAR-10 : trval-loader has {:3d} batch with {:} per batch'.format(len(trainval_cifar10_loader), cifar_config.batch_size))
|
||||
print ('CIFAR-10 : train-loader has {:3d} batch with {:} per batch'.format(len(train_cifar10_loader), cifar_config.batch_size))
|
||||
print ('CIFAR-10 : valid-loader has {:3d} batch with {:} per batch'.format(len(valid_cifar10_loader), cifar_config.batch_size))
|
||||
print ('CIFAR-10 : test--loader has {:3d} batch with {:} per batch'.format(len(test__cifar10_loader), cifar_config.batch_size))
|
||||
print ('-'*200)
|
||||
# CIFAR-100
|
||||
TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets('cifar100', str(torch_dir/'cifar.python'), -1)
|
||||
print ('original CIFAR-100: {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_CIFAR100), len(VALID_CIFAR100), xshape, class_num))
|
||||
cifar100_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'cifar100-test-split.txt', None, None)
|
||||
assert cifar100_splits.xvalid[:10] == [1, 3, 4, 5, 8, 10, 13, 14, 15, 16] and cifar100_splits.xtest[:10] == [0, 2, 6, 7, 9, 11, 12, 17, 20, 24]
|
||||
train_cifar100_loader = torch.utils.data.DataLoader(TRAIN_CIFAR100, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True)
|
||||
valid_cifar100_loader = torch.utils.data.DataLoader(VALID_CIFAR100, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), num_workers=workers, pin_memory=True)
|
||||
test__cifar100_loader = torch.utils.data.DataLoader(VALID_CIFAR100, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest) , num_workers=workers, pin_memory=True)
|
||||
print ('CIFAR-100 : train-loader has {:3d} batch'.format(len(train_cifar100_loader)))
|
||||
print ('CIFAR-100 : valid-loader has {:3d} batch'.format(len(valid_cifar100_loader)))
|
||||
print ('CIFAR-100 : test--loader has {:3d} batch'.format(len(test__cifar100_loader)))
|
||||
print ('-'*200)
|
||||
|
||||
imagenet16_config_path = 'configs/nas-benchmark/ImageNet-16.config'
|
||||
imagenet16_config = load_config(imagenet16_config_path, None, None)
|
||||
TRAIN_ImageNet16_120, VALID_ImageNet16_120, xshape, class_num = get_datasets('ImageNet16-120', str(torch_dir/'cifar.python'/'ImageNet16'), -1)
|
||||
print ('original TRAIN_ImageNet16_120: {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_ImageNet16_120), len(VALID_ImageNet16_120), xshape, class_num))
|
||||
imagenet_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'imagenet-16-120-test-split.txt', None, None)
|
||||
assert imagenet_splits.xvalid[:10] == [1, 2, 3, 6, 7, 8, 9, 12, 16, 18] and imagenet_splits.xtest[:10] == [0, 4, 5, 10, 11, 13, 14, 15, 17, 20]
|
||||
train_imagenet_loader = torch.utils.data.DataLoader(TRAIN_ImageNet16_120, batch_size=imagenet16_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True)
|
||||
valid_imagenet_loader = torch.utils.data.DataLoader(VALID_ImageNet16_120, batch_size=imagenet16_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xvalid), num_workers=workers, pin_memory=True)
|
||||
test__imagenet_loader = torch.utils.data.DataLoader(VALID_ImageNet16_120, batch_size=imagenet16_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xtest) , num_workers=workers, pin_memory=True)
|
||||
print ('ImageNet-16-120 : train-loader has {:3d} batch with {:} per batch'.format(len(train_imagenet_loader), imagenet16_config.batch_size))
|
||||
print ('ImageNet-16-120 : valid-loader has {:3d} batch with {:} per batch'.format(len(valid_imagenet_loader), imagenet16_config.batch_size))
|
||||
print ('ImageNet-16-120 : test--loader has {:3d} batch with {:} per batch'.format(len(test__imagenet_loader), imagenet16_config.batch_size))
|
||||
|
||||
# 'cifar10', 'cifar100', 'ImageNet16-120'
|
||||
loaders = {'cifar10@trainval': trainval_cifar10_loader,
|
||||
'cifar10@train' : train_cifar10_loader,
|
||||
'cifar10@valid' : valid_cifar10_loader,
|
||||
'cifar10@test' : test__cifar10_loader,
|
||||
'cifar100@train' : train_cifar100_loader,
|
||||
'cifar100@valid' : valid_cifar100_loader,
|
||||
'cifar100@test' : test__cifar100_loader,
|
||||
'ImageNet16-120@train': train_imagenet_loader,
|
||||
'ImageNet16-120@valid': valid_imagenet_loader,
|
||||
'ImageNet16-120@test' : test__imagenet_loader}
|
||||
return loaders
|
||||
|
||||
|
||||
|
||||
def simplify(save_dir, meta_file, basestr, target_dir):
|
||||
meta_infos = torch.load(meta_file, map_location='cpu')
|
||||
meta_archs = meta_infos['archs'] # a list of architecture strings
|
||||
meta_num_archs = meta_infos['total']
|
||||
meta_max_node = meta_infos['max_node']
|
||||
assert meta_num_archs == len(meta_archs), 'invalid number of archs : {:} vs {:}'.format(meta_num_archs, len(meta_archs))
|
||||
|
||||
sub_model_dirs = sorted(list(save_dir.glob('*-*-{:}'.format(basestr))))
|
||||
print ('{:} find {:} directories used to save checkpoints'.format(time_string(), len(sub_model_dirs)))
|
||||
|
||||
subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0
|
||||
num_seeds = defaultdict(lambda: 0)
|
||||
for index, sub_dir in enumerate(sub_model_dirs):
|
||||
xcheckpoints = list(sub_dir.glob('arch-*-seed-*.pth'))
|
||||
arch_indexes = set()
|
||||
for checkpoint in xcheckpoints:
|
||||
temp_names = checkpoint.name.split('-')
|
||||
assert len(temp_names) == 4 and temp_names[0] == 'arch' and temp_names[2] == 'seed', 'invalid checkpoint name : {:}'.format(checkpoint.name)
|
||||
arch_indexes.add( temp_names[1] )
|
||||
subdir2archs[sub_dir] = sorted(list(arch_indexes))
|
||||
num_evaluated_arch += len(arch_indexes)
|
||||
# count number of seeds for each architecture
|
||||
for arch_index in arch_indexes:
|
||||
num_seeds[ len(list(sub_dir.glob('arch-{:}-seed-*.pth'.format(arch_index)))) ] += 1
|
||||
print('{:} There are {:5d} architectures that have been evaluated ({:} in total).'.format(time_string(), num_evaluated_arch, meta_num_archs))
|
||||
for key in sorted( list( num_seeds.keys() ) ): print ('{:} There are {:5d} architectures that are evaluated {:} times.'.format(time_string(), num_seeds[key], key))
|
||||
|
||||
dataloader_dict = GET_DataLoaders( 6 )
|
||||
|
||||
to_save_simply = save_dir / 'simplifies'
|
||||
to_save_allarc = save_dir / 'simplifies' / 'architectures'
|
||||
if not to_save_simply.exists(): to_save_simply.mkdir(parents=True, exist_ok=True)
|
||||
if not to_save_allarc.exists(): to_save_allarc.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
assert (save_dir / target_dir) in subdir2archs, 'can not find {:}'.format(target_dir)
|
||||
arch2infos, datasets = {}, ('cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120')
|
||||
evaluated_indexes = set()
|
||||
target_directory = save_dir / target_dir
|
||||
arch_indexes = subdir2archs[ target_directory ]
|
||||
num_seeds = defaultdict(lambda: 0)
|
||||
end_time = time.time()
|
||||
arch_time = AverageMeter()
|
||||
for idx, arch_index in enumerate(arch_indexes):
|
||||
checkpoints = list(target_directory.glob('arch-{:}-seed-*.pth'.format(arch_index)))
|
||||
arch_info = account_one_arch(arch_index, meta_archs[int(arch_index)], checkpoints, datasets, dataloader_dict)
|
||||
try:
|
||||
arch_info = account_one_arch(arch_index, meta_archs[int(arch_index)], checkpoints, datasets, dataloader_dict)
|
||||
num_seeds[ len(checkpoints) ] += 1
|
||||
except:
|
||||
print('Loading {:} failed, : {:}'.format(arch_index, checkpoints))
|
||||
continue
|
||||
assert int(arch_index) not in evaluated_indexes, 'conflict arch-index : {:}'.format(arch_index)
|
||||
assert 0 <= int(arch_index) < len(meta_archs), 'invalid arch-index {:} (not found in meta_archs)'.format(arch_index)
|
||||
evaluated_indexes.add( int(arch_index) )
|
||||
arch2infos[int(arch_index)] = arch_info
|
||||
torch.save(arch_info.state_dict(), to_save_allarc / '{:}-FULL.pth'.format(arch_index))
|
||||
#torch.save(arch_info, to_save_allarc / '{:}-FULL.pth'.format(arch_index))
|
||||
arch_info.clear_params()
|
||||
torch.save(arch_info, to_save_allarc / '{:}-SIMPLE.pth'.format(arch_index))
|
||||
# measure elapsed time
|
||||
arch_time.update(time.time() - end_time)
|
||||
end_time = time.time()
|
||||
need_time = '{:}'.format( convert_secs2time(arch_time.avg * (len(arch_indexes)-idx-1), True) )
|
||||
print('{:} {:} [{:03d}/{:03d}] : {:} still need {:}'.format(time_string(), target_dir, idx, len(arch_indexes), arch_index, need_time))
|
||||
# measure time
|
||||
xstrs = ['{:}:{:03d}'.format(key, num_seeds[key]) for key in sorted( list( num_seeds.keys() ) ) ]
|
||||
print('{:} {:} done : {:}'.format(time_string(), target_dir, xstrs))
|
||||
final_infos = {'meta_archs' : meta_archs,
|
||||
'total_archs': meta_num_archs,
|
||||
'basestr' : basestr,
|
||||
'arch2infos' : arch2infos,
|
||||
'evaluated_indexes': evaluated_indexes}
|
||||
save_file_name = to_save_simply / '{:}.pth'.format(target_dir)
|
||||
torch.save(final_infos, save_file_name)
|
||||
print ('Save {:} / {:} architecture results into {:}.'.format(len(evaluated_indexes), meta_num_archs, save_file_name))
|
||||
|
||||
|
||||
|
||||
def merge_all(save_dir, meta_file, basestr):
|
||||
meta_infos = torch.load(meta_file, map_location='cpu')
|
||||
meta_archs = meta_infos['archs']
|
||||
meta_num_archs = meta_infos['total']
|
||||
meta_max_node = meta_infos['max_node']
|
||||
assert meta_num_archs == len(meta_archs), 'invalid number of archs : {:} vs {:}'.format(meta_num_archs, len(meta_archs))
|
||||
|
||||
sub_model_dirs = sorted(list(save_dir.glob('*-*-{:}'.format(basestr))))
|
||||
print ('{:} find {:} directories used to save checkpoints'.format(time_string(), len(sub_model_dirs)))
|
||||
for index, sub_dir in enumerate(sub_model_dirs):
|
||||
arch_info_files = sorted( list(sub_dir.glob('arch-*-seed-*.pth') ) )
|
||||
print ('The {:02d}/{:02d}-th directory : {:} : {:} runs.'.format(index, len(sub_model_dirs), sub_dir, len(arch_info_files)))
|
||||
|
||||
subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0
|
||||
num_seeds = defaultdict(lambda: 0)
|
||||
for index, sub_dir in enumerate(sub_model_dirs):
|
||||
xcheckpoints = list(sub_dir.glob('arch-*-seed-*.pth'))
|
||||
arch_indexes = set()
|
||||
for checkpoint in xcheckpoints:
|
||||
temp_names = checkpoint.name.split('-')
|
||||
assert len(temp_names) == 4 and temp_names[0] == 'arch' and temp_names[2] == 'seed', 'invalid checkpoint name : {:}'.format(checkpoint.name)
|
||||
arch_indexes.add( temp_names[1] )
|
||||
subdir2archs[sub_dir] = sorted(list(arch_indexes))
|
||||
num_evaluated_arch += len(arch_indexes)
|
||||
# count number of seeds for each architecture
|
||||
for arch_index in arch_indexes:
|
||||
num_seeds[ len(list(sub_dir.glob('arch-{:}-seed-*.pth'.format(arch_index)))) ] += 1
|
||||
print('There are {:5d} architectures that have been evaluated ({:} in total).'.format(num_evaluated_arch, meta_num_archs))
|
||||
for key in sorted( list( num_seeds.keys() ) ): print ('There are {:5d} architectures that are evaluated {:} times.'.format(num_seeds[key], key))
|
||||
|
||||
arch2infos, evaluated_indexes = dict(), set()
|
||||
for IDX, (sub_dir, arch_indexes) in enumerate(subdir2archs.items()):
|
||||
ckp_path = sub_dir.parent / 'simplifies' / '{:}.pth'.format(sub_dir.name)
|
||||
if ckp_path.exists():
|
||||
sub_ckps = torch.load(ckp_path, map_location='cpu')
|
||||
assert sub_ckps['total_archs'] == meta_num_archs and sub_ckps['basestr'] == basestr
|
||||
xarch2infos = sub_ckps['arch2infos']
|
||||
xevalindexs = sub_ckps['evaluated_indexes']
|
||||
for eval_index in xevalindexs:
|
||||
assert eval_index not in evaluated_indexes and eval_index not in arch2infos
|
||||
arch2infos[eval_index] = xarch2infos[eval_index]
|
||||
evaluated_indexes.add( eval_index )
|
||||
print ('{:} [{:03d}/{:03d}] merge data from {:} with {:} models.'.format(time_string(), IDX, len(subdir2archs), ckp_path, len(xevalindexs)))
|
||||
else:
|
||||
print ('{:} [{:03d}/{:03d}] can not find {:}, skip.'.format(time_string(), IDX, len(subdir2archs), ckp_path))
|
||||
|
||||
evaluated_indexes = sorted( list( evaluated_indexes ) )
|
||||
print ('Finally, there are {:} models.'.format(len(evaluated_indexes)))
|
||||
|
||||
to_save_simply = save_dir / 'simplifies'
|
||||
if not to_save_simply.exists(): to_save_simply.mkdir(parents=True, exist_ok=True)
|
||||
final_infos = {'meta_archs' : meta_archs,
|
||||
'total_archs': meta_num_archs,
|
||||
'arch2infos' : arch2infos,
|
||||
'evaluated_indexes': evaluated_indexes}
|
||||
save_file_name = to_save_simply / '{:}-final-infos.pth'.format(basestr)
|
||||
torch.save(final_infos, save_file_name)
|
||||
print ('Save {:} / {:} architecture results into {:}.'.format(len(evaluated_indexes), meta_num_archs, save_file_name))
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser(description='An Algorithm-Agnostic (AA) NAS Benchmark', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--mode' , type=str, choices=['cal', 'merge'], help='The running mode for this script.')
|
||||
parser.add_argument('--base_save_dir', type=str, default='./output/AA-NAS-BENCH-4', help='The base-name of folder to save checkpoints and log.')
|
||||
parser.add_argument('--target_dir' , type=str, help='The target directory.')
|
||||
parser.add_argument('--max_node' , type=int, default=4, help='The maximum node in a cell.')
|
||||
parser.add_argument('--channel' , type=int, default=16, help='The number of channels.')
|
||||
parser.add_argument('--num_cells' , type=int, default=5, help='The number of cells in one stage.')
|
||||
args = parser.parse_args()
|
||||
|
||||
save_dir = Path( args.base_save_dir )
|
||||
meta_path = save_dir / 'meta-node-{:}.pth'.format(args.max_node)
|
||||
assert save_dir.exists(), 'invalid save dir path : {:}'.format(save_dir)
|
||||
assert meta_path.exists(), 'invalid saved meta path : {:}'.format(meta_path)
|
||||
print ('start the statistics of our nas-benchmark from {:} using {:}.'.format(save_dir, args.target_dir))
|
||||
basestr = 'C{:}-N{:}'.format(args.channel, args.num_cells)
|
||||
|
||||
if args.mode == 'cal':
|
||||
simplify(save_dir, meta_path, basestr, args.target_dir)
|
||||
elif args.mode == 'merge':
|
||||
merge_all(save_dir, meta_path, basestr)
|
||||
else:
|
||||
raise ValueError('invalid mode : {:}'.format(args.mode))
|
@ -5,8 +5,10 @@ lib_dir = (Path(__file__).parent / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
|
||||
from log_utils import time_string
|
||||
from aa_nas_api import AANASBenchAPI, ArchResults
|
||||
from models import CellStructure
|
||||
|
||||
|
||||
def get_unique_matrix(archs, consider_zero):
|
||||
UniquStrs = [arch.to_unique_str(consider_zero) for arch in archs]
|
||||
print ('{:} create unique-string done'.format(time_string()))
|
||||
@ -24,15 +26,37 @@ def get_unique_matrix(archs, consider_zero):
|
||||
unique_num += 1
|
||||
return sm_matrix, unique_ids, unique_num
|
||||
|
||||
|
||||
def check_unique_arch():
|
||||
print ('{:} start'.format(time_string()))
|
||||
meta_info = torch.load('./output/AA-NAS-BENCH-4/meta-node-4.pth')
|
||||
arch_strs = meta_info['archs']
|
||||
archs = [CellStructure.str2structure(arch_str) for arch_str in arch_strs]
|
||||
_, _, unique_num = get_unique_matrix(archs, False)
|
||||
"""
|
||||
for i, arch in enumerate(archs):
|
||||
if not arch.check_valid():
|
||||
print('{:05d} {:}'.format(i, arch))
|
||||
#start = int(i / 390.) * 390
|
||||
#xxend = start + 389
|
||||
#print ('/home/dxy/search-configures/output/TINY-NAS-BENCHMARK-4/{:06d}-{:06d}-C16-N5/arch-{:06d}-seed-0888.pth'.format(start, xxend, i))
|
||||
"""
|
||||
print ('There are {:} valid-archs'.format( sum(arch.check_valid() for arch in archs) ))
|
||||
sm_matrix, uniqueIDs, unique_num = get_unique_matrix(archs, False)
|
||||
save_dir = './output/cell-search-tiny/same-matrix.pth'
|
||||
torch.save(sm_matrix, save_dir)
|
||||
print ('{:} There are {:} unique architectures (not considering zero).'.format(time_string(), unique_num))
|
||||
_, _, unique_num = get_unique_matrix(archs, True)
|
||||
sm_matrix, uniqueIDs, unique_num = get_unique_matrix(archs, True)
|
||||
print ('{:} There are {:} unique architectures (considering zero).'.format(time_string(), unique_num))
|
||||
|
||||
|
||||
def test_aa_nas_api():
|
||||
arch_result = ArchResults.create_from_state_dict('output/AA-NAS-BENCH-4/simplifies/architectures/000002-FULL.pth')
|
||||
arch_result.show(True)
|
||||
result = arch_result.query('cifar100')
|
||||
#xfile = '/home/dxy/search-configures/output/TINY-NAS-BENCHMARK-4/simplifies/C16-N5-final-infos.pth'
|
||||
#api = AANASBenchAPI(xfile)
|
||||
import pdb; pdb.set_trace()
|
||||
|
||||
if __name__ == '__main__':
|
||||
check_unique_arch()
|
||||
#check_unique_arch()
|
||||
test_aa_nas_api()
|
||||
|
@ -1,6 +1,3 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, time, torch
|
||||
from procedures import prepare_seed, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
|
@ -1,6 +1,6 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
########################################################
|
||||
# DARTS: Differentiable Architecture Search, ICLR 2019 #
|
||||
########################################################
|
||||
import os, sys, time, glob, random, argparse
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
|
@ -1,6 +1,6 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
########################################################
|
||||
# DARTS: Differentiable Architecture Search, ICLR 2019 #
|
||||
########################################################
|
||||
import os, sys, time, glob, random, argparse
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
|
@ -1,6 +1,3 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, time, glob, random, argparse
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
|
@ -1,6 +1,3 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, time, glob, random, argparse
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
|
236
exps/algos/RANDOM-NAS.py
Normal file
236
exps/algos/RANDOM-NAS.py
Normal file
@ -0,0 +1,236 @@
|
||||
import os, sys, time, glob, random, argparse
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import load_config, dict2config, configure2str
|
||||
from datasets import get_datasets, SearchDataset
|
||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from models import get_cell_based_tiny_net, get_search_spaces
|
||||
|
||||
|
||||
def search_func(xloader, network, criterion, scheduler, w_optimizer, epoch_str, print_freq, logger):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
network.train()
|
||||
end = time.time()
|
||||
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
|
||||
scheduler.update(None, 1.0 * step / len(xloader))
|
||||
base_targets = base_targets.cuda(non_blocking=True)
|
||||
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
# update the weights
|
||||
network.module.random_genotype( True )
|
||||
w_optimizer.zero_grad()
|
||||
_, logits = network(base_inputs)
|
||||
base_loss = criterion(logits, base_targets)
|
||||
base_loss.backward()
|
||||
nn.utils.clip_grad_norm_(network.parameters(), 5)
|
||||
w_optimizer.step()
|
||||
# record
|
||||
base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
|
||||
base_losses.update(base_loss.item(), base_inputs.size(0))
|
||||
base_top1.update (base_prec1.item(), base_inputs.size(0))
|
||||
base_top5.update (base_prec5.item(), base_inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if step % print_freq == 0 or step + 1 == len(xloader):
|
||||
Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader))
|
||||
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||
Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5)
|
||||
logger.log(Sstr + ' ' + Tstr + ' ' + Wstr)
|
||||
return base_losses.avg, base_top1.avg, base_top5.avg
|
||||
|
||||
|
||||
def valid_func(xloader, network, criterion):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
network.eval()
|
||||
end = time.time()
|
||||
with torch.no_grad():
|
||||
for step, (arch_inputs, arch_targets) in enumerate(xloader):
|
||||
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
# prediction
|
||||
|
||||
network.module.random_genotype( True )
|
||||
_, logits = network(arch_inputs)
|
||||
arch_loss = criterion(logits, arch_targets)
|
||||
# record
|
||||
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
|
||||
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
||||
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
|
||||
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
return arch_losses.avg, arch_top1.avg, arch_top5.avg
|
||||
|
||||
|
||||
def main(xargs):
|
||||
assert torch.cuda.is_available(), 'CUDA is not available.'
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.set_num_threads( xargs.workers )
|
||||
prepare_seed(xargs.rand_seed)
|
||||
logger = prepare_logger(args)
|
||||
|
||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
|
||||
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||
cifar_split = load_config(split_Fpath, None, None)
|
||||
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||
#elif xargs.dataset.startswith('ImageNet16'):
|
||||
# # all_indexes = list(range(len(train_data))) ; random.seed(111) ; random.shuffle(all_indexes)
|
||||
# # train_split, valid_split = sorted(all_indexes[: len(train_data)//2]), sorted(all_indexes[len(train_data)//2 :])
|
||||
# # imagenet16_split = dict2config({'train': train_split, 'valid': valid_split}, None)
|
||||
# # _ = configure2str(imagenet16_split, 'temp.txt')
|
||||
# split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset)
|
||||
# imagenet16_split = load_config(split_Fpath, None, None)
|
||||
# train_split, valid_split = imagenet16_split.train, imagenet16_split.valid
|
||||
# logger.log('Load split file from {:}'.format(split_Fpath))
|
||||
else:
|
||||
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
|
||||
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||
logger.log('config : {:}'.format(config))
|
||||
# To split data
|
||||
train_data_v2 = deepcopy(train_data)
|
||||
train_data_v2.transform = valid_data.transform
|
||||
valid_data = train_data_v2
|
||||
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
||||
# data loader
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
|
||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||
|
||||
search_space = get_search_spaces('cell', xargs.search_space_name)
|
||||
model_config = dict2config({'name': 'RANDOM', 'C': xargs.channel, 'N': xargs.num_cells,
|
||||
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
|
||||
'space' : search_space}, None)
|
||||
search_model = get_cell_based_tiny_net(model_config)
|
||||
|
||||
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.parameters(), config)
|
||||
logger.log('w-optimizer : {:}'.format(w_optimizer))
|
||||
logger.log('w-scheduler : {:}'.format(w_scheduler))
|
||||
logger.log('criterion : {:}'.format(criterion))
|
||||
|
||||
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
|
||||
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
|
||||
|
||||
if last_info.exists(): # automatically resume from previous checkpoint
|
||||
logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
|
||||
last_info = torch.load(last_info)
|
||||
start_epoch = last_info['epoch']
|
||||
checkpoint = torch.load(last_info['last_checkpoint'])
|
||||
valid_accuracies = checkpoint['valid_accuracies']
|
||||
search_model.load_state_dict( checkpoint['search_model'] )
|
||||
w_scheduler.load_state_dict ( checkpoint['w_scheduler'] )
|
||||
w_optimizer.load_state_dict ( checkpoint['w_optimizer'] )
|
||||
logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
|
||||
else:
|
||||
logger.log("=> do not find the last-info file : {:}".format(last_info))
|
||||
start_epoch, valid_accuracies = 0, {'best': -1}
|
||||
|
||||
# start training
|
||||
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup
|
||||
for epoch in range(start_epoch, total_epoch):
|
||||
w_scheduler.update(epoch, 0.0)
|
||||
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
|
||||
epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
|
||||
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr())))
|
||||
|
||||
search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger)
|
||||
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5))
|
||||
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
|
||||
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
|
||||
# check the best accuracy
|
||||
valid_accuracies[epoch] = valid_a_top1
|
||||
if valid_a_top1 > valid_accuracies['best']:
|
||||
valid_accuracies['best'] = valid_a_top1
|
||||
find_best = True
|
||||
else: find_best = False
|
||||
|
||||
# save checkpoint
|
||||
save_path = save_checkpoint({'epoch' : epoch + 1,
|
||||
'args' : deepcopy(xargs),
|
||||
'search_model': search_model.state_dict(),
|
||||
'w_optimizer' : w_optimizer.state_dict(),
|
||||
'w_scheduler' : w_scheduler.state_dict(),
|
||||
'valid_accuracies' : valid_accuracies},
|
||||
model_base_path, logger)
|
||||
last_info = save_checkpoint({
|
||||
'epoch': epoch + 1,
|
||||
'args' : deepcopy(args),
|
||||
'last_checkpoint': save_path,
|
||||
}, logger.path('info'), logger)
|
||||
if find_best:
|
||||
logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1))
|
||||
copy_checkpoint(model_base_path, model_best_path, logger)
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
|
||||
logger.log('\n' + '-'*200)
|
||||
|
||||
best_arch, best_acc = None, -1
|
||||
for iarch in range(xargs.select_num):
|
||||
arch = search_model.random_genotype( True )
|
||||
valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion)
|
||||
logger.log('final evaluation [{:02d}/{:02d}] : {:} : accuracy={:.2f}%, loss={:.3f}'.format(iarch, xargs.select_num, arch, valid_a_top1, valid_a_loss))
|
||||
if best_arch is None or best_acc < valid_a_top1:
|
||||
best_arch, best_acc = arch, valid_a_top1
|
||||
|
||||
logger.log('Find the best one : {:} with accuracy={:.2f}%'.format(best_arch, best_acc))
|
||||
|
||||
logger.log('\n' + '-'*100)
|
||||
"""
|
||||
# check the performance from the architecture dataset
|
||||
if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset):
|
||||
logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset))
|
||||
else:
|
||||
nas_bench = TinyNASBenchmarkAPI(xargs.arch_nas_dataset)
|
||||
geno = best_arch
|
||||
logger.log('The last model is {:}'.format(geno))
|
||||
info = nas_bench.query_by_arch( geno )
|
||||
if info is None: logger.log('Did not find this architecture : {:}.'.format(geno))
|
||||
else : logger.log('{:}'.format(info))
|
||||
logger.log('-'*100)
|
||||
logger.close()
|
||||
"""
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser("Random search for NAS.")
|
||||
parser.add_argument('--data_path', type=str, help='Path to dataset')
|
||||
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
|
||||
# channels and number-of-cells
|
||||
parser.add_argument('--search_space_name', type=str, help='The search space name.')
|
||||
parser.add_argument('--config_path', type=str, help='The path to the configuration.')
|
||||
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
|
||||
parser.add_argument('--channel', type=int, help='The number of channels.')
|
||||
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
|
||||
parser.add_argument('--select_num', type=int, help='The number of selected architectures to evaluate.')
|
||||
# log
|
||||
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
|
||||
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
|
||||
parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).')
|
||||
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
|
||||
parser.add_argument('--rand_seed', type=int, help='manual seed')
|
||||
args = parser.parse_args()
|
||||
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
|
||||
main(args)
|
@ -1,8 +1,6 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019
|
||||
##################################################
|
||||
######################################################################################
|
||||
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 #
|
||||
######################################################################################
|
||||
import os, sys, time, glob, random, argparse
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
@ -24,6 +22,7 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer
|
||||
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
end = time.time()
|
||||
network.train()
|
||||
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
|
||||
scheduler.update(None, 1.0 * step / len(xloader))
|
||||
base_targets = base_targets.cuda(non_blocking=True)
|
||||
@ -32,13 +31,11 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
# update the weights
|
||||
network.train()
|
||||
sampled_arch = network.module.dync_genotype(True)
|
||||
network.module.set_cal_mode('dynamic', sampled_arch)
|
||||
#network.module.set_cal_mode( 'urs' )
|
||||
network.zero_grad()
|
||||
_, logits = network( torch.cat((base_inputs, arch_inputs), dim=0) )
|
||||
logits = logits[:base_inputs.size(0)]
|
||||
_, logits = network(base_inputs)
|
||||
base_loss = criterion(logits, base_targets)
|
||||
base_loss.backward()
|
||||
w_optimizer.step()
|
||||
@ -49,7 +46,6 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer
|
||||
base_top5.update (base_prec5.item(), base_inputs.size(0))
|
||||
|
||||
# update the architecture-weight
|
||||
network.eval()
|
||||
network.module.set_cal_mode( 'joint' )
|
||||
network.zero_grad()
|
||||
_, logits = network(arch_inputs)
|
||||
@ -257,6 +253,7 @@ def main(xargs):
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
|
||||
logger.log('During searching, the best gentotype is : {:} , with the validation accuracy of {:.3f}%.'.format(genotypes['best'], valid_accuracies['best']))
|
||||
# sampling
|
||||
"""
|
||||
with torch.no_grad():
|
||||
|
@ -1,6 +1,3 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, time, torch, random, argparse
|
||||
from PIL import ImageFile
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
@ -1,6 +1,3 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import sys, time, torch, random, argparse
|
||||
from PIL import ImageFile
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
@ -1,6 +1,3 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import sys, time, torch, random, argparse
|
||||
from PIL import ImageFile
|
||||
from os import path as osp
|
||||
|
@ -1,6 +1,6 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
#######################################################################
|
||||
# Network Pruning via Transformable Architecture Search, NeurIPS 2019 #
|
||||
#######################################################################
|
||||
import sys, time, torch, random, argparse
|
||||
from PIL import ImageFile
|
||||
from os import path as osp
|
||||
|
2
lib/aa_nas_api/__init__.py
Normal file
2
lib/aa_nas_api/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .api import AANASBenchAPI
|
||||
from .api import ArchResults, ResultsCount
|
290
lib/aa_nas_api/api.py
Normal file
290
lib/aa_nas_api/api.py
Normal file
@ -0,0 +1,290 @@
|
||||
import os, sys, copy, torch, numpy as np
|
||||
|
||||
|
||||
|
||||
def print_information(information, extra_info=None, show=False):
|
||||
dataset_names = information.get_dataset_names()
|
||||
strings = [information.arch_str, 'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info)]
|
||||
def metric2str(loss, acc):
|
||||
return 'loss = {:.3f}, top1 = {:.2f}%'.format(loss, acc)
|
||||
|
||||
for ida, dataset in enumerate(dataset_names):
|
||||
flop, param, latency = information.get_comput_costs(dataset)
|
||||
str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(dataset, flop, param, '{:.2f}'.format(latency*1000) if latency > 0 else None)
|
||||
train_loss, train_acc = information.get_metrics(dataset, 'train')
|
||||
if dataset == 'cifar10-valid':
|
||||
valid_loss, valid_acc = information.get_metrics(dataset, 'x-valid')
|
||||
str2 = '{:14s} train : [{:}], valid : [{:}]'.format(dataset, metric2str(train_loss, train_acc), metric2str(valid_loss, valid_acc))
|
||||
elif dataset == 'cifar10':
|
||||
test__loss, test__acc = information.get_metrics(dataset, 'ori-test')
|
||||
str2 = '{:14s} train : [{:}], test : [{:}]'.format(dataset, metric2str(train_loss, train_acc), metric2str(test__loss, test__acc))
|
||||
else:
|
||||
valid_loss, valid_acc = information.get_metrics(dataset, 'x-valid')
|
||||
test__loss, test__acc = information.get_metrics(dataset, 'x-test')
|
||||
str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(dataset, metric2str(train_loss, train_acc), metric2str(valid_loss, valid_acc), metric2str(test__loss, test__acc))
|
||||
strings += [str1, str2]
|
||||
if show: print('\n'.join(strings))
|
||||
return strings
|
||||
|
||||
|
||||
class AANASBenchAPI(object):
|
||||
|
||||
def __init__(self, file_path_or_dict):
|
||||
if isinstance(file_path_or_dict, str):
|
||||
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
|
||||
file_path_or_dict = torch.load(file_path_or_dict)
|
||||
assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict))
|
||||
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
|
||||
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
|
||||
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )
|
||||
self.arch2infos = copy.deepcopy( file_path_or_dict['arch2infos'] )
|
||||
self.evaluated_indexes = sorted(list( copy.deepcopy( file_path_or_dict['evaluated_indexes'] ) ))
|
||||
self.archstr2index = {}
|
||||
for idx, arch in enumerate(self.meta_archs):
|
||||
assert arch.tostr() not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch.tostr()])
|
||||
self.archstr2index[ arch.tostr() ] = idx
|
||||
|
||||
def __getitem__(self, index):
|
||||
return copy.deepcopy( self.meta_archs[index] )
|
||||
|
||||
def __len__(self):
|
||||
return len(self.meta_archs)
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}({num}/{total} architectures)'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs)))
|
||||
|
||||
def query_index_by_arch(self, arch):
|
||||
if arch.tostr() in self.archstr2index:
|
||||
arch_index = self.archstr2index[ arch.tostr() ]
|
||||
#else:
|
||||
# arch_str = Structure.str2fullstructure( arch.tostr() ).tostr()
|
||||
# if arch_str in self.archstr2index:
|
||||
# arch_index = self.archstr2index[ arch_str ]
|
||||
else: arch_index = -1
|
||||
return arch_index
|
||||
|
||||
def query_by_arch(self, arch):
|
||||
arch_index = self.query_index_by_arch(arch)
|
||||
if arch_index == -1: return None
|
||||
if arch_index in self.arch2infos:
|
||||
strings = print_information(self.arch2infos[ arch_index ], 'arch-index={:}'.format(arch_index))
|
||||
return '\n'.join(strings)
|
||||
else:
|
||||
print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index))
|
||||
return None
|
||||
|
||||
def query_by_index(self, arch_index, dataname):
|
||||
assert arch_index in self.arch2infos, 'arch_index [{:}] does not in arch2info'.format(arch_index)
|
||||
archInfo = copy.deepcopy( self.arch2infos[ arch_index ] )
|
||||
assert dataname in archInfo.get_dataset_names(), 'invalid dataset-name : {:}'.format(dataname)
|
||||
info = archInfo.query(dataname)
|
||||
return info
|
||||
|
||||
def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None):
|
||||
best_index, highest_accuracy = -1, None
|
||||
for i, idx in enumerate(self.evaluated_indexes):
|
||||
flop, param, latency = self.arch2infos[idx].get_comput_costs(dataset)
|
||||
if FLOP_max is not None and flop > FLOP_max : continue
|
||||
if Param_max is not None and param > Param_max: continue
|
||||
loss, accuracy = self.arch2infos[idx].get_metrics(dataset, metric_on_set)
|
||||
if best_index == -1:
|
||||
best_index, highest_accuracy = idx, accuracy
|
||||
elif highest_accuracy < accuracy:
|
||||
best_index, highest_accuracy = idx, accuracy
|
||||
return best_index
|
||||
|
||||
def arch(self, index):
|
||||
assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs))
|
||||
return copy.deepcopy(self.meta_archs[index])
|
||||
|
||||
def show(self, index=-1):
|
||||
if index == -1: # show all architectures
|
||||
print(self)
|
||||
for i, idx in enumerate(self.evaluated_indexes):
|
||||
print('\n' + '-' * 10 + ' The ({:5d}/{:5d}) {:06d}-th architecture! '.format(i, len(self.evaluated_indexes), idx) + '-'*10)
|
||||
print('arch : {:}'.format(self.meta_archs[idx]))
|
||||
strings = print_information(self.arch2infos[idx])
|
||||
print('>' * 20)
|
||||
print('\n'.join(strings))
|
||||
print('<' * 20)
|
||||
else:
|
||||
if 0 <= index < len(self.meta_archs):
|
||||
if index not in self.evaluated_indexes: print('The {:}-th architecture has not been evaluated or not saved.'.format(index))
|
||||
else:
|
||||
strings = print_information(self.arch2infos[index])
|
||||
print('\n'.join(strings))
|
||||
else:
|
||||
print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs)))
|
||||
|
||||
|
||||
|
||||
class ArchResults(object):
|
||||
|
||||
def __init__(self, arch_index, arch_str):
|
||||
self.arch_index = int(arch_index)
|
||||
self.arch_str = copy.deepcopy(arch_str)
|
||||
self.all_results = dict()
|
||||
self.dataset_seed = dict()
|
||||
self.clear_net_done = False
|
||||
|
||||
def get_comput_costs(self, dataset):
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
results = [self.all_results[ (dataset, seed) ] for seed in x_seeds]
|
||||
flops = [result.flop for result in results]
|
||||
params = [result.params for result in results]
|
||||
lantencies = [result.get_latency() for result in results]
|
||||
return np.mean(flops), np.mean(params), np.mean(lantencies)
|
||||
|
||||
def get_metrics(self, dataset, setname, iepoch=None):
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
results = [self.all_results[ (dataset, seed) ] for seed in x_seeds]
|
||||
loss, accuracy = [], []
|
||||
for result in results:
|
||||
if setname == 'train':
|
||||
info = result.get_train(iepoch)
|
||||
else:
|
||||
info = result.get_eval(setname, iepoch)
|
||||
loss.append( info['loss'] )
|
||||
accuracy.append( info['accuracy'] )
|
||||
return float(np.mean(loss)), float(np.mean(accuracy))
|
||||
|
||||
def show(self, is_print=False):
|
||||
return print_information(self, None, is_print)
|
||||
|
||||
def get_dataset_names(self):
|
||||
return list(self.dataset_seed.keys())
|
||||
|
||||
def query(self, dataset, seed=None):
|
||||
if seed is None:
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
return [self.all_results[ (dataset, seed) ] for seed in x_seeds]
|
||||
else:
|
||||
return self.all_results[ (dataset, seed) ]
|
||||
|
||||
def arch_idx_str(self):
|
||||
return '{:06d}'.format(self.arch_index)
|
||||
|
||||
def update(self, dataset_name, seed, result):
|
||||
if dataset_name not in self.dataset_seed:
|
||||
self.dataset_seed[dataset_name] = []
|
||||
assert seed not in self.dataset_seed[dataset_name], '{:}-th arch alreadly has this seed ({:}) on {:}'.format(self.arch_index, seed, dataset_name)
|
||||
self.dataset_seed[ dataset_name ].append( seed )
|
||||
self.dataset_seed[ dataset_name ] = sorted( self.dataset_seed[ dataset_name ] )
|
||||
assert (dataset_name, seed) not in self.all_results
|
||||
self.all_results[ (dataset_name, seed) ] = result
|
||||
self.clear_net_done = False
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = dict()
|
||||
for key, value in self.__dict__.items():
|
||||
if key == 'all_results': # contain the class of ResultsCount
|
||||
xvalue = dict()
|
||||
assert isinstance(value, dict), 'invalid type of value for {:} : {:}'.format(key, type(value))
|
||||
for _k, _v in value.items():
|
||||
assert isinstance(_v, ResultsCount), 'invalid type of value for {:}/{:} : {:}'.format(key, _k, type(_v))
|
||||
xvalue[_k] = _v.state_dict()
|
||||
else:
|
||||
xvalue = value
|
||||
state_dict[key] = xvalue
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
new_state_dict = dict()
|
||||
for key, value in state_dict.items():
|
||||
if key == 'all_results': # to convert to the class of ResultsCount
|
||||
xvalue = dict()
|
||||
assert isinstance(value, dict), 'invalid type of value for {:} : {:}'.format(key, type(value))
|
||||
for _k, _v in value.items():
|
||||
xvalue[_k] = ResultsCount.create_from_state_dict(_v)
|
||||
else: xvalue = value
|
||||
new_state_dict[key] = xvalue
|
||||
self.__dict__.update(new_state_dict)
|
||||
|
||||
@staticmethod
|
||||
def create_from_state_dict(state_dict_or_file):
|
||||
x = ArchResults(-1, -1)
|
||||
if isinstance(state_dict_or_file, str): # a file path
|
||||
state_dict = torch.load(state_dict_or_file)
|
||||
elif isinstance(state_dict_or_file, dict):
|
||||
state_dict = state_dict_or_file
|
||||
else:
|
||||
raise ValueError('invalid type of state_dict_or_file : {:}'.format(type(state_dict_or_file)))
|
||||
x.load_state_dict(state_dict)
|
||||
return x
|
||||
|
||||
def clear_params(self):
|
||||
for key, result in self.all_results.items():
|
||||
result.net_state_dict = None
|
||||
self.clear_net_done = True
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}(arch-index={index}, arch={arch}, {num} runs, clear={clear})'.format(name=self.__class__.__name__, index=self.arch_index, arch=self.arch_str, num=len(self.all_results), clear=self.clear_net_done))
|
||||
|
||||
|
||||
|
||||
class ResultsCount(object):
|
||||
|
||||
def __init__(self, name, state_dict, train_accs, train_losses, params, flop, arch_config, seed, epochs, latency):
|
||||
self.name = name
|
||||
self.net_state_dict = state_dict
|
||||
self.train_accs = copy.deepcopy(train_accs)
|
||||
self.train_losses = copy.deepcopy(train_losses)
|
||||
self.arch_config = copy.deepcopy(arch_config)
|
||||
self.params = params
|
||||
self.flop = flop
|
||||
self.seed = seed
|
||||
self.epochs = epochs
|
||||
self.latency = latency
|
||||
# evaluation results
|
||||
self.reset_eval()
|
||||
|
||||
def reset_eval(self):
|
||||
self.eval_names = []
|
||||
self.eval_accs = {}
|
||||
self.eval_losses = {}
|
||||
|
||||
def update_latency(self, latency):
|
||||
self.latency = copy.deepcopy( latency )
|
||||
|
||||
def get_latency(self):
|
||||
if self.latency is None: return -1
|
||||
else: return sum(self.latency) / len(self.latency)
|
||||
|
||||
def update_eval(self, name, accs, losses):
|
||||
assert name not in self.eval_names, '{:} has already added'.format(name)
|
||||
self.eval_names.append( name )
|
||||
self.eval_accs[name] = copy.deepcopy( accs )
|
||||
self.eval_losses[name] = copy.deepcopy( losses )
|
||||
|
||||
def __repr__(self):
|
||||
num_eval = len(self.eval_names)
|
||||
return ('{name}({xname}, arch={arch}, FLOP={flop:.2f}M, Param={param:.3f}MB, seed={seed}, {num_eval} eval-sets)'.format(name=self.__class__.__name__, xname=self.name, arch=self.arch_config['arch_str'], flop=self.flop, param=self.params, seed=self.seed, num_eval=num_eval))
|
||||
|
||||
def valid_evaluation_set(self):
|
||||
return self.eval_names
|
||||
|
||||
def get_train(self, iepoch=None):
|
||||
if iepoch is None: iepoch = self.epochs-1
|
||||
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
|
||||
return {'loss': self.train_losses[iepoch], 'accuracy': self.train_accs[iepoch]}
|
||||
|
||||
def get_eval(self, name, iepoch=None):
|
||||
if iepoch is None: iepoch = self.epochs-1
|
||||
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
|
||||
return {'loss': self.eval_losses[name][iepoch], 'accuracy': self.eval_accs[name][iepoch]}
|
||||
|
||||
def get_net_param(self):
|
||||
return self.net_state_dict
|
||||
|
||||
def state_dict(self):
|
||||
_state_dict = {key: value for key, value in self.__dict__.items()}
|
||||
return _state_dict
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
@staticmethod
|
||||
def create_from_state_dict(state_dict):
|
||||
x = ResultsCount(None, None, None, None, None, None, None, None, None, None)
|
||||
x.load_state_dict(state_dict)
|
||||
return x
|
@ -13,7 +13,7 @@ from .cell_searchs import CellStructure, CellArchitectures
|
||||
|
||||
# Cell-based NAS Models
|
||||
def get_cell_based_tiny_net(config):
|
||||
group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS']
|
||||
group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM']
|
||||
from .cell_searchs import nas_super_nets
|
||||
if config.name in group_names:
|
||||
return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space)
|
||||
|
@ -3,10 +3,12 @@ from .search_model_darts_v2 import TinyNetworkDartsV2
|
||||
from .search_model_gdas import TinyNetworkGDAS
|
||||
from .search_model_setn import TinyNetworkSETN
|
||||
from .search_model_enas import TinyNetworkENAS
|
||||
from .search_model_random import TinyNetworkRANDOM
|
||||
from .genotypes import Structure as CellStructure, architectures as CellArchitectures
|
||||
|
||||
nas_super_nets = {'DARTS-V1': TinyNetworkDartsV1,
|
||||
'DARTS-V2': TinyNetworkDartsV2,
|
||||
'GDAS' : TinyNetworkGDAS,
|
||||
'SETN' : TinyNetworkSETN,
|
||||
'ENAS' : TinyNetworkENAS}
|
||||
'ENAS' : TinyNetworkENAS,
|
||||
'RANDOM' : TinyNetworkRANDOM}
|
||||
|
@ -60,6 +60,17 @@ class Structure:
|
||||
strings.append( string )
|
||||
return '+'.join(strings)
|
||||
|
||||
def check_valid(self):
|
||||
nodes = {0: True}
|
||||
for i, node_info in enumerate(self.nodes):
|
||||
sums = []
|
||||
for op, xin in node_info:
|
||||
if op == 'none' or nodes[xin] == False: x = False
|
||||
else: x = True
|
||||
sums.append( x )
|
||||
nodes[i+1] = sum(sums) > 0
|
||||
return nodes[len(self.nodes)]
|
||||
|
||||
def to_unique_str(self, consider_zero=False):
|
||||
# this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation
|
||||
# two operations are special, i.e., none and skip_connect
|
||||
|
79
lib/models/cell_searchs/search_model_random.py
Normal file
79
lib/models/cell_searchs/search_model_random.py
Normal file
@ -0,0 +1,79 @@
|
||||
##############################################################################
|
||||
# Random Search and Reproducibility for Neural Architecture Search, UAI 2019 #
|
||||
##############################################################################
|
||||
import torch, random
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .search_cells import SearchCell
|
||||
from .genotypes import Structure
|
||||
|
||||
|
||||
class TinyNetworkRANDOM(nn.Module):
|
||||
|
||||
def __init__(self, C, N, max_nodes, num_classes, search_space):
|
||||
super(TinyNetworkRANDOM, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self.max_nodes = max_nodes
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C))
|
||||
|
||||
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
C_prev, num_edge, edge2index = C, None, None
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||||
else:
|
||||
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space)
|
||||
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
|
||||
self.cells.append( cell )
|
||||
C_prev = cell.out_dim
|
||||
self.op_names = deepcopy( search_space )
|
||||
self._Layer = len(self.cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.arch_cache = None
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
def random_genotype(self, set_cache):
|
||||
genotypes = []
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
op_name = random.choice( self.op_names )
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append( tuple(xlist) )
|
||||
arch = Structure( genotypes )
|
||||
if set_cache: self.arch_cache = arch
|
||||
return arch
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if isinstance(cell, SearchCell):
|
||||
feature = cell.forward_dynamic(feature, self.arch_cache)
|
||||
else: feature = cell(feature)
|
||||
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling( out )
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
return out, logits
|
16
scripts-search/AA-NAS-meta-gen.sh
Normal file
16
scripts-search/AA-NAS-meta-gen.sh
Normal file
@ -0,0 +1,16 @@
|
||||
#!/bin/bash
|
||||
# bash ./scripts-search/AA-NAS-meta-gen.sh AA-NAS-BENCHMARK 4
|
||||
echo script name: $0
|
||||
echo $# arguments
|
||||
if [ "$#" -ne 2 ] ;then
|
||||
echo "Input illegal number of parameters " $#
|
||||
echo "Need 2 parameters for save-dir-name and maximum-node-in-cell"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
name=$1
|
||||
node=$2
|
||||
|
||||
save_dir=./output/${name}-${node}
|
||||
|
||||
python ./exps/AA-NAS-Bench-main.py --mode meta --save_dir ${save_dir} --max_node ${node}
|
41
scripts-search/AA-NAS-train-archs.sh
Normal file
41
scripts-search/AA-NAS-train-archs.sh
Normal file
@ -0,0 +1,41 @@
|
||||
#!/bin/bash
|
||||
# bash ./scripts-search/AA-NAS-train-archs.sh 0 100 -1 '777 888 999'
|
||||
echo script name: $0
|
||||
echo $# arguments
|
||||
if [ "$#" -ne 4 ] ;then
|
||||
echo "Input illegal number of parameters " $#
|
||||
echo "Need 4 parameters for start and end and arch-index"
|
||||
exit 1
|
||||
fi
|
||||
if [ "$TORCH_HOME" = "" ]; then
|
||||
echo "Must set TORCH_HOME envoriment variable for data dir saving"
|
||||
exit 1
|
||||
else
|
||||
echo "TORCH_HOME : $TORCH_HOME"
|
||||
fi
|
||||
|
||||
xstart=$1
|
||||
xend=$2
|
||||
arch_index=$3
|
||||
all_seeds=$4
|
||||
|
||||
if [ ${arch_index} == "-1" ]; then
|
||||
mode=new
|
||||
else
|
||||
mode=cover
|
||||
fi
|
||||
|
||||
save_dir=./output/AA-NAS-BENCH-4/
|
||||
|
||||
OMP_NUM_THREADS=4 python ./exps/AA-NAS-Bench-main.py \
|
||||
--mode ${mode} --save_dir ${save_dir} --max_node 4 \
|
||||
--datasets cifar10 cifar10 cifar100 ImageNet16-120 \
|
||||
--splits 1 0 0 0 \
|
||||
--xpaths $TORCH_HOME/cifar.python \
|
||||
$TORCH_HOME/cifar.python \
|
||||
$TORCH_HOME/cifar.python \
|
||||
$TORCH_HOME/cifar.python/ImageNet16 \
|
||||
--channel 16 --num_cells 5 \
|
||||
--workers 4 \
|
||||
--srange ${xstart} ${xend} --arch_index ${arch_index} \
|
||||
--seeds ${all_seeds}
|
33
scripts-search/AA-NAS-train-net.sh
Normal file
33
scripts-search/AA-NAS-train-net.sh
Normal file
@ -0,0 +1,33 @@
|
||||
#!/bin/bash
|
||||
# bash ./scripts-search/AA-NAS-train-net.sh resnet 16 5
|
||||
echo script name: $0
|
||||
echo $# arguments
|
||||
if [ "$#" -ne 3 ] ;then
|
||||
echo "Input illegal number of parameters " $#
|
||||
echo "Need 3 parameters for network, channel, num-of-cells"
|
||||
exit 1
|
||||
fi
|
||||
if [ "$TORCH_HOME" = "" ]; then
|
||||
echo "Must set TORCH_HOME envoriment variable for data dir saving"
|
||||
exit 1
|
||||
else
|
||||
echo "TORCH_HOME : $TORCH_HOME"
|
||||
fi
|
||||
|
||||
model=$1
|
||||
channel=$2
|
||||
num_cells=$3
|
||||
|
||||
save_dir=./output/AA-NAS-BENCH-4/
|
||||
|
||||
OMP_NUM_THREADS=4 python ./exps/AA-NAS-Bench-main.py \
|
||||
--mode specific-${model} --save_dir ${save_dir} --max_node 4 \
|
||||
--datasets cifar10 cifar10 cifar100 ImageNet16-120 \
|
||||
--splits 1 0 0 0 \
|
||||
--xpaths $TORCH_HOME/cifar.python \
|
||||
$TORCH_HOME/cifar.python \
|
||||
$TORCH_HOME/cifar.python \
|
||||
$TORCH_HOME/cifar.python/ImageNet16 \
|
||||
--channel ${channel} --num_cells ${num_cells} \
|
||||
--workers 4 \
|
||||
--seeds 777 888 999
|
38
scripts-search/algos/RANDOM-NAS.sh
Normal file
38
scripts-search/algos/RANDOM-NAS.sh
Normal file
@ -0,0 +1,38 @@
|
||||
#!/bin/bash
|
||||
# Random Search and Reproducibility for Neural Architecture Search, UAI 2019
|
||||
# bash ./scripts-search/algos/RANDOM-NAS.sh cifar10 -1
|
||||
echo script name: $0
|
||||
echo $# arguments
|
||||
if [ "$#" -ne 2 ] ;then
|
||||
echo "Input illegal number of parameters " $#
|
||||
echo "Need 2 parameters for dataset and seed"
|
||||
exit 1
|
||||
fi
|
||||
if [ "$TORCH_HOME" = "" ]; then
|
||||
echo "Must set TORCH_HOME envoriment variable for data dir saving"
|
||||
exit 1
|
||||
else
|
||||
echo "TORCH_HOME : $TORCH_HOME"
|
||||
fi
|
||||
|
||||
dataset=$1
|
||||
seed=$2
|
||||
channel=16
|
||||
num_cells=5
|
||||
max_nodes=4
|
||||
|
||||
if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
|
||||
data_path="$TORCH_HOME/cifar.python"
|
||||
else
|
||||
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
||||
fi
|
||||
|
||||
save_dir=./output/cell-search-tiny/RANDOM-NAS-${dataset}
|
||||
|
||||
OMP_NUM_THREADS=4 python ./exps/algos/RANDOM-NAS.py \
|
||||
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
||||
--dataset ${dataset} --data_path ${data_path} \
|
||||
--search_space_name aa-nas \
|
||||
--config_path ./configs/nas-benchmark/algos/RANDOM.config \
|
||||
--select_num 100 \
|
||||
--workers 4 --print_freq 200 --rand_seed ${seed}
|
Loading…
Reference in New Issue
Block a user