#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """NAS network (adopted from DARTS).""" from torch.autograd import Variable import torch import torch.nn as nn import pycls.core.logging as logging from pycls.core.config import cfg from pycls.models.common import Preprocess from pycls.models.common import Classifier from pycls.models.nas.genotypes import GENOTYPES from pycls.models.nas.genotypes import Genotype from pycls.models.nas.operations import FactorizedReduce from pycls.models.nas.operations import OPS from pycls.models.nas.operations import ReLUConvBN from pycls.models.nas.operations import Identity logger = logging.get_logger(__name__) def drop_path(x, drop_prob): """Drop path (ported from DARTS).""" if drop_prob > 0.: keep_prob = 1.-drop_prob mask = Variable( torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob) ) x.div_(keep_prob) x.mul_(mask) return x class Cell(nn.Module): """NAS cell (ported from DARTS).""" def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev): super(Cell, self).__init__() logger.info('{}, {}, {}'.format(C_prev_prev, C_prev, C)) if reduction_prev: self.preprocess0 = FactorizedReduce(C_prev_prev, C) else: self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0) self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0) if reduction: op_names, indices = zip(*genotype.reduce) concat = genotype.reduce_concat else: op_names, indices = zip(*genotype.normal) concat = genotype.normal_concat self._compile(C, op_names, indices, concat, reduction) def _compile(self, C, op_names, indices, concat, reduction): assert len(op_names) == len(indices) self._steps = len(op_names) // 2 self._concat = concat self.multiplier = len(concat) self._ops = nn.ModuleList() for name, index in zip(op_names, indices): stride = 2 if reduction and index < 2 else 1 op = OPS[name](C, stride, True) self._ops += [op] self._indices = indices def forward(self, s0, s1, drop_prob): s0 = self.preprocess0(s0) s1 = self.preprocess1(s1) states = [s0, s1] for i in range(self._steps): h1 = states[self._indices[2*i]] h2 = states[self._indices[2*i+1]] op1 = self._ops[2*i] op2 = self._ops[2*i+1] h1 = op1(h1) h2 = op2(h2) if self.training and drop_prob > 0.: if not isinstance(op1, Identity): h1 = drop_path(h1, drop_prob) if not isinstance(op2, Identity): h2 = drop_path(h2, drop_prob) s = h1 + h2 states += [s] return torch.cat([states[i] for i in self._concat], dim=1) class AuxiliaryHeadCIFAR(nn.Module): def __init__(self, C, num_classes): """assuming input size 8x8""" super(AuxiliaryHeadCIFAR, self).__init__() self.features = nn.Sequential( nn.ReLU(inplace=True), nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2 nn.Conv2d(C, 128, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 768, 2, bias=False), nn.BatchNorm2d(768), nn.ReLU(inplace=True) ) self.classifier = nn.Linear(768, num_classes) def forward(self, x): x = self.features(x) x = self.classifier(x.view(x.size(0),-1)) return x class AuxiliaryHeadImageNet(nn.Module): def __init__(self, C, num_classes): """assuming input size 14x14""" super(AuxiliaryHeadImageNet, self).__init__() self.features = nn.Sequential( nn.ReLU(inplace=True), nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False), nn.Conv2d(C, 128, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 768, 2, bias=False), # NOTE: This batchnorm was omitted in my earlier implementation due to a typo. # Commenting it out for consistency with the experiments in the paper. # nn.BatchNorm2d(768), nn.ReLU(inplace=True) ) self.classifier = nn.Linear(768, num_classes) def forward(self, x): x = self.features(x) x = self.classifier(x.view(x.size(0),-1)) return x class NetworkCIFAR(nn.Module): """CIFAR network (ported from DARTS).""" def __init__(self, C, num_classes, layers, auxiliary, genotype): super(NetworkCIFAR, self).__init__() self._layers = layers self._auxiliary = auxiliary stem_multiplier = 3 C_curr = stem_multiplier*C self.stem = nn.Sequential( nn.Conv2d(cfg.MODEL.INPUT_CHANNELS, C_curr, 3, padding=1, bias=False), nn.BatchNorm2d(C_curr) ) C_prev_prev, C_prev, C_curr = C_curr, C_curr, C self.cells = nn.ModuleList() reduction_prev = False for i in range(layers): if i in [layers//3, 2*layers//3]: C_curr *= 2 reduction = True else: reduction = False cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) reduction_prev = reduction self.cells += [cell] C_prev_prev, C_prev = C_prev, cell.multiplier*C_curr if i == 2*layers//3: C_to_auxiliary = C_prev if auxiliary: self.auxiliary_head = AuxiliaryHeadCIFAR(C_to_auxiliary, num_classes) self.classifier = Classifier(C_prev, num_classes) def forward(self, input): input = Preprocess(input) logits_aux = None s0 = s1 = self.stem(input) for i, cell in enumerate(self.cells): s0, s1 = s1, cell(s0, s1, self.drop_path_prob) if i == 2*self._layers//3: if self._auxiliary and self.training: logits_aux = self.auxiliary_head(s1) logits = self.classifier(s1, input.shape[2:]) if self._auxiliary and self.training: return logits, logits_aux return logits def _loss(self, input, target, return_logits=False): logits = self(input) loss = self._criterion(logits, target) return (loss, logits) if return_logits else loss def step(self, input, target, args, shared=None, return_grad=False): Lt, logit_t = self._loss(input, target, return_logits=True) Lt.backward() if args.grad_clip != 0: nn.utils.clip_grad_norm_(self.get_weights(), args.grad_clip) self.optimizer.step() if return_grad: grad = torch.nn.utils.parameters_to_vector([p.grad for p in self.get_weights()]) return logit_t, Lt, grad else: return logit_t, Lt class NetworkImageNet(nn.Module): """ImageNet network (ported from DARTS).""" def __init__(self, C, num_classes, layers, auxiliary, genotype): super(NetworkImageNet, self).__init__() self._layers = layers self._auxiliary = auxiliary self.stem0 = nn.Sequential( nn.Conv2d(cfg.MODEL.INPUT_CHANNELS, C // 2, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(C // 2), nn.ReLU(inplace=True), nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(C), ) self.stem1 = nn.Sequential( nn.ReLU(inplace=True), nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(C), ) C_prev_prev, C_prev, C_curr = C, C, C self.cells = nn.ModuleList() reduction_prev = True reduction_layers = [layers//3] if cfg.TASK == 'seg' else [layers//3, 2*layers//3] for i in range(layers): if i in reduction_layers: C_curr *= 2 reduction = True else: reduction = False cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) reduction_prev = reduction self.cells += [cell] C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr if i == 2 * layers // 3: C_to_auxiliary = C_prev if auxiliary: self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes) self.classifier = Classifier(C_prev, num_classes) def forward(self, input): input = Preprocess(input) logits_aux = None s0 = self.stem0(input) s1 = self.stem1(s0) for i, cell in enumerate(self.cells): s0, s1 = s1, cell(s0, s1, self.drop_path_prob) if i == 2 * self._layers // 3: if self._auxiliary and self.training: logits_aux = self.auxiliary_head(s1) logits = self.classifier(s1, input.shape[2:]) if self._auxiliary and self.training: return logits, logits_aux return logits def _loss(self, input, target, return_logits=False): logits = self(input) loss = self._criterion(logits, target) return (loss, logits) if return_logits else loss def step(self, input, target, args, shared=None, return_grad=False): Lt, logit_t = self._loss(input, target, return_logits=True) Lt.backward() if args.grad_clip != 0: nn.utils.clip_grad_norm_(self.get_weights(), args.grad_clip) self.optimizer.step() if return_grad: grad = torch.nn.utils.parameters_to_vector([p.grad for p in self.get_weights()]) return logit_t, Lt, grad else: return logit_t, Lt class NAS(nn.Module): """NAS net wrapper (delegates to nets from DARTS).""" def __init__(self): assert cfg.TRAIN.DATASET in ['cifar10', 'imagenet', 'cityscapes'], \ 'Training on {} is not supported'.format(cfg.TRAIN.DATASET) assert cfg.TEST.DATASET in ['cifar10', 'imagenet', 'cityscapes'], \ 'Testing on {} is not supported'.format(cfg.TEST.DATASET) assert cfg.NAS.GENOTYPE in GENOTYPES, \ 'Genotype {} not supported'.format(cfg.NAS.GENOTYPE) super(NAS, self).__init__() logger.info('Constructing NAS: {}'.format(cfg.NAS)) # Use a custom or predefined genotype if cfg.NAS.GENOTYPE == 'custom': genotype = Genotype( normal=cfg.NAS.CUSTOM_GENOTYPE[0], normal_concat=cfg.NAS.CUSTOM_GENOTYPE[1], reduce=cfg.NAS.CUSTOM_GENOTYPE[2], reduce_concat=cfg.NAS.CUSTOM_GENOTYPE[3], ) else: genotype = GENOTYPES[cfg.NAS.GENOTYPE] # Determine the network constructor for dataset if 'cifar' in cfg.TRAIN.DATASET: net_ctor = NetworkCIFAR else: net_ctor = NetworkImageNet # Construct the network self.net_ = net_ctor( C=cfg.NAS.WIDTH, num_classes=cfg.MODEL.NUM_CLASSES, layers=cfg.NAS.DEPTH, auxiliary=cfg.NAS.AUX, genotype=genotype ) # Drop path probability (set / annealed based on epoch) self.net_.drop_path_prob = 0.0 def set_drop_path_prob(self, drop_path_prob): self.net_.drop_path_prob = drop_path_prob def forward(self, x): return self.net_.forward(x)