diff --git a/exps/LFNA/lfna-v1.py b/exps/LFNA/lfna-v1.py index 90ac10a..c604d59 100644 --- a/exps/LFNA/lfna-v1.py +++ b/exps/LFNA/lfna-v1.py @@ -21,6 +21,57 @@ from procedures.advanced_main import basic_train_fn, basic_eval_fn from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric from datasets.synthetic_core import get_synthetic_env from models.xcore import get_model +from xlayers import super_core + + +class LFNAmlp: + """A LFNA meta-model that uses the MLP as delta-net.""" + + def __init__(self, obs_dim, hidden_sizes, act_name): + self.delta_net = super_core.SuperSequential( + super_core.SuperLinear(obs_dim, hidden_sizes[0]), + super_core.super_name2activation[act_name](), + super_core.SuperLinear(hidden_sizes[0], hidden_sizes[1]), + super_core.super_name2activation[act_name](), + super_core.SuperLinear(hidden_sizes[1], 1), + ) + self.meta_optimizer = torch.optim.Adam( + self.delta_net.parameters(), lr=0.01, amsgrad=True + ) + + def adapt(self, model, criterion, w_container, xs, ys): + containers = [w_container] + for idx, (x, y) in enumerate(zip(xs, ys)): + y_hat = model.forward_with_container(x, containers[-1]) + loss = criterion(y_hat, y) + gradients = torch.autograd.grad(loss, containers[-1].tensors) + with torch.no_grad(): + flatten_w = containers[-1].flatten().view(-1, 1) + flatten_g = containers[-1].flatten(gradients).view(-1, 1) + input_statistics = torch.tensor([x.mean(), x.std()]).view(1, 2) + input_statistics = input_statistics.expand(flatten_w.numel(), -1) + delta_inputs = torch.cat((flatten_w, flatten_g, input_statistics), dim=-1) + delta = self.delta_net(delta_inputs).view(-1) + # delta = torch.clamp(delta, -0.5, 0.5) + unflatten_delta = containers[-1].unflatten(delta) + future_container = containers[-1].additive(unflatten_delta) + containers.append(future_container) + # containers = containers[1:] + meta_loss = [] + for idx, (x, y) in enumerate(zip(xs, ys)): + if idx == 0: + continue + current_container = containers[idx] + y_hat = model.forward_with_container(x, current_container) + loss = criterion(y_hat, y) + meta_loss.append(loss) + meta_loss = sum(meta_loss) + meta_loss.backward() + self.meta_optimizer.step() + + def zero_grad(self): + self.meta_optimizer.zero_grad() + self.delta_net.zero_grad() class Population: @@ -28,11 +79,23 @@ class Population: def __init__(self): self._time2model = dict() + self._time2score = dict() # higher is better - def append(self, timestamp, model): + def append(self, timestamp, model, score): if timestamp in self._time2model: raise ValueError("This timestamp has been added.") self._time2model[timestamp] = model + self._time2score[timestamp] = score + + def query(self, timestamp): + closet_timestamp = None + for xtime, model in self._time2model.items(): + if ( + closet_timestamp is None + or timestamp - closet_timestamp >= timestamp - xtime + ): + closet_timestamp = xtime + return self._time2model[closet_timestamp], closet_timestamp def main(args): @@ -70,100 +133,39 @@ def main(args): ) w_container = base_model.named_parameters_buffers() + criterion = torch.nn.MSELoss() print("There are {:} weights.".format(w_container.numel())) + adaptor = LFNAmlp(4, (50, 20), "leaky_relu") + pool = Population() pool.append(0, w_container) # LFNA meta-training per_epoch_time, start_time = AverageMeter(), time.time() for iepoch in range(args.epochs): - import pdb - - pdb.set_trace() - print("-") - - for i, idx in enumerate(to_evaluate_indexes): need_time = "Time Left: {:}".format( - convert_secs2time( - per_timestamp_time.avg * (len(to_evaluate_indexes) - i), True - ) + convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) ) logger.log( - "[{:}]".format(time_string()) - + " [{:04d}/{:04d}][{:04d}]".format(i, len(to_evaluate_indexes), idx) - + " " + "[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs) + need_time ) - # train the same data - assert idx != 0 - historical_x = env_info["{:}-x".format(idx)] - historical_y = env_info["{:}-y".format(idx)] - # build model - mean, std = historical_x.mean().item(), historical_x.std().item() - model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std) - model = get_model(dict(model_type="simple_mlp"), **model_kwargs) - # build optimizer - optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) - criterion = torch.nn.MSELoss() - lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( - optimizer, - milestones=[ - int(args.epochs * 0.25), - int(args.epochs * 0.5), - int(args.epochs * 0.75), - ], - gamma=0.3, - ) - train_metric = MSEMetric() - best_loss, best_param = None, None - for _iepoch in range(args.epochs): - preds = model(historical_x) - optimizer.zero_grad() - loss = criterion(preds, historical_y) - loss.backward() - optimizer.step() - lr_scheduler.step() - # save best - if best_loss is None or best_loss > loss.item(): - best_loss = loss.item() - best_param = copy.deepcopy(model.state_dict()) - model.load_state_dict(best_param) - with torch.no_grad(): - train_metric(preds, historical_y) - train_results = train_metric.get_info() - metric = ComposeMetric(MSEMetric(), SaveMetric()) - eval_dataset = torch.utils.data.TensorDataset( - env_info["{:}-x".format(idx)], env_info["{:}-y".format(idx)] - ) - eval_loader = torch.utils.data.DataLoader( - eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0 - ) - results = basic_eval_fn(eval_loader, model, metric, logger) - log_str = ( - "[{:}]".format(time_string()) - + " [{:04d}/{:04d}]".format(idx, env_info["total"]) - + " train-mse: {:.5f}, eval-mse: {:.5f}".format( - train_results["mse"], results["mse"] - ) - ) - logger.log(log_str) + for ibatch in range(args.meta_batch): + sampled_timestamp = random.randint(0, train_time_bar) + query_w_container, query_timestamp = pool.query(sampled_timestamp) + # def adapt(self, model, w_container, xs, ys): + xs, ys = [], [] + for it in range(sampled_timestamp, sampled_timestamp + args.max_seq): + xs.append(env_info["{:}-x".format(it)]) + ys.append(env_info["{:}-y".format(it)]) + adaptor.adapt(base_model, criterion, query_w_container, xs, ys) + import pdb - save_path = logger.path(None) / "{:04d}-{:04d}.pth".format( - idx, env_info["total"] - ) - save_checkpoint( - { - "model_state_dict": model.state_dict(), - "model": model, - "index": idx, - "timestamp": env_info["{:}-timestamp".format(idx)], - }, - save_path, - logger, - ) + pdb.set_trace() + print("-") logger.log("") per_timestamp_time.update(time.time() - start_time) @@ -188,10 +190,10 @@ if __name__ == "__main__": help="The initial learning rate for the optimizer (default is Adam)", ) parser.add_argument( - "--batch_size", + "--meta_batch", type=int, - default=512, - help="The batch size", + default=2, + help="The batch size for the meta-model", ) parser.add_argument( "--epochs", @@ -199,6 +201,12 @@ if __name__ == "__main__": default=1000, help="The total number of epochs.", ) + parser.add_argument( + "--max_seq", + type=int, + default=5, + help="The maximum length of the sequence.", + ) parser.add_argument( "--workers", type=int, diff --git a/lib/models/xcore.py b/lib/models/xcore.py index 14e826d..a8196a0 100644 --- a/lib/models/xcore.py +++ b/lib/models/xcore.py @@ -34,3 +34,4 @@ def get_model(config: Dict[Text, Any], **kwargs): else: raise TypeError("Unkonwn model type: {:}".format(model_type)) return model + diff --git a/lib/xlayers/super_activations.py b/lib/xlayers/super_activations.py index 336dff3..bf3f3e8 100644 --- a/lib/xlayers/super_activations.py +++ b/lib/xlayers/super_activations.py @@ -31,6 +31,9 @@ class SuperReLU(SuperModule): def forward_raw(self, input: torch.Tensor) -> torch.Tensor: return F.relu(input, inplace=self._inplace) + def forward_with_container(self, input, container, prefix=[]): + return self.forward_raw(input) + def extra_repr(self) -> str: return "inplace=True" if self._inplace else "" @@ -53,6 +56,29 @@ class SuperLeakyReLU(SuperModule): def forward_raw(self, input: torch.Tensor) -> torch.Tensor: return F.leaky_relu(input, self._negative_slope, self._inplace) + def forward_with_container(self, input, container, prefix=[]): + return self.forward_raw(input) + def extra_repr(self) -> str: inplace_str = "inplace=True" if self._inplace else "" return "negative_slope={}{}".format(self._negative_slope, inplace_str) + + +class SuperTanh(SuperModule): + """Applies a the Tanh function element-wise.""" + + def __init__(self) -> None: + super(SuperTanh, self).__init__() + + @property + def abstract_search_space(self): + return spaces.VirtualNode(id(self)) + + def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: + return self.forward_raw(input) + + def forward_raw(self, input: torch.Tensor) -> torch.Tensor: + return torch.tanh(input) + + def forward_with_container(self, input, container, prefix=[]): + return self.forward_raw(input) diff --git a/lib/xlayers/super_container.py b/lib/xlayers/super_container.py index 0b1f9e6..5d21e5f 100644 --- a/lib/xlayers/super_container.py +++ b/lib/xlayers/super_container.py @@ -111,3 +111,10 @@ class SuperSequential(SuperModule): for module in self: input = module(input) return input + + def forward_with_container(self, input, container, prefix=[]): + for index, module in enumerate(self): + input = module.forward_with_container( + input, container, prefix + [str(index)] + ) + return input diff --git a/lib/xlayers/super_core.py b/lib/xlayers/super_core.py index fceb3e7..3e1d04f 100644 --- a/lib/xlayers/super_core.py +++ b/lib/xlayers/super_core.py @@ -27,8 +27,13 @@ from .super_transformer import SuperTransformerEncoderLayer from .super_activations import SuperReLU from .super_activations import SuperLeakyReLU +from .super_activations import SuperTanh -super_name2activation = {"relu": SuperReLU, "leaky_relu": SuperLeakyReLU} +super_name2activation = { + "relu": SuperReLU, + "leaky_relu": SuperLeakyReLU, + "tanh": SuperTanh, +} from .super_trade_stem import SuperAlphaEBDv1 diff --git a/lib/xlayers/super_linear.py b/lib/xlayers/super_linear.py index 5d2f005..803555f 100644 --- a/lib/xlayers/super_linear.py +++ b/lib/xlayers/super_linear.py @@ -115,6 +115,16 @@ class SuperLinear(SuperModule): self._in_features, self._out_features, self._bias ) + def forward_with_container(self, input, container, prefix=[]): + super_weight_name = ".".join(prefix + ["_super_weight"]) + super_weight = container.query(super_weight_name) + super_bias_name = ".".join(prefix + ["_super_bias"]) + if container.has(super_bias_name): + super_bias = container.query(super_bias_name) + else: + super_bias = None + return F.linear(input, super_weight, super_bias) + class SuperMLPv1(SuperModule): """An MLP layer: FC -> Activation -> Drop -> FC -> Drop.""" diff --git a/lib/xlayers/super_module.py b/lib/xlayers/super_module.py index d4e07eb..d9fde80 100644 --- a/lib/xlayers/super_module.py +++ b/lib/xlayers/super_module.py @@ -39,6 +39,41 @@ class TensorContainer: self._param_or_buffers = [] self._name2index = dict() + def additive(self, tensors): + result = TensorContainer() + for index, name in enumerate(self._names): + new_tensor = self._tensors[index] + tensors[index] + result.append(name, new_tensor, self._param_or_buffers[index]) + return result + + def no_grad_clone(self): + result = TensorContainer() + with torch.no_grad(): + for index, name in enumerate(self._names): + result.append( + name, self._tensors[index].clone(), self._param_or_buffers[index] + ) + return result + + @property + def tensors(self): + return self._tensors + + def flatten(self, tensors=None): + if tensors is None: + tensors = self._tensors + tensors = [tensor.view(-1) for tensor in tensors] + return torch.cat(tensors) + + def unflatten(self, tensor): + tensors, s = [], 0 + for raw_tensor in self._tensors: + length = raw_tensor.numel() + x = torch.reshape(tensor[s : s + length], shape=raw_tensor.shape) + tensors.append(x) + s += length + return tensors + def append(self, name, tensor, param_or_buffer): if not isinstance(tensor, torch.Tensor): raise TypeError( @@ -54,6 +89,23 @@ class TensorContainer: ) self._name2index[name] = len(self._names) - 1 + def query(self, name): + if not self.has(name): + raise ValueError( + "The {:} is not in {:}".format(name, list(self._name2index.keys())) + ) + index = self._name2index[name] + return self._tensors[index] + + def has(self, name): + return name in self._name2index + + def has_prefix(self, prefix): + for name, idx in self._name2index.items(): + if name.startswith(prefix): + return name + return False + def numel(self): total = 0 for tensor in self._tensors: @@ -181,3 +233,6 @@ class SuperModule(abc.ABC, nn.Module): ) ) return outputs + + def forward_with_container(self, inputs, container, prefix=[]): + raise NotImplementedError diff --git a/lib/xlayers/super_norm.py b/lib/xlayers/super_norm.py index 5671c0f..b745e9d 100644 --- a/lib/xlayers/super_norm.py +++ b/lib/xlayers/super_norm.py @@ -161,6 +161,21 @@ class SuperSimpleLearnableNorm(SuperModule): mean, std = torch.unsqueeze(mean, dim=0), torch.unsqueeze(std, dim=0) return tensor.sub_(mean).div_(std) + def forward_with_container(self, input, container, prefix=[]): + if not self._inplace: + tensor = input.clone() + else: + tensor = input + mean_name = ".".join(prefix + ["_mean"]) + std_name = ".".join(prefix + ["_std"]) + mean, std = ( + container.query(mean_name).to(tensor.device), + torch.abs(container.query(std_name).to(tensor.device)) + self._eps, + ) + while mean.ndim < tensor.ndim: + mean, std = torch.unsqueeze(mean, dim=0), torch.unsqueeze(std, dim=0) + return tensor.sub_(mean).div_(std) + def extra_repr(self) -> str: return "mean={mean}, std={std}, inplace={inplace}".format( mean=self._mean.item(), std=self._std.item(), inplace=self._inplace @@ -191,3 +206,6 @@ class SuperIdentity(SuperModule): def extra_repr(self) -> str: return "inplace={inplace}".format(inplace=self._inplace) + + def forward_with_container(self, input, container, prefix=[]): + return self.forward_raw(input) diff --git a/lib/xlayers/super_rl_actor.py b/lib/xlayers/super_rl_actor.py new file mode 100644 index 0000000..5725fed --- /dev/null +++ b/lib/xlayers/super_rl_actor.py @@ -0,0 +1,120 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # +##################################################### +# DISABLED / NOT-FINISHED +##################################################### +import torch +import torch.nn as nn +import torch.nn.functional as F + +import math +from typing import Optional, Callable + +import spaces +from .super_container import SuperSequential +from .super_linear import SuperLinear + + +class SuperActor(SuperModule): + """A Actor in RL.""" + + def _distribution(self, obs): + raise NotImplementedError + + def _log_prob_from_distribution(self, pi, act): + raise NotImplementedError + + def forward_candidate(self, **kwargs): + return self.forward_raw(**kwargs) + + def forward_raw(self, obs, act=None): + # Produce action distributions for given observations, and + # optionally compute the log likelihood of given actions under + # those distributions. + pi = self._distribution(obs) + logp_a = None + if act is not None: + logp_a = self._log_prob_from_distribution(pi, act) + return pi, logp_a + + +class SuperLfnaMetaMLP(SuperModule): + def __init__(self, obs_dim, hidden_sizes, act_cls): + super(SuperLfnaMetaMLP).__init__() + self.delta_net = SuperSequential( + SuperLinear(obs_dim, hidden_sizes[0]), + act_cls(), + SuperLinear(hidden_sizes[0], hidden_sizes[1]), + act_cls(), + SuperLinear(hidden_sizes[1], 1), + ) + + +class SuperLfnaMetaMLP(SuperModule): + def __init__(self, obs_dim, act_dim, hidden_sizes, act_cls): + super(SuperLfnaMetaMLP).__init__() + log_std = -0.5 * np.ones(act_dim, dtype=np.float32) + self.log_std = torch.nn.Parameter(torch.as_tensor(log_std)) + self.mu_net = SuperSequential( + SuperLinear(obs_dim, hidden_sizes[0]), + act_cls(), + SuperLinear(hidden_sizes[0], hidden_sizes[1]), + act_cls(), + SuperLinear(hidden_sizes[1], act_dim), + ) + + def _distribution(self, obs): + mu = self.mu_net(obs) + std = torch.exp(self.log_std) + return Normal(mu, std) + + def _log_prob_from_distribution(self, pi, act): + return pi.log_prob(act).sum(axis=-1) + + def forward_candidate(self, **kwargs): + return self.forward_raw(**kwargs) + + def forward_raw(self, obs, act=None): + # Produce action distributions for given observations, and + # optionally compute the log likelihood of given actions under + # those distributions. + pi = self._distribution(obs) + logp_a = None + if act is not None: + logp_a = self._log_prob_from_distribution(pi, act) + return pi, logp_a + + +class SuperMLPGaussianActor(SuperModule): + def __init__(self, obs_dim, act_dim, hidden_sizes, act_cls): + super(SuperMLPGaussianActor).__init__() + log_std = -0.5 * np.ones(act_dim, dtype=np.float32) + self.log_std = torch.nn.Parameter(torch.as_tensor(log_std)) + self.mu_net = SuperSequential( + SuperLinear(obs_dim, hidden_sizes[0]), + act_cls(), + SuperLinear(hidden_sizes[0], hidden_sizes[1]), + act_cls(), + SuperLinear(hidden_sizes[1], act_dim), + ) + + def _distribution(self, obs): + mu = self.mu_net(obs) + std = torch.exp(self.log_std) + return Normal(mu, std) + + def _log_prob_from_distribution(self, pi, act): + return pi.log_prob(act).sum(axis=-1) + + def forward_candidate(self, **kwargs): + return self.forward_raw(**kwargs) + + def forward_raw(self, obs, act=None): + # Produce action distributions for given observations, and + # optionally compute the log likelihood of given actions under + # those distributions. + pi = self._distribution(obs) + logp_a = None + if act is not None: + logp_a = self._log_prob_from_distribution(pi, act) + return pi, logp_a