Fix bugs
This commit is contained in:
		| @@ -74,8 +74,12 @@ def main(args): | ||||
|         ) | ||||
|         # train the same data | ||||
|         assert idx != 0 | ||||
|         historical_x = env_info["{:}-x".format(idx)] | ||||
|         historical_y = env_info["{:}-y".format(idx)] | ||||
|         historical_x, historical_y = [], [] | ||||
|         for past_i in range(idx): | ||||
|             historical_x.append(env_info["{:}-x".format(past_i)]) | ||||
|             historical_y.append(env_info["{:}-y".format(past_i)]) | ||||
|         historical_x, historical_y = torch.cat(historical_x), torch.cat(historical_y) | ||||
|         historical_x, historical_y = subsample(historical_x, historical_y) | ||||
|         # build model | ||||
|         mean, std = historical_x.mean().item(), historical_x.std().item() | ||||
|         model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std) | ||||
| @@ -153,7 +157,7 @@ if __name__ == "__main__": | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./outputs/lfna-synthetic/use-same-timestamp", | ||||
|         default="./outputs/lfna-synthetic/use-all-past-data", | ||||
|         help="The checkpoint directory.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|   | ||||
| @@ -74,12 +74,8 @@ def main(args): | ||||
|         ) | ||||
|         # train the same data | ||||
|         assert idx != 0 | ||||
|         historical_x, historical_y = [], [] | ||||
|         for past_i in range(idx): | ||||
|             historical_x.append(env_info["{:}-x".format(past_i)]) | ||||
|             historical_y.append(env_info["{:}-y".format(past_i)]) | ||||
|         historical_x, historical_y = torch.cat(historical_x), torch.cat(historical_y) | ||||
|         historical_x, historical_y = subsample(historical_x, historical_y) | ||||
|         historical_x = env_info["{:}-x".format(idx)] | ||||
|         historical_y = env_info["{:}-y".format(idx)] | ||||
|         # build model | ||||
|         mean, std = historical_x.mean().item(), historical_x.std().item() | ||||
|         model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std) | ||||
| @@ -153,11 +149,11 @@ def main(args): | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("Use data at the same timestamp.") | ||||
|     parser = argparse.ArgumentParser("Use the data in the past.") | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./outputs/lfna-synthetic/use-all-past-data", | ||||
|         default="./outputs/lfna-synthetic/use-same-timestamp", | ||||
|         help="The checkpoint directory.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user