Update LFNA
This commit is contained in:
		| @@ -9,4 +9,4 @@ | |||||||
| - [2020.10.15] [446262a](https://github.com/D-X-Y/AutoDL-Projects/tree/446262a) Update NATS-BENCH to version 1.0 | - [2020.10.15] [446262a](https://github.com/D-X-Y/AutoDL-Projects/tree/446262a) Update NATS-BENCH to version 1.0 | ||||||
| - [2020.12.20] [dae387a](https://github.com/D-X-Y/AutoDL-Projects/tree/dae387a) Update NATS-BENCH to version 1.1 | - [2020.12.20] [dae387a](https://github.com/D-X-Y/AutoDL-Projects/tree/dae387a) Update NATS-BENCH to version 1.1 | ||||||
| - [2021.05.18] [98fadf8](https://github.com/D-X-Y/AutoDL-Projects/tree/98fadf8) Before moving to `xautodl` | - [2021.05.18] [98fadf8](https://github.com/D-X-Y/AutoDL-Projects/tree/98fadf8) Before moving to `xautodl` | ||||||
| - [2021.05.21] [5b09f05](https://github.com/D-X-Y/AutoDL-Projects/tree/5b09f05) `xautodl` is close to ready | - [2021.05.21] [8109ed1](https://github.com/D-X-Y/AutoDL-Projects/tree/8109ed1) `xautodl` is close to ready | ||||||
|   | |||||||
| @@ -1,5 +1,5 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | # Learning to Generate Model One Step Ahead         # | ||||||
| ##################################################### | ##################################################### | ||||||
| # python exps/LFNA/lfna.py --env_version v1 --workers 0 | # python exps/LFNA/lfna.py --env_version v1 --workers 0 | ||||||
| # python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.001 | # python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.001 | ||||||
| @@ -109,6 +109,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | |||||||
|     final_best_name = "final-pretrain-{:}.pth".format(args.rand_seed) |     final_best_name = "final-pretrain-{:}.pth".format(args.rand_seed) | ||||||
|     if meta_model.has_best(final_best_name): |     if meta_model.has_best(final_best_name): | ||||||
|         meta_model.load_best(final_best_name) |         meta_model.load_best(final_best_name) | ||||||
|  |         logger.log("Directly load the best model from {:}".format(final_best_name)) | ||||||
|         return |         return | ||||||
|  |  | ||||||
|     meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed)) |     meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed)) | ||||||
| @@ -118,58 +119,64 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | |||||||
|         left_time = "Time Left: {:}".format( |         left_time = "Time Left: {:}".format( | ||||||
|             convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) |             convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) | ||||||
|         ) |         ) | ||||||
|         total_meta_losses, total_match_losses = [], [] |         total_meta_v1_losses, total_meta_v2_losses, total_match_losses = [], [], [] | ||||||
|         optimizer.zero_grad() |         optimizer.zero_grad() | ||||||
|         for ibatch in range(args.meta_batch): |         for ibatch in range(args.meta_batch): | ||||||
|             rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1) |             rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1) | ||||||
|             timestamps = meta_model.meta_timestamps[ |             timestamps = meta_model.meta_timestamps[ | ||||||
|                 rand_index : rand_index + xenv.seq_length |                 rand_index : rand_index + xenv.seq_length | ||||||
|             ] |             ] | ||||||
|  |             meta_embeds = meta_model.super_meta_embed[ | ||||||
|  |                 rand_index : rand_index + xenv.seq_length | ||||||
|  |             ] | ||||||
|  |  | ||||||
|             seq_timestamps, (seq_inputs, seq_targets) = xenv.seq_call(timestamps) |             _, (seq_inputs, seq_targets) = xenv.seq_call(timestamps) | ||||||
|             [seq_containers], time_embeds = meta_model( |  | ||||||
|                 torch.unsqueeze(timestamps, dim=0), None |  | ||||||
|             ) |  | ||||||
|             # performance loss |  | ||||||
|             losses = [] |  | ||||||
|             seq_inputs, seq_targets = seq_inputs.to(args.device), seq_targets.to( |             seq_inputs, seq_targets = seq_inputs.to(args.device), seq_targets.to( | ||||||
|                 args.device |                 args.device | ||||||
|             ) |             ) | ||||||
|  |             # generate models one step ahead | ||||||
|  |             [seq_containers], time_embeds = meta_model( | ||||||
|  |                 torch.unsqueeze(timestamps, dim=0), None | ||||||
|  |             ) | ||||||
|             for container, inputs, targets in zip( |             for container, inputs, targets in zip( | ||||||
|                 seq_containers, seq_inputs, seq_targets |                 seq_containers, seq_inputs, seq_targets | ||||||
|             ): |             ): | ||||||
|                 predictions = base_model.forward_with_container(inputs, container) |                 predictions = base_model.forward_with_container(inputs, container) | ||||||
|                 loss = criterion(predictions, targets) |                 total_meta_v1_losses.append(criterion(predictions, targets)) | ||||||
|                 losses.append(loss) |             # the matching loss | ||||||
|             meta_loss = torch.stack(losses).mean() |             match_loss = criterion(torch.squeeze(time_embeds, dim=0), meta_embeds) | ||||||
|             match_loss = criterion( |  | ||||||
|                 torch.squeeze(time_embeds, dim=0), |  | ||||||
|                 meta_model.super_meta_embed[rand_index : rand_index + xenv.seq_length], |  | ||||||
|             ) |  | ||||||
|             total_meta_losses.append(meta_loss) |  | ||||||
|             total_match_losses.append(match_loss) |             total_match_losses.append(match_loss) | ||||||
|  |             # generate models via memory | ||||||
|  |             [seq_containers], _ = meta_model(None, torch.unsqueeze(meta_embeds, dim=0)) | ||||||
|  |             for container, inputs, targets in zip( | ||||||
|  |                 seq_containers, seq_inputs, seq_targets | ||||||
|  |             ): | ||||||
|  |                 predictions = base_model.forward_with_container(inputs, container) | ||||||
|  |                 total_meta_v2_losses.append(criterion(predictions, targets)) | ||||||
|         with torch.no_grad(): |         with torch.no_grad(): | ||||||
|             meta_std = torch.stack(total_meta_losses).std().item() |             meta_std = torch.stack(total_meta_v1_losses).std().item() | ||||||
|         final_meta_loss = torch.stack(total_meta_losses).mean() |         meta_v1_loss = torch.stack(total_meta_v1_losses).mean() | ||||||
|         final_match_loss = torch.stack(total_match_losses).mean() |         meta_v2_loss = torch.stack(total_meta_v2_losses).mean() | ||||||
|         total_loss = final_meta_loss + final_match_loss |         match_loss = torch.stack(total_match_losses).mean() | ||||||
|  |         total_loss = meta_v1_loss + meta_v2_loss + match_loss | ||||||
|         total_loss.backward() |         total_loss.backward() | ||||||
|         optimizer.step() |         optimizer.step() | ||||||
|         # success |         # success | ||||||
|         success, best_score = meta_model.save_best(-total_loss.item()) |         success, best_score = meta_model.save_best(-total_loss.item()) | ||||||
|         logger.log( |         logger.log( | ||||||
|             "{:} [Pre-V2 {:04d}/{:}] loss : {:.5f} +- {:.5f} = {:.5f} + {:.5f} (match)".format( |             "{:} [Pre-V2 {:04d}/{:}] loss : {:.4f} +- {:.4f} = {:.4f} + {:.4f} + {:.4f} (match)".format( | ||||||
|                 time_string(), |                 time_string(), | ||||||
|                 iepoch, |                 iepoch, | ||||||
|                 args.epochs, |                 args.epochs, | ||||||
|                 total_loss.item(), |                 total_loss.item(), | ||||||
|                 meta_std, |                 meta_std, | ||||||
|                 final_meta_loss.item(), |                 meta_v1_loss.item(), | ||||||
|                 final_match_loss.item(), |                 meta_v2_loss.item(), | ||||||
|  |                 match_loss.item(), | ||||||
|             ) |             ) | ||||||
|             + ", batch={:}".format(len(total_meta_losses)) |             + ", batch={:}".format(len(total_meta_v1_losses)) | ||||||
|             + ", success={:}, best_score={:.4f}".format(success, -best_score) |             + ", success={:}, best={:.4f}".format(success, -best_score) | ||||||
|             + ", LS={:}/{:}".format(last_success_epoch, early_stop_thresh) |             + ", LS={:}/{:}".format(last_success_epoch - iepoch, early_stop_thresh) | ||||||
|             + ", {:}".format(left_time) |             + ", {:}".format(left_time) | ||||||
|         ) |         ) | ||||||
|         if success: |         if success: | ||||||
| @@ -184,6 +191,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | |||||||
|     meta_model.set_best_name(final_best_name) |     meta_model.set_best_name(final_best_name) | ||||||
|     success, _ = meta_model.save_best(best_score + 1e-6) |     success, _ = meta_model.save_best(best_score + 1e-6) | ||||||
|     assert success |     assert success | ||||||
|  |     logger.log("Save the best model into {:}".format(final_best_name)) | ||||||
|  |  | ||||||
|  |  | ||||||
| def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger): | def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger): | ||||||
| @@ -243,8 +251,8 @@ def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger): | |||||||
|                 final_loss.item(), |                 final_loss.item(), | ||||||
|             ) |             ) | ||||||
|             + ", batch={:}".format(len(losses)) |             + ", batch={:}".format(len(losses)) | ||||||
|             + ", success={:}, best_score={:.4f}".format(success, -best_score) |             + ", success={:}, best={:.4f}".format(success, -best_score) | ||||||
|             + ", LS={:}/{:}".format(last_success_epoch, early_stop_thresh) |             + ", LS={:}/{:}".format(last_success_epoch - iepoch, early_stop_thresh) | ||||||
|             + " {:}".format(left_time) |             + " {:}".format(left_time) | ||||||
|         ) |         ) | ||||||
|         if success: |         if success: | ||||||
| @@ -277,6 +285,8 @@ def main(args): | |||||||
|  |  | ||||||
|     logger.log("The base-model has {:} weights.".format(base_model.numel())) |     logger.log("The base-model has {:} weights.".format(base_model.numel())) | ||||||
|     logger.log("The meta-model has {:} weights.".format(meta_model.numel())) |     logger.log("The meta-model has {:} weights.".format(meta_model.numel())) | ||||||
|  |     logger.log("The base-model is\n{:}".format(base_model)) | ||||||
|  |     logger.log("The meta-model is\n{:}".format(meta_model)) | ||||||
|  |  | ||||||
|     batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge) |     batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge) | ||||||
|     train_env.reset_max_seq_length(args.seq_length) |     train_env.reset_max_seq_length(args.seq_length) | ||||||
| @@ -294,9 +304,10 @@ def main(args): | |||||||
|         num_workers=args.workers, |         num_workers=args.workers, | ||||||
|         pin_memory=True, |         pin_memory=True, | ||||||
|     ) |     ) | ||||||
|  |     pretrain_v2(base_model, meta_model, criterion, train_env, args, logger) | ||||||
|  |  | ||||||
|     optimizer = torch.optim.Adam( |     optimizer = torch.optim.Adam( | ||||||
|         meta_model.parameters(), |         meta_model.get_parameters(True, True, False),  # fix hypernet | ||||||
|         lr=args.lr, |         lr=args.lr, | ||||||
|         weight_decay=args.weight_decay, |         weight_decay=args.weight_decay, | ||||||
|         amsgrad=True, |         amsgrad=True, | ||||||
| @@ -306,14 +317,10 @@ def main(args): | |||||||
|         milestones=[1, 2, 3, 4, 5], |         milestones=[1, 2, 3, 4, 5], | ||||||
|         gamma=0.2, |         gamma=0.2, | ||||||
|     ) |     ) | ||||||
|     logger.log("The base-model is\n{:}".format(base_model)) |  | ||||||
|     logger.log("The meta-model is\n{:}".format(meta_model)) |  | ||||||
|     logger.log("The optimizer is\n{:}".format(optimizer)) |     logger.log("The optimizer is\n{:}".format(optimizer)) | ||||||
|     logger.log("The scheduler is\n{:}".format(lr_scheduler)) |     logger.log("The scheduler is\n{:}".format(lr_scheduler)) | ||||||
|     logger.log("Per epoch iterations = {:}".format(len(train_env_loader))) |     logger.log("Per epoch iterations = {:}".format(len(train_env_loader))) | ||||||
|  |  | ||||||
|     pretrain_v2(base_model, meta_model, criterion, train_env, args, logger) |  | ||||||
|  |  | ||||||
|     if logger.path("model").exists(): |     if logger.path("model").exists(): | ||||||
|         ckp_data = torch.load(logger.path("model")) |         ckp_data = torch.load(logger.path("model")) | ||||||
|         base_model.load_state_dict(ckp_data["base_model"]) |         base_model.load_state_dict(ckp_data["base_model"]) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user