Update base models
This commit is contained in:
parent
4c51f62906
commit
80ccc49d92
@ -1,7 +1,7 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
||||
#####################################################
|
||||
# python exps/LFNA/basic-same.py --srange 1-999 --env_version v1 --hidden_dim 16
|
||||
# python exps/LFNA/basic-same.py --env_version v1 --hidden_dim 16
|
||||
# python exps/LFNA/basic-same.py --srange 1-999 --env_version v2 --hidden_dim
|
||||
#####################################################
|
||||
import sys, time, copy, torch, random, argparse
|
||||
@ -38,27 +38,17 @@ def subsample(historical_x, historical_y, maxn=10000):
|
||||
def main(args):
|
||||
logger, env_info, model_kwargs = lfna_setup(args)
|
||||
|
||||
# check indexes to be evaluated
|
||||
to_evaluate_indexes = split_str2indexes(args.srange, env_info["total"], None)
|
||||
logger.log(
|
||||
"Evaluate {:}, which has {:} timestamps in total.".format(
|
||||
args.srange, len(to_evaluate_indexes)
|
||||
)
|
||||
)
|
||||
|
||||
w_container_per_epoch = dict()
|
||||
|
||||
per_timestamp_time, start_time = AverageMeter(), time.time()
|
||||
for i, idx in enumerate(to_evaluate_indexes):
|
||||
for idx in range(env_info["total"]):
|
||||
|
||||
need_time = "Time Left: {:}".format(
|
||||
convert_secs2time(
|
||||
per_timestamp_time.avg * (len(to_evaluate_indexes) - i), True
|
||||
)
|
||||
convert_secs2time(per_timestamp_time.avg * (env_info["total"] - idx), True)
|
||||
)
|
||||
logger.log(
|
||||
"[{:}]".format(time_string())
|
||||
+ " [{:04d}/{:04d}][{:04d}]".format(i, len(to_evaluate_indexes), idx)
|
||||
+ " [{:04d}/{:04d}]".format(idx, env_info["total"])
|
||||
+ " "
|
||||
+ need_time
|
||||
)
|
||||
@ -66,7 +56,8 @@ def main(args):
|
||||
historical_x = env_info["{:}-x".format(idx)]
|
||||
historical_y = env_info["{:}-y".format(idx)]
|
||||
# build model
|
||||
model = get_model(dict(model_type="simple_mlp"), **model_kwargs)
|
||||
model = get_model(**model_kwargs)
|
||||
print(model)
|
||||
# build optimizer
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
|
||||
criterion = torch.nn.MSELoss()
|
||||
@ -180,9 +171,6 @@ if __name__ == "__main__":
|
||||
default=1000,
|
||||
help="The total number of epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--srange", type=str, required=True, help="The range of models to be evaluated"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
|
@ -90,6 +90,7 @@ def main(args):
|
||||
|
||||
final_loss = torch.stack(losses).mean()
|
||||
final_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(parameters, 1.0)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
|
@ -28,13 +28,24 @@ def lfna_setup(args):
|
||||
env_info["dynamic_env"] = dynamic_env
|
||||
torch.save(env_info, cache_path)
|
||||
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
config=dict(model_type="simple_mlp"),
|
||||
input_dim=1,
|
||||
output_dim=1,
|
||||
hidden_dim=args.hidden_dim,
|
||||
act_cls="leaky_relu",
|
||||
norm_cls="identity",
|
||||
)
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
config=dict(model_type="norm_mlp"),
|
||||
input_dim=1,
|
||||
output_dim=1,
|
||||
hidden_dims=[args.hidden_dim] * 2,
|
||||
act_cls="gelu",
|
||||
norm_cls="layer_norm_1d",
|
||||
)
|
||||
return logger, env_info, model_kwargs
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user