Update LFNA
This commit is contained in:
parent
72f240bf0a
commit
5e766603be
@ -94,7 +94,10 @@ def main(args):
|
||||
)
|
||||
|
||||
optimizer = torch.optim.Adam(
|
||||
meta_model.parameters(), lr=args.init_lr, weight_decay=1e-5, amsgrad=True
|
||||
meta_model.parameters(),
|
||||
lr=args.init_lr,
|
||||
weight_decay=args.weight_decay,
|
||||
amsgrad=True,
|
||||
)
|
||||
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||
optimizer,
|
||||
@ -137,7 +140,7 @@ def main(args):
|
||||
)
|
||||
success, best_score = meta_model.save_best(-loss_meter.avg)
|
||||
if success:
|
||||
logger.log("Achieve the best with best_score = {:.3f}".format(best_score))
|
||||
logger.log("Achieve the best with best-score = {:.5f}".format(best_score))
|
||||
last_success_epoch = iepoch
|
||||
save_checkpoint(
|
||||
{
|
||||
@ -262,6 +265,12 @@ if __name__ == "__main__":
|
||||
default=0.005,
|
||||
help="The initial learning rate for the optimizer (default is Adam)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--weight_decay",
|
||||
type=float,
|
||||
default=0.00001,
|
||||
help="The weight decay for the optimizer (default is Adam)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--meta_batch",
|
||||
type=int,
|
||||
@ -274,11 +283,11 @@ if __name__ == "__main__":
|
||||
default=5,
|
||||
help="Enlarge the #iterations for an epoch",
|
||||
)
|
||||
parser.add_argument("--epochs", type=int, default=1000, help="The total #epochs.")
|
||||
parser.add_argument("--epochs", type=int, default=10000, help="The total #epochs.")
|
||||
parser.add_argument(
|
||||
"--early_stop_thresh",
|
||||
type=int,
|
||||
default=25,
|
||||
default=100,
|
||||
help="The maximum epochs for early stop.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -299,11 +308,13 @@ if __name__ == "__main__":
|
||||
if args.rand_seed is None or args.rand_seed < 0:
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
assert args.save_dir is not None, "The save dir argument can not be None"
|
||||
args.save_dir = "{:}-d{:}_{:}_{:}-e{:}-env{:}".format(
|
||||
args.save_dir = "{:}-d{:}_{:}_{:}-lr{:}-wd{:}-e{:}-env{:}".format(
|
||||
args.save_dir,
|
||||
args.hidden_dim,
|
||||
args.layer_dim,
|
||||
args.time_dim,
|
||||
args.init_lr,
|
||||
args.weight_decay,
|
||||
args.epochs,
|
||||
args.env_version,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user