diff --git a/.gitignore b/.gitignore index a2a45a5..562720b 100644 --- a/.gitignore +++ b/.gitignore @@ -110,3 +110,4 @@ logs # snapshot a.pth +cal-merge.sh diff --git a/exps/AA-NAS-test-API.py b/exps/AA-NAS-test-API.py index 3a6a79c..6967ba8 100644 --- a/exps/AA-NAS-test-API.py +++ b/exps/AA-NAS-test-API.py @@ -11,11 +11,16 @@ from models import CellStructure def get_unique_matrix(archs, consider_zero): UniquStrs = [arch.to_unique_str(consider_zero) for arch in archs] - print ('{:} create unique-string done'.format(time_string())) + print ('{:} create unique-string ({:}/{:}) done'.format(time_string(), len(set(UniquStrs)), len(UniquStrs))) + Unique2Index = dict() + for index, xstr in enumerate(UniquStrs): + if xstr not in Unique2Index: Unique2Index[xstr] = list() + Unique2Index[xstr].append( index ) sm_matrix = torch.eye(len(archs)).bool() - for i, _ in enumerate(UniquStrs): - for j in range(i): - sm_matrix[i,j] = sm_matrix[j,i] = UniquStrs[i] == UniquStrs[j] + for _, xlist in Unique2Index.items(): + for i in xlist: + for j in xlist: + sm_matrix[i,j] = True unique_ids, unique_num = [-1 for _ in archs], 0 for i in range(len(unique_ids)): if unique_ids[i] > -1: continue diff --git a/exps/algos/SETN.py b/exps/algos/SETN.py index 9703c34..049e600 100644 --- a/exps/algos/SETN.py +++ b/exps/algos/SETN.py @@ -76,22 +76,22 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer def get_best_arch(xloader, network, n_samples): with torch.no_grad(): network.eval() - archs, valid_accs = [], [] + archs, valid_accs = network.module.return_topK(n_samples), [] + #print ('obtain the top-{:} architectures'.format(n_samples)) loader_iter = iter(xloader) - for i in range(n_samples): + for i, sampled_arch in enumerate(archs): + network.module.set_cal_mode('dynamic', sampled_arch) try: inputs, targets = next(loader_iter) except: loader_iter = iter(xloader) inputs, targets = next(loader_iter) - sampled_arch = network.module.dync_genotype(False) - network.module.set_cal_mode('dynamic', sampled_arch) _, logits = network(inputs) val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5)) - archs.append( sampled_arch ) valid_accs.append( val_top1.item() ) + #print ('--- {:}/{:} : {:} : {:}'.format(i, len(archs), sampled_arch, val_top1)) best_idx = np.argmax(valid_accs) best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx] @@ -221,11 +221,6 @@ def main(xargs): #logger.log('[{:}] Selec-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) # check the best accuracy valid_accuracies[epoch] = valid_a_top1 - if valid_a_top1 > valid_accuracies['best']: - valid_accuracies['best'] = valid_a_top1 - genotypes['best'] = search_model.genotype() - find_best = True - else: find_best = False genotypes[epoch] = genotype logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch])) @@ -244,16 +239,17 @@ def main(xargs): 'args' : deepcopy(args), 'last_checkpoint': save_path, }, logger.path('info'), logger) - if find_best: - logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1)) - copy_checkpoint(model_base_path, model_best_path, logger) with torch.no_grad(): logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() - logger.log('During searching, the best gentotype is : {:} , with the validation accuracy of {:.3f}%.'.format(genotypes['best'], valid_accuracies['best'])) + #logger.log('During searching, the best gentotype is : {:} , with the validation accuracy of {:.3f}%.'.format(genotypes['best'], valid_accuracies['best'])) + genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num) + network.module.set_cal_mode('dynamic', genotype) + valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) + logger.log('Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.'.format(genotype, valid_a_top1)) # sampling """ with torch.no_grad(): diff --git a/lib/models/cell_searchs/genotypes.py b/lib/models/cell_searchs/genotypes.py index 208dc0a..6bd8af8 100644 --- a/lib/models/cell_searchs/genotypes.py +++ b/lib/models/cell_searchs/genotypes.py @@ -81,10 +81,10 @@ class Structure: if consider_zero: if op == 'none' or nodes[xin] == '#': x = '#' # zero elif op == 'skip_connect': x = nodes[xin] - else: x = nodes[xin] + '@{:}'.format(op) + else: x = '('+nodes[xin]+')' + '@{:}'.format(op) else: if op == 'skip_connect': x = nodes[xin] - else: x = nodes[xin] + '@{:}'.format(op) + else: x = '('+nodes[xin]+')' + '@{:}'.format(op) cur_node.append(x) nodes[i_node+1] = '+'.join( sorted(cur_node) ) return nodes[ len(self.nodes) ] diff --git a/lib/models/cell_searchs/search_model_setn.py b/lib/models/cell_searchs/search_model_setn.py index 316c88d..6d60d55 100644 --- a/lib/models/cell_searchs/search_model_setn.py +++ b/lib/models/cell_searchs/search_model_setn.py @@ -84,7 +84,6 @@ class TinyNetworkSETN(nn.Module): genotypes.append( tuple(xlist) ) return Structure( genotypes ) - def dync_genotype(self, use_random=False): genotypes = [] with torch.no_grad(): @@ -103,6 +102,26 @@ class TinyNetworkSETN(nn.Module): genotypes.append( tuple(xlist) ) return Structure( genotypes ) + def get_log_prob(self, arch): + with torch.no_grad(): + logits = nn.functional.log_softmax(self.arch_parameters, dim=-1) + select_logits = [] + for i, node_info in enumerate(arch.nodes): + for op, xin in node_info: + node_str = '{:}<-{:}'.format(i+1, xin) + op_index = self.op_names.index(op) + select_logits.append( logits[self.edge2index[node_str], op_index] ) + return sum(select_logits).item() + + + def return_topK(self, K): + archs = Structure.gen_all(self.op_names, self.max_nodes, False) + pairs = [(self.get_log_prob(arch), arch) for arch in archs] + if K < 0 or K >= len(archs): K = len(archs) + sorted_pairs = sorted(pairs, key=lambda x: -x[0]) + return_pairs = [sorted_pairs[_][1] for _ in range(K)] + return return_pairs + def forward(self, inputs): alphas = nn.functional.softmax(self.arch_parameters, dim=-1)