import torch import torch.nn as nn from .construct_utils import Cell, Transition class AuxiliaryHeadCIFAR(nn.Module): def __init__(self, C, num_classes): """assuming input size 8x8""" super(AuxiliaryHeadCIFAR, self).__init__() self.features = nn.Sequential( nn.ReLU(inplace=True), nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2 nn.Conv2d(C, 128, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 768, 2, bias=False), nn.BatchNorm2d(768), nn.ReLU(inplace=True) ) self.classifier = nn.Linear(768, num_classes) def forward(self, x): x = self.features(x) x = self.classifier(x.view(x.size(0),-1)) return x class NetworkCIFAR(nn.Module): def __init__(self, C, num_classes, layers, auxiliary, genotype): super(NetworkCIFAR, self).__init__() self._layers = layers stem_multiplier = 3 C_curr = stem_multiplier*C self.stem = nn.Sequential( nn.Conv2d(3, C_curr, 3, padding=1, bias=False), nn.BatchNorm2d(C_curr) ) C_prev_prev, C_prev, C_curr = C_curr, C_curr, C self.cells = nn.ModuleList() reduction_prev = False for i in range(layers): if i in [layers//3, 2*layers//3]: C_curr *= 2 reduction = True else: reduction = False if reduction and genotype.reduce is None: cell = Transition(C_prev_prev, C_prev, C_curr, reduction_prev) else: cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) reduction_prev = reduction self.cells.append( cell ) C_prev_prev, C_prev = C_prev, cell.multiplier*C_curr if i == 2*layers//3: C_to_auxiliary = C_prev if auxiliary: self.auxiliary_head = AuxiliaryHeadCIFAR(C_to_auxiliary, num_classes) else: self.auxiliary_head = None self.global_pooling = nn.AdaptiveAvgPool2d(1) self.classifier = nn.Linear(C_prev, num_classes) self.drop_path_prob = -1 def update_drop_path(self, drop_path_prob): self.drop_path_prob = drop_path_prob def auxiliary_param(self): if self.auxiliary_head is None: return [] else: return list( self.auxiliary_head.parameters() ) def forward(self, inputs): s0 = s1 = self.stem(inputs) for i, cell in enumerate(self.cells): s0, s1 = s1, cell(s0, s1, self.drop_path_prob) if i == 2*self._layers//3: if self.auxiliary_head and self.training: logits_aux = self.auxiliary_head(s1) out = self.global_pooling(s1) out = out.view(out.size(0), -1) logits = self.classifier(out) if self.auxiliary_head and self.training: return logits, logits_aux else: return logits