Update test weights and shapes
This commit is contained in:
		| @@ -4,18 +4,15 @@ | |||||||
| # Before run these commands, the files must be properly put. | # Before run these commands, the files must be properly put. | ||||||
| # python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_0-e61699 | # python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_0-e61699 | ||||||
| # python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10-valid --use_12 1 --use_valid 1 | # python exps/NAS-Bench-201/test-weights.py --base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 --dataset cifar10-valid --use_12 1 --use_valid 1 | ||||||
| # bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1 1 | # bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1 | ||||||
| ############################################################################################### | ############################################################################################### | ||||||
| import os, gc, sys, time, glob, random, argparse | import os, gc, sys, time, glob, random, argparse | ||||||
| import numpy as np | import numpy as np | ||||||
| import torch | import torch | ||||||
| import torch.nn as nn |  | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from collections import OrderedDict | from collections import OrderedDict | ||||||
| from tqdm import tqdm |  | ||||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||||
| from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler |  | ||||||
| from nas_201_api import NASBench201API as API | from nas_201_api import NASBench201API as API | ||||||
| from log_utils import time_string | from log_utils import time_string | ||||||
| from models import get_cell_based_tiny_net | from models import get_cell_based_tiny_net | ||||||
| @@ -34,19 +31,22 @@ def tostr(accdict, norms): | |||||||
|   return ' '.join(xstr) |   return ' '.join(xstr) | ||||||
|  |  | ||||||
|  |  | ||||||
| def evaluate(api, weight_dir, data: str, use_12epochs_result: bool, valid_or_test: bool): | def evaluate(api, weight_dir, data: str, use_12epochs_result: bool): | ||||||
|   print('\nEvaluate dataset={:}'.format(data)) |   print('\nEvaluate dataset={:}'.format(data)) | ||||||
|   norms, accs = [], [] |   norms = [] | ||||||
|   final_accs = OrderedDict({'cifar10-valid': [], 'cifar10': [], 'cifar100': [], 'ImageNet16-120': []}) |   final_val_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []}) | ||||||
|  |   final_test_accs = OrderedDict({'cifar10': [], 'cifar100': [], 'ImageNet16-120': []}) | ||||||
|   for idx in range(len(api)): |   for idx in range(len(api)): | ||||||
|     info = api.get_more_info(idx, data, use_12epochs_result=use_12epochs_result, is_random=False) |     info = api.get_more_info(idx, data, use_12epochs_result=use_12epochs_result, is_random=False) | ||||||
|     if valid_or_test: |     for key in ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']: | ||||||
|       accs.append(info['valid-accuracy']) |  | ||||||
|     else: |  | ||||||
|       accs.append(info['test-accuracy']) |  | ||||||
|     for key in final_accs.keys(): |  | ||||||
|       info = api.get_more_info(idx, key, use_12epochs_result=False, is_random=False) |       info = api.get_more_info(idx, key, use_12epochs_result=False, is_random=False) | ||||||
|       final_accs[key].append(info['test-accuracy']) |       if key == 'cifar10-valid': | ||||||
|  |         final_val_accs['cifar10'].append(info['valid-accuracy']) | ||||||
|  |       elif key == 'cifar10': | ||||||
|  |         final_test_accs['cifar10'].append(info['test-accuracy']) | ||||||
|  |       else: | ||||||
|  |         final_test_accs[key].append(info['test-accuracy']) | ||||||
|  |         final_val_accs[key].append(info['valid-accuracy']) | ||||||
|     config = api.get_net_config(idx, data) |     config = api.get_net_config(idx, data) | ||||||
|     net = get_cell_based_tiny_net(config) |     net = get_cell_based_tiny_net(config) | ||||||
|     api.reload(weight_dir, idx) |     api.reload(weight_dir, idx) | ||||||
| @@ -60,14 +60,15 @@ def evaluate(api, weight_dir, data: str, use_12epochs_result: bool, valid_or_tes | |||||||
|     norms.append( float(np.mean(cur_norms)) ) |     norms.append( float(np.mean(cur_norms)) ) | ||||||
|     api.clear_params(idx, use_12epochs_result) |     api.clear_params(idx, use_12epochs_result) | ||||||
|     if idx % 200 == 199 or idx + 1 == len(api): |     if idx % 200 == 199 or idx + 1 == len(api): | ||||||
|       correlation = get_cor(norms, accs) |  | ||||||
|       head = '{:05d}/{:05d}'.format(idx, len(api)) |       head = '{:05d}/{:05d}'.format(idx, len(api)) | ||||||
|       stem = tostr(final_accs, norms) |       stem_val = tostr(final_val_accs, norms) | ||||||
|       print('{:} {:} {:} with {:} epochs on {:} : the correlation is {:.3f}. {:}'.format(time_string(), head, data, 12 if use_12epochs_result else 200, 'valid' if valid_or_test else 'test', correlation, stem)) |       stem_test = tostr(final_test_accs, norms) | ||||||
|  |       print('{:} {:} {:} with {:} epochs on {:} : the correlation is {:.3f}'.format(time_string(), head, data, 12 if use_12epochs_result else 200)) | ||||||
|  |       print('    -->>  {:}  ||  {:}'.format(stem_val, stem_test)) | ||||||
|       torch.cuda.empty_cache() ; gc.collect() |       torch.cuda.empty_cache() ; gc.collect() | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(meta_file: str, weight_dir, save_dir, xdata, use_12epochs_result, valid_or_test): | def main(meta_file: str, weight_dir, save_dir, xdata, use_12epochs_result): | ||||||
|   api = API(meta_file) |   api = API(meta_file) | ||||||
|   datasets = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120'] |   datasets = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120'] | ||||||
|   print(time_string() + ' ' + '='*50) |   print(time_string() + ' ' + '='*50) | ||||||
| @@ -83,7 +84,7 @@ def main(meta_file: str, weight_dir, save_dir, xdata, use_12epochs_result, valid | |||||||
|   print(time_string() + ' ' + '='*50) |   print(time_string() + ' ' + '='*50) | ||||||
|  |  | ||||||
|   #evaluate(api, weight_dir, 'cifar10-valid', False, True) |   #evaluate(api, weight_dir, 'cifar10-valid', False, True) | ||||||
|   evaluate(api, weight_dir, xdata, use_12epochs_result, valid_or_test) |   evaluate(api, weight_dir, xdata, use_12epochs_result) | ||||||
|    |    | ||||||
|   print('{:} finish this test.'.format(time_string())) |   print('{:} finish this test.'.format(time_string())) | ||||||
|  |  | ||||||
| @@ -94,7 +95,6 @@ if __name__ == '__main__': | |||||||
|   parser.add_argument('--base_path',  type=str, default=None, help='The path to the NAS-Bench-201 benchmark file and weight dir.') |   parser.add_argument('--base_path',  type=str, default=None, help='The path to the NAS-Bench-201 benchmark file and weight dir.') | ||||||
|   parser.add_argument('--dataset'  ,  type=str, default=None, help='.') |   parser.add_argument('--dataset'  ,  type=str, default=None, help='.') | ||||||
|   parser.add_argument('--use_12'   ,  type=int, default=None, help='.') |   parser.add_argument('--use_12'   ,  type=int, default=None, help='.') | ||||||
|   parser.add_argument('--use_valid',  type=int, default=None, help='.') |  | ||||||
|   args = parser.parse_args() |   args = parser.parse_args() | ||||||
|  |  | ||||||
|   save_dir = Path(args.save_dir) |   save_dir = Path(args.save_dir) | ||||||
| @@ -104,5 +104,5 @@ if __name__ == '__main__': | |||||||
|   assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file) |   assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file) | ||||||
|   assert weight_dir.exists() and weight_dir.is_dir(), 'invalid path for weight dir : {:}'.format(weight_dir) |   assert weight_dir.exists() and weight_dir.is_dir(), 'invalid path for weight dir : {:}'.format(weight_dir) | ||||||
|  |  | ||||||
|   main(str(meta_file), weight_dir, save_dir, args.dataset, bool(args.use_12), bool(args.use_valid)) |   main(str(meta_file), weight_dir, save_dir, args.dataset, bool(args.use_12)) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -88,11 +88,7 @@ def evaluate_all_datasets(channels: Text, datasets: List[Text], xpaths: List[Tex | |||||||
|  |  | ||||||
| def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text], | def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text], | ||||||
|          splits: List[int], seeds: List[int], nets: List[str], opt_config: Dict[Text, Any], |          splits: List[int], seeds: List[int], nets: List[str], opt_config: Dict[Text, Any], | ||||||
|          srange: tuple, cover_mode: bool): |          to_evaluate_indexes: tuple, cover_mode: bool): | ||||||
|   assert torch.cuda.is_available(), 'CUDA is not available.' |  | ||||||
|   torch.backends.cudnn.enabled = True |  | ||||||
|   torch.backends.cudnn.deterministic = True |  | ||||||
|   torch.set_num_threads(workers) |  | ||||||
|  |  | ||||||
|   log_dir = save_dir / 'logs' |   log_dir = save_dir / 'logs' | ||||||
|   log_dir.mkdir(parents=True, exist_ok=True) |   log_dir.mkdir(parents=True, exist_ok=True) | ||||||
| @@ -103,13 +99,13 @@ def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text], | |||||||
|   logger.log('-' * 100) |   logger.log('-' * 100) | ||||||
|  |  | ||||||
|   logger.log( |   logger.log( | ||||||
|     'Start evaluating range =: {:06d} - {:06d} / {:06d} with cover-mode={:}'.format(srange[0], srange[1], len(nets), |     'Start evaluating range =: {:06d} - {:06d}'.format(min(to_evaluate_indexes), max(to_evaluate_indexes)) | ||||||
|                                                                                     cover_mode)) |    +'({:} in total) / {:06d} with cover-mode={:}'.format(len(to_evaluate_indexes), len(nets), cover_mode)) | ||||||
|   for i, (dataset, xpath, split) in enumerate(zip(datasets, xpaths, splits)): |   for i, (dataset, xpath, split) in enumerate(zip(datasets, xpaths, splits)): | ||||||
|     logger.log( |     logger.log( | ||||||
|       '--->>> Evaluate {:}/{:} : dataset={:9s}, path={:}, split={:}'.format(i, len(datasets), dataset, xpath, split)) |       '--->>> Evaluate {:}/{:} : dataset={:9s}, path={:}, split={:}'.format(i, len(datasets), dataset, xpath, split)) | ||||||
|   logger.log('--->>> optimization config : {:}'.format(opt_config)) |   logger.log('--->>> optimization config : {:}'.format(opt_config)) | ||||||
|   to_evaluate_indexes = list(range(srange[0], srange[1] + 1)) |   #to_evaluate_indexes = list(range(srange[0], srange[1] + 1)) | ||||||
|  |  | ||||||
|   start_time, epoch_time = time.time(), AverageMeter() |   start_time, epoch_time = time.time(), AverageMeter() | ||||||
|   for i, index in enumerate(to_evaluate_indexes): |   for i, index in enumerate(to_evaluate_indexes): | ||||||
| @@ -158,21 +154,55 @@ def traverse_net(candidates: List[int], N: int): | |||||||
|   return nets |   return nets | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def filter_indexes(xlist, mode, save_dir, seeds): | ||||||
|  |   all_indexes = [] | ||||||
|  |   for index in xlist: | ||||||
|  |     if mode == 'cover': | ||||||
|  |       all_indexes.append(index) | ||||||
|  |     else: | ||||||
|  |       for seed in seeds: | ||||||
|  |         temp_path = save_dir / 'arch-{:06d}-seed-{:04d}.pth'.format(index, seed) | ||||||
|  |         if not temp_path.exists(): | ||||||
|  |           all_indexes.append(index) | ||||||
|  |           break | ||||||
|  |   print('{:} [FILTER-INDEXES] : there are {:} architectures in total'.format(time_string(), len(all_indexes))) | ||||||
|  |  | ||||||
|  |   SLURM_PROCID, SLURM_NTASKS = 'SLURM_PROCID', 'SLURM_NTASKS' | ||||||
|  |   if SLURM_PROCID in os.environ and  SLURM_NTASKS in os.environ:  # run on the slurm | ||||||
|  |     proc_id, ntasks = int(os.environ[SLURM_PROCID]), int(os.environ[SLURM_NTASKS]) | ||||||
|  |     assert 0 <= proc_id < ntasks, 'invalid proc_id {:} vs ntasks {:}'.format(proc_id, ntasks) | ||||||
|  |     scales = [int(float(i)/ntasks*len(all_indexes)) for i in range(ntasks)] + [len(all_indexes)] | ||||||
|  |     per_job = [] | ||||||
|  |     for i in range(ntasks): | ||||||
|  |       xs, xe = min(max(scales[i],0), len(all_indexes)-1), min(max(scales[i+1]-1,0), len(all_indexes)-1) | ||||||
|  |       per_job.append((xs, xe)) | ||||||
|  |     for i, srange in enumerate(per_job): | ||||||
|  |       print('  -->> {:2d}/{:02d} : {:}'.format(i, ntasks, srange)) | ||||||
|  |     current_range = per_job[proc_id] | ||||||
|  |     all_indexes = [all_indexes[i] for i in range(current_range[0], current_range[1]+1)] | ||||||
|  |     # set the device id | ||||||
|  |     device = proc_id % torch.cuda.device_count() | ||||||
|  |     torch.cuda.set_device(device) | ||||||
|  |     print('  set the device id = {:}'.format(device)) | ||||||
|  |   print('{:} [FILTER-INDEXES] : after filtering there are {:} architectures in total'.format(time_string(), len(all_indexes))) | ||||||
|  |   return all_indexes | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|   parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter) |   parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||||
|   parser.add_argument('--mode',        type=str,   required=True, choices=['new', 'cover'], help='The script mode.') |   parser.add_argument('--mode',        type=str, required=True, choices=['new', 'cover'], help='The script mode.') | ||||||
|   parser.add_argument('--save_dir',    type=str,   default='output/NAS-BENCH-202', help='Folder to save checkpoints and log.') |   parser.add_argument('--save_dir',    type=str, default='output/NAS-BENCH-202', help='Folder to save checkpoints and log.') | ||||||
|   parser.add_argument('--candidateC',  type=int,   nargs='+', default=[8, 16, 24, 32, 40, 48, 56, 64], help='.') |   parser.add_argument('--candidateC',  type=int, nargs='+', default=[8, 16, 24, 32, 40, 48, 56, 64], help='.') | ||||||
|   parser.add_argument('--num_layers',  type=int,   default=5,      help='The number of layers in a network.') |   parser.add_argument('--num_layers',  type=int, default=5,      help='The number of layers in a network.') | ||||||
|   parser.add_argument('--check_N',     type=int,   default=32768,  help='For safety.') |   parser.add_argument('--check_N',     type=int, default=32768,  help='For safety.') | ||||||
|   # use for train the model |   # use for train the model | ||||||
|   parser.add_argument('--workers',     type=int,   default=8,      help='The number of data loading workers (default: 2)') |   parser.add_argument('--workers',     type=int, default=8,      help='The number of data loading workers (default: 2)') | ||||||
|   parser.add_argument('--srange' ,     type=str,   required=True,  help='The range of models to be evaluated') |   parser.add_argument('--srange' ,     type=str, required=True,  help='The range of models to be evaluated') | ||||||
|   parser.add_argument('--datasets',    type=str,   nargs='+',      help='The applied datasets.') |   parser.add_argument('--datasets',    type=str, nargs='+',      help='The applied datasets.') | ||||||
|   parser.add_argument('--xpaths',      type=str,   nargs='+',      help='The root path for this dataset.') |   parser.add_argument('--xpaths',      type=str, nargs='+',      help='The root path for this dataset.') | ||||||
|   parser.add_argument('--splits',      type=int,   nargs='+',      help='The root path for this dataset.') |   parser.add_argument('--splits',      type=int, nargs='+',      help='The root path for this dataset.') | ||||||
|   parser.add_argument('--hyper',       type=str,   default='12', choices=['12', '90'], help='The tag for hyper-parameters.') |   parser.add_argument('--hyper',       type=str, default='12', choices=['12', '90'], help='The tag for hyper-parameters.') | ||||||
|   parser.add_argument('--seeds'  ,     type=int,   nargs='+',      help='The range of models to be evaluated') |   parser.add_argument('--seeds'  ,     type=int, nargs='+',      help='The range of models to be evaluated') | ||||||
|   args = parser.parse_args() |   args = parser.parse_args() | ||||||
|  |  | ||||||
|   nets = traverse_net(args.candidateC, args.num_layers) |   nets = traverse_net(args.candidateC, args.num_layers) | ||||||
| @@ -182,15 +212,31 @@ if __name__ == '__main__': | |||||||
|   if not os.path.isfile(opt_config): raise ValueError('{:} is not a file.'.format(opt_config)) |   if not os.path.isfile(opt_config): raise ValueError('{:} is not a file.'.format(opt_config)) | ||||||
|   save_dir = Path(args.save_dir) / 'raw-data-{:}'.format(args.hyper) |   save_dir = Path(args.save_dir) / 'raw-data-{:}'.format(args.hyper) | ||||||
|   save_dir.mkdir(parents=True, exist_ok=True) |   save_dir.mkdir(parents=True, exist_ok=True) | ||||||
|   if not isinstance(args.srange, str) or len(args.srange.split('-')) != 2: |   if not isinstance(args.srange, str): | ||||||
|     raise ValueError('Invalid scheme for {:}'.format(args.srange)) |     raise ValueError('Invalid scheme for {:}'.format(args.srange)) | ||||||
|   srange = args.srange.split('-') |   srangestr = "".join(args.srange.split()) | ||||||
|   srange = (int(srange[0]), int(srange[1])) |   to_evaluate_indexes = set() | ||||||
|   assert 0 <= srange[0] <= srange[1] < args.check_N, '{:} vs {:} vs {:}'.format(srange[0], srange[1], args.check_N) |   for srange in srangestr.split(','): | ||||||
|  |     srange = srange.split('-') | ||||||
|  |     if len(srange) != 2: raise ValueError('invalid srange : {:}'.format(srange)) | ||||||
|  |     assert len(srange[0]) == len(srange[1]) == 5, 'invalid srange : {:}'.format(srange) | ||||||
|  |     srange = (int(srange[0]), int(srange[1])) | ||||||
|  |     if not (0 <= srange[0] <= srange[1] < args.check_N): | ||||||
|  |       raise ValueError('{:} vs {:} vs {:}'.format(srange[0], srange[1], args.check_N)) | ||||||
|  |     for i in range(srange[0], srange[1]+1): | ||||||
|  |       to_evaluate_indexes.add(i) | ||||||
|  |  | ||||||
|   assert len(args.seeds) > 0, 'invalid length of seeds args: {:}'.format(args.seeds) |   assert len(args.seeds) > 0, 'invalid length of seeds args: {:}'.format(args.seeds) | ||||||
|   assert len(args.datasets) == len(args.xpaths) == len(args.splits), 'invalid infos : {:} vs {:} vs {:}'.format(len(args.datasets), len(args.xpaths), len(args.splits)) |   if not (len(args.datasets) == len(args.xpaths) == len(args.splits)): | ||||||
|  |     raise ValueError('invalid infos : {:} vs {:} vs {:}'.format(len(args.datasets), len(args.xpaths), len(args.splits))) | ||||||
|   assert args.workers > 0, 'invalid number of workers : {:}'.format(args.workers) |   assert args.workers > 0, 'invalid number of workers : {:}'.format(args.workers) | ||||||
|  |  | ||||||
|   main(save_dir, args.workers, args.datasets, args.xpaths, args.splits, tuple(args.seeds), nets, opt_config, |   target_indexes = filter_indexes(to_evaluate_indexes, args.mode, save_dir, args.seeds) | ||||||
|        srange, args.mode == 'cover') |    | ||||||
|  |   assert torch.cuda.is_available(), 'CUDA is not available.' | ||||||
|  |   torch.backends.cudnn.enabled = True | ||||||
|  |   torch.backends.cudnn.deterministic = True | ||||||
|  |   torch.set_num_threads(args.workers) | ||||||
|  |  | ||||||
|  |   main(save_dir, args.workers, args.datasets, args.xpaths, args.splits, tuple(args.seeds), nets, opt_config, target_indexes, args.mode == 'cover') | ||||||
|  |  | ||||||
|   | |||||||
| @@ -17,13 +17,13 @@ __all__ = ['evaluate_for_seed', 'pure_evaluate', 'get_nas_bench_loaders'] | |||||||
| def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()): | def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()): | ||||||
|   data_time, batch_time, batch = AverageMeter(), AverageMeter(), None |   data_time, batch_time, batch = AverageMeter(), AverageMeter(), None | ||||||
|   losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() |   losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||||
|   latencies = [] |   latencies, device = [], torch.cuda.current_device() | ||||||
|   network.eval() |   network.eval() | ||||||
|   with torch.no_grad(): |   with torch.no_grad(): | ||||||
|     end = time.time() |     end = time.time() | ||||||
|     for i, (inputs, targets) in enumerate(xloader): |     for i, (inputs, targets) in enumerate(xloader): | ||||||
|       targets = targets.cuda(non_blocking=True) |       targets = targets.cuda(device=device, non_blocking=True) | ||||||
|       inputs  = inputs.cuda(non_blocking=True) |       inputs  = inputs.cuda(device=device, non_blocking=True) | ||||||
|       data_time.update(time.time() - end) |       data_time.update(time.time() - end) | ||||||
|       # forward |       # forward | ||||||
|       features, logits = network(inputs) |       features, logits = network(inputs) | ||||||
| @@ -48,12 +48,12 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode: str): | |||||||
|   if mode == 'train'  : network.train() |   if mode == 'train'  : network.train() | ||||||
|   elif mode == 'valid': network.eval() |   elif mode == 'valid': network.eval() | ||||||
|   else: raise ValueError("The mode is not right : {:}".format(mode)) |   else: raise ValueError("The mode is not right : {:}".format(mode)) | ||||||
|  |   device = torch.cuda.current_device() | ||||||
|   data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time() |   data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time() | ||||||
|   for i, (inputs, targets) in enumerate(xloader): |   for i, (inputs, targets) in enumerate(xloader): | ||||||
|     if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader)) |     if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader)) | ||||||
|  |  | ||||||
|     targets = targets.cuda(non_blocking=True) |     targets = targets.cuda(device=device, non_blocking=True) | ||||||
|     if mode == 'train': optimizer.zero_grad() |     if mode == 'train': optimizer.zero_grad() | ||||||
|     # forward |     # forward | ||||||
|     features, logits = network(inputs) |     features, logits = network(inputs) | ||||||
| @@ -84,7 +84,9 @@ def evaluate_for_seed(arch_config, opt_config, train_loader, valid_loaders, seed | |||||||
|   logger.log('FLOP = {:} MB, Param = {:} MB'.format(flop, param)) |   logger.log('FLOP = {:} MB, Param = {:} MB'.format(flop, param)) | ||||||
|   # train and valid |   # train and valid | ||||||
|   optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), opt_config) |   optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), opt_config) | ||||||
|   network, criterion = torch.nn.DataParallel(net).cuda(), criterion.cuda() |   default_device = torch.cuda.current_device() | ||||||
|  |   network = torch.nn.DataParallel(net, device_ids=[default_device]).cuda(device=default_device) | ||||||
|  |   criterion = criterion.cuda(device=default_device) | ||||||
|   # start training |   # start training | ||||||
|   start_time, epoch_time, total_epoch = time.time(), AverageMeter(), opt_config.epochs + opt_config.warmup |   start_time, epoch_time, total_epoch = time.time(), AverageMeter(), opt_config.epochs + opt_config.warmup | ||||||
|   train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {} |   train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {} | ||||||
|   | |||||||
| @@ -1,12 +1,13 @@ | |||||||
| #!/bin/bash | #!/bin/bash | ||||||
| # bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1 1 | # bash ./scripts-search/NAS-Bench-201/test-weights.sh cifar10-valid 1 | ||||||
| echo script name: $0 | echo script name: $0 | ||||||
| echo $# arguments | echo $# arguments | ||||||
| if [ "$#" -ne 3 ] ;then | if [ "$#" -ne 2 ] ;then | ||||||
|   echo "Input illegal number of parameters " $# |   echo "Input illegal number of parameters " $# | ||||||
|   echo "Need 3 parameters for dataset, use_12_epoch, and use_validation_set" |   echo "Need 2 parameters for dataset and use_12_epoch" | ||||||
|   exit 1 |   exit 1 | ||||||
| fi | fi | ||||||
|  |  | ||||||
| if [ "$TORCH_HOME" = "" ]; then | if [ "$TORCH_HOME" = "" ]; then | ||||||
|   echo "Must set TORCH_HOME envoriment variable for data dir saving" |   echo "Must set TORCH_HOME envoriment variable for data dir saving" | ||||||
|   exit 1 |   exit 1 | ||||||
| @@ -17,4 +18,4 @@ fi | |||||||
| OMP_NUM_THREADS=4 python exps/NAS-Bench-201/test-weights.py \ | OMP_NUM_THREADS=4 python exps/NAS-Bench-201/test-weights.py \ | ||||||
| 	--base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 \ | 	--base_path $HOME/.torch/NAS-Bench-201-v1_1-096897 \ | ||||||
| 	--dataset $1 \ | 	--dataset $1 \ | ||||||
| 	--use_12 $2 --use_valid $3 | 	--use_12 $2 | ||||||
|   | |||||||
| @@ -3,14 +3,16 @@ | |||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.01 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.01 # | ||||||
| ##################################################### | ##################################################### | ||||||
| # [mars6] CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/X-X/train-shapes.sh 00000-05000 12 777 | # [mars6] CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/X-X/train-shapes.sh 00000-05000 12 777 | ||||||
| # [mars6] bash ./scripts-search/X-X/train-shapes.sh 05001-10000 12 777 | # [mars6]   bash ./scripts-search/X-X/train-shapes.sh 05001-10000 12 777 | ||||||
| # [mars20] bash ./scripts-search/X-X/train-shapes.sh 10001-14500 12 777 | # [mars20]  bash ./scripts-search/X-X/train-shapes.sh 10001-14500 12 777 | ||||||
| # [mars20] bash ./scripts-search/X-X/train-shapes.sh 14501-19500 12 777 | # [mars20]  bash ./scripts-search/X-X/train-shapes.sh 14501-19500 12 777 | ||||||
| # bash ./scripts-search/X-X/train-shapes.sh 19501-23500 12 777 | # [saturn4] bash ./scripts-search/X-X/train-shapes.sh 19501-23500 12 777 | ||||||
| # bash ./scripts-search/X-X/train-shapes.sh 23501-27500 12 777 | # [saturn4] bash ./scripts-search/X-X/train-shapes.sh 23501-27500 12 777 | ||||||
| # bash ./scripts-search/X-X/train-shapes.sh 27501-30000 12 777 | # [saturn4] bash ./scripts-search/X-X/train-shapes.sh 27501-30000 12 777 | ||||||
| # bash ./scripts-search/X-X/train-shapes.sh 30001-32767 12 777 | # [saturn4] bash ./scripts-search/X-X/train-shapes.sh 30001-32767 12 777 | ||||||
| # | # | ||||||
|  | # CUDA_VISIBLE_DEVICES=2 bash ./scripts-search/X-X/train-shapes.sh 01000-03999,04050-05000,06000-09000,11000-14500,15000-18500,20000-23500,25000-27500,29000-30000 12 777 | ||||||
|  | # SLURM_PROCID=1 SLURM_NTASKS=5 bash ./scripts-search/X-X/train-shapes.sh 01000-03999,04050-05000,06000-09000,11000-14500,15000-18500,20000-23500,25000-27500,29000-30000 90 777 | ||||||
| echo script name: $0 | echo script name: $0 | ||||||
| echo $# arguments | echo $# arguments | ||||||
| if [ "$#" -ne 3 ] ;then | if [ "$#" -ne 3 ] ;then | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user