Update LFNA

This commit is contained in:
D-X-Y 2021-05-22 09:43:48 +00:00
parent ce787df02c
commit c8e95b0ddc

View File

@ -100,9 +100,15 @@ def pretrain(base_model, meta_model, criterion, xenv, args, logger):
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
amsgrad=True, amsgrad=True,
) )
logger.log("Pre-train the meta-model")
logger.log("Using the optimizer: {:}".format(optimizer))
meta_model.set_best_dir(logger.path(None) / "checkpoint-pretrain") meta_model.set_best_dir(logger.path(None) / "checkpoint-pretrain")
per_epoch_time, start_time = AverageMeter(), time.time()
for iepoch in range(args.epochs): for iepoch in range(args.epochs):
left_time = "Time Left: {:}".format(
convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
)
total_meta_losses, total_match_losses = [], [] total_meta_losses, total_match_losses = [], []
for ibatch in range(args.meta_batch): for ibatch in range(args.meta_batch):
rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1) rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1)
@ -151,7 +157,11 @@ def pretrain(base_model, meta_model, criterion, xenv, args, logger):
final_match_loss.item(), final_match_loss.item(),
) )
+ ", batch={:}".format(len(total_meta_losses)) + ", batch={:}".format(len(total_meta_losses))
+ ", success={:}, best_score={:.4f}".format(success, -best_score)
+ " {:}".format(left_time)
) )
per_epoch_time.update(time.time() - start_time)
start_time = time.time()
def main(args): def main(args):