update affines for NAS
This commit is contained in:
		
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -111,3 +111,4 @@ logs | |||||||
| # snapshot | # snapshot | ||||||
| a.pth | a.pth | ||||||
| cal-merge*.sh | cal-merge*.sh | ||||||
|  | GPU-*.sh | ||||||
|   | |||||||
| @@ -1,7 +1,7 @@ | |||||||
| { | { | ||||||
|   "scheduler": ["str",   "cos"], |   "scheduler": ["str",   "cos"], | ||||||
|   "eta_min"  : ["float", "0.0"], |   "eta_min"  : ["float", "0.0"], | ||||||
|   "epochs"   : ["int",   "10"], |   "epochs"   : ["int",   "12"], | ||||||
|   "warmup"   : ["int",   "0"], |   "warmup"   : ["int",   "0"], | ||||||
|   "optim"    : ["str",   "SGD"], |   "optim"    : ["str",   "SGD"], | ||||||
|   "LR"       : ["float", "0.1"], |   "LR"       : ["float", "0.1"], | ||||||
|   | |||||||
| @@ -15,10 +15,10 @@ from procedures   import get_machine_info | |||||||
| from datasets     import get_datasets | from datasets     import get_datasets | ||||||
| from log_utils    import Logger, AverageMeter, time_string, convert_secs2time | from log_utils    import Logger, AverageMeter, time_string, convert_secs2time | ||||||
| from models       import CellStructure, CellArchitectures, get_search_spaces | from models       import CellStructure, CellArchitectures, get_search_spaces | ||||||
| from AA_functions import evaluate_for_seed | from AA_functions_v2 import evaluate_for_seed | ||||||
|  |  | ||||||
|  |  | ||||||
| def evaluate_all_datasets(arch, datasets, xpaths, splits, seed, arch_config, workers, logger): | def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger): | ||||||
|   machine_info, arch_config = get_machine_info(), deepcopy(arch_config) |   machine_info, arch_config = get_machine_info(), deepcopy(arch_config) | ||||||
|   all_infos = {'info': machine_info} |   all_infos = {'info': machine_info} | ||||||
|   all_dataset_keys = [] |   all_dataset_keys = [] | ||||||
| @@ -28,10 +28,12 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, seed, arch_config, wor | |||||||
|     train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) |     train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) | ||||||
|     # load the configurature |     # load the configurature | ||||||
|     if dataset == 'cifar10' or dataset == 'cifar100': |     if dataset == 'cifar10' or dataset == 'cifar100': | ||||||
|       config_path = 'configs/nas-benchmark/CIFAR.config' |       if use_less: config_path = 'configs/nas-benchmark/LESS.config' | ||||||
|  |       else       : config_path = 'configs/nas-benchmark/CIFAR.config' | ||||||
|       split_info  = load_config('configs/nas-benchmark/cifar-split.txt', None, None) |       split_info  = load_config('configs/nas-benchmark/cifar-split.txt', None, None) | ||||||
|     elif dataset.startswith('ImageNet16'): |     elif dataset.startswith('ImageNet16'): | ||||||
|       config_path = 'configs/nas-benchmark/ImageNet-16.config' |       if use_less: config_path = 'configs/nas-benchmark/LESS.config' | ||||||
|  |       else       : config_path = 'configs/nas-benchmark/ImageNet-16.config' | ||||||
|       split_info  = load_config('configs/nas-benchmark/{:}-split.txt'.format(dataset), None, None) |       split_info  = load_config('configs/nas-benchmark/{:}-split.txt'.format(dataset), None, None) | ||||||
|     else: |     else: | ||||||
|       raise ValueError('invalid dataset : {:}'.format(dataset)) |       raise ValueError('invalid dataset : {:}'.format(dataset)) | ||||||
| @@ -41,6 +43,8 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, seed, arch_config, wor | |||||||
|                             logger) |                             logger) | ||||||
|     # check whether use splited validation set |     # check whether use splited validation set | ||||||
|     if bool(split): |     if bool(split): | ||||||
|  |       assert dataset == 'cifar10' | ||||||
|  |       ValLoaders = {'ori-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)} | ||||||
|       assert len(train_data) == len(split_info.train) + len(split_info.valid), 'invalid length : {:} vs {:} + {:}'.format(len(train_data), len(split_info.train), len(split_info.valid)) |       assert len(train_data) == len(split_info.train) + len(split_info.valid), 'invalid length : {:} vs {:} + {:}'.format(len(train_data), len(split_info.train), len(split_info.valid)) | ||||||
|       train_data_v2 = deepcopy(train_data) |       train_data_v2 = deepcopy(train_data) | ||||||
|       train_data_v2.transform = valid_data.transform |       train_data_v2.transform = valid_data.transform | ||||||
| @@ -48,23 +52,42 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, seed, arch_config, wor | |||||||
|       # data loader |       # data loader | ||||||
|       train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), num_workers=workers, pin_memory=True) |       train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), num_workers=workers, pin_memory=True) | ||||||
|       valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), num_workers=workers, pin_memory=True) |       valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), num_workers=workers, pin_memory=True) | ||||||
|  |       ValLoaders['x-valid'] = valid_loader | ||||||
|     else: |     else: | ||||||
|       # data loader |       # data loader | ||||||
|       train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, shuffle=True , num_workers=workers, pin_memory=True) |       train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, shuffle=True , num_workers=workers, pin_memory=True) | ||||||
|       valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True) |       valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True) | ||||||
|  |       if dataset == 'cifar10': | ||||||
|  |         ValLoaders = {'ori-test': valid_loader} | ||||||
|  |       elif dataset == 'cifar100': | ||||||
|  |         cifar100_splits = load_config('configs/nas-benchmark/cifar100-test-split.txt', None, None) | ||||||
|  |         ValLoaders = {'ori-test': valid_loader, | ||||||
|  |                       'x-valid' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), num_workers=workers, pin_memory=True), | ||||||
|  |                       'x-test'  : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest ), num_workers=workers, pin_memory=True) | ||||||
|  |                      } | ||||||
|  |       elif dataset == 'ImageNet16-120': | ||||||
|  |         imagenet16_splits = load_config('configs/nas-benchmark/imagenet-16-120-test-split.txt', None, None) | ||||||
|  |         ValLoaders = {'ori-test': valid_loader, | ||||||
|  |                       'x-valid' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xvalid), num_workers=workers, pin_memory=True), | ||||||
|  |                       'x-test'  : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xtest ), num_workers=workers, pin_memory=True) | ||||||
|  |                      } | ||||||
|  |       else: | ||||||
|  |         raise ValueError('invalid dataset : {:}'.format(dataset)) | ||||||
|  |  | ||||||
|     dataset_key = '{:}'.format(dataset) |     dataset_key = '{:}'.format(dataset) | ||||||
|     if bool(split): dataset_key = dataset_key + '-valid' |     if bool(split): dataset_key = dataset_key + '-valid' | ||||||
|     logger.log('Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size)) |     logger.log('Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size)) | ||||||
|     logger.log('Evaluate ||||||| {:10s} ||||||| Config={:}'.format(dataset_key, config)) |     logger.log('Evaluate ||||||| {:10s} ||||||| Config={:}'.format(dataset_key, config)) | ||||||
|     results = evaluate_for_seed(arch_config, config, arch, train_loader, valid_loader, seed, logger) |     for key, value in ValLoaders.items(): | ||||||
|  |       logger.log('Evaluate ---->>>> {:10s} with {:} batchs'.format(key, len(value))) | ||||||
|  |     results = evaluate_for_seed(arch_config, config, arch, train_loader, ValLoaders, seed, logger) | ||||||
|     all_infos[dataset_key] = results |     all_infos[dataset_key] = results | ||||||
|     all_dataset_keys.append( dataset_key ) |     all_dataset_keys.append( dataset_key ) | ||||||
|   all_infos['all_dataset_keys'] = all_dataset_keys |   all_infos['all_dataset_keys'] = all_dataset_keys | ||||||
|   return all_infos |   return all_infos | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(save_dir, workers, datasets, xpaths, splits, srange, arch_index, seeds, cover_mode, meta_info, arch_config): | def main(save_dir, workers, datasets, xpaths, splits, use_less, srange, arch_index, seeds, cover_mode, meta_info, arch_config): | ||||||
|   assert torch.cuda.is_available(), 'CUDA is not available.' |   assert torch.cuda.is_available(), 'CUDA is not available.' | ||||||
|   torch.backends.cudnn.enabled   = True |   torch.backends.cudnn.enabled   = True | ||||||
|   #torch.backends.cudnn.benchmark = True |   #torch.backends.cudnn.benchmark = True | ||||||
| @@ -73,6 +96,9 @@ def main(save_dir, workers, datasets, xpaths, splits, srange, arch_index, seeds, | |||||||
|  |  | ||||||
|   assert len(srange) == 2 and 0 <= srange[0] <= srange[1], 'invalid srange : {:}'.format(srange) |   assert len(srange) == 2 and 0 <= srange[0] <= srange[1], 'invalid srange : {:}'.format(srange) | ||||||
|    |    | ||||||
|  |   if use_less: | ||||||
|  |     sub_dir = Path(save_dir) / '{:06d}-{:06d}-C{:}-N{:}-LESS'.format(srange[0], srange[1], arch_config['channel'], arch_config['num_cells']) | ||||||
|  |   else: | ||||||
|     sub_dir = Path(save_dir) / '{:06d}-{:06d}-C{:}-N{:}'.format(srange[0], srange[1], arch_config['channel'], arch_config['num_cells']) |     sub_dir = Path(save_dir) / '{:06d}-{:06d}-C{:}-N{:}'.format(srange[0], srange[1], arch_config['channel'], arch_config['num_cells']) | ||||||
|   logger  = Logger(str(sub_dir), 0, False) |   logger  = Logger(str(sub_dir), 0, False) | ||||||
|  |  | ||||||
| @@ -114,7 +140,7 @@ def main(save_dir, workers, datasets, xpaths, splits, srange, arch_index, seeds, | |||||||
|           has_continue = True |           has_continue = True | ||||||
|           continue |           continue | ||||||
|       results = evaluate_all_datasets(CellStructure.str2structure(arch), \ |       results = evaluate_all_datasets(CellStructure.str2structure(arch), \ | ||||||
|                                         datasets, xpaths, splits, seed, \ |                                         datasets, xpaths, splits, use_less, seed, \ | ||||||
|                                         arch_config, workers, logger) |                                         arch_config, workers, logger) | ||||||
|       torch.save(results, to_save_name) |       torch.save(results, to_save_name) | ||||||
|       logger.log('{:} --evaluate-- {:06d}/{:06d} ({:06d}/{:06d})-th seed={:} done, save into {:}'.format('-'*15, i, len(to_evaluate_indexes), index, meta_info['total'], seed, to_save_name)) |       logger.log('{:} --evaluate-- {:06d}/{:06d} ({:06d}/{:06d})-th seed={:} done, save into {:}'.format('-'*15, i, len(to_evaluate_indexes), index, meta_info['total'], seed, to_save_name)) | ||||||
| @@ -130,7 +156,7 @@ def main(save_dir, workers, datasets, xpaths, splits, srange, arch_index, seeds, | |||||||
|   logger.close() |   logger.close() | ||||||
|  |  | ||||||
|  |  | ||||||
| def train_single_model(save_dir, workers, datasets, xpaths, splits, seeds, model_str, arch_config): | def train_single_model(save_dir, workers, datasets, xpaths, use_less, splits, seeds, model_str, arch_config): | ||||||
|   assert torch.cuda.is_available(), 'CUDA is not available.' |   assert torch.cuda.is_available(), 'CUDA is not available.' | ||||||
|   torch.backends.cudnn.enabled   = True |   torch.backends.cudnn.enabled   = True | ||||||
|   torch.backends.cudnn.deterministic = True |   torch.backends.cudnn.deterministic = True | ||||||
| @@ -160,7 +186,7 @@ def train_single_model(save_dir, workers, datasets, xpaths, splits, seeds, model | |||||||
|       checkpoint = torch.load(to_save_name) |       checkpoint = torch.load(to_save_name) | ||||||
|     else: |     else: | ||||||
|       logger.log('Does not find the existing file {:}, train and evaluate!'.format(to_save_name)) |       logger.log('Does not find the existing file {:}, train and evaluate!'.format(to_save_name)) | ||||||
|       checkpoint = evaluate_all_datasets(arch, datasets, xpaths, splits, seed, arch_config, workers, logger) |       checkpoint = evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger) | ||||||
|       torch.save(checkpoint, to_save_name) |       torch.save(checkpoint, to_save_name) | ||||||
|     # log information |     # log information | ||||||
|     logger.log('{:}'.format(checkpoint['info'])) |     logger.log('{:}'.format(checkpoint['info'])) | ||||||
| @@ -252,6 +278,7 @@ if __name__ == '__main__': | |||||||
|   parser.add_argument('--datasets',    type=str,   nargs='+',      help='The applied datasets.') |   parser.add_argument('--datasets',    type=str,   nargs='+',      help='The applied datasets.') | ||||||
|   parser.add_argument('--xpaths',      type=str,   nargs='+',      help='The root path for this dataset.') |   parser.add_argument('--xpaths',      type=str,   nargs='+',      help='The root path for this dataset.') | ||||||
|   parser.add_argument('--splits',      type=int,   nargs='+',      help='The root path for this dataset.') |   parser.add_argument('--splits',      type=int,   nargs='+',      help='The root path for this dataset.') | ||||||
|  |   parser.add_argument('--use_less',    type=int,   default=0,      help='Using the less-training-epoch config.') | ||||||
|   parser.add_argument('--seeds'  ,     type=int,   nargs='+',      help='The range of models to be evaluated') |   parser.add_argument('--seeds'  ,     type=int,   nargs='+',      help='The range of models to be evaluated') | ||||||
|   parser.add_argument('--channel',     type=int,                   help='The number of channels.') |   parser.add_argument('--channel',     type=int,                   help='The number of channels.') | ||||||
|   parser.add_argument('--num_cells',   type=int,                   help='The number of cells in one stage.') |   parser.add_argument('--num_cells',   type=int,                   help='The number of cells in one stage.') | ||||||
| @@ -264,7 +291,7 @@ if __name__ == '__main__': | |||||||
|   elif args.mode.startswith('specific'): |   elif args.mode.startswith('specific'): | ||||||
|     assert len(args.mode.split('-')) == 2, 'invalid mode : {:}'.format(args.mode) |     assert len(args.mode.split('-')) == 2, 'invalid mode : {:}'.format(args.mode) | ||||||
|     model_str = args.mode.split('-')[1] |     model_str = args.mode.split('-')[1] | ||||||
|     train_single_model(args.save_dir, args.workers, args.datasets, args.xpaths, args.splits, \ |     train_single_model(args.save_dir, args.workers, args.datasets, args.xpaths, args.splits, args.use_less>0, \ | ||||||
|                          tuple(args.seeds), model_str, {'channel': args.channel, 'num_cells': args.num_cells}) |                          tuple(args.seeds), model_str, {'channel': args.channel, 'num_cells': args.num_cells}) | ||||||
|   else: |   else: | ||||||
|     meta_path = Path(args.save_dir) / 'meta-node-{:}.pth'.format(args.max_node) |     meta_path = Path(args.save_dir) / 'meta-node-{:}.pth'.format(args.max_node) | ||||||
| @@ -276,7 +303,7 @@ if __name__ == '__main__': | |||||||
|     assert len(args.datasets) == len(args.xpaths) == len(args.splits), 'invalid infos : {:} vs {:} vs {:}'.format(len(args.datasets), len(args.xpaths), len(args.splits)) |     assert len(args.datasets) == len(args.xpaths) == len(args.splits), 'invalid infos : {:} vs {:} vs {:}'.format(len(args.datasets), len(args.xpaths), len(args.splits)) | ||||||
|     assert args.workers > 0, 'invalid number of workers : {:}'.format(args.workers) |     assert args.workers > 0, 'invalid number of workers : {:}'.format(args.workers) | ||||||
|    |    | ||||||
|     main(args.save_dir, args.workers, args.datasets, args.xpaths, args.splits, \ |     main(args.save_dir, args.workers, args.datasets, args.xpaths, args.splits, args.use_less>0, \ | ||||||
|            tuple(args.srange), args.arch_index, tuple(args.seeds), \ |            tuple(args.srange), args.arch_index, tuple(args.seeds), \ | ||||||
|            args.mode == 'cover', meta_info, \ |            args.mode == 'cover', meta_info, \ | ||||||
|            {'channel': args.channel, 'num_cells': args.num_cells}) |            {'channel': args.channel, 'num_cells': args.num_cells}) | ||||||
|   | |||||||
| @@ -47,6 +47,7 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode): | |||||||
|   elif mode == 'valid': network.eval() |   elif mode == 'valid': network.eval() | ||||||
|   else: raise ValueError("The mode is not right : {:}".format(mode)) |   else: raise ValueError("The mode is not right : {:}".format(mode)) | ||||||
|  |  | ||||||
|  |   batch_time, end = AverageMeter(), time.time() | ||||||
|   for i, (inputs, targets) in enumerate(xloader): |   for i, (inputs, targets) in enumerate(xloader): | ||||||
|     if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader)) |     if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader)) | ||||||
|  |  | ||||||
| @@ -64,7 +65,10 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode): | |||||||
|     losses.update(loss.item(),  inputs.size(0)) |     losses.update(loss.item(),  inputs.size(0)) | ||||||
|     top1.update  (prec1.item(), inputs.size(0)) |     top1.update  (prec1.item(), inputs.size(0)) | ||||||
|     top5.update  (prec5.item(), inputs.size(0)) |     top5.update  (prec5.item(), inputs.size(0)) | ||||||
|   return losses.avg, top1.avg, top5.avg |     # count time | ||||||
|  |     batch_time.update(time.time() - end) | ||||||
|  |     end = time.time() | ||||||
|  |   return losses.avg, top1.avg, top5.avg, batch_time.sum | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -87,18 +91,21 @@ def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loader, see | |||||||
|   # start training |   # start training | ||||||
|   start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup |   start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup | ||||||
|   train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {} |   train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {} | ||||||
|  |   train_times , valid_times = {}, {} | ||||||
|   for epoch in range(total_epoch): |   for epoch in range(total_epoch): | ||||||
|     scheduler.update(epoch, 0.0) |     scheduler.update(epoch, 0.0) | ||||||
|  |  | ||||||
|     train_loss, train_acc1, train_acc5 = procedure(train_loader, network, criterion, scheduler, optimizer, 'train') |     train_loss, train_acc1, train_acc5, train_tm = procedure(train_loader, network, criterion, scheduler, optimizer, 'train') | ||||||
|     with torch.no_grad(): |     with torch.no_grad(): | ||||||
|       valid_loss, valid_acc1, valid_acc5 = procedure(valid_loader, network, criterion,      None,      None, 'valid') |       valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(valid_loader, network, criterion,      None,      None, 'valid') | ||||||
|     train_losses[epoch] = train_loss |     train_losses[epoch] = train_loss | ||||||
|     train_acc1es[epoch] = train_acc1  |     train_acc1es[epoch] = train_acc1  | ||||||
|     train_acc5es[epoch] = train_acc5 |     train_acc5es[epoch] = train_acc5 | ||||||
|     valid_losses[epoch] = valid_loss |     valid_losses[epoch] = valid_loss | ||||||
|     valid_acc1es[epoch] = valid_acc1  |     valid_acc1es[epoch] = valid_acc1  | ||||||
|     valid_acc5es[epoch] = valid_acc5 |     valid_acc5es[epoch] = valid_acc5 | ||||||
|  |     train_times [epoch] = train_tm | ||||||
|  |     valid_times [epoch] = valid_tm | ||||||
|  |  | ||||||
|     # measure elapsed time |     # measure elapsed time | ||||||
|     epoch_time.update(time.time() - start_time) |     epoch_time.update(time.time() - start_time) | ||||||
| @@ -114,9 +121,11 @@ def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loader, see | |||||||
|                'train_losses': train_losses, |                'train_losses': train_losses, | ||||||
|                'train_acc1es': train_acc1es, |                'train_acc1es': train_acc1es, | ||||||
|                'train_acc5es': train_acc5es, |                'train_acc5es': train_acc5es, | ||||||
|  |                'train_times' : train_times, | ||||||
|                'valid_losses': valid_losses, |                'valid_losses': valid_losses, | ||||||
|                'valid_acc1es': valid_acc1es, |                'valid_acc1es': valid_acc1es, | ||||||
|                'valid_acc5es': valid_acc5es, |                'valid_acc5es': valid_acc5es, | ||||||
|  |                'valid_times' : valid_times, | ||||||
|                'net_state_dict': net.state_dict(), |                'net_state_dict': net.state_dict(), | ||||||
|                'net_string'  : '{:}'.format(net), |                'net_string'  : '{:}'.format(net), | ||||||
|                'finish-train': True |                'finish-train': True | ||||||
|   | |||||||
| @@ -19,9 +19,9 @@ class InferCell(nn.Module): | |||||||
|       cur_innod = [] |       cur_innod = [] | ||||||
|       for (op_name, op_in) in node_info: |       for (op_name, op_in) in node_info: | ||||||
|         if op_in == 0: |         if op_in == 0: | ||||||
|           layer = OPS[op_name](C_in , C_out, stride) |           layer = OPS[op_name](C_in , C_out, stride, True) | ||||||
|         else: |         else: | ||||||
|           layer = OPS[op_name](C_out, C_out,      1) |           layer = OPS[op_name](C_out, C_out,      1, True) | ||||||
|         cur_index.append( len(self.layers) ) |         cur_index.append( len(self.layers) ) | ||||||
|         cur_innod.append( op_in ) |         cur_innod.append( op_in ) | ||||||
|         self.layers.append( layer ) |         self.layers.append( layer ) | ||||||
|   | |||||||
| @@ -22,7 +22,7 @@ class TinyNetwork(nn.Module): | |||||||
|     self.cells = nn.ModuleList() |     self.cells = nn.ModuleList() | ||||||
|     for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): |     for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): | ||||||
|       if reduction: |       if reduction: | ||||||
|         cell = ResNetBasicblock(C_prev, C_curr, 2) |         cell = ResNetBasicblock(C_prev, C_curr, 2, True) | ||||||
|       else: |       else: | ||||||
|         cell = InferCell(genotype, C_prev, C_curr, 1) |         cell = InferCell(genotype, C_prev, C_curr, 1) | ||||||
|       self.cells.append( cell ) |       self.cells.append( cell ) | ||||||
|   | |||||||
| @@ -4,16 +4,16 @@ | |||||||
| import torch | import torch | ||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
|  |  | ||||||
| __all__ = ['OPS', 'ReLUConvBN', 'ResNetBasicblock', 'SearchSpaceNames'] | __all__ = ['OPS', 'ResNetBasicblock', 'SearchSpaceNames'] | ||||||
|  |  | ||||||
| OPS = { | OPS = { | ||||||
|   'none'         : lambda C_in, C_out, stride: Zero(C_in, C_out, stride), |   'none'         : lambda C_in, C_out, stride, affine: Zero(C_in, C_out, stride), | ||||||
|   'avg_pool_3x3' : lambda C_in, C_out, stride: POOLING(C_in, C_out, stride, 'avg'), |   'avg_pool_3x3' : lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'avg'), | ||||||
|   'max_pool_3x3' : lambda C_in, C_out, stride: POOLING(C_in, C_out, stride, 'max'), |   'max_pool_3x3' : lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'max'), | ||||||
|   'nor_conv_7x7' : lambda C_in, C_out, stride: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1)), |   'nor_conv_7x7' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1), affine), | ||||||
|   'nor_conv_3x3' : lambda C_in, C_out, stride: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1)), |   'nor_conv_3x3' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine), | ||||||
|   'nor_conv_1x1' : lambda C_in, C_out, stride: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1)), |   'nor_conv_1x1' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1), affine), | ||||||
|   'skip_connect' : lambda C_in, C_out, stride: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride), |   'skip_connect' : lambda C_in, C_out, stride, affine: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine), | ||||||
| } | } | ||||||
|  |  | ||||||
| CONNECT_NAS_BENCHMARK  = ['none', 'skip_connect', 'nor_conv_3x3'] | CONNECT_NAS_BENCHMARK  = ['none', 'skip_connect', 'nor_conv_3x3'] | ||||||
| @@ -26,12 +26,12 @@ SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK, | |||||||
|  |  | ||||||
| class ReLUConvBN(nn.Module): | class ReLUConvBN(nn.Module): | ||||||
|  |  | ||||||
|   def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation): |   def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine): | ||||||
|     super(ReLUConvBN, self).__init__() |     super(ReLUConvBN, self).__init__() | ||||||
|     self.op = nn.Sequential( |     self.op = nn.Sequential( | ||||||
|       nn.ReLU(inplace=False), |       nn.ReLU(inplace=False), | ||||||
|       nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False), |       nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False), | ||||||
|       nn.BatchNorm2d(C_out) |       nn.BatchNorm2d(C_out, affine=affine) | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|   def forward(self, x): |   def forward(self, x): | ||||||
| @@ -40,17 +40,17 @@ class ReLUConvBN(nn.Module): | |||||||
|  |  | ||||||
| class ResNetBasicblock(nn.Module): | class ResNetBasicblock(nn.Module): | ||||||
|  |  | ||||||
|   def __init__(self, inplanes, planes, stride): |   def __init__(self, inplanes, planes, stride, affine=True): | ||||||
|     super(ResNetBasicblock, self).__init__() |     super(ResNetBasicblock, self).__init__() | ||||||
|     assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) |     assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) | ||||||
|     self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1) |     self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1, affine) | ||||||
|     self.conv_b = ReLUConvBN(  planes, planes, 3,      1, 1, 1) |     self.conv_b = ReLUConvBN(  planes, planes, 3,      1, 1, 1, affine) | ||||||
|     if stride == 2: |     if stride == 2: | ||||||
|       self.downsample = nn.Sequential( |       self.downsample = nn.Sequential( | ||||||
|                            nn.AvgPool2d(kernel_size=2, stride=2, padding=0), |                            nn.AvgPool2d(kernel_size=2, stride=2, padding=0), | ||||||
|                            nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False)) |                            nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False)) | ||||||
|     elif inplanes != planes: |     elif inplanes != planes: | ||||||
|       self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1) |       self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1, affine) | ||||||
|     else: |     else: | ||||||
|       self.downsample = None |       self.downsample = None | ||||||
|     self.in_dim  = inplanes |     self.in_dim  = inplanes | ||||||
| @@ -76,12 +76,12 @@ class ResNetBasicblock(nn.Module): | |||||||
|  |  | ||||||
| class POOLING(nn.Module): | class POOLING(nn.Module): | ||||||
|  |  | ||||||
|   def __init__(self, C_in, C_out, stride, mode): |   def __init__(self, C_in, C_out, stride, mode, affine=True): | ||||||
|     super(POOLING, self).__init__() |     super(POOLING, self).__init__() | ||||||
|     if C_in == C_out: |     if C_in == C_out: | ||||||
|       self.preprocess = None |       self.preprocess = None | ||||||
|     else: |     else: | ||||||
|       self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0) |       self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, affine) | ||||||
|     if mode == 'avg'  : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) |     if mode == 'avg'  : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) | ||||||
|     elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1) |     elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1) | ||||||
|     else              : raise ValueError('Invalid mode={:} in POOLING'.format(mode)) |     else              : raise ValueError('Invalid mode={:} in POOLING'.format(mode)) | ||||||
| @@ -126,7 +126,7 @@ class Zero(nn.Module): | |||||||
|  |  | ||||||
| class FactorizedReduce(nn.Module): | class FactorizedReduce(nn.Module): | ||||||
|  |  | ||||||
|   def __init__(self, C_in, C_out, stride): |   def __init__(self, C_in, C_out, stride, affine): | ||||||
|     super(FactorizedReduce, self).__init__() |     super(FactorizedReduce, self).__init__() | ||||||
|     self.stride = stride |     self.stride = stride | ||||||
|     self.C_in   = C_in   |     self.C_in   = C_in   | ||||||
| @@ -141,8 +141,7 @@ class FactorizedReduce(nn.Module): | |||||||
|       self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0) |       self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0) | ||||||
|     else: |     else: | ||||||
|       raise ValueError('Invalid stride : {:}'.format(stride)) |       raise ValueError('Invalid stride : {:}'.format(stride)) | ||||||
|      |     self.bn = nn.BatchNorm2d(C_out, affine=affine) | ||||||
|     self.bn = nn.BatchNorm2d(C_out) |  | ||||||
|  |  | ||||||
|   def forward(self, x): |   def forward(self, x): | ||||||
|     x = self.relu(x) |     x = self.relu(x) | ||||||
|   | |||||||
| @@ -23,9 +23,9 @@ class SearchCell(nn.Module): | |||||||
|       for j in range(i): |       for j in range(i): | ||||||
|         node_str = '{:}<-{:}'.format(i, j) |         node_str = '{:}<-{:}'.format(i, j) | ||||||
|         if j == 0: |         if j == 0: | ||||||
|           xlists = [OPS[op_name](C_in , C_out, stride) for op_name in op_names] |           xlists = [OPS[op_name](C_in , C_out, stride, False) for op_name in op_names] | ||||||
|         else: |         else: | ||||||
|           xlists = [OPS[op_name](C_in , C_out,      1) for op_name in op_names] |           xlists = [OPS[op_name](C_in , C_out,      1, False) for op_name in op_names] | ||||||
|         self.edges[ node_str ] = nn.ModuleList( xlists ) |         self.edges[ node_str ] = nn.ModuleList( xlists ) | ||||||
|     self.edge_keys  = sorted(list(self.edges.keys())) |     self.edge_keys  = sorted(list(self.edges.keys())) | ||||||
|     self.edge2index = {key:i for i, key in enumerate(self.edge_keys)} |     self.edge2index = {key:i for i, key in enumerate(self.edge_keys)} | ||||||
|   | |||||||
| @@ -29,6 +29,7 @@ save_dir=./output/AA-NAS-BENCH-4/ | |||||||
|  |  | ||||||
| OMP_NUM_THREADS=4 python ./exps/AA-NAS-Bench-main.py \ | OMP_NUM_THREADS=4 python ./exps/AA-NAS-Bench-main.py \ | ||||||
| 	--mode ${mode} --save_dir ${save_dir} --max_node 4 \ | 	--mode ${mode} --save_dir ${save_dir} --max_node 4 \ | ||||||
|  | 	--use_less 0 \ | ||||||
| 	--datasets cifar10 cifar10 cifar100 ImageNet16-120 \ | 	--datasets cifar10 cifar10 cifar100 ImageNet16-120 \ | ||||||
| 	--splits   1       0       0        0 \ | 	--splits   1       0       0        0 \ | ||||||
| 	--xpaths $TORCH_HOME/cifar.python \ | 	--xpaths $TORCH_HOME/cifar.python \ | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user