update SETN
This commit is contained in:
		| @@ -81,10 +81,10 @@ class Structure: | ||||
|         if consider_zero: | ||||
|           if op == 'none' or nodes[xin] == '#': x = '#' # zero | ||||
|           elif op == 'skip_connect': x = nodes[xin] | ||||
|           else: x = nodes[xin] + '@{:}'.format(op) | ||||
|           else: x = '('+nodes[xin]+')' + '@{:}'.format(op) | ||||
|         else: | ||||
|           if op == 'skip_connect': x = nodes[xin] | ||||
|           else: x = nodes[xin] + '@{:}'.format(op) | ||||
|           else: x = '('+nodes[xin]+')' + '@{:}'.format(op) | ||||
|         cur_node.append(x) | ||||
|       nodes[i_node+1] = '+'.join( sorted(cur_node) ) | ||||
|     return nodes[ len(self.nodes) ] | ||||
|   | ||||
| @@ -84,7 +84,6 @@ class TinyNetworkSETN(nn.Module): | ||||
|       genotypes.append( tuple(xlist) ) | ||||
|     return Structure( genotypes ) | ||||
|  | ||||
|  | ||||
|   def dync_genotype(self, use_random=False): | ||||
|     genotypes = [] | ||||
|     with torch.no_grad(): | ||||
| @@ -103,6 +102,26 @@ class TinyNetworkSETN(nn.Module): | ||||
|       genotypes.append( tuple(xlist) ) | ||||
|     return Structure( genotypes ) | ||||
|  | ||||
|   def get_log_prob(self, arch): | ||||
|     with torch.no_grad(): | ||||
|       logits = nn.functional.log_softmax(self.arch_parameters, dim=-1) | ||||
|     select_logits = [] | ||||
|     for i, node_info in enumerate(arch.nodes): | ||||
|       for op, xin in node_info: | ||||
|         node_str = '{:}<-{:}'.format(i+1, xin) | ||||
|         op_index = self.op_names.index(op) | ||||
|         select_logits.append( logits[self.edge2index[node_str], op_index] ) | ||||
|     return sum(select_logits).item() | ||||
|  | ||||
|  | ||||
|   def return_topK(self, K): | ||||
|     archs = Structure.gen_all(self.op_names, self.max_nodes, False) | ||||
|     pairs = [(self.get_log_prob(arch), arch) for arch in archs] | ||||
|     if K < 0 or K >= len(archs): K = len(archs) | ||||
|     sorted_pairs = sorted(pairs, key=lambda x: -x[0]) | ||||
|     return_pairs = [sorted_pairs[_][1] for _ in range(K)] | ||||
|     return return_pairs | ||||
|  | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     alphas  = nn.functional.softmax(self.arch_parameters, dim=-1) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user