diff --git a/exps/LFNA/lfna.py b/exps/LFNA/lfna.py index 89be5a9..0a34f4c 100644 --- a/exps/LFNA/lfna.py +++ b/exps/LFNA/lfna.py @@ -132,8 +132,8 @@ def main(args): ) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, - milestones=[1, 2], - gamma=0.1, + milestones=[1, 2, 3, 4, 5], + gamma=0.2, ) logger.log("The base-model is\n{:}".format(base_model)) logger.log("The meta-model is\n{:}".format(meta_model)) @@ -223,11 +223,12 @@ def main(args): logger, ) if iepoch - last_success_epoch >= args.early_stop_thresh: - if lr_scheduler.last_epoch > 2: + if lr_scheduler.last_epoch > 4: logger.log("Early stop at {:}".format(iepoch)) break else: - last_epoch.step() + last_success_epoch = iepoch + lr_scheduler.step() logger.log("Decay the lr [{:}]".format(lr_scheduler.last_epoch)) per_epoch_time.update(time.time() - start_time) @@ -375,7 +376,7 @@ if __name__ == "__main__": help="The #epochs for early stop.", ) parser.add_argument( - "--seq_length", type=int, default=5, help="The sequence length." + "--seq_length", type=int, default=10, help="The sequence length." ) parser.add_argument( "--workers", type=int, default=4, help="The number of workers in parallel." @@ -392,11 +393,12 @@ if __name__ == "__main__": if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) assert args.save_dir is not None, "The save dir argument can not be None" - args.save_dir = "{:}-d{:}_{:}_{:}-lr{:}-wd{:}-e{:}-env{:}".format( + args.save_dir = "{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-env{:}".format( args.save_dir, args.hidden_dim, args.layer_dim, args.time_dim, + args.seq_length, args.lr, args.weight_decay, args.epochs,