Update xlayers
This commit is contained in:
		| @@ -21,6 +21,57 @@ from procedures.advanced_main import basic_train_fn, basic_eval_fn | |||||||
| from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric | from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric | ||||||
| from datasets.synthetic_core import get_synthetic_env | from datasets.synthetic_core import get_synthetic_env | ||||||
| from models.xcore import get_model | from models.xcore import get_model | ||||||
|  | from xlayers import super_core | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class LFNAmlp: | ||||||
|  |     """A LFNA meta-model that uses the MLP as delta-net.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, obs_dim, hidden_sizes, act_name): | ||||||
|  |         self.delta_net = super_core.SuperSequential( | ||||||
|  |             super_core.SuperLinear(obs_dim, hidden_sizes[0]), | ||||||
|  |             super_core.super_name2activation[act_name](), | ||||||
|  |             super_core.SuperLinear(hidden_sizes[0], hidden_sizes[1]), | ||||||
|  |             super_core.super_name2activation[act_name](), | ||||||
|  |             super_core.SuperLinear(hidden_sizes[1], 1), | ||||||
|  |         ) | ||||||
|  |         self.meta_optimizer = torch.optim.Adam( | ||||||
|  |             self.delta_net.parameters(), lr=0.01, amsgrad=True | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def adapt(self, model, criterion, w_container, xs, ys): | ||||||
|  |         containers = [w_container] | ||||||
|  |         for idx, (x, y) in enumerate(zip(xs, ys)): | ||||||
|  |             y_hat = model.forward_with_container(x, containers[-1]) | ||||||
|  |             loss = criterion(y_hat, y) | ||||||
|  |             gradients = torch.autograd.grad(loss, containers[-1].tensors) | ||||||
|  |             with torch.no_grad(): | ||||||
|  |                 flatten_w = containers[-1].flatten().view(-1, 1) | ||||||
|  |                 flatten_g = containers[-1].flatten(gradients).view(-1, 1) | ||||||
|  |                 input_statistics = torch.tensor([x.mean(), x.std()]).view(1, 2) | ||||||
|  |                 input_statistics = input_statistics.expand(flatten_w.numel(), -1) | ||||||
|  |             delta_inputs = torch.cat((flatten_w, flatten_g, input_statistics), dim=-1) | ||||||
|  |             delta = self.delta_net(delta_inputs).view(-1) | ||||||
|  |             # delta = torch.clamp(delta, -0.5, 0.5) | ||||||
|  |             unflatten_delta = containers[-1].unflatten(delta) | ||||||
|  |             future_container = containers[-1].additive(unflatten_delta) | ||||||
|  |             containers.append(future_container) | ||||||
|  |         # containers = containers[1:] | ||||||
|  |         meta_loss = [] | ||||||
|  |         for idx, (x, y) in enumerate(zip(xs, ys)): | ||||||
|  |             if idx == 0: | ||||||
|  |                 continue | ||||||
|  |             current_container = containers[idx] | ||||||
|  |             y_hat = model.forward_with_container(x, current_container) | ||||||
|  |             loss = criterion(y_hat, y) | ||||||
|  |             meta_loss.append(loss) | ||||||
|  |         meta_loss = sum(meta_loss) | ||||||
|  |         meta_loss.backward() | ||||||
|  |         self.meta_optimizer.step() | ||||||
|  |  | ||||||
|  |     def zero_grad(self): | ||||||
|  |         self.meta_optimizer.zero_grad() | ||||||
|  |         self.delta_net.zero_grad() | ||||||
|  |  | ||||||
|  |  | ||||||
| class Population: | class Population: | ||||||
| @@ -28,11 +79,23 @@ class Population: | |||||||
|  |  | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         self._time2model = dict() |         self._time2model = dict() | ||||||
|  |         self._time2score = dict()  # higher is better | ||||||
|  |  | ||||||
|     def append(self, timestamp, model): |     def append(self, timestamp, model, score): | ||||||
|         if timestamp in self._time2model: |         if timestamp in self._time2model: | ||||||
|             raise ValueError("This timestamp has been added.") |             raise ValueError("This timestamp has been added.") | ||||||
|         self._time2model[timestamp] = model |         self._time2model[timestamp] = model | ||||||
|  |         self._time2score[timestamp] = score | ||||||
|  |  | ||||||
|  |     def query(self, timestamp): | ||||||
|  |         closet_timestamp = None | ||||||
|  |         for xtime, model in self._time2model.items(): | ||||||
|  |             if ( | ||||||
|  |                 closet_timestamp is None | ||||||
|  |                 or timestamp - closet_timestamp >= timestamp - xtime | ||||||
|  |             ): | ||||||
|  |                 closet_timestamp = xtime | ||||||
|  |         return self._time2model[closet_timestamp], closet_timestamp | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(args): | def main(args): | ||||||
| @@ -70,100 +133,39 @@ def main(args): | |||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     w_container = base_model.named_parameters_buffers() |     w_container = base_model.named_parameters_buffers() | ||||||
|  |     criterion = torch.nn.MSELoss() | ||||||
|     print("There are {:} weights.".format(w_container.numel())) |     print("There are {:} weights.".format(w_container.numel())) | ||||||
|  |  | ||||||
|  |     adaptor = LFNAmlp(4, (50, 20), "leaky_relu") | ||||||
|  |  | ||||||
|     pool = Population() |     pool = Population() | ||||||
|     pool.append(0, w_container) |     pool.append(0, w_container) | ||||||
|  |  | ||||||
|     # LFNA meta-training |     # LFNA meta-training | ||||||
|     per_epoch_time, start_time = AverageMeter(), time.time() |     per_epoch_time, start_time = AverageMeter(), time.time() | ||||||
|     for iepoch in range(args.epochs): |     for iepoch in range(args.epochs): | ||||||
|         import pdb |  | ||||||
|  |  | ||||||
|         pdb.set_trace() |  | ||||||
|         print("-") |  | ||||||
|  |  | ||||||
|     for i, idx in enumerate(to_evaluate_indexes): |  | ||||||
|  |  | ||||||
|         need_time = "Time Left: {:}".format( |         need_time = "Time Left: {:}".format( | ||||||
|             convert_secs2time( |             convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) | ||||||
|                 per_timestamp_time.avg * (len(to_evaluate_indexes) - i), True |  | ||||||
|             ) |  | ||||||
|         ) |         ) | ||||||
|         logger.log( |         logger.log( | ||||||
|             "[{:}]".format(time_string()) |             "[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs) | ||||||
|             + " [{:04d}/{:04d}][{:04d}]".format(i, len(to_evaluate_indexes), idx) |  | ||||||
|             + " " |  | ||||||
|             + need_time |             + need_time | ||||||
|         ) |         ) | ||||||
|         # train the same data |  | ||||||
|         assert idx != 0 |  | ||||||
|         historical_x = env_info["{:}-x".format(idx)] |  | ||||||
|         historical_y = env_info["{:}-y".format(idx)] |  | ||||||
|         # build model |  | ||||||
|         mean, std = historical_x.mean().item(), historical_x.std().item() |  | ||||||
|         model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std) |  | ||||||
|         model = get_model(dict(model_type="simple_mlp"), **model_kwargs) |  | ||||||
|         # build optimizer |  | ||||||
|         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) |  | ||||||
|         criterion = torch.nn.MSELoss() |  | ||||||
|         lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( |  | ||||||
|             optimizer, |  | ||||||
|             milestones=[ |  | ||||||
|                 int(args.epochs * 0.25), |  | ||||||
|                 int(args.epochs * 0.5), |  | ||||||
|                 int(args.epochs * 0.75), |  | ||||||
|             ], |  | ||||||
|             gamma=0.3, |  | ||||||
|         ) |  | ||||||
|         train_metric = MSEMetric() |  | ||||||
|         best_loss, best_param = None, None |  | ||||||
|         for _iepoch in range(args.epochs): |  | ||||||
|             preds = model(historical_x) |  | ||||||
|             optimizer.zero_grad() |  | ||||||
|             loss = criterion(preds, historical_y) |  | ||||||
|             loss.backward() |  | ||||||
|             optimizer.step() |  | ||||||
|             lr_scheduler.step() |  | ||||||
|             # save best |  | ||||||
|             if best_loss is None or best_loss > loss.item(): |  | ||||||
|                 best_loss = loss.item() |  | ||||||
|                 best_param = copy.deepcopy(model.state_dict()) |  | ||||||
|         model.load_state_dict(best_param) |  | ||||||
|         with torch.no_grad(): |  | ||||||
|             train_metric(preds, historical_y) |  | ||||||
|         train_results = train_metric.get_info() |  | ||||||
|  |  | ||||||
|         metric = ComposeMetric(MSEMetric(), SaveMetric()) |         for ibatch in range(args.meta_batch): | ||||||
|         eval_dataset = torch.utils.data.TensorDataset( |             sampled_timestamp = random.randint(0, train_time_bar) | ||||||
|             env_info["{:}-x".format(idx)], env_info["{:}-y".format(idx)] |             query_w_container, query_timestamp = pool.query(sampled_timestamp) | ||||||
|         ) |             # def adapt(self, model, w_container, xs, ys): | ||||||
|         eval_loader = torch.utils.data.DataLoader( |             xs, ys = [], [] | ||||||
|             eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0 |             for it in range(sampled_timestamp, sampled_timestamp + args.max_seq): | ||||||
|         ) |                 xs.append(env_info["{:}-x".format(it)]) | ||||||
|         results = basic_eval_fn(eval_loader, model, metric, logger) |                 ys.append(env_info["{:}-y".format(it)]) | ||||||
|         log_str = ( |             adaptor.adapt(base_model, criterion, query_w_container, xs, ys) | ||||||
|             "[{:}]".format(time_string()) |             import pdb | ||||||
|             + " [{:04d}/{:04d}]".format(idx, env_info["total"]) |  | ||||||
|             + " train-mse: {:.5f}, eval-mse: {:.5f}".format( |  | ||||||
|                 train_results["mse"], results["mse"] |  | ||||||
|             ) |  | ||||||
|         ) |  | ||||||
|         logger.log(log_str) |  | ||||||
|  |  | ||||||
|         save_path = logger.path(None) / "{:04d}-{:04d}.pth".format( |             pdb.set_trace() | ||||||
|             idx, env_info["total"] |         print("-") | ||||||
|         ) |  | ||||||
|         save_checkpoint( |  | ||||||
|             { |  | ||||||
|                 "model_state_dict": model.state_dict(), |  | ||||||
|                 "model": model, |  | ||||||
|                 "index": idx, |  | ||||||
|                 "timestamp": env_info["{:}-timestamp".format(idx)], |  | ||||||
|             }, |  | ||||||
|             save_path, |  | ||||||
|             logger, |  | ||||||
|         ) |  | ||||||
|         logger.log("") |         logger.log("") | ||||||
|  |  | ||||||
|         per_timestamp_time.update(time.time() - start_time) |         per_timestamp_time.update(time.time() - start_time) | ||||||
| @@ -188,10 +190,10 @@ if __name__ == "__main__": | |||||||
|         help="The initial learning rate for the optimizer (default is Adam)", |         help="The initial learning rate for the optimizer (default is Adam)", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--batch_size", |         "--meta_batch", | ||||||
|         type=int, |         type=int, | ||||||
|         default=512, |         default=2, | ||||||
|         help="The batch size", |         help="The batch size for the meta-model", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--epochs", |         "--epochs", | ||||||
| @@ -199,6 +201,12 @@ if __name__ == "__main__": | |||||||
|         default=1000, |         default=1000, | ||||||
|         help="The total number of epochs.", |         help="The total number of epochs.", | ||||||
|     ) |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--max_seq", | ||||||
|  |         type=int, | ||||||
|  |         default=5, | ||||||
|  |         help="The maximum length of the sequence.", | ||||||
|  |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--workers", |         "--workers", | ||||||
|         type=int, |         type=int, | ||||||
|   | |||||||
| @@ -34,3 +34,4 @@ def get_model(config: Dict[Text, Any], **kwargs): | |||||||
|     else: |     else: | ||||||
|         raise TypeError("Unkonwn model type: {:}".format(model_type)) |         raise TypeError("Unkonwn model type: {:}".format(model_type)) | ||||||
|     return model |     return model | ||||||
|  |  | ||||||
|   | |||||||
| @@ -31,6 +31,9 @@ class SuperReLU(SuperModule): | |||||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: |     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|         return F.relu(input, inplace=self._inplace) |         return F.relu(input, inplace=self._inplace) | ||||||
|  |  | ||||||
|  |     def forward_with_container(self, input, container, prefix=[]): | ||||||
|  |         return self.forward_raw(input) | ||||||
|  |  | ||||||
|     def extra_repr(self) -> str: |     def extra_repr(self) -> str: | ||||||
|         return "inplace=True" if self._inplace else "" |         return "inplace=True" if self._inplace else "" | ||||||
|  |  | ||||||
| @@ -53,6 +56,29 @@ class SuperLeakyReLU(SuperModule): | |||||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: |     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|         return F.leaky_relu(input, self._negative_slope, self._inplace) |         return F.leaky_relu(input, self._negative_slope, self._inplace) | ||||||
|  |  | ||||||
|  |     def forward_with_container(self, input, container, prefix=[]): | ||||||
|  |         return self.forward_raw(input) | ||||||
|  |  | ||||||
|     def extra_repr(self) -> str: |     def extra_repr(self) -> str: | ||||||
|         inplace_str = "inplace=True" if self._inplace else "" |         inplace_str = "inplace=True" if self._inplace else "" | ||||||
|         return "negative_slope={}{}".format(self._negative_slope, inplace_str) |         return "negative_slope={}{}".format(self._negative_slope, inplace_str) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SuperTanh(SuperModule): | ||||||
|  |     """Applies a the Tanh function element-wise.""" | ||||||
|  |  | ||||||
|  |     def __init__(self) -> None: | ||||||
|  |         super(SuperTanh, self).__init__() | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def abstract_search_space(self): | ||||||
|  |         return spaces.VirtualNode(id(self)) | ||||||
|  |  | ||||||
|  |     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         return self.forward_raw(input) | ||||||
|  |  | ||||||
|  |     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         return torch.tanh(input) | ||||||
|  |  | ||||||
|  |     def forward_with_container(self, input, container, prefix=[]): | ||||||
|  |         return self.forward_raw(input) | ||||||
|   | |||||||
| @@ -111,3 +111,10 @@ class SuperSequential(SuperModule): | |||||||
|         for module in self: |         for module in self: | ||||||
|             input = module(input) |             input = module(input) | ||||||
|         return input |         return input | ||||||
|  |  | ||||||
|  |     def forward_with_container(self, input, container, prefix=[]): | ||||||
|  |         for index, module in enumerate(self): | ||||||
|  |             input = module.forward_with_container( | ||||||
|  |                 input, container, prefix + [str(index)] | ||||||
|  |             ) | ||||||
|  |         return input | ||||||
|   | |||||||
| @@ -27,8 +27,13 @@ from .super_transformer import SuperTransformerEncoderLayer | |||||||
|  |  | ||||||
| from .super_activations import SuperReLU | from .super_activations import SuperReLU | ||||||
| from .super_activations import SuperLeakyReLU | from .super_activations import SuperLeakyReLU | ||||||
|  | from .super_activations import SuperTanh | ||||||
|  |  | ||||||
| super_name2activation = {"relu": SuperReLU, "leaky_relu": SuperLeakyReLU} | super_name2activation = { | ||||||
|  |     "relu": SuperReLU, | ||||||
|  |     "leaky_relu": SuperLeakyReLU, | ||||||
|  |     "tanh": SuperTanh, | ||||||
|  | } | ||||||
|  |  | ||||||
|  |  | ||||||
| from .super_trade_stem import SuperAlphaEBDv1 | from .super_trade_stem import SuperAlphaEBDv1 | ||||||
|   | |||||||
| @@ -115,6 +115,16 @@ class SuperLinear(SuperModule): | |||||||
|             self._in_features, self._out_features, self._bias |             self._in_features, self._out_features, self._bias | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     def forward_with_container(self, input, container, prefix=[]): | ||||||
|  |         super_weight_name = ".".join(prefix + ["_super_weight"]) | ||||||
|  |         super_weight = container.query(super_weight_name) | ||||||
|  |         super_bias_name = ".".join(prefix + ["_super_bias"]) | ||||||
|  |         if container.has(super_bias_name): | ||||||
|  |             super_bias = container.query(super_bias_name) | ||||||
|  |         else: | ||||||
|  |             super_bias = None | ||||||
|  |         return F.linear(input, super_weight, super_bias) | ||||||
|  |  | ||||||
|  |  | ||||||
| class SuperMLPv1(SuperModule): | class SuperMLPv1(SuperModule): | ||||||
|     """An MLP layer: FC -> Activation -> Drop -> FC -> Drop.""" |     """An MLP layer: FC -> Activation -> Drop -> FC -> Drop.""" | ||||||
|   | |||||||
| @@ -39,6 +39,41 @@ class TensorContainer: | |||||||
|         self._param_or_buffers = [] |         self._param_or_buffers = [] | ||||||
|         self._name2index = dict() |         self._name2index = dict() | ||||||
|  |  | ||||||
|  |     def additive(self, tensors): | ||||||
|  |         result = TensorContainer() | ||||||
|  |         for index, name in enumerate(self._names): | ||||||
|  |             new_tensor = self._tensors[index] + tensors[index] | ||||||
|  |             result.append(name, new_tensor, self._param_or_buffers[index]) | ||||||
|  |         return result | ||||||
|  |  | ||||||
|  |     def no_grad_clone(self): | ||||||
|  |         result = TensorContainer() | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             for index, name in enumerate(self._names): | ||||||
|  |                 result.append( | ||||||
|  |                     name, self._tensors[index].clone(), self._param_or_buffers[index] | ||||||
|  |                 ) | ||||||
|  |         return result | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def tensors(self): | ||||||
|  |         return self._tensors | ||||||
|  |  | ||||||
|  |     def flatten(self, tensors=None): | ||||||
|  |         if tensors is None: | ||||||
|  |             tensors = self._tensors | ||||||
|  |         tensors = [tensor.view(-1) for tensor in tensors] | ||||||
|  |         return torch.cat(tensors) | ||||||
|  |  | ||||||
|  |     def unflatten(self, tensor): | ||||||
|  |         tensors, s = [], 0 | ||||||
|  |         for raw_tensor in self._tensors: | ||||||
|  |             length = raw_tensor.numel() | ||||||
|  |             x = torch.reshape(tensor[s : s + length], shape=raw_tensor.shape) | ||||||
|  |             tensors.append(x) | ||||||
|  |             s += length | ||||||
|  |         return tensors | ||||||
|  |  | ||||||
|     def append(self, name, tensor, param_or_buffer): |     def append(self, name, tensor, param_or_buffer): | ||||||
|         if not isinstance(tensor, torch.Tensor): |         if not isinstance(tensor, torch.Tensor): | ||||||
|             raise TypeError( |             raise TypeError( | ||||||
| @@ -54,6 +89,23 @@ class TensorContainer: | |||||||
|         ) |         ) | ||||||
|         self._name2index[name] = len(self._names) - 1 |         self._name2index[name] = len(self._names) - 1 | ||||||
|  |  | ||||||
|  |     def query(self, name): | ||||||
|  |         if not self.has(name): | ||||||
|  |             raise ValueError( | ||||||
|  |                 "The {:} is not in {:}".format(name, list(self._name2index.keys())) | ||||||
|  |             ) | ||||||
|  |         index = self._name2index[name] | ||||||
|  |         return self._tensors[index] | ||||||
|  |  | ||||||
|  |     def has(self, name): | ||||||
|  |         return name in self._name2index | ||||||
|  |  | ||||||
|  |     def has_prefix(self, prefix): | ||||||
|  |         for name, idx in self._name2index.items(): | ||||||
|  |             if name.startswith(prefix): | ||||||
|  |                 return name | ||||||
|  |         return False | ||||||
|  |  | ||||||
|     def numel(self): |     def numel(self): | ||||||
|         total = 0 |         total = 0 | ||||||
|         for tensor in self._tensors: |         for tensor in self._tensors: | ||||||
| @@ -181,3 +233,6 @@ class SuperModule(abc.ABC, nn.Module): | |||||||
|                 ) |                 ) | ||||||
|             ) |             ) | ||||||
|         return outputs |         return outputs | ||||||
|  |  | ||||||
|  |     def forward_with_container(self, inputs, container, prefix=[]): | ||||||
|  |         raise NotImplementedError | ||||||
|   | |||||||
| @@ -161,6 +161,21 @@ class SuperSimpleLearnableNorm(SuperModule): | |||||||
|             mean, std = torch.unsqueeze(mean, dim=0), torch.unsqueeze(std, dim=0) |             mean, std = torch.unsqueeze(mean, dim=0), torch.unsqueeze(std, dim=0) | ||||||
|         return tensor.sub_(mean).div_(std) |         return tensor.sub_(mean).div_(std) | ||||||
|  |  | ||||||
|  |     def forward_with_container(self, input, container, prefix=[]): | ||||||
|  |         if not self._inplace: | ||||||
|  |             tensor = input.clone() | ||||||
|  |         else: | ||||||
|  |             tensor = input | ||||||
|  |         mean_name = ".".join(prefix + ["_mean"]) | ||||||
|  |         std_name = ".".join(prefix + ["_std"]) | ||||||
|  |         mean, std = ( | ||||||
|  |             container.query(mean_name).to(tensor.device), | ||||||
|  |             torch.abs(container.query(std_name).to(tensor.device)) + self._eps, | ||||||
|  |         ) | ||||||
|  |         while mean.ndim < tensor.ndim: | ||||||
|  |             mean, std = torch.unsqueeze(mean, dim=0), torch.unsqueeze(std, dim=0) | ||||||
|  |         return tensor.sub_(mean).div_(std) | ||||||
|  |  | ||||||
|     def extra_repr(self) -> str: |     def extra_repr(self) -> str: | ||||||
|         return "mean={mean}, std={std}, inplace={inplace}".format( |         return "mean={mean}, std={std}, inplace={inplace}".format( | ||||||
|             mean=self._mean.item(), std=self._std.item(), inplace=self._inplace |             mean=self._mean.item(), std=self._std.item(), inplace=self._inplace | ||||||
| @@ -191,3 +206,6 @@ class SuperIdentity(SuperModule): | |||||||
|  |  | ||||||
|     def extra_repr(self) -> str: |     def extra_repr(self) -> str: | ||||||
|         return "inplace={inplace}".format(inplace=self._inplace) |         return "inplace={inplace}".format(inplace=self._inplace) | ||||||
|  |  | ||||||
|  |     def forward_with_container(self, input, container, prefix=[]): | ||||||
|  |         return self.forward_raw(input) | ||||||
|   | |||||||
							
								
								
									
										120
									
								
								lib/xlayers/super_rl_actor.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										120
									
								
								lib/xlayers/super_rl_actor.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,120 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||||
|  | ##################################################### | ||||||
|  | # DISABLED / NOT-FINISHED | ||||||
|  | ##################################################### | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | import torch.nn.functional as F | ||||||
|  |  | ||||||
|  | import math | ||||||
|  | from typing import Optional, Callable | ||||||
|  |  | ||||||
|  | import spaces | ||||||
|  | from .super_container import SuperSequential | ||||||
|  | from .super_linear import SuperLinear | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SuperActor(SuperModule): | ||||||
|  |     """A Actor in RL.""" | ||||||
|  |  | ||||||
|  |     def _distribution(self, obs): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def _log_prob_from_distribution(self, pi, act): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def forward_candidate(self, **kwargs): | ||||||
|  |         return self.forward_raw(**kwargs) | ||||||
|  |  | ||||||
|  |     def forward_raw(self, obs, act=None): | ||||||
|  |         # Produce action distributions for given observations, and | ||||||
|  |         # optionally compute the log likelihood of given actions under | ||||||
|  |         # those distributions. | ||||||
|  |         pi = self._distribution(obs) | ||||||
|  |         logp_a = None | ||||||
|  |         if act is not None: | ||||||
|  |             logp_a = self._log_prob_from_distribution(pi, act) | ||||||
|  |         return pi, logp_a | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SuperLfnaMetaMLP(SuperModule): | ||||||
|  |     def __init__(self, obs_dim, hidden_sizes, act_cls): | ||||||
|  |         super(SuperLfnaMetaMLP).__init__() | ||||||
|  |         self.delta_net = SuperSequential( | ||||||
|  |             SuperLinear(obs_dim, hidden_sizes[0]), | ||||||
|  |             act_cls(), | ||||||
|  |             SuperLinear(hidden_sizes[0], hidden_sizes[1]), | ||||||
|  |             act_cls(), | ||||||
|  |             SuperLinear(hidden_sizes[1], 1), | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SuperLfnaMetaMLP(SuperModule): | ||||||
|  |     def __init__(self, obs_dim, act_dim, hidden_sizes, act_cls): | ||||||
|  |         super(SuperLfnaMetaMLP).__init__() | ||||||
|  |         log_std = -0.5 * np.ones(act_dim, dtype=np.float32) | ||||||
|  |         self.log_std = torch.nn.Parameter(torch.as_tensor(log_std)) | ||||||
|  |         self.mu_net = SuperSequential( | ||||||
|  |             SuperLinear(obs_dim, hidden_sizes[0]), | ||||||
|  |             act_cls(), | ||||||
|  |             SuperLinear(hidden_sizes[0], hidden_sizes[1]), | ||||||
|  |             act_cls(), | ||||||
|  |             SuperLinear(hidden_sizes[1], act_dim), | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def _distribution(self, obs): | ||||||
|  |         mu = self.mu_net(obs) | ||||||
|  |         std = torch.exp(self.log_std) | ||||||
|  |         return Normal(mu, std) | ||||||
|  |  | ||||||
|  |     def _log_prob_from_distribution(self, pi, act): | ||||||
|  |         return pi.log_prob(act).sum(axis=-1) | ||||||
|  |  | ||||||
|  |     def forward_candidate(self, **kwargs): | ||||||
|  |         return self.forward_raw(**kwargs) | ||||||
|  |  | ||||||
|  |     def forward_raw(self, obs, act=None): | ||||||
|  |         # Produce action distributions for given observations, and | ||||||
|  |         # optionally compute the log likelihood of given actions under | ||||||
|  |         # those distributions. | ||||||
|  |         pi = self._distribution(obs) | ||||||
|  |         logp_a = None | ||||||
|  |         if act is not None: | ||||||
|  |             logp_a = self._log_prob_from_distribution(pi, act) | ||||||
|  |         return pi, logp_a | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SuperMLPGaussianActor(SuperModule): | ||||||
|  |     def __init__(self, obs_dim, act_dim, hidden_sizes, act_cls): | ||||||
|  |         super(SuperMLPGaussianActor).__init__() | ||||||
|  |         log_std = -0.5 * np.ones(act_dim, dtype=np.float32) | ||||||
|  |         self.log_std = torch.nn.Parameter(torch.as_tensor(log_std)) | ||||||
|  |         self.mu_net = SuperSequential( | ||||||
|  |             SuperLinear(obs_dim, hidden_sizes[0]), | ||||||
|  |             act_cls(), | ||||||
|  |             SuperLinear(hidden_sizes[0], hidden_sizes[1]), | ||||||
|  |             act_cls(), | ||||||
|  |             SuperLinear(hidden_sizes[1], act_dim), | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def _distribution(self, obs): | ||||||
|  |         mu = self.mu_net(obs) | ||||||
|  |         std = torch.exp(self.log_std) | ||||||
|  |         return Normal(mu, std) | ||||||
|  |  | ||||||
|  |     def _log_prob_from_distribution(self, pi, act): | ||||||
|  |         return pi.log_prob(act).sum(axis=-1) | ||||||
|  |  | ||||||
|  |     def forward_candidate(self, **kwargs): | ||||||
|  |         return self.forward_raw(**kwargs) | ||||||
|  |  | ||||||
|  |     def forward_raw(self, obs, act=None): | ||||||
|  |         # Produce action distributions for given observations, and | ||||||
|  |         # optionally compute the log likelihood of given actions under | ||||||
|  |         # those distributions. | ||||||
|  |         pi = self._distribution(obs) | ||||||
|  |         logp_a = None | ||||||
|  |         if act is not None: | ||||||
|  |             logp_a = self._log_prob_from_distribution(pi, act) | ||||||
|  |         return pi, logp_a | ||||||
		Reference in New Issue
	
	Block a user