diff --git a/exps/LFNA/lfna-v1.py b/exps/LFNA/lfna-v1.py new file mode 100644 index 0000000..9c9b90a --- /dev/null +++ b/exps/LFNA/lfna-v1.py @@ -0,0 +1,212 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # +##################################################### +# python exps/LFNA/lfna-v1.py +##################################################### +import sys, time, copy, torch, random, argparse +from tqdm import tqdm +from copy import deepcopy +from pathlib import Path + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) +from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint +from log_utils import time_string +from log_utils import AverageMeter, convert_secs2time + +from utils import split_str2indexes + +from procedures.advanced_main import basic_train_fn, basic_eval_fn +from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric +from datasets.synthetic_core import get_synthetic_env +from models.xcore import get_model + + +class Population: + def __init__(self): + self._time2model = dict() + + def append(self, timestamp, model): + if timestamp in self._time2model: + raise ValueError("This timestamp has been added.") + self._time2model[timestamp] = model + + +def main(args): + prepare_seed(args.rand_seed) + logger = prepare_logger(args) + + cache_path = (logger.path(None) / ".." / "env-info.pth").resolve() + if cache_path.exists(): + env_info = torch.load(cache_path) + else: + env_info = dict() + dynamic_env = get_synthetic_env() + env_info["total"] = len(dynamic_env) + for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)): + env_info["{:}-timestamp".format(idx)] = timestamp + env_info["{:}-x".format(idx)] = _allx + env_info["{:}-y".format(idx)] = _ally + env_info["dynamic_env"] = dynamic_env + torch.save(env_info, cache_path) + + total_time = env_info["total"] + for i in range(total_time): + for xkey in ("timestamp", "x", "y"): + nkey = "{:}-{:}".format(i, xkey) + assert nkey in env_info, "{:} no in {:}".format(nkey, list(env_info.keys())) + train_time_bar = total_time // 2 + base_model = get_model( + dict(model_type="simple_mlp"), + act_cls="leaky_relu", + norm_cls="simple_learn_norm", + mean=0, + std=1, + input_dim=1, + output_dim=1, + ) + + w_container = base_model.named_parameters_buffers() + print("There are {:} weights.".format(w_container.numel())) + + pool = Population() + pool.append(0, w_container) + + # LFNA meta-training + per_epoch_time, start_time = AverageMeter(), time.time() + for iepoch in range(args.epochs): + import pdb + + pdb.set_trace() + print("-") + + for i, idx in enumerate(to_evaluate_indexes): + + need_time = "Time Left: {:}".format( + convert_secs2time( + per_timestamp_time.avg * (len(to_evaluate_indexes) - i), True + ) + ) + logger.log( + "[{:}]".format(time_string()) + + " [{:04d}/{:04d}][{:04d}]".format(i, len(to_evaluate_indexes), idx) + + " " + + need_time + ) + # train the same data + assert idx != 0 + historical_x = env_info["{:}-x".format(idx)] + historical_y = env_info["{:}-y".format(idx)] + # build model + mean, std = historical_x.mean().item(), historical_x.std().item() + model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std) + model = get_model(dict(model_type="simple_mlp"), **model_kwargs) + # build optimizer + optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) + criterion = torch.nn.MSELoss() + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=[ + int(args.epochs * 0.25), + int(args.epochs * 0.5), + int(args.epochs * 0.75), + ], + gamma=0.3, + ) + train_metric = MSEMetric() + best_loss, best_param = None, None + for _iepoch in range(args.epochs): + preds = model(historical_x) + optimizer.zero_grad() + loss = criterion(preds, historical_y) + loss.backward() + optimizer.step() + lr_scheduler.step() + # save best + if best_loss is None or best_loss > loss.item(): + best_loss = loss.item() + best_param = copy.deepcopy(model.state_dict()) + model.load_state_dict(best_param) + with torch.no_grad(): + train_metric(preds, historical_y) + train_results = train_metric.get_info() + + metric = ComposeMetric(MSEMetric(), SaveMetric()) + eval_dataset = torch.utils.data.TensorDataset( + env_info["{:}-x".format(idx)], env_info["{:}-y".format(idx)] + ) + eval_loader = torch.utils.data.DataLoader( + eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0 + ) + results = basic_eval_fn(eval_loader, model, metric, logger) + log_str = ( + "[{:}]".format(time_string()) + + " [{:04d}/{:04d}]".format(idx, env_info["total"]) + + " train-mse: {:.5f}, eval-mse: {:.5f}".format( + train_results["mse"], results["mse"] + ) + ) + logger.log(log_str) + + save_path = logger.path(None) / "{:04d}-{:04d}.pth".format( + idx, env_info["total"] + ) + save_checkpoint( + { + "model_state_dict": model.state_dict(), + "model": model, + "index": idx, + "timestamp": env_info["{:}-timestamp".format(idx)], + }, + save_path, + logger, + ) + logger.log("") + + per_timestamp_time.update(time.time() - start_time) + start_time = time.time() + + logger.log("-" * 200 + "\n") + logger.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Use the data in the past.") + parser.add_argument( + "--save_dir", + type=str, + default="./outputs/lfna-synthetic/lfna-v1", + help="The checkpoint directory.", + ) + parser.add_argument( + "--init_lr", + type=float, + default=0.1, + help="The initial learning rate for the optimizer (default is Adam)", + ) + parser.add_argument( + "--batch_size", + type=int, + default=512, + help="The batch size", + ) + parser.add_argument( + "--epochs", + type=int, + default=1000, + help="The total number of epochs.", + ) + parser.add_argument( + "--workers", + type=int, + default=4, + help="The number of data loading workers (default: 4)", + ) + # Random Seed + parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") + args = parser.parse_args() + if args.rand_seed is None or args.rand_seed < 0: + args.rand_seed = random.randint(1, 100000) + assert args.save_dir is not None, "The save dir argument can not be None" + main(args) diff --git a/lib/models/xcore.py b/lib/models/xcore.py index 13be1cc..14e826d 100644 --- a/lib/models/xcore.py +++ b/lib/models/xcore.py @@ -10,21 +10,26 @@ __all__ = ["get_model"] from xlayers.super_core import SuperSequential -from xlayers.super_core import SuperSimpleNorm -from xlayers.super_core import SuperLeakyReLU from xlayers.super_core import SuperLinear +from xlayers.super_core import super_name2norm +from xlayers.super_core import super_name2activation def get_model(config: Dict[Text, Any], **kwargs): model_type = config.get("model_type", "simple_mlp") if model_type == "simple_mlp": + act_cls = super_name2activation[kwargs["act_cls"]] + norm_cls = super_name2norm[kwargs["norm_cls"]] + mean, std = kwargs.get("mean", None), kwargs.get("std", None) + hidden_dim1 = kwargs.get("hidden_dim1", 200) + hidden_dim2 = kwargs.get("hidden_dim2", 100) model = SuperSequential( - SuperSimpleNorm(kwargs["mean"], kwargs["std"]), - SuperLinear(kwargs["input_dim"], 200), - SuperLeakyReLU(), - SuperLinear(200, 100), - SuperLeakyReLU(), - SuperLinear(100, kwargs["output_dim"]), + norm_cls(mean=mean, std=std), + SuperLinear(kwargs["input_dim"], hidden_dim1), + act_cls(), + SuperLinear(hidden_dim1, hidden_dim2), + act_cls(), + SuperLinear(hidden_dim2, kwargs["output_dim"]), ) else: raise TypeError("Unkonwn model type: {:}".format(model_type)) diff --git a/lib/xlayers/super_core.py b/lib/xlayers/super_core.py index 58a0c2f..fceb3e7 100644 --- a/lib/xlayers/super_core.py +++ b/lib/xlayers/super_core.py @@ -9,13 +9,27 @@ from .super_module import SuperModule from .super_container import SuperSequential from .super_linear import SuperLinear from .super_linear import SuperMLPv1, SuperMLPv2 + from .super_norm import SuperSimpleNorm from .super_norm import SuperLayerNorm1D +from .super_norm import SuperSimpleLearnableNorm +from .super_norm import SuperIdentity + +super_name2norm = { + "simple_norm": SuperSimpleNorm, + "simple_learn_norm": SuperSimpleLearnableNorm, + "layer_norm_1d": SuperLayerNorm1D, + "identity": SuperIdentity, +} + from .super_attention import SuperAttention from .super_transformer import SuperTransformerEncoderLayer from .super_activations import SuperReLU from .super_activations import SuperLeakyReLU +super_name2activation = {"relu": SuperReLU, "leaky_relu": SuperLeakyReLU} + + from .super_trade_stem import SuperAlphaEBDv1 from .super_positional_embedding import SuperPositionalEncoder diff --git a/lib/xlayers/super_module.py b/lib/xlayers/super_module.py index 9004f34..d4e07eb 100644 --- a/lib/xlayers/super_module.py +++ b/lib/xlayers/super_module.py @@ -30,6 +30,45 @@ class SuperRunMode(Enum): Default = "fullmodel" +class TensorContainer: + """A class to maintain both parameters and buffers for a model.""" + + def __init__(self): + self._names = [] + self._tensors = [] + self._param_or_buffers = [] + self._name2index = dict() + + def append(self, name, tensor, param_or_buffer): + if not isinstance(tensor, torch.Tensor): + raise TypeError( + "The input tensor must be torch.Tensor instead of {:}".format( + type(tensor) + ) + ) + self._names.append(name) + self._tensors.append(tensor) + self._param_or_buffers.append(param_or_buffer) + assert name not in self._name2index, "The [{:}] has already been added.".format( + name + ) + self._name2index[name] = len(self._names) - 1 + + def numel(self): + total = 0 + for tensor in self._tensors: + total += tensor.numel() + return total + + def __len__(self): + return len(self._names) + + def __repr__(self): + return "{name}({num} tensors)".format( + name=self.__class__.__name__, num=len(self) + ) + + class SuperModule(abc.ABC, nn.Module): """This class equips the nn.Module class with the ability to apply AutoDL.""" @@ -71,6 +110,14 @@ class SuperModule(abc.ABC, nn.Module): ) self._abstract_child = abstract_child + def named_parameters_buffers(self): + container = TensorContainer() + for name, param in self.named_parameters(): + container.append(name, param, True) + for name, buf in self.named_buffers(): + container.append(name, buf, False) + return container + @property def abstract_search_space(self): raise NotImplementedError diff --git a/lib/xlayers/super_norm.py b/lib/xlayers/super_norm.py index 0103781..5671c0f 100644 --- a/lib/xlayers/super_norm.py +++ b/lib/xlayers/super_norm.py @@ -89,8 +89,8 @@ class SuperSimpleNorm(SuperModule): def __init__(self, mean, std, inplace=False) -> None: super(SuperSimpleNorm, self).__init__() - self._mean = mean - self._std = std + self.register_buffer("_mean", torch.tensor(mean, dtype=torch.float)) + self.register_buffer("_std", torch.tensor(std, dtype=torch.float)) self._inplace = inplace @property @@ -111,7 +111,7 @@ class SuperSimpleNorm(SuperModule): if (std == 0).any(): raise ValueError( "std evaluated to zero after conversion to {}, leading to division by zero.".format( - dtype + tensor.dtype ) ) while mean.ndim < tensor.ndim: @@ -119,6 +119,75 @@ class SuperSimpleNorm(SuperModule): return tensor.sub_(mean).div_(std) def extra_repr(self) -> str: - return "mean={mean}, std={mean}, inplace={inplace}".format( - mean=self._mean, std=self._std, inplace=self._inplace + return "mean={mean}, std={std}, inplace={inplace}".format( + mean=self._mean.item(), std=self._std.item(), inplace=self._inplace ) + + +class SuperSimpleLearnableNorm(SuperModule): + """Super simple normalization.""" + + def __init__(self, mean=0, std=1, eps=1e-6, inplace=False) -> None: + super(SuperSimpleLearnableNorm, self).__init__() + self.register_parameter( + "_mean", nn.Parameter(torch.tensor(mean, dtype=torch.float)) + ) + self.register_parameter( + "_std", nn.Parameter(torch.tensor(std, dtype=torch.float)) + ) + self._eps = eps + self._inplace = inplace + + @property + def abstract_search_space(self): + return spaces.VirtualNode(id(self)) + + def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: + # check inputs -> + return self.forward_raw(input) + + def forward_raw(self, input: torch.Tensor) -> torch.Tensor: + if not self._inplace: + tensor = input.clone() + else: + tensor = input + mean, std = ( + self._mean.to(tensor.device), + torch.abs(self._std.to(tensor.device)) + self._eps, + ) + if (std == 0).any(): + raise ValueError("std leads to division by zero.") + while mean.ndim < tensor.ndim: + mean, std = torch.unsqueeze(mean, dim=0), torch.unsqueeze(std, dim=0) + return tensor.sub_(mean).div_(std) + + def extra_repr(self) -> str: + return "mean={mean}, std={std}, inplace={inplace}".format( + mean=self._mean.item(), std=self._std.item(), inplace=self._inplace + ) + + +class SuperIdentity(SuperModule): + """Super identity mapping layer.""" + + def __init__(self, inplace=False, **kwargs) -> None: + super(SuperIdentity, self).__init__() + self._inplace = inplace + + @property + def abstract_search_space(self): + return spaces.VirtualNode(id(self)) + + def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: + # check inputs -> + return self.forward_raw(input) + + def forward_raw(self, input: torch.Tensor) -> torch.Tensor: + if not self._inplace: + tensor = input.clone() + else: + tensor = input + return tensor + + def extra_repr(self) -> str: + return "inplace={inplace}".format(inplace=self._inplace) diff --git a/tests/test_super_norm.py b/tests/test_super_norm.py index c50f397..d5a21d6 100644 --- a/tests/test_super_norm.py +++ b/tests/test_super_norm.py @@ -51,3 +51,35 @@ class TestSuperSimpleNorm(unittest.TestCase): output_shape = (20, abstract_child["1"]["_out_features"].value) outputs = model(inputs) self.assertEqual(tuple(outputs.shape), output_shape) + + def test_super_simple_learn_norm(self): + out_features = spaces.Categorical(12, 24, 36) + bias = spaces.Categorical(True, False) + model = super_core.SuperSequential( + super_core.SuperSimpleLearnableNorm(), + super_core.SuperIdentity(), + super_core.SuperLinear(10, out_features, bias=bias), + ) + print("The simple super module is:\n{:}".format(model)) + model.apply_verbose(True) + + print(model.super_run_type) + self.assertTrue(model[1].bias) + + inputs = torch.rand(20, 10) + print("Input shape: {:}".format(inputs.shape)) + outputs = model(inputs) + self.assertEqual(tuple(outputs.shape), (20, 36)) + + abstract_space = model.abstract_search_space + abstract_space.clean_last() + abstract_child = abstract_space.random() + print("The abstract searc space:\n{:}".format(abstract_space)) + print("The abstract child program:\n{:}".format(abstract_child)) + + model.set_super_run_type(super_core.SuperRunMode.Candidate) + model.apply_candidate(abstract_child) + + output_shape = (20, abstract_child["1"]["_out_features"].value) + outputs = model(inputs) + self.assertEqual(tuple(outputs.shape), output_shape)