Update Warmup

This commit is contained in:
D-X-Y 2020-10-08 10:19:34 +11:00
parent ad5d6e28b9
commit ab801cbf14
7 changed files with 90 additions and 43 deletions

View File

@ -2,7 +2,13 @@
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
###################################################################################### ######################################################################################
# In this file, we aims to evaluate three kinds of channel searching strategies: # In this file, we aims to evaluate three kinds of channel searching strategies:
# - # - channel-wise interpaltion from "Network Pruning via Transformable Architecture Search, NeurIPS 2019"
# - masking + Gumbel-Softmax from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020"
# - masking + sampling from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020"
# For simplicity, we use tas, fbv2, and tunas to refer these three strategies. Their official implementations are at the following links:
# - TAS: https://github.com/D-X-Y/AutoDL-Projects/blob/master/docs/NeurIPS-2019-TAS.md
# - FBV2: https://github.com/facebookresearch/mobile-vision
# - TuNAS: https://github.com/google-research/google-research/tree/master/tunas
#### ####
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio 0.25 # python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio 0.25
#### ####

View File

@ -26,7 +26,8 @@ from nats_bench import create
from log_utils import time_string from log_utils import time_string
def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suffix='-AWD0.0-WARMNone'): # def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suffix='-WARMNone'):
def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suffix='-WARM0.3'):
ss_dir = '{:}-{:}'.format(root_dir, search_space) ss_dir = '{:}-{:}'.format(root_dir, search_space)
alg2name, alg2path = OrderedDict(), OrderedDict() alg2name, alg2path = OrderedDict(), OrderedDict()
seeds = [777, 888, 999] seeds = [777, 888, 999]
@ -39,9 +40,12 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suf
alg2name['ENAS'] = 'enas-affine0_BN0-None' alg2name['ENAS'] = 'enas-affine0_BN0-None'
alg2name['SETN'] = 'setn-affine0_BN0-None' alg2name['SETN'] = 'setn-affine0_BN0-None'
else: else:
alg2name['TAS'] = 'tas-affine0_BN0{:}'.format(suffix) # alg2name['TAS'] = 'tas-affine0_BN0{:}'.format(suffix)
alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix) # alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix)
alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix) # alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix)
alg2name['channel-wise interpaltion'] = 'tas-affine0_BN0-AWD0.001{:}'.format(suffix)
alg2name['masking + Gumbel-Softmax'] = 'fbv2-affine0_BN0-AWD0.001{:}'.format(suffix)
alg2name['masking + sampling'] = 'tunas-affine0_BN0-AWD0.0{:}'.format(suffix)
for alg, name in alg2name.items(): for alg, name in alg2name.items():
alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth') alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth')
alg2data = OrderedDict() alg2data = OrderedDict()
@ -98,8 +102,11 @@ def visualize_curve(api, vis_save_dir, search_space):
for idx, (alg, data) in enumerate(alg2data.items()): for idx, (alg, data) in enumerate(alg2data.items()):
print('plot alg : {:}'.format(alg)) print('plot alg : {:}'.format(alg))
xs, accuracies = [], [] xs, accuracies = [], []
for iepoch in range(epochs+1): for iepoch in range(epochs + 1):
structures, accs = [_[iepoch-1] for _ in data], [] try:
structures, accs = [_[iepoch-1] for _ in data], []
except:
raise ValueError('This alg {:} on {:} has invalid checkpoints.'.format(alg, dataset))
for structure in structures: for structure in structures:
info = api.get_more_info(structure, dataset=dataset, hp=90 if api.search_space_name == 'size' else 200, is_random=False) info = api.get_more_info(structure, dataset=dataset, hp=90 if api.search_space_name == 'size' else 200, is_random=False)
accs.append(info['test-accuracy']) accs.append(info['test-accuracy'])
@ -131,5 +138,5 @@ if __name__ == '__main__':
save_dir = Path(args.save_dir) save_dir = Path(args.save_dir)
api = create(None, args.search_space, verbose=False) api = create(None, args.search_space, fast_mode=True, verbose=False)
visualize_curve(api, save_dir, args.search_space) visualize_curve(api, save_dir, args.search_space)

View File

