Re-org GeMOSA codes
This commit is contained in:
		| @@ -1,117 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ##################################################### | ||||
| import copy | ||||
| import torch | ||||
|  | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| from xlayers import super_core | ||||
| from xlayers import trunc_normal_ | ||||
| from models.xcore import get_model | ||||
|  | ||||
|  | ||||
| class HyperNet(super_core.SuperModule): | ||||
|     """The hyper-network.""" | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         shape_container, | ||||
|         layer_embeding, | ||||
|         task_embedding, | ||||
|         num_tasks, | ||||
|         return_container=True, | ||||
|     ): | ||||
|         super(HyperNet, self).__init__() | ||||
|         self._shape_container = shape_container | ||||
|         self._num_layers = len(shape_container) | ||||
|         self._numel_per_layer = [] | ||||
|         for ilayer in range(self._num_layers): | ||||
|             self._numel_per_layer.append(shape_container[ilayer].numel()) | ||||
|  | ||||
|         self.register_parameter( | ||||
|             "_super_layer_embed", | ||||
|             torch.nn.Parameter(torch.Tensor(self._num_layers, layer_embeding)), | ||||
|         ) | ||||
|         self.register_parameter( | ||||
|             "_super_task_embed", | ||||
|             torch.nn.Parameter(torch.Tensor(num_tasks, task_embedding)), | ||||
|         ) | ||||
|         trunc_normal_(self._super_layer_embed, std=0.02) | ||||
|         trunc_normal_(self._super_task_embed, std=0.02) | ||||
|  | ||||
|         model_kwargs = dict( | ||||
|             config=dict(model_type="dual_norm_mlp"), | ||||
|             input_dim=layer_embeding + task_embedding, | ||||
|             output_dim=max(self._numel_per_layer), | ||||
|             hidden_dims=[(layer_embeding + task_embedding) * 2] * 3, | ||||
|             act_cls="gelu", | ||||
|             norm_cls="layer_norm_1d", | ||||
|             dropout=0.2, | ||||
|         ) | ||||
|         self._generator = get_model(**model_kwargs) | ||||
|         self._return_container = return_container | ||||
|         print("generator: {:}".format(self._generator)) | ||||
|  | ||||
|     def forward_raw(self, task_embed_id): | ||||
|         layer_embed = self._super_layer_embed | ||||
|         task_embed = ( | ||||
|             self._super_task_embed[task_embed_id] | ||||
|             .view(1, -1) | ||||
|             .expand(self._num_layers, -1) | ||||
|         ) | ||||
|  | ||||
|         joint_embed = torch.cat((task_embed, layer_embed), dim=-1) | ||||
|         weights = self._generator(joint_embed) | ||||
|         if self._return_container: | ||||
|             weights = torch.split(weights, 1) | ||||
|             return self._shape_container.translate(weights) | ||||
|         else: | ||||
|             return weights | ||||
|  | ||||
|     def forward_candidate(self, input): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|         return "(_super_layer_embed): {:}".format(list(self._super_layer_embed.shape)) | ||||
|  | ||||
|  | ||||
| class HyperNet_VX(super_core.SuperModule): | ||||
|     def __init__(self, shape_container, input_embeding, return_container=True): | ||||
|         super(HyperNet_VX, self).__init__() | ||||
|         self._shape_container = shape_container | ||||
|         self._num_layers = len(shape_container) | ||||
|         self._numel_per_layer = [] | ||||
|         for ilayer in range(self._num_layers): | ||||
|             self._numel_per_layer.append(shape_container[ilayer].numel()) | ||||
|  | ||||
|         self.register_parameter( | ||||
|             "_super_layer_embed", | ||||
|             torch.nn.Parameter(torch.Tensor(self._num_layers, input_embeding)), | ||||
|         ) | ||||
|         trunc_normal_(self._super_layer_embed, std=0.02) | ||||
|  | ||||
|         model_kwargs = dict( | ||||
|             input_dim=input_embeding, | ||||
|             output_dim=max(self._numel_per_layer), | ||||
|             hidden_dim=input_embeding * 4, | ||||
|             act_cls="sigmoid", | ||||
|             norm_cls="identity", | ||||
|         ) | ||||
|         self._generator = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||
|         self._return_container = return_container | ||||
|         print("generator: {:}".format(self._generator)) | ||||
|  | ||||
|     def forward_raw(self, input): | ||||
|         weights = self._generator(self._super_layer_embed) | ||||
|         if self._return_container: | ||||
|             weights = torch.split(weights, 1) | ||||
|             return self._shape_container.translate(weights) | ||||
|         else: | ||||
|             return weights | ||||
|  | ||||
|     def forward_candidate(self, input): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|         return "(_super_layer_embed): {:}".format(list(self._super_layer_embed.shape)) | ||||
| @@ -35,7 +35,7 @@ from xautodl.models.xcore import get_model | ||||
| from xautodl.xlayers import super_core, trunc_normal_ | ||||
|  | ||||
| from lfna_utils import lfna_setup, train_model, TimeData | ||||
| from lfna_meta_model import MetaModelV1 | ||||
| from meta_model import MetaModelV1 | ||||
|  | ||||
|  | ||||
| def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=False): | ||||
| @@ -106,7 +106,7 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger): | ||||
|         ) | ||||
|         optimizer.zero_grad() | ||||
|  | ||||
|         generated_time_embeds = meta_model(meta_model.meta_timestamps, None, True) | ||||
|         generated_time_embeds = gen_time_embed(meta_model.meta_timestamps) | ||||
|  | ||||
|         batch_indexes = random.choices(total_indexes, k=args.meta_batch) | ||||
|  | ||||
| @@ -219,11 +219,11 @@ def main(args): | ||||
|     w_containers, loss_meter = online_evaluate( | ||||
|         all_env, meta_model, base_model, criterion, args, logger, True | ||||
|     ) | ||||
|     logger.log("In this enviornment, the loss-meter is {:}".format(loss_meter)) | ||||
|     logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter)) | ||||
|  | ||||
|     save_checkpoint( | ||||
|         {"w_containers": w_containers}, | ||||
|         logger.path(None) / "final-ckp.pth", | ||||
|         {"all_w_containers": w_containers}, | ||||
|         logger.path(None) / "final-ckp-{:}.pth".format(args.rand_seed), | ||||
|         logger, | ||||
|     ) | ||||
|  | ||||
|   | ||||
| @@ -154,8 +154,9 @@ class MetaModelV1(super_core.SuperModule): | ||||
|                     (self._append_meta_embed["fixed"], meta_embed), dim=0 | ||||
|                 ) | ||||
| 
 | ||||
