Fix bugs
This commit is contained in:
parent
184f2326bb
commit
f7c2bb5e32
@ -74,8 +74,12 @@ def main(args):
|
||||
)
|
||||
# train the same data
|
||||
assert idx != 0
|
||||
historical_x = env_info["{:}-x".format(idx)]
|
||||
historical_y = env_info["{:}-y".format(idx)]
|
||||
historical_x, historical_y = [], []
|
||||
for past_i in range(idx):
|
||||
historical_x.append(env_info["{:}-x".format(past_i)])
|
||||
historical_y.append(env_info["{:}-y".format(past_i)])
|
||||
historical_x, historical_y = torch.cat(historical_x), torch.cat(historical_y)
|
||||
historical_x, historical_y = subsample(historical_x, historical_y)
|
||||
# build model
|
||||
mean, std = historical_x.mean().item(), historical_x.std().item()
|
||||
model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std)
|
||||
@ -153,7 +157,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--save_dir",
|
||||
type=str,
|
||||
default="./outputs/lfna-synthetic/use-same-timestamp",
|
||||
default="./outputs/lfna-synthetic/use-all-past-data",
|
||||
help="The checkpoint directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
@ -74,12 +74,8 @@ def main(args):
|
||||
)
|
||||
# train the same data
|
||||
assert idx != 0
|
||||
historical_x, historical_y = [], []
|
||||
for past_i in range(idx):
|
||||
historical_x.append(env_info["{:}-x".format(past_i)])
|
||||
historical_y.append(env_info["{:}-y".format(past_i)])
|
||||
historical_x, historical_y = torch.cat(historical_x), torch.cat(historical_y)
|
||||
historical_x, historical_y = subsample(historical_x, historical_y)
|
||||
historical_x = env_info["{:}-x".format(idx)]
|
||||
historical_y = env_info["{:}-y".format(idx)]
|
||||
# build model
|
||||
mean, std = historical_x.mean().item(), historical_x.std().item()
|
||||
model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std)
|
||||
@ -153,11 +149,11 @@ def main(args):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("Use data at the same timestamp.")
|
||||
parser = argparse.ArgumentParser("Use the data in the past.")
|
||||
parser.add_argument(
|
||||
"--save_dir",
|
||||
type=str,
|
||||
default="./outputs/lfna-synthetic/use-all-past-data",
|
||||
default="./outputs/lfna-synthetic/use-same-timestamp",
|
||||
help="The checkpoint directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
Loading…
Reference in New Issue
Block a user