@ -2,8 +2,8 @@
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
##################################################### #####################################################
# Here, we utilized three techniques to search for the number of channels: # Here, we utilized three techniques to search for the number of channels:
# - feature interpaltion from "Network Pruning via Transformable Architecture Search, NeurIPS 2019" # - channel-wise interpaltion from "Network Pruning via Transformable Architecture Search, NeurIPS 2019"
# - masking + GumbelSoftmax from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020" # - masking + Gumbel-Softmax from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020"
# - masking + sampling from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020" # - masking + sampling from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020"
from typing import List, Text, Any from typing import List, Text, Any
import random, torch import random, torch
@ -55,10 +55,10 @@ class GenericNAS301Model(nn.Module):
assert algo in ['fbv2', 'tunas', 'tas'], 'invalid algo : {:}'.format(algo) assert algo in ['fbv2', 'tunas', 'tas'], 'invalid algo : {:}'.format(algo)
self._algo = algo self._algo = algo
self._arch_parameters = nn.Parameter(1e-3*torch.randn(self._max_num_Cs, len(self._candidate_Cs))) self._arch_parameters = nn.Parameter(1e-3*torch.randn(self._max_num_Cs, len(self._candidate_Cs)))
if algo == 'fbv2' or algo == 'tunas': # if algo == 'fbv2' or algo == 'tunas':
self.register_buffer('_masks', torch.zeros(len(self._candidate_Cs), max(self._candidate_Cs))) self.register_buffer('_masks', torch.zeros(len(self._candidate_Cs), max(self._candidate_Cs)))
for i in range(len(self._candidate_Cs)): for i in range(len(self._candidate_Cs)):
self._masks.data[i, :self._candidate_Cs[i]] = 1 self._masks.data[i, :self._candidate_Cs[i]] = 1
@property @property
def tau(self): def tau(self):

View File

@ -7,7 +7,6 @@
# [2020.08.31] NATS-sss-v1_0-50262.pickle.pbz2 # # [2020.08.31] NATS-sss-v1_0-50262.pickle.pbz2 #
##################################################################################### #####################################################################################
import os, copy, random, numpy as np import os, copy, random, numpy as np
from pathlib import Path
from typing import List, Text, Union, Dict, Optional from typing import List, Text, Union, Dict, Optional
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from .api_utils import time_string from .api_utils import time_string
@ -15,6 +14,8 @@ from .api_utils import pickle_load
from .api_utils import ArchResults from .api_utils import ArchResults
from .api_utils import NASBenchMetaAPI from .api_utils import NASBenchMetaAPI
from .api_utils import remap_dataset_set_names from .api_utils import remap_dataset_set_names
from .api_utils import nats_is_dir
from .api_utils import nats_is_file
from .api_utils import PICKLE_EXT from .api_utils import PICKLE_EXT
@ -70,20 +71,20 @@ class NATSsize(NASBenchMetaAPI):
else: else:
file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], '{:}.{:}'.format(ALL_BASE_NAMES[-1], PICKLE_EXT)) file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], '{:}.{:}'.format(ALL_BASE_NAMES[-1], PICKLE_EXT))
print ('{:} Try to use the default NATS-Bench (size) path from fast_mode={:} and path={:}.'.format(time_string(), self._fast_mode, file_path_or_dict)) print ('{:} Try to use the default NATS-Bench (size) path from fast_mode={:} and path={:}.'.format(time_string(), self._fast_mode, file_path_or_dict))
if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path): if isinstance(file_path_or_dict, str):
file_path_or_dict = str(file_path_or_dict) file_path_or_dict = str(file_path_or_dict)
if verbose: if verbose:
print('{:} Try to create the NATS-Bench (size) api from {:} with fast_mode={:}'.format(time_string(), file_path_or_dict, fast_mode)) print('{:} Try to create the NATS-Bench (size) api from {:} with fast_mode={:}'.format(time_string(), file_path_or_dict, fast_mode))
if not os.path.isfile(file_path_or_dict) and not os.path.isdir(file_path_or_dict): if not nats_is_file(file_path_or_dict) and not nats_is_dir(file_path_or_dict):
raise ValueError('{:} is neither a file or a dir.'.format(file_path_or_dict)) raise ValueError('{:} is neither a file or a dir.'.format(file_path_or_dict))
self.filename = Path(file_path_or_dict).name self.filename = os.path.basename(file_path_or_dict)
if fast_mode: if fast_mode:
if os.path.isfile(file_path_or_dict): if nats_is_file(file_path_or_dict):
raise ValueError('fast_mode={:} must feed the path for directory : {:}'.format(fast_mode, file_path_or_dict)) raise ValueError('fast_mode={:} must feed the path for directory : {:}'.format(fast_mode, file_path_or_dict))
else: else:
self._archive_dir = file_path_or_dict self._archive_dir = file_path_or_dict
else: else:
if os.path.isdir(file_path_or_dict): if nats_is_dir(file_path_or_dict):
raise ValueError('fast_mode={:} must feed the path for file : {:}'.format(fast_mode, file_path_or_dict)) raise ValueError('fast_mode={:} must feed the path for file : {:}'.format(fast_mode, file_path_or_dict))
else: else:
file_path_or_dict = pickle_load(file_path_or_dict) file_path_or_dict = pickle_load(file_path_or_dict)

