Update base models

This commit is contained in:
D-X-Y 2021-05-12 10:58:54 +00:00
parent 4c51f62906
commit 80ccc49d92
3 changed files with 18 additions and 18 deletions

View File

@ -1,7 +1,7 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/LFNA/basic-same.py --srange 1-999 --env_version v1 --hidden_dim 16
# python exps/LFNA/basic-same.py --env_version v1 --hidden_dim 16
# python exps/LFNA/basic-same.py --srange 1-999 --env_version v2 --hidden_dim
#####################################################
import sys, time, copy, torch, random, argparse
@ -38,27 +38,17 @@ def subsample(historical_x, historical_y, maxn=10000):
def main(args):
logger, env_info, model_kwargs = lfna_setup(args)
# check indexes to be evaluated
to_evaluate_indexes = split_str2indexes(args.srange, env_info["total"], None)
logger.log(
"Evaluate {:}, which has {:} timestamps in total.".format(
args.srange, len(to_evaluate_indexes)
)
)
w_container_per_epoch = dict()
per_timestamp_time, start_time = AverageMeter(), time.time()
for i, idx in enumerate(to_evaluate_indexes):
for idx in range(env_info["total"]):
need_time = "Time Left: {:}".format(
convert_secs2time(
per_timestamp_time.avg * (len(to_evaluate_indexes) - i), True
)
convert_secs2time(per_timestamp_time.avg * (env_info["total"] - idx), True)
)
logger.log(
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}][{:04d}]".format(i, len(to_evaluate_indexes), idx)
+ " [{:04d}/{:04d}]".format(idx, env_info["total"])
+ " "
+ need_time
)
@ -66,7 +56,8 @@ def main(args):
historical_x = env_info["{:}-x".format(idx)]
historical_y = env_info["{:}-y".format(idx)]
# build model
model = get_model(dict(model_type="simple_mlp"), **model_kwargs)
model = get_model(**model_kwargs)
print(model)
# build optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
criterion = torch.nn.MSELoss()
@ -180,9 +171,6 @@ if __name__ == "__main__":
default=1000,
help="The total number of epochs.",
)
parser.add_argument(
"--srange", type=str, required=True, help="The range of models to be evaluated"
)
parser.add_argument(
"--workers",
type=int,

View File

@ -90,6 +90,7 @@ def main(args):
final_loss = torch.stack(losses).mean()
final_loss.backward()
torch.nn.utils.clip_grad_norm_(parameters, 1.0)
optimizer.step()
lr_scheduler.step()

View File

@ -28,13 +28,24 @@ def lfna_setup(args):
env_info["dynamic_env"] = dynamic_env
torch.save(env_info, cache_path)
"""
model_kwargs = dict(
config=dict(model_type="simple_mlp"),
input_dim=1,
output_dim=1,
hidden_dim=args.hidden_dim,
act_cls="leaky_relu",
norm_cls="identity",
)
"""
model_kwargs = dict(
config=dict(model_type="norm_mlp"),
input_dim=1,
output_dim=1,
hidden_dims=[args.hidden_dim] * 2,
act_cls="gelu",
norm_cls="layer_norm_1d",
)
return logger, env_info, model_kwargs