Add early stop

This commit is contained in:
D-X-Y 2021-05-13 08:07:54 +00:00
parent d1836cbe52
commit 0138e71cf2

View File

@ -1,7 +1,7 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 16 --epochs 20000 --init_lr 0.01
# python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 16 --epochs 10000 --init_lr 0.01
# python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 16 --epochs 10000 --init_lr 0.01 --device cuda
#####################################################
import sys, time, copy, torch, random, argparse
@ -76,6 +76,7 @@ def main(args):
# LFNA meta-training
loss_meter = AverageMeter()
per_epoch_time, start_time = AverageMeter(), time.time()
last_success = 0
for iepoch in range(args.epochs):
need_time = "Time Left: {:}".format(
@ -108,6 +109,13 @@ def main(args):
lr_scheduler.step()
loss_meter.update(final_loss.item())
success, best_score = hypernet.save_best(-loss_meter.val)
if success:
logger.log("Achieve the best with best_score = {:.3f}".format(best_score))
last_success = iepoch
if iepoch - last_success >= args.early_stop_thresh:
logger.log("Early stop at {:}".format(iepoch))
break
if iepoch % 20 == 0:
logger.log(
head_str
@ -119,11 +127,6 @@ def main(args):
)
)
success, best_score = hypernet.save_best(-loss_meter.avg)
if success:
logger.log(
"Achieve the best with best_score = {:.3f}".format(best_score)
)
save_checkpoint(
{
"hypernet": hypernet.state_dict(),
@ -192,6 +195,12 @@ if __name__ == "__main__":
required=True,
help="The hidden dimension.",
)
parser.add_argument(
"--early_stop_thresh",
type=int,
default=100,
help="The maximum epochs for early stop.",
)
#####
parser.add_argument(
"--init_lr",