From d175a361bddc19717f269c32eb44951653357e6c Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Mon, 2 Dec 2019 18:03:40 +1100 Subject: [PATCH] update affines for NAS --- .gitignore | 1 + configs/nas-benchmark/LESS.config | 2 +- exps/AA-NAS-Bench-main.py | 53 +++++++++++++++++++------ exps/AA_functions.py | 15 +++++-- lib/models/cell_infers/cells.py | 4 +- lib/models/cell_infers/tiny_network.py | 2 +- lib/models/cell_operations.py | 37 +++++++++-------- lib/models/cell_searchs/search_cells.py | 4 +- scripts-search/AA-NAS-train-archs.sh | 1 + 9 files changed, 78 insertions(+), 41 deletions(-) diff --git a/.gitignore b/.gitignore index 3a2a7fb..2e838bb 100644 --- a/.gitignore +++ b/.gitignore @@ -111,3 +111,4 @@ logs # snapshot a.pth cal-merge*.sh +GPU-*.sh diff --git a/configs/nas-benchmark/LESS.config b/configs/nas-benchmark/LESS.config index 1e3e559..05e308e 100644 --- a/configs/nas-benchmark/LESS.config +++ b/configs/nas-benchmark/LESS.config @@ -1,7 +1,7 @@ { "scheduler": ["str", "cos"], "eta_min" : ["float", "0.0"], - "epochs" : ["int", "10"], + "epochs" : ["int", "12"], "warmup" : ["int", "0"], "optim" : ["str", "SGD"], "LR" : ["float", "0.1"], diff --git a/exps/AA-NAS-Bench-main.py b/exps/AA-NAS-Bench-main.py index e5dee74..c1dd673 100644 --- a/exps/AA-NAS-Bench-main.py +++ b/exps/AA-NAS-Bench-main.py @@ -15,10 +15,10 @@ from procedures import get_machine_info from datasets import get_datasets from log_utils import Logger, AverageMeter, time_string, convert_secs2time from models import CellStructure, CellArchitectures, get_search_spaces -from AA_functions import evaluate_for_seed +from AA_functions_v2 import evaluate_for_seed -def evaluate_all_datasets(arch, datasets, xpaths, splits, seed, arch_config, workers, logger): +def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger): machine_info, arch_config = get_machine_info(), deepcopy(arch_config) all_infos = {'info': machine_info} all_dataset_keys = [] @@ -28,10 +28,12 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, seed, arch_config, wor train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) # load the configurature if dataset == 'cifar10' or dataset == 'cifar100': - config_path = 'configs/nas-benchmark/CIFAR.config' + if use_less: config_path = 'configs/nas-benchmark/LESS.config' + else : config_path = 'configs/nas-benchmark/CIFAR.config' split_info = load_config('configs/nas-benchmark/cifar-split.txt', None, None) elif dataset.startswith('ImageNet16'): - config_path = 'configs/nas-benchmark/ImageNet-16.config' + if use_less: config_path = 'configs/nas-benchmark/LESS.config' + else : config_path = 'configs/nas-benchmark/ImageNet-16.config' split_info = load_config('configs/nas-benchmark/{:}-split.txt'.format(dataset), None, None) else: raise ValueError('invalid dataset : {:}'.format(dataset)) @@ -41,6 +43,8 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, seed, arch_config, wor logger) # check whether use splited validation set if bool(split): + assert dataset == 'cifar10' + ValLoaders = {'ori-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)} assert len(train_data) == len(split_info.train) + len(split_info.valid), 'invalid length : {:} vs {:} + {:}'.format(len(train_data), len(split_info.train), len(split_info.valid)) train_data_v2 = deepcopy(train_data) train_data_v2.transform = valid_data.transform @@ -48,23 +52,42 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, seed, arch_config, wor # data loader train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), num_workers=workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), num_workers=workers, pin_memory=True) + ValLoaders['x-valid'] = valid_loader else: # data loader train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, shuffle=True , num_workers=workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True) - + if dataset == 'cifar10': + ValLoaders = {'ori-test': valid_loader} + elif dataset == 'cifar100': + cifar100_splits = load_config('configs/nas-benchmark/cifar100-test-split.txt', None, None) + ValLoaders = {'ori-test': valid_loader, + 'x-valid' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), num_workers=workers, pin_memory=True), + 'x-test' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest ), num_workers=workers, pin_memory=True) + } + elif dataset == 'ImageNet16-120': + imagenet16_splits = load_config('configs/nas-benchmark/imagenet-16-120-test-split.txt', None, None) + ValLoaders = {'ori-test': valid_loader, + 'x-valid' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xvalid), num_workers=workers, pin_memory=True), + 'x-test' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xtest ), num_workers=workers, pin_memory=True) + } + else: + raise ValueError('invalid dataset : {:}'.format(dataset)) + dataset_key = '{:}'.format(dataset) if bool(split): dataset_key = dataset_key + '-valid' logger.log('Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size)) logger.log('Evaluate ||||||| {:10s} ||||||| Config={:}'.format(dataset_key, config)) - results = evaluate_for_seed(arch_config, config, arch, train_loader, valid_loader, seed, logger) + for key, value in ValLoaders.items(): + logger.log('Evaluate ---->>>> {:10s} with {:} batchs'.format(key, len(value))) + results = evaluate_for_seed(arch_config, config, arch, train_loader, ValLoaders, seed, logger) all_infos[dataset_key] = results all_dataset_keys.append( dataset_key ) all_infos['all_dataset_keys'] = all_dataset_keys return all_infos -def main(save_dir, workers, datasets, xpaths, splits, srange, arch_index, seeds, cover_mode, meta_info, arch_config): +def main(save_dir, workers, datasets, xpaths, splits, use_less, srange, arch_index, seeds, cover_mode, meta_info, arch_config): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True #torch.backends.cudnn.benchmark = True @@ -73,7 +96,10 @@ def main(save_dir, workers, datasets, xpaths, splits, srange, arch_index, seeds, assert len(srange) == 2 and 0 <= srange[0] <= srange[1], 'invalid srange : {:}'.format(srange) - sub_dir = Path(save_dir) / '{:06d}-{:06d}-C{:}-N{:}'.format(srange[0], srange[1], arch_config['channel'], arch_config['num_cells']) + if use_less: + sub_dir = Path(save_dir) / '{:06d}-{:06d}-C{:}-N{:}-LESS'.format(srange[0], srange[1], arch_config['channel'], arch_config['num_cells']) + else: + sub_dir = Path(save_dir) / '{:06d}-{:06d}-C{:}-N{:}'.format(srange[0], srange[1], arch_config['channel'], arch_config['num_cells']) logger = Logger(str(sub_dir), 0, False) all_archs = meta_info['archs'] @@ -114,7 +140,7 @@ def main(save_dir, workers, datasets, xpaths, splits, srange, arch_index, seeds, has_continue = True continue results = evaluate_all_datasets(CellStructure.str2structure(arch), \ - datasets, xpaths, splits, seed, \ + datasets, xpaths, splits, use_less, seed, \ arch_config, workers, logger) torch.save(results, to_save_name) logger.log('{:} --evaluate-- {:06d}/{:06d} ({:06d}/{:06d})-th seed={:} done, save into {:}'.format('-'*15, i, len(to_evaluate_indexes), index, meta_info['total'], seed, to_save_name)) @@ -130,7 +156,7 @@ def main(save_dir, workers, datasets, xpaths, splits, srange, arch_index, seeds, logger.close() -def train_single_model(save_dir, workers, datasets, xpaths, splits, seeds, model_str, arch_config): +def train_single_model(save_dir, workers, datasets, xpaths, use_less, splits, seeds, model_str, arch_config): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.deterministic = True @@ -160,7 +186,7 @@ def train_single_model(save_dir, workers, datasets, xpaths, splits, seeds, model checkpoint = torch.load(to_save_name) else: logger.log('Does not find the existing file {:}, train and evaluate!'.format(to_save_name)) - checkpoint = evaluate_all_datasets(arch, datasets, xpaths, splits, seed, arch_config, workers, logger) + checkpoint = evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger) torch.save(checkpoint, to_save_name) # log information logger.log('{:}'.format(checkpoint['info'])) @@ -252,6 +278,7 @@ if __name__ == '__main__': parser.add_argument('--datasets', type=str, nargs='+', help='The applied datasets.') parser.add_argument('--xpaths', type=str, nargs='+', help='The root path for this dataset.') parser.add_argument('--splits', type=int, nargs='+', help='The root path for this dataset.') + parser.add_argument('--use_less', type=int, default=0, help='Using the less-training-epoch config.') parser.add_argument('--seeds' , type=int, nargs='+', help='The range of models to be evaluated') parser.add_argument('--channel', type=int, help='The number of channels.') parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.') @@ -264,7 +291,7 @@ if __name__ == '__main__': elif args.mode.startswith('specific'): assert len(args.mode.split('-')) == 2, 'invalid mode : {:}'.format(args.mode) model_str = args.mode.split('-')[1] - train_single_model(args.save_dir, args.workers, args.datasets, args.xpaths, args.splits, \ + train_single_model(args.save_dir, args.workers, args.datasets, args.xpaths, args.splits, args.use_less>0, \ tuple(args.seeds), model_str, {'channel': args.channel, 'num_cells': args.num_cells}) else: meta_path = Path(args.save_dir) / 'meta-node-{:}.pth'.format(args.max_node) @@ -276,7 +303,7 @@ if __name__ == '__main__': assert len(args.datasets) == len(args.xpaths) == len(args.splits), 'invalid infos : {:} vs {:} vs {:}'.format(len(args.datasets), len(args.xpaths), len(args.splits)) assert args.workers > 0, 'invalid number of workers : {:}'.format(args.workers) - main(args.save_dir, args.workers, args.datasets, args.xpaths, args.splits, \ + main(args.save_dir, args.workers, args.datasets, args.xpaths, args.splits, args.use_less>0, \ tuple(args.srange), args.arch_index, tuple(args.seeds), \ args.mode == 'cover', meta_info, \ {'channel': args.channel, 'num_cells': args.num_cells}) diff --git a/exps/AA_functions.py b/exps/AA_functions.py index a5253d0..3435b0c 100644 --- a/exps/AA_functions.py +++ b/exps/AA_functions.py @@ -47,6 +47,7 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode): elif mode == 'valid': network.eval() else: raise ValueError("The mode is not right : {:}".format(mode)) + batch_time, end = AverageMeter(), time.time() for i, (inputs, targets) in enumerate(xloader): if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader)) @@ -64,7 +65,10 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode): losses.update(loss.item(), inputs.size(0)) top1.update (prec1.item(), inputs.size(0)) top5.update (prec5.item(), inputs.size(0)) - return losses.avg, top1.avg, top5.avg + # count time + batch_time.update(time.time() - end) + end = time.time() + return losses.avg, top1.avg, top5.avg, batch_time.sum @@ -87,18 +91,21 @@ def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loader, see # start training start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {} + train_times , valid_times = {}, {} for epoch in range(total_epoch): scheduler.update(epoch, 0.0) - train_loss, train_acc1, train_acc5 = procedure(train_loader, network, criterion, scheduler, optimizer, 'train') + train_loss, train_acc1, train_acc5, train_tm = procedure(train_loader, network, criterion, scheduler, optimizer, 'train') with torch.no_grad(): - valid_loss, valid_acc1, valid_acc5 = procedure(valid_loader, network, criterion, None, None, 'valid') + valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(valid_loader, network, criterion, None, None, 'valid') train_losses[epoch] = train_loss train_acc1es[epoch] = train_acc1 train_acc5es[epoch] = train_acc5 valid_losses[epoch] = valid_loss valid_acc1es[epoch] = valid_acc1 valid_acc5es[epoch] = valid_acc5 + train_times [epoch] = train_tm + valid_times [epoch] = valid_tm # measure elapsed time epoch_time.update(time.time() - start_time) @@ -114,9 +121,11 @@ def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loader, see 'train_losses': train_losses, 'train_acc1es': train_acc1es, 'train_acc5es': train_acc5es, + 'train_times' : train_times, 'valid_losses': valid_losses, 'valid_acc1es': valid_acc1es, 'valid_acc5es': valid_acc5es, + 'valid_times' : valid_times, 'net_state_dict': net.state_dict(), 'net_string' : '{:}'.format(net), 'finish-train': True diff --git a/lib/models/cell_infers/cells.py b/lib/models/cell_infers/cells.py index 4cec78a..2071d5c 100644 --- a/lib/models/cell_infers/cells.py +++ b/lib/models/cell_infers/cells.py @@ -19,9 +19,9 @@ class InferCell(nn.Module): cur_innod = [] for (op_name, op_in) in node_info: if op_in == 0: - layer = OPS[op_name](C_in , C_out, stride) + layer = OPS[op_name](C_in , C_out, stride, True) else: - layer = OPS[op_name](C_out, C_out, 1) + layer = OPS[op_name](C_out, C_out, 1, True) cur_index.append( len(self.layers) ) cur_innod.append( op_in ) self.layers.append( layer ) diff --git a/lib/models/cell_infers/tiny_network.py b/lib/models/cell_infers/tiny_network.py index eb5c38c..818948c 100644 --- a/lib/models/cell_infers/tiny_network.py +++ b/lib/models/cell_infers/tiny_network.py @@ -22,7 +22,7 @@ class TinyNetwork(nn.Module): self.cells = nn.ModuleList() for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): if reduction: - cell = ResNetBasicblock(C_prev, C_curr, 2) + cell = ResNetBasicblock(C_prev, C_curr, 2, True) else: cell = InferCell(genotype, C_prev, C_curr, 1) self.cells.append( cell ) diff --git a/lib/models/cell_operations.py b/lib/models/cell_operations.py index 4e28f56..5f39a99 100644 --- a/lib/models/cell_operations.py +++ b/lib/models/cell_operations.py @@ -4,16 +4,16 @@ import torch import torch.nn as nn -__all__ = ['OPS', 'ReLUConvBN', 'ResNetBasicblock', 'SearchSpaceNames'] +__all__ = ['OPS', 'ResNetBasicblock', 'SearchSpaceNames'] OPS = { - 'none' : lambda C_in, C_out, stride: Zero(C_in, C_out, stride), - 'avg_pool_3x3' : lambda C_in, C_out, stride: POOLING(C_in, C_out, stride, 'avg'), - 'max_pool_3x3' : lambda C_in, C_out, stride: POOLING(C_in, C_out, stride, 'max'), - 'nor_conv_7x7' : lambda C_in, C_out, stride: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1)), - 'nor_conv_3x3' : lambda C_in, C_out, stride: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1)), - 'nor_conv_1x1' : lambda C_in, C_out, stride: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1)), - 'skip_connect' : lambda C_in, C_out, stride: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride), + 'none' : lambda C_in, C_out, stride, affine: Zero(C_in, C_out, stride), + 'avg_pool_3x3' : lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'avg'), + 'max_pool_3x3' : lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'max'), + 'nor_conv_7x7' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1), affine), + 'nor_conv_3x3' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine), + 'nor_conv_1x1' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1), affine), + 'skip_connect' : lambda C_in, C_out, stride, affine: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine), } CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3'] @@ -26,12 +26,12 @@ SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK, class ReLUConvBN(nn.Module): - def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation): + def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine): super(ReLUConvBN, self).__init__() self.op = nn.Sequential( nn.ReLU(inplace=False), nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False), - nn.BatchNorm2d(C_out) + nn.BatchNorm2d(C_out, affine=affine) ) def forward(self, x): @@ -40,17 +40,17 @@ class ReLUConvBN(nn.Module): class ResNetBasicblock(nn.Module): - def __init__(self, inplanes, planes, stride): + def __init__(self, inplanes, planes, stride, affine=True): super(ResNetBasicblock, self).__init__() assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) - self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1) - self.conv_b = ReLUConvBN( planes, planes, 3, 1, 1, 1) + self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1, affine) + self.conv_b = ReLUConvBN( planes, planes, 3, 1, 1, 1, affine) if stride == 2: self.downsample = nn.Sequential( nn.AvgPool2d(kernel_size=2, stride=2, padding=0), nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False)) elif inplanes != planes: - self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1) + self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1, affine) else: self.downsample = None self.in_dim = inplanes @@ -76,12 +76,12 @@ class ResNetBasicblock(nn.Module): class POOLING(nn.Module): - def __init__(self, C_in, C_out, stride, mode): + def __init__(self, C_in, C_out, stride, mode, affine=True): super(POOLING, self).__init__() if C_in == C_out: self.preprocess = None else: - self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0) + self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, affine) if mode == 'avg' : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1) else : raise ValueError('Invalid mode={:} in POOLING'.format(mode)) @@ -126,7 +126,7 @@ class Zero(nn.Module): class FactorizedReduce(nn.Module): - def __init__(self, C_in, C_out, stride): + def __init__(self, C_in, C_out, stride, affine): super(FactorizedReduce, self).__init__() self.stride = stride self.C_in = C_in @@ -141,8 +141,7 @@ class FactorizedReduce(nn.Module): self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0) else: raise ValueError('Invalid stride : {:}'.format(stride)) - - self.bn = nn.BatchNorm2d(C_out) + self.bn = nn.BatchNorm2d(C_out, affine=affine) def forward(self, x): x = self.relu(x) diff --git a/lib/models/cell_searchs/search_cells.py b/lib/models/cell_searchs/search_cells.py index d510ba1..f5af162 100644 --- a/lib/models/cell_searchs/search_cells.py +++ b/lib/models/cell_searchs/search_cells.py @@ -23,9 +23,9 @@ class SearchCell(nn.Module): for j in range(i): node_str = '{:}<-{:}'.format(i, j) if j == 0: - xlists = [OPS[op_name](C_in , C_out, stride) for op_name in op_names] + xlists = [OPS[op_name](C_in , C_out, stride, False) for op_name in op_names] else: - xlists = [OPS[op_name](C_in , C_out, 1) for op_name in op_names] + xlists = [OPS[op_name](C_in , C_out, 1, False) for op_name in op_names] self.edges[ node_str ] = nn.ModuleList( xlists ) self.edge_keys = sorted(list(self.edges.keys())) self.edge2index = {key:i for i, key in enumerate(self.edge_keys)} diff --git a/scripts-search/AA-NAS-train-archs.sh b/scripts-search/AA-NAS-train-archs.sh index f5224e8..7739e89 100644 --- a/scripts-search/AA-NAS-train-archs.sh +++ b/scripts-search/AA-NAS-train-archs.sh @@ -29,6 +29,7 @@ save_dir=./output/AA-NAS-BENCH-4/ OMP_NUM_THREADS=4 python ./exps/AA-NAS-Bench-main.py \ --mode ${mode} --save_dir ${save_dir} --max_node 4 \ + --use_less 0 \ --datasets cifar10 cifar10 cifar100 ImageNet16-120 \ --splits 1 0 0 0 \ --xpaths $TORCH_HOME/cifar.python \