Update MLAML
This commit is contained in:
		| @@ -5,26 +5,16 @@ | ||||
| ##################################################### | ||||
| import unittest | ||||
|  | ||||
| from xautodl.datasets.math_core import ConstantFunc, ComposedSinSFunc | ||||
| from xautodl.datasets.synthetic_core import SyntheticDEnv | ||||
| from xautodl.datasets.synthetic_core import get_synthetic_env | ||||
|  | ||||
|  | ||||
| class TestSynethicEnv(unittest.TestCase): | ||||
|     """Test the synethtic environment.""" | ||||
|  | ||||
|     def test_simple(self): | ||||
|         mean_generator = ConstantFunc(constant=0.1) | ||||
|         std_generator = ConstantFunc(constant=0.5) | ||||
|         dataset = SyntheticDEnv([mean_generator], [[std_generator]], num_per_task=5000) | ||||
|         print(dataset) | ||||
|         for timestamp, tau in dataset: | ||||
|             self.assertEqual(tau.shape, (5000, 1)) | ||||
|  | ||||
|     def test_length(self): | ||||
|         mean_generator = ComposedSinSFunc({0: 1, 1: 1, 2: 3}) | ||||
|         std_generator = ConstantFunc(constant=0.5) | ||||
|         dataset = SyntheticDEnv([mean_generator], [[std_generator]], num_per_task=5000) | ||||
|         self.assertEqual(len(dataset), 100) | ||||
|  | ||||
|         dataset = SyntheticDEnv([mean_generator], [[std_generator]], mode="train") | ||||
|         self.assertEqual(len(dataset), 60) | ||||
|         versions = ["v1", "v2", "v3", "v4"] | ||||
|         for version in versions: | ||||
|             env = get_synthetic_env(version=version) | ||||
|         print(env) | ||||
|         for timestamp, tau in env: | ||||
|             self.assertEqual(tau.shape, (1000, env.ndim)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user