Update base models
This commit is contained in:
		| @@ -1,7 +1,7 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||||
| ##################################################### | ##################################################### | ||||||
| # python exps/LFNA/basic-same.py --srange 1-999 --env_version v1 --hidden_dim 16 | # python exps/LFNA/basic-same.py --env_version v1 --hidden_dim 16 | ||||||
| # python exps/LFNA/basic-same.py --srange 1-999 --env_version v2 --hidden_dim | # python exps/LFNA/basic-same.py --srange 1-999 --env_version v2 --hidden_dim | ||||||
| ##################################################### | ##################################################### | ||||||
| import sys, time, copy, torch, random, argparse | import sys, time, copy, torch, random, argparse | ||||||
| @@ -38,27 +38,17 @@ def subsample(historical_x, historical_y, maxn=10000): | |||||||
| def main(args): | def main(args): | ||||||
|     logger, env_info, model_kwargs = lfna_setup(args) |     logger, env_info, model_kwargs = lfna_setup(args) | ||||||
|  |  | ||||||
|     # check indexes to be evaluated |  | ||||||
|     to_evaluate_indexes = split_str2indexes(args.srange, env_info["total"], None) |  | ||||||
|     logger.log( |  | ||||||
|         "Evaluate {:}, which has {:} timestamps in total.".format( |  | ||||||
|             args.srange, len(to_evaluate_indexes) |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     w_container_per_epoch = dict() |     w_container_per_epoch = dict() | ||||||
|  |  | ||||||
|     per_timestamp_time, start_time = AverageMeter(), time.time() |     per_timestamp_time, start_time = AverageMeter(), time.time() | ||||||
|     for i, idx in enumerate(to_evaluate_indexes): |     for idx in range(env_info["total"]): | ||||||
|  |  | ||||||
|         need_time = "Time Left: {:}".format( |         need_time = "Time Left: {:}".format( | ||||||
|             convert_secs2time( |             convert_secs2time(per_timestamp_time.avg * (env_info["total"] - idx), True) | ||||||
|                 per_timestamp_time.avg * (len(to_evaluate_indexes) - i), True |  | ||||||
|             ) |  | ||||||
|         ) |         ) | ||||||
|         logger.log( |         logger.log( | ||||||
|             "[{:}]".format(time_string()) |             "[{:}]".format(time_string()) | ||||||
|             + " [{:04d}/{:04d}][{:04d}]".format(i, len(to_evaluate_indexes), idx) |             + " [{:04d}/{:04d}]".format(idx, env_info["total"]) | ||||||
|             + " " |             + " " | ||||||
|             + need_time |             + need_time | ||||||
|         ) |         ) | ||||||
| @@ -66,7 +56,8 @@ def main(args): | |||||||
|         historical_x = env_info["{:}-x".format(idx)] |         historical_x = env_info["{:}-x".format(idx)] | ||||||
|         historical_y = env_info["{:}-y".format(idx)] |         historical_y = env_info["{:}-y".format(idx)] | ||||||
|         # build model |         # build model | ||||||
|         model = get_model(dict(model_type="simple_mlp"), **model_kwargs) |         model = get_model(**model_kwargs) | ||||||
|  |         print(model) | ||||||
|         # build optimizer |         # build optimizer | ||||||
|         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) |         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) | ||||||
|         criterion = torch.nn.MSELoss() |         criterion = torch.nn.MSELoss() | ||||||
| @@ -180,9 +171,6 @@ if __name__ == "__main__": | |||||||
|         default=1000, |         default=1000, | ||||||
|         help="The total number of epochs.", |         help="The total number of epochs.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |  | ||||||
|         "--srange", type=str, required=True, help="The range of models to be evaluated" |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--workers", |         "--workers", | ||||||
|         type=int, |         type=int, | ||||||
|   | |||||||
| @@ -90,6 +90,7 @@ def main(args): | |||||||
|  |  | ||||||
|         final_loss = torch.stack(losses).mean() |         final_loss = torch.stack(losses).mean() | ||||||
|         final_loss.backward() |         final_loss.backward() | ||||||
|  |         torch.nn.utils.clip_grad_norm_(parameters, 1.0) | ||||||
|         optimizer.step() |         optimizer.step() | ||||||
|         lr_scheduler.step() |         lr_scheduler.step() | ||||||
|  |  | ||||||
|   | |||||||
| @@ -28,13 +28,24 @@ def lfna_setup(args): | |||||||
|         env_info["dynamic_env"] = dynamic_env |         env_info["dynamic_env"] = dynamic_env | ||||||
|         torch.save(env_info, cache_path) |         torch.save(env_info, cache_path) | ||||||
|  |  | ||||||
|  |     """ | ||||||
|     model_kwargs = dict( |     model_kwargs = dict( | ||||||
|  |         config=dict(model_type="simple_mlp"), | ||||||
|         input_dim=1, |         input_dim=1, | ||||||
|         output_dim=1, |         output_dim=1, | ||||||
|         hidden_dim=args.hidden_dim, |         hidden_dim=args.hidden_dim, | ||||||
|         act_cls="leaky_relu", |         act_cls="leaky_relu", | ||||||
|         norm_cls="identity", |         norm_cls="identity", | ||||||
|     ) |     ) | ||||||
|  |     """ | ||||||
|  |     model_kwargs = dict( | ||||||
|  |         config=dict(model_type="norm_mlp"), | ||||||
|  |         input_dim=1, | ||||||
|  |         output_dim=1, | ||||||
|  |         hidden_dims=[args.hidden_dim] * 2, | ||||||
|  |         act_cls="gelu", | ||||||
|  |         norm_cls="layer_norm_1d", | ||||||
|  |     ) | ||||||
|     return logger, env_info, model_kwargs |     return logger, env_info, model_kwargs | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user