MeCo/pycls/models/nas/nas.py
HamsterMimi 189df25fd3 upload
2023-05-04 13:09:03 +08:00

337 lines
12 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""NAS network (adopted from DARTS)."""
from torch.autograd import Variable
import torch
import torch.nn as nn
import pycls.core.logging as logging
from pycls.core.config import cfg
from pycls.models.common import Preprocess
from pycls.models.common import Classifier
from pycls.models.nas.genotypes import GENOTYPES
from pycls.models.nas.genotypes import Genotype
from pycls.models.nas.operations import FactorizedReduce
from pycls.models.nas.operations import OPS
from pycls.models.nas.operations import ReLUConvBN
from pycls.models.nas.operations import Identity
logger = logging.get_logger(__name__)
def drop_path(x, drop_prob):
"""Drop path (ported from DARTS)."""
if drop_prob > 0.:
keep_prob = 1.-drop_prob
mask = Variable(
torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)
)
x.div_(keep_prob)
x.mul_(mask)
return x
class Cell(nn.Module):
"""NAS cell (ported from DARTS)."""
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
super(Cell, self).__init__()
logger.info('{}, {}, {}'.format(C_prev_prev, C_prev, C))
if reduction_prev:
self.preprocess0 = FactorizedReduce(C_prev_prev, C)
else:
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
if reduction:
op_names, indices = zip(*genotype.reduce)
concat = genotype.reduce_concat
else:
op_names, indices = zip(*genotype.normal)
concat = genotype.normal_concat
self._compile(C, op_names, indices, concat, reduction)
def _compile(self, C, op_names, indices, concat, reduction):
assert len(op_names) == len(indices)
self._steps = len(op_names) // 2
self._concat = concat
self.multiplier = len(concat)
self._ops = nn.ModuleList()
for name, index in zip(op_names, indices):
stride = 2 if reduction and index < 2 else 1
op = OPS[name](C, stride, True)
self._ops += [op]
self._indices = indices
def forward(self, s0, s1, drop_prob):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
for i in range(self._steps):
h1 = states[self._indices[2*i]]
h2 = states[self._indices[2*i+1]]
op1 = self._ops[2*i]
op2 = self._ops[2*i+1]
h1 = op1(h1)
h2 = op2(h2)
if self.training and drop_prob > 0.:
if not isinstance(op1, Identity):
h1 = drop_path(h1, drop_prob)
if not isinstance(op2, Identity):
h2 = drop_path(h2, drop_prob)
s = h1 + h2
states += [s]
return torch.cat([states[i] for i in self._concat], dim=1)
class AuxiliaryHeadCIFAR(nn.Module):
def __init__(self, C, num_classes):
"""assuming input size 8x8"""
super(AuxiliaryHeadCIFAR, self).__init__()
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.classifier = nn.Linear(768, num_classes)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0),-1))
return x
class AuxiliaryHeadImageNet(nn.Module):
def __init__(self, C, num_classes):
"""assuming input size 14x14"""
super(AuxiliaryHeadImageNet, self).__init__()
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False),
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
# NOTE: This batchnorm was omitted in my earlier implementation due to a typo.
# Commenting it out for consistency with the experiments in the paper.
# nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.classifier = nn.Linear(768, num_classes)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0),-1))
return x
class NetworkCIFAR(nn.Module):
"""CIFAR network (ported from DARTS)."""
def __init__(self, C, num_classes, layers, auxiliary, genotype):
super(NetworkCIFAR, self).__init__()
self._layers = layers
self._auxiliary = auxiliary
stem_multiplier = 3
C_curr = stem_multiplier*C
self.stem = nn.Sequential(
nn.Conv2d(cfg.MODEL.INPUT_CHANNELS, C_curr, 3, padding=1, bias=False),
nn.BatchNorm2d(C_curr)
)
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
self.cells = nn.ModuleList()
reduction_prev = False
for i in range(layers):
if i in [layers//3, 2*layers//3]:
C_curr *= 2
reduction = True
else:
reduction = False
cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
reduction_prev = reduction
self.cells += [cell]
C_prev_prev, C_prev = C_prev, cell.multiplier*C_curr
if i == 2*layers//3:
C_to_auxiliary = C_prev
if auxiliary:
self.auxiliary_head = AuxiliaryHeadCIFAR(C_to_auxiliary, num_classes)
self.classifier = Classifier(C_prev, num_classes)
def forward(self, input):
input = Preprocess(input)
logits_aux = None
s0 = s1 = self.stem(input)
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
if i == 2*self._layers//3:
if self._auxiliary and self.training:
logits_aux = self.auxiliary_head(s1)
logits = self.classifier(s1, input.shape[2:])
if self._auxiliary and self.training:
return logits, logits_aux
return logits
def _loss(self, input, target, return_logits=False):
logits = self(input)
loss = self._criterion(logits, target)
return (loss, logits) if return_logits else loss
def step(self, input, target, args, shared=None, return_grad=False):
Lt, logit_t = self._loss(input, target, return_logits=True)
Lt.backward()
if args.grad_clip != 0:
nn.utils.clip_grad_norm_(self.get_weights(), args.grad_clip)
self.optimizer.step()
if return_grad:
grad = torch.nn.utils.parameters_to_vector([p.grad for p in self.get_weights()])
return logit_t, Lt, grad
else:
return logit_t, Lt
class NetworkImageNet(nn.Module):
"""ImageNet network (ported from DARTS)."""
def __init__(self, C, num_classes, layers, auxiliary, genotype):
super(NetworkImageNet, self).__init__()
self._layers = layers
self._auxiliary = auxiliary
self.stem0 = nn.Sequential(
nn.Conv2d(cfg.MODEL.INPUT_CHANNELS, C // 2, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C // 2),
nn.ReLU(inplace=True),
nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C),
)
self.stem1 = nn.Sequential(
nn.ReLU(inplace=True),
nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C),
)
C_prev_prev, C_prev, C_curr = C, C, C
self.cells = nn.ModuleList()
reduction_prev = True
reduction_layers = [layers//3] if cfg.TASK == 'seg' else [layers//3, 2*layers//3]
for i in range(layers):
if i in reduction_layers:
C_curr *= 2
reduction = True
else:
reduction = False
cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
reduction_prev = reduction
self.cells += [cell]
C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
if i == 2 * layers // 3:
C_to_auxiliary = C_prev
if auxiliary:
self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes)
self.classifier = Classifier(C_prev, num_classes)
def forward(self, input):
input = Preprocess(input)
logits_aux = None
s0 = self.stem0(input)
s1 = self.stem1(s0)
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
if i == 2 * self._layers // 3:
if self._auxiliary and self.training:
logits_aux = self.auxiliary_head(s1)
logits = self.classifier(s1, input.shape[2:])
if self._auxiliary and self.training:
return logits, logits_aux
return logits
def _loss(self, input, target, return_logits=False):
logits = self(input)
loss = self._criterion(logits, target)
return (loss, logits) if return_logits else loss
def step(self, input, target, args, shared=None, return_grad=False):
Lt, logit_t = self._loss(input, target, return_logits=True)
Lt.backward()
if args.grad_clip != 0:
nn.utils.clip_grad_norm_(self.get_weights(), args.grad_clip)
self.optimizer.step()
if return_grad:
grad = torch.nn.utils.parameters_to_vector([p.grad for p in self.get_weights()])
return logit_t, Lt, grad
else:
return logit_t, Lt
class NAS(nn.Module):
"""NAS net wrapper (delegates to nets from DARTS)."""
def __init__(self):
assert cfg.TRAIN.DATASET in ['cifar10', 'imagenet', 'cityscapes'], \
'Training on {} is not supported'.format(cfg.TRAIN.DATASET)
assert cfg.TEST.DATASET in ['cifar10', 'imagenet', 'cityscapes'], \
'Testing on {} is not supported'.format(cfg.TEST.DATASET)
assert cfg.NAS.GENOTYPE in GENOTYPES, \
'Genotype {} not supported'.format(cfg.NAS.GENOTYPE)
super(NAS, self).__init__()
logger.info('Constructing NAS: {}'.format(cfg.NAS))
# Use a custom or predefined genotype
if cfg.NAS.GENOTYPE == 'custom':
genotype = Genotype(
normal=cfg.NAS.CUSTOM_GENOTYPE[0],
normal_concat=cfg.NAS.CUSTOM_GENOTYPE[1],
reduce=cfg.NAS.CUSTOM_GENOTYPE[2],
reduce_concat=cfg.NAS.CUSTOM_GENOTYPE[3],
)
else:
genotype = GENOTYPES[cfg.NAS.GENOTYPE]
# Determine the network constructor for dataset
if 'cifar' in cfg.TRAIN.DATASET:
net_ctor = NetworkCIFAR
else:
net_ctor = NetworkImageNet
# Construct the network
self.net_ = net_ctor(
C=cfg.NAS.WIDTH,
num_classes=cfg.MODEL.NUM_CLASSES,
layers=cfg.NAS.DEPTH,
auxiliary=cfg.NAS.AUX,
genotype=genotype
)
# Drop path probability (set / annealed based on epoch)
self.net_.drop_path_prob = 0.0
def set_drop_path_prob(self, drop_path_prob):
self.net_.drop_path_prob = drop_path_prob
def forward(self, x):
return self.net_.forward(x)