51 lines
		
	
	
		
			1.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			51 lines
		
	
	
		
			1.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| ##################################################
 | |
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
 | |
| ##################################################
 | |
| import math, random
 | |
| import tensorflow as tf
 | |
| from copy import deepcopy
 | |
| from ..cell_operations import OPS
 | |
| 
 | |
| 
 | |
| class SearchCell(tf.keras.layers.Layer):
 | |
| 
 | |
|   def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False):
 | |
|     super(SearchCell, self).__init__()
 | |
| 
 | |
|     self.op_names  = deepcopy(op_names)
 | |
|     self.max_nodes = max_nodes
 | |
|     self.in_dim    = C_in
 | |
|     self.out_dim   = C_out
 | |
|     self.edge_keys = []
 | |
|     for i in range(1, max_nodes):
 | |
|       for j in range(i):
 | |
|         node_str = '{:}<-{:}'.format(i, j)
 | |
|         if j == 0:
 | |
|           xlists = [OPS[op_name](C_in , C_out, stride, affine) for op_name in op_names]
 | |
|         else:
 | |
|           xlists = [OPS[op_name](C_in , C_out,      1, affine) for op_name in op_names]
 | |
|         for k, op in enumerate(xlists):
 | |
|           setattr(self, '{:}.{:}'.format(node_str, k), op)
 | |
|         self.edge_keys.append( node_str )
 | |
|     self.edge_keys  = sorted(self.edge_keys)
 | |
|     self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}
 | |
|     self.num_edges  = len(self.edge_keys)
 | |
| 
 | |
|   def call(self, inputs, weightss, training):
 | |
|     w_lst = tf.split(weightss, self.num_edges, 0)
 | |
|     nodes = [inputs]
 | |
|     for i in range(1, self.max_nodes):
 | |
|       inter_nodes = []
 | |
|       for j in range(i):
 | |
|         node_str = '{:}<-{:}'.format(i, j)
 | |
|         edge_idx = self.edge2index[node_str]
 | |
|         op_outps = []
 | |
|         for k, op_name in enumerate(self.op_names):
 | |
|           op = getattr(self, '{:}.{:}'.format(node_str, k))
 | |
|           op_outps.append( op(nodes[j], training) )
 | |
|         stack_op_outs = tf.stack(op_outps, axis=-1)
 | |
|         weighted_sums = tf.math.multiply(stack_op_outs, w_lst[edge_idx])
 | |
|         inter_nodes.append( tf.math.reduce_sum(weighted_sums, axis=-1) )
 | |
|       nodes.append( tf.math.add_n(inter_nodes) )
 | |
|     return nodes[-1]
 |