diff --git a/exps/LFNA/lfna-test-hpnet.py b/exps/LFNA/lfna-test-hpnet.py index 1f14642..334c8ac 100644 --- a/exps/LFNA/lfna-test-hpnet.py +++ b/exps/LFNA/lfna-test-hpnet.py @@ -1,7 +1,7 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # ##################################################### -# python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 16 --epochs 20000 --init_lr 0.01 +# python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 16 --epochs 10000 --init_lr 0.01 # python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 16 --epochs 10000 --init_lr 0.01 --device cuda ##################################################### import sys, time, copy, torch, random, argparse @@ -76,6 +76,7 @@ def main(args): # LFNA meta-training loss_meter = AverageMeter() per_epoch_time, start_time = AverageMeter(), time.time() + last_success = 0 for iepoch in range(args.epochs): need_time = "Time Left: {:}".format( @@ -108,6 +109,13 @@ def main(args): lr_scheduler.step() loss_meter.update(final_loss.item()) + success, best_score = hypernet.save_best(-loss_meter.val) + if success: + logger.log("Achieve the best with best_score = {:.3f}".format(best_score)) + last_success = iepoch + if iepoch - last_success >= args.early_stop_thresh: + logger.log("Early stop at {:}".format(iepoch)) + break if iepoch % 20 == 0: logger.log( head_str @@ -119,11 +127,6 @@ def main(args): ) ) - success, best_score = hypernet.save_best(-loss_meter.avg) - if success: - logger.log( - "Achieve the best with best_score = {:.3f}".format(best_score) - ) save_checkpoint( { "hypernet": hypernet.state_dict(), @@ -192,6 +195,12 @@ if __name__ == "__main__": required=True, help="The hidden dimension.", ) + parser.add_argument( + "--early_stop_thresh", + type=int, + default=100, + help="The maximum epochs for early stop.", + ) ##### parser.add_argument( "--init_lr",