Update xlayers
This commit is contained in:
		@@ -1,12 +1,12 @@
 | 
			
		||||
# This file shows the major updates of this repo.
 | 
			
		||||
 | 
			
		||||
- [2020.04.11] [4ef9531] Add change log as `CHANGE-LOG.md`.
 | 
			
		||||
- [2019.12.20] [69ca086] Release NAS-Bench-201.
 | 
			
		||||
- [2019.09.28] [f8f3f38] TAS and SETN codes were publicly released.
 | 
			
		||||
- [2019.01.31] [13e908f] GDAS codes were publicly released.
 | 
			
		||||
- [2020.07.01] [a45808b] Upgrade NAS-API to the 2.0 version.
 | 
			
		||||
- [2020.09.16] [7052265] Create NATS-BENCH.
 | 
			
		||||
- [2020.04.11] [4ef9531](https://github.com/D-X-Y/AutoDL-Projects/tree/4ef9531) Add change log as `CHANGE-LOG.md`.
 | 
			
		||||
- [2019.12.20] [69ca086](https://github.com/D-X-Y/AutoDL-Projects/tree/69ca086) Release NAS-Bench-201.
 | 
			
		||||
- [2019.09.28] [f8f3f38](https://github.com/D-X-Y/AutoDL-Projects/tree/f8f3f38) TAS and SETN codes were publicly released.
 | 
			
		||||
- [2019.01.31] [13e908f](https://github.com/D-X-Y/AutoDL-Projects/tree/13e908f) GDAS codes were publicly released.
 | 
			
		||||
- [2020.07.01] [a45808b](https://github.com/D-X-Y/AutoDL-Projects/tree/a45808b) Upgrade NAS-API to the 2.0 version.
 | 
			
		||||
- [2020.09.16] [7052265](https://github.com/D-X-Y/AutoDL-Projects/tree/7052265) Create NATS-BENCH.
 | 
			
		||||
- [2020.10.15] [446262a](https://github.com/D-X-Y/AutoDL-Projects/tree/446262a) Update NATS-BENCH to version 1.0
 | 
			
		||||
- [2020.12.20] [dae387a](https://github.com/D-X-Y/AutoDL-Projects/tree/dae387a) Update NATS-BENCH to version 1.1
 | 
			
		||||
- [2021.05.18] [98fadf8](https://github.com/D-X-Y/AutoDL-Projects/tree/98fadf8) Before moving to `xautodl`
 | 
			
		||||
- [2021.05.21] [b4e8eae](https://github.com/D-X-Y/AutoDL-Projects/tree/b4e8eae) `xautodl` is close to ready
 | 
			
		||||
- [2021.05.21] [5b09f05](https://github.com/D-X-Y/AutoDL-Projects/tree/5b09f05) `xautodl` is close to ready
 | 
			
		||||
 
 | 
			
		||||
@@ -106,8 +106,13 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
 | 
			
		||||
    logger.log("Using the optimizer: {:}".format(optimizer))
 | 
			
		||||
 | 
			
		||||
    meta_model.set_best_dir(logger.path(None) / "ckps-pretrain-v2")
 | 
			
		||||
    final_best_name = "final-pretrain-{:}.pth".format(args.rand_seed)
 | 
			
		||||
    if meta_model.has_best(final_best_name):
 | 
			
		||||
        meta_model.load_best(final_best_name)
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed))
 | 
			
		||||
    last_success_epoch = 0
 | 
			
		||||
    last_success_epoch, early_stop_thresh = 0, args.pretrain_early_stop_thresh
 | 
			
		||||
    per_epoch_time, start_time = AverageMeter(), time.time()
 | 
			
		||||
    for iepoch in range(args.epochs):
 | 
			
		||||
        left_time = "Time Left: {:}".format(
 | 
			
		||||
@@ -164,14 +169,21 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
 | 
			
		||||
            )
 | 
			
		||||
            + ", batch={:}".format(len(total_meta_losses))
 | 
			
		||||
            + ", success={:}, best_score={:.4f}".format(success, -best_score)
 | 
			
		||||
            + " {:}".format(left_time)
 | 
			
		||||
            + ", LS={:}/{:}".format(last_success_epoch, early_stop_thresh)
 | 
			
		||||
            + ", {:}".format(left_time)
 | 
			
		||||
        )
 | 
			
		||||
        if iepoch - last_success_epoch >= args.early_stop_thresh * 5:
 | 
			
		||||
        if success:
 | 
			
		||||
            last_success_epoch = iepoch
 | 
			
		||||
        if iepoch - last_success_epoch >= early_stop_thresh:
 | 
			
		||||
            logger.log("Early stop the pre-training at {:}".format(iepoch))
 | 
			
		||||
            break
 | 
			
		||||
        per_epoch_time.update(time.time() - start_time)
 | 
			
		||||
        start_time = time.time()
 | 
			
		||||
    meta_model.load_best()
 | 
			
		||||
    # save to the final model
 | 
			
		||||
    meta_model.set_best_name(final_best_name)
 | 
			
		||||
    success, _ = meta_model.save_best(best_score + 1e-6)
 | 
			
		||||
    assert success
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger):
 | 
			
		||||
@@ -189,7 +201,7 @@ def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger):
 | 
			
		||||
    meta_model.set_best_dir(logger.path(None) / "ckps-pretrain-v1")
 | 
			
		||||
    meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed))
 | 
			
		||||
    per_epoch_time, start_time = AverageMeter(), time.time()
 | 
			
		||||
    last_success_epoch = 0
 | 
			
		||||
    last_success_epoch, early_stop_thresh = 0, args.pretrain_early_stop_thresh
 | 
			
		||||
    for iepoch in range(args.epochs):
 | 
			
		||||
        left_time = "Time Left: {:}".format(
 | 
			
		||||
            convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
 | 
			
		||||
@@ -232,9 +244,12 @@ def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger):
 | 
			
		||||
            )
 | 
			
		||||
            + ", batch={:}".format(len(losses))
 | 
			
		||||
            + ", success={:}, best_score={:.4f}".format(success, -best_score)
 | 
			
		||||
            + ", LS={:}/{:}".format(last_success_epoch, early_stop_thresh)
 | 
			
		||||
            + " {:}".format(left_time)
 | 
			
		||||
        )
 | 
			
		||||
        if iepoch - last_success_epoch >= args.early_stop_thresh * 5:
 | 
			
		||||
        if success:
 | 
			
		||||
            last_success_epoch = iepoch
 | 
			
		||||
        if iepoch - last_success_epoch >= early_stop_thresh:
 | 
			
		||||
            logger.log("Early stop the pre-training at {:}".format(iepoch))
 | 
			
		||||
            break
 | 
			
		||||
        per_epoch_time.update(time.time() - start_time)
 | 
			
		||||
@@ -521,7 +536,7 @@ if __name__ == "__main__":
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--refine_lr",
 | 
			
		||||
        type=float,
 | 
			
		||||
        default=0.005,
 | 
			
		||||
        default=0.001,
 | 
			
		||||
        help="The learning rate for the optimizer, during refine",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
@@ -533,6 +548,12 @@ if __name__ == "__main__":
 | 
			
		||||
        default=20,
 | 
			
		||||
        help="The #epochs for early stop.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pretrain_early_stop_thresh",
 | 
			
		||||
        type=int,
 | 
			
		||||
        default=200,
 | 
			
		||||
        help="The #epochs for early stop.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--seq_length", type=int, default=10, help="The sequence length."
 | 
			
		||||
    )
 | 
			
		||||
 
 | 
			
		||||
@@ -70,6 +70,7 @@ class LFNA_Meta(super_core.SuperModule):
 | 
			
		||||
                    dropout,
 | 
			
		||||
                    norm_affine=False,
 | 
			
		||||
                    order=super_core.LayerOrder.PostNorm,
 | 
			
		||||
                    use_mask=True,
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
        layers.append(super_core.SuperLinear(time_embedding * 2, time_embedding))
 | 
			
		||||
@@ -162,11 +163,14 @@ class LFNA_Meta(super_core.SuperModule):
 | 
			
		||||
    def _obtain_time_embed(self, timestamps):
 | 
			
		||||
        # timestamps is a batch of sequence of timestamps
 | 
			
		||||
        batch, seq = timestamps.shape
 | 
			
		||||
        meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed
 | 
			
		||||
        timestamp_q_embed = self._tscalar_embed(timestamps)
 | 
			
		||||
        timestamp_k_embed = self._tscalar_embed(self.meta_timestamps.view(1, -1))
 | 
			
		||||
        timestamp_v_embed = self.super_meta_embed.unsqueeze(dim=0)
 | 
			
		||||
        timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1))
 | 
			
		||||
        timestamp_v_embed = meta_embeds.unsqueeze(dim=0)
 | 
			
		||||
        # create the mask
 | 
			
		||||
        mask = torch.unsqueeze(timestamps, dim=-1) <= meta_timestamps.view(1, 1, -1)
 | 
			
		||||
        timestamp_embeds = self._trans_att(
 | 
			
		||||
            timestamp_q_embed, timestamp_k_embed, timestamp_v_embed
 | 
			
		||||
            timestamp_q_embed, timestamp_k_embed, timestamp_v_embed, mask
 | 
			
		||||
        )
 | 
			
		||||
        # relative_timestamps = timestamps - timestamps[:, :1]
 | 
			
		||||
        # relative_pos_embeds = self._tscalar_embed(relative_timestamps)
 | 
			
		||||
