diff --git a/exps/LFNA/basic-his.py b/exps/LFNA/basic-his.py index 2506ceb..5ba3d68 100644 --- a/exps/LFNA/basic-his.py +++ b/exps/LFNA/basic-his.py @@ -82,7 +82,14 @@ def main(args): historical_x, historical_y = subsample(historical_x, historical_y) # build model mean, std = historical_x.mean().item(), historical_x.std().item() - model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std) + model_kwargs = dict( + input_dim=1, + output_dim=1, + act_cls="leaky_relu", + norm_cls="simple_norm", + mean=mean, + std=std, + ) model = get_model(dict(model_type="simple_mlp"), **model_kwargs) # build optimizer optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) diff --git a/exps/LFNA/basic-same.py b/exps/LFNA/basic-same.py index 4fcdf5d..4bcb702 100644 --- a/exps/LFNA/basic-same.py +++ b/exps/LFNA/basic-same.py @@ -78,7 +78,14 @@ def main(args): historical_y = env_info["{:}-y".format(idx)] # build model mean, std = historical_x.mean().item(), historical_x.std().item() - model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std) + model_kwargs = dict( + input_dim=1, + output_dim=1, + act_cls="leaky_relu", + norm_cls="simple_norm", + mean=mean, + std=std, + ) model = get_model(dict(model_type="simple_mlp"), **model_kwargs) # build optimizer optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) diff --git a/exps/LFNA/lfna-v1.py b/exps/LFNA/lfna-v1.py index 9c9b90a..90ac10a 100644 --- a/exps/LFNA/lfna-v1.py +++ b/exps/LFNA/lfna-v1.py @@ -24,6 +24,8 @@ from models.xcore import get_model class Population: + """A population used to maintain models at different timestamps.""" + def __init__(self): self._time2model = dict() diff --git a/tests/test_super_norm.py b/tests/test_super_norm.py index d5a21d6..7e2e6f1 100644 --- a/tests/test_super_norm.py +++ b/tests/test_super_norm.py @@ -64,7 +64,7 @@ class TestSuperSimpleNorm(unittest.TestCase): model.apply_verbose(True) print(model.super_run_type) - self.assertTrue(model[1].bias) + self.assertTrue(model[2].bias) inputs = torch.rand(20, 10) print("Input shape: {:}".format(inputs.shape)) @@ -80,6 +80,6 @@ class TestSuperSimpleNorm(unittest.TestCase): model.set_super_run_type(super_core.SuperRunMode.Candidate) model.apply_candidate(abstract_child) - output_shape = (20, abstract_child["1"]["_out_features"].value) + output_shape = (20, abstract_child["2"]["_out_features"].value) outputs = model(inputs) self.assertEqual(tuple(outputs.shape), output_shape)