Add early stop
This commit is contained in:
		| @@ -1,7 +1,7 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ##################################################### | ||||
| # python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 16 --epochs 20000 --init_lr 0.01 | ||||
| # python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 16 --epochs 10000 --init_lr 0.01 | ||||
| # python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 16 --epochs 10000 --init_lr 0.01 --device cuda | ||||
| ##################################################### | ||||
| import sys, time, copy, torch, random, argparse | ||||
| @@ -76,6 +76,7 @@ def main(args): | ||||
|     # LFNA meta-training | ||||
|     loss_meter = AverageMeter() | ||||
|     per_epoch_time, start_time = AverageMeter(), time.time() | ||||
|     last_success = 0 | ||||
|     for iepoch in range(args.epochs): | ||||
|  | ||||
|         need_time = "Time Left: {:}".format( | ||||
| @@ -108,6 +109,13 @@ def main(args): | ||||
|         lr_scheduler.step() | ||||
|  | ||||
|         loss_meter.update(final_loss.item()) | ||||
|         success, best_score = hypernet.save_best(-loss_meter.val) | ||||
|         if success: | ||||
|             logger.log("Achieve the best with best_score = {:.3f}".format(best_score)) | ||||
|             last_success = iepoch | ||||
|         if iepoch - last_success >= args.early_stop_thresh: | ||||
|             logger.log("Early stop at {:}".format(iepoch)) | ||||
|             break | ||||
|         if iepoch % 20 == 0: | ||||
|             logger.log( | ||||
|                 head_str | ||||
| @@ -119,11 +127,6 @@ def main(args): | ||||
|                 ) | ||||
|             ) | ||||
|  | ||||
|             success, best_score = hypernet.save_best(-loss_meter.avg) | ||||
|             if success: | ||||
|                 logger.log( | ||||
|                     "Achieve the best with best_score = {:.3f}".format(best_score) | ||||
|                 ) | ||||
|             save_checkpoint( | ||||
|                 { | ||||
|                     "hypernet": hypernet.state_dict(), | ||||
| @@ -192,6 +195,12 @@ if __name__ == "__main__": | ||||
|         required=True, | ||||
|         help="The hidden dimension.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--early_stop_thresh", | ||||
|         type=int, | ||||
|         default=100, | ||||
|         help="The maximum epochs for early stop.", | ||||
|     ) | ||||
|     ##### | ||||
|     parser.add_argument( | ||||
|         "--init_lr", | ||||
|   | ||||
		Reference in New Issue
	
	Block a user