Update LFNA

This commit is contained in:
D-X-Y 2021-05-15 16:31:35 +08:00
parent 72f240bf0a
commit 5e766603be

View File

@ -94,7 +94,10 @@ def main(args):
)
optimizer = torch.optim.Adam(
meta_model.parameters(), lr=args.init_lr, weight_decay=1e-5, amsgrad=True
meta_model.parameters(),
lr=args.init_lr,
weight_decay=args.weight_decay,
amsgrad=True,
)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
@ -137,7 +140,7 @@ def main(args):
)
success, best_score = meta_model.save_best(-loss_meter.avg)
if success:
logger.log("Achieve the best with best_score = {:.3f}".format(best_score))
logger.log("Achieve the best with best-score = {:.5f}".format(best_score))
last_success_epoch = iepoch
save_checkpoint(
{
@ -262,6 +265,12 @@ if __name__ == "__main__":
default=0.005,
help="The initial learning rate for the optimizer (default is Adam)",
)
parser.add_argument(
"--weight_decay",
type=float,
default=0.00001,
help="The weight decay for the optimizer (default is Adam)",
)
parser.add_argument(
"--meta_batch",
type=int,
@ -274,11 +283,11 @@ if __name__ == "__main__":
default=5,
help="Enlarge the #iterations for an epoch",
)
parser.add_argument("--epochs", type=int, default=1000, help="The total #epochs.")
parser.add_argument("--epochs", type=int, default=10000, help="The total #epochs.")
parser.add_argument(
"--early_stop_thresh",
type=int,
default=25,
default=100,
help="The maximum epochs for early stop.",
)
parser.add_argument(
@ -299,11 +308,13 @@ if __name__ == "__main__":
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None"
args.save_dir = "{:}-d{:}_{:}_{:}-e{:}-env{:}".format(
args.save_dir = "{:}-d{:}_{:}_{:}-lr{:}-wd{:}-e{:}-env{:}".format(
args.save_dir,
args.hidden_dim,
args.layer_dim,
args.time_dim,
args.init_lr,
args.weight_decay,
args.epochs,
args.env_version,
)