Update base models

This commit is contained in:
D-X-Y 2021-05-12 10:58:54 +00:00
parent 4c51f62906
commit 80ccc49d92
3 changed files with 18 additions and 18 deletions

View File

@ -1,7 +1,7 @@
##################################################### #####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
##################################################### #####################################################
# python exps/LFNA/basic-same.py --srange 1-999 --env_version v1 --hidden_dim 16 # python exps/LFNA/basic-same.py --env_version v1 --hidden_dim 16
# python exps/LFNA/basic-same.py --srange 1-999 --env_version v2 --hidden_dim # python exps/LFNA/basic-same.py --srange 1-999 --env_version v2 --hidden_dim
##################################################### #####################################################
import sys, time, copy, torch, random, argparse import sys, time, copy, torch, random, argparse
@ -38,27 +38,17 @@ def subsample(historical_x, historical_y, maxn=10000):
def main(args): def main(args):
logger, env_info, model_kwargs = lfna_setup(args) logger, env_info, model_kwargs = lfna_setup(args)
# check indexes to be evaluated
to_evaluate_indexes = split_str2indexes(args.srange, env_info["total"], None)
logger.log(
"Evaluate {:}, which has {:} timestamps in total.".format(
args.srange, len(to_evaluate_indexes)
)
)
w_container_per_epoch = dict() w_container_per_epoch = dict()
per_timestamp_time, start_time = AverageMeter(), time.time() per_timestamp_time, start_time = AverageMeter(), time.time()
for i, idx in enumerate(to_evaluate_indexes): for idx in range(env_info["total"]):
need_time = "Time Left: {:}".format( need_time = "Time Left: {:}".format(
convert_secs2time( convert_secs2time(per_timestamp_time.avg * (env_info["total"] - idx), True)
per_timestamp_time.avg * (len(to_evaluate_indexes) - i), True
)
) )
logger.log( logger.log(
"[{:}]".format(time_string()) "[{:}]".format(time_string())
+ " [{:04d}/{:04d}][{:04d}]".format(i, len(to_evaluate_indexes), idx) + " [{:04d}/{:04d}]".format(idx, env_info["total"])
+ " " + " "
+ need_time + need_time
) )
@ -66,7 +56,8 @@ def main(args):
historical_x = env_info["{:}-x".format(idx)] historical_x = env_info["{:}-x".format(idx)]
historical_y = env_info["{:}-y".format(idx)] historical_y = env_info["{:}-y".format(idx)]
# build model # build model
model = get_model(dict(model_type="simple_mlp"), **model_kwargs) model = get_model(**model_kwargs)
print(model)
# build optimizer # build optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
criterion = torch.nn.MSELoss() criterion = torch.nn.MSELoss()
@ -180,9 +171,6 @@ if __name__ == "__main__":
default=1000, default=1000,
help="The total number of epochs.", help="The total number of epochs.",
) )
parser.add_argument(
"--srange", type=str, required=True, help="The range of models to be evaluated"
)
parser.add_argument( parser.add_argument(
"--workers", "--workers",
type=int, type=int,

View File

@ -90,6 +90,7 @@ def main(args):
final_loss = torch.stack(losses).mean() final_loss = torch.stack(losses).mean()
final_loss.backward() final_loss.backward()
torch.nn.utils.clip_grad_norm_(parameters, 1.0)
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()

View File

@ -28,13 +28,24 @@ def lfna_setup(args):
env_info["dynamic_env"] = dynamic_env env_info["dynamic_env"] = dynamic_env
torch.save(env_info, cache_path) torch.save(env_info, cache_path)
"""
model_kwargs = dict( model_kwargs = dict(
config=dict(model_type="simple_mlp"),
input_dim=1, input_dim=1,
output_dim=1, output_dim=1,
hidden_dim=args.hidden_dim, hidden_dim=args.hidden_dim,
act_cls="leaky_relu", act_cls="leaky_relu",
norm_cls="identity", norm_cls="identity",
) )
"""
model_kwargs = dict(
config=dict(model_type="norm_mlp"),
input_dim=1,
output_dim=1,
hidden_dims=[args.hidden_dim] * 2,
act_cls="gelu",
norm_cls="layer_norm_1d",
)
return logger, env_info, model_kwargs return logger, env_info, model_kwargs