update SETN
This commit is contained in:
		
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -110,3 +110,4 @@ logs | |||||||
|  |  | ||||||
| # snapshot | # snapshot | ||||||
| a.pth | a.pth | ||||||
|  | cal-merge.sh | ||||||
|   | |||||||
| @@ -11,11 +11,16 @@ from models     import CellStructure | |||||||
|  |  | ||||||
| def get_unique_matrix(archs, consider_zero): | def get_unique_matrix(archs, consider_zero): | ||||||
|   UniquStrs = [arch.to_unique_str(consider_zero) for arch in archs] |   UniquStrs = [arch.to_unique_str(consider_zero) for arch in archs] | ||||||
|   print ('{:} create unique-string done'.format(time_string())) |   print ('{:} create unique-string ({:}/{:}) done'.format(time_string(), len(set(UniquStrs)), len(UniquStrs))) | ||||||
|  |   Unique2Index = dict() | ||||||
|  |   for index, xstr in enumerate(UniquStrs): | ||||||
|  |     if xstr not in Unique2Index: Unique2Index[xstr] = list() | ||||||
|  |     Unique2Index[xstr].append( index ) | ||||||
|   sm_matrix = torch.eye(len(archs)).bool() |   sm_matrix = torch.eye(len(archs)).bool() | ||||||
|   for i, _ in enumerate(UniquStrs): |   for _, xlist in Unique2Index.items(): | ||||||
|     for j in range(i): |     for i in xlist: | ||||||
|       sm_matrix[i,j] = sm_matrix[j,i] = UniquStrs[i] == UniquStrs[j] |       for j in xlist: | ||||||
|  |         sm_matrix[i,j] = True | ||||||
|   unique_ids, unique_num = [-1 for _ in archs], 0 |   unique_ids, unique_num = [-1 for _ in archs], 0 | ||||||
|   for i in range(len(unique_ids)): |   for i in range(len(unique_ids)): | ||||||
|     if unique_ids[i] > -1: continue |     if unique_ids[i] > -1: continue | ||||||
|   | |||||||
| @@ -76,22 +76,22 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer | |||||||
| def get_best_arch(xloader, network, n_samples): | def get_best_arch(xloader, network, n_samples): | ||||||
|   with torch.no_grad(): |   with torch.no_grad(): | ||||||
|     network.eval() |     network.eval() | ||||||
|     archs, valid_accs = [], [] |     archs, valid_accs = network.module.return_topK(n_samples), [] | ||||||
|  |     #print ('obtain the top-{:} architectures'.format(n_samples)) | ||||||
|     loader_iter = iter(xloader) |     loader_iter = iter(xloader) | ||||||
|     for i in range(n_samples): |     for i, sampled_arch in enumerate(archs): | ||||||
|  |       network.module.set_cal_mode('dynamic', sampled_arch) | ||||||
|       try: |       try: | ||||||
|         inputs, targets = next(loader_iter) |         inputs, targets = next(loader_iter) | ||||||
|       except: |       except: | ||||||
|         loader_iter = iter(xloader) |         loader_iter = iter(xloader) | ||||||
|         inputs, targets = next(loader_iter) |         inputs, targets = next(loader_iter) | ||||||
|  |  | ||||||
|       sampled_arch = network.module.dync_genotype(False) |  | ||||||
|       network.module.set_cal_mode('dynamic', sampled_arch) |  | ||||||
|       _, logits = network(inputs) |       _, logits = network(inputs) | ||||||
|       val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5)) |       val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5)) | ||||||
|  |  | ||||||
|       archs.append( sampled_arch ) |  | ||||||
|       valid_accs.append( val_top1.item() ) |       valid_accs.append( val_top1.item() ) | ||||||
|  |       #print ('--- {:}/{:} : {:} : {:}'.format(i, len(archs), sampled_arch, val_top1)) | ||||||
|  |  | ||||||
|     best_idx = np.argmax(valid_accs) |     best_idx = np.argmax(valid_accs) | ||||||
|     best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx] |     best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx] | ||||||
| @@ -221,11 +221,6 @@ def main(xargs): | |||||||
|     #logger.log('[{:}] Selec-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) |     #logger.log('[{:}] Selec-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) | ||||||
|     # check the best accuracy |     # check the best accuracy | ||||||
|     valid_accuracies[epoch] = valid_a_top1 |     valid_accuracies[epoch] = valid_a_top1 | ||||||
|     if valid_a_top1 > valid_accuracies['best']: |  | ||||||
|       valid_accuracies['best'] = valid_a_top1 |  | ||||||
|       genotypes['best']        = search_model.genotype() |  | ||||||
|       find_best = True |  | ||||||
|     else: find_best = False |  | ||||||
|  |  | ||||||
|     genotypes[epoch] = genotype |     genotypes[epoch] = genotype | ||||||
|     logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch])) |     logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch])) | ||||||
| @@ -244,16 +239,17 @@ def main(xargs): | |||||||
|           'args' : deepcopy(args), |           'args' : deepcopy(args), | ||||||
|           'last_checkpoint': save_path, |           'last_checkpoint': save_path, | ||||||
|           }, logger.path('info'), logger) |           }, logger.path('info'), logger) | ||||||
|     if find_best: |  | ||||||
|       logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1)) |  | ||||||
|       copy_checkpoint(model_base_path, model_best_path, logger) |  | ||||||
|     with torch.no_grad(): |     with torch.no_grad(): | ||||||
|       logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) |       logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) | ||||||
|     # measure elapsed time |     # measure elapsed time | ||||||
|     epoch_time.update(time.time() - start_time) |     epoch_time.update(time.time() - start_time) | ||||||
|     start_time = time.time() |     start_time = time.time() | ||||||
|  |  | ||||||
|   logger.log('During searching, the best gentotype is : {:} , with the validation accuracy of {:.3f}%.'.format(genotypes['best'], valid_accuracies['best'])) |   #logger.log('During searching, the best gentotype is : {:} , with the validation accuracy of {:.3f}%.'.format(genotypes['best'], valid_accuracies['best'])) | ||||||
|  |   genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num) | ||||||
|  |   network.module.set_cal_mode('dynamic', genotype) | ||||||
|  |   valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) | ||||||
|  |   logger.log('Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.'.format(genotype, valid_a_top1)) | ||||||
|   # sampling |   # sampling | ||||||
|   """ |   """ | ||||||
|   with torch.no_grad(): |   with torch.no_grad(): | ||||||
|   | |||||||
| @@ -81,10 +81,10 @@ class Structure: | |||||||
|         if consider_zero: |         if consider_zero: | ||||||
|           if op == 'none' or nodes[xin] == '#': x = '#' # zero |           if op == 'none' or nodes[xin] == '#': x = '#' # zero | ||||||
|           elif op == 'skip_connect': x = nodes[xin] |           elif op == 'skip_connect': x = nodes[xin] | ||||||
|           else: x = nodes[xin] + '@{:}'.format(op) |           else: x = '('+nodes[xin]+')' + '@{:}'.format(op) | ||||||
|         else: |         else: | ||||||
|           if op == 'skip_connect': x = nodes[xin] |           if op == 'skip_connect': x = nodes[xin] | ||||||
|           else: x = nodes[xin] + '@{:}'.format(op) |           else: x = '('+nodes[xin]+')' + '@{:}'.format(op) | ||||||
|         cur_node.append(x) |         cur_node.append(x) | ||||||
|       nodes[i_node+1] = '+'.join( sorted(cur_node) ) |       nodes[i_node+1] = '+'.join( sorted(cur_node) ) | ||||||
|     return nodes[ len(self.nodes) ] |     return nodes[ len(self.nodes) ] | ||||||
|   | |||||||
| @@ -84,7 +84,6 @@ class TinyNetworkSETN(nn.Module): | |||||||
|       genotypes.append( tuple(xlist) ) |       genotypes.append( tuple(xlist) ) | ||||||
|     return Structure( genotypes ) |     return Structure( genotypes ) | ||||||
|  |  | ||||||
|  |  | ||||||
|   def dync_genotype(self, use_random=False): |   def dync_genotype(self, use_random=False): | ||||||
|     genotypes = [] |     genotypes = [] | ||||||
|     with torch.no_grad(): |     with torch.no_grad(): | ||||||
| @@ -103,6 +102,26 @@ class TinyNetworkSETN(nn.Module): | |||||||
|       genotypes.append( tuple(xlist) ) |       genotypes.append( tuple(xlist) ) | ||||||
|     return Structure( genotypes ) |     return Structure( genotypes ) | ||||||
|  |  | ||||||
|  |   def get_log_prob(self, arch): | ||||||
|  |     with torch.no_grad(): | ||||||
|  |       logits = nn.functional.log_softmax(self.arch_parameters, dim=-1) | ||||||
|  |     select_logits = [] | ||||||
|  |     for i, node_info in enumerate(arch.nodes): | ||||||
|  |       for op, xin in node_info: | ||||||
|  |         node_str = '{:}<-{:}'.format(i+1, xin) | ||||||
|  |         op_index = self.op_names.index(op) | ||||||
|  |         select_logits.append( logits[self.edge2index[node_str], op_index] ) | ||||||
|  |     return sum(select_logits).item() | ||||||
|  |  | ||||||
|  |  | ||||||
|  |   def return_topK(self, K): | ||||||
|  |     archs = Structure.gen_all(self.op_names, self.max_nodes, False) | ||||||
|  |     pairs = [(self.get_log_prob(arch), arch) for arch in archs] | ||||||
|  |     if K < 0 or K >= len(archs): K = len(archs) | ||||||
|  |     sorted_pairs = sorted(pairs, key=lambda x: -x[0]) | ||||||
|  |     return_pairs = [sorted_pairs[_][1] for _ in range(K)] | ||||||
|  |     return return_pairs | ||||||
|  |  | ||||||
|  |  | ||||||
|   def forward(self, inputs): |   def forward(self, inputs): | ||||||
|     alphas  = nn.functional.softmax(self.arch_parameters, dim=-1) |     alphas  = nn.functional.softmax(self.arch_parameters, dim=-1) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user