update SETN
This commit is contained in:
		| @@ -11,11 +11,16 @@ from models     import CellStructure | ||||
|  | ||||
| def get_unique_matrix(archs, consider_zero): | ||||
|   UniquStrs = [arch.to_unique_str(consider_zero) for arch in archs] | ||||
|   print ('{:} create unique-string done'.format(time_string())) | ||||
|   print ('{:} create unique-string ({:}/{:}) done'.format(time_string(), len(set(UniquStrs)), len(UniquStrs))) | ||||
|   Unique2Index = dict() | ||||
|   for index, xstr in enumerate(UniquStrs): | ||||
|     if xstr not in Unique2Index: Unique2Index[xstr] = list() | ||||
|     Unique2Index[xstr].append( index ) | ||||
|   sm_matrix = torch.eye(len(archs)).bool() | ||||
|   for i, _ in enumerate(UniquStrs): | ||||
|     for j in range(i): | ||||
|       sm_matrix[i,j] = sm_matrix[j,i] = UniquStrs[i] == UniquStrs[j] | ||||
|   for _, xlist in Unique2Index.items(): | ||||
|     for i in xlist: | ||||
|       for j in xlist: | ||||
|         sm_matrix[i,j] = True | ||||
|   unique_ids, unique_num = [-1 for _ in archs], 0 | ||||
|   for i in range(len(unique_ids)): | ||||
|     if unique_ids[i] > -1: continue | ||||
|   | ||||
| @@ -76,22 +76,22 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer | ||||
| def get_best_arch(xloader, network, n_samples): | ||||
|   with torch.no_grad(): | ||||
|     network.eval() | ||||
|     archs, valid_accs = [], [] | ||||
|     archs, valid_accs = network.module.return_topK(n_samples), [] | ||||
|     #print ('obtain the top-{:} architectures'.format(n_samples)) | ||||
|     loader_iter = iter(xloader) | ||||
|     for i in range(n_samples): | ||||
|     for i, sampled_arch in enumerate(archs): | ||||
|       network.module.set_cal_mode('dynamic', sampled_arch) | ||||
|       try: | ||||
|         inputs, targets = next(loader_iter) | ||||
|       except: | ||||
|         loader_iter = iter(xloader) | ||||
|         inputs, targets = next(loader_iter) | ||||
|  | ||||
|       sampled_arch = network.module.dync_genotype(False) | ||||
|       network.module.set_cal_mode('dynamic', sampled_arch) | ||||
|       _, logits = network(inputs) | ||||
|       val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5)) | ||||
|  | ||||
|       archs.append( sampled_arch ) | ||||
|       valid_accs.append( val_top1.item() ) | ||||
|       #print ('--- {:}/{:} : {:} : {:}'.format(i, len(archs), sampled_arch, val_top1)) | ||||
|  | ||||
|     best_idx = np.argmax(valid_accs) | ||||
|     best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx] | ||||
| @@ -221,11 +221,6 @@ def main(xargs): | ||||
|     #logger.log('[{:}] Selec-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) | ||||
|     # check the best accuracy | ||||
|     valid_accuracies[epoch] = valid_a_top1 | ||||
|     if valid_a_top1 > valid_accuracies['best']: | ||||
|       valid_accuracies['best'] = valid_a_top1 | ||||
|       genotypes['best']        = search_model.genotype() | ||||
|       find_best = True | ||||
|     else: find_best = False | ||||
|  | ||||
|     genotypes[epoch] = genotype | ||||
|     logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch])) | ||||
| @@ -244,16 +239,17 @@ def main(xargs): | ||||
|           'args' : deepcopy(args), | ||||
|           'last_checkpoint': save_path, | ||||
|           }, logger.path('info'), logger) | ||||
|     if find_best: | ||||
|       logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1)) | ||||
|       copy_checkpoint(model_base_path, model_best_path, logger) | ||||
|     with torch.no_grad(): | ||||
|       logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) | ||||
|     # measure elapsed time | ||||
|     epoch_time.update(time.time() - start_time) | ||||
|     start_time = time.time() | ||||
|  | ||||
|   logger.log('During searching, the best gentotype is : {:} , with the validation accuracy of {:.3f}%.'.format(genotypes['best'], valid_accuracies['best'])) | ||||
|   #logger.log('During searching, the best gentotype is : {:} , with the validation accuracy of {:.3f}%.'.format(genotypes['best'], valid_accuracies['best'])) | ||||
|   genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num) | ||||
|   network.module.set_cal_mode('dynamic', genotype) | ||||
|   valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) | ||||
|   logger.log('Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.'.format(genotype, valid_a_top1)) | ||||
|   # sampling | ||||
|   """ | ||||
|   with torch.no_grad(): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user