From f7c2bb5e32b3130b8559b1aced3d5b4a68dd297f Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Thu, 29 Apr 2021 19:11:48 +0800 Subject: [PATCH] Fix bugs --- exps/LFNA/basic-his.py | 10 +++++++--- exps/LFNA/basic-same.py | 12 ++++-------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/exps/LFNA/basic-his.py b/exps/LFNA/basic-his.py index fca229f..c8b369b 100644 --- a/exps/LFNA/basic-his.py +++ b/exps/LFNA/basic-his.py @@ -74,8 +74,12 @@ def main(args): ) # train the same data assert idx != 0 - historical_x = env_info["{:}-x".format(idx)] - historical_y = env_info["{:}-y".format(idx)] + historical_x, historical_y = [], [] + for past_i in range(idx): + historical_x.append(env_info["{:}-x".format(past_i)]) + historical_y.append(env_info["{:}-y".format(past_i)]) + historical_x, historical_y = torch.cat(historical_x), torch.cat(historical_y) + historical_x, historical_y = subsample(historical_x, historical_y) # build model mean, std = historical_x.mean().item(), historical_x.std().item() model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std) @@ -153,7 +157,7 @@ if __name__ == "__main__": parser.add_argument( "--save_dir", type=str, - default="./outputs/lfna-synthetic/use-same-timestamp", + default="./outputs/lfna-synthetic/use-all-past-data", help="The checkpoint directory.", ) parser.add_argument( diff --git a/exps/LFNA/basic-same.py b/exps/LFNA/basic-same.py index 578c6f7..0a889a9 100644 --- a/exps/LFNA/basic-same.py +++ b/exps/LFNA/basic-same.py @@ -74,12 +74,8 @@ def main(args): ) # train the same data assert idx != 0 - historical_x, historical_y = [], [] - for past_i in range(idx): - historical_x.append(env_info["{:}-x".format(past_i)]) - historical_y.append(env_info["{:}-y".format(past_i)]) - historical_x, historical_y = torch.cat(historical_x), torch.cat(historical_y) - historical_x, historical_y = subsample(historical_x, historical_y) + historical_x = env_info["{:}-x".format(idx)] + historical_y = env_info["{:}-y".format(idx)] # build model mean, std = historical_x.mean().item(), historical_x.std().item() model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std) @@ -153,11 +149,11 @@ def main(args): if __name__ == "__main__": - parser = argparse.ArgumentParser("Use data at the same timestamp.") + parser = argparse.ArgumentParser("Use the data in the past.") parser.add_argument( "--save_dir", type=str, - default="./outputs/lfna-synthetic/use-all-past-data", + default="./outputs/lfna-synthetic/use-same-timestamp", help="The checkpoint directory.", ) parser.add_argument(