|     def _obtain_time_embed(self, timestamps): | ||||
|         # timestamps is a batch of sequence of timestamps | ||||
|     def gen_time_embed(self, timestamps): | ||||
|         # timestamps is a batch of timestamps | ||||
|         [B] = timestamps.shape | ||||
|         # batch, seq = timestamps.shape | ||||
|         timestamps = timestamps.view(-1, 1) | ||||
|         meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed | ||||
| @@ -179,15 +180,8 @@ class MetaModelV1(super_core.SuperModule): | ||||
|         ) | ||||
|         return timestamp_embeds[:, -1, :] | ||||
| 
 | ||||
|     def forward_raw(self, timestamps, time_embeds, tembed_only=False): | ||||
|         if time_embeds is None: | ||||
|             [B] = timestamps.shape | ||||
|             time_embeds = self._obtain_time_embed(timestamps) | ||||
|         else:  # use the hyper-net only | ||||
|             time_seq = None | ||||
|             B, _ = time_embeds.shape | ||||
|         if tembed_only: | ||||
|             return time_embeds | ||||
|     def gen_model(self, time_embeds): | ||||
|         B, _ = time_embeds.shape | ||||
|         # create joint embed | ||||
|         num_layer, _ = self._super_layer_embed.shape | ||||
|         # The shape of `joint_embed` is batch * num-layers * input-dim | ||||
| @@ -206,6 +200,9 @@ class MetaModelV1(super_core.SuperModule): | ||||
|             ) | ||||
|         return batch_containers, time_embeds | ||||
| 
 | ||||
