diff --git a/exps/LFNA/lfna-debug.py b/exps/LFNA/lfna-debug.py new file mode 100644 index 0000000..0802e7e --- /dev/null +++ b/exps/LFNA/lfna-debug.py @@ -0,0 +1,476 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # +##################################################### +# python exps/LFNA/lfna-debug.py --env_version v1 --workers 0 +# python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.001 +# python exps/LFNA/lfna.py --env_version v1 --device cuda --lr 0.002 +##################################################### +import sys, time, copy, torch, random, argparse +from tqdm import tqdm +from copy import deepcopy +from pathlib import Path + +lib_dir = (Path(__file__).parent / ".." / "..").resolve() +print("LIB-DIR: {:}".format(lib_dir)) +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) + +from xautodl.procedures import ( + prepare_seed, + prepare_logger, + save_checkpoint, + copy_checkpoint, +) +from xautodl.log_utils import time_string +from xautodl.log_utils import AverageMeter, convert_secs2time + +from xautodl.utils import split_str2indexes + +from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn +from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric +from xautodl.datasets.synthetic_core import get_synthetic_env, EnvSampler +from xautodl.models.xcore import get_model +from xautodl.xlayers import super_core, trunc_normal_ + +from lfna_utils import lfna_setup, train_model, TimeData +from lfna_meta_model import LFNA_Meta + + +def epoch_train(loader, meta_model, base_model, optimizer, criterion, device, logger): + base_model.train() + meta_model.train() + loss_meter = AverageMeter() + for ibatch, batch_data in enumerate(loader): + timestamps, (batch_seq_inputs, batch_seq_targets) = batch_data + timestamps = timestamps.squeeze(dim=-1).to(device) + batch_seq_inputs = batch_seq_inputs.to(device) + batch_seq_targets = batch_seq_targets.to(device) + + optimizer.zero_grad() + + batch_seq_containers = meta_model(timestamps) + losses = [] + for seq_containers, seq_inputs, seq_targets in zip( + batch_seq_containers, batch_seq_inputs, batch_seq_targets + ): + for container, inputs, targets in zip( + seq_containers, seq_inputs, seq_targets + ): + predictions = base_model.forward_with_container(inputs, container) + loss = criterion(predictions, targets) + losses.append(loss) + final_loss = torch.stack(losses).mean() + final_loss.backward() + optimizer.step() + loss_meter.update(final_loss.item()) + return loss_meter + + +def epoch_evaluate(loader, meta_model, base_model, criterion, device, logger): + with torch.no_grad(): + base_model.eval() + meta_model.eval() + loss_meter = AverageMeter() + for ibatch, batch_data in enumerate(loader): + timestamps, (batch_seq_inputs, batch_seq_targets) = batch_data + timestamps = timestamps.squeeze(dim=-1).to(device) + batch_seq_inputs = batch_seq_inputs.to(device) + batch_seq_targets = batch_seq_targets.to(device) + + batch_seq_containers = meta_model(timestamps) + losses = [] + for seq_containers, seq_inputs, seq_targets in zip( + batch_seq_containers, batch_seq_inputs, batch_seq_targets + ): + for container, inputs, targets in zip( + seq_containers, seq_inputs, seq_targets + ): + predictions = base_model.forward_with_container(inputs, container) + loss = criterion(predictions, targets) + losses.append(loss) + final_loss = torch.stack(losses).mean() + loss_meter.update(final_loss.item()) + return loss_meter + + +def pretrain(base_model, meta_model, criterion, xenv, args, logger): + base_model.train() + meta_model.train() + + optimizer = torch.optim.Adam( + meta_model.parameters(), + lr=args.lr, + weight_decay=args.weight_decay, + amsgrad=True, + ) + logger.log("Pre-train the meta-model") + logger.log("Using the optimizer: {:}".format(optimizer)) + + meta_model.set_best_dir(logger.path(None) / "ckps-basic-pretrain") + rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1) + for iepoch in range(args.epochs): + left_time = "Time Left: {:}".format( + convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) + ) + losses = [] + for ibatch in range(args.meta_batch): + timestamps = meta_model.meta_timestamps[ + rand_index : rand_index + xenv.seq_length + ] + seq_timestamps, (seq_inputs, seq_targets) = xenv.seq_call(timestamps) + time_embeds = meta_model.super_meta_embed[ + rand_index : rand_index + xenv.seq_length + ] + [seq_containers], time_embeds = meta_model( + None, torch.unsqueeze(time_embeds, dim=0) + ) + seq_inputs, seq_targets = seq_inputs.to(args.device), seq_targets.to( + args.device + ) + for container, inputs, targets in zip( + seq_containers, seq_inputs, seq_targets + ): + predictions = base_model.forward_with_container(inputs, container) + loss = criterion(predictions, targets) + losses.append(loss) + final_loss = torch.stack(losses).mean() + final_loss.backward() + optimizer.step() + # success + success, best_score = meta_model.save_best(-final_loss.item()) + logger.log( + "{:} [{:04d}/{:}] loss : {:.5f}".format( + time_string(), + iepoch, + args.epochs, + final_loss.item(), + ) + + ", batch={:}".format(len(losses)) + + ", success={:}, best_score={:.4f}".format(success, -best_score) + + " {:}".format(left_time) + ) + per_epoch_time.update(time.time() - start_time) + start_time = time.time() + + +def main(args): + logger, env_info, model_kwargs = lfna_setup(args) + train_env = get_synthetic_env(mode="train", version=args.env_version) + valid_env = get_synthetic_env(mode="valid", version=args.env_version) + logger.log("training enviornment: {:}".format(train_env)) + logger.log("validation enviornment: {:}".format(valid_env)) + + base_model = get_model(**model_kwargs) + base_model = base_model.to(args.device) + criterion = torch.nn.MSELoss() + + shape_container = base_model.get_w_container().to_shape_container() + + # pre-train the hypernetwork + timestamps = train_env.get_timestamp(None) + meta_model = LFNA_Meta(shape_container, args.layer_dim, args.time_dim, timestamps) + meta_model = meta_model.to(args.device) + + logger.log("The base-model has {:} weights.".format(base_model.numel())) + logger.log("The meta-model has {:} weights.".format(meta_model.numel())) + + batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge) + train_env.reset_max_seq_length(args.seq_length) + valid_env.reset_max_seq_length(args.seq_length) + valid_env_loader = torch.utils.data.DataLoader( + valid_env, + batch_size=args.meta_batch, + shuffle=True, + num_workers=args.workers, + pin_memory=True, + ) + train_env_loader = torch.utils.data.DataLoader( + train_env, + batch_sampler=batch_sampler, + num_workers=args.workers, + pin_memory=True, + ) + + optimizer = torch.optim.Adam( + meta_model.parameters(), + lr=args.lr, + weight_decay=args.weight_decay, + amsgrad=True, + ) + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=[1, 2, 3, 4, 5], + gamma=0.2, + ) + logger.log("The base-model is\n{:}".format(base_model)) + logger.log("The meta-model is\n{:}".format(meta_model)) + logger.log("The optimizer is\n{:}".format(optimizer)) + logger.log("The scheduler is\n{:}".format(lr_scheduler)) + logger.log("Per epoch iterations = {:}".format(len(train_env_loader))) + + pretrain(base_model, meta_model, criterion, train_env, args, logger) + + if logger.path("model").exists(): + ckp_data = torch.load(logger.path("model")) + base_model.load_state_dict(ckp_data["base_model"]) + meta_model.load_state_dict(ckp_data["meta_model"]) + optimizer.load_state_dict(ckp_data["optimizer"]) + lr_scheduler.load_state_dict(ckp_data["lr_scheduler"]) + last_success_epoch = ckp_data["last_success_epoch"] + start_epoch = ckp_data["iepoch"] + 1 + check_strs = [ + "epochs", + "env_version", + "hidden_dim", + "lr", + "layer_dim", + "time_dim", + "seq_length", + ] + for xstr in check_strs: + cx = getattr(args, xstr) + px = getattr(ckp_data["args"], xstr) + assert cx == px, "[{:}] {:} vs {:}".format(xstr, cx, ps) + success, _ = meta_model.save_best(ckp_data["cur_score"]) + logger.log("Load ckp from {:}".format(logger.path("model"))) + if success: + logger.log( + "Re-save the best model with score={:}".format(ckp_data["cur_score"]) + ) + else: + start_epoch, last_success_epoch = 0, 0 + + # LFNA meta-train + meta_model.set_best_dir(logger.path(None) / "checkpoint") + per_epoch_time, start_time = AverageMeter(), time.time() + for iepoch in range(start_epoch, args.epochs): + + head_str = "[{:}] [{:04d}/{:04d}] ".format( + time_string(), iepoch, args.epochs + ) + "Time Left: {:}".format( + convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) + ) + + loss_meter = epoch_train( + train_env_loader, + meta_model, + base_model, + optimizer, + criterion, + args.device, + logger, + ) + + valid_loss_meter = epoch_evaluate( + valid_env_loader, meta_model, base_model, criterion, args.device, logger + ) + logger.log( + head_str + + " meta-train-loss: {meter.avg:.4f} ({meter.count:.0f})".format( + meter=loss_meter + ) + + " meta-valid-loss: {meter.val:.4f}".format(meter=valid_loss_meter) + + " :: lr={:.5f}".format(min(lr_scheduler.get_last_lr())) + + " :: last-success={:}".format(last_success_epoch) + ) + success, best_score = meta_model.save_best(-loss_meter.avg) + if success: + logger.log("Achieve the best with best-score = {:.5f}".format(best_score)) + last_success_epoch = iepoch + save_checkpoint( + { + "meta_model": meta_model.state_dict(), + "base_model": base_model.state_dict(), + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + "last_success_epoch": last_success_epoch, + "cur_score": -loss_meter.avg, + "iepoch": iepoch, + "args": args, + }, + logger.path("model"), + logger, + ) + if iepoch - last_success_epoch >= args.early_stop_thresh: + if lr_scheduler.last_epoch > 4: + logger.log("Early stop at {:}".format(iepoch)) + break + else: + last_success_epoch = iepoch + lr_scheduler.step() + logger.log("Decay the lr [{:}]".format(lr_scheduler.last_epoch)) + + per_epoch_time.update(time.time() - start_time) + start_time = time.time() + + # meta-test + meta_model.load_best() + eval_env = env_info["dynamic_env"] + w_container_per_epoch = dict() + for idx in range(args.seq_length, len(eval_env)): + # build-timestamp + future_time = env_info["{:}-timestamp".format(idx)].item() + time_seqs = [] + for iseq in range(args.seq_length): + time_seqs.append(future_time - iseq * eval_env.timestamp_interval) + time_seqs.reverse() + with torch.no_grad(): + meta_model.eval() + base_model.eval() + time_seqs = torch.Tensor(time_seqs).view(1, -1).to(args.device) + [seq_containers] = meta_model(time_seqs) + future_container = seq_containers[-1] + w_container_per_epoch[idx] = future_container.no_grad_clone() + # evaluation + future_x = env_info["{:}-x".format(idx)].to(args.device) + future_y = env_info["{:}-y".format(idx)].to(args.device) + future_y_hat = base_model.forward_with_container( + future_x, w_container_per_epoch[idx] + ) + future_loss = criterion(future_y_hat, future_y) + logger.log( + "meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item()) + ) + + # creating the new meta-time-embedding + distance = meta_model.get_closest_meta_distance(future_time) + if distance < eval_env.timestamp_interval: + continue + # + new_param = meta_model.create_meta_embed() + optimizer = torch.optim.Adam( + [new_param], lr=args.refine_lr, weight_decay=1e-5, amsgrad=True + ) + meta_model.replace_append_learnt( + torch.Tensor([future_time]).to(args.device), new_param + ) + meta_model.eval() + base_model.train() + for iepoch in range(args.refine_epochs): + optimizer.zero_grad() + [seq_containers] = meta_model(time_seqs) + future_container = seq_containers[-1] + future_y_hat = base_model.forward_with_container(future_x, future_container) + future_loss = criterion(future_y_hat, future_y) + future_loss.backward() + optimizer.step() + logger.log( + "post-meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item()) + ) + with torch.no_grad(): + meta_model.replace_append_learnt(None, None) + meta_model.append_fixed(torch.Tensor([future_time]), new_param) + + save_checkpoint( + {"w_container_per_epoch": w_container_per_epoch}, + logger.path(None) / "final-ckp.pth", + logger, + ) + + logger.log("-" * 200 + "\n") + logger.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(".") + parser.add_argument( + "--save_dir", + type=str, + default="./outputs/lfna-synthetic/lfna-battle", + help="The checkpoint directory.", + ) + parser.add_argument( + "--env_version", + type=str, + required=True, + help="The synthetic enviornment version.", + ) + parser.add_argument( + "--hidden_dim", + type=int, + default=16, + help="The hidden dimension.", + ) + parser.add_argument( + "--layer_dim", + type=int, + default=16, + help="The layer chunk dimension.", + ) + parser.add_argument( + "--time_dim", + type=int, + default=16, + help="The timestamp dimension.", + ) + ##### + parser.add_argument( + "--lr", + type=float, + default=0.002, + help="The initial learning rate for the optimizer (default is Adam)", + ) + parser.add_argument( + "--weight_decay", + type=float, + default=0.00001, + help="The weight decay for the optimizer (default is Adam)", + ) + parser.add_argument( + "--meta_batch", + type=int, + default=64, + help="The batch size for the meta-model", + ) + parser.add_argument( + "--sampler_enlarge", + type=int, + default=5, + help="Enlarge the #iterations for an epoch", + ) + parser.add_argument("--epochs", type=int, default=10000, help="The total #epochs.") + parser.add_argument( + "--refine_lr", + type=float, + default=0.005, + help="The learning rate for the optimizer, during refine", + ) + parser.add_argument( + "--refine_epochs", type=int, default=1000, help="The final refine #epochs." + ) + parser.add_argument( + "--early_stop_thresh", + type=int, + default=20, + help="The #epochs for early stop.", + ) + parser.add_argument( + "--seq_length", type=int, default=10, help="The sequence length." + ) + parser.add_argument( + "--workers", type=int, default=4, help="The number of workers in parallel." + ) + parser.add_argument( + "--device", + type=str, + default="cpu", + help="", + ) + # Random Seed + parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") + args = parser.parse_args() + if args.rand_seed is None or args.rand_seed < 0: + args.rand_seed = random.randint(1, 100000) + assert args.save_dir is not None, "The save dir argument can not be None" + args.save_dir = "{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-env{:}".format( + args.save_dir, + args.hidden_dim, + args.layer_dim, + args.time_dim, + args.seq_length, + args.lr, + args.weight_decay, + args.epochs, + args.env_version, + ) + main(args) diff --git a/exps/LFNA/lfna.py b/exps/LFNA/lfna.py index 3264c3e..00a998b 100644 --- a/exps/LFNA/lfna.py +++ b/exps/LFNA/lfna.py @@ -93,7 +93,7 @@ def epoch_evaluate(loader, meta_model, base_model, criterion, device, logger): return loss_meter -def pretrain(base_model, meta_model, criterion, xenv, args, logger): +def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): optimizer = torch.optim.Adam( meta_model.parameters(), lr=args.lr, @@ -164,6 +164,69 @@ def pretrain(base_model, meta_model, criterion, xenv, args, logger): start_time = time.time() +def pretrain(base_model, meta_model, criterion, xenv, args, logger): + base_model.train() + meta_model.train() + optimizer = torch.optim.Adam( + meta_model.parameters(), + lr=args.lr, + weight_decay=args.weight_decay, + amsgrad=True, + ) + logger.log("Pre-train the meta-model") + logger.log("Using the optimizer: {:}".format(optimizer)) + + meta_model.set_best_dir(logger.path(None) / "ckps-basic-pretrain") + meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed)) + per_epoch_time, start_time = AverageMeter(), time.time() + for iepoch in range(args.epochs): + left_time = "Time Left: {:}".format( + convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) + ) + losses = [] + optimizer.zero_grad() + for ibatch in range(args.meta_batch): + rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1) + timestamps = meta_model.meta_timestamps[ + rand_index : rand_index + xenv.seq_length + ] + + seq_timestamps, (seq_inputs, seq_targets) = xenv.seq_call(timestamps) + time_embeds = meta_model.super_meta_embed[ + rand_index : rand_index + xenv.seq_length + ] + [seq_containers], time_embeds = meta_model( + None, torch.unsqueeze(time_embeds, dim=0) + ) + seq_inputs, seq_targets = seq_inputs.to(args.device), seq_targets.to( + args.device + ) + for container, inputs, targets in zip( + seq_containers, seq_inputs, seq_targets + ): + predictions = base_model.forward_with_container(inputs, container) + loss = criterion(predictions, targets) + losses.append(loss) + final_loss = torch.stack(losses).mean() + final_loss.backward() + optimizer.step() + # success + success, best_score = meta_model.save_best(-final_loss.item()) + logger.log( + "{:} [{:04d}/{:}] loss : {:.5f}".format( + time_string(), + iepoch, + args.epochs, + final_loss.item(), + ) + + ", batch={:}".format(len(losses)) + + ", success={:}, best_score={:.4f}".format(success, -best_score) + + " {:}".format(left_time) + ) + per_epoch_time.update(time.time() - start_time) + start_time = time.time() + + def main(args): logger, env_info, model_kwargs = lfna_setup(args) train_env = get_synthetic_env(mode="train", version=args.env_version) diff --git a/exps/LFNA/lfna_meta_model.py b/exps/LFNA/lfna_meta_model.py index b9aa87a..10cdeb2 100644 --- a/exps/LFNA/lfna_meta_model.py +++ b/exps/LFNA/lfna_meta_model.py @@ -85,7 +85,6 @@ class LFNA_Meta(super_core.SuperModule): dropout=dropout, ) self._generator = get_model(**model_kwargs) - # print("generator: {:}".format(self._generator)) # initialization trunc_normal_( @@ -163,9 +162,12 @@ class LFNA_Meta(super_core.SuperModule): corrected_embeds = self.meta_corrector(init_timestamp_embeds) return corrected_embeds - def forward_raw(self, timestamps): - batch, seq = timestamps.shape - time_embed = self._obtain_time_embed(timestamps) + def forward_raw(self, timestamps, time_embed): + if time_embed is None: + batch, seq = timestamps.shape + time_embed = self._obtain_time_embed(timestamps) + else: + batch, seq, _ = time_embed.shape # create joint embed num_layer, _ = self._super_layer_embed.shape meta_embed = time_embed.view(batch, seq, 1, -1).expand(-1, -1, num_layer, -1) diff --git a/xautodl/xlayers/super_module.py b/xautodl/xlayers/super_module.py index 82d4587..b0fb8eb 100644 --- a/xautodl/xlayers/super_module.py +++ b/xautodl/xlayers/super_module.py @@ -19,6 +19,7 @@ from .super_utils import TensorContainer from .super_utils import ShapeContainer BEST_DIR_KEY = "best_model_dir" +BEST_NAME_KEY = "best_model_name" BEST_SCORE_KEY = "best_model_score" @@ -94,6 +95,9 @@ class SuperModule(abc.ABC, nn.Module): self._meta_info[BEST_DIR_KEY] = str(xdir) Path(xdir).mkdir(parents=True, exist_ok=True) + def set_best_name(self, xname): + self._meta_info[BEST_NAME_KEY] = str(xname) + def save_best(self, score): if BEST_DIR_KEY not in self._meta_info: tempdir = tempfile.mkdtemp("-xlayers") @@ -102,10 +106,11 @@ class SuperModule(abc.ABC, nn.Module): self._meta_info[BEST_SCORE_KEY] = None best_score = self._meta_info[BEST_SCORE_KEY] if best_score is None or best_score <= score: - best_save_path = os.path.join( - self._meta_info[BEST_DIR_KEY], - "best-{:}.pth".format(self.__class__.__name__), + best_save_name = self._meta_info.get( + BEST_NAME_KEY, "best-{:}.pth".format(self.__class__.__name__) ) + + best_save_path = os.path.join(self._meta_info[BEST_DIR_KEY], best_save_name) self._meta_info[BEST_SCORE_KEY] = score torch.save(self.state_dict(), best_save_path) return True, self._meta_info[BEST_SCORE_KEY]