diff --git a/.gitignore b/.gitignore index 9acd847..7c47a90 100755 --- a/.gitignore +++ b/.gitignore @@ -102,3 +102,4 @@ main_main.py # Device scripts-nas/.nfs00* */.nfs00* +*.DS_Store diff --git a/README.md b/README.md index 7c1984f..6ac42aa 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,16 @@ -# GDAS -By Xuanyi Dong and Yi Yang +# Searching for A Robust Neural Architecture in Four GPU Hours -University of Technology Sydney +We propose A Gradient-based neural architecture search approach using Differentiable Architecture Sampler (GDAS). -Requirements -- PyTorch 1.0 +## Requirements +- PyTorch 1.0.1 - Python 3.6 - opencv ``` conda install pytorch torchvision cuda100 -c pytorch ``` -## Algorithm +## Usages Train the searched CNN on CIFAR ``` @@ -26,6 +25,11 @@ CUDA_VISIBLE_DEVICES=0 bash ./scripts-cnn/train-imagenet.sh GDAS_F1 52 14 CUDA_VISIBLE_DEVICES=0 bash ./scripts-cnn/train-imagenet.sh GDAS_V1 50 14 ``` +Evaluate a trained CNN model +``` +CUDA_VISIBLE_DEVICES=0 python ./exps-cnn/evaluate.py --data_path $TORCH_HOME/cifar.python --checkpoint ${checkpoint-path} +CUDA_VISIBLE_DEVICES=0 python ./exps-cnn/evaluate.py --data_path $TORCH_HOME/ILSVRC2012 --checkpoint ${checkpoint-path} +``` Train the searched RNN ``` @@ -36,3 +40,13 @@ CUDA_VISIBLE_DEVICES=0 bash ./scripts-rnn/train-WT2.sh DARTS_V1 CUDA_VISIBLE_DEVICES=0 bash ./scripts-rnn/train-WT2.sh DARTS_V2 CUDA_VISIBLE_DEVICES=0 bash ./scripts-rnn/train-WT2.sh GDAS ``` + +## Citation +``` +@inproceedings{dong2019search, + title={Searching for A Robust Neural Architecture in Four GPU Hours}, + author={Dong, Xuanyi and Yang, Yi}, + booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, + year={2019} +} +``` diff --git a/exps-cnn/evaluate.py b/exps-cnn/evaluate.py new file mode 100644 index 0000000..be59b44 --- /dev/null +++ b/exps-cnn/evaluate.py @@ -0,0 +1,49 @@ +import os, sys, time, glob, random, argparse +import numpy as np +from copy import deepcopy +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.datasets as dset +import torch.backends.cudnn as cudnn +import torchvision.transforms as transforms +from pathlib import Path +lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() +if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +from utils import AverageMeter, time_string, convert_secs2time +from utils import print_log, obtain_accuracy +from utils import Cutout, count_parameters_in_MB +from nas import model_types as models +from train_utils import main_procedure +from train_utils_imagenet import main_procedure_imagenet +from scheduler import load_config + + +parser = argparse.ArgumentParser("Evaluate-CNN") +parser.add_argument('--data_path', type=str, help='Path to dataset.') +parser.add_argument('--checkpoint', type=str, help='Choose between Cifar10/100 and ImageNet.') +args = parser.parse_args() + +assert torch.cuda.is_available(), 'torch.cuda is not available' + + +def main(): + + assert os.path.isdir( args.data_path ), 'invalid data-path : {:}'.format(args.data_path) + assert os.path.isfile( args.checkpoint ), 'invalid checkpoint : {:}'.format(args.checkpoint) + + checkpoint = torch.load( args.checkpoint ) + xargs = checkpoint['args'] + config = load_config(xargs.model_config) + genotype = models[xargs.arch] + + # clear GPU cache + torch.cuda.empty_cache() + if xargs.dataset == 'imagenet': + main_procedure_imagenet(config, args.data_path, xargs, genotype, xargs.init_channels, xargs.layers, checkpoint['state_dict'], None) + else: + main_procedure(config, xargs.dataset, args.data_path, xargs, genotype, xargs.init_channels, xargs.layers, checkpoint['state_dict'], None) + + +if __name__ == '__main__': + main() diff --git a/exps-cnn/train_base.py b/exps-cnn/train_base.py index 18e8379..d25d0a0 100644 --- a/exps-cnn/train_base.py +++ b/exps-cnn/train_base.py @@ -19,7 +19,7 @@ from train_utils_imagenet import main_procedure_imagenet from scheduler import load_config -parser = argparse.ArgumentParser("cifar") +parser = argparse.ArgumentParser("Train-CNN") parser.add_argument('--data_path', type=str, help='Path to dataset') parser.add_argument('--dataset', type=str, choices=['imagenet', 'cifar10', 'cifar100'], help='Choose between Cifar10/100 and ImageNet.') parser.add_argument('--arch', type=str, choices=models.keys(), help='the searched model.') @@ -38,6 +38,7 @@ args = parser.parse_args() assert torch.cuda.is_available(), 'torch.cuda is not available' + if args.manualSeed is None: args.manualSeed = random.randint(1, 10000) random.seed(args.manualSeed) @@ -72,9 +73,9 @@ def main(): # clear GPU cache torch.cuda.empty_cache() if args.dataset == 'imagenet': - main_procedure_imagenet(config, args.data_path, args, genotype, args.init_channels, args.layers, log) + main_procedure_imagenet(config, args.data_path, args, genotype, args.init_channels, args.layers, None, log) else: - main_procedure(config, args.dataset, args.data_path, args, genotype, args.init_channels, args.layers, log) + main_procedure(config, args.dataset, args.data_path, args, genotype, args.init_channels, args.layers, None, log) log.close() diff --git a/exps-cnn/train_utils.py b/exps-cnn/train_utils.py index c5e29fb..cdb0efe 100644 --- a/exps-cnn/train_utils.py +++ b/exps-cnn/train_utils.py @@ -2,7 +2,7 @@ import os, sys, time from copy import deepcopy import torch import torchvision.transforms as transforms - +from shutil import copyfile from utils import print_log, obtain_accuracy, AverageMeter from utils import time_string, convert_secs2time @@ -11,6 +11,7 @@ from utils import Cutout from nas import NetworkCIFAR as Network from datasets import get_datasets + def obtain_best(accuracies): if len(accuracies) == 0: return (0, 0) tops = [value for key, value in accuracies.items()] @@ -18,7 +19,7 @@ def obtain_best(accuracies): return s2b[-1] -def main_procedure(config, dataset, data_path, args, genotype, init_channels, layers, log): +def main_procedure(config, dataset, data_path, args, genotype, init_channels, layers, pure_evaluate, log): train_data, test_data, class_num = get_datasets(dataset, data_path, config.cutout) @@ -57,10 +58,17 @@ def main_procedure(config, dataset, data_path, args, genotype, init_channels, la else: raise ValueError('Can not find the schedular type : {:}'.format(config.type)) - checkpoint_path = os.path.join(args.save_path, 'checkpoint-{:}-model.pth'.format(dataset)) - if os.path.isfile(checkpoint_path): - checkpoint = torch.load( checkpoint_path ) + checkpoint_path = os.path.join(args.save_path, 'checkpoint-{:}-model.pth'.format(dataset)) + checkpoint_best = os.path.join(args.save_path, 'checkpoint-{:}-best.pth'.format(dataset)) + if pure_evaluate: + print_log('-'*20 + 'Pure Evaluation' + '-'*20, log) + basemodel.load_state_dict( pure_evaluate ) + with torch.no_grad(): + valid_acc1, valid_acc5, valid_los = _train(test_loader, model, criterion, optimizer, 'test', -1, config, args.print_freq, log) + return (valid_acc1, valid_acc5) + elif os.path.isfile(checkpoint_path): + checkpoint = torch.load( checkpoint_path ) start_epoch = checkpoint['epoch'] basemodel.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) @@ -96,12 +104,14 @@ def main_procedure(config, dataset, data_path, args, genotype, init_channels, la 'accuracies': accuracies}, checkpoint_path) best_acc = obtain_best( accuracies ) + if accuracies[epoch] == best_acc: copyfile(checkpoint_path, checkpoint_best) print_log('----> Best Accuracy : Acc@1={:.2f}, Acc@5={:.2f}, Error@1={:.2f}, Error@5={:.2f}'.format(best_acc[0], best_acc[1], 100-best_acc[0], 100-best_acc[1]), log) print_log('----> Save into {:}'.format(checkpoint_path), log) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() + return obtain_best( accuracies ) def _train(xloader, model, criterion, optimizer, mode, epoch, config, print_freq, log): diff --git a/exps-cnn/train_utils_imagenet.py b/exps-cnn/train_utils_imagenet.py index 76fdd42..24357cc 100644 --- a/exps-cnn/train_utils_imagenet.py +++ b/exps-cnn/train_utils_imagenet.py @@ -3,7 +3,7 @@ from copy import deepcopy import torch import torch.nn as nn import torchvision.transforms as transforms - +from shutil import copyfile from utils import print_log, obtain_accuracy, AverageMeter from utils import time_string, convert_secs2time @@ -37,7 +37,7 @@ class CrossEntropyLabelSmooth(nn.Module): return loss -def main_procedure_imagenet(config, data_path, args, genotype, init_channels, layers, log): +def main_procedure_imagenet(config, data_path, args, genotype, init_channels, layers, pure_evaluate, log): # training data and testing data train_data, valid_data, class_num = get_datasets('imagenet-1k', data_path, -1) @@ -48,8 +48,6 @@ def main_procedure_imagenet(config, data_path, args, genotype, init_channels, la valid_queue = torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers) - class_num = 1000 - print_log('-------------------------------------- main-procedure', log) print_log('config : {:}'.format(config), log) print_log('genotype : {:}'.format(genotype), log) @@ -84,9 +82,16 @@ def main_procedure_imagenet(config, data_path, args, genotype, init_channels, la checkpoint_path = os.path.join(args.save_path, 'checkpoint-imagenet-model.pth') - if os.path.isfile(checkpoint_path): - checkpoint = torch.load( checkpoint_path ) + checkpoint_best = os.path.join(args.save_path, 'checkpoint-imagenet-best.pth') + if pure_evaluate: + print_log('-'*20 + 'Pure Evaluation' + '-'*20, log) + basemodel.load_state_dict( pure_evaluate ) + with torch.no_grad(): + valid_acc1, valid_acc5, valid_los = _train(valid_queue, model, criterion, None, 'test' , -1, config, args.print_freq, log) + return (valid_acc1, valid_acc5) + elif os.path.isfile(checkpoint_path): + checkpoint = torch.load( checkpoint_path ) start_epoch = checkpoint['epoch'] basemodel.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) @@ -122,12 +127,14 @@ def main_procedure_imagenet(config, data_path, args, genotype, init_channels, la 'accuracies': accuracies}, checkpoint_path) best_acc = obtain_best( accuracies ) + if accuracies[epoch] == best_acc: copyfile(checkpoint_path, checkpoint_best) print_log('----> Best Accuracy : Acc@1={:.2f}, Acc@5={:.2f}, Error@1={:.2f}, Error@5={:.2f}'.format(best_acc[0], best_acc[1], 100-best_acc[0], 100-best_acc[1]), log) print_log('----> Save into {:}'.format(checkpoint_path), log) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() + return obtain_best( accuracies ) def _train(xloader, model, criterion, optimizer, mode, epoch, config, print_freq, log): diff --git a/lib/datasets/get_dataset_with_transform.py b/lib/datasets/get_dataset_with_transform.py index 6112b7a..4e5e0b6 100644 --- a/lib/datasets/get_dataset_with_transform.py +++ b/lib/datasets/get_dataset_with_transform.py @@ -7,6 +7,7 @@ import torchvision.transforms as transforms from utils import Cutout from .TieredImageNet import TieredImageNet + Dataset2Class = {'cifar10' : 10, 'cifar100': 100, 'tiered' : -1, @@ -59,11 +60,11 @@ def get_datasets(name, root, cutout): else: raise TypeError("Unknow dataset : {:}".format(name)) if name == 'cifar10': - train_data = dset.CIFAR10(root, train=True, transform=train_transform, download=True) - test_data = dset.CIFAR10(root, train=True, transform=test_transform , download=True) + train_data = dset.CIFAR10(root, train=True , transform=train_transform, download=True) + test_data = dset.CIFAR10(root, train=False, transform=test_transform , download=True) elif name == 'cifar100': - train_data = dset.CIFAR100(root, train=True, transform=train_transform, download=True) - test_data = dset.CIFAR100(root, train=True, transform=test_transform , download=True) + train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True) + test_data = dset.CIFAR100(root, train=False, transform=test_transform , download=True) elif name == 'imagenet-1k' or name == 'imagenet-100': train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform) test_data = dset.ImageFolder(osp.join(root, 'val'), train_transform) diff --git a/lib/nas/__init__.py b/lib/nas/__init__.py index e092e54..fd347f4 100644 --- a/lib/nas/__init__.py +++ b/lib/nas/__init__.py @@ -1,12 +1,5 @@ from .model_search import Network -from .model_search_v1 import NetworkV1 -from .model_search_f1 import NetworkF1 # acceleration model -from .model_search_f1_acc2 import NetworkFACC1 -from .model_search_acc2 import NetworkACC2 -from .model_search_v3 import NetworkV3 -from .model_search_v4 import NetworkV4 -from .model_search_v5 import NetworkV5 from .CifarNet import NetworkCIFAR from .ImageNet import NetworkImageNet diff --git a/lib/nas/construct_utils.py b/lib/nas/construct_utils.py index f7a0a67..0207e1b 100644 --- a/lib/nas/construct_utils.py +++ b/lib/nas/construct_utils.py @@ -128,7 +128,7 @@ class Transition(nn.Module): self.ops2 = nn.ModuleList( [nn.Sequential( - nn.MaxPool2d(3, stride=1, padding=1), + nn.MaxPool2d(3, stride=2, padding=1), nn.BatchNorm2d(C, affine=True)), nn.Sequential( nn.MaxPool2d(3, stride=2, padding=1), @@ -144,7 +144,8 @@ class Transition(nn.Module): if self.training and drop_prob > 0.: X0, X1 = drop_path(X0, drop_prob), drop_path(X1, drop_prob) - X2 = self.ops2[0] (X0+X1) + #X2 = self.ops2[0] (X0+X1) + X2 = self.ops2[0] (s0) X3 = self.ops2[1] (s1) if self.training and drop_prob > 0.: X2, X3 = drop_path(X2, drop_prob), drop_path(X3, drop_prob) diff --git a/lib/nas/model_search_acc2.py b/lib/nas/model_search_acc2.py deleted file mode 100644 index 9a1ca90..0000000 --- a/lib/nas/model_search_acc2.py +++ /dev/null @@ -1,180 +0,0 @@ -# gumbel softmax -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn.parameter import Parameter - -from .operations import OPS, FactorizedReduce, ReLUConvBN -from .genotypes import PRIMITIVES, Genotype - - -class MixedOp(nn.Module): - - def __init__(self, C, stride): - super(MixedOp, self).__init__() - self._ops = nn.ModuleList() - for primitive in PRIMITIVES: - op = OPS[primitive](C, stride, False) - self._ops.append(op) - - def forward(self, x, weights, cpu_weights): - use_sum = sum([abs(_) > 1e-10 for _ in cpu_weights]) - if use_sum > 3: - return sum(w * op(x) for w, op in zip(weights, self._ops)) - else: - clist = [] - for j, cpu_weight in enumerate(cpu_weights): - if abs(cpu_weight) > 1e-10: - clist.append( weights[j] * self._ops[j](x) ) - assert len(clist) > 0, 'invalid length : {:}'.format(cpu_weights) - return sum(clist) - - -class Cell(nn.Module): - - def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev): - super(Cell, self).__init__() - self.reduction = reduction - - if reduction_prev: - self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False) - else: - self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False) - self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False) - self._steps = steps - self._multiplier = multiplier - - self._ops = nn.ModuleList() - for i in range(self._steps): - for j in range(2+i): - stride = 2 if reduction and j < 2 else 1 - op = MixedOp(C, stride) - self._ops.append(op) - - def forward(self, s0, s1, weights): - s0 = self.preprocess0(s0) - s1 = self.preprocess1(s1) - - cpu_weights = weights.tolist() - states = [s0, s1] - offset = 0 - for i in range(self._steps): - clist = [] - for j, h in enumerate(states): - x = self._ops[offset+j](h, weights[offset+j], cpu_weights[offset+j]) - clist.append( x ) - s = sum(clist) - offset += len(states) - states.append(s) - - return torch.cat(states[-self._multiplier:], dim=1) - - -class NetworkACC2(nn.Module): - - def __init__(self, C, num_classes, layers, steps=4, multiplier=4, stem_multiplier=3): - super(NetworkACC2, self).__init__() - self._C = C - self._num_classes = num_classes - self._layers = layers - self._steps = steps - self._multiplier = multiplier - - C_curr = stem_multiplier*C - self.stem = nn.Sequential( - nn.Conv2d(3, C_curr, 3, padding=1, bias=False), - nn.BatchNorm2d(C_curr) - ) - - C_prev_prev, C_prev, C_curr = C_curr, C_curr, C - reduction_prev, cells = False, [] - for i in range(layers): - if i in [layers//3, 2*layers//3]: - C_curr *= 2 - reduction = True - else: - reduction = False - cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) - reduction_prev = reduction - cells.append( cell ) - C_prev_prev, C_prev = C_prev, multiplier*C_curr - self.cells = nn.ModuleList(cells) - - self.global_pooling = nn.AdaptiveAvgPool2d(1) - self.classifier = nn.Linear(C_prev, num_classes) - self.tau = 5 - self.use_gumbel = True - - # initialize architecture parameters - k = sum(1 for i in range(self._steps) for n in range(2+i)) - num_ops = len(PRIMITIVES) - - self.alphas_normal = Parameter(torch.Tensor(k, num_ops)) - self.alphas_reduce = Parameter(torch.Tensor(k, num_ops)) - nn.init.normal_(self.alphas_normal, 0, 0.001) - nn.init.normal_(self.alphas_reduce, 0, 0.001) - - def set_gumbel(self, use_gumbel): - self.use_gumbel = use_gumbel - - def set_tau(self, tau): - self.tau = tau - - def get_tau(self): - return self.tau - - def arch_parameters(self): - return [self.alphas_normal, self.alphas_reduce] - - def base_parameters(self): - lists = list(self.stem.parameters()) + list(self.cells.parameters()) - lists += list(self.global_pooling.parameters()) - lists += list(self.classifier.parameters()) - return lists - - def forward(self, inputs): - batch, C, H, W = inputs.size() - s0 = s1 = self.stem(inputs) - for i, cell in enumerate(self.cells): - if cell.reduction: - if self.use_gumbel : weights = F.gumbel_softmax(self.alphas_reduce, self.tau, True) - else : weights = F.softmax(self.alphas_reduce, dim=-1) - else: - if self.use_gumbel : weights = F.gumbel_softmax(self.alphas_normal, self.tau, True) - else : weights = F.softmax(self.alphas_normal, dim=-1) - - s0, s1 = s1, cell(s0, s1, weights) - out = self.global_pooling(s1) - out = out.view(batch, -1) - logits = self.classifier(out) - return logits - - def genotype(self): - - def _parse(weights): - gene, n, start = [], 2, 0 - for i in range(self._steps): - end = start + n - W = weights[start:end].copy() - edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2] - for j in edges: - k_best = None - for k in range(len(W[j])): - if k != PRIMITIVES.index('none'): - if k_best is None or W[j][k] > W[j][k_best]: - k_best = k - gene.append((PRIMITIVES[k_best], j, float(W[j][k_best]))) - start = end - n += 1 - return gene - - with torch.no_grad(): - gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).cpu().numpy()) - gene_reduce = _parse(F.softmax(self.alphas_reduce, dim=-1).cpu().numpy()) - - concat = range(2+self._steps-self._multiplier, self._steps+2) - genotype = Genotype( - normal=gene_normal, normal_concat=concat, - reduce=gene_reduce, reduce_concat=concat - ) - return genotype diff --git a/lib/nas/model_search_f1.py b/lib/nas/model_search_f1.py deleted file mode 100644 index 198dfdb..0000000 --- a/lib/nas/model_search_f1.py +++ /dev/null @@ -1,167 +0,0 @@ -# share parameters -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn.parameter import Parameter - -from .operations import OPS, FactorizedReduce, ReLUConvBN -from .construct_utils import Transition -from .genotypes import PRIMITIVES, Genotype - - -class MixedOp(nn.Module): - - def __init__(self, C, stride): - super(MixedOp, self).__init__() - self._ops = nn.ModuleList() - for primitive in PRIMITIVES: - op = OPS[primitive](C, stride, False) - self._ops.append(op) - - def forward(self, x, weights): - return sum(w * op(x) for w, op in zip(weights, self._ops)) - - -class Cell(nn.Module): - - def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev): - super(Cell, self).__init__() - self.reduction = reduction - - if reduction_prev: - self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False) - else: - self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False) - self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False) - self._steps = steps - self._multiplier = multiplier - - self._ops = nn.ModuleList() - for i in range(self._steps): - for j in range(2+i): - stride = 2 if reduction and j < 2 else 1 - op = MixedOp(C, stride) - self._ops.append(op) - - def forward(self, s0, s1, weights): - s0 = self.preprocess0(s0) - s1 = self.preprocess1(s1) - - states = [s0, s1] - offset = 0 - for i in range(self._steps): - clist = [] - for j, h in enumerate(states): - x = self._ops[offset+j](h, weights[offset+j]) - clist.append( x ) - s = sum(clist) - offset += len(states) - states.append(s) - - return torch.cat(states[-self._multiplier:], dim=1) - - -class NetworkF1(nn.Module): - - def __init__(self, C, num_classes, layers, steps=4, multiplier=4, stem_multiplier=3): - super(NetworkF1, self).__init__() - self._C = C - self._num_classes = num_classes - self._layers = layers - self._steps = steps - self._multiplier = multiplier - - C_curr = stem_multiplier*C - self.stem = nn.Sequential( - nn.Conv2d(3, C_curr, 3, padding=1, bias=False), - nn.BatchNorm2d(C_curr) - ) - - C_prev_prev, C_prev, C_curr = C_curr, C_curr, C - reduction_prev, cells = False, [] - for i in range(layers): - if i in [layers//3, 2*layers//3]: - C_curr *= 2 - reduction = True - else: - reduction = False - if reduction: - cell = Transition(C_prev_prev, C_prev, C_curr, reduction_prev, multiplier) - else: - cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) - reduction_prev = reduction - cells.append( cell ) - C_prev_prev, C_prev = C_prev, multiplier*C_curr - self.cells = nn.ModuleList(cells) - - self.global_pooling = nn.AdaptiveAvgPool2d(1) - self.classifier = nn.Linear(C_prev, num_classes) - - # initialize architecture parameters - k = sum(1 for i in range(self._steps) for n in range(2+i)) - num_ops = len(PRIMITIVES) - - self.alphas_normal = Parameter(torch.Tensor(k, num_ops)) - #self.alphas_reduce = Parameter(torch.Tensor(k, num_ops)) - nn.init.normal_(self.alphas_normal, 0, 0.001) - #nn.init.normal_(self.alphas_reduce, 0, 0.001) - - def set_tau(self, tau): - return -1 - - def get_tau(self): - return -1 - - def arch_parameters(self): - return [self.alphas_normal] - - def base_parameters(self): - lists = list(self.stem.parameters()) + list(self.cells.parameters()) - lists += list(self.global_pooling.parameters()) - lists += list(self.classifier.parameters()) - return lists - - def forward(self, inputs): - batch, C, H, W = inputs.size() - s0 = s1 = self.stem(inputs) - for i, cell in enumerate(self.cells): - if cell.reduction: - s0, s1 = s1, cell(s0, s1) - else: - weights = F.softmax(self.alphas_normal, dim=-1) - s0, s1 = s1, cell(s0, s1, weights) - #print('{:} : s0 : {:}, s1 : {:}'.format(i, s0.size(), s1.size())) - out = self.global_pooling(s1) - out = out.view(batch, -1) - logits = self.classifier(out) - return logits - - def genotype(self): - - def _parse(weights): - gene, n, start = [], 2, 0 - for i in range(self._steps): - end = start + n - W = weights[start:end].copy() - edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2] - for j in edges: - k_best = None - for k in range(len(W[j])): - if k != PRIMITIVES.index('none'): - if k_best is None or W[j][k] > W[j][k_best]: - k_best = k - gene.append((PRIMITIVES[k_best], j, float(W[j][k_best]))) - start = end - n += 1 - return gene - - with torch.no_grad(): - gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).cpu().numpy()) - #gene_reduce = _parse(F.softmax(self.alphas_normal, dim=-1).cpu().numpy()) - - concat = range(2+self._steps-self._multiplier, self._steps+2) - genotype = Genotype( - normal=gene_normal, normal_concat=concat, - reduce=None , reduce_concat=concat - ) - return genotype diff --git a/lib/nas/model_search_f1_acc2.py b/lib/nas/model_search_f1_acc2.py deleted file mode 100644 index 99e2e85..0000000 --- a/lib/nas/model_search_f1_acc2.py +++ /dev/null @@ -1,183 +0,0 @@ -# share parameters -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn.parameter import Parameter - -from .operations import OPS, FactorizedReduce, ReLUConvBN -from .construct_utils import Transition -from .genotypes import PRIMITIVES, Genotype - - -class MixedOp(nn.Module): - - def __init__(self, C, stride): - super(MixedOp, self).__init__() - self._ops = nn.ModuleList() - for primitive in PRIMITIVES: - op = OPS[primitive](C, stride, False) - self._ops.append(op) - - def forward(self, x, weights, cpu_weights): - use_sum = sum([abs(_) > 1e-10 for _ in cpu_weights]) - if use_sum > 3: - return sum(w * op(x) for w, op in zip(weights, self._ops)) - else: - clist = [] - for j, cpu_weight in enumerate(cpu_weights): - if abs(cpu_weight) > 1e-10: - clist.append( weights[j] * self._ops[j](x) ) - assert len(clist) > 0, 'invalid length : {:}'.format(cpu_weights) - return sum(clist) - - -class Cell(nn.Module): - - def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev): - super(Cell, self).__init__() - self.reduction = reduction - - if reduction_prev: - self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False) - else: - self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False) - self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False) - self._steps = steps - self._multiplier = multiplier - - self._ops = nn.ModuleList() - for i in range(self._steps): - for j in range(2+i): - stride = 2 if reduction and j < 2 else 1 - op = MixedOp(C, stride) - self._ops.append(op) - - def forward(self, s0, s1, weights): - s0 = self.preprocess0(s0) - s1 = self.preprocess1(s1) - - cpu_weights = weights.tolist() - states = [s0, s1] - offset = 0 - for i in range(self._steps): - clist = [] - for j, h in enumerate(states): - x = self._ops[offset+j](h, weights[offset+j], cpu_weights[offset+j]) - clist.append( x ) - s = sum(clist) - offset += len(states) - states.append(s) - - return torch.cat(states[-self._multiplier:], dim=1) - - -class NetworkFACC1(nn.Module): - - def __init__(self, C, num_classes, layers, steps=4, multiplier=4, stem_multiplier=3): - super(NetworkFACC1, self).__init__() - self._C = C - self._num_classes = num_classes - self._layers = layers - self._steps = steps - self._multiplier = multiplier - self.tau = 5 - self.use_gumbel = True - - C_curr = stem_multiplier*C - self.stem = nn.Sequential( - nn.Conv2d(3, C_curr, 3, padding=1, bias=False), - nn.BatchNorm2d(C_curr) - ) - - C_prev_prev, C_prev, C_curr = C_curr, C_curr, C - reduction_prev, cells = False, [] - for i in range(layers): - if i in [layers//3, 2*layers//3]: - C_curr *= 2 - reduction = True - else: - reduction = False - if reduction: - cell = Transition(C_prev_prev, C_prev, C_curr, reduction_prev, multiplier) - else: - cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) - reduction_prev = reduction - cells.append( cell ) - C_prev_prev, C_prev = C_prev, multiplier*C_curr - self.cells = nn.ModuleList(cells) - - self.global_pooling = nn.AdaptiveAvgPool2d(1) - self.classifier = nn.Linear(C_prev, num_classes) - - # initialize architecture parameters - k = sum(1 for i in range(self._steps) for n in range(2+i)) - num_ops = len(PRIMITIVES) - - self.alphas_normal = Parameter(torch.Tensor(k, num_ops)) - #self.alphas_reduce = Parameter(torch.Tensor(k, num_ops)) - nn.init.normal_(self.alphas_normal, 0, 0.001) - #nn.init.normal_(self.alphas_reduce, 0, 0.001) - - def set_gumbel(self, use_gumbel): - self.use_gumbel = use_gumbel - - def set_tau(self, tau): - self.tau = tau - - def get_tau(self): - return self.tau - - def arch_parameters(self): - return [self.alphas_normal] - - def base_parameters(self): - lists = list(self.stem.parameters()) + list(self.cells.parameters()) - lists += list(self.global_pooling.parameters()) - lists += list(self.classifier.parameters()) - return lists - - def forward(self, inputs): - batch, C, H, W = inputs.size() - s0 = s1 = self.stem(inputs) - for i, cell in enumerate(self.cells): - if cell.reduction: - s0, s1 = s1, cell(s0, s1) - else: - if self.use_gumbel : weights = F.gumbel_softmax(self.alphas_normal, self.tau, True) - else : weights = F.softmax(self.alphas_normal, dim=-1) - s0, s1 = s1, cell(s0, s1, weights) - #print('{:} : s0 : {:}, s1 : {:}'.format(i, s0.size(), s1.size())) - out = self.global_pooling(s1) - out = out.view(batch, -1) - logits = self.classifier(out) - return logits - - def genotype(self): - - def _parse(weights): - gene, n, start = [], 2, 0 - for i in range(self._steps): - end = start + n - W = weights[start:end].copy() - edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2] - for j in edges: - k_best = None - for k in range(len(W[j])): - if k != PRIMITIVES.index('none'): - if k_best is None or W[j][k] > W[j][k_best]: - k_best = k - gene.append((PRIMITIVES[k_best], j, float(W[j][k_best]))) - start = end - n += 1 - return gene - - with torch.no_grad(): - gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).cpu().numpy()) - #gene_reduce = _parse(F.softmax(self.alphas_normal, dim=-1).cpu().numpy()) - - concat = range(2+self._steps-self._multiplier, self._steps+2) - genotype = Genotype( - normal=gene_normal, normal_concat=concat, - reduce=None , reduce_concat=concat - ) - return genotype diff --git a/lib/nas/model_search_v1.py b/lib/nas/model_search_v1.py deleted file mode 100644 index 18f2509..0000000 --- a/lib/nas/model_search_v1.py +++ /dev/null @@ -1,161 +0,0 @@ -# share parameters -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn.parameter import Parameter -from .operations import OPS, FactorizedReduce, ReLUConvBN -from .genotypes import PRIMITIVES, Genotype - - -class MixedOp(nn.Module): - - def __init__(self, C, stride): - super(MixedOp, self).__init__() - self._ops = nn.ModuleList() - for primitive in PRIMITIVES: - op = OPS[primitive](C, stride, False) - self._ops.append(op) - - def forward(self, x, weights): - return sum(w * op(x) for w, op in zip(weights, self._ops)) - - -class Cell(nn.Module): - - def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev): - super(Cell, self).__init__() - self.reduction = reduction - - if reduction_prev: - self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False) - else: - self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False) - self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False) - self._steps = steps - self._multiplier = multiplier - - self._ops = nn.ModuleList() - for i in range(self._steps): - for j in range(2+i): - stride = 2 if reduction and j < 2 else 1 - op = MixedOp(C, stride) - self._ops.append(op) - - def forward(self, s0, s1, weights): - s0 = self.preprocess0(s0) - s1 = self.preprocess1(s1) - - states = [s0, s1] - offset = 0 - for i in range(self._steps): - clist = [] - for j, h in enumerate(states): - x = self._ops[offset+j](h, weights[offset+j]) - clist.append( x ) - s = sum(clist) - offset += len(states) - states.append(s) - - return torch.cat(states[-self._multiplier:], dim=1) - - -class NetworkV1(nn.Module): - - def __init__(self, C, num_classes, layers, steps=4, multiplier=4, stem_multiplier=3): - super(NetworkV1, self).__init__() - self._C = C - self._num_classes = num_classes - self._layers = layers - self._steps = steps - self._multiplier = multiplier - - C_curr = stem_multiplier*C - self.stem = nn.Sequential( - nn.Conv2d(3, C_curr, 3, padding=1, bias=False), - nn.BatchNorm2d(C_curr) - ) - - C_prev_prev, C_prev, C_curr = C_curr, C_curr, C - reduction_prev, cells = False, [] - for i in range(layers): - if i in [layers//3, 2*layers//3]: - C_curr *= 2 - reduction = True - else: - reduction = False - cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) - reduction_prev = reduction - cells.append( cell ) - C_prev_prev, C_prev = C_prev, multiplier*C_curr - self.cells = nn.ModuleList(cells) - - self.global_pooling = nn.AdaptiveAvgPool2d(1) - self.classifier = nn.Linear(C_prev, num_classes) - - # initialize architecture parameters - k = sum(1 for i in range(self._steps) for n in range(2+i)) - num_ops = len(PRIMITIVES) - - self.alphas_normal = Parameter(torch.Tensor(k, num_ops)) - #self.alphas_reduce = Parameter(torch.Tensor(k, num_ops)) - nn.init.normal_(self.alphas_normal, 0, 0.001) - #nn.init.normal_(self.alphas_reduce, 0, 0.001) - - def set_tau(self, tau): - return -1 - - def get_tau(self): - return -1 - - def arch_parameters(self): - return [self.alphas_normal] - - def base_parameters(self): - lists = list(self.stem.parameters()) + list(self.cells.parameters()) - lists += list(self.global_pooling.parameters()) - lists += list(self.classifier.parameters()) - return lists - - def forward(self, inputs): - batch, C, H, W = inputs.size() - s0 = s1 = self.stem(inputs) - for i, cell in enumerate(self.cells): - if cell.reduction: - weights = F.softmax(self.alphas_normal, dim=-1) - else: - weights = F.softmax(self.alphas_normal, dim=-1) - s0, s1 = s1, cell(s0, s1, weights) - out = self.global_pooling(s1) - out = out.view(batch, -1) - logits = self.classifier(out) - return logits - - def genotype(self): - - def _parse(weights): - gene, n, start = [], 2, 0 - for i in range(self._steps): - end = start + n - W = weights[start:end].copy() - edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2] - for j in edges: - k_best = None - for k in range(len(W[j])): - if k != PRIMITIVES.index('none'): - if k_best is None or W[j][k] > W[j][k_best]: - k_best = k - gene.append((PRIMITIVES[k_best], j, float(W[j][k_best]))) - start = end - n += 1 - return gene - - with torch.no_grad(): - gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).cpu().numpy()) - gene_reduce = _parse(F.softmax(self.alphas_normal, dim=-1).cpu().numpy()) - - concat = range(2+self._steps-self._multiplier, self._steps+2) - genotype = Genotype( - normal=gene_normal, normal_concat=concat, - reduce=gene_reduce, reduce_concat=concat - ) - return genotype diff --git a/lib/nas/model_search_v3.py b/lib/nas/model_search_v3.py deleted file mode 100644 index 903fa09..0000000 --- a/lib/nas/model_search_v3.py +++ /dev/null @@ -1,171 +0,0 @@ -# random selection -import torch -import random -import torch.nn as nn -import torch.nn.functional as F -from torch.nn.parameter import Parameter -from .operations import OPS, FactorizedReduce, ReLUConvBN -from .genotypes import PRIMITIVES, Genotype -from .construct_utils import random_select, all_select - - - -class MixedOp(nn.Module): - - def __init__(self, C, stride): - super(MixedOp, self).__init__() - self._ops = nn.ModuleList() - for primitive in PRIMITIVES: - op = OPS[primitive](C, stride, False) - self._ops.append(op) - - def forward(self, x, weights, cpu_weights): - return sum(w * op(x) for w, op in zip(weights, self._ops)) - - -class Cell(nn.Module): - - def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev): - super(Cell, self).__init__() - self.reduction = reduction - - if reduction_prev: - self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False) - else: - self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False) - self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False) - self._steps = steps - self._multiplier = multiplier - - self._ops = nn.ModuleList() - for i in range(self._steps): - for j in range(2+i): - stride = 2 if reduction and j < 2 else 1 - op = MixedOp(C, stride) - self._ops.append(op) - - def forward(self, s0, s1, weights): - s0 = self.preprocess0(s0) - s1 = self.preprocess1(s1) - - cpu_weights = weights.tolist() - states = [s0, s1] - offset = 0 - for i in range(self._steps): - clist = [] - if i == 0: - indicator = all_select( len(states) ) - else: - indicator = random_select( len(states), 0.5 ) - for j, h in enumerate(states): - if indicator[j] == 0: continue - x = self._ops[offset+j](h, weights[offset+j], cpu_weights[offset+j]) - clist.append( x ) - s = sum(clist) / sum(indicator) - offset += len(states) - states.append(s) - - return torch.cat(states[-self._multiplier:], dim=1) - - -class NetworkV3(nn.Module): - - def __init__(self, C, num_classes, layers, steps=4, multiplier=4, stem_multiplier=3): - super(NetworkV3, self).__init__() - self._C = C - self._num_classes = num_classes - self._layers = layers - self._steps = steps - self._multiplier = multiplier - - C_curr = stem_multiplier*C - self.stem = nn.Sequential( - nn.Conv2d(3, C_curr, 3, padding=1, bias=False), - nn.BatchNorm2d(C_curr) - ) - - C_prev_prev, C_prev, C_curr = C_curr, C_curr, C - reduction_prev, cells = False, [] - for i in range(layers): - if i in [layers//3, 2*layers//3]: - C_curr *= 2 - reduction = True - else: - reduction = False - cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) - reduction_prev = reduction - cells.append( cell ) - C_prev_prev, C_prev = C_prev, multiplier*C_curr - self.cells = nn.ModuleList(cells) - - self.global_pooling = nn.AdaptiveAvgPool2d(1) - self.classifier = nn.Linear(C_prev, num_classes) - self.tau = 5 - - # initialize architecture parameters - k = sum(1 for i in range(self._steps) for n in range(2+i)) - num_ops = len(PRIMITIVES) - - self.alphas_normal = Parameter(torch.Tensor(k, num_ops)) - self.alphas_reduce = Parameter(torch.Tensor(k, num_ops)) - nn.init.normal_(self.alphas_normal, 0, 0.001) - nn.init.normal_(self.alphas_reduce, 0, 0.001) - - def set_tau(self, tau): - self.tau = tau - - def get_tau(self): - return self.tau - - def arch_parameters(self): - return [self.alphas_normal, self.alphas_reduce] - - def base_parameters(self): - lists = list(self.stem.parameters()) + list(self.cells.parameters()) - lists += list(self.global_pooling.parameters()) - lists += list(self.classifier.parameters()) - return lists - - def forward(self, inputs): - batch, C, H, W = inputs.size() - s0 = s1 = self.stem(inputs) - for i, cell in enumerate(self.cells): - if cell.reduction: - weights = F.softmax(self.alphas_reduce, dim=-1) - else: - weights = F.softmax(self.alphas_reduce, dim=-1) - s0, s1 = s1, cell(s0, s1, weights) - out = self.global_pooling(s1) - out = out.view(batch, -1) - logits = self.classifier(out) - return logits - - def genotype(self): - - def _parse(weights): - gene, n, start = [], 2, 0 - for i in range(self._steps): - end = start + n - W = weights[start:end].copy() - edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2] - for j in edges: - k_best = None - for k in range(len(W[j])): - if k != PRIMITIVES.index('none'): - if k_best is None or W[j][k] > W[j][k_best]: - k_best = k - gene.append((PRIMITIVES[k_best], j, float(W[j][k_best]))) - start = end - n += 1 - return gene - - with torch.no_grad(): - gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).cpu().numpy()) - gene_reduce = _parse(F.softmax(self.alphas_reduce, dim=-1).cpu().numpy()) - - concat = range(2+self._steps-self._multiplier, self._steps+2) - genotype = Genotype( - normal=gene_normal, normal_concat=concat, - reduce=gene_reduce, reduce_concat=concat - ) - return genotype diff --git a/lib/nas/model_search_v4.py b/lib/nas/model_search_v4.py deleted file mode 100644 index 760b9df..0000000 --- a/lib/nas/model_search_v4.py +++ /dev/null @@ -1,176 +0,0 @@ -# random selection -import torch -import random -import torch.nn as nn -import torch.nn.functional as F -from torch.nn.parameter import Parameter -from .operations import OPS, FactorizedReduce, ReLUConvBN -from .genotypes import PRIMITIVES, Genotype -from .construct_utils import random_select, all_select - - -class MixedOp(nn.Module): - - def __init__(self, C, stride): - super(MixedOp, self).__init__() - self._ops = nn.ModuleList() - for primitive in PRIMITIVES: - op = OPS[primitive](C, stride, False) - self._ops.append(op) - - def forward(self, x, weights, cpu_weights): - indicators = random_select( len(cpu_weights), 0.5 ) - clist, ws = [], [] - for w, indicator, op in zip(weights, indicators, self._ops): - if indicator: - clist.append( w * op(x) ) - ws.append( w ) - return sum(clist) / sum(ws) - - -class Cell(nn.Module): - - def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev): - super(Cell, self).__init__() - self.reduction = reduction - - if reduction_prev: - self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False) - else: - self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False) - self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False) - self._steps = steps - self._multiplier = multiplier - - self._ops = nn.ModuleList() - for i in range(self._steps): - for j in range(2+i): - stride = 2 if reduction and j < 2 else 1 - op = MixedOp(C, stride) - self._ops.append(op) - - def forward(self, s0, s1, weights): - s0 = self.preprocess0(s0) - s1 = self.preprocess1(s1) - - cpu_weights = weights.tolist() - states = [s0, s1] - offset = 0 - for i in range(self._steps): - clist = [] - if i == 0: - indicator = all_select( len(states) ) - else: - indicator = random_select( len(states), 0.5 ) - for j, h in enumerate(states): - if indicator[j] == 0: continue - x = self._ops[offset+j](h, weights[offset+j], cpu_weights[offset+j]) - clist.append( x ) - s = sum(clist) / sum(indicator) - offset += len(states) - states.append(s) - - return torch.cat(states[-self._multiplier:], dim=1) - - -class NetworkV4(nn.Module): - - def __init__(self, C, num_classes, layers, steps=4, multiplier=4, stem_multiplier=3): - super(NetworkV4, self).__init__() - self._C = C - self._num_classes = num_classes - self._layers = layers - self._steps = steps - self._multiplier = multiplier - - C_curr = stem_multiplier*C - self.stem = nn.Sequential( - nn.Conv2d(3, C_curr, 3, padding=1, bias=False), - nn.BatchNorm2d(C_curr) - ) - - C_prev_prev, C_prev, C_curr = C_curr, C_curr, C - reduction_prev, cells = False, [] - for i in range(layers): - if i in [layers//3, 2*layers//3]: - C_curr *= 2 - reduction = True - else: - reduction = False - cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) - reduction_prev = reduction - cells.append( cell ) - C_prev_prev, C_prev = C_prev, multiplier*C_curr - self.cells = nn.ModuleList(cells) - - self.global_pooling = nn.AdaptiveAvgPool2d(1) - self.classifier = nn.Linear(C_prev, num_classes) - self.tau = 5 - - # initialize architecture parameters - k = sum(1 for i in range(self._steps) for n in range(2+i)) - num_ops = len(PRIMITIVES) - - self.alphas_normal = Parameter(torch.Tensor(k, num_ops)) - self.alphas_reduce = Parameter(torch.Tensor(k, num_ops)) - nn.init.normal_(self.alphas_normal, 0, 0.001) - nn.init.normal_(self.alphas_reduce, 0, 0.001) - - def set_tau(self, tau): - self.tau = tau - - def get_tau(self): - return self.tau - - def arch_parameters(self): - return [self.alphas_normal, self.alphas_reduce] - - def base_parameters(self): - lists = list(self.stem.parameters()) + list(self.cells.parameters()) - lists += list(self.global_pooling.parameters()) - lists += list(self.classifier.parameters()) - return lists - - def forward(self, inputs): - batch, C, H, W = inputs.size() - s0 = s1 = self.stem(inputs) - for i, cell in enumerate(self.cells): - if cell.reduction: - weights = F.softmax(self.alphas_reduce, dim=-1) - else: - weights = F.softmax(self.alphas_reduce, dim=-1) - s0, s1 = s1, cell(s0, s1, weights) - out = self.global_pooling(s1) - out = out.view(batch, -1) - logits = self.classifier(out) - return logits - - def genotype(self): - - def _parse(weights): - gene, n, start = [], 2, 0 - for i in range(self._steps): - end = start + n - W = weights[start:end].copy() - edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2] - for j in edges: - k_best = None - for k in range(len(W[j])): - if k != PRIMITIVES.index('none'): - if k_best is None or W[j][k] > W[j][k_best]: - k_best = k - gene.append((PRIMITIVES[k_best], j, float(W[j][k_best]))) - start = end - n += 1 - return gene - - with torch.no_grad(): - gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).cpu().numpy()) - gene_reduce = _parse(F.softmax(self.alphas_reduce, dim=-1).cpu().numpy()) - - concat = range(2+self._steps-self._multiplier, self._steps+2) - genotype = Genotype( - normal=gene_normal, normal_concat=concat, - reduce=gene_reduce, reduce_concat=concat - ) - return genotype diff --git a/lib/nas/model_search_v5.py b/lib/nas/model_search_v5.py deleted file mode 100644 index 5903b21..0000000 --- a/lib/nas/model_search_v5.py +++ /dev/null @@ -1,174 +0,0 @@ -# gumbel softmax -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn.parameter import Parameter -from .operations import OPS, FactorizedReduce, ReLUConvBN -from .genotypes import PRIMITIVES, Genotype -from .construct_utils import random_select, all_select - - -class MixedOp(nn.Module): - - def __init__(self, C, stride): - super(MixedOp, self).__init__() - self._ops = nn.ModuleList() - for primitive in PRIMITIVES: - op = OPS[primitive](C, stride, False) - self._ops.append(op) - - def forward(self, x, weights, cpu_weights): - clist = [] - for j, cpu_weight in enumerate(cpu_weights): - if abs(cpu_weight) > 1e-10: - clist.append( weights[j] * self._ops[j](x) ) - assert len(clist) > 0, 'invalid length : {:}'.format(cpu_weights) - if len(clist) == 1: return clist[0] - else : return sum(clist) - - -class Cell(nn.Module): - - def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev): - super(Cell, self).__init__() - self.reduction = reduction - - if reduction_prev: - self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False) - else: - self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False) - self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False) - self._steps = steps - self._multiplier = multiplier - - self._ops = nn.ModuleList() - for i in range(self._steps): - for j in range(2+i): - stride = 2 if reduction and j < 2 else 1 - op = MixedOp(C, stride) - self._ops.append(op) - - def forward(self, s0, s1, weights): - s0 = self.preprocess0(s0) - s1 = self.preprocess1(s1) - - cpu_weights = weights.tolist() - states = [s0, s1] - offset = 0 - for i in range(self._steps): - clist = [] - if i == 0: indicator = all_select( len(states) ) - else : indicator = random_select( len(states), 0.6 ) - - for j, h in enumerate(states): - if indicator[j] == 0: continue - x = self._ops[offset+j](h, weights[offset+j], cpu_weights[offset+j]) - clist.append( x ) - s = sum(clist) - offset += len(states) - states.append(s) - - return torch.cat(states[-self._multiplier:], dim=1) - - -class NetworkV5(nn.Module): - - def __init__(self, C, num_classes, layers, steps=4, multiplier=4, stem_multiplier=3): - super(NetworkV5, self).__init__() - self._C = C - self._num_classes = num_classes - self._layers = layers - self._steps = steps - self._multiplier = multiplier - - C_curr = stem_multiplier*C - self.stem = nn.Sequential( - nn.Conv2d(3, C_curr, 3, padding=1, bias=False), - nn.BatchNorm2d(C_curr) - ) - - C_prev_prev, C_prev, C_curr = C_curr, C_curr, C - reduction_prev, cells = False, [] - for i in range(layers): - if i in [layers//3, 2*layers//3]: - C_curr *= 2 - reduction = True - else: - reduction = False - cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) - reduction_prev = reduction - cells.append( cell ) - C_prev_prev, C_prev = C_prev, multiplier*C_curr - self.cells = nn.ModuleList(cells) - - self.global_pooling = nn.AdaptiveAvgPool2d(1) - self.classifier = nn.Linear(C_prev, num_classes) - self.tau = 5 - - # initialize architecture parameters - k = sum(1 for i in range(self._steps) for n in range(2+i)) - num_ops = len(PRIMITIVES) - - self.alphas_normal = Parameter(torch.Tensor(k, num_ops)) - self.alphas_reduce = Parameter(torch.Tensor(k, num_ops)) - nn.init.normal_(self.alphas_normal, 0, 0.001) - nn.init.normal_(self.alphas_reduce, 0, 0.001) - - def set_tau(self, tau): - self.tau = tau - - def get_tau(self): - return self.tau - - def arch_parameters(self): - return [self.alphas_normal, self.alphas_reduce] - - def base_parameters(self): - lists = list(self.stem.parameters()) + list(self.cells.parameters()) - lists += list(self.global_pooling.parameters()) - lists += list(self.classifier.parameters()) - return lists - - def forward(self, inputs): - batch, C, H, W = inputs.size() - s0 = s1 = self.stem(inputs) - for i, cell in enumerate(self.cells): - if cell.reduction: - weights = F.gumbel_softmax(self.alphas_reduce, self.tau, True) - else: - weights = F.gumbel_softmax(self.alphas_normal, self.tau, True) - s0, s1 = s1, cell(s0, s1, weights) - out = self.global_pooling(s1) - out = out.view(batch, -1) - logits = self.classifier(out) - return logits - - def genotype(self): - - def _parse(weights): - gene, n, start = [], 2, 0 - for i in range(self._steps): - end = start + n - W = weights[start:end].copy() - edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2] - for j in edges: - k_best = None - for k in range(len(W[j])): - if k != PRIMITIVES.index('none'): - if k_best is None or W[j][k] > W[j][k_best]: - k_best = k - gene.append((PRIMITIVES[k_best], j, float(W[j][k_best]))) - start = end - n += 1 - return gene - - with torch.no_grad(): - gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).cpu().numpy()) - gene_reduce = _parse(F.softmax(self.alphas_reduce, dim=-1).cpu().numpy()) - - concat = range(2+self._steps-self._multiplier, self._steps+2) - genotype = Genotype( - normal=gene_normal, normal_concat=concat, - reduce=gene_reduce, reduce_concat=concat - ) - return genotype diff --git a/lib/utils/model_utils.py b/lib/utils/model_utils.py index 2c5372a..5e97bc9 100644 --- a/lib/utils/model_utils.py +++ b/lib/utils/model_utils.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn import numpy as np + def count_parameters_in_MB(model): if isinstance(model, nn.Module): return np.sum(np.prod(v.size()) for v in model.parameters())/1e6 diff --git a/scripts-cluster/README.md b/scripts-cluster/README.md index 4db4a7b..996c962 100644 --- a/scripts-cluster/README.md +++ b/scripts-cluster/README.md @@ -9,4 +9,5 @@ bash scripts-cluster/submit.sh yq01-v100-box-idl-2-8 PTB-GDAS 1 "bash ./scripts- ## CNN ``` bash scripts-cluster/submit.sh yq01-v100-box-idl-2-8 CIFAR10-CUT-GDAS-F1 1 "bash ./scripts-cnn/train-cifar.sh GDAS_F1 cifar10 cut" +bash scripts-cluster/submit.sh yq01-v100-box-idl-2-8 IMAGENET-GDAS-F1 1 "bash ./scripts-cnn/train-imagenet.sh GDAS_F1 52 14" ``` diff --git a/scripts-cluster/job-script.sh b/scripts-cluster/job-script.sh index d770ed1..2480d0a 100644 --- a/scripts-cluster/job-script.sh +++ b/scripts-cluster/job-script.sh @@ -6,9 +6,11 @@ sh /home/HGCP_Program/software-install/afs_mount/bin/afs_mount.sh \ `pwd`/hadoop-data \ afs://xingtian.afs.baidu.com:9902/user/COMM_KM_Data/dongxuanyi/datasets -tar xvf ./hadoop-data/cifar.python.tar -C ./data/data/ +export TORCH_HOME="./data/data/" +tar xvf ./hadoop-data/cifar.python.tar -C ${TORCH_HOME} +#tar xvf ./hadoop-data/ILSVRC2012.tar -C ${TORCH_HOME} -cifar_dir="./data/data/cifar.python" +cifar_dir="${TORCH_HOME}/cifar.python" if [ -d ${cifar_dir} ]; then echo "Find cifar-dir: "${cifar_dir} else @@ -16,7 +18,6 @@ else exit 1 fi echo "CHECK-DATA-DIR DONE" -export TORCH_HOME="./data/data/" # config python diff --git a/scripts-cnn/train-imagenet.sh b/scripts-cnn/train-imagenet.sh index db1042a..a0152ed 100644 --- a/scripts-cnn/train-imagenet.sh +++ b/scripts-cnn/train-imagenet.sh @@ -24,6 +24,8 @@ if [ ! -f ${PY_C} ]; then PY_C="python" else echo "Cluster Run with Python: "${PY_C} + echo "Unzip ILSVRC2012" + tar xvf ./hadoop-data/ILSVRC2012.tar -C ${TORCH_HOME} fi ${PY_C} --version