Update super-activation layers
This commit is contained in:
		| @@ -25,6 +25,7 @@ from xlayers import super_core | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| from lfna_utils import lfna_setup, train_model, TimeData | from lfna_utils import lfna_setup, train_model, TimeData | ||||||
|  | from lfna_models import HyperNet | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class LFNAmlp: | class LFNAmlp: | ||||||
| @@ -77,17 +78,40 @@ def main(args): | |||||||
|             nkey = "{:}-{:}".format(i, xkey) |             nkey = "{:}-{:}".format(i, xkey) | ||||||
|             assert nkey in env_info, "{:} no in {:}".format(nkey, list(env_info.keys())) |             assert nkey in env_info, "{:} no in {:}".format(nkey, list(env_info.keys())) | ||||||
|     train_time_bar = total_time // 2 |     train_time_bar = total_time // 2 | ||||||
|     network = get_model(dict(model_type="simple_mlp"), **model_kwargs) |  | ||||||
| 
 | 
 | ||||||
|     criterion = torch.nn.MSELoss() |     criterion = torch.nn.MSELoss() | ||||||
|     logger.log("There are {:} weights.".format(network.get_w_container().numel())) |     logger.log("There are {:} weights.".format(model.get_w_container().numel())) | ||||||
| 
 | 
 | ||||||
|     adaptor = LFNAmlp(args.meta_seq, (200, 200), "leaky_relu", criterion) |     adaptor = LFNAmlp(args.meta_seq, (200, 200), "leaky_relu", criterion) | ||||||
| 
 | 
 | ||||||
|     # pre-train the model |     # pre-train the model | ||||||
|     init_dataset = TimeData(0, env_info["0-x"], env_info["0-y"]) |     dataset = init_dataset = TimeData(0, env_info["0-x"], env_info["0-y"]) | ||||||
|     init_loss = train_model(network, init_dataset, args.init_lr, args.epochs) | 
 | ||||||
|  |     shape_container = model.get_w_container().to_shape_container() | ||||||
|  |     hypernet = HyperNet(shape_container, 16) | ||||||
|  | 
 | ||||||
|  |     optimizer = torch.optim.Adam(hypernet.parameters(), lr=args.init_lr, amsgrad=True) | ||||||
|  | 
 | ||||||
|  |     best_loss, best_param = None, None | ||||||
|  |     for _iepoch in range(args.epochs): | ||||||
|  |         container = hypernet(None) | ||||||
|  | 
 | ||||||
|  |         preds = model.forward_with_container(dataset.x, container) | ||||||
|  |         optimizer.zero_grad() | ||||||
|  |         loss = criterion(preds, dataset.y) | ||||||
|  |         loss.backward() | ||||||
|  |         optimizer.step() | ||||||
|  |         # save best | ||||||
|  |         if best_loss is None or best_loss > loss.item(): | ||||||
|  |             best_loss = loss.item() | ||||||
|  |             best_param = copy.deepcopy(model.state_dict()) | ||||||
|  |     print("hyper-net : best={:.4f}".format(best_loss)) | ||||||
|  | 
 | ||||||
|  |     init_loss = train_model(model, init_dataset, args.init_lr, args.epochs) | ||||||
|     logger.log("The pre-training loss is {:.4f}".format(init_loss)) |     logger.log("The pre-training loss is {:.4f}".format(init_loss)) | ||||||
|  |     import pdb | ||||||
|  | 
 | ||||||
|  |     pdb.set_trace() | ||||||
| 
 | 
 | ||||||
