From 8109ed166aac9b42b7a44e8666b8be3790a14699 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Sat, 22 May 2021 23:04:24 +0800 Subject: [PATCH] Update xlayers --- CHANGE-LOG.md | 14 +++++----- exps/LFNA/lfna.py | 33 ++++++++++++++++++----- exps/LFNA/lfna_meta_model.py | 18 +++++++++---- xautodl/xlayers/super_attention.py | 40 ++++++++++++++++++++++------ xautodl/xlayers/super_module.py | 30 ++++++++++++++++----- xautodl/xlayers/super_transformer.py | 2 ++ 6 files changed, 104 insertions(+), 33 deletions(-) diff --git a/CHANGE-LOG.md b/CHANGE-LOG.md index 717ffb6..aaeae65 100644 --- a/CHANGE-LOG.md +++ b/CHANGE-LOG.md @@ -1,12 +1,12 @@ # This file shows the major updates of this repo. -- [2020.04.11] [4ef9531] Add change log as `CHANGE-LOG.md`. -- [2019.12.20] [69ca086] Release NAS-Bench-201. -- [2019.09.28] [f8f3f38] TAS and SETN codes were publicly released. -- [2019.01.31] [13e908f] GDAS codes were publicly released. -- [2020.07.01] [a45808b] Upgrade NAS-API to the 2.0 version. -- [2020.09.16] [7052265] Create NATS-BENCH. +- [2020.04.11] [4ef9531](https://github.com/D-X-Y/AutoDL-Projects/tree/4ef9531) Add change log as `CHANGE-LOG.md`. +- [2019.12.20] [69ca086](https://github.com/D-X-Y/AutoDL-Projects/tree/69ca086) Release NAS-Bench-201. +- [2019.09.28] [f8f3f38](https://github.com/D-X-Y/AutoDL-Projects/tree/f8f3f38) TAS and SETN codes were publicly released. +- [2019.01.31] [13e908f](https://github.com/D-X-Y/AutoDL-Projects/tree/13e908f) GDAS codes were publicly released. +- [2020.07.01] [a45808b](https://github.com/D-X-Y/AutoDL-Projects/tree/a45808b) Upgrade NAS-API to the 2.0 version. +- [2020.09.16] [7052265](https://github.com/D-X-Y/AutoDL-Projects/tree/7052265) Create NATS-BENCH. - [2020.10.15] [446262a](https://github.com/D-X-Y/AutoDL-Projects/tree/446262a) Update NATS-BENCH to version 1.0 - [2020.12.20] [dae387a](https://github.com/D-X-Y/AutoDL-Projects/tree/dae387a) Update NATS-BENCH to version 1.1 - [2021.05.18] [98fadf8](https://github.com/D-X-Y/AutoDL-Projects/tree/98fadf8) Before moving to `xautodl` -- [2021.05.21] [b4e8eae](https://github.com/D-X-Y/AutoDL-Projects/tree/b4e8eae) `xautodl` is close to ready +- [2021.05.21] [5b09f05](https://github.com/D-X-Y/AutoDL-Projects/tree/5b09f05) `xautodl` is close to ready diff --git a/exps/LFNA/lfna.py b/exps/LFNA/lfna.py index db7f36b..73e56f1 100644 --- a/exps/LFNA/lfna.py +++ b/exps/LFNA/lfna.py @@ -106,8 +106,13 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): logger.log("Using the optimizer: {:}".format(optimizer)) meta_model.set_best_dir(logger.path(None) / "ckps-pretrain-v2") + final_best_name = "final-pretrain-{:}.pth".format(args.rand_seed) + if meta_model.has_best(final_best_name): + meta_model.load_best(final_best_name) + return + meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed)) - last_success_epoch = 0 + last_success_epoch, early_stop_thresh = 0, args.pretrain_early_stop_thresh per_epoch_time, start_time = AverageMeter(), time.time() for iepoch in range(args.epochs): left_time = "Time Left: {:}".format( @@ -164,14 +169,21 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): ) + ", batch={:}".format(len(total_meta_losses)) + ", success={:}, best_score={:.4f}".format(success, -best_score) - + " {:}".format(left_time) + + ", LS={:}/{:}".format(last_success_epoch, early_stop_thresh) + + ", {:}".format(left_time) ) - if iepoch - last_success_epoch >= args.early_stop_thresh * 5: + if success: + last_success_epoch = iepoch + if iepoch - last_success_epoch >= early_stop_thresh: logger.log("Early stop the pre-training at {:}".format(iepoch)) break per_epoch_time.update(time.time() - start_time) start_time = time.time() meta_model.load_best() + # save to the final model + meta_model.set_best_name(final_best_name) + success, _ = meta_model.save_best(best_score + 1e-6) + assert success def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger): @@ -189,7 +201,7 @@ def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger): meta_model.set_best_dir(logger.path(None) / "ckps-pretrain-v1") meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed)) per_epoch_time, start_time = AverageMeter(), time.time() - last_success_epoch = 0 + last_success_epoch, early_stop_thresh = 0, args.pretrain_early_stop_thresh for iepoch in range(args.epochs): left_time = "Time Left: {:}".format( convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) @@ -232,9 +244,12 @@ def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger): ) + ", batch={:}".format(len(losses)) + ", success={:}, best_score={:.4f}".format(success, -best_score) + + ", LS={:}/{:}".format(last_success_epoch, early_stop_thresh) + " {:}".format(left_time) ) - if iepoch - last_success_epoch >= args.early_stop_thresh * 5: + if success: + last_success_epoch = iepoch + if iepoch - last_success_epoch >= early_stop_thresh: logger.log("Early stop the pre-training at {:}".format(iepoch)) break per_epoch_time.update(time.time() - start_time) @@ -521,7 +536,7 @@ if __name__ == "__main__": parser.add_argument( "--refine_lr", type=float, - default=0.005, + default=0.001, help="The learning rate for the optimizer, during refine", ) parser.add_argument( @@ -533,6 +548,12 @@ if __name__ == "__main__": default=20, help="The #epochs for early stop.", ) + parser.add_argument( + "--pretrain_early_stop_thresh", + type=int, + default=200, + help="The #epochs for early stop.", + ) parser.add_argument( "--seq_length", type=int, default=10, help="The sequence length." ) diff --git a/exps/LFNA/lfna_meta_model.py b/exps/LFNA/lfna_meta_model.py index 10cfd51..19f80ad 100644 --- a/exps/LFNA/lfna_meta_model.py +++ b/exps/LFNA/lfna_meta_model.py @@ -70,6 +70,7 @@ class LFNA_Meta(super_core.SuperModule): dropout, norm_affine=False, order=super_core.LayerOrder.PostNorm, + use_mask=True, ) ) layers.append(super_core.SuperLinear(time_embedding * 2, time_embedding)) @@ -162,11 +163,14 @@ class LFNA_Meta(super_core.SuperModule): def _obtain_time_embed(self, timestamps): # timestamps is a batch of sequence of timestamps batch, seq = timestamps.shape + meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed timestamp_q_embed = self._tscalar_embed(timestamps) - timestamp_k_embed = self._tscalar_embed(self.meta_timestamps.view(1, -1)) - timestamp_v_embed = self.super_meta_embed.unsqueeze(dim=0) + timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1)) + timestamp_v_embed = meta_embeds.unsqueeze(dim=0) + # create the mask + mask = torch.unsqueeze(timestamps, dim=-1) <= meta_timestamps.view(1, 1, -1) timestamp_embeds = self._trans_att( - timestamp_q_embed, timestamp_k_embed, timestamp_v_embed + timestamp_q_embed, timestamp_k_embed, timestamp_v_embed, mask ) # relative_timestamps = timestamps - timestamps[:, :1] # relative_pos_embeds = self._tscalar_embed(relative_timestamps) @@ -186,8 +190,12 @@ class LFNA_Meta(super_core.SuperModule): layer_embed = self._super_layer_embed.view(1, 1, num_layer, -1).expand( batch, seq, -1, -1 ) - joint_embed = torch.cat((meta_embed, layer_embed), dim=-1) - batch_weights = self._generator(joint_embed) + joint_embed = torch.cat( + (meta_embed, layer_embed), dim=-1 + ) # batch, seq, num-layers, input-dim + batch_weights = self._generator( + joint_embed + ) # batch, seq, num-layers, num-weights batch_containers = [] for seq_weights in torch.split(batch_weights, 1): seq_containers = [] diff --git a/xautodl/xlayers/super_attention.py b/xautodl/xlayers/super_attention.py index a1eb317..1adfae6 100644 --- a/xautodl/xlayers/super_attention.py +++ b/xautodl/xlayers/super_attention.py @@ -31,12 +31,15 @@ class SuperSelfAttention(SuperModule): qkv_bias: BoolSpaceType = False, attn_drop: Optional[float] = None, proj_drop: Optional[float] = None, + use_mask=False, ): super(SuperSelfAttention, self).__init__() self._input_dim = input_dim self._proj_dim = proj_dim self._num_heads = num_heads self._qkv_bias = qkv_bias + self._use_mask = use_mask + self._infinity = 1e9 self.q_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias) self.k_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias) @@ -113,6 +116,12 @@ class SuperSelfAttention(SuperModule): .permute(0, 2, 1, 3) ) attn_v1 = (q_v1 @ k_v1.transpose(-2, -1)) * math.sqrt(head_dim) + if self._use_mask: + mask = torch.triu( + torch.ones((N, N), dtype=torch.bool, device=input.device), 1 + ) + mask = torch.unsqueeze(torch.unsqueeze(mask, dim=0), dim=0) + attn_v1 = attn_v1.masked_fill(mask, -self._infinity) attn_v1 = attn_v1.softmax(dim=-1) # B * #head * N * N attn_v1 = self.attn_drop(attn_v1) feats_v1 = (attn_v1 @ v_v1).permute(0, 2, 1, 3).reshape(B, N, -1) @@ -147,8 +156,14 @@ class SuperSelfAttention(SuperModule): return outs def extra_repr(self) -> str: - return "input_dim={:}, proj_dim={:}, num_heads={:}".format( - self._input_dim, self._proj_dim, self._num_heads + return ( + "input_dim={:}, proj_dim={:}, num_heads={:}, mask={:}, infinity={:}".format( + self._input_dim, + self._proj_dim, + self._num_heads, + self._use_mask, + self._infinity, + ) ) @@ -181,6 +196,7 @@ class SuperQKVAttention(SuperModule): self.attn_drop = nn.Dropout(attn_drop or 0.0) self.proj = SuperLinear(proj_dim, proj_dim) self.proj_drop = nn.Dropout(proj_drop or 0.0) + self._infinity = 1e9 @property def num_heads(self): @@ -232,7 +248,9 @@ class SuperQKVAttention(SuperModule): if "proj" in abstract_child: self.proj.apply_candidate(abstract_child["proj"]) - def forward_qkv(self, q_tensor, k_tensor, v_tensor, num_head: int) -> torch.Tensor: + def forward_qkv( + self, q_tensor, k_tensor, v_tensor, num_head: int, mask=None + ) -> torch.Tensor: q = self.q_fc(q_tensor) B, N, C = q.shape @@ -257,6 +275,9 @@ class SuperQKVAttention(SuperModule): ) # compute the attention map attn_v1 = (q_v1 @ k_v1.transpose(-2, -1)) * math.sqrt(head_dim) + if mask is not None: + mask = torch.unsqueeze(mask, dim=1) + attn_v1 = attn_v1.masked_fill(mask, -self._infinity) attn_v1 = attn_v1.softmax(dim=-1) # B * #head * N * S attn_v1 = self.attn_drop(attn_v1) @@ -281,26 +302,29 @@ class SuperQKVAttention(SuperModule): feats = torch.cat([feats_v1, feats_v2], dim=-1) return feats - def forward_candidate(self, q_tensor, k_tensor, v_tensor) -> torch.Tensor: + def forward_candidate( + self, q_tensor, k_tensor, v_tensor, mask=None + ) -> torch.Tensor: # check the num_heads: if not spaces.is_determined(self._num_heads): num_heads = self.abstract_child["_num_heads"].value else: num_heads = spaces.get_determined_value(self._num_heads) - feats = self.forward_qkv(q_tensor, k_tensor, v_tensor, num_heads) + feats = self.forward_qkv(q_tensor, k_tensor, v_tensor, num_heads, mask) outs = self.proj(feats) outs = self.proj_drop(outs) return outs - def forward_raw(self, q_tensor, k_tensor, v_tensor) -> torch.Tensor: - feats = self.forward_qkv(q_tensor, k_tensor, v_tensor, self.num_heads) + def forward_raw(self, q_tensor, k_tensor, v_tensor, mask=None) -> torch.Tensor: + feats = self.forward_qkv(q_tensor, k_tensor, v_tensor, self.num_heads, mask) outs = self.proj(feats) outs = self.proj_drop(outs) return outs def extra_repr(self) -> str: - return "input_dim={:}, proj_dim={:}, num_heads={:}".format( + return "input_dim={:}, proj_dim={:}, num_heads={:}, infinity={:}".format( (self.in_q_dim, self.in_k_dim, self.in_v_dim), self._proj_dim, self._num_heads, + self._infinity, ) diff --git a/xautodl/xlayers/super_module.py b/xautodl/xlayers/super_module.py index b0fb8eb..ff56f34 100644 --- a/xautodl/xlayers/super_module.py +++ b/xautodl/xlayers/super_module.py @@ -117,16 +117,32 @@ class SuperModule(abc.ABC, nn.Module): else: return False, self._meta_info[BEST_SCORE_KEY] - def load_best(self): - if BEST_DIR_KEY not in self._meta_info or BEST_SCORE_KEY not in self._meta_info: - raise ValueError("Please call save_best at first") - best_save_path = os.path.join( - self._meta_info[BEST_DIR_KEY], - "best-{:}.pth".format(self.__class__.__name__), - ) + def load_best(self, best_save_path=None): + if best_save_path is None: + if ( + BEST_DIR_KEY not in self._meta_info + or BEST_SCORE_KEY not in self._meta_info + ): + raise ValueError("Please call save_best at first") + best_save_name = self._meta_info.get( + BEST_NAME_KEY, "best-{:}.pth".format(self.__class__.__name__) + ) + best_save_path = os.path.join(self._meta_info[BEST_DIR_KEY], best_save_name) state_dict = torch.load(best_save_path) self.load_state_dict(state_dict) + def has_best(self, best_name=None): + if BEST_DIR_KEY not in self._meta_info: + raise ValueError("Please set BEST_DIR_KEY at first") + if best_name is None: + best_save_name = self._meta_info.get( + BEST_NAME_KEY, "best-{:}.pth".format(self.__class__.__name__) + ) + else: + best_save_name = best_name + best_save_path = os.path.join(self._meta_info[BEST_DIR_KEY], best_save_name) + return os.path.exists(best_save_path) + @property def abstract_search_space(self): raise NotImplementedError diff --git a/xautodl/xlayers/super_transformer.py b/xautodl/xlayers/super_transformer.py index ef43879..8e2e3e9 100644 --- a/xautodl/xlayers/super_transformer.py +++ b/xautodl/xlayers/super_transformer.py @@ -45,6 +45,7 @@ class SuperTransformerEncoderLayer(SuperModule): norm_affine: bool = True, act_layer: Callable[[], nn.Module] = nn.GELU, order: LayerOrder = LayerOrder.PreNorm, + use_mask: bool = False, ): super(SuperTransformerEncoderLayer, self).__init__() mha = SuperSelfAttention( @@ -54,6 +55,7 @@ class SuperTransformerEncoderLayer(SuperModule): qkv_bias=qkv_bias, attn_drop=drop, proj_drop=drop, + use_mask=use_mask, ) mlp = SuperMLPv2( d_model,