Add early stop

This commit is contained in:
D-X-Y 2021-05-13 08:07:54 +00:00
parent d1836cbe52
commit 0138e71cf2

View File

@ -1,7 +1,7 @@
##################################################### #####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
##################################################### #####################################################
# python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 16 --epochs 20000 --init_lr 0.01 # python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 16 --epochs 10000 --init_lr 0.01
# python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 16 --epochs 10000 --init_lr 0.01 --device cuda # python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 16 --epochs 10000 --init_lr 0.01 --device cuda
##################################################### #####################################################
import sys, time, copy, torch, random, argparse import sys, time, copy, torch, random, argparse
@ -76,6 +76,7 @@ def main(args):
# LFNA meta-training # LFNA meta-training
loss_meter = AverageMeter() loss_meter = AverageMeter()
per_epoch_time, start_time = AverageMeter(), time.time() per_epoch_time, start_time = AverageMeter(), time.time()
last_success = 0
for iepoch in range(args.epochs): for iepoch in range(args.epochs):
need_time = "Time Left: {:}".format( need_time = "Time Left: {:}".format(
@ -108,6 +109,13 @@ def main(args):
lr_scheduler.step() lr_scheduler.step()
loss_meter.update(final_loss.item()) loss_meter.update(final_loss.item())
success, best_score = hypernet.save_best(-loss_meter.val)
if success:
logger.log("Achieve the best with best_score = {:.3f}".format(best_score))
last_success = iepoch
if iepoch - last_success >= args.early_stop_thresh:
logger.log("Early stop at {:}".format(iepoch))
break
if iepoch % 20 == 0: if iepoch % 20 == 0:
logger.log( logger.log(
head_str head_str
@ -119,11 +127,6 @@ def main(args):
) )
) )
success, best_score = hypernet.save_best(-loss_meter.avg)
if success:
logger.log(
"Achieve the best with best_score = {:.3f}".format(best_score)
)
save_checkpoint( save_checkpoint(
{ {
"hypernet": hypernet.state_dict(), "hypernet": hypernet.state_dict(),
@ -192,6 +195,12 @@ if __name__ == "__main__":
required=True, required=True,
help="The hidden dimension.", help="The hidden dimension.",
) )
parser.add_argument(
"--early_stop_thresh",
type=int,
default=100,
help="The maximum epochs for early stop.",
)
##### #####
parser.add_argument( parser.add_argument(
"--init_lr", "--init_lr",