@@ -186,8 +190,12 @@ class LFNA_Meta(super_core.SuperModule):
 | 
			
		||||
        layer_embed = self._super_layer_embed.view(1, 1, num_layer, -1).expand(
 | 
			
		||||
            batch, seq, -1, -1
 | 
			
		||||
        )
 | 
			
		||||
        joint_embed = torch.cat((meta_embed, layer_embed), dim=-1)
 | 
			
		||||
        batch_weights = self._generator(joint_embed)
 | 
			
		||||
        joint_embed = torch.cat(
 | 
			
		||||
            (meta_embed, layer_embed), dim=-1
 | 
			
		||||
        )  # batch, seq, num-layers, input-dim
 | 
			
		||||
        batch_weights = self._generator(
 | 
			
		||||
            joint_embed
 | 
			
		||||
        )  # batch, seq, num-layers, num-weights
 | 
			
		||||
        batch_containers = []
 | 
			
		||||
        for seq_weights in torch.split(batch_weights, 1):
 | 
			
		||||
            seq_containers = []
 | 
			
		||||
 
 | 
			
		||||
@@ -31,12 +31,15 @@ class SuperSelfAttention(SuperModule):
 | 
			
		||||
        qkv_bias: BoolSpaceType = False,
 | 
			
		||||
        attn_drop: Optional[float] = None,
 | 
			
		||||
        proj_drop: Optional[float] = None,
 | 
			
		||||
        use_mask=False,
 | 
			
		||||
    ):
 | 
			
		||||
        super(SuperSelfAttention, self).__init__()
 | 
			
		||||
        self._input_dim = input_dim
 | 
			
		||||
        self._proj_dim = proj_dim
 | 
			
		||||
        self._num_heads = num_heads
 | 
			
		||||
        self._qkv_bias = qkv_bias
 | 
			
		||||
        self._use_mask = use_mask
 | 
			
		||||
        self._infinity = 1e9
 | 
			
		||||
 | 
			
		||||
        self.q_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
 | 
			
		||||
        self.k_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
 | 
			
		||||
