Update SuperMLP
This commit is contained in:
		| @@ -30,32 +30,37 @@ class TestSuperLinear(unittest.TestCase): | ||||
|         print(model.super_run_type) | ||||
|         self.assertTrue(model.bias) | ||||
|  | ||||
|         inputs = torch.rand(32, 10) | ||||
|         inputs = torch.rand(20, 10) | ||||
|         print("Input shape: {:}".format(inputs.shape)) | ||||
|         print("Weight shape: {:}".format(model._super_weight.shape)) | ||||
|         print("Bias shape: {:}".format(model._super_bias.shape)) | ||||
|         outputs = model(inputs) | ||||
|         self.assertEqual(tuple(outputs.shape), (32, 36)) | ||||
|         self.assertEqual(tuple(outputs.shape), (20, 36)) | ||||
|  | ||||
|         abstract_space = model.abstract_search_space | ||||
|         abstract_space.clean_last() | ||||
|         abstract_child = abstract_space.random() | ||||
|         print("The abstract searc space:\n{:}".format(abstract_space)) | ||||
|         print("The abstract child program:\n{:}".format(abstract_child)) | ||||
|  | ||||
|         model.set_super_run_type(super_core.SuperRunMode.Candidate) | ||||
|         model.apply_candiate(abstract_child) | ||||
|         model.apply_candidate(abstract_child) | ||||
|  | ||||
|         output_shape = (32, abstract_child["_out_features"].value) | ||||
|         output_shape = (20, abstract_child["_out_features"].value) | ||||
|         outputs = model(inputs) | ||||
|         self.assertEqual(tuple(outputs.shape), output_shape) | ||||
|  | ||||
|     def test_super_mlp(self): | ||||
|         hidden_features = spaces.Categorical(12, 24, 36) | ||||
|         out_features = spaces.Categorical(12, 24, 36) | ||||
|         out_features = spaces.Categorical(24, 36, 48) | ||||
|         mlp = super_core.SuperMLP(10, hidden_features, out_features) | ||||
|         print(mlp) | ||||
|         self.assertTrue(mlp.fc1._out_features, mlp.fc2._in_features) | ||||
|  | ||||
|         inputs = torch.rand(4, 10) | ||||
|         outputs = mlp(inputs) | ||||
|         self.assertEqual(tuple(outputs.shape), (4, 48)) | ||||
|  | ||||
|         abstract_space = mlp.abstract_search_space | ||||
|         print("The abstract search space for SuperMLP is:\n{:}".format(abstract_space)) | ||||
|         self.assertEqual( | ||||
| @@ -67,10 +72,16 @@ class TestSuperLinear(unittest.TestCase): | ||||
|             is abstract_space["fc2"]["_in_features"] | ||||
|         ) | ||||
|  | ||||
|         abstract_space.clean_last_sample() | ||||
|         abstract_space.clean_last() | ||||
|         abstract_child = abstract_space.random(reuse_last=True) | ||||
|         print("The abstract child program is:\n{:}".format(abstract_child)) | ||||
|         self.assertEqual( | ||||
|             abstract_child["fc1"]["_out_features"].value, | ||||
|             abstract_child["fc2"]["_in_features"].value, | ||||
|         ) | ||||
|  | ||||
|         mlp.set_super_run_type(super_core.SuperRunMode.Candidate) | ||||
|         mlp.apply_candidate(abstract_child) | ||||
|         outputs = mlp(inputs) | ||||
|         output_shape = (4, abstract_child["fc2"]["_out_features"].value) | ||||
|         self.assertEqual(tuple(outputs.shape), output_shape) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user