v2
This commit is contained in:
21
LICENCE
21
LICENCE
@@ -1,21 +0,0 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2020 BayesWatch
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
60
README.md
60
README.md
@@ -1,63 +1,31 @@
|
||||
# [Neural Architecture Search Without Training](https://arxiv.org/abs/2006.04647)
|
||||
# Neural Architecture Search Without Training
|
||||
|
||||
This repository contains code for replicating our paper, [NAS Without Training](https://arxiv.org/abs/2006.04647).
|
||||
> :warning: Note: this repository has been updated to reflect the second version of the paper to appear on arXiv 1 March. :warning
|
||||
|
||||
## Setup
|
||||
## Usage
|
||||
|
||||
1. Download the [datasets](https://drive.google.com/drive/folders/1L0Lzq8rWpZLPfiQGd6QR8q5xLV88emU7).
|
||||
2. Download [NAS-Bench-201](https://drive.google.com/file/d/16Y0UwGisiouVRxW-W5hEtbxmcHw_0hF_/view).
|
||||
3. Install the requirements in a conda environment with `conda env create -f environment.yml`.
|
||||
Create a conda environment using the env.yml file
|
||||
|
||||
We also refer the reader to instructions in the official [NAS-Bench-201 README](https://github.com/D-X-Y/NAS-Bench-201).
|
||||
|
||||
## Reproducing our results
|
||||
|
||||
To reproduce our results:
|
||||
|
||||
```
|
||||
conda activate nas-wot
|
||||
./reproduce.sh 3 # average accuracy over 3 runs
|
||||
./reproduce.sh 500 # average accuracy over 500 runs (this will take longer)
|
||||
```bash
|
||||
conda env create -f env.yml
|
||||
```
|
||||
|
||||
Each command will finish by calling `process_results.py`, which will print a table. `./reproduce.sh 3` should print the following table:
|
||||
Activate the environment and follow the instructions to install
|
||||
|
||||
| Method | Search time (s) | CIFAR-10 (val) | CIFAR-10 (test) | CIFAR-100 (val) | CIFAR-100 (test) | ImageNet16-120 (val) | ImageNet16-120 (test) |
|
||||
|:-------------|------------------:|:-----------------|:------------------|:------------------|:-------------------|:-----------------------|:------------------------|
|
||||
| Ours (N=10) | 1.75 | 89.50 +- 0.51 | 92.98 +- 0.82 | 69.80 +- 2.46 | 69.86 +- 2.21 | 42.35 +- 1.19 | 42.38 +- 1.37 |
|
||||
| Ours (N=100) | 17.76 | 87.44 +- 1.45 | 92.27 +- 1.53 | 70.26 +- 1.09 | 69.86 +- 0.60 | 43.30 +- 1.62 | 43.51 +- 1.40
|
||||
Install nasbench (see https://github.com/google-research/nasbench)
|
||||
|
||||
`./reproduce 500` will produce the following table:
|
||||
Download the NDS data from https://github.com/facebookresearch/nds and place the json files in naswot-codebase/nds_data/
|
||||
Download the NASbench101 data (see https://github.com/google-research/nasbench)
|
||||
Download the NASbench201 data (see https://github.com/D-X-Y/NAS-Bench-201)
|
||||
|
||||
| Method | Search time (s) | CIFAR-10 (val) | CIFAR-10 (test) | CIFAR-100 (val) | CIFAR-100 (test) | ImageNet16-120 (val) | ImageNet16-120 (test) |
|
||||
|:-------------|------------------:|:-----------------|:------------------|:------------------|:-------------------|:-----------------------|:------------------------|
|
||||
| Ours (N=10) | 1.67 | 88.61 +- 1.58 | 91.58 +- 1.70 | 67.03 +- 3.01 | 67.15 +- 3.08 | 39.74 +- 4.17 | 39.76 +- 4.39 |
|
||||
| Ours (N=100) | 17.12 | 88.43 +- 1.67 | 91.24 +- 1.70 | 67.04 +- 2.91 | 67.12 +- 2.98 | 40.68 +- 3.41 | 40.67 +- 3.55 |
|
||||
|
||||
|
||||
|
||||
To try different sample sizes, simply change the `--n_samples` argument in the call to `search.py`, and update the list of sample sizes [this line](https://github.com/BayesWatch/nas-without-training/blob/master/process_results.py#L51) of `process_results.py`.
|
||||
|
||||
Note that search times may vary from the reported result owing to hardware setup.
|
||||
|
||||
|
||||
## Plotting histograms
|
||||
|
||||
In order to plot the histograms in Figure 1 of the paper, run:
|
||||
Reproduce all of the results by running
|
||||
|
||||
```bash
|
||||
./scorehook.sh
|
||||
```
|
||||
python plot_histograms.py
|
||||
```
|
||||
to produce:
|
||||
|
||||

|
||||
|
||||
The code is licensed under the MIT licence.
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
This repository makes liberal use of code from the [AutoDL](https://github.com/D-X-Y/AutoDL-Projects) library. We also rely on [NAS-Bench-201](https://github.com/D-X-Y/NAS-Bench-201).
|
||||
|
||||
## Citing us
|
||||
|
||||
If you use or build on our work, please consider citing us:
|
||||
|
||||
1
autodl/__init__.py
Normal file
1
autodl/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
11
autodl/nas_201_api/__init__.py
Normal file
11
autodl/nas_201_api/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||
#####################################################
|
||||
from .api_utils import ArchResults, ResultsCount
|
||||
from .api_201 import NASBench201API
|
||||
from .api_301 import NASBench301API
|
||||
|
||||
# NAS_BENCH_201_API_VERSION="v1.1" # [2020.02.25]
|
||||
# NAS_BENCH_201_API_VERSION="v1.2" # [2020.03.09]
|
||||
# NAS_BENCH_201_API_VERSION="v1.3" # [2020.03.16]
|
||||
NAS_BENCH_201_API_VERSION="v2.0" # [2020.06.30]
|
||||
274
autodl/nas_201_api/api_201.py
Normal file
274
autodl/nas_201_api/api_201.py
Normal file
@@ -0,0 +1,274 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||
############################################################################################
|
||||
# NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search, ICLR 2020 #
|
||||
############################################################################################
|
||||
# The history of benchmark files:
|
||||
# [2020.02.25] NAS-Bench-201-v1_0-e61699.pth : 6219 architectures are trained once, 1621 architectures are trained twice, 7785 architectures are trained three times. `LESS` only supports CIFAR10-VALID.
|
||||
# [2020.03.16] NAS-Bench-201-v1_1-096897.pth : 2225 architectures are trained once, 5439 archiitectures are trained twice, 7961 architectures are trained three times on all training sets. For the hyper-parameters with the total epochs of 12, each model is trained on CIFAR-10, CIFAR-100, ImageNet16-120 once, and is trained on CIFAR-10-VALID twice.
|
||||
#
|
||||
# I'm still actively enhancing this benchmark. Please feel free to contact me if you have any question w.r.t. NAS-Bench-201.
|
||||
#
|
||||
import os, copy, random, torch, numpy as np
|
||||
from pathlib import Path
|
||||
from typing import List, Text, Union, Dict, Optional
|
||||
from collections import OrderedDict, defaultdict
|
||||
|
||||
from .api_utils import ArchResults
|
||||
from .api_utils import NASBenchMetaAPI
|
||||
from .api_utils import remap_dataset_set_names
|
||||
|
||||
|
||||
ALL_BENCHMARK_FILES = ['NAS-Bench-201-v1_0-e61699.pth', 'NAS-Bench-201-v1_1-096897.pth']
|
||||
ALL_ARCHIVE_DIRS = ['NAS-Bench-201-v1_1-archive']
|
||||
|
||||
|
||||
def print_information(information, extra_info=None, show=False):
|
||||
dataset_names = information.get_dataset_names()
|
||||
strings = [information.arch_str, 'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info)]
|
||||
def metric2str(loss, acc):
|
||||
return 'loss = {:.3f}, top1 = {:.2f}%'.format(loss, acc)
|
||||
|
||||
for ida, dataset in enumerate(dataset_names):
|
||||
metric = information.get_compute_costs(dataset)
|
||||
flop, param, latency = metric['flops'], metric['params'], metric['latency']
|
||||
str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(dataset, flop, param, '{:.2f}'.format(latency*1000) if latency is not None and latency > 0 else None)
|
||||
train_info = information.get_metrics(dataset, 'train')
|
||||
if dataset == 'cifar10-valid':
|
||||
valid_info = information.get_metrics(dataset, 'x-valid')
|
||||
str2 = '{:14s} train : [{:}], valid : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']))
|
||||
elif dataset == 'cifar10':
|
||||
test__info = information.get_metrics(dataset, 'ori-test')
|
||||
str2 = '{:14s} train : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy']))
|
||||
else:
|
||||
valid_info = information.get_metrics(dataset, 'x-valid')
|
||||
test__info = information.get_metrics(dataset, 'x-test')
|
||||
str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy']))
|
||||
strings += [str1, str2]
|
||||
if show: print('\n'.join(strings))
|
||||
return strings
|
||||
|
||||
|
||||
"""
|
||||
This is the class for the API of NAS-Bench-201.
|
||||
"""
|
||||
class NASBench201API(NASBenchMetaAPI):
|
||||
|
||||
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
|
||||
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None,
|
||||
verbose: bool=True):
|
||||
self.filename = None
|
||||
self.reset_time()
|
||||
if file_path_or_dict is None:
|
||||
file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], ALL_BENCHMARK_FILES[-1])
|
||||
print ('Try to use the default NAS-Bench-201 path from {:}.'.format(file_path_or_dict))
|
||||
if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path):
|
||||
file_path_or_dict = str(file_path_or_dict)
|
||||
if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict))
|
||||
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
|
||||
self.filename = Path(file_path_or_dict).name
|
||||
file_path_or_dict = torch.load(file_path_or_dict, map_location='cpu')
|
||||
elif isinstance(file_path_or_dict, dict):
|
||||
file_path_or_dict = copy.deepcopy(file_path_or_dict)
|
||||
else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict)))
|
||||
assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict))
|
||||
self.verbose = verbose # [TODO] a flag indicating whether to print more logs
|
||||
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
|
||||
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
|
||||
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )
|
||||
# This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults
|
||||
self.arch2infos_dict = OrderedDict()
|
||||
self._avaliable_hps = set(['12', '200'])
|
||||
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
|
||||
all_info = file_path_or_dict['arch2infos'][xkey]
|
||||
hp2archres = OrderedDict()
|
||||
# self.arch2infos_less[xkey] = ArchResults.create_from_state_dict( all_info['less'] )
|
||||
# self.arch2infos_full[xkey] = ArchResults.create_from_state_dict( all_info['full'] )
|
||||
hp2archres['12'] = ArchResults.create_from_state_dict(all_info['less'])
|
||||
hp2archres['200'] = ArchResults.create_from_state_dict(all_info['full'])
|
||||
self.arch2infos_dict[xkey] = hp2archres
|
||||
self.evaluated_indexes = sorted(list(file_path_or_dict['evaluated_indexes']))
|
||||
self.archstr2index = {}
|
||||
for idx, arch in enumerate(self.meta_archs):
|
||||
assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch])
|
||||
self.archstr2index[ arch ] = idx
|
||||
|
||||
def reload(self, archive_root: Text = None, index: int = None):
|
||||
"""Overwrite all information of the 'index'-th architecture in the search space.
|
||||
It will load its data from 'archive_root'.
|
||||
"""
|
||||
if archive_root is None:
|
||||
archive_root = os.path.join(os.environ['TORCH_HOME'], ALL_ARCHIVE_DIRS[-1])
|
||||
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
|
||||
if index is None:
|
||||
indexes = list(range(len(self)))
|
||||
else:
|
||||
indexes = [index]
|
||||
for idx in indexes:
|
||||
assert 0 <= idx < len(self.meta_archs), 'invalid index of {:}'.format(idx)
|
||||
xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(idx))
|
||||
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
|
||||
xdata = torch.load(xfile_path, map_location='cpu')
|
||||
assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path)
|
||||
hp2archres = OrderedDict()
|
||||
hp2archres['12'] = ArchResults.create_from_state_dict(xdata['less'])
|
||||
hp2archres['200'] = ArchResults.create_from_state_dict(xdata['full'])
|
||||
self.arch2infos_dict[idx] = hp2archres
|
||||
|
||||
def query_info_str_by_arch(self, arch, hp: Text='12'):
|
||||
""" This function is used to query the information of a specific architecture
|
||||
'arch' can be an architecture index or an architecture string
|
||||
When hp=12, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/12E.config'
|
||||
When hp=200, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/200E.config'
|
||||
The difference between these three configurations are the number of training epochs.
|
||||
"""
|
||||
if self.verbose:
|
||||
print('Call query_info_str_by_arch with arch={:} and hp={:}'.format(arch, hp))
|
||||
return self._query_info_str_by_arch(arch, hp, print_information)
|
||||
|
||||
# obtain the metric for the `index`-th architecture
|
||||
# `dataset` indicates the dataset:
|
||||
# 'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set
|
||||
# 'cifar10' : using the proposed train+valid set of CIFAR-10 as the training set
|
||||
# 'cifar100' : using the proposed train set of CIFAR-100 as the training set
|
||||
# 'ImageNet16-120' : using the proposed train set of ImageNet-16-120 as the training set
|
||||
# `iepoch` indicates the index of training epochs from 0 to 11/199.
|
||||
# When iepoch=None, it will return the metric for the last training epoch
|
||||
# When iepoch=11, it will return the metric for the 11-th training epoch (starting from 0)
|
||||
# `use_12epochs_result` indicates different hyper-parameters for training
|
||||
# When use_12epochs_result=True, it trains the network with 12 epochs and the LR decayed from 0.1 to 0 within 12 epochs
|
||||
# When use_12epochs_result=False, it trains the network with 200 epochs and the LR decayed from 0.1 to 0 within 200 epochs
|
||||
# `is_random`
|
||||
# When is_random=True, the performance of a random architecture will be returned
|
||||
# When is_random=False, the performanceo of all trials will be averaged.
|
||||
def get_more_info(self, index, dataset, iepoch=None, hp='12', is_random=True):
|
||||
if self.verbose:
|
||||
print('Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(index, dataset, iepoch, hp, is_random))
|
||||
index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object
|
||||
if index not in self.arch2infos_dict:
|
||||
raise ValueError('Did not find {:} from arch2infos_dict.'.format(index))
|
||||
archresult = self.arch2infos_dict[index][str(hp)]
|
||||
# if randomly select one trial, select the seed at first
|
||||
if isinstance(is_random, bool) and is_random:
|
||||
seeds = archresult.get_dataset_seeds(dataset)
|
||||
is_random = random.choice(seeds)
|
||||
# collect the training information
|
||||
train_info = archresult.get_metrics(dataset, 'train', iepoch=iepoch, is_random=is_random)
|
||||
total = train_info['iepoch'] + 1
|
||||
xinfo = {'train-loss' : train_info['loss'],
|
||||
'train-accuracy': train_info['accuracy'],
|
||||
'train-per-time': train_info['all_time'] / total if train_info['all_time'] is not None else None,
|
||||
'train-all-time': train_info['all_time']}
|
||||
# collect the evaluation information
|
||||
if dataset == 'cifar10-valid':
|
||||
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
|
||||
try:
|
||||
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||
except:
|
||||
test_info = None
|
||||
valtest_info = None
|
||||
else:
|
||||
try: # collect results on the proposed test set
|
||||
if dataset == 'cifar10':
|
||||
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||
else:
|
||||
test_info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random)
|
||||
except:
|
||||
test_info = None
|
||||
try: # collect results on the proposed validation set
|
||||
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
|
||||
except:
|
||||
valid_info = None
|
||||
try:
|
||||
if dataset != 'cifar10':
|
||||
valtest_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||
else:
|
||||
valtest_info = None
|
||||
except:
|
||||
valtest_info = None
|
||||
if valid_info is not None:
|
||||
xinfo['valid-loss'] = valid_info['loss']
|
||||
xinfo['valid-accuracy'] = valid_info['accuracy']
|
||||
xinfo['valid-per-time'] = valid_info['all_time'] / total if valid_info['all_time'] is not None else None
|
||||
xinfo['valid-all-time'] = valid_info['all_time']
|
||||
if test_info is not None:
|
||||
xinfo['test-loss'] = test_info['loss']
|
||||
xinfo['test-accuracy'] = test_info['accuracy']
|
||||
xinfo['test-per-time'] = test_info['all_time'] / total if test_info['all_time'] is not None else None
|
||||
xinfo['test-all-time'] = test_info['all_time']
|
||||
if valtest_info is not None:
|
||||
xinfo['valtest-loss'] = valtest_info['loss']
|
||||
xinfo['valtest-accuracy'] = valtest_info['accuracy']
|
||||
xinfo['valtest-per-time'] = valtest_info['all_time'] / total if valtest_info['all_time'] is not None else None
|
||||
xinfo['valtest-all-time'] = valtest_info['all_time']
|
||||
return xinfo
|
||||
|
||||
def show(self, index: int = -1) -> None:
|
||||
"""This function will print the information of a specific (or all) architecture(s)."""
|
||||
self._show(index, print_information)
|
||||
|
||||
@staticmethod
|
||||
def str2lists(arch_str: Text) -> List[tuple]:
|
||||
"""
|
||||
This function shows how to read the string-based architecture encoding.
|
||||
It is the same as the `str2structure` func in `AutoDL-Projects/lib/models/cell_searchs/genotypes.py`
|
||||
|
||||
:param
|
||||
arch_str: the input is a string indicates the architecture topology, such as
|
||||
|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|
|
||||
:return: a list of tuple, contains multiple (op, input_node_index) pairs.
|
||||
|
||||
:usage
|
||||
arch = api.str2lists( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' )
|
||||
print ('there are {:} nodes in this arch'.format(len(arch)+1)) # arch is a list
|
||||
for i, node in enumerate(arch):
|
||||
print('the {:}-th node is the sum of these {:} nodes with op: {:}'.format(i+1, len(node), node))
|
||||
"""
|
||||
node_strs = arch_str.split('+')
|
||||
genotypes = []
|
||||
for i, node_str in enumerate(node_strs):
|
||||
inputs = list(filter(lambda x: x != '', node_str.split('|')))
|
||||
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
|
||||
inputs = ( xi.split('~') for xi in inputs )
|
||||
input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs)
|
||||
genotypes.append( input_infos )
|
||||
return genotypes
|
||||
|
||||
@staticmethod
|
||||
def str2matrix(arch_str: Text,
|
||||
search_space: List[Text] = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']) -> np.ndarray:
|
||||
"""
|
||||
This func shows how to convert the string-based architecture encoding to the encoding strategy in NAS-Bench-101.
|
||||
|
||||
:param
|
||||
arch_str: the input is a string indicates the architecture topology, such as
|
||||
|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|
|
||||
search_space: a list of operation string, the default list is the search space for NAS-Bench-201
|
||||
the default value should be be consistent with this line https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/models/cell_operations.py#L24
|
||||
:return
|
||||
the numpy matrix (2-D np.ndarray) representing the DAG of this architecture topology
|
||||
:usage
|
||||
matrix = api.str2matrix( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' )
|
||||
This matrix is 4-by-4 matrix representing a cell with 4 nodes (only the lower left triangle is useful).
|
||||
[ [0, 0, 0, 0], # the first line represents the input (0-th) node
|
||||
[2, 0, 0, 0], # the second line represents the 1-st node, is calculated by 2-th-op( 0-th-node )
|
||||
[0, 0, 0, 0], # the third line represents the 2-nd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node )
|
||||
[0, 0, 1, 0] ] # the fourth line represents the 3-rd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node ) + 1-th-op( 2-th-node )
|
||||
In NAS-Bench-201 search space, 0-th-op is 'none', 1-th-op is 'skip_connect',
|
||||
2-th-op is 'nor_conv_1x1', 3-th-op is 'nor_conv_3x3', 4-th-op is 'avg_pool_3x3'.
|
||||
:(NOTE)
|
||||
If a node has two input-edges from the same node, this function does not work. One edge will be overlapped.
|
||||
"""
|
||||
node_strs = arch_str.split('+')
|
||||
num_nodes = len(node_strs) + 1
|
||||
matrix = np.zeros((num_nodes, num_nodes))
|
||||
for i, node_str in enumerate(node_strs):
|
||||
inputs = list(filter(lambda x: x != '', node_str.split('|')))
|
||||
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
|
||||
for xi in inputs:
|
||||
op, idx = xi.split('~')
|
||||
if op not in search_space: raise ValueError('this op ({:}) is not in {:}'.format(op, search_space))
|
||||
op_idx, node_idx = search_space.index(op), int(idx)
|
||||
matrix[i+1, node_idx] = op_idx
|
||||
return matrix
|
||||
|
||||
222
autodl/nas_201_api/api_301.py
Normal file
222
autodl/nas_201_api/api_301.py
Normal file
@@ -0,0 +1,222 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
|
||||
############################################################################################
|
||||
# NAS-Bench-301, coming soon.
|
||||
############################################################################################
|
||||
# The history of benchmark files:
|
||||
# [2020.06.30] NAS-Bench-301-v1_0
|
||||
#
|
||||
import os, copy, random, torch, numpy as np
|
||||
from pathlib import Path
|
||||
from typing import List, Text, Union, Dict, Optional
|
||||
from collections import OrderedDict, defaultdict
|
||||
from .api_utils import ArchResults
|
||||
from .api_utils import NASBenchMetaAPI
|
||||
from .api_utils import remap_dataset_set_names
|
||||
|
||||
|
||||
ALL_BENCHMARK_FILES = ['NAS-Bench-301-v1_0-363be7.pth']
|
||||
ALL_ARCHIVE_DIRS = ['NAS-Bench-301-v1_0-archive']
|
||||
|
||||
|
||||
def print_information(information, extra_info=None, show=False):
|
||||
dataset_names = information.get_dataset_names()
|
||||
strings = [information.arch_str, 'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info)]
|
||||
def metric2str(loss, acc):
|
||||
return 'loss = {:.3f} & top1 = {:.2f}%'.format(loss, acc)
|
||||
|
||||
for ida, dataset in enumerate(dataset_names):
|
||||
metric = information.get_compute_costs(dataset)
|
||||
flop, param, latency = metric['flops'], metric['params'], metric['latency']
|
||||
str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(dataset, flop, param, '{:.2f}'.format(latency*1000) if latency is not None and latency > 0 else None)
|
||||
train_info = information.get_metrics(dataset, 'train')
|
||||
if dataset == 'cifar10-valid':
|
||||
valid_info = information.get_metrics(dataset, 'x-valid')
|
||||
test__info = information.get_metrics(dataset, 'ori-test')
|
||||
str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(
|
||||
dataset, metric2str(train_info['loss'], train_info['accuracy']),
|
||||
metric2str(valid_info['loss'], valid_info['accuracy']),
|
||||
metric2str(test__info['loss'], test__info['accuracy']))
|
||||
elif dataset == 'cifar10':
|
||||
test__info = information.get_metrics(dataset, 'ori-test')
|
||||
str2 = '{:14s} train : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy']))
|
||||
else:
|
||||
valid_info = information.get_metrics(dataset, 'x-valid')
|
||||
test__info = information.get_metrics(dataset, 'x-test')
|
||||
str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy']))
|
||||
strings += [str1, str2]
|
||||
if show: print('\n'.join(strings))
|
||||
return strings
|
||||
|
||||
|
||||
"""
|
||||
This is the class for the API of NAS-Bench-301.
|
||||
"""
|
||||
class NASBench301API(NASBenchMetaAPI):
|
||||
|
||||
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
|
||||
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, verbose: bool=True):
|
||||
self.filename = None
|
||||
self.reset_time()
|
||||
if file_path_or_dict is None:
|
||||
file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], ALL_BENCHMARK_FILES[-1])
|
||||
print ('Try to use the default NAS-Bench-301 path from {:}.'.format(file_path_or_dict))
|
||||
if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path):
|
||||
file_path_or_dict = str(file_path_or_dict)
|
||||
if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict))
|
||||
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
|
||||
self.filename = Path(file_path_or_dict).name
|
||||
file_path_or_dict = torch.load(file_path_or_dict, map_location='cpu')
|
||||
elif isinstance(file_path_or_dict, dict):
|
||||
file_path_or_dict = copy.deepcopy( file_path_or_dict )
|
||||
else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict)))
|
||||
assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict))
|
||||
self.verbose = verbose # [TODO] a flag indicating whether to print more logs
|
||||
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
|
||||
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
|
||||
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )
|
||||
# This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults
|
||||
self.arch2infos_dict = OrderedDict()
|
||||
self._avaliable_hps = set()
|
||||
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
|
||||
all_infos = file_path_or_dict['arch2infos'][xkey]
|
||||
hp2archres = OrderedDict()
|
||||
for hp_key, results in all_infos.items():
|
||||
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
|
||||
self._avaliable_hps.add(hp_key) # save the avaliable hyper-parameter
|
||||
self.arch2infos_dict[xkey] = hp2archres
|
||||
self.evaluated_indexes = sorted(list(file_path_or_dict['evaluated_indexes']))
|
||||
self.archstr2index = {}
|
||||
for idx, arch in enumerate(self.meta_archs):
|
||||
assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch])
|
||||
self.archstr2index[ arch ] = idx
|
||||
if self.verbose:
|
||||
print('Create NAS-Bench-301 done with {:}/{:} architectures avaliable.'.format(len(self.evaluated_indexes), len(self.meta_archs)))
|
||||
|
||||
def reload(self, archive_root: Text = None, index: int = None):
|
||||
"""Overwrite all information of the 'index'-th architecture in the search space, where the data will be loaded from 'archive_root'.
|
||||
If index is None, overwrite all ckps.
|
||||
"""
|
||||
if self.verbose:
|
||||
print('Call clear_params with archive_root={:} and index={:}'.format(archive_root, index))
|
||||
if archive_root is None:
|
||||
archive_root = os.path.join(os.environ['TORCH_HOME'], ALL_ARCHIVE_DIRS[-1])
|
||||
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
|
||||
if index is None:
|
||||
indexes = list(range(len(self)))
|
||||
else:
|
||||
indexes = [index]
|
||||
for idx in indexes:
|
||||
assert 0 <= idx < len(self.meta_archs), 'invalid index of {:}'.format(idx)
|
||||
xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(idx))
|
||||
if not os.path.isfile(xfile_path):
|
||||
xfile_path = os.path.join(archive_root, '{:d}-FULL.pth'.format(idx))
|
||||
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
|
||||
xdata = torch.load(xfile_path, map_location='cpu')
|
||||
assert isinstance(xdata, dict), 'invalid format of data in {:}'.format(xfile_path)
|
||||
|
||||
hp2archres = OrderedDict()
|
||||
for hp_key, results in xdata.items():
|
||||
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
|
||||
self.arch2infos_dict[idx] = hp2archres
|
||||
|
||||
def query_info_str_by_arch(self, arch, hp: Text='12'):
|
||||
""" This function is used to query the information of a specific architecture
|
||||
'arch' can be an architecture index or an architecture string
|
||||
When hp=01, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/01E.config'
|
||||
When hp=12, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/12E.config'
|
||||
When hp=90, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/90E.config'
|
||||
The difference between these three configurations are the number of training epochs.
|
||||
"""
|
||||
if self.verbose:
|
||||
print('Call query_info_str_by_arch with arch={:} and hp={:}'.format(arch, hp))
|
||||
return self._query_info_str_by_arch(arch, hp, print_information)
|
||||
|
||||
def get_more_info(self, index, dataset: Text, iepoch=None, hp='12', is_random=True):
|
||||
"""This function will return the metric for the `index`-th architecture
|
||||
`dataset` indicates the dataset:
|
||||
'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set
|
||||
'cifar10' : using the proposed train+valid set of CIFAR-10 as the training set
|
||||
'cifar100' : using the proposed train set of CIFAR-100 as the training set
|
||||
'ImageNet16-120' : using the proposed train set of ImageNet-16-120 as the training set
|
||||
`iepoch` indicates the index of training epochs from 0 to 11/199.
|
||||
When iepoch=None, it will return the metric for the last training epoch
|
||||
When iepoch=11, it will return the metric for the 11-th training epoch (starting from 0)
|
||||
`hp` indicates different hyper-parameters for training
|
||||
When hp=01, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 01 epochs
|
||||
When hp=12, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 12 epochs
|
||||
When hp=90, it trains the network with 01 epochs and the LR decayed from 0.1 to 0 within 90 epochs
|
||||
`is_random`
|
||||
When is_random=True, the performance of a random architecture will be returned
|
||||
When is_random=False, the performanceo of all trials will be averaged.
|
||||
"""
|
||||
if self.verbose:
|
||||
print('Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(index, dataset, iepoch, hp, is_random))
|
||||
index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object
|
||||
if index not in self.arch2infos_dict:
|
||||
raise ValueError('Did not find {:} from arch2infos_dict.'.format(index))
|
||||
archresult = self.arch2infos_dict[index][str(hp)]
|
||||
# if randomly select one trial, select the seed at first
|
||||
if isinstance(is_random, bool) and is_random:
|
||||
seeds = archresult.get_dataset_seeds(dataset)
|
||||
is_random = random.choice(seeds)
|
||||
# collect the training information
|
||||
train_info = archresult.get_metrics(dataset, 'train', iepoch=iepoch, is_random=is_random)
|
||||
total = train_info['iepoch'] + 1
|
||||
xinfo = {'train-loss' : train_info['loss'],
|
||||
'train-accuracy': train_info['accuracy'],
|
||||
'train-per-time': train_info['all_time'] / total,
|
||||
'train-all-time': train_info['all_time']}
|
||||
# collect the evaluation information
|
||||
if dataset == 'cifar10-valid':
|
||||
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
|
||||
try:
|
||||
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||
except:
|
||||
test_info = None
|
||||
valtest_info = None
|
||||
else:
|
||||
try: # collect results on the proposed test set
|
||||
if dataset == 'cifar10':
|
||||
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||
else:
|
||||
test_info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random)
|
||||
except:
|
||||
test_info = None
|
||||
try: # collect results on the proposed validation set
|
||||
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
|
||||
except:
|
||||
valid_info = None
|
||||
try:
|
||||
if dataset != 'cifar10':
|
||||
valtest_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||
else:
|
||||
valtest_info = None
|
||||
except:
|
||||
valtest_info = None
|
||||
if valid_info is not None:
|
||||
xinfo['valid-loss'] = valid_info['loss']
|
||||
xinfo['valid-accuracy'] = valid_info['accuracy']
|
||||
xinfo['valid-per-time'] = valid_info['all_time'] / total
|
||||
xinfo['valid-all-time'] = valid_info['all_time']
|
||||
if test_info is not None:
|
||||
xinfo['test-loss'] = test_info['loss']
|
||||
xinfo['test-accuracy'] = test_info['accuracy']
|
||||
xinfo['test-per-time'] = test_info['all_time'] / total
|
||||
xinfo['test-all-time'] = test_info['all_time']
|
||||
if valtest_info is not None:
|
||||
xinfo['valtest-loss'] = valtest_info['loss']
|
||||
xinfo['valtest-accuracy'] = valtest_info['accuracy']
|
||||
xinfo['valtest-per-time'] = valtest_info['all_time'] / total
|
||||
xinfo['valtest-all-time'] = valtest_info['all_time']
|
||||
return xinfo
|
||||
|
||||
def show(self, index: int = -1) -> None:
|
||||
"""
|
||||
This function will print the information of a specific (or all) architecture(s).
|
||||
|
||||
:param index: If the index < 0: it will loop for all architectures and print their information one by one.
|
||||
else: it will print the information of the 'index'-th architecture.
|
||||
:return: nothing
|
||||
"""
|
||||
self._show(index, print_information)
|
||||
750
autodl/nas_201_api/api_utils.py
Normal file
750
autodl/nas_201_api/api_utils.py
Normal file
@@ -0,0 +1,750 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||
############################################################################################
|
||||
# NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search, ICLR 2020 #
|
||||
############################################################################################
|
||||
# In this Python file, we define NASBenchMetaAPI, the abstract class for benchmark APIs.
|
||||
# We also define the class ArchResults, which contains all information of a single architecture trained by one kind of hyper-parameters on three datasets.
|
||||
# We also define the class ResultsCount, which contains all information of a single trial for a single architecture.
|
||||
############################################################################################
|
||||
# History:
|
||||
# [2020.06.30] The first version.
|
||||
#
|
||||
import os, abc, copy, random, torch, numpy as np
|
||||
from pathlib import Path
|
||||
from typing import List, Text, Union, Dict, Optional
|
||||
from collections import OrderedDict, defaultdict
|
||||
|
||||
|
||||
def remap_dataset_set_names(dataset, metric_on_set, verbose=False):
|
||||
"""re-map the metric_on_set to internal keys"""
|
||||
if verbose:
|
||||
print('Call internal function _remap_dataset_set_names with dataset={:} and metric_on_set={:}'.format(dataset, metric_on_set))
|
||||
if dataset == 'cifar10' and metric_on_set == 'valid':
|
||||
dataset, metric_on_set = 'cifar10-valid', 'x-valid'
|
||||
elif dataset == 'cifar10' and metric_on_set == 'test':
|
||||
dataset, metric_on_set = 'cifar10', 'ori-test'
|
||||
elif dataset == 'cifar10' and metric_on_set == 'train':
|
||||
dataset, metric_on_set = 'cifar10', 'train'
|
||||
elif (dataset == 'cifar100' or dataset == 'ImageNet16-120') and metric_on_set == 'valid':
|
||||
metric_on_set = 'x-valid'
|
||||
elif (dataset == 'cifar100' or dataset == 'ImageNet16-120') and metric_on_set == 'test':
|
||||
metric_on_set = 'x-test'
|
||||
if verbose:
|
||||
print(' return dataset={:} and metric_on_set={:}'.format(dataset, metric_on_set))
|
||||
return dataset, metric_on_set
|
||||
|
||||
|
||||
class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
||||
|
||||
@abc.abstractmethod
|
||||
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, verbose: bool=True):
|
||||
"""The initialization function that takes the dataset file path (or a dict loaded from that path) as input."""
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
return copy.deepcopy(self.meta_archs[index])
|
||||
|
||||
def arch(self, index: int):
|
||||
"""Return the topology structure of the `index`-th architecture."""
|
||||
if self.verbose:
|
||||
print('Call the arch function with index={:}'.format(index))
|
||||
assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs))
|
||||
return copy.deepcopy(self.meta_archs[index])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.meta_archs)
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}({num}/{total} architectures, file={filename})'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs), filename=self.filename))
|
||||
|
||||
@property
|
||||
def avaliable_hps(self):
|
||||
return list(copy.deepcopy(self._avaliable_hps))
|
||||
|
||||
@property
|
||||
def used_time(self):
|
||||
return self._used_time
|
||||
|
||||
def reset_time(self):
|
||||
self._used_time = 0
|
||||
|
||||
def simulate_train_eval(self, arch, dataset, iepoch=None, hp='12', account_time=True):
|
||||
index = self.query_index_by_arch(arch)
|
||||
all_names = ('cifar10', 'cifar100', 'ImageNet16-120')
|
||||
assert dataset in all_names, 'Invalid dataset name : {:} vs {:}'.format(dataset, all_names)
|
||||
if dataset == 'cifar10':
|
||||
info = self.get_more_info(index, 'cifar10-valid', iepoch=iepoch, hp=hp, is_random=True)
|
||||
else:
|
||||
info = self.get_more_info(index, dataset, iepoch=iepoch, hp=hp, is_random=True)
|
||||
valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time']
|
||||
latency = self.get_latency(index, dataset)
|
||||
if account_time:
|
||||
self._used_time += time_cost
|
||||
return valid_acc, latency, time_cost, self._used_time
|
||||
|
||||
def random(self):
|
||||
"""Return a random index of all architectures."""
|
||||
return random.randint(0, len(self.meta_archs)-1)
|
||||
|
||||
def query_index_by_arch(self, arch):
|
||||
""" This function is used to query the index of an architecture in the search space.
|
||||
In the topology search space, the input arch can be an architecture string such as '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|';
|
||||
or an instance that has the 'tostr' function that can generate the architecture string;
|
||||
or it is directly an architecture index, in this case, we will check whether it is valid or not.
|
||||
This function will return the index.
|
||||
If return -1, it means this architecture is not in the search space.
|
||||
Otherwise, it will return an int in [0, the-number-of-candidates-in-the-search-space).
|
||||
"""
|
||||
if self.verbose:
|
||||
print('Call query_index_by_arch with arch={:}'.format(arch))
|
||||
if isinstance(arch, int):
|
||||
if 0 <= arch < len(self):
|
||||
return arch
|
||||
else:
|
||||
raise ValueError('Invalid architecture index {:} vs [{:}, {:}].'.format(arch, 0, len(self)))
|
||||
elif isinstance(arch, str):
|
||||
if arch in self.archstr2index: arch_index = self.archstr2index[ arch ]
|
||||
else : arch_index = -1
|
||||
elif hasattr(arch, 'tostr'):
|
||||
if arch.tostr() in self.archstr2index: arch_index = self.archstr2index[ arch.tostr() ]
|
||||
else : arch_index = -1
|
||||
else: arch_index = -1
|
||||
return arch_index
|
||||
|
||||
def query_by_arch(self, arch, hp):
|
||||
# This is to make the current version be compatible with the old version.
|
||||
return self.query_info_str_by_arch(arch, hp)
|
||||
|
||||
@abc.abstractmethod
|
||||
def reload(self, archive_root: Text = None, index: int = None):
|
||||
"""Overwrite all information of the 'index'-th architecture in the search space, where the data will be loaded from 'archive_root'.
|
||||
If index is None, overwrite all ckps.
|
||||
"""
|
||||
|
||||
def clear_params(self, index: int, hp: Optional[Text]=None):
|
||||
"""Remove the architecture's weights to save memory.
|
||||
:arg
|
||||
index: the index of the target architecture
|
||||
hp: a flag to controll how to clear the parameters.
|
||||
-- None: clear all the weights in '01'/'12'/'90', which indicates the number of training epochs.
|
||||
-- '01' or '12' or '90': clear all the weights in arch2infos_dict[index][hp].
|
||||
"""
|
||||
if self.verbose:
|
||||
print('Call clear_params with index={:} and hp={:}'.format(index, hp))
|
||||
if hp is None:
|
||||
for key, result in self.arch2infos_dict[index].items():
|
||||
result.clear_params()
|
||||
else:
|
||||
if str(hp) not in self.arch2infos_dict[index]:
|
||||
raise ValueError('The {:}-th architecture only has hyper-parameters of {:} instead of {:}.'.format(index, list(self.arch2infos_dict[index].keys()), hp))
|
||||
self.arch2infos_dict[index][str(hp)].clear_params()
|
||||
|
||||
@abc.abstractmethod
|
||||
def query_info_str_by_arch(self, arch, hp: Text='12'):
|
||||
"""This function is used to query the information of a specific architecture."""
|
||||
|
||||
def _query_info_str_by_arch(self, arch, hp: Text='12', print_information=None):
|
||||
arch_index = self.query_index_by_arch(arch)
|
||||
if arch_index in self.arch2infos_dict:
|
||||
if hp not in self.arch2infos_dict[arch_index]:
|
||||
raise ValueError('The {:}-th architecture only has hyper-parameters of {:} instead of {:}.'.format(index, list(self.arch2infos_dict[arch_index].keys()), hp))
|
||||
info = self.arch2infos_dict[arch_index][hp]
|
||||
strings = print_information(info, 'arch-index={:}'.format(arch_index))
|
||||
return '\n'.join(strings)
|
||||
else:
|
||||
print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index))
|
||||
return None
|
||||
|
||||
def query_meta_info_by_index(self, arch_index, hp: Text = '12'):
|
||||
"""Return the ArchResults for the 'arch_index'-th architecture. This function is similar to query_by_index."""
|
||||
if self.verbose:
|
||||
print('Call query_meta_info_by_index with arch_index={:}, hp={:}'.format(arch_index, hp))
|
||||
if arch_index in self.arch2infos_dict:
|
||||
if hp not in self.arch2infos_dict[arch_index]:
|
||||
raise ValueError('The {:}-th architecture only has hyper-parameters of {:} instead of {:}.'.format(arch_index, list(self.arch2infos_dict[arch_index].keys()), hp))
|
||||
info = self.arch2infos_dict[arch_index][hp]
|
||||
else:
|
||||
raise ValueError('arch_index [{:}] does not in arch2infos'.format(arch_index))
|
||||
return copy.deepcopy(info)
|
||||
|
||||
def query_by_index(self, arch_index: int, dataname: Union[None, Text] = None, hp: Text = '12'):
|
||||
""" This 'query_by_index' function is used to query information with the training of 01 epochs, 12 epochs, 90 epochs, or 200 epochs.
|
||||
------
|
||||
If hp=01, we train the model by 01 epochs (see config in configs/nas-benchmark/hyper-opts/01E.config)
|
||||
If hp=12, we train the model by 01 epochs (see config in configs/nas-benchmark/hyper-opts/12E.config)
|
||||
If hp=90, we train the model by 01 epochs (see config in configs/nas-benchmark/hyper-opts/90E.config)
|
||||
If hp=200, we train the model by 01 epochs (see config in configs/nas-benchmark/hyper-opts/200E.config)
|
||||
------
|
||||
If dataname is None, return the ArchResults
|
||||
else, return a dict with all trials on that dataset (the key is the seed)
|
||||
Options are 'cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120'.
|
||||
-- cifar10-valid : training the model on the CIFAR-10 training set.
|
||||
-- cifar10 : training the model on the CIFAR-10 training + validation set.
|
||||
-- cifar100 : training the model on the CIFAR-100 training set.
|
||||
-- ImageNet16-120 : training the model on the ImageNet16-120 training set.
|
||||
"""
|
||||
if self.verbose:
|
||||
print('Call query_by_index with arch_index={:}, dataname={:}, hp={:}'.format(arch_index, dataname, hp))
|
||||
info = self.query_meta_info_by_index(arch_index, hp)
|
||||
if dataname is None: return info
|
||||
else:
|
||||
if dataname not in info.get_dataset_names():
|
||||
raise ValueError('invalid dataset-name : {:} vs. {:}'.format(dataname, info.get_dataset_names()))
|
||||
return info.query(dataname)
|
||||
|
||||
def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None, hp: Text = '12'):
|
||||
"""Find the architecture with the highest accuracy based on some constraints."""
|
||||
if self.verbose:
|
||||
print('Call find_best with dataset={:}, metric_on_set={:}, hp={:} | with #FLOPs < {:} and #Params < {:}'.format(dataset, metric_on_set, hp, FLOP_max, Param_max))
|
||||
dataset, metric_on_set = remap_dataset_set_names(dataset, metric_on_set, self.verbose)
|
||||
best_index, highest_accuracy = -1, None
|
||||
for i, arch_index in enumerate(self.evaluated_indexes):
|
||||
arch_info = self.arch2infos_dict[arch_index][hp]
|
||||
info = arch_info.get_compute_costs(dataset) # the information of costs
|
||||
flop, param, latency = info['flops'], info['params'], info['latency']
|
||||
if FLOP_max is not None and flop > FLOP_max : continue
|
||||
if Param_max is not None and param > Param_max: continue
|
||||
xinfo = arch_info.get_metrics(dataset, metric_on_set) # the information of loss and accuracy
|
||||
loss, accuracy = xinfo['loss'], xinfo['accuracy']
|
||||
if best_index == -1:
|
||||
best_index, highest_accuracy = arch_index, accuracy
|
||||
elif highest_accuracy < accuracy:
|
||||
best_index, highest_accuracy = arch_index, accuracy
|
||||
if self.verbose:
|
||||
print(' the best architecture : [{:}] {:} with accuracy={:.3f}%'.format(best_index, self.arch(best_index), highest_accuracy))
|
||||
return best_index, highest_accuracy
|
||||
|
||||
def get_net_param(self, index, dataset, seed: Optional[int], hp: Text = '12'):
|
||||
"""
|
||||
This function is used to obtain the trained weights of the `index`-th architecture on `dataset` with the seed of `seed`
|
||||
Args [seed]:
|
||||
-- None : return a dict containing the trained weights of all trials, where each key is a seed and its corresponding value is the weights.
|
||||
-- a interger : return the weights of a specific trial, whose seed is this interger.
|
||||
Args [hp]:
|
||||
-- 01 : train the model by 01 epochs
|
||||
-- 12 : train the model by 12 epochs
|
||||
-- 90 : train the model by 90 epochs
|
||||
-- 200 : train the model by 200 epochs
|
||||
"""
|
||||
if self.verbose:
|
||||
print('Call the get_net_param function with index={:}, dataset={:}, seed={:}, hp={:}'.format(index, dataset, seed, hp))
|
||||
info = self.query_meta_info_by_index(index, hp)
|
||||
return info.get_net_param(dataset, seed)
|
||||
|
||||
def get_net_config(self, index: int, dataset: Text):
|
||||
"""
|
||||
This function is used to obtain the configuration for the `index`-th architecture on `dataset`.
|
||||
Args [dataset] (4 possible options):
|
||||
-- cifar10-valid : training the model on the CIFAR-10 training set.
|
||||
-- cifar10 : training the model on the CIFAR-10 training + validation set.
|
||||
-- cifar100 : training the model on the CIFAR-100 training set.
|
||||
-- ImageNet16-120 : training the model on the ImageNet16-120 training set.
|
||||
This function will return a dict.
|
||||
========= Some examlpes for using this function:
|
||||
config = api.get_net_config(128, 'cifar10')
|
||||
"""
|
||||
if self.verbose:
|
||||
print('Call the get_net_config function with index={:}, dataset={:}.'.format(index, dataset))
|
||||
if index in self.arch2infos_dict:
|
||||
info = self.arch2infos_dict[index]
|
||||
else:
|
||||
raise ValueError('The arch_index={:} is not in arch2infos_dict.'.format(arch_index))
|
||||
info = next(iter(info.values()))
|
||||
results = info.query(dataset, None)
|
||||
results = next(iter(results.values()))
|
||||
return results.get_config(None)
|
||||
|
||||
def get_cost_info(self, index: int, dataset: Text, hp: Text = '12') -> Dict[Text, float]:
|
||||
"""To obtain the cost metric for the `index`-th architecture on a dataset."""
|
||||
if self.verbose:
|
||||
print('Call the get_cost_info function with index={:}, dataset={:}, and hp={:}.'.format(index, dataset, hp))
|
||||
info = self.query_meta_info_by_index(index, hp)
|
||||
return info.get_compute_costs(dataset)
|
||||
|
||||
def get_latency(self, index: int, dataset: Text, hp: Text = '12') -> float:
|
||||
"""
|
||||
To obtain the latency of the network (by default it will return the latency with the batch size of 256).
|
||||
:param index: the index of the target architecture
|
||||
:param dataset: the dataset name (cifar10-valid, cifar10, cifar100, ImageNet16-120)
|
||||
:return: return a float value in seconds
|
||||
"""
|
||||
if self.verbose:
|
||||
print('Call the get_latency function with index={:}, dataset={:}, and hp={:}.'.format(index, dataset, hp))
|
||||
cost_dict = self.get_cost_info(index, dataset, hp)
|
||||
return cost_dict['latency']
|
||||
|
||||
@abc.abstractmethod
|
||||
def show(self, index=-1):
|
||||
"""This function will print the information of a specific (or all) architecture(s)."""
|
||||
|
||||
def _show(self, index=-1, print_information=None) -> None:
|
||||
"""
|
||||
This function will print the information of a specific (or all) architecture(s).
|
||||
|
||||
:param index: If the index < 0: it will loop for all architectures and print their information one by one.
|
||||
else: it will print the information of the 'index'-th architecture.
|
||||
:return: nothing
|
||||
"""
|
||||
if index < 0: # show all architectures
|
||||
print(self)
|
||||
for i, idx in enumerate(self.evaluated_indexes):
|
||||
print('\n' + '-' * 10 + ' The ({:5d}/{:5d}) {:06d}-th architecture! '.format(i, len(self.evaluated_indexes), idx) + '-'*10)
|
||||
print('arch : {:}'.format(self.meta_archs[idx]))
|
||||
for key, result in self.arch2infos_dict[index].items():
|
||||
strings = print_information(result)
|
||||
print('>' * 40 + ' {:03d} epochs '.format(result.get_total_epoch()) + '>' * 40)
|
||||
print('\n'.join(strings))
|
||||
print('<' * 40 + '------------' + '<' * 40)
|
||||
else:
|
||||
if 0 <= index < len(self.meta_archs):
|
||||
if index not in self.evaluated_indexes: print('The {:}-th architecture has not been evaluated or not saved.'.format(index))
|
||||
else:
|
||||
arch_info = self.arch2infos_dict[index]
|
||||
for key, result in self.arch2infos_dict[index].items():
|
||||
strings = print_information(result)
|
||||
print('>' * 40 + ' {:03d} epochs '.format(result.get_total_epoch()) + '>' * 40)
|
||||
print('\n'.join(strings))
|
||||
print('<' * 40 + '------------' + '<' * 40)
|
||||
else:
|
||||
print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs)))
|
||||
|
||||
def statistics(self, dataset: Text, hp: Union[Text, int]) -> Dict[int, int]:
|
||||
"""This function will count the number of total trials."""
|
||||
if self.verbose:
|
||||
print('Call the statistics function with dataset={:} and hp={:}.'.format(dataset, hp))
|
||||
valid_datasets = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
|
||||
if dataset not in valid_datasets:
|
||||
raise ValueError('{:} not in {:}'.format(dataset, valid_datasets))
|
||||
nums, hp = defaultdict(lambda: 0), str(hp)
|
||||
for index in range(len(self)):
|
||||
archInfo = self.arch2infos_dict[index][hp]
|
||||
dataset_seed = archInfo.dataset_seed
|
||||
if dataset not in dataset_seed:
|
||||
nums[0] += 1
|
||||
else:
|
||||
nums[len(dataset_seed[dataset])] += 1
|
||||
return dict(nums)
|
||||
|
||||
|
||||
class ArchResults(object):
|
||||
|
||||
def __init__(self, arch_index, arch_str):
|
||||
self.arch_index = int(arch_index)
|
||||
self.arch_str = copy.deepcopy(arch_str)
|
||||
self.all_results = dict()
|
||||
self.dataset_seed = dict()
|
||||
self.clear_net_done = False
|
||||
|
||||
def get_compute_costs(self, dataset):
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
results = [self.all_results[ (dataset, seed) ] for seed in x_seeds]
|
||||
|
||||
flops = [result.flop for result in results]
|
||||
params = [result.params for result in results]
|
||||
latencies = [result.get_latency() for result in results]
|
||||
latencies = [x for x in latencies if x > 0]
|
||||
mean_latency = np.mean(latencies) if len(latencies) > 0 else None
|
||||
time_infos = defaultdict(list)
|
||||
for result in results:
|
||||
time_info = result.get_times()
|
||||
for key, value in time_info.items(): time_infos[key].append( value )
|
||||
|
||||
info = {'flops' : np.mean(flops),
|
||||
'params' : np.mean(params),
|
||||
'latency': mean_latency}
|
||||
for key, value in time_infos.items():
|
||||
if len(value) > 0 and value[0] is not None:
|
||||
info[key] = np.mean(value)
|
||||
else: info[key] = None
|
||||
return info
|
||||
|
||||
def get_metrics(self, dataset, setname, iepoch=None, is_random=False):
|
||||
"""
|
||||
This `get_metrics` function is used to obtain obtain the loss, accuracy, etc information on a specific dataset.
|
||||
If not specify, each set refer to the proposed split in NAS-Bench-201 paper.
|
||||
If some args return None or raise error, then it is not avaliable.
|
||||
========================================
|
||||
Args [dataset] (4 possible options):
|
||||
-- cifar10-valid : training the model on the CIFAR-10 training set.
|
||||
-- cifar10 : training the model on the CIFAR-10 training + validation set.
|
||||
-- cifar100 : training the model on the CIFAR-100 training set.
|
||||
-- ImageNet16-120 : training the model on the ImageNet16-120 training set.
|
||||
Args [setname] (each dataset has different setnames):
|
||||
-- When dataset = cifar10-valid, you can use 'train', 'x-valid', 'ori-test'
|
||||
------ 'train' : the metric on the training set.
|
||||
------ 'x-valid' : the metric on the validation set.
|
||||
------ 'ori-test' : the metric on the test set.
|
||||
-- When dataset = cifar10, you can use 'train', 'ori-test'.
|
||||
------ 'train' : the metric on the training + validation set.
|
||||
------ 'ori-test' : the metric on the test set.
|
||||
-- When dataset = cifar100 or ImageNet16-120, you can use 'train', 'ori-test', 'x-valid', 'x-test'
|
||||
------ 'train' : the metric on the training set.
|
||||
------ 'x-valid' : the metric on the validation set.
|
||||
------ 'x-test' : the metric on the test set.
|
||||
------ 'ori-test' : the metric on the validation + test set.
|
||||
Args [iepoch] (None or an integer in [0, the-number-of-total-training-epochs)
|
||||
------ None : return the metric after the last training epoch.
|
||||
------ an integer i : return the metric after the i-th training epoch.
|
||||
Args [is_random]:
|
||||
------ True : return the metric of a randomly selected trial.
|
||||
------ False : return the averaged metric of all avaliable trials.
|
||||
------ an integer indicating the 'seed' value : return the metric of a specific trial (whose random seed is 'is_random').
|
||||
"""
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
results = [self.all_results[ (dataset, seed) ] for seed in x_seeds]
|
||||
infos = defaultdict(list)
|
||||
for result in results:
|
||||
if setname == 'train':
|
||||
info = result.get_train(iepoch)
|
||||
else:
|
||||
info = result.get_eval(setname, iepoch)
|
||||
for key, value in info.items(): infos[key].append( value )
|
||||
return_info = dict()
|
||||
if isinstance(is_random, bool) and is_random: # randomly select one
|
||||
index = random.randint(0, len(results)-1)
|
||||
for key, value in infos.items(): return_info[key] = value[index]
|
||||
elif isinstance(is_random, bool) and not is_random: # average
|
||||
for key, value in infos.items():
|
||||
if len(value) > 0 and value[0] is not None:
|
||||
return_info[key] = np.mean(value)
|
||||
else: return_info[key] = None
|
||||
elif isinstance(is_random, int): # specify the seed
|
||||
if is_random not in x_seeds: raise ValueError('can not find random seed ({:}) from {:}'.format(is_random, x_seeds))
|
||||
index = x_seeds.index(is_random)
|
||||
for key, value in infos.items(): return_info[key] = value[index]
|
||||
else:
|
||||
raise ValueError('invalid value for is_random: {:}'.format(is_random))
|
||||
return return_info
|
||||
|
||||
def show(self, is_print=False):
|
||||
return print_information(self, None, is_print)
|
||||
|
||||
def get_dataset_names(self):
|
||||
return list(self.dataset_seed.keys())
|
||||
|
||||
def get_dataset_seeds(self, dataset):
|
||||
return copy.deepcopy( self.dataset_seed[dataset] )
|
||||
|
||||
def get_net_param(self, dataset: Text, seed: Union[None, int] =None):
|
||||
"""
|
||||
This function will return the trained network's weights on the 'dataset'.
|
||||
:arg
|
||||
dataset: one of 'cifar10-valid', 'cifar10', 'cifar100', and 'ImageNet16-120'.
|
||||
seed: an integer indicates the seed value or None that indicates returing all trials.
|
||||
"""
|
||||
if seed is None:
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
return {seed: self.all_results[(dataset, seed)].get_net_param() for seed in x_seeds}
|
||||
else:
|
||||
xkey = (dataset, seed)
|
||||
if xkey in self.all_results:
|
||||
return self.all_results[xkey].get_net_param()
|
||||
else:
|
||||
raise ValueError('key={:} not in {:}'.format(xkey, list(self.all_results.keys())))
|
||||
|
||||
def reset_latency(self, dataset: Text, seed: Union[None, Text], latency: float) -> None:
|
||||
"""This function is used to reset the latency in all corresponding ResultsCount(s)."""
|
||||
if seed is None:
|
||||
for seed in self.dataset_seed[dataset]:
|
||||
self.all_results[(dataset, seed)].update_latency([latency])
|
||||
else:
|
||||
self.all_results[(dataset, seed)].update_latency([latency])
|
||||
|
||||
def reset_pseudo_train_times(self, dataset: Text, seed: Union[None, Text], estimated_per_epoch_time: float) -> None:
|
||||
"""This function is used to reset the train-times in all corresponding ResultsCount(s)."""
|
||||
if seed is None:
|
||||
for seed in self.dataset_seed[dataset]:
|
||||
self.all_results[(dataset, seed)].reset_pseudo_train_times(estimated_per_epoch_time)
|
||||
else:
|
||||
self.all_results[(dataset, seed)].reset_pseudo_train_times(estimated_per_epoch_time)
|
||||
|
||||
def reset_pseudo_eval_times(self, dataset: Text, seed: Union[None, Text], eval_name: Text, estimated_per_epoch_time: float) -> None:
|
||||
"""This function is used to reset the eval-times in all corresponding ResultsCount(s)."""
|
||||
if seed is None:
|
||||
for seed in self.dataset_seed[dataset]:
|
||||
self.all_results[(dataset, seed)].reset_pseudo_eval_times(eval_name, estimated_per_epoch_time)
|
||||
else:
|
||||
self.all_results[(dataset, seed)].reset_pseudo_eval_times(eval_name, estimated_per_epoch_time)
|
||||
|
||||
def get_latency(self, dataset: Text) -> float:
|
||||
"""Get the latency of a model on the target dataset. [Timestamp: 2020.03.09]"""
|
||||
latencies = []
|
||||
for seed in self.dataset_seed[dataset]:
|
||||
latency = self.all_results[(dataset, seed)].get_latency()
|
||||
if not isinstance(latency, float) or latency <= 0:
|
||||
raise ValueError('invalid latency of {:} with seed={:} : {:}'.format(dataset, seed, latency))
|
||||
latencies.append(latency)
|
||||
return sum(latencies) / len(latencies)
|
||||
|
||||
def get_total_epoch(self, dataset=None):
|
||||
"""Return the total number of training epochs."""
|
||||
if dataset is None:
|
||||
epochss = []
|
||||
for xdata, x_seeds in self.dataset_seed.items():
|
||||
epochss += [self.all_results[(xdata, seed)].get_total_epoch() for seed in x_seeds]
|
||||
elif isinstance(dataset, str):
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
epochss = [self.all_results[(dataset, seed)].get_total_epoch() for seed in x_seeds]
|
||||
else:
|
||||
raise ValueError('invalid dataset={:}'.format(dataset))
|
||||
if len(set(epochss)) > 1: raise ValueError('Each trial mush have the same number of training epochs : {:}'.format(epochss))
|
||||
return epochss[-1]
|
||||
|
||||
def query(self, dataset, seed=None):
|
||||
"""Return the ResultsCount object (containing all information of a single trial) for 'dataset' and 'seed'"""
|
||||
if seed is None:
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
return {seed: self.all_results[(dataset, seed)] for seed in x_seeds}
|
||||
else:
|
||||
return self.all_results[(dataset, seed)]
|
||||
|
||||
def arch_idx_str(self):
|
||||
return '{:06d}'.format(self.arch_index)
|
||||
|
||||
def update(self, dataset_name, seed, result):
|
||||
if dataset_name not in self.dataset_seed:
|
||||
self.dataset_seed[dataset_name] = []
|
||||
assert seed not in self.dataset_seed[dataset_name], '{:}-th arch alreadly has this seed ({:}) on {:}'.format(self.arch_index, seed, dataset_name)
|
||||
self.dataset_seed[ dataset_name ].append( seed )
|
||||
self.dataset_seed[ dataset_name ] = sorted( self.dataset_seed[ dataset_name ] )
|
||||
assert (dataset_name, seed) not in self.all_results
|
||||
self.all_results[ (dataset_name, seed) ] = result
|
||||
self.clear_net_done = False
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = dict()
|
||||
for key, value in self.__dict__.items():
|
||||
if key == 'all_results': # contain the class of ResultsCount
|
||||
xvalue = dict()
|
||||
assert isinstance(value, dict), 'invalid type of value for {:} : {:}'.format(key, type(value))
|
||||
for _k, _v in value.items():
|
||||
assert isinstance(_v, ResultsCount), 'invalid type of value for {:}/{:} : {:}'.format(key, _k, type(_v))
|
||||
xvalue[_k] = _v.state_dict()
|
||||
else:
|
||||
xvalue = value
|
||||
state_dict[key] = xvalue
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
new_state_dict = dict()
|
||||
for key, value in state_dict.items():
|
||||
if key == 'all_results': # to convert to the class of ResultsCount
|
||||
xvalue = dict()
|
||||
assert isinstance(value, dict), 'invalid type of value for {:} : {:}'.format(key, type(value))
|
||||
for _k, _v in value.items():
|
||||
xvalue[_k] = ResultsCount.create_from_state_dict(_v)
|
||||
else: xvalue = value
|
||||
new_state_dict[key] = xvalue
|
||||
self.__dict__.update(new_state_dict)
|
||||
|
||||
@staticmethod
|
||||
def create_from_state_dict(state_dict_or_file):
|
||||
x = ArchResults(-1, -1)
|
||||
if isinstance(state_dict_or_file, str): # a file path
|
||||
state_dict = torch.load(state_dict_or_file, map_location='cpu')
|
||||
elif isinstance(state_dict_or_file, dict):
|
||||
state_dict = state_dict_or_file
|
||||
else:
|
||||
raise ValueError('invalid type of state_dict_or_file : {:}'.format(type(state_dict_or_file)))
|
||||
x.load_state_dict(state_dict)
|
||||
return x
|
||||
|
||||
# This function is used to clear the weights saved in each 'result'
|
||||
# This can help reduce the memory footprint.
|
||||
def clear_params(self):
|
||||
for key, result in self.all_results.items():
|
||||
del result.net_state_dict
|
||||
result.net_state_dict = None
|
||||
self.clear_net_done = True
|
||||
|
||||
def debug_test(self):
|
||||
"""This function is used for me to debug and test, which will call most methods."""
|
||||
all_dataset = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
|
||||
for dataset in all_dataset:
|
||||
print('---->>>> {:}'.format(dataset))
|
||||
print('The latency on {:} is {:} s'.format(dataset, self.get_latency(dataset)))
|
||||
for seed in self.dataset_seed[dataset]:
|
||||
result = self.all_results[(dataset, seed)]
|
||||
print(' ==>> result = {:}'.format(result))
|
||||
print(' ==>> cost = {:}'.format(result.get_times()))
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}(arch-index={index}, arch={arch}, {num} runs, clear={clear})'.format(name=self.__class__.__name__, index=self.arch_index, arch=self.arch_str, num=len(self.all_results), clear=self.clear_net_done))
|
||||
|
||||
|
||||
"""
|
||||
This class (ResultsCount) is used to save the information of one trial for a single architecture.
|
||||
I did not write much comment for this class, because it is the lowest-level class in NAS-Bench-201 API, which will be rarely called.
|
||||
If you have any question regarding this class, please open an issue or email me.
|
||||
"""
|
||||
class ResultsCount(object):
|
||||
|
||||
def __init__(self, name, state_dict, train_accs, train_losses, params, flop, arch_config, seed, epochs, latency):
|
||||
self.name = name
|
||||
self.net_state_dict = state_dict
|
||||
self.train_acc1es = copy.deepcopy(train_accs)
|
||||
self.train_acc5es = None
|
||||
self.train_losses = copy.deepcopy(train_losses)
|
||||
self.train_times = None
|
||||
self.arch_config = copy.deepcopy(arch_config)
|
||||
self.params = params
|
||||
self.flop = flop
|
||||
self.seed = seed
|
||||
self.epochs = epochs
|
||||
self.latency = latency
|
||||
# evaluation results
|
||||
self.reset_eval()
|
||||
|
||||
def update_train_info(self, train_acc1es, train_acc5es, train_losses, train_times) -> None:
|
||||
self.train_acc1es = train_acc1es
|
||||
self.train_acc5es = train_acc5es
|
||||
self.train_losses = train_losses
|
||||
self.train_times = train_times
|
||||
|
||||
def reset_pseudo_train_times(self, estimated_per_epoch_time: float) -> None:
|
||||
"""Assign the training times."""
|
||||
train_times = OrderedDict()
|
||||
for i in range(self.epochs):
|
||||
train_times[i] = estimated_per_epoch_time
|
||||
self.train_times = train_times
|
||||
|
||||
def reset_pseudo_eval_times(self, eval_name: Text, estimated_per_epoch_time: float) -> None:
|
||||
"""Assign the evaluation times."""
|
||||
if eval_name not in self.eval_names: raise ValueError('invalid eval name : {:}'.format(eval_name))
|
||||
for i in range(self.epochs):
|
||||
self.eval_times['{:}@{:}'.format(eval_name,i)] = estimated_per_epoch_time
|
||||
|
||||
def reset_eval(self):
|
||||
self.eval_names = []
|
||||
self.eval_acc1es = {}
|
||||
self.eval_times = {}
|
||||
self.eval_losses = {}
|
||||
|
||||
def update_latency(self, latency):
|
||||
self.latency = copy.deepcopy( latency )
|
||||
|
||||
def get_latency(self) -> float:
|
||||
"""Return the latency value in seconds. -1 represents not avaliable ; otherwise it should be a float value"""
|
||||
if self.latency is None: return -1.0
|
||||
else: return sum(self.latency) / len(self.latency)
|
||||
|
||||
def update_eval(self, accs, losses, times): # new version
|
||||
data_names = set([x.split('@')[0] for x in accs.keys()])
|
||||
for data_name in data_names:
|
||||
assert data_name not in self.eval_names, '{:} has already been added into eval-names'.format(data_name)
|
||||
self.eval_names.append( data_name )
|
||||
for iepoch in range(self.epochs):
|
||||
xkey = '{:}@{:}'.format(data_name, iepoch)
|
||||
self.eval_acc1es[ xkey ] = accs[ xkey ]
|
||||
self.eval_losses[ xkey ] = losses[ xkey ]
|
||||
self.eval_times [ xkey ] = times[ xkey ]
|
||||
|
||||
def update_OLD_eval(self, name, accs, losses): # old version
|
||||
assert name not in self.eval_names, '{:} has already added'.format(name)
|
||||
self.eval_names.append( name )
|
||||
for iepoch in range(self.epochs):
|
||||
if iepoch in accs:
|
||||
self.eval_acc1es['{:}@{:}'.format(name,iepoch)] = accs[iepoch]
|
||||
self.eval_losses['{:}@{:}'.format(name,iepoch)] = losses[iepoch]
|
||||
|
||||
def __repr__(self):
|
||||
num_eval = len(self.eval_names)
|
||||
set_name = '[' + ', '.join(self.eval_names) + ']'
|
||||
return ('{name}({xname}, arch={arch}, FLOP={flop:.2f}M, Param={param:.3f}MB, seed={seed}, {num_eval} eval-sets: {set_name})'.format(name=self.__class__.__name__, xname=self.name, arch=self.arch_config['arch_str'], flop=self.flop, param=self.params, seed=self.seed, num_eval=num_eval, set_name=set_name))
|
||||
|
||||
def get_total_epoch(self):
|
||||
return copy.deepcopy(self.epochs)
|
||||
|
||||
def get_times(self):
|
||||
"""Obtain the information regarding both training and evaluation time."""
|
||||
if self.train_times is not None and isinstance(self.train_times, dict):
|
||||
train_times = list( self.train_times.values() )
|
||||
time_info = {'T-train@epoch': np.mean(train_times), 'T-train@total': np.sum(train_times)}
|
||||
else:
|
||||
time_info = {'T-train@epoch': None, 'T-train@total': None }
|
||||
for name in self.eval_names:
|
||||
try:
|
||||
xtimes = [self.eval_times['{:}@{:}'.format(name,i)] for i in range(self.epochs)]
|
||||
time_info['T-{:}@epoch'.format(name)] = np.mean(xtimes)
|
||||
time_info['T-{:}@total'.format(name)] = np.sum(xtimes)
|
||||
except:
|
||||
time_info['T-{:}@epoch'.format(name)] = None
|
||||
time_info['T-{:}@total'.format(name)] = None
|
||||
return time_info
|
||||
|
||||
def get_eval_set(self):
|
||||
return self.eval_names
|
||||
|
||||
# get the training information
|
||||
def get_train(self, iepoch=None):
|
||||
if iepoch is None: iepoch = self.epochs-1
|
||||
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
|
||||
if self.train_times is not None:
|
||||
xtime = self.train_times[iepoch]
|
||||
atime = sum([self.train_times[i] for i in range(iepoch+1)])
|
||||
else: xtime, atime = None, None
|
||||
return {'iepoch' : iepoch,
|
||||
'loss' : self.train_losses[iepoch],
|
||||
'accuracy': self.train_acc1es[iepoch],
|
||||
'cur_time': xtime,
|
||||
'all_time': atime}
|
||||
|
||||
def get_eval(self, name, iepoch=None):
|
||||
"""Get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument)."""
|
||||
if iepoch is None: iepoch = self.epochs-1
|
||||
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
|
||||
def _internal_query(xname):
|
||||
if isinstance(self.eval_times,dict) and len(self.eval_times) > 0:
|
||||
xtime = self.eval_times['{:}@{:}'.format(xname, iepoch)]
|
||||
atime = sum([self.eval_times['{:}@{:}'.format(xname, i)] for i in range(iepoch+1)])
|
||||
else:
|
||||
xtime, atime = None, None
|
||||
return {'iepoch' : iepoch,
|
||||
'loss' : self.eval_losses['{:}@{:}'.format(xname, iepoch)],
|
||||
'accuracy': self.eval_acc1es['{:}@{:}'.format(xname, iepoch)],
|
||||
'cur_time': xtime,
|
||||
'all_time': atime}
|
||||
if name == 'valid':
|
||||
return _internal_query('x-valid')
|
||||
else:
|
||||
return _internal_query(name)
|
||||
|
||||
def get_net_param(self, clone=False):
|
||||
if clone: return copy.deepcopy(self.net_state_dict)
|
||||
else: return self.net_state_dict
|
||||
|
||||
def get_config(self, str2structure):
|
||||
"""This function is used to obtain the config dict for this architecture."""
|
||||
if str2structure is None:
|
||||
# In this case, this is NAS-Bench-301
|
||||
if 'name' in self.arch_config and self.arch_config['name'] == 'infer.shape.tiny':
|
||||
return {'name': 'infer.shape.tiny', 'channels': self.arch_config['channels'],
|
||||
'genotype': self.arch_config['genotype'], 'num_classes': self.arch_config['class_num']}
|
||||
# In this case, this is NAS-Bench-201
|
||||
else:
|
||||
return {'name': 'infer.tiny', 'C': self.arch_config['channel'],
|
||||
'N' : self.arch_config['num_cells'],
|
||||
'arch_str': self.arch_config['arch_str'], 'num_classes': self.arch_config['class_num']}
|
||||
else:
|
||||
# In this case, this is NAS-Bench-301
|
||||
if 'name' in self.arch_config and self.arch_config['name'] == 'infer.shape.tiny':
|
||||
return {'name': 'infer.shape.tiny', 'channels': self.arch_config['channels'],
|
||||
'genotype': str2structure(self.arch_config['genotype']), 'num_classes': self.arch_config['class_num']}
|
||||
# In this case, this is NAS-Bench-201
|
||||
else:
|
||||
return {'name': 'infer.tiny', 'C': self.arch_config['channel'],
|
||||
'N' : self.arch_config['num_cells'],
|
||||
'genotype': str2structure(self.arch_config['arch_str']), 'num_classes': self.arch_config['class_num']}
|
||||
|
||||
def state_dict(self):
|
||||
_state_dict = {key: value for key, value in self.__dict__.items()}
|
||||
return _state_dict
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
@staticmethod
|
||||
def create_from_state_dict(state_dict):
|
||||
x = ResultsCount(None, None, None, None, None, None, None, None, None, None)
|
||||
x.load_state_dict(state_dict)
|
||||
return x
|
||||
25
autodl/procedures/__init__.py
Normal file
25
autodl/procedures/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from .starts import prepare_seed, prepare_logger, get_machine_info, save_checkpoint, copy_checkpoint
|
||||
from .optimizers import get_optim_scheduler
|
||||
from .funcs_nasbench import evaluate_for_seed as bench_evaluate_for_seed
|
||||
from .funcs_nasbench import pure_evaluate as bench_pure_evaluate
|
||||
from .funcs_nasbench import get_nas_bench_loaders
|
||||
|
||||
def get_procedures(procedure):
|
||||
from .basic_main import basic_train, basic_valid
|
||||
from .search_main import search_train, search_valid
|
||||
from .search_main_v2 import search_train_v2
|
||||
from .simple_KD_main import simple_KD_train, simple_KD_valid
|
||||
|
||||
train_funcs = {'basic' : basic_train, \
|
||||
'search': search_train,'Simple-KD': simple_KD_train, \
|
||||
'search-v2': search_train_v2}
|
||||
valid_funcs = {'basic' : basic_valid, \
|
||||
'search': search_valid,'Simple-KD': simple_KD_valid, \
|
||||
'search-v2': search_valid}
|
||||
|
||||
train_func = train_funcs[procedure]
|
||||
valid_func = valid_funcs[procedure]
|
||||
return train_func, valid_func
|
||||
75
autodl/procedures/basic_main.py
Normal file
75
autodl/procedures/basic_main.py
Normal file
@@ -0,0 +1,75 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, time, torch
|
||||
from log_utils import AverageMeter, time_string
|
||||
from utils import obtain_accuracy
|
||||
|
||||
|
||||
def basic_train(xloader, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger):
|
||||
loss, acc1, acc5 = procedure(xloader, network, criterion, scheduler, optimizer, 'train', optim_config, extra_info, print_freq, logger)
|
||||
return loss, acc1, acc5
|
||||
|
||||
|
||||
def basic_valid(xloader, network, criterion, optim_config, extra_info, print_freq, logger):
|
||||
with torch.no_grad():
|
||||
loss, acc1, acc5 = procedure(xloader, network, criterion, None, None, 'valid', None, extra_info, print_freq, logger)
|
||||
return loss, acc1, acc5
|
||||
|
||||
|
||||
def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger):
|
||||
data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
|
||||
if mode == 'train':
|
||||
network.train()
|
||||
elif mode == 'valid':
|
||||
network.eval()
|
||||
else: raise ValueError("The mode is not right : {:}".format(mode))
|
||||
|
||||
#logger.log('[{:5s}] config :: auxiliary={:}, message={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, network.module.get_message()))
|
||||
logger.log('[{:5s}] config :: auxiliary={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1))
|
||||
end = time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader))
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
# calculate prediction and loss
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
|
||||
if mode == 'train': optimizer.zero_grad()
|
||||
|
||||
features, logits = network(inputs)
|
||||
if isinstance(logits, list):
|
||||
assert len(logits) == 2, 'logits must has {:} items instead of {:}'.format(2, len(logits))
|
||||
logits, logits_aux = logits
|
||||
else:
|
||||
logits, logits_aux = logits, None
|
||||
loss = criterion(logits, targets)
|
||||
if config is not None and hasattr(config, 'auxiliary') and config.auxiliary > 0:
|
||||
loss_aux = criterion(logits_aux, targets)
|
||||
loss += config.auxiliary * loss_aux
|
||||
|
||||
if mode == 'train':
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# record
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update (prec1.item(), inputs.size(0))
|
||||
top5.update (prec5.item(), inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % print_freq == 0 or (i+1) == len(xloader):
|
||||
Sstr = ' {:5s} '.format(mode.upper()) + time_string() + ' [{:}][{:03d}/{:03d}]'.format(extra_info, i, len(xloader))
|
||||
if scheduler is not None:
|
||||
Sstr += ' {:}'.format(scheduler.get_min_info())
|
||||
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=losses, top1=top1, top5=top5)
|
||||
Istr = 'Size={:}'.format(list(inputs.size()))
|
||||
logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr)
|
||||
|
||||
logger.log(' **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'.format(mode=mode.upper(), top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, loss=losses.avg))
|
||||
return losses.avg, top1.avg, top5.avg
|
||||
203
autodl/procedures/funcs_nasbench.py
Normal file
203
autodl/procedures/funcs_nasbench.py
Normal file
@@ -0,0 +1,203 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||
#####################################################
|
||||
import os, time, copy, torch, pathlib
|
||||
|
||||
import datasets
|
||||
from config_utils import load_config
|
||||
from autodl.procedures import prepare_seed, get_optim_scheduler
|
||||
from autodl.utils import get_model_infos, obtain_accuracy
|
||||
from autodl.log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from models import get_cell_based_tiny_net
|
||||
|
||||
|
||||
__all__ = ['evaluate_for_seed', 'pure_evaluate', 'get_nas_bench_loaders']
|
||||
|
||||
|
||||
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
|
||||
data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
|
||||
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
latencies, device = [], torch.cuda.current_device()
|
||||
network.eval()
|
||||
with torch.no_grad():
|
||||
end = time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
targets = targets.cuda(device=device, non_blocking=True)
|
||||
inputs = inputs.cuda(device=device, non_blocking=True)
|
||||
data_time.update(time.time() - end)
|
||||
# forward
|
||||
features, logits = network(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
batch_time.update(time.time() - end)
|
||||
if batch is None or batch == inputs.size(0):
|
||||
batch = inputs.size(0)
|
||||
latencies.append( batch_time.val - data_time.val )
|
||||
# record loss and accuracy
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update (prec1.item(), inputs.size(0))
|
||||
top5.update (prec5.item(), inputs.size(0))
|
||||
end = time.time()
|
||||
if len(latencies) > 2: latencies = latencies[1:]
|
||||
return losses.avg, top1.avg, top5.avg, latencies
|
||||
|
||||
|
||||
|
||||
def procedure(xloader, network, criterion, scheduler, optimizer, mode: str):
|
||||
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
if mode == 'train' : network.train()
|
||||
elif mode == 'valid': network.eval()
|
||||
else: raise ValueError("The mode is not right : {:}".format(mode))
|
||||
device = torch.cuda.current_device()
|
||||
data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader))
|
||||
|
||||
targets = targets.cuda(device=device, non_blocking=True)
|
||||
if mode == 'train': optimizer.zero_grad()
|
||||
# forward
|
||||
features, logits = network(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
# backward
|
||||
if mode == 'train':
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# record loss and accuracy
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update (prec1.item(), inputs.size(0))
|
||||
top5.update (prec5.item(), inputs.size(0))
|
||||
# count time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
return losses.avg, top1.avg, top5.avg, batch_time.sum
|
||||
|
||||
|
||||
def evaluate_for_seed(arch_config, opt_config, train_loader, valid_loaders, seed: int, logger):
|
||||
|
||||
prepare_seed(seed) # random seed
|
||||
net = get_cell_based_tiny_net(arch_config)
|
||||
#net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num)
|
||||
flop, param = get_model_infos(net, opt_config.xshape)
|
||||
logger.log('Network : {:}'.format(net.get_message()), False)
|
||||
logger.log('{:} Seed-------------------------- {:} --------------------------'.format(time_string(), seed))
|
||||
logger.log('FLOP = {:} MB, Param = {:} MB'.format(flop, param))
|
||||
# train and valid
|
||||
optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), opt_config)
|
||||
default_device = torch.cuda.current_device()
|
||||
network = torch.nn.DataParallel(net, device_ids=[default_device]).cuda(device=default_device)
|
||||
criterion = criterion.cuda(device=default_device)
|
||||
# start training
|
||||
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), opt_config.epochs + opt_config.warmup
|
||||
train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {}
|
||||
train_times , valid_times, lrs = {}, {}, {}
|
||||
for epoch in range(total_epoch):
|
||||
scheduler.update(epoch, 0.0)
|
||||
lr = min(scheduler.get_lr())
|
||||
train_loss, train_acc1, train_acc5, train_tm = procedure(train_loader, network, criterion, scheduler, optimizer, 'train')
|
||||
train_losses[epoch] = train_loss
|
||||
train_acc1es[epoch] = train_acc1
|
||||
train_acc5es[epoch] = train_acc5
|
||||
train_times [epoch] = train_tm
|
||||
lrs[epoch] = lr
|
||||
with torch.no_grad():
|
||||
for key, xloder in valid_loaders.items():
|
||||
valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(xloder , network, criterion, None, None, 'valid')
|
||||
valid_losses['{:}@{:}'.format(key,epoch)] = valid_loss
|
||||
valid_acc1es['{:}@{:}'.format(key,epoch)] = valid_acc1
|
||||
valid_acc5es['{:}@{:}'.format(key,epoch)] = valid_acc5
|
||||
valid_times ['{:}@{:}'.format(key,epoch)] = valid_tm
|
||||
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (total_epoch-epoch-1), True) )
|
||||
logger.log('{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%], lr={:}'.format(time_string(), need_time, epoch, total_epoch, train_loss, train_acc1, train_acc5, valid_loss, valid_acc1, valid_acc5, lr))
|
||||
info_seed = {'flop' : flop,
|
||||
'param': param,
|
||||
'arch_config' : arch_config._asdict(),
|
||||
'opt_config' : opt_config._asdict(),
|
||||
'total_epoch' : total_epoch ,
|
||||
'train_losses': train_losses,
|
||||
'train_acc1es': train_acc1es,
|
||||
'train_acc5es': train_acc5es,
|
||||
'train_times' : train_times,
|
||||
'valid_losses': valid_losses,
|
||||
'valid_acc1es': valid_acc1es,
|
||||
'valid_acc5es': valid_acc5es,
|
||||
'valid_times' : valid_times,
|
||||
'learning_rates': lrs,
|
||||
'net_state_dict': net.state_dict(),
|
||||
'net_string' : '{:}'.format(net),
|
||||
'finish-train': True
|
||||
}
|
||||
return info_seed
|
||||
|
||||
|
||||
def get_nas_bench_loaders(workers):
|
||||
|
||||
torch.set_num_threads(workers)
|
||||
|
||||
root_dir = (pathlib.Path(__file__).parent / '..' / '..').resolve()
|
||||
torch_dir = pathlib.Path(os.environ['TORCH_HOME'])
|
||||
# cifar
|
||||
cifar_config_path = root_dir / 'configs' / 'nas-benchmark' / 'CIFAR.config'
|
||||
cifar_config = load_config(cifar_config_path, None, None)
|
||||
get_datasets = datasets.get_datasets # a function to return the dataset
|
||||
break_line = '-' * 150
|
||||
print ('{:} Create data-loader for all datasets'.format(time_string()))
|
||||
print (break_line)
|
||||
TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets('cifar10', str(torch_dir/'cifar.python'), -1)
|
||||
print ('original CIFAR-10 : {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_CIFAR10), len(VALID_CIFAR10), xshape, class_num))
|
||||
cifar10_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'cifar-split.txt', None, None)
|
||||
assert cifar10_splits.train[:10] == [0, 5, 7, 11, 13, 15, 16, 17, 20, 24] and cifar10_splits.valid[:10] == [1, 2, 3, 4, 6, 8, 9, 10, 12, 14]
|
||||
temp_dataset = copy.deepcopy(TRAIN_CIFAR10)
|
||||
temp_dataset.transform = VALID_CIFAR10.transform
|
||||
# data loader
|
||||
trainval_cifar10_loader = torch.utils.data.DataLoader(TRAIN_CIFAR10, batch_size=cifar_config.batch_size, shuffle=True , num_workers=workers, pin_memory=True)
|
||||
train_cifar10_loader = torch.utils.data.DataLoader(TRAIN_CIFAR10, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.train), num_workers=workers, pin_memory=True)
|
||||
valid_cifar10_loader = torch.utils.data.DataLoader(temp_dataset , batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.valid), num_workers=workers, pin_memory=True)
|
||||
test__cifar10_loader = torch.utils.data.DataLoader(VALID_CIFAR10, batch_size=cifar_config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)
|
||||
print ('CIFAR-10 : trval-loader has {:3d} batch with {:} per batch'.format(len(trainval_cifar10_loader), cifar_config.batch_size))
|
||||
print ('CIFAR-10 : train-loader has {:3d} batch with {:} per batch'.format(len(train_cifar10_loader), cifar_config.batch_size))
|
||||
print ('CIFAR-10 : valid-loader has {:3d} batch with {:} per batch'.format(len(valid_cifar10_loader), cifar_config.batch_size))
|
||||
print ('CIFAR-10 : test--loader has {:3d} batch with {:} per batch'.format(len(test__cifar10_loader), cifar_config.batch_size))
|
||||
print (break_line)
|
||||
# CIFAR-100
|
||||
TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets('cifar100', str(torch_dir/'cifar.python'), -1)
|
||||
print ('original CIFAR-100: {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_CIFAR100), len(VALID_CIFAR100), xshape, class_num))
|
||||
cifar100_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'cifar100-test-split.txt', None, None)
|
||||
assert cifar100_splits.xvalid[:10] == [1, 3, 4, 5, 8, 10, 13, 14, 15, 16] and cifar100_splits.xtest[:10] == [0, 2, 6, 7, 9, 11, 12, 17, 20, 24]
|
||||
train_cifar100_loader = torch.utils.data.DataLoader(TRAIN_CIFAR100, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True)
|
||||
valid_cifar100_loader = torch.utils.data.DataLoader(VALID_CIFAR100, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), num_workers=workers, pin_memory=True)
|
||||
test__cifar100_loader = torch.utils.data.DataLoader(VALID_CIFAR100, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest) , num_workers=workers, pin_memory=True)
|
||||
print ('CIFAR-100 : train-loader has {:3d} batch'.format(len(train_cifar100_loader)))
|
||||
print ('CIFAR-100 : valid-loader has {:3d} batch'.format(len(valid_cifar100_loader)))
|
||||
print ('CIFAR-100 : test--loader has {:3d} batch'.format(len(test__cifar100_loader)))
|
||||
print (break_line)
|
||||
|
||||
imagenet16_config_path = 'configs/nas-benchmark/ImageNet-16.config'
|
||||
imagenet16_config = load_config(imagenet16_config_path, None, None)
|
||||
TRAIN_ImageNet16_120, VALID_ImageNet16_120, xshape, class_num = get_datasets('ImageNet16-120', str(torch_dir/'cifar.python'/'ImageNet16'), -1)
|
||||
print ('original TRAIN_ImageNet16_120: {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_ImageNet16_120), len(VALID_ImageNet16_120), xshape, class_num))
|
||||
imagenet_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'imagenet-16-120-test-split.txt', None, None)
|
||||
assert imagenet_splits.xvalid[:10] == [1, 2, 3, 6, 7, 8, 9, 12, 16, 18] and imagenet_splits.xtest[:10] == [0, 4, 5, 10, 11, 13, 14, 15, 17, 20]
|
||||
train_imagenet_loader = torch.utils.data.DataLoader(TRAIN_ImageNet16_120, batch_size=imagenet16_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True)
|
||||
valid_imagenet_loader = torch.utils.data.DataLoader(VALID_ImageNet16_120, batch_size=imagenet16_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xvalid), num_workers=workers, pin_memory=True)
|
||||
test__imagenet_loader = torch.utils.data.DataLoader(VALID_ImageNet16_120, batch_size=imagenet16_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xtest) , num_workers=workers, pin_memory=True)
|
||||
print ('ImageNet-16-120 : train-loader has {:3d} batch with {:} per batch'.format(len(train_imagenet_loader), imagenet16_config.batch_size))
|
||||
print ('ImageNet-16-120 : valid-loader has {:3d} batch with {:} per batch'.format(len(valid_imagenet_loader), imagenet16_config.batch_size))
|
||||
print ('ImageNet-16-120 : test--loader has {:3d} batch with {:} per batch'.format(len(test__imagenet_loader), imagenet16_config.batch_size))
|
||||
|
||||
# 'cifar10', 'cifar100', 'ImageNet16-120'
|
||||
loaders = {'cifar10@trainval': trainval_cifar10_loader,
|
||||
'cifar10@train' : train_cifar10_loader,
|
||||
'cifar10@valid' : valid_cifar10_loader,
|
||||
'cifar10@test' : test__cifar10_loader,
|
||||
'cifar100@train' : train_cifar100_loader,
|
||||
'cifar100@valid' : valid_cifar100_loader,
|
||||
'cifar100@test' : test__cifar100_loader,
|
||||
'ImageNet16-120@train': train_imagenet_loader,
|
||||
'ImageNet16-120@valid': valid_imagenet_loader,
|
||||
'ImageNet16-120@test' : test__imagenet_loader}
|
||||
return loaders
|
||||
204
autodl/procedures/optimizers.py
Normal file
204
autodl/procedures/optimizers.py
Normal file
@@ -0,0 +1,204 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import math, torch
|
||||
import torch.nn as nn
|
||||
from bisect import bisect_right
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
||||
class _LRScheduler(object):
|
||||
|
||||
def __init__(self, optimizer, warmup_epochs, epochs):
|
||||
if not isinstance(optimizer, Optimizer):
|
||||
raise TypeError('{:} is not an Optimizer'.format(type(optimizer).__name__))
|
||||
self.optimizer = optimizer
|
||||
for group in optimizer.param_groups:
|
||||
group.setdefault('initial_lr', group['lr'])
|
||||
self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
|
||||
self.max_epochs = epochs
|
||||
self.warmup_epochs = warmup_epochs
|
||||
self.current_epoch = 0
|
||||
self.current_iter = 0
|
||||
|
||||
def extra_repr(self):
|
||||
return ''
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}(warmup={warmup_epochs}, max-epoch={max_epochs}, current::epoch={current_epoch}, iter={current_iter:.2f}'.format(name=self.__class__.__name__, **self.__dict__)
|
||||
+ ', {:})'.format(self.extra_repr()))
|
||||
|
||||
def state_dict(self):
|
||||
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
def get_lr(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_min_info(self):
|
||||
lrs = self.get_lr()
|
||||
return '#LR=[{:.6f}~{:.6f}] epoch={:03d}, iter={:4.2f}#'.format(min(lrs), max(lrs), self.current_epoch, self.current_iter)
|
||||
|
||||
def get_min_lr(self):
|
||||
return min( self.get_lr() )
|
||||
|
||||
def update(self, cur_epoch, cur_iter):
|
||||
if cur_epoch is not None:
|
||||
assert isinstance(cur_epoch, int) and cur_epoch>=0, 'invalid cur-epoch : {:}'.format(cur_epoch)
|
||||
self.current_epoch = cur_epoch
|
||||
if cur_iter is not None:
|
||||
assert isinstance(cur_iter, float) and cur_iter>=0, 'invalid cur-iter : {:}'.format(cur_iter)
|
||||
self.current_iter = cur_iter
|
||||
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
|
||||
param_group['lr'] = lr
|
||||
|
||||
|
||||
|
||||
class CosineAnnealingLR(_LRScheduler):
|
||||
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, T_max, eta_min):
|
||||
self.T_max = T_max
|
||||
self.eta_min = eta_min
|
||||
super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'type={:}, T-max={:}, eta-min={:}'.format('cosine', self.T_max, self.eta_min)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs and self.current_epoch < self.max_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
#if last_epoch < self.T_max:
|
||||
#if last_epoch < self.max_epochs:
|
||||
lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * last_epoch / self.T_max)) / 2
|
||||
#else:
|
||||
# lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2
|
||||
elif self.current_epoch >= self.max_epochs:
|
||||
lr = self.eta_min
|
||||
else:
|
||||
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
|
||||
lrs.append( lr )
|
||||
return lrs
|
||||
|
||||
|
||||
|
||||
class MultiStepLR(_LRScheduler):
|
||||
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas):
|
||||
assert len(milestones) == len(gammas), 'invalid {:} vs {:}'.format(len(milestones), len(gammas))
|
||||
self.milestones = milestones
|
||||
self.gammas = gammas
|
||||
super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'type={:}, milestones={:}, gammas={:}, base-lrs={:}'.format('multistep', self.milestones, self.gammas, self.base_lrs)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
idx = bisect_right(self.milestones, last_epoch)
|
||||
lr = base_lr
|
||||
for x in self.gammas[:idx]: lr *= x
|
||||
else:
|
||||
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
|
||||
lrs.append( lr )
|
||||
return lrs
|
||||
|
||||
|
||||
class ExponentialLR(_LRScheduler):
|
||||
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, gamma):
|
||||
self.gamma = gamma
|
||||
super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'type={:}, gamma={:}, base-lrs={:}'.format('exponential', self.gamma, self.base_lrs)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
assert last_epoch >= 0, 'invalid last_epoch : {:}'.format(last_epoch)
|
||||
lr = base_lr * (self.gamma ** last_epoch)
|
||||
else:
|
||||
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
|
||||
lrs.append( lr )
|
||||
return lrs
|
||||
|
||||
|
||||
class LinearLR(_LRScheduler):
|
||||
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, max_LR, min_LR):
|
||||
self.max_LR = max_LR
|
||||
self.min_LR = min_LR
|
||||
super(LinearLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'type={:}, max_LR={:}, min_LR={:}, base-lrs={:}'.format('LinearLR', self.max_LR, self.min_LR, self.base_lrs)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
assert last_epoch >= 0, 'invalid last_epoch : {:}'.format(last_epoch)
|
||||
ratio = (self.max_LR - self.min_LR) * last_epoch / self.max_epochs / self.max_LR
|
||||
lr = base_lr * (1-ratio)
|
||||
else:
|
||||
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
|
||||
lrs.append( lr )
|
||||
return lrs
|
||||
|
||||
|
||||
|
||||
class CrossEntropyLabelSmooth(nn.Module):
|
||||
|
||||
def __init__(self, num_classes, epsilon):
|
||||
super(CrossEntropyLabelSmooth, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.epsilon = epsilon
|
||||
self.logsoftmax = nn.LogSoftmax(dim=1)
|
||||
|
||||
def forward(self, inputs, targets):
|
||||
log_probs = self.logsoftmax(inputs)
|
||||
targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
|
||||
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
|
||||
loss = (-targets * log_probs).mean(0).sum()
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
def get_optim_scheduler(parameters, config):
|
||||
assert hasattr(config, 'optim') and hasattr(config, 'scheduler') and hasattr(config, 'criterion'), 'config must have optim / scheduler / criterion keys instead of {:}'.format(config)
|
||||
if config.optim == 'SGD':
|
||||
optim = torch.optim.SGD(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay, nesterov=config.nesterov)
|
||||
elif config.optim == 'RMSprop':
|
||||
optim = torch.optim.RMSprop(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay)
|
||||
else:
|
||||
raise ValueError('invalid optim : {:}'.format(config.optim))
|
||||
|
||||
if config.scheduler == 'cos':
|
||||
T_max = getattr(config, 'T_max', config.epochs)
|
||||
scheduler = CosineAnnealingLR(optim, config.warmup, config.epochs, T_max, config.eta_min)
|
||||
elif config.scheduler == 'multistep':
|
||||
scheduler = MultiStepLR(optim, config.warmup, config.epochs, config.milestones, config.gammas)
|
||||
elif config.scheduler == 'exponential':
|
||||
scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma)
|
||||
elif config.scheduler == 'linear':
|
||||
scheduler = LinearLR(optim, config.warmup, config.epochs, config.LR, config.LR_min)
|
||||
else:
|
||||
raise ValueError('invalid scheduler : {:}'.format(config.scheduler))
|
||||
|
||||
if config.criterion == 'Softmax':
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
elif config.criterion == 'SmoothSoftmax':
|
||||
criterion = CrossEntropyLabelSmooth(config.class_num, config.label_smooth)
|
||||
else:
|
||||
raise ValueError('invalid criterion : {:}'.format(config.criterion))
|
||||
return optim, scheduler, criterion
|
||||
126
autodl/procedures/search_main.py
Normal file
126
autodl/procedures/search_main.py
Normal file
@@ -0,0 +1,126 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, time, torch
|
||||
from log_utils import AverageMeter, time_string
|
||||
from utils import obtain_accuracy
|
||||
from models import change_key
|
||||
|
||||
|
||||
def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant):
|
||||
expected_flop = torch.mean( expected_flop )
|
||||
|
||||
if flop_cur < flop_need - flop_tolerant: # Too Small FLOP
|
||||
loss = - torch.log( expected_flop )
|
||||
#elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP
|
||||
elif flop_cur > flop_need: # Too Large FLOP
|
||||
loss = torch.log( expected_flop )
|
||||
else: # Required FLOP
|
||||
loss = None
|
||||
if loss is None: return 0, 0
|
||||
else : return loss, loss.item()
|
||||
|
||||
|
||||
def search_train(search_loader, network, criterion, scheduler, base_optimizer, arch_optimizer, optim_config, extra_info, print_freq, logger):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
|
||||
arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter()
|
||||
epoch_str, flop_need, flop_weight, flop_tolerant = extra_info['epoch-str'], extra_info['FLOP-exp'], extra_info['FLOP-weight'], extra_info['FLOP-tolerant']
|
||||
|
||||
network.train()
|
||||
logger.log('[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}'.format(epoch_str, flop_need, flop_weight))
|
||||
end = time.time()
|
||||
network.apply( change_key('search_mode', 'search') )
|
||||
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader):
|
||||
scheduler.update(None, 1.0 * step / len(search_loader))
|
||||
# calculate prediction and loss
|
||||
base_targets = base_targets.cuda(non_blocking=True)
|
||||
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
# update the weights
|
||||
base_optimizer.zero_grad()
|
||||
logits, expected_flop = network(base_inputs)
|
||||
#network.apply( change_key('search_mode', 'basic') )
|
||||
#features, logits = network(base_inputs)
|
||||
base_loss = criterion(logits, base_targets)
|
||||
base_loss.backward()
|
||||
base_optimizer.step()
|
||||
# record
|
||||
prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
|
||||
base_losses.update(base_loss.item(), base_inputs.size(0))
|
||||
top1.update (prec1.item(), base_inputs.size(0))
|
||||
top5.update (prec5.item(), base_inputs.size(0))
|
||||
|
||||
# update the architecture
|
||||
arch_optimizer.zero_grad()
|
||||
logits, expected_flop = network(arch_inputs)
|
||||
flop_cur = network.module.get_flop('genotype', None, None)
|
||||
flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant)
|
||||
acls_loss = criterion(logits, arch_targets)
|
||||
arch_loss = acls_loss + flop_loss * flop_weight
|
||||
arch_loss.backward()
|
||||
arch_optimizer.step()
|
||||
|
||||
# record
|
||||
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
||||
arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0))
|
||||
arch_cls_losses.update (acls_loss.item(), arch_inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
if step % print_freq == 0 or (step+1) == len(search_loader):
|
||||
Sstr = '**TRAIN** ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(search_loader))
|
||||
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||
Lstr = 'Base-Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=base_losses, top1=top1, top5=top5)
|
||||
Vstr = 'Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})'.format(aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses)
|
||||
logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr)
|
||||
#Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size()))
|
||||
#logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr)
|
||||
#print(network.module.get_arch_info())
|
||||
#print(network.module.width_attentions[0])
|
||||
#print(network.module.width_attentions[1])
|
||||
|
||||
logger.log(' **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, baseloss=base_losses.avg, archloss=arch_losses.avg))
|
||||
return base_losses.avg, arch_losses.avg, top1.avg, top5.avg
|
||||
|
||||
|
||||
|
||||
def search_valid(xloader, network, criterion, extra_info, print_freq, logger):
|
||||
data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
|
||||
|
||||
network.eval()
|
||||
network.apply( change_key('search_mode', 'search') )
|
||||
end = time.time()
|
||||
#logger.log('Starting evaluating {:}'.format(epoch_info))
|
||||
with torch.no_grad():
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
# calculate prediction and loss
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
|
||||
logits, expected_flop = network(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
# record
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update (prec1.item(), inputs.size(0))
|
||||
top5.update (prec5.item(), inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % print_freq == 0 or (i+1) == len(xloader):
|
||||
Sstr = '**VALID** ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(extra_info, i, len(xloader))
|
||||
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=losses, top1=top1, top5=top5)
|
||||
Istr = 'Size={:}'.format(list(inputs.size()))
|
||||
logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr)
|
||||
|
||||
logger.log(' **VALID** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, loss=losses.avg))
|
||||
|
||||
return losses.avg, top1.avg, top5.avg
|
||||
87
autodl/procedures/search_main_v2.py
Normal file
87
autodl/procedures/search_main_v2.py
Normal file
@@ -0,0 +1,87 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, time, torch
|
||||
from log_utils import AverageMeter, time_string
|
||||
from utils import obtain_accuracy
|
||||
from models import change_key
|
||||
|
||||
|
||||
def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant):
|
||||
expected_flop = torch.mean( expected_flop )
|
||||
|
||||
if flop_cur < flop_need - flop_tolerant: # Too Small FLOP
|
||||
loss = - torch.log( expected_flop )
|
||||
#elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP
|
||||
elif flop_cur > flop_need: # Too Large FLOP
|
||||
loss = torch.log( expected_flop )
|
||||
else: # Required FLOP
|
||||
loss = None
|
||||
if loss is None: return 0, 0
|
||||
else : return loss, loss.item()
|
||||
|
||||
|
||||
def search_train_v2(search_loader, network, criterion, scheduler, base_optimizer, arch_optimizer, optim_config, extra_info, print_freq, logger):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
|
||||
arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter()
|
||||
epoch_str, flop_need, flop_weight, flop_tolerant = extra_info['epoch-str'], extra_info['FLOP-exp'], extra_info['FLOP-weight'], extra_info['FLOP-tolerant']
|
||||
|
||||
network.train()
|
||||
logger.log('[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}'.format(epoch_str, flop_need, flop_weight))
|
||||
end = time.time()
|
||||
network.apply( change_key('search_mode', 'search') )
|
||||
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader):
|
||||
scheduler.update(None, 1.0 * step / len(search_loader))
|
||||
# calculate prediction and loss
|
||||
base_targets = base_targets.cuda(non_blocking=True)
|
||||
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
# update the weights
|
||||
base_optimizer.zero_grad()
|
||||
logits, expected_flop = network(base_inputs)
|
||||
base_loss = criterion(logits, base_targets)
|
||||
base_loss.backward()
|
||||
base_optimizer.step()
|
||||
# record
|
||||
prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
|
||||
base_losses.update(base_loss.item(), base_inputs.size(0))
|
||||
top1.update (prec1.item(), base_inputs.size(0))
|
||||
top5.update (prec5.item(), base_inputs.size(0))
|
||||
|
||||
# update the architecture
|
||||
arch_optimizer.zero_grad()
|
||||
logits, expected_flop = network(arch_inputs)
|
||||
flop_cur = network.module.get_flop('genotype', None, None)
|
||||
flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant)
|
||||
acls_loss = criterion(logits, arch_targets)
|
||||
arch_loss = acls_loss + flop_loss * flop_weight
|
||||
arch_loss.backward()
|
||||
arch_optimizer.step()
|
||||
|
||||
# record
|
||||
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
||||
arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0))
|
||||
arch_cls_losses.update (acls_loss.item(), arch_inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
if step % print_freq == 0 or (step+1) == len(search_loader):
|
||||
Sstr = '**TRAIN** ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(search_loader))
|
||||
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||
Lstr = 'Base-Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=base_losses, top1=top1, top5=top5)
|
||||
Vstr = 'Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})'.format(aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses)
|
||||
logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr)
|
||||
#num_bytes = torch.cuda.max_memory_allocated( next(network.parameters()).device ) * 1.0
|
||||
#logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' GPU={:.2f}MB'.format(num_bytes/1e6))
|
||||
#Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size()))
|
||||
#logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr)
|
||||
#print(network.module.get_arch_info())
|
||||
#print(network.module.width_attentions[0])
|
||||
#print(network.module.width_attentions[1])
|
||||
|
||||
logger.log(' **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, baseloss=base_losses.avg, archloss=arch_losses.avg))
|
||||
return base_losses.avg, arch_losses.avg, top1.avg, top5.avg
|
||||
94
autodl/procedures/simple_KD_main.py
Normal file
94
autodl/procedures/simple_KD_main.py
Normal file
@@ -0,0 +1,94 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import os, sys, time, torch
|
||||
import torch.nn.functional as F
|
||||
# our modules
|
||||
from log_utils import AverageMeter, time_string
|
||||
from utils import obtain_accuracy
|
||||
|
||||
|
||||
def simple_KD_train(xloader, teacher, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger):
|
||||
loss, acc1, acc5 = procedure(xloader, teacher, network, criterion, scheduler, optimizer, 'train', optim_config, extra_info, print_freq, logger)
|
||||
return loss, acc1, acc5
|
||||
|
||||
def simple_KD_valid(xloader, teacher, network, criterion, optim_config, extra_info, print_freq, logger):
|
||||
with torch.no_grad():
|
||||
loss, acc1, acc5 = procedure(xloader, teacher, network, criterion, None, None, 'valid', optim_config, extra_info, print_freq, logger)
|
||||
return loss, acc1, acc5
|
||||
|
||||
|
||||
def loss_KD_fn(criterion, student_logits, teacher_logits, studentFeatures, teacherFeatures, targets, alpha, temperature):
|
||||
basic_loss = criterion(student_logits, targets) * (1. - alpha)
|
||||
log_student= F.log_softmax(student_logits / temperature, dim=1)
|
||||
sof_teacher= F.softmax (teacher_logits / temperature, dim=1)
|
||||
KD_loss = F.kl_div(log_student, sof_teacher, reduction='batchmean') * (alpha * temperature * temperature)
|
||||
return basic_loss + KD_loss
|
||||
|
||||
|
||||
def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger):
|
||||
data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
|
||||
Ttop1, Ttop5 = AverageMeter(), AverageMeter()
|
||||
if mode == 'train':
|
||||
network.train()
|
||||
elif mode == 'valid':
|
||||
network.eval()
|
||||
else: raise ValueError("The mode is not right : {:}".format(mode))
|
||||
teacher.eval()
|
||||
|
||||
logger.log('[{:5s}] config :: auxiliary={:}, KD :: [alpha={:.2f}, temperature={:.2f}]'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, config.KD_alpha, config.KD_temperature))
|
||||
end = time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader))
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
# calculate prediction and loss
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
|
||||
if mode == 'train': optimizer.zero_grad()
|
||||
|
||||
student_f, logits = network(inputs)
|
||||
if isinstance(logits, list):
|
||||
assert len(logits) == 2, 'logits must has {:} items instead of {:}'.format(2, len(logits))
|
||||
logits, logits_aux = logits
|
||||
else:
|
||||
logits, logits_aux = logits, None
|
||||
with torch.no_grad():
|
||||
teacher_f, teacher_logits = teacher(inputs)
|
||||
|
||||
loss = loss_KD_fn(criterion, logits, teacher_logits, student_f, teacher_f, targets, config.KD_alpha, config.KD_temperature)
|
||||
if config is not None and hasattr(config, 'auxiliary') and config.auxiliary > 0:
|
||||
loss_aux = criterion(logits_aux, targets)
|
||||
loss += config.auxiliary * loss_aux
|
||||
|
||||
if mode == 'train':
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# record
|
||||
sprec1, sprec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update (sprec1.item(), inputs.size(0))
|
||||
top5.update (sprec5.item(), inputs.size(0))
|
||||
# teacher
|
||||
tprec1, tprec5 = obtain_accuracy(teacher_logits.data, targets.data, topk=(1, 5))
|
||||
Ttop1.update (tprec1.item(), inputs.size(0))
|
||||
Ttop5.update (tprec5.item(), inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % print_freq == 0 or (i+1) == len(xloader):
|
||||
Sstr = ' {:5s} '.format(mode.upper()) + time_string() + ' [{:}][{:03d}/{:03d}]'.format(extra_info, i, len(xloader))
|
||||
if scheduler is not None:
|
||||
Sstr += ' {:}'.format(scheduler.get_min_info())
|
||||
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=losses, top1=top1, top5=top5)
|
||||
Lstr+= ' Teacher : acc@1={:.2f}, acc@5={:.2f}'.format(Ttop1.avg, Ttop5.avg)
|
||||
Istr = 'Size={:}'.format(list(inputs.size()))
|
||||
logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr)
|
||||
|
||||
logger.log(' **{:5s}** accuracy drop :: @1={:.2f}, @5={:.2f}'.format(mode.upper(), Ttop1.avg - top1.avg, Ttop5.avg - top5.avg))
|
||||
logger.log(' **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'.format(mode=mode.upper(), top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, loss=losses.avg))
|
||||
return losses.avg, top1.avg, top5.avg
|
||||
64
autodl/procedures/starts.py
Normal file
64
autodl/procedures/starts.py
Normal file
@@ -0,0 +1,64 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, torch, random, PIL, copy, numpy as np
|
||||
from os import path as osp
|
||||
from shutil import copyfile
|
||||
|
||||
|
||||
def prepare_seed(rand_seed):
|
||||
random.seed(rand_seed)
|
||||
np.random.seed(rand_seed)
|
||||
torch.manual_seed(rand_seed)
|
||||
torch.cuda.manual_seed(rand_seed)
|
||||
torch.cuda.manual_seed_all(rand_seed)
|
||||
|
||||
|
||||
def prepare_logger(xargs):
|
||||
args = copy.deepcopy( xargs )
|
||||
from autodl.log_utils import Logger
|
||||
logger = Logger(args.save_dir, args.rand_seed)
|
||||
logger.log('Main Function with logger : {:}'.format(logger))
|
||||
logger.log('Arguments : -------------------------------')
|
||||
for name, value in args._get_kwargs():
|
||||
logger.log('{:16} : {:}'.format(name, value))
|
||||
logger.log("Python Version : {:}".format(sys.version.replace('\n', ' ')))
|
||||
logger.log("Pillow Version : {:}".format(PIL.__version__))
|
||||
logger.log("PyTorch Version : {:}".format(torch.__version__))
|
||||
logger.log("cuDNN Version : {:}".format(torch.backends.cudnn.version()))
|
||||
logger.log("CUDA available : {:}".format(torch.cuda.is_available()))
|
||||
logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count()))
|
||||
logger.log("CUDA_VISIBLE_DEVICES : {:}".format(os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ else 'None'))
|
||||
return logger
|
||||
|
||||
|
||||
def get_machine_info():
|
||||
info = "Python Version : {:}".format(sys.version.replace('\n', ' '))
|
||||
info+= "\nPillow Version : {:}".format(PIL.__version__)
|
||||
info+= "\nPyTorch Version : {:}".format(torch.__version__)
|
||||
info+= "\ncuDNN Version : {:}".format(torch.backends.cudnn.version())
|
||||
info+= "\nCUDA available : {:}".format(torch.cuda.is_available())
|
||||
info+= "\nCUDA GPU numbers : {:}".format(torch.cuda.device_count())
|
||||
if 'CUDA_VISIBLE_DEVICES' in os.environ:
|
||||
info+= "\nCUDA_VISIBLE_DEVICES={:}".format(os.environ['CUDA_VISIBLE_DEVICES'])
|
||||
else:
|
||||
info+= "\nDoes not set CUDA_VISIBLE_DEVICES"
|
||||
return info
|
||||
|
||||
|
||||
def save_checkpoint(state, filename, logger):
|
||||
if osp.isfile(filename):
|
||||
if hasattr(logger, 'log'): logger.log('Find {:} exist, delete is at first before saving'.format(filename))
|
||||
os.remove(filename)
|
||||
torch.save(state, filename)
|
||||
assert osp.isfile(filename), 'save filename : {:} failed, which is not found.'.format(filename)
|
||||
if hasattr(logger, 'log'): logger.log('save checkpoint into {:}'.format(filename))
|
||||
return filename
|
||||
|
||||
|
||||
def copy_checkpoint(src, dst, logger):
|
||||
if osp.isfile(dst):
|
||||
if hasattr(logger, 'log'): logger.log('Find {:} exist, delete is at first before saving'.format(dst))
|
||||
os.remove(dst)
|
||||
copyfile(src, dst)
|
||||
if hasattr(logger, 'log'): logger.log('copy the file from {:} into {:}'.format(src, dst))
|
||||
5
autodl/utils/__init__.py
Normal file
5
autodl/utils/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .evaluation_utils import obtain_accuracy
|
||||
from .gpu_manager import GPUManager
|
||||
from .flop_benchmark import get_model_infos, count_parameters_in_MB
|
||||
from .affine_utils import normalize_points, denormalize_points
|
||||
from .affine_utils import identity2affine, solve2theta, affine2image
|
||||
125
autodl/utils/affine_utils.py
Normal file
125
autodl/utils/affine_utils.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# functions for affine transformation
|
||||
import math, torch
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
|
||||
def identity2affine(full=False):
|
||||
if not full:
|
||||
parameters = torch.zeros((2,3))
|
||||
parameters[0, 0] = parameters[1, 1] = 1
|
||||
else:
|
||||
parameters = torch.zeros((3,3))
|
||||
parameters[0, 0] = parameters[1, 1] = parameters[2, 2] = 1
|
||||
return parameters
|
||||
|
||||
def normalize_L(x, L):
|
||||
return -1. + 2. * x / (L-1)
|
||||
|
||||
def denormalize_L(x, L):
|
||||
return (x + 1.0) / 2.0 * (L-1)
|
||||
|
||||
def crop2affine(crop_box, W, H):
|
||||
assert len(crop_box) == 4, 'Invalid crop-box : {:}'.format(crop_box)
|
||||
parameters = torch.zeros(3,3)
|
||||
x1, y1 = normalize_L(crop_box[0], W), normalize_L(crop_box[1], H)
|
||||
x2, y2 = normalize_L(crop_box[2], W), normalize_L(crop_box[3], H)
|
||||
parameters[0,0] = (x2-x1)/2
|
||||
parameters[0,2] = (x2+x1)/2
|
||||
|
||||
parameters[1,1] = (y2-y1)/2
|
||||
parameters[1,2] = (y2+y1)/2
|
||||
parameters[2,2] = 1
|
||||
return parameters
|
||||
|
||||
def scale2affine(scalex, scaley):
|
||||
parameters = torch.zeros(3,3)
|
||||
parameters[0,0] = scalex
|
||||
parameters[1,1] = scaley
|
||||
parameters[2,2] = 1
|
||||
return parameters
|
||||
|
||||
def offset2affine(offx, offy):
|
||||
parameters = torch.zeros(3,3)
|
||||
parameters[0,0] = parameters[1,1] = parameters[2,2] = 1
|
||||
parameters[0,2] = offx
|
||||
parameters[1,2] = offy
|
||||
return parameters
|
||||
|
||||
def horizontalmirror2affine():
|
||||
parameters = torch.zeros(3,3)
|
||||
parameters[0,0] = -1
|
||||
parameters[1,1] = parameters[2,2] = 1
|
||||
return parameters
|
||||
|
||||
# clockwise rotate image = counterclockwise rotate the rectangle
|
||||
# degree is between [0, 360]
|
||||
def rotate2affine(degree):
|
||||
assert degree >= 0 and degree <= 360, 'Invalid degree : {:}'.format(degree)
|
||||
degree = degree / 180 * math.pi
|
||||
parameters = torch.zeros(3,3)
|
||||
parameters[0,0] = math.cos(-degree)
|
||||
parameters[0,1] = -math.sin(-degree)
|
||||
parameters[1,0] = math.sin(-degree)
|
||||
parameters[1,1] = math.cos(-degree)
|
||||
parameters[2,2] = 1
|
||||
return parameters
|
||||
|
||||
# shape is a tuple [H, W]
|
||||
def normalize_points(shape, points):
|
||||
assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(shape) == 2, 'invalid shape : {:}'.format(shape)
|
||||
assert isinstance(points, torch.Tensor) and (points.shape[0] == 2), 'points are wrong : {:}'.format(points.shape)
|
||||
(H, W), points = shape, points.clone()
|
||||
points[0, :] = normalize_L(points[0,:], W)
|
||||
points[1, :] = normalize_L(points[1,:], H)
|
||||
return points
|
||||
|
||||
# shape is a tuple [H, W]
|
||||
def normalize_points_batch(shape, points):
|
||||
assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(shape) == 2, 'invalid shape : {:}'.format(shape)
|
||||
assert isinstance(points, torch.Tensor) and (points.size(-1) == 2), 'points are wrong : {:}'.format(points.shape)
|
||||
(H, W), points = shape, points.clone()
|
||||
x = normalize_L(points[...,0], W)
|
||||
y = normalize_L(points[...,1], H)
|
||||
return torch.stack((x,y), dim=-1)
|
||||
|
||||
# shape is a tuple [H, W]
|
||||
def denormalize_points(shape, points):
|
||||
assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(shape) == 2, 'invalid shape : {:}'.format(shape)
|
||||
assert isinstance(points, torch.Tensor) and (points.shape[0] == 2), 'points are wrong : {:}'.format(points.shape)
|
||||
(H, W), points = shape, points.clone()
|
||||
points[0, :] = denormalize_L(points[0,:], W)
|
||||
points[1, :] = denormalize_L(points[1,:], H)
|
||||
return points
|
||||
|
||||
# shape is a tuple [H, W]
|
||||
def denormalize_points_batch(shape, points):
|
||||
assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(shape) == 2, 'invalid shape : {:}'.format(shape)
|
||||
assert isinstance(points, torch.Tensor) and (points.shape[-1] == 2), 'points are wrong : {:}'.format(points.shape)
|
||||
(H, W), points = shape, points.clone()
|
||||
x = denormalize_L(points[...,0], W)
|
||||
y = denormalize_L(points[...,1], H)
|
||||
return torch.stack((x,y), dim=-1)
|
||||
|
||||
# make target * theta = source
|
||||
def solve2theta(source, target):
|
||||
source, target = source.clone(), target.clone()
|
||||
oks = source[2, :] == 1
|
||||
assert torch.sum(oks).item() >= 3, 'valid points : {:} is short'.format(oks)
|
||||
if target.size(0) == 2: target = torch.cat((target, oks.unsqueeze(0).float()), dim=0)
|
||||
source, target = source[:, oks], target[:, oks]
|
||||
source, target = source.transpose(1,0), target.transpose(1,0)
|
||||
assert source.size(1) == target.size(1) == 3
|
||||
#X, residual, rank, s = np.linalg.lstsq(target.numpy(), source.numpy())
|
||||
#theta = torch.Tensor(X.T[:2, :])
|
||||
X_, qr = torch.gels(source, target)
|
||||
theta = X_[:3, :2].transpose(1, 0)
|
||||
return theta
|
||||
|
||||
# shape = [H,W]
|
||||
def affine2image(image, theta, shape):
|
||||
C, H, W = image.size()
|
||||
theta = theta[:2, :].unsqueeze(0)
|
||||
grid_size = torch.Size([1, C, shape[0], shape[1]])
|
||||
grid = F.affine_grid(theta, grid_size)
|
||||
affI = F.grid_sample(image.unsqueeze(0), grid, mode='bilinear', padding_mode='border')
|
||||
return affI.squeeze(0)
|
||||
16
autodl/utils/evaluation_utils.py
Normal file
16
autodl/utils/evaluation_utils.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import torch
|
||||
|
||||
def obtain_accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
181
autodl/utils/flop_benchmark.py
Normal file
181
autodl/utils/flop_benchmark.py
Normal file
@@ -0,0 +1,181 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
def count_parameters_in_MB(model):
|
||||
if isinstance(model, nn.Module):
|
||||
return np.sum(np.prod(v.size()) for v in model.parameters())/1e6
|
||||
else:
|
||||
return np.sum(np.prod(v.size()) for v in model)/1e6
|
||||
|
||||
|
||||
def get_model_infos(model, shape):
|
||||
#model = copy.deepcopy( model )
|
||||
|
||||
model = add_flops_counting_methods(model)
|
||||
#model = model.cuda()
|
||||
model.eval()
|
||||
|
||||
#cache_inputs = torch.zeros(*shape).cuda()
|
||||
#cache_inputs = torch.zeros(*shape)
|
||||
cache_inputs = torch.rand(*shape)
|
||||
if next(model.parameters()).is_cuda: cache_inputs = cache_inputs.cuda()
|
||||
#print_log('In the calculating function : cache input size : {:}'.format(cache_inputs.size()), log)
|
||||
with torch.no_grad():
|
||||
_____ = model(cache_inputs)
|
||||
FLOPs = compute_average_flops_cost( model ) / 1e6
|
||||
Param = count_parameters_in_MB(model)
|
||||
|
||||
if hasattr(model, 'auxiliary_param'):
|
||||
aux_params = count_parameters_in_MB(model.auxiliary_param())
|
||||
print ('The auxiliary params of this model is : {:}'.format(aux_params))
|
||||
print ('We remove the auxiliary params from the total params ({:}) when counting'.format(Param))
|
||||
Param = Param - aux_params
|
||||
|
||||
#print_log('FLOPs : {:} MB'.format(FLOPs), log)
|
||||
torch.cuda.empty_cache()
|
||||
model.apply( remove_hook_function )
|
||||
return FLOPs, Param
|
||||
|
||||
|
||||
# ---- Public functions
|
||||
def add_flops_counting_methods( model ):
|
||||
model.__batch_counter__ = 0
|
||||
add_batch_counter_hook_function( model )
|
||||
model.apply( add_flops_counter_variable_or_reset )
|
||||
model.apply( add_flops_counter_hook_function )
|
||||
return model
|
||||
|
||||
|
||||
|
||||
def compute_average_flops_cost(model):
|
||||
"""
|
||||
A method that will be available after add_flops_counting_methods() is called on a desired net object.
|
||||
Returns current mean flops consumption per image.
|
||||
"""
|
||||
batches_count = model.__batch_counter__
|
||||
flops_sum = 0
|
||||
#or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \
|
||||
for module in model.modules():
|
||||
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \
|
||||
or isinstance(module, torch.nn.Conv1d) \
|
||||
or hasattr(module, 'calculate_flop_self'):
|
||||
flops_sum += module.__flops__
|
||||
return flops_sum / batches_count
|
||||
|
||||
|
||||
# ---- Internal functions
|
||||
def pool_flops_counter_hook(pool_module, inputs, output):
|
||||
batch_size = inputs[0].size(0)
|
||||
kernel_size = pool_module.kernel_size
|
||||
out_C, output_height, output_width = output.shape[1:]
|
||||
assert out_C == inputs[0].size(1), '{:} vs. {:}'.format(out_C, inputs[0].size())
|
||||
|
||||
overall_flops = batch_size * out_C * output_height * output_width * kernel_size * kernel_size
|
||||
pool_module.__flops__ += overall_flops
|
||||
|
||||
|
||||
def self_calculate_flops_counter_hook(self_module, inputs, output):
|
||||
overall_flops = self_module.calculate_flop_self(inputs[0].shape, output.shape)
|
||||
self_module.__flops__ += overall_flops
|
||||
|
||||
|
||||
def fc_flops_counter_hook(fc_module, inputs, output):
|
||||
batch_size = inputs[0].size(0)
|
||||
xin, xout = fc_module.in_features, fc_module.out_features
|
||||
assert xin == inputs[0].size(1) and xout == output.size(1), 'IO=({:}, {:})'.format(xin, xout)
|
||||
overall_flops = batch_size * xin * xout
|
||||
if fc_module.bias is not None:
|
||||
overall_flops += batch_size * xout
|
||||
fc_module.__flops__ += overall_flops
|
||||
|
||||
|
||||
def conv1d_flops_counter_hook(conv_module, inputs, outputs):
|
||||
batch_size = inputs[0].size(0)
|
||||
outL = outputs.shape[-1]
|
||||
[kernel] = conv_module.kernel_size
|
||||
in_channels = conv_module.in_channels
|
||||
out_channels = conv_module.out_channels
|
||||
groups = conv_module.groups
|
||||
conv_per_position_flops = kernel * in_channels * out_channels / groups
|
||||
|
||||
active_elements_count = batch_size * outL
|
||||
overall_flops = conv_per_position_flops * active_elements_count
|
||||
|
||||
if conv_module.bias is not None:
|
||||
overall_flops += out_channels * active_elements_count
|
||||
conv_module.__flops__ += overall_flops
|
||||
|
||||
|
||||
def conv2d_flops_counter_hook(conv_module, inputs, output):
|
||||
batch_size = inputs[0].size(0)
|
||||
output_height, output_width = output.shape[2:]
|
||||
|
||||
kernel_height, kernel_width = conv_module.kernel_size
|
||||
in_channels = conv_module.in_channels
|
||||
out_channels = conv_module.out_channels
|
||||
groups = conv_module.groups
|
||||
conv_per_position_flops = kernel_height * kernel_width * in_channels * out_channels / groups
|
||||
|
||||
active_elements_count = batch_size * output_height * output_width
|
||||
overall_flops = conv_per_position_flops * active_elements_count
|
||||
|
||||
if conv_module.bias is not None:
|
||||
overall_flops += out_channels * active_elements_count
|
||||
conv_module.__flops__ += overall_flops
|
||||
|
||||
|
||||
def batch_counter_hook(module, inputs, output):
|
||||
# Can have multiple inputs, getting the first one
|
||||
inputs = inputs[0]
|
||||
batch_size = inputs.shape[0]
|
||||
module.__batch_counter__ += batch_size
|
||||
|
||||
|
||||
def add_batch_counter_hook_function(module):
|
||||
if not hasattr(module, '__batch_counter_handle__'):
|
||||
handle = module.register_forward_hook(batch_counter_hook)
|
||||
module.__batch_counter_handle__ = handle
|
||||
|
||||
|
||||
def add_flops_counter_variable_or_reset(module):
|
||||
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \
|
||||
or isinstance(module, torch.nn.Conv1d) \
|
||||
or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \
|
||||
or hasattr(module, 'calculate_flop_self'):
|
||||
module.__flops__ = 0
|
||||
|
||||
|
||||
def add_flops_counter_hook_function(module):
|
||||
if isinstance(module, torch.nn.Conv2d):
|
||||
if not hasattr(module, '__flops_handle__'):
|
||||
handle = module.register_forward_hook(conv2d_flops_counter_hook)
|
||||
module.__flops_handle__ = handle
|
||||
elif isinstance(module, torch.nn.Conv1d):
|
||||
if not hasattr(module, '__flops_handle__'):
|
||||
handle = module.register_forward_hook(conv1d_flops_counter_hook)
|
||||
module.__flops_handle__ = handle
|
||||
elif isinstance(module, torch.nn.Linear):
|
||||
if not hasattr(module, '__flops_handle__'):
|
||||
handle = module.register_forward_hook(fc_flops_counter_hook)
|
||||
module.__flops_handle__ = handle
|
||||
elif isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d):
|
||||
if not hasattr(module, '__flops_handle__'):
|
||||
handle = module.register_forward_hook(pool_flops_counter_hook)
|
||||
module.__flops_handle__ = handle
|
||||
elif hasattr(module, 'calculate_flop_self'): # self-defined module
|
||||
if not hasattr(module, '__flops_handle__'):
|
||||
handle = module.register_forward_hook(self_calculate_flops_counter_hook)
|
||||
module.__flops_handle__ = handle
|
||||
|
||||
|
||||
def remove_hook_function(module):
|
||||
hookers = ['__batch_counter_handle__', '__flops_handle__']
|
||||
for hooker in hookers:
|
||||
if hasattr(module, hooker):
|
||||
handle = getattr(module, hooker)
|
||||
handle.remove()
|
||||
keys = ['__flops__', '__batch_counter__', '__flops__'] + hookers
|
||||
for ckey in keys:
|
||||
if hasattr(module, ckey): delattr(module, ckey)
|
||||
70
autodl/utils/gpu_manager.py
Normal file
70
autodl/utils/gpu_manager.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import os
|
||||
|
||||
class GPUManager():
|
||||
queries = ('index', 'gpu_name', 'memory.free', 'memory.used', 'memory.total', 'power.draw', 'power.limit')
|
||||
|
||||
def __init__(self):
|
||||
all_gpus = self.query_gpu(False)
|
||||
|
||||
def get_info(self, ctype):
|
||||
cmd = 'nvidia-smi --query-gpu={} --format=csv,noheader'.format(ctype)
|
||||
lines = os.popen(cmd).readlines()
|
||||
lines = [line.strip('\n') for line in lines]
|
||||
return lines
|
||||
|
||||
def query_gpu(self, show=True):
|
||||
num_gpus = len( self.get_info('index') )
|
||||
all_gpus = [ {} for i in range(num_gpus) ]
|
||||
for query in self.queries:
|
||||
infos = self.get_info(query)
|
||||
for idx, info in enumerate(infos):
|
||||
all_gpus[idx][query] = info
|
||||
|
||||
if 'CUDA_VISIBLE_DEVICES' in os.environ:
|
||||
CUDA_VISIBLE_DEVICES = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
|
||||
selected_gpus = []
|
||||
for idx, CUDA_VISIBLE_DEVICE in enumerate(CUDA_VISIBLE_DEVICES):
|
||||
find = False
|
||||
for gpu in all_gpus:
|
||||
if gpu['index'] == CUDA_VISIBLE_DEVICE:
|
||||
assert not find, 'Duplicate cuda device index : {}'.format(CUDA_VISIBLE_DEVICE)
|
||||
find = True
|
||||
selected_gpus.append( gpu.copy() )
|
||||
selected_gpus[-1]['index'] = '{}'.format(idx)
|
||||
assert find, 'Does not find the device : {}'.format(CUDA_VISIBLE_DEVICE)
|
||||
all_gpus = selected_gpus
|
||||
|
||||
if show:
|
||||
allstrings = ''
|
||||
for gpu in all_gpus:
|
||||
string = '| '
|
||||
for query in self.queries:
|
||||
if query.find('memory') == 0: xinfo = '{:>9}'.format(gpu[query])
|
||||
else: xinfo = gpu[query]
|
||||
string = string + query + ' : ' + xinfo + ' | '
|
||||
allstrings = allstrings + string + '\n'
|
||||
return allstrings
|
||||
else:
|
||||
return all_gpus
|
||||
|
||||
def select_by_memory(self, numbers=1):
|
||||
all_gpus = self.query_gpu(False)
|
||||
assert numbers <= len(all_gpus), 'Require {} gpus more than you have'.format(numbers)
|
||||
alls = []
|
||||
for idx, gpu in enumerate(all_gpus):
|
||||
free_memory = gpu['memory.free']
|
||||
free_memory = free_memory.split(' ')[0]
|
||||
free_memory = int(free_memory)
|
||||
index = gpu['index']
|
||||
alls.append((free_memory, index))
|
||||
alls.sort(reverse = True)
|
||||
alls = [ int(alls[i][1]) for i in range(numbers) ]
|
||||
return sorted(alls)
|
||||
|
||||
"""
|
||||
if __name__ == '__main__':
|
||||
manager = GPUManager()
|
||||
manager.query_gpu(True)
|
||||
indexes = manager.select_by_memory(3)
|
||||
print (indexes)
|
||||
"""
|
||||
57
autodl/utils/nas_utils.py
Normal file
57
autodl/utils/nas_utils.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# This file is for experimental usage
|
||||
import torch, random
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch.nn as nn
|
||||
|
||||
# from utils import obtain_accuracy
|
||||
from models import CellStructure
|
||||
from log_utils import time_string
|
||||
|
||||
|
||||
def evaluate_one_shot(model, xloader, api, cal_mode, seed=111):
|
||||
print ('This is an old version of codes to use NAS-Bench-API, and should be modified to align with the new version. Please contact me for more details if you use this function.')
|
||||
weights = deepcopy(model.state_dict())
|
||||
model.train(cal_mode)
|
||||
with torch.no_grad():
|
||||
logits = nn.functional.log_softmax(model.arch_parameters, dim=-1)
|
||||
archs = CellStructure.gen_all(model.op_names, model.max_nodes, False)
|
||||
probs, accuracies, gt_accs_10_valid, gt_accs_10_test = [], [], [], []
|
||||
loader_iter = iter(xloader)
|
||||
random.seed(seed)
|
||||
random.shuffle(archs)
|
||||
for idx, arch in enumerate(archs):
|
||||
arch_index = api.query_index_by_arch( arch )
|
||||
metrics = api.get_more_info(arch_index, 'cifar10-valid', None, False, False)
|
||||
gt_accs_10_valid.append( metrics['valid-accuracy'] )
|
||||
metrics = api.get_more_info(arch_index, 'cifar10', None, False, False)
|
||||
gt_accs_10_test.append( metrics['test-accuracy'] )
|
||||
select_logits = []
|
||||
for i, node_info in enumerate(arch.nodes):
|
||||
for op, xin in node_info:
|
||||
node_str = '{:}<-{:}'.format(i+1, xin)
|
||||
op_index = model.op_names.index(op)
|
||||
select_logits.append( logits[model.edge2index[node_str], op_index] )
|
||||
cur_prob = sum(select_logits).item()
|
||||
probs.append( cur_prob )
|
||||
cor_prob_valid = np.corrcoef(probs, gt_accs_10_valid)[0,1]
|
||||
cor_prob_test = np.corrcoef(probs, gt_accs_10_test )[0,1]
|
||||
print ('{:} correlation for probabilities : {:.6f} on CIFAR-10 validation and {:.6f} on CIFAR-10 test'.format(time_string(), cor_prob_valid, cor_prob_test))
|
||||
|
||||
for idx, arch in enumerate(archs):
|
||||
model.set_cal_mode('dynamic', arch)
|
||||
try:
|
||||
inputs, targets = next(loader_iter)
|
||||
except:
|
||||
loader_iter = iter(xloader)
|
||||
inputs, targets = next(loader_iter)
|
||||
_, logits = model(inputs.cuda())
|
||||
_, preds = torch.max(logits, dim=-1)
|
||||
correct = (preds == targets.cuda() ).float()
|
||||
accuracies.append( correct.mean().item() )
|
||||
if idx != 0 and (idx % 500 == 0 or idx + 1 == len(archs)):
|
||||
cor_accs_valid = np.corrcoef(accuracies, gt_accs_10_valid[:idx+1])[0,1]
|
||||
cor_accs_test = np.corrcoef(accuracies, gt_accs_10_test [:idx+1])[0,1]
|
||||
print ('{:} {:05d}/{:05d} mode={:5s}, correlation : accs={:.5f} for CIFAR-10 valid, {:.5f} for CIFAR-10 test.'.format(time_string(), idx, len(archs), 'Train' if cal_mode else 'Eval', cor_accs_valid, cor_accs_test))
|
||||
model.load_state_dict(weights)
|
||||
return archs, probs, accuracies
|
||||
319
autodl/utils/weight_watcher.py
Normal file
319
autodl/utils/weight_watcher.py
Normal file
@@ -0,0 +1,319 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.03 #
|
||||
#####################################################
|
||||
# Reformulate the codes in https://github.com/CalculatedContent/WeightWatcher
|
||||
#####################################################
|
||||
import numpy as np
|
||||
from typing import List
|
||||
import torch.nn as nn
|
||||
from collections import OrderedDict
|
||||
from sklearn.decomposition import TruncatedSVD
|
||||
|
||||
|
||||
def available_module_types():
|
||||
return (nn.Conv2d, nn.Linear)
|
||||
|
||||
|
||||
def get_conv2D_Wmats(tensor: np.ndarray) -> List[np.ndarray]:
|
||||
"""
|
||||
Extract W slices from a 4 index conv2D tensor of shape: (N,M,i,j) or (M,N,i,j).
|
||||
Return ij (N x M) matrices
|
||||
"""
|
||||
mats = []
|
||||
N, M, imax, jmax = tensor.shape
|
||||
assert N + M >= imax + jmax, 'invalid tensor shape detected: {}x{} (NxM), {}x{} (i,j)'.format(N, M, imax, jmax)
|
||||
for i in range(imax):
|
||||
for j in range(jmax):
|
||||
w = tensor[:, :, i, j]
|
||||
if N < M: w = w.T
|
||||
mats.append(w)
|
||||
return mats
|
||||
|
||||
|
||||
def glorot_norm_check(W, N, M, rf_size, lower=0.5, upper=1.5):
|
||||
"""Check if this layer needs Glorot Normalization Fix"""
|
||||
|
||||
kappa = np.sqrt(2 / ((N + M) * rf_size))
|
||||
norm = np.linalg.norm(W)
|
||||
|
||||
check1 = norm / np.sqrt(N * M)
|
||||
check2 = norm / (kappa * np.sqrt(N * M))
|
||||
|
||||
if (rf_size > 1) and (check2 > lower) and (check2 < upper):
|
||||
return check2, True
|
||||
elif (check1 > lower) & (check1 < upper):
|
||||
return check1, True
|
||||
else:
|
||||
if rf_size > 1: return check2, False
|
||||
else: return check1, False
|
||||
|
||||
def glorot_norm_fix(w, n, m, rf_size):
|
||||
"""Apply Glorot Normalization Fix."""
|
||||
kappa = np.sqrt(2 / ((n + m) * rf_size))
|
||||
w = w / kappa
|
||||
return w
|
||||
|
||||
|
||||
def analyze_weights(weights, min_size, max_size, alphas, lognorms, spectralnorms, softranks, normalize, glorot_fix):
|
||||
results = OrderedDict()
|
||||
count = len(weights)
|
||||
if count == 0: return results
|
||||
|
||||
for i, weight in enumerate(weights):
|
||||
M, N = np.min(weight.shape), np.max(weight.shape)
|
||||
Q = N / M
|
||||
results[i] = cur_res = OrderedDict(N=N, M=M, Q=Q)
|
||||
check, checkTF = glorot_norm_check(weight, N, M, count)
|
||||
cur_res['check'] = check
|
||||
cur_res['checkTF'] = checkTF
|
||||
# assume receptive field size is count
|
||||
if glorot_fix:
|
||||
weight = glorot_norm_fix(weight, N, M, count)
|
||||
else:
|
||||
# probably never needed since we always fix for glorot
|
||||
weight = weight * np.sqrt(count / 2.0)
|
||||
|
||||
if spectralnorms: # spectralnorm is the max eigenvalues
|
||||
svd = TruncatedSVD(n_components=1, n_iter=7, random_state=10)
|
||||
svd.fit(weight)
|
||||
sv = svd.singular_values_
|
||||
sv_max = np.max(sv)
|
||||
if normalize:
|
||||
evals = sv * sv / N
|
||||
else:
|
||||
evals = sv * sv
|
||||
lambda0 = evals[0]
|
||||
cur_res["spectralnorm"] = lambda0
|
||||
cur_res["logspectralnorm"] = np.log10(lambda0)
|
||||
else:
|
||||
lambda0 = None
|
||||
|
||||
if M < min_size:
|
||||
summary = "Weight matrix {}/{} ({},{}): Skipping: too small (<{})".format(i + 1, count, M, N, min_size)
|
||||
cur_res["summary"] = summary
|
||||
continue
|
||||
elif max_size > 0 and M > max_size:
|
||||
summary = "Weight matrix {}/{} ({},{}): Skipping: too big (testing) (>{})".format(i + 1, count, M, N, max_size)
|
||||
cur_res["summary"] = summary
|
||||
continue
|
||||
else:
|
||||
summary = []
|
||||
if alphas:
|
||||
import powerlaw
|
||||
svd = TruncatedSVD(n_components=M - 1, n_iter=7, random_state=10)
|
||||
svd.fit(weight.astype(float))
|
||||
sv = svd.singular_values_
|
||||
if normalize: evals = sv * sv / N
|
||||
else: evals = sv * sv
|
||||
|
||||
lambda_max = np.max(evals)
|
||||
fit = powerlaw.Fit(evals, xmax=lambda_max, verbose=False)
|
||||
alpha = fit.alpha
|
||||
cur_res["alpha"] = alpha
|
||||
D = fit.D
|
||||
cur_res["D"] = D
|
||||
cur_res["lambda_min"] = np.min(evals)
|
||||
cur_res["lambda_max"] = lambda_max
|
||||
alpha_weighted = alpha * np.log10(lambda_max)
|
||||
cur_res["alpha_weighted"] = alpha_weighted
|
||||
tolerance = lambda_max * M * np.finfo(np.max(sv)).eps
|
||||
cur_res["rank_loss"] = np.count_nonzero(sv > tolerance, axis=-1)
|
||||
|
||||
logpnorm = np.log10(np.sum([ev ** alpha for ev in evals]))
|
||||
cur_res["logpnorm"] = logpnorm
|
||||
|
||||
summary.append(
|
||||
"Weight matrix {}/{} ({},{}): Alpha: {}, Alpha Weighted: {}, D: {}, pNorm {}".format(i + 1, count, M, N, alpha,
|
||||
alpha_weighted, D,
|
||||
logpnorm))
|
||||
|
||||
if lognorms:
|
||||
norm = np.linalg.norm(weight) # Frobenius Norm
|
||||
cur_res["norm"] = norm
|
||||
lognorm = np.log10(norm)
|
||||
cur_res["lognorm"] = lognorm
|
||||
|
||||
X = np.dot(weight.T, weight)
|
||||
if normalize: X = X / N
|
||||
normX = np.linalg.norm(X) # Frobenius Norm
|
||||
cur_res["normX"] = normX
|
||||
lognormX = np.log10(normX)
|
||||
cur_res["lognormX"] = lognormX
|
||||
|
||||
summary.append(
|
||||
"Weight matrix {}/{} ({},{}): LogNorm: {} ; LogNormX: {}".format(i + 1, count, M, N, lognorm, lognormX))
|
||||
|
||||
if softranks:
|
||||
softrank = norm ** 2 / sv_max ** 2
|
||||
softranklog = np.log10(softrank)
|
||||
softranklogratio = lognorm / np.log10(sv_max)
|
||||
cur_res["softrank"] = softrank
|
||||
cur_res["softranklog"] = softranklog
|
||||
cur_res["softranklogratio"] = softranklogratio
|
||||
summary += "{}. Softrank: {}. Softrank log: {}. Softrank log ratio: {}".format(summary, softrank, softranklog,
|
||||
softranklogratio)
|
||||
cur_res["summary"] = "\n".join(summary)
|
||||
return results
|
||||
|
||||
|
||||
def compute_details(results):
|
||||
"""
|
||||
Return a pandas data frame.
|
||||
"""
|
||||
final_summary = OrderedDict()
|
||||
|
||||
metrics = {
|
||||
# key in "results" : pretty print name
|
||||
"check": "Check",
|
||||
"checkTF": "CheckTF",
|
||||
"norm": "Norm",
|
||||
"lognorm": "LogNorm",
|
||||
"normX": "Norm X",
|
||||
"lognormX": "LogNorm X",
|
||||
"alpha": "Alpha",
|
||||
"alpha_weighted": "Alpha Weighted",
|
||||
"spectralnorm": "Spectral Norm",
|
||||
"logspectralnorm": "Log Spectral Norm",
|
||||
"softrank": "Softrank",
|
||||
"softranklog": "Softrank Log",
|
||||
"softranklogratio": "Softrank Log Ratio",
|
||||
"sigma_mp": "Marchenko-Pastur (MP) fit sigma",
|
||||
"numofSpikes": "Number of spikes per MP fit",
|
||||
"ratio_numofSpikes": "aka, percent_mass, Number of spikes / total number of evals",
|
||||
"softrank_mp": "Softrank for MP fit",
|
||||
"logpnorm": "alpha pNorm"
|
||||
}
|
||||
|
||||
metrics_stats = []
|
||||
for metric in metrics:
|
||||
metrics_stats.append("{}_min".format(metric))
|
||||
metrics_stats.append("{}_max".format(metric))
|
||||
metrics_stats.append("{}_avg".format(metric))
|
||||
|
||||
metrics_stats.append("{}_compound_min".format(metric))
|
||||
metrics_stats.append("{}_compound_max".format(metric))
|
||||
metrics_stats.append("{}_compound_avg".format(metric))
|
||||
|
||||
columns = ["layer_id", "layer_type", "N", "M", "layer_count", "slice",
|
||||
"slice_count", "level", "comment"] + [*metrics] + metrics_stats
|
||||
|
||||
metrics_values = {}
|
||||
metrics_values_compound = {}
|
||||
|
||||
for metric in metrics:
|
||||
metrics_values[metric] = []
|
||||
metrics_values_compound[metric] = []
|
||||
|
||||
layer_count = 0
|
||||
for layer_id, result in results.items():
|
||||
layer_count += 1
|
||||
|
||||
layer_type = np.NAN
|
||||
if "layer_type" in result:
|
||||
layer_type = str(result["layer_type"]).replace("LAYER_TYPE.", "")
|
||||
|
||||
compounds = {} # temp var
|
||||
for metric in metrics:
|
||||
compounds[metric] = []
|
||||
|
||||
slice_count, Ntotal, Mtotal = 0, 0, 0
|
||||
for slice_id, summary in result.items():
|
||||
if not str(slice_id).isdigit():
|
||||
continue
|
||||
slice_count += 1
|
||||
|
||||
N = np.NAN
|
||||
if "N" in summary:
|
||||
N = summary["N"]
|
||||
Ntotal += N
|
||||
|
||||
M = np.NAN
|
||||
if "M" in summary:
|
||||
M = summary["M"]
|
||||
Mtotal += M
|
||||
|
||||
data = {"layer_id": layer_id, "layer_type": layer_type, "N": N, "M": M, "slice": slice_id, "level": "SLICE",
|
||||
"comment": "Slice level"}
|
||||
for metric in metrics:
|
||||
if metric in summary:
|
||||
value = summary[metric]
|
||||
if value is not None:
|
||||
metrics_values[metric].append(value)
|
||||
compounds[metric].append(value)
|
||||
data[metric] = value
|
||||
|
||||
data = {"layer_id": layer_id, "layer_type": layer_type, "N": Ntotal, "M": Mtotal, "slice_count": slice_count,
|
||||
"level": "LAYER", "comment": "Layer level"}
|
||||
# Compute the compound value over the slices
|
||||
for metric, value in compounds.items():
|
||||
count = len(value)
|
||||
if count == 0:
|
||||
continue
|
||||
|
||||
compound = np.mean(value)
|
||||
metrics_values_compound[metric].append(compound)
|
||||
data[metric] = compound
|
||||
|
||||
data = {"layer_count": layer_count, "level": "NETWORK", "comment": "Network Level"}
|
||||
for metric, metric_name in metrics.items():
|
||||
if metric not in metrics_values or len(metrics_values[metric]) == 0:
|
||||
continue
|
||||
|
||||
values = metrics_values[metric]
|
||||
minimum = min(values)
|
||||
maximum = max(values)
|
||||
avg = np.mean(values)
|
||||
final_summary[metric] = avg
|
||||
# print("{}: min: {}, max: {}, avg: {}".format(metric_name, minimum, maximum, avg))
|
||||
data["{}_min".format(metric)] = minimum
|
||||
data["{}_max".format(metric)] = maximum
|
||||
data["{}_avg".format(metric)] = avg
|
||||
|
||||
values = metrics_values_compound[metric]
|
||||
minimum = min(values)
|
||||
maximum = max(values)
|
||||
avg = np.mean(values)
|
||||
final_summary["{}_compound".format(metric)] = avg
|
||||
# print("{} compound: min: {}, max: {}, avg: {}".format(metric_name, minimum, maximum, avg))
|
||||
data["{}_compound_min".format(metric)] = minimum
|
||||
data["{}_compound_max".format(metric)] = maximum
|
||||
data["{}_compound_avg".format(metric)] = avg
|
||||
|
||||
return final_summary
|
||||
|
||||
|
||||
def analyze(model: nn.Module, min_size=50, max_size=0,
|
||||
alphas: bool = False, lognorms: bool = True, spectralnorms: bool = False,
|
||||
softranks: bool = False, normalize: bool = False, glorot_fix: bool = False):
|
||||
"""
|
||||
Analyze the weight matrices of a model.
|
||||
:param model: A PyTorch model
|
||||
:param min_size: The minimum weight matrix size to analyze.
|
||||
:param max_size: The maximum weight matrix size to analyze (0 = no limit).
|
||||
:param alphas: Compute the power laws (alpha) of the weight matrices.
|
||||
Time consuming so disabled by default (use lognorm if you want speed)
|
||||
:param lognorms: Compute the log norms of the weight matrices.
|
||||
:param spectralnorms: Compute the spectral norm (max eigenvalue) of the weight matrices.
|
||||
:param softranks: Compute the soft norm (i.e. StableRank) of the weight matrices.
|
||||
:param normalize: Normalize or not.
|
||||
:param glorot_fix:
|
||||
:return: (a dict of all layers' results, a dict of the summarized info)
|
||||
"""
|
||||
names, modules = [], []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, available_module_types()):
|
||||
names.append(name)
|
||||
modules.append(module)
|
||||
# print('There are {:} layers to be analyzed in this model.'.format(len(modules)))
|
||||
all_results = OrderedDict()
|
||||
for index, module in enumerate(modules):
|
||||
if isinstance(module, nn.Linear):
|
||||
weights = [module.weight.cpu().detach().numpy()]
|
||||
else:
|
||||
weights = get_conv2D_Wmats(module.weight.cpu().detach().numpy())
|
||||
results = analyze_weights(weights, min_size, max_size, alphas, lognorms, spectralnorms, softranks, normalize, glorot_fix)
|
||||
results['id'] = index
|
||||
results['type'] = type(module)
|
||||
all_results[index] = results
|
||||
summary = compute_details(all_results)
|
||||
return all_results, summary
|
||||
File diff suppressed because one or more lines are too long
@@ -3,3 +3,4 @@
|
||||
##################################################
|
||||
from .get_dataset_with_transform import get_datasets, get_nas_search_loaders
|
||||
from .SearchDatasetWrap import SearchDataset
|
||||
from .data import get_data
|
||||
|
||||
69
datasets/data.py
Normal file
69
datasets/data.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from datasets import get_datasets
|
||||
from config_utils import load_config
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
class AddGaussianNoise(object):
|
||||
def __init__(self, mean=0., std=0.001):
|
||||
self.std = std
|
||||
self.mean = mean
|
||||
|
||||
def __call__(self, tensor):
|
||||
return tensor + torch.randn(tensor.size()) * self.std + self.mean
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
|
||||
|
||||
|
||||
|
||||
|
||||
class RepeatSampler(torch.utils.data.sampler.Sampler):
|
||||
def __init__(self, samp, repeat):
|
||||
self.samp = samp
|
||||
self.repeat = repeat
|
||||
def __iter__(self):
|
||||
for i in self.samp:
|
||||
for j in range(self.repeat):
|
||||
yield i
|
||||
def __len__(self):
|
||||
return self.repeat*len(self.samp)
|
||||
|
||||
|
||||
def get_data(dataset, data_loc, trainval, batch_size, augtype, repeat, args, pin_memory=True):
|
||||
train_data, valid_data, xshape, class_num = get_datasets(dataset, data_loc, cutout=0)
|
||||
if augtype == 'gaussnoise':
|
||||
train_data.transform.transforms = train_data.transform.transforms[2:]
|
||||
train_data.transform.transforms.append(AddGaussianNoise(std=args.sigma))
|
||||
elif augtype == 'cutout':
|
||||
train_data.transform.transforms = train_data.transform.transforms[2:]
|
||||
train_data.transform.transforms.append(torchvision.transforms.RandomErasing(p=0.9, scale=(0.02, 0.04)))
|
||||
elif augtype == 'none':
|
||||
train_data.transform.transforms = train_data.transform.transforms[2:]
|
||||
|
||||
if dataset == 'cifar10':
|
||||
acc_type = 'ori-test'
|
||||
val_acc_type = 'x-valid'
|
||||
|
||||
else:
|
||||
acc_type = 'x-test'
|
||||
val_acc_type = 'x-valid'
|
||||
|
||||
if trainval and 'cifar10' in dataset:
|
||||
cifar_split = load_config('config_utils/cifar-split.txt', None, None)
|
||||
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||
if repeat > 0:
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
|
||||
num_workers=0, pin_memory=pin_memory, sampler= RepeatSampler(torch.utils.data.sampler.SubsetRandomSampler(train_split), repeat))
|
||||
else:
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
|
||||
num_workers=0, pin_memory=pin_memory, sampler= torch.utils.data.sampler.SubsetRandomSampler(train_split))
|
||||
|
||||
|
||||
else:
|
||||
if repeat > 0:
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, #shuffle=True,
|
||||
num_workers=0, pin_memory=pin_memory, sampler= RepeatSampler(torch.utils.data.sampler.SubsetRandomSampler(range(len(train_data))), repeat))
|
||||
else:
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True,
|
||||
num_workers=0, pin_memory=pin_memory)
|
||||
return train_loader
|
||||
@@ -16,7 +16,9 @@ from config_utils import load_config
|
||||
|
||||
Dataset2Class = {'cifar10' : 10,
|
||||
'cifar100': 100,
|
||||
'fake':10,
|
||||
'imagenet-1k-s':1000,
|
||||
'imagenette2' : 10,
|
||||
'imagenet-1k' : 1000,
|
||||
'ImageNet16' : 1000,
|
||||
'ImageNet16-150': 150,
|
||||
@@ -98,8 +100,13 @@ def get_datasets(name, root, cutout):
|
||||
elif name == 'cifar100':
|
||||
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
|
||||
std = [x / 255 for x in [68.2, 65.4, 70.4]]
|
||||
elif name == 'fake':
|
||||
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
|
||||
std = [x / 255 for x in [68.2, 65.4, 70.4]]
|
||||
elif name.startswith('imagenet-1k'):
|
||||
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
||||
elif name.startswith('imagenette'):
|
||||
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
||||
elif name.startswith('ImageNet16'):
|
||||
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
|
||||
std = [x / 255 for x in [63.22, 61.26 , 65.09]]
|
||||
@@ -113,6 +120,12 @@ def get_datasets(name, root, cutout):
|
||||
train_transform = transforms.Compose(lists)
|
||||
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||
xshape = (1, 3, 32, 32)
|
||||
elif name == 'fake':
|
||||
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)]
|
||||
if cutout > 0 : lists += [CUTOUT(cutout)]
|
||||
train_transform = transforms.Compose(lists)
|
||||
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||
xshape = (1, 3, 32, 32)
|
||||
elif name.startswith('ImageNet16'):
|
||||
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(16, padding=2), transforms.ToTensor(), transforms.Normalize(mean, std)]
|
||||
if cutout > 0 : lists += [CUTOUT(cutout)]
|
||||
@@ -125,6 +138,15 @@ def get_datasets(name, root, cutout):
|
||||
train_transform = transforms.Compose(lists)
|
||||
test_transform = transforms.Compose([transforms.CenterCrop(80), transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||
xshape = (1, 3, 32, 32)
|
||||
elif name.startswith('imagenette'):
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
xlists = []
|
||||
xlists.append( transforms.ToTensor() )
|
||||
xlists.append( normalize )
|
||||
#train_transform = transforms.Compose(xlists)
|
||||
train_transform = transforms.Compose([normalize, normalize, transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize])
|
||||
test_transform = transforms.Compose([normalize, normalize, transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize])
|
||||
xshape = (1, 3, 224, 224)
|
||||
elif name.startswith('imagenet-1k'):
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
if name == 'imagenet-1k':
|
||||
@@ -156,6 +178,12 @@ def get_datasets(name, root, cutout):
|
||||
train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True)
|
||||
test_data = dset.CIFAR100(root, train=False, transform=test_transform , download=True)
|
||||
assert len(train_data) == 50000 and len(test_data) == 10000
|
||||
elif name == 'fake':
|
||||
train_data = dset.FakeData(size=50000, image_size=(3, 32, 32), transform=train_transform)
|
||||
test_data = dset.FakeData(size=10000, image_size=(3, 32, 32), transform=test_transform)
|
||||
elif name.startswith('imagenette2'):
|
||||
train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
|
||||
test_data = dset.ImageFolder(osp.join(root, 'val'), test_transform)
|
||||
elif name.startswith('imagenet-1k'):
|
||||
train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
|
||||
test_data = dset.ImageFolder(osp.join(root, 'val'), test_transform)
|
||||
|
||||
24
env.yml
Normal file
24
env.yml
Normal file
@@ -0,0 +1,24 @@
|
||||
name: naswot2
|
||||
channels:
|
||||
- conda-forge
|
||||
- pytorch
|
||||
dependencies:
|
||||
- python=3.7
|
||||
- numpy
|
||||
- matplotlib
|
||||
- seaborn
|
||||
- pandas
|
||||
- xlrd
|
||||
- scipy
|
||||
- pip
|
||||
- scikit-learn
|
||||
- scikit-image
|
||||
- pytorch::pytorch==1.6.0
|
||||
- pytorch::torchvision==0.7.0
|
||||
- cudatoolkit=9.2
|
||||
- tqdm
|
||||
- pip:
|
||||
- tensorflow-gpu==1.15
|
||||
- yacs
|
||||
- simplejson
|
||||
- "--editable=git+https://github.com/google-research/nasbench#egg=nasbench-master"
|
||||
@@ -1,54 +0,0 @@
|
||||
name: nas-wot
|
||||
channels:
|
||||
- pytorch
|
||||
- defaults
|
||||
dependencies:
|
||||
- _libgcc_mutex=0.1=main
|
||||
- blas=1.0=mkl
|
||||
- ca-certificates=2020.1.1=0
|
||||
- certifi=2020.4.5.1=py38_0
|
||||
- cudatoolkit=10.2.89=hfd86e86_1
|
||||
- freetype=2.9.1=h8a8886c_1
|
||||
- intel-openmp=2020.1=217
|
||||
- jpeg=9b=h024ee3a_2
|
||||
- ld_impl_linux-64=2.33.1=h53a641e_7
|
||||
- libedit=3.1.20181209=hc058e9b_0
|
||||
- libffi=3.3=he6710b0_1
|
||||
- libgcc-ng=9.1.0=hdf63c60_0
|
||||
- libgfortran-ng=7.3.0=hdf63c60_0
|
||||
- libpng=1.6.37=hbc83047_0
|
||||
- libstdcxx-ng=9.1.0=hdf63c60_0
|
||||
- libtiff=4.1.0=h2733197_1
|
||||
- lz4-c=1.9.2=he6710b0_0
|
||||
- mkl=2020.1=217
|
||||
- mkl-service=2.3.0=py38he904b0f_0
|
||||
- mkl_fft=1.0.15=py38ha843d7b_0
|
||||
- mkl_random=1.1.1=py38h0573a6f_0
|
||||
- ncurses=6.2=he6710b0_1
|
||||
- ninja=1.9.0=py38hfd86e86_0
|
||||
- numpy=1.18.1=py38h4f9e942_0
|
||||
- numpy-base=1.18.1=py38hde5b4d6_1
|
||||
- olefile=0.46=py_0
|
||||
- openssl=1.1.1g=h7b6447c_0
|
||||
- pandas=1.0.3=py38h0573a6f_0
|
||||
- pillow=7.1.2=py38hb39fc2d_0
|
||||
- pip=20.0.2=py38_3
|
||||
- python=3.8.3=hcff3b4d_0
|
||||
- python-dateutil=2.8.1=py_0
|
||||
- pytorch=1.5.0=py3.8_cuda10.2.89_cudnn7.6.5_0
|
||||
- pytz=2020.1=py_0
|
||||
- readline=8.0=h7b6447c_0
|
||||
- setuptools=46.4.0=py38_0
|
||||
- six=1.14.0=py38_0
|
||||
- sqlite=3.31.1=h62c20be_1
|
||||
- tk=8.6.8=hbc83047_0
|
||||
- torchvision=0.6.0=py38_cu102
|
||||
- tqdm=4.46.0=py_0
|
||||
- wheel=0.34.2=py38_0
|
||||
- xz=5.2.5=h7b6447c_0
|
||||
- zlib=1.2.11=h7b6447c_3
|
||||
- zstd=1.4.4=h0b5b093_3
|
||||
- pip:
|
||||
- argparse==1.4.0
|
||||
- nas-bench-201==1.3
|
||||
- tabulate==0.8.7
|
||||
@@ -55,4 +55,4 @@ class TinyNetwork(nn.Module):
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
||||
return logits, out
|
||||
|
||||
1
nas_101_api/__init__.py
Normal file
1
nas_101_api/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
65
nas_101_api/base_ops.py
Normal file
65
nas_101_api/base_ops.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Base operations used by the modules in this search space."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class ConvBnRelu(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0):
|
||||
super(ConvBnRelu, self).__init__()
|
||||
|
||||
self.conv_bn_relu = nn.Sequential(
|
||||
#nn.ReLU(),
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
#nn.ReLU(inplace=True)
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv_bn_relu(x)
|
||||
|
||||
class Conv3x3BnRelu(nn.Module):
|
||||
"""3x3 convolution with batch norm and ReLU activation."""
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(Conv3x3BnRelu, self).__init__()
|
||||
|
||||
self.conv3x3 = ConvBnRelu(in_channels, out_channels, 3, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv3x3(x)
|
||||
return x
|
||||
|
||||
class Conv1x1BnRelu(nn.Module):
|
||||
"""1x1 convolution with batch norm and ReLU activation."""
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(Conv1x1BnRelu, self).__init__()
|
||||
|
||||
self.conv1x1 = ConvBnRelu(in_channels, out_channels, 1, 1, 0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
class MaxPool3x3(nn.Module):
|
||||
"""3x3 max pool with no subsampling."""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
|
||||
super(MaxPool3x3, self).__init__()
|
||||
|
||||
self.maxpool = nn.MaxPool2d(kernel_size, stride, padding)
|
||||
#self.maxpool = nn.AvgPool2d(kernel_size, stride, padding)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.maxpool(x)
|
||||
return x
|
||||
|
||||
# Commas should not be used in op names
|
||||
OP_MAP = {
|
||||
'conv3x3-bn-relu': Conv3x3BnRelu,
|
||||
'conv1x1-bn-relu': Conv1x1BnRelu,
|
||||
'maxpool3x3': MaxPool3x3
|
||||
}
|
||||
167
nas_101_api/graph_util.py
Normal file
167
nas_101_api/graph_util.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# Copyright 2019 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Utility functions used by generate_graph.py."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import hashlib
|
||||
import itertools
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def gen_is_edge_fn(bits):
|
||||
"""Generate a boolean function for the edge connectivity.
|
||||
|
||||
Given a bitstring FEDCBA and a 4x4 matrix, the generated matrix is
|
||||
[[0, A, B, D],
|
||||
[0, 0, C, E],
|
||||
[0, 0, 0, F],
|
||||
[0, 0, 0, 0]]
|
||||
|
||||
Note that this function is agnostic to the actual matrix dimension due to
|
||||
order in which elements are filled out (column-major, starting from least
|
||||
significant bit). For example, the same FEDCBA bitstring (0-padded) on a 5x5
|
||||
matrix is
|
||||
[[0, A, B, D, 0],
|
||||
[0, 0, C, E, 0],
|
||||
[0, 0, 0, F, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0]]
|
||||
|
||||
Args:
|
||||
bits: integer which will be interpreted as a bit mask.
|
||||
|
||||
Returns:
|
||||
vectorized function that returns True when an edge is present.
|
||||
"""
|
||||
def is_edge(x, y):
|
||||
"""Is there an edge from x to y (0-indexed)?"""
|
||||
if x >= y:
|
||||
return 0
|
||||
# Map x, y to index into bit string
|
||||
index = x + (y * (y - 1) // 2)
|
||||
return (bits >> index) % 2 == 1
|
||||
|
||||
return np.vectorize(is_edge)
|
||||
|
||||
|
||||
def is_full_dag(matrix):
|
||||
"""Full DAG == all vertices on a path from vert 0 to (V-1).
|
||||
|
||||
i.e. no disconnected or "hanging" vertices.
|
||||
|
||||
It is sufficient to check for:
|
||||
1) no rows of 0 except for row V-1 (only output vertex has no out-edges)
|
||||
2) no cols of 0 except for col 0 (only input vertex has no in-edges)
|
||||
|
||||
Args:
|
||||
matrix: V x V upper-triangular adjacency matrix
|
||||
|
||||
Returns:
|
||||
True if the there are no dangling vertices.
|
||||
"""
|
||||
shape = np.shape(matrix)
|
||||
|
||||
rows = matrix[:shape[0]-1, :] == 0
|
||||
rows = np.all(rows, axis=1) # Any row with all 0 will be True
|
||||
rows_bad = np.any(rows)
|
||||
|
||||
cols = matrix[:, 1:] == 0
|
||||
cols = np.all(cols, axis=0) # Any col with all 0 will be True
|
||||
cols_bad = np.any(cols)
|
||||
|
||||
return (not rows_bad) and (not cols_bad)
|
||||
|
||||
|
||||
def num_edges(matrix):
|
||||
"""Computes number of edges in adjacency matrix."""
|
||||
return np.sum(matrix)
|
||||
|
||||
|
||||
def hash_module(matrix, labeling):
|
||||
"""Computes a graph-invariance MD5 hash of the matrix and label pair.
|
||||
|
||||
Args:
|
||||
matrix: np.ndarray square upper-triangular adjacency matrix.
|
||||
labeling: list of int labels of length equal to both dimensions of
|
||||
matrix.
|
||||
|
||||
Returns:
|
||||
MD5 hash of the matrix and labeling.
|
||||
"""
|
||||
vertices = np.shape(matrix)[0]
|
||||
in_edges = np.sum(matrix, axis=0).tolist()
|
||||
out_edges = np.sum(matrix, axis=1).tolist()
|
||||
|
||||
assert len(in_edges) == len(out_edges) == len(labeling)
|
||||
hashes = list(zip(out_edges, in_edges, labeling))
|
||||
hashes = [hashlib.md5(str(h).encode('utf-8')).hexdigest() for h in hashes]
|
||||
# Computing this up to the diameter is probably sufficient but since the
|
||||
# operation is fast, it is okay to repeat more times.
|
||||
for _ in range(vertices):
|
||||
new_hashes = []
|
||||
for v in range(vertices):
|
||||
in_neighbors = [hashes[w] for w in range(vertices) if matrix[w, v]]
|
||||
out_neighbors = [hashes[w] for w in range(vertices) if matrix[v, w]]
|
||||
new_hashes.append(hashlib.md5(
|
||||
(''.join(sorted(in_neighbors)) + '|' +
|
||||
''.join(sorted(out_neighbors)) + '|' +
|
||||
hashes[v]).encode('utf-8')).hexdigest())
|
||||
hashes = new_hashes
|
||||
fingerprint = hashlib.md5(str(sorted(hashes)).encode('utf-8')).hexdigest()
|
||||
|
||||
return fingerprint
|
||||
|
||||
|
||||
def permute_graph(graph, label, permutation):
|
||||
"""Permutes the graph and labels based on permutation.
|
||||
|
||||
Args:
|
||||
graph: np.ndarray adjacency matrix.
|
||||
label: list of labels of same length as graph dimensions.
|
||||
permutation: a permutation list of ints of same length as graph dimensions.
|
||||
|
||||
Returns:
|
||||
np.ndarray where vertex permutation[v] is vertex v from the original graph
|
||||
"""
|
||||
# vertex permutation[v] in new graph is vertex v in the old graph
|
||||
forward_perm = zip(permutation, list(range(len(permutation))))
|
||||
inverse_perm = [x[1] for x in sorted(forward_perm)]
|
||||
edge_fn = lambda x, y: graph[inverse_perm[x], inverse_perm[y]] == 1
|
||||
new_matrix = np.fromfunction(np.vectorize(edge_fn),
|
||||
(len(label), len(label)),
|
||||
dtype=np.int8)
|
||||
new_label = [label[inverse_perm[i]] for i in range(len(label))]
|
||||
return new_matrix, new_label
|
||||
|
||||
|
||||
def is_isomorphic(graph1, graph2):
|
||||
"""Exhaustively checks if 2 graphs are isomorphic."""
|
||||
matrix1, label1 = np.array(graph1[0]), graph1[1]
|
||||
matrix2, label2 = np.array(graph2[0]), graph2[1]
|
||||
assert np.shape(matrix1) == np.shape(matrix2)
|
||||
assert len(label1) == len(label2)
|
||||
|
||||
vertices = np.shape(matrix1)[0]
|
||||
# Note: input and output in our constrained graphs always map to themselves
|
||||
# but this script does not enforce that.
|
||||
for perm in itertools.permutations(range(0, vertices)):
|
||||
pmatrix1, plabel1 = permute_graph(matrix1, label1, perm)
|
||||
if np.array_equal(pmatrix1, matrix2) and plabel1 == label2:
|
||||
return True
|
||||
|
||||
return False
|
||||
252
nas_101_api/model.py
Normal file
252
nas_101_api/model.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""Builds the Pytorch computational graph.
|
||||
|
||||
Tensors flowing into a single vertex are added together for all vertices
|
||||
except the output, which is concatenated instead. Tensors flowing out of input
|
||||
are always added.
|
||||
|
||||
If interior edge channels don't match, drop the extra channels (channels are
|
||||
guaranteed non-decreasing). Tensors flowing out of the input as always
|
||||
projected instead.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
from .base_ops import *
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Network(nn.Module):
|
||||
def __init__(self, spec, args, searchspace=[]):
|
||||
super(Network, self).__init__()
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
in_channels = 3
|
||||
out_channels = args.stem_out_channels
|
||||
|
||||
# initial stem convolution
|
||||
stem_conv = ConvBnRelu(in_channels, out_channels, 3, 1, 1)
|
||||
self.layers.append(stem_conv)
|
||||
|
||||
in_channels = out_channels
|
||||
for stack_num in range(args.num_stacks):
|
||||
if stack_num > 0:
|
||||
#downsample = nn.MaxPool2d(kernel_size=3, stride=2)
|
||||
downsample = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
#downsample = nn.AvgPool2d(kernel_size=2, stride=2)
|
||||
#downsample = nn.Conv2d(in_channels, out_channels, kernel_size=(2, 2), stride=2)
|
||||
self.layers.append(downsample)
|
||||
|
||||
out_channels *= 2
|
||||
|
||||
for module_num in range(args.num_modules_per_stack):
|
||||
cell = Cell(spec, in_channels, out_channels)
|
||||
self.layers.append(cell)
|
||||
in_channels = out_channels
|
||||
|
||||
self.classifier = nn.Linear(out_channels, args.num_labels)
|
||||
|
||||
# for DARTS search
|
||||
num_edge = np.shape(spec.matrix)[0]
|
||||
self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(searchspace)))
|
||||
|
||||
self._initialize_weights()
|
||||
|
||||
def forward(self, x, get_ints=True):
|
||||
ints = []
|
||||
for _, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
ints.append(x)
|
||||
out = torch.mean(x, (2, 3))
|
||||
ints.append(out)
|
||||
out = self.classifier(out)
|
||||
if get_ints:
|
||||
return out, ints[-1]
|
||||
else:
|
||||
return out
|
||||
|
||||
def _initialize_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / n))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
pass
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
pass
|
||||
elif isinstance(m, nn.Linear):
|
||||
n = m.weight.size(1)
|
||||
m.weight.data.normal_(0, 0.01)
|
||||
m.bias.data.zero_()
|
||||
pass
|
||||
|
||||
def get_weights(self):
|
||||
xlist = []
|
||||
for m in self.modules():
|
||||
xlist.append(m.parameters())
|
||||
return xlist
|
||||
|
||||
def get_alphas(self):
|
||||
return [self.arch_parameters]
|
||||
|
||||
def genotype(self):
|
||||
return str(spec)
|
||||
|
||||
|
||||
class Cell(nn.Module):
|
||||
"""
|
||||
Builds the model using the adjacency matrix and op labels specified. Channels
|
||||
controls the module output channel count but the interior channels are
|
||||
determined via equally splitting the channel count whenever there is a
|
||||
concatenation of Tensors.
|
||||
"""
|
||||
def __init__(self, spec, in_channels, out_channels):
|
||||
super(Cell, self).__init__()
|
||||
|
||||
self.spec = spec
|
||||
self.num_vertices = np.shape(self.spec.matrix)[0]
|
||||
|
||||
# vertex_channels[i] = number of output channels of vertex i
|
||||
self.vertex_channels = ComputeVertexChannels(in_channels, out_channels, self.spec.matrix)
|
||||
#self.vertex_channels = [in_channels] + [out_channels] * (self.num_vertices - 1)
|
||||
|
||||
# operation for each node
|
||||
self.vertex_op = nn.ModuleList([None])
|
||||
for t in range(1, self.num_vertices-1):
|
||||
op = OP_MAP[spec.ops[t]](self.vertex_channels[t], self.vertex_channels[t])
|
||||
self.vertex_op.append(op)
|
||||
|
||||
# operation for input on each vertex
|
||||
self.input_op = nn.ModuleList([None])
|
||||
for t in range(1, self.num_vertices):
|
||||
if self.spec.matrix[0, t]:
|
||||
self.input_op.append(Projection(in_channels, self.vertex_channels[t]))
|
||||
else:
|
||||
self.input_op.append(None)
|
||||
|
||||
def forward(self, x):
|
||||
tensors = [x]
|
||||
out_concat = []
|
||||
for t in range(1, self.num_vertices-1):
|
||||
fan_in = [Truncate(tensors[src], self.vertex_channels[t]) for src in range(1, t) if self.spec.matrix[src, t]]
|
||||
fan_in_inds = [src for src in range(1, t) if self.spec.matrix[src, t]]
|
||||
|
||||
if self.spec.matrix[0, t]:
|
||||
fan_in.append(self.input_op[t](x))
|
||||
fan_in_inds = [0] + fan_in_inds
|
||||
|
||||
# perform operation on node
|
||||
#vertex_input = torch.stack(fan_in, dim=0).sum(dim=0)
|
||||
vertex_input = sum(fan_in)
|
||||
#vertex_input = sum(fan_in) / len(fan_in)
|
||||
vertex_output = self.vertex_op[t](vertex_input)
|
||||
|
||||
tensors.append(vertex_output)
|
||||
if self.spec.matrix[t, self.num_vertices-1]:
|
||||
out_concat.append(tensors[t])
|
||||
|
||||
if not out_concat: # empty list
|
||||
assert self.spec.matrix[0, self.num_vertices-1]
|
||||
outputs = self.input_op[self.num_vertices-1](tensors[0])
|
||||
else:
|
||||
if len(out_concat) == 1:
|
||||
outputs = out_concat[0]
|
||||
else:
|
||||
outputs = torch.cat(out_concat, 1)
|
||||
|
||||
if self.spec.matrix[0, self.num_vertices-1]:
|
||||
outputs += self.input_op[self.num_vertices-1](tensors[0])
|
||||
|
||||
#if self.spec.matrix[0, self.num_vertices-1]:
|
||||
# out_concat.append(self.input_op[self.num_vertices-1](tensors[0]))
|
||||
#outputs = sum(out_concat) / len(out_concat)
|
||||
|
||||
return outputs
|
||||
|
||||
def Projection(in_channels, out_channels):
|
||||
"""1x1 projection (as in ResNet) followed by batch normalization and ReLU."""
|
||||
return ConvBnRelu(in_channels, out_channels, 1)
|
||||
|
||||
def Truncate(inputs, channels):
|
||||
"""Slice the inputs to channels if necessary."""
|
||||
input_channels = inputs.size()[1]
|
||||
if input_channels < channels:
|
||||
raise ValueError('input channel < output channels for truncate')
|
||||
elif input_channels == channels:
|
||||
return inputs # No truncation necessary
|
||||
else:
|
||||
# Truncation should only be necessary when channel division leads to
|
||||
# vertices with +1 channels. The input vertex should always be projected to
|
||||
# the minimum channel count.
|
||||
assert input_channels - channels == 1
|
||||
return inputs[:, :channels, :, :]
|
||||
|
||||
def ComputeVertexChannels(in_channels, out_channels, matrix):
|
||||
"""Computes the number of channels at every vertex.
|
||||
|
||||
Given the input channels and output channels, this calculates the number of
|
||||
channels at each interior vertex. Interior vertices have the same number of
|
||||
channels as the max of the channels of the vertices it feeds into. The output
|
||||
channels are divided amongst the vertices that are directly connected to it.
|
||||
When the division is not even, some vertices may receive an extra channel to
|
||||
compensate.
|
||||
|
||||
Returns:
|
||||
list of channel counts, in order of the vertices.
|
||||
"""
|
||||
num_vertices = np.shape(matrix)[0]
|
||||
|
||||
vertex_channels = [0] * num_vertices
|
||||
vertex_channels[0] = in_channels
|
||||
vertex_channels[num_vertices - 1] = out_channels
|
||||
|
||||
if num_vertices == 2:
|
||||
# Edge case where module only has input and output vertices
|
||||
return vertex_channels
|
||||
|
||||
# Compute the in-degree ignoring input, axis 0 is the src vertex and axis 1 is
|
||||
# the dst vertex. Summing over 0 gives the in-degree count of each vertex.
|
||||
in_degree = np.sum(matrix[1:], axis=0)
|
||||
interior_channels = out_channels // in_degree[num_vertices - 1]
|
||||
correction = out_channels % in_degree[num_vertices - 1] # Remainder to add
|
||||
|
||||
# Set channels of vertices that flow directly to output
|
||||
for v in range(1, num_vertices - 1):
|
||||
if matrix[v, num_vertices - 1]:
|
||||
vertex_channels[v] = interior_channels
|
||||
if correction:
|
||||
vertex_channels[v] += 1
|
||||
correction -= 1
|
||||
|
||||
# Set channels for all other vertices to the max of the out edges, going
|
||||
# backwards. (num_vertices - 2) index skipped because it only connects to
|
||||
# output.
|
||||
for v in range(num_vertices - 3, 0, -1):
|
||||
if not matrix[v, num_vertices - 1]:
|
||||
for dst in range(v + 1, num_vertices - 1):
|
||||
if matrix[v, dst]:
|
||||
vertex_channels[v] = max(vertex_channels[v], vertex_channels[dst])
|
||||
assert vertex_channels[v] > 0
|
||||
|
||||
# Sanity check, verify that channels never increase and final channels add up.
|
||||
final_fan_in = 0
|
||||
for v in range(1, num_vertices - 1):
|
||||
if matrix[v, num_vertices - 1]:
|
||||
final_fan_in += vertex_channels[v]
|
||||
for dst in range(v + 1, num_vertices - 1):
|
||||
if matrix[v, dst]:
|
||||
assert vertex_channels[v] >= vertex_channels[dst]
|
||||
assert final_fan_in == out_channels or num_vertices == 2
|
||||
# num_vertices == 2 means only input/output nodes, so 0 fan-in
|
||||
|
||||
return vertex_channels
|
||||
152
nas_101_api/model_spec.py
Normal file
152
nas_101_api/model_spec.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""Model specification for module connectivity individuals.
|
||||
|
||||
This module handles pruning the unused parts of the computation graph but should
|
||||
avoid creating any TensorFlow models (this is done inside model_builder.py).
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import numpy as np
|
||||
|
||||
from . import graph_util
|
||||
|
||||
# Graphviz is optional and only required for visualization.
|
||||
try:
|
||||
import graphviz # pylint: disable=g-import-not-at-top
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class ModelSpec(object):
|
||||
"""Model specification given adjacency matrix and labeling."""
|
||||
|
||||
def __init__(self, matrix, ops, data_format='channels_last'):
|
||||
"""Initialize the module spec.
|
||||
|
||||
Args:
|
||||
matrix: ndarray or nested list with shape [V, V] for the adjacency matrix.
|
||||
ops: V-length list of labels for the base ops used. The first and last
|
||||
elements are ignored because they are the input and output vertices
|
||||
which have no operations. The elements are retained to keep consistent
|
||||
indexing.
|
||||
data_format: channels_last or channels_first.
|
||||
|
||||
Raises:
|
||||
ValueError: invalid matrix or ops
|
||||
"""
|
||||
if not isinstance(matrix, np.ndarray):
|
||||
matrix = np.array(matrix)
|
||||
shape = np.shape(matrix)
|
||||
if len(shape) != 2 or shape[0] != shape[1]:
|
||||
raise ValueError('matrix must be square')
|
||||
if shape[0] != len(ops):
|
||||
raise ValueError('length of ops must match matrix dimensions')
|
||||
if not is_upper_triangular(matrix):
|
||||
raise ValueError('matrix must be upper triangular')
|
||||
|
||||
# Both the original and pruned matrices are deep copies of the matrix and
|
||||
# ops so any changes to those after initialization are not recognized by the
|
||||
# spec.
|
||||
self.original_matrix = copy.deepcopy(matrix)
|
||||
self.original_ops = copy.deepcopy(ops)
|
||||
|
||||
self.matrix = copy.deepcopy(matrix)
|
||||
self.ops = copy.deepcopy(ops)
|
||||
self.valid_spec = True
|
||||
self._prune()
|
||||
|
||||
self.data_format = data_format
|
||||
|
||||
def _prune(self):
|
||||
"""Prune the extraneous parts of the graph.
|
||||
|
||||
General procedure:
|
||||
1) Remove parts of graph not connected to input.
|
||||
2) Remove parts of graph not connected to output.
|
||||
3) Reorder the vertices so that they are consecutive after steps 1 and 2.
|
||||
|
||||
These 3 steps can be combined by deleting the rows and columns of the
|
||||
vertices that are not reachable from both the input and output (in reverse).
|
||||
"""
|
||||
num_vertices = np.shape(self.original_matrix)[0]
|
||||
|
||||
# DFS forward from input
|
||||
visited_from_input = set([0])
|
||||
frontier = [0]
|
||||
while frontier:
|
||||
top = frontier.pop()
|
||||
for v in range(top + 1, num_vertices):
|
||||
if self.original_matrix[top, v] and v not in visited_from_input:
|
||||
visited_from_input.add(v)
|
||||
frontier.append(v)
|
||||
|
||||
# DFS backward from output
|
||||
visited_from_output = set([num_vertices - 1])
|
||||
frontier = [num_vertices - 1]
|
||||
while frontier:
|
||||
top = frontier.pop()
|
||||
for v in range(0, top):
|
||||
if self.original_matrix[v, top] and v not in visited_from_output:
|
||||
visited_from_output.add(v)
|
||||
frontier.append(v)
|
||||
|
||||
# Any vertex that isn't connected to both input and output is extraneous to
|
||||
# the computation graph.
|
||||
extraneous = set(range(num_vertices)).difference(
|
||||
visited_from_input.intersection(visited_from_output))
|
||||
|
||||
# If the non-extraneous graph is less than 2 vertices, the input is not
|
||||
# connected to the output and the spec is invalid.
|
||||
if len(extraneous) > num_vertices - 2:
|
||||
self.matrix = None
|
||||
self.ops = None
|
||||
self.valid_spec = False
|
||||
return
|
||||
|
||||
self.matrix = np.delete(self.matrix, list(extraneous), axis=0)
|
||||
self.matrix = np.delete(self.matrix, list(extraneous), axis=1)
|
||||
for index in sorted(extraneous, reverse=True):
|
||||
del self.ops[index]
|
||||
|
||||
def hash_spec(self, canonical_ops):
|
||||
"""Computes the isomorphism-invariant graph hash of this spec.
|
||||
|
||||
Args:
|
||||
canonical_ops: list of operations in the canonical ordering which they
|
||||
were assigned (i.e. the order provided in the config['available_ops']).
|
||||
|
||||
Returns:
|
||||
MD5 hash of this spec which can be used to query the dataset.
|
||||
"""
|
||||
# Invert the operations back to integer label indices used in graph gen.
|
||||
labeling = [-1] + [canonical_ops.index(op) for op in self.ops[1:-1]] + [-2]
|
||||
return graph_util.hash_module(self.matrix, labeling)
|
||||
|
||||
def visualize(self):
|
||||
"""Creates a dot graph. Can be visualized in colab directly."""
|
||||
num_vertices = np.shape(self.matrix)[0]
|
||||
g = graphviz.Digraph()
|
||||
g.node(str(0), 'input')
|
||||
for v in range(1, num_vertices - 1):
|
||||
g.node(str(v), self.ops[v])
|
||||
g.node(str(num_vertices - 1), 'output')
|
||||
|
||||
for src in range(num_vertices - 1):
|
||||
for dst in range(src + 1, num_vertices):
|
||||
if self.matrix[src, dst]:
|
||||
g.edge(str(src), str(dst))
|
||||
|
||||
return g
|
||||
|
||||
|
||||
def is_upper_triangular(matrix):
|
||||
"""True if matrix is 0 on diagonal and below."""
|
||||
for src in range(np.shape(matrix)[0]):
|
||||
for dst in range(0, src + 1):
|
||||
if matrix[src, dst] != 0:
|
||||
return False
|
||||
|
||||
return True
|
||||
15
nas_201_api/__init__.py
Normal file
15
nas_201_api/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||
#####################################################################
|
||||
# This API will not be updated after 2020.09.16. #
|
||||
# Please use our new API in NATS-Bench, which is #
|
||||
# more efficient and contains info of more architecture candidates. #
|
||||
#####################################################################
|
||||
from .api_utils import ArchResults, ResultsCount
|
||||
from .api_201 import NASBench201API
|
||||
|
||||
# NAS_BENCH_201_API_VERSION="v1.1" # [2020.02.25]
|
||||
# NAS_BENCH_201_API_VERSION="v1.2" # [2020.03.09]
|
||||
# NAS_BENCH_201_API_VERSION="v1.3" # [2020.03.16]
|
||||
NAS_BENCH_201_API_VERSION="v2.0" # [2020.06.30]
|
||||
|
||||
830
nas_201_api/api.py
Normal file
830
nas_201_api/api.py
Normal file
@@ -0,0 +1,830 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||
############################################################################################
|
||||
# NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search, ICLR 2020 #
|
||||
############################################################################################
|
||||
# [2020.02.25] NAS-Bench-201-v1_0-e61699.pth : 6219 architectures are trained once, 1621 architectures are trained twice, 7785 architectures are trained three times. `LESS` only supports CIFAR10-VALID.
|
||||
# [2020.03.08] Next version (coming soon)
|
||||
#
|
||||
#
|
||||
import os, copy, random, torch, numpy as np
|
||||
from typing import List, Text, Union, Dict
|
||||
from collections import OrderedDict, defaultdict
|
||||
|
||||
|
||||
def print_information(information, extra_info=None, show=False):
|
||||
dataset_names = information.get_dataset_names()
|
||||
strings = [information.arch_str, 'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info)]
|
||||
def metric2str(loss, acc):
|
||||
return 'loss = {:.3f}, top1 = {:.2f}%'.format(loss, acc)
|
||||
|
||||
for ida, dataset in enumerate(dataset_names):
|
||||
metric = information.get_compute_costs(dataset)
|
||||
flop, param, latency = metric['flops'], metric['params'], metric['latency']
|
||||
str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(dataset, flop, param, '{:.2f}'.format(latency*1000) if latency is not None and latency > 0 else None)
|
||||
train_info = information.get_metrics(dataset, 'train')
|
||||
if dataset == 'cifar10-valid':
|
||||
valid_info = information.get_metrics(dataset, 'x-valid')
|
||||
str2 = '{:14s} train : [{:}], valid : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']))
|
||||
elif dataset == 'cifar10':
|
||||
test__info = information.get_metrics(dataset, 'ori-test')
|
||||
str2 = '{:14s} train : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy']))
|
||||
else:
|
||||
valid_info = information.get_metrics(dataset, 'x-valid')
|
||||
test__info = information.get_metrics(dataset, 'x-test')
|
||||
str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy']))
|
||||
strings += [str1, str2]
|
||||
if show: print('\n'.join(strings))
|
||||
return strings
|
||||
|
||||
"""
|
||||
This is the class for API of NAS-Bench-201.
|
||||
"""
|
||||
class NASBench201API(object):
|
||||
|
||||
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
|
||||
def __init__(self, file_path_or_dict: Union[Text, Dict], verbose: bool=True):
|
||||
if isinstance(file_path_or_dict, str):
|
||||
if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict))
|
||||
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
|
||||
file_path_or_dict = torch.load(file_path_or_dict)
|
||||
elif isinstance(file_path_or_dict, dict):
|
||||
file_path_or_dict = copy.deepcopy( file_path_or_dict )
|
||||
else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict)))
|
||||
assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict))
|
||||
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
|
||||
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
|
||||
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )
|
||||
self.arch2infos_less = OrderedDict()
|
||||
self.arch2infos_full = OrderedDict()
|
||||
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
|
||||
all_info = file_path_or_dict['arch2infos'][xkey]
|
||||
self.arch2infos_less[xkey] = ArchResults.create_from_state_dict( all_info['less'] )
|
||||
self.arch2infos_full[xkey] = ArchResults.create_from_state_dict( all_info['full'] )
|
||||
self.evaluated_indexes = sorted(list(file_path_or_dict['evaluated_indexes']))
|
||||
self.archstr2index = {}
|
||||
for idx, arch in enumerate(self.meta_archs):
|
||||
#assert arch.tostr() not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch.tostr()])
|
||||
assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch])
|
||||
self.archstr2index[ arch ] = idx
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
return copy.deepcopy( self.meta_archs[index] )
|
||||
|
||||
def __len__(self):
|
||||
return len(self.meta_archs)
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}({num}/{total} architectures)'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs)))
|
||||
|
||||
def random(self):
|
||||
"""Return a random index of all architectures."""
|
||||
return random.randint(0, len(self.meta_archs)-1)
|
||||
|
||||
# This function is used to query the index of an architecture in the search space.
|
||||
# The input arch can be an architecture string such as '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|'
|
||||
# or an instance that has the 'tostr' function that can generate the architecture string.
|
||||
# This function will return the index.
|
||||
# If return -1, it means this architecture is not in the search space.
|
||||
# Otherwise, it will return an int in [0, the-number-of-candidates-in-the-search-space).
|
||||
def query_index_by_arch(self, arch):
|
||||
if isinstance(arch, str):
|
||||
if arch in self.archstr2index: arch_index = self.archstr2index[ arch ]
|
||||
else : arch_index = -1
|
||||
elif hasattr(arch, 'tostr'):
|
||||
if arch.tostr() in self.archstr2index: arch_index = self.archstr2index[ arch.tostr() ]
|
||||
else : arch_index = -1
|
||||
else: arch_index = -1
|
||||
return arch_index
|
||||
|
||||
# Overwrite all information of the 'index'-th architecture in the search space.
|
||||
# It will load its data from 'archive_root'.
|
||||
def reload(self, archive_root: Text, index: int):
|
||||
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
|
||||
xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(index))
|
||||
assert 0 <= index < len(self.meta_archs), 'invalid index of {:}'.format(index)
|
||||
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
|
||||
xdata = torch.load(xfile_path)
|
||||
assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path)
|
||||
self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] )
|
||||
self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] )
|
||||
|
||||
# This function is used to query the information of a specific archiitecture
|
||||
# 'arch' can be an architecture index or an architecture string
|
||||
# When use_12epochs_result=True, the hyper-parameters used to train a model are in 'configs/nas-benchmark/CIFAR.config'
|
||||
# When use_12epochs_result=False, the hyper-parameters used to train a model are in 'configs/nas-benchmark/LESS.config'
|
||||
# The difference between these two configurations are the number of training epochs, which is 200 in CIFAR.config and 12 in LESS.config.
|
||||
def query_by_arch(self, arch, use_12epochs_result=False):
|
||||
if isinstance(arch, int):
|
||||
arch_index = arch
|
||||
else:
|
||||
arch_index = self.query_index_by_arch(arch)
|
||||
if arch_index == -1: return None # the following two lines are used to support few training epochs
|
||||
if use_12epochs_result: arch2infos = self.arch2infos_less
|
||||
else : arch2infos = self.arch2infos_full
|
||||
if arch_index in arch2infos:
|
||||
strings = print_information(arch2infos[ arch_index ], 'arch-index={:}'.format(arch_index))
|
||||
return '\n'.join(strings)
|
||||
else:
|
||||
print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index))
|
||||
return None
|
||||
|
||||
# This 'query_by_index' function is used to query information with the training of 12 epochs or 200 epochs.
|
||||
# ------
|
||||
# If use_12epochs_result=True, we train the model by 12 epochs (see config in configs/nas-benchmark/LESS.config)
|
||||
# If use_12epochs_result=False, we train the model by 200 epochs (see config in configs/nas-benchmark/CIFAR.config)
|
||||
# ------
|
||||
# If dataname is None, return the ArchResults
|
||||
# else, return a dict with all trials on that dataset (the key is the seed)
|
||||
# Options are 'cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120'.
|
||||
# -- cifar10-valid : training the model on the CIFAR-10 training set.
|
||||
# -- cifar10 : training the model on the CIFAR-10 training + validation set.
|
||||
# -- cifar100 : training the model on the CIFAR-100 training set.
|
||||
# -- ImageNet16-120 : training the model on the ImageNet16-120 training set.
|
||||
def query_by_index(self, arch_index: int, dataname: Union[None, Text] = None,
|
||||
use_12epochs_result: bool = False):
|
||||
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
||||
assert arch_index in arch2infos, 'arch_index [{:}] does not in arch2info with {:}'.format(arch_index, basestr)
|
||||
archInfo = copy.deepcopy( arch2infos[ arch_index ] )
|
||||
if dataname is None: return archInfo
|
||||
else:
|
||||
assert dataname in archInfo.get_dataset_names(), 'invalid dataset-name : {:}'.format(dataname)
|
||||
info = archInfo.query(dataname)
|
||||
return info
|
||||
|
||||
def query_meta_info_by_index(self, arch_index, use_12epochs_result=False):
|
||||
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
||||
assert arch_index in arch2infos, 'arch_index [{:}] does not in arch2info with {:}'.format(arch_index, basestr)
|
||||
archInfo = copy.deepcopy( arch2infos[ arch_index ] )
|
||||
return archInfo
|
||||
|
||||
def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None, use_12epochs_result=False):
|
||||
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
||||
best_index, highest_accuracy = -1, None
|
||||
for i, idx in enumerate(self.evaluated_indexes):
|
||||
info = arch2infos[idx].get_compute_costs(dataset)
|
||||
flop, param, latency = info['flops'], info['params'], info['latency']
|
||||
if FLOP_max is not None and flop > FLOP_max : continue
|
||||
if Param_max is not None and param > Param_max: continue
|
||||
xinfo = arch2infos[idx].get_metrics(dataset, metric_on_set)
|
||||
loss, accuracy = xinfo['loss'], xinfo['accuracy']
|
||||
if best_index == -1:
|
||||
best_index, highest_accuracy = idx, accuracy
|
||||
elif highest_accuracy < accuracy:
|
||||
best_index, highest_accuracy = idx, accuracy
|
||||
return best_index, highest_accuracy
|
||||
|
||||
|
||||
def arch(self, index: int):
|
||||
"""Return the topology structure of the `index`-th architecture."""
|
||||
assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs))
|
||||
return copy.deepcopy(self.meta_archs[index])
|
||||
|
||||
def get_net_param(self, index, dataset, seed, use_12epochs_result=False):
|
||||
"""
|
||||
This function is used to obtain the trained weights of the `index`-th architecture on `dataset` with the seed of `seed`
|
||||
Args [seed]:
|
||||
-- None : return a dict containing the trained weights of all trials, where each key is a seed and its corresponding value is the weights.
|
||||
-- a interger : return the weights of a specific trial, whose seed is this interger.
|
||||
Args [use_12epochs_result]:
|
||||
-- True : train the model by 12 epochs
|
||||
-- False : train the model by 200 epochs
|
||||
"""
|
||||
if use_12epochs_result: arch2infos = self.arch2infos_less
|
||||
else: arch2infos = self.arch2infos_full
|
||||
arch_result = arch2infos[index]
|
||||
return arch_result.get_net_param(dataset, seed)
|
||||
|
||||
|
||||
def get_net_config(self, index: int, dataset: Text):
|
||||
"""
|
||||
This function is used to obtain the configuration for the `index`-th architecture on `dataset`.
|
||||
Args [dataset] (4 possible options):
|
||||
-- cifar10-valid : training the model on the CIFAR-10 training set.
|
||||
-- cifar10 : training the model on the CIFAR-10 training + validation set.
|
||||
-- cifar100 : training the model on the CIFAR-100 training set.
|
||||
-- ImageNet16-120 : training the model on the ImageNet16-120 training set.
|
||||
This function will return a dict.
|
||||
========= Some examlpes for using this function:
|
||||
config = api.get_net_config(128, 'cifar10')
|
||||
"""
|
||||
archresult = self.arch2infos_full[index]
|
||||
all_results = archresult.query(dataset, None)
|
||||
if len(all_results) == 0: raise ValueError('can not find one valid trial for the {:}-th architecture on {:}'.format(index, dataset))
|
||||
for seed, result in all_results.items():
|
||||
return result.get_config(None)
|
||||
#print ('SEED [{:}] : {:}'.format(seed, result))
|
||||
raise ValueError('Impossible to reach here!')
|
||||
|
||||
|
||||
def get_cost_info(self, index: int, dataset: Text, use_12epochs_result: bool = False) -> Dict[Text, float]:
|
||||
"""To obtain the cost metric for the `index`-th architecture on a dataset."""
|
||||
if use_12epochs_result: arch2infos = self.arch2infos_less
|
||||
else: arch2infos = self.arch2infos_full
|
||||
arch_result = arch2infos[index]
|
||||
return arch_result.get_compute_costs(dataset)
|
||||
|
||||
|
||||
def get_latency(self, index: int, dataset: Text, use_12epochs_result: bool = False) -> float:
|
||||
"""
|
||||
To obtain the latency of the network (by default it will return the latency with the batch size of 256).
|
||||
:param index: the index of the target architecture
|
||||
:param dataset: the dataset name (cifar10-valid, cifar10, cifar100, ImageNet16-120)
|
||||
:return: return a float value in seconds
|
||||
"""
|
||||
cost_dict = self.get_cost_info(index, dataset, use_12epochs_result)
|
||||
return cost_dict['latency']
|
||||
|
||||
|
||||
# obtain the metric for the `index`-th architecture
|
||||
# `dataset` indicates the dataset:
|
||||
# 'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set
|
||||
# 'cifar10' : using the proposed train+valid set of CIFAR-10 as the training set
|
||||
# 'cifar100' : using the proposed train set of CIFAR-100 as the training set
|
||||
# 'ImageNet16-120' : using the proposed train set of ImageNet-16-120 as the training set
|
||||
# `iepoch` indicates the index of training epochs from 0 to 11/199.
|
||||
# When iepoch=None, it will return the metric for the last training epoch
|
||||
# When iepoch=11, it will return the metric for the 11-th training epoch (starting from 0)
|
||||
# `use_12epochs_result` indicates different hyper-parameters for training
|
||||
# When use_12epochs_result=True, it trains the network with 12 epochs and the LR decayed from 0.1 to 0 within 12 epochs
|
||||
# When use_12epochs_result=False, it trains the network with 200 epochs and the LR decayed from 0.1 to 0 within 200 epochs
|
||||
# `is_random`
|
||||
# When is_random=True, the performance of a random architecture will be returned
|
||||
# When is_random=False, the performanceo of all trials will be averaged.
|
||||
def get_more_info(self, index: int, dataset, iepoch=None, use_12epochs_result=False, is_random=True):
|
||||
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
||||
archresult = arch2infos[index]
|
||||
# if randomly select one trial, select the seed at first
|
||||
if isinstance(is_random, bool) and is_random:
|
||||
seeds = archresult.get_dataset_seeds(dataset)
|
||||
is_random = random.choice(seeds)
|
||||
if dataset == 'cifar10-valid':
|
||||
train_info = archresult.get_metrics(dataset, 'train' , iepoch=iepoch, is_random=is_random)
|
||||
valid_info = archresult.get_metrics(dataset, 'x-valid' , iepoch=iepoch, is_random=is_random)
|
||||
try:
|
||||
test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||
except:
|
||||
test__info = None
|
||||
total = train_info['iepoch'] + 1
|
||||
xifo = {'train-loss' : train_info['loss'],
|
||||
'train-accuracy': train_info['accuracy'],
|
||||
'train-per-time': None if train_info['all_time'] is None else train_info['all_time'] / total,
|
||||
'train-all-time': train_info['all_time'],
|
||||
'valid-loss' : valid_info['loss'],
|
||||
'valid-accuracy': valid_info['accuracy'],
|
||||
'valid-all-time': valid_info['all_time'],
|
||||
'valid-per-time': None if valid_info['all_time'] is None else valid_info['all_time'] / total}
|
||||
if test__info is not None:
|
||||
xifo['test-loss'] = test__info['loss']
|
||||
xifo['test-accuracy'] = test__info['accuracy']
|
||||
return xifo
|
||||
else:
|
||||
train_info = archresult.get_metrics(dataset, 'train' , iepoch=iepoch, is_random=is_random)
|
||||
try:
|
||||
if dataset == 'cifar10':
|
||||
test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||
else:
|
||||
test__info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random)
|
||||
except:
|
||||
test__info = None
|
||||
try:
|
||||
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
|
||||
except:
|
||||
valid_info = None
|
||||
try:
|
||||
est_valid_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||
except:
|
||||
est_valid_info = None
|
||||
xifo = {'train-loss' : train_info['loss'],
|
||||
'train-accuracy': train_info['accuracy']}
|
||||
if test__info is not None:
|
||||
xifo['test-loss'] = test__info['loss'],
|
||||
xifo['test-accuracy'] = test__info['accuracy']
|
||||
if valid_info is not None:
|
||||
xifo['valid-loss'] = valid_info['loss']
|
||||
xifo['valid-accuracy'] = valid_info['accuracy']
|
||||
if est_valid_info is not None:
|
||||
xifo['est-valid-loss'] = est_valid_info['loss']
|
||||
xifo['est-valid-accuracy'] = est_valid_info['accuracy']
|
||||
return xifo
|
||||
|
||||
|
||||
def show(self, index: int = -1):
|
||||
return_flag = 0
|
||||
"""
|
||||
This function will print the information of a specific (or all) architecture(s).
|
||||
|
||||
:param index: If the index < 0: it will loop for all architectures and print their information one by one.
|
||||
else: it will print the information of the 'index'-th archiitecture.
|
||||
:return: nothing
|
||||
"""
|
||||
if index < 0: # show all architectures
|
||||
print(self)
|
||||
for i, idx in enumerate(self.evaluated_indexes):
|
||||
print('\n' + '-' * 10 + ' The ({:5d}/{:5d}) {:06d}-th architecture! '.format(i, len(self.evaluated_indexes), idx) + '-'*10)
|
||||
print('arch : {:}'.format(self.meta_archs[idx]))
|
||||
strings = print_information(self.arch2infos_full[idx])
|
||||
print('>' * 40 + ' {:03d} epochs '.format(self.arch2infos_full[idx].get_total_epoch()) + '>' * 40)
|
||||
print('\n'.join(strings))
|
||||
strings = print_information(self.arch2infos_less[idx])
|
||||
print('>' * 40 + ' {:03d} epochs '.format(self.arch2infos_less[idx].get_total_epoch()) + '>' * 40)
|
||||
print('\n'.join(strings))
|
||||
print('<' * 40 + '------------' + '<' * 40)
|
||||
else:
|
||||
if 0 <= index < len(self.meta_archs):
|
||||
if index not in self.evaluated_indexes: print('The {:}-th architecture has not been evaluated or not saved.'.format(index))
|
||||
else:
|
||||
return_flag = 1
|
||||
out = []
|
||||
strings = print_information(self.arch2infos_full[index])
|
||||
out.append(strings)
|
||||
print('>' * 40 + ' {:03d} epochs '.format(self.arch2infos_full[index].get_total_epoch()) + '>' * 40)
|
||||
print('\n'.join(strings))
|
||||
strings = print_information(self.arch2infos_less[index])
|
||||
out.append(strings)
|
||||
print('>' * 40 + ' {:03d} epochs '.format(self.arch2infos_less[index].get_total_epoch()) + '>' * 40)
|
||||
print('\n'.join(strings))
|
||||
print('<' * 40 + '------------' + '<' * 40)
|
||||
|
||||
else:
|
||||
print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs)))
|
||||
|
||||
if return_flag:
|
||||
return out
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def str2lists(arch_str: Text) -> List[tuple]:
|
||||
"""
|
||||
This function shows how to read the string-based architecture encoding.
|
||||
It is the same as the `str2structure` func in `AutoDL-Projects/lib/models/cell_searchs/genotypes.py`
|
||||
|
||||
:param
|
||||
arch_str: the input is a string indicates the architecture topology, such as
|
||||
|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|
|
||||
:return: a list of tuple, contains multiple (op, input_node_index) pairs.
|
||||
|
||||
:usage
|
||||
arch = api.str2lists( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' )
|
||||
print ('there are {:} nodes in this arch'.format(len(arch)+1)) # arch is a list
|
||||
for i, node in enumerate(arch):
|
||||
print('the {:}-th node is the sum of these {:} nodes with op: {:}'.format(i+1, len(node), node))
|
||||
"""
|
||||
node_strs = arch_str.split('+')
|
||||
genotypes = []
|
||||
for i, node_str in enumerate(node_strs):
|
||||
inputs = list(filter(lambda x: x != '', node_str.split('|')))
|
||||
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
|
||||
inputs = ( xi.split('~') for xi in inputs )
|
||||
input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs)
|
||||
genotypes.append( input_infos )
|
||||
return genotypes
|
||||
|
||||
|
||||
@staticmethod
|
||||
def str2matrix(arch_str: Text,
|
||||
search_space: List[Text] = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']) -> np.ndarray:
|
||||
"""
|
||||
This func shows how to convert the string-based architecture encoding to the encoding strategy in NAS-Bench-101.
|
||||
|
||||
:param
|
||||
arch_str: the input is a string indicates the architecture topology, such as
|
||||
|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|
|
||||
search_space: a list of operation string, the default list is the search space for NAS-Bench-201
|
||||
the default value should be be consistent with this line https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/models/cell_operations.py#L24
|
||||
:return
|
||||
the numpy matrix (2-D np.ndarray) representing the DAG of this architecture topology
|
||||
:usage
|
||||
matrix = api.str2matrix( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' )
|
||||
This matrix is 4-by-4 matrix representing a cell with 4 nodes (only the lower left triangle is useful).
|
||||
[ [0, 0, 0, 0], # the first line represents the input (0-th) node
|
||||
[2, 0, 0, 0], # the second line represents the 1-st node, is calculated by 2-th-op( 0-th-node )
|
||||
[0, 0, 0, 0], # the third line represents the 2-nd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node )
|
||||
[0, 0, 1, 0] ] # the fourth line represents the 3-rd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node ) + 1-th-op( 2-th-node )
|
||||
In NAS-Bench-201 search space, 0-th-op is 'none', 1-th-op is 'skip_connect',
|
||||
2-th-op is 'nor_conv_1x1', 3-th-op is 'nor_conv_3x3', 4-th-op is 'avg_pool_3x3'.
|
||||
:(NOTE)
|
||||
If a node has two input-edges from the same node, this function does not work. One edge will be overlapped.
|
||||
"""
|
||||
node_strs = arch_str.split('+')
|
||||
num_nodes = len(node_strs) + 1
|
||||
matrix = np.zeros((num_nodes, num_nodes))
|
||||
for i, node_str in enumerate(node_strs):
|
||||
inputs = list(filter(lambda x: x != '', node_str.split('|')))
|
||||
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
|
||||
for xi in inputs:
|
||||
op, idx = xi.split('~')
|
||||
if op not in search_space: raise ValueError('this op ({:}) is not in {:}'.format(op, search_space))
|
||||
op_idx, node_idx = search_space.index(op), int(idx)
|
||||
matrix[i+1, node_idx] = op_idx
|
||||
return matrix
|
||||
|
||||
|
||||
|
||||
class ArchResults(object):
|
||||
|
||||
def __init__(self, arch_index, arch_str):
|
||||
self.arch_index = int(arch_index)
|
||||
self.arch_str = copy.deepcopy(arch_str)
|
||||
self.all_results = dict()
|
||||
self.dataset_seed = dict()
|
||||
self.clear_net_done = False
|
||||
|
||||
def get_compute_costs(self, dataset):
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
results = [self.all_results[ (dataset, seed) ] for seed in x_seeds]
|
||||
|
||||
flops = [result.flop for result in results]
|
||||
params = [result.params for result in results]
|
||||
latencies = [result.get_latency() for result in results]
|
||||
latencies = [x for x in latencies if x > 0]
|
||||
mean_latency = np.mean(latencies) if len(latencies) > 0 else None
|
||||
time_infos = defaultdict(list)
|
||||
for result in results:
|
||||
time_info = result.get_times()
|
||||
for key, value in time_info.items(): time_infos[key].append( value )
|
||||
|
||||
info = {'flops' : np.mean(flops),
|
||||
'params' : np.mean(params),
|
||||
'latency': mean_latency}
|
||||
for key, value in time_infos.items():
|
||||
if len(value) > 0 and value[0] is not None:
|
||||
info[key] = np.mean(value)
|
||||
else: info[key] = None
|
||||
return info
|
||||
|
||||
def get_metrics(self, dataset, setname, iepoch=None, is_random=False):
|
||||
"""
|
||||
This `get_metrics` function is used to obtain obtain the loss, accuracy, etc information on a specific dataset.
|
||||
If not specify, each set refer to the proposed split in NAS-Bench-201 paper.
|
||||
If some args return None or raise error, then it is not avaliable.
|
||||
========================================
|
||||
Args [dataset] (4 possible options):
|
||||
-- cifar10-valid : training the model on the CIFAR-10 training set.
|
||||
-- cifar10 : training the model on the CIFAR-10 training + validation set.
|
||||
-- cifar100 : training the model on the CIFAR-100 training set.
|
||||
-- ImageNet16-120 : training the model on the ImageNet16-120 training set.
|
||||
Args [setname] (each dataset has different setnames):
|
||||
-- When dataset = cifar10-valid, you can use 'train', 'x-valid', 'ori-test'
|
||||
------ 'train' : the metric on the training set.
|
||||
------ 'x-valid' : the metric on the validation set.
|
||||
------ 'ori-test' : the metric on the test set.
|
||||
-- When dataset = cifar10, you can use 'train', 'ori-test'.
|
||||
------ 'train' : the metric on the training + validation set.
|
||||
------ 'ori-test' : the metric on the test set.
|
||||
-- When dataset = cifar100 or ImageNet16-120, you can use 'train', 'ori-test', 'x-valid', 'x-test'
|
||||
------ 'train' : the metric on the training set.
|
||||
------ 'x-valid' : the metric on the validation set.
|
||||
------ 'x-test' : the metric on the test set.
|
||||
------ 'ori-test' : the metric on the validation + test set.
|
||||
Args [iepoch] (None or an integer in [0, the-number-of-total-training-epochs)
|
||||
------ None : return the metric after the last training epoch.
|
||||
------ an integer i : return the metric after the i-th training epoch.
|
||||
Args [is_random]:
|
||||
------ True : return the metric of a randomly selected trial.
|
||||
------ False : return the averaged metric of all avaliable trials.
|
||||
------ an integer indicating the 'seed' value : return the metric of a specific trial (whose random seed is 'is_random').
|
||||
"""
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
results = [self.all_results[ (dataset, seed) ] for seed in x_seeds]
|
||||
infos = defaultdict(list)
|
||||
for result in results:
|
||||
if setname == 'train':
|
||||
info = result.get_train(iepoch)
|
||||
else:
|
||||
info = result.get_eval(setname, iepoch)
|
||||
for key, value in info.items(): infos[key].append( value )
|
||||
return_info = dict()
|
||||
if isinstance(is_random, bool) and is_random: # randomly select one
|
||||
index = random.randint(0, len(results)-1)
|
||||
for key, value in infos.items(): return_info[key] = value[index]
|
||||
elif isinstance(is_random, bool) and not is_random: # average
|
||||
for key, value in infos.items():
|
||||
if len(value) > 0 and value[0] is not None:
|
||||
return_info[key] = np.mean(value)
|
||||
else: return_info[key] = None
|
||||
elif isinstance(is_random, int): # specify the seed
|
||||
if is_random not in x_seeds: raise ValueError('can not find random seed ({:}) from {:}'.format(is_random, x_seeds))
|
||||
index = x_seeds.index(is_random)
|
||||
for key, value in infos.items(): return_info[key] = value[index]
|
||||
else:
|
||||
raise ValueError('invalid value for is_random: {:}'.format(is_random))
|
||||
return return_info
|
||||
|
||||
def show(self, is_print=False):
|
||||
return print_information(self, None, is_print)
|
||||
|
||||
def get_dataset_names(self):
|
||||
return list(self.dataset_seed.keys())
|
||||
|
||||
def get_dataset_seeds(self, dataset):
|
||||
return copy.deepcopy( self.dataset_seed[dataset] )
|
||||
|
||||
def get_net_param(self, dataset: Text, seed: Union[None, int] =None):
|
||||
"""
|
||||
This function will return the trained network's weights on the 'dataset'.
|
||||
:arg
|
||||
dataset: one of 'cifar10-valid', 'cifar10', 'cifar100', and 'ImageNet16-120'.
|
||||
seed: an integer indicates the seed value or None that indicates returing all trials.
|
||||
"""
|
||||
if seed is None:
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
return {seed: self.all_results[(dataset, seed)].get_net_param() for seed in x_seeds}
|
||||
else:
|
||||
return self.all_results[(dataset, seed)].get_net_param()
|
||||
|
||||
def reset_latency(self, dataset: Text, seed: Union[None, Text], latency: float) -> None:
|
||||
"""This function is used to reset the latency in all corresponding ResultsCount(s)."""
|
||||
if seed is None:
|
||||
for seed in self.dataset_seed[dataset]:
|
||||
self.all_results[(dataset, seed)].update_latency([latency])
|
||||
else:
|
||||
self.all_results[(dataset, seed)].update_latency([latency])
|
||||
|
||||
def reset_pseudo_train_times(self, dataset: Text, seed: Union[None, Text], estimated_per_epoch_time: float) -> None:
|
||||
"""This function is used to reset the train-times in all corresponding ResultsCount(s)."""
|
||||
if seed is None:
|
||||
for seed in self.dataset_seed[dataset]:
|
||||
self.all_results[(dataset, seed)].reset_pseudo_train_times(estimated_per_epoch_time)
|
||||
else:
|
||||
self.all_results[(dataset, seed)].reset_pseudo_train_times(estimated_per_epoch_time)
|
||||
|
||||
def reset_pseudo_eval_times(self, dataset: Text, seed: Union[None, Text], eval_name: Text, estimated_per_epoch_time: float) -> None:
|
||||
"""This function is used to reset the eval-times in all corresponding ResultsCount(s)."""
|
||||
if seed is None:
|
||||
for seed in self.dataset_seed[dataset]:
|
||||
self.all_results[(dataset, seed)].reset_pseudo_eval_times(eval_name, estimated_per_epoch_time)
|
||||
else:
|
||||
self.all_results[(dataset, seed)].reset_pseudo_eval_times(eval_name, estimated_per_epoch_time)
|
||||
|
||||
def get_latency(self, dataset: Text) -> float:
|
||||
"""Get the latency of a model on the target dataset. [Timestamp: 2020.03.09]"""
|
||||
latencies = []
|
||||
for seed in self.dataset_seed[dataset]:
|
||||
latency = self.all_results[(dataset, seed)].get_latency()
|
||||
if not isinstance(latency, float) or latency <= 0:
|
||||
raise ValueError('invalid latency of {:} for {:} with {:}'.format(dataset))
|
||||
latencies.append(latency)
|
||||
return sum(latencies) / len(latencies)
|
||||
|
||||
def get_total_epoch(self, dataset=None):
|
||||
"""Return the total number of training epochs."""
|
||||
if dataset is None:
|
||||
epochss = []
|
||||
for xdata, x_seeds in self.dataset_seed.items():
|
||||
epochss += [self.all_results[(xdata, seed)].get_total_epoch() for seed in x_seeds]
|
||||
elif isinstance(dataset, str):
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
epochss = [self.all_results[(dataset, seed)].get_total_epoch() for seed in x_seeds]
|
||||
else:
|
||||
raise ValueError('invalid dataset={:}'.format(dataset))
|
||||
if len(set(epochss)) > 1: raise ValueError('Each trial mush have the same number of training epochs : {:}'.format(epochss))
|
||||
return epochss[-1]
|
||||
|
||||
def query(self, dataset, seed=None):
|
||||
"""Return the ResultsCount object (containing all information of a single trial) for 'dataset' and 'seed'"""
|
||||
if seed is None:
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
return {seed: self.all_results[(dataset, seed)] for seed in x_seeds}
|
||||
else:
|
||||
return self.all_results[(dataset, seed)]
|
||||
|
||||
def arch_idx_str(self):
|
||||
return '{:06d}'.format(self.arch_index)
|
||||
|
||||
def update(self, dataset_name, seed, result):
|
||||
if dataset_name not in self.dataset_seed:
|
||||
self.dataset_seed[dataset_name] = []
|
||||
assert seed not in self.dataset_seed[dataset_name], '{:}-th arch alreadly has this seed ({:}) on {:}'.format(self.arch_index, seed, dataset_name)
|
||||
self.dataset_seed[ dataset_name ].append( seed )
|
||||
self.dataset_seed[ dataset_name ] = sorted( self.dataset_seed[ dataset_name ] )
|
||||
assert (dataset_name, seed) not in self.all_results
|
||||
self.all_results[ (dataset_name, seed) ] = result
|
||||
self.clear_net_done = False
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = dict()
|
||||
for key, value in self.__dict__.items():
|
||||
if key == 'all_results': # contain the class of ResultsCount
|
||||
xvalue = dict()
|
||||
assert isinstance(value, dict), 'invalid type of value for {:} : {:}'.format(key, type(value))
|
||||
for _k, _v in value.items():
|
||||
assert isinstance(_v, ResultsCount), 'invalid type of value for {:}/{:} : {:}'.format(key, _k, type(_v))
|
||||
xvalue[_k] = _v.state_dict()
|
||||
else:
|
||||
xvalue = value
|
||||
state_dict[key] = xvalue
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
new_state_dict = dict()
|
||||
for key, value in state_dict.items():
|
||||
if key == 'all_results': # to convert to the class of ResultsCount
|
||||
xvalue = dict()
|
||||
assert isinstance(value, dict), 'invalid type of value for {:} : {:}'.format(key, type(value))
|
||||
for _k, _v in value.items():
|
||||
xvalue[_k] = ResultsCount.create_from_state_dict(_v)
|
||||
else: xvalue = value
|
||||
new_state_dict[key] = xvalue
|
||||
self.__dict__.update(new_state_dict)
|
||||
|
||||
@staticmethod
|
||||
def create_from_state_dict(state_dict_or_file):
|
||||
x = ArchResults(-1, -1)
|
||||
if isinstance(state_dict_or_file, str): # a file path
|
||||
state_dict = torch.load(state_dict_or_file)
|
||||
elif isinstance(state_dict_or_file, dict):
|
||||
state_dict = state_dict_or_file
|
||||
else:
|
||||
raise ValueError('invalid type of state_dict_or_file : {:}'.format(type(state_dict_or_file)))
|
||||
x.load_state_dict(state_dict)
|
||||
return x
|
||||
|
||||
# This function is used to clear the weights saved in each 'result'
|
||||
# This can help reduce the memory footprint.
|
||||
def clear_params(self):
|
||||
for key, result in self.all_results.items():
|
||||
result.net_state_dict = None
|
||||
self.clear_net_done = True
|
||||
|
||||
def debug_test(self):
|
||||
"""This function is used for me to debug and test, which will call most methods."""
|
||||
all_dataset = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
|
||||
for dataset in all_dataset:
|
||||
print('---->>>> {:}'.format(dataset))
|
||||
print('The latency on {:} is {:} s'.format(dataset, self.get_latency(dataset)))
|
||||
for seed in self.dataset_seed[dataset]:
|
||||
result = self.all_results[(dataset, seed)]
|
||||
print(' ==>> result = {:}'.format(result))
|
||||
print(' ==>> cost = {:}'.format(result.get_times()))
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}(arch-index={index}, arch={arch}, {num} runs, clear={clear})'.format(name=self.__class__.__name__, index=self.arch_index, arch=self.arch_str, num=len(self.all_results), clear=self.clear_net_done))
|
||||
|
||||
|
||||
|
||||
"""
|
||||
This class (ResultsCount) is used to save the information of one trial for a single architecture.
|
||||
I did not write much comment for this class, because it is the lowest-level class in NAS-Bench-201 API, which will be rarely called.
|
||||
If you have any question regarding this class, please open an issue or email me.
|
||||
"""
|
||||
class ResultsCount(object):
|
||||
|
||||
def __init__(self, name, state_dict, train_accs, train_losses, params, flop, arch_config, seed, epochs, latency):
|
||||
self.name = name
|
||||
self.net_state_dict = state_dict
|
||||
self.train_acc1es = copy.deepcopy(train_accs)
|
||||
self.train_acc5es = None
|
||||
self.train_losses = copy.deepcopy(train_losses)
|
||||
self.train_times = None
|
||||
self.arch_config = copy.deepcopy(arch_config)
|
||||
self.params = params
|
||||
self.flop = flop
|
||||
self.seed = seed
|
||||
self.epochs = epochs
|
||||
self.latency = latency
|
||||
# evaluation results
|
||||
self.reset_eval()
|
||||
|
||||
def update_train_info(self, train_acc1es, train_acc5es, train_losses, train_times) -> None:
|
||||
self.train_acc1es = train_acc1es
|
||||
self.train_acc5es = train_acc5es
|
||||
self.train_losses = train_losses
|
||||
self.train_times = train_times
|
||||
|
||||
def reset_pseudo_train_times(self, estimated_per_epoch_time: float) -> None:
|
||||
"""Assign the training times."""
|
||||
train_times = OrderedDict()
|
||||
for i in range(self.epochs):
|
||||
train_times[i] = estimated_per_epoch_time
|
||||
self.train_times = train_times
|
||||
|
||||
def reset_pseudo_eval_times(self, eval_name: Text, estimated_per_epoch_time: float) -> None:
|
||||
"""Assign the evaluation times."""
|
||||
if eval_name not in self.eval_names: raise ValueError('invalid eval name : {:}'.format(eval_name))
|
||||
for i in range(self.epochs):
|
||||
self.eval_times['{:}@{:}'.format(eval_name,i)] = estimated_per_epoch_time
|
||||
|
||||
def reset_eval(self):
|
||||
self.eval_names = []
|
||||
self.eval_acc1es = {}
|
||||
self.eval_times = {}
|
||||
self.eval_losses = {}
|
||||
|
||||
def update_latency(self, latency):
|
||||
self.latency = copy.deepcopy( latency )
|
||||
|
||||
def get_latency(self) -> float:
|
||||
"""Return the latency value in seconds. -1 represents not avaliable ; otherwise it should be a float value"""
|
||||
if self.latency is None: return -1.0
|
||||
else: return sum(self.latency) / len(self.latency)
|
||||
|
||||
def update_eval(self, accs, losses, times): # new version
|
||||
data_names = set([x.split('@')[0] for x in accs.keys()])
|
||||
for data_name in data_names:
|
||||
assert data_name not in self.eval_names, '{:} has already been added into eval-names'.format(data_name)
|
||||
self.eval_names.append( data_name )
|
||||
for iepoch in range(self.epochs):
|
||||
xkey = '{:}@{:}'.format(data_name, iepoch)
|
||||
self.eval_acc1es[ xkey ] = accs[ xkey ]
|
||||
self.eval_losses[ xkey ] = losses[ xkey ]
|
||||
self.eval_times [ xkey ] = times[ xkey ]
|
||||
|
||||
def update_OLD_eval(self, name, accs, losses): # old version
|
||||
assert name not in self.eval_names, '{:} has already added'.format(name)
|
||||
self.eval_names.append( name )
|
||||
for iepoch in range(self.epochs):
|
||||
if iepoch in accs:
|
||||
self.eval_acc1es['{:}@{:}'.format(name,iepoch)] = accs[iepoch]
|
||||
self.eval_losses['{:}@{:}'.format(name,iepoch)] = losses[iepoch]
|
||||
|
||||
def __repr__(self):
|
||||
num_eval = len(self.eval_names)
|
||||
set_name = '[' + ', '.join(self.eval_names) + ']'
|
||||
return ('{name}({xname}, arch={arch}, FLOP={flop:.2f}M, Param={param:.3f}MB, seed={seed}, {num_eval} eval-sets: {set_name})'.format(name=self.__class__.__name__, xname=self.name, arch=self.arch_config['arch_str'], flop=self.flop, param=self.params, seed=self.seed, num_eval=num_eval, set_name=set_name))
|
||||
|
||||
def get_total_epoch(self):
|
||||
return copy.deepcopy(self.epochs)
|
||||
|
||||
def get_times(self):
|
||||
"""Obtain the information regarding both training and evaluation time."""
|
||||
if self.train_times is not None and isinstance(self.train_times, dict):
|
||||
train_times = list( self.train_times.values() )
|
||||
time_info = {'T-train@epoch': np.mean(train_times), 'T-train@total': np.sum(train_times)}
|
||||
else:
|
||||
time_info = {'T-train@epoch': None, 'T-train@total': None }
|
||||
for name in self.eval_names:
|
||||
try:
|
||||
xtimes = [self.eval_times['{:}@{:}'.format(name,i)] for i in range(self.epochs)]
|
||||
time_info['T-{:}@epoch'.format(name)] = np.mean(xtimes)
|
||||
time_info['T-{:}@total'.format(name)] = np.sum(xtimes)
|
||||
except:
|
||||
time_info['T-{:}@epoch'.format(name)] = None
|
||||
time_info['T-{:}@total'.format(name)] = None
|
||||
return time_info
|
||||
|
||||
def get_eval_set(self):
|
||||
return self.eval_names
|
||||
|
||||
# get the training information
|
||||
def get_train(self, iepoch=None):
|
||||
if iepoch is None: iepoch = self.epochs-1
|
||||
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
|
||||
if self.train_times is not None:
|
||||
xtime = self.train_times[iepoch]
|
||||
atime = sum([self.train_times[i] for i in range(iepoch+1)])
|
||||
else: xtime, atime = None, None
|
||||
return {'iepoch' : iepoch,
|
||||
'loss' : self.train_losses[iepoch],
|
||||
'accuracy': self.train_acc1es[iepoch],
|
||||
'cur_time': xtime,
|
||||
'all_time': atime}
|
||||
|
||||
# get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument).
|
||||
def get_eval(self, name, iepoch=None):
|
||||
if iepoch is None: iepoch = self.epochs-1
|
||||
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
|
||||
if isinstance(self.eval_times,dict) and len(self.eval_times) > 0:
|
||||
xtime = self.eval_times['{:}@{:}'.format(name,iepoch)]
|
||||
atime = sum([self.eval_times['{:}@{:}'.format(name,i)] for i in range(iepoch+1)])
|
||||
else: xtime, atime = None, None
|
||||
return {'iepoch' : iepoch,
|
||||
'loss' : self.eval_losses['{:}@{:}'.format(name,iepoch)],
|
||||
'accuracy': self.eval_acc1es['{:}@{:}'.format(name,iepoch)],
|
||||
'cur_time': xtime,
|
||||
'all_time': atime}
|
||||
|
||||
def get_net_param(self, clone=False):
|
||||
if clone: return copy.deepcopy(self.net_state_dict)
|
||||
else: return self.net_state_dict
|
||||
|
||||
# This function is used to obtain the config dict for this architecture.
|
||||
def get_config(self, str2structure):
|
||||
if str2structure is None:
|
||||
return {'name': 'infer.tiny', 'C': self.arch_config['channel'],
|
||||
'N' : self.arch_config['num_cells'],
|
||||
'arch_str': self.arch_config['arch_str'], 'num_classes': self.arch_config['class_num']}
|
||||
else:
|
||||
return {'name': 'infer.tiny', 'C': self.arch_config['channel'],
|
||||
'N' : self.arch_config['num_cells'],
|
||||
'genotype': str2structure(self.arch_config['arch_str']), 'num_classes': self.arch_config['class_num']}
|
||||
|
||||
def state_dict(self):
|
||||
_state_dict = {key: value for key, value in self.__dict__.items()}
|
||||
return _state_dict
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
@staticmethod
|
||||
def create_from_state_dict(state_dict):
|
||||
x = ResultsCount(None, None, None, None, None, None, None, None, None, None)
|
||||
x.load_state_dict(state_dict)
|
||||
return x
|
||||
274
nas_201_api/api_201.py
Normal file
274
nas_201_api/api_201.py
Normal file
@@ -0,0 +1,274 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||
############################################################################################
|
||||
# NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search, ICLR 2020 #
|
||||
############################################################################################
|
||||
# The history of benchmark files:
|
||||
# [2020.02.25] NAS-Bench-201-v1_0-e61699.pth : 6219 architectures are trained once, 1621 architectures are trained twice, 7785 architectures are trained three times. `LESS` only supports CIFAR10-VALID.
|
||||
# [2020.03.16] NAS-Bench-201-v1_1-096897.pth : 2225 architectures are trained once, 5439 archiitectures are trained twice, 7961 architectures are trained three times on all training sets. For the hyper-parameters with the total epochs of 12, each model is trained on CIFAR-10, CIFAR-100, ImageNet16-120 once, and is trained on CIFAR-10-VALID twice.
|
||||
#
|
||||
# I'm still actively enhancing our benchmark, while for the future benchmark file, please follow news from NATS-Bench (an extended version of NAS-Bench-201).
|
||||
#
|
||||
import os, copy, random, torch, numpy as np
|
||||
from pathlib import Path
|
||||
from typing import List, Text, Union, Dict, Optional
|
||||
from collections import OrderedDict, defaultdict
|
||||
|
||||
from .api_utils import ArchResults
|
||||
from .api_utils import NASBenchMetaAPI
|
||||
from .api_utils import remap_dataset_set_names
|
||||
|
||||
|
||||
ALL_BENCHMARK_FILES = ['NAS-Bench-201-v1_0-e61699.pth', 'NAS-Bench-201-v1_1-096897.pth']
|
||||
ALL_ARCHIVE_DIRS = ['NAS-Bench-201-v1_1-archive']
|
||||
|
||||
|
||||
def print_information(information, extra_info=None, show=False):
|
||||
dataset_names = information.get_dataset_names()
|
||||
strings = [information.arch_str, 'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info)]
|
||||
def metric2str(loss, acc):
|
||||
return 'loss = {:.3f}, top1 = {:.2f}%'.format(loss, acc)
|
||||
|
||||
for ida, dataset in enumerate(dataset_names):
|
||||
metric = information.get_compute_costs(dataset)
|
||||
flop, param, latency = metric['flops'], metric['params'], metric['latency']
|
||||
str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(dataset, flop, param, '{:.2f}'.format(latency*1000) if latency is not None and latency > 0 else None)
|
||||
train_info = information.get_metrics(dataset, 'train')
|
||||
if dataset == 'cifar10-valid':
|
||||
valid_info = information.get_metrics(dataset, 'x-valid')
|
||||
str2 = '{:14s} train : [{:}], valid : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']))
|
||||
elif dataset == 'cifar10':
|
||||
test__info = information.get_metrics(dataset, 'ori-test')
|
||||
str2 = '{:14s} train : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy']))
|
||||
else:
|
||||
valid_info = information.get_metrics(dataset, 'x-valid')
|
||||
test__info = information.get_metrics(dataset, 'x-test')
|
||||
str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy']))
|
||||
strings += [str1, str2]
|
||||
if show: print('\n'.join(strings))
|
||||
return strings
|
||||
|
||||
|
||||
"""
|
||||
This is the class for the API of NAS-Bench-201.
|
||||
"""
|
||||
class NASBench201API(NASBenchMetaAPI):
|
||||
|
||||
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
|
||||
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None,
|
||||
verbose: bool=True):
|
||||
self.filename = None
|
||||
self.reset_time()
|
||||
if file_path_or_dict is None:
|
||||
file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], ALL_BENCHMARK_FILES[-1])
|
||||
print ('Try to use the default NAS-Bench-201 path from {:}.'.format(file_path_or_dict))
|
||||
if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path):
|
||||
file_path_or_dict = str(file_path_or_dict)
|
||||
if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict))
|
||||
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
|
||||
self.filename = Path(file_path_or_dict).name
|
||||
file_path_or_dict = torch.load(file_path_or_dict, map_location='cpu')
|
||||
elif isinstance(file_path_or_dict, dict):
|
||||
file_path_or_dict = copy.deepcopy(file_path_or_dict)
|
||||
else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict)))
|
||||
assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict))
|
||||
self.verbose = verbose # [TODO] a flag indicating whether to print more logs
|
||||
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
|
||||
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
|
||||
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )
|
||||
# This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults
|
||||
self.arch2infos_dict = OrderedDict()
|
||||
self._avaliable_hps = set(['12', '200'])
|
||||
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
|
||||
all_info = file_path_or_dict['arch2infos'][xkey]
|
||||
hp2archres = OrderedDict()
|
||||
# self.arch2infos_less[xkey] = ArchResults.create_from_state_dict( all_info['less'] )
|
||||
# self.arch2infos_full[xkey] = ArchResults.create_from_state_dict( all_info['full'] )
|
||||
hp2archres['12'] = ArchResults.create_from_state_dict(all_info['less'])
|
||||
hp2archres['200'] = ArchResults.create_from_state_dict(all_info['full'])
|
||||
self.arch2infos_dict[xkey] = hp2archres
|
||||
self.evaluated_indexes = sorted(list(file_path_or_dict['evaluated_indexes']))
|
||||
self.archstr2index = {}
|
||||
for idx, arch in enumerate(self.meta_archs):
|
||||
assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch])
|
||||
self.archstr2index[ arch ] = idx
|
||||
|
||||
def reload(self, archive_root: Text = None, index: int = None):
|
||||
"""Overwrite all information of the 'index'-th architecture in the search space.
|
||||
It will load its data from 'archive_root'.
|
||||
"""
|
||||
if archive_root is None:
|
||||
archive_root = os.path.join(os.environ['TORCH_HOME'], ALL_ARCHIVE_DIRS[-1])
|
||||
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
|
||||
if index is None:
|
||||
indexes = list(range(len(self)))
|
||||
else:
|
||||
indexes = [index]
|
||||
for idx in indexes:
|
||||
assert 0 <= idx < len(self.meta_archs), 'invalid index of {:}'.format(idx)
|
||||
xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(idx))
|
||||
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
|
||||
xdata = torch.load(xfile_path, map_location='cpu')
|
||||
assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path)
|
||||
hp2archres = OrderedDict()
|
||||
hp2archres['12'] = ArchResults.create_from_state_dict(xdata['less'])
|
||||
hp2archres['200'] = ArchResults.create_from_state_dict(xdata['full'])
|
||||
self.arch2infos_dict[idx] = hp2archres
|
||||
|
||||
def query_info_str_by_arch(self, arch, hp: Text='12'):
|
||||
""" This function is used to query the information of a specific architecture
|
||||
'arch' can be an architecture index or an architecture string
|
||||
When hp=12, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/12E.config'
|
||||
When hp=200, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/200E.config'
|
||||
The difference between these three configurations are the number of training epochs.
|
||||
"""
|
||||
if self.verbose:
|
||||
print('Call query_info_str_by_arch with arch={:} and hp={:}'.format(arch, hp))
|
||||
return self._query_info_str_by_arch(arch, hp, print_information)
|
||||
|
||||
# obtain the metric for the `index`-th architecture
|
||||
# `dataset` indicates the dataset:
|
||||
# 'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set
|
||||
# 'cifar10' : using the proposed train+valid set of CIFAR-10 as the training set
|
||||
# 'cifar100' : using the proposed train set of CIFAR-100 as the training set
|
||||
# 'ImageNet16-120' : using the proposed train set of ImageNet-16-120 as the training set
|
||||
# `iepoch` indicates the index of training epochs from 0 to 11/199.
|
||||
# When iepoch=None, it will return the metric for the last training epoch
|
||||
# When iepoch=11, it will return the metric for the 11-th training epoch (starting from 0)
|
||||
# `use_12epochs_result` indicates different hyper-parameters for training
|
||||
# When use_12epochs_result=True, it trains the network with 12 epochs and the LR decayed from 0.1 to 0 within 12 epochs
|
||||
# When use_12epochs_result=False, it trains the network with 200 epochs and the LR decayed from 0.1 to 0 within 200 epochs
|
||||
# `is_random`
|
||||
# When is_random=True, the performance of a random architecture will be returned
|
||||
# When is_random=False, the performanceo of all trials will be averaged.
|
||||
def get_more_info(self, index, dataset, iepoch=None, hp='12', is_random=True):
|
||||
if self.verbose:
|
||||
print('Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(index, dataset, iepoch, hp, is_random))
|
||||
index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object
|
||||
if index not in self.arch2infos_dict:
|
||||
raise ValueError('Did not find {:} from arch2infos_dict.'.format(index))
|
||||
archresult = self.arch2infos_dict[index][str(hp)]
|
||||
# if randomly select one trial, select the seed at first
|
||||
if isinstance(is_random, bool) and is_random:
|
||||
seeds = archresult.get_dataset_seeds(dataset)
|
||||
is_random = random.choice(seeds)
|
||||
# collect the training information
|
||||
train_info = archresult.get_metrics(dataset, 'train', iepoch=iepoch, is_random=is_random)
|
||||
total = train_info['iepoch'] + 1
|
||||
xinfo = {'train-loss' : train_info['loss'],
|
||||
'train-accuracy': train_info['accuracy'],
|
||||
'train-per-time': train_info['all_time'] / total if train_info['all_time'] is not None else None,
|
||||
'train-all-time': train_info['all_time']}
|
||||
# collect the evaluation information
|
||||
if dataset == 'cifar10-valid':
|
||||
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
|
||||
try:
|
||||
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||
except:
|
||||
test_info = None
|
||||
valtest_info = None
|
||||
else:
|
||||
try: # collect results on the proposed test set
|
||||
if dataset == 'cifar10':
|
||||
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||
else:
|
||||
test_info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random)
|
||||
except:
|
||||
test_info = None
|
||||
try: # collect results on the proposed validation set
|
||||
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
|
||||
except:
|
||||
valid_info = None
|
||||
try:
|
||||
if dataset != 'cifar10':
|
||||
valtest_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||
else:
|
||||
valtest_info = None
|
||||
except:
|
||||
valtest_info = None
|
||||
if valid_info is not None:
|
||||
xinfo['valid-loss'] = valid_info['loss']
|
||||
xinfo['valid-accuracy'] = valid_info['accuracy']
|
||||
xinfo['valid-per-time'] = valid_info['all_time'] / total if valid_info['all_time'] is not None else None
|
||||
xinfo['valid-all-time'] = valid_info['all_time']
|
||||
if test_info is not None:
|
||||
xinfo['test-loss'] = test_info['loss']
|
||||
xinfo['test-accuracy'] = test_info['accuracy']
|
||||
xinfo['test-per-time'] = test_info['all_time'] / total if test_info['all_time'] is not None else None
|
||||
xinfo['test-all-time'] = test_info['all_time']
|
||||
if valtest_info is not None:
|
||||
xinfo['valtest-loss'] = valtest_info['loss']
|
||||
xinfo['valtest-accuracy'] = valtest_info['accuracy']
|
||||
xinfo['valtest-per-time'] = valtest_info['all_time'] / total if valtest_info['all_time'] is not None else None
|
||||
xinfo['valtest-all-time'] = valtest_info['all_time']
|
||||
return xinfo
|
||||
|
||||
def show(self, index: int = -1) -> None:
|
||||
"""This function will print the information of a specific (or all) architecture(s)."""
|
||||
self._show(index, print_information)
|
||||
|
||||
@staticmethod
|
||||
def str2lists(arch_str: Text) -> List[tuple]:
|
||||
"""
|
||||
This function shows how to read the string-based architecture encoding.
|
||||
It is the same as the `str2structure` func in `AutoDL-Projects/lib/models/cell_searchs/genotypes.py`
|
||||
|
||||
:param
|
||||
arch_str: the input is a string indicates the architecture topology, such as
|
||||
|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|
|
||||
:return: a list of tuple, contains multiple (op, input_node_index) pairs.
|
||||
|
||||
:usage
|
||||
arch = api.str2lists( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' )
|
||||
print ('there are {:} nodes in this arch'.format(len(arch)+1)) # arch is a list
|
||||
for i, node in enumerate(arch):
|
||||
print('the {:}-th node is the sum of these {:} nodes with op: {:}'.format(i+1, len(node), node))
|
||||
"""
|
||||
node_strs = arch_str.split('+')
|
||||
genotypes = []
|
||||
for i, node_str in enumerate(node_strs):
|
||||
inputs = list(filter(lambda x: x != '', node_str.split('|')))
|
||||
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
|
||||
inputs = ( xi.split('~') for xi in inputs )
|
||||
input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs)
|
||||
genotypes.append( input_infos )
|
||||
return genotypes
|
||||
|
||||
@staticmethod
|
||||
def str2matrix(arch_str: Text,
|
||||
search_space: List[Text] = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']) -> np.ndarray:
|
||||
"""
|
||||
This func shows how to convert the string-based architecture encoding to the encoding strategy in NAS-Bench-101.
|
||||
|
||||
:param
|
||||
arch_str: the input is a string indicates the architecture topology, such as
|
||||
|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|
|
||||
search_space: a list of operation string, the default list is the search space for NAS-Bench-201
|
||||
the default value should be be consistent with this line https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/models/cell_operations.py#L24
|
||||
:return
|
||||
the numpy matrix (2-D np.ndarray) representing the DAG of this architecture topology
|
||||
:usage
|
||||
matrix = api.str2matrix( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' )
|
||||
This matrix is 4-by-4 matrix representing a cell with 4 nodes (only the lower left triangle is useful).
|
||||
[ [0, 0, 0, 0], # the first line represents the input (0-th) node
|
||||
[2, 0, 0, 0], # the second line represents the 1-st node, is calculated by 2-th-op( 0-th-node )
|
||||
[0, 0, 0, 0], # the third line represents the 2-nd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node )
|
||||
[0, 0, 1, 0] ] # the fourth line represents the 3-rd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node ) + 1-th-op( 2-th-node )
|
||||
In NAS-Bench-201 search space, 0-th-op is 'none', 1-th-op is 'skip_connect',
|
||||
2-th-op is 'nor_conv_1x1', 3-th-op is 'nor_conv_3x3', 4-th-op is 'avg_pool_3x3'.
|
||||
:(NOTE)
|
||||
If a node has two input-edges from the same node, this function does not work. One edge will be overlapped.
|
||||
"""
|
||||
node_strs = arch_str.split('+')
|
||||
num_nodes = len(node_strs) + 1
|
||||
matrix = np.zeros((num_nodes, num_nodes))
|
||||
for i, node_str in enumerate(node_strs):
|
||||
inputs = list(filter(lambda x: x != '', node_str.split('|')))
|
||||
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
|
||||
for xi in inputs:
|
||||
op, idx = xi.split('~')
|
||||
if op not in search_space: raise ValueError('this op ({:}) is not in {:}'.format(op, search_space))
|
||||
op_idx, node_idx = search_space.index(op), int(idx)
|
||||
matrix[i+1, node_idx] = op_idx
|
||||
return matrix
|
||||
|
||||
750
nas_201_api/api_utils.py
Normal file
750
nas_201_api/api_utils.py
Normal file
@@ -0,0 +1,750 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||
############################################################################################
|
||||
# NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search, ICLR 2020 #
|
||||
############################################################################################
|
||||
# In this Python file, we define NASBenchMetaAPI, the abstract class for benchmark APIs.
|
||||
# We also define the class ArchResults, which contains all information of a single architecture trained by one kind of hyper-parameters on three datasets.
|
||||
# We also define the class ResultsCount, which contains all information of a single trial for a single architecture.
|
||||
############################################################################################
|
||||
#
|
||||
import os, abc, copy, random, torch, numpy as np
|
||||
from pathlib import Path
|
||||
from typing import List, Text, Union, Dict, Optional
|
||||
from collections import OrderedDict, defaultdict
|
||||
|
||||
|
||||
def remap_dataset_set_names(dataset, metric_on_set, verbose=False):
|
||||
"""re-map the metric_on_set to internal keys"""
|
||||
if verbose:
|
||||
print('Call internal function _remap_dataset_set_names with dataset={:} and metric_on_set={:}'.format(dataset, metric_on_set))
|
||||
if dataset == 'cifar10' and metric_on_set == 'valid':
|
||||
dataset, metric_on_set = 'cifar10-valid', 'x-valid'
|
||||
elif dataset == 'cifar10' and metric_on_set == 'test':
|
||||
dataset, metric_on_set = 'cifar10', 'ori-test'
|
||||
elif dataset == 'cifar10' and metric_on_set == 'train':
|
||||
dataset, metric_on_set = 'cifar10', 'train'
|
||||
elif (dataset == 'cifar100' or dataset == 'ImageNet16-120') and metric_on_set == 'valid':
|
||||
metric_on_set = 'x-valid'
|
||||
elif (dataset == 'cifar100' or dataset == 'ImageNet16-120') and metric_on_set == 'test':
|
||||
metric_on_set = 'x-test'
|
||||
if verbose:
|
||||
print(' return dataset={:} and metric_on_set={:}'.format(dataset, metric_on_set))
|
||||
return dataset, metric_on_set
|
||||
|
||||
|
||||
class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
||||
|
||||
@abc.abstractmethod
|
||||
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, verbose: bool=True):
|
||||
"""The initialization function that takes the dataset file path (or a dict loaded from that path) as input."""
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
return copy.deepcopy(self.meta_archs[index])
|
||||
|
||||
def arch(self, index: int):
|
||||
"""Return the topology structure of the `index`-th architecture."""
|
||||
if self.verbose:
|
||||
print('Call the arch function with index={:}'.format(index))
|
||||
assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs))
|
||||
return copy.deepcopy(self.meta_archs[index])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.meta_archs)
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}({num}/{total} architectures, file={filename})'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs), filename=self.filename))
|
||||
|
||||
@property
|
||||
def avaliable_hps(self):
|
||||
return list(copy.deepcopy(self._avaliable_hps))
|
||||
|
||||
@property
|
||||
def used_time(self):
|
||||
return self._used_time
|
||||
|
||||
def reset_time(self):
|
||||
self._used_time = 0
|
||||
|
||||
def simulate_train_eval(self, arch, dataset, iepoch=None, hp='12', account_time=True):
|
||||
index = self.query_index_by_arch(arch)
|
||||
all_names = ('cifar10', 'cifar100', 'ImageNet16-120')
|
||||
assert dataset in all_names, 'Invalid dataset name : {:} vs {:}'.format(dataset, all_names)
|
||||
if dataset == 'cifar10':
|
||||
info = self.get_more_info(index, 'cifar10-valid', iepoch=iepoch, hp=hp, is_random=True)
|
||||
else:
|
||||
info = self.get_more_info(index, dataset, iepoch=iepoch, hp=hp, is_random=True)
|
||||
valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time']
|
||||
latency = self.get_latency(index, dataset)
|
||||
if account_time:
|
||||
self._used_time += time_cost
|
||||
return valid_acc, latency, time_cost, self._used_time
|
||||
|
||||
def random(self):
|
||||
"""Return a random index of all architectures."""
|
||||
return random.randint(0, len(self.meta_archs)-1)
|
||||
|
||||
def query_index_by_arch(self, arch):
|
||||
""" This function is used to query the index of an architecture in the search space.
|
||||
In the topology search space, the input arch can be an architecture string such as '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|';
|
||||
or an instance that has the 'tostr' function that can generate the architecture string;
|
||||
or it is directly an architecture index, in this case, we will check whether it is valid or not.
|
||||
This function will return the index.
|
||||
If return -1, it means this architecture is not in the search space.
|
||||
Otherwise, it will return an int in [0, the-number-of-candidates-in-the-search-space).
|
||||
"""
|
||||
if self.verbose:
|
||||
print('Call query_index_by_arch with arch={:}'.format(arch))
|
||||
if isinstance(arch, int):
|
||||
if 0 <= arch < len(self):
|
||||
return arch
|
||||
else:
|
||||
raise ValueError('Invalid architecture index {:} vs [{:}, {:}].'.format(arch, 0, len(self)))
|
||||
elif isinstance(arch, str):
|
||||
if arch in self.archstr2index: arch_index = self.archstr2index[ arch ]
|
||||
else : arch_index = -1
|
||||
elif hasattr(arch, 'tostr'):
|
||||
if arch.tostr() in self.archstr2index: arch_index = self.archstr2index[ arch.tostr() ]
|
||||
else : arch_index = -1
|
||||
else: arch_index = -1
|
||||
return arch_index
|
||||
|
||||
def query_by_arch(self, arch, hp):
|
||||
# This is to make the current version be compatible with the old version.
|
||||
return self.query_info_str_by_arch(arch, hp)
|
||||
|
||||
@abc.abstractmethod
|
||||
def reload(self, archive_root: Text = None, index: int = None):
|
||||
"""Overwrite all information of the 'index'-th architecture in the search space, where the data will be loaded from 'archive_root'.
|
||||
If index is None, overwrite all ckps.
|
||||
"""
|
||||
|
||||
def clear_params(self, index: int, hp: Optional[Text]=None):
|
||||
"""Remove the architecture's weights to save memory.
|
||||
:arg
|
||||
index: the index of the target architecture
|
||||
hp: a flag to controll how to clear the parameters.
|
||||
-- None: clear all the weights in '01'/'12'/'90', which indicates the number of training epochs.
|
||||
-- '01' or '12' or '90': clear all the weights in arch2infos_dict[index][hp].
|
||||
"""
|
||||
if self.verbose:
|
||||
print('Call clear_params with index={:} and hp={:}'.format(index, hp))
|
||||
if hp is None:
|
||||
for key, result in self.arch2infos_dict[index].items():
|
||||
result.clear_params()
|
||||
else:
|
||||
if str(hp) not in self.arch2infos_dict[index]:
|
||||
raise ValueError('The {:}-th architecture only has hyper-parameters of {:} instead of {:}.'.format(index, list(self.arch2infos_dict[index].keys()), hp))
|
||||
self.arch2infos_dict[index][str(hp)].clear_params()
|
||||
|
||||
@abc.abstractmethod
|
||||
def query_info_str_by_arch(self, arch, hp: Text='12'):
|
||||
"""This function is used to query the information of a specific architecture."""
|
||||
|
||||
def _query_info_str_by_arch(self, arch, hp: Text='12', print_information=None):
|
||||
arch_index = self.query_index_by_arch(arch)
|
||||
if arch_index in self.arch2infos_dict:
|
||||
if hp not in self.arch2infos_dict[arch_index]:
|
||||
raise ValueError('The {:}-th architecture only has hyper-parameters of {:} instead of {:}.'.format(index, list(self.arch2infos_dict[arch_index].keys()), hp))
|
||||
info = self.arch2infos_dict[arch_index][hp]
|
||||
strings = print_information(info, 'arch-index={:}'.format(arch_index))
|
||||
return '\n'.join(strings)
|
||||
else:
|
||||
print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index))
|
||||
return None
|
||||
|
||||
def query_meta_info_by_index(self, arch_index, hp: Text = '12'):
|
||||
"""Return the ArchResults for the 'arch_index'-th architecture. This function is similar to query_by_index."""
|
||||
if self.verbose:
|
||||
print('Call query_meta_info_by_index with arch_index={:}, hp={:}'.format(arch_index, hp))
|
||||
if arch_index in self.arch2infos_dict:
|
||||
if hp not in self.arch2infos_dict[arch_index]:
|
||||
raise ValueError('The {:}-th architecture only has hyper-parameters of {:} instead of {:}.'.format(arch_index, list(self.arch2infos_dict[arch_index].keys()), hp))
|
||||
info = self.arch2infos_dict[arch_index][hp]
|
||||
else:
|
||||
raise ValueError('arch_index [{:}] does not in arch2infos'.format(arch_index))
|
||||
return copy.deepcopy(info)
|
||||
|
||||
def query_by_index(self, arch_index: int, dataname: Union[None, Text] = None, hp: Text = '12'):
|
||||
""" This 'query_by_index' function is used to query information with the training of 01 epochs, 12 epochs, 90 epochs, or 200 epochs.
|
||||
------
|
||||
If hp=01, we train the model by 01 epochs (see config in configs/nas-benchmark/hyper-opts/01E.config)
|
||||
If hp=12, we train the model by 01 epochs (see config in configs/nas-benchmark/hyper-opts/12E.config)
|
||||
If hp=90, we train the model by 01 epochs (see config in configs/nas-benchmark/hyper-opts/90E.config)
|
||||
If hp=200, we train the model by 01 epochs (see config in configs/nas-benchmark/hyper-opts/200E.config)
|
||||
------
|
||||
If dataname is None, return the ArchResults
|
||||
else, return a dict with all trials on that dataset (the key is the seed)
|
||||
Options are 'cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120'.
|
||||
-- cifar10-valid : training the model on the CIFAR-10 training set.
|
||||
-- cifar10 : training the model on the CIFAR-10 training + validation set.
|
||||
-- cifar100 : training the model on the CIFAR-100 training set.
|
||||
-- ImageNet16-120 : training the model on the ImageNet16-120 training set.
|
||||
"""
|
||||
if self.verbose:
|
||||
print('Call query_by_index with arch_index={:}, dataname={:}, hp={:}'.format(arch_index, dataname, hp))
|
||||
info = self.query_meta_info_by_index(arch_index, hp)
|
||||
if dataname is None: return info
|
||||
else:
|
||||
if dataname not in info.get_dataset_names():
|
||||
raise ValueError('invalid dataset-name : {:} vs. {:}'.format(dataname, info.get_dataset_names()))
|
||||
return info.query(dataname)
|
||||
|
||||
def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None, hp: Text = '12'):
|
||||
"""Find the architecture with the highest accuracy based on some constraints."""
|
||||
if self.verbose:
|
||||
print('Call find_best with dataset={:}, metric_on_set={:}, hp={:} | with #FLOPs < {:} and #Params < {:}'.format(dataset, metric_on_set, hp, FLOP_max, Param_max))
|
||||
dataset, metric_on_set = remap_dataset_set_names(dataset, metric_on_set, self.verbose)
|
||||
best_index, highest_accuracy = -1, None
|
||||
for i, arch_index in enumerate(self.evaluated_indexes):
|
||||
arch_info = self.arch2infos_dict[arch_index][hp]
|
||||
info = arch_info.get_compute_costs(dataset) # the information of costs
|
||||
flop, param, latency = info['flops'], info['params'], info['latency']
|
||||
if FLOP_max is not None and flop > FLOP_max : continue
|
||||
if Param_max is not None and param > Param_max: continue
|
||||
xinfo = arch_info.get_metrics(dataset, metric_on_set) # the information of loss and accuracy
|
||||
loss, accuracy = xinfo['loss'], xinfo['accuracy']
|
||||
if best_index == -1:
|
||||
best_index, highest_accuracy = arch_index, accuracy
|
||||
elif highest_accuracy < accuracy:
|
||||
best_index, highest_accuracy = arch_index, accuracy
|
||||
if self.verbose:
|
||||
print(' the best architecture : [{:}] {:} with accuracy={:.3f}%'.format(best_index, self.arch(best_index), highest_accuracy))
|
||||
return best_index, highest_accuracy
|
||||
|
||||
def get_net_param(self, index, dataset, seed: Optional[int], hp: Text = '12'):
|
||||
"""
|
||||
This function is used to obtain the trained weights of the `index`-th architecture on `dataset` with the seed of `seed`
|
||||
Args [seed]:
|
||||
-- None : return a dict containing the trained weights of all trials, where each key is a seed and its corresponding value is the weights.
|
||||
-- a interger : return the weights of a specific trial, whose seed is this interger.
|
||||
Args [hp]:
|
||||
-- 01 : train the model by 01 epochs
|
||||
-- 12 : train the model by 12 epochs
|
||||
-- 90 : train the model by 90 epochs
|
||||
-- 200 : train the model by 200 epochs
|
||||
"""
|
||||
if self.verbose:
|
||||
print('Call the get_net_param function with index={:}, dataset={:}, seed={:}, hp={:}'.format(index, dataset, seed, hp))
|
||||
info = self.query_meta_info_by_index(index, hp)
|
||||
return info.get_net_param(dataset, seed)
|
||||
|
||||
def get_net_config(self, index: int, dataset: Text):
|
||||
"""
|
||||
This function is used to obtain the configuration for the `index`-th architecture on `dataset`.
|
||||
Args [dataset] (4 possible options):
|
||||
-- cifar10-valid : training the model on the CIFAR-10 training set.
|
||||
-- cifar10 : training the model on the CIFAR-10 training + validation set.
|
||||
-- cifar100 : training the model on the CIFAR-100 training set.
|
||||
-- ImageNet16-120 : training the model on the ImageNet16-120 training set.
|
||||
This function will return a dict.
|
||||
========= Some examlpes for using this function:
|
||||
config = api.get_net_config(128, 'cifar10')
|
||||
"""
|
||||
if self.verbose:
|
||||
print('Call the get_net_config function with index={:}, dataset={:}.'.format(index, dataset))
|
||||
if index in self.arch2infos_dict:
|
||||
info = self.arch2infos_dict[index]
|
||||
else:
|
||||
raise ValueError('The arch_index={:} is not in arch2infos_dict.'.format(arch_index))
|
||||
info = next(iter(info.values()))
|
||||
results = info.query(dataset, None)
|
||||
results = next(iter(results.values()))
|
||||
return results.get_config(None)
|
||||
|
||||
def get_cost_info(self, index: int, dataset: Text, hp: Text = '12') -> Dict[Text, float]:
|
||||
"""To obtain the cost metric for the `index`-th architecture on a dataset."""
|
||||
if self.verbose:
|
||||
print('Call the get_cost_info function with index={:}, dataset={:}, and hp={:}.'.format(index, dataset, hp))
|
||||
info = self.query_meta_info_by_index(index, hp)
|
||||
return info.get_compute_costs(dataset)
|
||||
|
||||
def get_latency(self, index: int, dataset: Text, hp: Text = '12') -> float:
|
||||
"""
|
||||
To obtain the latency of the network (by default it will return the latency with the batch size of 256).
|
||||
:param index: the index of the target architecture
|
||||
:param dataset: the dataset name (cifar10-valid, cifar10, cifar100, ImageNet16-120)
|
||||
:return: return a float value in seconds
|
||||
"""
|
||||
if self.verbose:
|
||||
print('Call the get_latency function with index={:}, dataset={:}, and hp={:}.'.format(index, dataset, hp))
|
||||
cost_dict = self.get_cost_info(index, dataset, hp)
|
||||
return cost_dict['latency']
|
||||
|
||||
@abc.abstractmethod
|
||||
def show(self, index=-1):
|
||||
"""This function will print the information of a specific (or all) architecture(s)."""
|
||||
|
||||
def _show(self, index=-1, print_information=None) -> None:
|
||||
"""
|
||||
This function will print the information of a specific (or all) architecture(s).
|
||||
|
||||
:param index: If the index < 0: it will loop for all architectures and print their information one by one.
|
||||
else: it will print the information of the 'index'-th architecture.
|
||||
:return: nothing
|
||||
"""
|
||||
if index < 0: # show all architectures
|
||||
print(self)
|
||||
for i, idx in enumerate(self.evaluated_indexes):
|
||||
print('\n' + '-' * 10 + ' The ({:5d}/{:5d}) {:06d}-th architecture! '.format(i, len(self.evaluated_indexes), idx) + '-'*10)
|
||||
print('arch : {:}'.format(self.meta_archs[idx]))
|
||||
for key, result in self.arch2infos_dict[index].items():
|
||||
strings = print_information(result)
|
||||
print('>' * 40 + ' {:03d} epochs '.format(result.get_total_epoch()) + '>' * 40)
|
||||
print('\n'.join(strings))
|
||||
print('<' * 40 + '------------' + '<' * 40)
|
||||
else:
|
||||
if 0 <= index < len(self.meta_archs):
|
||||
if index not in self.evaluated_indexes: print('The {:}-th architecture has not been evaluated or not saved.'.format(index))
|
||||
else:
|
||||
arch_info = self.arch2infos_dict[index]
|
||||
for key, result in self.arch2infos_dict[index].items():
|
||||
strings = print_information(result)
|
||||
print('>' * 40 + ' {:03d} epochs '.format(result.get_total_epoch()) + '>' * 40)
|
||||
print('\n'.join(strings))
|
||||
print('<' * 40 + '------------' + '<' * 40)
|
||||
else:
|
||||
print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs)))
|
||||
|
||||
def statistics(self, dataset: Text, hp: Union[Text, int]) -> Dict[int, int]:
|
||||
"""This function will count the number of total trials."""
|
||||
if self.verbose:
|
||||
print('Call the statistics function with dataset={:} and hp={:}.'.format(dataset, hp))
|
||||
valid_datasets = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
|
||||
if dataset not in valid_datasets:
|
||||
raise ValueError('{:} not in {:}'.format(dataset, valid_datasets))
|
||||
nums, hp = defaultdict(lambda: 0), str(hp)
|
||||
for index in range(len(self)):
|
||||
archInfo = self.arch2infos_dict[index][hp]
|
||||
dataset_seed = archInfo.dataset_seed
|
||||
if dataset not in dataset_seed:
|
||||
nums[0] += 1
|
||||
else:
|
||||
nums[len(dataset_seed[dataset])] += 1
|
||||
return dict(nums)
|
||||
|
||||
|
||||
class ArchResults(object):
|
||||
|
||||
def __init__(self, arch_index, arch_str):
|
||||
self.arch_index = int(arch_index)
|
||||
self.arch_str = copy.deepcopy(arch_str)
|
||||
self.all_results = dict()
|
||||
self.dataset_seed = dict()
|
||||
self.clear_net_done = False
|
||||
|
||||
def get_compute_costs(self, dataset):
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
results = [self.all_results[ (dataset, seed) ] for seed in x_seeds]
|
||||
|
||||
flops = [result.flop for result in results]
|
||||
params = [result.params for result in results]
|
||||
latencies = [result.get_latency() for result in results]
|
||||
latencies = [x for x in latencies if x > 0]
|
||||
mean_latency = np.mean(latencies) if len(latencies) > 0 else None
|
||||
time_infos = defaultdict(list)
|
||||
for result in results:
|
||||
time_info = result.get_times()
|
||||
for key, value in time_info.items(): time_infos[key].append( value )
|
||||
|
||||
info = {'flops' : np.mean(flops),
|
||||
'params' : np.mean(params),
|
||||
'latency': mean_latency}
|
||||
for key, value in time_infos.items():
|
||||
if len(value) > 0 and value[0] is not None:
|
||||
info[key] = np.mean(value)
|
||||
else: info[key] = None
|
||||
return info
|
||||
|
||||
def get_metrics(self, dataset, setname, iepoch=None, is_random=False):
|
||||
"""
|
||||
This `get_metrics` function is used to obtain obtain the loss, accuracy, etc information on a specific dataset.
|
||||
If not specify, each set refer to the proposed split in NAS-Bench-201 paper.
|
||||
If some args return None or raise error, then it is not avaliable.
|
||||
========================================
|
||||
Args [dataset] (4 possible options):
|
||||
-- cifar10-valid : training the model on the CIFAR-10 training set.
|
||||
-- cifar10 : training the model on the CIFAR-10 training + validation set.
|
||||
-- cifar100 : training the model on the CIFAR-100 training set.
|
||||
-- ImageNet16-120 : training the model on the ImageNet16-120 training set.
|
||||
Args [setname] (each dataset has different setnames):
|
||||
-- When dataset = cifar10-valid, you can use 'train', 'x-valid', 'ori-test'
|
||||
------ 'train' : the metric on the training set.
|
||||
------ 'x-valid' : the metric on the validation set.
|
||||
------ 'ori-test' : the metric on the test set.
|
||||
-- When dataset = cifar10, you can use 'train', 'ori-test'.
|
||||
------ 'train' : the metric on the training + validation set.
|
||||
------ 'ori-test' : the metric on the test set.
|
||||
-- When dataset = cifar100 or ImageNet16-120, you can use 'train', 'ori-test', 'x-valid', 'x-test'
|
||||
------ 'train' : the metric on the training set.
|
||||
------ 'x-valid' : the metric on the validation set.
|
||||
------ 'x-test' : the metric on the test set.
|
||||
------ 'ori-test' : the metric on the validation + test set.
|
||||
Args [iepoch] (None or an integer in [0, the-number-of-total-training-epochs)
|
||||
------ None : return the metric after the last training epoch.
|
||||
------ an integer i : return the metric after the i-th training epoch.
|
||||
Args [is_random]:
|
||||
------ True : return the metric of a randomly selected trial.
|
||||
------ False : return the averaged metric of all avaliable trials.
|
||||
------ an integer indicating the 'seed' value : return the metric of a specific trial (whose random seed is 'is_random').
|
||||
"""
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
results = [self.all_results[ (dataset, seed) ] for seed in x_seeds]
|
||||
infos = defaultdict(list)
|
||||
for result in results:
|
||||
if setname == 'train':
|
||||
info = result.get_train(iepoch)
|
||||
else:
|
||||
info = result.get_eval(setname, iepoch)
|
||||
for key, value in info.items(): infos[key].append( value )
|
||||
return_info = dict()
|
||||
if isinstance(is_random, bool) and is_random: # randomly select one
|
||||
index = random.randint(0, len(results)-1)
|
||||
for key, value in infos.items(): return_info[key] = value[index]
|
||||
elif isinstance(is_random, bool) and not is_random: # average
|
||||
for key, value in infos.items():
|
||||
if len(value) > 0 and value[0] is not None:
|
||||
return_info[key] = np.mean(value)
|
||||
else: return_info[key] = None
|
||||
elif isinstance(is_random, int): # specify the seed
|
||||
if is_random not in x_seeds: raise ValueError('can not find random seed ({:}) from {:}'.format(is_random, x_seeds))
|
||||
index = x_seeds.index(is_random)
|
||||
for key, value in infos.items(): return_info[key] = value[index]
|
||||
else:
|
||||
raise ValueError('invalid value for is_random: {:}'.format(is_random))
|
||||
return return_info
|
||||
|
||||
def show(self, is_print=False):
|
||||
return print_information(self, None, is_print)
|
||||
|
||||
def get_dataset_names(self):
|
||||
return list(self.dataset_seed.keys())
|
||||
|
||||
def get_dataset_seeds(self, dataset):
|
||||
return copy.deepcopy( self.dataset_seed[dataset] )
|
||||
|
||||
def get_net_param(self, dataset: Text, seed: Union[None, int] =None):
|
||||
"""
|
||||
This function will return the trained network's weights on the 'dataset'.
|
||||
:arg
|
||||
dataset: one of 'cifar10-valid', 'cifar10', 'cifar100', and 'ImageNet16-120'.
|
||||
seed: an integer indicates the seed value or None that indicates returing all trials.
|
||||
"""
|
||||
if seed is None:
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
return {seed: self.all_results[(dataset, seed)].get_net_param() for seed in x_seeds}
|
||||
else:
|
||||
xkey = (dataset, seed)
|
||||
if xkey in self.all_results:
|
||||
return self.all_results[xkey].get_net_param()
|
||||
else:
|
||||
raise ValueError('key={:} not in {:}'.format(xkey, list(self.all_results.keys())))
|
||||
|
||||
def reset_latency(self, dataset: Text, seed: Union[None, Text], latency: float) -> None:
|
||||
"""This function is used to reset the latency in all corresponding ResultsCount(s)."""
|
||||
if seed is None:
|
||||
for seed in self.dataset_seed[dataset]:
|
||||
self.all_results[(dataset, seed)].update_latency([latency])
|
||||
else:
|
||||
self.all_results[(dataset, seed)].update_latency([latency])
|
||||
|
||||
def reset_pseudo_train_times(self, dataset: Text, seed: Union[None, Text], estimated_per_epoch_time: float) -> None:
|
||||
"""This function is used to reset the train-times in all corresponding ResultsCount(s)."""
|
||||
if seed is None:
|
||||
for seed in self.dataset_seed[dataset]:
|
||||
self.all_results[(dataset, seed)].reset_pseudo_train_times(estimated_per_epoch_time)
|
||||
else:
|
||||
self.all_results[(dataset, seed)].reset_pseudo_train_times(estimated_per_epoch_time)
|
||||
|
||||
def reset_pseudo_eval_times(self, dataset: Text, seed: Union[None, Text], eval_name: Text, estimated_per_epoch_time: float) -> None:
|
||||
"""This function is used to reset the eval-times in all corresponding ResultsCount(s)."""
|
||||
if seed is None:
|
||||
for seed in self.dataset_seed[dataset]:
|
||||
self.all_results[(dataset, seed)].reset_pseudo_eval_times(eval_name, estimated_per_epoch_time)
|
||||
else:
|
||||
self.all_results[(dataset, seed)].reset_pseudo_eval_times(eval_name, estimated_per_epoch_time)
|
||||
|
||||
def get_latency(self, dataset: Text) -> float:
|
||||
"""Get the latency of a model on the target dataset. [Timestamp: 2020.03.09]"""
|
||||
latencies = []
|
||||
for seed in self.dataset_seed[dataset]:
|
||||
latency = self.all_results[(dataset, seed)].get_latency()
|
||||
if not isinstance(latency, float) or latency <= 0:
|
||||
raise ValueError('invalid latency of {:} with seed={:} : {:}'.format(dataset, seed, latency))
|
||||
latencies.append(latency)
|
||||
return sum(latencies) / len(latencies)
|
||||
|
||||
def get_total_epoch(self, dataset=None):
|
||||
"""Return the total number of training epochs."""
|
||||
if dataset is None:
|
||||
epochss = []
|
||||
for xdata, x_seeds in self.dataset_seed.items():
|
||||
epochss += [self.all_results[(xdata, seed)].get_total_epoch() for seed in x_seeds]
|
||||
elif isinstance(dataset, str):
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
epochss = [self.all_results[(dataset, seed)].get_total_epoch() for seed in x_seeds]
|
||||
else:
|
||||
raise ValueError('invalid dataset={:}'.format(dataset))
|
||||
if len(set(epochss)) > 1: raise ValueError('Each trial mush have the same number of training epochs : {:}'.format(epochss))
|
||||
return epochss[-1]
|
||||
|
||||
def query(self, dataset, seed=None):
|
||||
"""Return the ResultsCount object (containing all information of a single trial) for 'dataset' and 'seed'"""
|
||||
if seed is None:
|
||||
#print(self.dataset_seed.keys())
|
||||
#print(dataset)
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
return {seed: self.all_results[(dataset, seed)] for seed in x_seeds}
|
||||
else:
|
||||
return self.all_results[(dataset, seed)]
|
||||
|
||||
def arch_idx_str(self):
|
||||
return '{:06d}'.format(self.arch_index)
|
||||
|
||||
def update(self, dataset_name, seed, result):
|
||||
if dataset_name not in self.dataset_seed:
|
||||
self.dataset_seed[dataset_name] = []
|
||||
assert seed not in self.dataset_seed[dataset_name], '{:}-th arch alreadly has this seed ({:}) on {:}'.format(self.arch_index, seed, dataset_name)
|
||||
self.dataset_seed[ dataset_name ].append( seed )
|
||||
self.dataset_seed[ dataset_name ] = sorted( self.dataset_seed[ dataset_name ] )
|
||||
assert (dataset_name, seed) not in self.all_results
|
||||
self.all_results[ (dataset_name, seed) ] = result
|
||||
self.clear_net_done = False
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = dict()
|
||||
for key, value in self.__dict__.items():
|
||||
if key == 'all_results': # contain the class of ResultsCount
|
||||
xvalue = dict()
|
||||
assert isinstance(value, dict), 'invalid type of value for {:} : {:}'.format(key, type(value))
|
||||
for _k, _v in value.items():
|
||||
assert isinstance(_v, ResultsCount), 'invalid type of value for {:}/{:} : {:}'.format(key, _k, type(_v))
|
||||
xvalue[_k] = _v.state_dict()
|
||||
else:
|
||||
xvalue = value
|
||||
state_dict[key] = xvalue
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
new_state_dict = dict()
|
||||
for key, value in state_dict.items():
|
||||
if key == 'all_results': # to convert to the class of ResultsCount
|
||||
xvalue = dict()
|
||||
assert isinstance(value, dict), 'invalid type of value for {:} : {:}'.format(key, type(value))
|
||||
for _k, _v in value.items():
|
||||
xvalue[_k] = ResultsCount.create_from_state_dict(_v)
|
||||
else: xvalue = value
|
||||
new_state_dict[key] = xvalue
|
||||
self.__dict__.update(new_state_dict)
|
||||
|
||||
@staticmethod
|
||||
def create_from_state_dict(state_dict_or_file):
|
||||
x = ArchResults(-1, -1)
|
||||
if isinstance(state_dict_or_file, str): # a file path
|
||||
state_dict = torch.load(state_dict_or_file, map_location='cpu')
|
||||
elif isinstance(state_dict_or_file, dict):
|
||||
state_dict = state_dict_or_file
|
||||
else:
|
||||
raise ValueError('invalid type of state_dict_or_file : {:}'.format(type(state_dict_or_file)))
|
||||
x.load_state_dict(state_dict)
|
||||
return x
|
||||
|
||||
# This function is used to clear the weights saved in each 'result'
|
||||
# This can help reduce the memory footprint.
|
||||
def clear_params(self):
|
||||
for key, result in self.all_results.items():
|
||||
del result.net_state_dict
|
||||
result.net_state_dict = None
|
||||
self.clear_net_done = True
|
||||
|
||||
def debug_test(self):
|
||||
"""This function is used for me to debug and test, which will call most methods."""
|
||||
all_dataset = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
|
||||
for dataset in all_dataset:
|
||||
print('---->>>> {:}'.format(dataset))
|
||||
print('The latency on {:} is {:} s'.format(dataset, self.get_latency(dataset)))
|
||||
for seed in self.dataset_seed[dataset]:
|
||||
result = self.all_results[(dataset, seed)]
|
||||
print(' ==>> result = {:}'.format(result))
|
||||
print(' ==>> cost = {:}'.format(result.get_times()))
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}(arch-index={index}, arch={arch}, {num} runs, clear={clear})'.format(name=self.__class__.__name__, index=self.arch_index, arch=self.arch_str, num=len(self.all_results), clear=self.clear_net_done))
|
||||
|
||||
|
||||
"""
|
||||
This class (ResultsCount) is used to save the information of one trial for a single architecture.
|
||||
I did not write much comment for this class, because it is the lowest-level class in NAS-Bench-201 API, which will be rarely called.
|
||||
If you have any question regarding this class, please open an issue or email me.
|
||||
"""
|
||||
class ResultsCount(object):
|
||||
|
||||
def __init__(self, name, state_dict, train_accs, train_losses, params, flop, arch_config, seed, epochs, latency):
|
||||
self.name = name
|
||||
self.net_state_dict = state_dict
|
||||
self.train_acc1es = copy.deepcopy(train_accs)
|
||||
self.train_acc5es = None
|
||||
self.train_losses = copy.deepcopy(train_losses)
|
||||
self.train_times = None
|
||||
self.arch_config = copy.deepcopy(arch_config)
|
||||
self.params = params
|
||||
self.flop = flop
|
||||
self.seed = seed
|
||||
self.epochs = epochs
|
||||
self.latency = latency
|
||||
# evaluation results
|
||||
self.reset_eval()
|
||||
|
||||
def update_train_info(self, train_acc1es, train_acc5es, train_losses, train_times) -> None:
|
||||
self.train_acc1es = train_acc1es
|
||||
self.train_acc5es = train_acc5es
|
||||
self.train_losses = train_losses
|
||||
self.train_times = train_times
|
||||
|
||||
def reset_pseudo_train_times(self, estimated_per_epoch_time: float) -> None:
|
||||
"""Assign the training times."""
|
||||
train_times = OrderedDict()
|
||||
for i in range(self.epochs):
|
||||
train_times[i] = estimated_per_epoch_time
|
||||
self.train_times = train_times
|
||||
|
||||
def reset_pseudo_eval_times(self, eval_name: Text, estimated_per_epoch_time: float) -> None:
|
||||
"""Assign the evaluation times."""
|
||||
if eval_name not in self.eval_names: raise ValueError('invalid eval name : {:}'.format(eval_name))
|
||||
for i in range(self.epochs):
|
||||
self.eval_times['{:}@{:}'.format(eval_name,i)] = estimated_per_epoch_time
|
||||
|
||||
def reset_eval(self):
|
||||
self.eval_names = []
|
||||
self.eval_acc1es = {}
|
||||
self.eval_times = {}
|
||||
self.eval_losses = {}
|
||||
|
||||
def update_latency(self, latency):
|
||||
self.latency = copy.deepcopy( latency )
|
||||
|
||||
def get_latency(self) -> float:
|
||||
"""Return the latency value in seconds. -1 represents not avaliable ; otherwise it should be a float value"""
|
||||
if self.latency is None: return -1.0
|
||||
else: return sum(self.latency) / len(self.latency)
|
||||
|
||||
def update_eval(self, accs, losses, times): # new version
|
||||
data_names = set([x.split('@')[0] for x in accs.keys()])
|
||||
for data_name in data_names:
|
||||
assert data_name not in self.eval_names, '{:} has already been added into eval-names'.format(data_name)
|
||||
self.eval_names.append( data_name )
|
||||
for iepoch in range(self.epochs):
|
||||
xkey = '{:}@{:}'.format(data_name, iepoch)
|
||||
self.eval_acc1es[ xkey ] = accs[ xkey ]
|
||||
self.eval_losses[ xkey ] = losses[ xkey ]
|
||||
self.eval_times [ xkey ] = times[ xkey ]
|
||||
|
||||
def update_OLD_eval(self, name, accs, losses): # old version
|
||||
assert name not in self.eval_names, '{:} has already added'.format(name)
|
||||
self.eval_names.append( name )
|
||||
for iepoch in range(self.epochs):
|
||||
if iepoch in accs:
|
||||
self.eval_acc1es['{:}@{:}'.format(name,iepoch)] = accs[iepoch]
|
||||
self.eval_losses['{:}@{:}'.format(name,iepoch)] = losses[iepoch]
|
||||
|
||||
def __repr__(self):
|
||||
num_eval = len(self.eval_names)
|
||||
set_name = '[' + ', '.join(self.eval_names) + ']'
|
||||
return ('{name}({xname}, arch={arch}, FLOP={flop:.2f}M, Param={param:.3f}MB, seed={seed}, {num_eval} eval-sets: {set_name})'.format(name=self.__class__.__name__, xname=self.name, arch=self.arch_config['arch_str'], flop=self.flop, param=self.params, seed=self.seed, num_eval=num_eval, set_name=set_name))
|
||||
|
||||
def get_total_epoch(self):
|
||||
return copy.deepcopy(self.epochs)
|
||||
|
||||
def get_times(self):
|
||||
"""Obtain the information regarding both training and evaluation time."""
|
||||
if self.train_times is not None and isinstance(self.train_times, dict):
|
||||
train_times = list( self.train_times.values() )
|
||||
time_info = {'T-train@epoch': np.mean(train_times), 'T-train@total': np.sum(train_times)}
|
||||
else:
|
||||
time_info = {'T-train@epoch': None, 'T-train@total': None }
|
||||
for name in self.eval_names:
|
||||
try:
|
||||
xtimes = [self.eval_times['{:}@{:}'.format(name,i)] for i in range(self.epochs)]
|
||||
time_info['T-{:}@epoch'.format(name)] = np.mean(xtimes)
|
||||
time_info['T-{:}@total'.format(name)] = np.sum(xtimes)
|
||||
except:
|
||||
time_info['T-{:}@epoch'.format(name)] = None
|
||||
time_info['T-{:}@total'.format(name)] = None
|
||||
return time_info
|
||||
|
||||
def get_eval_set(self):
|
||||
return self.eval_names
|
||||
|
||||
# get the training information
|
||||
def get_train(self, iepoch=None):
|
||||
if iepoch is None: iepoch = self.epochs-1
|
||||
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
|
||||
if self.train_times is not None:
|
||||
xtime = self.train_times[iepoch]
|
||||
atime = sum([self.train_times[i] for i in range(iepoch+1)])
|
||||
else: xtime, atime = None, None
|
||||
return {'iepoch' : iepoch,
|
||||
'loss' : self.train_losses[iepoch],
|
||||
'accuracy': self.train_acc1es[iepoch],
|
||||
'cur_time': xtime,
|
||||
'all_time': atime}
|
||||
|
||||
def get_eval(self, name, iepoch=None):
|
||||
"""Get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument)."""
|
||||
if iepoch is None: iepoch = self.epochs-1
|
||||
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
|
||||
def _internal_query(xname):
|
||||
if isinstance(self.eval_times,dict) and len(self.eval_times) > 0:
|
||||
xtime = self.eval_times['{:}@{:}'.format(xname, iepoch)]
|
||||
atime = sum([self.eval_times['{:}@{:}'.format(xname, i)] for i in range(iepoch+1)])
|
||||
else:
|
||||
xtime, atime = None, None
|
||||
return {'iepoch' : iepoch,
|
||||
'loss' : self.eval_losses['{:}@{:}'.format(xname, iepoch)],
|
||||
'accuracy': self.eval_acc1es['{:}@{:}'.format(xname, iepoch)],
|
||||
'cur_time': xtime,
|
||||
'all_time': atime}
|
||||
if name == 'valid':
|
||||
return _internal_query('x-valid')
|
||||
else:
|
||||
return _internal_query(name)
|
||||
|
||||
def get_net_param(self, clone=False):
|
||||
if clone: return copy.deepcopy(self.net_state_dict)
|
||||
else: return self.net_state_dict
|
||||
|
||||
def get_config(self, str2structure):
|
||||
"""This function is used to obtain the config dict for this architecture."""
|
||||
if str2structure is None:
|
||||
# In this case, this is to handle the size search space.
|
||||
if 'name' in self.arch_config and self.arch_config['name'] == 'infer.shape.tiny':
|
||||
return {'name': 'infer.shape.tiny', 'channels': self.arch_config['channels'],
|
||||
'genotype': self.arch_config['genotype'], 'num_classes': self.arch_config['class_num']}
|
||||
# In this case, this is NAS-Bench-201
|
||||
else:
|
||||
return {'name': 'infer.tiny', 'C': self.arch_config['channel'],
|
||||
'N' : self.arch_config['num_cells'],
|
||||
'arch_str': self.arch_config['arch_str'], 'num_classes': self.arch_config['class_num']}
|
||||
else:
|
||||
# In this case, this is to handle the size search space.
|
||||
if 'name' in self.arch_config and self.arch_config['name'] == 'infer.shape.tiny':
|
||||
return {'name': 'infer.shape.tiny', 'channels': self.arch_config['channels'],
|
||||
'genotype': str2structure(self.arch_config['genotype']), 'num_classes': self.arch_config['class_num']}
|
||||
# In this case, this is NAS-Bench-201
|
||||
else:
|
||||
return {'name': 'infer.tiny', 'C': self.arch_config['channel'],
|
||||
'N' : self.arch_config['num_cells'],
|
||||
'genotype': str2structure(self.arch_config['arch_str']), 'num_classes': self.arch_config['class_num']}
|
||||
|
||||
def state_dict(self):
|
||||
_state_dict = {key: value for key, value in self.__dict__.items()}
|
||||
return _state_dict
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
@staticmethod
|
||||
def create_from_state_dict(state_dict):
|
||||
x = ResultsCount(None, None, None, None, None, None, None, None, None, None)
|
||||
x.load_state_dict(state_dict)
|
||||
return x
|
||||
360
nasspace.py
Normal file
360
nasspace.py
Normal file
@@ -0,0 +1,360 @@
|
||||
from models import get_cell_based_tiny_net, get_search_spaces
|
||||
from nas_201_api import NASBench201API as API
|
||||
from nasbench import api as nasbench101api
|
||||
from nas_101_api.model import Network
|
||||
from nas_101_api.model_spec import ModelSpec
|
||||
import itertools
|
||||
import random
|
||||
import numpy as np
|
||||
from models.cell_searchs.genotypes import Structure
|
||||
from copy import deepcopy
|
||||
from pycls.models.nas.nas import NetworkImageNet, NetworkCIFAR
|
||||
from pycls.models.anynet import AnyNet
|
||||
from pycls.models.nas.genotypes import GENOTYPES, Genotype
|
||||
import json
|
||||
import torch
|
||||
|
||||
|
||||
class Nasbench201:
|
||||
def __init__(self, dataset, apiloc):
|
||||
self.dataset = dataset
|
||||
self.api = API(apiloc, verbose=False)
|
||||
self.epochs = '12'
|
||||
def get_network(self, uid):
|
||||
#config = self.api.get_net_config(uid, self.dataset)
|
||||
config = self.api.get_net_config(uid, 'cifar10-valid')
|
||||
config['num_classes'] = 1
|
||||
network = get_cell_based_tiny_net(config)
|
||||
return network
|
||||
def __iter__(self):
|
||||
for uid in range(len(self)):
|
||||
network = self.get_network(uid)
|
||||
yield uid, network
|
||||
def __getitem__(self, index):
|
||||
return index
|
||||
def __len__(self):
|
||||
return 15625
|
||||
def num_activations(self):
|
||||
network = self.get_network(0)
|
||||
return network.classifier.in_features
|
||||
#def get_12epoch_accuracy(self, uid, acc_type, trainval, traincifar10=False):
|
||||
# archinfo = self.api.query_meta_info_by_index(uid)
|
||||
# if (self.dataset == 'cifar10' or traincifar10) and trainval:
|
||||
# #return archinfo.get_metrics('cifar10-valid', acc_type, iepoch=12)['accuracy']
|
||||
# return archinfo.get_metrics('cifar10-valid', 'x-valid', iepoch=12)['accuracy']
|
||||
# elif traincifar10:
|
||||
# return archinfo.get_metrics('cifar10', acc_type, iepoch=12)['accuracy']
|
||||
# else:
|
||||
# return archinfo.get_metrics(self.dataset, 'ori-test', iepoch=12)['accuracy']
|
||||
def get_12epoch_accuracy(self, uid, acc_type, trainval, traincifar10=False):
|
||||
#archinfo = self.api.query_meta_info_by_index(uid)
|
||||
#if (self.dataset == 'cifar10' and trainval) or traincifar10:
|
||||
info = self.api.get_more_info(uid, 'cifar10-valid', iepoch=None, hp=self.epochs, is_random=True)
|
||||
#else:
|
||||
# info = self.api.get_more_info(uid, self.dataset, iepoch=None, hp=self.epochs, is_random=True)
|
||||
return info['valid-accuracy']
|
||||
def get_final_accuracy(self, uid, acc_type, trainval):
|
||||
#archinfo = self.api.query_meta_info_by_index(uid)
|
||||
if self.dataset == 'cifar10' and trainval:
|
||||
info = self.api.query_meta_info_by_index(uid, hp='200').get_metrics('cifar10-valid', 'x-valid')
|
||||
#info = self.api.query_by_index(uid, 'cifar10-valid', hp='200')
|
||||
#info = self.api.get_more_info(uid, 'cifar10-valid', iepoch=None, hp='200', is_random=True)
|
||||
else:
|
||||
info = self.api.query_meta_info_by_index(uid, hp='200').get_metrics(self.dataset, acc_type)
|
||||
#info = self.api.query_by_index(uid, self.dataset, hp='200')
|
||||
#info = self.api.get_more_info(uid, self.dataset, iepoch=None, hp='200', is_random=True)
|
||||
return info['accuracy']
|
||||
#return info['valid-accuracy']
|
||||
#if self.dataset == 'cifar10' and trainval:
|
||||
# return archinfo.get_metrics('cifar10-valid', acc_type, iepoch=11)['accuracy']
|
||||
#else:
|
||||
# #return archinfo.get_metrics(self.dataset, 'ori-test', iepoch=12)['accuracy']
|
||||
# return archinfo.get_metrics(self.dataset, 'x-test', iepoch=11)['accuracy']
|
||||
##dataset = self.dataset
|
||||
##if self.dataset == 'cifar10' and trainval:
|
||||
## dataset = 'cifar10-valid'
|
||||
##archinfo = self.api.get_more_info(uid, dataset, iepoch=None, use_12epochs_result=True, is_random=True)
|
||||
##return archinfo['valid-accuracy']
|
||||
|
||||
def get_accuracy(self, uid, acc_type, trainval=True):
|
||||
archinfo = self.api.query_meta_info_by_index(uid)
|
||||
if self.dataset == 'cifar10' and trainval:
|
||||
return archinfo.get_metrics('cifar10-valid', acc_type)['accuracy']
|
||||
else:
|
||||
return archinfo.get_metrics(self.dataset, acc_type)['accuracy']
|
||||
|
||||
def get_accuracy_for_all_datasets(self, uid):
|
||||
archinfo = self.api.query_meta_info_by_index(uid,hp='200')
|
||||
|
||||
c10 = archinfo.get_metrics('cifar10', 'ori-test')['accuracy']
|
||||
c10_val = archinfo.get_metrics('cifar10-valid', 'x-valid')['accuracy']
|
||||
|
||||
c100 = archinfo.get_metrics('cifar100', 'x-test')['accuracy']
|
||||
c100_val = archinfo.get_metrics('cifar100', 'x-valid')['accuracy']
|
||||
|
||||
imagenet = archinfo.get_metrics('ImageNet16-120', 'x-test')['accuracy']
|
||||
imagenet_val = archinfo.get_metrics('ImageNet16-120', 'x-valid')['accuracy']
|
||||
|
||||
return c10, c10_val, c100, c100_val, imagenet, imagenet_val
|
||||
|
||||
#def train_and_eval(self, arch, dataname, acc_type, trainval=True):
|
||||
# unique_hash = self.__getitem__(arch)
|
||||
# time = self.get_training_time(unique_hash)
|
||||
# acc12 = self.get_12epoch_accuracy(unique_hash, acc_type, trainval)
|
||||
# acc = self.get_final_accuracy(unique_hash, acc_type, trainval)
|
||||
# return acc12, acc, time
|
||||
def train_and_eval(self, arch, dataname, acc_type, trainval=True, traincifar10=False):
|
||||
unique_hash = self.__getitem__(arch)
|
||||
time = self.get_training_time(unique_hash)
|
||||
acc12 = self.get_12epoch_accuracy(unique_hash, acc_type, trainval, traincifar10)
|
||||
acc = self.get_final_accuracy(unique_hash, acc_type, trainval)
|
||||
return acc12, acc, time
|
||||
def random_arch(self):
|
||||
return random.randint(0, len(self)-1)
|
||||
def get_training_time(self, unique_hash):
|
||||
#info = self.api.get_more_info(unique_hash, 'cifar10-valid' if self.dataset == 'cifar10' else self.dataset, iepoch=None, use_12epochs_result=True, is_random=True)
|
||||
|
||||
|
||||
#info = self.api.get_more_info(unique_hash, 'cifar10-valid', iepoch=None, use_12epochs_result=True, is_random=True)
|
||||
info = self.api.get_more_info(unique_hash, 'cifar10-valid', iepoch=None, hp='12', is_random=True)
|
||||
return info['train-all-time'] + info['valid-per-time']
|
||||
#if self.dataset == 'cifar10' and trainval:
|
||||
# info = self.api.get_more_info(unique_hash, 'cifar10-valid', iepoch=None, hp=self.epochs, is_random=True)
|
||||
#else:
|
||||
# info = self.api.get_more_info(unique_hash, self.dataset, iepoch=None, hp=self.epochs, is_random=True)
|
||||
|
||||
##info = self.api.get_more_info(unique_hash, 'cifar10-valid', iepoch=None, use_12epochs_result=True, is_random=True)
|
||||
#return info['train-all-time'] + info['valid-per-time']
|
||||
def mutate_arch(self, arch):
|
||||
op_names = get_search_spaces('cell', 'nas-bench-201')
|
||||
#config = self.api.get_net_config(arch, self.dataset)
|
||||
config = self.api.get_net_config(arch, 'cifar10-valid')
|
||||
parent_arch = Structure(self.api.str2lists(config['arch_str']))
|
||||
child_arch = deepcopy( parent_arch )
|
||||
node_id = random.randint(0, len(child_arch.nodes)-1)
|
||||
node_info = list( child_arch.nodes[node_id] )
|
||||
snode_id = random.randint(0, len(node_info)-1)
|
||||
xop = random.choice( op_names )
|
||||
while xop == node_info[snode_id][0]:
|
||||
xop = random.choice( op_names )
|
||||
node_info[snode_id] = (xop, node_info[snode_id][1])
|
||||
child_arch.nodes[node_id] = tuple( node_info )
|
||||
arch_index = self.api.query_index_by_arch( child_arch )
|
||||
return arch_index
|
||||
|
||||
class Nasbench101:
|
||||
def __init__(self, dataset, apiloc, args):
|
||||
self.dataset = dataset
|
||||
self.api = nasbench101api.NASBench(apiloc)
|
||||
self.args = args
|
||||
def get_accuracy(self, unique_hash, acc_type, trainval=True):
|
||||
spec = self.get_spec(unique_hash)
|
||||
_, stats = self.api.get_metrics_from_spec(spec)
|
||||
maxacc = 0.
|
||||
for ep in stats:
|
||||
for statmap in stats[ep]:
|
||||
newacc = statmap['final_test_accuracy']
|
||||
if newacc > maxacc:
|
||||
maxacc = newacc
|
||||
return maxacc
|
||||
def get_final_accuracy(self, uid, acc_type, trainval):
|
||||
return self.get_accuracy(uid, acc_type, trainval)
|
||||
def get_training_time(self, unique_hash):
|
||||
spec = self.get_spec(unique_hash)
|
||||
_, stats = self.api.get_metrics_from_spec(spec)
|
||||
maxacc = -1.
|
||||
maxtime = 0.
|
||||
for ep in stats:
|
||||
for statmap in stats[ep]:
|
||||
newacc = statmap['final_test_accuracy']
|
||||
if newacc > maxacc:
|
||||
maxacc = newacc
|
||||
maxtime = statmap['final_training_time']
|
||||
return maxtime
|
||||
def get_network(self, unique_hash):
|
||||
spec = self.get_spec(unique_hash)
|
||||
network = Network(spec, self.args)
|
||||
return network
|
||||
def get_spec(self, unique_hash):
|
||||
matrix = self.api.fixed_statistics[unique_hash]['module_adjacency']
|
||||
operations = self.api.fixed_statistics[unique_hash]['module_operations']
|
||||
spec = ModelSpec(matrix, operations)
|
||||
return spec
|
||||
def __iter__(self):
|
||||
for unique_hash in self.api.hash_iterator():
|
||||
network = self.get_network(unique_hash)
|
||||
yield unique_hash, network
|
||||
def __getitem__(self, index):
|
||||
return next(itertools.islice(self.api.hash_iterator(), index, None))
|
||||
def __len__(self):
|
||||
return len(self.api.hash_iterator())
|
||||
def num_activations(self):
|
||||
for unique_hash in self.api.hash_iterator():
|
||||
network = self.get_network(unique_hash)
|
||||
return network.classifier.in_features
|
||||
def train_and_eval(self, arch, dataname, acc_type, trainval=True, traincifar10=False):
|
||||
unique_hash = self.__getitem__(arch)
|
||||
time =12.* self.get_training_time(unique_hash)/108.
|
||||
acc = self.get_accuracy(unique_hash, acc_type, trainval)
|
||||
return acc, acc, time
|
||||
def random_arch(self):
|
||||
return random.randint(0, len(self)-1)
|
||||
def mutate_arch(self, arch):
|
||||
unique_hash = self.__getitem__(arch)
|
||||
matrix = self.api.fixed_statistics[unique_hash]['module_adjacency']
|
||||
operations = self.api.fixed_statistics[unique_hash]['module_operations']
|
||||
coords = [ (i, j) for i in range(matrix.shape[0]) for j in range(i+1, matrix.shape[1])]
|
||||
random.shuffle(coords)
|
||||
# loop through changes until we find change thats allowed
|
||||
for i, j in coords:
|
||||
# try the ops in a particular order
|
||||
for k in [m for m in np.unique(matrix) if m != matrix[i, j]]:
|
||||
newmatrix = matrix.copy()
|
||||
newmatrix[i, j] = k
|
||||
spec = ModelSpec(newmatrix, operations)
|
||||
try:
|
||||
newhash = self.api._hash_spec(spec)
|
||||
if newhash in self.api.fixed_statistics:
|
||||
return [n for n, m in enumerate(self.api.fixed_statistics.keys()) if m == newhash][0]
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
class ReturnFeatureLayer(torch.nn.Module):
|
||||
def __init__(self, mod):
|
||||
super(ReturnFeatureLayer, self).__init__()
|
||||
self.mod = mod
|
||||
def forward(self, x):
|
||||
return self.mod(x), x
|
||||
|
||||
|
||||
def return_feature_layer(network, prefix=''):
|
||||
#for attr_str in dir(network):
|
||||
# target_attr = getattr(network, attr_str)
|
||||
# if isinstance(target_attr, torch.nn.Linear):
|
||||
# setattr(network, attr_str, ReturnFeatureLayer(target_attr))
|
||||
for n, ch in list(network.named_children()):
|
||||
if isinstance(ch, torch.nn.Linear):
|
||||
setattr(network, n, ReturnFeatureLayer(ch))
|
||||
else:
|
||||
return_feature_layer(ch, prefix + '\t')
|
||||
|
||||
|
||||
class NDS:
|
||||
def __init__(self, searchspace):
|
||||
self.searchspace = searchspace
|
||||
data = json.load(open(f'nds_data/{searchspace}.json', 'r'))
|
||||
try:
|
||||
data = data['top'] + data['mid']
|
||||
except Exception as e:
|
||||
pass
|
||||
self.data = data
|
||||
def __iter__(self):
|
||||
for unique_hash in range(len(self)):
|
||||
network = self.get_network(unique_hash)
|
||||
yield unique_hash, network
|
||||
def get_network_config(self, uid):
|
||||
return self.data[uid]['net']
|
||||
def get_network_optim_config(self, uid):
|
||||
return self.data[uid]['optim']
|
||||
def get_network(self, uid):
|
||||
netinfo = self.data[uid]
|
||||
config = netinfo['net']
|
||||
#print(config)
|
||||
if 'genotype' in config:
|
||||
#print('geno')
|
||||
gen = config['genotype']
|
||||
genotype = Genotype(normal=gen['normal'], normal_concat=gen['normal_concat'], reduce=gen['reduce'], reduce_concat=gen['reduce_concat'])
|
||||
if '_in' in self.searchspace:
|
||||
network = NetworkImageNet(config['width'], 1, config['depth'], config['aux'], genotype)
|
||||
else:
|
||||
network = NetworkCIFAR(config['width'], 1, config['depth'], config['aux'], genotype)
|
||||
network.drop_path_prob = 0.
|
||||
#print(config)
|
||||
#print('genotype')
|
||||
L = config['depth']
|
||||
else:
|
||||
if 'bot_muls' in config and 'bms' not in config:
|
||||
config['bms'] = config['bot_muls']
|
||||
del config['bot_muls']
|
||||
if 'num_gs' in config and 'gws' not in config:
|
||||
config['gws'] = config['num_gs']
|
||||
del config['num_gs']
|
||||
config['nc'] = 1
|
||||
config['se_r'] = None
|
||||
config['stem_w'] = 12
|
||||
L = sum(config['ds'])
|
||||
if 'ResN' in self.searchspace:
|
||||
config['stem_type'] = 'res_stem_in'
|
||||
else:
|
||||
config['stem_type'] = 'simple_stem_in'
|
||||
#"res_stem_cifar": ResStemCifar,
|
||||
#"res_stem_in": ResStemIN,
|
||||
#"simple_stem_in": SimpleStemIN,
|
||||
if config['block_type'] == 'double_plain_block':
|
||||
config['block_type'] = 'vanilla_block'
|
||||
network = AnyNet(**config)
|
||||
return_feature_layer(network)
|
||||
return network
|
||||
def __getitem__(self, index):
|
||||
return index
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
def random_arch(self):
|
||||
return random.randint(0, len(self.data)-1)
|
||||
def get_final_accuracy(self, uid, acc_type, trainval):
|
||||
return 100.-self.data[uid]['test_ep_top1'][-1]
|
||||
|
||||
|
||||
def get_search_space(args):
|
||||
if args.nasspace == 'nasbench201':
|
||||
return Nasbench201(args.dataset, args.api_loc)
|
||||
elif args.nasspace == 'nasbench101':
|
||||
return Nasbench101(args.dataset, args.api_loc, args)
|
||||
elif args.nasspace == 'nds_resnet':
|
||||
return NDS('ResNet')
|
||||
elif args.nasspace == 'nds_amoeba':
|
||||
return NDS('Amoeba')
|
||||
elif args.nasspace == 'nds_amoeba_in':
|
||||
return NDS('Amoeba_in')
|
||||
elif args.nasspace == 'nds_darts_in':
|
||||
return NDS('DARTS_in')
|
||||
elif args.nasspace == 'nds_darts':
|
||||
return NDS('DARTS')
|
||||
elif args.nasspace == 'nds_darts_fix-w-d':
|
||||
return NDS('DARTS_fix-w-d')
|
||||
elif args.nasspace == 'nds_darts_lr-wd':
|
||||
return NDS('DARTS_lr-wd')
|
||||
elif args.nasspace == 'nds_enas':
|
||||
return NDS('ENAS')
|
||||
elif args.nasspace == 'nds_enas_in':
|
||||
return NDS('ENAS_in')
|
||||
elif args.nasspace == 'nds_enas_fix-w-d':
|
||||
return NDS('ENAS_fix-w-d')
|
||||
elif args.nasspace == 'nds_pnas':
|
||||
return NDS('PNAS')
|
||||
elif args.nasspace == 'nds_pnas_fix-w-d':
|
||||
return NDS('PNAS_fix-w-d')
|
||||
elif args.nasspace == 'nds_pnas_in':
|
||||
return NDS('PNAS_in')
|
||||
elif args.nasspace == 'nds_nasnet':
|
||||
return NDS('NASNet')
|
||||
elif args.nasspace == 'nds_nasnet_in':
|
||||
return NDS('NASNet_in')
|
||||
elif args.nasspace == 'nds_resnext-a':
|
||||
return NDS('ResNeXt-A')
|
||||
elif args.nasspace == 'nds_resnext-a_in':
|
||||
return NDS('ResNeXt-A_in')
|
||||
elif args.nasspace == 'nds_resnext-b':
|
||||
return NDS('ResNeXt-B')
|
||||
elif args.nasspace == 'nds_resnext-b_in':
|
||||
return NDS('ResNeXt-B_in')
|
||||
elif args.nasspace == 'nds_vanilla':
|
||||
return NDS('Vanilla')
|
||||
elif args.nasspace == 'nds_vanilla_lr-wd':
|
||||
return NDS('Vanilla_lr-wd')
|
||||
elif args.nasspace == 'nds_vanilla_lr-wd_in':
|
||||
return NDS('Vanilla_lr-wd_in')
|
||||
|
||||
@@ -1,144 +0,0 @@
|
||||
import os
|
||||
import argparse
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from datasets import get_datasets
|
||||
from config_utils import load_config
|
||||
|
||||
from nas_201_api import NASBench201API as API
|
||||
from models import get_cell_based_tiny_net
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def get_batch_jacobian(net, data_loader, device):
|
||||
data_iterator = iter(data_loader)
|
||||
x, target = next(data_iterator)
|
||||
x = x.to(device)
|
||||
net.zero_grad()
|
||||
x.requires_grad_(True)
|
||||
_, y = net(x)
|
||||
y.backward(torch.ones_like(y))
|
||||
jacob = x.grad.detach()
|
||||
return jacob, target.detach()
|
||||
|
||||
def plot_hist(jacob, ax, colour):
|
||||
xx = jacob.reshape(jacob.size(0), -1).cpu().numpy()
|
||||
corrs = np.corrcoef(xx)
|
||||
ax.hist(corrs.flatten(), bins=100, color=colour)
|
||||
|
||||
def decide_plot(acc, plt_cts, num_rows, boundaries=[60., 70., 80., 90.]):
|
||||
if acc < boundaries[0]:
|
||||
plt_col = 0
|
||||
accrange = f'< {boundaries[0]}%'
|
||||
elif acc < boundaries[1]:
|
||||
plt_col = 1
|
||||
accrange = f'[{boundaries[0]}% , {boundaries[1]}%)'
|
||||
elif acc < boundaries[2]:
|
||||
plt_col = 2
|
||||
accrange = f'[{boundaries[1]}% , {boundaries[2]}%)'
|
||||
elif acc < boundaries[3]:
|
||||
accrange = f'[{boundaries[2]}% , {boundaries[3]}%)'
|
||||
plt_col = 3
|
||||
else:
|
||||
accrange = f'>= {boundaries[3]}%'
|
||||
plt_col = 4
|
||||
|
||||
can_plot = False
|
||||
plt_row = 0
|
||||
if plt_cts[plt_col] < num_rows:
|
||||
can_plot = True
|
||||
plt_row = plt_cts[plt_col]
|
||||
plt_cts[plt_col] += 1
|
||||
|
||||
return can_plot, plt_row, plt_col, accrange
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Plot histograms of correlation matrix')
|
||||
parser.add_argument('--data_loc', default='../datasets/cifar/', type=str, help='dataset folder')
|
||||
parser.add_argument('--api_loc', default='NAS-Bench-201-v1_1-096897.pth',
|
||||
type=str, help='path to API')
|
||||
parser.add_argument('--arch_start', default=0, type=int)
|
||||
parser.add_argument('--arch_end', default=15625, type=int)
|
||||
parser.add_argument('--seed', default=42, type=int)
|
||||
parser.add_argument('--GPU', default='0', type=str)
|
||||
parser.add_argument('--batch_size', default=256, type=int)
|
||||
|
||||
args = parser.parse_args()
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU
|
||||
|
||||
# Reproducibility
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
ARCH_START = args.arch_start
|
||||
ARCH_END = args.arch_end
|
||||
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
train_data, valid_data, xshape, class_num = get_datasets('cifar10', args.data_loc, 0)
|
||||
|
||||
cifar_split = load_config('config_utils/cifar-split.txt', None, None)
|
||||
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size,
|
||||
num_workers=0, pin_memory=True, sampler= torch.utils.data.sampler.SubsetRandomSampler(train_split))
|
||||
|
||||
scores = []
|
||||
accs = []
|
||||
|
||||
plot_shape = (25, 5)
|
||||
num_plots = plot_shape[0]*plot_shape[1]
|
||||
fig, axes = plt.subplots(*plot_shape, sharex=True, figsize=(9, 9) )
|
||||
plt_cts = [0 for i in range(plot_shape[1])]
|
||||
|
||||
api = API(args.api_loc)
|
||||
|
||||
archs = list(range(ARCH_START, ARCH_END))
|
||||
colours = ['#811F41', '#A92941', '#D15141', '#EF7941', '#F99C4B']
|
||||
|
||||
strs = []
|
||||
random.shuffle(archs)
|
||||
for arch in archs:
|
||||
try:
|
||||
config = api.get_net_config(arch, 'cifar10')
|
||||
archinfo = api.query_meta_info_by_index(arch)
|
||||
acc = archinfo.get_metrics('cifar10-valid', 'x-valid')['accuracy']
|
||||
|
||||
network = get_cell_based_tiny_net(config)
|
||||
network = network.to(device)
|
||||
jacobs, labels = get_batch_jacobian(network, train_loader, device)
|
||||
|
||||
boundaries = [60., 70., 80., 90.]
|
||||
can_plt, row, col, accrange = decide_plot(acc, plt_cts, plot_shape[0], boundaries)
|
||||
if not can_plt:
|
||||
continue
|
||||
axes[row, col].axis('off')
|
||||
|
||||
plot_hist(jacobs, axes[row, col], colours[col])
|
||||
if row == 0:
|
||||
axes[row, col].set_title(f'{accrange}')
|
||||
|
||||
if row + 1 == plot_shape[0]:
|
||||
axes[row, col].axis('on')
|
||||
plt.setp(axes[row, col].get_xticklabels(), fontsize=12)
|
||||
axes[row, col].spines["top"].set_visible(False)
|
||||
axes[row, col].spines["right"].set_visible(False)
|
||||
axes[row, col].spines["left"].set_visible(False)
|
||||
axes[row, col].set_yticks([])
|
||||
|
||||
if sum(plt_cts) == num_plots:
|
||||
plt.tight_layout()
|
||||
plt.savefig(f'results/histograms_cifar10val_batch{args.batch_size}.png')
|
||||
plt.show()
|
||||
break
|
||||
except Exception as e:
|
||||
plt_cts[col] -= 1
|
||||
continue
|
||||
285
plot_scores.py
Normal file
285
plot_scores.py
Normal file
@@ -0,0 +1,285 @@
|
||||
import argparse
|
||||
import random
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib as mp
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
from decimal import Decimal
|
||||
from scipy.special import logit, expit
|
||||
from scipy import stats
|
||||
import seaborn as sns
|
||||
|
||||
'''
|
||||
font = {
|
||||
'size' : 18}
|
||||
|
||||
matplotlib.rc('font', **font)
|
||||
'''
|
||||
SMALL_SIZE = 10
|
||||
MEDIUM_SIZE = 12
|
||||
BIGGER_SIZE = 14
|
||||
|
||||
plt.rc('font', size=MEDIUM_SIZE) # controls default text sizes
|
||||
plt.rc('axes', titlesize=BIGGER_SIZE) # fontsize of the axes title
|
||||
plt.rc('axes', labelsize=MEDIUM_SIZE) # fontsize of the x and y labels
|
||||
plt.rc('xtick', labelsize=MEDIUM_SIZE) # fontsize of the tick labels
|
||||
plt.rc('ytick', labelsize=MEDIUM_SIZE) # fontsize of the tick labels
|
||||
plt.rc('legend', fontsize=MEDIUM_SIZE) # legend fontsize
|
||||
plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title
|
||||
|
||||
parser = argparse.ArgumentParser(description='NAS Without Training')
|
||||
parser.add_argument('--data_loc', default='../cifardata/', type=str, help='dataset folder')
|
||||
parser.add_argument('--api_loc', default='../NAS-Bench-201-v1_0-e61699.pth',
|
||||
type=str, help='path to API')
|
||||
parser.add_argument('--save_loc', default='results', type=str, help='folder to save results')
|
||||
parser.add_argument('--save_string', default='naswot', type=str, help='prefix of results file')
|
||||
parser.add_argument('--score', default='hook_logdet', type=str, help='the score to evaluate')
|
||||
parser.add_argument('--nasspace', default='nasbench201', type=str, help='the nas search space to use')
|
||||
parser.add_argument('--batch_size', default=128, type=int)
|
||||
parser.add_argument('--repeat', default=1, type=int, help='how often to repeat a single image with a batch')
|
||||
parser.add_argument('--augtype', default='none', type=str, help='which perturbations to use')
|
||||
parser.add_argument('--sigma', default=0.05, type=float, help='noise level if augtype is "gaussnoise"')
|
||||
parser.add_argument('--init', default='', type=str)
|
||||
parser.add_argument('--GPU', default='0', type=str)
|
||||
parser.add_argument('--seed', default=1, type=int)
|
||||
parser.add_argument('--trainval', action='store_true')
|
||||
parser.add_argument('--dropout', action='store_true')
|
||||
parser.add_argument('--dataset', default='cifar10', type=str)
|
||||
parser.add_argument('--maxofn', default=1, type=int, help='score is the max of this many evaluations of the network')
|
||||
parser.add_argument('--n_samples', default=100, type=int)
|
||||
parser.add_argument('--n_runs', default=500, type=int)
|
||||
parser.add_argument('--stem_out_channels', default=16, type=int, help='output channels of stem convolution (nasbench101)')
|
||||
parser.add_argument('--num_stacks', default=3, type=int, help='#stacks of modules (nasbench101)')
|
||||
parser.add_argument('--num_modules_per_stack', default=3, type=int, help='#modules per stack (nasbench101)')
|
||||
parser.add_argument('--num_labels', default=1, type=int, help='#classes (nasbench101)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f'{args.batch_size}')
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
filename = f'{args.save_loc}/{args.save_string}_{args.score}_{args.nasspace}_{args.dataset}{"_" + args.init + "_" if args.init != "" else args.init}_{"_dropout" if args.dropout else ""}_{args.augtype}_{args.sigma}_{args.repeat}_{args.trainval}_{args.batch_size}_{args.maxofn}_{args.seed}.npy'
|
||||
accfilename = f'{args.save_loc}/{args.save_string}_accs_{args.nasspace}_{args.dataset}_{args.trainval}.npy'
|
||||
|
||||
from matplotlib.colors import hsv_to_rgb
|
||||
print(filename)
|
||||
scores = np.load(filename)
|
||||
accs = np.load(accfilename)
|
||||
|
||||
def make_colours_by_hue(h, v=1.):
|
||||
return [hsv_to_rgb((h1 if h1 < 1. else h1-1., s, v)) for h1, s,v in zip(np.linspace(h, h+0.05, 5), np.linspace(1., .6, 5), np.linspace(0.1, 1., 5))]
|
||||
print(f'NETWORK accuracy with highest score {accs[np.argmax(scores)]}')
|
||||
|
||||
make_colours = lambda cols: [mp.colors.to_rgba(c) for c in cols]
|
||||
oranges = make_colours(['#811F41', '#A92941', '#D15141', '#EF7941', '#F99C4B'])
|
||||
blues = make_colours(['#190C30', '#241147', '#34208C', '#4882FA', '#81BAFC'])
|
||||
print(blues)
|
||||
print(make_colours_by_hue(0.9))
|
||||
if args.nasspace == 'nasbench101':
|
||||
#colours = blues
|
||||
colours = make_colours_by_hue(0.9)
|
||||
elif 'darts' in args.nasspace:
|
||||
#colours = sns.color_palette("BuGn_r", n_colors=5)
|
||||
colours = make_colours_by_hue(0.0)
|
||||
elif 'pnas' in args.nasspace:
|
||||
#colours = sns.color_palette("PuRd", n_colors=5)
|
||||
colours = make_colours_by_hue(0.1)
|
||||
elif args.nasspace == 'nasbench201':
|
||||
#colours = oranges
|
||||
colours = make_colours_by_hue(0.3)
|
||||
elif 'enas' in args.nasspace:
|
||||
#colours = oranges
|
||||
colours = make_colours_by_hue(0.4)
|
||||
elif 'resnet' in args.nasspace:
|
||||
#colours = sns.color_palette("viridis_r", n_colors=5)
|
||||
colours = make_colours_by_hue(0.5)
|
||||
elif 'amoeba' in args.nasspace:
|
||||
#colours = sns.color_palette("viridis_r", n_colors=5)
|
||||
colours = make_colours_by_hue(0.6)
|
||||
elif 'nasnet' in args.nasspace:
|
||||
#colours = sns.color_palette("viridis_r", n_colors=5)
|
||||
colours = make_colours_by_hue(0.7)
|
||||
elif 'resnext-b' in args.nasspace:
|
||||
#colours = sns.color_palette("viridis_r", n_colors=5)
|
||||
colours = make_colours_by_hue(0.8)
|
||||
else:
|
||||
from zlib import crc32
|
||||
|
||||
def bytes_to_float(b):
|
||||
return float(crc32(b) & 0xffffffff) / 2**32
|
||||
def str_to_float(s, encoding="utf-8"):
|
||||
return bytes_to_float(s.encode(encoding))
|
||||
#colours = sns.color_palette("Purples_r", n_colors=5)
|
||||
colours = make_colours_by_hue(str_to_float(args.nasspace))
|
||||
|
||||
def make_colordict(colours, points):
|
||||
cdict = {'red': [[pt, colour[0], colour[0]] for pt, colour in zip(points, colours)],
|
||||
'green':[[pt, colour[1], colour[1]] for pt, colour in zip(points, colours)],
|
||||
'blue':[[pt, colour[2], colour[2]] for pt, colour in zip(points, colours)]}
|
||||
return cdict
|
||||
|
||||
def make_colormap(dataset, space, colours):
|
||||
if dataset == 'cifar10' and 'resn' in space:
|
||||
points = [0., 0.85, 0.9, 0.95, 1.0, 1.0]
|
||||
colours = [colours[0]] + colours
|
||||
elif dataset == 'cifar10' and 'nds_darts' in space:
|
||||
points = [0., 0.8, 0.85, 0.9, 0.95, 1.0]
|
||||
colours = [colours[0]] + colours
|
||||
elif dataset == 'cifar10' and 'pnas' in space:
|
||||
points = [0., 0.875, 0.9, 0.925, 0.95, 1.0]
|
||||
colours = [colours[0]] + colours
|
||||
elif dataset == 'cifar10':
|
||||
points = [0., 0.6, 0.7, 0.8, 0.9, 1.0]
|
||||
colours = [colours[0]] + colours
|
||||
#cdict = {'red': [[0., colours[0][0], colours[0][0]]] + [[0.1*i + 0.6, colours[i][0], colours[i][0]] for i in range(len(colours))],
|
||||
# 'green':[[0., colours[0][1], colours[0][1]]] + [[0.1*i + 0.6, colours[i][1], colours[i][1]] for i in range(len(colours))],
|
||||
# 'blue':[[0., colours[0][2], colours[0][2]]] + [[0.1*i + 0.6, colours[i][2], colours[i][2]] for i in range(len(colours))]}
|
||||
elif dataset == 'cifar100':
|
||||
points = [0., 0.3, 0.4, 0.5, 0.6, 0.7, 1.0]
|
||||
colours = [colours[0]] + colours + [colours[-1]]
|
||||
|
||||
#cdict = {'red': [[0., colours[0][0], colours[0][0]]] + [[0.1*i + 0.3, colours[i][0], colours[i][0]] for i in range(len(colours))] + [[1., colours[-1][0], colours[-1][0]]] ,
|
||||
# 'green':[[0., colours[0][1], colours[0][1]]] + [[0.1*i + 0.3, colours[i][1], colours[i][1]] for i in range(len(colours))] + [[1., colours[-1][1], colours[-1][1]]] ,
|
||||
# 'blue':[[0., colours[0][2], colours[0][2]]] + [[0.1*i + 0.3, colours[i][2], colours[i][2]] for i in range(len(colours))] + [[1., colours[-1][2], colours[-1][2]]] }
|
||||
else:
|
||||
points = [0., 0.1, 0.2, 0.3, 0.4, 1.0]
|
||||
colours = colours + [colours[-1]]
|
||||
|
||||
#cdict = {'red': [[0.1*i, colours[i][0], colours[i][0]] for i in range(len(colours))] + [[1., colours[-1][0], colours[-1][0]]] ,
|
||||
# 'green': [[0.1*i, colours[i][1], colours[i][1]] for i in range(len(colours))] + [[1., colours[-1][1], colours[-1][1]]] ,
|
||||
# 'blue': [[0.1*i, colours[i][2], colours[i][2]] for i in range(len(colours))] + [[1., colours[-1][2], colours[-1][2]]] }
|
||||
|
||||
cdict = make_colordict(colours, points)
|
||||
return cdict
|
||||
cdict = make_colormap(args.dataset, args.nasspace, colours)
|
||||
newcmp = mp.colors.LinearSegmentedColormap('testCmap', segmentdata=cdict, N=256)
|
||||
|
||||
if args.nasspace == 'nasbench101':
|
||||
accs = accs[:10000]
|
||||
scores = scores[:10000]
|
||||
inds = accs > 0.5
|
||||
accs = accs[inds]
|
||||
scores = scores[inds]
|
||||
print(accs.shape)
|
||||
elif args.nasspace == 'nds_amoeba' or args.nasspace == 'nds_darts_fix-w-d':
|
||||
print(accs.shape)
|
||||
inds = accs > 15.
|
||||
accs = accs[inds]
|
||||
scores = scores[inds]
|
||||
print(accs.shape)
|
||||
elif args.nasspace == 'nds_darts':
|
||||
inds = accs > 15.
|
||||
from nasspace import get_search_space
|
||||
searchspace = get_search_space(args)
|
||||
accs = accs[inds]
|
||||
scores = scores[inds]
|
||||
print(accs.shape)
|
||||
else:
|
||||
print(accs.shape)
|
||||
inds = accs > 15.
|
||||
accs = accs[inds]
|
||||
scores = scores[inds]
|
||||
print(accs.shape)
|
||||
|
||||
inds = scores == 0.
|
||||
accs = accs[~inds]
|
||||
scores = scores[~inds]
|
||||
|
||||
|
||||
|
||||
if accs.size > 1000:
|
||||
inds = np.random.choice(accs.size, 1000, replace=False)
|
||||
accs = accs[inds]
|
||||
scores = scores[inds]
|
||||
|
||||
inds = np.isnan(scores)
|
||||
accs = accs[~inds]
|
||||
scores = scores[~inds]
|
||||
|
||||
tau, p = stats.kendalltau(accs, scores)
|
||||
|
||||
if args.nasspace == 'nasbench101':
|
||||
fig, ax = plt.subplots(1, 1, figsize=(5,5))
|
||||
else:
|
||||
fig, ax = plt.subplots(1, 1, figsize=(5,5))
|
||||
|
||||
def scale(x):
|
||||
return 2.**(10*x) - 1.
|
||||
|
||||
if args.score == 'svd':
|
||||
score_scale = lambda x: 10.0**x
|
||||
else:
|
||||
score_scale = lambda x: x
|
||||
|
||||
if args.nasspace == 'nonetwork':
|
||||
ax.scatter(scale(accs/100.), score_scale(scores), c=newcmp(accs/100., depths))
|
||||
else:
|
||||
ax.scatter(scale(accs/100. if args.nasspace == 'nasbench201' or 'nds' in args.nasspace else accs), score_scale(scores), c=newcmp(accs/100. if args.nasspace == 'nasbench201' or 'nds' in args.nasspace else accs))
|
||||
|
||||
|
||||
if args.dataset == 'cifar100':
|
||||
ax.set_xticks([scale(float(a)/100.) for a in [40, 60, 70]])
|
||||
ax.set_xticklabels([f'{a}' for a in [40, 60, 70]])
|
||||
elif args.dataset == 'imagenette2':
|
||||
ax.set_xticks([scale(float(a)/100.) for a in [40, 50, 60, 70]])
|
||||
ax.set_xticklabels([f'{a}' for a in [40, 50, 60, 70]])
|
||||
elif args.dataset == 'ImageNet16-120':
|
||||
ax.set_xticks([scale(float(a)/100.) for a in [20, 30, 40, 45]])
|
||||
ax.set_xticklabels([f'{a}' for a in [20, 30, 40, 45]])
|
||||
elif args.nasspace == 'nasbench101' and args.dataset == 'cifar10':
|
||||
ax.set_xticks([scale(float(a)/100.) for a in [50, 80, 90, 95]])
|
||||
ax.set_xticklabels([f'{a}' for a in [50, 80, 90, 95]])
|
||||
elif args.nasspace == 'nasbench201' and args.dataset == 'cifar10' and args.score == 'svd':
|
||||
ax.set_xticks([scale(float(a)/100.) for a in [50, 80, 90, 95]])
|
||||
ax.set_xticklabels([f'{a}' for a in [50, 80, 90, 95]])
|
||||
elif 'nds_resne' in args.nasspace and args.dataset == 'cifar10':
|
||||
ax.set_xticks([scale(float(a)/100.) for a in [85, 88, 91, 94]])
|
||||
ax.set_xticklabels([f'{a}' for a in [85, 88, 91, 94]])
|
||||
elif args.nasspace == 'nds_darts' and args.dataset == 'cifar10':
|
||||
ax.set_xticks([scale(float(a)/100.) for a in [80, 85, 90, 95]])
|
||||
ax.set_xticklabels([f'{a}' for a in [80, 85, 90, 95]])
|
||||
elif args.nasspace == 'nds_pnas' and args.dataset == 'cifar10':
|
||||
ax.set_xticks([scale(float(a)/100.) for a in [90., 91.5, 93, 94.5]])
|
||||
ax.set_xticklabels([f'{a}' for a in [90., 91.5, 93, 94.5]])
|
||||
else:
|
||||
ax.set_xticks([scale(float(a)/100.) for a in [50, 80, 90]])
|
||||
ax.set_xticklabels([f'{a}' for a in [50, 80, 90]])
|
||||
ax.spines["top"].set_visible(False)
|
||||
ax.spines["right"].set_visible(False)
|
||||
|
||||
nasspacenames = {
|
||||
'nds_resnext-a_in': 'NDS-ResNeXt-A(ImageNet)',
|
||||
'nds_resnext-b_in': 'NDS-ResNeXt-B(ImageNet)',
|
||||
'nds_resnext-a': 'NDS-ResNeXt-A(CIFAR10)',
|
||||
'nds_resnext-b': 'NDS-ResNeXt-B(CIFAR10)',
|
||||
'nds_nasnet': 'NDS-NASNet(CIFAR10)',
|
||||
'nds_nasnet_in': 'NDS-NASNet(ImageNet)',
|
||||
'nds_enas': 'NDS-ENAS(CIFAR10)',
|
||||
'nds_enas_in': 'NDS-ENAS(ImageNet)',
|
||||
'nds_amoeba': 'NDS-Amoeba(CIFAR10)',
|
||||
'nds_amoeba_in': 'NDS-Amoeba(ImageNet)',
|
||||
'nds_resnet': 'NDS-ResNet(CIFAR10)',
|
||||
'nds_pnas': 'NDS-PNAS(CIFAR10)',
|
||||
'nds_pnas_in': 'NDS-PNAS(ImageNet)',
|
||||
'nds_darts': 'NDS-DARTS(CIFAR10)',
|
||||
'nds_darts_in': 'NDS-DARTS(ImageNet)',
|
||||
'nds_darts_fix-w-d': 'NDS-DARTS fixed width/depth (CIFAR10)',
|
||||
'nds_darts_in_fix-w-d': 'NDS-DARTS fixed width/depth (ImageNet)',
|
||||
'nds_darts_in': 'NDS-DARTS(ImageNet)',
|
||||
'nasbench101': 'NAS-Bench-101',
|
||||
'nasbench201': 'NAS-Bench-201'
|
||||
}
|
||||
|
||||
ax.set_ylabel('Score')
|
||||
ax.set_xlabel(f'{"Test" if not args.trainval else "Validation"} accuracy')
|
||||
ax.set_title(f'{nasspacenames[args.nasspace]} {args.dataset} \n $\\tau=${tau:.3f}')
|
||||
|
||||
filename = f'{args.save_loc}/{args.save_string}_{args.score}_{args.nasspace}_{args.dataset}{"_" + args.init + "_" if args.init != "" else args.init}{"_dropout" if args.dropout else ""}_{args.augtype}_{args.sigma}_{args.repeat}_{args.trainval}_{args.batch_size}_{args.maxofn}_{args.seed}'
|
||||
print(filename)
|
||||
plt.tight_layout()
|
||||
plt.savefig(filename + '.pdf')
|
||||
plt.savefig(filename + '.png')
|
||||
|
||||
plt.show()
|
||||
@@ -1,87 +0,0 @@
|
||||
import numpy as np
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import pandas as pd
|
||||
from collections import OrderedDict
|
||||
|
||||
import tabulate
|
||||
parser = argparse.ArgumentParser(description='Produce tables')
|
||||
parser.add_argument('--data_loc', default='../datasets/cifar/', type=str, help='dataset folder')
|
||||
parser.add_argument('--save_loc', default='results', type=str, help='folder to save results')
|
||||
|
||||
parser.add_argument('--batch_size', default=256, type=int)
|
||||
parser.add_argument('--GPU', default='0', type=str)
|
||||
|
||||
parser.add_argument('--seed', default=1, type=int)
|
||||
parser.add_argument('--trainval', action='store_true')
|
||||
|
||||
parser.add_argument('--n_runs', default=500, type=int)
|
||||
|
||||
args = parser.parse_args()
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU
|
||||
|
||||
from statistics import mean, median, stdev as std
|
||||
|
||||
import torch
|
||||
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
df = []
|
||||
|
||||
datasets = OrderedDict()
|
||||
|
||||
datasets['CIFAR-10 (val)'] = ('cifar10-valid', 'x-valid', True)
|
||||
datasets['CIFAR-10 (test)'] = ('cifar10', 'ori-test', False)
|
||||
|
||||
### CIFAR-100
|
||||
datasets['CIFAR-100 (val)'] = ('cifar100', 'x-valid', False)
|
||||
datasets['CIFAR-100 (test)'] = ('cifar100', 'x-test', False)
|
||||
|
||||
datasets['ImageNet16-120 (val)'] = ('ImageNet16-120', 'x-valid', False)
|
||||
datasets['ImageNet16-120 (test)'] = ('ImageNet16-120', 'x-test', False)
|
||||
|
||||
|
||||
dataset_top1s = OrderedDict()
|
||||
|
||||
for n_samples in [10, 100]:
|
||||
method = f"Ours (N={n_samples})"
|
||||
|
||||
time = 0.
|
||||
|
||||
for dataset, params in datasets.items():
|
||||
top1s = []
|
||||
|
||||
dset = params[0]
|
||||
acc_type = 'accs' if 'test' in params[1] else 'val_accs'
|
||||
filename = f"{args.save_loc}/{dset}_{args.n_runs}_{n_samples}_{args.seed}.t7"
|
||||
|
||||
full_scores = torch.load(filename)
|
||||
if dataset == 'CIFAR-10 (test)':
|
||||
time = median(full_scores['times'])
|
||||
time = f"{time:.2f}"
|
||||
accs = []
|
||||
for n in range(args.n_runs):
|
||||
acc = full_scores[acc_type][n]
|
||||
accs.append(acc)
|
||||
dataset_top1s[dataset] = accs
|
||||
|
||||
cifar10_val = f"{mean(dataset_top1s['CIFAR-10 (val)']):.2f} +- {std(dataset_top1s['CIFAR-10 (val)']):.2f}"
|
||||
cifar10_test = f"{mean(dataset_top1s['CIFAR-10 (test)']):.2f} +- {std(dataset_top1s['CIFAR-10 (test)']):.2f}"
|
||||
|
||||
cifar100_val = f"{mean(dataset_top1s['CIFAR-100 (val)']):.2f} +- {std(dataset_top1s['CIFAR-100 (val)']):.2f}"
|
||||
cifar100_test = f"{mean(dataset_top1s['CIFAR-100 (test)']):.2f} +- {std(dataset_top1s['CIFAR-100 (test)']):.2f}"
|
||||
|
||||
imagenet_val = f"{mean(dataset_top1s['ImageNet16-120 (val)']):.2f} +- {std(dataset_top1s['ImageNet16-120 (val)']):.2f}"
|
||||
imagenet_test = f"{mean(dataset_top1s['ImageNet16-120 (test)']):.2f} +- {std(dataset_top1s['ImageNet16-120 (test)']):.2f}"
|
||||
|
||||
df.append([method, time, cifar10_val, cifar10_test, cifar100_val, cifar100_test, imagenet_val, imagenet_test])
|
||||
|
||||
|
||||
df = pd.DataFrame(df, columns=['Method','Search time (s)','CIFAR-10 (val)','CIFAR-10 (test)','CIFAR-100 (val)','CIFAR-100 (test)','ImageNet16-120 (val)','ImageNet16-120 (test)' ])
|
||||
|
||||
print(tabulate.tabulate(df.values,df.columns, tablefmt="pipe"))
|
||||
0
pycls/core/__init__.py
Normal file
0
pycls/core/__init__.py
Normal file
136
pycls/core/benchmark.py
Normal file
136
pycls/core/benchmark.py
Normal file
@@ -0,0 +1,136 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Benchmarking functions."""
|
||||
|
||||
import pycls.core.logging as logging
|
||||
import pycls.datasets.loader as loader
|
||||
import torch
|
||||
from pycls.core.config import cfg
|
||||
from pycls.core.timer import Timer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_time_eval(model):
|
||||
"""Computes precise model forward test time using dummy data."""
|
||||
# Use eval mode
|
||||
model.eval()
|
||||
# Generate a dummy mini-batch and copy data to GPU
|
||||
im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TEST.BATCH_SIZE / cfg.NUM_GPUS)
|
||||
if cfg.TASK == "jig":
|
||||
inputs = torch.rand(batch_size, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
|
||||
else:
|
||||
inputs = torch.zeros(batch_size, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
|
||||
# Compute precise forward pass time
|
||||
timer = Timer()
|
||||
total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
|
||||
for cur_iter in range(total_iter):
|
||||
# Reset the timers after the warmup phase
|
||||
if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
|
||||
timer.reset()
|
||||
# Forward
|
||||
timer.tic()
|
||||
model(inputs)
|
||||
torch.cuda.synchronize()
|
||||
timer.toc()
|
||||
return timer.average_time
|
||||
|
||||
|
||||
def compute_time_train(model, loss_fun):
|
||||
"""Computes precise model forward + backward time using dummy data."""
|
||||
# Use train mode
|
||||
model.train()
|
||||
# Generate a dummy mini-batch and copy data to GPU
|
||||
im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS)
|
||||
if cfg.TASK == "jig":
|
||||
inputs = torch.rand(batch_size, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
|
||||
else:
|
||||
inputs = torch.rand(batch_size, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
|
||||
if cfg.TASK in ['col', 'seg']:
|
||||
labels = torch.zeros(batch_size, im_size, im_size, dtype=torch.int64).cuda(non_blocking=False)
|
||||
else:
|
||||
labels = torch.zeros(batch_size, dtype=torch.int64).cuda(non_blocking=False)
|
||||
# Cache BatchNorm2D running stats
|
||||
bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
|
||||
bn_stats = [[bn.running_mean.clone(), bn.running_var.clone()] for bn in bns]
|
||||
# Compute precise forward backward pass time
|
||||
fw_timer, bw_timer = Timer(), Timer()
|
||||
total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
|
||||
for cur_iter in range(total_iter):
|
||||
# Reset the timers after the warmup phase
|
||||
if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
|
||||
fw_timer.reset()
|
||||
bw_timer.reset()
|
||||
# Forward
|
||||
fw_timer.tic()
|
||||
preds = model(inputs)
|
||||
if isinstance(preds, tuple):
|
||||
loss = loss_fun(preds[0], labels) + cfg.NAS.AUX_WEIGHT * loss_fun(preds[1], labels)
|
||||
preds = preds[0]
|
||||
else:
|
||||
loss = loss_fun(preds, labels)
|
||||
torch.cuda.synchronize()
|
||||
fw_timer.toc()
|
||||
# Backward
|
||||
bw_timer.tic()
|
||||
loss.backward()
|
||||
torch.cuda.synchronize()
|
||||
bw_timer.toc()
|
||||
# Restore BatchNorm2D running stats
|
||||
for bn, (mean, var) in zip(bns, bn_stats):
|
||||
bn.running_mean, bn.running_var = mean, var
|
||||
return fw_timer.average_time, bw_timer.average_time
|
||||
|
||||
|
||||
def compute_time_loader(data_loader):
|
||||
"""Computes loader time."""
|
||||
timer = Timer()
|
||||
loader.shuffle(data_loader, 0)
|
||||
data_loader_iterator = iter(data_loader)
|
||||
total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
|
||||
total_iter = min(total_iter, len(data_loader))
|
||||
for cur_iter in range(total_iter):
|
||||
if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
|
||||
timer.reset()
|
||||
timer.tic()
|
||||
next(data_loader_iterator)
|
||||
timer.toc()
|
||||
return timer.average_time
|
||||
|
||||
|
||||
def compute_time_full(model, loss_fun, train_loader, test_loader):
|
||||
"""Times model and data loader."""
|
||||
logger.info("Computing model and loader timings...")
|
||||
# Compute timings
|
||||
test_fw_time = compute_time_eval(model)
|
||||
train_fw_time, train_bw_time = compute_time_train(model, loss_fun)
|
||||
train_fw_bw_time = train_fw_time + train_bw_time
|
||||
train_loader_time = compute_time_loader(train_loader)
|
||||
# Output iter timing
|
||||
iter_times = {
|
||||
"test_fw_time": test_fw_time,
|
||||
"train_fw_time": train_fw_time,
|
||||
"train_bw_time": train_bw_time,
|
||||
"train_fw_bw_time": train_fw_bw_time,
|
||||
"train_loader_time": train_loader_time,
|
||||
}
|
||||
logger.info(logging.dump_log_data(iter_times, "iter_times"))
|
||||
# Output epoch timing
|
||||
epoch_times = {
|
||||
"test_fw_time": test_fw_time * len(test_loader),
|
||||
"train_fw_time": train_fw_time * len(train_loader),
|
||||
"train_bw_time": train_bw_time * len(train_loader),
|
||||
"train_fw_bw_time": train_fw_bw_time * len(train_loader),
|
||||
"train_loader_time": train_loader_time * len(train_loader),
|
||||
}
|
||||
logger.info(logging.dump_log_data(epoch_times, "epoch_times"))
|
||||
# Compute data loader overhead (assuming DATA_LOADER.NUM_WORKERS>1)
|
||||
overhead = max(0, train_loader_time - train_fw_bw_time) / train_fw_bw_time
|
||||
logger.info("Overhead of data loader is {:.2f}%".format(overhead * 100))
|
||||
88
pycls/core/builders.py
Normal file
88
pycls/core/builders.py
Normal file
@@ -0,0 +1,88 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Model and loss construction functions."""
|
||||
|
||||
import torch
|
||||
from pycls.core.config import cfg
|
||||
from pycls.models.anynet import AnyNet
|
||||
from pycls.models.effnet import EffNet
|
||||
from pycls.models.regnet import RegNet
|
||||
from pycls.models.resnet import ResNet
|
||||
from pycls.models.nas.nas import NAS
|
||||
from pycls.models.nas.nas_search import NAS_Search
|
||||
from pycls.models.nas_bench.model_builder import NAS_Bench
|
||||
|
||||
|
||||
class LabelSmoothedCrossEntropyLoss(torch.nn.Module):
|
||||
"""CrossEntropyLoss with label smoothing."""
|
||||
def __init__(self):
|
||||
super(LabelSmoothedCrossEntropyLoss, self).__init__()
|
||||
self.eps = cfg.MODEL.LABEL_SMOOTHING_EPS
|
||||
self.num_classes = cfg.MODEL.NUM_CLASSES
|
||||
|
||||
def forward(self, logits, target):
|
||||
pred = logits.log_softmax(dim=-1)
|
||||
with torch.no_grad():
|
||||
target_dist = torch.ones_like(pred) * self.eps / (self.num_classes - 1)
|
||||
target_dist.scatter_(-1, target.unsqueeze(-1), 1 - self.eps)
|
||||
return (-target_dist * pred).sum(dim=-1).mean()
|
||||
|
||||
|
||||
# Supported models
|
||||
_models = {
|
||||
"anynet": AnyNet,
|
||||
"effnet": EffNet,
|
||||
"resnet": ResNet,
|
||||
"regnet": RegNet,
|
||||
"nas": NAS,
|
||||
"nas_search": NAS_Search,
|
||||
"nas_bench": NAS_Bench,
|
||||
}
|
||||
|
||||
# Supported loss functions
|
||||
_loss_funs = {
|
||||
"cross_entropy": torch.nn.CrossEntropyLoss,
|
||||
"label_smoothed_cross_entropy": LabelSmoothedCrossEntropyLoss,
|
||||
}
|
||||
|
||||
|
||||
def get_model():
|
||||
"""Gets the model class specified in the config."""
|
||||
err_str = "Model type '{}' not supported"
|
||||
assert cfg.MODEL.TYPE in _models.keys(), err_str.format(cfg.MODEL.TYPE)
|
||||
return _models[cfg.MODEL.TYPE]
|
||||
|
||||
|
||||
def get_loss_fun():
|
||||
"""Gets the loss function class specified in the config."""
|
||||
err_str = "Loss function type '{}' not supported"
|
||||
assert cfg.MODEL.LOSS_FUN in _loss_funs.keys(), err_str.format(cfg.TRAIN.LOSS)
|
||||
return _loss_funs[cfg.MODEL.LOSS_FUN]
|
||||
|
||||
|
||||
def build_model():
|
||||
"""Builds the model."""
|
||||
return get_model()()
|
||||
|
||||
|
||||
def build_loss_fun():
|
||||
"""Build the loss function."""
|
||||
if cfg.TASK == "seg":
|
||||
return get_loss_fun()(ignore_index=255)
|
||||
else:
|
||||
return get_loss_fun()()
|
||||
|
||||
|
||||
def register_model(name, ctor):
|
||||
"""Registers a model dynamically."""
|
||||
_models[name] = ctor
|
||||
|
||||
|
||||
def register_loss_fun(name, ctor):
|
||||
"""Registers a loss function dynamically."""
|
||||
_loss_funs[name] = ctor
|
||||
98
pycls/core/checkpoint.py
Normal file
98
pycls/core/checkpoint.py
Normal file
@@ -0,0 +1,98 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Functions that handle saving and loading of checkpoints."""
|
||||
|
||||
import os
|
||||
|
||||
import pycls.core.distributed as dist
|
||||
import torch
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
# Common prefix for checkpoint file names
|
||||
_NAME_PREFIX = "model_epoch_"
|
||||
# Checkpoints directory name
|
||||
_DIR_NAME = "checkpoints"
|
||||
|
||||
|
||||
def get_checkpoint_dir():
|
||||
"""Retrieves the location for storing checkpoints."""
|
||||
return os.path.join(cfg.OUT_DIR, _DIR_NAME)
|
||||
|
||||
|
||||
def get_checkpoint(epoch):
|
||||
"""Retrieves the path to a checkpoint file."""
|
||||
name = "{}{:04d}.pyth".format(_NAME_PREFIX, epoch)
|
||||
return os.path.join(get_checkpoint_dir(), name)
|
||||
|
||||
|
||||
def get_last_checkpoint():
|
||||
"""Retrieves the most recent checkpoint (highest epoch number)."""
|
||||
checkpoint_dir = get_checkpoint_dir()
|
||||
# Checkpoint file names are in lexicographic order
|
||||
checkpoints = [f for f in os.listdir(checkpoint_dir) if _NAME_PREFIX in f]
|
||||
last_checkpoint_name = sorted(checkpoints)[-1]
|
||||
return os.path.join(checkpoint_dir, last_checkpoint_name)
|
||||
|
||||
|
||||
def has_checkpoint():
|
||||
"""Determines if there are checkpoints available."""
|
||||
checkpoint_dir = get_checkpoint_dir()
|
||||
if not os.path.exists(checkpoint_dir):
|
||||
return False
|
||||
return any(_NAME_PREFIX in f for f in os.listdir(checkpoint_dir))
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, epoch):
|
||||
"""Saves a checkpoint."""
|
||||
# Save checkpoints only from the master process
|
||||
if not dist.is_master_proc():
|
||||
return
|
||||
# Ensure that the checkpoint dir exists
|
||||
os.makedirs(get_checkpoint_dir(), exist_ok=True)
|
||||
# Omit the DDP wrapper in the multi-gpu setting
|
||||
sd = model.module.state_dict() if cfg.NUM_GPUS > 1 else model.state_dict()
|
||||
# Record the state
|
||||
if isinstance(optimizer, list):
|
||||
checkpoint = {
|
||||
"epoch": epoch,
|
||||
"model_state": sd,
|
||||
"optimizer_w_state": optimizer[0].state_dict(),
|
||||
"optimizer_a_state": optimizer[1].state_dict(),
|
||||
"cfg": cfg.dump(),
|
||||
}
|
||||
else:
|
||||
checkpoint = {
|
||||
"epoch": epoch,
|
||||
"model_state": sd,
|
||||
"optimizer_state": optimizer.state_dict(),
|
||||
"cfg": cfg.dump(),
|
||||
}
|
||||
# Write the checkpoint
|
||||
checkpoint_file = get_checkpoint(epoch + 1)
|
||||
torch.save(checkpoint, checkpoint_file)
|
||||
return checkpoint_file
|
||||
|
||||
|
||||
def load_checkpoint(checkpoint_file, model, optimizer=None):
|
||||
"""Loads the checkpoint from the given file."""
|
||||
err_str = "Checkpoint '{}' not found"
|
||||
assert os.path.exists(checkpoint_file), err_str.format(checkpoint_file)
|
||||
# Load the checkpoint on CPU to avoid GPU mem spike
|
||||
checkpoint = torch.load(checkpoint_file, map_location="cpu")
|
||||
# Account for the DDP wrapper in the multi-gpu setting
|
||||
ms = model.module if cfg.NUM_GPUS > 1 else model
|
||||
ms.load_state_dict(checkpoint["model_state"])
|
||||
# Load the optimizer state (commonly not done when fine-tuning)
|
||||
if optimizer:
|
||||
if isinstance(optimizer, list):
|
||||
optimizer[0].load_state_dict(checkpoint["optimizer_w_state"])
|
||||
optimizer[1].load_state_dict(checkpoint["optimizer_a_state"])
|
||||
else:
|
||||
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
||||
return checkpoint["epoch"]
|
||||
500
pycls/core/config.py
Normal file
500
pycls/core/config.py
Normal file
@@ -0,0 +1,500 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Configuration file (powered by YACS)."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
from pycls.core.io import cache_url
|
||||
from yacs.config import CfgNode as CfgNode
|
||||
|
||||
|
||||
# Global config object
|
||||
_C = CfgNode()
|
||||
|
||||
# Example usage:
|
||||
# from core.config import cfg
|
||||
cfg = _C
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Model options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.MODEL = CfgNode()
|
||||
|
||||
# Model type
|
||||
_C.MODEL.TYPE = ""
|
||||
|
||||
# Number of weight layers
|
||||
_C.MODEL.DEPTH = 0
|
||||
|
||||
# Number of input channels
|
||||
_C.MODEL.INPUT_CHANNELS = 3
|
||||
|
||||
# Number of classes
|
||||
_C.MODEL.NUM_CLASSES = 10
|
||||
|
||||
# Loss function (see pycls/core/builders.py for options)
|
||||
_C.MODEL.LOSS_FUN = "cross_entropy"
|
||||
|
||||
# Label smoothing eps
|
||||
_C.MODEL.LABEL_SMOOTHING_EPS = 0.0
|
||||
|
||||
# ASPP channels
|
||||
_C.MODEL.ASPP_CHANNELS = 256
|
||||
|
||||
# ASPP dilation rates
|
||||
_C.MODEL.ASPP_RATES = [6, 12, 18]
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# ResNet options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.RESNET = CfgNode()
|
||||
|
||||
# Transformation function (see pycls/models/resnet.py for options)
|
||||
_C.RESNET.TRANS_FUN = "basic_transform"
|
||||
|
||||
# Number of groups to use (1 -> ResNet; > 1 -> ResNeXt)
|
||||
_C.RESNET.NUM_GROUPS = 1
|
||||
|
||||
# Width of each group (64 -> ResNet; 4 -> ResNeXt)
|
||||
_C.RESNET.WIDTH_PER_GROUP = 64
|
||||
|
||||
# Apply stride to 1x1 conv (True -> MSRA; False -> fb.torch)
|
||||
_C.RESNET.STRIDE_1X1 = True
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# AnyNet options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.ANYNET = CfgNode()
|
||||
|
||||
# Stem type
|
||||
_C.ANYNET.STEM_TYPE = "simple_stem_in"
|
||||
|
||||
# Stem width
|
||||
_C.ANYNET.STEM_W = 32
|
||||
|
||||
# Block type
|
||||
_C.ANYNET.BLOCK_TYPE = "res_bottleneck_block"
|
||||
|
||||
# Depth for each stage (number of blocks in the stage)
|
||||
_C.ANYNET.DEPTHS = []
|
||||
|
||||
# Width for each stage (width of each block in the stage)
|
||||
_C.ANYNET.WIDTHS = []
|
||||
|
||||
# Strides for each stage (applies to the first block of each stage)
|
||||
_C.ANYNET.STRIDES = []
|
||||
|
||||
# Bottleneck multipliers for each stage (applies to bottleneck block)
|
||||
_C.ANYNET.BOT_MULS = []
|
||||
|
||||
# Group widths for each stage (applies to bottleneck block)
|
||||
_C.ANYNET.GROUP_WS = []
|
||||
|
||||
# Whether SE is enabled for res_bottleneck_block
|
||||
_C.ANYNET.SE_ON = False
|
||||
|
||||
# SE ratio
|
||||
_C.ANYNET.SE_R = 0.25
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# RegNet options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.REGNET = CfgNode()
|
||||
|
||||
# Stem type
|
||||
_C.REGNET.STEM_TYPE = "simple_stem_in"
|
||||
|
||||
# Stem width
|
||||
_C.REGNET.STEM_W = 32
|
||||
|
||||
# Block type
|
||||
_C.REGNET.BLOCK_TYPE = "res_bottleneck_block"
|
||||
|
||||
# Stride of each stage
|
||||
_C.REGNET.STRIDE = 2
|
||||
|
||||
# Squeeze-and-Excitation (RegNetY)
|
||||
_C.REGNET.SE_ON = False
|
||||
_C.REGNET.SE_R = 0.25
|
||||
|
||||
# Depth
|
||||
_C.REGNET.DEPTH = 10
|
||||
|
||||
# Initial width
|
||||
_C.REGNET.W0 = 32
|
||||
|
||||
# Slope
|
||||
_C.REGNET.WA = 5.0
|
||||
|
||||
# Quantization
|
||||
_C.REGNET.WM = 2.5
|
||||
|
||||
# Group width
|
||||
_C.REGNET.GROUP_W = 16
|
||||
|
||||
# Bottleneck multiplier (bm = 1 / b from the paper)
|
||||
_C.REGNET.BOT_MUL = 1.0
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# EfficientNet options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.EN = CfgNode()
|
||||
|
||||
# Stem width
|
||||
_C.EN.STEM_W = 32
|
||||
|
||||
# Depth for each stage (number of blocks in the stage)
|
||||
_C.EN.DEPTHS = []
|
||||
|
||||
# Width for each stage (width of each block in the stage)
|
||||
_C.EN.WIDTHS = []
|
||||
|
||||
# Expansion ratios for MBConv blocks in each stage
|
||||
_C.EN.EXP_RATIOS = []
|
||||
|
||||
# Squeeze-and-Excitation (SE) ratio
|
||||
_C.EN.SE_R = 0.25
|
||||
|
||||
# Strides for each stage (applies to the first block of each stage)
|
||||
_C.EN.STRIDES = []
|
||||
|
||||
# Kernel sizes for each stage
|
||||
_C.EN.KERNELS = []
|
||||
|
||||
# Head width
|
||||
_C.EN.HEAD_W = 1280
|
||||
|
||||
# Drop connect ratio
|
||||
_C.EN.DC_RATIO = 0.0
|
||||
|
||||
# Dropout ratio
|
||||
_C.EN.DROPOUT_RATIO = 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# NAS options
|
||||
# ---------------------------------------------------------------------------- #
|
||||
_C.NAS = CfgNode()
|
||||
|
||||
# Cell genotype
|
||||
_C.NAS.GENOTYPE = 'nas'
|
||||
|
||||
# Custom genotype
|
||||
_C.NAS.CUSTOM_GENOTYPE = []
|
||||
|
||||
# Base NAS width
|
||||
_C.NAS.WIDTH = 16
|
||||
|
||||
# Total number of cells
|
||||
_C.NAS.DEPTH = 20
|
||||
|
||||
# Auxiliary heads
|
||||
_C.NAS.AUX = False
|
||||
|
||||
# Weight for auxiliary heads
|
||||
_C.NAS.AUX_WEIGHT = 0.4
|
||||
|
||||
# Drop path probability
|
||||
_C.NAS.DROP_PROB = 0.0
|
||||
|
||||
# Matrix in NAS Bench
|
||||
_C.NAS.MATRIX = []
|
||||
|
||||
# Operations in NAS Bench
|
||||
_C.NAS.OPS = []
|
||||
|
||||
# Number of stacks in NAS Bench
|
||||
_C.NAS.NUM_STACKS = 3
|
||||
|
||||
# Number of modules per stack in NAS Bench
|
||||
_C.NAS.NUM_MODULES_PER_STACK = 3
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Batch norm options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.BN = CfgNode()
|
||||
|
||||
# BN epsilon
|
||||
_C.BN.EPS = 1e-5
|
||||
|
||||
# BN momentum (BN momentum in PyTorch = 1 - BN momentum in Caffe2)
|
||||
_C.BN.MOM = 0.1
|
||||
|
||||
# Precise BN stats
|
||||
_C.BN.USE_PRECISE_STATS = False
|
||||
_C.BN.NUM_SAMPLES_PRECISE = 1024
|
||||
|
||||
# Initialize the gamma of the final BN of each block to zero
|
||||
_C.BN.ZERO_INIT_FINAL_GAMMA = False
|
||||
|
||||
# Use a different weight decay for BN layers
|
||||
_C.BN.USE_CUSTOM_WEIGHT_DECAY = False
|
||||
_C.BN.CUSTOM_WEIGHT_DECAY = 0.0
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Optimizer options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.OPTIM = CfgNode()
|
||||
|
||||
# Base learning rate
|
||||
_C.OPTIM.BASE_LR = 0.1
|
||||
|
||||
# Learning rate policy select from {'cos', 'exp', 'steps'}
|
||||
_C.OPTIM.LR_POLICY = "cos"
|
||||
|
||||
# Exponential decay factor
|
||||
_C.OPTIM.GAMMA = 0.1
|
||||
|
||||
# Steps for 'steps' policy (in epochs)
|
||||
_C.OPTIM.STEPS = []
|
||||
|
||||
# Learning rate multiplier for 'steps' policy
|
||||
_C.OPTIM.LR_MULT = 0.1
|
||||
|
||||
# Maximal number of epochs
|
||||
_C.OPTIM.MAX_EPOCH = 200
|
||||
|
||||
# Momentum
|
||||
_C.OPTIM.MOMENTUM = 0.9
|
||||
|
||||
# Momentum dampening
|
||||
_C.OPTIM.DAMPENING = 0.0
|
||||
|
||||
# Nesterov momentum
|
||||
_C.OPTIM.NESTEROV = True
|
||||
|
||||
# L2 regularization
|
||||
_C.OPTIM.WEIGHT_DECAY = 5e-4
|
||||
|
||||
# Start the warm up from OPTIM.BASE_LR * OPTIM.WARMUP_FACTOR
|
||||
_C.OPTIM.WARMUP_FACTOR = 0.1
|
||||
|
||||
# Gradually warm up the OPTIM.BASE_LR over this number of epochs
|
||||
_C.OPTIM.WARMUP_EPOCHS = 0
|
||||
|
||||
# Update the learning rate per iter
|
||||
_C.OPTIM.ITER_LR = False
|
||||
|
||||
# Base learning rate for arch
|
||||
_C.OPTIM.ARCH_BASE_LR = 0.0003
|
||||
|
||||
# L2 regularization for arch
|
||||
_C.OPTIM.ARCH_WEIGHT_DECAY = 0.001
|
||||
|
||||
# Optimizer for arch
|
||||
_C.OPTIM.ARCH_OPTIM = 'adam'
|
||||
|
||||
# Epoch to start optimizing arch
|
||||
_C.OPTIM.ARCH_EPOCH = 0.0
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Training options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.TRAIN = CfgNode()
|
||||
|
||||
# Dataset and split
|
||||
_C.TRAIN.DATASET = ""
|
||||
_C.TRAIN.SPLIT = "train"
|
||||
|
||||
# Total mini-batch size
|
||||
_C.TRAIN.BATCH_SIZE = 128
|
||||
|
||||
# Image size
|
||||
_C.TRAIN.IM_SIZE = 224
|
||||
|
||||
# Evaluate model on test data every eval period epochs
|
||||
_C.TRAIN.EVAL_PERIOD = 1
|
||||
|
||||
# Save model checkpoint every checkpoint period epochs
|
||||
_C.TRAIN.CHECKPOINT_PERIOD = 1
|
||||
|
||||
# Resume training from the latest checkpoint in the output directory
|
||||
_C.TRAIN.AUTO_RESUME = True
|
||||
|
||||
# Weights to start training from
|
||||
_C.TRAIN.WEIGHTS = ""
|
||||
|
||||
# Percentage of gray images in jig
|
||||
_C.TRAIN.GRAY_PERCENTAGE = 0.0
|
||||
|
||||
# Portion to create trainA/trainB split
|
||||
_C.TRAIN.PORTION = 1.0
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Testing options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.TEST = CfgNode()
|
||||
|
||||
# Dataset and split
|
||||
_C.TEST.DATASET = ""
|
||||
_C.TEST.SPLIT = "val"
|
||||
|
||||
# Total mini-batch size
|
||||
_C.TEST.BATCH_SIZE = 200
|
||||
|
||||
# Image size
|
||||
_C.TEST.IM_SIZE = 256
|
||||
|
||||
# Weights to use for testing
|
||||
_C.TEST.WEIGHTS = ""
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Common train/test data loader options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.DATA_LOADER = CfgNode()
|
||||
|
||||
# Number of data loader workers per process
|
||||
_C.DATA_LOADER.NUM_WORKERS = 8
|
||||
|
||||
# Load data to pinned host memory
|
||||
_C.DATA_LOADER.PIN_MEMORY = True
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Memory options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.MEM = CfgNode()
|
||||
|
||||
# Perform ReLU inplace
|
||||
_C.MEM.RELU_INPLACE = True
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# CUDNN options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.CUDNN = CfgNode()
|
||||
|
||||
# Perform benchmarking to select the fastest CUDNN algorithms to use
|
||||
# Note that this may increase the memory usage and will likely not result
|
||||
# in overall speedups when variable size inputs are used (e.g. COCO training)
|
||||
_C.CUDNN.BENCHMARK = True
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Precise timing options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.PREC_TIME = CfgNode()
|
||||
|
||||
# Number of iterations to warm up the caches
|
||||
_C.PREC_TIME.WARMUP_ITER = 3
|
||||
|
||||
# Number of iterations to compute avg time
|
||||
_C.PREC_TIME.NUM_ITER = 30
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Misc options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
|
||||
# Number of GPUs to use (applies to both training and testing)
|
||||
_C.NUM_GPUS = 1
|
||||
|
||||
# Task (cls, seg, rot, col, jig)
|
||||
_C.TASK = "cls"
|
||||
|
||||
# Grid in Jigsaw (2, 3); no effect if TASK is not jig
|
||||
_C.JIGSAW_GRID = 3
|
||||
|
||||
# Output directory
|
||||
_C.OUT_DIR = "/tmp"
|
||||
|
||||
# Config destination (in OUT_DIR)
|
||||
_C.CFG_DEST = "config.yaml"
|
||||
|
||||
# Note that non-determinism may still be present due to non-deterministic
|
||||
# operator implementations in GPU operator libraries
|
||||
_C.RNG_SEED = 1
|
||||
|
||||
# Log destination ('stdout' or 'file')
|
||||
_C.LOG_DEST = "stdout"
|
||||
|
||||
# Log period in iters
|
||||
_C.LOG_PERIOD = 10
|
||||
|
||||
# Distributed backend
|
||||
_C.DIST_BACKEND = "nccl"
|
||||
|
||||
# Hostname and port for initializing multi-process groups
|
||||
_C.HOST = "localhost"
|
||||
_C.PORT = 10001
|
||||
|
||||
# Models weights referred to by URL are downloaded to this local cache
|
||||
_C.DOWNLOAD_CACHE = "/tmp/pycls-download-cache"
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Deprecated keys
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
|
||||
_C.register_deprecated_key("PREC_TIME.BATCH_SIZE")
|
||||
_C.register_deprecated_key("PREC_TIME.ENABLED")
|
||||
|
||||
|
||||
def assert_and_infer_cfg(cache_urls=True):
|
||||
"""Checks config values invariants."""
|
||||
err_str = "The first lr step must start at 0"
|
||||
assert not _C.OPTIM.STEPS or _C.OPTIM.STEPS[0] == 0, err_str
|
||||
data_splits = ["train", "val", "test"]
|
||||
err_str = "Data split '{}' not supported"
|
||||
assert _C.TRAIN.SPLIT in data_splits, err_str.format(_C.TRAIN.SPLIT)
|
||||
assert _C.TEST.SPLIT in data_splits, err_str.format(_C.TEST.SPLIT)
|
||||
err_str = "Mini-batch size should be a multiple of NUM_GPUS."
|
||||
assert _C.TRAIN.BATCH_SIZE % _C.NUM_GPUS == 0, err_str
|
||||
assert _C.TEST.BATCH_SIZE % _C.NUM_GPUS == 0, err_str
|
||||
err_str = "Precise BN stats computation not verified for > 1 GPU"
|
||||
assert not _C.BN.USE_PRECISE_STATS or _C.NUM_GPUS == 1, err_str
|
||||
err_str = "Log destination '{}' not supported"
|
||||
assert _C.LOG_DEST in ["stdout", "file"], err_str.format(_C.LOG_DEST)
|
||||
if cache_urls:
|
||||
cache_cfg_urls()
|
||||
|
||||
|
||||
def cache_cfg_urls():
|
||||
"""Download URLs in config, cache them, and rewrite cfg to use cached file."""
|
||||
_C.TRAIN.WEIGHTS = cache_url(_C.TRAIN.WEIGHTS, _C.DOWNLOAD_CACHE)
|
||||
_C.TEST.WEIGHTS = cache_url(_C.TEST.WEIGHTS, _C.DOWNLOAD_CACHE)
|
||||
|
||||
|
||||
def dump_cfg():
|
||||
"""Dumps the config to the output directory."""
|
||||
cfg_file = os.path.join(_C.OUT_DIR, _C.CFG_DEST)
|
||||
with open(cfg_file, "w") as f:
|
||||
_C.dump(stream=f)
|
||||
|
||||
|
||||
def load_cfg(out_dir, cfg_dest="config.yaml"):
|
||||
"""Loads config from specified output directory."""
|
||||
cfg_file = os.path.join(out_dir, cfg_dest)
|
||||
_C.merge_from_file(cfg_file)
|
||||
|
||||
|
||||
def load_cfg_fom_args(description="Config file options."):
|
||||
"""Load config from command line arguments and set any specified options."""
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
help_s = "Config file location"
|
||||
parser.add_argument("--cfg", dest="cfg_file", help=help_s, required=True, type=str)
|
||||
help_s = "See pycls/core/config.py for all options"
|
||||
parser.add_argument("opts", help=help_s, default=None, nargs=argparse.REMAINDER)
|
||||
if len(sys.argv) == 1:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
args = parser.parse_args()
|
||||
_C.merge_from_file(args.cfg_file)
|
||||
_C.merge_from_list(args.opts)
|
||||
157
pycls/core/distributed.py
Normal file
157
pycls/core/distributed.py
Normal file
@@ -0,0 +1,157 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Distributed helpers."""
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
import signal
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
import torch
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
def is_master_proc():
|
||||
"""Determines if the current process is the master process.
|
||||
|
||||
Master process is responsible for logging, writing and loading checkpoints. In
|
||||
the multi GPU setting, we assign the master role to the rank 0 process. When
|
||||
training using a single GPU, there is a single process which is considered master.
|
||||
"""
|
||||
return cfg.NUM_GPUS == 1 or torch.distributed.get_rank() == 0
|
||||
|
||||
|
||||
def init_process_group(proc_rank, world_size):
|
||||
"""Initializes the default process group."""
|
||||
# Set the GPU to use
|
||||
torch.cuda.set_device(proc_rank)
|
||||
# Initialize the process group
|
||||
torch.distributed.init_process_group(
|
||||
backend=cfg.DIST_BACKEND,
|
||||
init_method="tcp://{}:{}".format(cfg.HOST, cfg.PORT),
|
||||
world_size=world_size,
|
||||
rank=proc_rank,
|
||||
)
|
||||
|
||||
|
||||
def destroy_process_group():
|
||||
"""Destroys the default process group."""
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def scaled_all_reduce(tensors):
|
||||
"""Performs the scaled all_reduce operation on the provided tensors.
|
||||
|
||||
The input tensors are modified in-place. Currently supports only the sum
|
||||
reduction operator. The reduced values are scaled by the inverse size of the
|
||||
process group (equivalent to cfg.NUM_GPUS).
|
||||
"""
|
||||
# There is no need for reduction in the single-proc case
|
||||
if cfg.NUM_GPUS == 1:
|
||||
return tensors
|
||||
# Queue the reductions
|
||||
reductions = []
|
||||
for tensor in tensors:
|
||||
reduction = torch.distributed.all_reduce(tensor, async_op=True)
|
||||
reductions.append(reduction)
|
||||
# Wait for reductions to finish
|
||||
for reduction in reductions:
|
||||
reduction.wait()
|
||||
# Scale the results
|
||||
for tensor in tensors:
|
||||
tensor.mul_(1.0 / cfg.NUM_GPUS)
|
||||
return tensors
|
||||
|
||||
|
||||
class ChildException(Exception):
|
||||
"""Wraps an exception from a child process."""
|
||||
|
||||
def __init__(self, child_trace):
|
||||
super(ChildException, self).__init__(child_trace)
|
||||
|
||||
|
||||
class ErrorHandler(object):
|
||||
"""Multiprocessing error handler (based on fairseq's).
|
||||
|
||||
Listens for errors in child processes and propagates the tracebacks to the parent.
|
||||
"""
|
||||
|
||||
def __init__(self, error_queue):
|
||||
# Shared error queue
|
||||
self.error_queue = error_queue
|
||||
# Children processes sharing the error queue
|
||||
self.children_pids = []
|
||||
# Start a thread listening to errors
|
||||
self.error_listener = threading.Thread(target=self.listen, daemon=True)
|
||||
self.error_listener.start()
|
||||
# Register the signal handler
|
||||
signal.signal(signal.SIGUSR1, self.signal_handler)
|
||||
|
||||
def add_child(self, pid):
|
||||
"""Registers a child process."""
|
||||
self.children_pids.append(pid)
|
||||
|
||||
def listen(self):
|
||||
"""Listens for errors in the error queue."""
|
||||
# Wait until there is an error in the queue
|
||||
child_trace = self.error_queue.get()
|
||||
# Put the error back for the signal handler
|
||||
self.error_queue.put(child_trace)
|
||||
# Invoke the signal handler
|
||||
os.kill(os.getpid(), signal.SIGUSR1)
|
||||
|
||||
def signal_handler(self, _sig_num, _stack_frame):
|
||||
"""Signal handler."""
|
||||
# Kill children processes
|
||||
for pid in self.children_pids:
|
||||
os.kill(pid, signal.SIGINT)
|
||||
# Propagate the error from the child process
|
||||
raise ChildException(self.error_queue.get())
|
||||
|
||||
|
||||
def run(proc_rank, world_size, error_queue, fun, fun_args, fun_kwargs):
|
||||
"""Runs a function from a child process."""
|
||||
try:
|
||||
# Initialize the process group
|
||||
init_process_group(proc_rank, world_size)
|
||||
# Run the function
|
||||
fun(*fun_args, **fun_kwargs)
|
||||
except KeyboardInterrupt:
|
||||
# Killed by the parent process
|
||||
pass
|
||||
except Exception:
|
||||
# Propagate exception to the parent process
|
||||
error_queue.put(traceback.format_exc())
|
||||
finally:
|
||||
# Destroy the process group
|
||||
destroy_process_group()
|
||||
|
||||
|
||||
def multi_proc_run(num_proc, fun, fun_args=(), fun_kwargs=None):
|
||||
"""Runs a function in a multi-proc setting (unless num_proc == 1)."""
|
||||
# There is no need for multi-proc in the single-proc case
|
||||
fun_kwargs = fun_kwargs if fun_kwargs else {}
|
||||
if num_proc == 1:
|
||||
fun(*fun_args, **fun_kwargs)
|
||||
return
|
||||
# Handle errors from training subprocesses
|
||||
error_queue = multiprocessing.SimpleQueue()
|
||||
error_handler = ErrorHandler(error_queue)
|
||||
# Run each training subprocess
|
||||
ps = []
|
||||
for i in range(num_proc):
|
||||
p_i = multiprocessing.Process(
|
||||
target=run, args=(i, num_proc, error_queue, fun, fun_args, fun_kwargs)
|
||||
)
|
||||
ps.append(p_i)
|
||||
p_i.start()
|
||||
error_handler.add_child(p_i.pid)
|
||||
# Wait for each subprocess to finish
|
||||
for p in ps:
|
||||
p.join()
|
||||
77
pycls/core/io.py
Normal file
77
pycls/core/io.py
Normal file
@@ -0,0 +1,77 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""IO utilities (adapted from Detectron)"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from urllib import request as urlrequest
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_PYCLS_BASE_URL = "https://dl.fbaipublicfiles.com/pycls"
|
||||
|
||||
|
||||
def cache_url(url_or_file, cache_dir):
|
||||
"""Download the file specified by the URL to the cache_dir and return the path to
|
||||
the cached file. If the argument is not a URL, simply return it as is.
|
||||
"""
|
||||
is_url = re.match(r"^(?:http)s?://", url_or_file, re.IGNORECASE) is not None
|
||||
if not is_url:
|
||||
return url_or_file
|
||||
url = url_or_file
|
||||
err_str = "pycls only automatically caches URLs in the pycls S3 bucket: {}"
|
||||
assert url.startswith(_PYCLS_BASE_URL), err_str.format(_PYCLS_BASE_URL)
|
||||
cache_file_path = url.replace(_PYCLS_BASE_URL, cache_dir)
|
||||
if os.path.exists(cache_file_path):
|
||||
return cache_file_path
|
||||
cache_file_dir = os.path.dirname(cache_file_path)
|
||||
if not os.path.exists(cache_file_dir):
|
||||
os.makedirs(cache_file_dir)
|
||||
logger.info("Downloading remote file {} to {}".format(url, cache_file_path))
|
||||
download_url(url, cache_file_path)
|
||||
return cache_file_path
|
||||
|
||||
|
||||
def _progress_bar(count, total):
|
||||
"""Report download progress. Credit:
|
||||
https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console/27871113
|
||||
"""
|
||||
bar_len = 60
|
||||
filled_len = int(round(bar_len * count / float(total)))
|
||||
percents = round(100.0 * count / float(total), 1)
|
||||
bar = "=" * filled_len + "-" * (bar_len - filled_len)
|
||||
sys.stdout.write(
|
||||
" [{}] {}% of {:.1f}MB file \r".format(bar, percents, total / 1024 / 1024)
|
||||
)
|
||||
sys.stdout.flush()
|
||||
if count >= total:
|
||||
sys.stdout.write("\n")
|
||||
|
||||
|
||||
def download_url(url, dst_file_path, chunk_size=8192, progress_hook=_progress_bar):
|
||||
"""Download url and write it to dst_file_path. Credit:
|
||||
https://stackoverflow.com/questions/2028517/python-urllib2-progress-hook
|
||||
"""
|
||||
req = urlrequest.Request(url)
|
||||
response = urlrequest.urlopen(req)
|
||||
total_size = response.info().get("Content-Length").strip()
|
||||
total_size = int(total_size)
|
||||
bytes_so_far = 0
|
||||
with open(dst_file_path, "wb") as f:
|
||||
while 1:
|
||||
chunk = response.read(chunk_size)
|
||||
bytes_so_far += len(chunk)
|
||||
if not chunk:
|
||||
break
|
||||
if progress_hook:
|
||||
progress_hook(bytes_so_far, total_size)
|
||||
f.write(chunk)
|
||||
return bytes_so_far
|
||||
138
pycls/core/logging.py
Normal file
138
pycls/core/logging.py
Normal file
@@ -0,0 +1,138 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Logging."""
|
||||
|
||||
import builtins
|
||||
import decimal
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pycls.core.distributed as dist
|
||||
import simplejson
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
# Show filename and line number in logs
|
||||
_FORMAT = "[%(filename)s: %(lineno)3d]: %(message)s"
|
||||
|
||||
# Log file name (for cfg.LOG_DEST = 'file')
|
||||
_LOG_FILE = "stdout.log"
|
||||
|
||||
# Data output with dump_log_data(data, data_type) will be tagged w/ this
|
||||
_TAG = "json_stats: "
|
||||
|
||||
# Data output with dump_log_data(data, data_type) will have data[_TYPE]=data_type
|
||||
_TYPE = "_type"
|
||||
|
||||
|
||||
def _suppress_print():
|
||||
"""Suppresses printing from the current process."""
|
||||
|
||||
def ignore(*_objects, _sep=" ", _end="\n", _file=sys.stdout, _flush=False):
|
||||
pass
|
||||
|
||||
builtins.print = ignore
|
||||
|
||||
|
||||
def setup_logging():
|
||||
"""Sets up the logging."""
|
||||
# Enable logging only for the master process
|
||||
if dist.is_master_proc():
|
||||
# Clear the root logger to prevent any existing logging config
|
||||
# (e.g. set by another module) from messing with our setup
|
||||
logging.root.handlers = []
|
||||
# Construct logging configuration
|
||||
logging_config = {"level": logging.INFO, "format": _FORMAT}
|
||||
# Log either to stdout or to a file
|
||||
if cfg.LOG_DEST == "stdout":
|
||||
logging_config["stream"] = sys.stdout
|
||||
else:
|
||||
logging_config["filename"] = os.path.join(cfg.OUT_DIR, _LOG_FILE)
|
||||
# Configure logging
|
||||
logging.basicConfig(**logging_config)
|
||||
else:
|
||||
_suppress_print()
|
||||
|
||||
|
||||
def get_logger(name):
|
||||
"""Retrieves the logger."""
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
def dump_log_data(data, data_type, prec=4):
|
||||
"""Covert data (a dictionary) into tagged json string for logging."""
|
||||
data[_TYPE] = data_type
|
||||
data = float_to_decimal(data, prec)
|
||||
data_json = simplejson.dumps(data, sort_keys=True, use_decimal=True)
|
||||
return "{:s}{:s}".format(_TAG, data_json)
|
||||
|
||||
|
||||
def float_to_decimal(data, prec=4):
|
||||
"""Convert floats to decimals which allows for fixed width json."""
|
||||
if isinstance(data, dict):
|
||||
return {k: float_to_decimal(v, prec) for k, v in data.items()}
|
||||
if isinstance(data, float):
|
||||
return decimal.Decimal(("{:." + str(prec) + "f}").format(data))
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def get_log_files(log_dir, name_filter="", log_file=_LOG_FILE):
|
||||
"""Get all log files in directory containing subdirs of trained models."""
|
||||
names = [n for n in sorted(os.listdir(log_dir)) if name_filter in n]
|
||||
files = [os.path.join(log_dir, n, log_file) for n in names]
|
||||
f_n_ps = [(f, n) for (f, n) in zip(files, names) if os.path.exists(f)]
|
||||
files, names = zip(*f_n_ps) if f_n_ps else ([], [])
|
||||
return files, names
|
||||
|
||||
|
||||
def load_log_data(log_file, data_types_to_skip=()):
|
||||
"""Loads log data into a dictionary of the form data[data_type][metric][index]."""
|
||||
# Load log_file
|
||||
assert os.path.exists(log_file), "Log file not found: {}".format(log_file)
|
||||
with open(log_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
# Extract and parse lines that start with _TAG and have a type specified
|
||||
lines = [l[l.find(_TAG) + len(_TAG) :] for l in lines if _TAG in l]
|
||||
lines = [simplejson.loads(l) for l in lines]
|
||||
lines = [l for l in lines if _TYPE in l and not l[_TYPE] in data_types_to_skip]
|
||||
# Generate data structure accessed by data[data_type][index][metric]
|
||||
data_types = [l[_TYPE] for l in lines]
|
||||
data = {t: [] for t in data_types}
|
||||
for t, line in zip(data_types, lines):
|
||||
del line[_TYPE]
|
||||
data[t].append(line)
|
||||
# Generate data structure accessed by data[data_type][metric][index]
|
||||
for t in data:
|
||||
metrics = sorted(data[t][0].keys())
|
||||
err_str = "Inconsistent metrics in log for _type={}: {}".format(t, metrics)
|
||||
assert all(sorted(d.keys()) == metrics for d in data[t]), err_str
|
||||
data[t] = {m: [d[m] for d in data[t]] for m in metrics}
|
||||
return data
|
||||
|
||||
|
||||
def sort_log_data(data):
|
||||
"""Sort each data[data_type][metric] by epoch or keep only first instance."""
|
||||
for t in data:
|
||||
if "epoch" in data[t]:
|
||||
assert "epoch_ind" not in data[t] and "epoch_max" not in data[t]
|
||||
data[t]["epoch_ind"] = [int(e.split("/")[0]) for e in data[t]["epoch"]]
|
||||
data[t]["epoch_max"] = [int(e.split("/")[1]) for e in data[t]["epoch"]]
|
||||
epoch = data[t]["epoch_ind"]
|
||||
if "iter" in data[t]:
|
||||
assert "iter_ind" not in data[t] and "iter_max" not in data[t]
|
||||
data[t]["iter_ind"] = [int(i.split("/")[0]) for i in data[t]["iter"]]
|
||||
data[t]["iter_max"] = [int(i.split("/")[1]) for i in data[t]["iter"]]
|
||||
itr = zip(epoch, data[t]["iter_ind"], data[t]["iter_max"])
|
||||
epoch = [e + (i_ind - 1) / i_max for e, i_ind, i_max in itr]
|
||||
for m in data[t]:
|
||||
data[t][m] = [v for _, v in sorted(zip(epoch, data[t][m]))]
|
||||
else:
|
||||
data[t] = {m: d[0] for m, d in data[t].items()}
|
||||
return data
|
||||
435
pycls/core/meters.py
Normal file
435
pycls/core/meters.py
Normal file
@@ -0,0 +1,435 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Meters."""
|
||||
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
import pycls.core.logging as logging
|
||||
import torch
|
||||
from pycls.core.config import cfg
|
||||
from pycls.core.timer import Timer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def time_string(seconds):
|
||||
"""Converts time in seconds to a fixed-width string format."""
|
||||
days, rem = divmod(int(seconds), 24 * 3600)
|
||||
hrs, rem = divmod(rem, 3600)
|
||||
mins, secs = divmod(rem, 60)
|
||||
return "{0:02},{1:02}:{2:02}:{3:02}".format(days, hrs, mins, secs)
|
||||
|
||||
|
||||
def inter_union(preds, labels, num_classes):
|
||||
_, preds = torch.max(preds, 1)
|
||||
preds = preds.type(torch.uint8) + 1
|
||||
labels = labels.type(torch.uint8) + 1
|
||||
preds = preds * (labels > 0).type(torch.uint8)
|
||||
|
||||
inter = preds * (preds == labels).type(torch.uint8)
|
||||
area_inter = torch.histc(inter.type(torch.int64), bins=num_classes, min=1, max=num_classes)
|
||||
area_preds = torch.histc(preds.type(torch.int64), bins=num_classes, min=1, max=num_classes)
|
||||
area_labels = torch.histc(labels.type(torch.int64), bins=num_classes, min=1, max=num_classes)
|
||||
area_union = area_preds + area_labels - area_inter
|
||||
|
||||
return [area_inter.type(torch.float64) / labels.size(0), area_union.type(torch.float64) / labels.size(0)]
|
||||
|
||||
|
||||
def topk_errors(preds, labels, ks):
|
||||
"""Computes the top-k error for each k."""
|
||||
err_str = "Batch dim of predictions and labels must match"
|
||||
assert preds.size(0) == labels.size(0), err_str
|
||||
# Find the top max_k predictions for each sample
|
||||
_top_max_k_vals, top_max_k_inds = torch.topk(
|
||||
preds, max(ks), dim=1, largest=True, sorted=True
|
||||
)
|
||||
# (batch_size, max_k) -> (max_k, batch_size)
|
||||
top_max_k_inds = top_max_k_inds.t()
|
||||
# (batch_size, ) -> (max_k, batch_size)
|
||||
rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds)
|
||||
# (i, j) = 1 if top i-th prediction for the j-th sample is correct
|
||||
top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels)
|
||||
# Compute the number of topk correct predictions for each k
|
||||
topks_correct = [top_max_k_correct[:k, :].view(-1).float().sum() for k in ks]
|
||||
return [(1.0 - x / preds.size(0)) * 100.0 for x in topks_correct]
|
||||
|
||||
|
||||
def gpu_mem_usage():
|
||||
"""Computes the GPU memory usage for the current device (MB)."""
|
||||
mem_usage_bytes = torch.cuda.max_memory_allocated()
|
||||
return mem_usage_bytes / 1024 / 1024
|
||||
|
||||
|
||||
class ScalarMeter(object):
|
||||
"""Measures a scalar value (adapted from Detectron)."""
|
||||
|
||||
def __init__(self, window_size):
|
||||
self.deque = deque(maxlen=window_size)
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
|
||||
def reset(self):
|
||||
self.deque.clear()
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
|
||||
def add_value(self, value):
|
||||
self.deque.append(value)
|
||||
self.count += 1
|
||||
self.total += value
|
||||
|
||||
def get_win_median(self):
|
||||
return np.median(self.deque)
|
||||
|
||||
def get_win_avg(self):
|
||||
return np.mean(self.deque)
|
||||
|
||||
def get_global_avg(self):
|
||||
return self.total / self.count
|
||||
|
||||
|
||||
class TrainMeter(object):
|
||||
"""Measures training stats."""
|
||||
|
||||
def __init__(self, epoch_iters):
|
||||
self.epoch_iters = epoch_iters
|
||||
self.max_iter = cfg.OPTIM.MAX_EPOCH * epoch_iters
|
||||
self.iter_timer = Timer()
|
||||
self.loss = ScalarMeter(cfg.LOG_PERIOD)
|
||||
self.loss_total = 0.0
|
||||
self.lr = None
|
||||
# Current minibatch errors (smoothed over a window)
|
||||
self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD)
|
||||
self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD)
|
||||
# Number of misclassified examples
|
||||
self.num_top1_mis = 0
|
||||
self.num_top5_mis = 0
|
||||
self.num_samples = 0
|
||||
|
||||
def reset(self, timer=False):
|
||||
if timer:
|
||||
self.iter_timer.reset()
|
||||
self.loss.reset()
|
||||
self.loss_total = 0.0
|
||||
self.lr = None
|
||||
self.mb_top1_err.reset()
|
||||
self.mb_top5_err.reset()
|
||||
self.num_top1_mis = 0
|
||||
self.num_top5_mis = 0
|
||||
self.num_samples = 0
|
||||
|
||||
def iter_tic(self):
|
||||
self.iter_timer.tic()
|
||||
|
||||
def iter_toc(self):
|
||||
self.iter_timer.toc()
|
||||
|
||||
def update_stats(self, top1_err, top5_err, loss, lr, mb_size):
|
||||
# Current minibatch stats
|
||||
self.mb_top1_err.add_value(top1_err)
|
||||
self.mb_top5_err.add_value(top5_err)
|
||||
self.loss.add_value(loss)
|
||||
self.lr = lr
|
||||
# Aggregate stats
|
||||
self.num_top1_mis += top1_err * mb_size
|
||||
self.num_top5_mis += top5_err * mb_size
|
||||
self.loss_total += loss * mb_size
|
||||
self.num_samples += mb_size
|
||||
|
||||
def get_iter_stats(self, cur_epoch, cur_iter):
|
||||
cur_iter_total = cur_epoch * self.epoch_iters + cur_iter + 1
|
||||
eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total)
|
||||
mem_usage = gpu_mem_usage()
|
||||
stats = {
|
||||
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
|
||||
"iter": "{}/{}".format(cur_iter + 1, self.epoch_iters),
|
||||
"time_avg": self.iter_timer.average_time,
|
||||
"time_diff": self.iter_timer.diff,
|
||||
"eta": time_string(eta_sec),
|
||||
"top1_err": self.mb_top1_err.get_win_median(),
|
||||
"top5_err": self.mb_top5_err.get_win_median(),
|
||||
"loss": self.loss.get_win_median(),
|
||||
"lr": self.lr,
|
||||
"mem": int(np.ceil(mem_usage)),
|
||||
}
|
||||
return stats
|
||||
|
||||
def log_iter_stats(self, cur_epoch, cur_iter):
|
||||
if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
|
||||
return
|
||||
stats = self.get_iter_stats(cur_epoch, cur_iter)
|
||||
logger.info(logging.dump_log_data(stats, "train_iter"))
|
||||
|
||||
def get_epoch_stats(self, cur_epoch):
|
||||
cur_iter_total = (cur_epoch + 1) * self.epoch_iters
|
||||
eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total)
|
||||
mem_usage = gpu_mem_usage()
|
||||
top1_err = self.num_top1_mis / self.num_samples
|
||||
top5_err = self.num_top5_mis / self.num_samples
|
||||
avg_loss = self.loss_total / self.num_samples
|
||||
stats = {
|
||||
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
|
||||
"time_avg": self.iter_timer.average_time,
|
||||
"eta": time_string(eta_sec),
|
||||
"top1_err": top1_err,
|
||||
"top5_err": top5_err,
|
||||
"loss": avg_loss,
|
||||
"lr": self.lr,
|
||||
"mem": int(np.ceil(mem_usage)),
|
||||
}
|
||||
return stats
|
||||
|
||||
def log_epoch_stats(self, cur_epoch):
|
||||
stats = self.get_epoch_stats(cur_epoch)
|
||||
logger.info(logging.dump_log_data(stats, "train_epoch"))
|
||||
|
||||
|
||||
class TestMeter(object):
|
||||
"""Measures testing stats."""
|
||||
|
||||
def __init__(self, max_iter):
|
||||
self.max_iter = max_iter
|
||||
self.iter_timer = Timer()
|
||||
# Current minibatch errors (smoothed over a window)
|
||||
self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD)
|
||||
self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD)
|
||||
# Min errors (over the full test set)
|
||||
self.min_top1_err = 100.0
|
||||
self.min_top5_err = 100.0
|
||||
# Number of misclassified examples
|
||||
self.num_top1_mis = 0
|
||||
self.num_top5_mis = 0
|
||||
self.num_samples = 0
|
||||
|
||||
def reset(self, min_errs=False):
|
||||
if min_errs:
|
||||
self.min_top1_err = 100.0
|
||||
self.min_top5_err = 100.0
|
||||
self.iter_timer.reset()
|
||||
self.mb_top1_err.reset()
|
||||
self.mb_top5_err.reset()
|
||||
self.num_top1_mis = 0
|
||||
self.num_top5_mis = 0
|
||||
self.num_samples = 0
|
||||
|
||||
def iter_tic(self):
|
||||
self.iter_timer.tic()
|
||||
|
||||
def iter_toc(self):
|
||||
self.iter_timer.toc()
|
||||
|
||||
def update_stats(self, top1_err, top5_err, mb_size):
|
||||
self.mb_top1_err.add_value(top1_err)
|
||||
self.mb_top5_err.add_value(top5_err)
|
||||
self.num_top1_mis += top1_err * mb_size
|
||||
self.num_top5_mis += top5_err * mb_size
|
||||
self.num_samples += mb_size
|
||||
|
||||
def get_iter_stats(self, cur_epoch, cur_iter):
|
||||
mem_usage = gpu_mem_usage()
|
||||
iter_stats = {
|
||||
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
|
||||
"iter": "{}/{}".format(cur_iter + 1, self.max_iter),
|
||||
"time_avg": self.iter_timer.average_time,
|
||||
"time_diff": self.iter_timer.diff,
|
||||
"top1_err": self.mb_top1_err.get_win_median(),
|
||||
"top5_err": self.mb_top5_err.get_win_median(),
|
||||
"mem": int(np.ceil(mem_usage)),
|
||||
}
|
||||
return iter_stats
|
||||
|
||||
def log_iter_stats(self, cur_epoch, cur_iter):
|
||||
if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
|
||||
return
|
||||
stats = self.get_iter_stats(cur_epoch, cur_iter)
|
||||
logger.info(logging.dump_log_data(stats, "test_iter"))
|
||||
|
||||
def get_epoch_stats(self, cur_epoch):
|
||||
top1_err = self.num_top1_mis / self.num_samples
|
||||
top5_err = self.num_top5_mis / self.num_samples
|
||||
self.min_top1_err = min(self.min_top1_err, top1_err)
|
||||
self.min_top5_err = min(self.min_top5_err, top5_err)
|
||||
mem_usage = gpu_mem_usage()
|
||||
stats = {
|
||||
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
|
||||
"time_avg": self.iter_timer.average_time,
|
||||
"top1_err": top1_err,
|
||||
"top5_err": top5_err,
|
||||
"min_top1_err": self.min_top1_err,
|
||||
"min_top5_err": self.min_top5_err,
|
||||
"mem": int(np.ceil(mem_usage)),
|
||||
}
|
||||
return stats
|
||||
|
||||
def log_epoch_stats(self, cur_epoch):
|
||||
stats = self.get_epoch_stats(cur_epoch)
|
||||
logger.info(logging.dump_log_data(stats, "test_epoch"))
|
||||
|
||||
|
||||
class TrainMeterIoU(object):
|
||||
"""Measures training stats."""
|
||||
|
||||
def __init__(self, epoch_iters):
|
||||
self.epoch_iters = epoch_iters
|
||||
self.max_iter = cfg.OPTIM.MAX_EPOCH * epoch_iters
|
||||
self.iter_timer = Timer()
|
||||
self.loss = ScalarMeter(cfg.LOG_PERIOD)
|
||||
self.loss_total = 0.0
|
||||
self.lr = None
|
||||
|
||||
self.mb_miou = ScalarMeter(cfg.LOG_PERIOD)
|
||||
|
||||
self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES)
|
||||
self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES)
|
||||
self.num_samples = 0
|
||||
|
||||
def reset(self, timer=False):
|
||||
if timer:
|
||||
self.iter_timer.reset()
|
||||
self.loss.reset()
|
||||
self.loss_total = 0.0
|
||||
self.lr = None
|
||||
self.mb_miou.reset()
|
||||
self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES)
|
||||
self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES)
|
||||
self.num_samples = 0
|
||||
|
||||
def iter_tic(self):
|
||||
self.iter_timer.tic()
|
||||
|
||||
def iter_toc(self):
|
||||
self.iter_timer.toc()
|
||||
|
||||
def update_stats(self, inter, union, loss, lr, mb_size):
|
||||
# Current minibatch stats
|
||||
self.mb_miou.add_value((inter / (union + 1e-10)).mean())
|
||||
self.loss.add_value(loss)
|
||||
self.lr = lr
|
||||
# Aggregate stats
|
||||
self.num_inter += inter * mb_size
|
||||
self.num_union += union * mb_size
|
||||
self.loss_total += loss * mb_size
|
||||
self.num_samples += mb_size
|
||||
|
||||
def get_iter_stats(self, cur_epoch, cur_iter):
|
||||
cur_iter_total = cur_epoch * self.epoch_iters + cur_iter + 1
|
||||
eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total)
|
||||
mem_usage = gpu_mem_usage()
|
||||
stats = {
|
||||
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
|
||||
"iter": "{}/{}".format(cur_iter + 1, self.epoch_iters),
|
||||
"time_avg": self.iter_timer.average_time,
|
||||
"time_diff": self.iter_timer.diff,
|
||||
"eta": time_string(eta_sec),
|
||||
"miou": self.mb_miou.get_win_median(),
|
||||
"loss": self.loss.get_win_median(),
|
||||
"lr": self.lr,
|
||||
"mem": int(np.ceil(mem_usage)),
|
||||
}
|
||||
return stats
|
||||
|
||||
def log_iter_stats(self, cur_epoch, cur_iter):
|
||||
if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
|
||||
return
|
||||
stats = self.get_iter_stats(cur_epoch, cur_iter)
|
||||
logger.info(logging.dump_log_data(stats, "train_iter"))
|
||||
|
||||
def get_epoch_stats(self, cur_epoch):
|
||||
cur_iter_total = (cur_epoch + 1) * self.epoch_iters
|
||||
eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total)
|
||||
mem_usage = gpu_mem_usage()
|
||||
miou = (self.num_inter / (self.num_union + 1e-10)).mean()
|
||||
avg_loss = self.loss_total / self.num_samples
|
||||
stats = {
|
||||
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
|
||||
"time_avg": self.iter_timer.average_time,
|
||||
"eta": time_string(eta_sec),
|
||||
"miou": miou,
|
||||
"loss": avg_loss,
|
||||
"lr": self.lr,
|
||||
"mem": int(np.ceil(mem_usage)),
|
||||
}
|
||||
return stats
|
||||
|
||||
def log_epoch_stats(self, cur_epoch):
|
||||
stats = self.get_epoch_stats(cur_epoch)
|
||||
logger.info(logging.dump_log_data(stats, "train_epoch"))
|
||||
|
||||
|
||||
class TestMeterIoU(object):
|
||||
"""Measures testing stats."""
|
||||
|
||||
def __init__(self, max_iter):
|
||||
self.max_iter = max_iter
|
||||
self.iter_timer = Timer()
|
||||
|
||||
self.mb_miou = ScalarMeter(cfg.LOG_PERIOD)
|
||||
|
||||
self.max_miou = 0.0
|
||||
|
||||
self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES)
|
||||
self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES)
|
||||
self.num_samples = 0
|
||||
|
||||
def reset(self, min_errs=False):
|
||||
if min_errs:
|
||||
self.max_miou = 0.0
|
||||
self.iter_timer.reset()
|
||||
self.mb_miou.reset()
|
||||
self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES)
|
||||
self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES)
|
||||
self.num_samples = 0
|
||||
|
||||
def iter_tic(self):
|
||||
self.iter_timer.tic()
|
||||
|
||||
def iter_toc(self):
|
||||
self.iter_timer.toc()
|
||||
|
||||
def update_stats(self, inter, union, mb_size):
|
||||
self.mb_miou.add_value((inter / (union + 1e-10)).mean())
|
||||
self.num_inter += inter * mb_size
|
||||
self.num_union += union * mb_size
|
||||
self.num_samples += mb_size
|
||||
|
||||
def get_iter_stats(self, cur_epoch, cur_iter):
|
||||
mem_usage = gpu_mem_usage()
|
||||
iter_stats = {
|
||||
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
|
||||
"iter": "{}/{}".format(cur_iter + 1, self.max_iter),
|
||||
"time_avg": self.iter_timer.average_time,
|
||||
"time_diff": self.iter_timer.diff,
|
||||
"miou": self.mb_miou.get_win_median(),
|
||||
"mem": int(np.ceil(mem_usage)),
|
||||
}
|
||||
return iter_stats
|
||||
|
||||
def log_iter_stats(self, cur_epoch, cur_iter):
|
||||
if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
|
||||
return
|
||||
stats = self.get_iter_stats(cur_epoch, cur_iter)
|
||||
logger.info(logging.dump_log_data(stats, "test_iter"))
|
||||
|
||||
def get_epoch_stats(self, cur_epoch):
|
||||
miou = (self.num_inter / (self.num_union + 1e-10)).mean()
|
||||
self.max_miou = max(self.max_miou, miou)
|
||||
mem_usage = gpu_mem_usage()
|
||||
stats = {
|
||||
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
|
||||
"time_avg": self.iter_timer.average_time,
|
||||
"miou": miou,
|
||||
"max_miou": self.max_miou,
|
||||
"mem": int(np.ceil(mem_usage)),
|
||||
}
|
||||
return stats
|
||||
|
||||
def log_epoch_stats(self, cur_epoch):
|
||||
stats = self.get_epoch_stats(cur_epoch)
|
||||
logger.info(logging.dump_log_data(stats, "test_epoch"))
|
||||
129
pycls/core/net.py
Normal file
129
pycls/core/net.py
Normal file
@@ -0,0 +1,129 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Functions for manipulating networks."""
|
||||
|
||||
import itertools
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
def init_weights(m):
|
||||
"""Performs ResNet-style weight initialization."""
|
||||
if isinstance(m, nn.Conv2d):
|
||||
# Note that there is no bias due to BN
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
zero_init_gamma = cfg.BN.ZERO_INIT_FINAL_GAMMA
|
||||
zero_init_gamma = hasattr(m, "final_bn") and m.final_bn and zero_init_gamma
|
||||
m.weight.data.fill_(0.0 if zero_init_gamma else 1.0)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
m.weight.data.normal_(mean=0.0, std=0.01)
|
||||
m.bias.data.zero_()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_precise_bn_stats(model, loader):
|
||||
"""Computes precise BN stats on training data."""
|
||||
# Compute the number of minibatches to use
|
||||
num_iter = min(cfg.BN.NUM_SAMPLES_PRECISE // loader.batch_size, len(loader))
|
||||
# Retrieve the BN layers
|
||||
bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
|
||||
# Initialize stats storage
|
||||
mus = [torch.zeros_like(bn.running_mean) for bn in bns]
|
||||
sqs = [torch.zeros_like(bn.running_var) for bn in bns]
|
||||
# Remember momentum values
|
||||
moms = [bn.momentum for bn in bns]
|
||||
# Disable momentum
|
||||
for bn in bns:
|
||||
bn.momentum = 1.0
|
||||
# Accumulate the stats across the data samples
|
||||
for inputs, _labels in itertools.islice(loader, num_iter):
|
||||
model(inputs.cuda())
|
||||
# Accumulate the stats for each BN layer
|
||||
for i, bn in enumerate(bns):
|
||||
m, v = bn.running_mean, bn.running_var
|
||||
sqs[i] += (v + m * m) / num_iter
|
||||
mus[i] += m / num_iter
|
||||
# Set the stats and restore momentum values
|
||||
for i, bn in enumerate(bns):
|
||||
bn.running_var = sqs[i] - mus[i] * mus[i]
|
||||
bn.running_mean = mus[i]
|
||||
bn.momentum = moms[i]
|
||||
|
||||
|
||||
def reset_bn_stats(model):
|
||||
"""Resets running BN stats."""
|
||||
for m in model.modules():
|
||||
if isinstance(m, torch.nn.BatchNorm2d):
|
||||
m.reset_running_stats()
|
||||
|
||||
|
||||
def complexity_conv2d(cx, w_in, w_out, k, stride, padding, groups=1, bias=False):
|
||||
"""Accumulates complexity of Conv2D into cx = (h, w, flops, params, acts)."""
|
||||
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
|
||||
h = (h + 2 * padding - k) // stride + 1
|
||||
w = (w + 2 * padding - k) // stride + 1
|
||||
flops += k * k * w_in * w_out * h * w // groups
|
||||
params += k * k * w_in * w_out // groups
|
||||
flops += w_out if bias else 0
|
||||
params += w_out if bias else 0
|
||||
acts += w_out * h * w
|
||||
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
|
||||
|
||||
|
||||
def complexity_batchnorm2d(cx, w_in):
|
||||
"""Accumulates complexity of BatchNorm2D into cx = (h, w, flops, params, acts)."""
|
||||
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
|
||||
params += 2 * w_in
|
||||
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
|
||||
|
||||
|
||||
def complexity_maxpool2d(cx, k, stride, padding):
|
||||
"""Accumulates complexity of MaxPool2d into cx = (h, w, flops, params, acts)."""
|
||||
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
|
||||
h = (h + 2 * padding - k) // stride + 1
|
||||
w = (w + 2 * padding - k) // stride + 1
|
||||
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
|
||||
|
||||
|
||||
def complexity(model):
|
||||
"""Compute model complexity (model can be model instance or model class)."""
|
||||
size = cfg.TRAIN.IM_SIZE
|
||||
cx = {"h": size, "w": size, "flops": 0, "params": 0, "acts": 0}
|
||||
cx = model.complexity(cx)
|
||||
return {"flops": cx["flops"], "params": cx["params"], "acts": cx["acts"]}
|
||||
|
||||
|
||||
def drop_connect(x, drop_ratio):
|
||||
"""Drop connect (adapted from DARTS)."""
|
||||
keep_ratio = 1.0 - drop_ratio
|
||||
mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device)
|
||||
mask.bernoulli_(keep_ratio)
|
||||
x.div_(keep_ratio)
|
||||
x.mul_(mask)
|
||||
return x
|
||||
|
||||
|
||||
def get_flat_weights(model):
|
||||
"""Gets all model weights as a single flat vector."""
|
||||
return torch.cat([p.data.view(-1, 1) for p in model.parameters()], 0)
|
||||
|
||||
|
||||
def set_flat_weights(model, flat_weights):
|
||||
"""Sets all model weights from a single flat vector."""
|
||||
k = 0
|
||||
for p in model.parameters():
|
||||
n = p.data.numel()
|
||||
p.data.copy_(flat_weights[k : (k + n)].view_as(p.data))
|
||||
k += n
|
||||
assert k == flat_weights.numel()
|
||||
95
pycls/core/optimizer.py
Normal file
95
pycls/core/optimizer.py
Normal file
@@ -0,0 +1,95 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Optimizer."""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
def construct_optimizer(model):
|
||||
"""Constructs the optimizer.
|
||||
|
||||
Note that the momentum update in PyTorch differs from the one in Caffe2.
|
||||
In particular,
|
||||
|
||||
Caffe2:
|
||||
V := mu * V + lr * g
|
||||
p := p - V
|
||||
|
||||
PyTorch:
|
||||
V := mu * V + g
|
||||
p := p - lr * V
|
||||
|
||||
where V is the velocity, mu is the momentum factor, lr is the learning rate,
|
||||
g is the gradient and p are the parameters.
|
||||
|
||||
Since V is defined independently of the learning rate in PyTorch,
|
||||
when the learning rate is changed there is no need to perform the
|
||||
momentum correction by scaling V (unlike in the Caffe2 case).
|
||||
"""
|
||||
if cfg.BN.USE_CUSTOM_WEIGHT_DECAY:
|
||||
# Apply different weight decay to Batchnorm and non-batchnorm parameters.
|
||||
p_bn = [p for n, p in model.named_parameters() if "bn" in n]
|
||||
p_non_bn = [p for n, p in model.named_parameters() if "bn" not in n]
|
||||
optim_params = [
|
||||
{"params": p_bn, "weight_decay": cfg.BN.CUSTOM_WEIGHT_DECAY},
|
||||
{"params": p_non_bn, "weight_decay": cfg.OPTIM.WEIGHT_DECAY},
|
||||
]
|
||||
else:
|
||||
optim_params = model.parameters()
|
||||
return torch.optim.SGD(
|
||||
optim_params,
|
||||
lr=cfg.OPTIM.BASE_LR,
|
||||
momentum=cfg.OPTIM.MOMENTUM,
|
||||
weight_decay=cfg.OPTIM.WEIGHT_DECAY,
|
||||
dampening=cfg.OPTIM.DAMPENING,
|
||||
nesterov=cfg.OPTIM.NESTEROV,
|
||||
)
|
||||
|
||||
|
||||
def lr_fun_steps(cur_epoch):
|
||||
"""Steps schedule (cfg.OPTIM.LR_POLICY = 'steps')."""
|
||||
ind = [i for i, s in enumerate(cfg.OPTIM.STEPS) if cur_epoch >= s][-1]
|
||||
return cfg.OPTIM.BASE_LR * (cfg.OPTIM.LR_MULT ** ind)
|
||||
|
||||
|
||||
def lr_fun_exp(cur_epoch):
|
||||
"""Exponential schedule (cfg.OPTIM.LR_POLICY = 'exp')."""
|
||||
return cfg.OPTIM.BASE_LR * (cfg.OPTIM.GAMMA ** cur_epoch)
|
||||
|
||||
|
||||
def lr_fun_cos(cur_epoch):
|
||||
"""Cosine schedule (cfg.OPTIM.LR_POLICY = 'cos')."""
|
||||
base_lr, max_epoch = cfg.OPTIM.BASE_LR, cfg.OPTIM.MAX_EPOCH
|
||||
return 0.5 * base_lr * (1.0 + np.cos(np.pi * cur_epoch / max_epoch))
|
||||
|
||||
|
||||
def get_lr_fun():
|
||||
"""Retrieves the specified lr policy function"""
|
||||
lr_fun = "lr_fun_" + cfg.OPTIM.LR_POLICY
|
||||
if lr_fun not in globals():
|
||||
raise NotImplementedError("Unknown LR policy:" + cfg.OPTIM.LR_POLICY)
|
||||
return globals()[lr_fun]
|
||||
|
||||
|
||||
def get_epoch_lr(cur_epoch):
|
||||
"""Retrieves the lr for the given epoch according to the policy."""
|
||||
lr = get_lr_fun()(cur_epoch)
|
||||
# Linear warmup
|
||||
if cur_epoch < cfg.OPTIM.WARMUP_EPOCHS:
|
||||
alpha = cur_epoch / cfg.OPTIM.WARMUP_EPOCHS
|
||||
warmup_factor = cfg.OPTIM.WARMUP_FACTOR * (1.0 - alpha) + alpha
|
||||
lr *= warmup_factor
|
||||
return lr
|
||||
|
||||
|
||||
def set_lr(optimizer, new_lr):
|
||||
"""Sets the optimizer lr to the specified value."""
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = new_lr
|
||||
132
pycls/core/plotting.py
Normal file
132
pycls/core/plotting.py
Normal file
@@ -0,0 +1,132 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Plotting functions."""
|
||||
|
||||
import colorlover as cl
|
||||
import matplotlib.pyplot as plt
|
||||
import plotly.graph_objs as go
|
||||
import plotly.offline as offline
|
||||
import pycls.core.logging as logging
|
||||
|
||||
|
||||
def get_plot_colors(max_colors, color_format="pyplot"):
|
||||
"""Generate colors for plotting."""
|
||||
colors = cl.scales["11"]["qual"]["Paired"]
|
||||
if max_colors > len(colors):
|
||||
colors = cl.to_rgb(cl.interp(colors, max_colors))
|
||||
if color_format == "pyplot":
|
||||
return [[j / 255.0 for j in c] for c in cl.to_numeric(colors)]
|
||||
return colors
|
||||
|
||||
|
||||
def prepare_plot_data(log_files, names, metric="top1_err"):
|
||||
"""Load logs and extract data for plotting error curves."""
|
||||
plot_data = []
|
||||
for file, name in zip(log_files, names):
|
||||
d, data = {}, logging.sort_log_data(logging.load_log_data(file))
|
||||
for phase in ["train", "test"]:
|
||||
x = data[phase + "_epoch"]["epoch_ind"]
|
||||
y = data[phase + "_epoch"][metric]
|
||||
d["x_" + phase], d["y_" + phase] = x, y
|
||||
d[phase + "_label"] = "[{:5.2f}] ".format(min(y) if y else 0) + name
|
||||
plot_data.append(d)
|
||||
assert len(plot_data) > 0, "No data to plot"
|
||||
return plot_data
|
||||
|
||||
|
||||
def plot_error_curves_plotly(log_files, names, filename, metric="top1_err"):
|
||||
"""Plot error curves using plotly and save to file."""
|
||||
plot_data = prepare_plot_data(log_files, names, metric)
|
||||
colors = get_plot_colors(len(plot_data), "plotly")
|
||||
# Prepare data for plots (3 sets, train duplicated w and w/o legend)
|
||||
data = []
|
||||
for i, d in enumerate(plot_data):
|
||||
s = str(i)
|
||||
line_train = {"color": colors[i], "dash": "dashdot", "width": 1.5}
|
||||
line_test = {"color": colors[i], "dash": "solid", "width": 1.5}
|
||||
data.append(
|
||||
go.Scatter(
|
||||
x=d["x_train"],
|
||||
y=d["y_train"],
|
||||
mode="lines",
|
||||
name=d["train_label"],
|
||||
line=line_train,
|
||||
legendgroup=s,
|
||||
visible=True,
|
||||
showlegend=False,
|
||||
)
|
||||
)
|
||||
data.append(
|
||||
go.Scatter(
|
||||
x=d["x_test"],
|
||||
y=d["y_test"],
|
||||
mode="lines",
|
||||
name=d["test_label"],
|
||||
line=line_test,
|
||||
legendgroup=s,
|
||||
visible=True,
|
||||
showlegend=True,
|
||||
)
|
||||
)
|
||||
data.append(
|
||||
go.Scatter(
|
||||
x=d["x_train"],
|
||||
y=d["y_train"],
|
||||
mode="lines",
|
||||
name=d["train_label"],
|
||||
line=line_train,
|
||||
legendgroup=s,
|
||||
visible=False,
|
||||
showlegend=True,
|
||||
)
|
||||
)
|
||||
# Prepare layout w ability to toggle 'all', 'train', 'test'
|
||||
titlefont = {"size": 18, "color": "#7f7f7f"}
|
||||
vis = [[True, True, False], [False, False, True], [False, True, False]]
|
||||
buttons = zip(["all", "train", "test"], [[{"visible": v}] for v in vis])
|
||||
buttons = [{"label": b, "args": v, "method": "update"} for b, v in buttons]
|
||||
layout = go.Layout(
|
||||
title=metric + " vs. epoch<br>[dash=train, solid=test]",
|
||||
xaxis={"title": "epoch", "titlefont": titlefont},
|
||||
yaxis={"title": metric, "titlefont": titlefont},
|
||||
showlegend=True,
|
||||
hoverlabel={"namelength": -1},
|
||||
updatemenus=[
|
||||
{
|
||||
"buttons": buttons,
|
||||
"direction": "down",
|
||||
"showactive": True,
|
||||
"x": 1.02,
|
||||
"xanchor": "left",
|
||||
"y": 1.08,
|
||||
"yanchor": "top",
|
||||
}
|
||||
],
|
||||
)
|
||||
# Create plotly plot
|
||||
offline.plot({"data": data, "layout": layout}, filename=filename)
|
||||
|
||||
|
||||
def plot_error_curves_pyplot(log_files, names, filename=None, metric="top1_err"):
|
||||
"""Plot error curves using matplotlib.pyplot and save to file."""
|
||||
plot_data = prepare_plot_data(log_files, names, metric)
|
||||
colors = get_plot_colors(len(names))
|
||||
for ind, d in enumerate(plot_data):
|
||||
c, lbl = colors[ind], d["test_label"]
|
||||
plt.plot(d["x_train"], d["y_train"], "--", c=c, alpha=0.8)
|
||||
plt.plot(d["x_test"], d["y_test"], "-", c=c, alpha=0.8, label=lbl)
|
||||
plt.title(metric + " vs. epoch\n[dash=train, solid=test]", fontsize=14)
|
||||
plt.xlabel("epoch", fontsize=14)
|
||||
plt.ylabel(metric, fontsize=14)
|
||||
plt.grid(alpha=0.4)
|
||||
plt.legend()
|
||||
if filename:
|
||||
plt.savefig(filename)
|
||||
plt.clf()
|
||||
else:
|
||||
plt.show()
|
||||
39
pycls/core/timer.py
Normal file
39
pycls/core/timer.py
Normal file
@@ -0,0 +1,39 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Timer."""
|
||||
|
||||
import time
|
||||
|
||||
|
||||
class Timer(object):
|
||||
"""A simple timer (adapted from Detectron)."""
|
||||
|
||||
def __init__(self):
|
||||
self.total_time = None
|
||||
self.calls = None
|
||||
self.start_time = None
|
||||
self.diff = None
|
||||
self.average_time = None
|
||||
self.reset()
|
||||
|
||||
def tic(self):
|
||||
# using time.time as time.clock does not normalize for multithreading
|
||||
self.start_time = time.time()
|
||||
|
||||
def toc(self):
|
||||
self.diff = time.time() - self.start_time
|
||||
self.total_time += self.diff
|
||||
self.calls += 1
|
||||
self.average_time = self.total_time / self.calls
|
||||
|
||||
def reset(self):
|
||||
self.total_time = 0.0
|
||||
self.calls = 0
|
||||
self.start_time = 0.0
|
||||
self.diff = 0.0
|
||||
self.average_time = 0.0
|
||||
419
pycls/core/trainer.py
Normal file
419
pycls/core/trainer.py
Normal file
@@ -0,0 +1,419 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Tools for training and testing a model."""
|
||||
|
||||
import os
|
||||
from thop import profile
|
||||
|
||||
import numpy as np
|
||||
import pycls.core.benchmark as benchmark
|
||||
import pycls.core.builders as builders
|
||||
import pycls.core.checkpoint as checkpoint
|
||||
import pycls.core.config as config
|
||||
import pycls.core.distributed as dist
|
||||
import pycls.core.logging as logging
|
||||
import pycls.core.meters as meters
|
||||
import pycls.core.net as net
|
||||
import pycls.core.optimizer as optim
|
||||
import pycls.datasets.loader as loader
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def setup_env():
|
||||
"""Sets up environment for training or testing."""
|
||||
if dist.is_master_proc():
|
||||
# Ensure that the output dir exists
|
||||
os.makedirs(cfg.OUT_DIR, exist_ok=True)
|
||||
# Save the config
|
||||
config.dump_cfg()
|
||||
# Setup logging
|
||||
logging.setup_logging()
|
||||
# Log the config as both human readable and as a json
|
||||
logger.info("Config:\n{}".format(cfg))
|
||||
logger.info(logging.dump_log_data(cfg, "cfg"))
|
||||
# Fix the RNG seeds (see RNG comment in core/config.py for discussion)
|
||||
np.random.seed(cfg.RNG_SEED)
|
||||
torch.manual_seed(cfg.RNG_SEED)
|
||||
# Configure the CUDNN backend
|
||||
torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK
|
||||
|
||||
|
||||
def setup_model():
|
||||
"""Sets up a model for training or testing and log the results."""
|
||||
# Build the model
|
||||
model = builders.build_model()
|
||||
logger.info("Model:\n{}".format(model))
|
||||
# Log model complexity
|
||||
# logger.info(logging.dump_log_data(net.complexity(model), "complexity"))
|
||||
if cfg.TASK == "seg" and cfg.TRAIN.DATASET == "cityscapes":
|
||||
h, w = 1025, 2049
|
||||
else:
|
||||
h, w = cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE
|
||||
if cfg.TASK == "jig":
|
||||
x = torch.randn(1, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, h, w)
|
||||
else:
|
||||
x = torch.randn(1, cfg.MODEL.INPUT_CHANNELS, h, w)
|
||||
macs, params = profile(model, inputs=(x, ), verbose=False)
|
||||
logger.info("Params: {:,}".format(params))
|
||||
logger.info("Flops: {:,}".format(macs))
|
||||
# Transfer the model to the current GPU device
|
||||
err_str = "Cannot use more GPU devices than available"
|
||||
assert cfg.NUM_GPUS <= torch.cuda.device_count(), err_str
|
||||
cur_device = torch.cuda.current_device()
|
||||
model = model.cuda(device=cur_device)
|
||||
# Use multi-process data parallel model in the multi-gpu setting
|
||||
if cfg.NUM_GPUS > 1:
|
||||
# Make model replica operate on the current device
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
module=model, device_ids=[cur_device], output_device=cur_device
|
||||
)
|
||||
# Set complexity function to be module's complexity function
|
||||
# model.complexity = model.module.complexity
|
||||
return model
|
||||
|
||||
|
||||
def train_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch):
|
||||
"""Performs one epoch of training."""
|
||||
# Update drop path prob for NAS
|
||||
if cfg.MODEL.TYPE == "nas":
|
||||
m = model.module if cfg.NUM_GPUS > 1 else model
|
||||
m.set_drop_path_prob(cfg.NAS.DROP_PROB * cur_epoch / cfg.OPTIM.MAX_EPOCH)
|
||||
# Shuffle the data
|
||||
loader.shuffle(train_loader, cur_epoch)
|
||||
# Update the learning rate per epoch
|
||||
if not cfg.OPTIM.ITER_LR:
|
||||
lr = optim.get_epoch_lr(cur_epoch)
|
||||
optim.set_lr(optimizer, lr)
|
||||
# Enable training mode
|
||||
model.train()
|
||||
train_meter.iter_tic()
|
||||
for cur_iter, (inputs, labels) in enumerate(train_loader):
|
||||
# Update the learning rate per iter
|
||||
if cfg.OPTIM.ITER_LR:
|
||||
lr = optim.get_epoch_lr(cur_epoch + cur_iter / len(train_loader))
|
||||
optim.set_lr(optimizer, lr)
|
||||
# Transfer the data to the current GPU device
|
||||
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
|
||||
# Perform the forward pass
|
||||
preds = model(inputs)
|
||||
# Compute the loss
|
||||
if isinstance(preds, tuple):
|
||||
loss = loss_fun(preds[0], labels) + cfg.NAS.AUX_WEIGHT * loss_fun(preds[1], labels)
|
||||
preds = preds[0]
|
||||
else:
|
||||
loss = loss_fun(preds, labels)
|
||||
# Perform the backward pass
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
# Update the parameters
|
||||
optimizer.step()
|
||||
# Compute the errors
|
||||
if cfg.TASK == "col":
|
||||
preds = preds.permute(0, 2, 3, 1)
|
||||
preds = preds.reshape(-1, preds.size(3))
|
||||
labels = labels.reshape(-1)
|
||||
mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS
|
||||
else:
|
||||
mb_size = inputs.size(0) * cfg.NUM_GPUS
|
||||
if cfg.TASK == "seg":
|
||||
# top1_err is in fact inter; top5_err is in fact union
|
||||
top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES)
|
||||
else:
|
||||
ks = [1, min(5, cfg.MODEL.NUM_CLASSES)] # rot only has 4 classes
|
||||
top1_err, top5_err = meters.topk_errors(preds, labels, ks)
|
||||
# Combine the stats across the GPUs (no reduction if 1 GPU used)
|
||||
loss, top1_err, top5_err = dist.scaled_all_reduce([loss, top1_err, top5_err])
|
||||
# Copy the stats from GPU to CPU (sync point)
|
||||
loss = loss.item()
|
||||
if cfg.TASK == "seg":
|
||||
top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy()
|
||||
else:
|
||||
top1_err, top5_err = top1_err.item(), top5_err.item()
|
||||
train_meter.iter_toc()
|
||||
# Update and log stats
|
||||
train_meter.update_stats(top1_err, top5_err, loss, lr, mb_size)
|
||||
train_meter.log_iter_stats(cur_epoch, cur_iter)
|
||||
train_meter.iter_tic()
|
||||
# Log epoch stats
|
||||
train_meter.log_epoch_stats(cur_epoch)
|
||||
train_meter.reset()
|
||||
|
||||
|
||||
def search_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch):
|
||||
"""Performs one epoch of differentiable architecture search."""
|
||||
m = model.module if cfg.NUM_GPUS > 1 else model
|
||||
# Shuffle the data
|
||||
loader.shuffle(train_loader[0], cur_epoch)
|
||||
loader.shuffle(train_loader[1], cur_epoch)
|
||||
# Update the learning rate per epoch
|
||||
if not cfg.OPTIM.ITER_LR:
|
||||
lr = optim.get_epoch_lr(cur_epoch)
|
||||
optim.set_lr(optimizer[0], lr)
|
||||
# Enable training mode
|
||||
model.train()
|
||||
train_meter.iter_tic()
|
||||
trainB_iter = iter(train_loader[1])
|
||||
for cur_iter, (inputs, labels) in enumerate(train_loader[0]):
|
||||
# Update the learning rate per iter
|
||||
if cfg.OPTIM.ITER_LR:
|
||||
lr = optim.get_epoch_lr(cur_epoch + cur_iter / len(train_loader[0]))
|
||||
optim.set_lr(optimizer[0], lr)
|
||||
# Transfer the data to the current GPU device
|
||||
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
|
||||
# Update architecture
|
||||
if cur_epoch + cur_iter / len(train_loader[0]) >= cfg.OPTIM.ARCH_EPOCH:
|
||||
try:
|
||||
inputsB, labelsB = next(trainB_iter)
|
||||
except StopIteration:
|
||||
trainB_iter = iter(train_loader[1])
|
||||
inputsB, labelsB = next(trainB_iter)
|
||||
inputsB, labelsB = inputsB.cuda(), labelsB.cuda(non_blocking=True)
|
||||
optimizer[1].zero_grad()
|
||||
loss = m._loss(inputsB, labelsB)
|
||||
loss.backward()
|
||||
optimizer[1].step()
|
||||
# Perform the forward pass
|
||||
preds = model(inputs)
|
||||
# Compute the loss
|
||||
loss = loss_fun(preds, labels)
|
||||
# Perform the backward pass
|
||||
optimizer[0].zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm(model.parameters(), 5.0)
|
||||
# Update the parameters
|
||||
optimizer[0].step()
|
||||
# Compute the errors
|
||||
if cfg.TASK == "col":
|
||||
preds = preds.permute(0, 2, 3, 1)
|
||||
preds = preds.reshape(-1, preds.size(3))
|
||||
labels = labels.reshape(-1)
|
||||
mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS
|
||||
else:
|
||||
mb_size = inputs.size(0) * cfg.NUM_GPUS
|
||||
if cfg.TASK == "seg":
|
||||
# top1_err is in fact inter; top5_err is in fact union
|
||||
top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES)
|
||||
else:
|
||||
ks = [1, min(5, cfg.MODEL.NUM_CLASSES)] # rot only has 4 classes
|
||||
top1_err, top5_err = meters.topk_errors(preds, labels, ks)
|
||||
# Combine the stats across the GPUs (no reduction if 1 GPU used)
|
||||
loss, top1_err, top5_err = dist.scaled_all_reduce([loss, top1_err, top5_err])
|
||||
# Copy the stats from GPU to CPU (sync point)
|
||||
loss = loss.item()
|
||||
if cfg.TASK == "seg":
|
||||
top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy()
|
||||
else:
|
||||
top1_err, top5_err = top1_err.item(), top5_err.item()
|
||||
train_meter.iter_toc()
|
||||
# Update and log stats
|
||||
train_meter.update_stats(top1_err, top5_err, loss, lr, mb_size)
|
||||
train_meter.log_iter_stats(cur_epoch, cur_iter)
|
||||
train_meter.iter_tic()
|
||||
# Log epoch stats
|
||||
train_meter.log_epoch_stats(cur_epoch)
|
||||
train_meter.reset()
|
||||
# Log genotype
|
||||
genotype = m.genotype()
|
||||
logger.info("genotype = %s", genotype)
|
||||
logger.info(F.softmax(m.net_.alphas_normal, dim=-1))
|
||||
logger.info(F.softmax(m.net_.alphas_reduce, dim=-1))
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_epoch(test_loader, model, test_meter, cur_epoch):
|
||||
"""Evaluates the model on the test set."""
|
||||
# Enable eval mode
|
||||
model.eval()
|
||||
test_meter.iter_tic()
|
||||
for cur_iter, (inputs, labels) in enumerate(test_loader):
|
||||
# Transfer the data to the current GPU device
|
||||
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
|
||||
# Compute the predictions
|
||||
preds = model(inputs)
|
||||
# Compute the errors
|
||||
if cfg.TASK == "col":
|
||||
preds = preds.permute(0, 2, 3, 1)
|
||||
preds = preds.reshape(-1, preds.size(3))
|
||||
labels = labels.reshape(-1)
|
||||
mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS
|
||||
else:
|
||||
mb_size = inputs.size(0) * cfg.NUM_GPUS
|
||||
if cfg.TASK == "seg":
|
||||
# top1_err is in fact inter; top5_err is in fact union
|
||||
top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES)
|
||||
else:
|
||||
ks = [1, min(5, cfg.MODEL.NUM_CLASSES)] # rot only has 4 classes
|
||||
top1_err, top5_err = meters.topk_errors(preds, labels, ks)
|
||||
# Combine the errors across the GPUs (no reduction if 1 GPU used)
|
||||
top1_err, top5_err = dist.scaled_all_reduce([top1_err, top5_err])
|
||||
# Copy the errors from GPU to CPU (sync point)
|
||||
if cfg.TASK == "seg":
|
||||
top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy()
|
||||
else:
|
||||
top1_err, top5_err = top1_err.item(), top5_err.item()
|
||||
test_meter.iter_toc()
|
||||
# Update and log stats
|
||||
test_meter.update_stats(top1_err, top5_err, mb_size)
|
||||
test_meter.log_iter_stats(cur_epoch, cur_iter)
|
||||
test_meter.iter_tic()
|
||||
# Log epoch stats
|
||||
test_meter.log_epoch_stats(cur_epoch)
|
||||
test_meter.reset()
|
||||
|
||||
|
||||
def train_model():
|
||||
"""Trains the model."""
|
||||
# Setup training/testing environment
|
||||
setup_env()
|
||||
# Construct the model, loss_fun, and optimizer
|
||||
model = setup_model()
|
||||
loss_fun = builders.build_loss_fun().cuda()
|
||||
if "search" in cfg.MODEL.TYPE:
|
||||
params_w = [v for k, v in model.named_parameters() if "alphas" not in k]
|
||||
params_a = [v for k, v in model.named_parameters() if "alphas" in k]
|
||||
optimizer_w = torch.optim.SGD(
|
||||
params=params_w,
|
||||
lr=cfg.OPTIM.BASE_LR,
|
||||
momentum=cfg.OPTIM.MOMENTUM,
|
||||
weight_decay=cfg.OPTIM.WEIGHT_DECAY,
|
||||
dampening=cfg.OPTIM.DAMPENING,
|
||||
nesterov=cfg.OPTIM.NESTEROV
|
||||
)
|
||||
if cfg.OPTIM.ARCH_OPTIM == "adam":
|
||||
optimizer_a = torch.optim.Adam(
|
||||
params=params_a,
|
||||
lr=cfg.OPTIM.ARCH_BASE_LR,
|
||||
betas=(0.5, 0.999),
|
||||
weight_decay=cfg.OPTIM.ARCH_WEIGHT_DECAY
|
||||
)
|
||||
elif cfg.OPTIM.ARCH_OPTIM == "sgd":
|
||||
optimizer_a = torch.optim.SGD(
|
||||
params=params_a,
|
||||
lr=cfg.OPTIM.ARCH_BASE_LR,
|
||||
momentum=cfg.OPTIM.MOMENTUM,
|
||||
weight_decay=cfg.OPTIM.ARCH_WEIGHT_DECAY,
|
||||
dampening=cfg.OPTIM.DAMPENING,
|
||||
nesterov=cfg.OPTIM.NESTEROV
|
||||
)
|
||||
optimizer = [optimizer_w, optimizer_a]
|
||||
else:
|
||||
optimizer = optim.construct_optimizer(model)
|
||||
# Load checkpoint or initial weights
|
||||
start_epoch = 0
|
||||
if cfg.TRAIN.AUTO_RESUME and checkpoint.has_checkpoint():
|
||||
last_checkpoint = checkpoint.get_last_checkpoint()
|
||||
checkpoint_epoch = checkpoint.load_checkpoint(last_checkpoint, model, optimizer)
|
||||
logger.info("Loaded checkpoint from: {}".format(last_checkpoint))
|
||||
start_epoch = checkpoint_epoch + 1
|
||||
elif cfg.TRAIN.WEIGHTS:
|
||||
checkpoint.load_checkpoint(cfg.TRAIN.WEIGHTS, model)
|
||||
logger.info("Loaded initial weights from: {}".format(cfg.TRAIN.WEIGHTS))
|
||||
# Create data loaders and meters
|
||||
if cfg.TRAIN.PORTION < 1:
|
||||
if "search" in cfg.MODEL.TYPE:
|
||||
train_loader = [loader._construct_loader(
|
||||
dataset_name=cfg.TRAIN.DATASET,
|
||||
split=cfg.TRAIN.SPLIT,
|
||||
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
portion=cfg.TRAIN.PORTION,
|
||||
side="l"
|
||||
),
|
||||
loader._construct_loader(
|
||||
dataset_name=cfg.TRAIN.DATASET,
|
||||
split=cfg.TRAIN.SPLIT,
|
||||
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
portion=cfg.TRAIN.PORTION,
|
||||
side="r"
|
||||
)]
|
||||
else:
|
||||
train_loader = loader._construct_loader(
|
||||
dataset_name=cfg.TRAIN.DATASET,
|
||||
split=cfg.TRAIN.SPLIT,
|
||||
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
portion=cfg.TRAIN.PORTION,
|
||||
side="l"
|
||||
)
|
||||
test_loader = loader._construct_loader(
|
||||
dataset_name=cfg.TRAIN.DATASET,
|
||||
split=cfg.TRAIN.SPLIT,
|
||||
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
portion=cfg.TRAIN.PORTION,
|
||||
side="r"
|
||||
)
|
||||
else:
|
||||
train_loader = loader.construct_train_loader()
|
||||
test_loader = loader.construct_test_loader()
|
||||
train_meter_type = meters.TrainMeterIoU if cfg.TASK == "seg" else meters.TrainMeter
|
||||
test_meter_type = meters.TestMeterIoU if cfg.TASK == "seg" else meters.TestMeter
|
||||
l = train_loader[0] if isinstance(train_loader, list) else train_loader
|
||||
train_meter = train_meter_type(len(l))
|
||||
test_meter = test_meter_type(len(test_loader))
|
||||
# Compute model and loader timings
|
||||
if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0:
|
||||
l = train_loader[0] if isinstance(train_loader, list) else train_loader
|
||||
benchmark.compute_time_full(model, loss_fun, l, test_loader)
|
||||
# Perform the training loop
|
||||
logger.info("Start epoch: {}".format(start_epoch + 1))
|
||||
for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
|
||||
# Train for one epoch
|
||||
f = search_epoch if "search" in cfg.MODEL.TYPE else train_epoch
|
||||
f(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch)
|
||||
# Compute precise BN stats
|
||||
if cfg.BN.USE_PRECISE_STATS:
|
||||
net.compute_precise_bn_stats(model, train_loader)
|
||||
# Save a checkpoint
|
||||
if (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0:
|
||||
checkpoint_file = checkpoint.save_checkpoint(model, optimizer, cur_epoch)
|
||||
logger.info("Wrote checkpoint to: {}".format(checkpoint_file))
|
||||
# Evaluate the model
|
||||
next_epoch = cur_epoch + 1
|
||||
if next_epoch % cfg.TRAIN.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH:
|
||||
test_epoch(test_loader, model, test_meter, cur_epoch)
|
||||
|
||||
|
||||
def test_model():
|
||||
"""Evaluates a trained model."""
|
||||
# Setup training/testing environment
|
||||
setup_env()
|
||||
# Construct the model
|
||||
model = setup_model()
|
||||
# Load model weights
|
||||
checkpoint.load_checkpoint(cfg.TEST.WEIGHTS, model)
|
||||
logger.info("Loaded model weights from: {}".format(cfg.TEST.WEIGHTS))
|
||||
# Create data loaders and meters
|
||||
test_loader = loader.construct_test_loader()
|
||||
test_meter = meters.TestMeter(len(test_loader))
|
||||
# Evaluate the model
|
||||
test_epoch(test_loader, model, test_meter, 0)
|
||||
|
||||
|
||||
def time_model():
|
||||
"""Times model and data loader."""
|
||||
# Setup training/testing environment
|
||||
setup_env()
|
||||
# Construct the model and loss_fun
|
||||
model = setup_model()
|
||||
loss_fun = builders.build_loss_fun().cuda()
|
||||
# Create data loaders
|
||||
train_loader = loader.construct_train_loader()
|
||||
test_loader = loader.construct_test_loader()
|
||||
# Compute model and loader timings
|
||||
benchmark.compute_time_full(model, loss_fun, train_loader, test_loader)
|
||||
0
pycls/models/__init__.py
Normal file
0
pycls/models/__init__.py
Normal file
406
pycls/models/anynet.py
Normal file
406
pycls/models/anynet.py
Normal file
@@ -0,0 +1,406 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""AnyNet models."""
|
||||
|
||||
import pycls.core.net as net
|
||||
import torch.nn as nn
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
def get_stem_fun(stem_type):
|
||||
"""Retrieves the stem function by name."""
|
||||
stem_funs = {
|
||||
"res_stem_cifar": ResStemCifar,
|
||||
"res_stem_in": ResStemIN,
|
||||
"simple_stem_in": SimpleStemIN,
|
||||
}
|
||||
err_str = "Stem type '{}' not supported"
|
||||
assert stem_type in stem_funs.keys(), err_str.format(stem_type)
|
||||
return stem_funs[stem_type]
|
||||
|
||||
|
||||
def get_block_fun(block_type):
|
||||
"""Retrieves the block function by name."""
|
||||
block_funs = {
|
||||
"vanilla_block": VanillaBlock,
|
||||
"res_basic_block": ResBasicBlock,
|
||||
"res_bottleneck_block": ResBottleneckBlock,
|
||||
}
|
||||
err_str = "Block type '{}' not supported"
|
||||
assert block_type in block_funs.keys(), err_str.format(block_type)
|
||||
return block_funs[block_type]
|
||||
|
||||
|
||||
class AnyHead(nn.Module):
|
||||
"""AnyNet head: AvgPool, 1x1."""
|
||||
|
||||
def __init__(self, w_in, nc):
|
||||
super(AnyHead, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(w_in, nc, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.avg_pool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, nc):
|
||||
cx["h"], cx["w"] = 1, 1
|
||||
cx = net.complexity_conv2d(cx, w_in, nc, 1, 1, 0, bias=True)
|
||||
return cx
|
||||
|
||||
|
||||
class VanillaBlock(nn.Module):
|
||||
"""Vanilla block: [3x3 conv, BN, Relu] x2."""
|
||||
|
||||
def __init__(self, w_in, w_out, stride, bm=None, gw=None, se_r=None):
|
||||
err_str = "Vanilla block does not support bm, gw, and se_r options"
|
||||
assert bm is None and gw is None and se_r is None, err_str
|
||||
super(VanillaBlock, self).__init__()
|
||||
self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False)
|
||||
self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
|
||||
self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False)
|
||||
self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, stride, bm=None, gw=None, se_r=None):
|
||||
err_str = "Vanilla block does not support bm, gw, and se_r options"
|
||||
assert bm is None and gw is None and se_r is None, err_str
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 3, stride, 1)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
cx = net.complexity_conv2d(cx, w_out, w_out, 3, 1, 1)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class BasicTransform(nn.Module):
|
||||
"""Basic transformation: [3x3 conv, BN, Relu] x2."""
|
||||
|
||||
def __init__(self, w_in, w_out, stride):
|
||||
super(BasicTransform, self).__init__()
|
||||
self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False)
|
||||
self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
|
||||
self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False)
|
||||
self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.b_bn.final_bn = True
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, stride):
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 3, stride, 1)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
cx = net.complexity_conv2d(cx, w_out, w_out, 3, 1, 1)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class ResBasicBlock(nn.Module):
|
||||
"""Residual basic block: x + F(x), F = basic transform."""
|
||||
|
||||
def __init__(self, w_in, w_out, stride, bm=None, gw=None, se_r=None):
|
||||
err_str = "Basic transform does not support bm, gw, and se_r options"
|
||||
assert bm is None and gw is None and se_r is None, err_str
|
||||
super(ResBasicBlock, self).__init__()
|
||||
self.proj_block = (w_in != w_out) or (stride != 1)
|
||||
if self.proj_block:
|
||||
self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False)
|
||||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.f = BasicTransform(w_in, w_out, stride)
|
||||
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
|
||||
|
||||
def forward(self, x):
|
||||
if self.proj_block:
|
||||
x = self.bn(self.proj(x)) + self.f(x)
|
||||
else:
|
||||
x = x + self.f(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, stride, bm=None, gw=None, se_r=None):
|
||||
err_str = "Basic transform does not support bm, gw, and se_r options"
|
||||
assert bm is None and gw is None and se_r is None, err_str
|
||||
proj_block = (w_in != w_out) or (stride != 1)
|
||||
if proj_block:
|
||||
h, w = cx["h"], cx["w"]
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
cx["h"], cx["w"] = h, w # parallel branch
|
||||
cx = BasicTransform.complexity(cx, w_in, w_out, stride)
|
||||
return cx
|
||||
|
||||
|
||||
class SE(nn.Module):
|
||||
"""Squeeze-and-Excitation (SE) block: AvgPool, FC, ReLU, FC, Sigmoid."""
|
||||
|
||||
def __init__(self, w_in, w_se):
|
||||
super(SE, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.f_ex = nn.Sequential(
|
||||
nn.Conv2d(w_in, w_se, 1, bias=True),
|
||||
nn.ReLU(inplace=cfg.MEM.RELU_INPLACE),
|
||||
nn.Conv2d(w_se, w_in, 1, bias=True),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.f_ex(self.avg_pool(x))
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_se):
|
||||
h, w = cx["h"], cx["w"]
|
||||
cx["h"], cx["w"] = 1, 1
|
||||
cx = net.complexity_conv2d(cx, w_in, w_se, 1, 1, 0, bias=True)
|
||||
cx = net.complexity_conv2d(cx, w_se, w_in, 1, 1, 0, bias=True)
|
||||
cx["h"], cx["w"] = h, w
|
||||
return cx
|
||||
|
||||
|
||||
class BottleneckTransform(nn.Module):
|
||||
"""Bottleneck transformation: 1x1, 3x3 [+SE], 1x1."""
|
||||
|
||||
def __init__(self, w_in, w_out, stride, bm, gw, se_r):
|
||||
super(BottleneckTransform, self).__init__()
|
||||
w_b = int(round(w_out * bm))
|
||||
g = w_b // gw
|
||||
self.a = nn.Conv2d(w_in, w_b, 1, stride=1, padding=0, bias=False)
|
||||
self.a_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
|
||||
self.b = nn.Conv2d(w_b, w_b, 3, stride=stride, padding=1, groups=g, bias=False)
|
||||
self.b_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
|
||||
if se_r:
|
||||
w_se = int(round(w_in * se_r))
|
||||
self.se = SE(w_b, w_se)
|
||||
self.c = nn.Conv2d(w_b, w_out, 1, stride=1, padding=0, bias=False)
|
||||
self.c_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.c_bn.final_bn = True
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, stride, bm, gw, se_r):
|
||||
w_b = int(round(w_out * bm))
|
||||
g = w_b // gw
|
||||
cx = net.complexity_conv2d(cx, w_in, w_b, 1, 1, 0)
|
||||
cx = net.complexity_batchnorm2d(cx, w_b)
|
||||
cx = net.complexity_conv2d(cx, w_b, w_b, 3, stride, 1, g)
|
||||
cx = net.complexity_batchnorm2d(cx, w_b)
|
||||
if se_r:
|
||||
w_se = int(round(w_in * se_r))
|
||||
cx = SE.complexity(cx, w_b, w_se)
|
||||
cx = net.complexity_conv2d(cx, w_b, w_out, 1, 1, 0)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class ResBottleneckBlock(nn.Module):
|
||||
"""Residual bottleneck block: x + F(x), F = bottleneck transform."""
|
||||
|
||||
def __init__(self, w_in, w_out, stride, bm=1.0, gw=1, se_r=None):
|
||||
super(ResBottleneckBlock, self).__init__()
|
||||
# Use skip connection with projection if shape changes
|
||||
self.proj_block = (w_in != w_out) or (stride != 1)
|
||||
if self.proj_block:
|
||||
self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False)
|
||||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.f = BottleneckTransform(w_in, w_out, stride, bm, gw, se_r)
|
||||
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
|
||||
|
||||
def forward(self, x):
|
||||
if self.proj_block:
|
||||
x = self.bn(self.proj(x)) + self.f(x)
|
||||
else:
|
||||
x = x + self.f(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, stride, bm=1.0, gw=1, se_r=None):
|
||||
proj_block = (w_in != w_out) or (stride != 1)
|
||||
if proj_block:
|
||||
h, w = cx["h"], cx["w"]
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
cx["h"], cx["w"] = h, w # parallel branch
|
||||
cx = BottleneckTransform.complexity(cx, w_in, w_out, stride, bm, gw, se_r)
|
||||
return cx
|
||||
|
||||
|
||||
class ResStemCifar(nn.Module):
|
||||
"""ResNet stem for CIFAR: 3x3, BN, ReLU."""
|
||||
|
||||
def __init__(self, w_in, w_out):
|
||||
super(ResStemCifar, self).__init__()
|
||||
self.conv = nn.Conv2d(w_in, w_out, 3, stride=1, padding=1, bias=False)
|
||||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out):
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 3, 1, 1)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class ResStemIN(nn.Module):
|
||||
"""ResNet stem for ImageNet: 7x7, BN, ReLU, MaxPool."""
|
||||
|
||||
def __init__(self, w_in, w_out):
|
||||
super(ResStemIN, self).__init__()
|
||||
self.conv = nn.Conv2d(w_in, w_out, 7, stride=2, padding=3, bias=False)
|
||||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
|
||||
self.pool = nn.MaxPool2d(3, stride=2, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out):
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 7, 2, 3)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
cx = net.complexity_maxpool2d(cx, 3, 2, 1)
|
||||
return cx
|
||||
|
||||
|
||||
class SimpleStemIN(nn.Module):
|
||||
"""Simple stem for ImageNet: 3x3, BN, ReLU."""
|
||||
|
||||
def __init__(self, w_in, w_out):
|
||||
super(SimpleStemIN, self).__init__()
|
||||
self.conv = nn.Conv2d(w_in, w_out, 3, stride=2, padding=1, bias=False)
|
||||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out):
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 3, 2, 1)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class AnyStage(nn.Module):
|
||||
"""AnyNet stage (sequence of blocks w/ the same output shape)."""
|
||||
|
||||
def __init__(self, w_in, w_out, stride, d, block_fun, bm, gw, se_r):
|
||||
super(AnyStage, self).__init__()
|
||||
for i in range(d):
|
||||
b_stride = stride if i == 0 else 1
|
||||
b_w_in = w_in if i == 0 else w_out
|
||||
name = "b{}".format(i + 1)
|
||||
self.add_module(name, block_fun(b_w_in, w_out, b_stride, bm, gw, se_r))
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.children():
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, stride, d, block_fun, bm, gw, se_r):
|
||||
for i in range(d):
|
||||
b_stride = stride if i == 0 else 1
|
||||
b_w_in = w_in if i == 0 else w_out
|
||||
cx = block_fun.complexity(cx, b_w_in, w_out, b_stride, bm, gw, se_r)
|
||||
return cx
|
||||
|
||||
|
||||
class AnyNet(nn.Module):
|
||||
"""AnyNet model."""
|
||||
|
||||
@staticmethod
|
||||
def get_args():
|
||||
return {
|
||||
"stem_type": cfg.ANYNET.STEM_TYPE,
|
||||
"stem_w": cfg.ANYNET.STEM_W,
|
||||
"block_type": cfg.ANYNET.BLOCK_TYPE,
|
||||
"ds": cfg.ANYNET.DEPTHS,
|
||||
"ws": cfg.ANYNET.WIDTHS,
|
||||
"ss": cfg.ANYNET.STRIDES,
|
||||
"bms": cfg.ANYNET.BOT_MULS,
|
||||
"gws": cfg.ANYNET.GROUP_WS,
|
||||
"se_r": cfg.ANYNET.SE_R if cfg.ANYNET.SE_ON else None,
|
||||
"nc": cfg.MODEL.NUM_CLASSES,
|
||||
}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(AnyNet, self).__init__()
|
||||
kwargs = self.get_args() if not kwargs else kwargs
|
||||
#print(kwargs)
|
||||
self._construct(**kwargs)
|
||||
self.apply(net.init_weights)
|
||||
|
||||
def _construct(self, stem_type, stem_w, block_type, ds, ws, ss, bms, gws, se_r, nc):
|
||||
# Generate dummy bot muls and gs for models that do not use them
|
||||
bms = bms if bms else [None for _d in ds]
|
||||
gws = gws if gws else [None for _d in ds]
|
||||
stage_params = list(zip(ds, ws, ss, bms, gws))
|
||||
stem_fun = get_stem_fun(stem_type)
|
||||
self.stem = stem_fun(3, stem_w)
|
||||
block_fun = get_block_fun(block_type)
|
||||
prev_w = stem_w
|
||||
for i, (d, w, s, bm, gw) in enumerate(stage_params):
|
||||
name = "s{}".format(i + 1)
|
||||
self.add_module(name, AnyStage(prev_w, w, s, d, block_fun, bm, gw, se_r))
|
||||
prev_w = w
|
||||
self.head = AnyHead(w_in=prev_w, nc=nc)
|
||||
|
||||
def forward(self, x, get_ints=False):
|
||||
for module in self.children():
|
||||
x = module(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, **kwargs):
|
||||
"""Computes model complexity. If you alter the model, make sure to update."""
|
||||
kwargs = AnyNet.get_args() if not kwargs else kwargs
|
||||
return AnyNet._complexity(cx, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _complexity(cx, stem_type, stem_w, block_type, ds, ws, ss, bms, gws, se_r, nc):
|
||||
bms = bms if bms else [None for _d in ds]
|
||||
gws = gws if gws else [None for _d in ds]
|
||||
stage_params = list(zip(ds, ws, ss, bms, gws))
|
||||
stem_fun = get_stem_fun(stem_type)
|
||||
cx = stem_fun.complexity(cx, 3, stem_w)
|
||||
block_fun = get_block_fun(block_type)
|
||||
prev_w = stem_w
|
||||
for d, w, s, bm, gw in stage_params:
|
||||
cx = AnyStage.complexity(cx, prev_w, w, s, d, block_fun, bm, gw, se_r)
|
||||
prev_w = w
|
||||
cx = AnyHead.complexity(cx, prev_w, nc)
|
||||
return cx
|
||||
108
pycls/models/common.py
Normal file
108
pycls/models/common.py
Normal file
@@ -0,0 +1,108 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
def Preprocess(x):
|
||||
if cfg.TASK == 'jig':
|
||||
assert len(x.shape) == 5, 'Wrong tensor dimension for jigsaw'
|
||||
assert x.shape[1] == cfg.JIGSAW_GRID ** 2, 'Wrong grid for jigsaw'
|
||||
x = x.view([x.shape[0] * x.shape[1], x.shape[2], x.shape[3], x.shape[4]])
|
||||
return x
|
||||
|
||||
|
||||
class Classifier(nn.Module):
|
||||
def __init__(self, channels, num_classes):
|
||||
super(Classifier, self).__init__()
|
||||
if cfg.TASK == 'jig':
|
||||
self.jig_sq = cfg.JIGSAW_GRID ** 2
|
||||
self.pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(channels * self.jig_sq, num_classes)
|
||||
elif cfg.TASK == 'col':
|
||||
self.classifier = nn.Conv2d(channels, num_classes, kernel_size=1, stride=1)
|
||||
elif cfg.TASK == 'seg':
|
||||
self.classifier = ASPP(channels, cfg.MODEL.ASPP_CHANNELS, num_classes, cfg.MODEL.ASPP_RATES)
|
||||
else:
|
||||
self.pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(channels, num_classes)
|
||||
|
||||
def forward(self, x, shape):
|
||||
if cfg.TASK == 'jig':
|
||||
x = self.pooling(x)
|
||||
x = x.view([x.shape[0] // self.jig_sq, x.shape[1] * self.jig_sq, x.shape[2], x.shape[3]])
|
||||
x = self.classifier(x.view(x.size(0), -1))
|
||||
elif cfg.TASK in ['col', 'seg']:
|
||||
x = self.classifier(x)
|
||||
x = nn.Upsample(shape, mode='bilinear', align_corners=True)(x)
|
||||
else:
|
||||
x = self.pooling(x)
|
||||
x = self.classifier(x.view(x.size(0), -1))
|
||||
return x
|
||||
|
||||
|
||||
class ASPP(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, num_classes, rates):
|
||||
super(ASPP, self).__init__()
|
||||
assert len(rates) in [1, 3]
|
||||
self.rates = rates
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.aspp1 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.aspp2 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 3, dilation=rates[0],
|
||||
padding=rates[0], bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
if len(self.rates) == 3:
|
||||
self.aspp3 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 3, dilation=rates[1],
|
||||
padding=rates[1], bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.aspp4 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 3, dilation=rates[2],
|
||||
padding=rates[2], bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.aspp5 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Conv2d(out_channels * (len(rates) + 2), out_channels, 1,
|
||||
bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channels, num_classes, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.aspp1(x)
|
||||
x2 = self.aspp2(x)
|
||||
x5 = self.global_pooling(x)
|
||||
x5 = self.aspp5(x5)
|
||||
x5 = nn.Upsample((x.shape[2], x.shape[3]), mode='bilinear',
|
||||
align_corners=True)(x5)
|
||||
if len(self.rates) == 3:
|
||||
x3 = self.aspp3(x)
|
||||
x4 = self.aspp4(x)
|
||||
x = torch.cat((x1, x2, x3, x4, x5), 1)
|
||||
else:
|
||||
x = torch.cat((x1, x2, x5), 1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
232
pycls/models/effnet.py
Normal file
232
pycls/models/effnet.py
Normal file
@@ -0,0 +1,232 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""EfficientNet models."""
|
||||
|
||||
import pycls.core.net as net
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
class EffHead(nn.Module):
|
||||
"""EfficientNet head: 1x1, BN, Swish, AvgPool, Dropout, FC."""
|
||||
|
||||
def __init__(self, w_in, w_out, nc):
|
||||
super(EffHead, self).__init__()
|
||||
self.conv = nn.Conv2d(w_in, w_out, 1, stride=1, padding=0, bias=False)
|
||||
self.conv_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.conv_swish = Swish()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
if cfg.EN.DROPOUT_RATIO > 0.0:
|
||||
self.dropout = nn.Dropout(p=cfg.EN.DROPOUT_RATIO)
|
||||
self.fc = nn.Linear(w_out, nc, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_swish(self.conv_bn(self.conv(x)))
|
||||
x = self.avg_pool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.dropout(x) if hasattr(self, "dropout") else x
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, nc):
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 1, 1, 0)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
cx["h"], cx["w"] = 1, 1
|
||||
cx = net.complexity_conv2d(cx, w_out, nc, 1, 1, 0, bias=True)
|
||||
return cx
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
"""Swish activation function: x * sigmoid(x)."""
|
||||
|
||||
def __init__(self):
|
||||
super(Swish, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class SE(nn.Module):
|
||||
"""Squeeze-and-Excitation (SE) block w/ Swish: AvgPool, FC, Swish, FC, Sigmoid."""
|
||||
|
||||
def __init__(self, w_in, w_se):
|
||||
super(SE, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.f_ex = nn.Sequential(
|
||||
nn.Conv2d(w_in, w_se, 1, bias=True),
|
||||
Swish(),
|
||||
nn.Conv2d(w_se, w_in, 1, bias=True),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.f_ex(self.avg_pool(x))
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_se):
|
||||
h, w = cx["h"], cx["w"]
|
||||
cx["h"], cx["w"] = 1, 1
|
||||
cx = net.complexity_conv2d(cx, w_in, w_se, 1, 1, 0, bias=True)
|
||||
cx = net.complexity_conv2d(cx, w_se, w_in, 1, 1, 0, bias=True)
|
||||
cx["h"], cx["w"] = h, w
|
||||
return cx
|
||||
|
||||
|
||||
class MBConv(nn.Module):
|
||||
"""Mobile inverted bottleneck block w/ SE (MBConv)."""
|
||||
|
||||
def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out):
|
||||
# expansion, 3x3 dwise, BN, Swish, SE, 1x1, BN, skip_connection
|
||||
super(MBConv, self).__init__()
|
||||
self.exp = None
|
||||
w_exp = int(w_in * exp_r)
|
||||
if w_exp != w_in:
|
||||
self.exp = nn.Conv2d(w_in, w_exp, 1, stride=1, padding=0, bias=False)
|
||||
self.exp_bn = nn.BatchNorm2d(w_exp, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.exp_swish = Swish()
|
||||
dwise_args = {"groups": w_exp, "padding": (kernel - 1) // 2, "bias": False}
|
||||
self.dwise = nn.Conv2d(w_exp, w_exp, kernel, stride=stride, **dwise_args)
|
||||
self.dwise_bn = nn.BatchNorm2d(w_exp, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.dwise_swish = Swish()
|
||||
self.se = SE(w_exp, int(w_in * se_r))
|
||||
self.lin_proj = nn.Conv2d(w_exp, w_out, 1, stride=1, padding=0, bias=False)
|
||||
self.lin_proj_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
# Skip connection if in and out shapes are the same (MN-V2 style)
|
||||
self.has_skip = stride == 1 and w_in == w_out
|
||||
|
||||
def forward(self, x):
|
||||
f_x = x
|
||||
if self.exp:
|
||||
f_x = self.exp_swish(self.exp_bn(self.exp(f_x)))
|
||||
f_x = self.dwise_swish(self.dwise_bn(self.dwise(f_x)))
|
||||
f_x = self.se(f_x)
|
||||
f_x = self.lin_proj_bn(self.lin_proj(f_x))
|
||||
if self.has_skip:
|
||||
if self.training and cfg.EN.DC_RATIO > 0.0:
|
||||
f_x = net.drop_connect(f_x, cfg.EN.DC_RATIO)
|
||||
f_x = x + f_x
|
||||
return f_x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, exp_r, kernel, stride, se_r, w_out):
|
||||
w_exp = int(w_in * exp_r)
|
||||
if w_exp != w_in:
|
||||
cx = net.complexity_conv2d(cx, w_in, w_exp, 1, 1, 0)
|
||||
cx = net.complexity_batchnorm2d(cx, w_exp)
|
||||
padding = (kernel - 1) // 2
|
||||
cx = net.complexity_conv2d(cx, w_exp, w_exp, kernel, stride, padding, w_exp)
|
||||
cx = net.complexity_batchnorm2d(cx, w_exp)
|
||||
cx = SE.complexity(cx, w_exp, int(w_in * se_r))
|
||||
cx = net.complexity_conv2d(cx, w_exp, w_out, 1, 1, 0)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class EffStage(nn.Module):
|
||||
"""EfficientNet stage."""
|
||||
|
||||
def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out, d):
|
||||
super(EffStage, self).__init__()
|
||||
for i in range(d):
|
||||
b_stride = stride if i == 0 else 1
|
||||
b_w_in = w_in if i == 0 else w_out
|
||||
name = "b{}".format(i + 1)
|
||||
self.add_module(name, MBConv(b_w_in, exp_r, kernel, b_stride, se_r, w_out))
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.children():
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, exp_r, kernel, stride, se_r, w_out, d):
|
||||
for i in range(d):
|
||||
b_stride = stride if i == 0 else 1
|
||||
b_w_in = w_in if i == 0 else w_out
|
||||
cx = MBConv.complexity(cx, b_w_in, exp_r, kernel, b_stride, se_r, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class StemIN(nn.Module):
|
||||
"""EfficientNet stem for ImageNet: 3x3, BN, Swish."""
|
||||
|
||||
def __init__(self, w_in, w_out):
|
||||
super(StemIN, self).__init__()
|
||||
self.conv = nn.Conv2d(w_in, w_out, 3, stride=2, padding=1, bias=False)
|
||||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.swish = Swish()
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out):
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 3, 2, 1)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class EffNet(nn.Module):
|
||||
"""EfficientNet model."""
|
||||
|
||||
@staticmethod
|
||||
def get_args():
|
||||
return {
|
||||
"stem_w": cfg.EN.STEM_W,
|
||||
"ds": cfg.EN.DEPTHS,
|
||||
"ws": cfg.EN.WIDTHS,
|
||||
"exp_rs": cfg.EN.EXP_RATIOS,
|
||||
"se_r": cfg.EN.SE_R,
|
||||
"ss": cfg.EN.STRIDES,
|
||||
"ks": cfg.EN.KERNELS,
|
||||
"head_w": cfg.EN.HEAD_W,
|
||||
"nc": cfg.MODEL.NUM_CLASSES,
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
err_str = "Dataset {} is not supported"
|
||||
assert cfg.TRAIN.DATASET in ["imagenet"], err_str.format(cfg.TRAIN.DATASET)
|
||||
assert cfg.TEST.DATASET in ["imagenet"], err_str.format(cfg.TEST.DATASET)
|
||||
super(EffNet, self).__init__()
|
||||
self._construct(**EffNet.get_args())
|
||||
self.apply(net.init_weights)
|
||||
|
||||
def _construct(self, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, nc):
|
||||
stage_params = list(zip(ds, ws, exp_rs, ss, ks))
|
||||
self.stem = StemIN(3, stem_w)
|
||||
prev_w = stem_w
|
||||
for i, (d, w, exp_r, stride, kernel) in enumerate(stage_params):
|
||||
name = "s{}".format(i + 1)
|
||||
self.add_module(name, EffStage(prev_w, exp_r, kernel, stride, se_r, w, d))
|
||||
prev_w = w
|
||||
self.head = EffHead(prev_w, head_w, nc)
|
||||
|
||||
def forward(self, x):
|
||||
for module in self.children():
|
||||
x = module(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx):
|
||||
"""Computes model complexity. If you alter the model, make sure to update."""
|
||||
return EffNet._complexity(cx, **EffNet.get_args())
|
||||
|
||||
@staticmethod
|
||||
def _complexity(cx, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, nc):
|
||||
stage_params = list(zip(ds, ws, exp_rs, ss, ks))
|
||||
cx = StemIN.complexity(cx, 3, stem_w)
|
||||
prev_w = stem_w
|
||||
for d, w, exp_r, stride, kernel in stage_params:
|
||||
cx = EffStage.complexity(cx, prev_w, exp_r, kernel, stride, se_r, w, d)
|
||||
prev_w = w
|
||||
cx = EffHead.complexity(cx, prev_w, head_w, nc)
|
||||
return cx
|
||||
634
pycls/models/nas/genotypes.py
Normal file
634
pycls/models/nas/genotypes.py
Normal file
@@ -0,0 +1,634 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""NAS genotypes (adopted from DARTS)."""
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
|
||||
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
|
||||
|
||||
|
||||
# NASNet ops
|
||||
NASNET_OPS = [
|
||||
'skip_connect',
|
||||
'conv_3x1_1x3',
|
||||
'conv_7x1_1x7',
|
||||
'dil_conv_3x3',
|
||||
'avg_pool_3x3',
|
||||
'max_pool_3x3',
|
||||
'max_pool_5x5',
|
||||
'max_pool_7x7',
|
||||
'conv_1x1',
|
||||
'conv_3x3',
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'sep_conv_7x7',
|
||||
]
|
||||
|
||||
# ENAS ops
|
||||
ENAS_OPS = [
|
||||
'skip_connect',
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'avg_pool_3x3',
|
||||
'max_pool_3x3',
|
||||
]
|
||||
|
||||
# AmoebaNet ops
|
||||
AMOEBA_OPS = [
|
||||
'skip_connect',
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'sep_conv_7x7',
|
||||
'avg_pool_3x3',
|
||||
'max_pool_3x3',
|
||||
'dil_sep_conv_3x3',
|
||||
'conv_7x1_1x7',
|
||||
]
|
||||
|
||||
# NAO ops
|
||||
NAO_OPS = [
|
||||
'skip_connect',
|
||||
'conv_1x1',
|
||||
'conv_3x3',
|
||||
'conv_3x1_1x3',
|
||||
'conv_7x1_1x7',
|
||||
'max_pool_2x2',
|
||||
'max_pool_3x3',
|
||||
'max_pool_5x5',
|
||||
'avg_pool_2x2',
|
||||
'avg_pool_3x3',
|
||||
'avg_pool_5x5',
|
||||
]
|
||||
|
||||
# PNAS ops
|
||||
PNAS_OPS = [
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'sep_conv_7x7',
|
||||
'conv_7x1_1x7',
|
||||
'skip_connect',
|
||||
'avg_pool_3x3',
|
||||
'max_pool_3x3',
|
||||
'dil_conv_3x3',
|
||||
]
|
||||
|
||||
# DARTS ops
|
||||
DARTS_OPS = [
|
||||
'none',
|
||||
'max_pool_3x3',
|
||||
'avg_pool_3x3',
|
||||
'skip_connect',
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'dil_conv_3x3',
|
||||
'dil_conv_5x5',
|
||||
]
|
||||
|
||||
|
||||
NASNet = Genotype(
|
||||
normal=[
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('avg_pool_3x3', 1),
|
||||
('skip_connect', 0),
|
||||
('avg_pool_3x3', 0),
|
||||
('avg_pool_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('skip_connect', 1),
|
||||
],
|
||||
normal_concat=[2, 3, 4, 5, 6],
|
||||
reduce=[
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_7x7', 0),
|
||||
('max_pool_3x3', 1),
|
||||
('sep_conv_7x7', 0),
|
||||
('avg_pool_3x3', 1),
|
||||
('sep_conv_5x5', 0),
|
||||
('skip_connect', 3),
|
||||
('avg_pool_3x3', 2),
|
||||
('sep_conv_3x3', 2),
|
||||
('max_pool_3x3', 1),
|
||||
],
|
||||
reduce_concat=[4, 5, 6],
|
||||
)
|
||||
|
||||
|
||||
PNASNet = Genotype(
|
||||
normal=[
|
||||
('sep_conv_5x5', 0),
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_7x7', 1),
|
||||
('max_pool_3x3', 1),
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 4),
|
||||
('max_pool_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('skip_connect', 1),
|
||||
],
|
||||
normal_concat=[2, 3, 4, 5, 6],
|
||||
reduce=[
|
||||
('sep_conv_5x5', 0),
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_7x7', 1),
|
||||
('max_pool_3x3', 1),
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 4),
|
||||
('max_pool_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('skip_connect', 1),
|
||||
],
|
||||
reduce_concat=[2, 3, 4, 5, 6],
|
||||
)
|
||||
|
||||
|
||||
AmoebaNet = Genotype(
|
||||
normal=[
|
||||
('avg_pool_3x3', 0),
|
||||
('max_pool_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_5x5', 2),
|
||||
('sep_conv_3x3', 0),
|
||||
('avg_pool_3x3', 3),
|
||||
('sep_conv_3x3', 1),
|
||||
('skip_connect', 1),
|
||||
('skip_connect', 0),
|
||||
('avg_pool_3x3', 1),
|
||||
],
|
||||
normal_concat=[4, 5, 6],
|
||||
reduce=[
|
||||
('avg_pool_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_7x7', 2),
|
||||
('sep_conv_7x7', 0),
|
||||
('avg_pool_3x3', 1),
|
||||
('max_pool_3x3', 0),
|
||||
('max_pool_3x3', 1),
|
||||
('conv_7x1_1x7', 0),
|
||||
('sep_conv_3x3', 5),
|
||||
],
|
||||
reduce_concat=[3, 4, 6]
|
||||
)
|
||||
|
||||
|
||||
DARTS_V1 = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('skip_connect', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('skip_connect', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('skip_connect', 2)
|
||||
],
|
||||
normal_concat=[2, 3, 4, 5],
|
||||
reduce=[
|
||||
('max_pool_3x3', 0),
|
||||
('max_pool_3x3', 1),
|
||||
('skip_connect', 2),
|
||||
('max_pool_3x3', 0),
|
||||
('max_pool_3x3', 0),
|
||||
('skip_connect', 2),
|
||||
('skip_connect', 2),
|
||||
('avg_pool_3x3', 0)
|
||||
],
|
||||
reduce_concat=[2, 3, 4, 5]
|
||||
)
|
||||
|
||||
|
||||
DARTS_V2 = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('skip_connect', 0),
|
||||
('skip_connect', 0),
|
||||
('dil_conv_3x3', 2)
|
||||
],
|
||||
normal_concat=[2, 3, 4, 5],
|
||||
reduce=[
|
||||
('max_pool_3x3', 0),
|
||||
('max_pool_3x3', 1),
|
||||
('skip_connect', 2),
|
||||
('max_pool_3x3', 1),
|
||||
('max_pool_3x3', 0),
|
||||
('skip_connect', 2),
|
||||
('skip_connect', 2),
|
||||
('max_pool_3x3', 1)
|
||||
],
|
||||
reduce_concat=[2, 3, 4, 5]
|
||||
)
|
||||
|
||||
PDARTS = Genotype(
|
||||
normal=[
|
||||
('skip_connect', 0),
|
||||
('dil_conv_3x3', 1),
|
||||
('skip_connect', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 3),
|
||||
('sep_conv_3x3', 0),
|
||||
('dil_conv_5x5', 4)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('avg_pool_3x3', 0),
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('dil_conv_5x5', 2),
|
||||
('max_pool_3x3', 0),
|
||||
('dil_conv_3x3', 1),
|
||||
('dil_conv_3x3', 1),
|
||||
('dil_conv_5x5', 3)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
PCDARTS_C10 = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 1),
|
||||
('skip_connect', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('dil_conv_3x3', 1),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('avg_pool_3x3', 0),
|
||||
('dil_conv_3x3', 1)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('sep_conv_5x5', 1),
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_5x5', 2),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 3),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 2)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
PCDARTS_IN1K = Genotype(
|
||||
normal=[
|
||||
('skip_connect', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('skip_connect', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 3),
|
||||
('sep_conv_3x3', 1),
|
||||
('dil_conv_5x5', 4)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('sep_conv_3x3', 0),
|
||||
('skip_connect', 1),
|
||||
('dil_conv_5x5', 2),
|
||||
('max_pool_3x3', 1),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 3)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_IMAGENET_CLS = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_3x3', 0)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('max_pool_3x3', 0),
|
||||
('skip_connect', 1),
|
||||
('max_pool_3x3', 0),
|
||||
('dil_conv_5x5', 2),
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 4),
|
||||
('dil_conv_5x5', 3)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_IMAGENET_ROT = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 4),
|
||||
('sep_conv_5x5', 2)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_IMAGENET_COL = Genotype(
|
||||
normal=[
|
||||
('skip_connect', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('skip_connect', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 3),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 2)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_5x5', 3),
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_3x3', 4)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_IMAGENET_JIG = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 3),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_5x5', 0)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 1)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_IMAGENET22K_CLS = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 1),
|
||||
('skip_connect', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('max_pool_3x3', 0),
|
||||
('max_pool_3x3', 1),
|
||||
('dil_conv_5x5', 2),
|
||||
('max_pool_3x3', 0),
|
||||
('dil_conv_5x5', 3),
|
||||
('dil_conv_5x5', 2),
|
||||
('dil_conv_5x5', 4),
|
||||
('dil_conv_5x5', 3)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_IMAGENET22K_ROT = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_5x5', 1),
|
||||
('dil_conv_5x5', 2),
|
||||
('sep_conv_5x5', 0),
|
||||
('dil_conv_5x5', 3),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 4),
|
||||
('sep_conv_3x3', 3)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_IMAGENET22K_COL = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 3),
|
||||
('sep_conv_3x3', 0)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('max_pool_3x3', 0),
|
||||
('skip_connect', 1),
|
||||
('dil_conv_5x5', 2),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 3),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 4),
|
||||
('sep_conv_5x5', 1)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_IMAGENET22K_JIG = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 4)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('sep_conv_5x5', 0),
|
||||
('skip_connect', 1),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_5x5', 3),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_5x5', 4)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_CITYSCAPES_SEG = Genotype(
|
||||
normal=[
|
||||
('skip_connect', 0),
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('sep_conv_3x3', 0),
|
||||
('avg_pool_3x3', 1),
|
||||
('avg_pool_3x3', 1),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 4),
|
||||
('sep_conv_5x5', 2)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_CITYSCAPES_ROT = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 3),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_5x5', 2),
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_5x5', 3),
|
||||
('dil_conv_5x5', 2),
|
||||
('sep_conv_5x5', 2),
|
||||
('sep_conv_5x5', 0)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_CITYSCAPES_COL = Genotype(
|
||||
normal=[
|
||||
('dil_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('skip_connect', 0),
|
||||
('sep_conv_5x5', 2),
|
||||
('dil_conv_3x3', 3),
|
||||
('skip_connect', 0),
|
||||
('skip_connect', 0),
|
||||
('sep_conv_3x3', 1)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('avg_pool_3x3', 1),
|
||||
('avg_pool_3x3', 0),
|
||||
('avg_pool_3x3', 1),
|
||||
('avg_pool_3x3', 0),
|
||||
('avg_pool_3x3', 1),
|
||||
('avg_pool_3x3', 0),
|
||||
('avg_pool_3x3', 1),
|
||||
('skip_connect', 4)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_CITYSCAPES_JIG = Genotype(
|
||||
normal=[
|
||||
('dil_conv_5x5', 1),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 0),
|
||||
('dil_conv_5x5', 1)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('avg_pool_3x3', 0),
|
||||
('skip_connect', 1),
|
||||
('dil_conv_5x5', 1),
|
||||
('dil_conv_5x5', 2),
|
||||
('dil_conv_5x5', 2),
|
||||
('dil_conv_5x5', 0),
|
||||
('dil_conv_5x5', 3),
|
||||
('dil_conv_5x5', 2)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
|
||||
# Supported genotypes
|
||||
GENOTYPES = {
|
||||
'nas': NASNet,
|
||||
'pnas': PNASNet,
|
||||
'amoeba': AmoebaNet,
|
||||
'darts_v1': DARTS_V1,
|
||||
'darts_v2': DARTS_V2,
|
||||
'pdarts': PDARTS,
|
||||
'pcdarts_c10': PCDARTS_C10,
|
||||
'pcdarts_in1k': PCDARTS_IN1K,
|
||||
'unnas_imagenet_cls': UNNAS_IMAGENET_CLS,
|
||||
'unnas_imagenet_rot': UNNAS_IMAGENET_ROT,
|
||||
'unnas_imagenet_col': UNNAS_IMAGENET_COL,
|
||||
'unnas_imagenet_jig': UNNAS_IMAGENET_JIG,
|
||||
'unnas_imagenet22k_cls': UNNAS_IMAGENET22K_CLS,
|
||||
'unnas_imagenet22k_rot': UNNAS_IMAGENET22K_ROT,
|
||||
'unnas_imagenet22k_col': UNNAS_IMAGENET22K_COL,
|
||||
'unnas_imagenet22k_jig': UNNAS_IMAGENET22K_JIG,
|
||||
'unnas_cityscapes_seg': UNNAS_CITYSCAPES_SEG,
|
||||
'unnas_cityscapes_rot': UNNAS_CITYSCAPES_ROT,
|
||||
'unnas_cityscapes_col': UNNAS_CITYSCAPES_COL,
|
||||
'unnas_cityscapes_jig': UNNAS_CITYSCAPES_JIG,
|
||||
'custom': None,
|
||||
}
|
||||
299
pycls/models/nas/nas.py
Normal file
299
pycls/models/nas/nas.py
Normal file
@@ -0,0 +1,299 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""NAS network (adopted from DARTS)."""
|
||||
|
||||
from torch.autograd import Variable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import pycls.core.logging as logging
|
||||
|
||||
from pycls.core.config import cfg
|
||||
from pycls.models.common import Preprocess
|
||||
from pycls.models.common import Classifier
|
||||
from pycls.models.nas.genotypes import GENOTYPES
|
||||
from pycls.models.nas.genotypes import Genotype
|
||||
from pycls.models.nas.operations import FactorizedReduce
|
||||
from pycls.models.nas.operations import OPS
|
||||
from pycls.models.nas.operations import ReLUConvBN
|
||||
from pycls.models.nas.operations import Identity
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def drop_path(x, drop_prob):
|
||||
"""Drop path (ported from DARTS)."""
|
||||
if drop_prob > 0.:
|
||||
keep_prob = 1.-drop_prob
|
||||
mask = Variable(
|
||||
torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)
|
||||
)
|
||||
x.div_(keep_prob)
|
||||
x.mul_(mask)
|
||||
return x
|
||||
|
||||
|
||||
class Cell(nn.Module):
|
||||
"""NAS cell (ported from DARTS)."""
|
||||
|
||||
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
|
||||
super(Cell, self).__init__()
|
||||
logger.info('{}, {}, {}'.format(C_prev_prev, C_prev, C))
|
||||
|
||||
if reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(C_prev_prev, C)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
|
||||
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
|
||||
|
||||
if reduction:
|
||||
op_names, indices = zip(*genotype.reduce)
|
||||
concat = genotype.reduce_concat
|
||||
else:
|
||||
op_names, indices = zip(*genotype.normal)
|
||||
concat = genotype.normal_concat
|
||||
self._compile(C, op_names, indices, concat, reduction)
|
||||
|
||||
def _compile(self, C, op_names, indices, concat, reduction):
|
||||
assert len(op_names) == len(indices)
|
||||
self._steps = len(op_names) // 2
|
||||
self._concat = concat
|
||||
self.multiplier = len(concat)
|
||||
|
||||
self._ops = nn.ModuleList()
|
||||
for name, index in zip(op_names, indices):
|
||||
stride = 2 if reduction and index < 2 else 1
|
||||
op = OPS[name](C, stride, True)
|
||||
self._ops += [op]
|
||||
self._indices = indices
|
||||
|
||||
def forward(self, s0, s1, drop_prob):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
states = [s0, s1]
|
||||
for i in range(self._steps):
|
||||
h1 = states[self._indices[2*i]]
|
||||
h2 = states[self._indices[2*i+1]]
|
||||
op1 = self._ops[2*i]
|
||||
op2 = self._ops[2*i+1]
|
||||
h1 = op1(h1)
|
||||
h2 = op2(h2)
|
||||
if self.training and drop_prob > 0.:
|
||||
if not isinstance(op1, Identity):
|
||||
h1 = drop_path(h1, drop_prob)
|
||||
if not isinstance(op2, Identity):
|
||||
h2 = drop_path(h2, drop_prob)
|
||||
s = h1 + h2
|
||||
states += [s]
|
||||
return torch.cat([states[i] for i in self._concat], dim=1)
|
||||
|
||||
|
||||
class AuxiliaryHeadCIFAR(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes):
|
||||
"""assuming input size 8x8"""
|
||||
super(AuxiliaryHeadCIFAR, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
|
||||
nn.Conv2d(C, 128, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 768, 2, bias=False),
|
||||
nn.BatchNorm2d(768),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.classifier = nn.Linear(768, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x.view(x.size(0),-1))
|
||||
return x
|
||||
|
||||
|
||||
class AuxiliaryHeadImageNet(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes):
|
||||
"""assuming input size 14x14"""
|
||||
super(AuxiliaryHeadImageNet, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False),
|
||||
nn.Conv2d(C, 128, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 768, 2, bias=False),
|
||||
# NOTE: This batchnorm was omitted in my earlier implementation due to a typo.
|
||||
# Commenting it out for consistency with the experiments in the paper.
|
||||
# nn.BatchNorm2d(768),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.classifier = nn.Linear(768, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x.view(x.size(0),-1))
|
||||
return x
|
||||
|
||||
|
||||
class NetworkCIFAR(nn.Module):
|
||||
"""CIFAR network (ported from DARTS)."""
|
||||
|
||||
def __init__(self, C, num_classes, layers, auxiliary, genotype):
|
||||
super(NetworkCIFAR, self).__init__()
|
||||
self._layers = layers
|
||||
self._auxiliary = auxiliary
|
||||
|
||||
stem_multiplier = 3
|
||||
C_curr = stem_multiplier*C
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(cfg.MODEL.INPUT_CHANNELS, C_curr, 3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C_curr)
|
||||
)
|
||||
|
||||
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
|
||||
self.cells = nn.ModuleList()
|
||||
reduction_prev = False
|
||||
for i in range(layers):
|
||||
if i in [layers//3, 2*layers//3]:
|
||||
C_curr *= 2
|
||||
reduction = True
|
||||
else:
|
||||
reduction = False
|
||||
cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||
reduction_prev = reduction
|
||||
self.cells += [cell]
|
||||
C_prev_prev, C_prev = C_prev, cell.multiplier*C_curr
|
||||
if i == 2*layers//3:
|
||||
C_to_auxiliary = C_prev
|
||||
|
||||
if auxiliary:
|
||||
self.auxiliary_head = AuxiliaryHeadCIFAR(C_to_auxiliary, num_classes)
|
||||
self.classifier = Classifier(C_prev, num_classes)
|
||||
|
||||
def forward(self, input):
|
||||
input = Preprocess(input)
|
||||
logits_aux = None
|
||||
s0 = s1 = self.stem(input)
|
||||
for i, cell in enumerate(self.cells):
|
||||
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
|
||||
if i == 2*self._layers//3:
|
||||
if self._auxiliary and self.training:
|
||||
logits_aux = self.auxiliary_head(s1)
|
||||
logits = self.classifier(s1, input.shape[2:])
|
||||
if self._auxiliary and self.training:
|
||||
return logits, logits_aux
|
||||
return logits
|
||||
|
||||
|
||||
class NetworkImageNet(nn.Module):
|
||||
"""ImageNet network (ported from DARTS)."""
|
||||
|
||||
def __init__(self, C, num_classes, layers, auxiliary, genotype):
|
||||
super(NetworkImageNet, self).__init__()
|
||||
self._layers = layers
|
||||
self._auxiliary = auxiliary
|
||||
|
||||
self.stem0 = nn.Sequential(
|
||||
nn.Conv2d(cfg.MODEL.INPUT_CHANNELS, C // 2, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C),
|
||||
)
|
||||
|
||||
self.stem1 = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C),
|
||||
)
|
||||
|
||||
C_prev_prev, C_prev, C_curr = C, C, C
|
||||
|
||||
self.cells = nn.ModuleList()
|
||||
reduction_prev = True
|
||||
reduction_layers = [layers//3] if cfg.TASK == 'seg' else [layers//3, 2*layers//3]
|
||||
for i in range(layers):
|
||||
if i in reduction_layers:
|
||||
C_curr *= 2
|
||||
reduction = True
|
||||
else:
|
||||
reduction = False
|
||||
cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||
reduction_prev = reduction
|
||||
self.cells += [cell]
|
||||
C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
|
||||
if i == 2 * layers // 3:
|
||||
C_to_auxiliary = C_prev
|
||||
|
||||
if auxiliary:
|
||||
self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes)
|
||||
self.classifier = Classifier(C_prev, num_classes)
|
||||
|
||||
def forward(self, input):
|
||||
input = Preprocess(input)
|
||||
logits_aux = None
|
||||
s0 = self.stem0(input)
|
||||
s1 = self.stem1(s0)
|
||||
for i, cell in enumerate(self.cells):
|
||||
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
|
||||
if i == 2 * self._layers // 3:
|
||||
if self._auxiliary and self.training:
|
||||
logits_aux = self.auxiliary_head(s1)
|
||||
logits = self.classifier(s1, input.shape[2:])
|
||||
if self._auxiliary and self.training:
|
||||
return logits, logits_aux
|
||||
return logits
|
||||
|
||||
|
||||
class NAS(nn.Module):
|
||||
"""NAS net wrapper (delegates to nets from DARTS)."""
|
||||
|
||||
def __init__(self):
|
||||
assert cfg.TRAIN.DATASET in ['cifar10', 'imagenet', 'cityscapes'], \
|
||||
'Training on {} is not supported'.format(cfg.TRAIN.DATASET)
|
||||
assert cfg.TEST.DATASET in ['cifar10', 'imagenet', 'cityscapes'], \
|
||||
'Testing on {} is not supported'.format(cfg.TEST.DATASET)
|
||||
assert cfg.NAS.GENOTYPE in GENOTYPES, \
|
||||
'Genotype {} not supported'.format(cfg.NAS.GENOTYPE)
|
||||
super(NAS, self).__init__()
|
||||
logger.info('Constructing NAS: {}'.format(cfg.NAS))
|
||||
# Use a custom or predefined genotype
|
||||
if cfg.NAS.GENOTYPE == 'custom':
|
||||
genotype = Genotype(
|
||||
normal=cfg.NAS.CUSTOM_GENOTYPE[0],
|
||||
normal_concat=cfg.NAS.CUSTOM_GENOTYPE[1],
|
||||
reduce=cfg.NAS.CUSTOM_GENOTYPE[2],
|
||||
reduce_concat=cfg.NAS.CUSTOM_GENOTYPE[3],
|
||||
)
|
||||
else:
|
||||
genotype = GENOTYPES[cfg.NAS.GENOTYPE]
|
||||
# Determine the network constructor for dataset
|
||||
if 'cifar' in cfg.TRAIN.DATASET:
|
||||
net_ctor = NetworkCIFAR
|
||||
else:
|
||||
net_ctor = NetworkImageNet
|
||||
# Construct the network
|
||||
self.net_ = net_ctor(
|
||||
C=cfg.NAS.WIDTH,
|
||||
num_classes=cfg.MODEL.NUM_CLASSES,
|
||||
layers=cfg.NAS.DEPTH,
|
||||
auxiliary=cfg.NAS.AUX,
|
||||
genotype=genotype
|
||||
)
|
||||
# Drop path probability (set / annealed based on epoch)
|
||||
self.net_.drop_path_prob = 0.0
|
||||
|
||||
def set_drop_path_prob(self, drop_path_prob):
|
||||
self.net_.drop_path_prob = drop_path_prob
|
||||
|
||||
def forward(self, x):
|
||||
return self.net_.forward(x)
|
||||
201
pycls/models/nas/operations.py
Normal file
201
pycls/models/nas/operations.py
Normal file
@@ -0,0 +1,201 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
"""NAS ops (adopted from DARTS)."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
OPS = {
|
||||
'none': lambda C, stride, affine:
|
||||
Zero(stride),
|
||||
'avg_pool_2x2': lambda C, stride, affine:
|
||||
nn.AvgPool2d(2, stride=stride, padding=0, count_include_pad=False),
|
||||
'avg_pool_3x3': lambda C, stride, affine:
|
||||
nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
|
||||
'avg_pool_5x5': lambda C, stride, affine:
|
||||
nn.AvgPool2d(5, stride=stride, padding=2, count_include_pad=False),
|
||||
'max_pool_2x2': lambda C, stride, affine:
|
||||
nn.MaxPool2d(2, stride=stride, padding=0),
|
||||
'max_pool_3x3': lambda C, stride, affine:
|
||||
nn.MaxPool2d(3, stride=stride, padding=1),
|
||||
'max_pool_5x5': lambda C, stride, affine:
|
||||
nn.MaxPool2d(5, stride=stride, padding=2),
|
||||
'max_pool_7x7': lambda C, stride, affine:
|
||||
nn.MaxPool2d(7, stride=stride, padding=3),
|
||||
'skip_connect': lambda C, stride, affine:
|
||||
Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
|
||||
'conv_1x1': lambda C, stride, affine:
|
||||
nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, 1, stride=stride, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C, affine=affine)
|
||||
),
|
||||
'conv_3x3': lambda C, stride, affine:
|
||||
nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, 3, stride=stride, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C, affine=affine)
|
||||
),
|
||||
'sep_conv_3x3': lambda C, stride, affine:
|
||||
SepConv(C, C, 3, stride, 1, affine=affine),
|
||||
'sep_conv_5x5': lambda C, stride, affine:
|
||||
SepConv(C, C, 5, stride, 2, affine=affine),
|
||||
'sep_conv_7x7': lambda C, stride, affine:
|
||||
SepConv(C, C, 7, stride, 3, affine=affine),
|
||||
'dil_conv_3x3': lambda C, stride, affine:
|
||||
DilConv(C, C, 3, stride, 2, 2, affine=affine),
|
||||
'dil_conv_5x5': lambda C, stride, affine:
|
||||
DilConv(C, C, 5, stride, 4, 2, affine=affine),
|
||||
'dil_sep_conv_3x3': lambda C, stride, affine:
|
||||
DilSepConv(C, C, 3, stride, 2, 2, affine=affine),
|
||||
'conv_3x1_1x3': lambda C, stride, affine:
|
||||
nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, (1,3), stride=(1, stride), padding=(0, 1), bias=False),
|
||||
nn.Conv2d(C, C, (3,1), stride=(stride, 1), padding=(1, 0), bias=False),
|
||||
nn.BatchNorm2d(C, affine=affine)
|
||||
),
|
||||
'conv_7x1_1x7': lambda C, stride, affine:
|
||||
nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, (1,7), stride=(1, stride), padding=(0, 3), bias=False),
|
||||
nn.Conv2d(C, C, (7,1), stride=(stride, 1), padding=(3, 0), bias=False),
|
||||
nn.BatchNorm2d(C, affine=affine)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class ReLUConvBN(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
|
||||
super(ReLUConvBN, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in, C_out, kernel_size, stride=stride,
|
||||
padding=padding, bias=False
|
||||
),
|
||||
nn.BatchNorm2d(C_out, affine=affine)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class DilConv(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True
|
||||
):
|
||||
super(DilConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in, C_in, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, groups=C_in, bias=False
|
||||
),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class SepConv(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
|
||||
super(SepConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in, C_in, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, groups=C_in, bias=False
|
||||
),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_in, affine=affine),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in, C_in, kernel_size=kernel_size, stride=1,
|
||||
padding=padding, groups=C_in, bias=False
|
||||
),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class DilSepConv(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True
|
||||
):
|
||||
super(DilSepConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in, C_in, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, groups=C_in, bias=False
|
||||
),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_in, affine=affine),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in, C_in, kernel_size=kernel_size, stride=1,
|
||||
padding=padding, dilation=dilation, groups=C_in, bias=False
|
||||
),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class Zero(nn.Module):
|
||||
|
||||
def __init__(self, stride):
|
||||
super(Zero, self).__init__()
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
if self.stride == 1:
|
||||
return x.mul(0.)
|
||||
return x[:,:,::self.stride,::self.stride].mul(0.)
|
||||
|
||||
|
||||
class FactorizedReduce(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, affine=True):
|
||||
super(FactorizedReduce, self).__init__()
|
||||
assert C_out % 2 == 0
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
|
||||
self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
|
||||
self.bn = nn.BatchNorm2d(C_out, affine=affine)
|
||||
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(x)
|
||||
y = self.pad(x)
|
||||
out = torch.cat([self.conv_1(x), self.conv_2(y[:,:,1:,1:])], dim=1)
|
||||
out = self.bn(out)
|
||||
return out
|
||||
89
pycls/models/regnet.py
Normal file
89
pycls/models/regnet.py
Normal file
@@ -0,0 +1,89 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""RegNet models."""
|
||||
|
||||
import numpy as np
|
||||
from pycls.core.config import cfg
|
||||
from pycls.models.anynet import AnyNet
|
||||
|
||||
|
||||
def quantize_float(f, q):
|
||||
"""Converts a float to closest non-zero int divisible by q."""
|
||||
return int(round(f / q) * q)
|
||||
|
||||
|
||||
def adjust_ws_gs_comp(ws, bms, gs):
|
||||
"""Adjusts the compatibility of widths and groups."""
|
||||
ws_bot = [int(w * b) for w, b in zip(ws, bms)]
|
||||
gs = [min(g, w_bot) for g, w_bot in zip(gs, ws_bot)]
|
||||
ws_bot = [quantize_float(w_bot, g) for w_bot, g in zip(ws_bot, gs)]
|
||||
ws = [int(w_bot / b) for w_bot, b in zip(ws_bot, bms)]
|
||||
return ws, gs
|
||||
|
||||
|
||||
def get_stages_from_blocks(ws, rs):
|
||||
"""Gets ws/ds of network at each stage from per block values."""
|
||||
ts_temp = zip(ws + [0], [0] + ws, rs + [0], [0] + rs)
|
||||
ts = [w != wp or r != rp for w, wp, r, rp in ts_temp]
|
||||
s_ws = [w for w, t in zip(ws, ts[:-1]) if t]
|
||||
s_ds = np.diff([d for d, t in zip(range(len(ts)), ts) if t]).tolist()
|
||||
return s_ws, s_ds
|
||||
|
||||
|
||||
def generate_regnet(w_a, w_0, w_m, d, q=8):
|
||||
"""Generates per block ws from RegNet parameters."""
|
||||
assert w_a >= 0 and w_0 > 0 and w_m > 1 and w_0 % q == 0
|
||||
ws_cont = np.arange(d) * w_a + w_0
|
||||
ks = np.round(np.log(ws_cont / w_0) / np.log(w_m))
|
||||
ws = w_0 * np.power(w_m, ks)
|
||||
ws = np.round(np.divide(ws, q)) * q
|
||||
num_stages, max_stage = len(np.unique(ws)), ks.max() + 1
|
||||
ws, ws_cont = ws.astype(int).tolist(), ws_cont.tolist()
|
||||
return ws, num_stages, max_stage, ws_cont
|
||||
|
||||
|
||||
class RegNet(AnyNet):
|
||||
"""RegNet model."""
|
||||
|
||||
@staticmethod
|
||||
def get_args():
|
||||
"""Convert RegNet to AnyNet parameter format."""
|
||||
# Generate RegNet ws per block
|
||||
w_a, w_0, w_m, d = cfg.REGNET.WA, cfg.REGNET.W0, cfg.REGNET.WM, cfg.REGNET.DEPTH
|
||||
ws, num_stages, _, _ = generate_regnet(w_a, w_0, w_m, d)
|
||||
# Convert to per stage format
|
||||
s_ws, s_ds = get_stages_from_blocks(ws, ws)
|
||||
# Use the same gw, bm and ss for each stage
|
||||
s_gs = [cfg.REGNET.GROUP_W for _ in range(num_stages)]
|
||||
s_bs = [cfg.REGNET.BOT_MUL for _ in range(num_stages)]
|
||||
s_ss = [cfg.REGNET.STRIDE for _ in range(num_stages)]
|
||||
# Adjust the compatibility of ws and gws
|
||||
s_ws, s_gs = adjust_ws_gs_comp(s_ws, s_bs, s_gs)
|
||||
# Get AnyNet arguments defining the RegNet
|
||||
return {
|
||||
"stem_type": cfg.REGNET.STEM_TYPE,
|
||||
"stem_w": cfg.REGNET.STEM_W,
|
||||
"block_type": cfg.REGNET.BLOCK_TYPE,
|
||||
"ds": s_ds,
|
||||
"ws": s_ws,
|
||||
"ss": s_ss,
|
||||
"bms": s_bs,
|
||||
"gws": s_gs,
|
||||
"se_r": cfg.REGNET.SE_R if cfg.REGNET.SE_ON else None,
|
||||
"nc": cfg.MODEL.NUM_CLASSES,
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
kwargs = RegNet.get_args()
|
||||
super(RegNet, self).__init__(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, **kwargs):
|
||||
"""Computes model complexity. If you alter the model, make sure to update."""
|
||||
kwargs = RegNet.get_args() if not kwargs else kwargs
|
||||
return AnyNet.complexity(cx, **kwargs)
|
||||
280
pycls/models/resnet.py
Normal file
280
pycls/models/resnet.py
Normal file
@@ -0,0 +1,280 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""ResNe(X)t models."""
|
||||
|
||||
import pycls.core.net as net
|
||||
import torch.nn as nn
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
# Stage depths for ImageNet models
|
||||
_IN_STAGE_DS = {50: (3, 4, 6, 3), 101: (3, 4, 23, 3), 152: (3, 8, 36, 3)}
|
||||
|
||||
|
||||
def get_trans_fun(name):
|
||||
"""Retrieves the transformation function by name."""
|
||||
trans_funs = {
|
||||
"basic_transform": BasicTransform,
|
||||
"bottleneck_transform": BottleneckTransform,
|
||||
}
|
||||
err_str = "Transformation function '{}' not supported"
|
||||
assert name in trans_funs.keys(), err_str.format(name)
|
||||
return trans_funs[name]
|
||||
|
||||
|
||||
class ResHead(nn.Module):
|
||||
"""ResNet head: AvgPool, 1x1."""
|
||||
|
||||
def __init__(self, w_in, nc):
|
||||
super(ResHead, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(w_in, nc, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.avg_pool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, nc):
|
||||
cx["h"], cx["w"] = 1, 1
|
||||
cx = net.complexity_conv2d(cx, w_in, nc, 1, 1, 0, bias=True)
|
||||
return cx
|
||||
|
||||
|
||||
class BasicTransform(nn.Module):
|
||||
"""Basic transformation: 3x3, BN, ReLU, 3x3, BN."""
|
||||
|
||||
def __init__(self, w_in, w_out, stride, w_b=None, num_gs=1):
|
||||
err_str = "Basic transform does not support w_b and num_gs options"
|
||||
assert w_b is None and num_gs == 1, err_str
|
||||
super(BasicTransform, self).__init__()
|
||||
self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False)
|
||||
self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
|
||||
self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False)
|
||||
self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.b_bn.final_bn = True
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, stride, w_b=None, num_gs=1):
|
||||
err_str = "Basic transform does not support w_b and num_gs options"
|
||||
assert w_b is None and num_gs == 1, err_str
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 3, stride, 1)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
cx = net.complexity_conv2d(cx, w_out, w_out, 3, 1, 1)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class BottleneckTransform(nn.Module):
|
||||
"""Bottleneck transformation: 1x1, BN, ReLU, 3x3, BN, ReLU, 1x1, BN."""
|
||||
|
||||
def __init__(self, w_in, w_out, stride, w_b, num_gs):
|
||||
super(BottleneckTransform, self).__init__()
|
||||
# MSRA -> stride=2 is on 1x1; TH/C2 -> stride=2 is on 3x3
|
||||
(s1, s3) = (stride, 1) if cfg.RESNET.STRIDE_1X1 else (1, stride)
|
||||
self.a = nn.Conv2d(w_in, w_b, 1, stride=s1, padding=0, bias=False)
|
||||
self.a_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
|
||||
self.b = nn.Conv2d(w_b, w_b, 3, stride=s3, padding=1, groups=num_gs, bias=False)
|
||||
self.b_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
|
||||
self.c = nn.Conv2d(w_b, w_out, 1, stride=1, padding=0, bias=False)
|
||||
self.c_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.c_bn.final_bn = True
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, stride, w_b, num_gs):
|
||||
(s1, s3) = (stride, 1) if cfg.RESNET.STRIDE_1X1 else (1, stride)
|
||||
cx = net.complexity_conv2d(cx, w_in, w_b, 1, s1, 0)
|
||||
cx = net.complexity_batchnorm2d(cx, w_b)
|
||||
cx = net.complexity_conv2d(cx, w_b, w_b, 3, s3, 1, num_gs)
|
||||
cx = net.complexity_batchnorm2d(cx, w_b)
|
||||
cx = net.complexity_conv2d(cx, w_b, w_out, 1, 1, 0)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
"""Residual block: x + F(x)."""
|
||||
|
||||
def __init__(self, w_in, w_out, stride, trans_fun, w_b=None, num_gs=1):
|
||||
super(ResBlock, self).__init__()
|
||||
# Use skip connection with projection if shape changes
|
||||
self.proj_block = (w_in != w_out) or (stride != 1)
|
||||
if self.proj_block:
|
||||
self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False)
|
||||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.f = trans_fun(w_in, w_out, stride, w_b, num_gs)
|
||||
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
|
||||
|
||||
def forward(self, x):
|
||||
if self.proj_block:
|
||||
x = self.bn(self.proj(x)) + self.f(x)
|
||||
else:
|
||||
x = x + self.f(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, stride, trans_fun, w_b, num_gs):
|
||||
proj_block = (w_in != w_out) or (stride != 1)
|
||||
if proj_block:
|
||||
h, w = cx["h"], cx["w"]
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
cx["h"], cx["w"] = h, w # parallel branch
|
||||
cx = trans_fun.complexity(cx, w_in, w_out, stride, w_b, num_gs)
|
||||
return cx
|
||||
|
||||
|
||||
class ResStage(nn.Module):
|
||||
"""Stage of ResNet."""
|
||||
|
||||
def __init__(self, w_in, w_out, stride, d, w_b=None, num_gs=1):
|
||||
super(ResStage, self).__init__()
|
||||
for i in range(d):
|
||||
b_stride = stride if i == 0 else 1
|
||||
b_w_in = w_in if i == 0 else w_out
|
||||
trans_fun = get_trans_fun(cfg.RESNET.TRANS_FUN)
|
||||
res_block = ResBlock(b_w_in, w_out, b_stride, trans_fun, w_b, num_gs)
|
||||
self.add_module("b{}".format(i + 1), res_block)
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.children():
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out, stride, d, w_b=None, num_gs=1):
|
||||
for i in range(d):
|
||||
b_stride = stride if i == 0 else 1
|
||||
b_w_in = w_in if i == 0 else w_out
|
||||
trans_f = get_trans_fun(cfg.RESNET.TRANS_FUN)
|
||||
cx = ResBlock.complexity(cx, b_w_in, w_out, b_stride, trans_f, w_b, num_gs)
|
||||
return cx
|
||||
|
||||
|
||||
class ResStemCifar(nn.Module):
|
||||
"""ResNet stem for CIFAR: 3x3, BN, ReLU."""
|
||||
|
||||
def __init__(self, w_in, w_out):
|
||||
super(ResStemCifar, self).__init__()
|
||||
self.conv = nn.Conv2d(w_in, w_out, 3, stride=1, padding=1, bias=False)
|
||||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out):
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 3, 1, 1)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
return cx
|
||||
|
||||
|
||||
class ResStemIN(nn.Module):
|
||||
"""ResNet stem for ImageNet: 7x7, BN, ReLU, MaxPool."""
|
||||
|
||||
def __init__(self, w_in, w_out):
|
||||
super(ResStemIN, self).__init__()
|
||||
self.conv = nn.Conv2d(w_in, w_out, 7, stride=2, padding=3, bias=False)
|
||||
self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
|
||||
self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
|
||||
self.pool = nn.MaxPool2d(3, stride=2, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, w_in, w_out):
|
||||
cx = net.complexity_conv2d(cx, w_in, w_out, 7, 2, 3)
|
||||
cx = net.complexity_batchnorm2d(cx, w_out)
|
||||
cx = net.complexity_maxpool2d(cx, 3, 2, 1)
|
||||
return cx
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
"""ResNet model."""
|
||||
|
||||
def __init__(self):
|
||||
datasets = ["cifar10", "imagenet"]
|
||||
err_str = "Dataset {} is not supported"
|
||||
assert cfg.TRAIN.DATASET in datasets, err_str.format(cfg.TRAIN.DATASET)
|
||||
assert cfg.TEST.DATASET in datasets, err_str.format(cfg.TEST.DATASET)
|
||||
super(ResNet, self).__init__()
|
||||
if "cifar" in cfg.TRAIN.DATASET:
|
||||
self._construct_cifar()
|
||||
else:
|
||||
self._construct_imagenet()
|
||||
self.apply(net.init_weights)
|
||||
|
||||
def _construct_cifar(self):
|
||||
err_str = "Model depth should be of the format 6n + 2 for cifar"
|
||||
assert (cfg.MODEL.DEPTH - 2) % 6 == 0, err_str
|
||||
d = int((cfg.MODEL.DEPTH - 2) / 6)
|
||||
self.stem = ResStemCifar(3, 16)
|
||||
self.s1 = ResStage(16, 16, stride=1, d=d)
|
||||
self.s2 = ResStage(16, 32, stride=2, d=d)
|
||||
self.s3 = ResStage(32, 64, stride=2, d=d)
|
||||
self.head = ResHead(64, nc=cfg.MODEL.NUM_CLASSES)
|
||||
|
||||
def _construct_imagenet(self):
|
||||
g, gw = cfg.RESNET.NUM_GROUPS, cfg.RESNET.WIDTH_PER_GROUP
|
||||
(d1, d2, d3, d4) = _IN_STAGE_DS[cfg.MODEL.DEPTH]
|
||||
w_b = gw * g
|
||||
self.stem = ResStemIN(3, 64)
|
||||
self.s1 = ResStage(64, 256, stride=1, d=d1, w_b=w_b, num_gs=g)
|
||||
self.s2 = ResStage(256, 512, stride=2, d=d2, w_b=w_b * 2, num_gs=g)
|
||||
self.s3 = ResStage(512, 1024, stride=2, d=d3, w_b=w_b * 4, num_gs=g)
|
||||
self.s4 = ResStage(1024, 2048, stride=2, d=d4, w_b=w_b * 8, num_gs=g)
|
||||
self.head = ResHead(2048, nc=cfg.MODEL.NUM_CLASSES)
|
||||
|
||||
def forward(self, x):
|
||||
for module in self.children():
|
||||
x = module(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx):
|
||||
"""Computes model complexity. If you alter the model, make sure to update."""
|
||||
if "cifar" in cfg.TRAIN.DATASET:
|
||||
d = int((cfg.MODEL.DEPTH - 2) / 6)
|
||||
cx = ResStemCifar.complexity(cx, 3, 16)
|
||||
cx = ResStage.complexity(cx, 16, 16, stride=1, d=d)
|
||||
cx = ResStage.complexity(cx, 16, 32, stride=2, d=d)
|
||||
cx = ResStage.complexity(cx, 32, 64, stride=2, d=d)
|
||||
cx = ResHead.complexity(cx, 64, nc=cfg.MODEL.NUM_CLASSES)
|
||||
else:
|
||||
g, gw = cfg.RESNET.NUM_GROUPS, cfg.RESNET.WIDTH_PER_GROUP
|
||||
(d1, d2, d3, d4) = _IN_STAGE_DS[cfg.MODEL.DEPTH]
|
||||
w_b = gw * g
|
||||
cx = ResStemIN.complexity(cx, 3, 64)
|
||||
cx = ResStage.complexity(cx, 64, 256, 1, d=d1, w_b=w_b, num_gs=g)
|
||||
cx = ResStage.complexity(cx, 256, 512, 2, d=d2, w_b=w_b * 2, num_gs=g)
|
||||
cx = ResStage.complexity(cx, 512, 1024, 2, d=d3, w_b=w_b * 4, num_gs=g)
|
||||
cx = ResStage.complexity(cx, 1024, 2048, 2, d=d4, w_b=w_b * 8, num_gs=g)
|
||||
cx = ResHead.complexity(cx, 2048, nc=cfg.MODEL.NUM_CLASSES)
|
||||
return cx
|
||||
13
reproduce.sh
13
reproduce.sh
@@ -1,13 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
python search.py --dataset cifar10 --data_loc '../datasets/cifar10' --n_runs $1 --n_samples 10
|
||||
python search.py --dataset cifar10 --trainval --data_loc '../datasets/cifar10' --n_runs $1 --n_samples 10
|
||||
python search.py --dataset cifar100 --data_loc '../datasets/cifar100' --n_runs $1 --n_samples 10
|
||||
python search.py --dataset ImageNet16-120 --data_loc '../datasets/ImageNet16' --n_runs $1 --n_samples 10
|
||||
|
||||
python search.py --dataset cifar10 --data_loc '../datasets/cifar10' --n_runs $1 --n_samples 100
|
||||
python search.py --dataset cifar10 --trainval --data_loc '../datasets/cifar10' --n_runs $1 --n_samples 100
|
||||
python search.py --dataset cifar100 --data_loc '../datasets/cifar100' --n_runs $1 --n_samples 100
|
||||
python search.py --dataset ImageNet16-120 --data_loc '../datasets/ImageNet16' --n_runs $1 --n_samples 100
|
||||
|
||||
python process_results.py --n_runs $1
|
||||
Binary file not shown.
Binary file not shown.
|
Before Width: | Height: | Size: 30 KiB |
164
score_networks.py
Normal file
164
score_networks.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import argparse
|
||||
import nasspace
|
||||
import datasets
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
import os
|
||||
from scores import get_score_func
|
||||
from scipy import stats
|
||||
from pycls.models.nas.nas import Cell
|
||||
from utils import add_dropout, init_network
|
||||
|
||||
parser = argparse.ArgumentParser(description='NAS Without Training')
|
||||
parser.add_argument('--data_loc', default='../cifardata/', type=str, help='dataset folder')
|
||||
parser.add_argument('--api_loc', default='../NAS-Bench-201-v1_0-e61699.pth',
|
||||
type=str, help='path to API')
|
||||
parser.add_argument('--save_loc', default='results', type=str, help='folder to save results')
|
||||
parser.add_argument('--save_string', default='naswot', type=str, help='prefix of results file')
|
||||
parser.add_argument('--score', default='hook_logdet', type=str, help='the score to evaluate')
|
||||
parser.add_argument('--nasspace', default='nasbench201', type=str, help='the nas search space to use')
|
||||
parser.add_argument('--batch_size', default=128, type=int)
|
||||
parser.add_argument('--repeat', default=1, type=int, help='how often to repeat a single image with a batch')
|
||||
parser.add_argument('--augtype', default='none', type=str, help='which perturbations to use')
|
||||
parser.add_argument('--sigma', default=0.05, type=float, help='noise level if augtype is "gaussnoise"')
|
||||
parser.add_argument('--GPU', default='0', type=str)
|
||||
parser.add_argument('--seed', default=1, type=int)
|
||||
parser.add_argument('--init', default='', type=str)
|
||||
parser.add_argument('--trainval', action='store_true')
|
||||
parser.add_argument('--dropout', action='store_true')
|
||||
parser.add_argument('--dataset', default='cifar10', type=str)
|
||||
parser.add_argument('--maxofn', default=1, type=int, help='score is the max of this many evaluations of the network')
|
||||
parser.add_argument('--n_samples', default=100, type=int)
|
||||
parser.add_argument('--n_runs', default=500, type=int)
|
||||
parser.add_argument('--stem_out_channels', default=16, type=int, help='output channels of stem convolution (nasbench101)')
|
||||
parser.add_argument('--num_stacks', default=3, type=int, help='#stacks of modules (nasbench101)')
|
||||
parser.add_argument('--num_modules_per_stack', default=3, type=int, help='#modules per stack (nasbench101)')
|
||||
parser.add_argument('--num_labels', default=1, type=int, help='#classes (nasbench101)')
|
||||
|
||||
args = parser.parse_args()
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU
|
||||
|
||||
# Reproducibility
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
|
||||
def get_batch_jacobian(net, x, target, device, args=None):
|
||||
net.zero_grad()
|
||||
x.requires_grad_(True)
|
||||
y, out = net(x)
|
||||
y.backward(torch.ones_like(y))
|
||||
jacob = x.grad.detach()
|
||||
return jacob, target.detach(), y.detach(), out.detach()
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
savedataset = args.dataset
|
||||
dataset = 'fake' if 'fake' in args.dataset else args.dataset
|
||||
args.dataset = args.dataset.replace('fake', '')
|
||||
if args.dataset == 'cifar10':
|
||||
args.dataset = args.dataset + '-valid'
|
||||
searchspace = nasspace.get_search_space(args)
|
||||
if 'valid' in args.dataset:
|
||||
args.dataset = args.dataset.replace('-valid', '')
|
||||
train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args)
|
||||
os.makedirs(args.save_loc, exist_ok=True)
|
||||
|
||||
filename = f'{args.save_loc}/{args.save_string}_{args.score}_{args.nasspace}_{savedataset}{"_" + args.init + "_" if args.init != "" else args.init}_{"_dropout" if args.dropout else ""}_{args.augtype}_{args.sigma}_{args.repeat}_{args.trainval}_{args.batch_size}_{args.maxofn}_{args.seed}'
|
||||
accfilename = f'{args.save_loc}/{args.save_string}_accs_{args.nasspace}_{savedataset}_{args.trainval}'
|
||||
|
||||
if args.dataset == 'cifar10':
|
||||
acc_type = 'ori-test'
|
||||
val_acc_type = 'x-valid'
|
||||
else:
|
||||
acc_type = 'x-test'
|
||||
val_acc_type = 'x-valid'
|
||||
|
||||
|
||||
scores = np.zeros(len(searchspace))
|
||||
try:
|
||||
accs = np.load(accfilename + '.npy')
|
||||
except:
|
||||
accs = np.zeros(len(searchspace))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
for i, (uid, network) in enumerate(searchspace):
|
||||
# Reproducibility
|
||||
try:
|
||||
if args.dropout:
|
||||
add_dropout(network, args.sigma)
|
||||
if args.init != '':
|
||||
init_network(network, args.init)
|
||||
if 'hook_' in args.score:
|
||||
network.K = np.zeros((args.batch_size, args.batch_size))
|
||||
def counting_forward_hook(module, inp, out):
|
||||
try:
|
||||
if not module.visited_backwards:
|
||||
return
|
||||
if isinstance(inp, tuple):
|
||||
inp = inp[0]
|
||||
inp = inp.view(inp.size(0), -1)
|
||||
x = (inp > 0).float()
|
||||
K = x @ x.t()
|
||||
K2 = (1.-x) @ (1.-x.t())
|
||||
network.K = network.K + K.cpu().numpy() + K2.cpu().numpy()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def counting_backward_hook(module, inp, out):
|
||||
module.visited_backwards = True
|
||||
|
||||
|
||||
for name, module in network.named_modules():
|
||||
if 'ReLU' in str(type(module)):
|
||||
#hooks[name] = module.register_forward_hook(counting_hook)
|
||||
module.register_forward_hook(counting_forward_hook)
|
||||
module.register_backward_hook(counting_backward_hook)
|
||||
|
||||
network = network.to(device)
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
s = []
|
||||
for j in range(args.maxofn):
|
||||
data_iterator = iter(train_loader)
|
||||
x, target = next(data_iterator)
|
||||
x2 = torch.clone(x)
|
||||
x2 = x2.to(device)
|
||||
x, target = x.to(device), target.to(device)
|
||||
jacobs, labels, y, out = get_batch_jacobian(network, x, target, device, args)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if 'hook_' in args.score:
|
||||
network(x2.to(device))
|
||||
s.append(get_score_func(args.score)(network.K, target))
|
||||
else:
|
||||
s.append(get_score_func(args.score)(jacobs, labels))
|
||||
scores[i] = np.mean(s)
|
||||
accs[i] = searchspace.get_final_accuracy(uid, acc_type, args.trainval)
|
||||
accs_ = accs[~np.isnan(scores)]
|
||||
scores_ = scores[~np.isnan(scores)]
|
||||
numnan = np.isnan(scores).sum()
|
||||
tau, p = stats.kendalltau(accs_[:max(i-numnan, 1)], scores_[:max(i-numnan, 1)])
|
||||
print(f'{tau}')
|
||||
if i % 1000 == 0:
|
||||
np.save(filename, scores)
|
||||
np.save(accfilename, accs)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
accs[i] = searchspace.get_final_accuracy(uid, acc_type, args.trainval)
|
||||
scores[i] = np.nan
|
||||
np.save(filename, scores)
|
||||
np.save(accfilename, accs)
|
||||
32
scorehook.sh
Normal file
32
scorehook.sh
Normal file
@@ -0,0 +1,32 @@
|
||||
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nasbench201 --batch_size 128 --GPU 3 --dataset cifar10
|
||||
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nasbench201 --batch_size 128 --GPU 3 --dataset cifar100 --data_loc ../cifar100/
|
||||
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nasbench201 --batch_size 128 --GPU 3 --dataset ImageNet16-120 --data_loc ../imagenet16/Imagenet16/
|
||||
|
||||
|
||||
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_pnas --batch_size 128 --GPU 3
|
||||
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_enas --batch_size 128 --GPU 3
|
||||
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_darts --batch_size 128 --GPU 3
|
||||
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_darts_fix-w-d --batch_size 128 --GPU 3
|
||||
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_nasnet --batch_size 128 --GPU 3
|
||||
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_amoeba --batch_size 128 --GPU 3
|
||||
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_resnet --batch_size 128 --GPU 3
|
||||
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_resnext-a --batch_size 128 --GPU 3
|
||||
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_resnext-b --batch_size 128 --GPU 3
|
||||
|
||||
|
||||
|
||||
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace amoeba_in --batch_size 128 --GPU 3 --dataset imagenette2 --data_loc ../imagenette2/
|
||||
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_amoeba_in --batch_size 128 --GPU 3 --dataset imagenette2 --data_loc ../imagenette2/
|
||||
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_darts_in --batch_size 128 --GPU 3 --dataset imagenette2 --data_loc ../imagenette2/
|
||||
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_nasnet_in --batch_size 128 --GPU 3 --dataset imagenette2 --data_loc ../imagenette2/
|
||||
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_pnas_in --batch_size 128 --GPU 3 --dataset imagenette2 --data_loc ../imagenette2/
|
||||
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_enas_in --batch_size 128 --GPU 3 --dataset imagenette2 --data_loc ../imagenette2/
|
||||
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nds_resnext-a_in --batch_size 128 --GPU 3 --dataset imagenette2 --data_loc ../imagenette2/
|
||||
|
||||
|
||||
|
||||
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nasbench201 --batch_size 128 --GPU 3 --dataset cifar100 --data_loc ../cifar100/
|
||||
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nasbench201 --batch_size 128 --GPU 3 --dataset ImageNet16-120 --data_loc ../imagenet16/Imagenet16/
|
||||
|
||||
python score_networks.py --trainval --augtype none --repeat 1 --score hook_logdet --sigma 0.05 --nasspace nasbench101 --batch_size 128 --GPU 3 --api_loc ../nasbench_only108.tfrecord
|
||||
|
||||
21
scores.py
Normal file
21
scores.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
|
||||
def hooklogdet(K, labels=None):
|
||||
s, ld = np.linalg.slogdet(K)
|
||||
return ld
|
||||
|
||||
def random_score(jacob, label=None):
|
||||
return np.random.normal()
|
||||
|
||||
|
||||
_scores = {
|
||||
'hook_logdet': hooklogdet,
|
||||
'random': random_score
|
||||
}
|
||||
|
||||
def get_score_func(score_name):
|
||||
return _scores[score_name]
|
||||
206
search.py
206
search.py
@@ -1,35 +1,49 @@
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import nasspace
|
||||
import datasets
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
import os
|
||||
from scores import get_score_func
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
from tqdm import trange
|
||||
from statistics import mean
|
||||
import time
|
||||
from utils import add_dropout
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='NAS Without Training')
|
||||
parser.add_argument('--data_loc', default='../datasets/cifar', type=str, help='dataset folder')
|
||||
parser.add_argument('--api_loc', default='../datasets/NAS-Bench-201-v1_1-096897.pth',
|
||||
parser.add_argument('--data_loc', default='../cifardata/', type=str, help='dataset folder')
|
||||
parser.add_argument('--api_loc', default='../NAS-Bench-201-v1_0-e61699.pth',
|
||||
type=str, help='path to API')
|
||||
parser.add_argument('--save_loc', default='results', type=str, help='folder to save results')
|
||||
parser.add_argument('--batch_size', default=256, type=int)
|
||||
parser.add_argument('--save_loc', default='results/ICML', type=str, help='folder to save results')
|
||||
parser.add_argument('--save_string', default='naswot', type=str, help='prefix of results file')
|
||||
parser.add_argument('--score', default='hook_logdet', type=str, help='the score to evaluate')
|
||||
parser.add_argument('--nasspace', default='nasbench201', type=str, help='the nas search space to use')
|
||||
parser.add_argument('--batch_size', default=128, type=int)
|
||||
parser.add_argument('--kernel', action='store_true')
|
||||
parser.add_argument('--dropout', action='store_true')
|
||||
parser.add_argument('--repeat', default=1, type=int, help='how often to repeat a single image with a batch')
|
||||
parser.add_argument('--augtype', default='none', type=str, help='which perturbations to use')
|
||||
parser.add_argument('--sigma', default=0.05, type=float, help='noise level if augtype is "gaussnoise"')
|
||||
parser.add_argument('--GPU', default='0', type=str)
|
||||
parser.add_argument('--seed', default=1, type=int)
|
||||
parser.add_argument('--init', default='', type=str)
|
||||
parser.add_argument('--trainval', action='store_true')
|
||||
parser.add_argument('--activations', action='store_true')
|
||||
parser.add_argument('--cosine', action='store_true')
|
||||
parser.add_argument('--dataset', default='cifar10', type=str)
|
||||
parser.add_argument('--n_samples', default=100, type=int)
|
||||
parser.add_argument('--n_runs', default=500, type=int)
|
||||
parser.add_argument('--stem_out_channels', default=16, type=int, help='output channels of stem convolution (nasbench101)')
|
||||
parser.add_argument('--num_stacks', default=3, type=int, help='#stacks of modules (nasbench101)')
|
||||
parser.add_argument('--num_modules_per_stack', default=3, type=int, help='#modules per stack (nasbench101)')
|
||||
parser.add_argument('--num_labels', default=1, type=int, help='#classes (nasbench101)')
|
||||
|
||||
args = parser.parse_args()
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
import torchvision.datasets as datasets
|
||||
import torch.optim as optim
|
||||
|
||||
from models import get_cell_based_tiny_net
|
||||
|
||||
# Reproducibility
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
@@ -37,120 +51,140 @@ random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
import torchvision.transforms as transforms
|
||||
from datasets import get_datasets
|
||||
from config_utils import load_config
|
||||
from nas_201_api import NASBench201API as API
|
||||
|
||||
def get_batch_jacobian(net, x, target, to, device, args=None):
|
||||
def get_batch_jacobian(net, x, target, device, args=None):
|
||||
net.zero_grad()
|
||||
|
||||
x.requires_grad_(True)
|
||||
|
||||
_, y = net(x)
|
||||
|
||||
y, ints = net(x)
|
||||
y.backward(torch.ones_like(y))
|
||||
jacob = x.grad.detach()
|
||||
|
||||
return jacob, target.detach()
|
||||
|
||||
|
||||
def eval_score(jacob, labels=None):
|
||||
corrs = np.corrcoef(jacob)
|
||||
v, _ = np.linalg.eig(corrs)
|
||||
k = 1e-5
|
||||
return -np.sum(np.log(v + k) + 1./(v + k))
|
||||
|
||||
return jacob, target.detach(), y.detach(), ints.detach()
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
print(device)
|
||||
THE_START = time.time()
|
||||
api = API(args.api_loc)
|
||||
|
||||
searchspace = nasspace.get_search_space(args)
|
||||
train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args)
|
||||
os.makedirs(args.save_loc, exist_ok=True)
|
||||
|
||||
train_data, valid_data, xshape, class_num = get_datasets(args.dataset, args.data_loc, cutout=0)
|
||||
|
||||
if args.dataset == 'cifar10':
|
||||
acc_type = 'ori-test'
|
||||
val_acc_type = 'x-valid'
|
||||
|
||||
else:
|
||||
acc_type = 'x-test'
|
||||
val_acc_type = 'x-valid'
|
||||
|
||||
if args.trainval:
|
||||
cifar_split = load_config('config_utils/cifar-split.txt', None, None)
|
||||
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size,
|
||||
num_workers=0, pin_memory=True, sampler= torch.utils.data.sampler.SubsetRandomSampler(train_split))
|
||||
|
||||
else:
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
|
||||
num_workers=0, pin_memory=True)
|
||||
|
||||
times = []
|
||||
chosen = []
|
||||
acc = []
|
||||
val_acc = []
|
||||
topscores = []
|
||||
|
||||
dset = args.dataset if not args.trainval else 'cifar10-valid'
|
||||
|
||||
order_fn = np.nanargmax
|
||||
|
||||
|
||||
if args.dataset == 'cifar10':
|
||||
acc_type = 'ori-test'
|
||||
val_acc_type = 'x-valid'
|
||||
else:
|
||||
acc_type = 'x-test'
|
||||
val_acc_type = 'x-valid'
|
||||
|
||||
|
||||
|
||||
runs = trange(args.n_runs, desc='acc: ')
|
||||
for N in runs:
|
||||
start = time.time()
|
||||
indices = np.random.randint(0,15625,args.n_samples)
|
||||
indices = np.random.randint(0,len(searchspace),args.n_samples)
|
||||
scores = []
|
||||
|
||||
npstate = np.random.get_state()
|
||||
ranstate = random.getstate()
|
||||
torchstate = torch.random.get_rng_state()
|
||||
for arch in indices:
|
||||
|
||||
data_iterator = iter(train_loader)
|
||||
x, target = next(data_iterator)
|
||||
x, target = x.to(device), target.to(device)
|
||||
|
||||
config = api.get_net_config(arch, args.dataset)
|
||||
config['num_classes'] = 1
|
||||
|
||||
network = get_cell_based_tiny_net(config) # create the network from configuration
|
||||
network = network.to(device)
|
||||
|
||||
jacobs, labels= get_batch_jacobian(network, x, target, 1, device, args)
|
||||
jacobs = jacobs.reshape(jacobs.size(0), -1).cpu().numpy()
|
||||
|
||||
try:
|
||||
s = eval_score(jacobs, labels)
|
||||
uid = searchspace[arch]
|
||||
network = searchspace.get_network(uid)
|
||||
network.to(device)
|
||||
if args.dropout:
|
||||
add_dropout(network, args.sigma)
|
||||
if args.init != '':
|
||||
init_network(network, args.init)
|
||||
if 'hook_' in args.score:
|
||||
network.K = np.zeros((args.batch_size, args.batch_size))
|
||||
def counting_forward_hook(module, inp, out):
|
||||
try:
|
||||
if not module.visited_backwards:
|
||||
return
|
||||
if isinstance(inp, tuple):
|
||||
inp = inp[0]
|
||||
inp = inp.view(inp.size(0), -1)
|
||||
x = (inp > 0).float()
|
||||
K = x @ x.t()
|
||||
K2 = (1.-x) @ (1.-x.t())
|
||||
network.K = network.K + K.cpu().numpy() + K2.cpu().numpy()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def counting_backward_hook(module, inp, out):
|
||||
module.visited_backwards = True
|
||||
|
||||
|
||||
for name, module in network.named_modules():
|
||||
if 'ReLU' in str(type(module)):
|
||||
#hooks[name] = module.register_forward_hook(counting_hook)
|
||||
module.register_forward_hook(counting_forward_hook)
|
||||
module.register_backward_hook(counting_backward_hook)
|
||||
|
||||
random.setstate(ranstate)
|
||||
np.random.set_state(npstate)
|
||||
torch.set_rng_state(torchstate)
|
||||
|
||||
data_iterator = iter(train_loader)
|
||||
x, target = next(data_iterator)
|
||||
x2 = torch.clone(x)
|
||||
x2 = x2.to(device)
|
||||
x, target = x.to(device), target.to(device)
|
||||
jacobs, labels, y, out = get_batch_jacobian(network, x, target, device, args)
|
||||
|
||||
if args.kernel:
|
||||
s = get_score_func(args.score)(out, labels)
|
||||
elif 'hook_' in args.score:
|
||||
network(x2.to(device))
|
||||
s = get_score_func(args.score)(network.K, target)
|
||||
elif args.repeat < args.batch_size:
|
||||
s = get_score_func(args.score)(jacobs, labels, args.repeat)
|
||||
else:
|
||||
s = get_score_func(args.score)(jacobs, labels)
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
s = np.nan
|
||||
|
||||
s = 0.
|
||||
|
||||
scores.append(s)
|
||||
|
||||
#print(len(scores))
|
||||
#print(scores)
|
||||
#print(order_fn(scores))
|
||||
|
||||
|
||||
|
||||
best_arch = indices[order_fn(scores)]
|
||||
info = api.query_by_index(best_arch)
|
||||
uid = searchspace[best_arch]
|
||||
topscores.append(scores[order_fn(scores)])
|
||||
chosen.append(best_arch)
|
||||
acc.append(info.get_metrics(dset, acc_type)['accuracy'])
|
||||
#acc.append(searchspace.get_accuracy(uid, acc_type, args.trainval))
|
||||
acc.append(searchspace.get_final_accuracy(uid, acc_type, False))
|
||||
|
||||
if not args.dataset == 'cifar10' or args.trainval:
|
||||
val_acc.append(info.get_metrics(dset, val_acc_type)['accuracy'])
|
||||
val_acc.append(searchspace.get_final_accuracy(uid, val_acc_type, args.trainval))
|
||||
# val_acc.append(info.get_metrics(dset, val_acc_type)['accuracy'])
|
||||
|
||||
times.append(time.time()-start)
|
||||
runs.set_description(f"acc: {mean(acc if not args.trainval else val_acc):.2f}%")
|
||||
runs.set_description(f"acc: {mean(acc):.2f}% time:{mean(times):.2f}")
|
||||
|
||||
print(f"Final mean test accuracy: {np.mean(acc)}")
|
||||
if len(val_acc) > 1:
|
||||
print(f"Final mean validation accuracy: {np.mean(val_acc)}")
|
||||
#if len(val_acc) > 1:
|
||||
# print(f"Final mean validation accuracy: {np.mean(val_acc)}")
|
||||
|
||||
state = {'accs': acc,
|
||||
'val_accs': val_acc,
|
||||
'chosen': chosen,
|
||||
'times': times,
|
||||
'topscores': topscores,
|
||||
}
|
||||
|
||||
dset = args.dataset if not args.trainval else 'cifar10-valid'
|
||||
fname = f"{args.save_loc}/{dset}_{args.n_runs}_{args.n_samples}_{args.seed}.t7"
|
||||
dset = args.dataset if not (args.trainval and args.dataset == 'cifar10') else 'cifar10-valid'
|
||||
fname = f"{args.save_loc}/{args.save_string}_{args.score}_{args.nasspace}_{dset}_{args.kernel}_{args.dropout}_{args.augtype}_{args.sigma}_{args.repeat}_{args.batch_size}_{args.n_runs}_{args.n_samples}_{args.seed}.t7"
|
||||
torch.save(state, fname)
|
||||
|
||||
100
utils.py
Normal file
100
utils.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import torch
|
||||
from pycls.models.nas.nas import Cell
|
||||
|
||||
class DropChannel(torch.nn.Module):
|
||||
def __init__(self, p, mod):
|
||||
super(DropChannel, self).__init__()
|
||||
self.mod = mod
|
||||
self.p = p
|
||||
def forward(self, s0, s1, droppath):
|
||||
ret = self.mod(s0, s1, droppath)
|
||||
return ret
|
||||
|
||||
|
||||
class DropConnect(torch.nn.Module):
|
||||
def __init__(self, p):
|
||||
super(DropConnect, self).__init__()
|
||||
self.p = p
|
||||
def forward(self, inputs):
|
||||
batch_size = inputs.shape[0]
|
||||
dim1 = inputs.shape[2]
|
||||
dim2 = inputs.shape[3]
|
||||
channel_size = inputs.shape[1]
|
||||
keep_prob = 1 - self.p
|
||||
# generate binary_tensor mask according to probability (p for 0, 1-p for 1)
|
||||
random_tensor = keep_prob
|
||||
random_tensor += torch.rand([batch_size, channel_size, 1, 1], dtype=inputs.dtype, device=inputs.device)
|
||||
binary_tensor = torch.floor(random_tensor)
|
||||
output = inputs / keep_prob * binary_tensor
|
||||
return output
|
||||
|
||||
def add_dropout(network, p, prefix=''):
|
||||
#p = 0.5
|
||||
for attr_str in dir(network):
|
||||
target_attr = getattr(network, attr_str)
|
||||
if isinstance(target_attr, torch.nn.Conv2d):
|
||||
setattr(network, attr_str, torch.nn.Sequential(target_attr, DropConnect(p)))
|
||||
elif isinstance(target_attr, Cell):
|
||||
setattr(network, attr_str, DropChannel(p, target_attr))
|
||||
for n, ch in list(network.named_children()):
|
||||
#print(f'{prefix}add_dropout {n}')
|
||||
if isinstance(ch, torch.nn.Conv2d):
|
||||
setattr(network, n, torch.nn.Sequential(ch, DropConnect(p)))
|
||||
elif isinstance(ch, Cell):
|
||||
setattr(network, n, DropChannel(p, ch))
|
||||
else:
|
||||
add_dropout(ch, p, prefix + '\t')
|
||||
|
||||
|
||||
|
||||
|
||||
def orth_init(m):
|
||||
if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)):
|
||||
torch.nn.init.orthogonal_(m.weight)
|
||||
|
||||
def uni_init(m):
|
||||
if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)):
|
||||
torch.nn.init.uniform_(m.weight)
|
||||
|
||||
def uni2_init(m):
|
||||
if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)):
|
||||
torch.nn.init.uniform_(m.weight, -1., 1.)
|
||||
|
||||
def uni3_init(m):
|
||||
if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)):
|
||||
torch.nn.init.uniform_(m.weight, -.5, .5)
|
||||
|
||||
def norm_init(m):
|
||||
if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)):
|
||||
torch.nn.init.norm_(m.weight)
|
||||
|
||||
def eye_init(m):
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
torch.nn.init.eye_(m.weight)
|
||||
elif isinstance(m, torch.nn.Conv2d):
|
||||
torch.nn.init.dirac_(m.weight)
|
||||
|
||||
|
||||
|
||||
def fixup_init(m):
|
||||
if isinstance(m, torch.nn.Conv2d):
|
||||
torch.nn.init.zero_(m.weight)
|
||||
elif isinstance(m, torch.nn.Linear):
|
||||
torch.nn.init.zero_(m.weight)
|
||||
torch.nn.init.zero_(m.bias)
|
||||
|
||||
|
||||
def init_network(network, init):
|
||||
if init == 'orthogonal':
|
||||
network.apply(orth_init)
|
||||
elif init == 'uniform':
|
||||
print('uniform')
|
||||
network.apply(uni_init)
|
||||
elif init == 'uniform2':
|
||||
network.apply(uni2_init)
|
||||
elif init == 'uniform3':
|
||||
network.apply(uni3_init)
|
||||
elif init == 'normal':
|
||||
network.apply(norm_init)
|
||||
elif init == 'identity':
|
||||
network.apply(eye_init)
|
||||
@@ -1,81 +0,0 @@
|
||||
import re
|
||||
from graphviz import Digraph
|
||||
import pandas as pd
|
||||
import time
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='Fast cell visualisation')
|
||||
parser.add_argument('--arch', default=1, type=int)
|
||||
parser.add_argument('--save', action='store_true')
|
||||
args = parser.parse_args()
|
||||
|
||||
def set_none(bit):
|
||||
print(bit)
|
||||
tmp = bit.split('~')
|
||||
tmp[0] = 'none'
|
||||
print('~'.join(tmp))
|
||||
return '~'.join(tmp)
|
||||
|
||||
def remove_pointless_ops(archstr):
|
||||
old = None
|
||||
new = archstr
|
||||
while old != new:
|
||||
old = new
|
||||
bits = old.strip('|').split('|')
|
||||
if 'none~' in bits[0]: # node 1 has no connections to it
|
||||
bits[3] = set_none(bits[3]) # node 1 -> 2 now none
|
||||
bits[6] = set_none(bits[6]) # node 1 -> 3 now none
|
||||
if 'none~' in bits[2] and 'none~' in bits[3]: # node 2 has no connections to it
|
||||
bits[7] = set_none(bits[7]) # node 2 -> 3 now none
|
||||
if 'none~' in bits[7]: # doesn't matter what comes through node 2
|
||||
bits[2] = set_none(bits[2]) # node 0 -> 2 now none
|
||||
bits[3] = set_none(bits[3]) # node 1 -> 2 now none
|
||||
if 'none~' in bits[6] and 'none~' in bits[7]: # doesn't matter what comes through node 1
|
||||
bits[0] = set_none(bits[0]) # node 0 -> 1 now none
|
||||
new = '|'.join(bits)
|
||||
print(new)
|
||||
return new
|
||||
|
||||
|
||||
df = pd.read_pickle('results/arch_score_acc.pd')
|
||||
|
||||
nodestr = df.iloc[args.arch]['cellstr']
|
||||
nodestr = nodestr[1:-1] # remove leading and trailing bars |
|
||||
|
||||
nodestr = remove_pointless_ops(nodestr)
|
||||
nodes = nodestr.split("|+|")
|
||||
|
||||
dot = Digraph(
|
||||
format='pdf',
|
||||
edge_attr=dict(fontsize='12'),
|
||||
node_attr=dict(fixedsize='true',shape="circle", height='0.5', width='0.5'),
|
||||
engine='dot')
|
||||
|
||||
dot.body.extend(['rankdir=LR'])
|
||||
|
||||
OPS = ['conv_3x3','avg_pool_3x3','skip_connect','conv_1x1','none']
|
||||
|
||||
dot.node('0', 'in')
|
||||
|
||||
## ops are separated by bars (|) so
|
||||
for i, node in enumerate(nodes):
|
||||
|
||||
# if node 3 then label as output
|
||||
if (i+1) == 3:
|
||||
dot.node(str(i+1), 'out')
|
||||
else:
|
||||
dot.node(str(i+1))
|
||||
|
||||
for op_str in node.split('|'):
|
||||
op_name = [o for o in OPS if o in op_str][0]
|
||||
if op_name == 'none':
|
||||
break
|
||||
connect = re.findall('~[0-9]', op_str)[0]
|
||||
connect = connect[1:]
|
||||
dot.edge(connect,str(i+1), label=op_name)
|
||||
|
||||
dot.render( view=True)
|
||||
|
||||
|
||||
if args.save:
|
||||
dot.render(f'outputs/{args.arch}.gv')
|
||||
Reference in New Issue
Block a user