update NAS-Bench-102 baselines / support track_running_stats
This commit is contained in:
		| @@ -19,9 +19,9 @@ class InferCell(nn.Module): | ||||
|       cur_innod = [] | ||||
|       for (op_name, op_in) in node_info: | ||||
|         if op_in == 0: | ||||
|           layer = OPS[op_name](C_in , C_out, stride, True) | ||||
|           layer = OPS[op_name](C_in , C_out, stride, True, True) | ||||
|         else: | ||||
|           layer = OPS[op_name](C_out, C_out,      1, True) | ||||
|           layer = OPS[op_name](C_out, C_out,      1, True, True) | ||||
|         cur_index.append( len(self.layers) ) | ||||
|         cur_innod.append( op_in ) | ||||
|         self.layers.append( layer ) | ||||
|   | ||||
| @@ -7,13 +7,13 @@ import torch.nn as nn | ||||
| __all__ = ['OPS', 'ResNetBasicblock', 'SearchSpaceNames'] | ||||
|  | ||||
| OPS = { | ||||
|   'none'         : lambda C_in, C_out, stride, affine: Zero(C_in, C_out, stride), | ||||
|   'avg_pool_3x3' : lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'avg'), | ||||
|   'max_pool_3x3' : lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'max'), | ||||
|   'nor_conv_7x7' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1), affine), | ||||
|   'nor_conv_3x3' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine), | ||||
|   'nor_conv_1x1' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1), affine), | ||||
|   'skip_connect' : lambda C_in, C_out, stride, affine: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine), | ||||
|   'none'         : lambda C_in, C_out, stride, affine, track_running_stats: Zero(C_in, C_out, stride), | ||||
|   'avg_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'avg', affine, track_running_stats), | ||||
|   'max_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'max', affine, track_running_stats), | ||||
|   'nor_conv_7x7' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1), affine, track_running_stats), | ||||
|   'nor_conv_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats), | ||||
|   'nor_conv_1x1' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1), affine, track_running_stats), | ||||
|   'skip_connect' : lambda C_in, C_out, stride, affine, track_running_stats: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine, track_running_stats), | ||||
| } | ||||
|  | ||||
| CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3'] | ||||
| @@ -27,12 +27,12 @@ SearchSpaceNames = {'connect-nas'  : CONNECT_NAS_BENCHMARK, | ||||
|  | ||||
| class ReLUConvBN(nn.Module): | ||||
|  | ||||
|   def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine): | ||||
|   def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True): | ||||
|     super(ReLUConvBN, self).__init__() | ||||
|     self.op = nn.Sequential( | ||||
|       nn.ReLU(inplace=False), | ||||
|       nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False), | ||||
|       nn.BatchNorm2d(C_out, affine=affine) | ||||
|       nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats) | ||||
|     ) | ||||
|  | ||||
|   def forward(self, x): | ||||
| @@ -77,12 +77,12 @@ class ResNetBasicblock(nn.Module): | ||||
|  | ||||
| class POOLING(nn.Module): | ||||
|  | ||||
|   def __init__(self, C_in, C_out, stride, mode, affine=True): | ||||
|   def __init__(self, C_in, C_out, stride, mode, affine=True, track_running_stats=True): | ||||
|     super(POOLING, self).__init__() | ||||
|     if C_in == C_out: | ||||
|       self.preprocess = None | ||||
|     else: | ||||
|       self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, affine) | ||||
|       self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, affine, track_running_stats) | ||||
|     if mode == 'avg'  : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) | ||||
|     elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1) | ||||
|     else              : raise ValueError('Invalid mode={:} in POOLING'.format(mode)) | ||||
| @@ -127,7 +127,7 @@ class Zero(nn.Module): | ||||
|  | ||||
| class FactorizedReduce(nn.Module): | ||||
|  | ||||
|   def __init__(self, C_in, C_out, stride, affine): | ||||
|   def __init__(self, C_in, C_out, stride, affine, track_running_stats): | ||||
|     super(FactorizedReduce, self).__init__() | ||||
|     self.stride = stride | ||||
|     self.C_in   = C_in   | ||||
| @@ -142,7 +142,7 @@ class FactorizedReduce(nn.Module): | ||||
|       self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0) | ||||
|     else: | ||||
|       raise ValueError('Invalid stride : {:}'.format(stride)) | ||||
|     self.bn = nn.BatchNorm2d(C_out, affine=affine) | ||||
|     self.bn = nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     x = self.relu(x) | ||||
|   | ||||
| @@ -11,7 +11,7 @@ from ..cell_operations import OPS | ||||
|  | ||||
| class SearchCell(nn.Module): | ||||
|  | ||||
|   def __init__(self, C_in, C_out, stride, max_nodes, op_names): | ||||
|   def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True): | ||||
|     super(SearchCell, self).__init__() | ||||
|  | ||||
|     self.op_names  = deepcopy(op_names) | ||||
| @@ -23,9 +23,9 @@ class SearchCell(nn.Module): | ||||
|       for j in range(i): | ||||
|         node_str = '{:}<-{:}'.format(i, j) | ||||
|         if j == 0: | ||||
|           xlists = [OPS[op_name](C_in , C_out, stride, False) for op_name in op_names] | ||||
|           xlists = [OPS[op_name](C_in , C_out, stride, affine, track_running_stats) for op_name in op_names] | ||||
|         else: | ||||
|           xlists = [OPS[op_name](C_in , C_out,      1, False) for op_name in op_names] | ||||
|           xlists = [OPS[op_name](C_in , C_out,      1, affine, track_running_stats) for op_name in op_names] | ||||
|         self.edges[ node_str ] = nn.ModuleList( xlists ) | ||||
|     self.edge_keys  = sorted(list(self.edges.keys())) | ||||
|     self.edge2index = {key:i for i, key in enumerate(self.edge_keys)} | ||||
|   | ||||
		Reference in New Issue
	
	Block a user