diff --git a/exps/NATS-algos/search-size.py b/exps/NATS-algos/search-size.py index c3c3dcc..d0a9931 100644 --- a/exps/NATS-algos/search-size.py +++ b/exps/NATS-algos/search-size.py @@ -2,7 +2,13 @@ # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # ###################################################################################### # In this file, we aims to evaluate three kinds of channel searching strategies: -# - +# - channel-wise interpaltion from "Network Pruning via Transformable Architecture Search, NeurIPS 2019" +# - masking + Gumbel-Softmax from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020" +# - masking + sampling from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020" +# For simplicity, we use tas, fbv2, and tunas to refer these three strategies. Their official implementations are at the following links: +# - TAS: https://github.com/D-X-Y/AutoDL-Projects/blob/master/docs/NeurIPS-2019-TAS.md +# - FBV2: https://github.com/facebookresearch/mobile-vision +# - TuNAS: https://github.com/google-research/google-research/tree/master/tunas #### # python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio 0.25 #### diff --git a/exps/experimental/vis-nats-bench-ws.py b/exps/experimental/vis-nats-bench-ws.py index 563c9e6..b1d5014 100644 --- a/exps/experimental/vis-nats-bench-ws.py +++ b/exps/experimental/vis-nats-bench-ws.py @@ -26,7 +26,8 @@ from nats_bench import create from log_utils import time_string -def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suffix='-AWD0.0-WARMNone'): +# def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suffix='-WARMNone'): +def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suffix='-WARM0.3'): ss_dir = '{:}-{:}'.format(root_dir, search_space) alg2name, alg2path = OrderedDict(), OrderedDict() seeds = [777, 888, 999] @@ -39,9 +40,12 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suf alg2name['ENAS'] = 'enas-affine0_BN0-None' alg2name['SETN'] = 'setn-affine0_BN0-None' else: - alg2name['TAS'] = 'tas-affine0_BN0{:}'.format(suffix) - alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix) - alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix) + # alg2name['TAS'] = 'tas-affine0_BN0{:}'.format(suffix) + # alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix) + # alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix) + alg2name['channel-wise interpaltion'] = 'tas-affine0_BN0-AWD0.001{:}'.format(suffix) + alg2name['masking + Gumbel-Softmax'] = 'fbv2-affine0_BN0-AWD0.001{:}'.format(suffix) + alg2name['masking + sampling'] = 'tunas-affine0_BN0-AWD0.0{:}'.format(suffix) for alg, name in alg2name.items(): alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth') alg2data = OrderedDict() @@ -98,8 +102,11 @@ def visualize_curve(api, vis_save_dir, search_space): for idx, (alg, data) in enumerate(alg2data.items()): print('plot alg : {:}'.format(alg)) xs, accuracies = [], [] - for iepoch in range(epochs+1): - structures, accs = [_[iepoch-1] for _ in data], [] + for iepoch in range(epochs + 1): + try: + structures, accs = [_[iepoch-1] for _ in data], [] + except: + raise ValueError('This alg {:} on {:} has invalid checkpoints.'.format(alg, dataset)) for structure in structures: info = api.get_more_info(structure, dataset=dataset, hp=90 if api.search_space_name == 'size' else 200, is_random=False) accs.append(info['test-accuracy']) @@ -131,5 +138,5 @@ if __name__ == '__main__': save_dir = Path(args.save_dir) - api = create(None, args.search_space, verbose=False) + api = create(None, args.search_space, fast_mode=True, verbose=False) visualize_curve(api, save_dir, args.search_space) diff --git a/lib/models/shape_searchs/generic_size_tiny_cell_model.py b/lib/models/shape_searchs/generic_size_tiny_cell_model.py index e1a00f9..e6e5ff3 100644 --- a/lib/models/shape_searchs/generic_size_tiny_cell_model.py +++ b/lib/models/shape_searchs/generic_size_tiny_cell_model.py @@ -2,8 +2,8 @@ # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # ##################################################### # Here, we utilized three techniques to search for the number of channels: -# - feature interpaltion from "Network Pruning via Transformable Architecture Search, NeurIPS 2019" -# - masking + GumbelSoftmax from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020" +# - channel-wise interpaltion from "Network Pruning via Transformable Architecture Search, NeurIPS 2019" +# - masking + Gumbel-Softmax from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020" # - masking + sampling from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020" from typing import List, Text, Any import random, torch @@ -55,10 +55,10 @@ class GenericNAS301Model(nn.Module): assert algo in ['fbv2', 'tunas', 'tas'], 'invalid algo : {:}'.format(algo) self._algo = algo self._arch_parameters = nn.Parameter(1e-3*torch.randn(self._max_num_Cs, len(self._candidate_Cs))) - if algo == 'fbv2' or algo == 'tunas': - self.register_buffer('_masks', torch.zeros(len(self._candidate_Cs), max(self._candidate_Cs))) - for i in range(len(self._candidate_Cs)): - self._masks.data[i, :self._candidate_Cs[i]] = 1 + # if algo == 'fbv2' or algo == 'tunas': + self.register_buffer('_masks', torch.zeros(len(self._candidate_Cs), max(self._candidate_Cs))) + for i in range(len(self._candidate_Cs)): + self._masks.data[i, :self._candidate_Cs[i]] = 1 @property def tau(self): diff --git a/lib/nats_bench/api_size.py b/lib/nats_bench/api_size.py index 14cc5b5..d10425c 100644 --- a/lib/nats_bench/api_size.py +++ b/lib/nats_bench/api_size.py @@ -7,7 +7,6 @@ # [2020.08.31] NATS-sss-v1_0-50262.pickle.pbz2 # ##################################################################################### import os, copy, random, numpy as np -from pathlib import Path from typing import List, Text, Union, Dict, Optional from collections import OrderedDict, defaultdict from .api_utils import time_string @@ -15,6 +14,8 @@ from .api_utils import pickle_load from .api_utils import ArchResults from .api_utils import NASBenchMetaAPI from .api_utils import remap_dataset_set_names +from .api_utils import nats_is_dir +from .api_utils import nats_is_file from .api_utils import PICKLE_EXT @@ -70,20 +71,20 @@ class NATSsize(NASBenchMetaAPI): else: file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], '{:}.{:}'.format(ALL_BASE_NAMES[-1], PICKLE_EXT)) print ('{:} Try to use the default NATS-Bench (size) path from fast_mode={:} and path={:}.'.format(time_string(), self._fast_mode, file_path_or_dict)) - if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path): + if isinstance(file_path_or_dict, str): file_path_or_dict = str(file_path_or_dict) if verbose: print('{:} Try to create the NATS-Bench (size) api from {:} with fast_mode={:}'.format(time_string(), file_path_or_dict, fast_mode)) - if not os.path.isfile(file_path_or_dict) and not os.path.isdir(file_path_or_dict): + if not nats_is_file(file_path_or_dict) and not nats_is_dir(file_path_or_dict): raise ValueError('{:} is neither a file or a dir.'.format(file_path_or_dict)) - self.filename = Path(file_path_or_dict).name + self.filename = os.path.basename(file_path_or_dict) if fast_mode: - if os.path.isfile(file_path_or_dict): + if nats_is_file(file_path_or_dict): raise ValueError('fast_mode={:} must feed the path for directory : {:}'.format(fast_mode, file_path_or_dict)) else: self._archive_dir = file_path_or_dict else: - if os.path.isdir(file_path_or_dict): + if nats_is_dir(file_path_or_dict): raise ValueError('fast_mode={:} must feed the path for file : {:}'.format(fast_mode, file_path_or_dict)) else: file_path_or_dict = pickle_load(file_path_or_dict) diff --git a/lib/nats_bench/api_topology.py b/lib/nats_bench/api_topology.py index c8659fc..9b0dccb 100644 --- a/lib/nats_bench/api_topology.py +++ b/lib/nats_bench/api_topology.py @@ -7,7 +7,6 @@ # [2020.08.31] NATS-tss-v1_0-3ffb9.pickle.pbz2 # ##################################################################################### import os, copy, random, numpy as np -from pathlib import Path from typing import List, Text, Union, Dict, Optional from collections import OrderedDict, defaultdict import warnings @@ -16,6 +15,8 @@ from .api_utils import pickle_load from .api_utils import ArchResults from .api_utils import NASBenchMetaAPI from .api_utils import remap_dataset_set_names +from .api_utils import nats_is_dir +from .api_utils import nats_is_file from .api_utils import PICKLE_EXT @@ -67,20 +68,20 @@ class NATStopology(NASBenchMetaAPI): else: file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], '{:}.{:}'.format(ALL_BASE_NAMES[-1], PICKLE_EXT)) print ('{:} Try to use the default NATS-Bench (topology) path from {:}.'.format(time_string(), file_path_or_dict)) - if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path): + if isinstance(file_path_or_dict, str): file_path_or_dict = str(file_path_or_dict) if verbose: print('{:} Try to create the NATS-Bench (topology) api from {:} with fast_mode={:}'.format(time_string(), file_path_or_dict, fast_mode)) - if not os.path.isfile(file_path_or_dict) and not os.path.isdir(file_path_or_dict): + if not nats_is_file(file_path_or_dict) and not nats_is_dir(file_path_or_dict): raise ValueError('{:} is neither a file or a dir.'.format(file_path_or_dict)) - self.filename = Path(file_path_or_dict).name + self.filename = os.path.basename(file_path_or_dict) if fast_mode: - if os.path.isfile(file_path_or_dict): + if nats_is_file(file_path_or_dict): raise ValueError('fast_mode={:} must feed the path for directory : {:}'.format(fast_mode, file_path_or_dict)) else: self._archive_dir = file_path_or_dict else: - if os.path.isdir(file_path_or_dict): + if nats_is_dir(file_path_or_dict): raise ValueError('fast_mode={:} must feed the path for file : {:}'.format(fast_mode, file_path_or_dict)) else: file_path_or_dict = pickle_load(file_path_or_dict) diff --git a/lib/nats_bench/api_utils.py b/lib/nats_bench/api_utils.py index 4f7ca35..aa49969 100644 --- a/lib/nats_bench/api_utils.py +++ b/lib/nats_bench/api_utils.py @@ -17,6 +17,7 @@ from typing import List, Text, Union, Dict, Optional from collections import OrderedDict, defaultdict +_FILE_SYSTEM = 'default' PICKLE_EXT = 'pickle.pbz2' @@ -45,6 +46,34 @@ def time_string(): return string +def reset_file_system(lib: Text='default'): + _FILE_SYSTEM = lib + + +def get_file_system(lib: Text='default'): + return _FILE_SYSTEM + + +def nats_is_dir(file_path): + if _FILE_SYSTEM == 'default': + return os.path.isdir(file_path) + elif _FILE_SYSTEM == 'google': + import tensorflow as tf + return tf.gfile.isdir(file_path) + else: + raise ValueError('Unknown file system lib: {:}'.format(_FILE_SYSTEM)) + + +def nats_is_file(file_path): + if _FILE_SYSTEM == 'default': + return os.path.isfile(file_path) + elif _FILE_SYSTEM == 'google': + import tensorflow as tf + return tf.gfile.exists(file_path) and not tf.gfile.isdir(file_path) + else: + raise ValueError('Unknown file system lib: {:}'.format(_FILE_SYSTEM)) + + def remap_dataset_set_names(dataset, metric_on_set, verbose=False): """re-map the metric_on_set to internal keys""" if verbose: @@ -146,10 +175,10 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta): time_string(), archive_root, index)) if archive_root is None: archive_root = os.path.join(os.environ['TORCH_HOME'], '{:}-full'.format(self.ALL_BASE_NAMES[-1])) - if not os.path.isdir(archive_root): + if not nats_is_dir(archive_root): warnings.warn('The input archive_root is None and the default archive_root path ({:}) does not exist, try to use self.archive_dir.'.format(archive_root)) archive_root = self.archive_dir - if archive_root is None or not os.path.isdir(archive_root): + if archive_root is None or not nats_is_dir(archive_root): raise ValueError('Invalid archive_root : {:}'.format(archive_root)) if index is None: indexes = list(range(len(self))) @@ -158,9 +187,9 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta): for idx in indexes: assert 0 <= idx < len(self.meta_archs), 'invalid index of {:}'.format(idx) xfile_path = os.path.join(archive_root, '{:06d}.{:}'.format(idx, PICKLE_EXT)) - if not os.path.isfile(xfile_path): + if not nats_is_file(xfile_path): xfile_path = os.path.join(archive_root, '{:d}.{:}'.format(idx, PICKLE_EXT)) - assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path) + assert nats_is_file(xfile_path), 'invalid data path : {:}'.format(xfile_path) xdata = pickle_load(xfile_path) assert isinstance(xdata, dict), 'invalid format of data in {:}'.format(xfile_path) self.evaluated_indexes.add(idx) diff --git a/scripts-search/NATS/search-size.sh b/scripts-search/NATS/search-size.sh index 61c64c7..df5b97d 100644 --- a/scripts-search/NATS/search-size.sh +++ b/scripts-search/NATS/search-size.sh @@ -1,10 +1,10 @@ #!/bin/bash -# bash ./NATS/search-size.sh 0 777 +# bash scripts-search/NATS/search-size.sh 0 0.3 777 echo script name: $0 echo $# arguments -if [ "$#" -ne 2 ] ;then +if [ "$#" -ne 3 ] ;then echo "Input illegal number of parameters " $# - echo "Need 2 parameters for GPU-device and seed" + echo "Need 3 parameters for GPU-device, warmup-ratio, and seed" exit 1 fi if [ "$TORCH_HOME" = "" ]; then @@ -15,16 +15,19 @@ else fi device=$1 -seed=$2 +ratio=$2 +seed=$3 -CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed ${seed} -CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed ${seed} -CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed ${seed} +CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tas --warmup_ratio ${ratio} --rand_seed ${seed} +CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --warmup_ratio ${ratio} --rand_seed ${seed} +CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --warmup_ratio ${ratio} --rand_seed ${seed} -CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed ${seed} -CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed ${seed} -CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --rand_seed ${seed} +# +CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo fbv2 --warmup_ratio ${ratio} --rand_seed ${seed} +CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --warmup_ratio ${ratio} --rand_seed ${seed} +CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --warmup_ratio ${ratio} --rand_seed ${seed} -CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed ${seed} -CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed ${seed} -CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tunas --arch_weight_decay 0 --rand_seed ${seed} +# +CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed} +CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed} +CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tunas --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed}