Update super-activation layers
This commit is contained in:
		| @@ -25,6 +25,7 @@ from xlayers import super_core | ||||
| 
 | ||||
| 
 | ||||
| from lfna_utils import lfna_setup, train_model, TimeData | ||||
| from lfna_models import HyperNet | ||||
| 
 | ||||
| 
 | ||||
| class LFNAmlp: | ||||
| @@ -77,17 +78,40 @@ def main(args): | ||||
|             nkey = "{:}-{:}".format(i, xkey) | ||||
|             assert nkey in env_info, "{:} no in {:}".format(nkey, list(env_info.keys())) | ||||
|     train_time_bar = total_time // 2 | ||||
|     network = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||
| 
 | ||||
|     criterion = torch.nn.MSELoss() | ||||
|     logger.log("There are {:} weights.".format(network.get_w_container().numel())) | ||||
|     logger.log("There are {:} weights.".format(model.get_w_container().numel())) | ||||
| 
 | ||||
|     adaptor = LFNAmlp(args.meta_seq, (200, 200), "leaky_relu", criterion) | ||||
| 
 | ||||
|     # pre-train the model | ||||
|     init_dataset = TimeData(0, env_info["0-x"], env_info["0-y"]) | ||||
|     init_loss = train_model(network, init_dataset, args.init_lr, args.epochs) | ||||
|     dataset = init_dataset = TimeData(0, env_info["0-x"], env_info["0-y"]) | ||||
| 
 | ||||
|     shape_container = model.get_w_container().to_shape_container() | ||||
|     hypernet = HyperNet(shape_container, 16) | ||||
| 
 | ||||
|     optimizer = torch.optim.Adam(hypernet.parameters(), lr=args.init_lr, amsgrad=True) | ||||
| 
 | ||||
|     best_loss, best_param = None, None | ||||
|     for _iepoch in range(args.epochs): | ||||
|         container = hypernet(None) | ||||
| 
 | ||||
|         preds = model.forward_with_container(dataset.x, container) | ||||
|         optimizer.zero_grad() | ||||
|         loss = criterion(preds, dataset.y) | ||||
|         loss.backward() | ||||
|         optimizer.step() | ||||
|         # save best | ||||
|         if best_loss is None or best_loss > loss.item(): | ||||
|             best_loss = loss.item() | ||||
|             best_param = copy.deepcopy(model.state_dict()) | ||||
|     print("hyper-net : best={:.4f}".format(best_loss)) | ||||
| 
 | ||||
|     init_loss = train_model(model, init_dataset, args.init_lr, args.epochs) | ||||
|     logger.log("The pre-training loss is {:.4f}".format(init_loss)) | ||||
|     import pdb | ||||
| 
 | ||||
|     pdb.set_trace() | ||||
| 
 | ||||
|     all_past_containers = [] | ||||
|     ground_truth_path = ( | ||||
							
								
								
									
										50
									
								
								exps/LFNA/backup/lfna_models.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								exps/LFNA/backup/lfna_models.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,50 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ##################################################### | ||||
| import copy | ||||
| import torch | ||||
|  | ||||
| from xlayers import super_core | ||||
| from xlayers import trunc_normal_ | ||||
| from models.xcore import get_model | ||||
|  | ||||
|  | ||||
| class HyperNet(super_core.SuperModule): | ||||
|     def __init__(self, shape_container, input_embeding, return_container=True): | ||||
|         super(HyperNet, self).__init__() | ||||
|         self._shape_container = shape_container | ||||
|         self._num_layers = len(shape_container) | ||||
|         self._numel_per_layer = [] | ||||
|         for ilayer in range(self._num_layers): | ||||
|             self._numel_per_layer.append(shape_container[ilayer].numel()) | ||||
|  | ||||
|         self.register_parameter( | ||||
|             "_super_layer_embed", | ||||
|             torch.nn.Parameter(torch.Tensor(self._num_layers, input_embeding)), | ||||
|         ) | ||||
|         trunc_normal_(self._super_layer_embed, std=0.02) | ||||
|  | ||||
|         model_kwargs = dict( | ||||
|             input_dim=input_embeding, | ||||
|             output_dim=max(self._numel_per_layer), | ||||
|             hidden_dim=input_embeding * 4, | ||||
|             act_cls="sigmoid", | ||||
|             norm_cls="identity", | ||||
|         ) | ||||
|         self._generator = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||
|         self._return_container = return_container | ||||
|         print("generator: {:}".format(self._generator)) | ||||
|  | ||||
|     def forward_raw(self, input): | ||||
|         weights = self._generator(self._super_layer_embed) | ||||
|         if self._return_container: | ||||
|             weights = torch.split(weights, 1) | ||||
|             return self._shape_container.translate(weights) | ||||
|         else: | ||||
|             return weights | ||||
|  | ||||
|     def forward_candidate(self, input): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|         return "(_super_layer_embed): {:}".format(list(self._super_layer_embed.shape)) | ||||
| @@ -38,6 +38,46 @@ class SuperReLU(SuperModule): | ||||
|         return "inplace=True" if self._inplace else "" | ||||
|  | ||||
|  | ||||
| class SuperGELU(SuperModule): | ||||
|     """Applies a the Gaussian Error Linear Units function element-wise.""" | ||||
|  | ||||
|     def __init__(self) -> None: | ||||
|         super(SuperGELU, self).__init__() | ||||
|  | ||||
|     @property | ||||
|     def abstract_search_space(self): | ||||
|         return spaces.VirtualNode(id(self)) | ||||
|  | ||||
|     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         return self.forward_raw(input) | ||||
|  | ||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         return F.gelu(input) | ||||
|  | ||||
|     def forward_with_container(self, input, container, prefix=[]): | ||||
|         return self.forward_raw(input) | ||||
|  | ||||
|  | ||||
| class SuperSigmoid(SuperModule): | ||||
|     """Applies a the Sigmoid function element-wise.""" | ||||
|  | ||||
|     def __init__(self) -> None: | ||||
|         super(SuperSigmoid, self).__init__() | ||||
|  | ||||
|     @property | ||||
|     def abstract_search_space(self): | ||||
|         return spaces.VirtualNode(id(self)) | ||||
|  | ||||
|     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         return self.forward_raw(input) | ||||
|  | ||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         return torch.sigmoid(input) | ||||
|  | ||||
|     def forward_with_container(self, input, container, prefix=[]): | ||||
|         return self.forward_raw(input) | ||||
|  | ||||
|  | ||||
| class SuperLeakyReLU(SuperModule): | ||||
|     """https://pytorch.org/docs/stable/_modules/torch/nn/modules/activation.html#LeakyReLU""" | ||||
|  | ||||
|   | ||||
| @@ -28,9 +28,13 @@ from .super_transformer import SuperTransformerEncoderLayer | ||||
| from .super_activations import SuperReLU | ||||
| from .super_activations import SuperLeakyReLU | ||||
| from .super_activations import SuperTanh | ||||
| from .super_activations import SuperGELU | ||||
| from .super_activations import SuperSigmoid | ||||
|  | ||||
| super_name2activation = { | ||||
|     "relu": SuperReLU, | ||||
|     "sigmoid": SuperSigmoid, | ||||
|     "gelu": SuperGELU, | ||||
|     "leaky_relu": SuperLeakyReLU, | ||||
|     "tanh": SuperTanh, | ||||
| } | ||||
|   | ||||
| @@ -11,128 +11,10 @@ from enum import Enum | ||||
|  | ||||
| import spaces | ||||
|  | ||||
| IntSpaceType = Union[int, spaces.Integer, spaces.Categorical] | ||||
| BoolSpaceType = Union[bool, spaces.Categorical] | ||||
|  | ||||
|  | ||||
| class LayerOrder(Enum): | ||||
|     """This class defines the enumerations for order of operation in a residual or normalization-based layer.""" | ||||
|  | ||||
|     PreNorm = "pre-norm" | ||||
|     PostNorm = "post-norm" | ||||
|  | ||||
|  | ||||
| class SuperRunMode(Enum): | ||||
|     """This class defines the enumerations for Super Model Running Mode.""" | ||||
|  | ||||
|     FullModel = "fullmodel" | ||||
|     Candidate = "candidate" | ||||
|     Default = "fullmodel" | ||||
|  | ||||
|  | ||||
| class TensorContainer: | ||||
|     """A class to maintain both parameters and buffers for a model.""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         self._names = [] | ||||
|         self._tensors = [] | ||||
|         self._param_or_buffers = [] | ||||
|         self._name2index = dict() | ||||
|  | ||||
|     def additive(self, tensors): | ||||
|         result = TensorContainer() | ||||
|         for index, name in enumerate(self._names): | ||||
|             new_tensor = self._tensors[index] + tensors[index] | ||||
|             result.append(name, new_tensor, self._param_or_buffers[index]) | ||||
|         return result | ||||
|  | ||||
|     def create_container(self, tensors): | ||||
|         result = TensorContainer() | ||||
|         for index, name in enumerate(self._names): | ||||
|             new_tensor = tensors[index] | ||||
|             result.append(name, new_tensor, self._param_or_buffers[index]) | ||||
|         return result | ||||
|  | ||||
|     def no_grad_clone(self): | ||||
|         result = TensorContainer() | ||||
|         with torch.no_grad(): | ||||
|             for index, name in enumerate(self._names): | ||||
|                 result.append( | ||||
|                     name, self._tensors[index].clone(), self._param_or_buffers[index] | ||||
|                 ) | ||||
|         return result | ||||
|  | ||||
|     def requires_grad_(self, requires_grad=True): | ||||
|         for tensor in self._tensors: | ||||
|             tensor.requires_grad_(requires_grad) | ||||
|  | ||||
|     def parameters(self): | ||||
|         return self._tensors | ||||
|  | ||||
|     @property | ||||
|     def tensors(self): | ||||
|         return self._tensors | ||||
|  | ||||
|     def flatten(self, tensors=None): | ||||
|         if tensors is None: | ||||
|             tensors = self._tensors | ||||
|         tensors = [tensor.view(-1) for tensor in tensors] | ||||
|         return torch.cat(tensors) | ||||
|  | ||||
|     def unflatten(self, tensor): | ||||
|         tensors, s = [], 0 | ||||
|         for raw_tensor in self._tensors: | ||||
|             length = raw_tensor.numel() | ||||
|             x = torch.reshape(tensor[s : s + length], shape=raw_tensor.shape) | ||||
|             tensors.append(x) | ||||
|             s += length | ||||
|         return tensors | ||||
|  | ||||
|     def append(self, name, tensor, param_or_buffer): | ||||
|         if not isinstance(tensor, torch.Tensor): | ||||
|             raise TypeError( | ||||
|                 "The input tensor must be torch.Tensor instead of {:}".format( | ||||
|                     type(tensor) | ||||
|                 ) | ||||
|             ) | ||||
|         self._names.append(name) | ||||
|         self._tensors.append(tensor) | ||||
|         self._param_or_buffers.append(param_or_buffer) | ||||
|         assert name not in self._name2index, "The [{:}] has already been added.".format( | ||||
|             name | ||||
|         ) | ||||
|         self._name2index[name] = len(self._names) - 1 | ||||
|  | ||||
|     def query(self, name): | ||||
|         if not self.has(name): | ||||
|             raise ValueError( | ||||
|                 "The {:} is not in {:}".format(name, list(self._name2index.keys())) | ||||
|             ) | ||||
|         index = self._name2index[name] | ||||
|         return self._tensors[index] | ||||
|  | ||||
|     def has(self, name): | ||||
|         return name in self._name2index | ||||
|  | ||||
|     def has_prefix(self, prefix): | ||||
|         for name, idx in self._name2index.items(): | ||||
|             if name.startswith(prefix): | ||||
|                 return name | ||||
|         return False | ||||
|  | ||||
|     def numel(self): | ||||
|         total = 0 | ||||
|         for tensor in self._tensors: | ||||
|             total += tensor.numel() | ||||
|         return total | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self._names) | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({num} tensors)".format( | ||||
|             name=self.__class__.__name__, num=len(self) | ||||
|         ) | ||||
| from .super_utils import IntSpaceType, BoolSpaceType | ||||
| from .super_utils import LayerOrder, SuperRunMode | ||||
| from .super_utils import TensorContainer | ||||
| from .super_utils import ShapeContainer | ||||
|  | ||||
|  | ||||
| class SuperModule(abc.ABC, nn.Module): | ||||
|   | ||||
							
								
								
									
										222
									
								
								lib/xlayers/super_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										222
									
								
								lib/xlayers/super_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,222 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
|  | ||||
| import abc | ||||
| import warnings | ||||
| from typing import Optional, Union, Callable | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from enum import Enum | ||||
|  | ||||
| import spaces | ||||
|  | ||||
| IntSpaceType = Union[int, spaces.Integer, spaces.Categorical] | ||||
| BoolSpaceType = Union[bool, spaces.Categorical] | ||||
|  | ||||
|  | ||||
| class LayerOrder(Enum): | ||||
|     """This class defines the enumerations for order of operation in a residual or normalization-based layer.""" | ||||
|  | ||||
|     PreNorm = "pre-norm" | ||||
|     PostNorm = "post-norm" | ||||
|  | ||||
|  | ||||
| class SuperRunMode(Enum): | ||||
|     """This class defines the enumerations for Super Model Running Mode.""" | ||||
|  | ||||
|     FullModel = "fullmodel" | ||||
|     Candidate = "candidate" | ||||
|     Default = "fullmodel" | ||||
|  | ||||
|  | ||||
| class ShapeContainer: | ||||
|     """A class to maintain the shape of each weight tensor for a model.""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         self._names = [] | ||||
|         self._shapes = [] | ||||
|         self._name2index = dict() | ||||
|         self._param_or_buffers = [] | ||||
|  | ||||
|     @property | ||||
|     def shapes(self): | ||||
|         return self._shapes | ||||
|  | ||||
|     def __getitem__(self, index): | ||||
|         return self._shapes[index] | ||||
|  | ||||
|     def translate(self, tensors, all_none_match=True): | ||||
|         result = TensorContainer() | ||||
|         for index, name in enumerate(self._names): | ||||
|             cur_num = tensors[index].numel() | ||||
|             expected_num = self._shapes[index].numel() | ||||
|             if cur_num < expected_num or ( | ||||
|                 cur_num > expected_num and not all_none_match | ||||
|             ): | ||||
|                 raise ValueError("Invalid {:} vs {:}".format(cur_num, expected_num)) | ||||
|             cur_tensor = tensors[index].view(-1)[:expected_num] | ||||
|             new_tensor = torch.reshape(cur_tensor, self._shapes[index]) | ||||
|             result.append(name, new_tensor, self._param_or_buffers[index]) | ||||
|         return result | ||||
|  | ||||
|     def append(self, name, shape, param_or_buffer): | ||||
|         if not isinstance(shape, torch.Size): | ||||
|             raise TypeError( | ||||
|                 "The input tensor must be torch.Size instead of {:}".format(type(shape)) | ||||
|             ) | ||||
|         self._names.append(name) | ||||
|         self._shapes.append(shape) | ||||
|         self._param_or_buffers.append(param_or_buffer) | ||||
|         assert name not in self._name2index, "The [{:}] has already been added.".format( | ||||
|             name | ||||
|         ) | ||||
|         self._name2index[name] = len(self._names) - 1 | ||||
|  | ||||
|     def query(self, name): | ||||
|         if not self.has(name): | ||||
|             raise ValueError( | ||||
|                 "The {:} is not in {:}".format(name, list(self._name2index.keys())) | ||||
|             ) | ||||
|         index = self._name2index[name] | ||||
|         return self._shapes[index] | ||||
|  | ||||
|     def has(self, name): | ||||
|         return name in self._name2index | ||||
|  | ||||
|     def has_prefix(self, prefix): | ||||
|         for name, idx in self._name2index.items(): | ||||
|             if name.startswith(prefix): | ||||
|                 return name | ||||
|         return False | ||||
|  | ||||
|     def numel(self, index=None): | ||||
|         if index is None: | ||||
|             shapes = self._shapes | ||||
|         else: | ||||
|             shapes = [self._shapes[index]] | ||||
|         total = 0 | ||||
|         for shape in shapes: | ||||
|             total += shape.numel() | ||||
|         return total | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self._names) | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({num} tensors)".format( | ||||
|             name=self.__class__.__name__, num=len(self) | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class TensorContainer: | ||||
|     """A class to maintain both parameters and buffers for a model.""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         self._names = [] | ||||
|         self._tensors = [] | ||||
|         self._param_or_buffers = [] | ||||
|         self._name2index = dict() | ||||
|  | ||||
|     def additive(self, tensors): | ||||
|         result = TensorContainer() | ||||
|         for index, name in enumerate(self._names): | ||||
|             new_tensor = self._tensors[index] + tensors[index] | ||||
|             result.append(name, new_tensor, self._param_or_buffers[index]) | ||||
|         return result | ||||
|  | ||||
|     def create_container(self, tensors): | ||||
|         result = TensorContainer() | ||||
|         for index, name in enumerate(self._names): | ||||
|             new_tensor = tensors[index] | ||||
|             result.append(name, new_tensor, self._param_or_buffers[index]) | ||||
|         return result | ||||
|  | ||||
|     def no_grad_clone(self): | ||||
|         result = TensorContainer() | ||||
|         with torch.no_grad(): | ||||
|             for index, name in enumerate(self._names): | ||||
|                 result.append( | ||||
|                     name, self._tensors[index].clone(), self._param_or_buffers[index] | ||||
|                 ) | ||||
|         return result | ||||
|  | ||||
|     def to_shape_container(self): | ||||
|         result = ShapeContainer() | ||||
|         for index, name in enumerate(self._names): | ||||
|             result.append( | ||||
|                 name, self._tensors[index].shape, self._param_or_buffers[index] | ||||
|             ) | ||||
|         return result | ||||
|  | ||||
|     def requires_grad_(self, requires_grad=True): | ||||
|         for tensor in self._tensors: | ||||
|             tensor.requires_grad_(requires_grad) | ||||
|  | ||||
|     def parameters(self): | ||||
|         return self._tensors | ||||
|  | ||||
|     @property | ||||
|     def tensors(self): | ||||
|         return self._tensors | ||||
|  | ||||
|     def flatten(self, tensors=None): | ||||
|         if tensors is None: | ||||
|             tensors = self._tensors | ||||
|         tensors = [tensor.view(-1) for tensor in tensors] | ||||
|         return torch.cat(tensors) | ||||
|  | ||||
|     def unflatten(self, tensor): | ||||
|         tensors, s = [], 0 | ||||
|         for raw_tensor in self._tensors: | ||||
|             length = raw_tensor.numel() | ||||
|             x = torch.reshape(tensor[s : s + length], shape=raw_tensor.shape) | ||||
|             tensors.append(x) | ||||
|             s += length | ||||
|         return tensors | ||||
|  | ||||
|     def append(self, name, tensor, param_or_buffer): | ||||
|         if not isinstance(tensor, torch.Tensor): | ||||
|             raise TypeError( | ||||
|                 "The input tensor must be torch.Tensor instead of {:}".format( | ||||
|                     type(tensor) | ||||
|                 ) | ||||
|             ) | ||||
|         self._names.append(name) | ||||
|         self._tensors.append(tensor) | ||||
|         self._param_or_buffers.append(param_or_buffer) | ||||
|         assert name not in self._name2index, "The [{:}] has already been added.".format( | ||||
|             name | ||||
|         ) | ||||
|         self._name2index[name] = len(self._names) - 1 | ||||
|  | ||||
|     def query(self, name): | ||||
|         if not self.has(name): | ||||
|             raise ValueError( | ||||
|                 "The {:} is not in {:}".format(name, list(self._name2index.keys())) | ||||
|             ) | ||||
|         index = self._name2index[name] | ||||
|         return self._tensors[index] | ||||
|  | ||||
|     def has(self, name): | ||||
|         return name in self._name2index | ||||
|  | ||||
|     def has_prefix(self, prefix): | ||||
|         for name, idx in self._name2index.items(): | ||||
|             if name.startswith(prefix): | ||||
|                 return name | ||||
|         return False | ||||
|  | ||||
|     def numel(self): | ||||
|         total = 0 | ||||
|         for tensor in self._tensors: | ||||
|             total += tensor.numel() | ||||
|         return total | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self._names) | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({num} tensors)".format( | ||||
|             name=self.__class__.__name__, num=len(self) | ||||
|         ) | ||||
		Reference in New Issue
	
	Block a user