Update LFNA
This commit is contained in:
		| @@ -94,7 +94,10 @@ def main(args): | ||||
|     ) | ||||
|  | ||||
|     optimizer = torch.optim.Adam( | ||||
|         meta_model.parameters(), lr=args.init_lr, weight_decay=1e-5, amsgrad=True | ||||
|         meta_model.parameters(), | ||||
|         lr=args.init_lr, | ||||
|         weight_decay=args.weight_decay, | ||||
|         amsgrad=True, | ||||
|     ) | ||||
|     lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||
|         optimizer, | ||||
| @@ -137,7 +140,7 @@ def main(args): | ||||
|         ) | ||||
|         success, best_score = meta_model.save_best(-loss_meter.avg) | ||||
|         if success: | ||||
|             logger.log("Achieve the best with best_score = {:.3f}".format(best_score)) | ||||
|             logger.log("Achieve the best with best-score = {:.5f}".format(best_score)) | ||||
|             last_success_epoch = iepoch | ||||
|             save_checkpoint( | ||||
|                 { | ||||
| @@ -262,6 +265,12 @@ if __name__ == "__main__": | ||||
|         default=0.005, | ||||
|         help="The initial learning rate for the optimizer (default is Adam)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--weight_decay", | ||||
|         type=float, | ||||
|         default=0.00001, | ||||
|         help="The weight decay for the optimizer (default is Adam)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--meta_batch", | ||||
|         type=int, | ||||
| @@ -274,11 +283,11 @@ if __name__ == "__main__": | ||||
|         default=5, | ||||
|         help="Enlarge the #iterations for an epoch", | ||||
|     ) | ||||
|     parser.add_argument("--epochs", type=int, default=1000, help="The total #epochs.") | ||||
|     parser.add_argument("--epochs", type=int, default=10000, help="The total #epochs.") | ||||
|     parser.add_argument( | ||||
|         "--early_stop_thresh", | ||||
|         type=int, | ||||
|         default=25, | ||||
|         default=100, | ||||
|         help="The maximum epochs for early stop.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
| @@ -299,11 +308,13 @@ if __name__ == "__main__": | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "The save dir argument can not be None" | ||||
|     args.save_dir = "{:}-d{:}_{:}_{:}-e{:}-env{:}".format( | ||||
|     args.save_dir = "{:}-d{:}_{:}_{:}-lr{:}-wd{:}-e{:}-env{:}".format( | ||||
|         args.save_dir, | ||||
|         args.hidden_dim, | ||||
|         args.layer_dim, | ||||
|         args.time_dim, | ||||
|         args.init_lr, | ||||
|         args.weight_decay, | ||||
|         args.epochs, | ||||
|         args.env_version, | ||||
|     ) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user