update NAS-Bench-102 baselines / support track_running_stats
This commit is contained in:
		| @@ -11,7 +11,7 @@ from ..cell_operations import OPS | ||||
|  | ||||
| class SearchCell(nn.Module): | ||||
|  | ||||
|   def __init__(self, C_in, C_out, stride, max_nodes, op_names): | ||||
|   def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True): | ||||
|     super(SearchCell, self).__init__() | ||||
|  | ||||
|     self.op_names  = deepcopy(op_names) | ||||
| @@ -23,9 +23,9 @@ class SearchCell(nn.Module): | ||||
|       for j in range(i): | ||||
|         node_str = '{:}<-{:}'.format(i, j) | ||||
|         if j == 0: | ||||
|           xlists = [OPS[op_name](C_in , C_out, stride, False) for op_name in op_names] | ||||
|           xlists = [OPS[op_name](C_in , C_out, stride, affine, track_running_stats) for op_name in op_names] | ||||
|         else: | ||||
|           xlists = [OPS[op_name](C_in , C_out,      1, False) for op_name in op_names] | ||||
|           xlists = [OPS[op_name](C_in , C_out,      1, affine, track_running_stats) for op_name in op_names] | ||||
|         self.edges[ node_str ] = nn.ModuleList( xlists ) | ||||
|     self.edge_keys  = sorted(list(self.edges.keys())) | ||||
|     self.edge2index = {key:i for i, key in enumerate(self.edge_keys)} | ||||
|   | ||||
		Reference in New Issue
	
	Block a user