Update Warmup
This commit is contained in:
		| @@ -2,7 +2,13 @@ | |||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||||
| ###################################################################################### | ###################################################################################### | ||||||
| # In this file, we aims to evaluate three kinds of channel searching strategies: | # In this file, we aims to evaluate three kinds of channel searching strategies: | ||||||
| # -  | # - channel-wise interpaltion from "Network Pruning via Transformable Architecture Search, NeurIPS 2019" | ||||||
|  | # - masking + Gumbel-Softmax from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020" | ||||||
|  | # - masking + sampling from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020" | ||||||
|  | # For simplicity, we use tas, fbv2, and tunas to refer these three strategies. Their official implementations are at the following links: | ||||||
|  | # - TAS: https://github.com/D-X-Y/AutoDL-Projects/blob/master/docs/NeurIPS-2019-TAS.md | ||||||
|  | # - FBV2: https://github.com/facebookresearch/mobile-vision | ||||||
|  | # - TuNAS: https://github.com/google-research/google-research/tree/master/tunas | ||||||
| #### | #### | ||||||
| # python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio 0.25 | # python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio 0.25 | ||||||
| #### | #### | ||||||
|   | |||||||
| @@ -26,7 +26,8 @@ from nats_bench import create | |||||||
| from log_utils import time_string | from log_utils import time_string | ||||||
|  |  | ||||||
|  |  | ||||||
| def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suffix='-AWD0.0-WARMNone'): | # def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suffix='-WARMNone'): | ||||||
|  | def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suffix='-WARM0.3'): | ||||||
|   ss_dir = '{:}-{:}'.format(root_dir, search_space) |   ss_dir = '{:}-{:}'.format(root_dir, search_space) | ||||||
|   alg2name, alg2path = OrderedDict(), OrderedDict() |   alg2name, alg2path = OrderedDict(), OrderedDict() | ||||||
|   seeds = [777, 888, 999] |   seeds = [777, 888, 999] | ||||||
| @@ -39,9 +40,12 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suf | |||||||
|     alg2name['ENAS'] = 'enas-affine0_BN0-None' |     alg2name['ENAS'] = 'enas-affine0_BN0-None' | ||||||
|     alg2name['SETN'] = 'setn-affine0_BN0-None' |     alg2name['SETN'] = 'setn-affine0_BN0-None' | ||||||
|   else: |   else: | ||||||
|     alg2name['TAS'] = 'tas-affine0_BN0{:}'.format(suffix) |     # alg2name['TAS'] = 'tas-affine0_BN0{:}'.format(suffix) | ||||||
|     alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix) |     # alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix) | ||||||
|     alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix) |     # alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix) | ||||||
|  |     alg2name['channel-wise interpaltion'] = 'tas-affine0_BN0-AWD0.001{:}'.format(suffix) | ||||||
|  |     alg2name['masking + Gumbel-Softmax'] = 'fbv2-affine0_BN0-AWD0.001{:}'.format(suffix) | ||||||
|  |     alg2name['masking + sampling'] = 'tunas-affine0_BN0-AWD0.0{:}'.format(suffix) | ||||||
|   for alg, name in alg2name.items(): |   for alg, name in alg2name.items(): | ||||||
|     alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth') |     alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth') | ||||||
|   alg2data = OrderedDict() |   alg2data = OrderedDict() | ||||||
| @@ -99,7 +103,10 @@ def visualize_curve(api, vis_save_dir, search_space): | |||||||
|       print('plot alg : {:}'.format(alg)) |       print('plot alg : {:}'.format(alg)) | ||||||
|       xs, accuracies = [], [] |       xs, accuracies = [], [] | ||||||
|       for iepoch in range(epochs + 1): |       for iepoch in range(epochs + 1): | ||||||
|  |         try: | ||||||
|           structures, accs = [_[iepoch-1] for _ in data], [] |           structures, accs = [_[iepoch-1] for _ in data], [] | ||||||
|  |         except: | ||||||
|  |           raise ValueError('This alg {:} on {:} has invalid checkpoints.'.format(alg, dataset)) | ||||||
|         for structure in structures: |         for structure in structures: | ||||||
|           info = api.get_more_info(structure, dataset=dataset, hp=90 if api.search_space_name == 'size' else 200, is_random=False) |           info = api.get_more_info(structure, dataset=dataset, hp=90 if api.search_space_name == 'size' else 200, is_random=False) | ||||||
|           accs.append(info['test-accuracy']) |           accs.append(info['test-accuracy']) | ||||||
| @@ -131,5 +138,5 @@ if __name__ == '__main__': | |||||||
|  |  | ||||||
|   save_dir = Path(args.save_dir) |   save_dir = Path(args.save_dir) | ||||||
|  |  | ||||||
|   api = create(None, args.search_space, verbose=False) |   api = create(None, args.search_space, fast_mode=True, verbose=False) | ||||||
|   visualize_curve(api, save_dir, args.search_space) |   visualize_curve(api, save_dir, args.search_space) | ||||||
|   | |||||||
| @@ -2,8 +2,8 @@ | |||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||||
| ##################################################### | ##################################################### | ||||||
| # Here, we utilized three techniques to search for the number of channels: | # Here, we utilized three techniques to search for the number of channels: | ||||||
| # - feature interpaltion from "Network Pruning via Transformable Architecture Search, NeurIPS 2019" | # - channel-wise interpaltion from "Network Pruning via Transformable Architecture Search, NeurIPS 2019" | ||||||
| # - masking + GumbelSoftmax from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020" | # - masking + Gumbel-Softmax from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020" | ||||||
| # - masking + sampling from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020" | # - masking + sampling from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020" | ||||||
| from typing import List, Text, Any | from typing import List, Text, Any | ||||||
| import random, torch | import random, torch | ||||||
| @@ -55,7 +55,7 @@ class GenericNAS301Model(nn.Module): | |||||||
|     assert algo in ['fbv2', 'tunas', 'tas'], 'invalid algo : {:}'.format(algo) |     assert algo in ['fbv2', 'tunas', 'tas'], 'invalid algo : {:}'.format(algo) | ||||||
|     self._algo = algo |     self._algo = algo | ||||||
|     self._arch_parameters = nn.Parameter(1e-3*torch.randn(self._max_num_Cs, len(self._candidate_Cs))) |     self._arch_parameters = nn.Parameter(1e-3*torch.randn(self._max_num_Cs, len(self._candidate_Cs))) | ||||||
|     if algo == 'fbv2' or algo == 'tunas': |     # if algo == 'fbv2' or algo == 'tunas': | ||||||
|     self.register_buffer('_masks', torch.zeros(len(self._candidate_Cs), max(self._candidate_Cs))) |     self.register_buffer('_masks', torch.zeros(len(self._candidate_Cs), max(self._candidate_Cs))) | ||||||
|     for i in range(len(self._candidate_Cs)): |     for i in range(len(self._candidate_Cs)): | ||||||
|       self._masks.data[i, :self._candidate_Cs[i]] = 1 |       self._masks.data[i, :self._candidate_Cs[i]] = 1 | ||||||
|   | |||||||
| @@ -7,7 +7,6 @@ | |||||||
| # [2020.08.31] NATS-sss-v1_0-50262.pickle.pbz2                                      # | # [2020.08.31] NATS-sss-v1_0-50262.pickle.pbz2                                      # | ||||||
| ##################################################################################### | ##################################################################################### | ||||||
| import os, copy, random, numpy as np | import os, copy, random, numpy as np | ||||||
| from pathlib import Path |  | ||||||
| from typing import List, Text, Union, Dict, Optional | from typing import List, Text, Union, Dict, Optional | ||||||
| from collections import OrderedDict, defaultdict | from collections import OrderedDict, defaultdict | ||||||
| from .api_utils import time_string | from .api_utils import time_string | ||||||
| @@ -15,6 +14,8 @@ from .api_utils import pickle_load | |||||||
| from .api_utils import ArchResults | from .api_utils import ArchResults | ||||||
| from .api_utils import NASBenchMetaAPI | from .api_utils import NASBenchMetaAPI | ||||||
| from .api_utils import remap_dataset_set_names | from .api_utils import remap_dataset_set_names | ||||||
|  | from .api_utils import nats_is_dir | ||||||
|  | from .api_utils import nats_is_file | ||||||
| from .api_utils import PICKLE_EXT | from .api_utils import PICKLE_EXT | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -70,20 +71,20 @@ class NATSsize(NASBenchMetaAPI): | |||||||
|       else: |       else: | ||||||
|         file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], '{:}.{:}'.format(ALL_BASE_NAMES[-1], PICKLE_EXT)) |         file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], '{:}.{:}'.format(ALL_BASE_NAMES[-1], PICKLE_EXT)) | ||||||
|       print ('{:} Try to use the default NATS-Bench (size) path from fast_mode={:} and path={:}.'.format(time_string(), self._fast_mode, file_path_or_dict)) |       print ('{:} Try to use the default NATS-Bench (size) path from fast_mode={:} and path={:}.'.format(time_string(), self._fast_mode, file_path_or_dict)) | ||||||
|     if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path): |     if isinstance(file_path_or_dict, str): | ||||||
|       file_path_or_dict = str(file_path_or_dict) |       file_path_or_dict = str(file_path_or_dict) | ||||||
|       if verbose: |       if verbose: | ||||||
|         print('{:} Try to create the NATS-Bench (size) api from {:} with fast_mode={:}'.format(time_string(), file_path_or_dict, fast_mode)) |         print('{:} Try to create the NATS-Bench (size) api from {:} with fast_mode={:}'.format(time_string(), file_path_or_dict, fast_mode)) | ||||||
|       if not os.path.isfile(file_path_or_dict) and not os.path.isdir(file_path_or_dict): |       if not nats_is_file(file_path_or_dict) and not nats_is_dir(file_path_or_dict): | ||||||
|         raise ValueError('{:} is neither a file or a dir.'.format(file_path_or_dict)) |         raise ValueError('{:} is neither a file or a dir.'.format(file_path_or_dict)) | ||||||
|       self.filename = Path(file_path_or_dict).name |       self.filename = os.path.basename(file_path_or_dict) | ||||||
|       if fast_mode: |       if fast_mode: | ||||||
|         if os.path.isfile(file_path_or_dict): |         if nats_is_file(file_path_or_dict): | ||||||
|           raise ValueError('fast_mode={:} must feed the path for directory : {:}'.format(fast_mode, file_path_or_dict)) |           raise ValueError('fast_mode={:} must feed the path for directory : {:}'.format(fast_mode, file_path_or_dict)) | ||||||
|         else: |         else: | ||||||
|           self._archive_dir = file_path_or_dict |           self._archive_dir = file_path_or_dict | ||||||
|       else: |       else: | ||||||
|         if os.path.isdir(file_path_or_dict): |         if nats_is_dir(file_path_or_dict): | ||||||
|           raise ValueError('fast_mode={:} must feed the path for file : {:}'.format(fast_mode, file_path_or_dict)) |           raise ValueError('fast_mode={:} must feed the path for file : {:}'.format(fast_mode, file_path_or_dict)) | ||||||
|         else: |         else: | ||||||
|           file_path_or_dict = pickle_load(file_path_or_dict) |           file_path_or_dict = pickle_load(file_path_or_dict) | ||||||
|   | |||||||
| @@ -7,7 +7,6 @@ | |||||||
| # [2020.08.31] NATS-tss-v1_0-3ffb9.pickle.pbz2                                      # | # [2020.08.31] NATS-tss-v1_0-3ffb9.pickle.pbz2                                      # | ||||||
| ##################################################################################### | ##################################################################################### | ||||||
| import os, copy, random, numpy as np | import os, copy, random, numpy as np | ||||||
| from pathlib import Path |  | ||||||
| from typing import List, Text, Union, Dict, Optional | from typing import List, Text, Union, Dict, Optional | ||||||
| from collections import OrderedDict, defaultdict | from collections import OrderedDict, defaultdict | ||||||
| import warnings | import warnings | ||||||
| @@ -16,6 +15,8 @@ from .api_utils import pickle_load | |||||||
| from .api_utils import ArchResults | from .api_utils import ArchResults | ||||||
| from .api_utils import NASBenchMetaAPI | from .api_utils import NASBenchMetaAPI | ||||||
| from .api_utils import remap_dataset_set_names | from .api_utils import remap_dataset_set_names | ||||||
|  | from .api_utils import nats_is_dir | ||||||
|  | from .api_utils import nats_is_file | ||||||
| from .api_utils import PICKLE_EXT | from .api_utils import PICKLE_EXT | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -67,20 +68,20 @@ class NATStopology(NASBenchMetaAPI): | |||||||
|       else: |       else: | ||||||
|         file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], '{:}.{:}'.format(ALL_BASE_NAMES[-1], PICKLE_EXT)) |         file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], '{:}.{:}'.format(ALL_BASE_NAMES[-1], PICKLE_EXT)) | ||||||
|       print ('{:} Try to use the default NATS-Bench (topology) path from {:}.'.format(time_string(), file_path_or_dict)) |       print ('{:} Try to use the default NATS-Bench (topology) path from {:}.'.format(time_string(), file_path_or_dict)) | ||||||
|     if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path): |     if isinstance(file_path_or_dict, str): | ||||||
|       file_path_or_dict = str(file_path_or_dict) |       file_path_or_dict = str(file_path_or_dict) | ||||||
|       if verbose: |       if verbose: | ||||||
|         print('{:} Try to create the NATS-Bench (topology) api from {:} with fast_mode={:}'.format(time_string(), file_path_or_dict, fast_mode)) |         print('{:} Try to create the NATS-Bench (topology) api from {:} with fast_mode={:}'.format(time_string(), file_path_or_dict, fast_mode)) | ||||||
|       if not os.path.isfile(file_path_or_dict) and not os.path.isdir(file_path_or_dict): |       if not nats_is_file(file_path_or_dict) and not nats_is_dir(file_path_or_dict): | ||||||
|         raise ValueError('{:} is neither a file or a dir.'.format(file_path_or_dict)) |         raise ValueError('{:} is neither a file or a dir.'.format(file_path_or_dict)) | ||||||
|       self.filename = Path(file_path_or_dict).name |       self.filename = os.path.basename(file_path_or_dict) | ||||||
|       if fast_mode: |       if fast_mode: | ||||||
|         if os.path.isfile(file_path_or_dict): |         if nats_is_file(file_path_or_dict): | ||||||
|           raise ValueError('fast_mode={:} must feed the path for directory : {:}'.format(fast_mode, file_path_or_dict)) |           raise ValueError('fast_mode={:} must feed the path for directory : {:}'.format(fast_mode, file_path_or_dict)) | ||||||
|         else: |         else: | ||||||
|           self._archive_dir = file_path_or_dict |           self._archive_dir = file_path_or_dict | ||||||
|       else: |       else: | ||||||
|         if os.path.isdir(file_path_or_dict): |         if nats_is_dir(file_path_or_dict): | ||||||
|           raise ValueError('fast_mode={:} must feed the path for file : {:}'.format(fast_mode, file_path_or_dict)) |           raise ValueError('fast_mode={:} must feed the path for file : {:}'.format(fast_mode, file_path_or_dict)) | ||||||
|         else: |         else: | ||||||
|           file_path_or_dict = pickle_load(file_path_or_dict) |           file_path_or_dict = pickle_load(file_path_or_dict) | ||||||
|   | |||||||
| @@ -17,6 +17,7 @@ from typing import List, Text, Union, Dict, Optional | |||||||
| from collections import OrderedDict, defaultdict | from collections import OrderedDict, defaultdict | ||||||
|  |  | ||||||
|  |  | ||||||
|  | _FILE_SYSTEM = 'default' | ||||||
| PICKLE_EXT = 'pickle.pbz2' | PICKLE_EXT = 'pickle.pbz2' | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -45,6 +46,34 @@ def time_string(): | |||||||
|   return string |   return string | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def reset_file_system(lib: Text='default'): | ||||||
|  |   _FILE_SYSTEM = lib | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_file_system(lib: Text='default'): | ||||||
|  |   return _FILE_SYSTEM | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def nats_is_dir(file_path): | ||||||
|  |   if _FILE_SYSTEM == 'default': | ||||||
|  |     return os.path.isdir(file_path) | ||||||
|  |   elif _FILE_SYSTEM == 'google': | ||||||
|  |     import tensorflow as tf | ||||||
|  |     return tf.gfile.isdir(file_path) | ||||||
|  |   else: | ||||||
|  |     raise ValueError('Unknown file system lib: {:}'.format(_FILE_SYSTEM)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def nats_is_file(file_path): | ||||||
|  |   if _FILE_SYSTEM == 'default': | ||||||
|  |     return os.path.isfile(file_path) | ||||||
|  |   elif _FILE_SYSTEM == 'google': | ||||||
|  |     import tensorflow as tf | ||||||
|  |     return tf.gfile.exists(file_path) and not tf.gfile.isdir(file_path) | ||||||
|  |   else: | ||||||
|  |     raise ValueError('Unknown file system lib: {:}'.format(_FILE_SYSTEM)) | ||||||
|  |  | ||||||
|  |  | ||||||
| def remap_dataset_set_names(dataset, metric_on_set, verbose=False): | def remap_dataset_set_names(dataset, metric_on_set, verbose=False): | ||||||
|   """re-map the metric_on_set to internal keys""" |   """re-map the metric_on_set to internal keys""" | ||||||
|   if verbose: |   if verbose: | ||||||
| @@ -146,10 +175,10 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta): | |||||||
|             time_string(), archive_root, index)) |             time_string(), archive_root, index)) | ||||||
|     if archive_root is None: |     if archive_root is None: | ||||||
|       archive_root = os.path.join(os.environ['TORCH_HOME'], '{:}-full'.format(self.ALL_BASE_NAMES[-1])) |       archive_root = os.path.join(os.environ['TORCH_HOME'], '{:}-full'.format(self.ALL_BASE_NAMES[-1])) | ||||||
|       if not os.path.isdir(archive_root): |       if not nats_is_dir(archive_root): | ||||||
|         warnings.warn('The input archive_root is None and the default archive_root path ({:}) does not exist, try to use self.archive_dir.'.format(archive_root)) |         warnings.warn('The input archive_root is None and the default archive_root path ({:}) does not exist, try to use self.archive_dir.'.format(archive_root)) | ||||||
|         archive_root = self.archive_dir |         archive_root = self.archive_dir | ||||||
|     if archive_root is None or not os.path.isdir(archive_root): |     if archive_root is None or not nats_is_dir(archive_root): | ||||||
|       raise ValueError('Invalid archive_root : {:}'.format(archive_root)) |       raise ValueError('Invalid archive_root : {:}'.format(archive_root)) | ||||||
|     if index is None: |     if index is None: | ||||||
|       indexes = list(range(len(self))) |       indexes = list(range(len(self))) | ||||||
| @@ -158,9 +187,9 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta): | |||||||
|     for idx in indexes: |     for idx in indexes: | ||||||
|       assert 0 <= idx < len(self.meta_archs), 'invalid index of {:}'.format(idx) |       assert 0 <= idx < len(self.meta_archs), 'invalid index of {:}'.format(idx) | ||||||
|       xfile_path = os.path.join(archive_root, '{:06d}.{:}'.format(idx, PICKLE_EXT)) |       xfile_path = os.path.join(archive_root, '{:06d}.{:}'.format(idx, PICKLE_EXT)) | ||||||
|       if not os.path.isfile(xfile_path): |       if not nats_is_file(xfile_path): | ||||||
|         xfile_path = os.path.join(archive_root, '{:d}.{:}'.format(idx, PICKLE_EXT)) |         xfile_path = os.path.join(archive_root, '{:d}.{:}'.format(idx, PICKLE_EXT)) | ||||||
|       assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path) |       assert nats_is_file(xfile_path), 'invalid data path : {:}'.format(xfile_path) | ||||||
|       xdata = pickle_load(xfile_path) |       xdata = pickle_load(xfile_path) | ||||||
|       assert isinstance(xdata, dict), 'invalid format of data in {:}'.format(xfile_path) |       assert isinstance(xdata, dict), 'invalid format of data in {:}'.format(xfile_path) | ||||||
|       self.evaluated_indexes.add(idx) |       self.evaluated_indexes.add(idx) | ||||||
|   | |||||||
| @@ -1,10 +1,10 @@ | |||||||
| #!/bin/bash | #!/bin/bash | ||||||
| # bash ./NATS/search-size.sh 0 777 | # bash scripts-search/NATS/search-size.sh 0 0.3 777 | ||||||
| echo script name: $0 | echo script name: $0 | ||||||
| echo $# arguments | echo $# arguments | ||||||
| if [ "$#" -ne 2 ] ;then | if [ "$#" -ne 3 ] ;then | ||||||
|   echo "Input illegal number of parameters " $# |   echo "Input illegal number of parameters " $# | ||||||
|   echo "Need 2 parameters for GPU-device and seed" |   echo "Need 3 parameters for GPU-device, warmup-ratio, and seed" | ||||||
|   exit 1 |   exit 1 | ||||||
| fi | fi | ||||||
| if [ "$TORCH_HOME" = "" ]; then | if [ "$TORCH_HOME" = "" ]; then | ||||||
| @@ -15,16 +15,19 @@ else | |||||||
| fi | fi | ||||||
|  |  | ||||||
| device=$1 | device=$1 | ||||||
| seed=$2 | ratio=$2 | ||||||
|  | seed=$3 | ||||||
|  |  | ||||||
| CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed ${seed} | CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo tas --warmup_ratio ${ratio} --rand_seed ${seed} | ||||||
| CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed ${seed} | CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --warmup_ratio ${ratio} --rand_seed ${seed} | ||||||
| CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed ${seed} | CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --warmup_ratio ${ratio} --rand_seed ${seed} | ||||||
|  |  | ||||||
| CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed ${seed} | # | ||||||
| CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed ${seed} | CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo fbv2 --warmup_ratio ${ratio} --rand_seed ${seed} | ||||||
| CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --rand_seed ${seed} | CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --warmup_ratio ${ratio} --rand_seed ${seed} | ||||||
|  | CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --warmup_ratio ${ratio} --rand_seed ${seed} | ||||||
|  |  | ||||||
| CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed ${seed} | # | ||||||
| CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed ${seed} | CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed} | ||||||
| CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tunas --arch_weight_decay 0 --rand_seed ${seed} | CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed} | ||||||
|  | CUDA_VISIBLE_DEVICES=${device} python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tunas --arch_weight_decay 0 --warmup_ratio ${ratio} --rand_seed ${seed} | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user