Fix test bugs
This commit is contained in:
		| @@ -82,7 +82,14 @@ def main(args): | ||||
|         historical_x, historical_y = subsample(historical_x, historical_y) | ||||
|         # build model | ||||
|         mean, std = historical_x.mean().item(), historical_x.std().item() | ||||
|         model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std) | ||||
|         model_kwargs = dict( | ||||
|             input_dim=1, | ||||
|             output_dim=1, | ||||
|             act_cls="leaky_relu", | ||||
|             norm_cls="simple_norm", | ||||
|             mean=mean, | ||||
|             std=std, | ||||
|         ) | ||||
|         model = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||
|         # build optimizer | ||||
|         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) | ||||
|   | ||||
| @@ -78,7 +78,14 @@ def main(args): | ||||
|         historical_y = env_info["{:}-y".format(idx)] | ||||
|         # build model | ||||
|         mean, std = historical_x.mean().item(), historical_x.std().item() | ||||
|         model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std) | ||||
|         model_kwargs = dict( | ||||
|             input_dim=1, | ||||
|             output_dim=1, | ||||
|             act_cls="leaky_relu", | ||||
|             norm_cls="simple_norm", | ||||
|             mean=mean, | ||||
|             std=std, | ||||
|         ) | ||||
|         model = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||
|         # build optimizer | ||||
|         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) | ||||
|   | ||||
| @@ -24,6 +24,8 @@ from models.xcore import get_model | ||||
|  | ||||
|  | ||||
| class Population: | ||||
|     """A population used to maintain models at different timestamps.""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         self._time2model = dict() | ||||
|  | ||||
|   | ||||
| @@ -64,7 +64,7 @@ class TestSuperSimpleNorm(unittest.TestCase): | ||||
|         model.apply_verbose(True) | ||||
|  | ||||
|         print(model.super_run_type) | ||||
|         self.assertTrue(model[1].bias) | ||||
|         self.assertTrue(model[2].bias) | ||||
|  | ||||
|         inputs = torch.rand(20, 10) | ||||
|         print("Input shape: {:}".format(inputs.shape)) | ||||
| @@ -80,6 +80,6 @@ class TestSuperSimpleNorm(unittest.TestCase): | ||||
|         model.set_super_run_type(super_core.SuperRunMode.Candidate) | ||||
|         model.apply_candidate(abstract_child) | ||||
|  | ||||
|         output_shape = (20, abstract_child["1"]["_out_features"].value) | ||||
|         output_shape = (20, abstract_child["2"]["_out_features"].value) | ||||
|         outputs = model(inputs) | ||||
|         self.assertEqual(tuple(outputs.shape), output_shape) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user