Fix 1-element in norm bug
This commit is contained in:
		| @@ -58,6 +58,7 @@ def main(args): | |||||||
|         # build model |         # build model | ||||||
|         model = get_model(**model_kwargs) |         model = get_model(**model_kwargs) | ||||||
|         print(model) |         print(model) | ||||||
|  |         model.analyze_weights() | ||||||
|         # build optimizer |         # build optimizer | ||||||
|         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) |         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) | ||||||
|         criterion = torch.nn.MSELoss() |         criterion = torch.nn.MSELoss() | ||||||
| @@ -168,7 +169,7 @@ if __name__ == "__main__": | |||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--epochs", |         "--epochs", | ||||||
|         type=int, |         type=int, | ||||||
|         default=1000, |         default=300, | ||||||
|         help="The total number of epochs.", |         help="The total number of epochs.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|   | |||||||
| @@ -40,13 +40,10 @@ def get_model(config: Dict[Text, Any], **kwargs): | |||||||
|         norm_cls = super_name2norm[kwargs["norm_cls"]] |         norm_cls = super_name2norm[kwargs["norm_cls"]] | ||||||
|         sub_layers, last_dim = [], kwargs["input_dim"] |         sub_layers, last_dim = [], kwargs["input_dim"] | ||||||
|         for i, hidden_dim in enumerate(kwargs["hidden_dims"]): |         for i, hidden_dim in enumerate(kwargs["hidden_dims"]): | ||||||
|             sub_layers.extend( |             if last_dim > 1: | ||||||
|                 [ |                 sub_layers.append(norm_cls(last_dim, elementwise_affine=False)) | ||||||
|                     norm_cls(last_dim, elementwise_affine=False), |             sub_layers.append(SuperLinear(last_dim, hidden_dim)) | ||||||
|                     SuperLinear(last_dim, hidden_dim), |             sub_layers.append(act_cls()) | ||||||
|                     act_cls(), |  | ||||||
|                 ] |  | ||||||
|             ) |  | ||||||
|             last_dim = hidden_dim |             last_dim = hidden_dim | ||||||
|         sub_layers.append(SuperLinear(last_dim, kwargs["output_dim"])) |         sub_layers.append(SuperLinear(last_dim, kwargs["output_dim"])) | ||||||
|         model = SuperSequential(*sub_layers) |         model = SuperSequential(*sub_layers) | ||||||
|   | |||||||
| @@ -66,6 +66,15 @@ class SuperModule(abc.ABC, nn.Module): | |||||||
|             container.append(name, buf, False) |             container.append(name, buf, False) | ||||||
|         return container |         return container | ||||||
|  |  | ||||||
|  |     def analyze_weights(self): | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             for name, param in self.named_parameters(): | ||||||
|  |                 shapestr = "[{:10s}] shape={:}".format(name, list(param.shape)) | ||||||
|  |                 finalstr = shapestr + "{:.2f} +- {:.2f}".format( | ||||||
|  |                     param.mean(), param.std() | ||||||
|  |                 ) | ||||||
|  |                 print(finalstr) | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def abstract_search_space(self): |     def abstract_search_space(self): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user