MeCo/nasbench201/cell_infers/tiny_network.py
HamsterMimi 2410fe9f5e update
2023-05-04 13:42:06 +08:00

82 lines
2.7 KiB
Python

#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch.nn as nn
from ..cell_operations import ResNetBasicblock
from .cells import InferCell
# The macro structure for architectures in NAS-Bench-201
class TinyNetwork(nn.Module):
def __init__(self, C, N, genotype, num_classes):
super(TinyNetwork, self).__init__()
self._C = C
self._layerN = N
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C))
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev = C
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2, True)
else:
cell = InferCell(genotype, C_prev, C_curr, 1)
self.cells.append( cell )
C_prev = cell.out_dim
self._Layer= len(self.cells)
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self.requires_feature = True
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_C}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def forward(self, inputs):
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
if self.requires_feature:
return logits, out
else:
return logits
def _loss(self, input, target, return_logits=False):
logits, _ = self(input)
loss = self._criterion(logits, target)
return (loss, logits) if return_logits else loss
def step(self, input, target, args, shared=None, return_grad=False):
Lt, logit_t = self._loss(input, target, return_logits=True)
Lt.backward()
if args.grad_clip != 0:
nn.utils.clip_grad_norm_(self.get_weights(), args.grad_clip)
self.optimizer.step()
if return_grad:
grad = torch.nn.utils.parameters_to_vector([p.grad for p in self.get_weights()])
return logit_t, Lt, grad
else:
return logit_t, Lt