Update LFNA
This commit is contained in:
parent
72f240bf0a
commit
5e766603be
@ -94,7 +94,10 @@ def main(args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
optimizer = torch.optim.Adam(
|
optimizer = torch.optim.Adam(
|
||||||
meta_model.parameters(), lr=args.init_lr, weight_decay=1e-5, amsgrad=True
|
meta_model.parameters(),
|
||||||
|
lr=args.init_lr,
|
||||||
|
weight_decay=args.weight_decay,
|
||||||
|
amsgrad=True,
|
||||||
)
|
)
|
||||||
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||||
optimizer,
|
optimizer,
|
||||||
@ -137,7 +140,7 @@ def main(args):
|
|||||||
)
|
)
|
||||||
success, best_score = meta_model.save_best(-loss_meter.avg)
|
success, best_score = meta_model.save_best(-loss_meter.avg)
|
||||||
if success:
|
if success:
|
||||||
logger.log("Achieve the best with best_score = {:.3f}".format(best_score))
|
logger.log("Achieve the best with best-score = {:.5f}".format(best_score))
|
||||||
last_success_epoch = iepoch
|
last_success_epoch = iepoch
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
{
|
{
|
||||||
@ -262,6 +265,12 @@ if __name__ == "__main__":
|
|||||||
default=0.005,
|
default=0.005,
|
||||||
help="The initial learning rate for the optimizer (default is Adam)",
|
help="The initial learning rate for the optimizer (default is Adam)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--weight_decay",
|
||||||
|
type=float,
|
||||||
|
default=0.00001,
|
||||||
|
help="The weight decay for the optimizer (default is Adam)",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--meta_batch",
|
"--meta_batch",
|
||||||
type=int,
|
type=int,
|
||||||
@ -274,11 +283,11 @@ if __name__ == "__main__":
|
|||||||
default=5,
|
default=5,
|
||||||
help="Enlarge the #iterations for an epoch",
|
help="Enlarge the #iterations for an epoch",
|
||||||
)
|
)
|
||||||
parser.add_argument("--epochs", type=int, default=1000, help="The total #epochs.")
|
parser.add_argument("--epochs", type=int, default=10000, help="The total #epochs.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--early_stop_thresh",
|
"--early_stop_thresh",
|
||||||
type=int,
|
type=int,
|
||||||
default=25,
|
default=100,
|
||||||
help="The maximum epochs for early stop.",
|
help="The maximum epochs for early stop.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -299,11 +308,13 @@ if __name__ == "__main__":
|
|||||||
if args.rand_seed is None or args.rand_seed < 0:
|
if args.rand_seed is None or args.rand_seed < 0:
|
||||||
args.rand_seed = random.randint(1, 100000)
|
args.rand_seed = random.randint(1, 100000)
|
||||||
assert args.save_dir is not None, "The save dir argument can not be None"
|
assert args.save_dir is not None, "The save dir argument can not be None"
|
||||||
args.save_dir = "{:}-d{:}_{:}_{:}-e{:}-env{:}".format(
|
args.save_dir = "{:}-d{:}_{:}_{:}-lr{:}-wd{:}-e{:}-env{:}".format(
|
||||||
args.save_dir,
|
args.save_dir,
|
||||||
args.hidden_dim,
|
args.hidden_dim,
|
||||||
args.layer_dim,
|
args.layer_dim,
|
||||||
args.time_dim,
|
args.time_dim,
|
||||||
|
args.init_lr,
|
||||||
|
args.weight_decay,
|
||||||
args.epochs,
|
args.epochs,
|
||||||
args.env_version,
|
args.env_version,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user