Fix test bugs

This commit is contained in:
D-X-Y 2021-05-06 16:43:31 +08:00
parent 4c14c7b85b
commit f6a024a6ff
4 changed files with 20 additions and 4 deletions

View File

@ -82,7 +82,14 @@ def main(args):
historical_x, historical_y = subsample(historical_x, historical_y) historical_x, historical_y = subsample(historical_x, historical_y)
# build model # build model
mean, std = historical_x.mean().item(), historical_x.std().item() mean, std = historical_x.mean().item(), historical_x.std().item()
model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std) model_kwargs = dict(
input_dim=1,
output_dim=1,
act_cls="leaky_relu",
norm_cls="simple_norm",
mean=mean,
std=std,
)
model = get_model(dict(model_type="simple_mlp"), **model_kwargs) model = get_model(dict(model_type="simple_mlp"), **model_kwargs)
# build optimizer # build optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)

View File

@ -78,7 +78,14 @@ def main(args):
historical_y = env_info["{:}-y".format(idx)] historical_y = env_info["{:}-y".format(idx)]
# build model # build model
mean, std = historical_x.mean().item(), historical_x.std().item() mean, std = historical_x.mean().item(), historical_x.std().item()
model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std) model_kwargs = dict(
input_dim=1,
output_dim=1,
act_cls="leaky_relu",
norm_cls="simple_norm",
mean=mean,
std=std,
)
model = get_model(dict(model_type="simple_mlp"), **model_kwargs) model = get_model(dict(model_type="simple_mlp"), **model_kwargs)
# build optimizer # build optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)

View File

@ -24,6 +24,8 @@ from models.xcore import get_model
class Population: class Population:
"""A population used to maintain models at different timestamps."""
def __init__(self): def __init__(self):
self._time2model = dict() self._time2model = dict()

View File

@ -64,7 +64,7 @@ class TestSuperSimpleNorm(unittest.TestCase):
model.apply_verbose(True) model.apply_verbose(True)
print(model.super_run_type) print(model.super_run_type)
self.assertTrue(model[1].bias) self.assertTrue(model[2].bias)
inputs = torch.rand(20, 10) inputs = torch.rand(20, 10)
print("Input shape: {:}".format(inputs.shape)) print("Input shape: {:}".format(inputs.shape))
@ -80,6 +80,6 @@ class TestSuperSimpleNorm(unittest.TestCase):
model.set_super_run_type(super_core.SuperRunMode.Candidate) model.set_super_run_type(super_core.SuperRunMode.Candidate)
model.apply_candidate(abstract_child) model.apply_candidate(abstract_child)
output_shape = (20, abstract_child["1"]["_out_features"].value) output_shape = (20, abstract_child["2"]["_out_features"].value)
outputs = model(inputs) outputs = model(inputs)
self.assertEqual(tuple(outputs.shape), output_shape) self.assertEqual(tuple(outputs.shape), output_shape)