183 lines
6.9 KiB
Python
183 lines
6.9 KiB
Python
import math, random, torch
|
|
import warnings
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from copy import deepcopy
|
|
import sys
|
|
sys.path.insert(0, '../')
|
|
from nasbench201.cell_operations import OPS
|
|
|
|
|
|
# This module is used for NAS-Bench-201, represents a small search space with a complete DAG
|
|
class NAS201SearchCell(nn.Module):
|
|
|
|
def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True):
|
|
super(NAS201SearchCell, self).__init__()
|
|
|
|
self.op_names = deepcopy(op_names)
|
|
self.edges = nn.ModuleDict()
|
|
self.max_nodes = max_nodes
|
|
self.in_dim = C_in
|
|
self.out_dim = C_out
|
|
for i in range(1, max_nodes):
|
|
for j in range(i):
|
|
node_str = '{:}<-{:}'.format(i, j)
|
|
if j == 0:
|
|
xlists = [OPS[op_name](C_in , C_out, stride, affine, track_running_stats) for op_name in op_names]
|
|
else:
|
|
xlists = [OPS[op_name](C_in , C_out, 1, affine, track_running_stats) for op_name in op_names]
|
|
self.edges[ node_str ] = nn.ModuleList( xlists )
|
|
self.edge_keys = sorted(list(self.edges.keys()))
|
|
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}
|
|
self.num_edges = len(self.edges)
|
|
|
|
def extra_repr(self):
|
|
string = 'info :: {max_nodes} nodes, inC={in_dim}, outC={out_dim}'.format(**self.__dict__)
|
|
return string
|
|
|
|
def forward(self, inputs, weightss):
|
|
return self._forward(inputs, weightss)
|
|
|
|
def _forward(self, inputs, weightss):
|
|
with torch.autograd.set_detect_anomaly(True):
|
|
nodes = [inputs]
|
|
for i in range(1, self.max_nodes):
|
|
inter_nodes = []
|
|
for j in range(i):
|
|
node_str = '{:}<-{:}'.format(i, j)
|
|
weights = weightss[ self.edge2index[node_str] ]
|
|
inter_nodes.append(sum(layer(nodes[j], block_input=True)*w if w==0 else layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights)) )
|
|
nodes.append( sum(inter_nodes) )
|
|
return nodes[-1]
|
|
|
|
# GDAS
|
|
def forward_gdas(self, inputs, hardwts, index):
|
|
nodes = [inputs]
|
|
for i in range(1, self.max_nodes):
|
|
inter_nodes = []
|
|
for j in range(i):
|
|
node_str = '{:}<-{:}'.format(i, j)
|
|
weights = hardwts[ self.edge2index[node_str] ]
|
|
argmaxs = index[ self.edge2index[node_str] ].item()
|
|
weigsum = sum( weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie] for _ie, edge in enumerate(self.edges[node_str]) )
|
|
inter_nodes.append( weigsum )
|
|
nodes.append( sum(inter_nodes) )
|
|
return nodes[-1]
|
|
|
|
# joint
|
|
def forward_joint(self, inputs, weightss):
|
|
nodes = [inputs]
|
|
for i in range(1, self.max_nodes):
|
|
inter_nodes = []
|
|
for j in range(i):
|
|
node_str = '{:}<-{:}'.format(i, j)
|
|
weights = weightss[ self.edge2index[node_str] ]
|
|
#aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) / weights.numel()
|
|
aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) )
|
|
inter_nodes.append( aggregation )
|
|
nodes.append( sum(inter_nodes) )
|
|
return nodes[-1]
|
|
|
|
# uniform random sampling per iteration, SETN
|
|
def forward_urs(self, inputs):
|
|
nodes = [inputs]
|
|
for i in range(1, self.max_nodes):
|
|
while True: # to avoid select zero for all ops
|
|
sops, has_non_zero = [], False
|
|
for j in range(i):
|
|
node_str = '{:}<-{:}'.format(i, j)
|
|
candidates = self.edges[node_str]
|
|
select_op = random.choice(candidates)
|
|
sops.append( select_op )
|
|
if not hasattr(select_op, 'is_zero') or select_op.is_zero is False: has_non_zero=True
|
|
if has_non_zero: break
|
|
inter_nodes = []
|
|
for j, select_op in enumerate(sops):
|
|
inter_nodes.append( select_op(nodes[j]) )
|
|
nodes.append( sum(inter_nodes) )
|
|
return nodes[-1]
|
|
|
|
# select the argmax
|
|
def forward_select(self, inputs, weightss):
|
|
nodes = [inputs]
|
|
for i in range(1, self.max_nodes):
|
|
inter_nodes = []
|
|
for j in range(i):
|
|
node_str = '{:}<-{:}'.format(i, j)
|
|
weights = weightss[ self.edge2index[node_str] ]
|
|
inter_nodes.append( self.edges[node_str][ weights.argmax().item() ]( nodes[j] ) )
|
|
#inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) )
|
|
nodes.append( sum(inter_nodes) )
|
|
return nodes[-1]
|
|
|
|
# forward with a specific structure
|
|
def forward_dynamic(self, inputs, structure):
|
|
nodes = [inputs]
|
|
for i in range(1, self.max_nodes):
|
|
cur_op_node = structure.nodes[i-1]
|
|
inter_nodes = []
|
|
for op_name, j in cur_op_node:
|
|
node_str = '{:}<-{:}'.format(i, j)
|
|
op_index = self.op_names.index( op_name )
|
|
inter_nodes.append( self.edges[node_str][op_index]( nodes[j] ) )
|
|
nodes.append( sum(inter_nodes) )
|
|
return nodes[-1]
|
|
|
|
|
|
def channel_shuffle(x, groups):
|
|
batchsize, num_channels, height, width = x.data.size()
|
|
channels_per_group = num_channels // groups
|
|
# reshape
|
|
x = x.view(batchsize, groups,
|
|
channels_per_group, height, width)
|
|
x = torch.transpose(x, 1, 2).contiguous()
|
|
# flatten
|
|
x = x.view(batchsize, -1, height, width)
|
|
return x
|
|
|
|
|
|
class NAS201SearchCell_PartialChannel(NAS201SearchCell):
|
|
|
|
def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True, k=4):
|
|
super(NAS201SearchCell, self).__init__()
|
|
|
|
self.k = k
|
|
self.op_names = deepcopy(op_names)
|
|
self.edges = nn.ModuleDict()
|
|
self.max_nodes = max_nodes
|
|
self.in_dim = C_in
|
|
self.out_dim = C_out
|
|
for i in range(1, max_nodes):
|
|
for j in range(i):
|
|
node_str = '{:}<-{:}'.format(i, j)
|
|
if j == 0:
|
|
xlists = [OPS[op_name](C_in//self.k , C_out//self.k, stride, affine, track_running_stats) for op_name in op_names]
|
|
else:
|
|
xlists = [OPS[op_name](C_in//self.k , C_out//self.k, 1, affine, track_running_stats) for op_name in op_names]
|
|
self.edges[ node_str ] = nn.ModuleList( xlists )
|
|
self.edge_keys = sorted(list(self.edges.keys()))
|
|
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}
|
|
self.num_edges = len(self.edges)
|
|
|
|
def MixedOp(self, x, ops, weights):
|
|
dim_2 = x.shape[1]
|
|
xtemp = x[ : , : dim_2//self.k, :, :]
|
|
xtemp2 = x[ : , dim_2//self.k:, :, :]
|
|
temp1 = sum(w * op(xtemp) for w, op in zip(weights, ops))
|
|
ans = torch.cat([temp1,xtemp2],dim=1)
|
|
ans = channel_shuffle(ans,self.k)
|
|
return ans
|
|
|
|
def forward(self, inputs, weightss):
|
|
nodes = [inputs]
|
|
for i in range(1, self.max_nodes):
|
|
inter_nodes = []
|
|
for j in range(i):
|
|
node_str = '{:}<-{:}'.format(i, j)
|
|
weights = weightss[ self.edge2index[node_str] ]
|
|
# inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) )
|
|
inter_nodes.append(self.MixedOp(x=nodes[j], ops=self.edges[node_str], weights=weights))
|
|
nodes.append( sum(inter_nodes) )
|
|
return nodes[-1]
|
|
|