Refine Transformer
This commit is contained in:
		| @@ -4,7 +4,7 @@ import torch.nn.functional as F | ||||
|  | ||||
| from xautodl.xlayers import super_core | ||||
| from xautodl.xlayers import trunc_normal_ | ||||
| from xautodl.models.xcore import get_model | ||||
| from xautodl.xmodels.xcore import get_model | ||||
|  | ||||
|  | ||||
| class MetaModelV1(super_core.SuperModule): | ||||
|   | ||||
| @@ -8,6 +8,9 @@ | ||||
| import os, sys, time, torch | ||||
| import pickle | ||||
| import tempfile | ||||
| from pathlib import Path | ||||
|  | ||||
| root_dir = (Path(__file__).parent / ".." / "..").resolve() | ||||
|  | ||||
| from xautodl.trade_models.quant_transformer import QuantTransformer | ||||
|  | ||||
| @@ -17,7 +20,7 @@ def test_create(): | ||||
|     if not torch.cuda.is_available(): | ||||
|         return | ||||
|     quant_model = QuantTransformer(GPU=0) | ||||
|     temp_dir = lib_dir / ".." / "tests" / ".pytest_cache" | ||||
|     temp_dir = root_dir / "tests" / ".pytest_cache" | ||||
|     temp_dir.mkdir(parents=True, exist_ok=True) | ||||
|     temp_file = temp_dir / "quant-model.pkl" | ||||
|     with temp_file.open("wb") as f: | ||||
| @@ -30,7 +33,7 @@ def test_create(): | ||||
|  | ||||
|  | ||||
| def test_load(): | ||||
|     temp_file = lib_dir / ".." / "tests" / ".pytest_cache" / "quant-model.pkl" | ||||
|     temp_file = root_dir / "tests" / ".pytest_cache" / "quant-model.pkl" | ||||
|     with temp_file.open("rb") as f: | ||||
|         model = pickle.load(f) | ||||
|         print(model.model) | ||||
|   | ||||
| @@ -21,10 +21,10 @@ import torch.nn.functional as F | ||||
| import torch.optim as optim | ||||
| import torch.utils.data as th_data | ||||
|  | ||||
| from log_utils import AverageMeter | ||||
| from utils import count_parameters | ||||
| from xautodl.xmisc import AverageMeter | ||||
| from xautodl.xmisc import count_parameters | ||||
|  | ||||
| from xlayers import super_core | ||||
| from xautodl.xlayers import super_core | ||||
| from .transformers import DEFAULT_NET_CONFIG | ||||
| from .transformers import get_transformer | ||||
|  | ||||
|   | ||||
| @@ -13,7 +13,7 @@ import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| from xautodl import spaces | ||||
| from xautodl.xlayers import trunc_normal_ | ||||
| from xautodl.xlayers import weight_init | ||||
| from xautodl.xlayers import super_core | ||||
|  | ||||
|  | ||||
| @@ -104,7 +104,7 @@ class SuperTransformer(super_core.SuperModule): | ||||
|         self.head = super_core.SuperSequential( | ||||
|             super_core.SuperLayerNorm1D(embed_dim), super_core.SuperLinear(embed_dim, 1) | ||||
|         ) | ||||
|         trunc_normal_(self.cls_token, std=0.02) | ||||
|         weight_init.trunc_normal_(self.cls_token, std=0.02) | ||||
|         self.apply(self._init_weights) | ||||
|  | ||||
|     @property | ||||
| @@ -136,11 +136,11 @@ class SuperTransformer(super_core.SuperModule): | ||||
|  | ||||
|     def _init_weights(self, m): | ||||
|         if isinstance(m, nn.Linear): | ||||
|             trunc_normal_(m.weight, std=0.02) | ||||
|             weight_init.trunc_normal_(m.weight, std=0.02) | ||||
|             if isinstance(m, nn.Linear) and m.bias is not None: | ||||
|                 nn.init.constant_(m.bias, 0) | ||||
|         elif isinstance(m, super_core.SuperLinear): | ||||
|             trunc_normal_(m._super_weight, std=0.02) | ||||
|             weight_init.trunc_normal_(m._super_weight, std=0.02) | ||||
|             if m._super_bias is not None: | ||||
|                 nn.init.constant_(m._super_bias, 0) | ||||
|         elif isinstance(m, super_core.SuperLayerNorm1D): | ||||
|   | ||||
| @@ -4,5 +4,4 @@ | ||||
| # This file is expected to be self-contained, expect | ||||
| # for importing from spaces to include search space. | ||||
| ##################################################### | ||||
| from .weight_init import trunc_normal_ | ||||
| from .super_core import * | ||||
|   | ||||
| @@ -1,8 +1,12 @@ | ||||
| # Borrowed from https://github.com/rwightman/pytorch-image-models | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import math | ||||
| import warnings | ||||
|  | ||||
| # setup for xlayers | ||||
| from . import super_core | ||||
|  | ||||
|  | ||||
| def _no_grad_trunc_normal_(tensor, mean, std, a, b): | ||||
|     # Cut & paste from PyTorch official master until it's in a few official releases - RW | ||||
| @@ -64,3 +68,17 @@ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): | ||||
|         return [_no_grad_trunc_normal_(x, mean, std, a, b) for x in tensor] | ||||
|     else: | ||||
|         return _no_grad_trunc_normal_(tensor, mean, std, a, b) | ||||
|  | ||||
|  | ||||
| def init_transformer(m): | ||||
|     if isinstance(m, nn.Linear): | ||||
|         trunc_normal_(m.weight, std=0.02) | ||||
|         if isinstance(m, nn.Linear) and m.bias is not None: | ||||
|             nn.init.constant_(m.bias, 0) | ||||
|     elif isinstance(m, super_core.SuperLinear): | ||||
|         trunc_normal_(m._super_weight, std=0.02) | ||||
|         if m._super_bias is not None: | ||||
|             nn.init.constant_(m._super_bias, 0) | ||||
|     elif isinstance(m, super_core.SuperLayerNorm1D): | ||||
|         nn.init.constant_(m.weight, 1.0) | ||||
|         nn.init.constant_(m.bias, 0) | ||||
|   | ||||
| @@ -4,4 +4,4 @@ | ||||
| # The models in this folder is written with xlayers # | ||||
| ##################################################### | ||||
|  | ||||
| from .transformers import get_transformer | ||||
| from .core import * | ||||
|   | ||||
| @@ -15,7 +15,7 @@ from xautodl.xlayers.super_core import super_name2activation | ||||
| 
 | ||||
