diff --git a/exps/LFNA/lfna-test-hpnet.py b/exps/LFNA/backup/lfna-test-hpnet.py similarity index 100% rename from exps/LFNA/lfna-test-hpnet.py rename to exps/LFNA/backup/lfna-test-hpnet.py diff --git a/exps/LFNA/basic-same.py b/exps/LFNA/basic-same.py index b565f4d..d7dc9b2 100644 --- a/exps/LFNA/basic-same.py +++ b/exps/LFNA/basic-same.py @@ -1,8 +1,8 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # ##################################################### -# python exps/LFNA/basic-same.py --env_version v1 --hidden_dim 16 -# python exps/LFNA/basic-same.py --srange 1-999 --env_version v2 --hidden_dim +# python exps/LFNA/basic-same.py --env_version v1 --hidden_dim 16 --epochs 500 --init_lr 0.1 +# python exps/LFNA/basic-same.py --env_version v2 --hidden_dim 16 --epochs 1000 --init_lr 0.05 ##################################################### import sys, time, copy, torch, random, argparse from tqdm import tqdm @@ -58,7 +58,6 @@ def main(args): # build model model = get_model(**model_kwargs) print(model) - model.analyze_weights() # build optimizer optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) criterion = torch.nn.MSELoss() @@ -85,6 +84,7 @@ def main(args): best_loss = loss.item() best_param = copy.deepcopy(model.state_dict()) model.load_state_dict(best_param) + model.analyze_weights() with torch.no_grad(): train_metric(preds, historical_y) train_results = train_metric.get_info() diff --git a/exps/LFNA/lfna-tall-hpnet.py b/exps/LFNA/lfna-debug-hpnet.py similarity index 56% rename from exps/LFNA/lfna-tall-hpnet.py rename to exps/LFNA/lfna-debug-hpnet.py index 7d2dbee..6e3e627 100644 --- a/exps/LFNA/lfna-tall-hpnet.py +++ b/exps/LFNA/lfna-debug-hpnet.py @@ -1,7 +1,7 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # ##################################################### -# python exps/LFNA/lfna-tall-hpnet.py --env_version v1 --hidden_dim 16 --epochs 100000 --meta_batch 64 +# python exps/LFNA/lfna-debug-hpnet.py --env_version v1 --hidden_dim 16 --meta_batch 64 --device cuda ##################################################### import sys, time, copy, torch, random, argparse from tqdm import tqdm @@ -26,7 +26,6 @@ from xlayers import super_core, trunc_normal_ from lfna_utils import lfna_setup, train_model, TimeData -# from lfna_models import HyperNet_VX as HyperNet from lfna_models import HyperNet @@ -36,19 +35,31 @@ def main(args): model = get_model(**model_kwargs) criterion = torch.nn.MSELoss() - logger.log("There are {:} weights.".format(model.numel())) - shape_container = model.get_w_container().to_shape_container() - hypernet = HyperNet(shape_container, args.hidden_dim, args.task_dim) - total_bar = env_info["total"] - 1 - task_embeds = [] - for i in range(env_info["total"]): - task_embeds.append(torch.nn.Parameter(torch.Tensor(1, args.task_dim))) - for task_embed in task_embeds: - trunc_normal_(task_embed, std=0.02) + hypernet = HyperNet( + shape_container, args.hidden_dim, args.task_dim, len(dynamic_env) + ) + hypernet = hypernet.to(args.device) - parameters = list(hypernet.parameters()) + task_embeds - optimizer = torch.optim.Adam(parameters, lr=args.init_lr, amsgrad=True) + logger.log( + "{:} There are {:} weights in the base-model.".format( + time_string(), model.numel() + ) + ) + logger.log( + "{:} There are {:} weights in the meta-model.".format( + time_string(), hypernet.numel() + ) + ) + + for i in range(len(dynamic_env)): + env_info["{:}-x".format(i)] = env_info["{:}-x".format(i)].to(args.device) + env_info["{:}-y".format(i)] = env_info["{:}-y".format(i)].to(args.device) + logger.log("{:} Convert to device-{:} done".format(time_string(), args.device)) + + optimizer = torch.optim.Adam( + hypernet.parameters(), lr=args.init_lr, weight_decay=1e-5, amsgrad=True + ) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[ @@ -59,8 +70,8 @@ def main(args): ) # LFNA meta-training - loss_meter = AverageMeter() per_epoch_time, start_time = AverageMeter(), time.time() + last_success_epoch = 0 for iepoch in range(args.epochs): need_time = "Time Left: {:}".format( @@ -70,65 +81,65 @@ def main(args): "[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs) + need_time ) + # One Epoch + loss_meter = AverageMeter() + for istep in range(args.per_epoch_step): + losses = [] + for ibatch in range(args.meta_batch): + cur_time = random.randint(0, len(dynamic_env) - 1) + cur_container = hypernet(cur_time) + cur_x = env_info["{:}-x".format(cur_time)] + cur_y = env_info["{:}-y".format(cur_time)] + cur_dataset = TimeData(cur_time, cur_x, cur_y) - limit_bar = float(iepoch + 1) / args.epochs * total_bar - limit_bar = min(max(32, int(limit_bar)), total_bar) - losses = [] - for ibatch in range(args.meta_batch): - cur_time = random.randint(0, limit_bar) - cur_task_embed = task_embeds[cur_time] - cur_container = hypernet(cur_task_embed) - cur_x = env_info["{:}-x".format(cur_time)] - cur_y = env_info["{:}-y".format(cur_time)] - cur_dataset = TimeData(cur_time, cur_x, cur_y) + preds = model.forward_with_container(cur_dataset.x, cur_container) + optimizer.zero_grad() + loss = criterion(preds, cur_dataset.y) - preds = model.forward_with_container(cur_dataset.x, cur_container) - optimizer.zero_grad() - loss = criterion(preds, cur_dataset.y) - - losses.append(loss) - - final_loss = torch.stack(losses).mean() - final_loss.backward() - torch.nn.utils.clip_grad_norm_(parameters, 1.0) - optimizer.step() - lr_scheduler.step() - - loss_meter.update(final_loss.item()) - if iepoch % 200 == 0: - logger.log( - head_str - + " meta-loss: {:.4f} ({:.4f}) :: lr={:.5f}, batch={:}, limit={:}".format( - loss_meter.avg, - loss_meter.val, - min(lr_scheduler.get_last_lr()), - len(losses), - limit_bar, - ) + losses.append(loss) + final_loss = torch.stack(losses).mean() + final_loss.backward() + optimizer.step() + lr_scheduler.step() + loss_meter.update(final_loss.item()) + success, best_score = hypernet.save_best(-loss_meter.avg) + if success: + logger.log("Achieve the best with best_score = {:.3f}".format(best_score)) + last_success_epoch = iepoch + if iepoch - last_success_epoch >= args.early_stop_thresh: + logger.log("Early stop at {:}".format(iepoch)) + break + logger.log( + head_str + + " meta-loss: {:.4f} ({:.4f}) :: lr={:.5f}, batch={:}".format( + loss_meter.avg, + loss_meter.val, + min(lr_scheduler.get_last_lr()), + len(losses), ) + ) - save_checkpoint( - { - "hypernet": hypernet.state_dict(), - "task_embeds": task_embeds, - "lr_scheduler": lr_scheduler.state_dict(), - "iepoch": iepoch, - }, - logger.path("model"), - logger, - ) - loss_meter.reset() + save_checkpoint( + { + "hypernet": hypernet.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + "iepoch": iepoch, + }, + logger.path("model"), + logger, + ) per_epoch_time.update(time.time() - start_time) start_time = time.time() print(model) print(hypernet) + hypernet.load_best() + w_container_per_epoch = dict() for idx in range(0, env_info["total"]): - future_time = env_info["{:}-timestamp".format(idx)] future_x = env_info["{:}-x".format(idx)] future_y = env_info["{:}-y".format(idx)] - future_container = hypernet(task_embeds[idx]) + future_container = hypernet(idx) w_container_per_epoch[idx] = future_container.no_grad_clone() with torch.no_grad(): future_y_hat = model.forward_with_container( @@ -152,7 +163,7 @@ if __name__ == "__main__": parser.add_argument( "--save_dir", type=str, - default="./outputs/lfna-synthetic/lfna-tall-hpnet", + default="./outputs/lfna-synthetic/lfna-debug-hpnet", help="The checkpoint directory.", ) parser.add_argument( @@ -171,7 +182,7 @@ if __name__ == "__main__": parser.add_argument( "--init_lr", type=float, - default=0.1, + default=0.01, help="The initial learning rate for the optimizer (default is Adam)", ) parser.add_argument( @@ -180,12 +191,30 @@ if __name__ == "__main__": default=64, help="The batch size for the meta-model", ) + parser.add_argument( + "--early_stop_thresh", + type=int, + default=100, + help="The maximum epochs for early stop.", + ) parser.add_argument( "--epochs", type=int, default=2000, help="The total number of epochs.", ) + parser.add_argument( + "--per_epoch_step", + type=int, + default=20, + help="The total number of epochs.", + ) + parser.add_argument( + "--device", + type=str, + default="cpu", + help="", + ) # Random Seed parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") args = parser.parse_args() diff --git a/exps/LFNA/lfna_models_v2.py b/exps/LFNA/lfna_models_v2.py index 8cdbe97..ad1f91f 100644 --- a/exps/LFNA/lfna_models_v2.py +++ b/exps/LFNA/lfna_models_v2.py @@ -39,10 +39,10 @@ class HyperNet(super_core.SuperModule): config=dict(model_type="dual_norm_mlp"), input_dim=layer_embeding + task_embedding, output_dim=max(self._numel_per_layer), - hidden_dims=[layer_embeding * 4] * 3, + hidden_dims=[(layer_embeding + task_embedding) * 2] * 3, act_cls="gelu", norm_cls="layer_norm_1d", - dropout=0.1, + dropout=0.2, ) import pdb