From 9057011781a2b653aa337a1d606fbf8e5840452f Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Wed, 26 May 2021 02:41:36 +0000 Subject: [PATCH] Update codes --- exps/GMOA/lfna-debug-hpnet.py | 228 ----------- exps/GMOA/lfna-debug.py | 476 ----------------------- exps/{GMOA => GeMOSA}/basic-his.py | 0 exps/{GMOA => GeMOSA}/basic-maml.py | 0 exps/{GMOA => GeMOSA}/basic-prev.py | 0 exps/{GMOA => GeMOSA}/basic-same.py | 4 +- exps/{GMOA => GeMOSA}/lfna_meta_model.py | 11 +- exps/{GMOA => GeMOSA}/lfna_models.py | 0 exps/{GMOA => GeMOSA}/lfna_utils.py | 0 exps/{GMOA/lfna.py => GeMOSA/main.py} | 123 ++---- exps/{GMOA => GeMOSA}/vis-synthetic.py | 0 11 files changed, 42 insertions(+), 800 deletions(-) delete mode 100644 exps/GMOA/lfna-debug-hpnet.py delete mode 100644 exps/GMOA/lfna-debug.py rename exps/{GMOA => GeMOSA}/basic-his.py (100%) rename exps/{GMOA => GeMOSA}/basic-maml.py (100%) rename exps/{GMOA => GeMOSA}/basic-prev.py (100%) rename exps/{GMOA => GeMOSA}/basic-same.py (97%) rename exps/{GMOA => GeMOSA}/lfna_meta_model.py (98%) rename exps/{GMOA => GeMOSA}/lfna_models.py (100%) rename exps/{GMOA => GeMOSA}/lfna_utils.py (100%) rename exps/{GMOA/lfna.py => GeMOSA/main.py} (71%) rename exps/{GMOA => GeMOSA}/vis-synthetic.py (100%) diff --git a/exps/GMOA/lfna-debug-hpnet.py b/exps/GMOA/lfna-debug-hpnet.py deleted file mode 100644 index 6e3e627..0000000 --- a/exps/GMOA/lfna-debug-hpnet.py +++ /dev/null @@ -1,228 +0,0 @@ -##################################################### -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # -##################################################### -# python exps/LFNA/lfna-debug-hpnet.py --env_version v1 --hidden_dim 16 --meta_batch 64 --device cuda -##################################################### -import sys, time, copy, torch, random, argparse -from tqdm import tqdm -from copy import deepcopy -from pathlib import Path - -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint -from log_utils import time_string -from log_utils import AverageMeter, convert_secs2time - -from utils import split_str2indexes - -from procedures.advanced_main import basic_train_fn, basic_eval_fn -from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric -from datasets.synthetic_core import get_synthetic_env -from models.xcore import get_model -from xlayers import super_core, trunc_normal_ - - -from lfna_utils import lfna_setup, train_model, TimeData - -from lfna_models import HyperNet - - -def main(args): - logger, env_info, model_kwargs = lfna_setup(args) - dynamic_env = env_info["dynamic_env"] - model = get_model(**model_kwargs) - criterion = torch.nn.MSELoss() - - shape_container = model.get_w_container().to_shape_container() - hypernet = HyperNet( - shape_container, args.hidden_dim, args.task_dim, len(dynamic_env) - ) - hypernet = hypernet.to(args.device) - - logger.log( - "{:} There are {:} weights in the base-model.".format( - time_string(), model.numel() - ) - ) - logger.log( - "{:} There are {:} weights in the meta-model.".format( - time_string(), hypernet.numel() - ) - ) - - for i in range(len(dynamic_env)): - env_info["{:}-x".format(i)] = env_info["{:}-x".format(i)].to(args.device) - env_info["{:}-y".format(i)] = env_info["{:}-y".format(i)].to(args.device) - logger.log("{:} Convert to device-{:} done".format(time_string(), args.device)) - - optimizer = torch.optim.Adam( - hypernet.parameters(), lr=args.init_lr, weight_decay=1e-5, amsgrad=True - ) - lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( - optimizer, - milestones=[ - int(args.epochs * 0.8), - int(args.epochs * 0.9), - ], - gamma=0.1, - ) - - # LFNA meta-training - per_epoch_time, start_time = AverageMeter(), time.time() - last_success_epoch = 0 - for iepoch in range(args.epochs): - - need_time = "Time Left: {:}".format( - convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) - ) - head_str = ( - "[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs) - + need_time - ) - # One Epoch - loss_meter = AverageMeter() - for istep in range(args.per_epoch_step): - losses = [] - for ibatch in range(args.meta_batch): - cur_time = random.randint(0, len(dynamic_env) - 1) - cur_container = hypernet(cur_time) - cur_x = env_info["{:}-x".format(cur_time)] - cur_y = env_info["{:}-y".format(cur_time)] - cur_dataset = TimeData(cur_time, cur_x, cur_y) - - preds = model.forward_with_container(cur_dataset.x, cur_container) - optimizer.zero_grad() - loss = criterion(preds, cur_dataset.y) - - losses.append(loss) - final_loss = torch.stack(losses).mean() - final_loss.backward() - optimizer.step() - lr_scheduler.step() - loss_meter.update(final_loss.item()) - success, best_score = hypernet.save_best(-loss_meter.avg) - if success: - logger.log("Achieve the best with best_score = {:.3f}".format(best_score)) - last_success_epoch = iepoch - if iepoch - last_success_epoch >= args.early_stop_thresh: - logger.log("Early stop at {:}".format(iepoch)) - break - logger.log( - head_str - + " meta-loss: {:.4f} ({:.4f}) :: lr={:.5f}, batch={:}".format( - loss_meter.avg, - loss_meter.val, - min(lr_scheduler.get_last_lr()), - len(losses), - ) - ) - - save_checkpoint( - { - "hypernet": hypernet.state_dict(), - "lr_scheduler": lr_scheduler.state_dict(), - "iepoch": iepoch, - }, - logger.path("model"), - logger, - ) - per_epoch_time.update(time.time() - start_time) - start_time = time.time() - - print(model) - print(hypernet) - hypernet.load_best() - - w_container_per_epoch = dict() - for idx in range(0, env_info["total"]): - future_x = env_info["{:}-x".format(idx)] - future_y = env_info["{:}-y".format(idx)] - future_container = hypernet(idx) - w_container_per_epoch[idx] = future_container.no_grad_clone() - with torch.no_grad(): - future_y_hat = model.forward_with_container( - future_x, w_container_per_epoch[idx] - ) - future_loss = criterion(future_y_hat, future_y) - logger.log("meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item())) - - save_checkpoint( - {"w_container_per_epoch": w_container_per_epoch}, - logger.path(None) / "final-ckp.pth", - logger, - ) - - logger.log("-" * 200 + "\n") - logger.close() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser("Use the data in the past.") - parser.add_argument( - "--save_dir", - type=str, - default="./outputs/lfna-synthetic/lfna-debug-hpnet", - help="The checkpoint directory.", - ) - parser.add_argument( - "--env_version", - type=str, - required=True, - help="The synthetic enviornment version.", - ) - parser.add_argument( - "--hidden_dim", - type=int, - required=True, - help="The hidden dimension.", - ) - ##### - parser.add_argument( - "--init_lr", - type=float, - default=0.01, - help="The initial learning rate for the optimizer (default is Adam)", - ) - parser.add_argument( - "--meta_batch", - type=int, - default=64, - help="The batch size for the meta-model", - ) - parser.add_argument( - "--early_stop_thresh", - type=int, - default=100, - help="The maximum epochs for early stop.", - ) - parser.add_argument( - "--epochs", - type=int, - default=2000, - help="The total number of epochs.", - ) - parser.add_argument( - "--per_epoch_step", - type=int, - default=20, - help="The total number of epochs.", - ) - parser.add_argument( - "--device", - type=str, - default="cpu", - help="", - ) - # Random Seed - parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") - args = parser.parse_args() - if args.rand_seed is None or args.rand_seed < 0: - args.rand_seed = random.randint(1, 100000) - assert args.save_dir is not None, "The save dir argument can not be None" - args.task_dim = args.hidden_dim - args.save_dir = "{:}-{:}-d{:}".format( - args.save_dir, args.env_version, args.hidden_dim - ) - main(args) diff --git a/exps/GMOA/lfna-debug.py b/exps/GMOA/lfna-debug.py deleted file mode 100644 index 0802e7e..0000000 --- a/exps/GMOA/lfna-debug.py +++ /dev/null @@ -1,476 +0,0 @@ -##################################################### -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # -##################################################### -# python exps/LFNA/lfna-debug.py --env_version v1 --workers 0 -# python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.001 -# python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.002 -##################################################### -import sys, time, copy, torch, random, argparse -from tqdm import tqdm -from copy import deepcopy -from pathlib import Path - -lib_dir = (Path(__file__).parent / ".." / "..").resolve() -print("LIB-DIR: {:}".format(lib_dir)) -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) - -from xautodl.procedures import ( - prepare_seed, - prepare_logger, - save_checkpoint, - copy_checkpoint, -) -from xautodl.log_utils import time_string -from xautodl.log_utils import AverageMeter, convert_secs2time - -from xautodl.utils import split_str2indexes - -from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn -from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric -from xautodl.datasets.synthetic_core import get_synthetic_env, EnvSampler -from xautodl.models.xcore import get_model -from xautodl.xlayers import super_core, trunc_normal_ - -from lfna_utils import lfna_setup, train_model, TimeData -from lfna_meta_model import LFNA_Meta - - -def epoch_train(loader, meta_model, base_model, optimizer, criterion, device, logger): - base_model.train() - meta_model.train() - loss_meter = AverageMeter() - for ibatch, batch_data in enumerate(loader): - timestamps, (batch_seq_inputs, batch_seq_targets) = batch_data - timestamps = timestamps.squeeze(dim=-1).to(device) - batch_seq_inputs = batch_seq_inputs.to(device) - batch_seq_targets = batch_seq_targets.to(device) - - optimizer.zero_grad() - - batch_seq_containers = meta_model(timestamps) - losses = [] - for seq_containers, seq_inputs, seq_targets in zip( - batch_seq_containers, batch_seq_inputs, batch_seq_targets - ): - for container, inputs, targets in zip( - seq_containers, seq_inputs, seq_targets - ): - predictions = base_model.forward_with_container(inputs, container) - loss = criterion(predictions, targets) - losses.append(loss) - final_loss = torch.stack(losses).mean() - final_loss.backward() - optimizer.step() - loss_meter.update(final_loss.item()) - return loss_meter - - -def epoch_evaluate(loader, meta_model, base_model, criterion, device, logger): - with torch.no_grad(): - base_model.eval() - meta_model.eval() - loss_meter = AverageMeter() - for ibatch, batch_data in enumerate(loader): - timestamps, (batch_seq_inputs, batch_seq_targets) = batch_data - timestamps = timestamps.squeeze(dim=-1).to(device) - batch_seq_inputs = batch_seq_inputs.to(device) - batch_seq_targets = batch_seq_targets.to(device) - - batch_seq_containers = meta_model(timestamps) - losses = [] - for seq_containers, seq_inputs, seq_targets in zip( - batch_seq_containers, batch_seq_inputs, batch_seq_targets - ): - for container, inputs, targets in zip( - seq_containers, seq_inputs, seq_targets - ): - predictions = base_model.forward_with_container(inputs, container) - loss = criterion(predictions, targets) - losses.append(loss) - final_loss = torch.stack(losses).mean() - loss_meter.update(final_loss.item()) - return loss_meter - - -def pretrain(base_model, meta_model, criterion, xenv, args, logger): - base_model.train() - meta_model.train() - - optimizer = torch.optim.Adam( - meta_model.parameters(), - lr=args.lr, - weight_decay=args.weight_decay, - amsgrad=True, - ) - logger.log("Pre-train the meta-model") - logger.log("Using the optimizer: {:}".format(optimizer)) - - meta_model.set_best_dir(logger.path(None) / "ckps-basic-pretrain") - rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1) - for iepoch in range(args.epochs): - left_time = "Time Left: {:}".format( - convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) - ) - losses = [] - for ibatch in range(args.meta_batch): - timestamps = meta_model.meta_timestamps[ - rand_index : rand_index + xenv.seq_length - ] - seq_timestamps, (seq_inputs, seq_targets) = xenv.seq_call(timestamps) - time_embeds = meta_model.super_meta_embed[ - rand_index : rand_index + xenv.seq_length - ] - [seq_containers], time_embeds = meta_model( - None, torch.unsqueeze(time_embeds, dim=0) - ) - seq_inputs, seq_targets = seq_inputs.to(args.device), seq_targets.to( - args.device - ) - for container, inputs, targets in zip( - seq_containers, seq_inputs, seq_targets - ): - predictions = base_model.forward_with_container(inputs, container) - loss = criterion(predictions, targets) - losses.append(loss) - final_loss = torch.stack(losses).mean() - final_loss.backward() - optimizer.step() - # success - success, best_score = meta_model.save_best(-final_loss.item()) - logger.log( - "{:} [{:04d}/{:}] loss : {:.5f}".format( - time_string(), - iepoch, - args.epochs, - final_loss.item(), - ) - + ", batch={:}".format(len(losses)) - + ", success={:}, best_score={:.4f}".format(success, -best_score) - + " {:}".format(left_time) - ) - per_epoch_time.update(time.time() - start_time) - start_time = time.time() - - -def main(args): - logger, env_info, model_kwargs = lfna_setup(args) - train_env = get_synthetic_env(mode="train", version=args.env_version) - valid_env = get_synthetic_env(mode="valid", version=args.env_version) - logger.log("training enviornment: {:}".format(train_env)) - logger.log("validation enviornment: {:}".format(valid_env)) - - base_model = get_model(**model_kwargs) - base_model = base_model.to(args.device) - criterion = torch.nn.MSELoss() - - shape_container = base_model.get_w_container().to_shape_container() - - # pre-train the hypernetwork - timestamps = train_env.get_timestamp(None) - meta_model = LFNA_Meta(shape_container, args.layer_dim, args.time_dim, timestamps) - meta_model = meta_model.to(args.device) - - logger.log("The base-model has {:} weights.".format(base_model.numel())) - logger.log("The meta-model has {:} weights.".format(meta_model.numel())) - - batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge) - train_env.reset_max_seq_length(args.seq_length) - valid_env.reset_max_seq_length(args.seq_length) - valid_env_loader = torch.utils.data.DataLoader( - valid_env, - batch_size=args.meta_batch, - shuffle=True, - num_workers=args.workers, - pin_memory=True, - ) - train_env_loader = torch.utils.data.DataLoader( - train_env, - batch_sampler=batch_sampler, - num_workers=args.workers, - pin_memory=True, - ) - - optimizer = torch.optim.Adam( - meta_model.parameters(), - lr=args.lr, - weight_decay=args.weight_decay, - amsgrad=True, - ) - lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( - optimizer, - milestones=[1, 2, 3, 4, 5], - gamma=0.2, - ) - logger.log("The base-model is\n{:}".format(base_model)) - logger.log("The meta-model is\n{:}".format(meta_model)) - logger.log("The optimizer is\n{:}".format(optimizer)) - logger.log("The scheduler is\n{:}".format(lr_scheduler)) - logger.log("Per epoch iterations = {:}".format(len(train_env_loader))) - - pretrain(base_model, meta_model, criterion, train_env, args, logger) - - if logger.path("model").exists(): - ckp_data = torch.load(logger.path("model")) - base_model.load_state_dict(ckp_data["base_model"]) - meta_model.load_state_dict(ckp_data["meta_model"]) - optimizer.load_state_dict(ckp_data["optimizer"]) - lr_scheduler.load_state_dict(ckp_data["lr_scheduler"]) - last_success_epoch = ckp_data["last_success_epoch"] - start_epoch = ckp_data["iepoch"] + 1 - check_strs = [ - "epochs", - "env_version", - "hidden_dim", - "lr", - "layer_dim", - "time_dim", - "seq_length", - ] - for xstr in check_strs: - cx = getattr(args, xstr) - px = getattr(ckp_data["args"], xstr) - assert cx == px, "[{:}] {:} vs {:}".format(xstr, cx, ps) - success, _ = meta_model.save_best(ckp_data["cur_score"]) - logger.log("Load ckp from {:}".format(logger.path("model"))) - if success: - logger.log( - "Re-save the best model with score={:}".format(ckp_data["cur_score"]) - ) - else: - start_epoch, last_success_epoch = 0, 0 - - # LFNA meta-train - meta_model.set_best_dir(logger.path(None) / "checkpoint") - per_epoch_time, start_time = AverageMeter(), time.time() - for iepoch in range(start_epoch, args.epochs): - - head_str = "[{:}] [{:04d}/{:04d}] ".format( - time_string(), iepoch, args.epochs - ) + "Time Left: {:}".format( - convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) - ) - - loss_meter = epoch_train( - train_env_loader, - meta_model, - base_model, - optimizer, - criterion, - args.device, - logger, - ) - - valid_loss_meter = epoch_evaluate( - valid_env_loader, meta_model, base_model, criterion, args.device, logger - ) - logger.log( - head_str - + " meta-train-loss: {meter.avg:.4f} ({meter.count:.0f})".format( - meter=loss_meter - ) - + " meta-valid-loss: {meter.val:.4f}".format(meter=valid_loss_meter) - + " :: lr={:.5f}".format(min(lr_scheduler.get_last_lr())) - + " :: last-success={:}".format(last_success_epoch) - ) - success, best_score = meta_model.save_best(-loss_meter.avg) - if success: - logger.log("Achieve the best with best-score = {:.5f}".format(best_score)) - last_success_epoch = iepoch - save_checkpoint( - { - "meta_model": meta_model.state_dict(), - "base_model": base_model.state_dict(), - "optimizer": optimizer.state_dict(), - "lr_scheduler": lr_scheduler.state_dict(), - "last_success_epoch": last_success_epoch, - "cur_score": -loss_meter.avg, - "iepoch": iepoch, - "args": args, - }, - logger.path("model"), - logger, - ) - if iepoch - last_success_epoch >= args.early_stop_thresh: - if lr_scheduler.last_epoch > 4: - logger.log("Early stop at {:}".format(iepoch)) - break - else: - last_success_epoch = iepoch - lr_scheduler.step() - logger.log("Decay the lr [{:}]".format(lr_scheduler.last_epoch)) - - per_epoch_time.update(time.time() - start_time) - start_time = time.time() - - # meta-test - meta_model.load_best() - eval_env = env_info["dynamic_env"] - w_container_per_epoch = dict() - for idx in range(args.seq_length, len(eval_env)): - # build-timestamp - future_time = env_info["{:}-timestamp".format(idx)].item() - time_seqs = [] - for iseq in range(args.seq_length): - time_seqs.append(future_time - iseq * eval_env.timestamp_interval) - time_seqs.reverse() - with torch.no_grad(): - meta_model.eval() - base_model.eval() - time_seqs = torch.Tensor(time_seqs).view(1, -1).to(args.device) - [seq_containers] = meta_model(time_seqs) - future_container = seq_containers[-1] - w_container_per_epoch[idx] = future_container.no_grad_clone() - # evaluation - future_x = env_info["{:}-x".format(idx)].to(args.device) - future_y = env_info["{:}-y".format(idx)].to(args.device) - future_y_hat = base_model.forward_with_container( - future_x, w_container_per_epoch[idx] - ) - future_loss = criterion(future_y_hat, future_y) - logger.log( - "meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item()) - ) - - # creating the new meta-time-embedding - distance = meta_model.get_closest_meta_distance(future_time) - if distance < eval_env.timestamp_interval: - continue - # - new_param = meta_model.create_meta_embed() - optimizer = torch.optim.Adam( - [new_param], lr=args.refine_lr, weight_decay=1e-5, amsgrad=True - ) - meta_model.replace_append_learnt( - torch.Tensor([future_time]).to(args.device), new_param - ) - meta_model.eval() - base_model.train() - for iepoch in range(args.refine_epochs): - optimizer.zero_grad() - [seq_containers] = meta_model(time_seqs) - future_container = seq_containers[-1] - future_y_hat = base_model.forward_with_container(future_x, future_container) - future_loss = criterion(future_y_hat, future_y) - future_loss.backward() - optimizer.step() - logger.log( - "post-meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item()) - ) - with torch.no_grad(): - meta_model.replace_append_learnt(None, None) - meta_model.append_fixed(torch.Tensor([future_time]), new_param) - - save_checkpoint( - {"w_container_per_epoch": w_container_per_epoch}, - logger.path(None) / "final-ckp.pth", - logger, - ) - - logger.log("-" * 200 + "\n") - logger.close() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(".") - parser.add_argument( - "--save_dir", - type=str, - default="./outputs/lfna-synthetic/lfna-battle", - help="The checkpoint directory.", - ) - parser.add_argument( - "--env_version", - type=str, - required=True, - help="The synthetic enviornment version.", - ) - parser.add_argument( - "--hidden_dim", - type=int, - default=16, - help="The hidden dimension.", - ) - parser.add_argument( - "--layer_dim", - type=int, - default=16, - help="The layer chunk dimension.", - ) - parser.add_argument( - "--time_dim", - type=int, - default=16, - help="The timestamp dimension.", - ) - ##### - parser.add_argument( - "--lr", - type=float, - default=0.002, - help="The initial learning rate for the optimizer (default is Adam)", - ) - parser.add_argument( - "--weight_decay", - type=float, - default=0.00001, - help="The weight decay for the optimizer (default is Adam)", - ) - parser.add_argument( - "--meta_batch", - type=int, - default=64, - help="The batch size for the meta-model", - ) - parser.add_argument( - "--sampler_enlarge", - type=int, - default=5, - help="Enlarge the #iterations for an epoch", - ) - parser.add_argument("--epochs", type=int, default=10000, help="The total #epochs.") - parser.add_argument( - "--refine_lr", - type=float, - default=0.005, - help="The learning rate for the optimizer, during refine", - ) - parser.add_argument( - "--refine_epochs", type=int, default=1000, help="The final refine #epochs." - ) - parser.add_argument( - "--early_stop_thresh", - type=int, - default=20, - help="The #epochs for early stop.", - ) - parser.add_argument( - "--seq_length", type=int, default=10, help="The sequence length." - ) - parser.add_argument( - "--workers", type=int, default=4, help="The number of workers in parallel." - ) - parser.add_argument( - "--device", - type=str, - default="cpu", - help="", - ) - # Random Seed - parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") - args = parser.parse_args() - if args.rand_seed is None or args.rand_seed < 0: - args.rand_seed = random.randint(1, 100000) - assert args.save_dir is not None, "The save dir argument can not be None" - args.save_dir = "{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-env{:}".format( - args.save_dir, - args.hidden_dim, - args.layer_dim, - args.time_dim, - args.seq_length, - args.lr, - args.weight_decay, - args.epochs, - args.env_version, - ) - main(args) diff --git a/exps/GMOA/basic-his.py b/exps/GeMOSA/basic-his.py similarity index 100% rename from exps/GMOA/basic-his.py rename to exps/GeMOSA/basic-his.py diff --git a/exps/GMOA/basic-maml.py b/exps/GeMOSA/basic-maml.py similarity index 100% rename from exps/GMOA/basic-maml.py rename to exps/GeMOSA/basic-maml.py diff --git a/exps/GMOA/basic-prev.py b/exps/GeMOSA/basic-prev.py similarity index 100% rename from exps/GMOA/basic-prev.py rename to exps/GeMOSA/basic-prev.py diff --git a/exps/GMOA/basic-same.py b/exps/GeMOSA/basic-same.py similarity index 97% rename from exps/GMOA/basic-same.py rename to exps/GeMOSA/basic-same.py index 1ca25a6..0d06a9d 100644 --- a/exps/GMOA/basic-same.py +++ b/exps/GeMOSA/basic-same.py @@ -1,8 +1,8 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # ##################################################### -# python exps/LFNA/basic-same.py --env_version v1 --hidden_dim 16 --epochs 500 --init_lr 0.1 -# python exps/LFNA/basic-same.py --env_version v2 --hidden_dim 16 --epochs 1000 --init_lr 0.05 +# python exps/GeMOSA/basic-same.py --env_version v1 --hidden_dim 16 --epochs 500 --init_lr 0.1 +# python exps/GeMOSA/basic-same.py --env_version v2 --hidden_dim 16 --epochs 1000 --init_lr 0.05 ##################################################### import sys, time, copy, torch, random, argparse from tqdm import tqdm diff --git a/exps/GMOA/lfna_meta_model.py b/exps/GeMOSA/lfna_meta_model.py similarity index 98% rename from exps/GMOA/lfna_meta_model.py rename to exps/GeMOSA/lfna_meta_model.py index dc61b47..2ef4286 100644 --- a/exps/GMOA/lfna_meta_model.py +++ b/exps/GeMOSA/lfna_meta_model.py @@ -181,21 +181,20 @@ class MetaModelV1(super_core.SuperModule): timestamp_v_embed, mask, ) - return timestamp_embeds + return timestamp_embeds[:, -1, :] def forward_raw(self, timestamps, time_embeds, tembed_only=False): if time_embeds is None: time_seq = timestamps.view(-1, 1) + self._time_interval.view(1, -1) B, S = time_seq.shape time_embeds = self._obtain_time_embed(time_seq) - else: + else: # use the hyper-net only time_seq = None - B, S, _ = time_embeds.shape - # create joint embed - num_layer, _ = self._super_layer_embed.shape - time_embeds = time_embeds[:, -1, :] + B, _ = time_embeds.shape if tembed_only: return time_embeds + # create joint embed + num_layer, _ = self._super_layer_embed.shape # The shape of `joint_embed` is batch * num-layers * input-dim joint_embeds = torch.cat( ( diff --git a/exps/GMOA/lfna_models.py b/exps/GeMOSA/lfna_models.py similarity index 100% rename from exps/GMOA/lfna_models.py rename to exps/GeMOSA/lfna_models.py diff --git a/exps/GMOA/lfna_utils.py b/exps/GeMOSA/lfna_utils.py similarity index 100% rename from exps/GMOA/lfna_utils.py rename to exps/GeMOSA/lfna_utils.py diff --git a/exps/GMOA/lfna.py b/exps/GeMOSA/main.py similarity index 71% rename from exps/GMOA/lfna.py rename to exps/GeMOSA/main.py index 004896b..a78f8fd 100644 --- a/exps/GMOA/lfna.py +++ b/exps/GeMOSA/main.py @@ -1,10 +1,10 @@ ##################################################### # Learning to Generate Model One Step Ahead # ##################################################### -# python exps/GMOA/lfna.py --env_version v1 --workers 0 -# python exps/GMOA/lfna.py --env_version v1 --device cuda --lr 0.001 -# python exps/GMOA/lfna.py --env_version v1 --device cuda --lr 0.002 --seq_length 16 --meta_batch 128 -# python exps/GMOA/lfna.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128 +# python exps/GeMOSA/lfna.py --env_version v1 --workers 0 +# python exps/GeMOSA/lfna.py --env_version v1 --device cuda --lr 0.001 +# python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 16 --meta_batch 128 +# python exps/GeMOSA/lfna.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128 ##################################################### import sys, time, copy, torch, random, argparse from tqdm import tqdm @@ -38,63 +38,6 @@ from lfna_utils import lfna_setup, train_model, TimeData from lfna_meta_model import MetaModelV1 -def epoch_train(loader, meta_model, base_model, optimizer, criterion, device, logger): - base_model.train() - meta_model.train() - loss_meter = AverageMeter() - for ibatch, batch_data in enumerate(loader): - timestamps, (batch_seq_inputs, batch_seq_targets) = batch_data - timestamps = timestamps.squeeze(dim=-1).to(device) - batch_seq_inputs = batch_seq_inputs.to(device) - batch_seq_targets = batch_seq_targets.to(device) - - optimizer.zero_grad() - - batch_seq_containers = meta_model(timestamps) - losses = [] - for seq_containers, seq_inputs, seq_targets in zip( - batch_seq_containers, batch_seq_inputs, batch_seq_targets - ): - for container, inputs, targets in zip( - seq_containers, seq_inputs, seq_targets - ): - predictions = base_model.forward_with_container(inputs, container) - loss = criterion(predictions, targets) - losses.append(loss) - final_loss = torch.stack(losses).mean() - final_loss.backward() - optimizer.step() - loss_meter.update(final_loss.item()) - return loss_meter - - -def epoch_evaluate(loader, meta_model, base_model, criterion, device, logger): - with torch.no_grad(): - base_model.eval() - meta_model.eval() - loss_meter = AverageMeter() - for ibatch, batch_data in enumerate(loader): - timestamps, (batch_seq_inputs, batch_seq_targets) = batch_data - timestamps = timestamps.squeeze(dim=-1).to(device) - batch_seq_inputs = batch_seq_inputs.to(device) - batch_seq_targets = batch_seq_targets.to(device) - - batch_seq_containers = meta_model(timestamps) - losses = [] - for seq_containers, seq_inputs, seq_targets in zip( - batch_seq_containers, batch_seq_inputs, batch_seq_targets - ): - for container, inputs, targets in zip( - seq_containers, seq_inputs, seq_targets - ): - predictions = base_model.forward_with_container(inputs, container) - loss = criterion(predictions, targets) - losses.append(loss) - final_loss = torch.stack(losses).mean() - loss_meter.update(final_loss.item()) - return loss_meter - - def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=False): logger.log("Online evaluate: {:}".format(env)) loss_meter = AverageMeter() @@ -133,7 +76,7 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=F return w_containers, loss_meter -def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): +def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger): base_model.train() meta_model.train() optimizer = torch.optim.Adam( @@ -152,6 +95,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): logger.log("Directly load the best model from {:}".format(final_best_name)) return + total_indexes = list(range(meta_model.meta_length)) meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed)) last_success_epoch, early_stop_thresh = 0, args.pretrain_early_stop_thresh per_epoch_time, start_time = AverageMeter(), time.time() @@ -160,47 +104,50 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): left_time = "Time Left: {:}".format( convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) ) - total_future_losses, total_present_losses, total_regu_losses = [], [], [] optimizer.zero_grad() - for ibatch in range(args.meta_batch): - rand_index = random.randint(0, meta_model.meta_length - 1) - timestamp = meta_model.meta_timestamps[rand_index] - _, [container], time_embed = meta_model( - torch.unsqueeze(timestamp, dim=0), None, False - ) - _, (inputs, targets) = xenv(timestamp.item()) + generated_time_embeds = meta_model(meta_model.meta_timestamps, None, True) + + batch_indexes = random.choices(total_indexes, k=args.meta_batch) + + raw_time_steps = meta_model.meta_timestamps[batch_indexes] + + regularization_loss = F.l1_loss( + generated_time_embeds, meta_model.super_meta_embed, reduction="mean" + ) + # future loss + total_future_losses, total_present_losses = [], [] + _, future_containers, _ = meta_model( + None, generated_time_embeds[batch_indexes], False + ) + _, present_containers, _ = meta_model( + None, meta_model.super_meta_embed[batch_indexes], False + ) + for ibatch, time_step in enumerate(raw_time_steps.cpu().tolist()): + _, (inputs, targets) = xenv(time_step) inputs, targets = inputs.to(device), targets.to(device) - # generate models one step ahead - predictions = base_model.forward_with_container(inputs, container) - total_future_losses.append(criterion(predictions, targets)) - # randomly sample - rand_index = random.randint(0, meta_model.meta_length - 1) - timestamp = meta_model.meta_timestamps[rand_index] - meta_embed = meta_model.super_meta_embed[rand_index] - time_embed = meta_model(torch.unsqueeze(timestamp, dim=0), None, True) - total_regu_losses.append( - F.mse_loss( - torch.squeeze(time_embed, dim=0), meta_embed, reduction="mean" - ) + predictions = base_model.forward_with_container( + inputs, future_containers[ibatch] + ) + total_future_losses.append(criterion(predictions, targets)) + + predictions = base_model.forward_with_container( + inputs, present_containers[ibatch] ) - # generate models via memory - _, [container], _ = meta_model(None, meta_embed.view(1, 1, -1), False) - predictions = base_model.forward_with_container(inputs, container) total_present_losses.append(criterion(predictions, targets)) + with torch.no_grad(): meta_std = torch.stack(total_future_losses).std().item() loss_future = torch.stack(total_future_losses).mean() loss_present = torch.stack(total_present_losses).mean() - regularization_loss = torch.stack(total_regu_losses).mean() total_loss = loss_future + loss_present + regularization_loss total_loss.backward() optimizer.step() # success success, best_score = meta_model.save_best(-total_loss.item()) logger.log( - "{:} [Pre-V2 {:04d}/{:}] loss : {:.4f} +- {:.4f} = {:.4f} + {:.4f} + {:.4f}".format( + "{:} [META {:04d}/{:}] loss : {:.4f} +- {:.4f} = {:.4f} + {:.4f} + {:.4f}".format( time_string(), iepoch, args.epochs, @@ -264,7 +211,7 @@ def main(args): logger.log("The base-model is\n{:}".format(base_model)) logger.log("The meta-model is\n{:}".format(meta_model)) - pretrain_v2(base_model, meta_model, criterion, trainval_env, args, logger) + meta_train_procedure(base_model, meta_model, criterion, trainval_env, args, logger) # try to evaluate once # online_evaluate(train_env, meta_model, base_model, criterion, args, logger) diff --git a/exps/GMOA/vis-synthetic.py b/exps/GeMOSA/vis-synthetic.py similarity index 100% rename from exps/GMOA/vis-synthetic.py rename to exps/GeMOSA/vis-synthetic.py