61 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			61 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import torch
 | |
| import torch.nn.functional as F
 | |
| 
 | |
| 
 | |
| def drop_path(x, drop_prob):
 | |
|   if drop_prob > 0.:
 | |
|     keep_prob = 1. - drop_prob
 | |
|     mask = x.new_zeros(x.size(0), 1, 1, 1)
 | |
|     mask = mask.bernoulli_(keep_prob)
 | |
|     x = torch.div(x, keep_prob)
 | |
|     x.mul_(mask)
 | |
|   return x
 | |
| 
 | |
| 
 | |
| def return_alphas_str(basemodel):
 | |
|   if hasattr(basemodel, 'alphas_normal'):
 | |
|     string = 'normal [{:}] : \n-->>{:}'.format(basemodel.alphas_normal.size(), F.softmax(basemodel.alphas_normal, dim=-1) )
 | |
|   else: string = ''
 | |
|   if hasattr(basemodel, 'alphas_reduce'):
 | |
|     string = string + '\nreduce : {:}'.format( F.softmax(basemodel.alphas_reduce, dim=-1) )
 | |
| 
 | |
|   if hasattr(basemodel, 'get_adjacency'):
 | |
|     adjacency = basemodel.get_adjacency()
 | |
|     for i in range( len(adjacency) ):
 | |
|       weight = F.softmax( basemodel.connect_normal[str(i)], dim=-1 )
 | |
|       adj = torch.mm(weight, adjacency[i]).view(-1)
 | |
|       adj = ['{:3.3f}'.format(x) for x in adj.cpu().tolist()]
 | |
|       string = string + '\nnormal--{:}-->{:}'.format(i, ', '.join(adj))
 | |
|     for i in range( len(adjacency) ):
 | |
|       weight = F.softmax( basemodel.connect_reduce[str(i)], dim=-1 )
 | |
|       adj = torch.mm(weight, adjacency[i]).view(-1)
 | |
|       adj = ['{:3.3f}'.format(x) for x in adj.cpu().tolist()]
 | |
|       string = string + '\nreduce--{:}-->{:}'.format(i, ', '.join(adj))
 | |
| 
 | |
|   if hasattr(basemodel, 'alphas_connect'):
 | |
|     weight = F.softmax(basemodel.alphas_connect, dim=-1).cpu()
 | |
|     ZERO = ['{:.3f}'.format(x) for x in weight[:,0].tolist()]
 | |
|     IDEN = ['{:.3f}'.format(x) for x in weight[:,1].tolist()]
 | |
|     string = string + '\nconnect [{:}] : \n ->{:}\n ->{:}'.format( list(basemodel.alphas_connect.size()), ZERO, IDEN )
 | |
|   else:
 | |
|     string = string + '\nconnect = None'
 | |
|   
 | |
|   if hasattr(basemodel, 'get_gcn_out'):
 | |
|     outputs = basemodel.get_gcn_out(True)
 | |
|     for i, output in enumerate(outputs):
 | |
|       string = string + '\nnormal:[{:}] : {:}'.format(i, F.softmax(output, dim=-1) )
 | |
| 
 | |
|   return string
 | |
| 
 | |
| 
 | |
| def remove_duplicate_archs(all_archs):
 | |
|   archs = []
 | |
|   str_archs = ['{:}'.format(x) for x in all_archs]
 | |
|   for i, arch_x in enumerate(str_archs):
 | |
|     choose = True
 | |
|     for j in range(i):
 | |
|       if arch_x == str_archs[j]:
 | |
|         choose = False; break
 | |
|     if choose: archs.append(all_archs[i])
 | |
|   return archs
 |