213 lines
8.5 KiB
Python
213 lines
8.5 KiB
Python
import torch
|
|
from copy import deepcopy
|
|
|
|
from sota.cnn.operations import *
|
|
from sota.cnn.genotypes import Genotype
|
|
import sys
|
|
sys.path.insert(0, '../../')
|
|
from sota.cnn.model_search import Network
|
|
|
|
class DartsNetworkProj(Network):
|
|
def __init__(self, C, num_classes, layers, criterion, primitives, args,
|
|
steps=4, multiplier=4, stem_multiplier=3, drop_path_prob=0.0):
|
|
super(DartsNetworkProj, self).__init__(C, num_classes, layers, criterion, primitives, args,
|
|
steps=steps, multiplier=multiplier, stem_multiplier=stem_multiplier, drop_path_prob=drop_path_prob)
|
|
|
|
self._initialize_flags()
|
|
self._initialize_proj_weights()
|
|
self._initialize_topology_dicts()
|
|
|
|
#### proj flags
|
|
def _initialize_topology_dicts(self):
|
|
self.nid2eids = {0:[2,3,4], 1:[5,6,7,8], 2:[9,10,11,12,13]}
|
|
self.nid2selected_eids = {
|
|
'normal': {0:[],1:[],2:[]},
|
|
'reduce': {0:[],1:[],2:[]},
|
|
}
|
|
|
|
def _initialize_flags(self):
|
|
self.candidate_flags = {
|
|
'normal':torch.tensor(self.num_edges * [True], requires_grad=False, dtype=torch.bool).cuda(),
|
|
'reduce':torch.tensor(self.num_edges * [True], requires_grad=False, dtype=torch.bool).cuda(),
|
|
} # must be in this order
|
|
self.candidate_flags_edge = {
|
|
'normal': torch.tensor(3 * [True], requires_grad=False, dtype=torch.bool).cuda(),
|
|
'reduce': torch.tensor(3 * [True], requires_grad=False, dtype=torch.bool).cuda(),
|
|
}
|
|
|
|
def _initialize_proj_weights(self):
|
|
''' data structures used for proj '''
|
|
if isinstance(self.alphas_normal, list):
|
|
alphas_normal = torch.stack(self.alphas_normal, dim=0)
|
|
alphas_reduce = torch.stack(self.alphas_reduce, dim=0)
|
|
else:
|
|
alphas_normal = self.alphas_normal
|
|
alphas_reduce = self.alphas_reduce
|
|
|
|
self.proj_weights = { # for hard/soft assignment after project
|
|
'normal': torch.zeros_like(alphas_normal),
|
|
'reduce': torch.zeros_like(alphas_reduce),
|
|
}
|
|
|
|
#### proj function
|
|
def project_op(self, eid, opid, cell_type):
|
|
self.proj_weights[cell_type][eid][opid] = 1 ## hard by default
|
|
self.candidate_flags[cell_type][eid] = False
|
|
|
|
def project_edge(self, nid, eids, cell_type):
|
|
for eid in self.nid2eids[nid]:
|
|
if eid not in eids: # not top2
|
|
self.proj_weights[cell_type][eid].data.fill_(0)
|
|
self.nid2selected_eids[cell_type][nid] = deepcopy(eids)
|
|
self.candidate_flags_edge[cell_type][nid] = False
|
|
|
|
#### critical function
|
|
def get_projected_weights(self, cell_type):
|
|
''' used in forward and genotype '''
|
|
weights = self.get_softmax()[cell_type]
|
|
|
|
## proj op
|
|
for eid in range(self.num_edges):
|
|
if not self.candidate_flags[cell_type][eid]:
|
|
weights[eid].data.copy_(self.proj_weights[cell_type][eid])
|
|
|
|
## proj edge
|
|
for nid in self.nid2eids:
|
|
if not self.candidate_flags_edge[cell_type][nid]: ## projected node
|
|
for eid in self.nid2eids[nid]:
|
|
if eid not in self.nid2selected_eids[cell_type][nid]:
|
|
weights[eid].data.copy_(self.proj_weights[cell_type][eid])
|
|
|
|
return weights
|
|
|
|
def get_all_projected_weights(self, cell_type):
|
|
weights = self.get_softmax()[cell_type]
|
|
|
|
for eid in range(self.num_edges):
|
|
weights[eid].data.copy_(self.proj_weights[cell_type][eid])
|
|
|
|
for nid in self.nid2eids:
|
|
for eid in self.nid2eids[nid]:
|
|
weights[eid].data.copy_(self.proj_weights[cell_type][eid])
|
|
|
|
return weights
|
|
|
|
def forward(self, input, weights_dict=None, using_proj=False):
|
|
if using_proj:
|
|
weights_normal = self.get_all_projected_weights('normal')
|
|
weights_reduce = self.get_all_projected_weights('reduce')
|
|
else:
|
|
if weights_dict is None or 'normal' not in weights_dict:
|
|
weights_normal = self.get_projected_weights('normal')
|
|
else:
|
|
weights_normal = weights_dict['normal']
|
|
if weights_dict is None or 'reduce' not in weights_dict:
|
|
weights_reduce = self.get_projected_weights('reduce')
|
|
else:
|
|
weights_reduce = weights_dict['reduce']
|
|
|
|
|
|
|
|
s0 = s1 = self.stem(input)
|
|
for i, cell in enumerate(self.cells):
|
|
if cell.reduction:
|
|
weights = weights_reduce
|
|
else:
|
|
weights = weights_normal
|
|
|
|
s0, s1 = s1, cell(s0, s1, weights, self.drop_path_prob)
|
|
|
|
out = self.global_pooling(s1)
|
|
logits = self.classifier(out.view(out.size(0),-1))
|
|
|
|
return logits
|
|
|
|
def reset_arch_parameters(self):
|
|
self._initialize_flags()
|
|
self._initialize_proj_weights()
|
|
self._initialize_topology_dicts()
|
|
|
|
#### utils
|
|
def printing(self, logging, option='all'):
|
|
weights_normal = self.get_projected_weights('normal')
|
|
weights_reduce = self.get_projected_weights('reduce')
|
|
|
|
if option in ['all', 'normal']:
|
|
logging.info('\n%s', weights_normal)
|
|
if option in ['all', 'reduce']:
|
|
logging.info('\n%s', weights_reduce)
|
|
|
|
def genotype(self):
|
|
def _parse(weights, normal=True):
|
|
PRIMITIVES = self.PRIMITIVES['primitives_normal' if normal else 'primitives_reduct']
|
|
|
|
gene = []
|
|
n = 2
|
|
start = 0
|
|
for i in range(self._steps):
|
|
end = start + n
|
|
W = weights[start:end].copy()
|
|
|
|
try:
|
|
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES[x].index('none')))[:2]
|
|
except ValueError:
|
|
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x]))))[:2]
|
|
|
|
for j in edges:
|
|
k_best = None
|
|
for k in range(len(W[j])):
|
|
if 'none' in PRIMITIVES[j]:
|
|
if k != PRIMITIVES[j].index('none'):
|
|
if k_best is None or W[j][k] > W[j][k_best]:
|
|
k_best = k
|
|
else:
|
|
if k_best is None or W[j][k] > W[j][k_best]:
|
|
k_best = k
|
|
gene.append((PRIMITIVES[start+j][k_best], j))
|
|
start = end
|
|
n += 1
|
|
return gene
|
|
|
|
weights_normal = self.get_projected_weights('normal')
|
|
weights_reduce = self.get_projected_weights('reduce')
|
|
gene_normal = _parse(weights_normal.data.cpu().numpy(), True)
|
|
gene_reduce = _parse(weights_reduce.data.cpu().numpy(), False)
|
|
|
|
concat = range(2+self._steps-self._multiplier, self._steps+2)
|
|
genotype = Genotype(
|
|
normal=gene_normal, normal_concat=concat,
|
|
reduce=gene_reduce, reduce_concat=concat
|
|
)
|
|
return genotype
|
|
|
|
def get_state_dict(self, epoch, architect, scheduler):
|
|
model_state_dict = {
|
|
'epoch': epoch, ## no +1 because we are saving before projection / at the beginning of an epoch
|
|
'state_dict': self.state_dict(),
|
|
'alpha': self.arch_parameters(),
|
|
'optimizer': self.optimizer.state_dict(),
|
|
'arch_optimizer': architect.optimizer.state_dict(),
|
|
'scheduler': scheduler.state_dict(),
|
|
#### projection
|
|
'nid2eids': self.nid2eids,
|
|
'nid2selected_eids': self.nid2selected_eids,
|
|
'candidate_flags': self.candidate_flags,
|
|
'candidate_flags_edge': self.candidate_flags_edge,
|
|
'proj_weights': self.proj_weights,
|
|
}
|
|
return model_state_dict
|
|
|
|
def set_state_dict(self, architect, scheduler, checkpoint):
|
|
#### common
|
|
self.load_state_dict(checkpoint['state_dict'])
|
|
self.set_arch_parameters(checkpoint['alpha'])
|
|
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
|
architect.optimizer.load_state_dict(checkpoint['arch_optimizer'])
|
|
scheduler.load_state_dict(checkpoint['scheduler'])
|
|
|
|
#### projection
|
|
self.nid2eids = checkpoint['nid2eids']
|
|
self.nid2selected_eids = checkpoint['nid2selected_eids']
|
|
self.candidate_flags = checkpoint['candidate_flags']
|
|
self.candidate_flags_edge = checkpoint['candidate_flags_edge']
|
|
self.proj_weights = checkpoint['proj_weights'] |