Add super/norm layers in xcore
This commit is contained in:
		
							
								
								
									
										212
									
								
								exps/LFNA/lfna-v1.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										212
									
								
								exps/LFNA/lfna-v1.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,212 @@
 | 
			
		||||
#####################################################
 | 
			
		||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
 | 
			
		||||
#####################################################
 | 
			
		||||
# python exps/LFNA/lfna-v1.py
 | 
			
		||||
#####################################################
 | 
			
		||||
import sys, time, copy, torch, random, argparse
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
 | 
			
		||||
if str(lib_dir) not in sys.path:
 | 
			
		||||
    sys.path.insert(0, str(lib_dir))
 | 
			
		||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint
 | 
			
		||||
from log_utils import time_string
 | 
			
		||||
from log_utils import AverageMeter, convert_secs2time
 | 
			
		||||
 | 
			
		||||
from utils import split_str2indexes
 | 
			
		||||
 | 
			
		||||
from procedures.advanced_main import basic_train_fn, basic_eval_fn
 | 
			
		||||
from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric
 | 
			
		||||
from datasets.synthetic_core import get_synthetic_env
 | 
			
		||||
from models.xcore import get_model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Population:
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self._time2model = dict()
 | 
			
		||||
 | 
			
		||||
    def append(self, timestamp, model):
 | 
			
		||||
        if timestamp in self._time2model:
 | 
			
		||||
            raise ValueError("This timestamp has been added.")
 | 
			
		||||
        self._time2model[timestamp] = model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main(args):
 | 
			
		||||
    prepare_seed(args.rand_seed)
 | 
			
		||||
    logger = prepare_logger(args)
 | 
			
		||||
 | 
			
		||||
    cache_path = (logger.path(None) / ".." / "env-info.pth").resolve()
 | 
			
		||||
    if cache_path.exists():
 | 
			
		||||
        env_info = torch.load(cache_path)
 | 
			
		||||
    else:
 | 
			
		||||
        env_info = dict()
 | 
			
		||||
        dynamic_env = get_synthetic_env()
 | 
			
		||||
        env_info["total"] = len(dynamic_env)
 | 
			
		||||
        for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)):
 | 
			
		||||
            env_info["{:}-timestamp".format(idx)] = timestamp
 | 
			
		||||
            env_info["{:}-x".format(idx)] = _allx
 | 
			
		||||
            env_info["{:}-y".format(idx)] = _ally
 | 
			
		||||
        env_info["dynamic_env"] = dynamic_env
 | 
			
		||||
        torch.save(env_info, cache_path)
 | 
			
		||||
 | 
			
		||||
    total_time = env_info["total"]
 | 
			
		||||
    for i in range(total_time):
 | 
			
		||||
        for xkey in ("timestamp", "x", "y"):
 | 
			
		||||
            nkey = "{:}-{:}".format(i, xkey)
 | 
			
		||||
            assert nkey in env_info, "{:} no in {:}".format(nkey, list(env_info.keys()))
 | 
			
		||||
    train_time_bar = total_time // 2
 | 
			
		||||
    base_model = get_model(
 | 
			
		||||
        dict(model_type="simple_mlp"),
 | 
			
		||||
        act_cls="leaky_relu",
 | 
			
		||||
        norm_cls="simple_learn_norm",
 | 
			
		||||
        mean=0,
 | 
			
		||||
        std=1,
 | 
			
		||||
        input_dim=1,
 | 
			
		||||
        output_dim=1,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    w_container = base_model.named_parameters_buffers()
 | 
			
		||||
    print("There are {:} weights.".format(w_container.numel()))
 | 
			
		||||
 | 
			
		||||
    pool = Population()
 | 
			
		||||
    pool.append(0, w_container)
 | 
			
		||||
 | 
			
		||||
    # LFNA meta-training
 | 
			
		||||
    per_epoch_time, start_time = AverageMeter(), time.time()
 | 
			
		||||
    for iepoch in range(args.epochs):
 | 
			
		||||
        import pdb
 | 
			
		||||
 | 
			
		||||
        pdb.set_trace()
 | 
			
		||||
        print("-")
 | 
			
		||||
 | 
			
		||||
    for i, idx in enumerate(to_evaluate_indexes):
 | 
			
		||||
 | 
			
		||||
        need_time = "Time Left: {:}".format(
 | 
			
		||||
            convert_secs2time(
 | 
			
		||||
                per_timestamp_time.avg * (len(to_evaluate_indexes) - i), True
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        logger.log(
 | 
			
		||||
            "[{:}]".format(time_string())
 | 
			
		||||
            + " [{:04d}/{:04d}][{:04d}]".format(i, len(to_evaluate_indexes), idx)
 | 
			
		||||
            + " "
 | 
			
		||||
            + need_time
 | 
			
		||||
        )
 | 
			
		||||
        # train the same data
 | 
			
		||||
        assert idx != 0
 | 
			
		||||
        historical_x = env_info["{:}-x".format(idx)]
 | 
			
		||||
        historical_y = env_info["{:}-y".format(idx)]
 | 
			
		||||
        # build model
 | 
			
		||||
        mean, std = historical_x.mean().item(), historical_x.std().item()
 | 
			
		||||
        model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std)
 | 
			
		||||
        model = get_model(dict(model_type="simple_mlp"), **model_kwargs)
 | 
			
		||||
        # build optimizer
 | 
			
		||||
        optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
 | 
			
		||||
        criterion = torch.nn.MSELoss()
 | 
			
		||||
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
 | 
			
		||||
            optimizer,
 | 
			
		||||
            milestones=[
 | 
			
		||||
                int(args.epochs * 0.25),
 | 
			
		||||
                int(args.epochs * 0.5),
 | 
			
		||||
                int(args.epochs * 0.75),
 | 
			
		||||
            ],
 | 
			
		||||
            gamma=0.3,
 | 
			
		||||
        )
 | 
			
		||||
        train_metric = MSEMetric()
 | 
			
		||||
        best_loss, best_param = None, None
 | 
			
		||||
        for _iepoch in range(args.epochs):
 | 
			
		||||
            preds = model(historical_x)
 | 
			
		||||
            optimizer.zero_grad()
 | 
			
		||||
            loss = criterion(preds, historical_y)
 | 
			
		||||
            loss.backward()
 | 
			
		||||
            optimizer.step()
 | 
			
		||||
            lr_scheduler.step()
 | 
			
		||||
            # save best
 | 
			
		||||
            if best_loss is None or best_loss > loss.item():
 | 
			
		||||
                best_loss = loss.item()
 | 
			
		||||
                best_param = copy.deepcopy(model.state_dict())
 | 
			
		||||
        model.load_state_dict(best_param)
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            train_metric(preds, historical_y)
 | 
			
		||||
        train_results = train_metric.get_info()
 | 
			
		||||
 | 
			
		||||
        metric = ComposeMetric(MSEMetric(), SaveMetric())
 | 
			
		||||
        eval_dataset = torch.utils.data.TensorDataset(
 | 
			
		||||
            env_info["{:}-x".format(idx)], env_info["{:}-y".format(idx)]
 | 
			
		||||
        )
 | 
			
		||||
        eval_loader = torch.utils.data.DataLoader(
 | 
			
		||||
            eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0
 | 
			
		||||
        )
 | 
			
		||||
        results = basic_eval_fn(eval_loader, model, metric, logger)
 | 
			
		||||
        log_str = (
 | 
			
		||||
            "[{:}]".format(time_string())
 | 
			
		||||
            + " [{:04d}/{:04d}]".format(idx, env_info["total"])
 | 
			
		||||
            + " train-mse: {:.5f}, eval-mse: {:.5f}".format(
 | 
			
		||||
                train_results["mse"], results["mse"]
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        logger.log(log_str)
 | 
			
		||||
 | 
			
		||||
        save_path = logger.path(None) / "{:04d}-{:04d}.pth".format(
 | 
			
		||||
            idx, env_info["total"]
 | 
			
		||||
        )
 | 
			
		||||
        save_checkpoint(
 | 
			
		||||
            {
 | 
			
		||||
                "model_state_dict": model.state_dict(),
 | 
			
		||||
                "model": model,
 | 
			
		||||
                "index": idx,
 | 
			
		||||
                "timestamp": env_info["{:}-timestamp".format(idx)],
 | 
			
		||||
            },
 | 
			
		||||
            save_path,
 | 
			
		||||
            logger,
 | 
			
		||||
        )
 | 
			
		||||
        logger.log("")
 | 
			
		||||
 | 
			
		||||
        per_timestamp_time.update(time.time() - start_time)
 | 
			
		||||
        start_time = time.time()
 | 
			
		||||
 | 
			
		||||
    logger.log("-" * 200 + "\n")
 | 
			
		||||
    logger.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser("Use the data in the past.")
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--save_dir",
 | 
			
		||||
        type=str,
 | 
			
		||||
        default="./outputs/lfna-synthetic/lfna-v1",
 | 
			
		||||
        help="The checkpoint directory.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--init_lr",
 | 
			
		||||
        type=float,
 | 
			
		||||
        default=0.1,
 | 
			
		||||
        help="The initial learning rate for the optimizer (default is Adam)",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--batch_size",
 | 
			
		||||
        type=int,
 | 
			
		||||
        default=512,
 | 
			
		||||
        help="The batch size",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--epochs",
 | 
			
		||||
        type=int,
 | 
			
		||||
        default=1000,
 | 
			
		||||
        help="The total number of epochs.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--workers",
 | 
			
		||||
        type=int,
 | 
			
		||||
        default=4,
 | 
			
		||||
        help="The number of data loading workers (default: 4)",
 | 
			
		||||
    )
 | 
			
		||||
    # Random Seed
 | 
			
		||||
    parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    if args.rand_seed is None or args.rand_seed < 0:
 | 
			
		||||
        args.rand_seed = random.randint(1, 100000)
 | 
			
		||||
    assert args.save_dir is not None, "The save dir argument can not be None"
 | 
			
		||||
    main(args)
 | 
			
		||||
@@ -10,21 +10,26 @@ __all__ = ["get_model"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from xlayers.super_core import SuperSequential
 | 
			
		||||
from xlayers.super_core import SuperSimpleNorm
 | 
			
		||||
from xlayers.super_core import SuperLeakyReLU
 | 
			
		||||
from xlayers.super_core import SuperLinear
 | 
			
		||||
from xlayers.super_core import super_name2norm
 | 
			
		||||
from xlayers.super_core import super_name2activation
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_model(config: Dict[Text, Any], **kwargs):
 | 
			
		||||
    model_type = config.get("model_type", "simple_mlp")
 | 
			
		||||
    if model_type == "simple_mlp":
 | 
			
		||||
        act_cls = super_name2activation[kwargs["act_cls"]]
 | 
			
		||||
        norm_cls = super_name2norm[kwargs["norm_cls"]]
 | 
			
		||||
        mean, std = kwargs.get("mean", None), kwargs.get("std", None)
 | 
			
		||||
        hidden_dim1 = kwargs.get("hidden_dim1", 200)
 | 
			
		||||
        hidden_dim2 = kwargs.get("hidden_dim2", 100)
 | 
			
		||||
        model = SuperSequential(
 | 
			
		||||
            SuperSimpleNorm(kwargs["mean"], kwargs["std"]),
 | 
			
		||||
            SuperLinear(kwargs["input_dim"], 200),
 | 
			
		||||
            SuperLeakyReLU(),
 | 
			
		||||
            SuperLinear(200, 100),
 | 
			
		||||
            SuperLeakyReLU(),
 | 
			
		||||
            SuperLinear(100, kwargs["output_dim"]),
 | 
			
		||||
            norm_cls(mean=mean, std=std),
 | 
			
		||||
            SuperLinear(kwargs["input_dim"], hidden_dim1),
 | 
			
		||||
            act_cls(),
 | 
			
		||||
            SuperLinear(hidden_dim1, hidden_dim2),
 | 
			
		||||
            act_cls(),
 | 
			
		||||
            SuperLinear(hidden_dim2, kwargs["output_dim"]),
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        raise TypeError("Unkonwn model type: {:}".format(model_type))
 | 
			
		||||
 
 | 
			
		||||
@@ -9,13 +9,27 @@ from .super_module import SuperModule
 | 
			
		||||
from .super_container import SuperSequential
 | 
			
		||||
from .super_linear import SuperLinear
 | 
			
		||||
from .super_linear import SuperMLPv1, SuperMLPv2
 | 
			
		||||
 | 
			
		||||
from .super_norm import SuperSimpleNorm
 | 
			
		||||
from .super_norm import SuperLayerNorm1D
 | 
			
		||||
from .super_norm import SuperSimpleLearnableNorm
 | 
			
		||||
from .super_norm import SuperIdentity
 | 
			
		||||
 | 
			
		||||
super_name2norm = {
 | 
			
		||||
    "simple_norm": SuperSimpleNorm,
 | 
			
		||||
    "simple_learn_norm": SuperSimpleLearnableNorm,
 | 
			
		||||
    "layer_norm_1d": SuperLayerNorm1D,
 | 
			
		||||
    "identity": SuperIdentity,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
from .super_attention import SuperAttention
 | 
			
		||||
from .super_transformer import SuperTransformerEncoderLayer
 | 
			
		||||
 | 
			
		||||
from .super_activations import SuperReLU
 | 
			
		||||
from .super_activations import SuperLeakyReLU
 | 
			
		||||
 | 
			
		||||
super_name2activation = {"relu": SuperReLU, "leaky_relu": SuperLeakyReLU}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from .super_trade_stem import SuperAlphaEBDv1
 | 
			
		||||
from .super_positional_embedding import SuperPositionalEncoder
 | 
			
		||||
 
 | 
			
		||||
@@ -30,6 +30,45 @@ class SuperRunMode(Enum):
 | 
			
		||||
    Default = "fullmodel"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TensorContainer:
 | 
			
		||||
    """A class to maintain both parameters and buffers for a model."""
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self._names = []
 | 
			
		||||
        self._tensors = []
 | 
			
		||||
        self._param_or_buffers = []
 | 
			
		||||
        self._name2index = dict()
 | 
			
		||||
 | 
			
		||||
    def append(self, name, tensor, param_or_buffer):
 | 
			
		||||
        if not isinstance(tensor, torch.Tensor):
 | 
			
		||||
            raise TypeError(
 | 
			
		||||
                "The input tensor must be torch.Tensor instead of {:}".format(
 | 
			
		||||
                    type(tensor)
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
        self._names.append(name)
 | 
			
		||||
        self._tensors.append(tensor)
 | 
			
		||||
        self._param_or_buffers.append(param_or_buffer)
 | 
			
		||||
        assert name not in self._name2index, "The [{:}] has already been added.".format(
 | 
			
		||||
            name
 | 
			
		||||
        )
 | 
			
		||||
        self._name2index[name] = len(self._names) - 1
 | 
			
		||||
 | 
			
		||||
    def numel(self):
 | 
			
		||||
        total = 0
 | 
			
		||||
        for tensor in self._tensors:
 | 
			
		||||
            total += tensor.numel()
 | 
			
		||||
        return total
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        return len(self._names)
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return "{name}({num} tensors)".format(
 | 
			
		||||
            name=self.__class__.__name__, num=len(self)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SuperModule(abc.ABC, nn.Module):
 | 
			
		||||
    """This class equips the nn.Module class with the ability to apply AutoDL."""
 | 
			
		||||
 | 
			
		||||
@@ -71,6 +110,14 @@ class SuperModule(abc.ABC, nn.Module):
 | 
			
		||||
            )
 | 
			
		||||
        self._abstract_child = abstract_child
 | 
			
		||||
 | 
			
		||||
    def named_parameters_buffers(self):
 | 
			
		||||
        container = TensorContainer()
 | 
			
		||||
        for name, param in self.named_parameters():
 | 
			
		||||
            container.append(name, param, True)
 | 
			
		||||
        for name, buf in self.named_buffers():
 | 
			
		||||
            container.append(name, buf, False)
 | 
			
		||||
        return container
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def abstract_search_space(self):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 
 | 
			
		||||
@@ -89,8 +89,8 @@ class SuperSimpleNorm(SuperModule):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, mean, std, inplace=False) -> None:
 | 
			
		||||
        super(SuperSimpleNorm, self).__init__()
 | 
			
		||||
        self._mean = mean
 | 
			
		||||
        self._std = std
 | 
			
		||||
        self.register_buffer("_mean", torch.tensor(mean, dtype=torch.float))
 | 
			
		||||
        self.register_buffer("_std", torch.tensor(std, dtype=torch.float))
 | 
			
		||||
        self._inplace = inplace
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
@@ -111,7 +111,7 @@ class SuperSimpleNorm(SuperModule):
 | 
			
		||||
        if (std == 0).any():
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "std evaluated to zero after conversion to {}, leading to division by zero.".format(
 | 
			
		||||
                    dtype
 | 
			
		||||
                    tensor.dtype
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
        while mean.ndim < tensor.ndim:
 | 
			
		||||
@@ -119,6 +119,75 @@ class SuperSimpleNorm(SuperModule):
 | 
			
		||||
        return tensor.sub_(mean).div_(std)
 | 
			
		||||
 | 
			
		||||
    def extra_repr(self) -> str:
 | 
			
		||||
        return "mean={mean}, std={mean}, inplace={inplace}".format(
 | 
			
		||||
            mean=self._mean, std=self._std, inplace=self._inplace
 | 
			
		||||
        return "mean={mean}, std={std}, inplace={inplace}".format(
 | 
			
		||||
            mean=self._mean.item(), std=self._std.item(), inplace=self._inplace
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SuperSimpleLearnableNorm(SuperModule):
 | 
			
		||||
    """Super simple normalization."""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, mean=0, std=1, eps=1e-6, inplace=False) -> None:
 | 
			
		||||
        super(SuperSimpleLearnableNorm, self).__init__()
 | 
			
		||||
        self.register_parameter(
 | 
			
		||||
            "_mean", nn.Parameter(torch.tensor(mean, dtype=torch.float))
 | 
			
		||||
        )
 | 
			
		||||
        self.register_parameter(
 | 
			
		||||
            "_std", nn.Parameter(torch.tensor(std, dtype=torch.float))
 | 
			
		||||
        )
 | 
			
		||||
        self._eps = eps
 | 
			
		||||
        self._inplace = inplace
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def abstract_search_space(self):
 | 
			
		||||
        return spaces.VirtualNode(id(self))
 | 
			
		||||
 | 
			
		||||
    def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        # check inputs ->
 | 
			
		||||
        return self.forward_raw(input)
 | 
			
		||||
 | 
			
		||||
    def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        if not self._inplace:
 | 
			
		||||
            tensor = input.clone()
 | 
			
		||||
        else:
 | 
			
		||||
            tensor = input
 | 
			
		||||
        mean, std = (
 | 
			
		||||
            self._mean.to(tensor.device),
 | 
			
		||||
            torch.abs(self._std.to(tensor.device)) + self._eps,
 | 
			
		||||
        )
 | 
			
		||||
        if (std == 0).any():
 | 
			
		||||
            raise ValueError("std leads to division by zero.")
 | 
			
		||||
        while mean.ndim < tensor.ndim:
 | 
			
		||||
            mean, std = torch.unsqueeze(mean, dim=0), torch.unsqueeze(std, dim=0)
 | 
			
		||||
        return tensor.sub_(mean).div_(std)
 | 
			
		||||
 | 
			
		||||
    def extra_repr(self) -> str:
 | 
			
		||||
        return "mean={mean}, std={std}, inplace={inplace}".format(
 | 
			
		||||
            mean=self._mean.item(), std=self._std.item(), inplace=self._inplace
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SuperIdentity(SuperModule):
 | 
			
		||||
    """Super identity mapping layer."""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, inplace=False, **kwargs) -> None:
 | 
			
		||||
        super(SuperIdentity, self).__init__()
 | 
			
		||||
        self._inplace = inplace
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def abstract_search_space(self):
 | 
			
		||||
        return spaces.VirtualNode(id(self))
 | 
			
		||||
 | 
			
		||||
    def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        # check inputs ->
 | 
			
		||||
        return self.forward_raw(input)
 | 
			
		||||
 | 
			
		||||
    def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        if not self._inplace:
 | 
			
		||||
            tensor = input.clone()
 | 
			
		||||
        else:
 | 
			
		||||
            tensor = input
 | 
			
		||||
        return tensor
 | 
			
		||||
 | 
			
		||||
    def extra_repr(self) -> str:
 | 
			
		||||
        return "inplace={inplace}".format(inplace=self._inplace)
 | 
			
		||||
 
 | 
			
		||||
@@ -51,3 +51,35 @@ class TestSuperSimpleNorm(unittest.TestCase):
 | 
			
		||||
        output_shape = (20, abstract_child["1"]["_out_features"].value)
 | 
			
		||||
        outputs = model(inputs)
 | 
			
		||||
        self.assertEqual(tuple(outputs.shape), output_shape)
 | 
			
		||||
 | 
			
		||||
    def test_super_simple_learn_norm(self):
 | 
			
		||||
        out_features = spaces.Categorical(12, 24, 36)
 | 
			
		||||
        bias = spaces.Categorical(True, False)
 | 
			
		||||
        model = super_core.SuperSequential(
 | 
			
		||||
            super_core.SuperSimpleLearnableNorm(),
 | 
			
		||||
            super_core.SuperIdentity(),
 | 
			
		||||
            super_core.SuperLinear(10, out_features, bias=bias),
 | 
			
		||||
        )
 | 
			
		||||
        print("The simple super module is:\n{:}".format(model))
 | 
			
		||||
        model.apply_verbose(True)
 | 
			
		||||
 | 
			
		||||
        print(model.super_run_type)
 | 
			
		||||
        self.assertTrue(model[1].bias)
 | 
			
		||||
 | 
			
		||||
        inputs = torch.rand(20, 10)
 | 
			
		||||
        print("Input shape: {:}".format(inputs.shape))
 | 
			
		||||
        outputs = model(inputs)
 | 
			
		||||
        self.assertEqual(tuple(outputs.shape), (20, 36))
 | 
			
		||||
 | 
			
		||||
        abstract_space = model.abstract_search_space
 | 
			
		||||
        abstract_space.clean_last()
 | 
			
		||||
        abstract_child = abstract_space.random()
 | 
			
		||||
        print("The abstract searc space:\n{:}".format(abstract_space))
 | 
			
		||||
        print("The abstract child program:\n{:}".format(abstract_child))
 | 
			
		||||
 | 
			
		||||
        model.set_super_run_type(super_core.SuperRunMode.Candidate)
 | 
			
		||||
        model.apply_candidate(abstract_child)
 | 
			
		||||
 | 
			
		||||
        output_shape = (20, abstract_child["1"]["_out_features"].value)
 | 
			
		||||
        outputs = model(inputs)
 | 
			
		||||
        self.assertEqual(tuple(outputs.shape), output_shape)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user