update SETN
This commit is contained in:
parent
7b354d4c74
commit
5c73aeb50b
1
.gitignore
vendored
1
.gitignore
vendored
@ -110,3 +110,4 @@ logs
|
||||
|
||||
# snapshot
|
||||
a.pth
|
||||
cal-merge.sh
|
||||
|
@ -11,11 +11,16 @@ from models import CellStructure
|
||||
|
||||
def get_unique_matrix(archs, consider_zero):
|
||||
UniquStrs = [arch.to_unique_str(consider_zero) for arch in archs]
|
||||
print ('{:} create unique-string done'.format(time_string()))
|
||||
print ('{:} create unique-string ({:}/{:}) done'.format(time_string(), len(set(UniquStrs)), len(UniquStrs)))
|
||||
Unique2Index = dict()
|
||||
for index, xstr in enumerate(UniquStrs):
|
||||
if xstr not in Unique2Index: Unique2Index[xstr] = list()
|
||||
Unique2Index[xstr].append( index )
|
||||
sm_matrix = torch.eye(len(archs)).bool()
|
||||
for i, _ in enumerate(UniquStrs):
|
||||
for j in range(i):
|
||||
sm_matrix[i,j] = sm_matrix[j,i] = UniquStrs[i] == UniquStrs[j]
|
||||
for _, xlist in Unique2Index.items():
|
||||
for i in xlist:
|
||||
for j in xlist:
|
||||
sm_matrix[i,j] = True
|
||||
unique_ids, unique_num = [-1 for _ in archs], 0
|
||||
for i in range(len(unique_ids)):
|
||||
if unique_ids[i] > -1: continue
|
||||
|
@ -76,22 +76,22 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer
|
||||
def get_best_arch(xloader, network, n_samples):
|
||||
with torch.no_grad():
|
||||
network.eval()
|
||||
archs, valid_accs = [], []
|
||||
archs, valid_accs = network.module.return_topK(n_samples), []
|
||||
#print ('obtain the top-{:} architectures'.format(n_samples))
|
||||
loader_iter = iter(xloader)
|
||||
for i in range(n_samples):
|
||||
for i, sampled_arch in enumerate(archs):
|
||||
network.module.set_cal_mode('dynamic', sampled_arch)
|
||||
try:
|
||||
inputs, targets = next(loader_iter)
|
||||
except:
|
||||
loader_iter = iter(xloader)
|
||||
inputs, targets = next(loader_iter)
|
||||
|
||||
sampled_arch = network.module.dync_genotype(False)
|
||||
network.module.set_cal_mode('dynamic', sampled_arch)
|
||||
_, logits = network(inputs)
|
||||
val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5))
|
||||
|
||||
archs.append( sampled_arch )
|
||||
valid_accs.append( val_top1.item() )
|
||||
#print ('--- {:}/{:} : {:} : {:}'.format(i, len(archs), sampled_arch, val_top1))
|
||||
|
||||
best_idx = np.argmax(valid_accs)
|
||||
best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx]
|
||||
@ -221,11 +221,6 @@ def main(xargs):
|
||||
#logger.log('[{:}] Selec-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
|
||||
# check the best accuracy
|
||||
valid_accuracies[epoch] = valid_a_top1
|
||||
if valid_a_top1 > valid_accuracies['best']:
|
||||
valid_accuracies['best'] = valid_a_top1
|
||||
genotypes['best'] = search_model.genotype()
|
||||
find_best = True
|
||||
else: find_best = False
|
||||
|
||||
genotypes[epoch] = genotype
|
||||
logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch]))
|
||||
@ -244,16 +239,17 @@ def main(xargs):
|
||||
'args' : deepcopy(args),
|
||||
'last_checkpoint': save_path,
|
||||
}, logger.path('info'), logger)
|
||||
if find_best:
|
||||
logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1))
|
||||
copy_checkpoint(model_base_path, model_best_path, logger)
|
||||
with torch.no_grad():
|
||||
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
|
||||
logger.log('During searching, the best gentotype is : {:} , with the validation accuracy of {:.3f}%.'.format(genotypes['best'], valid_accuracies['best']))
|
||||
#logger.log('During searching, the best gentotype is : {:} , with the validation accuracy of {:.3f}%.'.format(genotypes['best'], valid_accuracies['best']))
|
||||
genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num)
|
||||
network.module.set_cal_mode('dynamic', genotype)
|
||||
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
|
||||
logger.log('Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.'.format(genotype, valid_a_top1))
|
||||
# sampling
|
||||
"""
|
||||
with torch.no_grad():
|
||||
|
@ -81,10 +81,10 @@ class Structure:
|
||||
if consider_zero:
|
||||
if op == 'none' or nodes[xin] == '#': x = '#' # zero
|
||||
elif op == 'skip_connect': x = nodes[xin]
|
||||
else: x = nodes[xin] + '@{:}'.format(op)
|
||||
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
|
||||
else:
|
||||
if op == 'skip_connect': x = nodes[xin]
|
||||
else: x = nodes[xin] + '@{:}'.format(op)
|
||||
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
|
||||
cur_node.append(x)
|
||||
nodes[i_node+1] = '+'.join( sorted(cur_node) )
|
||||
return nodes[ len(self.nodes) ]
|
||||
|
@ -84,7 +84,6 @@ class TinyNetworkSETN(nn.Module):
|
||||
genotypes.append( tuple(xlist) )
|
||||
return Structure( genotypes )
|
||||
|
||||
|
||||
def dync_genotype(self, use_random=False):
|
||||
genotypes = []
|
||||
with torch.no_grad():
|
||||
@ -103,6 +102,26 @@ class TinyNetworkSETN(nn.Module):
|
||||
genotypes.append( tuple(xlist) )
|
||||
return Structure( genotypes )
|
||||
|
||||
def get_log_prob(self, arch):
|
||||
with torch.no_grad():
|
||||
logits = nn.functional.log_softmax(self.arch_parameters, dim=-1)
|
||||
select_logits = []
|
||||
for i, node_info in enumerate(arch.nodes):
|
||||
for op, xin in node_info:
|
||||
node_str = '{:}<-{:}'.format(i+1, xin)
|
||||
op_index = self.op_names.index(op)
|
||||
select_logits.append( logits[self.edge2index[node_str], op_index] )
|
||||
return sum(select_logits).item()
|
||||
|
||||
|
||||
def return_topK(self, K):
|
||||
archs = Structure.gen_all(self.op_names, self.max_nodes, False)
|
||||
pairs = [(self.get_log_prob(arch), arch) for arch in archs]
|
||||
if K < 0 or K >= len(archs): K = len(archs)
|
||||
sorted_pairs = sorted(pairs, key=lambda x: -x[0])
|
||||
return_pairs = [sorted_pairs[_][1] for _ in range(K)]
|
||||
return return_pairs
|
||||
|
||||
|
||||
def forward(self, inputs):
|
||||
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
|
||||
|
Loading…
Reference in New Issue
Block a user