| 
 | ||||
| def get_model(config: Dict[Text, Any], **kwargs): | ||||
|     model_type = config.get("model_type", "simple_mlp") | ||||
|     model_type = config.get("model_type", "simple_mlp").lower() | ||||
|     if model_type == "simple_mlp": | ||||
|         act_cls = super_name2activation[kwargs["act_cls"]] | ||||
|         norm_cls = super_name2norm[kwargs["norm_cls"]] | ||||
| @@ -60,6 +60,8 @@ def get_model(config: Dict[Text, Any], **kwargs): | ||||
|             last_dim = hidden_dim | ||||
|         sub_layers.append(SuperLinear(last_dim, kwargs["output_dim"])) | ||||
|         model = SuperSequential(*sub_layers) | ||||
|     elif model_type == "quant_transformer": | ||||
|         raise NotImplementedError | ||||
|     else: | ||||
|         raise TypeError("Unkonwn model type: {:}".format(model_type)) | ||||
|     return model | ||||
| @@ -20,20 +20,6 @@ def pair(t): | ||||
|     return t if isinstance(t, tuple) else (t, t) | ||||
|  | ||||
|  | ||||
| def _init_weights(m): | ||||
|     if isinstance(m, nn.Linear): | ||||
|         weight_init.trunc_normal_(m.weight, std=0.02) | ||||
|         if isinstance(m, nn.Linear) and m.bias is not None: | ||||
|             nn.init.constant_(m.bias, 0) | ||||
|     elif isinstance(m, xlayers.SuperLinear): | ||||
|         weight_init.trunc_normal_(m._super_weight, std=0.02) | ||||
|         if m._super_bias is not None: | ||||
|             nn.init.constant_(m._super_bias, 0) | ||||
|     elif isinstance(m, xlayers.SuperLayerNorm1D): | ||||
|         nn.init.constant_(m.weight, 1.0) | ||||
|         nn.init.constant_(m.bias, 0) | ||||
|  | ||||
|  | ||||
| name2config = { | ||||
|     "vit-cifar10-p4-d4-h4-c32": dict( | ||||
|         type="vit", | ||||
| @@ -155,7 +141,7 @@ class SuperViT(xlayers.SuperModule): | ||||
|         ) | ||||
|  | ||||
|         weight_init.trunc_normal_(self.cls_token, std=0.02) | ||||
|         self.apply(_init_weights) | ||||
|         self.apply(weight_init.init_transformer) | ||||
|  | ||||
|     @property | ||||
|     def abstract_search_space(self): | ||||
|   | ||||
							
								
								
									
										124
									
								
								xautodl/xmodels/transformers_quantum.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										124
									
								
								xautodl/xmodels/transformers_quantum.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,124 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 # | ||||
| ##################################################### | ||||
| # Vision Transformer: arxiv.org/pdf/2010.11929.pdf  # | ||||
| ##################################################### | ||||
| import copy, math | ||||
| from functools import partial | ||||
| from typing import Optional, Text, List | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| from xautodl import spaces | ||||
| from xautodl import xlayers | ||||
| from xautodl.xlayers import weight_init | ||||
|  | ||||
|  | ||||
| class SuperQuaT(xlayers.SuperModule): | ||||
|     """The super transformer for transformer.""" | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         image_size, | ||||
|         patch_size, | ||||
|         num_classes, | ||||
|         dim, | ||||
|         depth, | ||||
|         heads, | ||||
|         mlp_multiplier=4, | ||||
|         channels=3, | ||||
|         dropout=0.0, | ||||
|         att_dropout=0.0, | ||||
|     ): | ||||
|         super(SuperQuaT, self).__init__() | ||||
|         image_height, image_width = pair(image_size) | ||||
|         patch_height, patch_width = pair(patch_size) | ||||
|  | ||||
|         if image_height % patch_height != 0 or image_width % patch_width != 0: | ||||
|             raise ValueError("Image dimensions must be divisible by the patch size.") | ||||
|  | ||||
|         num_patches = (image_height // patch_height) * (image_width // patch_width) | ||||
|         patch_dim = channels * patch_height * patch_width | ||||
|         self.to_patch_embedding = xlayers.SuperSequential( | ||||
|             xlayers.SuperReArrange( | ||||
|                 "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", | ||||
|                 p1=patch_height, | ||||
|                 p2=patch_width, | ||||
|             ), | ||||
|             xlayers.SuperLinear(patch_dim, dim), | ||||
|         ) | ||||
|  | ||||
|         self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) | ||||
|         self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) | ||||
|         self.dropout = nn.Dropout(dropout) | ||||
|  | ||||
|         # build the transformer encode layers | ||||
|         layers = [] | ||||
|         for ilayer in range(depth): | ||||
|             layers.append( | ||||
|                 xlayers.SuperTransformerEncoderLayer( | ||||
|                     dim, | ||||
|                     heads, | ||||
|                     False, | ||||
|                     mlp_multiplier, | ||||
|                     dropout=dropout, | ||||
|                     att_dropout=att_dropout, | ||||
|                 ) | ||||
|             ) | ||||
|         self.backbone = xlayers.SuperSequential(*layers) | ||||
|         self.cls_head = xlayers.SuperSequential( | ||||
|             xlayers.SuperLayerNorm1D(dim), xlayers.SuperLinear(dim, num_classes) | ||||
|         ) | ||||
|  | ||||
|         weight_init.trunc_normal_(self.cls_token, std=0.02) | ||||
|         self.apply(_init_weights) | ||||
|  | ||||
|     @property | ||||
|     def abstract_search_space(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def apply_candidate(self, abstract_child: spaces.VirtualNode): | ||||
|         super(SuperQuaT, self).apply_candidate(abstract_child) | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         tensors = self.to_patch_embedding(input) | ||||
|         batch, seq, _ = tensors.shape | ||||
|  | ||||
|         cls_tokens = self.cls_token.expand(batch, -1, -1) | ||||
|         feats = torch.cat((cls_tokens, tensors), dim=1) | ||||
|         feats = feats + self.pos_embedding[:, : seq + 1, :] | ||||
|         feats = self.dropout(feats) | ||||
|  | ||||
|         feats = self.backbone(feats) | ||||
|  | ||||
|         x = feats[:, 0]  # the features for cls-token | ||||
|  | ||||
|         return self.cls_head(x) | ||||
|  | ||||
|  | ||||
| def get_transformer(config): | ||||
|     if isinstance(config, str) and config.lower() in name2config: | ||||
|         config = name2config[config.lower()] | ||||
|     if not isinstance(config, dict): | ||||
|         raise ValueError("Invalid Configuration: {:}".format(config)) | ||||
|     model_type = config.get("type", "vit").lower() | ||||
|     if model_type == "vit": | ||||
|         model = SuperQuaT( | ||||
|             image_size=config.get("image_size"), | ||||
|             patch_size=config.get("patch_size"), | ||||
|             num_classes=config.get("num_classes"), | ||||
|             dim=config.get("dim"), | ||||
|             depth=config.get("depth"), | ||||
|             heads=config.get("heads"), | ||||
|             dropout=config.get("dropout"), | ||||
|             att_dropout=config.get("att_dropout"), | ||||
|         ) | ||||
|     else: | ||||
|         raise ValueError("Unknown model type: {:}".format(model_type)) | ||||
|     return model | ||||
		Reference in New Issue
	
	Block a user