From 755c7c90cf3e45baaa01b7e03e42ae52535506d3 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Mon, 10 May 2021 09:42:42 +0800 Subject: [PATCH] Prototype MAML --- exps/LFNA/basic-maml.py | 142 +++++++++++++++++++--------------------- 1 file changed, 67 insertions(+), 75 deletions(-) diff --git a/exps/LFNA/basic-maml.py b/exps/LFNA/basic-maml.py index e0ead12..3dcd7b2 100644 --- a/exps/LFNA/basic-maml.py +++ b/exps/LFNA/basic-maml.py @@ -1,8 +1,8 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # ##################################################### -# python exps/LFNA/basic-maml.py --env_version v1 # -# python exps/LFNA/basic-maml.py --env_version v2 # +# python exps/LFNA/basic-maml.py --env_version v1 --hidden_dim 16 --inner_step 5 +# python exps/LFNA/basic-maml.py --env_version v2 ##################################################### import sys, time, copy, torch, random, argparse from tqdm import tqdm @@ -30,73 +30,71 @@ from lfna_utils import lfna_setup, TimeData class MAML: """A LFNA meta-model that uses the MLP as delta-net.""" - def __init__(self, container, criterion, meta_lr, inner_lr=0.01, inner_step=1): + def __init__( + self, network, criterion, epochs, meta_lr, inner_lr=0.01, inner_step=1 + ): self.criterion = criterion - self.container = container + # self.container = container + self.network = network self.meta_optimizer = torch.optim.Adam( - self.container.parameters(), lr=meta_lr, amsgrad=True + self.network.parameters(), lr=meta_lr, amsgrad=True + ) + self.meta_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=[ + int(epochs * 0.25), + int(epochs * 0.5), + int(epochs * 0.75), + ], + gamma=0.3, ) self.inner_lr = inner_lr self.inner_step = inner_step + self._best_info = dict(state_dict=None, score=None) + print("There are {:} weights.".format(w_container.numel())) - def adapt(self, model, dataset): + def adapt(self, dataset): # create a container for the future timestamp - y_hat = model.forward_with_container(dataset.x, self.container) - loss = self.criterion(y_hat, dataset.y) - grads = torch.autograd.grad(loss, self.container.parameters()) + container = self.network.get_w_container() - fast_container = self.container.additive( - [-self.inner_lr * grad for grad in grads] - ) - import pdb + for k in range(0, self.inner_step): + y_hat = self.network.forward_with_container(dataset.x, container) + loss = self.criterion(y_hat, dataset.y) + grads = torch.autograd.grad(loss, container.parameters()) - pdb.set_trace() - w_container.requires_grad_(True) - containers = [w_container] - for idx, dataset in enumerate(seq_datasets): - x, y = dataset.x, dataset.y - y_hat = model.forward_with_container(x, containers[-1]) - loss = criterion(y_hat, y) - gradients = torch.autograd.grad(loss, containers[-1].tensors) - with torch.no_grad(): - flatten_w = containers[-1].flatten().view(-1, 1) - flatten_g = containers[-1].flatten(gradients).view(-1, 1) - input_statistics = torch.tensor([x.mean(), x.std()]).view(1, 2) - input_statistics = input_statistics.expand(flatten_w.numel(), -1) - delta_inputs = torch.cat((flatten_w, flatten_g, input_statistics), dim=-1) - delta = self.delta_net(delta_inputs).view(-1) - delta = torch.clamp(delta, -0.5, 0.5) - unflatten_delta = containers[-1].unflatten(delta) - future_container = containers[-1].no_grad_clone().additive(unflatten_delta) - # future_container = containers[-1].additive(unflatten_delta) - containers.append(future_container) - # containers = containers[1:] - meta_loss = [] - temp_containers = [] - for idx, dataset in enumerate(seq_datasets): - if idx == 0: - continue - current_container = containers[idx] - y_hat = model.forward_with_container(dataset.x, current_container) - loss = criterion(y_hat, dataset.y) - meta_loss.append(loss) - temp_containers.append((dataset.timestamp, current_container, -loss.item())) - meta_loss = sum(meta_loss) - w_container.requires_grad_(False) - # meta_loss.backward() - # self.meta_optimizer.step() - return meta_loss, temp_containers + container = container.additive([-self.inner_lr * grad for grad in grads]) + return container + + def predict(self, x, container=None): + if container is not None: + y_hat = self.network.forward_with_container(x, container) + else: + y_hat = self.network(x) + return y_hat def step(self): - torch.nn.utils.clip_grad_norm_(self.delta_net.parameters(), 1.0) + torch.nn.utils.clip_grad_norm_(self.container.parameters(), 1.0) self.meta_optimizer.step() + self.meta_lr_scheduler.step() def zero_grad(self): self.meta_optimizer.zero_grad() + def save_best(self, network, score): + if self._best_info["score"] is None or self._best_info["score"] < score: + state_dict = dict( + criterion=criterion, + network=network.state_dict(), + meta_optimizer=self.meta_optimizer.state_dict(), + meta_lr_scheduler=self.meta_lr_scheduler.state_dict(), + ) + self._best_info["state_dict"] = state_dict + self._best_info["score"] = score + def main(args): - logger, env_info = lfna_setup(args) + logger, env_info, model_kwargs = lfna_setup(args) + model = get_model(dict(model_type="simple_mlp"), **model_kwargs) total_time = env_info["total"] for i in range(total_time): @@ -104,19 +102,12 @@ def main(args): nkey = "{:}-{:}".format(i, xkey) assert nkey in env_info, "{:} no in {:}".format(nkey, list(env_info.keys())) train_time_bar = total_time // 2 - base_model = get_model( - dict(model_type="simple_mlp"), - act_cls="leaky_relu", - norm_cls="identity", - input_dim=1, - output_dim=1, - ) - w_container = base_model.get_w_container() criterion = torch.nn.MSELoss() - print("There are {:} weights.".format(w_container.numel())) - maml = MAML(w_container, criterion, args.meta_lr, args.inner_lr, args.inner_step) + maml = MAML( + model, criterion, args.epochs, args.meta_lr, args.inner_lr, args.inner_step + ) # meta-training per_epoch_time, start_time = AverageMeter(), time.time() @@ -131,8 +122,7 @@ def main(args): ) maml.zero_grad() - - all_meta_losses = [] + meta_losses = [] for ibatch in range(args.meta_batch): sampled_timestamp = random.randint(0, train_time_bar) past_dataset = TimeData( @@ -145,21 +135,23 @@ def main(args): env_info["{:}-x".format(sampled_timestamp + 1)], env_info["{:}-y".format(sampled_timestamp + 1)], ) - maml.adapt(base_model, past_dataset) - import pdb - - pdb.set_trace() - - meta_loss = torch.stack(all_meta_losses).mean() + future_container = maml.adapt(model, past_dataset) + future_y_hat = maml.predict(future_dataset.x, future_container) + future_loss = maml.criterion(future_y_hat, future_dataset.y) + meta_losses.append(future_loss) + meta_loss = torch.stack(meta_losses).mean() meta_loss.backward() - adaptor.step() + maml.step() - debug_str = pool.debug_info(debug_timestamp) logger.log("meta-loss: {:.4f}".format(meta_loss.item())) per_epoch_time.update(time.time() - start_time) start_time = time.time() + import pdb + + pdb.set_trace() + logger.log("-" * 200 + "\n") logger.close() @@ -187,7 +179,7 @@ if __name__ == "__main__": parser.add_argument( "--meta_lr", type=float, - default=0.01, + default=0.1, help="The learning rate for the MAML optimizer (default is Adam)", ) parser.add_argument( @@ -202,7 +194,7 @@ if __name__ == "__main__": parser.add_argument( "--meta_batch", type=int, - default=5, + default=10, help="The batch size for the meta-model", ) parser.add_argument( @@ -223,7 +215,7 @@ if __name__ == "__main__": if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) assert args.save_dir is not None, "The save dir argument can not be None" - args.save_dir = "{:}-{:}-d{:}".format( - args.save_dir, args.env_version, args.hidden_dim + args.save_dir = "{:}-s{:}-{:}-d{:}".format( + args.save_dir, args.inner_step, args.env_version, args.hidden_dim ) main(args)