|     all_past_containers = [] |     all_past_containers = [] | ||||||
|     ground_truth_path = ( |     ground_truth_path = ( | ||||||
							
								
								
									
										50
									
								
								exps/LFNA/backup/lfna_models.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								exps/LFNA/backup/lfna_models.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,50 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||||
|  | ##################################################### | ||||||
|  | import copy | ||||||
|  | import torch | ||||||
|  |  | ||||||
|  | from xlayers import super_core | ||||||
|  | from xlayers import trunc_normal_ | ||||||
|  | from models.xcore import get_model | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class HyperNet(super_core.SuperModule): | ||||||
|  |     def __init__(self, shape_container, input_embeding, return_container=True): | ||||||
|  |         super(HyperNet, self).__init__() | ||||||
|  |         self._shape_container = shape_container | ||||||
|  |         self._num_layers = len(shape_container) | ||||||
|  |         self._numel_per_layer = [] | ||||||
|  |         for ilayer in range(self._num_layers): | ||||||
|  |             self._numel_per_layer.append(shape_container[ilayer].numel()) | ||||||
|  |  | ||||||
|  |         self.register_parameter( | ||||||
|  |             "_super_layer_embed", | ||||||
|  |             torch.nn.Parameter(torch.Tensor(self._num_layers, input_embeding)), | ||||||
|  |         ) | ||||||
|  |         trunc_normal_(self._super_layer_embed, std=0.02) | ||||||
|  |  | ||||||
|  |         model_kwargs = dict( | ||||||
|  |             input_dim=input_embeding, | ||||||
|  |             output_dim=max(self._numel_per_layer), | ||||||
|  |             hidden_dim=input_embeding * 4, | ||||||
|  |             act_cls="sigmoid", | ||||||
|  |             norm_cls="identity", | ||||||
|  |         ) | ||||||
|  |         self._generator = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||||
|  |         self._return_container = return_container | ||||||
|  |         print("generator: {:}".format(self._generator)) | ||||||
|  |  | ||||||
|  |     def forward_raw(self, input): | ||||||
|  |         weights = self._generator(self._super_layer_embed) | ||||||
|  |         if self._return_container: | ||||||
|  |             weights = torch.split(weights, 1) | ||||||
|  |             return self._shape_container.translate(weights) | ||||||
|  |         else: | ||||||
|  |             return weights | ||||||
|  |  | ||||||
|  |     def forward_candidate(self, input): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def extra_repr(self) -> str: | ||||||
|  |         return "(_super_layer_embed): {:}".format(list(self._super_layer_embed.shape)) | ||||||
| @@ -38,6 +38,46 @@ class SuperReLU(SuperModule): | |||||||
|         return "inplace=True" if self._inplace else "" |         return "inplace=True" if self._inplace else "" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SuperGELU(SuperModule): | ||||||
|  |     """Applies a the Gaussian Error Linear Units function element-wise.""" | ||||||
|  |  | ||||||
|  |     def __init__(self) -> None: | ||||||
|  |         super(SuperGELU, self).__init__() | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def abstract_search_space(self): | ||||||
|  |         return spaces.VirtualNode(id(self)) | ||||||
|  |  | ||||||
|  |     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         return self.forward_raw(input) | ||||||
|  |  | ||||||
|  |     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         return F.gelu(input) | ||||||
|  |  | ||||||
|  |     def forward_with_container(self, input, container, prefix=[]): | ||||||
|  |         return self.forward_raw(input) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SuperSigmoid(SuperModule): | ||||||
|  |     """Applies a the Sigmoid function element-wise.""" | ||||||
|  |  | ||||||
|  |     def __init__(self) -> None: | ||||||
|  |         super(SuperSigmoid, self).__init__() | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def abstract_search_space(self): | ||||||
|  |         return spaces.VirtualNode(id(self)) | ||||||
|  |  | ||||||
|  |     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         return self.forward_raw(input) | ||||||
|  |  | ||||||
|  |     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         return torch.sigmoid(input) | ||||||
|  |  | ||||||
|  |     def forward_with_container(self, input, container, prefix=[]): | ||||||
|  |         return self.forward_raw(input) | ||||||
|  |  | ||||||
|  |  | ||||||
| class SuperLeakyReLU(SuperModule): | class SuperLeakyReLU(SuperModule): | ||||||
|     """https://pytorch.org/docs/stable/_modules/torch/nn/modules/activation.html#LeakyReLU""" |     """https://pytorch.org/docs/stable/_modules/torch/nn/modules/activation.html#LeakyReLU""" | ||||||
|  |  | ||||||
|   | |||||||
| @@ -28,9 +28,13 @@ from .super_transformer import SuperTransformerEncoderLayer | |||||||
| from .super_activations import SuperReLU | from .super_activations import SuperReLU | ||||||
| from .super_activations import SuperLeakyReLU | from .super_activations import SuperLeakyReLU | ||||||
| from .super_activations import SuperTanh | from .super_activations import SuperTanh | ||||||
|  | from .super_activations import SuperGELU | ||||||
|  | from .super_activations import SuperSigmoid | ||||||
|  |  | ||||||
| super_name2activation = { | super_name2activation = { | ||||||
|     "relu": SuperReLU, |     "relu": SuperReLU, | ||||||
|  |     "sigmoid": SuperSigmoid, | ||||||
|  |     "gelu": SuperGELU, | ||||||
|     "leaky_relu": SuperLeakyReLU, |     "leaky_relu": SuperLeakyReLU, | ||||||
|     "tanh": SuperTanh, |     "tanh": SuperTanh, | ||||||
| } | } | ||||||
|   | |||||||
| @@ -11,128 +11,10 @@ from enum import Enum | |||||||
|  |  | ||||||
| import spaces | import spaces | ||||||
|  |  | ||||||
| IntSpaceType = Union[int, spaces.Integer, spaces.Categorical] | from .super_utils import IntSpaceType, BoolSpaceType | ||||||
| BoolSpaceType = Union[bool, spaces.Categorical] | from .super_utils import LayerOrder, SuperRunMode | ||||||
|  | from .super_utils import TensorContainer | ||||||
|  | from .super_utils import ShapeContainer | ||||||
| class LayerOrder(Enum): |  | ||||||
|     """This class defines the enumerations for order of operation in a residual or normalization-based layer.""" |  | ||||||
|  |  | ||||||
|     PreNorm = "pre-norm" |  | ||||||
|     PostNorm = "post-norm" |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class SuperRunMode(Enum): |  | ||||||
|     """This class defines the enumerations for Super Model Running Mode.""" |  | ||||||
|  |  | ||||||
|     FullModel = "fullmodel" |  | ||||||
|     Candidate = "candidate" |  | ||||||
|     Default = "fullmodel" |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class TensorContainer: |  | ||||||
|     """A class to maintain both parameters and buffers for a model.""" |  | ||||||
|  |  | ||||||
|     def __init__(self): |  | ||||||
|         self._names = [] |  | ||||||
|         self._tensors = [] |  | ||||||
|         self._param_or_buffers = [] |  | ||||||
|         self._name2index = dict() |  | ||||||
|  |  | ||||||
|     def additive(self, tensors): |  | ||||||
|         result = TensorContainer() |  | ||||||
|         for index, name in enumerate(self._names): |  | ||||||
|             new_tensor = self._tensors[index] + tensors[index] |  | ||||||
|             result.append(name, new_tensor, self._param_or_buffers[index]) |  | ||||||
|         return result |  | ||||||
|  |  | ||||||
|     def create_container(self, tensors): |  | ||||||
|         result = TensorContainer() |  | ||||||
|         for index, name in enumerate(self._names): |  | ||||||
|             new_tensor = tensors[index] |  | ||||||
|             result.append(name, new_tensor, self._param_or_buffers[index]) |  | ||||||
|         return result |  | ||||||
|  |  | ||||||
|     def no_grad_clone(self): |  | ||||||
|         result = TensorContainer() |  | ||||||
|         with torch.no_grad(): |  | ||||||
|             for index, name in enumerate(self._names): |  | ||||||
|                 result.append( |  | ||||||
|                     name, self._tensors[index].clone(), self._param_or_buffers[index] |  | ||||||
|                 ) |  | ||||||
|         return result |  | ||||||
|  |  | ||||||
|     def requires_grad_(self, requires_grad=True): |  | ||||||
|         for tensor in self._tensors: |  | ||||||
|             tensor.requires_grad_(requires_grad) |  | ||||||
|  |  | ||||||
|     def parameters(self): |  | ||||||
|         return self._tensors |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def tensors(self): |  | ||||||
|         return self._tensors |  | ||||||
|  |  | ||||||
|     def flatten(self, tensors=None): |  | ||||||
|         if tensors is None: |  | ||||||
|             tensors = self._tensors |  | ||||||
|         tensors = [tensor.view(-1) for tensor in tensors] |  | ||||||
|         return torch.cat(tensors) |  | ||||||
|  |  | ||||||
|     def unflatten(self, tensor): |  | ||||||
|         tensors, s = [], 0 |  | ||||||
|         for raw_tensor in self._tensors: |  | ||||||
|             length = raw_tensor.numel() |  | ||||||
|             x = torch.reshape(tensor[s : s + length], shape=raw_tensor.shape) |  | ||||||
|             tensors.append(x) |  | ||||||
|             s += length |  | ||||||
|         return tensors |  | ||||||
|  |  | ||||||
|     def append(self, name, tensor, param_or_buffer): |  | ||||||
|         if not isinstance(tensor, torch.Tensor): |  | ||||||
|             raise TypeError( |  | ||||||
|                 "The input tensor must be torch.Tensor instead of {:}".format( |  | ||||||
|                     type(tensor) |  | ||||||
|                 ) |  | ||||||
|             ) |  | ||||||
|         self._names.append(name) |  | ||||||
|         self._tensors.append(tensor) |  | ||||||
|         self._param_or_buffers.append(param_or_buffer) |  | ||||||
|         assert name not in self._name2index, "The [{:}] has already been added.".format( |  | ||||||
|             name |  | ||||||
|         ) |  | ||||||
|         self._name2index[name] = len(self._names) - 1 |  | ||||||
|  |  | ||||||
|     def query(self, name): |  | ||||||
|         if not self.has(name): |  | ||||||
|             raise ValueError( |  | ||||||
|                 "The {:} is not in {:}".format(name, list(self._name2index.keys())) |  | ||||||
|             ) |  | ||||||
|         index = self._name2index[name] |  | ||||||
|         return self._tensors[index] |  | ||||||
|  |  | ||||||
|     def has(self, name): |  | ||||||
|         return name in self._name2index |  | ||||||
|  |  | ||||||
|     def has_prefix(self, prefix): |  | ||||||
|         for name, idx in self._name2index.items(): |  | ||||||
|             if name.startswith(prefix): |  | ||||||
|                 return name |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     def numel(self): |  | ||||||
|         total = 0 |  | ||||||
|         for tensor in self._tensors: |  | ||||||
|             total += tensor.numel() |  | ||||||
|         return total |  | ||||||
|  |  | ||||||
|     def __len__(self): |  | ||||||
|         return len(self._names) |  | ||||||
|  |  | ||||||
|     def __repr__(self): |  | ||||||
|         return "{name}({num} tensors)".format( |  | ||||||
|             name=self.__class__.__name__, num=len(self) |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class SuperModule(abc.ABC, nn.Module): | class SuperModule(abc.ABC, nn.Module): | ||||||
|   | |||||||
							
								
								
									
										222
									
								
								lib/xlayers/super_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										222
									
								
								lib/xlayers/super_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,222 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||||
|  | ##################################################### | ||||||
|  |  | ||||||
|  | import abc | ||||||
|  | import warnings | ||||||
|  | from typing import Optional, Union, Callable | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | from enum import Enum | ||||||
|  |  | ||||||
|  | import spaces | ||||||
|  |  | ||||||
|  | IntSpaceType = Union[int, spaces.Integer, spaces.Categorical] | ||||||
|  | BoolSpaceType = Union[bool, spaces.Categorical] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class LayerOrder(Enum): | ||||||
|  |     """This class defines the enumerations for order of operation in a residual or normalization-based layer.""" | ||||||
|  |  | ||||||
|  |     PreNorm = "pre-norm" | ||||||
|  |     PostNorm = "post-norm" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SuperRunMode(Enum): | ||||||
|  |     """This class defines the enumerations for Super Model Running Mode.""" | ||||||
|  |  | ||||||
|  |     FullModel = "fullmodel" | ||||||
|  |     Candidate = "candidate" | ||||||
|  |     Default = "fullmodel" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ShapeContainer: | ||||||
|  |     """A class to maintain the shape of each weight tensor for a model.""" | ||||||
|  |  | ||||||
|  |     def __init__(self): | ||||||
|  |         self._names = [] | ||||||
|  |         self._shapes = [] | ||||||
|  |         self._name2index = dict() | ||||||
|  |         self._param_or_buffers = [] | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def shapes(self): | ||||||
|  |         return self._shapes | ||||||
|  |  | ||||||
|  |     def __getitem__(self, index): | ||||||
|  |         return self._shapes[index] | ||||||
|  |  | ||||||
|  |     def translate(self, tensors, all_none_match=True): | ||||||
|  |         result = TensorContainer() | ||||||
|  |         for index, name in enumerate(self._names): | ||||||
|  |             cur_num = tensors[index].numel() | ||||||
|  |             expected_num = self._shapes[index].numel() | ||||||
|  |             if cur_num < expected_num or ( | ||||||
|  |                 cur_num > expected_num and not all_none_match | ||||||
|  |             ): | ||||||
|  |                 raise ValueError("Invalid {:} vs {:}".format(cur_num, expected_num)) | ||||||
|  |             cur_tensor = tensors[index].view(-1)[:expected_num] | ||||||
|  |             new_tensor = torch.reshape(cur_tensor, self._shapes[index]) | ||||||
|  |             result.append(name, new_tensor, self._param_or_buffers[index]) | ||||||
|  |         return result | ||||||
|  |  | ||||||
|  |     def append(self, name, shape, param_or_buffer): | ||||||
|  |         if not isinstance(shape, torch.Size): | ||||||
|  |             raise TypeError( | ||||||
|  |                 "The input tensor must be torch.Size instead of {:}".format(type(shape)) | ||||||
|  |             ) | ||||||
|  |         self._names.append(name) | ||||||
|  |         self._shapes.append(shape) | ||||||
|  |         self._param_or_buffers.append(param_or_buffer) | ||||||
|  |         assert name not in self._name2index, "The [{:}] has already been added.".format( | ||||||
|  |             name | ||||||
|  |         ) | ||||||
|  |         self._name2index[name] = len(self._names) - 1 | ||||||
|  |  | ||||||
|  |     def query(self, name): | ||||||
|  |         if not self.has(name): | ||||||
|  |             raise ValueError( | ||||||
|  |                 "The {:} is not in {:}".format(name, list(self._name2index.keys())) | ||||||
|  |             ) | ||||||
|  |         index = self._name2index[name] | ||||||
|  |         return self._shapes[index] | ||||||
|  |  | ||||||
|  |     def has(self, name): | ||||||
|  |         return name in self._name2index | ||||||
|  |  | ||||||
|  |     def has_prefix(self, prefix): | ||||||
|  |         for name, idx in self._name2index.items(): | ||||||
|  |             if name.startswith(prefix): | ||||||
|  |                 return name | ||||||
|  |         return False | ||||||
|  |  | ||||||
|  |     def numel(self, index=None): | ||||||
|  |         if index is None: | ||||||
|  |             shapes = self._shapes | ||||||
|  |         else: | ||||||
|  |             shapes = [self._shapes[index]] | ||||||
|  |         total = 0 | ||||||
|  |         for shape in shapes: | ||||||
|  |             total += shape.numel() | ||||||
|  |         return total | ||||||
|  |  | ||||||
|  |     def __len__(self): | ||||||
|  |         return len(self._names) | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return "{name}({num} tensors)".format( | ||||||
|  |             name=self.__class__.__name__, num=len(self) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TensorContainer: | ||||||
|  |     """A class to maintain both parameters and buffers for a model.""" | ||||||
|  |  | ||||||
|  |     def __init__(self): | ||||||
|  |         self._names = [] | ||||||
|  |         self._tensors = [] | ||||||
|  |         self._param_or_buffers = [] | ||||||
|  |         self._name2index = dict() | ||||||
|  |  | ||||||
|  |     def additive(self, tensors): | ||||||
|  |         result = TensorContainer() | ||||||
|  |         for index, name in enumerate(self._names): | ||||||
|  |             new_tensor = self._tensors[index] + tensors[index] | ||||||
|  |             result.append(name, new_tensor, self._param_or_buffers[index]) | ||||||
|  |         return result | ||||||
|  |  | ||||||
|  |     def create_container(self, tensors): | ||||||
|  |         result = TensorContainer() | ||||||
|  |         for index, name in enumerate(self._names): | ||||||
|  |             new_tensor = tensors[index] | ||||||
|  |             result.append(name, new_tensor, self._param_or_buffers[index]) | ||||||
|  |         return result | ||||||
|  |  | ||||||
|  |     def no_grad_clone(self): | ||||||
|  |         result = TensorContainer() | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             for index, name in enumerate(self._names): | ||||||
|  |                 result.append( | ||||||
|  |                     name, self._tensors[index].clone(), self._param_or_buffers[index] | ||||||
|  |                 ) | ||||||
|  |         return result | ||||||
|  |  | ||||||
|  |     def to_shape_container(self): | ||||||
|  |         result = ShapeContainer() | ||||||
|  |         for index, name in enumerate(self._names): | ||||||
|  |             result.append( | ||||||
|  |                 name, self._tensors[index].shape, self._param_or_buffers[index] | ||||||
|  |             ) | ||||||
|  |         return result | ||||||
|  |  | ||||||
|  |     def requires_grad_(self, requires_grad=True): | ||||||
|  |         for tensor in self._tensors: | ||||||
|  |             tensor.requires_grad_(requires_grad) | ||||||
|  |  | ||||||
|  |     def parameters(self): | ||||||
|  |         return self._tensors | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def tensors(self): | ||||||
|  |         return self._tensors | ||||||
|  |  | ||||||
|  |     def flatten(self, tensors=None): | ||||||
|  |         if tensors is None: | ||||||
|  |             tensors = self._tensors | ||||||
|  |         tensors = [tensor.view(-1) for tensor in tensors] | ||||||
|  |         return torch.cat(tensors) | ||||||
|  |  | ||||||
|  |     def unflatten(self, tensor): | ||||||
|  |         tensors, s = [], 0 | ||||||
|  |         for raw_tensor in self._tensors: | ||||||
|  |             length = raw_tensor.numel() | ||||||
|  |             x = torch.reshape(tensor[s : s + length], shape=raw_tensor.shape) | ||||||
|  |             tensors.append(x) | ||||||
|  |             s += length | ||||||
|  |         return tensors | ||||||
|  |  | ||||||
|  |     def append(self, name, tensor, param_or_buffer): | ||||||
|  |         if not isinstance(tensor, torch.Tensor): | ||||||
|  |             raise TypeError( | ||||||
|  |                 "The input tensor must be torch.Tensor instead of {:}".format( | ||||||
|  |                     type(tensor) | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |         self._names.append(name) | ||||||
|  |         self._tensors.append(tensor) | ||||||
|  |         self._param_or_buffers.append(param_or_buffer) | ||||||
|  |         assert name not in self._name2index, "The [{:}] has already been added.".format( | ||||||
|  |             name | ||||||
|  |         ) | ||||||
|  |         self._name2index[name] = len(self._names) - 1 | ||||||
|  |  | ||||||
|  |     def query(self, name): | ||||||
|  |         if not self.has(name): | ||||||
|  |             raise ValueError( | ||||||
|  |                 "The {:} is not in {:}".format(name, list(self._name2index.keys())) | ||||||
|  |             ) | ||||||
|  |         index = self._name2index[name] | ||||||
|  |         return self._tensors[index] | ||||||
|  |  | ||||||
|  |     def has(self, name): | ||||||
|  |         return name in self._name2index | ||||||
|  |  | ||||||
|  |     def has_prefix(self, prefix): | ||||||
|  |         for name, idx in self._name2index.items(): | ||||||
|  |             if name.startswith(prefix): | ||||||
|  |                 return name | ||||||
|  |         return False | ||||||
|  |  | ||||||
|  |     def numel(self): | ||||||
|  |         total = 0 | ||||||
|  |         for tensor in self._tensors: | ||||||
|  |             total += tensor.numel() | ||||||
|  |         return total | ||||||
|  |  | ||||||
|  |     def __len__(self): | ||||||
|  |         return len(self._names) | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return "{name}({num} tensors)".format( | ||||||
|  |             name=self.__class__.__name__, num=len(self) | ||||||
|  |         ) | ||||||
		Reference in New Issue
	
	Block a user