NAS-sharing-parameters support 3 datasets / update ops / update pypi
This commit is contained in:
		| @@ -13,16 +13,22 @@ OPS = { | ||||
|   'nor_conv_7x7' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1), affine, track_running_stats), | ||||
|   'nor_conv_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats), | ||||
|   'nor_conv_1x1' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1), affine, track_running_stats), | ||||
|   'dua_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats), | ||||
|   'dua_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (5,5), (stride,stride), (2,2), (1,1), affine, track_running_stats), | ||||
|   'dil_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats:     SepConv(C_in, C_out, (3,3), (stride,stride), (2,2), (2,2), affine, track_running_stats), | ||||
|   'dil_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats:     SepConv(C_in, C_out, (5,5), (stride,stride), (4,4), (2,2), affine, track_running_stats), | ||||
|   'skip_connect' : lambda C_in, C_out, stride, affine, track_running_stats: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine, track_running_stats), | ||||
| } | ||||
|  | ||||
| CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3'] | ||||
| NAS_BENCH_102         = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3'] | ||||
| DARTS_SPACE           = ['none', 'skip_connect', 'dua_sepc_3x3', 'dua_sepc_5x5', 'dil_sepc_3x3', 'dil_sepc_5x5', 'avg_pool_3x3', 'max_pool_3x3'] | ||||
|  | ||||
| SearchSpaceNames = {'connect-nas'  : CONNECT_NAS_BENCHMARK, | ||||
|                     'aa-nas'       : NAS_BENCH_102, | ||||
|                     'nas-bench-102': NAS_BENCH_102, | ||||
|                     'full'         : sorted(list(OPS.keys()))} | ||||
|                     'darts'        : DARTS_SPACE} | ||||
|                     #'full'         : sorted(list(OPS.keys()))} | ||||
|  | ||||
|  | ||||
| class ReLUConvBN(nn.Module): | ||||
| @@ -39,6 +45,34 @@ class ReLUConvBN(nn.Module): | ||||
|     return self.op(x) | ||||
|  | ||||
|  | ||||
| class SepConv(nn.Module): | ||||
|      | ||||
|   def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True): | ||||
|     super(SepConv, self).__init__() | ||||
|     self.op = nn.Sequential( | ||||
|       nn.ReLU(inplace=False), | ||||
|       nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False), | ||||
|       nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), | ||||
|       nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats), | ||||
|       ) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     return self.op(x) | ||||
|  | ||||
|  | ||||
| class DualSepConv(nn.Module): | ||||
|      | ||||
|   def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True): | ||||
|     super(DualSepConv, self).__init__() | ||||
|     self.op_a = SepConv(C_in, C_in , kernel_size, stride, padding, dilation, affine, track_running_stats) | ||||
|     self.op_b = SepConv(C_in, C_out, kernel_size, 1, padding, dilation, affine, track_running_stats) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     x = self.op_a(x) | ||||
|     x = self.op_b(x) | ||||
|     return x | ||||
|  | ||||
|  | ||||
| class ResNetBasicblock(nn.Module): | ||||
|  | ||||
|   def __init__(self, inplanes, planes, stride, affine=True): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user