View File

@ -7,7 +7,6 @@
# [2020.08.31] NATS-tss-v1_0-3ffb9.pickle.pbz2 # # [2020.08.31] NATS-tss-v1_0-3ffb9.pickle.pbz2 #
##################################################################################### #####################################################################################
import os, copy, random, numpy as np import os, copy, random, numpy as np
from pathlib import Path
from typing import List, Text, Union, Dict, Optional from typing import List, Text, Union, Dict, Optional
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
import warnings import warnings
@ -16,6 +15,8 @@ from .api_utils import pickle_load
from .api_utils import ArchResults from .api_utils import ArchResults
from .api_utils import NASBenchMetaAPI from .api_utils import NASBenchMetaAPI
from .api_utils import remap_dataset_set_names from .api_utils import remap_dataset_set_names
from .api_utils import nats_is_dir
from .api_utils import nats_is_file
from .api_utils import PICKLE_EXT from .api_utils import PICKLE_EXT
@ -67,20 +68,20 @@ class NATStopology(NASBenchMetaAPI):
else: else:
file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], '{:}.{:}'.format(ALL_BASE_NAMES[-1], PICKLE_EXT)) file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], '{:}.{:}'.format(ALL_BASE_NAMES[-1], PICKLE_EXT))
print ('{:} Try to use the default NATS-Bench (topology) path from {:}.'.format(time_string(), file_path_or_dict)) print ('{:} Try to use the default NATS-Bench (topology) path from {:}.'.format(time_string(), file_path_or_dict))
if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path): if isinstance(file_path_or_dict, str):
file_path_or_dict = str(file_path_or_dict) file_path_or_dict = str(file_path_or_dict)
if verbose: if verbose:
print('{:} Try to create the NATS-Bench (topology) api from {:} with fast_mode={:}'.format(time_string(), file_path_or_dict, fast_mode)) print('{:} Try to create the NATS-Bench (topology) api from {:} with fast_mode={:}'.format(time_string(), file_path_or_dict, fast_mode))
if not os.path.isfile(file_path_or_dict) and not os.path.isdir(file_path_or_dict): if not nats_is_file(file_path_or_dict) and not nats_is_dir(file_path_or_dict):
raise ValueError('{:} is neither a file or a dir.'.format(file_path_or_dict)) raise ValueError('{:} is neither a file or a dir.'.format(file_path_or_dict))
self.filename = Path(file_path_or_dict).name self.filename = os.path.basename(file_path_or_dict)
if fast_mode: if fast_mode:
if os.path.isfile(file_path_or_dict): if nats_is_file(file_path_or_dict):
raise ValueError('fast_mode={:} must feed the path for directory : {:}'.format(fast_mode, file_path_or_dict)) raise ValueError('fast_mode={:} must feed the path for directory : {:}'.format(fast_mode, file_path_or_dict))
else: else:
self._archive_dir = file_path_or_dict self._archive_dir = file_path_or_dict
else: else:
if os.path.isdir(file_path_or_dict): if nats_is_dir(file_path_or_dict):
raise ValueError('fast_mode={:} must feed the path for file : {:}'.format(fast_mode, file_path_or_dict)) raise ValueError('fast_mode={:} must feed the path for file : {:}'.format(fast_mode, file_path_or_dict))
else: else:
file_path_or_dict = pickle_load(file_path_or_dict) file_path_or_dict = pickle_load(file_path_or_dict)

View File

