Fix the potential memory leak in NAS-Bench-201 clear_param
This commit is contained in:
		
							
								
								
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -121,3 +121,5 @@ lib/NAS-Bench-*-v1_0.pth | |||||||
| others/TF | others/TF | ||||||
| scripts-search/l2s-algos | scripts-search/l2s-algos | ||||||
| TEMP-L.sh | TEMP-L.sh | ||||||
|  |  | ||||||
|  | .nfs00* | ||||||
|   | |||||||
| @@ -1,7 +1,7 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # | ||||||
| ##################################################### | ##################################################### | ||||||
| import os, sys, time, torch | import time, torch | ||||||
| from procedures   import prepare_seed, get_optim_scheduler | from procedures   import prepare_seed, get_optim_scheduler | ||||||
| from utils        import get_model_infos, obtain_accuracy | from utils        import get_model_infos, obtain_accuracy | ||||||
| from config_utils import dict2config | from config_utils import dict2config | ||||||
| @@ -9,11 +9,9 @@ from log_utils    import AverageMeter, time_string, convert_secs2time | |||||||
| from models       import get_cell_based_tiny_net | from models       import get_cell_based_tiny_net | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| __all__ = ['evaluate_for_seed', 'pure_evaluate'] | __all__ = ['evaluate_for_seed', 'pure_evaluate'] | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()): | def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()): | ||||||
|   data_time, batch_time, batch = AverageMeter(), AverageMeter(), None |   data_time, batch_time, batch = AverageMeter(), AverageMeter(), None | ||||||
|   losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() |   losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||||
|   | |||||||
| @@ -28,7 +28,7 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_c | |||||||
|   for dataset, xpath, split in zip(datasets, xpaths, splits): |   for dataset, xpath, split in zip(datasets, xpaths, splits): | ||||||
|     # train valid data |     # train valid data | ||||||
|     train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) |     train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) | ||||||
|     # load the configurature |     # load the configuration | ||||||
|     if dataset == 'cifar10' or dataset == 'cifar100': |     if dataset == 'cifar10' or dataset == 'cifar100': | ||||||
|       if use_less: config_path = 'configs/nas-benchmark/LESS.config' |       if use_less: config_path = 'configs/nas-benchmark/LESS.config' | ||||||
|       else       : config_path = 'configs/nas-benchmark/CIFAR.config' |       else       : config_path = 'configs/nas-benchmark/CIFAR.config' | ||||||
|   | |||||||
| @@ -3,7 +3,7 @@ | |||||||
| ################################################################################################ | ################################################################################################ | ||||||
| # python exps/NAS-Bench-201/show-best.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth # | # python exps/NAS-Bench-201/show-best.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth # | ||||||
| ################################################################################################ | ################################################################################################ | ||||||
| import os, sys, time, glob, random, argparse | import sys, argparse | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||||
|   | |||||||
| @@ -6,7 +6,7 @@ | |||||||
| # python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10-valid --use_12 1 --use_valid 1 | # python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10-valid --use_12 1 --use_valid 1 | ||||||
| # bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1 | # bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1 | ||||||
| ############################################################################################### | ############################################################################################### | ||||||
| import os, gc, sys, time, glob, random, argparse | import os, gc, sys, argparse, psutil | ||||||
| import numpy as np | import numpy as np | ||||||
| import torch | import torch | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| @@ -33,7 +33,7 @@ def tostr(accdict, norms): | |||||||
|  |  | ||||||
| def evaluate(api, weight_dir, data: str, use_12epochs_result: bool): | def evaluate(api, weight_dir, data: str, use_12epochs_result: bool): | ||||||
|   print('\nEvaluate dataset={:}'.format(data)) |   print('\nEvaluate dataset={:}'.format(data)) | ||||||
|   norms = [] |   norms, process = [], psutil.Process(os.getpid()) | ||||||
|   final_val_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []}) |   final_val_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []}) | ||||||
|   final_test_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []}) |   final_test_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []}) | ||||||
|   for idx in range(len(api)): |   for idx in range(len(api)): | ||||||
| @@ -56,16 +56,17 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool): | |||||||
|       with torch.no_grad(): |       with torch.no_grad(): | ||||||
|         net.load_state_dict(param) |         net.load_state_dict(param) | ||||||
|         _, summary = weight_watcher.analyze(net, alphas=False) |         _, summary = weight_watcher.analyze(net, alphas=False) | ||||||
|         cur_norms.append( summary['lognorm'] ) |         cur_norms.append(summary['lognorm']) | ||||||
|     norms.append( float(np.mean(cur_norms)) ) |     norms.append( float(np.mean(cur_norms)) ) | ||||||
|     api.clear_params(idx, use_12epochs_result) |     api.clear_params(idx, None) | ||||||
|     if idx % 200 == 199 or idx + 1 == len(api): |     if idx % 200 == 199 or idx + 1 == len(api): | ||||||
|       head = '{:05d}/{:05d}'.format(idx, len(api)) |       head = '{:05d}/{:05d}'.format(idx, len(api)) | ||||||
|       stem_val = tostr(final_val_accs, norms) |       stem_val = tostr(final_val_accs, norms) | ||||||
|       stem_test = tostr(final_test_accs, norms) |       stem_test = tostr(final_test_accs, norms) | ||||||
|       print('{:} {:} {:} with {:} epochs on {:} : the correlation is {:.3f}'.format(time_string(), head, data, 12 if use_12epochs_result else 200)) |       print('{:} {:} {:} with {:} epochs ({:.2f} MB memory)'.format(time_string(), head, data, 12 if use_12epochs_result else 200, process.memory_info().rss / 1e6)) | ||||||
|       print('    -->>  {:}  ||  {:}'.format(stem_val, stem_test)) |       print('  [Valid] -->>  {:}'.format(stem_val)) | ||||||
|       torch.cuda.empty_cache() ; gc.collect() |       print('  [Test.] -->>  {:}'.format(stem_test)) | ||||||
|  |       gc.collect() | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(meta_file: str, weight_dir, save_dir, xdata, use_12epochs_result): | def main(meta_file: str, weight_dir, save_dir, xdata, use_12epochs_result): | ||||||
|   | |||||||
| @@ -3,7 +3,7 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # python exps/NAS-Bench-201/visualize.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth | # python exps/NAS-Bench-201/visualize.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth | ||||||
| ##################################################### | ##################################################### | ||||||
| import os, sys, time, argparse, collections | import sys, argparse | ||||||
| from tqdm import tqdm | from tqdm import tqdm | ||||||
| from collections import OrderedDict | from collections import OrderedDict | ||||||
| import numpy as np | import numpy as np | ||||||
|   | |||||||
| @@ -24,11 +24,11 @@ def evaluate_all_datasets(channels: Text, datasets: List[Text], xpaths: List[Tex | |||||||
|   machine_info = get_machine_info() |   machine_info = get_machine_info() | ||||||
|   all_infos = {'info': machine_info} |   all_infos = {'info': machine_info} | ||||||
|   all_dataset_keys = [] |   all_dataset_keys = [] | ||||||
|   # look all the datasets |   # look all the dataset | ||||||
|   for dataset, xpath, split in zip(datasets, xpaths, splits): |   for dataset, xpath, split in zip(datasets, xpaths, splits): | ||||||
|     # train valid data |     # the train and valid data | ||||||
|     train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) |     train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) | ||||||
|     # load the configurature |     # load the configuration | ||||||
|     if dataset == 'cifar10' or dataset == 'cifar100': |     if dataset == 'cifar10' or dataset == 'cifar100': | ||||||
|       split_info  = load_config('configs/nas-benchmark/cifar-split.txt', None, None) |       split_info  = load_config('configs/nas-benchmark/cifar-split.txt', None, None) | ||||||
|     elif dataset.startswith('ImageNet16'): |     elif dataset.startswith('ImageNet16'): | ||||||
| @@ -36,7 +36,7 @@ def evaluate_all_datasets(channels: Text, datasets: List[Text], xpaths: List[Tex | |||||||
|     else: |     else: | ||||||
|       raise ValueError('invalid dataset : {:}'.format(dataset)) |       raise ValueError('invalid dataset : {:}'.format(dataset)) | ||||||
|     config = load_config(config_path, dict(class_num=class_num, xshape=xshape), logger) |     config = load_config(config_path, dict(class_num=class_num, xshape=xshape), logger) | ||||||
|     # check whether use splited validation set |     # check whether use the splitted validation set | ||||||
|     if bool(split): |     if bool(split): | ||||||
|       assert dataset == 'cifar10' |       assert dataset == 'cifar10' | ||||||
|       ValLoaders = {'ori-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)} |       ValLoaders = {'ori-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)} | ||||||
| @@ -92,7 +92,7 @@ def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text], | |||||||
|  |  | ||||||
|   log_dir = save_dir / 'logs' |   log_dir = save_dir / 'logs' | ||||||
|   log_dir.mkdir(parents=True, exist_ok=True) |   log_dir.mkdir(parents=True, exist_ok=True) | ||||||
|   logger = Logger(str(log_dir), 0, False) |   logger = Logger(str(log_dir), os.getpid(), False) | ||||||
|  |  | ||||||
|   logger.log('xargs : seeds      = {:}'.format(seeds)) |   logger.log('xargs : seeds      = {:}'.format(seeds)) | ||||||
|   logger.log('xargs : cover_mode = {:}'.format(cover_mode)) |   logger.log('xargs : cover_mode = {:}'.format(cover_mode)) | ||||||
|   | |||||||
| @@ -114,15 +114,27 @@ class NASBench201API(object): | |||||||
|     assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path) |     assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path) | ||||||
|     xdata = torch.load(xfile_path, map_location='cpu') |     xdata = torch.load(xfile_path, map_location='cpu') | ||||||
|     assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path) |     assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path) | ||||||
|  |     if index in self.arch2infos_less: del self.arch2infos_less[index] | ||||||
|  |     if index in self.arch2infos_full: del self.arch2infos_full[index] | ||||||
|     self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] ) |     self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] ) | ||||||
|     self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] ) |     self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] ) | ||||||
|  |  | ||||||
|   def clear_params(self, index: int, use_12epochs_result: bool): |   def clear_params(self, index: int, use_12epochs_result: Union[bool, None]): | ||||||
|     """Remove the architecture's weights to save memory.""" |     """Remove the architecture's weights to save memory. | ||||||
|  |     :arg | ||||||
|  |       index: the index of the target architecture | ||||||
|  |       use_12epochs_result: a flag to controll how to clear the parameters. | ||||||
|  |         -- None: clear all the weights in both `less` and `full`, which indicates the training hyper-parameters. | ||||||
|  |         -- True: clear all the weights in arch2infos_less, which by default is 12-epoch-training result. | ||||||
|  |         -- False: clear all the weights in arch2infos_full, which by default is 200-epoch-training result. | ||||||
|  |     """ | ||||||
|  |     if use_12epochs_result is None: | ||||||
|  |       self.arch2infos_less[index].clear_params() | ||||||
|  |       self.arch2infos_full[index].clear_params() | ||||||
|  |     else: | ||||||
|       if use_12epochs_result: arch2infos = self.arch2infos_less |       if use_12epochs_result: arch2infos = self.arch2infos_less | ||||||
|       else                  : arch2infos = self.arch2infos_full |       else                  : arch2infos = self.arch2infos_full | ||||||
|     archresult = arch2infos[index] |       arch2infos[index].clear_params() | ||||||
|     archresult.clear_params() |  | ||||||
|    |    | ||||||
|   # This function is used to query the information of a specific archiitecture |   # This function is used to query the information of a specific archiitecture | ||||||
|   # 'arch' can be an architecture index or an architecture string |   # 'arch' can be an architecture index or an architecture string | ||||||
| @@ -193,7 +205,6 @@ class NASBench201API(object): | |||||||
|         best_index, highest_accuracy = idx, accuracy |         best_index, highest_accuracy = idx, accuracy | ||||||
|     return best_index, highest_accuracy |     return best_index, highest_accuracy | ||||||
|  |  | ||||||
|  |  | ||||||
|   def arch(self, index: int): |   def arch(self, index: int): | ||||||
|     """Return the topology structure of the `index`-th architecture.""" |     """Return the topology structure of the `index`-th architecture.""" | ||||||
|     assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs)) |     assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs)) | ||||||
| @@ -214,7 +225,6 @@ class NASBench201API(object): | |||||||
|     arch_result = arch2infos[index] |     arch_result = arch2infos[index] | ||||||
|     return arch_result.get_net_param(dataset, seed) |     return arch_result.get_net_param(dataset, seed) | ||||||
|  |  | ||||||
|  |  | ||||||
|   def get_net_config(self, index: int, dataset: Text): |   def get_net_config(self, index: int, dataset: Text): | ||||||
|     """ |     """ | ||||||
|       This function is used to obtain the configuration for the `index`-th architecture on `dataset`. |       This function is used to obtain the configuration for the `index`-th architecture on `dataset`. | ||||||
| @@ -235,7 +245,6 @@ class NASBench201API(object): | |||||||
|       #print ('SEED [{:}] : {:}'.format(seed, result)) |       #print ('SEED [{:}] : {:}'.format(seed, result)) | ||||||
|     raise ValueError('Impossible to reach here!') |     raise ValueError('Impossible to reach here!') | ||||||
|  |  | ||||||
|  |  | ||||||
|   def get_cost_info(self, index: int, dataset: Text, use_12epochs_result: bool = False) -> Dict[Text, float]: |   def get_cost_info(self, index: int, dataset: Text, use_12epochs_result: bool = False) -> Dict[Text, float]: | ||||||
|     """To obtain the cost metric for the `index`-th architecture on a dataset.""" |     """To obtain the cost metric for the `index`-th architecture on a dataset.""" | ||||||
|     if use_12epochs_result: arch2infos = self.arch2infos_less |     if use_12epochs_result: arch2infos = self.arch2infos_less | ||||||
| @@ -243,7 +252,6 @@ class NASBench201API(object): | |||||||
|     arch_result = arch2infos[index] |     arch_result = arch2infos[index] | ||||||
|     return arch_result.get_compute_costs(dataset) |     return arch_result.get_compute_costs(dataset) | ||||||
|  |  | ||||||
|  |  | ||||||
|   def get_latency(self, index: int, dataset: Text, use_12epochs_result: bool = False) -> float: |   def get_latency(self, index: int, dataset: Text, use_12epochs_result: bool = False) -> float: | ||||||
|     """ |     """ | ||||||
|     To obtain the latency of the network (by default it will return the latency with the batch size of 256). |     To obtain the latency of the network (by default it will return the latency with the batch size of 256). | ||||||
| @@ -254,7 +262,6 @@ class NASBench201API(object): | |||||||
|     cost_dict = self.get_cost_info(index, dataset, use_12epochs_result) |     cost_dict = self.get_cost_info(index, dataset, use_12epochs_result) | ||||||
|     return cost_dict['latency'] |     return cost_dict['latency'] | ||||||
|  |  | ||||||
|  |  | ||||||
|   # obtain the metric for the `index`-th architecture |   # obtain the metric for the `index`-th architecture | ||||||
|   # `dataset` indicates the dataset: |   # `dataset` indicates the dataset: | ||||||
|   #   'cifar10-valid'  : using the proposed train set of CIFAR-10 as the training set |   #   'cifar10-valid'  : using the proposed train set of CIFAR-10 as the training set | ||||||
| @@ -388,7 +395,6 @@ class NASBench201API(object): | |||||||
|       return xifo |       return xifo | ||||||
|   """ |   """ | ||||||
|  |  | ||||||
|  |  | ||||||
|   def show(self, index: int = -1) -> None: |   def show(self, index: int = -1) -> None: | ||||||
|     """ |     """ | ||||||
|     This function will print the information of a specific (or all) architecture(s). |     This function will print the information of a specific (or all) architecture(s). | ||||||
| @@ -423,7 +429,6 @@ class NASBench201API(object): | |||||||
|       else: |       else: | ||||||
|         print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs))) |         print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs))) | ||||||
|  |  | ||||||
|  |  | ||||||
|   def statistics(self, dataset: Text, use_12epochs_result: bool) -> Dict[int, int]: |   def statistics(self, dataset: Text, use_12epochs_result: bool) -> Dict[int, int]: | ||||||
|     """ |     """ | ||||||
|     This function will count the number of total trials. |     This function will count the number of total trials. | ||||||
| @@ -443,7 +448,6 @@ class NASBench201API(object): | |||||||
|         nums[len(dataset_seed[dataset])] += 1 |         nums[len(dataset_seed[dataset])] += 1 | ||||||
|     return dict(nums) |     return dict(nums) | ||||||
|  |  | ||||||
|  |  | ||||||
|   @staticmethod |   @staticmethod | ||||||
|   def str2lists(arch_str: Text) -> List[tuple]: |   def str2lists(arch_str: Text) -> List[tuple]: | ||||||
|     """ |     """ | ||||||
| @@ -471,7 +475,6 @@ class NASBench201API(object): | |||||||
|       genotypes.append( input_infos ) |       genotypes.append( input_infos ) | ||||||
|     return genotypes |     return genotypes | ||||||
|  |  | ||||||
|  |  | ||||||
|   @staticmethod |   @staticmethod | ||||||
|   def str2matrix(arch_str: Text, |   def str2matrix(arch_str: Text, | ||||||
|                  search_space: List[Text] = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']) -> np.ndarray: |                  search_space: List[Text] = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']) -> np.ndarray: | ||||||
| @@ -511,7 +514,6 @@ class NASBench201API(object): | |||||||
|     return matrix |     return matrix | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ArchResults(object): | class ArchResults(object): | ||||||
|  |  | ||||||
|   def __init__(self, arch_index, arch_str): |   def __init__(self, arch_index, arch_str): | ||||||
| @@ -754,7 +756,6 @@ class ArchResults(object): | |||||||
|     return ('{name}(arch-index={index}, arch={arch}, {num} runs, clear={clear})'.format(name=self.__class__.__name__, index=self.arch_index, arch=self.arch_str, num=len(self.all_results), clear=self.clear_net_done)) |     return ('{name}(arch-index={index}, arch={arch}, {num} runs, clear={clear})'.format(name=self.__class__.__name__, index=self.arch_index, arch=self.arch_str, num=len(self.all_results), clear=self.clear_net_done)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| """ | """ | ||||||
| This class (ResultsCount) is used to save the information of one trial for a single architecture. | This class (ResultsCount) is used to save the information of one trial for a single architecture. | ||||||
| I did not write much comment for this class, because it is the lowest-level class in NAS-Bench-201 API, which will be rarely called. | I did not write much comment for this class, because it is the lowest-level class in NAS-Bench-201 API, which will be rarely called. | ||||||
| @@ -872,8 +873,8 @@ class ResultsCount(object): | |||||||
|             'cur_time': xtime, |             'cur_time': xtime, | ||||||
|             'all_time': atime} |             'all_time': atime} | ||||||
|  |  | ||||||
|   # get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument). |  | ||||||
|   def get_eval(self, name, iepoch=None): |   def get_eval(self, name, iepoch=None): | ||||||
|  |     """Get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument).""" | ||||||
|     if iepoch is None: iepoch = self.epochs-1 |     if iepoch is None: iepoch = self.epochs-1 | ||||||
|     assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs) |     assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs) | ||||||
|     if isinstance(self.eval_times,dict) and len(self.eval_times) > 0: |     if isinstance(self.eval_times,dict) and len(self.eval_times) > 0: | ||||||
| @@ -890,8 +891,8 @@ class ResultsCount(object): | |||||||
|     if clone: return copy.deepcopy(self.net_state_dict) |     if clone: return copy.deepcopy(self.net_state_dict) | ||||||
|     else: return self.net_state_dict |     else: return self.net_state_dict | ||||||
|  |  | ||||||
|   # This function is used to obtain the config dict for this architecture. |  | ||||||
|   def get_config(self, str2structure): |   def get_config(self, str2structure): | ||||||
|  |     """This function is used to obtain the config dict for this architecture.""" | ||||||
|     if str2structure is None: |     if str2structure is None: | ||||||
|       return {'name': 'infer.tiny', 'C': self.arch_config['channel'], |       return {'name': 'infer.tiny', 'C': self.arch_config['channel'], | ||||||
|               'N'   : self.arch_config['num_cells'], |               'N'   : self.arch_config['num_cells'], | ||||||
|   | |||||||
| @@ -15,7 +15,7 @@ else | |||||||
|   echo "TORCH_HOME : $TORCH_HOME" |   echo "TORCH_HOME : $TORCH_HOME" | ||||||
| fi | fi | ||||||
|  |  | ||||||
| OMP_NUM_THREADS=4 python exps/NAS-Bench-201/test-weights.py \ | CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/NAS-Bench-201/test-weights.py \ | ||||||
| 	--base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 \ | 	--base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 \ | ||||||
| 	--dataset $1 \ | 	--dataset $1 \ | ||||||
| 	--use_12 $2 | 	--use_12 $2 | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user