Update LFNA
This commit is contained in:
		| @@ -100,9 +100,15 @@ def pretrain(base_model, meta_model, criterion, xenv, args, logger): | ||||
|         weight_decay=args.weight_decay, | ||||
|         amsgrad=True, | ||||
|     ) | ||||
|     logger.log("Pre-train the meta-model") | ||||
|     logger.log("Using the optimizer: {:}".format(optimizer)) | ||||
|  | ||||
|     meta_model.set_best_dir(logger.path(None) / "checkpoint-pretrain") | ||||
|     per_epoch_time, start_time = AverageMeter(), time.time() | ||||
|     for iepoch in range(args.epochs): | ||||
|         left_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) | ||||
|         ) | ||||
|         total_meta_losses, total_match_losses = [], [] | ||||
|         for ibatch in range(args.meta_batch): | ||||
|             rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1) | ||||
| @@ -151,7 +157,11 @@ def pretrain(base_model, meta_model, criterion, xenv, args, logger): | ||||
|                 final_match_loss.item(), | ||||
|             ) | ||||
|             + ", batch={:}".format(len(total_meta_losses)) | ||||
|             + ", success={:}, best_score={:.4f}".format(success, -best_score) | ||||
|             + " {:}".format(left_time) | ||||
|         ) | ||||
|         per_epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user