Update LFNA
This commit is contained in:
		| @@ -1,5 +1,5 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| # Learning to Generate Model One Step Ahead         # | ||||
| ##################################################### | ||||
| # python exps/LFNA/lfna.py --env_version v1 --workers 0 | ||||
| # python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.001 | ||||
| @@ -109,6 +109,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | ||||
|     final_best_name = "final-pretrain-{:}.pth".format(args.rand_seed) | ||||
|     if meta_model.has_best(final_best_name): | ||||
|         meta_model.load_best(final_best_name) | ||||
|         logger.log("Directly load the best model from {:}".format(final_best_name)) | ||||
|         return | ||||
|  | ||||
|     meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed)) | ||||
| @@ -118,58 +119,64 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | ||||
|         left_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) | ||||
|         ) | ||||
|         total_meta_losses, total_match_losses = [], [] | ||||
|         total_meta_v1_losses, total_meta_v2_losses, total_match_losses = [], [], [] | ||||
|         optimizer.zero_grad() | ||||
|         for ibatch in range(args.meta_batch): | ||||
|             rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1) | ||||
|             timestamps = meta_model.meta_timestamps[ | ||||
|                 rand_index : rand_index + xenv.seq_length | ||||
|             ] | ||||
|             meta_embeds = meta_model.super_meta_embed[ | ||||
|                 rand_index : rand_index + xenv.seq_length | ||||
|             ] | ||||
|  | ||||
|             seq_timestamps, (seq_inputs, seq_targets) = xenv.seq_call(timestamps) | ||||
|             [seq_containers], time_embeds = meta_model( | ||||
|                 torch.unsqueeze(timestamps, dim=0), None | ||||
|             ) | ||||
|             # performance loss | ||||
|             losses = [] | ||||
|             _, (seq_inputs, seq_targets) = xenv.seq_call(timestamps) | ||||
|             seq_inputs, seq_targets = seq_inputs.to(args.device), seq_targets.to( | ||||
|                 args.device | ||||
|             ) | ||||
|             # generate models one step ahead | ||||
|             [seq_containers], time_embeds = meta_model( | ||||
|                 torch.unsqueeze(timestamps, dim=0), None | ||||
|             ) | ||||
|             for container, inputs, targets in zip( | ||||
|                 seq_containers, seq_inputs, seq_targets | ||||
|             ): | ||||
|                 predictions = base_model.forward_with_container(inputs, container) | ||||
|                 loss = criterion(predictions, targets) | ||||
|                 losses.append(loss) | ||||
|             meta_loss = torch.stack(losses).mean() | ||||
|             match_loss = criterion( | ||||
|                 torch.squeeze(time_embeds, dim=0), | ||||
|                 meta_model.super_meta_embed[rand_index : rand_index + xenv.seq_length], | ||||
|             ) | ||||
|             total_meta_losses.append(meta_loss) | ||||
|                 total_meta_v1_losses.append(criterion(predictions, targets)) | ||||
|             # the matching loss | ||||
|             match_loss = criterion(torch.squeeze(time_embeds, dim=0), meta_embeds) | ||||
|             total_match_losses.append(match_loss) | ||||
|             # generate models via memory | ||||
|             [seq_containers], _ = meta_model(None, torch.unsqueeze(meta_embeds, dim=0)) | ||||
|             for container, inputs, targets in zip( | ||||
|                 seq_containers, seq_inputs, seq_targets | ||||
|             ): | ||||
|                 predictions = base_model.forward_with_container(inputs, container) | ||||
|                 total_meta_v2_losses.append(criterion(predictions, targets)) | ||||
|         with torch.no_grad(): | ||||
|             meta_std = torch.stack(total_meta_losses).std().item() | ||||
|         final_meta_loss = torch.stack(total_meta_losses).mean() | ||||
|         final_match_loss = torch.stack(total_match_losses).mean() | ||||
|         total_loss = final_meta_loss + final_match_loss | ||||
|             meta_std = torch.stack(total_meta_v1_losses).std().item() | ||||
|         meta_v1_loss = torch.stack(total_meta_v1_losses).mean() | ||||
|         meta_v2_loss = torch.stack(total_meta_v2_losses).mean() | ||||
|         match_loss = torch.stack(total_match_losses).mean() | ||||
|         total_loss = meta_v1_loss + meta_v2_loss + match_loss | ||||
|         total_loss.backward() | ||||
|         optimizer.step() | ||||
|         # success | ||||
|         success, best_score = meta_model.save_best(-total_loss.item()) | ||||
|         logger.log( | ||||
|             "{:} [Pre-V2 {:04d}/{:}] loss : {:.5f} +- {:.5f} = {:.5f} + {:.5f} (match)".format( | ||||
|             "{:} [Pre-V2 {:04d}/{:}] loss : {:.4f} +- {:.4f} = {:.4f} + {:.4f} + {:.4f} (match)".format( | ||||
|                 time_string(), | ||||
|                 iepoch, | ||||
|                 args.epochs, | ||||
|                 total_loss.item(), | ||||
|                 meta_std, | ||||
|                 final_meta_loss.item(), | ||||
|                 final_match_loss.item(), | ||||
|                 meta_v1_loss.item(), | ||||
|                 meta_v2_loss.item(), | ||||
|                 match_loss.item(), | ||||
|             ) | ||||
|             + ", batch={:}".format(len(total_meta_losses)) | ||||
|             + ", success={:}, best_score={:.4f}".format(success, -best_score) | ||||
|             + ", LS={:}/{:}".format(last_success_epoch, early_stop_thresh) | ||||
|             + ", batch={:}".format(len(total_meta_v1_losses)) | ||||
|             + ", success={:}, best={:.4f}".format(success, -best_score) | ||||
|             + ", LS={:}/{:}".format(last_success_epoch - iepoch, early_stop_thresh) | ||||
|             + ", {:}".format(left_time) | ||||
|         ) | ||||
|         if success: | ||||
| @@ -184,6 +191,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | ||||
|     meta_model.set_best_name(final_best_name) | ||||
|     success, _ = meta_model.save_best(best_score + 1e-6) | ||||
|     assert success | ||||
|     logger.log("Save the best model into {:}".format(final_best_name)) | ||||
|  | ||||
|  | ||||
| def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger): | ||||
| @@ -243,8 +251,8 @@ def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger): | ||||
|                 final_loss.item(), | ||||
|             ) | ||||
|             + ", batch={:}".format(len(losses)) | ||||
|             + ", success={:}, best_score={:.4f}".format(success, -best_score) | ||||
|             + ", LS={:}/{:}".format(last_success_epoch, early_stop_thresh) | ||||
|             + ", success={:}, best={:.4f}".format(success, -best_score) | ||||
|             + ", LS={:}/{:}".format(last_success_epoch - iepoch, early_stop_thresh) | ||||
|             + " {:}".format(left_time) | ||||
|         ) | ||||
|         if success: | ||||
| @@ -277,6 +285,8 @@ def main(args): | ||||
|  | ||||
|     logger.log("The base-model has {:} weights.".format(base_model.numel())) | ||||
|     logger.log("The meta-model has {:} weights.".format(meta_model.numel())) | ||||
|     logger.log("The base-model is\n{:}".format(base_model)) | ||||
|     logger.log("The meta-model is\n{:}".format(meta_model)) | ||||
|  | ||||
|     batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge) | ||||
|     train_env.reset_max_seq_length(args.seq_length) | ||||
| @@ -294,9 +304,10 @@ def main(args): | ||||
|         num_workers=args.workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     pretrain_v2(base_model, meta_model, criterion, train_env, args, logger) | ||||
|  | ||||
|     optimizer = torch.optim.Adam( | ||||
|         meta_model.parameters(), | ||||
|         meta_model.get_parameters(True, True, False),  # fix hypernet | ||||
|         lr=args.lr, | ||||
|         weight_decay=args.weight_decay, | ||||
|         amsgrad=True, | ||||
| @@ -306,14 +317,10 @@ def main(args): | ||||
|         milestones=[1, 2, 3, 4, 5], | ||||
|         gamma=0.2, | ||||
|     ) | ||||
|     logger.log("The base-model is\n{:}".format(base_model)) | ||||
|     logger.log("The meta-model is\n{:}".format(meta_model)) | ||||
|     logger.log("The optimizer is\n{:}".format(optimizer)) | ||||
|     logger.log("The scheduler is\n{:}".format(lr_scheduler)) | ||||
|     logger.log("Per epoch iterations = {:}".format(len(train_env_loader))) | ||||
|  | ||||
|     pretrain_v2(base_model, meta_model, criterion, train_env, args, logger) | ||||
|  | ||||
|     if logger.path("model").exists(): | ||||
|         ckp_data = torch.load(logger.path("model")) | ||||
|         base_model.load_state_dict(ckp_data["base_model"]) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user