|     def forward_raw(self, timestamps, time_embeds, tembed_only=False): | ||||
|         raise NotImplementedError | ||||
| 
 | ||||
|     def forward_candidate(self, input): | ||||
|         raise NotImplementedError | ||||
| 
 | ||||
| @@ -1,8 +1,9 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # | ||||
| ############################################################################ | ||||
| # python exps/GMOA/vis-synthetic.py --env_version v1                       # | ||||
| # python exps/GMOA/vis-synthetic.py --env_version v2                       # | ||||
| # python exps/GeMOSA/vis-synthetic.py --env_version v1                     # | ||||
| # python exps/GeMOSA/vis-synthetic.py --env_version v2                     # | ||||
| # python exps/GeMOSA/vis-synthetic.py --env_version v2                     # | ||||
| ############################################################################ | ||||
| import os, sys, copy, random | ||||
| import torch | ||||
| @@ -181,7 +182,7 @@ def compare_cl(save_dir): | ||||
| def visualize_env(save_dir, version): | ||||
|     save_dir = Path(str(save_dir)) | ||||
|     for substr in ("pdf", "png"): | ||||
|         sub_save_dir = save_dir / substr | ||||
|         sub_save_dir = save_dir / "{:}-{:}".format(substr, version) | ||||
|         sub_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|     dynamic_env = get_synthetic_env(version=version) | ||||
| @@ -190,6 +191,8 @@ def visualize_env(save_dir, version): | ||||
|         allxs.append(allx) | ||||
|         allys.append(ally) | ||||
|     allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1) | ||||
|     print("x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item())) | ||||
|     print("y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item())) | ||||
|     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||
|         dpi, width, height = 30, 1800, 1400 | ||||
|         figsize = width / float(dpi), height / float(dpi) | ||||
| @@ -210,14 +213,22 @@ def visualize_env(save_dir, version): | ||||
|         cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1)) | ||||
|         cur_ax.legend(loc=1, fontsize=LegendFontsize) | ||||
|  | ||||
|         pdf_save_path = save_dir / "pdf" / "v{:}-{:05d}.pdf".format(version, idx) | ||||
|         pdf_save_path = ( | ||||
|             save_dir | ||||
|             / "pdf-{:}".format(version) | ||||
|             / "v{:}-{:05d}.pdf".format(version, idx) | ||||
|         ) | ||||
|         fig.savefig(str(pdf_save_path), dpi=dpi, bbox_inches="tight", format="pdf") | ||||
|         png_save_path = save_dir / "png" / "v{:}-{:05d}.png".format(version, idx) | ||||
|         png_save_path = ( | ||||
|             save_dir | ||||
|             / "png-{:}".format(version) | ||||
|             / "v{:}-{:05d}.png".format(version, idx) | ||||
|         ) | ||||
|         fig.savefig(str(png_save_path), dpi=dpi, bbox_inches="tight", format="png") | ||||
|         plt.close("all") | ||||
|     save_dir = save_dir.resolve() | ||||
|     base_cmd = "ffmpeg -y -i {xdir}/v{version}-%05d.png -vf scale=1800:1400 -pix_fmt yuv420p -vb 5000k".format( | ||||
|         xdir=save_dir / "png", version=version | ||||
|         xdir=save_dir / "png-{:}".format(version), version=version | ||||
|     ) | ||||
|     print(base_cmd) | ||||
|     os.system("{:} {xdir}/env-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version)) | ||||
| @@ -367,7 +378,7 @@ if __name__ == "__main__": | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     # visualize_env(os.path.join(args.save_dir, "vis-env"), "v1") | ||||
|     visualize_env(os.path.join(args.save_dir, "vis-env"), args.env_version) | ||||
|     # visualize_env(os.path.join(args.save_dir, "vis-env"), "v2") | ||||
|     compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version) | ||||
|     # compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version) | ||||
|     # compare_cl(os.path.join(args.save_dir, "compare-cl")) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user