Move to xautodl
This commit is contained in:
		| @@ -1,20 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| # general config related functions | ||||
| from .config_utils import load_config, dict2config, configure2str | ||||
|  | ||||
| # the args setting for different experiments | ||||
| from .basic_args import obtain_basic_args | ||||
| from .attention_args import obtain_attention_args | ||||
| from .random_baseline import obtain_RandomSearch_args | ||||
| from .cls_kd_args import obtain_cls_kd_args | ||||
| from .cls_init_args import obtain_cls_init_args | ||||
| from .search_single_args import obtain_search_single_args | ||||
| from .search_args import obtain_search_args | ||||
|  | ||||
| # for network pruning | ||||
| from .pruning_args import obtain_pruning_args | ||||
|  | ||||
| # utils for args | ||||
| from .args_utils import arg_str2bool | ||||
| @@ -1,12 +0,0 @@ | ||||
| import argparse | ||||
|  | ||||
|  | ||||
| def arg_str2bool(v): | ||||
|     if isinstance(v, bool): | ||||
|         return v | ||||
|     elif v.lower() in ("yes", "true", "t", "y", "1"): | ||||
|         return True | ||||
|     elif v.lower() in ("no", "false", "f", "n", "0"): | ||||
|         return False | ||||
|     else: | ||||
|         raise argparse.ArgumentTypeError("Boolean value expected.") | ||||
| @@ -1,32 +0,0 @@ | ||||
| import random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
|  | ||||
| def obtain_attention_args(): | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a classification model on typical image classification datasets.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||
|     parser.add_argument( | ||||
|         "--model_config", type=str, help="The path to the model configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||
|     ) | ||||
|     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||
|     parser.add_argument("--att_channel", type=int, help=".") | ||||
|     parser.add_argument("--att_spatial", type=str, help=".") | ||||
|     parser.add_argument("--att_active", type=str, help=".") | ||||
|     add_shared_args(parser) | ||||
|     # Optimization options | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=2, help="Batch size for training." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "save-path argument can not be None" | ||||
|     return args | ||||
| @@ -1,44 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||
| ################################################## | ||||
| import random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
|  | ||||
| def obtain_basic_args(): | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a classification model on typical image classification datasets.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||
|     parser.add_argument( | ||||
|         "--model_config", type=str, help="The path to the model configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||
|     ) | ||||
|     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||
|     parser.add_argument( | ||||
|         "--model_source", | ||||
|         type=str, | ||||
|         default="normal", | ||||
|         help="The source of model defination.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--extra_model_path", | ||||
|         type=str, | ||||
|         default=None, | ||||
|         help="The extra model ckp file (help to indicate the searched architecture).", | ||||
|     ) | ||||
|     add_shared_args(parser) | ||||
|     # Optimization options | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=2, help="Batch size for training." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "save-path argument can not be None" | ||||
|     return args | ||||
| @@ -1,32 +0,0 @@ | ||||
| import random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
|  | ||||
| def obtain_cls_init_args(): | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a classification model on typical image classification datasets.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||
|     parser.add_argument( | ||||
|         "--model_config", type=str, help="The path to the model configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||
|     ) | ||||
|     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||
|     parser.add_argument( | ||||
|         "--init_checkpoint", type=str, help="The checkpoint path to the initial model." | ||||
|     ) | ||||
|     add_shared_args(parser) | ||||
|     # Optimization options | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=2, help="Batch size for training." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "save-path argument can not be None" | ||||
|     return args | ||||
| @@ -1,43 +0,0 @@ | ||||
| import random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
|  | ||||
| def obtain_cls_kd_args(): | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a classification model on typical image classification datasets.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||
|     parser.add_argument( | ||||
|         "--model_config", type=str, help="The path to the model configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||
|     ) | ||||
|     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||
|     parser.add_argument( | ||||
|         "--KD_checkpoint", | ||||
|         type=str, | ||||
|         help="The teacher checkpoint in knowledge distillation.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--KD_alpha", type=float, help="The alpha parameter in knowledge distillation." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--KD_temperature", | ||||
|         type=float, | ||||
|         help="The temperature parameter in knowledge distillation.", | ||||
|     ) | ||||
|     # parser.add_argument('--KD_feature',       type=float,                 help='Knowledge distillation at the feature level.') | ||||
|     add_shared_args(parser) | ||||
|     # Optimization options | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=2, help="Batch size for training." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "save-path argument can not be None" | ||||
|     return args | ||||
| @@ -1,135 +0,0 @@ | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
| # | ||||
| import os, json | ||||
| from os import path as osp | ||||
| from pathlib import Path | ||||
| from collections import namedtuple | ||||
|  | ||||
| support_types = ("str", "int", "bool", "float", "none") | ||||
|  | ||||
|  | ||||
| def convert_param(original_lists): | ||||
|     assert isinstance(original_lists, list), "The type is not right : {:}".format( | ||||
|         original_lists | ||||
|     ) | ||||
|     ctype, value = original_lists[0], original_lists[1] | ||||
|     assert ctype in support_types, "Ctype={:}, support={:}".format(ctype, support_types) | ||||
|     is_list = isinstance(value, list) | ||||
|     if not is_list: | ||||
|         value = [value] | ||||
|     outs = [] | ||||
|     for x in value: | ||||
|         if ctype == "int": | ||||
|             x = int(x) | ||||
|         elif ctype == "str": | ||||
|             x = str(x) | ||||
|         elif ctype == "bool": | ||||
|             x = bool(int(x)) | ||||
|         elif ctype == "float": | ||||
|             x = float(x) | ||||
|         elif ctype == "none": | ||||
|             if x.lower() != "none": | ||||
|                 raise ValueError( | ||||
|                     "For the none type, the value must be none instead of {:}".format(x) | ||||
|                 ) | ||||
|             x = None | ||||
|         else: | ||||
|             raise TypeError("Does not know this type : {:}".format(ctype)) | ||||
|         outs.append(x) | ||||
|     if not is_list: | ||||
|         outs = outs[0] | ||||
|     return outs | ||||
|  | ||||
|  | ||||
| def load_config(path, extra, logger): | ||||
|     path = str(path) | ||||
|     if hasattr(logger, "log"): | ||||
|         logger.log(path) | ||||
|     assert os.path.exists(path), "Can not find {:}".format(path) | ||||
|     # Reading data back | ||||
|     with open(path, "r") as f: | ||||
|         data = json.load(f) | ||||
|     content = {k: convert_param(v) for k, v in data.items()} | ||||
|     assert extra is None or isinstance( | ||||
|         extra, dict | ||||
|     ), "invalid type of extra : {:}".format(extra) | ||||
|     if isinstance(extra, dict): | ||||
|         content = {**content, **extra} | ||||
|     Arguments = namedtuple("Configure", " ".join(content.keys())) | ||||
|     content = Arguments(**content) | ||||
|     if hasattr(logger, "log"): | ||||
|         logger.log("{:}".format(content)) | ||||
|     return content | ||||
|  | ||||
|  | ||||
| def configure2str(config, xpath=None): | ||||
|     if not isinstance(config, dict): | ||||
|         config = config._asdict() | ||||
|  | ||||
|     def cstring(x): | ||||
|         return '"{:}"'.format(x) | ||||
|  | ||||
|     def gtype(x): | ||||
|         if isinstance(x, list): | ||||
|             x = x[0] | ||||
|         if isinstance(x, str): | ||||
|             return "str" | ||||
|         elif isinstance(x, bool): | ||||
|             return "bool" | ||||
|         elif isinstance(x, int): | ||||
|             return "int" | ||||
|         elif isinstance(x, float): | ||||
|             return "float" | ||||
|         elif x is None: | ||||
|             return "none" | ||||
|         else: | ||||
|             raise ValueError("invalid : {:}".format(x)) | ||||
|  | ||||
|     def cvalue(x, xtype): | ||||
|         if isinstance(x, list): | ||||
|             is_list = True | ||||
|         else: | ||||
|             is_list, x = False, [x] | ||||
|         temps = [] | ||||
|         for temp in x: | ||||
|             if xtype == "bool": | ||||
|                 temp = cstring(int(temp)) | ||||
|             elif xtype == "none": | ||||
|                 temp = cstring("None") | ||||
|             else: | ||||
|                 temp = cstring(temp) | ||||
|             temps.append(temp) | ||||
|         if is_list: | ||||
|             return "[{:}]".format(", ".join(temps)) | ||||
|         else: | ||||
|             return temps[0] | ||||
|  | ||||
|     xstrings = [] | ||||
|     for key, value in config.items(): | ||||
|         xtype = gtype(value) | ||||
|         string = "  {:20s} : [{:8s}, {:}]".format( | ||||
|             cstring(key), cstring(xtype), cvalue(value, xtype) | ||||
|         ) | ||||
|         xstrings.append(string) | ||||
|     Fstring = "{\n" + ",\n".join(xstrings) + "\n}" | ||||
|     if xpath is not None: | ||||
|         parent = Path(xpath).resolve().parent | ||||
|         parent.mkdir(parents=True, exist_ok=True) | ||||
|         if osp.isfile(xpath): | ||||
|             os.remove(xpath) | ||||
|         with open(xpath, "w") as text_file: | ||||
|             text_file.write("{:}".format(Fstring)) | ||||
|     return Fstring | ||||
|  | ||||
|  | ||||
| def dict2config(xdict, logger): | ||||
|     assert isinstance(xdict, dict), "invalid type : {:}".format(type(xdict)) | ||||
|     Arguments = namedtuple("Configure", " ".join(xdict.keys())) | ||||
|     content = Arguments(**xdict) | ||||
|     if hasattr(logger, "log"): | ||||
|         logger.log("{:}".format(content)) | ||||
|     return content | ||||
| @@ -1,48 +0,0 @@ | ||||
| import os, sys, time, random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
|  | ||||
| def obtain_pruning_args(): | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a classification model on typical image classification datasets.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||
|     parser.add_argument( | ||||
|         "--model_config", type=str, help="The path to the model configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||
|     ) | ||||
|     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||
|     parser.add_argument( | ||||
|         "--keep_ratio", | ||||
|         type=float, | ||||
|         help="The left channel ratio compared to the original network.", | ||||
|     ) | ||||
|     parser.add_argument("--model_version", type=str, help="The network version.") | ||||
|     parser.add_argument( | ||||
|         "--KD_alpha", type=float, help="The alpha parameter in knowledge distillation." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--KD_temperature", | ||||
|         type=float, | ||||
|         help="The temperature parameter in knowledge distillation.", | ||||
|     ) | ||||
|     parser.add_argument("--Regular_W_feat", type=float, help="The .") | ||||
|     parser.add_argument("--Regular_W_conv", type=float, help="The .") | ||||
|     add_shared_args(parser) | ||||
|     # Optimization options | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=2, help="Batch size for training." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "save-path argument can not be None" | ||||
|     assert ( | ||||
|         args.keep_ratio > 0 and args.keep_ratio <= 1 | ||||
|     ), "invalid keep ratio : {:}".format(args.keep_ratio) | ||||
|     return args | ||||
| @@ -1,44 +0,0 @@ | ||||
| import os, sys, time, random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
|  | ||||
| def obtain_RandomSearch_args(): | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a classification model on typical image classification datasets.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||
|     parser.add_argument( | ||||
|         "--expect_flop", type=float, help="The expected flop keep ratio." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--arch_nums", | ||||
|         type=int, | ||||
|         help="The maximum number of running random arch generating..", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--model_config", type=str, help="The path to the model configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--random_mode", | ||||
|         type=str, | ||||
|         choices=["random", "fix"], | ||||
|         help="The path to the optimizer configuration", | ||||
|     ) | ||||
|     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||
|     add_shared_args(parser) | ||||
|     # Optimization options | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=2, help="Batch size for training." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "save-path argument can not be None" | ||||
|     # assert args.flop_ratio_min < args.flop_ratio_max, 'flop-ratio {:} vs {:}'.format(args.flop_ratio_min, args.flop_ratio_max) | ||||
|     return args | ||||
| @@ -1,53 +0,0 @@ | ||||
| import os, sys, time, random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
|  | ||||
| def obtain_search_args(): | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a classification model on typical image classification datasets.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument( | ||||
|         "--model_config", type=str, help="The path to the model configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||
|     ) | ||||
|     parser.add_argument("--split_path", type=str, help="The split file path.") | ||||
|     # parser.add_argument('--arch_para_pure',   type=int,                   help='The architecture-parameter pure or not.') | ||||
|     parser.add_argument( | ||||
|         "--gumbel_tau_max", type=float, help="The maximum tau for Gumbel." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--gumbel_tau_min", type=float, help="The minimum tau for Gumbel." | ||||
|     ) | ||||
|     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||
|     parser.add_argument("--FLOP_ratio", type=float, help="The expected FLOP ratio.") | ||||
|     parser.add_argument("--FLOP_weight", type=float, help="The loss weight for FLOP.") | ||||
|     parser.add_argument( | ||||
|         "--FLOP_tolerant", type=float, help="The tolerant range for FLOP." | ||||
|     ) | ||||
|     # ablation studies | ||||
|     parser.add_argument( | ||||
|         "--ablation_num_select", | ||||
|         type=int, | ||||
|         help="The number of randomly selected channels.", | ||||
|     ) | ||||
|     add_shared_args(parser) | ||||
|     # Optimization options | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=2, help="Batch size for training." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "save-path argument can not be None" | ||||
|     assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None | ||||
|     assert ( | ||||
|         args.FLOP_tolerant is not None and args.FLOP_tolerant > 0 | ||||
|     ), "invalid FLOP_tolerant : {:}".format(FLOP_tolerant) | ||||
|     # assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure) | ||||
|     # args.arch_para_pure = bool(args.arch_para_pure) | ||||
|     return args | ||||
| @@ -1,48 +0,0 @@ | ||||
| import os, sys, time, random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
|  | ||||
| def obtain_search_single_args(): | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a classification model on typical image classification datasets.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument( | ||||
|         "--model_config", type=str, help="The path to the model configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||
|     ) | ||||
|     parser.add_argument("--split_path", type=str, help="The split file path.") | ||||
|     parser.add_argument("--search_shape", type=str, help="The shape to be searched.") | ||||
|     # parser.add_argument('--arch_para_pure',   type=int,                   help='The architecture-parameter pure or not.') | ||||
|     parser.add_argument( | ||||
|         "--gumbel_tau_max", type=float, help="The maximum tau for Gumbel." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--gumbel_tau_min", type=float, help="The minimum tau for Gumbel." | ||||
|     ) | ||||
|     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||
|     parser.add_argument("--FLOP_ratio", type=float, help="The expected FLOP ratio.") | ||||
|     parser.add_argument("--FLOP_weight", type=float, help="The loss weight for FLOP.") | ||||
|     parser.add_argument( | ||||
|         "--FLOP_tolerant", type=float, help="The tolerant range for FLOP." | ||||
|     ) | ||||
|     add_shared_args(parser) | ||||
|     # Optimization options | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=2, help="Batch size for training." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "save-path argument can not be None" | ||||
|     assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None | ||||
|     assert ( | ||||
|         args.FLOP_tolerant is not None and args.FLOP_tolerant > 0 | ||||
|     ), "invalid FLOP_tolerant : {:}".format(FLOP_tolerant) | ||||
|     # assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure) | ||||
|     # args.arch_para_pure = bool(args.arch_para_pure) | ||||
|     return args | ||||
| @@ -1,39 +0,0 @@ | ||||
| import os, sys, time, random, argparse | ||||
|  | ||||
|  | ||||
| def add_shared_args(parser): | ||||
|     # Data Generation | ||||
|     parser.add_argument("--dataset", type=str, help="The dataset name.") | ||||
|     parser.add_argument("--data_path", type=str, help="The dataset name.") | ||||
|     parser.add_argument( | ||||
|         "--cutout_length", type=int, help="The cutout length, negative means not use." | ||||
|     ) | ||||
|     # Printing | ||||
|     parser.add_argument( | ||||
|         "--print_freq", type=int, default=100, help="print frequency (default: 200)" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--print_freq_eval", | ||||
|         type=int, | ||||
|         default=100, | ||||
|         help="print frequency (default: 200)", | ||||
|     ) | ||||
|     # Checkpoints | ||||
|     parser.add_argument( | ||||
|         "--eval_frequency", | ||||
|         type=int, | ||||
|         default=1, | ||||
|         help="evaluation frequency (default: 200)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", type=str, help="Folder to save checkpoints and log." | ||||
|     ) | ||||
|     # Acceleration | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=8, | ||||
|         help="number of data loading workers (default: 8)", | ||||
|     ) | ||||
|     # Random Seed | ||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||
| @@ -1,148 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import os, sys, hashlib, torch | ||||
| import numpy as np | ||||
| from PIL import Image | ||||
| import torch.utils.data as data | ||||
|  | ||||
| if sys.version_info[0] == 2: | ||||
|     import cPickle as pickle | ||||
| else: | ||||
|     import pickle | ||||
|  | ||||
|  | ||||
| def calculate_md5(fpath, chunk_size=1024 * 1024): | ||||
|     md5 = hashlib.md5() | ||||
|     with open(fpath, "rb") as f: | ||||
|         for chunk in iter(lambda: f.read(chunk_size), b""): | ||||
|             md5.update(chunk) | ||||
|     return md5.hexdigest() | ||||
|  | ||||
|  | ||||
| def check_md5(fpath, md5, **kwargs): | ||||
|     return md5 == calculate_md5(fpath, **kwargs) | ||||
|  | ||||
|  | ||||
| def check_integrity(fpath, md5=None): | ||||
|     if not os.path.isfile(fpath): | ||||
|         return False | ||||
|     if md5 is None: | ||||
|         return True | ||||
|     else: | ||||
|         return check_md5(fpath, md5) | ||||
|  | ||||
|  | ||||
| class ImageNet16(data.Dataset): | ||||
|     # http://image-net.org/download-images | ||||
|     # A Downsampled Variant of ImageNet as an Alternative to the CIFAR datasets | ||||
|     # https://arxiv.org/pdf/1707.08819.pdf | ||||
|  | ||||
|     train_list = [ | ||||
|         ["train_data_batch_1", "27846dcaa50de8e21a7d1a35f30f0e91"], | ||||
|         ["train_data_batch_2", "c7254a054e0e795c69120a5727050e3f"], | ||||
|         ["train_data_batch_3", "4333d3df2e5ffb114b05d2ffc19b1e87"], | ||||
|         ["train_data_batch_4", "1620cdf193304f4a92677b695d70d10f"], | ||||
|         ["train_data_batch_5", "348b3c2fdbb3940c4e9e834affd3b18d"], | ||||
|         ["train_data_batch_6", "6e765307c242a1b3d7d5ef9139b48945"], | ||||
|         ["train_data_batch_7", "564926d8cbf8fc4818ba23d2faac7564"], | ||||
|         ["train_data_batch_8", "f4755871f718ccb653440b9dd0ebac66"], | ||||
|         ["train_data_batch_9", "bb6dd660c38c58552125b1a92f86b5d4"], | ||||
|         ["train_data_batch_10", "8f03f34ac4b42271a294f91bf480f29b"], | ||||
|     ] | ||||
|     valid_list = [ | ||||
|         ["val_data", "3410e3017fdaefba8d5073aaa65e4bd6"], | ||||
|     ] | ||||
|  | ||||
|     def __init__(self, root, train, transform, use_num_of_class_only=None): | ||||
|         self.root = root | ||||
|         self.transform = transform | ||||
|         self.train = train  # training set or valid set | ||||
|         if not self._check_integrity(): | ||||
|             raise RuntimeError("Dataset not found or corrupted.") | ||||
|  | ||||
|         if self.train: | ||||
|             downloaded_list = self.train_list | ||||
|         else: | ||||
|             downloaded_list = self.valid_list | ||||
|         self.data = [] | ||||
|         self.targets = [] | ||||
|  | ||||
|         # now load the picked numpy arrays | ||||
|         for i, (file_name, checksum) in enumerate(downloaded_list): | ||||
|             file_path = os.path.join(self.root, file_name) | ||||
|             # print ('Load {:}/{:02d}-th : {:}'.format(i, len(downloaded_list), file_path)) | ||||
|             with open(file_path, "rb") as f: | ||||
|                 if sys.version_info[0] == 2: | ||||
|                     entry = pickle.load(f) | ||||
|                 else: | ||||
|                     entry = pickle.load(f, encoding="latin1") | ||||
|                 self.data.append(entry["data"]) | ||||
|                 self.targets.extend(entry["labels"]) | ||||
|         self.data = np.vstack(self.data).reshape(-1, 3, 16, 16) | ||||
|         self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC | ||||
|         if use_num_of_class_only is not None: | ||||
|             assert ( | ||||
|                 isinstance(use_num_of_class_only, int) | ||||
|                 and use_num_of_class_only > 0 | ||||
|                 and use_num_of_class_only < 1000 | ||||
|             ), "invalid use_num_of_class_only : {:}".format(use_num_of_class_only) | ||||
|             new_data, new_targets = [], [] | ||||
|             for I, L in zip(self.data, self.targets): | ||||
|                 if 1 <= L <= use_num_of_class_only: | ||||
|                     new_data.append(I) | ||||
|                     new_targets.append(L) | ||||
|             self.data = new_data | ||||
|             self.targets = new_targets | ||||
|         #    self.mean.append(entry['mean']) | ||||
|         # self.mean = np.vstack(self.mean).reshape(-1, 3, 16, 16) | ||||
|         # self.mean = np.mean(np.mean(np.mean(self.mean, axis=0), axis=1), axis=1) | ||||
|         # print ('Mean : {:}'.format(self.mean)) | ||||
|         # temp      = self.data - np.reshape(self.mean, (1, 1, 1, 3)) | ||||
|         # std_data  = np.std(temp, axis=0) | ||||
|         # std_data  = np.mean(np.mean(std_data, axis=0), axis=0) | ||||
|         # print ('Std  : {:}'.format(std_data)) | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({num} images, {classes} classes)".format( | ||||
|             name=self.__class__.__name__, | ||||
|             num=len(self.data), | ||||
|             classes=len(set(self.targets)), | ||||
|         ) | ||||
|  | ||||
|     def __getitem__(self, index): | ||||
|         img, target = self.data[index], self.targets[index] - 1 | ||||
|  | ||||
|         img = Image.fromarray(img) | ||||
|  | ||||
|         if self.transform is not None: | ||||
|             img = self.transform(img) | ||||
|  | ||||
|         return img, target | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self.data) | ||||
|  | ||||
|     def _check_integrity(self): | ||||
|         root = self.root | ||||
|         for fentry in self.train_list + self.valid_list: | ||||
|             filename, md5 = fentry[0], fentry[1] | ||||
|             fpath = os.path.join(root, filename) | ||||
|             if not check_integrity(fpath, md5): | ||||
|                 return False | ||||
|         return True | ||||
|  | ||||
|  | ||||
| """ | ||||
| if __name__ == '__main__': | ||||
|   train = ImageNet16('~/.torch/cifar.python/ImageNet16', True , None)  | ||||
|   valid = ImageNet16('~/.torch/cifar.python/ImageNet16', False, None)  | ||||
|  | ||||
|   print ( len(train) ) | ||||
|   print ( len(valid) ) | ||||
|   image, label = train[111] | ||||
|   trainX = ImageNet16('~/.torch/cifar.python/ImageNet16', True , None, 200) | ||||
|   validX = ImageNet16('~/.torch/cifar.python/ImageNet16', False , None, 200) | ||||
|   print ( len(trainX) ) | ||||
|   print ( len(validX) ) | ||||
| """ | ||||
| @@ -1,301 +0,0 @@ | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
| # | ||||
| from os import path as osp | ||||
| from copy import deepcopy as copy | ||||
| from tqdm import tqdm | ||||
| import warnings, time, random, numpy as np | ||||
|  | ||||
| from pts_utils import generate_label_map | ||||
| from xvision import denormalize_points | ||||
| from xvision import identity2affine, solve2theta, affine2image | ||||
| from .dataset_utils import pil_loader | ||||
| from .landmark_utils import PointMeta2V | ||||
| from .augmentation_utils import CutOut | ||||
| import torch | ||||
| import torch.utils.data as data | ||||
|  | ||||
|  | ||||
| class LandmarkDataset(data.Dataset): | ||||
|     def __init__( | ||||
|         self, | ||||
|         transform, | ||||
|         sigma, | ||||
|         downsample, | ||||
|         heatmap_type, | ||||
|         shape, | ||||
|         use_gray, | ||||
|         mean_file, | ||||
|         data_indicator, | ||||
|         cache_images=None, | ||||
|     ): | ||||
|  | ||||
|         self.transform = transform | ||||
|         self.sigma = sigma | ||||
|         self.downsample = downsample | ||||
|         self.heatmap_type = heatmap_type | ||||
|         self.dataset_name = data_indicator | ||||
|         self.shape = shape  # [H,W] | ||||
|         self.use_gray = use_gray | ||||
|         assert transform is not None, "transform : {:}".format(transform) | ||||
|         self.mean_file = mean_file | ||||
|         if mean_file is None: | ||||
|             self.mean_data = None | ||||
|             warnings.warn("LandmarkDataset initialized with mean_data = None") | ||||
|         else: | ||||
|             assert osp.isfile(mean_file), "{:} is not a file.".format(mean_file) | ||||
|             self.mean_data = torch.load(mean_file) | ||||
|         self.reset() | ||||
|         self.cutout = None | ||||
|         self.cache_images = cache_images | ||||
|         print("The general dataset initialization done : {:}".format(self)) | ||||
|         warnings.simplefilter("once") | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}(point-num={NUM_PTS}, shape={shape}, sigma={sigma}, heatmap_type={heatmap_type}, length={length}, cutout={cutout}, dataset={dataset_name}, mean={mean_file})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def set_cutout(self, length): | ||||
|         if length is not None and length >= 1: | ||||
|             self.cutout = CutOut(int(length)) | ||||
|         else: | ||||
|             self.cutout = None | ||||
|  | ||||
|     def reset(self, num_pts=-1, boxid="default", only_pts=False): | ||||
|         self.NUM_PTS = num_pts | ||||
|         if only_pts: | ||||
|             return | ||||
|         self.length = 0 | ||||
|         self.datas = [] | ||||
|         self.labels = [] | ||||
|         self.NormDistances = [] | ||||
|         self.BOXID = boxid | ||||
|         if self.mean_data is None: | ||||
|             self.mean_face = None | ||||
|         else: | ||||
|             self.mean_face = torch.Tensor(self.mean_data[boxid].copy().T) | ||||
|             assert (self.mean_face >= -1).all() and ( | ||||
|                 self.mean_face <= 1 | ||||
|             ).all(), "mean-{:}-face : {:}".format(boxid, self.mean_face) | ||||
|         # assert self.dataset_name is not None, 'The dataset name is None' | ||||
|  | ||||
|     def __len__(self): | ||||
|         assert len(self.datas) == self.length, "The length is not correct : {}".format( | ||||
|             self.length | ||||
|         ) | ||||
|         return self.length | ||||
|  | ||||
|     def append(self, data, label, distance): | ||||
|         assert osp.isfile(data), "The image path is not a file : {:}".format(data) | ||||
|         self.datas.append(data) | ||||
|         self.labels.append(label) | ||||
|         self.NormDistances.append(distance) | ||||
|         self.length = self.length + 1 | ||||
|  | ||||
|     def load_list(self, file_lists, num_pts, boxindicator, normalizeL, reset): | ||||
|         if reset: | ||||
|             self.reset(num_pts, boxindicator) | ||||
|         else: | ||||
|             assert ( | ||||
|                 self.NUM_PTS == num_pts and self.BOXID == boxindicator | ||||
|             ), "The number of point is inconsistance : {:} vs {:}".format( | ||||
|                 self.NUM_PTS, num_pts | ||||
|             ) | ||||
|         if isinstance(file_lists, str): | ||||
|             file_lists = [file_lists] | ||||
|         samples = [] | ||||
|         for idx, file_path in enumerate(file_lists): | ||||
|             print( | ||||
|                 ":::: load list {:}/{:} : {:}".format(idx, len(file_lists), file_path) | ||||
|             ) | ||||
|             xdata = torch.load(file_path) | ||||
|             if isinstance(xdata, list): | ||||
|                 data = xdata  # image or video dataset list | ||||
|             elif isinstance(xdata, dict): | ||||
|                 data = xdata["datas"]  # multi-view dataset list | ||||
|             else: | ||||
|                 raise ValueError("Invalid Type Error : {:}".format(type(xdata))) | ||||
|             samples = samples + data | ||||
|         # samples is a dict, where the key is the image-path and the value is the annotation | ||||
|         # each annotation is a dict, contains 'points' (3,num_pts), and various box | ||||
|         print("GeneralDataset-V2 : {:} samples".format(len(samples))) | ||||
|  | ||||
|         # for index, annotation in enumerate(samples): | ||||
|         for index in tqdm(range(len(samples))): | ||||
|             annotation = samples[index] | ||||
|             image_path = annotation["current_frame"] | ||||
|             points, box = ( | ||||
|                 annotation["points"], | ||||
|                 annotation["box-{:}".format(boxindicator)], | ||||
|             ) | ||||
|             label = PointMeta2V( | ||||
|                 self.NUM_PTS, points, box, image_path, self.dataset_name | ||||
|             ) | ||||
|             if normalizeL is None: | ||||
|                 normDistance = None | ||||
|             else: | ||||
|                 normDistance = annotation["normalizeL-{:}".format(normalizeL)] | ||||
|             self.append(image_path, label, normDistance) | ||||
|  | ||||
|         assert ( | ||||
|             len(self.datas) == self.length | ||||
|         ), "The length and the data is not right {} vs {}".format( | ||||
|             self.length, len(self.datas) | ||||
|         ) | ||||
|         assert ( | ||||
|             len(self.labels) == self.length | ||||
|         ), "The length and the labels is not right {} vs {}".format( | ||||
|             self.length, len(self.labels) | ||||
|         ) | ||||
|         assert ( | ||||
|             len(self.NormDistances) == self.length | ||||
|         ), "The length and the NormDistances is not right {} vs {}".format( | ||||
|             self.length, len(self.NormDistance) | ||||
|         ) | ||||
|         print( | ||||
|             "Load data done for LandmarkDataset, which has {:} images.".format( | ||||
|                 self.length | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|     def __getitem__(self, index): | ||||
|         assert index >= 0 and index < self.length, "Invalid index : {:}".format(index) | ||||
|         if self.cache_images is not None and self.datas[index] in self.cache_images: | ||||
|             image = self.cache_images[self.datas[index]].clone() | ||||
|         else: | ||||
|             image = pil_loader(self.datas[index], self.use_gray) | ||||
|         target = self.labels[index].copy() | ||||
|         return self._process_(image, target, index) | ||||
|  | ||||
|     def _process_(self, image, target, index): | ||||
|  | ||||
|         # transform the image and points | ||||
|         image, target, theta = self.transform(image, target) | ||||
|         (C, H, W), (height, width) = image.size(), self.shape | ||||
|  | ||||
|         # obtain the visiable indicator vector | ||||
|         if target.is_none(): | ||||
|             nopoints = True | ||||
|         else: | ||||
|             nopoints = False | ||||
|         if index == -1: | ||||
|             __path = None | ||||
|         else: | ||||
|             __path = self.datas[index] | ||||
|         if isinstance(theta, list) or isinstance(theta, tuple): | ||||
|             affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = ( | ||||
|                 [], | ||||
|                 [], | ||||
|                 [], | ||||
|                 [], | ||||
|                 [], | ||||
|                 [], | ||||
|             ) | ||||
|             for _theta in theta: | ||||
|                 ( | ||||
|                     _affineImage, | ||||
|                     _heatmaps, | ||||
|                     _mask, | ||||
|                     _norm_trans_points, | ||||
|                     _theta, | ||||
|                     _transpose_theta, | ||||
|                 ) = self.__process_affine( | ||||
|                     image, target, _theta, nopoints, "P[{:}]@{:}".format(index, __path) | ||||
|                 ) | ||||
|                 affineImage.append(_affineImage) | ||||
|                 heatmaps.append(_heatmaps) | ||||
|                 mask.append(_mask) | ||||
|                 norm_trans_points.append(_norm_trans_points) | ||||
|                 THETA.append(_theta) | ||||
|                 transpose_theta.append(_transpose_theta) | ||||
|             affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = ( | ||||
|                 torch.stack(affineImage), | ||||
|                 torch.stack(heatmaps), | ||||
|                 torch.stack(mask), | ||||
|                 torch.stack(norm_trans_points), | ||||
|                 torch.stack(THETA), | ||||
|                 torch.stack(transpose_theta), | ||||
|             ) | ||||
|         else: | ||||
|             ( | ||||
|                 affineImage, | ||||
|                 heatmaps, | ||||
|                 mask, | ||||
|                 norm_trans_points, | ||||
|                 THETA, | ||||
|                 transpose_theta, | ||||
|             ) = self.__process_affine( | ||||
|                 image, target, theta, nopoints, "S[{:}]@{:}".format(index, __path) | ||||
|             ) | ||||
|  | ||||
|         torch_index = torch.IntTensor([index]) | ||||
|         torch_nopoints = torch.ByteTensor([nopoints]) | ||||
|         torch_shape = torch.IntTensor([H, W]) | ||||
|  | ||||
|         return ( | ||||
|             affineImage, | ||||
|             heatmaps, | ||||
|             mask, | ||||
|             norm_trans_points, | ||||
|             THETA, | ||||
|             transpose_theta, | ||||
|             torch_index, | ||||
|             torch_nopoints, | ||||
|             torch_shape, | ||||
|         ) | ||||
|  | ||||
|     def __process_affine(self, image, target, theta, nopoints, aux_info=None): | ||||
|         image, target, theta = image.clone(), target.copy(), theta.clone() | ||||
|         (C, H, W), (height, width) = image.size(), self.shape | ||||
|         if nopoints:  # do not have label | ||||
|             norm_trans_points = torch.zeros((3, self.NUM_PTS)) | ||||
|             heatmaps = torch.zeros( | ||||
|                 (self.NUM_PTS + 1, height // self.downsample, width // self.downsample) | ||||
|             ) | ||||
|             mask = torch.ones((self.NUM_PTS + 1, 1, 1), dtype=torch.uint8) | ||||
|             transpose_theta = identity2affine(False) | ||||
|         else: | ||||
|             norm_trans_points = apply_affine2point(target.get_points(), theta, (H, W)) | ||||
|             norm_trans_points = apply_boundary(norm_trans_points) | ||||
|             real_trans_points = norm_trans_points.clone() | ||||
|             real_trans_points[:2, :] = denormalize_points( | ||||
|                 self.shape, real_trans_points[:2, :] | ||||
|             ) | ||||
|             heatmaps, mask = generate_label_map( | ||||
|                 real_trans_points.numpy(), | ||||
|                 height // self.downsample, | ||||
|                 width // self.downsample, | ||||
|                 self.sigma, | ||||
|                 self.downsample, | ||||
|                 nopoints, | ||||
|                 self.heatmap_type, | ||||
|             )  # H*W*C | ||||
|             heatmaps = torch.from_numpy(heatmaps.transpose((2, 0, 1))).type( | ||||
|                 torch.FloatTensor | ||||
|             ) | ||||
|             mask = torch.from_numpy(mask.transpose((2, 0, 1))).type(torch.ByteTensor) | ||||
|             if self.mean_face is None: | ||||
|                 # warnings.warn('In LandmarkDataset use identity2affine for transpose_theta because self.mean_face is None.') | ||||
|                 transpose_theta = identity2affine(False) | ||||
|             else: | ||||
|                 if torch.sum(norm_trans_points[2, :] == 1) < 3: | ||||
|                     warnings.warn( | ||||
|                         "In LandmarkDataset after transformation, no visiable point, using identity instead. Aux: {:}".format( | ||||
|                             aux_info | ||||
|                         ) | ||||
|                     ) | ||||
|                     transpose_theta = identity2affine(False) | ||||
|                 else: | ||||
|                     transpose_theta = solve2theta( | ||||
|                         norm_trans_points, self.mean_face.clone() | ||||
|                     ) | ||||
|  | ||||
|         affineImage = affine2image(image, theta, self.shape) | ||||
|         if self.cutout is not None: | ||||
|             affineImage = self.cutout(affineImage) | ||||
|  | ||||
|         return affineImage, heatmaps, mask, norm_trans_points, theta, transpose_theta | ||||
| @@ -1,54 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import torch, copy, random | ||||
| import torch.utils.data as data | ||||
|  | ||||
|  | ||||
| class SearchDataset(data.Dataset): | ||||
|     def __init__(self, name, data, train_split, valid_split, check=True): | ||||
|         self.datasetname = name | ||||
|         if isinstance(data, (list, tuple)):  # new type of SearchDataset | ||||
|             assert len(data) == 2, "invalid length: {:}".format(len(data)) | ||||
|             self.train_data = data[0] | ||||
|             self.valid_data = data[1] | ||||
|             self.train_split = train_split.copy() | ||||
|             self.valid_split = valid_split.copy() | ||||
|             self.mode_str = "V2"  # new mode | ||||
|         else: | ||||
|             self.mode_str = "V1"  # old mode | ||||
|             self.data = data | ||||
|             self.train_split = train_split.copy() | ||||
|             self.valid_split = valid_split.copy() | ||||
|             if check: | ||||
|                 intersection = set(train_split).intersection(set(valid_split)) | ||||
|                 assert ( | ||||
|                     len(intersection) == 0 | ||||
|                 ), "the splitted train and validation sets should have no intersection" | ||||
|         self.length = len(self.train_split) | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}(name={datasetname}, train={tr_L}, valid={val_L}, version={ver})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             datasetname=self.datasetname, | ||||
|             tr_L=len(self.train_split), | ||||
|             val_L=len(self.valid_split), | ||||
|             ver=self.mode_str, | ||||
|         ) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return self.length | ||||
|  | ||||
|     def __getitem__(self, index): | ||||
|         assert index >= 0 and index < self.length, "invalid index = {:}".format(index) | ||||
|         train_index = self.train_split[index] | ||||
|         valid_index = random.choice(self.valid_split) | ||||
|         if self.mode_str == "V1": | ||||
|             train_image, train_label = self.data[train_index] | ||||
|             valid_image, valid_label = self.data[valid_index] | ||||
|         elif self.mode_str == "V2": | ||||
|             train_image, train_label = self.train_data[train_index] | ||||
|             valid_image, valid_label = self.valid_data[valid_index] | ||||
|         else: | ||||
|             raise ValueError("invalid mode : {:}".format(self.mode_str)) | ||||
|         return train_image, train_label, valid_image, valid_label | ||||
| @@ -1,5 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| from .get_dataset_with_transform import get_datasets, get_nas_search_loaders | ||||
| from .SearchDatasetWrap import SearchDataset | ||||
| @@ -1,362 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import os, sys, torch | ||||
| import os.path as osp | ||||
| import numpy as np | ||||
| import torchvision.datasets as dset | ||||
| import torchvision.transforms as transforms | ||||
| from copy import deepcopy | ||||
| from PIL import Image | ||||
|  | ||||
| from .DownsampledImageNet import ImageNet16 | ||||
| from .SearchDatasetWrap import SearchDataset | ||||
| from config_utils import load_config | ||||
|  | ||||
|  | ||||
| Dataset2Class = { | ||||
|     "cifar10": 10, | ||||
|     "cifar100": 100, | ||||
|     "imagenet-1k-s": 1000, | ||||
|     "imagenet-1k": 1000, | ||||
|     "ImageNet16": 1000, | ||||
|     "ImageNet16-150": 150, | ||||
|     "ImageNet16-120": 120, | ||||
|     "ImageNet16-200": 200, | ||||
| } | ||||
|  | ||||
|  | ||||
| class CUTOUT(object): | ||||
|     def __init__(self, length): | ||||
|         self.length = length | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}(length={length})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def __call__(self, img): | ||||
|         h, w = img.size(1), img.size(2) | ||||
|         mask = np.ones((h, w), np.float32) | ||||
|         y = np.random.randint(h) | ||||
|         x = np.random.randint(w) | ||||
|  | ||||
|         y1 = np.clip(y - self.length // 2, 0, h) | ||||
|         y2 = np.clip(y + self.length // 2, 0, h) | ||||
|         x1 = np.clip(x - self.length // 2, 0, w) | ||||
|         x2 = np.clip(x + self.length // 2, 0, w) | ||||
|  | ||||
|         mask[y1:y2, x1:x2] = 0.0 | ||||
|         mask = torch.from_numpy(mask) | ||||
|         mask = mask.expand_as(img) | ||||
|         img *= mask | ||||
|         return img | ||||
|  | ||||
|  | ||||
| imagenet_pca = { | ||||
|     "eigval": np.asarray([0.2175, 0.0188, 0.0045]), | ||||
|     "eigvec": np.asarray( | ||||
|         [ | ||||
|             [-0.5675, 0.7192, 0.4009], | ||||
|             [-0.5808, -0.0045, -0.8140], | ||||
|             [-0.5836, -0.6948, 0.4203], | ||||
|         ] | ||||
|     ), | ||||
| } | ||||
|  | ||||
|  | ||||
| class Lighting(object): | ||||
|     def __init__( | ||||
|         self, alphastd, eigval=imagenet_pca["eigval"], eigvec=imagenet_pca["eigvec"] | ||||
|     ): | ||||
|         self.alphastd = alphastd | ||||
|         assert eigval.shape == (3,) | ||||
|         assert eigvec.shape == (3, 3) | ||||
|         self.eigval = eigval | ||||
|         self.eigvec = eigvec | ||||
|  | ||||
|     def __call__(self, img): | ||||
|         if self.alphastd == 0.0: | ||||
|             return img | ||||
|         rnd = np.random.randn(3) * self.alphastd | ||||
|         rnd = rnd.astype("float32") | ||||
|         v = rnd | ||||
|         old_dtype = np.asarray(img).dtype | ||||
|         v = v * self.eigval | ||||
|         v = v.reshape((3, 1)) | ||||
|         inc = np.dot(self.eigvec, v).reshape((3,)) | ||||
|         img = np.add(img, inc) | ||||
|         if old_dtype == np.uint8: | ||||
|             img = np.clip(img, 0, 255) | ||||
|         img = Image.fromarray(img.astype(old_dtype), "RGB") | ||||
|         return img | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return self.__class__.__name__ + "()" | ||||
|  | ||||
|  | ||||
| def get_datasets(name, root, cutout): | ||||
|  | ||||
|     if name == "cifar10": | ||||
|         mean = [x / 255 for x in [125.3, 123.0, 113.9]] | ||||
|         std = [x / 255 for x in [63.0, 62.1, 66.7]] | ||||
|     elif name == "cifar100": | ||||
|         mean = [x / 255 for x in [129.3, 124.1, 112.4]] | ||||
|         std = [x / 255 for x in [68.2, 65.4, 70.4]] | ||||
|     elif name.startswith("imagenet-1k"): | ||||
|         mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] | ||||
|     elif name.startswith("ImageNet16"): | ||||
|         mean = [x / 255 for x in [122.68, 116.66, 104.01]] | ||||
|         std = [x / 255 for x in [63.22, 61.26, 65.09]] | ||||
|     else: | ||||
|         raise TypeError("Unknow dataset : {:}".format(name)) | ||||
|  | ||||
|     # Data Argumentation | ||||
|     if name == "cifar10" or name == "cifar100": | ||||
|         lists = [ | ||||
|             transforms.RandomHorizontalFlip(), | ||||
|             transforms.RandomCrop(32, padding=4), | ||||
|             transforms.ToTensor(), | ||||
|             transforms.Normalize(mean, std), | ||||
|         ] | ||||
|         if cutout > 0: | ||||
|             lists += [CUTOUT(cutout)] | ||||
|         train_transform = transforms.Compose(lists) | ||||
|         test_transform = transforms.Compose( | ||||
|             [transforms.ToTensor(), transforms.Normalize(mean, std)] | ||||
|         ) | ||||
|         xshape = (1, 3, 32, 32) | ||||
|     elif name.startswith("ImageNet16"): | ||||
|         lists = [ | ||||
|             transforms.RandomHorizontalFlip(), | ||||
|             transforms.RandomCrop(16, padding=2), | ||||
|             transforms.ToTensor(), | ||||
|             transforms.Normalize(mean, std), | ||||
|         ] | ||||
|         if cutout > 0: | ||||
|             lists += [CUTOUT(cutout)] | ||||
|         train_transform = transforms.Compose(lists) | ||||
|         test_transform = transforms.Compose( | ||||
|             [transforms.ToTensor(), transforms.Normalize(mean, std)] | ||||
|         ) | ||||
|         xshape = (1, 3, 16, 16) | ||||
|     elif name == "tiered": | ||||
|         lists = [ | ||||
|             transforms.RandomHorizontalFlip(), | ||||
|             transforms.RandomCrop(80, padding=4), | ||||
|             transforms.ToTensor(), | ||||
|             transforms.Normalize(mean, std), | ||||
|         ] | ||||
|         if cutout > 0: | ||||
|             lists += [CUTOUT(cutout)] | ||||
|         train_transform = transforms.Compose(lists) | ||||
|         test_transform = transforms.Compose( | ||||
|             [ | ||||
|                 transforms.CenterCrop(80), | ||||
|                 transforms.ToTensor(), | ||||
|                 transforms.Normalize(mean, std), | ||||
|             ] | ||||
|         ) | ||||
|         xshape = (1, 3, 32, 32) | ||||
|     elif name.startswith("imagenet-1k"): | ||||
|         normalize = transforms.Normalize( | ||||
|             mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | ||||
|         ) | ||||
|         if name == "imagenet-1k": | ||||
|             xlists = [transforms.RandomResizedCrop(224)] | ||||
|             xlists.append( | ||||
|                 transforms.ColorJitter( | ||||
|                     brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2 | ||||
|                 ) | ||||
|             ) | ||||
|             xlists.append(Lighting(0.1)) | ||||
|         elif name == "imagenet-1k-s": | ||||
|             xlists = [transforms.RandomResizedCrop(224, scale=(0.2, 1.0))] | ||||
|         else: | ||||
|             raise ValueError("invalid name : {:}".format(name)) | ||||
|         xlists.append(transforms.RandomHorizontalFlip(p=0.5)) | ||||
|         xlists.append(transforms.ToTensor()) | ||||
|         xlists.append(normalize) | ||||
|         train_transform = transforms.Compose(xlists) | ||||
|         test_transform = transforms.Compose( | ||||
|             [ | ||||
|                 transforms.Resize(256), | ||||
|                 transforms.CenterCrop(224), | ||||
|                 transforms.ToTensor(), | ||||
|                 normalize, | ||||
|             ] | ||||
|         ) | ||||
|         xshape = (1, 3, 224, 224) | ||||
|     else: | ||||
|         raise TypeError("Unknow dataset : {:}".format(name)) | ||||
|  | ||||
|     if name == "cifar10": | ||||
|         train_data = dset.CIFAR10( | ||||
|             root, train=True, transform=train_transform, download=True | ||||
|         ) | ||||
|         test_data = dset.CIFAR10( | ||||
|             root, train=False, transform=test_transform, download=True | ||||
|         ) | ||||
|         assert len(train_data) == 50000 and len(test_data) == 10000 | ||||
|     elif name == "cifar100": | ||||
|         train_data = dset.CIFAR100( | ||||
|             root, train=True, transform=train_transform, download=True | ||||
|         ) | ||||
|         test_data = dset.CIFAR100( | ||||
|             root, train=False, transform=test_transform, download=True | ||||
|         ) | ||||
|         assert len(train_data) == 50000 and len(test_data) == 10000 | ||||
|     elif name.startswith("imagenet-1k"): | ||||
|         train_data = dset.ImageFolder(osp.join(root, "train"), train_transform) | ||||
|         test_data = dset.ImageFolder(osp.join(root, "val"), test_transform) | ||||
|         assert ( | ||||
|             len(train_data) == 1281167 and len(test_data) == 50000 | ||||
|         ), "invalid number of images : {:} & {:} vs {:} & {:}".format( | ||||
|             len(train_data), len(test_data), 1281167, 50000 | ||||
|         ) | ||||
|     elif name == "ImageNet16": | ||||
|         train_data = ImageNet16(root, True, train_transform) | ||||
|         test_data = ImageNet16(root, False, test_transform) | ||||
|         assert len(train_data) == 1281167 and len(test_data) == 50000 | ||||
|     elif name == "ImageNet16-120": | ||||
|         train_data = ImageNet16(root, True, train_transform, 120) | ||||
|         test_data = ImageNet16(root, False, test_transform, 120) | ||||
|         assert len(train_data) == 151700 and len(test_data) == 6000 | ||||
|     elif name == "ImageNet16-150": | ||||
|         train_data = ImageNet16(root, True, train_transform, 150) | ||||
|         test_data = ImageNet16(root, False, test_transform, 150) | ||||
|         assert len(train_data) == 190272 and len(test_data) == 7500 | ||||
|     elif name == "ImageNet16-200": | ||||
|         train_data = ImageNet16(root, True, train_transform, 200) | ||||
|         test_data = ImageNet16(root, False, test_transform, 200) | ||||
|         assert len(train_data) == 254775 and len(test_data) == 10000 | ||||
|     else: | ||||
|         raise TypeError("Unknow dataset : {:}".format(name)) | ||||
|  | ||||
|     class_num = Dataset2Class[name] | ||||
|     return train_data, test_data, xshape, class_num | ||||
|  | ||||
|  | ||||
| def get_nas_search_loaders( | ||||
|     train_data, valid_data, dataset, config_root, batch_size, workers | ||||
| ): | ||||
|     if isinstance(batch_size, (list, tuple)): | ||||
|         batch, test_batch = batch_size | ||||
|     else: | ||||
|         batch, test_batch = batch_size, batch_size | ||||
|     if dataset == "cifar10": | ||||
|         # split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||
|         cifar_split = load_config("{:}/cifar-split.txt".format(config_root), None, None) | ||||
|         train_split, valid_split = ( | ||||
|             cifar_split.train, | ||||
|             cifar_split.valid, | ||||
|         )  # search over the proposed training and validation set | ||||
|         # logger.log('Load split file from {:}'.format(split_Fpath))      # they are two disjoint groups in the original CIFAR-10 training set | ||||
|         # To split data | ||||
|         xvalid_data = deepcopy(train_data) | ||||
|         if hasattr(xvalid_data, "transforms"):  # to avoid a print issue | ||||
|             xvalid_data.transforms = valid_data.transform | ||||
|         xvalid_data.transform = deepcopy(valid_data.transform) | ||||
|         search_data = SearchDataset(dataset, train_data, train_split, valid_split) | ||||
|         # data loader | ||||
|         search_loader = torch.utils.data.DataLoader( | ||||
|             search_data, | ||||
|             batch_size=batch, | ||||
|             shuffle=True, | ||||
|             num_workers=workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|         train_loader = torch.utils.data.DataLoader( | ||||
|             train_data, | ||||
|             batch_size=batch, | ||||
|             sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), | ||||
|             num_workers=workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|         valid_loader = torch.utils.data.DataLoader( | ||||
|             xvalid_data, | ||||
|             batch_size=test_batch, | ||||
|             sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), | ||||
|             num_workers=workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|     elif dataset == "cifar100": | ||||
|         cifar100_test_split = load_config( | ||||
|             "{:}/cifar100-test-split.txt".format(config_root), None, None | ||||
|         ) | ||||
|         search_train_data = train_data | ||||
|         search_valid_data = deepcopy(valid_data) | ||||
|         search_valid_data.transform = train_data.transform | ||||
|         search_data = SearchDataset( | ||||
|             dataset, | ||||
|             [search_train_data, search_valid_data], | ||||
|             list(range(len(search_train_data))), | ||||
|             cifar100_test_split.xvalid, | ||||
|         ) | ||||
|         search_loader = torch.utils.data.DataLoader( | ||||
|             search_data, | ||||
|             batch_size=batch, | ||||
|             shuffle=True, | ||||
|             num_workers=workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|         train_loader = torch.utils.data.DataLoader( | ||||
|             train_data, | ||||
|             batch_size=batch, | ||||
|             shuffle=True, | ||||
|             num_workers=workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|         valid_loader = torch.utils.data.DataLoader( | ||||
|             valid_data, | ||||
|             batch_size=test_batch, | ||||
|             sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||
|                 cifar100_test_split.xvalid | ||||
|             ), | ||||
|             num_workers=workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|     elif dataset == "ImageNet16-120": | ||||
|         imagenet_test_split = load_config( | ||||
|             "{:}/imagenet-16-120-test-split.txt".format(config_root), None, None | ||||
|         ) | ||||
|         search_train_data = train_data | ||||
|         search_valid_data = deepcopy(valid_data) | ||||
|         search_valid_data.transform = train_data.transform | ||||
|         search_data = SearchDataset( | ||||
|             dataset, | ||||
|             [search_train_data, search_valid_data], | ||||
|             list(range(len(search_train_data))), | ||||
|             imagenet_test_split.xvalid, | ||||
|         ) | ||||
|         search_loader = torch.utils.data.DataLoader( | ||||
|             search_data, | ||||
|             batch_size=batch, | ||||
|             shuffle=True, | ||||
|             num_workers=workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|         train_loader = torch.utils.data.DataLoader( | ||||
|             train_data, | ||||
|             batch_size=batch, | ||||
|             shuffle=True, | ||||
|             num_workers=workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|         valid_loader = torch.utils.data.DataLoader( | ||||
|             valid_data, | ||||
|             batch_size=test_batch, | ||||
|             sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||
|                 imagenet_test_split.xvalid | ||||
|             ), | ||||
|             num_workers=workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|     else: | ||||
|         raise ValueError("invalid dataset : {:}".format(dataset)) | ||||
|     return search_loader, train_loader, valid_loader | ||||
|  | ||||
|  | ||||
| # if __name__ == '__main__': | ||||
| #  train_data, test_data, xshape, class_num = dataset = get_datasets('cifar10', '/data02/dongxuanyi/.torch/cifar.python/', -1) | ||||
| #  import pdb; pdb.set_trace() | ||||
| @@ -1 +0,0 @@ | ||||
| from .point_meta import PointMeta2V, apply_affine2point, apply_boundary | ||||
| @@ -1,219 +0,0 @@ | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
| # | ||||
| import copy, math, torch, numpy as np | ||||
| from xvision import normalize_points | ||||
| from xvision import denormalize_points | ||||
|  | ||||
|  | ||||
| class PointMeta: | ||||
|     # points    : 3 x num_pts (x, y, oculusion) | ||||
|     # image_size: original [width, height] | ||||
|     def __init__(self, num_point, points, box, image_path, dataset_name): | ||||
|  | ||||
|         self.num_point = num_point | ||||
|         if box is not None: | ||||
|             assert (isinstance(box, tuple) or isinstance(box, list)) and len(box) == 4 | ||||
|             self.box = torch.Tensor(box) | ||||
|         else: | ||||
|             self.box = None | ||||
|         if points is None: | ||||
|             self.points = points | ||||
|         else: | ||||
|             assert ( | ||||
|                 len(points.shape) == 2 | ||||
|                 and points.shape[0] == 3 | ||||
|                 and points.shape[1] == self.num_point | ||||
|             ), "The shape of point is not right : {}".format(points) | ||||
|             self.points = torch.Tensor(points.copy()) | ||||
|         self.image_path = image_path | ||||
|         self.datasets = dataset_name | ||||
|  | ||||
|     def __repr__(self): | ||||
|         if self.box is None: | ||||
|             boxstr = "None" | ||||
|         else: | ||||
|             boxstr = "box=[{:.1f}, {:.1f}, {:.1f}, {:.1f}]".format(*self.box.tolist()) | ||||
|         return ( | ||||
|             "{name}(points={num_point}, ".format( | ||||
|                 name=self.__class__.__name__, **self.__dict__ | ||||
|             ) | ||||
|             + boxstr | ||||
|             + ")" | ||||
|         ) | ||||
|  | ||||
|     def get_box(self, return_diagonal=False): | ||||
|         if self.box is None: | ||||
|             return None | ||||
|         if not return_diagonal: | ||||
|             return self.box.clone() | ||||
|         else: | ||||
|             W = (self.box[2] - self.box[0]).item() | ||||
|             H = (self.box[3] - self.box[1]).item() | ||||
|             return math.sqrt(H * H + W * W) | ||||
|  | ||||
|     def get_points(self, ignore_indicator=False): | ||||
|         if ignore_indicator: | ||||
|             last = 2 | ||||
|         else: | ||||
|             last = 3 | ||||
|         if self.points is not None: | ||||
|             return self.points.clone()[:last, :] | ||||
|         else: | ||||
|             return torch.zeros((last, self.num_point)) | ||||
|  | ||||
|     def is_none(self): | ||||
|         # assert self.box is not None, 'The box should not be None' | ||||
|         return self.points is None | ||||
|         # if self.box is None: return True | ||||
|         # else               : return self.points is None | ||||
|  | ||||
|     def copy(self): | ||||
|         return copy.deepcopy(self) | ||||
|  | ||||
|     def visiable_pts_num(self): | ||||
|         with torch.no_grad(): | ||||
|             ans = self.points[2, :] > 0 | ||||
|             ans = torch.sum(ans) | ||||
|             ans = ans.item() | ||||
|         return ans | ||||
|  | ||||
|     def special_fun(self, indicator): | ||||
|         if ( | ||||
|             indicator == "68to49" | ||||
|         ):  # For 300W or 300VW, convert the default 68 points to 49 points. | ||||
|             assert self.num_point == 68, "num-point must be 68 vs. {:}".format( | ||||
|                 self.num_point | ||||
|             ) | ||||
|             self.num_point = 49 | ||||
|             out = torch.ones((68), dtype=torch.uint8) | ||||
|             out[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 60, 64]] = 0 | ||||
|             if self.points is not None: | ||||
|                 self.points = self.points.clone()[:, out] | ||||
|         else: | ||||
|             raise ValueError("Invalid indicator : {:}".format(indicator)) | ||||
|  | ||||
|     def apply_horizontal_flip(self): | ||||
|         # self.points[0, :] = width - self.points[0, :] - 1 | ||||
|         # Mugsy spefic or Synthetic | ||||
|         if self.datasets.startswith("HandsyROT"): | ||||
|             ori = np.array(list(range(0, 42))) | ||||
|             pos = np.array(list(range(21, 42)) + list(range(0, 21))) | ||||
|             self.points[:, pos] = self.points[:, ori] | ||||
|         elif self.datasets.startswith("face68"): | ||||
|             ori = np.array(list(range(0, 68))) | ||||
|             pos = ( | ||||
|                 np.array( | ||||
|                     [ | ||||
|                         17, | ||||
|                         16, | ||||
|                         15, | ||||
|                         14, | ||||
|                         13, | ||||
|                         12, | ||||
|                         11, | ||||
|                         10, | ||||
|                         9, | ||||
|                         8, | ||||
|                         7, | ||||
|                         6, | ||||
|                         5, | ||||
|                         4, | ||||
|                         3, | ||||
|                         2, | ||||
|                         1, | ||||
|                         27, | ||||
|                         26, | ||||
|                         25, | ||||
|                         24, | ||||
|                         23, | ||||
|                         22, | ||||
|                         21, | ||||
|                         20, | ||||
|                         19, | ||||
|                         18, | ||||
|                         28, | ||||
|                         29, | ||||
|                         30, | ||||
|                         31, | ||||
|                         36, | ||||
|                         35, | ||||
|                         34, | ||||
|                         33, | ||||
|                         32, | ||||
|                         46, | ||||
|                         45, | ||||
|                         44, | ||||
|                         43, | ||||
|                         48, | ||||
|                         47, | ||||
|                         40, | ||||
|                         39, | ||||
|                         38, | ||||
|                         37, | ||||
|                         42, | ||||
|                         41, | ||||
|                         55, | ||||
|                         54, | ||||
|                         53, | ||||
|                         52, | ||||
|                         51, | ||||
|                         50, | ||||
|                         49, | ||||
|                         60, | ||||
|                         59, | ||||
|                         58, | ||||
|                         57, | ||||
|                         56, | ||||
|                         65, | ||||
|                         64, | ||||
|                         63, | ||||
|                         62, | ||||
|                         61, | ||||
|                         68, | ||||
|                         67, | ||||
|                         66, | ||||
|                     ] | ||||
|                 ) | ||||
|                 - 1 | ||||
|             ) | ||||
|             self.points[:, ori] = self.points[:, pos] | ||||
|         else: | ||||
|             raise ValueError("Does not support {:}".format(self.datasets)) | ||||
|  | ||||
|  | ||||
| # shape = (H,W) | ||||
| def apply_affine2point(points, theta, shape): | ||||
|     assert points.size(0) == 3, "invalid points shape : {:}".format(points.size()) | ||||
|     with torch.no_grad(): | ||||
|         ok_points = points[2, :] == 1 | ||||
|         assert torch.sum(ok_points).item() > 0, "there is no visiable point" | ||||
|         points[:2, :] = normalize_points(shape, points[:2, :]) | ||||
|  | ||||
|         norm_trans_points = ok_points.unsqueeze(0).repeat(3, 1).float() | ||||
|  | ||||
|         trans_points, ___ = torch.gesv(points[:, ok_points], theta) | ||||
|  | ||||
|         norm_trans_points[:, ok_points] = trans_points | ||||
|  | ||||
|     return norm_trans_points | ||||
|  | ||||
|  | ||||
| def apply_boundary(norm_trans_points): | ||||
|     with torch.no_grad(): | ||||
|         norm_trans_points = norm_trans_points.clone() | ||||
|         oks = torch.stack( | ||||
|             ( | ||||
|                 norm_trans_points[0] > -1, | ||||
|                 norm_trans_points[0] < 1, | ||||
|                 norm_trans_points[1] > -1, | ||||
|                 norm_trans_points[1] < 1, | ||||
|                 norm_trans_points[2] > 0, | ||||
|             ) | ||||
|         ) | ||||
|         oks = torch.sum(oks, dim=0) == 5 | ||||
|         norm_trans_points[2, :] = oks | ||||
|     return norm_trans_points | ||||
| @@ -1,100 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
| import math | ||||
| import abc | ||||
| import copy | ||||
| import numpy as np | ||||
| from typing import Optional | ||||
| import torch | ||||
| import torch.utils.data as data | ||||
|  | ||||
| from .math_base_funcs import FitFunc | ||||
| from .math_base_funcs import QuadraticFunc | ||||
| from .math_base_funcs import QuarticFunc | ||||
|  | ||||
|  | ||||
| class ConstantFunc(FitFunc): | ||||
|     """The constant function: f(x) = c.""" | ||||
|  | ||||
|     def __init__(self, constant=None): | ||||
|         param = dict() | ||||
|         param[0] = constant | ||||
|         super(ConstantFunc, self).__init__(0, None, param) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
|         return self._params[0] | ||||
|  | ||||
|     def fit(self, **kwargs): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def _getitem(self, x, weights): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a})".format(name=self.__class__.__name__, a=self._params[0]) | ||||
|  | ||||
|  | ||||
| class ComposedSinFunc(FitFunc): | ||||
|     """The composed sin function that outputs: | ||||
|       f(x) = amplitude-scale-of(x) * sin( period-phase-shift-of(x) ) | ||||
|     - the amplitude scale is a quadratic function of x | ||||
|     - the period-phase-shift is another quadratic function of x | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, **kwargs): | ||||
|         super(ComposedSinFunc, self).__init__(0, None) | ||||
|         self.fit(**kwargs) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
|         scale = self._params["amplitude_scale"](x) | ||||
|         period_phase = self._params["period_phase_shift"](x) | ||||
|         return scale * math.sin(period_phase) | ||||
|  | ||||
|     def fit(self, **kwargs): | ||||
|         num_sin_phase = kwargs.get("num_sin_phase", 7) | ||||
|         sin_speed_use_power = kwargs.get("sin_speed_use_power", True) | ||||
|         min_amplitude = kwargs.get("min_amplitude", 1) | ||||
|         max_amplitude = kwargs.get("max_amplitude", 4) | ||||
|         phase_shift = kwargs.get("phase_shift", 0.0) | ||||
|         # create parameters | ||||
|         if kwargs.get("amplitude_scale", None) is None: | ||||
|             amplitude_scale = QuadraticFunc( | ||||
|                 [(0, min_amplitude), (0.5, max_amplitude), (1, min_amplitude)] | ||||
|             ) | ||||
|         else: | ||||
|             amplitude_scale = kwargs.get("amplitude_scale") | ||||
|         if kwargs.get("period_phase_shift", None) is None: | ||||
|             fitting_data = [] | ||||
|             if sin_speed_use_power: | ||||
|                 temp_max_scalar = 2 ** (num_sin_phase - 1) | ||||
|             else: | ||||
|                 temp_max_scalar = num_sin_phase - 1 | ||||
|             for i in range(num_sin_phase): | ||||
|                 if sin_speed_use_power: | ||||
|                     value = (2 ** i) / temp_max_scalar | ||||
|                     next_value = (2 ** (i + 1)) / temp_max_scalar | ||||
|                 else: | ||||
|                     value = i / temp_max_scalar | ||||
|                     next_value = (i + 1) / temp_max_scalar | ||||
|                 for _phase in (0, 0.25, 0.5, 0.75): | ||||
|                     inter_value = value + (next_value - value) * _phase | ||||
|                     fitting_data.append((inter_value, math.pi * (2 * i + _phase))) | ||||
|             period_phase_shift = QuarticFunc(fitting_data) | ||||
|         else: | ||||
|             period_phase_shift = kwargs.get("period_phase_shift") | ||||
|         self.set( | ||||
|             dict(amplitude_scale=amplitude_scale, period_phase_shift=period_phase_shift) | ||||
|         ) | ||||
|  | ||||
|     def _getitem(self, x, weights): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({amplitude_scale} * sin({period_phase_shift}))".format( | ||||
|             name=self.__class__.__name__, | ||||
|             amplitude_scale=self._params["amplitude_scale"], | ||||
|             period_phase_shift=self._params["period_phase_shift"], | ||||
|         ) | ||||
| @@ -1,210 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
| import math | ||||
| import abc | ||||
| import copy | ||||
| import numpy as np | ||||
| from typing import Optional | ||||
| import torch | ||||
| import torch.utils.data as data | ||||
|  | ||||
|  | ||||
| class FitFunc(abc.ABC): | ||||
|     """The fit function that outputs f(x) = a * x^2 + b * x + c.""" | ||||
|  | ||||
|     def __init__(self, freedom: int, list_of_points=None, params=None): | ||||
|         self._params = dict() | ||||
|         for i in range(freedom): | ||||
|             self._params[i] = None | ||||
|         self._freedom = freedom | ||||
|         if list_of_points is not None and params is not None: | ||||
|             raise ValueError("list_of_points and params can not be set simultaneously") | ||||
|         if list_of_points is not None: | ||||
|             self.fit(list_of_points=list_of_points) | ||||
|         if params is not None: | ||||
|             self.set(params) | ||||
|  | ||||
|     def set(self, params): | ||||
|         self._params = copy.deepcopy(params) | ||||
|  | ||||
|     def check_valid(self): | ||||
|         for key, value in self._params.items(): | ||||
|             if value is None: | ||||
|                 raise ValueError("The {:} is None".format(key)) | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def __call__(self, x): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def noise_call(self, x, std=0.1): | ||||
|         clean_y = self.__call__(x) | ||||
|         if isinstance(clean_y, np.ndarray): | ||||
|             noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape) | ||||
|         else: | ||||
|             raise ValueError("Unkonwn type: {:}".format(type(clean_y))) | ||||
|         return noise_y | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def _getitem(self, x): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def fit(self, **kwargs): | ||||
|         list_of_points = kwargs["list_of_points"] | ||||
|         max_iter, lr_max, verbose = ( | ||||
|             kwargs.get("max_iter", 900), | ||||
|             kwargs.get("lr_max", 1.0), | ||||
|             kwargs.get("verbose", False), | ||||
|         ) | ||||
|         with torch.no_grad(): | ||||
|             data = torch.Tensor(list_of_points).type(torch.float32) | ||||
|             assert data.ndim == 2 and data.size(1) == 2, "Invalid shape : {:}".format( | ||||
|                 data.shape | ||||
|             ) | ||||
|             x, y = data[:, 0], data[:, 1] | ||||
|         weights = torch.nn.Parameter(torch.Tensor(self._freedom)) | ||||
|         torch.nn.init.normal_(weights, mean=0.0, std=1.0) | ||||
|         optimizer = torch.optim.Adam([weights], lr=lr_max, amsgrad=True) | ||||
|         lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||
|             optimizer, | ||||
|             milestones=[ | ||||
|                 int(max_iter * 0.25), | ||||
|                 int(max_iter * 0.5), | ||||
|                 int(max_iter * 0.75), | ||||
|             ], | ||||
|             gamma=0.1, | ||||
|         ) | ||||
|         if verbose: | ||||
|             print("The optimizer: {:}".format(optimizer)) | ||||
|  | ||||
|         best_loss = None | ||||
|         for _iter in range(max_iter): | ||||
|             y_hat = self._getitem(x, weights) | ||||
|             loss = torch.mean(torch.abs(y - y_hat)) | ||||
|             optimizer.zero_grad() | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|             lr_scheduler.step() | ||||
|             if verbose: | ||||
|                 print( | ||||
|                     "In the fit, loss at the {:02d}/{:02d}-th iter is {:}".format( | ||||
|                         _iter, max_iter, loss.item() | ||||
|                     ) | ||||
|                 ) | ||||
|             # Update the params | ||||
|             if best_loss is None or best_loss > loss.item(): | ||||
|                 best_loss = loss.item() | ||||
|                 for i in range(self._freedom): | ||||
|                     self._params[i] = weights[i].item() | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}(freedom={freedom})".format( | ||||
|             name=self.__class__.__name__, freedom=freedom | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class LinearFunc(FitFunc): | ||||
|     """The linear function that outputs f(x) = a * x + b.""" | ||||
|  | ||||
|     def __init__(self, list_of_points=None, params=None): | ||||
|         super(LinearFunc, self).__init__(2, list_of_points, params) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
|         return self._params[0] * x + self._params[1] | ||||
|  | ||||
|     def _getitem(self, x, weights): | ||||
|         return weights[0] * x + weights[1] | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * x + {b})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class QuadraticFunc(FitFunc): | ||||
|     """The quadratic function that outputs f(x) = a * x^2 + b * x + c.""" | ||||
|  | ||||
|     def __init__(self, list_of_points=None, params=None): | ||||
|         super(QuadraticFunc, self).__init__(3, list_of_points, params) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
|         return self._params[0] * x * x + self._params[1] * x + self._params[2] | ||||
|  | ||||
|     def _getitem(self, x, weights): | ||||
|         return weights[0] * x * x + weights[1] * x + weights[2] | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * x^2 + {b} * x + {c})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
|             c=self._params[2], | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class CubicFunc(FitFunc): | ||||
|     """The cubic function that outputs f(x) = a * x^3 + b * x^2 + c * x + d.""" | ||||
|  | ||||
|     def __init__(self, list_of_points=None): | ||||
|         super(CubicFunc, self).__init__(4, list_of_points) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
|         return ( | ||||
|             self._params[0] * x ** 3 | ||||
|             + self._params[1] * x ** 2 | ||||
|             + self._params[2] * x | ||||
|             + self._params[3] | ||||
|         ) | ||||
|  | ||||
|     def _getitem(self, x, weights): | ||||
|         return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3] | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * x^3 + {b} * x^2 + {c} * x + {d})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
|             c=self._params[2], | ||||
|             d=self._params[3], | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class QuarticFunc(FitFunc): | ||||
|     """The quartic function that outputs f(x) = a * x^4 + b * x^3 + c * x^2 + d * x + e.""" | ||||
|  | ||||
|     def __init__(self, list_of_points=None): | ||||
|         super(QuarticFunc, self).__init__(5, list_of_points) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
|         return ( | ||||
|             self._params[0] * x ** 4 | ||||
|             + self._params[1] * x ** 3 | ||||
|             + self._params[2] * x ** 2 | ||||
|             + self._params[3] * x | ||||
|             + self._params[4] | ||||
|         ) | ||||
|  | ||||
|     def _getitem(self, x, weights): | ||||
|         return ( | ||||
|             weights[0] * x ** 4 | ||||
|             + weights[1] * x ** 3 | ||||
|             + weights[2] * x ** 2 | ||||
|             + weights[3] * x | ||||
|             + weights[4] | ||||
|         ) | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * x^4 + {b} * x^3 + {c} * x^2 + {d} * x + {e})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
|             c=self._params[2], | ||||
|             d=self._params[3], | ||||
|             e=self._params[3], | ||||
|         ) | ||||
| @@ -1,8 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 # | ||||
| ##################################################### | ||||
| from .math_base_funcs import LinearFunc, QuadraticFunc, CubicFunc, QuarticFunc | ||||
| from .math_dynamic_funcs import DynamicLinearFunc | ||||
| from .math_dynamic_funcs import DynamicQuadraticFunc | ||||
| from .math_adv_funcs import ConstantFunc | ||||
| from .math_adv_funcs import ComposedSinFunc | ||||
| @@ -1,93 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
| import math | ||||
| import abc | ||||
| import copy | ||||
| import numpy as np | ||||
| from typing import Optional | ||||
| import torch | ||||
| import torch.utils.data as data | ||||
|  | ||||
| from .math_base_funcs import FitFunc | ||||
|  | ||||
|  | ||||
| class DynamicFunc(FitFunc): | ||||
|     """The dynamic quadratic function, where each param is a function.""" | ||||
|  | ||||
|     def __init__(self, freedom: int, params=None): | ||||
|         super(DynamicFunc, self).__init__(freedom, None, params) | ||||
|         self._timestamp = None | ||||
|  | ||||
|     def __call__(self, x, timestamp=None): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def _getitem(self, x, weights): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def set_timestamp(self, timestamp): | ||||
|         self._timestamp = timestamp | ||||
|  | ||||
|     def noise_call(self, x, timestamp=None, std=0.1): | ||||
|         clean_y = self.__call__(x, timestamp) | ||||
|         if isinstance(clean_y, np.ndarray): | ||||
|             noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape) | ||||
|         else: | ||||
|             raise ValueError("Unkonwn type: {:}".format(type(clean_y))) | ||||
|         return noise_y | ||||
|  | ||||
|  | ||||
| class DynamicLinearFunc(DynamicFunc): | ||||
|     """The dynamic linear function that outputs f(x) = a * x + b. | ||||
|     The a and b is a function of timestamp. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, params=None): | ||||
|         super(DynamicLinearFunc, self).__init__(3, params) | ||||
|  | ||||
|     def __call__(self, x, timestamp=None): | ||||
|         self.check_valid() | ||||
|         if timestamp is None: | ||||
|             timestamp = self._timestamp | ||||
|         a = self._params[0](timestamp) | ||||
|         b = self._params[1](timestamp) | ||||
|         convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x | ||||
|         a, b = convert_fn(a), convert_fn(b) | ||||
|         return a * x + b | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * x + {b}, timestamp={timestamp})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
|             timestamp=self._timestamp, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class DynamicQuadraticFunc(DynamicFunc): | ||||
|     """The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c. | ||||
|     The a, b, and c is a function of timestamp. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, params=None): | ||||
|         super(DynamicQuadraticFunc, self).__init__(3, params) | ||||
|  | ||||
|     def __call__(self, x, timestamp=None): | ||||
|         self.check_valid() | ||||
|         if timestamp is None: | ||||
|             timestamp = self._timestamp | ||||
|         a = self._params[0](timestamp) | ||||
|         b = self._params[1](timestamp) | ||||
|         c = self._params[2](timestamp) | ||||
|         convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x | ||||
|         a, b, c = convert_fn(a), convert_fn(b), convert_fn(c) | ||||
|         return a * x * x + b * x + c | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * x^2 + {b} * x + {c}, timestamp={timestamp})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
|             c=self._params[2], | ||||
|             timestamp=self._timestamp, | ||||
|         ) | ||||
| @@ -1,58 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 # | ||||
| ##################################################### | ||||
| from .synthetic_utils import TimeStamp | ||||
| from .synthetic_env import EnvSampler | ||||
| from .synthetic_env import SyntheticDEnv | ||||
| from .math_core import LinearFunc | ||||
| from .math_core import DynamicLinearFunc | ||||
| from .math_core import DynamicQuadraticFunc | ||||
| from .math_core import ConstantFunc, ComposedSinFunc | ||||
|  | ||||
|  | ||||
| __all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"] | ||||
|  | ||||
|  | ||||
| def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, version="v1"): | ||||
|     if version == "v1": | ||||
|         mean_generator = ConstantFunc(0) | ||||
|         std_generator = ConstantFunc(1) | ||||
|     elif version == "v2": | ||||
|         mean_generator = ComposedSinFunc() | ||||
|         std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=1.5) | ||||
|     else: | ||||
|         raise ValueError("Unknown version: {:}".format(version)) | ||||
|     dynamic_env = SyntheticDEnv( | ||||
|         [mean_generator], | ||||
|         [[std_generator]], | ||||
|         num_per_task=num_per_task, | ||||
|         timestamp_config=dict( | ||||
|             min_timestamp=-0.5, max_timestamp=1.5, num=total_timestamp, mode=mode | ||||
|         ), | ||||
|     ) | ||||
|     if version == "v1": | ||||
|         function = DynamicLinearFunc() | ||||
|         function_param = dict() | ||||
|         function_param[0] = ComposedSinFunc( | ||||
|             amplitude_scale=ConstantFunc(3.0), | ||||
|             num_sin_phase=9, | ||||
|             sin_speed_use_power=False, | ||||
|         ) | ||||
|         function_param[1] = ConstantFunc(constant=0.9) | ||||
|     elif version == "v2": | ||||
|         function = DynamicQuadraticFunc() | ||||
|         function_param = dict() | ||||
|         function_param[0] = ComposedSinFunc( | ||||
|             num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0 | ||||
|         ) | ||||
|         function_param[1] = ConstantFunc(constant=0.9) | ||||
|         function_param[2] = ComposedSinFunc( | ||||
|             num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9 | ||||
|         ) | ||||
|     else: | ||||
|         raise ValueError("Unknown version: {:}".format(version)) | ||||
|  | ||||
|     function.set(function_param) | ||||
|     # dynamic_env.set_oracle_map(copy.deepcopy(function)) | ||||
|     dynamic_env.set_oracle_map(function) | ||||
|     return dynamic_env | ||||
| @@ -1,180 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ##################################################### | ||||
| import math | ||||
| import random | ||||
| import numpy as np | ||||
| from typing import List, Optional, Dict | ||||
| import torch | ||||
| import torch.utils.data as data | ||||
|  | ||||
| from .synthetic_utils import TimeStamp | ||||
|  | ||||
|  | ||||
| def is_list_tuple(x): | ||||
|     return isinstance(x, (tuple, list)) | ||||
|  | ||||
|  | ||||
| def zip_sequence(sequence): | ||||
|     def _combine(*alist): | ||||
|         if is_list_tuple(alist[0]): | ||||
|             return [_combine(*xlist) for xlist in zip(*alist)] | ||||
|         else: | ||||
|             return torch.cat(alist, dim=0) | ||||
|  | ||||
|     def unsqueeze(a): | ||||
|         if is_list_tuple(a): | ||||
|             return [unsqueeze(x) for x in a] | ||||
|         else: | ||||
|             return a.unsqueeze(dim=0) | ||||
|  | ||||
|     with torch.no_grad(): | ||||
|         sequence = [unsqueeze(a) for a in sequence] | ||||
|         return _combine(*sequence) | ||||
|  | ||||
|  | ||||
| class SyntheticDEnv(data.Dataset): | ||||
|     """The synethtic dynamic environment.""" | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         mean_functors: List[data.Dataset], | ||||
|         cov_functors: List[List[data.Dataset]], | ||||
|         num_per_task: int = 5000, | ||||
|         timestamp_config: Optional[Dict] = None, | ||||
|         mode: Optional[str] = None, | ||||
|         timestamp_noise_scale: float = 0.3, | ||||
|     ): | ||||
|         self._ndim = len(mean_functors) | ||||
|         assert self._ndim == len( | ||||
|             cov_functors | ||||
|         ), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functors)) | ||||
|         for cov_functor in cov_functors: | ||||
|             assert self._ndim == len( | ||||
|                 cov_functor | ||||
|             ), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functor)) | ||||
|         self._num_per_task = num_per_task | ||||
|         if timestamp_config is None: | ||||
|             timestamp_config = dict(mode=mode) | ||||
|         elif "mode" not in timestamp_config: | ||||
|             timestamp_config["mode"] = mode | ||||
|  | ||||
|         self._timestamp_generator = TimeStamp(**timestamp_config) | ||||
|         self._timestamp_noise_scale = timestamp_noise_scale | ||||
|  | ||||
|         self._mean_functors = mean_functors | ||||
|         self._cov_functors = cov_functors | ||||
|  | ||||
|         self._oracle_map = None | ||||
|         self._seq_length = None | ||||
|  | ||||
|     @property | ||||
|     def min_timestamp(self): | ||||
|         return self._timestamp_generator.min_timestamp | ||||
|  | ||||
|     @property | ||||
|     def max_timestamp(self): | ||||
|         return self._timestamp_generator.max_timestamp | ||||
|  | ||||
|     @property | ||||
|     def timestamp_interval(self): | ||||
|         return self._timestamp_generator.interval | ||||
|  | ||||
|     def random_timestamp(self): | ||||
|         return ( | ||||
|             random.random() * (self.max_timestamp - self.min_timestamp) | ||||
|             + self.min_timestamp | ||||
|         ) | ||||
|  | ||||
|     def reset_max_seq_length(self, seq_length): | ||||
|         self._seq_length = seq_length | ||||
|  | ||||
|     def get_timestamp(self, index): | ||||
|         if index is None: | ||||
|             timestamps = [] | ||||
|             for index in range(len(self._timestamp_generator)): | ||||
|                 timestamps.append(self._timestamp_generator[index][1]) | ||||
|             return tuple(timestamps) | ||||
|         else: | ||||
|             index, timestamp = self._timestamp_generator[index] | ||||
|             return timestamp | ||||
|  | ||||
|     def set_oracle_map(self, functor): | ||||
|         self._oracle_map = functor | ||||
|  | ||||
|     def __iter__(self): | ||||
|         self._iter_num = 0 | ||||
|         return self | ||||
|  | ||||
|     def __next__(self): | ||||
|         if self._iter_num >= len(self): | ||||
|             raise StopIteration | ||||
|         self._iter_num += 1 | ||||
|         return self.__getitem__(self._iter_num - 1) | ||||
|  | ||||
|     def __getitem__(self, index): | ||||
|         assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self)) | ||||
|         index, timestamp = self._timestamp_generator[index] | ||||
|         if self._seq_length is None: | ||||
|             return self.__call__(timestamp) | ||||
|         else: | ||||
|             noise = ( | ||||
|                 random.random() * self.timestamp_interval * self._timestamp_noise_scale | ||||
|             ) | ||||
|             timestamps = [ | ||||
|                 timestamp + i * self.timestamp_interval + noise | ||||
|                 for i in range(self._seq_length) | ||||
|             ] | ||||
|             xdata = [self.__call__(timestamp) for timestamp in timestamps] | ||||
|             return zip_sequence(xdata) | ||||
|  | ||||
|     def __call__(self, timestamp): | ||||
|         mean_list = [functor(timestamp) for functor in self._mean_functors] | ||||
|         cov_matrix = [ | ||||
|             [abs(cov_gen(timestamp)) for cov_gen in cov_functor] | ||||
|             for cov_functor in self._cov_functors | ||||
|         ] | ||||
|  | ||||
|         dataset = np.random.multivariate_normal( | ||||
|             mean_list, cov_matrix, size=self._num_per_task | ||||
|         ) | ||||
|         if self._oracle_map is None: | ||||
|             return torch.Tensor([timestamp]), torch.Tensor(dataset) | ||||
|         else: | ||||
|             targets = self._oracle_map.noise_call(dataset, timestamp) | ||||
|             return torch.Tensor([timestamp]), ( | ||||
|                 torch.Tensor(dataset), | ||||
|                 torch.Tensor(targets), | ||||
|             ) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self._timestamp_generator) | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({cur_num:}/{total} elements, ndim={ndim}, num_per_task={num_per_task}, range=[{xrange_min:.5f}~{xrange_max:.5f}], mode={mode})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             cur_num=len(self), | ||||
|             total=len(self._timestamp_generator), | ||||
|             ndim=self._ndim, | ||||
|             num_per_task=self._num_per_task, | ||||
|             xrange_min=self.min_timestamp, | ||||
|             xrange_max=self.max_timestamp, | ||||
|             mode=self._timestamp_generator.mode, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class EnvSampler: | ||||
|     def __init__(self, env, batch, enlarge): | ||||
|         indexes = list(range(len(env))) | ||||
|         self._indexes = indexes * enlarge | ||||
|         self._batch = batch | ||||
|         self._iterations = len(self._indexes) // self._batch | ||||
|  | ||||
|     def __iter__(self): | ||||
|         random.shuffle(self._indexes) | ||||
|         for it in range(self._iterations): | ||||
|             indexes = self._indexes[it * self._batch : (it + 1) * self._batch] | ||||
|             yield indexes | ||||
|  | ||||
|     def __len__(self): | ||||
|         return self._iterations | ||||
| @@ -1,72 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ##################################################### | ||||
| import copy | ||||
|  | ||||
| from .math_dynamic_funcs import DynamicLinearFunc, DynamicQuadraticFunc | ||||
| from .math_adv_funcs import ConstantFunc, ComposedSinFunc | ||||
| from .synthetic_env import SyntheticDEnv | ||||
|  | ||||
|  | ||||
| def create_example(timestamp_config=None, num_per_task=5000, indicator="v1"): | ||||
|     if indicator == "v1": | ||||
|         return create_example_v1(timestamp_config, num_per_task) | ||||
|     elif indicator == "v2": | ||||
|         return create_example_v2(timestamp_config, num_per_task) | ||||
|     else: | ||||
|         raise ValueError("Unkonwn indicator: {:}".format(indicator)) | ||||
|  | ||||
|  | ||||
| def create_example_v1( | ||||
|     timestamp_config=None, | ||||
|     num_per_task=5000, | ||||
| ): | ||||
|     mean_generator = ComposedSinFunc() | ||||
|     std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=0.5) | ||||
|  | ||||
|     dynamic_env = SyntheticDEnv( | ||||
|         [mean_generator], | ||||
|         [[std_generator]], | ||||
|         num_per_task=num_per_task, | ||||
|         timestamp_config=timestamp_config, | ||||
|     ) | ||||
|  | ||||
|     function = DynamicQuadraticFunc() | ||||
|     function_param = dict() | ||||
|     function_param[0] = ComposedSinFunc( | ||||
|         num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0 | ||||
|     ) | ||||
|     function_param[1] = ConstantFunc(constant=0.9) | ||||
|     function_param[2] = ComposedSinFunc( | ||||
|         num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9 | ||||
|     ) | ||||
|     function.set(function_param) | ||||
|  | ||||
|     dynamic_env.set_oracle_map(copy.deepcopy(function)) | ||||
|     return dynamic_env, function | ||||
|  | ||||
|  | ||||
| def create_example_v2( | ||||
|     timestamp_config=None, | ||||
|     num_per_task=5000, | ||||
| ): | ||||
|     mean_generator = ConstantFunc(0) | ||||
|     std_generator = ConstantFunc(1) | ||||
|  | ||||
|     dynamic_env = SyntheticDEnv( | ||||
|         [mean_generator], | ||||
|         [[std_generator]], | ||||
|         num_per_task=num_per_task, | ||||
|         timestamp_config=timestamp_config, | ||||
|     ) | ||||
|  | ||||
|     function = DynamicLinearFunc() | ||||
|     function_param = dict() | ||||
|     function_param[0] = ComposedSinFunc( | ||||
|         amplitude_scale=ConstantFunc(1.0), period_phase_shift=ConstantFunc(1.0) | ||||
|     ) | ||||
|     function_param[1] = ConstantFunc(constant=0.9) | ||||
|     function.set(function_param) | ||||
|  | ||||
|     dynamic_env.set_oracle_map(copy.deepcopy(function)) | ||||
|     return dynamic_env, function | ||||
| @@ -1,93 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
| import math | ||||
| import abc | ||||
| import numpy as np | ||||
| from typing import Optional | ||||
| import torch | ||||
| import torch.utils.data as data | ||||
|  | ||||
|  | ||||
| class UnifiedSplit: | ||||
|     """A class to unify the split strategy.""" | ||||
|  | ||||
|     def __init__(self, total_num, mode): | ||||
|         # Training Set 60% | ||||
|         num_of_train = int(total_num * 0.6) | ||||
|         # Validation Set 20% | ||||
|         num_of_valid = int(total_num * 0.2) | ||||
|         # Test Set 20% | ||||
|         num_of_set = total_num - num_of_train - num_of_valid | ||||
|         all_indexes = list(range(total_num)) | ||||
|         if mode is None: | ||||
|             self._indexes = all_indexes | ||||
|         elif mode.lower() in ("train", "training"): | ||||
|             self._indexes = all_indexes[:num_of_train] | ||||
|         elif mode.lower() in ("valid", "validation"): | ||||
|             self._indexes = all_indexes[num_of_train : num_of_train + num_of_valid] | ||||
|         elif mode.lower() in ("test", "testing"): | ||||
|             self._indexes = all_indexes[num_of_train + num_of_valid :] | ||||
|         else: | ||||
|             raise ValueError("Unkonwn mode of {:}".format(mode)) | ||||
|         self._all_indexes = all_indexes | ||||
|         self._mode = mode | ||||
|  | ||||
|     @property | ||||
|     def mode(self): | ||||
|         return self._mode | ||||
|  | ||||
|  | ||||
| class TimeStamp(UnifiedSplit, data.Dataset): | ||||
|     """The timestamp dataset.""" | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         min_timestamp: float = 0.0, | ||||
|         max_timestamp: float = 1.0, | ||||
|         num: int = 100, | ||||
|         mode: Optional[str] = None, | ||||
|     ): | ||||
|         self._min_timestamp = min_timestamp | ||||
|         self._max_timestamp = max_timestamp | ||||
|         self._interval = (max_timestamp - min_timestamp) / (float(num) - 1) | ||||
|         self._total_num = num | ||||
|         UnifiedSplit.__init__(self, self._total_num, mode) | ||||
|  | ||||
|     @property | ||||
|     def min_timestamp(self): | ||||
|         return self._min_timestamp + self._interval * min(self._indexes) | ||||
|  | ||||
|     @property | ||||
|     def max_timestamp(self): | ||||
|         return self._min_timestamp + self._interval * max(self._indexes) | ||||
|  | ||||
|     @property | ||||
|     def interval(self): | ||||
|         return self._interval | ||||
|  | ||||
|     def __iter__(self): | ||||
|         self._iter_num = 0 | ||||
|         return self | ||||
|  | ||||
|     def __next__(self): | ||||
|         if self._iter_num >= len(self): | ||||
|             raise StopIteration | ||||
|         self._iter_num += 1 | ||||
|         return self.__getitem__(self._iter_num - 1) | ||||
|  | ||||
|     def __getitem__(self, index): | ||||
|         assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self)) | ||||
|         index = self._indexes[index] | ||||
|         timestamp = self._min_timestamp + self._interval * index | ||||
|         return index, timestamp | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self._indexes) | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({cur_num:}/{total} elements)".format( | ||||
|             name=self.__class__.__name__, | ||||
|             cur_num=len(self), | ||||
|             total=self._total_num, | ||||
|         ) | ||||
| @@ -1,24 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import os | ||||
|  | ||||
|  | ||||
| def test_imagenet_data(imagenet): | ||||
|     total_length = len(imagenet) | ||||
|     assert ( | ||||
|         total_length == 1281166 or total_length == 50000 | ||||
|     ), "The length of ImageNet is wrong : {}".format(total_length) | ||||
|     map_id = {} | ||||
|     for index in range(total_length): | ||||
|         path, target = imagenet.imgs[index] | ||||
|         folder, image_name = os.path.split(path) | ||||
|         _, folder = os.path.split(folder) | ||||
|         if folder not in map_id: | ||||
|             map_id[folder] = target | ||||
|         else: | ||||
|             assert map_id[folder] == target, "Class : {} is not {}".format( | ||||
|                 folder, target | ||||
|             ) | ||||
|         assert image_name.find(folder) == 0, "{} is wrong.".format(path) | ||||
|     print("Check ImageNet Dataset OK") | ||||
| @@ -1,16 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| # every package does not rely on pytorch or tensorflow | ||||
| # I tried to list all dependency here: os, sys, time, numpy, (possibly) matplotlib | ||||
| ################################################## | ||||
| from .logger import Logger, PrintLogger | ||||
| from .meter import AverageMeter | ||||
| from .time_utils import ( | ||||
|     time_for_file, | ||||
|     time_string, | ||||
|     time_string_short, | ||||
|     time_print, | ||||
|     convert_secs2time, | ||||
| ) | ||||
| from .pickle_wrap import pickle_save, pickle_load | ||||
| @@ -1,173 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| from pathlib import Path | ||||
| import importlib, warnings | ||||
| import os, sys, time, numpy as np | ||||
|  | ||||
| if sys.version_info.major == 2:  # Python 2.x | ||||
|     from StringIO import StringIO as BIO | ||||
| else:  # Python 3.x | ||||
|     from io import BytesIO as BIO | ||||
|  | ||||
| if importlib.util.find_spec("tensorflow"): | ||||
|     import tensorflow as tf | ||||
|  | ||||
|  | ||||
| class PrintLogger(object): | ||||
|     def __init__(self): | ||||
|         """Create a summary writer logging to log_dir.""" | ||||
|         self.name = "PrintLogger" | ||||
|  | ||||
|     def log(self, string): | ||||
|         print(string) | ||||
|  | ||||
|     def close(self): | ||||
|         print("-" * 30 + " close printer " + "-" * 30) | ||||
|  | ||||
|  | ||||
| class Logger(object): | ||||
|     def __init__(self, log_dir, seed, create_model_dir=True, use_tf=False): | ||||
|         """Create a summary writer logging to log_dir.""" | ||||
|         self.seed = int(seed) | ||||
|         self.log_dir = Path(log_dir) | ||||
|         self.model_dir = Path(log_dir) / "checkpoint" | ||||
|         self.log_dir.mkdir(parents=True, exist_ok=True) | ||||
|         if create_model_dir: | ||||
|             self.model_dir.mkdir(parents=True, exist_ok=True) | ||||
|         # self.meta_dir.mkdir(mode=0o775, parents=True, exist_ok=True) | ||||
|  | ||||
|         self.use_tf = bool(use_tf) | ||||
|         self.tensorboard_dir = self.log_dir / ( | ||||
|             "tensorboard-{:}".format(time.strftime("%d-%h", time.gmtime(time.time()))) | ||||
|         ) | ||||
|         # self.tensorboard_dir = self.log_dir / ('tensorboard-{:}'.format(time.strftime( '%d-%h-at-%H:%M:%S', time.gmtime(time.time()) ))) | ||||
|         self.logger_path = self.log_dir / "seed-{:}-T-{:}.log".format( | ||||
|             self.seed, time.strftime("%d-%h-at-%H-%M-%S", time.gmtime(time.time())) | ||||
|         ) | ||||
|         self.logger_file = open(self.logger_path, "w") | ||||
|  | ||||
|         if self.use_tf: | ||||
|             self.tensorboard_dir.mkdir(mode=0o775, parents=True, exist_ok=True) | ||||
|             self.writer = tf.summary.FileWriter(str(self.tensorboard_dir)) | ||||
|         else: | ||||
|             self.writer = None | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}(dir={log_dir}, use-tf={use_tf}, writer={writer})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def path(self, mode): | ||||
|         valids = ("model", "best", "info", "log", None) | ||||
|         if mode is None: | ||||
|             return self.log_dir | ||||
|         elif mode == "model": | ||||
|             return self.model_dir / "seed-{:}-basic.pth".format(self.seed) | ||||
|         elif mode == "best": | ||||
|             return self.model_dir / "seed-{:}-best.pth".format(self.seed) | ||||
|         elif mode == "info": | ||||
|             return self.log_dir / "seed-{:}-last-info.pth".format(self.seed) | ||||
|         elif mode == "log": | ||||
|             return self.log_dir | ||||
|         else: | ||||
|             raise TypeError("Unknow mode = {:}, valid modes = {:}".format(mode, valids)) | ||||
|  | ||||
|     def extract_log(self): | ||||
|         return self.logger_file | ||||
|  | ||||
|     def close(self): | ||||
|         self.logger_file.close() | ||||
|         if self.writer is not None: | ||||
|             self.writer.close() | ||||
|  | ||||
|     def log(self, string, save=True, stdout=False): | ||||
|         if stdout: | ||||
|             sys.stdout.write(string) | ||||
|             sys.stdout.flush() | ||||
|         else: | ||||
|             print(string) | ||||
|         if save: | ||||
|             self.logger_file.write("{:}\n".format(string)) | ||||
|             self.logger_file.flush() | ||||
|  | ||||
|     def scalar_summary(self, tags, values, step): | ||||
|         """Log a scalar variable.""" | ||||
|         if not self.use_tf: | ||||
|             warnings.warn("Do set use-tensorflow installed but call scalar_summary") | ||||
|         else: | ||||
|             assert isinstance(tags, list) == isinstance( | ||||
|                 values, list | ||||
|             ), "Type : {:} vs {:}".format(type(tags), type(values)) | ||||
|             if not isinstance(tags, list): | ||||
|                 tags, values = [tags], [values] | ||||
|             for tag, value in zip(tags, values): | ||||
|                 summary = tf.Summary( | ||||
|                     value=[tf.Summary.Value(tag=tag, simple_value=value)] | ||||
|                 ) | ||||
|                 self.writer.add_summary(summary, step) | ||||
|                 self.writer.flush() | ||||
|  | ||||
|     def image_summary(self, tag, images, step): | ||||
|         """Log a list of images.""" | ||||
|         import scipy | ||||
|  | ||||
|         if not self.use_tf: | ||||
|             warnings.warn("Do set use-tensorflow installed but call scalar_summary") | ||||
|             return | ||||
|  | ||||
|         img_summaries = [] | ||||
|         for i, img in enumerate(images): | ||||
|             # Write the image to a string | ||||
|             try: | ||||
|                 s = StringIO() | ||||
|             except: | ||||
|                 s = BytesIO() | ||||
|             scipy.misc.toimage(img).save(s, format="png") | ||||
|  | ||||
|             # Create an Image object | ||||
|             img_sum = tf.Summary.Image( | ||||
|                 encoded_image_string=s.getvalue(), | ||||
|                 height=img.shape[0], | ||||
|                 width=img.shape[1], | ||||
|             ) | ||||
|             # Create a Summary value | ||||
|             img_summaries.append( | ||||
|                 tf.Summary.Value(tag="{}/{}".format(tag, i), image=img_sum) | ||||
|             ) | ||||
|  | ||||
|         # Create and write Summary | ||||
|         summary = tf.Summary(value=img_summaries) | ||||
|         self.writer.add_summary(summary, step) | ||||
|         self.writer.flush() | ||||
|  | ||||
|     def histo_summary(self, tag, values, step, bins=1000): | ||||
|         """Log a histogram of the tensor of values.""" | ||||
|         if not self.use_tf: | ||||
|             raise ValueError("Do not have tensorflow") | ||||
|         import tensorflow as tf | ||||
|  | ||||
|         # Create a histogram using numpy | ||||
|         counts, bin_edges = np.histogram(values, bins=bins) | ||||
|  | ||||
|         # Fill the fields of the histogram proto | ||||
|         hist = tf.HistogramProto() | ||||
|         hist.min = float(np.min(values)) | ||||
|         hist.max = float(np.max(values)) | ||||
|         hist.num = int(np.prod(values.shape)) | ||||
|         hist.sum = float(np.sum(values)) | ||||
|         hist.sum_squares = float(np.sum(values ** 2)) | ||||
|  | ||||
|         # Drop the start of the first bin | ||||
|         bin_edges = bin_edges[1:] | ||||
|  | ||||
|         # Add bin edges and counts | ||||
|         for edge in bin_edges: | ||||
|             hist.bucket_limit.append(edge) | ||||
|         for c in counts: | ||||
|             hist.bucket.append(c) | ||||
|  | ||||
|         # Create and write Summary | ||||
|         summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) | ||||
|         self.writer.add_summary(summary, step) | ||||
|         self.writer.flush() | ||||
| @@ -1,120 +0,0 @@ | ||||
| import numpy as np | ||||
|  | ||||
|  | ||||
| class AverageMeter(object): | ||||
|     """Computes and stores the average and current value""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         self.reset() | ||||
|  | ||||
|     def reset(self): | ||||
|         self.val = 0.0 | ||||
|         self.avg = 0.0 | ||||
|         self.sum = 0.0 | ||||
|         self.count = 0.0 | ||||
|  | ||||
|     def update(self, val, n=1): | ||||
|         self.val = val | ||||
|         self.sum += val * n | ||||
|         self.count += n | ||||
|         self.avg = self.sum / self.count | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}(val={val}, avg={avg}, count={count})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class RecorderMeter(object): | ||||
|     """Computes and stores the minimum loss value and its epoch index""" | ||||
|  | ||||
|     def __init__(self, total_epoch): | ||||
|         self.reset(total_epoch) | ||||
|  | ||||
|     def reset(self, total_epoch): | ||||
|         assert total_epoch > 0, "total_epoch should be greater than 0 vs {:}".format( | ||||
|             total_epoch | ||||
|         ) | ||||
|         self.total_epoch = total_epoch | ||||
|         self.current_epoch = 0 | ||||
|         self.epoch_losses = np.zeros( | ||||
|             (self.total_epoch, 2), dtype=np.float32 | ||||
|         )  # [epoch, train/val] | ||||
|         self.epoch_losses = self.epoch_losses - 1 | ||||
|         self.epoch_accuracy = np.zeros( | ||||
|             (self.total_epoch, 2), dtype=np.float32 | ||||
|         )  # [epoch, train/val] | ||||
|         self.epoch_accuracy = self.epoch_accuracy | ||||
|  | ||||
|     def update(self, idx, train_loss, train_acc, val_loss, val_acc): | ||||
|         assert ( | ||||
|             idx >= 0 and idx < self.total_epoch | ||||
|         ), "total_epoch : {} , but update with the {} index".format( | ||||
|             self.total_epoch, idx | ||||
|         ) | ||||
|         self.epoch_losses[idx, 0] = train_loss | ||||
|         self.epoch_losses[idx, 1] = val_loss | ||||
|         self.epoch_accuracy[idx, 0] = train_acc | ||||
|         self.epoch_accuracy[idx, 1] = val_acc | ||||
|         self.current_epoch = idx + 1 | ||||
|         return self.max_accuracy(False) == self.epoch_accuracy[idx, 1] | ||||
|  | ||||
|     def max_accuracy(self, istrain): | ||||
|         if self.current_epoch <= 0: | ||||
|             return 0 | ||||
|         if istrain: | ||||
|             return self.epoch_accuracy[: self.current_epoch, 0].max() | ||||
|         else: | ||||
|             return self.epoch_accuracy[: self.current_epoch, 1].max() | ||||
|  | ||||
|     def plot_curve(self, save_path): | ||||
|         import matplotlib | ||||
|  | ||||
|         matplotlib.use("agg") | ||||
|         import matplotlib.pyplot as plt | ||||
|  | ||||
|         title = "the accuracy/loss curve of train/val" | ||||
|         dpi = 100 | ||||
|         width, height = 1600, 1000 | ||||
|         legend_fontsize = 10 | ||||
|         figsize = width / float(dpi), height / float(dpi) | ||||
|  | ||||
|         fig = plt.figure(figsize=figsize) | ||||
|         x_axis = np.array([i for i in range(self.total_epoch)])  # epochs | ||||
|         y_axis = np.zeros(self.total_epoch) | ||||
|  | ||||
|         plt.xlim(0, self.total_epoch) | ||||
|         plt.ylim(0, 100) | ||||
|         interval_y = 5 | ||||
|         interval_x = 5 | ||||
|         plt.xticks(np.arange(0, self.total_epoch + interval_x, interval_x)) | ||||
|         plt.yticks(np.arange(0, 100 + interval_y, interval_y)) | ||||
|         plt.grid() | ||||
|         plt.title(title, fontsize=20) | ||||
|         plt.xlabel("the training epoch", fontsize=16) | ||||
|         plt.ylabel("accuracy", fontsize=16) | ||||
|  | ||||
|         y_axis[:] = self.epoch_accuracy[:, 0] | ||||
|         plt.plot(x_axis, y_axis, color="g", linestyle="-", label="train-accuracy", lw=2) | ||||
|         plt.legend(loc=4, fontsize=legend_fontsize) | ||||
|  | ||||
|         y_axis[:] = self.epoch_accuracy[:, 1] | ||||
|         plt.plot(x_axis, y_axis, color="y", linestyle="-", label="valid-accuracy", lw=2) | ||||
|         plt.legend(loc=4, fontsize=legend_fontsize) | ||||
|  | ||||
|         y_axis[:] = self.epoch_losses[:, 0] | ||||
|         plt.plot( | ||||
|             x_axis, y_axis * 50, color="g", linestyle=":", label="train-loss-x50", lw=2 | ||||
|         ) | ||||
|         plt.legend(loc=4, fontsize=legend_fontsize) | ||||
|  | ||||
|         y_axis[:] = self.epoch_losses[:, 1] | ||||
|         plt.plot( | ||||
|             x_axis, y_axis * 50, color="y", linestyle=":", label="valid-loss-x50", lw=2 | ||||
|         ) | ||||
|         plt.legend(loc=4, fontsize=legend_fontsize) | ||||
|  | ||||
|         if save_path is not None: | ||||
|             fig.savefig(save_path, dpi=dpi, bbox_inches="tight") | ||||
|             print("---- save figure {} into {}".format(title, save_path)) | ||||
|         plt.close(fig) | ||||
| @@ -1,21 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| import pickle | ||||
| from pathlib import Path | ||||
|  | ||||
|  | ||||
| def pickle_save(obj, path): | ||||
|     file_path = Path(path) | ||||
|     file_dir = file_path.parent | ||||
|     file_dir.mkdir(parents=True, exist_ok=True) | ||||
|     with file_path.open("wb") as f: | ||||
|         pickle.dump(obj, f) | ||||
|  | ||||
|  | ||||
| def pickle_load(path): | ||||
|     if not Path(path).exists(): | ||||
|         raise ValueError("{:} does not exists".format(path)) | ||||
|     with Path(path).open("rb") as f: | ||||
|         data = pickle.load(f) | ||||
|     return data | ||||
| @@ -1,49 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| import time, sys | ||||
| import numpy as np | ||||
|  | ||||
|  | ||||
| def time_for_file(): | ||||
|     ISOTIMEFORMAT = "%d-%h-at-%H-%M-%S" | ||||
|     return "{:}".format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time()))) | ||||
|  | ||||
|  | ||||
| def time_string(): | ||||
|     ISOTIMEFORMAT = "%Y-%m-%d %X" | ||||
|     string = "[{:}]".format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time()))) | ||||
|     return string | ||||
|  | ||||
|  | ||||
| def time_string_short(): | ||||
|     ISOTIMEFORMAT = "%Y%m%d" | ||||
|     string = "{:}".format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time()))) | ||||
|     return string | ||||
|  | ||||
|  | ||||
| def time_print(string, is_print=True): | ||||
|     if is_print: | ||||
|         print("{} : {}".format(time_string(), string)) | ||||
|  | ||||
|  | ||||
| def convert_secs2time(epoch_time, return_str=False): | ||||
|     need_hour = int(epoch_time / 3600) | ||||
|     need_mins = int((epoch_time - 3600 * need_hour) / 60) | ||||
|     need_secs = int(epoch_time - 3600 * need_hour - 60 * need_mins) | ||||
|     if return_str: | ||||
|         str = "[{:02d}:{:02d}:{:02d}]".format(need_hour, need_mins, need_secs) | ||||
|         return str | ||||
|     else: | ||||
|         return need_hour, need_mins, need_secs | ||||
|  | ||||
|  | ||||
| def print_log(print_string, log): | ||||
|     # if isinstance(log, Logger): log.log('{:}'.format(print_string)) | ||||
|     if hasattr(log, "log"): | ||||
|         log.log("{:}".format(print_string)) | ||||
|     else: | ||||
|         print("{:}".format(print_string)) | ||||
|         if log is not None: | ||||
|             log.write("{:}\n".format(print_string)) | ||||
|             log.flush() | ||||
| @@ -1,117 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import math, torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| from .initialization import initialize_resnet | ||||
|  | ||||
|  | ||||
| class Bottleneck(nn.Module): | ||||
|     def __init__(self, nChannels, growthRate): | ||||
|         super(Bottleneck, self).__init__() | ||||
|         interChannels = 4 * growthRate | ||||
|         self.bn1 = nn.BatchNorm2d(nChannels) | ||||
|         self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False) | ||||
|         self.bn2 = nn.BatchNorm2d(interChannels) | ||||
|         self.conv2 = nn.Conv2d( | ||||
|             interChannels, growthRate, kernel_size=3, padding=1, bias=False | ||||
|         ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         out = self.conv1(F.relu(self.bn1(x))) | ||||
|         out = self.conv2(F.relu(self.bn2(out))) | ||||
|         out = torch.cat((x, out), 1) | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class SingleLayer(nn.Module): | ||||
|     def __init__(self, nChannels, growthRate): | ||||
|         super(SingleLayer, self).__init__() | ||||
|         self.bn1 = nn.BatchNorm2d(nChannels) | ||||
|         self.conv1 = nn.Conv2d( | ||||
|             nChannels, growthRate, kernel_size=3, padding=1, bias=False | ||||
|         ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         out = self.conv1(F.relu(self.bn1(x))) | ||||
|         out = torch.cat((x, out), 1) | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class Transition(nn.Module): | ||||
|     def __init__(self, nChannels, nOutChannels): | ||||
|         super(Transition, self).__init__() | ||||
|         self.bn1 = nn.BatchNorm2d(nChannels) | ||||
|         self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, bias=False) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         out = self.conv1(F.relu(self.bn1(x))) | ||||
|         out = F.avg_pool2d(out, 2) | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class DenseNet(nn.Module): | ||||
|     def __init__(self, growthRate, depth, reduction, nClasses, bottleneck): | ||||
|         super(DenseNet, self).__init__() | ||||
|  | ||||
|         if bottleneck: | ||||
|             nDenseBlocks = int((depth - 4) / 6) | ||||
|         else: | ||||
|             nDenseBlocks = int((depth - 4) / 3) | ||||
|  | ||||
|         self.message = "CifarDenseNet : block : {:}, depth : {:}, reduction : {:}, growth-rate = {:}, class = {:}".format( | ||||
|             "bottleneck" if bottleneck else "basic", | ||||
|             depth, | ||||
|             reduction, | ||||
|             growthRate, | ||||
|             nClasses, | ||||
|         ) | ||||
|  | ||||
|         nChannels = 2 * growthRate | ||||
|         self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, bias=False) | ||||
|  | ||||
|         self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) | ||||
|         nChannels += nDenseBlocks * growthRate | ||||
|         nOutChannels = int(math.floor(nChannels * reduction)) | ||||
|         self.trans1 = Transition(nChannels, nOutChannels) | ||||
|  | ||||
|         nChannels = nOutChannels | ||||
|         self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) | ||||
|         nChannels += nDenseBlocks * growthRate | ||||
|         nOutChannels = int(math.floor(nChannels * reduction)) | ||||
|         self.trans2 = Transition(nChannels, nOutChannels) | ||||
|  | ||||
|         nChannels = nOutChannels | ||||
|         self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) | ||||
|         nChannels += nDenseBlocks * growthRate | ||||
|  | ||||
|         self.act = nn.Sequential( | ||||
|             nn.BatchNorm2d(nChannels), nn.ReLU(inplace=True), nn.AvgPool2d(8) | ||||
|         ) | ||||
|         self.fc = nn.Linear(nChannels, nClasses) | ||||
|  | ||||
|         self.apply(initialize_resnet) | ||||
|  | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck): | ||||
|         layers = [] | ||||
|         for i in range(int(nDenseBlocks)): | ||||
|             if bottleneck: | ||||
|                 layers.append(Bottleneck(nChannels, growthRate)) | ||||
|             else: | ||||
|                 layers.append(SingleLayer(nChannels, growthRate)) | ||||
|             nChannels += growthRate | ||||
|         return nn.Sequential(*layers) | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         out = self.conv1(inputs) | ||||
|         out = self.trans1(self.dense1(out)) | ||||
|         out = self.trans2(self.dense2(out)) | ||||
|         out = self.dense3(out) | ||||
|         features = self.act(out) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         out = self.fc(features) | ||||
|         return features, out | ||||
| @@ -1,180 +0,0 @@ | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| from .initialization import initialize_resnet | ||||
| from .SharedUtils import additive_func | ||||
|  | ||||
|  | ||||
| class Downsample(nn.Module): | ||||
|     def __init__(self, nIn, nOut, stride): | ||||
|         super(Downsample, self).__init__() | ||||
|         assert stride == 2 and nOut == 2 * nIn, "stride:{} IO:{},{}".format( | ||||
|             stride, nIn, nOut | ||||
|         ) | ||||
|         self.in_dim = nIn | ||||
|         self.out_dim = nOut | ||||
|         self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | ||||
|         self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=1, padding=0, bias=False) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.avg(x) | ||||
|         out = self.conv(x) | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|     def __init__(self, nIn, nOut, kernel, stride, padding, bias, relu): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|         self.conv = nn.Conv2d( | ||||
|             nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, bias=bias | ||||
|         ) | ||||
|         self.bn = nn.BatchNorm2d(nOut) | ||||
|         if relu: | ||||
|             self.relu = nn.ReLU(inplace=True) | ||||
|         else: | ||||
|             self.relu = None | ||||
|         self.out_dim = nOut | ||||
|         self.num_conv = 1 | ||||
|  | ||||
|     def forward(self, x): | ||||
|         conv = self.conv(x) | ||||
|         bn = self.bn(conv) | ||||
|         if self.relu: | ||||
|             return self.relu(bn) | ||||
|         else: | ||||
|             return bn | ||||
|  | ||||
|  | ||||
| class ResNetBasicblock(nn.Module): | ||||
|     expansion = 1 | ||||
|  | ||||
|     def __init__(self, inplanes, planes, stride): | ||||
|         super(ResNetBasicblock, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, True) | ||||
|         self.conv_b = ConvBNReLU(planes, planes, 3, 1, 1, False, False) | ||||
|         if stride == 2: | ||||
|             self.downsample = Downsample(inplanes, planes, stride) | ||||
|         elif inplanes != planes: | ||||
|             self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, False) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.out_dim = planes | ||||
|         self.num_conv = 2 | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|  | ||||
|         basicblock = self.conv_a(inputs) | ||||
|         basicblock = self.conv_b(basicblock) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = additive_func(residual, basicblock) | ||||
|         return F.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class ResNetBottleneck(nn.Module): | ||||
|     expansion = 4 | ||||
|  | ||||
|     def __init__(self, inplanes, planes, stride): | ||||
|         super(ResNetBottleneck, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, True) | ||||
|         self.conv_3x3 = ConvBNReLU(planes, planes, 3, stride, 1, False, True) | ||||
|         self.conv_1x4 = ConvBNReLU( | ||||
|             planes, planes * self.expansion, 1, 1, 0, False, False | ||||
|         ) | ||||
|         if stride == 2: | ||||
|             self.downsample = Downsample(inplanes, planes * self.expansion, stride) | ||||
|         elif inplanes != planes * self.expansion: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, planes * self.expansion, 1, 1, 0, False, False | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.out_dim = planes * self.expansion | ||||
|         self.num_conv = 3 | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|  | ||||
|         bottleneck = self.conv_1x1(inputs) | ||||
|         bottleneck = self.conv_3x3(bottleneck) | ||||
|         bottleneck = self.conv_1x4(bottleneck) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = additive_func(residual, bottleneck) | ||||
|         return F.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class CifarResNet(nn.Module): | ||||
|     def __init__(self, block_name, depth, num_classes, zero_init_residual): | ||||
|         super(CifarResNet, self).__init__() | ||||
|  | ||||
|         # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|         if block_name == "ResNetBasicblock": | ||||
|             block = ResNetBasicblock | ||||
|             assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" | ||||
|             layer_blocks = (depth - 2) // 6 | ||||
|         elif block_name == "ResNetBottleneck": | ||||
|             block = ResNetBottleneck | ||||
|             assert (depth - 2) % 9 == 0, "depth should be one of 164" | ||||
|             layer_blocks = (depth - 2) // 9 | ||||
|         else: | ||||
|             raise ValueError("invalid block : {:}".format(block_name)) | ||||
|  | ||||
|         self.message = "CifarResNet : Block : {:}, Depth : {:}, Layers for each block : {:}".format( | ||||
|             block_name, depth, layer_blocks | ||||
|         ) | ||||
|         self.num_classes = num_classes | ||||
|         self.channels = [16] | ||||
|         self.layers = nn.ModuleList([ConvBNReLU(3, 16, 3, 1, 1, False, True)]) | ||||
|         for stage in range(3): | ||||
|             for iL in range(layer_blocks): | ||||
|                 iC = self.channels[-1] | ||||
|                 planes = 16 * (2 ** stage) | ||||
|                 stride = 2 if stage > 0 and iL == 0 else 1 | ||||
|                 module = block(iC, planes, stride) | ||||
|                 self.channels.append(module.out_dim) | ||||
|                 self.layers.append(module) | ||||
|                 self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format( | ||||
|                     stage, | ||||
|                     iL, | ||||
|                     layer_blocks, | ||||
|                     len(self.layers) - 1, | ||||
|                     iC, | ||||
|                     module.out_dim, | ||||
|                     stride, | ||||
|                 ) | ||||
|  | ||||
|         self.avgpool = nn.AvgPool2d(8) | ||||
|         self.classifier = nn.Linear(module.out_dim, num_classes) | ||||
|         assert ( | ||||
|             sum(x.num_conv for x in self.layers) + 1 == depth | ||||
|         ), "invalid depth check {:} vs {:}".format( | ||||
|             sum(x.num_conv for x in self.layers) + 1, depth | ||||
|         ) | ||||
|  | ||||
|         self.apply(initialize_resnet) | ||||
|         if zero_init_residual: | ||||
|             for m in self.modules(): | ||||
|                 if isinstance(m, ResNetBasicblock): | ||||
|                     nn.init.constant_(m.conv_b.bn.weight, 0) | ||||
|                 elif isinstance(m, ResNetBottleneck): | ||||
|                     nn.init.constant_(m.conv_1x4.bn.weight, 0) | ||||
|  | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         x = inputs | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             x = layer(x) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = self.classifier(features) | ||||
|         return features, logits | ||||
| @@ -1,115 +0,0 @@ | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| from .initialization import initialize_resnet | ||||
|  | ||||
|  | ||||
| class WideBasicblock(nn.Module): | ||||
|     def __init__(self, inplanes, planes, stride, dropout=False): | ||||
|         super(WideBasicblock, self).__init__() | ||||
|  | ||||
|         self.bn_a = nn.BatchNorm2d(inplanes) | ||||
|         self.conv_a = nn.Conv2d( | ||||
|             inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False | ||||
|         ) | ||||
|  | ||||
|         self.bn_b = nn.BatchNorm2d(planes) | ||||
|         if dropout: | ||||
|             self.dropout = nn.Dropout2d(p=0.5, inplace=True) | ||||
|         else: | ||||
|             self.dropout = None | ||||
|         self.conv_b = nn.Conv2d( | ||||
|             planes, planes, kernel_size=3, stride=1, padding=1, bias=False | ||||
|         ) | ||||
|  | ||||
|         if inplanes != planes: | ||||
|             self.downsample = nn.Conv2d( | ||||
|                 inplanes, planes, kernel_size=1, stride=stride, padding=0, bias=False | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|  | ||||
|     def forward(self, x): | ||||
|  | ||||
|         basicblock = self.bn_a(x) | ||||
|         basicblock = F.relu(basicblock) | ||||
|         basicblock = self.conv_a(basicblock) | ||||
|  | ||||
|         basicblock = self.bn_b(basicblock) | ||||
|         basicblock = F.relu(basicblock) | ||||
|         if self.dropout is not None: | ||||
|             basicblock = self.dropout(basicblock) | ||||
|         basicblock = self.conv_b(basicblock) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             x = self.downsample(x) | ||||
|  | ||||
|         return x + basicblock | ||||
|  | ||||
|  | ||||
| class CifarWideResNet(nn.Module): | ||||
|     """ | ||||
|     ResNet optimized for the Cifar dataset, as specified in | ||||
|     https://arxiv.org/abs/1512.03385.pdf | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, depth, widen_factor, num_classes, dropout): | ||||
|         super(CifarWideResNet, self).__init__() | ||||
|  | ||||
|         # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|         assert (depth - 4) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" | ||||
|         layer_blocks = (depth - 4) // 6 | ||||
|         print( | ||||
|             "CifarPreResNet : Depth : {} , Layers for each block : {}".format( | ||||
|                 depth, layer_blocks | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         self.num_classes = num_classes | ||||
|         self.dropout = dropout | ||||
|         self.conv_3x3 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) | ||||
|  | ||||
|         self.message = "Wide ResNet : depth={:}, widen_factor={:}, class={:}".format( | ||||
|             depth, widen_factor, num_classes | ||||
|         ) | ||||
|         self.inplanes = 16 | ||||
|         self.stage_1 = self._make_layer( | ||||
|             WideBasicblock, 16 * widen_factor, layer_blocks, 1 | ||||
|         ) | ||||
|         self.stage_2 = self._make_layer( | ||||
|             WideBasicblock, 32 * widen_factor, layer_blocks, 2 | ||||
|         ) | ||||
|         self.stage_3 = self._make_layer( | ||||
|             WideBasicblock, 64 * widen_factor, layer_blocks, 2 | ||||
|         ) | ||||
|         self.lastact = nn.Sequential( | ||||
|             nn.BatchNorm2d(64 * widen_factor), nn.ReLU(inplace=True) | ||||
|         ) | ||||
|         self.avgpool = nn.AvgPool2d(8) | ||||
|         self.classifier = nn.Linear(64 * widen_factor, num_classes) | ||||
|  | ||||
|         self.apply(initialize_resnet) | ||||
|  | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def _make_layer(self, block, planes, blocks, stride): | ||||
|  | ||||
|         layers = [] | ||||
|         layers.append(block(self.inplanes, planes, stride, self.dropout)) | ||||
|         self.inplanes = planes | ||||
|         for i in range(1, blocks): | ||||
|             layers.append(block(self.inplanes, planes, 1, self.dropout)) | ||||
|  | ||||
|         return nn.Sequential(*layers) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.conv_3x3(x) | ||||
|         x = self.stage_1(x) | ||||
|         x = self.stage_2(x) | ||||
|         x = self.stage_3(x) | ||||
|         x = self.lastact(x) | ||||
|         x = self.avgpool(x) | ||||
|         features = x.view(x.size(0), -1) | ||||
|         outs = self.classifier(features) | ||||
|         return features, outs | ||||
| @@ -1,117 +0,0 @@ | ||||
| # MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018 | ||||
| from torch import nn | ||||
| from .initialization import initialize_resnet | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|     def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|         padding = (kernel_size - 1) // 2 | ||||
|         self.conv = nn.Conv2d( | ||||
|             in_planes, | ||||
|             out_planes, | ||||
|             kernel_size, | ||||
|             stride, | ||||
|             padding, | ||||
|             groups=groups, | ||||
|             bias=False, | ||||
|         ) | ||||
|         self.bn = nn.BatchNorm2d(out_planes) | ||||
|         self.relu = nn.ReLU6(inplace=True) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         out = self.conv(x) | ||||
|         out = self.bn(out) | ||||
|         out = self.relu(out) | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class InvertedResidual(nn.Module): | ||||
|     def __init__(self, inp, oup, stride, expand_ratio): | ||||
|         super(InvertedResidual, self).__init__() | ||||
|         self.stride = stride | ||||
|         assert stride in [1, 2] | ||||
|  | ||||
|         hidden_dim = int(round(inp * expand_ratio)) | ||||
|         self.use_res_connect = self.stride == 1 and inp == oup | ||||
|  | ||||
|         layers = [] | ||||
|         if expand_ratio != 1: | ||||
|             # pw | ||||
|             layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) | ||||
|         layers.extend( | ||||
|             [ | ||||
|                 # dw | ||||
|                 ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), | ||||
|                 # pw-linear | ||||
|                 nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), | ||||
|                 nn.BatchNorm2d(oup), | ||||
|             ] | ||||
|         ) | ||||
|         self.conv = nn.Sequential(*layers) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         if self.use_res_connect: | ||||
|             return x + self.conv(x) | ||||
|         else: | ||||
|             return self.conv(x) | ||||
|  | ||||
|  | ||||
| class MobileNetV2(nn.Module): | ||||
|     def __init__( | ||||
|         self, num_classes, width_mult, input_channel, last_channel, block_name, dropout | ||||
|     ): | ||||
|         super(MobileNetV2, self).__init__() | ||||
|         if block_name == "InvertedResidual": | ||||
|             block = InvertedResidual | ||||
|         else: | ||||
|             raise ValueError("invalid block name : {:}".format(block_name)) | ||||
|         inverted_residual_setting = [ | ||||
|             # t, c,  n, s | ||||
|             [1, 16, 1, 1], | ||||
|             [6, 24, 2, 2], | ||||
|             [6, 32, 3, 2], | ||||
|             [6, 64, 4, 2], | ||||
|             [6, 96, 3, 1], | ||||
|             [6, 160, 3, 2], | ||||
|             [6, 320, 1, 1], | ||||
|         ] | ||||
|  | ||||
|         # building first layer | ||||
|         input_channel = int(input_channel * width_mult) | ||||
|         self.last_channel = int(last_channel * max(1.0, width_mult)) | ||||
|         features = [ConvBNReLU(3, input_channel, stride=2)] | ||||
|         # building inverted residual blocks | ||||
|         for t, c, n, s in inverted_residual_setting: | ||||
|             output_channel = int(c * width_mult) | ||||
|             for i in range(n): | ||||
|                 stride = s if i == 0 else 1 | ||||
|                 features.append( | ||||
|                     block(input_channel, output_channel, stride, expand_ratio=t) | ||||
|                 ) | ||||
|                 input_channel = output_channel | ||||
|         # building last several layers | ||||
|         features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) | ||||
|         # make it nn.Sequential | ||||
|         self.features = nn.Sequential(*features) | ||||
|  | ||||
|         # building classifier | ||||
|         self.classifier = nn.Sequential( | ||||
|             nn.Dropout(dropout), | ||||
|             nn.Linear(self.last_channel, num_classes), | ||||
|         ) | ||||
|         self.message = "MobileNetV2 : width_mult={:}, in-C={:}, last-C={:}, block={:}, dropout={:}".format( | ||||
|             width_mult, input_channel, last_channel, block_name, dropout | ||||
|         ) | ||||
|  | ||||
|         # weight initialization | ||||
|         self.apply(initialize_resnet) | ||||
|  | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         features = self.features(inputs) | ||||
|         vectors = features.mean([2, 3]) | ||||
|         predicts = self.classifier(vectors) | ||||
|         return features, predicts | ||||
| @@ -1,217 +0,0 @@ | ||||
| # Deep Residual Learning for Image Recognition, CVPR 2016 | ||||
| import torch.nn as nn | ||||
| from .initialization import initialize_resnet | ||||
|  | ||||
|  | ||||
| def conv3x3(in_planes, out_planes, stride=1, groups=1): | ||||
|     return nn.Conv2d( | ||||
|         in_planes, | ||||
|         out_planes, | ||||
|         kernel_size=3, | ||||
|         stride=stride, | ||||
|         padding=1, | ||||
|         groups=groups, | ||||
|         bias=False, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def conv1x1(in_planes, out_planes, stride=1): | ||||
|     return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | ||||
|  | ||||
|  | ||||
| class BasicBlock(nn.Module): | ||||
|     expansion = 1 | ||||
|  | ||||
|     def __init__( | ||||
|         self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64 | ||||
|     ): | ||||
|         super(BasicBlock, self).__init__() | ||||
|         if groups != 1 or base_width != 64: | ||||
|             raise ValueError("BasicBlock only supports groups=1 and base_width=64") | ||||
|         # Both self.conv1 and self.downsample layers downsample the input when stride != 1 | ||||
|         self.conv1 = conv3x3(inplanes, planes, stride) | ||||
|         self.bn1 = nn.BatchNorm2d(planes) | ||||
|         self.relu = nn.ReLU(inplace=True) | ||||
|         self.conv2 = conv3x3(planes, planes) | ||||
|         self.bn2 = nn.BatchNorm2d(planes) | ||||
|         self.downsample = downsample | ||||
|         self.stride = stride | ||||
|  | ||||
|     def forward(self, x): | ||||
|         identity = x | ||||
|  | ||||
|         out = self.conv1(x) | ||||
|         out = self.bn1(out) | ||||
|         out = self.relu(out) | ||||
|  | ||||
|         out = self.conv2(out) | ||||
|         out = self.bn2(out) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             identity = self.downsample(x) | ||||
|  | ||||
|         out += identity | ||||
|         out = self.relu(out) | ||||
|  | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class Bottleneck(nn.Module): | ||||
|     expansion = 4 | ||||
|  | ||||
|     def __init__( | ||||
|         self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64 | ||||
|     ): | ||||
|         super(Bottleneck, self).__init__() | ||||
|         width = int(planes * (base_width / 64.0)) * groups | ||||
|         # Both self.conv2 and self.downsample layers downsample the input when stride != 1 | ||||
|         self.conv1 = conv1x1(inplanes, width) | ||||
|         self.bn1 = nn.BatchNorm2d(width) | ||||
|         self.conv2 = conv3x3(width, width, stride, groups) | ||||
|         self.bn2 = nn.BatchNorm2d(width) | ||||
|         self.conv3 = conv1x1(width, planes * self.expansion) | ||||
|         self.bn3 = nn.BatchNorm2d(planes * self.expansion) | ||||
|         self.relu = nn.ReLU(inplace=True) | ||||
|         self.downsample = downsample | ||||
|         self.stride = stride | ||||
|  | ||||
|     def forward(self, x): | ||||
|         identity = x | ||||
|  | ||||
|         out = self.conv1(x) | ||||
|         out = self.bn1(out) | ||||
|         out = self.relu(out) | ||||
|  | ||||
|         out = self.conv2(out) | ||||
|         out = self.bn2(out) | ||||
|         out = self.relu(out) | ||||
|  | ||||
|         out = self.conv3(out) | ||||
|         out = self.bn3(out) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             identity = self.downsample(x) | ||||
|  | ||||
|         out += identity | ||||
|         out = self.relu(out) | ||||
|  | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class ResNet(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         block_name, | ||||
|         layers, | ||||
|         deep_stem, | ||||
|         num_classes, | ||||
|         zero_init_residual, | ||||
|         groups, | ||||
|         width_per_group, | ||||
|     ): | ||||
|         super(ResNet, self).__init__() | ||||
|  | ||||
|         # planes = [int(width_per_group * groups * 2 ** i) for i in range(4)] | ||||
|         if block_name == "BasicBlock": | ||||
|             block = BasicBlock | ||||
|         elif block_name == "Bottleneck": | ||||
|             block = Bottleneck | ||||
|         else: | ||||
|             raise ValueError("invalid block-name : {:}".format(block_name)) | ||||
|  | ||||
|         if not deep_stem: | ||||
|             self.conv = nn.Sequential( | ||||
|                 nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False), | ||||
|                 nn.BatchNorm2d(64), | ||||
|                 nn.ReLU(inplace=True), | ||||
|             ) | ||||
|         else: | ||||
|             self.conv = nn.Sequential( | ||||
|                 nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False), | ||||
|                 nn.BatchNorm2d(32), | ||||
|                 nn.ReLU(inplace=True), | ||||
|                 nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False), | ||||
|                 nn.BatchNorm2d(32), | ||||
|                 nn.ReLU(inplace=True), | ||||
|                 nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False), | ||||
|                 nn.BatchNorm2d(64), | ||||
|                 nn.ReLU(inplace=True), | ||||
|             ) | ||||
|         self.inplanes = 64 | ||||
|         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | ||||
|         self.layer1 = self._make_layer( | ||||
|             block, 64, layers[0], stride=1, groups=groups, base_width=width_per_group | ||||
|         ) | ||||
|         self.layer2 = self._make_layer( | ||||
|             block, 128, layers[1], stride=2, groups=groups, base_width=width_per_group | ||||
|         ) | ||||
|         self.layer3 = self._make_layer( | ||||
|             block, 256, layers[2], stride=2, groups=groups, base_width=width_per_group | ||||
|         ) | ||||
|         self.layer4 = self._make_layer( | ||||
|             block, 512, layers[3], stride=2, groups=groups, base_width=width_per_group | ||||
|         ) | ||||
|         self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | ||||
|         self.fc = nn.Linear(512 * block.expansion, num_classes) | ||||
|         self.message = ( | ||||
|             "block = {:}, layers = {:}, deep_stem = {:}, num_classes = {:}".format( | ||||
|                 block, layers, deep_stem, num_classes | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         self.apply(initialize_resnet) | ||||
|  | ||||
|         # Zero-initialize the last BN in each residual branch, | ||||
|         # so that the residual branch starts with zeros, and each residual block behaves like an identity. | ||||
|         # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 | ||||
|         if zero_init_residual: | ||||
|             for m in self.modules(): | ||||
|                 if isinstance(m, Bottleneck): | ||||
|                     nn.init.constant_(m.bn3.weight, 0) | ||||
|                 elif isinstance(m, BasicBlock): | ||||
|                     nn.init.constant_(m.bn2.weight, 0) | ||||
|  | ||||
|     def _make_layer(self, block, planes, blocks, stride, groups, base_width): | ||||
|         downsample = None | ||||
|         if stride != 1 or self.inplanes != planes * block.expansion: | ||||
|             if stride == 2: | ||||
|                 downsample = nn.Sequential( | ||||
|                     nn.AvgPool2d(kernel_size=2, stride=2, padding=0), | ||||
|                     conv1x1(self.inplanes, planes * block.expansion, 1), | ||||
|                     nn.BatchNorm2d(planes * block.expansion), | ||||
|                 ) | ||||
|             elif stride == 1: | ||||
|                 downsample = nn.Sequential( | ||||
|                     conv1x1(self.inplanes, planes * block.expansion, stride), | ||||
|                     nn.BatchNorm2d(planes * block.expansion), | ||||
|                 ) | ||||
|             else: | ||||
|                 raise ValueError("invalid stride [{:}] for downsample".format(stride)) | ||||
|  | ||||
|         layers = [] | ||||
|         layers.append( | ||||
|             block(self.inplanes, planes, stride, downsample, groups, base_width) | ||||
|         ) | ||||
|         self.inplanes = planes * block.expansion | ||||
|         for _ in range(1, blocks): | ||||
|             layers.append(block(self.inplanes, planes, 1, None, groups, base_width)) | ||||
|  | ||||
|         return nn.Sequential(*layers) | ||||
|  | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.conv(x) | ||||
|         x = self.maxpool(x) | ||||
|  | ||||
|         x = self.layer1(x) | ||||
|         x = self.layer2(x) | ||||
|         x = self.layer3(x) | ||||
|         x = self.layer4(x) | ||||
|  | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = self.fc(features) | ||||
|  | ||||
|         return features, logits | ||||
| @@ -1,37 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
|  | ||||
| def additive_func(A, B): | ||||
|     assert A.dim() == B.dim() and A.size(0) == B.size(0), "{:} vs {:}".format( | ||||
|         A.size(), B.size() | ||||
|     ) | ||||
|     C = min(A.size(1), B.size(1)) | ||||
|     if A.size(1) == B.size(1): | ||||
|         return A + B | ||||
|     elif A.size(1) < B.size(1): | ||||
|         out = B.clone() | ||||
|         out[:, :C] += A | ||||
|         return out | ||||
|     else: | ||||
|         out = A.clone() | ||||
|         out[:, :C] += B | ||||
|         return out | ||||
|  | ||||
|  | ||||
| def change_key(key, value): | ||||
|     def func(m): | ||||
|         if hasattr(m, key): | ||||
|             setattr(m, key, value) | ||||
|  | ||||
|     return func | ||||
|  | ||||
|  | ||||
| def parse_channel_info(xstring): | ||||
|     blocks = xstring.split(" ") | ||||
|     blocks = [x.split("-") for x in blocks] | ||||
|     blocks = [[int(_) for _ in x] for x in blocks] | ||||
|     return blocks | ||||
| @@ -1,326 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| from os import path as osp | ||||
| from typing import List, Text | ||||
| import torch | ||||
|  | ||||
| __all__ = [ | ||||
|     "change_key", | ||||
|     "get_cell_based_tiny_net", | ||||
|     "get_search_spaces", | ||||
|     "get_cifar_models", | ||||
|     "get_imagenet_models", | ||||
|     "obtain_model", | ||||
|     "obtain_search_model", | ||||
|     "load_net_from_checkpoint", | ||||
|     "CellStructure", | ||||
|     "CellArchitectures", | ||||
| ] | ||||
|  | ||||
| # useful modules | ||||
| from config_utils import dict2config | ||||
| from models.SharedUtils import change_key | ||||
| from models.cell_searchs import CellStructure, CellArchitectures | ||||
|  | ||||
|  | ||||
| # Cell-based NAS Models | ||||
| def get_cell_based_tiny_net(config): | ||||
|     if isinstance(config, dict): | ||||
|         config = dict2config(config, None)  # to support the argument being a dict | ||||
|     super_type = getattr(config, "super_type", "basic") | ||||
|     group_names = ["DARTS-V1", "DARTS-V2", "GDAS", "SETN", "ENAS", "RANDOM", "generic"] | ||||
|     if super_type == "basic" and config.name in group_names: | ||||
|         from .cell_searchs import nas201_super_nets as nas_super_nets | ||||
|  | ||||
|         try: | ||||
|             return nas_super_nets[config.name]( | ||||
|                 config.C, | ||||
|                 config.N, | ||||
|                 config.max_nodes, | ||||
|                 config.num_classes, | ||||
|                 config.space, | ||||
|                 config.affine, | ||||
|                 config.track_running_stats, | ||||
|             ) | ||||
|         except: | ||||
|             return nas_super_nets[config.name]( | ||||
|                 config.C, config.N, config.max_nodes, config.num_classes, config.space | ||||
|             ) | ||||
|     elif super_type == "search-shape": | ||||
|         from .shape_searchs import GenericNAS301Model | ||||
|  | ||||
|         genotype = CellStructure.str2structure(config.genotype) | ||||
|         return GenericNAS301Model( | ||||
|             config.candidate_Cs, | ||||
|             config.max_num_Cs, | ||||
|             genotype, | ||||
|             config.num_classes, | ||||
|             config.affine, | ||||
|             config.track_running_stats, | ||||
|         ) | ||||
|     elif super_type == "nasnet-super": | ||||
|         from .cell_searchs import nasnet_super_nets as nas_super_nets | ||||
|  | ||||
|         return nas_super_nets[config.name]( | ||||
|             config.C, | ||||
|             config.N, | ||||
|             config.steps, | ||||
|             config.multiplier, | ||||
|             config.stem_multiplier, | ||||
|             config.num_classes, | ||||
|             config.space, | ||||
|             config.affine, | ||||
|             config.track_running_stats, | ||||
|         ) | ||||
|     elif config.name == "infer.tiny": | ||||
|         from .cell_infers import TinyNetwork | ||||
|  | ||||
|         if hasattr(config, "genotype"): | ||||
|             genotype = config.genotype | ||||
|         elif hasattr(config, "arch_str"): | ||||
|             genotype = CellStructure.str2structure(config.arch_str) | ||||
|         else: | ||||
|             raise ValueError( | ||||
|                 "Can not find genotype from this config : {:}".format(config) | ||||
|             ) | ||||
|         return TinyNetwork(config.C, config.N, genotype, config.num_classes) | ||||
|     elif config.name == "infer.shape.tiny": | ||||
|         from .shape_infers import DynamicShapeTinyNet | ||||
|  | ||||
|         if isinstance(config.channels, str): | ||||
|             channels = tuple([int(x) for x in config.channels.split(":")]) | ||||
|         else: | ||||
|             channels = config.channels | ||||
|         genotype = CellStructure.str2structure(config.genotype) | ||||
|         return DynamicShapeTinyNet(channels, genotype, config.num_classes) | ||||
|     elif config.name == "infer.nasnet-cifar": | ||||
|         from .cell_infers import NASNetonCIFAR | ||||
|  | ||||
|         raise NotImplementedError | ||||
|     else: | ||||
|         raise ValueError("invalid network name : {:}".format(config.name)) | ||||
|  | ||||
|  | ||||
| # obtain the search space, i.e., a dict mapping the operation name into a python-function for this op | ||||
| def get_search_spaces(xtype, name) -> List[Text]: | ||||
|     if xtype == "cell" or xtype == "tss":  # The topology search space. | ||||
|         from .cell_operations import SearchSpaceNames | ||||
|  | ||||
|         assert name in SearchSpaceNames, "invalid name [{:}] in {:}".format( | ||||
|             name, SearchSpaceNames.keys() | ||||
|         ) | ||||
|         return SearchSpaceNames[name] | ||||
|     elif xtype == "sss":  # The size search space. | ||||
|         if name in ["nats-bench", "nats-bench-size"]: | ||||
|             return {"candidates": [8, 16, 24, 32, 40, 48, 56, 64], "numbers": 5} | ||||
|         else: | ||||
|             raise ValueError("Invalid name : {:}".format(name)) | ||||
|     else: | ||||
|         raise ValueError("invalid search-space type is {:}".format(xtype)) | ||||
|  | ||||
|  | ||||
| def get_cifar_models(config, extra_path=None): | ||||
|     super_type = getattr(config, "super_type", "basic") | ||||
|     if super_type == "basic": | ||||
|         from .CifarResNet import CifarResNet | ||||
|         from .CifarDenseNet import DenseNet | ||||
|         from .CifarWideResNet import CifarWideResNet | ||||
|  | ||||
|         if config.arch == "resnet": | ||||
|             return CifarResNet( | ||||
|                 config.module, config.depth, config.class_num, config.zero_init_residual | ||||
|             ) | ||||
|         elif config.arch == "densenet": | ||||
|             return DenseNet( | ||||
|                 config.growthRate, | ||||
|                 config.depth, | ||||
|                 config.reduction, | ||||
|                 config.class_num, | ||||
|                 config.bottleneck, | ||||
|             ) | ||||
|         elif config.arch == "wideresnet": | ||||
|             return CifarWideResNet( | ||||
|                 config.depth, config.wide_factor, config.class_num, config.dropout | ||||
|             ) | ||||
|         else: | ||||
|             raise ValueError("invalid module type : {:}".format(config.arch)) | ||||
|     elif super_type.startswith("infer"): | ||||
|         from .shape_infers import InferWidthCifarResNet | ||||
|         from .shape_infers import InferDepthCifarResNet | ||||
|         from .shape_infers import InferCifarResNet | ||||
|         from .cell_infers import NASNetonCIFAR | ||||
|  | ||||
|         assert len(super_type.split("-")) == 2, "invalid super_type : {:}".format( | ||||
|             super_type | ||||
|         ) | ||||
|         infer_mode = super_type.split("-")[1] | ||||
|         if infer_mode == "width": | ||||
|             return InferWidthCifarResNet( | ||||
|                 config.module, | ||||
|                 config.depth, | ||||
|                 config.xchannels, | ||||
|                 config.class_num, | ||||
|                 config.zero_init_residual, | ||||
|             ) | ||||
|         elif infer_mode == "depth": | ||||
|             return InferDepthCifarResNet( | ||||
|                 config.module, | ||||
|                 config.depth, | ||||
|                 config.xblocks, | ||||
|                 config.class_num, | ||||
|                 config.zero_init_residual, | ||||
|             ) | ||||
|         elif infer_mode == "shape": | ||||
|             return InferCifarResNet( | ||||
|                 config.module, | ||||
|                 config.depth, | ||||
|                 config.xblocks, | ||||
|                 config.xchannels, | ||||
|                 config.class_num, | ||||
|                 config.zero_init_residual, | ||||
|             ) | ||||
|         elif infer_mode == "nasnet.cifar": | ||||
|             genotype = config.genotype | ||||
|             if extra_path is not None:  # reload genotype by extra_path | ||||
|                 if not osp.isfile(extra_path): | ||||
|                     raise ValueError("invalid extra_path : {:}".format(extra_path)) | ||||
|                 xdata = torch.load(extra_path) | ||||
|                 current_epoch = xdata["epoch"] | ||||
|                 genotype = xdata["genotypes"][current_epoch - 1] | ||||
|             C = config.C if hasattr(config, "C") else config.ichannel | ||||
|             N = config.N if hasattr(config, "N") else config.layers | ||||
|             return NASNetonCIFAR( | ||||
|                 C, N, config.stem_multi, config.class_num, genotype, config.auxiliary | ||||
|             ) | ||||
|         else: | ||||
|             raise ValueError("invalid infer-mode : {:}".format(infer_mode)) | ||||
|     else: | ||||
|         raise ValueError("invalid super-type : {:}".format(super_type)) | ||||
|  | ||||
|  | ||||
| def get_imagenet_models(config): | ||||
|     super_type = getattr(config, "super_type", "basic") | ||||
|     if super_type == "basic": | ||||
|         from .ImageNet_ResNet import ResNet | ||||
|         from .ImageNet_MobileNetV2 import MobileNetV2 | ||||
|  | ||||
|         if config.arch == "resnet": | ||||
|             return ResNet( | ||||
|                 config.block_name, | ||||
|                 config.layers, | ||||
|                 config.deep_stem, | ||||
|                 config.class_num, | ||||
|                 config.zero_init_residual, | ||||
|                 config.groups, | ||||
|                 config.width_per_group, | ||||
|             ) | ||||
|         elif config.arch == "mobilenet_v2": | ||||
|             return MobileNetV2( | ||||
|                 config.class_num, | ||||
|                 config.width_multi, | ||||
|                 config.input_channel, | ||||
|                 config.last_channel, | ||||
|                 "InvertedResidual", | ||||
|                 config.dropout, | ||||
|             ) | ||||
|         else: | ||||
|             raise ValueError("invalid arch : {:}".format(config.arch)) | ||||
|     elif super_type.startswith("infer"):  # NAS searched architecture | ||||
|         assert len(super_type.split("-")) == 2, "invalid super_type : {:}".format( | ||||
|             super_type | ||||
|         ) | ||||
|         infer_mode = super_type.split("-")[1] | ||||
|         if infer_mode == "shape": | ||||
|             from .shape_infers import InferImagenetResNet | ||||
|             from .shape_infers import InferMobileNetV2 | ||||
|  | ||||
|             if config.arch == "resnet": | ||||
|                 return InferImagenetResNet( | ||||
|                     config.block_name, | ||||
|                     config.layers, | ||||
|                     config.xblocks, | ||||
|                     config.xchannels, | ||||
|                     config.deep_stem, | ||||
|                     config.class_num, | ||||
|                     config.zero_init_residual, | ||||
|                 ) | ||||
|             elif config.arch == "MobileNetV2": | ||||
|                 return InferMobileNetV2( | ||||
|                     config.class_num, config.xchannels, config.xblocks, config.dropout | ||||
|                 ) | ||||
|             else: | ||||
|                 raise ValueError("invalid arch-mode : {:}".format(config.arch)) | ||||
|         else: | ||||
|             raise ValueError("invalid infer-mode : {:}".format(infer_mode)) | ||||
|     else: | ||||
|         raise ValueError("invalid super-type : {:}".format(super_type)) | ||||
|  | ||||
|  | ||||
| # Try to obtain the network by config. | ||||
| def obtain_model(config, extra_path=None): | ||||
|     if config.dataset == "cifar": | ||||
|         return get_cifar_models(config, extra_path) | ||||
|     elif config.dataset == "imagenet": | ||||
|         return get_imagenet_models(config) | ||||
|     else: | ||||
|         raise ValueError("invalid dataset in the model config : {:}".format(config)) | ||||
|  | ||||
|  | ||||
| def obtain_search_model(config): | ||||
|     if config.dataset == "cifar": | ||||
|         if config.arch == "resnet": | ||||
|             from .shape_searchs import SearchWidthCifarResNet | ||||
|             from .shape_searchs import SearchDepthCifarResNet | ||||
|             from .shape_searchs import SearchShapeCifarResNet | ||||
|  | ||||
|             if config.search_mode == "width": | ||||
|                 return SearchWidthCifarResNet( | ||||
|                     config.module, config.depth, config.class_num | ||||
|                 ) | ||||
|             elif config.search_mode == "depth": | ||||
|                 return SearchDepthCifarResNet( | ||||
|                     config.module, config.depth, config.class_num | ||||
|                 ) | ||||
|             elif config.search_mode == "shape": | ||||
|                 return SearchShapeCifarResNet( | ||||
|                     config.module, config.depth, config.class_num | ||||
|                 ) | ||||
|             else: | ||||
|                 raise ValueError("invalid search mode : {:}".format(config.search_mode)) | ||||
|         elif config.arch == "simres": | ||||
|             from .shape_searchs import SearchWidthSimResNet | ||||
|  | ||||
|             if config.search_mode == "width": | ||||
|                 return SearchWidthSimResNet(config.depth, config.class_num) | ||||
|             else: | ||||
|                 raise ValueError("invalid search mode : {:}".format(config.search_mode)) | ||||
|         else: | ||||
|             raise ValueError( | ||||
|                 "invalid arch : {:} for dataset [{:}]".format( | ||||
|                     config.arch, config.dataset | ||||
|                 ) | ||||
|             ) | ||||
|     elif config.dataset == "imagenet": | ||||
|         from .shape_searchs import SearchShapeImagenetResNet | ||||
|  | ||||
|         assert config.search_mode == "shape", "invalid search-mode : {:}".format( | ||||
|             config.search_mode | ||||
|         ) | ||||
|         if config.arch == "resnet": | ||||
|             return SearchShapeImagenetResNet( | ||||
|                 config.block_name, config.layers, config.deep_stem, config.class_num | ||||
|             ) | ||||
|         else: | ||||
|             raise ValueError("invalid model config : {:}".format(config)) | ||||
|     else: | ||||
|         raise ValueError("invalid dataset in the model config : {:}".format(config)) | ||||
|  | ||||
|  | ||||
| def load_net_from_checkpoint(checkpoint): | ||||
|     assert osp.isfile(checkpoint), "checkpoint {:} does not exist".format(checkpoint) | ||||
|     checkpoint = torch.load(checkpoint) | ||||
|     model_config = dict2config(checkpoint["model-config"], None) | ||||
|     model = obtain_model(model_config) | ||||
|     model.load_state_dict(checkpoint["base-model"]) | ||||
|     return model | ||||
| @@ -1,5 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| from .tiny_network import TinyNetwork | ||||
| from .nasnet_cifar import NASNetonCIFAR | ||||
| @@ -1,155 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
|  | ||||
| from models.cell_operations import OPS | ||||
|  | ||||
|  | ||||
| # Cell for NAS-Bench-201 | ||||
| class InferCell(nn.Module): | ||||
|     def __init__( | ||||
|         self, genotype, C_in, C_out, stride, affine=True, track_running_stats=True | ||||
|     ): | ||||
|         super(InferCell, self).__init__() | ||||
|  | ||||
|         self.layers = nn.ModuleList() | ||||
|         self.node_IN = [] | ||||
|         self.node_IX = [] | ||||
|         self.genotype = deepcopy(genotype) | ||||
|         for i in range(1, len(genotype)): | ||||
|             node_info = genotype[i - 1] | ||||
|             cur_index = [] | ||||
|             cur_innod = [] | ||||
|             for (op_name, op_in) in node_info: | ||||
|                 if op_in == 0: | ||||
|                     layer = OPS[op_name]( | ||||
|                         C_in, C_out, stride, affine, track_running_stats | ||||
|                     ) | ||||
|                 else: | ||||
|                     layer = OPS[op_name](C_out, C_out, 1, affine, track_running_stats) | ||||
|                 cur_index.append(len(self.layers)) | ||||
|                 cur_innod.append(op_in) | ||||
|                 self.layers.append(layer) | ||||
|             self.node_IX.append(cur_index) | ||||
|             self.node_IN.append(cur_innod) | ||||
|         self.nodes = len(genotype) | ||||
|         self.in_dim = C_in | ||||
|         self.out_dim = C_out | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         string = "info :: nodes={nodes}, inC={in_dim}, outC={out_dim}".format( | ||||
|             **self.__dict__ | ||||
|         ) | ||||
|         laystr = [] | ||||
|         for i, (node_layers, node_innods) in enumerate(zip(self.node_IX, self.node_IN)): | ||||
|             y = [ | ||||
|                 "I{:}-L{:}".format(_ii, _il) | ||||
|                 for _il, _ii in zip(node_layers, node_innods) | ||||
|             ] | ||||
|             x = "{:}<-({:})".format(i + 1, ",".join(y)) | ||||
|             laystr.append(x) | ||||
|         return ( | ||||
|             string | ||||
|             + ", [{:}]".format(" | ".join(laystr)) | ||||
|             + ", {:}".format(self.genotype.tostr()) | ||||
|         ) | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         nodes = [inputs] | ||||
|         for i, (node_layers, node_innods) in enumerate(zip(self.node_IX, self.node_IN)): | ||||
|             node_feature = sum( | ||||
|                 self.layers[_il](nodes[_ii]) | ||||
|                 for _il, _ii in zip(node_layers, node_innods) | ||||
|             ) | ||||
|             nodes.append(node_feature) | ||||
|         return nodes[-1] | ||||
|  | ||||
|  | ||||
| # Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018 | ||||
| class NASNetInferCell(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         genotype, | ||||
|         C_prev_prev, | ||||
|         C_prev, | ||||
|         C, | ||||
|         reduction, | ||||
|         reduction_prev, | ||||
|         affine, | ||||
|         track_running_stats, | ||||
|     ): | ||||
|         super(NASNetInferCell, self).__init__() | ||||
|         self.reduction = reduction | ||||
|         if reduction_prev: | ||||
|             self.preprocess0 = OPS["skip_connect"]( | ||||
|                 C_prev_prev, C, 2, affine, track_running_stats | ||||
|             ) | ||||
|         else: | ||||
|             self.preprocess0 = OPS["nor_conv_1x1"]( | ||||
|                 C_prev_prev, C, 1, affine, track_running_stats | ||||
|             ) | ||||
|         self.preprocess1 = OPS["nor_conv_1x1"]( | ||||
|             C_prev, C, 1, affine, track_running_stats | ||||
|         ) | ||||
|  | ||||
|         if not reduction: | ||||
|             nodes, concats = genotype["normal"], genotype["normal_concat"] | ||||
|         else: | ||||
|             nodes, concats = genotype["reduce"], genotype["reduce_concat"] | ||||
|         self._multiplier = len(concats) | ||||
|         self._concats = concats | ||||
|         self._steps = len(nodes) | ||||
|         self._nodes = nodes | ||||
|         self.edges = nn.ModuleDict() | ||||
|         for i, node in enumerate(nodes): | ||||
|             for in_node in node: | ||||
|                 name, j = in_node[0], in_node[1] | ||||
|                 stride = 2 if reduction and j < 2 else 1 | ||||
|                 node_str = "{:}<-{:}".format(i + 2, j) | ||||
|                 self.edges[node_str] = OPS[name]( | ||||
|                     C, C, stride, affine, track_running_stats | ||||
|                 ) | ||||
|  | ||||
|     # [TODO] to support drop_prob in this function.. | ||||
|     def forward(self, s0, s1, unused_drop_prob): | ||||
|         s0 = self.preprocess0(s0) | ||||
|         s1 = self.preprocess1(s1) | ||||
|  | ||||
|         states = [s0, s1] | ||||
|         for i, node in enumerate(self._nodes): | ||||
|             clist = [] | ||||
|             for in_node in node: | ||||
|                 name, j = in_node[0], in_node[1] | ||||
|                 node_str = "{:}<-{:}".format(i + 2, j) | ||||
|                 op = self.edges[node_str] | ||||
|                 clist.append(op(states[j])) | ||||
|             states.append(sum(clist)) | ||||
|         return torch.cat([states[x] for x in self._concats], dim=1) | ||||
|  | ||||
|  | ||||
| class AuxiliaryHeadCIFAR(nn.Module): | ||||
|     def __init__(self, C, num_classes): | ||||
|         """assuming input size 8x8""" | ||||
|         super(AuxiliaryHeadCIFAR, self).__init__() | ||||
|         self.features = nn.Sequential( | ||||
|             nn.ReLU(inplace=True), | ||||
|             nn.AvgPool2d( | ||||
|                 5, stride=3, padding=0, count_include_pad=False | ||||
|             ),  # image size = 2 x 2 | ||||
|             nn.Conv2d(C, 128, 1, bias=False), | ||||
|             nn.BatchNorm2d(128), | ||||
|             nn.ReLU(inplace=True), | ||||
|             nn.Conv2d(128, 768, 2, bias=False), | ||||
|             nn.BatchNorm2d(768), | ||||
|             nn.ReLU(inplace=True), | ||||
|         ) | ||||
|         self.classifier = nn.Linear(768, num_classes) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.features(x) | ||||
|         x = self.classifier(x.view(x.size(0), -1)) | ||||
|         return x | ||||
| @@ -1,117 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from .cells import NASNetInferCell as InferCell, AuxiliaryHeadCIFAR | ||||
|  | ||||
|  | ||||
| # The macro structure is based on NASNet | ||||
| class NASNetonCIFAR(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         C, | ||||
|         N, | ||||
|         stem_multiplier, | ||||
|         num_classes, | ||||
|         genotype, | ||||
|         auxiliary, | ||||
|         affine=True, | ||||
|         track_running_stats=True, | ||||
|     ): | ||||
|         super(NASNetonCIFAR, self).__init__() | ||||
|         self._C = C | ||||
|         self._layerN = N | ||||
|         self.stem = nn.Sequential( | ||||
|             nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False), | ||||
|             nn.BatchNorm2d(C * stem_multiplier), | ||||
|         ) | ||||
|  | ||||
|         # config for each layer | ||||
|         layer_channels = ( | ||||
|             [C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1) | ||||
|         ) | ||||
|         layer_reductions = ( | ||||
|             [False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1) | ||||
|         ) | ||||
|  | ||||
|         C_prev_prev, C_prev, C_curr, reduction_prev = ( | ||||
|             C * stem_multiplier, | ||||
|             C * stem_multiplier, | ||||
|             C, | ||||
|             False, | ||||
|         ) | ||||
|         self.auxiliary_index = None | ||||
|         self.auxiliary_head = None | ||||
|         self.cells = nn.ModuleList() | ||||
|         for index, (C_curr, reduction) in enumerate( | ||||
|             zip(layer_channels, layer_reductions) | ||||
|         ): | ||||
|             cell = InferCell( | ||||
|                 genotype, | ||||
|                 C_prev_prev, | ||||
|                 C_prev, | ||||
|                 C_curr, | ||||
|                 reduction, | ||||
|                 reduction_prev, | ||||
|                 affine, | ||||
|                 track_running_stats, | ||||
|             ) | ||||
|             self.cells.append(cell) | ||||
|             C_prev_prev, C_prev, reduction_prev = ( | ||||
|                 C_prev, | ||||
|                 cell._multiplier * C_curr, | ||||
|                 reduction, | ||||
|             ) | ||||
|             if reduction and C_curr == C * 4 and auxiliary: | ||||
|                 self.auxiliary_head = AuxiliaryHeadCIFAR(C_prev, num_classes) | ||||
|                 self.auxiliary_index = index | ||||
|         self._Layer = len(self.cells) | ||||
|         self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) | ||||
|         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|         self.classifier = nn.Linear(C_prev, num_classes) | ||||
|         self.drop_path_prob = -1 | ||||
|  | ||||
|     def update_drop_path(self, drop_path_prob): | ||||
|         self.drop_path_prob = drop_path_prob | ||||
|  | ||||
|     def auxiliary_param(self): | ||||
|         if self.auxiliary_head is None: | ||||
|             return [] | ||||
|         else: | ||||
|             return list(self.auxiliary_head.parameters()) | ||||
|  | ||||
|     def get_message(self): | ||||
|         string = self.extra_repr() | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             string += "\n {:02d}/{:02d} :: {:}".format( | ||||
|                 i, len(self.cells), cell.extra_repr() | ||||
|             ) | ||||
|         return string | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "{name}(C={_C}, N={_layerN}, L={_Layer})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         stem_feature, logits_aux = self.stem(inputs), None | ||||
|         cell_results = [stem_feature, stem_feature] | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             cell_feature = cell(cell_results[-2], cell_results[-1], self.drop_path_prob) | ||||
|             cell_results.append(cell_feature) | ||||
|             if ( | ||||
|                 self.auxiliary_index is not None | ||||
|                 and i == self.auxiliary_index | ||||
|                 and self.training | ||||
|             ): | ||||
|                 logits_aux = self.auxiliary_head(cell_results[-1]) | ||||
|         out = self.lastact(cell_results[-1]) | ||||
|         out = self.global_pooling(out) | ||||
|         out = out.view(out.size(0), -1) | ||||
|         logits = self.classifier(out) | ||||
|         if logits_aux is None: | ||||
|             return out, logits | ||||
|         else: | ||||
|             return out, [logits, logits_aux] | ||||
| @@ -1,63 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| import torch.nn as nn | ||||
| from ..cell_operations import ResNetBasicblock | ||||
| from .cells import InferCell | ||||
|  | ||||
|  | ||||
| # The macro structure for architectures in NAS-Bench-201 | ||||
| class TinyNetwork(nn.Module): | ||||
|     def __init__(self, C, N, genotype, num_classes): | ||||
|         super(TinyNetwork, self).__init__() | ||||
|         self._C = C | ||||
|         self._layerN = N | ||||
|  | ||||
|         self.stem = nn.Sequential( | ||||
|             nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C) | ||||
|         ) | ||||
|  | ||||
|         layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N | ||||
|         layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||
|  | ||||
|         C_prev = C | ||||
|         self.cells = nn.ModuleList() | ||||
|         for index, (C_curr, reduction) in enumerate( | ||||
|             zip(layer_channels, layer_reductions) | ||||
|         ): | ||||
|             if reduction: | ||||
|                 cell = ResNetBasicblock(C_prev, C_curr, 2, True) | ||||
|             else: | ||||
|                 cell = InferCell(genotype, C_prev, C_curr, 1) | ||||
|             self.cells.append(cell) | ||||
|             C_prev = cell.out_dim | ||||
|         self._Layer = len(self.cells) | ||||
|  | ||||
|         self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) | ||||
|         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|         self.classifier = nn.Linear(C_prev, num_classes) | ||||
|  | ||||
|     def get_message(self): | ||||
|         string = self.extra_repr() | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             string += "\n {:02d}/{:02d} :: {:}".format( | ||||
|                 i, len(self.cells), cell.extra_repr() | ||||
|             ) | ||||
|         return string | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "{name}(C={_C}, N={_layerN}, L={_Layer})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         feature = self.stem(inputs) | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             feature = cell(feature) | ||||
|  | ||||
|         out = self.lastact(feature) | ||||
|         out = self.global_pooling(out) | ||||
|         out = out.view(out.size(0), -1) | ||||
|         logits = self.classifier(out) | ||||
|  | ||||
|         return out, logits | ||||
| @@ -1,553 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
| __all__ = ["OPS", "RAW_OP_CLASSES", "ResNetBasicblock", "SearchSpaceNames"] | ||||
|  | ||||
| OPS = { | ||||
|     "none": lambda C_in, C_out, stride, affine, track_running_stats: Zero( | ||||
|         C_in, C_out, stride | ||||
|     ), | ||||
|     "avg_pool_3x3": lambda C_in, C_out, stride, affine, track_running_stats: POOLING( | ||||
|         C_in, C_out, stride, "avg", affine, track_running_stats | ||||
|     ), | ||||
|     "max_pool_3x3": lambda C_in, C_out, stride, affine, track_running_stats: POOLING( | ||||
|         C_in, C_out, stride, "max", affine, track_running_stats | ||||
|     ), | ||||
|     "nor_conv_7x7": lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN( | ||||
|         C_in, | ||||
|         C_out, | ||||
|         (7, 7), | ||||
|         (stride, stride), | ||||
|         (3, 3), | ||||
|         (1, 1), | ||||
|         affine, | ||||
|         track_running_stats, | ||||
|     ), | ||||
|     "nor_conv_3x3": lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN( | ||||
|         C_in, | ||||
|         C_out, | ||||
|         (3, 3), | ||||
|         (stride, stride), | ||||
|         (1, 1), | ||||
|         (1, 1), | ||||
|         affine, | ||||
|         track_running_stats, | ||||
|     ), | ||||
|     "nor_conv_1x1": lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN( | ||||
|         C_in, | ||||
|         C_out, | ||||
|         (1, 1), | ||||
|         (stride, stride), | ||||
|         (0, 0), | ||||
|         (1, 1), | ||||
|         affine, | ||||
|         track_running_stats, | ||||
|     ), | ||||
|     "dua_sepc_3x3": lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv( | ||||
|         C_in, | ||||
|         C_out, | ||||
|         (3, 3), | ||||
|         (stride, stride), | ||||
|         (1, 1), | ||||
|         (1, 1), | ||||
|         affine, | ||||
|         track_running_stats, | ||||
|     ), | ||||
|     "dua_sepc_5x5": lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv( | ||||
|         C_in, | ||||
|         C_out, | ||||
|         (5, 5), | ||||
|         (stride, stride), | ||||
|         (2, 2), | ||||
|         (1, 1), | ||||
|         affine, | ||||
|         track_running_stats, | ||||
|     ), | ||||
|     "dil_sepc_3x3": lambda C_in, C_out, stride, affine, track_running_stats: SepConv( | ||||
|         C_in, | ||||
|         C_out, | ||||
|         (3, 3), | ||||
|         (stride, stride), | ||||
|         (2, 2), | ||||
|         (2, 2), | ||||
|         affine, | ||||
|         track_running_stats, | ||||
|     ), | ||||
|     "dil_sepc_5x5": lambda C_in, C_out, stride, affine, track_running_stats: SepConv( | ||||
|         C_in, | ||||
|         C_out, | ||||
|         (5, 5), | ||||
|         (stride, stride), | ||||
|         (4, 4), | ||||
|         (2, 2), | ||||
|         affine, | ||||
|         track_running_stats, | ||||
|     ), | ||||
|     "skip_connect": lambda C_in, C_out, stride, affine, track_running_stats: Identity() | ||||
|     if stride == 1 and C_in == C_out | ||||
|     else FactorizedReduce(C_in, C_out, stride, affine, track_running_stats), | ||||
| } | ||||
|  | ||||
| CONNECT_NAS_BENCHMARK = ["none", "skip_connect", "nor_conv_3x3"] | ||||
| NAS_BENCH_201 = ["none", "skip_connect", "nor_conv_1x1", "nor_conv_3x3", "avg_pool_3x3"] | ||||
| DARTS_SPACE = [ | ||||
|     "none", | ||||
|     "skip_connect", | ||||
|     "dua_sepc_3x3", | ||||
|     "dua_sepc_5x5", | ||||
|     "dil_sepc_3x3", | ||||
|     "dil_sepc_5x5", | ||||
|     "avg_pool_3x3", | ||||
|     "max_pool_3x3", | ||||
| ] | ||||
|  | ||||
| SearchSpaceNames = { | ||||
|     "connect-nas": CONNECT_NAS_BENCHMARK, | ||||
|     "nats-bench": NAS_BENCH_201, | ||||
|     "nas-bench-201": NAS_BENCH_201, | ||||
|     "darts": DARTS_SPACE, | ||||
| } | ||||
|  | ||||
|  | ||||
| class ReLUConvBN(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         C_in, | ||||
|         C_out, | ||||
|         kernel_size, | ||||
|         stride, | ||||
|         padding, | ||||
|         dilation, | ||||
|         affine, | ||||
|         track_running_stats=True, | ||||
|     ): | ||||
|         super(ReLUConvBN, self).__init__() | ||||
|         self.op = nn.Sequential( | ||||
|             nn.ReLU(inplace=False), | ||||
|             nn.Conv2d( | ||||
|                 C_in, | ||||
|                 C_out, | ||||
|                 kernel_size, | ||||
|                 stride=stride, | ||||
|                 padding=padding, | ||||
|                 dilation=dilation, | ||||
|                 bias=not affine, | ||||
|             ), | ||||
|             nn.BatchNorm2d( | ||||
|                 C_out, affine=affine, track_running_stats=track_running_stats | ||||
|             ), | ||||
|         ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         return self.op(x) | ||||
|  | ||||
|  | ||||
| class SepConv(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         C_in, | ||||
|         C_out, | ||||
|         kernel_size, | ||||
|         stride, | ||||
|         padding, | ||||
|         dilation, | ||||
|         affine, | ||||
|         track_running_stats=True, | ||||
|     ): | ||||
|         super(SepConv, self).__init__() | ||||
|         self.op = nn.Sequential( | ||||
|             nn.ReLU(inplace=False), | ||||
|             nn.Conv2d( | ||||
|                 C_in, | ||||
|                 C_in, | ||||
|                 kernel_size=kernel_size, | ||||
|                 stride=stride, | ||||
|                 padding=padding, | ||||
|                 dilation=dilation, | ||||
|                 groups=C_in, | ||||
|                 bias=False, | ||||
|             ), | ||||
|             nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=not affine), | ||||
|             nn.BatchNorm2d( | ||||
|                 C_out, affine=affine, track_running_stats=track_running_stats | ||||
|             ), | ||||
|         ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         return self.op(x) | ||||
|  | ||||
|  | ||||
| class DualSepConv(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         C_in, | ||||
|         C_out, | ||||
|         kernel_size, | ||||
|         stride, | ||||
|         padding, | ||||
|         dilation, | ||||
|         affine, | ||||
|         track_running_stats=True, | ||||
|     ): | ||||
|         super(DualSepConv, self).__init__() | ||||
|         self.op_a = SepConv( | ||||
|             C_in, | ||||
|             C_in, | ||||
|             kernel_size, | ||||
|             stride, | ||||
|             padding, | ||||
|             dilation, | ||||
|             affine, | ||||
|             track_running_stats, | ||||
|         ) | ||||
|         self.op_b = SepConv( | ||||
|             C_in, C_out, kernel_size, 1, padding, dilation, affine, track_running_stats | ||||
|         ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.op_a(x) | ||||
|         x = self.op_b(x) | ||||
|         return x | ||||
|  | ||||
|  | ||||
| class ResNetBasicblock(nn.Module): | ||||
|     def __init__(self, inplanes, planes, stride, affine=True, track_running_stats=True): | ||||
|         super(ResNetBasicblock, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         self.conv_a = ReLUConvBN( | ||||
|             inplanes, planes, 3, stride, 1, 1, affine, track_running_stats | ||||
|         ) | ||||
|         self.conv_b = ReLUConvBN( | ||||
|             planes, planes, 3, 1, 1, 1, affine, track_running_stats | ||||
|         ) | ||||
|         if stride == 2: | ||||
|             self.downsample = nn.Sequential( | ||||
|                 nn.AvgPool2d(kernel_size=2, stride=2, padding=0), | ||||
|                 nn.Conv2d( | ||||
|                     inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False | ||||
|                 ), | ||||
|             ) | ||||
|         elif inplanes != planes: | ||||
|             self.downsample = ReLUConvBN( | ||||
|                 inplanes, planes, 1, 1, 0, 1, affine, track_running_stats | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.in_dim = inplanes | ||||
|         self.out_dim = planes | ||||
|         self.stride = stride | ||||
|         self.num_conv = 2 | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         string = "{name}(inC={in_dim}, outC={out_dim}, stride={stride})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|         return string | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|  | ||||
|         basicblock = self.conv_a(inputs) | ||||
|         basicblock = self.conv_b(basicblock) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         return residual + basicblock | ||||
|  | ||||
|  | ||||
| class POOLING(nn.Module): | ||||
|     def __init__( | ||||
|         self, C_in, C_out, stride, mode, affine=True, track_running_stats=True | ||||
|     ): | ||||
|         super(POOLING, self).__init__() | ||||
|         if C_in == C_out: | ||||
|             self.preprocess = None | ||||
|         else: | ||||
|             self.preprocess = ReLUConvBN( | ||||
|                 C_in, C_out, 1, 1, 0, 1, affine, track_running_stats | ||||
|             ) | ||||
|         if mode == "avg": | ||||
|             self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) | ||||
|         elif mode == "max": | ||||
|             self.op = nn.MaxPool2d(3, stride=stride, padding=1) | ||||
|         else: | ||||
|             raise ValueError("Invalid mode={:} in POOLING".format(mode)) | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.preprocess: | ||||
|             x = self.preprocess(inputs) | ||||
|         else: | ||||
|             x = inputs | ||||
|         return self.op(x) | ||||
|  | ||||
|  | ||||
| class Identity(nn.Module): | ||||
|     def __init__(self): | ||||
|         super(Identity, self).__init__() | ||||
|  | ||||
|     def forward(self, x): | ||||
|         return x | ||||
|  | ||||
|  | ||||
| class Zero(nn.Module): | ||||
|     def __init__(self, C_in, C_out, stride): | ||||
|         super(Zero, self).__init__() | ||||
|         self.C_in = C_in | ||||
|         self.C_out = C_out | ||||
|         self.stride = stride | ||||
|         self.is_zero = True | ||||
|  | ||||
|     def forward(self, x): | ||||
|         if self.C_in == self.C_out: | ||||
|             if self.stride == 1: | ||||
|                 return x.mul(0.0) | ||||
|             else: | ||||
|                 return x[:, :, :: self.stride, :: self.stride].mul(0.0) | ||||
|         else: | ||||
|             shape = list(x.shape) | ||||
|             shape[1] = self.C_out | ||||
|             zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device) | ||||
|             return zeros | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "C_in={C_in}, C_out={C_out}, stride={stride}".format(**self.__dict__) | ||||
|  | ||||
|  | ||||
| class FactorizedReduce(nn.Module): | ||||
|     def __init__(self, C_in, C_out, stride, affine, track_running_stats): | ||||
|         super(FactorizedReduce, self).__init__() | ||||
|         self.stride = stride | ||||
|         self.C_in = C_in | ||||
|         self.C_out = C_out | ||||
|         self.relu = nn.ReLU(inplace=False) | ||||
|         if stride == 2: | ||||
|             # assert C_out % 2 == 0, 'C_out : {:}'.format(C_out) | ||||
|             C_outs = [C_out // 2, C_out - C_out // 2] | ||||
|             self.convs = nn.ModuleList() | ||||
|             for i in range(2): | ||||
|                 self.convs.append( | ||||
|                     nn.Conv2d( | ||||
|                         C_in, C_outs[i], 1, stride=stride, padding=0, bias=not affine | ||||
|                     ) | ||||
|                 ) | ||||
|             self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0) | ||||
|         elif stride == 1: | ||||
|             self.conv = nn.Conv2d( | ||||
|                 C_in, C_out, 1, stride=stride, padding=0, bias=not affine | ||||
|             ) | ||||
|         else: | ||||
|             raise ValueError("Invalid stride : {:}".format(stride)) | ||||
|         self.bn = nn.BatchNorm2d( | ||||
|             C_out, affine=affine, track_running_stats=track_running_stats | ||||
|         ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         if self.stride == 2: | ||||
|             x = self.relu(x) | ||||
|             y = self.pad(x) | ||||
|             out = torch.cat([self.convs[0](x), self.convs[1](y[:, :, 1:, 1:])], dim=1) | ||||
|         else: | ||||
|             out = self.conv(x) | ||||
|         out = self.bn(out) | ||||
|         return out | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "C_in={C_in}, C_out={C_out}, stride={stride}".format(**self.__dict__) | ||||
|  | ||||
|  | ||||
| # Auto-ReID: Searching for a Part-Aware ConvNet for Person Re-Identification, ICCV 2019 | ||||
| class PartAwareOp(nn.Module): | ||||
|     def __init__(self, C_in, C_out, stride, part=4): | ||||
|         super().__init__() | ||||
|         self.part = 4 | ||||
|         self.hidden = C_in // 3 | ||||
|         self.avg_pool = nn.AdaptiveAvgPool2d(1) | ||||
|         self.local_conv_list = nn.ModuleList() | ||||
|         for i in range(self.part): | ||||
|             self.local_conv_list.append( | ||||
|                 nn.Sequential( | ||||
|                     nn.ReLU(), | ||||
|                     nn.Conv2d(C_in, self.hidden, 1), | ||||
|                     nn.BatchNorm2d(self.hidden, affine=True), | ||||
|                 ) | ||||
|             ) | ||||
|         self.W_K = nn.Linear(self.hidden, self.hidden) | ||||
|         self.W_Q = nn.Linear(self.hidden, self.hidden) | ||||
|  | ||||
|         if stride == 2: | ||||
|             self.last = FactorizedReduce(C_in + self.hidden, C_out, 2) | ||||
|         elif stride == 1: | ||||
|             self.last = FactorizedReduce(C_in + self.hidden, C_out, 1) | ||||
|         else: | ||||
|             raise ValueError("Invalid Stride : {:}".format(stride)) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         batch, C, H, W = x.size() | ||||
|         assert H >= self.part, "input size too small : {:} vs {:}".format( | ||||
|             x.shape, self.part | ||||
|         ) | ||||
|         IHs = [0] | ||||
|         for i in range(self.part): | ||||
|             IHs.append(min(H, int((i + 1) * (float(H) / self.part)))) | ||||
|         local_feat_list = [] | ||||
|         for i in range(self.part): | ||||
|             feature = x[:, :, IHs[i] : IHs[i + 1], :] | ||||
|             xfeax = self.avg_pool(feature) | ||||
|             xfea = self.local_conv_list[i](xfeax) | ||||
|             local_feat_list.append(xfea) | ||||
|         part_feature = torch.cat(local_feat_list, dim=2).view(batch, -1, self.part) | ||||
|         part_feature = part_feature.transpose(1, 2).contiguous() | ||||
|         part_K = self.W_K(part_feature) | ||||
|         part_Q = self.W_Q(part_feature).transpose(1, 2).contiguous() | ||||
|         weight_att = torch.bmm(part_K, part_Q) | ||||
|         attention = torch.softmax(weight_att, dim=2) | ||||
|         aggreateF = torch.bmm(attention, part_feature).transpose(1, 2).contiguous() | ||||
|         features = [] | ||||
|         for i in range(self.part): | ||||
|             feature = aggreateF[:, :, i : i + 1].expand( | ||||
|                 batch, self.hidden, IHs[i + 1] - IHs[i] | ||||
|             ) | ||||
|             feature = feature.view(batch, self.hidden, IHs[i + 1] - IHs[i], 1) | ||||
|             features.append(feature) | ||||
|         features = torch.cat(features, dim=2).expand(batch, self.hidden, H, W) | ||||
|         final_fea = torch.cat((x, features), dim=1) | ||||
|         outputs = self.last(final_fea) | ||||
|         return outputs | ||||
|  | ||||
|  | ||||
| def drop_path(x, drop_prob): | ||||
|     if drop_prob > 0.0: | ||||
|         keep_prob = 1.0 - drop_prob | ||||
|         mask = x.new_zeros(x.size(0), 1, 1, 1) | ||||
|         mask = mask.bernoulli_(keep_prob) | ||||
|         x = torch.div(x, keep_prob) | ||||
|         x.mul_(mask) | ||||
|     return x | ||||
|  | ||||
|  | ||||
| # Searching for A Robust Neural Architecture in Four GPU Hours | ||||
| class GDAS_Reduction_Cell(nn.Module): | ||||
|     def __init__( | ||||
|         self, C_prev_prev, C_prev, C, reduction_prev, affine, track_running_stats | ||||
|     ): | ||||
|         super(GDAS_Reduction_Cell, self).__init__() | ||||
|         if reduction_prev: | ||||
|             self.preprocess0 = FactorizedReduce( | ||||
|                 C_prev_prev, C, 2, affine, track_running_stats | ||||
|             ) | ||||
|         else: | ||||
|             self.preprocess0 = ReLUConvBN( | ||||
|                 C_prev_prev, C, 1, 1, 0, 1, affine, track_running_stats | ||||
|             ) | ||||
|         self.preprocess1 = ReLUConvBN( | ||||
|             C_prev, C, 1, 1, 0, 1, affine, track_running_stats | ||||
|         ) | ||||
|  | ||||
|         self.reduction = True | ||||
|         self.ops1 = nn.ModuleList( | ||||
|             [ | ||||
|                 nn.Sequential( | ||||
|                     nn.ReLU(inplace=False), | ||||
|                     nn.Conv2d( | ||||
|                         C, | ||||
|                         C, | ||||
|                         (1, 3), | ||||
|                         stride=(1, 2), | ||||
|                         padding=(0, 1), | ||||
|                         groups=8, | ||||
|                         bias=not affine, | ||||
|                     ), | ||||
|                     nn.Conv2d( | ||||
|                         C, | ||||
|                         C, | ||||
|                         (3, 1), | ||||
|                         stride=(2, 1), | ||||
|                         padding=(1, 0), | ||||
|                         groups=8, | ||||
|                         bias=not affine, | ||||
|                     ), | ||||
|                     nn.BatchNorm2d( | ||||
|                         C, affine=affine, track_running_stats=track_running_stats | ||||
|                     ), | ||||
|                     nn.ReLU(inplace=False), | ||||
|                     nn.Conv2d(C, C, 1, stride=1, padding=0, bias=not affine), | ||||
|                     nn.BatchNorm2d( | ||||
|                         C, affine=affine, track_running_stats=track_running_stats | ||||
|                     ), | ||||
|                 ), | ||||
|                 nn.Sequential( | ||||
|                     nn.ReLU(inplace=False), | ||||
|                     nn.Conv2d( | ||||
|                         C, | ||||
|                         C, | ||||
|                         (1, 3), | ||||
|                         stride=(1, 2), | ||||
|                         padding=(0, 1), | ||||
|                         groups=8, | ||||
|                         bias=not affine, | ||||
|                     ), | ||||
|                     nn.Conv2d( | ||||
|                         C, | ||||
|                         C, | ||||
|                         (3, 1), | ||||
|                         stride=(2, 1), | ||||
|                         padding=(1, 0), | ||||
|                         groups=8, | ||||
|                         bias=not affine, | ||||
|                     ), | ||||
|                     nn.BatchNorm2d( | ||||
|                         C, affine=affine, track_running_stats=track_running_stats | ||||
|                     ), | ||||
|                     nn.ReLU(inplace=False), | ||||
|                     nn.Conv2d(C, C, 1, stride=1, padding=0, bias=not affine), | ||||
|                     nn.BatchNorm2d( | ||||
|                         C, affine=affine, track_running_stats=track_running_stats | ||||
|                     ), | ||||
|                 ), | ||||
|             ] | ||||
|         ) | ||||
|  | ||||
|         self.ops2 = nn.ModuleList( | ||||
|             [ | ||||
|                 nn.Sequential( | ||||
|                     nn.MaxPool2d(3, stride=2, padding=1), | ||||
|                     nn.BatchNorm2d( | ||||
|                         C, affine=affine, track_running_stats=track_running_stats | ||||
|                     ), | ||||
|                 ), | ||||
|                 nn.Sequential( | ||||
|                     nn.MaxPool2d(3, stride=2, padding=1), | ||||
|                     nn.BatchNorm2d( | ||||
|                         C, affine=affine, track_running_stats=track_running_stats | ||||
|                     ), | ||||
|                 ), | ||||
|             ] | ||||
|         ) | ||||
|  | ||||
|     @property | ||||
|     def multiplier(self): | ||||
|         return 4 | ||||
|  | ||||
|     def forward(self, s0, s1, drop_prob=-1): | ||||
|         s0 = self.preprocess0(s0) | ||||
|         s1 = self.preprocess1(s1) | ||||
|  | ||||
|         X0 = self.ops1[0](s0) | ||||
|         X1 = self.ops1[1](s1) | ||||
|         if self.training and drop_prob > 0.0: | ||||
|             X0, X1 = drop_path(X0, drop_prob), drop_path(X1, drop_prob) | ||||
|  | ||||
|         # X2 = self.ops2[0] (X0+X1) | ||||
|         X2 = self.ops2[0](s0) | ||||
|         X3 = self.ops2[1](s1) | ||||
|         if self.training and drop_prob > 0.0: | ||||
|             X2, X3 = drop_path(X2, drop_prob), drop_path(X3, drop_prob) | ||||
|         return torch.cat([X0, X1, X2, X3], dim=1) | ||||
|  | ||||
|  | ||||
| # To manage the useful classes in this file. | ||||
| RAW_OP_CLASSES = {"gdas_reduction": GDAS_Reduction_Cell} | ||||
| @@ -1,33 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| # The macro structure is defined in NAS-Bench-201 | ||||
| from .search_model_darts import TinyNetworkDarts | ||||
| from .search_model_gdas import TinyNetworkGDAS | ||||
| from .search_model_setn import TinyNetworkSETN | ||||
| from .search_model_enas import TinyNetworkENAS | ||||
| from .search_model_random import TinyNetworkRANDOM | ||||
| from .generic_model import GenericNAS201Model | ||||
| from .genotypes import Structure as CellStructure, architectures as CellArchitectures | ||||
|  | ||||
| # NASNet-based macro structure | ||||
| from .search_model_gdas_nasnet import NASNetworkGDAS | ||||
| from .search_model_gdas_frc_nasnet import NASNetworkGDAS_FRC | ||||
| from .search_model_darts_nasnet import NASNetworkDARTS | ||||
|  | ||||
|  | ||||
| nas201_super_nets = { | ||||
|     "DARTS-V1": TinyNetworkDarts, | ||||
|     "DARTS-V2": TinyNetworkDarts, | ||||
|     "GDAS": TinyNetworkGDAS, | ||||
|     "SETN": TinyNetworkSETN, | ||||
|     "ENAS": TinyNetworkENAS, | ||||
|     "RANDOM": TinyNetworkRANDOM, | ||||
|     "generic": GenericNAS201Model, | ||||
| } | ||||
|  | ||||
| nasnet_super_nets = { | ||||
|     "GDAS": NASNetworkGDAS, | ||||
|     "GDAS_FRC": NASNetworkGDAS_FRC, | ||||
|     "DARTS": NASNetworkDARTS, | ||||
| } | ||||
| @@ -1,14 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import torch | ||||
| from search_model_enas_utils import Controller | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|     controller = Controller(6, 4) | ||||
|     predictions = controller() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
| @@ -1,362 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 # | ||||
| ##################################################### | ||||
| import torch, random | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from typing import Text | ||||
| from torch.distributions.categorical import Categorical | ||||
|  | ||||
| from ..cell_operations import ResNetBasicblock, drop_path | ||||
| from .search_cells import NAS201SearchCell as SearchCell | ||||
| from .genotypes import Structure | ||||
|  | ||||
|  | ||||
| class Controller(nn.Module): | ||||
|     # we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py | ||||
|     def __init__( | ||||
|         self, | ||||
|         edge2index, | ||||
|         op_names, | ||||
|         max_nodes, | ||||
|         lstm_size=32, | ||||
|         lstm_num_layers=2, | ||||
|         tanh_constant=2.5, | ||||
|         temperature=5.0, | ||||
|     ): | ||||
|         super(Controller, self).__init__() | ||||
|         # assign the attributes | ||||
|         self.max_nodes = max_nodes | ||||
|         self.num_edge = len(edge2index) | ||||
|         self.edge2index = edge2index | ||||
|         self.num_ops = len(op_names) | ||||
|         self.op_names = op_names | ||||
|         self.lstm_size = lstm_size | ||||
|         self.lstm_N = lstm_num_layers | ||||
|         self.tanh_constant = tanh_constant | ||||
|         self.temperature = temperature | ||||
|         # create parameters | ||||
|         self.register_parameter( | ||||
|             "input_vars", nn.Parameter(torch.Tensor(1, 1, lstm_size)) | ||||
|         ) | ||||
|         self.w_lstm = nn.LSTM( | ||||
|             input_size=self.lstm_size, | ||||
|             hidden_size=self.lstm_size, | ||||
|             num_layers=self.lstm_N, | ||||
|         ) | ||||
|         self.w_embd = nn.Embedding(self.num_ops, self.lstm_size) | ||||
|         self.w_pred = nn.Linear(self.lstm_size, self.num_ops) | ||||
|  | ||||
|         nn.init.uniform_(self.input_vars, -0.1, 0.1) | ||||
|         nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1) | ||||
|         nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1) | ||||
|         nn.init.uniform_(self.w_embd.weight, -0.1, 0.1) | ||||
|         nn.init.uniform_(self.w_pred.weight, -0.1, 0.1) | ||||
|  | ||||
|     def convert_structure(self, _arch): | ||||
|         genotypes = [] | ||||
|         for i in range(1, self.max_nodes): | ||||
|             xlist = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 op_index = _arch[self.edge2index[node_str]] | ||||
|                 op_name = self.op_names[op_index] | ||||
|                 xlist.append((op_name, j)) | ||||
|             genotypes.append(tuple(xlist)) | ||||
|         return Structure(genotypes) | ||||
|  | ||||
|     def forward(self): | ||||
|  | ||||
|         inputs, h0 = self.input_vars, None | ||||
|         log_probs, entropys, sampled_arch = [], [], [] | ||||
|         for iedge in range(self.num_edge): | ||||
|             outputs, h0 = self.w_lstm(inputs, h0) | ||||
|  | ||||
|             logits = self.w_pred(outputs) | ||||
|             logits = logits / self.temperature | ||||
|             logits = self.tanh_constant * torch.tanh(logits) | ||||
|             # distribution | ||||
|             op_distribution = Categorical(logits=logits) | ||||
|             op_index = op_distribution.sample() | ||||
|             sampled_arch.append(op_index.item()) | ||||
|  | ||||
|             op_log_prob = op_distribution.log_prob(op_index) | ||||
|             log_probs.append(op_log_prob.view(-1)) | ||||
|             op_entropy = op_distribution.entropy() | ||||
|             entropys.append(op_entropy.view(-1)) | ||||
|  | ||||
|             # obtain the input embedding for the next step | ||||
|             inputs = self.w_embd(op_index) | ||||
|         return ( | ||||
|             torch.sum(torch.cat(log_probs)), | ||||
|             torch.sum(torch.cat(entropys)), | ||||
|             self.convert_structure(sampled_arch), | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class GenericNAS201Model(nn.Module): | ||||
|     def __init__( | ||||
|         self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats | ||||
|     ): | ||||
|         super(GenericNAS201Model, self).__init__() | ||||
|         self._C = C | ||||
|         self._layerN = N | ||||
|         self._max_nodes = max_nodes | ||||
|         self._stem = nn.Sequential( | ||||
|             nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C) | ||||
|         ) | ||||
|         layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N | ||||
|         layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||
|         C_prev, num_edge, edge2index = C, None, None | ||||
|         self._cells = nn.ModuleList() | ||||
|         for index, (C_curr, reduction) in enumerate( | ||||
|             zip(layer_channels, layer_reductions) | ||||
|         ): | ||||
|             if reduction: | ||||
|                 cell = ResNetBasicblock(C_prev, C_curr, 2) | ||||
|             else: | ||||
|                 cell = SearchCell( | ||||
|                     C_prev, | ||||
|                     C_curr, | ||||
|                     1, | ||||
|                     max_nodes, | ||||
|                     search_space, | ||||
|                     affine, | ||||
|                     track_running_stats, | ||||
|                 ) | ||||
|                 if num_edge is None: | ||||
|                     num_edge, edge2index = cell.num_edges, cell.edge2index | ||||
|                 else: | ||||
|                     assert ( | ||||
|                         num_edge == cell.num_edges and edge2index == cell.edge2index | ||||
|                     ), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges) | ||||
|             self._cells.append(cell) | ||||
|             C_prev = cell.out_dim | ||||
|         self._op_names = deepcopy(search_space) | ||||
|         self._Layer = len(self._cells) | ||||
|         self.edge2index = edge2index | ||||
|         self.lastact = nn.Sequential( | ||||
|             nn.BatchNorm2d( | ||||
|                 C_prev, affine=affine, track_running_stats=track_running_stats | ||||
|             ), | ||||
|             nn.ReLU(inplace=True), | ||||
|         ) | ||||
|         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|         self.classifier = nn.Linear(C_prev, num_classes) | ||||
|         self._num_edge = num_edge | ||||
|         # algorithm related | ||||
|         self.arch_parameters = nn.Parameter( | ||||
|             1e-3 * torch.randn(num_edge, len(search_space)) | ||||
|         ) | ||||
|         self._mode = None | ||||
|         self.dynamic_cell = None | ||||
|         self._tau = None | ||||
|         self._algo = None | ||||
|         self._drop_path = None | ||||
|         self.verbose = False | ||||
|  | ||||
|     def set_algo(self, algo: Text): | ||||
|         # used for searching | ||||
|         assert self._algo is None, "This functioin can only be called once." | ||||
|         self._algo = algo | ||||
|         if algo == "enas": | ||||
|             self.controller = Controller( | ||||
|                 self.edge2index, self._op_names, self._max_nodes | ||||
|             ) | ||||
|         else: | ||||
|             self.arch_parameters = nn.Parameter( | ||||
|                 1e-3 * torch.randn(self._num_edge, len(self._op_names)) | ||||
|             ) | ||||
|             if algo == "gdas": | ||||
|                 self._tau = 10 | ||||
|  | ||||
|     def set_cal_mode(self, mode, dynamic_cell=None): | ||||
|         assert mode in ["gdas", "enas", "urs", "joint", "select", "dynamic"] | ||||
|         self._mode = mode | ||||
|         if mode == "dynamic": | ||||
|             self.dynamic_cell = deepcopy(dynamic_cell) | ||||
|         else: | ||||
|             self.dynamic_cell = None | ||||
|  | ||||
|     def set_drop_path(self, progress, drop_path_rate): | ||||
|         if drop_path_rate is None: | ||||
|             self._drop_path = None | ||||
|         elif progress is None: | ||||
|             self._drop_path = drop_path_rate | ||||
|         else: | ||||
|             self._drop_path = progress * drop_path_rate | ||||
|  | ||||
|     @property | ||||
|     def mode(self): | ||||
|         return self._mode | ||||
|  | ||||
|     @property | ||||
|     def drop_path(self): | ||||
|         return self._drop_path | ||||
|  | ||||
|     @property | ||||
|     def weights(self): | ||||
|         xlist = list(self._stem.parameters()) | ||||
|         xlist += list(self._cells.parameters()) | ||||
|         xlist += list(self.lastact.parameters()) | ||||
|         xlist += list(self.global_pooling.parameters()) | ||||
|         xlist += list(self.classifier.parameters()) | ||||
|         return xlist | ||||
|  | ||||
|     def set_tau(self, tau): | ||||
|         self._tau = tau | ||||
|  | ||||
|     @property | ||||
|     def tau(self): | ||||
|         return self._tau | ||||
|  | ||||
|     @property | ||||
|     def alphas(self): | ||||
|         if self._algo == "enas": | ||||
|             return list(self.controller.parameters()) | ||||
|         else: | ||||
|             return [self.arch_parameters] | ||||
|  | ||||
|     @property | ||||
|     def message(self): | ||||
|         string = self.extra_repr() | ||||
|         for i, cell in enumerate(self._cells): | ||||
|             string += "\n {:02d}/{:02d} :: {:}".format( | ||||
|                 i, len(self._cells), cell.extra_repr() | ||||
|             ) | ||||
|         return string | ||||
|  | ||||
|     def show_alphas(self): | ||||
|         with torch.no_grad(): | ||||
|             if self._algo == "enas": | ||||
|                 return "w_pred :\n{:}".format(self.controller.w_pred.weight) | ||||
|             else: | ||||
|                 return "arch-parameters :\n{:}".format( | ||||
|                     nn.functional.softmax(self.arch_parameters, dim=-1).cpu() | ||||
|                 ) | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "{name}(C={_C}, Max-Nodes={_max_nodes}, N={_layerN}, L={_Layer}, alg={_algo})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     @property | ||||
|     def genotype(self): | ||||
|         genotypes = [] | ||||
|         for i in range(1, self._max_nodes): | ||||
|             xlist = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 with torch.no_grad(): | ||||
|                     weights = self.arch_parameters[self.edge2index[node_str]] | ||||
|                     op_name = self._op_names[weights.argmax().item()] | ||||
|                 xlist.append((op_name, j)) | ||||
|             genotypes.append(tuple(xlist)) | ||||
|         return Structure(genotypes) | ||||
|  | ||||
|     def dync_genotype(self, use_random=False): | ||||
|         genotypes = [] | ||||
|         with torch.no_grad(): | ||||
|             alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1) | ||||
|         for i in range(1, self._max_nodes): | ||||
|             xlist = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 if use_random: | ||||
|                     op_name = random.choice(self._op_names) | ||||
|                 else: | ||||
|                     weights = alphas_cpu[self.edge2index[node_str]] | ||||
|                     op_index = torch.multinomial(weights, 1).item() | ||||
|                     op_name = self._op_names[op_index] | ||||
|                 xlist.append((op_name, j)) | ||||
|             genotypes.append(tuple(xlist)) | ||||
|         return Structure(genotypes) | ||||
|  | ||||
|     def get_log_prob(self, arch): | ||||
|         with torch.no_grad(): | ||||
|             logits = nn.functional.log_softmax(self.arch_parameters, dim=-1) | ||||
|         select_logits = [] | ||||
|         for i, node_info in enumerate(arch.nodes): | ||||
|             for op, xin in node_info: | ||||
|                 node_str = "{:}<-{:}".format(i + 1, xin) | ||||
|                 op_index = self._op_names.index(op) | ||||
|                 select_logits.append(logits[self.edge2index[node_str], op_index]) | ||||
|         return sum(select_logits).item() | ||||
|  | ||||
|     def return_topK(self, K, use_random=False): | ||||
|         archs = Structure.gen_all(self._op_names, self._max_nodes, False) | ||||
|         pairs = [(self.get_log_prob(arch), arch) for arch in archs] | ||||
|         if K < 0 or K >= len(archs): | ||||
|             K = len(archs) | ||||
|         if use_random: | ||||
|             return random.sample(archs, K) | ||||
|         else: | ||||
|             sorted_pairs = sorted(pairs, key=lambda x: -x[0]) | ||||
|             return_pairs = [sorted_pairs[_][1] for _ in range(K)] | ||||
|             return return_pairs | ||||
|  | ||||
|     def normalize_archp(self): | ||||
|         if self.mode == "gdas": | ||||
|             while True: | ||||
|                 gumbels = -torch.empty_like(self.arch_parameters).exponential_().log() | ||||
|                 logits = (self.arch_parameters.log_softmax(dim=1) + gumbels) / self.tau | ||||
|                 probs = nn.functional.softmax(logits, dim=1) | ||||
|                 index = probs.max(-1, keepdim=True)[1] | ||||
|                 one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0) | ||||
|                 hardwts = one_h - probs.detach() + probs | ||||
|                 if ( | ||||
|                     (torch.isinf(gumbels).any()) | ||||
|                     or (torch.isinf(probs).any()) | ||||
|                     or (torch.isnan(probs).any()) | ||||
|                 ): | ||||
|                     continue | ||||
|                 else: | ||||
|                     break | ||||
|             with torch.no_grad(): | ||||
|                 hardwts_cpu = hardwts.detach().cpu() | ||||
|             return hardwts, hardwts_cpu, index, "GUMBEL" | ||||
|         else: | ||||
|             alphas = nn.functional.softmax(self.arch_parameters, dim=-1) | ||||
|             index = alphas.max(-1, keepdim=True)[1] | ||||
|             with torch.no_grad(): | ||||
|                 alphas_cpu = alphas.detach().cpu() | ||||
|             return alphas, alphas_cpu, index, "SOFTMAX" | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         alphas, alphas_cpu, index, verbose_str = self.normalize_archp() | ||||
|         feature = self._stem(inputs) | ||||
|         for i, cell in enumerate(self._cells): | ||||
|             if isinstance(cell, SearchCell): | ||||
|                 if self.mode == "urs": | ||||
|                     feature = cell.forward_urs(feature) | ||||
|                     if self.verbose: | ||||
|                         verbose_str += "-forward_urs" | ||||
|                 elif self.mode == "select": | ||||
|                     feature = cell.forward_select(feature, alphas_cpu) | ||||
|                     if self.verbose: | ||||
|                         verbose_str += "-forward_select" | ||||
|                 elif self.mode == "joint": | ||||
|                     feature = cell.forward_joint(feature, alphas) | ||||
|                     if self.verbose: | ||||
|                         verbose_str += "-forward_joint" | ||||
|                 elif self.mode == "dynamic": | ||||
|                     feature = cell.forward_dynamic(feature, self.dynamic_cell) | ||||
|                     if self.verbose: | ||||
|                         verbose_str += "-forward_dynamic" | ||||
|                 elif self.mode == "gdas": | ||||
|                     feature = cell.forward_gdas(feature, alphas, index) | ||||
|                     if self.verbose: | ||||
|                         verbose_str += "-forward_gdas" | ||||
|                 else: | ||||
|                     raise ValueError("invalid mode={:}".format(self.mode)) | ||||
|             else: | ||||
|                 feature = cell(feature) | ||||
|             if self.drop_path is not None: | ||||
|                 feature = drop_path(feature, self.drop_path) | ||||
|         if self.verbose and random.random() < 0.001: | ||||
|             print(verbose_str) | ||||
|         out = self.lastact(feature) | ||||
|         out = self.global_pooling(out) | ||||
|         out = out.view(out.size(0), -1) | ||||
|         logits = self.classifier(out) | ||||
|         return out, logits | ||||
| @@ -1,274 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| from copy import deepcopy | ||||
|  | ||||
|  | ||||
| def get_combination(space, num): | ||||
|     combs = [] | ||||
|     for i in range(num): | ||||
|         if i == 0: | ||||
|             for func in space: | ||||
|                 combs.append([(func, i)]) | ||||
|         else: | ||||
|             new_combs = [] | ||||
|             for string in combs: | ||||
|                 for func in space: | ||||
|                     xstring = string + [(func, i)] | ||||
|                     new_combs.append(xstring) | ||||
|             combs = new_combs | ||||
|     return combs | ||||
|  | ||||
|  | ||||
| class Structure: | ||||
|     def __init__(self, genotype): | ||||
|         assert isinstance(genotype, list) or isinstance( | ||||
|             genotype, tuple | ||||
|         ), "invalid class of genotype : {:}".format(type(genotype)) | ||||
|         self.node_num = len(genotype) + 1 | ||||
|         self.nodes = [] | ||||
|         self.node_N = [] | ||||
|         for idx, node_info in enumerate(genotype): | ||||
|             assert isinstance(node_info, list) or isinstance( | ||||
|                 node_info, tuple | ||||
|             ), "invalid class of node_info : {:}".format(type(node_info)) | ||||
|             assert len(node_info) >= 1, "invalid length : {:}".format(len(node_info)) | ||||
|             for node_in in node_info: | ||||
|                 assert isinstance(node_in, list) or isinstance( | ||||
|                     node_in, tuple | ||||
|                 ), "invalid class of in-node : {:}".format(type(node_in)) | ||||
|                 assert ( | ||||
|                     len(node_in) == 2 and node_in[1] <= idx | ||||
|                 ), "invalid in-node : {:}".format(node_in) | ||||
|             self.node_N.append(len(node_info)) | ||||
|             self.nodes.append(tuple(deepcopy(node_info))) | ||||
|  | ||||
|     def tolist(self, remove_str): | ||||
|         # convert this class to the list, if remove_str is 'none', then remove the 'none' operation. | ||||
|         # note that we re-order the input node in this function | ||||
|         # return the-genotype-list and success [if unsuccess, it is not a connectivity] | ||||
|         genotypes = [] | ||||
|         for node_info in self.nodes: | ||||
|             node_info = list(node_info) | ||||
|             node_info = sorted(node_info, key=lambda x: (x[1], x[0])) | ||||
|             node_info = tuple(filter(lambda x: x[0] != remove_str, node_info)) | ||||
|             if len(node_info) == 0: | ||||
|                 return None, False | ||||
|             genotypes.append(node_info) | ||||
|         return genotypes, True | ||||
|  | ||||
|     def node(self, index): | ||||
|         assert index > 0 and index <= len(self), "invalid index={:} < {:}".format( | ||||
|             index, len(self) | ||||
|         ) | ||||
|         return self.nodes[index] | ||||
|  | ||||
|     def tostr(self): | ||||
|         strings = [] | ||||
|         for node_info in self.nodes: | ||||
|             string = "|".join([x[0] + "~{:}".format(x[1]) for x in node_info]) | ||||
|             string = "|{:}|".format(string) | ||||
|             strings.append(string) | ||||
|         return "+".join(strings) | ||||
|  | ||||
|     def check_valid(self): | ||||
|         nodes = {0: True} | ||||
|         for i, node_info in enumerate(self.nodes): | ||||
|             sums = [] | ||||
|             for op, xin in node_info: | ||||
|                 if op == "none" or nodes[xin] is False: | ||||
|                     x = False | ||||
|                 else: | ||||
|                     x = True | ||||
|                 sums.append(x) | ||||
|             nodes[i + 1] = sum(sums) > 0 | ||||
|         return nodes[len(self.nodes)] | ||||
|  | ||||
|     def to_unique_str(self, consider_zero=False): | ||||
|         # this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation | ||||
|         # two operations are special, i.e., none and skip_connect | ||||
|         nodes = {0: "0"} | ||||
|         for i_node, node_info in enumerate(self.nodes): | ||||
|             cur_node = [] | ||||
|             for op, xin in node_info: | ||||
|                 if consider_zero is None: | ||||
|                     x = "(" + nodes[xin] + ")" + "@{:}".format(op) | ||||
|                 elif consider_zero: | ||||
|                     if op == "none" or nodes[xin] == "#": | ||||
|                         x = "#"  # zero | ||||
|                     elif op == "skip_connect": | ||||
|                         x = nodes[xin] | ||||
|                     else: | ||||
|                         x = "(" + nodes[xin] + ")" + "@{:}".format(op) | ||||
|                 else: | ||||
|                     if op == "skip_connect": | ||||
|                         x = nodes[xin] | ||||
|                     else: | ||||
|                         x = "(" + nodes[xin] + ")" + "@{:}".format(op) | ||||
|                 cur_node.append(x) | ||||
|             nodes[i_node + 1] = "+".join(sorted(cur_node)) | ||||
|         return nodes[len(self.nodes)] | ||||
|  | ||||
|     def check_valid_op(self, op_names): | ||||
|         for node_info in self.nodes: | ||||
|             for inode_edge in node_info: | ||||
|                 # assert inode_edge[0] in op_names, 'invalid op-name : {:}'.format(inode_edge[0]) | ||||
|                 if inode_edge[0] not in op_names: | ||||
|                     return False | ||||
|         return True | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({node_num} nodes with {node_info})".format( | ||||
|             name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self.nodes) + 1 | ||||
|  | ||||
|     def __getitem__(self, index): | ||||
|         return self.nodes[index] | ||||
|  | ||||
|     @staticmethod | ||||
|     def str2structure(xstr): | ||||
|         if isinstance(xstr, Structure): | ||||
|             return xstr | ||||
|         assert isinstance(xstr, str), "must take string (not {:}) as input".format( | ||||
|             type(xstr) | ||||
|         ) | ||||
|         nodestrs = xstr.split("+") | ||||
|         genotypes = [] | ||||
|         for i, node_str in enumerate(nodestrs): | ||||
|             inputs = list(filter(lambda x: x != "", node_str.split("|"))) | ||||
|             for xinput in inputs: | ||||
|                 assert len(xinput.split("~")) == 2, "invalid input length : {:}".format( | ||||
|                     xinput | ||||
|                 ) | ||||
|             inputs = (xi.split("~") for xi in inputs) | ||||
|             input_infos = tuple((op, int(IDX)) for (op, IDX) in inputs) | ||||
|             genotypes.append(input_infos) | ||||
|         return Structure(genotypes) | ||||
|  | ||||
|     @staticmethod | ||||
|     def str2fullstructure(xstr, default_name="none"): | ||||
|         assert isinstance(xstr, str), "must take string (not {:}) as input".format( | ||||
|             type(xstr) | ||||
|         ) | ||||
|         nodestrs = xstr.split("+") | ||||
|         genotypes = [] | ||||
|         for i, node_str in enumerate(nodestrs): | ||||
|             inputs = list(filter(lambda x: x != "", node_str.split("|"))) | ||||
|             for xinput in inputs: | ||||
|                 assert len(xinput.split("~")) == 2, "invalid input length : {:}".format( | ||||
|                     xinput | ||||
|                 ) | ||||
|             inputs = (xi.split("~") for xi in inputs) | ||||
|             input_infos = list((op, int(IDX)) for (op, IDX) in inputs) | ||||
|             all_in_nodes = list(x[1] for x in input_infos) | ||||
|             for j in range(i): | ||||
|                 if j not in all_in_nodes: | ||||
|                     input_infos.append((default_name, j)) | ||||
|             node_info = sorted(input_infos, key=lambda x: (x[1], x[0])) | ||||
|             genotypes.append(tuple(node_info)) | ||||
|         return Structure(genotypes) | ||||
|  | ||||
|     @staticmethod | ||||
|     def gen_all(search_space, num, return_ori): | ||||
|         assert isinstance(search_space, list) or isinstance( | ||||
|             search_space, tuple | ||||
|         ), "invalid class of search-space : {:}".format(type(search_space)) | ||||
|         assert ( | ||||
|             num >= 2 | ||||
|         ), "There should be at least two nodes in a neural cell instead of {:}".format( | ||||
|             num | ||||
|         ) | ||||
|         all_archs = get_combination(search_space, 1) | ||||
|         for i, arch in enumerate(all_archs): | ||||
|             all_archs[i] = [tuple(arch)] | ||||
|  | ||||
|         for inode in range(2, num): | ||||
|             cur_nodes = get_combination(search_space, inode) | ||||
|             new_all_archs = [] | ||||
|             for previous_arch in all_archs: | ||||
|                 for cur_node in cur_nodes: | ||||
|                     new_all_archs.append(previous_arch + [tuple(cur_node)]) | ||||
|             all_archs = new_all_archs | ||||
|         if return_ori: | ||||
|             return all_archs | ||||
|         else: | ||||
|             return [Structure(x) for x in all_archs] | ||||
|  | ||||
|  | ||||
| ResNet_CODE = Structure( | ||||
|     [ | ||||
|         (("nor_conv_3x3", 0),),  # node-1 | ||||
|         (("nor_conv_3x3", 1),),  # node-2 | ||||
|         (("skip_connect", 0), ("skip_connect", 2)), | ||||
|     ]  # node-3 | ||||
| ) | ||||
|  | ||||
| AllConv3x3_CODE = Structure( | ||||
|     [ | ||||
|         (("nor_conv_3x3", 0),),  # node-1 | ||||
|         (("nor_conv_3x3", 0), ("nor_conv_3x3", 1)),  # node-2 | ||||
|         (("nor_conv_3x3", 0), ("nor_conv_3x3", 1), ("nor_conv_3x3", 2)), | ||||
|     ]  # node-3 | ||||
| ) | ||||
|  | ||||
| AllFull_CODE = Structure( | ||||
|     [ | ||||
|         ( | ||||
|             ("skip_connect", 0), | ||||
|             ("nor_conv_1x1", 0), | ||||
|             ("nor_conv_3x3", 0), | ||||
|             ("avg_pool_3x3", 0), | ||||
|         ),  # node-1 | ||||
|         ( | ||||
|             ("skip_connect", 0), | ||||
|             ("nor_conv_1x1", 0), | ||||
|             ("nor_conv_3x3", 0), | ||||
|             ("avg_pool_3x3", 0), | ||||
|             ("skip_connect", 1), | ||||
|             ("nor_conv_1x1", 1), | ||||
|             ("nor_conv_3x3", 1), | ||||
|             ("avg_pool_3x3", 1), | ||||
|         ),  # node-2 | ||||
|         ( | ||||
|             ("skip_connect", 0), | ||||
|             ("nor_conv_1x1", 0), | ||||
|             ("nor_conv_3x3", 0), | ||||
|             ("avg_pool_3x3", 0), | ||||
|             ("skip_connect", 1), | ||||
|             ("nor_conv_1x1", 1), | ||||
|             ("nor_conv_3x3", 1), | ||||
|             ("avg_pool_3x3", 1), | ||||
|             ("skip_connect", 2), | ||||
|             ("nor_conv_1x1", 2), | ||||
|             ("nor_conv_3x3", 2), | ||||
|             ("avg_pool_3x3", 2), | ||||
|         ), | ||||
|     ]  # node-3 | ||||
| ) | ||||
|  | ||||
| AllConv1x1_CODE = Structure( | ||||
|     [ | ||||
|         (("nor_conv_1x1", 0),),  # node-1 | ||||
|         (("nor_conv_1x1", 0), ("nor_conv_1x1", 1)),  # node-2 | ||||
|         (("nor_conv_1x1", 0), ("nor_conv_1x1", 1), ("nor_conv_1x1", 2)), | ||||
|     ]  # node-3 | ||||
| ) | ||||
|  | ||||
| AllIdentity_CODE = Structure( | ||||
|     [ | ||||
|         (("skip_connect", 0),),  # node-1 | ||||
|         (("skip_connect", 0), ("skip_connect", 1)),  # node-2 | ||||
|         (("skip_connect", 0), ("skip_connect", 1), ("skip_connect", 2)), | ||||
|     ]  # node-3 | ||||
| ) | ||||
|  | ||||
| architectures = { | ||||
|     "resnet": ResNet_CODE, | ||||
|     "all_c3x3": AllConv3x3_CODE, | ||||
|     "all_c1x1": AllConv1x1_CODE, | ||||
|     "all_idnt": AllIdentity_CODE, | ||||
|     "all_full": AllFull_CODE, | ||||
| } | ||||
| @@ -1,251 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import math, random, torch | ||||
| import warnings | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| from copy import deepcopy | ||||
| from ..cell_operations import OPS | ||||
|  | ||||
|  | ||||
| # This module is used for NAS-Bench-201, represents a small search space with a complete DAG | ||||
| class NAS201SearchCell(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         C_in, | ||||
|         C_out, | ||||
|         stride, | ||||
|         max_nodes, | ||||
|         op_names, | ||||
|         affine=False, | ||||
|         track_running_stats=True, | ||||
|     ): | ||||
|         super(NAS201SearchCell, self).__init__() | ||||
|  | ||||
|         self.op_names = deepcopy(op_names) | ||||
|         self.edges = nn.ModuleDict() | ||||
|         self.max_nodes = max_nodes | ||||
|         self.in_dim = C_in | ||||
|         self.out_dim = C_out | ||||
|         for i in range(1, max_nodes): | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 if j == 0: | ||||
|                     xlists = [ | ||||
|                         OPS[op_name](C_in, C_out, stride, affine, track_running_stats) | ||||
|                         for op_name in op_names | ||||
|                     ] | ||||
|                 else: | ||||
|                     xlists = [ | ||||
|                         OPS[op_name](C_in, C_out, 1, affine, track_running_stats) | ||||
|                         for op_name in op_names | ||||
|                     ] | ||||
|                 self.edges[node_str] = nn.ModuleList(xlists) | ||||
|         self.edge_keys = sorted(list(self.edges.keys())) | ||||
|         self.edge2index = {key: i for i, key in enumerate(self.edge_keys)} | ||||
|         self.num_edges = len(self.edges) | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         string = "info :: {max_nodes} nodes, inC={in_dim}, outC={out_dim}".format( | ||||
|             **self.__dict__ | ||||
|         ) | ||||
|         return string | ||||
|  | ||||
|     def forward(self, inputs, weightss): | ||||
|         nodes = [inputs] | ||||
|         for i in range(1, self.max_nodes): | ||||
|             inter_nodes = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 weights = weightss[self.edge2index[node_str]] | ||||
|                 inter_nodes.append( | ||||
|                     sum( | ||||
|                         layer(nodes[j]) * w | ||||
|                         for layer, w in zip(self.edges[node_str], weights) | ||||
|                     ) | ||||
|                 ) | ||||
|             nodes.append(sum(inter_nodes)) | ||||
|         return nodes[-1] | ||||
|  | ||||
|     # GDAS | ||||
|     def forward_gdas(self, inputs, hardwts, index): | ||||
|         nodes = [inputs] | ||||
|         for i in range(1, self.max_nodes): | ||||
|             inter_nodes = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 weights = hardwts[self.edge2index[node_str]] | ||||
|                 argmaxs = index[self.edge2index[node_str]].item() | ||||
|                 weigsum = sum( | ||||
|                     weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie] | ||||
|                     for _ie, edge in enumerate(self.edges[node_str]) | ||||
|                 ) | ||||
|                 inter_nodes.append(weigsum) | ||||
|             nodes.append(sum(inter_nodes)) | ||||
|         return nodes[-1] | ||||
|  | ||||
|     # joint | ||||
|     def forward_joint(self, inputs, weightss): | ||||
|         nodes = [inputs] | ||||
|         for i in range(1, self.max_nodes): | ||||
|             inter_nodes = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 weights = weightss[self.edge2index[node_str]] | ||||
|                 # aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) / weights.numel() | ||||
|                 aggregation = sum( | ||||
|                     layer(nodes[j]) * w | ||||
|                     for layer, w in zip(self.edges[node_str], weights) | ||||
|                 ) | ||||
|                 inter_nodes.append(aggregation) | ||||
|             nodes.append(sum(inter_nodes)) | ||||
|         return nodes[-1] | ||||
|  | ||||
|     # uniform random sampling per iteration, SETN | ||||
|     def forward_urs(self, inputs): | ||||
|         nodes = [inputs] | ||||
|         for i in range(1, self.max_nodes): | ||||
|             while True:  # to avoid select zero for all ops | ||||
|                 sops, has_non_zero = [], False | ||||
|                 for j in range(i): | ||||
|                     node_str = "{:}<-{:}".format(i, j) | ||||
|                     candidates = self.edges[node_str] | ||||
|                     select_op = random.choice(candidates) | ||||
|                     sops.append(select_op) | ||||
|                     if not hasattr(select_op, "is_zero") or select_op.is_zero is False: | ||||
|                         has_non_zero = True | ||||
|                 if has_non_zero: | ||||
|                     break | ||||
|             inter_nodes = [] | ||||
|             for j, select_op in enumerate(sops): | ||||
|                 inter_nodes.append(select_op(nodes[j])) | ||||
|             nodes.append(sum(inter_nodes)) | ||||
|         return nodes[-1] | ||||
|  | ||||
|     # select the argmax | ||||
|     def forward_select(self, inputs, weightss): | ||||
|         nodes = [inputs] | ||||
|         for i in range(1, self.max_nodes): | ||||
|             inter_nodes = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 weights = weightss[self.edge2index[node_str]] | ||||
|                 inter_nodes.append( | ||||
|                     self.edges[node_str][weights.argmax().item()](nodes[j]) | ||||
|                 ) | ||||
|                 # inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) ) | ||||
|             nodes.append(sum(inter_nodes)) | ||||
|         return nodes[-1] | ||||
|  | ||||
|     # forward with a specific structure | ||||
|     def forward_dynamic(self, inputs, structure): | ||||
|         nodes = [inputs] | ||||
|         for i in range(1, self.max_nodes): | ||||
|             cur_op_node = structure.nodes[i - 1] | ||||
|             inter_nodes = [] | ||||
|             for op_name, j in cur_op_node: | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 op_index = self.op_names.index(op_name) | ||||
|                 inter_nodes.append(self.edges[node_str][op_index](nodes[j])) | ||||
|             nodes.append(sum(inter_nodes)) | ||||
|         return nodes[-1] | ||||
|  | ||||
|  | ||||
| class MixedOp(nn.Module): | ||||
|     def __init__(self, space, C, stride, affine, track_running_stats): | ||||
|         super(MixedOp, self).__init__() | ||||
|         self._ops = nn.ModuleList() | ||||
|         for primitive in space: | ||||
|             op = OPS[primitive](C, C, stride, affine, track_running_stats) | ||||
|             self._ops.append(op) | ||||
|  | ||||
|     def forward_gdas(self, x, weights, index): | ||||
|         return self._ops[index](x) * weights[index] | ||||
|  | ||||
|     def forward_darts(self, x, weights): | ||||
|         return sum(w * op(x) for w, op in zip(weights, self._ops)) | ||||
|  | ||||
|  | ||||
| # Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018 | ||||
| class NASNetSearchCell(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         space, | ||||
|         steps, | ||||
|         multiplier, | ||||
|         C_prev_prev, | ||||
|         C_prev, | ||||
|         C, | ||||
|         reduction, | ||||
|         reduction_prev, | ||||
|         affine, | ||||
|         track_running_stats, | ||||
|     ): | ||||
|         super(NASNetSearchCell, self).__init__() | ||||
|         self.reduction = reduction | ||||
|         self.op_names = deepcopy(space) | ||||
|         if reduction_prev: | ||||
|             self.preprocess0 = OPS["skip_connect"]( | ||||
|                 C_prev_prev, C, 2, affine, track_running_stats | ||||
|             ) | ||||
|         else: | ||||
|             self.preprocess0 = OPS["nor_conv_1x1"]( | ||||
|                 C_prev_prev, C, 1, affine, track_running_stats | ||||
|             ) | ||||
|         self.preprocess1 = OPS["nor_conv_1x1"]( | ||||
|             C_prev, C, 1, affine, track_running_stats | ||||
|         ) | ||||
|         self._steps = steps | ||||
|         self._multiplier = multiplier | ||||
|  | ||||
|         self._ops = nn.ModuleList() | ||||
|         self.edges = nn.ModuleDict() | ||||
|         for i in range(self._steps): | ||||
|             for j in range(2 + i): | ||||
|                 node_str = "{:}<-{:}".format( | ||||
|                     i, j | ||||
|                 )  # indicate the edge from node-(j) to node-(i+2) | ||||
|                 stride = 2 if reduction and j < 2 else 1 | ||||
|                 op = MixedOp(space, C, stride, affine, track_running_stats) | ||||
|                 self.edges[node_str] = op | ||||
|         self.edge_keys = sorted(list(self.edges.keys())) | ||||
|         self.edge2index = {key: i for i, key in enumerate(self.edge_keys)} | ||||
|         self.num_edges = len(self.edges) | ||||
|  | ||||
|     @property | ||||
|     def multiplier(self): | ||||
|         return self._multiplier | ||||
|  | ||||
|     def forward_gdas(self, s0, s1, weightss, indexs): | ||||
|         s0 = self.preprocess0(s0) | ||||
|         s1 = self.preprocess1(s1) | ||||
|  | ||||
|         states = [s0, s1] | ||||
|         for i in range(self._steps): | ||||
|             clist = [] | ||||
|             for j, h in enumerate(states): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 op = self.edges[node_str] | ||||
|                 weights = weightss[self.edge2index[node_str]] | ||||
|                 index = indexs[self.edge2index[node_str]].item() | ||||
|                 clist.append(op.forward_gdas(h, weights, index)) | ||||
|             states.append(sum(clist)) | ||||
|  | ||||
|         return torch.cat(states[-self._multiplier :], dim=1) | ||||
|  | ||||
|     def forward_darts(self, s0, s1, weightss): | ||||
|         s0 = self.preprocess0(s0) | ||||
|         s1 = self.preprocess1(s1) | ||||
|  | ||||
|         states = [s0, s1] | ||||
|         for i in range(self._steps): | ||||
|             clist = [] | ||||
|             for j, h in enumerate(states): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 op = self.edges[node_str] | ||||
|                 weights = weightss[self.edge2index[node_str]] | ||||
|                 clist.append(op.forward_darts(h, weights)) | ||||
|             states.append(sum(clist)) | ||||
|  | ||||
|         return torch.cat(states[-self._multiplier :], dim=1) | ||||
| @@ -1,122 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ######################################################## | ||||
| # DARTS: Differentiable Architecture Search, ICLR 2019 # | ||||
| ######################################################## | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from ..cell_operations import ResNetBasicblock | ||||
| from .search_cells import NAS201SearchCell as SearchCell | ||||
| from .genotypes import Structure | ||||
|  | ||||
|  | ||||
| class TinyNetworkDarts(nn.Module): | ||||
|     def __init__( | ||||
|         self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats | ||||
|     ): | ||||
|         super(TinyNetworkDarts, self).__init__() | ||||
|         self._C = C | ||||
|         self._layerN = N | ||||
|         self.max_nodes = max_nodes | ||||
|         self.stem = nn.Sequential( | ||||
|             nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C) | ||||
|         ) | ||||
|  | ||||
|         layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N | ||||
|         layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||
|  | ||||
|         C_prev, num_edge, edge2index = C, None, None | ||||
|         self.cells = nn.ModuleList() | ||||
|         for index, (C_curr, reduction) in enumerate( | ||||
|             zip(layer_channels, layer_reductions) | ||||
|         ): | ||||
|             if reduction: | ||||
|                 cell = ResNetBasicblock(C_prev, C_curr, 2) | ||||
|             else: | ||||
|                 cell = SearchCell( | ||||
|                     C_prev, | ||||
|                     C_curr, | ||||
|                     1, | ||||
|                     max_nodes, | ||||
|                     search_space, | ||||
|                     affine, | ||||
|                     track_running_stats, | ||||
|                 ) | ||||
|                 if num_edge is None: | ||||
|                     num_edge, edge2index = cell.num_edges, cell.edge2index | ||||
|                 else: | ||||
|                     assert ( | ||||
|                         num_edge == cell.num_edges and edge2index == cell.edge2index | ||||
|                     ), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges) | ||||
|             self.cells.append(cell) | ||||
|             C_prev = cell.out_dim | ||||
|         self.op_names = deepcopy(search_space) | ||||
|         self._Layer = len(self.cells) | ||||
|         self.edge2index = edge2index | ||||
|         self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) | ||||
|         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|         self.classifier = nn.Linear(C_prev, num_classes) | ||||
|         self.arch_parameters = nn.Parameter( | ||||
|             1e-3 * torch.randn(num_edge, len(search_space)) | ||||
|         ) | ||||
|  | ||||
|     def get_weights(self): | ||||
|         xlist = list(self.stem.parameters()) + list(self.cells.parameters()) | ||||
|         xlist += list(self.lastact.parameters()) + list( | ||||
|             self.global_pooling.parameters() | ||||
|         ) | ||||
|         xlist += list(self.classifier.parameters()) | ||||
|         return xlist | ||||
|  | ||||
|     def get_alphas(self): | ||||
|         return [self.arch_parameters] | ||||
|  | ||||
|     def show_alphas(self): | ||||
|         with torch.no_grad(): | ||||
|             return "arch-parameters :\n{:}".format( | ||||
|                 nn.functional.softmax(self.arch_parameters, dim=-1).cpu() | ||||
|             ) | ||||
|  | ||||
|     def get_message(self): | ||||
|         string = self.extra_repr() | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             string += "\n {:02d}/{:02d} :: {:}".format( | ||||
|                 i, len(self.cells), cell.extra_repr() | ||||
|             ) | ||||
|         return string | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def genotype(self): | ||||
|         genotypes = [] | ||||
|         for i in range(1, self.max_nodes): | ||||
|             xlist = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 with torch.no_grad(): | ||||
|                     weights = self.arch_parameters[self.edge2index[node_str]] | ||||
|                     op_name = self.op_names[weights.argmax().item()] | ||||
|                 xlist.append((op_name, j)) | ||||
|             genotypes.append(tuple(xlist)) | ||||
|         return Structure(genotypes) | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         alphas = nn.functional.softmax(self.arch_parameters, dim=-1) | ||||
|  | ||||
|         feature = self.stem(inputs) | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             if isinstance(cell, SearchCell): | ||||
|                 feature = cell(feature, alphas) | ||||
|             else: | ||||
|                 feature = cell(feature) | ||||
|  | ||||
|         out = self.lastact(feature) | ||||
|         out = self.global_pooling(out) | ||||
|         out = out.view(out.size(0), -1) | ||||
|         logits = self.classifier(out) | ||||
|  | ||||
|         return out, logits | ||||
| @@ -1,178 +0,0 @@ | ||||
| #################### | ||||
| # DARTS, ICLR 2019 # | ||||
| #################### | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from typing import List, Text, Dict | ||||
| from .search_cells import NASNetSearchCell as SearchCell | ||||
|  | ||||
|  | ||||
| # The macro structure is based on NASNet | ||||
| class NASNetworkDARTS(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         C: int, | ||||
|         N: int, | ||||
|         steps: int, | ||||
|         multiplier: int, | ||||
|         stem_multiplier: int, | ||||
|         num_classes: int, | ||||
|         search_space: List[Text], | ||||
|         affine: bool, | ||||
|         track_running_stats: bool, | ||||
|     ): | ||||
|         super(NASNetworkDARTS, self).__init__() | ||||
|         self._C = C | ||||
|         self._layerN = N | ||||
|         self._steps = steps | ||||
|         self._multiplier = multiplier | ||||
|         self.stem = nn.Sequential( | ||||
|             nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False), | ||||
|             nn.BatchNorm2d(C * stem_multiplier), | ||||
|         ) | ||||
|  | ||||
|         # config for each layer | ||||
|         layer_channels = ( | ||||
|             [C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1) | ||||
|         ) | ||||
|         layer_reductions = ( | ||||
|             [False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1) | ||||
|         ) | ||||
|  | ||||
|         num_edge, edge2index = None, None | ||||
|         C_prev_prev, C_prev, C_curr, reduction_prev = ( | ||||
|             C * stem_multiplier, | ||||
|             C * stem_multiplier, | ||||
|             C, | ||||
|             False, | ||||
|         ) | ||||
|  | ||||
|         self.cells = nn.ModuleList() | ||||
|         for index, (C_curr, reduction) in enumerate( | ||||
|             zip(layer_channels, layer_reductions) | ||||
|         ): | ||||
|             cell = SearchCell( | ||||
|                 search_space, | ||||
|                 steps, | ||||
|                 multiplier, | ||||
|                 C_prev_prev, | ||||
|                 C_prev, | ||||
|                 C_curr, | ||||
|                 reduction, | ||||
|                 reduction_prev, | ||||
|                 affine, | ||||
|                 track_running_stats, | ||||
|             ) | ||||
|             if num_edge is None: | ||||
|                 num_edge, edge2index = cell.num_edges, cell.edge2index | ||||
|             else: | ||||
|                 assert ( | ||||
|                     num_edge == cell.num_edges and edge2index == cell.edge2index | ||||
|                 ), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges) | ||||
|             self.cells.append(cell) | ||||
|             C_prev_prev, C_prev, reduction_prev = C_prev, multiplier * C_curr, reduction | ||||
|         self.op_names = deepcopy(search_space) | ||||
|         self._Layer = len(self.cells) | ||||
|         self.edge2index = edge2index | ||||
|         self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) | ||||
|         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|         self.classifier = nn.Linear(C_prev, num_classes) | ||||
|         self.arch_normal_parameters = nn.Parameter( | ||||
|             1e-3 * torch.randn(num_edge, len(search_space)) | ||||
|         ) | ||||
|         self.arch_reduce_parameters = nn.Parameter( | ||||
|             1e-3 * torch.randn(num_edge, len(search_space)) | ||||
|         ) | ||||
|  | ||||
|     def get_weights(self) -> List[torch.nn.Parameter]: | ||||
|         xlist = list(self.stem.parameters()) + list(self.cells.parameters()) | ||||
|         xlist += list(self.lastact.parameters()) + list( | ||||
|             self.global_pooling.parameters() | ||||
|         ) | ||||
|         xlist += list(self.classifier.parameters()) | ||||
|         return xlist | ||||
|  | ||||
|     def get_alphas(self) -> List[torch.nn.Parameter]: | ||||
|         return [self.arch_normal_parameters, self.arch_reduce_parameters] | ||||
|  | ||||
|     def show_alphas(self) -> Text: | ||||
|         with torch.no_grad(): | ||||
|             A = "arch-normal-parameters :\n{:}".format( | ||||
|                 nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() | ||||
|             ) | ||||
|             B = "arch-reduce-parameters :\n{:}".format( | ||||
|                 nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() | ||||
|             ) | ||||
|         return "{:}\n{:}".format(A, B) | ||||
|  | ||||
|     def get_message(self) -> Text: | ||||
|         string = self.extra_repr() | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             string += "\n {:02d}/{:02d} :: {:}".format( | ||||
|                 i, len(self.cells), cell.extra_repr() | ||||
|             ) | ||||
|         return string | ||||
|  | ||||
|     def extra_repr(self) -> Text: | ||||
|         return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def genotype(self) -> Dict[Text, List]: | ||||
|         def _parse(weights): | ||||
|             gene = [] | ||||
|             for i in range(self._steps): | ||||
|                 edges = [] | ||||
|                 for j in range(2 + i): | ||||
|                     node_str = "{:}<-{:}".format(i, j) | ||||
|                     ws = weights[self.edge2index[node_str]] | ||||
|                     for k, op_name in enumerate(self.op_names): | ||||
|                         if op_name == "none": | ||||
|                             continue | ||||
|                         edges.append((op_name, j, ws[k])) | ||||
|                 # (TODO) xuanyidong: | ||||
|                 # Here the selected two edges might come from the same input node. | ||||
|                 # And this case could be a problem that two edges will collapse into a single one | ||||
|                 # due to our assumption -- at most one edge from an input node during evaluation. | ||||
|                 edges = sorted(edges, key=lambda x: -x[-1]) | ||||
|                 selected_edges = edges[:2] | ||||
|                 gene.append(tuple(selected_edges)) | ||||
|             return gene | ||||
|  | ||||
|         with torch.no_grad(): | ||||
|             gene_normal = _parse( | ||||
|                 torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy() | ||||
|             ) | ||||
|             gene_reduce = _parse( | ||||
|                 torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy() | ||||
|             ) | ||||
|         return { | ||||
|             "normal": gene_normal, | ||||
|             "normal_concat": list( | ||||
|                 range(2 + self._steps - self._multiplier, self._steps + 2) | ||||
|             ), | ||||
|             "reduce": gene_reduce, | ||||
|             "reduce_concat": list( | ||||
|                 range(2 + self._steps - self._multiplier, self._steps + 2) | ||||
|             ), | ||||
|         } | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|  | ||||
|         normal_w = nn.functional.softmax(self.arch_normal_parameters, dim=1) | ||||
|         reduce_w = nn.functional.softmax(self.arch_reduce_parameters, dim=1) | ||||
|  | ||||
|         s0 = s1 = self.stem(inputs) | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             if cell.reduction: | ||||
|                 ww = reduce_w | ||||
|             else: | ||||
|                 ww = normal_w | ||||
|             s0, s1 = s1, cell.forward_darts(s0, s1, ww) | ||||
|         out = self.lastact(s1) | ||||
|         out = self.global_pooling(out) | ||||
|         out = out.view(out.size(0), -1) | ||||
|         logits = self.classifier(out) | ||||
|  | ||||
|         return out, logits | ||||
| @@ -1,114 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ########################################################################## | ||||
| # Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 # | ||||
| ########################################################################## | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from ..cell_operations import ResNetBasicblock | ||||
| from .search_cells import NAS201SearchCell as SearchCell | ||||
| from .genotypes import Structure | ||||
| from .search_model_enas_utils import Controller | ||||
|  | ||||
|  | ||||
| class TinyNetworkENAS(nn.Module): | ||||
|     def __init__( | ||||
|         self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats | ||||
|     ): | ||||
|         super(TinyNetworkENAS, self).__init__() | ||||
|         self._C = C | ||||
|         self._layerN = N | ||||
|         self.max_nodes = max_nodes | ||||
|         self.stem = nn.Sequential( | ||||
|             nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C) | ||||
|         ) | ||||
|  | ||||
|         layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N | ||||
|         layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||
|  | ||||
|         C_prev, num_edge, edge2index = C, None, None | ||||
|         self.cells = nn.ModuleList() | ||||
|         for index, (C_curr, reduction) in enumerate( | ||||
|             zip(layer_channels, layer_reductions) | ||||
|         ): | ||||
|             if reduction: | ||||
|                 cell = ResNetBasicblock(C_prev, C_curr, 2) | ||||
|             else: | ||||
|                 cell = SearchCell( | ||||
|                     C_prev, | ||||
|                     C_curr, | ||||
|                     1, | ||||
|                     max_nodes, | ||||
|                     search_space, | ||||
|                     affine, | ||||
|                     track_running_stats, | ||||
|                 ) | ||||
|                 if num_edge is None: | ||||
|                     num_edge, edge2index = cell.num_edges, cell.edge2index | ||||
|                 else: | ||||
|                     assert ( | ||||
|                         num_edge == cell.num_edges and edge2index == cell.edge2index | ||||
|                     ), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges) | ||||
|             self.cells.append(cell) | ||||
|             C_prev = cell.out_dim | ||||
|         self.op_names = deepcopy(search_space) | ||||
|         self._Layer = len(self.cells) | ||||
|         self.edge2index = edge2index | ||||
|         self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) | ||||
|         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|         self.classifier = nn.Linear(C_prev, num_classes) | ||||
|         # to maintain the sampled architecture | ||||
|         self.sampled_arch = None | ||||
|  | ||||
|     def update_arch(self, _arch): | ||||
|         if _arch is None: | ||||
|             self.sampled_arch = None | ||||
|         elif isinstance(_arch, Structure): | ||||
|             self.sampled_arch = _arch | ||||
|         elif isinstance(_arch, (list, tuple)): | ||||
|             genotypes = [] | ||||
|             for i in range(1, self.max_nodes): | ||||
|                 xlist = [] | ||||
|                 for j in range(i): | ||||
|                     node_str = "{:}<-{:}".format(i, j) | ||||
|                     op_index = _arch[self.edge2index[node_str]] | ||||
|                     op_name = self.op_names[op_index] | ||||
|                     xlist.append((op_name, j)) | ||||
|                 genotypes.append(tuple(xlist)) | ||||
|             self.sampled_arch = Structure(genotypes) | ||||
|         else: | ||||
|             raise ValueError("invalid type of input architecture : {:}".format(_arch)) | ||||
|         return self.sampled_arch | ||||
|  | ||||
|     def create_controller(self): | ||||
|         return Controller(len(self.edge2index), len(self.op_names)) | ||||
|  | ||||
|     def get_message(self): | ||||
|         string = self.extra_repr() | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             string += "\n {:02d}/{:02d} :: {:}".format( | ||||
|                 i, len(self.cells), cell.extra_repr() | ||||
|             ) | ||||
|         return string | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|  | ||||
|         feature = self.stem(inputs) | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             if isinstance(cell, SearchCell): | ||||
|                 feature = cell.forward_dynamic(feature, self.sampled_arch) | ||||
|             else: | ||||
|                 feature = cell(feature) | ||||
|  | ||||
|         out = self.lastact(feature) | ||||
|         out = self.global_pooling(out) | ||||
|         out = out.view(out.size(0), -1) | ||||
|         logits = self.classifier(out) | ||||
|  | ||||
|         return out, logits | ||||
| @@ -1,74 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ########################################################################## | ||||
| # Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 # | ||||
| ########################################################################## | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from torch.distributions.categorical import Categorical | ||||
|  | ||||
|  | ||||
| class Controller(nn.Module): | ||||
|     # we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py | ||||
|     def __init__( | ||||
|         self, | ||||
|         num_edge, | ||||
|         num_ops, | ||||
|         lstm_size=32, | ||||
|         lstm_num_layers=2, | ||||
|         tanh_constant=2.5, | ||||
|         temperature=5.0, | ||||
|     ): | ||||
|         super(Controller, self).__init__() | ||||
|         # assign the attributes | ||||
|         self.num_edge = num_edge | ||||
|         self.num_ops = num_ops | ||||
|         self.lstm_size = lstm_size | ||||
|         self.lstm_N = lstm_num_layers | ||||
|         self.tanh_constant = tanh_constant | ||||
|         self.temperature = temperature | ||||
|         # create parameters | ||||
|         self.register_parameter( | ||||
|             "input_vars", nn.Parameter(torch.Tensor(1, 1, lstm_size)) | ||||
|         ) | ||||
|         self.w_lstm = nn.LSTM( | ||||
|             input_size=self.lstm_size, | ||||
|             hidden_size=self.lstm_size, | ||||
|             num_layers=self.lstm_N, | ||||
|         ) | ||||
|         self.w_embd = nn.Embedding(self.num_ops, self.lstm_size) | ||||
|         self.w_pred = nn.Linear(self.lstm_size, self.num_ops) | ||||
|  | ||||
|         nn.init.uniform_(self.input_vars, -0.1, 0.1) | ||||
|         nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1) | ||||
|         nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1) | ||||
|         nn.init.uniform_(self.w_embd.weight, -0.1, 0.1) | ||||
|         nn.init.uniform_(self.w_pred.weight, -0.1, 0.1) | ||||
|  | ||||
|     def forward(self): | ||||
|  | ||||
|         inputs, h0 = self.input_vars, None | ||||
|         log_probs, entropys, sampled_arch = [], [], [] | ||||
|         for iedge in range(self.num_edge): | ||||
|             outputs, h0 = self.w_lstm(inputs, h0) | ||||
|  | ||||
|             logits = self.w_pred(outputs) | ||||
|             logits = logits / self.temperature | ||||
|             logits = self.tanh_constant * torch.tanh(logits) | ||||
|             # distribution | ||||
|             op_distribution = Categorical(logits=logits) | ||||
|             op_index = op_distribution.sample() | ||||
|             sampled_arch.append(op_index.item()) | ||||
|  | ||||
|             op_log_prob = op_distribution.log_prob(op_index) | ||||
|             log_probs.append(op_log_prob.view(-1)) | ||||
|             op_entropy = op_distribution.entropy() | ||||
|             entropys.append(op_entropy.view(-1)) | ||||
|  | ||||
|             # obtain the input embedding for the next step | ||||
|             inputs = self.w_embd(op_index) | ||||
|         return ( | ||||
|             torch.sum(torch.cat(log_probs)), | ||||
|             torch.sum(torch.cat(entropys)), | ||||
|             sampled_arch, | ||||
|         ) | ||||
| @@ -1,142 +0,0 @@ | ||||
| ########################################################################### | ||||
| # Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 # | ||||
| ########################################################################### | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from ..cell_operations import ResNetBasicblock | ||||
| from .search_cells import NAS201SearchCell as SearchCell | ||||
| from .genotypes import Structure | ||||
|  | ||||
|  | ||||
| class TinyNetworkGDAS(nn.Module): | ||||
|  | ||||
|     # def __init__(self, C, N, max_nodes, num_classes, search_space, affine=False, track_running_stats=True): | ||||
|     def __init__( | ||||
|         self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats | ||||
|     ): | ||||
|         super(TinyNetworkGDAS, self).__init__() | ||||
|         self._C = C | ||||
|         self._layerN = N | ||||
|         self.max_nodes = max_nodes | ||||
|         self.stem = nn.Sequential( | ||||
|             nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C) | ||||
|         ) | ||||
|  | ||||
|         layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N | ||||
|         layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||
|  | ||||
|         C_prev, num_edge, edge2index = C, None, None | ||||
|         self.cells = nn.ModuleList() | ||||
|         for index, (C_curr, reduction) in enumerate( | ||||
|             zip(layer_channels, layer_reductions) | ||||
|         ): | ||||
|             if reduction: | ||||
|                 cell = ResNetBasicblock(C_prev, C_curr, 2) | ||||
|             else: | ||||
|                 cell = SearchCell( | ||||
|                     C_prev, | ||||
|                     C_curr, | ||||
|                     1, | ||||
|                     max_nodes, | ||||
|                     search_space, | ||||
|                     affine, | ||||
|                     track_running_stats, | ||||
|                 ) | ||||
|                 if num_edge is None: | ||||
|                     num_edge, edge2index = cell.num_edges, cell.edge2index | ||||
|                 else: | ||||
|                     assert ( | ||||
|                         num_edge == cell.num_edges and edge2index == cell.edge2index | ||||
|                     ), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges) | ||||
|             self.cells.append(cell) | ||||
|             C_prev = cell.out_dim | ||||
|         self.op_names = deepcopy(search_space) | ||||
|         self._Layer = len(self.cells) | ||||
|         self.edge2index = edge2index | ||||
|         self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) | ||||
|         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|         self.classifier = nn.Linear(C_prev, num_classes) | ||||
|         self.arch_parameters = nn.Parameter( | ||||
|             1e-3 * torch.randn(num_edge, len(search_space)) | ||||
|         ) | ||||
|         self.tau = 10 | ||||
|  | ||||
|     def get_weights(self): | ||||
|         xlist = list(self.stem.parameters()) + list(self.cells.parameters()) | ||||
|         xlist += list(self.lastact.parameters()) + list( | ||||
|             self.global_pooling.parameters() | ||||
|         ) | ||||
|         xlist += list(self.classifier.parameters()) | ||||
|         return xlist | ||||
|  | ||||
|     def set_tau(self, tau): | ||||
|         self.tau = tau | ||||
|  | ||||
|     def get_tau(self): | ||||
|         return self.tau | ||||
|  | ||||
|     def get_alphas(self): | ||||
|         return [self.arch_parameters] | ||||
|  | ||||
|     def show_alphas(self): | ||||
|         with torch.no_grad(): | ||||
|             return "arch-parameters :\n{:}".format( | ||||
|                 nn.functional.softmax(self.arch_parameters, dim=-1).cpu() | ||||
|             ) | ||||
|  | ||||
|     def get_message(self): | ||||
|         string = self.extra_repr() | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             string += "\n {:02d}/{:02d} :: {:}".format( | ||||
|                 i, len(self.cells), cell.extra_repr() | ||||
|             ) | ||||
|         return string | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def genotype(self): | ||||
|         genotypes = [] | ||||
|         for i in range(1, self.max_nodes): | ||||
|             xlist = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 with torch.no_grad(): | ||||
|                     weights = self.arch_parameters[self.edge2index[node_str]] | ||||
|                     op_name = self.op_names[weights.argmax().item()] | ||||
|                 xlist.append((op_name, j)) | ||||
|             genotypes.append(tuple(xlist)) | ||||
|         return Structure(genotypes) | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         while True: | ||||
|             gumbels = -torch.empty_like(self.arch_parameters).exponential_().log() | ||||
|             logits = (self.arch_parameters.log_softmax(dim=1) + gumbels) / self.tau | ||||
|             probs = nn.functional.softmax(logits, dim=1) | ||||
|             index = probs.max(-1, keepdim=True)[1] | ||||
|             one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0) | ||||
|             hardwts = one_h - probs.detach() + probs | ||||
|             if ( | ||||
|                 (torch.isinf(gumbels).any()) | ||||
|                 or (torch.isinf(probs).any()) | ||||
|                 or (torch.isnan(probs).any()) | ||||
|             ): | ||||
|                 continue | ||||
|             else: | ||||
|                 break | ||||
|  | ||||
|         feature = self.stem(inputs) | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             if isinstance(cell, SearchCell): | ||||
|                 feature = cell.forward_gdas(feature, hardwts, index) | ||||
|             else: | ||||
|                 feature = cell(feature) | ||||
|         out = self.lastact(feature) | ||||
|         out = self.global_pooling(out) | ||||
|         out = out.view(out.size(0), -1) | ||||
|         logits = self.classifier(out) | ||||
|  | ||||
|         return out, logits | ||||
| @@ -1,199 +0,0 @@ | ||||
| ########################################################################### | ||||
| # Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 # | ||||
| ########################################################################### | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from models.cell_searchs.search_cells import NASNetSearchCell as SearchCell | ||||
| from models.cell_operations import RAW_OP_CLASSES | ||||
|  | ||||
|  | ||||
| # The macro structure is based on NASNet | ||||
| class NASNetworkGDAS_FRC(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         C, | ||||
|         N, | ||||
|         steps, | ||||
|         multiplier, | ||||
|         stem_multiplier, | ||||
|         num_classes, | ||||
|         search_space, | ||||
|         affine, | ||||
|         track_running_stats, | ||||
|     ): | ||||
|         super(NASNetworkGDAS_FRC, self).__init__() | ||||
|         self._C = C | ||||
|         self._layerN = N | ||||
|         self._steps = steps | ||||
|         self._multiplier = multiplier | ||||
|         self.stem = nn.Sequential( | ||||
|             nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False), | ||||
|             nn.BatchNorm2d(C * stem_multiplier), | ||||
|         ) | ||||
|  | ||||
|         # config for each layer | ||||
|         layer_channels = ( | ||||
|             [C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1) | ||||
|         ) | ||||
|         layer_reductions = ( | ||||
|             [False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1) | ||||
|         ) | ||||
|  | ||||
|         num_edge, edge2index = None, None | ||||
|         C_prev_prev, C_prev, C_curr, reduction_prev = ( | ||||
|             C * stem_multiplier, | ||||
|             C * stem_multiplier, | ||||
|             C, | ||||
|             False, | ||||
|         ) | ||||
|  | ||||
|         self.cells = nn.ModuleList() | ||||
|         for index, (C_curr, reduction) in enumerate( | ||||
|             zip(layer_channels, layer_reductions) | ||||
|         ): | ||||
|             if reduction: | ||||
|                 cell = RAW_OP_CLASSES["gdas_reduction"]( | ||||
|                     C_prev_prev, | ||||
|                     C_prev, | ||||
|                     C_curr, | ||||
|                     reduction_prev, | ||||
|                     affine, | ||||
|                     track_running_stats, | ||||
|                 ) | ||||
|             else: | ||||
|                 cell = SearchCell( | ||||
|                     search_space, | ||||
|                     steps, | ||||
|                     multiplier, | ||||
|                     C_prev_prev, | ||||
|                     C_prev, | ||||
|                     C_curr, | ||||
|                     reduction, | ||||
|                     reduction_prev, | ||||
|                     affine, | ||||
|                     track_running_stats, | ||||
|                 ) | ||||
|             if num_edge is None: | ||||
|                 num_edge, edge2index = cell.num_edges, cell.edge2index | ||||
|             else: | ||||
|                 assert ( | ||||
|                     reduction | ||||
|                     or num_edge == cell.num_edges | ||||
|                     and edge2index == cell.edge2index | ||||
|                 ), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges) | ||||
|             self.cells.append(cell) | ||||
|             C_prev_prev, C_prev, reduction_prev = ( | ||||
|                 C_prev, | ||||
|                 cell.multiplier * C_curr, | ||||
|                 reduction, | ||||
|             ) | ||||
|         self.op_names = deepcopy(search_space) | ||||
|         self._Layer = len(self.cells) | ||||
|         self.edge2index = edge2index | ||||
|         self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) | ||||
|         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|         self.classifier = nn.Linear(C_prev, num_classes) | ||||
|         self.arch_parameters = nn.Parameter( | ||||
|             1e-3 * torch.randn(num_edge, len(search_space)) | ||||
|         ) | ||||
|         self.tau = 10 | ||||
|  | ||||
|     def get_weights(self): | ||||
|         xlist = list(self.stem.parameters()) + list(self.cells.parameters()) | ||||
|         xlist += list(self.lastact.parameters()) + list( | ||||
|             self.global_pooling.parameters() | ||||
|         ) | ||||
|         xlist += list(self.classifier.parameters()) | ||||
|         return xlist | ||||
|  | ||||
|     def set_tau(self, tau): | ||||
|         self.tau = tau | ||||
|  | ||||
|     def get_tau(self): | ||||
|         return self.tau | ||||
|  | ||||
|     def get_alphas(self): | ||||
|         return [self.arch_parameters] | ||||
|  | ||||
|     def show_alphas(self): | ||||
|         with torch.no_grad(): | ||||
|             A = "arch-normal-parameters :\n{:}".format( | ||||
|                 nn.functional.softmax(self.arch_parameters, dim=-1).cpu() | ||||
|             ) | ||||
|         return "{:}".format(A) | ||||
|  | ||||
|     def get_message(self): | ||||
|         string = self.extra_repr() | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             string += "\n {:02d}/{:02d} :: {:}".format( | ||||
|                 i, len(self.cells), cell.extra_repr() | ||||
|             ) | ||||
|         return string | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def genotype(self): | ||||
|         def _parse(weights): | ||||
|             gene = [] | ||||
|             for i in range(self._steps): | ||||
|                 edges = [] | ||||
|                 for j in range(2 + i): | ||||
|                     node_str = "{:}<-{:}".format(i, j) | ||||
|                     ws = weights[self.edge2index[node_str]] | ||||
|                     for k, op_name in enumerate(self.op_names): | ||||
|                         if op_name == "none": | ||||
|                             continue | ||||
|                         edges.append((op_name, j, ws[k])) | ||||
|                 edges = sorted(edges, key=lambda x: -x[-1]) | ||||
|                 selected_edges = edges[:2] | ||||
|                 gene.append(tuple(selected_edges)) | ||||
|             return gene | ||||
|  | ||||
|         with torch.no_grad(): | ||||
|             gene_normal = _parse( | ||||
|                 torch.softmax(self.arch_parameters, dim=-1).cpu().numpy() | ||||
|             ) | ||||
|         return { | ||||
|             "normal": gene_normal, | ||||
|             "normal_concat": list( | ||||
|                 range(2 + self._steps - self._multiplier, self._steps + 2) | ||||
|             ), | ||||
|         } | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         def get_gumbel_prob(xins): | ||||
|             while True: | ||||
|                 gumbels = -torch.empty_like(xins).exponential_().log() | ||||
|                 logits = (xins.log_softmax(dim=1) + gumbels) / self.tau | ||||
|                 probs = nn.functional.softmax(logits, dim=1) | ||||
|                 index = probs.max(-1, keepdim=True)[1] | ||||
|                 one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0) | ||||
|                 hardwts = one_h - probs.detach() + probs | ||||
|                 if ( | ||||
|                     (torch.isinf(gumbels).any()) | ||||
|                     or (torch.isinf(probs).any()) | ||||
|                     or (torch.isnan(probs).any()) | ||||
|                 ): | ||||
|                     continue | ||||
|                 else: | ||||
|                     break | ||||
|             return hardwts, index | ||||
|  | ||||
|         hardwts, index = get_gumbel_prob(self.arch_parameters) | ||||
|  | ||||
|         s0 = s1 = self.stem(inputs) | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             if cell.reduction: | ||||
|                 s0, s1 = s1, cell(s0, s1) | ||||
|             else: | ||||
|                 s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index) | ||||
|         out = self.lastact(s1) | ||||
|         out = self.global_pooling(out) | ||||
|         out = out.view(out.size(0), -1) | ||||
|         logits = self.classifier(out) | ||||
|  | ||||
|         return out, logits | ||||
| @@ -1,197 +0,0 @@ | ||||
| ########################################################################### | ||||
| # Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 # | ||||
| ########################################################################### | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from .search_cells import NASNetSearchCell as SearchCell | ||||
|  | ||||
|  | ||||
| # The macro structure is based on NASNet | ||||
| class NASNetworkGDAS(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         C, | ||||
|         N, | ||||
|         steps, | ||||
|         multiplier, | ||||
|         stem_multiplier, | ||||
|         num_classes, | ||||
|         search_space, | ||||
|         affine, | ||||
|         track_running_stats, | ||||
|     ): | ||||
|         super(NASNetworkGDAS, self).__init__() | ||||
|         self._C = C | ||||
|         self._layerN = N | ||||
|         self._steps = steps | ||||
|         self._multiplier = multiplier | ||||
|         self.stem = nn.Sequential( | ||||
|             nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False), | ||||
|             nn.BatchNorm2d(C * stem_multiplier), | ||||
|         ) | ||||
|  | ||||
|         # config for each layer | ||||
|         layer_channels = ( | ||||
|             [C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1) | ||||
|         ) | ||||
|         layer_reductions = ( | ||||
|             [False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1) | ||||
|         ) | ||||
|  | ||||
|         num_edge, edge2index = None, None | ||||
|         C_prev_prev, C_prev, C_curr, reduction_prev = ( | ||||
|             C * stem_multiplier, | ||||
|             C * stem_multiplier, | ||||
|             C, | ||||
|             False, | ||||
|         ) | ||||
|  | ||||
|         self.cells = nn.ModuleList() | ||||
|         for index, (C_curr, reduction) in enumerate( | ||||
|             zip(layer_channels, layer_reductions) | ||||
|         ): | ||||
|             cell = SearchCell( | ||||
|                 search_space, | ||||
|                 steps, | ||||
|                 multiplier, | ||||
|                 C_prev_prev, | ||||
|                 C_prev, | ||||
|                 C_curr, | ||||
|                 reduction, | ||||
|                 reduction_prev, | ||||
|                 affine, | ||||
|                 track_running_stats, | ||||
|             ) | ||||
|             if num_edge is None: | ||||
|                 num_edge, edge2index = cell.num_edges, cell.edge2index | ||||
|             else: | ||||
|                 assert ( | ||||
|                     num_edge == cell.num_edges and edge2index == cell.edge2index | ||||
|                 ), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges) | ||||
|             self.cells.append(cell) | ||||
|             C_prev_prev, C_prev, reduction_prev = C_prev, multiplier * C_curr, reduction | ||||
|         self.op_names = deepcopy(search_space) | ||||
|         self._Layer = len(self.cells) | ||||
|         self.edge2index = edge2index | ||||
|         self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) | ||||
|         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|         self.classifier = nn.Linear(C_prev, num_classes) | ||||
|         self.arch_normal_parameters = nn.Parameter( | ||||
|             1e-3 * torch.randn(num_edge, len(search_space)) | ||||
|         ) | ||||
|         self.arch_reduce_parameters = nn.Parameter( | ||||
|             1e-3 * torch.randn(num_edge, len(search_space)) | ||||
|         ) | ||||
|         self.tau = 10 | ||||
|  | ||||
|     def get_weights(self): | ||||
|         xlist = list(self.stem.parameters()) + list(self.cells.parameters()) | ||||
|         xlist += list(self.lastact.parameters()) + list( | ||||
|             self.global_pooling.parameters() | ||||
|         ) | ||||
|         xlist += list(self.classifier.parameters()) | ||||
|         return xlist | ||||
|  | ||||
|     def set_tau(self, tau): | ||||
|         self.tau = tau | ||||
|  | ||||
|     def get_tau(self): | ||||
|         return self.tau | ||||
|  | ||||
|     def get_alphas(self): | ||||
|         return [self.arch_normal_parameters, self.arch_reduce_parameters] | ||||
|  | ||||
|     def show_alphas(self): | ||||
|         with torch.no_grad(): | ||||
|             A = "arch-normal-parameters :\n{:}".format( | ||||
|                 nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() | ||||
|             ) | ||||
|             B = "arch-reduce-parameters :\n{:}".format( | ||||
|                 nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() | ||||
|             ) | ||||
|         return "{:}\n{:}".format(A, B) | ||||
|  | ||||
|     def get_message(self): | ||||
|         string = self.extra_repr() | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             string += "\n {:02d}/{:02d} :: {:}".format( | ||||
|                 i, len(self.cells), cell.extra_repr() | ||||
|             ) | ||||
|         return string | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def genotype(self): | ||||
|         def _parse(weights): | ||||
|             gene = [] | ||||
|             for i in range(self._steps): | ||||
|                 edges = [] | ||||
|                 for j in range(2 + i): | ||||
|                     node_str = "{:}<-{:}".format(i, j) | ||||
|                     ws = weights[self.edge2index[node_str]] | ||||
|                     for k, op_name in enumerate(self.op_names): | ||||
|                         if op_name == "none": | ||||
|                             continue | ||||
|                         edges.append((op_name, j, ws[k])) | ||||
|                 edges = sorted(edges, key=lambda x: -x[-1]) | ||||
|                 selected_edges = edges[:2] | ||||
|                 gene.append(tuple(selected_edges)) | ||||
|             return gene | ||||
|  | ||||
|         with torch.no_grad(): | ||||
|             gene_normal = _parse( | ||||
|                 torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy() | ||||
|             ) | ||||
|             gene_reduce = _parse( | ||||
|                 torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy() | ||||
|             ) | ||||
|         return { | ||||
|             "normal": gene_normal, | ||||
|             "normal_concat": list( | ||||
|                 range(2 + self._steps - self._multiplier, self._steps + 2) | ||||
|             ), | ||||
|             "reduce": gene_reduce, | ||||
|             "reduce_concat": list( | ||||
|                 range(2 + self._steps - self._multiplier, self._steps + 2) | ||||
|             ), | ||||
|         } | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         def get_gumbel_prob(xins): | ||||
|             while True: | ||||
|                 gumbels = -torch.empty_like(xins).exponential_().log() | ||||
|                 logits = (xins.log_softmax(dim=1) + gumbels) / self.tau | ||||
|                 probs = nn.functional.softmax(logits, dim=1) | ||||
|                 index = probs.max(-1, keepdim=True)[1] | ||||
|                 one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0) | ||||
|                 hardwts = one_h - probs.detach() + probs | ||||
|                 if ( | ||||
|                     (torch.isinf(gumbels).any()) | ||||
|                     or (torch.isinf(probs).any()) | ||||
|                     or (torch.isnan(probs).any()) | ||||
|                 ): | ||||
|                     continue | ||||
|                 else: | ||||
|                     break | ||||
|             return hardwts, index | ||||
|  | ||||
|         normal_hardwts, normal_index = get_gumbel_prob(self.arch_normal_parameters) | ||||
|         reduce_hardwts, reduce_index = get_gumbel_prob(self.arch_reduce_parameters) | ||||
|  | ||||
|         s0 = s1 = self.stem(inputs) | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             if cell.reduction: | ||||
|                 hardwts, index = reduce_hardwts, reduce_index | ||||
|             else: | ||||
|                 hardwts, index = normal_hardwts, normal_index | ||||
|             s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index) | ||||
|         out = self.lastact(s1) | ||||
|         out = self.global_pooling(out) | ||||
|         out = out.view(out.size(0), -1) | ||||
|         logits = self.classifier(out) | ||||
|  | ||||
|         return out, logits | ||||
| @@ -1,102 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ############################################################################## | ||||
| # Random Search and Reproducibility for Neural Architecture Search, UAI 2019 # | ||||
| ############################################################################## | ||||
| import torch, random | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from ..cell_operations import ResNetBasicblock | ||||
| from .search_cells import NAS201SearchCell as SearchCell | ||||
| from .genotypes import Structure | ||||
|  | ||||
|  | ||||
| class TinyNetworkRANDOM(nn.Module): | ||||
|     def __init__( | ||||
|         self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats | ||||
|     ): | ||||
|         super(TinyNetworkRANDOM, self).__init__() | ||||
|         self._C = C | ||||
|         self._layerN = N | ||||
|         self.max_nodes = max_nodes | ||||
|         self.stem = nn.Sequential( | ||||
|             nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C) | ||||
|         ) | ||||
|  | ||||
|         layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N | ||||
|         layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||
|  | ||||
|         C_prev, num_edge, edge2index = C, None, None | ||||
|         self.cells = nn.ModuleList() | ||||
|         for index, (C_curr, reduction) in enumerate( | ||||
|             zip(layer_channels, layer_reductions) | ||||
|         ): | ||||
|             if reduction: | ||||
|                 cell = ResNetBasicblock(C_prev, C_curr, 2) | ||||
|             else: | ||||
|                 cell = SearchCell( | ||||
|                     C_prev, | ||||
|                     C_curr, | ||||
|                     1, | ||||
|                     max_nodes, | ||||
|                     search_space, | ||||
|                     affine, | ||||
|                     track_running_stats, | ||||
|                 ) | ||||
|                 if num_edge is None: | ||||
|                     num_edge, edge2index = cell.num_edges, cell.edge2index | ||||
|                 else: | ||||
|                     assert ( | ||||
|                         num_edge == cell.num_edges and edge2index == cell.edge2index | ||||
|                     ), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges) | ||||
|             self.cells.append(cell) | ||||
|             C_prev = cell.out_dim | ||||
|         self.op_names = deepcopy(search_space) | ||||
|         self._Layer = len(self.cells) | ||||
|         self.edge2index = edge2index | ||||
|         self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) | ||||
|         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|         self.classifier = nn.Linear(C_prev, num_classes) | ||||
|         self.arch_cache = None | ||||
|  | ||||
|     def get_message(self): | ||||
|         string = self.extra_repr() | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             string += "\n {:02d}/{:02d} :: {:}".format( | ||||
|                 i, len(self.cells), cell.extra_repr() | ||||
|             ) | ||||
|         return string | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def random_genotype(self, set_cache): | ||||
|         genotypes = [] | ||||
|         for i in range(1, self.max_nodes): | ||||
|             xlist = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 op_name = random.choice(self.op_names) | ||||
|                 xlist.append((op_name, j)) | ||||
|             genotypes.append(tuple(xlist)) | ||||
|         arch = Structure(genotypes) | ||||
|         if set_cache: | ||||
|             self.arch_cache = arch | ||||
|         return arch | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|  | ||||
|         feature = self.stem(inputs) | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             if isinstance(cell, SearchCell): | ||||
|                 feature = cell.forward_dynamic(feature, self.arch_cache) | ||||
|             else: | ||||
|                 feature = cell(feature) | ||||
|  | ||||
|         out = self.lastact(feature) | ||||
|         out = self.global_pooling(out) | ||||
|         out = out.view(out.size(0), -1) | ||||
|         logits = self.classifier(out) | ||||
|         return out, logits | ||||
| @@ -1,178 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ###################################################################################### | ||||
| # One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 # | ||||
| ###################################################################################### | ||||
| import torch, random | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from ..cell_operations import ResNetBasicblock | ||||
| from .search_cells import NAS201SearchCell as SearchCell | ||||
| from .genotypes import Structure | ||||
|  | ||||
|  | ||||
| class TinyNetworkSETN(nn.Module): | ||||
|     def __init__( | ||||
|         self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats | ||||
|     ): | ||||
|         super(TinyNetworkSETN, self).__init__() | ||||
|         self._C = C | ||||
|         self._layerN = N | ||||
|         self.max_nodes = max_nodes | ||||
|         self.stem = nn.Sequential( | ||||
|             nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C) | ||||
|         ) | ||||
|  | ||||
|         layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N | ||||
|         layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||
|  | ||||
|         C_prev, num_edge, edge2index = C, None, None | ||||
|         self.cells = nn.ModuleList() | ||||
|         for index, (C_curr, reduction) in enumerate( | ||||
|             zip(layer_channels, layer_reductions) | ||||
|         ): | ||||
|             if reduction: | ||||
|                 cell = ResNetBasicblock(C_prev, C_curr, 2) | ||||
|             else: | ||||
|                 cell = SearchCell( | ||||
|                     C_prev, | ||||
|                     C_curr, | ||||
|                     1, | ||||
|                     max_nodes, | ||||
|                     search_space, | ||||
|                     affine, | ||||
|                     track_running_stats, | ||||
|                 ) | ||||
|                 if num_edge is None: | ||||
|                     num_edge, edge2index = cell.num_edges, cell.edge2index | ||||
|                 else: | ||||
|                     assert ( | ||||
|                         num_edge == cell.num_edges and edge2index == cell.edge2index | ||||
|                     ), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges) | ||||
|             self.cells.append(cell) | ||||
|             C_prev = cell.out_dim | ||||
|         self.op_names = deepcopy(search_space) | ||||
|         self._Layer = len(self.cells) | ||||
|         self.edge2index = edge2index | ||||
|         self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) | ||||
|         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|         self.classifier = nn.Linear(C_prev, num_classes) | ||||
|         self.arch_parameters = nn.Parameter( | ||||
|             1e-3 * torch.randn(num_edge, len(search_space)) | ||||
|         ) | ||||
|         self.mode = "urs" | ||||
|         self.dynamic_cell = None | ||||
|  | ||||
|     def set_cal_mode(self, mode, dynamic_cell=None): | ||||
|         assert mode in ["urs", "joint", "select", "dynamic"] | ||||
|         self.mode = mode | ||||
|         if mode == "dynamic": | ||||
|             self.dynamic_cell = deepcopy(dynamic_cell) | ||||
|         else: | ||||
|             self.dynamic_cell = None | ||||
|  | ||||
|     def get_cal_mode(self): | ||||
|         return self.mode | ||||
|  | ||||
|     def get_weights(self): | ||||
|         xlist = list(self.stem.parameters()) + list(self.cells.parameters()) | ||||
|         xlist += list(self.lastact.parameters()) + list( | ||||
|             self.global_pooling.parameters() | ||||
|         ) | ||||
|         xlist += list(self.classifier.parameters()) | ||||
|         return xlist | ||||
|  | ||||
|     def get_alphas(self): | ||||
|         return [self.arch_parameters] | ||||
|  | ||||
|     def get_message(self): | ||||
|         string = self.extra_repr() | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             string += "\n {:02d}/{:02d} :: {:}".format( | ||||
|                 i, len(self.cells), cell.extra_repr() | ||||
|             ) | ||||
|         return string | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def genotype(self): | ||||
|         genotypes = [] | ||||
|         for i in range(1, self.max_nodes): | ||||
|             xlist = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 with torch.no_grad(): | ||||
|                     weights = self.arch_parameters[self.edge2index[node_str]] | ||||
|                     op_name = self.op_names[weights.argmax().item()] | ||||
|                 xlist.append((op_name, j)) | ||||
|             genotypes.append(tuple(xlist)) | ||||
|         return Structure(genotypes) | ||||
|  | ||||
|     def dync_genotype(self, use_random=False): | ||||
|         genotypes = [] | ||||
|         with torch.no_grad(): | ||||
|             alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1) | ||||
|         for i in range(1, self.max_nodes): | ||||
|             xlist = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 if use_random: | ||||
|                     op_name = random.choice(self.op_names) | ||||
|                 else: | ||||
|                     weights = alphas_cpu[self.edge2index[node_str]] | ||||
|                     op_index = torch.multinomial(weights, 1).item() | ||||
|                     op_name = self.op_names[op_index] | ||||
|                 xlist.append((op_name, j)) | ||||
|             genotypes.append(tuple(xlist)) | ||||
|         return Structure(genotypes) | ||||
|  | ||||
|     def get_log_prob(self, arch): | ||||
|         with torch.no_grad(): | ||||
|             logits = nn.functional.log_softmax(self.arch_parameters, dim=-1) | ||||
|         select_logits = [] | ||||
|         for i, node_info in enumerate(arch.nodes): | ||||
|             for op, xin in node_info: | ||||
|                 node_str = "{:}<-{:}".format(i + 1, xin) | ||||
|                 op_index = self.op_names.index(op) | ||||
|                 select_logits.append(logits[self.edge2index[node_str], op_index]) | ||||
|         return sum(select_logits).item() | ||||
|  | ||||
|     def return_topK(self, K): | ||||
|         archs = Structure.gen_all(self.op_names, self.max_nodes, False) | ||||
|         pairs = [(self.get_log_prob(arch), arch) for arch in archs] | ||||
|         if K < 0 or K >= len(archs): | ||||
|             K = len(archs) | ||||
|         sorted_pairs = sorted(pairs, key=lambda x: -x[0]) | ||||
|         return_pairs = [sorted_pairs[_][1] for _ in range(K)] | ||||
|         return return_pairs | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         alphas = nn.functional.softmax(self.arch_parameters, dim=-1) | ||||
|         with torch.no_grad(): | ||||
|             alphas_cpu = alphas.detach().cpu() | ||||
|  | ||||
|         feature = self.stem(inputs) | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             if isinstance(cell, SearchCell): | ||||
|                 if self.mode == "urs": | ||||
|                     feature = cell.forward_urs(feature) | ||||
|                 elif self.mode == "select": | ||||
|                     feature = cell.forward_select(feature, alphas_cpu) | ||||
|                 elif self.mode == "joint": | ||||
|                     feature = cell.forward_joint(feature, alphas) | ||||
|                 elif self.mode == "dynamic": | ||||
|                     feature = cell.forward_dynamic(feature, self.dynamic_cell) | ||||
|                 else: | ||||
|                     raise ValueError("invalid mode={:}".format(self.mode)) | ||||
|             else: | ||||
|                 feature = cell(feature) | ||||
|  | ||||
|         out = self.lastact(feature) | ||||
|         out = self.global_pooling(out) | ||||
|         out = out.view(out.size(0), -1) | ||||
|         logits = self.classifier(out) | ||||
|  | ||||
|         return out, logits | ||||
| @@ -1,205 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ###################################################################################### | ||||
| # One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 # | ||||
| ###################################################################################### | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from typing import List, Text, Dict | ||||
| from .search_cells import NASNetSearchCell as SearchCell | ||||
|  | ||||
|  | ||||
| # The macro structure is based on NASNet | ||||
| class NASNetworkSETN(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         C: int, | ||||
|         N: int, | ||||
|         steps: int, | ||||
|         multiplier: int, | ||||
|         stem_multiplier: int, | ||||
|         num_classes: int, | ||||
|         search_space: List[Text], | ||||
|         affine: bool, | ||||
|         track_running_stats: bool, | ||||
|     ): | ||||
|         super(NASNetworkSETN, self).__init__() | ||||
|         self._C = C | ||||
|         self._layerN = N | ||||
|         self._steps = steps | ||||
|         self._multiplier = multiplier | ||||
|         self.stem = nn.Sequential( | ||||
|             nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False), | ||||
|             nn.BatchNorm2d(C * stem_multiplier), | ||||
|         ) | ||||
|  | ||||
|         # config for each layer | ||||
|         layer_channels = ( | ||||
|             [C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1) | ||||
|         ) | ||||
|         layer_reductions = ( | ||||
|             [False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1) | ||||
|         ) | ||||
|  | ||||
|         num_edge, edge2index = None, None | ||||
|         C_prev_prev, C_prev, C_curr, reduction_prev = ( | ||||
|             C * stem_multiplier, | ||||
|             C * stem_multiplier, | ||||
|             C, | ||||
|             False, | ||||
|         ) | ||||
|  | ||||
|         self.cells = nn.ModuleList() | ||||
|         for index, (C_curr, reduction) in enumerate( | ||||
|             zip(layer_channels, layer_reductions) | ||||
|         ): | ||||
|             cell = SearchCell( | ||||
|                 search_space, | ||||
|                 steps, | ||||
|                 multiplier, | ||||
|                 C_prev_prev, | ||||
|                 C_prev, | ||||
|                 C_curr, | ||||
|                 reduction, | ||||
|                 reduction_prev, | ||||
|                 affine, | ||||
|                 track_running_stats, | ||||
|             ) | ||||
|             if num_edge is None: | ||||
|                 num_edge, edge2index = cell.num_edges, cell.edge2index | ||||
|             else: | ||||
|                 assert ( | ||||
|                     num_edge == cell.num_edges and edge2index == cell.edge2index | ||||
|                 ), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges) | ||||
|             self.cells.append(cell) | ||||
|             C_prev_prev, C_prev, reduction_prev = C_prev, multiplier * C_curr, reduction | ||||
|         self.op_names = deepcopy(search_space) | ||||
|         self._Layer = len(self.cells) | ||||
|         self.edge2index = edge2index | ||||
|         self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) | ||||
|         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|         self.classifier = nn.Linear(C_prev, num_classes) | ||||
|         self.arch_normal_parameters = nn.Parameter( | ||||
|             1e-3 * torch.randn(num_edge, len(search_space)) | ||||
|         ) | ||||
|         self.arch_reduce_parameters = nn.Parameter( | ||||
|             1e-3 * torch.randn(num_edge, len(search_space)) | ||||
|         ) | ||||
|         self.mode = "urs" | ||||
|         self.dynamic_cell = None | ||||
|  | ||||
|     def set_cal_mode(self, mode, dynamic_cell=None): | ||||
|         assert mode in ["urs", "joint", "select", "dynamic"] | ||||
|         self.mode = mode | ||||
|         if mode == "dynamic": | ||||
|             self.dynamic_cell = deepcopy(dynamic_cell) | ||||
|         else: | ||||
|             self.dynamic_cell = None | ||||
|  | ||||
|     def get_weights(self): | ||||
|         xlist = list(self.stem.parameters()) + list(self.cells.parameters()) | ||||
|         xlist += list(self.lastact.parameters()) + list( | ||||
|             self.global_pooling.parameters() | ||||
|         ) | ||||
|         xlist += list(self.classifier.parameters()) | ||||
|         return xlist | ||||
|  | ||||
|     def get_alphas(self): | ||||
|         return [self.arch_normal_parameters, self.arch_reduce_parameters] | ||||
|  | ||||
|     def show_alphas(self): | ||||
|         with torch.no_grad(): | ||||
|             A = "arch-normal-parameters :\n{:}".format( | ||||
|                 nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() | ||||
|             ) | ||||
|             B = "arch-reduce-parameters :\n{:}".format( | ||||
|                 nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() | ||||
|             ) | ||||
|         return "{:}\n{:}".format(A, B) | ||||
|  | ||||
|     def get_message(self): | ||||
|         string = self.extra_repr() | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             string += "\n {:02d}/{:02d} :: {:}".format( | ||||
|                 i, len(self.cells), cell.extra_repr() | ||||
|             ) | ||||
|         return string | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def dync_genotype(self, use_random=False): | ||||
|         genotypes = [] | ||||
|         with torch.no_grad(): | ||||
|             alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1) | ||||
|         for i in range(1, self.max_nodes): | ||||
|             xlist = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 if use_random: | ||||
|                     op_name = random.choice(self.op_names) | ||||
|                 else: | ||||
|                     weights = alphas_cpu[self.edge2index[node_str]] | ||||
|                     op_index = torch.multinomial(weights, 1).item() | ||||
|                     op_name = self.op_names[op_index] | ||||
|                 xlist.append((op_name, j)) | ||||
|             genotypes.append(tuple(xlist)) | ||||
|         return Structure(genotypes) | ||||
|  | ||||
|     def genotype(self): | ||||
|         def _parse(weights): | ||||
|             gene = [] | ||||
|             for i in range(self._steps): | ||||
|                 edges = [] | ||||
|                 for j in range(2 + i): | ||||
|                     node_str = "{:}<-{:}".format(i, j) | ||||
|                     ws = weights[self.edge2index[node_str]] | ||||
|                     for k, op_name in enumerate(self.op_names): | ||||
|                         if op_name == "none": | ||||
|                             continue | ||||
|                         edges.append((op_name, j, ws[k])) | ||||
|                 edges = sorted(edges, key=lambda x: -x[-1]) | ||||
|                 selected_edges = edges[:2] | ||||
|                 gene.append(tuple(selected_edges)) | ||||
|             return gene | ||||
|  | ||||
|         with torch.no_grad(): | ||||
|             gene_normal = _parse( | ||||
|                 torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy() | ||||
|             ) | ||||
|             gene_reduce = _parse( | ||||
|                 torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy() | ||||
|             ) | ||||
|         return { | ||||
|             "normal": gene_normal, | ||||
|             "normal_concat": list( | ||||
|                 range(2 + self._steps - self._multiplier, self._steps + 2) | ||||
|             ), | ||||
|             "reduce": gene_reduce, | ||||
|             "reduce_concat": list( | ||||
|                 range(2 + self._steps - self._multiplier, self._steps + 2) | ||||
|             ), | ||||
|         } | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         normal_hardwts = nn.functional.softmax(self.arch_normal_parameters, dim=-1) | ||||
|         reduce_hardwts = nn.functional.softmax(self.arch_reduce_parameters, dim=-1) | ||||
|  | ||||
|         s0 = s1 = self.stem(inputs) | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             # [TODO] | ||||
|             raise NotImplementedError | ||||
|             if cell.reduction: | ||||
|                 hardwts, index = reduce_hardwts, reduce_index | ||||
|             else: | ||||
|                 hardwts, index = normal_hardwts, normal_index | ||||
|             s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index) | ||||
|         out = self.lastact(s1) | ||||
|         out = self.global_pooling(out) | ||||
|         out = out.view(out.size(0), -1) | ||||
|         logits = self.classifier(out) | ||||
|  | ||||
|         return out, logits | ||||
| @@ -1,74 +0,0 @@ | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
|  | ||||
| def copy_conv(module, init): | ||||
|     assert isinstance(module, nn.Conv2d), "invalid module : {:}".format(module) | ||||
|     assert isinstance(init, nn.Conv2d), "invalid module : {:}".format(init) | ||||
|     new_i, new_o = module.in_channels, module.out_channels | ||||
|     module.weight.copy_(init.weight.detach()[:new_o, :new_i]) | ||||
|     if module.bias is not None: | ||||
|         module.bias.copy_(init.bias.detach()[:new_o]) | ||||
|  | ||||
|  | ||||
| def copy_bn(module, init): | ||||
|     assert isinstance(module, nn.BatchNorm2d), "invalid module : {:}".format(module) | ||||
|     assert isinstance(init, nn.BatchNorm2d), "invalid module : {:}".format(init) | ||||
|     num_features = module.num_features | ||||
|     if module.weight is not None: | ||||
|         module.weight.copy_(init.weight.detach()[:num_features]) | ||||
|     if module.bias is not None: | ||||
|         module.bias.copy_(init.bias.detach()[:num_features]) | ||||
|     if module.running_mean is not None: | ||||
|         module.running_mean.copy_(init.running_mean.detach()[:num_features]) | ||||
|     if module.running_var is not None: | ||||
|         module.running_var.copy_(init.running_var.detach()[:num_features]) | ||||
|  | ||||
|  | ||||
| def copy_fc(module, init): | ||||
|     assert isinstance(module, nn.Linear), "invalid module : {:}".format(module) | ||||
|     assert isinstance(init, nn.Linear), "invalid module : {:}".format(init) | ||||
|     new_i, new_o = module.in_features, module.out_features | ||||
|     module.weight.copy_(init.weight.detach()[:new_o, :new_i]) | ||||
|     if module.bias is not None: | ||||
|         module.bias.copy_(init.bias.detach()[:new_o]) | ||||
|  | ||||
|  | ||||
| def copy_base(module, init): | ||||
|     assert type(module).__name__ in [ | ||||
|         "ConvBNReLU", | ||||
|         "Downsample", | ||||
|     ], "invalid module : {:}".format(module) | ||||
|     assert type(init).__name__ in [ | ||||
|         "ConvBNReLU", | ||||
|         "Downsample", | ||||
|     ], "invalid module : {:}".format(init) | ||||
|     if module.conv is not None: | ||||
|         copy_conv(module.conv, init.conv) | ||||
|     if module.bn is not None: | ||||
|         copy_bn(module.bn, init.bn) | ||||
|  | ||||
|  | ||||
| def copy_basic(module, init): | ||||
|     copy_base(module.conv_a, init.conv_a) | ||||
|     copy_base(module.conv_b, init.conv_b) | ||||
|     if module.downsample is not None: | ||||
|         if init.downsample is not None: | ||||
|             copy_base(module.downsample, init.downsample) | ||||
|         # else: | ||||
|         # import pdb; pdb.set_trace() | ||||
|  | ||||
|  | ||||
| def init_from_model(network, init_model): | ||||
|     with torch.no_grad(): | ||||
|         copy_fc(network.classifier, init_model.classifier) | ||||
|         for base, target in zip(init_model.layers, network.layers): | ||||
|             assert ( | ||||
|                 type(base).__name__ == type(target).__name__ | ||||
|             ), "invalid type : {:} vs {:}".format(base, target) | ||||
|             if type(base).__name__ == "ConvBNReLU": | ||||
|                 copy_base(target, base) | ||||
|             elif type(base).__name__ == "ResNetBasicblock": | ||||
|                 copy_basic(target, base) | ||||
|             else: | ||||
|                 raise ValueError("unknown type name : {:}".format(type(base).__name__)) | ||||
| @@ -1,16 +0,0 @@ | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
|  | ||||
| def initialize_resnet(m): | ||||
|     if isinstance(m, nn.Conv2d): | ||||
|         nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | ||||
|         if m.bias is not None: | ||||
|             nn.init.constant_(m.bias, 0) | ||||
|     elif isinstance(m, nn.BatchNorm2d): | ||||
|         nn.init.constant_(m.weight, 1) | ||||
|         if m.bias is not None: | ||||
|             nn.init.constant_(m.bias, 0) | ||||
|     elif isinstance(m, nn.Linear): | ||||
|         nn.init.normal_(m.weight, 0, 0.01) | ||||
|         nn.init.constant_(m.bias, 0) | ||||
| @@ -1,286 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| from ..initialization import initialize_resnet | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|     def __init__( | ||||
|         self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu | ||||
|     ): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|         if has_avg: | ||||
|             self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | ||||
|         else: | ||||
|             self.avg = None | ||||
|         self.conv = nn.Conv2d( | ||||
|             nIn, | ||||
|             nOut, | ||||
|             kernel_size=kernel, | ||||
|             stride=stride, | ||||
|             padding=padding, | ||||
|             dilation=1, | ||||
|             groups=1, | ||||
|             bias=bias, | ||||
|         ) | ||||
|         if has_bn: | ||||
|             self.bn = nn.BatchNorm2d(nOut) | ||||
|         else: | ||||
|             self.bn = None | ||||
|         if has_relu: | ||||
|             self.relu = nn.ReLU(inplace=True) | ||||
|         else: | ||||
|             self.relu = None | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.avg: | ||||
|             out = self.avg(inputs) | ||||
|         else: | ||||
|             out = inputs | ||||
|         conv = self.conv(out) | ||||
|         if self.bn: | ||||
|             out = self.bn(conv) | ||||
|         else: | ||||
|             out = conv | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         else: | ||||
|             out = out | ||||
|  | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class ResNetBasicblock(nn.Module): | ||||
|     num_conv = 2 | ||||
|     expansion = 1 | ||||
|  | ||||
|     def __init__(self, iCs, stride): | ||||
|         super(ResNetBasicblock, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         assert isinstance(iCs, tuple) or isinstance( | ||||
|             iCs, list | ||||
|         ), "invalid type of iCs : {:}".format(iCs) | ||||
|         assert len(iCs) == 3, "invalid lengths of iCs : {:}".format(iCs) | ||||
|  | ||||
|         self.conv_a = ConvBNReLU( | ||||
|             iCs[0], | ||||
|             iCs[1], | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_b = ConvBNReLU( | ||||
|             iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False | ||||
|         ) | ||||
|         residual_in = iCs[0] | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[2], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|             residual_in = iCs[2] | ||||
|         elif iCs[0] != iCs[2]: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[2], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         # self.out_dim  = max(residual_in, iCs[2]) | ||||
|         self.out_dim = iCs[2] | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         basicblock = self.conv_a(inputs) | ||||
|         basicblock = self.conv_b(basicblock) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = residual + basicblock | ||||
|         return F.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class ResNetBottleneck(nn.Module): | ||||
|     expansion = 4 | ||||
|     num_conv = 3 | ||||
|  | ||||
|     def __init__(self, iCs, stride): | ||||
|         super(ResNetBottleneck, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         assert isinstance(iCs, tuple) or isinstance( | ||||
|             iCs, list | ||||
|         ), "invalid type of iCs : {:}".format(iCs) | ||||
|         assert len(iCs) == 4, "invalid lengths of iCs : {:}".format(iCs) | ||||
|         self.conv_1x1 = ConvBNReLU( | ||||
|             iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True | ||||
|         ) | ||||
|         self.conv_3x3 = ConvBNReLU( | ||||
|             iCs[1], | ||||
|             iCs[2], | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_1x4 = ConvBNReLU( | ||||
|             iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False | ||||
|         ) | ||||
|         residual_in = iCs[0] | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[3], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|             residual_in = iCs[3] | ||||
|         elif iCs[0] != iCs[3]: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[3], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|             residual_in = iCs[3] | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         # self.out_dim = max(residual_in, iCs[3]) | ||||
|         self.out_dim = iCs[3] | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|  | ||||
|         bottleneck = self.conv_1x1(inputs) | ||||
|         bottleneck = self.conv_3x3(bottleneck) | ||||
|         bottleneck = self.conv_1x4(bottleneck) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = residual + bottleneck | ||||
|         return F.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class InferCifarResNet(nn.Module): | ||||
|     def __init__( | ||||
|         self, block_name, depth, xblocks, xchannels, num_classes, zero_init_residual | ||||
|     ): | ||||
|         super(InferCifarResNet, self).__init__() | ||||
|  | ||||
|         # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|         if block_name == "ResNetBasicblock": | ||||
|             block = ResNetBasicblock | ||||
|             assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" | ||||
|             layer_blocks = (depth - 2) // 6 | ||||
|         elif block_name == "ResNetBottleneck": | ||||
|             block = ResNetBottleneck | ||||
|             assert (depth - 2) % 9 == 0, "depth should be one of 164" | ||||
|             layer_blocks = (depth - 2) // 9 | ||||
|         else: | ||||
|             raise ValueError("invalid block : {:}".format(block_name)) | ||||
|         assert len(xblocks) == 3, "invalid xblocks : {:}".format(xblocks) | ||||
|  | ||||
|         self.message = ( | ||||
|             "InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format( | ||||
|                 depth, layer_blocks | ||||
|             ) | ||||
|         ) | ||||
|         self.num_classes = num_classes | ||||
|         self.xchannels = xchannels | ||||
|         self.layers = nn.ModuleList( | ||||
|             [ | ||||
|                 ConvBNReLU( | ||||
|                     xchannels[0], | ||||
|                     xchannels[1], | ||||
|                     3, | ||||
|                     1, | ||||
|                     1, | ||||
|                     False, | ||||
|                     has_avg=False, | ||||
|                     has_bn=True, | ||||
|                     has_relu=True, | ||||
|                 ) | ||||
|             ] | ||||
|         ) | ||||
|         last_channel_idx = 1 | ||||
|         for stage in range(3): | ||||
|             for iL in range(layer_blocks): | ||||
|                 num_conv = block.num_conv | ||||
|                 iCs = self.xchannels[last_channel_idx : last_channel_idx + num_conv + 1] | ||||
|                 stride = 2 if stage > 0 and iL == 0 else 1 | ||||
|                 module = block(iCs, stride) | ||||
|                 last_channel_idx += num_conv | ||||
|                 self.xchannels[last_channel_idx] = module.out_dim | ||||
|                 self.layers.append(module) | ||||
|                 self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format( | ||||
|                     stage, | ||||
|                     iL, | ||||
|                     layer_blocks, | ||||
|                     len(self.layers) - 1, | ||||
|                     iCs, | ||||
|                     module.out_dim, | ||||
|                     stride, | ||||
|                 ) | ||||
|                 if iL + 1 == xblocks[stage]:  # reach the maximum depth | ||||
|                     out_channel = module.out_dim | ||||
|                     for iiL in range(iL + 1, layer_blocks): | ||||
|                         last_channel_idx += num_conv | ||||
|                     self.xchannels[last_channel_idx] = module.out_dim | ||||
|                     break | ||||
|  | ||||
|         self.avgpool = nn.AvgPool2d(8) | ||||
|         self.classifier = nn.Linear(self.xchannels[-1], num_classes) | ||||
|  | ||||
|         self.apply(initialize_resnet) | ||||
|         if zero_init_residual: | ||||
|             for m in self.modules(): | ||||
|                 if isinstance(m, ResNetBasicblock): | ||||
|                     nn.init.constant_(m.conv_b.bn.weight, 0) | ||||
|                 elif isinstance(m, ResNetBottleneck): | ||||
|                     nn.init.constant_(m.conv_1x4.bn.weight, 0) | ||||
|  | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         x = inputs | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             x = layer(x) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = self.classifier(features) | ||||
|         return features, logits | ||||
| @@ -1,263 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| from ..initialization import initialize_resnet | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|     def __init__( | ||||
|         self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu | ||||
|     ): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|         if has_avg: | ||||
|             self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | ||||
|         else: | ||||
|             self.avg = None | ||||
|         self.conv = nn.Conv2d( | ||||
|             nIn, | ||||
|             nOut, | ||||
|             kernel_size=kernel, | ||||
|             stride=stride, | ||||
|             padding=padding, | ||||
|             dilation=1, | ||||
|             groups=1, | ||||
|             bias=bias, | ||||
|         ) | ||||
|         if has_bn: | ||||
|             self.bn = nn.BatchNorm2d(nOut) | ||||
|         else: | ||||
|             self.bn = None | ||||
|         if has_relu: | ||||
|             self.relu = nn.ReLU(inplace=True) | ||||
|         else: | ||||
|             self.relu = None | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.avg: | ||||
|             out = self.avg(inputs) | ||||
|         else: | ||||
|             out = inputs | ||||
|         conv = self.conv(out) | ||||
|         if self.bn: | ||||
|             out = self.bn(conv) | ||||
|         else: | ||||
|             out = conv | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         else: | ||||
|             out = out | ||||
|  | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class ResNetBasicblock(nn.Module): | ||||
|     num_conv = 2 | ||||
|     expansion = 1 | ||||
|  | ||||
|     def __init__(self, inplanes, planes, stride): | ||||
|         super(ResNetBasicblock, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|  | ||||
|         self.conv_a = ConvBNReLU( | ||||
|             inplanes, | ||||
|             planes, | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_b = ConvBNReLU( | ||||
|             planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False | ||||
|         ) | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         elif inplanes != planes: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.out_dim = planes | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         basicblock = self.conv_a(inputs) | ||||
|         basicblock = self.conv_b(basicblock) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = residual + basicblock | ||||
|         return F.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class ResNetBottleneck(nn.Module): | ||||
|     expansion = 4 | ||||
|     num_conv = 3 | ||||
|  | ||||
|     def __init__(self, inplanes, planes, stride): | ||||
|         super(ResNetBottleneck, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         self.conv_1x1 = ConvBNReLU( | ||||
|             inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True | ||||
|         ) | ||||
|         self.conv_3x3 = ConvBNReLU( | ||||
|             planes, | ||||
|             planes, | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_1x4 = ConvBNReLU( | ||||
|             planes, | ||||
|             planes * self.expansion, | ||||
|             1, | ||||
|             1, | ||||
|             0, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=False, | ||||
|         ) | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes * self.expansion, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         elif inplanes != planes * self.expansion: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes * self.expansion, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.out_dim = planes * self.expansion | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|  | ||||
|         bottleneck = self.conv_1x1(inputs) | ||||
|         bottleneck = self.conv_3x3(bottleneck) | ||||
|         bottleneck = self.conv_1x4(bottleneck) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = residual + bottleneck | ||||
|         return F.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class InferDepthCifarResNet(nn.Module): | ||||
|     def __init__(self, block_name, depth, xblocks, num_classes, zero_init_residual): | ||||
|         super(InferDepthCifarResNet, self).__init__() | ||||
|  | ||||
|         # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|         if block_name == "ResNetBasicblock": | ||||
|             block = ResNetBasicblock | ||||
|             assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" | ||||
|             layer_blocks = (depth - 2) // 6 | ||||
|         elif block_name == "ResNetBottleneck": | ||||
|             block = ResNetBottleneck | ||||
|             assert (depth - 2) % 9 == 0, "depth should be one of 164" | ||||
|             layer_blocks = (depth - 2) // 9 | ||||
|         else: | ||||
|             raise ValueError("invalid block : {:}".format(block_name)) | ||||
|         assert len(xblocks) == 3, "invalid xblocks : {:}".format(xblocks) | ||||
|  | ||||
|         self.message = ( | ||||
|             "InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format( | ||||
|                 depth, layer_blocks | ||||
|             ) | ||||
|         ) | ||||
|         self.num_classes = num_classes | ||||
|         self.layers = nn.ModuleList( | ||||
|             [ | ||||
|                 ConvBNReLU( | ||||
|                     3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True | ||||
|                 ) | ||||
|             ] | ||||
|         ) | ||||
|         self.channels = [16] | ||||
|         for stage in range(3): | ||||
|             for iL in range(layer_blocks): | ||||
|                 iC = self.channels[-1] | ||||
|                 planes = 16 * (2 ** stage) | ||||
|                 stride = 2 if stage > 0 and iL == 0 else 1 | ||||
|                 module = block(iC, planes, stride) | ||||
|                 self.channels.append(module.out_dim) | ||||
|                 self.layers.append(module) | ||||
|                 self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:}, oC={:3d}, stride={:}".format( | ||||
|                     stage, | ||||
|                     iL, | ||||
|                     layer_blocks, | ||||
|                     len(self.layers) - 1, | ||||
|                     planes, | ||||
|                     module.out_dim, | ||||
|                     stride, | ||||
|                 ) | ||||
|                 if iL + 1 == xblocks[stage]:  # reach the maximum depth | ||||
|                     break | ||||
|  | ||||
|         self.avgpool = nn.AvgPool2d(8) | ||||
|         self.classifier = nn.Linear(self.channels[-1], num_classes) | ||||
|  | ||||
|         self.apply(initialize_resnet) | ||||
|         if zero_init_residual: | ||||
|             for m in self.modules(): | ||||
|                 if isinstance(m, ResNetBasicblock): | ||||
|                     nn.init.constant_(m.conv_b.bn.weight, 0) | ||||
|                 elif isinstance(m, ResNetBottleneck): | ||||
|                     nn.init.constant_(m.conv_1x4.bn.weight, 0) | ||||
|  | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         x = inputs | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             x = layer(x) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = self.classifier(features) | ||||
|         return features, logits | ||||
| @@ -1,277 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| from ..initialization import initialize_resnet | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|     def __init__( | ||||
|         self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu | ||||
|     ): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|         if has_avg: | ||||
|             self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | ||||
|         else: | ||||
|             self.avg = None | ||||
|         self.conv = nn.Conv2d( | ||||
|             nIn, | ||||
|             nOut, | ||||
|             kernel_size=kernel, | ||||
|             stride=stride, | ||||
|             padding=padding, | ||||
|             dilation=1, | ||||
|             groups=1, | ||||
|             bias=bias, | ||||
|         ) | ||||
|         if has_bn: | ||||
|             self.bn = nn.BatchNorm2d(nOut) | ||||
|         else: | ||||
|             self.bn = None | ||||
|         if has_relu: | ||||
|             self.relu = nn.ReLU(inplace=True) | ||||
|         else: | ||||
|             self.relu = None | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.avg: | ||||
|             out = self.avg(inputs) | ||||
|         else: | ||||
|             out = inputs | ||||
|         conv = self.conv(out) | ||||
|         if self.bn: | ||||
|             out = self.bn(conv) | ||||
|         else: | ||||
|             out = conv | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         else: | ||||
|             out = out | ||||
|  | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class ResNetBasicblock(nn.Module): | ||||
|     num_conv = 2 | ||||
|     expansion = 1 | ||||
|  | ||||
|     def __init__(self, iCs, stride): | ||||
|         super(ResNetBasicblock, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         assert isinstance(iCs, tuple) or isinstance( | ||||
|             iCs, list | ||||
|         ), "invalid type of iCs : {:}".format(iCs) | ||||
|         assert len(iCs) == 3, "invalid lengths of iCs : {:}".format(iCs) | ||||
|  | ||||
|         self.conv_a = ConvBNReLU( | ||||
|             iCs[0], | ||||
|             iCs[1], | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_b = ConvBNReLU( | ||||
|             iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False | ||||
|         ) | ||||
|         residual_in = iCs[0] | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[2], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|             residual_in = iCs[2] | ||||
|         elif iCs[0] != iCs[2]: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[2], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         # self.out_dim  = max(residual_in, iCs[2]) | ||||
|         self.out_dim = iCs[2] | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         basicblock = self.conv_a(inputs) | ||||
|         basicblock = self.conv_b(basicblock) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = residual + basicblock | ||||
|         return F.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class ResNetBottleneck(nn.Module): | ||||
|     expansion = 4 | ||||
|     num_conv = 3 | ||||
|  | ||||
|     def __init__(self, iCs, stride): | ||||
|         super(ResNetBottleneck, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         assert isinstance(iCs, tuple) or isinstance( | ||||
|             iCs, list | ||||
|         ), "invalid type of iCs : {:}".format(iCs) | ||||
|         assert len(iCs) == 4, "invalid lengths of iCs : {:}".format(iCs) | ||||
|         self.conv_1x1 = ConvBNReLU( | ||||
|             iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True | ||||
|         ) | ||||
|         self.conv_3x3 = ConvBNReLU( | ||||
|             iCs[1], | ||||
|             iCs[2], | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_1x4 = ConvBNReLU( | ||||
|             iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False | ||||
|         ) | ||||
|         residual_in = iCs[0] | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[3], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|             residual_in = iCs[3] | ||||
|         elif iCs[0] != iCs[3]: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[3], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|             residual_in = iCs[3] | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         # self.out_dim = max(residual_in, iCs[3]) | ||||
|         self.out_dim = iCs[3] | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|  | ||||
|         bottleneck = self.conv_1x1(inputs) | ||||
|         bottleneck = self.conv_3x3(bottleneck) | ||||
|         bottleneck = self.conv_1x4(bottleneck) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = residual + bottleneck | ||||
|         return F.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class InferWidthCifarResNet(nn.Module): | ||||
|     def __init__(self, block_name, depth, xchannels, num_classes, zero_init_residual): | ||||
|         super(InferWidthCifarResNet, self).__init__() | ||||
|  | ||||
|         # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|         if block_name == "ResNetBasicblock": | ||||
|             block = ResNetBasicblock | ||||
|             assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" | ||||
|             layer_blocks = (depth - 2) // 6 | ||||
|         elif block_name == "ResNetBottleneck": | ||||
|             block = ResNetBottleneck | ||||
|             assert (depth - 2) % 9 == 0, "depth should be one of 164" | ||||
|             layer_blocks = (depth - 2) // 9 | ||||
|         else: | ||||
|             raise ValueError("invalid block : {:}".format(block_name)) | ||||
|  | ||||
|         self.message = ( | ||||
|             "InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format( | ||||
|                 depth, layer_blocks | ||||
|             ) | ||||
|         ) | ||||
|         self.num_classes = num_classes | ||||
|         self.xchannels = xchannels | ||||
|         self.layers = nn.ModuleList( | ||||
|             [ | ||||
|                 ConvBNReLU( | ||||
|                     xchannels[0], | ||||
|                     xchannels[1], | ||||
|                     3, | ||||
|                     1, | ||||
|                     1, | ||||
|                     False, | ||||
|                     has_avg=False, | ||||
|                     has_bn=True, | ||||
|                     has_relu=True, | ||||
|                 ) | ||||
|             ] | ||||
|         ) | ||||
|         last_channel_idx = 1 | ||||
|         for stage in range(3): | ||||
|             for iL in range(layer_blocks): | ||||
|                 num_conv = block.num_conv | ||||
|                 iCs = self.xchannels[last_channel_idx : last_channel_idx + num_conv + 1] | ||||
|                 stride = 2 if stage > 0 and iL == 0 else 1 | ||||
|                 module = block(iCs, stride) | ||||
|                 last_channel_idx += num_conv | ||||
|                 self.xchannels[last_channel_idx] = module.out_dim | ||||
|                 self.layers.append(module) | ||||
|                 self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format( | ||||
|                     stage, | ||||
|                     iL, | ||||
|                     layer_blocks, | ||||
|                     len(self.layers) - 1, | ||||
|                     iCs, | ||||
|                     module.out_dim, | ||||
|                     stride, | ||||
|                 ) | ||||
|  | ||||
|         self.avgpool = nn.AvgPool2d(8) | ||||
|         self.classifier = nn.Linear(self.xchannels[-1], num_classes) | ||||
|  | ||||
|         self.apply(initialize_resnet) | ||||
|         if zero_init_residual: | ||||
|             for m in self.modules(): | ||||
|                 if isinstance(m, ResNetBasicblock): | ||||
|                     nn.init.constant_(m.conv_b.bn.weight, 0) | ||||
|                 elif isinstance(m, ResNetBottleneck): | ||||
|                     nn.init.constant_(m.conv_1x4.bn.weight, 0) | ||||
|  | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         x = inputs | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             x = layer(x) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = self.classifier(features) | ||||
|         return features, logits | ||||
| @@ -1,324 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| from ..initialization import initialize_resnet | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|  | ||||
|     num_conv = 1 | ||||
|  | ||||
|     def __init__( | ||||
|         self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu | ||||
|     ): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|         if has_avg: | ||||
|             self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | ||||
|         else: | ||||
|             self.avg = None | ||||
|         self.conv = nn.Conv2d( | ||||
|             nIn, | ||||
|             nOut, | ||||
|             kernel_size=kernel, | ||||
|             stride=stride, | ||||
|             padding=padding, | ||||
|             dilation=1, | ||||
|             groups=1, | ||||
|             bias=bias, | ||||
|         ) | ||||
|         if has_bn: | ||||
|             self.bn = nn.BatchNorm2d(nOut) | ||||
|         else: | ||||
|             self.bn = None | ||||
|         if has_relu: | ||||
|             self.relu = nn.ReLU(inplace=True) | ||||
|         else: | ||||
|             self.relu = None | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.avg: | ||||
|             out = self.avg(inputs) | ||||
|         else: | ||||
|             out = inputs | ||||
|         conv = self.conv(out) | ||||
|         if self.bn: | ||||
|             out = self.bn(conv) | ||||
|         else: | ||||
|             out = conv | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         else: | ||||
|             out = out | ||||
|  | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class ResNetBasicblock(nn.Module): | ||||
|     num_conv = 2 | ||||
|     expansion = 1 | ||||
|  | ||||
|     def __init__(self, iCs, stride): | ||||
|         super(ResNetBasicblock, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         assert isinstance(iCs, tuple) or isinstance( | ||||
|             iCs, list | ||||
|         ), "invalid type of iCs : {:}".format(iCs) | ||||
|         assert len(iCs) == 3, "invalid lengths of iCs : {:}".format(iCs) | ||||
|  | ||||
|         self.conv_a = ConvBNReLU( | ||||
|             iCs[0], | ||||
|             iCs[1], | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_b = ConvBNReLU( | ||||
|             iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False | ||||
|         ) | ||||
|         residual_in = iCs[0] | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[2], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|             residual_in = iCs[2] | ||||
|         elif iCs[0] != iCs[2]: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[2], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         # self.out_dim  = max(residual_in, iCs[2]) | ||||
|         self.out_dim = iCs[2] | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         basicblock = self.conv_a(inputs) | ||||
|         basicblock = self.conv_b(basicblock) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = residual + basicblock | ||||
|         return F.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class ResNetBottleneck(nn.Module): | ||||
|     expansion = 4 | ||||
|     num_conv = 3 | ||||
|  | ||||
|     def __init__(self, iCs, stride): | ||||
|         super(ResNetBottleneck, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         assert isinstance(iCs, tuple) or isinstance( | ||||
|             iCs, list | ||||
|         ), "invalid type of iCs : {:}".format(iCs) | ||||
|         assert len(iCs) == 4, "invalid lengths of iCs : {:}".format(iCs) | ||||
|         self.conv_1x1 = ConvBNReLU( | ||||
|             iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True | ||||
|         ) | ||||
|         self.conv_3x3 = ConvBNReLU( | ||||
|             iCs[1], | ||||
|             iCs[2], | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_1x4 = ConvBNReLU( | ||||
|             iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False | ||||
|         ) | ||||
|         residual_in = iCs[0] | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[3], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|             residual_in = iCs[3] | ||||
|         elif iCs[0] != iCs[3]: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 iCs[0], | ||||
|                 iCs[3], | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|             residual_in = iCs[3] | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         # self.out_dim = max(residual_in, iCs[3]) | ||||
|         self.out_dim = iCs[3] | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|  | ||||
|         bottleneck = self.conv_1x1(inputs) | ||||
|         bottleneck = self.conv_3x3(bottleneck) | ||||
|         bottleneck = self.conv_1x4(bottleneck) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = residual + bottleneck | ||||
|         return F.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class InferImagenetResNet(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         block_name, | ||||
|         layers, | ||||
|         xblocks, | ||||
|         xchannels, | ||||
|         deep_stem, | ||||
|         num_classes, | ||||
|         zero_init_residual, | ||||
|     ): | ||||
|         super(InferImagenetResNet, self).__init__() | ||||
|  | ||||
|         # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|         if block_name == "BasicBlock": | ||||
|             block = ResNetBasicblock | ||||
|         elif block_name == "Bottleneck": | ||||
|             block = ResNetBottleneck | ||||
|         else: | ||||
|             raise ValueError("invalid block : {:}".format(block_name)) | ||||
|         assert len(xblocks) == len( | ||||
|             layers | ||||
|         ), "invalid layers : {:} vs xblocks : {:}".format(layers, xblocks) | ||||
|  | ||||
|         self.message = "InferImagenetResNet : Depth : {:} -> {:}, Layers for each block : {:}".format( | ||||
|             sum(layers) * block.num_conv, sum(xblocks) * block.num_conv, xblocks | ||||
|         ) | ||||
|         self.num_classes = num_classes | ||||
|         self.xchannels = xchannels | ||||
|         if not deep_stem: | ||||
|             self.layers = nn.ModuleList( | ||||
|                 [ | ||||
|                     ConvBNReLU( | ||||
|                         xchannels[0], | ||||
|                         xchannels[1], | ||||
|                         7, | ||||
|                         2, | ||||
|                         3, | ||||
|                         False, | ||||
|                         has_avg=False, | ||||
|                         has_bn=True, | ||||
|                         has_relu=True, | ||||
|                     ) | ||||
|                 ] | ||||
|             ) | ||||
|             last_channel_idx = 1 | ||||
|         else: | ||||
|             self.layers = nn.ModuleList( | ||||
|                 [ | ||||
|                     ConvBNReLU( | ||||
|                         xchannels[0], | ||||
|                         xchannels[1], | ||||
|                         3, | ||||
|                         2, | ||||
|                         1, | ||||
|                         False, | ||||
|                         has_avg=False, | ||||
|                         has_bn=True, | ||||
|                         has_relu=True, | ||||
|                     ), | ||||
|                     ConvBNReLU( | ||||
|                         xchannels[1], | ||||
|                         xchannels[2], | ||||
|                         3, | ||||
|                         1, | ||||
|                         1, | ||||
|                         False, | ||||
|                         has_avg=False, | ||||
|                         has_bn=True, | ||||
|                         has_relu=True, | ||||
|                     ), | ||||
|                 ] | ||||
|             ) | ||||
|             last_channel_idx = 2 | ||||
|         self.layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) | ||||
|         for stage, layer_blocks in enumerate(layers): | ||||
|             for iL in range(layer_blocks): | ||||
|                 num_conv = block.num_conv | ||||
|                 iCs = self.xchannels[last_channel_idx : last_channel_idx + num_conv + 1] | ||||
|                 stride = 2 if stage > 0 and iL == 0 else 1 | ||||
|                 module = block(iCs, stride) | ||||
|                 last_channel_idx += num_conv | ||||
|                 self.xchannels[last_channel_idx] = module.out_dim | ||||
|                 self.layers.append(module) | ||||
|                 self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format( | ||||
|                     stage, | ||||
|                     iL, | ||||
|                     layer_blocks, | ||||
|                     len(self.layers) - 1, | ||||
|                     iCs, | ||||
|                     module.out_dim, | ||||
|                     stride, | ||||
|                 ) | ||||
|                 if iL + 1 == xblocks[stage]:  # reach the maximum depth | ||||
|                     out_channel = module.out_dim | ||||
|                     for iiL in range(iL + 1, layer_blocks): | ||||
|                         last_channel_idx += num_conv | ||||
|                     self.xchannels[last_channel_idx] = module.out_dim | ||||
|                     break | ||||
|         assert last_channel_idx + 1 == len(self.xchannels), "{:} vs {:}".format( | ||||
|             last_channel_idx, len(self.xchannels) | ||||
|         ) | ||||
|         self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | ||||
|         self.classifier = nn.Linear(self.xchannels[-1], num_classes) | ||||
|  | ||||
|         self.apply(initialize_resnet) | ||||
|         if zero_init_residual: | ||||
|             for m in self.modules(): | ||||
|                 if isinstance(m, ResNetBasicblock): | ||||
|                     nn.init.constant_(m.conv_b.bn.weight, 0) | ||||
|                 elif isinstance(m, ResNetBottleneck): | ||||
|                     nn.init.constant_(m.conv_1x4.bn.weight, 0) | ||||
|  | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         x = inputs | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             x = layer(x) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = self.classifier(features) | ||||
|         return features, logits | ||||
| @@ -1,174 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| # MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018 | ||||
| from torch import nn | ||||
| from ..initialization import initialize_resnet | ||||
| from ..SharedUtils import parse_channel_info | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         in_planes, | ||||
|         out_planes, | ||||
|         kernel_size, | ||||
|         stride, | ||||
|         groups, | ||||
|         has_bn=True, | ||||
|         has_relu=True, | ||||
|     ): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|         padding = (kernel_size - 1) // 2 | ||||
|         self.conv = nn.Conv2d( | ||||
|             in_planes, | ||||
|             out_planes, | ||||
|             kernel_size, | ||||
|             stride, | ||||
|             padding, | ||||
|             groups=groups, | ||||
|             bias=False, | ||||
|         ) | ||||
|         if has_bn: | ||||
|             self.bn = nn.BatchNorm2d(out_planes) | ||||
|         else: | ||||
|             self.bn = None | ||||
|         if has_relu: | ||||
|             self.relu = nn.ReLU6(inplace=True) | ||||
|         else: | ||||
|             self.relu = None | ||||
|  | ||||
|     def forward(self, x): | ||||
|         out = self.conv(x) | ||||
|         if self.bn: | ||||
|             out = self.bn(out) | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class InvertedResidual(nn.Module): | ||||
|     def __init__(self, channels, stride, expand_ratio, additive): | ||||
|         super(InvertedResidual, self).__init__() | ||||
|         self.stride = stride | ||||
|         assert stride in [1, 2], "invalid stride : {:}".format(stride) | ||||
|         assert len(channels) in [2, 3], "invalid channels : {:}".format(channels) | ||||
|  | ||||
|         if len(channels) == 2: | ||||
|             layers = [] | ||||
|         else: | ||||
|             layers = [ConvBNReLU(channels[0], channels[1], 1, 1, 1)] | ||||
|         layers.extend( | ||||
|             [ | ||||
|                 # dw | ||||
|                 ConvBNReLU(channels[-2], channels[-2], 3, stride, channels[-2]), | ||||
|                 # pw-linear | ||||
|                 ConvBNReLU(channels[-2], channels[-1], 1, 1, 1, True, False), | ||||
|             ] | ||||
|         ) | ||||
|         self.conv = nn.Sequential(*layers) | ||||
|         self.additive = additive | ||||
|         if self.additive and channels[0] != channels[-1]: | ||||
|             self.shortcut = ConvBNReLU(channels[0], channels[-1], 1, 1, 1, True, False) | ||||
|         else: | ||||
|             self.shortcut = None | ||||
|         self.out_dim = channels[-1] | ||||
|  | ||||
|     def forward(self, x): | ||||
|         out = self.conv(x) | ||||
|         # if self.additive: return additive_func(out, x) | ||||
|         if self.shortcut: | ||||
|             return out + self.shortcut(x) | ||||
|         else: | ||||
|             return out | ||||
|  | ||||
|  | ||||
| class InferMobileNetV2(nn.Module): | ||||
|     def __init__(self, num_classes, xchannels, xblocks, dropout): | ||||
|         super(InferMobileNetV2, self).__init__() | ||||
|         block = InvertedResidual | ||||
|         inverted_residual_setting = [ | ||||
|             # t, c,  n, s | ||||
|             [1, 16, 1, 1], | ||||
|             [6, 24, 2, 2], | ||||
|             [6, 32, 3, 2], | ||||
|             [6, 64, 4, 2], | ||||
|             [6, 96, 3, 1], | ||||
|             [6, 160, 3, 2], | ||||
|             [6, 320, 1, 1], | ||||
|         ] | ||||
|         assert len(inverted_residual_setting) == len( | ||||
|             xblocks | ||||
|         ), "invalid number of layers : {:} vs {:}".format( | ||||
|             len(inverted_residual_setting), len(xblocks) | ||||
|         ) | ||||
|         for block_num, ir_setting in zip(xblocks, inverted_residual_setting): | ||||
|             assert block_num <= ir_setting[2], "{:} vs {:}".format( | ||||
|                 block_num, ir_setting | ||||
|             ) | ||||
|         xchannels = parse_channel_info(xchannels) | ||||
|         # for i, chs in enumerate(xchannels): | ||||
|         #  if i > 0: assert chs[0] == xchannels[i-1][-1], 'Layer[{:}] is invalid {:} vs {:}'.format(i, xchannels[i-1], chs) | ||||
|         self.xchannels = xchannels | ||||
|         self.message = "InferMobileNetV2 : xblocks={:}".format(xblocks) | ||||
|         # building first layer | ||||
|         features = [ConvBNReLU(xchannels[0][0], xchannels[0][1], 3, 2, 1)] | ||||
|         last_channel_idx = 1 | ||||
|  | ||||
|         # building inverted residual blocks | ||||
|         for stage, (t, c, n, s) in enumerate(inverted_residual_setting): | ||||
|             for i in range(n): | ||||
|                 stride = s if i == 0 else 1 | ||||
|                 additv = True if i > 0 else False | ||||
|                 module = block(self.xchannels[last_channel_idx], stride, t, additv) | ||||
|                 features.append(module) | ||||
|                 self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, Cs={:}, stride={:}, expand={:}, original-C={:}".format( | ||||
|                     stage, | ||||
|                     i, | ||||
|                     n, | ||||
|                     len(features), | ||||
|                     self.xchannels[last_channel_idx], | ||||
|                     stride, | ||||
|                     t, | ||||
|                     c, | ||||
|                 ) | ||||
|                 last_channel_idx += 1 | ||||
|                 if i + 1 == xblocks[stage]: | ||||
|                     out_channel = module.out_dim | ||||
|                     for iiL in range(i + 1, n): | ||||
|                         last_channel_idx += 1 | ||||
|                     self.xchannels[last_channel_idx][0] = module.out_dim | ||||
|                     break | ||||
|         # building last several layers | ||||
|         features.append( | ||||
|             ConvBNReLU( | ||||
|                 self.xchannels[last_channel_idx][0], | ||||
|                 self.xchannels[last_channel_idx][1], | ||||
|                 1, | ||||
|                 1, | ||||
|                 1, | ||||
|             ) | ||||
|         ) | ||||
|         assert last_channel_idx + 2 == len(self.xchannels), "{:} vs {:}".format( | ||||
|             last_channel_idx, len(self.xchannels) | ||||
|         ) | ||||
|         # make it nn.Sequential | ||||
|         self.features = nn.Sequential(*features) | ||||
|  | ||||
|         # building classifier | ||||
|         self.classifier = nn.Sequential( | ||||
|             nn.Dropout(dropout), | ||||
|             nn.Linear(self.xchannels[last_channel_idx][1], num_classes), | ||||
|         ) | ||||
|  | ||||
|         # weight initialization | ||||
|         self.apply(initialize_resnet) | ||||
|  | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         features = self.features(inputs) | ||||
|         vectors = features.mean([2, 3]) | ||||
|         predicts = self.classifier(vectors) | ||||
|         return features, predicts | ||||
| @@ -1,64 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| from typing import List, Text, Any | ||||
| import torch.nn as nn | ||||
| from models.cell_operations import ResNetBasicblock | ||||
| from models.cell_infers.cells import InferCell | ||||
|  | ||||
|  | ||||
| class DynamicShapeTinyNet(nn.Module): | ||||
|     def __init__(self, channels: List[int], genotype: Any, num_classes: int): | ||||
|         super(DynamicShapeTinyNet, self).__init__() | ||||
|         self._channels = channels | ||||
|         if len(channels) % 3 != 2: | ||||
|             raise ValueError("invalid number of layers : {:}".format(len(channels))) | ||||
|         self._num_stage = N = len(channels) // 3 | ||||
|  | ||||
|         self.stem = nn.Sequential( | ||||
|             nn.Conv2d(3, channels[0], kernel_size=3, padding=1, bias=False), | ||||
|             nn.BatchNorm2d(channels[0]), | ||||
|         ) | ||||
|  | ||||
|         # layer_channels   = [C    ] * N + [C*2 ] + [C*2  ] * N + [C*4 ] + [C*4  ] * N | ||||
|         layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||
|  | ||||
|         c_prev = channels[0] | ||||
|         self.cells = nn.ModuleList() | ||||
|         for index, (c_curr, reduction) in enumerate(zip(channels, layer_reductions)): | ||||
|             if reduction: | ||||
|                 cell = ResNetBasicblock(c_prev, c_curr, 2, True) | ||||
|             else: | ||||
|                 cell = InferCell(genotype, c_prev, c_curr, 1) | ||||
|             self.cells.append(cell) | ||||
|             c_prev = cell.out_dim | ||||
|         self._num_layer = len(self.cells) | ||||
|  | ||||
|         self.lastact = nn.Sequential(nn.BatchNorm2d(c_prev), nn.ReLU(inplace=True)) | ||||
|         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|         self.classifier = nn.Linear(c_prev, num_classes) | ||||
|  | ||||
|     def get_message(self) -> Text: | ||||
|         string = self.extra_repr() | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             string += "\n {:02d}/{:02d} :: {:}".format( | ||||
|                 i, len(self.cells), cell.extra_repr() | ||||
|             ) | ||||
|         return string | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "{name}(C={_channels}, N={_num_stage}, L={_num_layer})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         feature = self.stem(inputs) | ||||
|         for i, cell in enumerate(self.cells): | ||||
|             feature = cell(feature) | ||||
|  | ||||
|         out = self.lastact(feature) | ||||
|         out = self.global_pooling(out) | ||||
|         out = out.view(out.size(0), -1) | ||||
|         logits = self.classifier(out) | ||||
|  | ||||
|         return out, logits | ||||
| @@ -1,9 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| from .InferCifarResNet_width import InferWidthCifarResNet | ||||
| from .InferImagenetResNet import InferImagenetResNet | ||||
| from .InferCifarResNet_depth import InferDepthCifarResNet | ||||
| from .InferCifarResNet import InferCifarResNet | ||||
| from .InferMobileNetV2 import InferMobileNetV2 | ||||
| from .InferTinyCellNet import DynamicShapeTinyNet | ||||
| @@ -1,5 +0,0 @@ | ||||
| def parse_channel_info(xstring): | ||||
|     blocks = xstring.split(" ") | ||||
|     blocks = [x.split("-") for x in blocks] | ||||
|     blocks = [[int(_) for _ in x] for x in blocks] | ||||
|     return blocks | ||||
| @@ -1,760 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import math, torch | ||||
| from collections import OrderedDict | ||||
| from bisect import bisect_right | ||||
| import torch.nn as nn | ||||
| from ..initialization import initialize_resnet | ||||
| from ..SharedUtils import additive_func | ||||
| from .SoftSelect import select2withP, ChannelWiseInter | ||||
| from .SoftSelect import linear_forward | ||||
| from .SoftSelect import get_width_choices | ||||
|  | ||||
|  | ||||
| def get_depth_choices(nDepth, return_num): | ||||
|     if nDepth == 2: | ||||
|         choices = (1, 2) | ||||
|     elif nDepth == 3: | ||||
|         choices = (1, 2, 3) | ||||
|     elif nDepth > 3: | ||||
|         choices = list(range(1, nDepth + 1, 2)) | ||||
|         if choices[-1] < nDepth: | ||||
|             choices.append(nDepth) | ||||
|     else: | ||||
|         raise ValueError("invalid nDepth : {:}".format(nDepth)) | ||||
|     if return_num: | ||||
|         return len(choices) | ||||
|     else: | ||||
|         return choices | ||||
|  | ||||
|  | ||||
| def conv_forward(inputs, conv, choices): | ||||
|     iC = conv.in_channels | ||||
|     fill_size = list(inputs.size()) | ||||
|     fill_size[1] = iC - fill_size[1] | ||||
|     filled = torch.zeros(fill_size, device=inputs.device) | ||||
|     xinputs = torch.cat((inputs, filled), dim=1) | ||||
|     outputs = conv(xinputs) | ||||
|     selecteds = [outputs[:, :oC] for oC in choices] | ||||
|     return selecteds | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|     num_conv = 1 | ||||
|  | ||||
|     def __init__( | ||||
|         self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu | ||||
|     ): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|         self.InShape = None | ||||
|         self.OutShape = None | ||||
|         self.choices = get_width_choices(nOut) | ||||
|         self.register_buffer("choices_tensor", torch.Tensor(self.choices)) | ||||
|  | ||||
|         if has_avg: | ||||
|             self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | ||||
|         else: | ||||
|             self.avg = None | ||||
|         self.conv = nn.Conv2d( | ||||
|             nIn, | ||||
|             nOut, | ||||
|             kernel_size=kernel, | ||||
|             stride=stride, | ||||
|             padding=padding, | ||||
|             dilation=1, | ||||
|             groups=1, | ||||
|             bias=bias, | ||||
|         ) | ||||
|         # if has_bn  : self.bn  = nn.BatchNorm2d(nOut) | ||||
|         # else       : self.bn  = None | ||||
|         self.has_bn = has_bn | ||||
|         self.BNs = nn.ModuleList() | ||||
|         for i, _out in enumerate(self.choices): | ||||
|             self.BNs.append(nn.BatchNorm2d(_out)) | ||||
|         if has_relu: | ||||
|             self.relu = nn.ReLU(inplace=True) | ||||
|         else: | ||||
|             self.relu = None | ||||
|         self.in_dim = nIn | ||||
|         self.out_dim = nOut | ||||
|         self.search_mode = "basic" | ||||
|  | ||||
|     def get_flops(self, channels, check_range=True, divide=1): | ||||
|         iC, oC = channels | ||||
|         if check_range: | ||||
|             assert ( | ||||
|                 iC <= self.conv.in_channels and oC <= self.conv.out_channels | ||||
|             ), "{:} vs {:}  |  {:} vs {:}".format( | ||||
|                 iC, self.conv.in_channels, oC, self.conv.out_channels | ||||
|             ) | ||||
|         assert ( | ||||
|             isinstance(self.InShape, tuple) and len(self.InShape) == 2 | ||||
|         ), "invalid in-shape : {:}".format(self.InShape) | ||||
|         assert ( | ||||
|             isinstance(self.OutShape, tuple) and len(self.OutShape) == 2 | ||||
|         ), "invalid out-shape : {:}".format(self.OutShape) | ||||
|         # conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups | ||||
|         conv_per_position_flops = ( | ||||
|             self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups | ||||
|         ) | ||||
|         all_positions = self.OutShape[0] * self.OutShape[1] | ||||
|         flops = (conv_per_position_flops * all_positions / divide) * iC * oC | ||||
|         if self.conv.bias is not None: | ||||
|             flops += all_positions / divide | ||||
|         return flops | ||||
|  | ||||
|     def get_range(self): | ||||
|         return [self.choices] | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.search_mode == "basic": | ||||
|             return self.basic_forward(inputs) | ||||
|         elif self.search_mode == "search": | ||||
|             return self.search_forward(inputs) | ||||
|         else: | ||||
|             raise ValueError("invalid search_mode = {:}".format(self.search_mode)) | ||||
|  | ||||
|     def search_forward(self, tuple_inputs): | ||||
|         assert ( | ||||
|             isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 | ||||
|         ), "invalid type input : {:}".format(type(tuple_inputs)) | ||||
|         inputs, expected_inC, probability, index, prob = tuple_inputs | ||||
|         index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob) | ||||
|         probability = torch.squeeze(probability) | ||||
|         assert len(index) == 2, "invalid length : {:}".format(index) | ||||
|         # compute expected flop | ||||
|         # coordinates   = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability) | ||||
|         expected_outC = (self.choices_tensor * probability).sum() | ||||
|         expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6) | ||||
|         if self.avg: | ||||
|             out = self.avg(inputs) | ||||
|         else: | ||||
|             out = inputs | ||||
|         # convolutional layer | ||||
|         out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index]) | ||||
|         out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)] | ||||
|         # merge | ||||
|         out_channel = max([x.size(1) for x in out_bns]) | ||||
|         outA = ChannelWiseInter(out_bns[0], out_channel) | ||||
|         outB = ChannelWiseInter(out_bns[1], out_channel) | ||||
|         out = outA * prob[0] + outB * prob[1] | ||||
|         # out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1]) | ||||
|  | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         else: | ||||
|             out = out | ||||
|         return out, expected_outC, expected_flop | ||||
|  | ||||
|     def basic_forward(self, inputs): | ||||
|         if self.avg: | ||||
|             out = self.avg(inputs) | ||||
|         else: | ||||
|             out = inputs | ||||
|         conv = self.conv(out) | ||||
|         if self.has_bn: | ||||
|             out = self.BNs[-1](conv) | ||||
|         else: | ||||
|             out = conv | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         else: | ||||
|             out = out | ||||
|         if self.InShape is None: | ||||
|             self.InShape = (inputs.size(-2), inputs.size(-1)) | ||||
|             self.OutShape = (out.size(-2), out.size(-1)) | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class ResNetBasicblock(nn.Module): | ||||
|     expansion = 1 | ||||
|     num_conv = 2 | ||||
|  | ||||
|     def __init__(self, inplanes, planes, stride): | ||||
|         super(ResNetBasicblock, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         self.conv_a = ConvBNReLU( | ||||
|             inplanes, | ||||
|             planes, | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_b = ConvBNReLU( | ||||
|             planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False | ||||
|         ) | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         elif inplanes != planes: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.out_dim = planes | ||||
|         self.search_mode = "basic" | ||||
|  | ||||
|     def get_range(self): | ||||
|         return self.conv_a.get_range() + self.conv_b.get_range() | ||||
|  | ||||
|     def get_flops(self, channels): | ||||
|         assert len(channels) == 3, "invalid channels : {:}".format(channels) | ||||
|         flop_A = self.conv_a.get_flops([channels[0], channels[1]]) | ||||
|         flop_B = self.conv_b.get_flops([channels[1], channels[2]]) | ||||
|         if hasattr(self.downsample, "get_flops"): | ||||
|             flop_C = self.downsample.get_flops([channels[0], channels[-1]]) | ||||
|         else: | ||||
|             flop_C = 0 | ||||
|         if ( | ||||
|             channels[0] != channels[-1] and self.downsample is None | ||||
|         ):  # this short-cut will be added during the infer-train | ||||
|             flop_C = ( | ||||
|                 channels[0] | ||||
|                 * channels[-1] | ||||
|                 * self.conv_b.OutShape[0] | ||||
|                 * self.conv_b.OutShape[1] | ||||
|             ) | ||||
|         return flop_A + flop_B + flop_C | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.search_mode == "basic": | ||||
|             return self.basic_forward(inputs) | ||||
|         elif self.search_mode == "search": | ||||
|             return self.search_forward(inputs) | ||||
|         else: | ||||
|             raise ValueError("invalid search_mode = {:}".format(self.search_mode)) | ||||
|  | ||||
|     def search_forward(self, tuple_inputs): | ||||
|         assert ( | ||||
|             isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 | ||||
|         ), "invalid type input : {:}".format(type(tuple_inputs)) | ||||
|         inputs, expected_inC, probability, indexes, probs = tuple_inputs | ||||
|         assert indexes.size(0) == 2 and probs.size(0) == 2 and probability.size(0) == 2 | ||||
|         out_a, expected_inC_a, expected_flop_a = self.conv_a( | ||||
|             (inputs, expected_inC, probability[0], indexes[0], probs[0]) | ||||
|         ) | ||||
|         out_b, expected_inC_b, expected_flop_b = self.conv_b( | ||||
|             (out_a, expected_inC_a, probability[1], indexes[1], probs[1]) | ||||
|         ) | ||||
|         if self.downsample is not None: | ||||
|             residual, _, expected_flop_c = self.downsample( | ||||
|                 (inputs, expected_inC, probability[1], indexes[1], probs[1]) | ||||
|             ) | ||||
|         else: | ||||
|             residual, expected_flop_c = inputs, 0 | ||||
|         out = additive_func(residual, out_b) | ||||
|         return ( | ||||
|             nn.functional.relu(out, inplace=True), | ||||
|             expected_inC_b, | ||||
|             sum([expected_flop_a, expected_flop_b, expected_flop_c]), | ||||
|         ) | ||||
|  | ||||
|     def basic_forward(self, inputs): | ||||
|         basicblock = self.conv_a(inputs) | ||||
|         basicblock = self.conv_b(basicblock) | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = additive_func(residual, basicblock) | ||||
|         return nn.functional.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class ResNetBottleneck(nn.Module): | ||||
|     expansion = 4 | ||||
|     num_conv = 3 | ||||
|  | ||||
|     def __init__(self, inplanes, planes, stride): | ||||
|         super(ResNetBottleneck, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         self.conv_1x1 = ConvBNReLU( | ||||
|             inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True | ||||
|         ) | ||||
|         self.conv_3x3 = ConvBNReLU( | ||||
|             planes, | ||||
|             planes, | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_1x4 = ConvBNReLU( | ||||
|             planes, | ||||
|             planes * self.expansion, | ||||
|             1, | ||||
|             1, | ||||
|             0, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=False, | ||||
|         ) | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes * self.expansion, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         elif inplanes != planes * self.expansion: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes * self.expansion, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.out_dim = planes * self.expansion | ||||
|         self.search_mode = "basic" | ||||
|  | ||||
|     def get_range(self): | ||||
|         return ( | ||||
|             self.conv_1x1.get_range() | ||||
|             + self.conv_3x3.get_range() | ||||
|             + self.conv_1x4.get_range() | ||||
|         ) | ||||
|  | ||||
|     def get_flops(self, channels): | ||||
|         assert len(channels) == 4, "invalid channels : {:}".format(channels) | ||||
|         flop_A = self.conv_1x1.get_flops([channels[0], channels[1]]) | ||||
|         flop_B = self.conv_3x3.get_flops([channels[1], channels[2]]) | ||||
|         flop_C = self.conv_1x4.get_flops([channels[2], channels[3]]) | ||||
|         if hasattr(self.downsample, "get_flops"): | ||||
|             flop_D = self.downsample.get_flops([channels[0], channels[-1]]) | ||||
|         else: | ||||
|             flop_D = 0 | ||||
|         if ( | ||||
|             channels[0] != channels[-1] and self.downsample is None | ||||
|         ):  # this short-cut will be added during the infer-train | ||||
|             flop_D = ( | ||||
|                 channels[0] | ||||
|                 * channels[-1] | ||||
|                 * self.conv_1x4.OutShape[0] | ||||
|                 * self.conv_1x4.OutShape[1] | ||||
|             ) | ||||
|         return flop_A + flop_B + flop_C + flop_D | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.search_mode == "basic": | ||||
|             return self.basic_forward(inputs) | ||||
|         elif self.search_mode == "search": | ||||
|             return self.search_forward(inputs) | ||||
|         else: | ||||
|             raise ValueError("invalid search_mode = {:}".format(self.search_mode)) | ||||
|  | ||||
|     def basic_forward(self, inputs): | ||||
|         bottleneck = self.conv_1x1(inputs) | ||||
|         bottleneck = self.conv_3x3(bottleneck) | ||||
|         bottleneck = self.conv_1x4(bottleneck) | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = additive_func(residual, bottleneck) | ||||
|         return nn.functional.relu(out, inplace=True) | ||||
|  | ||||
|     def search_forward(self, tuple_inputs): | ||||
|         assert ( | ||||
|             isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 | ||||
|         ), "invalid type input : {:}".format(type(tuple_inputs)) | ||||
|         inputs, expected_inC, probability, indexes, probs = tuple_inputs | ||||
|         assert indexes.size(0) == 3 and probs.size(0) == 3 and probability.size(0) == 3 | ||||
|         out_1x1, expected_inC_1x1, expected_flop_1x1 = self.conv_1x1( | ||||
|             (inputs, expected_inC, probability[0], indexes[0], probs[0]) | ||||
|         ) | ||||
|         out_3x3, expected_inC_3x3, expected_flop_3x3 = self.conv_3x3( | ||||
|             (out_1x1, expected_inC_1x1, probability[1], indexes[1], probs[1]) | ||||
|         ) | ||||
|         out_1x4, expected_inC_1x4, expected_flop_1x4 = self.conv_1x4( | ||||
|             (out_3x3, expected_inC_3x3, probability[2], indexes[2], probs[2]) | ||||
|         ) | ||||
|         if self.downsample is not None: | ||||
|             residual, _, expected_flop_c = self.downsample( | ||||
|                 (inputs, expected_inC, probability[2], indexes[2], probs[2]) | ||||
|             ) | ||||
|         else: | ||||
|             residual, expected_flop_c = inputs, 0 | ||||
|         out = additive_func(residual, out_1x4) | ||||
|         return ( | ||||
|             nn.functional.relu(out, inplace=True), | ||||
|             expected_inC_1x4, | ||||
|             sum( | ||||
|                 [ | ||||
|                     expected_flop_1x1, | ||||
|                     expected_flop_3x3, | ||||
|                     expected_flop_1x4, | ||||
|                     expected_flop_c, | ||||
|                 ] | ||||
|             ), | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class SearchShapeCifarResNet(nn.Module): | ||||
|     def __init__(self, block_name, depth, num_classes): | ||||
|         super(SearchShapeCifarResNet, self).__init__() | ||||
|  | ||||
|         # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|         if block_name == "ResNetBasicblock": | ||||
|             block = ResNetBasicblock | ||||
|             assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" | ||||
|             layer_blocks = (depth - 2) // 6 | ||||
|         elif block_name == "ResNetBottleneck": | ||||
|             block = ResNetBottleneck | ||||
|             assert (depth - 2) % 9 == 0, "depth should be one of 164" | ||||
|             layer_blocks = (depth - 2) // 9 | ||||
|         else: | ||||
|             raise ValueError("invalid block : {:}".format(block_name)) | ||||
|  | ||||
|         self.message = ( | ||||
|             "SearchShapeCifarResNet : Depth : {:} , Layers for each block : {:}".format( | ||||
|                 depth, layer_blocks | ||||
|             ) | ||||
|         ) | ||||
|         self.num_classes = num_classes | ||||
|         self.channels = [16] | ||||
|         self.layers = nn.ModuleList( | ||||
|             [ | ||||
|                 ConvBNReLU( | ||||
|                     3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True | ||||
|                 ) | ||||
|             ] | ||||
|         ) | ||||
|         self.InShape = None | ||||
|         self.depth_info = OrderedDict() | ||||
|         self.depth_at_i = OrderedDict() | ||||
|         for stage in range(3): | ||||
|             cur_block_choices = get_depth_choices(layer_blocks, False) | ||||
|             assert ( | ||||
|                 cur_block_choices[-1] == layer_blocks | ||||
|             ), "stage={:}, {:} vs {:}".format(stage, cur_block_choices, layer_blocks) | ||||
|             self.message += ( | ||||
|                 "\nstage={:} ::: depth-block-choices={:} for {:} blocks.".format( | ||||
|                     stage, cur_block_choices, layer_blocks | ||||
|                 ) | ||||
|             ) | ||||
|             block_choices, xstart = [], len(self.layers) | ||||
|             for iL in range(layer_blocks): | ||||
|                 iC = self.channels[-1] | ||||
|                 planes = 16 * (2 ** stage) | ||||
|                 stride = 2 if stage > 0 and iL == 0 else 1 | ||||
|                 module = block(iC, planes, stride) | ||||
|                 self.channels.append(module.out_dim) | ||||
|                 self.layers.append(module) | ||||
|                 self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format( | ||||
|                     stage, | ||||
|                     iL, | ||||
|                     layer_blocks, | ||||
|                     len(self.layers) - 1, | ||||
|                     iC, | ||||
|                     module.out_dim, | ||||
|                     stride, | ||||
|                 ) | ||||
|                 # added for depth | ||||
|                 layer_index = len(self.layers) - 1 | ||||
|                 if iL + 1 in cur_block_choices: | ||||
|                     block_choices.append(layer_index) | ||||
|                 if iL + 1 == layer_blocks: | ||||
|                     self.depth_info[layer_index] = { | ||||
|                         "choices": block_choices, | ||||
|                         "stage": stage, | ||||
|                         "xstart": xstart, | ||||
|                     } | ||||
|         self.depth_info_list = [] | ||||
|         for xend, info in self.depth_info.items(): | ||||
|             self.depth_info_list.append((xend, info)) | ||||
|             xstart, xstage = info["xstart"], info["stage"] | ||||
|             for ilayer in range(xstart, xend + 1): | ||||
|                 idx = bisect_right(info["choices"], ilayer - 1) | ||||
|                 self.depth_at_i[ilayer] = (xstage, idx) | ||||
|  | ||||
|         self.avgpool = nn.AvgPool2d(8) | ||||
|         self.classifier = nn.Linear(module.out_dim, num_classes) | ||||
|         self.InShape = None | ||||
|         self.tau = -1 | ||||
|         self.search_mode = "basic" | ||||
|         # assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth) | ||||
|  | ||||
|         # parameters for width | ||||
|         self.Ranges = [] | ||||
|         self.layer2indexRange = [] | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             start_index = len(self.Ranges) | ||||
|             self.Ranges += layer.get_range() | ||||
|             self.layer2indexRange.append((start_index, len(self.Ranges))) | ||||
|         assert len(self.Ranges) + 1 == depth, "invalid depth check {:} vs {:}".format( | ||||
|             len(self.Ranges) + 1, depth | ||||
|         ) | ||||
|  | ||||
|         self.register_parameter( | ||||
|             "width_attentions", | ||||
|             nn.Parameter(torch.Tensor(len(self.Ranges), get_width_choices(None))), | ||||
|         ) | ||||
|         self.register_parameter( | ||||
|             "depth_attentions", | ||||
|             nn.Parameter(torch.Tensor(3, get_depth_choices(layer_blocks, True))), | ||||
|         ) | ||||
|         nn.init.normal_(self.width_attentions, 0, 0.01) | ||||
|         nn.init.normal_(self.depth_attentions, 0, 0.01) | ||||
|         self.apply(initialize_resnet) | ||||
|  | ||||
|     def arch_parameters(self, LR=None): | ||||
|         if LR is None: | ||||
|             return [self.width_attentions, self.depth_attentions] | ||||
|         else: | ||||
|             return [ | ||||
|                 {"params": self.width_attentions, "lr": LR}, | ||||
|                 {"params": self.depth_attentions, "lr": LR}, | ||||
|             ] | ||||
|  | ||||
|     def base_parameters(self): | ||||
|         return ( | ||||
|             list(self.layers.parameters()) | ||||
|             + list(self.avgpool.parameters()) | ||||
|             + list(self.classifier.parameters()) | ||||
|         ) | ||||
|  | ||||
|     def get_flop(self, mode, config_dict, extra_info): | ||||
|         if config_dict is not None: | ||||
|             config_dict = config_dict.copy() | ||||
|         # select channels | ||||
|         channels = [3] | ||||
|         for i, weight in enumerate(self.width_attentions): | ||||
|             if mode == "genotype": | ||||
|                 with torch.no_grad(): | ||||
|                     probe = nn.functional.softmax(weight, dim=0) | ||||
|                     C = self.Ranges[i][torch.argmax(probe).item()] | ||||
|             elif mode == "max": | ||||
|                 C = self.Ranges[i][-1] | ||||
|             elif mode == "fix": | ||||
|                 C = int(math.sqrt(extra_info) * self.Ranges[i][-1]) | ||||
|             elif mode == "random": | ||||
|                 assert isinstance(extra_info, float), "invalid extra_info : {:}".format( | ||||
|                     extra_info | ||||
|                 ) | ||||
|                 with torch.no_grad(): | ||||
|                     prob = nn.functional.softmax(weight, dim=0) | ||||
|                     approximate_C = int(math.sqrt(extra_info) * self.Ranges[i][-1]) | ||||
|                     for j in range(prob.size(0)): | ||||
|                         prob[j] = 1 / ( | ||||
|                             abs(j - (approximate_C - self.Ranges[i][j])) + 0.2 | ||||
|                         ) | ||||
|                     C = self.Ranges[i][torch.multinomial(prob, 1, False).item()] | ||||
|             else: | ||||
|                 raise ValueError("invalid mode : {:}".format(mode)) | ||||
|             channels.append(C) | ||||
|         # select depth | ||||
|         if mode == "genotype": | ||||
|             with torch.no_grad(): | ||||
|                 depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) | ||||
|                 choices = torch.argmax(depth_probs, dim=1).cpu().tolist() | ||||
|         elif mode == "max" or mode == "fix": | ||||
|             choices = [depth_probs.size(1) - 1 for _ in range(depth_probs.size(0))] | ||||
|         elif mode == "random": | ||||
|             with torch.no_grad(): | ||||
|                 depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) | ||||
|                 choices = torch.multinomial(depth_probs, 1, False).cpu().tolist() | ||||
|         else: | ||||
|             raise ValueError("invalid mode : {:}".format(mode)) | ||||
|         selected_layers = [] | ||||
|         for choice, xvalue in zip(choices, self.depth_info_list): | ||||
|             xtemp = xvalue[1]["choices"][choice] - xvalue[1]["xstart"] + 1 | ||||
|             selected_layers.append(xtemp) | ||||
|         flop = 0 | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             s, e = self.layer2indexRange[i] | ||||
|             xchl = tuple(channels[s : e + 1]) | ||||
|             if i in self.depth_at_i: | ||||
|                 xstagei, xatti = self.depth_at_i[i] | ||||
|                 if xatti <= choices[xstagei]:  # leave this depth | ||||
|                     flop += layer.get_flops(xchl) | ||||
|                 else: | ||||
|                     flop += 0  # do not use this layer | ||||
|             else: | ||||
|                 flop += layer.get_flops(xchl) | ||||
|         # the last fc layer | ||||
|         flop += channels[-1] * self.classifier.out_features | ||||
|         if config_dict is None: | ||||
|             return flop / 1e6 | ||||
|         else: | ||||
|             config_dict["xchannels"] = channels | ||||
|             config_dict["xblocks"] = selected_layers | ||||
|             config_dict["super_type"] = "infer-shape" | ||||
|             config_dict["estimated_FLOP"] = flop / 1e6 | ||||
|             return flop / 1e6, config_dict | ||||
|  | ||||
|     def get_arch_info(self): | ||||
|         string = ( | ||||
|             "for depth and width, there are {:} + {:} attention probabilities.".format( | ||||
|                 len(self.depth_attentions), len(self.width_attentions) | ||||
|             ) | ||||
|         ) | ||||
|         string += "\n{:}".format(self.depth_info) | ||||
|         discrepancy = [] | ||||
|         with torch.no_grad(): | ||||
|             for i, att in enumerate(self.depth_attentions): | ||||
|                 prob = nn.functional.softmax(att, dim=0) | ||||
|                 prob = prob.cpu() | ||||
|                 selc = prob.argmax().item() | ||||
|                 prob = prob.tolist() | ||||
|                 prob = ["{:.3f}".format(x) for x in prob] | ||||
|                 xstring = "{:03d}/{:03d}-th : {:}".format( | ||||
|                     i, len(self.depth_attentions), " ".join(prob) | ||||
|                 ) | ||||
|                 logt = ["{:.4f}".format(x) for x in att.cpu().tolist()] | ||||
|                 xstring += "  ||  {:17s}".format(" ".join(logt)) | ||||
|                 prob = sorted([float(x) for x in prob]) | ||||
|                 disc = prob[-1] - prob[-2] | ||||
|                 xstring += "  || discrepancy={:.2f} || select={:}/{:}".format( | ||||
|                     disc, selc, len(prob) | ||||
|                 ) | ||||
|                 discrepancy.append(disc) | ||||
|                 string += "\n{:}".format(xstring) | ||||
|             string += "\n-----------------------------------------------" | ||||
|             for i, att in enumerate(self.width_attentions): | ||||
|                 prob = nn.functional.softmax(att, dim=0) | ||||
|                 prob = prob.cpu() | ||||
|                 selc = prob.argmax().item() | ||||
|                 prob = prob.tolist() | ||||
|                 prob = ["{:.3f}".format(x) for x in prob] | ||||
|                 xstring = "{:03d}/{:03d}-th : {:}".format( | ||||
|                     i, len(self.width_attentions), " ".join(prob) | ||||
|                 ) | ||||
|                 logt = ["{:.3f}".format(x) for x in att.cpu().tolist()] | ||||
|                 xstring += "  ||  {:52s}".format(" ".join(logt)) | ||||
|                 prob = sorted([float(x) for x in prob]) | ||||
|                 disc = prob[-1] - prob[-2] | ||||
|                 xstring += "  || dis={:.2f} || select={:}/{:}".format( | ||||
|                     disc, selc, len(prob) | ||||
|                 ) | ||||
|                 discrepancy.append(disc) | ||||
|                 string += "\n{:}".format(xstring) | ||||
|         return string, discrepancy | ||||
|  | ||||
|     def set_tau(self, tau_max, tau_min, epoch_ratio): | ||||
|         assert ( | ||||
|             epoch_ratio >= 0 and epoch_ratio <= 1 | ||||
|         ), "invalid epoch-ratio : {:}".format(epoch_ratio) | ||||
|         tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2 | ||||
|         self.tau = tau | ||||
|  | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.search_mode == "basic": | ||||
|             return self.basic_forward(inputs) | ||||
|         elif self.search_mode == "search": | ||||
|             return self.search_forward(inputs) | ||||
|         else: | ||||
|             raise ValueError("invalid search_mode = {:}".format(self.search_mode)) | ||||
|  | ||||
|     def search_forward(self, inputs): | ||||
|         flop_width_probs = nn.functional.softmax(self.width_attentions, dim=1) | ||||
|         flop_depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) | ||||
|         flop_depth_probs = torch.flip( | ||||
|             torch.cumsum(torch.flip(flop_depth_probs, [1]), 1), [1] | ||||
|         ) | ||||
|         selected_widths, selected_width_probs = select2withP( | ||||
|             self.width_attentions, self.tau | ||||
|         ) | ||||
|         selected_depth_probs = select2withP(self.depth_attentions, self.tau, True) | ||||
|         with torch.no_grad(): | ||||
|             selected_widths = selected_widths.cpu() | ||||
|  | ||||
|         x, last_channel_idx, expected_inC, flops = inputs, 0, 3, [] | ||||
|         feature_maps = [] | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             selected_w_index = selected_widths[ | ||||
|                 last_channel_idx : last_channel_idx + layer.num_conv | ||||
|             ] | ||||
|             selected_w_probs = selected_width_probs[ | ||||
|                 last_channel_idx : last_channel_idx + layer.num_conv | ||||
|             ] | ||||
|             layer_prob = flop_width_probs[ | ||||
|                 last_channel_idx : last_channel_idx + layer.num_conv | ||||
|             ] | ||||
|             x, expected_inC, expected_flop = layer( | ||||
|                 (x, expected_inC, layer_prob, selected_w_index, selected_w_probs) | ||||
|             ) | ||||
|             feature_maps.append(x) | ||||
|             last_channel_idx += layer.num_conv | ||||
|             if i in self.depth_info:  # aggregate the information | ||||
|                 choices = self.depth_info[i]["choices"] | ||||
|                 xstagei = self.depth_info[i]["stage"] | ||||
|                 # print ('iL={:}, choices={:}, stage={:}, probs={:}'.format(i, choices, xstagei, selected_depth_probs[xstagei].cpu().tolist())) | ||||
|                 # for A, W in zip(choices, selected_depth_probs[xstagei]): | ||||
|                 #  print('Size = {:}, W = {:}'.format(feature_maps[A].size(), W)) | ||||
|                 possible_tensors = [] | ||||
|                 max_C = max(feature_maps[A].size(1) for A in choices) | ||||
|                 for tempi, A in enumerate(choices): | ||||
|                     xtensor = ChannelWiseInter(feature_maps[A], max_C) | ||||
|                     # drop_ratio = 1-(tempi+1.0)/len(choices) | ||||
|                     # xtensor = drop_path(xtensor, drop_ratio) | ||||
|                     possible_tensors.append(xtensor) | ||||
|                 weighted_sum = sum( | ||||
|                     xtensor * W | ||||
|                     for xtensor, W in zip( | ||||
|                         possible_tensors, selected_depth_probs[xstagei] | ||||
|                     ) | ||||
|                 ) | ||||
|                 x = weighted_sum | ||||
|  | ||||
|             if i in self.depth_at_i: | ||||
|                 xstagei, xatti = self.depth_at_i[i] | ||||
|                 x_expected_flop = flop_depth_probs[xstagei, xatti] * expected_flop | ||||
|             else: | ||||
|                 x_expected_flop = expected_flop | ||||
|             flops.append(x_expected_flop) | ||||
|         flops.append(expected_inC * (self.classifier.out_features * 1.0 / 1e6)) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = linear_forward(features, self.classifier) | ||||
|         return logits, torch.stack([sum(flops)]) | ||||
|  | ||||
|     def basic_forward(self, inputs): | ||||
|         if self.InShape is None: | ||||
|             self.InShape = (inputs.size(-2), inputs.size(-1)) | ||||
|         x = inputs | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             x = layer(x) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = self.classifier(features) | ||||
|         return features, logits | ||||
| @@ -1,515 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import math, torch | ||||
| from collections import OrderedDict | ||||
| from bisect import bisect_right | ||||
| import torch.nn as nn | ||||
| from ..initialization import initialize_resnet | ||||
| from ..SharedUtils import additive_func | ||||
| from .SoftSelect import select2withP, ChannelWiseInter | ||||
| from .SoftSelect import linear_forward | ||||
| from .SoftSelect import get_width_choices | ||||
|  | ||||
|  | ||||
| def get_depth_choices(nDepth, return_num): | ||||
|     if nDepth == 2: | ||||
|         choices = (1, 2) | ||||
|     elif nDepth == 3: | ||||
|         choices = (1, 2, 3) | ||||
|     elif nDepth > 3: | ||||
|         choices = list(range(1, nDepth + 1, 2)) | ||||
|         if choices[-1] < nDepth: | ||||
|             choices.append(nDepth) | ||||
|     else: | ||||
|         raise ValueError("invalid nDepth : {:}".format(nDepth)) | ||||
|     if return_num: | ||||
|         return len(choices) | ||||
|     else: | ||||
|         return choices | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|     num_conv = 1 | ||||
|  | ||||
|     def __init__( | ||||
|         self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu | ||||
|     ): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|         self.InShape = None | ||||
|         self.OutShape = None | ||||
|         self.choices = get_width_choices(nOut) | ||||
|         self.register_buffer("choices_tensor", torch.Tensor(self.choices)) | ||||
|  | ||||
|         if has_avg: | ||||
|             self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | ||||
|         else: | ||||
|             self.avg = None | ||||
|         self.conv = nn.Conv2d( | ||||
|             nIn, | ||||
|             nOut, | ||||
|             kernel_size=kernel, | ||||
|             stride=stride, | ||||
|             padding=padding, | ||||
|             dilation=1, | ||||
|             groups=1, | ||||
|             bias=bias, | ||||
|         ) | ||||
|         if has_bn: | ||||
|             self.bn = nn.BatchNorm2d(nOut) | ||||
|         else: | ||||
|             self.bn = None | ||||
|         if has_relu: | ||||
|             self.relu = nn.ReLU(inplace=False) | ||||
|         else: | ||||
|             self.relu = None | ||||
|         self.in_dim = nIn | ||||
|         self.out_dim = nOut | ||||
|  | ||||
|     def get_flops(self, divide=1): | ||||
|         iC, oC = self.in_dim, self.out_dim | ||||
|         assert ( | ||||
|             iC <= self.conv.in_channels and oC <= self.conv.out_channels | ||||
|         ), "{:} vs {:}  |  {:} vs {:}".format( | ||||
|             iC, self.conv.in_channels, oC, self.conv.out_channels | ||||
|         ) | ||||
|         assert ( | ||||
|             isinstance(self.InShape, tuple) and len(self.InShape) == 2 | ||||
|         ), "invalid in-shape : {:}".format(self.InShape) | ||||
|         assert ( | ||||
|             isinstance(self.OutShape, tuple) and len(self.OutShape) == 2 | ||||
|         ), "invalid out-shape : {:}".format(self.OutShape) | ||||
|         # conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups | ||||
|         conv_per_position_flops = ( | ||||
|             self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups | ||||
|         ) | ||||
|         all_positions = self.OutShape[0] * self.OutShape[1] | ||||
|         flops = (conv_per_position_flops * all_positions / divide) * iC * oC | ||||
|         if self.conv.bias is not None: | ||||
|             flops += all_positions / divide | ||||
|         return flops | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.avg: | ||||
|             out = self.avg(inputs) | ||||
|         else: | ||||
|             out = inputs | ||||
|         conv = self.conv(out) | ||||
|         if self.bn: | ||||
|             out = self.bn(conv) | ||||
|         else: | ||||
|             out = conv | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         else: | ||||
|             out = out | ||||
|         if self.InShape is None: | ||||
|             self.InShape = (inputs.size(-2), inputs.size(-1)) | ||||
|             self.OutShape = (out.size(-2), out.size(-1)) | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class ResNetBasicblock(nn.Module): | ||||
|     expansion = 1 | ||||
|     num_conv = 2 | ||||
|  | ||||
|     def __init__(self, inplanes, planes, stride): | ||||
|         super(ResNetBasicblock, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         self.conv_a = ConvBNReLU( | ||||
|             inplanes, | ||||
|             planes, | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_b = ConvBNReLU( | ||||
|             planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False | ||||
|         ) | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         elif inplanes != planes: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.out_dim = planes | ||||
|         self.search_mode = "basic" | ||||
|  | ||||
|     def get_flops(self, divide=1): | ||||
|         flop_A = self.conv_a.get_flops(divide) | ||||
|         flop_B = self.conv_b.get_flops(divide) | ||||
|         if hasattr(self.downsample, "get_flops"): | ||||
|             flop_C = self.downsample.get_flops(divide) | ||||
|         else: | ||||
|             flop_C = 0 | ||||
|         return flop_A + flop_B + flop_C | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         basicblock = self.conv_a(inputs) | ||||
|         basicblock = self.conv_b(basicblock) | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = additive_func(residual, basicblock) | ||||
|         return nn.functional.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class ResNetBottleneck(nn.Module): | ||||
|     expansion = 4 | ||||
|     num_conv = 3 | ||||
|  | ||||
|     def __init__(self, inplanes, planes, stride): | ||||
|         super(ResNetBottleneck, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         self.conv_1x1 = ConvBNReLU( | ||||
|             inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True | ||||
|         ) | ||||
|         self.conv_3x3 = ConvBNReLU( | ||||
|             planes, | ||||
|             planes, | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_1x4 = ConvBNReLU( | ||||
|             planes, | ||||
|             planes * self.expansion, | ||||
|             1, | ||||
|             1, | ||||
|             0, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=False, | ||||
|         ) | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes * self.expansion, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         elif inplanes != planes * self.expansion: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes * self.expansion, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.out_dim = planes * self.expansion | ||||
|         self.search_mode = "basic" | ||||
|  | ||||
|     def get_range(self): | ||||
|         return ( | ||||
|             self.conv_1x1.get_range() | ||||
|             + self.conv_3x3.get_range() | ||||
|             + self.conv_1x4.get_range() | ||||
|         ) | ||||
|  | ||||
|     def get_flops(self, divide): | ||||
|         flop_A = self.conv_1x1.get_flops(divide) | ||||
|         flop_B = self.conv_3x3.get_flops(divide) | ||||
|         flop_C = self.conv_1x4.get_flops(divide) | ||||
|         if hasattr(self.downsample, "get_flops"): | ||||
|             flop_D = self.downsample.get_flops(divide) | ||||
|         else: | ||||
|             flop_D = 0 | ||||
|         return flop_A + flop_B + flop_C + flop_D | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         bottleneck = self.conv_1x1(inputs) | ||||
|         bottleneck = self.conv_3x3(bottleneck) | ||||
|         bottleneck = self.conv_1x4(bottleneck) | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = additive_func(residual, bottleneck) | ||||
|         return nn.functional.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class SearchDepthCifarResNet(nn.Module): | ||||
|     def __init__(self, block_name, depth, num_classes): | ||||
|         super(SearchDepthCifarResNet, self).__init__() | ||||
|  | ||||
|         # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|         if block_name == "ResNetBasicblock": | ||||
|             block = ResNetBasicblock | ||||
|             assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" | ||||
|             layer_blocks = (depth - 2) // 6 | ||||
|         elif block_name == "ResNetBottleneck": | ||||
|             block = ResNetBottleneck | ||||
|             assert (depth - 2) % 9 == 0, "depth should be one of 164" | ||||
|             layer_blocks = (depth - 2) // 9 | ||||
|         else: | ||||
|             raise ValueError("invalid block : {:}".format(block_name)) | ||||
|  | ||||
|         self.message = ( | ||||
|             "SearchShapeCifarResNet : Depth : {:} , Layers for each block : {:}".format( | ||||
|                 depth, layer_blocks | ||||
|             ) | ||||
|         ) | ||||
|         self.num_classes = num_classes | ||||
|         self.channels = [16] | ||||
|         self.layers = nn.ModuleList( | ||||
|             [ | ||||
|                 ConvBNReLU( | ||||
|                     3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True | ||||
|                 ) | ||||
|             ] | ||||
|         ) | ||||
|         self.InShape = None | ||||
|         self.depth_info = OrderedDict() | ||||
|         self.depth_at_i = OrderedDict() | ||||
|         for stage in range(3): | ||||
|             cur_block_choices = get_depth_choices(layer_blocks, False) | ||||
|             assert ( | ||||
|                 cur_block_choices[-1] == layer_blocks | ||||
|             ), "stage={:}, {:} vs {:}".format(stage, cur_block_choices, layer_blocks) | ||||
|             self.message += ( | ||||
|                 "\nstage={:} ::: depth-block-choices={:} for {:} blocks.".format( | ||||
|                     stage, cur_block_choices, layer_blocks | ||||
|                 ) | ||||
|             ) | ||||
|             block_choices, xstart = [], len(self.layers) | ||||
|             for iL in range(layer_blocks): | ||||
|                 iC = self.channels[-1] | ||||
|                 planes = 16 * (2 ** stage) | ||||
|                 stride = 2 if stage > 0 and iL == 0 else 1 | ||||
|                 module = block(iC, planes, stride) | ||||
|                 self.channels.append(module.out_dim) | ||||
|                 self.layers.append(module) | ||||
|                 self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format( | ||||
|                     stage, | ||||
|                     iL, | ||||
|                     layer_blocks, | ||||
|                     len(self.layers) - 1, | ||||
|                     iC, | ||||
|                     module.out_dim, | ||||
|                     stride, | ||||
|                 ) | ||||
|                 # added for depth | ||||
|                 layer_index = len(self.layers) - 1 | ||||
|                 if iL + 1 in cur_block_choices: | ||||
|                     block_choices.append(layer_index) | ||||
|                 if iL + 1 == layer_blocks: | ||||
|                     self.depth_info[layer_index] = { | ||||
|                         "choices": block_choices, | ||||
|                         "stage": stage, | ||||
|                         "xstart": xstart, | ||||
|                     } | ||||
|         self.depth_info_list = [] | ||||
|         for xend, info in self.depth_info.items(): | ||||
|             self.depth_info_list.append((xend, info)) | ||||
|             xstart, xstage = info["xstart"], info["stage"] | ||||
|             for ilayer in range(xstart, xend + 1): | ||||
|                 idx = bisect_right(info["choices"], ilayer - 1) | ||||
|                 self.depth_at_i[ilayer] = (xstage, idx) | ||||
|  | ||||
|         self.avgpool = nn.AvgPool2d(8) | ||||
|         self.classifier = nn.Linear(module.out_dim, num_classes) | ||||
|         self.InShape = None | ||||
|         self.tau = -1 | ||||
|         self.search_mode = "basic" | ||||
|         # assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth) | ||||
|  | ||||
|         self.register_parameter( | ||||
|             "depth_attentions", | ||||
|             nn.Parameter(torch.Tensor(3, get_depth_choices(layer_blocks, True))), | ||||
|         ) | ||||
|         nn.init.normal_(self.depth_attentions, 0, 0.01) | ||||
|         self.apply(initialize_resnet) | ||||
|  | ||||
|     def arch_parameters(self): | ||||
|         return [self.depth_attentions] | ||||
|  | ||||
|     def base_parameters(self): | ||||
|         return ( | ||||
|             list(self.layers.parameters()) | ||||
|             + list(self.avgpool.parameters()) | ||||
|             + list(self.classifier.parameters()) | ||||
|         ) | ||||
|  | ||||
|     def get_flop(self, mode, config_dict, extra_info): | ||||
|         if config_dict is not None: | ||||
|             config_dict = config_dict.copy() | ||||
|         # select depth | ||||
|         if mode == "genotype": | ||||
|             with torch.no_grad(): | ||||
|                 depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) | ||||
|                 choices = torch.argmax(depth_probs, dim=1).cpu().tolist() | ||||
|         elif mode == "max": | ||||
|             choices = [depth_probs.size(1) - 1 for _ in range(depth_probs.size(0))] | ||||
|         elif mode == "random": | ||||
|             with torch.no_grad(): | ||||
|                 depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) | ||||
|                 choices = torch.multinomial(depth_probs, 1, False).cpu().tolist() | ||||
|         else: | ||||
|             raise ValueError("invalid mode : {:}".format(mode)) | ||||
|         selected_layers = [] | ||||
|         for choice, xvalue in zip(choices, self.depth_info_list): | ||||
|             xtemp = xvalue[1]["choices"][choice] - xvalue[1]["xstart"] + 1 | ||||
|             selected_layers.append(xtemp) | ||||
|         flop = 0 | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             if i in self.depth_at_i: | ||||
|                 xstagei, xatti = self.depth_at_i[i] | ||||
|                 if xatti <= choices[xstagei]:  # leave this depth | ||||
|                     flop += layer.get_flops() | ||||
|                 else: | ||||
|                     flop += 0  # do not use this layer | ||||
|             else: | ||||
|                 flop += layer.get_flops() | ||||
|         # the last fc layer | ||||
|         flop += self.classifier.in_features * self.classifier.out_features | ||||
|         if config_dict is None: | ||||
|             return flop / 1e6 | ||||
|         else: | ||||
|             config_dict["xblocks"] = selected_layers | ||||
|             config_dict["super_type"] = "infer-depth" | ||||
|             config_dict["estimated_FLOP"] = flop / 1e6 | ||||
|             return flop / 1e6, config_dict | ||||
|  | ||||
|     def get_arch_info(self): | ||||
|         string = "for depth, there are {:} attention probabilities.".format( | ||||
|             len(self.depth_attentions) | ||||
|         ) | ||||
|         string += "\n{:}".format(self.depth_info) | ||||
|         discrepancy = [] | ||||
|         with torch.no_grad(): | ||||
|             for i, att in enumerate(self.depth_attentions): | ||||
|                 prob = nn.functional.softmax(att, dim=0) | ||||
|                 prob = prob.cpu() | ||||
|                 selc = prob.argmax().item() | ||||
|                 prob = prob.tolist() | ||||
|                 prob = ["{:.3f}".format(x) for x in prob] | ||||
|                 xstring = "{:03d}/{:03d}-th : {:}".format( | ||||
|                     i, len(self.depth_attentions), " ".join(prob) | ||||
|                 ) | ||||
|                 logt = ["{:.4f}".format(x) for x in att.cpu().tolist()] | ||||
|                 xstring += "  ||  {:17s}".format(" ".join(logt)) | ||||
|                 prob = sorted([float(x) for x in prob]) | ||||
|                 disc = prob[-1] - prob[-2] | ||||
|                 xstring += "  || discrepancy={:.2f} || select={:}/{:}".format( | ||||
|                     disc, selc, len(prob) | ||||
|                 ) | ||||
|                 discrepancy.append(disc) | ||||
|                 string += "\n{:}".format(xstring) | ||||
|         return string, discrepancy | ||||
|  | ||||
|     def set_tau(self, tau_max, tau_min, epoch_ratio): | ||||
|         assert ( | ||||
|             epoch_ratio >= 0 and epoch_ratio <= 1 | ||||
|         ), "invalid epoch-ratio : {:}".format(epoch_ratio) | ||||
|         tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2 | ||||
|         self.tau = tau | ||||
|  | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.search_mode == "basic": | ||||
|             return self.basic_forward(inputs) | ||||
|         elif self.search_mode == "search": | ||||
|             return self.search_forward(inputs) | ||||
|         else: | ||||
|             raise ValueError("invalid search_mode = {:}".format(self.search_mode)) | ||||
|  | ||||
|     def search_forward(self, inputs): | ||||
|         flop_depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) | ||||
|         flop_depth_probs = torch.flip( | ||||
|             torch.cumsum(torch.flip(flop_depth_probs, [1]), 1), [1] | ||||
|         ) | ||||
|         selected_depth_probs = select2withP(self.depth_attentions, self.tau, True) | ||||
|  | ||||
|         x, flops = inputs, [] | ||||
|         feature_maps = [] | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             layer_i = layer(x) | ||||
|             feature_maps.append(layer_i) | ||||
|             if i in self.depth_info:  # aggregate the information | ||||
|                 choices = self.depth_info[i]["choices"] | ||||
|                 xstagei = self.depth_info[i]["stage"] | ||||
|                 possible_tensors = [] | ||||
|                 for tempi, A in enumerate(choices): | ||||
|                     xtensor = feature_maps[A] | ||||
|                     possible_tensors.append(xtensor) | ||||
|                 weighted_sum = sum( | ||||
|                     xtensor * W | ||||
|                     for xtensor, W in zip( | ||||
|                         possible_tensors, selected_depth_probs[xstagei] | ||||
|                     ) | ||||
|                 ) | ||||
|                 x = weighted_sum | ||||
|             else: | ||||
|                 x = layer_i | ||||
|  | ||||
|             if i in self.depth_at_i: | ||||
|                 xstagei, xatti = self.depth_at_i[i] | ||||
|                 # print ('layer-{:03d}, stage={:}, att={:}, prob={:}, flop={:}'.format(i, xstagei, xatti, flop_depth_probs[xstagei, xatti].item(), layer.get_flops(1e6))) | ||||
|                 x_expected_flop = flop_depth_probs[xstagei, xatti] * layer.get_flops( | ||||
|                     1e6 | ||||
|                 ) | ||||
|             else: | ||||
|                 x_expected_flop = layer.get_flops(1e6) | ||||
|             flops.append(x_expected_flop) | ||||
|         flops.append( | ||||
|             (self.classifier.in_features * self.classifier.out_features * 1.0 / 1e6) | ||||
|         ) | ||||
|  | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = linear_forward(features, self.classifier) | ||||
|         return logits, torch.stack([sum(flops)]) | ||||
|  | ||||
|     def basic_forward(self, inputs): | ||||
|         if self.InShape is None: | ||||
|             self.InShape = (inputs.size(-2), inputs.size(-1)) | ||||
|         x = inputs | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             x = layer(x) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = self.classifier(features) | ||||
|         return features, logits | ||||
| @@ -1,619 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import math, torch | ||||
| import torch.nn as nn | ||||
| from ..initialization import initialize_resnet | ||||
| from ..SharedUtils import additive_func | ||||
| from .SoftSelect import select2withP, ChannelWiseInter | ||||
| from .SoftSelect import linear_forward | ||||
| from .SoftSelect import get_width_choices as get_choices | ||||
|  | ||||
|  | ||||
| def conv_forward(inputs, conv, choices): | ||||
|     iC = conv.in_channels | ||||
|     fill_size = list(inputs.size()) | ||||
|     fill_size[1] = iC - fill_size[1] | ||||
|     filled = torch.zeros(fill_size, device=inputs.device) | ||||
|     xinputs = torch.cat((inputs, filled), dim=1) | ||||
|     outputs = conv(xinputs) | ||||
|     selecteds = [outputs[:, :oC] for oC in choices] | ||||
|     return selecteds | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|     num_conv = 1 | ||||
|  | ||||
|     def __init__( | ||||
|         self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu | ||||
|     ): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|         self.InShape = None | ||||
|         self.OutShape = None | ||||
|         self.choices = get_choices(nOut) | ||||
|         self.register_buffer("choices_tensor", torch.Tensor(self.choices)) | ||||
|  | ||||
|         if has_avg: | ||||
|             self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | ||||
|         else: | ||||
|             self.avg = None | ||||
|         self.conv = nn.Conv2d( | ||||
|             nIn, | ||||
|             nOut, | ||||
|             kernel_size=kernel, | ||||
|             stride=stride, | ||||
|             padding=padding, | ||||
|             dilation=1, | ||||
|             groups=1, | ||||
|             bias=bias, | ||||
|         ) | ||||
|         # if has_bn  : self.bn  = nn.BatchNorm2d(nOut) | ||||
|         # else       : self.bn  = None | ||||
|         self.has_bn = has_bn | ||||
|         self.BNs = nn.ModuleList() | ||||
|         for i, _out in enumerate(self.choices): | ||||
|             self.BNs.append(nn.BatchNorm2d(_out)) | ||||
|         if has_relu: | ||||
|             self.relu = nn.ReLU(inplace=True) | ||||
|         else: | ||||
|             self.relu = None | ||||
|         self.in_dim = nIn | ||||
|         self.out_dim = nOut | ||||
|         self.search_mode = "basic" | ||||
|  | ||||
|     def get_flops(self, channels, check_range=True, divide=1): | ||||
|         iC, oC = channels | ||||
|         if check_range: | ||||
|             assert ( | ||||
|                 iC <= self.conv.in_channels and oC <= self.conv.out_channels | ||||
|             ), "{:} vs {:}  |  {:} vs {:}".format( | ||||
|                 iC, self.conv.in_channels, oC, self.conv.out_channels | ||||
|             ) | ||||
|         assert ( | ||||
|             isinstance(self.InShape, tuple) and len(self.InShape) == 2 | ||||
|         ), "invalid in-shape : {:}".format(self.InShape) | ||||
|         assert ( | ||||
|             isinstance(self.OutShape, tuple) and len(self.OutShape) == 2 | ||||
|         ), "invalid out-shape : {:}".format(self.OutShape) | ||||
|         # conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups | ||||
|         conv_per_position_flops = ( | ||||
|             self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups | ||||
|         ) | ||||
|         all_positions = self.OutShape[0] * self.OutShape[1] | ||||
|         flops = (conv_per_position_flops * all_positions / divide) * iC * oC | ||||
|         if self.conv.bias is not None: | ||||
|             flops += all_positions / divide | ||||
|         return flops | ||||
|  | ||||
|     def get_range(self): | ||||
|         return [self.choices] | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.search_mode == "basic": | ||||
|             return self.basic_forward(inputs) | ||||
|         elif self.search_mode == "search": | ||||
|             return self.search_forward(inputs) | ||||
|         else: | ||||
|             raise ValueError("invalid search_mode = {:}".format(self.search_mode)) | ||||
|  | ||||
|     def search_forward(self, tuple_inputs): | ||||
|         assert ( | ||||
|             isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 | ||||
|         ), "invalid type input : {:}".format(type(tuple_inputs)) | ||||
|         inputs, expected_inC, probability, index, prob = tuple_inputs | ||||
|         index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob) | ||||
|         probability = torch.squeeze(probability) | ||||
|         assert len(index) == 2, "invalid length : {:}".format(index) | ||||
|         # compute expected flop | ||||
|         # coordinates   = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability) | ||||
|         expected_outC = (self.choices_tensor * probability).sum() | ||||
|         expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6) | ||||
|         if self.avg: | ||||
|             out = self.avg(inputs) | ||||
|         else: | ||||
|             out = inputs | ||||
|         # convolutional layer | ||||
|         out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index]) | ||||
|         out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)] | ||||
|         # merge | ||||
|         out_channel = max([x.size(1) for x in out_bns]) | ||||
|         outA = ChannelWiseInter(out_bns[0], out_channel) | ||||
|         outB = ChannelWiseInter(out_bns[1], out_channel) | ||||
|         out = outA * prob[0] + outB * prob[1] | ||||
|         # out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1]) | ||||
|  | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         else: | ||||
|             out = out | ||||
|         return out, expected_outC, expected_flop | ||||
|  | ||||
|     def basic_forward(self, inputs): | ||||
|         if self.avg: | ||||
|             out = self.avg(inputs) | ||||
|         else: | ||||
|             out = inputs | ||||
|         conv = self.conv(out) | ||||
|         if self.has_bn: | ||||
|             out = self.BNs[-1](conv) | ||||
|         else: | ||||
|             out = conv | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         else: | ||||
|             out = out | ||||
|         if self.InShape is None: | ||||
|             self.InShape = (inputs.size(-2), inputs.size(-1)) | ||||
|             self.OutShape = (out.size(-2), out.size(-1)) | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class ResNetBasicblock(nn.Module): | ||||
|     expansion = 1 | ||||
|     num_conv = 2 | ||||
|  | ||||
|     def __init__(self, inplanes, planes, stride): | ||||
|         super(ResNetBasicblock, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         self.conv_a = ConvBNReLU( | ||||
|             inplanes, | ||||
|             planes, | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_b = ConvBNReLU( | ||||
|             planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False | ||||
|         ) | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         elif inplanes != planes: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.out_dim = planes | ||||
|         self.search_mode = "basic" | ||||
|  | ||||
|     def get_range(self): | ||||
|         return self.conv_a.get_range() + self.conv_b.get_range() | ||||
|  | ||||
|     def get_flops(self, channels): | ||||
|         assert len(channels) == 3, "invalid channels : {:}".format(channels) | ||||
|         flop_A = self.conv_a.get_flops([channels[0], channels[1]]) | ||||
|         flop_B = self.conv_b.get_flops([channels[1], channels[2]]) | ||||
|         if hasattr(self.downsample, "get_flops"): | ||||
|             flop_C = self.downsample.get_flops([channels[0], channels[-1]]) | ||||
|         else: | ||||
|             flop_C = 0 | ||||
|         if ( | ||||
|             channels[0] != channels[-1] and self.downsample is None | ||||
|         ):  # this short-cut will be added during the infer-train | ||||
|             flop_C = ( | ||||
|                 channels[0] | ||||
|                 * channels[-1] | ||||
|                 * self.conv_b.OutShape[0] | ||||
|                 * self.conv_b.OutShape[1] | ||||
|             ) | ||||
|         return flop_A + flop_B + flop_C | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.search_mode == "basic": | ||||
|             return self.basic_forward(inputs) | ||||
|         elif self.search_mode == "search": | ||||
|             return self.search_forward(inputs) | ||||
|         else: | ||||
|             raise ValueError("invalid search_mode = {:}".format(self.search_mode)) | ||||
|  | ||||
|     def search_forward(self, tuple_inputs): | ||||
|         assert ( | ||||
|             isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 | ||||
|         ), "invalid type input : {:}".format(type(tuple_inputs)) | ||||
|         inputs, expected_inC, probability, indexes, probs = tuple_inputs | ||||
|         assert indexes.size(0) == 2 and probs.size(0) == 2 and probability.size(0) == 2 | ||||
|         out_a, expected_inC_a, expected_flop_a = self.conv_a( | ||||
|             (inputs, expected_inC, probability[0], indexes[0], probs[0]) | ||||
|         ) | ||||
|         out_b, expected_inC_b, expected_flop_b = self.conv_b( | ||||
|             (out_a, expected_inC_a, probability[1], indexes[1], probs[1]) | ||||
|         ) | ||||
|         if self.downsample is not None: | ||||
|             residual, _, expected_flop_c = self.downsample( | ||||
|                 (inputs, expected_inC, probability[1], indexes[1], probs[1]) | ||||
|             ) | ||||
|         else: | ||||
|             residual, expected_flop_c = inputs, 0 | ||||
|         out = additive_func(residual, out_b) | ||||
|         return ( | ||||
|             nn.functional.relu(out, inplace=True), | ||||
|             expected_inC_b, | ||||
|             sum([expected_flop_a, expected_flop_b, expected_flop_c]), | ||||
|         ) | ||||
|  | ||||
|     def basic_forward(self, inputs): | ||||
|         basicblock = self.conv_a(inputs) | ||||
|         basicblock = self.conv_b(basicblock) | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = additive_func(residual, basicblock) | ||||
|         return nn.functional.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class ResNetBottleneck(nn.Module): | ||||
|     expansion = 4 | ||||
|     num_conv = 3 | ||||
|  | ||||
|     def __init__(self, inplanes, planes, stride): | ||||
|         super(ResNetBottleneck, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         self.conv_1x1 = ConvBNReLU( | ||||
|             inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True | ||||
|         ) | ||||
|         self.conv_3x3 = ConvBNReLU( | ||||
|             planes, | ||||
|             planes, | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_1x4 = ConvBNReLU( | ||||
|             planes, | ||||
|             planes * self.expansion, | ||||
|             1, | ||||
|             1, | ||||
|             0, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=False, | ||||
|         ) | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes * self.expansion, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         elif inplanes != planes * self.expansion: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes * self.expansion, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.out_dim = planes * self.expansion | ||||
|         self.search_mode = "basic" | ||||
|  | ||||
|     def get_range(self): | ||||
|         return ( | ||||
|             self.conv_1x1.get_range() | ||||
|             + self.conv_3x3.get_range() | ||||
|             + self.conv_1x4.get_range() | ||||
|         ) | ||||
|  | ||||
|     def get_flops(self, channels): | ||||
|         assert len(channels) == 4, "invalid channels : {:}".format(channels) | ||||
|         flop_A = self.conv_1x1.get_flops([channels[0], channels[1]]) | ||||
|         flop_B = self.conv_3x3.get_flops([channels[1], channels[2]]) | ||||
|         flop_C = self.conv_1x4.get_flops([channels[2], channels[3]]) | ||||
|         if hasattr(self.downsample, "get_flops"): | ||||
|             flop_D = self.downsample.get_flops([channels[0], channels[-1]]) | ||||
|         else: | ||||
|             flop_D = 0 | ||||
|         if ( | ||||
|             channels[0] != channels[-1] and self.downsample is None | ||||
|         ):  # this short-cut will be added during the infer-train | ||||
|             flop_D = ( | ||||
|                 channels[0] | ||||
|                 * channels[-1] | ||||
|                 * self.conv_1x4.OutShape[0] | ||||
|                 * self.conv_1x4.OutShape[1] | ||||
|             ) | ||||
|         return flop_A + flop_B + flop_C + flop_D | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.search_mode == "basic": | ||||
|             return self.basic_forward(inputs) | ||||
|         elif self.search_mode == "search": | ||||
|             return self.search_forward(inputs) | ||||
|         else: | ||||
|             raise ValueError("invalid search_mode = {:}".format(self.search_mode)) | ||||
|  | ||||
|     def basic_forward(self, inputs): | ||||
|         bottleneck = self.conv_1x1(inputs) | ||||
|         bottleneck = self.conv_3x3(bottleneck) | ||||
|         bottleneck = self.conv_1x4(bottleneck) | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = additive_func(residual, bottleneck) | ||||
|         return nn.functional.relu(out, inplace=True) | ||||
|  | ||||
|     def search_forward(self, tuple_inputs): | ||||
|         assert ( | ||||
|             isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 | ||||
|         ), "invalid type input : {:}".format(type(tuple_inputs)) | ||||
|         inputs, expected_inC, probability, indexes, probs = tuple_inputs | ||||
|         assert indexes.size(0) == 3 and probs.size(0) == 3 and probability.size(0) == 3 | ||||
|         out_1x1, expected_inC_1x1, expected_flop_1x1 = self.conv_1x1( | ||||
|             (inputs, expected_inC, probability[0], indexes[0], probs[0]) | ||||
|         ) | ||||
|         out_3x3, expected_inC_3x3, expected_flop_3x3 = self.conv_3x3( | ||||
|             (out_1x1, expected_inC_1x1, probability[1], indexes[1], probs[1]) | ||||
|         ) | ||||
|         out_1x4, expected_inC_1x4, expected_flop_1x4 = self.conv_1x4( | ||||
|             (out_3x3, expected_inC_3x3, probability[2], indexes[2], probs[2]) | ||||
|         ) | ||||
|         if self.downsample is not None: | ||||
|             residual, _, expected_flop_c = self.downsample( | ||||
|                 (inputs, expected_inC, probability[2], indexes[2], probs[2]) | ||||
|             ) | ||||
|         else: | ||||
|             residual, expected_flop_c = inputs, 0 | ||||
|         out = additive_func(residual, out_1x4) | ||||
|         return ( | ||||
|             nn.functional.relu(out, inplace=True), | ||||
|             expected_inC_1x4, | ||||
|             sum( | ||||
|                 [ | ||||
|                     expected_flop_1x1, | ||||
|                     expected_flop_3x3, | ||||
|                     expected_flop_1x4, | ||||
|                     expected_flop_c, | ||||
|                 ] | ||||
|             ), | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class SearchWidthCifarResNet(nn.Module): | ||||
|     def __init__(self, block_name, depth, num_classes): | ||||
|         super(SearchWidthCifarResNet, self).__init__() | ||||
|  | ||||
|         # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|         if block_name == "ResNetBasicblock": | ||||
|             block = ResNetBasicblock | ||||
|             assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" | ||||
|             layer_blocks = (depth - 2) // 6 | ||||
|         elif block_name == "ResNetBottleneck": | ||||
|             block = ResNetBottleneck | ||||
|             assert (depth - 2) % 9 == 0, "depth should be one of 164" | ||||
|             layer_blocks = (depth - 2) // 9 | ||||
|         else: | ||||
|             raise ValueError("invalid block : {:}".format(block_name)) | ||||
|  | ||||
|         self.message = ( | ||||
|             "SearchWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format( | ||||
|                 depth, layer_blocks | ||||
|             ) | ||||
|         ) | ||||
|         self.num_classes = num_classes | ||||
|         self.channels = [16] | ||||
|         self.layers = nn.ModuleList( | ||||
|             [ | ||||
|                 ConvBNReLU( | ||||
|                     3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True | ||||
|                 ) | ||||
|             ] | ||||
|         ) | ||||
|         self.InShape = None | ||||
|         for stage in range(3): | ||||
|             for iL in range(layer_blocks): | ||||
|                 iC = self.channels[-1] | ||||
|                 planes = 16 * (2 ** stage) | ||||
|                 stride = 2 if stage > 0 and iL == 0 else 1 | ||||
|                 module = block(iC, planes, stride) | ||||
|                 self.channels.append(module.out_dim) | ||||
|                 self.layers.append(module) | ||||
|                 self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format( | ||||
|                     stage, | ||||
|                     iL, | ||||
|                     layer_blocks, | ||||
|                     len(self.layers) - 1, | ||||
|                     iC, | ||||
|                     module.out_dim, | ||||
|                     stride, | ||||
|                 ) | ||||
|  | ||||
|         self.avgpool = nn.AvgPool2d(8) | ||||
|         self.classifier = nn.Linear(module.out_dim, num_classes) | ||||
|         self.InShape = None | ||||
|         self.tau = -1 | ||||
|         self.search_mode = "basic" | ||||
|         # assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth) | ||||
|  | ||||
|         # parameters for width | ||||
|         self.Ranges = [] | ||||
|         self.layer2indexRange = [] | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             start_index = len(self.Ranges) | ||||
|             self.Ranges += layer.get_range() | ||||
|             self.layer2indexRange.append((start_index, len(self.Ranges))) | ||||
|         assert len(self.Ranges) + 1 == depth, "invalid depth check {:} vs {:}".format( | ||||
|             len(self.Ranges) + 1, depth | ||||
|         ) | ||||
|  | ||||
|         self.register_parameter( | ||||
|             "width_attentions", | ||||
|             nn.Parameter(torch.Tensor(len(self.Ranges), get_choices(None))), | ||||
|         ) | ||||
|         nn.init.normal_(self.width_attentions, 0, 0.01) | ||||
|         self.apply(initialize_resnet) | ||||
|  | ||||
|     def arch_parameters(self): | ||||
|         return [self.width_attentions] | ||||
|  | ||||
|     def base_parameters(self): | ||||
|         return ( | ||||
|             list(self.layers.parameters()) | ||||
|             + list(self.avgpool.parameters()) | ||||
|             + list(self.classifier.parameters()) | ||||
|         ) | ||||
|  | ||||
|     def get_flop(self, mode, config_dict, extra_info): | ||||
|         if config_dict is not None: | ||||
|             config_dict = config_dict.copy() | ||||
|         # weights = [F.softmax(x, dim=0) for x in self.width_attentions] | ||||
|         channels = [3] | ||||
|         for i, weight in enumerate(self.width_attentions): | ||||
|             if mode == "genotype": | ||||
|                 with torch.no_grad(): | ||||
|                     probe = nn.functional.softmax(weight, dim=0) | ||||
|                     C = self.Ranges[i][torch.argmax(probe).item()] | ||||
|             elif mode == "max": | ||||
|                 C = self.Ranges[i][-1] | ||||
|             elif mode == "fix": | ||||
|                 C = int(math.sqrt(extra_info) * self.Ranges[i][-1]) | ||||
|             elif mode == "random": | ||||
|                 assert isinstance(extra_info, float), "invalid extra_info : {:}".format( | ||||
|                     extra_info | ||||
|                 ) | ||||
|                 with torch.no_grad(): | ||||
|                     prob = nn.functional.softmax(weight, dim=0) | ||||
|                     approximate_C = int(math.sqrt(extra_info) * self.Ranges[i][-1]) | ||||
|                     for j in range(prob.size(0)): | ||||
|                         prob[j] = 1 / ( | ||||
|                             abs(j - (approximate_C - self.Ranges[i][j])) + 0.2 | ||||
|                         ) | ||||
|                     C = self.Ranges[i][torch.multinomial(prob, 1, False).item()] | ||||
|             else: | ||||
|                 raise ValueError("invalid mode : {:}".format(mode)) | ||||
|             channels.append(C) | ||||
|         flop = 0 | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             s, e = self.layer2indexRange[i] | ||||
|             xchl = tuple(channels[s : e + 1]) | ||||
|             flop += layer.get_flops(xchl) | ||||
|         # the last fc layer | ||||
|         flop += channels[-1] * self.classifier.out_features | ||||
|         if config_dict is None: | ||||
|             return flop / 1e6 | ||||
|         else: | ||||
|             config_dict["xchannels"] = channels | ||||
|             config_dict["super_type"] = "infer-width" | ||||
|             config_dict["estimated_FLOP"] = flop / 1e6 | ||||
|             return flop / 1e6, config_dict | ||||
|  | ||||
|     def get_arch_info(self): | ||||
|         string = "for width, there are {:} attention probabilities.".format( | ||||
|             len(self.width_attentions) | ||||
|         ) | ||||
|         discrepancy = [] | ||||
|         with torch.no_grad(): | ||||
|             for i, att in enumerate(self.width_attentions): | ||||
|                 prob = nn.functional.softmax(att, dim=0) | ||||
|                 prob = prob.cpu() | ||||
|                 selc = prob.argmax().item() | ||||
|                 prob = prob.tolist() | ||||
|                 prob = ["{:.3f}".format(x) for x in prob] | ||||
|                 xstring = "{:03d}/{:03d}-th : {:}".format( | ||||
|                     i, len(self.width_attentions), " ".join(prob) | ||||
|                 ) | ||||
|                 logt = ["{:.3f}".format(x) for x in att.cpu().tolist()] | ||||
|                 xstring += "  ||  {:52s}".format(" ".join(logt)) | ||||
|                 prob = sorted([float(x) for x in prob]) | ||||
|                 disc = prob[-1] - prob[-2] | ||||
|                 xstring += "  || dis={:.2f} || select={:}/{:}".format( | ||||
|                     disc, selc, len(prob) | ||||
|                 ) | ||||
|                 discrepancy.append(disc) | ||||
|                 string += "\n{:}".format(xstring) | ||||
|         return string, discrepancy | ||||
|  | ||||
|     def set_tau(self, tau_max, tau_min, epoch_ratio): | ||||
|         assert ( | ||||
|             epoch_ratio >= 0 and epoch_ratio <= 1 | ||||
|         ), "invalid epoch-ratio : {:}".format(epoch_ratio) | ||||
|         tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2 | ||||
|         self.tau = tau | ||||
|  | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.search_mode == "basic": | ||||
|             return self.basic_forward(inputs) | ||||
|         elif self.search_mode == "search": | ||||
|             return self.search_forward(inputs) | ||||
|         else: | ||||
|             raise ValueError("invalid search_mode = {:}".format(self.search_mode)) | ||||
|  | ||||
|     def search_forward(self, inputs): | ||||
|         flop_probs = nn.functional.softmax(self.width_attentions, dim=1) | ||||
|         selected_widths, selected_probs = select2withP(self.width_attentions, self.tau) | ||||
|         with torch.no_grad(): | ||||
|             selected_widths = selected_widths.cpu() | ||||
|  | ||||
|         x, last_channel_idx, expected_inC, flops = inputs, 0, 3, [] | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             selected_w_index = selected_widths[ | ||||
|                 last_channel_idx : last_channel_idx + layer.num_conv | ||||
|             ] | ||||
|             selected_w_probs = selected_probs[ | ||||
|                 last_channel_idx : last_channel_idx + layer.num_conv | ||||
|             ] | ||||
|             layer_prob = flop_probs[ | ||||
|                 last_channel_idx : last_channel_idx + layer.num_conv | ||||
|             ] | ||||
|             x, expected_inC, expected_flop = layer( | ||||
|                 (x, expected_inC, layer_prob, selected_w_index, selected_w_probs) | ||||
|             ) | ||||
|             last_channel_idx += layer.num_conv | ||||
|             flops.append(expected_flop) | ||||
|         flops.append(expected_inC * (self.classifier.out_features * 1.0 / 1e6)) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = linear_forward(features, self.classifier) | ||||
|         return logits, torch.stack([sum(flops)]) | ||||
|  | ||||
|     def basic_forward(self, inputs): | ||||
|         if self.InShape is None: | ||||
|             self.InShape = (inputs.size(-2), inputs.size(-1)) | ||||
|         x = inputs | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             x = layer(x) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = self.classifier(features) | ||||
|         return features, logits | ||||
| @@ -1,766 +0,0 @@ | ||||
| import math, torch | ||||
| from collections import OrderedDict | ||||
| from bisect import bisect_right | ||||
| import torch.nn as nn | ||||
| from ..initialization import initialize_resnet | ||||
| from ..SharedUtils import additive_func | ||||
| from .SoftSelect import select2withP, ChannelWiseInter | ||||
| from .SoftSelect import linear_forward | ||||
| from .SoftSelect import get_width_choices | ||||
|  | ||||
|  | ||||
| def get_depth_choices(layers): | ||||
|     min_depth = min(layers) | ||||
|     info = {"num": min_depth} | ||||
|     for i, depth in enumerate(layers): | ||||
|         choices = [] | ||||
|         for j in range(1, min_depth + 1): | ||||
|             choices.append(int(float(depth) * j / min_depth)) | ||||
|         info[i] = choices | ||||
|     return info | ||||
|  | ||||
|  | ||||
| def conv_forward(inputs, conv, choices): | ||||
|     iC = conv.in_channels | ||||
|     fill_size = list(inputs.size()) | ||||
|     fill_size[1] = iC - fill_size[1] | ||||
|     filled = torch.zeros(fill_size, device=inputs.device) | ||||
|     xinputs = torch.cat((inputs, filled), dim=1) | ||||
|     outputs = conv(xinputs) | ||||
|     selecteds = [outputs[:, :oC] for oC in choices] | ||||
|     return selecteds | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|     num_conv = 1 | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         nIn, | ||||
|         nOut, | ||||
|         kernel, | ||||
|         stride, | ||||
|         padding, | ||||
|         bias, | ||||
|         has_avg, | ||||
|         has_bn, | ||||
|         has_relu, | ||||
|         last_max_pool=False, | ||||
|     ): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|         self.InShape = None | ||||
|         self.OutShape = None | ||||
|         self.choices = get_width_choices(nOut) | ||||
|         self.register_buffer("choices_tensor", torch.Tensor(self.choices)) | ||||
|  | ||||
|         if has_avg: | ||||
|             self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | ||||
|         else: | ||||
|             self.avg = None | ||||
|         self.conv = nn.Conv2d( | ||||
|             nIn, | ||||
|             nOut, | ||||
|             kernel_size=kernel, | ||||
|             stride=stride, | ||||
|             padding=padding, | ||||
|             dilation=1, | ||||
|             groups=1, | ||||
|             bias=bias, | ||||
|         ) | ||||
|         # if has_bn  : self.bn  = nn.BatchNorm2d(nOut) | ||||
|         # else       : self.bn  = None | ||||
|         self.has_bn = has_bn | ||||
|         self.BNs = nn.ModuleList() | ||||
|         for i, _out in enumerate(self.choices): | ||||
|             self.BNs.append(nn.BatchNorm2d(_out)) | ||||
|         if has_relu: | ||||
|             self.relu = nn.ReLU(inplace=True) | ||||
|         else: | ||||
|             self.relu = None | ||||
|  | ||||
|         if last_max_pool: | ||||
|             self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | ||||
|         else: | ||||
|             self.maxpool = None | ||||
|         self.in_dim = nIn | ||||
|         self.out_dim = nOut | ||||
|         self.search_mode = "basic" | ||||
|  | ||||
|     def get_flops(self, channels, check_range=True, divide=1): | ||||
|         iC, oC = channels | ||||
|         if check_range: | ||||
|             assert ( | ||||
|                 iC <= self.conv.in_channels and oC <= self.conv.out_channels | ||||
|             ), "{:} vs {:}  |  {:} vs {:}".format( | ||||
|                 iC, self.conv.in_channels, oC, self.conv.out_channels | ||||
|             ) | ||||
|         assert ( | ||||
|             isinstance(self.InShape, tuple) and len(self.InShape) == 2 | ||||
|         ), "invalid in-shape : {:}".format(self.InShape) | ||||
|         assert ( | ||||
|             isinstance(self.OutShape, tuple) and len(self.OutShape) == 2 | ||||
|         ), "invalid out-shape : {:}".format(self.OutShape) | ||||
|         # conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups | ||||
|         conv_per_position_flops = ( | ||||
|             self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups | ||||
|         ) | ||||
|         all_positions = self.OutShape[0] * self.OutShape[1] | ||||
|         flops = (conv_per_position_flops * all_positions / divide) * iC * oC | ||||
|         if self.conv.bias is not None: | ||||
|             flops += all_positions / divide | ||||
|         return flops | ||||
|  | ||||
|     def get_range(self): | ||||
|         return [self.choices] | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.search_mode == "basic": | ||||
|             return self.basic_forward(inputs) | ||||
|         elif self.search_mode == "search": | ||||
|             return self.search_forward(inputs) | ||||
|         else: | ||||
|             raise ValueError("invalid search_mode = {:}".format(self.search_mode)) | ||||
|  | ||||
|     def search_forward(self, tuple_inputs): | ||||
|         assert ( | ||||
|             isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 | ||||
|         ), "invalid type input : {:}".format(type(tuple_inputs)) | ||||
|         inputs, expected_inC, probability, index, prob = tuple_inputs | ||||
|         index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob) | ||||
|         probability = torch.squeeze(probability) | ||||
|         assert len(index) == 2, "invalid length : {:}".format(index) | ||||
|         # compute expected flop | ||||
|         # coordinates   = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability) | ||||
|         expected_outC = (self.choices_tensor * probability).sum() | ||||
|         expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6) | ||||
|         if self.avg: | ||||
|             out = self.avg(inputs) | ||||
|         else: | ||||
|             out = inputs | ||||
|         # convolutional layer | ||||
|         out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index]) | ||||
|         out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)] | ||||
|         # merge | ||||
|         out_channel = max([x.size(1) for x in out_bns]) | ||||
|         outA = ChannelWiseInter(out_bns[0], out_channel) | ||||
|         outB = ChannelWiseInter(out_bns[1], out_channel) | ||||
|         out = outA * prob[0] + outB * prob[1] | ||||
|         # out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1]) | ||||
|  | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         if self.maxpool: | ||||
|             out = self.maxpool(out) | ||||
|         return out, expected_outC, expected_flop | ||||
|  | ||||
|     def basic_forward(self, inputs): | ||||
|         if self.avg: | ||||
|             out = self.avg(inputs) | ||||
|         else: | ||||
|             out = inputs | ||||
|         conv = self.conv(out) | ||||
|         if self.has_bn: | ||||
|             out = self.BNs[-1](conv) | ||||
|         else: | ||||
|             out = conv | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         else: | ||||
|             out = out | ||||
|         if self.InShape is None: | ||||
|             self.InShape = (inputs.size(-2), inputs.size(-1)) | ||||
|             self.OutShape = (out.size(-2), out.size(-1)) | ||||
|         if self.maxpool: | ||||
|             out = self.maxpool(out) | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class ResNetBasicblock(nn.Module): | ||||
|     expansion = 1 | ||||
|     num_conv = 2 | ||||
|  | ||||
|     def __init__(self, inplanes, planes, stride): | ||||
|         super(ResNetBasicblock, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         self.conv_a = ConvBNReLU( | ||||
|             inplanes, | ||||
|             planes, | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_b = ConvBNReLU( | ||||
|             planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False | ||||
|         ) | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         elif inplanes != planes: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.out_dim = planes | ||||
|         self.search_mode = "basic" | ||||
|  | ||||
|     def get_range(self): | ||||
|         return self.conv_a.get_range() + self.conv_b.get_range() | ||||
|  | ||||
|     def get_flops(self, channels): | ||||
|         assert len(channels) == 3, "invalid channels : {:}".format(channels) | ||||
|         flop_A = self.conv_a.get_flops([channels[0], channels[1]]) | ||||
|         flop_B = self.conv_b.get_flops([channels[1], channels[2]]) | ||||
|         if hasattr(self.downsample, "get_flops"): | ||||
|             flop_C = self.downsample.get_flops([channels[0], channels[-1]]) | ||||
|         else: | ||||
|             flop_C = 0 | ||||
|         if ( | ||||
|             channels[0] != channels[-1] and self.downsample is None | ||||
|         ):  # this short-cut will be added during the infer-train | ||||
|             flop_C = ( | ||||
|                 channels[0] | ||||
|                 * channels[-1] | ||||
|                 * self.conv_b.OutShape[0] | ||||
|                 * self.conv_b.OutShape[1] | ||||
|             ) | ||||
|         return flop_A + flop_B + flop_C | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.search_mode == "basic": | ||||
|             return self.basic_forward(inputs) | ||||
|         elif self.search_mode == "search": | ||||
|             return self.search_forward(inputs) | ||||
|         else: | ||||
|             raise ValueError("invalid search_mode = {:}".format(self.search_mode)) | ||||
|  | ||||
|     def search_forward(self, tuple_inputs): | ||||
|         assert ( | ||||
|             isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 | ||||
|         ), "invalid type input : {:}".format(type(tuple_inputs)) | ||||
|         inputs, expected_inC, probability, indexes, probs = tuple_inputs | ||||
|         assert indexes.size(0) == 2 and probs.size(0) == 2 and probability.size(0) == 2 | ||||
|         # import pdb; pdb.set_trace() | ||||
|         out_a, expected_inC_a, expected_flop_a = self.conv_a( | ||||
|             (inputs, expected_inC, probability[0], indexes[0], probs[0]) | ||||
|         ) | ||||
|         out_b, expected_inC_b, expected_flop_b = self.conv_b( | ||||
|             (out_a, expected_inC_a, probability[1], indexes[1], probs[1]) | ||||
|         ) | ||||
|         if self.downsample is not None: | ||||
|             residual, _, expected_flop_c = self.downsample( | ||||
|                 (inputs, expected_inC, probability[1], indexes[1], probs[1]) | ||||
|             ) | ||||
|         else: | ||||
|             residual, expected_flop_c = inputs, 0 | ||||
|         out = additive_func(residual, out_b) | ||||
|         return ( | ||||
|             nn.functional.relu(out, inplace=True), | ||||
|             expected_inC_b, | ||||
|             sum([expected_flop_a, expected_flop_b, expected_flop_c]), | ||||
|         ) | ||||
|  | ||||
|     def basic_forward(self, inputs): | ||||
|         basicblock = self.conv_a(inputs) | ||||
|         basicblock = self.conv_b(basicblock) | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = additive_func(residual, basicblock) | ||||
|         return nn.functional.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class ResNetBottleneck(nn.Module): | ||||
|     expansion = 4 | ||||
|     num_conv = 3 | ||||
|  | ||||
|     def __init__(self, inplanes, planes, stride): | ||||
|         super(ResNetBottleneck, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         self.conv_1x1 = ConvBNReLU( | ||||
|             inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True | ||||
|         ) | ||||
|         self.conv_3x3 = ConvBNReLU( | ||||
|             planes, | ||||
|             planes, | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         self.conv_1x4 = ConvBNReLU( | ||||
|             planes, | ||||
|             planes * self.expansion, | ||||
|             1, | ||||
|             1, | ||||
|             0, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=False, | ||||
|         ) | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes * self.expansion, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         elif inplanes != planes * self.expansion: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes * self.expansion, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.out_dim = planes * self.expansion | ||||
|         self.search_mode = "basic" | ||||
|  | ||||
|     def get_range(self): | ||||
|         return ( | ||||
|             self.conv_1x1.get_range() | ||||
|             + self.conv_3x3.get_range() | ||||
|             + self.conv_1x4.get_range() | ||||
|         ) | ||||
|  | ||||
|     def get_flops(self, channels): | ||||
|         assert len(channels) == 4, "invalid channels : {:}".format(channels) | ||||
|         flop_A = self.conv_1x1.get_flops([channels[0], channels[1]]) | ||||
|         flop_B = self.conv_3x3.get_flops([channels[1], channels[2]]) | ||||
|         flop_C = self.conv_1x4.get_flops([channels[2], channels[3]]) | ||||
|         if hasattr(self.downsample, "get_flops"): | ||||
|             flop_D = self.downsample.get_flops([channels[0], channels[-1]]) | ||||
|         else: | ||||
|             flop_D = 0 | ||||
|         if ( | ||||
|             channels[0] != channels[-1] and self.downsample is None | ||||
|         ):  # this short-cut will be added during the infer-train | ||||
|             flop_D = ( | ||||
|                 channels[0] | ||||
|                 * channels[-1] | ||||
|                 * self.conv_1x4.OutShape[0] | ||||
|                 * self.conv_1x4.OutShape[1] | ||||
|             ) | ||||
|         return flop_A + flop_B + flop_C + flop_D | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.search_mode == "basic": | ||||
|             return self.basic_forward(inputs) | ||||
|         elif self.search_mode == "search": | ||||
|             return self.search_forward(inputs) | ||||
|         else: | ||||
|             raise ValueError("invalid search_mode = {:}".format(self.search_mode)) | ||||
|  | ||||
|     def basic_forward(self, inputs): | ||||
|         bottleneck = self.conv_1x1(inputs) | ||||
|         bottleneck = self.conv_3x3(bottleneck) | ||||
|         bottleneck = self.conv_1x4(bottleneck) | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = additive_func(residual, bottleneck) | ||||
|         return nn.functional.relu(out, inplace=True) | ||||
|  | ||||
|     def search_forward(self, tuple_inputs): | ||||
|         assert ( | ||||
|             isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 | ||||
|         ), "invalid type input : {:}".format(type(tuple_inputs)) | ||||
|         inputs, expected_inC, probability, indexes, probs = tuple_inputs | ||||
|         assert indexes.size(0) == 3 and probs.size(0) == 3 and probability.size(0) == 3 | ||||
|         out_1x1, expected_inC_1x1, expected_flop_1x1 = self.conv_1x1( | ||||
|             (inputs, expected_inC, probability[0], indexes[0], probs[0]) | ||||
|         ) | ||||
|         out_3x3, expected_inC_3x3, expected_flop_3x3 = self.conv_3x3( | ||||
|             (out_1x1, expected_inC_1x1, probability[1], indexes[1], probs[1]) | ||||
|         ) | ||||
|         out_1x4, expected_inC_1x4, expected_flop_1x4 = self.conv_1x4( | ||||
|             (out_3x3, expected_inC_3x3, probability[2], indexes[2], probs[2]) | ||||
|         ) | ||||
|         if self.downsample is not None: | ||||
|             residual, _, expected_flop_c = self.downsample( | ||||
|                 (inputs, expected_inC, probability[2], indexes[2], probs[2]) | ||||
|             ) | ||||
|         else: | ||||
|             residual, expected_flop_c = inputs, 0 | ||||
|         out = additive_func(residual, out_1x4) | ||||
|         return ( | ||||
|             nn.functional.relu(out, inplace=True), | ||||
|             expected_inC_1x4, | ||||
|             sum( | ||||
|                 [ | ||||
|                     expected_flop_1x1, | ||||
|                     expected_flop_3x3, | ||||
|                     expected_flop_1x4, | ||||
|                     expected_flop_c, | ||||
|                 ] | ||||
|             ), | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class SearchShapeImagenetResNet(nn.Module): | ||||
|     def __init__(self, block_name, layers, deep_stem, num_classes): | ||||
|         super(SearchShapeImagenetResNet, self).__init__() | ||||
|  | ||||
|         # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|         if block_name == "BasicBlock": | ||||
|             block = ResNetBasicblock | ||||
|         elif block_name == "Bottleneck": | ||||
|             block = ResNetBottleneck | ||||
|         else: | ||||
|             raise ValueError("invalid block : {:}".format(block_name)) | ||||
|  | ||||
|         self.message = ( | ||||
|             "SearchShapeCifarResNet : Depth : {:} , Layers for each block : {:}".format( | ||||
|                 sum(layers) * block.num_conv, layers | ||||
|             ) | ||||
|         ) | ||||
|         self.num_classes = num_classes | ||||
|         if not deep_stem: | ||||
|             self.layers = nn.ModuleList( | ||||
|                 [ | ||||
|                     ConvBNReLU( | ||||
|                         3, | ||||
|                         64, | ||||
|                         7, | ||||
|                         2, | ||||
|                         3, | ||||
|                         False, | ||||
|                         has_avg=False, | ||||
|                         has_bn=True, | ||||
|                         has_relu=True, | ||||
|                         last_max_pool=True, | ||||
|                     ) | ||||
|                 ] | ||||
|             ) | ||||
|             self.channels = [64] | ||||
|         else: | ||||
|             self.layers = nn.ModuleList( | ||||
|                 [ | ||||
|                     ConvBNReLU( | ||||
|                         3, 32, 3, 2, 1, False, has_avg=False, has_bn=True, has_relu=True | ||||
|                     ), | ||||
|                     ConvBNReLU( | ||||
|                         32, | ||||
|                         64, | ||||
|                         3, | ||||
|                         1, | ||||
|                         1, | ||||
|                         False, | ||||
|                         has_avg=False, | ||||
|                         has_bn=True, | ||||
|                         has_relu=True, | ||||
|                         last_max_pool=True, | ||||
|                     ), | ||||
|                 ] | ||||
|             ) | ||||
|             self.channels = [32, 64] | ||||
|  | ||||
|         meta_depth_info = get_depth_choices(layers) | ||||
|         self.InShape = None | ||||
|         self.depth_info = OrderedDict() | ||||
|         self.depth_at_i = OrderedDict() | ||||
|         for stage, layer_blocks in enumerate(layers): | ||||
|             cur_block_choices = meta_depth_info[stage] | ||||
|             assert ( | ||||
|                 cur_block_choices[-1] == layer_blocks | ||||
|             ), "stage={:}, {:} vs {:}".format(stage, cur_block_choices, layer_blocks) | ||||
|             block_choices, xstart = [], len(self.layers) | ||||
|             for iL in range(layer_blocks): | ||||
|                 iC = self.channels[-1] | ||||
|                 planes = 64 * (2 ** stage) | ||||
|                 stride = 2 if stage > 0 and iL == 0 else 1 | ||||
|                 module = block(iC, planes, stride) | ||||
|                 self.channels.append(module.out_dim) | ||||
|                 self.layers.append(module) | ||||
|                 self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format( | ||||
|                     stage, | ||||
|                     iL, | ||||
|                     layer_blocks, | ||||
|                     len(self.layers) - 1, | ||||
|                     iC, | ||||
|                     module.out_dim, | ||||
|                     stride, | ||||
|                 ) | ||||
|                 # added for depth | ||||
|                 layer_index = len(self.layers) - 1 | ||||
|                 if iL + 1 in cur_block_choices: | ||||
|                     block_choices.append(layer_index) | ||||
|                 if iL + 1 == layer_blocks: | ||||
|                     self.depth_info[layer_index] = { | ||||
|                         "choices": block_choices, | ||||
|                         "stage": stage, | ||||
|                         "xstart": xstart, | ||||
|                     } | ||||
|         self.depth_info_list = [] | ||||
|         for xend, info in self.depth_info.items(): | ||||
|             self.depth_info_list.append((xend, info)) | ||||
|             xstart, xstage = info["xstart"], info["stage"] | ||||
|             for ilayer in range(xstart, xend + 1): | ||||
|                 idx = bisect_right(info["choices"], ilayer - 1) | ||||
|                 self.depth_at_i[ilayer] = (xstage, idx) | ||||
|  | ||||
|         self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | ||||
|         self.classifier = nn.Linear(module.out_dim, num_classes) | ||||
|         self.InShape = None | ||||
|         self.tau = -1 | ||||
|         self.search_mode = "basic" | ||||
|         # assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth) | ||||
|  | ||||
|         # parameters for width | ||||
|         self.Ranges = [] | ||||
|         self.layer2indexRange = [] | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             start_index = len(self.Ranges) | ||||
|             self.Ranges += layer.get_range() | ||||
|             self.layer2indexRange.append((start_index, len(self.Ranges))) | ||||
|  | ||||
|         self.register_parameter( | ||||
|             "width_attentions", | ||||
|             nn.Parameter(torch.Tensor(len(self.Ranges), get_width_choices(None))), | ||||
|         ) | ||||
|         self.register_parameter( | ||||
|             "depth_attentions", | ||||
|             nn.Parameter(torch.Tensor(len(layers), meta_depth_info["num"])), | ||||
|         ) | ||||
|         nn.init.normal_(self.width_attentions, 0, 0.01) | ||||
|         nn.init.normal_(self.depth_attentions, 0, 0.01) | ||||
|         self.apply(initialize_resnet) | ||||
|  | ||||
|     def arch_parameters(self, LR=None): | ||||
|         if LR is None: | ||||
|             return [self.width_attentions, self.depth_attentions] | ||||
|         else: | ||||
|             return [ | ||||
|                 {"params": self.width_attentions, "lr": LR}, | ||||
|                 {"params": self.depth_attentions, "lr": LR}, | ||||
|             ] | ||||
|  | ||||
|     def base_parameters(self): | ||||
|         return ( | ||||
|             list(self.layers.parameters()) | ||||
|             + list(self.avgpool.parameters()) | ||||
|             + list(self.classifier.parameters()) | ||||
|         ) | ||||
|  | ||||
|     def get_flop(self, mode, config_dict, extra_info): | ||||
|         if config_dict is not None: | ||||
|             config_dict = config_dict.copy() | ||||
|         # select channels | ||||
|         channels = [3] | ||||
|         for i, weight in enumerate(self.width_attentions): | ||||
|             if mode == "genotype": | ||||
|                 with torch.no_grad(): | ||||
|                     probe = nn.functional.softmax(weight, dim=0) | ||||
|                     C = self.Ranges[i][torch.argmax(probe).item()] | ||||
|             else: | ||||
|                 raise ValueError("invalid mode : {:}".format(mode)) | ||||
|             channels.append(C) | ||||
|         # select depth | ||||
|         if mode == "genotype": | ||||
|             with torch.no_grad(): | ||||
|                 depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) | ||||
|                 choices = torch.argmax(depth_probs, dim=1).cpu().tolist() | ||||
|         else: | ||||
|             raise ValueError("invalid mode : {:}".format(mode)) | ||||
|         selected_layers = [] | ||||
|         for choice, xvalue in zip(choices, self.depth_info_list): | ||||
|             xtemp = xvalue[1]["choices"][choice] - xvalue[1]["xstart"] + 1 | ||||
|             selected_layers.append(xtemp) | ||||
|         flop = 0 | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             s, e = self.layer2indexRange[i] | ||||
|             xchl = tuple(channels[s : e + 1]) | ||||
|             if i in self.depth_at_i: | ||||
|                 xstagei, xatti = self.depth_at_i[i] | ||||
|                 if xatti <= choices[xstagei]:  # leave this depth | ||||
|                     flop += layer.get_flops(xchl) | ||||
|                 else: | ||||
|                     flop += 0  # do not use this layer | ||||
|             else: | ||||
|                 flop += layer.get_flops(xchl) | ||||
|         # the last fc layer | ||||
|         flop += channels[-1] * self.classifier.out_features | ||||
|         if config_dict is None: | ||||
|             return flop / 1e6 | ||||
|         else: | ||||
|             config_dict["xchannels"] = channels | ||||
|             config_dict["xblocks"] = selected_layers | ||||
|             config_dict["super_type"] = "infer-shape" | ||||
|             config_dict["estimated_FLOP"] = flop / 1e6 | ||||
|             return flop / 1e6, config_dict | ||||
|  | ||||
|     def get_arch_info(self): | ||||
|         string = ( | ||||
|             "for depth and width, there are {:} + {:} attention probabilities.".format( | ||||
|                 len(self.depth_attentions), len(self.width_attentions) | ||||
|             ) | ||||
|         ) | ||||
|         string += "\n{:}".format(self.depth_info) | ||||
|         discrepancy = [] | ||||
|         with torch.no_grad(): | ||||
|             for i, att in enumerate(self.depth_attentions): | ||||
|                 prob = nn.functional.softmax(att, dim=0) | ||||
|                 prob = prob.cpu() | ||||
|                 selc = prob.argmax().item() | ||||
|                 prob = prob.tolist() | ||||
|                 prob = ["{:.3f}".format(x) for x in prob] | ||||
|                 xstring = "{:03d}/{:03d}-th : {:}".format( | ||||
|                     i, len(self.depth_attentions), " ".join(prob) | ||||
|                 ) | ||||
|                 logt = ["{:.4f}".format(x) for x in att.cpu().tolist()] | ||||
|                 xstring += "  ||  {:17s}".format(" ".join(logt)) | ||||
|                 prob = sorted([float(x) for x in prob]) | ||||
|                 disc = prob[-1] - prob[-2] | ||||
|                 xstring += "  || discrepancy={:.2f} || select={:}/{:}".format( | ||||
|                     disc, selc, len(prob) | ||||
|                 ) | ||||
|                 discrepancy.append(disc) | ||||
|                 string += "\n{:}".format(xstring) | ||||
|             string += "\n-----------------------------------------------" | ||||
|             for i, att in enumerate(self.width_attentions): | ||||
|                 prob = nn.functional.softmax(att, dim=0) | ||||
|                 prob = prob.cpu() | ||||
|                 selc = prob.argmax().item() | ||||
|                 prob = prob.tolist() | ||||
|                 prob = ["{:.3f}".format(x) for x in prob] | ||||
|                 xstring = "{:03d}/{:03d}-th : {:}".format( | ||||
|                     i, len(self.width_attentions), " ".join(prob) | ||||
|                 ) | ||||
|                 logt = ["{:.3f}".format(x) for x in att.cpu().tolist()] | ||||
|                 xstring += "  ||  {:52s}".format(" ".join(logt)) | ||||
|                 prob = sorted([float(x) for x in prob]) | ||||
|                 disc = prob[-1] - prob[-2] | ||||
|                 xstring += "  || dis={:.2f} || select={:}/{:}".format( | ||||
|                     disc, selc, len(prob) | ||||
|                 ) | ||||
|                 discrepancy.append(disc) | ||||
|                 string += "\n{:}".format(xstring) | ||||
|         return string, discrepancy | ||||
|  | ||||
|     def set_tau(self, tau_max, tau_min, epoch_ratio): | ||||
|         assert ( | ||||
|             epoch_ratio >= 0 and epoch_ratio <= 1 | ||||
|         ), "invalid epoch-ratio : {:}".format(epoch_ratio) | ||||
|         tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2 | ||||
|         self.tau = tau | ||||
|  | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.search_mode == "basic": | ||||
|             return self.basic_forward(inputs) | ||||
|         elif self.search_mode == "search": | ||||
|             return self.search_forward(inputs) | ||||
|         else: | ||||
|             raise ValueError("invalid search_mode = {:}".format(self.search_mode)) | ||||
|  | ||||
|     def search_forward(self, inputs): | ||||
|         flop_width_probs = nn.functional.softmax(self.width_attentions, dim=1) | ||||
|         flop_depth_probs = nn.functional.softmax(self.depth_attentions, dim=1) | ||||
|         flop_depth_probs = torch.flip( | ||||
|             torch.cumsum(torch.flip(flop_depth_probs, [1]), 1), [1] | ||||
|         ) | ||||
|         selected_widths, selected_width_probs = select2withP( | ||||
|             self.width_attentions, self.tau | ||||
|         ) | ||||
|         selected_depth_probs = select2withP(self.depth_attentions, self.tau, True) | ||||
|         with torch.no_grad(): | ||||
|             selected_widths = selected_widths.cpu() | ||||
|  | ||||
|         x, last_channel_idx, expected_inC, flops = inputs, 0, 3, [] | ||||
|         feature_maps = [] | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             selected_w_index = selected_widths[ | ||||
|                 last_channel_idx : last_channel_idx + layer.num_conv | ||||
|             ] | ||||
|             selected_w_probs = selected_width_probs[ | ||||
|                 last_channel_idx : last_channel_idx + layer.num_conv | ||||
|             ] | ||||
|             layer_prob = flop_width_probs[ | ||||
|                 last_channel_idx : last_channel_idx + layer.num_conv | ||||
|             ] | ||||
|             x, expected_inC, expected_flop = layer( | ||||
|                 (x, expected_inC, layer_prob, selected_w_index, selected_w_probs) | ||||
|             ) | ||||
|             feature_maps.append(x) | ||||
|             last_channel_idx += layer.num_conv | ||||
|             if i in self.depth_info:  # aggregate the information | ||||
|                 choices = self.depth_info[i]["choices"] | ||||
|                 xstagei = self.depth_info[i]["stage"] | ||||
|                 # print ('iL={:}, choices={:}, stage={:}, probs={:}'.format(i, choices, xstagei, selected_depth_probs[xstagei].cpu().tolist())) | ||||
|                 # for A, W in zip(choices, selected_depth_probs[xstagei]): | ||||
|                 #  print('Size = {:}, W = {:}'.format(feature_maps[A].size(), W)) | ||||
|                 possible_tensors = [] | ||||
|                 max_C = max(feature_maps[A].size(1) for A in choices) | ||||
|                 for tempi, A in enumerate(choices): | ||||
|                     xtensor = ChannelWiseInter(feature_maps[A], max_C) | ||||
|                     possible_tensors.append(xtensor) | ||||
|                 weighted_sum = sum( | ||||
|                     xtensor * W | ||||
|                     for xtensor, W in zip( | ||||
|                         possible_tensors, selected_depth_probs[xstagei] | ||||
|                     ) | ||||
|                 ) | ||||
|                 x = weighted_sum | ||||
|  | ||||
|             if i in self.depth_at_i: | ||||
|                 xstagei, xatti = self.depth_at_i[i] | ||||
|                 x_expected_flop = flop_depth_probs[xstagei, xatti] * expected_flop | ||||
|             else: | ||||
|                 x_expected_flop = expected_flop | ||||
|             flops.append(x_expected_flop) | ||||
|         flops.append(expected_inC * (self.classifier.out_features * 1.0 / 1e6)) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = linear_forward(features, self.classifier) | ||||
|         return logits, torch.stack([sum(flops)]) | ||||
|  | ||||
|     def basic_forward(self, inputs): | ||||
|         if self.InShape is None: | ||||
|             self.InShape = (inputs.size(-2), inputs.size(-1)) | ||||
|         x = inputs | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             x = layer(x) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = self.classifier(features) | ||||
|         return features, logits | ||||
| @@ -1,466 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import math, torch | ||||
| import torch.nn as nn | ||||
| from ..initialization import initialize_resnet | ||||
| from ..SharedUtils import additive_func | ||||
| from .SoftSelect import select2withP, ChannelWiseInter | ||||
| from .SoftSelect import linear_forward | ||||
| from .SoftSelect import get_width_choices as get_choices | ||||
|  | ||||
|  | ||||
| def conv_forward(inputs, conv, choices): | ||||
|     iC = conv.in_channels | ||||
|     fill_size = list(inputs.size()) | ||||
|     fill_size[1] = iC - fill_size[1] | ||||
|     filled = torch.zeros(fill_size, device=inputs.device) | ||||
|     xinputs = torch.cat((inputs, filled), dim=1) | ||||
|     outputs = conv(xinputs) | ||||
|     selecteds = [outputs[:, :oC] for oC in choices] | ||||
|     return selecteds | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|     num_conv = 1 | ||||
|  | ||||
|     def __init__( | ||||
|         self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu | ||||
|     ): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|         self.InShape = None | ||||
|         self.OutShape = None | ||||
|         self.choices = get_choices(nOut) | ||||
|         self.register_buffer("choices_tensor", torch.Tensor(self.choices)) | ||||
|  | ||||
|         if has_avg: | ||||
|             self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | ||||
|         else: | ||||
|             self.avg = None | ||||
|         self.conv = nn.Conv2d( | ||||
|             nIn, | ||||
|             nOut, | ||||
|             kernel_size=kernel, | ||||
|             stride=stride, | ||||
|             padding=padding, | ||||
|             dilation=1, | ||||
|             groups=1, | ||||
|             bias=bias, | ||||
|         ) | ||||
|         # if has_bn  : self.bn  = nn.BatchNorm2d(nOut) | ||||
|         # else       : self.bn  = None | ||||
|         self.has_bn = has_bn | ||||
|         self.BNs = nn.ModuleList() | ||||
|         for i, _out in enumerate(self.choices): | ||||
|             self.BNs.append(nn.BatchNorm2d(_out)) | ||||
|         if has_relu: | ||||
|             self.relu = nn.ReLU(inplace=True) | ||||
|         else: | ||||
|             self.relu = None | ||||
|         self.in_dim = nIn | ||||
|         self.out_dim = nOut | ||||
|         self.search_mode = "basic" | ||||
|  | ||||
|     def get_flops(self, channels, check_range=True, divide=1): | ||||
|         iC, oC = channels | ||||
|         if check_range: | ||||
|             assert ( | ||||
|                 iC <= self.conv.in_channels and oC <= self.conv.out_channels | ||||
|             ), "{:} vs {:}  |  {:} vs {:}".format( | ||||
|                 iC, self.conv.in_channels, oC, self.conv.out_channels | ||||
|             ) | ||||
|         assert ( | ||||
|             isinstance(self.InShape, tuple) and len(self.InShape) == 2 | ||||
|         ), "invalid in-shape : {:}".format(self.InShape) | ||||
|         assert ( | ||||
|             isinstance(self.OutShape, tuple) and len(self.OutShape) == 2 | ||||
|         ), "invalid out-shape : {:}".format(self.OutShape) | ||||
|         # conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups | ||||
|         conv_per_position_flops = ( | ||||
|             self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups | ||||
|         ) | ||||
|         all_positions = self.OutShape[0] * self.OutShape[1] | ||||
|         flops = (conv_per_position_flops * all_positions / divide) * iC * oC | ||||
|         if self.conv.bias is not None: | ||||
|             flops += all_positions / divide | ||||
|         return flops | ||||
|  | ||||
|     def get_range(self): | ||||
|         return [self.choices] | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.search_mode == "basic": | ||||
|             return self.basic_forward(inputs) | ||||
|         elif self.search_mode == "search": | ||||
|             return self.search_forward(inputs) | ||||
|         else: | ||||
|             raise ValueError("invalid search_mode = {:}".format(self.search_mode)) | ||||
|  | ||||
|     def search_forward(self, tuple_inputs): | ||||
|         assert ( | ||||
|             isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 | ||||
|         ), "invalid type input : {:}".format(type(tuple_inputs)) | ||||
|         inputs, expected_inC, probability, index, prob = tuple_inputs | ||||
|         index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob) | ||||
|         probability = torch.squeeze(probability) | ||||
|         assert len(index) == 2, "invalid length : {:}".format(index) | ||||
|         # compute expected flop | ||||
|         # coordinates   = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability) | ||||
|         expected_outC = (self.choices_tensor * probability).sum() | ||||
|         expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6) | ||||
|         if self.avg: | ||||
|             out = self.avg(inputs) | ||||
|         else: | ||||
|             out = inputs | ||||
|         # convolutional layer | ||||
|         out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index]) | ||||
|         out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)] | ||||
|         # merge | ||||
|         out_channel = max([x.size(1) for x in out_bns]) | ||||
|         outA = ChannelWiseInter(out_bns[0], out_channel) | ||||
|         outB = ChannelWiseInter(out_bns[1], out_channel) | ||||
|         out = outA * prob[0] + outB * prob[1] | ||||
|         # out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1]) | ||||
|  | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         else: | ||||
|             out = out | ||||
|         return out, expected_outC, expected_flop | ||||
|  | ||||
|     def basic_forward(self, inputs): | ||||
|         if self.avg: | ||||
|             out = self.avg(inputs) | ||||
|         else: | ||||
|             out = inputs | ||||
|         conv = self.conv(out) | ||||
|         if self.has_bn: | ||||
|             out = self.BNs[-1](conv) | ||||
|         else: | ||||
|             out = conv | ||||
|         if self.relu: | ||||
|             out = self.relu(out) | ||||
|         else: | ||||
|             out = out | ||||
|         if self.InShape is None: | ||||
|             self.InShape = (inputs.size(-2), inputs.size(-1)) | ||||
|             self.OutShape = (out.size(-2), out.size(-1)) | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class SimBlock(nn.Module): | ||||
|     expansion = 1 | ||||
|     num_conv = 1 | ||||
|  | ||||
|     def __init__(self, inplanes, planes, stride): | ||||
|         super(SimBlock, self).__init__() | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         self.conv = ConvBNReLU( | ||||
|             inplanes, | ||||
|             planes, | ||||
|             3, | ||||
|             stride, | ||||
|             1, | ||||
|             False, | ||||
|             has_avg=False, | ||||
|             has_bn=True, | ||||
|             has_relu=True, | ||||
|         ) | ||||
|         if stride == 2: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=True, | ||||
|                 has_bn=False, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         elif inplanes != planes: | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, | ||||
|                 planes, | ||||
|                 1, | ||||
|                 1, | ||||
|                 0, | ||||
|                 False, | ||||
|                 has_avg=False, | ||||
|                 has_bn=True, | ||||
|                 has_relu=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.out_dim = planes | ||||
|         self.search_mode = "basic" | ||||
|  | ||||
|     def get_range(self): | ||||
|         return self.conv.get_range() | ||||
|  | ||||
|     def get_flops(self, channels): | ||||
|         assert len(channels) == 2, "invalid channels : {:}".format(channels) | ||||
|         flop_A = self.conv.get_flops([channels[0], channels[1]]) | ||||
|         if hasattr(self.downsample, "get_flops"): | ||||
|             flop_C = self.downsample.get_flops([channels[0], channels[-1]]) | ||||
|         else: | ||||
|             flop_C = 0 | ||||
|         if ( | ||||
|             channels[0] != channels[-1] and self.downsample is None | ||||
|         ):  # this short-cut will be added during the infer-train | ||||
|             flop_C = ( | ||||
|                 channels[0] | ||||
|                 * channels[-1] | ||||
|                 * self.conv.OutShape[0] | ||||
|                 * self.conv.OutShape[1] | ||||
|             ) | ||||
|         return flop_A + flop_C | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.search_mode == "basic": | ||||
|             return self.basic_forward(inputs) | ||||
|         elif self.search_mode == "search": | ||||
|             return self.search_forward(inputs) | ||||
|         else: | ||||
|             raise ValueError("invalid search_mode = {:}".format(self.search_mode)) | ||||
|  | ||||
|     def search_forward(self, tuple_inputs): | ||||
|         assert ( | ||||
|             isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5 | ||||
|         ), "invalid type input : {:}".format(type(tuple_inputs)) | ||||
|         inputs, expected_inC, probability, indexes, probs = tuple_inputs | ||||
|         assert ( | ||||
|             indexes.size(0) == 1 and probs.size(0) == 1 and probability.size(0) == 1 | ||||
|         ), "invalid size : {:}, {:}, {:}".format( | ||||
|             indexes.size(), probs.size(), probability.size() | ||||
|         ) | ||||
|         out, expected_next_inC, expected_flop = self.conv( | ||||
|             (inputs, expected_inC, probability[0], indexes[0], probs[0]) | ||||
|         ) | ||||
|         if self.downsample is not None: | ||||
|             residual, _, expected_flop_c = self.downsample( | ||||
|                 (inputs, expected_inC, probability[-1], indexes[-1], probs[-1]) | ||||
|             ) | ||||
|         else: | ||||
|             residual, expected_flop_c = inputs, 0 | ||||
|         out = additive_func(residual, out) | ||||
|         return ( | ||||
|             nn.functional.relu(out, inplace=True), | ||||
|             expected_next_inC, | ||||
|             sum([expected_flop, expected_flop_c]), | ||||
|         ) | ||||
|  | ||||
|     def basic_forward(self, inputs): | ||||
|         basicblock = self.conv(inputs) | ||||
|         if self.downsample is not None: | ||||
|             residual = self.downsample(inputs) | ||||
|         else: | ||||
|             residual = inputs | ||||
|         out = additive_func(residual, basicblock) | ||||
|         return nn.functional.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
| class SearchWidthSimResNet(nn.Module): | ||||
|     def __init__(self, depth, num_classes): | ||||
|         super(SearchWidthSimResNet, self).__init__() | ||||
|  | ||||
|         assert ( | ||||
|             depth - 2 | ||||
|         ) % 3 == 0, "depth should be one of 5, 8, 11, 14, ... instead of {:}".format( | ||||
|             depth | ||||
|         ) | ||||
|         layer_blocks = (depth - 2) // 3 | ||||
|         self.message = ( | ||||
|             "SearchWidthSimResNet : Depth : {:} , Layers for each block : {:}".format( | ||||
|                 depth, layer_blocks | ||||
|             ) | ||||
|         ) | ||||
|         self.num_classes = num_classes | ||||
|         self.channels = [16] | ||||
|         self.layers = nn.ModuleList( | ||||
|             [ | ||||
|                 ConvBNReLU( | ||||
|                     3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True | ||||
|                 ) | ||||
|             ] | ||||
|         ) | ||||
|         self.InShape = None | ||||
|         for stage in range(3): | ||||
|             for iL in range(layer_blocks): | ||||
|                 iC = self.channels[-1] | ||||
|                 planes = 16 * (2 ** stage) | ||||
|                 stride = 2 if stage > 0 and iL == 0 else 1 | ||||
|                 module = SimBlock(iC, planes, stride) | ||||
|                 self.channels.append(module.out_dim) | ||||
|                 self.layers.append(module) | ||||
|                 self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format( | ||||
|                     stage, | ||||
|                     iL, | ||||
|                     layer_blocks, | ||||
|                     len(self.layers) - 1, | ||||
|                     iC, | ||||
|                     module.out_dim, | ||||
|                     stride, | ||||
|                 ) | ||||
|  | ||||
|         self.avgpool = nn.AvgPool2d(8) | ||||
|         self.classifier = nn.Linear(module.out_dim, num_classes) | ||||
|         self.InShape = None | ||||
|         self.tau = -1 | ||||
|         self.search_mode = "basic" | ||||
|         # assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth) | ||||
|  | ||||
|         # parameters for width | ||||
|         self.Ranges = [] | ||||
|         self.layer2indexRange = [] | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             start_index = len(self.Ranges) | ||||
|             self.Ranges += layer.get_range() | ||||
|             self.layer2indexRange.append((start_index, len(self.Ranges))) | ||||
|         assert len(self.Ranges) + 1 == depth, "invalid depth check {:} vs {:}".format( | ||||
|             len(self.Ranges) + 1, depth | ||||
|         ) | ||||
|  | ||||
|         self.register_parameter( | ||||
|             "width_attentions", | ||||
|             nn.Parameter(torch.Tensor(len(self.Ranges), get_choices(None))), | ||||
|         ) | ||||
|         nn.init.normal_(self.width_attentions, 0, 0.01) | ||||
|         self.apply(initialize_resnet) | ||||
|  | ||||
|     def arch_parameters(self): | ||||
|         return [self.width_attentions] | ||||
|  | ||||
|     def base_parameters(self): | ||||
|         return ( | ||||
|             list(self.layers.parameters()) | ||||
|             + list(self.avgpool.parameters()) | ||||
|             + list(self.classifier.parameters()) | ||||
|         ) | ||||
|  | ||||
|     def get_flop(self, mode, config_dict, extra_info): | ||||
|         if config_dict is not None: | ||||
|             config_dict = config_dict.copy() | ||||
|         # weights = [F.softmax(x, dim=0) for x in self.width_attentions] | ||||
|         channels = [3] | ||||
|         for i, weight in enumerate(self.width_attentions): | ||||
|             if mode == "genotype": | ||||
|                 with torch.no_grad(): | ||||
|                     probe = nn.functional.softmax(weight, dim=0) | ||||
|                     C = self.Ranges[i][torch.argmax(probe).item()] | ||||
|             elif mode == "max": | ||||
|                 C = self.Ranges[i][-1] | ||||
|             elif mode == "fix": | ||||
|                 C = int(math.sqrt(extra_info) * self.Ranges[i][-1]) | ||||
|             elif mode == "random": | ||||
|                 assert isinstance(extra_info, float), "invalid extra_info : {:}".format( | ||||
|                     extra_info | ||||
|                 ) | ||||
|                 with torch.no_grad(): | ||||
|                     prob = nn.functional.softmax(weight, dim=0) | ||||
|                     approximate_C = int(math.sqrt(extra_info) * self.Ranges[i][-1]) | ||||
|                     for j in range(prob.size(0)): | ||||
|                         prob[j] = 1 / ( | ||||
|                             abs(j - (approximate_C - self.Ranges[i][j])) + 0.2 | ||||
|                         ) | ||||
|                     C = self.Ranges[i][torch.multinomial(prob, 1, False).item()] | ||||
|             else: | ||||
|                 raise ValueError("invalid mode : {:}".format(mode)) | ||||
|             channels.append(C) | ||||
|         flop = 0 | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             s, e = self.layer2indexRange[i] | ||||
|             xchl = tuple(channels[s : e + 1]) | ||||
|             flop += layer.get_flops(xchl) | ||||
|         # the last fc layer | ||||
|         flop += channels[-1] * self.classifier.out_features | ||||
|         if config_dict is None: | ||||
|             return flop / 1e6 | ||||
|         else: | ||||
|             config_dict["xchannels"] = channels | ||||
|             config_dict["super_type"] = "infer-width" | ||||
|             config_dict["estimated_FLOP"] = flop / 1e6 | ||||
|             return flop / 1e6, config_dict | ||||
|  | ||||
|     def get_arch_info(self): | ||||
|         string = "for width, there are {:} attention probabilities.".format( | ||||
|             len(self.width_attentions) | ||||
|         ) | ||||
|         discrepancy = [] | ||||
|         with torch.no_grad(): | ||||
|             for i, att in enumerate(self.width_attentions): | ||||
|                 prob = nn.functional.softmax(att, dim=0) | ||||
|                 prob = prob.cpu() | ||||
|                 selc = prob.argmax().item() | ||||
|                 prob = prob.tolist() | ||||
|                 prob = ["{:.3f}".format(x) for x in prob] | ||||
|                 xstring = "{:03d}/{:03d}-th : {:}".format( | ||||
|                     i, len(self.width_attentions), " ".join(prob) | ||||
|                 ) | ||||
|                 logt = ["{:.3f}".format(x) for x in att.cpu().tolist()] | ||||
|                 xstring += "  ||  {:52s}".format(" ".join(logt)) | ||||
|                 prob = sorted([float(x) for x in prob]) | ||||
|                 disc = prob[-1] - prob[-2] | ||||
|                 xstring += "  || dis={:.2f} || select={:}/{:}".format( | ||||
|                     disc, selc, len(prob) | ||||
|                 ) | ||||
|                 discrepancy.append(disc) | ||||
|                 string += "\n{:}".format(xstring) | ||||
|         return string, discrepancy | ||||
|  | ||||
|     def set_tau(self, tau_max, tau_min, epoch_ratio): | ||||
|         assert ( | ||||
|             epoch_ratio >= 0 and epoch_ratio <= 1 | ||||
|         ), "invalid epoch-ratio : {:}".format(epoch_ratio) | ||||
|         tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2 | ||||
|         self.tau = tau | ||||
|  | ||||
|     def get_message(self): | ||||
|         return self.message | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         if self.search_mode == "basic": | ||||
|             return self.basic_forward(inputs) | ||||
|         elif self.search_mode == "search": | ||||
|             return self.search_forward(inputs) | ||||
|         else: | ||||
|             raise ValueError("invalid search_mode = {:}".format(self.search_mode)) | ||||
|  | ||||
|     def search_forward(self, inputs): | ||||
|         flop_probs = nn.functional.softmax(self.width_attentions, dim=1) | ||||
|         selected_widths, selected_probs = select2withP(self.width_attentions, self.tau) | ||||
|         with torch.no_grad(): | ||||
|             selected_widths = selected_widths.cpu() | ||||
|  | ||||
|         x, last_channel_idx, expected_inC, flops = inputs, 0, 3, [] | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             selected_w_index = selected_widths[ | ||||
|                 last_channel_idx : last_channel_idx + layer.num_conv | ||||
|             ] | ||||
|             selected_w_probs = selected_probs[ | ||||
|                 last_channel_idx : last_channel_idx + layer.num_conv | ||||
|             ] | ||||
|             layer_prob = flop_probs[ | ||||
|                 last_channel_idx : last_channel_idx + layer.num_conv | ||||
|             ] | ||||
|             x, expected_inC, expected_flop = layer( | ||||
|                 (x, expected_inC, layer_prob, selected_w_index, selected_w_probs) | ||||
|             ) | ||||
|             last_channel_idx += layer.num_conv | ||||
|             flops.append(expected_flop) | ||||
|         flops.append(expected_inC * (self.classifier.out_features * 1.0 / 1e6)) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = linear_forward(features, self.classifier) | ||||
|         return logits, torch.stack([sum(flops)]) | ||||
|  | ||||
|     def basic_forward(self, inputs): | ||||
|         if self.InShape is None: | ||||
|             self.InShape = (inputs.size(-2), inputs.size(-1)) | ||||
|         x = inputs | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             x = layer(x) | ||||
|         features = self.avgpool(x) | ||||
|         features = features.view(features.size(0), -1) | ||||
|         logits = self.classifier(features) | ||||
|         return features, logits | ||||
| @@ -1,128 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import math, torch | ||||
| import torch.nn as nn | ||||
|  | ||||
|  | ||||
| def select2withP(logits, tau, just_prob=False, num=2, eps=1e-7): | ||||
|     if tau <= 0: | ||||
|         new_logits = logits | ||||
|         probs = nn.functional.softmax(new_logits, dim=1) | ||||
|     else: | ||||
|         while True:  # a trick to avoid the gumbels bug | ||||
|             gumbels = -torch.empty_like(logits).exponential_().log() | ||||
|             new_logits = (logits.log_softmax(dim=1) + gumbels) / tau | ||||
|             probs = nn.functional.softmax(new_logits, dim=1) | ||||
|             if ( | ||||
|                 (not torch.isinf(gumbels).any()) | ||||
|                 and (not torch.isinf(probs).any()) | ||||
|                 and (not torch.isnan(probs).any()) | ||||
|             ): | ||||
|                 break | ||||
|  | ||||
|     if just_prob: | ||||
|         return probs | ||||
|  | ||||
|     # with torch.no_grad(): # add eps for unexpected torch error | ||||
|     #  probs = nn.functional.softmax(new_logits, dim=1) | ||||
|     #  selected_index = torch.multinomial(probs + eps, 2, False) | ||||
|     with torch.no_grad():  # add eps for unexpected torch error | ||||
|         probs = probs.cpu() | ||||
|         selected_index = torch.multinomial(probs + eps, num, False).to(logits.device) | ||||
|     selected_logit = torch.gather(new_logits, 1, selected_index) | ||||
|     selcted_probs = nn.functional.softmax(selected_logit, dim=1) | ||||
|     return selected_index, selcted_probs | ||||
|  | ||||
|  | ||||
| def ChannelWiseInter(inputs, oC, mode="v2"): | ||||
|     if mode == "v1": | ||||
|         return ChannelWiseInterV1(inputs, oC) | ||||
|     elif mode == "v2": | ||||
|         return ChannelWiseInterV2(inputs, oC) | ||||
|     else: | ||||
|         raise ValueError("invalid mode : {:}".format(mode)) | ||||
|  | ||||
|  | ||||
| def ChannelWiseInterV1(inputs, oC): | ||||
|     assert inputs.dim() == 4, "invalid dimension : {:}".format(inputs.size()) | ||||
|  | ||||
|     def start_index(a, b, c): | ||||
|         return int(math.floor(float(a * c) / b)) | ||||
|  | ||||
|     def end_index(a, b, c): | ||||
|         return int(math.ceil(float((a + 1) * c) / b)) | ||||
|  | ||||
|     batch, iC, H, W = inputs.size() | ||||
|     outputs = torch.zeros((batch, oC, H, W), dtype=inputs.dtype, device=inputs.device) | ||||
|     if iC == oC: | ||||
|         return inputs | ||||
|     for ot in range(oC): | ||||
|         istartT, iendT = start_index(ot, oC, iC), end_index(ot, oC, iC) | ||||
|         values = inputs[:, istartT:iendT].mean(dim=1) | ||||
|         outputs[:, ot, :, :] = values | ||||
|     return outputs | ||||
|  | ||||
|  | ||||
| def ChannelWiseInterV2(inputs, oC): | ||||
|     assert inputs.dim() == 4, "invalid dimension : {:}".format(inputs.size()) | ||||
|     batch, C, H, W = inputs.size() | ||||
|     if C == oC: | ||||
|         return inputs | ||||
|     else: | ||||
|         return nn.functional.adaptive_avg_pool3d(inputs, (oC, H, W)) | ||||
|     # inputs_5D = inputs.view(batch, 1, C, H, W) | ||||
|     # otputs_5D = nn.functional.interpolate(inputs_5D, (oC,H,W), None, 'area', None) | ||||
|     # otputs    = otputs_5D.view(batch, oC, H, W) | ||||
|     # otputs_5D = nn.functional.interpolate(inputs_5D, (oC,H,W), None, 'trilinear', False) | ||||
|     # return otputs | ||||
|  | ||||
|  | ||||
| def linear_forward(inputs, linear): | ||||
|     if linear is None: | ||||
|         return inputs | ||||
|     iC = inputs.size(1) | ||||
|     weight = linear.weight[:, :iC] | ||||
|     if linear.bias is None: | ||||
|         bias = None | ||||
|     else: | ||||
|         bias = linear.bias | ||||
|     return nn.functional.linear(inputs, weight, bias) | ||||
|  | ||||
|  | ||||
| def get_width_choices(nOut): | ||||
|     xsrange = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] | ||||
|     if nOut is None: | ||||
|         return len(xsrange) | ||||
|     else: | ||||
|         Xs = [int(nOut * i) for i in xsrange] | ||||
|         # xs = [ int(nOut * i // 10) for i in range(2, 11)] | ||||
|         # Xs = [x for i, x in enumerate(xs) if i+1 == len(xs) or xs[i+1] > x+1] | ||||
|         Xs = sorted(list(set(Xs))) | ||||
|         return tuple(Xs) | ||||
|  | ||||
|  | ||||
| def get_depth_choices(nDepth): | ||||
|     if nDepth is None: | ||||
|         return 3 | ||||
|     else: | ||||
|         assert nDepth >= 3, "nDepth should be greater than 2 vs {:}".format(nDepth) | ||||
|         if nDepth == 1: | ||||
|             return (1, 1, 1) | ||||
|         elif nDepth == 2: | ||||
|             return (1, 1, 2) | ||||
|         elif nDepth >= 3: | ||||
|             return (nDepth // 3, nDepth * 2 // 3, nDepth) | ||||
|         else: | ||||
|             raise ValueError("invalid Depth : {:}".format(nDepth)) | ||||
|  | ||||
|  | ||||
| def drop_path(x, drop_prob): | ||||
|     if drop_prob > 0.0: | ||||
|         keep_prob = 1.0 - drop_prob | ||||
|         mask = x.new_zeros(x.size(0), 1, 1, 1) | ||||
|         mask = mask.bernoulli_(keep_prob) | ||||
|         x = x * (mask / keep_prob) | ||||
|         # x.div_(keep_prob) | ||||
|         # x.mul_(mask) | ||||
|     return x | ||||
| @@ -1,9 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| from .SearchCifarResNet_width import SearchWidthCifarResNet | ||||
| from .SearchCifarResNet_depth import SearchDepthCifarResNet | ||||
| from .SearchCifarResNet import SearchShapeCifarResNet | ||||
| from .SearchSimResNet_width import SearchWidthSimResNet | ||||
| from .SearchImagenetResNet import SearchShapeImagenetResNet | ||||
| from .generic_size_tiny_cell_model import GenericNAS301Model | ||||
| @@ -1,209 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| # Here, we utilized three techniques to search for the number of channels: | ||||
| # - channel-wise interpolation from "Network Pruning via Transformable Architecture Search, NeurIPS 2019" | ||||
| # - masking + Gumbel-Softmax (mask_gumbel) from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020" | ||||
| # - masking + sampling (mask_rl) from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020" | ||||
| from typing import List, Text, Any | ||||
| import random, torch | ||||
| import torch.nn as nn | ||||
|  | ||||
| from models.cell_operations import ResNetBasicblock | ||||
| from models.cell_infers.cells import InferCell | ||||
| from models.shape_searchs.SoftSelect import select2withP, ChannelWiseInter | ||||
|  | ||||
|  | ||||
| class GenericNAS301Model(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         candidate_Cs: List[int], | ||||
|         max_num_Cs: int, | ||||
|         genotype: Any, | ||||
|         num_classes: int, | ||||
|         affine: bool, | ||||
|         track_running_stats: bool, | ||||
|     ): | ||||
|         super(GenericNAS301Model, self).__init__() | ||||
|         self._max_num_Cs = max_num_Cs | ||||
|         self._candidate_Cs = candidate_Cs | ||||
|         if max_num_Cs % 3 != 2: | ||||
|             raise ValueError("invalid number of layers : {:}".format(max_num_Cs)) | ||||
|         self._num_stage = N = max_num_Cs // 3 | ||||
|         self._max_C = max(candidate_Cs) | ||||
|  | ||||
|         stem = nn.Sequential( | ||||
|             nn.Conv2d(3, self._max_C, kernel_size=3, padding=1, bias=not affine), | ||||
|             nn.BatchNorm2d( | ||||
|                 self._max_C, affine=affine, track_running_stats=track_running_stats | ||||
|             ), | ||||
|         ) | ||||
|  | ||||
|         layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||
|  | ||||
|         c_prev = self._max_C | ||||
|         self._cells = nn.ModuleList() | ||||
|         self._cells.append(stem) | ||||
|         for index, reduction in enumerate(layer_reductions): | ||||
|             if reduction: | ||||
|                 cell = ResNetBasicblock(c_prev, self._max_C, 2, True) | ||||
|             else: | ||||
|                 cell = InferCell( | ||||
|                     genotype, c_prev, self._max_C, 1, affine, track_running_stats | ||||
|                 ) | ||||
|             self._cells.append(cell) | ||||
|             c_prev = cell.out_dim | ||||
|         self._num_layer = len(self._cells) | ||||
|  | ||||
|         self.lastact = nn.Sequential( | ||||
|             nn.BatchNorm2d( | ||||
|                 c_prev, affine=affine, track_running_stats=track_running_stats | ||||
|             ), | ||||
|             nn.ReLU(inplace=True), | ||||
|         ) | ||||
|         self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|         self.classifier = nn.Linear(c_prev, num_classes) | ||||
|         # algorithm related | ||||
|         self.register_buffer("_tau", torch.zeros(1)) | ||||
|         self._algo = None | ||||
|         self._warmup_ratio = None | ||||
|  | ||||
|     def set_algo(self, algo: Text): | ||||
|         # used for searching | ||||
|         assert self._algo is None, "This functioin can only be called once." | ||||
|         assert algo in ["mask_gumbel", "mask_rl", "tas"], "invalid algo : {:}".format( | ||||
|             algo | ||||
|         ) | ||||
|         self._algo = algo | ||||
|         self._arch_parameters = nn.Parameter( | ||||
|             1e-3 * torch.randn(self._max_num_Cs, len(self._candidate_Cs)) | ||||
|         ) | ||||
|         # if algo == 'mask_gumbel' or algo == 'mask_rl': | ||||
|         self.register_buffer( | ||||
|             "_masks", torch.zeros(len(self._candidate_Cs), max(self._candidate_Cs)) | ||||
|         ) | ||||
|         for i in range(len(self._candidate_Cs)): | ||||
|             self._masks.data[i, : self._candidate_Cs[i]] = 1 | ||||
|  | ||||
|     @property | ||||
|     def tau(self): | ||||
|         return self._tau | ||||
|  | ||||
|     def set_tau(self, tau): | ||||
|         self._tau.data[:] = tau | ||||
|  | ||||
|     @property | ||||
|     def warmup_ratio(self): | ||||
|         return self._warmup_ratio | ||||
|  | ||||
|     def set_warmup_ratio(self, ratio: float): | ||||
|         self._warmup_ratio = ratio | ||||
|  | ||||
|     @property | ||||
|     def weights(self): | ||||
|         xlist = list(self._cells.parameters()) | ||||
|         xlist += list(self.lastact.parameters()) | ||||
|         xlist += list(self.global_pooling.parameters()) | ||||
|         xlist += list(self.classifier.parameters()) | ||||
|         return xlist | ||||
|  | ||||
|     @property | ||||
|     def alphas(self): | ||||
|         return [self._arch_parameters] | ||||
|  | ||||
|     def show_alphas(self): | ||||
|         with torch.no_grad(): | ||||
|             return "arch-parameters :\n{:}".format( | ||||
|                 nn.functional.softmax(self._arch_parameters, dim=-1).cpu() | ||||
|             ) | ||||
|  | ||||
|     @property | ||||
|     def random(self): | ||||
|         cs = [] | ||||
|         for i in range(self._max_num_Cs): | ||||
|             index = random.randint(0, len(self._candidate_Cs) - 1) | ||||
|             cs.append(str(self._candidate_Cs[index])) | ||||
|         return ":".join(cs) | ||||
|  | ||||
|     @property | ||||
|     def genotype(self): | ||||
|         cs = [] | ||||
|         for i in range(self._max_num_Cs): | ||||
|             with torch.no_grad(): | ||||
|                 index = self._arch_parameters[i].argmax().item() | ||||
|                 cs.append(str(self._candidate_Cs[index])) | ||||
|         return ":".join(cs) | ||||
|  | ||||
|     def get_message(self) -> Text: | ||||
|         string = self.extra_repr() | ||||
|         for i, cell in enumerate(self._cells): | ||||
|             string += "\n {:02d}/{:02d} :: {:}".format( | ||||
|                 i, len(self._cells), cell.extra_repr() | ||||
|             ) | ||||
|         return string | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "{name}(candidates={_candidate_Cs}, num={_max_num_Cs}, N={_num_stage}, L={_num_layer})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     def forward(self, inputs): | ||||
|         feature = inputs | ||||
|  | ||||
|         log_probs = [] | ||||
|         for i, cell in enumerate(self._cells): | ||||
|             feature = cell(feature) | ||||
|             # apply different searching algorithms | ||||
|             idx = max(0, i - 1) | ||||
|             if self._warmup_ratio is not None: | ||||
|                 if random.random() < self._warmup_ratio: | ||||
|                     mask = self._masks[-1] | ||||
|                 else: | ||||
|                     mask = self._masks[random.randint(0, len(self._masks) - 1)] | ||||
|                 feature = feature * mask.view(1, -1, 1, 1) | ||||
|             elif self._algo == "mask_gumbel": | ||||
|                 weights = nn.functional.gumbel_softmax( | ||||
|                     self._arch_parameters[idx : idx + 1], tau=self.tau, dim=-1 | ||||
|                 ) | ||||
|                 mask = torch.matmul(weights, self._masks).view(1, -1, 1, 1) | ||||
|                 feature = feature * mask | ||||
|             elif self._algo == "tas": | ||||
|                 selected_cs, selected_probs = select2withP( | ||||
|                     self._arch_parameters[idx : idx + 1], self.tau, num=2 | ||||
|                 ) | ||||
|                 with torch.no_grad(): | ||||
|                     i1, i2 = selected_cs.cpu().view(-1).tolist() | ||||
|                 c1, c2 = self._candidate_Cs[i1], self._candidate_Cs[i2] | ||||
|                 out_channel = max(c1, c2) | ||||
|                 out1 = ChannelWiseInter(feature[:, :c1], out_channel) | ||||
|                 out2 = ChannelWiseInter(feature[:, :c2], out_channel) | ||||
|                 out = out1 * selected_probs[0, 0] + out2 * selected_probs[0, 1] | ||||
|                 if feature.shape[1] == out.shape[1]: | ||||
|                     feature = out | ||||
|                 else: | ||||
|                     miss = torch.zeros( | ||||
|                         feature.shape[0], | ||||
|                         feature.shape[1] - out.shape[1], | ||||
|                         feature.shape[2], | ||||
|                         feature.shape[3], | ||||
|                         device=feature.device, | ||||
|                     ) | ||||
|                     feature = torch.cat((out, miss), dim=1) | ||||
|             elif self._algo == "mask_rl": | ||||
|                 prob = nn.functional.softmax( | ||||
|                     self._arch_parameters[idx : idx + 1], dim=-1 | ||||
|                 ) | ||||
|                 dist = torch.distributions.Categorical(prob) | ||||
|                 action = dist.sample() | ||||
|                 log_probs.append(dist.log_prob(action)) | ||||
|                 mask = self._masks[action.item()].view(1, -1, 1, 1) | ||||
|                 feature = feature * mask | ||||
|             else: | ||||
|                 raise ValueError("invalid algorithm : {:}".format(self._algo)) | ||||
|  | ||||
|         out = self.lastact(feature) | ||||
|         out = self.global_pooling(out) | ||||
|         out = out.view(out.size(0), -1) | ||||
|         logits = self.classifier(out) | ||||
|  | ||||
|         return out, logits, log_probs | ||||
| @@ -1,20 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from SoftSelect import ChannelWiseInter | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|  | ||||
|     tensors = torch.rand((16, 128, 7, 7)) | ||||
|  | ||||
|     for oc in range(200, 210): | ||||
|         out_v1 = ChannelWiseInter(tensors, oc, "v1") | ||||
|         out_v2 = ChannelWiseInter(tensors, oc, "v2") | ||||
|         assert (out_v1 == out_v2).any().item() == 1 | ||||
|     for oc in range(48, 160): | ||||
|         out_v1 = ChannelWiseInter(tensors, oc, "v1") | ||||
|         out_v2 = ChannelWiseInter(tensors, oc, "v2") | ||||
|         assert (out_v1 == out_v2).any().item() == 1 | ||||
| @@ -1,67 +0,0 @@ | ||||
| ####################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04   # | ||||
| ####################################################### | ||||
| # Use module in xlayers to construct different models # | ||||
| ####################################################### | ||||
| from typing import List, Text, Dict, Any | ||||
| import torch | ||||
|  | ||||
| __all__ = ["get_model"] | ||||
|  | ||||
|  | ||||
| from xlayers.super_core import SuperSequential | ||||
| from xlayers.super_core import SuperLinear | ||||
| from xlayers.super_core import SuperDropout | ||||
| from xlayers.super_core import super_name2norm | ||||
| from xlayers.super_core import super_name2activation | ||||
|  | ||||
|  | ||||
| def get_model(config: Dict[Text, Any], **kwargs): | ||||
|     model_type = config.get("model_type", "simple_mlp") | ||||
|     if model_type == "simple_mlp": | ||||
|         act_cls = super_name2activation[kwargs["act_cls"]] | ||||
|         norm_cls = super_name2norm[kwargs["norm_cls"]] | ||||
|         mean, std = kwargs.get("mean", None), kwargs.get("std", None) | ||||
|         if "hidden_dim" in kwargs: | ||||
|             hidden_dim1 = kwargs.get("hidden_dim") | ||||
|             hidden_dim2 = kwargs.get("hidden_dim") | ||||
|         else: | ||||
|             hidden_dim1 = kwargs.get("hidden_dim1", 200) | ||||
|             hidden_dim2 = kwargs.get("hidden_dim2", 100) | ||||
|         model = SuperSequential( | ||||
|             norm_cls(mean=mean, std=std), | ||||
|             SuperLinear(kwargs["input_dim"], hidden_dim1), | ||||
|             act_cls(), | ||||
|             SuperLinear(hidden_dim1, hidden_dim2), | ||||
|             act_cls(), | ||||
|             SuperLinear(hidden_dim2, kwargs["output_dim"]), | ||||
|         ) | ||||
|     elif model_type == "norm_mlp": | ||||
|         act_cls = super_name2activation[kwargs["act_cls"]] | ||||
|         norm_cls = super_name2norm[kwargs["norm_cls"]] | ||||
|         sub_layers, last_dim = [], kwargs["input_dim"] | ||||
|         for i, hidden_dim in enumerate(kwargs["hidden_dims"]): | ||||
|             if last_dim > 1: | ||||
|                 sub_layers.append(norm_cls(last_dim, elementwise_affine=False)) | ||||
|             sub_layers.append(SuperLinear(last_dim, hidden_dim)) | ||||
|             sub_layers.append(act_cls()) | ||||
|             last_dim = hidden_dim | ||||
|         sub_layers.append(SuperLinear(last_dim, kwargs["output_dim"])) | ||||
|         model = SuperSequential(*sub_layers) | ||||
|     elif model_type == "dual_norm_mlp": | ||||
|         act_cls = super_name2activation[kwargs["act_cls"]] | ||||
|         norm_cls = super_name2norm[kwargs["norm_cls"]] | ||||
|         sub_layers, last_dim = [], kwargs["input_dim"] | ||||
|         for i, hidden_dim in enumerate(kwargs["hidden_dims"]): | ||||
|             if i > 0: | ||||
|                 sub_layers.append(norm_cls(last_dim, elementwise_affine=False)) | ||||
|             sub_layers.append(SuperLinear(last_dim, hidden_dim)) | ||||
|             sub_layers.append(SuperDropout(kwargs["dropout"])) | ||||
|             sub_layers.append(SuperLinear(hidden_dim, hidden_dim)) | ||||
|             sub_layers.append(act_cls()) | ||||
|             last_dim = hidden_dim | ||||
|         sub_layers.append(SuperLinear(last_dim, kwargs["output_dim"])) | ||||
|         model = SuperSequential(*sub_layers) | ||||
|     else: | ||||
|         raise TypeError("Unkonwn model type: {:}".format(model_type)) | ||||
|     return model | ||||
| @@ -1,15 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # | ||||
| ##################################################################### | ||||
| # This API will not be updated after 2020.09.16.                    # | ||||
| # Please use our new API in NATS-Bench, which is                    # | ||||
| # more efficient and contains info of more architecture candidates. # | ||||
| ##################################################################### | ||||
| from .api_utils import ArchResults, ResultsCount | ||||
| from .api_201 import NASBench201API | ||||
|  | ||||
| # NAS_BENCH_201_API_VERSION="v1.1"  # [2020.02.25] | ||||
| # NAS_BENCH_201_API_VERSION="v1.2"  # [2020.03.09] | ||||
| # NAS_BENCH_201_API_VERSION="v1.3"  # [2020.03.16] | ||||
| NAS_BENCH_201_API_VERSION="v2.0"    # [2020.06.30] | ||||
|  | ||||
| @@ -1,274 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # | ||||
| ############################################################################################ | ||||
| # NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search, ICLR 2020 # | ||||
| ############################################################################################ | ||||
| # The history of benchmark files: | ||||
| # [2020.02.25] NAS-Bench-201-v1_0-e61699.pth : 6219 architectures are trained once, 1621 architectures are trained twice, 7785 architectures are trained three times. `LESS` only supports CIFAR10-VALID. | ||||
| # [2020.03.16] NAS-Bench-201-v1_1-096897.pth : 2225 architectures are trained once, 5439 archiitectures are trained twice, 7961 architectures are trained three times on all training sets. For the hyper-parameters with the total epochs of 12, each model is trained on CIFAR-10, CIFAR-100, ImageNet16-120 once, and is trained on CIFAR-10-VALID twice. | ||||
| # | ||||
| # I'm still actively enhancing our benchmark, while for the future benchmark file, please follow news from NATS-Bench (an extended version of NAS-Bench-201). | ||||
| # | ||||
| import os, copy, random, torch, numpy as np | ||||
| from pathlib import Path | ||||
| from typing import List, Text, Union, Dict, Optional | ||||
| from collections import OrderedDict, defaultdict | ||||
|  | ||||
| from .api_utils import ArchResults | ||||
| from .api_utils import NASBenchMetaAPI | ||||
| from .api_utils import remap_dataset_set_names | ||||
|  | ||||
|  | ||||
| ALL_BENCHMARK_FILES = ['NAS-Bench-201-v1_0-e61699.pth', 'NAS-Bench-201-v1_1-096897.pth'] | ||||
| ALL_ARCHIVE_DIRS = ['NAS-Bench-201-v1_1-archive'] | ||||
|  | ||||
|  | ||||
| def print_information(information, extra_info=None, show=False): | ||||
|   dataset_names = information.get_dataset_names() | ||||
|   strings = [information.arch_str, 'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info)] | ||||
|   def metric2str(loss, acc): | ||||
|     return 'loss = {:.3f}, top1 = {:.2f}%'.format(loss, acc) | ||||
|  | ||||
|   for ida, dataset in enumerate(dataset_names): | ||||
|     metric = information.get_compute_costs(dataset) | ||||
|     flop, param, latency = metric['flops'], metric['params'], metric['latency'] | ||||
|     str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(dataset, flop, param, '{:.2f}'.format(latency*1000) if latency is not None and latency > 0 else None) | ||||
|     train_info = information.get_metrics(dataset, 'train') | ||||
|     if dataset == 'cifar10-valid': | ||||
|       valid_info = information.get_metrics(dataset, 'x-valid') | ||||
|       str2 = '{:14s} train : [{:}], valid : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy'])) | ||||
|     elif dataset == 'cifar10': | ||||
|       test__info = information.get_metrics(dataset, 'ori-test') | ||||
|       str2 = '{:14s} train : [{:}], test  : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy'])) | ||||
|     else: | ||||
|       valid_info = information.get_metrics(dataset, 'x-valid') | ||||
|       test__info = information.get_metrics(dataset, 'x-test') | ||||
|       str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy'])) | ||||
|     strings += [str1, str2] | ||||
|   if show: print('\n'.join(strings)) | ||||
|   return strings | ||||
|  | ||||
|  | ||||
| """ | ||||
| This is the class for the API of NAS-Bench-201. | ||||
| """ | ||||
| class NASBench201API(NASBenchMetaAPI): | ||||
|  | ||||
|   """ The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """ | ||||
|   def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, | ||||
|                verbose: bool=True): | ||||
|     self.filename = None | ||||
|     self.reset_time() | ||||
|     if file_path_or_dict is None: | ||||
|       file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], ALL_BENCHMARK_FILES[-1]) | ||||
|       print ('Try to use the default NAS-Bench-201 path from {:}.'.format(file_path_or_dict)) | ||||
|     if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path): | ||||
|       file_path_or_dict = str(file_path_or_dict) | ||||
|       if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict)) | ||||
|       assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict) | ||||
|       self.filename = Path(file_path_or_dict).name | ||||
|       file_path_or_dict = torch.load(file_path_or_dict, map_location='cpu') | ||||
|     elif isinstance(file_path_or_dict, dict): | ||||
|       file_path_or_dict = copy.deepcopy(file_path_or_dict) | ||||
|     else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict))) | ||||
|     assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict)) | ||||
|     self.verbose = verbose # [TODO] a flag indicating whether to print more logs | ||||
|     keys = ('meta_archs', 'arch2infos', 'evaluated_indexes') | ||||
|     for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key) | ||||
|     self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] ) | ||||
|     # This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults | ||||
|     self.arch2infos_dict = OrderedDict() | ||||
|     self._avaliable_hps = set(['12', '200']) | ||||
|     for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())): | ||||
|       all_info = file_path_or_dict['arch2infos'][xkey] | ||||
|       hp2archres = OrderedDict() | ||||
|       # self.arch2infos_less[xkey] = ArchResults.create_from_state_dict( all_info['less'] ) | ||||
|       # self.arch2infos_full[xkey] = ArchResults.create_from_state_dict( all_info['full'] ) | ||||
|       hp2archres['12'] = ArchResults.create_from_state_dict(all_info['less']) | ||||
|       hp2archres['200'] = ArchResults.create_from_state_dict(all_info['full']) | ||||
|       self.arch2infos_dict[xkey] = hp2archres | ||||
|     self.evaluated_indexes = sorted(list(file_path_or_dict['evaluated_indexes'])) | ||||
|     self.archstr2index = {} | ||||
|     for idx, arch in enumerate(self.meta_archs): | ||||
|       assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch]) | ||||
|       self.archstr2index[ arch ] = idx | ||||
|  | ||||
|   def reload(self, archive_root: Text = None, index: int = None): | ||||
|     """Overwrite all information of the 'index'-th architecture in the search space. | ||||
|          It will load its data from 'archive_root'. | ||||
|     """ | ||||
|     if archive_root is None: | ||||
|       archive_root = os.path.join(os.environ['TORCH_HOME'], ALL_ARCHIVE_DIRS[-1]) | ||||
|     assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root) | ||||
|     if index is None: | ||||
|       indexes = list(range(len(self))) | ||||
|     else: | ||||
|       indexes = [index] | ||||
|     for idx in indexes: | ||||
|       assert 0 <= idx < len(self.meta_archs), 'invalid index of {:}'.format(idx) | ||||
|       xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(idx)) | ||||
|       assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path) | ||||
|       xdata = torch.load(xfile_path, map_location='cpu') | ||||
|       assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path) | ||||
|       hp2archres = OrderedDict() | ||||
|       hp2archres['12'] = ArchResults.create_from_state_dict(xdata['less']) | ||||
|       hp2archres['200'] = ArchResults.create_from_state_dict(xdata['full']) | ||||
|       self.arch2infos_dict[idx] = hp2archres | ||||
|  | ||||
|   def query_info_str_by_arch(self, arch, hp: Text='12'): | ||||
|     """ This function is used to query the information of a specific architecture | ||||
|         'arch' can be an architecture index or an architecture string | ||||
|         When hp=12, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/12E.config' | ||||
|         When hp=200, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/200E.config' | ||||
|         The difference between these three configurations are the number of training epochs. | ||||
|     """ | ||||
|     if self.verbose: | ||||
|       print('Call query_info_str_by_arch with arch={:} and hp={:}'.format(arch, hp)) | ||||
|     return self._query_info_str_by_arch(arch, hp, print_information) | ||||
|  | ||||
|   # obtain the metric for the `index`-th architecture | ||||
|   # `dataset` indicates the dataset: | ||||
|   #   'cifar10-valid'  : using the proposed train set of CIFAR-10 as the training set | ||||
|   #   'cifar10'        : using the proposed train+valid set of CIFAR-10 as the training set | ||||
|   #   'cifar100'       : using the proposed train set of CIFAR-100 as the training set | ||||
|   #   'ImageNet16-120' : using the proposed train set of ImageNet-16-120 as the training set | ||||
|   # `iepoch` indicates the index of training epochs from 0 to 11/199. | ||||
|   #   When iepoch=None, it will return the metric for the last training epoch | ||||
|   #   When iepoch=11, it will return the metric for the 11-th training epoch (starting from 0) | ||||
|   # `use_12epochs_result` indicates different hyper-parameters for training | ||||
|   #   When use_12epochs_result=True, it trains the network with 12 epochs and the LR decayed from 0.1 to 0 within 12 epochs | ||||
|   #   When use_12epochs_result=False, it trains the network with 200 epochs and the LR decayed from 0.1 to 0 within 200 epochs | ||||
|   # `is_random` | ||||
|   #   When is_random=True, the performance of a random architecture will be returned | ||||
|   #   When is_random=False, the performanceo of all trials will be averaged. | ||||
|   def get_more_info(self, index, dataset, iepoch=None, hp='12', is_random=True): | ||||
|     if self.verbose: | ||||
|       print('Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(index, dataset, iepoch, hp, is_random)) | ||||
|     index = self.query_index_by_arch(index)  # To avoid the input is a string or an instance of a arch object | ||||
|     if index not in self.arch2infos_dict: | ||||
|       raise ValueError('Did not find {:} from arch2infos_dict.'.format(index)) | ||||
|     archresult = self.arch2infos_dict[index][str(hp)] | ||||
|     # if randomly select one trial, select the seed at first | ||||
|     if isinstance(is_random, bool) and is_random: | ||||
|       seeds = archresult.get_dataset_seeds(dataset) | ||||
|       is_random = random.choice(seeds) | ||||
|     # collect the training information | ||||
|     train_info = archresult.get_metrics(dataset, 'train', iepoch=iepoch, is_random=is_random) | ||||
|     total = train_info['iepoch'] + 1 | ||||
|     xinfo = {'train-loss'    : train_info['loss'], | ||||
|              'train-accuracy': train_info['accuracy'], | ||||
|              'train-per-time': train_info['all_time'] / total if train_info['all_time'] is not None else None, | ||||
|              'train-all-time': train_info['all_time']} | ||||
|     # collect the evaluation information | ||||
|     if dataset == 'cifar10-valid': | ||||
|       valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random) | ||||
|       try: | ||||
|         test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) | ||||
|       except: | ||||
|         test_info = None | ||||
|       valtest_info = None | ||||
|     else: | ||||
|       try: # collect results on the proposed test set | ||||
|         if dataset == 'cifar10': | ||||
|           test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) | ||||
|         else: | ||||
|           test_info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random) | ||||
|       except: | ||||
|         test_info = None | ||||
|       try: # collect results on the proposed validation set | ||||
|         valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random) | ||||
|       except: | ||||
|         valid_info = None | ||||
|       try: | ||||
|         if dataset != 'cifar10': | ||||
|           valtest_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) | ||||
|         else: | ||||
|           valtest_info = None | ||||
|       except: | ||||
|         valtest_info = None | ||||
|     if valid_info is not None: | ||||
|       xinfo['valid-loss'] = valid_info['loss'] | ||||
|       xinfo['valid-accuracy'] = valid_info['accuracy'] | ||||
|       xinfo['valid-per-time'] = valid_info['all_time'] / total if valid_info['all_time'] is not None else None | ||||
|       xinfo['valid-all-time'] = valid_info['all_time'] | ||||
|     if test_info is not None: | ||||
|       xinfo['test-loss'] = test_info['loss'] | ||||
|       xinfo['test-accuracy'] = test_info['accuracy'] | ||||
|       xinfo['test-per-time'] = test_info['all_time'] / total if test_info['all_time'] is not None else None | ||||
|       xinfo['test-all-time'] = test_info['all_time'] | ||||
|     if valtest_info is not None: | ||||
|       xinfo['valtest-loss'] = valtest_info['loss'] | ||||
|       xinfo['valtest-accuracy'] = valtest_info['accuracy'] | ||||
|       xinfo['valtest-per-time'] = valtest_info['all_time'] / total if valtest_info['all_time'] is not None else None | ||||
|       xinfo['valtest-all-time'] = valtest_info['all_time'] | ||||
|     return xinfo | ||||
|  | ||||
|   def show(self, index: int = -1) -> None: | ||||
|     """This function will print the information of a specific (or all) architecture(s).""" | ||||
|     self._show(index, print_information) | ||||
|  | ||||
|   @staticmethod | ||||
|   def str2lists(arch_str: Text) -> List[tuple]: | ||||
|     """ | ||||
|     This function shows how to read the string-based architecture encoding. | ||||
|       It is the same as the `str2structure` func in `AutoDL-Projects/lib/models/cell_searchs/genotypes.py` | ||||
|  | ||||
|     :param | ||||
|       arch_str: the input is a string indicates the architecture topology, such as | ||||
|                     |nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2| | ||||
|     :return: a list of tuple, contains multiple (op, input_node_index) pairs. | ||||
|  | ||||
|     :usage | ||||
|       arch = api.str2lists( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' ) | ||||
|       print ('there are {:} nodes in this arch'.format(len(arch)+1)) # arch is a list | ||||
|       for i, node in enumerate(arch): | ||||
|         print('the {:}-th node is the sum of these {:} nodes with op: {:}'.format(i+1, len(node), node)) | ||||
|     """ | ||||
|     node_strs = arch_str.split('+') | ||||
|     genotypes = [] | ||||
|     for i, node_str in enumerate(node_strs): | ||||
|       inputs = list(filter(lambda x: x != '', node_str.split('|'))) | ||||
|       for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput) | ||||
|       inputs = ( xi.split('~') for xi in inputs ) | ||||
|       input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs) | ||||
|       genotypes.append( input_infos ) | ||||
|     return genotypes | ||||
|  | ||||
|   @staticmethod | ||||
|   def str2matrix(arch_str: Text, | ||||
|                  search_space: List[Text] = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']) -> np.ndarray: | ||||
|     """ | ||||
|     This func shows how to convert the string-based architecture encoding to the encoding strategy in NAS-Bench-101. | ||||
|  | ||||
|     :param | ||||
|       arch_str: the input is a string indicates the architecture topology, such as | ||||
|                     |nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2| | ||||
|       search_space: a list of operation string, the default list is the search space for NAS-Bench-201 | ||||
|         the default value should be be consistent with this line https://github.com/D-X-Y/AutoDL-Projects/blob/main/lib/models/cell_operations.py#L24 | ||||
|     :return | ||||
|       the numpy matrix (2-D np.ndarray) representing the DAG of this architecture topology | ||||
|     :usage | ||||
|       matrix = api.str2matrix( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' ) | ||||
|       This matrix is 4-by-4 matrix representing a cell with 4 nodes (only the lower left triangle is useful). | ||||
|          [ [0, 0, 0, 0],  # the first line represents the input (0-th) node | ||||
|            [2, 0, 0, 0],  # the second line represents the 1-st node, is calculated by 2-th-op( 0-th-node ) | ||||
|            [0, 0, 0, 0],  # the third line represents the 2-nd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node ) | ||||
|            [0, 0, 1, 0] ] # the fourth line represents the 3-rd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node ) + 1-th-op( 2-th-node ) | ||||
|       In NAS-Bench-201 search space, 0-th-op is 'none', 1-th-op is 'skip_connect', | ||||
|          2-th-op is 'nor_conv_1x1', 3-th-op is 'nor_conv_3x3', 4-th-op is 'avg_pool_3x3'. | ||||
|     :(NOTE) | ||||
|       If a node has two input-edges from the same node, this function does not work. One edge will be overlapped. | ||||
|     """ | ||||
|     node_strs = arch_str.split('+') | ||||
|     num_nodes = len(node_strs) + 1 | ||||
|     matrix = np.zeros((num_nodes, num_nodes)) | ||||
|     for i, node_str in enumerate(node_strs): | ||||
|       inputs = list(filter(lambda x: x != '', node_str.split('|'))) | ||||
|       for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput) | ||||
|       for xi in inputs: | ||||
|         op, idx = xi.split('~') | ||||
|         if op not in search_space: raise ValueError('this op ({:}) is not in {:}'.format(op, search_space)) | ||||
|         op_idx, node_idx = search_space.index(op), int(idx) | ||||
|         matrix[i+1, node_idx] = op_idx | ||||
|     return matrix | ||||
|  | ||||
| @@ -1,748 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # | ||||
| ############################################################################################ | ||||
| # NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search, ICLR 2020 # | ||||
| ############################################################################################ | ||||
| # In this Python file, we define NASBenchMetaAPI, the abstract class for benchmark APIs. | ||||
| # We also define the class ArchResults, which contains all information of a single architecture trained by one kind of hyper-parameters on three datasets. | ||||
| # We also define the class ResultsCount, which contains all information of a single trial for a single architecture. | ||||
| ############################################################################################ | ||||
| # | ||||
| import os, abc, copy, random, torch, numpy as np | ||||
| from pathlib import Path | ||||
| from typing import List, Text, Union, Dict, Optional | ||||
| from collections import OrderedDict, defaultdict | ||||
|  | ||||
|  | ||||
| def remap_dataset_set_names(dataset, metric_on_set, verbose=False): | ||||
|   """re-map the metric_on_set to internal keys""" | ||||
|   if verbose: | ||||
|     print('Call internal function _remap_dataset_set_names with dataset={:} and metric_on_set={:}'.format(dataset, metric_on_set)) | ||||
|   if dataset == 'cifar10' and metric_on_set == 'valid': | ||||
|     dataset, metric_on_set = 'cifar10-valid', 'x-valid' | ||||
|   elif dataset == 'cifar10' and metric_on_set == 'test': | ||||
|     dataset, metric_on_set = 'cifar10', 'ori-test' | ||||
|   elif dataset == 'cifar10' and metric_on_set == 'train': | ||||
|     dataset, metric_on_set = 'cifar10', 'train' | ||||
|   elif (dataset == 'cifar100' or dataset == 'ImageNet16-120') and metric_on_set == 'valid': | ||||
|     metric_on_set = 'x-valid' | ||||
|   elif (dataset == 'cifar100' or dataset == 'ImageNet16-120') and metric_on_set == 'test': | ||||
|     metric_on_set = 'x-test' | ||||
|   if verbose: | ||||
|     print('  return dataset={:} and metric_on_set={:}'.format(dataset, metric_on_set)) | ||||
|   return dataset, metric_on_set | ||||
|  | ||||
|  | ||||
| class NASBenchMetaAPI(metaclass=abc.ABCMeta): | ||||
|  | ||||
|   @abc.abstractmethod | ||||
|   def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, verbose: bool=True): | ||||
|     """The initialization function that takes the dataset file path (or a dict loaded from that path) as input.""" | ||||
|  | ||||
|   def __getitem__(self, index: int): | ||||
|     return copy.deepcopy(self.meta_archs[index]) | ||||
|  | ||||
|   def arch(self, index: int): | ||||
|     """Return the topology structure of the `index`-th architecture.""" | ||||
|     if self.verbose: | ||||
|       print('Call the arch function with index={:}'.format(index)) | ||||
|     assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs)) | ||||
|     return copy.deepcopy(self.meta_archs[index]) | ||||
|  | ||||
|   def __len__(self): | ||||
|     return len(self.meta_archs) | ||||
|  | ||||
|   def __repr__(self): | ||||
|     return ('{name}({num}/{total} architectures, file={filename})'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs), filename=self.filename)) | ||||
|  | ||||
|   @property | ||||
|   def avaliable_hps(self): | ||||
|     return list(copy.deepcopy(self._avaliable_hps)) | ||||
|  | ||||
|   @property | ||||
|   def used_time(self): | ||||
|     return self._used_time | ||||
|  | ||||
|   def reset_time(self): | ||||
|     self._used_time = 0 | ||||
|  | ||||
|   def simulate_train_eval(self, arch, dataset, iepoch=None, hp='12', account_time=True): | ||||
|     index = self.query_index_by_arch(arch) | ||||
|     all_names = ('cifar10', 'cifar100', 'ImageNet16-120') | ||||
|     assert dataset in all_names, 'Invalid dataset name : {:} vs {:}'.format(dataset, all_names) | ||||
|     if dataset == 'cifar10': | ||||
|       info = self.get_more_info(index, 'cifar10-valid', iepoch=iepoch, hp=hp, is_random=True) | ||||
|     else: | ||||
|       info = self.get_more_info(index, dataset, iepoch=iepoch, hp=hp, is_random=True) | ||||
|     valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time'] | ||||
|     latency = self.get_latency(index, dataset) | ||||
|     if account_time: | ||||
|       self._used_time += time_cost | ||||
|     return valid_acc, latency, time_cost, self._used_time | ||||
|  | ||||
|   def random(self): | ||||
|     """Return a random index of all architectures.""" | ||||
|     return random.randint(0, len(self.meta_archs)-1) | ||||
|  | ||||
|   def query_index_by_arch(self, arch): | ||||
|     """ This function is used to query the index of an architecture in the search space. | ||||
|         In the topology search space, the input arch can be an architecture string such as '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|'; | ||||
|           or an instance that has the 'tostr' function that can generate the architecture string; | ||||
|           or it is directly an architecture index, in this case, we will check whether it is valid or not. | ||||
|         This function will return the index. | ||||
|         If return -1, it means this architecture is not in the search space. | ||||
|         Otherwise, it will return an int in [0, the-number-of-candidates-in-the-search-space). | ||||
|     """ | ||||
|     if self.verbose: | ||||
|       print('Call query_index_by_arch with arch={:}'.format(arch)) | ||||
|     if isinstance(arch, int): | ||||
|       if 0 <= arch < len(self): | ||||
|         return arch | ||||
|       else: | ||||
|         raise ValueError('Invalid architecture index {:} vs [{:}, {:}].'.format(arch, 0, len(self))) | ||||
|     elif isinstance(arch, str): | ||||
|       if arch in self.archstr2index: arch_index = self.archstr2index[ arch ] | ||||
|       else                         : arch_index = -1 | ||||
|     elif hasattr(arch, 'tostr'): | ||||
|       if arch.tostr() in self.archstr2index: arch_index = self.archstr2index[ arch.tostr() ] | ||||
|       else                                 : arch_index = -1 | ||||
|     else: arch_index = -1 | ||||
|     return arch_index | ||||
|  | ||||
|   def query_by_arch(self, arch, hp): | ||||
|     # This is to make the current version be compatible with the old version. | ||||
|     return self.query_info_str_by_arch(arch, hp) | ||||
|  | ||||
|   @abc.abstractmethod | ||||
|   def reload(self, archive_root: Text = None, index: int = None): | ||||
|     """Overwrite all information of the 'index'-th architecture in the search space, where the data will be loaded from 'archive_root'. | ||||
|        If index is None, overwrite all ckps. | ||||
|     """ | ||||
|  | ||||
|   def clear_params(self, index: int, hp: Optional[Text]=None): | ||||
|     """Remove the architecture's weights to save memory. | ||||
|     :arg | ||||
|       index: the index of the target architecture | ||||
|       hp: a flag to controll how to clear the parameters. | ||||
|         -- None: clear all the weights in '01'/'12'/'90', which indicates the number of training epochs. | ||||
|         -- '01' or '12' or '90': clear all the weights in arch2infos_dict[index][hp]. | ||||
|     """ | ||||
|     if self.verbose: | ||||
|       print('Call clear_params with index={:} and hp={:}'.format(index, hp)) | ||||
|     if hp is None: | ||||
|       for key, result in self.arch2infos_dict[index].items(): | ||||
|         result.clear_params() | ||||
|     else: | ||||
|       if str(hp) not in self.arch2infos_dict[index]: | ||||
|         raise ValueError('The {:}-th architecture only has hyper-parameters of {:} instead of {:}.'.format(index, list(self.arch2infos_dict[index].keys()), hp)) | ||||
|       self.arch2infos_dict[index][str(hp)].clear_params() | ||||
|  | ||||
|   @abc.abstractmethod | ||||
|   def query_info_str_by_arch(self, arch, hp: Text='12'): | ||||
|     """This function is used to query the information of a specific architecture.""" | ||||
|  | ||||
|   def _query_info_str_by_arch(self, arch, hp: Text='12', print_information=None): | ||||
|     arch_index = self.query_index_by_arch(arch) | ||||
|     if arch_index in self.arch2infos_dict: | ||||
|       if hp not in self.arch2infos_dict[arch_index]: | ||||
|         raise ValueError('The {:}-th architecture only has hyper-parameters of {:} instead of {:}.'.format(index, list(self.arch2infos_dict[arch_index].keys()), hp)) | ||||
|       info = self.arch2infos_dict[arch_index][hp] | ||||
|       strings = print_information(info, 'arch-index={:}'.format(arch_index)) | ||||
|       return '\n'.join(strings) | ||||
|     else: | ||||
|       print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index)) | ||||
|       return None | ||||
|  | ||||
|   def query_meta_info_by_index(self, arch_index, hp: Text = '12'): | ||||
|     """Return the ArchResults for the 'arch_index'-th architecture. This function is similar to query_by_index.""" | ||||
|     if self.verbose: | ||||
|       print('Call query_meta_info_by_index with arch_index={:}, hp={:}'.format(arch_index, hp)) | ||||
|     if arch_index in self.arch2infos_dict: | ||||
|       if hp not in self.arch2infos_dict[arch_index]: | ||||
|         raise ValueError('The {:}-th architecture only has hyper-parameters of {:} instead of {:}.'.format(arch_index, list(self.arch2infos_dict[arch_index].keys()), hp)) | ||||
|       info = self.arch2infos_dict[arch_index][hp] | ||||
|     else: | ||||
|       raise ValueError('arch_index [{:}] does not in arch2infos'.format(arch_index)) | ||||
|     return copy.deepcopy(info) | ||||
|  | ||||
|   def query_by_index(self, arch_index: int, dataname: Union[None, Text] = None, hp: Text = '12'): | ||||
|     """ This 'query_by_index' function is used to query information with the training of 01 epochs, 12 epochs, 90 epochs, or 200 epochs. | ||||
|         ------ | ||||
|         If hp=01, we train the model by 01 epochs (see config in configs/nas-benchmark/hyper-opts/01E.config) | ||||
|         If hp=12, we train the model by 01 epochs (see config in configs/nas-benchmark/hyper-opts/12E.config) | ||||
|         If hp=90, we train the model by 01 epochs (see config in configs/nas-benchmark/hyper-opts/90E.config) | ||||
|         If hp=200, we train the model by 01 epochs (see config in configs/nas-benchmark/hyper-opts/200E.config) | ||||
|         ------ | ||||
|         If dataname is None, return the ArchResults | ||||
|           else, return a dict with all trials on that dataset (the key is the seed) | ||||
|         Options are 'cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120'. | ||||
|         -- cifar10-valid : training the model on the CIFAR-10 training set. | ||||
|         -- cifar10 : training the model on the CIFAR-10 training + validation set. | ||||
|         -- cifar100 : training the model on the CIFAR-100 training set. | ||||
|         -- ImageNet16-120 : training the model on the ImageNet16-120 training set. | ||||
|     """ | ||||
|     if self.verbose: | ||||
|       print('Call query_by_index with arch_index={:}, dataname={:}, hp={:}'.format(arch_index, dataname, hp)) | ||||
|     info = self.query_meta_info_by_index(arch_index, hp) | ||||
|     if dataname is None: return info | ||||
|     else: | ||||
|       if dataname not in info.get_dataset_names(): | ||||
|         raise ValueError('invalid dataset-name : {:} vs. {:}'.format(dataname, info.get_dataset_names())) | ||||
|       return info.query(dataname) | ||||
|  | ||||
|   def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None, hp: Text = '12'): | ||||
|     """Find the architecture with the highest accuracy based on some constraints.""" | ||||
|     if self.verbose: | ||||
|       print('Call find_best with dataset={:}, metric_on_set={:}, hp={:} | with #FLOPs < {:} and #Params < {:}'.format(dataset, metric_on_set, hp, FLOP_max, Param_max)) | ||||
|     dataset, metric_on_set = remap_dataset_set_names(dataset, metric_on_set, self.verbose) | ||||
|     best_index, highest_accuracy = -1, None | ||||
|     for i, arch_index in enumerate(self.evaluated_indexes): | ||||
|       arch_info = self.arch2infos_dict[arch_index][hp] | ||||
|       info = arch_info.get_compute_costs(dataset)  # the information of costs | ||||
|       flop, param, latency = info['flops'], info['params'], info['latency'] | ||||
|       if FLOP_max  is not None and flop  > FLOP_max : continue | ||||
|       if Param_max is not None and param > Param_max: continue | ||||
|       xinfo = arch_info.get_metrics(dataset, metric_on_set)  # the information of loss and accuracy | ||||
|       loss, accuracy = xinfo['loss'], xinfo['accuracy'] | ||||
|       if best_index == -1: | ||||
|         best_index, highest_accuracy = arch_index, accuracy | ||||
|       elif highest_accuracy < accuracy: | ||||
|         best_index, highest_accuracy = arch_index, accuracy | ||||
|     if self.verbose: | ||||
|       print('  the best architecture : [{:}] {:} with accuracy={:.3f}%'.format(best_index, self.arch(best_index), highest_accuracy)) | ||||
|     return best_index, highest_accuracy | ||||
|  | ||||
|   def get_net_param(self, index, dataset, seed: Optional[int], hp: Text = '12'): | ||||
|     """ | ||||
|       This function is used to obtain the trained weights of the `index`-th architecture on `dataset` with the seed of `seed` | ||||
|       Args [seed]: | ||||
|         -- None : return a dict containing the trained weights of all trials, where each key is a seed and its corresponding value is the weights. | ||||
|         -- a interger : return the weights of a specific trial, whose seed is this interger. | ||||
|       Args [hp]: | ||||
|         -- 01 : train the model by 01 epochs | ||||
|         -- 12 : train the model by 12 epochs | ||||
|         -- 90 : train the model by 90 epochs | ||||
|         -- 200 : train the model by 200 epochs | ||||
|     """ | ||||
|     if self.verbose: | ||||
|       print('Call the get_net_param function with index={:}, dataset={:}, seed={:}, hp={:}'.format(index, dataset, seed, hp)) | ||||
|     info = self.query_meta_info_by_index(index, hp) | ||||
|     return info.get_net_param(dataset, seed) | ||||
|  | ||||
|   def get_net_config(self, index: int, dataset: Text): | ||||
|     """ | ||||
|       This function is used to obtain the configuration for the `index`-th architecture on `dataset`. | ||||
|       Args [dataset] (4 possible options): | ||||
|         -- cifar10-valid : training the model on the CIFAR-10 training set. | ||||
|         -- cifar10 : training the model on the CIFAR-10 training + validation set. | ||||
|         -- cifar100 : training the model on the CIFAR-100 training set. | ||||
|         -- ImageNet16-120 : training the model on the ImageNet16-120 training set. | ||||
|       This function will return a dict. | ||||
|       ========= Some examlpes for using this function: | ||||
|       config = api.get_net_config(128, 'cifar10') | ||||
|     """ | ||||
|     if self.verbose: | ||||
|       print('Call the get_net_config function with index={:}, dataset={:}.'.format(index, dataset)) | ||||
|     if index in self.arch2infos_dict: | ||||
|       info = self.arch2infos_dict[index] | ||||
|     else: | ||||
|       raise ValueError('The arch_index={:} is not in arch2infos_dict.'.format(arch_index)) | ||||
|     info = next(iter(info.values())) | ||||
|     results = info.query(dataset, None) | ||||
|     results = next(iter(results.values())) | ||||
|     return results.get_config(None) | ||||
|    | ||||
|   def get_cost_info(self, index: int, dataset: Text, hp: Text = '12') -> Dict[Text, float]: | ||||
|     """To obtain the cost metric for the `index`-th architecture on a dataset.""" | ||||
|     if self.verbose: | ||||
|       print('Call the get_cost_info function with index={:}, dataset={:}, and hp={:}.'.format(index, dataset, hp)) | ||||
|     info = self.query_meta_info_by_index(index, hp) | ||||
|     return info.get_compute_costs(dataset) | ||||
|  | ||||
|   def get_latency(self, index: int, dataset: Text, hp: Text = '12') -> float: | ||||
|     """ | ||||
|     To obtain the latency of the network (by default it will return the latency with the batch size of 256). | ||||
|     :param index: the index of the target architecture | ||||
|     :param dataset: the dataset name (cifar10-valid, cifar10, cifar100, ImageNet16-120) | ||||
|     :return: return a float value in seconds | ||||
|     """ | ||||
|     if self.verbose: | ||||
|       print('Call the get_latency function with index={:}, dataset={:}, and hp={:}.'.format(index, dataset, hp)) | ||||
|     cost_dict = self.get_cost_info(index, dataset, hp) | ||||
|     return cost_dict['latency'] | ||||
|  | ||||
|   @abc.abstractmethod | ||||
|   def show(self, index=-1): | ||||
|     """This function will print the information of a specific (or all) architecture(s).""" | ||||
|  | ||||
|   def _show(self, index=-1, print_information=None) -> None: | ||||
|     """ | ||||
|     This function will print the information of a specific (or all) architecture(s). | ||||
|  | ||||
|     :param index: If the index < 0: it will loop for all architectures and print their information one by one. | ||||
|                   else: it will print the information of the 'index'-th architecture. | ||||
|     :return: nothing | ||||
|     """ | ||||
|     if index < 0: # show all architectures | ||||
|       print(self) | ||||
|       for i, idx in enumerate(self.evaluated_indexes): | ||||
|         print('\n' + '-' * 10 + ' The ({:5d}/{:5d}) {:06d}-th architecture! '.format(i, len(self.evaluated_indexes), idx) + '-'*10) | ||||
|         print('arch : {:}'.format(self.meta_archs[idx])) | ||||
|         for key, result in self.arch2infos_dict[index].items(): | ||||
|           strings = print_information(result) | ||||
|           print('>' * 40 + ' {:03d} epochs '.format(result.get_total_epoch()) + '>' * 40) | ||||
|           print('\n'.join(strings)) | ||||
|         print('<' * 40 + '------------' + '<' * 40) | ||||
|     else: | ||||
|       if 0 <= index < len(self.meta_archs): | ||||
|         if index not in self.evaluated_indexes: print('The {:}-th architecture has not been evaluated or not saved.'.format(index)) | ||||
|         else: | ||||
|           arch_info = self.arch2infos_dict[index] | ||||
|           for key, result in self.arch2infos_dict[index].items(): | ||||
|             strings = print_information(result) | ||||
|             print('>' * 40 + ' {:03d} epochs '.format(result.get_total_epoch()) + '>' * 40) | ||||
|             print('\n'.join(strings)) | ||||
|           print('<' * 40 + '------------' + '<' * 40) | ||||
|       else: | ||||
|         print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs))) | ||||
|  | ||||
|   def statistics(self, dataset: Text, hp: Union[Text, int]) -> Dict[int, int]: | ||||
|     """This function will count the number of total trials.""" | ||||
|     if self.verbose: | ||||
|       print('Call the statistics function with dataset={:} and hp={:}.'.format(dataset, hp)) | ||||
|     valid_datasets = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120'] | ||||
|     if dataset not in valid_datasets: | ||||
|       raise ValueError('{:} not in {:}'.format(dataset, valid_datasets)) | ||||
|     nums, hp = defaultdict(lambda: 0), str(hp) | ||||
|     for index in range(len(self)): | ||||
|       archInfo = self.arch2infos_dict[index][hp] | ||||
|       dataset_seed = archInfo.dataset_seed | ||||
|       if dataset not in dataset_seed: | ||||
|         nums[0] += 1 | ||||
|       else: | ||||
|         nums[len(dataset_seed[dataset])] += 1 | ||||
|     return dict(nums) | ||||
|  | ||||
|  | ||||
| class ArchResults(object): | ||||
|  | ||||
|   def __init__(self, arch_index, arch_str): | ||||
|     self.arch_index   = int(arch_index) | ||||
|     self.arch_str     = copy.deepcopy(arch_str) | ||||
|     self.all_results  = dict() | ||||
|     self.dataset_seed = dict() | ||||
|     self.clear_net_done = False | ||||
|  | ||||
|   def get_compute_costs(self, dataset): | ||||
|     x_seeds = self.dataset_seed[dataset] | ||||
|     results = [self.all_results[ (dataset, seed) ] for seed in x_seeds] | ||||
|  | ||||
|     flops     = [result.flop for result in results] | ||||
|     params    = [result.params for result in results] | ||||
|     latencies = [result.get_latency() for result in results] | ||||
|     latencies = [x for x in latencies if x > 0] | ||||
|     mean_latency = np.mean(latencies) if len(latencies) > 0 else None | ||||
|     time_infos = defaultdict(list) | ||||
|     for result in results: | ||||
|       time_info = result.get_times() | ||||
|       for key, value in time_info.items(): time_infos[key].append( value ) | ||||
|       | ||||
|     info = {'flops'  : np.mean(flops), | ||||
|             'params' : np.mean(params), | ||||
|             'latency': mean_latency} | ||||
|     for key, value in time_infos.items(): | ||||
|       if len(value) > 0 and value[0] is not None: | ||||
|         info[key] = np.mean(value) | ||||
|       else: info[key] = None | ||||
|     return info | ||||
|  | ||||
|   def get_metrics(self, dataset, setname, iepoch=None, is_random=False): | ||||
|     """ | ||||
|       This `get_metrics` function is used to obtain obtain the loss, accuracy, etc information on a specific dataset. | ||||
|       If not specify, each set refer to the proposed split in NAS-Bench-201 paper. | ||||
|       If some args return None or raise error, then it is not avaliable. | ||||
|       ======================================== | ||||
|       Args [dataset] (4 possible options): | ||||
|         -- cifar10-valid : training the model on the CIFAR-10 training set. | ||||
|         -- cifar10 : training the model on the CIFAR-10 training + validation set. | ||||
|         -- cifar100 : training the model on the CIFAR-100 training set. | ||||
|         -- ImageNet16-120 : training the model on the ImageNet16-120 training set. | ||||
|       Args [setname] (each dataset has different setnames): | ||||
|         -- When dataset = cifar10-valid, you can use 'train', 'x-valid', 'ori-test' | ||||
|         ------ 'train' : the metric on the training set. | ||||
|         ------ 'x-valid' : the metric on the validation set. | ||||
|         ------ 'ori-test' : the metric on the test set. | ||||
|         -- When dataset = cifar10, you can use 'train', 'ori-test'. | ||||
|         ------ 'train' : the metric on the training + validation set. | ||||
|         ------ 'ori-test' : the metric on the test set. | ||||
|         -- When dataset = cifar100 or ImageNet16-120, you can use 'train', 'ori-test', 'x-valid', 'x-test' | ||||
|         ------ 'train' : the metric on the training set. | ||||
|         ------ 'x-valid' : the metric on the validation set. | ||||
|         ------ 'x-test' : the metric on the test set. | ||||
|         ------ 'ori-test' : the metric on the validation + test set. | ||||
|       Args [iepoch] (None or an integer in [0, the-number-of-total-training-epochs) | ||||
|         ------ None : return the metric after the last training epoch. | ||||
|         ------ an integer i : return the metric after the i-th training epoch. | ||||
|       Args [is_random]: | ||||
|         ------ True : return the metric of a randomly selected trial. | ||||
|         ------ False : return the averaged metric of all avaliable trials. | ||||
|         ------ an integer indicating the 'seed' value : return the metric of a specific trial (whose random seed is 'is_random'). | ||||
|     """ | ||||
|     x_seeds = self.dataset_seed[dataset] | ||||
|     results = [self.all_results[ (dataset, seed) ] for seed in x_seeds] | ||||
|     infos   = defaultdict(list) | ||||
|     for result in results: | ||||
|       if setname == 'train': | ||||
|         info = result.get_train(iepoch) | ||||
|       else: | ||||
|         info = result.get_eval(setname, iepoch) | ||||
|       for key, value in info.items(): infos[key].append( value ) | ||||
|     return_info = dict() | ||||
|     if isinstance(is_random, bool) and is_random: # randomly select one | ||||
|       index = random.randint(0, len(results)-1) | ||||
|       for key, value in infos.items(): return_info[key] = value[index] | ||||
|     elif isinstance(is_random, bool) and not is_random: # average | ||||
|       for key, value in infos.items(): | ||||
|         if len(value) > 0 and value[0] is not None: | ||||
|           return_info[key] = np.mean(value) | ||||
|         else: return_info[key] = None | ||||
|     elif isinstance(is_random, int): # specify the seed | ||||
|       if is_random not in x_seeds: raise ValueError('can not find random seed ({:}) from {:}'.format(is_random, x_seeds)) | ||||
|       index = x_seeds.index(is_random) | ||||
|       for key, value in infos.items(): return_info[key] = value[index] | ||||
|     else: | ||||
|       raise ValueError('invalid value for is_random: {:}'.format(is_random)) | ||||
|     return return_info | ||||
|  | ||||
|   def show(self, is_print=False): | ||||
|     return print_information(self, None, is_print) | ||||
|  | ||||
|   def get_dataset_names(self): | ||||
|     return list(self.dataset_seed.keys()) | ||||
|  | ||||
|   def get_dataset_seeds(self, dataset): | ||||
|     return copy.deepcopy( self.dataset_seed[dataset] ) | ||||
|  | ||||
|   def get_net_param(self, dataset: Text, seed: Union[None, int] =None): | ||||
|     """ | ||||
|     This function will return the trained network's weights on the 'dataset'. | ||||
|     :arg | ||||
|       dataset: one of 'cifar10-valid', 'cifar10', 'cifar100', and 'ImageNet16-120'. | ||||
|       seed: an integer indicates the seed value or None that indicates returing all trials. | ||||
|     """ | ||||
|     if seed is None: | ||||
|       x_seeds = self.dataset_seed[dataset] | ||||
|       return {seed: self.all_results[(dataset, seed)].get_net_param() for seed in x_seeds} | ||||
|     else: | ||||
|       xkey = (dataset, seed) | ||||
|       if xkey in self.all_results: | ||||
|         return self.all_results[xkey].get_net_param() | ||||
|       else: | ||||
|         raise ValueError('key={:} not in {:}'.format(xkey, list(self.all_results.keys()))) | ||||
|  | ||||
|   def reset_latency(self, dataset: Text, seed: Union[None, Text], latency: float) -> None: | ||||
|     """This function is used to reset the latency in all corresponding ResultsCount(s).""" | ||||
|     if seed is None: | ||||
|       for seed in self.dataset_seed[dataset]: | ||||
|         self.all_results[(dataset, seed)].update_latency([latency]) | ||||
|     else: | ||||
|       self.all_results[(dataset, seed)].update_latency([latency]) | ||||
|  | ||||
|   def reset_pseudo_train_times(self, dataset: Text, seed: Union[None, Text], estimated_per_epoch_time: float) -> None: | ||||
|     """This function is used to reset the train-times in all corresponding ResultsCount(s).""" | ||||
|     if seed is None: | ||||
|       for seed in self.dataset_seed[dataset]: | ||||
|         self.all_results[(dataset, seed)].reset_pseudo_train_times(estimated_per_epoch_time) | ||||
|     else: | ||||
|       self.all_results[(dataset, seed)].reset_pseudo_train_times(estimated_per_epoch_time) | ||||
|  | ||||
|   def reset_pseudo_eval_times(self, dataset: Text, seed: Union[None, Text], eval_name: Text, estimated_per_epoch_time: float) -> None: | ||||
|     """This function is used to reset the eval-times in all corresponding ResultsCount(s).""" | ||||
|     if seed is None: | ||||
|       for seed in self.dataset_seed[dataset]: | ||||
|         self.all_results[(dataset, seed)].reset_pseudo_eval_times(eval_name, estimated_per_epoch_time) | ||||
|     else: | ||||
|       self.all_results[(dataset, seed)].reset_pseudo_eval_times(eval_name, estimated_per_epoch_time) | ||||
|  | ||||
|   def get_latency(self, dataset: Text) -> float: | ||||
|     """Get the latency of a model on the target dataset. [Timestamp: 2020.03.09]""" | ||||
|     latencies = [] | ||||
|     for seed in self.dataset_seed[dataset]: | ||||
|       latency = self.all_results[(dataset, seed)].get_latency() | ||||
|       if not isinstance(latency, float) or latency <= 0: | ||||
|         raise ValueError('invalid latency of {:} with seed={:} : {:}'.format(dataset, seed, latency)) | ||||
|       latencies.append(latency) | ||||
|     return sum(latencies) / len(latencies) | ||||
|  | ||||
|   def get_total_epoch(self, dataset=None): | ||||
|     """Return the total number of training epochs.""" | ||||
|     if dataset is None: | ||||
|       epochss = [] | ||||
|       for xdata, x_seeds in self.dataset_seed.items(): | ||||
|         epochss += [self.all_results[(xdata, seed)].get_total_epoch() for seed in x_seeds] | ||||
|     elif isinstance(dataset, str): | ||||
|       x_seeds = self.dataset_seed[dataset] | ||||
|       epochss = [self.all_results[(dataset, seed)].get_total_epoch() for seed in x_seeds] | ||||
|     else: | ||||
|       raise ValueError('invalid dataset={:}'.format(dataset)) | ||||
|     if len(set(epochss)) > 1: raise ValueError('Each trial mush have the same number of training epochs : {:}'.format(epochss)) | ||||
|     return epochss[-1] | ||||
|  | ||||
|   def query(self, dataset, seed=None): | ||||
|     """Return the ResultsCount object (containing all information of a single trial) for 'dataset' and 'seed'""" | ||||
|     if seed is None: | ||||
|       x_seeds = self.dataset_seed[dataset] | ||||
|       return {seed: self.all_results[(dataset, seed)] for seed in x_seeds} | ||||
|     else: | ||||
|       return self.all_results[(dataset, seed)] | ||||
|  | ||||
|   def arch_idx_str(self): | ||||
|     return '{:06d}'.format(self.arch_index) | ||||
|  | ||||
|   def update(self, dataset_name, seed, result): | ||||
|     if dataset_name not in self.dataset_seed: | ||||
|       self.dataset_seed[dataset_name] = [] | ||||
|     assert seed not in self.dataset_seed[dataset_name], '{:}-th arch alreadly has this seed ({:}) on {:}'.format(self.arch_index, seed, dataset_name) | ||||
|     self.dataset_seed[ dataset_name ].append( seed ) | ||||
|     self.dataset_seed[ dataset_name ] = sorted( self.dataset_seed[ dataset_name ] ) | ||||
|     assert (dataset_name, seed) not in self.all_results | ||||
|     self.all_results[ (dataset_name, seed) ] = result | ||||
|     self.clear_net_done = False | ||||
|  | ||||
|   def state_dict(self): | ||||
|     state_dict = dict() | ||||
|     for key, value in self.__dict__.items(): | ||||
|       if key == 'all_results': # contain the class of ResultsCount | ||||
|         xvalue = dict() | ||||
|         assert isinstance(value, dict), 'invalid type of value for {:} : {:}'.format(key, type(value)) | ||||
|         for _k, _v in value.items(): | ||||
|           assert isinstance(_v, ResultsCount), 'invalid type of value for {:}/{:} : {:}'.format(key, _k, type(_v)) | ||||
|           xvalue[_k] = _v.state_dict() | ||||
|       else: | ||||
|         xvalue = value | ||||
|       state_dict[key] = xvalue | ||||
|     return state_dict | ||||
|  | ||||
|   def load_state_dict(self, state_dict): | ||||
|     new_state_dict = dict() | ||||
|     for key, value in state_dict.items(): | ||||
|       if key == 'all_results': # to convert to the class of ResultsCount | ||||
|         xvalue = dict() | ||||
|         assert isinstance(value, dict), 'invalid type of value for {:} : {:}'.format(key, type(value)) | ||||
|         for _k, _v in value.items(): | ||||
|           xvalue[_k] = ResultsCount.create_from_state_dict(_v) | ||||
|       else: xvalue = value | ||||
|       new_state_dict[key] = xvalue | ||||
|     self.__dict__.update(new_state_dict) | ||||
|  | ||||
|   @staticmethod | ||||
|   def create_from_state_dict(state_dict_or_file): | ||||
|     x = ArchResults(-1, -1) | ||||
|     if isinstance(state_dict_or_file, str): # a file path | ||||
|       state_dict = torch.load(state_dict_or_file, map_location='cpu') | ||||
|     elif isinstance(state_dict_or_file, dict): | ||||
|       state_dict = state_dict_or_file | ||||
|     else: | ||||
|       raise ValueError('invalid type of state_dict_or_file : {:}'.format(type(state_dict_or_file))) | ||||
|     x.load_state_dict(state_dict) | ||||
|     return x | ||||
|  | ||||
|   # This function is used to clear the weights saved in each 'result' | ||||
|   # This can help reduce the memory footprint. | ||||
|   def clear_params(self): | ||||
|     for key, result in self.all_results.items(): | ||||
|       del result.net_state_dict | ||||
|       result.net_state_dict = None | ||||
|     self.clear_net_done = True | ||||
|  | ||||
|   def debug_test(self): | ||||
|     """This function is used for me to debug and test, which will call most methods.""" | ||||
|     all_dataset = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120'] | ||||
|     for dataset in all_dataset: | ||||
|       print('---->>>> {:}'.format(dataset)) | ||||
|       print('The latency on {:} is {:} s'.format(dataset, self.get_latency(dataset))) | ||||
|       for seed in self.dataset_seed[dataset]: | ||||
|         result = self.all_results[(dataset, seed)] | ||||
|         print('  ==>> result = {:}'.format(result)) | ||||
|         print('  ==>> cost = {:}'.format(result.get_times())) | ||||
|  | ||||
|   def __repr__(self): | ||||
|     return ('{name}(arch-index={index}, arch={arch}, {num} runs, clear={clear})'.format(name=self.__class__.__name__, index=self.arch_index, arch=self.arch_str, num=len(self.all_results), clear=self.clear_net_done)) | ||||
|  | ||||
|  | ||||
| """ | ||||
| This class (ResultsCount) is used to save the information of one trial for a single architecture. | ||||
| I did not write much comment for this class, because it is the lowest-level class in NAS-Bench-201 API, which will be rarely called. | ||||
| If you have any question regarding this class, please open an issue or email me. | ||||
| """ | ||||
| class ResultsCount(object): | ||||
|  | ||||
|   def __init__(self, name, state_dict, train_accs, train_losses, params, flop, arch_config, seed, epochs, latency): | ||||
|     self.name           = name | ||||
|     self.net_state_dict = state_dict | ||||
|     self.train_acc1es = copy.deepcopy(train_accs) | ||||
|     self.train_acc5es = None | ||||
|     self.train_losses = copy.deepcopy(train_losses) | ||||
|     self.train_times  = None | ||||
|     self.arch_config  = copy.deepcopy(arch_config) | ||||
|     self.params     = params | ||||
|     self.flop       = flop | ||||
|     self.seed       = seed | ||||
|     self.epochs     = epochs | ||||
|     self.latency    = latency | ||||
|     # evaluation results | ||||
|     self.reset_eval() | ||||
|  | ||||
|   def update_train_info(self, train_acc1es, train_acc5es, train_losses, train_times) -> None: | ||||
|     self.train_acc1es = train_acc1es | ||||
|     self.train_acc5es = train_acc5es | ||||
|     self.train_losses = train_losses | ||||
|     self.train_times  = train_times | ||||
|  | ||||
|   def reset_pseudo_train_times(self, estimated_per_epoch_time: float) -> None: | ||||
|     """Assign the training times.""" | ||||
|     train_times = OrderedDict() | ||||
|     for i in range(self.epochs): | ||||
|       train_times[i] = estimated_per_epoch_time | ||||
|     self.train_times = train_times | ||||
|  | ||||
|   def reset_pseudo_eval_times(self, eval_name: Text, estimated_per_epoch_time: float) -> None: | ||||
|     """Assign the evaluation times.""" | ||||
|     if eval_name not in self.eval_names: raise ValueError('invalid eval name : {:}'.format(eval_name)) | ||||
|     for i in range(self.epochs): | ||||
|       self.eval_times['{:}@{:}'.format(eval_name,i)] = estimated_per_epoch_time | ||||
|  | ||||
|   def reset_eval(self): | ||||
|     self.eval_names  = [] | ||||
|     self.eval_acc1es = {} | ||||
|     self.eval_times  = {} | ||||
|     self.eval_losses = {} | ||||
|  | ||||
|   def update_latency(self, latency): | ||||
|     self.latency = copy.deepcopy( latency ) | ||||
|  | ||||
|   def get_latency(self) -> float: | ||||
|     """Return the latency value in seconds. -1 represents not avaliable ; otherwise it should be a float value""" | ||||
|     if self.latency is None: return -1.0 | ||||
|     else: return sum(self.latency) / len(self.latency) | ||||
|  | ||||
|   def update_eval(self, accs, losses, times):  # new version | ||||
|     data_names = set([x.split('@')[0] for x in accs.keys()]) | ||||
|     for data_name in data_names: | ||||
|       assert data_name not in self.eval_names, '{:} has already been added into eval-names'.format(data_name) | ||||
|       self.eval_names.append( data_name ) | ||||
|       for iepoch in range(self.epochs): | ||||
|         xkey = '{:}@{:}'.format(data_name, iepoch) | ||||
|         self.eval_acc1es[ xkey ] = accs[ xkey ] | ||||
|         self.eval_losses[ xkey ] = losses[ xkey ] | ||||
|         self.eval_times [ xkey ] = times[ xkey ] | ||||
|  | ||||
|   def update_OLD_eval(self, name, accs, losses): # old version | ||||
|     assert name not in self.eval_names, '{:} has already added'.format(name) | ||||
|     self.eval_names.append( name ) | ||||
|     for iepoch in range(self.epochs): | ||||
|       if iepoch in accs: | ||||
|         self.eval_acc1es['{:}@{:}'.format(name,iepoch)] = accs[iepoch] | ||||
|         self.eval_losses['{:}@{:}'.format(name,iepoch)] = losses[iepoch] | ||||
|  | ||||
|   def __repr__(self): | ||||
|     num_eval = len(self.eval_names) | ||||
|     set_name = '[' + ', '.join(self.eval_names) + ']' | ||||
|     return ('{name}({xname}, arch={arch}, FLOP={flop:.2f}M, Param={param:.3f}MB, seed={seed}, {num_eval} eval-sets: {set_name})'.format(name=self.__class__.__name__, xname=self.name, arch=self.arch_config['arch_str'], flop=self.flop, param=self.params, seed=self.seed, num_eval=num_eval, set_name=set_name)) | ||||
|  | ||||
|   def get_total_epoch(self): | ||||
|     return copy.deepcopy(self.epochs) | ||||
|  | ||||
|   def get_times(self): | ||||
|     """Obtain the information regarding both training and evaluation time.""" | ||||
|     if self.train_times is not None and isinstance(self.train_times, dict): | ||||
|       train_times = list( self.train_times.values() ) | ||||
|       time_info = {'T-train@epoch': np.mean(train_times), 'T-train@total': np.sum(train_times)} | ||||
|     else: | ||||
|       time_info = {'T-train@epoch':                 None, 'T-train@total':               None } | ||||
|     for name in self.eval_names: | ||||
|       try: | ||||
|         xtimes = [self.eval_times['{:}@{:}'.format(name,i)] for i in range(self.epochs)] | ||||
|         time_info['T-{:}@epoch'.format(name)] = np.mean(xtimes) | ||||
|         time_info['T-{:}@total'.format(name)] = np.sum(xtimes) | ||||
|       except: | ||||
|         time_info['T-{:}@epoch'.format(name)] = None | ||||
|         time_info['T-{:}@total'.format(name)] = None | ||||
|     return time_info | ||||
|  | ||||
|   def get_eval_set(self): | ||||
|     return self.eval_names | ||||
|  | ||||
|   # get the training information | ||||
|   def get_train(self, iepoch=None): | ||||
|     if iepoch is None: iepoch = self.epochs-1 | ||||
|     assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs) | ||||
|     if self.train_times is not None: | ||||
|       xtime = self.train_times[iepoch] | ||||
|       atime = sum([self.train_times[i] for i in range(iepoch+1)]) | ||||
|     else: xtime, atime = None, None | ||||
|     return {'iepoch'  : iepoch, | ||||
|             'loss'    : self.train_losses[iepoch], | ||||
|             'accuracy': self.train_acc1es[iepoch], | ||||
|             'cur_time': xtime, | ||||
|             'all_time': atime} | ||||
|  | ||||
|   def get_eval(self, name, iepoch=None): | ||||
|     """Get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument).""" | ||||
|     if iepoch is None: iepoch = self.epochs-1 | ||||
|     assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs) | ||||
|     def _internal_query(xname): | ||||
|       if isinstance(self.eval_times,dict) and len(self.eval_times) > 0: | ||||
|         xtime = self.eval_times['{:}@{:}'.format(xname, iepoch)] | ||||
|         atime = sum([self.eval_times['{:}@{:}'.format(xname, i)] for i in range(iepoch+1)]) | ||||
|       else: | ||||
|         xtime, atime = None, None | ||||
|       return {'iepoch'  : iepoch, | ||||
|               'loss'    : self.eval_losses['{:}@{:}'.format(xname, iepoch)], | ||||
|               'accuracy': self.eval_acc1es['{:}@{:}'.format(xname, iepoch)], | ||||
|               'cur_time': xtime, | ||||
|               'all_time': atime} | ||||
|     if name == 'valid': | ||||
|       return _internal_query('x-valid') | ||||
|     else: | ||||
|       return _internal_query(name) | ||||
|  | ||||
|   def get_net_param(self, clone=False): | ||||
|     if clone: return copy.deepcopy(self.net_state_dict) | ||||
|     else: return self.net_state_dict | ||||
|  | ||||
|   def get_config(self, str2structure): | ||||
|     """This function is used to obtain the config dict for this architecture.""" | ||||
|     if str2structure is None: | ||||
|       # In this case, this is to handle the size search space. | ||||
|       if 'name' in self.arch_config and self.arch_config['name'] == 'infer.shape.tiny': | ||||
|         return {'name': 'infer.shape.tiny', 'channels': self.arch_config['channels'], | ||||
|                 'genotype': self.arch_config['genotype'], 'num_classes': self.arch_config['class_num']} | ||||
|       # In this case, this is NAS-Bench-201 | ||||
|       else: | ||||
|         return {'name': 'infer.tiny', 'C': self.arch_config['channel'], | ||||
|                 'N'   : self.arch_config['num_cells'], | ||||
|                 'arch_str': self.arch_config['arch_str'], 'num_classes': self.arch_config['class_num']} | ||||
|     else: | ||||
|       # In this case, this is to handle the size search space. | ||||
|       if 'name' in self.arch_config and self.arch_config['name'] == 'infer.shape.tiny': | ||||
|         return {'name': 'infer.shape.tiny', 'channels': self.arch_config['channels'], | ||||
|                 'genotype': str2structure(self.arch_config['genotype']), 'num_classes': self.arch_config['class_num']} | ||||
|       # In this case, this is NAS-Bench-201 | ||||
|       else: | ||||
|         return {'name': 'infer.tiny', 'C': self.arch_config['channel'], | ||||
|                 'N'   : self.arch_config['num_cells'], | ||||
|                 'genotype': str2structure(self.arch_config['arch_str']), 'num_classes': self.arch_config['class_num']} | ||||
|  | ||||
|   def state_dict(self): | ||||
|     _state_dict = {key: value for key, value in self.__dict__.items()} | ||||
|     return _state_dict | ||||
|  | ||||
|   def load_state_dict(self, state_dict): | ||||
|     self.__dict__.update(state_dict) | ||||
|  | ||||
|   @staticmethod | ||||
|   def create_from_state_dict(state_dict): | ||||
|     x = ResultsCount(None, None, None, None, None, None, None, None, None, None) | ||||
|     x.load_state_dict(state_dict) | ||||
|     return x | ||||
| @@ -1,76 +0,0 @@ | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from .construct_utils import drop_path | ||||
| from .head_utils      import CifarHEAD, AuxiliaryHeadCIFAR | ||||
| from .base_cells      import InferCell | ||||
|  | ||||
|  | ||||
| class NetworkCIFAR(nn.Module): | ||||
|  | ||||
|   def __init__(self, C, N, stem_multiplier, auxiliary, genotype, num_classes): | ||||
|     super(NetworkCIFAR, self).__init__() | ||||
|     self._C               = C | ||||
|     self._layerN          = N | ||||
|     self._stem_multiplier = stem_multiplier | ||||
|  | ||||
|     C_curr = self._stem_multiplier * C | ||||
|     self.stem = CifarHEAD(C_curr) | ||||
|    | ||||
|     layer_channels   = [C    ] * N + [C*2 ] + [C*2  ] * N + [C*4 ] + [C*4  ] * N     | ||||
|     layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||
|     block_indexs     = [0    ] * N + [-1  ] + [1    ] * N + [-1  ] + [2    ] * N | ||||
|     block2index      = {0:[], 1:[], 2:[]} | ||||
|  | ||||
|     C_prev_prev, C_prev, C_curr = C_curr, C_curr, C | ||||
|     reduction_prev, spatial, dims = False, 1, [] | ||||
|     self.auxiliary_index = None | ||||
|     self.auxiliary_head  = None | ||||
|     self.cells = nn.ModuleList() | ||||
|     for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): | ||||
|       cell = InferCell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) | ||||
|       reduction_prev = reduction | ||||
|       self.cells.append( cell ) | ||||
|       C_prev_prev, C_prev = C_prev, cell._multiplier*C_curr | ||||
|       if reduction and C_curr == C*4: | ||||
|         if auxiliary: | ||||
|           self.auxiliary_head = AuxiliaryHeadCIFAR(C_prev, num_classes) | ||||
|           self.auxiliary_index = index | ||||
|  | ||||
|       if reduction: spatial *= 2 | ||||
|       dims.append( (C_prev, spatial) ) | ||||
|        | ||||
|     self._Layer= len(self.cells) | ||||
|  | ||||
|  | ||||
|     self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|     self.classifier = nn.Linear(C_prev, num_classes) | ||||
|     self.drop_path_prob = -1 | ||||
|  | ||||
|   def update_drop_path(self, drop_path_prob): | ||||
|     self.drop_path_prob = drop_path_prob | ||||
|  | ||||
|   def auxiliary_param(self): | ||||
|     if self.auxiliary_head is None: return [] | ||||
|     else: return list( self.auxiliary_head.parameters() ) | ||||
|  | ||||
|   def get_message(self): | ||||
|     return self.extra_repr() | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return ('{name}(C={_C}, N={_layerN}, L={_Layer}, stem={_stem_multiplier}, drop-path={drop_path_prob})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     stem_feature, logits_aux = self.stem(inputs), None | ||||
|     cell_results = [stem_feature, stem_feature] | ||||
|     for i, cell in enumerate(self.cells): | ||||
|       cell_feature = cell(cell_results[-2], cell_results[-1], self.drop_path_prob) | ||||
|       cell_results.append( cell_feature ) | ||||
|  | ||||
|       if self.auxiliary_index is not None and i == self.auxiliary_index and self.training: | ||||
|         logits_aux = self.auxiliary_head( cell_results[-1] ) | ||||
|     out = self.global_pooling( cell_results[-1] ) | ||||
|     out = out.view(out.size(0), -1) | ||||
|     logits = self.classifier(out) | ||||
|  | ||||
|     if logits_aux is None: return out, logits | ||||
|     else                 : return out, [logits, logits_aux] | ||||
| @@ -1,77 +0,0 @@ | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from .construct_utils import drop_path | ||||
| from .base_cells import InferCell | ||||
| from .head_utils import ImageNetHEAD, AuxiliaryHeadImageNet | ||||
|  | ||||
|  | ||||
| class NetworkImageNet(nn.Module): | ||||
|  | ||||
|   def __init__(self, C, N, auxiliary, genotype, num_classes): | ||||
|     super(NetworkImageNet, self).__init__() | ||||
|     self._C          = C | ||||
|     self._layerN     = N | ||||
|     layer_channels   = [C    ] * N + [C*2 ] + [C*2  ] * N + [C*4 ] + [C*4] * N | ||||
|     layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||
|     self.stem0 = nn.Sequential( | ||||
|       nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False), | ||||
|       nn.BatchNorm2d(C // 2), | ||||
|       nn.ReLU(inplace=True), | ||||
|       nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False), | ||||
|       nn.BatchNorm2d(C), | ||||
|     ) | ||||
|  | ||||
|     self.stem1 = nn.Sequential( | ||||
|       nn.ReLU(inplace=True), | ||||
|       nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False), | ||||
|       nn.BatchNorm2d(C), | ||||
|     ) | ||||
|  | ||||
|     C_prev_prev, C_prev, C_curr, reduction_prev = C, C, C, True | ||||
|  | ||||
|     self.cells = nn.ModuleList() | ||||
|     self.auxiliary_index = None | ||||
|     for i, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): | ||||
|       cell = InferCell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) | ||||
|       reduction_prev = reduction | ||||
|       self.cells += [cell] | ||||
|       C_prev_prev, C_prev = C_prev, cell._multiplier * C_curr | ||||
|       if reduction and C_curr == C*4: | ||||
|         C_to_auxiliary = C_prev | ||||
|         self.auxiliary_index = i | ||||
|    | ||||
|     self._NNN = len(self.cells) | ||||
|     if auxiliary: | ||||
|       self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes) | ||||
|     else: | ||||
|       self.auxiliary_head = None | ||||
|     self.global_pooling = nn.AvgPool2d(7) | ||||
|     self.classifier     = nn.Linear(C_prev, num_classes) | ||||
|     self.drop_path_prob = -1 | ||||
|  | ||||
|   def update_drop_path(self, drop_path_prob): | ||||
|     self.drop_path_prob = drop_path_prob | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return ('{name}(C={_C}, N=[{_layerN}, {_NNN}], aux-index={auxiliary_index}, drop-path={drop_path_prob})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||
|  | ||||
|   def get_message(self): | ||||
|     return self.extra_repr() | ||||
|  | ||||
|   def auxiliary_param(self): | ||||
|     if self.auxiliary_head is None: return [] | ||||
|     else: return list( self.auxiliary_head.parameters() ) | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     s0 = self.stem0(inputs) | ||||
|     s1 = self.stem1(s0) | ||||
|     logits_aux = None | ||||
|     for i, cell in enumerate(self.cells): | ||||
|       s0, s1 = s1, cell(s0, s1, self.drop_path_prob) | ||||
|       if i == self.auxiliary_index and self.auxiliary_head and self.training: | ||||
|         logits_aux = self.auxiliary_head(s1) | ||||
|     out = self.global_pooling(s1) | ||||
|     logits = self.classifier(out.view(out.size(0), -1)) | ||||
|  | ||||
|     if logits_aux is None: return out, logits | ||||
|     else                 : return out, [logits, logits_aux] | ||||
| @@ -1,5 +0,0 @@ | ||||
| # Performance-Aware Template Network for One-Shot Neural Architecture Search | ||||
| from .CifarNet  import NetworkCIFAR as CifarNet | ||||
| from .ImageNet  import NetworkImageNet as ImageNet | ||||
| from .genotypes import Networks | ||||
| from .genotypes import build_genotype_from_dict | ||||
| @@ -1,173 +0,0 @@ | ||||
| import math | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| from .construct_utils import drop_path | ||||
| from ..operations import OPS, Identity, FactorizedReduce, ReLUConvBN | ||||
|  | ||||
|  | ||||
| class MixedOp(nn.Module): | ||||
|  | ||||
|   def __init__(self, C, stride, PRIMITIVES): | ||||
|     super(MixedOp, self).__init__() | ||||
|     self._ops = nn.ModuleList() | ||||
|     self.name2idx = {} | ||||
|     for idx, primitive in enumerate(PRIMITIVES): | ||||
|       op = OPS[primitive](C, C, stride, False) | ||||
|       self._ops.append(op) | ||||
|       assert primitive not in self.name2idx, '{:} has already in'.format(primitive) | ||||
|       self.name2idx[primitive] = idx | ||||
|  | ||||
|   def forward(self, x, weights, op_name): | ||||
|     if op_name is None: | ||||
|       if weights is None: | ||||
|         return [op(x) for op in self._ops] | ||||
|       else: | ||||
|         return sum(w * op(x) for w, op in zip(weights, self._ops)) | ||||
|     else: | ||||
|       op_index = self.name2idx[op_name] | ||||
|       return self._ops[op_index](x) | ||||
|  | ||||
|  | ||||
|  | ||||
| class SearchCell(nn.Module): | ||||
|  | ||||
|   def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, PRIMITIVES, use_residual): | ||||
|     super(SearchCell, self).__init__() | ||||
|     self.reduction  = reduction | ||||
|     self.PRIMITIVES = deepcopy(PRIMITIVES) | ||||
|    | ||||
|     if reduction_prev: | ||||
|       self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2, affine=False) | ||||
|     else: | ||||
|       self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False) | ||||
|     self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False) | ||||
|     self._steps        = steps | ||||
|     self._multiplier   = multiplier | ||||
|     self._use_residual = use_residual | ||||
|  | ||||
|     self._ops = nn.ModuleList() | ||||
|     for i in range(self._steps): | ||||
|       for j in range(2+i): | ||||
|         stride = 2 if reduction and j < 2 else 1 | ||||
|         op = MixedOp(C, stride, self.PRIMITIVES) | ||||
|         self._ops.append(op) | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return ('{name}(residual={_use_residual}, steps={_steps}, multiplier={_multiplier})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||
|  | ||||
|   def forward(self, S0, S1, weights, connect, adjacency, drop_prob, modes): | ||||
|     if modes[0] is None: | ||||
|       if modes[1] == 'normal': | ||||
|         output = self.__forwardBoth(S0, S1, weights, connect, adjacency, drop_prob) | ||||
|       elif modes[1] == 'only_W': | ||||
|         output = self.__forwardOnlyW(S0, S1, drop_prob) | ||||
|     else: | ||||
|       test_genotype = modes[0] | ||||
|       if self.reduction: operations, concats = test_genotype.reduce, test_genotype.reduce_concat | ||||
|       else             : operations, concats = test_genotype.normal, test_genotype.normal_concat | ||||
|       s0, s1 = self.preprocess0(S0), self.preprocess1(S1) | ||||
|       states, offset = [s0, s1], 0 | ||||
|       assert self._steps == len(operations), '{:} vs. {:}'.format(self._steps, len(operations)) | ||||
|       for i, (opA, opB) in enumerate(operations): | ||||
|         A = self._ops[offset + opA[1]](states[opA[1]], None, opA[0]) | ||||
|         B = self._ops[offset + opB[1]](states[opB[1]], None, opB[0]) | ||||
|         state = A + B | ||||
|         offset += len(states) | ||||
|         states.append(state) | ||||
|       output = torch.cat([states[i] for i in concats], dim=1) | ||||
|     if self._use_residual and S1.size() == output.size(): | ||||
|       return S1 + output | ||||
|     else: return output | ||||
|    | ||||
|   def __forwardBoth(self, S0, S1, weights, connect, adjacency, drop_prob): | ||||
|     s0, s1 = self.preprocess0(S0), self.preprocess1(S1) | ||||
|     states, offset = [s0, s1], 0 | ||||
|     for i in range(self._steps): | ||||
|       clist = [] | ||||
|       for j, h in enumerate(states): | ||||
|         x = self._ops[offset+j](h, weights[offset+j], None) | ||||
|         if self.training and drop_prob > 0.: | ||||
|           x = drop_path(x, math.pow(drop_prob, 1./len(states))) | ||||
|         clist.append( x ) | ||||
|       connection = torch.mm(connect['{:}'.format(i)], adjacency[i]).squeeze(0) | ||||
|       state = sum(w * node for w, node in zip(connection, clist)) | ||||
|       offset += len(states) | ||||
|       states.append(state) | ||||
|     return torch.cat(states[-self._multiplier:], dim=1) | ||||
|  | ||||
|   def __forwardOnlyW(self, S0, S1, drop_prob): | ||||
|     s0, s1 = self.preprocess0(S0), self.preprocess1(S1) | ||||
|     states, offset = [s0, s1], 0 | ||||
|     for i in range(self._steps): | ||||
|       clist = [] | ||||
|       for j, h in enumerate(states): | ||||
|         xs = self._ops[offset+j](h, None, None) | ||||
|         clist += xs | ||||
|       if self.training and drop_prob > 0.: | ||||
|         xlist = [drop_path(x, math.pow(drop_prob, 1./len(states))) for x in clist] | ||||
|       else: xlist = clist | ||||
|       state = sum(xlist) * 2 / len(xlist) | ||||
|       offset += len(states) | ||||
|       states.append(state) | ||||
|     return torch.cat(states[-self._multiplier:], dim=1) | ||||
|  | ||||
|  | ||||
|  | ||||
| class InferCell(nn.Module): | ||||
|  | ||||
|   def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev): | ||||
|     super(InferCell, self).__init__() | ||||
|     print(C_prev_prev, C_prev, C) | ||||
|  | ||||
|     if reduction_prev is None: | ||||
|       self.preprocess0 = Identity() | ||||
|     elif reduction_prev: | ||||
|       self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2) | ||||
|     else: | ||||
|       self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0) | ||||
|     self.preprocess1   = ReLUConvBN(C_prev, C, 1, 1, 0) | ||||
|      | ||||
|     if reduction: step_ops, concat = genotype.reduce, genotype.reduce_concat | ||||
|     else        : step_ops, concat = genotype.normal, genotype.normal_concat | ||||
|     self._steps        = len(step_ops) | ||||
|     self._concat       = concat | ||||
|     self._multiplier   = len(concat) | ||||
|     self._ops          = nn.ModuleList() | ||||
|     self._indices      = [] | ||||
|     for operations in step_ops: | ||||
|       for name, index in operations: | ||||
|         stride = 2 if reduction and index < 2 else 1 | ||||
|         if reduction_prev is None and index == 0: | ||||
|           op = OPS[name](C_prev_prev, C, stride, True) | ||||
|         else: | ||||
|           op = OPS[name](C          , C, stride, True) | ||||
|         self._ops.append( op ) | ||||
|         self._indices.append( index ) | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return ('{name}(steps={_steps}, concat={_concat})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||
|  | ||||
|   def forward(self, S0, S1, drop_prob): | ||||
|     s0 = self.preprocess0(S0) | ||||
|     s1 = self.preprocess1(S1) | ||||
|  | ||||
|     states = [s0, s1] | ||||
|     for i in range(self._steps): | ||||
|       h1 = states[self._indices[2*i]] | ||||
|       h2 = states[self._indices[2*i+1]] | ||||
|       op1 = self._ops[2*i] | ||||
|       op2 = self._ops[2*i+1] | ||||
|       h1 = op1(h1) | ||||
|       h2 = op2(h2) | ||||
|       if self.training and drop_prob > 0.: | ||||
|         if not isinstance(op1, Identity): | ||||
|           h1 = drop_path(h1, drop_prob) | ||||
|         if not isinstance(op2, Identity): | ||||
|           h2 = drop_path(h2, drop_prob) | ||||
|  | ||||
|       state = h1 + h2 | ||||
|       states += [state] | ||||
|     output = torch.cat([states[i] for i in self._concat], dim=1) | ||||
|     return output | ||||
| @@ -1,60 +0,0 @@ | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
|  | ||||
|  | ||||
| def drop_path(x, drop_prob): | ||||
|   if drop_prob > 0.: | ||||
|     keep_prob = 1. - drop_prob | ||||
|     mask = x.new_zeros(x.size(0), 1, 1, 1) | ||||
|     mask = mask.bernoulli_(keep_prob) | ||||
|     x = torch.div(x, keep_prob) | ||||
|     x.mul_(mask) | ||||
|   return x | ||||
|  | ||||
|  | ||||
| def return_alphas_str(basemodel): | ||||
|   if hasattr(basemodel, 'alphas_normal'): | ||||
|     string = 'normal [{:}] : \n-->>{:}'.format(basemodel.alphas_normal.size(), F.softmax(basemodel.alphas_normal, dim=-1) ) | ||||
|   else: string = '' | ||||
|   if hasattr(basemodel, 'alphas_reduce'): | ||||
|     string = string + '\nreduce : {:}'.format( F.softmax(basemodel.alphas_reduce, dim=-1) ) | ||||
|  | ||||
|   if hasattr(basemodel, 'get_adjacency'): | ||||
|     adjacency = basemodel.get_adjacency() | ||||
|     for i in range( len(adjacency) ): | ||||
|       weight = F.softmax( basemodel.connect_normal[str(i)], dim=-1 ) | ||||
|       adj = torch.mm(weight, adjacency[i]).view(-1) | ||||
|       adj = ['{:3.3f}'.format(x) for x in adj.cpu().tolist()] | ||||
|       string = string + '\nnormal--{:}-->{:}'.format(i, ', '.join(adj)) | ||||
|     for i in range( len(adjacency) ): | ||||
|       weight = F.softmax( basemodel.connect_reduce[str(i)], dim=-1 ) | ||||
|       adj = torch.mm(weight, adjacency[i]).view(-1) | ||||
|       adj = ['{:3.3f}'.format(x) for x in adj.cpu().tolist()] | ||||
|       string = string + '\nreduce--{:}-->{:}'.format(i, ', '.join(adj)) | ||||
|  | ||||
|   if hasattr(basemodel, 'alphas_connect'): | ||||
|     weight = F.softmax(basemodel.alphas_connect, dim=-1).cpu() | ||||
|     ZERO = ['{:.3f}'.format(x) for x in weight[:,0].tolist()] | ||||
|     IDEN = ['{:.3f}'.format(x) for x in weight[:,1].tolist()] | ||||
|     string = string + '\nconnect [{:}] : \n ->{:}\n ->{:}'.format( list(basemodel.alphas_connect.size()), ZERO, IDEN ) | ||||
|   else: | ||||
|     string = string + '\nconnect = None' | ||||
|    | ||||
|   if hasattr(basemodel, 'get_gcn_out'): | ||||
|     outputs = basemodel.get_gcn_out(True) | ||||
|     for i, output in enumerate(outputs): | ||||
|       string = string + '\nnormal:[{:}] : {:}'.format(i, F.softmax(output, dim=-1) ) | ||||
|  | ||||
|   return string | ||||
|  | ||||
|  | ||||
| def remove_duplicate_archs(all_archs): | ||||
|   archs = [] | ||||
|   str_archs = ['{:}'.format(x) for x in all_archs] | ||||
|   for i, arch_x in enumerate(str_archs): | ||||
|     choose = True | ||||
|     for j in range(i): | ||||
|       if arch_x == str_archs[j]: | ||||
|         choose = False; break | ||||
|     if choose: archs.append(all_archs[i]) | ||||
|   return archs | ||||
| @@ -1,182 +0,0 @@ | ||||
| from collections import namedtuple | ||||
|  | ||||
| Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat connectN connects') | ||||
| #Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') | ||||
|  | ||||
| PRIMITIVES_small = [ | ||||
|     'max_pool_3x3', | ||||
|     'avg_pool_3x3', | ||||
|     'skip_connect', | ||||
|     'sep_conv_3x3', | ||||
|     'sep_conv_5x5', | ||||
|     'conv_3x1_1x3', | ||||
| ] | ||||
|  | ||||
| PRIMITIVES_large = [ | ||||
|     'max_pool_3x3', | ||||
|     'avg_pool_3x3', | ||||
|     'skip_connect', | ||||
|     'sep_conv_3x3', | ||||
|     'sep_conv_5x5', | ||||
|     'dil_conv_3x3', | ||||
|     'dil_conv_5x5', | ||||
|     'conv_3x1_1x3', | ||||
| ] | ||||
|  | ||||
| PRIMITIVES_huge = [ | ||||
|     'skip_connect', | ||||
|     'nor_conv_1x1', | ||||
|     'max_pool_3x3', | ||||
|     'avg_pool_3x3', | ||||
|     'nor_conv_3x3', | ||||
|     'sep_conv_3x3', | ||||
|     'dil_conv_3x3', | ||||
|     'conv_3x1_1x3', | ||||
|     'sep_conv_5x5', | ||||
|     'dil_conv_5x5', | ||||
|     'sep_conv_7x7', | ||||
|     'conv_7x1_1x7', | ||||
|     'att_squeeze', | ||||
| ] | ||||
|  | ||||
| PRIMITIVES = {'small': PRIMITIVES_small, | ||||
|               'large': PRIMITIVES_large, | ||||
|               'huge' : PRIMITIVES_huge} | ||||
|  | ||||
| NASNet = Genotype( | ||||
|   normal = [ | ||||
|     (('sep_conv_5x5', 1), ('sep_conv_3x3', 0)), | ||||
|     (('sep_conv_5x5', 0), ('sep_conv_3x3', 0)), | ||||
|     (('avg_pool_3x3', 1), ('skip_connect', 0)), | ||||
|     (('avg_pool_3x3', 0), ('avg_pool_3x3', 0)), | ||||
|     (('sep_conv_3x3', 1), ('skip_connect', 1)), | ||||
|   ], | ||||
|   normal_concat = [2, 3, 4, 5, 6], | ||||
|   reduce = [ | ||||
|     (('sep_conv_5x5', 1), ('sep_conv_7x7', 0)), | ||||
|     (('max_pool_3x3', 1), ('sep_conv_7x7', 0)), | ||||
|     (('avg_pool_3x3', 1), ('sep_conv_5x5', 0)), | ||||
|     (('skip_connect', 3), ('avg_pool_3x3', 2)), | ||||
|     (('sep_conv_3x3', 2), ('max_pool_3x3', 1)), | ||||
|   ], | ||||
|   reduce_concat = [4, 5, 6], | ||||
|   connectN=None, connects=None, | ||||
| ) | ||||
|  | ||||
| PNASNet = Genotype( | ||||
|   normal = [ | ||||
|     (('sep_conv_5x5', 0), ('max_pool_3x3', 0)), | ||||
|     (('sep_conv_7x7', 1), ('max_pool_3x3', 1)), | ||||
|     (('sep_conv_5x5', 1), ('sep_conv_3x3', 1)), | ||||
|     (('sep_conv_3x3', 4), ('max_pool_3x3', 1)), | ||||
|     (('sep_conv_3x3', 0), ('skip_connect', 1)), | ||||
|   ], | ||||
|   normal_concat = [2, 3, 4, 5, 6], | ||||
|   reduce = [ | ||||
|     (('sep_conv_5x5', 0), ('max_pool_3x3', 0)), | ||||
|     (('sep_conv_7x7', 1), ('max_pool_3x3', 1)), | ||||
|     (('sep_conv_5x5', 1), ('sep_conv_3x3', 1)), | ||||
|     (('sep_conv_3x3', 4), ('max_pool_3x3', 1)), | ||||
|     (('sep_conv_3x3', 0), ('skip_connect', 1)), | ||||
|   ], | ||||
|   reduce_concat = [2, 3, 4, 5, 6], | ||||
|   connectN=None, connects=None, | ||||
| ) | ||||
|  | ||||
|  | ||||
| DARTS_V1 = Genotype( | ||||
|   normal=[ | ||||
|     (('sep_conv_3x3', 1), ('sep_conv_3x3', 0)), # step 1 | ||||
|     (('skip_connect', 0), ('sep_conv_3x3', 1)), # step 2 | ||||
|     (('skip_connect', 0), ('sep_conv_3x3', 1)), # step 3 | ||||
|     (('sep_conv_3x3', 0), ('skip_connect', 2))  # step 4 | ||||
|   ], | ||||
|   normal_concat=[2, 3, 4, 5], | ||||
|   reduce=[ | ||||
|     (('max_pool_3x3', 0), ('max_pool_3x3', 1)), # step 1 | ||||
|     (('skip_connect', 2), ('max_pool_3x3', 0)), # step 2 | ||||
|     (('max_pool_3x3', 0), ('skip_connect', 2)), # step 3 | ||||
|     (('skip_connect', 2), ('avg_pool_3x3', 0))  # step 4 | ||||
|   ], | ||||
|   reduce_concat=[2, 3, 4, 5], | ||||
|   connectN=None, connects=None, | ||||
| ) | ||||
|  | ||||
| # DARTS: Differentiable Architecture Search, ICLR 2019 | ||||
| DARTS_V2 = Genotype( | ||||
|   normal=[ | ||||
|     (('sep_conv_3x3', 0), ('sep_conv_3x3', 1)), # step 1 | ||||
|     (('sep_conv_3x3', 0), ('sep_conv_3x3', 1)), # step 2 | ||||
|     (('sep_conv_3x3', 1), ('skip_connect', 0)), # step 3 | ||||
|     (('skip_connect', 0), ('dil_conv_3x3', 2))  # step 4 | ||||
|   ], | ||||
|   normal_concat=[2, 3, 4, 5], | ||||
|   reduce=[ | ||||
|     (('max_pool_3x3', 0), ('max_pool_3x3', 1)), # step 1 | ||||
|     (('skip_connect', 2), ('max_pool_3x3', 1)), # step 2 | ||||
|     (('max_pool_3x3', 0), ('skip_connect', 2)), # step 3 | ||||
|     (('skip_connect', 2), ('max_pool_3x3', 1))  # step 4 | ||||
|   ], | ||||
|   reduce_concat=[2, 3, 4, 5], | ||||
|   connectN=None, connects=None, | ||||
| ) | ||||
|  | ||||
|  | ||||
| # One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 | ||||
| SETN = Genotype( | ||||
|   normal=[ | ||||
|     (('skip_connect', 0), ('sep_conv_5x5', 1)), | ||||
|     (('sep_conv_5x5', 0), ('sep_conv_3x3', 1)), | ||||
|     (('sep_conv_5x5', 1), ('sep_conv_5x5', 3)), | ||||
|     (('max_pool_3x3', 1), ('conv_3x1_1x3', 4))], | ||||
|   normal_concat=[2, 3, 4, 5], | ||||
|   reduce=[ | ||||
|     (('sep_conv_3x3', 0), ('sep_conv_5x5', 1)), | ||||
|     (('avg_pool_3x3', 0), ('sep_conv_5x5', 1)), | ||||
|     (('avg_pool_3x3', 0), ('sep_conv_5x5', 1)), | ||||
|     (('avg_pool_3x3', 0), ('skip_connect', 1))], | ||||
|   reduce_concat=[2, 3, 4, 5], | ||||
|   connectN=None, connects=None | ||||
| ) | ||||
|  | ||||
|  | ||||
| # Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 | ||||
| GDAS_V1 = Genotype( | ||||
|   normal=[ | ||||
|     (('skip_connect', 0), ('skip_connect', 1)), | ||||
|     (('skip_connect', 0), ('sep_conv_5x5', 2)), | ||||
|     (('sep_conv_3x3', 3), ('skip_connect', 0)), | ||||
|     (('sep_conv_5x5', 4), ('sep_conv_3x3', 3))], | ||||
|   normal_concat=[2, 3, 4, 5], | ||||
|   reduce=[ | ||||
|     (('sep_conv_5x5', 0), ('sep_conv_3x3', 1)),  | ||||
|     (('sep_conv_5x5', 2), ('sep_conv_5x5', 1)), | ||||
|     (('dil_conv_5x5', 2), ('sep_conv_3x3', 1)), | ||||
|     (('sep_conv_5x5', 0), ('sep_conv_5x5', 1))], | ||||
|   reduce_concat=[2, 3, 4, 5], | ||||
|   connectN=None, connects=None | ||||
| ) | ||||
|  | ||||
|  | ||||
|  | ||||
| Networks = {'DARTS_V1': DARTS_V1, | ||||
|             'DARTS_V2': DARTS_V2, | ||||
|             'DARTS'   : DARTS_V2, | ||||
|             'NASNet'  : NASNet, | ||||
|             'GDAS_V1' : GDAS_V1, | ||||
|             'PNASNet' : PNASNet, | ||||
|             'SETN'    : SETN, | ||||
|            } | ||||
|  | ||||
| # This function will return a Genotype from a dict. | ||||
| def build_genotype_from_dict(xdict): | ||||
|   def remove_value(nodes): | ||||
|     return [tuple([(x[0], x[1]) for x in node]) for node in nodes] | ||||
|   genotype = Genotype( | ||||
|       normal=remove_value(xdict['normal']), | ||||
|       normal_concat=xdict['normal_concat'], | ||||
|       reduce=remove_value(xdict['reduce']), | ||||
|       reduce_concat=xdict['reduce_concat'], | ||||
|       connectN=None, connects=None | ||||
|       ) | ||||
|   return genotype | ||||
| @@ -1,65 +0,0 @@ | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
|  | ||||
| class ImageNetHEAD(nn.Sequential): | ||||
|   def __init__(self, C, stride=2): | ||||
|     super(ImageNetHEAD, self).__init__() | ||||
|     self.add_module('conv1', nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False)) | ||||
|     self.add_module('bn1'  , nn.BatchNorm2d(C // 2)) | ||||
|     self.add_module('relu1', nn.ReLU(inplace=True)) | ||||
|     self.add_module('conv2', nn.Conv2d(C // 2, C, kernel_size=3, stride=stride, padding=1, bias=False)) | ||||
|     self.add_module('bn2'  , nn.BatchNorm2d(C)) | ||||
|  | ||||
|  | ||||
| class CifarHEAD(nn.Sequential): | ||||
|   def __init__(self, C): | ||||
|     super(CifarHEAD, self).__init__() | ||||
|     self.add_module('conv', nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False)) | ||||
|     self.add_module('bn', nn.BatchNorm2d(C)) | ||||
|  | ||||
|  | ||||
| class AuxiliaryHeadCIFAR(nn.Module): | ||||
|  | ||||
|   def __init__(self, C, num_classes): | ||||
|     """assuming input size 8x8""" | ||||
|     super(AuxiliaryHeadCIFAR, self).__init__() | ||||
|     self.features = nn.Sequential( | ||||
|       nn.ReLU(inplace=True), | ||||
|       nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2 | ||||
|       nn.Conv2d(C, 128, 1, bias=False), | ||||
|       nn.BatchNorm2d(128), | ||||
|       nn.ReLU(inplace=True), | ||||
|       nn.Conv2d(128, 768, 2, bias=False), | ||||
|       nn.BatchNorm2d(768), | ||||
|       nn.ReLU(inplace=True) | ||||
|     ) | ||||
|     self.classifier = nn.Linear(768, num_classes) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     x = self.features(x) | ||||
|     x = self.classifier(x.view(x.size(0),-1)) | ||||
|     return x | ||||
|  | ||||
|  | ||||
| class AuxiliaryHeadImageNet(nn.Module): | ||||
|  | ||||
|   def __init__(self, C, num_classes): | ||||
|     """assuming input size 14x14""" | ||||
|     super(AuxiliaryHeadImageNet, self).__init__() | ||||
|     self.features = nn.Sequential( | ||||
|       nn.ReLU(inplace=True), | ||||
|       nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False), | ||||
|       nn.Conv2d(C, 128, 1, bias=False), | ||||
|       nn.BatchNorm2d(128), | ||||
|       nn.ReLU(inplace=True), | ||||
|       nn.Conv2d(128, 768, 2, bias=False), | ||||
|       nn.BatchNorm2d(768), | ||||
|       nn.ReLU(inplace=True) | ||||
|     ) | ||||
|     self.classifier = nn.Linear(768, num_classes) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     x = self.features(x) | ||||
|     x = self.classifier(x.view(x.size(0),-1)) | ||||
|     return x | ||||
| @@ -1,31 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| # I write this package to make AutoDL-Projects to be compatible with the old GDAS projects. | ||||
| # Ideally, this package will be merged into lib/models/cell_infers in future. | ||||
| # Currently, this package is used to reproduce the results in GDAS (Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019). | ||||
| ################################################## | ||||
|  | ||||
| import os, torch | ||||
|  | ||||
| def obtain_nas_infer_model(config, extra_model_path=None): | ||||
|    | ||||
|   if config.arch == 'dxys': | ||||
|     from .DXYs import CifarNet, ImageNet, Networks | ||||
|     from .DXYs import build_genotype_from_dict | ||||
|     if config.genotype is None: | ||||
|       if extra_model_path is not None and not os.path.isfile(extra_model_path): | ||||
|         raise ValueError('When genotype in confiig is None, extra_model_path must be set as a path instead of {:}'.format(extra_model_path)) | ||||
|       xdata = torch.load(extra_model_path) | ||||
|       current_epoch = xdata['epoch'] | ||||
|       genotype_dict = xdata['genotypes'][current_epoch-1] | ||||
|       genotype = build_genotype_from_dict(genotype_dict) | ||||
|     else: | ||||
|       genotype = Networks[config.genotype] | ||||
|     if config.dataset == 'cifar': | ||||
|       return CifarNet(config.ichannel, config.layers, config.stem_multi, config.auxiliary, genotype, config.class_num) | ||||
|     elif config.dataset == 'imagenet': | ||||
|       return ImageNet(config.ichannel, config.layers, config.auxiliary, genotype, config.class_num) | ||||
|     else: raise ValueError('invalid dataset : {:}'.format(config.dataset)) | ||||
|   else: | ||||
|     raise ValueError('invalid nas arch type : {:}'.format(config.arch)) | ||||
| @@ -1,183 +0,0 @@ | ||||
| ############################################################################################## | ||||
| # This code is copied and modified from Hanxiao Liu's work (https://github.com/quark0/darts) # | ||||
| ############################################################################################## | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
| OPS = { | ||||
|   'none'         : lambda C_in, C_out, stride, affine: Zero(stride), | ||||
|   'avg_pool_3x3' : lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'avg'), | ||||
|   'max_pool_3x3' : lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'max'), | ||||
|   'nor_conv_7x7' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), affine), | ||||
|   'nor_conv_3x3' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), affine), | ||||
|   'nor_conv_1x1' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), affine), | ||||
|   'skip_connect' : lambda C_in, C_out, stride, affine: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine), | ||||
|   'sep_conv_3x3' : lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 3, stride, 1, affine=affine), | ||||
|   'sep_conv_5x5' : lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 5, stride, 2, affine=affine), | ||||
|   'sep_conv_7x7' : lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 7, stride, 3, affine=affine), | ||||
|   'dil_conv_3x3' : lambda C_in, C_out, stride, affine: DilConv(C_in, C_out, 3, stride, 2, 2, affine=affine), | ||||
|   'dil_conv_5x5' : lambda C_in, C_out, stride, affine: DilConv(C_in, C_out, 5, stride, 4, 2, affine=affine), | ||||
|   'conv_7x1_1x7' : lambda C_in, C_out, stride, affine: Conv717(C_in, C_out, stride, affine), | ||||
|   'conv_3x1_1x3' : lambda C_in, C_out, stride, affine: Conv313(C_in, C_out, stride, affine) | ||||
| } | ||||
|  | ||||
|  | ||||
| class POOLING(nn.Module): | ||||
|  | ||||
|   def __init__(self, C_in, C_out, stride, mode): | ||||
|     super(POOLING, self).__init__() | ||||
|     if C_in == C_out: | ||||
|       self.preprocess = None | ||||
|     else: | ||||
|       self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0) | ||||
|     if mode == 'avg'  : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) | ||||
|     elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1) | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     if self.preprocess is not None: | ||||
|       x = self.preprocess(inputs) | ||||
|     else: x = inputs | ||||
|     return self.op(x) | ||||
|  | ||||
|  | ||||
| class Conv313(nn.Module): | ||||
|  | ||||
|   def __init__(self, C_in, C_out, stride, affine): | ||||
|     super(Conv313, self).__init__() | ||||
|     self.op = nn.Sequential( | ||||
|       nn.ReLU(inplace=False), | ||||
|       nn.Conv2d(C_in , C_out, (1,3), stride=(1, stride), padding=(0, 1), bias=False), | ||||
|       nn.Conv2d(C_out, C_out, (3,1), stride=(stride, 1), padding=(1, 0), bias=False), | ||||
|       nn.BatchNorm2d(C_out, affine=affine) | ||||
|     ) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     return self.op(x) | ||||
|  | ||||
|  | ||||
| class Conv717(nn.Module): | ||||
|  | ||||
|   def __init__(self, C_in, C_out, stride, affine): | ||||
|     super(Conv717, self).__init__() | ||||
|     self.op = nn.Sequential( | ||||
|       nn.ReLU(inplace=False), | ||||
|       nn.Conv2d(C_in , C_out, (1,7), stride=(1, stride), padding=(0, 3), bias=False), | ||||
|       nn.Conv2d(C_out, C_out, (7,1), stride=(stride, 1), padding=(3, 0), bias=False), | ||||
|       nn.BatchNorm2d(C_out, affine=affine) | ||||
|     ) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     return self.op(x) | ||||
|  | ||||
|  | ||||
| class ReLUConvBN(nn.Module): | ||||
|  | ||||
|   def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): | ||||
|     super(ReLUConvBN, self).__init__() | ||||
|     self.op = nn.Sequential( | ||||
|       nn.ReLU(inplace=False), | ||||
|       nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False), | ||||
|       nn.BatchNorm2d(C_out, affine=affine) | ||||
|     ) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     return self.op(x) | ||||
|  | ||||
|  | ||||
| class DilConv(nn.Module): | ||||
|      | ||||
|   def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): | ||||
|     super(DilConv, self).__init__() | ||||
|     self.op = nn.Sequential( | ||||
|       nn.ReLU(inplace=False), | ||||
|       nn.Conv2d(C_in, C_in,  kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False), | ||||
|       nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), | ||||
|       nn.BatchNorm2d(C_out, affine=affine), | ||||
|       ) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     return self.op(x) | ||||
|  | ||||
|  | ||||
| class SepConv(nn.Module): | ||||
|      | ||||
|   def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): | ||||
|     super(SepConv, self).__init__() | ||||
|     self.op = nn.Sequential( | ||||
|       nn.ReLU(inplace=False), | ||||
|       nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False), | ||||
|       nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False), | ||||
|       nn.BatchNorm2d(C_in, affine=affine), | ||||
|       nn.ReLU(inplace=False), | ||||
|       nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=     1, padding=padding, groups=C_in, bias=False), | ||||
|       nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), | ||||
|       nn.BatchNorm2d(C_out, affine=affine), | ||||
|       ) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     return self.op(x) | ||||
|  | ||||
|  | ||||
| class Identity(nn.Module): | ||||
|  | ||||
|   def __init__(self): | ||||
|     super(Identity, self).__init__() | ||||
|  | ||||
|   def forward(self, x): | ||||
|     return x | ||||
|  | ||||
|  | ||||
| class Zero(nn.Module): | ||||
|  | ||||
|   def __init__(self, stride): | ||||
|     super(Zero, self).__init__() | ||||
|     self.stride = stride | ||||
|  | ||||
|   def forward(self, x): | ||||
|     if self.stride == 1: | ||||
|       return x.mul(0.) | ||||
|     return x[:,:,::self.stride,::self.stride].mul(0.) | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return 'stride={stride}'.format(**self.__dict__) | ||||
|  | ||||
|  | ||||
| class FactorizedReduce(nn.Module): | ||||
|  | ||||
|   def __init__(self, C_in, C_out, stride, affine=True): | ||||
|     super(FactorizedReduce, self).__init__() | ||||
|     self.stride = stride | ||||
|     self.C_in   = C_in   | ||||
|     self.C_out  = C_out   | ||||
|     self.relu   = nn.ReLU(inplace=False) | ||||
|     if stride == 2: | ||||
|       #assert C_out % 2 == 0, 'C_out : {:}'.format(C_out) | ||||
|       C_outs = [C_out // 2, C_out - C_out // 2] | ||||
|       self.convs = nn.ModuleList() | ||||
|       for i in range(2): | ||||
|         self.convs.append( nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=False) ) | ||||
|       self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0) | ||||
|     elif stride == 4: | ||||
|       assert C_out % 4 == 0, 'C_out : {:}'.format(C_out) | ||||
|       self.convs = nn.ModuleList() | ||||
|       for i in range(4): | ||||
|         self.convs.append( nn.Conv2d(C_in, C_out // 4, 1, stride=stride, padding=0, bias=False) ) | ||||
|       self.pad = nn.ConstantPad2d((0, 3, 0, 3), 0) | ||||
|     else: | ||||
|       raise ValueError('Invalid stride : {:}'.format(stride)) | ||||
|      | ||||
|     self.bn = nn.BatchNorm2d(C_out, affine=affine) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     x = self.relu(x) | ||||
|     y = self.pad(x) | ||||
|     if self.stride == 2: | ||||
|       out = torch.cat([self.convs[0](x), self.convs[1](y[:,:,1:,1:])], dim=1) | ||||
|     else: | ||||
|       out = torch.cat([self.convs[0](x),            self.convs[1](y[:,:,1:-2,1:-2]), | ||||
|                        self.convs[2](y[:,:,2:-1,2:-1]), self.convs[3](y[:,:,3:,3:])], dim=1) | ||||
|     out = self.bn(out) | ||||
|     return out | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__) | ||||
| @@ -1,36 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| from .starts import prepare_seed | ||||
| from .starts import prepare_logger | ||||
| from .starts import get_machine_info | ||||
| from .starts import save_checkpoint | ||||
| from .starts import copy_checkpoint | ||||
| from .optimizers import get_optim_scheduler | ||||
| from .funcs_nasbench import evaluate_for_seed as bench_evaluate_for_seed | ||||
| from .funcs_nasbench import pure_evaluate as bench_pure_evaluate | ||||
| from .funcs_nasbench import get_nas_bench_loaders | ||||
|  | ||||
|  | ||||
| def get_procedures(procedure): | ||||
|     from .basic_main import basic_train, basic_valid | ||||
|     from .search_main import search_train, search_valid | ||||
|     from .search_main_v2 import search_train_v2 | ||||
|     from .simple_KD_main import simple_KD_train, simple_KD_valid | ||||
|  | ||||
|     train_funcs = { | ||||
|         "basic": basic_train, | ||||
|         "search": search_train, | ||||
|         "Simple-KD": simple_KD_train, | ||||
|         "search-v2": search_train_v2, | ||||
|     } | ||||
|     valid_funcs = { | ||||
|         "basic": basic_valid, | ||||
|         "search": search_valid, | ||||
|         "Simple-KD": simple_KD_valid, | ||||
|         "search-v2": search_valid, | ||||
|     } | ||||
|  | ||||
|     train_func = train_funcs[procedure] | ||||
|     valid_func = valid_funcs[procedure] | ||||
|     return train_func, valid_func | ||||
| @@ -1,100 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 # | ||||
| ##################################################### | ||||
| # To be finished. | ||||
| # | ||||
| import os, sys, time, torch | ||||
| from typing import Optional, Text, Callable | ||||
|  | ||||
| # modules in AutoDL | ||||
| from log_utils import AverageMeter | ||||
| from log_utils import time_string | ||||
| from .eval_funcs import obtain_accuracy | ||||
|  | ||||
|  | ||||
| def get_device(tensors): | ||||
|     if isinstance(tensors, (list, tuple)): | ||||
|         return get_device(tensors[0]) | ||||
|     elif isinstance(tensors, dict): | ||||
|         for key, value in tensors.items(): | ||||
|             return get_device(value) | ||||
|     else: | ||||
|         return tensors.device | ||||
|  | ||||
|  | ||||
| def basic_train_fn( | ||||
|     xloader, | ||||
|     network, | ||||
|     criterion, | ||||
|     optimizer, | ||||
|     metric, | ||||
|     logger, | ||||
| ): | ||||
|     results = procedure( | ||||
|         xloader, | ||||
|         network, | ||||
|         criterion, | ||||
|         optimizer, | ||||
|         metric, | ||||
|         "train", | ||||
|         logger, | ||||
|     ) | ||||
|     return results | ||||
|  | ||||
|  | ||||
| def basic_eval_fn(xloader, network, metric, logger): | ||||
|     with torch.no_grad(): | ||||
|         results = procedure( | ||||
|             xloader, | ||||
|             network, | ||||
|             None, | ||||
|             None, | ||||
|             metric, | ||||
|             "valid", | ||||
|             logger, | ||||
|         ) | ||||
|     return results | ||||
|  | ||||
|  | ||||
| def procedure( | ||||
|     xloader, | ||||
|     network, | ||||
|     criterion, | ||||
|     optimizer, | ||||
|     metric, | ||||
|     mode: Text, | ||||
|     logger_fn: Callable = None, | ||||
| ): | ||||
|     data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|     if mode.lower() == "train": | ||||
|         network.train() | ||||
|     elif mode.lower() == "valid": | ||||
|         network.eval() | ||||
|     else: | ||||
|         raise ValueError("The mode is not right : {:}".format(mode)) | ||||
|  | ||||
|     end = time.time() | ||||
|     for i, (inputs, targets) in enumerate(xloader): | ||||
|         # measure data loading time | ||||
|         data_time.update(time.time() - end) | ||||
|         # calculate prediction and loss | ||||
|  | ||||
|         if mode == "train": | ||||
|             optimizer.zero_grad() | ||||
|  | ||||
|         outputs = network(inputs) | ||||
|         targets = targets.to(get_device(outputs)) | ||||
|  | ||||
|         if mode == "train": | ||||
|             loss = criterion(outputs, targets) | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|  | ||||
|         # record | ||||
|         with torch.no_grad(): | ||||
|             results = metric(outputs, targets) | ||||
|  | ||||
|         # measure elapsed time | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|     return metric.get_info() | ||||
| @@ -1,155 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import os, sys, time, torch | ||||
|  | ||||
| # modules in AutoDL | ||||
| from log_utils import AverageMeter | ||||
| from log_utils import time_string | ||||
| from .eval_funcs import obtain_accuracy | ||||
|  | ||||
|  | ||||
| def basic_train( | ||||
|     xloader, | ||||
|     network, | ||||
|     criterion, | ||||
|     scheduler, | ||||
|     optimizer, | ||||
|     optim_config, | ||||
|     extra_info, | ||||
|     print_freq, | ||||
|     logger, | ||||
| ): | ||||
|     loss, acc1, acc5 = procedure( | ||||
|         xloader, | ||||
|         network, | ||||
|         criterion, | ||||
|         scheduler, | ||||
|         optimizer, | ||||
|         "train", | ||||
|         optim_config, | ||||
|         extra_info, | ||||
|         print_freq, | ||||
|         logger, | ||||
|     ) | ||||
|     return loss, acc1, acc5 | ||||
|  | ||||
|  | ||||
| def basic_valid( | ||||
|     xloader, network, criterion, optim_config, extra_info, print_freq, logger | ||||
| ): | ||||
|     with torch.no_grad(): | ||||
|         loss, acc1, acc5 = procedure( | ||||
|             xloader, | ||||
|             network, | ||||
|             criterion, | ||||
|             None, | ||||
|             None, | ||||
|             "valid", | ||||
|             None, | ||||
|             extra_info, | ||||
|             print_freq, | ||||
|             logger, | ||||
|         ) | ||||
|     return loss, acc1, acc5 | ||||
|  | ||||
|  | ||||
| def procedure( | ||||
|     xloader, | ||||
|     network, | ||||
|     criterion, | ||||
|     scheduler, | ||||
|     optimizer, | ||||
|     mode, | ||||
|     config, | ||||
|     extra_info, | ||||
|     print_freq, | ||||
|     logger, | ||||
| ): | ||||
|     data_time, batch_time, losses, top1, top5 = ( | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|     ) | ||||
|     if mode == "train": | ||||
|         network.train() | ||||
|     elif mode == "valid": | ||||
|         network.eval() | ||||
|     else: | ||||
|         raise ValueError("The mode is not right : {:}".format(mode)) | ||||
|  | ||||
|     # logger.log('[{:5s}] config ::  auxiliary={:}, message={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, network.module.get_message())) | ||||
|     logger.log( | ||||
|         "[{:5s}] config ::  auxiliary={:}".format( | ||||
|             mode, config.auxiliary if hasattr(config, "auxiliary") else -1 | ||||
|         ) | ||||
|     ) | ||||
|     end = time.time() | ||||
|     for i, (inputs, targets) in enumerate(xloader): | ||||
|         if mode == "train": | ||||
|             scheduler.update(None, 1.0 * i / len(xloader)) | ||||
|         # measure data loading time | ||||
|         data_time.update(time.time() - end) | ||||
|         # calculate prediction and loss | ||||
|         targets = targets.cuda(non_blocking=True) | ||||
|  | ||||
|         if mode == "train": | ||||
|             optimizer.zero_grad() | ||||
|  | ||||
|         features, logits = network(inputs) | ||||
|         if isinstance(logits, list): | ||||
|             assert len(logits) == 2, "logits must has {:} items instead of {:}".format( | ||||
|                 2, len(logits) | ||||
|             ) | ||||
|             logits, logits_aux = logits | ||||
|         else: | ||||
|             logits, logits_aux = logits, None | ||||
|         loss = criterion(logits, targets) | ||||
|         if config is not None and hasattr(config, "auxiliary") and config.auxiliary > 0: | ||||
|             loss_aux = criterion(logits_aux, targets) | ||||
|             loss += config.auxiliary * loss_aux | ||||
|  | ||||
|         if mode == "train": | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|  | ||||
|         # record | ||||
|         prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|         losses.update(loss.item(), inputs.size(0)) | ||||
|         top1.update(prec1.item(), inputs.size(0)) | ||||
|         top5.update(prec5.item(), inputs.size(0)) | ||||
|  | ||||
|         # measure elapsed time | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|  | ||||
|         if i % print_freq == 0 or (i + 1) == len(xloader): | ||||
|             Sstr = ( | ||||
|                 " {:5s} ".format(mode.upper()) | ||||
|                 + time_string() | ||||
|                 + " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader)) | ||||
|             ) | ||||
|             if scheduler is not None: | ||||
|                 Sstr += " {:}".format(scheduler.get_min_info()) | ||||
|             Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( | ||||
|                 batch_time=batch_time, data_time=data_time | ||||
|             ) | ||||
|             Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format( | ||||
|                 loss=losses, top1=top1, top5=top5 | ||||
|             ) | ||||
|             Istr = "Size={:}".format(list(inputs.size())) | ||||
|             logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr) | ||||
|  | ||||
|     logger.log( | ||||
|         " **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format( | ||||
|             mode=mode.upper(), | ||||
|             top1=top1, | ||||
|             top5=top5, | ||||
|             error1=100 - top1.avg, | ||||
|             error5=100 - top5.avg, | ||||
|             loss=losses.avg, | ||||
|         ) | ||||
|     ) | ||||
|     return losses.avg, top1.avg, top5.avg | ||||
| @@ -1,20 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 # | ||||
| ##################################################### | ||||
| import abc | ||||
|  | ||||
|  | ||||
| def obtain_accuracy(output, target, topk=(1,)): | ||||
|     """Computes the precision@k for the specified values of k""" | ||||
|     maxk = max(topk) | ||||
|     batch_size = target.size(0) | ||||
|  | ||||
|     _, pred = output.topk(maxk, 1, True, True) | ||||
|     pred = pred.t() | ||||
|     correct = pred.eq(target.view(1, -1).expand_as(pred)) | ||||
|  | ||||
|     res = [] | ||||
|     for k in topk: | ||||
|         correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | ||||
|         res.append(correct_k.mul_(100.0 / batch_size)) | ||||
|     return res | ||||
| @@ -1,438 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # | ||||
| ##################################################### | ||||
| import os, time, copy, torch, pathlib | ||||
|  | ||||
| # modules in AutoDL | ||||
| import datasets | ||||
| from config_utils import load_config | ||||
| from procedures import prepare_seed, get_optim_scheduler | ||||
| from log_utils import AverageMeter, time_string, convert_secs2time | ||||
| from models import get_cell_based_tiny_net | ||||
| from utils import get_model_infos | ||||
| from .eval_funcs import obtain_accuracy | ||||
|  | ||||
|  | ||||
| __all__ = ["evaluate_for_seed", "pure_evaluate", "get_nas_bench_loaders"] | ||||
|  | ||||
|  | ||||
| def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()): | ||||
|     data_time, batch_time, batch = AverageMeter(), AverageMeter(), None | ||||
|     losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     latencies, device = [], torch.cuda.current_device() | ||||
|     network.eval() | ||||
|     with torch.no_grad(): | ||||
|         end = time.time() | ||||
|         for i, (inputs, targets) in enumerate(xloader): | ||||
|             targets = targets.cuda(device=device, non_blocking=True) | ||||
|             inputs = inputs.cuda(device=device, non_blocking=True) | ||||
|             data_time.update(time.time() - end) | ||||
|             # forward | ||||
|             features, logits = network(inputs) | ||||
|             loss = criterion(logits, targets) | ||||
|             batch_time.update(time.time() - end) | ||||
|             if batch is None or batch == inputs.size(0): | ||||
|                 batch = inputs.size(0) | ||||
|                 latencies.append(batch_time.val - data_time.val) | ||||
|             # record loss and accuracy | ||||
|             prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|             losses.update(loss.item(), inputs.size(0)) | ||||
|             top1.update(prec1.item(), inputs.size(0)) | ||||
|             top5.update(prec5.item(), inputs.size(0)) | ||||
|             end = time.time() | ||||
|     if len(latencies) > 2: | ||||
|         latencies = latencies[1:] | ||||
|     return losses.avg, top1.avg, top5.avg, latencies | ||||
|  | ||||
|  | ||||
| def procedure(xloader, network, criterion, scheduler, optimizer, mode: str): | ||||
|     losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     if mode == "train": | ||||
|         network.train() | ||||
|     elif mode == "valid": | ||||
|         network.eval() | ||||
|     else: | ||||
|         raise ValueError("The mode is not right : {:}".format(mode)) | ||||
|     device = torch.cuda.current_device() | ||||
|     data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time() | ||||
|     for i, (inputs, targets) in enumerate(xloader): | ||||
|         if mode == "train": | ||||
|             scheduler.update(None, 1.0 * i / len(xloader)) | ||||
|  | ||||
|         targets = targets.cuda(device=device, non_blocking=True) | ||||
|         if mode == "train": | ||||
|             optimizer.zero_grad() | ||||
|         # forward | ||||
|         features, logits = network(inputs) | ||||
|         loss = criterion(logits, targets) | ||||
|         # backward | ||||
|         if mode == "train": | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|         # record loss and accuracy | ||||
|         prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|         losses.update(loss.item(), inputs.size(0)) | ||||
|         top1.update(prec1.item(), inputs.size(0)) | ||||
|         top5.update(prec5.item(), inputs.size(0)) | ||||
|         # count time | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|     return losses.avg, top1.avg, top5.avg, batch_time.sum | ||||
|  | ||||
|  | ||||
| def evaluate_for_seed( | ||||
|     arch_config, opt_config, train_loader, valid_loaders, seed: int, logger | ||||
| ): | ||||
|  | ||||
|     prepare_seed(seed)  # random seed | ||||
|     net = get_cell_based_tiny_net(arch_config) | ||||
|     # net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num) | ||||
|     flop, param = get_model_infos(net, opt_config.xshape) | ||||
|     logger.log("Network : {:}".format(net.get_message()), False) | ||||
|     logger.log( | ||||
|         "{:} Seed-------------------------- {:} --------------------------".format( | ||||
|             time_string(), seed | ||||
|         ) | ||||
|     ) | ||||
|     logger.log("FLOP = {:} MB, Param = {:} MB".format(flop, param)) | ||||
|     # train and valid | ||||
|     optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), opt_config) | ||||
|     default_device = torch.cuda.current_device() | ||||
|     network = torch.nn.DataParallel(net, device_ids=[default_device]).cuda( | ||||
|         device=default_device | ||||
|     ) | ||||
|     criterion = criterion.cuda(device=default_device) | ||||
|     # start training | ||||
|     start_time, epoch_time, total_epoch = ( | ||||
|         time.time(), | ||||
|         AverageMeter(), | ||||
|         opt_config.epochs + opt_config.warmup, | ||||
|     ) | ||||
|     ( | ||||
|         train_losses, | ||||
|         train_acc1es, | ||||
|         train_acc5es, | ||||
|         valid_losses, | ||||
|         valid_acc1es, | ||||
|         valid_acc5es, | ||||
|     ) = ({}, {}, {}, {}, {}, {}) | ||||
|     train_times, valid_times, lrs = {}, {}, {} | ||||
|     for epoch in range(total_epoch): | ||||
|         scheduler.update(epoch, 0.0) | ||||
|         lr = min(scheduler.get_lr()) | ||||
|         train_loss, train_acc1, train_acc5, train_tm = procedure( | ||||
|             train_loader, network, criterion, scheduler, optimizer, "train" | ||||
|         ) | ||||
|         train_losses[epoch] = train_loss | ||||
|         train_acc1es[epoch] = train_acc1 | ||||
|         train_acc5es[epoch] = train_acc5 | ||||
|         train_times[epoch] = train_tm | ||||
|         lrs[epoch] = lr | ||||
|         with torch.no_grad(): | ||||
|             for key, xloder in valid_loaders.items(): | ||||
|                 valid_loss, valid_acc1, valid_acc5, valid_tm = procedure( | ||||
|                     xloder, network, criterion, None, None, "valid" | ||||
|                 ) | ||||
|                 valid_losses["{:}@{:}".format(key, epoch)] = valid_loss | ||||
|                 valid_acc1es["{:}@{:}".format(key, epoch)] = valid_acc1 | ||||
|                 valid_acc5es["{:}@{:}".format(key, epoch)] = valid_acc5 | ||||
|                 valid_times["{:}@{:}".format(key, epoch)] = valid_tm | ||||
|  | ||||
|         # measure elapsed time | ||||
|         epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%], lr={:}".format( | ||||
|                 time_string(), | ||||
|                 need_time, | ||||
|                 epoch, | ||||
|                 total_epoch, | ||||
|                 train_loss, | ||||
|                 train_acc1, | ||||
|                 train_acc5, | ||||
|                 valid_loss, | ||||
|                 valid_acc1, | ||||
|                 valid_acc5, | ||||
|                 lr, | ||||
|             ) | ||||
|         ) | ||||
|     info_seed = { | ||||
|         "flop": flop, | ||||
|         "param": param, | ||||
|         "arch_config": arch_config._asdict(), | ||||
|         "opt_config": opt_config._asdict(), | ||||
|         "total_epoch": total_epoch, | ||||
|         "train_losses": train_losses, | ||||
|         "train_acc1es": train_acc1es, | ||||
|         "train_acc5es": train_acc5es, | ||||
|         "train_times": train_times, | ||||
|         "valid_losses": valid_losses, | ||||
|         "valid_acc1es": valid_acc1es, | ||||
|         "valid_acc5es": valid_acc5es, | ||||
|         "valid_times": valid_times, | ||||
|         "learning_rates": lrs, | ||||
|         "net_state_dict": net.state_dict(), | ||||
|         "net_string": "{:}".format(net), | ||||
|         "finish-train": True, | ||||
|     } | ||||
|     return info_seed | ||||
|  | ||||
|  | ||||
| def get_nas_bench_loaders(workers): | ||||
|  | ||||
|     torch.set_num_threads(workers) | ||||
|  | ||||
|     root_dir = (pathlib.Path(__file__).parent / ".." / "..").resolve() | ||||
|     torch_dir = pathlib.Path(os.environ["TORCH_HOME"]) | ||||
|     # cifar | ||||
|     cifar_config_path = root_dir / "configs" / "nas-benchmark" / "CIFAR.config" | ||||
|     cifar_config = load_config(cifar_config_path, None, None) | ||||
|     get_datasets = datasets.get_datasets  # a function to return the dataset | ||||
|     break_line = "-" * 150 | ||||
|     print("{:} Create data-loader for all datasets".format(time_string())) | ||||
|     print(break_line) | ||||
|     TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets( | ||||
|         "cifar10", str(torch_dir / "cifar.python"), -1 | ||||
|     ) | ||||
|     print( | ||||
|         "original CIFAR-10 : {:} training images and {:} test images : {:} input shape : {:} number of classes".format( | ||||
|             len(TRAIN_CIFAR10), len(VALID_CIFAR10), xshape, class_num | ||||
|         ) | ||||
|     ) | ||||
|     cifar10_splits = load_config( | ||||
|         root_dir / "configs" / "nas-benchmark" / "cifar-split.txt", None, None | ||||
|     ) | ||||
|     assert cifar10_splits.train[:10] == [ | ||||
|         0, | ||||
|         5, | ||||
|         7, | ||||
|         11, | ||||
|         13, | ||||
|         15, | ||||
|         16, | ||||
|         17, | ||||
|         20, | ||||
|         24, | ||||
|     ] and cifar10_splits.valid[:10] == [ | ||||
|         1, | ||||
|         2, | ||||
|         3, | ||||
|         4, | ||||
|         6, | ||||
|         8, | ||||
|         9, | ||||
|         10, | ||||
|         12, | ||||
|         14, | ||||
|     ] | ||||
|     temp_dataset = copy.deepcopy(TRAIN_CIFAR10) | ||||
|     temp_dataset.transform = VALID_CIFAR10.transform | ||||
|     # data loader | ||||
|     trainval_cifar10_loader = torch.utils.data.DataLoader( | ||||
|         TRAIN_CIFAR10, | ||||
|         batch_size=cifar_config.batch_size, | ||||
|         shuffle=True, | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     train_cifar10_loader = torch.utils.data.DataLoader( | ||||
|         TRAIN_CIFAR10, | ||||
|         batch_size=cifar_config.batch_size, | ||||
|         sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.train), | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     valid_cifar10_loader = torch.utils.data.DataLoader( | ||||
|         temp_dataset, | ||||
|         batch_size=cifar_config.batch_size, | ||||
|         sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.valid), | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     test__cifar10_loader = torch.utils.data.DataLoader( | ||||
|         VALID_CIFAR10, | ||||
|         batch_size=cifar_config.batch_size, | ||||
|         shuffle=False, | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     print( | ||||
|         "CIFAR-10  : trval-loader has {:3d} batch with {:} per batch".format( | ||||
|             len(trainval_cifar10_loader), cifar_config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|     print( | ||||
|         "CIFAR-10  : train-loader has {:3d} batch with {:} per batch".format( | ||||
|             len(train_cifar10_loader), cifar_config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|     print( | ||||
|         "CIFAR-10  : valid-loader has {:3d} batch with {:} per batch".format( | ||||
|             len(valid_cifar10_loader), cifar_config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|     print( | ||||
|         "CIFAR-10  : test--loader has {:3d} batch with {:} per batch".format( | ||||
|             len(test__cifar10_loader), cifar_config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|     print(break_line) | ||||
|     # CIFAR-100 | ||||
|     TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets( | ||||
|         "cifar100", str(torch_dir / "cifar.python"), -1 | ||||
|     ) | ||||
|     print( | ||||
|         "original CIFAR-100: {:} training images and {:} test images : {:} input shape : {:} number of classes".format( | ||||
|             len(TRAIN_CIFAR100), len(VALID_CIFAR100), xshape, class_num | ||||
|         ) | ||||
|     ) | ||||
|     cifar100_splits = load_config( | ||||
|         root_dir / "configs" / "nas-benchmark" / "cifar100-test-split.txt", None, None | ||||
|     ) | ||||
|     assert cifar100_splits.xvalid[:10] == [ | ||||
|         1, | ||||
|         3, | ||||
|         4, | ||||
|         5, | ||||
|         8, | ||||
|         10, | ||||
|         13, | ||||
|         14, | ||||
|         15, | ||||
|         16, | ||||
|     ] and cifar100_splits.xtest[:10] == [ | ||||
|         0, | ||||
|         2, | ||||
|         6, | ||||
|         7, | ||||
|         9, | ||||
|         11, | ||||
|         12, | ||||
|         17, | ||||
|         20, | ||||
|         24, | ||||
|     ] | ||||
|     train_cifar100_loader = torch.utils.data.DataLoader( | ||||
|         TRAIN_CIFAR100, | ||||
|         batch_size=cifar_config.batch_size, | ||||
|         shuffle=True, | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     valid_cifar100_loader = torch.utils.data.DataLoader( | ||||
|         VALID_CIFAR100, | ||||
|         batch_size=cifar_config.batch_size, | ||||
|         sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     test__cifar100_loader = torch.utils.data.DataLoader( | ||||
|         VALID_CIFAR100, | ||||
|         batch_size=cifar_config.batch_size, | ||||
|         sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest), | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     print( | ||||
|         "CIFAR-100  : train-loader has {:3d} batch".format(len(train_cifar100_loader)) | ||||
|     ) | ||||
|     print( | ||||
|         "CIFAR-100  : valid-loader has {:3d} batch".format(len(valid_cifar100_loader)) | ||||
|     ) | ||||
|     print( | ||||
|         "CIFAR-100  : test--loader has {:3d} batch".format(len(test__cifar100_loader)) | ||||
|     ) | ||||
|     print(break_line) | ||||
|  | ||||
|     imagenet16_config_path = "configs/nas-benchmark/ImageNet-16.config" | ||||
|     imagenet16_config = load_config(imagenet16_config_path, None, None) | ||||
|     TRAIN_ImageNet16_120, VALID_ImageNet16_120, xshape, class_num = get_datasets( | ||||
|         "ImageNet16-120", str(torch_dir / "cifar.python" / "ImageNet16"), -1 | ||||
|     ) | ||||
|     print( | ||||
|         "original TRAIN_ImageNet16_120: {:} training images and {:} test images : {:} input shape : {:} number of classes".format( | ||||
|             len(TRAIN_ImageNet16_120), len(VALID_ImageNet16_120), xshape, class_num | ||||
|         ) | ||||
|     ) | ||||
|     imagenet_splits = load_config( | ||||
|         root_dir / "configs" / "nas-benchmark" / "imagenet-16-120-test-split.txt", | ||||
|         None, | ||||
|         None, | ||||
|     ) | ||||
|     assert imagenet_splits.xvalid[:10] == [ | ||||
|         1, | ||||
|         2, | ||||
|         3, | ||||
|         6, | ||||
|         7, | ||||
|         8, | ||||
|         9, | ||||
|         12, | ||||
|         16, | ||||
|         18, | ||||
|     ] and imagenet_splits.xtest[:10] == [ | ||||
|         0, | ||||
|         4, | ||||
|         5, | ||||
|         10, | ||||
|         11, | ||||
|         13, | ||||
|         14, | ||||
|         15, | ||||
|         17, | ||||
|         20, | ||||
|     ] | ||||
|     train_imagenet_loader = torch.utils.data.DataLoader( | ||||
|         TRAIN_ImageNet16_120, | ||||
|         batch_size=imagenet16_config.batch_size, | ||||
|         shuffle=True, | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     valid_imagenet_loader = torch.utils.data.DataLoader( | ||||
|         VALID_ImageNet16_120, | ||||
|         batch_size=imagenet16_config.batch_size, | ||||
|         sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xvalid), | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     test__imagenet_loader = torch.utils.data.DataLoader( | ||||
|         VALID_ImageNet16_120, | ||||
|         batch_size=imagenet16_config.batch_size, | ||||
|         sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xtest), | ||||
|         num_workers=workers, | ||||
|         pin_memory=True, | ||||
|     ) | ||||
|     print( | ||||
|         "ImageNet-16-120  : train-loader has {:3d} batch with {:} per batch".format( | ||||
|             len(train_imagenet_loader), imagenet16_config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|     print( | ||||
|         "ImageNet-16-120  : valid-loader has {:3d} batch with {:} per batch".format( | ||||
|             len(valid_imagenet_loader), imagenet16_config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|     print( | ||||
|         "ImageNet-16-120  : test--loader has {:3d} batch with {:} per batch".format( | ||||
|             len(test__imagenet_loader), imagenet16_config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     # 'cifar10', 'cifar100', 'ImageNet16-120' | ||||
|     loaders = { | ||||
|         "cifar10@trainval": trainval_cifar10_loader, | ||||
|         "cifar10@train": train_cifar10_loader, | ||||
|         "cifar10@valid": valid_cifar10_loader, | ||||
|         "cifar10@test": test__cifar10_loader, | ||||
|         "cifar100@train": train_cifar100_loader, | ||||
|         "cifar100@valid": valid_cifar100_loader, | ||||
|         "cifar100@test": test__cifar100_loader, | ||||
|         "ImageNet16-120@train": train_imagenet_loader, | ||||
|         "ImageNet16-120@valid": valid_imagenet_loader, | ||||
|         "ImageNet16-120@test": test__imagenet_loader, | ||||
|     } | ||||
|     return loaders | ||||
| @@ -1,134 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 # | ||||
| ##################################################### | ||||
| import abc | ||||
| import numpy as np | ||||
| import torch | ||||
|  | ||||
|  | ||||
| class AverageMeter(object): | ||||
|     """Computes and stores the average and current value""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         self.reset() | ||||
|  | ||||
|     def reset(self): | ||||
|         self.val = 0.0 | ||||
|         self.avg = 0.0 | ||||
|         self.sum = 0.0 | ||||
|         self.count = 0.0 | ||||
|  | ||||
|     def update(self, val, n=1): | ||||
|         self.val = val | ||||
|         self.sum += val * n | ||||
|         self.count += n | ||||
|         self.avg = self.sum / self.count | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}(val={val}, avg={avg}, count={count})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class Metric(abc.ABC): | ||||
|     """The default meta metric class.""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         self.reset() | ||||
|  | ||||
|     def reset(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __call__(self, predictions, targets): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def get_info(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({inner})".format( | ||||
|             name=self.__class__.__name__, inner=self.inner_repr() | ||||
|         ) | ||||
|  | ||||
|     def inner_repr(self): | ||||
|         return "" | ||||
|  | ||||
|  | ||||
| class ComposeMetric(Metric): | ||||
|     """The composed metric class.""" | ||||
|  | ||||
|     def __init__(self, *metric_list): | ||||
|         self.reset() | ||||
|         for metric in metric_list: | ||||
|             self.append(metric) | ||||
|  | ||||
|     def reset(self): | ||||
|         self._metric_list = [] | ||||
|  | ||||
|     def append(self, metric): | ||||
|         if not isinstance(metric, Metric): | ||||
|             raise ValueError( | ||||
|                 "The input metric is not correct: {:}".format(type(metric)) | ||||
|             ) | ||||
|         self._metric_list.append(metric) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self._metric_list) | ||||
|  | ||||
|     def __call__(self, predictions, targets): | ||||
|         results = list() | ||||
|         for metric in self._metric_list: | ||||
|             results.append(metric(predictions, targets)) | ||||
|         return results | ||||
|  | ||||
|     def get_info(self): | ||||
|         results = dict() | ||||
|         for metric in self._metric_list: | ||||
|             for key, value in metric.get_info().items(): | ||||
|                 results[key] = value | ||||
|         return results | ||||
|  | ||||
|     def inner_repr(self): | ||||
|         xlist = [] | ||||
|         for metric in self._metric_list: | ||||
|             xlist.append(str(metric)) | ||||
|         return ",".join(xlist) | ||||
|  | ||||
|  | ||||
| class MSEMetric(Metric): | ||||
|     """The metric for mse.""" | ||||
|  | ||||
|     def reset(self): | ||||
|         self._mse = AverageMeter() | ||||
|  | ||||
|     def __call__(self, predictions, targets): | ||||
|         if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor): | ||||
|             batch = predictions.shape[0] | ||||
|             loss = torch.nn.functional.mse_loss(predictions.data, targets.data) | ||||
|             loss = loss.item() | ||||
|             self._mse.update(loss, batch) | ||||
|             return loss | ||||
|         else: | ||||
|             raise NotImplementedError | ||||
|  | ||||
|     def get_info(self): | ||||
|         return {"mse": self._mse.avg} | ||||
|  | ||||
|  | ||||
| class SaveMetric(Metric): | ||||
|     """The metric for mse.""" | ||||
|  | ||||
|     def reset(self): | ||||
|         self._predicts = [] | ||||
|  | ||||
|     def __call__(self, predictions, targets=None): | ||||
|         if isinstance(predictions, torch.Tensor): | ||||
|             predicts = predictions.cpu().numpy() | ||||
|             self._predicts.append(predicts) | ||||
|             return predicts | ||||
|         else: | ||||
|             raise NotImplementedError | ||||
|  | ||||
|     def get_info(self): | ||||
|         all_predicts = np.concatenate(self._predicts) | ||||
|         return {"predictions": all_predicts} | ||||
| @@ -1,263 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| import math, torch | ||||
| import torch.nn as nn | ||||
| from bisect import bisect_right | ||||
| from torch.optim import Optimizer | ||||
|  | ||||
|  | ||||
| class _LRScheduler(object): | ||||
|     def __init__(self, optimizer, warmup_epochs, epochs): | ||||
|         if not isinstance(optimizer, Optimizer): | ||||
|             raise TypeError("{:} is not an Optimizer".format(type(optimizer).__name__)) | ||||
|         self.optimizer = optimizer | ||||
|         for group in optimizer.param_groups: | ||||
|             group.setdefault("initial_lr", group["lr"]) | ||||
|         self.base_lrs = list( | ||||
|             map(lambda group: group["initial_lr"], optimizer.param_groups) | ||||
|         ) | ||||
|         self.max_epochs = epochs | ||||
|         self.warmup_epochs = warmup_epochs | ||||
|         self.current_epoch = 0 | ||||
|         self.current_iter = 0 | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "" | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}(warmup={warmup_epochs}, max-epoch={max_epochs}, current::epoch={current_epoch}, iter={current_iter:.2f}".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) + ", {:})".format( | ||||
|             self.extra_repr() | ||||
|         ) | ||||
|  | ||||
|     def state_dict(self): | ||||
|         return { | ||||
|             key: value for key, value in self.__dict__.items() if key != "optimizer" | ||||
|         } | ||||
|  | ||||
|     def load_state_dict(self, state_dict): | ||||
|         self.__dict__.update(state_dict) | ||||
|  | ||||
|     def get_lr(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def get_min_info(self): | ||||
|         lrs = self.get_lr() | ||||
|         return "#LR=[{:.6f}~{:.6f}] epoch={:03d}, iter={:4.2f}#".format( | ||||
|             min(lrs), max(lrs), self.current_epoch, self.current_iter | ||||
|         ) | ||||
|  | ||||
|     def get_min_lr(self): | ||||
|         return min(self.get_lr()) | ||||
|  | ||||
|     def update(self, cur_epoch, cur_iter): | ||||
|         if cur_epoch is not None: | ||||
|             assert ( | ||||
|                 isinstance(cur_epoch, int) and cur_epoch >= 0 | ||||
|             ), "invalid cur-epoch : {:}".format(cur_epoch) | ||||
|             self.current_epoch = cur_epoch | ||||
|         if cur_iter is not None: | ||||
|             assert ( | ||||
|                 isinstance(cur_iter, float) and cur_iter >= 0 | ||||
|             ), "invalid cur-iter : {:}".format(cur_iter) | ||||
|             self.current_iter = cur_iter | ||||
|         for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): | ||||
|             param_group["lr"] = lr | ||||
|  | ||||
|  | ||||
| class CosineAnnealingLR(_LRScheduler): | ||||
|     def __init__(self, optimizer, warmup_epochs, epochs, T_max, eta_min): | ||||
|         self.T_max = T_max | ||||
|         self.eta_min = eta_min | ||||
|         super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs) | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "type={:}, T-max={:}, eta-min={:}".format( | ||||
|             "cosine", self.T_max, self.eta_min | ||||
|         ) | ||||
|  | ||||
|     def get_lr(self): | ||||
|         lrs = [] | ||||
|         for base_lr in self.base_lrs: | ||||
|             if ( | ||||
|                 self.current_epoch >= self.warmup_epochs | ||||
|                 and self.current_epoch < self.max_epochs | ||||
|             ): | ||||
|                 last_epoch = self.current_epoch - self.warmup_epochs | ||||
|                 # if last_epoch < self.T_max: | ||||
|                 # if last_epoch < self.max_epochs: | ||||
|                 lr = ( | ||||
|                     self.eta_min | ||||
|                     + (base_lr - self.eta_min) | ||||
|                     * (1 + math.cos(math.pi * last_epoch / self.T_max)) | ||||
|                     / 2 | ||||
|                 ) | ||||
|                 # else: | ||||
|                 #  lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2 | ||||
|             elif self.current_epoch >= self.max_epochs: | ||||
|                 lr = self.eta_min | ||||
|             else: | ||||
|                 lr = ( | ||||
|                     self.current_epoch / self.warmup_epochs | ||||
|                     + self.current_iter / self.warmup_epochs | ||||
|                 ) * base_lr | ||||
|             lrs.append(lr) | ||||
|         return lrs | ||||
|  | ||||
|  | ||||
| class MultiStepLR(_LRScheduler): | ||||
|     def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas): | ||||
|         assert len(milestones) == len(gammas), "invalid {:} vs {:}".format( | ||||
|             len(milestones), len(gammas) | ||||
|         ) | ||||
|         self.milestones = milestones | ||||
|         self.gammas = gammas | ||||
|         super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs) | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "type={:}, milestones={:}, gammas={:}, base-lrs={:}".format( | ||||
|             "multistep", self.milestones, self.gammas, self.base_lrs | ||||
|         ) | ||||
|  | ||||
|     def get_lr(self): | ||||
|         lrs = [] | ||||
|         for base_lr in self.base_lrs: | ||||
|             if self.current_epoch >= self.warmup_epochs: | ||||
|                 last_epoch = self.current_epoch - self.warmup_epochs | ||||
|                 idx = bisect_right(self.milestones, last_epoch) | ||||
|                 lr = base_lr | ||||
|                 for x in self.gammas[:idx]: | ||||
|                     lr *= x | ||||
|             else: | ||||
|                 lr = ( | ||||
|                     self.current_epoch / self.warmup_epochs | ||||
|                     + self.current_iter / self.warmup_epochs | ||||
|                 ) * base_lr | ||||
|             lrs.append(lr) | ||||
|         return lrs | ||||
|  | ||||
|  | ||||
| class ExponentialLR(_LRScheduler): | ||||
|     def __init__(self, optimizer, warmup_epochs, epochs, gamma): | ||||
|         self.gamma = gamma | ||||
|         super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs) | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "type={:}, gamma={:}, base-lrs={:}".format( | ||||
|             "exponential", self.gamma, self.base_lrs | ||||
|         ) | ||||
|  | ||||
|     def get_lr(self): | ||||
|         lrs = [] | ||||
|         for base_lr in self.base_lrs: | ||||
|             if self.current_epoch >= self.warmup_epochs: | ||||
|                 last_epoch = self.current_epoch - self.warmup_epochs | ||||
|                 assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch) | ||||
|                 lr = base_lr * (self.gamma ** last_epoch) | ||||
|             else: | ||||
|                 lr = ( | ||||
|                     self.current_epoch / self.warmup_epochs | ||||
|                     + self.current_iter / self.warmup_epochs | ||||
|                 ) * base_lr | ||||
|             lrs.append(lr) | ||||
|         return lrs | ||||
|  | ||||
|  | ||||
| class LinearLR(_LRScheduler): | ||||
|     def __init__(self, optimizer, warmup_epochs, epochs, max_LR, min_LR): | ||||
|         self.max_LR = max_LR | ||||
|         self.min_LR = min_LR | ||||
|         super(LinearLR, self).__init__(optimizer, warmup_epochs, epochs) | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "type={:}, max_LR={:}, min_LR={:}, base-lrs={:}".format( | ||||
|             "LinearLR", self.max_LR, self.min_LR, self.base_lrs | ||||
|         ) | ||||
|  | ||||
|     def get_lr(self): | ||||
|         lrs = [] | ||||
|         for base_lr in self.base_lrs: | ||||
|             if self.current_epoch >= self.warmup_epochs: | ||||
|                 last_epoch = self.current_epoch - self.warmup_epochs | ||||
|                 assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch) | ||||
|                 ratio = ( | ||||
|                     (self.max_LR - self.min_LR) | ||||
|                     * last_epoch | ||||
|                     / self.max_epochs | ||||
|                     / self.max_LR | ||||
|                 ) | ||||
|                 lr = base_lr * (1 - ratio) | ||||
|             else: | ||||
|                 lr = ( | ||||
|                     self.current_epoch / self.warmup_epochs | ||||
|                     + self.current_iter / self.warmup_epochs | ||||
|                 ) * base_lr | ||||
|             lrs.append(lr) | ||||
|         return lrs | ||||
|  | ||||
|  | ||||
| class CrossEntropyLabelSmooth(nn.Module): | ||||
|     def __init__(self, num_classes, epsilon): | ||||
|         super(CrossEntropyLabelSmooth, self).__init__() | ||||
|         self.num_classes = num_classes | ||||
|         self.epsilon = epsilon | ||||
|         self.logsoftmax = nn.LogSoftmax(dim=1) | ||||
|  | ||||
|     def forward(self, inputs, targets): | ||||
|         log_probs = self.logsoftmax(inputs) | ||||
|         targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1) | ||||
|         targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes | ||||
|         loss = (-targets * log_probs).mean(0).sum() | ||||
|         return loss | ||||
|  | ||||
|  | ||||
| def get_optim_scheduler(parameters, config): | ||||
|     assert ( | ||||
|         hasattr(config, "optim") | ||||
|         and hasattr(config, "scheduler") | ||||
|         and hasattr(config, "criterion") | ||||
|     ), "config must have optim / scheduler / criterion keys instead of {:}".format( | ||||
|         config | ||||
|     ) | ||||
|     if config.optim == "SGD": | ||||
|         optim = torch.optim.SGD( | ||||
|             parameters, | ||||
|             config.LR, | ||||
|             momentum=config.momentum, | ||||
|             weight_decay=config.decay, | ||||
|             nesterov=config.nesterov, | ||||
|         ) | ||||
|     elif config.optim == "RMSprop": | ||||
|         optim = torch.optim.RMSprop( | ||||
|             parameters, config.LR, momentum=config.momentum, weight_decay=config.decay | ||||
|         ) | ||||
|     else: | ||||
|         raise ValueError("invalid optim : {:}".format(config.optim)) | ||||
|  | ||||
|     if config.scheduler == "cos": | ||||
|         T_max = getattr(config, "T_max", config.epochs) | ||||
|         scheduler = CosineAnnealingLR( | ||||
|             optim, config.warmup, config.epochs, T_max, config.eta_min | ||||
|         ) | ||||
|     elif config.scheduler == "multistep": | ||||
|         scheduler = MultiStepLR( | ||||
|             optim, config.warmup, config.epochs, config.milestones, config.gammas | ||||
|         ) | ||||
|     elif config.scheduler == "exponential": | ||||
|         scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma) | ||||
|     elif config.scheduler == "linear": | ||||
|         scheduler = LinearLR( | ||||
|             optim, config.warmup, config.epochs, config.LR, config.LR_min | ||||
|         ) | ||||
|     else: | ||||
|         raise ValueError("invalid scheduler : {:}".format(config.scheduler)) | ||||
|  | ||||
|     if config.criterion == "Softmax": | ||||
|         criterion = torch.nn.CrossEntropyLoss() | ||||
|     elif config.criterion == "SmoothSoftmax": | ||||
|         criterion = CrossEntropyLabelSmooth(config.class_num, config.label_smooth) | ||||
|     else: | ||||
|         raise ValueError("invalid criterion : {:}".format(config.criterion)) | ||||
|     return optim, scheduler, criterion | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user