MeCo/sota/cnn/model_search_darts_proj.py
HamsterMimi 189df25fd3 upload
2023-05-04 13:09:03 +08:00

213 lines
8.5 KiB
Python

import torch
from copy import deepcopy
from sota.cnn.operations import *
from sota.cnn.genotypes import Genotype
import sys
sys.path.insert(0, '../../')
from sota.cnn.model_search import Network
class DartsNetworkProj(Network):
def __init__(self, C, num_classes, layers, criterion, primitives, args,
steps=4, multiplier=4, stem_multiplier=3, drop_path_prob=0.0):
super(DartsNetworkProj, self).__init__(C, num_classes, layers, criterion, primitives, args,
steps=steps, multiplier=multiplier, stem_multiplier=stem_multiplier, drop_path_prob=drop_path_prob)
self._initialize_flags()
self._initialize_proj_weights()
self._initialize_topology_dicts()
#### proj flags
def _initialize_topology_dicts(self):
self.nid2eids = {0:[2,3,4], 1:[5,6,7,8], 2:[9,10,11,12,13]}
self.nid2selected_eids = {
'normal': {0:[],1:[],2:[]},
'reduce': {0:[],1:[],2:[]},
}
def _initialize_flags(self):
self.candidate_flags = {
'normal':torch.tensor(self.num_edges * [True], requires_grad=False, dtype=torch.bool).cuda(),
'reduce':torch.tensor(self.num_edges * [True], requires_grad=False, dtype=torch.bool).cuda(),
} # must be in this order
self.candidate_flags_edge = {
'normal': torch.tensor(3 * [True], requires_grad=False, dtype=torch.bool).cuda(),
'reduce': torch.tensor(3 * [True], requires_grad=False, dtype=torch.bool).cuda(),
}
def _initialize_proj_weights(self):
''' data structures used for proj '''
if isinstance(self.alphas_normal, list):
alphas_normal = torch.stack(self.alphas_normal, dim=0)
alphas_reduce = torch.stack(self.alphas_reduce, dim=0)
else:
alphas_normal = self.alphas_normal
alphas_reduce = self.alphas_reduce
self.proj_weights = { # for hard/soft assignment after project
'normal': torch.zeros_like(alphas_normal),
'reduce': torch.zeros_like(alphas_reduce),
}
#### proj function
def project_op(self, eid, opid, cell_type):
self.proj_weights[cell_type][eid][opid] = 1 ## hard by default
self.candidate_flags[cell_type][eid] = False
def project_edge(self, nid, eids, cell_type):
for eid in self.nid2eids[nid]:
if eid not in eids: # not top2
self.proj_weights[cell_type][eid].data.fill_(0)
self.nid2selected_eids[cell_type][nid] = deepcopy(eids)
self.candidate_flags_edge[cell_type][nid] = False
#### critical function
def get_projected_weights(self, cell_type):
''' used in forward and genotype '''
weights = self.get_softmax()[cell_type]
## proj op
for eid in range(self.num_edges):
if not self.candidate_flags[cell_type][eid]:
weights[eid].data.copy_(self.proj_weights[cell_type][eid])
## proj edge
for nid in self.nid2eids:
if not self.candidate_flags_edge[cell_type][nid]: ## projected node
for eid in self.nid2eids[nid]:
if eid not in self.nid2selected_eids[cell_type][nid]:
weights[eid].data.copy_(self.proj_weights[cell_type][eid])
return weights
def get_all_projected_weights(self, cell_type):
weights = self.get_softmax()[cell_type]
for eid in range(self.num_edges):
weights[eid].data.copy_(self.proj_weights[cell_type][eid])
for nid in self.nid2eids:
for eid in self.nid2eids[nid]:
weights[eid].data.copy_(self.proj_weights[cell_type][eid])
return weights
def forward(self, input, weights_dict=None, using_proj=False):
if using_proj:
weights_normal = self.get_all_projected_weights('normal')
weights_reduce = self.get_all_projected_weights('reduce')
else:
if weights_dict is None or 'normal' not in weights_dict:
weights_normal = self.get_projected_weights('normal')
else:
weights_normal = weights_dict['normal']
if weights_dict is None or 'reduce' not in weights_dict:
weights_reduce = self.get_projected_weights('reduce')
else:
weights_reduce = weights_dict['reduce']
s0 = s1 = self.stem(input)
for i, cell in enumerate(self.cells):
if cell.reduction:
weights = weights_reduce
else:
weights = weights_normal
s0, s1 = s1, cell(s0, s1, weights, self.drop_path_prob)
out = self.global_pooling(s1)
logits = self.classifier(out.view(out.size(0),-1))
return logits
def reset_arch_parameters(self):
self._initialize_flags()
self._initialize_proj_weights()
self._initialize_topology_dicts()
#### utils
def printing(self, logging, option='all'):
weights_normal = self.get_projected_weights('normal')
weights_reduce = self.get_projected_weights('reduce')
if option in ['all', 'normal']:
logging.info('\n%s', weights_normal)
if option in ['all', 'reduce']:
logging.info('\n%s', weights_reduce)
def genotype(self):
def _parse(weights, normal=True):
PRIMITIVES = self.PRIMITIVES['primitives_normal' if normal else 'primitives_reduct']
gene = []
n = 2
start = 0
for i in range(self._steps):
end = start + n
W = weights[start:end].copy()
try:
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES[x].index('none')))[:2]
except ValueError:
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x]))))[:2]
for j in edges:
k_best = None
for k in range(len(W[j])):
if 'none' in PRIMITIVES[j]:
if k != PRIMITIVES[j].index('none'):
if k_best is None or W[j][k] > W[j][k_best]:
k_best = k
else:
if k_best is None or W[j][k] > W[j][k_best]:
k_best = k
gene.append((PRIMITIVES[start+j][k_best], j))
start = end
n += 1
return gene
weights_normal = self.get_projected_weights('normal')
weights_reduce = self.get_projected_weights('reduce')
gene_normal = _parse(weights_normal.data.cpu().numpy(), True)
gene_reduce = _parse(weights_reduce.data.cpu().numpy(), False)
concat = range(2+self._steps-self._multiplier, self._steps+2)
genotype = Genotype(
normal=gene_normal, normal_concat=concat,
reduce=gene_reduce, reduce_concat=concat
)
return genotype
def get_state_dict(self, epoch, architect, scheduler):
model_state_dict = {
'epoch': epoch, ## no +1 because we are saving before projection / at the beginning of an epoch
'state_dict': self.state_dict(),
'alpha': self.arch_parameters(),
'optimizer': self.optimizer.state_dict(),
'arch_optimizer': architect.optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
#### projection
'nid2eids': self.nid2eids,
'nid2selected_eids': self.nid2selected_eids,
'candidate_flags': self.candidate_flags,
'candidate_flags_edge': self.candidate_flags_edge,
'proj_weights': self.proj_weights,
}
return model_state_dict
def set_state_dict(self, architect, scheduler, checkpoint):
#### common
self.load_state_dict(checkpoint['state_dict'])
self.set_arch_parameters(checkpoint['alpha'])
self.optimizer.load_state_dict(checkpoint['optimizer'])
architect.optimizer.load_state_dict(checkpoint['arch_optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])
#### projection
self.nid2eids = checkpoint['nid2eids']
self.nid2selected_eids = checkpoint['nid2selected_eids']
self.candidate_flags = checkpoint['candidate_flags']
self.candidate_flags_edge = checkpoint['candidate_flags_edge']
self.proj_weights = checkpoint['proj_weights']