Update LFNA

This commit is contained in:
D-X-Y 2021-05-17 12:01:58 +00:00
parent 5c851ac25a
commit 85f7f1a400

View File

@ -132,8 +132,8 @@ def main(args):
)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[1, 2],
gamma=0.1,
milestones=[1, 2, 3, 4, 5],
gamma=0.2,
)
logger.log("The base-model is\n{:}".format(base_model))
logger.log("The meta-model is\n{:}".format(meta_model))
@ -223,11 +223,12 @@ def main(args):
logger,
)
if iepoch - last_success_epoch >= args.early_stop_thresh:
if lr_scheduler.last_epoch > 2:
if lr_scheduler.last_epoch > 4:
logger.log("Early stop at {:}".format(iepoch))
break
else:
last_epoch.step()
last_success_epoch = iepoch
lr_scheduler.step()
logger.log("Decay the lr [{:}]".format(lr_scheduler.last_epoch))
per_epoch_time.update(time.time() - start_time)
@ -375,7 +376,7 @@ if __name__ == "__main__":
help="The #epochs for early stop.",
)
parser.add_argument(
"--seq_length", type=int, default=5, help="The sequence length."
"--seq_length", type=int, default=10, help="The sequence length."
)
parser.add_argument(
"--workers", type=int, default=4, help="The number of workers in parallel."
@ -392,11 +393,12 @@ if __name__ == "__main__":
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None"
args.save_dir = "{:}-d{:}_{:}_{:}-lr{:}-wd{:}-e{:}-env{:}".format(
args.save_dir = "{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-env{:}".format(
args.save_dir,
args.hidden_dim,
args.layer_dim,
args.time_dim,
args.seq_length,
args.lr,
args.weight_decay,
args.epochs,