This commit is contained in:
D-X-Y 2021-04-29 19:11:48 +08:00
parent 184f2326bb
commit f7c2bb5e32
2 changed files with 11 additions and 11 deletions

View File

@ -74,8 +74,12 @@ def main(args):
)
# train the same data
assert idx != 0
historical_x = env_info["{:}-x".format(idx)]
historical_y = env_info["{:}-y".format(idx)]
historical_x, historical_y = [], []
for past_i in range(idx):
historical_x.append(env_info["{:}-x".format(past_i)])
historical_y.append(env_info["{:}-y".format(past_i)])
historical_x, historical_y = torch.cat(historical_x), torch.cat(historical_y)
historical_x, historical_y = subsample(historical_x, historical_y)
# build model
mean, std = historical_x.mean().item(), historical_x.std().item()
model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std)
@ -153,7 +157,7 @@ if __name__ == "__main__":
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/lfna-synthetic/use-same-timestamp",
default="./outputs/lfna-synthetic/use-all-past-data",
help="The checkpoint directory.",
)
parser.add_argument(

View File

@ -74,12 +74,8 @@ def main(args):
)
# train the same data
assert idx != 0
historical_x, historical_y = [], []
for past_i in range(idx):
historical_x.append(env_info["{:}-x".format(past_i)])
historical_y.append(env_info["{:}-y".format(past_i)])
historical_x, historical_y = torch.cat(historical_x), torch.cat(historical_y)
historical_x, historical_y = subsample(historical_x, historical_y)
historical_x = env_info["{:}-x".format(idx)]
historical_y = env_info["{:}-y".format(idx)]
# build model
mean, std = historical_x.mean().item(), historical_x.std().item()
model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std)
@ -153,11 +149,11 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser("Use data at the same timestamp.")
parser = argparse.ArgumentParser("Use the data in the past.")
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/lfna-synthetic/use-all-past-data",
default="./outputs/lfna-synthetic/use-same-timestamp",
help="The checkpoint directory.",
)
parser.add_argument(