Add early stop
This commit is contained in:
parent
d1836cbe52
commit
0138e71cf2
@ -1,7 +1,7 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
||||||
#####################################################
|
#####################################################
|
||||||
# python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 16 --epochs 20000 --init_lr 0.01
|
# python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 16 --epochs 10000 --init_lr 0.01
|
||||||
# python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 16 --epochs 10000 --init_lr 0.01 --device cuda
|
# python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 16 --epochs 10000 --init_lr 0.01 --device cuda
|
||||||
#####################################################
|
#####################################################
|
||||||
import sys, time, copy, torch, random, argparse
|
import sys, time, copy, torch, random, argparse
|
||||||
@ -76,6 +76,7 @@ def main(args):
|
|||||||
# LFNA meta-training
|
# LFNA meta-training
|
||||||
loss_meter = AverageMeter()
|
loss_meter = AverageMeter()
|
||||||
per_epoch_time, start_time = AverageMeter(), time.time()
|
per_epoch_time, start_time = AverageMeter(), time.time()
|
||||||
|
last_success = 0
|
||||||
for iepoch in range(args.epochs):
|
for iepoch in range(args.epochs):
|
||||||
|
|
||||||
need_time = "Time Left: {:}".format(
|
need_time = "Time Left: {:}".format(
|
||||||
@ -108,6 +109,13 @@ def main(args):
|
|||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
|
|
||||||
loss_meter.update(final_loss.item())
|
loss_meter.update(final_loss.item())
|
||||||
|
success, best_score = hypernet.save_best(-loss_meter.val)
|
||||||
|
if success:
|
||||||
|
logger.log("Achieve the best with best_score = {:.3f}".format(best_score))
|
||||||
|
last_success = iepoch
|
||||||
|
if iepoch - last_success >= args.early_stop_thresh:
|
||||||
|
logger.log("Early stop at {:}".format(iepoch))
|
||||||
|
break
|
||||||
if iepoch % 20 == 0:
|
if iepoch % 20 == 0:
|
||||||
logger.log(
|
logger.log(
|
||||||
head_str
|
head_str
|
||||||
@ -119,11 +127,6 @@ def main(args):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
success, best_score = hypernet.save_best(-loss_meter.avg)
|
|
||||||
if success:
|
|
||||||
logger.log(
|
|
||||||
"Achieve the best with best_score = {:.3f}".format(best_score)
|
|
||||||
)
|
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
{
|
{
|
||||||
"hypernet": hypernet.state_dict(),
|
"hypernet": hypernet.state_dict(),
|
||||||
@ -192,6 +195,12 @@ if __name__ == "__main__":
|
|||||||
required=True,
|
required=True,
|
||||||
help="The hidden dimension.",
|
help="The hidden dimension.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--early_stop_thresh",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="The maximum epochs for early stop.",
|
||||||
|
)
|
||||||
#####
|
#####
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--init_lr",
|
"--init_lr",
|
||||||
|
Loading…
Reference in New Issue
Block a user