@@ -113,6 +116,12 @@ class SuperSelfAttention(SuperModule):
 | 
			
		||||
            .permute(0, 2, 1, 3)
 | 
			
		||||
        )
 | 
			
		||||
        attn_v1 = (q_v1 @ k_v1.transpose(-2, -1)) * math.sqrt(head_dim)
 | 
			
		||||
        if self._use_mask:
 | 
			
		||||
            mask = torch.triu(
 | 
			
		||||
                torch.ones((N, N), dtype=torch.bool, device=input.device), 1
 | 
			
		||||
            )
 | 
			
		||||
            mask = torch.unsqueeze(torch.unsqueeze(mask, dim=0), dim=0)
 | 
			
		||||
            attn_v1 = attn_v1.masked_fill(mask, -self._infinity)
 | 
			
		||||
        attn_v1 = attn_v1.softmax(dim=-1)  # B * #head * N * N
 | 
			
		||||
        attn_v1 = self.attn_drop(attn_v1)
 | 
			
		||||
        feats_v1 = (attn_v1 @ v_v1).permute(0, 2, 1, 3).reshape(B, N, -1)
 | 
			
		||||
@@ -147,8 +156,14 @@ class SuperSelfAttention(SuperModule):
 | 
			
		||||
        return outs
 | 
			
		||||
 | 
			
		||||
    def extra_repr(self) -> str:
 | 
			
		||||
        return "input_dim={:}, proj_dim={:}, num_heads={:}".format(
 | 
			
		||||
            self._input_dim, self._proj_dim, self._num_heads
 | 
			
		||||
        return (
 | 
			
		||||
            "input_dim={:}, proj_dim={:}, num_heads={:}, mask={:}, infinity={:}".format(
 | 
			
		||||
                self._input_dim,
 | 
			
		||||
                self._proj_dim,
 | 
			
		||||
                self._num_heads,
 | 
			
		||||
                self._use_mask,
 | 
			
		||||
                self._infinity,
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -181,6 +196,7 @@ class SuperQKVAttention(SuperModule):
 | 
			
		||||
        self.attn_drop = nn.Dropout(attn_drop or 0.0)
 | 
			
		||||
        self.proj = SuperLinear(proj_dim, proj_dim)
 | 
			
		||||
        self.proj_drop = nn.Dropout(proj_drop or 0.0)
 | 
			
		||||
        self._infinity = 1e9
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def num_heads(self):
 | 
			
		||||
@@ -232,7 +248,9 @@ class SuperQKVAttention(SuperModule):
 | 
			
		||||
        if "proj" in abstract_child:
 | 
			
		||||
            self.proj.apply_candidate(abstract_child["proj"])
 | 
			
		||||
 | 
			
		||||
    def forward_qkv(self, q_tensor, k_tensor, v_tensor, num_head: int) -> torch.Tensor:
 | 
			
		||||
    def forward_qkv(
 | 
			
		||||
        self, q_tensor, k_tensor, v_tensor, num_head: int, mask=None
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        q = self.q_fc(q_tensor)
 | 
			
		||||
        B, N, C = q.shape
 | 
			
		||||
 | 
			
		||||
@@ -257,6 +275,9 @@ class SuperQKVAttention(SuperModule):
 | 
			
		||||
        )
 | 
			
		||||
        # compute the attention map
 | 
			
		||||
        attn_v1 = (q_v1 @ k_v1.transpose(-2, -1)) * math.sqrt(head_dim)
 | 
			
		||||
        if mask is not None:
 | 
			
		||||
            mask = torch.unsqueeze(mask, dim=1)
 | 
			
		||||
            attn_v1 = attn_v1.masked_fill(mask, -self._infinity)
 | 
			
		||||
        attn_v1 = attn_v1.softmax(dim=-1)  # B * #head * N * S
 | 
			
		||||
        attn_v1 = self.attn_drop(attn_v1)
 | 
			
		||||
 | 
			
		||||
@@ -281,26 +302,29 @@ class SuperQKVAttention(SuperModule):
 | 
			
		||||
            feats = torch.cat([feats_v1, feats_v2], dim=-1)
 | 
			
		||||
        return feats
 | 
			
		||||
 | 
			
		||||
    def forward_candidate(self, q_tensor, k_tensor, v_tensor) -> torch.Tensor:
 | 
			
		||||
    def forward_candidate(
 | 
			
		||||
        self, q_tensor, k_tensor, v_tensor, mask=None
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        # check the num_heads:
 | 
			
		||||
        if not spaces.is_determined(self._num_heads):
 | 
			
		||||
            num_heads = self.abstract_child["_num_heads"].value
 | 
			
		||||
        else:
 | 
			
		||||
            num_heads = spaces.get_determined_value(self._num_heads)
 | 
			
		||||
        feats = self.forward_qkv(q_tensor, k_tensor, v_tensor, num_heads)
 | 
			
		||||
        feats = self.forward_qkv(q_tensor, k_tensor, v_tensor, num_heads, mask)
 | 
			
		||||
        outs = self.proj(feats)
 | 
			
		||||
        outs = self.proj_drop(outs)
 | 
			
		||||
        return outs
 | 
			
		||||
 | 
			
		||||
    def forward_raw(self, q_tensor, k_tensor, v_tensor) -> torch.Tensor:
 | 
			
		||||
        feats = self.forward_qkv(q_tensor, k_tensor, v_tensor, self.num_heads)
 | 
			
		||||
    def forward_raw(self, q_tensor, k_tensor, v_tensor, mask=None) -> torch.Tensor:
 | 
			
		||||
        feats = self.forward_qkv(q_tensor, k_tensor, v_tensor, self.num_heads, mask)
 | 
			
		||||
        outs = self.proj(feats)
 | 
			
		||||
        outs = self.proj_drop(outs)
 | 
			
		||||
        return outs
 | 
			
		||||
 | 
			
		||||
    def extra_repr(self) -> str:
 | 
			
		||||
        return "input_dim={:}, proj_dim={:}, num_heads={:}".format(
 | 
			
		||||
        return "input_dim={:}, proj_dim={:}, num_heads={:}, infinity={:}".format(
 | 
			
		||||
            (self.in_q_dim, self.in_k_dim, self.in_v_dim),
 | 
			
		||||
            self._proj_dim,
 | 
			
		||||
            self._num_heads,
 | 
			
		||||
            self._infinity,
 | 
			
		||||
        )
 | 
			
		||||
 
 | 
			
		||||
@@ -117,16 +117,32 @@ class SuperModule(abc.ABC, nn.Module):
 | 
			
		||||
        else:
 | 
			
		||||
            return False, self._meta_info[BEST_SCORE_KEY]
 | 
			
		||||
 | 
			
		||||
    def load_best(self):
 | 
			
		||||
        if BEST_DIR_KEY not in self._meta_info or BEST_SCORE_KEY not in self._meta_info:
 | 
			
		||||
            raise ValueError("Please call save_best at first")
 | 
			
		||||
        best_save_path = os.path.join(
 | 
			
		||||
            self._meta_info[BEST_DIR_KEY],
 | 
			
		||||
            "best-{:}.pth".format(self.__class__.__name__),
 | 
			
		||||
        )
 | 
			
		||||
    def load_best(self, best_save_path=None):
 | 
			
		||||
        if best_save_path is None:
 | 
			
		||||
            if (
 | 
			
		||||
                BEST_DIR_KEY not in self._meta_info
 | 
			
		||||
                or BEST_SCORE_KEY not in self._meta_info
 | 
			
		||||
            ):
 | 
			
		||||
                raise ValueError("Please call save_best at first")
 | 
			
		||||
            best_save_name = self._meta_info.get(
 | 
			
		||||
                BEST_NAME_KEY, "best-{:}.pth".format(self.__class__.__name__)
 | 
			
		||||
            )
 | 
			
		||||
            best_save_path = os.path.join(self._meta_info[BEST_DIR_KEY], best_save_name)
 | 
			
		||||
        state_dict = torch.load(best_save_path)
 | 
			
		||||
        self.load_state_dict(state_dict)
 | 
			
		||||
 | 
			
		||||
    def has_best(self, best_name=None):
 | 
			
		||||
        if BEST_DIR_KEY not in self._meta_info:
 | 
			
		||||
            raise ValueError("Please set BEST_DIR_KEY at first")
 | 
			
		||||
        if best_name is None:
 | 
			
		||||
            best_save_name = self._meta_info.get(
 | 
			
		||||
                BEST_NAME_KEY, "best-{:}.pth".format(self.__class__.__name__)
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            best_save_name = best_name
 | 
			
		||||
        best_save_path = os.path.join(self._meta_info[BEST_DIR_KEY], best_save_name)
 | 
			
		||||
        return os.path.exists(best_save_path)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def abstract_search_space(self):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 
 | 
			
		||||
@@ -45,6 +45,7 @@ class SuperTransformerEncoderLayer(SuperModule):
 | 
			
		||||
        norm_affine: bool = True,
 | 
			
		||||
        act_layer: Callable[[], nn.Module] = nn.GELU,
 | 
			
		||||
        order: LayerOrder = LayerOrder.PreNorm,
 | 
			
		||||
        use_mask: bool = False,
 | 
			
		||||
    ):
 | 
			
		||||
        super(SuperTransformerEncoderLayer, self).__init__()
 | 
			
		||||
        mha = SuperSelfAttention(
 | 
			
		||||
@@ -54,6 +55,7 @@ class SuperTransformerEncoderLayer(SuperModule):
 | 
			
		||||
            qkv_bias=qkv_bias,
 | 
			
		||||
            attn_drop=drop,
 | 
			
		||||
            proj_drop=drop,
 | 
			
		||||
            use_mask=use_mask,
 | 
			
		||||
        )
 | 
			
		||||
        mlp = SuperMLPv2(
 | 
			
		||||
            d_model,
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user