diff --git a/exps/LFNA/lfna-debug.py b/exps/LFNA/backup/lfna-debug.py similarity index 89% rename from exps/LFNA/lfna-debug.py rename to exps/LFNA/backup/lfna-debug.py index b5a3963..969d9d1 100644 --- a/exps/LFNA/lfna-debug.py +++ b/exps/LFNA/backup/lfna-debug.py @@ -25,6 +25,7 @@ from xlayers import super_core from lfna_utils import lfna_setup, train_model, TimeData +from lfna_models import HyperNet class LFNAmlp: @@ -77,17 +78,40 @@ def main(args): nkey = "{:}-{:}".format(i, xkey) assert nkey in env_info, "{:} no in {:}".format(nkey, list(env_info.keys())) train_time_bar = total_time // 2 - network = get_model(dict(model_type="simple_mlp"), **model_kwargs) criterion = torch.nn.MSELoss() - logger.log("There are {:} weights.".format(network.get_w_container().numel())) + logger.log("There are {:} weights.".format(model.get_w_container().numel())) adaptor = LFNAmlp(args.meta_seq, (200, 200), "leaky_relu", criterion) # pre-train the model - init_dataset = TimeData(0, env_info["0-x"], env_info["0-y"]) - init_loss = train_model(network, init_dataset, args.init_lr, args.epochs) + dataset = init_dataset = TimeData(0, env_info["0-x"], env_info["0-y"]) + + shape_container = model.get_w_container().to_shape_container() + hypernet = HyperNet(shape_container, 16) + + optimizer = torch.optim.Adam(hypernet.parameters(), lr=args.init_lr, amsgrad=True) + + best_loss, best_param = None, None + for _iepoch in range(args.epochs): + container = hypernet(None) + + preds = model.forward_with_container(dataset.x, container) + optimizer.zero_grad() + loss = criterion(preds, dataset.y) + loss.backward() + optimizer.step() + # save best + if best_loss is None or best_loss > loss.item(): + best_loss = loss.item() + best_param = copy.deepcopy(model.state_dict()) + print("hyper-net : best={:.4f}".format(best_loss)) + + init_loss = train_model(model, init_dataset, args.init_lr, args.epochs) logger.log("The pre-training loss is {:.4f}".format(init_loss)) + import pdb + + pdb.set_trace() all_past_containers = [] ground_truth_path = ( diff --git a/exps/LFNA/backup/lfna_models.py b/exps/LFNA/backup/lfna_models.py new file mode 100644 index 0000000..2e163c7 --- /dev/null +++ b/exps/LFNA/backup/lfna_models.py @@ -0,0 +1,50 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # +##################################################### +import copy +import torch + +from xlayers import super_core +from xlayers import trunc_normal_ +from models.xcore import get_model + + +class HyperNet(super_core.SuperModule): + def __init__(self, shape_container, input_embeding, return_container=True): + super(HyperNet, self).__init__() + self._shape_container = shape_container + self._num_layers = len(shape_container) + self._numel_per_layer = [] + for ilayer in range(self._num_layers): + self._numel_per_layer.append(shape_container[ilayer].numel()) + + self.register_parameter( + "_super_layer_embed", + torch.nn.Parameter(torch.Tensor(self._num_layers, input_embeding)), + ) + trunc_normal_(self._super_layer_embed, std=0.02) + + model_kwargs = dict( + input_dim=input_embeding, + output_dim=max(self._numel_per_layer), + hidden_dim=input_embeding * 4, + act_cls="sigmoid", + norm_cls="identity", + ) + self._generator = get_model(dict(model_type="simple_mlp"), **model_kwargs) + self._return_container = return_container + print("generator: {:}".format(self._generator)) + + def forward_raw(self, input): + weights = self._generator(self._super_layer_embed) + if self._return_container: + weights = torch.split(weights, 1) + return self._shape_container.translate(weights) + else: + return weights + + def forward_candidate(self, input): + raise NotImplementedError + + def extra_repr(self) -> str: + return "(_super_layer_embed): {:}".format(list(self._super_layer_embed.shape)) diff --git a/lib/models/xcore.py b/lib/models/xcore.py index 91a0498..b547554 100644 --- a/lib/models/xcore.py +++ b/lib/models/xcore.py @@ -37,4 +37,4 @@ def get_model(config: Dict[Text, Any], **kwargs): ) else: raise TypeError("Unkonwn model type: {:}".format(model_type)) - return model + return model \ No newline at end of file diff --git a/lib/xlayers/super_activations.py b/lib/xlayers/super_activations.py index bf3f3e8..c4fbab6 100644 --- a/lib/xlayers/super_activations.py +++ b/lib/xlayers/super_activations.py @@ -38,6 +38,46 @@ class SuperReLU(SuperModule): return "inplace=True" if self._inplace else "" +class SuperGELU(SuperModule): + """Applies a the Gaussian Error Linear Units function element-wise.""" + + def __init__(self) -> None: + super(SuperGELU, self).__init__() + + @property + def abstract_search_space(self): + return spaces.VirtualNode(id(self)) + + def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: + return self.forward_raw(input) + + def forward_raw(self, input: torch.Tensor) -> torch.Tensor: + return F.gelu(input) + + def forward_with_container(self, input, container, prefix=[]): + return self.forward_raw(input) + + +class SuperSigmoid(SuperModule): + """Applies a the Sigmoid function element-wise.""" + + def __init__(self) -> None: + super(SuperSigmoid, self).__init__() + + @property + def abstract_search_space(self): + return spaces.VirtualNode(id(self)) + + def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: + return self.forward_raw(input) + + def forward_raw(self, input: torch.Tensor) -> torch.Tensor: + return torch.sigmoid(input) + + def forward_with_container(self, input, container, prefix=[]): + return self.forward_raw(input) + + class SuperLeakyReLU(SuperModule): """https://pytorch.org/docs/stable/_modules/torch/nn/modules/activation.html#LeakyReLU""" diff --git a/lib/xlayers/super_core.py b/lib/xlayers/super_core.py index 3e1d04f..8b9d3fd 100644 --- a/lib/xlayers/super_core.py +++ b/lib/xlayers/super_core.py @@ -28,9 +28,13 @@ from .super_transformer import SuperTransformerEncoderLayer from .super_activations import SuperReLU from .super_activations import SuperLeakyReLU from .super_activations import SuperTanh +from .super_activations import SuperGELU +from .super_activations import SuperSigmoid super_name2activation = { "relu": SuperReLU, + "sigmoid": SuperSigmoid, + "gelu": SuperGELU, "leaky_relu": SuperLeakyReLU, "tanh": SuperTanh, } diff --git a/lib/xlayers/super_module.py b/lib/xlayers/super_module.py index 5a85c51..58a6993 100644 --- a/lib/xlayers/super_module.py +++ b/lib/xlayers/super_module.py @@ -11,128 +11,10 @@ from enum import Enum import spaces -IntSpaceType = Union[int, spaces.Integer, spaces.Categorical] -BoolSpaceType = Union[bool, spaces.Categorical] - - -class LayerOrder(Enum): - """This class defines the enumerations for order of operation in a residual or normalization-based layer.""" - - PreNorm = "pre-norm" - PostNorm = "post-norm" - - -class SuperRunMode(Enum): - """This class defines the enumerations for Super Model Running Mode.""" - - FullModel = "fullmodel" - Candidate = "candidate" - Default = "fullmodel" - - -class TensorContainer: - """A class to maintain both parameters and buffers for a model.""" - - def __init__(self): - self._names = [] - self._tensors = [] - self._param_or_buffers = [] - self._name2index = dict() - - def additive(self, tensors): - result = TensorContainer() - for index, name in enumerate(self._names): - new_tensor = self._tensors[index] + tensors[index] - result.append(name, new_tensor, self._param_or_buffers[index]) - return result - - def create_container(self, tensors): - result = TensorContainer() - for index, name in enumerate(self._names): - new_tensor = tensors[index] - result.append(name, new_tensor, self._param_or_buffers[index]) - return result - - def no_grad_clone(self): - result = TensorContainer() - with torch.no_grad(): - for index, name in enumerate(self._names): - result.append( - name, self._tensors[index].clone(), self._param_or_buffers[index] - ) - return result - - def requires_grad_(self, requires_grad=True): - for tensor in self._tensors: - tensor.requires_grad_(requires_grad) - - def parameters(self): - return self._tensors - - @property - def tensors(self): - return self._tensors - - def flatten(self, tensors=None): - if tensors is None: - tensors = self._tensors - tensors = [tensor.view(-1) for tensor in tensors] - return torch.cat(tensors) - - def unflatten(self, tensor): - tensors, s = [], 0 - for raw_tensor in self._tensors: - length = raw_tensor.numel() - x = torch.reshape(tensor[s : s + length], shape=raw_tensor.shape) - tensors.append(x) - s += length - return tensors - - def append(self, name, tensor, param_or_buffer): - if not isinstance(tensor, torch.Tensor): - raise TypeError( - "The input tensor must be torch.Tensor instead of {:}".format( - type(tensor) - ) - ) - self._names.append(name) - self._tensors.append(tensor) - self._param_or_buffers.append(param_or_buffer) - assert name not in self._name2index, "The [{:}] has already been added.".format( - name - ) - self._name2index[name] = len(self._names) - 1 - - def query(self, name): - if not self.has(name): - raise ValueError( - "The {:} is not in {:}".format(name, list(self._name2index.keys())) - ) - index = self._name2index[name] - return self._tensors[index] - - def has(self, name): - return name in self._name2index - - def has_prefix(self, prefix): - for name, idx in self._name2index.items(): - if name.startswith(prefix): - return name - return False - - def numel(self): - total = 0 - for tensor in self._tensors: - total += tensor.numel() - return total - - def __len__(self): - return len(self._names) - - def __repr__(self): - return "{name}({num} tensors)".format( - name=self.__class__.__name__, num=len(self) - ) +from .super_utils import IntSpaceType, BoolSpaceType +from .super_utils import LayerOrder, SuperRunMode +from .super_utils import TensorContainer +from .super_utils import ShapeContainer class SuperModule(abc.ABC, nn.Module): diff --git a/lib/xlayers/super_utils.py b/lib/xlayers/super_utils.py new file mode 100644 index 0000000..6fc0b99 --- /dev/null +++ b/lib/xlayers/super_utils.py @@ -0,0 +1,222 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # +##################################################### + +import abc +import warnings +from typing import Optional, Union, Callable +import torch +import torch.nn as nn +from enum import Enum + +import spaces + +IntSpaceType = Union[int, spaces.Integer, spaces.Categorical] +BoolSpaceType = Union[bool, spaces.Categorical] + + +class LayerOrder(Enum): + """This class defines the enumerations for order of operation in a residual or normalization-based layer.""" + + PreNorm = "pre-norm" + PostNorm = "post-norm" + + +class SuperRunMode(Enum): + """This class defines the enumerations for Super Model Running Mode.""" + + FullModel = "fullmodel" + Candidate = "candidate" + Default = "fullmodel" + + +class ShapeContainer: + """A class to maintain the shape of each weight tensor for a model.""" + + def __init__(self): + self._names = [] + self._shapes = [] + self._name2index = dict() + self._param_or_buffers = [] + + @property + def shapes(self): + return self._shapes + + def __getitem__(self, index): + return self._shapes[index] + + def translate(self, tensors, all_none_match=True): + result = TensorContainer() + for index, name in enumerate(self._names): + cur_num = tensors[index].numel() + expected_num = self._shapes[index].numel() + if cur_num < expected_num or ( + cur_num > expected_num and not all_none_match + ): + raise ValueError("Invalid {:} vs {:}".format(cur_num, expected_num)) + cur_tensor = tensors[index].view(-1)[:expected_num] + new_tensor = torch.reshape(cur_tensor, self._shapes[index]) + result.append(name, new_tensor, self._param_or_buffers[index]) + return result + + def append(self, name, shape, param_or_buffer): + if not isinstance(shape, torch.Size): + raise TypeError( + "The input tensor must be torch.Size instead of {:}".format(type(shape)) + ) + self._names.append(name) + self._shapes.append(shape) + self._param_or_buffers.append(param_or_buffer) + assert name not in self._name2index, "The [{:}] has already been added.".format( + name + ) + self._name2index[name] = len(self._names) - 1 + + def query(self, name): + if not self.has(name): + raise ValueError( + "The {:} is not in {:}".format(name, list(self._name2index.keys())) + ) + index = self._name2index[name] + return self._shapes[index] + + def has(self, name): + return name in self._name2index + + def has_prefix(self, prefix): + for name, idx in self._name2index.items(): + if name.startswith(prefix): + return name + return False + + def numel(self, index=None): + if index is None: + shapes = self._shapes + else: + shapes = [self._shapes[index]] + total = 0 + for shape in shapes: + total += shape.numel() + return total + + def __len__(self): + return len(self._names) + + def __repr__(self): + return "{name}({num} tensors)".format( + name=self.__class__.__name__, num=len(self) + ) + + +class TensorContainer: + """A class to maintain both parameters and buffers for a model.""" + + def __init__(self): + self._names = [] + self._tensors = [] + self._param_or_buffers = [] + self._name2index = dict() + + def additive(self, tensors): + result = TensorContainer() + for index, name in enumerate(self._names): + new_tensor = self._tensors[index] + tensors[index] + result.append(name, new_tensor, self._param_or_buffers[index]) + return result + + def create_container(self, tensors): + result = TensorContainer() + for index, name in enumerate(self._names): + new_tensor = tensors[index] + result.append(name, new_tensor, self._param_or_buffers[index]) + return result + + def no_grad_clone(self): + result = TensorContainer() + with torch.no_grad(): + for index, name in enumerate(self._names): + result.append( + name, self._tensors[index].clone(), self._param_or_buffers[index] + ) + return result + + def to_shape_container(self): + result = ShapeContainer() + for index, name in enumerate(self._names): + result.append( + name, self._tensors[index].shape, self._param_or_buffers[index] + ) + return result + + def requires_grad_(self, requires_grad=True): + for tensor in self._tensors: + tensor.requires_grad_(requires_grad) + + def parameters(self): + return self._tensors + + @property + def tensors(self): + return self._tensors + + def flatten(self, tensors=None): + if tensors is None: + tensors = self._tensors + tensors = [tensor.view(-1) for tensor in tensors] + return torch.cat(tensors) + + def unflatten(self, tensor): + tensors, s = [], 0 + for raw_tensor in self._tensors: + length = raw_tensor.numel() + x = torch.reshape(tensor[s : s + length], shape=raw_tensor.shape) + tensors.append(x) + s += length + return tensors + + def append(self, name, tensor, param_or_buffer): + if not isinstance(tensor, torch.Tensor): + raise TypeError( + "The input tensor must be torch.Tensor instead of {:}".format( + type(tensor) + ) + ) + self._names.append(name) + self._tensors.append(tensor) + self._param_or_buffers.append(param_or_buffer) + assert name not in self._name2index, "The [{:}] has already been added.".format( + name + ) + self._name2index[name] = len(self._names) - 1 + + def query(self, name): + if not self.has(name): + raise ValueError( + "The {:} is not in {:}".format(name, list(self._name2index.keys())) + ) + index = self._name2index[name] + return self._tensors[index] + + def has(self, name): + return name in self._name2index + + def has_prefix(self, prefix): + for name, idx in self._name2index.items(): + if name.startswith(prefix): + return name + return False + + def numel(self): + total = 0 + for tensor in self._tensors: + total += tensor.numel() + return total + + def __len__(self): + return len(self._names) + + def __repr__(self): + return "{name}({num} tensors)".format( + name=self.__class__.__name__, num=len(self) + )