Refine Transformer
This commit is contained in:
		| @@ -4,7 +4,7 @@ import torch.nn.functional as F | |||||||
|  |  | ||||||
| from xautodl.xlayers import super_core | from xautodl.xlayers import super_core | ||||||
| from xautodl.xlayers import trunc_normal_ | from xautodl.xlayers import trunc_normal_ | ||||||
| from xautodl.models.xcore import get_model | from xautodl.xmodels.xcore import get_model | ||||||
|  |  | ||||||
|  |  | ||||||
| class MetaModelV1(super_core.SuperModule): | class MetaModelV1(super_core.SuperModule): | ||||||
|   | |||||||
| @@ -8,6 +8,9 @@ | |||||||
| import os, sys, time, torch | import os, sys, time, torch | ||||||
| import pickle | import pickle | ||||||
| import tempfile | import tempfile | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
|  | root_dir = (Path(__file__).parent / ".." / "..").resolve() | ||||||
|  |  | ||||||
| from xautodl.trade_models.quant_transformer import QuantTransformer | from xautodl.trade_models.quant_transformer import QuantTransformer | ||||||
|  |  | ||||||
| @@ -17,7 +20,7 @@ def test_create(): | |||||||
|     if not torch.cuda.is_available(): |     if not torch.cuda.is_available(): | ||||||
|         return |         return | ||||||
|     quant_model = QuantTransformer(GPU=0) |     quant_model = QuantTransformer(GPU=0) | ||||||
|     temp_dir = lib_dir / ".." / "tests" / ".pytest_cache" |     temp_dir = root_dir / "tests" / ".pytest_cache" | ||||||
|     temp_dir.mkdir(parents=True, exist_ok=True) |     temp_dir.mkdir(parents=True, exist_ok=True) | ||||||
|     temp_file = temp_dir / "quant-model.pkl" |     temp_file = temp_dir / "quant-model.pkl" | ||||||
|     with temp_file.open("wb") as f: |     with temp_file.open("wb") as f: | ||||||
| @@ -30,7 +33,7 @@ def test_create(): | |||||||
|  |  | ||||||
|  |  | ||||||
| def test_load(): | def test_load(): | ||||||
|     temp_file = lib_dir / ".." / "tests" / ".pytest_cache" / "quant-model.pkl" |     temp_file = root_dir / "tests" / ".pytest_cache" / "quant-model.pkl" | ||||||
|     with temp_file.open("rb") as f: |     with temp_file.open("rb") as f: | ||||||
|         model = pickle.load(f) |         model = pickle.load(f) | ||||||
|         print(model.model) |         print(model.model) | ||||||
|   | |||||||
| @@ -21,10 +21,10 @@ import torch.nn.functional as F | |||||||
| import torch.optim as optim | import torch.optim as optim | ||||||
| import torch.utils.data as th_data | import torch.utils.data as th_data | ||||||
|  |  | ||||||
| from log_utils import AverageMeter | from xautodl.xmisc import AverageMeter | ||||||
| from utils import count_parameters | from xautodl.xmisc import count_parameters | ||||||
|  |  | ||||||
| from xlayers import super_core | from xautodl.xlayers import super_core | ||||||
| from .transformers import DEFAULT_NET_CONFIG | from .transformers import DEFAULT_NET_CONFIG | ||||||
| from .transformers import get_transformer | from .transformers import get_transformer | ||||||
|  |  | ||||||
|   | |||||||
| @@ -13,7 +13,7 @@ import torch.nn as nn | |||||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||||
|  |  | ||||||
| from xautodl import spaces | from xautodl import spaces | ||||||
| from xautodl.xlayers import trunc_normal_ | from xautodl.xlayers import weight_init | ||||||
| from xautodl.xlayers import super_core | from xautodl.xlayers import super_core | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -104,7 +104,7 @@ class SuperTransformer(super_core.SuperModule): | |||||||
|         self.head = super_core.SuperSequential( |         self.head = super_core.SuperSequential( | ||||||
|             super_core.SuperLayerNorm1D(embed_dim), super_core.SuperLinear(embed_dim, 1) |             super_core.SuperLayerNorm1D(embed_dim), super_core.SuperLinear(embed_dim, 1) | ||||||
|         ) |         ) | ||||||
|         trunc_normal_(self.cls_token, std=0.02) |         weight_init.trunc_normal_(self.cls_token, std=0.02) | ||||||
|         self.apply(self._init_weights) |         self.apply(self._init_weights) | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
| @@ -136,11 +136,11 @@ class SuperTransformer(super_core.SuperModule): | |||||||
|  |  | ||||||
|     def _init_weights(self, m): |     def _init_weights(self, m): | ||||||
|         if isinstance(m, nn.Linear): |         if isinstance(m, nn.Linear): | ||||||
|             trunc_normal_(m.weight, std=0.02) |             weight_init.trunc_normal_(m.weight, std=0.02) | ||||||
|             if isinstance(m, nn.Linear) and m.bias is not None: |             if isinstance(m, nn.Linear) and m.bias is not None: | ||||||
|                 nn.init.constant_(m.bias, 0) |                 nn.init.constant_(m.bias, 0) | ||||||
|         elif isinstance(m, super_core.SuperLinear): |         elif isinstance(m, super_core.SuperLinear): | ||||||
|             trunc_normal_(m._super_weight, std=0.02) |             weight_init.trunc_normal_(m._super_weight, std=0.02) | ||||||
|             if m._super_bias is not None: |             if m._super_bias is not None: | ||||||
|                 nn.init.constant_(m._super_bias, 0) |                 nn.init.constant_(m._super_bias, 0) | ||||||
|         elif isinstance(m, super_core.SuperLayerNorm1D): |         elif isinstance(m, super_core.SuperLayerNorm1D): | ||||||
|   | |||||||
| @@ -4,5 +4,4 @@ | |||||||
| # This file is expected to be self-contained, expect | # This file is expected to be self-contained, expect | ||||||
| # for importing from spaces to include search space. | # for importing from spaces to include search space. | ||||||
| ##################################################### | ##################################################### | ||||||
| from .weight_init import trunc_normal_ |  | ||||||
| from .super_core import * | from .super_core import * | ||||||
|   | |||||||
| @@ -1,8 +1,12 @@ | |||||||
| # Borrowed from https://github.com/rwightman/pytorch-image-models | # Borrowed from https://github.com/rwightman/pytorch-image-models | ||||||
| import torch | import torch | ||||||
|  | import torch.nn as nn | ||||||
| import math | import math | ||||||
| import warnings | import warnings | ||||||
|  |  | ||||||
|  | # setup for xlayers | ||||||
|  | from . import super_core | ||||||
|  |  | ||||||
|  |  | ||||||
| def _no_grad_trunc_normal_(tensor, mean, std, a, b): | def _no_grad_trunc_normal_(tensor, mean, std, a, b): | ||||||
|     # Cut & paste from PyTorch official master until it's in a few official releases - RW |     # Cut & paste from PyTorch official master until it's in a few official releases - RW | ||||||
| @@ -64,3 +68,17 @@ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): | |||||||
|         return [_no_grad_trunc_normal_(x, mean, std, a, b) for x in tensor] |         return [_no_grad_trunc_normal_(x, mean, std, a, b) for x in tensor] | ||||||
|     else: |     else: | ||||||
|         return _no_grad_trunc_normal_(tensor, mean, std, a, b) |         return _no_grad_trunc_normal_(tensor, mean, std, a, b) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def init_transformer(m): | ||||||
|  |     if isinstance(m, nn.Linear): | ||||||
|  |         trunc_normal_(m.weight, std=0.02) | ||||||
|  |         if isinstance(m, nn.Linear) and m.bias is not None: | ||||||
|  |             nn.init.constant_(m.bias, 0) | ||||||
|  |     elif isinstance(m, super_core.SuperLinear): | ||||||
|  |         trunc_normal_(m._super_weight, std=0.02) | ||||||
|  |         if m._super_bias is not None: | ||||||
|  |             nn.init.constant_(m._super_bias, 0) | ||||||
|  |     elif isinstance(m, super_core.SuperLayerNorm1D): | ||||||
|  |         nn.init.constant_(m.weight, 1.0) | ||||||
|  |         nn.init.constant_(m.bias, 0) | ||||||
|   | |||||||
| @@ -4,4 +4,4 @@ | |||||||
| # The models in this folder is written with xlayers # | # The models in this folder is written with xlayers # | ||||||
| ##################################################### | ##################################################### | ||||||
|  |  | ||||||
| from .transformers import get_transformer | from .core import * | ||||||
|   | |||||||
| @@ -15,7 +15,7 @@ from xautodl.xlayers.super_core import super_name2activation | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_model(config: Dict[Text, Any], **kwargs): | def get_model(config: Dict[Text, Any], **kwargs): | ||||||
|     model_type = config.get("model_type", "simple_mlp") |     model_type = config.get("model_type", "simple_mlp").lower() | ||||||
|     if model_type == "simple_mlp": |     if model_type == "simple_mlp": | ||||||
|         act_cls = super_name2activation[kwargs["act_cls"]] |         act_cls = super_name2activation[kwargs["act_cls"]] | ||||||
|         norm_cls = super_name2norm[kwargs["norm_cls"]] |         norm_cls = super_name2norm[kwargs["norm_cls"]] | ||||||
| @@ -60,6 +60,8 @@ def get_model(config: Dict[Text, Any], **kwargs): | |||||||
|             last_dim = hidden_dim |             last_dim = hidden_dim | ||||||
|         sub_layers.append(SuperLinear(last_dim, kwargs["output_dim"])) |         sub_layers.append(SuperLinear(last_dim, kwargs["output_dim"])) | ||||||
|         model = SuperSequential(*sub_layers) |         model = SuperSequential(*sub_layers) | ||||||
|  |     elif model_type == "quant_transformer": | ||||||
|  |         raise NotImplementedError | ||||||
|     else: |     else: | ||||||
|         raise TypeError("Unkonwn model type: {:}".format(model_type)) |         raise TypeError("Unkonwn model type: {:}".format(model_type)) | ||||||
|     return model |     return model | ||||||
| @@ -20,20 +20,6 @@ def pair(t): | |||||||
|     return t if isinstance(t, tuple) else (t, t) |     return t if isinstance(t, tuple) else (t, t) | ||||||
|  |  | ||||||
|  |  | ||||||
| def _init_weights(m): |  | ||||||
|     if isinstance(m, nn.Linear): |  | ||||||
|         weight_init.trunc_normal_(m.weight, std=0.02) |  | ||||||
|         if isinstance(m, nn.Linear) and m.bias is not None: |  | ||||||
|             nn.init.constant_(m.bias, 0) |  | ||||||
|     elif isinstance(m, xlayers.SuperLinear): |  | ||||||
|         weight_init.trunc_normal_(m._super_weight, std=0.02) |  | ||||||
|         if m._super_bias is not None: |  | ||||||
|             nn.init.constant_(m._super_bias, 0) |  | ||||||
|     elif isinstance(m, xlayers.SuperLayerNorm1D): |  | ||||||
|         nn.init.constant_(m.weight, 1.0) |  | ||||||
|         nn.init.constant_(m.bias, 0) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| name2config = { | name2config = { | ||||||
|     "vit-cifar10-p4-d4-h4-c32": dict( |     "vit-cifar10-p4-d4-h4-c32": dict( | ||||||
|         type="vit", |         type="vit", | ||||||
| @@ -155,7 +141,7 @@ class SuperViT(xlayers.SuperModule): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         weight_init.trunc_normal_(self.cls_token, std=0.02) |         weight_init.trunc_normal_(self.cls_token, std=0.02) | ||||||
|         self.apply(_init_weights) |         self.apply(weight_init.init_transformer) | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def abstract_search_space(self): |     def abstract_search_space(self): | ||||||
|   | |||||||
							
								
								
									
										124
									
								
								xautodl/xmodels/transformers_quantum.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										124
									
								
								xautodl/xmodels/transformers_quantum.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,124 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 # | ||||||
|  | ##################################################### | ||||||
|  | # Vision Transformer: arxiv.org/pdf/2010.11929.pdf  # | ||||||
|  | ##################################################### | ||||||
|  | import copy, math | ||||||
|  | from functools import partial | ||||||
|  | from typing import Optional, Text, List | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | import torch.nn.functional as F | ||||||
|  |  | ||||||
|  | from xautodl import spaces | ||||||
|  | from xautodl import xlayers | ||||||
|  | from xautodl.xlayers import weight_init | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SuperQuaT(xlayers.SuperModule): | ||||||
|  |     """The super transformer for transformer.""" | ||||||
|  |  | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         image_size, | ||||||
|  |         patch_size, | ||||||
|  |         num_classes, | ||||||
|  |         dim, | ||||||
|  |         depth, | ||||||
|  |         heads, | ||||||
|  |         mlp_multiplier=4, | ||||||
|  |         channels=3, | ||||||
|  |         dropout=0.0, | ||||||
|  |         att_dropout=0.0, | ||||||
|  |     ): | ||||||
|  |         super(SuperQuaT, self).__init__() | ||||||
|  |         image_height, image_width = pair(image_size) | ||||||
|  |         patch_height, patch_width = pair(patch_size) | ||||||
|  |  | ||||||
|  |         if image_height % patch_height != 0 or image_width % patch_width != 0: | ||||||
|  |             raise ValueError("Image dimensions must be divisible by the patch size.") | ||||||
|  |  | ||||||
|  |         num_patches = (image_height // patch_height) * (image_width // patch_width) | ||||||
|  |         patch_dim = channels * patch_height * patch_width | ||||||
|  |         self.to_patch_embedding = xlayers.SuperSequential( | ||||||
|  |             xlayers.SuperReArrange( | ||||||
|  |                 "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", | ||||||
|  |                 p1=patch_height, | ||||||
|  |                 p2=patch_width, | ||||||
|  |             ), | ||||||
|  |             xlayers.SuperLinear(patch_dim, dim), | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) | ||||||
|  |         self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) | ||||||
|  |         self.dropout = nn.Dropout(dropout) | ||||||
|  |  | ||||||
|  |         # build the transformer encode layers | ||||||
|  |         layers = [] | ||||||
|  |         for ilayer in range(depth): | ||||||
|  |             layers.append( | ||||||
|  |                 xlayers.SuperTransformerEncoderLayer( | ||||||
|  |                     dim, | ||||||
|  |                     heads, | ||||||
|  |                     False, | ||||||
|  |                     mlp_multiplier, | ||||||
|  |                     dropout=dropout, | ||||||
|  |                     att_dropout=att_dropout, | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |         self.backbone = xlayers.SuperSequential(*layers) | ||||||
|  |         self.cls_head = xlayers.SuperSequential( | ||||||
|  |             xlayers.SuperLayerNorm1D(dim), xlayers.SuperLinear(dim, num_classes) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         weight_init.trunc_normal_(self.cls_token, std=0.02) | ||||||
|  |         self.apply(_init_weights) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def abstract_search_space(self): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def apply_candidate(self, abstract_child: spaces.VirtualNode): | ||||||
|  |         super(SuperQuaT, self).apply_candidate(abstract_child) | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         tensors = self.to_patch_embedding(input) | ||||||
|  |         batch, seq, _ = tensors.shape | ||||||
|  |  | ||||||
|  |         cls_tokens = self.cls_token.expand(batch, -1, -1) | ||||||
|  |         feats = torch.cat((cls_tokens, tensors), dim=1) | ||||||
|  |         feats = feats + self.pos_embedding[:, : seq + 1, :] | ||||||
|  |         feats = self.dropout(feats) | ||||||
|  |  | ||||||
|  |         feats = self.backbone(feats) | ||||||
|  |  | ||||||
|  |         x = feats[:, 0]  # the features for cls-token | ||||||
|  |  | ||||||
|  |         return self.cls_head(x) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_transformer(config): | ||||||
|  |     if isinstance(config, str) and config.lower() in name2config: | ||||||
|  |         config = name2config[config.lower()] | ||||||
|  |     if not isinstance(config, dict): | ||||||
|  |         raise ValueError("Invalid Configuration: {:}".format(config)) | ||||||
|  |     model_type = config.get("type", "vit").lower() | ||||||
|  |     if model_type == "vit": | ||||||
|  |         model = SuperQuaT( | ||||||
|  |             image_size=config.get("image_size"), | ||||||
|  |             patch_size=config.get("patch_size"), | ||||||
|  |             num_classes=config.get("num_classes"), | ||||||
|  |             dim=config.get("dim"), | ||||||
|  |             depth=config.get("depth"), | ||||||
|  |             heads=config.get("heads"), | ||||||
|  |             dropout=config.get("dropout"), | ||||||
|  |             att_dropout=config.get("att_dropout"), | ||||||
|  |         ) | ||||||
|  |     else: | ||||||
|  |         raise ValueError("Unknown model type: {:}".format(model_type)) | ||||||
|  |     return model | ||||||
		Reference in New Issue
	
	Block a user