@ -17,6 +17,7 @@ from typing import List, Text, Union, Dict, Optional
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
_FILE_SYSTEM = 'default'
PICKLE_EXT = 'pickle.pbz2' PICKLE_EXT = 'pickle.pbz2'
@ -45,6 +46,34 @@ def time_string():
return string return string
def reset_file_system(lib: Text='default'):
_FILE_SYSTEM = lib
def get_file_system(lib: Text='default'):
return _FILE_SYSTEM
def nats_is_dir(file_path):
if _FILE_SYSTEM == 'default':
return os.path.isdir(file_path)
elif _FILE_SYSTEM == 'google':
import tensorflow as tf
return tf.gfile.isdir(file_path)
else:
raise ValueError('Unknown file system lib: {:}'.format(_FILE_SYSTEM))
def nats_is_file(file_path):
if _FILE_SYSTEM == 'default':
return os.path.isfile(file_path)
elif _FILE_SYSTEM == 'google':
import tensorflow as tf
return tf.gfile.exists(file_path) and not tf.gfile.isdir(file_path)
else:
raise ValueError('Unknown file system lib: {:}'.format(_FILE_SYSTEM))
def remap_dataset_set_names(dataset, metric_on_set, verbose=False): def remap_dataset_set_names(dataset, metric_on_set, verbose=False):
"""re-map the metric_on_set to internal keys""" """re-map the metric_on_set to internal keys"""
if verbose: if verbose:
@ -146,10 +175,10 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
time_string(), archive_root, index)) time_string(), archive_root, index))
if archive_root is None: if archive_root is None:
archive_root = os.path.join(os.environ['TORCH_HOME'], '{:}-full'.format(self.ALL_BASE_NAMES[-1])) archive_root = os.path.join(os.environ['TORCH_HOME'], '{:}-full'.format(self.ALL_BASE_NAMES[-1]))
if not os.path.isdir(archive_root): if not nats_is_dir(archive_root):
warnings.warn('The input archive_root is None and the default archive_root path ({:}) does not exist, try to use self.archive_dir.'.format(archive_root)) warnings.warn('The input archive_root is None and the default archive_root path ({:}) does not exist, try to use self.archive_dir.'.format(archive_root))
archive_root = self.archive_dir archive_root = self.archive_dir
if archive_root is None or not os.path.isdir(archive_root): if archive_root is None or not nats_is_dir(archive_root):
raise ValueError('Invalid archive_root : {:}'.format(archive_root)) raise ValueError('Invalid archive_root : {:}'.format(archive_root))
if index is None: if index is None:
indexes = list(range(len(self))) indexes = list(range(len(self)))
@ -158,9 +187,9 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
for idx in indexes: for idx in indexes:
assert 0 <= idx < len(self.meta_archs), 'invalid index of {:}'.format(idx) assert 0 <= idx < len(self.meta_archs), 'invalid index of {:}'.format(idx)
xfile_path = os.path.join(archive_root, '{:06d}.{:}'.format(idx, PICKLE_EXT)) xfile_path = os.path.join(archive_root, '{:06d}.{:}'.format(idx, PICKLE_EXT))
if not os.path.isfile(xfile_path): if not nats_is_file(xfile_path):
xfile_path = os.path.join(archive_root, '{:d}.{:}'.format(idx, PICKLE_EXT)) xfile_path = os.path.join(archive_root, '{:d}.{:}'.format(idx, PICKLE_EXT))
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path) assert nats_is_file(xfile_path), 'invalid data path : {:}'.format(xfile_path)
xdata = pickle_load(xfile_path) xdata = pickle_load(xfile_path)
assert isinstance(xdata, dict), 'invalid format of data in {:}'.format(xfile_path) assert isinstance(xdata, dict), 'invalid format of data in {:}'.format(xfile_path)
self.evaluated_indexes.add(idx) self.evaluated_indexes.add(idx)

View File

@ -1,10 +1,10 @@
#!/bin/bash #!/bin/bash
# bash ./NATS/search-size.sh 0 777 # bash scripts-search/NATS/search-size.sh 0 0.3 777
echo script name: $0 echo script name: $0
echo $# arguments echo $# arguments
if [ "$#" -ne 2 ] ;then if [ "$#" -ne 3 ] ;then
echo "Input illegal number of parameters " $# echo "Input illegal number of parameters " $#
echo "Need 2 parameters for GPU-device and seed" echo "Need 3 parameters for GPU-device, warmup-ratio, and seed"
exit 1 exit 1
fi fi
if [ "$TORCH_HOME" = "" ]; then if [ "$TORCH_HOME" = "" ]; then
@ -15,16 +15,19 @@ else
fi fi
device=$1 device=$1
seed=$2 ratio=$2
seed=$3
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed ${seed} CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tas --warmup_ratio ${ratio} --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed ${seed} CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --warmup_ratio ${ratio} --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed ${seed} CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --warmup_ratio ${ratio} --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed ${seed} #
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed ${seed} CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo fbv2 --warmup_ratio ${ratio} --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --rand_seed ${seed} CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --warmup_ratio ${ratio} --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --warmup_ratio ${ratio} --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed ${seed} #
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed ${seed} CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tunas --arch_weight_decay 0 --rand_seed ${seed} CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed}
CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tunas --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed}