Complete Super Linear
This commit is contained in:
		| @@ -41,14 +41,14 @@ class TestBasicSpace(unittest.TestCase): | ||||
|     def test_continuous(self): | ||||
|         random.seed(999) | ||||
|         space = Continuous(0, 1) | ||||
|         self.assertGreaterEqual(space.random(), 0) | ||||
|         self.assertGreaterEqual(1, space.random()) | ||||
|         self.assertGreaterEqual(space.random().value, 0) | ||||
|         self.assertGreaterEqual(1, space.random().value) | ||||
|  | ||||
|         lower, upper = 1.5, 4.6 | ||||
|         space = Continuous(lower, upper, log=False) | ||||
|         values = [] | ||||
|         for i in range(1000000): | ||||
|             x = space.random() | ||||
|             x = space.random().value | ||||
|             self.assertGreaterEqual(x, lower) | ||||
|             self.assertGreaterEqual(upper, x) | ||||
|             values.append(x) | ||||
| @@ -89,7 +89,7 @@ class TestBasicSpace(unittest.TestCase): | ||||
|             Categorical(4, Categorical(5, 6, 7, Categorical(8, 9), 10), 11), | ||||
|             12, | ||||
|         ) | ||||
|         print(nested_space) | ||||
|         print("\nThe nested search space:\n{:}".format(nested_space)) | ||||
|         for i in range(1, 13): | ||||
|             self.assertTrue(nested_space.has(i)) | ||||
|  | ||||
| @@ -102,6 +102,19 @@ class TestAbstractSpace(unittest.TestCase): | ||||
|     """Test the abstract search spaces.""" | ||||
|  | ||||
|     def test_continous(self): | ||||
|         print("") | ||||
|         space = Continuous(0, 1) | ||||
|         self.assertEqual(space, space.abstract()) | ||||
|         print("The abstract search space for Continuous: {:}".format(space.abstract())) | ||||
|  | ||||
|         space = Categorical(1, 2, 3) | ||||
|         self.assertEqual(len(space.abstract()), 3) | ||||
|         print(space.abstract()) | ||||
|  | ||||
|         nested_space = Categorical( | ||||
|             Categorical(1, 2, 3), | ||||
|             Categorical(4, Categorical(5, 6, 7, Categorical(8, 9), 10), 11), | ||||
|             12, | ||||
|         ) | ||||
|         abstract_nested_space = nested_space.abstract() | ||||
|         print("The abstract nested search space:\n{:}".format(abstract_nested_space)) | ||||
|   | ||||
| @@ -25,6 +25,26 @@ class TestSuperLinear(unittest.TestCase): | ||||
|         out_features = spaces.Categorical(12, 24, 36) | ||||
|         bias = spaces.Categorical(True, False) | ||||
|         model = super_core.SuperLinear(10, out_features, bias=bias) | ||||
|         print(model) | ||||
|         print("The simple super linear module is:\n{:}".format(model)) | ||||
|  | ||||
|         print(model.super_run_type) | ||||
|         print(model.abstract_search_space()) | ||||
|         self.assertTrue(model.bias) | ||||
|  | ||||
|         inputs = torch.rand(32, 10) | ||||
|         print("Input shape: {:}".format(inputs.shape)) | ||||
|         print("Weight shape: {:}".format(model._super_weight.shape)) | ||||
|         print("Bias shape: {:}".format(model._super_bias.shape)) | ||||
|         outputs = model(inputs) | ||||
|         self.assertEqual(tuple(outputs.shape), (32, 36)) | ||||
|  | ||||
|         abstract_space = model.abstract_search_space | ||||
|         abstract_child = abstract_space.random() | ||||
|         print("The abstract searc space:\n{:}".format(abstract_space)) | ||||
|         print("The abstract child program:\n{:}".format(abstract_child)) | ||||
|  | ||||
|         model.set_super_run_type(super_core.SuperRunMode.Candidate) | ||||
|         model.apply_candiate(abstract_child) | ||||
|  | ||||
|         output_shape = (32, abstract_child["_out_features"].value) | ||||
|         outputs = model(inputs) | ||||
|         self.assertEqual(tuple(outputs.shape), output_shape) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user